Ali2206 commited on
Commit
0ea3469
·
verified ·
1 Parent(s): 9345354

Update src/txagent/txagent.py

Browse files
Files changed (1) hide show
  1. src/txagent/txagent.py +685 -500
src/txagent/txagent.py CHANGED
@@ -11,25 +11,19 @@ import types
11
  from tooluniverse import ToolUniverse
12
  from gradio import ChatMessage
13
  from .toolrag import ToolRAGModel
14
- import torch
15
- import logging
16
- from difflib import SequenceMatcher
17
- import threading
18
-
19
- logger = logging.getLogger(__name__)
20
- logging.basicConfig(level=logging.INFO)
21
 
22
  from .utils import NoRepeatSentenceProcessor, ReasoningTraceChecker, tool_result_format
23
 
 
24
  class TxAgent:
25
  def __init__(self, model_name,
26
  rag_model_name,
27
- tool_files_dict=None,
28
  enable_finish=True,
29
  enable_rag=True,
30
  enable_summary=False,
31
- init_rag_num=2,
32
- step_rag_num=4,
33
  summary_mode='step',
34
  summary_skip_last_k=0,
35
  summary_context_length=None,
@@ -38,7 +32,8 @@ class TxAgent:
38
  seed=None,
39
  enable_checker=False,
40
  enable_chat=False,
41
- additional_default_tools=None):
 
42
  self.model_name = model_name
43
  self.tokenizer = None
44
  self.terminators = None
@@ -47,9 +42,10 @@ class TxAgent:
47
  self.model = None
48
  self.rag_model = ToolRAGModel(rag_model_name)
49
  self.tooluniverse = None
50
- self.prompt_multi_step = "You are a medical assistant solving clinical oversight issues step-by-step using provided tools."
51
- self.self_prompt = "Follow instructions precisely."
52
- self.chat_prompt = "You are a helpful assistant for clinical queries."
 
53
  self.enable_finish = enable_finish
54
  self.enable_rag = enable_rag
55
  self.enable_summary = enable_summary
@@ -63,7 +59,7 @@ class TxAgent:
63
  self.seed = seed
64
  self.enable_checker = enable_checker
65
  self.additional_default_tools = additional_default_tools
66
- logger.debug("TxAgent initialized with parameters: %s", self.__dict__)
67
 
68
  def init_model(self):
69
  self.load_models()
@@ -72,29 +68,19 @@ class TxAgent:
72
 
73
  def print_self_values(self):
74
  for attr, value in self.__dict__.items():
75
- logger.debug("%s: %s", attr, value)
76
 
77
  def load_models(self, model_name=None):
78
- if model_name is not None and model_name == self.model_name:
79
- return f"The model {model_name} is already loaded."
80
- if model_name:
81
  self.model_name = model_name
82
 
83
- try:
84
- torch.cuda.empty_cache()
85
- self.model = LLM(
86
- model=self.model_name,
87
- dtype="float16",
88
- max_model_len=131072,
89
- enforce_eager=True # Avoid graph compilation issues
90
- )
91
- self.chat_template = Template(self.model.get_tokenizer().chat_template)
92
- self.tokenizer = self.model.get_tokenizer()
93
- logger.info("Model %s loaded successfully", self.model_name)
94
- return f"Model {self.model_name} loaded successfully."
95
- except Exception as e:
96
- logger.error(f"Model loading error: {e}")
97
- raise
98
 
99
  def load_tooluniverse(self):
100
  self.tooluniverse = ToolUniverse(tool_files=self.tool_files_dict)
@@ -102,225 +88,316 @@ class TxAgent:
102
  special_tools = self.tooluniverse.prepare_tool_prompts(
103
  self.tooluniverse.tool_category_dicts["special_tools"])
104
  self.special_tools_name = [tool['name'] for tool in special_tools]
105
- logger.debug("ToolUniverse loaded with %d special tools", len(self.special_tools_name))
106
 
107
  def load_tool_desc_embedding(self):
108
  self.rag_model.load_tool_desc_embedding(self.tooluniverse)
109
- logger.debug("Tool description embeddings loaded")
110
 
111
  def rag_infer(self, query, top_k=5):
112
  return self.rag_model.rag_infer(query, top_k)
113
 
114
  def initialize_tools_prompt(self, call_agent, call_agent_level, message):
115
  picked_tools_prompt = []
116
- if "use external tools" not in message.lower():
117
- picked_tools_prompt = self.add_special_tools(picked_tools_prompt, call_agent=False)
118
- else:
119
- picked_tools_prompt = self.add_special_tools(picked_tools_prompt, call_agent=call_agent)
120
- if call_agent:
121
- call_agent_level += 1
122
- if call_agent_level >= 2:
123
- call_agent = False
124
- if self.enable_rag:
125
- picked_tools_prompt += self.tool_RAG(message=message, rag_num=self.init_rag_num)
126
  return picked_tools_prompt, call_agent_level
127
 
128
  def initialize_conversation(self, message, conversation=None, history=None):
129
  if conversation is None:
130
  conversation = []
131
 
132
- conversation = self.set_system_prompt(conversation, self.prompt_multi_step)
133
- if history:
134
- conversation.extend(
135
- {"role": h['role'], "content": h['content']}
136
- for h in history if h['role'] in ['user', 'assistant']
137
- )
 
 
 
 
 
 
 
 
 
 
 
 
138
  conversation.append({"role": "user", "content": message})
139
- logger.debug("Conversation initialized with %d messages", len(conversation))
140
  return conversation
141
 
142
- def tool_RAG(self, message=None, picked_tool_names=None,
143
- existing_tools_prompt=None, rag_num=4, return_call_result=False):
144
- extra_factor = 10
 
 
 
145
  if picked_tool_names is None:
146
- picked_tool_names = self.rag_infer(message, top_k=rag_num * extra_factor)
 
 
 
 
 
 
 
 
 
147
 
148
- picked_tool_names = [
149
- tool for tool in picked_tool_names
150
- if tool not in self.special_tools_name
151
- ][:rag_num]
152
  picked_tools = self.tooluniverse.get_tool_by_name(picked_tool_names)
153
- picked_tools_prompt = self.tooluniverse.prepare_tool_prompts(picked_tools)
154
- logger.debug("RAG selected %d tools: %s", len(picked_tool_names), picked_tool_names)
155
  if return_call_result:
156
  return picked_tools_prompt, picked_tool_names
157
  return picked_tools_prompt
158
 
159
  def add_special_tools(self, tools, call_agent=False):
160
  if self.enable_finish:
161
- tools.append(self.tooluniverse.get_one_tool_by_one_name('Finish', return_prompt=True))
162
- logger.debug("Finish tool added")
163
- if call_agent and "use external tools" in self.prompt_multi_step.lower():
164
- tools.append(self.tooluniverse.get_one_tool_by_one_name('CallAgent', return_prompt=True))
165
- logger.debug("CallAgent tool added")
166
- elif self.enable_rag and "use external tools" in self.prompt_multi_step.lower():
167
- tools.append(self.tooluniverse.get_one_tool_by_one_name('Tool_RAG', return_prompt=True))
168
- logger.debug("Tool_RAG tool added")
169
- if self.additional_default_tools:
170
- for tool_name in self.additional_default_tools:
171
- tool_prompt = self.tooluniverse.get_one_tool_by_one_name(tool_name, return_prompt=True)
172
- if tool_prompt:
173
- tools.append(tool_prompt)
174
- logger.debug("%s tool added", tool_name)
 
 
 
 
 
 
175
  return tools
176
 
177
  def add_finish_tools(self, tools):
178
- tools.append(self.tooluniverse.get_one_tool_by_one_name('Finish', return_prompt=True))
179
- logger.debug("Finish tool added")
 
180
  return tools
181
 
182
  def set_system_prompt(self, conversation, sys_prompt):
183
- if not conversation:
184
- conversation.append({"role": "system", "content": sys_prompt})
 
185
  else:
186
  conversation[0] = {"role": "system", "content": sys_prompt}
187
  return conversation
188
 
189
- def run_function_call(self, fcall_str, return_message=False,
190
- existing_tools_prompt=None, message_for_call_agent=None,
191
- call_agent=False, call_agent_level=None, temperature=None):
 
 
 
 
 
192
  function_call_json, message = self.tooluniverse.extract_function_call_json(
193
  fcall_str, return_message=return_message, verbose=False)
194
  call_results = []
195
  special_tool_call = ''
