File size: 13,732 Bytes
ed4d993
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
"""Chain for applying self-critique using the SmartGPT workflow."""
from typing import Any, Dict, List, Optional, Tuple, Type

from langchain.base_language import BaseLanguageModel
from langchain.chains.base import Chain
from langchain.input import get_colored_text
from langchain.schema import LLMResult, PromptValue
from langchain_core.callbacks.manager import CallbackManagerForChainRun
from langchain_core.prompts.base import BasePromptTemplate
from langchain_core.prompts.chat import (
    AIMessagePromptTemplate,
    BaseMessagePromptTemplate,
    ChatPromptTemplate,
    HumanMessagePromptTemplate,
)

from langchain_experimental.pydantic_v1 import Extra, root_validator


class SmartLLMChain(Chain):
    """Chain for applying self-critique using the SmartGPT workflow.

    See details at https://youtu.be/wVzuvf9D9BU

    A SmartLLMChain is an LLMChain that instead of simply passing the prompt to the LLM
    performs these 3 steps:
    1. Ideate: Pass the user prompt to an ideation LLM n_ideas times,
       each result is an "idea"
    2. Critique: Pass the ideas to a critique LLM which looks for flaws in the ideas
       & picks the best one
    3. Resolve: Pass the critique to a resolver LLM which improves upon the best idea
       & outputs only the (improved version of) the best output

    In total, SmartLLMChain pass will use n_ideas+2 LLM calls

    Note that SmartLLMChain will only improve results (compared to a basic LLMChain),
    when the underlying models have the capability for reflection, which smaller models
    often don't.

    Finally, a SmartLLMChain assumes that each underlying LLM outputs exactly 1 result.
    """

    class SmartLLMChainHistory:
        question: str = ""
        ideas: List[str] = []
        critique: str = ""

        @property
        def n_ideas(self) -> int:
            return len(self.ideas)

        def ideation_prompt_inputs(self) -> Dict[str, Any]:
            return {"question": self.question}

        def critique_prompt_inputs(self) -> Dict[str, Any]:
            return {
                "question": self.question,
                **{f"idea_{i+1}": idea for i, idea in enumerate(self.ideas)},
            }

        def resolve_prompt_inputs(self) -> Dict[str, Any]:
            return {
                "question": self.question,
                **{f"idea_{i+1}": idea for i, idea in enumerate(self.ideas)},
                "critique": self.critique,
            }

    prompt: BasePromptTemplate
    """Prompt object to use."""
    output_key: str = "resolution"
    ideation_llm: Optional[BaseLanguageModel] = None
    """LLM to use in ideation step. If None given, 'llm' will be used."""
    critique_llm: Optional[BaseLanguageModel] = None
    """LLM to use in critique step. If None given, 'llm' will be used."""
    resolver_llm: Optional[BaseLanguageModel] = None
    """LLM to use in resolve step. If None given, 'llm' will be used."""
    llm: Optional[BaseLanguageModel] = None
    """LLM to use for each steps, if no specific llm for that step is given. """
    n_ideas: int = 3
    """Number of ideas to generate in idea step"""
    return_intermediate_steps: bool = False
    """Whether to return ideas and critique, in addition to resolution."""
    history: SmartLLMChainHistory = SmartLLMChainHistory()

    class Config:
        extra = Extra.forbid

    # TODO: move away from `root_validator` since it is deprecated in pydantic v2
    #       and causes mypy type-checking failures (hence the `type: ignore`)
    @root_validator  # type: ignore[call-overload]
    @classmethod
    def validate_inputs(cls, values: Dict[str, Any]) -> Dict[str, Any]:
        """Ensure we have an LLM for each step."""
        llm = values.get("llm")
        ideation_llm = values.get("ideation_llm")
        critique_llm = values.get("critique_llm")
        resolver_llm = values.get("resolver_llm")

        if not llm and not ideation_llm:
            raise ValueError(
                "Either ideation_llm or llm needs to be given. Pass llm, "
                "if you want to use the same llm for all steps, or pass "
                "ideation_llm, critique_llm and resolver_llm if you want "
                "to use different llms for each step."
            )
        if not llm and not critique_llm:
            raise ValueError(
                "Either critique_llm or llm needs to be given. Pass llm, "
                "if you want to use the same llm for all steps, or pass "
                "ideation_llm, critique_llm and resolver_llm if you want "
                "to use different llms for each step."
            )
        if not llm and not resolver_llm:
            raise ValueError(
                "Either resolve_llm or llm needs to be given. Pass llm, "
                "if you want to use the same llm for all steps, or pass "
                "ideation_llm, critique_llm and resolver_llm if you want "
                "to use different llms for each step."
            )
        if llm and ideation_llm and critique_llm and resolver_llm:
            raise ValueError(
                "LLMs are given for each step (ideation_llm, critique_llm,"
                " resolver_llm), but backup LLM (llm) is also given, which"
                " would not be used."
            )
        return values

    @property
    def input_keys(self) -> List[str]:
        """Defines the input keys."""
        return self.prompt.input_variables

    @property
    def output_keys(self) -> List[str]:
        """Defines the output keys."""
        if self.return_intermediate_steps:
            return ["ideas", "critique", self.output_key]
        return [self.output_key]

    def prep_prompts(
        self,
        inputs: Dict[str, Any],
        run_manager: Optional[CallbackManagerForChainRun] = None,
    ) -> Tuple[PromptValue, Optional[List[str]]]:
        """Prepare prompts from inputs."""
        stop = None
        if "stop" in inputs:
            stop = inputs["stop"]
        selected_inputs = {k: inputs[k] for k in self.prompt.input_variables}
        prompt = self.prompt.format_prompt(**selected_inputs)
        _colored_text = get_colored_text(prompt.to_string(), "green")
        _text = "Prompt after formatting:\n" + _colored_text
        if run_manager:
            run_manager.on_text(_text, end="\n", verbose=self.verbose)
        if "stop" in inputs and inputs["stop"] != stop:
            raise ValueError(
                "If `stop` is present in any inputs, should be present in all."
            )
        return prompt, stop

    def _call(
        self,
        input_list: Dict[str, Any],
        run_manager: Optional[CallbackManagerForChainRun] = None,
    ) -> Dict[str, Any]:
        prompt, stop = self.prep_prompts(input_list, run_manager=run_manager)
        self.history.question = prompt.to_string()
        ideas = self._ideate(stop, run_manager)
        self.history.ideas = ideas
        critique = self._critique(stop, run_manager)
        self.history.critique = critique
        resolution = self._resolve(stop, run_manager)
        if self.return_intermediate_steps:
            return {"ideas": ideas, "critique": critique, self.output_key: resolution}
        return {self.output_key: resolution}

    def _get_text_from_llm_result(self, result: LLMResult, step: str) -> str:
        """Between steps, only the LLM result text is passed, not the LLMResult object.
        This function extracts the text from an LLMResult."""
        if len(result.generations) != 1:
            raise ValueError(
                f"In SmartLLM the LLM result in step {step} is not "
                "exactly 1 element. This should never happen"
            )
        if len(result.generations[0]) != 1:
            raise ValueError(
                f"In SmartLLM the LLM in step {step} returned more than "
                "1 output. SmartLLM only works with LLMs returning "
                "exactly 1 output."
            )
        return result.generations[0][0].text

    def get_prompt_strings(
        self, stage: str
    ) -> List[Tuple[Type[BaseMessagePromptTemplate], str]]:
        role_strings: List[Tuple[Type[BaseMessagePromptTemplate], str]] = []
        role_strings.append(
            (
                HumanMessagePromptTemplate,
                "Question: {question}\nAnswer: Let's work this out in a step by "
                "step way to be sure we have the right answer:",
            )
        )
        if stage == "ideation":
            return role_strings
        role_strings.extend(
            [
                *[
                    (
                        AIMessagePromptTemplate,
                        "Idea " + str(i + 1) + ": {idea_" + str(i + 1) + "}",
                    )
                    for i in range(self.n_ideas)
                ],
                (
                    HumanMessagePromptTemplate,
                    "You are a researcher tasked with investigating the "
                    f"{self.n_ideas} response options provided. List the flaws and "
                    "faulty logic of each answer option. Let's work this out in a step"
                    " by step way to be sure we have all the errors:",
                ),
            ]
        )
        if stage == "critique":
            return role_strings
        role_strings.extend(
            [
                (AIMessagePromptTemplate, "Critique: {critique}"),
                (
                    HumanMessagePromptTemplate,
                    "You are a resolver tasked with 1) finding which of "
                    f"the {self.n_ideas} answer options the researcher thought was  "
                    "best, 2) improving that answer and 3) printing the answer in "
                    "full. Don't output anything for step 1 or 2, only the full "
                    "answer in 3. Let's work this out in a step by step way to "
                    "be sure we have the right answer:",
                ),
            ]
        )
        if stage == "resolve":
            return role_strings
        raise ValueError(
            "stage should be either 'ideation', 'critique' or 'resolve',"
            f" but it is '{stage}'. This should never happen."
        )

    def ideation_prompt(self) -> ChatPromptTemplate:
        return ChatPromptTemplate.from_strings(self.get_prompt_strings("ideation"))

    def critique_prompt(self) -> ChatPromptTemplate:
        return ChatPromptTemplate.from_strings(self.get_prompt_strings("critique"))

    def resolve_prompt(self) -> ChatPromptTemplate:
        return ChatPromptTemplate.from_strings(self.get_prompt_strings("resolve"))

    def _ideate(
        self,
        stop: Optional[List[str]] = None,
        run_manager: Optional[CallbackManagerForChainRun] = None,
    ) -> List[str]:
        """Generate n_ideas ideas as response to user prompt."""
        llm = self.ideation_llm if self.ideation_llm else self.llm
        prompt = self.ideation_prompt().format_prompt(
            **self.history.ideation_prompt_inputs()
        )
        callbacks = run_manager.get_child() if run_manager else None
        if llm:
            ideas = [
                self._get_text_from_llm_result(
                    llm.generate_prompt([prompt], stop, callbacks),
                    step="ideate",
                )
                for _ in range(self.n_ideas)
            ]
            for i, idea in enumerate(ideas):
                _colored_text = get_colored_text(idea, "blue")
                _text = f"Idea {i+1}:\n" + _colored_text
                if run_manager:
                    run_manager.on_text(_text, end="\n", verbose=self.verbose)
            return ideas
        else:
            raise ValueError("llm is none, which should never happen")

    def _critique(
        self,
        stop: Optional[List[str]] = None,
        run_manager: Optional[CallbackManagerForChainRun] = None,
    ) -> str:
        """Critique each of the ideas from ideation stage & select best one."""
        llm = self.critique_llm if self.critique_llm else self.llm
        prompt = self.critique_prompt().format_prompt(
            **self.history.critique_prompt_inputs()
        )
        callbacks = run_manager.handlers if run_manager else None
        if llm:
            critique = self._get_text_from_llm_result(
                llm.generate_prompt([prompt], stop, callbacks), step="critique"
            )
            _colored_text = get_colored_text(critique, "yellow")
            _text = "Critique:\n" + _colored_text
            if run_manager:
                run_manager.on_text(_text, end="\n", verbose=self.verbose)
            return critique
        else:
            raise ValueError("llm is none, which should never happen")

    def _resolve(
        self,
        stop: Optional[List[str]] = None,
        run_manager: Optional[CallbackManagerForChainRun] = None,
    ) -> str:
        """Improve upon the best idea as chosen in critique step & return it."""
        llm = self.resolver_llm if self.resolver_llm else self.llm
        prompt = self.resolve_prompt().format_prompt(
            **self.history.resolve_prompt_inputs()
        )
        callbacks = run_manager.handlers if run_manager else None
        if llm:
            resolution = self._get_text_from_llm_result(
                llm.generate_prompt([prompt], stop, callbacks), step="resolve"
            )
            _colored_text = get_colored_text(resolution, "green")
            _text = "Resolution:\n" + _colored_text
            if run_manager:
                run_manager.on_text(_text, end="\n", verbose=self.verbose)
            return resolution
        else:
            raise ValueError("llm is none, which should never happen")