Ali2206 commited on
Commit
7524766
·
verified ·
1 Parent(s): f7cda5c

Update src/txagent/txagent.py

Browse files
Files changed (1) hide show
  1. src/txagent/txagent.py +252 -292
src/txagent/txagent.py CHANGED
@@ -12,17 +12,18 @@ from tooluniverse import ToolUniverse
12
  from gradio import ChatMessage
13
  from .toolrag import ToolRAGModel
14
  import torch
 
15
  import logging
16
-
17
  logger = logging.getLogger(__name__)
18
- logging.basicConfig(level=logging.DEBUG, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
19
 
20
  from .utils import NoRepeatSentenceProcessor, ReasoningTraceChecker, tool_result_format
21
 
 
22
  class TxAgent:
23
  def __init__(self, model_name,
24
  rag_model_name,
25
- tool_files_dict=None,
26
  enable_finish=True,
27
  enable_rag=True,
28
  enable_summary=False,
@@ -46,13 +47,10 @@ class TxAgent:
46
  self.model = None
47
  self.rag_model = ToolRAGModel(rag_model_name)
48
  self.tooluniverse = None
49
- self.prompt_multi_step = ("You are a highly skilled medical assistant tasked with analyzing medical records in detail. "
50
- "Provide comprehensive, step-by-step reasoning to identify oversights, including specific diagnoses, "
51
- "medication conflicts, incomplete assessments, and abnormal results. For each point, include clinical "
52
- "rationale, standardized screening tools (e.g., PCL-5, SCID-5-PD), and actionable recommendations for "
53
- "follow-up, ensuring a thorough and precise response.")
54
  self.self_prompt = "Strictly follow the instruction."
55
- self.chat_prompt = "You are a helpful assistant to chat with the user."
56
  self.enable_finish = enable_finish
57
  self.enable_rag = enable_rag
58
  self.enable_summary = enable_summary
@@ -67,57 +65,42 @@ class TxAgent:
67
  self.enable_checker = enable_checker
68
  self.additional_default_tools = additional_default_tools
69
  self.print_self_values()
70
- logger.info("TxAgent initialized with model_name=%s, rag_model_name=%s", model_name, rag_model_name)
71
 
72
  def init_model(self):
73
- logger.info("Initializing model: %s", self.model_name)
74
- try:
75
- self.load_models()
76
- self.load_tooluniverse()
77
- self.load_tool_desc_embedding()
78
- logger.info("Model initialization complete")
79
- except Exception as e:
80
- logger.error("Failed to initialize model: %s", e, exc_info=True)
81
- raise
82
 
83
  def print_self_values(self):
84
  for attr, value in self.__dict__.items():
85
- logger.debug("%s: %s", attr, value)
86
 
87
  def load_models(self, model_name=None):
88
  if model_name is not None:
89
  if model_name == self.model_name:
90
- logger.debug("Model %s already loaded", model_name)
91
  return f"The model {model_name} is already loaded."
92
  self.model_name = model_name
93
 
94
- logger.debug("Loading model %s", self.model_name)
95
  self.model = LLM(model=self.model_name)
96
  self.chat_template = Template(self.model.get_tokenizer().chat_template)
97
  self.tokenizer = self.model.get_tokenizer()
98
- logger.info("Model %s loaded successfully", self.model_name)
99
  return f"Model {model_name} loaded successfully."
100
 
101
  def load_tooluniverse(self):
102
- logger.debug("Loading tool universe")
103
  self.tooluniverse = ToolUniverse(tool_files=self.tool_files_dict)
104
  self.tooluniverse.load_tools()
105
  special_tools = self.tooluniverse.prepare_tool_prompts(
106
  self.tooluniverse.tool_category_dicts["special_tools"])
107
  self.special_tools_name = [tool['name'] for tool in special_tools]
108
- logger.debug("Tool universe loaded with %d special tools", len(self.special_tools_name))
109
 
110
  def load_tool_desc_embedding(self):
111
- logger.debug("Loading tool description embeddings")
112
  self.rag_model.load_tool_desc_embedding(self.tooluniverse)
113
- logger.debug("Tool description embeddings loaded")
114
 
115
  def rag_infer(self, query, top_k=5):
116
- logger.debug("Running RAG inference with query: %s", query[:50])
117
  return self.rag_model.rag_infer(query, top_k)
118
 
119
  def initialize_tools_prompt(self, call_agent, call_agent_level, message):
120
- logger.debug("Initializing tools prompt, call_agent=%s, level=%d", call_agent, call_agent_level)
121
  picked_tools_prompt = []
122
  picked_tools_prompt = self.add_special_tools(
123
  picked_tools_prompt, call_agent=call_agent)
@@ -129,11 +112,9 @@ class TxAgent:
129
  if not call_agent:
130
  picked_tools_prompt += self.tool_RAG(
131
  message=message, rag_num=self.init_rag_num)
132
- logger.debug("Tools prompt initialized with %d tools", len(picked_tools_prompt))
133
  return picked_tools_prompt, call_agent_level
134
 
135
  def initialize_conversation(self, message, conversation=None, history=None):
136
- logger.debug("Initializing conversation with message: %s", message[:50])
137
  if conversation is None:
138
  conversation = []
139
 
@@ -142,7 +123,7 @@ class TxAgent:
142
  if history is not None:
143
  if len(history) == 0:
144
  conversation = []
145
- logger.debug("Cleared conversation")
146
  else:
147
  for i in range(len(history)):
148
  if history[i]['role'] == 'user':
@@ -156,7 +137,7 @@ class TxAgent:
156
  {"role": "assistant", "content": history[i]['content']})
157
 
158
  conversation.append({"role": "user", "content": message})
159
- logger.debug("Conversation initialized with %d messages", len(conversation))
160
  return conversation
161
 
162
  def tool_RAG(self, message=None,
@@ -164,8 +145,7 @@ class TxAgent:
164
  existing_tools_prompt=[],
165
  rag_num=5,
166
  return_call_result=False):
167
- logger.debug("Running tool RAG, message=%s, rag_num=%d", message[:50] if message else None, rag_num)
168
- extra_factor = 30
169
  if picked_tool_names is None:
170
  assert picked_tool_names is not None or message is not None
171
  picked_tool_names = self.rag_infer(
@@ -183,43 +163,39 @@ class TxAgent:
183
  picked_tools)
184
  if return_call_result:
185
  return picked_tools_prompt, picked_tool_names
186
- logger.debug("Tool RAG returned %d tools", len(picked_tools_prompt))
187
  return picked_tools_prompt
188
 
189
  def add_special_tools(self, tools, call_agent=False):
190
- logger.debug("Adding special tools, call_agent=%s", call_agent)
191
  if self.enable_finish:
192
  tools.append(self.tooluniverse.get_one_tool_by_one_name(
193
  'Finish', return_prompt=True))
194
- logger.debug("Finish tool added")
195
  if call_agent:
196
  tools.append(self.tooluniverse.get_one_tool_by_one_name(
197
  'CallAgent', return_prompt=True))
198
- logger.debug("CallAgent tool added")
199
  else:
200
  if self.enable_rag:
201
  tools.append(self.tooluniverse.get_one_tool_by_one_name(
202
  'Tool_RAG', return_prompt=True))
203
- logger.debug("Tool_RAG tool added")
204
 
205
  if self.additional_default_tools is not None:
206
  for each_tool_name in self.additional_default_tools:
207
  tool_prompt = self.tooluniverse.get_one_tool_by_one_name(
208
  each_tool_name, return_prompt=True)
209
  if tool_prompt is not None:
210
- logger.debug("%s tool added", each_tool_name)
211
  tools.append(tool_prompt)
212
  return tools
213
 
214
  def add_finish_tools(self, tools):
215
- logger.debug("Adding finish tools")
216
  tools.append(self.tooluniverse.get_one_tool_by_one_name(
217
  'Finish', return_prompt=True))
218
- logger.debug("Finish tool added")
219
  return tools
220
 
221
  def set_system_prompt(self, conversation, sys_prompt):
222
- logger.debug("Setting system prompt")
223
  if len(conversation) == 0:
224
  conversation.append(
225
  {"role": "system", "content": sys_prompt})
@@ -235,7 +211,6 @@ class TxAgent:
235
  call_agent_level=None,
236
  temperature=None):
237
 
238
- logger.debug("Running function call with input: %s", fcall_str[:50])
239
  function_call_json, message = self.tooluniverse.extract_function_call_json(
240
  fcall_str, return_message=return_message, verbose=False)
241
  call_results = []
@@ -243,7 +218,7 @@ class TxAgent:
243
  if function_call_json is not None:
244
  if isinstance(function_call_json, list):
245
  for i in range(len(function_call_json)):
246
- logger.debug("Tool Call: %s", function_call_json[i])
247
  if function_call_json[i]["name"] == 'Finish':
248
  special_tool_call = 'Finish'
249
  break
@@ -264,7 +239,7 @@ class TxAgent:
264
  )
265
  call_result = self.run_multistep_agent(
266
  full_message, temperature=temperature,
267
- max_new_tokens=1024, max_token=8192,
268
  call_agent=False, call_agent_level=call_agent_level)
269
  if call_result is None:
270
  call_result = "⚠️ No content returned from sub-agent."
@@ -278,7 +253,7 @@ class TxAgent:
278
 
279
  call_id = self.tooluniverse.call_id_gen()
280
  function_call_json[i]["call_id"] = call_id
281
- logger.debug("Tool Call Result: %s", call_result)
282
  call_results.append({
283
  "role": "tool",
284
  "content": json.dumps({"tool_name": function_call_json[i]["name"], "content": call_result, "call_id": call_id})
@@ -286,15 +261,16 @@ class TxAgent:
286
  else:
287
  call_results.append({
288
  "role": "tool",
289
- "content": json.dumps({"content": "No valid tool call detected; proceeding with analysis."})
290
  })
291
 
292
  revised_messages = [{
293
  "role": "assistant",
294
- "content": message.strip() if message else "Processing...",
295
- "tool_calls": json.dumps(function_call_json) if function_call_json else None
296
  }] + call_results
297
- logger.debug("Function call completed, returning %d messages", len(revised_messages))
 
298
  return revised_messages, existing_tools_prompt, special_tool_call
299
 
300
  def run_function_call_stream(self, fcall_str,
@@ -306,104 +282,102 @@ class TxAgent:
306
  temperature=None,
307
  return_gradio_history=True):
308
 
309
- logger.debug("Running function call stream with input: %s", fcall_str[:50])
310
  function_call_json, message = self.tooluniverse.extract_function_call_json(
311
  fcall_str, return_message=return_message, verbose=False)
312
  call_results = []
313
  special_tool_call = ''
314
- gradio_history = []
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
315
 
316
- if function_call_json is None:
317
- logger.warning("No valid function call JSON extracted")
 
 
 
 
 
 
 
 
 
 
 
 
318
  call_results.append({
319
  "role": "tool",
320
- "content": json.dumps({"content": "No tool call detected; continuing analysis."})
321
  })
322
- if return_gradio_history:
323
- gradio_history.append({"role": "assistant", "content": "No specific tool call identified. Proceeding with medical record analysis."})
324
- yield [{"role": "assistant", "content": "Processing..."}], existing_tools_prompt or [], special_tool_call, gradio_history
325
- return
326
-
327
- if isinstance(function_call_json, list):
328
- for i in range(len(function_call_json)):
329
- logger.debug("Processing tool call: %s", function_call_json[i])
330
- if function_call_json[i]["name"] == 'Finish':
331
- special_tool_call = 'Finish'
332
- break
333
- elif function_call_json[i]["name"] == 'Tool_RAG':
334
- new_tools_prompt, call_result = self.tool_RAG(
335
- message=message,
336
- existing_tools_prompt=existing_tools_prompt,
337
- rag_num=self.step_rag_num,
338
- return_call_result=True)
339
- existing_tools_prompt = (existing_tools_prompt or []) + new_tools_prompt
340
- elif function_call_json[i]["name"] == 'DirectResponse':
341
- call_result = function_call_json[i]['arguments']['respose']
342
- special_tool_call = 'DirectResponse'
343
- elif function_call_json[i]["name"] == 'RequireClarification':
344
- call_result = function_call_json[i]['arguments']['unclear_question']
345
- special_tool_call = 'RequireClarification'
346
- elif function_call_json[i]["name"] == 'CallAgent':
347
- if call_agent_level < 2 and call_agent:
348
- solution_plan = function_call_json[i]['arguments']['solution']
349
- full_message = (
350
- message_for_call_agent +
351
- "\nYou must follow the following plan to answer the question: " +
352
- str(solution_plan)
353
- )
354
- sub_agent_task = "Sub TxAgent plan: " + str(solution_plan)
355
- sub_result = yield from self.run_gradio_chat(
356
- full_message, history=[], temperature=temperature,
357
- max_new_tokens=1024, max_token=8192,
358
- call_agent=False, call_agent_level=call_agent_level,
359
- conversation=None,
360
- sub_agent_task=sub_agent_task)
361
- call_result = sub_result if isinstance(sub_result, str) else "No content from sub-agent."
362
- if '[FinalAnswer]' in call_result:
363
- call_result = call_result.split('[FinalAnswer]')[-1].strip()
364
- else:
365
- call_result = "CallAgent disabled. Proceeding with reasoning."
366
- else:
367
- call_result = self.tooluniverse.run_one_function(function_call_json[i])
368
-
369
- call_id = self.tooluniverse.call_id_gen()
370
- function_call_json[i]["call_id"] = call_id
371
- call_results.append({
372
- "role": "tool",
373
- "content": json.dumps({"tool_name": function_call_json[i]["name"], "content": call_result, "call_id": call_id})
374
- })
375
-
376
- if return_gradio_history and function_call_json[i]["name"] != 'Finish':
377
- metadata = {"title": f"⚒️ {function_call_json[i]['name']}", "log": str(function_call_json[i]['arguments'])}
378
- gradio_history.append({"role": "assistant", "content": str(call_result), "metadata": metadata})
379
 
380
  revised_messages = [{
381
  "role": "assistant",
382
- "content": message.strip() if message else "Processing...",
383
- "tool_calls": json.dumps(function_call_json) if function_call_json else None
384
  }] + call_results
385
 
386
  if return_gradio_history:
387
- logger.debug("Yielding gradio history with %d entries", len(gradio_history))
388
- yield revised_messages, existing_tools_prompt or [], special_tool_call, gradio_history
389
  else:
390
- yield revised_messages, existing_tools_prompt or [], special_tool_call
391
- logger.debug("Function call stream completed")
392
 
393
  def get_answer_based_on_unfinished_reasoning(self, conversation, temperature, max_new_tokens, max_token, outputs=None):
394
- logger.debug("Forcing answer due to unfinished reasoning")
395
- if conversation[-1]['role'] == 'assistant':
396
  conversation.append(
397
- {'role': 'tool', 'content': 'Errors occurred; provide a detailed final answer based on current information.'})
398
  finish_tools_prompt = self.add_finish_tools([])
399
 
400
  last_outputs_str = self.llm_infer(messages=conversation,
401
  temperature=temperature,
402
  tools=finish_tools_prompt,
403
- output_begin_string='[FinalAnswer]',
404
  skip_special_tokens=True,
405
  max_new_tokens=max_new_tokens, max_token=max_token)
406
- logger.debug("Forced finish output: %s", last_outputs_str[:100])
407
  return last_outputs_str
408
 
409
  def run_multistep_agent(self, message: str,
@@ -413,7 +387,16 @@ class TxAgent:
413
  max_round: int = 20,
414
  call_agent=False,
415
  call_agent_level=0) -> str:
416
- logger.info("Starting multistep agent with message: %s", message[:50])
 
 
 
 
 
 
 
 
 
417
  picked_tools_prompt, call_agent_level = self.initialize_tools_prompt(
418
  call_agent, call_agent_level, message)
419
  conversation = self.initialize_conversation(message)
@@ -432,7 +415,6 @@ class TxAgent:
432
  try:
433
  while next_round and current_round < max_round:
434
  current_round += 1
435
- logger.debug("Round %d", current_round)
436
  if len(outputs) > 0:
437
  function_call_messages, picked_tools_prompt, special_tool_call = self.run_function_call(
438
  last_outputs, return_message=True,
@@ -450,12 +432,12 @@ class TxAgent:
450
  function_call_messages[0]['content'])
451
  content = function_call_messages[0]['content']
452
  if content is None:
453
- logger.warning("No content after Finish tool call")
454
  return "❌ No content returned after Finish tool call."
455
- logger.debug("Returning final content: %s", content[:50])
456
  return content.split('[FinalAnswer]')[-1]
457
 
458
  if (self.enable_summary or token_overflow) and not call_agent:
 
 
459
  enable_summary = True
460
  last_status = self.function_result_summary(
461
  conversation, status=last_status, enable_summary=enable_summary)
@@ -466,14 +448,15 @@ class TxAgent:
466
  function_call_messages))
