Ali2206 commited on
Commit
7691fc2
·
verified ·
1 Parent(s): 9737311

Update src/txagent/txagent.py

Browse files
Files changed (1) hide show
  1. src/txagent/txagent.py +278 -356
src/txagent/txagent.py CHANGED
@@ -24,17 +24,17 @@ class TxAgent:
24
  rag_model_name,
25
  tool_files_dict=None,
26
  enable_finish=True,
27
- enable_rag=False,
28
  enable_summary=False,
29
- init_rag_num=0,
30
- step_rag_num=0,
31
  summary_mode='step',
32
  summary_skip_last_k=0,
33
  summary_context_length=None,
34
  force_finish=True,
35
  avoid_repeat=True,
36
  seed=None,
37
- enable_checker=False,
38
  enable_chat=False,
39
  additional_default_tools=None):
40
  self.model_name = model_name
@@ -45,9 +45,9 @@ class TxAgent:
45
  self.model = None
46
  self.rag_model = ToolRAGModel(rag_model_name)
47
  self.tooluniverse = None
48
- self.prompt_multi_step = "You are a helpful assistant that solves problems through step-by-step reasoning."
49
- self.self_prompt = "Strictly follow the instruction."
50
- self.chat_prompt = "You are a helpful assistant for user chat."
51
  self.enable_finish = enable_finish
52
  self.enable_rag = enable_rag
53
  self.enable_summary = enable_summary
@@ -61,23 +61,28 @@ class TxAgent:
61
  self.seed = seed
62
  self.enable_checker = enable_checker
63
  self.additional_default_tools = additional_default_tools
64
- logger.info("TxAgent initialized with model: %s, RAG: %s", model_name, rag_model_name)
65
 
66
  def init_model(self):
67
  self.load_models()
68
  self.load_tooluniverse()
 
 
 
 
 
69
 
70
  def load_models(self, model_name=None):
71
- if model_name is not None:
72
- if model_name == self.model_name:
73
- return f"The model {model_name} is already loaded."
74
  self.model_name = model_name
75
 
76
- self.model = LLM(model=self.model_name, dtype="float16", max_model_len=2048, gpu_memory_utilization=0.8)
77
  self.chat_template = Template(self.model.get_tokenizer().chat_template)
78
  self.tokenizer = self.model.get_tokenizer()
79
  logger.info("Model %s loaded successfully", self.model_name)
80
- return f"Model {model_name} loaded successfully."
81
 
82
  def load_tooluniverse(self):
83
  self.tooluniverse = ToolUniverse(tool_files=self.tool_files_dict)
@@ -88,12 +93,7 @@ class TxAgent:
88
  logger.debug("ToolUniverse loaded with %d special tools", len(self.special_tools_name))
89
 
90
  def load_tool_desc_embedding(self):
91
- cache_path = os.path.join(os.path.dirname(self.tool_files_dict["new_tool"]), "tool_embeddings.pkl")
92
- if os.path.exists(cache_path):
93
- self.rag_model.load_cached_embeddings(cache_path)
94
- else:
95
- self.rag_model.load_tool_desc_embedding(self.tooluniverse)
96
- self.rag_model.save_embeddings(cache_path)
97
  logger.debug("Tool description embeddings loaded")
98
 
99
  def rag_infer(self, query, top_k=5):
@@ -107,43 +107,39 @@ class TxAgent:
107
  call_agent_level += 1
108
  if call_agent_level >= 2:
109
  call_agent = False
 
 
 
 
110
  return picked_tools_prompt, call_agent_level
111
 
112
  def initialize_conversation(self, message, conversation=None, history=None):
113
  if conversation is None:
114
  conversation = []
115
 
116
- conversation = self.set_system_prompt(
117
- conversation, self.prompt_multi_step)
118
  if history:
119
- for i in range(len(history)):
120
- if history[i]['role'] == 'user':
121
- conversation.append({"role": "user", "content": history[i]['content']})
122
- elif history[i]['role'] == 'assistant':
123
- conversation.append({"role": "assistant", "content": history[i]['content']})
124
  conversation.append({"role": "user", "content": message})
125
  logger.debug("Conversation initialized with %d messages", len(conversation))
126
  return conversation
127
 
128
- def tool_RAG(self, message=None,
129
- picked_tool_names=None,
130
- existing_tools_prompt=[],
131
- rag_num=0,
132
- return_call_result=False):
133
- if not self.enable_rag:
134
- return []
135
- extra_factor = 10
136
  if picked_tool_names is None:
137
- assert picked_tool_names is not None or message is not None
138
- picked_tool_names = self.rag_infer(
139
- message, top_k=rag_num * extra_factor)
140
-
141
- picked_tool_names_no_special = [tool for tool in picked_tool_names if tool not in self.special_tools_name]
142
- picked_tool_names = picked_tool_names_no_special[:rag_num]
143
 
 
 
 
 
144
  picked_tools = self.tooluniverse.get_tool_by_name(picked_tool_names)
145
  picked_tools_prompt = self.tooluniverse.prepare_tool_prompts(picked_tools)
146
- logger.debug("Retrieved %d tools via RAG", len(picked_tools_prompt))
147
  if return_call_result:
148
  return picked_tools_prompt, picked_tool_names
149
  return picked_tools_prompt
@@ -155,6 +151,15 @@ class TxAgent:
155
  if call_agent:
156
  tools.append(self.tooluniverse.get_one_tool_by_one_name('CallAgent', return_prompt=True))
157
  logger.debug("CallAgent tool added")
 
 
 
 
 
 
 
 
 
158
  return tools
159
 
160
  def add_finish_tools(self, tools):
@@ -169,179 +174,135 @@ class TxAgent:
169
  conversation[0] = {"role": "system", "content": sys_prompt}
170
  return conversation
171
 
172
- def run_function_call(self, fcall_str,
173
- return_message=False,
174
- existing_tools_prompt=None,
175
- message_for_call_agent=None,
176
- call_agent=False,
177
- call_agent_level=None,
178
- temperature=None):
179
- try:
180
- function_call_json, message = self.tooluniverse.extract_function_call_json(
181
- fcall_str, return_message=return_message, verbose=False)
182
- except Exception as e:
183
- logger.error("Tool call parsing failed: %s", e)
184
- function_call_json = []
185
- message = fcall_str
186
-
187
  call_results = []
188
  special_tool_call = ''
189
  if function_call_json:
190
- if isinstance(function_call_json, list):
191
- for i in range(len(function_call_json)):
192
- logger.info("Tool Call: %s", function_call_json[i])
193
- if function_call_json[i]["name"] == 'Finish':
194
- special_tool_call = 'Finish'
195
- break
196
- elif function_call_json[i]["name"] == 'CallAgent':
197
- if call_agent_level < 2 and call_agent:
198
- solution_plan = function_call_json[i]['arguments']['solution']
199
- full_message = (
200
- message_for_call_agent +
201
- "\nYou must follow the following plan to answer the question: " +
202
- str(solution_plan)
203
- )
204
- call_result = self.run_multistep_agent(
205
- full_message, temperature=temperature,
206
- max_new_tokens=512, max_token=2048,
207
- call_agent=False, call_agent_level=call_agent_level)
208
- if call_result is None:
209
- call_result = "⚠️ No content returned from sub-agent."
210
- else:
211
- call_result = call_result.split('[FinalAnswer]')[-1].strip()
212
- else:
213
- call_result = "Error: CallAgent disabled."
214
- else:
215
- call_result = self.tooluniverse.run_one_function(function_call_json[i])
216
- call_id = self.tooluniverse.call_id_gen()
217
- function_call_json[i]["call_id"] = call_id
218
- logger.info("Tool Call Result: %s", call_result)
219
- call_results.append({
220
- "role": "tool",
221
- "content": json.dumps({"tool_name": function_call_json[i]["name"], "content": call_result, "call_id": call_id})
222
- })
223
  else:
224
  call_results.append({
225
  "role": "tool",
226
- "content": json.dumps({"content": "Invalid or no function call detected."})
227
  })
228
 
229
  revised_messages = [{
230
  "role": "assistant",
231
- "content": message.strip(),
232
  "tool_calls": json.dumps(function_call_json)
233
  }] + call_results
234
  return revised_messages, existing_tools_prompt, special_tool_call
235
 
236
- def run_function_call_stream(self, fcall_str,
237
- return_message=False,
238
- existing_tools_prompt=None,
239
- message_for_call_agent=None,
240
- call_agent=False,
241
- call_agent_level=None,
242
- temperature=None,
243
- return_gradio_history=True):
244
- try:
245
- function_call_json, message = self.tooluniverse.extract_function_call_json(
246
- fcall_str, return_message=return_message, verbose=False)
247
- except Exception as e:
248
- logger.error("Tool call parsing failed: %s", e)
249
- function_call_json = []
250
- message = fcall_str
251
-
252
  call_results = []
253
  special_tool_call = ''
254
- if return_gradio_history:
255
- gradio_history = []
256
  if function_call_json:
257
- if isinstance(function_call_json, list):
258
- for i in range(len(function_call_json)):
259
- if function_call_json[i]["name"] == 'Finish':
260
- special_tool_call = 'Finish'
261
- break
262
- elif function_call_json[i]["name"] == 'DirectResponse':
263
- call_result = function_call_json[i]['arguments']['respose']
264
- special_tool_call = 'DirectResponse'
265
- elif function_call_json[i]["name"] == 'RequireClarification':
266
- call_result = function_call_json[i]['arguments']['unclear_question']
267
- special_tool_call = 'RequireClarification'
268
- elif function_call_json[i]["name"] == 'CallAgent':
269
- if call_agent_level < 2 and call_agent:
270
- solution_plan = function_call_json[i]['arguments']['solution']
271
- full_message = (
272
- message_for_call_agent +
273
- "\nYou must follow the following plan to answer the question: " +
274
- str(solution_plan)
275
- )
276
- sub_agent_task = "Sub TxAgent plan: " + str(solution_plan)
277
- call_result = yield from self.run_gradio_chat(
278
- full_message, history=[], temperature=temperature,
279
- max_new_tokens=512, max_token=2048,
280
- call_agent=False, call_agent_level=call_agent_level,
281
- conversation=None, sub_agent_task=sub_agent_task)
282
- if call_result is not None and isinstance(call_result, str):
283
- call_result = call_result.split('[FinalAnswer]')[-1]
284
- else:
285
- call_result = "⚠️ No content returned from sub-agent."
286
- else:
287
- call_result = "Error: CallAgent disabled."
288
- else:
289
- call_result = self.tooluniverse.run_one_function(function_call_json[i])
290
- call_id = self.tooluniverse.call_id_gen()
291
- function_call_json[i]["call_id"] = call_id
292
- call_results.append({
293
- "role": "tool",
294
- "content": json.dumps({"tool_name": function_call_json[i]["name"], "content": call_result, "call_id": call_id})
295
- })
296
- if return_gradio_history and function_call_json[i]["name"] != 'Finish':
297
- metadata = {"title": f"🧰 {function_call_json[i]['name']}", "log": str(function_call_json[i]['arguments'])}
298
- gradio_history.append(ChatMessage(role="assistant", content=str(call_result), metadata=metadata))
299
  else:
300
  call_results.append({
301
  "role": "tool",
302
- "content": json.dumps({"content": "Invalid or no function call detected."})
303
  })
304
 
305
  revised_messages = [{
306
  "role": "assistant",
307
- "content": message.strip(),
308
  "tool_calls": json.dumps(function_call_json)
309
  }] + call_results
310
- if return_gradio_history:
311
- return revised_messages, existing_tools_prompt, special_tool_call, gradio_history
312
- return revised_messages, existing_tools_prompt, special_tool_call
313
 
314
- def get_answer_based_on_unfinished_reasoning(self, conversation, temperature, max_new_tokens, max_token, outputs=None):
315
- # Truncate conversation to fit within max_token
316
- tokenized = self.tokenizer.encode(json.dumps(conversation), add_special_tokens=False)
317
- if len(tokenized) > max_token - 100:
318
- logger.warning("Truncating conversation to fit max_token=%d", max_token)
319
- while len(tokenized) > max_token - 100 and len(conversation) > 1:
320
- conversation.pop(1) # Keep system prompt and latest message
321
- tokenized = self.tokenizer.encode(json.dumps(conversation), add_special_tokens=False)
322
  if conversation[-1]['role'] == 'assistant':
323
  conversation.append(
324
- {'role': 'tool', 'content': 'Errors occurred during function call; provide final answer with current information.'})
325
  finish_tools_prompt = self.add_finish_tools([])
326
- last_outputs_str = self.llm_infer(
327
- messages=conversation,
328
- temperature=temperature,
329
- tools=finish_tools_prompt,
330
- output_begin_string='[FinalAnswer]',
331
- skip_special_tokens=True,
332
- max_new_tokens=max_new_tokens,
333
- max_token=max_token)
334
- logger.info("Unfinished reasoning answer: %s", last_outputs_str[:100])
335
- return last_outputs_str
336
-
337
- def run_multistep_agent(self, message: str,
338
- temperature: float,
339
- max_new_tokens: int,
340
- max_token: int,
341
- max_round: int = 5,
342
- call_agent=False,
343
- call_agent_level=0):
344
- logger.info("Starting multistep agent for message: %s", message[:100])
345
  picked_tools_prompt, call_agent_level = self.initialize_tools_prompt(
346
  call_agent, call_agent_level, message)
347
  conversation = self.initialize_conversation(message)
@@ -353,24 +314,22 @@ class TxAgent:
353
  enable_summary = False
354
  last_status = {}
355
 
 
 
 
356
  while next_round and current_round < max_round:
357
  current_round += 1
358
- if len(outputs) > 0:
359
  function_call_messages, picked_tools_prompt, special_tool_call = self.run_function_call(
360
- last_outputs, return_message=True,
361
- existing_tools_prompt=picked_tools_prompt,
362
- message_for_call_agent=message,
363
- call_agent=call_agent,
364
- call_agent_level=call_agent_level,
365
- temperature=temperature)
366
 
