Spaces:
Runtime error
Runtime error
langchain-qa-bot
/
docs
/langchain
/libs
/community
/langchain_community
/callbacks
/trubrics_callback.py
import os | |
from typing import Any, Dict, List, Optional | |
from uuid import UUID | |
from langchain_core.callbacks import BaseCallbackHandler | |
from langchain_core.messages import ( | |
AIMessage, | |
BaseMessage, | |
ChatMessage, | |
FunctionMessage, | |
HumanMessage, | |
SystemMessage, | |
) | |
from langchain_core.outputs import LLMResult | |
def _convert_message_to_dict(message: BaseMessage) -> dict: | |
message_dict: Dict[str, Any] | |
if isinstance(message, ChatMessage): | |
message_dict = {"role": message.role, "content": message.content} | |
elif isinstance(message, HumanMessage): | |
message_dict = {"role": "user", "content": message.content} | |
elif isinstance(message, AIMessage): | |
message_dict = {"role": "assistant", "content": message.content} | |
if "function_call" in message.additional_kwargs: | |
message_dict["function_call"] = message.additional_kwargs["function_call"] | |
# If function call only, content is None not empty string | |
if message_dict["content"] == "": | |
message_dict["content"] = None | |
elif isinstance(message, SystemMessage): | |
message_dict = {"role": "system", "content": message.content} | |
elif isinstance(message, FunctionMessage): | |
message_dict = { | |
"role": "function", | |
"content": message.content, | |
"name": message.name, | |
} | |
else: | |
raise TypeError(f"Got unknown type {message}") | |
if "name" in message.additional_kwargs: | |
message_dict["name"] = message.additional_kwargs["name"] | |
return message_dict | |
class TrubricsCallbackHandler(BaseCallbackHandler): | |
""" | |
Callback handler for Trubrics. | |
Args: | |
project: a trubrics project, default project is "default" | |
email: a trubrics account email, can equally be set in env variables | |
password: a trubrics account password, can equally be set in env variables | |
**kwargs: all other kwargs are parsed and set to trubrics prompt variables, | |
or added to the `metadata` dict | |
""" | |
def __init__( | |
self, | |
project: str = "default", | |
email: Optional[str] = None, | |
password: Optional[str] = None, | |
**kwargs: Any, | |
) -> None: | |
super().__init__() | |
try: | |
from trubrics import Trubrics | |
except ImportError: | |
raise ImportError( | |
"The TrubricsCallbackHandler requires installation of " | |
"the trubrics package. " | |
"Please install it with `pip install trubrics`." | |
) | |
self.trubrics = Trubrics( | |
project=project, | |
email=email or os.environ["TRUBRICS_EMAIL"], | |
password=password or os.environ["TRUBRICS_PASSWORD"], | |
) | |
self.config_model: dict = {} | |
self.prompt: Optional[str] = None | |
self.messages: Optional[list] = None | |
self.trubrics_kwargs: Optional[dict] = kwargs if kwargs else None | |
def on_llm_start( | |
self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any | |
) -> None: | |
self.prompt = prompts[0] | |
def on_chat_model_start( | |
self, | |
serialized: Dict[str, Any], | |
messages: List[List[BaseMessage]], | |
**kwargs: Any, | |
) -> None: | |
self.messages = [_convert_message_to_dict(message) for message in messages[0]] | |
self.prompt = self.messages[-1]["content"] | |
def on_llm_end(self, response: LLMResult, run_id: UUID, **kwargs: Any) -> None: | |
tags = ["langchain"] | |
user_id = None | |
session_id = None | |
metadata: dict = {"langchain_run_id": run_id} | |
if self.messages: | |
metadata["messages"] = self.messages | |
if self.trubrics_kwargs: | |
if self.trubrics_kwargs.get("tags"): | |
tags.append(*self.trubrics_kwargs.pop("tags")) | |
user_id = self.trubrics_kwargs.pop("user_id", None) | |
session_id = self.trubrics_kwargs.pop("session_id", None) | |
metadata.update(self.trubrics_kwargs) | |
for generation in response.generations: | |
self.trubrics.log_prompt( | |
config_model={ | |
"model": response.llm_output.get("model_name") | |
if response.llm_output | |
else "NA" | |
}, | |
prompt=self.prompt, | |
generation=generation[0].text, | |
user_id=user_id, | |
session_id=session_id, | |
tags=tags, | |
metadata=metadata, | |
) | |