dkatz2391 commited on
Commit
2f98862
·
verified ·
1 Parent(s): ce32001
Files changed (1) hide show
  1. app.py +104 -328
app.py CHANGED
@@ -1,22 +1,17 @@
1
- # Version: 1.1.0 - API State Fix (2025-05-04)
2
- # Changes:
3
- # - Modified `text_to_3d` to explicitly return the serializable `state_dict` from `pack_state`
4
- # as the first return value. This ensures the dictionary is available via the API.
5
- # - Modified `extract_glb` and `extract_gaussian` to accept `state_dict: dict` as their first argument
6
- # instead of relying on the implicit `gr.State` object type when called via API.
7
- # - Kept Gradio UI bindings (`outputs=[output_buf, ...]`, `inputs=[output_buf, ...]`)
8
- # so the UI continues to function by passing the dictionary through output_buf.
9
- # - Added minor safety checks and logging.
10
 
11
  import gradio as gr
12
  import spaces
13
-
14
  import os
15
  import shutil
16
  os.environ['TOKENIZERS_PARALLELISM'] = 'true'
17
- # Fix potential SpConv issue if needed, try 'hash' or 'native'
18
- # os.environ.setdefault('SPCONV_ALGO', 'native') # Use setdefault to avoid overwriting if already set
19
- os.environ['SPCONV_ALGO'] = 'native' # Direct set as per original
20
 
21
  from typing import *
22
  import torch
@@ -26,105 +21,80 @@ from easydict import EasyDict as edict
26
  from trellis.pipelines import TrellisTextTo3DPipeline
27
  from trellis.representations import Gaussian, MeshExtractResult
28
  from trellis.utils import render_utils, postprocessing_utils
29
-
30
  import traceback
31
  import sys
32
 
33
-
34
  MAX_SEED = np.iinfo(np.int32).max
35
  TMP_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'tmp')
36
  os.makedirs(TMP_DIR, exist_ok=True)
37
 
38
 
39
  def start_session(req: gr.Request):
40
- """Creates a temporary directory for the user session."""
41
  user_dir = os.path.join(TMP_DIR, str(req.session_hash))
42
  os.makedirs(user_dir, exist_ok=True)
43
  print(f"Started session, created directory: {user_dir}")
44
 
45
 
46
  def end_session(req: gr.Request):
47
- """Removes the temporary directory for the user session."""
48
  user_dir = os.path.join(TMP_DIR, str(req.session_hash))
49
  if os.path.exists(user_dir):
50
- try:
51
- shutil.rmtree(user_dir)
52
- print(f"Ended session, removed directory: {user_dir}")
53
- except OSError as e:
54
- print(f"Error removing tmp directory {user_dir}: {e.strerror}", file=sys.stderr)
55
  else:
56
  print(f"Ended session, directory already removed: {user_dir}")
57
 
58
 
59
  def pack_state(gs: Gaussian, mesh: MeshExtractResult) -> dict:
60
- """Packs Gaussian and Mesh data into a serializable dictionary."""
61
- # Ensure tensors are on CPU and converted to numpy before returning the dict
62
- print("[pack_state] Packing state to dictionary...")
63
  packed_data = {
64
  'gaussian': {
65
- # Spread init_params first to ensure correct types
66
- **{k: v for k, v in gs.init_params.items()}, # Ensure init_params are included
67
- '_xyz': gs._xyz.detach().cpu().numpy(),
68
- '_features_dc': gs._features_dc.detach().cpu().numpy(),
69
- '_scaling': gs._scaling.detach().cpu().numpy(),
70
- '_rotation': gs._rotation.detach().cpu().numpy(),
71
- '_opacity': gs._opacity.detach().cpu().numpy(),
72
  },
73
  'mesh': {
74
- 'vertices': mesh.vertices.detach().cpu().numpy(),
75
- 'faces': mesh.faces.detach().cpu().numpy(),
76
  },
77
  }
78
- print(f"[pack_state] Dictionary created. Keys: {list(packed_data.keys())}, Gaussian points: {len(packed_data['gaussian']['_xyz'])}, Mesh vertices: {len(packed_data['mesh']['vertices'])}")
79
  return packed_data
80
 
81
 
82
  def unpack_state(state_dict: dict) -> Tuple[Gaussian, edict]:
83
- """Unpacks Gaussian and Mesh data from a dictionary."""
84
- print("[unpack_state] Unpacking state from dictionary...")
85
- if not isinstance(state_dict, dict) or 'gaussian' not in state_dict or 'mesh' not in state_dict:
86
- raise ValueError("Invalid state_dict structure passed to unpack_state.")
87
-
88
- # Ensure the device is correctly set when unpacking
89
  device = 'cuda' if torch.cuda.is_available() else 'cpu'
