Lord-Raven
Trying to use ONNX model.
8a243e5
raw
history blame
1.53 kB
import gradio
import json
import torch
from transformers import AutoTokenizer
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from transformers import pipeline
from huggingface_hub import cached_download
from optimum.onnxruntime import ORTModelForQuestionAnswering
class OnnxTokenClassificationPipeline(TokenClassificationPipeline):
# CORS Config
app = FastAPI()
app.add_middleware(
CORSMiddleware,
allow_origins=["https://jhuhman.com"], #["https://statosphere-3704059fdd7e.c5v4v4jx6pq5.win"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
model_name = "xenova/mobilebert-uncased-mnli"
model = ORTModelForQuestionAnswering.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained("typeform/mobilebert-uncased-mnli")
# file = cached_download("https://huggingface.co/" + model_name + "")
# sess = InferenceSession(file)
classifier = pipeline(task="zero-shot-classification", model=model, tokenizer=tokenizer)
def zero_shot_classification(data_string):
print(data_string)
data = json.loads(data_string)
print(data)
results = classifier(data['sequence'], candidate_labels=data['candidate_labels'], hypothesis_template=data['hypothesis_template'], multi_label=data['multi_label'])
response_string = json.dumps(results)
return response_string
gradio_interface = gradio.Interface(
fn = zero_shot_classification,
inputs = gradio.Textbox(label="JSON Input"),
outputs = gradio.Textbox()
)
gradio_interface.launch()