dkatz2391 commited on
Commit
ef798fd
·
verified ·
1 Parent(s): ab1a5ae

gemini inside cursor state change

Browse files
Files changed (1) hide show
  1. app.py +151 -41
app.py CHANGED
@@ -39,7 +39,8 @@ def start_session(req: gr.Request):
39
 
40
  def end_session(req: gr.Request):
41
  user_dir = os.path.join(TMP_DIR, str(req.session_hash))
42
- shutil.rmtree(user_dir)
 
43
 
44
 
45
  def pack_state(gs: Gaussian, mesh: MeshExtractResult) -> dict:
@@ -68,15 +69,16 @@ def unpack_state(state: dict) -> Tuple[Gaussian, edict, str]:
68
  opacity_bias=state['gaussian']['opacity_bias'],
69
  scaling_activation=state['gaussian']['scaling_activation'],
70
  )
71
- gs._xyz = torch.tensor(state['gaussian']['_xyz'], device='cuda')
72
- gs._features_dc = torch.tensor(state['gaussian']['_features_dc'], device='cuda')
73
- gs._scaling = torch.tensor(state['gaussian']['_scaling'], device='cuda')
74
- gs._rotation = torch.tensor(state['gaussian']['_rotation'], device='cuda')
75
- gs._opacity = torch.tensor(state['gaussian']['_opacity'], device='cuda')
 
76
 
77
  mesh = edict(
78
- vertices=torch.tensor(state['mesh']['vertices'], device='cuda'),
79
- faces=torch.tensor(state['mesh']['faces'], device='cuda'),
80
  )
81
 
82
  return gs, mesh
@@ -98,9 +100,9 @@ def text_to_3d(
98
  slat_guidance_strength: float,
99
  slat_sampling_steps: int,
100
  req: gr.Request,
101
- ) -> Tuple[dict, str]:
102
  """
103
- Convert an text prompt to a 3D model.
104
  Args:
105
  prompt (str): The text prompt.
106
  seed (int): The random seed.
@@ -109,11 +111,14 @@ def text_to_3d(
109
  slat_guidance_strength (float): The guidance strength for structured latent generation.
110
  slat_sampling_steps (int): The number of sampling steps for structured latent generation.
111
  Returns:
112
- dict: The information of the generated 3D model.
113
- str: The path to the video of the 3D model.
114
  """
 
115
  user_dir = os.path.join(TMP_DIR, str(req.session_hash))
116
- os.makedirs(user_dir, exist_ok=True)
 
 
 
117
  outputs = pipeline.run(
118
  prompt,
119
  seed=seed,
@@ -127,19 +132,58 @@ def text_to_3d(
127
  "cfg_strength": slat_guidance_strength,
128
  },
129
  )
130
- video = render_utils.render_video(outputs['gaussian'][0], num_frames=120)['color']
131
- video_geo = render_utils.render_video(outputs['mesh'][0], num_frames=120)['normal']
132
- video = [np.concatenate([video[i], video_geo[i]], axis=1) for i in range(len(video))]
133
- video_path = os.path.join(user_dir, 'sample.mp4')
134
- imageio.mimsave(video_path, video, fps=15)
 
 
135
 
136
  # Create the state object and ensure it's JSON serializable for API calls
137
  state = pack_state(outputs['gaussian'][0], outputs['mesh'][0])
138
  # Convert to serializable format
139
  serializable_state = json.loads(json.dumps(state, cls=NumpyEncoder))
140
 
 
 
141
  torch.cuda.empty_cache()
142
- return serializable_state, video_path
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
143
 
144
 
145
  @spaces.GPU(duration=90)
@@ -150,50 +194,76 @@ def extract_glb(
150
  req: gr.Request,
151
  ) -> Tuple[str, str]:
152
  """
153
- Extract a GLB file from the 3D model.
154
  Args:
155
  state (dict): The state of the generated 3D model.
156
  mesh_simplify (float): The mesh simplification factor.
157
  texture_size (int): The texture resolution.
158
  Returns:
159
- str: The path to the extracted GLB file.
 
160
  """
 
 
 
 
