ammarnasr commited on
Commit
715997e
·
1 Parent(s): 5f48ffc

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +46 -5
handler.py CHANGED
@@ -1,3 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  from typing import Any, Dict, List
2
  import torch
3
  import transformers
@@ -13,7 +56,7 @@ class EndpointHandler:
13
  self.model = self.model.to(self.device)
14
 
15
 
16
- def __call__(self, data: Dict[str, Any]) -> List[str]:
17
  prompt = data["inputs"]
18
  if "config" in data:
19
  config = data.pop("config", None)
@@ -21,7 +64,5 @@ class EndpointHandler:
21
  config = {'max_new_tokens':100}
22
  input_ids = self.tokenizer(prompt, return_tensors="pt").input_ids.to(self.device)
23
  generated_ids = self.model.generate(input_ids, **config)
24
- return [self.tokenizer.decode(generated_ids[0], skip_special_tokens=True)]
25
-
26
-
27
-
 
1
+ Hugging Face's logo
2
+ Hugging Face
3
+ Search models, datasets, users...
4
+ Models
5
+ Datasets
6
+ Spaces
7
+ Docs
8
+ Solutions
9
+ Pricing
10
+
11
+
12
+
13
+
14
+ ammarnasr
15
+ /
16
+ CodeGen2_1B_merged
17
+
18
+ like
19
+ 0
20
+ Text Generation
21
+ Transformers
22
+ PyTorch
23
+ codegen
24
+ custom_code
25
+ Inference Endpoints
26
+ Model card
27
+ Files and versions
28
+ Community
29
+ Settings
30
+ CodeGen2_1B_merged
31
+ /
32
+ handler.py
33
+ ammarnasr's picture
34
+ ammarnasr
35
+ Update handler.py
36
+ 5f48ffc
37
+ 7 minutes ago
38
+ raw
39
+ history
40
+ blame
41
+ edit
42
+ delete
43
+ 1.06 kB
44
  from typing import Any, Dict, List
45
  import torch
46
  import transformers
 
56
  self.model = self.model.to(self.device)
57
 
58
 
59
+ def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
60
  prompt = data["inputs"]
61
  if "config" in data:
62
  config = data.pop("config", None)
 
64
  config = {'max_new_tokens':100}
65
  input_ids = self.tokenizer(prompt, return_tensors="pt").input_ids.to(self.device)
66
  generated_ids = self.model.generate(input_ids, **config)
67
+ generated_text = self.tokenizer.decode(generated_ids[0], skip_special_tokens=True)
68
+ return [{"generated_text": generated_text}]