90
- print(f"[unpack_state] Using device: {device}")
91
-
92
  gauss_data = state_dict['gaussian']
93
  mesh_data = state_dict['mesh']
94
-
95
- # Recreate Gaussian object using parameters stored during packing
96
  gs = Gaussian(
97
- aabb=gauss_data.get('aabb'), # Use .get for safety
98
  sh_degree=gauss_data.get('sh_degree'),
99
  mininum_kernel_size=gauss_data.get('mininum_kernel_size'),
100
  scaling_bias=gauss_data.get('scaling_bias'),
101
  opacity_bias=gauss_data.get('opacity_bias'),
102
  scaling_activation=gauss_data.get('scaling_activation'),
103
  )
104
- # Load tensors, ensuring they are created on the correct device
105
- gs._xyz = torch.tensor(gauss_data['_xyz'], device=device, dtype=torch.float32)
106
- gs._features_dc = torch.tensor(gauss_data['_features_dc'], device=device, dtype=torch.float32)
107
- gs._scaling = torch.tensor(gauss_data['_scaling'], device=device, dtype=torch.float32)
108
- gs._rotation = torch.tensor(gauss_data['_rotation'], device=device, dtype=torch.float32)
109
- gs._opacity = torch.tensor(gauss_data['_opacity'], device=device, dtype=torch.float32)
110
- print(f"[unpack_state] Gaussian unpacked. Points: {gs.get_xyz.shape[0]}")
111
-
112
- # Recreate mesh object using edict for compatibility if needed elsewhere
113
  mesh = edict(
114
- vertices=torch.tensor(mesh_data['vertices'], device=device, dtype=torch.float32),
115
- faces=torch.tensor(mesh_data['faces'], device=device, dtype=torch.int64), # Faces are typically long/int64
116
  )
117
- print(f"[unpack_state] Mesh unpacked. Vertices: {mesh.vertices.shape[0]}, Faces: {mesh.faces.shape[0]}")
118
-
119
  return gs, mesh
120
 
121
 
122
  def get_seed(randomize_seed: bool, seed: int) -> int:
123
- """Gets a seed value, randomizing if requested."""
124
  new_seed = np.random.randint(0, MAX_SEED) if randomize_seed else seed
125
- print(f"[get_seed] Randomize: {randomize_seed}, Input Seed: {seed}, Output Seed: {new_seed}")
126
- return int(new_seed) # Ensure it's a standard int
127
-
128
 
129
  @spaces.GPU
