|
import os |
|
import gc |
|
import sys |
|
import time |
|
import logging |
|
import traceback |
|
import torch |
|
import warnings |
|
from transformers import AutoModelForCausalLM, AutoTokenizer |
|
from onnxruntime.quantization import quantize_dynamic, QuantType |
|
|
|
|
|
logging.basicConfig( |
|
level=logging.INFO, |
|
format='%(asctime)s - %(levelname)s - %(message)s', |
|
datefmt='%Y-%m-%d %H:%M:%S', |
|
handlers=[logging.StreamHandler(sys.stdout)] |
|
) |
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
warnings.filterwarnings("ignore", category=UserWarning, message=".*The shape of the input dimension.*") |
|
warnings.filterwarnings("ignore", category=UserWarning, message=".*Converting a tensor to a Python.*") |
|
|
|
|
|
RELIABLE_MODELS = [ |
|
{ |
|
"id": "facebook/opt-350m", |
|
"description": "Well-balanced model (350M) for RAG and chatbots" |
|
}, |
|
{ |
|
"id": "gpt2", |
|
"description": "Very reliable model (124M) with excellent ONNX compatibility" |
|
}, |
|
{ |
|
"id": "distilgpt2", |
|
"description": "Lightweight (82M) model with good performance" |
|
} |
|
] |
|
|
|
class ModelWrapper(torch.nn.Module): |
|
""" |
|
Wrapper to handle ONNX export compatibility issues. |
|
This wrapper specifically: |
|
1. Bypasses cache handling |
|
2. Simplifies the forward pass to avoid dynamic operations |
|
""" |
|
def __init__(self, model): |
|
super().__init__() |
|
self.model = model |
|
|
|
def forward(self, input_ids): |
|
|
|
with torch.no_grad(): |
|
return self.model(input_ids=input_ids, use_cache=False, return_dict=False)[0] |
|
|
|
def convert_model(model_id, output_dir, quantize=True): |
|
"""Convert a model to ONNX format with maximum compatibility.""" |
|
start_time = time.time() |
|
|
|
logger.info(f"\n{'=' * 60}") |
|
logger.info(f"Converting {model_id} to ONNX") |
|
logger.info(f"{'=' * 60}") |
|
|
|
|
|
model_name = model_id.split("/")[-1] |
|
model_dir = os.path.join(output_dir, model_name) |
|
os.makedirs(model_dir, exist_ok=True) |
|
|
|
try: |
|
|
|
logger.info("Step 1/5: Loading tokenizer...") |
|
|
|
tokenizer = AutoTokenizer.from_pretrained(model_id) |
|
|
|
|
|
if tokenizer.pad_token is None and hasattr(tokenizer, 'eos_token'): |
|
logger.info("Adding pad_token = eos_token") |
|
tokenizer.pad_token = tokenizer.eos_token |
|
|
|
|
|
tokenizer.save_pretrained(model_dir) |
|
logger.info(f"β Tokenizer saved to {model_dir}") |
|
|
|
|
|
logger.info("Step 2/5: Loading model with memory optimizations...") |
|
|
|
|
|
gc.collect() |
|
torch.cuda.empty_cache() if torch.cuda.is_available() else None |
|
|
|
|
|
model = AutoModelForCausalLM.from_pretrained( |
|
model_id, |
|
torch_dtype=torch.float16, |
|
low_cpu_mem_usage=True |
|
) |
|
|
|
|
|
model.config.save_pretrained(model_dir) |
|
logger.info(f"β Model config saved to {model_dir}") |
|
|
|
|
|
logger.info("Step 3/5: Preparing for export...") |
|
|
|
|
|
wrapped_model = ModelWrapper(model) |
|
wrapped_model.eval() |
|
|
|
|
|
gc.collect() |
|
torch.cuda.empty_cache() if torch.cuda.is_available() else None |
|
|
|
|
|
logger.info("Step 4/5: Exporting to ONNX format...") |
|
onnx_path = os.path.join(model_dir, "model.onnx") |
|
|
|
|
|
batch_size = 1 |
|
seq_length = 8 |
|
dummy_input = torch.ones(batch_size, seq_length, dtype=torch.long) |
|
|
|
|
|
torch.onnx.export( |
|
wrapped_model, |
|
dummy_input, |
|
onnx_path, |
|
export_params=True, |
|
opset_version=14, |
|
do_constant_folding=True, |
|
input_names=['input_ids'], |
|
output_names=['logits'], |
|
dynamic_axes={ |
|
'input_ids': {0: 'batch_size', 1: 'sequence'}, |
|
'logits': {0: 'batch_size', 1: 'sequence'} |
|
} |
|
) |
|
|
|
|
|
del model |
|
del wrapped_model |
|
gc.collect() |
|
torch.cuda.empty_cache() if torch.cuda.is_available() else None |
|
|
|
|
|
if os.path.exists(onnx_path): |
|
size_mb = os.path.getsize(onnx_path) / (1024 * 1024) |
|
logger.info(f"β ONNX model saved to {onnx_path}") |
|
logger.info(f"β Original size: {size_mb:.2f} MB") |
|
|
|
|
|
if quantize: |
|
logger.info("Step 5/5: Applying int8 quantization...") |
|
quant_path = onnx_path.replace(".onnx", "_quantized.onnx") |
|
|
|
try: |
|
quantize_dynamic( |
|
model_input=onnx_path, |
|
model_output=quant_path, |
|
per_channel=False, |
|
reduce_range=False, |
|
weight_type=QuantType.QInt8 |
|
) |
|
|
|
if os.path.exists(quant_path): |
|
quant_size = os.path.getsize(quant_path) / (1024 * 1024) |
|
logger.info(f"β Quantized size: {quant_size:.2f} MB") |
|
logger.info(f"β Size reduction: {(1 - quant_size/size_mb) * 100:.1f}%") |
|
|
|
|
|
os.replace(quant_path, onnx_path) |
|
logger.info("β Replaced original with quantized version") |
|
else: |
|
logger.warning("β Quantized file not created, using original") |
|
except Exception as e: |
|
logger.error(f"β Quantization error: {str(e)}") |
|
logger.info("β Using original model without quantization") |
|
else: |
|
logger.info("Step 5/5: Skipping quantization (not requested)") |
|
|
|
|
|
end_time = time.time() |
|
duration = end_time - start_time |
|
logger.info(f"β Conversion completed in {duration:.2f} seconds") |
|
|
|
return { |
|
"success": True, |
|
"model_id": model_id, |
|
"size_mb": os.path.getsize(onnx_path) / (1024 * 1024), |
|
"duration_seconds": duration, |
|
"output_dir": model_dir |
|
} |
|
else: |
|
logger.error(f"Γ ONNX file not created at {onnx_path}") |
|
return { |
|
"success": False, |
|
"model_id": model_id, |
|
"error": "ONNX file not created" |
|
} |
|
|
|
except Exception as e: |
|
logger.error(f"Γ Error converting model: {str(e)}") |
|
logger.error(traceback.format_exc()) |
|
|
|
return { |
|
"success": False, |
|
"model_id": model_id, |
|
"error": str(e) |
|
} |
|
|
|
def main(): |
|
"""Convert all reliable models.""" |
|
|
|
logger.info("\nGUARANTEED ONNX CONVERTER") |
|
logger.info("======================") |
|
logger.info("Using reliable models with proven ONNX compatibility") |
|
|
|
|
|
output_dir = "./onnx_models" |
|
os.makedirs(output_dir, exist_ok=True) |
|
|
|
|
|
if len(sys.argv) > 1: |
|
model_id = sys.argv[1] |
|
logger.info(f"Converting single model: {model_id}") |
|
convert_model(model_id, output_dir) |
|
return |
|
|
|
|
|
results = [] |
|
for model_info in RELIABLE_MODELS: |
|
model_id = model_info["id"] |
|
logger.info(f"Processing model: {model_id}") |
|
logger.info(f"Description: {model_info['description']}") |
|
|
|
result = convert_model(model_id, output_dir) |
|
results.append(result) |
|
|
|
|
|
logger.info("\n" + "=" * 60) |
|
logger.info("CONVERSION SUMMARY") |
|
logger.info("=" * 60) |
|
|
|
success_count = 0 |
|
for result in results: |
|
if result.get("success", False): |
|
success_count += 1 |
|
size_info = f" - Size: {result.get('size_mb', 0):.2f} MB" |
|
time_info = f" - Time: {result.get('duration_seconds', 0):.2f}s" |
|
logger.info(f"β SUCCESS: {result['model_id']}{size_info}{time_info}") |
|
else: |
|
logger.info(f"Γ FAILED: {result['model_id']} - Error: {result.get('error', 'Unknown error')}") |
|
|
|
logger.info(f"\nSuccessfully converted {success_count}/{len(RELIABLE_MODELS)} models") |
|
logger.info(f"Models saved to: {os.path.abspath(output_dir)}") |
|
|
|
if success_count > 0: |
|
logger.info("\nThe models are ready for RAG and chatbot applications!") |
|
|
|
if __name__ == "__main__": |
|
main() |