Ali2206 commited on
Commit
aba9ae9
·
verified ·
1 Parent(s): 499e72e

Update src/txagent/txagent.py

Browse files
Files changed (1) hide show
  1. src/txagent/txagent.py +322 -511
src/txagent/txagent.py CHANGED
@@ -11,19 +11,23 @@ import types
11
  from tooluniverse import ToolUniverse
12
  from gradio import ChatMessage
13
  from .toolrag import ToolRAGModel
 
 
14
 
15
- from .utils import NoRepeatSentenceProcessor, ReasoningTraceChecker, tool_result_format
 
16
 
 
17
 
18
  class TxAgent:
19
  def __init__(self, model_name,
20
  rag_model_name,
21
- tool_files_dict=None, # None leads to the default tool files in ToolUniverse
22
  enable_finish=True,
23
- enable_rag=True,
24
  enable_summary=False,
25
  init_rag_num=0,
26
- step_rag_num=10,
27
  summary_mode='step',
28
  summary_skip_last_k=0,
29
  summary_context_length=None,
@@ -32,8 +36,7 @@ class TxAgent:
32
  seed=None,
33
  enable_checker=False,
34
  enable_chat=False,
35
- additional_default_tools=None,
36
- ):
37
  self.model_name = model_name
38
  self.tokenizer = None
39
  self.terminators = None
@@ -42,10 +45,9 @@ class TxAgent:
42
  self.model = None
43
  self.rag_model = ToolRAGModel(rag_model_name)
44
  self.tooluniverse = None
45
- # self.tool_desc = None
46
- self.prompt_multi_step = "You are a helpful assistant that will solve problems through detailed, step-by-step reasoning and actions based on your reasoning. Typically, your actions will use the provided functions. You have access to the following functions."
47
  self.self_prompt = "Strictly follow the instruction."
48
- self.chat_prompt = "You are helpful assistant to chat with the user."
49
  self.enable_finish = enable_finish
50
  self.enable_rag = enable_rag
51
  self.enable_summary = enable_summary
@@ -59,16 +61,11 @@ class TxAgent:
59
  self.seed = seed
60
  self.enable_checker = enable_checker
61
  self.additional_default_tools = additional_default_tools
62
- self.print_self_values()
63
 
64
  def init_model(self):
65
  self.load_models()
66
  self.load_tooluniverse()
67
- self.load_tool_desc_embedding()
68
-
69
- def print_self_values(self):
70
- for attr, value in self.__dict__.items():
71
- print(f"{attr}: {value}")
72
 
73
  def load_models(self, model_name=None):
74
  if model_name is not None:
@@ -76,10 +73,10 @@ class TxAgent:
76
  return f"The model {model_name} is already loaded."
77
  self.model_name = model_name
78
 
79
- self.model = LLM(model=self.model_name)
80
  self.chat_template = Template(self.model.get_tokenizer().chat_template)
81
  self.tokenizer = self.model.get_tokenizer()
82
-
83
  return f"Model {model_name} loaded successfully."
84
 
85
  def load_tooluniverse(self):
@@ -88,9 +85,16 @@ class TxAgent:
88
  special_tools = self.tooluniverse.prepare_tool_prompts(
89
  self.tooluniverse.tool_category_dicts["special_tools"])
90
  self.special_tools_name = [tool['name'] for tool in special_tools]
 
91
 
92
  def load_tool_desc_embedding(self):
93
- self.rag_model.load_tool_desc_embedding(self.tooluniverse)
 
 
 
 
 
 
94
 
95
  def rag_infer(self, query, top_k=5):
96
  return self.rag_model.rag_infer(query, top_k)
@@ -103,10 +107,6 @@ class TxAgent:
103
  call_agent_level += 1
104
  if call_agent_level >= 2:
105
  call_agent = False
106
-
107
- if not call_agent:
108
- picked_tools_prompt += self.tool_RAG(
109
- message=message, rag_num=self.init_rag_num)
110
  return picked_tools_prompt, call_agent_level
111
 
112
  def initialize_conversation(self, message, conversation=None, history=None):
@@ -115,85 +115,56 @@ class TxAgent:
115
 
116
  conversation = self.set_system_prompt(
117
  conversation, self.prompt_multi_step)
118
- if history is not None:
119
- if len(history) == 0:
120
- conversation = []
121
- print("clear conversation successfully")
122
- else:
123
- for i in range(len(history)):
124
- if history[i]['role'] == 'user':
125
- if i-1 >= 0 and history[i-1]['role'] == 'assistant':
126
- conversation.append(
127
- {"role": "assistant", "content": history[i-1]['content']})
128
- conversation.append(
129
- {"role": "user", "content": history[i]['content']})
130
- if i == len(history)-1 and history[i]['role'] == 'assistant':
131
- conversation.append(
132
- {"role": "assistant", "content": history[i]['content']})
133
-
134
  conversation.append({"role": "user", "content": message})
135
-
136
  return conversation
137
 
138
  def tool_RAG(self, message=None,
139
  picked_tool_names=None,
140
  existing_tools_prompt=[],
141
- rag_num=5,
142
  return_call_result=False):
143
- extra_factor = 30 # Factor to retrieve more than rag_num
 
 
144
  if picked_tool_names is None:
145
  assert picked_tool_names is not None or message is not None
146
  picked_tool_names = self.rag_infer(
147
- message, top_k=rag_num*extra_factor)
148
 
149
- picked_tool_names_no_special = []
150
- for tool in picked_tool_names:
151
- if tool not in self.special_tools_name:
152
- picked_tool_names_no_special.append(tool)
153
- picked_tool_names_no_special = picked_tool_names_no_special[:rag_num]
154
  picked_tool_names = picked_tool_names_no_special[:rag_num]
155
 
156
  picked_tools = self.tooluniverse.get_tool_by_name(picked_tool_names)
157
- picked_tools_prompt = self.tooluniverse.prepare_tool_prompts(
158
- picked_tools)
159
  if return_call_result:
160
  return picked_tools_prompt, picked_tool_names
161
  return picked_tools_prompt
162
 
163
  def add_special_tools(self, tools, call_agent=False):
164
  if self.enable_finish:
165
- tools.append(self.tooluniverse.get_one_tool_by_one_name(
166
- 'Finish', return_prompt=True))
167
- print("Finish tool is added")
168
  if call_agent:
169
- tools.append(self.tooluniverse.get_one_tool_by_one_name(
170
- 'CallAgent', return_prompt=True))
171
- print("CallAgent tool is added")
172
- else:
173
- if self.enable_rag:
174
- tools.append(self.tooluniverse.get_one_tool_by_one_name(
175
- 'Tool_RAG', return_prompt=True))
176
- print("Tool_RAG tool is added")
177
-
178
- if self.additional_default_tools is not None:
179
- for each_tool_name in self.additional_default_tools:
180
- tool_prompt = self.tooluniverse.get_one_tool_by_one_name(
181
- each_tool_name, return_prompt=True)
182
- if tool_prompt is not None:
183
- print(f"{each_tool_name} tool is added")
184
- tools.append(tool_prompt)
185
  return tools
186
 
187
  def add_finish_tools(self, tools):
188
- tools.append(self.tooluniverse.get_one_tool_by_one_name(
189
- 'Finish', return_prompt=True))
190
- print("Finish tool is added")
191
  return tools
192
 
193
  def set_system_prompt(self, conversation, sys_prompt):
194
- if len(conversation) == 0:
195
- conversation.append(
196
- {"role": "system", "content": sys_prompt})
197
  else:
198
  conversation[0] = {"role": "system", "content": sys_prompt}
199
  return conversation
@@ -205,25 +176,23 @@ class TxAgent:
205
  call_agent=False,
206
  call_agent_level=None,
207
  temperature=None):
 
 
 
 
 
 
 
208
 
209
- function_call_json, message = self.tooluniverse.extract_function_call_json(
210
- fcall_str, return_message=return_message, verbose=False)
211
  call_results = []
212
  special_tool_call = ''
213
- if function_call_json is not None:
214
  if isinstance(function_call_json, list):
215
  for i in range(len(function_call_json)):
216
- print("\033[94mTool Call:\033[0m", function_call_json[i])
217
  if function_call_json[i]["name"] == 'Finish':
218
  special_tool_call = 'Finish'
219
  break
