cronos3k commited on
Commit
9173005
·
verified ·
1 Parent(s): 0bfb596

Update app.py

Browse files

let's see if this works a bit simpler implementation

Files changed (1) hide show
  1. app.py +47 -96
app.py CHANGED
@@ -110,9 +110,9 @@ def image_to_3d(
110
  slat_guidance_strength: float,
111
  slat_sampling_steps: int,
112
  req: gr.Request,
113
- ) -> Tuple[dict, str, str]:
114
  """
115
- Convert an image to a 3D model and generate a GLB file.
116
 
117
  Args:
118
  image (Image.Image): The input image.
@@ -124,7 +124,7 @@ def image_to_3d(
124
  req (gr.Request): Gradio request object.
125
 
126
  Returns:
127
- Tuple[dict, str, str]: The state dictionary, path to the generated video, and path to the standard GLB file.
128
  """
129
  user_dir = os.path.join(TMP_DIR, str(req.session_hash))
130
  outputs = pipeline.run(
@@ -148,20 +148,8 @@ def image_to_3d(
148
  video_path = os.path.join(user_dir, f"{trial_id}.mp4")
149
  imageio.mimsave(video_path, video, fps=15)
150
  state = pack_state(outputs['gaussian'][0], outputs['mesh'][0], trial_id)
151
-
152
- # Generate standard GLB file with default simplification and texture size
153
- glb = postprocessing_utils.to_glb(
154
- outputs['gaussian'][0],
155
- outputs['mesh'][0],
156
- simplify=0.95, # Default simplification
157
- texture_size=1024, # Default texture size
158
- verbose=False
159
- )
160
- glb_path = os.path.join(user_dir, f"{trial_id}.glb")
161
- glb.export(glb_path)
162
-
163
  torch.cuda.empty_cache()
164
- return state, video_path, glb_path
165
 
166
  # Existing GLB Extraction Function
167
  @spaces.GPU
@@ -170,7 +158,7 @@ def extract_glb(
170
  mesh_simplify: float,
171
  texture_size: int,
172
  req: gr.Request,
173
- ) -> Tuple[dict, bytes]:
174
  """
175
  Extract a GLB file from the 3D model.
176
 
@@ -181,31 +169,22 @@ def extract_glb(
181
  req (gr.Request): Gradio request object.
182
 
183
  Returns:
184
- Tuple[dict, bytes]: The model state and the GLB file bytes.
185
  """
186
  user_dir = os.path.join(TMP_DIR, str(req.session_hash))
187
  gs, mesh, trial_id = unpack_state(state)
188
- glb = postprocessing_utils.to_glb(
189
- gs,
190
- mesh,
191
- simplify=mesh_simplify,
192
- texture_size=texture_size,
193
- verbose=False
194
- )
195
  glb_path = os.path.join(user_dir, f"{trial_id}.glb")
196
  glb.export(glb_path)
197
- # Read the GLB file as bytes
198
- with open(glb_path, "rb") as f:
199
- glb_bytes = f.read()
200
  torch.cuda.empty_cache()
201
- return state, glb_bytes
202
 
203
- # New High-Quality GLB Extraction Function
204
  @spaces.GPU
205
  def extract_glb_high_quality(
206
  state: dict,
207
  req: gr.Request,
208
- ) -> Tuple[dict, bytes]:
209
  """
210
  Extract a high-quality GLB file from the 3D model without polygon reduction.
211
 
@@ -214,37 +193,27 @@ def extract_glb_high_quality(
214
  req (gr.Request): Gradio request object.
215
 
216
  Returns:
217
- Tuple[dict, bytes]: The model state and the high-quality GLB file bytes.
218
  """
219
  user_dir = os.path.join(TMP_DIR, str(req.session_hash))
220
  gs, mesh, trial_id = unpack_state(state)
221
  # Set simplify to 0.0 to disable polygon reduction
222
  # Set texture_size to 2048 for maximum texture quality
223
- glb = postprocessing_utils.to_glb(
224
- gs,
225
- mesh,
226
- simplify=0.0,
227
- texture_size=2048,
228
- verbose=False
229
- )
230
  glb_path = os.path.join(user_dir, f"{trial_id}_high_quality.glb")
231
  glb.export(glb_path)
232
- # Read the GLB file as bytes
233
- with open(glb_path, "rb") as f:
234
- glb_bytes = f.read()
235
  torch.cuda.empty_cache()
236
- return state, glb_bytes
237
 
238
  # Gradio Interface Definition
239
  with gr.Blocks(delete_cache=(600, 600)) as demo:
240
  gr.Markdown("""
241
  ## Image to 3D Asset with [TRELLIS](https://trellis3d.github.io/)
242
- * **Generate:** Upload an image and click "Generate" to create a 3D asset. If the image has an alpha channel, it will be used as the mask. Otherwise, the background will be removed automatically.
243
- * **Extract GLB:** If the generated 3D asset is satisfactory, click "Extract GLB" to extract the GLB file based on your chosen settings and download it.
244
- * **Download High Quality GLB:** Click this button to download a high-quality GLB file without any polygon reduction and with maximum texture quality.
245
- * **Status:** View messages and feedback about your actions below.
246
  """)
247
-
248
  with gr.Row():
249
  with gr.Column():
250
  # Image Input
@@ -322,12 +291,12 @@ with gr.Blocks(delete_cache=(600, 600)) as demo:
322
  step=512
323
  )
324
 
325
- # Extract GLB Button
326
- extract_glb_btn = gr.Button("Extract GLB")
327
 
328
- # Download High Quality GLB Button
329
- download_glb_high_quality_btn = gr.Button("Download High Quality GLB")
330
-
331
  with gr.Column():
332
  # Video Output
333
  video_output = gr.Video(
@@ -342,26 +311,21 @@ with gr.Blocks(delete_cache=(600, 600)) as demo:
342
  exposure=20.0,
343
  height=300
344
  )
345
- # Download GLB Buttons
346
  download_glb = gr.DownloadButton(
347
- label="Download GLB"
 
348
  )
 
349
  download_high_quality_glb = gr.DownloadButton(
350
- label="Download High Quality GLB"
351
- )
352
-
353
- # Status Message
354
- status_message = gr.Textbox(
355
- label="Status",
356
- value="Awaiting your action...",
357
- interactive=False,
358
- lines=2
359
  )
360
 
361
  # State Variables
362
  output_buf = gr.State()
363
- glb_bytes_state = gr.State() # For standard GLB
364
- glb_high_quality_bytes_state = gr.State() # For high-quality GLB
365
 
366
  # Example Images
367
  with gr.Row():
@@ -403,57 +367,44 @@ with gr.Blocks(delete_cache=(600, 600)) as demo:
403
  slat_guidance_strength,
404
  slat_sampling_steps
405
  ],
406
- outputs=[output_buf, video_output, glb_path := gr.State()]
407
  ).then(
408
- lambda state, video, glb: "Generation successful! You can now extract and download GLB files.",
409
- inputs=[output_buf, video_output, glb_path],
410
- outputs=[status_message]
411
  )
412
 
413
- # Extract GLB Button Click Handler
414
  extract_glb_btn.click(
415
  extract_glb,
416
  inputs=[output_buf, mesh_simplify, texture_size],
417
- outputs=[output_buf, glb_bytes_state],
418
  ).then(
419
- # Map the GLB bytes to the DownloadButton with filename
420
- lambda state, glb_bytes: (glb_bytes, "model.glb"),
421
- inputs=[output_buf, glb_bytes_state],
422
  outputs=[download_glb],
423
- ).then(
424
- # Update status message
425
- lambda: "GLB extraction successful! Click 'Download GLB' to save your model.",
426
- inputs=None,
427
- outputs=[status_message]
428
  )
429
 
430
- # Download High Quality GLB Button Click Handler
431
- download_glb_high_quality_btn.click(
432
  extract_glb_high_quality,
433
  inputs=[output_buf],
434
- outputs=[output_buf, glb_high_quality_bytes_state],
435
  ).then(
436
- # Map the high-quality GLB bytes to the DownloadButton with filename
437
- lambda state, glb_bytes: (glb_bytes, "model_high_quality.glb"),
438
- inputs=[output_buf, glb_high_quality_bytes_state],
439
  outputs=[download_high_quality_glb],
440
- ).then(
441
- # Update status message
442
- lambda: "High-quality GLB extraction successful! Click 'Download High Quality GLB' to save your model.",
443
- inputs=None,
444
- outputs=[status_message]
445
  )
