File size: 4,615 Bytes
ed4d993
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
from types import ModuleType, SimpleNamespace
from typing import TYPE_CHECKING, Any, Callable, Dict

from langchain_core.tracers import BaseTracer
from langchain_core.utils import guard_import

if TYPE_CHECKING:
    from uuid import UUID

    from comet_llm import Span
    from comet_llm.chains.chain import Chain

    from langchain_community.callbacks.tracers.schemas import Run


def _get_run_type(run: "Run") -> str:
    if isinstance(run.run_type, str):
        return run.run_type
    elif hasattr(run.run_type, "value"):
        return run.run_type.value
    else:
        return str(run.run_type)


def import_comet_llm_api() -> SimpleNamespace:
    """Import comet_llm api and raise an error if it is not installed."""
    comet_llm = guard_import("comet_llm")
    comet_llm_chains = guard_import("comet_llm.chains")

    return SimpleNamespace(
        chain=comet_llm_chains.chain,
        span=comet_llm_chains.span,
        chain_api=comet_llm_chains.api,
        experiment_info=comet_llm.experiment_info,
        flush=comet_llm.flush,
    )


class CometTracer(BaseTracer):
    """Comet Tracer."""

    def __init__(self, **kwargs: Any) -> None:
        """Initialize the Comet Tracer."""
        super().__init__(**kwargs)
        self._span_map: Dict["UUID", "Span"] = {}
        """Map from run id to span."""
        self._chains_map: Dict["UUID", "Chain"] = {}
        """Map from run id to chain."""
        self._initialize_comet_modules()

    def _initialize_comet_modules(self) -> None:
        comet_llm_api = import_comet_llm_api()
        self._chain: ModuleType = comet_llm_api.chain
        self._span: ModuleType = comet_llm_api.span
        self._chain_api: ModuleType = comet_llm_api.chain_api
        self._experiment_info: ModuleType = comet_llm_api.experiment_info
        self._flush: Callable[[], None] = comet_llm_api.flush

    def _persist_run(self, run: "Run") -> None:
        run_dict: Dict[str, Any] = run.dict()
        chain_ = self._chains_map[run.id]
        chain_.set_outputs(outputs=run_dict["outputs"])
        self._chain_api.log_chain(chain_)

    def _process_start_trace(self, run: "Run") -> None:
        run_dict: Dict[str, Any] = run.dict()
        if not run.parent_run_id:
            # This is the first run, which maps to a chain
            metadata = run_dict["extra"].get("metadata", None)

            chain_: "Chain" = self._chain.Chain(
                inputs=run_dict["inputs"],
                metadata=metadata,
                experiment_info=self._experiment_info.get(),
            )
            self._chains_map[run.id] = chain_
        else:
            span: "Span" = self._span.Span(
                inputs=run_dict["inputs"],
                category=_get_run_type(run),
                metadata=run_dict["extra"],
                name=run.name,
            )
            span.__api__start__(self._chains_map[run.parent_run_id])
            self._chains_map[run.id] = self._chains_map[run.parent_run_id]
            self._span_map[run.id] = span

    def _process_end_trace(self, run: "Run") -> None:
        run_dict: Dict[str, Any] = run.dict()
        if not run.parent_run_id:
            pass
            # Langchain will call _persist_run for us
        else:
            span = self._span_map[run.id]
            span.set_outputs(outputs=run_dict["outputs"])
            span.__api__end__()

    def flush(self) -> None:
        self._flush()

    def _on_llm_start(self, run: "Run") -> None:
        """Process the LLM Run upon start."""
        self._process_start_trace(run)

    def _on_llm_end(self, run: "Run") -> None:
        """Process the LLM Run."""
        self._process_end_trace(run)

    def _on_llm_error(self, run: "Run") -> None:
        """Process the LLM Run upon error."""
        self._process_end_trace(run)

    def _on_chain_start(self, run: "Run") -> None:
        """Process the Chain Run upon start."""
        self._process_start_trace(run)

    def _on_chain_end(self, run: "Run") -> None:
        """Process the Chain Run."""
        self._process_end_trace(run)

    def _on_chain_error(self, run: "Run") -> None:
        """Process the Chain Run upon error."""
        self._process_end_trace(run)

    def _on_tool_start(self, run: "Run") -> None:
        """Process the Tool Run upon start."""
        self._process_start_trace(run)

    def _on_tool_end(self, run: "Run") -> None:
        """Process the Tool Run."""
        self._process_end_trace(run)

    def _on_tool_error(self, run: "Run") -> None:
        """Process the Tool Run upon error."""
        self._process_end_trace(run)