File size: 7,464 Bytes
79a93f3 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 |
import torch
from transformers import AutoModelForQuestionAnswering
from transformers import AutoTokenizer, BertConfig
import onnx
from onnxruntime.quantization import quantize_dynamic, QuantType
from onnxruntime.quantization import shape_inference
import os
import logging
from typing import Optional, Dict, Any
import subprocess # Import the subprocess module
class ONNXModelConverter:
def __init__(self, model_name: str, output_dir: str):
self.model_name = model_name
self.output_dir = output_dir
self.setup_logging()
os.makedirs(output_dir, exist_ok=True)
self.logger.info(f"Loading tokenizer {model_name}...")
self.tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
self.logger.info(f"Loading model {model_name}...")
self.model = AutoModelForQuestionAnswering.from_pretrained(
model_name,
trust_remote_code=True,
torch_dtype=torch.float32
)
self.model.eval()
def setup_logging(self):
self.logger = logging.getLogger(__name__)
self.logger.setLevel(logging.INFO)
handler = logging.StreamHandler()
formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s')
handler.setFormatter(formatter)
self.logger.addHandler(handler)
def prepare_dummy_inputs(self):
dummy_input = self.tokenizer(
"Hello, how are you?",
return_tensors="pt",
padding=True,
truncation=True,
max_length=128
)
return {
'input_ids': dummy_input['input_ids'],
'attention_mask': dummy_input['attention_mask'],
'token_type_ids': dummy_input['token_type_ids']
}
def export_to_onnx(self):
output_path = os.path.join(self.output_dir, "model.onnx")
inputs = self.prepare_dummy_inputs()
dynamic_axes = {
'input_ids': {0: 'batch_size', 1: 'sequence_length'},
'attention_mask': {0: 'batch_size', 1: 'sequence_length'},
'token_type_ids': {0: 'batch_size', 1: 'sequence_length'},
'start_logits': {0: 'batch_size', 1: 'sequence_length'},
'end_logits': {0: 'batch_size', 1: 'sequence_length'},
}
class ModelWrapper(torch.nn.Module):
def __init__(self, model):
super().__init__()
self.model = model
def forward(self, input_ids, attention_mask, token_type_ids):
outputs = self.model(input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids)
return outputs.start_logits, outputs.end_logits
wrapped_model = ModelWrapper(self.model)
try:
torch.onnx.export(
wrapped_model,
(inputs['input_ids'], inputs['attention_mask'], inputs['token_type_ids']),
output_path,
export_params=True,
opset_version=14, # Or a suitable version
do_constant_folding=True,
input_names=['input_ids', 'attention_mask', 'token_type_ids'],
output_names=['start_logits', 'end_logits'],
dynamic_axes=dynamic_axes,
verbose=False
)
self.logger.info(f"Model exported to {output_path}")
return output_path
except Exception as e:
self.logger.error(f"ONNX export failed: {str(e)}")
raise
def verify_model(self, model_path: str):
try:
onnx_model = onnx.load(model_path)
onnx.checker.check_model(onnx_model)
self.logger.info("ONNX model verification successful")
return True
except Exception as e:
self.logger.error(f"Model verification failed: {str(e)}")
return False
def preprocess_model(self, model_path: str) -> str:
preprocessed_path = os.path.join(self.output_dir, "model-infer.onnx")
try:
command = [
"python", "-m", "onnxruntime.quantization.preprocess",
"--input", model_path,
"--output", preprocessed_path
]
result = subprocess.run(command, check=True, capture_output=True, text=True)
if result.returncode == 0:
self.logger.info(f"Model preprocessing successful. Output saved to {preprocessed_path}")
return preprocessed_path
else:
raise subprocess.CalledProcessError(result.returncode, command, result.stdout, result.stderr)
except subprocess.CalledProcessError as e:
self.logger.error(f"Preprocessing failed: {e.stderr}")
raise
except Exception as e:
self.logger.error(f"Preprocessing failed: {str(e)}")
raise
def quantize_model(self, model_path: str):
weight_types = {'int4':QuantType.QInt4, 'int8':QuantType.QInt8, 'uint4':QuantType.QUInt4, 'uint8':QuantType.QUInt8, 'uint16':QuantType.QUInt16, 'int16':QuantType.QInt16}
all_quantized_paths = []
for weight_type in weight_types.keys():
quantized_path = os.path.join(self.output_dir, "model_" + weight_type + ".onnx")
try:
quantize_dynamic(
model_path,
quantized_path,
weight_type=weight_types[weight_type]
)
self.logger.info(f"Model quantized ({weight_type}) and saved to {quantized_path}")
all_quantized_paths.append(quantized_path)
except Exception as e:
self.logger.error(f"Quantization ({weight_type}) failed: {str(e)}")
raise
return all_quantized_paths
def convert(self):
try:
onnx_path = self.export_to_onnx()
if self.verify_model(onnx_path):
# Add preprocessing step before quantization
# preprocessed_path = self.preprocess_model(onnx_path)
# Use preprocessed model for quantization
quantized_paths = self.quantize_model(onnx_path)
tokenizer_path = os.path.join(self.output_dir, "tokenizer")
self.tokenizer.save_pretrained(tokenizer_path)
self.logger.info(f"Tokenizer saved to {tokenizer_path}")
return {
'onnx_model': onnx_path,
'quantized_models': quantized_paths, # Return a list of quantized model paths
'tokenizer': tokenizer_path
}
else:
raise Exception("Model verification failed")
except Exception as e:
self.logger.error(f"Conversion process failed: {str(e)}")
raise
if __name__ == "__main__":
MODEL_NAME = "Intel/dynamic_tinybert" # Or any other suitable model
OUTPUT_DIR = "onnx"
try:
converter = ONNXModelConverter(MODEL_NAME, OUTPUT_DIR)
results = converter.convert()
print("\nConversion completed successfully!")
print(f"ONNX model path: {results['onnx_model']}")
print(f"Quantized model paths: {results['quantized_models']}") # Print the list
print(f"Tokenizer path: {results['tokenizer']}")
except Exception as e:
print(f"Conversion failed: {str(e)}")
|