File size: 9,658 Bytes
16ffc97
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
import os
import gc
import sys
import time
import logging
import traceback
import torch
import warnings
from transformers import AutoModelForCausalLM, AutoTokenizer
from onnxruntime.quantization import quantize_dynamic, QuantType

# Configure logging
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(levelname)s - %(message)s',
    datefmt='%Y-%m-%d %H:%M:%S',
    handlers=[logging.StreamHandler(sys.stdout)]
)
logger = logging.getLogger(__name__)

# Suppress specific warnings
warnings.filterwarnings("ignore", category=UserWarning, message=".*The shape of the input dimension.*")
warnings.filterwarnings("ignore", category=UserWarning, message=".*Converting a tensor to a Python.*")

# Models that are known to work well with ONNX conversion
RELIABLE_MODELS = [
    {
        "id": "facebook/opt-350m",
        "description": "Well-balanced model (350M) for RAG and chatbots"
    },
    {
        "id": "gpt2",
        "description": "Very reliable model (124M) with excellent ONNX compatibility"
    },
    {
        "id": "distilgpt2",
        "description": "Lightweight (82M) model with good performance"
    }
]

class ModelWrapper(torch.nn.Module):
    """
    Wrapper to handle ONNX export compatibility issues.
    This wrapper specifically:
    1. Bypasses cache handling
    2. Simplifies the forward pass to avoid dynamic operations
    """
    def __init__(self, model):
        super().__init__()
        self.model = model
        
    def forward(self, input_ids):
        # Force no cache, no gradient, and no special features
        with torch.no_grad():
            return self.model(input_ids=input_ids, use_cache=False, return_dict=False)[0]

def convert_model(model_id, output_dir, quantize=True):
    """Convert a model to ONNX format with maximum compatibility."""
    start_time = time.time()
    
    logger.info(f"\n{'=' * 60}")
    logger.info(f"Converting {model_id} to ONNX")
    logger.info(f"{'=' * 60}")
    
    # Create output directory
    model_name = model_id.split("/")[-1]
    model_dir = os.path.join(output_dir, model_name)
    os.makedirs(model_dir, exist_ok=True)
    
    try:
        # Step 1: Load tokenizer
        logger.info("Step 1/5: Loading tokenizer...")
        
        tokenizer = AutoTokenizer.from_pretrained(model_id)
        
        # Handle missing pad token
        if tokenizer.pad_token is None and hasattr(tokenizer, 'eos_token'):
            logger.info("Adding pad_token = eos_token")
            tokenizer.pad_token = tokenizer.eos_token
        
        # Save tokenizer
        tokenizer.save_pretrained(model_dir)
        logger.info(f"βœ“ Tokenizer saved to {model_dir}")
        
        # Step 2: Load model with memory optimizations
        logger.info("Step 2/5: Loading model with memory optimizations...")
        
        # Clean memory before loading
        gc.collect()
        torch.cuda.empty_cache() if torch.cuda.is_available() else None
        
        # Load model with optimizations
        model = AutoModelForCausalLM.from_pretrained(
            model_id,
            torch_dtype=torch.float16,  # Use half precision
            low_cpu_mem_usage=True      # Reduce memory usage
        )
        
        # Save config for reference
        model.config.save_pretrained(model_dir)
        logger.info(f"βœ“ Model config saved to {model_dir}")
        
        # Step 3: Prepare for export
        logger.info("Step 3/5: Preparing for export...")
        
        # Wrap model to avoid tracing issues
        wrapped_model = ModelWrapper(model)
        wrapped_model.eval()  # Set to evaluation mode
        
        # Clean memory again
        gc.collect()
        torch.cuda.empty_cache() if torch.cuda.is_available() else None
        
        # Step 4: Export to ONNX
        logger.info("Step 4/5: Exporting to ONNX format...")
        onnx_path = os.path.join(model_dir, "model.onnx")
        
        # Create dummy input
        batch_size = 1
        seq_length = 8  # Small sequence length to reduce memory
        dummy_input = torch.ones(batch_size, seq_length, dtype=torch.long)
        
        # Export to ONNX format with new opset version
        torch.onnx.export(
            wrapped_model,             # Use wrapped model
            dummy_input,               # Model input
            onnx_path,                 # Output path
            export_params=True,        # Store model weights
            opset_version=14,          # ONNX opset version (changed from 13 to 14)
            do_constant_folding=True,  # Optimize constants
            input_names=['input_ids'], # Input names
            output_names=['logits'],   # Output names
            dynamic_axes={
                'input_ids': {0: 'batch_size', 1: 'sequence'},
                'logits': {0: 'batch_size', 1: 'sequence'}
            }
        )
        
        # Clean up to save memory
        del model
        del wrapped_model
        gc.collect()
        torch.cuda.empty_cache() if torch.cuda.is_available() else None
        
        # Verify export was successful
        if os.path.exists(onnx_path):
            size_mb = os.path.getsize(onnx_path) / (1024 * 1024)
            logger.info(f"βœ“ ONNX model saved to {onnx_path}")
            logger.info(f"βœ“ Original size: {size_mb:.2f} MB")
            
            # Step 5: Quantize
            if quantize:
                logger.info("Step 5/5: Applying int8 quantization...")
                quant_path = onnx_path.replace(".onnx", "_quantized.onnx")
                
                try:
                    quantize_dynamic(
                        model_input=onnx_path,
                        model_output=quant_path,
                        per_channel=False,
                        reduce_range=False,
                        weight_type=QuantType.QInt8
                    )
                    
                    if os.path.exists(quant_path):
                        quant_size = os.path.getsize(quant_path) / (1024 * 1024)
                        logger.info(f"βœ“ Quantized size: {quant_size:.2f} MB")
                        logger.info(f"βœ“ Size reduction: {(1 - quant_size/size_mb) * 100:.1f}%")
                        
                        # Replace original with quantized to save space
                        os.replace(quant_path, onnx_path)
                        logger.info("βœ“ Replaced original with quantized version")
                    else:
                        logger.warning("⚠ Quantized file not created, using original")
                except Exception as e:
                    logger.error(f"⚠ Quantization error: {str(e)}")
                    logger.info("⚠ Using original model without quantization")
            else:
                logger.info("Step 5/5: Skipping quantization (not requested)")
            
            # Calculate elapsed time
            end_time = time.time()
            duration = end_time - start_time
            logger.info(f"βœ“ Conversion completed in {duration:.2f} seconds")
            
            return {
                "success": True,
                "model_id": model_id,
                "size_mb": os.path.getsize(onnx_path) / (1024 * 1024),
                "duration_seconds": duration,
                "output_dir": model_dir
            }
        else:
            logger.error(f"Γ— ONNX file not created at {onnx_path}")
            return {
                "success": False,
                "model_id": model_id,
                "error": "ONNX file not created"
            }
    
    except Exception as e:
        logger.error(f"Γ— Error converting model: {str(e)}")
        logger.error(traceback.format_exc())
        
        return {
            "success": False,
            "model_id": model_id,
            "error": str(e)
        }