130
  def text_to_3d(
@@ -135,331 +105,137 @@ def text_to_3d(
135
  slat_guidance_strength: float,
136
  slat_sampling_steps: int,
137
  req: gr.Request,
138
- ) -> Tuple[dict, str]: # Return type changed for clarity
139
- """
140
- Generates a 3D model (Gaussian and Mesh) from text and returns a
141
- serializable state dictionary and a video preview path.
142
- """
143
- print(f"[text_to_3d] Received prompt: '{prompt}', Seed: {seed}")
 
 
 
 
 
 
144
  user_dir = os.path.join(TMP_DIR, str(req.session_hash))
145
  os.makedirs(user_dir, exist_ok=True)
146
- print(f"[text_to_3d] User directory: {user_dir}")
147
-
148
- # --- Generation Pipeline ---
149
- try:
150
- print("[text_to_3d] Running Trellis pipeline...")
151
- outputs = pipeline.run(
152
- prompt,
153
- seed=seed,
154
- formats=["gaussian", "mesh"], # Ensure both are generated
155
- sparse_structure_sampler_params={
156
- "steps": int(ss_sampling_steps), # Ensure steps are int
157
- "cfg_strength": float(ss_guidance_strength),
158
- },
159
- slat_sampler_params={
160
- "steps": int(slat_sampling_steps), # Ensure steps are int
161
- "cfg_strength": float(slat_guidance_strength),
162
- },
163
- )
164
- print("[text_to_3d] Pipeline run completed.")
165
- except Exception as e:
166
- print(f"❌ [text_to_3d] Pipeline error: {e}", file=sys.stderr)
167
- traceback.print_exc()
168
- # Return an empty dict and maybe an error indicator path or None?
169
- # For now, re-raise to signal failure clearly upstream.
170
- raise gr.Error(f"Trellis pipeline failed: {e}")
171
-
172
- # --- Create Serializable State Dictionary --- VITAL CHANGE for API
173
- # This dictionary holds the necessary data for later extraction.
174
- try:
175
- state_dict = pack_state(outputs['gaussian'][0], outputs['mesh'][0])
176
- except Exception as e:
177
- print(f"❌ [text_to_3d] pack_state error: {e}", file=sys.stderr)
178
- traceback.print_exc()
179
- raise gr.Error(f"Failed to pack state: {e}")
180
-
181
- # --- Render Video Preview ---
182
- try:
183
- print("[text_to_3d] Rendering video preview...")
184
- video = render_utils.render_video(outputs['gaussian'][0], num_frames=120)['color']
185
- video_geo = render_utils.render_video(outputs['mesh'][0], num_frames=120)['normal']
186
- # Ensure video frames are uint8
187
- video = [np.concatenate([v.astype(np.uint8), vg.astype(np.uint8)], axis=1) for v, vg in zip(video, video_geo)]
188
- video_path = os.path.join(user_dir, 'sample.mp4')
189
- imageio.mimsave(video_path, video, fps=15, quality=8) # Added quality setting
190
- print(f"[text_to_3d] Video saved to: {video_path}")
191
- except Exception as e:
192
- print(f"❌ [text_to_3d] Video rendering/saving error: {e}", file=sys.stderr)
193
- traceback.print_exc()
194
- # Still return state_dict, but maybe signal video error? Return None for path.
195
- video_path = None # Indicate video failure
196
-
197
- # --- Cleanup and Return ---
198
- # Clear CUDA cache if GPU was used
199
- if torch.cuda.is_available():
200
- torch.cuda.empty_cache()
201
- print("[text_to_3d] Cleared CUDA cache.")
202
-
203
- # --- Return Serializable Dictionary and Video Path --- VITAL CHANGE for API
204
- print("[text_to_3d] Returning state dictionary and video path.")
205
  return state_dict, video_path
206
 
207
-
208
- @spaces.GPU(duration=120) # Increased duration slightly
209
  def extract_glb(
210
- state_dict: dict, # <-- VITAL CHANGE: Accept the dictionary directly
211
  mesh_simplify: float,
212
  texture_size: int,
213
  req: gr.Request,
214
  ) -> Tuple[str, str]:
215
- """
216
- Extracts a GLB file from the provided 3D model state dictionary.
217
- """
218
- print(f"[extract_glb] Received request. Simplify: {mesh_simplify}, Texture Size: {texture_size}")
219
- if not isinstance(state_dict, dict):
220
- print("❌ [extract_glb] Error: Invalid state_dict received (not a dictionary).")
221
- raise gr.Error("Invalid state data received. Please generate the model first.")
222
-
223
  user_dir = os.path.join(TMP_DIR, str(req.session_hash))
224
  os.makedirs(user_dir, exist_ok=True)
225
- print(f"[extract_glb] User directory: {user_dir}")
226
-
227
- # --- Unpack state from the dictionary --- VITAL CHANGE for API
228
- try:
229
- gs, mesh = unpack_state(state_dict)
230
- except Exception as e:
231
- print(f"❌ [extract_glb] unpack_state error: {e}", file=sys.stderr)
232
- traceback.print_exc()
233
- raise gr.Error(f"Failed to unpack state: {e}")
234
-
235
- # --- Postprocessing and Export ---
236
- try:
237
- print("[extract_glb] Converting to GLB...")
238
- glb = postprocessing_utils.to_glb(gs, mesh, simplify=float(mesh_simplify), texture_size=int(texture_size), verbose=True) # Verbose for debugging
239
- glb_path = os.path.join(user_dir, 'sample.glb')
240
- print(f"[extract_glb] Exporting GLB to: {glb_path}")
241
- glb.export(glb_path)
242
- print("[extract_glb] GLB exported successfully.")
243
- except Exception as e:
244
- print(f"❌ [extract_glb] GLB conversion/export error: {e}", file=sys.stderr)
245
- traceback.print_exc()
246
- raise gr.Error(f"Failed to extract GLB: {e}")
247
-
248
- # --- Cleanup and Return ---
249
- if torch.cuda.is_available():
250
- torch.cuda.empty_cache()
251
- print("[extract_glb] Cleared CUDA cache.")
252
-
253
- # Return path twice for both Model3D and DownloadButton components
254
- print("[extract_glb] Returning GLB path.")
255
  return glb_path, glb_path
256
 
257
-
258
  @spaces.GPU
259
  def extract_gaussian(
260
- state_dict: dict, # <-- VITAL CHANGE: Accept the dictionary directly
261
  req: gr.Request
262
  ) -> Tuple[str, str]:
263
- """
264
- Extracts a PLY (Gaussian) file from the provided 3D model state dictionary.
265
- """
266
- print("[extract_gaussian] Received request.")
267
- if not isinstance(state_dict, dict):
268
- print("❌ [extract_gaussian] Error: Invalid state_dict received (not a dictionary).")
269
- raise gr.Error("Invalid state data received. Please generate the model first.")
270
-
271
  user_dir = os.path.join(TMP_DIR, str(req.session_hash))
272
  os.makedirs(user_dir, exist_ok=True)
273
- print(f"[extract_gaussian] User directory: {user_dir}")
274
-
275
- # --- Unpack state from the dictionary --- VITAL CHANGE for API
276
- try:
277
- gs, _ = unpack_state(state_dict) # Only need Gaussian part
278
- except Exception as e:
279
- print(f"❌ [extract_gaussian] unpack_state error: {e}", file=sys.stderr)
280
- traceback.print_exc()
281
- raise gr.Error(f"Failed to unpack state: {e}")
282
-
283
- # --- Export PLY ---
284
- try:
285
- gaussian_path = os.path.join(user_dir, 'sample.ply')
286
- print(f"[extract_gaussian] Saving PLY to: {gaussian_path}")
287
- gs.save_ply(gaussian_path)
288
- print("[extract_gaussian] PLY saved successfully.")
289
- except Exception as e:
290
- print(f"❌ [extract_gaussian] PLY saving error: {e}", file=sys.stderr)
291
- traceback.print_exc()
292
- raise gr.Error(f"Failed to extract Gaussian PLY: {e}")
293
-
294
- # --- Cleanup and Return ---
295
- if torch.cuda.is_available():
296
- torch.cuda.empty_cache()
297
- print("[extract_gaussian] Cleared CUDA cache.")
298
-
299
- # Return path twice for both Model3D and DownloadButton components
300
- print("[extract_gaussian] Returning PLY path.")
301
  return gaussian_path, gaussian_path
302
 
303
-
304
  # --- Gradio UI Definition ---
305
- print("Setting up Gradio Blocks interface...")
306
  with gr.Blocks(delete_cache=(600, 600), title="TRELLIS Text-to-3D") as demo:
307
  gr.Markdown("""
308
  # Text to 3D Asset with [TRELLIS](https://trellis3d.github.io/)
309
- * Type a text prompt and click "Generate" to create a 3D asset preview.
310
- * Adjust extraction settings if desired.
311
- * Click "Extract GLB" or "Extract Gaussian" to get the downloadable 3D file.
312
  """)
313
 
314
- # --- State Buffer ---
315
- # This hidden component will hold the dictionary returned by text_to_3d,
316
- # acting as the state link between generation and extraction for the UI/API.
317
  output_buf = gr.State()
318
 
319
  with gr.Row():
320
- with gr.Column(scale=1): # Input column
321
- text_prompt = gr.Textbox(label="Text Prompt", lines=5, placeholder="e.g., a cute red dragon")
322
-
323
  with gr.Accordion(label="Generation Settings", open=False):
324
  seed = gr.Slider(0, MAX_SEED, label="Seed", value=0, step=1)
325
  randomize_seed = gr.Checkbox(label="Randomize Seed", value=True)
326
- gr.Markdown("--- \n **Stage 1: Sparse Structure Generation**")
327
- with gr.Row():
328
- ss_guidance_strength = gr.Slider(0.0, 15.0, label="Guidance Strength", value=7.5, step=0.1)
329
- ss_sampling_steps = gr.Slider(10, 50, label="Sampling Steps", value=25, step=1)
330
- gr.Markdown("--- \n **Stage 2: Structured Latent Generation**")
331
- with gr.Row():
332
- slat_guidance_strength = gr.Slider(0.0, 15.0, label="Guidance Strength", value=7.5, step=0.1)
333
- slat_sampling_steps = gr.Slider(10, 50, label="Sampling Steps", value=25, step=1)
334
-
335
  generate_btn = gr.Button("Generate 3D Preview", variant="primary")
336
-
337
- with gr.Accordion(label="GLB Extraction Settings", open=True): # Open by default
338
- # Tooltips added for clarity
339
- mesh_simplify = gr.Slider(0.9, 0.99, label="Simplify Factor", value=0.95, step=0.01, info="Higher value = less simplification (more polys)")
340
- texture_size = gr.Slider(512, 2048, label="Texture Size (pixels)", value=1024, step=512, info="Size of the generated texture map")
341
-
342
- with gr.Row():
343
- extract_glb_btn = gr.Button("Extract GLB", interactive=False)
344
- extract_gs_btn = gr.Button("Extract Gaussian (PLY)", interactive=False)
345
- gr.Markdown("""
346
- *NOTE: Gaussian file (.ply) can be very large (~50MB+) and may take time to process/download.*
347
- """)
348
-
349
- with gr.Column(scale=1): # Output column
350
- video_output = gr.Video(label="Generated 3D Preview (Geometry | Texture)", autoplay=True, loop=True, height=350) # Slightly larger height
351
- model_output = gr.Model3D(label="Extracted Model Preview", height=350, clear_color=[0.95, 0.95, 0.95, 1.0]) # Light background
352
-
353
- with gr.Row():
354
- # Link download button visibility/interactivity to model_output potentially
355
- download_glb = gr.DownloadButton(label="Download GLB", interactive=False)
356
- download_gs = gr.DownloadButton(label="Download Gaussian (PLY)", interactive=False)
357
-
358
- # --- Event Handlers ---
359
- print("Defining Gradio event handlers...")
360
-
361
- # Handle session start/end
362
- demo.load(start_session)
363
- demo.unload(end_session)
364
-
365
- # --- Generate Button Click Flow ---
366
- # 1. Get Seed -> 2. Run text_to_3d -> 3. Enable extraction buttons
367
  generate_event = generate_btn.click(
368
  get_seed,
369
  inputs=[randomize_seed, seed],
370
  outputs=[seed],
371
- api_name="get_seed" # Optional API name
372
  ).then(
373
  text_to_3d,
374
  inputs=[text_prompt, seed, ss_guidance_strength, ss_sampling_steps, slat_guidance_strength, slat_sampling_steps],
375
- outputs=[output_buf, video_output], # output_buf receives state_dict
376
- api_name="text_to_3d"
377
  ).then(
378
- lambda: ( # Return tuple for multiple outputs
379
- gr.Button(interactive=True),
380
- gr.Button(interactive=True),
381
- gr.DownloadButton(interactive=False), # Ensure download buttons are disabled initially
382
- gr.DownloadButton(interactive=False)
383
- ),
384
- outputs=[extract_glb_btn, extract_gs_btn, download_glb, download_gs], # Update interactivity
385
  )
386
 
387
- # --- Clear video/model outputs if prompt changes (optional, prevents confusion)
388
- # text_prompt.change(lambda: (None, None, gr.Button(interactive=False), gr.Button(interactive=False)), outputs=[video_output, model_output, extract_glb_btn, extract_gs_btn])
389
-
390
- # --- Extract GLB Button Click Flow ---
391
- # 1. Run extract_glb -> 2. Update Model3D and Download Button
392
  extract_glb_event = extract_glb_btn.click(
393
  extract_glb,
394
- inputs=[output_buf, mesh_simplify, texture_size], # Pass the state_dict via output_buf
395
- outputs=[model_output, download_glb], # Returns path to both
396
- api_name="extract_glb"
397
  ).then(
398
- lambda: gr.DownloadButton(interactive=True), # Enable download button
399
  outputs=[download_glb],
400
  )
401
 
402
- # --- Extract Gaussian Button Click Flow ---
403
- # 1. Run extract_gaussian -> 2. Update Model3D and Download Button
404
  extract_gs_event = extract_gs_btn.click(
405
  extract_gaussian,
406
- inputs=[output_buf], # Pass the state_dict via output_buf
407
- outputs=[model_output, download_gs], # Returns path to both
408
- api_name="extract_gaussian"
409
  ).then(
410
- lambda: gr.DownloadButton(interactive=True), # Enable download button
411
  outputs=[download_gs],
412
  )
413
 
414
- # --- Clear Download Button Interactivity when model preview is cleared ---
415
- # This might be redundant if generate disables them, but adds safety
416
  model_output.clear(
417
- lambda: (gr.DownloadButton(interactive=False), gr.DownloadButton(interactive=False)),
418
- outputs=[download_glb, download_gs]
419
  )
420
- video_output.clear( # Also disable extraction if video is cleared (e.g., new generation starts)
421
- lambda: (
422
- gr.Button(interactive=False),
423
- gr.Button(interactive=False),
424
- gr.DownloadButton(interactive=False),
425
- gr.DownloadButton(interactive=False)
426
- ),
427
  outputs=[extract_glb_btn, extract_gs_btn, download_glb, download_gs],
428
  )
429
 
430
- print("Gradio interface setup complete.")
431
-
432
-
433
- # --- Launch the Gradio app ---
434
  if __name__ == "__main__":
435
- print("Loading Trellis pipeline...")
436
- try:
437
- # Ensure model/variant matches requirements, use revision if needed
438
- pipeline = TrellisTextTo3DPipeline.from_pretrained(
439
- "JeffreyXiang/TRELLIS-text-xlarge",
440
- # revision="main", # Specify if needed
441
- torch_dtype=torch.float16 # Use float16 if GPU supports it for less memory
442
- )
443
- # Move to GPU if available
444
- if torch.cuda.is_available():
445
- pipeline = pipeline.to("cuda")
446
- print("✅ Trellis pipeline loaded successfully to GPU.")
447
- else:
448
- print("⚠️ WARNING: CUDA not available, running on CPU (will be very slow).")
449
- print("✅ Trellis pipeline loaded successfully to CPU.")
450
- except Exception as e:
451
- print(f"❌ Failed to load Trellis pipeline: {e}", file=sys.stderr)
452
- traceback.print_exc()
453
- # Exit if pipeline is critical for the app to run
454
- print("❌ Exiting due to pipeline load failure.")
455
- sys.exit(1)
456
-
457
- print("Launching Gradio demo...")
458
- # Set share=True if you need a public link (e.g., for testing from outside local network)
459
- # Set server_name="0.0.0.0" to allow access from local network IP
460
- demo.queue().launch( # Use queue for potentially long-running tasks
461
- # server_name="0.0.0.0",
462
- # share=False,
463
- debug=True # Enable debug mode for more logs
464
  )
465
- print("Gradio demo launched.")
 
 
1
+ # Version: 1.1.1 - Targeted signature and state fixes
2
+ # Applied:
3
+ # - Removed unsupported inputs/outputs kwargs on demo.load/unload
4
+ # - Converted NumPy arrays to lists in pack_state for JSON safety
5
+ # - Fixed indentation in Blocks event-handlers
6
+ # - Verified clear() callbacks use only callback + outputs
7
+ # - Bumped version, added comments at change sites
 
 
8
 
9
  import gradio as gr
10
  import spaces
 
11
  import os
12
  import shutil
13
  os.environ['TOKENIZERS_PARALLELISM'] = 'true'
14
+ os.environ['SPCONV_ALGO'] = 'native'
 
 
15
 
16
  from typing import *
17
  import torch
 
21
  from trellis.pipelines import TrellisTextTo3DPipeline
22
  from trellis.representations import Gaussian, MeshExtractResult
23
  from trellis.utils import render_utils, postprocessing_utils
 
24
  import traceback
25
  import sys
26
 
 
27
  MAX_SEED = np.iinfo(np.int32).max
28
  TMP_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'tmp')