467
  else:
468
  next_round = False
469
- content = ''.join(last_outputs).replace("</s>", "")
470
- logger.debug("Returning content: %s", content[:50])
471
- return content
472
-
473
  if self.enable_checker:
474
  good_status, wrong_info = checker.check_conversation()
475
  if not good_status:
476
- logger.warning("Internal error in reasoning: %s", wrong_info)
 
 
477
  break
478
  last_outputs = []
479
  outputs.append("### TxAgent:\n")
@@ -484,7 +467,7 @@ class TxAgent:
484
  max_new_tokens=max_new_tokens, max_token=max_token,
485
  check_token_status=True)
486
  if last_outputs_str is None:
487
- logger.warning("Token limit exceeded")
488
  if self.force_finish:
489
  return self.get_answer_based_on_unfinished_reasoning(conversation, temperature, max_new_tokens, max_token)
490
  else:
@@ -492,22 +475,21 @@ class TxAgent:
492
  else:
493
  last_outputs.append(last_outputs_str)
494
  if max_round == current_round:
495
- logger.warning("Max rounds exceeded")
496
  if self.force_finish:
497
  return self.get_answer_based_on_unfinished_reasoning(conversation, temperature, max_new_tokens, max_token)
498
  else:
499
- logger.debug("No output due to max rounds")
500
  return None
501
 
502
  except Exception as e:
503
- logger.error("Error in multistep agent: %s", e, exc_info=True)
504
  if self.force_finish:
505
  return self.get_answer_based_on_unfinished_reasoning(conversation, temperature, max_new_tokens, max_token)
506
  else:
507
  return None
508
 
509
  def build_logits_processor(self, messages, llm):
510
- logger.debug("Building logits processor")
511
  tokenizer = llm.get_tokenizer()
512
  if self.avoid_repeat and len(messages) > 2:
513
  assistant_messages = []
@@ -519,14 +501,14 @@ class TxAgent:
519
  forbidden_ids = [tokenizer.encode(
520
  msg, add_special_tokens=False) for msg in assistant_messages]
521
  return [NoRepeatSentenceProcessor(forbidden_ids, 5)]
522
- return None
 
523
 
524
  def llm_infer(self, messages, temperature=0.1, tools=None,
525
  output_begin_string=None, max_new_tokens=2048,
526
- max_token=8192, skip_special_tokens=True,
527
  model=None, tokenizer=None, terminators=None, seed=None, check_token_status=False):
528
 
529
- logger.debug("Running LLM inference with %d messages", len(messages))
530
  if model is None:
531
  model = self.model
532
 
@@ -534,6 +516,7 @@ class TxAgent:
534
  sampling_params = SamplingParams(
535
  temperature=temperature,
536
  max_tokens=max_new_tokens,
 
537
  seed=seed if seed is not None else self.seed,
538
  )
539
 
@@ -544,38 +527,24 @@ class TxAgent:
544
 
545
  if check_token_status and max_token is not None:
546
  token_overflow = False
547
- input_tokens = self.tokenizer.encode(prompt, return_tensors="pt")[0]
548
- num_input_tokens = len(input_tokens)
549
- if num_input_tokens > max_token:
550
- logger.info("Input tokens: %d, max_token: %d", num_input_tokens, max_token)
551
- max_prompt_tokens = max_token - max_new_tokens - 100
552
- if max_prompt_tokens > 0:
553
- truncated_input = self.tokenizer.decode(input_tokens[:max_prompt_tokens])
554
- prompt = truncated_input
555
- logger.info("Truncated to %d tokens", len(self.tokenizer.encode(prompt, return_tensors='pt')[0]))
556
- token_overflow = True
557
- else:
558
- logger.warning("Cannot truncate effectively")
559
  torch.cuda.empty_cache()
560
  gc.collect()
 
 
 
 
 
561
  return None, token_overflow
562
-
563
  output = model.generate(
564
  prompt,
565
  sampling_params=sampling_params,
566
  )
567
  output = output[0].outputs[0].text
568
- # Deduplicate repetitive output
569
- if output:
570
- lines = output.split('\n')
571
- seen = set()
572
- deduped_lines = []
573
- for line in lines:
574
- if line.strip() and line not in seen:
575
- seen.add(line)
576
- deduped_lines.append(line)
577
- output = '\n'.join(deduped_lines)
578
- logger.debug("LLM output: %s", output[:50])
579
  if check_token_status and max_token is not None:
580
  return output, token_overflow
581
 
@@ -586,7 +555,7 @@ class TxAgent:
586
  max_new_tokens: int,
587
  max_token: int) -> str:
588
 
589
- logger.info("Starting self agent with message: %s", message[:50])
590
  conversation = []
591
  conversation = self.set_system_prompt(conversation, self.self_prompt)
592
  conversation.append({"role": "user", "content": message})
@@ -600,7 +569,7 @@ class TxAgent:
600
  max_new_tokens: int,
601
  max_token: int) -> str:
602
 
603
- logger.info("Starting chat agent with message: %s", message[:50])
604
  conversation = []
605
  conversation = self.set_system_prompt(conversation, self.chat_prompt)
606
  conversation.append({"role": "user", "content": message})
@@ -615,7 +584,7 @@ class TxAgent:
615
  max_new_tokens: int,
616
  max_token: int) -> str:
617
 
618
- logger.info("Starting format agent")
619
  if '[FinalAnswer]' in answer:
620
  possible_final_answer = answer.split("[FinalAnswer]")[-1]
621
  elif "\n\n" in answer:
@@ -625,13 +594,12 @@ class TxAgent:
625
  if len(possible_final_answer) == 1:
626
  choice = possible_final_answer[0]
627
  if choice in ['A', 'B', 'C', 'D', 'E']:
628
- logger.debug("Returning choice: %s", choice)
629
  return choice
630
  elif len(possible_final_answer) > 1:
631
  if possible_final_answer[1] == ':':
632
  choice = possible_final_answer[0]
633
  if choice in ['A', 'B', 'C', 'D', 'E']:
634
- logger.debug("Returning choice: %s", choice)
635
  return choice
636
 
637
  conversation = []
@@ -649,7 +617,7 @@ class TxAgent:
649
  temperature: float,
650
  max_new_tokens: int,
651
  max_token: int) -> str:
652
- logger.info("Running summary agent")
653
  generate_tool_result_summary_training_prompt = """Thought and function calls:
654
  {thought_calls}
655
  Function calls' responses:
@@ -670,11 +638,20 @@ Generate **one summarized sentence** about "function calls' responses" with nece
670
 
671
  if '[' in output:
672
  output = output.split('[')[0]
673
- logger.debug("Summary output: %s", output)
674
  return output
675
 
676
  def function_result_summary(self, input_list, status, enable_summary):
677
- logger.debug("Running function result summary, enable_summary=%s", enable_summary)
 
 
 
 
 
 
 
 
 
 
678
  if 'tool_call_step' not in status:
679
  status['tool_call_step'] = 0
680
 
@@ -722,14 +699,14 @@ Generate **one summarized sentence** about "function calls' responses" with nece
722
  this_thought_calls = None
723
  else:
724
  if len(function_response) != 0:
725
- logger.debug("Generating internal summary")
726
  status['summarized_step'] += 1
727
  result_summary = self.run_summary_agent(
728
  thought_calls=this_thought_calls,
729
  function_response=function_response,
730
  temperature=0.1,
731
  max_new_tokens=1024,
732
- max_token=8192
733
  )
734
 
735
  input_list.insert(
@@ -758,7 +735,7 @@ Generate **one summarized sentence** about "function calls' responses" with nece
758
  function_response=function_response,
759
  temperature=0.1,
760
  max_new_tokens=1024,
761
- max_token=8192
762
  )
763
 
764
  tool_calls = json.loads(input_list[last_call_idx]['tool_calls'])
@@ -769,18 +746,19 @@ Generate **one summarized sentence** about "function calls' responses" with nece
769
  last_call_idx+1, {'role': 'tool', 'content': result_summary})
770
  status['summarized_index'] = last_call_idx + 2
771
 
772
- logger.debug("Function result summary completed")
773
  return status
774
 
 
 
 
775
  def update_parameters(self, **kwargs):
776
- logger.debug("Updating parameters: %s", kwargs)
777
  for key, value in kwargs.items():
778
  if hasattr(self, key):
779
  setattr(self, key, value)
780
 
 
781
  updated_attributes = {key: value for key,
782
  value in kwargs.items() if hasattr(self, key)}
783
- logger.debug("Updated attributes: %s", updated_attributes)
784
  return updated_attributes
785
 
786
  def run_gradio_chat(self, message: str,
@@ -806,117 +784,90 @@ Generate **one summarized sentence** about "function calls' responses" with nece
806
  Returns:
807
  str: Final assistant message.
808
  """
809
- logger.info("[TxAgent] Chat started with message: %s", message[:100])
810
- logger.debug("Initial history: %s", [msg["content"][:50] for msg in history] if history else [])
811
 
812
- # Yield initial message to ensure UI updates
813
- history.append({"role": "assistant", "content": "Starting analysis..."})
814
- yield history
815
- logger.debug("Yielded initial history")
816
 
817
- try:
818
- if not message or len(message.strip()) < 5:
819
- logger.warning("Invalid message detected")
820
- history.append({"role": "assistant", "content": "Please provide a valid message or upload files to analyze."})
821
- yield history
822
- return "Invalid input."
823
-
824
- if message.startswith("[\U0001f9f0 Tool_RAG") or message.startswith("⚒️"):
825
- logger.debug("Skipping tool-related message")
826
- yield history
827
- return ""
828
-
829
- outputs = []
830
- last_outputs = []
831
- picked_tools_prompt, call_agent_level = self.initialize_tools_prompt(
832
- call_agent, call_agent_level, message)
833
- conversation = self.initialize_conversation(
834
- message, conversation=conversation, history=history)
835
- history = [] # Reset history to avoid duplication
836
- logger.debug("Conversation initialized with %d messages", len(conversation))
837
-
838
- next_round = True
839
- function_call_messages = []
840
- current_round = 0
841
- enable_summary = False
842
- last_status = {}
843
- token_overflow = False
844
 
845
- if self.enable_checker:
846
- checker = ReasoningTraceChecker(
847
- message, conversation, init_index=len(conversation))
848
 
 
849
  while next_round and current_round < max_round:
850
  current_round += 1
851
- logger.debug("Round %d, conversation length: %d", current_round, len(conversation))
852
 
853
  if last_outputs:
854
- function_call_result = yield from self.run_function_call_stream(
855
- last_outputs[0], return_message=True,
856
  existing_tools_prompt=picked_tools_prompt,
857
  message_for_call_agent=message,
858
  call_agent=call_agent,
859
  call_agent_level=call_agent_level,
860
  temperature=temperature)
861
 
862
- if not function_call_result:
863
- logger.warning("Empty result from run_function_call_stream")
864
- history.append({"role": "assistant", "content": "Error: Unable to process tool response. Continuing analysis."})
865
- yield history
866
- last_outputs = []
867
- continue
868
-
869
- function_call_messages, picked_tools_prompt, special_tool_call, current_gradio_history = function_call_result
870
-
871
- # Convert history to dicts and deduplicate
872
- unique_history = []
873
- seen_contents = set()
874
- for msg in current_gradio_history:
875
- content = msg["content"] if isinstance(msg, dict) else msg.content
876
- if content not in seen_contents:
877
- unique_history.append({"role": "assistant", "content": content})
878
- seen_contents.add(content)
879
- history.extend(unique_history)
880
- logger.debug("Extended history with %d unique messages", len(unique_history))
881
 
882
  if special_tool_call == 'Finish' and function_call_messages:
883
- content = function_call_messages[0]['content']
884
- history.append({"role": "assistant", "content": content})
885
- logger.debug("Yielding final history after Finish: %s", content[:50])
886
  yield history
887
  next_round = False
888
  conversation.extend(function_call_messages)
889
- return content
890
 
891
  elif special_tool_call in ['RequireClarification', 'DirectResponse']:
892
- last_msg = history[-1] if history else {"role": "assistant", "content": "Response needed."}
893
- history.append({"role": "assistant", "content": last_msg["content"]})
894
- logger.debug("Yielding history for special tool: %s", last_msg["content"][:50])
895
  yield history
896
  next_round = False
897
- return last_msg["content"]
898
 
899
  if (self.enable_summary or token_overflow) and not call_agent:
900
  enable_summary = True
901
 
902
  last_status = self.function_result_summary(
903
- conversation, status=last_status, enable_summary=enable_summary)
 
904
 
905
  if function_call_messages:
906
  conversation.extend(function_call_messages)
 
907
  else:
908
  next_round = False
909
- content = ''.join(last_outputs).replace("</s>", "")
910
- history.append({"role": "assistant", "content": content})
911
- conversation.append({"role": "assistant", "content": content})
912
- logger.debug("Yielding history with content: %s", content[:50])
913
- yield history
914
- return content
915
 
916
  if self.enable_checker:
917
  good_status, wrong_info = checker.check_conversation()
918
  if not good_status:
919
- logger.warning("Checker flagged error: %s", wrong_info)
920
  break
921
 
922
  last_outputs = []
@@ -926,78 +877,87 @@ Generate **one summarized sentence** about "function calls' responses" with nece
926
  tools=picked_tools_prompt,
927
  skip_special_tokens=False,
928
  max_new_tokens=max_new_tokens,
929
- max_token=8192,
930
  seed=seed,
931
  check_token_status=True)
932
 
933
- logger.debug("llm_infer output: %s, token_overflow: %s",
934
- last_outputs_str[:50] if last_outputs_str else None, token_overflow)
935
 
936
  if last_outputs_str is None:
937
- logger.warning("llm_infer returned None")
938
- error_msg = "Error: Unable to generate response due to token limit. Please reduce input size."
939
- history.append({"role": "assistant", "content": error_msg})
940
- yield history
941
- return error_msg
 
 
 
 
 
 
 
942
 
943
  last_thought = last_outputs_str.split("[TOOL_CALLS]")[0]
944
 
945
  for msg in history:
946
- if isinstance(msg, dict) and "metadata" in msg and msg["metadata"] is not None:
947
- msg["metadata"]['status'] = 'done'
948
 
949
  if '[FinalAnswer]' in last_thought:
950
  parts = last_thought.split('[FinalAnswer]', 1)
