File size: 19,150 Bytes
16ffc97
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
import os
import gc
import sys
import time
import logging
import traceback
import torch
import warnings
import numpy as np
from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers.generation import GenerationConfig
from tqdm import tqdm
import onnx
from onnxruntime.quantization import quantize_dynamic, QuantType

# Configure logging
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(levelname)s - %(message)s',
    datefmt='%Y-%m-%d %H:%M:%S'
)
logger = logging.getLogger(__name__)

# Suppress unhelpful warnings
warnings.filterwarnings("ignore", category=UserWarning, message=".*The shape of the input dimension.*")
warnings.filterwarnings("ignore", category=UserWarning, message=".*Converting a tensor to a Python.*")
warnings.filterwarnings("ignore", category=UserWarning, message=".*The model does not use GenerationMixin.*")


class GenerationWrapper(torch.nn.Module):
    """
    Wrapper for model export that handles generation properly.
    This ensures the model can be correctly used for text generation.
    """
    def __init__(self, model):
        super().__init__()
        self.model = model
        self.config = model.config
        
    def forward(self, input_ids, attention_mask=None):
        # Return only the logits to avoid complex structures
        with torch.no_grad():
            try:
                # Standard approach for most models
                outputs = self.model(
                    input_ids=input_ids, 
                    attention_mask=attention_mask,
                    use_cache=False,
                    return_dict=True
                )
                return outputs.logits
            except Exception as e:
                logger.warning(f"Standard forward pass failed, trying fallback: {str(e)}")
                # Fallback for models with different API
                outputs = self.model(input_ids=input_ids, attention_mask=attention_mask)
                if hasattr(outputs, 'logits'):
                    return outputs.logits
                elif isinstance(outputs, tuple) and len(outputs) > 0:
                    return outputs[0]  # First element is typically logits
                else:
                    raise ValueError("Could not extract logits from model outputs")


def verify_model_generation(model, tokenizer, device="cpu"):
    """Test model generation capabilities before export"""
    model.eval()
    prompt = "Hello, how are you today? I am"
    
    logger.info("Testing model generation...")
    inputs = tokenizer(prompt, return_tensors="pt").to(device)
    
    # Configure generation parameters
    gen_config = GenerationConfig(
        max_length=30,
        do_sample=True,
        temperature=0.7,
        num_return_sequences=1,
    )
    
    try:
        # Try generation
        with torch.no_grad():
            outputs = model.generate(
                **inputs,
                generation_config=gen_config
            )
        
        generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
        logger.info(f"Test generation result: {generated_text}")
        
        if len(generated_text) <= len(prompt):
            logger.warning("Generation output is not longer than input prompt!")
            
        return True
    except Exception as e:
        logger.error(f"Generation test failed: {str(e)}")
        return False


def test_onnx_model(onnx_path, tokenizer):
    """Verify the ONNX model can be loaded and run"""
    try:
        import onnxruntime as ort
        
        logger.info("Testing ONNX model inference...")
        session = ort.InferenceSession(onnx_path)
        
        # Get input and output names
        input_names = [input.name for input in session.get_inputs()]
        output_names = [output.name for output in session.get_outputs()]
        
        # Create test input
        prompt = "Hello, how are you?"
        inputs = tokenizer(prompt, return_tensors="np")
        
        # Prepare input dict
        onnx_inputs = {}
        for name in input_names:
            if name == "input_ids" and "input_ids" in inputs:
                onnx_inputs[name] = inputs["input_ids"]
            elif name == "attention_mask" and "attention_mask" in inputs:
                onnx_inputs[name] = inputs["attention_mask"]
        
        # Run inference
        outputs = session.run(output_names, onnx_inputs)
        
        # Check output shape
        logits = outputs[0]
        logger.info(f"ONNX model output shape: {logits.shape}")
        
        if logits.shape[0] != 1 or logits.shape[1] != inputs["input_ids"].shape[1]:
            logger.warning("Output shape doesn't match expected dimensions!")
        
        # Test next token prediction
        next_token_logits = logits[0, -1, :]
        next_token_id = np.argmax(next_token_logits)
        next_token = tokenizer.decode([next_token_id])
        logger.info(f"Next predicted token: '{next_token}'")
        
        return True
    except Exception as e:
        logger.error(f"ONNX model test failed: {str(e)}")
        return False


