onnx-models / old_scripts /convert_single_model.py
agoor97's picture
Upload folder using huggingface_hub
16ffc97 verified
import os
import gc
import sys
import time
import logging
import traceback
import torch
import warnings
import numpy as np
from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers.generation import GenerationConfig
from tqdm import tqdm
import onnx
from onnxruntime.quantization import quantize_dynamic, QuantType
# Configure logging
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(levelname)s - %(message)s',
datefmt='%Y-%m-%d %H:%M:%S'
)
logger = logging.getLogger(__name__)
# Suppress unhelpful warnings
warnings.filterwarnings("ignore", category=UserWarning, message=".*The shape of the input dimension.*")
warnings.filterwarnings("ignore", category=UserWarning, message=".*Converting a tensor to a Python.*")
warnings.filterwarnings("ignore", category=UserWarning, message=".*The model does not use GenerationMixin.*")
class GenerationWrapper(torch.nn.Module):
"""
Wrapper for model export that handles generation properly.
This ensures the model can be correctly used for text generation.
"""
def __init__(self, model):
super().__init__()
self.model = model
self.config = model.config
def forward(self, input_ids, attention_mask=None):
# Return only the logits to avoid complex structures
with torch.no_grad():
try:
# Standard approach for most models
outputs = self.model(
input_ids=input_ids,
attention_mask=attention_mask,
use_cache=False,
return_dict=True
)
return outputs.logits
except Exception as e:
logger.warning(f"Standard forward pass failed, trying fallback: {str(e)}")
# Fallback for models with different API
outputs = self.model(input_ids=input_ids, attention_mask=attention_mask)
if hasattr(outputs, 'logits'):
return outputs.logits
elif isinstance(outputs, tuple) and len(outputs) > 0:
return outputs[0] # First element is typically logits
else:
raise ValueError("Could not extract logits from model outputs")
def verify_model_generation(model, tokenizer, device="cpu"):
"""Test model generation capabilities before export"""
model.eval()
prompt = "Hello, how are you today? I am"
logger.info("Testing model generation...")
inputs = tokenizer(prompt, return_tensors="pt").to(device)
# Configure generation parameters
gen_config = GenerationConfig(
max_length=30,
do_sample=True,
temperature=0.7,
num_return_sequences=1,
)
try:
# Try generation
with torch.no_grad():
outputs = model.generate(
**inputs,
generation_config=gen_config
)
generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
logger.info(f"Test generation result: {generated_text}")
if len(generated_text) <= len(prompt):
logger.warning("Generation output is not longer than input prompt!")
return True
except Exception as e:
logger.error(f"Generation test failed: {str(e)}")
return False
def test_onnx_model(onnx_path, tokenizer):
"""Verify the ONNX model can be loaded and run"""
try:
import onnxruntime as ort
logger.info("Testing ONNX model inference...")
session = ort.InferenceSession(onnx_path)
# Get input and output names
input_names = [input.name for input in session.get_inputs()]
output_names = [output.name for output in session.get_outputs()]
# Create test input
prompt = "Hello, how are you?"
inputs = tokenizer(prompt, return_tensors="np")
# Prepare input dict
onnx_inputs = {}
for name in input_names:
if name == "input_ids" and "input_ids" in inputs:
onnx_inputs[name] = inputs["input_ids"]
elif name == "attention_mask" and "attention_mask" in inputs:
onnx_inputs[name] = inputs["attention_mask"]
# Run inference
outputs = session.run(output_names, onnx_inputs)
# Check output shape
logits = outputs[0]
logger.info(f"ONNX model output shape: {logits.shape}")
if logits.shape[0] != 1 or logits.shape[1] != inputs["input_ids"].shape[1]:
logger.warning("Output shape doesn't match expected dimensions!")
# Test next token prediction
next_token_logits = logits[0, -1, :]
next_token_id = np.argmax(next_token_logits)
next_token = tokenizer.decode([next_token_id])
logger.info(f"Next predicted token: '{next_token}'")
return True
except Exception as e:
logger.error(f"ONNX model test failed: {str(e)}")
return False
def optimize_onnx_model(onnx_path):
"""Apply ONNX optimizations to improve performance"""
try:
logger.info("Optimizing ONNX model...")
# Load the model
model = onnx.load(onnx_path)
# Apply optimizations
from onnxruntime.transformers import optimizer
# Get model type from path
model_path = os.path.dirname(onnx_path)
model_name = os.path.basename(model_path).lower()
# Determine model type for optimization
if "gpt" in model_name:
model_type = "gpt2"
elif "opt" in model_name:
model_type = "opt"
elif "pythia" in model_name:
model_type = "gpt_neox"
else:
model_type = "gpt2" # Default fallback
logger.info(f"Using optimization profile for model type: {model_type}")
# Try to optimize the model
try:
optimized_model = optimizer.optimize_model(
onnx_path,
model_type=model_type,
num_heads=8, # Will be overridden by model's real config
hidden_size=768, # Will be overridden by model's real config
optimization_options=None
)
optimized_model.save_model_to_file(onnx_path)
logger.info("✓ ONNX model optimized")
return True
except Exception as e:
logger.warning(f"Optimization failed (non-critical): {str(e)}")
return False
except Exception as e:
logger.warning(f"ONNX optimization error (skipping): {str(e)}")
return False
def convert_model(model_id, output_dir="./onnx_models", seq_length=32, quantize=True):
"""
Convert a model to ONNX format with focus on reliability for generation.
Args:
model_id: HuggingFace model ID or path
output_dir: Directory to save the model
seq_length: Input sequence length for export
quantize: Whether to quantize the model to INT8
Returns:
bool: Success status
"""
start_time = time.time()
logger.info(f"\n{'=' * 60}")
logger.info(f"Converting {model_id} to ONNX (optimized for generation)")
logger.info(f"{'=' * 60}")
# Create output directory
model_name = model_id.split("/")[-1]
model_dir = os.path.join(output_dir, model_name)
os.makedirs(model_dir, exist_ok=True)
try:
# Step 1: Load tokenizer
logger.info("Step 1/6: Loading tokenizer...")
tokenizer = AutoTokenizer.from_pretrained(model_id)
if tokenizer.pad_token is None and hasattr(tokenizer, 'eos_token'):
logger.info("Adding pad_token = eos_token")
tokenizer.pad_token = tokenizer.eos_token
# Save tokenizer
tokenizer.save_pretrained(model_dir)
logger.info(f"✓ Tokenizer saved to {model_dir}")
# Step 2: Load model with reliability optimizations
logger.info("Step 2/6: Loading model...")
# Clean memory
gc.collect()
torch.cuda.empty_cache() if torch.cuda.is_available() else None
# Determine device
device = "cuda" if torch.cuda.is_available() else "cpu"
# Load model with full precision
model = AutoModelForCausalLM.from_pretrained(
model_id,
torch_dtype=torch.float32, # Use full precision for reliability
low_cpu_mem_usage=True, # Reduce memory usage
device_map=device # Use CUDA if available
)
# Save config
model.config.save_pretrained(model_dir)
logger.info(f"✓ Model config saved to {model_dir}")
# Step 3: Verify model can generate text
logger.info("Step 3/6: Validating generation capabilities...")
if not verify_model_generation(model, tokenizer, device):
logger.warning("⚠ Model generation test didn't complete successfully")
logger.info("Continuing with export anyway...")
# Step 4: Wrap and prepare for export
logger.info("Step 4/6: Preparing for export...")
# Wrap model with generation-optimized interface
wrapped_model = GenerationWrapper(model)
wrapped_model.eval()
# Clean memory again
gc.collect()
torch.cuda.empty_cache() if torch.cuda.is_available() else None
# Step 5: Export to ONNX
logger.info("Step 5/6: Exporting to ONNX format...")
onnx_path = os.path.join(model_dir, "model.onnx")
# Create minimal input
batch_size = 1
dummy_input = torch.ones(batch_size, seq_length, dtype=torch.long)
attention_mask = torch.ones(batch_size, seq_length, dtype=torch.long)
# Move tensors to correct device
dummy_input = dummy_input.to(device)
attention_mask = attention_mask.to(device)
# Export to ONNX with required opset for transformer models
with torch.no_grad():
torch.onnx.export(
wrapped_model, # Wrapped model
(dummy_input, attention_mask), # Input tensors
onnx_path, # Output path
export_params=True, # Store weights
opset_version=14, # Required for transformer models
do_constant_folding=True, # Optimize constants
input_names=['input_ids', 'attention_mask'], # Input names
output_names=['logits'], # Output name
dynamic_axes={ # Dynamic dimensions
'input_ids': {0: 'batch_size', 1: 'sequence'},
'attention_mask': {0: 'batch_size', 1: 'sequence'},
'logits': {0: 'batch_size', 1: 'sequence'}
}
)
# Clean up to save memory
del model
del wrapped_model
gc.collect()
torch.cuda.empty_cache() if torch.cuda.is_available() else None
# Verify export success
if os.path.exists(onnx_path):
size_mb = os.path.getsize(onnx_path) / (1024 * 1024)
logger.info(f"✓ ONNX model saved to {onnx_path}")
logger.info(f"✓ Original size: {size_mb:.2f} MB")
# Test ONNX model
test_onnx_model(onnx_path, tokenizer)
# Optimize the ONNX model
optimize_onnx_model(onnx_path)
# Step 6: Quantize the model (optional)
if quantize:
logger.info("Step 6/6: Applying INT8 quantization...")
quant_path = onnx_path.replace(".onnx", "_quantized.onnx")
try:
with tqdm(total=100, desc="Quantizing") as pbar:
# Update progress callback
def update_progress(x):
pbar.update(1)
quantize_dynamic(
model_input=onnx_path,
model_output=quant_path,
per_channel=False,
reduce_range=False,
weight_type=QuantType.QInt8,
optimize_model=True,
use_external_data_format=False
)
pbar.update(100) # Ensure progress reaches 100%
if os.path.exists(quant_path):
quant_size = os.path.getsize(quant_path) / (1024 * 1024)
logger.info(f"✓ Quantized size: {quant_size:.2f} MB")
logger.info(f"✓ Size reduction: {(1 - quant_size/size_mb) * 100:.1f}%")
# Test the quantized model
test_onnx_model(quant_path, tokenizer)
# Rename original as backup
backup_path = onnx_path.replace(".onnx", "_fp32.onnx")
os.rename(onnx_path, backup_path)
# Replace original with quantized
os.rename(quant_path, onnx_path)
logger.info("✓ Original model preserved as *_fp32.onnx")
logger.info("✓ Replaced original with quantized version")
else:
logger.warning("⚠ Quantized file not created, using original")
except Exception as e:
logger.error(f"⚠ Quantization error: {str(e)}")
logger.info("⚠ Using original model without quantization")
else:
logger.info("Step 6/6: Skipping quantization as requested")
# Calculate elapsed time
end_time = time.time()
duration = end_time - start_time
logger.info(f"✓ Conversion completed in {duration:.2f} seconds")
logger.info(f"✓ Final model size: {os.path.getsize(onnx_path) / (1024 * 1024):.2f} MB")
# Create a simple example usage file
example_path = os.path.join(model_dir, "example_usage.py")
with open(example_path, 'w') as f:
f.write("""
import onnxruntime as ort
from transformers import AutoTokenizer
import numpy as np
# Load tokenizer and model
tokenizer = AutoTokenizer.from_pretrained("./") # Path to model directory
session = ort.InferenceSession("./model.onnx")
# Prepare input
prompt = "Hello, how are you?"
inputs = tokenizer(prompt, return_tensors="np")
# Run inference for a single step
outputs = session.run(
["logits"],
{
"input_ids": inputs["input_ids"],
"attention_mask": inputs["attention_mask"]
}
)
# Get next token prediction
logits = outputs[0]
next_token_id = np.argmax(logits[0, -1, :])
next_token = tokenizer.decode([next_token_id])
print(f"Next predicted token: {next_token}")
# For full generation, you'd typically run in a loop, adding tokens one by one
""")
logger.info(f"✓ Example usage saved to {example_path}")
return True
else:
logger.error(f"× ONNX file not created at {onnx_path}")
return False
except Exception as e:
logger.error(f"× Error converting model: {str(e)}")
logger.error(traceback.format_exc())
return False
if __name__ == "__main__":
# Parse command line arguments
parser_available = False
try:
import argparse
parser = argparse.ArgumentParser(description="Convert HuggingFace models to ONNX for generation")
parser.add_argument("model_id", type=str, help="HuggingFace model ID or path")
parser.add_argument("--output_dir", "-o", type=str, default="./onnx_models",
help="Output directory for the converted model")
parser.add_argument("--seq_length", "-s", type=int, default=32,
help="Sequence length for model export")
parser.add_argument("--no_quantize", action="store_true",
help="Skip INT8 quantization step")
args = parser.parse_args()
parser_available = True
model_id = args.model_id
output_dir = args.output_dir
seq_length = args.seq_length
quantize = not args.no_quantize
except (ImportError, NameError):
# Fallback if argparse is not available
parser_available = False
if not parser_available:
if len(sys.argv) < 2:
print("Usage: python convert_model.py MODEL_ID [OUTPUT_DIR] [SEQ_LENGTH] [--no-quantize]")
print("Example: python convert_model.py facebook/opt-125m ./onnx_models 32")
print("\nRecommended models for small hardware:")
print(" - facebook/opt-125m")
print(" - distilgpt2")
print(" - TinyLlama/TinyLlama-1.1B-Chat-v1.0")
print(" - EleutherAI/pythia-70m")
sys.exit(1)
model_id = sys.argv[1]
output_dir = sys.argv[2] if len(sys.argv) > 2 else "./onnx_models"
seq_length = int(sys.argv[3]) if len(sys.argv) > 3 else 32
quantize = "--no-quantize" not in sys.argv and "--no_quantize" not in sys.argv
# Print header
logger.info("\nENHANCED ONNX CONVERTER FOR LANGUAGE MODELS")
logger.info("============================================")
logger.info(f"Model: {model_id}")
logger.info(f"Output directory: {output_dir}")
logger.info(f"Sequence length: {seq_length}")
logger.info(f"Quantization: {'Enabled' if quantize else 'Disabled'}")
# Create output directory
os.makedirs(output_dir, exist_ok=True)
# Convert the model
success = convert_model(model_id, output_dir, seq_length, quantize)
if success:
logger.info("\n" + "=" * 60)
logger.info("CONVERSION SUCCESSFUL")
logger.info("=" * 60)
logger.info(f"Model: {model_id}")
logger.info(f"Output directory: {os.path.abspath(output_dir)}")
logger.info("The model is ready for generation!")
logger.info("\nTo use the model:")
logger.info("1. See the example_usage.py file in the model directory")
logger.info("2. For chatbot applications, implement token-by-token generation")
else:
logger.error("\n" + "=" * 60)
logger.error("CONVERSION FAILED")
logger.error("=" * 60)
logger.error("Please try one of the recommended models:")
logger.error(" - facebook/opt-125m")
logger.error(" - distilgpt2")
logger.error(" - EleutherAI/pythia-70m")