Agent / streamlit_agent /callbacks /capturing_callback_handler.py
Fer-geo's picture
first
19b1388
"""Callback Handler captures all callbacks in a session for future offline playback."""
from __future__ import annotations
import pickle
import time
from typing import Any, TypedDict
from langchain.callbacks.base import BaseCallbackHandler
# This is intentionally not an enum so that we avoid serializing a
# custom class with pickle.
class CallbackType:
ON_LLM_START = "on_llm_start"
ON_LLM_NEW_TOKEN = "on_llm_new_token"
ON_LLM_END = "on_llm_end"
ON_LLM_ERROR = "on_llm_error"
ON_TOOL_START = "on_tool_start"
ON_TOOL_END = "on_tool_end"
ON_TOOL_ERROR = "on_tool_error"
ON_TEXT = "on_text"
ON_CHAIN_START = "on_chain_start"
ON_CHAIN_END = "on_chain_end"
ON_CHAIN_ERROR = "on_chain_error"
ON_AGENT_ACTION = "on_agent_action"
ON_AGENT_FINISH = "on_agent_finish"
# We use TypedDict, rather than NamedTuple, so that we avoid serializing a
# custom class with pickle. All of this class's members should be basic Python types.
class CallbackRecord(TypedDict):
callback_type: str
args: tuple[Any, ...]
kwargs: dict[str, Any]
time_delta: float # Number of seconds between this record and the previous one
def load_records_from_file(path: str) -> list[CallbackRecord]:
"""Load the list of CallbackRecords from a pickle file at the given path."""
with open(path, "rb") as file:
records = pickle.load(file)
if not isinstance(records, list):
raise RuntimeError(f"Bad CallbackRecord data in {path}")
return records
def playback_callbacks(
handlers: list[BaseCallbackHandler],
records_or_filename: list[CallbackRecord] | str,
max_pause_time: float,
) -> str:
if isinstance(records_or_filename, list):
records = records_or_filename
else:
records = load_records_from_file(records_or_filename)
for record in records:
pause_time = min(record["time_delta"], max_pause_time)
if pause_time > 0:
time.sleep(pause_time)
for handler in handlers:
if record["callback_type"] == CallbackType.ON_LLM_START:
handler.on_llm_start(*record["args"], **record["kwargs"])
elif record["callback_type"] == CallbackType.ON_LLM_NEW_TOKEN:
handler.on_llm_new_token(*record["args"], **record["kwargs"])
elif record["callback_type"] == CallbackType.ON_LLM_END:
handler.on_llm_end(*record["args"], **record["kwargs"])
elif record["callback_type"] == CallbackType.ON_LLM_ERROR:
handler.on_llm_error(*record["args"], **record["kwargs"])
elif record["callback_type"] == CallbackType.ON_TOOL_START:
handler.on_tool_start(*record["args"], **record["kwargs"])
elif record["callback_type"] == CallbackType.ON_TOOL_END:
handler.on_tool_end(*record["args"], **record["kwargs"])
elif record["callback_type"] == CallbackType.ON_TOOL_ERROR:
handler.on_tool_error(*record["args"], **record["kwargs"])
elif record["callback_type"] == CallbackType.ON_TEXT:
handler.on_text(*record["args"], **record["kwargs"])
elif record["callback_type"] == CallbackType.ON_CHAIN_START:
handler.on_chain_start(*record["args"], **record["kwargs"])
elif record["callback_type"] == CallbackType.ON_CHAIN_END:
handler.on_chain_end(*record["args"], **record["kwargs"])
elif record["callback_type"] == CallbackType.ON_CHAIN_ERROR:
handler.on_chain_error(*record["args"], **record["kwargs"])
elif record["callback_type"] == CallbackType.ON_AGENT_ACTION:
handler.on_agent_action(*record["args"], **record["kwargs"])
elif record["callback_type"] == CallbackType.ON_AGENT_FINISH:
handler.on_agent_finish(*record["args"], **record["kwargs"])
# Return the agent's result
for record in records:
if record["callback_type"] == CallbackType.ON_AGENT_FINISH:
return record["args"][0][0]["output"]
return "[Missing Agent Result]"
class CapturingCallbackHandler(BaseCallbackHandler):
def __init__(self) -> None:
self._records: list[CallbackRecord] = []
self._last_time: float | None = None
def dump_records_to_file(self, path: str) -> None:
"""Write the list of CallbackRecords to a pickle file at the given path."""
with open(path, "wb") as file:
pickle.dump(self._records, file)
def _append_record(
self, type: str, args: tuple[Any, ...], kwargs: dict[str, Any]
) -> None:
time_now = time.time()
time_delta = time_now - self._last_time if self._last_time is not None else 0
self._last_time = time_now
self._records.append(
CallbackRecord(
callback_type=type, args=args, kwargs=kwargs, time_delta=time_delta
)
)
def on_llm_start(self, *args: Any, **kwargs: Any) -> None:
self._append_record(CallbackType.ON_LLM_START, args, kwargs)
def on_llm_new_token(self, *args: Any, **kwargs: Any) -> None:
self._append_record(CallbackType.ON_LLM_NEW_TOKEN, args, kwargs)
def on_llm_end(self, *args: Any, **kwargs: Any) -> None:
self._append_record(CallbackType.ON_LLM_END, args, kwargs)
def on_llm_error(self, *args: Any, **kwargs: Any) -> None:
self._append_record(CallbackType.ON_LLM_ERROR, args, kwargs)
def on_tool_start(self, *args: Any, **kwargs: Any) -> None:
self._append_record(CallbackType.ON_TOOL_START, args, kwargs)
def on_tool_end(self, *args: Any, **kwargs: Any) -> None:
self._append_record(CallbackType.ON_TOOL_END, args, kwargs)
def on_tool_error(self, *args: Any, **kwargs: Any) -> None:
self._append_record(CallbackType.ON_TOOL_ERROR, args, kwargs)
def on_text(self, *args: Any, **kwargs: Any) -> None:
self._append_record(CallbackType.ON_TEXT, args, kwargs)
def on_chain_start(self, *args: Any, **kwargs: Any) -> None:
self._append_record(CallbackType.ON_CHAIN_START, args, kwargs)
def on_chain_end(self, *args: Any, **kwargs: Any) -> None:
self._append_record(CallbackType.ON_CHAIN_END, args, kwargs)
def on_chain_error(self, *args: Any, **kwargs: Any) -> None:
self._append_record(CallbackType.ON_CHAIN_ERROR, args, kwargs)
def on_agent_action(self, *args: Any, **kwargs: Any) -> Any:
self._append_record(CallbackType.ON_AGENT_ACTION, args, kwargs)
def on_agent_finish(self, *args: Any, **kwargs: Any) -> None:
self._append_record(CallbackType.ON_AGENT_FINISH, args, kwargs)