YiftachEde commited on
Commit
b266eca
·
1 Parent(s): 776d5b3
Files changed (1) hide show
  1. app.py +173 -126
app.py CHANGED
@@ -237,62 +237,78 @@ class ShapERenderer:
237
  print("Shap-E models initialized!")
238
 
239
  def ensure_models_loaded(self):
240
- if self.model is None:
241
- self.xm = load_model('transmitter', device=self.device)
242
- self.model = load_model('text300M', device=self.device)
243
- self.diffusion = diffusion_from_config(load_config('diffusion'))
 
 
 
 
 
244
 
245
  def generate_views(self, prompt, guidance_scale=15.0, num_steps=64):
246
- self.ensure_models_loaded()
247
-
248
- # Generate latents using the text-to-3D model
249
- batch_size = 1
250
- guidance_scale = float(guidance_scale)
251
- latents = sample_latents(
252
- batch_size=batch_size,
253
- model=self.model,
254
- diffusion=self.diffusion,
255
- guidance_scale=guidance_scale,
256
- model_kwargs=dict(texts=[prompt] * batch_size),
257
- progress=True,
258
- clip_denoised=True,
259
- use_fp16=True,
260
- use_karras=True,
261
- karras_steps=num_steps,
262
- sigma_min=1e-3,
263
- sigma_max=160,
264
- s_churn=0,
265
- )
 
 
 
 
266
 
267
- # Render the 6 views we need with specific viewing angles
268
- size = 320 # Size of each rendered image
269
- images = []
270
-
271
- # Define our 6 specific camera positions to match refine.py
272
- azimuths = [30, 90, 150, 210, 270, 330]
273
- elevations = [20, -10, 20, -10, 20, -10]
274
-
275
- for i, (azimuth, elevation) in enumerate(zip(azimuths, elevations)):
276
- cameras = create_custom_cameras(size, self.device, azimuths=[azimuth], elevations=[elevation], fov_degrees=30, distance=3.0)
277
- rendered_image = decode_latent_images(
278
- self.xm,
279
- latents[0],
280
- cameras=cameras,
281
- rendering_mode='stf'
282
- )
283
- images.append(rendered_image[0])
284
-
285
- # Convert images to uint8
286
- images = [np.array(image) for image in images]
287
-
288
- # Create 2x3 grid layout (640x960) instead of 3x2 (960x640)
289
- layout = np.zeros((960, 640, 3), dtype=np.uint8)
290
- for i, img in enumerate(images):
291
- row = i // 2 # Now 3 images per row
292
- col = i % 2 # Now 3 images per row
293
- layout[row*320:(row+1)*320, col*320:(col+1)*320] = img
 
 
294
 
295
- return Image.fromarray(layout), images
 
 
 
 
 
296
 
297
  class RefinerInterface:
298
  def __init__(self):
@@ -304,70 +320,88 @@ class RefinerInterface:
304
 
305
  def ensure_models_loaded(self):
306
  if self.pipeline is None:
307
- self.pipeline, self.model, self.infer_config = load_models()
 
 
 
 
 
308
 
309
  def refine_model(self, input_image, prompt, steps=75, guidance_scale=7.5):
310
  """Main refinement function"""
311
- self.ensure_models_loaded()
312
-
313
- # Process image and get refined output
314
- input_image = Image.fromarray(input_image)
315
-
316
- # Rotate the layout if needed (if we're getting a 640x960 layout but pipeline expects 960x640)
317
- if input_image.width == 960 and input_image.height == 640:
318
- # Transpose the image to get 960x640 layout
319
- input_array = np.array(input_image)
320
- new_layout = np.zeros((960, 640, 3), dtype=np.uint8)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
321
 
322
- # Rearrange from 2x3 to 3x2
 
 
 
 
323
  for i in range(6):
324
- src_row = i // 3
325
- src_col = i % 3
326
  dst_row = i // 2
327
  dst_col = i % 2
328
 
329
- new_layout[dst_row*320:(dst_row+1)*320, dst_col*320:(dst_col+1)*320] = \
330
- input_array[src_row*320:(src_row+1)*320, src_col*320:(src_col+1)*320]
331
 