220
- elif function_call_json[i]["name"] == 'Tool_RAG':
221
- new_tools_prompt, call_result = self.tool_RAG(
222
- message=message,
223
- existing_tools_prompt=existing_tools_prompt,
224
- rag_num=self.step_rag_num,
225
- return_call_result=True)
226
- existing_tools_prompt += new_tools_prompt
227
  elif function_call_json[i]["name"] == 'CallAgent':
228
  if call_agent_level < 2 and call_agent:
229
  solution_plan = function_call_json[i]['arguments']['solution']
@@ -234,27 +203,27 @@ class TxAgent:
234
  )
235
  call_result = self.run_multistep_agent(
236
  full_message, temperature=temperature,
237
- max_new_tokens=1024, max_token=99999,
238
  call_agent=False, call_agent_level=call_agent_level)
239
- call_result = call_result.split(
240
- '[FinalAnswer]')[-1].strip()
 
 
241
  else:
242
- call_result = "Error: The CallAgent has been disabled. Please proceed with your reasoning process to solve this question."
243
  else:
244
- call_result = self.tooluniverse.run_one_function(
245
- function_call_json[i])
246
-
247
  call_id = self.tooluniverse.call_id_gen()
248
  function_call_json[i]["call_id"] = call_id
249
- print("\033[94mTool Call Result:\033[0m", call_result)
250
  call_results.append({
251
  "role": "tool",
252
- "content": json.dumps({"content": call_result, "call_id": call_id})
253
  })
254
  else:
255
  call_results.append({
256
  "role": "tool",
257
- "content": json.dumps({"content": "Not a valid function call, please check the function call format."})
258
  })
259
 
260
  revised_messages = [{
@@ -262,8 +231,6 @@ class TxAgent:
262
  "content": message.strip(),
263
  "tool_calls": json.dumps(function_call_json)
264
  }] + call_results
265
-
266
- # Yield the final result.
267
  return revised_messages, existing_tools_prompt, special_tool_call
268
 
269
  def run_function_call_stream(self, fcall_str,
@@ -274,26 +241,24 @@ class TxAgent:
274
  call_agent_level=None,
275
  temperature=None,
276
  return_gradio_history=True):
 
 
 
 
 
 
 
277
 
278
- function_call_json, message = self.tooluniverse.extract_function_call_json(
279
- fcall_str, return_message=return_message, verbose=False)
280
  call_results = []
281
  special_tool_call = ''
282
  if return_gradio_history:
283
  gradio_history = []
284
- if function_call_json is not None:
285
  if isinstance(function_call_json, list):
286
  for i in range(len(function_call_json)):
287
  if function_call_json[i]["name"] == 'Finish':
288
  special_tool_call = 'Finish'
289
  break
290
- elif function_call_json[i]["name"] == 'Tool_RAG':
291
- new_tools_prompt, call_result = self.tool_RAG(
292
- message=message,
293
- existing_tools_prompt=existing_tools_prompt,
294
- rag_num=self.step_rag_num,
295
- return_call_result=True)
296
- existing_tools_prompt += new_tools_prompt
297
  elif function_call_json[i]["name"] == 'DirectResponse':
298
  call_result = function_call_json[i]['arguments']['respose']
299
  special_tool_call = 'DirectResponse'
@@ -308,42 +273,33 @@ class TxAgent:
308
  "\nYou must follow the following plan to answer the question: " +
309
  str(solution_plan)
310
  )
311
- sub_agent_task = "Sub TxAgent plan: " + \
312
- str(solution_plan)
313
- # When streaming, yield responses as they arrive.
314
  call_result = yield from self.run_gradio_chat(
315
  full_message, history=[], temperature=temperature,
316
- max_new_tokens=1024, max_token=99999,
317
  call_agent=False, call_agent_level=call_agent_level,
318
- conversation=None,
319
- sub_agent_task=sub_agent_task)
320
-
321
- call_result = call_result.split(
322
- '[FinalAnswer]')[-1]
323
  else:
324
- call_result = "Error: The CallAgent has been disabled. Please proceed with your reasoning process to solve this question."
325
  else:
326
- call_result = self.tooluniverse.run_one_function(
327
- function_call_json[i])
328
-
329
  call_id = self.tooluniverse.call_id_gen()
330
  function_call_json[i]["call_id"] = call_id
331
  call_results.append({
332
  "role": "tool",
333
- "content": json.dumps({"content": call_result, "call_id": call_id})
334
  })
335
  if return_gradio_history and function_call_json[i]["name"] != 'Finish':
336
- if function_call_json[i]["name"] == 'Tool_RAG':
337
- gradio_history.append(ChatMessage(role="assistant", content=str(call_result), metadata={
338
- "title": "🧰 "+function_call_json[i]['name'], "log": str(function_call_json[i]['arguments'])}))
339
-
340
- else:
341
- gradio_history.append(ChatMessage(role="assistant", content=str(call_result), metadata={
342
- "title": "⚒️ "+function_call_json[i]['name'], "log": str(function_call_json[i]['arguments'])}))
343
  else:
344
  call_results.append({
345
  "role": "tool",
346
- "content": json.dumps({"content": "Not a valid function call, please check the function call format."})
347
  })
348
 
349
  revised_messages = [{
@@ -351,152 +307,119 @@ class TxAgent:
351
  "content": message.strip(),
352
  "tool_calls": json.dumps(function_call_json)
353
  }] + call_results
354
-
355
- # Yield the final result.
356
  if return_gradio_history:
357
  return revised_messages, existing_tools_prompt, special_tool_call, gradio_history
358
- else:
359
- return revised_messages, existing_tools_prompt, special_tool_call
360
 
361
  def get_answer_based_on_unfinished_reasoning(self, conversation, temperature, max_new_tokens, max_token, outputs=None):
362
- if conversation[-1]['role'] == 'assisant':
363
  conversation.append(
364
- {'role': 'tool', 'content': 'Errors happen during the function call, please come up with the final answer with the current information.'})
365
  finish_tools_prompt = self.add_finish_tools([])
366
-
367
- last_outputs_str = self.llm_infer(messages=conversation,
368
- temperature=temperature,
369
- tools=finish_tools_prompt,
370
- output_begin_string='Since I cannot continue reasoning, I will provide the final answer based on the current information and general knowledge.\n\n[FinalAnswer]',
371
- skip_special_tokens=True,
372
- max_new_tokens=max_new_tokens, max_token=max_token)
373
- print(last_outputs_str)
 
374
  return last_outputs_str
375
 
376
  def run_multistep_agent(self, message: str,
377
  temperature: float,
378
  max_new_tokens: int,
379
  max_token: int,
380
- max_round: int = 20,
381
  call_agent=False,
382
- call_agent_level=0) -> str:
383
- """
384
- Generate a streaming response using the llama3-8b model.
385
- Args:
386
- message (str): The input message.
387
- temperature (float): The temperature for generating the response.
388
- max_new_tokens (int): The maximum number of new tokens to generate.
389
- Returns:
390
- str: The generated response.
391
- """
392
- print("\033[1;32;40mstart\033[0m")
393
  picked_tools_prompt, call_agent_level = self.initialize_tools_prompt(
394
  call_agent, call_agent_level, message)
395
  conversation = self.initialize_conversation(message)
396
-
397
  outputs = []
398
  last_outputs = []
399
  next_round = True
400
- function_call_messages = []
401
  current_round = 0
402
  token_overflow = False
403
  enable_summary = False
404
  last_status = {}
405
 
