File size: 1,410 Bytes
ed4d993
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
from abc import ABC, abstractmethod
from typing import Any, Dict, List, Optional, Tuple

from langchain.chains.base import Chain
from langchain_core.callbacks.manager import CallbackManagerForChainRun

from langchain_experimental.tot.thought import ThoughtValidity


class ToTChecker(Chain, ABC):
    """
    Tree of Thought (ToT) checker.

    This is an abstract ToT checker that must be implemented by the user. You
    can implement a simple rule-based checker or a more sophisticated
    neural network based classifier.
    """

    output_key: str = "validity"  #: :meta private:

    @property
    def input_keys(self) -> List[str]:
        """The checker input keys.

        :meta private:
        """
        return ["problem_description", "thoughts"]

    @property
    def output_keys(self) -> List[str]:
        """The checker output keys.

        :meta private:
        """
        return [self.output_key]

    @abstractmethod
    def evaluate(
        self,
        problem_description: str,
        thoughts: Tuple[str, ...] = (),
    ) -> ThoughtValidity:
        """
        Evaluate the response to the problem description and return the solution type.
        """

    def _call(
        self,
        inputs: Dict[str, Any],
        run_manager: Optional[CallbackManagerForChainRun] = None,
    ) -> Dict[str, ThoughtValidity]:
        return {self.output_key: self.evaluate(**inputs)}