Spaces:
Runtime error
Runtime error
langchain-qa-bot
/
docs
/langchain
/libs
/community
/langchain_community
/callbacks
/context_callback.py
"""Callback handler for Context AI""" | |
import os | |
from typing import Any, Dict, List | |
from uuid import UUID | |
from langchain_core.callbacks import BaseCallbackHandler | |
from langchain_core.messages import BaseMessage | |
from langchain_core.outputs import LLMResult | |
from langchain_core.utils import guard_import | |
def import_context() -> Any: | |
"""Import the `getcontext` package.""" | |
return ( | |
guard_import("getcontext", pip_name="python-context"), | |
guard_import("getcontext.token", pip_name="python-context").Credential, | |
guard_import( | |
"getcontext.generated.models", pip_name="python-context" | |
).Conversation, | |
guard_import("getcontext.generated.models", pip_name="python-context").Message, | |
guard_import( | |
"getcontext.generated.models", pip_name="python-context" | |
).MessageRole, | |
guard_import("getcontext.generated.models", pip_name="python-context").Rating, | |
) | |
class ContextCallbackHandler(BaseCallbackHandler): | |
"""Callback Handler that records transcripts to the Context service. | |
(https://context.ai). | |
Keyword Args: | |
token (optional): The token with which to authenticate requests to Context. | |
Visit https://with.context.ai/settings to generate a token. | |
If not provided, the value of the `CONTEXT_TOKEN` environment | |
variable will be used. | |
Raises: | |
ImportError: if the `context-python` package is not installed. | |
Chat Example: | |
>>> from langchain_community.llms import ChatOpenAI | |
>>> from langchain_community.callbacks import ContextCallbackHandler | |
>>> context_callback = ContextCallbackHandler( | |
... token="<CONTEXT_TOKEN_HERE>", | |
... ) | |
>>> chat = ChatOpenAI( | |
... temperature=0, | |
... headers={"user_id": "123"}, | |
... callbacks=[context_callback], | |
... openai_api_key="API_KEY_HERE", | |
... ) | |
>>> messages = [ | |
... SystemMessage(content="You translate English to French."), | |
... HumanMessage(content="I love programming with LangChain."), | |
... ] | |
>>> chat.invoke(messages) | |
Chain Example: | |
>>> from langchain.chains import LLMChain | |
>>> from langchain_community.chat_models import ChatOpenAI | |
>>> from langchain_community.callbacks import ContextCallbackHandler | |
>>> context_callback = ContextCallbackHandler( | |
... token="<CONTEXT_TOKEN_HERE>", | |
... ) | |
>>> human_message_prompt = HumanMessagePromptTemplate( | |
... prompt=PromptTemplate( | |
... template="What is a good name for a company that makes {product}?", | |
... input_variables=["product"], | |
... ), | |
... ) | |
>>> chat_prompt_template = ChatPromptTemplate.from_messages( | |
... [human_message_prompt] | |
... ) | |
>>> callback = ContextCallbackHandler(token) | |
>>> # Note: the same callback object must be shared between the | |
... LLM and the chain. | |
>>> chat = ChatOpenAI(temperature=0.9, callbacks=[callback]) | |
>>> chain = LLMChain( | |
... llm=chat, | |
... prompt=chat_prompt_template, | |
... callbacks=[callback] | |
... ) | |
>>> chain.run("colorful socks") | |
""" | |
def __init__(self, token: str = "", verbose: bool = False, **kwargs: Any) -> None: | |
( | |
self.context, | |
self.credential, | |
self.conversation_model, | |
self.message_model, | |
self.message_role_model, | |
self.rating_model, | |
) = import_context() | |
token = token or os.environ.get("CONTEXT_TOKEN") or "" | |
self.client = self.context.ContextAPI(credential=self.credential(token)) | |
self.chain_run_id = None | |
self.llm_model = None | |
self.messages: List[Any] = [] | |
self.metadata: Dict[str, str] = {} | |
def on_chat_model_start( | |
self, | |
serialized: Dict[str, Any], | |
messages: List[List[BaseMessage]], | |
*, | |
run_id: UUID, | |
**kwargs: Any, | |
) -> Any: | |
"""Run when the chat model is started.""" | |
llm_model = kwargs.get("invocation_params", {}).get("model", None) | |
if llm_model is not None: | |
self.metadata["model"] = llm_model | |
if len(messages) == 0: | |
return | |
for message in messages[0]: | |
role = self.message_role_model.SYSTEM | |
if message.type == "human": | |
role = self.message_role_model.USER | |
elif message.type == "system": | |
role = self.message_role_model.SYSTEM | |
elif message.type == "ai": | |
role = self.message_role_model.ASSISTANT | |
self.messages.append( | |
self.message_model( | |
message=message.content, | |
role=role, | |
) | |
) | |
def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None: | |
"""Run when LLM ends.""" | |
if len(response.generations) == 0 or len(response.generations[0]) == 0: | |
return | |
if not self.chain_run_id: | |
generation = response.generations[0][0] | |
self.messages.append( | |
self.message_model( | |
message=generation.text, | |
role=self.message_role_model.ASSISTANT, | |
) | |
) | |
self._log_conversation() | |
def on_chain_start( | |
self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any | |
) -> None: | |
"""Run when chain starts.""" | |
self.chain_run_id = kwargs.get("run_id", None) | |
def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None: | |
"""Run when chain ends.""" | |
self.messages.append( | |
self.message_model( | |
message=outputs["text"], | |
role=self.message_role_model.ASSISTANT, | |
) | |
) | |
self._log_conversation() | |
self.chain_run_id = None | |
def _log_conversation(self) -> None: | |
"""Log the conversation to the context API.""" | |
if len(self.messages) == 0: | |
return | |
self.client.log.conversation_upsert( | |
body={ | |
"conversation": self.conversation_model( | |
messages=self.messages, | |
metadata=self.metadata, | |
) | |
} | |
) | |
self.messages = [] | |
self.metadata = {} | |