196
- if function_call_json:
197
- for func in function_call_json if isinstance(function_call_json, list) else [function_call_json]:
198
- logger.debug("Tool Call: %s", func)
199
- if func["name"] == 'Finish':
200
- special_tool_call = 'Finish'
201
- break
202
- elif func["name"] == 'Tool_RAG':
203
- new_tools_prompt, call_result = self.tool_RAG(
204
- message=message, existing_tools_prompt=existing_tools_prompt,
205
- rag_num=self.step_rag_num, return_call_result=True)
206
- existing_tools_prompt += new_tools_prompt
207
- elif func["name"] == 'CallAgent' and call_agent and call_agent_level < 2:
208
- solution_plan = func['arguments']['solution']
209
- full_message = (
210
- message_for_call_agent + "\nFollow this plan: " + str(solution_plan)
211
- )
212
- call_result = self.run_multistep_agent(
213
- full_message, temperature=temperature, max_new_tokens=512,
214
- max_token=2048, call_agent=False, call_agent_level=call_agent_level)
215
- call_result = call_result.split('[FinalAnswer]')[-1].strip() if call_result else "⚠️ No content from sub-agent."
216
- else:
217
- call_result = self.tooluniverse.run_one_function(func)
218
-
219
- call_id = self.tooluniverse.call_id_gen()
220
- func["call_id"] = call_id
221
- logger.debug("Tool Call Result: %s", call_result)
222
- call_results.append({
223
- "role": "tool",
224
- "content": json.dumps({"tool_name": func["name"], "content": call_result, "call_id": call_id})
225
- })
 
 
 
 
 
 
 
 
 
 
 
226
  else:
227
  call_results.append({
228
  "role": "tool",
229
- "content": json.dumps({"content": "Invalid function call format."})
230
  })
231
 
232
  revised_messages = [{
233
  "role": "assistant",
234
- "content": message.strip() if message else "",
235
  "tool_calls": json.dumps(function_call_json)
236
  }] + call_results
 
 
237
  return revised_messages, existing_tools_prompt, special_tool_call
238
 
239
- def run_function_call_stream(self, fcall_str, return_message=False,
240
- existing_tools_prompt=None, message_for_call_agent=None,
241
- call_agent=False, call_agent_level=None, temperature=None,
242
- return_gradio_history=True):
 
 
 
 
 
243
  function_call_json, message = self.tooluniverse.extract_function_call_json(
244
  fcall_str, return_message=return_message, verbose=False)
245
  call_results = []
246
  special_tool_call = ''
247
- gradio_history = [] if return_gradio_history else None
248
- if function_call_json:
249
- for func in function_call_json if isinstance(function_call_json, list) else [function_call_json]:
250
- if func["name"] == 'Finish':
251
- special_tool_call = 'Finish'
252
- break
253
- elif func["name"] == 'Tool_RAG':
254
- new_tools_prompt, call_result = self.tool_RAG(
255
- message=message, existing_tools_prompt=existing_tools_prompt,
256
- rag_num=self.step_rag_num, return_call_result=True)
257
- existing_tools_prompt += new_tools_prompt
258
- elif func["name"] == 'DirectResponse':
259
- call_result = func['arguments']['response']
260
- special_tool_call = 'DirectResponse'
261
- elif func["name"] == 'RequireClarification':
262
- call_result = func['arguments']['unclear_question']
263
- special_tool_call = 'RequireClarification'
264
- elif func["name"] == 'CallAgent' and call_agent and call_agent_level < 2:
265
- solution_plan = func['arguments']['solution']
266
- full_message = (
267
- message_for_call_agent + "\nFollow this plan: " + str(solution_plan)
268
- )
269
- sub_agent_task = "Sub TxAgent plan: " + str(solution_plan)
270
- call_result = yield from self.run_gradio_chat(
271
- full_message, history=[], temperature=temperature,
272
- max_new_tokens=512, max_token=2048, call_agent=False,
273
- call_agent_level=call_agent_level, conversation=None,
274
- sub_agent_task=sub_agent_task)
275
- call_result = call_result.split('[FinalAnswer]')[-1] if call_result else "⚠️ No content from sub-agent."
276
- else:
277
- call_result = self.tooluniverse.run_one_function(func)
278
-
279
- call_id = self.tooluniverse.call_id_gen()
280
- func["call_id"] = call_id
281
- call_results.append({
282
- "role": "tool",
283
- "content": json.dumps({"tool_name": func["name"], "content": call_result, "call_id": call_id})
284
- })
285
- if return_gradio_history and func["name"] != 'Finish':
286
- title = f"{'🧰' if func['name'] == 'Tool_RAG' else '⚒️'} {func['name']}"
287
- gradio_history.append(ChatMessage(
288
- role="assistant", content=str(call_result),
289
- metadata={"title": title, "log": str(func['arguments'])}
290
- ))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
291
  else:
292
  call_results.append({
293
  "role": "tool",
294
- "content": json.dumps({"content": "Invalid function call format."})
295
  })
296
 
297
  revised_messages = [{
298
  "role": "assistant",
299
- "content": message.strip() if message else "",
300
  "tool_calls": json.dumps(function_call_json)
301
  }] + call_results
302
- return revised_messages, existing_tools_prompt, special_tool_call, gradio_history
303
 
304
- def get_answer_based_on_unfinished_reasoning(self, conversation, temperature, max_new_tokens, max_token):
305
- if conversation[-1]['role'] == 'assistant':
 
 
 
 
 
 
306
  conversation.append(
307
- {'role': 'tool', 'content': 'Errors occurred; provide final answer with current info.'})
308
  finish_tools_prompt = self.add_finish_tools([])
309
- output = self.llm_infer(
310
- messages=conversation, temperature=temperature, tools=finish_tools_prompt,
311
- output_begin_string='[FinalAnswer]', max_new_tokens=max_new_tokens, max_token=max_token)
312
- logger.debug("Unfinished reasoning output: %s", output)
313
- return output
314
 
315
- def run_multistep_agent(self, message: str, temperature: float, max_new_tokens: int,
316
- max_token: int, max_round: int = 3, call_agent=False, call_agent_level=0):
317
- logger.debug("Starting multistep agent for message: %s", message[:100])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
318
  picked_tools_prompt, call_agent_level = self.initialize_tools_prompt(
319
  call_agent, call_agent_level, message)
320
  conversation = self.initialize_conversation(message)
 
321
  outputs = []
322
  last_outputs = []
323
  next_round = True
 
324
  current_round = 0
325
  token_overflow = False
326
  enable_summary = False
@@ -328,425 +405,533 @@ class TxAgent:
328
 
329
  if self.enable_checker:
330
  checker = ReasoningTraceChecker(message, conversation)
