File size: 27,193 Bytes
0f41ba2
0982ce8
 
 
3951932
6b78472
 
 
 
 
6aac4cc
 
 
 
 
 
 
 
 
 
 
12b3742
 
6aac4cc
ca0709a
 
 
12b3742
 
 
6aac4cc
 
12b3742
 
 
6aac4cc
12b3742
 
ca0709a
12b3742
 
85865a0
12b3742
 
 
85865a0
12b3742
 
 
 
85865a0
12b3742
 
85865a0
12b3742
 
 
ca0709a
12b3742
85865a0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12b3742
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6aac4cc
0f41ba2
 
 
 
 
 
 
 
 
 
 
 
 
eec0975
 
776d5b3
2fc2bf3
 
0f41ba2
 
 
 
 
 
 
 
 
 
 
eec0975
776d5b3
eec0975
776d5b3
eec0975
 
 
 
776d5b3
eec0975
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0f41ba2
 
 
 
 
 
 
 
2fc2bf3
0f41ba2
2fc2bf3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0f41ba2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2fc2bf3
0f41ba2
2fc2bf3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0f41ba2
 
2fc2bf3
0f41ba2
 
2fc2bf3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0f41ba2
 
 
 
 
 
 
 
 
fcc9ef6
0f41ba2
 
12b3742
 
 
0f41ba2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fcc9ef6
0f41ba2
 
12b3742
 
 
0f41ba2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
776d5b3
0f41ba2
49f568d
 
 
 
776d5b3
 
0f41ba2
b266eca
 
 
 
 
 
 
3c7a85f
b266eca
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0f41ba2
b266eca
 
 
 
 
 
 
 
 
 
3c7a85f
b266eca
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0f41ba2
b266eca
 
 
 
 
 
0f41ba2
 
 
 
49f568d
 
776d5b3
 
0f41ba2
 
b266eca
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3c7a85f
b266eca
 
 
 
 
 
 
 
 
 
3c7a85f
b266eca
 
 
 
 
 
 
 
 
 
 
 
0f41ba2
b266eca
 
 
 
 
0f41ba2
b266eca
 
0f41ba2
 
 
b266eca
 
0f41ba2
b266eca
0f41ba2
b266eca
 
 
 
 
 
0f41ba2
 
49f568d
0f41ba2
49f568d
 
0f41ba2
 
49f568d
0f41ba2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b266eca
0f41ba2
 
 
 
 
 
b266eca
 
0f41ba2
 
 
b266eca
 
0f41ba2
 
 
 
 
 
 
 
 
 
 
fcc9ef6
0f41ba2
b266eca
12b3742
 
 
b266eca
 
 
 
 
 
 
 
 
0f41ba2
fcc9ef6
0f41ba2
b266eca
12b3742
 
 
b266eca
 
 
12b3742
 
b266eca
 
 
 
 
 
 
 
0f41ba2
 
 
 
b266eca
0f41ba2
 
 
 
 
b266eca
0f41ba2
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
import os
# this is a HF Spaces specific hack, as
#  (i)  building pytorch3d with GPU support is a bit tricky here
#  (ii) installing the wheel via requirements.txt breaks ZeroGPU
import spaces

# Use the dynamic approach from PyTorch3D documentation to determine the correct wheel
import sys
import torch

# Print debug information about the environment
try:
    cuda_version = torch.version.cuda
    torch_version = torch.__version__
    python_version = f"{sys.version_info.major}.{sys.version_info.minor}"
    print(f"CUDA Version: {cuda_version}")
    print(f"PyTorch Version: {torch_version}")
    print(f"Python Version: {python_version}")
except Exception as e:
    print(f"Error detecting environment versions: {e}")

# Install PyTorch3D properly from source
print("Installing PyTorch3D from source...")

# First uninstall any existing PyTorch3D installation to avoid conflicts
os.system("pip uninstall -y pytorch3d")