406
- if self.enable_checker:
407
- checker = ReasoningTraceChecker(message, conversation)
408
- try:
409
- while next_round and current_round < max_round:
410
- current_round += 1
411
- if len(outputs) > 0:
412
- function_call_messages, picked_tools_prompt, special_tool_call = self.run_function_call(
413
- last_outputs, return_message=True,
414
- existing_tools_prompt=picked_tools_prompt,
415
- message_for_call_agent=message,
416
- call_agent=call_agent,
417
- call_agent_level=call_agent_level,
418
- temperature=temperature)
419
-
420
- if special_tool_call == 'Finish':
421
- next_round = False
422
- conversation.extend(function_call_messages)
423
- if isinstance(function_call_messages[0]['content'], types.GeneratorType):
424
- function_call_messages[0]['content'] = next(
425
- function_call_messages[0]['content'])
426
- return function_call_messages[0]['content'].split('[FinalAnswer]')[-1]
427
-
428
- if (self.enable_summary or token_overflow) and not call_agent:
429
- if token_overflow:
430
- print("token_overflow, using summary")
431
- enable_summary = True
432
- last_status = self.function_result_summary(
433
- conversation, status=last_status, enable_summary=enable_summary)
434
-
435
- if function_call_messages is not None:
436
- conversation.extend(function_call_messages)
437
- outputs.append(tool_result_format(
438
- function_call_messages))
439
- else:
440
- next_round = False
441
- conversation.extend(
442
- [{"role": "assistant", "content": ''.join(last_outputs)}])
443
- return ''.join(last_outputs).replace("</s>", "")
444
- if self.enable_checker:
445
- good_status, wrong_info = checker.check_conversation()
446
- if not good_status:
447
- next_round = False
448
- print(
449
- "Internal error in reasoning: " + wrong_info)
450
- break
451
- last_outputs = []
452
- outputs.append("### TxAgent:\n")
453
- last_outputs_str, token_overflow = self.llm_infer(messages=conversation,
454
- temperature=temperature,
455
- tools=picked_tools_prompt,
456
- skip_special_tokens=False,
457
- max_new_tokens=max_new_tokens, max_token=max_token,
458
- check_token_status=True)
459
- if last_outputs_str is None:
460
  next_round = False
461
- print(
462
- "The number of tokens exceeds the maximum limit.")
 
 
 
 
 
 
 
 
 
 
 
 
463
  else:
464
- last_outputs.append(last_outputs_str)
465
- if max_round == current_round:
466
- print("The number of rounds exceeds the maximum limit!")
467
- if self.force_finish:
468
- return self.get_answer_based_on_unfinished_reasoning(conversation, temperature, max_new_tokens, max_token)
469
- else:
470
- return None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
471
 
472
- except Exception as e:
473
- print(f"Error: {e}")
474
- if self.force_finish:
475
- return self.get_answer_based_on_unfinished_reasoning(conversation, temperature, max_new_tokens, max_token)
476
- else:
477
- return None
478
 
479
  def build_logits_processor(self, messages, llm):
480
- # Use the tokenizer from the LLM instance.
481
- tokenizer = llm.get_tokenizer()
482
- if self.avoid_repeat and len(messages) > 2:
483
- assistant_messages = []
484
- for i in range(1, len(messages) + 1):
485
- if messages[-i]['role'] == 'assistant':
486
- assistant_messages.append(messages[-i]['content'])
487
- if len(assistant_messages) == 2:
488
- break
489
- forbidden_ids = [tokenizer.encode(
490
- msg, add_special_tokens=False) for msg in assistant_messages]
491
- return [NoRepeatSentenceProcessor(forbidden_ids, 5)]
492
- else:
493
- return None
494
 
495
  def llm_infer(self, messages, temperature=0.1, tools=None,
496
- output_begin_string=None, max_new_tokens=2048,
497
- max_token=None, skip_special_tokens=True,
498
- model=None, tokenizer=None, terminators=None, seed=None, check_token_status=False):
499
-
500
  if model is None:
501
  model = self.model
502
 
@@ -504,7 +427,6 @@ class TxAgent:
504
  sampling_params = SamplingParams(
505
  temperature=temperature,
506
  max_tokens=max_new_tokens,
507
- logits_processors=logits_processor,
508
  seed=seed if seed is not None else self.seed,
509
  )
510
 
@@ -515,244 +437,174 @@ class TxAgent:
515
 
516
  if check_token_status and max_token is not None:
517
  token_overflow = False
518
- num_input_tokens = len(self.tokenizer.encode(
519
- prompt, return_tensors="pt")[0])
520
- if max_token is not None:
521
- if num_input_tokens > max_token:
522
- torch.cuda.empty_cache()
523
- gc.collect()
524
- print("Number of input tokens before inference:",
525
- num_input_tokens)
526
- logger.info(
527
- "The number of tokens exceeds the maximum limit!!!!")
528
- token_overflow = True
529
- return None, token_overflow
530
- output = model.generate(
531
- prompt,
532
- sampling_params=sampling_params,
533
- )
534
  output = output[0].outputs[0].text
535
- print("\033[92m" + output + "\033[0m")
 
 
536
  if check_token_status and max_token is not None:
537
  return output, token_overflow
538
-
539
  return output
540
 
541
  def run_self_agent(self, message: str,
542
  temperature: float,
543
  max_new_tokens: int,
544
- max_token: int) -> str:
545
-
546
- print("\033[1;32;40mstart self agent\033[0m")
547
- conversation = []
548
- conversation = self.set_system_prompt(conversation, self.self_prompt)
549
  conversation.append({"role": "user", "content": message})
550
- return self.llm_infer(messages=conversation,
551
- temperature=temperature,
552
- tools=None,
553
- max_new_tokens=max_new_tokens, max_token=max_token)
 
 
554
 
555
  def run_chat_agent(self, message: str,
556
  temperature: float,
557
  max_new_tokens: int,
558
- max_token: int) -> str:
559
-
560
- print("\033[1;32;40mstart chat agent\033[0m")
561
- conversation = []
562
- conversation = self.set_system_prompt(conversation, self.chat_prompt)
563
  conversation.append({"role": "user", "content": message})
564
- return self.llm_infer(messages=conversation,
565
- temperature=temperature,
566
- tools=None,
567
- max_new_tokens=max_new_tokens, max_token=max_token)
 
 
568
 
569
  def run_format_agent(self, message: str,
570
  answer: str,
571
  temperature: float,
572
  max_new_tokens: int,
573
- max_token: int) -> str:
574
-
575
- print("\033[1;32;40mstart format agent\033[0m")
576
  if '[FinalAnswer]' in answer:
577
  possible_final_answer = answer.split("[FinalAnswer]")[-1]
578
  elif "\n\n" in answer:
579
  possible_final_answer = answer.split("\n\n")[-1]
580
  else:
581
  possible_final_answer = answer.strip()
582
- if len(possible_final_answer) == 1:
583
- choice = possible_final_answer[0]
584
- if choice in ['A', 'B', 'C', 'D', 'E']:
585
- return choice
586
- elif len(possible_final_answer) > 1:
587
- if possible_final_answer[1] == ':':
588
- choice = possible_final_answer[0]
589
- if choice in ['A', 'B', 'C', 'D', 'E']:
590
- print("choice", choice)
591
- return choice
592
-
593
- conversation = []
594
- format_prompt = f"You are helpful assistant to transform the answer of agent to the final answer of 'A', 'B', 'C', 'D'."
595
- conversation = self.set_system_prompt(conversation, format_prompt)
596
  conversation.append({"role": "user", "content": message +
597
- "\nThe final answer of agent:" + answer + "\n The answer is (must be a letter):"})
598
- return self.llm_infer(messages=conversation,
599
- temperature=temperature,
600
- tools=None,
601
- max_new_tokens=max_new_tokens, max_token=max_token)
 
 
602
 
603
  def run_summary_agent(self, thought_calls: str,
604
  function_response: str,
605
  temperature: float,
606
  max_new_tokens: int,
607
- max_token: int) -> str:
608
- print("\033[1;32;40mSummarized Tool Result:\033[0m")
609
- generate_tool_result_summary_training_prompt = """Thought and function calls:
610
  {thought_calls}
611
-
612
  Function calls' responses:
613
  \"\"\"
614
  {function_response}
615
  \"\"\"
616
-
617
- Based on the Thought and function calls, and the function calls' responses, you need to generate a summary of the function calls' responses that fulfills the requirements of the thought. The summary MUST BE ONE sentence and include all necessary information.
618
-
619
- Directly respond with the summarized sentence of the function calls' responses only.
620
-
621
- Generate **one summarized sentence** about "function calls' responses" with necessary information, and respond with a string:
622
- """.format(thought_calls=thought_calls, function_response=function_response)
623
- conversation = []
624
- conversation.append(
625
- {"role": "user", "content": generate_tool_result_summary_training_prompt})
626
- output = self.llm_infer(messages=conversation,
627
- temperature=temperature,
628
- tools=None,
629
- max_new_tokens=max_new_tokens, max_token=max_token)
630
-
631
  if '[' in output:
632
  output = output.split('[')[0]
633
  return output
634
 
635
  def function_result_summary(self, input_list, status, enable_summary):
636
- """
637
- Processes the input list, extracting information from sequences of 'user', 'tool', 'assistant' roles.
638
- Supports 'length' and 'step' modes, and skips the last 'k' groups.
639
-
640
- Parameters:
641
- input_list (list): A list of dictionaries containing role and other information.
642
- summary_skip_last_k (int): Number of groups to skip from the end. Defaults to 0.
643
- summary_context_length (int): The context length threshold for the 'length' mode.
644
- last_processed_index (tuple or int): The last processed index.
645
-
646
- Returns:
647
- list: A list of extracted information from valid sequences.
648
- """
649
  if 'tool_call_step' not in status:
650
  status['tool_call_step'] = 0
651
-
652
  for idx in range(len(input_list)):
653
- pos_id = len(input_list)-idx-1
654
- if input_list[pos_id]['role'] == 'assistant':
655
- if 'tool_calls' in input_list[pos_id]:
656
- if 'Tool_RAG' in str(input_list[pos_id]['tool_calls']):
657
- status['tool_call_step'] += 1
658
  break
659
 
660
- if 'step' in status:
661
- status['step'] += 1
662
- else:
663
- status['step'] = 0
664
-
665
  if not enable_summary:
666
  return status
667
 
668
- if 'summarized_index' not in status:
669
- status['summarized_index'] = 0
670
-
671
- if 'summarized_step' not in status:
672
- status['summarized_step'] = 0
673
-
674
- if 'previous_length' not in status:
675
- status['previous_length'] = 0
676
-
677
- if 'history' not in status:
678
- status['history'] = []
679
 
680
  function_response = ''
681
- idx = 0
682
- current_summarized_index = status['summarized_index']
683
-
684
- status['history'].append(self.summary_mode == 'step' and status['summarized_step']
685
- < status['step']-status['tool_call_step']-self.summary_skip_last_k)
686
 
687
- idx = current_summarized_index
688
  while idx < len(input_list):
689
- if (self.summary_mode == 'step' and status['summarized_step'] < status['step']-status['tool_call_step']-self.summary_skip_last_k) or (self.summary_mode == 'length' and status['previous_length'] > self.summary_context_length):
690
-
691
  if input_list[idx]['role'] == 'assistant':
692
- if 'Tool_RAG' in str(input_list[idx]['tool_calls']):
693
- this_thought_calls = None
694
- else:
695
- if len(function_response) != 0:
696
- print("internal summary")
697
- status['summarized_step'] += 1
698
- result_summary = self.run_summary_agent(
699
- thought_calls=this_thought_calls,
700
- function_response=function_response,
701
- temperature=0.1,
702
- max_new_tokens=1024,
703
- max_token=99999
704
- )
705
-
706
- input_list.insert(
707
- last_call_idx+1, {'role': 'tool', 'content': result_summary})
708
- status['summarized_index'] = last_call_idx + 2
709
- idx += 1
710
-
711
- last_call_idx = idx
712
- this_thought_calls = input_list[idx]['content'] + \
713
- input_list[idx]['tool_calls']
714
- function_response = ''
715
-
716
  elif input_list[idx]['role'] == 'tool' and this_thought_calls is not None:
717
  function_response += input_list[idx]['content']
718
  del input_list[idx]
719
  idx -= 1
720
-
721
  else:
722
  break
723
  idx += 1
724
 
725
- if len(function_response) != 0:
726
  status['summarized_step'] += 1
727
  result_summary = self.run_summary_agent(
728
  thought_calls=this_thought_calls,
729
  function_response=function_response,
730
  temperature=0.1,
731
- max_new_tokens=1024,
732
- max_token=99999
733
- )
734
-
735
  tool_calls = json.loads(input_list[last_call_idx]['tool_calls'])
736
  for tool_call in tool_calls:
737
  del tool_call['call_id']
738
  input_list[last_call_idx]['tool_calls'] = json.dumps(tool_calls)
739
- input_list.insert(
740
- last_call_idx+1, {'role': 'tool', 'content': result_summary})
741
  status['summarized_index'] = last_call_idx + 2
742
 
743
  return status
744
 
745
- # Following are Gradio related functions
746
-
747
- # General update method that accepts any new arguments through kwargs
748
  def update_parameters(self, **kwargs):
 
749
  for key, value in kwargs.items():
750
  if hasattr(self, key):
751
  setattr(self, key, value)
752
-
753
- # Return the updated attributes
754
- updated_attributes = {key: value for key,
755
- value in kwargs.items() if hasattr(self, key)}
756
  return updated_attributes
757
 
758
  def run_gradio_chat(self, message: str,
@@ -762,54 +614,33 @@ Generate **one summarized sentence** about "function calls' responses" with nece
762
  max_token: int,
763
  call_agent: bool,
764
  conversation: gr.State,
765
- max_round: int = 20,
766
  seed: int = None,
767
  call_agent_level: int = 0,
768
- sub_agent_task: str = None) -> str:
769
- """
770
- Generate a streaming response using the llama3-8b model.
771
- Args:
772
- message (str): The input message.
773
- history (list): The conversation history used by ChatInterface.
774
- temperature (float): The temperature for generating the response.
775
- max_new_tokens (int): The maximum number of new tokens to generate.
776
- Returns:
777
- str: The generated response.
778
- """
779
- print("\033[1;32;40mstart\033[0m")
780
- print("len(message)", len(message))
781
- if len(message) <= 10:
782
- yield "Hi, I am TxAgent, an assistant for answering biomedical questions. Please provide a valid message with a string longer than 10 characters."
783
- return "Please provide a valid message."
784
- outputs = []
785
- outputs_str = ''
786
- last_outputs = []
787
 
788
  picked_tools_prompt, call_agent_level = self.initialize_tools_prompt(
789
- call_agent,
790
- call_agent_level,
791
- message)
792
-
793
  conversation = self.initialize_conversation(
794
- message,
795
- conversation=conversation,
796
- history=history)
797
  history = []
 
798
 
799
  next_round = True
800
- function_call_messages = []
801
  current_round = 0
802
  enable_summary = False
803
- last_status = {} # for summary
804
  token_overflow = False
805
- if self.enable_checker:
806
- checker = ReasoningTraceChecker(
807
- message, conversation, init_index=len(conversation))
808
 
809
  try:
810
  while next_round and current_round < max_round:
811
  current_round += 1
812
- if len(last_outputs) > 0:
813
  function_call_messages, picked_tools_prompt, special_tool_call, current_gradio_history = yield from self.run_function_call_stream(
814
  last_outputs, return_message=True,
815
  existing_tools_prompt=picked_tools_prompt,
@@ -818,40 +649,33 @@ Generate **one summarized sentence** about "function calls' responses" with nece
818
  call_agent_level=call_agent_level,
819
  temperature=temperature)
820
  history.extend(current_gradio_history)
 
821
  if special_tool_call == 'Finish':
822
  yield history
823
  next_round = False
824
  conversation.extend(function_call_messages)
825
  return function_call_messages[0]['content']
826
- elif special_tool_call == 'RequireClarification' or special_tool_call == 'DirectResponse':
827
- history.append(
828
- ChatMessage(role="assistant", content=history[-1].content))
 
829
  yield history
830
  next_round = False
831
- return history[-1].content
 
832
  if (self.enable_summary or token_overflow) and not call_agent:
833
- if token_overflow:
834
- print("token_overflow, using summary")
835
  enable_summary = True
836
  last_status = self.function_result_summary(
837
- conversation, status=last_status,
838
- enable_summary=enable_summary)
839
- if function_call_messages is not None:
840
  conversation.extend(function_call_messages)
841
- formated_md_function_call_messages = tool_result_format(
842
- function_call_messages)
843
  yield history
844
  else:
845
  next_round = False
846
- conversation.extend(
847
- [{"role": "assistant", "content": ''.join(last_outputs)}])
848
  return ''.join(last_outputs).replace("</s>", "")
849
- if self.enable_checker:
850
- good_status, wrong_info = checker.check_conversation()
851
- if not good_status:
852
- next_round = False
853
- print("Internal error in reasoning: " + wrong_info)
854
- break
855
  last_outputs = []