367
  if special_tool_call == 'Finish':
368
  next_round = False
369
  conversation.extend(function_call_messages)
370
  content = function_call_messages[0]['content']
371
- if content is None:
372
- return "❌ No content returned after Finish tool call."
373
- return content.split('[FinalAnswer]')[-1]
374
 
375
  if (self.enable_summary or token_overflow) and not call_agent:
376
  enable_summary = True
@@ -382,28 +341,26 @@ class TxAgent:
382
  outputs.append(tool_result_format(function_call_messages))
383
  else:
384
  next_round = False
385
- conversation.extend([{"role": "assistant", "content": ''.join(last_outputs)}])
386
  return ''.join(last_outputs).replace("</s>", "")
387
 
 
 
 
 
 
 
388
  last_outputs = []
389
- outputs.append("### TxAgent:\n")
390
  last_outputs_str, token_overflow = self.llm_infer(
391
- messages=conversation,
392
- temperature=temperature,
393
- tools=picked_tools_prompt,
394
- skip_special_tokens=False,
395
- max_new_tokens=max_new_tokens,
396
- max_token=max_token,
397
- check_token_status=True)
398
  if last_outputs_str is None:
399
- logger.warning("Token limit exceeded")
400
  if self.force_finish:
401
  return self.get_answer_based_on_unfinished_reasoning(
402
  conversation, temperature, max_new_tokens, max_token)
403
  return "❌ Token limit exceeded."
404
  last_outputs.append(last_outputs_str)
405
 
406
- if max_round == current_round:
407
  logger.warning("Max rounds exceeded")
408
  if self.force_finish:
409
  return self.get_answer_based_on_unfinished_reasoning(
@@ -413,16 +370,16 @@ class TxAgent:
413
  def build_logits_processor(self, messages, llm):
414
  tokenizer = llm.get_tokenizer()
415
  if self.avoid_repeat and len(messages) > 2:
416
- assistant_messages = [msg['content'] for msg in messages[-3:] if msg['role'] == 'assistant'][:2]
 
 
417
  forbidden_ids = [tokenizer.encode(msg, add_special_tokens=False) for msg in assistant_messages]
418
  return [NoRepeatSentenceProcessor(forbidden_ids, 5)]
419
  return None
420
 
421
- def llm_infer(self, messages, temperature=0.1, tools=None,
422
- output_begin_string=None, max_new_tokens=512,
423
- max_token=2048, skip_special_tokens=True,
424
- model=None, tokenizer=None, terminators=None,
425
- seed=None, check_token_status=False):
426
  if model is None:
427
  model = self.model
428
 
@@ -431,108 +388,73 @@ class TxAgent:
431
  temperature=temperature,
432
  max_tokens=max_new_tokens,
433
  seed=seed if seed is not None else self.seed,
 
434
  )
435
 
436
- prompt = self.chat_template.render(
437
- messages=messages, tools=tools, add_generation_prompt=True)
438
- if output_begin_string is not None:
439
  prompt += output_begin_string
440
 
441
- if check_token_status and max_token is not None:
442
- token_overflow = False
443
  num_input_tokens = len(self.tokenizer.encode(prompt, return_tensors="pt")[0])
444
  if num_input_tokens > max_token:
445
  torch.cuda.empty_cache()
446
  gc.collect()
447
  logger.info("Token overflow: %d > %d", num_input_tokens, max_token)
448
  return None, True
 
449
 
450
  output = model.generate(prompt, sampling_params=sampling_params)
451
  output = output[0].outputs[0].text
452
  logger.debug("Inference output: %s", output[:100])
453
- torch.cuda.empty_cache()
454
- gc.collect()
455
- if check_token_status and max_token is not None:
456
- return output, token_overflow
457
  return output
458
 
459
- def run_self_agent(self, message: str,
460
- temperature: float,
461
- max_new_tokens: int,
462
- max_token: int):
463
- logger.info("Starting self agent")
464
  conversation = self.set_system_prompt([], self.self_prompt)
465
  conversation.append({"role": "user", "content": message})
466
- return self.llm_infer(
467
- messages=conversation,
468
- temperature=temperature,
469
- tools=None,
470
- max_new_tokens=max_new_tokens,
471
- max_token=max_token)
472
-
473
- def run_chat_agent(self, message: str,
474
- temperature: float,
475
- max_new_tokens: int,
476
- max_token: int):
477
- logger.info("Starting chat agent")
478
  conversation = self.set_system_prompt([], self.chat_prompt)
479
  conversation.append({"role": "user", "content": message})
480
- return self.llm_infer(
481
- messages=conversation,
482
- temperature=temperature,
483
- tools=None,
484
- max_new_tokens=max_new_tokens,
485
- max_token=max_token)
486
-
487
- def run_format_agent(self, message: str,
488
- answer: str,
489
- temperature: float,
490
- max_new_tokens: int,
491
- max_token: int):
492
- logger.info("Starting format agent")
493
  if '[FinalAnswer]' in answer:
494
  possible_final_answer = answer.split("[FinalAnswer]")[-1]
495
  elif "\n\n" in answer:
496
  possible_final_answer = answer.split("\n\n")[-1]
497
  else:
498
  possible_final_answer = answer.strip()
499
- if len(possible_final_answer) == 1 and possible_final_answer in ['A', 'B', 'C', 'D', 'E']:
500
- return possible_final_answer
 
501
  elif len(possible_final_answer) > 1 and possible_final_answer[1] == ':' and possible_final_answer[0] in ['A', 'B', 'C', 'D', 'E']:
502
  return possible_final_answer[0]
503
 
504
  conversation = self.set_system_prompt(
505
- [], "Transform the agent's answer to a single letter: 'A', 'B', 'C', 'D'.")
506
- conversation.append({"role": "user", "content": message +
507
- "\nAgent's answer: " + answer + "\nAnswer (must be a letter):"})
508
- return self.llm_infer(
509
- messages=conversation,
510
- temperature=temperature,
511
- tools=None,
512
- max_new_tokens=max_new_tokens,
513
- max_token=max_token)
514
-
515
- def run_summary_agent(self, thought_calls: str,
516
- function_response: str,
517
- temperature: float,
518
- max_new_tokens: int,
519
- max_token: int):
520
- logger.info("Summarizing tool result")
521
- prompt = f"""Thought and function calls:
522
- {thought_calls}
523
- Function calls' responses:
524
- \"\"\"
525
- {function_response}
526
- \"\"\"
527
- Summarize the function calls' responses in one sentence with all necessary information.
528
- """
529
  conversation = [{"role": "user", "content": prompt}]
530
- output = self.llm_infer(
531
- messages=conversation,
532
- temperature=temperature,
533
- tools=None,
534
- max_new_tokens=max_new_tokens,
535
- max_token=max_token)
536
  if '[' in output:
537
  output = output.split('[')[0]
538
  return output