# Install dependencies required for building PyTorch3D
os.system("apt-get update && apt-get install -y git build-essential libglib2.0-0 libsm6 libxrender-dev libxext6 ninja-build")
os.system("pip install 'imageio>=2.5.0' 'matplotlib>=3.1.2' 'numpy>=1.17.3' 'psutil>=5.6.5' 'scipy>=1.3.2' 'tqdm>=4.42.1' 'trimesh>=3.0.0'")
os.system("pip install fvcore iopath")

# Clone the PyTorch3D repository
os.system("rm -rf pytorch3d")  # Remove any existing directory
os.system("git clone https://github.com/facebookresearch/pytorch3d.git")

# Use a specific release tag that is known to be stable
os.system("cd pytorch3d && git checkout v0.7.4")

# Install PyTorch3D from source with CPU support
os.system("cd pytorch3d && pip install -e .")

# Verify the installation
import_result = os.popen('python -c "import pytorch3d; from pytorch3d import renderer; print(\'PyTorch3D and renderer successfully imported\')" 2>&1').read()
print(import_result)

# If the installation fails, try a different approach with a specific wheel
if "No module named" in import_result or "Error" in import_result:
    print("Source installation failed, trying with a specific wheel...")
    os.system("pip uninstall -y pytorch3d")
    
    # Try with a specific wheel that's known to work
    os.system("pip install https://dl.fbaipublicfiles.com/pytorch3d/packaging/wheels/py310_cpu_pyt201/pytorch3d-0.7.4-cp310-cp310-linux_x86_64.whl")
    
    # Verify again
    import_result = os.popen('python -c "import pytorch3d; from pytorch3d import renderer; print(\'PyTorch3D and renderer successfully imported\')" 2>&1').read()
    print(import_result)

# Patch the shap_e renderer to handle PyTorch3D renderer import error if needed
shap_e_renderer_path = "/usr/local/lib/python3.10/site-packages/shap_e/models/stf/renderer.py"
if os.path.exists(shap_e_renderer_path):
    print(f"Patching shap_e renderer at {shap_e_renderer_path}")
    
    # Read the current content
    with open(shap_e_renderer_path, "r") as f:
        content = f.read()
    
    # Create a backup
    os.system(f"cp {shap_e_renderer_path} {shap_e_renderer_path}.bak")
    
    # Modify the content to handle the error more gracefully
    modified_content = content
    
    # Replace the error message
    if "exception rendering with PyTorch3D" in content:
        modified_content = modified_content.replace(
            'warnings.warn(f"exception rendering with PyTorch3D: {exc}")',
            'warnings.warn("Using native PyTorch renderer")'
        )
    
    # Replace the fallback warning
    if "falling back on native PyTorch renderer" in modified_content:
        modified_content = modified_content.replace(
            'warnings.warn("falling back on native PyTorch renderer, which does not support full gradients")',
            'warnings.warn("Using native PyTorch renderer")'
        )
    
    # Write the modified content
    with open(shap_e_renderer_path, "w") as f:
        f.write(modified_content)
    
    print("Successfully patched shap_e renderer")
else:
    print(f"shap_e renderer not found at {shap_e_renderer_path}")

# Add a helper function to ensure PyTorch3D works with ZeroGPU
def ensure_pytorch3d_cuda_compatibility():
    """
    This function ensures PyTorch3D works correctly with CUDA in ZeroGPU environments.
    It should be called at the beginning of any @spaces.GPU decorated function.
    """
    try:
        import pytorch3d
        if torch.cuda.is_available():
            # Check if we can access the renderer module
            from pytorch3d import renderer
            print("PyTorch3D renderer module is available with CUDA")
        else:
            print("CUDA is not available, using CPU version of PyTorch3D")
    except ImportError as e:
        print(f"Error importing PyTorch3D: {e}")
    except Exception as e:
        print(f"Unexpected error with PyTorch3D: {e}")