856
  last_outputs_str, token_overflow = self.llm_infer(
857
  messages=conversation,
@@ -862,26 +686,34 @@ Generate **one summarized sentence** about "function calls' responses" with nece
862
  max_token=max_token,
863
  seed=seed,
864
  check_token_status=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
865
  last_thought = last_outputs_str.split("[TOOL_CALLS]")[0]
866
- for each in history:
867
- if each.metadata is not None:
868
- each.metadata['status'] = 'done'
 
869
  if '[FinalAnswer]' in last_thought:
870
- final_thought, final_answer = last_thought.split(
871
- '[FinalAnswer]')
872
- history.append(
873
- ChatMessage(role="assistant",
874
- content=final_thought.strip())
875
- )
876
  yield history
877
- history.append(
878
- ChatMessage(
879
- role="assistant", content="**Answer**:\n"+final_answer.strip())
880
- )
881
  yield history
882
  else:
883
- history.append(ChatMessage(
884
- role="assistant", content=last_thought))
885
  yield history
886
 
887
  last_outputs.append(last_outputs_str)
@@ -890,48 +722,27 @@ Generate **one summarized sentence** about "function calls' responses" with nece
890
  if self.force_finish:
891
  last_outputs_str = self.get_answer_based_on_unfinished_reasoning(
892
  conversation, temperature, max_new_tokens, max_token)
893
- for each in history:
894
- if each.metadata is not None:
895
- each.metadata['status'] = 'done'
896
- if '[FinalAnswer]' in last_thought:
897
- final_thought, final_answer = last_thought.split(
898
- '[FinalAnswer]')
899
- history.append(
900
- ChatMessage(role="assistant",
901
- content=final_thought.strip())
902
- )
903
- yield history
904
- history.append(
905
- ChatMessage(
906
- role="assistant", content="**Answer**:\n"+final_answer.strip())
907
- )
908
- yield history
909
  else:
910
- yield "The number of rounds exceeds the maximum limit!"
911
 
912
  except Exception as e:
913
- print(f"Error: {e}")
 
 
 
914
  if self.force_finish:
915
  last_outputs_str = self.get_answer_based_on_unfinished_reasoning(
916
- conversation,
917
- temperature,
918
- max_new_tokens,
919
- max_token)
920
- for each in history:
921
- if each.metadata is not None:
922
- each.metadata['status'] = 'done'
923
- if '[FinalAnswer]' in last_thought or '"name": "Finish",' in last_outputs_str:
924
- final_thought, final_answer = last_thought.split(
925
- '[FinalAnswer]')
926
- history.append(
927
- ChatMessage(role="assistant",
928
- content=final_thought.strip())
929
- )
930
- yield history
931
- history.append(
932
- ChatMessage(
933
- role="assistant", content="**Answer**:\n"+final_answer.strip())
934
- )
935
- yield history
936
- else:
937
- return None
 
11
  from tooluniverse import ToolUniverse
12
  from gradio import ChatMessage
13
  from .toolrag import ToolRAGModel
14
+ import torch
15
+ import logging
16
 
17
+ logger = logging.getLogger(__name__)
18
+ logging.basicConfig(level=logging.INFO)
19
 
20
+ from .utils import NoRepeatSentenceProcessor, ReasoningTraceChecker, tool_result_format
21
 
22
  class TxAgent:
23
  def __init__(self, model_name,
24
  rag_model_name,
25
+ tool_files_dict=None,
26
  enable_finish=True,
27
+ enable_rag=False,
28
  enable_summary=False,
29
  init_rag_num=0,
30
+ step_rag_num=0,
31
  summary_mode='step',
32
  summary_skip_last_k=0,
33
  summary_context_length=None,
 
36
  seed=None,
37
  enable_checker=False,
38
  enable_chat=False,
39
+ additional_default_tools=None):
 
40
  self.model_name = model_name
41
  self.tokenizer = None
42
  self.terminators = None
 
45
  self.model = None
46
  self.rag_model = ToolRAGModel(rag_model_name)
47
  self.tooluniverse = None
48
+ self.prompt_multi_step = "You are a helpful assistant that solves problems through step-by-step reasoning."
 
49
  self.self_prompt = "Strictly follow the instruction."
50
+ self.chat_prompt = "You are a helpful assistant for user chat."
51
  self.enable_finish = enable_finish
52
  self.enable_rag = enable_rag
53
  self.enable_summary = enable_summary
 
61
  self.seed = seed
62
  self.enable_checker = enable_checker
63
  self.additional_default_tools = additional_default_tools
64
+ logger.info("TxAgent initialized with model: %s, RAG: %s", model_name, rag_model_name)
65
 
66
  def init_model(self):
67
  self.load_models()
68
  self.load_tooluniverse()
 
 
 
 
 
69
 
70
  def load_models(self, model_name=None):
71
  if model_name is not None:
 
73
  return f"The model {model_name} is already loaded."
74
  self.model_name = model_name
75
 
76
+ self.model = LLM(model=self.model_name, dtype="float16", max_model_len=2048, gpu_memory_utilization=0.8)
77
  self.chat_template = Template(self.model.get_tokenizer().chat_template)
78
  self.tokenizer = self.model.get_tokenizer()
79
+ logger.info("Model %s loaded successfully", self.model_name)
80
  return f"Model {model_name} loaded successfully."
81
 
82
  def load_tooluniverse(self):
 
85
  special_tools = self.tooluniverse.prepare_tool_prompts(
86
  self.tooluniverse.tool_category_dicts["special_tools"])
87
  self.special_tools_name = [tool['name'] for tool in special_tools]
88
+ logger.debug("ToolUniverse loaded with %d special tools", len(self.special_tools_name))
89
 
90
  def load_tool_desc_embedding(self):
91
+ cache_path = os.path.join(os.path.dirname(self.tool_files_dict["new_tool"]), "tool_embeddings.pkl")
92
+ if os.path.exists(cache_path):
93
+ self.rag_model.load_cached_embeddings(cache_path)
94
+ else:
95
+ self.rag_model.load_tool_desc_embedding(self.tooluniverse)
96
+ self.rag_model.save_embeddings(cache_path)
97
+ logger.debug("Tool description embeddings loaded")
98
 
99
  def rag_infer(self, query, top_k=5):
100
  return self.rag_model.rag_infer(query, top_k)
 
107
  call_agent_level += 1
108
  if call_agent_level >= 2:
109
  call_agent = False
 
 
 
 
110
  return picked_tools_prompt, call_agent_level
111
 
112
  def initialize_conversation(self, message, conversation=None, history=None):
 
115
 
116
  conversation = self.set_system_prompt(
117
  conversation, self.prompt_multi_step)
118
+ if history:
119
+ for i in range(len(history)):
120
+ if history[i]['role'] == 'user':
121
+ conversation.append({"role": "user", "content": history[i]['content']})
122
+ elif history[i]['role'] == 'assistant':
123
+ conversation.append({"role": "assistant", "content": history[i]['content']})
 
 
 
 
 
 
 
 
 
 
124
  conversation.append({"role": "user", "content": message})
125
+ logger.debug("Conversation initialized with %d messages", len(conversation))
126
  return conversation
127
 
128
  def tool_RAG(self, message=None,
129
  picked_tool_names=None,
130
  existing_tools_prompt=[],
131
+ rag_num=0,
132
  return_call_result=False):
133
+ if not self.enable_rag:
134
+ return []
135
+ extra_factor = 10
136
  if picked_tool_names is None:
137
  assert picked_tool_names is not None or message is not None
138
  picked_tool_names = self.rag_infer(
139
+ message, top_k=rag_num * extra_factor)
140
 
141
+ picked_tool_names_no_special = [tool for tool in picked_tool_names if tool not in self.special_tools_name]
 
 
 
 
142
  picked_tool_names = picked_tool_names_no_special[:rag_num]
143
 
144
  picked_tools = self.tooluniverse.get_tool_by_name(picked_tool_names)
145
+ picked_tools_prompt = self.tooluniverse.prepare_tool_prompts(picked_tools)
146
+ logger.debug("Retrieved %d tools via RAG", len(picked_tools_prompt))
147
  if return_call_result:
148
  return picked_tools_prompt, picked_tool_names
149
  return picked_tools_prompt
150
 
151
  def add_special_tools(self, tools, call_agent=False):
152
  if self.enable_finish:
153
+ tools.append(self.tooluniverse.get_one_tool_by_one_name('Finish', return_prompt=True))
154
+ logger.debug("Finish tool added")
 
155
  if call_agent:
156
+ tools.append(self.tooluniverse.get_one_tool_by_one_name('CallAgent', return_prompt=True))
157
+ logger.debug("CallAgent tool added")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
158
  return tools
159
 
160
  def add_finish_tools(self, tools):
161
+ tools.append(self.tooluniverse.get_one_tool_by_one_name('Finish', return_prompt=True))
162
+ logger.debug("Finish tool added")
 
163
  return tools
164
 
165
  def set_system_prompt(self, conversation, sys_prompt):
166
+ if not conversation:
167
+ conversation.append({"role": "system", "content": sys_prompt})
 
168
  else:
169
  conversation[0] = {"role": "system", "content": sys_prompt}
170
  return conversation
 
176
  call_agent=False,
177
  call_agent_level=None,
178
  temperature=None):
