Spaces:
Running
on
Zero
Running
on
Zero
Lord-Raven
commited on
Commit
·
50b814c
1
Parent(s):
09439d2
Experimenting with few-shot classification.
Browse files- app.py +40 -2
- requirements.txt +1 -0
app.py
CHANGED
@@ -6,6 +6,7 @@ from transformers import pipeline
|
|
6 |
from optimum.onnxruntime import ORTModelForSequenceClassification
|
7 |
from fastapi import FastAPI
|
8 |
from fastapi.middleware.cors import CORSMiddleware
|
|
|
9 |
|
10 |
# CORS Config
|
11 |
app = FastAPI()
|
@@ -18,6 +19,26 @@ app.add_middleware(
|
|
18 |
allow_headers=["*"],
|
19 |
)
|
20 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
21 |
# "xenova/mobilebert-uncased-mnli" "typeform/mobilebert-uncased-mnli" Fast but small--same as bundled in Statosphere
|
22 |
# "xenova/deberta-v3-base-tasksource-nli" Not impressed
|
23 |
# "Xenova/bart-large-mnli" A bit slow
|
@@ -29,22 +50,39 @@ tokenizer_name = "cross-encoder/nli-deberta-v3-small"
|
|
29 |
model = ORTModelForSequenceClassification.from_pretrained(model_name, file_name=file_name)
|
30 |
tokenizer = AutoTokenizer.from_pretrained(tokenizer_name, model_max_length=512)
|
31 |
|
|
|
|
|
|
|
|
|
|
|
|
|
32 |
# file = cached_download("https://huggingface.co/" + model_name + "")
|
33 |
# sess = InferenceSession(file)
|
34 |
|
35 |
classifier = pipeline(task="zero-shot-classification", model=model, tokenizer=tokenizer)
|
36 |
|
37 |
-
def
|
38 |
if request:
|
39 |
if request.headers["origin"] not in ["https://statosphere-3704059fdd7e.c5v4v4jx6pq5.win", "https://crunchatize-77a78ffcc6a6.c5v4v4jx6pq5.win", "https://ravenok-statosphere-backend.hf.space"]:
|
40 |
return "{}"
|
41 |
data = json.loads(data_string)
|
|
|
|
|
|
|
|
|
|
|
|
|
42 |
results = classifier(data['sequence'], candidate_labels=data['candidate_labels'], hypothesis_template=data['hypothesis_template'], multi_label=data['multi_label'])
|
43 |
response_string = json.dumps(results)
|
44 |
return response_string
|
45 |
|
|
|
|
|
|
|
|
|
|
|
46 |
gradio_interface = gradio.Interface(
|
47 |
-
fn =
|
48 |
inputs = gradio.Textbox(label="JSON Input"),
|
49 |
outputs = gradio.Textbox()
|
50 |
)
|
|
|
6 |
from optimum.onnxruntime import ORTModelForSequenceClassification
|
7 |
from fastapi import FastAPI
|
8 |
from fastapi.middleware.cors import CORSMiddleware
|
9 |
+
from setfit import SetFitModel
|
10 |
|
11 |
# CORS Config
|
12 |
app = FastAPI()
|
|
|
19 |
allow_headers=["*"],
|
20 |
)
|
21 |
|
22 |
+
class OnnxSetFitModel:
|
23 |
+
def __init__(self, ort_model, tokenizer, model_head):
|
24 |
+
self.ort_model = ort_model
|
25 |
+
self.tokenizer = tokenizer
|
26 |
+
self.model_head = model_head
|
27 |
+
|
28 |
+
def predict(self, inputs):
|
29 |
+
encoded_inputs = self.tokenizer(
|
30 |
+
inputs, padding=True, truncation=True, return_tensors="pt"
|
31 |
+
).to(self.ort_model.device)
|
32 |
+
|
33 |
+
outputs = self.ort_model(**encoded_inputs)
|
34 |
+
embeddings = mean_pooling(
|
35 |
+
outputs["last_hidden_state"], encoded_inputs["attention_mask"]
|
36 |
+
)
|
37 |
+
return self.model_head.predict(embeddings.cpu())
|
38 |
+
|
39 |
+
def __call__(self, inputs):
|
40 |
+
return self.predict(inputs)
|
41 |
+
|
42 |
# "xenova/mobilebert-uncased-mnli" "typeform/mobilebert-uncased-mnli" Fast but small--same as bundled in Statosphere
|
43 |
# "xenova/deberta-v3-base-tasksource-nli" Not impressed
|
44 |
# "Xenova/bart-large-mnli" A bit slow
|
|
|
50 |
model = ORTModelForSequenceClassification.from_pretrained(model_name, file_name=file_name)
|
51 |
tokenizer = AutoTokenizer.from_pretrained(tokenizer_name, model_max_length=512)
|
52 |
|
53 |
+
few_shot_model_name = "moshew/bge-small-en-v1.5_setfit-sst2-english"
|
54 |
+
few_shot_model = setFitModel.from_pretrained(few_shot_model_name)
|
55 |
+
few_shot_tokenizer = AutoTokenizer.from_pretrained('bge_auto_opt_04', model_max_length=512)
|
56 |
+
ort_model = ORTModelForFeatureExtraction.from_pretrained('bge_auto_opt_O4')
|
57 |
+
onnx_few_shot_model = OnnxSetFitModel(ort_model, tokenizer, model.model_head)
|
58 |
+
|
59 |
# file = cached_download("https://huggingface.co/" + model_name + "")
|
60 |
# sess = InferenceSession(file)
|
61 |
|
62 |
classifier = pipeline(task="zero-shot-classification", model=model, tokenizer=tokenizer)
|
63 |
|
64 |
+
def classify(data_string, request: gradio.Request):
|
65 |
if request:
|
66 |
if request.headers["origin"] not in ["https://statosphere-3704059fdd7e.c5v4v4jx6pq5.win", "https://crunchatize-77a78ffcc6a6.c5v4v4jx6pq5.win", "https://ravenok-statosphere-backend.hf.space"]:
|
67 |
return "{}"
|
68 |
data = json.loads(data_string)
|
69 |
+
if (data['task'] == 'few_shot_classification')
|
70 |
+
return few_shot_classification
|
71 |
+
else
|
72 |
+
return zero_shot_classification
|
73 |
+
|
74 |
+
def zero_shot_classification(data):
|
75 |
results = classifier(data['sequence'], candidate_labels=data['candidate_labels'], hypothesis_template=data['hypothesis_template'], multi_label=data['multi_label'])
|
76 |
response_string = json.dumps(results)
|
77 |
return response_string
|
78 |
|
79 |
+
def few_shot_classification(data):
|
80 |
+
results = onnx_few_shot_model(data['sequence'])
|
81 |
+
response_string = json.dumps(results)
|
82 |
+
return response_string
|
83 |
+
|
84 |
gradio_interface = gradio.Interface(
|
85 |
+
fn = classify,
|
86 |
inputs = gradio.Textbox(label="JSON Input"),
|
87 |
outputs = gradio.Textbox()
|
88 |
)
|
requirements.txt
CHANGED
@@ -2,5 +2,6 @@ fastapi==0.88.0
|
|
2 |
json5==0.9.10
|
3 |
numpy==1.23.4
|
4 |
optimum[exporters,onnxruntime]==1.21.3
|
|
|
5 |
torch==1.12.1
|
6 |
torchvision==0.13.1
|
|
|
2 |
json5==0.9.10
|
3 |
numpy==1.23.4
|
4 |
optimum[exporters,onnxruntime]==1.21.3
|
5 |
+
setfit==1.0.3
|
6 |
torch==1.12.1
|
7 |
torchvision==0.13.1
|