Spaces:
Runtime error
Runtime error
import logging | |
from typing import Any, Dict, Iterator, List, Mapping, Optional, cast | |
from urllib.parse import urlparse | |
from langchain_core.callbacks import CallbackManagerForLLMRun | |
from langchain_core.language_models import BaseChatModel | |
from langchain_core.language_models.base import LanguageModelInput | |
from langchain_core.messages import ( | |
AIMessage, | |
AIMessageChunk, | |
BaseMessage, | |
BaseMessageChunk, | |
ChatMessage, | |
ChatMessageChunk, | |
FunctionMessage, | |
HumanMessage, | |
HumanMessageChunk, | |
SystemMessage, | |
SystemMessageChunk, | |
) | |
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult | |
from langchain_core.pydantic_v1 import ( | |
Field, | |
PrivateAttr, | |
) | |
from langchain_core.runnables import RunnableConfig | |
logger = logging.getLogger(__name__) | |
class ChatMlflow(BaseChatModel): | |
"""`MLflow` chat models API. | |
To use, you should have the `mlflow[genai]` python package installed. | |
For more information, see https://mlflow.org/docs/latest/llms/deployments. | |
Example: | |
.. code-block:: python | |
from langchain_community.chat_models import ChatMlflow | |
chat = ChatMlflow( | |
target_uri="http://localhost:5000", | |
endpoint="chat", | |
temperature-0.1, | |
) | |
""" | |
endpoint: str | |
"""The endpoint to use.""" | |
target_uri: str | |
"""The target URI to use.""" | |
temperature: float = 0.0 | |
"""The sampling temperature.""" | |
n: int = 1 | |
"""The number of completion choices to generate.""" | |
stop: Optional[List[str]] = None | |
"""The stop sequence.""" | |
max_tokens: Optional[int] = None | |
"""The maximum number of tokens to generate.""" | |
extra_params: dict = Field(default_factory=dict) | |
"""Any extra parameters to pass to the endpoint.""" | |
_client: Any = PrivateAttr() | |
def __init__(self, **kwargs: Any): | |
super().__init__(**kwargs) | |
self._validate_uri() | |
try: | |
from mlflow.deployments import get_deploy_client | |
self._client = get_deploy_client(self.target_uri) | |
except ImportError as e: | |
raise ImportError( | |
"Failed to create the client. " | |
f"Please run `pip install mlflow{self._mlflow_extras}` to install " | |
"required dependencies." | |
) from e | |
def _mlflow_extras(self) -> str: | |
return "[genai]" | |
def _validate_uri(self) -> None: | |
if self.target_uri == "databricks": | |
return | |
allowed = ["http", "https", "databricks"] | |
if urlparse(self.target_uri).scheme not in allowed: | |
raise ValueError( | |
f"Invalid target URI: {self.target_uri}. " | |
f"The scheme must be one of {allowed}." | |
) | |
def _default_params(self) -> Dict[str, Any]: | |
params: Dict[str, Any] = { | |
"target_uri": self.target_uri, | |
"endpoint": self.endpoint, | |
"temperature": self.temperature, | |
"n": self.n, | |
"stop": self.stop, | |
"max_tokens": self.max_tokens, | |
"extra_params": self.extra_params, | |
} | |
return params | |
def _prepare_inputs( | |
self, | |
messages: List[BaseMessage], | |
stop: Optional[List[str]] = None, | |
**kwargs: Any, | |
) -> Dict[str, Any]: | |
message_dicts = [ | |
ChatMlflow._convert_message_to_dict(message) for message in messages | |
] | |
data: Dict[str, Any] = { | |
"messages": message_dicts, | |
"temperature": self.temperature, | |
"n": self.n, | |
**self.extra_params, | |
**kwargs, | |
} | |
if stop := self.stop or stop: | |
data["stop"] = stop | |
if self.max_tokens is not None: | |
data["max_tokens"] = self.max_tokens | |
return data | |
def _generate( | |
self, | |
messages: List[BaseMessage], | |
stop: Optional[List[str]] = None, | |
run_manager: Optional[CallbackManagerForLLMRun] = None, | |
**kwargs: Any, | |
) -> ChatResult: | |
data = self._prepare_inputs( | |
messages, | |
stop, | |
**kwargs, | |
) | |
resp = self._client.predict(endpoint=self.endpoint, inputs=data) | |
return ChatMlflow._create_chat_result(resp) | |
def stream( | |
self, | |
input: LanguageModelInput, | |
config: Optional[RunnableConfig] = None, | |
*, | |
stop: Optional[List[str]] = None, | |
**kwargs: Any, | |
) -> Iterator[BaseMessageChunk]: | |
# We need to override `stream` to handle the case | |
# that `self._client` does not implement `predict_stream` | |
if not hasattr(self._client, "predict_stream"): | |
# MLflow deployment client does not implement streaming, | |
# so use default implementation | |
yield cast( | |
BaseMessageChunk, self.invoke(input, config=config, stop=stop, **kwargs) | |
) | |
else: | |
yield from super().stream(input, config, stop=stop, **kwargs) | |
def _stream( | |
self, | |
messages: List[BaseMessage], | |
stop: Optional[List[str]] = None, | |
run_manager: Optional[CallbackManagerForLLMRun] = None, | |
**kwargs: Any, | |
) -> Iterator[ChatGenerationChunk]: | |
data = self._prepare_inputs( | |
messages, | |
stop, | |
**kwargs, | |
) | |
# TODO: check if `_client.predict_stream` is available. | |
chunk_iter = self._client.predict_stream(endpoint=self.endpoint, inputs=data) | |
first_chunk_role = None | |
for chunk in chunk_iter: | |
choice = chunk["choices"][0] | |
chunk_delta = choice["delta"] | |
if first_chunk_role is None: | |
first_chunk_role = chunk_delta.get("role") | |
chunk = ChatMlflow._convert_delta_to_message_chunk( | |
chunk_delta, first_chunk_role | |
) | |
generation_info = {} | |
if finish_reason := choice.get("finish_reason"): | |
generation_info["finish_reason"] = finish_reason | |
if logprobs := choice.get("logprobs"): | |
generation_info["logprobs"] = logprobs | |
chunk = ChatGenerationChunk( | |
message=chunk, generation_info=generation_info or None | |
) | |
if run_manager: | |
run_manager.on_llm_new_token(chunk.text, chunk=chunk, logprobs=logprobs) | |
yield chunk | |
def _identifying_params(self) -> Dict[str, Any]: | |
return self._default_params | |
def _get_invocation_params( | |
self, stop: Optional[List[str]] = None, **kwargs: Any | |
) -> Dict[str, Any]: | |
"""Get the parameters used to invoke the model FOR THE CALLBACKS.""" | |
return { | |
**self._default_params, | |
**super()._get_invocation_params(stop=stop, **kwargs), | |
} | |
def _llm_type(self) -> str: | |
"""Return type of chat model.""" | |
return "mlflow-chat" | |
def _convert_dict_to_message(_dict: Mapping[str, Any]) -> BaseMessage: | |
role = _dict["role"] | |
content = _dict["content"] | |
if role == "user": | |
return HumanMessage(content=content) | |
elif role == "assistant": | |
return AIMessage(content=content) | |
elif role == "system": | |
return SystemMessage(content=content) | |
else: | |
return ChatMessage(content=content, role=role) | |
def _convert_delta_to_message_chunk( | |
_dict: Mapping[str, Any], default_role: str | |
) -> BaseMessageChunk: | |
role = _dict.get("role", default_role) | |
content = _dict["content"] | |
if role == "user": | |
return HumanMessageChunk(content=content) | |
elif role == "assistant": | |
return AIMessageChunk(content=content) | |
elif role == "system": | |
return SystemMessageChunk(content=content) | |
else: | |
return ChatMessageChunk(content=content, role=role) | |
def _raise_functions_not_supported() -> None: | |
raise ValueError( | |
"Function messages are not supported by Databricks. Please" | |
" create a feature request at https://github.com/mlflow/mlflow/issues." | |
) | |
def _convert_message_to_dict(message: BaseMessage) -> dict: | |
if isinstance(message, ChatMessage): | |
message_dict = {"role": message.role, "content": message.content} | |
elif isinstance(message, HumanMessage): | |
message_dict = {"role": "user", "content": message.content} | |
elif isinstance(message, AIMessage): | |
message_dict = {"role": "assistant", "content": message.content} | |
elif isinstance(message, SystemMessage): | |
message_dict = {"role": "system", "content": message.content} | |
elif isinstance(message, FunctionMessage): | |
raise ValueError( | |
"Function messages are not supported by Databricks. Please" | |
" create a feature request at https://github.com/mlflow/mlflow/issues." | |
) | |
else: | |
raise ValueError(f"Got unknown message type: {message}") | |
if "function_call" in message.additional_kwargs: | |
ChatMlflow._raise_functions_not_supported() | |
if message.additional_kwargs: | |
logger.warning( | |
"Additional message arguments are unsupported by Databricks" | |
" and will be ignored: %s", | |
message.additional_kwargs, | |
) | |
return message_dict | |
def _create_chat_result(response: Mapping[str, Any]) -> ChatResult: | |
generations = [] | |
for choice in response["choices"]: | |
message = ChatMlflow._convert_dict_to_message(choice["message"]) | |
usage = choice.get("usage", {}) | |
gen = ChatGeneration( | |
message=message, | |
generation_info=usage, | |
) | |
generations.append(gen) | |
usage = response.get("usage", {}) | |
return ChatResult(generations=generations, llm_output=usage) | |