Spaces:
Runtime error
Runtime error
File size: 4,615 Bytes
ed4d993 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 |
from types import ModuleType, SimpleNamespace
from typing import TYPE_CHECKING, Any, Callable, Dict
from langchain_core.tracers import BaseTracer
from langchain_core.utils import guard_import
if TYPE_CHECKING:
from uuid import UUID
from comet_llm import Span
from comet_llm.chains.chain import Chain
from langchain_community.callbacks.tracers.schemas import Run
def _get_run_type(run: "Run") -> str:
if isinstance(run.run_type, str):
return run.run_type
elif hasattr(run.run_type, "value"):
return run.run_type.value
else:
return str(run.run_type)
def import_comet_llm_api() -> SimpleNamespace:
"""Import comet_llm api and raise an error if it is not installed."""
comet_llm = guard_import("comet_llm")
comet_llm_chains = guard_import("comet_llm.chains")
return SimpleNamespace(
chain=comet_llm_chains.chain,
span=comet_llm_chains.span,
chain_api=comet_llm_chains.api,
experiment_info=comet_llm.experiment_info,
flush=comet_llm.flush,
)
class CometTracer(BaseTracer):
"""Comet Tracer."""
def __init__(self, **kwargs: Any) -> None:
"""Initialize the Comet Tracer."""
super().__init__(**kwargs)
self._span_map: Dict["UUID", "Span"] = {}
"""Map from run id to span."""
self._chains_map: Dict["UUID", "Chain"] = {}
"""Map from run id to chain."""
self._initialize_comet_modules()
def _initialize_comet_modules(self) -> None:
comet_llm_api = import_comet_llm_api()
self._chain: ModuleType = comet_llm_api.chain
self._span: ModuleType = comet_llm_api.span
self._chain_api: ModuleType = comet_llm_api.chain_api
self._experiment_info: ModuleType = comet_llm_api.experiment_info
self._flush: Callable[[], None] = comet_llm_api.flush
def _persist_run(self, run: "Run") -> None:
run_dict: Dict[str, Any] = run.dict()
chain_ = self._chains_map[run.id]
chain_.set_outputs(outputs=run_dict["outputs"])
self._chain_api.log_chain(chain_)
def _process_start_trace(self, run: "Run") -> None:
run_dict: Dict[str, Any] = run.dict()
if not run.parent_run_id:
# This is the first run, which maps to a chain
metadata = run_dict["extra"].get("metadata", None)
chain_: "Chain" = self._chain.Chain(
inputs=run_dict["inputs"],
metadata=metadata,
experiment_info=self._experiment_info.get(),
)
self._chains_map[run.id] = chain_
else:
span: "Span" = self._span.Span(
inputs=run_dict["inputs"],
category=_get_run_type(run),
metadata=run_dict["extra"],
name=run.name,
)
span.__api__start__(self._chains_map[run.parent_run_id])
self._chains_map[run.id] = self._chains_map[run.parent_run_id]
self._span_map[run.id] = span
def _process_end_trace(self, run: "Run") -> None:
run_dict: Dict[str, Any] = run.dict()
if not run.parent_run_id:
pass
# Langchain will call _persist_run for us
else:
span = self._span_map[run.id]
span.set_outputs(outputs=run_dict["outputs"])
span.__api__end__()
def flush(self) -> None:
self._flush()
def _on_llm_start(self, run: "Run") -> None:
"""Process the LLM Run upon start."""
self._process_start_trace(run)
def _on_llm_end(self, run: "Run") -> None:
"""Process the LLM Run."""
self._process_end_trace(run)
def _on_llm_error(self, run: "Run") -> None:
"""Process the LLM Run upon error."""
self._process_end_trace(run)
def _on_chain_start(self, run: "Run") -> None:
"""Process the Chain Run upon start."""
self._process_start_trace(run)
def _on_chain_end(self, run: "Run") -> None:
"""Process the Chain Run."""
self._process_end_trace(run)
def _on_chain_error(self, run: "Run") -> None:
"""Process the Chain Run upon error."""
self._process_end_trace(run)
def _on_tool_start(self, run: "Run") -> None:
"""Process the Tool Run upon start."""
self._process_start_trace(run)
def _on_tool_end(self, run: "Run") -> None:
"""Process the Tool Run."""
self._process_end_trace(run)
def _on_tool_error(self, run: "Run") -> None:
"""Process the Tool Run upon error."""
self._process_end_trace(run)
|