Ali2206 commited on
Commit
cc12a3f
·
verified ·
1 Parent(s): dd006b6

Update src/txagent/txagent.py

Browse files
Files changed (1) hide show
  1. src/txagent/txagent.py +945 -945
src/txagent/txagent.py CHANGED
@@ -1,945 +1,945 @@
1
- import gradio as gr
2
- import os
3
- import sys
4
- import json
5
- import gc
6
- import numpy as np
7
- from vllm import LLM, SamplingParams
8
- from jinja2 import Template
9
- from typing import List
10
- import types
11
- from tooluniverse import ToolUniverse
12
- from gradio import ChatMessage
13
- from .toolrag import ToolRAGModel
14
- import torch
15
- # near the top of txagent.py
16
- import logging
17
- logger = logging.getLogger(__name__)
18
- logging.basicConfig(level=logging.INFO)
19
-
20
- from .utils import NoRepeatSentenceProcessor, ReasoningTraceChecker, tool_result_format
21
-
22
-
23
- class TxAgent:
24
- def __init__(self, model_name,
25
- rag_model_name,
26
- tool_files_dict=None, # None leads to the default tool files in ToolUniverse
27
- enable_finish=True,
28
- enable_rag=True,
29
- enable_summary=False,
30
- init_rag_num=0,
31
- step_rag_num=10,
32
- summary_mode='step',
33
- summary_skip_last_k=0,
34
- summary_context_length=None,
35
- force_finish=True,
36
- avoid_repeat=True,
37
- seed=None,
38
- enable_checker=False,
39
- enable_chat=False,
40
- additional_default_tools=None,
41
- ):
42
- self.model_name = model_name
43
- self.tokenizer = None
44
- self.terminators = None
45
- self.rag_model_name = rag_model_name
46
- self.tool_files_dict = tool_files_dict
47
- self.model = None
48
- self.rag_model = ToolRAGModel(rag_model_name)
49
- self.tooluniverse = None
50
- # self.tool_desc = None
51
- self.prompt_multi_step = "You are a helpful assistant that will solve problems through detailed, step-by-step reasoning and actions based on your reasoning. Typically, your actions will use the provided functions. You have access to the following functions."
52
- self.self_prompt = "Strictly follow the instruction."
53
- self.chat_prompt = "You are helpful assistant to chat with the user."
54
- self.enable_finish = enable_finish
55
- self.enable_rag = enable_rag
56
- self.enable_summary = enable_summary
57
- self.summary_mode = summary_mode
58
- self.summary_skip_last_k = summary_skip_last_k
59
- self.summary_context_length = summary_context_length
60
- self.init_rag_num = init_rag_num
61
- self.step_rag_num = step_rag_num
62
- self.force_finish = force_finish
63
- self.avoid_repeat = avoid_repeat
64
- self.seed = seed
65
- self.enable_checker = enable_checker
66
- self.additional_default_tools = additional_default_tools
67
- self.print_self_values()
68
-
69
- def init_model(self):
70
- self.load_models()
71
- self.load_tooluniverse()
72
- self.load_tool_desc_embedding()
73
-
74
- def print_self_values(self):
75
- for attr, value in self.__dict__.items():
76
- print(f"{attr}: {value}")
77
-
78
- def load_models(self, model_name=None):
79
- if model_name is not None:
80
- if model_name == self.model_name:
81
- return f"The model {model_name} is already loaded."
82
- self.model_name = model_name
83
-
84
- self.model = LLM(model=self.model_name)
85
- self.chat_template = Template(self.model.get_tokenizer().chat_template)
86
- self.tokenizer = self.model.get_tokenizer()
87
-
88
- return f"Model {model_name} loaded successfully."
89
-
90
- def load_tooluniverse(self):
91
- self.tooluniverse = ToolUniverse(tool_files=self.tool_files_dict)
92
- self.tooluniverse.load_tools()
93
- special_tools = self.tooluniverse.prepare_tool_prompts(
94
- self.tooluniverse.tool_category_dicts["special_tools"])
95
- self.special_tools_name = [tool['name'] for tool in special_tools]
96
-
97
- def load_tool_desc_embedding(self):
98
- self.rag_model.load_tool_desc_embedding(self.tooluniverse)
99
-
100
- def rag_infer(self, query, top_k=5):
101
- return self.rag_model.rag_infer(query, top_k)
102
-
103
- def initialize_tools_prompt(self, call_agent, call_agent_level, message):
104
- picked_tools_prompt = []
105
- picked_tools_prompt = self.add_special_tools(
106
- picked_tools_prompt, call_agent=call_agent)
107
- if call_agent:
108
- call_agent_level += 1
109
- if call_agent_level >= 2:
110
- call_agent = False
111
-
112
- if not call_agent:
113
- picked_tools_prompt += self.tool_RAG(
114
- message=message, rag_num=self.init_rag_num)
115
- return picked_tools_prompt, call_agent_level
116
-
117
- def initialize_conversation(self, message, conversation=None, history=None):
118
- if conversation is None:
119
- conversation = []
120
-
121
- conversation = self.set_system_prompt(
122
- conversation, self.prompt_multi_step)
123
- if history is not None:
124
- if len(history) == 0:
125
- conversation = []
126
- print("clear conversation successfully")
127
- else:
128
- for i in range(len(history)):
129
- if history[i]['role'] == 'user':
130
- if i-1 >= 0 and history[i-1]['role'] == 'assistant':
131
- conversation.append(
132
- {"role": "assistant", "content": history[i-1]['content']})
133
- conversation.append(
134
- {"role": "user", "content": history[i]['content']})
135
- if i == len(history)-1 and history[i]['role'] == 'assistant':
136
- conversation.append(
137
- {"role": "assistant", "content": history[i]['content']})
138
-
139
- conversation.append({"role": "user", "content": message})
140
-
141
- return conversation
142
-
143
- def tool_RAG(self, message=None,
144
- picked_tool_names=None,
145
- existing_tools_prompt=[],
146
- rag_num=5,
147
- return_call_result=False):
148
- extra_factor = 30 # Factor to retrieve more than rag_num
149
- if picked_tool_names is None:
150
- assert picked_tool_names is not None or message is not None
151
- picked_tool_names = self.rag_infer(
152
- message, top_k=rag_num*extra_factor)
153
-
154
- picked_tool_names_no_special = []
155
- for tool in picked_tool_names:
156
- if tool not in self.special_tools_name:
157
- picked_tool_names_no_special.append(tool)
158
- picked_tool_names_no_special = picked_tool_names_no_special[:rag_num]
159
- picked_tool_names = picked_tool_names_no_special[:rag_num]
160
-
161
- picked_tools = self.tooluniverse.get_tool_by_name(picked_tool_names)
162
- picked_tools_prompt = self.tooluniverse.prepare_tool_prompts(
163
- picked_tools)
164
- if return_call_result:
165
- return picked_tools_prompt, picked_tool_names
166
- return picked_tools_prompt
167
-
168
- def add_special_tools(self, tools, call_agent=False):
169
- if self.enable_finish:
170
- tools.append(self.tooluniverse.get_one_tool_by_one_name(
171
- 'Finish', return_prompt=True))
172
- print("Finish tool is added")
173
- if call_agent:
174
- tools.append(self.tooluniverse.get_one_tool_by_one_name(
175
- 'CallAgent', return_prompt=True))
176
- print("CallAgent tool is added")
177
- else:
178
- if self.enable_rag:
179
- tools.append(self.tooluniverse.get_one_tool_by_one_name(
180
- 'Tool_RAG', return_prompt=True))
181
- print("Tool_RAG tool is added")
182
-
183
- if self.additional_default_tools is not None:
184
- for each_tool_name in self.additional_default_tools:
185
- tool_prompt = self.tooluniverse.get_one_tool_by_one_name(
186
- each_tool_name, return_prompt=True)
187
- if tool_prompt is not None:
188
- print(f"{each_tool_name} tool is added")
189
- tools.append(tool_prompt)
190
- return tools
191
-
192
- def add_finish_tools(self, tools):
193
- tools.append(self.tooluniverse.get_one_tool_by_one_name(
194
- 'Finish', return_prompt=True))
195
- print("Finish tool is added")
196
- return tools
197
-
198
- def set_system_prompt(self, conversation, sys_prompt):
199
- if len(conversation) == 0:
200
- conversation.append(
201
- {"role": "system", "content": sys_prompt})
202
- else:
203
- conversation[0] = {"role": "system", "content": sys_prompt}
204
- return conversation
205
-
206
- def run_function_call(self, fcall_str,
207
- return_message=False,
208
- existing_tools_prompt=None,
209
- message_for_call_agent=None,
210
- call_agent=False,
211
- call_agent_level=None,
212
- temperature=None):
213
-
214
- function_call_json, message = self.tooluniverse.extract_function_call_json(
215
- fcall_str, return_message=return_message, verbose=False)
216
- call_results = []
217
- special_tool_call = ''
218
- if function_call_json is not None:
219
- if isinstance(function_call_json, list):
220
- for i in range(len(function_call_json)):
221
- print("\033[94mTool Call:\033[0m", function_call_json[i])
222
- if function_call_json[i]["name"] == 'Finish':
223
- special_tool_call = 'Finish'
224
- break
225
- elif function_call_json[i]["name"] == 'Tool_RAG':
226
- new_tools_prompt, call_result = self.tool_RAG(
227
- message=message,
228
- existing_tools_prompt=existing_tools_prompt,
229
- rag_num=self.step_rag_num,
230
- return_call_result=True)
231
- existing_tools_prompt += new_tools_prompt
232
- elif function_call_json[i]["name"] == 'CallAgent':
233
- if call_agent_level < 2 and call_agent:
234
- solution_plan = function_call_json[i]['arguments']['solution']
235
- full_message = (
236
- message_for_call_agent +
237
- "\nYou must follow the following plan to answer the question: " +
238
- str(solution_plan)
239
- )
240
- call_result = self.run_multistep_agent(
241
- full_message, temperature=temperature,
242
- max_new_tokens=1024, max_token=99999,
243
- call_agent=False, call_agent_level=call_agent_level)
244
- call_result = call_result.split(
245
- '[FinalAnswer]')[-1].strip()
246
- else:
247
- call_result = "Error: The CallAgent has been disabled. Please proceed with your reasoning process to solve this question."
248
- else:
249
- call_result = self.tooluniverse.run_one_function(
250
- function_call_json[i])
251
-
252
- call_id = self.tooluniverse.call_id_gen()
253
- function_call_json[i]["call_id"] = call_id
254
- print("\033[94mTool Call Result:\033[0m", call_result)
255
- call_results.append({
256
- "role": "tool",
257
- "content": json.dumps({"content": call_result, "call_id": call_id})
258
- })
259
- else:
260
- call_results.append({
261
- "role": "tool",
262
- "content": json.dumps({"content": "Not a valid function call, please check the function call format."})
263
- })
264
-
265
- revised_messages = [{
266
- "role": "assistant",
267
- "content": message.strip(),
268
- "tool_calls": json.dumps(function_call_json)
269
- }] + call_results
270
-
271
- # Yield the final result.
272
- return revised_messages, existing_tools_prompt, special_tool_call
273
-
274
- def run_function_call_stream(self, fcall_str,
275
- return_message=False,
276
- existing_tools_prompt=None,
277
- message_for_call_agent=None,
278
- call_agent=False,
279
- call_agent_level=None,
280
- temperature=None,
281
- return_gradio_history=True):
282
-
283
- function_call_json, message = self.tooluniverse.extract_function_call_json(
284
- fcall_str, return_message=return_message, verbose=False)
285
- call_results = []
286
- special_tool_call = ''
287
- if return_gradio_history:
288
- gradio_history = []
289
- if function_call_json is not None:
290
- if isinstance(function_call_json, list):
291
- for i in range(len(function_call_json)):
292
- if function_call_json[i]["name"] == 'Finish':
293
- special_tool_call = 'Finish'
294
- break
295
- elif function_call_json[i]["name"] == 'Tool_RAG':
296
- new_tools_prompt, call_result = self.tool_RAG(
297
- message=message,
298
- existing_tools_prompt=existing_tools_prompt,
299
- rag_num=self.step_rag_num,
300
- return_call_result=True)
301
- existing_tools_prompt += new_tools_prompt
302
- elif function_call_json[i]["name"] == 'DirectResponse':
303
- call_result = function_call_json[i]['arguments']['respose']
304
- special_tool_call = 'DirectResponse'
305
- elif function_call_json[i]["name"] == 'RequireClarification':
306
- call_result = function_call_json[i]['arguments']['unclear_question']
307
- special_tool_call = 'RequireClarification'
308
- elif function_call_json[i]["name"] == 'CallAgent':
309
- if call_agent_level < 2 and call_agent:
310
- solution_plan = function_call_json[i]['arguments']['solution']
311
- full_message = (
312
- message_for_call_agent +
313
- "\nYou must follow the following plan to answer the question: " +
314
- str(solution_plan)
315
- )
316
- sub_agent_task = "Sub TxAgent plan: " + \
317
- str(solution_plan)
318
- # When streaming, yield responses as they arrive.
319
- call_result = yield from self.run_gradio_chat(
320
- full_message, history=[], temperature=temperature,
321
- max_new_tokens=1024, max_token=99999,
322
- call_agent=False, call_agent_level=call_agent_level,
323
- conversation=None,
324
- sub_agent_task=sub_agent_task)
325
-
326
- call_result = call_result.split(
327
- '[FinalAnswer]')[-1]
328
- else:
329
- call_result = "Error: The CallAgent has been disabled. Please proceed with your reasoning process to solve this question."
330
- else:
331
- call_result = self.tooluniverse.run_one_function(
332
- function_call_json[i])
333
-
334
- call_id = self.tooluniverse.call_id_gen()
335
- function_call_json[i]["call_id"] = call_id
336
- call_results.append({
337
- "role": "tool",
338
- "content": json.dumps({"content": call_result, "call_id": call_id})
339
- })
340
- if return_gradio_history and function_call_json[i]["name"] != 'Finish':
341
- if function_call_json[i]["name"] == 'Tool_RAG':
342
- gradio_history.append(ChatMessage(role="assistant", content=str(call_result), metadata={
343
- "title": "🧰 "+function_call_json[i]['name'], "log": str(function_call_json[i]['arguments'])}))
344
-
345
- else:
346
- gradio_history.append(ChatMessage(role="assistant", content=str(call_result), metadata={
347
- "title": "⚒️ "+function_call_json[i]['name'], "log": str(function_call_json[i]['arguments'])}))
348
- else:
349
- call_results.append({
350
- "role": "tool",
351
- "content": json.dumps({"content": "Not a valid function call, please check the function call format."})
352
- })
353
-
354
- revised_messages = [{
355
- "role": "assistant",
356
- "content": message.strip(),
357
- "tool_calls": json.dumps(function_call_json)
358
- }] + call_results
359
-
360
- # Yield the final result.
361
- if return_gradio_history:
362
- return revised_messages, existing_tools_prompt, special_tool_call, gradio_history
363
- else:
364
- return revised_messages, existing_tools_prompt, special_tool_call
365
-
366
- def get_answer_based_on_unfinished_reasoning(self, conversation, temperature, max_new_tokens, max_token, outputs=None):
367
- if conversation[-1]['role'] == 'assisant':
368
- conversation.append(
369
- {'role': 'tool', 'content': 'Errors happen during the function call, please come up with the final answer with the current information.'})
370
- finish_tools_prompt = self.add_finish_tools([])
371
-
372
- last_outputs_str = self.llm_infer(messages=conversation,
373
- temperature=temperature,
374
- tools=finish_tools_prompt,
375
- output_begin_string='Since I cannot continue reasoning, I will provide the final answer based on the current information and general knowledge.\n\n[FinalAnswer]',
376
- skip_special_tokens=True,
377
- max_new_tokens=max_new_tokens, max_token=max_token)
378
- print(last_outputs_str)
379
- return last_outputs_str
380
-
381
- def run_multistep_agent(self, message: str,
382
- temperature: float,
383
- max_new_tokens: int,
384
- max_token: int,
385
- max_round: int = 20,
386
- call_agent=False,
387
- call_agent_level=0) -> str:
388
- """
389
- Generate a streaming response using the llama3-8b model.
390
- Args:
391
- message (str): The input message.
392
- temperature (float): The temperature for generating the response.
393
- max_new_tokens (int): The maximum number of new tokens to generate.
394
- Returns:
395
- str: The generated response.
396
- """
397
- print("\033[1;32;40mstart\033[0m")
398
- picked_tools_prompt, call_agent_level = self.initialize_tools_prompt(
399
- call_agent, call_agent_level, message)
400
- conversation = self.initialize_conversation(message)
401
-
402
- outputs = []
403
- last_outputs = []
404
- next_round = True
405
- function_call_messages = []
406
- current_round = 0
407
- token_overflow = False
408
- enable_summary = False
409
- last_status = {}
410
-
411
- if self.enable_checker:
412
- checker = ReasoningTraceChecker(message, conversation)
413
- try:
414
- while next_round and current_round < max_round:
415
- current_round += 1
416
- if len(outputs) > 0:
417
- function_call_messages, picked_tools_prompt, special_tool_call = self.run_function_call(
418
- last_outputs, return_message=True,
419
- existing_tools_prompt=picked_tools_prompt,
420
- message_for_call_agent=message,
421
- call_agent=call_agent,
422
- call_agent_level=call_agent_level,
423
- temperature=temperature)
424
-
425
- if special_tool_call == 'Finish':
426
- next_round = False
427
- conversation.extend(function_call_messages)
428
- if isinstance(function_call_messages[0]['content'], types.GeneratorType):
429
- function_call_messages[0]['content'] = next(
430
- function_call_messages[0]['content'])
431
- return function_call_messages[0]['content'].split('[FinalAnswer]')[-1]
432
-
433
- if (self.enable_summary or token_overflow) and not call_agent:
434
- if token_overflow:
435
- print("token_overflow, using summary")
436
- enable_summary = True
437
- last_status = self.function_result_summary(
438
- conversation, status=last_status, enable_summary=enable_summary)
439
-
440
- if function_call_messages is not None:
441
- conversation.extend(function_call_messages)
442
- outputs.append(tool_result_format(
443
- function_call_messages))
444
- else:
445
- next_round = False
446
- conversation.extend(
447
- [{"role": "assistant", "content": ''.join(last_outputs)}])
448
- return ''.join(last_outputs).replace("</s>", "")
449
- if self.enable_checker:
450
- good_status, wrong_info = checker.check_conversation()
451
- if not good_status:
452
- next_round = False
453
- print(
454
- "Internal error in reasoning: " + wrong_info)
455
- break
456
- last_outputs = []
457
- outputs.append("### TxAgent:\n")
458
- last_outputs_str, token_overflow = self.llm_infer(messages=conversation,
459
- temperature=temperature,
460
- tools=picked_tools_prompt,
461
- skip_special_tokens=False,
462
- max_new_tokens=max_new_tokens, max_token=max_token,
463
- check_token_status=True)
464
- if last_outputs_str is None:
465
- next_round = False
466
- print(
467
- "The number of tokens exceeds the maximum limit.")
468
- else:
469
- last_outputs.append(last_outputs_str)
470
- if max_round == current_round:
471
- print("The number of rounds exceeds the maximum limit!")
472
- if self.force_finish:
473
- return self.get_answer_based_on_unfinished_reasoning(conversation, temperature, max_new_tokens, max_token)
474
- else:
475
- return None
476
-
477
- except Exception as e:
478
- print(f"Error: {e}")
479
- if self.force_finish:
480
- return self.get_answer_based_on_unfinished_reasoning(conversation, temperature, max_new_tokens, max_token)
481
- else:
482
- return None
483
-
484
- def build_logits_processor(self, messages, llm):
485
- # Use the tokenizer from the LLM instance.
486
- tokenizer = llm.get_tokenizer()
487
- if self.avoid_repeat and len(messages) > 2:
488
- assistant_messages = []
489
- for i in range(1, len(messages) + 1):
490
- if messages[-i]['role'] == 'assistant':
491
- assistant_messages.append(messages[-i]['content'])
492
- if len(assistant_messages) == 2:
493
- break
494
- forbidden_ids = [tokenizer.encode(
495
- msg, add_special_tokens=False) for msg in assistant_messages]
496
- return [NoRepeatSentenceProcessor(forbidden_ids, 5)]
497
- else:
498
- return None
499
-
500
- def llm_infer(self, messages, temperature=0.1, tools=None,
501
- output_begin_string=None, max_new_tokens=2048,
502
- max_token=None, skip_special_tokens=True,
503
- model=None, tokenizer=None, terminators=None, seed=None, check_token_status=False):
504
-
505
- if model is None:
506
- model = self.model
507
-
508
- logits_processor = self.build_logits_processor(messages, model)
509
- sampling_params = SamplingParams(
510
- temperature=temperature,
511
- max_tokens=max_new_tokens,
512
-
513
- seed=seed if seed is not None else self.seed,
514
- )
515
-
516
- prompt = self.chat_template.render(
517
- messages=messages, tools=tools, add_generation_prompt=True)
518
- if output_begin_string is not None:
519
- prompt += output_begin_string
520
-
521
- if check_token_status and max_token is not None:
522
- token_overflow = False
523
- num_input_tokens = len(self.tokenizer.encode(
524
- prompt, return_tensors="pt")[0])
525
- if max_token is not None:
526
- if num_input_tokens > max_token:
527
- torch.cuda.empty_cache()
528
- gc.collect()
529
- print("Number of input tokens before inference:",
530
- num_input_tokens)
531
- logger.info(
532
- "The number of tokens exceeds the maximum limit!!!!")
533
- token_overflow = True
534
- return None, token_overflow
535
- output = model.generate(
536
- prompt,
537
- sampling_params=sampling_params,
538
- )
539
- output = output[0].outputs[0].text
540
- print("\033[92m" + output + "\033[0m")
541
- if check_token_status and max_token is not None:
542
- return output, token_overflow
543
-
544
- return output
545
-
546
- def run_self_agent(self, message: str,
547
- temperature: float,
548
- max_new_tokens: int,
549
- max_token: int) -> str:
550
-
551
- print("\033[1;32;40mstart self agent\033[0m")
552
- conversation = []
553
- conversation = self.set_system_prompt(conversation, self.self_prompt)
554
- conversation.append({"role": "user", "content": message})
555
- return self.llm_infer(messages=conversation,
556
- temperature=temperature,
557
- tools=None,
558
- max_new_tokens=max_new_tokens, max_token=max_token)
559
-
560
- def run_chat_agent(self, message: str,
561
- temperature: float,
562
- max_new_tokens: int,
563
- max_token: int) -> str:
564
-
565
- print("\033[1;32;40mstart chat agent\033[0m")
566
- conversation = []
567
- conversation = self.set_system_prompt(conversation, self.chat_prompt)
568
- conversation.append({"role": "user", "content": message})
569
- return self.llm_infer(messages=conversation,
570
- temperature=temperature,
571
- tools=None,
572
- max_new_tokens=max_new_tokens, max_token=max_token)
573
-
574
- def run_format_agent(self, message: str,
575
- answer: str,
576
- temperature: float,
577
- max_new_tokens: int,
578
- max_token: int) -> str:
579
-
580
- print("\033[1;32;40mstart format agent\033[0m")
581
- if '[FinalAnswer]' in answer:
582
- possible_final_answer = answer.split("[FinalAnswer]")[-1]
583
- elif "\n\n" in answer:
584
- possible_final_answer = answer.split("\n\n")[-1]
585
- else:
586
- possible_final_answer = answer.strip()
587
- if len(possible_final_answer) == 1:
588
- choice = possible_final_answer[0]
589
- if choice in ['A', 'B', 'C', 'D', 'E']:
590
- return choice
591
- elif len(possible_final_answer) > 1:
592
- if possible_final_answer[1] == ':':
593
- choice = possible_final_answer[0]
594
- if choice in ['A', 'B', 'C', 'D', 'E']:
595
- print("choice", choice)
596
- return choice
597
-
598
- conversation = []
599
- format_prompt = f"You are helpful assistant to transform the answer of agent to the final answer of 'A', 'B', 'C', 'D'."
600
- conversation = self.set_system_prompt(conversation, format_prompt)
601
- conversation.append({"role": "user", "content": message +
602
- "\nThe final answer of agent:" + answer + "\n The answer is (must be a letter):"})
603
- return self.llm_infer(messages=conversation,
604
- temperature=temperature,
605
- tools=None,
606
- max_new_tokens=max_new_tokens, max_token=max_token)
607
-
608
- def run_summary_agent(self, thought_calls: str,
609
- function_response: str,
610
- temperature: float,
611
- max_new_tokens: int,
612
- max_token: int) -> str:
613
- print("\033[1;32;40mSummarized Tool Result:\033[0m")
614
- generate_tool_result_summary_training_prompt = """Thought and function calls:
615
- {thought_calls}
616
-
617
- Function calls' responses:
618
- \"\"\"
619
- {function_response}
620
- \"\"\"
621
-
622
- Based on the Thought and function calls, and the function calls' responses, you need to generate a summary of the function calls' responses that fulfills the requirements of the thought. The summary MUST BE ONE sentence and include all necessary information.
623
-
624
- Directly respond with the summarized sentence of the function calls' responses only.
625
-
626
- Generate **one summarized sentence** about "function calls' responses" with necessary information, and respond with a string:
627
- """.format(thought_calls=thought_calls, function_response=function_response)
628
- conversation = []
629
- conversation.append(
630
- {"role": "user", "content": generate_tool_result_summary_training_prompt})
631
- output = self.llm_infer(messages=conversation,
632
- temperature=temperature,
633
- tools=None,
634
- max_new_tokens=max_new_tokens, max_token=max_token)
635
-
636
- if '[' in output:
637
- output = output.split('[')[0]
638
- return output
639
-
640
- def function_result_summary(self, input_list, status, enable_summary):
641
- """
642
- Processes the input list, extracting information from sequences of 'user', 'tool', 'assistant' roles.
643
- Supports 'length' and 'step' modes, and skips the last 'k' groups.
644
-
645
- Parameters:
646
- input_list (list): A list of dictionaries containing role and other information.
647
- summary_skip_last_k (int): Number of groups to skip from the end. Defaults to 0.
648
- summary_context_length (int): The context length threshold for the 'length' mode.
649
- last_processed_index (tuple or int): The last processed index.
650
-
651
- Returns:
652
- list: A list of extracted information from valid sequences.
653
- """
654
- if 'tool_call_step' not in status:
655
- status['tool_call_step'] = 0
656
-
657
- for idx in range(len(input_list)):
658
- pos_id = len(input_list)-idx-1
659
- if input_list[pos_id]['role'] == 'assistant':
660
- if 'tool_calls' in input_list[pos_id]:
661
- if 'Tool_RAG' in str(input_list[pos_id]['tool_calls']):
662
- status['tool_call_step'] += 1
663
- break
664
-
665
- if 'step' in status:
666
- status['step'] += 1
667
- else:
668
- status['step'] = 0
669
-
670
- if not enable_summary:
671
- return status
672
-
673
- if 'summarized_index' not in status:
674
- status['summarized_index'] = 0
675
-
676
- if 'summarized_step' not in status:
677
- status['summarized_step'] = 0
678
-
679
- if 'previous_length' not in status:
680
- status['previous_length'] = 0
681
-
682
- if 'history' not in status:
683
- status['history'] = []
684
-
685
- function_response = ''
686
- idx = 0
687
- current_summarized_index = status['summarized_index']
688
-
689
- status['history'].append(self.summary_mode == 'step' and status['summarized_step']
690
- < status['step']-status['tool_call_step']-self.summary_skip_last_k)
691
-
692
- idx = current_summarized_index
693
- while idx < len(input_list):
694
- if (self.summary_mode == 'step' and status['summarized_step'] < status['step']-status['tool_call_step']-self.summary_skip_last_k) or (self.summary_mode == 'length' and status['previous_length'] > self.summary_context_length):
695
-
696
- if input_list[idx]['role'] == 'assistant':
697
- if 'Tool_RAG' in str(input_list[idx]['tool_calls']):
698
- this_thought_calls = None
699
- else:
700
- if len(function_response) != 0:
701
- print("internal summary")
702
- status['summarized_step'] += 1
703
- result_summary = self.run_summary_agent(
704
- thought_calls=this_thought_calls,
705
- function_response=function_response,
706
- temperature=0.1,
707
- max_new_tokens=1024,
708
- max_token=99999
709
- )
710
-
711
- input_list.insert(
712
- last_call_idx+1, {'role': 'tool', 'content': result_summary})
713
- status['summarized_index'] = last_call_idx + 2
714
- idx += 1
715
-
716
- last_call_idx = idx
717
- this_thought_calls = input_list[idx]['content'] + \
718
- input_list[idx]['tool_calls']
719
- function_response = ''
720
-
721
- elif input_list[idx]['role'] == 'tool' and this_thought_calls is not None:
722
- function_response += input_list[idx]['content']
723
- del input_list[idx]
724
- idx -= 1
725
-
726
- else:
727
- break
728
- idx += 1
729
-
730
- if len(function_response) != 0:
731
- status['summarized_step'] += 1
732
- result_summary = self.run_summary_agent(
733
- thought_calls=this_thought_calls,
734
- function_response=function_response,
735
- temperature=0.1,
736
- max_new_tokens=1024,
737
- max_token=99999
738
- )
739
-
740
- tool_calls = json.loads(input_list[last_call_idx]['tool_calls'])
741
- for tool_call in tool_calls:
742
- del tool_call['call_id']
743
- input_list[last_call_idx]['tool_calls'] = json.dumps(tool_calls)
744
- input_list.insert(
745
- last_call_idx+1, {'role': 'tool', 'content': result_summary})
746
- status['summarized_index'] = last_call_idx + 2
747
-
748
- return status
749
-
750
- # Following are Gradio related functions
751
-
752
- # General update method that accepts any new arguments through kwargs
753
- def update_parameters(self, **kwargs):
754
- for key, value in kwargs.items():
755
- if hasattr(self, key):
756
- setattr(self, key, value)
757
-
758
- # Return the updated attributes
759
- updated_attributes = {key: value for key,
760
- value in kwargs.items() if hasattr(self, key)}
761
- return updated_attributes
762
-
763
- def run_gradio_chat(self, message: str,
764
- history: list,
765
- temperature: float,
766
- max_new_tokens: int,
767
- max_token: int,
768
- call_agent: bool,
769
- conversation: gr.State,
770
- max_round: int = 20,
771
- seed: int = None,
772
- call_agent_level: int = 0,
773
- sub_agent_task: str = None) -> str:
774
- """
775
- Generate a streaming response using the llama3-8b model.
776
- Args:
777
- message (str): The input message.
778
- history (list): The conversation history used by ChatInterface.
779
- temperature (float): The temperature for generating the response.
780
- max_new_tokens (int): The maximum number of new tokens to generate.
781
- Returns:
782
- str: The generated response.
783
- """
784
- print("\033[1;32;40mstart\033[0m")
785
- print("len(message)", len(message))
786
- if len(message) <= 10:
787
- yield "Hi, I am TxAgent, an assistant for answering biomedical questions. Please provide a valid message with a string longer than 10 characters."
788
- return "Please provide a valid message."
789
- outputs = []
790
- outputs_str = ''
791
- last_outputs = []
792
-
793
- picked_tools_prompt, call_agent_level = self.initialize_tools_prompt(
794
- call_agent,
795
- call_agent_level,
796
- message)
797
-
798
- conversation = self.initialize_conversation(
799
- message,
800
- conversation=conversation,
801
- history=history)
802
- history = []
803
-
804
- next_round = True
805
- function_call_messages = []
806
- current_round = 0
807
- enable_summary = False
808
- last_status = {} # for summary
809
- token_overflow = False
810
- if self.enable_checker:
811
- checker = ReasoningTraceChecker(
812
- message, conversation, init_index=len(conversation))
813
-
814
- try:
815
- while next_round and current_round < max_round:
816
- current_round += 1
817
- if len(last_outputs) > 0:
818
- function_call_messages, picked_tools_prompt, special_tool_call, current_gradio_history = yield from self.run_function_call_stream(
819
- last_outputs, return_message=True,
820
- existing_tools_prompt=picked_tools_prompt,
821
- message_for_call_agent=message,
822
- call_agent=call_agent,
823
- call_agent_level=call_agent_level,
824
- temperature=temperature)
825
- history.extend(current_gradio_history)
826
- if special_tool_call == 'Finish':
827
- yield history
828
- next_round = False
829
- conversation.extend(function_call_messages)
830
- return function_call_messages[0]['content']
831
- elif special_tool_call == 'RequireClarification' or special_tool_call == 'DirectResponse':
832
- history.append(
833
- ChatMessage(role="assistant", content=history[-1].content))
834
- yield history
835
- next_round = False
836
- return history[-1].content
837
- if (self.enable_summary or token_overflow) and not call_agent:
838
- if token_overflow:
839
- print("token_overflow, using summary")
840
- enable_summary = True
841
- last_status = self.function_result_summary(
842
- conversation, status=last_status,
843
- enable_summary=enable_summary)
844
- if function_call_messages is not None:
845
- conversation.extend(function_call_messages)
846
- formated_md_function_call_messages = tool_result_format(
847
- function_call_messages)
848
- yield history
849
- else:
850
- next_round = False
851
- conversation.extend(
852
- [{"role": "assistant", "content": ''.join(last_outputs)}])
853
- return ''.join(last_outputs).replace("</s>", "")
854
- if self.enable_checker:
855
- good_status, wrong_info = checker.check_conversation()
856
- if not good_status:
857
- next_round = False
858
- print("Internal error in reasoning: " + wrong_info)
859
- break
860
- last_outputs = []
861
- last_outputs_str, token_overflow = self.llm_infer(
862
- messages=conversation,
863
- temperature=temperature,
864
- tools=picked_tools_prompt,
865
- skip_special_tokens=False,
866
- max_new_tokens=max_new_tokens,
867
- max_token=max_token,
868
- seed=seed,
869
- check_token_status=True)
870
- last_thought = last_outputs_str.split("[TOOL_CALLS]")[0]
871
- for each in history:
872
- if each.metadata is not None:
873
- each.metadata['status'] = 'done'
874
- if '[FinalAnswer]' in last_thought:
875
- final_thought, final_answer = last_thought.split(
876
- '[FinalAnswer]')
877
- history.append(
878
- ChatMessage(role="assistant",
879
- content=final_thought.strip())
880
- )
881
- yield history
882
- history.append(
883
- ChatMessage(
884
- role="assistant", content="**Answer**:\n"+final_answer.strip())
885
- )
886
- yield history
887
- else:
888
- history.append(ChatMessage(
889
- role="assistant", content=last_thought))
890
- yield history
891
-
892
- last_outputs.append(last_outputs_str)
893
-
894
- if next_round:
895
- if self.force_finish:
896
- last_outputs_str = self.get_answer_based_on_unfinished_reasoning(
897
- conversation, temperature, max_new_tokens, max_token)
898
- for each in history:
899
- if each.metadata is not None:
900
- each.metadata['status'] = 'done'
901
- if '[FinalAnswer]' in last_thought:
902
- final_thought, final_answer = last_thought.split(
903
- '[FinalAnswer]')
904
- history.append(
905
- ChatMessage(role="assistant",
906
- content=final_thought.strip())
907
- )
908
- yield history
909
- history.append(
910
- ChatMessage(
911
- role="assistant", content="**Answer**:\n"+final_answer.strip())
912
- )
913
- yield history
914
- else:
915
- yield "The number of rounds exceeds the maximum limit!"
916
-
917
- except Exception as e:
918
- print(f"Error: {e}")
919
- if self.force_finish:
920
- last_outputs_str = self.get_answer_based_on_unfinished_reasoning(
921
- conversation,
922
- temperature,
923
- max_new_tokens,
924
- max_token)
925
- for each in history:
926
- if each.metadata is not None:
927
- each.metadata['status'] = 'done'
928
- if '[FinalAnswer]' in last_thought or '"name": "Finish",' in last_outputs_str:
929
- if '[FinalAnswer]' in last_thought:
930
- final_thought, final_answer = last_thought.split('[FinalAnswer]', 1)
931
- else:
932
- final_thought = ""
933
- final_answer = last_thought
934
- history.append(
935
- ChatMessage(role="assistant",
936
- content=final_thought.strip())
937
- )
938
- yield history
939
- history.append(
940
- ChatMessage(
941
- role="assistant", content="**Answer**:\n" + final_answer.strip())
942
- )
943
- yield history
944
- else:
945
- return None
 
