|
import os |
|
import gc |
|
import sys |
|
import time |
|
import logging |
|
import traceback |
|
import torch |
|
import warnings |
|
import numpy as np |
|
from transformers import AutoModelForCausalLM, AutoTokenizer |
|
from transformers.generation import GenerationConfig |
|
from tqdm import tqdm |
|
from onnxruntime.quantization import quantize_dynamic, QuantType |
|
|
|
|
|
logging.basicConfig( |
|
level=logging.INFO, |
|
format='%(asctime)s - %(levelname)s - %(message)s', |
|
datefmt='%Y-%m-%d %H:%M:%S' |
|
) |
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
warnings.filterwarnings("ignore", category=UserWarning) |
|
|
|
|
|
class GenerationWrapper(torch.nn.Module): |
|
""" |
|
Wrapper for model export that handles generation properly. |
|
This ensures the model can be correctly used for text generation. |
|
""" |
|
def __init__(self, model): |
|
super().__init__() |
|
self.model = model |
|
self.config = model.config |
|
|
|
def forward(self, input_ids, attention_mask=None): |
|
|
|
with torch.no_grad(): |
|
try: |
|
|
|
outputs = self.model( |
|
input_ids=input_ids, |
|
attention_mask=attention_mask, |
|
use_cache=False, |
|
return_dict=True |
|
) |
|
return outputs.logits |
|
except Exception as e: |
|
logger.warning(f"Standard forward pass failed, trying fallback: {str(e)}") |
|
|
|
outputs = self.model(input_ids=input_ids, attention_mask=attention_mask) |
|
if hasattr(outputs, 'logits'): |
|
return outputs.logits |
|
elif isinstance(outputs, tuple) and len(outputs) > 0: |
|
return outputs[0] |
|
else: |
|
raise ValueError("Could not extract logits from model outputs") |
|
|
|
def verify_model_generation(model, tokenizer, device="cpu"): |
|
"""Test model generation capabilities before export""" |
|
model.eval() |
|
|
|
|
|
prompt = "User: Hello, how are you today?\nAssistant:" |
|
|
|
logger.info("Testing model generation...") |
|
inputs = tokenizer(prompt, return_tensors="pt").to(device) |
|
|
|
|
|
gen_config = GenerationConfig( |
|
max_length=100, |
|
do_sample=True, |
|
temperature=0.7, |
|
num_return_sequences=1, |
|
) |
|
|
|
try: |
|
|
|
with torch.no_grad(): |
|
outputs = model.generate( |
|
**inputs, |
|
generation_config=gen_config |
|
) |
|
|
|
generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True) |
|
logger.info(f"Test generation result: {generated_text}") |
|
|
|
if len(generated_text) <= len(prompt): |
|
logger.warning("Generation output is not longer than input prompt!") |
|
|
|
return True |
|
except Exception as e: |
|
logger.error(f"Generation test failed: {str(e)}") |
|
return False |
|
|
|
def test_onnx_model(onnx_path, tokenizer): |
|
"""Verify the ONNX model can be loaded and run""" |
|
try: |
|
import onnxruntime as ort |
|
|
|
logger.info("Testing ONNX model inference...") |
|
session = ort.InferenceSession(onnx_path) |
|
|
|
|
|
input_names = [input.name for input in session.get_inputs()] |
|
output_names = [output.name for output in session.get_outputs()] |
|
|
|
|
|
prompt = "User: Hello, how are you?\nAssistant:" |
|
inputs = tokenizer(prompt, return_tensors="np") |
|
|
|
|
|
onnx_inputs = {} |
|
for name in input_names: |
|
if name == "input_ids" and "input_ids" in inputs: |
|
onnx_inputs[name] = inputs["input_ids"] |
|
elif name == "attention_mask" and "attention_mask" in inputs: |
|
onnx_inputs[name] = inputs["attention_mask"] |
|
|
|
|
|
outputs = session.run(output_names, onnx_inputs) |
|
|
|
|
|
logits = outputs[0] |
|
logger.info(f"ONNX model output shape: {logits.shape}") |
|
|
|
if logits.shape[0] != 1 or logits.shape[1] != inputs["input_ids"].shape[1]: |
|
logger.warning("Output shape doesn't match expected dimensions!") |
|
|
|
|
|
next_token_logits = logits[0, -1, :] |
|
next_token_id = np.argmax(next_token_logits) |
|
next_token = tokenizer.decode([next_token_id]) |
|
logger.info(f"Next predicted token: '{next_token}'") |
|
|
|
return True |
|
except Exception as e: |
|
logger.error(f"ONNX model test failed: {str(e)}") |
|
return False |
|
|
|
def post_process_onnx_for_unity(onnx_path): |
|
""" |
|
Post-process ONNX model to be compatible with Unity Sentis |
|
using only core onnx functionality (no onnxsim) |
|
""" |
|
try: |
|
import onnx |
|
|
|
logger.info("Post-processing ONNX model for Unity compatibility...") |
|
|
|
|
|
backup_path = onnx_path.replace(".onnx", "_original.onnx") |
|
import shutil |
|
shutil.copy(onnx_path, backup_path) |
|
logger.info(f"Original model backed up to {backup_path}") |
|
|
|
|
|
model = onnx.load(onnx_path) |
|
|
|
|
|
try: |
|
|
|
onnx.checker.check_model(model) |
|
logger.info("✓ Model structure validated successfully") |
|
|
|
|
|
inferred_model = onnx.shape_inference.infer_shapes(model) |
|
onnx.save(inferred_model, onnx_path) |
|
logger.info("✓ Applied shape inference") |
|
|
|
except Exception as e: |
|
logger.warning(f"Model validation/optimization error (continuing): {str(e)}") |
|
|
|
return True |
|
|
|
except Exception as e: |
|
logger.warning(f"ONNX post-processing error (skipping): {str(e)}") |
|
return False |
|
|
|
def is_architecture_compatible(model_id): |
|
""" |
|
Check if the model architecture is expected to be compatible with ONNX opset 11 |
|
""" |
|
model_id_lower = model_id.lower() |
|
|
|
|
|
compatible_architectures = [ |
|
"gpt2", "distilgpt2", "opt-125m", "opt-350m", |
|
"pythia-70m", "pythia-160m", "rwkv", "gpt-neo" |
|
] |
|
|
|
|
|
incompatible_architectures = [ |
|
"llama", "mistral", "mixtral", "tinyllama", "phi-2", |
|
"gemma", "falcon", "bloom" |
|
] |
|
|
|
|
|
for arch in compatible_architectures: |
|
if arch in model_id_lower: |
|
return True, 11 |
|
|
|
|
|
for arch in incompatible_architectures: |
|
if arch in model_id_lower: |
|
return False, 14 |
|
|
|
|
|
if "phi-1" in model_id_lower: |
|
return True, 14 |
|
|
|
|
|
return False, 14 |
|
|
|
def setup_chat_template(model_id, tokenizer): |
|
""" |
|
Setup appropriate chat template based on model architecture |
|
""" |
|
model_id_lower = model_id.lower() |
|
|
|
|
|
try: |
|
if not hasattr(tokenizer, "chat_template") or tokenizer.chat_template is None: |
|
logger.info("Setting up chat template for improved conversations...") |
|
|
|
|
|
if "gpt2" in model_id_lower or "pythia" in model_id_lower or "opt" in model_id_lower: |
|
|
|
chat_template = "{% for message in messages %}\n{% if message['role'] == 'user' %}\nHuman: {{ message['content'] }}\n{% elif message['role'] == 'assistant' %}\nAI: {{ message['content'] }}\n{% endif %}\n{% endfor %}\n{% if add_generation_prompt %}\nAI: {% endif %}" |
|
tokenizer.chat_template = chat_template |
|
logger.info("✓ Added simple Human/AI chat template") |
|
|
|
elif "phi" in model_id_lower: |
|
|
|
chat_template = "{% for message in messages %}\n{% if message['role'] == 'user' %}\nHuman: {{ message['content'] }}\n{% elif message['role'] == 'assistant' %}\nAssistant: {{ message['content'] }}\n{% endif %}\n{% endfor %}\n{% if add_generation_prompt %}\nAssistant: {% endif %}" |
|
tokenizer.chat_template = chat_template |
|
logger.info("✓ Added Phi-style Human/Assistant chat template") |
|
|
|
elif "rwkv" in model_id_lower: |
|
|
|
chat_template = "{% for message in messages %}\n{% if message['role'] == 'user' %}\nUser: {{ message['content'] }}\n{% elif message['role'] == 'assistant' %}\nBot: {{ message['content'] }}\n{% endif %}\n{% endfor %}\n{% if add_generation_prompt %}\nBot: {% endif %}" |
|
tokenizer.chat_template = chat_template |
|
logger.info("✓ Added RWKV-style User/Bot chat template") |
|
|
|
except Exception as e: |
|
logger.warning(f"Couldn't setup chat template: {str(e)}") |
|
logger.info("Chat template setup will need to be handled in Unity") |
|
|
|
def convert_model(model_id, output_dir="./onnx_models", seq_length=32, quantize=True, force_opset=None): |
|
""" |
|
Convert a model to ONNX format with focus on Unity compatibility. |
|
|
|
Args: |
|
model_id: HuggingFace model ID or path |
|
output_dir: Directory to save the model |
|
seq_length: Input sequence length for export |
|
quantize: Whether to quantize the model to INT8 |
|
force_opset: Force a specific ONNX opset version |
|
|
|
Returns: |
|
bool: Success status |
|
""" |
|
start_time = time.time() |
|
|
|
|
|
is_compatible, recommended_opset = is_architecture_compatible(model_id) |
|
|
|
|
|
opset_version = force_opset if force_opset is not None else recommended_opset |
|
|
|
|
|
if not is_compatible and opset_version < 14: |
|
logger.warning(f"⚠ Model {model_id} may not be compatible with opset {opset_version}") |
|
logger.warning(f"⚠ Recommended opset for this model: {recommended_opset}") |
|
logger.warning(f"⚠ You can force a higher opset with --opset {recommended_opset}") |
|
|
|
logger.info(f"\n{'=' * 60}") |
|
logger.info(f"Converting {model_id} to ONNX for Unity (opset {opset_version})") |
|
logger.info(f"{'=' * 60}") |
|
|
|
|
|
model_name = model_id.split("/")[-1] |
|
model_dir = os.path.join(output_dir, model_name) |
|
os.makedirs(model_dir, exist_ok=True) |
|
|
|
try: |
|
|
|
logger.info("Step 1/7: Loading tokenizer...") |
|
|
|
tokenizer = AutoTokenizer.from_pretrained(model_id) |
|
if tokenizer.pad_token is None and hasattr(tokenizer, 'eos_token'): |
|
logger.info("Adding pad_token = eos_token") |
|
tokenizer.pad_token = tokenizer.eos_token |
|
|
|
|
|
setup_chat_template(model_id, tokenizer) |
|
|
|
|
|
tokenizer.save_pretrained(model_dir) |
|
logger.info(f"✓ Tokenizer saved to {model_dir}") |
|
|
|
|
|
logger.info("Step 2/7: Loading model...") |
|
|
|
|
|
gc.collect() |
|
torch.cuda.empty_cache() if torch.cuda.is_available() else None |
|
|
|
|
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
|
|
|
try: |
|
model = AutoModelForCausalLM.from_pretrained( |
|
model_id, |
|
torch_dtype=torch.float32, |
|
low_cpu_mem_usage=True, |
|
device_map=device |
|
) |
|
except Exception as e: |
|
logger.warning(f"Standard loading failed, trying with 'trust_remote_code=True': {str(e)}") |
|
|
|
model = AutoModelForCausalLM.from_pretrained( |
|
model_id, |
|
torch_dtype=torch.float32, |
|
low_cpu_mem_usage=True, |
|
device_map=device, |
|
trust_remote_code=True |
|
) |
|
|
|
|
|
model.config.save_pretrained(model_dir) |
|
logger.info(f"✓ Model config saved to {model_dir}") |
|
|
|
|
|
logger.info("Step 3/7: Validating chat capabilities...") |
|
|
|
if not verify_model_generation(model, tokenizer, device): |
|
logger.warning("⚠ Model chat test didn't complete successfully") |
|
logger.info("Continuing with export anyway...") |
|
|
|
|
|
logger.info(f"Step 4/7: Exporting to ONNX format with opset {opset_version}...") |
|
|
|
|
|
wrapped_model = GenerationWrapper(model) |
|
wrapped_model.eval() |
|
|
|
|
|
gc.collect() |
|
torch.cuda.empty_cache() if torch.cuda.is_available() else None |
|
|
|
|
|
onnx_path = os.path.join(model_dir, "model.onnx") |
|
|
|
|
|
batch_size = 1 |
|
dummy_input = torch.ones(batch_size, seq_length, dtype=torch.long) |
|
attention_mask = torch.ones(batch_size, seq_length, dtype=torch.long) |
|
|
|
|
|
dummy_input = dummy_input.to(device) |
|
attention_mask = attention_mask.to(device) |
|
|
|
|
|
with torch.no_grad(): |
|
torch.onnx.export( |
|
wrapped_model, |
|
(dummy_input, attention_mask), |
|
onnx_path, |
|
export_params=True, |
|
opset_version=opset_version, |
|
do_constant_folding=True, |
|
input_names=['input_ids', 'attention_mask'], |
|
output_names=['logits'], |
|
dynamic_axes={ |
|
'input_ids': {0: 'batch_size', 1: 'sequence'}, |
|
'attention_mask': {0: 'batch_size', 1: 'sequence'}, |
|
'logits': {0: 'batch_size', 1: 'sequence'} |
|
} |
|
) |
|
|
|
|
|
del model |
|
del wrapped_model |
|
gc.collect() |
|
torch.cuda.empty_cache() if torch.cuda.is_available() else None |
|
|
|
|
|
if os.path.exists(onnx_path): |
|
size_mb = os.path.getsize(onnx_path) / (1024 * 1024) |
|
logger.info(f"✓ ONNX model saved to {onnx_path}") |
|
logger.info(f"✓ Original size: {size_mb:.2f} MB") |
|
|
|
|
|
logger.info("Step 5/7: Post-processing ONNX model for Unity compatibility...") |
|
|
|
|
|
try: |
|
post_process_onnx_for_unity(onnx_path) |
|
except Exception as e: |
|
logger.warning(f"Post-processing failed (non-critical): {str(e)}") |
|
|
|
|
|
test_onnx_model(onnx_path, tokenizer) |
|
|
|
|
|
if quantize: |
|
logger.info("Step 6/7: Applying INT8 quantization...") |
|
quant_path = onnx_path.replace(".onnx", "_quantized.onnx") |
|
|
|
try: |
|
with tqdm(total=100, desc="Quantizing") as pbar: |
|
|
|
def update_progress(x): |
|
pbar.update(1) |
|
|
|
|
|
quantize_dynamic( |
|
model_input=onnx_path, |
|
model_output=quant_path, |
|
per_channel=False, |
|
reduce_range=False, |
|
weight_type=QuantType.QInt8, |
|
optimize_model=True, |
|
use_external_data_format=False |
|
) |
|
|
|
pbar.update(100) |
|
|
|
if os.path.exists(quant_path): |
|
quant_size = os.path.getsize(quant_path) / (1024 * 1024) |
|
logger.info(f"✓ Quantized size: {quant_size:.2f} MB") |
|
logger.info(f"✓ Size reduction: {(1 - quant_size/size_mb) * 100:.1f}%") |
|
|
|
|
|
test_onnx_model(quant_path, tokenizer) |
|
|
|
|
|
backup_path = onnx_path.replace(".onnx", "_fp32.onnx") |
|
os.rename(onnx_path, backup_path) |
|
|
|
|
|
os.rename(quant_path, onnx_path) |
|
logger.info("✓ Original model preserved as *_fp32.onnx") |
|
logger.info("✓ Replaced original with quantized version") |
|
else: |
|
logger.warning("⚠ Quantized file not created, using original") |
|
except Exception as e: |
|
logger.error(f"⚠ Quantization error: {str(e)}") |
|
logger.info("⚠ Using original model without quantization") |
|
else: |
|
logger.info("Step 6/7: Skipping quantization as requested") |
|
|
|
|
|
logger.info("Step 7/7: Generating Unity integration examples...") |
|
|
|
|
|
unity_example_path = os.path.join(model_dir, "unity_integration.cs") |
|
with open(unity_example_path, 'w') as f: |
|
f.write(""" |
|
using UnityEngine; |
|
using Unity.Sentis; |
|
using System.Collections.Generic; |
|
using System.Linq; |
|
using System.Text; |
|
using System.Threading.Tasks; |
|
|
|
public class ONNXChatbot : MonoBehaviour |
|
{ |
|
[SerializeField] private ModelAsset modelAsset; |
|
[SerializeField] private TextAsset tokenizerVocabJson; |
|
[SerializeField] private int maxTokens = 50; |
|
[SerializeField] private float temperature = 0.7f; |
|
|
|
private IWorker worker; |
|
private Dictionary<string, Tensor> inputs; |
|
private SimpleTokenizer tokenizer; |
|
private bool isGenerating = false; |
|
|
|
void Start() |
|
{ |
|
// Initialize the model |
|
var model = ModelLoader.Load(modelAsset); |
|
worker = WorkerFactory.CreateWorker(WorkerFactory.Type.ComputePrecompiled, model); |
|
|
|
// Initialize tokenizer |
|
tokenizer = new SimpleTokenizer(tokenizerVocabJson.text); |
|
|
|
// Prepare for inference |
|
inputs = new Dictionary<string, Tensor>(); |
|
|
|
Debug.Log("Model and tokenizer initialized successfully."); |
|
} |
|
|
|
public async Task<string> GenerateResponseAsync(string userMessage) |
|
{ |
|
if (isGenerating) |
|
{ |
|
Debug.LogWarning("Already generating a response. Please wait."); |
|
return "Already generating a response. Please wait."; |
|
} |
|
|
|
isGenerating = true; |
|
|
|
try |
|
{ |
|
// Format prompt with chat template |
|
string prompt = FormatChatPrompt(userMessage); |
|
Debug.Log($"Formatted prompt: {prompt}"); |
|
|
|
// Tokenize input |
|
var tokenIds = tokenizer.Encode(prompt); |
|
Debug.Log($"Encoded to {tokenIds.Length} tokens"); |
|
|
|
if (tokenIds.Length > 0) |
|
{ |
|
// Generate response token by token |
|
StringBuilder responseBuilder = new StringBuilder(); |
|
List<int> currentIds = tokenIds.ToList(); |
|
|
|
for (int i = 0; i < maxTokens; i++) |
|
{ |
|
// Make sure we don't exceed the model's context window |
|
if (currentIds.Count > 1024) |
|
{ |
|
// If too long, keep only the last 1024 tokens |
|
currentIds = currentIds.Skip(currentIds.Count - 1024).Take(1024).ToList(); |
|
} |
|
|
|
// Create tensors for current sequence |
|
using (var inputIdsTensor = new TensorInt(new TensorShape(1, currentIds.Count), currentIds.ToArray())) |
|
using (var attentionMaskTensor = new TensorInt(new TensorShape(1, currentIds.Count), Enumerable.Repeat(1, currentIds.Count).ToArray())) |
|
{ |
|
// Run inference |
|
inputs.Clear(); |
|
inputs["input_ids"] = inputIdsTensor; |
|
inputs["attention_mask"] = attentionMaskTensor; |
|
|
|
worker.Execute(inputs); |
|
var logits = worker.PeekOutput() as TensorFloat; |
|
|
|
// Get next token prediction |
|
int nextToken = SampleNextToken(logits, currentIds, temperature); |
|
|
|
// If we hit the end token or a newline after content, stop |
|
if (nextToken == tokenizer.EosToken || |
|
(i > 0 && nextToken == tokenizer.NewlineToken)) |
|
{ |
|
break; |
|
} |
|
|
|
// Add token to current sequence for next iteration |
|
currentIds.Add(nextToken); |
|
|
|
// Decode the latest token |
|
string newToken = tokenizer.Decode(new[] { nextToken }); |
|
responseBuilder.Append(newToken); |
|
|
|
// For smoother output, yield every few tokens |
|
if (i % 5 == 0) |
|
{ |
|
await Task.Delay(1); |
|
} |
|
} |
|
} |
|
|
|
// Return the full response, without the prompt |
|
string fullResponse = responseBuilder.ToString(); |
|
return CleanResponse(fullResponse); |
|
} |
|
else |
|
{ |
|
Debug.LogError("Tokenization failed: empty token list"); |
|
return "Sorry, I couldn't process that input."; |
|
} |
|
} |
|
catch (System.Exception ex) |
|
{ |
|
Debug.LogError($"Generation error: {ex.Message}\\n{ex.StackTrace}"); |
|
return "Sorry, an error occurred while generating a response."; |
|
} |
|
finally |
|
{ |
|
isGenerating = false; |
|
} |
|
} |
|
|
|
private string FormatChatPrompt(string userMessage) |
|
{ |
|
// You may need to adjust this template based on your specific model |
|
return $"User: {userMessage}\\nAssistant:"; |
|
} |
|
|
|
private string CleanResponse(string response) |
|
{ |
|
// Extract only the Assistant's response |
|
int assistantPrefix = response.IndexOf("Assistant:"); |
|
if (assistantPrefix >= 0) |
|
{ |
|
response = response.Substring(assistantPrefix + "Assistant:".Length).Trim(); |
|
} |
|
|
|
// Stop at any "User:" marker if present |
|
int nextUser = response.IndexOf("User:"); |
|
if (nextUser >= 0) |
|
{ |
|
response = response.Substring(0, nextUser).Trim(); |
|
} |
|
|
|
return response; |
|
} |
|
|
|
private int SampleNextToken(TensorFloat logits, List<int> currentInputs, float temp) |
|
{ |
|
// Get logits for the last position |
|
int lastPos = currentInputs.Count - 1; |
|
int vocabSize = logits.shape.channels; |
|
|
|
// Prepare array for logits |
|
float[] lastLogits = new float[vocabSize]; |
|
|
|
// Extract logits for the last token position |
|
for (int i = 0; i < vocabSize; i++) |
|
{ |
|
lastLogits[i] = logits[0, lastPos, i]; |
|
} |
|
|
|
// Simple temperature-based sampling |
|
if (temp <= 0.0f) |
|
{ |
|
// Greedy sampling (argmax) |
|
int maxIndex = 0; |
|
float maxValue = lastLogits[0]; |
|
|
|
for (int i = 1; i < vocabSize; i++) |
|
{ |
|
if (lastLogits[i] > maxValue) |
|
{ |
|
maxValue = lastLogits[i]; |
|
maxIndex = i; |
|
} |
|
} |
|
|
|
return maxIndex; |
|
} |
|
else |
|
{ |
|
// Temperature sampling |
|
// Apply temperature |
|
for (int i = 0; i < vocabSize; i++) |
|
{ |
|
lastLogits[i] /= temp; |
|
} |
|
|
|
// Softmax |
|
float maxLogit = lastLogits.Max(); |
|
float sum = 0.0f; |
|
|
|
for (int i = 0; i < vocabSize; i++) |
|
{ |
|
lastLogits[i] = Mathf.Exp(lastLogits[i] - maxLogit); |
|
sum += lastLogits[i]; |
|
} |
|
|
|
for (int i = 0; i < vocabSize; i++) |
|
{ |
|
lastLogits[i] /= sum; |
|
} |
|
|
|
// Sample from distribution |
|
float random = Random.value; |
|
float cumulativeProb = 0.0f; |
|
|
|
for (int i = 0; i < vocabSize; i++) |
|
{ |
|
cumulativeProb += lastLogits[i]; |
|
if (random < cumulativeProb) |
|
{ |
|
return i; |
|
} |
|
} |
|
|
|
// Fallback to last token if sampling fails |
|
return vocabSize - 1; |
|
} |
|
} |
|
|
|
void OnDestroy() |
|
{ |
|
worker?.Dispose(); |
|
} |
|
} |
|
|
|
// Simple tokenizer implementation for Unity |
|
public class SimpleTokenizer |
|
{ |
|
private Dictionary<string, int> vocab; |
|
private Dictionary<int, string> reversedVocab; |
|
|
|
public int PadToken { get; private set; } |
|
public int EosToken { get; private set; } |
|
public int BosToken { get; private set; } |
|
public int NewlineToken { get; private set; } |
|
|
|
public SimpleTokenizer(string vocabJson) |
|
{ |
|
// Parse the vocabulary from JSON |
|
vocab = new Dictionary<string, int>(); |
|
|
|
// Simple JSON parsing (you'll need a proper JSON parser in production) |
|
string[] entries = vocabJson.Split(new[] { '\\n', '{', '}', '\"', ':', ',' }, |
|
System.StringSplitOptions.RemoveEmptyEntries); |
|
|
|
for (int i = 0; i < entries.Length - 1; i += 2) |
|
{ |
|
string token = entries[i].Trim(); |
|
if (int.TryParse(entries[i + 1].Trim(), out int id)) |
|
{ |
|
vocab[token] = id; |
|
} |
|
} |
|
|
|
// Create reversed vocabulary for decoding |
|
reversedVocab = vocab.ToDictionary(kv => kv.Value, kv => kv.Key); |
|
|
|
// Find special tokens |
|
SetSpecialTokens(); |
|
|
|
Debug.Log($"Tokenizer initialized with {vocab.Count} tokens"); |
|
} |
|
|
|
private void SetSpecialTokens() |
|
{ |
|
// Try to find standard special tokens |
|
PadToken = FindToken(new[] { "<pad>", "[PAD]", "<|endoftext|>" }); |
|
EosToken = FindToken(new[] { "</s>", "<|endoftext|>", "[EOS]", "<eos>" }); |
|
BosToken = FindToken(new[] { "<s>", "<|startoftext|>", "[BOS]", "<bos>" }); |
|
|
|
// Find newline token |
|
foreach (var entry in vocab) |
|
{ |
|
if (entry.Key == "\\n" || entry.Key == "<\\n>" || entry.Key == "\\n") |
|
{ |
|
NewlineToken = entry.Value; |
|
break; |
|
} |
|
} |
|
|
|
Debug.Log($"Special tokens - PAD: {PadToken}, EOS: {EosToken}, BOS: {BosToken}, NEWLINE: {NewlineToken}"); |
|
} |
|
|
|
private int FindToken(string[] candidates) |
|
{ |
|
foreach (var candidate in candidates) |
|
{ |
|
if (vocab.TryGetValue(candidate, out int id)) |
|
{ |
|
return id; |
|
} |
|
} |
|
|
|
// Return -1 if not found |
|
return -1; |
|
} |
|
|
|
public int[] Encode(string text) |
|
{ |
|
// Simple character-level tokenization |
|
// In production, use a proper BPE/WordPiece tokenizer implementation |
|
List<int> tokens = new List<int>(); |
|
StringBuilder currentToken = new StringBuilder(); |
|
|
|
// Add BOS token if available |
|
if (BosToken != -1) |
|
{ |
|
tokens.Add(BosToken); |
|
} |
|
|
|
// Very simple tokenization - in production, this would implement |
|
// the specific tokenization algorithm for your model |
|
foreach (char c in text) |
|
{ |
|
currentToken.Append(c); |
|
string current = currentToken.ToString(); |
|
|
|
if (vocab.TryGetValue(current, out int id)) |
|
{ |
|
tokens.Add(id); |
|
currentToken.Clear(); |
|
} |
|
else if (currentToken.Length > 10) |
|
{ |
|
// If token is too long, add unknown token and reset |
|
tokens.Add(vocab.ContainsKey("<unk>") ? vocab["<unk>"] : 0); |
|
currentToken.Clear(); |
|
currentToken.Append(c); |
|
} |
|
} |
|
|
|
// Handle any remaining text |
|
if (currentToken.Length > 0) |
|
{ |
|
tokens.Add(vocab.ContainsKey("<unk>") ? vocab["<unk>"] : 0); |
|
} |
|
|
|
return tokens.ToArray(); |
|
} |
|
|
|
public string Decode(int[] ids) |
|
{ |
|
StringBuilder result = new StringBuilder(); |
|
|
|
foreach (int id in ids) |
|
{ |
|
if (reversedVocab.TryGetValue(id, out string token)) |
|
{ |
|
// Some tokenizers use special prefixes like "Ġ" for spaces |
|
string processedToken = token |
|
.Replace("Ġ", " ") |
|
.Replace("Ċ", "\n") |
|
.Replace("▁", " "); |
|
|
|
result.Append(processedToken); |
|
} |
|
} |
|
|
|
return result.ToString(); |
|
} |
|
} |
|
""") |
|
|
|
|
|
end_time = time.time() |
|
duration = end_time - start_time |
|
logger.info(f"✓ Conversion completed in {duration:.2f} seconds") |
|
logger.info(f"✓ Final model size: {os.path.getsize(onnx_path) / (1024 * 1024):.2f} MB") |
|
|
|
|
|
example_path = os.path.join(model_dir, "example_usage.py") |
|
with open(example_path, 'w') as f: |
|
f.write(""" |
|
import onnxruntime as ort |
|
from transformers import AutoTokenizer |
|
import numpy as np |
|
|
|
# Load tokenizer and model |
|
tokenizer = AutoTokenizer.from_pretrained("./") # Path to model directory |
|
session = ort.InferenceSession("./model.onnx") |
|
|
|
def generate_response(user_message, max_length=50): |
|
# Format as a chat message |
|
prompt = f"User: {user_message}\\nAssistant:" |
|
inputs = tokenizer(prompt, return_tensors="np") |
|
|
|
input_ids = inputs["input_ids"] |
|
attention_mask = inputs["attention_mask"] |
|
|
|
# Simple auto-regressive generation loop |
|
for _ in range(max_length): |
|
# Run inference for a single step |
|
outputs = session.run( |
|
["logits"], |
|
{ |
|
"input_ids": input_ids, |
|
"attention_mask": attention_mask |
|
} |
|
) |
|
|
|
# Get next token prediction from logits |
|
logits = outputs[0] |
|
next_token_logits = logits[0, -1, :] |
|
|
|
# Apply temperature sampling |
|
temperature = 0.7 |
|
next_token_logits = next_token_logits / temperature |
|
|
|
# Apply softmax to get probabilities |
|
exp_logits = np.exp(next_token_logits - np.max(next_token_logits)) |
|
probs = exp_logits / np.sum(exp_logits) |
|
|
|
# Sample from the distribution |
|
next_token_id = np.random.choice(probs.shape[0], p=probs) |
|
|
|
# Stop if we hit the end of sequence token |
|
if next_token_id == tokenizer.eos_token_id: |
|
break |
|
|
|
# Append new token to the input_ids |
|
input_ids = np.concatenate([input_ids, [[next_token_id]]], axis=1) |
|
attention_mask = np.concatenate([attention_mask, [[1]]], axis=1) |
|
|
|
# Decode the entire response |
|
response = tokenizer.decode(input_ids[0], skip_special_tokens=True) |
|
|
|
# Extract only the assistant's response |
|
if "Assistant:" in response: |
|
response = response.split("Assistant:")[-1].strip() |
|
|
|
return response |
|
|
|
# Example usage |
|
while True: |
|
user_input = input("You: ") |
|
if user_input.lower() in ['exit', 'quit']: |
|
break |
|
response = generate_response(user_input) |
|
print(f"Assistant: {response}") |
|
""") |
|
|
|
logger.info(f"✓ Example usage saved to {example_path}") |
|
logger.info(f"✓ Unity integration example saved to {unity_example_path}") |
|
return True |
|
|
|
else: |
|
logger.error(f"× ONNX file not created at {onnx_path}") |
|
return False |
|
|
|
except Exception as e: |
|
logger.error(f"× Error converting model: {str(e)}") |
|
logger.error(traceback.format_exc()) |
|
return False |
|
|
|
if __name__ == "__main__": |
|
|
|
parser_available = False |
|
try: |
|
import argparse |
|
parser = argparse.ArgumentParser(description="Convert HuggingFace models to ONNX for Unity") |
|
parser.add_argument("model_id", type=str, help="HuggingFace model ID or path") |
|
parser.add_argument("--output_dir", "-o", type=str, default="./onnx_models", |
|
help="Output directory for the converted model") |
|
parser.add_argument("--seq_length", "-s", type=int, default=32, |
|
help="Sequence length for model export") |
|
parser.add_argument("--no_quantize", action="store_true", |
|
help="Skip INT8 quantization step") |
|
parser.add_argument("--opset", "-op", type=int, default=None, |
|
help="Force a specific ONNX opset version") |
|
|
|
args = parser.parse_args() |
|
parser_available = True |
|
|
|
model_id = args.model_id |
|
output_dir = args.output_dir |
|
seq_length = args.seq_length |
|
quantize = not args.no_quantize |
|
force_opset = args.opset |
|
|
|
except (ImportError, NameError): |
|
|
|
parser_available = False |
|
|
|
if not parser_available: |
|
if len(sys.argv) < 2: |
|
print("Usage: python unity_compatible_converter.py MODEL_ID [OUTPUT_DIR] [SEQ_LENGTH] [--no-quantize] [--opset]") |
|
print("Example: python unity_compatible_converter.py distilgpt2 ./onnx_models 32") |
|
print("\nRecommended chat models for Unity:") |
|
print(" - distilgpt2 (smallest, opset 11)") |
|
print(" - EleutherAI/pythia-70m (better quality, opset 11)") |
|
print(" - microsoft/phi-1 (high quality, opset 14)") |
|
print(" - TinyLlama/TinyLlama-1.1B-Chat-v1.0 (chat-tuned, opset 14)") |
|
sys.exit(1) |
|
|
|
model_id = sys.argv[1] |
|
output_dir = sys.argv[2] if len(sys.argv) > 2 else "./onnx_models" |
|
seq_length = int(sys.argv[3]) if len(sys.argv) > 3 else 32 |
|
quantize = "--no-quantize" not in sys.argv and "--no_quantize" not in sys.argv |
|
force_opset = None |
|
|
|
|
|
for i, arg in enumerate(sys.argv): |
|
if arg == "--opset" and i + 1 < len(sys.argv): |
|
force_opset = int(sys.argv[i + 1]) |
|
|
|
|
|
is_compatible, recommended_opset = is_architecture_compatible(model_id) |
|
|
|
|
|
logger.info("\nUNITY-COMPATIBLE ONNX CONVERTER") |
|
logger.info("===============================") |
|
logger.info(f"Model: {model_id}") |
|
logger.info(f"Output directory: {output_dir}") |
|
logger.info(f"Sequence length: {seq_length}") |
|
|
|
if force_opset is not None: |
|
logger.info(f"ONNX opset version: {force_opset} (forced)") |
|
else: |
|
logger.info(f"Recommended ONNX opset: {recommended_opset}") |
|
logger.info(f"Architecture compatible with opset 11: {'Yes' if is_compatible else 'No'}") |
|
|
|
logger.info(f"Quantization: {'Enabled' if quantize else 'Disabled'}") |
|
|
|
|
|
os.makedirs(output_dir, exist_ok=True) |
|
|
|
|
|
success = convert_model(model_id, output_dir, seq_length, quantize, force_opset) |
|
|
|
if success: |
|
logger.info("\n" + "=" * 60) |
|
logger.info("CONVERSION SUCCESSFUL") |
|
logger.info("=" * 60) |
|
logger.info(f"Model: {model_id}") |
|
logger.info(f"Output directory: {os.path.abspath(output_dir)}") |
|
logger.info("The model is ready for Unity integration!") |
|
logger.info("\nNext steps:") |
|
logger.info("1. Import the ONNX model into Unity using the Sentis package") |
|
logger.info("2. Use the unity_integration.cs file as a starting point") |
|
logger.info("3. For tokenization in Unity, implement the SimpleTokenizer class") |
|
else: |
|
logger.info("\n" + "=" * 60) |
|
logger.info("CONVERSION FAILED") |
|
logger.info("=" * 60) |
|
logger.info("Please try one of the recommended models that work well with Unity:") |
|
|
|
if is_compatible: |
|
logger.info("Compatible with Unity (opset 11):") |
|
logger.info(" - distilgpt2") |
|
logger.info(" - EleutherAI/pythia-70m") |
|
|
|
logger.info("Advanced models (require opset 14):") |
|
logger.info(" - microsoft/phi-1 --opset 14") |
|
logger.info(" - TinyLlama/TinyLlama-1.1B-Chat-v1.0 --opset 14") |