Spaces:
Running
on
Zero
Running
on
Zero
Lord-Raven
commited on
Commit
·
fbcdba4
1
Parent(s):
bd9a53f
Experimenting with few-shot classification.
Browse files
app.py
CHANGED
@@ -7,8 +7,10 @@ from optimum.onnxruntime import ORTModelForSequenceClassification
|
|
7 |
from optimum.onnxruntime import ORTModelForFeatureExtraction
|
8 |
from fastapi import FastAPI
|
9 |
from fastapi.middleware.cors import CORSMiddleware
|
10 |
-
from setfit import SetFitModel
|
11 |
from setfit.exporters.utils import mean_pooling
|
|
|
|
|
12 |
|
13 |
# CORS Config
|
14 |
app = FastAPI()
|
@@ -55,9 +57,36 @@ classifier = pipeline(task="zero-shot-classification", model=model, tokenizer=to
|
|
55 |
|
56 |
few_shot_tokenizer = AutoTokenizer.from_pretrained('BAAI/bge-small-en-v1.5', model_max_length=512)
|
57 |
ort_model = ORTModelForFeatureExtraction.from_pretrained('BAAI/bge-small-en-v1.5', file_name="onnx/model.onnx")
|
58 |
-
few_shot_model = SetFitModel.from_pretrained("moshew/bge-small-en-v1.5_setfit-sst2-english")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
59 |
onnx_few_shot_model = OnnxSetFitModel(ort_model, few_shot_tokenizer, few_shot_model.model_head)
|
60 |
|
|
|
|
|
61 |
def classify(data_string, request: gradio.Request):
|
62 |
if request:
|
63 |
if request.headers["origin"] not in ["https://statosphere-3704059fdd7e.c5v4v4jx6pq5.win", "https://crunchatize-77a78ffcc6a6.c5v4v4jx6pq5.win", "https://ravenok-statosphere-backend.hf.space"]:
|
@@ -75,6 +104,7 @@ def zero_shot_classification(data):
|
|
75 |
|
76 |
def few_shot_classification(data):
|
77 |
results = onnx_few_shot_model(data['sequence'])
|
|
|
78 |
response_string = json.dumps(results.tolist())
|
79 |
return response_string
|
80 |
|
|
|
7 |
from optimum.onnxruntime import ORTModelForFeatureExtraction
|
8 |
from fastapi import FastAPI
|
9 |
from fastapi.middleware.cors import CORSMiddleware
|
10 |
+
from setfit import SetFitModel, Trainer, TrainingArguments
|
11 |
from setfit.exporters.utils import mean_pooling
|
12 |
+
from setfit import get_templated_dataset
|
13 |
+
from datasets import load_dataset
|
14 |
|
15 |
# CORS Config
|
16 |
app = FastAPI()
|
|
|
57 |
|
58 |
few_shot_tokenizer = AutoTokenizer.from_pretrained('BAAI/bge-small-en-v1.5', model_max_length=512)
|
59 |
ort_model = ORTModelForFeatureExtraction.from_pretrained('BAAI/bge-small-en-v1.5', file_name="onnx/model.onnx")
|
60 |
+
few_shot_model = SetFitModel.from_pretrained("moshew/bge-small-en-v1.5_setfit-sst2-english", multi_target_strategy="multi-output")
|
61 |
+
|
62 |
+
test_dataset = load_dataset("dair-ai/emotion", "split", split="test")
|
63 |
+
print(test_dataset)
|
64 |
+
classes = test_dataset.features["label"].names
|
65 |
+
print(classes)
|
66 |
+
train_dataset = get_templated_dataset()
|
67 |
+
print(train_dataset)
|
68 |
+
print(train_dataset[0])
|
69 |
+
|
70 |
+
args = TrainingArguments(
|
71 |
+
batch_size=32,
|
72 |
+
num_epochs=1
|
73 |
+
)
|
74 |
+
|
75 |
+
trainer = Trainer(
|
76 |
+
model=few_shot_model,
|
77 |
+
args=args,
|
78 |
+
train_dataset=train_dataset,
|
79 |
+
eval_dataset=test_dataset
|
80 |
+
)
|
81 |
+
trainer.train()
|
82 |
+
|
83 |
+
metrics = trainer.evaluate()
|
84 |
+
print(metrics)
|
85 |
+
|
86 |
onnx_few_shot_model = OnnxSetFitModel(ort_model, few_shot_tokenizer, few_shot_model.model_head)
|
87 |
|
88 |
+
|
89 |
+
|
90 |
def classify(data_string, request: gradio.Request):
|
91 |
if request:
|
92 |
if request.headers["origin"] not in ["https://statosphere-3704059fdd7e.c5v4v4jx6pq5.win", "https://crunchatize-77a78ffcc6a6.c5v4v4jx6pq5.win", "https://ravenok-statosphere-backend.hf.space"]:
|
|
|
104 |
|
105 |
def few_shot_classification(data):
|
106 |
results = onnx_few_shot_model(data['sequence'])
|
107 |
+
print([classes[idx] for idx in results)
|
108 |
response_string = json.dumps(results.tolist())
|
109 |
return response_string
|
110 |
|