29
  os.makedirs(TMP_DIR, exist_ok=True)
30
 
31
 
32
  def start_session(req: gr.Request):
 
33
  user_dir = os.path.join(TMP_DIR, str(req.session_hash))
34
  os.makedirs(user_dir, exist_ok=True)
35
  print(f"Started session, created directory: {user_dir}")
36
 
37
 
38
  def end_session(req: gr.Request):
 
39
  user_dir = os.path.join(TMP_DIR, str(req.session_hash))
40
  if os.path.exists(user_dir):
41
+ try:
42
+ shutil.rmtree(user_dir)
43
+ print(f"Ended session, removed directory: {user_dir}")
44
+ except OSError as e:
45
+ print(f"Error removing tmp directory {user_dir}: {e.strerror}", file=sys.stderr)
46
  else:
47
  print(f"Ended session, directory already removed: {user_dir}")
48
 
49
 
50
  def pack_state(gs: Gaussian, mesh: MeshExtractResult) -> dict:
51
+ """Packs Gaussian and Mesh data into a JSON-serializable dictionary."""
 
 
52
  packed_data = {
53
  'gaussian': {
54
+ **{k: v for k, v in gs.init_params.items()},
55
+ # FIX: convert arrays to lists for JSON
56
+ '_xyz': gs._xyz.detach().cpu().numpy().tolist(),
57
+ '_features_dc': gs._features_dc.detach().cpu().numpy().tolist(),
58
+ '_scaling': gs._scaling.detach().cpu().numpy().tolist(),
59
+ '_rotation': gs._rotation.detach().cpu().numpy().tolist(),
60
+ '_opacity': gs._opacity.detach().cpu().numpy().tolist(),
61
  },
62
  'mesh': {
63
+ 'vertices': mesh.vertices.detach().cpu().numpy().tolist(),
64
+ 'faces': mesh.faces.detach().cpu().numpy().tolist(),
65
  },
66
  }
 
67
  return packed_data
68
 
69
 
70
  def unpack_state(state_dict: dict) -> Tuple[Gaussian, edict]:
71
+ print("[unpack_state] Unpacking state from dictionary... ")
 
 
 
 
 
72
  device = 'cuda' if torch.cuda.is_available() else 'cpu'
 
 
73
  gauss_data = state_dict['gaussian']
74
  mesh_data = state_dict['mesh']
 
 
75
  gs = Gaussian(
76
+ aabb=gauss_data.get('aabb'),
77
  sh_degree=gauss_data.get('sh_degree'),
78
  mininum_kernel_size=gauss_data.get('mininum_kernel_size'),
79
  scaling_bias=gauss_data.get('scaling_bias'),
80
  opacity_bias=gauss_data.get('opacity_bias'),
81
  scaling_activation=gauss_data.get('scaling_activation'),
82
  )
83
+ gs._xyz = torch.tensor(np.array(gauss_data['_xyz']), device=device, dtype=torch.float32)
84
+ gs._features_dc = torch.tensor(np.array(gauss_data['_features_dc']), device=device, dtype=torch.float32)
85
+ gs._scaling = torch.tensor(np.array(gauss_data['_scaling']), device=device, dtype=torch.float32)
86
+ gs._rotation = torch.tensor(np.array(gauss_data['_rotation']), device=device, dtype=torch.float32)
87
+ gs._opacity = torch.tensor(np.array(gauss_data['_opacity']), device=device, dtype=torch.float32)
 
 
 
 
88
  mesh = edict(
89
+ vertices=torch.tensor(np.array(mesh_data['vertices']), device=device, dtype=torch.float32),
90
+ faces=torch.tensor(np.array(mesh_data['faces']), device=device, dtype=torch.int64),
91
  )
 
 
