Spaces:
Runtime error
Runtime error
"""ChatModel wrapper which returns user input as the response..""" | |
from io import StringIO | |
from typing import Any, Callable, Dict, List, Mapping, Optional | |
import yaml | |
from langchain_core.callbacks import ( | |
CallbackManagerForLLMRun, | |
) | |
from langchain_core.language_models.chat_models import BaseChatModel | |
from langchain_core.messages import ( | |
BaseMessage, | |
HumanMessage, | |
_message_from_dict, | |
messages_to_dict, | |
) | |
from langchain_core.outputs import ChatGeneration, ChatResult | |
from langchain_core.pydantic_v1 import Field | |
from langchain_community.llms.utils import enforce_stop_tokens | |
def _display_messages(messages: List[BaseMessage]) -> None: | |
dict_messages = messages_to_dict(messages) | |
for message in dict_messages: | |
yaml_string = yaml.dump( | |
message, | |
default_flow_style=False, | |
sort_keys=False, | |
allow_unicode=True, | |
width=10000, | |
line_break=None, | |
) | |
print("\n", "======= start of message =======", "\n\n") # noqa: T201 | |
print(yaml_string) # noqa: T201 | |
print("======= end of message =======", "\n\n") # noqa: T201 | |
def _collect_yaml_input( | |
messages: List[BaseMessage], stop: Optional[List[str]] = None | |
) -> BaseMessage: | |
"""Collects and returns user input as a single string.""" | |
lines = [] | |
while True: | |
line = input() | |
if not line.strip(): | |
break | |
if stop and any(seq in line for seq in stop): | |
break | |
lines.append(line) | |
yaml_string = "\n".join(lines) | |
# Try to parse the input string as YAML | |
try: | |
message = _message_from_dict(yaml.safe_load(StringIO(yaml_string))) | |
if message is None: | |
return HumanMessage(content="") | |
if stop: | |
if isinstance(message.content, str): | |
message.content = enforce_stop_tokens(message.content, stop) | |
else: | |
raise ValueError("Cannot use when output is not a string.") | |
return message | |
except yaml.YAMLError: | |
raise ValueError("Invalid YAML string entered.") | |
except ValueError: | |
raise ValueError("Invalid message entered.") | |
class HumanInputChatModel(BaseChatModel): | |
"""ChatModel which returns user input as the response.""" | |
input_func: Callable = Field(default_factory=lambda: _collect_yaml_input) | |
message_func: Callable = Field(default_factory=lambda: _display_messages) | |
separator: str = "\n" | |
input_kwargs: Mapping[str, Any] = {} | |
message_kwargs: Mapping[str, Any] = {} | |
def _identifying_params(self) -> Dict[str, Any]: | |
return { | |
"input_func": self.input_func.__name__, | |
"message_func": self.message_func.__name__, | |
} | |
def _llm_type(self) -> str: | |
"""Returns the type of LLM.""" | |
return "human-input-chat-model" | |
def _generate( | |
self, | |
messages: List[BaseMessage], | |
stop: Optional[List[str]] = None, | |
run_manager: Optional[CallbackManagerForLLMRun] = None, | |
**kwargs: Any, | |
) -> ChatResult: | |
""" | |
Displays the messages to the user and returns their input as a response. | |
Args: | |
messages (List[BaseMessage]): The messages to be displayed to the user. | |
stop (Optional[List[str]]): A list of stop strings. | |
run_manager (Optional[CallbackManagerForLLMRun]): Currently not used. | |
Returns: | |
ChatResult: The user's input as a response. | |
""" | |
self.message_func(messages, **self.message_kwargs) | |
user_input = self.input_func(messages, stop=stop, **self.input_kwargs) | |
return ChatResult(generations=[ChatGeneration(message=user_input)]) | |