import os import gc import sys import time import logging import traceback import torch import warnings import numpy as np from transformers import AutoModelForCausalLM, AutoTokenizer from transformers.generation import GenerationConfig from tqdm import tqdm import onnx from onnxruntime.quantization import quantize_dynamic, QuantType # Configure logging logging.basicConfig( level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s', datefmt='%Y-%m-%d %H:%M:%S' ) logger = logging.getLogger(__name__) # Suppress unhelpful warnings warnings.filterwarnings("ignore", category=UserWarning, message=".*The shape of the input dimension.*") warnings.filterwarnings("ignore", category=UserWarning, message=".*Converting a tensor to a Python.*") warnings.filterwarnings("ignore", category=UserWarning, message=".*The model does not use GenerationMixin.*") class GenerationWrapper(torch.nn.Module): """ Wrapper for model export that handles generation properly. This ensures the model can be correctly used for text generation. """ def __init__(self, model): super().__init__() self.model = model self.config = model.config def forward(self, input_ids, attention_mask=None): # Return only the logits to avoid complex structures with torch.no_grad(): try: # Standard approach for most models outputs = self.model( input_ids=input_ids, attention_mask=attention_mask, use_cache=False, return_dict=True ) return outputs.logits except Exception as e: logger.warning(f"Standard forward pass failed, trying fallback: {str(e)}") # Fallback for models with different API outputs = self.model(input_ids=input_ids, attention_mask=attention_mask) if hasattr(outputs, 'logits'): return outputs.logits elif isinstance(outputs, tuple) and len(outputs) > 0: return outputs[0] # First element is typically logits else: raise ValueError("Could not extract logits from model outputs") def verify_model_generation(model, tokenizer, device="cpu"): """Test model generation capabilities before export""" model.eval() prompt = "Hello, how are you today? I am" logger.info("Testing model generation...") inputs = tokenizer(prompt, return_tensors="pt").to(device) # Configure generation parameters gen_config = GenerationConfig( max_length=30, do_sample=True, temperature=0.7, num_return_sequences=1, ) try: # Try generation with torch.no_grad(): outputs = model.generate( **inputs, generation_config=gen_config ) generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True) logger.info(f"Test generation result: {generated_text}") if len(generated_text) <= len(prompt): logger.warning("Generation output is not longer than input prompt!") return True except Exception as e: logger.error(f"Generation test failed: {str(e)}") return False def test_onnx_model(onnx_path, tokenizer): """Verify the ONNX model can be loaded and run""" try: import onnxruntime as ort logger.info("Testing ONNX model inference...") session = ort.InferenceSession(onnx_path) # Get input and output names input_names = [input.name for input in session.get_inputs()] output_names = [output.name for output in session.get_outputs()] # Create test input prompt = "Hello, how are you?" inputs = tokenizer(prompt, return_tensors="np") # Prepare input dict onnx_inputs = {} for name in input_names: if name == "input_ids" and "input_ids" in inputs: onnx_inputs[name] = inputs["input_ids"] elif name == "attention_mask" and "attention_mask" in inputs: onnx_inputs[name] = inputs["attention_mask"] # Run inference outputs = session.run(output_names, onnx_inputs) # Check output shape logits = outputs[0] logger.info(f"ONNX model output shape: {logits.shape}") if logits.shape[0] != 1 or logits.shape[1] != inputs["input_ids"].shape[1]: logger.warning("Output shape doesn't match expected dimensions!") # Test next token prediction next_token_logits = logits[0, -1, :] next_token_id = np.argmax(next_token_logits) next_token = tokenizer.decode([next_token_id]) logger.info(f"Next predicted token: '{next_token}'") return True except Exception as e: logger.error(f"ONNX model test failed: {str(e)}") return False def optimize_onnx_model(onnx_path): """Apply ONNX optimizations to improve performance""" try: logger.info("Optimizing ONNX model...") # Load the model model = onnx.load(onnx_path) # Apply optimizations from onnxruntime.transformers import optimizer # Get model type from path model_path = os.path.dirname(onnx_path) model_name = os.path.basename(model_path).lower() # Determine model type for optimization if "gpt" in model_name: model_type = "gpt2" elif "opt" in model_name: model_type = "opt" elif "pythia" in model_name: model_type = "gpt_neox" else: model_type = "gpt2" # Default fallback logger.info(f"Using optimization profile for model type: {model_type}") # Try to optimize the model try: optimized_model = optimizer.optimize_model( onnx_path, model_type=model_type, num_heads=8, # Will be overridden by model's real config hidden_size=768, # Will be overridden by model's real config optimization_options=None ) optimized_model.save_model_to_file(onnx_path) logger.info("✓ ONNX model optimized") return True except Exception as e: logger.warning(f"Optimization failed (non-critical): {str(e)}") return False except Exception as e: logger.warning(f"ONNX optimization error (skipping): {str(e)}") return False def convert_model(model_id, output_dir="./onnx_models", seq_length=32, quantize=True): """ Convert a model to ONNX format with focus on reliability for generation. Args: model_id: HuggingFace model ID or path output_dir: Directory to save the model seq_length: Input sequence length for export quantize: Whether to quantize the model to INT8 Returns: bool: Success status """ start_time = time.time() logger.info(f"\n{'=' * 60}") logger.info(f"Converting {model_id} to ONNX (optimized for generation)") logger.info(f"{'=' * 60}") # Create output directory model_name = model_id.split("/")[-1] model_dir = os.path.join(output_dir, model_name) os.makedirs(model_dir, exist_ok=True) try: # Step 1: Load tokenizer logger.info("Step 1/6: Loading tokenizer...") tokenizer = AutoTokenizer.from_pretrained(model_id) if tokenizer.pad_token is None and hasattr(tokenizer, 'eos_token'): logger.info("Adding pad_token = eos_token") tokenizer.pad_token = tokenizer.eos_token # Save tokenizer tokenizer.save_pretrained(model_dir) logger.info(f"✓ Tokenizer saved to {model_dir}") # Step 2: Load model with reliability optimizations logger.info("Step 2/6: Loading model...") # Clean memory gc.collect() torch.cuda.empty_cache() if torch.cuda.is_available() else None # Determine device device = "cuda" if torch.cuda.is_available() else "cpu" # Load model with full precision model = AutoModelForCausalLM.from_pretrained( model_id, torch_dtype=torch.float32, # Use full precision for reliability low_cpu_mem_usage=True, # Reduce memory usage device_map=device # Use CUDA if available ) # Save config model.config.save_pretrained(model_dir) logger.info(f"✓ Model config saved to {model_dir}") # Step 3: Verify model can generate text logger.info("Step 3/6: Validating generation capabilities...") if not verify_model_generation(model, tokenizer, device): logger.warning("⚠ Model generation test didn't complete successfully") logger.info("Continuing with export anyway...") # Step 4: Wrap and prepare for export logger.info("Step 4/6: Preparing for export...") # Wrap model with generation-optimized interface wrapped_model = GenerationWrapper(model) wrapped_model.eval() # Clean memory again gc.collect() torch.cuda.empty_cache() if torch.cuda.is_available() else None # Step 5: Export to ONNX logger.info("Step 5/6: Exporting to ONNX format...") onnx_path = os.path.join(model_dir, "model.onnx") # Create minimal input batch_size = 1 dummy_input = torch.ones(batch_size, seq_length, dtype=torch.long) attention_mask = torch.ones(batch_size, seq_length, dtype=torch.long) # Move tensors to correct device dummy_input = dummy_input.to(device) attention_mask = attention_mask.to(device) # Export to ONNX with required opset for transformer models with torch.no_grad(): torch.onnx.export( wrapped_model, # Wrapped model (dummy_input, attention_mask), # Input tensors onnx_path, # Output path export_params=True, # Store weights opset_version=14, # Required for transformer models do_constant_folding=True, # Optimize constants input_names=['input_ids', 'attention_mask'], # Input names output_names=['logits'], # Output name dynamic_axes={ # Dynamic dimensions 'input_ids': {0: 'batch_size', 1: 'sequence'}, 'attention_mask': {0: 'batch_size', 1: 'sequence'}, 'logits': {0: 'batch_size', 1: 'sequence'} } ) # Clean up to save memory del model del wrapped_model gc.collect() torch.cuda.empty_cache() if torch.cuda.is_available() else None # Verify export success if os.path.exists(onnx_path): size_mb = os.path.getsize(onnx_path) / (1024 * 1024) logger.info(f"✓ ONNX model saved to {onnx_path}") logger.info(f"✓ Original size: {size_mb:.2f} MB") # Test ONNX model test_onnx_model(onnx_path, tokenizer) # Optimize the ONNX model optimize_onnx_model(onnx_path) # Step 6: Quantize the model (optional) if quantize: logger.info("Step 6/6: Applying INT8 quantization...") quant_path = onnx_path.replace(".onnx", "_quantized.onnx") try: with tqdm(total=100, desc="Quantizing") as pbar: # Update progress callback def update_progress(x): pbar.update(1) quantize_dynamic( model_input=onnx_path, model_output=quant_path, per_channel=False, reduce_range=False, weight_type=QuantType.QInt8, optimize_model=True, use_external_data_format=False ) pbar.update(100) # Ensure progress reaches 100% if os.path.exists(quant_path): quant_size = os.path.getsize(quant_path) / (1024 * 1024) logger.info(f"✓ Quantized size: {quant_size:.2f} MB") logger.info(f"✓ Size reduction: {(1 - quant_size/size_mb) * 100:.1f}%") # Test the quantized model test_onnx_model(quant_path, tokenizer) # Rename original as backup backup_path = onnx_path.replace(".onnx", "_fp32.onnx") os.rename(onnx_path, backup_path) # Replace original with quantized os.rename(quant_path, onnx_path) logger.info("✓ Original model preserved as *_fp32.onnx") logger.info("✓ Replaced original with quantized version") else: logger.warning("⚠ Quantized file not created, using original") except Exception as e: logger.error(f"⚠ Quantization error: {str(e)}") logger.info("⚠ Using original model without quantization") else: logger.info("Step 6/6: Skipping quantization as requested") # Calculate elapsed time end_time = time.time() duration = end_time - start_time logger.info(f"✓ Conversion completed in {duration:.2f} seconds") logger.info(f"✓ Final model size: {os.path.getsize(onnx_path) / (1024 * 1024):.2f} MB") # Create a simple example usage file example_path = os.path.join(model_dir, "example_usage.py") with open(example_path, 'w') as f: f.write(""" import onnxruntime as ort from transformers import AutoTokenizer import numpy as np # Load tokenizer and model tokenizer = AutoTokenizer.from_pretrained("./") # Path to model directory session = ort.InferenceSession("./model.onnx") # Prepare input prompt = "Hello, how are you?" inputs = tokenizer(prompt, return_tensors="np") # Run inference for a single step outputs = session.run( ["logits"], { "input_ids": inputs["input_ids"], "attention_mask": inputs["attention_mask"] } ) # Get next token prediction logits = outputs[0] next_token_id = np.argmax(logits[0, -1, :]) next_token = tokenizer.decode([next_token_id]) print(f"Next predicted token: {next_token}") # For full generation, you'd typically run in a loop, adding tokens one by one """) logger.info(f"✓ Example usage saved to {example_path}") return True else: logger.error(f"× ONNX file not created at {onnx_path}") return False except Exception as e: logger.error(f"× Error converting model: {str(e)}") logger.error(traceback.format_exc()) return False if __name__ == "__main__": # Parse command line arguments parser_available = False try: import argparse parser = argparse.ArgumentParser(description="Convert HuggingFace models to ONNX for generation") parser.add_argument("model_id", type=str, help="HuggingFace model ID or path") parser.add_argument("--output_dir", "-o", type=str, default="./onnx_models", help="Output directory for the converted model") parser.add_argument("--seq_length", "-s", type=int, default=32, help="Sequence length for model export") parser.add_argument("--no_quantize", action="store_true", help="Skip INT8 quantization step") args = parser.parse_args() parser_available = True model_id = args.model_id output_dir = args.output_dir seq_length = args.seq_length quantize = not args.no_quantize except (ImportError, NameError): # Fallback if argparse is not available parser_available = False if not parser_available: if len(sys.argv) < 2: print("Usage: python convert_model.py MODEL_ID [OUTPUT_DIR] [SEQ_LENGTH] [--no-quantize]") print("Example: python convert_model.py facebook/opt-125m ./onnx_models 32") print("\nRecommended models for small hardware:") print(" - facebook/opt-125m") print(" - distilgpt2") print(" - TinyLlama/TinyLlama-1.1B-Chat-v1.0") print(" - EleutherAI/pythia-70m") sys.exit(1) model_id = sys.argv[1] output_dir = sys.argv[2] if len(sys.argv) > 2 else "./onnx_models" seq_length = int(sys.argv[3]) if len(sys.argv) > 3 else 32 quantize = "--no-quantize" not in sys.argv and "--no_quantize" not in sys.argv # Print header logger.info("\nENHANCED ONNX CONVERTER FOR LANGUAGE MODELS") logger.info("============================================") logger.info(f"Model: {model_id}") logger.info(f"Output directory: {output_dir}") logger.info(f"Sequence length: {seq_length}") logger.info(f"Quantization: {'Enabled' if quantize else 'Disabled'}") # Create output directory os.makedirs(output_dir, exist_ok=True) # Convert the model success = convert_model(model_id, output_dir, seq_length, quantize) if success: logger.info("\n" + "=" * 60) logger.info("CONVERSION SUCCESSFUL") logger.info("=" * 60) logger.info(f"Model: {model_id}") logger.info(f"Output directory: {os.path.abspath(output_dir)}") logger.info("The model is ready for generation!") logger.info("\nTo use the model:") logger.info("1. See the example_usage.py file in the model directory") logger.info("2. For chatbot applications, implement token-by-token generation") else: logger.error("\n" + "=" * 60) logger.error("CONVERSION FAILED") logger.error("=" * 60) logger.error("Please try one of the recommended models:") logger.error(" - facebook/opt-125m") logger.error(" - distilgpt2") logger.error(" - EleutherAI/pythia-70m")