dkatz2391 commited on
Commit
3447081
·
verified ·
1 Parent(s): 5faf466

Modified `text_to_3d` to explicitly return the serializable `state_dict` from `pack_state` # as the first return value. This ensures the dictionary is available via the API. # - Modified `extract_glb` to accept `state_dict: dict` as its first argument instead of # relying on the implicit `gr.State` object type when called via API. # - Kept Gradio UI bindings (`outputs=[output_buf, ...]`, `inputs=[output_buf, ...]`) # so the UI continues to function by passing the dictionary through output_buf.

Browse files
Files changed (1) hide show
  1. app.py +126 -64
app.py CHANGED
@@ -1,3 +1,12 @@
 
 
 
 
 
 
 
 
 
1
  import gradio as gr
2
  import spaces
3
 
@@ -26,50 +35,59 @@ os.makedirs(TMP_DIR, exist_ok=True)
26
  def start_session(req: gr.Request):
27
  user_dir = os.path.join(TMP_DIR, str(req.session_hash))
28
  os.makedirs(user_dir, exist_ok=True)
29
-
30
-
31
  def end_session(req: gr.Request):
32
  user_dir = os.path.join(TMP_DIR, str(req.session_hash))
33
- shutil.rmtree(user_dir)
 
 
 
 
 
34
 
35
 
36
  def pack_state(gs: Gaussian, mesh: MeshExtractResult) -> dict:
 
37
  return {
38
  'gaussian': {
39
  **gs.init_params,
40
- '_xyz': gs._xyz.cpu().numpy(),
41
- '_features_dc': gs._features_dc.cpu().numpy(),
42
- '_scaling': gs._scaling.cpu().numpy(),
43
- '_rotation': gs._rotation.cpu().numpy(),
44
- '_opacity': gs._opacity.cpu().numpy(),
45
  },
46
  'mesh': {
47
- 'vertices': mesh.vertices.cpu().numpy(),
48
- 'faces': mesh.faces.cpu().numpy(),
49
  },
50
  }
51
-
52
-
53
- def unpack_state(state: dict) -> Tuple[Gaussian, edict, str]:
 
 
 
54
  gs = Gaussian(
55
- aabb=state['gaussian']['aabb'],
56
- sh_degree=state['gaussian']['sh_degree'],
57
- mininum_kernel_size=state['gaussian']['mininum_kernel_size'],
58
- scaling_bias=state['gaussian']['scaling_bias'],
59
- opacity_bias=state['gaussian']['opacity_bias'],
60
- scaling_activation=state['gaussian']['scaling_activation'],
61
  )
62
- gs._xyz = torch.tensor(state['gaussian']['_xyz'], device='cuda')
63
- gs._features_dc = torch.tensor(state['gaussian']['_features_dc'], device='cuda')
64
- gs._scaling = torch.tensor(state['gaussian']['_scaling'], device='cuda')
65
- gs._rotation = torch.tensor(state['gaussian']['_rotation'], device='cuda')
66
- gs._opacity = torch.tensor(state['gaussian']['_opacity'], device='cuda')
67
-
68
  mesh = edict(
69
- vertices=torch.tensor(state['mesh']['vertices'], device='cuda'),
70
- faces=torch.tensor(state['mesh']['faces'], device='cuda'),
71
  )
72
-
73
  return gs, mesh
74
 
75
 
@@ -89,7 +107,7 @@ def text_to_3d(
89
  slat_guidance_strength: float,
90
  slat_sampling_steps: int,
91
  req: gr.Request,
92
- ) -> Tuple[dict, str]:
93
  """
94
  Convert an text prompt to a 3D model.
95
  Args:
@@ -100,15 +118,17 @@ def text_to_3d(
100
  slat_guidance_strength (float): The guidance strength for structured latent generation.
101
  slat_sampling_steps (int): The number of sampling steps for structured latent generation.
102
  Returns:
103
- dict: The information of the generated 3D model.
104
- str: The path to the video of the 3D model.
105
  """
106
  user_dir = os.path.join(TMP_DIR, str(req.session_hash))
107
  os.makedirs(user_dir, exist_ok=True)
 
 