def optimize_onnx_model(onnx_path):
    """Apply ONNX optimizations to improve performance"""
    try:
        logger.info("Optimizing ONNX model...")
        
        # Load the model
        model = onnx.load(onnx_path)
        
        # Apply optimizations
        from onnxruntime.transformers import optimizer
        
        # Get model type from path
        model_path = os.path.dirname(onnx_path)
        model_name = os.path.basename(model_path).lower()
        
        # Determine model type for optimization
        if "gpt" in model_name:
            model_type = "gpt2"
        elif "opt" in model_name:
            model_type = "opt"
        elif "pythia" in model_name:
            model_type = "gpt_neox"
        else:
            model_type = "gpt2"  # Default fallback
            
        logger.info(f"Using optimization profile for model type: {model_type}")
        
        # Try to optimize the model
        try:
            optimized_model = optimizer.optimize_model(
                onnx_path,
                model_type=model_type,
                num_heads=8,  # Will be overridden by model's real config
                hidden_size=768,  # Will be overridden by model's real config
                optimization_options=None
            )
            optimized_model.save_model_to_file(onnx_path)
            logger.info("βœ“ ONNX model optimized")
            return True
        except Exception as e:
            logger.warning(f"Optimization failed (non-critical): {str(e)}")
            return False
            
    except Exception as e:
        logger.warning(f"ONNX optimization error (skipping): {str(e)}")
        return False


