zakerytclarke commited on
Commit
54213f8
·
verified ·
1 Parent(s): d6dc06a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +2 -20
app.py CHANGED
@@ -80,19 +80,7 @@ class TeapotAI:
80
  if self.settings.verbose:
81
  print(f"Loading Model: {self.model} Revision: {self.model_revision or 'Latest'}")
82
 
83
- # self.generator = pipeline("text2text-generation", model=self.model, revision=self.model_revision) if model_revision else pipeline("text2text-generation", model=self.model)
84
-
85
- self.tokenizer = AutoTokenizer.from_pretrained(self.model)
86
- model = AutoModelForSeq2SeqLM.from_pretrained(self.model)
87
- model.eval()
88
-
89
- # Quantization settings
90
- quantization_dtype = torch.qint8 # or torch.float16
91
- quantization_config = torch.quantization.get_default_qconfig('fbgemm') # or 'onednn'
92
-
93
- self.quantized_model = torch.quantization.quantize_dynamic(
94
- model, {torch.nn.Linear}, dtype=quantization_dtype
95
- )
96
 
97
  self.documents = documents
98
 
@@ -154,14 +142,8 @@ class TeapotAI:
154
  str: The generated output from the model.
155
  """
156
 
157
- inputs = self.tokenizer(input_text, return_tensors="pt", padding=True, truncation=True)
158
-
159
-
160
- with torch.no_grad():
161
- outputs = self.quantized_model.generate(inputs["input_ids"], max_length=512)
162
 
163
-
164
- result = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
165
 
166
 
167
  if self.settings.log_level == "debug":
 
80
  if self.settings.verbose:
81
  print(f"Loading Model: {self.model} Revision: {self.model_revision or 'Latest'}")
82
 
83
+ self.generator = pipeline("text2text-generation", model=self.model, revision=self.model_revision) if model_revision else pipeline("text2text-generation", model=self.model)
 
 
 
 
 
 
 
 
 
 
 
 
84
 
85
  self.documents = documents
86
 
 
142
  str: The generated output from the model.
143
  """
144
 
 
 
 
 
 
145
 
146
+ result = self.generator(input_text)[0].get("generated_text")
 
147
 
148
 
149
  if self.settings.log_level == "debug":