Lord-Raven
Experimenting with few-shot classification.
46aa75d
raw
history blame
3.64 kB
import gradio
import json
import torch
from transformers import AutoTokenizer
from transformers import pipeline
from optimum.onnxruntime import ORTModelForSequenceClassification
from optimum.onnxruntime import ORTModelForFeatureExtraction
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from setfit import SetFitModel
from setfit.exporters.utils import mean_pooling
# CORS Config
app = FastAPI()
app.add_middleware(
CORSMiddleware,
allow_origins=["https://statosphere-3704059fdd7e.c5v4v4jx6pq5.win","https://crunchatize-77a78ffcc6a6.c5v4v4jx6pq5.win"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
class OnnxSetFitModel:
def __init__(self, ort_model, tokenizer, model_head):
self.ort_model = ort_model
self.tokenizer = tokenizer
self.model_head = model_head
def predict(self, inputs):
encoded_inputs = self.tokenizer(
inputs, padding=True, truncation=True, return_tensors="pt"
).to(self.ort_model.device)
outputs = self.ort_model(**encoded_inputs)
embeddings = mean_pooling(
outputs["last_hidden_state"], encoded_inputs["attention_mask"]
)
return self.model_head.predict(embeddings.cpu())
def __call__(self, inputs):
return self.predict(inputs)
# "xenova/mobilebert-uncased-mnli" "typeform/mobilebert-uncased-mnli" Fast but small--same as bundled in Statosphere
# "xenova/deberta-v3-base-tasksource-nli" Not impressed
# "Xenova/bart-large-mnli" A bit slow
# "Xenova/distilbert-base-uncased-mnli" "typeform/distilbert-base-uncased-mnli" Bad answers
# "Xenova/deBERTa-v3-base-mnli" "MoritzLaurer/DeBERTa-v3-base-mnli" Still a bit slow and not great answers
model_name = "xenova/nli-deberta-v3-small"
file_name = "onnx/model_quantized.onnx"
tokenizer_name = "cross-encoder/nli-deberta-v3-small"
model = ORTModelForSequenceClassification.from_pretrained(model_name, file_name=file_name)
tokenizer = AutoTokenizer.from_pretrained(tokenizer_name, model_max_length=512)
classifier = pipeline(task="zero-shot-classification", model=model, tokenizer=tokenizer)
few_shot_tokenizer = AutoTokenizer.from_pretrained('BAAI/bge-small-en-v1.5', model_max_length=512)
ort_model = ORTModelForFeatureExtraction.from_pretrained('BAAI/bge-small-en-v1.5', file_name="onnx/model.onnx")
few_shot_model = SetFitModel.from_pretrained("moshew/bge-small-en-v1.5_setfit-sst2-english")
onnx_few_shot_model = OnnxSetFitModel(ort_model, few_shot_tokenizer, few_shot_model.model_head)
def classify(data_string, request: gradio.Request):
if request:
if request.headers["origin"] not in ["https://statosphere-3704059fdd7e.c5v4v4jx6pq5.win", "https://crunchatize-77a78ffcc6a6.c5v4v4jx6pq5.win", "https://ravenok-statosphere-backend.hf.space"]:
return "{}"
data = json.loads(data_string)
if 'task' in data and data['task'] == 'few_shot_classification':
return few_shot_classification(data)
else:
return zero_shot_classification(data)
def zero_shot_classification(data):
results = classifier(data['sequence'], candidate_labels=data['candidate_labels'], hypothesis_template=data['hypothesis_template'], multi_label=data['multi_label'])
response_string = json.dumps(results)
return response_string
def few_shot_classification(data):
results = onnx_few_shot_model(data['sequence'])
response_string = json.dumps(results)
return response_string
gradio_interface = gradio.Interface(
fn = classify,
inputs = gradio.Textbox(label="JSON Input"),
outputs = gradio.Textbox()
)
gradio_interface.launch()