179
+ try:
180
+ function_call_json, message = self.tooluniverse.extract_function_call_json(
181
+ fcall_str, return_message=return_message, verbose=False)
182
+ except Exception as e:
183
+ logger.error("Tool call parsing failed: %s", e)
184
+ function_call_json = []
185
+ message = fcall_str
186
 
 
 
187
  call_results = []
188
  special_tool_call = ''
189
+ if function_call_json:
190
  if isinstance(function_call_json, list):
191
  for i in range(len(function_call_json)):
192
+ logger.info("Tool Call: %s", function_call_json[i])
193
  if function_call_json[i]["name"] == 'Finish':
194
  special_tool_call = 'Finish'
195
  break
 
 
 
 
 
 
 
196
  elif function_call_json[i]["name"] == 'CallAgent':
197
  if call_agent_level < 2 and call_agent:
198
  solution_plan = function_call_json[i]['arguments']['solution']
 
203
  )
204
  call_result = self.run_multistep_agent(
205
  full_message, temperature=temperature,
206
+ max_new_tokens=512, max_token=2048,
207
  call_agent=False, call_agent_level=call_agent_level)
208
+ if call_result is None:
209
+ call_result = "⚠️ No content returned from sub-agent."
210
+ else:
211
+ call_result = call_result.split('[FinalAnswer]')[-1].strip()
212
  else:
213
+ call_result = "Error: CallAgent disabled."
214
  else:
215
+ call_result = self.tooluniverse.run_one_function(function_call_json[i])
 
 
216
  call_id = self.tooluniverse.call_id_gen()
217
  function_call_json[i]["call_id"] = call_id
218
+ logger.info("Tool Call Result: %s", call_result)
219
  call_results.append({
220
  "role": "tool",
221
+ "content": json.dumps({"tool_name": function_call_json[i]["name"], "content": call_result, "call_id": call_id})
222
  })
223
  else:
224
  call_results.append({
225
  "role": "tool",
226
+ "content": json.dumps({"content": "Invalid or no function call detected."})
227
  })
228
 
229
  revised_messages = [{
 
231
  "content": message.strip(),
232
  "tool_calls": json.dumps(function_call_json)
233
  }] + call_results
 
 
234
  return revised_messages, existing_tools_prompt, special_tool_call
235
 
236
  def run_function_call_stream(self, fcall_str,
 
241
  call_agent_level=None,
242
  temperature=None,
243
  return_gradio_history=True):
244
+ try:
245
+ function_call_json, message = self.tooluniverse.extract_function_call_json(
246
+ fcall_str, return_message=return_message, verbose=False)
247
+ except Exception as e:
248
+ logger.error("Tool call parsing failed: %s", e)
249
+ function_call_json = []
250
+ message = fcall_str
251
 
 
 
252
  call_results = []
253
  special_tool_call = ''
254
  if return_gradio_history:
255
  gradio_history = []
256
+ if function_call_json:
257
  if isinstance(function_call_json, list):
258
  for i in range(len(function_call_json)):
259
  if function_call_json[i]["name"] == 'Finish':
260
  special_tool_call = 'Finish'
261
  break
 
 
 
 
 
 
 
262
  elif function_call_json[i]["name"] == 'DirectResponse':
263
  call_result = function_call_json[i]['arguments']['respose']
264
  special_tool_call = 'DirectResponse'
 
273
  "\nYou must follow the following plan to answer the question: " +
274
  str(solution_plan)
275
  )
276
+ sub_agent_task = "Sub TxAgent plan: " + str(solution_plan)
 
 
277
  call_result = yield from self.run_gradio_chat(
278
  full_message, history=[], temperature=temperature,
279
+ max_new_tokens=512, max_token=2048,
280
  call_agent=False, call_agent_level=call_agent_level,
281
+ conversation=None, sub_agent_task=sub_agent_task)
282
+ if call_result is not None and isinstance(call_result, str):
283
+ call_result = call_result.split('[FinalAnswer]')[-1]
284
+ else:
285
+ call_result = "⚠️ No content returned from sub-agent."
286
  else:
287
+ call_result = "Error: CallAgent disabled."
288
  else:
289
+ call_result = self.tooluniverse.run_one_function(function_call_json[i])
 
 
290
  call_id = self.tooluniverse.call_id_gen()
291
  function_call_json[i]["call_id"] = call_id
292
  call_results.append({
293
  "role": "tool",
294
+ "content": json.dumps({"tool_name": function_call_json[i]["name"], "content": call_result, "call_id": call_id})
295
  })
296
  if return_gradio_history and function_call_json[i]["name"] != 'Finish':
297
+ metadata = {"title": f"🧰 {function_call_json[i]['name']}", "log": str(function_call_json[i]['arguments'])}
298
+ gradio_history.append(ChatMessage(role="assistant", content=str(call_result), metadata=metadata))
 
 
 
 
 
299
  else:
300
  call_results.append({
301
  "role": "tool",
302
+ "content": json.dumps({"content": "Invalid or no function call detected."})
303
  })
304
 
305
  revised_messages = [{
 
307
  "content": message.strip(),
308
  "tool_calls": json.dumps(function_call_json)
309
  }] + call_results
 
 
310
  if return_gradio_history:
311
  return revised_messages, existing_tools_prompt, special_tool_call, gradio_history
312
+ return revised_messages, existing_tools_prompt, special_tool_call
 
313
 
314
  def get_answer_based_on_unfinished_reasoning(self, conversation, temperature, max_new_tokens, max_token, outputs=None):
315
+ if conversation[-1]['role'] == 'assistant':
316
  conversation.append(
317
+ {'role': 'tool', 'content': 'Errors occurred during function call; provide final answer with current information.'})
318
  finish_tools_prompt = self.add_finish_tools([])
319
+ last_outputs_str = self.llm_infer(
320
+ messages=conversation,
321
+ temperature=temperature,
322
+ tools=finish_tools_prompt,
323
+ output_begin_string='[FinalAnswer]',
324
+ skip_special_tokens=True,
325
+ max_new_tokens=max_new_tokens,
326
+ max_token=max_token)
327
+ logger.info("Unfinished reasoning answer: %s", last_outputs_str[:100])
328
  return last_outputs_str
329
 
330
  def run_multistep_agent(self, message: str,
331
  temperature: float,
332
  max_new_tokens: int,
333
  max_token: int,
334
+ max_round: int = 5,
335
  call_agent=False,
336
+ call_agent_level=0):
337
+ logger.info("Starting multistep agent for message: %s", message[:100])
 
 
 
 
 
 
 
 
 
338
  picked_tools_prompt, call_agent_level = self.initialize_tools_prompt(
339
  call_agent, call_agent_level, message)
340
  conversation = self.initialize_conversation(message)
 
341
  outputs = []
342
  last_outputs = []
343
  next_round = True
 
344
  current_round = 0
345
  token_overflow = False
346
  enable_summary = False
347
  last_status = {}
348
 
349
+ while next_round and current_round < max_round:
350
+ current_round += 1
351
+ if len(outputs) > 0:
352
+ function_call_messages, picked_tools_prompt, special_tool_call = self.run_function_call(
353
+ last_outputs, return_message=True,
354
+ existing_tools_prompt=picked_tools_prompt,
355
+ message_for_call_agent=message,
356
+ call_agent=call_agent,
357
+ call_agent_level=call_agent_level,
358
+ temperature=temperature)
359
+
360
+ if special_tool_call == 'Finish':
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
361
  next_round = False
