Ali2206 commited on
Commit
e0669ce
·
verified ·
1 Parent(s): f5365bc

Update src/txagent/txagent.py

Browse files
Files changed (1) hide show
  1. src/txagent/txagent.py +53 -23
src/txagent/txagent.py CHANGED
@@ -13,6 +13,7 @@ from gradio import ChatMessage
13
  from .toolrag import ToolRAGModel
14
  import torch
15
  import logging
 
16
 
17
  logger = logging.getLogger(__name__)
18
  logging.basicConfig(level=logging.INFO)
@@ -26,15 +27,15 @@ class TxAgent:
26
  enable_finish=True,
27
  enable_rag=True,
28
  enable_summary=False,
29
- init_rag_num=2, # Reduced for faster initial tool selection
30
- step_rag_num=4, # Reduced for fewer RAG calls
31
  summary_mode='step',
32
  summary_skip_last_k=0,
33
  summary_context_length=None,
34
  force_finish=True,
35
  avoid_repeat=True,
36
  seed=None,
37
- enable_checker=False, # Disabled by default for speed
38
  enable_chat=False,
39
  additional_default_tools=None):
40
  self.model_name = model_name
@@ -78,7 +79,7 @@ class TxAgent:
78
  if model_name:
79
  self.model_name = model_name
80
 
81
- self.model = LLM(model=self.model_name, dtype="float16") # Enable FP16
82
  self.chat_template = Template(self.model.get_tokenizer().chat_template)
83
  self.tokenizer = self.model.get_tokenizer()
84
  logger.info("Model %s loaded successfully", self.model_name)
@@ -101,16 +102,17 @@ class TxAgent:
101
 
102
  def initialize_tools_prompt(self, call_agent, call_agent_level, message):
103
  picked_tools_prompt = []
104
- picked_tools_prompt = self.add_special_tools(
105
- picked_tools_prompt, call_agent=call_agent)
106
- if call_agent:
107
- call_agent_level += 1
108
- if call_agent_level >= 2:
109
- call_agent = False
110
-
111
- if not call_agent and self.enable_rag:
112
- picked_tools_prompt += self.tool_RAG(
113
- message=message, rag_num=self.init_rag_num)
 
114
  return picked_tools_prompt, call_agent_level
115
 
116
  def initialize_conversation(self, message, conversation=None, history=None):
@@ -129,7 +131,7 @@ class TxAgent:
129
 
130
  def tool_RAG(self, message=None, picked_tool_names=None,
131
  existing_tools_prompt=None, rag_num=4, return_call_result=False):
132
- extra_factor = 10 # Reduced from 30 for efficiency
133
  if picked_tool_names is None:
134
  picked_tool_names = self.rag_infer(message, top_k=rag_num * extra_factor)
135
 
@@ -148,10 +150,10 @@ class TxAgent:
148
  if self.enable_finish:
149
  tools.append(self.tooluniverse.get_one_tool_by_one_name('Finish', return_prompt=True))
150
  logger.debug("Finish tool added")
151
- if call_agent:
152
  tools.append(self.tooluniverse.get_one_tool_by_one_name('CallAgent', return_prompt=True))
153
  logger.debug("CallAgent tool added")
154
- elif self.enable_rag:
155
  tools.append(self.tooluniverse.get_one_tool_by_one_name('Tool_RAG', return_prompt=True))
156
  logger.debug("Tool_RAG tool added")
157
  if self.additional_default_tools:
@@ -301,7 +303,7 @@ class TxAgent:
301
  return output
302
 
303
  def run_multistep_agent(self, message: str, temperature: float, max_new_tokens: int,
304
- max_token: int, max_round: int = 10, call_agent=False, call_agent_level=0):
305
  logger.debug("Starting multistep agent for message: %s", message[:100])
306
  picked_tools_prompt, call_agent_level = self.initialize_tools_prompt(
307
  call_agent, call_agent_level, message)
@@ -317,6 +319,10 @@ class TxAgent:
317
  if self.enable_checker:
318
  checker = ReasoningTraceChecker(message, conversation)
319
 
 
 
 
 
320
  while next_round and current_round < max_round:
321
  current_round += 1
322
  if last_outputs:
@@ -349,9 +355,11 @@ class TxAgent:
349
  logger.warning("Checker error: %s", wrong_info)
350
  break
351
 
 
 
352
  last_outputs = []
353
  last_outputs_str, token_overflow = self.llm_infer(
354
- messages=conversation, temperature=temperature, tools=picked_tools_prompt,
355
  max_new_tokens=max_new_tokens, max_token=max_token, check_token_status=True)
356
  if last_outputs_str is None:
357
  if self.force_finish:
@@ -374,7 +382,22 @@ class TxAgent:
374
  m['content'] for m in messages[-3:] if m['role'] == 'assistant'
