Spaces:
Runtime error
Runtime error
File size: 13,732 Bytes
ed4d993 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 |
"""Chain for applying self-critique using the SmartGPT workflow."""
from typing import Any, Dict, List, Optional, Tuple, Type
from langchain.base_language import BaseLanguageModel
from langchain.chains.base import Chain
from langchain.input import get_colored_text
from langchain.schema import LLMResult, PromptValue
from langchain_core.callbacks.manager import CallbackManagerForChainRun
from langchain_core.prompts.base import BasePromptTemplate
from langchain_core.prompts.chat import (
AIMessagePromptTemplate,
BaseMessagePromptTemplate,
ChatPromptTemplate,
HumanMessagePromptTemplate,
)
from langchain_experimental.pydantic_v1 import Extra, root_validator
class SmartLLMChain(Chain):
"""Chain for applying self-critique using the SmartGPT workflow.
See details at https://youtu.be/wVzuvf9D9BU
A SmartLLMChain is an LLMChain that instead of simply passing the prompt to the LLM
performs these 3 steps:
1. Ideate: Pass the user prompt to an ideation LLM n_ideas times,
each result is an "idea"
2. Critique: Pass the ideas to a critique LLM which looks for flaws in the ideas
& picks the best one
3. Resolve: Pass the critique to a resolver LLM which improves upon the best idea
& outputs only the (improved version of) the best output
In total, SmartLLMChain pass will use n_ideas+2 LLM calls
Note that SmartLLMChain will only improve results (compared to a basic LLMChain),
when the underlying models have the capability for reflection, which smaller models
often don't.
Finally, a SmartLLMChain assumes that each underlying LLM outputs exactly 1 result.
"""
class SmartLLMChainHistory:
question: str = ""
ideas: List[str] = []
critique: str = ""
@property
def n_ideas(self) -> int:
return len(self.ideas)
def ideation_prompt_inputs(self) -> Dict[str, Any]:
return {"question": self.question}
def critique_prompt_inputs(self) -> Dict[str, Any]:
return {
"question": self.question,
**{f"idea_{i+1}": idea for i, idea in enumerate(self.ideas)},
}
def resolve_prompt_inputs(self) -> Dict[str, Any]:
return {
"question": self.question,
**{f"idea_{i+1}": idea for i, idea in enumerate(self.ideas)},
"critique": self.critique,
}
prompt: BasePromptTemplate
"""Prompt object to use."""
output_key: str = "resolution"
ideation_llm: Optional[BaseLanguageModel] = None
"""LLM to use in ideation step. If None given, 'llm' will be used."""
critique_llm: Optional[BaseLanguageModel] = None
"""LLM to use in critique step. If None given, 'llm' will be used."""
resolver_llm: Optional[BaseLanguageModel] = None
"""LLM to use in resolve step. If None given, 'llm' will be used."""
llm: Optional[BaseLanguageModel] = None
"""LLM to use for each steps, if no specific llm for that step is given. """
n_ideas: int = 3
"""Number of ideas to generate in idea step"""
return_intermediate_steps: bool = False
"""Whether to return ideas and critique, in addition to resolution."""
history: SmartLLMChainHistory = SmartLLMChainHistory()
class Config:
extra = Extra.forbid
# TODO: move away from `root_validator` since it is deprecated in pydantic v2
# and causes mypy type-checking failures (hence the `type: ignore`)
@root_validator # type: ignore[call-overload]
@classmethod
def validate_inputs(cls, values: Dict[str, Any]) -> Dict[str, Any]:
"""Ensure we have an LLM for each step."""
llm = values.get("llm")
ideation_llm = values.get("ideation_llm")
critique_llm = values.get("critique_llm")
resolver_llm = values.get("resolver_llm")
if not llm and not ideation_llm:
raise ValueError(
"Either ideation_llm or llm needs to be given. Pass llm, "
"if you want to use the same llm for all steps, or pass "
"ideation_llm, critique_llm and resolver_llm if you want "
"to use different llms for each step."
)
if not llm and not critique_llm:
raise ValueError(
"Either critique_llm or llm needs to be given. Pass llm, "
"if you want to use the same llm for all steps, or pass "
"ideation_llm, critique_llm and resolver_llm if you want "
"to use different llms for each step."
)
if not llm and not resolver_llm:
raise ValueError(
"Either resolve_llm or llm needs to be given. Pass llm, "
"if you want to use the same llm for all steps, or pass "
"ideation_llm, critique_llm and resolver_llm if you want "
"to use different llms for each step."
)
if llm and ideation_llm and critique_llm and resolver_llm:
raise ValueError(
"LLMs are given for each step (ideation_llm, critique_llm,"
" resolver_llm), but backup LLM (llm) is also given, which"
" would not be used."
)
return values
@property
def input_keys(self) -> List[str]:
"""Defines the input keys."""
return self.prompt.input_variables
@property
def output_keys(self) -> List[str]:
"""Defines the output keys."""
if self.return_intermediate_steps:
return ["ideas", "critique", self.output_key]
return [self.output_key]
def prep_prompts(
self,
inputs: Dict[str, Any],
run_manager: Optional[CallbackManagerForChainRun] = None,
) -> Tuple[PromptValue, Optional[List[str]]]:
"""Prepare prompts from inputs."""
stop = None
if "stop" in inputs:
stop = inputs["stop"]
selected_inputs = {k: inputs[k] for k in self.prompt.input_variables}
prompt = self.prompt.format_prompt(**selected_inputs)
_colored_text = get_colored_text(prompt.to_string(), "green")
_text = "Prompt after formatting:\n" + _colored_text
if run_manager:
run_manager.on_text(_text, end="\n", verbose=self.verbose)
if "stop" in inputs and inputs["stop"] != stop:
raise ValueError(
"If `stop` is present in any inputs, should be present in all."
)
return prompt, stop
def _call(
self,
input_list: Dict[str, Any],
run_manager: Optional[CallbackManagerForChainRun] = None,
) -> Dict[str, Any]:
prompt, stop = self.prep_prompts(input_list, run_manager=run_manager)
self.history.question = prompt.to_string()
ideas = self._ideate(stop, run_manager)
self.history.ideas = ideas
critique = self._critique(stop, run_manager)
self.history.critique = critique
resolution = self._resolve(stop, run_manager)
if self.return_intermediate_steps:
return {"ideas": ideas, "critique": critique, self.output_key: resolution}
return {self.output_key: resolution}
def _get_text_from_llm_result(self, result: LLMResult, step: str) -> str:
"""Between steps, only the LLM result text is passed, not the LLMResult object.
This function extracts the text from an LLMResult."""
if len(result.generations) != 1:
raise ValueError(
f"In SmartLLM the LLM result in step {step} is not "
"exactly 1 element. This should never happen"
)
if len(result.generations[0]) != 1:
raise ValueError(
f"In SmartLLM the LLM in step {step} returned more than "
"1 output. SmartLLM only works with LLMs returning "
"exactly 1 output."
)
return result.generations[0][0].text
def get_prompt_strings(
self, stage: str
) -> List[Tuple[Type[BaseMessagePromptTemplate], str]]:
role_strings: List[Tuple[Type[BaseMessagePromptTemplate], str]] = []
role_strings.append(
(
HumanMessagePromptTemplate,
"Question: {question}\nAnswer: Let's work this out in a step by "
"step way to be sure we have the right answer:",
)
)
if stage == "ideation":
return role_strings
role_strings.extend(
[
*[
(
AIMessagePromptTemplate,
"Idea " + str(i + 1) + ": {idea_" + str(i + 1) + "}",
)
for i in range(self.n_ideas)
],
(
HumanMessagePromptTemplate,
"You are a researcher tasked with investigating the "
f"{self.n_ideas} response options provided. List the flaws and "
"faulty logic of each answer option. Let's work this out in a step"
" by step way to be sure we have all the errors:",
),
]
)
if stage == "critique":
return role_strings
role_strings.extend(
[
(AIMessagePromptTemplate, "Critique: {critique}"),
(
HumanMessagePromptTemplate,
"You are a resolver tasked with 1) finding which of "
f"the {self.n_ideas} answer options the researcher thought was "
"best, 2) improving that answer and 3) printing the answer in "
"full. Don't output anything for step 1 or 2, only the full "
"answer in 3. Let's work this out in a step by step way to "
"be sure we have the right answer:",
),
]
)
if stage == "resolve":
return role_strings
raise ValueError(
"stage should be either 'ideation', 'critique' or 'resolve',"
f" but it is '{stage}'. This should never happen."
)
def ideation_prompt(self) -> ChatPromptTemplate:
return ChatPromptTemplate.from_strings(self.get_prompt_strings("ideation"))
def critique_prompt(self) -> ChatPromptTemplate:
return ChatPromptTemplate.from_strings(self.get_prompt_strings("critique"))
def resolve_prompt(self) -> ChatPromptTemplate:
return ChatPromptTemplate.from_strings(self.get_prompt_strings("resolve"))
def _ideate(
self,
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForChainRun] = None,
) -> List[str]:
"""Generate n_ideas ideas as response to user prompt."""
llm = self.ideation_llm if self.ideation_llm else self.llm
prompt = self.ideation_prompt().format_prompt(
**self.history.ideation_prompt_inputs()
)
callbacks = run_manager.get_child() if run_manager else None
if llm:
ideas = [
self._get_text_from_llm_result(
llm.generate_prompt([prompt], stop, callbacks),
step="ideate",
)
for _ in range(self.n_ideas)
]
for i, idea in enumerate(ideas):
_colored_text = get_colored_text(idea, "blue")
_text = f"Idea {i+1}:\n" + _colored_text
if run_manager:
run_manager.on_text(_text, end="\n", verbose=self.verbose)
return ideas
else:
raise ValueError("llm is none, which should never happen")
def _critique(
self,
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForChainRun] = None,
) -> str:
"""Critique each of the ideas from ideation stage & select best one."""
llm = self.critique_llm if self.critique_llm else self.llm
prompt = self.critique_prompt().format_prompt(
**self.history.critique_prompt_inputs()
)
callbacks = run_manager.handlers if run_manager else None
if llm:
critique = self._get_text_from_llm_result(
llm.generate_prompt([prompt], stop, callbacks), step="critique"
)
_colored_text = get_colored_text(critique, "yellow")
_text = "Critique:\n" + _colored_text
if run_manager:
run_manager.on_text(_text, end="\n", verbose=self.verbose)
return critique
else:
raise ValueError("llm is none, which should never happen")
def _resolve(
self,
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForChainRun] = None,
) -> str:
"""Improve upon the best idea as chosen in critique step & return it."""
llm = self.resolver_llm if self.resolver_llm else self.llm
prompt = self.resolve_prompt().format_prompt(
**self.history.resolve_prompt_inputs()
)
callbacks = run_manager.handlers if run_manager else None
if llm:
resolution = self._get_text_from_llm_result(
llm.generate_prompt([prompt], stop, callbacks), step="resolve"
)
_colored_text = get_colored_text(resolution, "green")
_text = "Resolution:\n" + _colored_text
if run_manager:
run_manager.on_text(_text, end="\n", verbose=self.verbose)
return resolution
else:
raise ValueError("llm is none, which should never happen")
|