import torch
import torch.nn as nn
import gradio as gr
import numpy as np
from PIL import Image
from omegaconf import OmegaConf
from pytorch_lightning import seed_everything
from huggingface_hub import hf_hub_download
from diffusers import DiffusionPipeline, EulerAncestralDiscreteScheduler
from einops import rearrange
from shap_e.diffusion.sample import sample_latents
from shap_e.diffusion.gaussian_diffusion import diffusion_from_config
from shap_e.models.download import load_model, load_config
from shap_e.util.notebooks import create_pan_cameras, decode_latent_images
from shap_e.models.nn.camera import DifferentiableCameraBatch, DifferentiableProjectiveCamera
import math
import time
from requests.exceptions import ReadTimeout, ConnectionError

from src.utils.train_util import instantiate_from_config
from src.utils.camera_util import (
    FOV_to_intrinsics, 
    get_zero123plus_input_cameras,
    get_circular_camera_poses,
    spherical_camera_pose
)
from src.utils.mesh_util import save_obj, save_glb
from src.utils.infer_util import remove_background, resize_foreground

def create_custom_cameras(size: int, device: torch.device, azimuths: list, elevations: list, 
                          fov_degrees: float, distance: float) -> DifferentiableCameraBatch:
    # Object is in a 2x2x2 bounding box (-1 to 1 in each dimension)
    object_diagonal = distance # Correct diagonal calculation for the cube
    
    # Calculate radius based on object size and FOV
    fov_radians = math.radians(fov_degrees)
    radius = (object_diagonal / 2) / math.tan(fov_radians / 2)  # Correct radius calculation
    
    origins = []
    xs = []
    ys = []
    zs = []
    
    for azimuth, elevation in zip(azimuths, elevations):
        azimuth_rad = np.radians(azimuth-90)
        elevation_rad = np.radians(elevation)
        
        # Calculate camera position
        x = radius * np.cos(elevation_rad) * np.cos(azimuth_rad)
        y = radius * np.cos(elevation_rad) * np.sin(azimuth_rad)
        z = radius * np.sin(elevation_rad)
        origin = np.array([x, y, z])
        
        # Calculate camera orientation
        z_axis = -origin / np.linalg.norm(origin)  # Point towards center
        x_axis = np.array([-np.sin(azimuth_rad), np.cos(azimuth_rad), 0])
        y_axis = np.cross(z_axis, x_axis)
        
        origins.append(origin)
        zs.append(z_axis)
        xs.append(x_axis)
        ys.append(y_axis)

    return DifferentiableCameraBatch(
        shape=(1, len(origins)),
        flat_camera=DifferentiableProjectiveCamera(
            origin=torch.from_numpy(np.stack(origins, axis=0)).float().to(device),
            x=torch.from_numpy(np.stack(xs, axis=0)).float().to(device),
            y=torch.from_numpy(np.stack(ys, axis=0)).float().to(device),
            z=torch.from_numpy(np.stack(zs, axis=0)).float().to(device),
            width=size,
            height=size,
            x_fov=fov_radians,
            y_fov=fov_radians,
        ),
    )

