import os import gc import sys import time import logging import traceback import torch import warnings import numpy as np from transformers import AutoModelForCausalLM, AutoTokenizer from transformers.generation import GenerationConfig from tqdm import tqdm from onnxruntime.quantization import quantize_dynamic, QuantType # Configure logging logging.basicConfig( level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s', datefmt='%Y-%m-%d %H:%M:%S' ) logger = logging.getLogger(__name__) # Suppress unhelpful warnings warnings.filterwarnings("ignore", category=UserWarning) class GenerationWrapper(torch.nn.Module): """ Wrapper for model export that handles generation properly. This ensures the model can be correctly used for text generation. """ def __init__(self, model): super().__init__() self.model = model self.config = model.config def forward(self, input_ids, attention_mask=None): # Return only the logits to avoid complex structures with torch.no_grad(): try: # Standard approach for most models outputs = self.model( input_ids=input_ids, attention_mask=attention_mask, use_cache=False, return_dict=True ) return outputs.logits except Exception as e: logger.warning(f"Standard forward pass failed, trying fallback: {str(e)}") # Fallback for models with different API outputs = self.model(input_ids=input_ids, attention_mask=attention_mask) if hasattr(outputs, 'logits'): return outputs.logits elif isinstance(outputs, tuple) and len(outputs) > 0: return outputs[0] # First element is typically logits else: raise ValueError("Could not extract logits from model outputs") def verify_model_generation(model, tokenizer, device="cpu"): """Test model generation capabilities before export""" model.eval() # Use a chat-like prompt for better testing prompt = "User: Hello, how are you today?\nAssistant:" logger.info("Testing model generation...") inputs = tokenizer(prompt, return_tensors="pt").to(device) # Configure generation parameters gen_config = GenerationConfig( max_length=100, do_sample=True, temperature=0.7, num_return_sequences=1, ) try: # Try generation with torch.no_grad(): outputs = model.generate( **inputs, generation_config=gen_config ) generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True) logger.info(f"Test generation result: {generated_text}") if len(generated_text) <= len(prompt): logger.warning("Generation output is not longer than input prompt!") return True except Exception as e: logger.error(f"Generation test failed: {str(e)}") return False def test_onnx_model(onnx_path, tokenizer): """Verify the ONNX model can be loaded and run""" try: import onnxruntime as ort logger.info("Testing ONNX model inference...") session = ort.InferenceSession(onnx_path) # Get input and output names input_names = [input.name for input in session.get_inputs()] output_names = [output.name for output in session.get_outputs()] # Create test input prompt = "User: Hello, how are you?\nAssistant:" inputs = tokenizer(prompt, return_tensors="np") # Prepare input dict onnx_inputs = {} for name in input_names: if name == "input_ids" and "input_ids" in inputs: onnx_inputs[name] = inputs["input_ids"] elif name == "attention_mask" and "attention_mask" in inputs: onnx_inputs[name] = inputs["attention_mask"] # Run inference outputs = session.run(output_names, onnx_inputs) # Check output shape logits = outputs[0] logger.info(f"ONNX model output shape: {logits.shape}") if logits.shape[0] != 1 or logits.shape[1] != inputs["input_ids"].shape[1]: logger.warning("Output shape doesn't match expected dimensions!") # Test next token prediction next_token_logits = logits[0, -1, :] next_token_id = np.argmax(next_token_logits) next_token = tokenizer.decode([next_token_id]) logger.info(f"Next predicted token: '{next_token}'") return True except Exception as e: logger.error(f"ONNX model test failed: {str(e)}") return False def post_process_onnx_for_unity(onnx_path): """ Post-process ONNX model to be compatible with Unity Sentis using only core onnx functionality (no onnxsim) """ try: import onnx logger.info("Post-processing ONNX model for Unity compatibility...") # First, create a backup of the original model backup_path = onnx_path.replace(".onnx", "_original.onnx") import shutil shutil.copy(onnx_path, backup_path) logger.info(f"Original model backed up to {backup_path}") # Load the model model = onnx.load(onnx_path) # Basic model checks and optimizations try: # Check model validity onnx.checker.check_model(model) logger.info("✓ Model structure validated successfully") # Apply shape inference inferred_model = onnx.shape_inference.infer_shapes(model) onnx.save(inferred_model, onnx_path) logger.info("✓ Applied shape inference") except Exception as e: logger.warning(f"Model validation/optimization error (continuing): {str(e)}") return True except Exception as e: logger.warning(f"ONNX post-processing error (skipping): {str(e)}") return False def is_architecture_compatible(model_id): """ Check if the model architecture is expected to be compatible with ONNX opset 11 """ model_id_lower = model_id.lower() # Models known to work with opset 11 compatible_architectures = [ "gpt2", "distilgpt2", "opt-125m", "opt-350m", "pythia-70m", "pythia-160m", "rwkv", "gpt-neo" ] # Models likely requiring higher opsets (usually 14+) incompatible_architectures = [ "llama", "mistral", "mixtral", "tinyllama", "phi-2", "gemma", "falcon", "bloom" ] # Check for compatibility for arch in compatible_architectures: if arch in model_id_lower: return True, 11 # Check for known incompatible architectures for arch in incompatible_architectures: if arch in model_id_lower: return False, 14 # For phi-1 models, use opset 14 but mark as potentially compatible if "phi-1" in model_id_lower: return True, 14 # Default to opset 14 for unknown architectures return False, 14 def setup_chat_template(model_id, tokenizer): """ Setup appropriate chat template based on model architecture """ model_id_lower = model_id.lower() # Try to setup chat template if it doesn't have one try: if not hasattr(tokenizer, "chat_template") or tokenizer.chat_template is None: logger.info("Setting up chat template for improved conversations...") # Determine chat template based on model if "gpt2" in model_id_lower or "pythia" in model_id_lower or "opt" in model_id_lower: # Simple template for base models chat_template = "{% for message in messages %}\n{% if message['role'] == 'user' %}\nHuman: {{ message['content'] }}\n{% elif message['role'] == 'assistant' %}\nAI: {{ message['content'] }}\n{% endif %}\n{% endfor %}\n{% if add_generation_prompt %}\nAI: {% endif %}" tokenizer.chat_template = chat_template logger.info("✓ Added simple Human/AI chat template") elif "phi" in model_id_lower: # Microsoft Phi models template chat_template = "{% for message in messages %}\n{% if message['role'] == 'user' %}\nHuman: {{ message['content'] }}\n{% elif message['role'] == 'assistant' %}\nAssistant: {{ message['content'] }}\n{% endif %}\n{% endfor %}\n{% if add_generation_prompt %}\nAssistant: {% endif %}" tokenizer.chat_template = chat_template logger.info("✓ Added Phi-style Human/Assistant chat template") elif "rwkv" in model_id_lower: # RWKV template chat_template = "{% for message in messages %}\n{% if message['role'] == 'user' %}\nUser: {{ message['content'] }}\n{% elif message['role'] == 'assistant' %}\nBot: {{ message['content'] }}\n{% endif %}\n{% endfor %}\n{% if add_generation_prompt %}\nBot: {% endif %}" tokenizer.chat_template = chat_template logger.info("✓ Added RWKV-style User/Bot chat template") except Exception as e: logger.warning(f"Couldn't setup chat template: {str(e)}") logger.info("Chat template setup will need to be handled in Unity") def convert_model(model_id, output_dir="./onnx_models", seq_length=32, quantize=True, force_opset=None): """ Convert a model to ONNX format with focus on Unity compatibility. Args: model_id: HuggingFace model ID or path output_dir: Directory to save the model seq_length: Input sequence length for export quantize: Whether to quantize the model to INT8 force_opset: Force a specific ONNX opset version Returns: bool: Success status """ start_time = time.time() # Check model architecture for compatibility is_compatible, recommended_opset = is_architecture_compatible(model_id) # Use forced opset if provided, otherwise use recommended opset_version = force_opset if force_opset is not None else recommended_opset # Warn if using a model that might not be compatible with Unity if not is_compatible and opset_version < 14: logger.warning(f"⚠ Model {model_id} may not be compatible with opset {opset_version}") logger.warning(f"⚠ Recommended opset for this model: {recommended_opset}") logger.warning(f"⚠ You can force a higher opset with --opset {recommended_opset}") logger.info(f"\n{'=' * 60}") logger.info(f"Converting {model_id} to ONNX for Unity (opset {opset_version})") logger.info(f"{'=' * 60}") # Create output directory model_name = model_id.split("/")[-1] model_dir = os.path.join(output_dir, model_name) os.makedirs(model_dir, exist_ok=True) try: # Step 1: Load tokenizer logger.info("Step 1/7: Loading tokenizer...") tokenizer = AutoTokenizer.from_pretrained(model_id) if tokenizer.pad_token is None and hasattr(tokenizer, 'eos_token'): logger.info("Adding pad_token = eos_token") tokenizer.pad_token = tokenizer.eos_token # Setup chat template for better conversation formatting setup_chat_template(model_id, tokenizer) # Save tokenizer tokenizer.save_pretrained(model_dir) logger.info(f"✓ Tokenizer saved to {model_dir}") # Step 2: Load model with reliability optimizations logger.info("Step 2/7: Loading model...") # Clean memory gc.collect() torch.cuda.empty_cache() if torch.cuda.is_available() else None # Determine device device = "cuda" if torch.cuda.is_available() else "cpu" # Load model with full precision try: model = AutoModelForCausalLM.from_pretrained( model_id, torch_dtype=torch.float32, # Use full precision for reliability low_cpu_mem_usage=True, # Reduce memory usage device_map=device # Use CUDA if available ) except Exception as e: logger.warning(f"Standard loading failed, trying with 'trust_remote_code=True': {str(e)}") # Some models (like RWKV) need trust_remote_code model = AutoModelForCausalLM.from_pretrained( model_id, torch_dtype=torch.float32, low_cpu_mem_usage=True, device_map=device, trust_remote_code=True ) # Save config model.config.save_pretrained(model_dir) logger.info(f"✓ Model config saved to {model_dir}") # Step 3: Verify model can generate chat responses logger.info("Step 3/7: Validating chat capabilities...") if not verify_model_generation(model, tokenizer, device): logger.warning("⚠ Model chat test didn't complete successfully") logger.info("Continuing with export anyway...") # Step 4: Export to ONNX logger.info(f"Step 4/7: Exporting to ONNX format with opset {opset_version}...") # Wrap model with generation-optimized interface wrapped_model = GenerationWrapper(model) wrapped_model.eval() # Clean memory again gc.collect() torch.cuda.empty_cache() if torch.cuda.is_available() else None # Export to ONNX with appropriate opset version onnx_path = os.path.join(model_dir, "model.onnx") # Create minimal input batch_size = 1 dummy_input = torch.ones(batch_size, seq_length, dtype=torch.long) attention_mask = torch.ones(batch_size, seq_length, dtype=torch.long) # Move tensors to correct device dummy_input = dummy_input.to(device) attention_mask = attention_mask.to(device) # Export to ONNX with required opset with torch.no_grad(): torch.onnx.export( wrapped_model, # Wrapped model (dummy_input, attention_mask), # Input tensors onnx_path, # Output path export_params=True, # Store weights opset_version=opset_version, # Required opset version do_constant_folding=True, # Optimize constants input_names=['input_ids', 'attention_mask'], # Input names output_names=['logits'], # Output name dynamic_axes={ # Dynamic dimensions 'input_ids': {0: 'batch_size', 1: 'sequence'}, 'attention_mask': {0: 'batch_size', 1: 'sequence'}, 'logits': {0: 'batch_size', 1: 'sequence'} } ) # Clean up to save memory del model del wrapped_model gc.collect() torch.cuda.empty_cache() if torch.cuda.is_available() else None # Verify export success if os.path.exists(onnx_path): size_mb = os.path.getsize(onnx_path) / (1024 * 1024) logger.info(f"✓ ONNX model saved to {onnx_path}") logger.info(f"✓ Original size: {size_mb:.2f} MB") # Step 5: Post-process the ONNX model for better Unity compatibility logger.info("Step 5/7: Post-processing ONNX model for Unity compatibility...") # Try to post-process model for Unity try: post_process_onnx_for_unity(onnx_path) except Exception as e: logger.warning(f"Post-processing failed (non-critical): {str(e)}") # Test ONNX model test_onnx_model(onnx_path, tokenizer) # Step 6: Quantize the model (optional) if quantize: logger.info("Step 6/7: Applying INT8 quantization...") quant_path = onnx_path.replace(".onnx", "_quantized.onnx") try: with tqdm(total=100, desc="Quantizing") as pbar: # Update progress callback def update_progress(x): pbar.update(1) # Apply quantization quantize_dynamic( model_input=onnx_path, model_output=quant_path, per_channel=False, reduce_range=False, weight_type=QuantType.QInt8, optimize_model=True, use_external_data_format=False ) pbar.update(100) # Ensure progress reaches 100% if os.path.exists(quant_path): quant_size = os.path.getsize(quant_path) / (1024 * 1024) logger.info(f"✓ Quantized size: {quant_size:.2f} MB") logger.info(f"✓ Size reduction: {(1 - quant_size/size_mb) * 100:.1f}%") # Test the quantized model test_onnx_model(quant_path, tokenizer) # Rename original as backup backup_path = onnx_path.replace(".onnx", "_fp32.onnx") os.rename(onnx_path, backup_path) # Replace original with quantized os.rename(quant_path, onnx_path) logger.info("✓ Original model preserved as *_fp32.onnx") logger.info("✓ Replaced original with quantized version") else: logger.warning("⚠ Quantized file not created, using original") except Exception as e: logger.error(f"⚠ Quantization error: {str(e)}") logger.info("⚠ Using original model without quantization") else: logger.info("Step 6/7: Skipping quantization as requested") # Step 7: Generate Unity integration examples logger.info("Step 7/7: Generating Unity integration examples...") # Create a Unity integration example unity_example_path = os.path.join(model_dir, "unity_integration.cs") with open(unity_example_path, 'w') as f: f.write(""" using UnityEngine; using Unity.Sentis; using System.Collections.Generic; using System.Linq; using System.Text; using System.Threading.Tasks; public class ONNXChatbot : MonoBehaviour { [SerializeField] private ModelAsset modelAsset; [SerializeField] private TextAsset tokenizerVocabJson; [SerializeField] private int maxTokens = 50; [SerializeField] private float temperature = 0.7f; private IWorker worker; private Dictionary inputs; private SimpleTokenizer tokenizer; private bool isGenerating = false; void Start() { // Initialize the model var model = ModelLoader.Load(modelAsset); worker = WorkerFactory.CreateWorker(WorkerFactory.Type.ComputePrecompiled, model); // Initialize tokenizer tokenizer = new SimpleTokenizer(tokenizerVocabJson.text); // Prepare for inference inputs = new Dictionary(); Debug.Log("Model and tokenizer initialized successfully."); } public async Task GenerateResponseAsync(string userMessage) { if (isGenerating) { Debug.LogWarning("Already generating a response. Please wait."); return "Already generating a response. Please wait."; } isGenerating = true; try { // Format prompt with chat template string prompt = FormatChatPrompt(userMessage); Debug.Log($"Formatted prompt: {prompt}"); // Tokenize input var tokenIds = tokenizer.Encode(prompt); Debug.Log($"Encoded to {tokenIds.Length} tokens"); if (tokenIds.Length > 0) { // Generate response token by token StringBuilder responseBuilder = new StringBuilder(); List currentIds = tokenIds.ToList(); for (int i = 0; i < maxTokens; i++) { // Make sure we don't exceed the model's context window if (currentIds.Count > 1024) { // If too long, keep only the last 1024 tokens currentIds = currentIds.Skip(currentIds.Count - 1024).Take(1024).ToList(); } // Create tensors for current sequence using (var inputIdsTensor = new TensorInt(new TensorShape(1, currentIds.Count), currentIds.ToArray())) using (var attentionMaskTensor = new TensorInt(new TensorShape(1, currentIds.Count), Enumerable.Repeat(1, currentIds.Count).ToArray())) { // Run inference inputs.Clear(); inputs["input_ids"] = inputIdsTensor; inputs["attention_mask"] = attentionMaskTensor; worker.Execute(inputs); var logits = worker.PeekOutput() as TensorFloat; // Get next token prediction int nextToken = SampleNextToken(logits, currentIds, temperature); // If we hit the end token or a newline after content, stop if (nextToken == tokenizer.EosToken || (i > 0 && nextToken == tokenizer.NewlineToken)) { break; } // Add token to current sequence for next iteration currentIds.Add(nextToken); // Decode the latest token string newToken = tokenizer.Decode(new[] { nextToken }); responseBuilder.Append(newToken); // For smoother output, yield every few tokens if (i % 5 == 0) { await Task.Delay(1); } } } // Return the full response, without the prompt string fullResponse = responseBuilder.ToString(); return CleanResponse(fullResponse); } else { Debug.LogError("Tokenization failed: empty token list"); return "Sorry, I couldn't process that input."; } } catch (System.Exception ex) { Debug.LogError($"Generation error: {ex.Message}\\n{ex.StackTrace}"); return "Sorry, an error occurred while generating a response."; } finally { isGenerating = false; } } private string FormatChatPrompt(string userMessage) { // You may need to adjust this template based on your specific model return $"User: {userMessage}\\nAssistant:"; } private string CleanResponse(string response) { // Extract only the Assistant's response int assistantPrefix = response.IndexOf("Assistant:"); if (assistantPrefix >= 0) { response = response.Substring(assistantPrefix + "Assistant:".Length).Trim(); } // Stop at any "User:" marker if present int nextUser = response.IndexOf("User:"); if (nextUser >= 0) { response = response.Substring(0, nextUser).Trim(); } return response; } private int SampleNextToken(TensorFloat logits, List currentInputs, float temp) { // Get logits for the last position int lastPos = currentInputs.Count - 1; int vocabSize = logits.shape.channels; // Prepare array for logits float[] lastLogits = new float[vocabSize]; // Extract logits for the last token position for (int i = 0; i < vocabSize; i++) { lastLogits[i] = logits[0, lastPos, i]; } // Simple temperature-based sampling if (temp <= 0.0f) { // Greedy sampling (argmax) int maxIndex = 0; float maxValue = lastLogits[0]; for (int i = 1; i < vocabSize; i++) { if (lastLogits[i] > maxValue) { maxValue = lastLogits[i]; maxIndex = i; } } return maxIndex; } else { // Temperature sampling // Apply temperature for (int i = 0; i < vocabSize; i++) { lastLogits[i] /= temp; } // Softmax float maxLogit = lastLogits.Max(); float sum = 0.0f; for (int i = 0; i < vocabSize; i++) { lastLogits[i] = Mathf.Exp(lastLogits[i] - maxLogit); sum += lastLogits[i]; } for (int i = 0; i < vocabSize; i++) { lastLogits[i] /= sum; } // Sample from distribution float random = Random.value; float cumulativeProb = 0.0f; for (int i = 0; i < vocabSize; i++) { cumulativeProb += lastLogits[i]; if (random < cumulativeProb) { return i; } } // Fallback to last token if sampling fails return vocabSize - 1; } } void OnDestroy() { worker?.Dispose(); } } // Simple tokenizer implementation for Unity public class SimpleTokenizer { private Dictionary vocab; private Dictionary reversedVocab; public int PadToken { get; private set; } public int EosToken { get; private set; } public int BosToken { get; private set; } public int NewlineToken { get; private set; } public SimpleTokenizer(string vocabJson) { // Parse the vocabulary from JSON vocab = new Dictionary(); // Simple JSON parsing (you'll need a proper JSON parser in production) string[] entries = vocabJson.Split(new[] { '\\n', '{', '}', '\"', ':', ',' }, System.StringSplitOptions.RemoveEmptyEntries); for (int i = 0; i < entries.Length - 1; i += 2) { string token = entries[i].Trim(); if (int.TryParse(entries[i + 1].Trim(), out int id)) { vocab[token] = id; } } // Create reversed vocabulary for decoding reversedVocab = vocab.ToDictionary(kv => kv.Value, kv => kv.Key); // Find special tokens SetSpecialTokens(); Debug.Log($"Tokenizer initialized with {vocab.Count} tokens"); } private void SetSpecialTokens() { // Try to find standard special tokens PadToken = FindToken(new[] { "", "[PAD]", "<|endoftext|>" }); EosToken = FindToken(new[] { "", "<|endoftext|>", "[EOS]", "" }); BosToken = FindToken(new[] { "", "<|startoftext|>", "[BOS]", "" }); // Find newline token foreach (var entry in vocab) { if (entry.Key == "\\n" || entry.Key == "<\\n>" || entry.Key == "\\n") { NewlineToken = entry.Value; break; } } Debug.Log($"Special tokens - PAD: {PadToken}, EOS: {EosToken}, BOS: {BosToken}, NEWLINE: {NewlineToken}"); } private int FindToken(string[] candidates) { foreach (var candidate in candidates) { if (vocab.TryGetValue(candidate, out int id)) { return id; } } // Return -1 if not found return -1; } public int[] Encode(string text) { // Simple character-level tokenization // In production, use a proper BPE/WordPiece tokenizer implementation List tokens = new List(); StringBuilder currentToken = new StringBuilder(); // Add BOS token if available if (BosToken != -1) { tokens.Add(BosToken); } // Very simple tokenization - in production, this would implement // the specific tokenization algorithm for your model foreach (char c in text) { currentToken.Append(c); string current = currentToken.ToString(); if (vocab.TryGetValue(current, out int id)) { tokens.Add(id); currentToken.Clear(); } else if (currentToken.Length > 10) { // If token is too long, add unknown token and reset tokens.Add(vocab.ContainsKey("") ? vocab[""] : 0); currentToken.Clear(); currentToken.Append(c); } } // Handle any remaining text if (currentToken.Length > 0) { tokens.Add(vocab.ContainsKey("") ? vocab[""] : 0); } return tokens.ToArray(); } public string Decode(int[] ids) { StringBuilder result = new StringBuilder(); foreach (int id in ids) { if (reversedVocab.TryGetValue(id, out string token)) { // Some tokenizers use special prefixes like "Ġ" for spaces string processedToken = token .Replace("Ġ", " ") .Replace("Ċ", "\n") .Replace("▁", " "); result.Append(processedToken); } } return result.ToString(); } } """) # Calculate elapsed time end_time = time.time() duration = end_time - start_time logger.info(f"✓ Conversion completed in {duration:.2f} seconds") logger.info(f"✓ Final model size: {os.path.getsize(onnx_path) / (1024 * 1024):.2f} MB") # Create a Python example usage file example_path = os.path.join(model_dir, "example_usage.py") with open(example_path, 'w') as f: f.write(""" import onnxruntime as ort from transformers import AutoTokenizer import numpy as np # Load tokenizer and model tokenizer = AutoTokenizer.from_pretrained("./") # Path to model directory session = ort.InferenceSession("./model.onnx") def generate_response(user_message, max_length=50): # Format as a chat message prompt = f"User: {user_message}\\nAssistant:" inputs = tokenizer(prompt, return_tensors="np") input_ids = inputs["input_ids"] attention_mask = inputs["attention_mask"] # Simple auto-regressive generation loop for _ in range(max_length): # Run inference for a single step outputs = session.run( ["logits"], { "input_ids": input_ids, "attention_mask": attention_mask } ) # Get next token prediction from logits logits = outputs[0] next_token_logits = logits[0, -1, :] # Apply temperature sampling temperature = 0.7 next_token_logits = next_token_logits / temperature # Apply softmax to get probabilities exp_logits = np.exp(next_token_logits - np.max(next_token_logits)) probs = exp_logits / np.sum(exp_logits) # Sample from the distribution next_token_id = np.random.choice(probs.shape[0], p=probs) # Stop if we hit the end of sequence token if next_token_id == tokenizer.eos_token_id: break # Append new token to the input_ids input_ids = np.concatenate([input_ids, [[next_token_id]]], axis=1) attention_mask = np.concatenate([attention_mask, [[1]]], axis=1) # Decode the entire response response = tokenizer.decode(input_ids[0], skip_special_tokens=True) # Extract only the assistant's response if "Assistant:" in response: response = response.split("Assistant:")[-1].strip() return response # Example usage while True: user_input = input("You: ") if user_input.lower() in ['exit', 'quit']: break response = generate_response(user_input) print(f"Assistant: {response}") """) logger.info(f"✓ Example usage saved to {example_path}") logger.info(f"✓ Unity integration example saved to {unity_example_path}") return True else: logger.error(f"× ONNX file not created at {onnx_path}") return False except Exception as e: logger.error(f"× Error converting model: {str(e)}") logger.error(traceback.format_exc()) return False if __name__ == "__main__": # Parse command line arguments parser_available = False try: import argparse parser = argparse.ArgumentParser(description="Convert HuggingFace models to ONNX for Unity") parser.add_argument("model_id", type=str, help="HuggingFace model ID or path") parser.add_argument("--output_dir", "-o", type=str, default="./onnx_models", help="Output directory for the converted model") parser.add_argument("--seq_length", "-s", type=int, default=32, help="Sequence length for model export") parser.add_argument("--no_quantize", action="store_true", help="Skip INT8 quantization step") parser.add_argument("--opset", "-op", type=int, default=None, help="Force a specific ONNX opset version") args = parser.parse_args() parser_available = True model_id = args.model_id output_dir = args.output_dir seq_length = args.seq_length quantize = not args.no_quantize force_opset = args.opset except (ImportError, NameError): # Fallback if argparse is not available parser_available = False if not parser_available: if len(sys.argv) < 2: print("Usage: python unity_compatible_converter.py MODEL_ID [OUTPUT_DIR] [SEQ_LENGTH] [--no-quantize] [--opset]") print("Example: python unity_compatible_converter.py distilgpt2 ./onnx_models 32") print("\nRecommended chat models for Unity:") print(" - distilgpt2 (smallest, opset 11)") print(" - EleutherAI/pythia-70m (better quality, opset 11)") print(" - microsoft/phi-1 (high quality, opset 14)") print(" - TinyLlama/TinyLlama-1.1B-Chat-v1.0 (chat-tuned, opset 14)") sys.exit(1) model_id = sys.argv[1] output_dir = sys.argv[2] if len(sys.argv) > 2 else "./onnx_models" seq_length = int(sys.argv[3]) if len(sys.argv) > 3 else 32 quantize = "--no-quantize" not in sys.argv and "--no_quantize" not in sys.argv force_opset = None # Check for opset flag for i, arg in enumerate(sys.argv): if arg == "--opset" and i + 1 < len(sys.argv): force_opset = int(sys.argv[i + 1]) # Check model architecture for automatic opset recommendation is_compatible, recommended_opset = is_architecture_compatible(model_id) # Print header logger.info("\nUNITY-COMPATIBLE ONNX CONVERTER") logger.info("===============================") logger.info(f"Model: {model_id}") logger.info(f"Output directory: {output_dir}") logger.info(f"Sequence length: {seq_length}") if force_opset is not None: logger.info(f"ONNX opset version: {force_opset} (forced)") else: logger.info(f"Recommended ONNX opset: {recommended_opset}") logger.info(f"Architecture compatible with opset 11: {'Yes' if is_compatible else 'No'}") logger.info(f"Quantization: {'Enabled' if quantize else 'Disabled'}") # Create output directory os.makedirs(output_dir, exist_ok=True) # Convert the model success = convert_model(model_id, output_dir, seq_length, quantize, force_opset) if success: logger.info("\n" + "=" * 60) logger.info("CONVERSION SUCCESSFUL") logger.info("=" * 60) logger.info(f"Model: {model_id}") logger.info(f"Output directory: {os.path.abspath(output_dir)}") logger.info("The model is ready for Unity integration!") logger.info("\nNext steps:") logger.info("1. Import the ONNX model into Unity using the Sentis package") logger.info("2. Use the unity_integration.cs file as a starting point") logger.info("3. For tokenization in Unity, implement the SimpleTokenizer class") else: logger.info("\n" + "=" * 60) logger.info("CONVERSION FAILED") logger.info("=" * 60) logger.info("Please try one of the recommended models that work well with Unity:") if is_compatible: logger.info("Compatible with Unity (opset 11):") logger.info(" - distilgpt2") logger.info(" - EleutherAI/pythia-70m") logger.info("Advanced models (require opset 14):") logger.info(" - microsoft/phi-1 --opset 14") logger.info(" - TinyLlama/TinyLlama-1.1B-Chat-v1.0 --opset 14")