Spaces:
Runtime error
Runtime error
langchain-qa-bot
/
docs
/langchain
/libs
/community
/langchain_community
/callbacks
/promptlayer_callback.py
"""Callback handler for promptlayer.""" | |
from __future__ import annotations | |
import datetime | |
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple | |
from uuid import UUID | |
from langchain_core.callbacks import BaseCallbackHandler | |
from langchain_core.messages import ( | |
AIMessage, | |
BaseMessage, | |
ChatMessage, | |
HumanMessage, | |
SystemMessage, | |
) | |
from langchain_core.outputs import ( | |
ChatGeneration, | |
LLMResult, | |
) | |
if TYPE_CHECKING: | |
import promptlayer | |
def _lazy_import_promptlayer() -> promptlayer: | |
"""Lazy import promptlayer to avoid circular imports.""" | |
try: | |
import promptlayer | |
except ImportError: | |
raise ImportError( | |
"The PromptLayerCallbackHandler requires the promptlayer package. " | |
" Please install it with `pip install promptlayer`." | |
) | |
return promptlayer | |
class PromptLayerCallbackHandler(BaseCallbackHandler): | |
"""Callback handler for promptlayer.""" | |
def __init__( | |
self, | |
pl_id_callback: Optional[Callable[..., Any]] = None, | |
pl_tags: Optional[List[str]] = None, | |
) -> None: | |
"""Initialize the PromptLayerCallbackHandler.""" | |
_lazy_import_promptlayer() | |
self.pl_id_callback = pl_id_callback | |
self.pl_tags = pl_tags or [] | |
self.runs: Dict[UUID, Dict[str, Any]] = {} | |
def on_chat_model_start( | |
self, | |
serialized: Dict[str, Any], | |
messages: List[List[BaseMessage]], | |
*, | |
run_id: UUID, | |
parent_run_id: Optional[UUID] = None, | |
tags: Optional[List[str]] = None, | |
**kwargs: Any, | |
) -> Any: | |
self.runs[run_id] = { | |
"messages": [self._create_message_dicts(m)[0] for m in messages], | |
"invocation_params": kwargs.get("invocation_params", {}), | |
"name": ".".join(serialized["id"]), | |
"request_start_time": datetime.datetime.now().timestamp(), | |
"tags": tags, | |
} | |
def on_llm_start( | |
self, | |
serialized: Dict[str, Any], | |
prompts: List[str], | |
*, | |
run_id: UUID, | |
parent_run_id: Optional[UUID] = None, | |
tags: Optional[List[str]] = None, | |
**kwargs: Any, | |
) -> Any: | |
self.runs[run_id] = { | |
"prompts": prompts, | |
"invocation_params": kwargs.get("invocation_params", {}), | |
"name": ".".join(serialized["id"]), | |
"request_start_time": datetime.datetime.now().timestamp(), | |
"tags": tags, | |
} | |
def on_llm_end( | |
self, | |
response: LLMResult, | |
*, | |
run_id: UUID, | |
parent_run_id: Optional[UUID] = None, | |
**kwargs: Any, | |
) -> None: | |
from promptlayer.utils import get_api_key, promptlayer_api_request | |
run_info = self.runs.get(run_id, {}) | |
if not run_info: | |
return | |
run_info["request_end_time"] = datetime.datetime.now().timestamp() | |
for i in range(len(response.generations)): | |
generation = response.generations[i][0] | |
resp = { | |
"text": generation.text, | |
"llm_output": response.llm_output, | |
} | |
model_params = run_info.get("invocation_params", {}) | |
is_chat_model = run_info.get("messages", None) is not None | |
model_input = ( | |
run_info.get("messages", [])[i] | |
if is_chat_model | |
else [run_info.get("prompts", [])[i]] | |
) | |
model_response = ( | |
[self._convert_message_to_dict(generation.message)] | |
if is_chat_model and isinstance(generation, ChatGeneration) | |
else resp | |
) | |
pl_request_id = promptlayer_api_request( | |
run_info.get("name"), | |
"langchain", | |
model_input, | |
model_params, | |
self.pl_tags, | |
model_response, | |
run_info.get("request_start_time"), | |
run_info.get("request_end_time"), | |
get_api_key(), | |
return_pl_id=bool(self.pl_id_callback is not None), | |
metadata={ | |
"_langchain_run_id": str(run_id), | |
"_langchain_parent_run_id": str(parent_run_id), | |
"_langchain_tags": str(run_info.get("tags", [])), | |
}, | |
) | |
if self.pl_id_callback: | |
self.pl_id_callback(pl_request_id) | |
def _convert_message_to_dict(self, message: BaseMessage) -> Dict[str, Any]: | |
if isinstance(message, HumanMessage): | |
message_dict = {"role": "user", "content": message.content} | |
elif isinstance(message, AIMessage): | |
message_dict = {"role": "assistant", "content": message.content} | |
elif isinstance(message, SystemMessage): | |
message_dict = {"role": "system", "content": message.content} | |
elif isinstance(message, ChatMessage): | |
message_dict = {"role": message.role, "content": message.content} | |
else: | |
raise ValueError(f"Got unknown type {message}") | |
if "name" in message.additional_kwargs: | |
message_dict["name"] = message.additional_kwargs["name"] | |
return message_dict | |
def _create_message_dicts( | |
self, messages: List[BaseMessage] | |
) -> Tuple[List[Dict[str, Any]], Dict[str, Any]]: | |
params: Dict[str, Any] = {} | |
message_dicts = [self._convert_message_to_dict(m) for m in messages] | |
return message_dicts, params | |