Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
@@ -80,19 +80,7 @@ class TeapotAI:
|
|
80 |
if self.settings.verbose:
|
81 |
print(f"Loading Model: {self.model} Revision: {self.model_revision or 'Latest'}")
|
82 |
|
83 |
-
|
84 |
-
|
85 |
-
self.tokenizer = AutoTokenizer.from_pretrained(self.model)
|
86 |
-
model = AutoModelForSeq2SeqLM.from_pretrained(self.model)
|
87 |
-
model.eval()
|
88 |
-
|
89 |
-
# Quantization settings
|
90 |
-
quantization_dtype = torch.qint8 # or torch.float16
|
91 |
-
quantization_config = torch.quantization.get_default_qconfig('fbgemm') # or 'onednn'
|
92 |
-
|
93 |
-
self.quantized_model = torch.quantization.quantize_dynamic(
|
94 |
-
model, {torch.nn.Linear}, dtype=quantization_dtype
|
95 |
-
)
|
96 |
|
97 |
self.documents = documents
|
98 |
|
@@ -154,14 +142,8 @@ class TeapotAI:
|
|
154 |
str: The generated output from the model.
|
155 |
"""
|
156 |
|
157 |
-
inputs = self.tokenizer(input_text, return_tensors="pt", padding=True, truncation=True)
|
158 |
-
|
159 |
-
|
160 |
-
with torch.no_grad():
|
161 |
-
outputs = self.quantized_model.generate(inputs["input_ids"], max_length=512)
|
162 |
|
163 |
-
|
164 |
-
result = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
|
165 |
|
166 |
|
167 |
if self.settings.log_level == "debug":
|
|
|
80 |
if self.settings.verbose:
|
81 |
print(f"Loading Model: {self.model} Revision: {self.model_revision or 'Latest'}")
|
82 |
|
83 |
+
self.generator = pipeline("text2text-generation", model=self.model, revision=self.model_revision) if model_revision else pipeline("text2text-generation", model=self.model)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
84 |
|
85 |
self.documents = documents
|
86 |
|
|
|
142 |
str: The generated output from the model.
|
143 |
"""
|
144 |
|
|
|
|
|
|
|
|
|
|
|
145 |
|
146 |
+
result = self.generator(input_text)[0].get("generated_text")
|
|
|
147 |
|
148 |
|
149 |
if self.settings.log_level == "debug":
|