Lord-Raven commited on
Commit
8a243e5
·
1 Parent(s): 4d6800e

Trying to use ONNX model.

Browse files
Files changed (2) hide show
  1. app.py +9 -2
  2. requirements.txt +2 -1
app.py CHANGED
@@ -4,7 +4,9 @@ import torch
4
  from transformers import AutoTokenizer
5
  from fastapi import FastAPI
6
  from fastapi.middleware.cors import CORSMiddleware
7
- from onnx_transformers import pipeline
 
 
8
 
9
  class OnnxTokenClassificationPipeline(TokenClassificationPipeline):
10
 
@@ -20,8 +22,13 @@ app.add_middleware(
20
  )
21
 
22
  model_name = "xenova/mobilebert-uncased-mnli"
 
 
23
 
24
- classifier = pipeline(task="zero-shot-classification", model=model_name, onnx=True)
 
 
 
25
 
26
  def zero_shot_classification(data_string):
27
  print(data_string)
 
4
  from transformers import AutoTokenizer
5
  from fastapi import FastAPI
6
  from fastapi.middleware.cors import CORSMiddleware
7
+ from transformers import pipeline
8
+ from huggingface_hub import cached_download
9
+ from optimum.onnxruntime import ORTModelForQuestionAnswering
10
 
11
  class OnnxTokenClassificationPipeline(TokenClassificationPipeline):
12
 
 
22
  )
23
 
24
  model_name = "xenova/mobilebert-uncased-mnli"
25
+ model = ORTModelForQuestionAnswering.from_pretrained(model_name)
26
+ tokenizer = AutoTokenizer.from_pretrained("typeform/mobilebert-uncased-mnli")
27
 
28
+ # file = cached_download("https://huggingface.co/" + model_name + "")
29
+ # sess = InferenceSession(file)
30
+
31
+ classifier = pipeline(task="zero-shot-classification", model=model, tokenizer=tokenizer)
32
 
33
  def zero_shot_classification(data_string):
34
  print(data_string)
requirements.txt CHANGED
@@ -1,8 +1,9 @@
1
  fastapi==0.88.0
 
2
  json5==0.9.10
3
  numpy==1.23.4
4
  onnxruntime==1.18.1
5
- https://github.com/patil-suraj/onnx_transformers
6
  torch==1.12.1
7
  torchvision==0.13.1
8
  transformers==4.44.0
 
1
  fastapi==0.88.0
2
+ huggingface_hub
3
  json5==0.9.10
4
  numpy==1.23.4
5
  onnxruntime==1.18.1
6
+ optimum==1.21.3
7
  torch==1.12.1
8
  torchvision==0.13.1
9
  transformers==4.44.0