446
 
447
  # Handle Clearing of Video Output
448
  video_output.clear(
449
- lambda: "Video output cleared. Please generate a new 3D asset.",
450
- outputs=[status_message],
451
  )
452
 
453
  # Handle Clearing of Model Output
454
  model_output.clear(
455
- lambda: "Model output cleared. Please extract and download GLB files again.",
456
- outputs=[status_message],
457
  )
458
 
459
  # Launch the Gradio app
@@ -465,4 +416,4 @@ if __name__ == "__main__":
465
  pipeline.preprocess_image(Image.fromarray(np.zeros((512, 512, 3), dtype=np.uint8))) # Preload rembg
466
  except:
467
  pass
468
- demo.launch()
 
110
  slat_guidance_strength: float,
111
  slat_sampling_steps: int,
112
  req: gr.Request,
113
+ ) -> Tuple[dict, str]:
114
  """
115
+ Convert an image to a 3D model.
116
 
117
  Args:
118
  image (Image.Image): The input image.
 
124
  req (gr.Request): Gradio request object.
125
 
126
  Returns:
127
+ Tuple[dict, str]: The state dictionary and the path to the generated video.
128
  """
129
  user_dir = os.path.join(TMP_DIR, str(req.session_hash))
130
  outputs = pipeline.run(
 
148
  video_path = os.path.join(user_dir, f"{trial_id}.mp4")
149
  imageio.mimsave(video_path, video, fps=15)
150
  state = pack_state(outputs['gaussian'][0], outputs['mesh'][0], trial_id)
 
 
 
 
 
 
 
 
 
 
 
 
151
  torch.cuda.empty_cache()
152
+ return state, video_path
153
 
154
  # Existing GLB Extraction Function
155
  @spaces.GPU
 
158
  mesh_simplify: float,
159
  texture_size: int,
160
  req: gr.Request,
161
+ ) -> Tuple[str, str]:
162
  """
163
  Extract a GLB file from the 3D model.
164
 
 
169
  req (gr.Request): Gradio request object.
170
 
171
  Returns:
172
+ Tuple[str, str]: The path to the extracted GLB file.
173
  """
174
  user_dir = os.path.join(TMP_DIR, str(req.session_hash))
175
  gs, mesh, trial_id = unpack_state(state)
176
+ glb = postprocessing_utils.to_glb(gs, mesh, simplify=mesh_simplify, texture_size=texture_size, verbose=False)
 
 
 
 
 
 
177
  glb_path = os.path.join(user_dir, f"{trial_id}.glb")
178
  glb.export(glb_path)
 
 
 
179
  torch.cuda.empty_cache()
180
+ return glb_path, glb_path
181
 
182
+ # **New High-Quality GLB Extraction Function**
183
  @spaces.GPU
184
  def extract_glb_high_quality(
185
  state: dict,
186
  req: gr.Request,
187
+ ) -> Tuple[str, str]:
188
  """
189
  Extract a high-quality GLB file from the 3D model without polygon reduction.
190
 
 
193
  req (gr.Request): Gradio request object.
194
 
195
  Returns:
196
+ Tuple[str, str]: The path to the high-quality GLB file.
197
  """
198
  user_dir = os.path.join(TMP_DIR, str(req.session_hash))
199
  gs, mesh, trial_id = unpack_state(state)
200
  # Set simplify to 0.0 to disable polygon reduction
201
  # Set texture_size to 2048 for maximum texture quality
202
+ glb = postprocessing_utils.to_glb(gs, mesh, simplify=0.0, texture_size=2048, verbose=False)
 
 
 
 
 
 
203
  glb_path = os.path.join(user_dir, f"{trial_id}_high_quality.glb")