def convert_model(model_id, output_dir="./onnx_models", seq_length=32, quantize=True):
    """
    Convert a model to ONNX format with focus on reliability for generation.
    
    Args:
        model_id: HuggingFace model ID or path
        output_dir: Directory to save the model
        seq_length: Input sequence length for export
        quantize: Whether to quantize the model to INT8
        
    Returns:
        bool: Success status
    """
    start_time = time.time()
    
    logger.info(f"\n{'=' * 60}")
    logger.info(f"Converting {model_id} to ONNX (optimized for generation)")
    logger.info(f"{'=' * 60}")
    
    # Create output directory
    model_name = model_id.split("/")[-1]
    model_dir = os.path.join(output_dir, model_name)
    os.makedirs(model_dir, exist_ok=True)
    
    try:
        # Step 1: Load tokenizer
        logger.info("Step 1/6: Loading tokenizer...")
        
        tokenizer = AutoTokenizer.from_pretrained(model_id)
        if tokenizer.pad_token is None and hasattr(tokenizer, 'eos_token'):
            logger.info("Adding pad_token = eos_token")
            tokenizer.pad_token = tokenizer.eos_token
        
        # Save tokenizer
        tokenizer.save_pretrained(model_dir)
        logger.info(f"βœ“ Tokenizer saved to {model_dir}")
        
        # Step 2: Load model with reliability optimizations
        logger.info("Step 2/6: Loading model...")
        
        # Clean memory
        gc.collect()
        torch.cuda.empty_cache() if torch.cuda.is_available() else None
        
        # Determine device
        device = "cuda" if torch.cuda.is_available() else "cpu"
        
        # Load model with full precision
        model = AutoModelForCausalLM.from_pretrained(
            model_id,
            torch_dtype=torch.float32,  # Use full precision for reliability
            low_cpu_mem_usage=True,    # Reduce memory usage
            device_map=device         # Use CUDA if available
        )
        
        # Save config
        model.config.save_pretrained(model_dir)
        logger.info(f"βœ“ Model config saved to {model_dir}")
        
        # Step 3: Verify model can generate text
        logger.info("Step 3/6: Validating generation capabilities...")
        
        if not verify_model_generation(model, tokenizer, device):
            logger.warning("⚠ Model generation test didn't complete successfully")
            logger.info("Continuing with export anyway...")
        
        # Step 4: Wrap and prepare for export
        logger.info("Step 4/6: Preparing for export...")
        
        # Wrap model with generation-optimized interface
        wrapped_model = GenerationWrapper(model)
        wrapped_model.eval()
        
        # Clean memory again
        gc.collect()
        torch.cuda.empty_cache() if torch.cuda.is_available() else None
        
        # Step 5: Export to ONNX
        logger.info("Step 5/6: Exporting to ONNX format...")
        onnx_path = os.path.join(model_dir, "model.onnx")
        
        # Create minimal input
        batch_size = 1
        dummy_input = torch.ones(batch_size, seq_length, dtype=torch.long)
        attention_mask = torch.ones(batch_size, seq_length, dtype=torch.long)
        
        # Move tensors to correct device
        dummy_input = dummy_input.to(device)
        attention_mask = attention_mask.to(device)
        
        # Export to ONNX with required opset for transformer models
        with torch.no_grad():
            torch.onnx.export(
                wrapped_model,                # Wrapped model
                (dummy_input, attention_mask), # Input tensors
                onnx_path,                    # Output path
                export_params=True,           # Store weights
                opset_version=14,             # Required for transformer models
                do_constant_folding=True,     # Optimize constants
                input_names=['input_ids', 'attention_mask'],  # Input names
                output_names=['logits'],      # Output name
                dynamic_axes={                # Dynamic dimensions
                    'input_ids': {0: 'batch_size', 1: 'sequence'},
                    'attention_mask': {0: 'batch_size', 1: 'sequence'},
                    'logits': {0: 'batch_size', 1: 'sequence'}
                }
            )
        
        # Clean up to save memory
        del model
        del wrapped_model
        gc.collect()
        torch.cuda.empty_cache() if torch.cuda.is_available() else None
        
        # Verify export success
        if os.path.exists(onnx_path):
            size_mb = os.path.getsize(onnx_path) / (1024 * 1024)
            logger.info(f"βœ“ ONNX model saved to {onnx_path}")
            logger.info(f"βœ“ Original size: {size_mb:.2f} MB")
            
            # Test ONNX model
            test_onnx_model(onnx_path, tokenizer)
            
            # Optimize the ONNX model
            optimize_onnx_model(onnx_path)
            
            # Step 6: Quantize the model (optional)
            if quantize:
                logger.info("Step 6/6: Applying INT8 quantization...")
                quant_path = onnx_path.replace(".onnx", "_quantized.onnx")
                
                try:
                    with tqdm(total=100, desc="Quantizing") as pbar:
                        # Update progress callback
                        def update_progress(x):
                            pbar.update(1)
                        
                        quantize_dynamic(
                            model_input=onnx_path,
                            model_output=quant_path,
                            per_channel=False,
                            reduce_range=False,
                            weight_type=QuantType.QInt8,
                            optimize_model=True,
                            use_external_data_format=False
                        )
                        
                        pbar.update(100)  # Ensure progress reaches 100%
                    
                    if os.path.exists(quant_path):
                        quant_size = os.path.getsize(quant_path) / (1024 * 1024)
                        logger.info(f"βœ“ Quantized size: {quant_size:.2f} MB")
                        logger.info(f"βœ“ Size reduction: {(1 - quant_size/size_mb) * 100:.1f}%")
                        
                        # Test the quantized model
                        test_onnx_model(quant_path, tokenizer)
                        
                        # Rename original as backup
                        backup_path = onnx_path.replace(".onnx", "_fp32.onnx")
                        os.rename(onnx_path, backup_path)
                        
                        # Replace original with quantized
                        os.rename(quant_path, onnx_path)
                        logger.info("βœ“ Original model preserved as *_fp32.onnx")
                        logger.info("βœ“ Replaced original with quantized version")
                    else:
                        logger.warning("⚠ Quantized file not created, using original")
                except Exception as e:
                    logger.error(f"⚠ Quantization error: {str(e)}")
                    logger.info("⚠ Using original model without quantization")
            else:
                logger.info("Step 6/6: Skipping quantization as requested")
            
            # Calculate elapsed time
            end_time = time.time()
            duration = end_time - start_time
            logger.info(f"βœ“ Conversion completed in {duration:.2f} seconds")
            logger.info(f"βœ“ Final model size: {os.path.getsize(onnx_path) / (1024 * 1024):.2f} MB")
            
            # Create a simple example usage file
            example_path = os.path.join(model_dir, "example_usage.py")
            with open(example_path, 'w') as f:
                f.write("""
import onnxruntime as ort
from transformers import AutoTokenizer
import numpy as np

# Load tokenizer and model
tokenizer = AutoTokenizer.from_pretrained("./")  # Path to model directory
session = ort.InferenceSession("./model.onnx")

# Prepare input
prompt = "Hello, how are you?"
inputs = tokenizer(prompt, return_tensors="np")

# Run inference for a single step
outputs = session.run(
    ["logits"], 
    {
        "input_ids": inputs["input_ids"],
        "attention_mask": inputs["attention_mask"]
    }
)

# Get next token prediction
logits = outputs[0]
next_token_id = np.argmax(logits[0, -1, :])
next_token = tokenizer.decode([next_token_id])
print(f"Next predicted token: {next_token}")

# For full generation, you'd typically run in a loop, adding tokens one by one
""")
            logger.info(f"βœ“ Example usage saved to {example_path}")
            
            return True
        else:
            logger.error(f"Γ— ONNX file not created at {onnx_path}")
            return False
    
    except Exception as e:
        logger.error(f"Γ— Error converting model: {str(e)}")
        logger.error(traceback.format_exc())
        return False


