File size: 39,577 Bytes
16ffc97
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
import os
import gc
import sys
import time
import logging
import traceback
import torch
import warnings
import numpy as np
from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers.generation import GenerationConfig
from tqdm import tqdm
from onnxruntime.quantization import quantize_dynamic, QuantType

# Configure logging
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(levelname)s - %(message)s',
    datefmt='%Y-%m-%d %H:%M:%S'
)
logger = logging.getLogger(__name__)

# Suppress unhelpful warnings
warnings.filterwarnings("ignore", category=UserWarning)


class GenerationWrapper(torch.nn.Module):
    """
    Wrapper for model export that handles generation properly.
    This ensures the model can be correctly used for text generation.
    """
    def __init__(self, model):
        super().__init__()
        self.model = model
        self.config = model.config
        
    def forward(self, input_ids, attention_mask=None):
        # Return only the logits to avoid complex structures
        with torch.no_grad():
            try:
                # Standard approach for most models
                outputs = self.model(
                    input_ids=input_ids, 
                    attention_mask=attention_mask,
                    use_cache=False,
                    return_dict=True
                )
                return outputs.logits
            except Exception as e:
                logger.warning(f"Standard forward pass failed, trying fallback: {str(e)}")
                # Fallback for models with different API
                outputs = self.model(input_ids=input_ids, attention_mask=attention_mask)
                if hasattr(outputs, 'logits'):
                    return outputs.logits
                elif isinstance(outputs, tuple) and len(outputs) > 0:
                    return outputs[0]  # First element is typically logits
                else:
                    raise ValueError("Could not extract logits from model outputs")
                
def verify_model_generation(model, tokenizer, device="cpu"):
        """Test model generation capabilities before export"""
        model.eval()
        
        # Use a chat-like prompt for better testing
        prompt = "User: Hello, how are you today?\nAssistant:"
        
        logger.info("Testing model generation...")
        inputs = tokenizer(prompt, return_tensors="pt").to(device)
        
        # Configure generation parameters
        gen_config = GenerationConfig(
            max_length=100,
            do_sample=True,
            temperature=0.7,
            num_return_sequences=1,
        )
        
        try:
            # Try generation
            with torch.no_grad():
                outputs = model.generate(
                    **inputs,
                    generation_config=gen_config
                )
            
            generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
            logger.info(f"Test generation result: {generated_text}")
            
            if len(generated_text) <= len(prompt):
                logger.warning("Generation output is not longer than input prompt!")
                
            return True
        except Exception as e:
            logger.error(f"Generation test failed: {str(e)}")
            return False

def test_onnx_model(onnx_path, tokenizer):
        """Verify the ONNX model can be loaded and run"""
        try:
            import onnxruntime as ort
            
            logger.info("Testing ONNX model inference...")
            session = ort.InferenceSession(onnx_path)
            
            # Get input and output names
            input_names = [input.name for input in session.get_inputs()]
            output_names = [output.name for output in session.get_outputs()]
            
            # Create test input
            prompt = "User: Hello, how are you?\nAssistant:"
            inputs = tokenizer(prompt, return_tensors="np")
            
            # Prepare input dict
            onnx_inputs = {}
            for name in input_names:
                if name == "input_ids" and "input_ids" in inputs:
                    onnx_inputs[name] = inputs["input_ids"]
                elif name == "attention_mask" and "attention_mask" in inputs:
                    onnx_inputs[name] = inputs["attention_mask"]
            
            # Run inference
            outputs = session.run(output_names, onnx_inputs)
            
            # Check output shape
            logits = outputs[0]
            logger.info(f"ONNX model output shape: {logits.shape}")
            
            if logits.shape[0] != 1 or logits.shape[1] != inputs["input_ids"].shape[1]:
                logger.warning("Output shape doesn't match expected dimensions!")
            
            # Test next token prediction
            next_token_logits = logits[0, -1, :]
            next_token_id = np.argmax(next_token_logits)
            next_token = tokenizer.decode([next_token_id])
            logger.info(f"Next predicted token: '{next_token}'")
            
            return True
        except Exception as e:
            logger.error(f"ONNX model test failed: {str(e)}")
            return False