331
-
332
- clinical_keywords = ['medication', 'symptom', 'evaluation', 'diagnosis']
333
- has_clinical_data = any(keyword in message.lower() for keyword in clinical_keywords)
334
-
335
- while next_round and current_round < max_round:
336
- current_round += 1
337
- if last_outputs:
338
- function_call_messages, picked_tools_prompt, special_tool_call = self.run_function_call(
339
- last_outputs, return_message=True, existing_tools_prompt=picked_tools_prompt,
340
- message_for_call_agent=message, call_agent=call_agent,
341
- call_agent_level=call_agent_level, temperature=temperature)
342
-
343
- if special_tool_call == 'Finish':
344
- next_round = False
345
- conversation.extend(function_call_messages)
346
- content = function_call_messages[0]['content']
347
- return content.split('[FinalAnswer]')[-1] if content else "❌ No content after Finish."
348
-
349
- if (self.enable_summary or token_overflow) and not call_agent:
350
- enable_summary = True
351
- last_status = self.function_result_summary(
352
- conversation, status=last_status, enable_summary=enable_summary)
353
-
354
- if function_call_messages:
355
- conversation.extend(function_call_messages)
356
- outputs.append(tool_result_format(function_call_messages))
357
- else:
358
- next_round = False
359
- return ''.join(last_outputs).replace("</s>", "")
360
-
361
- if self.enable_checker:
362
- good_status, wrong_info = checker.check_conversation()
363
- if not good_status:
364
- logger.warning("Checker error: %s", wrong_info)
365
- break
366
-
367
- tools = [] if has_clinical_data else picked_tools_prompt
368
- last_outputs = []
369
- last_outputs_str, token_overflow = self.llm_infer(
370
- messages=conversation, temperature=temperature, tools=tools,
371
- max_new_tokens=max_new_tokens, max_token=max_token, check_token_status=True)
372
- if last_outputs_str is None:
373
- if self.force_finish:
374
- return self.get_answer_based_on_unfinished_reasoning(
375
- conversation, temperature, max_new_tokens, max_token)
376
- return "❌ Token limit exceeded."
377
- last_outputs.append(last_outputs_str)
378
-
379
- if current_round >= max_round:
380
- logger.warning("Max rounds exceeded")
381
- if self.force_finish:
382
- return self.get_answer_based_on_unfinished_reasoning(
383
- conversation, temperature, max_new_tokens, max_token)
384
- return None
385
-
386
- def build_logits_processor(self, messages, llm):
387
- tokenizer = llm.get_tokenizer()
388
- if self.avoid_repeat and len(messages) > 2:
389
- assistant_messages = [
390
- m['content'] for m in messages[-3:] if m['role'] == 'assistant'
391
- ][:2]
392
- forbidden_ids = [tokenizer.encode(msg, add_special_tokens=False) for msg in assistant_messages]
393
- unique_sentences = []
394
- for msg in assistant_messages:
395
- sentences = msg.split('. ')
396
- for s in sentences:
397
- if not s:
398
- continue
399
- is_unique = True
400
- for seen_s in unique_sentences:
401
- if SequenceMatcher(None, s.lower(), seen_s.lower()).ratio() > 0.9:
402
- is_unique = False
403
- break
404
- if is_unique:
405
- unique_sentences.append(s)
406
- forbidden_ids = [tokenizer.encode(s, add_special_tokens=False) for s in unique_sentences]
407
- return [NoRepeatSentenceProcessor(forbidden_ids, 15)] # Increased penalty
408
- return None
409
-
410
- def llm_infer(self, messages, temperature=0.1, tools=None, output_begin_string=None,
411
- max_new_tokens=512, max_token=2048, skip_special_tokens=True,
412
- model=None, tokenizer=None, terminators=None, seed=None, check_token_status=False):
413
- if model is None:
414
- model = self.model
415
-
416
- logits_processor = self.build_logits_processor(messages, model)
417
- sampling_params = SamplingParams(
418
- temperature=temperature,
419
- max_tokens=max_new_tokens,
420
- seed=seed if seed is not None else self.seed,
421
- logits_processors=logits_processor
422
- )
423
-
424
- prompt = self.chat_template.render(messages=messages, tools=tools, add_generation_prompt=True)
425
- if output_begin_string:
426
- prompt += output_begin_string
427
-
428
- if len(prompt) > 100000: # Early text length check
429
- logger.error(f"Prompt length ({len(prompt)}) exceeds limit (100000).")
430
- return None, True
431
-
432
- if check_token_status and max_token:
433
- num_input_tokens = len(self.tokenizer.encode(prompt, add_special_tokens=False))
434
- if num_input_tokens > max_token:
435
- logger.warning(f"Input tokens ({num_input_tokens}) exceed max_token ({max_token}). Truncating.")
436
- prompt_tokens = self.tokenizer.encode(prompt, add_special_tokens=False)[:max_token]
437
- prompt = self.tokenizer.decode(prompt_tokens)
438
- if num_input_tokens > 131072:
439
- logger.error(f"Input tokens ({num_input_tokens}) exceed model limit (131072).")
440
- return None, True
441
-
442
- try:
443
- torch.cuda.empty_cache()
444
- output = model.generate(prompt, sampling_params=sampling_params)
445
- output = output[0].outputs[0].text
446
- logger.debug("Inference output: %s", output[:100])
447
- except Exception as e:
448
- logger.error(f"Inference error: {e}")
449
- return None, True
450
-
451
- torch.cuda.empty_cache()
452
- gc.collect()
453
- if check_token_status:
454
- return output, False
455
- return output
456
-
457
- def run_quick_summary(self, message: str, temperature: float = 0.1, max_new_tokens: int = 256, max_token: int = 1024):
458
- """Generate a fast, concise summary of potential missed diagnoses without tool calls"""
459
- logger.debug("Starting quick summary for message: %s", message[:100])
460
- if len(message) > 50000:
461
- logger.warning(f"Message length ({len(message)}) exceeds limit (50000). Truncating.")
462
- message = message[:50000]
463
-
464
- prompt = """
465
- Analyze the patient record excerpt for missed diagnoses, focusing ONLY on clinical findings such as symptoms, medications, or evaluation results. Provide a concise summary in ONE paragraph without headings or bullet points. ALWAYS treat medications or psychiatric evaluations as potential missed diagnoses, specifying their implications (e.g., 'use of Seroquel may indicate untreated psychosis'). Recommend urgent review for identified findings. Do NOT use external tools or repeat non-clinical data (e.g., name, date of birth). If no clinical findings are present, state 'No missed diagnoses identified' in ONE sentence.
466
- Patient Record Excerpt:
467
- {chunk}
468
- """
469
- conversation = self.set_system_prompt([], prompt.format(chunk=message))
470
- conversation.append({"role": "user", "content": message})
471
- output, token_overflow = self.llm_infer(
472
- messages=conversation,
473
- temperature=temperature,
474
- max_new_tokens=max_new_tokens,
475
- max_token=max_token,
476
- tools=[] # No tools
477
- )
478
- if token_overflow:
479
- logger.error("Token overflow in quick summary")
480
- return "Error: Input too large for quick summary."
481
- if output and '[FinalAnswer]' in output:
482
- output = output.split('[FinalAnswer]')[-1].strip()
483
- logger.debug("Quick summary output: %s", output[:100] if output else "None")
484
- return output or "No missed diagnoses identified"
485
-
486
- def run_background_report(self, message: str, history: list, temperature: float,
487
- max_new_tokens: int, max_token: int, call_agent: bool,
488
- conversation: gr.State, max_round: int, seed: int,
489
- call_agent_level: int, report_path: str):
490
- """Run detailed report generation in the background and save to file"""
491
- logger.debug("Starting background report for message: %s", message[:100])
492
- if len(message) > 50000:
493
- logger.warning(f"Message length ({len(message)}) exceeds limit (50000). Truncating.")
494
- message = message[:50000]
495
-
496
- combined_response = ""
497
- history_copy = history.copy()
498
-
499
- picked_tools_prompt, call_agent_level = self.initialize_tools_prompt(
500
- call_agent, call_agent_level, message)
501
- conversation = self.initialize_conversation(message, conversation, history_copy)
502
-
503
- next_round = True
504
- current_round = 0
505
- enable_summary = False
506
- last_status = {}
507
- token_overflow = False
508
-
509
- if self.enable_checker:
510
- checker = ReasoningTraceChecker(message, conversation, init_index=len(conversation))
511
-
512
  try:
513
  while next_round and current_round < max_round:
514
  current_round += 1
515
- last_outputs = []
516
- if last_outputs:
517
  function_call_messages, picked_tools_prompt, special_tool_call = self.run_function_call(
518
- last_outputs, return_message=True, existing_tools_prompt=picked_tools_prompt,
519
- message_for_call_agent=message, call_agent=call_agent,
520
- call_agent_level=call_agent_level, temperature=temperature)
521
-
 
 
 
522
  if special_tool_call == 'Finish':
523
  next_round = False
524
  conversation.extend(function_call_messages)
525
- combined_response += function_call_messages[0]['content'] + "\n"
526
- break
 
 
527
 
528
  if (self.enable_summary or token_overflow) and not call_agent:
 
 
529
  enable_summary = True
530
  last_status = self.function_result_summary(
531
  conversation, status=last_status, enable_summary=enable_summary)
532
 
533
- if function_call_messages:
534
  conversation.extend(function_call_messages)
535
- combined_response += tool_result_format(function_call_messages) + "\n"
 
536
  else:
537
  next_round = False
538
- combined_response += ''.join(last_outputs).replace("</s>", "") + "\n"
539
- break
540
-
541
  if self.enable_checker:
542
  good_status, wrong_info = checker.check_conversation()
543
  if not good_status:
544
- logger.warning("Checker error: %s", wrong_info)
 
 
545
  break
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
546
 
547
- tools = picked_tools_prompt
548
- last_outputs_str, token_overflow = self.llm_infer(
549
- messages=conversation, temperature=temperature, tools=tools,
550
- max_new_tokens=max_new_tokens, max_token=max_token, seed=seed, check_token_status=True)
 
 
551
 
552
- if last_outputs_str is None:
553
- if self.force_finish:
554
- last_outputs_str = self.get_answer_based_on_unfinished_reasoning(
555
- conversation, temperature, max_new_tokens, max_token)
556
- combined_response += last_outputs_str + "\n"
 
 
 
 
557
  break
558
- combined_response += "Token limit exceeded.\n"
559
- break
 
 
 
560
 
561
- combined_response += last_outputs_str + "\n"
562
- last_outputs.append(last_outputs_str)
 
 
563
 
564
- if next_round and self.force_finish:
565
- last_outputs_str = self.get_answer_based_on_unfinished_reasoning(
566
- conversation, temperature, max_new_tokens, max_token)
567
- combined_response += last_outputs_str + "\n"
 
 
 
 
 
 
568
 
569
- # Save report
570
- try:
571
- with open(report_path, "w", encoding="utf-8") as f:
572
- f.write(combined_response)
573
- logger.info("Detailed report saved to %s", report_path)
574
- except Exception as e:
575
- logger.error(f"Failed to save report: {e}")
576
 
