File size: 8,201 Bytes
2224132
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
# main class chaining Planner, Worker and Solver.
import re
import time

from nodes.Planner import Planner
from nodes.Solver import Solver
from nodes.Worker import *
from utils.util import *


class PWS:
    def __init__(self, available_tools=["Google", "LLM"], fewshot="\n", planner_model="text-davinci-003",
                 solver_model="text-davinci-003"):
        self.workers = available_tools
        self.planner = Planner(workers=self.workers,
                               model_name=planner_model,
                               fewshot=fewshot)
        self.solver = Solver(model_name=solver_model)
        self.plans = []
        self.planner_evidences = {}
        self.worker_evidences = {}
        self.tool_counter = {}
        self.planner_token_unit_price = get_token_unit_price(planner_model)
        self.solver_token_unit_price = get_token_unit_price(solver_model)
        self.tool_token_unit_price = get_token_unit_price("text-davinci-003")
        self.google_unit_price = 0.01

    # input: the question line. e.g. "Question: What is the capital of France?"
    def run(self, input):
        # run is stateless, so we need to reset the evidences
        self._reinitialize()
        result = {}
        st = time.time()
        # Plan
        planner_response = self.planner.run(input, log=True)
        plan = planner_response["output"]
        planner_log = planner_response["input"] + planner_response["output"]
        self.plans = self._parse_plans(plan)
        self.planner_evidences = self._parse_planner_evidences(plan)

        # Work
        self._get_worker_evidences()
        worker_log = ""
        for i in range(len(self.plans)):
            e = f"#E{i + 1}"
            worker_log += f"{self.plans[i]}\nEvidence:\n{self.worker_evidences[e]}\n\n"

        # Solve with all evidence at once
        max_retries = 5
        solver_response = None
        for attempt in range(max_retries):
            solver_response = self.solver.run(input, worker_log, log=True)
            output = solver_response["output"]
            
            # Check if the response is a refusal or too short
            if any(phrase in output.lower() for phrase in ["không thể", "cannot", "can't", "unable to", "không biết", "tự tìm", "Rất tiếc", "Xin lỗi"]) or len(output.split()) < 20:
                continue
            else:
                break
        
        # If all retries failed, use a simplified response format
        if solver_response is None or any(phrase in output.lower() for phrase in ["không thể", "cannot", "can't", "unable to", "không biết", "tự tìm"]):
            output = "Dựa trên thông tin thu thập được:\n\n" + "\n\n".join(self.worker_evidences.values())

        solver_log = solver_response["input"] + solver_response["output"]

        result["wall_time"] = time.time() - st
        result["input"] = input
        result["output"] = output
        result["planner_log"] = planner_log
        result["worker_log"] = worker_log
        result["solver_log"] = solver_log
        result["tool_usage"] = self.tool_counter
        result["steps"] = len(self.plans) + 1
        result["total_tokens"] = planner_response["prompt_tokens"] + planner_response["completion_tokens"] \
                                 + solver_response["prompt_tokens"] + solver_response["completion_tokens"] \
                                 + self.tool_counter.get("LLM_token", 0) \
                                 + self.tool_counter.get("Calculator_token", 0)
        result["token_cost"] = self.planner_token_unit_price * (planner_response["prompt_tokens"] + planner_response["completion_tokens"]) \
                               + self.solver_token_unit_price * (solver_response["prompt_tokens"] + solver_response["completion_tokens"]) \
                               + self.tool_token_unit_price * (self.tool_counter.get("LLM_token", 0) + self.tool_counter.get("Calculator_token", 0))
        result["tool_cost"] = self.tool_counter.get("Google", 0) * self.google_unit_price
        result["total_cost"] = result["token_cost"] + result["tool_cost"]

        return result

    def _parse_plans(self, response):
        plans = []
        for line in response.splitlines():
            if line.startswith("Plan:"):
                plans.append(line)
        return plans

    def _parse_planner_evidences(self, response):
        evidences = {}
        for line in response.splitlines():
            if line.startswith("#") and line[1] == "E" and line[2].isdigit():
                e, tool_call = line.split("=", 1)
                e, tool_call = e.strip(), tool_call.strip()
                if len(e) == 3:
                    evidences[e] = tool_call
                else:
                    evidences[e] = "No evidence found"
        return evidences

    # use planner evidences to assign tasks to respective workers.
    def _get_worker_evidences(self):
        for e, tool_call in self.planner_evidences.items():
            if "[" not in tool_call:
                self.worker_evidences[e] = tool_call
                continue
            tool, tool_input = tool_call.split("[", 1)
            tool_input = tool_input[:-1]
            # find variables in input and replace with previous evidences
            for var in re.findall(r"#E\d+", tool_input):
                if var in self.worker_evidences:
                    tool_input = tool_input.replace(var, "[" + self.worker_evidences[var] + "]")
            if tool in self.workers:
                evidence = WORKER_REGISTRY[tool].run(tool_input)
                
                # Keep evidence concise - around 300 tokens (roughly 400 words)
                if len(evidence.split()) > 400:
                    from langchain_openai import OpenAI
                    llm = OpenAI(temperature=0)
                    summarize_prompt = f"""Summarize the following information in about 250-300 words, keeping the most relevant details and any specific numbers, requirements, or key points:

{evidence}

Summary:"""
                    try:
                        evidence = llm.invoke(summarize_prompt).strip()
                    except:
                        # If summarization fails, take first 300 words
                        evidence = " ".join(evidence.split()[:300]) + "..."
                
                self.worker_evidences[e] = evidence
                
                if tool == "Google":
                    self.tool_counter["Google"] = self.tool_counter.get("Google", 0) + 1  # number of query
                elif tool == "LLM":
                    self.tool_counter["LLM_token"] = self.tool_counter.get("LLM_token", 0) + len(
                        tool_input + self.worker_evidences[e]) // 4
                elif tool == "Calculator":
                    self.tool_counter["Calculator_token"] = self.tool_counter.get("Calculator_token", 0) \
                                                            + len(
                        LLMMathChain(llm=OpenAI(), verbose=False).prompt.template + tool_input + self.worker_evidences[
                            e]) // 4
            else:
                self.worker_evidences[e] = "No evidence found"

    def _reinitialize(self):
        self.plans = []
        self.planner_evidences = {}
        self.worker_evidences = {}
        self.tool_counter = {}


class PWS_Base(PWS):
    def __init__(self, fewshot=fewshots.HOTPOTQA_PWS_BASE, planner_model="text-davinci-003",
                 solver_model="text-davinci-003", available_tools=["Wikipedia", "LLM"]):
        super().__init__(available_tools=available_tools,
                         fewshot=fewshot,
                         planner_model=planner_model,
                         solver_model=solver_model)


class PWS_Extra(PWS):
    def  __init__(self, fewshot=fewshots.HOTPOTQA_PWS_EXTRA, planner_model="text-davinci-003",
                 solver_model="text-davinci-003", available_tools=["Google", "Calculator", "LLM"]):
        super().__init__(available_tools=available_tools,
                         fewshot=fewshot,
                         planner_model=planner_model,
                         solver_model=solver_model)