Spaces:
Running
on
Zero
Running
on
Zero
File size: 5,781 Bytes
93643d5 040c521 b0d2a02 e83c60c 3554a8b bd9482b f8672fc fd79eb2 a20be7a 46aa75d fbcdba4 54b3e74 fd79eb2 a13235b fd79eb2 b0d2a02 50b814c 61e4431 50b814c 30d670a 6c40a85 dc02763 cfd4b0d f4c9eb8 39080c2 55d9b22 c44bdaf 55d9b22 1781106 fd25b82 7168d3b b0d2a02 33afb89 fbcdba4 523f8fd f0e8f06 c44bdaf 6df156f f0e8f06 fbcdba4 16e4efd 4b7b387 4d2ee65 6df156f fbcdba4 523f8fd fbcdba4 7168d3b b0d2a02 fbcdba4 50b814c a822923 77ad4bb f3bcef9 47a0109 9231215 50b814c 9704577 0686401 5071704 8ec85f2 cdb9220 d2e06fa 8ec85f2 50b814c 8ec85f2 46a3862 ef3a388 cdb9220 fb2ea03 5a3b82a cdb9220 50b814c 93643d5 50b814c daac94f 0686401 93643d5 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 |
import gradio
import json
import torch
from transformers import AutoTokenizer
from transformers import pipeline
from optimum.onnxruntime import ORTModelForSequenceClassification
from optimum.onnxruntime import ORTModelForFeatureExtraction
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from setfit import SetFitModel, SetFitTrainer, Trainer, TrainingArguments
from setfit.exporters.utils import mean_pooling
from setfit import get_templated_dataset
from datasets import load_dataset, Dataset
# CORS Config
app = FastAPI()
app.add_middleware(
CORSMiddleware,
allow_origins=["https://statosphere-3704059fdd7e.c5v4v4jx6pq5.win","https://crunchatize-77a78ffcc6a6.c5v4v4jx6pq5.win","https://lord-raven.github.io"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
class OnnxSetFitModel:
def __init__(self, ort_model, tokenizer, model_head):
self.ort_model = ort_model
self.tokenizer = tokenizer
self.model_head = model_head
def predict(self, inputs):
encoded_inputs = self.tokenizer(
inputs, padding=True, truncation=True, return_tensors="pt"
).to(self.ort_model.device)
outputs = self.ort_model(**encoded_inputs)
embeddings = mean_pooling(
outputs["last_hidden_state"], encoded_inputs["attention_mask"]
)
return self.model_head.predict(embeddings.cpu())
def predict_proba(self, inputs):
encoded_inputs = self.tokenizer(
inputs, padding=True, truncation=True, return_tensors="pt"
).to(self.ort_model.device)
outputs = self.ort_model(**encoded_inputs)
embeddings = mean_pooling(
outputs["last_hidden_state"], encoded_inputs["attention_mask"]
)
return self.model_head.predict_proba(embeddings.cpu())
def __call__(self, inputs):
return self.predict(inputs)
# "xenova/mobilebert-uncased-mnli" "typeform/mobilebert-uncased-mnli" Fast but small--same as bundled in Statosphere
# "xenova/deberta-v3-base-tasksource-nli" Not impressed
# "Xenova/bart-large-mnli" A bit slow
# "Xenova/distilbert-base-uncased-mnli" "typeform/distilbert-base-uncased-mnli" Bad answers
# "Xenova/deBERTa-v3-base-mnli" "MoritzLaurer/DeBERTa-v3-base-mnli" Still a bit slow and not great answers
# "xenova/nli-deberta-v3-small" "cross-encoder/nli-deberta-v3-small" Was using this for a good while and it was...okay
model_name = "MoritzLaurer/deberta-v3-base-zeroshot-v2.0"
file_name = "onnx/model.onnx"
tokenizer_name = "MoritzLaurer/deberta-v3-base-zeroshot-v2.0"
model = ORTModelForSequenceClassification.from_pretrained(model_name, file_name=file_name)
tokenizer = AutoTokenizer.from_pretrained(tokenizer_name, model_max_length=512)
classifier = pipeline(task="zero-shot-classification", model=model, tokenizer=tokenizer)
few_shot_tokenizer = AutoTokenizer.from_pretrained('BAAI/bge-small-en-v1.5', model_max_length=512) # 'BAAI/bge-small-en-v1.5'
ort_model = ORTModelForFeatureExtraction.from_pretrained('BAAI/bge-small-en-v1.5', file_name="onnx/model.onnx") # 'BAAI/bge-small-en-v1.5'
few_shot_model = SetFitModel.from_pretrained("moshew/bge-small-en-v1.5_setfit-sst2-english") # "moshew/bge-small-en-v1.5_setfit-sst2-english"
# Train few_shot_model
candidate_labels = ["supported", "refuted"]
reference_dataset = load_dataset("SetFit/sst2")
dummy_dataset = Dataset.from_dict({})
train_dataset = get_templated_dataset(dummy_dataset, candidate_labels=candidate_labels, sample_size=8, template="The CONCLUSION is {} by the PASSAGE.")
args = TrainingArguments(
batch_size=32,
num_epochs=1
)
trainer = Trainer(
model=few_shot_model,
args=args,
train_dataset=train_dataset,
eval_dataset=reference_dataset["test"]
)
trainer.train()
# metrics = trainer.evaluate()
# print(metrics)
onnx_few_shot_model = OnnxSetFitModel(ort_model, few_shot_tokenizer, few_shot_model.model_head)
def classify(data_string, request: gradio.Request):
if request:
if request.headers["origin"] not in ["https://statosphere-3704059fdd7e.c5v4v4jx6pq5.win", "https://crunchatize-77a78ffcc6a6.c5v4v4jx6pq5.win", "https://ravenok-statosphere-backend.hf.space"]:
return "{}"
data = json.loads(data_string)
if 'task' in data and data['task'] == 'few_shot_classification':
return few_shot_classification(data)
else:
return zero_shot_classification(data)
def zero_shot_classification(data):
results = classifier(data['sequence'], candidate_labels=data['candidate_labels'], hypothesis_template=data['hypothesis_template'], multi_label=data['multi_label'])
response_string = json.dumps(results)
return response_string
def create_sequences(data):
# return ['###Given:\n' + data['sequence'] + '\n###End Given\n###Hypothesis:\n' + data['hypothesis_template'].format(label) + "\n###End Hypothesis" for label in data['candidate_labels']]
return [data['sequence'] + '\n' + data['hypothesis_template'].format(label) for label in data['candidate_labels']]
def few_shot_classification(data):
sequences = create_sequences(data)
print(sequences)
# results = onnx_few_shot_model(sequences)
probs = onnx_few_shot_model.predict_proba(sequences)
scores = [true[0] for true in probs]
composite = list(zip(scores, data['candidate_labels']))
composite = sorted(composite, key=lambda x: x[0], reverse=True)
labels, scores = zip(*composite)
response_dict = {'scores': scores, 'labels': labels}
print(response_dict)
response_string = json.dumps(response_dict)
return response_string
gradio_interface = gradio.Interface(
fn = classify,
inputs = gradio.Textbox(label="JSON Input"),
outputs = gradio.Textbox()
)
gradio_interface.launch() |