def load_models():
    """Initialize and load all required models"""
    config = OmegaConf.load('configs/instant-nerf-large-best.yaml')
    model_config = config.model_config
    infer_config = config.infer_config

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    
    # Load diffusion pipeline with retry logic
    print('Loading diffusion pipeline...')
    max_retries = 3
    retry_delay = 5
    
    for attempt in range(max_retries):
        try:
            pipeline = DiffusionPipeline.from_pretrained(
                "sudo-ai/zero123plus-v1.2",
                custom_pipeline="zero123plus",
                torch_dtype=torch.float16,
                local_files_only=False,
                resume_download=True,
            )
            break
        except (ReadTimeout, ConnectionError) as e:
            if attempt == max_retries - 1:
                raise Exception(f"Failed to download pipeline after {max_retries} attempts: {str(e)}")
            print(f"Download attempt {attempt + 1} failed, retrying in {retry_delay} seconds...")
            time.sleep(retry_delay)
            retry_delay *= 2  # Exponential backoff
    
    pipeline.scheduler = EulerAncestralDiscreteScheduler.from_config(
        pipeline.scheduler.config, timestep_spacing='trailing'
    )

    # Modify UNet to handle 8 input channels instead of 4
    in_channels = 8
    out_channels = pipeline.unet.conv_in.out_channels
    pipeline.unet.register_to_config(in_channels=in_channels)
    with torch.no_grad():
        new_conv_in = nn.Conv2d(
            in_channels, out_channels, pipeline.unet.conv_in.kernel_size, 
            pipeline.unet.conv_in.stride, pipeline.unet.conv_in.padding
        )
        new_conv_in.weight.zero_()
        new_conv_in.weight[:, :4, :, :].copy_(pipeline.unet.conv_in.weight)
        pipeline.unet.conv_in = new_conv_in

    # Load custom UNet with retry logic
    print('Loading custom UNet...')
    for attempt in range(max_retries):
        try:
            pipeline.unet = pipeline.unet.from_pretrained(
                "YiftachEde/Sharp-It",
                local_files_only=False,
                resume_download=True,
            ).to(torch.float16)
            break
        except (ReadTimeout, ConnectionError) as e:
            if attempt == max_retries - 1:
                raise Exception(f"Failed to download UNet after {max_retries} attempts: {str(e)}")
            print(f"Download attempt {attempt + 1} failed, retrying in {retry_delay} seconds...")
            time.sleep(retry_delay)
            retry_delay *= 2

    pipeline = pipeline.to(device).to(torch_dtype=torch.float16)

    # Load reconstruction model with retry logic
    print('Loading reconstruction model...')
    model = instantiate_from_config(model_config)
    
    for attempt in range(max_retries):
        try:
            model_path = hf_hub_download(
                repo_id="TencentARC/InstantMesh",
                filename="instant_nerf_large.ckpt",
                repo_type="model",
                local_files_only=False,
                resume_download=True,
                cache_dir="model_cache"  # Use a specific cache directory
            )
            break
        except (ReadTimeout, ConnectionError) as e:
            if attempt == max_retries - 1:
                raise Exception(f"Failed to download model after {max_retries} attempts: {str(e)}")
            print(f"Download attempt {attempt + 1} failed, retrying in {retry_delay} seconds...")
            time.sleep(retry_delay)
            retry_delay *= 2

    state_dict = torch.load(model_path, map_location='cpu')['state_dict']
    state_dict = {k[14:]: v for k, v in state_dict.items() 
                 if k.startswith('lrm_generator.') and 'source_camera' not in k}
    model.load_state_dict(state_dict, strict=True)
    model = model.to(device)
    model.eval()
    
    return pipeline, model, infer_config

@spaces.GPU(duration=20)
def process_images(input_images, prompt, steps=75, guidance_scale=7.5, pipeline=None):
    """Process input images and run refinement"""
    # Ensure PyTorch3D works with CUDA
    ensure_pytorch3d_cuda_compatibility()
    
    device = pipeline.device
    
    if isinstance(input_images, list):
        if len(input_images) == 1:
            # Check if this is a pre-arranged layout
            img = Image.open(input_images[0].name).convert('RGB')
            if img.size == (640, 960):
                # This is already a layout, use it directly
                input_image = img
            else:
                # Single view - need 6 copies
                img = img.resize((320, 320))
                img_array = np.array(img) / 255.0
                images = [img_array] * 6
                images = np.stack(images)
                
                # Convert to tensor and create layout
                images = torch.from_numpy(images).float()
                images = images.permute(0, 3, 1, 2)
                images = images.reshape(3, 2, 3, 320, 320)
                images = images.permute(0, 2, 3, 1, 4)
                images = images.reshape(3, 3, 320, 640)
                images = images.reshape(1, 3, 960, 640)
                
                # Convert back to PIL
                images = images.permute(0, 2, 3, 1)[0]
                images = (images.numpy() * 255).astype(np.uint8)
                input_image = Image.fromarray(images)
        else:
            # Multiple individual views
            images = []
            for img_file in input_images:
                img = Image.open(img_file.name).convert('RGB')
                img = img.resize((320, 320))
                img = np.array(img) / 255.0
                images.append(img)
            
            # Pad to 6 images if needed
            while len(images) < 6:
                images.append(np.zeros_like(images[0]))
            images = np.stack(images[:6])
            
            # Convert to tensor and create layout
            images = torch.from_numpy(images).float()
            images = images.permute(0, 3, 1, 2)
            images = images.reshape(3, 2, 3, 320, 320)
            images = images.permute(0, 2, 3, 1, 4)
            images = images.reshape(3, 3, 320, 640)
            images = images.reshape(1, 3, 960, 640)
            
            # Convert back to PIL
            images = images.permute(0, 2, 3, 1)[0]
            images = (images.numpy() * 255).astype(np.uint8)
            input_image = Image.fromarray(images)
    else:
        raise ValueError("Expected a list of images")

    # Generate refined output
    output = pipeline.refine(
        input_image,
        prompt=prompt,
        num_inference_steps=int(steps),
        guidance_scale=guidance_scale
    ).images[0]
    
    return output, input_image

