File size: 2,203 Bytes
93643d5
040c521
b0d2a02
e83c60c
3554a8b
bd9482b
fd79eb2
 
 
 
 
 
 
 
9cd8fce
fd79eb2
 
 
 
b0d2a02
30d670a
6c40a85
dc02763
cfd4b0d
f4c9eb8
30d670a
3e5a168
30d670a
1781106
fd25b82
b0d2a02
8a243e5
 
 
 
b0d2a02
a822923
 
 
f3bcef9
 
fd25b82
47a0109
daac94f
9704577
0686401
 
5071704
93643d5
 
daac94f
0686401
93643d5
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
import gradio
import json
import torch
from transformers import AutoTokenizer
from transformers import pipeline
from optimum.onnxruntime import ORTModelForSequenceClassification
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware

# CORS Config
app = FastAPI()

app.add_middleware(
    CORSMiddleware,
    allow_origins=["https://statosphere-3704059fdd7e.c5v4v4jx6pq5.win"],
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)

# "xenova/mobilebert-uncased-mnli" "typeform/mobilebert-uncased-mnli" Fast but small--same as bundled in Statosphere
# "xenova/deberta-v3-base-tasksource-nli" Not impressed
# "Xenova/bart-large-mnli" A bit slow
# "Xenova/distilbert-base-uncased-mnli" "typeform/distilbert-base-uncased-mnli" Bad answers
# "Xenova/deBERTa-v3-base-mnli" "MoritzLaurer/DeBERTa-v3-base-mnli" Still a bit slow and not great answers
model_name = "xenova/nli-deberta-v3-small"
file_name = "onnx/model_quantized.onnx"
tokenizer_name = "cross-encoder/nli-deberta-v3-small"
model = ORTModelForSequenceClassification.from_pretrained(model_name, file_name=file_name)
tokenizer = AutoTokenizer.from_pretrained(tokenizer_name, model_max_length=512)

# file = cached_download("https://huggingface.co/" + model_name + "")
# sess = InferenceSession(file)

classifier = pipeline(task="zero-shot-classification", model=model, tokenizer=tokenizer)

def zero_shot_classification(data_string, request: gradio.Request):
    if request:
        print("Request headers dictionary:", request.headers)
        if request.headers["origin"] not in ["https://statosphere-3704059fdd7e.c5v4v4jx6pq5.win", "https://jhuhman-statosphere-backend.hf.space"]:
            return "{}"
    print(data_string)
    data = json.loads(data_string)
    print(data)
    results = classifier(data['sequence'], candidate_labels=data['candidate_labels'], hypothesis_template=data['hypothesis_template'], multi_label=data['multi_label'])
    response_string = json.dumps(results)
    return response_string

gradio_interface = gradio.Interface(
    fn = zero_shot_classification,
    inputs = gradio.Textbox(label="JSON Input"),
    outputs = gradio.Textbox()
)
gradio_interface.launch()