@@ -540,43 +462,55 @@ Summarize the function calls' responses in one sentence with all necessary infor
540
  def function_result_summary(self, input_list, status, enable_summary):
541
  if 'tool_call_step' not in status:
542
  status['tool_call_step'] = 0
 
 
 
 
543
  for idx in range(len(input_list)):
544
  pos_id = len(input_list) - idx - 1
545
  if input_list[pos_id]['role'] == 'assistant' and 'tool_calls' in input_list[pos_id]:
 
 
546
  break
547
 
548
- status['step'] = status.get('step', 0) + 1
549
  if not enable_summary:
550
  return status
551
 
552
- status['summarized_index'] = status.get('summarized_index', 0)
553
- status['summarized_step'] = status.get('summarized_step', 0)
554
- status['previous_length'] = status.get('previous_length', 0)
555
- status['history'] = status.get('history', [])
 
 
 
 
 
 
 
556
 
557
- function_response = ''
558
  idx = status['summarized_index']
 
559
  this_thought_calls = None
560
-
561
  while idx < len(input_list):
562
  if (self.summary_mode == 'step' and status['summarized_step'] < status['step'] - status['tool_call_step'] - self.summary_skip_last_k) or \
563
  (self.summary_mode == 'length' and status['previous_length'] > self.summary_context_length):
564
  if input_list[idx]['role'] == 'assistant':
565
- if function_response:
566
- status['summarized_step'] += 1
567
- result_summary = self.run_summary_agent(
568
- thought_calls=this_thought_calls,
569
- function_response=function_response,
570
- temperature=0.1,
571
- max_new_tokens=512,
572
- max_token=2048)
573
- input_list.insert(last_call_idx + 1, {'role': 'tool', 'content': result_summary})
574
- status['summarized_index'] = last_call_idx + 2
575
- idx += 1
576
- last_call_idx = idx
577
- this_thought_calls = input_list[idx]['content'] + input_list[idx]['tool_calls']
578
- function_response = ''
579
- elif input_list[idx]['role'] == 'tool' and this_thought_calls is not None:
 
580
  function_response += input_list[idx]['content']
581
  del input_list[idx]
582
  idx -= 1
@@ -587,16 +521,14 @@ Summarize the function calls' responses in one sentence with all necessary infor
587
  if function_response:
588
  status['summarized_step'] += 1
589
  result_summary = self.run_summary_agent(
590
- thought_calls=this_thought_calls,
591
- function_response=function_response,
592
- temperature=0.1,
593
- max_new_tokens=512,
594
- max_token=2048)
595
  tool_calls = json.loads(input_list[last_call_idx]['tool_calls'])
596
  for tool_call in tool_calls:
597
  del tool_call['call_id']
598
  input_list[last_call_idx]['tool_calls'] = json.dumps(tool_calls)
599
- input_list.insert(last_call_idx + 1, {'role': 'tool', 'content': result_summary})
 
600
  status['summarized_index'] = last_call_idx + 2
601
 
602
  return status
@@ -607,32 +539,27 @@ Summarize the function calls' responses in one sentence with all necessary infor
607
  if hasattr(self, key):
608
  setattr(self, key, value)
609
  updated_attributes[key] = value
610
- logger.info("Updated parameters: %s", updated_attributes)
611
  return updated_attributes
612
 
613
- def run_gradio_chat(self, message: str,
614
- history: list,
615
- temperature: float,
616
- max_new_tokens: int,
617
- max_token: int,
618
- call_agent: bool,
619
- conversation: gr.State,
620
- max_round: int = 5,
621
- seed: int = None,
622
- call_agent_level: int = 0,
623
- sub_agent_task: str = None,
624
  uploaded_files: list = None):
625
- logger.info("Chat started, message: %s", message[:100])
626
  if not message or len(message.strip()) < 5:
627
  yield "Please provide a valid message or upload files to analyze."
628
  return
629
 
 
 
 
630
  picked_tools_prompt, call_agent_level = self.initialize_tools_prompt(
631
  call_agent, call_agent_level, message)
632
  conversation = self.initialize_conversation(
633
  message, conversation, history)
634
  history = []
635
- last_outputs = [] # Initialize last_outputs to avoid UnboundLocalError
636
 
637
  next_round = True
638
  current_round = 0
@@ -640,17 +567,18 @@ Summarize the function calls' responses in one sentence with all necessary infor
640
  last_status = {}
641
  token_overflow = False
642
 
 
 
 
643
  try:
644
  while next_round and current_round < max_round:
645
  current_round += 1
646
- if last_outputs: # Process previous outputs if any
 
647
  function_call_messages, picked_tools_prompt, special_tool_call, current_gradio_history = yield from self.run_function_call_stream(
648
- last_outputs, return_message=True,
649
- existing_tools_prompt=picked_tools_prompt,
650
- message_for_call_agent=message,
651
- call_agent=call_agent,
652
- call_agent_level=call_agent_level,
653
- temperature=temperature)
654
  history.extend(current_gradio_history)
655
 
656
  if special_tool_call == 'Finish':
@@ -659,7 +587,7 @@ Summarize the function calls' responses in one sentence with all necessary infor
659
  conversation.extend(function_call_messages)
660
  return function_call_messages[0]['content']
661
 
662
- elif special_tool_call in ['RequireClarification', 'DirectResponse']:
663
  last_msg = history[-1] if history else ChatMessage(role="assistant", content="Response needed.")
664
  history.append(ChatMessage(role="assistant", content=last_msg.content))
665
  yield history
@@ -676,22 +604,19 @@ Summarize the function calls' responses in one sentence with all necessary infor
676
  yield history
677
  else:
678
  next_round = False
679
- conversation.append({"role": "assistant", "content": ''.join(last_outputs)})
680
  return ''.join(last_outputs).replace("</s>", "")
681
 
682
- last_outputs = []
 
 
 
 
 
683
  last_outputs_str, token_overflow = self.llm_infer(
684
- messages=conversation,
685
- temperature=temperature,
686
- tools=picked_tools_prompt,
687
- skip_special_tokens=False,
688
- max_new_tokens=max_new_tokens,
689
- max_token=max_token,
690
- seed=seed,
691
- check_token_status=True)
692
 
693
  if last_outputs_str is None:
694
- logger.warning("Token limit exceeded")
695
  if self.force_finish:
696
  last_outputs_str = self.get_answer_based_on_unfinished_reasoning(
697
  conversation, temperature, max_new_tokens, max_token)
@@ -705,7 +630,7 @@ Summarize the function calls' responses in one sentence with all necessary infor
705
 
706
  last_thought = last_outputs_str.split("[TOOL_CALLS]")[0]
707
  for msg in history:
708
- if msg.metadata is not None:
709
  msg.metadata['status'] = 'done'
710
 
711
  if '[FinalAnswer]' in last_thought:
@@ -721,18 +646,15 @@ Summarize the function calls' responses in one sentence with all necessary infor
721
 
722
  last_outputs.append(last_outputs_str)
723
 
