Ali2206 commited on
Commit
b6e9667
·
verified ·
1 Parent(s): 13df505

Update src/txagent/txagent.py

Browse files
Files changed (1) hide show
  1. src/txagent/txagent.py +470 -204
src/txagent/txagent.py CHANGED
@@ -12,22 +12,23 @@ from tooluniverse import ToolUniverse
12
  from gradio import ChatMessage
13
  from .toolrag import ToolRAGModel
14
  import torch
 
15
  import logging
16
-
17
  logger = logging.getLogger(__name__)
18
  logging.basicConfig(level=logging.INFO)
19
 
20
  from .utils import NoRepeatSentenceProcessor, ReasoningTraceChecker, tool_result_format
21
 
 
22
  class TxAgent:
23
  def __init__(self, model_name,
24
  rag_model_name,
25
- tool_files_dict=None,
26
- enable_finish=False, # MODIFIED: Default to False
27
- enable_rag=False,
28
  enable_summary=False,
29
  init_rag_num=0,
30
- step_rag_num=0,
31
  summary_mode='step',
32
  summary_skip_last_k=0,
33
  summary_context_length=None,
@@ -44,9 +45,10 @@ class TxAgent:
44
  self.rag_model_name = rag_model_name
45
  self.tool_files_dict = tool_files_dict
46
  self.model = None
47
- self.rag_model = ToolRAGModel(rag_model_name) if enable_rag else None
48
  self.tooluniverse = None
49
- self.prompt_multi_step = "You are a helpful assistant that will solve problems through detailed, step-by-step reasoning."
 
50
  self.self_prompt = "Strictly follow the instruction."
51
  self.chat_prompt = "You are helpful assistant to chat with the user."
52
  self.enable_finish = enable_finish
@@ -66,28 +68,26 @@ class TxAgent:
66
 
67
  def init_model(self):
68
  self.load_models()
69
- if self.enable_rag:
70
- self.load_tooluniverse()
71
- self.load_tool_desc_embedding()
72
 
73
  def print_self_values(self):
74
  for attr, value in self.__dict__.items():
75
- logger.info(f"{attr}: {value}")
76
 
77
  def load_models(self, model_name=None):
78
  if model_name is not None:
79
  if model_name == self.model_name:
80
  return f"The model {model_name} is already loaded."
81
  self.model_name = model_name
82
- self.model = LLM(model=self.model_name, enforce_eager=True, max_model_len=4096) # MODIFIED: Reduce KV cache
 
83
  self.chat_template = Template(self.model.get_tokenizer().chat_template)
84
  self.tokenizer = self.model.get_tokenizer()
 
85
  return f"Model {model_name} loaded successfully."
86
 
87
  def load_tooluniverse(self):
88
- if self.tool_files_dict is None and not self.enable_rag:
89
- logger.info("Skipping tool universe loading: RAG disabled and no tool files.")
90
- return
91
  self.tooluniverse = ToolUniverse(tool_files=self.tool_files_dict)
92
  self.tooluniverse.load_tools()
93
  special_tools = self.tooluniverse.prepare_tool_prompts(
@@ -95,24 +95,20 @@ class TxAgent:
95
  self.special_tools_name = [tool['name'] for tool in special_tools]
96
 
97
  def load_tool_desc_embedding(self):
98
- if self.rag_model and self.tooluniverse:
99
- self.rag_model.load_tool_desc_embedding(self.tooluniverse)
100
 
101
  def rag_infer(self, query, top_k=5):
102
- if not self.enable_rag or not self.rag_model:
103
- return []
104
  return self.rag_model.rag_infer(query, top_k)
105
 
106
  def initialize_tools_prompt(self, call_agent, call_agent_level, message):
107
  picked_tools_prompt = []
108
- if not self.enable_rag:
109
- return picked_tools_prompt, call_agent_level
110
  picked_tools_prompt = self.add_special_tools(
111
  picked_tools_prompt, call_agent=call_agent)
112
  if call_agent:
113
  call_agent_level += 1
114
  if call_agent_level >= 2:
115
  call_agent = False
 
116
  if not call_agent:
117
  picked_tools_prompt += self.tool_RAG(
118
  message=message, rag_num=self.init_rag_num)
@@ -121,12 +117,13 @@ class TxAgent:
121
  def initialize_conversation(self, message, conversation=None, history=None):
122
  if conversation is None:
123
  conversation = []
 
124
  conversation = self.set_system_prompt(
125
  conversation, self.prompt_multi_step)
126
  if history is not None:
127
  if len(history) == 0:
128
  conversation = []
129
- logger.info("clear conversation successfully")
130
  else:
131
  for i in range(len(history)):
132
  if history[i]['role'] == 'user':
@@ -138,7 +135,9 @@ class TxAgent:
138
  if i == len(history)-1 and history[i]['role'] == 'assistant':
139
  conversation.append(
140
  {"role": "assistant", "content": history[i]['content']})
 
141
  conversation.append({"role": "user", "content": message})
 
142
  return conversation
143
 
144
  def tool_RAG(self, message=None,
@@ -146,52 +145,60 @@ class TxAgent:
146
  existing_tools_prompt=[],
147
  rag_num=5,
148
  return_call_result=False):
149
- if not self.enable_rag:
150
- return [] if not return_call_result else ([], [])
151
- extra_factor = 30
152
  if picked_tool_names is None:
153
  assert picked_tool_names is not None or message is not None
154
  picked_tool_names = self.rag_infer(
155
  message, top_k=rag_num*extra_factor)
156
- picked_tool_names_no_special = [tool for tool in picked_tool_names if tool not in self.special_tools_name]
 
 
 
 
157
  picked_tool_names_no_special = picked_tool_names_no_special[:rag_num]
158
  picked_tool_names = picked_tool_names_no_special[:rag_num]
 
159
  picked_tools = self.tooluniverse.get_tool_by_name(picked_tool_names)
160
- picked_tools_prompt = self.tooluniverse.prepare_tool_prompts(picked_tools)
 
161
  if return_call_result:
162
  return picked_tools_prompt, picked_tool_names
163
  return picked_tools_prompt
164
 
165
  def add_special_tools(self, tools, call_agent=False):
166
- if not self.enable_rag and not self.enable_finish:
167
- return tools
168
- if self.enable_finish and self.tooluniverse:
169
- tools.append(self.tooluniverse.get_one_tool_by_one_name('Finish', return_prompt=True))
170
- logger.info("Finish tool is added")
171
- if call_agent and self.tooluniverse:
172
- tools.append(self.tooluniverse.get_one_tool_by_one_name('CallAgent', return_prompt=True))
173
- logger.info("CallAgent tool is added")
174
- elif self.enable_rag and self.tooluniverse:
175
- tools.append(self.tooluniverse.get_one_tool_by_one_name('Tool_RAG', return_prompt=True))
176
- logger.info("Tool_RAG tool is added")
177
- if self.additional_default_tools is not None and self.tooluniverse:
178
- for each_tool_name in self.additional_default_tools:
179
- tool_prompt = self.tooluniverse.get_one_tool_by_one_name(each_tool_name, return_prompt=True)
180
- if tool_prompt is not None:
181
- logger.info(f"{each_tool_name} tool is added")
182
- tools.append(tool_prompt)
 
 
 
 
183
  return tools
184
 
185
  def add_finish_tools(self, tools):
186
- if not self.enable_finish or not self.tooluniverse:
187
- return tools
188
- tools.append(self.tooluniverse.get_one_tool_by_one_name('Finish', return_prompt=True))
189
- logger.info("Finish tool is added")
190
  return tools
191
 
192
  def set_system_prompt(self, conversation, sys_prompt):
193
  if len(conversation) == 0:
194
- conversation.append({"role": "system", "content": sys_prompt})
 
195
  else:
196
  conversation[0] = {"role": "system", "content": sys_prompt}
197
  return conversation
@@ -203,15 +210,15 @@ class TxAgent:
203
  call_agent=False,
204
  call_agent_level=None,
205
  temperature=None):
206
- if not self.enable_rag:
207
- return [{"role": "assistant", "content": fcall_str.strip()}], existing_tools_prompt or [], ''
208
- function_call_json, message = self.tooluniverse.extract_function_call_json(fcall_str, return_message=return_message, verbose=False)
209
  call_results = []
210
  special_tool_call = ''
211
  if function_call_json is not None:
212
  if isinstance(function_call_json, list):
213
  for i in range(len(function_call_json)):
214
- logger.info(f"Tool Call: {function_call_json[i]}")
215
  if function_call_json[i]["name"] == 'Finish':
216
  special_tool_call = 'Finish'
217
  break
@@ -239,12 +246,14 @@ class TxAgent:
239
  else:
240
  call_result = call_result.split('[FinalAnswer]')[-1].strip()
241
  else:
242
- call_result = "Error: The CallAgent has been disabled."
243
  else:
244
- call_result = self.tooluniverse.run_one_function(function_call_json[i])
 
 
245
  call_id = self.tooluniverse.call_id_gen()
246
  function_call_json[i]["call_id"] = call_id
247
- logger.info(f"Tool Call Result: {call_result}")
248
  call_results.append({
249
  "role": "tool",
250
  "content": json.dumps({"tool_name": function_call_json[i]["name"], "content": call_result, "call_id": call_id})
@@ -252,30 +261,33 @@ class TxAgent:
252
  else:
253
  call_results.append({
254
  "role": "tool",
255
- "content": json.dumps({"content": "Not a valid function call."})
256
  })
 
257
  revised_messages = [{
258
  "role": "assistant",
259
  "content": message.strip(),
260
  "tool_calls": json.dumps(function_call_json)
261
  }] + call_results
 
 
262
  return revised_messages, existing_tools_prompt, special_tool_call
263
 
264
  def run_function_call_stream(self, fcall_str,
265
- return_message=False,
266
- existing_tools_prompt=None,
267
- message_for_call_agent=None,
268
- call_agent=False,
269
- call_agent_level=None,
270
- temperature=None,
271
- return_gradio_history=True):
272
- if not self.enable_rag:
273
- gradio_history = [] if return_gradio_history else None
274
- return [{"role": "assistant", "content": fcall_str.strip()}], existing_tools_prompt or [], '', gradio_history
275
- function_call_json, message = self.tooluniverse.extract_function_call_json(fcall_str, return_message=return_message, verbose=False)
276
  call_results = []
277
  special_tool_call = ''
278
- gradio_history = [] if return_gradio_history else None
 
279
  if function_call_json is not None:
280
  if isinstance(function_call_json, list):
281
  for i in range(len(function_call_json)):
@@ -303,21 +315,25 @@ class TxAgent:
303
  "\nYou must follow the following plan to answer the question: " +
304
  str(solution_plan)
305
  )
306
- sub_agent_task = "Sub TxAgent plan: " + str(solution_plan)
 
307
  call_result = yield from self.run_gradio_chat(
308
  full_message, history=[], temperature=temperature,
309
  max_new_tokens=1024, max_token=99999,
310
  call_agent=False, call_agent_level=call_agent_level,
311
  conversation=None,
312
  sub_agent_task=sub_agent_task)
 
313
  if call_result is not None and isinstance(call_result, str):
314
  call_result = call_result.split('[FinalAnswer]')[-1]
315
  else:
316
  call_result = "⚠️ No content returned from sub-agent."
317
  else:
318
- call_result = "Error: The CallAgent has been disabled."
319
  else:
320
- call_result = self.tooluniverse.run_one_function(function_call_json[i])
 
 
321
  call_id = self.tooluniverse.call_id_gen()
322
  function_call_json[i]["call_id"] = call_id
323
  call_results.append({
@@ -334,25 +350,34 @@ class TxAgent:
334
  else:
335
  call_results.append({
336
  "role": "tool",
337
- "content": json.dumps({"content": "Not a valid function call."})
338
  })
 
339
  revised_messages = [{
340
  "role": "assistant",
341
  "content": message.strip(),
342
  "tool_calls": json.dumps(function_call_json)
343
  }] + call_results
344
- return revised_messages, existing_tools_prompt, special_tool_call, gradio_history
 
 
 
 
 
345
 
346
  def get_answer_based_on_unfinished_reasoning(self, conversation, temperature, max_new_tokens, max_token, outputs=None):
347
- if conversation[-1]['role'] == 'assistant':
348
- conversation.append({'role': 'tool', 'content': 'Errors occurred, provide final answer.'})
 
 
 
349
  last_outputs_str = self.llm_infer(messages=conversation,
350
- temperature=temperature,
351
- tools=[],
352
- output_begin_string='[FinalAnswer]',
353
- skip_special_tokens=True,
354
- max_new_tokens=max_new_tokens, max_token=max_token)
355
- logger.info(f"Unfinished reasoning output: {last_outputs_str[:100]}...")
356
  return last_outputs_str
357
 
358
  def run_multistep_agent(self, message: str,
@@ -362,10 +387,20 @@ class TxAgent:
362
  max_round: int = 20,
363
  call_agent=False,
364
  call_agent_level=0) -> str:
365
- logger.debug("Starting multistep agent")
 
 
 
 
 
 
 
 
 
366
  picked_tools_prompt, call_agent_level = self.initialize_tools_prompt(
367
  call_agent, call_agent_level, message)
368
  conversation = self.initialize_conversation(message)
 
369
  outputs = []
370
  last_outputs = []
371
  next_round = True
@@ -374,6 +409,7 @@ class TxAgent:
374
  token_overflow = False
375
  enable_summary = False
376
  last_status = {}
 
377
  if self.enable_checker:
378
  checker = ReasoningTraceChecker(message, conversation)
379
  try:
@@ -387,60 +423,73 @@ class TxAgent:
387
  call_agent=call_agent,
388
  call_agent_level=call_agent_level,
389
  temperature=temperature)
 
390
  if special_tool_call == 'Finish':
391
  next_round = False
392
  conversation.extend(function_call_messages)
 
 
 
393
  content = function_call_messages[0]['content']
394
  if content is None:
395
- return "❌ No content returned after Finish."
396
  return content.split('[FinalAnswer]')[-1]
 
397
  if (self.enable_summary or token_overflow) and not call_agent:
 
 
398
  enable_summary = True
399
  last_status = self.function_result_summary(
400
  conversation, status=last_status, enable_summary=enable_summary)
401
- if function_call_messages:
 
402
  conversation.extend(function_call_messages)
403
- outputs.append(tool_result_format(function_call_messages))
 
404
  else:
405
  next_round = False
406
- conversation.extend([{"role": "assistant", "content": ''.join(last_outputs)}])
 
407
  return ''.join(last_outputs).replace("</s>", "")
408
  if self.enable_checker:
409
  good_status, wrong_info = checker.check_conversation()
410
  if not good_status:
411
- logger.warning(f"Internal error in reasoning: {wrong_info}")
 
 
412
  break
413
  last_outputs = []
414
  outputs.append("### TxAgent:\n")
415
- last_outputs_str, token_overflow = self.llm_infer(
416
- messages=conversation,
417
- temperature=temperature,
418
- tools=picked_tools_prompt,
419
- skip_special_tokens=False,
420
- max_new_tokens=max_new_tokens,
421
- max_token=max_token,
422
- check_token_status=True)
423
  if last_outputs_str is None:
424
- logger.warning("Token overflow detected")
425
  if self.force_finish:
426
- return self.get_answer_based_on_unfinished_reasoning(
427
- conversation, temperature, max_new_tokens, max_token)
428
- return "❌ Token limit exceeded."
429
- last_outputs.append(last_outputs_str)
 
430
  if max_round == current_round:
431
- logger.warning("Max rounds exceeded")
432
  if self.force_finish:
433
- return self.get_answer_based_on_unfinished_reasoning(
434
- conversation, temperature, max_new_tokens, max_token)
435
- return None
 
436
  except Exception as e:
437
- logger.error(f"Error in multistep agent: {e}")
438
  if self.force_finish:
439
- return self.get_answer_based_on_unfinished_reasoning(
440
- conversation, temperature, max_new_tokens, max_token)
441
- return None
442
 
443
  def build_logits_processor(self, messages, llm):
 
444
  tokenizer = llm.get_tokenizer()
445
  if self.avoid_repeat and len(messages) > 2:
446
  assistant_messages = []
@@ -449,49 +498,66 @@ class TxAgent:
449
  assistant_messages.append(messages[-i]['content'])
450
  if len(assistant_messages) == 2:
451
  break
452
- forbidden_ids = [tokenizer.encode(msg, add_special_tokens=False) for msg in assistant_messages]
453
- return [NoRepeatSentenceProcessor(forbidden_ids, 3)]
454
- return None
 
 
455
 
456
  def llm_infer(self, messages, temperature=0.1, tools=None,
457
  output_begin_string=None, max_new_tokens=2048,
458
  max_token=None, skip_special_tokens=True,
459
- model=None, tokenizer=None, terminators=None, seed=None,
460
- check_token_status=False):
461
  if model is None:
462
  model = self.model
 
463
  logits_processor = self.build_logits_processor(messages, model)
464
  sampling_params = SamplingParams(
465
  temperature=temperature,
466
  max_tokens=max_new_tokens,
 
467
  seed=seed if seed is not None else self.seed,
468
  )
 
469
  prompt = self.chat_template.render(
470
  messages=messages, tools=tools, add_generation_prompt=True)
471
  if output_begin_string is not None:
472
  prompt += output_begin_string
 
473
  if check_token_status and max_token is not None:
474
  token_overflow = False
475
- num_input_tokens = len(self.tokenizer.encode(prompt, return_tensors="pt")[0])
476
- if num_input_tokens > max_token:
477
- torch.cuda.empty_cache()
478
- gc.collect()
479
- logger.warning(f"Input tokens: {num_input_tokens}, exceeds max: {max_token}")
480
- token_overflow = True
481
- return None, token_overflow
482
- output = model.generate(prompt, sampling_params=sampling_params)
 
 
 
 
 
 
 
 
483
  output = output[0].outputs[0].text
484
- logger.debug(f"Inference output: {output[:100]}...")
485
  if check_token_status and max_token is not None:
486
  return output, token_overflow
 
487
  return output
488
 
489
  def run_self_agent(self, message: str,
490
- temperature: float,
491
- max_new_tokens: int,
492
- max_token: int) -> str:
493
- logger.debug("Starting self agent")
494
- conversation = self.set_system_prompt([], self.self_prompt)
 
 
495
  conversation.append({"role": "user", "content": message})
496
  return self.llm_infer(messages=conversation,
497
  temperature=temperature,
@@ -502,8 +568,10 @@ class TxAgent:
502
  temperature: float,
503
  max_new_tokens: int,
504
  max_token: int) -> str:
505
- logger.debug("Starting chat agent")
506
- conversation = self.set_system_prompt([], self.chat_prompt)
 
 
507
  conversation.append({"role": "user", "content": message})
508
  return self.llm_infer(messages=conversation,
509
  temperature=temperature,
@@ -511,106 +579,155 @@ class TxAgent:
511
  max_new_tokens=max_new_tokens, max_token=max_token)
512
 
513
  def run_format_agent(self, message: str,
514
- answer: str,
515
- temperature: float,
516
- max_new_tokens: int,
517
- max_token: int) -> str:
518
- logger.debug("Starting format agent")
 
519
  if '[FinalAnswer]' in answer:
520
  possible_final_answer = answer.split("[FinalAnswer]")[-1]
521
  elif "\n\n" in answer:
522
  possible_final_answer = answer.split("\n\n")[-1]
523
  else:
524
  possible_final_answer = answer.strip()
525
- if len(possible_final_answer) == 1 and possible_final_answer[0] in ['A', 'B', 'C', 'D', 'E']:
526
- return possible_final_answer[0]
527
- elif len(possible_final_answer) > 1 and possible_final_answer[1] == ':' and possible_final_answer[0] in ['A', 'B', 'C', 'D', 'E']:
528
- return possible_final_answer[0]
 
 
 
 
 
 
 
529
  conversation = []
530
- format_prompt = "You are helpful assistant to transform the answer to 'A', 'B', 'C', 'D'."
531
  conversation = self.set_system_prompt(conversation, format_prompt)
532
- conversation.append({"role": "user", "content": message + "\nAgent answer: " + answer + "\nAnswer (must be a letter):"})
 
533
  return self.llm_infer(messages=conversation,
534
  temperature=temperature,
535
  tools=None,
536
  max_new_tokens=max_new_tokens, max_token=max_token)
537
 
538
  def run_summary_agent(self, thought_calls: str,
539
- function_response: str,
540
- temperature: float,
541
- max_new_tokens: int,
542
- max_token: int) -> str:
543
- logger.debug("Starting summary agent")
544
  generate_tool_result_summary_training_prompt = """Thought and function calls:
545
  {thought_calls}
546
  Function calls' responses:
547
  \"\"\"
548
  {function_response}
549
  \"\"\"
550
- Generate one summarized sentence about "function calls' responses" with necessary information:
551
- """.format(thought_calls=thought_calls, function_response=function_response)
552
- conversation = [{"role": "user", "content": generate_tool_result_summary_training_prompt}]
 
 
 
 
553
  output = self.llm_infer(messages=conversation,
554
  temperature=temperature,
555
  tools=None,
556
  max_new_tokens=max_new_tokens, max_token=max_token)
 
557
  if '[' in output:
558
  output = output.split('[')[0]
559
  return output
560
 
561
  def function_result_summary(self, input_list, status, enable_summary):
 
 
 
 
 
 
 
 
 
 
 
562
  if 'tool_call_step' not in status:
563
  status['tool_call_step'] = 0
 
564
  for idx in range(len(input_list)):
565
  pos_id = len(input_list)-idx-1
566
- if input_list[pos_id]['role'] == 'assistant' and 'tool_calls' in input_list[pos_id]:
567
- if 'Tool_RAG' in str(input_list[pos_id]['tool_calls']):
568
- status['tool_call_step'] += 1
 
569
  break
570
- status['step'] = status.get('step', 0) + 1
 
 
 
 
 
571
  if not enable_summary:
572
  return status
 
573
  if 'summarized_index' not in status:
574
  status['summarized_index'] = 0
 
575
  if 'summarized_step' not in status:
576
  status['summarized_step'] = 0
 
577
  if 'previous_length' not in status:
578
  status['previous_length'] = 0
 
579
  if 'history' not in status:
580
  status['history'] = []
 
581
  function_response = ''
582
- idx = status['summarized_index']
583
  current_summarized_index = status['summarized_index']
 
584
  status['history'].append(self.summary_mode == 'step' and status['summarized_step']
585
  < status['step']-status['tool_call_step']-self.summary_skip_last_k)
 
 
586
  while idx < len(input_list):
587
  if (self.summary_mode == 'step' and status['summarized_step'] < status['step']-status['tool_call_step']-self.summary_skip_last_k) or (self.summary_mode == 'length' and status['previous_length'] > self.summary_context_length):
 
588
  if input_list[idx]['role'] == 'assistant':
589
  if 'Tool_RAG' in str(input_list[idx]['tool_calls']):
590
  this_thought_calls = None
591
  else:
592
  if len(function_response) != 0:
593
- logger.debug("Internal summary")
594
  status['summarized_step'] += 1
595
  result_summary = self.run_summary_agent(
596
  thought_calls=this_thought_calls,
597
  function_response=function_response,
598
  temperature=0.1,
599
  max_new_tokens=1024,
600
- max_token=99999)
601
- input_list.insert(last_call_idx+1, {'role': 'tool', 'content': result_summary})
 
 
 
602
  status['summarized_index'] = last_call_idx + 2
603
  idx += 1
 
604
  last_call_idx = idx
605
- this_thought_calls = input_list[idx]['content'] + input_list[idx]['tool_calls']
 
606
  function_response = ''
 
607
  elif input_list[idx]['role'] == 'tool' and this_thought_calls is not None:
608
  function_response += input_list[idx]['content']
609
  del input_list[idx]
610
  idx -= 1
 
611
  else:
612
  break
613
  idx += 1
 
614
  if len(function_response) != 0:
615
  status['summarized_step'] += 1
616
  result_summary = self.run_summary_agent(
@@ -618,20 +735,30 @@ Generate one summarized sentence about "function calls' responses" with necessar
618
  function_response=function_response,
619
  temperature=0.1,
620
  max_new_tokens=1024,
621
- max_token=99999)
 
 
622
  tool_calls = json.loads(input_list[last_call_idx]['tool_calls'])
623
  for tool_call in tool_calls:
624
  del tool_call['call_id']
625
  input_list[last_call_idx]['tool_calls'] = json.dumps(tool_calls)
626
- input_list.insert(last_call_idx+1, {'role': 'tool', 'content': result_summary})
 
627
  status['summarized_index'] = last_call_idx + 2
 
628
  return status
629
 
 
 
 
630
  def update_parameters(self, **kwargs):
631
  for key, value in kwargs.items():
632
  if hasattr(self, key):
633
  setattr(self, key, value)
634
- updated_attributes = {key: value for key, value in kwargs.items() if hasattr(self, key)}
 
 
 
635
  return updated_attributes
636
 
637
  def run_gradio_chat(self, message: str,
@@ -645,53 +772,192 @@ Generate one summarized sentence about "function calls' responses" with necessar
645
  seed: int = None,
646
  call_agent_level: int = 0,
647
  sub_agent_task: str = None,
648
- uploaded_files: list = None):
 
 
 
 
 
 
 
 
 
 
 
649
  logger.debug(f"[TxAgent] Chat started, message: {message[:100]}...")
650
  print("\033[1;32;40m[TxAgent] Chat started\033[0m")
 
651
  if not message or len(message.strip()) < 5:
652
  yield "Please provide a valid message or upload files to analyze."
653
  return "Invalid input."
 
654
  if message.startswith("[\U0001f9f0 Tool_RAG") or message.startswith("⚒️"):
655
  return ""
656
- conversation = self.initialize_conversation(message, conversation, history=[])
657
- sampling_params = SamplingParams(
658
- temperature=temperature,
659
- max_tokens=max_new_tokens,
660
- seed=seed if seed is not None else self.seed,
661
- )
662
- prompt = self.chat_template.render(messages=conversation, tools=[], add_generation_prompt=True)
663
- output = self.model.generate([prompt], sampling_params)[0].outputs[0].text # MODIFIED: Direct inference
664
- cleaned = clean_response(output) # MODIFIED: Use clean_response
665
- if '[FinalAnswer]' in cleaned:
666
- parts = cleaned.split('[FinalAnswer]', 1)
667
- final_answer = parts[1] if len(parts) > 1 else cleaned
668
- history.append(ChatMessage(role="assistant", content=final_answer.strip()))
669
- else:
670
- history.append(ChatMessage(role="assistant", content=cleaned.strip()))
671
- yield history
672
- return cleaned
673
-
674
- def clean_response(text: str) -> str: # MODIFIED: Add clean_response for compatibility
675
- text = sanitize_utf8(text)
676
- text = re.sub(r"\[TOOL_CALLS\].*?\n|\[.*?\].*?\n|(?:get_|tool\s|retrieve\s|use\s|rag\s).*?\n", "", text, flags=re.DOTALL | re.IGNORECASE)
677
- text = re.sub(r"\{'meta':\s*\{.*?\}\s*,\s*'results':\s*\[.*?\]\}\n?", "", text, flags=re.DOTALL)
678
- text = re.sub(
679
- r"(?i)(to\s|analyze|will\s|since\s|no\s|none|previous|attempt|involve|check\s|explore|manually|"
680
- r"start|look|use|focus|retrieve|tool|based\s|overall|indicate|mention|consider|ensure|need\s|"
681
- r"provide|review|assess|identify|potential|records|patient|history|symptoms|medication|"
682
- r"conflict|assessment|follow-up|issue|reasoning|step|prompt|address|rag|thought|try|john\sdoe|nkma).*?\n",
683
- "", text, flags=re.DOTALL
684
- )
685
- text = re.sub(r"\n{2,}", "\n", text).strip()
686
- lines = []
687
- valid_heading = False
688
- for line in text.split("\n"):
689
- line = line.strip()
690
- if line.lower() in ["missed diagnoses:", "medication conflicts:", "incomplete assessments:", "urgent follow-up:"]:
691
- valid_heading = True
692
- lines.append(f"**{line[:-1]}**:")
693
- elif valid_heading and line.startswith("-"):
694
- lines.append(line)
695
- else:
696
- valid_heading = False
697
- return "\n".join(lines).strip()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
  from gradio import ChatMessage
13
  from .toolrag import ToolRAGModel
14
  import torch
15
+ # near the top of txagent.py
16
  import logging
 
17
  logger = logging.getLogger(__name__)
18
  logging.basicConfig(level=logging.INFO)
19
 
20
  from .utils import NoRepeatSentenceProcessor, ReasoningTraceChecker, tool_result_format
21
 
22
+
23
  class TxAgent:
24
  def __init__(self, model_name,
25
  rag_model_name,
26
+ tool_files_dict=None, # None leads to the default tool files in ToolUniverse
27
+ enable_finish=True,
28
+ enable_rag=True,
29
  enable_summary=False,
30
  init_rag_num=0,
31
+ step_rag_num=10,
32
  summary_mode='step',
33
  summary_skip_last_k=0,
34
  summary_context_length=None,
 
45
  self.rag_model_name = rag_model_name
46
  self.tool_files_dict = tool_files_dict
47
  self.model = None
48
+ self.rag_model = ToolRAGModel(rag_model_name)
49
  self.tooluniverse = None
50
+ # self.tool_desc = None
51
+ self.prompt_multi_step = "You are a helpful assistant that will solve problems through detailed, step-by-step reasoning and actions based on your reasoning. Typically, your actions will use the provided functions. You have access to the following functions."
52
  self.self_prompt = "Strictly follow the instruction."
53
  self.chat_prompt = "You are helpful assistant to chat with the user."
54
  self.enable_finish = enable_finish
 
68
 
69
  def init_model(self):
70
  self.load_models()
71
+ self.load_tooluniverse()
72
+ self.load_tool_desc_embedding()
 
73
 
74
  def print_self_values(self):
75
  for attr, value in self.__dict__.items():
76
+ print(f"{attr}: {value}")
77
 
78
  def load_models(self, model_name=None):
79
  if model_name is not None:
80
  if model_name == self.model_name:
81
  return f"The model {model_name} is already loaded."
82
  self.model_name = model_name
83
+
84
+ self.model = LLM(model=self.model_name)
85
  self.chat_template = Template(self.model.get_tokenizer().chat_template)
86
  self.tokenizer = self.model.get_tokenizer()
87
+
88
  return f"Model {model_name} loaded successfully."
89
 
90
  def load_tooluniverse(self):
 
 
 
91
  self.tooluniverse = ToolUniverse(tool_files=self.tool_files_dict)
92
  self.tooluniverse.load_tools()
93
  special_tools = self.tooluniverse.prepare_tool_prompts(
 
95
  self.special_tools_name = [tool['name'] for tool in special_tools]
96
 
97
  def load_tool_desc_embedding(self):
98
+ self.rag_model.load_tool_desc_embedding(self.tooluniverse)
 
99
 
100
  def rag_infer(self, query, top_k=5):
 
 
101
  return self.rag_model.rag_infer(query, top_k)
102
 
103
  def initialize_tools_prompt(self, call_agent, call_agent_level, message):
104
  picked_tools_prompt = []
 
 
105
  picked_tools_prompt = self.add_special_tools(
106
  picked_tools_prompt, call_agent=call_agent)
107
  if call_agent:
108
  call_agent_level += 1
109
  if call_agent_level >= 2:
110
  call_agent = False
111
+
112
  if not call_agent:
113
  picked_tools_prompt += self.tool_RAG(
114
  message=message, rag_num=self.init_rag_num)
 
117
  def initialize_conversation(self, message, conversation=None, history=None):
118
  if conversation is None:
119
  conversation = []
120
+
121
  conversation = self.set_system_prompt(
122
  conversation, self.prompt_multi_step)
123
  if history is not None:
124
  if len(history) == 0:
125
  conversation = []
126
+ print("clear conversation successfully")
127
  else:
128
  for i in range(len(history)):
129
  if history[i]['role'] == 'user':
 
135
  if i == len(history)-1 and history[i]['role'] == 'assistant':
136
  conversation.append(
137
  {"role": "assistant", "content": history[i]['content']})
138
+
139
  conversation.append({"role": "user", "content": message})
140
+
141
  return conversation
142
 
143
  def tool_RAG(self, message=None,
 
145
  existing_tools_prompt=[],
146
  rag_num=5,
147
  return_call_result=False):
148
+ extra_factor = 30 # Factor to retrieve more than rag_num
 
 
149
  if picked_tool_names is None:
150
  assert picked_tool_names is not None or message is not None
151
  picked_tool_names = self.rag_infer(
152
  message, top_k=rag_num*extra_factor)
153
+
154
+ picked_tool_names_no_special = []
155
+ for tool in picked_tool_names:
156
+ if tool not in self.special_tools_name:
157
+ picked_tool_names_no_special.append(tool)
158
  picked_tool_names_no_special = picked_tool_names_no_special[:rag_num]
159
  picked_tool_names = picked_tool_names_no_special[:rag_num]
160
+
161
  picked_tools = self.tooluniverse.get_tool_by_name(picked_tool_names)
162
+ picked_tools_prompt = self.tooluniverse.prepare_tool_prompts(
163
+ picked_tools)
164
  if return_call_result:
165
  return picked_tools_prompt, picked_tool_names
166
  return picked_tools_prompt
167
 
168
  def add_special_tools(self, tools, call_agent=False):
169
+ if self.enable_finish:
170
+ tools.append(self.tooluniverse.get_one_tool_by_one_name(
171
+ 'Finish', return_prompt=True))
172
+ print("Finish tool is added")
173
+ if call_agent:
174
+ tools.append(self.tooluniverse.get_one_tool_by_one_name(
175
+ 'CallAgent', return_prompt=True))
176
+ print("CallAgent tool is added")
177
+ else:
178
+ if self.enable_rag:
179
+ tools.append(self.tooluniverse.get_one_tool_by_one_name(
180
+ 'Tool_RAG', return_prompt=True))
181
+ print("Tool_RAG tool is added")
182
+
183
+ if self.additional_default_tools is not None:
184
+ for each_tool_name in self.additional_default_tools:
185
+ tool_prompt = self.tooluniverse.get_one_tool_by_one_name(
186
+ each_tool_name, return_prompt=True)
187
+ if tool_prompt is not None:
188
+ print(f"{each_tool_name} tool is added")
189
+ tools.append(tool_prompt)
190
  return tools
191
 
192
  def add_finish_tools(self, tools):
193
+ tools.append(self.tooluniverse.get_one_tool_by_one_name(
194
+ 'Finish', return_prompt=True))
195
+ print("Finish tool is added")
 
196
  return tools
197
 
198
  def set_system_prompt(self, conversation, sys_prompt):
199
  if len(conversation) == 0:
200
+ conversation.append(
201
+ {"role": "system", "content": sys_prompt})
202
  else:
203
  conversation[0] = {"role": "system", "content": sys_prompt}
204
  return conversation
 
210
  call_agent=False,
211
  call_agent_level=None,
212
  temperature=None):
213
+
214
+ function_call_json, message = self.tooluniverse.extract_function_call_json(
215
+ fcall_str, return_message=return_message, verbose=False)
216
  call_results = []
217
  special_tool_call = ''
218
  if function_call_json is not None:
219
  if isinstance(function_call_json, list):
220
  for i in range(len(function_call_json)):
221
+ print("\033[94mTool Call:\033[0m", function_call_json[i])
222
  if function_call_json[i]["name"] == 'Finish':
223
  special_tool_call = 'Finish'
224
  break
 
246
  else:
247
  call_result = call_result.split('[FinalAnswer]')[-1].strip()
248
  else:
249
+ call_result = "Error: The CallAgent has been disabled. Please proceed with your reasoning process to solve this question."
250
  else:
251
+ call_result = self.tooluniverse.run_one_function(
252
+ function_call_json[i])
253
+
254
  call_id = self.tooluniverse.call_id_gen()
255
  function_call_json[i]["call_id"] = call_id
256
+ print("\033[94mTool Call Result:\033[0m", call_result)
257
  call_results.append({
258
  "role": "tool",
259
  "content": json.dumps({"tool_name": function_call_json[i]["name"], "content": call_result, "call_id": call_id})
 
261
  else:
262
  call_results.append({
263
  "role": "tool",
264
+ "content": json.dumps({"content": "Not a valid function call, please check the function call format."})
265
  })
266
+
267
  revised_messages = [{
268
  "role": "assistant",
269
  "content": message.strip(),
270
  "tool_calls": json.dumps(function_call_json)
271
  }] + call_results
272
+
273
+ # Yield the final result.
274
  return revised_messages, existing_tools_prompt, special_tool_call
275
 
276
  def run_function_call_stream(self, fcall_str,
277
+ return_message=False,
278
+ existing_tools_prompt=None,
279
+ message_for_call_agent=None,
280
+ call_agent=False,
281
+ call_agent_level=None,
282
+ temperature=None,
283
+ return_gradio_history=True):
284
+
285
+ function_call_json, message = self.tooluniverse.extract_function_call_json(
286
+ fcall_str, return_message=return_message, verbose=False)
 
287
  call_results = []
288
  special_tool_call = ''
289
+ if return_gradio_history:
290
+ gradio_history = []
291
  if function_call_json is not None:
292
  if isinstance(function_call_json, list):
293
  for i in range(len(function_call_json)):
 
315
  "\nYou must follow the following plan to answer the question: " +
316
  str(solution_plan)
317
  )
318
+ sub_agent_task = "Sub TxAgent plan: " + \
319
+ str(solution_plan)
320
  call_result = yield from self.run_gradio_chat(
321
  full_message, history=[], temperature=temperature,
322
  max_new_tokens=1024, max_token=99999,
323
  call_agent=False, call_agent_level=call_agent_level,
324
  conversation=None,
325
  sub_agent_task=sub_agent_task)
326
+
327
  if call_result is not None and isinstance(call_result, str):
328
  call_result = call_result.split('[FinalAnswer]')[-1]
329
  else:
330
  call_result = "⚠️ No content returned from sub-agent."
331
  else:
332
+ call_result = "Error: The CallAgent has been disabled. Please proceed with your reasoning process to solve this question."
333
  else:
334
+ call_result = self.tooluniverse.run_one_function(
335
+ function_call_json[i])
336
+
337
  call_id = self.tooluniverse.call_id_gen()
338
  function_call_json[i]["call_id"] = call_id
339
  call_results.append({
 
350
  else:
351
  call_results.append({
352
  "role": "tool",
353
+ "content": json.dumps({"content": "Not a valid function call, please check the function call format."})
354
  })
355
+
356
  revised_messages = [{
357
  "role": "assistant",
358
  "content": message.strip(),
359
  "tool_calls": json.dumps(function_call_json)
360
  }] + call_results
361
+
362
+ if return_gradio_history:
363
+ return revised_messages, existing_tools_prompt, special_tool_call, gradio_history
364
+ else:
365
+ return revised_messages, existing_tools_prompt, special_tool_call
366
+
367
 
368
  def get_answer_based_on_unfinished_reasoning(self, conversation, temperature, max_new_tokens, max_token, outputs=None):
369
+ if conversation[-1]['role'] == 'assisant':
370
+ conversation.append(
371
+ {'role': 'tool', 'content': 'Errors happen during the function call, please come up with the final answer with the current information.'})
372
+ finish_tools_prompt = self.add_finish_tools([])
373
+
374
  last_outputs_str = self.llm_infer(messages=conversation,
375
+ temperature=temperature,
376
+ tools=finish_tools_prompt,
377
+ output_begin_string='Since I cannot continue reasoning, I will provide the final answer based on the current information and general knowledge.\n\n[FinalAnswer]',
378
+ skip_special_tokens=True,
379
+ max_new_tokens=max_new_tokens, max_token=max_token)
380
+ print(last_outputs_str)
381
  return last_outputs_str
382
 
383
  def run_multistep_agent(self, message: str,
 
387
  max_round: int = 20,
388
  call_agent=False,
389
  call_agent_level=0) -> str:
390
+ """
391
+ Generate a streaming response using the llama3-8b model.
392
+ Args:
393
+ message (str): The input message.
394
+ temperature (float): The temperature for generating the response.
395
+ max_new_tokens (int): The maximum number of new tokens to generate.
396
+ Returns:
397
+ str: The generated response.
398
+ """
399
+ print("\033[1;32;40mstart\033[0m")
400
  picked_tools_prompt, call_agent_level = self.initialize_tools_prompt(
401
  call_agent, call_agent_level, message)
402
  conversation = self.initialize_conversation(message)
403
+
404
  outputs = []
405
  last_outputs = []
406
  next_round = True
 
409
  token_overflow = False
410
  enable_summary = False
411
  last_status = {}
412
+
413
  if self.enable_checker:
414
  checker = ReasoningTraceChecker(message, conversation)
415
  try:
 
423
  call_agent=call_agent,
424
  call_agent_level=call_agent_level,
425
  temperature=temperature)
426
+
427
  if special_tool_call == 'Finish':
428
  next_round = False
429
  conversation.extend(function_call_messages)
430
+ if isinstance(function_call_messages[0]['content'], types.GeneratorType):
431
+ function_call_messages[0]['content'] = next(
432
+ function_call_messages[0]['content'])
433
  content = function_call_messages[0]['content']
434
  if content is None:
435
+ return "❌ No content returned after Finish tool call."
436
  return content.split('[FinalAnswer]')[-1]
437
+
438
  if (self.enable_summary or token_overflow) and not call_agent:
439
+ if token_overflow:
440
+ print("token_overflow, using summary")
441
  enable_summary = True
442
  last_status = self.function_result_summary(
443
  conversation, status=last_status, enable_summary=enable_summary)
444
+
445
+ if function_call_messages is not None:
446
  conversation.extend(function_call_messages)
447
+ outputs.append(tool_result_format(
448
+ function_call_messages))
449
  else:
450
  next_round = False
451
+ conversation.extend(
452
+ [{"role": "assistant", "content": ''.join(last_outputs)}])
453
  return ''.join(last_outputs).replace("</s>", "")
454
  if self.enable_checker:
455
  good_status, wrong_info = checker.check_conversation()
456
  if not good_status:
457
+ next_round = False
458
+ print(
459
+ "Internal error in reasoning: " + wrong_info)
460
  break
461
  last_outputs = []
462
  outputs.append("### TxAgent:\n")
463
+ last_outputs_str, token_overflow = self.llm_infer(messages=conversation,
464
+ temperature=temperature,
465
+ tools=picked_tools_prompt,
466
+ skip_special_tokens=False,
467
+ max_new_tokens=max_new_tokens, max_token=max_token,
468
+ check_token_status=True)
 
 
469
  if last_outputs_str is None:
470
+ print("The number of tokens exceeds the maximum limit.")
471
  if self.force_finish:
472
+ return self.get_answer_based_on_unfinished_reasoning(conversation, temperature, max_new_tokens, max_token)
473
+ else:
474
+ return "❌ Token limit exceeded — no further steps possible."
475
+ else:
476
+ last_outputs.append(last_outputs_str)
477
  if max_round == current_round:
478
+ print("The number of rounds exceeds the maximum limit!")
479
  if self.force_finish:
480
+ return self.get_answer_based_on_unfinished_reasoning(conversation, temperature, max_new_tokens, max_token)
481
+ else:
482
+ return None
483
+
484
  except Exception as e:
485
+ print(f"Error: {e}")
486
  if self.force_finish:
487
+ return self.get_answer_based_on_unfinished_reasoning(conversation, temperature, max_new_tokens, max_token)
488
+ else:
489
+ return None
490
 
491
  def build_logits_processor(self, messages, llm):
492
+ # Use the tokenizer from the LLM instance.
493
  tokenizer = llm.get_tokenizer()
494
  if self.avoid_repeat and len(messages) > 2:
495
  assistant_messages = []
 
498
  assistant_messages.append(messages[-i]['content'])
499
  if len(assistant_messages) == 2:
500
  break
501
+ forbidden_ids = [tokenizer.encode(
502
+ msg, add_special_tokens=False) for msg in assistant_messages]
503
+ return [NoRepeatSentenceProcessor(forbidden_ids, 5)]
504
+ else:
505
+ return None
506
 
507
  def llm_infer(self, messages, temperature=0.1, tools=None,
508
  output_begin_string=None, max_new_tokens=2048,
509
  max_token=None, skip_special_tokens=True,
510
+ model=None, tokenizer=None, terminators=None, seed=None, check_token_status=False):
511
+
512
  if model is None:
513
  model = self.model
514
+
515
  logits_processor = self.build_logits_processor(messages, model)
516
  sampling_params = SamplingParams(
517
  temperature=temperature,
518
  max_tokens=max_new_tokens,
519
+
520
  seed=seed if seed is not None else self.seed,
521
  )
522
+
523
  prompt = self.chat_template.render(
524
  messages=messages, tools=tools, add_generation_prompt=True)
525
  if output_begin_string is not None:
526
  prompt += output_begin_string
527
+
528
  if check_token_status and max_token is not None:
529
  token_overflow = False
530
+ num_input_tokens = len(self.tokenizer.encode(
531
+ prompt, return_tensors="pt")[0])
532
+ if max_token is not None:
533
+ if num_input_tokens > max_token:
534
+ torch.cuda.empty_cache()
535
+ gc.collect()
536
+ print("Number of input tokens before inference:",
537
+ num_input_tokens)
538
+ logger.info(
539
+ "The number of tokens exceeds the maximum limit!!!!")
540
+ token_overflow = True
541
+ return None, token_overflow
542
+ output = model.generate(
543
+ prompt,
544
+ sampling_params=sampling_params,
545
+ )
546
  output = output[0].outputs[0].text
547
+ print("\033[92m" + output + "\033[0m")
548
  if check_token_status and max_token is not None:
549
  return output, token_overflow
550
+
551
  return output
552
 
553
  def run_self_agent(self, message: str,
554
+ temperature: float,
555
+ max_new_tokens: int,
556
+ max_token: int) -> str:
557
+
558
+ print("\033[1;32;40mstart self agent\033[0m")
559
+ conversation = []
560
+ conversation = self.set_system_prompt(conversation, self.self_prompt)
561
  conversation.append({"role": "user", "content": message})
562
  return self.llm_infer(messages=conversation,
563
  temperature=temperature,
 
568
  temperature: float,
569
  max_new_tokens: int,
570
  max_token: int) -> str:
571
+
572
+ print("\033[1;32;40mstart chat agent\033[0m")
573
+ conversation = []
574
+ conversation = self.set_system_prompt(conversation, self.chat_prompt)
575
  conversation.append({"role": "user", "content": message})
576
  return self.llm_infer(messages=conversation,
577
  temperature=temperature,
 
579
  max_new_tokens=max_new_tokens, max_token=max_token)
580
 
581
  def run_format_agent(self, message: str,
582
+ answer: str,
583
+ temperature: float,
584
+ max_new_tokens: int,
585
+ max_token: int) -> str:
586
+
587
+ print("\033[1;32;40mstart format agent\033[0m")
588
  if '[FinalAnswer]' in answer:
589
  possible_final_answer = answer.split("[FinalAnswer]")[-1]
590
  elif "\n\n" in answer:
591
  possible_final_answer = answer.split("\n\n")[-1]
592
  else:
593
  possible_final_answer = answer.strip()
594
+ if len(possible_final_answer) == 1:
595
+ choice = possible_final_answer[0]
596
+ if choice in ['A', 'B', 'C', 'D', 'E']:
597
+ return choice
598
+ elif len(possible_final_answer) > 1:
599
+ if possible_final_answer[1] == ':':
600
+ choice = possible_final_answer[0]
601
+ if choice in ['A', 'B', 'C', 'D', 'E']:
602
+ print("choice", choice)
603
+ return choice
604
+
605
  conversation = []
606
+ format_prompt = f"You are helpful assistant to transform the answer of agent to the final answer of 'A', 'B', 'C', 'D'."
607
  conversation = self.set_system_prompt(conversation, format_prompt)
608
+ conversation.append({"role": "user", "content": message +
609
+ "\nThe final answer of agent:" + answer + "\n The answer is (must be a letter):"})
610
  return self.llm_infer(messages=conversation,
611
  temperature=temperature,
612
  tools=None,
613
  max_new_tokens=max_new_tokens, max_token=max_token)
614
 
615
  def run_summary_agent(self, thought_calls: str,
616
+ function_response: str,
617
+ temperature: float,
618
+ max_new_tokens: int,
619
+ max_token: int) -> str:
620
+ print("\033[1;32;40mSummarized Tool Result:\033[0m")
621
  generate_tool_result_summary_training_prompt = """Thought and function calls:
622
  {thought_calls}
623
  Function calls' responses:
624
  \"\"\"
625
  {function_response}
626
  \"\"\"
627
+ Based on the Thought and function calls, and the function calls' responses, you need to generate a summary of the function calls' responses that fulfills the requirements of the thought. The summary MUST BE ONE sentence and include all necessary information.
628
+ Directly respond with the summarized sentence of the function calls' responses only.
629
+ Generate **one summarized sentence** about "function calls' responses" with necessary information, and respond with a string:
630
+ """.format(thought_calls=thought_calls, function_response=function_response)
631
+ conversation = []
632
+ conversation.append(
633
+ {"role": "user", "content": generate_tool_result_summary_training_prompt})
634
  output = self.llm_infer(messages=conversation,
635
  temperature=temperature,
636
  tools=None,
637
  max_new_tokens=max_new_tokens, max_token=max_token)
638
+
639
  if '[' in output:
640
  output = output.split('[')[0]
641
  return output
642
 
643
  def function_result_summary(self, input_list, status, enable_summary):
644
+ """
645
+ Processes the input list, extracting information from sequences of 'user', 'tool', 'assistant' roles.
646
+ Supports 'length' and 'step' modes, and skips the last 'k' groups.
647
+ Parameters:
648
+ input_list (list): A list of dictionaries containing role and other information.
649
+ summary_skip_last_k (int): Number of groups to skip from the end. Defaults to 0.
650
+ summary_context_length (int): The context length threshold for the 'length' mode.
651
+ last_processed_index (tuple or int): The last processed index.
652
+ Returns:
653
+ list: A list of extracted information from valid sequences.
654
+ """
655
  if 'tool_call_step' not in status:
656
  status['tool_call_step'] = 0
657
+
658
  for idx in range(len(input_list)):
659
  pos_id = len(input_list)-idx-1
660
+ if input_list[pos_id]['role'] == 'assistant':
661
+ if 'tool_calls' in input_list[pos_id]:
662
+ if 'Tool_RAG' in str(input_list[pos_id]['tool_calls']):
663
+ status['tool_call_step'] += 1
664
  break
665
+
666
+ if 'step' in status:
667
+ status['step'] += 1
668
+ else:
669
+ status['step'] = 0
670
+
671
  if not enable_summary:
672
  return status
673
+
674
  if 'summarized_index' not in status:
675
  status['summarized_index'] = 0
676
+
677
  if 'summarized_step' not in status:
678
  status['summarized_step'] = 0
679
+
680
  if 'previous_length' not in status:
681
  status['previous_length'] = 0
682
+
683
  if 'history' not in status:
684
  status['history'] = []
685
+
686
  function_response = ''
687
+ idx = 0
688
  current_summarized_index = status['summarized_index']
689
+
690
  status['history'].append(self.summary_mode == 'step' and status['summarized_step']
691
  < status['step']-status['tool_call_step']-self.summary_skip_last_k)
692
+
693
+ idx = current_summarized_index
694
  while idx < len(input_list):
695
  if (self.summary_mode == 'step' and status['summarized_step'] < status['step']-status['tool_call_step']-self.summary_skip_last_k) or (self.summary_mode == 'length' and status['previous_length'] > self.summary_context_length):
696
+
697
  if input_list[idx]['role'] == 'assistant':
698
  if 'Tool_RAG' in str(input_list[idx]['tool_calls']):
699
  this_thought_calls = None
700
  else:
701
  if len(function_response) != 0:
702
+ print("internal summary")
703
  status['summarized_step'] += 1
704
  result_summary = self.run_summary_agent(
705
  thought_calls=this_thought_calls,
706
  function_response=function_response,
707
  temperature=0.1,
708
  max_new_tokens=1024,
709
+ max_token=99999
710
+ )
711
+
712
+ input_list.insert(
713
+ last_call_idx+1, {'role': 'tool', 'content': result_summary})
714
  status['summarized_index'] = last_call_idx + 2
715
  idx += 1
716
+
717
  last_call_idx = idx
718
+ this_thought_calls = input_list[idx]['content'] + \
719
+ input_list[idx]['tool_calls']
720
  function_response = ''
721
+
722
  elif input_list[idx]['role'] == 'tool' and this_thought_calls is not None:
723
  function_response += input_list[idx]['content']
724
  del input_list[idx]
725
  idx -= 1
726
+
727
  else:
728
  break
729
  idx += 1
730
+
731
  if len(function_response) != 0:
732
  status['summarized_step'] += 1
733
  result_summary = self.run_summary_agent(
 
735
  function_response=function_response,
736
  temperature=0.1,
737
  max_new_tokens=1024,
738
+ max_token=99999
739
+ )
740
+
741
  tool_calls = json.loads(input_list[last_call_idx]['tool_calls'])
742
  for tool_call in tool_calls:
743
  del tool_call['call_id']
744
  input_list[last_call_idx]['tool_calls'] = json.dumps(tool_calls)
745
+ input_list.insert(
746
+ last_call_idx+1, {'role': 'tool', 'content': result_summary})
747
  status['summarized_index'] = last_call_idx + 2
748
+
749
  return status
750
 
751
+ # Following are Gradio related functions
752
+
753
+ # General update method that accepts any new arguments through kwargs
754
  def update_parameters(self, **kwargs):
755
  for key, value in kwargs.items():
756
  if hasattr(self, key):
757
  setattr(self, key, value)
758
+
759
+ # Return the updated attributes
760
+ updated_attributes = {key: value for key,
761
+ value in kwargs.items() if hasattr(self, key)}
762
  return updated_attributes
763
 
764
  def run_gradio_chat(self, message: str,
 
772
  seed: int = None,
773
  call_agent_level: int = 0,
774
  sub_agent_task: str = None,
775
+ uploaded_files: list = None) -> str:
776
+ """
777
+ Generate a streaming response using the loaded model.
778
+ Args:
779
+ message (str): The input message (with file content if uploaded).
780
+ history (list): The conversation history used by ChatInterface.
781
+ temperature (float): Sampling temperature.
782
+ max_new_tokens (int): Max new tokens.
783
+ max_token (int): Max total tokens allowed.
784
+ Returns:
785
+ str: Final assistant message.
786
+ """
787
  logger.debug(f"[TxAgent] Chat started, message: {message[:100]}...")
788
  print("\033[1;32;40m[TxAgent] Chat started\033[0m")
789
+
790
  if not message or len(message.strip()) < 5:
791
  yield "Please provide a valid message or upload files to analyze."
792
  return "Invalid input."
793
+
794
  if message.startswith("[\U0001f9f0 Tool_RAG") or message.startswith("⚒️"):
795
  return ""
796
+
797
+ outputs = []
798
+ outputs_str = ''
799
+ last_outputs = []
800
+
801
+ picked_tools_prompt, call_agent_level = self.initialize_tools_prompt(
802
+ call_agent,
803
+ call_agent_level,
804
+ message)
805
+
806
+ conversation = self.initialize_conversation(
807
+ message,
808
+ conversation=conversation,
809
+ history=history)
810
+ history = []
811
+
812
+ next_round = True
813
+ function_call_messages = []
814
+ current_round = 0
815
+ enable_summary = False
816
+ last_status = {}
817
+ token_overflow = False
818
+
819
+ if self.enable_checker:
820
+ checker = ReasoningTraceChecker(
821
+ message, conversation, init_index=len(conversation))
822
+
823
+ try:
824
+ while next_round and current_round < max_round:
825
+ current_round += 1
826
+ logger.debug(f"Round {current_round}, conversation length: {len(conversation)}")
827
+
828
+ if last_outputs:
829
+ function_call_messages, picked_tools_prompt, special_tool_call, current_gradio_history = yield from self.run_function_call_stream(
830
+ last_outputs, return_message=True,
831
+ existing_tools_prompt=picked_tools_prompt,
832
+ message_for_call_agent=message,
833
+ call_agent=call_agent,
834
+ call_agent_level=call_agent_level,
835
+ temperature=temperature)
836
+
837
+ history.extend(current_gradio_history)
838
+
839
+ if special_tool_call == 'Finish' and function_call_messages:
840
+ yield history
841
+ next_round = False
842
+ conversation.extend(function_call_messages)
843
+ return function_call_messages[0]['content']
844
+
845
+ elif special_tool_call in ['RequireClarification', 'DirectResponse']:
846
+ last_msg = history[-1] if history else ChatMessage(role="assistant", content="Response needed.")
847
+ history.append(ChatMessage(role="assistant", content=last_msg.content))
848
+ yield history
849
+ next_round = False
850
+ return last_msg.content
851
+
852
+ if (self.enable_summary or token_overflow) and not call_agent:
853
+ enable_summary = True
854
+
855
+ last_status = self.function_result_summary(
856
+ conversation, status=last_status,
857
+ enable_summary=enable_summary)
858
+
859
+ if function_call_messages:
860
+ conversation.extend(function_call_messages)
861
+ yield history
862
+ else:
863
+ next_round = False
864
+ conversation.append({"role": "assistant", "content": ''.join(last_outputs)})
865
+ return ''.join(last_outputs).replace("</s>", "")
866
+
867
+ if self.enable_checker:
868
+ good_status, wrong_info = checker.check_conversation()
869
+ if not good_status:
870
+ print("Checker flagged reasoning error: ", wrong_info)
871
+ break
872
+
873
+ last_outputs = []
874
+ last_outputs_str, token_overflow = self.llm_infer(
875
+ messages=conversation,
876
+ temperature=temperature,
877
+ tools=picked_tools_prompt,
878
+ skip_special_tokens=False,
879
+ max_new_tokens=max_new_tokens,
880
+ max_token=max_token,
881
+ seed=seed,
882
+ check_token_status=True)
883
+
884
+ logger.debug(f"llm_infer output: {last_outputs_str[:100] if last_outputs_str else None}, token_overflow: {token_overflow}")
885
+
886
+ if last_outputs_str is None:
887
+ logger.warning("llm_infer returned None due to token overflow")
888
+ if self.force_finish:
889
+ last_outputs_str = self.get_answer_based_on_unfinished_reasoning(
890
+ conversation, temperature, max_new_tokens, max_token)
891
+ history.append(ChatMessage(role="assistant", content=last_outputs_str.strip()))
892
+ yield history
893
+ return last_outputs_str
894
+ else:
895
+ error_msg = "Token limit exceeded. Please reduce input size or increase max_token."
896
+ history.append(ChatMessage(role="assistant", content=error_msg))
897
+ yield history
898
+ return error_msg
899
+
900
+ last_thought = last_outputs_str.split("[TOOL_CALLS]")[0]
901
+
902
+ for msg in history:
903
+ if msg.metadata is not None:
904
+ msg.metadata['status'] = 'done'
905
+
906
+ if '[FinalAnswer]' in last_thought:
907
+ parts = last_thought.split('[FinalAnswer]', 1)
908
+ if len(parts) == 2:
909
+ final_thought, final_answer = parts
910
+ else:
911
+ final_thought, final_answer = last_thought, ""
912
+ history.append(ChatMessage(role="assistant", content=final_thought.strip()))
913
+ yield history
914
+ history.append(ChatMessage(role="assistant", content="**🧠 Final Analysis:**\n" + final_answer.strip()))
915
+ yield history
916
+ else:
917
+ history.append(ChatMessage(role="assistant", content=last_thought))
918
+ yield history
919
+
920
+ last_outputs.append(last_outputs_str)
921
+
922
+ if next_round:
923
+ if self.force_finish:
924
+ last_outputs_str = self.get_answer_based_on_unfinished_reasoning(
925
+ conversation, temperature, max_new_tokens, max_token)
926
+ if '[FinalAnswer]' in last_outputs_str:
927
+ parts = last_outputs_str.split('[FinalAnswer]', 1)
928
+ if len(parts) == 2:
929
+ final_thought, final_answer = parts
930
+ else:
931
+ final_thought, final_answer = last_outputs_str, ""
932
+ history.append(ChatMessage(role="assistant", content=final_thought.strip()))
933
+ yield history
934
+ history.append(ChatMessage(role="assistant", content="**🧠 Final Analysis:**\n" + final_answer.strip()))
935
+ yield history
936
+ else:
937
+ history.append(ChatMessage(role="assistant", content=last_outputs_str.strip()))
938
+ yield history
939
+ else:
940
+ yield "The number of reasoning rounds exceeded the limit."
941
+
942
+ except Exception as e:
943
+ logger.error(f"Exception in run_gradio_chat: {e}", exc_info=True)
944
+ error_msg = f"An error occurred: {e}"
945
+ history.append(ChatMessage(role="assistant", content=error_msg))
946
+ yield history
947
+ if self.force_finish:
948
+ last_outputs_str = self.get_answer_based_on_unfinished_reasoning(
949
+ conversation, temperature, max_new_tokens, max_token)
950
+ if '[FinalAnswer]' in last_outputs_str:
951
+ parts = last_outputs_str.split('[FinalAnswer]', 1)
952
+ if len(parts) == 2:
953
+ final_thought, final_answer = parts
954
+ else:
955
+ final_thought, final_answer = last_outputs_str, ""
956
+ history.append(ChatMessage(role="assistant", content=final_thought.strip()))
957
+ yield history
958
+ history.append(ChatMessage(role="assistant", content="**🧠 Final Analysis:**\n" + final_answer.strip()))
959
+ yield history
960
+ else:
961
+ history.append(ChatMessage(role="assistant", content=last_outputs_str.strip()))
962
+ yield history
963
+ return error_msg