92
  return gs, mesh
93
 
94
 
95
  def get_seed(randomize_seed: bool, seed: int) -> int:
 
96
  new_seed = np.random.randint(0, MAX_SEED) if randomize_seed else seed
97
+ return int(new_seed)
 
 
98
 
99
  @spaces.GPU
100
  def text_to_3d(
 
105
  slat_guidance_strength: float,
106
  slat_sampling_steps: int,
107
  req: gr.Request,
108
+ ) -> Tuple[dict, str]:
109
+ outputs = pipeline.run(
110
+ prompt,
111
+ seed=seed,
112
+ formats=["gaussian", "mesh"],
113
+ sparse_structure_sampler_params={"steps": int(ss_sampling_steps), "cfg_strength": float(ss_guidance_strength)},
114
+ slat_sampler_params={"steps": int(slat_sampling_steps), "cfg_strength": float(slat_guidance_strength)},
115
+ )
116
+ state_dict = pack_state(outputs['gaussian'][0], outputs['mesh'][0])
117
+ video = render_utils.render_video(outputs['gaussian'][0], num_frames=120)['color']
118
+ video_geo = render_utils.render_video(outputs['mesh'][0], num_frames=120)['normal']
119
+ video_combined = [np.concatenate([v.astype(np.uint8), vg.astype(np.uint8)], axis=1) for v, vg in zip(video, video_geo)]
120
  user_dir = os.path.join(TMP_DIR, str(req.session_hash))