724
- if next_round:
725
- if self.force_finish:
726
- last_outputs_str = self.get_answer_based_on_unfinished_reasoning(
727
- conversation, temperature, max_new_tokens, max_token)
728
- parts = last_outputs_str.split('[FinalAnswer]', 1)
729
- final_thought, final_answer = parts if len(parts) == 2 else (last_outputs_str, "")
730
- history.append(ChatMessage(role="assistant", content=final_thought.strip()))
731
- yield history
732
- history.append(ChatMessage(role="assistant", content="**🧠 Final Analysis:**\n" + final_answer.strip()))
733
- yield history
734
- else:
735
- yield "Reasoning rounds exceeded limit."
736
 
737
  except Exception as e:
738
  logger.error("Exception in run_gradio_chat: %s", e, exc_info=True)
 
24
  rag_model_name,
25
  tool_files_dict=None,
26
  enable_finish=True,
27
+ enable_rag=True,
28
  enable_summary=False,
29
+ init_rag_num=2, # Reduced for faster initial tool selection
30
+ step_rag_num=4, # Reduced for fewer RAG calls
31
  summary_mode='step',
32
  summary_skip_last_k=0,
33
  summary_context_length=None,
34
  force_finish=True,
35
  avoid_repeat=True,
36
  seed=None,
37
+ enable_checker=False, # Disabled by default for speed
38
  enable_chat=False,
39
  additional_default_tools=None):
40
  self.model_name = model_name
 
45
  self.model = None
46
  self.rag_model = ToolRAGModel(rag_model_name)
47
  self.tooluniverse = None
48
+ self.prompt_multi_step = "You are a medical assistant solving clinical oversight issues step-by-step using provided tools."
49
+ self.self_prompt = "Follow instructions precisely."
50
+ self.chat_prompt = "You are a helpful assistant for clinical queries."
51
  self.enable_finish = enable_finish
52
  self.enable_rag = enable_rag
53
  self.enable_summary = enable_summary
 
61
  self.seed = seed
62
  self.enable_checker = enable_checker
63
  self.additional_default_tools = additional_default_tools
64
+ logger.debug("TxAgent initialized with parameters: %s", self.__dict__)
65
 
66
  def init_model(self):
67
  self.load_models()
68
  self.load_tooluniverse()
69
+ self.load_tool_desc_embedding()
70
+
71
+ def print_self_values(self):
72
+ for attr, value in self.__dict__.items():
73
+ logger.debug("%s: %s", attr, value)
74
 
75
  def load_models(self, model_name=None):
76
+ if model_name is not None and model_name == self.model_name:
77
+ return f"The model {model_name} is already loaded."
78
+ if model_name:
79
  self.model_name = model_name
80
 
81
+ self.model = LLM(model=self.model_name, dtype="float16") # Enable FP16
82
  self.chat_template = Template(self.model.get_tokenizer().chat_template)
83
  self.tokenizer = self.model.get_tokenizer()
84
  logger.info("Model %s loaded successfully", self.model_name)
85
+ return f"Model {self.model_name} loaded successfully."
86
 
87
  def load_tooluniverse(self):
88
  self.tooluniverse = ToolUniverse(tool_files=self.tool_files_dict)
 
93
  logger.debug("ToolUniverse loaded with %d special tools", len(self.special_tools_name))
94
 
95
  def load_tool_desc_embedding(self):
96
+ self.rag_model.load_tool_desc_embedding(self.tooluniverse)
 
 
 
 
 
97
  logger.debug("Tool description embeddings loaded")
98
 
99
  def rag_infer(self, query, top_k=5):
 
107
  call_agent_level += 1
108
  if call_agent_level >= 2:
109
  call_agent = False
110
+
111
+ if not call_agent and self.enable_rag:
112
+ picked_tools_prompt += self.tool_RAG(
113
+ message=message, rag_num=self.init_rag_num)
114
  return picked_tools_prompt, call_agent_level
115
 
116
  def initialize_conversation(self, message, conversation=None, history=None):
117
  if conversation is None:
118
  conversation = []
119
 
120
+ conversation = self.set_system_prompt(conversation, self.prompt_multi_step)
 
121
  if history:
122
+ conversation.extend(
123
+ {"role": h['role'], "content": h['content']}
124
+ for h in history if h['role'] in ['user', 'assistant']
125
+ )
 
126
  conversation.append({"role": "user", "content": message})
127
  logger.debug("Conversation initialized with %d messages", len(conversation))
128
  return conversation
129
 
130
+ def tool_RAG(self, message=None, picked_tool_names=None,
131
+ existing_tools_prompt=None, rag_num=4, return_call_result=False):
132
+ extra_factor = 10 # Reduced from 30 for efficiency
 
 
 
 
 
133
  if picked_tool_names is None:
134
+ picked_tool_names = self.rag_infer(message, top_k=rag_num * extra_factor)
 
 
 
 
 
135
 
136
+ picked_tool_names = [
137
+ tool for tool in picked_tool_names
138
+ if tool not in self.special_tools_name
139
+ ][:rag_num]
140
  picked_tools = self.tooluniverse.get_tool_by_name(picked_tool_names)
141
  picked_tools_prompt = self.tooluniverse.prepare_tool_prompts(picked_tools)
142
+ logger.debug("RAG selected %d tools: %s", len(picked_tool_names), picked_tool_names)
143
  if return_call_result:
144
  return picked_tools_prompt, picked_tool_names
145
  return picked_tools_prompt
 
151
  if call_agent:
152
  tools.append(self.tooluniverse.get_one_tool_by_one_name('CallAgent', return_prompt=True))
153
  logger.debug("CallAgent tool added")
154
+ elif self.enable_rag:
155
+ tools.append(self.tooluniverse.get_one_tool_by_one_name('Tool_RAG', return_prompt=True))
156
+ logger.debug("Tool_RAG tool added")
157
+ if self.additional_default_tools:
158
+ for tool_name in self.additional_default_tools:
159
+ tool_prompt = self.tooluniverse.get_one_tool_by_one_name(tool_name, return_prompt=True)
160
+ if tool_prompt:
161
+ tools.append(tool_prompt)
162
+ logger.debug("%s tool added", tool_name)
163
  return tools
164
 
165
  def add_finish_tools(self, tools):
 
174
  conversation[0] = {"role": "system", "content": sys_prompt}
175
  return conversation
176
 
177
+ def run_function_call(self, fcall_str, return_message=False,
178
+ existing_tools_prompt=None, message_for_call_agent=None,
179
+ call_agent=False, call_agent_level=None, temperature=None):
180
+ function_call_json, message = self.tooluniverse.extract_function_call_json(
181
+ fcall_str, return_message=return_message, verbose=False)
 
 
 
 
 
 
 
 
 
 
182
  call_results = []
183
  special_tool_call = ''
184
  if function_call_json:
185
+ for func in function_call_json if isinstance(function_call_json, list) else [function_call_json]:
186
+ logger.debug("Tool Call: %s", func)
187
+ if func["name"] == 'Finish':
188
+ special_tool_call = 'Finish'
189
+ break
190
+ elif func["name"] == 'Tool_RAG':
191
+ new_tools_prompt, call_result = self.tool_RAG(
192
+ message=message, existing_tools_prompt=existing_tools_prompt,
193
+ rag_num=self.step_rag_num, return_call_result=True)
194
+ existing_tools_prompt += new_tools_prompt
195
+ elif func["name"] == 'CallAgent' and call_agent and call_agent_level < 2:
196
+ solution_plan = func['arguments']['solution']
197
+ full_message = (
198
+ message_for_call_agent + "\nFollow this plan: " + str(solution_plan)
199
+ )
200
+ call_result = self.run_multistep_agent(
201
+ full_message, temperature=temperature, max_new_tokens=512,
202
+ max_token=2048, call_agent=False, call_agent_level=call_agent_level)
203
+ call_result = call_result.split('[FinalAnswer]')[-1].strip() if call_result else "⚠️ No content from sub-agent."
204
+ else:
205
+ call_result = self.tooluniverse.run_one_function(func)
206
+
207
+ call_id = self.tooluniverse.call_id_gen()
208
+ func["call_id"] = call_id
209
+ logger.debug("Tool Call Result: %s", call_result)
210
+ call_results.append({
211
+ "role": "tool",
212
+ "content": json.dumps({"tool_name": func["name"], "content": call_result, "call_id": call_id})
213
+ })
 
 
 
 
214
  else:
215
  call_results.append({
216
  "role": "tool",
217
+ "content": json.dumps({"content": "Invalid function call format."})
218
  })
219
 
220
  revised_messages = [{
221
  "role": "assistant",
222
+ "content": message.strip() if message else "",
223
  "tool_calls": json.dumps(function_call_json)
224
  }] + call_results
225
  return revised_messages, existing_tools_prompt, special_tool_call
226
 
227
+ def run_function_call_stream(self, fcall_str, return_message=False,
228
+ existing_tools_prompt=None, message_for_call_agent=None,
229
+ call_agent=False, call_agent_level=None, temperature=None,
230
+ return_gradio_history=True):
231
+ function_call_json, message = self.tooluniverse.extract_function_call_json(
232
+ fcall_str, return_message=return_message, verbose=False)
 
 
 
 
 
 
 
 
 
 
233
  call_results = []
234
  special_tool_call = ''
235
+ gradio_history = [] if return_gradio_history else None
 
236
  if function_call_json:
237
+ for func in function_call_json if isinstance(function_call_json, list) else [function_call_json]:
238
+ if func["name"] == 'Finish':
239
+ special_tool_call = 'Finish'
240
+ break
241
+ elif func["name"] == 'Tool_RAG':
242
+ new_tools_prompt, call_result = self.tool_RAG(
243
+ message=message, existing_tools_prompt=existing_tools_prompt,
244
+ rag_num=self.step_rag_num, return_call_result=True)
245
+ existing_tools_prompt += new_tools_prompt
246
+ elif func["name"] == 'DirectResponse':
247
+ call_result = func['arguments']['response']
248
+ special_tool_call = 'DirectResponse'
249
+ elif func["name"] == 'RequireClarification':
250
+ call_result = func['arguments']['unclear_question']
251
+ special_tool_call = 'RequireClarification'
252
+ elif func["name"] == 'CallAgent' and call_agent and call_agent_level < 2:
253
+ solution_plan = func['arguments']['solution']
254
+ full_message = (
255
+ message_for_call_agent + "\nFollow this plan: " + str(solution_plan)
256
+ )
257
+ sub_agent_task = "Sub TxAgent plan: " + str(solution_plan)
258
+ call_result = yield from self.run_gradio_chat(
259
+ full_message, history=[], temperature=temperature,
260
+ max_new_tokens=512, max_token=2048, call_agent=False,
261
+ call_agent_level=call_agent_level, conversation=None,
262
+ sub_agent_task=sub_agent_task)
263
+ call_result = call_result.split('[FinalAnswer]')[-1] if call_result else "⚠️ No content from sub-agent."
264
+ else:
265
+ call_result = self.tooluniverse.run_one_function(func)
266
+
267
+ call_id = self.tooluniverse.call_id_gen()
268
+ func["call_id"] = call_id
269
+ call_results.append({
270
+ "role": "tool",
271
+ "content": json.dumps({"tool_name": func["name"], "content": call_result, "call_id": call_id})
272
+ })
273
+ if return_gradio_history and func["name"] != 'Finish':
274
+ title = f"{'🧰' if func['name'] == 'Tool_RAG' else '⚒️'} {func['name']}"
275
+ gradio_history.append(ChatMessage(
276
+ role="assistant", content=str(call_result),
277
+ metadata={"title": title, "log": str(func['arguments'])}
278
+ ))
279
  else:
280
  call_results.append({
281
  "role": "tool",
282
+ "content": json.dumps({"content": "Invalid function call format."})
283
  })
284
 
285
  revised_messages = [{
286
  "role": "assistant",
287
+ "content": message.strip() if message else "",
288
  "tool_calls": json.dumps(function_call_json)
289
  }] + call_results
290
+ return revised_messages, existing_tools_prompt, special_tool_call, gradio_history
 
 
291
 
292
+ def get_answer_based_on_unfinished_reasoning(self, conversation, temperature, max_new_tokens, max_token):
 
 
 
 
 
 
 
293
  if conversation[-1]['role'] == 'assistant':
294
  conversation.append(
295
+ {'role': 'tool', 'content': 'Errors occurred; provide final answer with current info.'})
296
  finish_tools_prompt = self.add_finish_tools([])
297
+ output = self.llm_infer(
298
+ messages=conversation, temperature=temperature, tools=finish_tools_prompt,
299
+ output_begin_string='[FinalAnswer]', max_new_tokens=max_new_tokens, max_token=max_token)
300
+ logger.debug("Unfinished reasoning output: %s", output)
301
+ return output
302
+
303
+ def run_multistep_agent(self, message: str, temperature: float, max_new_tokens: int,
304
+ max_token: int, max_round: int = 10, call_agent=False, call_agent_level=0):
305
+ logger.debug("Starting multistep agent for message: %s", message[:100])
 
 
 
 
 
 
 
 
 
 
306
  picked_tools_prompt, call_agent_level = self.initialize_tools_prompt(
307
  call_agent, call_agent_level, message)
308
  conversation = self.initialize_conversation(message)
 
314
  enable_summary = False
315
  last_status = {}
316
 
317
+ if self.enable_checker:
318
+ checker = ReasoningTraceChecker(message, conversation)
319
+
320
  while next_round and current_round < max_round:
321
  current_round += 1
322
+ if last_outputs:
323
  function_call_messages, picked_tools_prompt, special_tool_call = self.run_function_call(
324
+ last_outputs, return_message=True, existing_tools_prompt=picked_tools_prompt,
325
+ message_for_call_agent=message, call_agent=call_agent,
326
+ call_agent_level=call_agent_level, temperature=temperature)
 
 
 
327
 
328
  if special_tool_call == 'Finish':
329
  next_round = False
