tikendraw commited on
Commit
867bdf1
·
1 Parent(s): 662f11c

gets default prompt from cot or da func

Browse files
Files changed (1) hide show
  1. app/utils.py +17 -11
app/utils.py CHANGED
@@ -19,15 +19,17 @@ from core.llms.litellm_llm import LLM
19
  from core.llms.utils import user_message_with_images
20
  from PIL import Image
21
  from streamlit.runtime.uploaded_file_manager import UploadedFile
22
- from core.prompts.decision_prompt import COT_OR_DA_PROMPT, COTorDAPromptOutput, Decision
 
23
 
24
 
25
 
26
 
 
27
  def cot_or_da_func(problem: str, llm: BaseLLM = None, **kwargs) -> COTorDAPromptOutput:
28
 
29
  cot_decision_message = [
30
- {"role": "system", "content": COT_OR_DA_PROMPT},
31
  {"role": "user", "content": problem}]
32
 
33
  raw_decision_response = llm.chat(messages=cot_decision_message, **kwargs)
@@ -37,10 +39,9 @@ def cot_or_da_func(problem: str, llm: BaseLLM = None, **kwargs) -> COTorDAPrompt
37
  try:
38
  decision = json.loads(decision_response)
39
  cot_or_da = COTorDAPromptOutput(**decision)
40
- except (json.JSONDecodeError, ValidationError, KeyError):
41
- print(colored("Error parsing LLM's CoT decision. Defaulting to Chain of thought.", 'red'))
42
- cot_or_da = COTorDAPromptOutput(problem=problem, decision="Chain-of-Thought", reasoning="Defaulting to Chain-of-Thought")
43
-
44
  return cot_or_da
45
 
46
 
@@ -52,8 +53,7 @@ def get_system_prompt(decision: Decision) -> str:
52
  else:
53
  raise ValueError(f"Invalid decision: {decision}")
54
 
55
- def set_system_message(messages: list[dict], cot_or_da: COTorDAPromptOutput) -> list[dict]:
56
- system_prompt = get_system_prompt(cot_or_da.decision)
57
  #check if any system message already exists
58
  if any(message['role'] == 'system' for message in messages):
59
  for i, message in enumerate(messages):
@@ -71,7 +71,13 @@ def generate_answer(messages: list[dict], max_steps: int = 20, llm: BaseLLM = No
71
  cot_or_da = cot_or_da_func(user_message, llm=llm, **kwargs)
72
  print(colored(f"LLM Decision: {cot_or_da.decision} - Justification: {cot_or_da.reasoning}", 'magenta'))
73
 
74
- MESSAGES = set_system_message(messages, cot_or_da)
 
 
 
 
 
 
75
 
76
 
77
  if cot_or_da.decision == Decision.CHAIN_OF_THOUGHT:
@@ -93,12 +99,12 @@ def generate_answer(messages: list[dict], max_steps: int = 20, llm: BaseLLM = No
93
  if thought.is_final_answer and not thought.next_step and not force_max_steps:
94
  break
95
 
96
- MESSAGES.append({"role": "user", "content": cot.REVIEW_PROMPT})
97
 
98
  time.sleep(sleeptime)
99
 
100
  # Get the final answer after all thoughts are processed
101
- MESSAGES += [{"role": "user", "content": cot.FINAL_ANSWER_PROMPT}]
102
 
103
  raw_final_answers = llm.chat(messages=MESSAGES, **kwargs)
104
  final_answer = raw_final_answers.choices[0].message.content
 
19
  from core.llms.utils import user_message_with_images
20
  from PIL import Image
21
  from streamlit.runtime.uploaded_file_manager import UploadedFile
22
+ from core.prompts.decision_prompt import PLAN_SYSTEM_PROMPT, COTorDAPromptOutput, Decision
23
+ from tenacity import retry, stop_after_attempt, wait_random_exponential
24
 
25
 
26
 
27
 
28
+ @retry(stop=stop_after_attempt(3))
29
  def cot_or_da_func(problem: str, llm: BaseLLM = None, **kwargs) -> COTorDAPromptOutput:
30
 
31
  cot_decision_message = [
32
+ {"role": "system", "content": PLAN_SYSTEM_PROMPT},
33
  {"role": "user", "content": problem}]
34
 
35
  raw_decision_response = llm.chat(messages=cot_decision_message, **kwargs)
 
39
  try:
40
  decision = json.loads(decision_response)
41
  cot_or_da = COTorDAPromptOutput(**decision)
42
+ except (json.JSONDecodeError, ValidationError, KeyError) as e:
43
+ raise e
44
+
 
45
  return cot_or_da
46
 
47
 
 
53
  else:
54
  raise ValueError(f"Invalid decision: {decision}")
55
 
56
+ def set_system_message(messages: list[dict], system_prompt: str) -> list[dict]:
 
57
  #check if any system message already exists
58
  if any(message['role'] == 'system' for message in messages):
59
  for i, message in enumerate(messages):
 
71
  cot_or_da = cot_or_da_func(user_message, llm=llm, **kwargs)
72
  print(colored(f"LLM Decision: {cot_or_da.decision} - Justification: {cot_or_da.reasoning}", 'magenta'))
73
 
74
+ system_prompt, review_prompt, final_answer_prompt = cot_or_da.prompts.system_prompt, cot_or_da.prompts.review_prompt, cot_or_da.prompts.final_answer_prompt
75
+
76
+ system_prompt += f" , {cot.SYSTEM_PROMPT_EXAMPLE_JSON}"
77
+ review_prompt += f" , {cot.REVIEW_PROMPT_EXAMPLE_JSON}"
78
+ final_answer_prompt += f" , {cot.FINAL_ANSWER_PROMPT}"
79
+
80
+ MESSAGES = set_system_message(messages, system_prompt)
81
 
82
 
83
  if cot_or_da.decision == Decision.CHAIN_OF_THOUGHT:
 
99
  if thought.is_final_answer and not thought.next_step and not force_max_steps:
100
  break
101
 
102
+ MESSAGES.append({"role": "user", "content": f"{thought.critic} {review_prompt} {cot.REVIEW_PROMPT_EXAMPLE_JSON}"})
103
 
104
  time.sleep(sleeptime)
105
 
106
  # Get the final answer after all thoughts are processed
107
+ MESSAGES += [{"role": "user", "content": f"{final_answer_prompt} {cot.FINAL_ANSWER_PROMPT}"}]
108
 
109
  raw_final_answers = llm.chat(messages=MESSAGES, **kwargs)
110
  final_answer = raw_final_answers.choices[0].message.content