161
  user_dir = os.path.join(TMP_DIR, str(req.session_hash))
162
  os.makedirs(user_dir, exist_ok=True)
 
 
163
  gs, mesh = unpack_state(state)
 
 
164
  glb = postprocessing_utils.to_glb(gs, mesh, simplify=mesh_simplify, texture_size=texture_size, verbose=False)
165
  glb_path = os.path.join(user_dir, 'sample.glb')
 
166
  glb.export(glb_path)
 
167
  torch.cuda.empty_cache()
 
168
  return glb_path, glb_path
169
 
170
 
171
  @spaces.GPU
172
  def extract_gaussian(state: dict, req: gr.Request) -> Tuple[str, str]:
173
  """
174
- Extract a Gaussian file from the 3D model.
175
  Args:
176
  state (dict): The state of the generated 3D model.
177
  Returns:
178
- str: The path to the extracted Gaussian file.
 
179
  """
 
 
 
 
180
  user_dir = os.path.join(TMP_DIR, str(req.session_hash))
181
  os.makedirs(user_dir, exist_ok=True)
 
 
182
  gs, _ = unpack_state(state)
 
183
  gaussian_path = os.path.join(user_dir, 'sample.ply')
 
184
  gs.save_ply(gaussian_path)
 
185
  torch.cuda.empty_cache()
 
186
  return gaussian_path, gaussian_path
187
 
188
 
189
- output_buf = gr.State()
190
- video_output = gr.Video(label="Generated 3D Asset", autoplay=True, loop=True, height=300)
 
 
191
 
192
  with gr.Blocks(delete_cache=(600, 600)) as demo:
193
  gr.Markdown("""
194
  ## Text to 3D Asset with [TRELLIS](https://trellis3d.github.io/)
195
  * Type a text prompt and click "Generate" to create a 3D asset.
196
- * If you find the generated 3D asset satisfactory, click "Extract GLB" to extract the GLB file and download it.
 
197
  """)
198
 
199
  with gr.Row():
@@ -219,6 +289,7 @@ with gr.Blocks(delete_cache=(600, 600)) as demo:
219
  texture_size = gr.Slider(512, 2048, label="Texture Size", value=1024, step=512)
220
 
221
  with gr.Row():
 
222
  extract_glb_btn = gr.Button("Extract GLB", interactive=False)
223
  extract_gs_btn = gr.Button("Extract Gaussian", interactive=False)
224
  gr.Markdown("""
@@ -226,63 +297,102 @@ with gr.Blocks(delete_cache=(600, 600)) as demo:
226
  """)
227
 
228
  with gr.Column():
229
- video_output = gr.Video(label="Generated 3D Asset", autoplay=True, loop=True, height=300)
 
230
  model_output = gr.Model3D(label="Extracted GLB/Gaussian", height=300)
231
 
232
  with gr.Row():
 
233
  download_glb = gr.DownloadButton(label="Download GLB", interactive=False)
234
  download_gs = gr.DownloadButton(label="Download Gaussian", interactive=False)
235
 
 
236
  output_buf = gr.State()
237
 
238
- # Handlers
239
  demo.load(start_session)
240
  demo.unload(end_session)
241
 
 
 
 
 
 
242
  generate_btn.click(
243
  get_seed,
244
  inputs=[randomize_seed, seed],
245
  outputs=[seed],
 
246
  ).then(
247
  text_to_3d,
248
  inputs=[text_prompt, seed, ss_guidance_strength, ss_sampling_steps, slat_guidance_strength, slat_sampling_steps],
249
- outputs=[output_buf, video_output],
 
250
  ).then(
251
- lambda: tuple([gr.Button(interactive=True), gr.Button(interactive=True)]),
 
 
 
 
 
252
  outputs=[extract_glb_btn, extract_gs_btn],
253
  )
254
 