330
  conversation.extend(function_call_messages)
331
  content = function_call_messages[0]['content']
332
+ return content.split('[FinalAnswer]')[-1] if content else "❌ No content after Finish."
 
 
333
 
334
  if (self.enable_summary or token_overflow) and not call_agent:
335
  enable_summary = True
 
341
  outputs.append(tool_result_format(function_call_messages))
342
  else:
343
  next_round = False
 
344
  return ''.join(last_outputs).replace("</s>", "")
345
 
346
+ if self.enable_checker:
347
+ good_status, wrong_info = checker.check_conversation()
348
+ if not good_status:
349
+ logger.warning("Checker error: %s", wrong_info)
350
+ break
351
+
352
  last_outputs = []
 
353
  last_outputs_str, token_overflow = self.llm_infer(
354
+ messages=conversation, temperature=temperature, tools=picked_tools_prompt,
355
+ max_new_tokens=max_new_tokens, max_token=max_token, check_token_status=True)
 
 
 
 
 
356
  if last_outputs_str is None:
 
357
  if self.force_finish:
358
  return self.get_answer_based_on_unfinished_reasoning(
359
  conversation, temperature, max_new_tokens, max_token)
360
  return "❌ Token limit exceeded."
361
  last_outputs.append(last_outputs_str)
362
 
363
+ if current_round >= max_round:
364
  logger.warning("Max rounds exceeded")
365
  if self.force_finish:
366
  return self.get_answer_based_on_unfinished_reasoning(
 
370
  def build_logits_processor(self, messages, llm):
371
  tokenizer = llm.get_tokenizer()
372
  if self.avoid_repeat and len(messages) > 2:
373
+ assistant_messages = [
374
+ m['content'] for m in messages[-3:] if m['role'] == 'assistant'
375
+ ][:2]
376
  forbidden_ids = [tokenizer.encode(msg, add_special_tokens=False) for msg in assistant_messages]
377
  return [NoRepeatSentenceProcessor(forbidden_ids, 5)]
378
  return None
379
 
380
+ def llm_infer(self, messages, temperature=0.1, tools=None, output_begin_string=None,
381
+ max_new_tokens=512, max_token=2048, skip_special_tokens=True,
382
+ model=None, tokenizer=None, terminators=None, seed=None, check_token_status=False):
 
 
383
  if model is None:
384
  model = self.model
385
 
 
388
  temperature=temperature,
389
  max_tokens=max_new_tokens,
390
  seed=seed if seed is not None else self.seed,
391
+ logits_processors=logits_processor
392
  )
393
 
394
+ prompt = self.chat_template.render(messages=messages, tools=tools, add_generation_prompt=True)
395
+ if output_begin_string:
 
396
  prompt += output_begin_string
397
 
398
+ if check_token_status and max_token:
 
399
  num_input_tokens = len(self.tokenizer.encode(prompt, return_tensors="pt")[0])
400
  if num_input_tokens > max_token:
401
  torch.cuda.empty_cache()
402
  gc.collect()
403
  logger.info("Token overflow: %d > %d", num_input_tokens, max_token)
404
  return None, True
405
+ logger.debug("Input tokens: %d", num_input_tokens)
406
 
407
  output = model.generate(prompt, sampling_params=sampling_params)
408
  output = output[0].outputs[0].text
409
  logger.debug("Inference output: %s", output[:100])
410
+ torch.cuda.empty_cache() # Clear CUDA cache
411
+ if check_token_status:
412
+ return output, False
 
413
  return output
414
 
415
+ def run_self_agent(self, message: str, temperature: float, max_new_tokens: int, max_token: int):
416
+ logger.debug("Starting self agent")
 
 
 
417
  conversation = self.set_system_prompt([], self.self_prompt)
418
  conversation.append({"role": "user", "content": message})
419
+ return self.llm_infer(messages=conversation, temperature=temperature,
420
+ max_new_tokens=max_new_tokens, max_token=max_token)
421
+
422
+ def run_chat_agent(self, message: str, temperature: float, max_new_tokens: int, max_token: int):
423
+ logger.debug("Starting chat agent")
 
 
 
 
 
 
 
424
  conversation = self.set_system_prompt([], self.chat_prompt)
425
  conversation.append({"role": "user", "content": message})
426
+ return self.llm_infer(messages=conversation, temperature=temperature,
427
+ max_new_tokens=max_new_tokens, max_token=max_token)
428
+
429
+ def run_format_agent(self, message: str, answer: str, temperature: float, max_new_tokens: int, max_token: int):
430
+ logger.debug("Starting format agent")
 
 
 
 
 
 
 
 
431
  if '[FinalAnswer]' in answer:
432
  possible_final_answer = answer.split("[FinalAnswer]")[-1]
433
  elif "\n\n" in answer:
434
  possible_final_answer = answer.split("\n\n")[-1]
435
  else:
436
  possible_final_answer = answer.strip()
437
+
438
+ if len(possible_final_answer) >= 1 and possible_final_answer[0] in ['A', 'B', 'C', 'D', 'E']:
439
+ return possible_final_answer[0]
440
  elif len(possible_final_answer) > 1 and possible_final_answer[1] == ':' and possible_final_answer[0] in ['A', 'B', 'C', 'D', 'E']:
441
  return possible_final_answer[0]
442
 
443
  conversation = self.set_system_prompt(
444
+ [], "Transform the answer to a single letter: 'A', 'B', 'C', 'D', or 'E'.")
445
+ conversation.append({"role": "user", "content": f"Original: {message}\nAnswer: {answer}\nFinal answer (letter):"})
446
+ return self.llm_infer(messages=conversation, temperature=temperature,
447
+ max_new_tokens=max_new_tokens, max_token=max_token)
448
+
449
+ def run_summary_agent(self, thought_calls: str, function_response: str,
450
+ temperature: float, max_new_tokens: int, max_token: int):
451
+ logger.debug("Starting summary agent")
452
+ prompt = f"""Thought and function calls: {thought_calls}
453
+ Function responses: {function_response}
454
+ Summarize the function responses in one sentence with all necessary information."""
 
 
 
 
 
 
 
 
 
 
 
 
 
455
  conversation = [{"role": "user", "content": prompt}]
456
+ output = self.llm_infer(messages=conversation, temperature=temperature,
457
+ max_new_tokens=max_new_tokens, max_token=max_token)
 
 
 
 
458
  if '[' in output:
459
  output = output.split('[')[0]
460
  return output
 
462
  def function_result_summary(self, input_list, status, enable_summary):
463
  if 'tool_call_step' not in status:
464
  status['tool_call_step'] = 0
465
+ if 'step' not in status:
466
+ status['step'] = 0
467
+ status['step'] += 1
468
+
469
  for idx in range(len(input_list)):
470
  pos_id = len(input_list) - idx - 1
471
  if input_list[pos_id]['role'] == 'assistant' and 'tool_calls' in input_list[pos_id]:
472
+ if 'Tool_RAG' in str(input_list[pos_id]['tool_calls']):
473
+ status['tool_call_step'] += 1
474
  break
