onnx-models / old_scripts /convert_to_onnx.py
agoor97's picture
Upload folder using huggingface_hub
16ffc97 verified
import os
import gc
import sys
import time
import logging
import traceback
import torch
import warnings
from transformers import AutoModelForCausalLM, AutoTokenizer
from onnxruntime.quantization import quantize_dynamic, QuantType
# Configure logging
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(levelname)s - %(message)s',
datefmt='%Y-%m-%d %H:%M:%S',
handlers=[logging.StreamHandler(sys.stdout)]
)
logger = logging.getLogger(__name__)
# Suppress specific warnings
warnings.filterwarnings("ignore", category=UserWarning, message=".*The shape of the input dimension.*")
warnings.filterwarnings("ignore", category=UserWarning, message=".*Converting a tensor to a Python.*")
# Models that are known to work well with ONNX conversion
RELIABLE_MODELS = [
{
"id": "facebook/opt-350m",
"description": "Well-balanced model (350M) for RAG and chatbots"
},
{
"id": "gpt2",
"description": "Very reliable model (124M) with excellent ONNX compatibility"
},
{
"id": "distilgpt2",
"description": "Lightweight (82M) model with good performance"
}
]
class ModelWrapper(torch.nn.Module):
"""
Wrapper to handle ONNX export compatibility issues.
This wrapper specifically:
1. Bypasses cache handling
2. Simplifies the forward pass to avoid dynamic operations
"""
def __init__(self, model):
super().__init__()
self.model = model
def forward(self, input_ids):
# Force no cache, no gradient, and no special features
with torch.no_grad():
return self.model(input_ids=input_ids, use_cache=False, return_dict=False)[0]
def convert_model(model_id, output_dir, quantize=True):
"""Convert a model to ONNX format with maximum compatibility."""
start_time = time.time()
logger.info(f"\n{'=' * 60}")
logger.info(f"Converting {model_id} to ONNX")
logger.info(f"{'=' * 60}")
# Create output directory
model_name = model_id.split("/")[-1]
model_dir = os.path.join(output_dir, model_name)
os.makedirs(model_dir, exist_ok=True)
try:
# Step 1: Load tokenizer
logger.info("Step 1/5: Loading tokenizer...")
tokenizer = AutoTokenizer.from_pretrained(model_id)
# Handle missing pad token
if tokenizer.pad_token is None and hasattr(tokenizer, 'eos_token'):
logger.info("Adding pad_token = eos_token")
tokenizer.pad_token = tokenizer.eos_token
# Save tokenizer
tokenizer.save_pretrained(model_dir)
logger.info(f"βœ“ Tokenizer saved to {model_dir}")
# Step 2: Load model with memory optimizations
logger.info("Step 2/5: Loading model with memory optimizations...")
# Clean memory before loading
gc.collect()
torch.cuda.empty_cache() if torch.cuda.is_available() else None
# Load model with optimizations
model = AutoModelForCausalLM.from_pretrained(
model_id,
torch_dtype=torch.float16, # Use half precision
low_cpu_mem_usage=True # Reduce memory usage
)
# Save config for reference
model.config.save_pretrained(model_dir)
logger.info(f"βœ“ Model config saved to {model_dir}")
# Step 3: Prepare for export
logger.info("Step 3/5: Preparing for export...")
# Wrap model to avoid tracing issues
wrapped_model = ModelWrapper(model)
wrapped_model.eval() # Set to evaluation mode
# Clean memory again
gc.collect()
torch.cuda.empty_cache() if torch.cuda.is_available() else None
# Step 4: Export to ONNX
logger.info("Step 4/5: Exporting to ONNX format...")
onnx_path = os.path.join(model_dir, "model.onnx")
# Create dummy input
batch_size = 1
seq_length = 8 # Small sequence length to reduce memory
dummy_input = torch.ones(batch_size, seq_length, dtype=torch.long)
# Export to ONNX format with new opset version
torch.onnx.export(
wrapped_model, # Use wrapped model
dummy_input, # Model input
onnx_path, # Output path
export_params=True, # Store model weights
opset_version=14, # ONNX opset version (changed from 13 to 14)
do_constant_folding=True, # Optimize constants
input_names=['input_ids'], # Input names
output_names=['logits'], # Output names
dynamic_axes={
'input_ids': {0: 'batch_size', 1: 'sequence'},
'logits': {0: 'batch_size', 1: 'sequence'}
}
)
# Clean up to save memory
del model
del wrapped_model
gc.collect()
torch.cuda.empty_cache() if torch.cuda.is_available() else None
# Verify export was successful
if os.path.exists(onnx_path):
size_mb = os.path.getsize(onnx_path) / (1024 * 1024)
logger.info(f"βœ“ ONNX model saved to {onnx_path}")
logger.info(f"βœ“ Original size: {size_mb:.2f} MB")
# Step 5: Quantize
if quantize:
logger.info("Step 5/5: Applying int8 quantization...")
quant_path = onnx_path.replace(".onnx", "_quantized.onnx")
try:
quantize_dynamic(
model_input=onnx_path,
model_output=quant_path,
per_channel=False,
reduce_range=False,
weight_type=QuantType.QInt8
)
if os.path.exists(quant_path):
quant_size = os.path.getsize(quant_path) / (1024 * 1024)
logger.info(f"βœ“ Quantized size: {quant_size:.2f} MB")
logger.info(f"βœ“ Size reduction: {(1 - quant_size/size_mb) * 100:.1f}%")
# Replace original with quantized to save space
os.replace(quant_path, onnx_path)
logger.info("βœ“ Replaced original with quantized version")
else:
logger.warning("⚠ Quantized file not created, using original")
except Exception as e:
logger.error(f"⚠ Quantization error: {str(e)}")
logger.info("⚠ Using original model without quantization")
else:
logger.info("Step 5/5: Skipping quantization (not requested)")
# Calculate elapsed time
end_time = time.time()
duration = end_time - start_time
logger.info(f"βœ“ Conversion completed in {duration:.2f} seconds")
return {
"success": True,
"model_id": model_id,
"size_mb": os.path.getsize(onnx_path) / (1024 * 1024),
"duration_seconds": duration,
"output_dir": model_dir
}
else:
logger.error(f"Γ— ONNX file not created at {onnx_path}")
return {
"success": False,
"model_id": model_id,
"error": "ONNX file not created"
}
except Exception as e:
logger.error(f"Γ— Error converting model: {str(e)}")
logger.error(traceback.format_exc())
return {
"success": False,
"model_id": model_id,
"error": str(e)
}
def main():
"""Convert all reliable models."""
# Print header
logger.info("\nGUARANTEED ONNX CONVERTER")
logger.info("======================")
logger.info("Using reliable models with proven ONNX compatibility")
# Create output directory
output_dir = "./onnx_models"
os.makedirs(output_dir, exist_ok=True)
# Check if specific model ID provided as argument
if len(sys.argv) > 1:
model_id = sys.argv[1]
logger.info(f"Converting single model: {model_id}")
convert_model(model_id, output_dir)
return
# Convert all reliable models
results = []
for model_info in RELIABLE_MODELS:
model_id = model_info["id"]
logger.info(f"Processing model: {model_id}")
logger.info(f"Description: {model_info['description']}")
result = convert_model(model_id, output_dir)
results.append(result)
# Print summary
logger.info("\n" + "=" * 60)
logger.info("CONVERSION SUMMARY")
logger.info("=" * 60)
success_count = 0
for result in results:
if result.get("success", False):
success_count += 1
size_info = f" - Size: {result.get('size_mb', 0):.2f} MB"
time_info = f" - Time: {result.get('duration_seconds', 0):.2f}s"
logger.info(f"βœ“ SUCCESS: {result['model_id']}{size_info}{time_info}")
else:
logger.info(f"Γ— FAILED: {result['model_id']} - Error: {result.get('error', 'Unknown error')}")
logger.info(f"\nSuccessfully converted {success_count}/{len(RELIABLE_MODELS)} models")
logger.info(f"Models saved to: {os.path.abspath(output_dir)}")
if success_count > 0:
logger.info("\nThe models are ready for RAG and chatbot applications!")
if __name__ == "__main__":
main()