Spaces:
Running
on
Zero
Running
on
Zero
keep it all in GRadio no passing state
Browse files
app.py
CHANGED
@@ -140,20 +140,18 @@ def text_to_3d(
|
|
140 |
# Convert to serializable format
|
141 |
serializable_state = json.loads(json.dumps(state, cls=NumpyEncoder))
|
142 |
|
143 |
-
print(f"[{req.session_hash}] text_to_3d completed.
|
144 |
|
145 |
torch.cuda.empty_cache()
|
|
|
|
|
|
|
|
|
|
|
|
|
146 |
|
147 |
-
#
|
148 |
-
|
149 |
-
# This tests if the API mechanism itself can return *any* data.
|
150 |
-
# REMEMBER TO REVERT THIS AFTER TESTING!
|
151 |
-
print("[DEBUG] Returning simple dict for API test.")
|
152 |
-
return {"status": "test_success", "received_prompt": prompt}
|
153 |
-
# --- END TEMPORARY DEBUGGING ---
|
154 |
-
|
155 |
-
# Original return line (commented out for test):
|
156 |
-
# return serializable_state # MODIFIED: Return only state
|
157 |
|
158 |
# --- NEW FUNCTION ---
|
159 |
@spaces.GPU
|
@@ -258,6 +256,91 @@ def extract_gaussian(state: dict, req: gr.Request) -> Tuple[str, str]:
|
|
258 |
return gaussian_path, gaussian_path
|
259 |
|
260 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
261 |
# State object to hold the generated model info between steps
|
262 |
output_buf = gr.State()
|
263 |
# Video component placeholder (will be populated by render_preview_video)
|
@@ -383,6 +466,25 @@ with gr.Blocks(delete_cache=(600, 600)) as demo:
|
|
383 |
outputs=[download_glb, download_gs], # Disable both download buttons
|
384 |
)
|
385 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
386 |
|
387 |
# --- Launch the Gradio app ---
|
388 |
if __name__ == "__main__":
|
|
|
140 |
# Convert to serializable format
|
141 |
serializable_state = json.loads(json.dumps(state, cls=NumpyEncoder))
|
142 |
|
143 |
+
print(f"[{req.session_hash}] text_to_3d completed. Returning state.") # Modified log message
|
144 |
|
145 |
torch.cuda.empty_cache()
|
146 |
+
|
147 |
+
# --- REVERTED DEBUGGING ---
|
148 |
+
# Remove the temporary simple dictionary return
|
149 |
+
# print("[DEBUG] Returning simple dict for API test.")
|
150 |
+
# return {"status": "test_success", "received_prompt": prompt}
|
151 |
+
# --- END REVERTED DEBUGGING ---
|
152 |
|
153 |
+
# Original return line (restored):
|
154 |
+
return serializable_state # MODIFIED: Return only state
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
155 |
|
156 |
# --- NEW FUNCTION ---
|
157 |
@spaces.GPU
|
|
|
256 |
return gaussian_path, gaussian_path
|
257 |
|
258 |
|
259 |
+
# --- NEW COMBINED API FUNCTION ---
|
260 |
+
@spaces.GPU(duration=120) # Allow more time for combined generation + extraction
|
261 |
+
def generate_and_extract_glb(
|
262 |
+
# Inputs mirror text_to_3d and extract_glb settings
|
263 |
+
prompt: str,
|
264 |
+
seed: int,
|
265 |
+
ss_guidance_strength: float,
|
266 |
+
ss_sampling_steps: int,
|
267 |
+
slat_guidance_strength: float,
|
268 |
+
slat_sampling_steps: int,
|
269 |
+
mesh_simplify: float, # Added from extract_glb
|
270 |
+
texture_size: int, # Added from extract_glb
|
271 |
+
req: gr.Request,
|
272 |
+
) -> str: # MODIFIED: Returns only the final GLB path string
|
273 |
+
"""
|
274 |
+
Combines 3D model generation and GLB extraction into a single step
|
275 |
+
for API usage, avoiding the need to transfer the state object.
|
276 |
+
|
277 |
+
Args:
|
278 |
+
prompt (str): Text prompt for generation.
|
279 |
+
seed (int): Random seed.
|
280 |
+
ss_guidance_strength (float): Sparse structure guidance.
|
281 |
+
ss_sampling_steps (int): Sparse structure steps.
|
282 |
+
slat_guidance_strength (float): Structured latent guidance.
|
283 |
+
slat_sampling_steps (int): Structured latent steps.
|
284 |
+
mesh_simplify (float): Mesh simplification factor for GLB.
|
285 |
+
texture_size (int): Texture resolution for GLB.
|
286 |
+
req (gr.Request): Gradio request object.
|
287 |
+
|
288 |
+
Returns:
|
289 |
+
str: The absolute path to the generated GLB file within the Space's filesystem.
|
290 |
+
Returns None if any step fails.
|
291 |
+
"""
|
292 |
+
user_dir = os.path.join(TMP_DIR, str(req.session_hash))
|
293 |
+
os.makedirs(user_dir, exist_ok=True)
|
294 |
+
|
295 |
+
print(f"[{req.session_hash}] API: Starting combined generation and extraction for prompt: {prompt}")
|
296 |
+
|
297 |
+
# --- Step 1: Generate 3D Model (adapted from text_to_3d) ---
|
298 |
+
try:
|
299 |
+
print(f"[{req.session_hash}] API: Running generation pipeline...")
|
300 |
+
outputs = pipeline.run(
|
301 |
+
prompt,
|
302 |
+
seed=seed,
|
303 |
+
formats=["gaussian", "mesh"], # Need both for GLB extraction
|
304 |
+
sparse_structure_sampler_params={
|
305 |
+
"steps": ss_sampling_steps,
|
306 |
+
"cfg_strength": ss_guidance_strength,
|
307 |
+
},
|
308 |
+
slat_sampler_params={
|
309 |
+
"steps": slat_sampling_steps,
|
310 |
+
"cfg_strength": slat_guidance_strength,
|
311 |
+
},
|
312 |
+
)
|
313 |
+
# Keep handles to the direct outputs (no need to pack/unpack state)
|
314 |
+
gs_output = outputs['gaussian'][0]
|
315 |
+
mesh_output = outputs['mesh'][0]
|
316 |
+
print(f"[{req.session_hash}] API: Generation pipeline completed.")
|
317 |
+
except Exception as e:
|
318 |
+
print(f"[{req.session_hash}] API: ERROR during generation pipeline: {e}")
|
319 |
+
traceback.print_exc()
|
320 |
+
torch.cuda.empty_cache()
|
321 |
+
return None # Return None on failure
|
322 |
+
|
323 |
+
# --- Step 2: Extract GLB (adapted from extract_glb) ---
|
324 |
+
try:
|
325 |
+
print(f"[{req.session_hash}] API: Extracting GLB (simplify={mesh_simplify}, texture={texture_size})...")
|
326 |
+
# Directly use the outputs from the pipeline
|
327 |
+
glb = postprocessing_utils.to_glb(gs_output, mesh_output, simplify=mesh_simplify, texture_size=texture_size, verbose=False)
|
328 |
+
glb_path = os.path.join(user_dir, 'api_generated_sample.glb') # Use a distinct name for API outputs
|
329 |
+
print(f"[{req.session_hash}] API: Saving GLB to {glb_path}")
|
330 |
+
glb.export(glb_path)
|
331 |
+
print(f"[{req.session_hash}] API: GLB extraction completed.")
|
332 |
+
except Exception as e:
|
333 |
+
print(f"[{req.session_hash}] API: ERROR during GLB extraction: {e}")
|
334 |
+
traceback.print_exc()
|
335 |
+
torch.cuda.empty_cache()
|
336 |
+
return None # Return None on failure
|
337 |
+
|
338 |
+
torch.cuda.empty_cache()
|
339 |
+
print(f"[{req.session_hash}] API: Combined process successful. Returning GLB path: {glb_path}")
|
340 |
+
return glb_path # Return only the path to the generated GLB
|
341 |
+
# --- END NEW COMBINED API FUNCTION ---
|
342 |
+
|
343 |
+
|
344 |
# State object to hold the generated model info between steps
|
345 |
output_buf = gr.State()
|
346 |
# Video component placeholder (will be populated by render_preview_video)
|
|
|
466 |
outputs=[download_glb, download_gs], # Disable both download buttons
|
467 |
)
|
468 |
|
469 |
+
# --- NEW API ENDPOINT DEFINITION ---
|
470 |
+
# Define the combined function as an API endpoint.
|
471 |
+
# This is *separate* from the UI button clicks.
|
472 |
+
# It directly calls the combined function.
|
473 |
+
demo.load(
|
474 |
+
None, # No function needed on load for this endpoint
|
475 |
+
inputs=[
|
476 |
+
text_prompt, # Map inputs from API request data based on order
|
477 |
+
seed,
|
478 |
+
ss_guidance_strength,
|
479 |
+
ss_sampling_steps,
|
480 |
+
slat_guidance_strength,
|
481 |
+
slat_sampling_steps,
|
482 |
+
mesh_simplify,
|
483 |
+
texture_size
|
484 |
+
],
|
485 |
+
outputs=None, # Output is handled by the function return for the API
|
486 |
+
api_name="generate_and_extract_glb" # Assign the specific API name
|
487 |
+
)
|
488 |
|
489 |
# --- Launch the Gradio app ---
|
490 |
if __name__ == "__main__":
|