def post_process_onnx_for_unity(onnx_path):
    """
    Post-process ONNX model to be compatible with Unity Sentis
    using only core onnx functionality (no onnxsim)
    """
    try:
        import onnx
        
        logger.info("Post-processing ONNX model for Unity compatibility...")
        
        # First, create a backup of the original model
        backup_path = onnx_path.replace(".onnx", "_original.onnx")
        import shutil
        shutil.copy(onnx_path, backup_path)
        logger.info(f"Original model backed up to {backup_path}")
        
        # Load the model
        model = onnx.load(onnx_path)
        
        # Basic model checks and optimizations
        try:
            # Check model validity
            onnx.checker.check_model(model)
            logger.info("βœ“ Model structure validated successfully")
            
            # Apply shape inference
            inferred_model = onnx.shape_inference.infer_shapes(model)
            onnx.save(inferred_model, onnx_path)
            logger.info("βœ“ Applied shape inference")
            
        except Exception as e:
            logger.warning(f"Model validation/optimization error (continuing): {str(e)}")
            
        return True
            
    except Exception as e:
        logger.warning(f"ONNX post-processing error (skipping): {str(e)}")
        return False
    
def is_architecture_compatible(model_id):
        """
        Check if the model architecture is expected to be compatible with ONNX opset 11
        """
        model_id_lower = model_id.lower()
        
        # Models known to work with opset 11
        compatible_architectures = [
            "gpt2", "distilgpt2", "opt-125m", "opt-350m", 
            "pythia-70m", "pythia-160m", "rwkv", "gpt-neo"
        ]
        
        # Models likely requiring higher opsets (usually 14+)
        incompatible_architectures = [
            "llama", "mistral", "mixtral", "tinyllama", "phi-2", 
            "gemma", "falcon", "bloom"
        ]
        
        # Check for compatibility
        for arch in compatible_architectures:
            if arch in model_id_lower:
                return True, 11
        
        # Check for known incompatible architectures
        for arch in incompatible_architectures:
            if arch in model_id_lower:
                return False, 14
                
        # For phi-1 models, use opset 14 but mark as potentially compatible
        if "phi-1" in model_id_lower:
            return True, 14
        
        # Default to opset 14 for unknown architectures
        return False, 14

def setup_chat_template(model_id, tokenizer):
        """
        Setup appropriate chat template based on model architecture
        """
        model_id_lower = model_id.lower()
        
        # Try to setup chat template if it doesn't have one
        try:
            if not hasattr(tokenizer, "chat_template") or tokenizer.chat_template is None:
                logger.info("Setting up chat template for improved conversations...")
                
                # Determine chat template based on model
                if "gpt2" in model_id_lower or "pythia" in model_id_lower or "opt" in model_id_lower:
                    # Simple template for base models
                    chat_template = "{% for message in messages %}\n{% if message['role'] == 'user' %}\nHuman: {{ message['content'] }}\n{% elif message['role'] == 'assistant' %}\nAI: {{ message['content'] }}\n{% endif %}\n{% endfor %}\n{% if add_generation_prompt %}\nAI: {% endif %}"
                    tokenizer.chat_template = chat_template
                    logger.info("βœ“ Added simple Human/AI chat template")
                    
                elif "phi" in model_id_lower:
                    # Microsoft Phi models template
                    chat_template = "{% for message in messages %}\n{% if message['role'] == 'user' %}\nHuman: {{ message['content'] }}\n{% elif message['role'] == 'assistant' %}\nAssistant: {{ message['content'] }}\n{% endif %}\n{% endfor %}\n{% if add_generation_prompt %}\nAssistant: {% endif %}"
                    tokenizer.chat_template = chat_template
                    logger.info("βœ“ Added Phi-style Human/Assistant chat template")
                    
                elif "rwkv" in model_id_lower:
                    # RWKV template
                    chat_template = "{% for message in messages %}\n{% if message['role'] == 'user' %}\nUser: {{ message['content'] }}\n{% elif message['role'] == 'assistant' %}\nBot: {{ message['content'] }}\n{% endif %}\n{% endfor %}\n{% if add_generation_prompt %}\nBot: {% endif %}"
                    tokenizer.chat_template = chat_template
                    logger.info("βœ“ Added RWKV-style User/Bot chat template")
                    
        except Exception as e:
            logger.warning(f"Couldn't setup chat template: {str(e)}")
            logger.info("Chat template setup will need to be handled in Unity")
            