577
- except Exception as e:
578
- logger.error(f"Background report error: {e}")
579
- combined_response += f"Error: {e}\n"
580
- with open(report_path, "w", encoding="utf-8") as f:
581
- f.write(combined_response)
582
-
583
- finally:
584
- torch.cuda.empty_cache()
585
- gc.collect()
586
-
587
- def run_gradio_chat(self, message: str, history: list, temperature: float,
588
- max_new_tokens: int, max_token: int, call_agent: bool,
589
- conversation: gr.State, max_round: int = 3, seed: int = None,
590
- call_agent_level: int = 0, sub_agent_task: str = None,
591
- uploaded_files: list = None, report_path: str = None):
592
- logger.debug("Chat started, message: %s", message[:100])
593
- if not message or len(message.strip()) < 5:
594
- yield "Please provide a valid message or upload files to analyze."
595
- return
596
-
597
- if message.startswith("[\U0001f9f0 Tool_RAG") or message.startswith("⚒️"):
598
- return
599
-
600
- clinical_keywords = ['medication', 'symptom', 'evaluation', 'diagnosis']
601
- has_clinical_data = any(keyword in message.lower() for keyword in clinical_keywords)
602
- call_agent = call_agent and not has_clinical_data
603
-
604
- # Generate quick summary
605
- quick_summary = self.run_quick_summary(
606
- message, temperature=temperature, max_new_tokens=256, max_token=1024)
607
- history.append(ChatMessage(role="assistant", content=f"**Quick Summary:**\n{quick_summary}"))
608
- yield history
609
-
610
- # Start background report generation
611
- if report_path:
612
- threading.Thread(
613
- target=self.run_background_report,
614
- args=(message, history, temperature, max_new_tokens, max_token, call_agent,
615
- conversation, max_round, seed, call_agent_level, report_path),
616
- daemon=True
617
- ).start()
618
- history.append(ChatMessage(
619
- role="assistant",
620
- content="Generating detailed report in the background. Download will be available when ready."
621
- ))
622
- yield history
623
-
624
- def run_self_agent(self, message: str, temperature: float, max_new_tokens: int, max_token: int):
625
- logger.debug("Starting self agent")
626
- conversation = self.set_system_prompt([], self.self_prompt)
627
  conversation.append({"role": "user", "content": message})
628
- return self.llm_infer(messages=conversation, temperature=temperature,
 
 
629
  max_new_tokens=max_new_tokens, max_token=max_token)
630
 
631
- def run_chat_agent(self, message: str, temperature: float, max_new_tokens: int, max_token: int):
632
- logger.debug("Starting chat agent")
633
- conversation = self.set_system_prompt([], self.chat_prompt)
 
 
 
 
 
634
  conversation.append({"role": "user", "content": message})
635
- return self.llm_infer(messages=conversation, temperature=temperature,
 
 
636
  max_new_tokens=max_new_tokens, max_token=max_token)
637
 
638
- def run_format_agent(self, message: str, answer: str, temperature: float, max_new_tokens: int, max_token: int):
639
- logger.debug("Starting format agent")
 
 
 
 
 
640
  if '[FinalAnswer]' in answer:
641
  possible_final_answer = answer.split("[FinalAnswer]")[-1]
642
  elif "\n\n" in answer:
643
  possible_final_answer = answer.split("\n\n")[-1]
644
  else:
645
  possible_final_answer = answer.strip()
646
-
647
- if len(possible_final_answer) >= 1 and possible_final_answer[0] in ['A', 'B', 'C', 'D', 'E']:
648
- return possible_final_answer[0]
649
- elif len(possible_final_answer) > 1 and possible_final_answer[1] == ':' and possible_final_answer[0] in ['A', 'B', 'C', 'D', 'E']:
650
- return possible_final_answer[0]
651
-
652
- conversation = self.set_system_prompt(
653
- [], "Transform the answer to a single letter: 'A', 'B', 'C', 'D', or 'E'.")
654
- conversation.append({"role": "user", "content": f"Original: {message}\nAnswer: {answer}\nFinal answer (letter):"})
655
- return self.llm_infer(messages=conversation, temperature=temperature,
 
 
 
 
 
 
 
 
 
656
  max_new_tokens=max_new_tokens, max_token=max_token)
657
 
658
- def run_summary_agent(self, thought_calls: str, function_response: str,
659
- temperature: float, max_new_tokens: int, max_token: int):
660
- logger.debug("Starting summary agent")
661
- prompt = f"""Thought and function calls: {thought_calls}
662
- Function responses: {function_response}
663
- Summarize the function responses in one sentence with all necessary information."""
664
- conversation = [{"role": "user", "content": prompt}]
665
- output = self.llm_infer(messages=conversation, temperature=temperature,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
666
  max_new_tokens=max_new_tokens, max_token=max_token)
 
667
  if '[' in output:
668
  output = output.split('[')[0]
669
  return output
670
 
671
  def function_result_summary(self, input_list, status, enable_summary):
 
 
 
 
 
 
 
 
 
 
 
 
 
672
  if 'tool_call_step' not in status:
673
  status['tool_call_step'] = 0
674
- if 'step' not in status:
675
- status['step'] = 0
676
- status['step'] += 1
677
 
678
  for idx in range(len(input_list)):
679
- pos_id = len(input_list) - idx - 1
680
- if input_list[pos_id]['role'] == 'assistant' and 'tool_calls' in input_list[pos_id]:
681
- if 'Tool_RAG' in str(input_list[pos_id]['tool_calls']):
682
- status['tool_call_step'] += 1
 
683
  break
684
 
 
 
 
 
 
685
  if not enable_summary:
686
  return status
687
 
688
  if 'summarized_index' not in status:
689
  status['summarized_index'] = 0
 
690
  if 'summarized_step' not in status:
691
  status['summarized_step'] = 0
 
692
  if 'previous_length' not in status:
693
  status['previous_length'] = 0
 
694
  if 'history' not in status:
695
  status['history'] = []
696
 
697
- status['history'].append(
698
- self.summary_mode == 'step' and status['summarized_step'] < status['step'] - status['tool_call_step'] - self.summary_skip_last_k)
699
-
700
- idx = status['summarized_index']
701
  function_response = ''
702
- this_thought_calls = None
 
 
 
 
 
 
703
  while idx < len(input_list):
704
- if (self.summary_mode == 'step' and status['summarized_step'] < status['step'] - status['tool_call_step'] - self.summary_skip_last_k) or \
705
- (self.summary_mode == 'length' and status['previous_length'] > self.summary_context_length):
706
  if input_list[idx]['role'] == 'assistant':
707
  if 'Tool_RAG' in str(input_list[idx]['tool_calls']):
708
  this_thought_calls = None
709
  else:
710
- if function_response:
 
711
  status['summarized_step'] += 1
712
  result_summary = self.run_summary_agent(
713
- thought_calls=this_thought_calls, function_response=function_response,
714
- temperature=0.1, max_new_tokens=512, max_token=2048)
 
 
 
 
 
715
  input_list.insert(
716
- last_call_idx + 1, {'role': 'tool', 'content': result_summary})
717
  status['summarized_index'] = last_call_idx + 2
718
  idx += 1
 
719
  last_call_idx = idx
720
- this_thought_calls = input_list[idx]['content'] + input_list[idx]['tool_calls']
 
721
  function_response = ''
722
- elif input_list[idx]['role'] == 'tool' and this_thought_calls:
 
723
  function_response += input_list[idx]['content']
724
  del input_list[idx]
725
  idx -= 1
 
726
  else:
727
  break
728
  idx += 1
729
 
730
- if function_response:
731
  status['summarized_step'] += 1
732
  result_summary = self.run_summary_agent(
733
- thought_calls=this_thought_calls, function_response=function_response,
734
- temperature=0.1, max_new_tokens=512, max_token=2048)
 
 
 
 
 
735
  tool_calls = json.loads(input_list[last_call_idx]['tool_calls'])
736
  for tool_call in tool_calls:
737
  del tool_call['call_id']
738
  input_list[last_call_idx]['tool_calls'] = json.dumps(tool_calls)
739
  input_list.insert(
740
- last_call_idx + 1, {'role': 'tool', 'content': result_summary})
741
  status['summarized_index'] = last_call_idx + 2
742
 
743
  return status
744
 
 
 
 
745
  def update_parameters(self, **kwargs):
746
- updated_attributes = {}
747
  for key, value in kwargs.items():
748
  if hasattr(self, key):
749
  setattr(self, key, value)
750
- updated_attributes[key] = value
751
- logger.debug("Updated parameters: %s", updated_attributes)
752
- return updated_attributes
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
  from tooluniverse import ToolUniverse
12
  from gradio import ChatMessage
13
  from .toolrag import ToolRAGModel
 
 
 
 
 
 
 
14
 
15
  from .utils import NoRepeatSentenceProcessor, ReasoningTraceChecker, tool_result_format
16
 
17
+
18
  class TxAgent:
19
  def __init__(self, model_name,
20
  rag_model_name,
21
+ tool_files_dict=None, # None leads to the default tool files in ToolUniverse
22
  enable_finish=True,
23
  enable_rag=True,
24
  enable_summary=False,
25
+ init_rag_num=0,
26
+ step_rag_num=10,
27
  summary_mode='step',
28
  summary_skip_last_k=0,
29
  summary_context_length=None,
 
32
  seed=None,
33
  enable_checker=False,
34
  enable_chat=False,
35
+ additional_default_tools=None,
36
+ ):
37
  self.model_name = model_name
38
  self.tokenizer = None
39
  self.terminators = None
 
42
  self.model = None
43
  self.rag_model = ToolRAGModel(rag_model_name)
44
  self.tooluniverse = None
45
+ # self.tool_desc = None
46
+ self.prompt_multi_step = "You are a helpful assistant that will solve problems through detailed, step-by-step reasoning and actions based on your reasoning. Typically, your actions will use the provided functions. You have access to the following functions."
47
+ self.self_prompt = "Strictly follow the instruction."
48
+ self.chat_prompt = "You are helpful assistant to chat with the user."
49
  self.enable_finish = enable_finish
