Spaces:
Runtime error
Runtime error
langchain-qa-bot
/
docs
/langchain
/libs
/community
/langchain_community
/callbacks
/tracers
/comet.py
from types import ModuleType, SimpleNamespace | |
from typing import TYPE_CHECKING, Any, Callable, Dict | |
from langchain_core.tracers import BaseTracer | |
from langchain_core.utils import guard_import | |
if TYPE_CHECKING: | |
from uuid import UUID | |
from comet_llm import Span | |
from comet_llm.chains.chain import Chain | |
from langchain_community.callbacks.tracers.schemas import Run | |
def _get_run_type(run: "Run") -> str: | |
if isinstance(run.run_type, str): | |
return run.run_type | |
elif hasattr(run.run_type, "value"): | |
return run.run_type.value | |
else: | |
return str(run.run_type) | |
def import_comet_llm_api() -> SimpleNamespace: | |
"""Import comet_llm api and raise an error if it is not installed.""" | |
comet_llm = guard_import("comet_llm") | |
comet_llm_chains = guard_import("comet_llm.chains") | |
return SimpleNamespace( | |
chain=comet_llm_chains.chain, | |
span=comet_llm_chains.span, | |
chain_api=comet_llm_chains.api, | |
experiment_info=comet_llm.experiment_info, | |
flush=comet_llm.flush, | |
) | |
class CometTracer(BaseTracer): | |
"""Comet Tracer.""" | |
def __init__(self, **kwargs: Any) -> None: | |
"""Initialize the Comet Tracer.""" | |
super().__init__(**kwargs) | |
self._span_map: Dict["UUID", "Span"] = {} | |
"""Map from run id to span.""" | |
self._chains_map: Dict["UUID", "Chain"] = {} | |
"""Map from run id to chain.""" | |
self._initialize_comet_modules() | |
def _initialize_comet_modules(self) -> None: | |
comet_llm_api = import_comet_llm_api() | |
self._chain: ModuleType = comet_llm_api.chain | |
self._span: ModuleType = comet_llm_api.span | |
self._chain_api: ModuleType = comet_llm_api.chain_api | |
self._experiment_info: ModuleType = comet_llm_api.experiment_info | |
self._flush: Callable[[], None] = comet_llm_api.flush | |
def _persist_run(self, run: "Run") -> None: | |
run_dict: Dict[str, Any] = run.dict() | |
chain_ = self._chains_map[run.id] | |
chain_.set_outputs(outputs=run_dict["outputs"]) | |
self._chain_api.log_chain(chain_) | |
def _process_start_trace(self, run: "Run") -> None: | |
run_dict: Dict[str, Any] = run.dict() | |
if not run.parent_run_id: | |
# This is the first run, which maps to a chain | |
metadata = run_dict["extra"].get("metadata", None) | |
chain_: "Chain" = self._chain.Chain( | |
inputs=run_dict["inputs"], | |
metadata=metadata, | |
experiment_info=self._experiment_info.get(), | |
) | |
self._chains_map[run.id] = chain_ | |
else: | |
span: "Span" = self._span.Span( | |
inputs=run_dict["inputs"], | |
category=_get_run_type(run), | |
metadata=run_dict["extra"], | |
name=run.name, | |
) | |
span.__api__start__(self._chains_map[run.parent_run_id]) | |
self._chains_map[run.id] = self._chains_map[run.parent_run_id] | |
self._span_map[run.id] = span | |
def _process_end_trace(self, run: "Run") -> None: | |
run_dict: Dict[str, Any] = run.dict() | |
if not run.parent_run_id: | |
pass | |
# Langchain will call _persist_run for us | |
else: | |
span = self._span_map[run.id] | |
span.set_outputs(outputs=run_dict["outputs"]) | |
span.__api__end__() | |
def flush(self) -> None: | |
self._flush() | |
def _on_llm_start(self, run: "Run") -> None: | |
"""Process the LLM Run upon start.""" | |
self._process_start_trace(run) | |
def _on_llm_end(self, run: "Run") -> None: | |
"""Process the LLM Run.""" | |
self._process_end_trace(run) | |
def _on_llm_error(self, run: "Run") -> None: | |
"""Process the LLM Run upon error.""" | |
self._process_end_trace(run) | |
def _on_chain_start(self, run: "Run") -> None: | |
"""Process the Chain Run upon start.""" | |
self._process_start_trace(run) | |
def _on_chain_end(self, run: "Run") -> None: | |
"""Process the Chain Run.""" | |
self._process_end_trace(run) | |
def _on_chain_error(self, run: "Run") -> None: | |
"""Process the Chain Run upon error.""" | |
self._process_end_trace(run) | |
def _on_tool_start(self, run: "Run") -> None: | |
"""Process the Tool Run upon start.""" | |
self._process_start_trace(run) | |
def _on_tool_end(self, run: "Run") -> None: | |
"""Process the Tool Run.""" | |
self._process_end_trace(run) | |
def _on_tool_error(self, run: "Run") -> None: | |
"""Process the Tool Run upon error.""" | |
self._process_end_trace(run) | |