def main():
    """Convert all reliable models."""
    # Print header
    logger.info("\nGUARANTEED ONNX CONVERTER")
    logger.info("======================")
    logger.info("Using reliable models with proven ONNX compatibility")
    
    # Create output directory
    output_dir = "./onnx_models"
    os.makedirs(output_dir, exist_ok=True)
    
    # Check if specific model ID provided as argument
    if len(sys.argv) > 1:
        model_id = sys.argv[1]
        logger.info(f"Converting single model: {model_id}")
        convert_model(model_id, output_dir)
        return
    
    # Convert all reliable models
    results = []
    for model_info in RELIABLE_MODELS:
        model_id = model_info["id"]
        logger.info(f"Processing model: {model_id}")
        logger.info(f"Description: {model_info['description']}")
        
        result = convert_model(model_id, output_dir)
        results.append(result)
    
    # Print summary
    logger.info("\n" + "=" * 60)
    logger.info("CONVERSION SUMMARY")
    logger.info("=" * 60)
    
    success_count = 0
    for result in results:
        if result.get("success", False):
            success_count += 1
            size_info = f" - Size: {result.get('size_mb', 0):.2f} MB"
            time_info = f" - Time: {result.get('duration_seconds', 0):.2f}s"
            logger.info(f"βœ“ SUCCESS: {result['model_id']}{size_info}{time_info}")
        else:
            logger.info(f"Γ— FAILED: {result['model_id']} - Error: {result.get('error', 'Unknown error')}")
    
    logger.info(f"\nSuccessfully converted {success_count}/{len(RELIABLE_MODELS)} models")
    logger.info(f"Models saved to: {os.path.abspath(output_dir)}")
    
    if success_count > 0:
        logger.info("\nThe models are ready for RAG and chatbot applications!")

if __name__ == "__main__":
    main()