362
+ conversation.extend(function_call_messages)
363
+ content = function_call_messages[0]['content']
364
+ if content is None:
365
+ return "❌ No content returned after Finish tool call."
366
+ return content.split('[FinalAnswer]')[-1]
367
+
368
+ if (self.enable_summary or token_overflow) and not call_agent:
369
+ enable_summary = True
370
+ last_status = self.function_result_summary(
371
+ conversation, status=last_status, enable_summary=enable_summary)
372
+
373
+ if function_call_messages:
374
+ conversation.extend(function_call_messages)
375
+ outputs.append(tool_result_format(function_call_messages))
376
  else:
377
+ next_round = False
378
+ conversation.extend([{"role": "assistant", "content": ''.join(last_outputs)}])
379
+ return ''.join(last_outputs).replace("</s>", "")
380
+
381
+ last_outputs = []
382
+ outputs.append("### TxAgent:\n")
383
+ last_outputs_str, token_overflow = self.llm_infer(
384
+ messages=conversation,
385
+ temperature=temperature,
386
+ tools=picked_tools_prompt,
387
+ skip_special_tokens=False,
388
+ max_new_tokens=max_new_tokens,
389
+ max_token=max_token,
390
+ check_token_status=True)
391
+ if last_outputs_str is None:
392
+ logger.warning("Token limit exceeded")
393
+ if self.force_finish:
394
+ return self.get_answer_based_on_unfinished_reasoning(
395
+ conversation, temperature, max_new_tokens, max_token)
396
+ return "❌ Token limit exceeded."
397
+ last_outputs.append(last_outputs_str)
398
 
399
+ if max_round == current_round:
400
+ logger.warning("Max rounds exceeded")
401
+ if self.force_finish:
402
+ return self.get_answer_based_on_unfinished_reasoning(
403
+ conversation, temperature, max_new_tokens, max_token)
404
+ return None
405
 
406
  def build_logits_processor(self, messages, llm):
407
+ # Disabled logits processor due to vLLM V1 limitation
408
+ logger.warning("Logits processor disabled due to vLLM V1 limitation")
409
+ return None
410
+ # Original code (commented out):
411
+ # tokenizer = llm.get_tokenizer()
412
+ # if self.avoid_repeat and len(messages) > 2:
413
+ # assistant_messages = [msg['content'] for msg in messages[-3:] if msg['role'] == 'assistant'][:2]
414
+ # forbidden_ids = [tokenizer.encode(msg, add_special_tokens=False) for msg in assistant_messages]
415
+ # return [NoRepeatSentenceProcessor(forbidden_ids, 5)]
416
+ # return None
 
 
 
 
417
 
418
  def llm_infer(self, messages, temperature=0.1, tools=None,
419
+ output_begin_string=None, max_new_tokens=512,
420
+ max_token=2048, skip_special_tokens=True,
421
+ model=None, tokenizer=None, terminators=None,
422
+ seed=None, check_token_status=False):
423
  if model is None:
424
  model = self.model
425
 
 
427
  sampling_params = SamplingParams(
428
  temperature=temperature,
429
  max_tokens=max_new_tokens,
 
430
  seed=seed if seed is not None else self.seed,
431
  )
432
 
 
437
 
438
  if check_token_status and max_token is not None:
439
  token_overflow = False
440
+ num_input_tokens = len(self.tokenizer.encode(prompt, return_tensors="pt")[0])
441
+ if num_input_tokens > max_token:
442
+ torch.cuda.empty_cache()
443
+ gc.collect()
444
+ logger.info("Token overflow: %d > %d", num_input_tokens, max_token)
445
+ return None, True
446
+
447
+ output = model.generate(prompt, sampling_params=sampling_params)
 
 
 
 
 
 
 
 
448
  output = output[0].outputs[0].text
449
+ logger.debug("Inference output: %s", output[:100])
450
+ torch.cuda.empty_cache()
451
+ gc.collect()
452
  if check_token_status and max_token is not None:
453
  return output, token_overflow
 
454
  return output
455
 
456
  def run_self_agent(self, message: str,
457
  temperature: float,
458
  max_new_tokens: int,
459
+ max_token: int):
460
+ logger.info("Starting self agent")
461
+ conversation = self.set_system_prompt([], self.self_prompt)
 
 
462
  conversation.append({"role": "user", "content": message})
463
+ return self.llm_infer(
464
+ messages=conversation,
465
+ temperature=temperature,
466
+ tools=None,
467
+ max_new_tokens=max_new_tokens,
468
+ max_token=max_token)
469
 
470
  def run_chat_agent(self, message: str,
471
  temperature: float,
472
  max_new_tokens: int,
473
+ max_token: int):
474
+ logger.info("Starting chat agent")
475
+ conversation = self.set_system_prompt([], self.chat_prompt)
 
 
476
  conversation.append({"role": "user", "content": message})
477
+ return self.llm_infer(
478
+ messages=conversation,
479
+ temperature=temperature,
480
+ tools=None,
481
+ max_new_tokens=max_new_tokens,
482
+ max_token=max_token)
483
 
484
  def run_format_agent(self, message: str,
485
  answer: str,
486
  temperature: float,
487
  max_new_tokens: int,
488
+ max_token: int):
489
+ logger.info("Starting format agent")
 
490
  if '[FinalAnswer]' in answer:
491
  possible_final_answer = answer.split("[FinalAnswer]")[-1]
492
  elif "\n\n" in answer:
493
  possible_final_answer = answer.split("\n\n")[-1]
494
  else:
495
  possible_final_answer = answer.strip()
496
+ if len(possible_final_answer) == 1 and possible_final_answer in ['A', 'B', 'C', 'D', 'E']:
497
+ return possible_final_answer
498
+ elif len(possible_final_answer) > 1 and possible_final_answer[1] == ':' and possible_final_answer[0] in ['A', 'B', 'C', 'D', 'E']:
499
+ return possible_final_answer[0]
500
+
501
+ conversation = self.set_system_prompt(
502
+ [], "Transform the agent's answer to a single letter: 'A', 'B', 'C', 'D'.")
 
 
 
 
 
 
 
503
  conversation.append({"role": "user", "content": message +
504
+ "\nAgent's answer: " + answer + "\nAnswer (must be a letter):"})
505
+ return self.llm_infer(
506
+ messages=conversation,
507
+ temperature=temperature,
508
+ tools=None,
509
+ max_new_tokens=max_new_tokens,
510
+ max_token=max_token)
511
 
512
  def run_summary_agent(self, thought_calls: str,
513
  function_response: str,
514
  temperature: float,
515
  max_new_tokens: int,
516
+ max_token: int):
517
+ logger.info("Summarizing tool result")
518
+ prompt = f"""Thought and function calls:
519
  {thought_calls}
 
520
  Function calls' responses:
521
  \"\"\"
522
  {function_response}
523
  \"\"\"
524
+ Summarize the function calls' responses in one sentence with all necessary information.
525
+ """
526
+ conversation = [{"role": "user", "content": prompt}]
527
+ output = self.llm_infer(
528
+ messages=conversation,
529
+ temperature=temperature,
530
+ tools=None,
531
+ max_new_tokens=max_new_tokens,
532
+ max_token=max_token)
 
 
 
 
 
 
533
  if '[' in output:
534
  output = output.split('[')[0]
535
  return output
536
 
537
  def function_result_summary(self, input_list, status, enable_summary):
 
 
 
 
 
 
 
 
 
 
 
 
 
538
  if 'tool_call_step' not in status:
539
  status['tool_call_step'] = 0
 
540
  for idx in range(len(input_list)):
541
+ pos_id = len(input_list) - idx - 1
542
+ if input_list[pos_id]['role'] == 'assistant' and 'tool_calls' in input_list[pos_id]:
 
 
 
543
  break
544
 
545
+ status['step'] = status.get('step', 0) + 1
 
 
 
 
546
  if not enable_summary:
547
  return status
548
 
549
+ status['summarized_index'] = status.get('summarized_index', 0)
550
+ status['summarized_step'] = status.get('summarized_step', 0)
551
+ status['previous_length'] = status.get('previous_length', 0)
552
+ status['history'] = status.get('history', [])
 
 
 
 
 
 
 
553
 
554
  function_response = ''