50
  self.enable_rag = enable_rag
51
  self.enable_summary = enable_summary
 
59
  self.seed = seed
60
  self.enable_checker = enable_checker
61
  self.additional_default_tools = additional_default_tools
62
+ self.print_self_values()
63
 
64
  def init_model(self):
65
  self.load_models()
 
68
 
69
  def print_self_values(self):
70
  for attr, value in self.__dict__.items():
71
+ print(f"{attr}: {value}")
72
 
73
  def load_models(self, model_name=None):
74
+ if model_name is not None:
75
+ if model_name == self.model_name:
76
+ return f"The model {model_name} is already loaded."
77
  self.model_name = model_name
78
 
79
+ self.model = LLM(model=self.model_name)
80
+ self.chat_template = Template(self.model.get_tokenizer().chat_template)
81
+ self.tokenizer = self.model.get_tokenizer()
82
+
83
+ return f"Model {model_name} loaded successfully."
 
 
 
 
 
 
 
 
 
 
84
 
85
  def load_tooluniverse(self):
86
  self.tooluniverse = ToolUniverse(tool_files=self.tool_files_dict)
 
88
  special_tools = self.tooluniverse.prepare_tool_prompts(
89
  self.tooluniverse.tool_category_dicts["special_tools"])
90
  self.special_tools_name = [tool['name'] for tool in special_tools]
 
91
 
92
  def load_tool_desc_embedding(self):
93
  self.rag_model.load_tool_desc_embedding(self.tooluniverse)
 
94
 
95
  def rag_infer(self, query, top_k=5):
96
  return self.rag_model.rag_infer(query, top_k)
97
 
98
  def initialize_tools_prompt(self, call_agent, call_agent_level, message):
99
  picked_tools_prompt = []
100
+ picked_tools_prompt = self.add_special_tools(
101
+ picked_tools_prompt, call_agent=call_agent)
102
+ if call_agent:
103
+ call_agent_level += 1
104
+ if call_agent_level >= 2:
105
+ call_agent = False
106
+
107
+ if not call_agent:
108
+ picked_tools_prompt += self.tool_RAG(
109
+ message=message, rag_num=self.init_rag_num)
110
  return picked_tools_prompt, call_agent_level
111
 
112
  def initialize_conversation(self, message, conversation=None, history=None):
113
  if conversation is None:
114
  conversation = []
115
 
116
+ conversation = self.set_system_prompt(
117
+ conversation, self.prompt_multi_step)
118
+ if history is not None:
119
+ if len(history) == 0:
120
+ conversation = []
121
+ print("clear conversation successfully")
122
+ else:
123
+ for i in range(len(history)):
124
+ if history[i]['role'] == 'user':
125
+ if i-1 >= 0 and history[i-1]['role'] == 'assistant':
126
+ conversation.append(
127
+ {"role": "assistant", "content": history[i-1]['content']})
128
+ conversation.append(
129
+ {"role": "user", "content": history[i]['content']})
130
+ if i == len(history)-1 and history[i]['role'] == 'assistant':
131
+ conversation.append(
132
+ {"role": "assistant", "content": history[i]['content']})
133
+
134
  conversation.append({"role": "user", "content": message})
135
+
136
  return conversation
137
 
138
+ def tool_RAG(self, message=None,
139
+ picked_tool_names=None,
140
+ existing_tools_prompt=[],
141
+ rag_num=5,
142
+ return_call_result=False):
143
+ extra_factor = 30 # Factor to retrieve more than rag_num
144
  if picked_tool_names is None:
145
+ assert picked_tool_names is not None or message is not None
146
+ picked_tool_names = self.rag_infer(
147
+ message, top_k=rag_num*extra_factor)
148
+
149
+ picked_tool_names_no_special = []
150
+ for tool in picked_tool_names:
151
+ if tool not in self.special_tools_name:
152
+ picked_tool_names_no_special.append(tool)
153
+ picked_tool_names_no_special = picked_tool_names_no_special[:rag_num]
154
+ picked_tool_names = picked_tool_names_no_special[:rag_num]
155
 
 
 
 
 
156
  picked_tools = self.tooluniverse.get_tool_by_name(picked_tool_names)
157
+ picked_tools_prompt = self.tooluniverse.prepare_tool_prompts(
158
+ picked_tools)
159
  if return_call_result:
160
  return picked_tools_prompt, picked_tool_names
161
  return picked_tools_prompt
162
 
163
  def add_special_tools(self, tools, call_agent=False):
164
  if self.enable_finish:
165
+ tools.append(self.tooluniverse.get_one_tool_by_one_name(
166
+ 'Finish', return_prompt=True))
167
+ print("Finish tool is added")
168
+ if call_agent:
169
+ tools.append(self.tooluniverse.get_one_tool_by_one_name(
170
+ 'CallAgent', return_prompt=True))
171
+ print("CallAgent tool is added")
172
+ else:
173
+ if self.enable_rag:
174
+ tools.append(self.tooluniverse.get_one_tool_by_one_name(
175
+ 'Tool_RAG', return_prompt=True))
176
+ print("Tool_RAG tool is added")
177
+
178
+ if self.additional_default_tools is not None:
179
+ for each_tool_name in self.additional_default_tools:
180
+ tool_prompt = self.tooluniverse.get_one_tool_by_one_name(
181
+ each_tool_name, return_prompt=True)
182
+ if tool_prompt is not None:
183
+ print(f"{each_tool_name} tool is added")
184
+ tools.append(tool_prompt)
185
  return tools
186
 
187
  def add_finish_tools(self, tools):
188
+ tools.append(self.tooluniverse.get_one_tool_by_one_name(
189
+ 'Finish', return_prompt=True))
190
+ print("Finish tool is added")
191
  return tools
192
 
193
  def set_system_prompt(self, conversation, sys_prompt):
194
+ if len(conversation) == 0:
195
+ conversation.append(
196
+ {"role": "system", "content": sys_prompt})
197
  else:
198
  conversation[0] = {"role": "system", "content": sys_prompt}
199
  return conversation
200
 
201
+ def run_function_call(self, fcall_str,
202
+ return_message=False,
203
+ existing_tools_prompt=None,
204
+ message_for_call_agent=None,
205
+ call_agent=False,
206
+ call_agent_level=None,
207
+ temperature=None):
208
+
209
  function_call_json, message = self.tooluniverse.extract_function_call_json(
210
  fcall_str, return_message=return_message, verbose=False)
211
  call_results = []
212
  special_tool_call = ''
213
+ if function_call_json is not None:
214
+ if isinstance(function_call_json, list):
215
+ for i in range(len(function_call_json)):
216
+ print("\033[94mTool Call:\033[0m", function_call_json[i])
217
+ if function_call_json[i]["name"] == 'Finish':
218
+ special_tool_call = 'Finish'
219
+ break
220
+ elif function_call_json[i]["name"] == 'Tool_RAG':
221
+ new_tools_prompt, call_result = self.tool_RAG(
222
+ message=message,
223
+ existing_tools_prompt=existing_tools_prompt,
224
+ rag_num=self.step_rag_num,
225
+ return_call_result=True)
226
+ existing_tools_prompt += new_tools_prompt
227
+ elif function_call_json[i]["name"] == 'CallAgent':
228
+ if call_agent_level < 2 and call_agent:
229
+ solution_plan = function_call_json[i]['arguments']['solution']
230
+ full_message = (
231
+ message_for_call_agent +
232
+ "\nYou must follow the following plan to answer the question: " +
233
+ str(solution_plan)
234
+ )
235
+ call_result = self.run_multistep_agent(
236
+ full_message, temperature=temperature,
237
+ max_new_tokens=1024, max_token=99999,
238
+ call_agent=False, call_agent_level=call_agent_level)
239
+ call_result = call_result.split(
240
+ '[FinalAnswer]')[-1].strip()
241
+ else:
242
+ call_result = "Error: The CallAgent has been disabled. Please proceed with your reasoning process to solve this question."
243
+ else:
244
+ call_result = self.tooluniverse.run_one_function(
245
+ function_call_json[i])
246
+
247
+ call_id = self.tooluniverse.call_id_gen()
248
+ function_call_json[i]["call_id"] = call_id
249
+ print("\033[94mTool Call Result:\033[0m", call_result)
250
+ call_results.append({
251
+ "role": "tool",
252
+ "content": json.dumps({"content": call_result, "call_id": call_id})
253
+ })
254
  else:
255
  call_results.append({
256
  "role": "tool",
257
+ "content": json.dumps({"content": "Not a valid function call, please check the function call format."})
258
  })
259
 
260
  revised_messages = [{
261
  "role": "assistant",
262
+ "content": message.strip(),
263
  "tool_calls": json.dumps(function_call_json)
264
  }] + call_results
265
+
266
+ # Yield the final result.
267
  return revised_messages, existing_tools_prompt, special_tool_call
268
 
