Update src/txagent/txagent.py
Browse files- src/txagent/txagent.py +53 -23
src/txagent/txagent.py
CHANGED
@@ -13,6 +13,7 @@ from gradio import ChatMessage
|
|
13 |
from .toolrag import ToolRAGModel
|
14 |
import torch
|
15 |
import logging
|
|
|
16 |
|
17 |
logger = logging.getLogger(__name__)
|
18 |
logging.basicConfig(level=logging.INFO)
|
@@ -26,15 +27,15 @@ class TxAgent:
|
|
26 |
enable_finish=True,
|
27 |
enable_rag=True,
|
28 |
enable_summary=False,
|
29 |
-
init_rag_num=2,
|
30 |
-
step_rag_num=4,
|
31 |
summary_mode='step',
|
32 |
summary_skip_last_k=0,
|
33 |
summary_context_length=None,
|
34 |
force_finish=True,
|
35 |
avoid_repeat=True,
|
36 |
seed=None,
|
37 |
-
enable_checker=False,
|
38 |
enable_chat=False,
|
39 |
additional_default_tools=None):
|
40 |
self.model_name = model_name
|
@@ -78,7 +79,7 @@ class TxAgent:
|
|
78 |
if model_name:
|
79 |
self.model_name = model_name
|
80 |
|
81 |
-
self.model = LLM(model=self.model_name, dtype="float16")
|
82 |
self.chat_template = Template(self.model.get_tokenizer().chat_template)
|
83 |
self.tokenizer = self.model.get_tokenizer()
|
84 |
logger.info("Model %s loaded successfully", self.model_name)
|
@@ -101,16 +102,17 @@ class TxAgent:
|
|
101 |
|
102 |
def initialize_tools_prompt(self, call_agent, call_agent_level, message):
|
103 |
picked_tools_prompt = []
|
104 |
-
|
105 |
-
|
106 |
-
|
107 |
-
|
108 |
-
|
109 |
-
|
110 |
-
|
111 |
-
|
112 |
-
|
113 |
-
|
|
|
114 |
return picked_tools_prompt, call_agent_level
|
115 |
|
116 |
def initialize_conversation(self, message, conversation=None, history=None):
|
@@ -129,7 +131,7 @@ class TxAgent:
|
|
129 |
|
130 |
def tool_RAG(self, message=None, picked_tool_names=None,
|
131 |
existing_tools_prompt=None, rag_num=4, return_call_result=False):
|
132 |
-
extra_factor = 10
|
133 |
if picked_tool_names is None:
|
134 |
picked_tool_names = self.rag_infer(message, top_k=rag_num * extra_factor)
|
135 |
|
@@ -148,10 +150,10 @@ class TxAgent:
|
|
148 |
if self.enable_finish:
|
149 |
tools.append(self.tooluniverse.get_one_tool_by_one_name('Finish', return_prompt=True))
|
150 |
logger.debug("Finish tool added")
|
151 |
-
if call_agent:
|
152 |
tools.append(self.tooluniverse.get_one_tool_by_one_name('CallAgent', return_prompt=True))
|
153 |
logger.debug("CallAgent tool added")
|
154 |
-
elif self.enable_rag:
|
155 |
tools.append(self.tooluniverse.get_one_tool_by_one_name('Tool_RAG', return_prompt=True))
|
156 |
logger.debug("Tool_RAG tool added")
|
157 |
if self.additional_default_tools:
|
@@ -301,7 +303,7 @@ class TxAgent:
|
|
301 |
return output
|
302 |
|
303 |
def run_multistep_agent(self, message: str, temperature: float, max_new_tokens: int,
|
304 |
-
max_token: int, max_round: int =
|
305 |
logger.debug("Starting multistep agent for message: %s", message[:100])
|
306 |
picked_tools_prompt, call_agent_level = self.initialize_tools_prompt(
|
307 |
call_agent, call_agent_level, message)
|
@@ -317,6 +319,10 @@ class TxAgent:
|
|
317 |
if self.enable_checker:
|
318 |
checker = ReasoningTraceChecker(message, conversation)
|
319 |
|
|
|
|
|
|
|
|
|
320 |
while next_round and current_round < max_round:
|
321 |
current_round += 1
|
322 |
if last_outputs:
|
@@ -349,9 +355,11 @@ class TxAgent:
|
|
349 |
logger.warning("Checker error: %s", wrong_info)
|
350 |
break
|
351 |
|
|
|
|
|
352 |
last_outputs = []
|
353 |
last_outputs_str, token_overflow = self.llm_infer(
|
354 |
-
messages=conversation, temperature=temperature, tools=
|
355 |
max_new_tokens=max_new_tokens, max_token=max_token, check_token_status=True)
|
356 |
if last_outputs_str is None:
|
357 |
if self.force_finish:
|
@@ -374,7 +382,22 @@ class TxAgent:
|
|
374 |
m['content'] for m in messages[-3:] if m['role'] == 'assistant'
|
375 |
][:2]
|
376 |
forbidden_ids = [tokenizer.encode(msg, add_special_tokens=False) for msg in assistant_messages]
|
377 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
378 |
return None
|
379 |
|
380 |
def llm_infer(self, messages, temperature=0.1, tools=None, output_begin_string=None,
|
@@ -407,7 +430,7 @@ class TxAgent:
|
|
407 |
output = model.generate(prompt, sampling_params=sampling_params)
|
408 |
output = output[0].outputs[0].text
|
409 |
logger.debug("Inference output: %s", output[:100])
|
410 |
-
torch.cuda.empty_cache()
|
411 |
if check_token_status:
|
412 |
return output, False
|
413 |
return output
|
@@ -544,7 +567,7 @@ Summarize the function responses in one sentence with all necessary information.
|
|
544 |
|
545 |
def run_gradio_chat(self, message: str, history: list, temperature: float,
|
546 |
max_new_tokens: int, max_token: int, call_agent: bool,
|
547 |
-
conversation: gr.State, max_round: int =
|
548 |
call_agent_level: int = 0, sub_agent_task: str = None,
|
549 |
uploaded_files: list = None):
|
550 |
logger.debug("Chat started, message: %s", message[:100])
|
@@ -555,6 +578,11 @@ Summarize the function responses in one sentence with all necessary information.
|
|
555 |
if message.startswith("[\U0001f9f0 Tool_RAG") or message.startswith("⚒️"):
|
556 |
return
|
557 |
|
|
|
|
|
|
|
|
|
|
|
558 |
picked_tools_prompt, call_agent_level = self.initialize_tools_prompt(
|
559 |
call_agent, call_agent_level, message)
|
560 |
conversation = self.initialize_conversation(
|
@@ -612,8 +640,10 @@ Summarize the function responses in one sentence with all necessary information.
|
|
612 |
logger.warning("Checker error: %s", wrong_info)
|
613 |
break
|
614 |
|
|
|
|
|
615 |
last_outputs_str, token_overflow = self.llm_infer(
|
616 |
-
messages=conversation, temperature=temperature, tools=
|
617 |
max_new_tokens=max_new_tokens, max_token=max_token, seed=seed, check_token_status=True)
|
618 |
|
619 |
if last_outputs_str is None:
|
|
|
13 |
from .toolrag import ToolRAGModel
|
14 |
import torch
|
15 |
import logging
|
16 |
+
from difflib import SequenceMatcher
|
17 |
|
18 |
logger = logging.getLogger(__name__)
|
19 |
logging.basicConfig(level=logging.INFO)
|
|
|
27 |
enable_finish=True,
|
28 |
enable_rag=True,
|
29 |
enable_summary=False,
|
30 |
+
init_rag_num=2,
|
31 |
+
step_rag_num=4,
|
32 |
summary_mode='step',
|
33 |
summary_skip_last_k=0,
|
34 |
summary_context_length=None,
|
35 |
force_finish=True,
|
36 |
avoid_repeat=True,
|
37 |
seed=None,
|
38 |
+
enable_checker=False,
|
39 |
enable_chat=False,
|
40 |
additional_default_tools=None):
|
41 |
self.model_name = model_name
|
|
|
79 |
if model_name:
|
80 |
self.model_name = model_name
|
81 |
|
82 |
+
self.model = LLM(model=self.model_name, dtype="float16")
|
83 |
self.chat_template = Template(self.model.get_tokenizer().chat_template)
|
84 |
self.tokenizer = self.model.get_tokenizer()
|
85 |
logger.info("Model %s loaded successfully", self.model_name)
|
|
|
102 |
|
103 |
def initialize_tools_prompt(self, call_agent, call_agent_level, message):
|
104 |
picked_tools_prompt = []
|
105 |
+
# Only add Finish tool unless prompt explicitly requires Tool_RAG or CallAgent
|
106 |
+
if "use external tools" not in message.lower():
|
107 |
+
picked_tools_prompt = self.add_special_tools(picked_tools_prompt, call_agent=False)
|
108 |
+
else:
|
109 |
+
picked_tools_prompt = self.add_special_tools(picked_tools_prompt, call_agent=call_agent)
|
110 |
+
if call_agent:
|
111 |
+
call_agent_level += 1
|
112 |
+
if call_agent_level >= 2:
|
113 |
+
call_agent = False
|
114 |
+
if self.enable_rag:
|
115 |
+
picked_tools_prompt += self.tool_RAG(message=message, rag_num=self.init_rag_num)
|
116 |
return picked_tools_prompt, call_agent_level
|
117 |
|
118 |
def initialize_conversation(self, message, conversation=None, history=None):
|
|
|
131 |
|
132 |
def tool_RAG(self, message=None, picked_tool_names=None,
|
133 |
existing_tools_prompt=None, rag_num=4, return_call_result=False):
|
134 |
+
extra_factor = 10
|
135 |
if picked_tool_names is None:
|
136 |
picked_tool_names = self.rag_infer(message, top_k=rag_num * extra_factor)
|
137 |
|
|
|
150 |
if self.enable_finish:
|
151 |
tools.append(self.tooluniverse.get_one_tool_by_one_name('Finish', return_prompt=True))
|
152 |
logger.debug("Finish tool added")
|
153 |
+
if call_agent and "use external tools" in self.prompt_multi_step.lower():
|
154 |
tools.append(self.tooluniverse.get_one_tool_by_one_name('CallAgent', return_prompt=True))
|
155 |
logger.debug("CallAgent tool added")
|
156 |
+
elif self.enable_rag and "use external tools" in self.prompt_multi_step.lower():
|
157 |
tools.append(self.tooluniverse.get_one_tool_by_one_name('Tool_RAG', return_prompt=True))
|
158 |
logger.debug("Tool_RAG tool added")
|
159 |
if self.additional_default_tools:
|
|
|
303 |
return output
|
304 |
|
305 |
def run_multistep_agent(self, message: str, temperature: float, max_new_tokens: int,
|
306 |
+
max_token: int, max_round: int = 3, call_agent=False, call_agent_level=0):
|
307 |
logger.debug("Starting multistep agent for message: %s", message[:100])
|
308 |
picked_tools_prompt, call_agent_level = self.initialize_tools_prompt(
|
309 |
call_agent, call_agent_level, message)
|
|
|
319 |
if self.enable_checker:
|
320 |
checker = ReasoningTraceChecker(message, conversation)
|
321 |
|
322 |
+
# Check if message contains clinical findings
|
323 |
+
clinical_keywords = ['medication', 'symptom', 'evaluation', 'diagnosis']
|
324 |
+
has_clinical_data = any(keyword in message.lower() for keyword in clinical_keywords)
|
325 |
+
|
326 |
while next_round and current_round < max_round:
|
327 |
current_round += 1
|
328 |
if last_outputs:
|
|
|
355 |
logger.warning("Checker error: %s", wrong_info)
|
356 |
break
|
357 |
|
358 |
+
# Skip tool calls if clinical data is present
|
359 |
+
tools = [] if has_clinical_data else picked_tools_prompt
|
360 |
last_outputs = []
|
361 |
last_outputs_str, token_overflow = self.llm_infer(
|
362 |
+
messages=conversation, temperature=temperature, tools=tools,
|
363 |
max_new_tokens=max_new_tokens, max_token=max_token, check_token_status=True)
|
364 |
if last_outputs_str is None:
|
365 |
if self.force_finish:
|
|
|
382 |
m['content'] for m in messages[-3:] if m['role'] == 'assistant'
|
383 |
][:2]
|
384 |
forbidden_ids = [tokenizer.encode(msg, add_special_tokens=False) for msg in assistant_messages]
|
385 |
+
# Enhance deduplication with similarity check
|
386 |
+
unique_sentences = []
|
387 |
+
for msg in assistant_messages:
|
388 |
+
sentences = msg.split('. ')
|
389 |
+
for s in sentences:
|
390 |
+
if not s:
|
391 |
+
continue
|
392 |
+
is_unique = True
|
393 |
+
for seen_s in unique_sentences:
|
394 |
+
if SequenceMatcher(None, s.lower(), seen_s.lower()).ratio() > 0.9:
|
395 |
+
is_unique = False
|
396 |
+
break
|
397 |
+
if is_unique:
|
398 |
+
unique_sentences.append(s)
|
399 |
+
forbidden_ids = [tokenizer.encode(s, add_special_tokens=False) for s in unique_sentences]
|
400 |
+
return [NoRepeatSentenceProcessor(forbidden_ids, 10)] # Increased penalty
|
401 |
return None
|
402 |
|
403 |
def llm_infer(self, messages, temperature=0.1, tools=None, output_begin_string=None,
|
|
|
430 |
output = model.generate(prompt, sampling_params=sampling_params)
|
431 |
output = output[0].outputs[0].text
|
432 |
logger.debug("Inference output: %s", output[:100])
|
433 |
+
torch.cuda.empty_cache()
|
434 |
if check_token_status:
|
435 |
return output, False
|
436 |
return output
|
|
|
567 |
|
568 |
def run_gradio_chat(self, message: str, history: list, temperature: float,
|
569 |
max_new_tokens: int, max_token: int, call_agent: bool,
|
570 |
+
conversation: gr.State, max_round: int = 3, seed: int = None,
|
571 |
call_agent_level: int = 0, sub_agent_task: str = None,
|
572 |
uploaded_files: list = None):
|
573 |
logger.debug("Chat started, message: %s", message[:100])
|
|
|
578 |
if message.startswith("[\U0001f9f0 Tool_RAG") or message.startswith("⚒️"):
|
579 |
return
|
580 |
|
581 |
+
# Check if message contains clinical findings
|
582 |
+
clinical_keywords = ['medication', 'symptom', 'evaluation', 'diagnosis']
|
583 |
+
has_clinical_data = any(keyword in message.lower() for keyword in clinical_keywords)
|
584 |
+
call_agent = call_agent and not has_clinical_data # Disable CallAgent for clinical data
|
585 |
+
|
586 |
picked_tools_prompt, call_agent_level = self.initialize_tools_prompt(
|
587 |
call_agent, call_agent_level, message)
|
588 |
conversation = self.initialize_conversation(
|
|
|
640 |
logger.warning("Checker error: %s", wrong_info)
|
641 |
break
|
642 |
|
643 |
+
# Skip tool calls if clinical data is present
|
644 |
+
tools = [] if has_clinical_data else picked_tools_prompt
|
645 |
last_outputs_str, token_overflow = self.llm_infer(
|
646 |
+
messages=conversation, temperature=temperature, tools=tools,
|
647 |
max_new_tokens=max_new_tokens, max_token=max_token, seed=seed, check_token_status=True)
|
648 |
|
649 |
if last_outputs_str is None:
|