204
  glb.export(glb_path)
 
 
 
205
  torch.cuda.empty_cache()
206
+ return glb_path, glb_path
207
 
208
  # Gradio Interface Definition
209
  with gr.Blocks(delete_cache=(600, 600)) as demo:
210
  gr.Markdown("""
211
  ## Image to 3D Asset with [TRELLIS](https://trellis3d.github.io/)
212
+ * Upload an image and click "Generate" to create a 3D asset. If the image has an alpha channel, it will be used as the mask. Otherwise, the background will be removed automatically.
213
+ * If you find the generated 3D asset satisfactory, click "Extract GLB" to extract the GLB file and download it.
214
+ * **New:** Click "Download High Quality GLB" to download the GLB file without any polygon reduction and with maximum texture quality.
 
215
  """)
216
+
217
  with gr.Row():
218
  with gr.Column():
219
  # Image Input
 
291
  step=512
292
  )
293
 
294
+ # Existing Extract GLB Button
295
+ extract_glb_btn = gr.Button("Extract GLB", interactive=True)
296
 
297
+ # **New Download High Quality GLB Button**
298
+ download_high_quality_glb_btn = gr.Button("Download High Quality GLB", interactive=True)
299
+
300
  with gr.Column():
301
  # Video Output
302
  video_output = gr.Video(
 
311
  exposure=20.0,
312
  height=300
313
  )
314
+ # Existing Download GLB Button
315
  download_glb = gr.DownloadButton(
316
+ label="Download GLB",
317
+ file_count="single",
318
  )
319
+ # **New Download High Quality GLB Button**
320
  download_high_quality_glb = gr.DownloadButton(
321
+ label="Download High Quality GLB",
322
+ file_count="single",
 
 
 
 
 
 
 
323
  )
324
 
325
  # State Variables
326
  output_buf = gr.State()
327
+ glb_path_state = gr.State() # For standard GLB
328
+ glb_high_quality_path_state = gr.State() # For high-quality GLB
329
 
330
  # Example Images
331
  with gr.Row():
 
367
  slat_guidance_strength,
368
  slat_sampling_steps
369
  ],
370
+ outputs=[output_buf, video_output],
371
  ).then(
372
+ lambda: gr.Button.update(interactive=True),
373
+ outputs=[extract_glb_btn, download_high_quality_glb_btn],
 
374
  )
375
 
376
+ # Existing Extract GLB Button Click Handler
377
  extract_glb_btn.click(
378
  extract_glb,
379
  inputs=[output_buf, mesh_simplify, texture_size],
380
+ outputs=[model_output, glb_path_state],
381
  ).then(
382
+ lambda glb_path: glb_path if glb_path else "",
383
+ inputs=[glb_path_state],
 
384
  outputs=[download_glb],
 
 
 
 
 
385
  )
386
 
387
+ # **New Download High Quality GLB Button Click Handler**
388
+ download_high_quality_glb_btn.click(
389
  extract_glb_high_quality,
390
  inputs=[output_buf],
391
+ outputs=[model_output, glb_high_quality_path_state],
392
  ).then(
393
+ lambda glb_path: glb_path if glb_path else "",
394
+ inputs=[glb_high_quality_path_state],
 
395
  outputs=[download_high_quality_glb],
 
 
 
 
 
396
  )
397
 
398
  # Handle Clearing of Video Output
399
  video_output.clear(
400
+ lambda: (gr.Button.update(interactive=True), gr.Button.update(interactive=True)),
401
+ outputs=[extract_glb_btn, download_high_quality_glb_btn],
402
  )
403
 
404
  # Handle Clearing of Model Output
405
  model_output.clear(
406
+ lambda: (gr.File.update(value=None), gr.File.update(value=None)),
407
+ outputs=[download_glb, download_high_quality_glb],
408
  )
409
 
410
  # Launch the Gradio app
 
416
  pipeline.preprocess_image(Image.fromarray(np.zeros((512, 512, 3), dtype=np.uint8))) # Preload rembg
417
  except:
418
  pass
419
+ demo.queue(concurrency_count=1, max_size=10).launch()