555
+ idx = status['summarized_index']
556
+ this_thought_calls = None
 
 
 
557
 
 
558
  while idx < len(input_list):
559
+ if (self.summary_mode == 'step' and status['summarized_step'] < status['step'] - status['tool_call_step'] - self.summary_skip_last_k) or \
560
+ (self.summary_mode == 'length' and status['previous_length'] > self.summary_context_length):
561
  if input_list[idx]['role'] == 'assistant':
562
+ if function_response:
563
+ status['summarized_step'] += 1
564
+ result_summary = self.run_summary_agent(
565
+ thought_calls=this_thought_calls,
566
+ function_response=function_response,
567
+ temperature=0.1,
568
+ max_new_tokens=512,
569
+ max_token=2048)
570
+ input_list.insert(last_call_idx + 1, {'role': 'tool', 'content': result_summary})
571
+ status['summarized_index'] = last_call_idx + 2
572
+ idx += 1
573
+ last_call_idx = idx
574
+ this_thought_calls = input_list[idx]['content'] + input_list[idx]['tool_calls']
575
+ function_response = ''
 
 
 
 
 
 
 
 
 
 
576
  elif input_list[idx]['role'] == 'tool' and this_thought_calls is not None:
577
  function_response += input_list[idx]['content']
578
  del input_list[idx]
579
  idx -= 1
 
580
  else:
581
  break
582
  idx += 1
583
 
584
+ if function_response:
585
  status['summarized_step'] += 1
586
  result_summary = self.run_summary_agent(
587
  thought_calls=this_thought_calls,
588
  function_response=function_response,
589
  temperature=0.1,
590
+ max_new_tokens=512,
591
+ max_token=2048)
 
 
592
  tool_calls = json.loads(input_list[last_call_idx]['tool_calls'])
593
  for tool_call in tool_calls:
594
  del tool_call['call_id']
595
  input_list[last_call_idx]['tool_calls'] = json.dumps(tool_calls)
596
+ input_list.insert(last_call_idx + 1, {'role': 'tool', 'content': result_summary})
 
597
  status['summarized_index'] = last_call_idx + 2
598
 
599
  return status
600
 
 
 
 
601
  def update_parameters(self, **kwargs):
602
+ updated_attributes = {}
603
  for key, value in kwargs.items():
604
  if hasattr(self, key):
605
  setattr(self, key, value)
606
+ updated_attributes[key] = value
607
+ logger.info("Updated parameters: %s", updated_attributes)
 
 
608
  return updated_attributes
609
 
610
  def run_gradio_chat(self, message: str,
 
614
  max_token: int,
615
  call_agent: bool,
616
  conversation: gr.State,
617
+ max_round: int = 5,
618
  seed: int = None,
619
  call_agent_level: int = 0,
620
+ sub_agent_task: str = None,
621
+ uploaded_files: list = None):
622
+ logger.info("Chat started, message: %s", message[:100])
623
+ if not message or len(message.strip()) < 5:
624
+ yield "Please provide a valid message or upload files to analyze."
625
+ return
 
 
 
 
 
 
 
 
 
 
 
 
 
626
 
627
  picked_tools_prompt, call_agent_level = self.initialize_tools_prompt(
628
+ call_agent, call_agent_level, message)
 
 
 
629
  conversation = self.initialize_conversation(
630
+ message, conversation, history)
 
 
631
  history = []
632
+ last_outputs = []
633
 
634
  next_round = True
 
635
  current_round = 0
636
  enable_summary = False
637
+ last_status = {}
638
  token_overflow = False
 
 
 
639
 
640
  try:
641
  while next_round and current_round < max_round:
642
  current_round += 1
643
+ if last_outputs:
644
  function_call_messages, picked_tools_prompt, special_tool_call, current_gradio_history = yield from self.run_function_call_stream(
645
  last_outputs, return_message=True,
646
  existing_tools_prompt=picked_tools_prompt,
 
649
  call_agent_level=call_agent_level,
650
  temperature=temperature)
651
  history.extend(current_gradio_history)
652
+
653
  if special_tool_call == 'Finish':
654
  yield history
655
  next_round = False
656
  conversation.extend(function_call_messages)
657
  return function_call_messages[0]['content']
658
+
659
+ elif special_tool_call in ['RequireClarification', 'DirectResponse']:
660
+ last_msg = history[-1] if history else ChatMessage(role="assistant", content="Response needed.")
661
+ history.append(ChatMessage(role="assistant", content=last_msg.content))
662
  yield history
663
  next_round = False
664
+ return last_msg.content
665
+
666
  if (self.enable_summary or token_overflow) and not call_agent:
 
 
667
  enable_summary = True
668
  last_status = self.function_result_summary(
669
+ conversation, status=last_status, enable_summary=enable_summary)
670
+
671
+ if function_call_messages:
672
  conversation.extend(function_call_messages)
 
 
673
  yield history
674
  else:
675
  next_round = False
676
+ conversation.append({"role": "assistant", "content": ''.join(last_outputs)})
 
677
  return ''.join(last_outputs).replace("</s>", "")
678
+
 
 
 
 
 
679
  last_outputs = []
680
  last_outputs_str, token_overflow = self.llm_infer(
681
  messages=conversation,
 
686
  max_token=max_token,
687
  seed=seed,
688
  check_token_status=True)
689
+
690
+ if last_outputs_str is None:
691
+ logger.warning("Token limit exceeded")
692
+ if self.force_finish:
693
+ last_outputs_str = self.get_answer_based_on_unfinished_reasoning(
694
+ conversation, temperature, max_new_tokens, max_token)
695
+ history.append(ChatMessage(role="assistant", content=last_outputs_str.strip()))
696
+ yield history
697
+ return last_outputs_str
698
+ error_msg = "Token limit exceeded."
699
+ history.append(ChatMessage(role="assistant", content=error_msg))
700
+ yield history
701
+ return error_msg
702
+
703
  last_thought = last_outputs_str.split("[TOOL_CALLS]")[0]
704
+ for msg in history:
705
+ if msg.metadata is not None:
706
+ msg.metadata['status'] = 'done'
707
+
708
  if '[FinalAnswer]' in last_thought:
709
+ parts = last_thought.split('[FinalAnswer]', 1)
710
+ final_thought, final_answer = parts if len(parts) == 2 else (last_thought, "")
711
+ history.append(ChatMessage(role="assistant", content=final_thought.strip()))
 
 
 
712
  yield history
713
+ history.append(ChatMessage(role="assistant", content="**🧠 Final Analysis:**\n" + final_answer.strip()))
 
 
 
714
  yield history
715
  else:
716
+ history.append(ChatMessage(role="assistant", content=last_thought))
 
717
  yield history
718
 
719
  last_outputs.append(last_outputs_str)
 
722
  if self.force_finish:
723
  last_outputs_str = self.get_answer_based_on_unfinished_reasoning(
724
  conversation, temperature, max_new_tokens, max_token)
725
+ parts = last_outputs_str.split('[FinalAnswer]', 1)
726
+ final_thought, final_answer = parts if len(parts) == 2 else (last_outputs_str, "")
727
+ history.append(ChatMessage(role="assistant", content=final_thought.strip()))
728
+ yield history
729
+ history.append(ChatMessage(role="assistant", content="**🧠 Final Analysis:**\n" + final_answer.strip()))
730
+ yield history
 
 
 
 
 
 
 
 
 
 
731
  else:
732
+ yield "Reasoning rounds exceeded limit."
733
 
734
  except Exception as e:
735
+ logger.error("Exception in run_gradio_chat: %s", e, exc_info=True)
736
+ error_msg = f"Error: {e}"
737
+ history.append(ChatMessage(role="assistant", content=error_msg))
738
+ yield history
739
  if self.force_finish:
740
  last_outputs_str = self.get_answer_based_on_unfinished_reasoning(
741
+ conversation, temperature, max_new_tokens, max_token)
742
+ parts = last_outputs_str.split('[FinalAnswer]', 1)
743
+ final_thought, final_answer = parts if len(parts) == 2 else (last_outputs_str, "")
744
+ history.append(ChatMessage(role="assistant", content=final_thought.strip()))
745
+ yield history
746
+ history.append(ChatMessage(role="assistant", content="**🧠 Final Analysis:**\n" + final_answer.strip()))
747
+ yield history
748
+ return error_msg