269
+ def run_function_call_stream(self, fcall_str,
270
+ return_message=False,
271
+ existing_tools_prompt=None,
272
+ message_for_call_agent=None,
273
+ call_agent=False,
274
+ call_agent_level=None,
275
+ temperature=None,
276
+ return_gradio_history=True):
277
+
278
  function_call_json, message = self.tooluniverse.extract_function_call_json(
279
  fcall_str, return_message=return_message, verbose=False)
280
  call_results = []
281
  special_tool_call = ''
282
+ if return_gradio_history:
283
+ gradio_history = []
284
+ if function_call_json is not None:
285
+ if isinstance(function_call_json, list):
286
+ for i in range(len(function_call_json)):
287
+ if function_call_json[i]["name"] == 'Finish':
288
+ special_tool_call = 'Finish'
289
+ break
290
+ elif function_call_json[i]["name"] == 'Tool_RAG':
291
+ new_tools_prompt, call_result = self.tool_RAG(
292
+ message=message,
293
+ existing_tools_prompt=existing_tools_prompt,
294
+ rag_num=self.step_rag_num,
295
+ return_call_result=True)
296
+ existing_tools_prompt += new_tools_prompt
297
+ elif function_call_json[i]["name"] == 'DirectResponse':
298
+ call_result = function_call_json[i]['arguments']['respose']
299
+ special_tool_call = 'DirectResponse'
300
+ elif function_call_json[i]["name"] == 'RequireClarification':
301
+ call_result = function_call_json[i]['arguments']['unclear_question']
302
+ special_tool_call = 'RequireClarification'
303
+ elif function_call_json[i]["name"] == 'CallAgent':
304
+ if call_agent_level < 2 and call_agent:
305
+ solution_plan = function_call_json[i]['arguments']['solution']
306
+ full_message = (
307
+ message_for_call_agent +
308
+ "\nYou must follow the following plan to answer the question: " +
309
+ str(solution_plan)
310
+ )
311
+ sub_agent_task = "Sub TxAgent plan: " + \
312
+ str(solution_plan)
313
+ # When streaming, yield responses as they arrive.
314
+ call_result = yield from self.run_gradio_chat(
315
+ full_message, history=[], temperature=temperature,
316
+ max_new_tokens=1024, max_token=99999,
317
+ call_agent=False, call_agent_level=call_agent_level,
318
+ conversation=None,
319
+ sub_agent_task=sub_agent_task)
320
+
321
+ call_result = call_result.split(
322
+ '[FinalAnswer]')[-1]
323
+ else:
324
+ call_result = "Error: The CallAgent has been disabled. Please proceed with your reasoning process to solve this question."
325
+ else:
326
+ call_result = self.tooluniverse.run_one_function(
327
+ function_call_json[i])
328
+
329
+ call_id = self.tooluniverse.call_id_gen()
330
+ function_call_json[i]["call_id"] = call_id
331
+ call_results.append({
332
+ "role": "tool",
333
+ "content": json.dumps({"content": call_result, "call_id": call_id})
334
+ })
335
+ if return_gradio_history and function_call_json[i]["name"] != 'Finish':
336
+ if function_call_json[i]["name"] == 'Tool_RAG':
337
+ gradio_history.append(ChatMessage(role="assistant", content=str(call_result), metadata={
338
+ "title": "🧰 "+function_call_json[i]['name'], "log": str(function_call_json[i]['arguments'])}))
339
+
340
+ else:
341
+ gradio_history.append(ChatMessage(role="assistant", content=str(call_result), metadata={
342
+ "title": "⚒️ "+function_call_json[i]['name'], "log": str(function_call_json[i]['arguments'])}))
343
  else:
344
  call_results.append({
345
  "role": "tool",
346
+ "content": json.dumps({"content": "Not a valid function call, please check the function call format."})
347
  })
348
 
349
  revised_messages = [{
350
  "role": "assistant",
351
+ "content": message.strip(),
352
  "tool_calls": json.dumps(function_call_json)
353
  }] + call_results
 
354
 
355
+ # Yield the final result.
356
+ if return_gradio_history:
357
+ return revised_messages, existing_tools_prompt, special_tool_call, gradio_history
358
+ else:
359
+ return revised_messages, existing_tools_prompt, special_tool_call
360
+
361
+ def get_answer_based_on_unfinished_reasoning(self, conversation, temperature, max_new_tokens, max_token, outputs=None):
362
+ if conversation[-1]['role'] == 'assisant':
363
  conversation.append(
364
+ {'role': 'tool', 'content': 'Errors happen during the function call, please come up with the final answer with the current information.'})
365
  finish_tools_prompt = self.add_finish_tools([])
 
 
 
 
 
366
 
367
+ last_outputs_str = self.llm_infer(messages=conversation,
368
+ temperature=temperature,
369
+ tools=finish_tools_prompt,
370
+ output_begin_string='Since I cannot continue reasoning, I will provide the final answer based on the current information and general knowledge.\n\n[FinalAnswer]',
371
+ skip_special_tokens=True,
372
+ max_new_tokens=max_new_tokens, max_token=max_token)
373
+ print(last_outputs_str)
374
+ return last_outputs_str
375
+
376
+ def run_multistep_agent(self, message: str,
377
+ temperature: float,
378
+ max_new_tokens: int,
379
+ max_token: int,
380
+ max_round: int = 20,
381
+ call_agent=False,
382
+ call_agent_level=0) -> str:
383
+ """
384
+ Generate a streaming response using the llama3-8b model.
385
+ Args:
386
+ message (str): The input message.
387
+ temperature (float): The temperature for generating the response.
388
+ max_new_tokens (int): The maximum number of new tokens to generate.
389
+ Returns:
390
+ str: The generated response.
391
+ """
392
+ print("\033[1;32;40mstart\033[0m")
393
  picked_tools_prompt, call_agent_level = self.initialize_tools_prompt(
394
  call_agent, call_agent_level, message)
395
  conversation = self.initialize_conversation(message)
396
+
397
  outputs = []
398
  last_outputs = []
399
  next_round = True
400
+ function_call_messages = []
401
  current_round = 0
402
  token_overflow = False
403
  enable_summary = False
 
405
 
406
  if self.enable_checker:
407
  checker = ReasoningTraceChecker(message, conversation)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
408
  try:
409
  while next_round and current_round < max_round:
410
  current_round += 1
411
+ if len(outputs) > 0:
 
412
  function_call_messages, picked_tools_prompt, special_tool_call = self.run_function_call(
413
+ last_outputs, return_message=True,
414
+ existing_tools_prompt=picked_tools_prompt,
415
+ message_for_call_agent=message,
416
+ call_agent=call_agent,
417
+ call_agent_level=call_agent_level,
418
+ temperature=temperature)
419
+
420
  if special_tool_call == 'Finish':
421
  next_round = False
422
  conversation.extend(function_call_messages)
423
+ if isinstance(function_call_messages[0]['content'], types.GeneratorType):
424
+ function_call_messages[0]['content'] = next(
425
+ function_call_messages[0]['content'])
426
+ return function_call_messages[0]['content'].split('[FinalAnswer]')[-1]
427
 
428
  if (self.enable_summary or token_overflow) and not call_agent:
429
+ if token_overflow:
430
+ print("token_overflow, using summary")
431
  enable_summary = True
432
  last_status = self.function_result_summary(
433
  conversation, status=last_status, enable_summary=enable_summary)
434
 
435
+ if function_call_messages is not None:
436
  conversation.extend(function_call_messages)
437
+ outputs.append(tool_result_format(
438
+ function_call_messages))
439
  else:
440
  next_round = False
441
+ conversation.extend(
442
+ [{"role": "assistant", "content": ''.join(last_outputs)}])
443
+ return ''.join(last_outputs).replace("</s>", "")
444
  if self.enable_checker:
445
  good_status, wrong_info = checker.check_conversation()
446
  if not good_status:
447
+ next_round = False
448
+ print(
449
+ "Internal error in reasoning: " + wrong_info)
450
  break
451
+ last_outputs = []
452
+ outputs.append("### TxAgent:\n")
453
+ last_outputs_str, token_overflow = self.llm_infer(messages=conversation,
454
+ temperature=temperature,
455
+ tools=picked_tools_prompt,
456
+ skip_special_tokens=False,
457
+ max_new_tokens=max_new_tokens, max_token=max_token,
458
+ check_token_status=True)
459
+ if last_outputs_str is None:
460
+ next_round = False
461
+ print(
462
+ "The number of tokens exceeds the maximum limit.")
463
+ else:
464
+ last_outputs.append(last_outputs_str)
465
+ if max_round == current_round:
466
+ print("The number of rounds exceeds the maximum limit!")
467
+ if self.force_finish:
468
+ return self.get_answer_based_on_unfinished_reasoning(conversation, temperature, max_new_tokens, max_token)
469
+ else:
470
+ return None
471
 
472
+ except Exception as e:
473
+ print(f"Error: {e}")
474
+ if self.force_finish:
475
+ return self.get_answer_based_on_unfinished_reasoning(conversation, temperature, max_new_tokens, max_token)
476
+ else:
477
+ return None
478
 
