Spaces:
Sleeping
Sleeping
File size: 8,201 Bytes
2224132 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 |
# main class chaining Planner, Worker and Solver.
import re
import time
from nodes.Planner import Planner
from nodes.Solver import Solver
from nodes.Worker import *
from utils.util import *
class PWS:
def __init__(self, available_tools=["Google", "LLM"], fewshot="\n", planner_model="text-davinci-003",
solver_model="text-davinci-003"):
self.workers = available_tools
self.planner = Planner(workers=self.workers,
model_name=planner_model,
fewshot=fewshot)
self.solver = Solver(model_name=solver_model)
self.plans = []
self.planner_evidences = {}
self.worker_evidences = {}
self.tool_counter = {}
self.planner_token_unit_price = get_token_unit_price(planner_model)
self.solver_token_unit_price = get_token_unit_price(solver_model)
self.tool_token_unit_price = get_token_unit_price("text-davinci-003")
self.google_unit_price = 0.01
# input: the question line. e.g. "Question: What is the capital of France?"
def run(self, input):
# run is stateless, so we need to reset the evidences
self._reinitialize()
result = {}
st = time.time()
# Plan
planner_response = self.planner.run(input, log=True)
plan = planner_response["output"]
planner_log = planner_response["input"] + planner_response["output"]
self.plans = self._parse_plans(plan)
self.planner_evidences = self._parse_planner_evidences(plan)
# Work
self._get_worker_evidences()
worker_log = ""
for i in range(len(self.plans)):
e = f"#E{i + 1}"
worker_log += f"{self.plans[i]}\nEvidence:\n{self.worker_evidences[e]}\n\n"
# Solve with all evidence at once
max_retries = 5
solver_response = None
for attempt in range(max_retries):
solver_response = self.solver.run(input, worker_log, log=True)
output = solver_response["output"]
# Check if the response is a refusal or too short
if any(phrase in output.lower() for phrase in ["không thể", "cannot", "can't", "unable to", "không biết", "tự tìm", "Rất tiếc", "Xin lỗi"]) or len(output.split()) < 20:
continue
else:
break
# If all retries failed, use a simplified response format
if solver_response is None or any(phrase in output.lower() for phrase in ["không thể", "cannot", "can't", "unable to", "không biết", "tự tìm"]):
output = "Dựa trên thông tin thu thập được:\n\n" + "\n\n".join(self.worker_evidences.values())
solver_log = solver_response["input"] + solver_response["output"]
result["wall_time"] = time.time() - st
result["input"] = input
result["output"] = output
result["planner_log"] = planner_log
result["worker_log"] = worker_log
result["solver_log"] = solver_log
result["tool_usage"] = self.tool_counter
result["steps"] = len(self.plans) + 1
result["total_tokens"] = planner_response["prompt_tokens"] + planner_response["completion_tokens"] \
+ solver_response["prompt_tokens"] + solver_response["completion_tokens"] \
+ self.tool_counter.get("LLM_token", 0) \
+ self.tool_counter.get("Calculator_token", 0)
result["token_cost"] = self.planner_token_unit_price * (planner_response["prompt_tokens"] + planner_response["completion_tokens"]) \
+ self.solver_token_unit_price * (solver_response["prompt_tokens"] + solver_response["completion_tokens"]) \
+ self.tool_token_unit_price * (self.tool_counter.get("LLM_token", 0) + self.tool_counter.get("Calculator_token", 0))
result["tool_cost"] = self.tool_counter.get("Google", 0) * self.google_unit_price
result["total_cost"] = result["token_cost"] + result["tool_cost"]
return result
def _parse_plans(self, response):
plans = []
for line in response.splitlines():
if line.startswith("Plan:"):
plans.append(line)
return plans
def _parse_planner_evidences(self, response):
evidences = {}
for line in response.splitlines():
if line.startswith("#") and line[1] == "E" and line[2].isdigit():
e, tool_call = line.split("=", 1)
e, tool_call = e.strip(), tool_call.strip()
if len(e) == 3:
evidences[e] = tool_call
else:
evidences[e] = "No evidence found"
return evidences
# use planner evidences to assign tasks to respective workers.
def _get_worker_evidences(self):
for e, tool_call in self.planner_evidences.items():
if "[" not in tool_call:
self.worker_evidences[e] = tool_call
continue
tool, tool_input = tool_call.split("[", 1)
tool_input = tool_input[:-1]
# find variables in input and replace with previous evidences
for var in re.findall(r"#E\d+", tool_input):
if var in self.worker_evidences:
tool_input = tool_input.replace(var, "[" + self.worker_evidences[var] + "]")
if tool in self.workers:
evidence = WORKER_REGISTRY[tool].run(tool_input)
# Keep evidence concise - around 300 tokens (roughly 400 words)
if len(evidence.split()) > 400:
from langchain_openai import OpenAI
llm = OpenAI(temperature=0)
summarize_prompt = f"""Summarize the following information in about 250-300 words, keeping the most relevant details and any specific numbers, requirements, or key points:
{evidence}
Summary:"""
try:
evidence = llm.invoke(summarize_prompt).strip()
except:
# If summarization fails, take first 300 words
evidence = " ".join(evidence.split()[:300]) + "..."
self.worker_evidences[e] = evidence
if tool == "Google":
self.tool_counter["Google"] = self.tool_counter.get("Google", 0) + 1 # number of query
elif tool == "LLM":
self.tool_counter["LLM_token"] = self.tool_counter.get("LLM_token", 0) + len(
tool_input + self.worker_evidences[e]) // 4
elif tool == "Calculator":
self.tool_counter["Calculator_token"] = self.tool_counter.get("Calculator_token", 0) \
+ len(
LLMMathChain(llm=OpenAI(), verbose=False).prompt.template + tool_input + self.worker_evidences[
e]) // 4
else:
self.worker_evidences[e] = "No evidence found"
def _reinitialize(self):
self.plans = []
self.planner_evidences = {}
self.worker_evidences = {}
self.tool_counter = {}
class PWS_Base(PWS):
def __init__(self, fewshot=fewshots.HOTPOTQA_PWS_BASE, planner_model="text-davinci-003",
solver_model="text-davinci-003", available_tools=["Wikipedia", "LLM"]):
super().__init__(available_tools=available_tools,
fewshot=fewshot,
planner_model=planner_model,
solver_model=solver_model)
class PWS_Extra(PWS):
def __init__(self, fewshot=fewshots.HOTPOTQA_PWS_EXTRA, planner_model="text-davinci-003",
solver_model="text-davinci-003", available_tools=["Google", "Calculator", "LLM"]):
super().__init__(available_tools=available_tools,
fewshot=fewshot,
planner_model=planner_model,
solver_model=solver_model)
|