108
  outputs = pipeline.run(
109
  prompt,
110
  seed=seed,
111
- formats=["gaussian", "mesh"],
112
  sparse_structure_sampler_params={
113
  "steps": ss_sampling_steps,
114
  "cfg_strength": ss_guidance_strength,
@@ -118,62 +138,84 @@ def text_to_3d(
118
  "cfg_strength": slat_guidance_strength,
119
  },
120
  )
 
 
 
 
 
 
 
121
  video = render_utils.render_video(outputs['gaussian'][0], num_frames=120)['color']
122
  video_geo = render_utils.render_video(outputs['mesh'][0], num_frames=120)['normal']
123
  video = [np.concatenate([video[i], video_geo[i]], axis=1) for i in range(len(video))]
124
  video_path = os.path.join(user_dir, 'sample.mp4')
125
  imageio.mimsave(video_path, video, fps=15)
126
- state = pack_state(outputs['gaussian'][0], outputs['mesh'][0])
127
  torch.cuda.empty_cache()
128
- return state, video_path
 
 
129
 
130
 
131
  @spaces.GPU(duration=90)
132
  def extract_glb(
133
- state: dict,
134
  mesh_simplify: float,
135
  texture_size: int,
136
  req: gr.Request,
137
  ) -> Tuple[str, str]:
138
  """
139
- Extract a GLB file from the 3D model.
140
  Args:
141
- state (dict): The state of the generated 3D model.
142
  mesh_simplify (float): The mesh simplification factor.
143
  texture_size (int): The texture resolution.
144
  Returns:
145
- str: The path to the extracted GLB file.
 
146
  """
147
  user_dir = os.path.join(TMP_DIR, str(req.session_hash))
148
  os.makedirs(user_dir, exist_ok=True)
149
- gs, mesh = unpack_state(state)
 
 
 
 
150
  glb = postprocessing_utils.to_glb(gs, mesh, simplify=mesh_simplify, texture_size=texture_size, verbose=False)
151
  glb_path = os.path.join(user_dir, 'sample.glb')
152
  glb.export(glb_path)
 
153
  torch.cuda.empty_cache()
 
154
  return glb_path, glb_path
155
 
156
 
157
  @spaces.GPU
158
- def extract_gaussian(state: dict, req: gr.Request) -> Tuple[str, str]:
159
  """
160
- Extract a Gaussian file from the 3D model.
161
  Args:
162
- state (dict): The state of the generated 3D model.
163
  Returns:
164
- str: The path to the extracted Gaussian file.
 
165
  """
166
  user_dir = os.path.join(TMP_DIR, str(req.session_hash))
167
  os.makedirs(user_dir, exist_ok=True)
168
- gs, _ = unpack_state(state)
 
 
 
169
  gaussian_path = os.path.join(user_dir, 'sample.ply')
170
  gs.save_ply(gaussian_path)
171
  torch.cuda.empty_cache()
 
172
  return gaussian_path, gaussian_path
173
 
174
 
175
- output_buf = gr.State()
176
- video_output = gr.Video(label="Generated 3D Asset", autoplay=True, loop=True, height=300)
 
177
 
178
  with gr.Blocks(delete_cache=(600, 600)) as demo:
179
  gr.Markdown("""
@@ -181,11 +223,11 @@ with gr.Blocks(delete_cache=(600, 600)) as demo:
181
  * Type a text prompt and click "Generate" to create a 3D asset.
182
  * If you find the generated 3D asset satisfactory, click "Extract GLB" to extract the GLB file and download it.
183
  """)
184
-
185
  with gr.Row():
186
  with gr.Column():
187
  text_prompt = gr.Textbox(label="Text Prompt", lines=5)
188
-
189
  with gr.Accordion(label="Generation Settings", open=False):
190
  seed = gr.Slider(0, MAX_SEED, label="Seed", value=0, step=1)
191
  randomize_seed = gr.Checkbox(label="Randomize Seed", value=True)
@@ -199,11 +241,11 @@ with gr.Blocks(delete_cache=(600, 600)) as demo:
199
  slat_sampling_steps = gr.Slider(1, 50, label="Sampling Steps", value=25, step=1)
200
 
201
  generate_btn = gr.Button("Generate")
202
-
203
  with gr.Accordion(label="GLB Extraction Settings", open=False):
204
  mesh_simplify = gr.Slider(0.9, 0.98, label="Simplify", value=0.95, step=0.01)
205
  texture_size = gr.Slider(512, 2048, label="Texture Size", value=1024, step=512)
206
-
207
  with gr.Row():
208
  extract_glb_btn = gr.Button("Extract GLB", interactive=False)
209
  extract_gs_btn = gr.Button("Extract Gaussian", interactive=False)
@@ -214,17 +256,21 @@ with gr.Blocks(delete_cache=(600, 600)) as demo:
214
  with gr.Column():
215
  video_output = gr.Video(label="Generated 3D Asset", autoplay=True, loop=True, height=300)
216
  model_output = gr.Model3D(label="Extracted GLB/Gaussian", height=300)
217
-
218
  with gr.Row():
219
  download_glb = gr.DownloadButton(label="Download GLB", interactive=False)
220
- download_gs = gr.DownloadButton(label="Download Gaussian", interactive=False)
221
-
 
 
222
  output_buf = gr.State()
223
 
224
- # Handlers
225
  demo.load(start_session)
226
  demo.unload(end_session)
227
 
 
 
228
  generate_btn.click(
229
  get_seed,
230
  inputs=[randomize_seed, seed],
@@ -232,7 +278,7 @@ with gr.Blocks(delete_cache=(600, 600)) as demo:
232
  ).then(
233
  text_to_3d,
234
  inputs=[text_prompt, seed, ss_guidance_strength, ss_sampling_steps, slat_guidance_strength, slat_sampling_steps],
235
- outputs=[output_buf, video_output],
236
  ).then(
237
  lambda: tuple([gr.Button(interactive=True), gr.Button(interactive=True)]),
238
  outputs=[extract_glb_btn, extract_gs_btn],
@@ -243,18 +289,22 @@ with gr.Blocks(delete_cache=(600, 600)) as demo:
243
  outputs=[extract_glb_btn, extract_gs_btn],
244
  )
245
 
 
 
246
  extract_glb_btn.click(
247
  extract_glb,
248
- inputs=[output_buf, mesh_simplify, texture_size],
249
  outputs=[model_output, download_glb],
250
  ).then(
251
  lambda: gr.Button(interactive=True),
252
  outputs=[download_glb],
253
  )
254
-
 
 
255
  extract_gs_btn.click(
256
  extract_gaussian,
257
- inputs=[output_buf],
258
  outputs=[model_output, download_gs],
259
  ).then(
260
  lambda: gr.Button(interactive=True),
@@ -262,13 +312,25 @@ with gr.Blocks(delete_cache=(600, 600)) as demo:
262
  )
263
 
264
  model_output.clear(
265
- lambda: gr.Button(interactive=False),
266
- outputs=[download_glb],
267
  )
268
-
269
 
270
- # Launch the Gradio app
 
271
  if __name__ == "__main__":
272
- pipeline = TrellisTextTo3DPipeline.from_pretrained("JeffreyXiang/TRELLIS-text-xlarge")
273
- pipeline.cuda()
 
 
 
 
 
 
 
 
 
 
 
 
274
  demo.launch()
 
1
+ # Version: Add API State Fix (2025-05-04)
2
+ # Changes:
3
+ # - Modified `text_to_3d` to explicitly return the serializable `state_dict` from `pack_state`
4
+ # as the first return value. This ensures the dictionary is available via the API.
5
+ # - Modified `extract_glb` to accept `state_dict: dict` as its first argument instead of
6
+ # relying on the implicit `gr.State` object type when called via API.
7
+ # - Kept Gradio UI bindings (`outputs=[output_buf, ...]`, `inputs=[output_buf, ...]`)
8
+ # so the UI continues to function by passing the dictionary through output_buf.
9
+
10
  import gradio as gr
11
  import spaces
12
 
 
35
  def start_session(req: gr.Request):
36
  user_dir = os.path.join(TMP_DIR, str(req.session_hash))
37
  os.makedirs(user_dir, exist_ok=True)
38
+
39
+
40
  def end_session(req: gr.Request):
41
  user_dir = os.path.join(TMP_DIR, str(req.session_hash))
42
+ # Add safety check before removing
43
+ if os.path.exists(user_dir):
44
+ try:
45
+ shutil.rmtree(user_dir)
46
+ except OSError as e:
47
+ print(f"Error removing tmp directory {user_dir}: {e.strerror}", file=sys.stderr)
48
 
49
 
50
  def pack_state(gs: Gaussian, mesh: MeshExtractResult) -> dict:
51
+ # Ensure tensors are on CPU and converted to numpy before returning the dict
52
  return {
53
  'gaussian': {
54
  **gs.init_params,
55
+ '_xyz': gs._xyz.detach().cpu().numpy(),
56
+ '_features_dc': gs._features_dc.detach().cpu().numpy(),
57
+ '_scaling': gs._scaling.detach().cpu().numpy(),
58
+ '_rotation': gs._rotation.detach().cpu().numpy(),
59
+ '_opacity': gs._opacity.detach().cpu().numpy(),
60
  },
61
  'mesh': {
62
+ 'vertices': mesh.vertices.detach().cpu().numpy(),
63
+ 'faces': mesh.faces.detach().cpu().numpy(),
64
  },
65
  }
66
+
67
+
68
+ def unpack_state(state_dict: dict) -> Tuple[Gaussian, edict]:
69
+ # Ensure the device is correctly set when unpacking
70
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
71
+
72
  gs = Gaussian(
73
+ aabb=state_dict['gaussian']['aabb'],
74
+ sh_degree=state_dict['gaussian']['sh_degree'],
75
+ mininum_kernel_size=state_dict['gaussian']['mininum_kernel_size'],
76
+ scaling_bias=state_dict['gaussian']['scaling_bias'],
77
+ opacity_bias=state_dict['gaussian']['opacity_bias'],
78
+ scaling_activation=state_dict['gaussian']['scaling_activation'],
79
  )
80
+ gs._xyz = torch.tensor(state_dict['gaussian']['_xyz'], device=device)
81
+ gs._features_dc = torch.tensor(state_dict['gaussian']['_features_dc'], device=device)
82
+ gs._scaling = torch.tensor(state_dict['gaussian']['_scaling'], device=device)
83
+ gs._rotation = torch.tensor(state_dict['gaussian']['_rotation'], device=device)
84
+ gs._opacity = torch.tensor(state_dict['gaussian']['_opacity'], device=device)
85
+
86
  mesh = edict(
87
+ vertices=torch.tensor(state_dict['mesh']['vertices'], device=device),
88
+ faces=torch.tensor(state_dict['mesh']['faces'], device=device),
89
  )
90
+
91
  return gs, mesh
92
 
93
 
 
107
  slat_guidance_strength: float,
108
  slat_sampling_steps: int,
109
  req: gr.Request,
110
+ ) -> Tuple[dict, str]: # <- Changed return annotation for clarity
111
  """
112
  Convert an text prompt to a 3D model.
113
  Args:
 
118
  slat_guidance_strength (float): The guidance strength for structured latent generation.
119
  slat_sampling_steps (int): The number of sampling steps for structured latent generation.
120
  Returns:
121
+ dict: The *serializable dictionary* representing the state of the generated 3D model. <-- CHANGE
122
+ str: The path to the video preview of the 3D model.
123
  """
124
  user_dir = os.path.join(TMP_DIR, str(req.session_hash))
125
  os.makedirs(user_dir, exist_ok=True)
126
+
127
+ # --- Generation Pipeline ---
128
  outputs = pipeline.run(
129
  prompt,
130
  seed=seed,
131
+ formats=["gaussian", "mesh"], # Ensure both are generated
132
  sparse_structure_sampler_params={
133
  "steps": ss_sampling_steps,
134
  "cfg_strength": ss_guidance_strength,
 
138
  "cfg_strength": slat_guidance_strength,
139
  },
140
  )
141
+
142
+ # --- Create Serializable State Dictionary --- VITAL CHANGE for API
143
+ # Instead of returning the raw state object, return a serializable dictionary
144
+ # which can be passed via the API correctly.
145
+ state_dict = pack_state(outputs['gaussian'][0], outputs['mesh'][0])
146
+
147
+ # --- Render Video Preview ---
148
  video = render_utils.render_video(outputs['gaussian'][0], num_frames=120)['color']
149
  video_geo = render_utils.render_video(outputs['mesh'][0], num_frames=120)['normal']
150
  video = [np.concatenate([video[i], video_geo[i]], axis=1) for i in range(len(video))]
151
  video_path = os.path.join(user_dir, 'sample.mp4')
152
  imageio.mimsave(video_path, video, fps=15)
153
+
154
  torch.cuda.empty_cache()
155
+
156
+ # --- Return Serializable Dictionary and Video Path --- VITAL CHANGE for API
157
+ return state_dict, video_path
158
 
159
 
160
  @spaces.GPU(duration=90)
161
  def extract_glb(
162
+ state_dict: dict, # <-- VITAL CHANGE: Accept the dictionary directly
163
  mesh_simplify: float,
164
  texture_size: int,
165
  req: gr.Request,
166
  ) -> Tuple[str, str]:
167
  """
168
+ Extract a GLB file from the 3D model state dictionary.
169
  Args:
170
+ state_dict (dict): The serializable dictionary state of the generated 3D model. <-- CHANGE
171
  mesh_simplify (float): The mesh simplification factor.
172
  texture_size (int): The texture resolution.
173
  Returns:
174
+ str: The path to the extracted GLB file (for Model3D component).
175
+ str: The path to the extracted GLB file (for DownloadButton).
176
  """
177
  user_dir = os.path.join(TMP_DIR, str(req.session_hash))
178
  os.makedirs(user_dir, exist_ok=True)
179
+
180
+ # --- Unpack state from the dictionary --- VITAL CHANGE for API
181
+ gs, mesh = unpack_state(state_dict)
182
+
183
+ # --- Postprocessing and Export ---
184
  glb = postprocessing_utils.to_glb(gs, mesh, simplify=mesh_simplify, texture_size=texture_size, verbose=False)
