tikendraw commited on
Commit
507dfcb
·
1 Parent(s): 7161273

cot or not

Browse files
Files changed (1) hide show
  1. app/utils.py +93 -24
app/utils.py CHANGED
@@ -1,12 +1,15 @@
1
  import json
2
  import re
 
 
3
  from typing import Generator
4
  from textwrap import dedent
5
  from litellm.types.utils import ModelResponse
6
  from pydantic import ValidationError
7
  from core.llms.base_llm import BaseLLM
8
- from core.types import ThoughtSteps
9
- from core.prompts.cot import REVIEW_PROMPT, SYSTEM_PROMPT ,FINAL_ANSWER_PROMPT
 
10
  import os
11
  import time
12
  from core.utils import parse_with_fallback
@@ -16,40 +19,106 @@ from core.llms.litellm_llm import LLM
16
  from core.llms.utils import user_message_with_images
17
  from PIL import Image
18
  from streamlit.runtime.uploaded_file_manager import UploadedFile
 
19
 
20
 
21
 
22
 
23
- def generate_answer(messages: list[dict], max_steps: int = 20, llm: BaseLLM = None, sleeptime: float = 0.0, force_max_steps: bool = False, **kwargs):
24
- thoughts = []
25
 
26
- for i in range(max_steps):
27
- raw_response = llm.chat(messages, **kwargs)
28
- response = raw_response.choices[0].message.content
29
- thought = response_parser(response)
30
-
31
- print(colored(f"{i+1} - {response}", 'yellow'))
 
 
 
 
 
 
 
 
 
 
32
 
33
- thoughts.append(thought)
34
- messages.append({"role": "assistant", "content": thought.model_dump_json()})
35
- yield thought
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
36
 
37
- if thought.is_final_answer and not thought.next_step and not force_max_steps:
38
- break
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
39
 
40
- messages.append({"role": "user", "content": REVIEW_PROMPT})
 
41
 
42
- time.sleep(sleeptime)
43
 
44
- # Get the final answer after all thoughts are processed
45
- messages += [{"role": "user", "content": FINAL_ANSWER_PROMPT}]
46
- raw_final_answers = llm.chat(messages=messages, **kwargs)
47
- final_answer = raw_final_answers.choices[0].message.content
 
48
 
49
- print(colored(f"final answer - {final_answer}", 'green'))
 
 
 
 
50
 
51
- final_thought = response_parser(final_answer)
52
- yield final_thought
53
 
54
  def response_parser(response:str) -> ThoughtSteps:
55
  if isinstance(response, str):
 
1
  import json
2
  import re
3
+ import sys
4
+ from turtle import color
5
  from typing import Generator
6
  from textwrap import dedent
7
  from litellm.types.utils import ModelResponse
8
  from pydantic import ValidationError
9
  from core.llms.base_llm import BaseLLM
10
+ from core.prompts import cot
11
+ from core.types import ThoughtSteps, ThoughtStepsDisplay
12
+ from core.prompts import REVIEW_PROMPT, SYSTEM_PROMPT ,FINAL_ANSWER_PROMPT, HELPFUL_ASSISTANT_PROMPT
13
  import os
14
  import time
15
  from core.utils import parse_with_fallback
 
19
  from core.llms.utils import user_message_with_images
20
  from PIL import Image
21
  from streamlit.runtime.uploaded_file_manager import UploadedFile
22
+ from core.prompts.decision_prompt import COT_OR_DA_PROMPT, COTorDAPromptOutput, Decision
23
 
24
 
25
 
26
 
27
+ def cot_or_da_func(problem: str, llm: BaseLLM = None, **kwargs) -> COTorDAPromptOutput:
 
28
 
29
+ cot_decision_message = [
30
+ {"role": "system", "content": COT_OR_DA_PROMPT},
31
+ {"role": "user", "content": problem}]
32
+
33
+ raw_decision_response = llm.chat(messages=cot_decision_message, **kwargs)
34
+ print(colored(f"Decision Response: {raw_decision_response.choices[0].message.content}", 'blue', 'on_black'))
35
+ decision_response = raw_decision_response.choices[0].message.content
36
+
37
+ try:
38
+ decision = json.loads(decision_response)
39
+ cot_or_da = COTorDAPromptOutput(**decision)
40
+ except (json.JSONDecodeError, ValidationError, KeyError):
41
+ print(colored("Error parsing LLM's CoT decision. Defaulting to Chain of thought.", 'red'))
42
+ cot_or_da = COTorDAPromptOutput(problem=problem, decision="Chain-of-Thought", reasoning="Defaulting to Chain-of-Thought")
43
+
44
+ return cot_or_da
45
 