375
  ][:2]
376
  forbidden_ids = [tokenizer.encode(msg, add_special_tokens=False) for msg in assistant_messages]
377
- return [NoRepeatSentenceProcessor(forbidden_ids, 5)]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
378
  return None
379
 
380
  def llm_infer(self, messages, temperature=0.1, tools=None, output_begin_string=None,
@@ -407,7 +430,7 @@ class TxAgent:
407
  output = model.generate(prompt, sampling_params=sampling_params)
408
  output = output[0].outputs[0].text
409
  logger.debug("Inference output: %s", output[:100])
410
- torch.cuda.empty_cache() # Clear CUDA cache
411
  if check_token_status:
412
  return output, False
413
  return output
@@ -544,7 +567,7 @@ Summarize the function responses in one sentence with all necessary information.
544
 
545
  def run_gradio_chat(self, message: str, history: list, temperature: float,
546
  max_new_tokens: int, max_token: int, call_agent: bool,
547
- conversation: gr.State, max_round: int = 10, seed: int = None,
548
  call_agent_level: int = 0, sub_agent_task: str = None,
549
  uploaded_files: list = None):
550
  logger.debug("Chat started, message: %s", message[:100])
@@ -555,6 +578,11 @@ Summarize the function responses in one sentence with all necessary information.
555
  if message.startswith("[\U0001f9f0 Tool_RAG") or message.startswith("⚒️"):
556
  return
557
 
 
 
 
 
 
558
  picked_tools_prompt, call_agent_level = self.initialize_tools_prompt(
559
  call_agent, call_agent_level, message)
560
  conversation = self.initialize_conversation(
@@ -612,8 +640,10 @@ Summarize the function responses in one sentence with all necessary information.
612
  logger.warning("Checker error: %s", wrong_info)
613
  break
614
 
 
 
615
  last_outputs_str, token_overflow = self.llm_infer(
616
- messages=conversation, temperature=temperature, tools=picked_tools_prompt,
617
  max_new_tokens=max_new_tokens, max_token=max_token, seed=seed, check_token_status=True)
618
 
619
  if last_outputs_str is None:
 
13
  from .toolrag import ToolRAGModel
14
  import torch
15
  import logging
16
+ from difflib import SequenceMatcher
17
 
18
  logger = logging.getLogger(__name__)
19
  logging.basicConfig(level=logging.INFO)
 
27
  enable_finish=True,
28
  enable_rag=True,
29
  enable_summary=False,
30
+ init_rag_num=2,
31
+ step_rag_num=4,
32
  summary_mode='step',
33
  summary_skip_last_k=0,
34
  summary_context_length=None,
35
  force_finish=True,
36
  avoid_repeat=True,
37
  seed=None,
38
+ enable_checker=False,
39
  enable_chat=False,
40
  additional_default_tools=None):
41
  self.model_name = model_name
 
79
  if model_name:
80
  self.model_name = model_name
81
 
82
+ self.model = LLM(model=self.model_name, dtype="float16")
83
  self.chat_template = Template(self.model.get_tokenizer().chat_template)
84
  self.tokenizer = self.model.get_tokenizer()
85
  logger.info("Model %s loaded successfully", self.model_name)
 
102
 
103
  def initialize_tools_prompt(self, call_agent, call_agent_level, message):
104
  picked_tools_prompt = []
105
+ # Only add Finish tool unless prompt explicitly requires Tool_RAG or CallAgent
106
+ if "use external tools" not in message.lower():
107
+ picked_tools_prompt = self.add_special_tools(picked_tools_prompt, call_agent=False)
108
+ else:
109
+ picked_tools_prompt = self.add_special_tools(picked_tools_prompt, call_agent=call_agent)
110
+ if call_agent:
111
+ call_agent_level += 1
112
+ if call_agent_level >= 2:
113
+ call_agent = False
114
+ if self.enable_rag:
115
+ picked_tools_prompt += self.tool_RAG(message=message, rag_num=self.init_rag_num)
116
  return picked_tools_prompt, call_agent_level
117
 
118
  def initialize_conversation(self, message, conversation=None, history=None):
 
131
 
132
  def tool_RAG(self, message=None, picked_tool_names=None,
133
  existing_tools_prompt=None, rag_num=4, return_call_result=False):
134
+ extra_factor = 10
135
  if picked_tool_names is None:
136
  picked_tool_names = self.rag_infer(message, top_k=rag_num * extra_factor)
137
 
 
150
  if self.enable_finish:
151
  tools.append(self.tooluniverse.get_one_tool_by_one_name('Finish', return_prompt=True))
152
  logger.debug("Finish tool added")
153
+ if call_agent and "use external tools" in self.prompt_multi_step.lower():
154
  tools.append(self.tooluniverse.get_one_tool_by_one_name('CallAgent', return_prompt=True))
155
  logger.debug("CallAgent tool added")
156
+ elif self.enable_rag and "use external tools" in self.prompt_multi_step.lower():
157
  tools.append(self.tooluniverse.get_one_tool_by_one_name('Tool_RAG', return_prompt=True))
158
  logger.debug("Tool_RAG tool added")
159
  if self.additional_default_tools:
 
303
  return output
304
 
305
  def run_multistep_agent(self, message: str, temperature: float, max_new_tokens: int,
306
+ max_token: int, max_round: int = 3, call_agent=False, call_agent_level=0):
307
  logger.debug("Starting multistep agent for message: %s", message[:100])
308
  picked_tools_prompt, call_agent_level = self.initialize_tools_prompt(
309
  call_agent, call_agent_level, message)
 
319
  if self.enable_checker:
320
  checker = ReasoningTraceChecker(message, conversation)
321
 
322
+ # Check if message contains clinical findings
323
+ clinical_keywords = ['medication', 'symptom', 'evaluation', 'diagnosis']
324
+ has_clinical_data = any(keyword in message.lower() for keyword in clinical_keywords)
325
+
326
  while next_round and current_round < max_round:
327
  current_round += 1
328
  if last_outputs:
 
355
  logger.warning("Checker error: %s", wrong_info)
356
  break
357
 
358
+ # Skip tool calls if clinical data is present
359
+ tools = [] if has_clinical_data else picked_tools_prompt
360
  last_outputs = []
361
  last_outputs_str, token_overflow = self.llm_infer(
362
+ messages=conversation, temperature=temperature, tools=tools,
363
  max_new_tokens=max_new_tokens, max_token=max_token, check_token_status=True)
364
  if last_outputs_str is None:
365
  if self.force_finish:
 
382
  m['content'] for m in messages[-3:] if m['role'] == 'assistant'
383
  ][:2]