255
- video_output.clear(
 
 
 
 
 
 
256
  lambda: tuple([gr.Button(interactive=False), gr.Button(interactive=False)]),
257
  outputs=[extract_glb_btn, extract_gs_btn],
258
  )
259
 
 
 
260
  extract_glb_btn.click(
261
  extract_glb,
262
  inputs=[output_buf, mesh_simplify, texture_size],
263
- outputs=[model_output, download_glb],
 
264
  ).then(
265
- lambda: gr.Button(interactive=True),
266
  outputs=[download_glb],
267
  )
268
 
 
269
  extract_gs_btn.click(
270
  extract_gaussian,
271
  inputs=[output_buf],
272
- outputs=[model_output, download_gs],
 
273
  ).then(
274
- lambda: gr.Button(interactive=True),
275
  outputs=[download_gs],
276
  )
277
 
 
278
  model_output.clear(
279
- lambda: gr.Button(interactive=False),
280
- outputs=[download_glb],
281
  )
282
 
283
 
284
- # Launch the Gradio app
285
  if __name__ == "__main__":
286
- pipeline = TrellisTextTo3DPipeline.from_pretrained("JeffreyXiang/TRELLIS-text-xlarge")
287
- pipeline.cuda()
288
- demo.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
39
 
40
  def end_session(req: gr.Request):
41
  user_dir = os.path.join(TMP_DIR, str(req.session_hash))
42
+ # Use shutil.rmtree with ignore_errors=True for robustness
43
+ shutil.rmtree(user_dir, ignore_errors=True)
44
 
45
 
46
  def pack_state(gs: Gaussian, mesh: MeshExtractResult) -> dict:
 
69
  opacity_bias=state['gaussian']['opacity_bias'],
70
  scaling_activation=state['gaussian']['scaling_activation'],
71
  )
72
+ # Ensure tensors are created on the correct device ('cuda')
73
+ gs._xyz = torch.tensor(state['gaussian']['_xyz'], device='cuda', dtype=torch.float32)
74
+ gs._features_dc = torch.tensor(state['gaussian']['_features_dc'], device='cuda', dtype=torch.float32)
75
+ gs._scaling = torch.tensor(state['gaussian']['_scaling'], device='cuda', dtype=torch.float32)
76
+ gs._rotation = torch.tensor(state['gaussian']['_rotation'], device='cuda', dtype=torch.float32)
77
+ gs._opacity = torch.tensor(state['gaussian']['_opacity'], device='cuda', dtype=torch.float32)
78
 
79
  mesh = edict(
80
+ vertices=torch.tensor(state['mesh']['vertices'], device='cuda', dtype=torch.float32),
81
+ faces=torch.tensor(state['mesh']['faces'], device='cuda', dtype=torch.int64), # Faces are usually integers
82
  )
83
 
84
  return gs, mesh
 
100
  slat_guidance_strength: float,
101
  slat_sampling_steps: int,
102
  req: gr.Request,
103
+ ) -> dict: # MODIFIED: Now returns only the state dict
104
  """
105
+ Convert a text prompt to a 3D model state object.
106
  Args:
107
  prompt (str): The text prompt.
108
  seed (int): The random seed.
 
111
  slat_guidance_strength (float): The guidance strength for structured latent generation.
112
  slat_sampling_steps (int): The number of sampling steps for structured latent generation.
113
  Returns:
114
+ dict: The JSON-serializable state object containing the generated 3D model info.
 
115
  """
116
+ # Ensure user directory exists (redundant if start_session is always called, but safe)
117
  user_dir = os.path.join(TMP_DIR, str(req.session_hash))
118
+ os.makedirs(user_dir, exist_ok=True)
119
+
120
+ print(f"[{req.session_hash}] Running text_to_3d for prompt: {prompt}") # Add logging
121
+
122
  outputs = pipeline.run(
123
  prompt,
124
  seed=seed,
 
132
  "cfg_strength": slat_guidance_strength,
133
  },
134
  )
