dkatz2391 commited on
Commit
f1b0be5
·
verified ·
1 Parent(s): 41481bc

keep it all in GRadio no passing state

Browse files
Files changed (1) hide show
  1. app.py +113 -11
app.py CHANGED
@@ -140,20 +140,18 @@ def text_to_3d(
140
  # Convert to serializable format
141
  serializable_state = json.loads(json.dumps(state, cls=NumpyEncoder))
142
 
143
- print(f"[{req.session_hash}] text_to_3d completed. Intended state generated.") # Add logging
144
 
145
  torch.cuda.empty_cache()
 
 
 
 
 
 
146
 
147
- # --- TEMPORARY DEBUGGING ---
148
- # Instead of returning the complex state, return a simple dictionary.
149
- # This tests if the API mechanism itself can return *any* data.
150
- # REMEMBER TO REVERT THIS AFTER TESTING!
151
- print("[DEBUG] Returning simple dict for API test.")
152
- return {"status": "test_success", "received_prompt": prompt}
153
- # --- END TEMPORARY DEBUGGING ---
154
-
155
- # Original return line (commented out for test):
156
- # return serializable_state # MODIFIED: Return only state
157
 
158
  # --- NEW FUNCTION ---
159
  @spaces.GPU
@@ -258,6 +256,91 @@ def extract_gaussian(state: dict, req: gr.Request) -> Tuple[str, str]:
258
  return gaussian_path, gaussian_path
259
 
260
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
261
  # State object to hold the generated model info between steps
262
  output_buf = gr.State()
263
  # Video component placeholder (will be populated by render_preview_video)
@@ -383,6 +466,25 @@ with gr.Blocks(delete_cache=(600, 600)) as demo:
383
  outputs=[download_glb, download_gs], # Disable both download buttons
384
  )
385
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
386
 
387
  # --- Launch the Gradio app ---
388
  if __name__ == "__main__":
 
140
  # Convert to serializable format
141
  serializable_state = json.loads(json.dumps(state, cls=NumpyEncoder))
142
 
143
+ print(f"[{req.session_hash}] text_to_3d completed. Returning state.") # Modified log message
144
 
145
  torch.cuda.empty_cache()
146
+
147
+ # --- REVERTED DEBUGGING ---
148
+ # Remove the temporary simple dictionary return
149
+ # print("[DEBUG] Returning simple dict for API test.")
150
+ # return {"status": "test_success", "received_prompt": prompt}
151
+ # --- END REVERTED DEBUGGING ---
152
 
153
+ # Original return line (restored):
154
+ return serializable_state # MODIFIED: Return only state
 
 
 
 
 
 
 
 
155
 
156
  # --- NEW FUNCTION ---
157
  @spaces.GPU
 
256
  return gaussian_path, gaussian_path
257
 
258
 