def convert_model(model_id, output_dir="./onnx_models", seq_length=32, quantize=True, force_opset=None):
    """
    Convert a model to ONNX format with focus on Unity compatibility.
    
    Args:
        model_id: HuggingFace model ID or path
        output_dir: Directory to save the model
        seq_length: Input sequence length for export
        quantize: Whether to quantize the model to INT8
        force_opset: Force a specific ONNX opset version
        
    Returns:
        bool: Success status
    """
    start_time = time.time()
    
    # Check model architecture for compatibility
    is_compatible, recommended_opset = is_architecture_compatible(model_id)
    
    # Use forced opset if provided, otherwise use recommended
    opset_version = force_opset if force_opset is not None else recommended_opset
        
    # Warn if using a model that might not be compatible with Unity
    if not is_compatible and opset_version < 14:
        logger.warning(f"⚠ Model {model_id} may not be compatible with opset {opset_version}")
        logger.warning(f"⚠ Recommended opset for this model: {recommended_opset}")
        logger.warning(f"⚠ You can force a higher opset with --opset {recommended_opset}")
    
    logger.info(f"\n{'=' * 60}")
    logger.info(f"Converting {model_id} to ONNX for Unity (opset {opset_version})")
    logger.info(f"{'=' * 60}")
    
    # Create output directory
    model_name = model_id.split("/")[-1]
    model_dir = os.path.join(output_dir, model_name)
    os.makedirs(model_dir, exist_ok=True)
        
    try:
        # Step 1: Load tokenizer
        logger.info("Step 1/7: Loading tokenizer...")
        
        tokenizer = AutoTokenizer.from_pretrained(model_id)
        if tokenizer.pad_token is None and hasattr(tokenizer, 'eos_token'):
            logger.info("Adding pad_token = eos_token")
            tokenizer.pad_token = tokenizer.eos_token
        
        # Setup chat template for better conversation formatting
        setup_chat_template(model_id, tokenizer)
        
        # Save tokenizer
        tokenizer.save_pretrained(model_dir)
        logger.info(f"βœ“ Tokenizer saved to {model_dir}")
        
        # Step 2: Load model with reliability optimizations
        logger.info("Step 2/7: Loading model...")
        
        # Clean memory
        gc.collect()
        torch.cuda.empty_cache() if torch.cuda.is_available() else None
        
        # Determine device
        device = "cuda" if torch.cuda.is_available() else "cpu"
            
        # Load model with full precision
        try:
            model = AutoModelForCausalLM.from_pretrained(
                model_id,
                torch_dtype=torch.float32,  # Use full precision for reliability
                low_cpu_mem_usage=True,     # Reduce memory usage
                device_map=device          # Use CUDA if available
            )
        except Exception as e:
            logger.warning(f"Standard loading failed, trying with 'trust_remote_code=True': {str(e)}")
            # Some models (like RWKV) need trust_remote_code
            model = AutoModelForCausalLM.from_pretrained(
                model_id,
                torch_dtype=torch.float32,
                low_cpu_mem_usage=True,
                device_map=device,
                trust_remote_code=True
            )
        
        # Save config
        model.config.save_pretrained(model_dir)
        logger.info(f"βœ“ Model config saved to {model_dir}")
        
        # Step 3: Verify model can generate chat responses
        logger.info("Step 3/7: Validating chat capabilities...")
        
        if not verify_model_generation(model, tokenizer, device):
            logger.warning("⚠ Model chat test didn't complete successfully")
            logger.info("Continuing with export anyway...")
        
        # Step 4: Export to ONNX
        logger.info(f"Step 4/7: Exporting to ONNX format with opset {opset_version}...")
            
        # Wrap model with generation-optimized interface
        wrapped_model = GenerationWrapper(model)
        wrapped_model.eval()
        
        # Clean memory again
        gc.collect()
        torch.cuda.empty_cache() if torch.cuda.is_available() else None
            
        # Export to ONNX with appropriate opset version
        onnx_path = os.path.join(model_dir, "model.onnx")
        
        # Create minimal input
        batch_size = 1
        dummy_input = torch.ones(batch_size, seq_length, dtype=torch.long)
        attention_mask = torch.ones(batch_size, seq_length, dtype=torch.long)
            
        # Move tensors to correct device
        dummy_input = dummy_input.to(device)
        attention_mask = attention_mask.to(device)
        
        # Export to ONNX with required opset
        with torch.no_grad():
            torch.onnx.export(
                wrapped_model,                # Wrapped model
                (dummy_input, attention_mask), # Input tensors
                onnx_path,                    # Output path
                export_params=True,           # Store weights
                opset_version=opset_version,  # Required opset version
                do_constant_folding=True,     # Optimize constants
                input_names=['input_ids', 'attention_mask'],  # Input names
                output_names=['logits'],      # Output name
                dynamic_axes={                # Dynamic dimensions
                    'input_ids': {0: 'batch_size', 1: 'sequence'},
                    'attention_mask': {0: 'batch_size', 1: 'sequence'},
                    'logits': {0: 'batch_size', 1: 'sequence'}
                }
            )
            
        # Clean up to save memory
        del model
        del wrapped_model
        gc.collect()
        torch.cuda.empty_cache() if torch.cuda.is_available() else None
        
        # Verify export success
        if os.path.exists(onnx_path):
            size_mb = os.path.getsize(onnx_path) / (1024 * 1024)
            logger.info(f"βœ“ ONNX model saved to {onnx_path}")
            logger.info(f"βœ“ Original size: {size_mb:.2f} MB")
            
            # Step 5: Post-process the ONNX model for better Unity compatibility
            logger.info("Step 5/7: Post-processing ONNX model for Unity compatibility...")
                
            # Try to post-process model for Unity
            try:
                post_process_onnx_for_unity(onnx_path)
            except Exception as e:
                logger.warning(f"Post-processing failed (non-critical): {str(e)}")
            
            # Test ONNX model
            test_onnx_model(onnx_path, tokenizer)
                
            # Step 6: Quantize the model (optional)
            if quantize:
                logger.info("Step 6/7: Applying INT8 quantization...")
                quant_path = onnx_path.replace(".onnx", "_quantized.onnx")
                
                try:
                    with tqdm(total=100, desc="Quantizing") as pbar:
                        # Update progress callback
                        def update_progress(x):
                            pbar.update(1)
                        
                        # Apply quantization
                        quantize_dynamic(
                            model_input=onnx_path,
                            model_output=quant_path,
                            per_channel=False,
                            reduce_range=False,
                            weight_type=QuantType.QInt8,
                            optimize_model=True,
                            use_external_data_format=False
                        )
                        
                        pbar.update(100)  # Ensure progress reaches 100%
                    
                    if os.path.exists(quant_path):
                        quant_size = os.path.getsize(quant_path) / (1024 * 1024)
                        logger.info(f"βœ“ Quantized size: {quant_size:.2f} MB")
                        logger.info(f"βœ“ Size reduction: {(1 - quant_size/size_mb) * 100:.1f}%")
                        
                        # Test the quantized model
                        test_onnx_model(quant_path, tokenizer)
                        
                        # Rename original as backup
                        backup_path = onnx_path.replace(".onnx", "_fp32.onnx")
                        os.rename(onnx_path, backup_path)
                        
                        # Replace original with quantized
                        os.rename(quant_path, onnx_path)
                        logger.info("βœ“ Original model preserved as *_fp32.onnx")
                        logger.info("βœ“ Replaced original with quantized version")
                    else:
                        logger.warning("⚠ Quantized file not created, using original")
                except Exception as e:
                    logger.error(f"⚠ Quantization error: {str(e)}")
                    logger.info("⚠ Using original model without quantization")
            else:
                logger.info("Step 6/7: Skipping quantization as requested")
            
            # Step 7: Generate Unity integration examples
            logger.info("Step 7/7: Generating Unity integration examples...")
            
            # Create a Unity integration example
            unity_example_path = os.path.join(model_dir, "unity_integration.cs")
            with open(unity_example_path, 'w') as f:
                f.write("""
using UnityEngine;
using Unity.Sentis;
using System.Collections.Generic;
using System.Linq;
using System.Text;
using System.Threading.Tasks;

public class ONNXChatbot : MonoBehaviour
{
    [SerializeField] private ModelAsset modelAsset;
    [SerializeField] private TextAsset tokenizerVocabJson;
    [SerializeField] private int maxTokens = 50;
    [SerializeField] private float temperature = 0.7f;
    
    private IWorker worker;
    private Dictionary<string, Tensor> inputs;
    private SimpleTokenizer tokenizer;
    private bool isGenerating = false;

    void Start()
    {
        // Initialize the model
        var model = ModelLoader.Load(modelAsset);
        worker = WorkerFactory.CreateWorker(WorkerFactory.Type.ComputePrecompiled, model);
        
        // Initialize tokenizer
        tokenizer = new SimpleTokenizer(tokenizerVocabJson.text);
        
        // Prepare for inference
        inputs = new Dictionary<string, Tensor>();
        
        Debug.Log("Model and tokenizer initialized successfully.");
    }

    public async Task<string> GenerateResponseAsync(string userMessage)
    {
        if (isGenerating)
        {
            Debug.LogWarning("Already generating a response. Please wait.");
            return "Already generating a response. Please wait.";
        }
        
        isGenerating = true;
        
        try
        {
            // Format prompt with chat template
            string prompt = FormatChatPrompt(userMessage);
            Debug.Log($"Formatted prompt: {prompt}");
            
            // Tokenize input
            var tokenIds = tokenizer.Encode(prompt);
            Debug.Log($"Encoded to {tokenIds.Length} tokens");
            
            if (tokenIds.Length > 0)
            {
                // Generate response token by token
                StringBuilder responseBuilder = new StringBuilder();
                List<int> currentIds = tokenIds.ToList();
                
                for (int i = 0; i < maxTokens; i++)
                {
                    // Make sure we don't exceed the model's context window
                    if (currentIds.Count > 1024)
                    {
                        // If too long, keep only the last 1024 tokens
                        currentIds = currentIds.Skip(currentIds.Count - 1024).Take(1024).ToList();
                    }
                    
                    // Create tensors for current sequence
                    using (var inputIdsTensor = new TensorInt(new TensorShape(1, currentIds.Count), currentIds.ToArray()))
                    using (var attentionMaskTensor = new TensorInt(new TensorShape(1, currentIds.Count), Enumerable.Repeat(1, currentIds.Count).ToArray()))
                    {
                        // Run inference
                        inputs.Clear();
                        inputs["input_ids"] = inputIdsTensor;
                        inputs["attention_mask"] = attentionMaskTensor;
                        
                        worker.Execute(inputs);
                        var logits = worker.PeekOutput() as TensorFloat;
                        
                        // Get next token prediction
                        int nextToken = SampleNextToken(logits, currentIds, temperature);
                        
                        // If we hit the end token or a newline after content, stop
                        if (nextToken == tokenizer.EosToken || 
                            (i > 0 && nextToken == tokenizer.NewlineToken))
                        {
                            break;
                        }
                        
                        // Add token to current sequence for next iteration
                        currentIds.Add(nextToken);
                        
                        // Decode the latest token
                        string newToken = tokenizer.Decode(new[] { nextToken });
                        responseBuilder.Append(newToken);
                        
                        // For smoother output, yield every few tokens
                        if (i % 5 == 0)
                        {
                            await Task.Delay(1);
                        }
                    }
                }
                
                // Return the full response, without the prompt
                string fullResponse = responseBuilder.ToString();
                return CleanResponse(fullResponse);
            }
            else
            {
                Debug.LogError("Tokenization failed: empty token list");
                return "Sorry, I couldn't process that input.";
            }
        }
        catch (System.Exception ex)
        {
            Debug.LogError($"Generation error: {ex.Message}\\n{ex.StackTrace}");
            return "Sorry, an error occurred while generating a response.";
        }
        finally
        {
            isGenerating = false;
        }
    }
    
    private string FormatChatPrompt(string userMessage)
    {
        // You may need to adjust this template based on your specific model
        return $"User: {userMessage}\\nAssistant:";
    }
    
    private string CleanResponse(string response)
    {
        // Extract only the Assistant's response
        int assistantPrefix = response.IndexOf("Assistant:");
        if (assistantPrefix >= 0)
        {
            response = response.Substring(assistantPrefix + "Assistant:".Length).Trim();
        }
        
        // Stop at any "User:" marker if present
        int nextUser = response.IndexOf("User:");
        if (nextUser >= 0)
        {
            response = response.Substring(0, nextUser).Trim();
        }
        
        return response;
    }
    
    private int SampleNextToken(TensorFloat logits, List<int> currentInputs, float temp)
    {
        // Get logits for the last position
        int lastPos = currentInputs.Count - 1;
        int vocabSize = logits.shape.channels;
        
        // Prepare array for logits
        float[] lastLogits = new float[vocabSize];
        
        // Extract logits for the last token position
        for (int i = 0; i < vocabSize; i++)
        {
            lastLogits[i] = logits[0, lastPos, i];
        }
        
        // Simple temperature-based sampling
        if (temp <= 0.0f)
        {
            // Greedy sampling (argmax)
            int maxIndex = 0;
            float maxValue = lastLogits[0];
            
            for (int i = 1; i < vocabSize; i++)
            {
                if (lastLogits[i] > maxValue)
                {
                    maxValue = lastLogits[i];
                    maxIndex = i;
                }
            }
            
            return maxIndex;
        }
        else
        {
            // Temperature sampling
            // Apply temperature
            for (int i = 0; i < vocabSize; i++)
            {
                lastLogits[i] /= temp;
            }
            
            // Softmax
            float maxLogit = lastLogits.Max();
            float sum = 0.0f;
            
            for (int i = 0; i < vocabSize; i++)
            {
                lastLogits[i] = Mathf.Exp(lastLogits[i] - maxLogit);
                sum += lastLogits[i];
            }
            
            for (int i = 0; i < vocabSize; i++)
            {
                lastLogits[i] /= sum;
            }
            
            // Sample from distribution
            float random = Random.value;
            float cumulativeProb = 0.0f;
            
            for (int i = 0; i < vocabSize; i++)
            {
                cumulativeProb += lastLogits[i];
                if (random < cumulativeProb)
                {
                    return i;
                }
            }
            
            // Fallback to last token if sampling fails
            return vocabSize - 1;
        }
    }
    
    void OnDestroy()
    {
        worker?.Dispose();
    }
}

// Simple tokenizer implementation for Unity
public class SimpleTokenizer
{
    private Dictionary<string, int> vocab;
    private Dictionary<int, string> reversedVocab;
    
    public int PadToken { get; private set; }
    public int EosToken { get; private set; }
    public int BosToken { get; private set; }
    public int NewlineToken { get; private set; }
    
    public SimpleTokenizer(string vocabJson)
    {
        // Parse the vocabulary from JSON
        vocab = new Dictionary<string, int>();
        
        // Simple JSON parsing (you'll need a proper JSON parser in production)
        string[] entries = vocabJson.Split(new[] { '\\n', '{', '}', '\"', ':', ',' }, 
                                        System.StringSplitOptions.RemoveEmptyEntries);
        
        for (int i = 0; i < entries.Length - 1; i += 2)
        {
            string token = entries[i].Trim();
            if (int.TryParse(entries[i + 1].Trim(), out int id))
            {
                vocab[token] = id;
            }
        }
        
        // Create reversed vocabulary for decoding
        reversedVocab = vocab.ToDictionary(kv => kv.Value, kv => kv.Key);
        
        // Find special tokens
        SetSpecialTokens();
        
        Debug.Log($"Tokenizer initialized with {vocab.Count} tokens");
    }
    
    private void SetSpecialTokens()
    {
        // Try to find standard special tokens
        PadToken = FindToken(new[] { "<pad>", "[PAD]", "<|endoftext|>" });
        EosToken = FindToken(new[] { "</s>", "<|endoftext|>", "[EOS]", "<eos>" });
        BosToken = FindToken(new[] { "<s>", "<|startoftext|>", "[BOS]", "<bos>" });
        
        // Find newline token
        foreach (var entry in vocab)
        {
            if (entry.Key == "\\n" || entry.Key == "<\\n>" || entry.Key == "\\n")
            {
                NewlineToken = entry.Value;
                break;
            }
        }
        
        Debug.Log($"Special tokens - PAD: {PadToken}, EOS: {EosToken}, BOS: {BosToken}, NEWLINE: {NewlineToken}");
    }
    
    private int FindToken(string[] candidates)
    {
        foreach (var candidate in candidates)
        {
            if (vocab.TryGetValue(candidate, out int id))
            {
                return id;
            }
        }
        
        // Return -1 if not found
        return -1;
    }
    
    public int[] Encode(string text)
    {
    // Simple character-level tokenization
    // In production, use a proper BPE/WordPiece tokenizer implementation
    List<int> tokens = new List<int>();
    StringBuilder currentToken = new StringBuilder();
    
    // Add BOS token if available
    if (BosToken != -1)
    {
        tokens.Add(BosToken);
    }
    
    // Very simple tokenization - in production, this would implement
    // the specific tokenization algorithm for your model
    foreach (char c in text)
    {
        currentToken.Append(c);
        string current = currentToken.ToString();
        
        if (vocab.TryGetValue(current, out int id))
        {
            tokens.Add(id);
            currentToken.Clear();
        }
        else if (currentToken.Length > 10)
        {
            // If token is too long, add unknown token and reset
            tokens.Add(vocab.ContainsKey("<unk>") ? vocab["<unk>"] : 0);
            currentToken.Clear();
            currentToken.Append(c);
        }
    }
    
    // Handle any remaining text
    if (currentToken.Length > 0)
    {
        tokens.Add(vocab.ContainsKey("<unk>") ? vocab["<unk>"] : 0);
    }
    
    return tokens.ToArray();
}

public string Decode(int[] ids)
{
    StringBuilder result = new StringBuilder();
    
    foreach (int id in ids)
    {
        if (reversedVocab.TryGetValue(id, out string token))
        {
            // Some tokenizers use special prefixes like "Δ " for spaces
            string processedToken = token
                .Replace("Δ ", " ")
                .Replace("Ċ", "\n")
                .Replace("▁", " ");
                
            result.Append(processedToken);
        }
    }
    
    return result.ToString();
}
}
""")

            # Calculate elapsed time
            end_time = time.time()
            duration = end_time - start_time
            logger.info(f"βœ“ Conversion completed in {duration:.2f} seconds")
            logger.info(f"βœ“ Final model size: {os.path.getsize(onnx_path) / (1024 * 1024):.2f} MB")
            
            # Create a Python example usage file
            example_path = os.path.join(model_dir, "example_usage.py")
            with open(example_path, 'w') as f:
                f.write("""
import onnxruntime as ort
from transformers import AutoTokenizer
import numpy as np

# Load tokenizer and model
tokenizer = AutoTokenizer.from_pretrained("./")  # Path to model directory
session = ort.InferenceSession("./model.onnx")

def generate_response(user_message, max_length=50):
    # Format as a chat message
    prompt = f"User: {user_message}\\nAssistant:"
    inputs = tokenizer(prompt, return_tensors="np")
    
    input_ids = inputs["input_ids"]
    attention_mask = inputs["attention_mask"]
    
    # Simple auto-regressive generation loop
    for _ in range(max_length):
        # Run inference for a single step
        outputs = session.run(
            ["logits"], 
            {
                "input_ids": input_ids,
                "attention_mask": attention_mask
            }
        )
        
        # Get next token prediction from logits
        logits = outputs[0]
        next_token_logits = logits[0, -1, :]
        
        # Apply temperature sampling
        temperature = 0.7
        next_token_logits = next_token_logits / temperature
        
        # Apply softmax to get probabilities
        exp_logits = np.exp(next_token_logits - np.max(next_token_logits))
        probs = exp_logits / np.sum(exp_logits)
        
        # Sample from the distribution
        next_token_id = np.random.choice(probs.shape[0], p=probs)
        
        # Stop if we hit the end of sequence token
        if next_token_id == tokenizer.eos_token_id:
            break
            
        # Append new token to the input_ids
        input_ids = np.concatenate([input_ids, [[next_token_id]]], axis=1)
        attention_mask = np.concatenate([attention_mask, [[1]]], axis=1)
    
    # Decode the entire response
    response = tokenizer.decode(input_ids[0], skip_special_tokens=True)
    
    # Extract only the assistant's response
    if "Assistant:" in response:
        response = response.split("Assistant:")[-1].strip()
    
    return response

# Example usage
while True:
    user_input = input("You: ")
    if user_input.lower() in ['exit', 'quit']:
        break
    response = generate_response(user_input)
    print(f"Assistant: {response}")
""")
        
            logger.info(f"βœ“ Example usage saved to {example_path}")
            logger.info(f"βœ“ Unity integration example saved to {unity_example_path}")    
            return True
    
        else:
            logger.error(f"Γ— ONNX file not created at {onnx_path}")
            return False

    except Exception as e:
        logger.error(f"Γ— Error converting model: {str(e)}")
        logger.error(traceback.format_exc())
        return False