479
+ def build_logits_processor(self, messages, llm):
480
+ # Use the tokenizer from the LLM instance.
481
+ tokenizer = llm.get_tokenizer()
482
+ if self.avoid_repeat and len(messages) > 2:
483
+ assistant_messages = []
484
+ for i in range(1, len(messages) + 1):
485
+ if messages[-i]['role'] == 'assistant':
486
+ assistant_messages.append(messages[-i]['content'])
487
+ if len(assistant_messages) == 2:
488
  break
489
+ forbidden_ids = [tokenizer.encode(
490
+ msg, add_special_tokens=False) for msg in assistant_messages]
491
+ return [NoRepeatSentenceProcessor(forbidden_ids, 5)]
492
+ else:
493
+ return None
494
 
495
+ def llm_infer(self, messages, temperature=0.1, tools=None,
496
+ output_begin_string=None, max_new_tokens=2048,
497
+ max_token=None, skip_special_tokens=True,
498
+ model=None, tokenizer=None, terminators=None, seed=None, check_token_status=False):
499
 
500
+ if model is None:
501
+ model = self.model
502
+
503
+ logits_processor = self.build_logits_processor(messages, model)
504
+ sampling_params = SamplingParams(
505
+ temperature=temperature,
506
+ max_tokens=max_new_tokens,
507
+ logits_processors=logits_processor,
508
+ seed=seed if seed is not None else self.seed,
509
+ )
510
 
511
+ prompt = self.chat_template.render(
512
+ messages=messages, tools=tools, add_generation_prompt=True)
513
+ if output_begin_string is not None:
514
+ prompt += output_begin_string
 
 
 
515
 
516
+ if check_token_status and max_token is not None:
517
+ token_overflow = False
518
+ num_input_tokens = len(self.tokenizer.encode(
519
+ prompt, return_tensors="pt")[0])
520
+ if max_token is not None:
521
+ if num_input_tokens > max_token:
522
+ torch.cuda.empty_cache()
523
+ gc.collect()
524
+ print("Number of input tokens before inference:",
525
+ num_input_tokens)
526
+ logger.info(
527
+ "The number of tokens exceeds the maximum limit!!!!")
528
+ token_overflow = True
529
+ return None, token_overflow
530
+ output = model.generate(
531
+ prompt,
532
+ sampling_params=sampling_params,
533
+ )
534
+ output = output[0].outputs[0].text
535
+ print("\033[92m" + output + "\033[0m")
536
+ if check_token_status and max_token is not None:
537
+ return output, token_overflow
538
+
539
+ return output
540
+
541
+ def run_self_agent(self, message: str,
542
+ temperature: float,
543
+ max_new_tokens: int,
544
+ max_token: int) -> str:
545
+
546
+ print("\033[1;32;40mstart self agent\033[0m")
547
+ conversation = []
548
+ conversation = self.set_system_prompt(conversation, self.self_prompt)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
549
  conversation.append({"role": "user", "content": message})
550
+ return self.llm_infer(messages=conversation,
551
+ temperature=temperature,
552
+ tools=None,
553
  max_new_tokens=max_new_tokens, max_token=max_token)
554
 
555
+ def run_chat_agent(self, message: str,
556
+ temperature: float,
557
+ max_new_tokens: int,
558
+ max_token: int) -> str:
559
+
560
+ print("\033[1;32;40mstart chat agent\033[0m")
561
+ conversation = []
562
+ conversation = self.set_system_prompt(conversation, self.chat_prompt)
563
  conversation.append({"role": "user", "content": message})
564
+ return self.llm_infer(messages=conversation,
565
+ temperature=temperature,
566
+ tools=None,
567
  max_new_tokens=max_new_tokens, max_token=max_token)
568
 
569
+ def run_format_agent(self, message: str,
570
+ answer: str,
571
+ temperature: float,
572
+ max_new_tokens: int,
573
+ max_token: int) -> str:
574
+
575
+ print("\033[1;32;40mstart format agent\033[0m")
576
  if '[FinalAnswer]' in answer:
577
  possible_final_answer = answer.split("[FinalAnswer]")[-1]
578
  elif "\n\n" in answer:
579
  possible_final_answer = answer.split("\n\n")[-1]
580
  else:
581
  possible_final_answer = answer.strip()
582
+ if len(possible_final_answer) == 1:
583
+ choice = possible_final_answer[0]
584
+ if choice in ['A', 'B', 'C', 'D', 'E']:
585
+ return choice
586
+ elif len(possible_final_answer) > 1:
587
+ if possible_final_answer[1] == ':':
588
+ choice = possible_final_answer[0]
589
+ if choice in ['A', 'B', 'C', 'D', 'E']:
590
+ print("choice", choice)
591
+ return choice
592
+
593
+ conversation = []
594
+ format_prompt = f"You are helpful assistant to transform the answer of agent to the final answer of 'A', 'B', 'C', 'D'."
595
+ conversation = self.set_system_prompt(conversation, format_prompt)
596
+ conversation.append({"role": "user", "content": message +
597
+ "\nThe final answer of agent:" + answer + "\n The answer is (must be a letter):"})
598
+ return self.llm_infer(messages=conversation,
599
+ temperature=temperature,
600
+ tools=None,
601
  max_new_tokens=max_new_tokens, max_token=max_token)
602
 
603
+ def run_summary_agent(self, thought_calls: str,
604
+ function_response: str,
605
+ temperature: float,
606
+ max_new_tokens: int,
607
+ max_token: int) -> str:
608
+ print("\033[1;32;40mSummarized Tool Result:\033[0m")
609
+ generate_tool_result_summary_training_prompt = """Thought and function calls:
610
+ {thought_calls}
611
+
612
+ Function calls' responses:
613
+ \"\"\"
614
+ {function_response}
615
+ \"\"\"
616
+
617
+ Based on the Thought and function calls, and the function calls' responses, you need to generate a summary of the function calls' responses that fulfills the requirements of the thought. The summary MUST BE ONE sentence and include all necessary information.
618
+
619
+ Directly respond with the summarized sentence of the function calls' responses only.
620
+
621
+ Generate **one summarized sentence** about "function calls' responses" with necessary information, and respond with a string:
622
+ """.format(thought_calls=thought_calls, function_response=function_response)
623
+ conversation = []
624
+ conversation.append(
625
+ {"role": "user", "content": generate_tool_result_summary_training_prompt})
626
+ output = self.llm_infer(messages=conversation,
627
+ temperature=temperature,
628
+ tools=None,
629
  max_new_tokens=max_new_tokens, max_token=max_token)
630
+
631
  if '[' in output:
632
  output = output.split('[')[0]
633
  return output
634
 
635
  def function_result_summary(self, input_list, status, enable_summary):
636
+ """
637
+ Processes the input list, extracting information from sequences of 'user', 'tool', 'assistant' roles.
638
+ Supports 'length' and 'step' modes, and skips the last 'k' groups.
639
+
640
+ Parameters:
641
+ input_list (list): A list of dictionaries containing role and other information.
642
+ summary_skip_last_k (int): Number of groups to skip from the end. Defaults to 0.
643
+ summary_context_length (int): The context length threshold for the 'length' mode.
644
+ last_processed_index (tuple or int): The last processed index.
645
+
646
+ Returns:
647
+ list: A list of extracted information from valid sequences.
648
+ """
649
  if 'tool_call_step' not in status:
650
  status['tool_call_step'] = 0
 
 
 
651
 
652
  for idx in range(len(input_list)):
653
+ pos_id = len(input_list)-idx-1
654
+ if input_list[pos_id]['role'] == 'assistant':
655
+ if 'tool_calls' in input_list[pos_id]:
656
+ if 'Tool_RAG' in str(input_list[pos_id]['tool_calls']):
657
+ status['tool_call_step'] += 1
658
  break
659
 
660
+ if 'step' in status:
661
+ status['step'] += 1
662
+ else:
663
+ status['step'] = 0
664
+
665
  if not enable_summary:
666
  return status
667
 
668
  if 'summarized_index' not in status:
669
  status['summarized_index'] = 0
670
+
671
  if 'summarized_step' not in status:
672
  status['summarized_step'] = 0
673
+
674
  if 'previous_length' not in status:
675
  status['previous_length'] = 0
676
+
677
  if 'history' not in status:
678
  status['history'] = []
679
 
 
 
 
 
680
  function_response = ''
681
+ idx = 0
682
+ current_summarized_index = status['summarized_index']
683
+
684
+ status['history'].append(self.summary_mode == 'step' and status['summarized_step']
685
+ < status['step']-status['tool_call_step']-self.summary_skip_last_k)
686
+
687
+ idx = current_summarized_index
688
  while idx < len(input_list):
689
+ if (self.summary_mode == 'step' and status['summarized_step'] < status['step']-status['tool_call_step']-self.summary_skip_last_k) or (self.summary_mode == 'length' and status['previous_length'] > self.summary_context_length):
690
+
691
  if input_list[idx]['role'] == 'assistant':
692
  if 'Tool_RAG' in str(input_list[idx]['tool_calls']):
693
  this_thought_calls = None
694
  else:
695
+ if len(function_response) != 0:
696
+ print("internal summary")
697
  status['summarized_step'] += 1
698
  result_summary = self.run_summary_agent(
699
+ thought_calls=this_thought_calls,
700
+ function_response=function_response,
701
+ temperature=0.1,
702
+ max_new_tokens=1024,
703
+ max_token=99999
704
+ )
705
+
706
  input_list.insert(
707
+ last_call_idx+1, {'role': 'tool', 'content': result_summary})
708
  status['summarized_index'] = last_call_idx + 2
709
  idx += 1
710
+
711
  last_call_idx = idx
712
+ this_thought_calls = input_list[idx]['content'] + \
713
+ input_list[idx]['tool_calls']
714
  function_response = ''
715
+
716
+ elif input_list[idx]['role'] == 'tool' and this_thought_calls is not None:
717
  function_response += input_list[idx]['content']
718
  del input_list[idx]
719
  idx -= 1
720
+
721
  else:
722
  break
723
  idx += 1
724
 
725
+ if len(function_response) != 0:
726
  status['summarized_step'] += 1
727
  result_summary = self.run_summary_agent(
728
+ thought_calls=this_thought_calls,
729
+ function_response=function_response,
730
+ temperature=0.1,
731
+ max_new_tokens=1024,
732
+ max_token=99999
733
+ )
734
+
735
  tool_calls = json.loads(input_list[last_call_idx]['tool_calls'])
736
  for tool_call in tool_calls:
737
  del tool_call['call_id']
738
  input_list[last_call_idx]['tool_calls'] = json.dumps(tool_calls)
739
  input_list.insert(
740
+ last_call_idx+1, {'role': 'tool', 'content': result_summary})
741
  status['summarized_index'] = last_call_idx + 2
742
 
743
  return status
744
 
745
+ # Following are Gradio related functions
746
+
747
+ # General update method that accepts any new arguments through kwargs
748
  def update_parameters(self, **kwargs):
 
749
  for key, value in kwargs.items():
750
  if hasattr(self, key):
751
  setattr(self, key, value)
752
+
753
+ # Return the updated attributes
754
+ updated_attributes = {key: value for key,
755
+ value in kwargs.items() if hasattr(self, key)}
756
+ return updated_attributes
757
+
758
+ def run_gradio_chat(self, message: str,
759
+ history: list,
760
+ temperature: float,
761
+ max_new_tokens: int,
762
+ max_token: int,
763
+ call_agent: bool,
764
+ conversation: gr.State,
765
+ max_round: int = 20,
766
+ seed: int = None,
767
+ call_agent_level: int = 0,
768
+ sub_agent_task: str = None) -> str:
769
+ """
770
+ Generate a streaming response using the llama3-8b model.
771
+ Args:
772
+ message (str): The input message.
773
+ history (list): The conversation history used by ChatInterface.
774
+ temperature (float): The temperature for generating the response.
775
+ max_new_tokens (int): The maximum number of new tokens to generate.
776
+ Returns:
777
+ str: The generated response.
778
+ """
779
+ print("\033[1;32;40mstart\033[0m")
780
+ print("len(message)", len(message))
781
+ if len(message) <= 10:
782
+ yield "Hi, I am TxAgent, an assistant for answering biomedical questions. Please provide a valid message with a string longer than 10 characters."
783
+ return "Please provide a valid message."
784
+ outputs = []
785
+ outputs_str = ''
786
+ last_outputs = []
787
+
788
+ picked_tools_prompt, call_agent_level = self.initialize_tools_prompt(
789
+ call_agent,
790
+ call_agent_level,
791
+ message)
792
+
793
+ conversation = self.initialize_conversation(
794
+ message,
795
+ conversation=conversation,
796
+ history=history)
797
+ history = []
798
+
799
+ next_round = True
800
+ function_call_messages = []
801
+ current_round = 0
802
+ enable_summary = False
803
+ last_status = {} # for summary
804
+ token_overflow = False
805
+ if self.enable_checker:
806
+ checker = ReasoningTraceChecker(
807
+ message, conversation, init_index=len(conversation))
808
+
809
+ try:
810
+ while next_round and current_round < max_round:
811
+ current_round += 1
812
+ if len(last_outputs) > 0:
813
+ function_call_messages, picked_tools_prompt, special_tool_call, current_gradio_history = yield from self.run_function_call_stream(
814
+ last_outputs, return_message=True,
815
+ existing_tools_prompt=picked_tools_prompt,
816
+ message_for_call_agent=message,
817
+ call_agent=call_agent,
818
+ call_agent_level=call_agent_level,
819
+ temperature=temperature)
820
+ history.extend(current_gradio_history)
821
+ if special_tool_call == 'Finish':
822
+ yield history
823
+ next_round = False
824
+ conversation.extend(function_call_messages)
825
+ return function_call_messages[0]['content']
826
+ elif special_tool_call == 'RequireClarification' or special_tool_call == 'DirectResponse':
827
+ history.append(
828
+ ChatMessage(role="assistant", content=history[-1].content))
829
+ yield history
830
+ next_round = False
831
+ return history[-1].content
832
+ if (self.enable_summary or token_overflow) and not call_agent:
833
+ if token_overflow:
834
+ print("token_overflow, using summary")
835
+ enable_summary = True
836
+ last_status = self.function_result_summary(
837
+ conversation, status=last_status,
838
+ enable_summary=enable_summary)
839
+ if function_call_messages is not None:
840
+ conversation.extend(function_call_messages)
841
+ formated_md_function_call_messages = tool_result_format(
842
+ function_call_messages)
843
+ yield history
844
+ else:
845
+ next_round = False
846
+ conversation.extend(
847
+ [{"role": "assistant", "content": ''.join(last_outputs)}])
848
+ return ''.join(last_outputs).replace("</s>", "")
849
+ if self.enable_checker:
850
+ good_status, wrong_info = checker.check_conversation()
851
+ if not good_status:
852
+ next_round = False
853
+ print("Internal error in reasoning: " + wrong_info)
854
+ break
855
+ last_outputs = []
856
+ last_outputs_str, token_overflow = self.llm_infer(
857
+ messages=conversation,
858
+ temperature=temperature,
859
+ tools=picked_tools_prompt,
860
+ skip_special_tokens=False,
861
+ max_new_tokens=max_new_tokens,
862
+ max_token=max_token,
863
+ seed=seed,
864
+ check_token_status=True)
865
+ last_thought = last_outputs_str.split("[TOOL_CALLS]")[0]
866
+ for each in history:
867
+ if each.metadata is not None:
868
+ each.metadata['status'] = 'done'
869
+ if '[FinalAnswer]' in last_thought:
870
+ final_thought, final_answer = last_thought.split(
871
+ '[FinalAnswer]')
872
+ history.append(
873
+ ChatMessage(role="assistant",
874
+ content=final_thought.strip())
875
+ )
876
+ yield history
877
+ history.append(
878
+ ChatMessage(
879
+ role="assistant", content="**Answer**:\n"+final_answer.strip())
880
+ )
881
+ yield history
882
+ else:
883
+ history.append(ChatMessage(
884
+ role="assistant", content=last_thought))
885
+ yield history
886
+
887
+ last_outputs.append(last_outputs_str)
888
+
889
+ if next_round:
890
+ if self.force_finish:
891
+ last_outputs_str = self.get_answer_based_on_unfinished_reasoning(
892
+ conversation, temperature, max_new_tokens, max_token)
893
+ for each in history:
894
+ if each.metadata is not None:
895
+ each.metadata['status'] = 'done'
896
+ if '[FinalAnswer]' in last_thought:
897
+ final_thought, final_answer = last_thought.split(
898
+ '[FinalAnswer]')
899
+ history.append(
900
+ ChatMessage(role="assistant",
901
+ content=final_thought.strip())
902
+ )
903
+ yield history
904
+ history.append(
905
+ ChatMessage(
906
+ role="assistant", content="**Answer**:\n"+final_answer.strip())
907
+ )
908
+ yield history
909
+ else:
910
+ yield "The number of rounds exceeds the maximum limit!"
911
+
912
+ except Exception as e:
913
+ print(f"Error: {e}")
914
+ if self.force_finish:
915
+ last_outputs_str = self.get_answer_based_on_unfinished_reasoning(
916
+ conversation,
917
+ temperature,
918
+ max_new_tokens,
919
+ max_token)
920
+ for each in history:
921
+ if each.metadata is not None:
922
+ each.metadata['status'] = 'done'
923
+ if '[FinalAnswer]' in last_thought or '"name": "Finish",' in last_outputs_str:
924
+ final_thought, final_answer = last_thought.split(
925
+ '[FinalAnswer]')
926
+ history.append(
927
+ ChatMessage(role="assistant",
928
+ content=final_thought.strip())
929
+ )
930
+ yield history
931
+ history.append(
932
+ ChatMessage(
933
+ role="assistant", content="**Answer**:\n"+final_answer.strip())
934
+ )
935
+ yield history
936
+ else:
937
+ return None