Ali2206 commited on
Commit
b316f9e
·
verified ·
1 Parent(s): 32df88c

Create txagent.py

Browse files
Files changed (1) hide show
  1. src/txagent/txagent.py +937 -0
src/txagent/txagent.py ADDED
@@ -0,0 +1,937 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import os
3
+ import sys
4
+ import json
5
+ import gc
6
+ import numpy as np
7
+ from vllm import LLM, SamplingParams
8
+ from jinja2 import Template
9
+ from typing import List
10
+ import types
11
+ from tooluniverse import ToolUniverse
12
+ from gradio import ChatMessage
13
+ from .toolrag import ToolRAGModel
14
+
15
+ from .utils import NoRepeatSentenceProcessor, ReasoningTraceChecker, tool_result_format
16
+
17
+
18
+ class TxAgent:
19
+ def __init__(self, model_name,
20
+ rag_model_name,
21
+ tool_files_dict=None, # None leads to the default tool files in ToolUniverse
22
+ enable_finish=True,
23
+ enable_rag=True,
24
+ enable_summary=False,
25
+ init_rag_num=0,
26
+ step_rag_num=10,
27
+ summary_mode='step',
28
+ summary_skip_last_k=0,
29
+ summary_context_length=None,
30
+ force_finish=True,
31
+ avoid_repeat=True,
32
+ seed=None,
33
+ enable_checker=False,
34
+ enable_chat=False,
35
+ additional_default_tools=None,
36
+ ):
37
+ self.model_name = model_name
38
+ self.tokenizer = None
39
+ self.terminators = None
40
+ self.rag_model_name = rag_model_name
41
+ self.tool_files_dict = tool_files_dict
42
+ self.model = None
43
+ self.rag_model = ToolRAGModel(rag_model_name)
44
+ self.tooluniverse = None
45
+ # self.tool_desc = None
46
+ self.prompt_multi_step = "You are a helpful assistant that will solve problems through detailed, step-by-step reasoning and actions based on your reasoning. Typically, your actions will use the provided functions. You have access to the following functions."
47
+ self.self_prompt = "Strictly follow the instruction."
48
+ self.chat_prompt = "You are helpful assistant to chat with the user."
49
+ self.enable_finish = enable_finish
50
+ self.enable_rag = enable_rag
51
+ self.enable_summary = enable_summary
52
+ self.summary_mode = summary_mode
53
+ self.summary_skip_last_k = summary_skip_last_k
54
+ self.summary_context_length = summary_context_length
55
+ self.init_rag_num = init_rag_num
56
+ self.step_rag_num = step_rag_num
57
+ self.force_finish = force_finish
58
+ self.avoid_repeat = avoid_repeat
59
+ self.seed = seed
60
+ self.enable_checker = enable_checker
61
+ self.additional_default_tools = additional_default_tools
62
+ self.print_self_values()
63
+
64
+ def init_model(self):
65
+ self.load_models()
66
+ self.load_tooluniverse()
67
+ self.load_tool_desc_embedding()
68
+
69
+ def print_self_values(self):
70
+ for attr, value in self.__dict__.items():
71
+ print(f"{attr}: {value}")
72
+
73
+ def load_models(self, model_name=None):
74
+ if model_name is not None:
75
+ if model_name == self.model_name:
76
+ return f"The model {model_name} is already loaded."
77
+ self.model_name = model_name
78
+
79
+ self.model = LLM(model=self.model_name)
80
+ self.chat_template = Template(self.model.get_tokenizer().chat_template)
81
+ self.tokenizer = self.model.get_tokenizer()
82
+
83
+ return f"Model {model_name} loaded successfully."
84
+
85
+ def load_tooluniverse(self):
86
+ self.tooluniverse = ToolUniverse(tool_files=self.tool_files_dict)
87
+ self.tooluniverse.load_tools()
88
+ special_tools = self.tooluniverse.prepare_tool_prompts(
89
+ self.tooluniverse.tool_category_dicts["special_tools"])
90
+ self.special_tools_name = [tool['name'] for tool in special_tools]
91
+
92
+ def load_tool_desc_embedding(self):
93
+ self.rag_model.load_tool_desc_embedding(self.tooluniverse)
94
+
95
+ def rag_infer(self, query, top_k=5):
96
+ return self.rag_model.rag_infer(query, top_k)
97
+
98
+ def initialize_tools_prompt(self, call_agent, call_agent_level, message):
99
+ picked_tools_prompt = []
100
+ picked_tools_prompt = self.add_special_tools(
101
+ picked_tools_prompt, call_agent=call_agent)
102
+ if call_agent:
103
+ call_agent_level += 1
104
+ if call_agent_level >= 2:
105
+ call_agent = False
106
+
107
+ if not call_agent:
108
+ picked_tools_prompt += self.tool_RAG(
109
+ message=message, rag_num=self.init_rag_num)
110
+ return picked_tools_prompt, call_agent_level
111
+
112
+ def initialize_conversation(self, message, conversation=None, history=None):
113
+ if conversation is None:
114
+ conversation = []
115
+
116
+ conversation = self.set_system_prompt(
117
+ conversation, self.prompt_multi_step)
118
+ if history is not None:
119
+ if len(history) == 0:
120
+ conversation = []
121
+ print("clear conversation successfully")
122
+ else:
123
+ for i in range(len(history)):
124
+ if history[i]['role'] == 'user':
125
+ if i-1 >= 0 and history[i-1]['role'] == 'assistant':
126
+ conversation.append(
127
+ {"role": "assistant", "content": history[i-1]['content']})
128
+ conversation.append(
129
+ {"role": "user", "content": history[i]['content']})
130
+ if i == len(history)-1 and history[i]['role'] == 'assistant':
131
+ conversation.append(
132
+ {"role": "assistant", "content": history[i]['content']})
133
+
134
+ conversation.append({"role": "user", "content": message})
135
+
136
+ return conversation
137
+
138
+ def tool_RAG(self, message=None,
139
+ picked_tool_names=None,
140
+ existing_tools_prompt=[],
141
+ rag_num=5,
142
+ return_call_result=False):
143
+ extra_factor = 30 # Factor to retrieve more than rag_num
144
+ if picked_tool_names is None:
145
+ assert picked_tool_names is not None or message is not None
146
+ picked_tool_names = self.rag_infer(
147
+ message, top_k=rag_num*extra_factor)
148
+
149
+ picked_tool_names_no_special = []
150
+ for tool in picked_tool_names:
151
+ if tool not in self.special_tools_name:
152
+ picked_tool_names_no_special.append(tool)
153
+ picked_tool_names_no_special = picked_tool_names_no_special[:rag_num]
154
+ picked_tool_names = picked_tool_names_no_special[:rag_num]
155
+
156
+ picked_tools = self.tooluniverse.get_tool_by_name(picked_tool_names)
157
+ picked_tools_prompt = self.tooluniverse.prepare_tool_prompts(
158
+ picked_tools)
159
+ if return_call_result:
160
+ return picked_tools_prompt, picked_tool_names
161
+ return picked_tools_prompt
162
+
163
+ def add_special_tools(self, tools, call_agent=False):
164
+ if self.enable_finish:
165
+ tools.append(self.tooluniverse.get_one_tool_by_one_name(
166
+ 'Finish', return_prompt=True))
167
+ print("Finish tool is added")
168
+ if call_agent:
169
+ tools.append(self.tooluniverse.get_one_tool_by_one_name(
170
+ 'CallAgent', return_prompt=True))
171
+ print("CallAgent tool is added")
172
+ else:
173
+ if self.enable_rag:
174
+ tools.append(self.tooluniverse.get_one_tool_by_one_name(
175
+ 'Tool_RAG', return_prompt=True))
176
+ print("Tool_RAG tool is added")
177
+
178
+ if self.additional_default_tools is not None:
179
+ for each_tool_name in self.additional_default_tools:
180
+ tool_prompt = self.tooluniverse.get_one_tool_by_one_name(
181
+ each_tool_name, return_prompt=True)
182
+ if tool_prompt is not None:
183
+ print(f"{each_tool_name} tool is added")
184
+ tools.append(tool_prompt)
185
+ return tools
186
+
187
+ def add_finish_tools(self, tools):
188
+ tools.append(self.tooluniverse.get_one_tool_by_one_name(
189
+ 'Finish', return_prompt=True))
190
+ print("Finish tool is added")
191
+ return tools
192
+
193
+ def set_system_prompt(self, conversation, sys_prompt):
194
+ if len(conversation) == 0:
195
+ conversation.append(
196
+ {"role": "system", "content": sys_prompt})
197
+ else:
198
+ conversation[0] = {"role": "system", "content": sys_prompt}
199
+ return conversation
200
+
201
+ def run_function_call(self, fcall_str,
202
+ return_message=False,
203
+ existing_tools_prompt=None,
204
+ message_for_call_agent=None,
205
+ call_agent=False,
206
+ call_agent_level=None,
207
+ temperature=None):
208
+
209
+ function_call_json, message = self.tooluniverse.extract_function_call_json(
210
+ fcall_str, return_message=return_message, verbose=False)
211
+ call_results = []
212
+ special_tool_call = ''
213
+ if function_call_json is not None:
214
+ if isinstance(function_call_json, list):
215
+ for i in range(len(function_call_json)):
216
+ print("\033[94mTool Call:\033[0m", function_call_json[i])
217
+ if function_call_json[i]["name"] == 'Finish':
218
+ special_tool_call = 'Finish'
219
+ break
220
+ elif function_call_json[i]["name"] == 'Tool_RAG':
221
+ new_tools_prompt, call_result = self.tool_RAG(
222
+ message=message,
223
+ existing_tools_prompt=existing_tools_prompt,
224
+ rag_num=self.step_rag_num,
225
+ return_call_result=True)
226
+ existing_tools_prompt += new_tools_prompt
227
+ elif function_call_json[i]["name"] == 'CallAgent':
228
+ if call_agent_level < 2 and call_agent:
229
+ solution_plan = function_call_json[i]['arguments']['solution']
230
+ full_message = (
231
+ message_for_call_agent +
232
+ "\nYou must follow the following plan to answer the question: " +
233
+ str(solution_plan)
234
+ )
235
+ call_result = self.run_multistep_agent(
236
+ full_message, temperature=temperature,
237
+ max_new_tokens=1024, max_token=99999,
238
+ call_agent=False, call_agent_level=call_agent_level)
239
+ call_result = call_result.split(
240
+ '[FinalAnswer]')[-1].strip()
241
+ else:
242
+ call_result = "Error: The CallAgent has been disabled. Please proceed with your reasoning process to solve this question."
243
+ else:
244
+ call_result = self.tooluniverse.run_one_function(
245
+ function_call_json[i])
246
+
247
+ call_id = self.tooluniverse.call_id_gen()
248
+ function_call_json[i]["call_id"] = call_id
249
+ print("\033[94mTool Call Result:\033[0m", call_result)
250
+ call_results.append({
251
+ "role": "tool",
252
+ "content": json.dumps({"content": call_result, "call_id": call_id})
253
+ })
254
+ else:
255
+ call_results.append({
256
+ "role": "tool",
257
+ "content": json.dumps({"content": "Not a valid function call, please check the function call format."})
258
+ })
259
+
260
+ revised_messages = [{
261
+ "role": "assistant",
262
+ "content": message.strip(),
263
+ "tool_calls": json.dumps(function_call_json)
264
+ }] + call_results
265
+
266
+ # Yield the final result.
267
+ return revised_messages, existing_tools_prompt, special_tool_call
268
+
269
+ def run_function_call_stream(self, fcall_str,
270
+ return_message=False,
271
+ existing_tools_prompt=None,
272
+ message_for_call_agent=None,
273
+ call_agent=False,
274
+ call_agent_level=None,
275
+ temperature=None,
276
+ return_gradio_history=True):
277
+
278
+ function_call_json, message = self.tooluniverse.extract_function_call_json(
279
+ fcall_str, return_message=return_message, verbose=False)
280
+ call_results = []
281
+ special_tool_call = ''
282
+ if return_gradio_history:
283
+ gradio_history = []
284
+ if function_call_json is not None:
285
+ if isinstance(function_call_json, list):
286
+ for i in range(len(function_call_json)):
287
+ if function_call_json[i]["name"] == 'Finish':
288
+ special_tool_call = 'Finish'
289
+ break
290
+ elif function_call_json[i]["name"] == 'Tool_RAG':
291
+ new_tools_prompt, call_result = self.tool_RAG(
292
+ message=message,
293
+ existing_tools_prompt=existing_tools_prompt,
294
+ rag_num=self.step_rag_num,
295
+ return_call_result=True)
296
+ existing_tools_prompt += new_tools_prompt
297
+ elif function_call_json[i]["name"] == 'DirectResponse':
298
+ call_result = function_call_json[i]['arguments']['respose']
299
+ special_tool_call = 'DirectResponse'
300
+ elif function_call_json[i]["name"] == 'RequireClarification':
301
+ call_result = function_call_json[i]['arguments']['unclear_question']
302
+ special_tool_call = 'RequireClarification'
303
+ elif function_call_json[i]["name"] == 'CallAgent':
304
+ if call_agent_level < 2 and call_agent:
305
+ solution_plan = function_call_json[i]['arguments']['solution']
306
+ full_message = (
307
+ message_for_call_agent +
308
+ "\nYou must follow the following plan to answer the question: " +
309
+ str(solution_plan)
310
+ )
311
+ sub_agent_task = "Sub TxAgent plan: " + \
312
+ str(solution_plan)
313
+ # When streaming, yield responses as they arrive.
314
+ call_result = yield from self.run_gradio_chat(
315
+ full_message, history=[], temperature=temperature,
316
+ max_new_tokens=1024, max_token=99999,
317
+ call_agent=False, call_agent_level=call_agent_level,
318
+ conversation=None,
319
+ sub_agent_task=sub_agent_task)
320
+
321
+ call_result = call_result.split(
322
+ '[FinalAnswer]')[-1]
323
+ else:
324
+ call_result = "Error: The CallAgent has been disabled. Please proceed with your reasoning process to solve this question."
325
+ else:
326
+ call_result = self.tooluniverse.run_one_function(
327
+ function_call_json[i])
328
+
329
+ call_id = self.tooluniverse.call_id_gen()
330
+ function_call_json[i]["call_id"] = call_id
331
+ call_results.append({
332
+ "role": "tool",
333
+ "content": json.dumps({"content": call_result, "call_id": call_id})
334
+ })
335
+ if return_gradio_history and function_call_json[i]["name"] != 'Finish':
336
+ if function_call_json[i]["name"] == 'Tool_RAG':
337
+ gradio_history.append(ChatMessage(role="assistant", content=str(call_result), metadata={
338
+ "title": "🧰 "+function_call_json[i]['name'], "log": str(function_call_json[i]['arguments'])}))
339
+
340
+ else:
341
+ gradio_history.append(ChatMessage(role="assistant", content=str(call_result), metadata={
342
+ "title": "⚒️ "+function_call_json[i]['name'], "log": str(function_call_json[i]['arguments'])}))
343
+ else:
344
+ call_results.append({
345
+ "role": "tool",
346
+ "content": json.dumps({"content": "Not a valid function call, please check the function call format."})
347
+ })
348
+
349
+ revised_messages = [{
350
+ "role": "assistant",
351
+ "content": message.strip(),
352
+ "tool_calls": json.dumps(function_call_json)
353
+ }] + call_results
354
+
355
+ # Yield the final result.
356
+ if return_gradio_history:
357
+ return revised_messages, existing_tools_prompt, special_tool_call, gradio_history
358
+ else:
359
+ return revised_messages, existing_tools_prompt, special_tool_call
360
+
361
+ def get_answer_based_on_unfinished_reasoning(self, conversation, temperature, max_new_tokens, max_token, outputs=None):
362
+ if conversation[-1]['role'] == 'assisant':
363
+ conversation.append(
364
+ {'role': 'tool', 'content': 'Errors happen during the function call, please come up with the final answer with the current information.'})
365
+ finish_tools_prompt = self.add_finish_tools([])
366
+
367
+ last_outputs_str = self.llm_infer(messages=conversation,
368
+ temperature=temperature,
369
+ tools=finish_tools_prompt,
370
+ output_begin_string='Since I cannot continue reasoning, I will provide the final answer based on the current information and general knowledge.\n\n[FinalAnswer]',
371
+ skip_special_tokens=True,
372
+ max_new_tokens=max_new_tokens, max_token=max_token)
373
+ print(last_outputs_str)
374
+ return last_outputs_str
375
+
376
+ def run_multistep_agent(self, message: str,
377
+ temperature: float,
378
+ max_new_tokens: int,
379
+ max_token: int,
380
+ max_round: int = 20,
381
+ call_agent=False,
382
+ call_agent_level=0) -> str:
383
+ """
384
+ Generate a streaming response using the llama3-8b model.
385
+ Args:
386
+ message (str): The input message.
387
+ temperature (float): The temperature for generating the response.
388
+ max_new_tokens (int): The maximum number of new tokens to generate.
389
+ Returns:
390
+ str: The generated response.
391
+ """
392
+ print("\033[1;32;40mstart\033[0m")
393
+ picked_tools_prompt, call_agent_level = self.initialize_tools_prompt(
394
+ call_agent, call_agent_level, message)
395
+ conversation = self.initialize_conversation(message)
396
+
397
+ outputs = []
398
+ last_outputs = []
399
+ next_round = True
400
+ function_call_messages = []
401
+ current_round = 0
402
+ token_overflow = False
403
+ enable_summary = False
404
+ last_status = {}
405
+
406
+ if self.enable_checker:
407
+ checker = ReasoningTraceChecker(message, conversation)
408
+ try:
409
+ while next_round and current_round < max_round:
410
+ current_round += 1
411
+ if len(outputs) > 0:
412
+ function_call_messages, picked_tools_prompt, special_tool_call = self.run_function_call(
413
+ last_outputs, return_message=True,
414
+ existing_tools_prompt=picked_tools_prompt,
415
+ message_for_call_agent=message,
416
+ call_agent=call_agent,
417
+ call_agent_level=call_agent_level,
418
+ temperature=temperature)
419
+
420
+ if special_tool_call == 'Finish':
421
+ next_round = False
422
+ conversation.extend(function_call_messages)
423
+ if isinstance(function_call_messages[0]['content'], types.GeneratorType):
424
+ function_call_messages[0]['content'] = next(
425
+ function_call_messages[0]['content'])
426
+ return function_call_messages[0]['content'].split('[FinalAnswer]')[-1]
427
+
428
+ if (self.enable_summary or token_overflow) and not call_agent:
429
+ if token_overflow:
430
+ print("token_overflow, using summary")
431
+ enable_summary = True
432
+ last_status = self.function_result_summary(
433
+ conversation, status=last_status, enable_summary=enable_summary)
434
+
435
+ if function_call_messages is not None:
436
+ conversation.extend(function_call_messages)
437
+ outputs.append(tool_result_format(
438
+ function_call_messages))
439
+ else:
440
+ next_round = False
441
+ conversation.extend(
442
+ [{"role": "assistant", "content": ''.join(last_outputs)}])
443
+ return ''.join(last_outputs).replace("</s>", "")
444
+ if self.enable_checker:
445
+ good_status, wrong_info = checker.check_conversation()
446
+ if not good_status:
447
+ next_round = False
448
+ print(
449
+ "Internal error in reasoning: " + wrong_info)
450
+ break
451
+ last_outputs = []
452
+ outputs.append("### TxAgent:\n")
453
+ last_outputs_str, token_overflow = self.llm_infer(messages=conversation,
454
+ temperature=temperature,
455
+ tools=picked_tools_prompt,
456
+ skip_special_tokens=False,
457
+ max_new_tokens=max_new_tokens, max_token=max_token,
458
+ check_token_status=True)
459
+ if last_outputs_str is None:
460
+ next_round = False
461
+ print(
462
+ "The number of tokens exceeds the maximum limit.")
463
+ else:
464
+ last_outputs.append(last_outputs_str)
465
+ if max_round == current_round:
466
+ print("The number of rounds exceeds the maximum limit!")
467
+ if self.force_finish:
468
+ return self.get_answer_based_on_unfinished_reasoning(conversation, temperature, max_new_tokens, max_token)
469
+ else:
470
+ return None
471
+
472
+ except Exception as e:
473
+ print(f"Error: {e}")
474
+ if self.force_finish:
475
+ return self.get_answer_based_on_unfinished_reasoning(conversation, temperature, max_new_tokens, max_token)
476
+ else:
477
+ return None
478
+
479
+ def build_logits_processor(self, messages, llm):
480
+ # Use the tokenizer from the LLM instance.
481
+ tokenizer = llm.get_tokenizer()
482
+ if self.avoid_repeat and len(messages) > 2:
483
+ assistant_messages = []
484
+ for i in range(1, len(messages) + 1):
485
+ if messages[-i]['role'] == 'assistant':
486
+ assistant_messages.append(messages[-i]['content'])
487
+ if len(assistant_messages) == 2:
488
+ break
489
+ forbidden_ids = [tokenizer.encode(
490
+ msg, add_special_tokens=False) for msg in assistant_messages]
491
+ return [NoRepeatSentenceProcessor(forbidden_ids, 5)]
492
+ else:
493
+ return None
494
+
495
+ def llm_infer(self, messages, temperature=0.1, tools=None,
496
+ output_begin_string=None, max_new_tokens=2048,
497
+ max_token=None, skip_special_tokens=True,
498
+ model=None, tokenizer=None, terminators=None, seed=None, check_token_status=False):
499
+
500
+ if model is None:
501
+ model = self.model
502
+
503
+ logits_processor = self.build_logits_processor(messages, model)
504
+ sampling_params = SamplingParams(
505
+ temperature=temperature,
506
+ max_tokens=max_new_tokens,
507
+ logits_processors=logits_processor,
508
+ seed=seed if seed is not None else self.seed,
509
+ )
510
+
511
+ prompt = self.chat_template.render(
512
+ messages=messages, tools=tools, add_generation_prompt=True)
513
+ if output_begin_string is not None:
514
+ prompt += output_begin_string
515
+
516
+ if check_token_status and max_token is not None:
517
+ token_overflow = False
518
+ num_input_tokens = len(self.tokenizer.encode(
519
+ prompt, return_tensors="pt")[0])
520
+ if max_token is not None:
521
+ if num_input_tokens > max_token:
522
+ torch.cuda.empty_cache()
523
+ gc.collect()
524
+ print("Number of input tokens before inference:",
525
+ num_input_tokens)
526
+ logger.info(
527
+ "The number of tokens exceeds the maximum limit!!!!")
528
+ token_overflow = True
529
+ return None, token_overflow
530
+ output = model.generate(
531
+ prompt,
532
+ sampling_params=sampling_params,
533
+ )
534
+ output = output[0].outputs[0].text
535
+ print("\033[92m" + output + "\033[0m")
536
+ if check_token_status and max_token is not None:
537
+ return output, token_overflow
538
+
539
+ return output
540
+
541
+ def run_self_agent(self, message: str,
542
+ temperature: float,
543
+ max_new_tokens: int,
544
+ max_token: int) -> str:
545
+
546
+ print("\033[1;32;40mstart self agent\033[0m")
547
+ conversation = []
548
+ conversation = self.set_system_prompt(conversation, self.self_prompt)
549
+ conversation.append({"role": "user", "content": message})
550
+ return self.llm_infer(messages=conversation,
551
+ temperature=temperature,
552
+ tools=None,
553
+ max_new_tokens=max_new_tokens, max_token=max_token)
554
+
555
+ def run_chat_agent(self, message: str,
556
+ temperature: float,
557
+ max_new_tokens: int,
558
+ max_token: int) -> str:
559
+
560
+ print("\033[1;32;40mstart chat agent\033[0m")
561
+ conversation = []
562
+ conversation = self.set_system_prompt(conversation, self.chat_prompt)
563
+ conversation.append({"role": "user", "content": message})
564
+ return self.llm_infer(messages=conversation,
565
+ temperature=temperature,
566
+ tools=None,
567
+ max_new_tokens=max_new_tokens, max_token=max_token)
568
+
569
+ def run_format_agent(self, message: str,
570
+ answer: str,
571
+ temperature: float,
572
+ max_new_tokens: int,
573
+ max_token: int) -> str:
574
+
575
+ print("\033[1;32;40mstart format agent\033[0m")
576
+ if '[FinalAnswer]' in answer:
577
+ possible_final_answer = answer.split("[FinalAnswer]")[-1]
578
+ elif "\n\n" in answer:
579
+ possible_final_answer = answer.split("\n\n")[-1]
580
+ else:
581
+ possible_final_answer = answer.strip()
582
+ if len(possible_final_answer) == 1:
583
+ choice = possible_final_answer[0]
584
+ if choice in ['A', 'B', 'C', 'D', 'E']:
585
+ return choice
586
+ elif len(possible_final_answer) > 1:
587
+ if possible_final_answer[1] == ':':
588
+ choice = possible_final_answer[0]
589
+ if choice in ['A', 'B', 'C', 'D', 'E']:
590
+ print("choice", choice)
591
+ return choice
592
+
593
+ conversation = []
594
+ format_prompt = f"You are helpful assistant to transform the answer of agent to the final answer of 'A', 'B', 'C', 'D'."
595
+ conversation = self.set_system_prompt(conversation, format_prompt)
596
+ conversation.append({"role": "user", "content": message +
597
+ "\nThe final answer of agent:" + answer + "\n The answer is (must be a letter):"})
598
+ return self.llm_infer(messages=conversation,
599
+ temperature=temperature,
600
+ tools=None,
601
+ max_new_tokens=max_new_tokens, max_token=max_token)
602
+
603
+ def run_summary_agent(self, thought_calls: str,
604
+ function_response: str,
605
+ temperature: float,
606
+ max_new_tokens: int,
607
+ max_token: int) -> str:
608
+ print("\033[1;32;40mSummarized Tool Result:\033[0m")
609
+ generate_tool_result_summary_training_prompt = """Thought and function calls:
610
+ {thought_calls}
611
+
612
+ Function calls' responses:
613
+ \"\"\"
614
+ {function_response}
615
+ \"\"\"
616
+
617
+ Based on the Thought and function calls, and the function calls' responses, you need to generate a summary of the function calls' responses that fulfills the requirements of the thought. The summary MUST BE ONE sentence and include all necessary information.
618
+
619
+ Directly respond with the summarized sentence of the function calls' responses only.
620
+
621
+ Generate **one summarized sentence** about "function calls' responses" with necessary information, and respond with a string:
622
+ """.format(thought_calls=thought_calls, function_response=function_response)
623
+ conversation = []
624
+ conversation.append(
625
+ {"role": "user", "content": generate_tool_result_summary_training_prompt})
626
+ output = self.llm_infer(messages=conversation,
627
+ temperature=temperature,
628
+ tools=None,
629
+ max_new_tokens=max_new_tokens, max_token=max_token)
630
+
631
+ if '[' in output:
632
+ output = output.split('[')[0]
633
+ return output
634
+
635
+ def function_result_summary(self, input_list, status, enable_summary):
636
+ """
637
+ Processes the input list, extracting information from sequences of 'user', 'tool', 'assistant' roles.
638
+ Supports 'length' and 'step' modes, and skips the last 'k' groups.
639
+
640
+ Parameters:
641
+ input_list (list): A list of dictionaries containing role and other information.
642
+ summary_skip_last_k (int): Number of groups to skip from the end. Defaults to 0.
643
+ summary_context_length (int): The context length threshold for the 'length' mode.
644
+ last_processed_index (tuple or int): The last processed index.
645
+
646
+ Returns:
647
+ list: A list of extracted information from valid sequences.
648
+ """
649
+ if 'tool_call_step' not in status:
650
+ status['tool_call_step'] = 0
651
+
652
+ for idx in range(len(input_list)):
653
+ pos_id = len(input_list)-idx-1
654
+ if input_list[pos_id]['role'] == 'assistant':
655
+ if 'tool_calls' in input_list[pos_id]:
656
+ if 'Tool_RAG' in str(input_list[pos_id]['tool_calls']):
657
+ status['tool_call_step'] += 1
658
+ break
659
+
660
+ if 'step' in status:
661
+ status['step'] += 1
662
+ else:
663
+ status['step'] = 0
664
+
665
+ if not enable_summary:
666
+ return status
667
+
668
+ if 'summarized_index' not in status:
669
+ status['summarized_index'] = 0
670
+
671
+ if 'summarized_step' not in status:
672
+ status['summarized_step'] = 0
673
+
674
+ if 'previous_length' not in status:
675
+ status['previous_length'] = 0
676
+
677
+ if 'history' not in status:
678
+ status['history'] = []
679
+
680
+ function_response = ''
681
+ idx = 0
682
+ current_summarized_index = status['summarized_index']
683
+
684
+ status['history'].append(self.summary_mode == 'step' and status['summarized_step']
685
+ < status['step']-status['tool_call_step']-self.summary_skip_last_k)
686
+
687
+ idx = current_summarized_index
688
+ while idx < len(input_list):
689
+ if (self.summary_mode == 'step' and status['summarized_step'] < status['step']-status['tool_call_step']-self.summary_skip_last_k) or (self.summary_mode == 'length' and status['previous_length'] > self.summary_context_length):
690
+
691
+ if input_list[idx]['role'] == 'assistant':
692
+ if 'Tool_RAG' in str(input_list[idx]['tool_calls']):
693
+ this_thought_calls = None
694
+ else:
695
+ if len(function_response) != 0:
696
+ print("internal summary")
697
+ status['summarized_step'] += 1
698
+ result_summary = self.run_summary_agent(
699
+ thought_calls=this_thought_calls,
700
+ function_response=function_response,
701
+ temperature=0.1,
702
+ max_new_tokens=1024,
703
+ max_token=99999
704
+ )
705
+
706
+ input_list.insert(
707
+ last_call_idx+1, {'role': 'tool', 'content': result_summary})
708
+ status['summarized_index'] = last_call_idx + 2
709
+ idx += 1
710
+
711
+ last_call_idx = idx
712
+ this_thought_calls = input_list[idx]['content'] + \
713
+ input_list[idx]['tool_calls']
714
+ function_response = ''
715
+
716
+ elif input_list[idx]['role'] == 'tool' and this_thought_calls is not None:
717
+ function_response += input_list[idx]['content']
718
+ del input_list[idx]
719
+ idx -= 1
720
+
721
+ else:
722
+ break
723
+ idx += 1
724
+
725
+ if len(function_response) != 0:
726
+ status['summarized_step'] += 1
727
+ result_summary = self.run_summary_agent(
728
+ thought_calls=this_thought_calls,
729
+ function_response=function_response,
730
+ temperature=0.1,
731
+ max_new_tokens=1024,
732
+ max_token=99999
733
+ )
734
+
735
+ tool_calls = json.loads(input_list[last_call_idx]['tool_calls'])
736
+ for tool_call in tool_calls:
737
+ del tool_call['call_id']
738
+ input_list[last_call_idx]['tool_calls'] = json.dumps(tool_calls)
739
+ input_list.insert(
740
+ last_call_idx+1, {'role': 'tool', 'content': result_summary})
741
+ status['summarized_index'] = last_call_idx + 2
742
+
743
+ return status
744
+
745
+ # Following are Gradio related functions
746
+
747
+ # General update method that accepts any new arguments through kwargs
748
+ def update_parameters(self, **kwargs):
749
+ for key, value in kwargs.items():
750
+ if hasattr(self, key):
751
+ setattr(self, key, value)
752
+
753
+ # Return the updated attributes
754
+ updated_attributes = {key: value for key,
755
+ value in kwargs.items() if hasattr(self, key)}
756
+ return updated_attributes
757
+
758
+ def run_gradio_chat(self, message: str,
759
+ history: list,
760
+ temperature: float,
761
+ max_new_tokens: int,
762
+ max_token: int,
763
+ call_agent: bool,
764
+ conversation: gr.State,
765
+ max_round: int = 20,
766
+ seed: int = None,
767
+ call_agent_level: int = 0,
768
+ sub_agent_task: str = None) -> str:
769
+ """
770
+ Generate a streaming response using the llama3-8b model.
771
+ Args:
772
+ message (str): The input message.
773
+ history (list): The conversation history used by ChatInterface.
774
+ temperature (float): The temperature for generating the response.
775
+ max_new_tokens (int): The maximum number of new tokens to generate.
776
+ Returns:
777
+ str: The generated response.
778
+ """
779
+ print("\033[1;32;40mstart\033[0m")
780
+ print("len(message)", len(message))
781
+ if len(message) <= 10:
782
+ yield "Hi, I am TxAgent, an assistant for answering biomedical questions. Please provide a valid message with a string longer than 10 characters."
783
+ return "Please provide a valid message."
784
+ outputs = []
785
+ outputs_str = ''
786
+ last_outputs = []
787
+
788
+ picked_tools_prompt, call_agent_level = self.initialize_tools_prompt(
789
+ call_agent,
790
+ call_agent_level,
791
+ message)
792
+
793
+ conversation = self.initialize_conversation(
794
+ message,
795
+ conversation=conversation,
796
+ history=history)
797
+ history = []
798
+
799
+ next_round = True
800
+ function_call_messages = []
801
+ current_round = 0
802
+ enable_summary = False
803
+ last_status = {} # for summary
804
+ token_overflow = False
805
+ if self.enable_checker:
806
+ checker = ReasoningTraceChecker(
807
+ message, conversation, init_index=len(conversation))
808
+
809
+ try:
810
+ while next_round and current_round < max_round:
811
+ current_round += 1
812
+ if len(last_outputs) > 0:
813
+ function_call_messages, picked_tools_prompt, special_tool_call, current_gradio_history = yield from self.run_function_call_stream(
814
+ last_outputs, return_message=True,
815
+ existing_tools_prompt=picked_tools_prompt,
816
+ message_for_call_agent=message,
817
+ call_agent=call_agent,
818
+ call_agent_level=call_agent_level,
819
+ temperature=temperature)
820
+ history.extend(current_gradio_history)
821
+ if special_tool_call == 'Finish':
822
+ yield history
823
+ next_round = False
824
+ conversation.extend(function_call_messages)
825
+ return function_call_messages[0]['content']
826
+ elif special_tool_call == 'RequireClarification' or special_tool_call == 'DirectResponse':
827
+ history.append(
828
+ ChatMessage(role="assistant", content=history[-1].content))
829
+ yield history
830
+ next_round = False
831
+ return history[-1].content
832
+ if (self.enable_summary or token_overflow) and not call_agent:
833
+ if token_overflow:
834
+ print("token_overflow, using summary")
835
+ enable_summary = True
836
+ last_status = self.function_result_summary(
837
+ conversation, status=last_status,
838
+ enable_summary=enable_summary)
839
+ if function_call_messages is not None:
840
+ conversation.extend(function_call_messages)
841
+ formated_md_function_call_messages = tool_result_format(
842
+ function_call_messages)
843
+ yield history
844
+ else:
845
+ next_round = False
846
+ conversation.extend(
847
+ [{"role": "assistant", "content": ''.join(last_outputs)}])
848
+ return ''.join(last_outputs).replace("</s>", "")
849
+ if self.enable_checker:
850
+ good_status, wrong_info = checker.check_conversation()
851
+ if not good_status:
852
+ next_round = False
853
+ print("Internal error in reasoning: " + wrong_info)
854
+ break
855
+ last_outputs = []
856
+ last_outputs_str, token_overflow = self.llm_infer(
857
+ messages=conversation,
858
+ temperature=temperature,
859
+ tools=picked_tools_prompt,
860
+ skip_special_tokens=False,
861
+ max_new_tokens=max_new_tokens,
862
+ max_token=max_token,
863
+ seed=seed,
864
+ check_token_status=True)
865
+ last_thought = last_outputs_str.split("[TOOL_CALLS]")[0]
866
+ for each in history:
867
+ if each.metadata is not None:
868
+ each.metadata['status'] = 'done'
869
+ if '[FinalAnswer]' in last_thought:
870
+ final_thought, final_answer = last_thought.split(
871
+ '[FinalAnswer]')
872
+ history.append(
873
+ ChatMessage(role="assistant",
874
+ content=final_thought.strip())
875
+ )
876
+ yield history
877
+ history.append(
878
+ ChatMessage(
879
+ role="assistant", content="**Answer**:\n"+final_answer.strip())
880
+ )
881
+ yield history
882
+ else:
883
+ history.append(ChatMessage(
884
+ role="assistant", content=last_thought))
885
+ yield history
886
+
887
+ last_outputs.append(last_outputs_str)
888
+
889
+ if next_round:
890
+ if self.force_finish:
891
+ last_outputs_str = self.get_answer_based_on_unfinished_reasoning(
892
+ conversation, temperature, max_new_tokens, max_token)
893
+ for each in history:
894
+ if each.metadata is not None:
895
+ each.metadata['status'] = 'done'
896
+ if '[FinalAnswer]' in last_thought:
897
+ final_thought, final_answer = last_thought.split(
898
+ '[FinalAnswer]')
899
+ history.append(
900
+ ChatMessage(role="assistant",
901
+ content=final_thought.strip())
902
+ )
903
+ yield history
904
+ history.append(
905
+ ChatMessage(
906
+ role="assistant", content="**Answer**:\n"+final_answer.strip())
907
+ )
908
+ yield history
909
+ else:
910
+ yield "The number of rounds exceeds the maximum limit!"
911
+
912
+ except Exception as e:
913
+ print(f"Error: {e}")
914
+ if self.force_finish:
915
+ last_outputs_str = self.get_answer_based_on_unfinished_reasoning(
916
+ conversation,
917
+ temperature,
918
+ max_new_tokens,
919
+ max_token)
920
+ for each in history:
921
+ if each.metadata is not None:
922
+ each.metadata['status'] = 'done'
923
+ if '[FinalAnswer]' in last_thought or '"name": "Finish",' in last_outputs_str:
924
+ final_thought, final_answer = last_thought.split(
925
+ '[FinalAnswer]')
926
+ history.append(
927
+ ChatMessage(role="assistant",
928
+ content=final_thought.strip())
929
+ )
930
+ yield history
931
+ history.append(
932
+ ChatMessage(
933
+ role="assistant", content="**Answer**:\n"+final_answer.strip())
934
+ )
935
+ yield history
936
+ else:
937
+ return None