135
+
136
+ # REMOVED: Video rendering logic moved to render_preview_video
137
+ # video = render_utils.render_video(outputs['gaussian'][0], num_frames=120)['color']
138
+ # video_geo = render_utils.render_video(outputs['mesh'][0], num_frames=120)['normal']
139
+ # video = [np.concatenate([video[i], video_geo[i]], axis=1) for i in range(len(video))]
140
+ # video_path = os.path.join(user_dir, 'sample.mp4')
141
+ # imageio.mimsave(video_path, video, fps=15)
142
 
143
  # Create the state object and ensure it's JSON serializable for API calls
144
  state = pack_state(outputs['gaussian'][0], outputs['mesh'][0])
145
  # Convert to serializable format
146
  serializable_state = json.loads(json.dumps(state, cls=NumpyEncoder))
147
 
148
+ print(f"[{req.session_hash}] text_to_3d completed. Returning state.") # Add logging
149
+
150
  torch.cuda.empty_cache()
151
+ return serializable_state # MODIFIED: Return only state
152
+
153
+ # --- NEW FUNCTION ---
154
+ @spaces.GPU
155
+ def render_preview_video(state: dict, req: gr.Request) -> str:
156
+ """
157
+ Renders a preview video from the provided state object.
158
+ Args:
159
+ state (dict): The state object containing Gaussian and mesh data.
160
+ req (gr.Request): Gradio request object for session hash.
161
+ Returns:
162
+ str: The path to the rendered video file.
163
+ """
164
+ if not state:
165
+ print(f"[{req.session_hash}] render_preview_video called with empty state. Returning None.")
166
+ # Consider returning a placeholder or raising an error if state is required
167
+ return None
168
+
169
+ user_dir = os.path.join(TMP_DIR, str(req.session_hash))
170
+ os.makedirs(user_dir, exist_ok=True) # Ensure directory exists
171
+
172
+ print(f"[{req.session_hash}] Unpacking state for video rendering.") # Add logging
173
+ gs, mesh = unpack_state(state)
174
+
175
+ print(f"[{req.session_hash}] Rendering video...") # Add logging
176
+ video = render_utils.render_video(gs, num_frames=120)['color']
177
+ video_geo = render_utils.render_video(mesh, num_frames=120)['normal']
178
+ video = [np.concatenate([video[i], video_geo[i]], axis=1) for i in range(len(video))]
179
+
180
+ video_path = os.path.join(user_dir, 'preview_sample.mp4') # Use a distinct name
181
+ print(f"[{req.session_hash}] Saving video to {video_path}") # Add logging
182
+ imageio.mimsave(video_path, video, fps=15)
183
+
184
+ torch.cuda.empty_cache()
185
+ return video_path
186
+ # --- END NEW FUNCTION ---
187
 
188
 
189
  @spaces.GPU(duration=90)
 
194
  req: gr.Request,
195
  ) -> Tuple[str, str]:
196
  """
197
+ Extract a GLB file from the 3D model state.
198
  Args:
199
  state (dict): The state of the generated 3D model.
200
  mesh_simplify (float): The mesh simplification factor.
201
  texture_size (int): The texture resolution.
202
  Returns:
203
+ str: The path to the extracted GLB file (for Model3D component).
204
+ str: The path to the extracted GLB file (for DownloadButton).
205
  """
206
+ if not state:
207
+ print(f"[{req.session_hash}] extract_glb called with empty state. Returning None.")
208
+ return None, None # Return Nones if state is missing
209
+
210
  user_dir = os.path.join(TMP_DIR, str(req.session_hash))
211
  os.makedirs(user_dir, exist_ok=True)
212
+
213
+ print(f"[{req.session_hash}] Unpacking state for GLB extraction.") # Add logging
214
  gs, mesh = unpack_state(state)
215
+
216
+ print(f"[{req.session_hash}] Extracting GLB (simplify={mesh_simplify}, texture={texture_size})...") # Add logging
217
  glb = postprocessing_utils.to_glb(gs, mesh, simplify=mesh_simplify, texture_size=texture_size, verbose=False)