if __name__ == "__main__":
    # Parse command line arguments
    parser_available = False
    try:
        import argparse
        parser = argparse.ArgumentParser(description="Convert HuggingFace models to ONNX for generation")
        parser.add_argument("model_id", type=str, help="HuggingFace model ID or path")
        parser.add_argument("--output_dir", "-o", type=str, default="./onnx_models", 
                          help="Output directory for the converted model")
        parser.add_argument("--seq_length", "-s", type=int, default=32,
                          help="Sequence length for model export")
        parser.add_argument("--no_quantize", action="store_true",
                          help="Skip INT8 quantization step")
        
        args = parser.parse_args()
        parser_available = True
        
        model_id = args.model_id
        output_dir = args.output_dir
        seq_length = args.seq_length
        quantize = not args.no_quantize
        
    except (ImportError, NameError):
        # Fallback if argparse is not available
        parser_available = False
    
    if not parser_available:
        if len(sys.argv) < 2:
            print("Usage: python convert_model.py MODEL_ID [OUTPUT_DIR] [SEQ_LENGTH] [--no-quantize]")
            print("Example: python convert_model.py facebook/opt-125m ./onnx_models 32")
            print("\nRecommended models for small hardware:")
            print("  - facebook/opt-125m")
            print("  - distilgpt2")
            print("  - TinyLlama/TinyLlama-1.1B-Chat-v1.0")
            print("  - EleutherAI/pythia-70m")
            sys.exit(1)
        
        model_id = sys.argv[1]
        output_dir = sys.argv[2] if len(sys.argv) > 2 else "./onnx_models"
        seq_length = int(sys.argv[3]) if len(sys.argv) > 3 else 32
        quantize = "--no-quantize" not in sys.argv and "--no_quantize" not in sys.argv
    
    # Print header
    logger.info("\nENHANCED ONNX CONVERTER FOR LANGUAGE MODELS")
    logger.info("============================================")
    logger.info(f"Model: {model_id}")
    logger.info(f"Output directory: {output_dir}")
    logger.info(f"Sequence length: {seq_length}")
    logger.info(f"Quantization: {'Enabled' if quantize else 'Disabled'}")
    
    # Create output directory
    os.makedirs(output_dir, exist_ok=True)
    
    # Convert the model
    success = convert_model(model_id, output_dir, seq_length, quantize)
    
    if success:
        logger.info("\n" + "=" * 60)
        logger.info("CONVERSION SUCCESSFUL")
        logger.info("=" * 60)
        logger.info(f"Model: {model_id}")
        logger.info(f"Output directory: {os.path.abspath(output_dir)}")
        logger.info("The model is ready for generation!")
        logger.info("\nTo use the model:")
        logger.info("1. See the example_usage.py file in the model directory")
        logger.info("2. For chatbot applications, implement token-by-token generation")
    else:
        logger.error("\n" + "=" * 60)
        logger.error("CONVERSION FAILED")
        logger.error("=" * 60)
        logger.error("Please try one of the recommended models:")
        logger.error("  - facebook/opt-125m")
        logger.error("  - distilgpt2")
        logger.error("  - EleutherAI/pythia-70m")