Hieucyber2208 commited on
Commit
752fe56
·
verified ·
1 Parent(s): 034a76b

Update src/generation/llm.py

Browse files
Files changed (1) hide show
  1. src/generation/llm.py +5 -3
src/generation/llm.py CHANGED
@@ -99,10 +99,12 @@ class LLM:
99
  )
100
  # Decode the generated tokens
101
  response = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
 
 
 
 
102
  print("Response generated successfully!")
103
- return response.strip()
104
- except Exception as e:
105
- raise RuntimeError(f"Failed to generate response: {str(e)}")
106
 
107
  def format_query_prompt(self, query: str, cuisines: List[str], dishes: List[str], price_ranges: List[str]) -> str:
108
  """
 
99
  )
100
  # Decode the generated tokens
101
  response = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
102
+ # Strip any system/user metadata
103
+ response = response.replace("system", "").replace("user", "").replace("assistant", "")
104
+ # Remove any extra whitespace or unwanted tokens
105
+ response = " ".join(response.split()).strip()
106
  print("Response generated successfully!")
107
+ return response
 
 
108
 
109
  def format_query_prompt(self, query: str, cuisines: List[str], dishes: List[str], price_ranges: List[str]) -> str:
110
  """