|
import os |
|
import gc |
|
import sys |
|
import time |
|
import logging |
|
import traceback |
|
import torch |
|
import warnings |
|
import numpy as np |
|
from transformers import AutoModelForCausalLM, AutoTokenizer |
|
from transformers.generation import GenerationConfig |
|
from tqdm import tqdm |
|
import onnx |
|
from onnxruntime.quantization import quantize_dynamic, QuantType |
|
|
|
|
|
logging.basicConfig( |
|
level=logging.INFO, |
|
format='%(asctime)s - %(levelname)s - %(message)s', |
|
datefmt='%Y-%m-%d %H:%M:%S' |
|
) |
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
warnings.filterwarnings("ignore", category=UserWarning, message=".*The shape of the input dimension.*") |
|
warnings.filterwarnings("ignore", category=UserWarning, message=".*Converting a tensor to a Python.*") |
|
warnings.filterwarnings("ignore", category=UserWarning, message=".*The model does not use GenerationMixin.*") |
|
|
|
|
|
class GenerationWrapper(torch.nn.Module): |
|
""" |
|
Wrapper for model export that handles generation properly. |
|
This ensures the model can be correctly used for text generation. |
|
""" |
|
def __init__(self, model): |
|
super().__init__() |
|
self.model = model |
|
self.config = model.config |
|
|
|
def forward(self, input_ids, attention_mask=None): |
|
|
|
with torch.no_grad(): |
|
try: |
|
|
|
outputs = self.model( |
|
input_ids=input_ids, |
|
attention_mask=attention_mask, |
|
use_cache=False, |
|
return_dict=True |
|
) |
|
return outputs.logits |
|
except Exception as e: |
|
logger.warning(f"Standard forward pass failed, trying fallback: {str(e)}") |
|
|
|
outputs = self.model(input_ids=input_ids, attention_mask=attention_mask) |
|
if hasattr(outputs, 'logits'): |
|
return outputs.logits |
|
elif isinstance(outputs, tuple) and len(outputs) > 0: |
|
return outputs[0] |
|
else: |
|
raise ValueError("Could not extract logits from model outputs") |
|
|
|
|
|
def verify_model_generation(model, tokenizer, device="cpu"): |
|
"""Test model generation capabilities before export""" |
|
model.eval() |
|
prompt = "Hello, how are you today? I am" |
|
|
|
logger.info("Testing model generation...") |
|
inputs = tokenizer(prompt, return_tensors="pt").to(device) |
|
|
|
|
|
gen_config = GenerationConfig( |
|
max_length=30, |
|
do_sample=True, |
|
temperature=0.7, |
|
num_return_sequences=1, |
|
) |
|
|
|
try: |
|
|
|
with torch.no_grad(): |
|
outputs = model.generate( |
|
**inputs, |
|
generation_config=gen_config |
|
) |
|
|
|
generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True) |
|
logger.info(f"Test generation result: {generated_text}") |
|
|
|
if len(generated_text) <= len(prompt): |
|
logger.warning("Generation output is not longer than input prompt!") |
|
|
|
return True |
|
except Exception as e: |
|
logger.error(f"Generation test failed: {str(e)}") |
|
return False |
|
|
|
|
|
def test_onnx_model(onnx_path, tokenizer): |
|
"""Verify the ONNX model can be loaded and run""" |
|
try: |
|
import onnxruntime as ort |
|
|
|
logger.info("Testing ONNX model inference...") |
|
session = ort.InferenceSession(onnx_path) |
|
|
|
|
|
input_names = [input.name for input in session.get_inputs()] |
|
output_names = [output.name for output in session.get_outputs()] |
|
|
|
|
|
prompt = "Hello, how are you?" |
|
inputs = tokenizer(prompt, return_tensors="np") |
|
|
|
|
|
onnx_inputs = {} |
|
for name in input_names: |
|
if name == "input_ids" and "input_ids" in inputs: |
|
onnx_inputs[name] = inputs["input_ids"] |
|
elif name == "attention_mask" and "attention_mask" in inputs: |
|
onnx_inputs[name] = inputs["attention_mask"] |
|
|
|
|
|
outputs = session.run(output_names, onnx_inputs) |
|
|
|
|
|
logits = outputs[0] |
|
logger.info(f"ONNX model output shape: {logits.shape}") |
|
|
|
if logits.shape[0] != 1 or logits.shape[1] != inputs["input_ids"].shape[1]: |
|
logger.warning("Output shape doesn't match expected dimensions!") |
|
|
|
|
|
next_token_logits = logits[0, -1, :] |
|
next_token_id = np.argmax(next_token_logits) |
|
next_token = tokenizer.decode([next_token_id]) |
|
logger.info(f"Next predicted token: '{next_token}'") |
|
|
|
return True |
|
except Exception as e: |
|
logger.error(f"ONNX model test failed: {str(e)}") |
|
return False |
|
|
|
|
|
def optimize_onnx_model(onnx_path): |
|
"""Apply ONNX optimizations to improve performance""" |
|
try: |
|
logger.info("Optimizing ONNX model...") |
|
|
|
|
|
model = onnx.load(onnx_path) |
|
|
|
|
|
from onnxruntime.transformers import optimizer |
|
|
|
|
|
model_path = os.path.dirname(onnx_path) |
|
model_name = os.path.basename(model_path).lower() |
|
|
|
|
|
if "gpt" in model_name: |
|
model_type = "gpt2" |
|
elif "opt" in model_name: |
|
model_type = "opt" |
|
elif "pythia" in model_name: |
|
model_type = "gpt_neox" |
|
else: |
|
model_type = "gpt2" |
|
|
|
logger.info(f"Using optimization profile for model type: {model_type}") |
|
|
|
|
|
try: |
|
optimized_model = optimizer.optimize_model( |
|
onnx_path, |
|
model_type=model_type, |
|
num_heads=8, |
|
hidden_size=768, |
|
optimization_options=None |
|
) |
|
optimized_model.save_model_to_file(onnx_path) |
|
logger.info("✓ ONNX model optimized") |
|
return True |
|
except Exception as e: |
|
logger.warning(f"Optimization failed (non-critical): {str(e)}") |
|
return False |
|
|
|
except Exception as e: |
|
logger.warning(f"ONNX optimization error (skipping): {str(e)}") |
|
return False |
|
|
|
|
|
def convert_model(model_id, output_dir="./onnx_models", seq_length=32, quantize=True): |
|
""" |
|
Convert a model to ONNX format with focus on reliability for generation. |
|
|
|
Args: |
|
model_id: HuggingFace model ID or path |
|
output_dir: Directory to save the model |
|
seq_length: Input sequence length for export |
|
quantize: Whether to quantize the model to INT8 |
|
|
|
Returns: |
|
bool: Success status |
|
""" |
|
start_time = time.time() |
|
|
|
logger.info(f"\n{'=' * 60}") |
|
logger.info(f"Converting {model_id} to ONNX (optimized for generation)") |
|
logger.info(f"{'=' * 60}") |
|
|
|
|
|
model_name = model_id.split("/")[-1] |
|
model_dir = os.path.join(output_dir, model_name) |
|
os.makedirs(model_dir, exist_ok=True) |
|
|
|
try: |
|
|
|
logger.info("Step 1/6: Loading tokenizer...") |
|
|
|
tokenizer = AutoTokenizer.from_pretrained(model_id) |
|
if tokenizer.pad_token is None and hasattr(tokenizer, 'eos_token'): |
|
logger.info("Adding pad_token = eos_token") |
|
tokenizer.pad_token = tokenizer.eos_token |
|
|
|
|
|
tokenizer.save_pretrained(model_dir) |
|
logger.info(f"✓ Tokenizer saved to {model_dir}") |
|
|
|
|
|
logger.info("Step 2/6: Loading model...") |
|
|
|
|
|
gc.collect() |
|
torch.cuda.empty_cache() if torch.cuda.is_available() else None |
|
|
|
|
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
|
|
|
model = AutoModelForCausalLM.from_pretrained( |
|
model_id, |
|
torch_dtype=torch.float32, |
|
low_cpu_mem_usage=True, |
|
device_map=device |
|
) |
|
|
|
|
|
model.config.save_pretrained(model_dir) |
|
logger.info(f"✓ Model config saved to {model_dir}") |
|
|
|
|
|
logger.info("Step 3/6: Validating generation capabilities...") |
|
|
|
if not verify_model_generation(model, tokenizer, device): |
|
logger.warning("⚠ Model generation test didn't complete successfully") |
|
logger.info("Continuing with export anyway...") |
|
|
|
|
|
logger.info("Step 4/6: Preparing for export...") |
|
|
|
|
|
wrapped_model = GenerationWrapper(model) |
|
wrapped_model.eval() |
|
|
|
|
|
gc.collect() |
|
torch.cuda.empty_cache() if torch.cuda.is_available() else None |
|
|
|
|
|
logger.info("Step 5/6: Exporting to ONNX format...") |
|
onnx_path = os.path.join(model_dir, "model.onnx") |
|
|
|
|
|
batch_size = 1 |
|
dummy_input = torch.ones(batch_size, seq_length, dtype=torch.long) |
|
attention_mask = torch.ones(batch_size, seq_length, dtype=torch.long) |
|
|
|
|
|
dummy_input = dummy_input.to(device) |
|
attention_mask = attention_mask.to(device) |
|
|
|
|
|
with torch.no_grad(): |
|
torch.onnx.export( |
|
wrapped_model, |
|
(dummy_input, attention_mask), |
|
onnx_path, |
|
export_params=True, |
|
opset_version=14, |
|
do_constant_folding=True, |
|
input_names=['input_ids', 'attention_mask'], |
|
output_names=['logits'], |
|
dynamic_axes={ |
|
'input_ids': {0: 'batch_size', 1: 'sequence'}, |
|
'attention_mask': {0: 'batch_size', 1: 'sequence'}, |
|
'logits': {0: 'batch_size', 1: 'sequence'} |
|
} |
|
) |
|
|
|
|
|
del model |
|
del wrapped_model |
|
gc.collect() |
|
torch.cuda.empty_cache() if torch.cuda.is_available() else None |
|
|
|
|
|
if os.path.exists(onnx_path): |
|
size_mb = os.path.getsize(onnx_path) / (1024 * 1024) |
|
logger.info(f"✓ ONNX model saved to {onnx_path}") |
|
logger.info(f"✓ Original size: {size_mb:.2f} MB") |
|
|
|
|
|
test_onnx_model(onnx_path, tokenizer) |
|
|
|
|
|
optimize_onnx_model(onnx_path) |
|
|
|
|
|
if quantize: |
|
logger.info("Step 6/6: Applying INT8 quantization...") |
|
quant_path = onnx_path.replace(".onnx", "_quantized.onnx") |
|
|
|
try: |
|
with tqdm(total=100, desc="Quantizing") as pbar: |
|
|
|
def update_progress(x): |
|
pbar.update(1) |
|
|
|
quantize_dynamic( |
|
model_input=onnx_path, |
|
model_output=quant_path, |
|
per_channel=False, |
|
reduce_range=False, |
|
weight_type=QuantType.QInt8, |
|
optimize_model=True, |
|
use_external_data_format=False |
|
) |
|
|
|
pbar.update(100) |
|
|
|
if os.path.exists(quant_path): |
|
quant_size = os.path.getsize(quant_path) / (1024 * 1024) |
|
logger.info(f"✓ Quantized size: {quant_size:.2f} MB") |
|
logger.info(f"✓ Size reduction: {(1 - quant_size/size_mb) * 100:.1f}%") |
|
|
|
|
|
test_onnx_model(quant_path, tokenizer) |
|
|
|
|
|
backup_path = onnx_path.replace(".onnx", "_fp32.onnx") |
|
os.rename(onnx_path, backup_path) |
|
|
|
|
|
os.rename(quant_path, onnx_path) |
|
logger.info("✓ Original model preserved as *_fp32.onnx") |
|
logger.info("✓ Replaced original with quantized version") |
|
else: |
|
logger.warning("⚠ Quantized file not created, using original") |
|
except Exception as e: |
|
logger.error(f"⚠ Quantization error: {str(e)}") |
|
logger.info("⚠ Using original model without quantization") |
|
else: |
|
logger.info("Step 6/6: Skipping quantization as requested") |
|
|
|
|
|
end_time = time.time() |
|
duration = end_time - start_time |
|
logger.info(f"✓ Conversion completed in {duration:.2f} seconds") |
|
logger.info(f"✓ Final model size: {os.path.getsize(onnx_path) / (1024 * 1024):.2f} MB") |
|
|
|
|
|
example_path = os.path.join(model_dir, "example_usage.py") |
|
with open(example_path, 'w') as f: |
|
f.write(""" |
|
import onnxruntime as ort |
|
from transformers import AutoTokenizer |
|
import numpy as np |
|
|
|
# Load tokenizer and model |
|
tokenizer = AutoTokenizer.from_pretrained("./") # Path to model directory |
|
session = ort.InferenceSession("./model.onnx") |
|
|
|
# Prepare input |
|
prompt = "Hello, how are you?" |
|
inputs = tokenizer(prompt, return_tensors="np") |
|
|
|
# Run inference for a single step |
|
outputs = session.run( |
|
["logits"], |
|
{ |
|
"input_ids": inputs["input_ids"], |
|
"attention_mask": inputs["attention_mask"] |
|
} |
|
) |
|
|
|
# Get next token prediction |
|
logits = outputs[0] |
|
next_token_id = np.argmax(logits[0, -1, :]) |
|
next_token = tokenizer.decode([next_token_id]) |
|
print(f"Next predicted token: {next_token}") |
|
|
|
# For full generation, you'd typically run in a loop, adding tokens one by one |
|
""") |
|
logger.info(f"✓ Example usage saved to {example_path}") |
|
|
|
return True |
|
else: |
|
logger.error(f"× ONNX file not created at {onnx_path}") |
|
return False |
|
|
|
except Exception as e: |
|
logger.error(f"× Error converting model: {str(e)}") |
|
logger.error(traceback.format_exc()) |
|
return False |
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
parser_available = False |
|
try: |
|
import argparse |
|
parser = argparse.ArgumentParser(description="Convert HuggingFace models to ONNX for generation") |
|
parser.add_argument("model_id", type=str, help="HuggingFace model ID or path") |
|
parser.add_argument("--output_dir", "-o", type=str, default="./onnx_models", |
|
help="Output directory for the converted model") |
|
parser.add_argument("--seq_length", "-s", type=int, default=32, |
|
help="Sequence length for model export") |
|
parser.add_argument("--no_quantize", action="store_true", |
|
help="Skip INT8 quantization step") |
|
|
|
args = parser.parse_args() |
|
parser_available = True |
|
|
|
model_id = args.model_id |
|
output_dir = args.output_dir |
|
seq_length = args.seq_length |
|
quantize = not args.no_quantize |
|
|
|
except (ImportError, NameError): |
|
|
|
parser_available = False |
|
|
|
if not parser_available: |
|
if len(sys.argv) < 2: |
|
print("Usage: python convert_model.py MODEL_ID [OUTPUT_DIR] [SEQ_LENGTH] [--no-quantize]") |
|
print("Example: python convert_model.py facebook/opt-125m ./onnx_models 32") |
|
print("\nRecommended models for small hardware:") |
|
print(" - facebook/opt-125m") |
|
print(" - distilgpt2") |
|
print(" - TinyLlama/TinyLlama-1.1B-Chat-v1.0") |
|
print(" - EleutherAI/pythia-70m") |
|
sys.exit(1) |
|
|
|
model_id = sys.argv[1] |
|
output_dir = sys.argv[2] if len(sys.argv) > 2 else "./onnx_models" |
|
seq_length = int(sys.argv[3]) if len(sys.argv) > 3 else 32 |
|
quantize = "--no-quantize" not in sys.argv and "--no_quantize" not in sys.argv |
|
|
|
|
|
logger.info("\nENHANCED ONNX CONVERTER FOR LANGUAGE MODELS") |
|
logger.info("============================================") |
|
logger.info(f"Model: {model_id}") |
|
logger.info(f"Output directory: {output_dir}") |
|
logger.info(f"Sequence length: {seq_length}") |
|
logger.info(f"Quantization: {'Enabled' if quantize else 'Disabled'}") |
|
|
|
|
|
os.makedirs(output_dir, exist_ok=True) |
|
|
|
|
|
success = convert_model(model_id, output_dir, seq_length, quantize) |
|
|
|
if success: |
|
logger.info("\n" + "=" * 60) |
|
logger.info("CONVERSION SUCCESSFUL") |
|
logger.info("=" * 60) |
|
logger.info(f"Model: {model_id}") |
|
logger.info(f"Output directory: {os.path.abspath(output_dir)}") |
|
logger.info("The model is ready for generation!") |
|
logger.info("\nTo use the model:") |
|
logger.info("1. See the example_usage.py file in the model directory") |
|
logger.info("2. For chatbot applications, implement token-by-token generation") |
|
else: |
|
logger.error("\n" + "=" * 60) |
|
logger.error("CONVERSION FAILED") |
|
logger.error("=" * 60) |
|
logger.error("Please try one of the recommended models:") |
|
logger.error(" - facebook/opt-125m") |
|
logger.error(" - distilgpt2") |
|
logger.error(" - EleutherAI/pythia-70m") |