if __name__ == "__main__":
    # Parse command line arguments
    parser_available = False
    try:
        import argparse
        parser = argparse.ArgumentParser(description="Convert HuggingFace models to ONNX for Unity")
        parser.add_argument("model_id", type=str, help="HuggingFace model ID or path")
        parser.add_argument("--output_dir", "-o", type=str, default="./onnx_models", 
                          help="Output directory for the converted model")
        parser.add_argument("--seq_length", "-s", type=int, default=32,
                          help="Sequence length for model export")
        parser.add_argument("--no_quantize", action="store_true",
                          help="Skip INT8 quantization step")
        parser.add_argument("--opset", "-op", type=int, default=None,
                          help="Force a specific ONNX opset version")
        
        args = parser.parse_args()
        parser_available = True
        
        model_id = args.model_id
        output_dir = args.output_dir
        seq_length = args.seq_length
        quantize = not args.no_quantize
        force_opset = args.opset
        
    except (ImportError, NameError):
        # Fallback if argparse is not available
        parser_available = False
    
    if not parser_available:
        if len(sys.argv) < 2:
            print("Usage: python unity_compatible_converter.py MODEL_ID [OUTPUT_DIR] [SEQ_LENGTH] [--no-quantize] [--opset]")
            print("Example: python unity_compatible_converter.py distilgpt2 ./onnx_models 32")
            print("\nRecommended chat models for Unity:")
            print("  - distilgpt2 (smallest, opset 11)")
            print("  - EleutherAI/pythia-70m (better quality, opset 11)")
            print("  - microsoft/phi-1 (high quality, opset 14)")
            print("  - TinyLlama/TinyLlama-1.1B-Chat-v1.0 (chat-tuned, opset 14)")
            sys.exit(1)
        
        model_id = sys.argv[1]
        output_dir = sys.argv[2] if len(sys.argv) > 2 else "./onnx_models"
        seq_length = int(sys.argv[3]) if len(sys.argv) > 3 else 32
        quantize = "--no-quantize" not in sys.argv and "--no_quantize" not in sys.argv
        force_opset = None
        
        # Check for opset flag
        for i, arg in enumerate(sys.argv):
            if arg == "--opset" and i + 1 < len(sys.argv):
                force_opset = int(sys.argv[i + 1])
    
    # Check model architecture for automatic opset recommendation
    is_compatible, recommended_opset = is_architecture_compatible(model_id)
    
    # Print header
    logger.info("\nUNITY-COMPATIBLE ONNX CONVERTER")
    logger.info("===============================")
    logger.info(f"Model: {model_id}")
    logger.info(f"Output directory: {output_dir}")
    logger.info(f"Sequence length: {seq_length}")
    
    if force_opset is not None:
        logger.info(f"ONNX opset version: {force_opset} (forced)")
    else:
        logger.info(f"Recommended ONNX opset: {recommended_opset}")
        logger.info(f"Architecture compatible with opset 11: {'Yes' if is_compatible else 'No'}")
        
    logger.info(f"Quantization: {'Enabled' if quantize else 'Disabled'}")
    
    # Create output directory
    os.makedirs(output_dir, exist_ok=True)
    
    # Convert the model
    success = convert_model(model_id, output_dir, seq_length, quantize, force_opset)
    
    if success:
        logger.info("\n" + "=" * 60)
        logger.info("CONVERSION SUCCESSFUL")
        logger.info("=" * 60)
        logger.info(f"Model: {model_id}")
        logger.info(f"Output directory: {os.path.abspath(output_dir)}")
        logger.info("The model is ready for Unity integration!")
        logger.info("\nNext steps:")
        logger.info("1. Import the ONNX model into Unity using the Sentis package")
        logger.info("2. Use the unity_integration.cs file as a starting point")
        logger.info("3. For tokenization in Unity, implement the SimpleTokenizer class")
    else:
        logger.info("\n" + "=" * 60)
        logger.info("CONVERSION FAILED")
        logger.info("=" * 60)
        logger.info("Please try one of the recommended models that work well with Unity:")
        
        if is_compatible:
            logger.info("Compatible with Unity (opset 11):")
            logger.info("  - distilgpt2")
            logger.info("  - EleutherAI/pythia-70m")
        
        logger.info("Advanced models (require opset 14):")
        logger.info("  - microsoft/phi-1 --opset 14")
        logger.info("  - TinyLlama/TinyLlama-1.1B-Chat-v1.0 --opset 14")