Spaces:
Runtime error
Runtime error
File size: 4,884 Bytes
ed4d993 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 |
import json
import logging
from typing import Any, List, Optional, Union
from langchain_core.callbacks import CallbackManagerForLLMRun
from langchain_core.language_models.llms import LLM
from langchain_core.messages import (
AIMessage,
BaseMessage,
FunctionMessage,
HumanMessage,
SystemMessage,
)
from langchain_core.pydantic_v1 import Field
from langchain_community.llms.utils import enforce_stop_tokens
logger = logging.getLogger(__name__)
HEADERS = {"Content-Type": "application/json"}
DEFAULT_TIMEOUT = 30
def _convert_message_to_dict(message: BaseMessage) -> dict:
if isinstance(message, HumanMessage):
message_dict = {"role": "user", "content": message.content}
elif isinstance(message, AIMessage):
message_dict = {"role": "assistant", "content": message.content}
elif isinstance(message, SystemMessage):
message_dict = {"role": "system", "content": message.content}
elif isinstance(message, FunctionMessage):
message_dict = {"role": "function", "content": message.content}
else:
raise ValueError(f"Got unknown type {message}")
return message_dict
class ChatGLM3(LLM):
"""ChatGLM3 LLM service."""
model_name: str = Field(default="chatglm3-6b", alias="model")
endpoint_url: str = "http://127.0.0.1:8000/v1/chat/completions"
"""Endpoint URL to use."""
model_kwargs: Optional[dict] = None
"""Keyword arguments to pass to the model."""
max_tokens: int = 20000
"""Max token allowed to pass to the model."""
temperature: float = 0.1
"""LLM model temperature from 0 to 10."""
top_p: float = 0.7
"""Top P for nucleus sampling from 0 to 1"""
prefix_messages: List[BaseMessage] = Field(default_factory=list)
"""Series of messages for Chat input."""
streaming: bool = False
"""Whether to stream the results or not."""
http_client: Union[Any, None] = None
timeout: int = DEFAULT_TIMEOUT
@property
def _llm_type(self) -> str:
return "chat_glm_3"
@property
def _invocation_params(self) -> dict:
"""Get the parameters used to invoke the model."""
params = {
"model": self.model_name,
"temperature": self.temperature,
"max_tokens": self.max_tokens,
"top_p": self.top_p,
"stream": self.streaming,
}
return {**params, **(self.model_kwargs or {})}
@property
def client(self) -> Any:
import httpx
return self.http_client or httpx.Client(timeout=self.timeout)
def _get_payload(self, prompt: str) -> dict:
params = self._invocation_params
messages = self.prefix_messages + [HumanMessage(content=prompt)]
params.update(
{
"messages": [_convert_message_to_dict(m) for m in messages],
}
)
return params
def _call(
self,
prompt: str,
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> str:
"""Call out to a ChatGLM3 LLM inference endpoint.
Args:
prompt: The prompt to pass into the model.
stop: Optional list of stop words to use when generating.
Returns:
The string generated by the model.
Example:
.. code-block:: python
response = chatglm_llm.invoke("Who are you?")
"""
import httpx
payload = self._get_payload(prompt)
logger.debug(f"ChatGLM3 payload: {payload}")
try:
response = self.client.post(
self.endpoint_url, headers=HEADERS, json=payload
)
except httpx.NetworkError as e:
raise ValueError(f"Error raised by inference endpoint: {e}")
logger.debug(f"ChatGLM3 response: {response}")
if response.status_code != 200:
raise ValueError(f"Failed with response: {response}")
try:
parsed_response = response.json()
if isinstance(parsed_response, dict):
content_keys = "choices"
if content_keys in parsed_response:
choices = parsed_response[content_keys]
if len(choices):
text = choices[0]["message"]["content"]
else:
raise ValueError(f"No content in response : {parsed_response}")
else:
raise ValueError(f"Unexpected response type: {parsed_response}")
except json.JSONDecodeError as e:
raise ValueError(
f"Error raised during decoding response from inference endpoint: {e}."
f"\nResponse: {response.text}"
)
if stop is not None:
text = enforce_stop_tokens(text, stop)
return text
|