951
- final_thought, final_answer = parts if len(parts) == 2 else (last_thought, "")
952
- history.append({"role": "assistant", "content": final_thought.strip()})
953
- history.append({"role": "assistant", "content": "**🧠 Final Analysis:**\n" + final_answer.strip()})
954
- logger.debug("Yielding final analysis: %s", final_answer[:50])
 
 
 
955
  yield history
956
- next_round = False
957
  else:
958
- history.append({"role": "assistant", "content": last_thought})
959
- logger.debug("Yielding intermediate history: %s", last_thought[:50])
960
  yield history
961
 
962
  last_outputs.append(last_outputs_str)
963
 
964
  if next_round:
965
- logger.info("Max rounds reached")
966
  if self.force_finish:
967
  last_outputs_str = self.get_answer_based_on_unfinished_reasoning(
968
  conversation, temperature, max_new_tokens, max_token)
969
  if '[FinalAnswer]' in last_outputs_str:
970
  parts = last_outputs_str.split('[FinalAnswer]', 1)
971
- final_thought, final_answer = parts if len(parts) == 2 else (last_outputs_str, "")
972
- history.append({"role": "assistant", "content": final_thought.strip()})
973
- history.append({"role": "assistant", "content": "**🧠 Final Analysis:**\n" + final_answer.strip()})
 
 
 
 
 
974
  else:
975
- history.append({"role": "assistant", "content": last_outputs_str.strip()})
976
- logger.debug("Yielding forced final history")
977
- yield history
978
  else:
979
- error_msg = "The number of reasoning rounds exceeded the limit."
980
- history.append({"role": "assistant", "content": error_msg})
981
- logger.debug("Yielding max rounds error")
982
- yield history
983
- return error_msg
984
 
985
  except Exception as e:
986
- logger.error("Exception in run_gradio_chat: %s", e, exc_info=True)
987
  error_msg = f"An error occurred: {e}"
988
- history.append({"role": "assistant", "content": error_msg})
989
- logger.debug("Yielding error history: %s", error_msg)
990
  yield history
991
  if self.force_finish:
992
  last_outputs_str = self.get_answer_based_on_unfinished_reasoning(
993
  conversation, temperature, max_new_tokens, max_token)
994
  if '[FinalAnswer]' in last_outputs_str:
995
  parts = last_outputs_str.split('[FinalAnswer]', 1)
996
- final_thought, final_answer = parts if len(parts) == 2 else (last_outputs_str, "")
997
- history.append({"role": "assistant", "content": final_thought.strip()})
998
- history.append({"role": "assistant", "content": "**🧠 Final Analysis:**\n" + final_answer.strip()})
 
 
 
 
 
999
  else:
1000
- history.append({"role": "assistant", "content": last_outputs_str.strip()})
1001
- logger.debug("Yielding forced final history after error")
1002
- yield history
1003
  return error_msg
 
12
  from gradio import ChatMessage
13
  from .toolrag import ToolRAGModel
14
  import torch
15
+ # near the top of txagent.py
16
  import logging
 
17
  logger = logging.getLogger(__name__)
18
+ logging.basicConfig(level=logging.INFO)
19
 
20
  from .utils import NoRepeatSentenceProcessor, ReasoningTraceChecker, tool_result_format
21
 
22
+
23
  class TxAgent:
24
  def __init__(self, model_name,
25
  rag_model_name,
26
+ tool_files_dict=None, # None leads to the default tool files in ToolUniverse
27
  enable_finish=True,
28
  enable_rag=True,
29
  enable_summary=False,
 
47
  self.model = None
48
  self.rag_model = ToolRAGModel(rag_model_name)
49
  self.tooluniverse = None
50
+ # self.tool_desc = None
51
+ self.prompt_multi_step = "You are a helpful assistant that will solve problems through detailed, step-by-step reasoning and actions based on your reasoning. Typically, your actions will use the provided functions. You have access to the following functions."
 
 
 
52
  self.self_prompt = "Strictly follow the instruction."
53
+ self.chat_prompt = "You are helpful assistant to chat with the user."
54
  self.enable_finish = enable_finish
55
  self.enable_rag = enable_rag
56
  self.enable_summary = enable_summary
 
65
  self.enable_checker = enable_checker
66
  self.additional_default_tools = additional_default_tools
67
  self.print_self_values()
 
68
 
69
  def init_model(self):
70
+ self.load_models()
71
+ self.load_tooluniverse()
72
+ self.load_tool_desc_embedding()
 
 
 
 
 
 
73
 
74
  def print_self_values(self):
75
  for attr, value in self.__dict__.items():
76
+ print(f"{attr}: {value}")
77
 
78
  def load_models(self, model_name=None):
79
  if model_name is not None:
80
  if model_name == self.model_name:
 
81
  return f"The model {model_name} is already loaded."
82
  self.model_name = model_name
83
 
 
84
  self.model = LLM(model=self.model_name)
85
  self.chat_template = Template(self.model.get_tokenizer().chat_template)
86
  self.tokenizer = self.model.get_tokenizer()
87
+
88
  return f"Model {model_name} loaded successfully."
89
 
90
  def load_tooluniverse(self):
 
91
  self.tooluniverse = ToolUniverse(tool_files=self.tool_files_dict)
92
  self.tooluniverse.load_tools()
93
  special_tools = self.tooluniverse.prepare_tool_prompts(
94
  self.tooluniverse.tool_category_dicts["special_tools"])
95
  self.special_tools_name = [tool['name'] for tool in special_tools]
 
96
 
97
  def load_tool_desc_embedding(self):
 
98
  self.rag_model.load_tool_desc_embedding(self.tooluniverse)
 
99
 
100
  def rag_infer(self, query, top_k=5):
 
101
  return self.rag_model.rag_infer(query, top_k)
102
 
103
  def initialize_tools_prompt(self, call_agent, call_agent_level, message):
 
104
  picked_tools_prompt = []
105
  picked_tools_prompt = self.add_special_tools(
106
  picked_tools_prompt, call_agent=call_agent)
 
112
  if not call_agent:
113
  picked_tools_prompt += self.tool_RAG(
114
  message=message, rag_num=self.init_rag_num)
 
115
  return picked_tools_prompt, call_agent_level
116
 
117
  def initialize_conversation(self, message, conversation=None, history=None):
 
118
  if conversation is None:
119
  conversation = []
120
 
 
123
  if history is not None:
124
  if len(history) == 0:
125
  conversation = []
126
+ print("clear conversation successfully")
127
  else:
128
  for i in range(len(history)):
129
  if history[i]['role'] == 'user':
 
137
  {"role": "assistant", "content": history[i]['content']})
138
 
139
  conversation.append({"role": "user", "content": message})
140
+
141
  return conversation
142
 
143
  def tool_RAG(self, message=None,
 
145
  existing_tools_prompt=[],
146
  rag_num=5,
147
  return_call_result=False):
148
+ extra_factor = 30 # Factor to retrieve more than rag_num
 
149
  if picked_tool_names is None:
150
  assert picked_tool_names is not None or message is not None
151
  picked_tool_names = self.rag_infer(
 
163
  picked_tools)
164
  if return_call_result:
165
  return picked_tools_prompt, picked_tool_names
 
166
  return picked_tools_prompt
167
 
168
  def add_special_tools(self, tools, call_agent=False):
 
169
  if self.enable_finish:
170
  tools.append(self.tooluniverse.get_one_tool_by_one_name(
171
  'Finish', return_prompt=True))
172
+ print("Finish tool is added")
173
  if call_agent:
174
  tools.append(self.tooluniverse.get_one_tool_by_one_name(
175
  'CallAgent', return_prompt=True))
176
+ print("CallAgent tool is added")
177
  else:
178
  if self.enable_rag:
179
  tools.append(self.tooluniverse.get_one_tool_by_one_name(
180
  'Tool_RAG', return_prompt=True))
181
+ print("Tool_RAG tool is added")
182
 
183
  if self.additional_default_tools is not None:
184
  for each_tool_name in self.additional_default_tools:
185
  tool_prompt = self.tooluniverse.get_one_tool_by_one_name(
186
  each_tool_name, return_prompt=True)
187
  if tool_prompt is not None:
188
+ print(f"{each_tool_name} tool is added")
189
  tools.append(tool_prompt)
190
  return tools
191
 
192
  def add_finish_tools(self, tools):
 
193
  tools.append(self.tooluniverse.get_one_tool_by_one_name(
194
  'Finish', return_prompt=True))
195
+ print("Finish tool is added")
196
  return tools
197
 
198
  def set_system_prompt(self, conversation, sys_prompt):
 
199
  if len(conversation) == 0:
200
  conversation.append(
201
  {"role": "system", "content": sys_prompt})
 
211
  call_agent_level=None,
212
  temperature=None):
213
 
 
214
  function_call_json, message = self.tooluniverse.extract_function_call_json(
215
  fcall_str, return_message=return_message, verbose=False)
216
  call_results = []
 
218
  if function_call_json is not None:
219
  if isinstance(function_call_json, list):
220
  for i in range(len(function_call_json)):
221
+ print("\033[94mTool Call:\033[0m", function_call_json[i])
222
  if function_call_json[i]["name"] == 'Finish':
223
  special_tool_call = 'Finish'
224
  break
 
239
  )
240
  call_result = self.run_multistep_agent(
241
  full_message, temperature=temperature,
242
+ max_new_tokens=1024, max_token=99999,
243
  call_agent=False, call_agent_level=call_agent_level)
244
  if call_result is None:
245
  call_result = "⚠️ No content returned from sub-agent."
 
253
 
254
  call_id = self.tooluniverse.call_id_gen()
255
  function_call_json[i]["call_id"] = call_id
256
+ print("\033[94mTool Call Result:\033[0m", call_result)
257
  call_results.append({
258
  "role": "tool",
259
  "content": json.dumps({"tool_name": function_call_json[i]["name"], "content": call_result, "call_id": call_id})
 
261
  else:
262
  call_results.append({
263
  "role": "tool",
264
+ "content": json.dumps({"content": "Not a valid function call, please check the function call format."})
265
  })
266
 
267
  revised_messages = [{
268
  "role": "assistant",
269
+ "content": message.strip(),
270
+ "tool_calls": json.dumps(function_call_json)
271
  }] + call_results
272
+
273
+ # Yield the final result.
274
  return revised_messages, existing_tools_prompt, special_tool_call
275
 
276
  def run_function_call_stream(self, fcall_str,
 
282
  temperature=None,
283
  return_gradio_history=True):
284
 
 
285
  function_call_json, message = self.tooluniverse.extract_function_call_json(
286
  fcall_str, return_message=return_message, verbose=False)
287
  call_results = []
288
  special_tool_call = ''
289
+ if return_gradio_history:
290
+ gradio_history = []
291
+ if function_call_json is not None:
292
+ if isinstance(function_call_json, list):
293
+ for i in range(len(function_call_json)):
294
+ if function_call_json[i]["name"] == 'Finish':
295
+ special_tool_call = 'Finish'
296
+ break
297
+ elif function_call_json[i]["name"] == 'Tool_RAG':
298
+ new_tools_prompt, call_result = self.tool_RAG(
299
+ message=message,
300
+ existing_tools_prompt=existing_tools_prompt,
301
+ rag_num=self.step_rag_num,
302
+ return_call_result=True)
303
+ existing_tools_prompt += new_tools_prompt
304
+ elif function_call_json[i]["name"] == 'DirectResponse':
305
+ call_result = function_call_json[i]['arguments']['respose']
306
+ special_tool_call = 'DirectResponse'
307
+ elif function_call_json[i]["name"] == 'RequireClarification':
308
+ call_result = function_call_json[i]['arguments']['unclear_question']
309
+ special_tool_call = 'RequireClarification'
310
+ elif function_call_json[i]["name"] == 'CallAgent':
311
+ if call_agent_level < 2 and call_agent:
312
+ solution_plan = function_call_json[i]['arguments']['solution']
313
+ full_message = (
314
+ message_for_call_agent +
315
+ "\nYou must follow the following plan to answer the question: " +
316
+ str(solution_plan)
317
+ )
318
+ sub_agent_task = "Sub TxAgent plan: " + \
319
+ str(solution_plan)
320
+ call_result = yield from self.run_gradio_chat(
321
+ full_message, history=[], temperature=temperature,
322
+ max_new_tokens=1024, max_token=99999,
323
+ call_agent=False, call_agent_level=call_agent_level,
324
+ conversation=None,
325
+ sub_agent_task=sub_agent_task)
326
+
327
+ if call_result is not None and isinstance(call_result, str):
328
+ call_result = call_result.split('[FinalAnswer]')[-1]
329
+ else:
330
+ call_result = "⚠️ No content returned from sub-agent."
331
+ else:
332
+ call_result = "Error: The CallAgent has been disabled. Please proceed with your reasoning process to solve this question."
333
+ else:
334
+ call_result = self.tooluniverse.run_one_function(
335
+ function_call_json[i])
336
 
337
+ call_id = self.tooluniverse.call_id_gen()
338
+ function_call_json[i]["call_id"] = call_id
339
+ call_results.append({
340
+ "role": "tool",
341
+ "content": json.dumps({"tool_name": function_call_json[i]["name"], "content": call_result, "call_id": call_id})
342
+ })
343
+ if return_gradio_history and function_call_json[i]["name"] != 'Finish':
344
+ if function_call_json[i]["name"] == 'Tool_RAG':
345
+ gradio_history.append(ChatMessage(role="assistant", content=str(call_result), metadata={
346
+ "title": "🧰 "+function_call_json[i]['name'], "log": str(function_call_json[i]['arguments'])}))
347
+ else:
348
+ gradio_history.append(ChatMessage(role="assistant", content=str(call_result), metadata={
349
+ "title": "⚒️ "+function_call_json[i]['name'], "log": str(function_call_json[i]['arguments'])}))
350
+ else:
351
  call_results.append({
352
  "role": "tool",
353
+ "content": json.dumps({"content": "Not a valid function call, please check the function call format."})
354
  })
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
355
 
356
  revised_messages = [{
357
  "role": "assistant",
358
+ "content": message.strip(),
359
+ "tool_calls": json.dumps(function_call_json)
360
  }] + call_results
361
 
362
  if return_gradio_history:
363
+ return revised_messages, existing_tools_prompt, special_tool_call, gradio_history
 
364
  else:
365
+ return revised_messages, existing_tools_prompt, special_tool_call
366
+
367
 
368
  def get_answer_based_on_unfinished_reasoning(self, conversation, temperature, max_new_tokens, max_token, outputs=None):
369
+ if conversation[-1]['role'] == 'assisant':
 
370
  conversation.append(
371
+ {'role': 'tool', 'content': 'Errors happen during the function call, please come up with the final answer with the current information.'})
372
  finish_tools_prompt = self.add_finish_tools([])
373
 
374
  last_outputs_str = self.llm_infer(messages=conversation,
375
  temperature=temperature,
376
  tools=finish_tools_prompt,
377
+ output_begin_string='Since I cannot continue reasoning, I will provide the final answer based on the current information and general knowledge.\n\n[FinalAnswer]',
378
  skip_special_tokens=True,
379
  max_new_tokens=max_new_tokens, max_token=max_token)
380
+ print(last_outputs_str)
381
  return last_outputs_str
382
 
383
  def run_multistep_agent(self, message: str,
 
387
  max_round: int = 20,
388
  call_agent=False,
389
  call_agent_level=0) -> str:
390
+ """
391
+ Generate a streaming response using the llama3-8b model.
392
+ Args:
393
+ message (str): The input message.
394
+ temperature (float): The temperature for generating the response.
395
+ max_new_tokens (int): The maximum number of new tokens to generate.
396
+ Returns:
397
+ str: The generated response.
398
+ """
399
+ print("\033[1;32;40mstart\033[0m")
400
  picked_tools_prompt, call_agent_level = self.initialize_tools_prompt(
401
  call_agent, call_agent_level, message)
402
  conversation = self.initialize_conversation(message)
 
415
  try:
416
  while next_round and current_round < max_round:
417
  current_round += 1
 
418
  if len(outputs) > 0:
419
  function_call_messages, picked_tools_prompt, special_tool_call = self.run_function_call(
420
  last_outputs, return_message=True,
 
432
  function_call_messages[0]['content'])
433
  content = function_call_messages[0]['content']
434
  if content is None:
 
435
  return "❌ No content returned after Finish tool call."
 
436
  return content.split('[FinalAnswer]')[-1]
437
 
438
  if (self.enable_summary or token_overflow) and not call_agent:
439
+ if token_overflow:
440
+ print("token_overflow, using summary")
441
  enable_summary = True
442
  last_status = self.function_result_summary(
443
  conversation, status=last_status, enable_summary=enable_summary)
 
448
  function_call_messages))
