File size: 5,813 Bytes
93643d5
040c521
b0d2a02
e83c60c
3554a8b
bd9482b
f8672fc
fd79eb2
 
a20be7a
46aa75d
fbcdba4
54b3e74
fd79eb2
 
 
 
 
 
a13235b
fd79eb2
 
 
 
b0d2a02
50b814c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
61e4431
 
 
 
 
 
 
 
 
 
 
50b814c
 
 
30d670a
6c40a85
dc02763
cfd4b0d
f4c9eb8
39080c2
55d9b22
c44bdaf
55d9b22
1781106
fd25b82
7168d3b
b0d2a02
33afb89
 
 
fbcdba4
523f8fd
f0e8f06
c44bdaf
6df156f
f0e8f06
fbcdba4
 
 
 
16e4efd
4b7b387
4d2ee65
6df156f
 
fbcdba4
 
 
523f8fd
 
fbcdba4
7168d3b
b0d2a02
fbcdba4
 
50b814c
a822923
54d574f
f3bcef9
47a0109
9231215
 
 
 
50b814c
 
9704577
0686401
 
5071704
8ec85f2
cdb9220
d2e06fa
8ec85f2
50b814c
8ec85f2
 
46a3862
ef3a388
cdb9220
fb2ea03
 
 
 
 
 
 
5a3b82a
cdb9220
50b814c
 
93643d5
50b814c
daac94f
0686401
93643d5
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
import gradio
import json
import torch
from transformers import AutoTokenizer
from transformers import pipeline
from optimum.onnxruntime import ORTModelForSequenceClassification
from optimum.onnxruntime import ORTModelForFeatureExtraction
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from setfit import SetFitModel, SetFitTrainer, Trainer, TrainingArguments
from setfit.exporters.utils import mean_pooling
from setfit import get_templated_dataset
from datasets import load_dataset, Dataset

# CORS Config
app = FastAPI()

app.add_middleware(
    CORSMiddleware,
    allow_origins=["https://statosphere-3704059fdd7e.c5v4v4jx6pq5.win","https://crunchatize-77a78ffcc6a6.c5v4v4jx6pq5.win","https://lord-raven.github.io"],
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)

class OnnxSetFitModel:
    def __init__(self, ort_model, tokenizer, model_head):
        self.ort_model = ort_model
        self.tokenizer = tokenizer
        self.model_head = model_head

    def predict(self, inputs):
        encoded_inputs = self.tokenizer(
            inputs, padding=True, truncation=True, return_tensors="pt"
        ).to(self.ort_model.device)

        outputs = self.ort_model(**encoded_inputs)
        embeddings = mean_pooling(
            outputs["last_hidden_state"], encoded_inputs["attention_mask"]
        )
        return self.model_head.predict(embeddings.cpu())

    def predict_proba(self, inputs):
        encoded_inputs = self.tokenizer(
            inputs, padding=True, truncation=True, return_tensors="pt"
        ).to(self.ort_model.device)

        outputs = self.ort_model(**encoded_inputs)
        embeddings = mean_pooling(
            outputs["last_hidden_state"], encoded_inputs["attention_mask"]
        )
        return self.model_head.predict_proba(embeddings.cpu())

    def __call__(self, inputs):
        return self.predict(inputs)

# "xenova/mobilebert-uncased-mnli" "typeform/mobilebert-uncased-mnli" Fast but small--same as bundled in Statosphere
# "xenova/deberta-v3-base-tasksource-nli" Not impressed
# "Xenova/bart-large-mnli" A bit slow
# "Xenova/distilbert-base-uncased-mnli" "typeform/distilbert-base-uncased-mnli" Bad answers
# "Xenova/deBERTa-v3-base-mnli" "MoritzLaurer/DeBERTa-v3-base-mnli" Still a bit slow and not great answers
# "xenova/nli-deberta-v3-small" "cross-encoder/nli-deberta-v3-small" Was using this for a good while and it was...okay
model_name = "MoritzLaurer/deberta-v3-base-zeroshot-v2.0"
file_name = "onnx/model.onnx"
tokenizer_name = "MoritzLaurer/deberta-v3-base-zeroshot-v2.0"
model = ORTModelForSequenceClassification.from_pretrained(model_name, file_name=file_name)
tokenizer = AutoTokenizer.from_pretrained(tokenizer_name, model_max_length=512)
classifier = pipeline(task="zero-shot-classification", model=model, tokenizer=tokenizer)

few_shot_tokenizer = AutoTokenizer.from_pretrained('BAAI/bge-small-en-v1.5', model_max_length=512) # 'BAAI/bge-small-en-v1.5'
ort_model = ORTModelForFeatureExtraction.from_pretrained('BAAI/bge-small-en-v1.5', file_name="onnx/model.onnx") # 'BAAI/bge-small-en-v1.5'
few_shot_model = SetFitModel.from_pretrained("moshew/bge-small-en-v1.5_setfit-sst2-english") # "moshew/bge-small-en-v1.5_setfit-sst2-english"

# Train few_shot_model
candidate_labels = ["supported", "refuted"]
reference_dataset = load_dataset("SetFit/sst2")
dummy_dataset = Dataset.from_dict({})
train_dataset = get_templated_dataset(dummy_dataset, candidate_labels=candidate_labels, sample_size=8, template="The CONCLUSION is {} by the PASSAGE.")
args = TrainingArguments(
    batch_size=32,
    num_epochs=1
)
trainer = Trainer(
    model=few_shot_model,
    args=args,
    train_dataset=train_dataset,
    eval_dataset=reference_dataset["test"]
)
trainer.train()

# metrics = trainer.evaluate()
# print(metrics)

onnx_few_shot_model = OnnxSetFitModel(ort_model, few_shot_tokenizer, few_shot_model.model_head)



def classify(data_string, request: gradio.Request):
    if request:
        if request.headers["origin"] not in ["https://statosphere-3704059fdd7e.c5v4v4jx6pq5.win", "https://crunchatize-77a78ffcc6a6.c5v4v4jx6pq5.win", "https://ravenok-statosphere-backend.hf.space", "https://lord-raven.github.io"]:
            return "{}"
    data = json.loads(data_string)
    if 'task' in data and data['task'] == 'few_shot_classification':
        return few_shot_classification(data)
    else:
        return zero_shot_classification(data)

def zero_shot_classification(data):
    results = classifier(data['sequence'], candidate_labels=data['candidate_labels'], hypothesis_template=data['hypothesis_template'], multi_label=data['multi_label'])
    response_string = json.dumps(results)
    return response_string

def create_sequences(data):
    # return ['###Given:\n' + data['sequence'] + '\n###End Given\n###Hypothesis:\n' + data['hypothesis_template'].format(label) + "\n###End Hypothesis" for label in data['candidate_labels']]
    return [data['sequence'] + '\n' + data['hypothesis_template'].format(label) for label in data['candidate_labels']]

def few_shot_classification(data):
    sequences = create_sequences(data)
    print(sequences)
    # results = onnx_few_shot_model(sequences)
    probs = onnx_few_shot_model.predict_proba(sequences)
    scores = [true[0] for true in probs]

    composite = list(zip(scores, data['candidate_labels']))
    composite = sorted(composite, key=lambda x: x[0], reverse=True)

    labels, scores = zip(*composite)

    response_dict = {'scores': scores, 'labels': labels}
    print(response_dict)
    response_string = json.dumps(response_dict)
    return response_string

gradio_interface = gradio.Interface(
    fn = classify,
    inputs = gradio.Textbox(label="JSON Input"),
    outputs = gradio.Textbox()
)
gradio_interface.launch()