@spaces.GPU(duration=20)
def create_mesh(refined_image, model, infer_config):
    """Generate mesh from refined image"""
    # Ensure PyTorch3D works with CUDA
    ensure_pytorch3d_cuda_compatibility()
    
    # Convert PIL image to tensor
    image = np.array(refined_image) / 255.0
    image = torch.from_numpy(image).float().permute(2, 0, 1)
    
    # Reshape to 6 views
    image = image.reshape(3, 960, 640)
    image = image.reshape(3, 3, 320, 640)
    image = image.permute(1, 0, 2, 3)
    image = image.reshape(3, 3, 320, 2, 320)
    image = image.permute(0, 3, 1, 2, 4)
    image = image.reshape(6, 3, 320, 320)
    
    # Add batch dimension
    image = image.unsqueeze(0)
    
    input_cameras = get_zero123plus_input_cameras(batch_size=1, radius=4.0).to("cuda")
    image = image.to("cuda")
    
    with torch.no_grad():
        planes = model.forward_planes(image, input_cameras)
        mesh_out = model.extract_mesh(planes, **infer_config)
        vertices, faces, vertex_colors = mesh_out
        
    return vertices, faces, vertex_colors

class ShapERenderer:
    def __init__(self, device):
        print("Initializing Shap-E models...")
        self.device = device
        torch.cuda.empty_cache()  # Clear GPU memory before loading
        self.xm = load_model('transmitter', device=self.device)
        self.model = load_model('text300M', device=self.device)
        self.diffusion = diffusion_from_config(load_config('diffusion'))
        print("Shap-E models initialized!")
    
    def generate_views(self, prompt, guidance_scale=15.0, num_steps=64):
        try:
            torch.cuda.empty_cache()  # Clear GPU memory before generation
            
            # Generate latents using the text-to-3D model
            batch_size = 1
            guidance_scale = float(guidance_scale)
            
            with torch.amp.autocast('cuda'):  # Use automatic mixed precision
                latents = sample_latents(
                    batch_size=batch_size,
                    model=self.model,
                    diffusion=self.diffusion,
                    guidance_scale=guidance_scale,
                    model_kwargs=dict(texts=[prompt] * batch_size),
                    progress=True,
                    clip_denoised=True,
                    use_fp16=True,
                    use_karras=True,
                    karras_steps=num_steps,
                    sigma_min=1e-3,
                    sigma_max=160,
                    s_churn=0,
                )

            # Render the 6 views we need with specific viewing angles
            size = 320  # Size of each rendered image
            images = []
            
            # Define our 6 specific camera positions to match refine.py
            azimuths = [30, 90, 150, 210, 270, 330]
            elevations = [20, -10, 20, -10, 20, -10]
            
            for i, (azimuth, elevation) in enumerate(zip(azimuths, elevations)):
                cameras = create_custom_cameras(size, self.device, azimuths=[azimuth], elevations=[elevation], fov_degrees=30, distance=3.0)
                with torch.amp.autocast('cuda'):  # Use automatic mixed precision
                    rendered_image = decode_latent_images(
                        self.xm,
                        latents[0],
                        cameras=cameras,
                        rendering_mode='stf'
                    )
                images.append(rendered_image[0])
                torch.cuda.empty_cache()  # Clear GPU memory after each view
            
            # Convert images to uint8
            images = [np.array(image) for image in images]
            
            # Create 2x3 grid layout (640x960)
            layout = np.zeros((960, 640, 3), dtype=np.uint8)
            for i, img in enumerate(images):
                row = i // 2
                col = i % 2
                layout[row*320:(row+1)*320, col*320:(col+1)*320] = img

            return Image.fromarray(layout), images
            
        except Exception as e:
            print(f"Error in generate_views: {e}")
            torch.cuda.empty_cache()  # Clear GPU memory on error
            raise

