Lord-Raven commited on
Commit
fbcdba4
·
1 Parent(s): bd9a53f

Experimenting with few-shot classification.

Browse files
Files changed (1) hide show
  1. app.py +32 -2
app.py CHANGED
@@ -7,8 +7,10 @@ from optimum.onnxruntime import ORTModelForSequenceClassification
7
  from optimum.onnxruntime import ORTModelForFeatureExtraction
8
  from fastapi import FastAPI
9
  from fastapi.middleware.cors import CORSMiddleware
10
- from setfit import SetFitModel
11
  from setfit.exporters.utils import mean_pooling
 
 
12
 
13
  # CORS Config
14
  app = FastAPI()
@@ -55,9 +57,36 @@ classifier = pipeline(task="zero-shot-classification", model=model, tokenizer=to
55
 
56
  few_shot_tokenizer = AutoTokenizer.from_pretrained('BAAI/bge-small-en-v1.5', model_max_length=512)
57
  ort_model = ORTModelForFeatureExtraction.from_pretrained('BAAI/bge-small-en-v1.5', file_name="onnx/model.onnx")
58
- few_shot_model = SetFitModel.from_pretrained("moshew/bge-small-en-v1.5_setfit-sst2-english")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
59
  onnx_few_shot_model = OnnxSetFitModel(ort_model, few_shot_tokenizer, few_shot_model.model_head)
60
 
 
 
61
  def classify(data_string, request: gradio.Request):
62
  if request:
63
  if request.headers["origin"] not in ["https://statosphere-3704059fdd7e.c5v4v4jx6pq5.win", "https://crunchatize-77a78ffcc6a6.c5v4v4jx6pq5.win", "https://ravenok-statosphere-backend.hf.space"]:
@@ -75,6 +104,7 @@ def zero_shot_classification(data):
75
 
76
  def few_shot_classification(data):
77
  results = onnx_few_shot_model(data['sequence'])
 
78
  response_string = json.dumps(results.tolist())
79
  return response_string
80
 
 
7
  from optimum.onnxruntime import ORTModelForFeatureExtraction
8
  from fastapi import FastAPI
9
  from fastapi.middleware.cors import CORSMiddleware
10
+ from setfit import SetFitModel, Trainer, TrainingArguments
11
  from setfit.exporters.utils import mean_pooling
12
+ from setfit import get_templated_dataset
13
+ from datasets import load_dataset
14
 
15
  # CORS Config
16
  app = FastAPI()
 
57
 
58
  few_shot_tokenizer = AutoTokenizer.from_pretrained('BAAI/bge-small-en-v1.5', model_max_length=512)
59
  ort_model = ORTModelForFeatureExtraction.from_pretrained('BAAI/bge-small-en-v1.5', file_name="onnx/model.onnx")
60
+ few_shot_model = SetFitModel.from_pretrained("moshew/bge-small-en-v1.5_setfit-sst2-english", multi_target_strategy="multi-output")
61
+
62
+ test_dataset = load_dataset("dair-ai/emotion", "split", split="test")
63
+ print(test_dataset)
64
+ classes = test_dataset.features["label"].names
65
+ print(classes)
66
+ train_dataset = get_templated_dataset()
67
+ print(train_dataset)
68
+ print(train_dataset[0])
69
+
70
+ args = TrainingArguments(
71
+ batch_size=32,
72
+ num_epochs=1
73
+ )
74
+
75
+ trainer = Trainer(
76
+ model=few_shot_model,
77
+ args=args,
78
+ train_dataset=train_dataset,
79
+ eval_dataset=test_dataset
80
+ )
81
+ trainer.train()
82
+
83
+ metrics = trainer.evaluate()
84
+ print(metrics)
85
+
86
  onnx_few_shot_model = OnnxSetFitModel(ort_model, few_shot_tokenizer, few_shot_model.model_head)
87
 
88
+
89
+
90
  def classify(data_string, request: gradio.Request):
91
  if request:
92
  if request.headers["origin"] not in ["https://statosphere-3704059fdd7e.c5v4v4jx6pq5.win", "https://crunchatize-77a78ffcc6a6.c5v4v4jx6pq5.win", "https://ravenok-statosphere-backend.hf.space"]:
 
104
 
105
  def few_shot_classification(data):
106
  results = onnx_few_shot_model(data['sequence'])
107
+ print([classes[idx] for idx in results)
108
  response_string = json.dumps(results.tolist())
109
  return response_string
110