475
 
 
476
  if not enable_summary:
477
  return status
478
 
479
+ if 'summarized_index' not in status:
480
+ status['summarized_index'] = 0
481
+ if 'summarized_step' not in status:
482
+ status['summarized_step'] = 0
483
+ if 'previous_length' not in status:
484
+ status['previous_length'] = 0
485
+ if 'history' not in status:
486
+ status['history'] = []
487
+
488
+ status['history'].append(
489
+ self.summary_mode == 'step' and status['summarized_step'] < status['step'] - status['tool_call_step'] - self.summary_skip_last_k)
490
 
 
491
  idx = status['summarized_index']
492
+ function_response = ''
493
  this_thought_calls = None
 
494
  while idx < len(input_list):
495
  if (self.summary_mode == 'step' and status['summarized_step'] < status['step'] - status['tool_call_step'] - self.summary_skip_last_k) or \
496
  (self.summary_mode == 'length' and status['previous_length'] > self.summary_context_length):
497
  if input_list[idx]['role'] == 'assistant':
498
+ if 'Tool_RAG' in str(input_list[idx]['tool_calls']):
499
+ this_thought_calls = None
500
+ else:
501
+ if function_response:
502
+ status['summarized_step'] += 1
503
+ result_summary = self.run_summary_agent(
504
+ thought_calls=this_thought_calls, function_response=function_response,
505
+ temperature=0.1, max_new_tokens=512, max_token=2048)
506
+ input_list.insert(
507
+ last_call_idx + 1, {'role': 'tool', 'content': result_summary})
508
+ status['summarized_index'] = last_call_idx + 2
509
+ idx += 1
510
+ last_call_idx = idx
511
+ this_thought_calls = input_list[idx]['content'] + input_list[idx]['tool_calls']
512
+ function_response = ''
513
+ elif input_list[idx]['role'] == 'tool' and this_thought_calls:
514
  function_response += input_list[idx]['content']
515
  del input_list[idx]
516
  idx -= 1
 
521
  if function_response:
522
  status['summarized_step'] += 1
523
  result_summary = self.run_summary_agent(
524
+ thought_calls=this_thought_calls, function_response=function_response,
525
+ temperature=0.1, max_new_tokens=512, max_token=2048)
 
 
 
526
  tool_calls = json.loads(input_list[last_call_idx]['tool_calls'])
527
  for tool_call in tool_calls:
528
  del tool_call['call_id']
529
  input_list[last_call_idx]['tool_calls'] = json.dumps(tool_calls)
530
+ input_list.insert(
531
+ last_call_idx + 1, {'role': 'tool', 'content': result_summary})
532
  status['summarized_index'] = last_call_idx + 2
533
 
534
  return status
 
539
  if hasattr(self, key):
540
  setattr(self, key, value)
541
  updated_attributes[key] = value
542
+ logger.debug("Updated parameters: %s", updated_attributes)
543
  return updated_attributes
544
 
545
+ def run_gradio_chat(self, message: str, history: list, temperature: float,
546
+ max_new_tokens: int, max_token: int, call_agent: bool,
547
+ conversation: gr.State, max_round: int = 10, seed: int = None,
548
+ call_agent_level: int = 0, sub_agent_task: str = None,
 
 
 
 
 
 
 
549
  uploaded_files: list = None):
550
+ logger.debug("Chat started, message: %s", message[:100])
551
  if not message or len(message.strip()) < 5:
552
  yield "Please provide a valid message or upload files to analyze."
553
  return
554
 
555
+ if message.startswith("[\U0001f9f0 Tool_RAG") or message.startswith("⚒️"):
556
+ return
557
+
558
  picked_tools_prompt, call_agent_level = self.initialize_tools_prompt(
559
  call_agent, call_agent_level, message)
560
  conversation = self.initialize_conversation(
561
  message, conversation, history)
562
  history = []
 
563
 
564
  next_round = True
565
  current_round = 0
 
567
  last_status = {}
568
  token_overflow = False
569
 
570
+ if self.enable_checker:
571
+ checker = ReasoningTraceChecker(message, conversation, init_index=len(conversation))
572
+
573
  try:
574
  while next_round and current_round < max_round:
575
  current_round += 1
576
+ last_outputs = []
577
+ if last_outputs:
578
  function_call_messages, picked_tools_prompt, special_tool_call, current_gradio_history = yield from self.run_function_call_stream(
579
+ last_outputs, return_message=True, existing_tools_prompt=picked_tools_prompt,
580
+ message_for_call_agent=message, call_agent=call_agent,
581
+ call_agent_level=call_agent_level, temperature=temperature)
 
 
 
582
  history.extend(current_gradio_history)
583
 
584
  if special_tool_call == 'Finish':
 
587
  conversation.extend(function_call_messages)
588
  return function_call_messages[0]['content']
589
 
590
+ if special_tool_call in ['RequireClarification', 'DirectResponse']:
591
  last_msg = history[-1] if history else ChatMessage(role="assistant", content="Response needed.")
592
  history.append(ChatMessage(role="assistant", content=last_msg.content))
593
  yield history
 
604
  yield history
605
  else:
606
  next_round = False
 
607
  return ''.join(last_outputs).replace("</s>", "")
608
 
609
+ if self.enable_checker:
610
+ good_status, wrong_info = checker.check_conversation()
611
+ if not good_status:
612
+ logger.warning("Checker error: %s", wrong_info)
613
+ break
614
+
615
  last_outputs_str, token_overflow = self.llm_infer(
616
+ messages=conversation, temperature=temperature, tools=picked_tools_prompt,
617
+ max_new_tokens=max_new_tokens, max_token=max_token, seed=seed, check_token_status=True)
 
 
 
 
 
 
618
 
619
  if last_outputs_str is None:
 
620
  if self.force_finish:
621
  last_outputs_str = self.get_answer_based_on_unfinished_reasoning(
622
  conversation, temperature, max_new_tokens, max_token)
 
630
 
631
  last_thought = last_outputs_str.split("[TOOL_CALLS]")[0]
632
  for msg in history:
633
+ if msg.metadata:
634
  msg.metadata['status'] = 'done'
635
 
636
  if '[FinalAnswer]' in last_thought:
 
646
 
647
  last_outputs.append(last_outputs_str)
648
 
649
+ if next_round and self.force_finish:
650
+ last_outputs_str = self.get_answer_based_on_unfinished_reasoning(
651
+ conversation, temperature, max_new_tokens, max_token)
652
+ parts = last_outputs_str.split('[FinalAnswer]', 1)
653
+ final_thought, final_answer = parts if len(parts) == 2 else (last_outputs_str, "")
654
+ history.append(ChatMessage(role="assistant", content=final_thought.strip()))
655
+ yield history
656
+ history.append(ChatMessage(role="assistant", content="**🧠 Final Analysis:**\n" + final_answer.strip()))
657
+ yield history
 
 
 
658
 
659
  except Exception as e:
660
  logger.error("Exception in run_gradio_chat: %s", e, exc_info=True)