Hieucyber2208 commited on
Commit
2e37701
·
verified ·
1 Parent(s): f269e0c

Update src/generation/llm.py

Browse files
Files changed (1) hide show
  1. src/generation/llm.py +10 -7
src/generation/llm.py CHANGED
@@ -70,7 +70,7 @@ class LLM:
70
  input_variables=["cuisines", "dishes", "price_ranges", "query"]
71
  )
72
 
73
- def generate(self, prompt: str, max_length: int = 1000) -> str:
74
  """
75
  Generate text using the LLM.
76
 
@@ -91,12 +91,15 @@ class LLM:
91
  inputs = self.tokenizer(prompt_with_template, return_tensors="pt").to(self.llm.device)
92
  # Generate text
93
  outputs = self.llm.generate(
94
- **inputs,
95
- max_new_tokens=max_length,
96
- temperature=0.7,
97
- do_sample=True,
98
- pad_token_id=self.tokenizer.eos_token_id
99
- )
 
 
 
100
  # Decode the generated tokens
101
  response = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
102
  print("Response generated successfully!")
 
70
  input_variables=["cuisines", "dishes", "price_ranges", "query"]
71
  )
72
 
73
+ def generate(self, prompt: str, max_length: int = 100) -> str:
74
  """
75
  Generate text using the LLM.
76
 
 
91
  inputs = self.tokenizer(prompt_with_template, return_tensors="pt").to(self.llm.device)
92
  # Generate text
93
  outputs = self.llm.generate(
94
+ **inputs,
95
+ max_new_tokens=max_length,
96
+ temperature=0.3,
97
+ do_sample=False,
98
+ top_p=1.0,
99
+ top_k=1,
100
+ pad_token_id=self.tokenizer.eos_token_id,
101
+ eos_token_id=self.tokenizer.eos_token_id,
102
+ )
103
  # Decode the generated tokens
104
  response = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
105
  print("Response generated successfully!")