185
  glb_path = os.path.join(user_dir, 'sample.glb')
186
  glb.export(glb_path)
187
+
188
  torch.cuda.empty_cache()
189
+ # Return path twice for both Model3D and DownloadButton components
190
  return glb_path, glb_path
191
 
192
 
193
  @spaces.GPU
194
+ def extract_gaussian(state_dict: dict, req: gr.Request) -> Tuple[str, str]: # <-- CHANGE: Accept dict
195
  """
196
+ Extract a Gaussian file from the 3D model state dictionary.
197
  Args:
198
+ state_dict (dict): The serializable dictionary state of the generated 3D model. <-- CHANGE
199
  Returns:
200
+ str: The path to the extracted Gaussian file (for Model3D component).
201
+ str: The path to the extracted Gaussian file (for DownloadButton).
202
  """
203
  user_dir = os.path.join(TMP_DIR, str(req.session_hash))
204
  os.makedirs(user_dir, exist_ok=True)
205
+
206
+ # --- Unpack state from the dictionary --- VITAL CHANGE for API
207
+ gs, _ = unpack_state(state_dict)
208
+
209
  gaussian_path = os.path.join(user_dir, 'sample.ply')
210
  gs.save_ply(gaussian_path)
211
  torch.cuda.empty_cache()
212
+ # Return path twice for both Model3D and DownloadButton components
213
  return gaussian_path, gaussian_path
214
 
215
 
216
+ # --- Gradio UI Definition ---
217
+ # output_buf = gr.State() # No change needed here, it will now hold the dict
218
+ # video_output = gr.Video(...) # No change needed
219
 
220
  with gr.Blocks(delete_cache=(600, 600)) as demo:
221
  gr.Markdown("""
 
223
  * Type a text prompt and click "Generate" to create a 3D asset.
224
  * If you find the generated 3D asset satisfactory, click "Extract GLB" to extract the GLB file and download it.
225
  """)
226
+
227
  with gr.Row():
228
  with gr.Column():
229
  text_prompt = gr.Textbox(label="Text Prompt", lines=5)
230
+
231
  with gr.Accordion(label="Generation Settings", open=False):
232
  seed = gr.Slider(0, MAX_SEED, label="Seed", value=0, step=1)
233
  randomize_seed = gr.Checkbox(label="Randomize Seed", value=True)
 