class RefinerInterface:
    def __init__(self):
        print("Initializing InstantMesh models...")
        torch.cuda.empty_cache()  # Clear GPU memory before loading
        self.pipeline, self.model, self.infer_config = load_models()
        print("InstantMesh models initialized!")
    
    def refine_model(self, input_image, prompt, steps=75, guidance_scale=7.5):
        """Main refinement function"""
        try:
            torch.cuda.empty_cache()  # Clear GPU memory before processing
            
            # Process image and get refined output
            input_image = Image.fromarray(input_image)
            
            # Rotate the layout if needed (if we're getting a 640x960 layout but pipeline expects 960x640)
            if input_image.width == 960 and input_image.height == 640:
                # Transpose the image to get 960x640 layout
                input_array = np.array(input_image)
                new_layout = np.zeros((960, 640, 3), dtype=np.uint8)
                
                # Rearrange from 2x3 to 3x2
                for i in range(6):
                    src_row = i // 3
                    src_col = i % 3
                    dst_row = i // 2
                    dst_col = i % 2
                    
                    new_layout[dst_row*320:(dst_row+1)*320, dst_col*320:(dst_col+1)*320] = \
                        input_array[src_row*320:(src_row+1)*320, src_col*320:(src_col+1)*320]
                
                input_image = Image.fromarray(new_layout)
            
            # Process with the pipeline (expects 960x640)
            with torch.amp.autocast('cuda'):  # Use automatic mixed precision
                refined_output_960x640 = self.pipeline.refine(
                    input_image,
                    prompt=prompt,
                    num_inference_steps=int(steps),
                    guidance_scale=guidance_scale
                ).images[0]
            
            torch.cuda.empty_cache()  # Clear GPU memory after refinement
            
            # Generate mesh using the 960x640 format
            with torch.amp.autocast('cuda'):  # Use automatic mixed precision
                vertices, faces, vertex_colors = create_mesh(
                    refined_output_960x640, 
                    self.model, 
                    self.infer_config
                )
            
            torch.cuda.empty_cache()  # Clear GPU memory after mesh generation
            
            # Save temporary mesh file
            os.makedirs("temp", exist_ok=True)
            temp_obj = os.path.join("temp", "refined_mesh.obj")
            save_obj(vertices, faces, vertex_colors, temp_obj)
            
            # Convert the output to 640x960 for display
            refined_array = np.array(refined_output_960x640)
            display_layout = np.zeros((960, 640, 3), dtype=np.uint8)
            
            # Rearrange from 3x2 to 2x3
            for i in range(6):
                src_row = i // 2
                src_col = i % 2
                dst_row = i // 2
                dst_col = i % 2
                
                display_layout[dst_row*320:(dst_row+1)*320, dst_col*320:(dst_col+1)*320] = \
                    refined_array[src_row*320:(src_row+1)*320, src_col*320:(src_col+1)*320]
            
            refined_output_640x960 = Image.fromarray(display_layout)
            
            return refined_output_640x960, temp_obj
            
        except Exception as e:
            print(f"Error in refine_model: {e}")
            torch.cuda.empty_cache()  # Clear GPU memory on error
            raise