259
+ # --- NEW COMBINED API FUNCTION ---
260
+ @spaces.GPU(duration=120) # Allow more time for combined generation + extraction
261
+ def generate_and_extract_glb(
262
+ # Inputs mirror text_to_3d and extract_glb settings
263
+ prompt: str,
264
+ seed: int,
265
+ ss_guidance_strength: float,
266
+ ss_sampling_steps: int,
267
+ slat_guidance_strength: float,
268
+ slat_sampling_steps: int,
269
+ mesh_simplify: float, # Added from extract_glb
270
+ texture_size: int, # Added from extract_glb
271
+ req: gr.Request,
272
+ ) -> str: # MODIFIED: Returns only the final GLB path string
273
+ """
274
+ Combines 3D model generation and GLB extraction into a single step
275
+ for API usage, avoiding the need to transfer the state object.
276
+
277
+ Args:
278
+ prompt (str): Text prompt for generation.
279
+ seed (int): Random seed.
280
+ ss_guidance_strength (float): Sparse structure guidance.
281
+ ss_sampling_steps (int): Sparse structure steps.
282
+ slat_guidance_strength (float): Structured latent guidance.
283
+ slat_sampling_steps (int): Structured latent steps.
284
+ mesh_simplify (float): Mesh simplification factor for GLB.
285
+ texture_size (int): Texture resolution for GLB.
286
+ req (gr.Request): Gradio request object.
287
+
288
+ Returns:
289
+ str: The absolute path to the generated GLB file within the Space's filesystem.
290
+ Returns None if any step fails.
291
+ """
292
+ user_dir = os.path.join(TMP_DIR, str(req.session_hash))
293
+ os.makedirs(user_dir, exist_ok=True)
294
+
295
+ print(f"[{req.session_hash}] API: Starting combined generation and extraction for prompt: {prompt}")
296
+
297
+ # --- Step 1: Generate 3D Model (adapted from text_to_3d) ---
298
+ try:
299
+ print(f"[{req.session_hash}] API: Running generation pipeline...")
300
+ outputs = pipeline.run(
301
+ prompt,
302
+ seed=seed,
303
+ formats=["gaussian", "mesh"], # Need both for GLB extraction
304
+ sparse_structure_sampler_params={
305
+ "steps": ss_sampling_steps,
306
+ "cfg_strength": ss_guidance_strength,
307
+ },
308
+ slat_sampler_params={
309
+ "steps": slat_sampling_steps,
310
+ "cfg_strength": slat_guidance_strength,
311
+ },
312
+ )
313
+ # Keep handles to the direct outputs (no need to pack/unpack state)
314
+ gs_output = outputs['gaussian'][0]
315
+ mesh_output = outputs['mesh'][0]
316
+ print(f"[{req.session_hash}] API: Generation pipeline completed.")
317
+ except Exception as e:
318
+ print(f"[{req.session_hash}] API: ERROR during generation pipeline: {e}")
319
+ traceback.print_exc()
320
+ torch.cuda.empty_cache()
321
+ return None # Return None on failure
322
+
323
+ # --- Step 2: Extract GLB (adapted from extract_glb) ---
324
+ try:
325
+ print(f"[{req.session_hash}] API: Extracting GLB (simplify={mesh_simplify}, texture={texture_size})...")
326
+ # Directly use the outputs from the pipeline
327
+ glb = postprocessing_utils.to_glb(gs_output, mesh_output, simplify=mesh_simplify, texture_size=texture_size, verbose=False)
328
+ glb_path = os.path.join(user_dir, 'api_generated_sample.glb') # Use a distinct name for API outputs
329
+ print(f"[{req.session_hash}] API: Saving GLB to {glb_path}")
330
+ glb.export(glb_path)
331
+ print(f"[{req.session_hash}] API: GLB extraction completed.")
332
+ except Exception as e:
333
+ print(f"[{req.session_hash}] API: ERROR during GLB extraction: {e}")
334
+ traceback.print_exc()
335
+ torch.cuda.empty_cache()
336
+ return None # Return None on failure
337
+
338
+ torch.cuda.empty_cache()
339
+ print(f"[{req.session_hash}] API: Combined process successful. Returning GLB path: {glb_path}")
340
+ return glb_path # Return only the path to the generated GLB
341
+ # --- END NEW COMBINED API FUNCTION ---
342
+
343
+
344
  # State object to hold the generated model info between steps
345
  output_buf = gr.State()
346
  # Video component placeholder (will be populated by render_preview_video)
 
466
  outputs=[download_glb, download_gs], # Disable both download buttons
467
  )
468
 
469
+ # --- NEW API ENDPOINT DEFINITION ---
470
+ # Define the combined function as an API endpoint.
471
+ # This is *separate* from the UI button clicks.
472
+ # It directly calls the combined function.
473
+ demo.load(
474
+ None, # No function needed on load for this endpoint
475
+ inputs=[
476
+ text_prompt, # Map inputs from API request data based on order
477
+ seed,
478
+ ss_guidance_strength,
479
+ ss_sampling_steps,
480
+ slat_guidance_strength,
481
+ slat_sampling_steps,
482
+ mesh_simplify,
483
+ texture_size
484
+ ],
485
+ outputs=None, # Output is handled by the function return for the API
486
+ api_name="generate_and_extract_glb" # Assign the specific API name
487
+ )
488
 
489
  # --- Launch the Gradio app ---
490
  if __name__ == "__main__":