332
- input_image = Image.fromarray(new_layout)
333
-
334
- # Process with the pipeline (expects 960x640)
335
- refined_output_960x640 = self.pipeline.refine(
336
- input_image,
337
- prompt=prompt,
338
- num_inference_steps=int(steps),
339
- guidance_scale=guidance_scale
340
- ).images[0]
341
-
342
- # Generate mesh using the 960x640 format
343
- vertices, faces, vertex_colors = create_mesh(
344
- refined_output_960x640,
345
- self.model,
346
- self.infer_config
347
- )
348
-
349
- # Save temporary mesh file
350
- os.makedirs("temp", exist_ok=True)
351
- temp_obj = os.path.join("temp", "refined_mesh.obj")
352
- save_obj(vertices, faces, vertex_colors, temp_obj)
353
-
354
- # Convert the output to 640x960 for display
355
- refined_array = np.array(refined_output_960x640)
356
- display_layout = np.zeros((960, 640, 3), dtype=np.uint8)
357
-
358
- # Rearrange from 3x2 to 2x3
359
- for i in range(6):
360
- src_row = i // 2
361
- src_col = i % 2
362
- dst_row = i // 2
363
- dst_col = i % 2
364
 
365
- display_layout[dst_row*320:(dst_row+1)*320, dst_col*320:(dst_col+1)*320] = \
366
- refined_array[src_row*320:(src_row+1)*320, src_col*320:(src_col+1)*320]
367
-
368
- refined_output_640x960 = Image.fromarray(display_layout)
369
-
370
- return refined_output_640x960, temp_obj
371
 
372
  def create_demo():
373
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
@@ -420,19 +454,20 @@ def create_demo():
420
  label="Refinement Guidance Scale"
421
  )
422
  refine_btn = gr.Button("Refine")
 
423
 
424
  # Second row: Image panels side by side
425
  with gr.Row():
426
  # Outputs - Images side by side
427
  shape_output = gr.Image(
428
  label="Generated Views",
429
- width=640, # Swapped dimensions
430
- height=960 # Swapped dimensions
431
  )
432
  refined_output = gr.Image(
433
  label="Refined Output",
434
- width=640, # Swapped dimensions
435
- height=960 # Swapped dimensions
436
  )
437
 
438
  # Third row: 3D mesh panel below
@@ -441,37 +476,49 @@ def create_demo():
441
  mesh_output = gr.Model3D(
442
  label="3D Mesh",
443
  clear_color=[1.0, 1.0, 1.0, 1.0],
444
- # width=1280, # Full width
445
- # height=600 # Taller for better visualization
446
  )
447
 
448
  # Set up event handlers
449
- @spaces.GPU(duration=60)
450
  def generate(prompt, guidance_scale, num_steps):
451
- with torch.no_grad():
452
- layout, _ = shap_e.generate_views(prompt, guidance_scale, num_steps)
453
- return layout
 
 
 
 
 
 
 
454
 
455
- @spaces.GPU(duration=60)
456
  def refine(input_image, prompt, steps, guidance_scale):
457
- refined_img, mesh_path = refiner.refine_model(
458
- input_image,
459
- prompt,
460
- steps,
461
- guidance_scale
462
- )
463
- return refined_img, mesh_path
 
 
 
 
 
 
 
464
 
465
  generate_btn.click(
466
  fn=generate,
467
  inputs=[shape_prompt, shape_guidance, shape_steps],
468
- outputs=[shape_output]
469
  )
470
 
471
  refine_btn.click(
472
  fn=refine,
473
  inputs=[shape_output, refine_prompt, refine_steps, refine_guidance],
474
- outputs=[refined_output, mesh_output]
475
  )
476
 
477
  return demo
 
237
  print("Shap-E models initialized!")
238
 
239
  def ensure_models_loaded(self):
240
+ if self.xm is None:
241
+ try:
242
+ torch.cuda.empty_cache() # Clear GPU memory before loading
243
+ self.xm = load_model('transmitter', device=self.device)
244
+ self.model = load_model('text300M', device=self.device)
245
+ self.diffusion = diffusion_from_config(load_config('diffusion'))
246
+ except Exception as e:
247
+ print(f"Error loading models: {e}")
248
+ raise
249
 
250
  def generate_views(self, prompt, guidance_scale=15.0, num_steps=64):
