Ali2206 commited on
Commit
5caebdc
·
verified ·
1 Parent(s): f260d4a

Update src/txagent/txagent.py

Browse files
Files changed (1) hide show
  1. src/txagent/txagent.py +72 -20
src/txagent/txagent.py CHANGED
@@ -74,10 +74,20 @@ class TxAgent:
74
  return f"The model {model_name} is already loaded."
75
  self.model_name = model_name
76
 
77
- self.model = LLM(model=self.model_name, dtype="float16", max_model_len=32768, gpu_memory_utilization=0.8)
 
 
 
 
 
 
 
78
  self.chat_template = Template(self.model.get_tokenizer().chat_template)
79
  self.tokenizer = self.model.get_tokenizer()
80
- logger.info("Model %s loaded successfully", self.model_name)
 
 
 
81
  return f"Model {model_name} loaded successfully."
82
 
83
  def load_tooluniverse(self):
@@ -204,7 +214,7 @@ class TxAgent:
204
  )
205
  call_result = self.run_multistep_agent(
206
  full_message, temperature=temperature,
207
- max_new_tokens=512, max_token=2048,
208
  call_agent=False, call_agent_level=call_agent_level)
209
  if call_result is None:
210
  call_result = "⚠️ No content returned from sub-agent."
@@ -277,7 +287,7 @@ class TxAgent:
277
  sub_agent_task = "Sub TxAgent plan: " + str(solution_plan)
278
  call_result = yield from self.run_gradio_chat(
279
  full_message, history=[], temperature=temperature,
280
- max_new_tokens=512, max_token=2048,
281
  call_agent=False, call_agent_level=call_agent_level,
282
  conversation=None, sub_agent_task=sub_agent_task)
283
  if call_result is not None and isinstance(call_result, str):
@@ -387,7 +397,7 @@ class TxAgent:
387
  tools=picked_tools_prompt,
388
  skip_special_tokens=False,
389
  max_new_tokens=2048,
390
- max_token=32768,
391
  check_token_status=True)
392
  if last_outputs_str is None:
393
  logger.warning("Token limit exceeded")
@@ -410,7 +420,7 @@ class TxAgent:
410
 
411
  def llm_infer(self, messages, temperature=0.1, tools=None,
412
  output_begin_string=None, max_new_tokens=512,
413
- max_token=2048, skip_special_tokens=True,
414
  model=None, tokenizer=None, terminators=None,
415
  seed=None, check_token_status=False):
416
  if model is None:
@@ -430,21 +440,23 @@ class TxAgent:
430
 
431
  if check_token_status and max_token is not None:
432
  token_overflow = False
433
- num_input_tokens = len(self.tokenizer.encode(prompt, return_tensors="pt")[0])
 
434
  if num_input_tokens > max_token:
435
  torch.cuda.empty_cache()
436
  gc.collect()
437
- logger.info("Token overflow: %d > %d", num_input_tokens, max_token)
438
  return None, True
439
 
440
  output = model.generate(prompt, sampling_params=sampling_params)
441
- output = output[0].outputs[0].text
442
- logger.debug("Inference output: %s", output[:100])
 
443
  torch.cuda.empty_cache()
444
  gc.collect()
445
  if check_token_status and max_token is not None:
446
- return output, token_overflow
447
- return output
448
 