def create_demo():
    print("Initializing models...")
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    
    # Initialize models at startup
    shap_e = ShapERenderer(device)
    refiner = RefinerInterface()
    print("All models initialized!")
    
    with gr.Blocks() as demo:
        gr.Markdown("# Shap-E to InstantMesh Pipeline")
        
        # First row: Controls
        with gr.Row():
            with gr.Column():
                # Shap-E inputs
                shape_prompt = gr.Textbox(
                    label="Shap-E Prompt", 
                    placeholder="Enter text to generate initial 3D model..."
                )
                shape_guidance = gr.Slider(
                    minimum=1, 
                    maximum=30, 
                    value=15.0, 
                    label="Shap-E Guidance Scale"
                )
                shape_steps = gr.Slider(
                    minimum=16, 
                    maximum=128, 
                    value=64, 
                    step=16, 
                    label="Shap-E Steps"
                )
                generate_btn = gr.Button("Generate Views")
            
            with gr.Column():
                # Refinement inputs
                refine_prompt = gr.Textbox(
                    label="Refinement Prompt", 
                    placeholder="Enter prompt to guide refinement..."
                )
                refine_steps = gr.Slider(
                    minimum=30,
                    maximum=100,
                    value=75,
                    step=1,
                    label="Refinement Steps"
                )
                refine_guidance = gr.Slider(
                    minimum=1,
                    maximum=20,
                    value=7.5,
                    label="Refinement Guidance Scale"
                )
                refine_btn = gr.Button("Refine")
                error_output = gr.Textbox(label="Status/Error Messages", interactive=False)

        # Second row: Image panels side by side
        with gr.Row():
            # Outputs - Images side by side
            shape_output = gr.Image(
                label="Generated Views", 
                width=640,
                height=960
            )
            refined_output = gr.Image(
                label="Refined Output",
                width=640,
                height=960
            )
        
        # Third row: 3D mesh panel below
        with gr.Row():
            # 3D mesh centered
            mesh_output = gr.Model3D(
                label="3D Mesh", 
                clear_color=[1.0, 1.0, 1.0, 1.0],
            )

        # Set up event handlers
        @spaces.GPU(duration=20)  # Reduced duration to 20 seconds
        def generate(prompt, guidance_scale, num_steps):
            try:
                # Ensure PyTorch3D works with CUDA
                ensure_pytorch3d_cuda_compatibility()
                
                torch.cuda.empty_cache()  # Clear GPU memory before starting
                with torch.no_grad():
                    layout, _ = shap_e.generate_views(prompt, guidance_scale, num_steps)
                return layout, None  # Return None for error message
            except Exception as e:
                torch.cuda.empty_cache()  # Clear GPU memory on error
                error_msg = f"Error during generation: {str(e)}"
                print(error_msg)
                return None, error_msg

        @spaces.GPU(duration=20)  # Reduced duration to 20 seconds
        def refine(input_image, prompt, steps, guidance_scale):
            try:
                # Ensure PyTorch3D works with CUDA
                ensure_pytorch3d_cuda_compatibility()
                
                torch.cuda.empty_cache()  # Clear GPU memory before starting
                refined_img, mesh_path = refiner.refine_model(
                    input_image, 
                    prompt,
                    steps,
                    guidance_scale
                )
                return refined_img, mesh_path, None  # Return None for error message
            except Exception as e:
                torch.cuda.empty_cache()  # Clear GPU memory on error
                error_msg = f"Error during refinement: {str(e)}"
                print(error_msg)
                return None, None, error_msg

        generate_btn.click(
            fn=generate,
            inputs=[shape_prompt, shape_guidance, shape_steps],
            outputs=[shape_output, error_output]
        )

        refine_btn.click(
            fn=refine,
            inputs=[shape_output, refine_prompt, refine_steps, refine_guidance],
            outputs=[refined_output, mesh_output, error_output]
        )

    return demo

if __name__ == "__main__":
    demo = create_demo()
    demo.launch(share=True)