121
  os.makedirs(user_dir, exist_ok=True)
122
+ video_path = os.path.join(user_dir, 'sample.mp4')
123
+ imageio.mimsave(video_path, video_combined, fps=15, quality=8)
124
+ if torch.cuda.is_available(): torch.cuda.empty_cache()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
125
  return state_dict, video_path
126
 
127
+ @spaces.GPU(duration=120)
 
128
  def extract_glb(
129
+ state_dict: dict,
130
  mesh_simplify: float,
131
  texture_size: int,
132
  req: gr.Request,
133
  ) -> Tuple[str, str]:
134
+ gs, mesh = unpack_state(state_dict)
 
 
 
 
 
 
 
135
  user_dir = os.path.join(TMP_DIR, str(req.session_hash))
136
  os.makedirs(user_dir, exist_ok=True)
137
+ glb = postprocessing_utils.to_glb(gs, mesh, simplify=float(mesh_simplify), texture_size=int(texture_size), verbose=True)
138
+ glb_path = os.path.join(user_dir, 'sample.glb')
139
+ glb.export(glb_path)
140
+ if torch.cuda.is_available(): torch.cuda.empty_cache()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
141
  return glb_path, glb_path
142
 
 
143
  @spaces.GPU
144
  def extract_gaussian(
145
+ state_dict: dict,
146
  req: gr.Request
147
  ) -> Tuple[str, str]:
148
+ gs, _ = unpack_state(state_dict)
 
 
 
 
 
 
 
149
  user_dir = os.path.join(TMP_DIR, str(req.session_hash))
150
  os.makedirs(user_dir, exist_ok=True)
151
+ gaussian_path = os.path.join(user_dir, 'sample.ply')
152
+ gs.save_ply(gaussian_path)
153
+ if torch.cuda.is_available(): torch.cuda.empty_cache()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
154
  return gaussian_path, gaussian_path
155
 
 
156
  # --- Gradio UI Definition ---
 
157
  with gr.Blocks(delete_cache=(600, 600), title="TRELLIS Text-to-3D") as demo:
158
  gr.Markdown("""
159
  # Text to 3D Asset with [TRELLIS](https://trellis3d.github.io/)
 
 
 
160
  """)
