Ali2206 commited on
Commit
589b0c2
·
verified ·
1 Parent(s): 2639902

Update src/txagent/txagent.py

Browse files
Files changed (1) hide show
  1. src/txagent/txagent.py +3 -3
src/txagent/txagent.py CHANGED
@@ -78,8 +78,8 @@ class TxAgent:
78
  model=self.model_name,
79
  dtype="float16",
80
  max_model_len=131072,
81
- max_num_batched_tokens=32768, # Increased for A100 80GB
82
- gpu_memory_utilization=0.9, # Higher utilization for better performance
83
  trust_remote_code=True
84
  )
85
  self.chat_template = Template(self.model.get_tokenizer().chat_template)
@@ -428,7 +428,7 @@ class TxAgent:
428
 
429
  logits_processor = self.build_logits_processor(messages, model)
430
  sampling_params = SamplingParams(
431
- temperature=temperature,
432
  max_tokens=max_new_tokens,
433
  seed=seed if seed is not None else self.seed,
434
  )
 
78
  model=self.model_name,
79
  dtype="float16",
80
  max_model_len=131072,
81
+ max_num_batched_tokens=65536, # Increased for A100 80GB
82
+ gpu_memory_utilization=0.95, # Higher utilization for better performance
83
  trust_remote_code=True
84
  )
85
  self.chat_template = Template(self.model.get_tokenizer().chat_template)
 
428
 
429
  logits_processor = self.build_logits_processor(messages, model)
430
  sampling_params = SamplingParams(
431
+ temperature=temperature if temperature is not None else 0.0,
432
  max_tokens=max_new_tokens,
433
  seed=seed if seed is not None else self.seed,
434
  )