dkatz2391 commited on
Commit
9df9f29
·
verified ·
1 Parent(s): 1809fe4

another fucking try

Browse files
Files changed (1) hide show
  1. app.py +124 -149
app.py CHANGED
@@ -1,14 +1,12 @@
1
- # Version: 1.1.2 - API State Fix + DEBUG (Video Disabled) + Import Fix (2025-05-04)
2
  # Changes:
 
3
  # - ENSURED `import spaces` is present for the @spaces.GPU decorator.
4
  # - TEMPORARY DEBUGGING STEP: Commented out video rendering in `text_to_3d`
5
  # and return None for video_path to isolate the "Session not found" error.
6
- # - Modified `text_to_3d` to explicitly return the serializable `state_dict` from `pack_state`
7
- # as the first return value. This ensures the dictionary is available via the API.
8
- # - Modified `extract_glb` and `extract_gaussian` to accept `state_dict: dict` as their first argument
9
- # instead of relying on the implicit `gr.State` object type when called via API.
10
- # - Kept Gradio UI bindings (`outputs=[output_buf, ...]`, `inputs=[output_buf, ...]`)
11
- # so the UI continues to function by passing the dictionary through output_buf.
12
  # - Added minor safety checks and logging.
13
 
14
  import gradio as gr
@@ -17,8 +15,6 @@ import spaces # <<<--- ENSURE THIS IMPORT IS PRESENT
17
  import os
18
  import shutil
19
  os.environ['TOKENIZERS_PARALLELISM'] = 'true'
20
- # Fix potential SpConv issue if needed, try 'hash' or 'native'
21
- # os.environ.setdefault('SPCONV_ALGO', 'native') # Use setdefault to avoid overwriting if already set
22
  os.environ['SPCONV_ALGO'] = 'native' # Direct set as per original
23
 
24
  from typing import *
@@ -35,33 +31,43 @@ import sys
35
 
36
 
37
  MAX_SEED = np.iinfo(np.int32).max
38
- # Ensure TMP_DIR is correctly defined relative to the script location
39
- # Using /tmp/ directly might be more robust in some container environments
40
- # TMP_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'tmp')
41
- TMP_DIR = '/tmp/gradio_sessions' # Use standard /tmp directory
42
  print(f"Using temporary directory: {TMP_DIR}")
43
- os.makedirs(TMP_DIR, exist_ok=True)
 
 
 
 
 
 
 
44
 
45
 
46
  def start_session(req: gr.Request):
47
  """Creates a temporary directory for the user session."""
 
48
  try:
49
  session_hash = req.session_hash
50
  if not session_hash:
51
- # Fallback or generate a temporary ID if session_hash is missing (might happen on first load?)
52
  session_hash = f"no_session_{np.random.randint(10000, 99999)}"
53
  print(f"Warning: No session_hash in request, using temporary ID: {session_hash}")
54
 
 
 
 
 
55
  user_dir = os.path.join(TMP_DIR, str(session_hash))
56
  os.makedirs(user_dir, exist_ok=True)
57
- print(f"Started session, created directory: {user_dir}")
58
  except Exception as e:
59
- print(f"Error in start_session: {e}", file=sys.stderr)
60
- # Decide if this is critical - maybe raise to prevent further issues?
61
 
62
 
63
  def end_session(req: gr.Request):
64
  """Removes the temporary directory for the user session."""
 
65
  try:
66
  session_hash = req.session_hash
67
  if not session_hash:
@@ -69,16 +75,16 @@ def end_session(req: gr.Request):
69
  return
70
 
71
  user_dir = os.path.join(TMP_DIR, str(session_hash))
72
- if os.path.exists(user_dir):
73
  try:
74
  shutil.rmtree(user_dir)
75
  print(f"Ended session, removed directory: {user_dir}")
76
  except OSError as e:
77
  print(f"Error removing tmp directory {user_dir}: {e.strerror}", file=sys.stderr)
78
  else:
79
- print(f"Ended session, directory already removed or hash mismatch: {user_dir}")
80
  except Exception as e:
81
- print(f"Error in end_session: {e}", file=sys.stderr)
82
 
83
 
84
  def pack_state(gs: Gaussian, mesh: MeshExtractResult) -> dict:
@@ -87,7 +93,7 @@ def pack_state(gs: Gaussian, mesh: MeshExtractResult) -> dict:
87
  try:
88
  packed_data = {
89
  'gaussian': {
90
- **{k: v for k, v in gs.init_params.items()}, # Ensure init_params are included
91
  '_xyz': gs._xyz.detach().cpu().numpy(),
92
  '_features_dc': gs._features_dc.detach().cpu().numpy(),
93
  '_scaling': gs._scaling.detach().cpu().numpy(),
@@ -104,7 +110,7 @@ def pack_state(gs: Gaussian, mesh: MeshExtractResult) -> dict:
104
  except Exception as e:
105
  print(f"Error during pack_state: {e}", file=sys.stderr)
106
  traceback.print_exc()
107
- raise # Re-raise the error to be caught upstream
108
 
109
 
110
  def unpack_state(state_dict: dict) -> Tuple[Gaussian, edict]:
@@ -114,23 +120,20 @@ def unpack_state(state_dict: dict) -> Tuple[Gaussian, edict]:
114
  if not isinstance(state_dict, dict) or 'gaussian' not in state_dict or 'mesh' not in state_dict:
115
  raise ValueError("Invalid state_dict structure passed to unpack_state.")
116
 
117
- # Ensure the device is correctly set when unpacking
118
  device = 'cuda' if torch.cuda.is_available() else 'cpu'
119
  print(f"[unpack_state] Using device: {device}")
120
 
121
  gauss_data = state_dict['gaussian']
122
  mesh_data = state_dict['mesh']
123
 
124
- # Recreate Gaussian object using parameters stored during packing
125
  gs = Gaussian(
126
- aabb=gauss_data.get('aabb'), # Use .get for safety
127
  sh_degree=gauss_data.get('sh_degree'),
128
  mininum_kernel_size=gauss_data.get('mininum_kernel_size'),
129
  scaling_bias=gauss_data.get('scaling_bias'),
130
  opacity_bias=gauss_data.get('opacity_bias'),
131
  scaling_activation=gauss_data.get('scaling_activation'),
132
  )
133
- # Load tensors, ensuring they are created on the correct device
134
  gs._xyz = torch.tensor(gauss_data['_xyz'], device=device, dtype=torch.float32)
135
  gs._features_dc = torch.tensor(gauss_data['_features_dc'], device=device, dtype=torch.float32)
136
  gs._scaling = torch.tensor(gauss_data['_scaling'], device=device, dtype=torch.float32)
@@ -138,10 +141,9 @@ def unpack_state(state_dict: dict) -> Tuple[Gaussian, edict]:
138
  gs._opacity = torch.tensor(gauss_data['_opacity'], device=device, dtype=torch.float32)
139
  print(f"[unpack_state] Gaussian unpacked. Points: {gs.get_xyz.shape[0]}")
140
 
141
- # Recreate mesh object using edict for compatibility if needed elsewhere
142
  mesh = edict(
143
  vertices=torch.tensor(mesh_data['vertices'], device=device, dtype=torch.float32),
144
- faces=torch.tensor(mesh_data['faces'], device=device, dtype=torch.int64), # Faces are typically long/int64
145
  )
146
  print(f"[unpack_state] Mesh unpacked. Vertices: {mesh.vertices.shape[0]}, Faces: {mesh.faces.shape[0]}")
147
 
@@ -149,14 +151,14 @@ def unpack_state(state_dict: dict) -> Tuple[Gaussian, edict]:
149
  except Exception as e:
150
  print(f"Error during unpack_state: {e}", file=sys.stderr)
151
  traceback.print_exc()
152
- raise # Re-raise the error
153
 
154
 
155
  def get_seed(randomize_seed: bool, seed: int) -> int:
156
  """Gets a seed value, randomizing if requested."""
157
  new_seed = np.random.randint(0, MAX_SEED) if randomize_seed else seed
158
  print(f"[get_seed] Randomize: {randomize_seed}, Input Seed: {seed}, Output Seed: {new_seed}")
159
- return int(new_seed) # Ensure it's a standard int
160
 
161
 
162
  @spaces.GPU
@@ -168,73 +170,57 @@ def text_to_3d(
168
  slat_guidance_strength: float,
169
  slat_sampling_steps: int,
170
  req: gr.Request,
171
- ) -> Tuple[dict, Optional[str]]: # Return type changed Optional[str] for video path
172
  """
173
  Generates a 3D model (Gaussian and Mesh) from text and returns a
174
  serializable state dictionary and potentially a video preview path.
175
  >>> TEMPORARILY DISABLED VIDEO RENDERING FOR DEBUGGING <<<
176
  """
177
  print(f"[text_to_3d - DEBUG MODE] Received prompt: '{prompt}', Seed: {seed}")
178
- session_hash = req.session_hash
179
- if not session_hash:
180
- session_hash = f"no_session_{np.random.randint(10000, 99999)}" # Use consistent fallback
181
- print(f"Warning: No session_hash in text_to_3d request, using temporary ID: {session_hash}")
182
- user_dir = os.path.join(TMP_DIR, str(session_hash))
183
- os.makedirs(user_dir, exist_ok=True) # Ensure it exists for this request
184
- print(f"[text_to_3d - DEBUG MODE] User directory: {user_dir}")
185
-
186
- # --- Generation Pipeline ---
187
  try:
 
 
 
 
 
 
 
 
 
 
 
188
  print("[text_to_3d - DEBUG MODE] Running Trellis pipeline...")
189
- # Add more specific pipeline settings if needed based on Trellis docs
190
  outputs = pipeline.run(
191
  prompt=prompt,
192
  seed=seed,
193
- formats=["gaussian", "mesh"], # Ensure both are generated
194
  sparse_structure_sampler_params={
195
- "steps": int(ss_sampling_steps), # Ensure steps are int
196
  "cfg_strength": float(ss_guidance_strength),
197
  },
198
  slat_sampler_params={
199
- "steps": int(slat_sampling_steps), # Ensure steps are int
200
  "cfg_strength": float(slat_guidance_strength),
201
  },
202
- # device='cuda' # Explicitly specify device if needed
203
  )
204
  print("[text_to_3d - DEBUG MODE] Pipeline run completed.")
205
- except Exception as e:
206
- print(f"❌ [text_to_3d - DEBUG MODE] Pipeline error: {e}", file=sys.stderr)
207
- traceback.print_exc()
208
- raise gr.Error(f"Trellis pipeline failed during generation: {e}") # More specific error
209
 
210
- # --- Create Serializable State Dictionary ---
211
- try:
212
  state_dict = pack_state(outputs['gaussian'][0], outputs['mesh'][0])
 
213
  except Exception as e:
214
- print(f"❌ [text_to_3d - DEBUG MODE] pack_state error: {e}", file=sys.stderr)
215
  traceback.print_exc()
216
- raise gr.Error(f"Failed to pack state after generation: {e}")
 
217
 
218
  # --- Render Video Preview (TEMPORARILY DISABLED FOR DEBUGGING) ---
219
- video_path = None # Explicitly set path to None for this debug version
220
  print("[text_to_3d - DEBUG MODE] Skipping video rendering.")
221
- # --- Original Video Code Block Start (Keep commented for now) ---
222
- # try:
223
- # print("[text_to_3d] Rendering video preview...")
224
- # video = render_utils.render_video(outputs['gaussian'][0], num_frames=120)['color']
225
- # video_geo = render_utils.render_video(outputs['mesh'][0], num_frames=120)['normal']
226
- # # Ensure video frames are uint8
227
- # video = [np.concatenate([v.astype(np.uint8), vg.astype(np.uint8)], axis=1) for v, vg in zip(video, video_geo)]
228
- # video_path_tmp = os.path.join(user_dir, 'sample.mp4') # Use temp name
229
- # imageio.mimsave(video_path_tmp, video, fps=15, quality=8) # Added quality setting
230
- # print(f"[text_to_3d] Video saved to: {video_path_tmp}")
231
- # video_path = video_path_tmp # Assign if successful
232
- # except Exception as e:
233
- # print(f"❌ [text_to_3d] Video rendering/saving error: {e}", file=sys.stderr)
234
- # traceback.print_exc()
235
- # # Still return state_dict, but maybe signal video error? Return None for path.
236
- # video_path = None # Indicate video failure
237
- # --- Original Video Code Block End ---
238
 
239
  # --- Cleanup and Return ---
240
  if torch.cuda.is_available():
@@ -243,15 +229,16 @@ def text_to_3d(
243
 
244
  # --- Return Serializable Dictionary and None Video Path ---
245
  print("[text_to_3d - DEBUG MODE] Returning state dictionary and None video path.")
246
- # Ensure state_dict is not None before returning
247
  if state_dict is None:
248
- raise gr.Error("Failed to create state dictionary.")
 
 
249
  return state_dict, video_path
250
 
251
 
252
- @spaces.GPU(duration=120) # Increased duration slightly
253
  def extract_glb(
254
- state_dict: dict, # <-- Accepts the dictionary directly
255
  mesh_simplify: float,
256
  texture_size: int,
257
  req: gr.Request,
@@ -260,31 +247,27 @@ def extract_glb(
260
  Extracts a GLB file from the provided 3D model state dictionary.
261
  """
262
  print(f"[extract_glb] Received request. Simplify: {mesh_simplify}, Texture Size: {texture_size}")
263
- session_hash = req.session_hash
264
- if not session_hash:
265
- session_hash = f"no_session_{np.random.randint(10000, 99999)}"
266
- print(f"Warning: No session_hash in extract_glb request, using temporary ID: {session_hash}")
 
 
 
267
 
268
- if not isinstance(state_dict, dict):
269
- print("❌ [extract_glb] Error: Invalid state_dict received (not a dictionary).")
270
- raise gr.Error("Invalid state data received. Please generate the model first.")
271
 
272
- user_dir = os.path.join(TMP_DIR, str(session_hash))
273
- os.makedirs(user_dir, exist_ok=True) # Ensure it exists
274
- print(f"[extract_glb] User directory: {user_dir}")
275
 
276
- # --- Unpack state from the dictionary ---
277
- try:
278
  gs, mesh = unpack_state(state_dict)
279
- except Exception as e:
280
- print(f"❌ [extract_glb] unpack_state error: {e}", file=sys.stderr)
281
- traceback.print_exc()
282
- raise gr.Error(f"Failed to unpack state during GLB extraction: {e}")
283
 
284
- # --- Postprocessing and Export ---
285
- try:
286
  print("[extract_glb] Converting to GLB...")
287
- # Ensure parameters have correct types
288
  simplify_factor = float(mesh_simplify)
289
  tex_size = int(texture_size)
290
  glb = postprocessing_utils.to_glb(gs, mesh, simplify=simplify_factor, texture_size=tex_size, verbose=True)
@@ -292,77 +275,77 @@ def extract_glb(
292
  print(f"[extract_glb] Exporting GLB to: {glb_path}")
293
  glb.export(glb_path)
294
  print("[extract_glb] GLB exported successfully.")
 
295
  except Exception as e:
296
- print(f"❌ [extract_glb] GLB conversion/export error: {e}", file=sys.stderr)
297
  traceback.print_exc()
298
- raise gr.Error(f"Failed to extract GLB: {e}")
299
 
300
  # --- Cleanup and Return ---
301
  if torch.cuda.is_available():
302
  torch.cuda.empty_cache()
303
  print("[extract_glb] Cleared CUDA cache.")
304
 
305
- # Return path twice for both Model3D and DownloadButton components
306
  print("[extract_glb] Returning GLB path.")
307
- # Ensure path is returned, even if export failed somehow (though error should raise first)
 
 
308
  return glb_path, glb_path
309
 
310
 
311
  @spaces.GPU
312
  def extract_gaussian(
313
- state_dict: dict, # <-- Accepts the dictionary directly
314
  req: gr.Request
315
  ) -> Tuple[str, str]:
316
  """
317
  Extracts a PLY (Gaussian) file from the provided 3D model state dictionary.
318
  """
319
  print("[extract_gaussian] Received request.")
320
- session_hash = req.session_hash
321
- if not session_hash:
322
- session_hash = f"no_session_{np.random.randint(10000, 99999)}"
323
- print(f"Warning: No session_hash in extract_gaussian request, using temporary ID: {session_hash}")
 
 
 
324
 
325
- if not isinstance(state_dict, dict):
326
- print("❌ [extract_gaussian] Error: Invalid state_dict received (not a dictionary).")
327
- raise gr.Error("Invalid state data received. Please generate the model first.")
328
 
329
- user_dir = os.path.join(TMP_DIR, str(session_hash))
330
- os.makedirs(user_dir, exist_ok=True) # Ensure it exists
331
- print(f"[extract_gaussian] User directory: {user_dir}")
332
 
333
- # --- Unpack state from the dictionary ---
334
- try:
335
- gs, _ = unpack_state(state_dict) # Only need Gaussian part
336
- except Exception as e:
337
- print(f"❌ [extract_gaussian] unpack_state error: {e}", file=sys.stderr)
338
- traceback.print_exc()
339
- raise gr.Error(f"Failed to unpack state during Gaussian extraction: {e}")
340
 
341
- # --- Export PLY ---
342
- try:
343
  gaussian_path = os.path.join(user_dir, 'sample.ply')
344
  print(f"[extract_gaussian] Saving PLY to: {gaussian_path}")
345
  gs.save_ply(gaussian_path)
346
  print("[extract_gaussian] PLY saved successfully.")
 
347
  except Exception as e:
348
- print(f"❌ [extract_gaussian] PLY saving error: {e}", file=sys.stderr)
349
  traceback.print_exc()
350
- raise gr.Error(f"Failed to extract Gaussian PLY: {e}")
351
 
352
  # --- Cleanup and Return ---
353
  if torch.cuda.is_available():
354
  torch.cuda.empty_cache()
355
  print("[extract_gaussian] Cleared CUDA cache.")
356
 
357
- # Return path twice for both Model3D and DownloadButton components
358
  print("[extract_gaussian] Returning PLY path.")
359
- # Ensure path is returned
 
 
360
  return gaussian_path, gaussian_path
361
 
362
 
363
  # --- Gradio UI Definition ---
364
  print("Setting up Gradio Blocks interface...")
365
- # Define the interface layout
366
  with gr.Blocks(delete_cache=(600, 600), title="TRELLIS Text-to-3D") as demo:
367
  gr.Markdown("""
368
  # Text to 3D Asset with [TRELLIS](https://trellis3d.github.io/)
@@ -373,7 +356,6 @@ with gr.Blocks(delete_cache=(600, 600), title="TRELLIS Text-to-3D") as demo:
373
  """)
374
 
375
  # --- State Buffer ---
376
- # This hidden component holds the dictionary linking generation and extraction.
377
  output_buf = gr.State()
378
 
379
  with gr.Row():
@@ -394,7 +376,7 @@ with gr.Blocks(delete_cache=(600, 600), title="TRELLIS Text-to-3D") as demo:
394
 
395
  generate_btn = gr.Button("Generate 3D Preview", variant="primary")
396
 
397
- with gr.Accordion(label="GLB Extraction Settings", open=True): # Open by default
398
  mesh_simplify = gr.Slider(0.9, 0.99, label="Simplify Factor", value=0.95, step=0.01, info="Higher value = less simplification (more polys)")
399
  texture_size = gr.Slider(512, 2048, label="Texture Size (pixels)", value=1024, step=512, info="Size of the generated texture map")
400
 
@@ -408,7 +390,7 @@ with gr.Blocks(delete_cache=(600, 600), title="TRELLIS Text-to-3D") as demo:
408
  with gr.Column(scale=1): # Output column
409
  # Video component remains for layout but won't show anything in this debug version
410
  video_output = gr.Video(label="Generated 3D Preview (DISABLED FOR DEBUG)", autoplay=False, loop=False, value=None, height=350)
411
- model_output = gr.Model3D(label="Extracted Model Preview", height=350, clear_color=[0.95, 0.95, 0.95, 1.0]) # Light background
412
 
413
  with gr.Row():
414
  download_glb = gr.DownloadButton(label="Download GLB", interactive=False)
@@ -418,8 +400,10 @@ with gr.Blocks(delete_cache=(600, 600), title="TRELLIS Text-to-3D") as demo:
418
  print("Defining Gradio event handlers...")
419
 
420
  # Handle session start/end
 
421
  demo.load(start_session, inputs=None, outputs=None)
422
- demo.unload(end_session, inputs=None, outputs=None)
 
423
 
424
  # --- Generate Button Click Flow ---
425
  generate_event = generate_btn.click(
@@ -430,18 +414,16 @@ with gr.Blocks(delete_cache=(600, 600), title="TRELLIS Text-to-3D") as demo:
430
  ).then(
431
  text_to_3d,
432
  inputs=[text_prompt, seed, ss_guidance_strength, ss_sampling_steps, slat_guidance_strength, slat_sampling_steps],
433
- # Output state_dict to buffer, output None to video component
434
- outputs=[output_buf, video_output],
435
  api_name="text_to_3d"
436
  ).then(
437
- # Function to update button interactivity after generation attempt
438
  lambda: (
439
  gr.Button(interactive=True),
440
  gr.Button(interactive=True),
441
  gr.DownloadButton(interactive=False),
442
  gr.DownloadButton(interactive=False)
443
  ),
444
- inputs=None, # No inputs needed for the lambda
445
  outputs=[extract_glb_btn, extract_gs_btn, download_glb, download_gs],
446
  )
447
 
@@ -475,7 +457,6 @@ with gr.Blocks(delete_cache=(600, 600), title="TRELLIS Text-to-3D") as demo:
475
  inputs=None,
476
  outputs=[download_glb, download_gs]
477
  )
478
- # Also disable buttons if the (currently disabled) video output is cleared
479
  video_output.clear(
480
  lambda: (
481
  gr.Button(interactive=False),
@@ -491,18 +472,15 @@ with gr.Blocks(delete_cache=(600, 600), title="TRELLIS Text-to-3D") as demo:
491
 
492
 
493
  # --- Launch the Gradio app ---
494
- # Main execution block
495
  if __name__ == "__main__":
496
  print("Loading Trellis pipeline...")
497
  pipeline_loaded = False
 
498
  try:
499
- # Ensure model/variant matches requirements, use revision if needed
500
  pipeline = TrellisTextTo3DPipeline.from_pretrained(
501
  "JeffreyXiang/TRELLIS-text-xlarge",
502
- # revision="main", # Specify if needed
503
- torch_dtype=torch.float16 # Use float16 if GPU supports it for less memory
504
  )
505
- # Move to GPU if available
506
  if torch.cuda.is_available():
507
  pipeline = pipeline.to("cuda")
508
  print("✅ Trellis pipeline loaded successfully to GPU.")
@@ -513,23 +491,20 @@ if __name__ == "__main__":
513
  except Exception as e:
514
  print(f"❌ Failed to load Trellis pipeline: {e}", file=sys.stderr)
515
  traceback.print_exc()
516
- # Exit if pipeline is critical for the app to run
517
  print("❌ Exiting due to pipeline load failure.")
518
- sys.exit(1) # Exit if pipeline fails
519
 
520
  if pipeline_loaded:
521
  print("Launching Gradio demo...")
522
- # Set share=True if you need a public link (e.g., for testing from outside local network)
523
- # Set server_name="0.0.0.0" to allow access from local network IP
524
- # Increased concurrency_limit and timeout for queue might help
525
  demo.queue(
526
- # default_concurrency_limit=5, # Adjust based on expected load and space resources
527
- # api_open=True # Keep API accessible
528
  ).launch(
529
- # server_name="0.0.0.0", # Make accessible on local network
530
- # share=False, # Set to True for public link if needed
531
- debug=True, # Enable Gradio debug mode for more detailed logs
532
- # prevent_thread_lock=True # May help with async issues in some cases
533
  )
534
  print("Gradio demo launched.")
535
  else:
 
1
+ # Version: 1.1.3 - API State Fix + DEBUG (Video Disabled) + unload() Fix (2025-05-04)
2
  # Changes:
3
+ # - FIXED TypeError in demo.unload() by removing incorrect 'inputs'/'outputs' arguments.
4
  # - ENSURED `import spaces` is present for the @spaces.GPU decorator.
5
  # - TEMPORARY DEBUGGING STEP: Commented out video rendering in `text_to_3d`
6
  # and return None for video_path to isolate the "Session not found" error.
7
+ # - Modified `text_to_3d` to explicitly return the serializable `state_dict` from `pack_state`.
8
+ # - Modified `extract_glb`/`extract_gaussian` to accept `state_dict: dict`.
9
+ # - Kept Gradio UI bindings using `output_buf`.
 
 
 
10
  # - Added minor safety checks and logging.
11
 
12
  import gradio as gr
 
15
  import os
16
  import shutil
17
  os.environ['TOKENIZERS_PARALLELISM'] = 'true'
 
 
18
  os.environ['SPCONV_ALGO'] = 'native' # Direct set as per original
19
 
20
  from typing import *
 
31
 
32
 
33
  MAX_SEED = np.iinfo(np.int32).max
34
+ # Use standard /tmp directory which is usually available in container environments
35
+ TMP_DIR = '/tmp/gradio_sessions'
 
 
36
  print(f"Using temporary directory: {TMP_DIR}")
37
+ # Ensure the base temp directory exists
38
+ try:
39
+ os.makedirs(TMP_DIR, exist_ok=True)
40
+ except OSError as e:
41
+ print(f"Warning: Could not create base temp directory {TMP_DIR}: {e}", file=sys.stderr)
42
+ # Potentially fall back or exit if temp dir is critical
43
+ TMP_DIR = '.' # Fallback to current directory (less ideal)
44
+ print(f"Warning: Falling back to use current directory for temp files: {os.path.abspath(TMP_DIR)}")
45
 
46
 
47
  def start_session(req: gr.Request):
48
  """Creates a temporary directory for the user session."""
49
+ user_dir = None # Initialize
50
  try:
51
  session_hash = req.session_hash
52
  if not session_hash:
 
53
  session_hash = f"no_session_{np.random.randint(10000, 99999)}"
54
  print(f"Warning: No session_hash in request, using temporary ID: {session_hash}")
55
 
56
+ # Ensure TMP_DIR exists before joining path
57
+ if not os.path.exists(TMP_DIR):
58
+ os.makedirs(TMP_DIR, exist_ok=True)
59
+
60
  user_dir = os.path.join(TMP_DIR, str(session_hash))
61
  os.makedirs(user_dir, exist_ok=True)
62
+ print(f"Started session, ensured directory exists: {user_dir}")
63
  except Exception as e:
64
+ print(f"Error in start_session creating directory '{user_dir}': {e}", file=sys.stderr)
65
+ traceback.print_exc()
66
 
67
 
68
  def end_session(req: gr.Request):
69
  """Removes the temporary directory for the user session."""
70
+ user_dir = None # Initialize
71
  try:
72
  session_hash = req.session_hash
73
  if not session_hash:
 
75
  return
76
 
77
  user_dir = os.path.join(TMP_DIR, str(session_hash))
78
+ if os.path.exists(user_dir) and os.path.isdir(user_dir): # Extra check if it's a directory
79
  try:
80
  shutil.rmtree(user_dir)
81
  print(f"Ended session, removed directory: {user_dir}")
82
  except OSError as e:
83
  print(f"Error removing tmp directory {user_dir}: {e.strerror}", file=sys.stderr)
84
  else:
85
+ print(f"Ended session, directory not found or not a directory: {user_dir}")
86
  except Exception as e:
87
+ print(f"Error in end_session cleaning directory '{user_dir}': {e}", file=sys.stderr)
88
 
89
 
90
  def pack_state(gs: Gaussian, mesh: MeshExtractResult) -> dict:
 
93
  try:
94
  packed_data = {
95
  'gaussian': {
96
+ **{k: v for k, v in gs.init_params.items()},
97
  '_xyz': gs._xyz.detach().cpu().numpy(),
98
  '_features_dc': gs._features_dc.detach().cpu().numpy(),
99
  '_scaling': gs._scaling.detach().cpu().numpy(),
 
110
  except Exception as e:
111
  print(f"Error during pack_state: {e}", file=sys.stderr)
112
  traceback.print_exc()
113
+ raise
114
 
115
 
116
  def unpack_state(state_dict: dict) -> Tuple[Gaussian, edict]:
 
120
  if not isinstance(state_dict, dict) or 'gaussian' not in state_dict or 'mesh' not in state_dict:
121
  raise ValueError("Invalid state_dict structure passed to unpack_state.")
122
 
 
123
  device = 'cuda' if torch.cuda.is_available() else 'cpu'
124
  print(f"[unpack_state] Using device: {device}")
125
 
126
  gauss_data = state_dict['gaussian']
127
  mesh_data = state_dict['mesh']
128
 
 
129
  gs = Gaussian(
130
+ aabb=gauss_data.get('aabb'),
131
  sh_degree=gauss_data.get('sh_degree'),
132
  mininum_kernel_size=gauss_data.get('mininum_kernel_size'),
133
  scaling_bias=gauss_data.get('scaling_bias'),
134
  opacity_bias=gauss_data.get('opacity_bias'),
135
  scaling_activation=gauss_data.get('scaling_activation'),
136
  )
 
137
  gs._xyz = torch.tensor(gauss_data['_xyz'], device=device, dtype=torch.float32)
138
  gs._features_dc = torch.tensor(gauss_data['_features_dc'], device=device, dtype=torch.float32)
139
  gs._scaling = torch.tensor(gauss_data['_scaling'], device=device, dtype=torch.float32)
 
141
  gs._opacity = torch.tensor(gauss_data['_opacity'], device=device, dtype=torch.float32)
142
  print(f"[unpack_state] Gaussian unpacked. Points: {gs.get_xyz.shape[0]}")
143
 
 
144
  mesh = edict(
145
  vertices=torch.tensor(mesh_data['vertices'], device=device, dtype=torch.float32),
146
+ faces=torch.tensor(mesh_data['faces'], device=device, dtype=torch.int64),
147
  )
148
  print(f"[unpack_state] Mesh unpacked. Vertices: {mesh.vertices.shape[0]}, Faces: {mesh.faces.shape[0]}")
149
 
 
151
  except Exception as e:
152
  print(f"Error during unpack_state: {e}", file=sys.stderr)
153
  traceback.print_exc()
154
+ raise
155
 
156
 
157
  def get_seed(randomize_seed: bool, seed: int) -> int:
158
  """Gets a seed value, randomizing if requested."""
159
  new_seed = np.random.randint(0, MAX_SEED) if randomize_seed else seed
160
  print(f"[get_seed] Randomize: {randomize_seed}, Input Seed: {seed}, Output Seed: {new_seed}")
161
+ return int(new_seed)
162
 
163
 
164
  @spaces.GPU
 
170
  slat_guidance_strength: float,
171
  slat_sampling_steps: int,
172
  req: gr.Request,
173
+ ) -> Tuple[dict, Optional[str]]:
174
  """
175
  Generates a 3D model (Gaussian and Mesh) from text and returns a
176
  serializable state dictionary and potentially a video preview path.
177
  >>> TEMPORARILY DISABLED VIDEO RENDERING FOR DEBUGGING <<<
178
  """
179
  print(f"[text_to_3d - DEBUG MODE] Received prompt: '{prompt}', Seed: {seed}")
180
+ user_dir = None # Initialize
181
+ state_dict = None # Initialize
 
 
 
 
 
 
 
182
  try:
183
+ session_hash = req.session_hash
184
+ if not session_hash:
185
+ session_hash = f"no_session_{np.random.randint(10000, 99999)}"
186
+ print(f"Warning: No session_hash in text_to_3d request, using temporary ID: {session_hash}")
187
+
188
+ # Ensure user directory exists
189
+ user_dir = os.path.join(TMP_DIR, str(session_hash))
190
+ os.makedirs(user_dir, exist_ok=True)
191
+ print(f"[text_to_3d - DEBUG MODE] User directory: {user_dir}")
192
+
193
+ # --- Generation Pipeline ---
194
  print("[text_to_3d - DEBUG MODE] Running Trellis pipeline...")
 
195
  outputs = pipeline.run(
196
  prompt=prompt,
197
  seed=seed,
198
+ formats=["gaussian", "mesh"],
199
  sparse_structure_sampler_params={
200
+ "steps": int(ss_sampling_steps),
201
  "cfg_strength": float(ss_guidance_strength),
202
  },
203
  slat_sampler_params={
204
+ "steps": int(slat_sampling_steps),
205
  "cfg_strength": float(slat_guidance_strength),
206
  },
 
207
  )
208
  print("[text_to_3d - DEBUG MODE] Pipeline run completed.")
 
 
 
 
209
 
210
+ # --- Create Serializable State Dictionary ---
 
211
  state_dict = pack_state(outputs['gaussian'][0], outputs['mesh'][0])
212
+
213
  except Exception as e:
214
+ print(f"❌ [text_to_3d - DEBUG MODE] Error during generation or packing: {e}", file=sys.stderr)
215
  traceback.print_exc()
216
+ # Raise a Gradio error to send failure message back to client if possible
217
+ raise gr.Error(f"Core generation failed: {e}")
218
 
219
  # --- Render Video Preview (TEMPORARILY DISABLED FOR DEBUGGING) ---
220
+ video_path = None
221
  print("[text_to_3d - DEBUG MODE] Skipping video rendering.")
222
+ # --- Original Video Code Block (Keep commented) ---
223
+ # ... (video code commented out) ...
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
224
 
225
  # --- Cleanup and Return ---
226
  if torch.cuda.is_available():
 
229
 
230
  # --- Return Serializable Dictionary and None Video Path ---
231
  print("[text_to_3d - DEBUG MODE] Returning state dictionary and None video path.")
 
232
  if state_dict is None:
233
+ # This case should ideally be caught by the exception handling above
234
+ print("Error: state_dict is None before return, generation likely failed.", file=sys.stderr)
235
+ raise gr.Error("State dictionary creation failed.")
236
  return state_dict, video_path
237
 
238
 
239
+ @spaces.GPU(duration=120)
240
  def extract_glb(
241
+ state_dict: dict,
242
  mesh_simplify: float,
243
  texture_size: int,
244
  req: gr.Request,
 
247
  Extracts a GLB file from the provided 3D model state dictionary.
248
  """
249
  print(f"[extract_glb] Received request. Simplify: {mesh_simplify}, Texture Size: {texture_size}")
250
+ user_dir = None # Initialize
251
+ glb_path = None # Initialize
252
+ try:
253
+ session_hash = req.session_hash
254
+ if not session_hash:
255
+ session_hash = f"no_session_{np.random.randint(10000, 99999)}"
256
+ print(f"Warning: No session_hash in extract_glb request, using temporary ID: {session_hash}")
257
 
258
+ if not isinstance(state_dict, dict):
259
+ print("❌ [extract_glb] Error: Invalid state_dict received (not a dictionary).")
260
+ raise gr.Error("Invalid state data received. Please generate the model first.")
261
 
262
+ user_dir = os.path.join(TMP_DIR, str(session_hash))
263
+ os.makedirs(user_dir, exist_ok=True)
264
+ print(f"[extract_glb] User directory: {user_dir}")
265
 
266
+ # --- Unpack state from the dictionary ---
 
267
  gs, mesh = unpack_state(state_dict)
 
 
 
 
268
 
269
+ # --- Postprocessing and Export ---
 
270
  print("[extract_glb] Converting to GLB...")
 
271
  simplify_factor = float(mesh_simplify)
272
  tex_size = int(texture_size)
273
  glb = postprocessing_utils.to_glb(gs, mesh, simplify=simplify_factor, texture_size=tex_size, verbose=True)
 
275
  print(f"[extract_glb] Exporting GLB to: {glb_path}")
276
  glb.export(glb_path)
277
  print("[extract_glb] GLB exported successfully.")
278
+
279
  except Exception as e:
280
+ print(f"❌ [extract_glb] Error during GLB extraction: {e}", file=sys.stderr)
281
  traceback.print_exc()
282
+ raise gr.Error(f"Failed to extract GLB: {e}") # Propagate error
283
 
284
  # --- Cleanup and Return ---
285
  if torch.cuda.is_available():
286
  torch.cuda.empty_cache()
287
  print("[extract_glb] Cleared CUDA cache.")
288
 
 
289
  print("[extract_glb] Returning GLB path.")
290
+ if glb_path is None:
291
+ print("Error: glb_path is None before return, extraction likely failed.", file=sys.stderr)
292
+ raise gr.Error("GLB path generation failed.")
293
  return glb_path, glb_path
294
 
295
 
296
  @spaces.GPU
297
  def extract_gaussian(
298
+ state_dict: dict,
299
  req: gr.Request
300
  ) -> Tuple[str, str]:
301
  """
302
  Extracts a PLY (Gaussian) file from the provided 3D model state dictionary.
303
  """
304
  print("[extract_gaussian] Received request.")
305
+ user_dir = None # Initialize
306
+ gaussian_path = None # Initialize
307
+ try:
308
+ session_hash = req.session_hash
309
+ if not session_hash:
310
+ session_hash = f"no_session_{np.random.randint(10000, 99999)}"
311
+ print(f"Warning: No session_hash in extract_gaussian request, using temporary ID: {session_hash}")
312
 
313
+ if not isinstance(state_dict, dict):
314
+ print("❌ [extract_gaussian] Error: Invalid state_dict received (not a dictionary).")
315
+ raise gr.Error("Invalid state data received. Please generate the model first.")
316
 
317
+ user_dir = os.path.join(TMP_DIR, str(session_hash))
318
+ os.makedirs(user_dir, exist_ok=True)
319
+ print(f"[extract_gaussian] User directory: {user_dir}")
320
 
321
+ # --- Unpack state from the dictionary ---
322
+ gs, _ = unpack_state(state_dict)
 
 
 
 
 
323
 
324
+ # --- Export PLY ---
 
325
  gaussian_path = os.path.join(user_dir, 'sample.ply')
326
  print(f"[extract_gaussian] Saving PLY to: {gaussian_path}")
327
  gs.save_ply(gaussian_path)
328
  print("[extract_gaussian] PLY saved successfully.")
329
+
330
  except Exception as e:
331
+ print(f"❌ [extract_gaussian] Error during Gaussian extraction: {e}", file=sys.stderr)
332
  traceback.print_exc()
333
+ raise gr.Error(f"Failed to extract Gaussian PLY: {e}") # Propagate error
334
 
335
  # --- Cleanup and Return ---
336
  if torch.cuda.is_available():
337
  torch.cuda.empty_cache()
338
  print("[extract_gaussian] Cleared CUDA cache.")
339
 
 
340
  print("[extract_gaussian] Returning PLY path.")
341
+ if gaussian_path is None:
342
+ print("Error: gaussian_path is None before return, extraction likely failed.", file=sys.stderr)
343
+ raise gr.Error("Gaussian PLY path generation failed.")
344
  return gaussian_path, gaussian_path
345
 
346
 
347
  # --- Gradio UI Definition ---
348
  print("Setting up Gradio Blocks interface...")
 
349
  with gr.Blocks(delete_cache=(600, 600), title="TRELLIS Text-to-3D") as demo:
350
  gr.Markdown("""
351
  # Text to 3D Asset with [TRELLIS](https://trellis3d.github.io/)
 
356
  """)
357
 
358
  # --- State Buffer ---
 
359
  output_buf = gr.State()
360
 
361
  with gr.Row():
 
376
 
377
  generate_btn = gr.Button("Generate 3D Preview", variant="primary")
378
 
379
+ with gr.Accordion(label="GLB Extraction Settings", open=True):
380
  mesh_simplify = gr.Slider(0.9, 0.99, label="Simplify Factor", value=0.95, step=0.01, info="Higher value = less simplification (more polys)")
381
  texture_size = gr.Slider(512, 2048, label="Texture Size (pixels)", value=1024, step=512, info="Size of the generated texture map")
382
 
 
390
  with gr.Column(scale=1): # Output column
391
  # Video component remains for layout but won't show anything in this debug version
392
  video_output = gr.Video(label="Generated 3D Preview (DISABLED FOR DEBUG)", autoplay=False, loop=False, value=None, height=350)
393
+ model_output = gr.Model3D(label="Extracted Model Preview", height=350, clear_color=[0.95, 0.95, 0.95, 1.0])
394
 
395
  with gr.Row():
396
  download_glb = gr.DownloadButton(label="Download GLB", interactive=False)
 
400
  print("Defining Gradio event handlers...")
401
 
402
  # Handle session start/end
403
+ # demo.load() is valid with inputs=None, outputs=None (though default)
404
  demo.load(start_session, inputs=None, outputs=None)
405
+ # >>> FIX: demo.unload() does NOT take inputs/outputs arguments <<<
406
+ demo.unload(end_session) # Removed inputs/outputs kwargs
407
 
408
  # --- Generate Button Click Flow ---
409
  generate_event = generate_btn.click(
 
414
  ).then(
415
  text_to_3d,
416
  inputs=[text_prompt, seed, ss_guidance_strength, ss_sampling_steps, slat_guidance_strength, slat_sampling_steps],
417
+ outputs=[output_buf, video_output], # state_dict -> output_buf, None -> video_output
 
418
  api_name="text_to_3d"
419
  ).then(
 
420
  lambda: (
421
  gr.Button(interactive=True),
422
  gr.Button(interactive=True),
423
  gr.DownloadButton(interactive=False),
424
  gr.DownloadButton(interactive=False)
425
  ),
426
+ inputs=None,
427
  outputs=[extract_glb_btn, extract_gs_btn, download_glb, download_gs],
428
  )
429
 
 
457
  inputs=None,
458
  outputs=[download_glb, download_gs]
459
  )
 
460
  video_output.clear(
461
  lambda: (
462
  gr.Button(interactive=False),
 
472
 
473
 
474
  # --- Launch the Gradio app ---
 
475
  if __name__ == "__main__":
476
  print("Loading Trellis pipeline...")
477
  pipeline_loaded = False
478
+ pipeline = None # Initialize
479
  try:
 
480
  pipeline = TrellisTextTo3DPipeline.from_pretrained(
481
  "JeffreyXiang/TRELLIS-text-xlarge",
482
+ torch_dtype=torch.float16 # Use float16 if GPU supports it
 
483
  )
 
484
  if torch.cuda.is_available():
485
  pipeline = pipeline.to("cuda")
486
  print("✅ Trellis pipeline loaded successfully to GPU.")
 
491
  except Exception as e:
492
  print(f"❌ Failed to load Trellis pipeline: {e}", file=sys.stderr)
493
  traceback.print_exc()
 
494
  print("❌ Exiting due to pipeline load failure.")
495
+ sys.exit(1)
496
 
497
  if pipeline_loaded:
498
  print("Launching Gradio demo...")
499
+ # Consider increasing queue timeout if tasks are long
 
 
500
  demo.queue(
501
+ # default_concurrency_limit=2, # Limit concurrency if resource issues suspected
502
+ # status_update_rate='auto'
503
  ).launch(
504
+ # server_name="0.0.0.0", # Allows access from local network
505
+ # share=False, # Set True for public link (careful with resources)
506
+ debug=True, # Enable Gradio/FastAPI debug logs
507
+ # prevent_thread_lock=True # Might help sometimes
508
  )
509
  print("Gradio demo launched.")
510
  else: