ammarnasr commited on
Commit
bbec848
·
1 Parent(s): 6f0551a

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +6 -2
handler.py CHANGED
@@ -15,9 +15,13 @@ class EndpointHandler:
15
 
16
  def __call__(self, data: Dict[str, Any]) -> List[str]:
17
  prompt = data["inputs"]
 
 
 
 
18
  input_ids = self.tokenizer(prompt, return_tensors="pt").input_ids.to(self.device)
19
- generated_ids = self.model.generate(input_ids)
20
- return [self.tokenizer.decode(generated_ids[0], skip_special_tokens=True)]
21
 
22
 
23
 
 
15
 
16
  def __call__(self, data: Dict[str, Any]) -> List[str]:
17
  prompt = data["inputs"]
18
+ if "config" in data:
19
+ config = data.pop("config", None)
20
+ else:
21
+ config = {'max_new_tokens':100}
22
  input_ids = self.tokenizer(prompt, return_tensors="pt").input_ids.to(self.device)
23
+ generated_ids = self.model.generate(input_ids, **config)
24
+ return self.tokenizer.decode(generated_ids[0], skip_special_tokens=True)
25
 
26
 
27