161
 
162
+ # State buffer
 
 
163
  output_buf = gr.State()
164
 
165
  with gr.Row():
166
+ with gr.Column(scale=1):
167
+ text_prompt = gr.Textbox(label="Text Prompt", lines=5)
 
168
  with gr.Accordion(label="Generation Settings", open=False):
169
  seed = gr.Slider(0, MAX_SEED, label="Seed", value=0, step=1)
170
  randomize_seed = gr.Checkbox(label="Randomize Seed", value=True)
171
+ gr.Markdown("---\n**Stage 1**")
172
+ ss_guidance_strength = gr.Slider(0.0, 15.0, label="Guidance Strength", value=7.5, step=0.1)
173
+ ss_sampling_steps = gr.Slider(10, 50, label="Sampling Steps", value=25, step=1)
174
+ gr.Markdown("---\n**Stage 2**")
175
+ slat_guidance_strength = gr.Slider(0.0, 15.0, label="Guidance Strength", value=7.5, step=0.1)
176
+ slat_sampling_steps = gr.Slider(10, 50, label="Sampling Steps", value=25, step=1)
 
 
 
177
  generate_btn = gr.Button("Generate 3D Preview", variant="primary")
178
+ with gr.Accordion(label="GLB Extraction Settings", open=True):
179
+ mesh_simplify = gr.Slider(0.9, 0.99, label="Simplify Factor", value=0.95, step=0.01)
180
+ texture_size = gr.Slider(512, 2048, label="Texture Size", value=1024, step=512)
181
+ extract_glb_btn = gr.Button("Extract GLB", interactive=False)
182
+ extract_gs_btn = gr.Button("Extract Gaussian (PLY)", interactive=False)
183
+ download_glb = gr.DownloadButton(label="Download GLB", interactive=False)
184
+ download_gs = gr.DownloadButton(label="Download Gaussian (PLY)", interactive=False)
185
+ with gr.Column(scale=1):
186
+ video_output = gr.Video(label="3D Preview", autoplay=True, loop=True)
187
+ model_output = gr.Model3D(label="Extracted Model Preview")
188
+
189
+ # --- Event handlers ---
190
+ demo.load(start_session) # FIX: remove inputs/outputs kwargs
191
+ demo.unload(end_session) # FIX: remove inputs/outputs kwargs
192
+
193
+ # Align indentation to one level under Blocks
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
194
  generate_event = generate_btn.click(
195
  get_seed,
196
  inputs=[randomize_seed, seed],
197
  outputs=[seed],
 
198
  ).then(
199
  text_to_3d,
200
  inputs=[text_prompt, seed, ss_guidance_strength, ss_sampling_steps, slat_guidance_strength, slat_sampling_steps],
201
+ outputs=[output_buf, video_output],
 
202
  ).then(
203
+ lambda: (extract_glb_btn.update(interactive=True), extract_gs_btn.update(interactive=True)),
204
+ outputs=[extract_glb_btn, extract_gs_btn],
 
 
 
 
 
205
  )
206
 
 
 
 
 
 
207
  extract_glb_event = extract_glb_btn.click(
208
  extract_glb,
209
+ inputs=[output_buf, mesh_simplify, texture_size],
210
+ outputs=[model_output, download_glb],
 
211
  ).then(
212
+ lambda: download_glb.update(interactive=True),
213
  outputs=[download_glb],
214
  )
215
 
 
 
216
  extract_gs_event = extract_gs_btn.click(
217
  extract_gaussian,
218
+ inputs=[output_buf],
219
+ outputs=[model_output, download_gs],
 
220
  ).then(
221
+ lambda: download_gs.update(interactive=True),
222
  outputs=[download_gs],
223
  )
224
 
225
+ # Clear callbacks
 
226
  model_output.clear(
227
+ lambda: (download_glb.update(interactive=False), download_gs.update(interactive=False)),
228
+ outputs=[download_glb, download_gs],
229
  )
230
+ video_output.clear(
231
+ lambda: (extract_glb_btn.update(interactive=False), extract_gs_btn.update(interactive=False), download_glb.update(interactive=False), download_gs.update(interactive=False)),
 
 
 
 
 
232
  outputs=[extract_glb_btn, extract_gs_btn, download_glb, download_gs],
233
  )
234
 
 
 
 
 
235
  if __name__ == "__main__":
236
+ pipeline = TrellisTextTo3DPipeline.from_pretrained(
237
+ "JeffreyXiang/TRELLIS-text-xlarge",
238
+ torch_dtype=torch.float16
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
239
  )
240
+ if torch.cuda.is_available(): pipeline = pipeline.to("cuda")
241
+ demo.queue().launch(debug=True)