449
  def run_self_agent(self, message: str,
450
  temperature: float,
@@ -514,7 +526,7 @@ Function calls' responses:
514
  \"\"\"
515
  {function_response}
516
  \"\"\"
517
- Summarize the function calls' responses in one sentence with all necessary information.
518
  """
519
  conversation = [{"role": "user", "content": prompt}]
520
  output = self.llm_infer(
@@ -559,7 +571,7 @@ Summarize the function calls' responses in one sentence with all necessary infor
559
  function_response=function_response,
560
  temperature=0.1,
561
  max_new_tokens=512,
562
- max_token=2048)
563
  input_list.insert(last_call_idx + 1, {'role': 'tool', 'content': result_summary})
564
  status['summarized_index'] = last_call_idx + 2
565
  idx += 1
@@ -581,7 +593,7 @@ Summarize the function calls' responses in one sentence with all necessary infor
581
  function_response=function_response,
582
  temperature=0.1,
583
  max_new_tokens=512,
584
- max_token=2048)
585
  tool_calls = json.loads(input_list[last_call_idx]['tool_calls'])
586
  for tool_call in tool_calls:
587
  del tool_call['call_id']
@@ -603,10 +615,10 @@ Summarize the function calls' responses in one sentence with all necessary infor
603
  def run_gradio_chat(self, message: str,
604
  history: list,
605
  temperature: float,
606
- max_new_tokens: 2048,
607
- max_token: 32768,
608
- call_agent: bool,
609
- conversation: gr.State,
610
  max_round: int = 5,
611
  seed: int = None,
612
  call_agent_level: int = 0,
@@ -755,4 +767,44 @@ Summarize the function calls' responses in one sentence with all necessary infor
755
  logger.info("Forced final answer after error: %s", final_answer[:100])
756
  yield history
757
  return final_answer
758
- return error_msg
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
74
  return f"The model {model_name} is already loaded."
75
  self.model_name = model_name
76
 
77
+ self.model = LLM(
78
+ model=self.model_name,
79
+ dtype="float16",
80
+ max_model_len=131072,
81
+ max_num_batched_tokens=32768, # Increased for A100 80GB
82
+ gpu_memory_utilization=0.9, # Higher utilization for better performance
83
+ trust_remote_code=True
84
+ )
85
  self.chat_template = Template(self.model.get_tokenizer().chat_template)
86
  self.tokenizer = self.model.get_tokenizer()
87
+ logger.info(
88
+ "Model %s loaded with max_model_len=%d, max_num_batched_tokens=%d, gpu_memory_utilization=%.2f",
89
+ self.model_name, 131072, 32768, 0.9
90
+ )
91
  return f"Model {model_name} loaded successfully."
92
 
93
  def load_tooluniverse(self):
 
214
  )
215
  call_result = self.run_multistep_agent(
216
  full_message, temperature=temperature,
217
+ max_new_tokens=512, max_token=131072,
218
  call_agent=False, call_agent_level=call_agent_level)
219
  if call_result is None:
220
  call_result = "⚠️ No content returned from sub-agent."
 
287
  sub_agent_task = "Sub TxAgent plan: " + str(solution_plan)
288
  call_result = yield from self.run_gradio_chat(
289
  full_message, history=[], temperature=temperature,
290
+ max_new_tokens=512, max_token=131072,
291
  call_agent=False, call_agent_level=call_agent_level,
292
  conversation=None, sub_agent_task=sub_agent_task)
293
  if call_result is not None and isinstance(call_result, str):
 
397
  tools=picked_tools_prompt,
398
  skip_special_tokens=False,
399
  max_new_tokens=2048,
400
+ max_token=131072,
401
  check_token_status=True)
402
  if last_outputs_str is None:
403
  logger.warning("Token limit exceeded")
 
420
 
421
  def llm_infer(self, messages, temperature=0.1, tools=None,
422
  output_begin_string=None, max_new_tokens=512,
423
+ max_token=131072, skip_special_tokens=True,
424
  model=None, tokenizer=None, terminators=None,
425
  seed=None, check_token_status=False):
426
  if model is None:
 
440
 
441
  if check_token_status and max_token is not None:
442
  token_overflow = False
443
+ num_input_tokens = len(self.tokenizer.encode(prompt, add_special_tokens=False))
444
+ logger.info("Input prompt tokens: %d, max_token: %d", num_input_tokens, max_token)
445
  if num_input_tokens > max_token:
446
  torch.cuda.empty_cache()
447
  gc.collect()
448
+ logger.warning("Token overflow: %d > %d", num_input_tokens, max_token)
449
  return None, True
450
 
451
  output = model.generate(prompt, sampling_params=sampling_params)
452
+ output_text = output[0].outputs[0].text
453
+ output_tokens = len(self.tokenizer.encode(output_text, add_special_tokens=False))
454
+ logger.debug("Inference output: %s (output tokens: %d)", output_text[:100], output_tokens)
455
  torch.cuda.empty_cache()
456
  gc.collect()
457
  if check_token_status and max_token is not None:
458
+ return output_text, token_overflow
459
+ return output_text
460
 
461
  def run_self_agent(self, message: str,
462
  temperature: float,
 
526
  \"\"\"
527
  {function_response}
528
  \"\"\"
529
+ Summarize the function calls' l responses in one sentence with all necessary information.
530
  """
531
  conversation = [{"role": "user", "content": prompt}]
532
  output = self.llm_infer(
 
571
  function_response=function_response,
572
  temperature=0.1,
573
  max_new_tokens=512,
574
+ max_token=131072)
575
  input_list.insert(last_call_idx + 1, {'role': 'tool', 'content': result_summary})
576
  status['summarized_index'] = last_call_idx + 2
577
  idx += 1
 
593
  function_response=function_response,
594
  temperature=0.1,
595
  max_new_tokens=512,
596
+ max_token=131072)
597
  tool_calls = json.loads(input_list[last_call_idx]['tool_calls'])
598
  for tool_call in tool_calls:
599
  del tool_call['call_id']
 
615
  def run_gradio_chat(self, message: str,
616
  history: list,
617
  temperature: float,
618
+ max_new_tokens: int = 2048,
619
+ max_token: int = 131072,
620
+ call_agent: bool = False,
621
+ conversation: gr.State = None,
622
  max_round: int = 5,
623
  seed: int = None,
624
  call_agent_level: int = 0,
 
767
  logger.info("Forced final answer after error: %s", final_answer[:100])
768
  yield history
769
  return final_answer
770
+ return error_msg
771
+
772
+ def run_gradio_chat_batch(self, messages: List[str],
773
+ temperature: float,
774
+ max_new_tokens: int = 2048,
775
+ max_token: int = 131072,
776
+ call_agent: bool = False,
777
+ conversation: List = None,
778
+ max_round: int = 5,
779
+ seed: int = None,
780
+ call_agent_level: int = 0):
781
+ """Run batch inference for multiple messages."""
782
+ logger.info("Starting batch chat for %d messages", len(messages))
783
+ batch_results = []
784
+
785
+ for message in messages:
786
+ # Initialize conversation for each message
787
+ conv = self.initialize_conversation(message, conversation, history=None)
788
+ picked_tools_prompt, call_agent_level = self.initialize_tools_prompt(
789
+ call_agent, call_agent_level, message)
790
+
791
+ # Run single inference for simplicity (extend for multi-round if needed)
792
+ output, token_overflow = self.llm_infer(
793
+ messages=conv,
794
+ temperature=temperature,
795
+ tools=picked_tools_prompt,
796
+ max_new_tokens=max_new_tokens,
797
+ max_token=max_token,
798
+ skip_special_tokens=False,
799
+ seed=seed,
800
+ check_token_status=True
801
+ )
802
+
803
+ if output is None:
804
+ logger.warning("Token limit exceeded for message: %s", message[:100])
805
+ batch_results.append("Token limit exceeded.")
806
+ else:
807
+ batch_results.append(output)
808
+
809
+ logger.info("Batch chat completed for %d messages", len(messages))
810
+ return batch_results