Spaces:
Running
on
Zero
Running
on
Zero
File size: 3,474 Bytes
93643d5 040c521 b0d2a02 e83c60c 3554a8b bd9482b fd79eb2 50b814c fd79eb2 6f85da5 fd79eb2 b0d2a02 50b814c 30d670a 6c40a85 dc02763 cfd4b0d f4c9eb8 30d670a 3e5a168 30d670a 1781106 fd25b82 7168d3b b0d2a02 3dd54b1 7168d3b b0d2a02 50b814c a822923 77ad4bb f3bcef9 47a0109 320dce1 50b814c 320dce1 50b814c 9704577 0686401 5071704 50b814c 93643d5 50b814c daac94f 0686401 93643d5 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 |
import gradio
import json
import torch
from transformers import AutoTokenizer
from transformers import pipeline
from optimum.onnxruntime import ORTModelForSequenceClassification
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from setfit import SetFitModel
# CORS Config
app = FastAPI()
app.add_middleware(
CORSMiddleware,
allow_origins=["https://statosphere-3704059fdd7e.c5v4v4jx6pq5.win","https://crunchatize-77a78ffcc6a6.c5v4v4jx6pq5.win"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
class OnnxSetFitModel:
def __init__(self, ort_model, tokenizer, model_head):
self.ort_model = ort_model
self.tokenizer = tokenizer
self.model_head = model_head
def predict(self, inputs):
encoded_inputs = self.tokenizer(
inputs, padding=True, truncation=True, return_tensors="pt"
).to(self.ort_model.device)
outputs = self.ort_model(**encoded_inputs)
embeddings = mean_pooling(
outputs["last_hidden_state"], encoded_inputs["attention_mask"]
)
return self.model_head.predict(embeddings.cpu())
def __call__(self, inputs):
return self.predict(inputs)
# "xenova/mobilebert-uncased-mnli" "typeform/mobilebert-uncased-mnli" Fast but small--same as bundled in Statosphere
# "xenova/deberta-v3-base-tasksource-nli" Not impressed
# "Xenova/bart-large-mnli" A bit slow
# "Xenova/distilbert-base-uncased-mnli" "typeform/distilbert-base-uncased-mnli" Bad answers
# "Xenova/deBERTa-v3-base-mnli" "MoritzLaurer/DeBERTa-v3-base-mnli" Still a bit slow and not great answers
model_name = "xenova/nli-deberta-v3-small"
file_name = "onnx/model_quantized.onnx"
tokenizer_name = "cross-encoder/nli-deberta-v3-small"
model = ORTModelForSequenceClassification.from_pretrained(model_name, file_name=file_name)
tokenizer = AutoTokenizer.from_pretrained(tokenizer_name, model_max_length=512)
classifier = pipeline(task="zero-shot-classification", model=model, tokenizer=tokenizer)
few_shot_tokenizer = AutoTokenizer.from_pretrained('BAAI/bge-small-en-v1.5', model_max_length=512)
ort_model = ORTModelForFeatureExtraction.from_pretrained('BAAI/bge-small-en-v1.5')
few_shot_model = SetFitModel.from_pretrained("moshew/bge-small-en-v1.5_setfit-sst2-english")
onnx_few_shot_model = OnnxSetFitModel(ort_model, few_shot_tokenizer, few_shot_model.model_head)
def classify(data_string, request: gradio.Request):
if request:
if request.headers["origin"] not in ["https://statosphere-3704059fdd7e.c5v4v4jx6pq5.win", "https://crunchatize-77a78ffcc6a6.c5v4v4jx6pq5.win", "https://ravenok-statosphere-backend.hf.space"]:
return "{}"
data = json.loads(data_string)
if data['task'] == 'few_shot_classification':
return few_shot_classification
else:
return zero_shot_classification
def zero_shot_classification(data):
results = classifier(data['sequence'], candidate_labels=data['candidate_labels'], hypothesis_template=data['hypothesis_template'], multi_label=data['multi_label'])
response_string = json.dumps(results)
return response_string
def few_shot_classification(data):
results = onnx_few_shot_model(data['sequence'])
response_string = json.dumps(results)
return response_string
gradio_interface = gradio.Interface(
fn = classify,
inputs = gradio.Textbox(label="JSON Input"),
outputs = gradio.Textbox()
)
gradio_interface.launch() |