Lord-Raven commited on
Commit
0cca822
·
1 Parent(s): 19acbf6

Trying to use ONNX model.

Browse files
Files changed (2) hide show
  1. app.py +2 -68
  2. requirements.txt +1 -0
app.py CHANGED
@@ -1,69 +1,13 @@
1
  import gradio
2
  import json
3
  import torch
4
- from transformers import pipeline
5
  from transformers import AutoTokenizer
6
  from fastapi import FastAPI
7
  from fastapi.middleware.cors import CORSMiddleware
8
- from onnxruntime import (
9
- InferenceSession, SessionOptions, GraphOptimizationLevel
10
- )
11
- from transformers import (
12
- TokenClassificationPipeline, AutoTokenizer, AutoModelForTokenClassification
13
- )
14
 
15
  class OnnxTokenClassificationPipeline(TokenClassificationPipeline):
16
 
17
- def __init__(self, *args, **kwargs):
18
- super().__init__(*args, **kwargs)
19
-
20
-
21
- def _forward(self, model_inputs):
22
- """
23
- Forward pass through the model. This method is not to be called by the user directly and is only used
24
- by the pipeline to perform the actual predictions.
25
- This is where we will define the actual process to do inference with the ONNX model and the session created
26
- before.
27
- """
28
-
29
- # This comes from the original implementation of the pipeline
30
- special_tokens_mask = model_inputs.pop("special_tokens_mask")
31
- offset_mapping = model_inputs.pop("offset_mapping", None)
32
- sentence = model_inputs.pop("sentence")
33
-
34
- inputs = {k: v.cpu().detach().numpy() for k, v in model_inputs.items()} # dict of numpy arrays
35
- outputs_name = session.get_outputs()[0].name # get the name of the output tensor
36
-
37
- logits = session.run(output_names=[outputs_name], input_feed=inputs)[0] # run the session
38
- logits = torch.tensor(logits) # convert to torch tensor to be compatible with the original implementation
39
-
40
- return {
41
- "logits": logits,
42
- "special_tokens_mask": special_tokens_mask,
43
- "offset_mapping": offset_mapping,
44
- "sentence": sentence,
45
- **model_inputs,
46
- }
47
-
48
- # We need to override the preprocess method because the onnx model is waiting for the attention masks as inputs
49
- # along with the embeddings.
50
- def preprocess(self, sentence, offset_mapping=None):
51
- truncation = True if self.tokenizer.model_max_length and self.tokenizer.model_max_length > 0 else False
52
- model_inputs = self.tokenizer(
53
- sentence,
54
- return_attention_mask=True, # This is the only difference from the original implementation
55
- return_tensors=self.framework,
56
- truncation=truncation,
57
- return_special_tokens_mask=True,
58
- return_offsets_mapping=self.tokenizer.is_fast,
59
- )
60
- if offset_mapping:
61
- model_inputs["offset_mapping"] = offset_mapping
62
-
63
- model_inputs["sentence"] = sentence
64
-
65
- return model_inputs
66
-
67
  # CORS Config
68
  app = FastAPI()
69
 
@@ -75,19 +19,9 @@ app.add_middleware(
75
  allow_headers=["*"],
76
  )
77
 
78
- options = SessionOptions()
79
- options.graph_optimization_level = GraphOptimizationLevel.ORT_ENABLE_ALL
80
-
81
- session = InferenceSession("onnx/model.onnx", sess_options=options, providers=["CPUExecutionProvider"])
82
-
83
- session.disable_fallback()
84
-
85
  model_name = "xenova/mobilebert-uncased-mnli"
86
 
87
- tokenizer = AutoTokenizer.from_pretrained(model_name)
88
- model = AutoModelForTokenClassification.from_pretrained(model_name)
89
-
90
- classifier = OnnxTokenClassificationPipeline(task="zero-shot-classification", model=model, tokenizer=tokenizer, framework="pt", aggregation_strategy="simple")
91
 
92
  def zero_shot_classification(data_string):
93
  print(data_string)
 
1
  import gradio
2
  import json
3
  import torch
 
4
  from transformers import AutoTokenizer
5
  from fastapi import FastAPI
6
  from fastapi.middleware.cors import CORSMiddleware
7
+ from onnx_transformers import pipeline
 
 
 
 
 
8
 
9
  class OnnxTokenClassificationPipeline(TokenClassificationPipeline):
10
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
  # CORS Config
12
  app = FastAPI()
13
 
 
19
  allow_headers=["*"],
20
  )
21
 
 
 
 
 
 
 
 
22
  model_name = "xenova/mobilebert-uncased-mnli"
23
 
24
+ classifier = pipeline(task="zero-shot-classification", model=model_name, onnx=True)
 
 
 
25
 
26
  def zero_shot_classification(data_string):
27
  print(data_string)
requirements.txt CHANGED
@@ -2,6 +2,7 @@ fastapi==0.88.0
2
  json5==0.9.10
3
  numpy==1.23.4
4
  onnxruntime==1.18.1
 
5
  torch==1.12.1
6
  torchvision==0.13.1
7
  transformers==4.44.0
 
2
  json5==0.9.10
3
  numpy==1.23.4
4
  onnxruntime==1.18.1
5
+ onnx_transformers
6
  torch==1.12.1
7
  torchvision==0.13.1
8
  transformers==4.44.0