dkatz2391 commited on
Commit
1809fe4
·
verified ·
1 Parent(s): 63ce34f

try that again last one was missing import spaces

Browse files
Files changed (1) hide show
  1. app.py +478 -29
app.py CHANGED
@@ -1,4 +1,163 @@
1
- # In appTrellis_fileErrorsFIx.py
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
 
3
  @spaces.GPU
4
  def text_to_3d(
@@ -9,21 +168,27 @@ def text_to_3d(
9
  slat_guidance_strength: float,
10
  slat_sampling_steps: int,
11
  req: gr.Request,
12
- ) -> Tuple[dict, str]: # Return type changed for clarity
13
  """
14
  Generates a 3D model (Gaussian and Mesh) from text and returns a
15
- serializable state dictionary and a video preview path.
 
16
  """
17
- print(f"[text_to_3d] Received prompt: '{prompt}', Seed: {seed}")
18
- user_dir = os.path.join(TMP_DIR, str(req.session_hash))
19
- os.makedirs(user_dir, exist_ok=True)
20
- print(f"[text_to_3d] User directory: {user_dir}")
 
 
 
 
21
 
22
  # --- Generation Pipeline ---
23
  try:
24
- print("[text_to_3d] Running Trellis pipeline...")
 
25
  outputs = pipeline.run(
26
- prompt,
27
  seed=seed,
28
  formats=["gaussian", "mesh"], # Ensure both are generated
29
  sparse_structure_sampler_params={
@@ -34,54 +199,338 @@ def text_to_3d(
34
  "steps": int(slat_sampling_steps), # Ensure steps are int
35
  "cfg_strength": float(slat_guidance_strength),
36
  },
 
37
  )
38
- print("[text_to_3d] Pipeline run completed.")
39
  except Exception as e:
40
- print(f"❌ [text_to_3d] Pipeline error: {e}", file=sys.stderr)
41
  traceback.print_exc()
42
- raise gr.Error(f"Trellis pipeline failed: {e}")
43
 
44
  # --- Create Serializable State Dictionary ---
45
  try:
46
  state_dict = pack_state(outputs['gaussian'][0], outputs['mesh'][0])
47
  except Exception as e:
48
- print(f"❌ [text_to_3d] pack_state error: {e}", file=sys.stderr)
49
  traceback.print_exc()
50
- raise gr.Error(f"Failed to pack state: {e}")
51
 
52
  # --- Render Video Preview (TEMPORARILY DISABLED FOR DEBUGGING) ---
53
- video_path = None # Set path to None
 
 
54
  # try:
55
  # print("[text_to_3d] Rendering video preview...")
56
  # video = render_utils.render_video(outputs['gaussian'][0], num_frames=120)['color']
57
  # video_geo = render_utils.render_video(outputs['mesh'][0], num_frames=120)['normal']
58
  # # Ensure video frames are uint8
59
  # video = [np.concatenate([v.astype(np.uint8), vg.astype(np.uint8)], axis=1) for v, vg in zip(video, video_geo)]
60
- # video_path = os.path.join(user_dir, 'sample.mp4')
61
- # imageio.mimsave(video_path, video, fps=15, quality=8) # Added quality setting
62
- # print(f"[text_to_3d] Video saved to: {video_path}")
 
63
  # except Exception as e:
64
  # print(f"❌ [text_to_3d] Video rendering/saving error: {e}", file=sys.stderr)
65
  # traceback.print_exc()
66
  # # Still return state_dict, but maybe signal video error? Return None for path.
67
  # video_path = None # Indicate video failure
68
- print("[text_to_3d] Skipping video rendering for debugging.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
69
 
70
  # --- Cleanup and Return ---
71
  if torch.cuda.is_available():
72
  torch.cuda.empty_cache()
73
- print("[text_to_3d] Cleared CUDA cache.")
 
 
 
 
 
74
 
75
- # --- Return Serializable Dictionary and None for Video Path ---
76
- print("[text_to_3d] Returning state dictionary and None video path.")
77
- return state_dict, video_path # Return dict and None video path
78
 
79
  # --- Gradio UI Definition ---
80
- # ... (rest of the file is the same, but you might want to adjust the output mapping if needed)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
81
 
82
- # In the generate_btn.click handler, adjust the outputs if the video component causes issues:
83
- # Option 1: Keep Video component, it will just show nothing.
84
- # outputs=[output_buf, video_output], # This might be fine
85
 
86
- # Option 2: Use a dummy hidden component if video_output causes issues receiving None
87
- # outputs=[output_buf, gr.Textbox(visible=False)], # Example dummy
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Version: 1.1.2 - API State Fix + DEBUG (Video Disabled) + Import Fix (2025-05-04)
2
+ # Changes:
3
+ # - ENSURED `import spaces` is present for the @spaces.GPU decorator.
4
+ # - TEMPORARY DEBUGGING STEP: Commented out video rendering in `text_to_3d`
5
+ # and return None for video_path to isolate the "Session not found" error.
6
+ # - Modified `text_to_3d` to explicitly return the serializable `state_dict` from `pack_state`
7
+ # as the first return value. This ensures the dictionary is available via the API.
8
+ # - Modified `extract_glb` and `extract_gaussian` to accept `state_dict: dict` as their first argument
9
+ # instead of relying on the implicit `gr.State` object type when called via API.
10
+ # - Kept Gradio UI bindings (`outputs=[output_buf, ...]`, `inputs=[output_buf, ...]`)
11
+ # so the UI continues to function by passing the dictionary through output_buf.
12
+ # - Added minor safety checks and logging.
13
+
14
+ import gradio as gr
15
+ import spaces # <<<--- ENSURE THIS IMPORT IS PRESENT
16
+
17
+ import os
18
+ import shutil
19
+ os.environ['TOKENIZERS_PARALLELISM'] = 'true'
20
+ # Fix potential SpConv issue if needed, try 'hash' or 'native'
21
+ # os.environ.setdefault('SPCONV_ALGO', 'native') # Use setdefault to avoid overwriting if already set
22
+ os.environ['SPCONV_ALGO'] = 'native' # Direct set as per original
23
+
24
+ from typing import *
25
+ import torch
26
+ import numpy as np
27
+ import imageio
28
+ from easydict import EasyDict as edict
29
+ from trellis.pipelines import TrellisTextTo3DPipeline
30
+ from trellis.representations import Gaussian, MeshExtractResult
31
+ from trellis.utils import render_utils, postprocessing_utils
32
+
33
+ import traceback
34
+ import sys
35
+
36
+
37
+ MAX_SEED = np.iinfo(np.int32).max
38
+ # Ensure TMP_DIR is correctly defined relative to the script location
39
+ # Using /tmp/ directly might be more robust in some container environments
40
+ # TMP_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'tmp')
41
+ TMP_DIR = '/tmp/gradio_sessions' # Use standard /tmp directory
42
+ print(f"Using temporary directory: {TMP_DIR}")
43
+ os.makedirs(TMP_DIR, exist_ok=True)
44
+
45
+
46
+ def start_session(req: gr.Request):
47
+ """Creates a temporary directory for the user session."""
48
+ try:
49
+ session_hash = req.session_hash
50
+ if not session_hash:
51
+ # Fallback or generate a temporary ID if session_hash is missing (might happen on first load?)
52
+ session_hash = f"no_session_{np.random.randint(10000, 99999)}"
53
+ print(f"Warning: No session_hash in request, using temporary ID: {session_hash}")
54
+
55
+ user_dir = os.path.join(TMP_DIR, str(session_hash))
56
+ os.makedirs(user_dir, exist_ok=True)
57
+ print(f"Started session, created directory: {user_dir}")
58
+ except Exception as e:
59
+ print(f"Error in start_session: {e}", file=sys.stderr)
60
+ # Decide if this is critical - maybe raise to prevent further issues?
61
+
62
+
63
+ def end_session(req: gr.Request):
64
+ """Removes the temporary directory for the user session."""
65
+ try:
66
+ session_hash = req.session_hash
67
+ if not session_hash:
68
+ print("Warning: No session_hash in end_session request, cannot clean up.")
69
+ return
70
+
71
+ user_dir = os.path.join(TMP_DIR, str(session_hash))
72
+ if os.path.exists(user_dir):
73
+ try:
74
+ shutil.rmtree(user_dir)
75
+ print(f"Ended session, removed directory: {user_dir}")
76
+ except OSError as e:
77
+ print(f"Error removing tmp directory {user_dir}: {e.strerror}", file=sys.stderr)
78
+ else:
79
+ print(f"Ended session, directory already removed or hash mismatch: {user_dir}")
80
+ except Exception as e:
81
+ print(f"Error in end_session: {e}", file=sys.stderr)
82
+
83
+
84
+ def pack_state(gs: Gaussian, mesh: MeshExtractResult) -> dict:
85
+ """Packs Gaussian and Mesh data into a serializable dictionary."""
86
+ print("[pack_state] Packing state to dictionary...")
87
+ try:
88
+ packed_data = {
89
+ 'gaussian': {
90
+ **{k: v for k, v in gs.init_params.items()}, # Ensure init_params are included
91
+ '_xyz': gs._xyz.detach().cpu().numpy(),
92
+ '_features_dc': gs._features_dc.detach().cpu().numpy(),
93
+ '_scaling': gs._scaling.detach().cpu().numpy(),
94
+ '_rotation': gs._rotation.detach().cpu().numpy(),
95
+ '_opacity': gs._opacity.detach().cpu().numpy(),
96
+ },
97
+ 'mesh': {
98
+ 'vertices': mesh.vertices.detach().cpu().numpy(),
99
+ 'faces': mesh.faces.detach().cpu().numpy(),
100
+ },
101
+ }
102
+ print(f"[pack_state] Dictionary created. Keys: {list(packed_data.keys())}, Gaussian points: {len(packed_data['gaussian']['_xyz'])}, Mesh vertices: {len(packed_data['mesh']['vertices'])}")
103
+ return packed_data
104
+ except Exception as e:
105
+ print(f"Error during pack_state: {e}", file=sys.stderr)
106
+ traceback.print_exc()
107
+ raise # Re-raise the error to be caught upstream
108
+
109
+
110
+ def unpack_state(state_dict: dict) -> Tuple[Gaussian, edict]:
111
+ """Unpacks Gaussian and Mesh data from a dictionary."""
112
+ print("[unpack_state] Unpacking state from dictionary...")
113
+ try:
114
+ if not isinstance(state_dict, dict) or 'gaussian' not in state_dict or 'mesh' not in state_dict:
115
+ raise ValueError("Invalid state_dict structure passed to unpack_state.")
116
+
117
+ # Ensure the device is correctly set when unpacking
118
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
119
+ print(f"[unpack_state] Using device: {device}")
120
+
121
+ gauss_data = state_dict['gaussian']
122
+ mesh_data = state_dict['mesh']
123
+
124
+ # Recreate Gaussian object using parameters stored during packing
125
+ gs = Gaussian(
126
+ aabb=gauss_data.get('aabb'), # Use .get for safety
127
+ sh_degree=gauss_data.get('sh_degree'),
128
+ mininum_kernel_size=gauss_data.get('mininum_kernel_size'),
129
+ scaling_bias=gauss_data.get('scaling_bias'),
130
+ opacity_bias=gauss_data.get('opacity_bias'),
131
+ scaling_activation=gauss_data.get('scaling_activation'),
132
+ )
133
+ # Load tensors, ensuring they are created on the correct device
134
+ gs._xyz = torch.tensor(gauss_data['_xyz'], device=device, dtype=torch.float32)
135
+ gs._features_dc = torch.tensor(gauss_data['_features_dc'], device=device, dtype=torch.float32)
136
+ gs._scaling = torch.tensor(gauss_data['_scaling'], device=device, dtype=torch.float32)
137
+ gs._rotation = torch.tensor(gauss_data['_rotation'], device=device, dtype=torch.float32)
138
+ gs._opacity = torch.tensor(gauss_data['_opacity'], device=device, dtype=torch.float32)
139
+ print(f"[unpack_state] Gaussian unpacked. Points: {gs.get_xyz.shape[0]}")
140
+
141
+ # Recreate mesh object using edict for compatibility if needed elsewhere
142
+ mesh = edict(
143
+ vertices=torch.tensor(mesh_data['vertices'], device=device, dtype=torch.float32),
144
+ faces=torch.tensor(mesh_data['faces'], device=device, dtype=torch.int64), # Faces are typically long/int64
145
+ )
146
+ print(f"[unpack_state] Mesh unpacked. Vertices: {mesh.vertices.shape[0]}, Faces: {mesh.faces.shape[0]}")
147
+
148
+ return gs, mesh
149
+ except Exception as e:
150
+ print(f"Error during unpack_state: {e}", file=sys.stderr)
151
+ traceback.print_exc()
152
+ raise # Re-raise the error
153
+
154
+
155
+ def get_seed(randomize_seed: bool, seed: int) -> int:
156
+ """Gets a seed value, randomizing if requested."""
157
+ new_seed = np.random.randint(0, MAX_SEED) if randomize_seed else seed
158
+ print(f"[get_seed] Randomize: {randomize_seed}, Input Seed: {seed}, Output Seed: {new_seed}")
159
+ return int(new_seed) # Ensure it's a standard int
160
+
161
 
162
  @spaces.GPU
163
  def text_to_3d(
 
168
  slat_guidance_strength: float,
169
  slat_sampling_steps: int,
170
  req: gr.Request,
171
+ ) -> Tuple[dict, Optional[str]]: # Return type changed Optional[str] for video path
172
  """
173
  Generates a 3D model (Gaussian and Mesh) from text and returns a
174
+ serializable state dictionary and potentially a video preview path.
175
+ >>> TEMPORARILY DISABLED VIDEO RENDERING FOR DEBUGGING <<<
176
  """
177
+ print(f"[text_to_3d - DEBUG MODE] Received prompt: '{prompt}', Seed: {seed}")
178
+ session_hash = req.session_hash
179
+ if not session_hash:
180
+ session_hash = f"no_session_{np.random.randint(10000, 99999)}" # Use consistent fallback
181
+ print(f"Warning: No session_hash in text_to_3d request, using temporary ID: {session_hash}")
182
+ user_dir = os.path.join(TMP_DIR, str(session_hash))
183
+ os.makedirs(user_dir, exist_ok=True) # Ensure it exists for this request
184
+ print(f"[text_to_3d - DEBUG MODE] User directory: {user_dir}")
185
 
186
  # --- Generation Pipeline ---
187
  try:
188
+ print("[text_to_3d - DEBUG MODE] Running Trellis pipeline...")
189
+ # Add more specific pipeline settings if needed based on Trellis docs
190
  outputs = pipeline.run(
191
+ prompt=prompt,
192
  seed=seed,
193
  formats=["gaussian", "mesh"], # Ensure both are generated
194
  sparse_structure_sampler_params={
 
199
  "steps": int(slat_sampling_steps), # Ensure steps are int
200
  "cfg_strength": float(slat_guidance_strength),
201
  },
202
+ # device='cuda' # Explicitly specify device if needed
203
  )
204
+ print("[text_to_3d - DEBUG MODE] Pipeline run completed.")
205
  except Exception as e:
206
+ print(f"❌ [text_to_3d - DEBUG MODE] Pipeline error: {e}", file=sys.stderr)
207
  traceback.print_exc()
208
+ raise gr.Error(f"Trellis pipeline failed during generation: {e}") # More specific error
209
 
210
  # --- Create Serializable State Dictionary ---
211
  try:
212
  state_dict = pack_state(outputs['gaussian'][0], outputs['mesh'][0])
213
  except Exception as e:
214
+ print(f"❌ [text_to_3d - DEBUG MODE] pack_state error: {e}", file=sys.stderr)
215
  traceback.print_exc()
216
+ raise gr.Error(f"Failed to pack state after generation: {e}")
217
 
218
  # --- Render Video Preview (TEMPORARILY DISABLED FOR DEBUGGING) ---
219
+ video_path = None # Explicitly set path to None for this debug version
220
+ print("[text_to_3d - DEBUG MODE] Skipping video rendering.")
221
+ # --- Original Video Code Block Start (Keep commented for now) ---
222
  # try:
223
  # print("[text_to_3d] Rendering video preview...")
224
  # video = render_utils.render_video(outputs['gaussian'][0], num_frames=120)['color']
225
  # video_geo = render_utils.render_video(outputs['mesh'][0], num_frames=120)['normal']
226
  # # Ensure video frames are uint8
227
  # video = [np.concatenate([v.astype(np.uint8), vg.astype(np.uint8)], axis=1) for v, vg in zip(video, video_geo)]
228
+ # video_path_tmp = os.path.join(user_dir, 'sample.mp4') # Use temp name
229
+ # imageio.mimsave(video_path_tmp, video, fps=15, quality=8) # Added quality setting
230
+ # print(f"[text_to_3d] Video saved to: {video_path_tmp}")
231
+ # video_path = video_path_tmp # Assign if successful
232
  # except Exception as e:
233
  # print(f"❌ [text_to_3d] Video rendering/saving error: {e}", file=sys.stderr)
234
  # traceback.print_exc()
235
  # # Still return state_dict, but maybe signal video error? Return None for path.
236
  # video_path = None # Indicate video failure
237
+ # --- Original Video Code Block End ---
238
+
239
+ # --- Cleanup and Return ---
240
+ if torch.cuda.is_available():
241
+ torch.cuda.empty_cache()
242
+ print("[text_to_3d - DEBUG MODE] Cleared CUDA cache.")
243
+
244
+ # --- Return Serializable Dictionary and None Video Path ---
245
+ print("[text_to_3d - DEBUG MODE] Returning state dictionary and None video path.")
246
+ # Ensure state_dict is not None before returning
247
+ if state_dict is None:
248
+ raise gr.Error("Failed to create state dictionary.")
249
+ return state_dict, video_path
250
+
251
+
252
+ @spaces.GPU(duration=120) # Increased duration slightly
253
+ def extract_glb(
254
+ state_dict: dict, # <-- Accepts the dictionary directly
255
+ mesh_simplify: float,
256
+ texture_size: int,
257
+ req: gr.Request,
258
+ ) -> Tuple[str, str]:
259
+ """
260
+ Extracts a GLB file from the provided 3D model state dictionary.
261
+ """
262
+ print(f"[extract_glb] Received request. Simplify: {mesh_simplify}, Texture Size: {texture_size}")
263
+ session_hash = req.session_hash
264
+ if not session_hash:
265
+ session_hash = f"no_session_{np.random.randint(10000, 99999)}"
266
+ print(f"Warning: No session_hash in extract_glb request, using temporary ID: {session_hash}")
267
+
268
+ if not isinstance(state_dict, dict):
269
+ print("❌ [extract_glb] Error: Invalid state_dict received (not a dictionary).")
270
+ raise gr.Error("Invalid state data received. Please generate the model first.")
271
+
272
+ user_dir = os.path.join(TMP_DIR, str(session_hash))
273
+ os.makedirs(user_dir, exist_ok=True) # Ensure it exists
274
+ print(f"[extract_glb] User directory: {user_dir}")
275
+
276
+ # --- Unpack state from the dictionary ---
277
+ try:
278
+ gs, mesh = unpack_state(state_dict)
279
+ except Exception as e:
280
+ print(f"❌ [extract_glb] unpack_state error: {e}", file=sys.stderr)
281
+ traceback.print_exc()
282
+ raise gr.Error(f"Failed to unpack state during GLB extraction: {e}")
283
+
284
+ # --- Postprocessing and Export ---
285
+ try:
286
+ print("[extract_glb] Converting to GLB...")
287
+ # Ensure parameters have correct types
288
+ simplify_factor = float(mesh_simplify)
289
+ tex_size = int(texture_size)
290
+ glb = postprocessing_utils.to_glb(gs, mesh, simplify=simplify_factor, texture_size=tex_size, verbose=True)
291
+ glb_path = os.path.join(user_dir, 'sample.glb')
292
+ print(f"[extract_glb] Exporting GLB to: {glb_path}")
293
+ glb.export(glb_path)
294
+ print("[extract_glb] GLB exported successfully.")
295
+ except Exception as e:
296
+ print(f"❌ [extract_glb] GLB conversion/export error: {e}", file=sys.stderr)
297
+ traceback.print_exc()
298
+ raise gr.Error(f"Failed to extract GLB: {e}")
299
+
300
+ # --- Cleanup and Return ---
301
+ if torch.cuda.is_available():
302
+ torch.cuda.empty_cache()
303
+ print("[extract_glb] Cleared CUDA cache.")
304
+
305
+ # Return path twice for both Model3D and DownloadButton components
306
+ print("[extract_glb] Returning GLB path.")
307
+ # Ensure path is returned, even if export failed somehow (though error should raise first)
308
+ return glb_path, glb_path
309
+
310
+
311
+ @spaces.GPU
312
+ def extract_gaussian(
313
+ state_dict: dict, # <-- Accepts the dictionary directly
314
+ req: gr.Request
315
+ ) -> Tuple[str, str]:
316
+ """
317
+ Extracts a PLY (Gaussian) file from the provided 3D model state dictionary.
318
+ """
319
+ print("[extract_gaussian] Received request.")
320
+ session_hash = req.session_hash
321
+ if not session_hash:
322
+ session_hash = f"no_session_{np.random.randint(10000, 99999)}"
323
+ print(f"Warning: No session_hash in extract_gaussian request, using temporary ID: {session_hash}")
324
+
325
+ if not isinstance(state_dict, dict):
326
+ print("❌ [extract_gaussian] Error: Invalid state_dict received (not a dictionary).")
327
+ raise gr.Error("Invalid state data received. Please generate the model first.")
328
+
329
+ user_dir = os.path.join(TMP_DIR, str(session_hash))
330
+ os.makedirs(user_dir, exist_ok=True) # Ensure it exists
331
+ print(f"[extract_gaussian] User directory: {user_dir}")
332
+
333
+ # --- Unpack state from the dictionary ---
334
+ try:
335
+ gs, _ = unpack_state(state_dict) # Only need Gaussian part
336
+ except Exception as e:
337
+ print(f"❌ [extract_gaussian] unpack_state error: {e}", file=sys.stderr)
338
+ traceback.print_exc()
339
+ raise gr.Error(f"Failed to unpack state during Gaussian extraction: {e}")
340
+
341
+ # --- Export PLY ---
342
+ try:
343
+ gaussian_path = os.path.join(user_dir, 'sample.ply')
344
+ print(f"[extract_gaussian] Saving PLY to: {gaussian_path}")
345
+ gs.save_ply(gaussian_path)
346
+ print("[extract_gaussian] PLY saved successfully.")
347
+ except Exception as e:
348
+ print(f"❌ [extract_gaussian] PLY saving error: {e}", file=sys.stderr)
349
+ traceback.print_exc()
350
+ raise gr.Error(f"Failed to extract Gaussian PLY: {e}")
351
 
352
  # --- Cleanup and Return ---
353
  if torch.cuda.is_available():
354
  torch.cuda.empty_cache()
355
+ print("[extract_gaussian] Cleared CUDA cache.")
356
+
357
+ # Return path twice for both Model3D and DownloadButton components
358
+ print("[extract_gaussian] Returning PLY path.")
359
+ # Ensure path is returned
360
+ return gaussian_path, gaussian_path
361
 
 
 
 
362
 
363
  # --- Gradio UI Definition ---
364
+ print("Setting up Gradio Blocks interface...")
365
+ # Define the interface layout
366
+ with gr.Blocks(delete_cache=(600, 600), title="TRELLIS Text-to-3D") as demo:
367
+ gr.Markdown("""
368
+ # Text to 3D Asset with [TRELLIS](https://trellis3d.github.io/)
369
+ * Type a text prompt and click "Generate" to create a 3D asset preview.
370
+ * Adjust extraction settings if desired.
371
+ * Click "Extract GLB" or "Extract Gaussian" to get the downloadable 3D file.
372
+ *(Note: Video preview is temporarily disabled for debugging)*
373
+ """)
374
+
375
+ # --- State Buffer ---
376
+ # This hidden component holds the dictionary linking generation and extraction.
377
+ output_buf = gr.State()
378
+
379
+ with gr.Row():
380
+ with gr.Column(scale=1): # Input column
381
+ text_prompt = gr.Textbox(label="Text Prompt", lines=5, placeholder="e.g., a cute red dragon")
382
+
383
+ with gr.Accordion(label="Generation Settings", open=False):
384
+ seed = gr.Slider(0, MAX_SEED, label="Seed", value=0, step=1)
385
+ randomize_seed = gr.Checkbox(label="Randomize Seed", value=True)
386
+ gr.Markdown("--- \n **Stage 1: Sparse Structure Generation**")
387
+ with gr.Row():
388
+ ss_guidance_strength = gr.Slider(0.0, 15.0, label="Guidance Strength", value=7.5, step=0.1)
389
+ ss_sampling_steps = gr.Slider(10, 50, label="Sampling Steps", value=25, step=1)
390
+ gr.Markdown("--- \n **Stage 2: Structured Latent Generation**")
391
+ with gr.Row():
392
+ slat_guidance_strength = gr.Slider(0.0, 15.0, label="Guidance Strength", value=7.5, step=0.1)
393
+ slat_sampling_steps = gr.Slider(10, 50, label="Sampling Steps", value=25, step=1)
394
+
395
+ generate_btn = gr.Button("Generate 3D Preview", variant="primary")
396
+
397
+ with gr.Accordion(label="GLB Extraction Settings", open=True): # Open by default
398
+ mesh_simplify = gr.Slider(0.9, 0.99, label="Simplify Factor", value=0.95, step=0.01, info="Higher value = less simplification (more polys)")
399
+ texture_size = gr.Slider(512, 2048, label="Texture Size (pixels)", value=1024, step=512, info="Size of the generated texture map")
400
+
401
+ with gr.Row():
402
+ extract_glb_btn = gr.Button("Extract GLB", interactive=False)
403
+ extract_gs_btn = gr.Button("Extract Gaussian (PLY)", interactive=False)
404
+ gr.Markdown("""
405
+ *NOTE: Gaussian file (.ply) can be very large (~50MB+) and may take time to process/download.*
406
+ """)
407
+
408
+ with gr.Column(scale=1): # Output column
409
+ # Video component remains for layout but won't show anything in this debug version
410
+ video_output = gr.Video(label="Generated 3D Preview (DISABLED FOR DEBUG)", autoplay=False, loop=False, value=None, height=350)
411
+ model_output = gr.Model3D(label="Extracted Model Preview", height=350, clear_color=[0.95, 0.95, 0.95, 1.0]) # Light background
412
+
413
+ with gr.Row():
414
+ download_glb = gr.DownloadButton(label="Download GLB", interactive=False)
415
+ download_gs = gr.DownloadButton(label="Download Gaussian (PLY)", interactive=False)
416
+
417
+ # --- Event Handlers ---
418
+ print("Defining Gradio event handlers...")
419
 
420
+ # Handle session start/end
421
+ demo.load(start_session, inputs=None, outputs=None)
422
+ demo.unload(end_session, inputs=None, outputs=None)
423
 
424
+ # --- Generate Button Click Flow ---
425
+ generate_event = generate_btn.click(
426
+ get_seed,
427
+ inputs=[randomize_seed, seed],
428
+ outputs=[seed],
429
+ api_name="get_seed"
430
+ ).then(
431
+ text_to_3d,
432
+ inputs=[text_prompt, seed, ss_guidance_strength, ss_sampling_steps, slat_guidance_strength, slat_sampling_steps],
433
+ # Output state_dict to buffer, output None to video component
434
+ outputs=[output_buf, video_output],
435
+ api_name="text_to_3d"
436
+ ).then(
437
+ # Function to update button interactivity after generation attempt
438
+ lambda: (
439
+ gr.Button(interactive=True),
440
+ gr.Button(interactive=True),
441
+ gr.DownloadButton(interactive=False),
442
+ gr.DownloadButton(interactive=False)
443
+ ),
444
+ inputs=None, # No inputs needed for the lambda
445
+ outputs=[extract_glb_btn, extract_gs_btn, download_glb, download_gs],
446
+ )
447
+
448
+ # --- Extract GLB Button Click Flow ---
449
+ extract_glb_event = extract_glb_btn.click(
450
+ extract_glb,
451
+ inputs=[output_buf, mesh_simplify, texture_size],
452
+ outputs=[model_output, download_glb],
453
+ api_name="extract_glb"
454
+ ).then(
455
+ lambda: gr.DownloadButton(interactive=True),
456
+ inputs=None,
457
+ outputs=[download_glb],
458
+ )
459
+
460
+ # --- Extract Gaussian Button Click Flow ---
461
+ extract_gs_event = extract_gs_btn.click(
462
+ extract_gaussian,
463
+ inputs=[output_buf],
464
+ outputs=[model_output, download_gs],
465
+ api_name="extract_gaussian"
466
+ ).then(
467
+ lambda: gr.DownloadButton(interactive=True),
468
+ inputs=None,
469
+ outputs=[download_gs],
470
+ )
471
+
472
+ # --- Clear Download Button Interactivity when model preview is cleared ---
473
+ model_output.clear(
474
+ lambda: (gr.DownloadButton(interactive=False), gr.DownloadButton(interactive=False)),
475
+ inputs=None,
476
+ outputs=[download_glb, download_gs]
477
+ )
478
+ # Also disable buttons if the (currently disabled) video output is cleared
479
+ video_output.clear(
480
+ lambda: (
481
+ gr.Button(interactive=False),
482
+ gr.Button(interactive=False),
483
+ gr.DownloadButton(interactive=False),
484
+ gr.DownloadButton(interactive=False)
485
+ ),
486
+ inputs=None,
487
+ outputs=[extract_glb_btn, extract_gs_btn, download_glb, download_gs],
488
+ )
489
+
490
+ print("Gradio interface setup complete.")
491
+
492
+
493
+ # --- Launch the Gradio app ---
494
+ # Main execution block
495
+ if __name__ == "__main__":
496
+ print("Loading Trellis pipeline...")
497
+ pipeline_loaded = False
498
+ try:
499
+ # Ensure model/variant matches requirements, use revision if needed
500
+ pipeline = TrellisTextTo3DPipeline.from_pretrained(
501
+ "JeffreyXiang/TRELLIS-text-xlarge",
502
+ # revision="main", # Specify if needed
503
+ torch_dtype=torch.float16 # Use float16 if GPU supports it for less memory
504
+ )
505
+ # Move to GPU if available
506
+ if torch.cuda.is_available():
507
+ pipeline = pipeline.to("cuda")
508
+ print("✅ Trellis pipeline loaded successfully to GPU.")
509
+ else:
510
+ print("⚠️ WARNING: CUDA not available, running on CPU (will be very slow).")
511
+ print("✅ Trellis pipeline loaded successfully to CPU.")
512
+ pipeline_loaded = True
513
+ except Exception as e:
514
+ print(f"❌ Failed to load Trellis pipeline: {e}", file=sys.stderr)
515
+ traceback.print_exc()
516
+ # Exit if pipeline is critical for the app to run
517
+ print("❌ Exiting due to pipeline load failure.")
518
+ sys.exit(1) # Exit if pipeline fails
519
+
520
+ if pipeline_loaded:
521
+ print("Launching Gradio demo...")
522
+ # Set share=True if you need a public link (e.g., for testing from outside local network)
523
+ # Set server_name="0.0.0.0" to allow access from local network IP
524
+ # Increased concurrency_limit and timeout for queue might help
525
+ demo.queue(
526
+ # default_concurrency_limit=5, # Adjust based on expected load and space resources
527
+ # api_open=True # Keep API accessible
528
+ ).launch(
529
+ # server_name="0.0.0.0", # Make accessible on local network
530
+ # share=False, # Set to True for public link if needed
531
+ debug=True, # Enable Gradio debug mode for more detailed logs
532
+ # prevent_thread_lock=True # May help with async issues in some cases
533
+ )
534
+ print("Gradio demo launched.")
535
+ else:
536
+ print("Gradio demo not launched due to pipeline loading failure.")