1
+ import gradio as gr
2
+ import os
3
+ import sys
4
+ import json
5
+ import gc
6
+ import numpy as np
7
+ from vllm import LLM, SamplingParams
8
+ from jinja2 import Template
9
+ from typing import List
10
+ import types
11
+ from tooluniverse import ToolUniverse
12
+ from gradio import ChatMessage
13
+ from .toolrag import ToolRAGModel
14
+ import torch
15
+ # near the top of txagent.py
16
+ import logging
17
+ logger = logging.getLogger(__name__)
18
+ logging.basicConfig(level=logging.INFO)
19
+
20
+ from .utils import NoRepeatSentenceProcessor, ReasoningTraceChecker, tool_result_format
21
+
22
+
23
+ class TxAgent:
24
+ def __init__(self, model_name,
25
+ rag_model_name,
26
+ tool_files_dict=None, # None leads to the default tool files in ToolUniverse
27
+ enable_finish=True,
28
+ enable_rag=True,
29
+ enable_summary=False,
30
+ init_rag_num=0,
31
+ step_rag_num=10,
32
+ summary_mode='step',
33
+ summary_skip_last_k=0,
34
+ summary_context_length=None,
35
+ force_finish=True,
36
+ avoid_repeat=True,
37
+ seed=None,
38
+ enable_checker=False,
39
+ enable_chat=False,
40
+ additional_default_tools=None,
41
+ ):
42
+ self.model_name = model_name
43
+ self.tokenizer = None
44
+ self.terminators = None
45
+ self.rag_model_name = rag_model_name
46
+ self.tool_files_dict = tool_files_dict
47
+ self.model = None
48
+ self.rag_model = ToolRAGModel(rag_model_name)
49
+ self.tooluniverse = None
50
+ # self.tool_desc = None
51
+ self.prompt_multi_step = "You are a helpful assistant that will solve problems through detailed, step-by-step reasoning and actions based on your reasoning. Typically, your actions will use the provided functions. You have access to the following functions."
52
+ self.self_prompt = "Strictly follow the instruction."
53
+ self.chat_prompt = "You are helpful assistant to chat with the user."
54
+ self.enable_finish = enable_finish
55
+ self.enable_rag = enable_rag
56
+ self.enable_summary = enable_summary
57
+ self.summary_mode = summary_mode
58
+ self.summary_skip_last_k = summary_skip_last_k
59
+ self.summary_context_length = summary_context_length
60
+ self.init_rag_num = init_rag_num
61
+ self.step_rag_num = step_rag_num
62
+ self.force_finish = force_finish
63
+ self.avoid_repeat = avoid_repeat
64
+ self.seed = seed
65
+ self.enable_checker = enable_checker
66
+ self.additional_default_tools = additional_default_tools
67
+ self.print_self_values()
68
+
69
+ def init_model(self):
70
+ self.load_models()
71
+ self.load_tooluniverse()
72
+ self.load_tool_desc_embedding()
73
+
74
+ def print_self_values(self):
75
+ for attr, value in self.__dict__.items():
76
+ print(f"{attr}: {value}")
77
+
78
+ def load_models(self, model_name=None):
79
+ if model_name is not None:
80
+ if model_name == self.model_name:
81
+ return f"The model {model_name} is already loaded."
82
+ self.model_name = model_name
83
+
84
+ self.model = LLM(model=self.model_name)
85
+ self.chat_template = Template(self.model.get_tokenizer().chat_template)
86
+ self.tokenizer = self.model.get_tokenizer()
87
+
88
+ return f"Model {model_name} loaded successfully."
89
+
90
+ def load_tooluniverse(self):
91
+ self.tooluniverse = ToolUniverse(tool_files=self.tool_files_dict)
92
+ self.tooluniverse.load_tools()
93
+ special_tools = self.tooluniverse.prepare_tool_prompts(
94
+ self.tooluniverse.tool_category_dicts["special_tools"])
95
+ self.special_tools_name = [tool['name'] for tool in special_tools]
96
+
97
+ def load_tool_desc_embedding(self):
98
+ self.rag_model.load_tool_desc_embedding(self.tooluniverse)
99
+
100
+ def rag_infer(self, query, top_k=5):
101
+ return self.rag_model.rag_infer(query, top_k)
102
+
103
+ def initialize_tools_prompt(self, call_agent, call_agent_level, message):
104
+ picked_tools_prompt = []
105
+ picked_tools_prompt = self.add_special_tools(
106
+ picked_tools_prompt, call_agent=call_agent)
107
+ if call_agent:
108
+ call_agent_level += 1
109
+ if call_agent_level >= 2:
110
+ call_agent = False
111
+
112
+ if not call_agent:
113
+ picked_tools_prompt += self.tool_RAG(
114
+ message=message, rag_num=self.init_rag_num)
115
+ return picked_tools_prompt, call_agent_level
116
+
117
+ def initialize_conversation(self, message, conversation=None, history=None):
118
+ if conversation is None:
119
+ conversation = []
120
+
121
+ conversation = self.set_system_prompt(
122
+ conversation, self.prompt_multi_step)
123
+ if history is not None:
124
+ if len(history) == 0:
125
+ conversation = []
126
+ print("clear conversation successfully")
127
+ else:
128
+ for i in range(len(history)):
129
+ if history[i]['role'] == 'user':
130
+ if i-1 >= 0 and history[i-1]['role'] == 'assistant':
131
+ conversation.append(
132
+ {"role": "assistant", "content": history[i-1]['content']})
133
+ conversation.append(
134
+ {"role": "user", "content": history[i]['content']})
135
+ if i == len(history)-1 and history[i]['role'] == 'assistant':
136
+ conversation.append(
137
+ {"role": "assistant", "content": history[i]['content']})
138
+
139
+ conversation.append({"role": "user", "content": message})
140
+
141
+ return conversation
142
+
143
+ def tool_RAG(self, message=None,
144
+ picked_tool_names=None,
145
+ existing_tools_prompt=[],
146
+ rag_num=5,
147
+ return_call_result=False):
148
+ extra_factor = 30 # Factor to retrieve more than rag_num
149
+ if picked_tool_names is None:
150
+ assert picked_tool_names is not None or message is not None
151
+ picked_tool_names = self.rag_infer(
152
+ message, top_k=rag_num*extra_factor)
153
+
154
+ picked_tool_names_no_special = []
155
+ for tool in picked_tool_names:
156
+ if tool not in self.special_tools_name:
157
+ picked_tool_names_no_special.append(tool)
158
+ picked_tool_names_no_special = picked_tool_names_no_special[:rag_num]
159
+ picked_tool_names = picked_tool_names_no_special[:rag_num]
160
+
161
+ picked_tools = self.tooluniverse.get_tool_by_name(picked_tool_names)
162
+ picked_tools_prompt = self.tooluniverse.prepare_tool_prompts(
163
+ picked_tools)
164
+ if return_call_result:
165
+ return picked_tools_prompt, picked_tool_names
166
+ return picked_tools_prompt
167
+
168
+ def add_special_tools(self, tools, call_agent=False):
169
+ if self.enable_finish:
170
+ tools.append(self.tooluniverse.get_one_tool_by_one_name(
171
+ 'Finish', return_prompt=True))
172
+ print("Finish tool is added")
173
+ if call_agent:
174
+ tools.append(self.tooluniverse.get_one_tool_by_one_name(
175
+ 'CallAgent', return_prompt=True))
176
+ print("CallAgent tool is added")
177
+ else:
178
+ if self.enable_rag:
179
+ tools.append(self.tooluniverse.get_one_tool_by_one_name(
180
+ 'Tool_RAG', return_prompt=True))
181
+ print("Tool_RAG tool is added")
182
+
183
+ if self.additional_default_tools is not None:
184
+ for each_tool_name in self.additional_default_tools:
185
+ tool_prompt = self.tooluniverse.get_one_tool_by_one_name(
186
+ each_tool_name, return_prompt=True)
187
+ if tool_prompt is not None:
188
+ print(f"{each_tool_name} tool is added")
189
+ tools.append(tool_prompt)
190
+ return tools
191
+
192
+ def add_finish_tools(self, tools):
193
+ tools.append(self.tooluniverse.get_one_tool_by_one_name(
194
+ 'Finish', return_prompt=True))
195
+ print("Finish tool is added")
196
+ return tools
197
+
198
+ def set_system_prompt(self, conversation, sys_prompt):
199
+ if len(conversation) == 0:
200
+ conversation.append(
201
+ {"role": "system", "content": sys_prompt})
202
+ else:
203
+ conversation[0] = {"role": "system", "content": sys_prompt}
204
+ return conversation
205
+
206
+ def run_function_call(self, fcall_str,
207
+ return_message=False,
208
+ existing_tools_prompt=None,
209
+ message_for_call_agent=None,
210
+ call_agent=False,
211
+ call_agent_level=None,
212
+ temperature=None):
213
+
214
+ function_call_json, message = self.tooluniverse.extract_function_call_json(
215
+ fcall_str, return_message=return_message, verbose=False)
216
+ call_results = []
217
+ special_tool_call = ''
218
+ if function_call_json is not None:
219
+ if isinstance(function_call_json, list):
220
+ for i in range(len(function_call_json)):
221
+ print("\033[94mTool Call:\033[0m", function_call_json[i])
222
+ if function_call_json[i]["name"] == 'Finish':
223
+ special_tool_call = 'Finish'
224
+ break
225
+ elif function_call_json[i]["name"] == 'Tool_RAG':
226
+ new_tools_prompt, call_result = self.tool_RAG(
227
+ message=message,
228
+ existing_tools_prompt=existing_tools_prompt,
229
+ rag_num=self.step_rag_num,
230
+ return_call_result=True)
231
+ existing_tools_prompt += new_tools_prompt
232
+ elif function_call_json[i]["name"] == 'CallAgent':
233
+ if call_agent_level < 2 and call_agent:
234
+ solution_plan = function_call_json[i]['arguments']['solution']
235
+ full_message = (
236
+ message_for_call_agent +
237
+ "\nYou must follow the following plan to answer the question: " +
238
+ str(solution_plan)
239
+ )
240
+ call_result = self.run_multistep_agent(
241
+ full_message, temperature=temperature,
242
+ max_new_tokens=1024, max_token=99999,
243
+ call_agent=False, call_agent_level=call_agent_level)
244
+ call_result = call_result.split(
245
+ '[FinalAnswer]')[-1].strip()
246
+ else:
247
+ call_result = "Error: The CallAgent has been disabled. Please proceed with your reasoning process to solve this question."
248
+ else:
249
+ call_result = self.tooluniverse.run_one_function(
250
+ function_call_json[i])
251
+
252
+ call_id = self.tooluniverse.call_id_gen()
253
+ function_call_json[i]["call_id"] = call_id
254
+ print("\033[94mTool Call Result:\033[0m", call_result)
255
+ call_results.append({
256
+ "role": "tool",
257
+ "content": json.dumps({"tool_name": function_call_json[i]["name"], "content": call_result, "call_id": call_id})
258
+ })
259
+ else:
260
+ call_results.append({
261
+ "role": "tool",
262
+ "content": json.dumps({"content": "Not a valid function call, please check the function call format."})
263
+ })
264
+
265
+ revised_messages = [{
266
+ "role": "assistant",
267
+ "content": message.strip(),
268
+ "tool_calls": json.dumps(function_call_json)
269
+ }] + call_results
270
+
271
+ # Yield the final result.
272
+ return revised_messages, existing_tools_prompt, special_tool_call
273
+
274
+ def run_function_call_stream(self, fcall_str,
275
+ return_message=False,
276
+ existing_tools_prompt=None,
277
+ message_for_call_agent=None,
278
+ call_agent=False,
279
+ call_agent_level=None,
280
+ temperature=None,
281
+ return_gradio_history=True):
282
+
283
+ function_call_json, message = self.tooluniverse.extract_function_call_json(
284
+ fcall_str, return_message=return_message, verbose=False)
285
+ call_results = []
286
+ special_tool_call = ''
287
+ if return_gradio_history:
288
+ gradio_history = []
289
+ if function_call_json is not None:
290
+ if isinstance(function_call_json, list):
291
+ for i in range(len(function_call_json)):
292
+ if function_call_json[i]["name"] == 'Finish':
293
+ special_tool_call = 'Finish'
294
+ break
295
+ elif function_call_json[i]["name"] == 'Tool_RAG':
296
+ new_tools_prompt, call_result = self.tool_RAG(
297
+ message=message,
298
+ existing_tools_prompt=existing_tools_prompt,
299
+ rag_num=self.step_rag_num,
300
+ return_call_result=True)
301
+ existing_tools_prompt += new_tools_prompt
302
+ elif function_call_json[i]["name"] == 'DirectResponse':
303
+ call_result = function_call_json[i]['arguments']['respose']
304
+ special_tool_call = 'DirectResponse'
305
+ elif function_call_json[i]["name"] == 'RequireClarification':
306
+ call_result = function_call_json[i]['arguments']['unclear_question']
307
+ special_tool_call = 'RequireClarification'
308
+ elif function_call_json[i]["name"] == 'CallAgent':
309
+ if call_agent_level < 2 and call_agent:
310
+ solution_plan = function_call_json[i]['arguments']['solution']
311
+ full_message = (
312
+ message_for_call_agent +
313
+ "\nYou must follow the following plan to answer the question: " +
314
+ str(solution_plan)
315
+ )
316
+ sub_agent_task = "Sub TxAgent plan: " + \
317
+ str(solution_plan)
318
+ # When streaming, yield responses as they arrive.
319
+ call_result = yield from self.run_gradio_chat(
320
+ full_message, history=[], temperature=temperature,
321
+ max_new_tokens=1024, max_token=99999,
322
+ call_agent=False, call_agent_level=call_agent_level,
323
+ conversation=None,
324
+ sub_agent_task=sub_agent_task)
325
+
326
+ call_result = call_result.split(
327
+ '[FinalAnswer]')[-1]
328
+ else:
329
+ call_result = "Error: The CallAgent has been disabled. Please proceed with your reasoning process to solve this question."
330
+ else:
331
+ call_result = self.tooluniverse.run_one_function(
332
+ function_call_json[i])
333
+
334
+ call_id = self.tooluniverse.call_id_gen()
335
+ function_call_json[i]["call_id"] = call_id
336
+ call_results.append({
337
+ "role": "tool",
338
+ "content": json.dumps({"content": call_result, "call_id": call_id})
339
+ })
340
+ if return_gradio_history and function_call_json[i]["name"] != 'Finish':
341
+ if function_call_json[i]["name"] == 'Tool_RAG':
342
+ gradio_history.append(ChatMessage(role="assistant", content=str(call_result), metadata={
343
+ "title": "🧰 "+function_call_json[i]['name'], "log": str(function_call_json[i]['arguments'])}))
344
+
345
+ else:
346
+ gradio_history.append(ChatMessage(role="assistant", content=str(call_result), metadata={
347
+ "title": "⚒️ "+function_call_json[i]['name'], "log": str(function_call_json[i]['arguments'])}))
348
+ else:
349
+ call_results.append({
350
+ "role": "tool",
351
+ "content": json.dumps({"content": "Not a valid function call, please check the function call format."})
352
+ })
353
+
354
+ revised_messages = [{
355
+ "role": "assistant",
356
+ "content": message.strip(),
357
+ "tool_calls": json.dumps(function_call_json)
358
+ }] + call_results
359
+
360
+ # Yield the final result.
361
+ if return_gradio_history:
362
+ return revised_messages, existing_tools_prompt, special_tool_call, gradio_history
363
+ else:
364
+ return revised_messages, existing_tools_prompt, special_tool_call
365
+
366
+ def get_answer_based_on_unfinished_reasoning(self, conversation, temperature, max_new_tokens, max_token, outputs=None):
367
+ if conversation[-1]['role'] == 'assisant':
368
+ conversation.append(
369
+ {'role': 'tool', 'content': 'Errors happen during the function call, please come up with the final answer with the current information.'})
370
+ finish_tools_prompt = self.add_finish_tools([])
371
+
372
+ last_outputs_str = self.llm_infer(messages=conversation,
373
+ temperature=temperature,
374
+ tools=finish_tools_prompt,
375
+ output_begin_string='Since I cannot continue reasoning, I will provide the final answer based on the current information and general knowledge.\n\n[FinalAnswer]',
376
+ skip_special_tokens=True,
377
+ max_new_tokens=max_new_tokens, max_token=max_token)
378
+ print(last_outputs_str)
379
+ return last_outputs_str
380
+
381
+ def run_multistep_agent(self, message: str,
382
+ temperature: float,
383
+ max_new_tokens: int,
384
+ max_token: int,
385
+ max_round: int = 20,
386
+ call_agent=False,
387
+ call_agent_level=0) -> str:
388
+ """
389
+ Generate a streaming response using the llama3-8b model.
390
+ Args:
391
+ message (str): The input message.
392
+ temperature (float): The temperature for generating the response.
393
+ max_new_tokens (int): The maximum number of new tokens to generate.
394
+ Returns:
395
+ str: The generated response.
396
+ """
397
+ print("\033[1;32;40mstart\033[0m")
398
+ picked_tools_prompt, call_agent_level = self.initialize_tools_prompt(
399
+ call_agent, call_agent_level, message)
400
+ conversation = self.initialize_conversation(message)
401
+
402
+ outputs = []
403
+ last_outputs = []
404
+ next_round = True
405
+ function_call_messages = []
406
+ current_round = 0
407
+ token_overflow = False
408
+ enable_summary = False
409
+ last_status = {}
410
+
411
+ if self.enable_checker:
412
+ checker = ReasoningTraceChecker(message, conversation)
413
+ try:
414
+ while next_round and current_round < max_round:
415
+ current_round += 1
416
+ if len(outputs) > 0:
417
+ function_call_messages, picked_tools_prompt, special_tool_call = self.run_function_call(
418
+ last_outputs, return_message=True,
419
+ existing_tools_prompt=picked_tools_prompt,
420
+ message_for_call_agent=message,
421
+ call_agent=call_agent,
422
+ call_agent_level=call_agent_level,
423
+ temperature=temperature)
424
+
425
+ if special_tool_call == 'Finish':
426
+ next_round = False
427
+ conversation.extend(function_call_messages)
428
+ if isinstance(function_call_messages[0]['content'], types.GeneratorType):
429
+ function_call_messages[0]['content'] = next(
430
+ function_call_messages[0]['content'])
431
+ return function_call_messages[0]['content'].split('[FinalAnswer]')[-1]
432
+
433
+ if (self.enable_summary or token_overflow) and not call_agent:
434
+ if token_overflow:
435
+ print("token_overflow, using summary")
436
+ enable_summary = True
437
+ last_status = self.function_result_summary(
438
+ conversation, status=last_status, enable_summary=enable_summary)
439
+
440
+ if function_call_messages is not None:
441
+ conversation.extend(function_call_messages)
442
+ outputs.append(tool_result_format(
443
+ function_call_messages))
444
+ else:
445
+ next_round = False
446
+ conversation.extend(
447
+ [{"role": "assistant", "content": ''.join(last_outputs)}])
448
+ return ''.join(last_outputs).replace("</s>", "")
449
+ if self.enable_checker:
450
+ good_status, wrong_info = checker.check_conversation()
451
+ if not good_status:
452
+ next_round = False
453
+ print(
454
+ "Internal error in reasoning: " + wrong_info)
455
+ break
456
+ last_outputs = []
457
+ outputs.append("### TxAgent:\n")
458
+ last_outputs_str, token_overflow = self.llm_infer(messages=conversation,
459
+ temperature=temperature,
460
+ tools=picked_tools_prompt,
461
+ skip_special_tokens=False,
462
+ max_new_tokens=max_new_tokens, max_token=max_token,
463
+ check_token_status=True)
464
+ if last_outputs_str is None:
465
+ next_round = False
466
+ print(
467
+ "The number of tokens exceeds the maximum limit.")
468
+ else:
469
+ last_outputs.append(last_outputs_str)
470
+ if max_round == current_round:
471
+ print("The number of rounds exceeds the maximum limit!")
472
+ if self.force_finish:
473
+ return self.get_answer_based_on_unfinished_reasoning(conversation, temperature, max_new_tokens, max_token)
474
+ else:
475
+ return None
476
+
477
+ except Exception as e:
478
+ print(f"Error: {e}")
479
+ if self.force_finish:
480
+ return self.get_answer_based_on_unfinished_reasoning(conversation, temperature, max_new_tokens, max_token)
481
+ else:
482
+ return None
483
+
484
+ def build_logits_processor(self, messages, llm):
485
+ # Use the tokenizer from the LLM instance.
486
+ tokenizer = llm.get_tokenizer()
487
+ if self.avoid_repeat and len(messages) > 2:
488
+ assistant_messages = []
489
+ for i in range(1, len(messages) + 1):
490
+ if messages[-i]['role'] == 'assistant':
491
+ assistant_messages.append(messages[-i]['content'])
492
+ if len(assistant_messages) == 2:
493
+ break
494
+ forbidden_ids = [tokenizer.encode(
495
+ msg, add_special_tokens=False) for msg in assistant_messages]
496
+ return [NoRepeatSentenceProcessor(forbidden_ids, 5)]
497
+ else:
498
+ return None
499
+
500
+ def llm_infer(self, messages, temperature=0.1, tools=None,
501
+ output_begin_string=None, max_new_tokens=2048,
502
+ max_token=None, skip_special_tokens=True,
503
+ model=None, tokenizer=None, terminators=None, seed=None, check_token_status=False):
504
+
505
+ if model is None:
506
+ model = self.model
507
+
508
+ logits_processor = self.build_logits_processor(messages, model)
509
+ sampling_params = SamplingParams(
510
+ temperature=temperature,
511
+ max_tokens=max_new_tokens,
512
+
513
+ seed=seed if seed is not None else self.seed,
514
+ )
515
+
516
+ prompt = self.chat_template.render(
517
+ messages=messages, tools=tools, add_generation_prompt=True)
518
+ if output_begin_string is not None:
519
+ prompt += output_begin_string
520
+
521
+ if check_token_status and max_token is not None:
522
+ token_overflow = False
523
+ num_input_tokens = len(self.tokenizer.encode(
524
+ prompt, return_tensors="pt")[0])
525
+ if max_token is not None:
526
+ if num_input_tokens > max_token:
527
+ torch.cuda.empty_cache()
528
+ gc.collect()
529
+ print("Number of input tokens before inference:",
530
+ num_input_tokens)
531
+ logger.info(
532
+ "The number of tokens exceeds the maximum limit!!!!")
533
+ token_overflow = True
534
+ return None, token_overflow
535
+ output = model.generate(
536
+ prompt,
537
+ sampling_params=sampling_params,
538
+ )
539
+ output = output[0].outputs[0].text
540
+ print("\033[92m" + output + "\033[0m")
541
+ if check_token_status and max_token is not None:
542
+ return output, token_overflow
543
+
544
+ return output
545
+
546
+ def run_self_agent(self, message: str,
547
+ temperature: float,
548
+ max_new_tokens: int,
549
+ max_token: int) -> str:
550
+
551
+ print("\033[1;32;40mstart self agent\033[0m")
552
+ conversation = []
553
+ conversation = self.set_system_prompt(conversation, self.self_prompt)
554
+ conversation.append({"role": "user", "content": message})
555
+ return self.llm_infer(messages=conversation,
556
+ temperature=temperature,
557
+ tools=None,
558
+ max_new_tokens=max_new_tokens, max_token=max_token)
559
+
560
+ def run_chat_agent(self, message: str,
561
+ temperature: float,
562
+ max_new_tokens: int,
563
+ max_token: int) -> str:
564
+
565
+ print("\033[1;32;40mstart chat agent\033[0m")
566
+ conversation = []
567
+ conversation = self.set_system_prompt(conversation, self.chat_prompt)
568
+ conversation.append({"role": "user", "content": message})
569
+ return self.llm_infer(messages=conversation,
570
+ temperature=temperature,
571
+ tools=None,
572
+ max_new_tokens=max_new_tokens, max_token=max_token)
573
+
574
+ def run_format_agent(self, message: str,
575
+ answer: str,
576
+ temperature: float,
577
+ max_new_tokens: int,
578
+ max_token: int) -> str:
579
+
580
+ print("\033[1;32;40mstart format agent\033[0m")
581
+ if '[FinalAnswer]' in answer:
582
+ possible_final_answer = answer.split("[FinalAnswer]")[-1]
583
+ elif "\n\n" in answer:
584
+ possible_final_answer = answer.split("\n\n")[-1]
585
+ else:
586
+ possible_final_answer = answer.strip()
587
+ if len(possible_final_answer) == 1:
588
+ choice = possible_final_answer[0]
589
+ if choice in ['A', 'B', 'C', 'D', 'E']:
590
+ return choice
591
+ elif len(possible_final_answer) > 1:
592
+ if possible_final_answer[1] == ':':
593
+ choice = possible_final_answer[0]
594
+ if choice in ['A', 'B', 'C', 'D', 'E']:
595
+ print("choice", choice)
596
+ return choice
597
+
598
+ conversation = []
599
+ format_prompt = f"You are helpful assistant to transform the answer of agent to the final answer of 'A', 'B', 'C', 'D'."
600
+ conversation = self.set_system_prompt(conversation, format_prompt)
601
+ conversation.append({"role": "user", "content": message +
602
+ "\nThe final answer of agent:" + answer + "\n The answer is (must be a letter):"})
603
+ return self.llm_infer(messages=conversation,
604
+ temperature=temperature,
605
+ tools=None,
606
+ max_new_tokens=max_new_tokens, max_token=max_token)
607
+
608
+ def run_summary_agent(self, thought_calls: str,
609
+ function_response: str,
610
+ temperature: float,
611
+ max_new_tokens: int,
612
+ max_token: int) -> str:
613
+ print("\033[1;32;40mSummarized Tool Result:\033[0m")
614
+ generate_tool_result_summary_training_prompt = """Thought and function calls:
615
+ {thought_calls}
616
+
617
+ Function calls' responses:
618
+ \"\"\"
619
+ {function_response}
620
+ \"\"\"
621
+
622
+ Based on the Thought and function calls, and the function calls' responses, you need to generate a summary of the function calls' responses that fulfills the requirements of the thought. The summary MUST BE ONE sentence and include all necessary information.
623
+
624
+ Directly respond with the summarized sentence of the function calls' responses only.
625
+
626
+ Generate **one summarized sentence** about "function calls' responses" with necessary information, and respond with a string:
627
+ """.format(thought_calls=thought_calls, function_response=function_response)
628
+ conversation = []
629
+ conversation.append(
630
+ {"role": "user", "content": generate_tool_result_summary_training_prompt})
631
+ output = self.llm_infer(messages=conversation,
632
+ temperature=temperature,
633
+ tools=None,
634
+ max_new_tokens=max_new_tokens, max_token=max_token)
635
+
636
+ if '[' in output:
637
+ output = output.split('[')[0]
638
+ return output
639
+
640
+ def function_result_summary(self, input_list, status, enable_summary):
641
+ """
642
+ Processes the input list, extracting information from sequences of 'user', 'tool', 'assistant' roles.
643
+ Supports 'length' and 'step' modes, and skips the last 'k' groups.
644
+
645
+ Parameters:
646
+ input_list (list): A list of dictionaries containing role and other information.
647
+ summary_skip_last_k (int): Number of groups to skip from the end. Defaults to 0.
648
+ summary_context_length (int): The context length threshold for the 'length' mode.
649
+ last_processed_index (tuple or int): The last processed index.
650
+
651
+ Returns:
652
+ list: A list of extracted information from valid sequences.
653
+ """
654
+ if 'tool_call_step' not in status:
655
+ status['tool_call_step'] = 0
656
+
657
+ for idx in range(len(input_list)):
658
+ pos_id = len(input_list)-idx-1
659
+ if input_list[pos_id]['role'] == 'assistant':
660
+ if 'tool_calls' in input_list[pos_id]:
661
+ if 'Tool_RAG' in str(input_list[pos_id]['tool_calls']):
662
+ status['tool_call_step'] += 1
663
+ break
664
+
665
+ if 'step' in status:
666
+ status['step'] += 1
667
+ else:
668
+ status['step'] = 0
669
+
670
+ if not enable_summary:
671
+ return status
672
+
673
+ if 'summarized_index' not in status:
674
+ status['summarized_index'] = 0
675
+
676
+ if 'summarized_step' not in status:
677
+ status['summarized_step'] = 0
678
+
679
+ if 'previous_length' not in status:
680
+ status['previous_length'] = 0
681
+
682
+ if 'history' not in status:
683
+ status['history'] = []
684
+
685
+ function_response = ''
686
+ idx = 0
687
+ current_summarized_index = status['summarized_index']
688
+
689
+ status['history'].append(self.summary_mode == 'step' and status['summarized_step']
690
+ < status['step']-status['tool_call_step']-self.summary_skip_last_k)
691
+
692
+ idx = current_summarized_index
693
+ while idx < len(input_list):
694
+ if (self.summary_mode == 'step' and status['summarized_step'] < status['step']-status['tool_call_step']-self.summary_skip_last_k) or (self.summary_mode == 'length' and status['previous_length'] > self.summary_context_length):
695
+
696
+ if input_list[idx]['role'] == 'assistant':
697
+ if 'Tool_RAG' in str(input_list[idx]['tool_calls']):
698
+ this_thought_calls = None
699
+ else:
700
+ if len(function_response) != 0:
701
+ print("internal summary")
702
+ status['summarized_step'] += 1
703
+ result_summary = self.run_summary_agent(
704
+ thought_calls=this_thought_calls,
705
+ function_response=function_response,
706
+ temperature=0.1,
707
+ max_new_tokens=1024,
708
+ max_token=99999
709
+ )
710
+
711
+ input_list.insert(
712
+ last_call_idx+1, {'role': 'tool', 'content': result_summary})
713
+ status['summarized_index'] = last_call_idx + 2
714
+ idx += 1
715
+
716
+ last_call_idx = idx
717
+ this_thought_calls = input_list[idx]['content'] + \
718
+ input_list[idx]['tool_calls']
719
+ function_response = ''
720
+
721
+ elif input_list[idx]['role'] == 'tool' and this_thought_calls is not None:
722
+ function_response += input_list[idx]['content']
723
+ del input_list[idx]
724
+ idx -= 1
725
+
726
+ else:
727
+ break
728
+ idx += 1
729
+
730
+ if len(function_response) != 0:
731
+ status['summarized_step'] += 1
732
+ result_summary = self.run_summary_agent(
733
+ thought_calls=this_thought_calls,
734
+ function_response=function_response,
735
+ temperature=0.1,
736
+ max_new_tokens=1024,
737
+ max_token=99999
738
+ )
739
+
740
+ tool_calls = json.loads(input_list[last_call_idx]['tool_calls'])
741
+ for tool_call in tool_calls:
742
+ del tool_call['call_id']
743
+ input_list[last_call_idx]['tool_calls'] = json.dumps(tool_calls)
744
+ input_list.insert(
745
+ last_call_idx+1, {'role': 'tool', 'content': result_summary})
746
+ status['summarized_index'] = last_call_idx + 2
747
+
748
+ return status
749
+
750
+ # Following are Gradio related functions
751
+
752
+ # General update method that accepts any new arguments through kwargs
753
+ def update_parameters(self, **kwargs):
754
+ for key, value in kwargs.items():
755
+ if hasattr(self, key):
756
+ setattr(self, key, value)
757
+
758
+ # Return the updated attributes
759
+ updated_attributes = {key: value for key,
760
+ value in kwargs.items() if hasattr(self, key)}
761
+ return updated_attributes
762
+
763
+ def run_gradio_chat(self, message: str,
764
+ history: list,
765
+ temperature: float,
766
+ max_new_tokens: int,
767
+ max_token: int,
768
+ call_agent: bool,
769
+ conversation: gr.State,
770
+ max_round: int = 20,
771
+ seed: int = None,
772
+ call_agent_level: int = 0,
773
+ sub_agent_task: str = None) -> str:
774
+ """
775
+ Generate a streaming response using the llama3-8b model.
776
+ Args:
777
+ message (str): The input message.
778
+ history (list): The conversation history used by ChatInterface.
779
+ temperature (float): The temperature for generating the response.
780
+ max_new_tokens (int): The maximum number of new tokens to generate.
781
+ Returns:
782
+ str: The generated response.
783
+ """
784
+ print("\033[1;32;40mstart\033[0m")
785
+ print("len(message)", len(message))
786
+ if len(message) <= 10:
787
+ yield "Hi, I am TxAgent, an assistant for answering biomedical questions. Please provide a valid message with a string longer than 10 characters."
788
+ return "Please provide a valid message."
789
+ outputs = []
790
+ outputs_str = ''
791
+ last_outputs = []
792
+
793
+ picked_tools_prompt, call_agent_level = self.initialize_tools_prompt(
794
+ call_agent,
795
+ call_agent_level,
796
+ message)
797
+
798
+ conversation = self.initialize_conversation(
799
+ message,
800
+ conversation=conversation,
801
+ history=history)
802
+ history = []
803
+
804
+ next_round = True
805
+ function_call_messages = []
806
+ current_round = 0
807
+ enable_summary = False
808
+ last_status = {} # for summary
809
+ token_overflow = False
810
+ if self.enable_checker:
811
+ checker = ReasoningTraceChecker(
812
+ message, conversation, init_index=len(conversation))
813
+
814
+ try:
815
+ while next_round and current_round < max_round:
816
+ current_round += 1
817
+ if len(last_outputs) > 0:
818
+ function_call_messages, picked_tools_prompt, special_tool_call, current_gradio_history = yield from self.run_function_call_stream(
819
+ last_outputs, return_message=True,
820
+ existing_tools_prompt=picked_tools_prompt,
821
+ message_for_call_agent=message,
822
+ call_agent=call_agent,
823
+ call_agent_level=call_agent_level,
824
+ temperature=temperature)
825
+ history.extend(current_gradio_history)
826
+ if special_tool_call == 'Finish':
827
+ yield history
828
+ next_round = False
829
+ conversation.extend(function_call_messages)
830
+ return function_call_messages[0]['content']
831
+ elif special_tool_call == 'RequireClarification' or special_tool_call == 'DirectResponse':
832
+ history.append(
833
+ ChatMessage(role="assistant", content=history[-1].content))
834
+ yield history
835
+ next_round = False
836
+ return history[-1].content
837
+ if (self.enable_summary or token_overflow) and not call_agent:
838
+ if token_overflow:
839
+ print("token_overflow, using summary")
840
+ enable_summary = True
841
+ last_status = self.function_result_summary(
842
+ conversation, status=last_status,
843
+ enable_summary=enable_summary)
844
+ if function_call_messages is not None:
845
+ conversation.extend(function_call_messages)
846
+ formated_md_function_call_messages = tool_result_format(
847
+ function_call_messages)
848
+ yield history
849
+ else:
850
+ next_round = False
851
+ conversation.extend(
852
+ [{"role": "assistant", "content": ''.join(last_outputs)}])
853
+ return ''.join(last_outputs).replace("</s>", "")
854
+ if self.enable_checker:
855
+ good_status, wrong_info = checker.check_conversation()
856
+ if not good_status:
857
+ next_round = False
858
+ print("Internal error in reasoning: " + wrong_info)
859
+ break
860
+ last_outputs = []
861
+ last_outputs_str, token_overflow = self.llm_infer(
862
+ messages=conversation,
863
+ temperature=temperature,
864
+ tools=picked_tools_prompt,
865
+ skip_special_tokens=False,
866
+ max_new_tokens=max_new_tokens,
867
+ max_token=max_token,
868
+ seed=seed,
869
+ check_token_status=True)
870
+ last_thought = last_outputs_str.split("[TOOL_CALLS]")[0]
871
+ for each in history:
872
+ if each.metadata is not None:
873
+ each.metadata['status'] = 'done'
874
+ if '[FinalAnswer]' in last_thought:
875
+ final_thought, final_answer = last_thought.split(
876
+ '[FinalAnswer]')
877
+ history.append(
878
+ ChatMessage(role="assistant",
879
+ content=final_thought.strip())
880
+ )
881
+ yield history
882
+ history.append(
883
+ ChatMessage(
884
+ role="assistant", content="**Answer**:\n"+final_answer.strip())
885
+ )
886
+ yield history
887
+ else:
888
+ history.append(ChatMessage(
889
+ role="assistant", content=last_thought))
890
+ yield history
891
+
892
+ last_outputs.append(last_outputs_str)
893
+
894
+ if next_round:
895
+ if self.force_finish:
896
+ last_outputs_str = self.get_answer_based_on_unfinished_reasoning(
897
+ conversation, temperature, max_new_tokens, max_token)
898
+ for each in history:
899
+ if each.metadata is not None:
900
+ each.metadata['status'] = 'done'
901
+ if '[FinalAnswer]' in last_thought:
902
+ final_thought, final_answer = last_thought.split(
903
+ '[FinalAnswer]')
904
+ history.append(
905
+ ChatMessage(role="assistant",
906
+ content=final_thought.strip())
907
+ )
908
+ yield history
909
+ history.append(
910
+ ChatMessage(
911
+ role="assistant", content="**Answer**:\n"+final_answer.strip())
912
+ )
913
+ yield history
914
+ else:
915
+ yield "The number of rounds exceeds the maximum limit!"
916
+
917
+ except Exception as e:
918
+ print(f"Error: {e}")
919
+ if self.force_finish:
920
+ last_outputs_str = self.get_answer_based_on_unfinished_reasoning(
921
+ conversation,
922
+ temperature,
923
+ max_new_tokens,
924
+ max_token)
925
+ for each in history:
926
+ if each.metadata is not None:
927
+ each.metadata['status'] = 'done'
928
+ if '[FinalAnswer]' in last_thought or '"name": "Finish",' in last_outputs_str:
929
+ if '[FinalAnswer]' in last_thought:
930
+ final_thought, final_answer = last_thought.split('[FinalAnswer]', 1)
931
+ else:
932
+ final_thought = ""
933
+ final_answer = last_thought
934
+ history.append(
935
+ ChatMessage(role="assistant",
936
+ content=final_thought.strip())
937
+ )
938
+ yield history
939
+ history.append(
940
+ ChatMessage(
941
+ role="assistant", content="**Answer**:\n" + final_answer.strip())
942
+ )
943
+ yield history
944
+ else:
945
+ return None