251
+ try:
252
+ self.ensure_models_loaded()
253
+ torch.cuda.empty_cache() # Clear GPU memory before generation
254
+
255
+ # Generate latents using the text-to-3D model
256
+ batch_size = 1
257
+ guidance_scale = float(guidance_scale)
258
+
259
+ with torch.cuda.amp.autocast(): # Use automatic mixed precision
260
+ latents = sample_latents(
261
+ batch_size=batch_size,
262
+ model=self.model,
263
+ diffusion=self.diffusion,
264
+ guidance_scale=guidance_scale,
265
+ model_kwargs=dict(texts=[prompt] * batch_size),
266
+ progress=True,
267
+ clip_denoised=True,
268
+ use_fp16=True,
269
+ use_karras=True,
270
+ karras_steps=num_steps,
271
+ sigma_min=1e-3,
272
+ sigma_max=160,
273
+ s_churn=0,
274
+ )
275
 
276
+ # Render the 6 views we need with specific viewing angles
277
+ size = 320 # Size of each rendered image
278
+ images = []
279
+
280
+ # Define our 6 specific camera positions to match refine.py
281
+ azimuths = [30, 90, 150, 210, 270, 330]
282
+ elevations = [20, -10, 20, -10, 20, -10]
283
+
284
+ for i, (azimuth, elevation) in enumerate(zip(azimuths, elevations)):
285
+ cameras = create_custom_cameras(size, self.device, azimuths=[azimuth], elevations=[elevation], fov_degrees=30, distance=3.0)
286
+ with torch.cuda.amp.autocast(): # Use automatic mixed precision
287
+ rendered_image = decode_latent_images(
288
+ self.xm,
289
+ latents[0],
290
+ cameras=cameras,
291
+ rendering_mode='stf'
292
+ )
293
+ images.append(rendered_image[0])
294
+ torch.cuda.empty_cache() # Clear GPU memory after each view
295
+
296
+ # Convert images to uint8
297
+ images = [np.array(image) for image in images]
298
+
299
+ # Create 2x3 grid layout (640x960)
300
+ layout = np.zeros((960, 640, 3), dtype=np.uint8)
301
+ for i, img in enumerate(images):
302
+ row = i // 2
303
+ col = i % 2
304
+ layout[row*320:(row+1)*320, col*320:(col+1)*320] = img
305
 
306
+ return Image.fromarray(layout), images
307
+
308
+ except Exception as e:
309
+ print(f"Error in generate_views: {e}")
310
+ torch.cuda.empty_cache() # Clear GPU memory on error
311
+ raise
312
 
313
  class RefinerInterface:
314
  def __init__(self):
 
320
 
321
  def ensure_models_loaded(self):
322
  if self.pipeline is None:
323
+ try:
324
+ torch.cuda.empty_cache() # Clear GPU memory before loading
325
+ self.pipeline, self.model, self.infer_config = load_models()
326
+ except Exception as e:
327
+ print(f"Error loading models: {e}")
328
+ raise
329
 
330
  def refine_model(self, input_image, prompt, steps=75, guidance_scale=7.5):
331
  """Main refinement function"""
332
+ try:
333
+ self.ensure_models_loaded()
334
+ torch.cuda.empty_cache() # Clear GPU memory before processing
335
+
336
+ # Process image and get refined output
337
+ input_image = Image.fromarray(input_image)
338
+
339
+ # Rotate the layout if needed (if we're getting a 640x960 layout but pipeline expects 960x640)
340
+ if input_image.width == 960 and input_image.height == 640:
341
+ # Transpose the image to get 960x640 layout
342
+ input_array = np.array(input_image)
343
+ new_layout = np.zeros((960, 640, 3), dtype=np.uint8)
344
+
345
+ # Rearrange from 2x3 to 3x2
346
+ for i in range(6):
347
+ src_row = i // 3
348
+ src_col = i % 3
349
+ dst_row = i // 2
350
+ dst_col = i % 2
351
+
352
+ new_layout[dst_row*320:(dst_row+1)*320, dst_col*320:(dst_col+1)*320] = \
353
+ input_array[src_row*320:(src_row+1)*320, src_col*320:(src_col+1)*320]
354
+
355
+ input_image = Image.fromarray(new_layout)
356
+
357
+ # Process with the pipeline (expects 960x640)
358
+ with torch.cuda.amp.autocast(): # Use automatic mixed precision
359
+ refined_output_960x640 = self.pipeline.refine(
360
+ input_image,
361
+ prompt=prompt,
362
+ num_inference_steps=int(steps),
363
+ guidance_scale=guidance_scale
364
+ ).images[0]
365
+
366
+ torch.cuda.empty_cache() # Clear GPU memory after refinement
367
+
368
+ # Generate mesh using the 960x640 format
369
+ with torch.cuda.amp.autocast(): # Use automatic mixed precision
370
+ vertices, faces, vertex_colors = create_mesh(
371
+ refined_output_960x640,
372
+ self.model,
373
+ self.infer_config
374
+ )
375
+
376
+ torch.cuda.empty_cache() # Clear GPU memory after mesh generation
377
+
378
+ # Save temporary mesh file
379
+ os.makedirs("temp", exist_ok=True)
380
+ temp_obj = os.path.join("temp", "refined_mesh.obj")
381
+ save_obj(vertices, faces, vertex_colors, temp_obj)
382
 
383
+ # Convert the output to 640x960 for display
384
+ refined_array = np.array(refined_output_960x640)
385
+ display_layout = np.zeros((960, 640, 3), dtype=np.uint8)
386
+
387
+ # Rearrange from 3x2 to 2x3
388
  for i in range(6):
389
+ src_row = i // 2
390
+ src_col = i % 2
391
  dst_row = i // 2
392
  dst_col = i % 2
393
 
394
+ display_layout[dst_row*320:(dst_row+1)*320, dst_col*320:(dst_col+1)*320] = \
395
+ refined_array[src_row*320:(src_row+1)*320, src_col*320:(src_col+1)*320]
396
 
397
+ refined_output_640x960 = Image.fromarray(display_layout)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
398
 
399
+ return refined_output_640x960, temp_obj
400
+
401
+ except Exception as e:
402
+ print(f"Error in refine_model: {e}")
403
+ torch.cuda.empty_cache() # Clear GPU memory on error
404
+ raise
405
 
406
  def create_demo():
407
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
 
454
  label="Refinement Guidance Scale"
455
  )
456
  refine_btn = gr.Button("Refine")
457
+ error_output = gr.Textbox(label="Status/Error Messages", interactive=False)
458
 
459
  # Second row: Image panels side by side
460
  with gr.Row():
461
  # Outputs - Images side by side
462
  shape_output = gr.Image(
463
  label="Generated Views",
464
+ width=640,
465
+ height=960
466
  )
467
  refined_output = gr.Image(
468
  label="Refined Output",
469
+ width=640,
470
+ height=960
471
  )
472
 
473
  # Third row: 3D mesh panel below
 
476
  mesh_output = gr.Model3D(
477
  label="3D Mesh",
478
  clear_color=[1.0, 1.0, 1.0, 1.0],
 
 
479
  )
480
 
481
  # Set up event handlers
482
+ @spaces.GPU(duration=120) # Increased duration to 120 seconds
483
  def generate(prompt, guidance_scale, num_steps):
484
+ try:
485
+ torch.cuda.empty_cache() # Clear GPU memory before starting
486
+ with torch.no_grad():
487
+ layout, _ = shap_e.generate_views(prompt, guidance_scale, num_steps)
488
+ return layout, None # Return None for error message
489
+ except Exception as e:
490
+ torch.cuda.empty_cache() # Clear GPU memory on error
491
+ error_msg = f"Error during generation: {str(e)}"
492
+ print(error_msg)
493
+ return None, error_msg
494
 
495
+ @spaces.GPU(duration=120) # Increased duration to 120 seconds
496
  def refine(input_image, prompt, steps, guidance_scale):
497
+ try:
498
+ torch.cuda.empty_cache() # Clear GPU memory before starting
499
+ refined_img, mesh_path = refiner.refine_model(
500
+ input_image,
501
+ prompt,
502
+ steps,
503
+ guidance_scale
504
+ )
505
+ return refined_img, mesh_path, None # Return None for error message
506
+ except Exception as e:
507
+ torch.cuda.empty_cache() # Clear GPU memory on error
508
+ error_msg = f"Error during refinement: {str(e)}"
509
+ print(error_msg)
510
+ return None, None, error_msg
511
 
512
  generate_btn.click(
513
  fn=generate,
514
  inputs=[shape_prompt, shape_guidance, shape_steps],
515
+ outputs=[shape_output, error_output]
516
  )
517
 
518
  refine_btn.click(
519
  fn=refine,
520
  inputs=[shape_output, refine_prompt, refine_steps, refine_guidance],
521
+ outputs=[refined_output, mesh_output, error_output]
522
  )
523
 
524
  return demo