Lord-Raven commited on
Commit
50b814c
·
1 Parent(s): 09439d2

Experimenting with few-shot classification.

Browse files
Files changed (2) hide show
  1. app.py +40 -2
  2. requirements.txt +1 -0
app.py CHANGED
@@ -6,6 +6,7 @@ from transformers import pipeline
6
  from optimum.onnxruntime import ORTModelForSequenceClassification
7
  from fastapi import FastAPI
8
  from fastapi.middleware.cors import CORSMiddleware
 
9
 
10
  # CORS Config
11
  app = FastAPI()
@@ -18,6 +19,26 @@ app.add_middleware(
18
  allow_headers=["*"],
19
  )
20
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21
  # "xenova/mobilebert-uncased-mnli" "typeform/mobilebert-uncased-mnli" Fast but small--same as bundled in Statosphere
22
  # "xenova/deberta-v3-base-tasksource-nli" Not impressed
23
  # "Xenova/bart-large-mnli" A bit slow
@@ -29,22 +50,39 @@ tokenizer_name = "cross-encoder/nli-deberta-v3-small"
29
  model = ORTModelForSequenceClassification.from_pretrained(model_name, file_name=file_name)
30
  tokenizer = AutoTokenizer.from_pretrained(tokenizer_name, model_max_length=512)
31
 
 
 
 
 
 
 
32
  # file = cached_download("https://huggingface.co/" + model_name + "")
33
  # sess = InferenceSession(file)
34
 
35
  classifier = pipeline(task="zero-shot-classification", model=model, tokenizer=tokenizer)
36
 
37
- def zero_shot_classification(data_string, request: gradio.Request):
38
  if request:
39
  if request.headers["origin"] not in ["https://statosphere-3704059fdd7e.c5v4v4jx6pq5.win", "https://crunchatize-77a78ffcc6a6.c5v4v4jx6pq5.win", "https://ravenok-statosphere-backend.hf.space"]:
40
  return "{}"
41
  data = json.loads(data_string)
 
 
 
 
 
 
42
  results = classifier(data['sequence'], candidate_labels=data['candidate_labels'], hypothesis_template=data['hypothesis_template'], multi_label=data['multi_label'])
43
  response_string = json.dumps(results)
44
  return response_string
45
 
 
 
 
 
 
46
  gradio_interface = gradio.Interface(
47
- fn = zero_shot_classification,
48
  inputs = gradio.Textbox(label="JSON Input"),
49
  outputs = gradio.Textbox()
50
  )
 
6
  from optimum.onnxruntime import ORTModelForSequenceClassification
7
  from fastapi import FastAPI
8
  from fastapi.middleware.cors import CORSMiddleware
9
+ from setfit import SetFitModel
10
 
11
  # CORS Config
12
  app = FastAPI()
 
19
  allow_headers=["*"],
20
  )
21
 
22
+ class OnnxSetFitModel:
23
+ def __init__(self, ort_model, tokenizer, model_head):
24
+ self.ort_model = ort_model
25
+ self.tokenizer = tokenizer
26
+ self.model_head = model_head
27
+
28
+ def predict(self, inputs):
29
+ encoded_inputs = self.tokenizer(
30
+ inputs, padding=True, truncation=True, return_tensors="pt"
31
+ ).to(self.ort_model.device)
32
+
33
+ outputs = self.ort_model(**encoded_inputs)
34
+ embeddings = mean_pooling(
35
+ outputs["last_hidden_state"], encoded_inputs["attention_mask"]
36
+ )
37
+ return self.model_head.predict(embeddings.cpu())
38
+
39
+ def __call__(self, inputs):
40
+ return self.predict(inputs)
41
+
42
  # "xenova/mobilebert-uncased-mnli" "typeform/mobilebert-uncased-mnli" Fast but small--same as bundled in Statosphere
43
  # "xenova/deberta-v3-base-tasksource-nli" Not impressed
44
  # "Xenova/bart-large-mnli" A bit slow
 
50
  model = ORTModelForSequenceClassification.from_pretrained(model_name, file_name=file_name)
51
  tokenizer = AutoTokenizer.from_pretrained(tokenizer_name, model_max_length=512)
52
 
53
+ few_shot_model_name = "moshew/bge-small-en-v1.5_setfit-sst2-english"
54
+ few_shot_model = setFitModel.from_pretrained(few_shot_model_name)
55
+ few_shot_tokenizer = AutoTokenizer.from_pretrained('bge_auto_opt_04', model_max_length=512)
56
+ ort_model = ORTModelForFeatureExtraction.from_pretrained('bge_auto_opt_O4')
57
+ onnx_few_shot_model = OnnxSetFitModel(ort_model, tokenizer, model.model_head)
58
+
59
  # file = cached_download("https://huggingface.co/" + model_name + "")
60
  # sess = InferenceSession(file)
61
 
62
  classifier = pipeline(task="zero-shot-classification", model=model, tokenizer=tokenizer)
63
 
64
+ def classify(data_string, request: gradio.Request):
65
  if request:
66
  if request.headers["origin"] not in ["https://statosphere-3704059fdd7e.c5v4v4jx6pq5.win", "https://crunchatize-77a78ffcc6a6.c5v4v4jx6pq5.win", "https://ravenok-statosphere-backend.hf.space"]:
67
  return "{}"
68
  data = json.loads(data_string)
69
+ if (data['task'] == 'few_shot_classification')
70
+ return few_shot_classification
71
+ else
72
+ return zero_shot_classification
73
+
74
+ def zero_shot_classification(data):
75
  results = classifier(data['sequence'], candidate_labels=data['candidate_labels'], hypothesis_template=data['hypothesis_template'], multi_label=data['multi_label'])
76
  response_string = json.dumps(results)
77
  return response_string
78
 
79
+ def few_shot_classification(data):
80
+ results = onnx_few_shot_model(data['sequence'])
81
+ response_string = json.dumps(results)
82
+ return response_string
83
+
84
  gradio_interface = gradio.Interface(
85
+ fn = classify,
86
  inputs = gradio.Textbox(label="JSON Input"),
87
  outputs = gradio.Textbox()
88
  )
requirements.txt CHANGED
@@ -2,5 +2,6 @@ fastapi==0.88.0
2
  json5==0.9.10
3
  numpy==1.23.4
4
  optimum[exporters,onnxruntime]==1.21.3
 
5
  torch==1.12.1
6
  torchvision==0.13.1
 
2
  json5==0.9.10
3
  numpy==1.23.4
4
  optimum[exporters,onnxruntime]==1.21.3
5
+ setfit==1.0.3
6
  torch==1.12.1
7
  torchvision==0.13.1