File size: 4,855 Bytes
ed4d993
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
from __future__ import annotations

from textwrap import indent
from typing import Any, Dict, List, Optional, Type

from langchain.base_language import BaseLanguageModel
from langchain.chains.base import Chain
from langchain_core.callbacks.manager import (
    AsyncCallbackManagerForChainRun,
    CallbackManagerForChainRun,
)

from langchain_experimental.pydantic_v1 import Extra
from langchain_experimental.tot.checker import ToTChecker
from langchain_experimental.tot.controller import ToTController
from langchain_experimental.tot.memory import ToTDFSMemory
from langchain_experimental.tot.thought import Thought, ThoughtValidity
from langchain_experimental.tot.thought_generation import (
    BaseThoughtGenerationStrategy,
    ProposePromptStrategy,
)


class ToTChain(Chain):
    """
    Chain implementing the Tree of Thought (ToT).
    """

    llm: BaseLanguageModel
    """
    Language model to use. It must be set to produce different variations for
    the same prompt.
    """
    checker: ToTChecker
    """ToT Checker to use."""
    output_key: str = "response"  #: :meta private:
    k: int = 10
    """The maximum number of conversation rounds"""
    c: int = 3
    """The number of children to explore at each node"""
    tot_memory: ToTDFSMemory = ToTDFSMemory()
    tot_controller: ToTController = ToTController()
    tot_strategy_class: Type[BaseThoughtGenerationStrategy] = ProposePromptStrategy
    verbose_llm: bool = False

    class Config:
        """Configuration for this pydantic object."""

        extra = Extra.forbid
        arbitrary_types_allowed = True

    @classmethod
    def from_llm(cls, llm: BaseLanguageModel, **kwargs: Any) -> ToTChain:
        """
        Create a ToTChain from a language model.

        :param llm: The language model to use.
        :param kwargs: Additional arguments to pass to the ToTChain constructor.
        """
        return cls(llm=llm, **kwargs)

    def __init__(self, **kwargs: Any):
        super().__init__(**kwargs)
        self.tot_controller.c = self.c

    @property
    def input_keys(self) -> List[str]:
        """Will be whatever keys the prompt expects.

        :meta private:
        """
        return ["problem_description"]

    @property
    def output_keys(self) -> List[str]:
        """Will always return text key.

        :meta private:
        """
        return [self.output_key]

    def log_thought(
        self,
        thought: Thought,
        level: int,
        run_manager: Optional[CallbackManagerForChainRun] = None,
    ) -> None:
        if run_manager:
            colors = {
                ThoughtValidity.VALID_FINAL: "green",
                ThoughtValidity.VALID_INTERMEDIATE: "yellow",
                ThoughtValidity.INVALID: "red",
            }
            text = indent(f"Thought: {thought.text}\n", prefix="    " * level)
            run_manager.on_text(
                text=text, color=colors[thought.validity], verbose=self.verbose
            )

    def _call(
        self,
        inputs: Dict[str, Any],
        run_manager: Optional[CallbackManagerForChainRun] = None,
    ) -> Dict[str, str]:
        _run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager()
        if run_manager:
            run_manager.on_text(text="Starting the ToT solve procedure.\n")

        problem_description = inputs["problem_description"]
        checker_inputs = {"problem_description": problem_description}
        thoughts_path: tuple[str, ...] = ()
        thought_generator = self.tot_strategy_class(
            llm=self.llm, c=self.c, verbose=self.verbose_llm
        )

        level = 0
        for _ in range(self.k):
            level = self.tot_memory.level
            thought_text = thought_generator.next_thought(
                problem_description, thoughts_path, callbacks=_run_manager.get_child()
            )
            checker_inputs["thoughts"] = thoughts_path + (thought_text,)
            thought_validity = self.checker(
                checker_inputs, callbacks=_run_manager.get_child()
            )["validity"]
            thought = Thought(text=thought_text, validity=thought_validity)
            if thought.validity == ThoughtValidity.VALID_FINAL:
                self.log_thought(thought, level, run_manager)
                return {self.output_key: thought.text}
            self.tot_memory.store(thought)
            self.log_thought(thought, level, run_manager)
            thoughts_path = self.tot_controller(self.tot_memory)

        return {self.output_key: "No solution found"}

    async def _acall(
        self,
        inputs: Dict[str, Any],
        run_manager: Optional[AsyncCallbackManagerForChainRun] = None,
    ) -> Dict[str, str]:
        raise NotImplementedError("Async not implemented yet")

    @property
    def _chain_type(self) -> str:
        return "tot"