241
  slat_sampling_steps = gr.Slider(1, 50, label="Sampling Steps", value=25, step=1)
242
 
243
  generate_btn = gr.Button("Generate")
244
+
245
  with gr.Accordion(label="GLB Extraction Settings", open=False):
246
  mesh_simplify = gr.Slider(0.9, 0.98, label="Simplify", value=0.95, step=0.01)
247
  texture_size = gr.Slider(512, 2048, label="Texture Size", value=1024, step=512)
248
+
249
  with gr.Row():
250
  extract_glb_btn = gr.Button("Extract GLB", interactive=False)
251
  extract_gs_btn = gr.Button("Extract Gaussian", interactive=False)
 
256
  with gr.Column():
257
  video_output = gr.Video(label="Generated 3D Asset", autoplay=True, loop=True, height=300)
258
  model_output = gr.Model3D(label="Extracted GLB/Gaussian", height=300)
259
+
260
  with gr.Row():
261
  download_glb = gr.DownloadButton(label="Download GLB", interactive=False)
262
+ download_gs = gr.DownloadButton(label="Download Gaussian", interactive=False)
263
+
264
+ # --- State Buffer ---
265
+ # This will now hold the dictionary returned by text_to_3d
266
  output_buf = gr.State()
267
 
268
+ # --- Handlers ---
269
  demo.load(start_session)
270
  demo.unload(end_session)
271
 
272
+ # --- Generate Button Click Flow ---
273
+ # No changes needed to the structure, but text_to_3d now puts the dictionary into output_buf
274
  generate_btn.click(
275
  get_seed,
276
  inputs=[randomize_seed, seed],
 
278
  ).then(
279
  text_to_3d,
280
  inputs=[text_prompt, seed, ss_guidance_strength, ss_sampling_steps, slat_guidance_strength, slat_sampling_steps],
281
+ outputs=[output_buf, video_output], # output_buf receives state_dict
282
  ).then(
283
  lambda: tuple([gr.Button(interactive=True), gr.Button(interactive=True)]),
284
  outputs=[extract_glb_btn, extract_gs_btn],
 
289
  outputs=[extract_glb_btn, extract_gs_btn],
290
  )
291
 
292
+ # --- Extract GLB Button Click Flow ---
293
+ # The input 'output_buf' now contains the state_dict needed by the modified extract_glb function
294
  extract_glb_btn.click(
295
  extract_glb,
296
+ inputs=[output_buf, mesh_simplify, texture_size], # Pass the state_dict via output_buf
297
  outputs=[model_output, download_glb],
298
  ).then(
299
  lambda: gr.Button(interactive=True),
300
  outputs=[download_glb],
301
  )
302
+
303
+ # --- Extract Gaussian Button Click Flow ---
304
+ # The input 'output_buf' now contains the state_dict needed by the modified extract_gaussian function
305
  extract_gs_btn.click(
306
  extract_gaussian,
307
+ inputs=[output_buf], # Pass the state_dict via output_buf
308
  outputs=[model_output, download_gs],
309
  ).then(
310
  lambda: gr.Button(interactive=True),
 
312
  )
313
 
314
  model_output.clear(
315
+ lambda: gr.Button(interactive=False), # Should clear both potentially?
316
+ outputs=[download_glb, download_gs], # Clear both download buttons
317
  )
 
318
 
319
+
320
+ # --- Launch the Gradio app ---
321
  if __name__ == "__main__":
322
+ # Consider adding error handling for pipeline loading
323
+ try:
324
+ pipeline = TrellisTextTo3DPipeline.from_pretrained("JeffreyXiang/TRELLIS-text-xlarge")
325
+ # Move to GPU if available
326
+ if torch.cuda.is_available():
327
+ pipeline.cuda()
328
+ else:
329
+ print("WARNING: CUDA not available, running on CPU (will be very slow).")
330
+ print("✅ Trellis pipeline loaded successfully.")
331
+ except Exception as e:
332
+ print(f"❌ Failed to load Trellis pipeline: {e}", file=sys.stderr)
333
+ # Optionally exit if pipeline is critical
334
+ # sys.exit(1)
335
+
336
  demo.launch()