218
  glb_path = os.path.join(user_dir, 'sample.glb')
219
+ print(f"[{req.session_hash}] Saving GLB to {glb_path}") # Add logging
220
  glb.export(glb_path)
221
+
222
  torch.cuda.empty_cache()
223
+ # Return the same path for both Model3D and DownloadButton components
224
  return glb_path, glb_path
225
 
226
 
227
  @spaces.GPU
228
  def extract_gaussian(state: dict, req: gr.Request) -> Tuple[str, str]:
229
  """
230
+ Extract a Gaussian PLY file from the 3D model state.
231
  Args:
232
  state (dict): The state of the generated 3D model.
233
  Returns:
234
+ str: The path to the extracted Gaussian file (for Model3D component).
235
+ str: The path to the extracted Gaussian file (for DownloadButton).
236
  """
237
+ if not state:
238
+ print(f"[{req.session_hash}] extract_gaussian called with empty state. Returning None.")
239
+ return None, None # Return Nones if state is missing
240
+
241
  user_dir = os.path.join(TMP_DIR, str(req.session_hash))
242
  os.makedirs(user_dir, exist_ok=True)
243
+
244
+ print(f"[{req.session_hash}] Unpacking state for Gaussian extraction.") # Add logging
245
  gs, _ = unpack_state(state)
246
+
247
  gaussian_path = os.path.join(user_dir, 'sample.ply')
248
+ print(f"[{req.session_hash}] Saving Gaussian PLY to {gaussian_path}") # Add logging
249
  gs.save_ply(gaussian_path)
250
+
251
  torch.cuda.empty_cache()
252
+ # Return the same path for both Model3D and DownloadButton components
253
  return gaussian_path, gaussian_path
254
 
255
 
256
+ # State object to hold the generated model info between steps
257
+ output_buf = gr.State()
258
+ # Video component placeholder (will be populated by render_preview_video)
259
+ # video_output = gr.Video(label="Generated 3D Asset", autoplay=True, loop=True, height=300) # Defined later inside the Blocks
260
 
261
  with gr.Blocks(delete_cache=(600, 600)) as demo:
262
  gr.Markdown("""
263
  ## Text to 3D Asset with [TRELLIS](https://trellis3d.github.io/)
264
  * Type a text prompt and click "Generate" to create a 3D asset.
265
+ * The preview video will appear after generation.
266
+ * If you find the generated 3D asset satisfactory, click "Extract GLB" or "Extract Gaussian" to extract the file and download it.
267
  """)
268
 
269
  with gr.Row():
 
289
  texture_size = gr.Slider(512, 2048, label="Texture Size", value=1024, step=512)
290
 
291
  with gr.Row():
292
+ # Buttons start non-interactive, enabled after generation
293
  extract_glb_btn = gr.Button("Extract GLB", interactive=False)
294
  extract_gs_btn = gr.Button("Extract Gaussian", interactive=False)
295
  gr.Markdown("""
 
297
  """)
298
 
299
  with gr.Column():
300
+ # Define UI components here
301
+ video_output = gr.Video(label="Generated 3D Asset Preview", autoplay=True, loop=True, height=300)
302
  model_output = gr.Model3D(label="Extracted GLB/Gaussian", height=300)
303
 
304
  with gr.Row():
305
+ # Buttons start non-interactive, enabled after extraction
306
  download_glb = gr.DownloadButton(label="Download GLB", interactive=False)
307
  download_gs = gr.DownloadButton(label="Download Gaussian", interactive=False)
308
 
309
+ # Define the state buffer here, outside the component definitions but inside the Blocks scope
310
  output_buf = gr.State()
311
 
312
+ # --- Handlers ---
313
  demo.load(start_session)
314
  demo.unload(end_session)
315
 
