import gradio import json import torch from transformers import AutoTokenizer from transformers import pipeline from optimum.onnxruntime import ORTModelForSequenceClassification from optimum.onnxruntime import ORTModelForFeatureExtraction from fastapi import FastAPI from fastapi.middleware.cors import CORSMiddleware from setfit import SetFitModel, SetFitTrainer, Trainer, TrainingArguments from setfit.exporters.utils import mean_pooling from setfit import get_templated_dataset from datasets import load_dataset, Dataset # CORS Config app = FastAPI() app.add_middleware( CORSMiddleware, allow_origins=["https://statosphere-3704059fdd7e.c5v4v4jx6pq5.win","https://crunchatize-77a78ffcc6a6.c5v4v4jx6pq5.win"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) class OnnxSetFitModel: def __init__(self, ort_model, tokenizer, model_head): self.ort_model = ort_model self.tokenizer = tokenizer self.model_head = model_head def predict(self, inputs): encoded_inputs = self.tokenizer( inputs, padding=True, truncation=True, return_tensors="pt" ).to(self.ort_model.device) outputs = self.ort_model(**encoded_inputs) embeddings = mean_pooling( outputs["last_hidden_state"], encoded_inputs["attention_mask"] ) return self.model_head.predict(embeddings.cpu()) def __call__(self, inputs): return self.predict(inputs) # "xenova/mobilebert-uncased-mnli" "typeform/mobilebert-uncased-mnli" Fast but small--same as bundled in Statosphere # "xenova/deberta-v3-base-tasksource-nli" Not impressed # "Xenova/bart-large-mnli" A bit slow # "Xenova/distilbert-base-uncased-mnli" "typeform/distilbert-base-uncased-mnli" Bad answers # "Xenova/deBERTa-v3-base-mnli" "MoritzLaurer/DeBERTa-v3-base-mnli" Still a bit slow and not great answers model_name = "xenova/nli-deberta-v3-small" file_name = "onnx/model_quantized.onnx" tokenizer_name = "cross-encoder/nli-deberta-v3-small" model = ORTModelForSequenceClassification.from_pretrained(model_name, file_name=file_name) tokenizer = AutoTokenizer.from_pretrained(tokenizer_name, model_max_length=512) classifier = pipeline(task="zero-shot-classification", model=model, tokenizer=tokenizer) few_shot_tokenizer = AutoTokenizer.from_pretrained('BAAI/bge-small-en-v1.5', model_max_length=512) ort_model = ORTModelForFeatureExtraction.from_pretrained('BAAI/bge-small-en-v1.5', file_name="onnx/model.onnx") few_shot_model = SetFitModel.from_pretrained("moshew/bge-small-en-v1.5_setfit-sst2-english") candidate_labels = ["true", "false"] reference_dataset = load_dataset("emotion") dummy_dataset = Dataset.from_dict({}) train_dataset = get_templated_dataset(dummy_dataset, candidate_labels=candidate_labels, sample_size=8, template="This statement is {}") args = TrainingArguments( batch_size=32, num_epochs=1 ) trainer = Trainer( model=few_shot_model, args=args, train_dataset=train_dataset, eval_dataset=reference_dataset["test"] ) trainer.train() metrics = trainer.evaluate() print(metrics) onnx_few_shot_model = OnnxSetFitModel(ort_model, few_shot_tokenizer, few_shot_model.model_head) def classify(data_string, request: gradio.Request): if request: if request.headers["origin"] not in ["https://statosphere-3704059fdd7e.c5v4v4jx6pq5.win", "https://crunchatize-77a78ffcc6a6.c5v4v4jx6pq5.win", "https://ravenok-statosphere-backend.hf.space"]: return "{}" data = json.loads(data_string) if 'task' in data and data['task'] == 'few_shot_classification': return few_shot_classification(data) else: return zero_shot_classification(data) def zero_shot_classification(data): results = classifier(data['sequence'], candidate_labels=data['candidate_labels'], hypothesis_template=data['hypothesis_template'], multi_label=data['multi_label']) response_string = json.dumps(results) return response_string def few_shot_classification(data): results = onnx_few_shot_model(data['sequence']) print([classes[idx] for idx in results]) response_string = json.dumps(results.tolist()) return response_string gradio_interface = gradio.Interface( fn = classify, inputs = gradio.Textbox(label="JSON Input"), outputs = gradio.Textbox() ) gradio_interface.launch()