Update src/txagent/txagent.py
Browse files- src/txagent/txagent.py +72 -20
src/txagent/txagent.py
CHANGED
@@ -74,10 +74,20 @@ class TxAgent:
|
|
74 |
return f"The model {model_name} is already loaded."
|
75 |
self.model_name = model_name
|
76 |
|
77 |
-
self.model = LLM(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
78 |
self.chat_template = Template(self.model.get_tokenizer().chat_template)
|
79 |
self.tokenizer = self.model.get_tokenizer()
|
80 |
-
logger.info(
|
|
|
|
|
|
|
81 |
return f"Model {model_name} loaded successfully."
|
82 |
|
83 |
def load_tooluniverse(self):
|
@@ -204,7 +214,7 @@ class TxAgent:
|
|
204 |
)
|
205 |
call_result = self.run_multistep_agent(
|
206 |
full_message, temperature=temperature,
|
207 |
-
max_new_tokens=512, max_token=
|
208 |
call_agent=False, call_agent_level=call_agent_level)
|
209 |
if call_result is None:
|
210 |
call_result = "⚠️ No content returned from sub-agent."
|
@@ -277,7 +287,7 @@ class TxAgent:
|
|
277 |
sub_agent_task = "Sub TxAgent plan: " + str(solution_plan)
|
278 |
call_result = yield from self.run_gradio_chat(
|
279 |
full_message, history=[], temperature=temperature,
|
280 |
-
max_new_tokens=512, max_token=
|
281 |
call_agent=False, call_agent_level=call_agent_level,
|
282 |
conversation=None, sub_agent_task=sub_agent_task)
|
283 |
if call_result is not None and isinstance(call_result, str):
|
@@ -387,7 +397,7 @@ class TxAgent:
|
|
387 |
tools=picked_tools_prompt,
|
388 |
skip_special_tokens=False,
|
389 |
max_new_tokens=2048,
|
390 |
-
max_token=
|
391 |
check_token_status=True)
|
392 |
if last_outputs_str is None:
|
393 |
logger.warning("Token limit exceeded")
|
@@ -410,7 +420,7 @@ class TxAgent:
|
|
410 |
|
411 |
def llm_infer(self, messages, temperature=0.1, tools=None,
|
412 |
output_begin_string=None, max_new_tokens=512,
|
413 |
-
max_token=
|
414 |
model=None, tokenizer=None, terminators=None,
|
415 |
seed=None, check_token_status=False):
|
416 |
if model is None:
|
@@ -430,21 +440,23 @@ class TxAgent:
|
|
430 |
|
431 |
if check_token_status and max_token is not None:
|
432 |
token_overflow = False
|
433 |
-
num_input_tokens = len(self.tokenizer.encode(prompt,
|
|
|
434 |
if num_input_tokens > max_token:
|
435 |
torch.cuda.empty_cache()
|
436 |
gc.collect()
|
437 |
-
logger.
|
438 |
return None, True
|
439 |
|
440 |
output = model.generate(prompt, sampling_params=sampling_params)
|
441 |
-
|
442 |
-
|
|
|
443 |
torch.cuda.empty_cache()
|
444 |
gc.collect()
|
445 |
if check_token_status and max_token is not None:
|
446 |
-
return
|
447 |
-
return
|
448 |
|
449 |
def run_self_agent(self, message: str,
|
450 |
temperature: float,
|
@@ -514,7 +526,7 @@ Function calls' responses:
|
|
514 |
\"\"\"
|
515 |
{function_response}
|
516 |
\"\"\"
|
517 |
-
Summarize the function calls' responses in one sentence with all necessary information.
|
518 |
"""
|
519 |
conversation = [{"role": "user", "content": prompt}]
|
520 |
output = self.llm_infer(
|
@@ -559,7 +571,7 @@ Summarize the function calls' responses in one sentence with all necessary infor
|
|
559 |
function_response=function_response,
|
560 |
temperature=0.1,
|
561 |
max_new_tokens=512,
|
562 |
-
max_token=
|
563 |
input_list.insert(last_call_idx + 1, {'role': 'tool', 'content': result_summary})
|
564 |
status['summarized_index'] = last_call_idx + 2
|
565 |
idx += 1
|
@@ -581,7 +593,7 @@ Summarize the function calls' responses in one sentence with all necessary infor
|
|
581 |
function_response=function_response,
|
582 |
temperature=0.1,
|
583 |
max_new_tokens=512,
|
584 |
-
max_token=
|
585 |
tool_calls = json.loads(input_list[last_call_idx]['tool_calls'])
|
586 |
for tool_call in tool_calls:
|
587 |
del tool_call['call_id']
|
@@ -603,10 +615,10 @@ Summarize the function calls' responses in one sentence with all necessary infor
|
|
603 |
def run_gradio_chat(self, message: str,
|
604 |
history: list,
|
605 |
temperature: float,
|
606 |
-
max_new_tokens: 2048,
|
607 |
-
max_token:
|
608 |
-
call_agent: bool,
|
609 |
-
conversation: gr.State,
|
610 |
max_round: int = 5,
|
611 |
seed: int = None,
|
612 |
call_agent_level: int = 0,
|
@@ -755,4 +767,44 @@ Summarize the function calls' responses in one sentence with all necessary infor
|
|
755 |
logger.info("Forced final answer after error: %s", final_answer[:100])
|
756 |
yield history
|
757 |
return final_answer
|
758 |
-
return error_msg
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
74 |
return f"The model {model_name} is already loaded."
|
75 |
self.model_name = model_name
|
76 |
|
77 |
+
self.model = LLM(
|
78 |
+
model=self.model_name,
|
79 |
+
dtype="float16",
|
80 |
+
max_model_len=131072,
|
81 |
+
max_num_batched_tokens=32768, # Increased for A100 80GB
|
82 |
+
gpu_memory_utilization=0.9, # Higher utilization for better performance
|
83 |
+
trust_remote_code=True
|
84 |
+
)
|
85 |
self.chat_template = Template(self.model.get_tokenizer().chat_template)
|
86 |
self.tokenizer = self.model.get_tokenizer()
|
87 |
+
logger.info(
|
88 |
+
"Model %s loaded with max_model_len=%d, max_num_batched_tokens=%d, gpu_memory_utilization=%.2f",
|
89 |
+
self.model_name, 131072, 32768, 0.9
|
90 |
+
)
|
91 |
return f"Model {model_name} loaded successfully."
|
92 |
|
93 |
def load_tooluniverse(self):
|
|
|
214 |
)
|
215 |
call_result = self.run_multistep_agent(
|
216 |
full_message, temperature=temperature,
|
217 |
+
max_new_tokens=512, max_token=131072,
|
218 |
call_agent=False, call_agent_level=call_agent_level)
|
219 |
if call_result is None:
|
220 |
call_result = "⚠️ No content returned from sub-agent."
|
|
|
287 |
sub_agent_task = "Sub TxAgent plan: " + str(solution_plan)
|
288 |
call_result = yield from self.run_gradio_chat(
|
289 |
full_message, history=[], temperature=temperature,
|
290 |
+
max_new_tokens=512, max_token=131072,
|
291 |
call_agent=False, call_agent_level=call_agent_level,
|
292 |
conversation=None, sub_agent_task=sub_agent_task)
|
293 |
if call_result is not None and isinstance(call_result, str):
|
|
|
397 |
tools=picked_tools_prompt,
|
398 |
skip_special_tokens=False,
|
399 |
max_new_tokens=2048,
|
400 |
+
max_token=131072,
|
401 |
check_token_status=True)
|
402 |
if last_outputs_str is None:
|
403 |
logger.warning("Token limit exceeded")
|
|
|
420 |
|
421 |
def llm_infer(self, messages, temperature=0.1, tools=None,
|
422 |
output_begin_string=None, max_new_tokens=512,
|
423 |
+
max_token=131072, skip_special_tokens=True,
|
424 |
model=None, tokenizer=None, terminators=None,
|
425 |
seed=None, check_token_status=False):
|
426 |
if model is None:
|
|
|
440 |
|
441 |
if check_token_status and max_token is not None:
|
442 |
token_overflow = False
|
443 |
+
num_input_tokens = len(self.tokenizer.encode(prompt, add_special_tokens=False))
|
444 |
+
logger.info("Input prompt tokens: %d, max_token: %d", num_input_tokens, max_token)
|
445 |
if num_input_tokens > max_token:
|
446 |
torch.cuda.empty_cache()
|
447 |
gc.collect()
|
448 |
+
logger.warning("Token overflow: %d > %d", num_input_tokens, max_token)
|
449 |
return None, True
|
450 |
|
451 |
output = model.generate(prompt, sampling_params=sampling_params)
|
452 |
+
output_text = output[0].outputs[0].text
|
453 |
+
output_tokens = len(self.tokenizer.encode(output_text, add_special_tokens=False))
|
454 |
+
logger.debug("Inference output: %s (output tokens: %d)", output_text[:100], output_tokens)
|
455 |
torch.cuda.empty_cache()
|
456 |
gc.collect()
|
457 |
if check_token_status and max_token is not None:
|
458 |
+
return output_text, token_overflow
|
459 |
+
return output_text
|
460 |
|
461 |
def run_self_agent(self, message: str,
|
462 |
temperature: float,
|
|
|
526 |
\"\"\"
|
527 |
{function_response}
|
528 |
\"\"\"
|
529 |
+
Summarize the function calls' l responses in one sentence with all necessary information.
|
530 |
"""
|
531 |
conversation = [{"role": "user", "content": prompt}]
|
532 |
output = self.llm_infer(
|
|
|
571 |
function_response=function_response,
|
572 |
temperature=0.1,
|
573 |
max_new_tokens=512,
|
574 |
+
max_token=131072)
|
575 |
input_list.insert(last_call_idx + 1, {'role': 'tool', 'content': result_summary})
|
576 |
status['summarized_index'] = last_call_idx + 2
|
577 |
idx += 1
|
|
|
593 |
function_response=function_response,
|
594 |
temperature=0.1,
|
595 |
max_new_tokens=512,
|
596 |
+
max_token=131072)
|
597 |
tool_calls = json.loads(input_list[last_call_idx]['tool_calls'])
|
598 |
for tool_call in tool_calls:
|
599 |
del tool_call['call_id']
|
|
|
615 |
def run_gradio_chat(self, message: str,
|
616 |
history: list,
|
617 |
temperature: float,
|
618 |
+
max_new_tokens: int = 2048,
|
619 |
+
max_token: int = 131072,
|
620 |
+
call_agent: bool = False,
|
621 |
+
conversation: gr.State = None,
|
622 |
max_round: int = 5,
|
623 |
seed: int = None,
|
624 |
call_agent_level: int = 0,
|
|
|
767 |
logger.info("Forced final answer after error: %s", final_answer[:100])
|
768 |
yield history
|
769 |
return final_answer
|
770 |
+
return error_msg
|
771 |
+
|
772 |
+
def run_gradio_chat_batch(self, messages: List[str],
|
773 |
+
temperature: float,
|
774 |
+
max_new_tokens: int = 2048,
|
775 |
+
max_token: int = 131072,
|
776 |
+
call_agent: bool = False,
|
777 |
+
conversation: List = None,
|
778 |
+
max_round: int = 5,
|
779 |
+
seed: int = None,
|
780 |
+
call_agent_level: int = 0):
|
781 |
+
"""Run batch inference for multiple messages."""
|
782 |
+
logger.info("Starting batch chat for %d messages", len(messages))
|
783 |
+
batch_results = []
|
784 |
+
|
785 |
+
for message in messages:
|
786 |
+
# Initialize conversation for each message
|
787 |
+
conv = self.initialize_conversation(message, conversation, history=None)
|
788 |
+
picked_tools_prompt, call_agent_level = self.initialize_tools_prompt(
|
789 |
+
call_agent, call_agent_level, message)
|
790 |
+
|
791 |
+
# Run single inference for simplicity (extend for multi-round if needed)
|
792 |
+
output, token_overflow = self.llm_infer(
|
793 |
+
messages=conv,
|
794 |
+
temperature=temperature,
|
795 |
+
tools=picked_tools_prompt,
|
796 |
+
max_new_tokens=max_new_tokens,
|
797 |
+
max_token=max_token,
|
798 |
+
skip_special_tokens=False,
|
799 |
+
seed=seed,
|
800 |
+
check_token_status=True
|
801 |
+
)
|
802 |
+
|
803 |
+
if output is None:
|
804 |
+
logger.warning("Token limit exceeded for message: %s", message[:100])
|
805 |
+
batch_results.append("Token limit exceeded.")
|
806 |
+
else:
|
807 |
+
batch_results.append(output)
|
808 |
+
|
809 |
+
logger.info("Batch chat completed for %d messages", len(messages))
|
810 |
+
return batch_results
|