316
+ # --- MODIFIED UI CHAIN ---
317
+ # 1. Get Seed
318
+ # 2. Run text_to_3d -> outputs state to output_buf
319
+ # 3. Run render_preview_video (using state from output_buf) -> outputs video to video_output
320
+ # 4. Enable extraction buttons
321
  generate_btn.click(
322
  get_seed,
323
  inputs=[randomize_seed, seed],
324
  outputs=[seed],
325
+ queue=True # Use queue for potentially long-running steps
326
  ).then(
327
  text_to_3d,
328
  inputs=[text_prompt, seed, ss_guidance_strength, ss_sampling_steps, slat_guidance_strength, slat_sampling_steps],
329
+ outputs=[output_buf], # text_to_3d now ONLY outputs state
330
+ api_name="text_to_3d" # Keep API name consistent if needed
331
  ).then(
332
+ render_preview_video, # NEW step: Render video from state
333
+ inputs=[output_buf],
334
+ outputs=[video_output],
335
+ api_name="render_preview_video" # Assign API name if you want to call this separately
336
+ ).then(
337
+ lambda: tuple([gr.Button(interactive=True), gr.Button(interactive=True)]), # Enable extraction buttons
338
  outputs=[extract_glb_btn, extract_gs_btn],
339
  )
340
 
341
+ # Clear video and disable extraction buttons if prompt is cleared or generation restarted
342
+ # (Consider adding logic to clear prompt on successful generation if desired)
343
+ text_prompt.change( # Example: Clear video if prompt changes
344
+ lambda: (None, gr.Button(interactive=False), gr.Button(interactive=False)),
345
+ outputs=[video_output, extract_glb_btn, extract_gs_btn]
346
+ )
347
+ video_output.clear( # This might be redundant if text_prompt.change handles it
348
  lambda: tuple([gr.Button(interactive=False), gr.Button(interactive=False)]),
349
  outputs=[extract_glb_btn, extract_gs_btn],
350
  )
351
 
352
+ # --- Extraction Handlers ---
353
+ # GLB Extraction: Takes state from output_buf, outputs model and download path
354
  extract_glb_btn.click(
355
  extract_glb,
356
  inputs=[output_buf, mesh_simplify, texture_size],
357
+ outputs=[model_output, download_glb], # Outputs to Model3D and DownloadButton path
358
+ api_name="extract_glb"
359
  ).then(
360
+ lambda: gr.Button(interactive=True), # Enable download button
361
  outputs=[download_glb],
362
  )
363
 
364
+ # Gaussian Extraction: Takes state from output_buf, outputs model and download path
365
  extract_gs_btn.click(
366
  extract_gaussian,
367
  inputs=[output_buf],
368
+ outputs=[model_output, download_gs], # Outputs to Model3D and DownloadButton path
369
+ api_name="extract_gaussian"
370
  ).then(
371
+ lambda: gr.Button(interactive=True), # Enable download button
372
  outputs=[download_gs],
373
  )
374
 
375
+ # Clear model and disable download buttons if video/state is cleared
376
  model_output.clear(
377
+ lambda: (gr.Button(interactive=False), gr.Button(interactive=False)),
378
+ outputs=[download_glb, download_gs], # Disable both download buttons
379
  )
380
 
381
 
382
+ # --- Launch the Gradio app ---
383
  if __name__ == "__main__":
384
+ print("Loading Trellis pipeline...")
385
+ # Consider adding error handling for pipeline loading
386
+ try:
387
+ pipeline = TrellisTextTo3DPipeline.from_pretrained("JeffreyXiang/TRELLIS-text-xlarge")
388
+ pipeline.cuda()
389
+ print("Pipeline loaded successfully.")
390
+ except Exception as e:
391
+ print(f"Error loading pipeline: {e}")
392
+ # Optionally exit or provide a fallback UI
393
+ sys.exit(1)
394
+
395
+ print("Launching Gradio demo...")
396
+ # Enable queue for handling multiple users/requests
397
+ # Set share=True if you need a public link (requires login for private spaces)
398
+ demo.queue().launch()