dkatz2391 commited on
Commit
63ce34f
·
verified ·
1 Parent(s): 3447081

comment out video save -

Browse files
Files changed (1) hide show
  1. app.py +66 -315
app.py CHANGED
@@ -1,102 +1,4 @@
1
- # Version: Add API State Fix (2025-05-04)
2
- # Changes:
3
- # - Modified `text_to_3d` to explicitly return the serializable `state_dict` from `pack_state`
4
- # as the first return value. This ensures the dictionary is available via the API.
5
- # - Modified `extract_glb` to accept `state_dict: dict` as its first argument instead of
6
- # relying on the implicit `gr.State` object type when called via API.
7
- # - Kept Gradio UI bindings (`outputs=[output_buf, ...]`, `inputs=[output_buf, ...]`)
8
- # so the UI continues to function by passing the dictionary through output_buf.
9
-
10
- import gradio as gr
11
- import spaces
12
-
13
- import os
14
- import shutil
15
- os.environ['TOKENIZERS_PARALLELISM'] = 'true'
16
- os.environ['SPCONV_ALGO'] = 'native'
17
- from typing import *
18
- import torch
19
- import numpy as np
20
- import imageio
21
- from easydict import EasyDict as edict
22
- from trellis.pipelines import TrellisTextTo3DPipeline
23
- from trellis.representations import Gaussian, MeshExtractResult
24
- from trellis.utils import render_utils, postprocessing_utils
25
-
26
- import traceback
27
- import sys
28
-
29
-
30
- MAX_SEED = np.iinfo(np.int32).max
31
- TMP_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'tmp')
32
- os.makedirs(TMP_DIR, exist_ok=True)
33
-
34
-
35
- def start_session(req: gr.Request):
36
- user_dir = os.path.join(TMP_DIR, str(req.session_hash))
37
- os.makedirs(user_dir, exist_ok=True)
38
-
39
-
40
- def end_session(req: gr.Request):
41
- user_dir = os.path.join(TMP_DIR, str(req.session_hash))
42
- # Add safety check before removing
43
- if os.path.exists(user_dir):
44
- try:
45
- shutil.rmtree(user_dir)
46
- except OSError as e:
47
- print(f"Error removing tmp directory {user_dir}: {e.strerror}", file=sys.stderr)
48
-
49
-
50
- def pack_state(gs: Gaussian, mesh: MeshExtractResult) -> dict:
51
- # Ensure tensors are on CPU and converted to numpy before returning the dict
52
- return {
53
- 'gaussian': {
54
- **gs.init_params,
55
- '_xyz': gs._xyz.detach().cpu().numpy(),
56
- '_features_dc': gs._features_dc.detach().cpu().numpy(),
57
- '_scaling': gs._scaling.detach().cpu().numpy(),
58
- '_rotation': gs._rotation.detach().cpu().numpy(),
59
- '_opacity': gs._opacity.detach().cpu().numpy(),
60
- },
61
- 'mesh': {
62
- 'vertices': mesh.vertices.detach().cpu().numpy(),
63
- 'faces': mesh.faces.detach().cpu().numpy(),
64
- },
65
- }
66
-
67
-
68
- def unpack_state(state_dict: dict) -> Tuple[Gaussian, edict]:
69
- # Ensure the device is correctly set when unpacking
70
- device = 'cuda' if torch.cuda.is_available() else 'cpu'
71
-
72
- gs = Gaussian(
73
- aabb=state_dict['gaussian']['aabb'],
74
- sh_degree=state_dict['gaussian']['sh_degree'],
75
- mininum_kernel_size=state_dict['gaussian']['mininum_kernel_size'],
76
- scaling_bias=state_dict['gaussian']['scaling_bias'],
77
- opacity_bias=state_dict['gaussian']['opacity_bias'],
78
- scaling_activation=state_dict['gaussian']['scaling_activation'],
79
- )
80
- gs._xyz = torch.tensor(state_dict['gaussian']['_xyz'], device=device)
81
- gs._features_dc = torch.tensor(state_dict['gaussian']['_features_dc'], device=device)
82
- gs._scaling = torch.tensor(state_dict['gaussian']['_scaling'], device=device)
83
- gs._rotation = torch.tensor(state_dict['gaussian']['_rotation'], device=device)
84
- gs._opacity = torch.tensor(state_dict['gaussian']['_opacity'], device=device)
85
-
86
- mesh = edict(
87
- vertices=torch.tensor(state_dict['mesh']['vertices'], device=device),
88
- faces=torch.tensor(state_dict['mesh']['faces'], device=device),
89
- )
90
-
91
- return gs, mesh
92
-
93
-
94
- def get_seed(randomize_seed: bool, seed: int) -> int:
95
- """
96
- Get the random seed.
97
- """
98
- return np.random.randint(0, MAX_SEED) if randomize_seed else seed
99
-
100
 