384
  forbidden_ids = [tokenizer.encode(msg, add_special_tokens=False) for msg in assistant_messages]
385
+ # Enhance deduplication with similarity check
386
+ unique_sentences = []
387
+ for msg in assistant_messages:
388
+ sentences = msg.split('. ')
389
+ for s in sentences:
390
+ if not s:
391
+ continue
392
+ is_unique = True
393
+ for seen_s in unique_sentences:
394
+ if SequenceMatcher(None, s.lower(), seen_s.lower()).ratio() > 0.9:
395
+ is_unique = False
396
+ break
397
+ if is_unique:
398
+ unique_sentences.append(s)
399
+ forbidden_ids = [tokenizer.encode(s, add_special_tokens=False) for s in unique_sentences]
400
+ return [NoRepeatSentenceProcessor(forbidden_ids, 10)] # Increased penalty
401
  return None
402
 
403
  def llm_infer(self, messages, temperature=0.1, tools=None, output_begin_string=None,
 
430
  output = model.generate(prompt, sampling_params=sampling_params)
431
  output = output[0].outputs[0].text
432
  logger.debug("Inference output: %s", output[:100])
433
+ torch.cuda.empty_cache()
434
  if check_token_status:
435
  return output, False
436
  return output
 
567
 
568
  def run_gradio_chat(self, message: str, history: list, temperature: float,
569
  max_new_tokens: int, max_token: int, call_agent: bool,
570
+ conversation: gr.State, max_round: int = 3, seed: int = None,
571
  call_agent_level: int = 0, sub_agent_task: str = None,
572
  uploaded_files: list = None):
573
  logger.debug("Chat started, message: %s", message[:100])
 
578
  if message.startswith("[\U0001f9f0 Tool_RAG") or message.startswith("⚒️"):
579
  return
580
 
581
+ # Check if message contains clinical findings
582
+ clinical_keywords = ['medication', 'symptom', 'evaluation', 'diagnosis']
583
+ has_clinical_data = any(keyword in message.lower() for keyword in clinical_keywords)
584
+ call_agent = call_agent and not has_clinical_data # Disable CallAgent for clinical data
585
+
586
  picked_tools_prompt, call_agent_level = self.initialize_tools_prompt(
587
  call_agent, call_agent_level, message)
588
  conversation = self.initialize_conversation(
 
640
  logger.warning("Checker error: %s", wrong_info)
641
  break
642
 
643
+ # Skip tool calls if clinical data is present
644
+ tools = [] if has_clinical_data else picked_tools_prompt
645
  last_outputs_str, token_overflow = self.llm_infer(
646
+ messages=conversation, temperature=temperature, tools=tools,
647
  max_new_tokens=max_new_tokens, max_token=max_token, seed=seed, check_token_status=True)
648
 
649
  if last_outputs_str is None: