Spaces:
Running
on
Zero
Running
on
Zero
File size: 1,860 Bytes
93643d5 040c521 b0d2a02 e83c60c 3554a8b bd9482b b0d2a02 dc02763 6c40a85 dc02763 cfd4b0d f4c9eb8 fd25b82 f4c9eb8 fd25b82 b0d2a02 8a243e5 b0d2a02 a822923 f3bcef9 fd25b82 47a0109 daac94f 9704577 0686401 5071704 93643d5 daac94f 0686401 93643d5 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 |
import gradio
import json
import torch
from transformers import AutoTokenizer
from transformers import pipeline
from optimum.onnxruntime import ORTModelForSequenceClassification
# "xenova/mobilebert-uncased-mnli" "typeform/mobilebert-uncased-mnli" Fast but small
# "xenova/deberta-v3-base-tasksource-nli" Not impressed
# "Xenova/bart-large-mnli" A bit slow
# "Xenova/distilbert-base-uncased-mnli" "typeform/distilbert-base-uncased-mnli" Bad answers
# "Xenova/deBERTa-v3-base-mnli" "MoritzLaurer/DeBERTa-v3-base-mnli" Still a bit slow and not great answers
model_name = "MoritzLaurer/deberta-v3-large-zeroshot-v2.0"
file_name = "onnx/model_quantized.onnx"
tokenizer_name = model_name
model = ORTModelForSequenceClassification.from_pretrained(model_name) #, file_name=file_name)
tokenizer = AutoTokenizer.from_pretrained(tokenizer_name, model_max_length=512)
# file = cached_download("https://huggingface.co/" + model_name + "")
# sess = InferenceSession(file)
classifier = pipeline(task="zero-shot-classification", model=model, tokenizer=tokenizer)
def zero_shot_classification(data_string, request: gradio.Request):
if request:
print("Request headers dictionary:", request.headers)
if request.headers["origin"] not in ["https://statosphere-3704059fdd7e.c5v4v4jx6pq5.win", "https://jhuhman-statosphere-backend.hf.space"]:
return "{}"
print(data_string)
data = json.loads(data_string)
print(data)
results = classifier(data['sequence'], candidate_labels=data['candidate_labels'], hypothesis_template=data['hypothesis_template'], multi_label=data['multi_label'])
response_string = json.dumps(results)
return response_string
gradio_interface = gradio.Interface(
fn = zero_shot_classification,
inputs = gradio.Textbox(label="JSON Input"),
outputs = gradio.Textbox()
)
gradio_interface.launch() |