101
  @spaces.GPU
102
  def text_to_3d(
@@ -107,230 +9,79 @@ def text_to_3d(
107
  slat_guidance_strength: float,
108
  slat_sampling_steps: int,
109
  req: gr.Request,
110
- ) -> Tuple[dict, str]: # <- Changed return annotation for clarity
111
  """
112
- Convert an text prompt to a 3D model.
113
- Args:
114
- prompt (str): The text prompt.
115
- seed (int): The random seed.
116
- ss_guidance_strength (float): The guidance strength for sparse structure generation.
117
- ss_sampling_steps (int): The number of sampling steps for sparse structure generation.
118
- slat_guidance_strength (float): The guidance strength for structured latent generation.
119
- slat_sampling_steps (int): The number of sampling steps for structured latent generation.
120
- Returns:
121
- dict: The *serializable dictionary* representing the state of the generated 3D model. <-- CHANGE
122
- str: The path to the video preview of the 3D model.
123
  """
 
124
  user_dir = os.path.join(TMP_DIR, str(req.session_hash))
125
  os.makedirs(user_dir, exist_ok=True)
 
126
 
127
  # --- Generation Pipeline ---
128
- outputs = pipeline.run(
129
- prompt,
130
- seed=seed,
131
- formats=["gaussian", "mesh"], # Ensure both are generated
132
- sparse_structure_sampler_params={
133
- "steps": ss_sampling_steps,
134
- "cfg_strength": ss_guidance_strength,
135
- },
136
- slat_sampler_params={
137
- "steps": slat_sampling_steps,
138
- "cfg_strength": slat_guidance_strength,
139
- },
140
- )
141
-
142
- # --- Create Serializable State Dictionary --- VITAL CHANGE for API
143
- # Instead of returning the raw state object, return a serializable dictionary
144
- # which can be passed via the API correctly.
145
- state_dict = pack_state(outputs['gaussian'][0], outputs['mesh'][0])
146
-
147
- # --- Render Video Preview ---
148
- video = render_utils.render_video(outputs['gaussian'][0], num_frames=120)['color']
149
- video_geo = render_utils.render_video(outputs['mesh'][0], num_frames=120)['normal']
150
- video = [np.concatenate([video[i], video_geo[i]], axis=1) for i in range(len(video))]
151
- video_path = os.path.join(user_dir, 'sample.mp4')
152
- imageio.mimsave(video_path, video, fps=15)
153
-
154
- torch.cuda.empty_cache()
155
-
156
- # --- Return Serializable Dictionary and Video Path --- VITAL CHANGE for API
157
- return state_dict, video_path
158
-
159
-
160
- @spaces.GPU(duration=90)
161
- def extract_glb(
162
- state_dict: dict, # <-- VITAL CHANGE: Accept the dictionary directly
163
- mesh_simplify: float,
164
- texture_size: int,
165
- req: gr.Request,
166
- ) -> Tuple[str, str]:
167
- """
168
- Extract a GLB file from the 3D model state dictionary.
169
- Args:
170
- state_dict (dict): The serializable dictionary state of the generated 3D model. <-- CHANGE
171
- mesh_simplify (float): The mesh simplification factor.
172
- texture_size (int): The texture resolution.
173
- Returns:
174
- str: The path to the extracted GLB file (for Model3D component).
175
- str: The path to the extracted GLB file (for DownloadButton).
176
- """
177
- user_dir = os.path.join(TMP_DIR, str(req.session_hash))
178
- os.makedirs(user_dir, exist_ok=True)
179
-
180
- # --- Unpack state from the dictionary --- VITAL CHANGE for API
181
- gs, mesh = unpack_state(state_dict)
182
-
183
- # --- Postprocessing and Export ---
184
- glb = postprocessing_utils.to_glb(gs, mesh, simplify=mesh_simplify, texture_size=texture_size, verbose=False)
185
- glb_path = os.path.join(user_dir, 'sample.glb')
186
- glb.export(glb_path)
187
-
188
- torch.cuda.empty_cache()
189
- # Return path twice for both Model3D and DownloadButton components
190
- return glb_path, glb_path
191
-
192
-
193
- @spaces.GPU
194
- def extract_gaussian(state_dict: dict, req: gr.Request) -> Tuple[str, str]: # <-- CHANGE: Accept dict
195
- """
196
- Extract a Gaussian file from the 3D model state dictionary.
197
- Args:
198
- state_dict (dict): The serializable dictionary state of the generated 3D model. <-- CHANGE
199
- Returns:
200
- str: The path to the extracted Gaussian file (for Model3D component).
201
- str: The path to the extracted Gaussian file (for DownloadButton).
202
- """
203
- user_dir = os.path.join(TMP_DIR, str(req.session_hash))
204
- os.makedirs(user_dir, exist_ok=True)
205
-
206
- # --- Unpack state from the dictionary --- VITAL CHANGE for API
207
- gs, _ = unpack_state(state_dict)
208
-
209
- gaussian_path = os.path.join(user_dir, 'sample.ply')
210
- gs.save_ply(gaussian_path)
211
- torch.cuda.empty_cache()
212
- # Return path twice for both Model3D and DownloadButton components
213
- return gaussian_path, gaussian_path
214
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
215
 
216
  # --- Gradio UI Definition ---
217
- # output_buf = gr.State() # No change needed here, it will now hold the dict
218
- # video_output = gr.Video(...) # No change needed
219
-
220
- with gr.Blocks(delete_cache=(600, 600)) as demo:
221
- gr.Markdown("""
222
- ## Text to 3D Asset with [TRELLIS](https://trellis3d.github.io/)
223
- * Type a text prompt and click "Generate" to create a 3D asset.
224
- * If you find the generated 3D asset satisfactory, click "Extract GLB" to extract the GLB file and download it.
225
- """)
226
 
227
- with gr.Row():
228
- with gr.Column():
229
- text_prompt = gr.Textbox(label="Text Prompt", lines=5)
230
-
231
- with gr.Accordion(label="Generation Settings", open=False):
232
- seed = gr.Slider(0, MAX_SEED, label="Seed", value=0, step=1)
233
- randomize_seed = gr.Checkbox(label="Randomize Seed", value=True)
234
- gr.Markdown("Stage 1: Sparse Structure Generation")
235
- with gr.Row():
236
- ss_guidance_strength = gr.Slider(0.0, 10.0, label="Guidance Strength", value=7.5, step=0.1)
237
- ss_sampling_steps = gr.Slider(1, 50, label="Sampling Steps", value=25, step=1)
238
- gr.Markdown("Stage 2: Structured Latent Generation")
239
- with gr.Row():
240
- slat_guidance_strength = gr.Slider(0.0, 10.0, label="Guidance Strength", value=7.5, step=0.1)
241
- slat_sampling_steps = gr.Slider(1, 50, label="Sampling Steps", value=25, step=1)
242
-
243
- generate_btn = gr.Button("Generate")
244
-
245
- with gr.Accordion(label="GLB Extraction Settings", open=False):
246
- mesh_simplify = gr.Slider(0.9, 0.98, label="Simplify", value=0.95, step=0.01)
247
- texture_size = gr.Slider(512, 2048, label="Texture Size", value=1024, step=512)
248
-
249
- with gr.Row():
250
- extract_glb_btn = gr.Button("Extract GLB", interactive=False)
251
- extract_gs_btn = gr.Button("Extract Gaussian", interactive=False)
252
- gr.Markdown("""
253
- *NOTE: Gaussian file can be very large (~50MB), it will take a while to display and download.*
254
- """)
255
-
256
- with gr.Column():
257
- video_output = gr.Video(label="Generated 3D Asset", autoplay=True, loop=True, height=300)
258
- model_output = gr.Model3D(label="Extracted GLB/Gaussian", height=300)
259
-
260
- with gr.Row():
261
- download_glb = gr.DownloadButton(label="Download GLB", interactive=False)
262
- download_gs = gr.DownloadButton(label="Download Gaussian", interactive=False)
263
-
264
- # --- State Buffer ---
265
- # This will now hold the dictionary returned by text_to_3d
266
- output_buf = gr.State()
267
-
268
- # --- Handlers ---
269
- demo.load(start_session)
270
- demo.unload(end_session)
271
-
272
- # --- Generate Button Click Flow ---
273
- # No changes needed to the structure, but text_to_3d now puts the dictionary into output_buf
274
- generate_btn.click(
275
- get_seed,
276
- inputs=[randomize_seed, seed],
277
- outputs=[seed],
278
- ).then(
279
- text_to_3d,
280
- inputs=[text_prompt, seed, ss_guidance_strength, ss_sampling_steps, slat_guidance_strength, slat_sampling_steps],
281
- outputs=[output_buf, video_output], # output_buf receives state_dict
282
- ).then(
283
- lambda: tuple([gr.Button(interactive=True), gr.Button(interactive=True)]),
284
- outputs=[extract_glb_btn, extract_gs_btn],
285
- )
286
-
287
- video_output.clear(
288
- lambda: tuple([gr.Button(interactive=False), gr.Button(interactive=False)]),
289
- outputs=[extract_glb_btn, extract_gs_btn],
290
- )
291
-
292
- # --- Extract GLB Button Click Flow ---
293
- # The input 'output_buf' now contains the state_dict needed by the modified extract_glb function
294
- extract_glb_btn.click(
295
- extract_glb,
296
- inputs=[output_buf, mesh_simplify, texture_size], # Pass the state_dict via output_buf
297
- outputs=[model_output, download_glb],
298
- ).then(
299
- lambda: gr.Button(interactive=True),
300
- outputs=[download_glb],
301
- )
302
-
303
- # --- Extract Gaussian Button Click Flow ---
304
- # The input 'output_buf' now contains the state_dict needed by the modified extract_gaussian function
305
- extract_gs_btn.click(
306
- extract_gaussian,
307
- inputs=[output_buf], # Pass the state_dict via output_buf
308
- outputs=[model_output, download_gs],
309
- ).then(
310
- lambda: gr.Button(interactive=True),
311
- outputs=[download_gs],
312
- )
313
-
314
- model_output.clear(
315
- lambda: gr.Button(interactive=False), # Should clear both potentially?
316
- outputs=[download_glb, download_gs], # Clear both download buttons
317
- )
318
-
319
-
320
- # --- Launch the Gradio app ---
321
- if __name__ == "__main__":
322
- # Consider adding error handling for pipeline loading
323
- try:
324
- pipeline = TrellisTextTo3DPipeline.from_pretrained("JeffreyXiang/TRELLIS-text-xlarge")
325
- # Move to GPU if available
326
- if torch.cuda.is_available():
327
- pipeline.cuda()
328
- else:
329
- print("WARNING: CUDA not available, running on CPU (will be very slow).")
330
- print("✅ Trellis pipeline loaded successfully.")
331
- except Exception as e:
332
- print(f"❌ Failed to load Trellis pipeline: {e}", file=sys.stderr)
333
- # Optionally exit if pipeline is critical
334
- # sys.exit(1)
335
 
336
- demo.launch()
 
 
1
+ # In appTrellis_fileErrorsFIx.py
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
 
3
  @spaces.GPU
4
  def text_to_3d(
 
9
  slat_guidance_strength: float,
10
  slat_sampling_steps: int,
11
  req: gr.Request,
12
+ ) -> Tuple[dict, str]: # Return type changed for clarity
13
  """
14
+ Generates a 3D model (Gaussian and Mesh) from text and returns a
15
+ serializable state dictionary and a video preview path.
 
 
 
 
 
 
 
 
 
16
  """
17
+ print(f"[text_to_3d] Received prompt: '{prompt}', Seed: {seed}")
18
  user_dir = os.path.join(TMP_DIR, str(req.session_hash))
19
  os.makedirs(user_dir, exist_ok=True)
20
+ print(f"[text_to_3d] User directory: {user_dir}")
21
 
22
  # --- Generation Pipeline ---
23
+ try:
24
+ print("[text_to_3d] Running Trellis pipeline...")
25
+ outputs = pipeline.run(
26
+ prompt,
27
+ seed=seed,
28
+ formats=["gaussian", "mesh"], # Ensure both are generated
29
+ sparse_structure_sampler_params={
30
+ "steps": int(ss_sampling_steps), # Ensure steps are int
31
+ "cfg_strength": float(ss_guidance_strength),
32
+ },
33
+ slat_sampler_params={
34
+ "steps": int(slat_sampling_steps), # Ensure steps are int
35
+ "cfg_strength": float(slat_guidance_strength),
36
+ },
37
+ )
38
+ print("[text_to_3d] Pipeline run completed.")
39
+ except Exception as e:
40
+ print(f"❌ [text_to_3d] Pipeline error: {e}", file=sys.stderr)
41
+ traceback.print_exc()
42
+ raise gr.Error(f"Trellis pipeline failed: {e}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
43
 
44
+ # --- Create Serializable State Dictionary ---
45
+ try:
46
+ state_dict = pack_state(outputs['gaussian'][0], outputs['mesh'][0])
47
+ except Exception as e:
48
+ print(f"❌ [text_to_3d] pack_state error: {e}", file=sys.stderr)
49
+ traceback.print_exc()
50
+ raise gr.Error(f"Failed to pack state: {e}")
51
+
52
+ # --- Render Video Preview (TEMPORARILY DISABLED FOR DEBUGGING) ---
53
+ video_path = None # Set path to None
54
+ # try:
55
+ # print("[text_to_3d] Rendering video preview...")
56
+ # video = render_utils.render_video(outputs['gaussian'][0], num_frames=120)['color']
57
+ # video_geo = render_utils.render_video(outputs['mesh'][0], num_frames=120)['normal']
58
+ # # Ensure video frames are uint8
59
+ # video = [np.concatenate([v.astype(np.uint8), vg.astype(np.uint8)], axis=1) for v, vg in zip(video, video_geo)]
60
+ # video_path = os.path.join(user_dir, 'sample.mp4')
61
+ # imageio.mimsave(video_path, video, fps=15, quality=8) # Added quality setting
62
+ # print(f"[text_to_3d] Video saved to: {video_path}")
63
+ # except Exception as e:
64
+ # print(f"❌ [text_to_3d] Video rendering/saving error: {e}", file=sys.stderr)
65
+ # traceback.print_exc()
66
+ # # Still return state_dict, but maybe signal video error? Return None for path.
67
+ # video_path = None # Indicate video failure
68
+ print("[text_to_3d] Skipping video rendering for debugging.")
69
+
70
+ # --- Cleanup and Return ---
71
+ if torch.cuda.is_available():
72
+ torch.cuda.empty_cache()
73
+ print("[text_to_3d] Cleared CUDA cache.")
74
+
75
+ # --- Return Serializable Dictionary and None for Video Path ---
76
+ print("[text_to_3d] Returning state dictionary and None video path.")
77
+ return state_dict, video_path # Return dict and None video path
78
 
79
  # --- Gradio UI Definition ---
80
+ # ... (rest of the file is the same, but you might want to adjust the output mapping if needed)
 
 
 
 
 
 
 
 
81
 
82
+ # In the generate_btn.click handler, adjust the outputs if the video component causes issues:
83
+ # Option 1: Keep Video component, it will just show nothing.
84
+ # outputs=[output_buf, video_output], # This might be fine
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
85
 
86
+ # Option 2: Use a dummy hidden component if video_output causes issues receiving None
87
+ # outputs=[output_buf, gr.Textbox(visible=False)], # Example dummy