Spaces:
Runtime error
Runtime error
File size: 5,301 Bytes
ed4d993 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 |
import logging
from typing import Any, Dict, List, Mapping, Optional
from langchain_core.callbacks import CallbackManagerForLLMRun
from langchain_core.language_models.llms import LLM
from langchain_core.pydantic_v1 import Extra, Field, SecretStr, root_validator
from langchain_core.utils import convert_to_secret_str, get_from_dict_or_env
logger = logging.getLogger(__name__)
class GooseAI(LLM):
"""GooseAI large language models.
To use, you should have the ``openai`` python package installed, and the
environment variable ``GOOSEAI_API_KEY`` set with your API key.
Any parameters that are valid to be passed to the openai.create call can be passed
in, even if not explicitly saved on this class.
Example:
.. code-block:: python
from langchain_community.llms import GooseAI
gooseai = GooseAI(model_name="gpt-neo-20b")
"""
client: Any
model_name: str = "gpt-neo-20b"
"""Model name to use"""
temperature: float = 0.7
"""What sampling temperature to use"""
max_tokens: int = 256
"""The maximum number of tokens to generate in the completion.
-1 returns as many tokens as possible given the prompt and
the models maximal context size."""
top_p: float = 1
"""Total probability mass of tokens to consider at each step."""
min_tokens: int = 1
"""The minimum number of tokens to generate in the completion."""
frequency_penalty: float = 0
"""Penalizes repeated tokens according to frequency."""
presence_penalty: float = 0
"""Penalizes repeated tokens."""
n: int = 1
"""How many completions to generate for each prompt."""
model_kwargs: Dict[str, Any] = Field(default_factory=dict)
"""Holds any model parameters valid for `create` call not explicitly specified."""
logit_bias: Optional[Dict[str, float]] = Field(default_factory=dict)
"""Adjust the probability of specific tokens being generated."""
gooseai_api_key: Optional[SecretStr] = None
class Config:
"""Configuration for this pydantic config."""
extra = Extra.ignore
@root_validator(pre=True)
def build_extra(cls, values: Dict[str, Any]) -> Dict[str, Any]:
"""Build extra kwargs from additional params that were passed in."""
all_required_field_names = {field.alias for field in cls.__fields__.values()}
extra = values.get("model_kwargs", {})
for field_name in list(values):
if field_name not in all_required_field_names:
if field_name in extra:
raise ValueError(f"Found {field_name} supplied twice.")
logger.warning(
f"""WARNING! {field_name} is not default parameter.
{field_name} was transferred to model_kwargs.
Please confirm that {field_name} is what you intended."""
)
extra[field_name] = values.pop(field_name)
values["model_kwargs"] = extra
return values
@root_validator()
def validate_environment(cls, values: Dict) -> Dict:
"""Validate that api key and python package exists in environment."""
gooseai_api_key = convert_to_secret_str(
get_from_dict_or_env(values, "gooseai_api_key", "GOOSEAI_API_KEY")
)
values["gooseai_api_key"] = gooseai_api_key
try:
import openai
openai.api_key = gooseai_api_key.get_secret_value()
openai.api_base = "https://api.goose.ai/v1"
values["client"] = openai.Completion
except ImportError:
raise ImportError(
"Could not import openai python package. "
"Please install it with `pip install openai`."
)
return values
@property
def _default_params(self) -> Dict[str, Any]:
"""Get the default parameters for calling GooseAI API."""
normal_params = {
"temperature": self.temperature,
"max_tokens": self.max_tokens,
"top_p": self.top_p,
"min_tokens": self.min_tokens,
"frequency_penalty": self.frequency_penalty,
"presence_penalty": self.presence_penalty,
"n": self.n,
"logit_bias": self.logit_bias,
}
return {**normal_params, **self.model_kwargs}
@property
def _identifying_params(self) -> Mapping[str, Any]:
"""Get the identifying parameters."""
return {**{"model_name": self.model_name}, **self._default_params}
@property
def _llm_type(self) -> str:
"""Return type of llm."""
return "gooseai"
def _call(
self,
prompt: str,
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> str:
"""Call the GooseAI API."""
params = self._default_params
if stop is not None:
if "stop" in params:
raise ValueError("`stop` found in both the input and default params.")
params["stop"] = stop
params = {**params, **kwargs}
response = self.client.create(engine=self.model_name, prompt=prompt, **params)
text = response.choices[0].text
return text
|