import os import gc import sys import time import logging import traceback import torch import warnings from transformers import AutoModelForCausalLM, AutoTokenizer from onnxruntime.quantization import quantize_dynamic, QuantType # Configure logging logging.basicConfig( level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s', datefmt='%Y-%m-%d %H:%M:%S', handlers=[logging.StreamHandler(sys.stdout)] ) logger = logging.getLogger(__name__) # Suppress specific warnings warnings.filterwarnings("ignore", category=UserWarning, message=".*The shape of the input dimension.*") warnings.filterwarnings("ignore", category=UserWarning, message=".*Converting a tensor to a Python.*") # Models that are known to work well with ONNX conversion RELIABLE_MODELS = [ { "id": "facebook/opt-350m", "description": "Well-balanced model (350M) for RAG and chatbots" }, { "id": "gpt2", "description": "Very reliable model (124M) with excellent ONNX compatibility" }, { "id": "distilgpt2", "description": "Lightweight (82M) model with good performance" } ] class ModelWrapper(torch.nn.Module): """ Wrapper to handle ONNX export compatibility issues. This wrapper specifically: 1. Bypasses cache handling 2. Simplifies the forward pass to avoid dynamic operations """ def __init__(self, model): super().__init__() self.model = model def forward(self, input_ids): # Force no cache, no gradient, and no special features with torch.no_grad(): return self.model(input_ids=input_ids, use_cache=False, return_dict=False)[0] def convert_model(model_id, output_dir, quantize=True): """Convert a model to ONNX format with maximum compatibility.""" start_time = time.time() logger.info(f"\n{'=' * 60}") logger.info(f"Converting {model_id} to ONNX") logger.info(f"{'=' * 60}") # Create output directory model_name = model_id.split("/")[-1] model_dir = os.path.join(output_dir, model_name) os.makedirs(model_dir, exist_ok=True) try: # Step 1: Load tokenizer logger.info("Step 1/5: Loading tokenizer...") tokenizer = AutoTokenizer.from_pretrained(model_id) # Handle missing pad token if tokenizer.pad_token is None and hasattr(tokenizer, 'eos_token'): logger.info("Adding pad_token = eos_token") tokenizer.pad_token = tokenizer.eos_token # Save tokenizer tokenizer.save_pretrained(model_dir) logger.info(f"✓ Tokenizer saved to {model_dir}") # Step 2: Load model with memory optimizations logger.info("Step 2/5: Loading model with memory optimizations...") # Clean memory before loading gc.collect() torch.cuda.empty_cache() if torch.cuda.is_available() else None # Load model with optimizations model = AutoModelForCausalLM.from_pretrained( model_id, torch_dtype=torch.float16, # Use half precision low_cpu_mem_usage=True # Reduce memory usage ) # Save config for reference model.config.save_pretrained(model_dir) logger.info(f"✓ Model config saved to {model_dir}") # Step 3: Prepare for export logger.info("Step 3/5: Preparing for export...") # Wrap model to avoid tracing issues wrapped_model = ModelWrapper(model) wrapped_model.eval() # Set to evaluation mode # Clean memory again gc.collect() torch.cuda.empty_cache() if torch.cuda.is_available() else None # Step 4: Export to ONNX logger.info("Step 4/5: Exporting to ONNX format...") onnx_path = os.path.join(model_dir, "model.onnx") # Create dummy input batch_size = 1 seq_length = 8 # Small sequence length to reduce memory dummy_input = torch.ones(batch_size, seq_length, dtype=torch.long) # Export to ONNX format with new opset version torch.onnx.export( wrapped_model, # Use wrapped model dummy_input, # Model input onnx_path, # Output path export_params=True, # Store model weights opset_version=14, # ONNX opset version (changed from 13 to 14) do_constant_folding=True, # Optimize constants input_names=['input_ids'], # Input names output_names=['logits'], # Output names dynamic_axes={ 'input_ids': {0: 'batch_size', 1: 'sequence'}, 'logits': {0: 'batch_size', 1: 'sequence'} } ) # Clean up to save memory del model del wrapped_model gc.collect() torch.cuda.empty_cache() if torch.cuda.is_available() else None # Verify export was successful if os.path.exists(onnx_path): size_mb = os.path.getsize(onnx_path) / (1024 * 1024) logger.info(f"✓ ONNX model saved to {onnx_path}") logger.info(f"✓ Original size: {size_mb:.2f} MB") # Step 5: Quantize if quantize: logger.info("Step 5/5: Applying int8 quantization...") quant_path = onnx_path.replace(".onnx", "_quantized.onnx") try: quantize_dynamic( model_input=onnx_path, model_output=quant_path, per_channel=False, reduce_range=False, weight_type=QuantType.QInt8 ) if os.path.exists(quant_path): quant_size = os.path.getsize(quant_path) / (1024 * 1024) logger.info(f"✓ Quantized size: {quant_size:.2f} MB") logger.info(f"✓ Size reduction: {(1 - quant_size/size_mb) * 100:.1f}%") # Replace original with quantized to save space os.replace(quant_path, onnx_path) logger.info("✓ Replaced original with quantized version") else: logger.warning("⚠ Quantized file not created, using original") except Exception as e: logger.error(f"⚠ Quantization error: {str(e)}") logger.info("⚠ Using original model without quantization") else: logger.info("Step 5/5: Skipping quantization (not requested)") # Calculate elapsed time end_time = time.time() duration = end_time - start_time logger.info(f"✓ Conversion completed in {duration:.2f} seconds") return { "success": True, "model_id": model_id, "size_mb": os.path.getsize(onnx_path) / (1024 * 1024), "duration_seconds": duration, "output_dir": model_dir } else: logger.error(f"× ONNX file not created at {onnx_path}") return { "success": False, "model_id": model_id, "error": "ONNX file not created" } except Exception as e: logger.error(f"× Error converting model: {str(e)}") logger.error(traceback.format_exc()) return { "success": False, "model_id": model_id, "error": str(e) } def main(): """Convert all reliable models.""" # Print header logger.info("\nGUARANTEED ONNX CONVERTER") logger.info("======================") logger.info("Using reliable models with proven ONNX compatibility") # Create output directory output_dir = "./onnx_models" os.makedirs(output_dir, exist_ok=True) # Check if specific model ID provided as argument if len(sys.argv) > 1: model_id = sys.argv[1] logger.info(f"Converting single model: {model_id}") convert_model(model_id, output_dir) return # Convert all reliable models results = [] for model_info in RELIABLE_MODELS: model_id = model_info["id"] logger.info(f"Processing model: {model_id}") logger.info(f"Description: {model_info['description']}") result = convert_model(model_id, output_dir) results.append(result) # Print summary logger.info("\n" + "=" * 60) logger.info("CONVERSION SUMMARY") logger.info("=" * 60) success_count = 0 for result in results: if result.get("success", False): success_count += 1 size_info = f" - Size: {result.get('size_mb', 0):.2f} MB" time_info = f" - Time: {result.get('duration_seconds', 0):.2f}s" logger.info(f"✓ SUCCESS: {result['model_id']}{size_info}{time_info}") else: logger.info(f"× FAILED: {result['model_id']} - Error: {result.get('error', 'Unknown error')}") logger.info(f"\nSuccessfully converted {success_count}/{len(RELIABLE_MODELS)} models") logger.info(f"Models saved to: {os.path.abspath(output_dir)}") if success_count > 0: logger.info("\nThe models are ready for RAG and chatbot applications!") if __name__ == "__main__": main()