Lord-Raven
Trying to use ONNX model.
b0d2a02
raw
history blame
3.99 kB
import gradio
import json
import torch
from transformers import pipeline
from transformers import AutoTokenizer
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from onnxruntime import (
InferenceSession, SessionOptions, GraphOptimizationLevel
)
from transformers import (
TokenClassificationPipeline, AutoTokenizer, AutoModelForTokenClassification
)
class OnnxTokenClassificationPipeline(TokenClassificationPipeline):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
def _forward(self, model_inputs):
"""
Forward pass through the model. This method is not to be called by the user directly and is only used
by the pipeline to perform the actual predictions.
This is where we will define the actual process to do inference with the ONNX model and the session created
before.
"""
# This comes from the original implementation of the pipeline
special_tokens_mask = model_inputs.pop("special_tokens_mask")
offset_mapping = model_inputs.pop("offset_mapping", None)
sentence = model_inputs.pop("sentence")
inputs = {k: v.cpu().detach().numpy() for k, v in model_inputs.items()} # dict of numpy arrays
outputs_name = session.get_outputs()[0].name # get the name of the output tensor
logits = session.run(output_names=[outputs_name], input_feed=inputs)[0] # run the session
logits = torch.tensor(logits) # convert to torch tensor to be compatible with the original implementation
return {
"logits": logits,
"special_tokens_mask": special_tokens_mask,
"offset_mapping": offset_mapping,
"sentence": sentence,
**model_inputs,
}
# We need to override the preprocess method because the onnx model is waiting for the attention masks as inputs
# along with the embeddings.
def preprocess(self, sentence, offset_mapping=None):
truncation = True if self.tokenizer.model_max_length and self.tokenizer.model_max_length > 0 else False
model_inputs = self.tokenizer(
sentence,
return_attention_mask=True, # This is the only difference from the original implementation
return_tensors=self.framework,
truncation=truncation,
return_special_tokens_mask=True,
return_offsets_mapping=self.tokenizer.is_fast,
)
if offset_mapping:
model_inputs["offset_mapping"] = offset_mapping
model_inputs["sentence"] = sentence
return model_inputs
# CORS Config
app = FastAPI()
app.add_middleware(
CORSMiddleware,
allow_origins=["https://jhuhman.com"], #["https://statosphere-3704059fdd7e.c5v4v4jx6pq5.win"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
options = SessionOptions()
options.graph_optimization_level = GraphOptimizationLevel.ORT_ENABLE_ALL
session = InferenceSession("onnx/model.onnx", sess_options=options, providers=["CPUExecutionProvider"])
session.disable_fallback()
model_name = "xenova/mobilebert-uncased-mnli"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForTokenClassification.from_pretrained(model_name)
classifier = OnnxTokenClassificationPipeline(task="zero-shot-classification", model=model, tokenizer=tokenizer, framework="pt", aggregation_strategy="simple")
def zero_shot_classification(data_string):
print(data_string)
data = json.loads(data_string)
print(data)
results = classifier(data['sequence'], candidate_labels=data['candidate_labels'], hypothesis_template=data['hypothesis_template'], multi_label=data['multi_label'])
response_string = json.dumps(results)
return response_string
gradio_interface = gradio.Interface(
fn = zero_shot_classification,
inputs = gradio.Textbox(label="JSON Input"),
outputs = gradio.Textbox()
)
gradio_interface.launch()