46
+
47
+ def get_system_prompt(decision: Decision) -> str:
48
+ if decision == Decision.CHAIN_OF_THOUGHT:
49
+ return cot.SYSTEM_PROMPT
50
+ elif decision == Decision.DIRECT_ANSWER:
51
+ return HELPFUL_ASSISTANT_PROMPT
52
+ else:
53
+ raise ValueError(f"Invalid decision: {decision}")
54
+
55
+ def set_system_message(messages: list[dict], cot_or_da: COTorDAPromptOutput) -> list[dict]:
56
+ system_prompt = get_system_prompt(cot_or_da.decision)
57
+ #check if any system message already exists
58
+ if any(message['role'] == 'system' for message in messages):
59
+ for i, message in enumerate(messages):
60
+ if message['role'] == 'system':
61
+ messages[i]['content'] = system_prompt
62
+ else:
63
+ # add a dict at the beginning of the list
64
+ messages.insert(0, {"role": "system", "content": system_prompt})
65
+ return messages
66
+
67
+
68
+ def generate_answer(messages: list[dict], max_steps: int = 20, llm: BaseLLM = None, sleeptime: float = 0.0, force_max_steps: bool = False, **kwargs) -> Generator[ThoughtStepsDisplay, None, None]:
69
+
70
+ user_message = messages[-1]['content']
71
+ cot_or_da = cot_or_da_func(user_message, llm=llm, **kwargs)
72
+ print(colored(f"LLM Decision: {cot_or_da.decision} - Justification: {cot_or_da.reasoning}", 'magenta'))
73
+
74
+ MESSAGES = set_system_message(messages, cot_or_da)
75
+
76
+
77
+ if cot_or_da.decision == Decision.CHAIN_OF_THOUGHT:
78
 
79
+ print(colored(f" {MESSAGES}", 'red'))
80
+ for i in range(max_steps):
81
+ print(i)
82
+ raw_response = llm.chat(messages=MESSAGES, **kwargs)
83
+ print(colored(f"{i+1} - {raw_response.choices[0].message.content}", 'blue', 'on_black'))
84
+ response = raw_response.choices[0].message.content
85
+ thought = response_parser(response)
86
+
87
+ print(colored(f"{i+1} - {response}", 'yellow'))
88
+
89
+ MESSAGES.append({"role": "assistant", "content": thought.model_dump_json()})
90
+
91
+ yield thought.to_thought_steps_display()
92
+
93
+ if thought.is_final_answer and not thought.next_step and not force_max_steps:
94
+ break
95
+
96
+ MESSAGES.append({"role": "user", "content": cot.REVIEW_PROMPT})
97
+
98
+ time.sleep(sleeptime)
99
+
100
+ # Get the final answer after all thoughts are processed
101
+ MESSAGES += [{"role": "user", "content": cot.FINAL_ANSWER_PROMPT}]
102
 
103
+ raw_final_answers = llm.chat(messages=MESSAGES, **kwargs)
104
+ final_answer = raw_final_answers.choices[0].message.content
105
 
106
+ print(colored(f"final answer - {final_answer}", 'green'))
107
 
108
+ final_thought = response_parser(final_answer)
109
+
110
+ yield final_thought.to_thought_steps_display()
111
+
112
+ else:
113
 
114
+ raw_response = llm.chat(messages=MESSAGES, **kwargs) #
115
+ response = raw_response.choices[0].message.content
116
+ thought = response_parser(response)
117
+
118
+ print(colored(f"Direct Answer - {response}", 'blue'))
119
 
120
+ yield thought.to_thought_steps_display()
121
+
122
 
123
  def response_parser(response:str) -> ThoughtSteps:
124
  if isinstance(response, str):