Spaces:
Runtime error
Runtime error
File size: 1,410 Bytes
ed4d993 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 |
from abc import ABC, abstractmethod
from typing import Any, Dict, List, Optional, Tuple
from langchain.chains.base import Chain
from langchain_core.callbacks.manager import CallbackManagerForChainRun
from langchain_experimental.tot.thought import ThoughtValidity
class ToTChecker(Chain, ABC):
"""
Tree of Thought (ToT) checker.
This is an abstract ToT checker that must be implemented by the user. You
can implement a simple rule-based checker or a more sophisticated
neural network based classifier.
"""
output_key: str = "validity" #: :meta private:
@property
def input_keys(self) -> List[str]:
"""The checker input keys.
:meta private:
"""
return ["problem_description", "thoughts"]
@property
def output_keys(self) -> List[str]:
"""The checker output keys.
:meta private:
"""
return [self.output_key]
@abstractmethod
def evaluate(
self,
problem_description: str,
thoughts: Tuple[str, ...] = (),
) -> ThoughtValidity:
"""
Evaluate the response to the problem description and return the solution type.
"""
def _call(
self,
inputs: Dict[str, Any],
run_manager: Optional[CallbackManagerForChainRun] = None,
) -> Dict[str, ThoughtValidity]:
return {self.output_key: self.evaluate(**inputs)}
|