YiftachEde commited on
Commit
0f41ba2
·
1 Parent(s): 805a8bb

Adding app.py

Browse files
Files changed (2) hide show
  1. app.py +419 -0
  2. app2.py +419 -0
app.py ADDED
@@ -0,0 +1,419 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import torch.nn as nn
4
+ import gradio as gr
5
+ import numpy as np
6
+ from PIL import Image
7
+ from omegaconf import OmegaConf
8
+ from pytorch_lightning import seed_everything
9
+ from huggingface_hub import hf_hub_download
10
+ from diffusers import DiffusionPipeline, EulerAncestralDiscreteScheduler
11
+ from einops import rearrange
12
+ from shap_e.diffusion.sample import sample_latents
13
+ from shap_e.diffusion.gaussian_diffusion import diffusion_from_config
14
+ from shap_e.models.download import load_model, load_config
15
+ from shap_e.util.notebooks import create_pan_cameras, decode_latent_images, create_custom_cameras
16
+
17
+ from src.utils.train_util import instantiate_from_config
18
+ from src.utils.camera_util import (
19
+ FOV_to_intrinsics,
20
+ get_zero123plus_input_cameras,
21
+ get_circular_camera_poses,
22
+ spherical_camera_pose
23
+ )
24
+ from src.utils.mesh_util import save_obj, save_glb
25
+ from src.utils.infer_util import remove_background, resize_foreground
26
+
27
+ def load_models():
28
+ """Initialize and load all required models"""
29
+ config = OmegaConf.load('configs/instant-nerf-large-best.yaml')
30
+ model_config = config.model_config
31
+ infer_config = config.infer_config
32
+
33
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
34
+
35
+ # Load diffusion pipeline
36
+ print('Loading diffusion pipeline...')
37
+ pipeline = DiffusionPipeline.from_pretrained(
38
+ "sudo-ai/zero123plus-v1.2",
39
+ custom_pipeline="zero123plus",
40
+ torch_dtype=torch.float16
41
+ )
42
+ pipeline.scheduler = EulerAncestralDiscreteScheduler.from_config(
43
+ pipeline.scheduler.config, timestep_spacing='trailing'
44
+ )
45
+
46
+ # Modify UNet to handle 8 input channels instead of 4
47
+ in_channels = 8
48
+ out_channels = pipeline.unet.conv_in.out_channels
49
+ pipeline.unet.register_to_config(in_channels=in_channels)
50
+ with torch.no_grad():
51
+ new_conv_in = nn.Conv2d(
52
+ in_channels, out_channels, pipeline.unet.conv_in.kernel_size,
53
+ pipeline.unet.conv_in.stride, pipeline.unet.conv_in.padding
54
+ )
55
+ new_conv_in.weight.zero_()
56
+ new_conv_in.weight[:, :4, :, :].copy_(pipeline.unet.conv_in.weight)
57
+ pipeline.unet.conv_in = new_conv_in
58
+
59
+ # Load custom UNet
60
+ print('Loading custom UNet...')
61
+ unet_path = "best_21.ckpt"
62
+ state_dict = torch.load(unet_path, map_location='cpu')
63
+
64
+ # Process the state dict to match the model keys
65
+ if 'state_dict' in state_dict:
66
+ new_state_dict = {key.replace('unet.unet.', ''): value for key, value in state_dict['state_dict'].items()}
67
+ pipeline.unet.load_state_dict(new_state_dict, strict=False)
68
+ else:
69
+ pipeline.unet.load_state_dict(state_dict, strict=False)
70
+
71
+ pipeline = pipeline.to(device).to(torch_dtype=torch.float16)
72
+
73
+ # Load reconstruction model
74
+ print('Loading reconstruction model...')
75
+ model = instantiate_from_config(model_config)
76
+ model_path = hf_hub_download(
77
+ repo_id="TencentARC/InstantMesh",
78
+ filename="instant_nerf_large.ckpt",
79
+ repo_type="model"
80
+ )
81
+ state_dict = torch.load(model_path, map_location='cpu')['state_dict']
82
+ state_dict = {k[14:]: v for k, v in state_dict.items()
83
+ if k.startswith('lrm_generator.') and 'source_camera' not in k}
84
+ model.load_state_dict(state_dict, strict=True)
85
+ model = model.to(device)
86
+ model.eval()
87
+
88
+ return pipeline, model, infer_config
89
+
90
+ def process_images(input_images, prompt, steps=75, guidance_scale=7.5, pipeline=None):
91
+ """Process input images and run refinement"""
92
+ device = pipeline.device
93
+
94
+ if isinstance(input_images, list):
95
+ if len(input_images) == 1:
96
+ # Check if this is a pre-arranged layout
97
+ img = Image.open(input_images[0].name).convert('RGB')
98
+ if img.size == (640, 960):
99
+ # This is already a layout, use it directly
100
+ input_image = img
101
+ else:
102
+ # Single view - need 6 copies
103
+ img = img.resize((320, 320))
104
+ img_array = np.array(img) / 255.0
105
+ images = [img_array] * 6
106
+ images = np.stack(images)
107
+
108
+ # Convert to tensor and create layout
109
+ images = torch.from_numpy(images).float()
110
+ images = images.permute(0, 3, 1, 2)
111
+ images = images.reshape(3, 2, 3, 320, 320)
112
+ images = images.permute(0, 2, 3, 1, 4)
113
+ images = images.reshape(3, 3, 320, 640)
114
+ images = images.reshape(1, 3, 960, 640)
115
+
116
+ # Convert back to PIL
117
+ images = images.permute(0, 2, 3, 1)[0]
118
+ images = (images.numpy() * 255).astype(np.uint8)
119
+ input_image = Image.fromarray(images)
120
+ else:
121
+ # Multiple individual views
122
+ images = []
123
+ for img_file in input_images:
124
+ img = Image.open(img_file.name).convert('RGB')
125
+ img = img.resize((320, 320))
126
+ img = np.array(img) / 255.0
127
+ images.append(img)
128
+
129
+ # Pad to 6 images if needed
130
+ while len(images) < 6:
131
+ images.append(np.zeros_like(images[0]))
132
+ images = np.stack(images[:6])
133
+
134
+ # Convert to tensor and create layout
135
+ images = torch.from_numpy(images).float()
136
+ images = images.permute(0, 3, 1, 2)
137
+ images = images.reshape(3, 2, 3, 320, 320)
138
+ images = images.permute(0, 2, 3, 1, 4)
139
+ images = images.reshape(3, 3, 320, 640)
140
+ images = images.reshape(1, 3, 960, 640)
141
+
142
+ # Convert back to PIL
143
+ images = images.permute(0, 2, 3, 1)[0]
144
+ images = (images.numpy() * 255).astype(np.uint8)
145
+ input_image = Image.fromarray(images)
146
+ else:
147
+ raise ValueError("Expected a list of images")
148
+
149
+ # Generate refined output
150
+ output = pipeline.refine(
151
+ input_image,
152
+ prompt=prompt,
153
+ num_inference_steps=int(steps),
154
+ guidance_scale=guidance_scale
155
+ ).images[0]
156
+
157
+ return output, input_image
158
+
159
+ def create_mesh(refined_image, model, infer_config):
160
+ """Generate mesh from refined image"""
161
+ # Convert PIL image to tensor
162
+ image = np.array(refined_image) / 255.0
163
+ image = torch.from_numpy(image).float().permute(2, 0, 1)
164
+
165
+ # Reshape to 6 views
166
+ image = image.reshape(3, 960, 640)
167
+ image = image.reshape(3, 3, 320, 640)
168
+ image = image.permute(1, 0, 2, 3)
169
+ image = image.reshape(3, 3, 320, 2, 320)
170
+ image = image.permute(0, 3, 1, 2, 4)
171
+ image = image.reshape(6, 3, 320, 320)
172
+
173
+ # Add batch dimension
174
+ image = image.unsqueeze(0)
175
+
176
+ input_cameras = get_zero123plus_input_cameras(batch_size=1, radius=4.0).to("cuda")
177
+ image = image.to("cuda")
178
+
179
+ with torch.no_grad():
180
+ planes = model.forward_planes(image, input_cameras)
181
+ mesh_out = model.extract_mesh(planes, **infer_config)
182
+ vertices, faces, vertex_colors = mesh_out
183
+
184
+ return vertices, faces, vertex_colors
185
+
186
+ class ShapERenderer:
187
+ def __init__(self, device):
188
+ print("Loading Shap-E models...")
189
+ self.device = device
190
+ self.xm = load_model('transmitter', device=device)
191
+ self.model = load_model('text300M', device=device)
192
+ self.diffusion = diffusion_from_config(load_config('diffusion'))
193
+ print("Shap-E models loaded!")
194
+
195
+ def generate_views(self, prompt, guidance_scale=15.0, num_steps=64):
196
+ # Generate latents using the text-to-3D model
197
+ batch_size = 1
198
+ guidance_scale = float(guidance_scale)
199
+ latents = sample_latents(
200
+ batch_size=batch_size,
201
+ model=self.model,
202
+ diffusion=self.diffusion,
203
+ guidance_scale=guidance_scale,
204
+ model_kwargs=dict(texts=[prompt] * batch_size),
205
+ progress=True,
206
+ clip_denoised=True,
207
+ use_fp16=True,
208
+ use_karras=True,
209
+ karras_steps=num_steps,
210
+ sigma_min=1e-3,
211
+ sigma_max=160,
212
+ s_churn=0,
213
+ )
214
+
215
+ # Render the 6 views we need with specific viewing angles
216
+ size = 320 # Size of each rendered image
217
+ images = []
218
+
219
+ # Define our 6 specific camera positions to match refine.py
220
+ azimuths = [30, 90, 150, 210, 270, 330]
221
+ elevations = [20, -10, 20, -10, 20, -10]
222
+
223
+ for i, (azimuth, elevation) in enumerate(zip(azimuths, elevations)):
224
+ cameras = create_custom_cameras(size, self.device, azimuths=[azimuth], elevations=[elevation], fov_degrees=30, distance=3.0)
225
+ rendered_image = decode_latent_images(
226
+ self.xm,
227
+ latents[0],
228
+ rendering_mode='stf',
229
+ cameras=cameras
230
+ )
231
+ images.append(rendered_image.detach().cpu().numpy())
232
+
233
+ # Convert images to uint8
234
+ images = [(image).astype(np.uint8) for image in images]
235
+
236
+ # Create 2x3 grid layout (640x960) instead of 3x2 (960x640)
237
+ layout = np.zeros((960, 640, 3), dtype=np.uint8)
238
+ for i, img in enumerate(images):
239
+ row = i // 2 # Now 3 images per row
240
+ col = i % 2 # Now 3 images per row
241
+ layout[row*320:(row+1)*320, col*320:(col+1)*320] = img
242
+
243
+ return Image.fromarray(layout), images
244
+
245
+ class RefinerInterface:
246
+ def __init__(self):
247
+ print("Initializing InstantMesh models...")
248
+ self.pipeline, self.model, self.infer_config = load_models()
249
+ print("InstantMesh models loaded!")
250
+
251
+ def refine_model(self, input_image, prompt, steps=75, guidance_scale=7.5):
252
+ """Main refinement function"""
253
+ # Process image and get refined output
254
+ input_image = Image.fromarray(input_image)
255
+
256
+ # Rotate the layout if needed (if we're getting a 640x960 layout but pipeline expects 960x640)
257
+ if input_image.width == 960 and input_image.height == 640:
258
+ # Transpose the image to get 960x640 layout
259
+ input_array = np.array(input_image)
260
+ new_layout = np.zeros((960, 640, 3), dtype=np.uint8)
261
+
262
+ # Rearrange from 2x3 to 3x2
263
+ for i in range(6):
264
+ src_row = i // 3
265
+ src_col = i % 3
266
+ dst_row = i // 2
267
+ dst_col = i % 2
268
+
269
+ new_layout[dst_row*320:(dst_row+1)*320, dst_col*320:(dst_col+1)*320] = \
270
+ input_array[src_row*320:(src_row+1)*320, src_col*320:(src_col+1)*320]
271
+
272
+ input_image = Image.fromarray(new_layout)
273
+
274
+ # Process with the pipeline (expects 960x640)
275
+ refined_output_960x640 = self.pipeline.refine(
276
+ input_image,
277
+ prompt=prompt,
278
+ num_inference_steps=int(steps),
279
+ guidance_scale=guidance_scale
280
+ ).images[0]
281
+
282
+ # Generate mesh using the 960x640 format
283
+ vertices, faces, vertex_colors = create_mesh(
284
+ refined_output_960x640,
285
+ self.model,
286
+ self.infer_config
287
+ )
288
+
289
+ # Save temporary mesh file
290
+ os.makedirs("temp", exist_ok=True)
291
+ temp_obj = os.path.join("temp", "refined_mesh.obj")
292
+ save_obj(vertices, faces, vertex_colors, temp_obj)
293
+
294
+ # Convert the output to 640x960 for display
295
+ refined_array = np.array(refined_output_960x640)
296
+ display_layout = np.zeros((960, 640, 3), dtype=np.uint8)
297
+
298
+ # Rearrange from 3x2 to 2x3
299
+ for i in range(6):
300
+ src_row = i // 2
301
+ src_col = i % 2
302
+ dst_row = i // 2
303
+ dst_col = i % 2
304
+
305
+ display_layout[dst_row*320:(dst_row+1)*320, dst_col*320:(dst_col+1)*320] = \
306
+ refined_array[src_row*320:(src_row+1)*320, src_col*320:(src_col+1)*320]
307
+
308
+ refined_output_640x960 = Image.fromarray(display_layout)
309
+
310
+ return refined_output_640x960, temp_obj
311
+
312
+ def create_demo():
313
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
314
+ shap_e = ShapERenderer(device)
315
+ refiner = RefinerInterface()
316
+
317
+ with gr.Blocks() as demo:
318
+ gr.Markdown("# Shap-E to InstantMesh Pipeline")
319
+
320
+ # First row: Controls
321
+ with gr.Row():
322
+ with gr.Column():
323
+ # Shap-E inputs
324
+ shape_prompt = gr.Textbox(
325
+ label="Shap-E Prompt",
326
+ placeholder="Enter text to generate initial 3D model..."
327
+ )
328
+ shape_guidance = gr.Slider(
329
+ minimum=1,
330
+ maximum=30,
331
+ value=15.0,
332
+ label="Shap-E Guidance Scale"
333
+ )
334
+ shape_steps = gr.Slider(
335
+ minimum=16,
336
+ maximum=128,
337
+ value=64,
338
+ step=16,
339
+ label="Shap-E Steps"
340
+ )
341
+ generate_btn = gr.Button("Generate Views")
342
+
343
+ with gr.Column():
344
+ # Refinement inputs
345
+ refine_prompt = gr.Textbox(
346
+ label="Refinement Prompt",
347
+ placeholder="Enter prompt to guide refinement..."
348
+ )
349
+ refine_steps = gr.Slider(
350
+ minimum=30,
351
+ maximum=100,
352
+ value=75,
353
+ step=1,
354
+ label="Refinement Steps"
355
+ )
356
+ refine_guidance = gr.Slider(
357
+ minimum=1,
358
+ maximum=20,
359
+ value=7.5,
360
+ label="Refinement Guidance Scale"
361
+ )
362
+ refine_btn = gr.Button("Refine")
363
+
364
+ # Second row: Image panels side by side
365
+ with gr.Row():
366
+ # Outputs - Images side by side
367
+ shape_output = gr.Image(
368
+ label="Generated Views",
369
+ width=640, # Swapped dimensions
370
+ height=960 # Swapped dimensions
371
+ )
372
+ refined_output = gr.Image(
373
+ label="Refined Output",
374
+ width=640, # Swapped dimensions
375
+ height=960 # Swapped dimensions
376
+ )
377
+
378
+ # Third row: 3D mesh panel below
379
+ with gr.Row():
380
+ # 3D mesh centered
381
+ mesh_output = gr.Model3D(
382
+ label="3D Mesh",
383
+ clear_color=[1.0, 1.0, 1.0, 1.0],
384
+ width=1280, # Full width
385
+ height=600 # Taller for better visualization
386
+ )
387
+
388
+ # Set up event handlers
389
+ def generate(prompt, guidance_scale, num_steps):
390
+ with torch.no_grad():
391
+ layout, _ = shap_e.generate_views(prompt, guidance_scale, num_steps)
392
+ return layout
393
+
394
+ def refine(input_image, prompt, steps, guidance_scale):
395
+ refined_img, mesh_path = refiner.refine_model(
396
+ input_image,
397
+ prompt,
398
+ steps,
399
+ guidance_scale
400
+ )
401
+ return refined_img, mesh_path
402
+
403
+ generate_btn.click(
404
+ fn=generate,
405
+ inputs=[shape_prompt, shape_guidance, shape_steps],
406
+ outputs=[shape_output]
407
+ )
408
+
409
+ refine_btn.click(
410
+ fn=refine,
411
+ inputs=[shape_output, refine_prompt, refine_steps, refine_guidance],
412
+ outputs=[refined_output, mesh_output]
413
+ )
414
+
415
+ return demo
416
+
417
+ if __name__ == "__main__":
418
+ demo = create_demo()
419
+ demo.launch(share=True)
app2.py ADDED
@@ -0,0 +1,419 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import torch.nn as nn
4
+ import gradio as gr
5
+ import numpy as np
6
+ from PIL import Image
7
+ from omegaconf import OmegaConf
8
+ from pytorch_lightning import seed_everything
9
+ from huggingface_hub import hf_hub_download
10
+ ""||||||||||||||||||||"from diffusers import DiffusionPipeline, EulerAncestralDiscreteScheduler
11
+ from einops import rearrange
12
+ from shap_e.diffusion.sample import sample_latents
13
+ from shap_e.diffusion.gaussian_diffusion import diffusion_from_config
14
+ from shap_e.models.download import load_model, load_config
15
+ from shap_e.util.notebooks import create_pan_cameras, decode_latent_images, create_custom_cameras
16
+
17
+ from src.utils.train_util import instantiate_from_config
18
+ from src.utils.camera_util import (
19
+ FOV_to_intrinsics,
20
+ get_zero123plus_input_cameras,
21
+ get_circular_camera_poses,
22
+ spherical_camera_pose
23
+ )
24
+ from src.utils.mesh_util import save_obj, save_glb
25
+ from src.utils.infer_util import remove_background, resize_foreground
26
+
27
+ def load_models():
28
+ """Initialize and load all required models"""
29
+ config = OmegaConf.load('configs/instant-nerf-large-best.yaml')
30
+ model_config = config.model_config
31
+ infer_config = config.infer_config
32
+
33
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
34
+
35
+ # Load diffusion pipeline
36
+ print('Loading diffusion pipeline...')
37
+ pipeline = DiffusionPipeline.from_pretrained(
38
+ "sudo-ai/zero123plus-v1.2",
39
+ custom_pipeline="zero123plus",
40
+ torch_dtype=torch.float16
41
+ )
42
+ pipeline.scheduler = EulerAncestralDiscreteScheduler.from_config(
43
+ pipeline.scheduler.config, timestep_spacing='trailing'
44
+ )
45
+
46
+ # Modify UNet to handle 8 input channels instead of 4
47
+ in_channels = 8
48
+ out_channels = pipeline.unet.conv_in.out_channels
49
+ pipeline.unet.register_to_config(in_channels=in_channels)
50
+ with torch.no_grad():
51
+ new_conv_in = nn.Conv2d(
52
+ in_channels, out_channels, pipeline.unet.conv_in.kernel_size,
53
+ pipeline.unet.conv_in.stride, pipeline.unet.conv_in.padding
54
+ )
55
+ new_conv_in.weight.zero_()
56
+ new_conv_in.weight[:, :4, :, :].copy_(pipeline.unet.conv_in.weight)
57
+ pipeline.unet.conv_in = new_conv_in
58
+
59
+ # Load custom UNet
60
+ print('Loading custom UNet...')
61
+ unet_path = "best_21.ckpt"
62
+ state_dict = torch.load(unet_path, map_location='cpu')
63
+
64
+ # Process the state dict to match the model keys
65
+ if 'state_dict' in state_dict:
66
+ new_state_dict = {key.replace('unet.unet.', ''): value for key, value in state_dict['state_dict'].items()}
67
+ pipeline.unet.load_state_dict(new_state_dict, strict=False)
68
+ else:
69
+ pipeline.unet.load_state_dict(state_dict, strict=False)
70
+
71
+ pipeline = pipeline.to(device).to(torch_dtype=torch.float16)
72
+
73
+ # Load reconstruction model
74
+ print('Loading reconstruction model...')
75
+ model = instantiate_from_config(model_config)
76
+ model_path = hf_hub_download(
77
+ repo_id="TencentARC/InstantMesh",
78
+ filename="instant_nerf_large.ckpt",
79
+ repo_type="model"
80
+ )
81
+ state_dict = torch.load(model_path, map_location='cpu')['state_dict']
82
+ state_dict = {k[14:]: v for k, v in state_dict.items()
83
+ if k.startswith('lrm_generator.') and 'source_camera' not in k}
84
+ model.load_state_dict(state_dict, strict=True)
85
+ model = model.to(device)
86
+ model.eval()
87
+
88
+ return pipeline, model, infer_config
89
+
90
+ def process_images(input_images, prompt, steps=75, guidance_scale=7.5, pipeline=None):
91
+ """Process input images and run refinement"""
92
+ device = pipeline.device
93
+
94
+ if isinstance(input_images, list):
95
+ if len(input_images) == 1:
96
+ # Check if this is a pre-arranged layout
97
+ img = Image.open(input_images[0].name).convert('RGB')
98
+ if img.size == (640, 960):
99
+ # This is already a layout, use it directly
100
+ input_image = img
101
+ else:
102
+ # Single view - need 6 copies
103
+ img = img.resize((320, 320))
104
+ img_array = np.array(img) / 255.0
105
+ images = [img_array] * 6
106
+ images = np.stack(images)
107
+
108
+ # Convert to tensor and create layout
109
+ images = torch.from_numpy(images).float()
110
+ images = images.permute(0, 3, 1, 2)
111
+ images = images.reshape(3, 2, 3, 320, 320)
112
+ images = images.permute(0, 2, 3, 1, 4)
113
+ images = images.reshape(3, 3, 320, 640)
114
+ images = images.reshape(1, 3, 960, 640)
115
+
116
+ # Convert back to PIL
117
+ images = images.permute(0, 2, 3, 1)[0]
118
+ images = (images.numpy() * 255).astype(np.uint8)
119
+ input_image = Image.fromarray(images)
120
+ else:
121
+ # Multiple individual views
122
+ images = []
123
+ for img_file in input_images:
124
+ img = Image.open(img_file.name).convert('RGB')
125
+ img = img.resize((320, 320))
126
+ img = np.array(img) / 255.0
127
+ images.append(img)
128
+
129
+ # Pad to 6 images if needed
130
+ while len(images) < 6:
131
+ images.append(np.zeros_like(images[0]))
132
+ images = np.stack(images[:6])
133
+
134
+ # Convert to tensor and create layout
135
+ images = torch.from_numpy(images).float()
136
+ images = images.permute(0, 3, 1, 2)
137
+ images = images.reshape(3, 2, 3, 320, 320)
138
+ images = images.permute(0, 2, 3, 1, 4)
139
+ images = images.reshape(3, 3, 320, 640)
140
+ images = images.reshape(1, 3, 960, 640)
141
+
142
+ # Convert back to PIL
143
+ images = images.permute(0, 2, 3, 1)[0]
144
+ images = (images.numpy() * 255).astype(np.uint8)
145
+ input_image = Image.fromarray(images)
146
+ else:
147
+ raise ValueError("Expected a list of images")
148
+
149
+ # Generate refined output
150
+ output = pipeline.refine(
151
+ input_image,
152
+ prompt=prompt,
153
+ num_inference_steps=int(steps),
154
+ guidance_scale=guidance_scale
155
+ ).images[0]
156
+
157
+ return output, input_image
158
+
159
+ def create_mesh(refined_image, model, infer_config):
160
+ """Generate mesh from refined image"""
161
+ # Convert PIL image to tensor
162
+ image = np.array(refined_image) / 255.0
163
+ image = torch.from_numpy(image).float().permute(2, 0, 1)
164
+
165
+ # Reshape to 6 views
166
+ image = image.reshape(3, 960, 640)
167
+ image = image.reshape(3, 3, 320, 640)
168
+ image = image.permute(1, 0, 2, 3)
169
+ image = image.reshape(3, 3, 320, 2, 320)
170
+ image = image.permute(0, 3, 1, 2, 4)
171
+ image = image.reshape(6, 3, 320, 320)
172
+
173
+ # Add batch dimension
174
+ image = image.unsqueeze(0)
175
+
176
+ input_cameras = get_zero123plus_input_cameras(batch_size=1, radius=4.0).to("cuda")
177
+ image = image.to("cuda")
178
+
179
+ with torch.no_grad():
180
+ planes = model.forward_planes(image, input_cameras)
181
+ mesh_out = model.extract_mesh(planes, **infer_config)
182
+ vertices, faces, vertex_colors = mesh_out
183
+
184
+ return vertices, faces, vertex_colors
185
+
186
+ class ShapERenderer:
187
+ def __init__(self, device):
188
+ print("Loading Shap-E models...")
189
+ self.device = device
190
+ self.xm = load_model('transmitter', device=device)
191
+ self.model = load_model('text300M', device=device)
192
+ self.diffusion = diffusion_from_config(load_config('diffusion'))
193
+ print("Shap-E models loaded!")
194
+
195
+ def generate_views(self, prompt, guidance_scale=15.0, num_steps=64):
196
+ # Generate latents using the text-to-3D model
197
+ batch_size = 1
198
+ guidance_scale = float(guidance_scale)
199
+ latents = sample_latents(
200
+ batch_size=batch_size,
201
+ model=self.model,
202
+ diffusion=self.diffusion,
203
+ guidance_scale=guidance_scale,
204
+ model_kwargs=dict(texts=[prompt] * batch_size),
205
+ progress=True,
206
+ clip_denoised=True,
207
+ use_fp16=True,
208
+ use_karras=True,
209
+ karras_steps=num_steps,
210
+ sigma_min=1e-3,
211
+ sigma_max=160,
212
+ s_churn=0,
213
+ )
214
+
215
+ # Render the 6 views we need with specific viewing angles
216
+ size = 320 # Size of each rendered image
217
+ images = []
218
+
219
+ # Define our 6 specific camera positions to match refine.py
220
+ azimuths = [30, 90, 150, 210, 270, 330]
221
+ elevations = [20, -10, 20, -10, 20, -10]
222
+
223
+ for i, (azimuth, elevation) in enumerate(zip(azimuths, elevations)):
224
+ cameras = create_custom_cameras(size, self.device, azimuths=[azimuth], elevations=[elevation], fov_degrees=30, distance=3.0)
225
+ rendered_image = decode_latent_images(
226
+ self.xm,
227
+ latents[0],
228
+ rendering_mode='stf',
229
+ cameras=cameras
230
+ )
231
+ images.append(rendered_image.detach().cpu().numpy())
232
+
233
+ # Convert images to uint8
234
+ images = [(image).astype(np.uint8) for image in images]
235
+
236
+ # Create 2x3 grid layout (640x960) instead of 3x2 (960x640)
237
+ layout = np.zeros((960, 640, 3), dtype=np.uint8)
238
+ for i, img in enumerate(images):
239
+ row = i // 2 # Now 3 images per row
240
+ col = i % 2 # Now 3 images per row
241
+ layout[row*320:(row+1)*320, col*320:(col+1)*320] = img
242
+
243
+ return Image.fromarray(layout), images
244
+
245
+ class RefinerInterface:
246
+ def __init__(self):
247
+ print("Initializing InstantMesh models...")
248
+ self.pipeline, self.model, self.infer_config = load_models()
249
+ print("InstantMesh models loaded!")
250
+
251
+ def refine_model(self, input_image, prompt, steps=75, guidance_scale=7.5):
252
+ """Main refinement function"""
253
+ # Process image and get refined output
254
+ input_image = Image.fromarray(input_image)
255
+
256
+ # Rotate the layout if needed (if we're getting a 640x960 layout but pipeline expects 960x640)
257
+ if input_image.width == 960 and input_image.height == 640:
258
+ # Transpose the image to get 960x640 layout
259
+ input_array = np.array(input_image)
260
+ new_layout = np.zeros((960, 640, 3), dtype=np.uint8)
261
+
262
+ # Rearrange from 2x3 to 3x2
263
+ for i in range(6):
264
+ src_row = i // 3
265
+ src_col = i % 3
266
+ dst_row = i // 2
267
+ dst_col = i % 2
268
+
269
+ new_layout[dst_row*320:(dst_row+1)*320, dst_col*320:(dst_col+1)*320] = \
270
+ input_array[src_row*320:(src_row+1)*320, src_col*320:(src_col+1)*320]
271
+
272
+ input_image = Image.fromarray(new_layout)
273
+
274
+ # Process with the pipeline (expects 960x640)
275
+ refined_output_960x640 = self.pipeline.refine(
276
+ input_image,
277
+ prompt=prompt,
278
+ num_inference_steps=int(steps),
279
+ guidance_scale=guidance_scale
280
+ ).images[0]
281
+
282
+ # Generate mesh using the 960x640 format
283
+ vertices, faces, vertex_colors = create_mesh(
284
+ refined_output_960x640,
285
+ self.model,
286
+ self.infer_config
287
+ )
288
+
289
+ # Save temporary mesh file
290
+ os.makedirs("temp", exist_ok=True)
291
+ temp_obj = os.path.join("temp", "refined_mesh.obj")
292
+ save_obj(vertices, faces, vertex_colors, temp_obj)
293
+
294
+ # Convert the output to 640x960 for display
295
+ refined_array = np.array(refined_output_960x640)
296
+ display_layout = np.zeros((960, 640, 3), dtype=np.uint8)
297
+
298
+ # Rearrange from 3x2 to 2x3
299
+ for i in range(6):
300
+ src_row = i // 2
301
+ src_col = i % 2
302
+ dst_row = i // 2
303
+ dst_col = i % 2
304
+
305
+ display_layout[dst_row*320:(dst_row+1)*320, dst_col*320:(dst_col+1)*320] = \
306
+ refined_array[src_row*320:(src_row+1)*320, src_col*320:(src_col+1)*320]
307
+
308
+ refined_output_640x960 = Image.fromarray(display_layout)
309
+
310
+ return refined_output_640x960, temp_obj
311
+
312
+ def create_demo():
313
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
314
+ shap_e = ShapERenderer(device)
315
+ refiner = RefinerInterface()
316
+
317
+ with gr.Blocks() as demo:
318
+ gr.Markdown("# Shap-E to InstantMesh Pipeline")
319
+
320
+ # First row: Controls
321
+ with gr.Row():
322
+ with gr.Column():
323
+ # Shap-E inputs
324
+ shape_prompt = gr.Textbox(
325
+ label="Shap-E Prompt",
326
+ placeholder="Enter text to generate initial 3D model..."
327
+ )
328
+ shape_guidance = gr.Slider(
329
+ minimum=1,
330
+ maximum=30,
331
+ value=15.0,
332
+ label="Shap-E Guidance Scale"
333
+ )
334
+ shape_steps = gr.Slider(
335
+ minimum=16,
336
+ maximum=128,
337
+ value=64,
338
+ step=16,
339
+ label="Shap-E Steps"
340
+ )
341
+ generate_btn = gr.Button("Generate Views")
342
+
343
+ with gr.Column():
344
+ # Refinement inputs
345
+ refine_prompt = gr.Textbox(
346
+ label="Refinement Prompt",
347
+ placeholder="Enter prompt to guide refinement..."
348
+ )
349
+ refine_steps = gr.Slider(
350
+ minimum=30,
351
+ maximum=100,
352
+ value=75,
353
+ step=1,
354
+ label="Refinement Steps"
355
+ )
356
+ refine_guidance = gr.Slider(
357
+ minimum=1,
358
+ maximum=20,
359
+ value=7.5,
360
+ label="Refinement Guidance Scale"
361
+ )
362
+ refine_btn = gr.Button("Refine")
363
+
364
+ # Second row: Image panels side by side
365
+ with gr.Row():
366
+ # Outputs - Images side by side
367
+ shape_output = gr.Image(
368
+ label="Generated Views",
369
+ width=640, # Swapped dimensions
370
+ height=960 # Swapped dimensions
371
+ )
372
+ refined_output = gr.Image(
373
+ label="Refined Output",
374
+ width=640, # Swapped dimensions
375
+ height=960 # Swapped dimensions
376
+ )
377
+
378
+ # Third row: 3D mesh panel below
379
+ with gr.Row():
380
+ # 3D mesh centered
381
+ mesh_output = gr.Model3D(
382
+ label="3D Mesh",
383
+ clear_color=[1.0, 1.0, 1.0, 1.0],
384
+ width=1280, # Full width
385
+ height=600 # Taller for better visualization
386
+ )
387
+
388
+ # Set up event handlers
389
+ def generate(prompt, guidance_scale, num_steps):
390
+ with torch.no_grad():
391
+ layout, _ = shap_e.generate_views(prompt, guidance_scale, num_steps)
392
+ return layout
393
+
394
+ def refine(input_image, prompt, steps, guidance_scale):
395
+ refined_img, mesh_path = refiner.refine_model(
396
+ input_image,
397
+ prompt,
398
+ steps,
399
+ guidance_scale
400
+ )
401
+ return refined_img, mesh_path
402
+
403
+ generate_btn.click(
404
+ fn=generate,
405
+ inputs=[shape_prompt, shape_guidance, shape_steps],
406
+ outputs=[shape_output]
407
+ )
408
+
409
+ refine_btn.click(
410
+ fn=refine,
411
+ inputs=[shape_output, refine_prompt, refine_steps, refine_guidance],
412
+ outputs=[refined_output, mesh_output]
413
+ )
414
+
415
+ return demo
416
+
417
+ if __name__ == "__main__":
418
+ demo = create_demo()
419
+ demo.launch(share=True)