Spaces:
Runtime error
Runtime error
langchain-qa-bot
/
docs
/langchain
/libs
/community
/langchain_community
/callbacks
/aim_callback.py
from copy import deepcopy | |
from typing import Any, Dict, List, Optional | |
from langchain_core.agents import AgentAction, AgentFinish | |
from langchain_core.callbacks import BaseCallbackHandler | |
from langchain_core.outputs import LLMResult | |
from langchain_core.utils import guard_import | |
def import_aim() -> Any: | |
"""Import the aim python package and raise an error if it is not installed.""" | |
return guard_import("aim") | |
class BaseMetadataCallbackHandler: | |
"""Callback handler for the metadata and associated function states for callbacks. | |
Attributes: | |
step (int): The current step. | |
starts (int): The number of times the start method has been called. | |
ends (int): The number of times the end method has been called. | |
errors (int): The number of times the error method has been called. | |
text_ctr (int): The number of times the text method has been called. | |
ignore_llm_ (bool): Whether to ignore llm callbacks. | |
ignore_chain_ (bool): Whether to ignore chain callbacks. | |
ignore_agent_ (bool): Whether to ignore agent callbacks. | |
ignore_retriever_ (bool): Whether to ignore retriever callbacks. | |
always_verbose_ (bool): Whether to always be verbose. | |
chain_starts (int): The number of times the chain start method has been called. | |
chain_ends (int): The number of times the chain end method has been called. | |
llm_starts (int): The number of times the llm start method has been called. | |
llm_ends (int): The number of times the llm end method has been called. | |
llm_streams (int): The number of times the text method has been called. | |
tool_starts (int): The number of times the tool start method has been called. | |
tool_ends (int): The number of times the tool end method has been called. | |
agent_ends (int): The number of times the agent end method has been called. | |
""" | |
def __init__(self) -> None: | |
self.step = 0 | |
self.starts = 0 | |
self.ends = 0 | |
self.errors = 0 | |
self.text_ctr = 0 | |
self.ignore_llm_ = False | |
self.ignore_chain_ = False | |
self.ignore_agent_ = False | |
self.ignore_retriever_ = False | |
self.always_verbose_ = False | |
self.chain_starts = 0 | |
self.chain_ends = 0 | |
self.llm_starts = 0 | |
self.llm_ends = 0 | |
self.llm_streams = 0 | |
self.tool_starts = 0 | |
self.tool_ends = 0 | |
self.agent_ends = 0 | |
def always_verbose(self) -> bool: | |
"""Whether to call verbose callbacks even if verbose is False.""" | |
return self.always_verbose_ | |
def ignore_llm(self) -> bool: | |
"""Whether to ignore LLM callbacks.""" | |
return self.ignore_llm_ | |
def ignore_chain(self) -> bool: | |
"""Whether to ignore chain callbacks.""" | |
return self.ignore_chain_ | |
def ignore_agent(self) -> bool: | |
"""Whether to ignore agent callbacks.""" | |
return self.ignore_agent_ | |
def ignore_retriever(self) -> bool: | |
"""Whether to ignore retriever callbacks.""" | |
return self.ignore_retriever_ | |
def get_custom_callback_meta(self) -> Dict[str, Any]: | |
return { | |
"step": self.step, | |
"starts": self.starts, | |
"ends": self.ends, | |
"errors": self.errors, | |
"text_ctr": self.text_ctr, | |
"chain_starts": self.chain_starts, | |
"chain_ends": self.chain_ends, | |
"llm_starts": self.llm_starts, | |
"llm_ends": self.llm_ends, | |
"llm_streams": self.llm_streams, | |
"tool_starts": self.tool_starts, | |
"tool_ends": self.tool_ends, | |
"agent_ends": self.agent_ends, | |
} | |
def reset_callback_meta(self) -> None: | |
"""Reset the callback metadata.""" | |
self.step = 0 | |
self.starts = 0 | |
self.ends = 0 | |
self.errors = 0 | |
self.text_ctr = 0 | |
self.ignore_llm_ = False | |
self.ignore_chain_ = False | |
self.ignore_agent_ = False | |
self.always_verbose_ = False | |
self.chain_starts = 0 | |
self.chain_ends = 0 | |
self.llm_starts = 0 | |
self.llm_ends = 0 | |
self.llm_streams = 0 | |
self.tool_starts = 0 | |
self.tool_ends = 0 | |
self.agent_ends = 0 | |
return None | |
class AimCallbackHandler(BaseMetadataCallbackHandler, BaseCallbackHandler): | |
"""Callback Handler that logs to Aim. | |
Parameters: | |
repo (:obj:`str`, optional): Aim repository path or Repo object to which | |
Run object is bound. If skipped, default Repo is used. | |
experiment_name (:obj:`str`, optional): Sets Run's `experiment` property. | |
'default' if not specified. Can be used later to query runs/sequences. | |
system_tracking_interval (:obj:`int`, optional): Sets the tracking interval | |
in seconds for system usage metrics (CPU, Memory, etc.). Set to `None` | |
to disable system metrics tracking. | |
log_system_params (:obj:`bool`, optional): Enable/Disable logging of system | |
params such as installed packages, git info, environment variables, etc. | |
This handler will utilize the associated callback method called and formats | |
the input of each callback function with metadata regarding the state of LLM run | |
and then logs the response to Aim. | |
""" | |
def __init__( | |
self, | |
repo: Optional[str] = None, | |
experiment_name: Optional[str] = None, | |
system_tracking_interval: Optional[int] = 10, | |
log_system_params: bool = True, | |
) -> None: | |
"""Initialize callback handler.""" | |
super().__init__() | |
aim = import_aim() | |
self.repo = repo | |
self.experiment_name = experiment_name | |
self.system_tracking_interval = system_tracking_interval | |
self.log_system_params = log_system_params | |
self._run = aim.Run( | |
repo=self.repo, | |
experiment=self.experiment_name, | |
system_tracking_interval=self.system_tracking_interval, | |
log_system_params=self.log_system_params, | |
) | |
self._run_hash = self._run.hash | |
self.action_records: list = [] | |
def setup(self, **kwargs: Any) -> None: | |
aim = import_aim() | |
if not self._run: | |
if self._run_hash: | |
self._run = aim.Run( | |
self._run_hash, | |
repo=self.repo, | |
system_tracking_interval=self.system_tracking_interval, | |
) | |
else: | |
self._run = aim.Run( | |
repo=self.repo, | |
experiment=self.experiment_name, | |
system_tracking_interval=self.system_tracking_interval, | |
log_system_params=self.log_system_params, | |
) | |
self._run_hash = self._run.hash | |
if kwargs: | |
for key, value in kwargs.items(): | |
self._run.set(key, value, strict=False) | |
def on_llm_start( | |
self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any | |
) -> None: | |
"""Run when LLM starts.""" | |
aim = import_aim() | |
self.step += 1 | |
self.llm_starts += 1 | |
self.starts += 1 | |
resp = {"action": "on_llm_start"} | |
resp.update(self.get_custom_callback_meta()) | |
prompts_res = deepcopy(prompts) | |
self._run.track( | |
[aim.Text(prompt) for prompt in prompts_res], | |
name="on_llm_start", | |
context=resp, | |
) | |
def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None: | |
"""Run when LLM ends running.""" | |
aim = import_aim() | |
self.step += 1 | |
self.llm_ends += 1 | |
self.ends += 1 | |
resp = {"action": "on_llm_end"} | |
resp.update(self.get_custom_callback_meta()) | |
response_res = deepcopy(response) | |
generated = [ | |
aim.Text(generation.text) | |
for generations in response_res.generations | |
for generation in generations | |
] | |
self._run.track( | |
generated, | |
name="on_llm_end", | |
context=resp, | |
) | |
def on_llm_new_token(self, token: str, **kwargs: Any) -> None: | |
"""Run when LLM generates a new token.""" | |
self.step += 1 | |
self.llm_streams += 1 | |
def on_llm_error(self, error: BaseException, **kwargs: Any) -> None: | |
"""Run when LLM errors.""" | |
self.step += 1 | |
self.errors += 1 | |
def on_chain_start( | |
self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any | |
) -> None: | |
"""Run when chain starts running.""" | |
aim = import_aim() | |
self.step += 1 | |
self.chain_starts += 1 | |
self.starts += 1 | |
resp = {"action": "on_chain_start"} | |
resp.update(self.get_custom_callback_meta()) | |
inputs_res = deepcopy(inputs) | |
self._run.track( | |
aim.Text(inputs_res["input"]), name="on_chain_start", context=resp | |
) | |
def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None: | |
"""Run when chain ends running.""" | |
aim = import_aim() | |
self.step += 1 | |
self.chain_ends += 1 | |
self.ends += 1 | |
resp = {"action": "on_chain_end"} | |
resp.update(self.get_custom_callback_meta()) | |
outputs_res = deepcopy(outputs) | |
self._run.track( | |
aim.Text(outputs_res["output"]), name="on_chain_end", context=resp | |
) | |
def on_chain_error(self, error: BaseException, **kwargs: Any) -> None: | |
"""Run when chain errors.""" | |
self.step += 1 | |
self.errors += 1 | |
def on_tool_start( | |
self, serialized: Dict[str, Any], input_str: str, **kwargs: Any | |
) -> None: | |
"""Run when tool starts running.""" | |
aim = import_aim() | |
self.step += 1 | |
self.tool_starts += 1 | |
self.starts += 1 | |
resp = {"action": "on_tool_start"} | |
resp.update(self.get_custom_callback_meta()) | |
self._run.track(aim.Text(input_str), name="on_tool_start", context=resp) | |
def on_tool_end(self, output: Any, **kwargs: Any) -> None: | |
"""Run when tool ends running.""" | |
output = str(output) | |
aim = import_aim() | |
self.step += 1 | |
self.tool_ends += 1 | |
self.ends += 1 | |
resp = {"action": "on_tool_end"} | |
resp.update(self.get_custom_callback_meta()) | |
self._run.track(aim.Text(output), name="on_tool_end", context=resp) | |
def on_tool_error(self, error: BaseException, **kwargs: Any) -> None: | |
"""Run when tool errors.""" | |
self.step += 1 | |
self.errors += 1 | |
def on_text(self, text: str, **kwargs: Any) -> None: | |
""" | |
Run when agent is ending. | |
""" | |
self.step += 1 | |
self.text_ctr += 1 | |
def on_agent_finish(self, finish: AgentFinish, **kwargs: Any) -> None: | |
"""Run when agent ends running.""" | |
aim = import_aim() | |
self.step += 1 | |
self.agent_ends += 1 | |
self.ends += 1 | |
resp = {"action": "on_agent_finish"} | |
resp.update(self.get_custom_callback_meta()) | |
finish_res = deepcopy(finish) | |
text = "OUTPUT:\n{}\n\nLOG:\n{}".format( | |
finish_res.return_values["output"], finish_res.log | |
) | |
self._run.track(aim.Text(text), name="on_agent_finish", context=resp) | |
def on_agent_action(self, action: AgentAction, **kwargs: Any) -> Any: | |
"""Run on agent action.""" | |
aim = import_aim() | |
self.step += 1 | |
self.tool_starts += 1 | |
self.starts += 1 | |
resp = { | |
"action": "on_agent_action", | |
"tool": action.tool, | |
} | |
resp.update(self.get_custom_callback_meta()) | |
action_res = deepcopy(action) | |
text = "TOOL INPUT:\n{}\n\nLOG:\n{}".format( | |
action_res.tool_input, action_res.log | |
) | |
self._run.track(aim.Text(text), name="on_agent_action", context=resp) | |
def flush_tracker( | |
self, | |
repo: Optional[str] = None, | |
experiment_name: Optional[str] = None, | |
system_tracking_interval: Optional[int] = 10, | |
log_system_params: bool = True, | |
langchain_asset: Any = None, | |
reset: bool = True, | |
finish: bool = False, | |
) -> None: | |
"""Flush the tracker and reset the session. | |
Args: | |
repo (:obj:`str`, optional): Aim repository path or Repo object to which | |
Run object is bound. If skipped, default Repo is used. | |
experiment_name (:obj:`str`, optional): Sets Run's `experiment` property. | |
'default' if not specified. Can be used later to query runs/sequences. | |
system_tracking_interval (:obj:`int`, optional): Sets the tracking interval | |
in seconds for system usage metrics (CPU, Memory, etc.). Set to `None` | |
to disable system metrics tracking. | |
log_system_params (:obj:`bool`, optional): Enable/Disable logging of system | |
params such as installed packages, git info, environment variables, etc. | |
langchain_asset: The langchain asset to save. | |
reset: Whether to reset the session. | |
finish: Whether to finish the run. | |
Returns: | |
None | |
""" | |
if langchain_asset: | |
try: | |
for key, value in langchain_asset.dict().items(): | |
self._run.set(key, value, strict=False) | |
except Exception: | |
pass | |
if finish or reset: | |
self._run.close() | |
self.reset_callback_meta() | |
if reset: | |
aim = import_aim() | |
self.repo = repo if repo else self.repo | |
self.experiment_name = ( | |
experiment_name if experiment_name else self.experiment_name | |
) | |
self.system_tracking_interval = ( | |
system_tracking_interval | |
if system_tracking_interval | |
else self.system_tracking_interval | |
) | |
self.log_system_params = ( | |
log_system_params if log_system_params else self.log_system_params | |
) | |
self._run = aim.Run( | |
repo=self.repo, | |
experiment=self.experiment_name, | |
system_tracking_interval=self.system_tracking_interval, | |
log_system_params=self.log_system_params, | |
) | |
self._run_hash = self._run.hash | |
self.action_records = [] | |