449
  else:
450
  next_round = False
451
+ conversation.extend(
452
+ [{"role": "assistant", "content": ''.join(last_outputs)}])
453
+ return ''.join(last_outputs).replace("</s>", "")
 
454
  if self.enable_checker:
455
  good_status, wrong_info = checker.check_conversation()
456
  if not good_status:
457
+ next_round = False
458
+ print(
459
+ "Internal error in reasoning: " + wrong_info)
460
  break
461
  last_outputs = []
462
  outputs.append("### TxAgent:\n")
 
467
  max_new_tokens=max_new_tokens, max_token=max_token,
468
  check_token_status=True)
469
  if last_outputs_str is None:
470
+ print("The number of tokens exceeds the maximum limit.")
471
  if self.force_finish:
472
  return self.get_answer_based_on_unfinished_reasoning(conversation, temperature, max_new_tokens, max_token)
473
  else:
 
475
  else:
476
  last_outputs.append(last_outputs_str)
477
  if max_round == current_round:
478
+ print("The number of rounds exceeds the maximum limit!")
479
  if self.force_finish:
480
  return self.get_answer_based_on_unfinished_reasoning(conversation, temperature, max_new_tokens, max_token)
481
  else:
 
482
  return None
483
 
484
  except Exception as e:
485
+ print(f"Error: {e}")
486
  if self.force_finish:
487
  return self.get_answer_based_on_unfinished_reasoning(conversation, temperature, max_new_tokens, max_token)
488
  else:
489
  return None
490
 
491
  def build_logits_processor(self, messages, llm):
492
+ # Use the tokenizer from the LLM instance.
493
  tokenizer = llm.get_tokenizer()
494
  if self.avoid_repeat and len(messages) > 2:
495
  assistant_messages = []
 
501
  forbidden_ids = [tokenizer.encode(
502
  msg, add_special_tokens=False) for msg in assistant_messages]
503
  return [NoRepeatSentenceProcessor(forbidden_ids, 5)]
504
+ else:
505
+ return None
506
 
507
  def llm_infer(self, messages, temperature=0.1, tools=None,
508
  output_begin_string=None, max_new_tokens=2048,
509
+ max_token=None, skip_special_tokens=True,
510
  model=None, tokenizer=None, terminators=None, seed=None, check_token_status=False):
511
 
 
512
  if model is None:
513
  model = self.model
514
 
 
516
  sampling_params = SamplingParams(
517
  temperature=temperature,
518
  max_tokens=max_new_tokens,
519
+
520
  seed=seed if seed is not None else self.seed,
521
  )
522
 
 
527
 
528
  if check_token_status and max_token is not None:
529
  token_overflow = False
530
+ num_input_tokens = len(self.tokenizer.encode(
531
+ prompt, return_tensors="pt")[0])
532
+ if max_token is not None:
533
+ if num_input_tokens > max_token:
 
 
 
 
 
 
 
 
534
  torch.cuda.empty_cache()
535
  gc.collect()
536
+ print("Number of input tokens before inference:",
537
+ num_input_tokens)
538
+ logger.info(
539
+ "The number of tokens exceeds the maximum limit!!!!")
540
+ token_overflow = True
541
  return None, token_overflow
 
542
  output = model.generate(
543
  prompt,
544
  sampling_params=sampling_params,
545
  )
546
  output = output[0].outputs[0].text
547
+ print("\033[92m" + output + "\033[0m")
 
 
 
 
 
 
 
 
 
 
548
  if check_token_status and max_token is not None:
549
  return output, token_overflow
550
 
 
555
  max_new_tokens: int,
556
  max_token: int) -> str:
557
 
558
+ print("\033[1;32;40mstart self agent\033[0m")
559
  conversation = []
560
  conversation = self.set_system_prompt(conversation, self.self_prompt)
561
  conversation.append({"role": "user", "content": message})
 
569
  max_new_tokens: int,
570
  max_token: int) -> str:
571
 
572
+ print("\033[1;32;40mstart chat agent\033[0m")
573
  conversation = []
574
  conversation = self.set_system_prompt(conversation, self.chat_prompt)
575
  conversation.append({"role": "user", "content": message})
 
584
  max_new_tokens: int,
585
  max_token: int) -> str:
586
 
587
+ print("\033[1;32;40mstart format agent\033[0m")
588
  if '[FinalAnswer]' in answer:
589
  possible_final_answer = answer.split("[FinalAnswer]")[-1]
590
  elif "\n\n" in answer:
 
594
  if len(possible_final_answer) == 1:
595
  choice = possible_final_answer[0]
596
  if choice in ['A', 'B', 'C', 'D', 'E']:
 
597
  return choice
598
  elif len(possible_final_answer) > 1:
599
  if possible_final_answer[1] == ':':
600
  choice = possible_final_answer[0]
601
  if choice in ['A', 'B', 'C', 'D', 'E']:
602
+ print("choice", choice)
603
  return choice
604
 
605
  conversation = []
 
617
  temperature: float,
618
  max_new_tokens: int,
619
  max_token: int) -> str:
620
+ print("\033[1;32;40mSummarized Tool Result:\033[0m")
621
  generate_tool_result_summary_training_prompt = """Thought and function calls:
622
  {thought_calls}
623
  Function calls' responses:
 
638
 
639
  if '[' in output:
640
  output = output.split('[')[0]
 
641
  return output
642
 
643
  def function_result_summary(self, input_list, status, enable_summary):
644
+ """
645
+ Processes the input list, extracting information from sequences of 'user', 'tool', 'assistant' roles.
646
+ Supports 'length' and 'step' modes, and skips the last 'k' groups.
647
+ Parameters:
648
+ input_list (list): A list of dictionaries containing role and other information.
649
+ summary_skip_last_k (int): Number of groups to skip from the end. Defaults to 0.
650
+ summary_context_length (int): The context length threshold for the 'length' mode.
651
+ last_processed_index (tuple or int): The last processed index.
652
+ Returns:
653
+ list: A list of extracted information from valid sequences.
654
+ """
655
  if 'tool_call_step' not in status:
656
  status['tool_call_step'] = 0
657
 
 
699
  this_thought_calls = None
700
  else:
701
  if len(function_response) != 0:
702
+ print("internal summary")
703
  status['summarized_step'] += 1
704
  result_summary = self.run_summary_agent(
705
  thought_calls=this_thought_calls,
706
  function_response=function_response,
707
  temperature=0.1,
708
  max_new_tokens=1024,
709
+ max_token=99999
710
  )
711
 
712
  input_list.insert(
 
735
  function_response=function_response,
736
  temperature=0.1,
737
  max_new_tokens=1024,
738
+ max_token=99999
739
  )
740
 
741
  tool_calls = json.loads(input_list[last_call_idx]['tool_calls'])
 
746
  last_call_idx+1, {'role': 'tool', 'content': result_summary})
747
  status['summarized_index'] = last_call_idx + 2
748
 
 
749
  return status
750
 
751
+ # Following are Gradio related functions
752
+
753
+ # General update method that accepts any new arguments through kwargs
754
  def update_parameters(self, **kwargs):
 
755
  for key, value in kwargs.items():
756
  if hasattr(self, key):
757
  setattr(self, key, value)
758
 
759
+ # Return the updated attributes
760
  updated_attributes = {key: value for key,
761
  value in kwargs.items() if hasattr(self, key)}
 
762
  return updated_attributes
763
 
764
  def run_gradio_chat(self, message: str,
 
784
  Returns:
785
  str: Final assistant message.
786
  """
787
+ logger.debug(f"[TxAgent] Chat started, message: {message[:100]}...")
788
+ print("\033[1;32;40m[TxAgent] Chat started\033[0m")
789
 
790
+ if not message or len(message.strip()) < 5:
791
+ yield "Please provide a valid message or upload files to analyze."
792
+ return "Invalid input."
 
793
 
794
+ if message.startswith("[\U0001f9f0 Tool_RAG") or message.startswith("⚒️"):
795
+ return ""
796
+
797
+ outputs = []
798
+ outputs_str = ''
799
+ last_outputs = []
800
+
801
+ picked_tools_prompt, call_agent_level = self.initialize_tools_prompt(
802
+ call_agent,
803
+ call_agent_level,
804
+ message)
805
+
806
+ conversation = self.initialize_conversation(
807
+ message,
808
+ conversation=conversation,
809
+ history=history)
810
+ history = []
811
+
812
+ next_round = True
813
+ function_call_messages = []
814
+ current_round = 0
815
+ enable_summary = False
816
+ last_status = {}
817
+ token_overflow = False
 
 
 
818
 
819
+ if self.enable_checker:
820
+ checker = ReasoningTraceChecker(
821
+ message, conversation, init_index=len(conversation))
822
 
823
+ try:
824
  while next_round and current_round < max_round:
825
  current_round += 1
826
+ logger.debug(f"Round {current_round}, conversation length: {len(conversation)}")
827
 
828
  if last_outputs:
829
+ function_call_messages, picked_tools_prompt, special_tool_call, current_gradio_history = yield from self.run_function_call_stream(
830
+ last_outputs, return_message=True,
831
  existing_tools_prompt=picked_tools_prompt,
832
  message_for_call_agent=message,
833
  call_agent=call_agent,
834
  call_agent_level=call_agent_level,
835
  temperature=temperature)
836
 
837
+ history.extend(current_gradio_history)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
838
 
839
  if special_tool_call == 'Finish' and function_call_messages:
 
 
 
840
  yield history
841
  next_round = False
842
  conversation.extend(function_call_messages)
843
+ return function_call_messages[0]['content']
844
 
845
  elif special_tool_call in ['RequireClarification', 'DirectResponse']:
846
+ last_msg = history[-1] if history else ChatMessage(role="assistant", content="Response needed.")
847
+ history.append(ChatMessage(role="assistant", content=last_msg.content))
 
848
  yield history
849
  next_round = False
850
+ return last_msg.content
851
 
852
  if (self.enable_summary or token_overflow) and not call_agent:
853
  enable_summary = True
854
 
855
  last_status = self.function_result_summary(
856
+ conversation, status=last_status,
857
+ enable_summary=enable_summary)
858
 
859
  if function_call_messages:
860
  conversation.extend(function_call_messages)
861
+ yield history
862
  else:
863
  next_round = False
864
+ conversation.append({"role": "assistant", "content": ''.join(last_outputs)})
865
+ return ''.join(last_outputs).replace("</s>", "")
 
 
 
 
866
 
867
  if self.enable_checker:
868
  good_status, wrong_info = checker.check_conversation()
869
  if not good_status:
870
+ print("Checker flagged reasoning error: ", wrong_info)
871
  break
872
 
873
  last_outputs = []
 
877
  tools=picked_tools_prompt,
878
  skip_special_tokens=False,
879
  max_new_tokens=max_new_tokens,
880
+ max_token=max_token,
881
  seed=seed,
882
  check_token_status=True)
883
 
884
+ logger.debug(f"llm_infer output: {last_outputs_str[:100] if last_outputs_str else None}, token_overflow: {token_overflow}")
 
885
 
886
  if last_outputs_str is None:
887
+ logger.warning("llm_infer returned None due to token overflow")
888
+ if self.force_finish:
889
+ last_outputs_str = self.get_answer_based_on_unfinished_reasoning(
890
+ conversation, temperature, max_new_tokens, max_token)
891
+ history.append(ChatMessage(role="assistant", content=last_outputs_str.strip()))
892
+ yield history
893
+ return last_outputs_str
894
+ else:
895
+ error_msg = "Token limit exceeded. Please reduce input size or increase max_token."
896
+ history.append(ChatMessage(role="assistant", content=error_msg))
897
+ yield history
898
+ return error_msg
899
 
900
  last_thought = last_outputs_str.split("[TOOL_CALLS]")[0]
901
 
902
  for msg in history:
903
+ if msg.metadata is not None:
904
+ msg.metadata['status'] = 'done'
905
 
906
  if '[FinalAnswer]' in last_thought:
907
  parts = last_thought.split('[FinalAnswer]', 1)
908
+ if len(parts) == 2:
909
+ final_thought, final_answer = parts
910
+ else:
911
+ final_thought, final_answer = last_thought, ""
912
+ history.append(ChatMessage(role="assistant", content=final_thought.strip()))
913
+ yield history
914
+ history.append(ChatMessage(role="assistant", content="**🧠 Final Analysis:**\n" + final_answer.strip()))
915
  yield history
 
916
  else:
917
+ history.append(ChatMessage(role="assistant", content=last_thought))
 
918
  yield history
919
 
920
  last_outputs.append(last_outputs_str)
921
 
922
  if next_round:
 
923
  if self.force_finish:
924
  last_outputs_str = self.get_answer_based_on_unfinished_reasoning(
925
  conversation, temperature, max_new_tokens, max_token)
926
  if '[FinalAnswer]' in last_outputs_str:
927
  parts = last_outputs_str.split('[FinalAnswer]', 1)
928
+ if len(parts) == 2:
929
+ final_thought, final_answer = parts
930
+ else:
931
+ final_thought, final_answer = last_outputs_str, ""
932
+ history.append(ChatMessage(role="assistant", content=final_thought.strip()))
933
+ yield history
934
+ history.append(ChatMessage(role="assistant", content="**🧠 Final Analysis:**\n" + final_answer.strip()))
935
+ yield history
936
  else:
937
+ history.append(ChatMessage(role="assistant", content=last_outputs_str.strip()))
938
+ yield history
 
939
  else:
940
+ yield "The number of reasoning rounds exceeded the limit."
 
 
 
 
941
 
942
  except Exception as e:
943
+ logger.error(f"Exception in run_gradio_chat: {e}", exc_info=True)
944
  error_msg = f"An error occurred: {e}"
945
+ history.append(ChatMessage(role="assistant", content=error_msg))
 
946
  yield history
947
  if self.force_finish:
948
  last_outputs_str = self.get_answer_based_on_unfinished_reasoning(
949
  conversation, temperature, max_new_tokens, max_token)
950
  if '[FinalAnswer]' in last_outputs_str:
951
  parts = last_outputs_str.split('[FinalAnswer]', 1)
952
+ if len(parts) == 2:
953
+ final_thought, final_answer = parts
954
+ else:
955
+ final_thought, final_answer = last_outputs_str, ""
956
+ history.append(ChatMessage(role="assistant", content=final_thought.strip()))
957
+ yield history
958
+ history.append(ChatMessage(role="assistant", content="**🧠 Final Analysis:**\n" + final_answer.strip()))
959
+ yield history
960
  else:
961
+ history.append(ChatMessage(role="assistant", content=last_outputs_str.strip()))
962
+ yield history
 
963
  return error_msg