dkatz2391's picture
keep it all on HF - easy dict errors
434fa76 verified
raw
history blame
21.1 kB
import gradio as gr
import spaces
import os
import shutil
import json
os.environ['TOKENIZERS_PARALLELISM'] = 'true'
os.environ['SPCONV_ALGO'] = 'native'
from typing import *
import torch
import numpy as np
import imageio
from easydict import EasyDict as edict
from trellis.pipelines import TrellisTextTo3DPipeline
from trellis.representations import Gaussian, MeshExtractResult
from trellis.utils import render_utils, postprocessing_utils
import traceback
import sys
# Add JSON encoder for NumPy arrays
class NumpyEncoder(json.JSONEncoder):
def default(self, obj):
if isinstance(obj, np.ndarray):
return obj.tolist()
return json.JSONEncoder.default(self, obj)
MAX_SEED = np.iinfo(np.int32).max
TMP_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'tmp')
os.makedirs(TMP_DIR, exist_ok=True)
def start_session(req: gr.Request):
user_dir = os.path.join(TMP_DIR, str(req.session_hash))
os.makedirs(user_dir, exist_ok=True)
def end_session(req: gr.Request):
user_dir = os.path.join(TMP_DIR, str(req.session_hash))
# Use shutil.rmtree with ignore_errors=True for robustness
shutil.rmtree(user_dir, ignore_errors=True)
def pack_state(gs: Gaussian, mesh: MeshExtractResult) -> dict:
return {
'gaussian': {
**gs.init_params,
'_xyz': gs._xyz.cpu().numpy(),
'_features_dc': gs._features_dc.cpu().numpy(),
'_scaling': gs._scaling.cpu().numpy(),
'_rotation': gs._rotation.cpu().numpy(),
'_opacity': gs._opacity.cpu().numpy(),
},
'mesh': {
'vertices': mesh.vertices.cpu().numpy(),
'faces': mesh.faces.cpu().numpy(),
},
}
def unpack_state(state: dict) -> Tuple[Gaussian, edict, str]:
gs = Gaussian(
aabb=state['gaussian']['aabb'],
sh_degree=state['gaussian']['sh_degree'],
mininum_kernel_size=state['gaussian']['mininum_kernel_size'],
scaling_bias=state['gaussian']['scaling_bias'],
opacity_bias=state['gaussian']['opacity_bias'],
scaling_activation=state['gaussian']['scaling_activation'],
)
# Ensure tensors are created on the correct device ('cuda')
gs._xyz = torch.tensor(state['gaussian']['_xyz'], device='cuda', dtype=torch.float32)
gs._features_dc = torch.tensor(state['gaussian']['_features_dc'], device='cuda', dtype=torch.float32)
gs._scaling = torch.tensor(state['gaussian']['_scaling'], device='cuda', dtype=torch.float32)
gs._rotation = torch.tensor(state['gaussian']['_rotation'], device='cuda', dtype=torch.float32)
gs._opacity = torch.tensor(state['gaussian']['_opacity'], device='cuda', dtype=torch.float32)
mesh = edict(
vertices=torch.tensor(state['mesh']['vertices'], device='cuda', dtype=torch.float32),
faces=torch.tensor(state['mesh']['faces'], device='cuda', dtype=torch.int64), # Faces are usually integers
)
return gs, mesh
def get_seed(randomize_seed: bool, seed: int) -> int:
"""
Get the random seed.
"""
return np.random.randint(0, MAX_SEED) if randomize_seed else seed
@spaces.GPU
def text_to_3d(
prompt: str,
seed: int,
ss_guidance_strength: float,
ss_sampling_steps: int,
slat_guidance_strength: float,
slat_sampling_steps: int,
req: gr.Request,
) -> dict: # MODIFIED: Now returns only the state dict
"""
Convert a text prompt to a 3D model state object.
Args:
prompt (str): The text prompt.
seed (int): The random seed.
ss_guidance_strength (float): The guidance strength for sparse structure generation.
ss_sampling_steps (int): The number of sampling steps for sparse structure generation.
slat_guidance_strength (float): The guidance strength for structured latent generation.
slat_sampling_steps (int): The number of sampling steps for structured latent generation.
Returns:
dict: The JSON-serializable state object containing the generated 3D model info.
"""
# Ensure user directory exists (redundant if start_session is always called, but safe)
user_dir = os.path.join(TMP_DIR, str(req.session_hash))
os.makedirs(user_dir, exist_ok=True)
print(f"[{req.session_hash}] Running text_to_3d for prompt: {prompt}") # Add logging
outputs = pipeline.run(
prompt,
seed=seed,
formats=["gaussian", "mesh"],
sparse_structure_sampler_params={
"steps": ss_sampling_steps,
"cfg_strength": ss_guidance_strength,
},
slat_sampler_params={
"steps": slat_sampling_steps,
"cfg_strength": slat_guidance_strength,
},
)
# REMOVED: Video rendering logic moved to render_preview_video
# Create the state object and ensure it's JSON serializable for API calls
state = pack_state(outputs['gaussian'][0], outputs['mesh'][0])
# Convert to serializable format
serializable_state = json.loads(json.dumps(state, cls=NumpyEncoder))
print(f"[{req.session_hash}] text_to_3d completed. Returning state.") # Modified log message
torch.cuda.empty_cache()
# --- REVERTED DEBUGGING ---
# Remove the temporary simple dictionary return
# print("[DEBUG] Returning simple dict for API test.")
# return {"status": "test_success", "received_prompt": prompt}
# --- END REVERTED DEBUGGING ---
# Original return line (restored):
return serializable_state # MODIFIED: Return only state
# --- NEW FUNCTION ---
@spaces.GPU
def render_preview_video(state: dict, req: gr.Request) -> str:
"""
Renders a preview video from the provided state object.
Args:
state (dict): The state object containing Gaussian and mesh data.
req (gr.Request): Gradio request object for session hash.
Returns:
str: The path to the rendered video file.
"""
if not state:
print(f"[{req.session_hash}] render_preview_video called with empty state. Returning None.")
return None
user_dir = os.path.join(TMP_DIR, str(req.session_hash))
os.makedirs(user_dir, exist_ok=True) # Ensure directory exists
print(f"[{req.session_hash}] Unpacking state for video rendering.")
# Only unpack gs, as mesh causes type errors with render_utils after unpacking
gs, _ = unpack_state(state) # We still need the mesh for GLB, but not for this video preview
print(f"[{req.session_hash}] Rendering video (Gaussian only)...")
# Render ONLY the Gaussian splats, as rendering the unpacked mesh fails
video = render_utils.render_video(gs, num_frames=120)['color']
# REMOVED: video_geo = render_utils.render_video(mesh, num_frames=120)['normal']
# REMOVED: video = [np.concatenate([video[i], video_geo[i]], axis=1) for i in range(len(video))]
video_path = os.path.join(user_dir, 'preview_sample.mp4')
print(f"[{req.session_hash}] Saving video to {video_path}")
# Save only the Gaussian render
imageio.mimsave(video_path, video, fps=15)
torch.cuda.empty_cache()
return video_path
# --- END NEW FUNCTION ---
@spaces.GPU(duration=90)
def extract_glb(
state: dict,
mesh_simplify: float,
texture_size: int,
req: gr.Request,
) -> Tuple[str, str]:
"""
Extract a GLB file from the 3D model state.
Args:
state (dict): The state of the generated 3D model.
mesh_simplify (float): The mesh simplification factor.
texture_size (int): The texture resolution.
Returns:
str: The path to the extracted GLB file (for Model3D component).
str: The path to the extracted GLB file (for DownloadButton).
"""
if not state:
print(f"[{req.session_hash}] extract_glb called with empty state. Returning None.")
return None, None # Return Nones if state is missing
user_dir = os.path.join(TMP_DIR, str(req.session_hash))
os.makedirs(user_dir, exist_ok=True)
print(f"[{req.session_hash}] Unpacking state for GLB extraction.") # Add logging
gs, mesh = unpack_state(state)
print(f"[{req.session_hash}] Extracting GLB (simplify={mesh_simplify}, texture={texture_size})...") # Add logging
glb = postprocessing_utils.to_glb(gs, mesh, simplify=mesh_simplify, texture_size=texture_size, verbose=False)
glb_path = os.path.join(user_dir, 'sample.glb')
print(f"[{req.session_hash}] Saving GLB to {glb_path}") # Add logging
glb.export(glb_path)
torch.cuda.empty_cache()
# Return the same path for both Model3D and DownloadButton components
return glb_path, glb_path
@spaces.GPU
def extract_gaussian(state: dict, req: gr.Request) -> Tuple[str, str]:
"""
Extract a Gaussian PLY file from the 3D model state.
Args:
state (dict): The state of the generated 3D model.
Returns:
str: The path to the extracted Gaussian file (for Model3D component).
str: The path to the extracted Gaussian file (for DownloadButton).
"""
if not state:
print(f"[{req.session_hash}] extract_gaussian called with empty state. Returning None.")
return None, None # Return Nones if state is missing
user_dir = os.path.join(TMP_DIR, str(req.session_hash))
os.makedirs(user_dir, exist_ok=True)
print(f"[{req.session_hash}] Unpacking state for Gaussian extraction.") # Add logging
gs, _ = unpack_state(state)
gaussian_path = os.path.join(user_dir, 'sample.ply')
print(f"[{req.session_hash}] Saving Gaussian PLY to {gaussian_path}") # Add logging
gs.save_ply(gaussian_path)
torch.cuda.empty_cache()
# Return the same path for both Model3D and DownloadButton components
return gaussian_path, gaussian_path
# --- NEW COMBINED API FUNCTION ---
@spaces.GPU(duration=120) # Allow more time for combined generation + extraction
def generate_and_extract_glb(
# Inputs mirror text_to_3d and extract_glb settings
prompt: str,
seed: int,
ss_guidance_strength: float,
ss_sampling_steps: int,
slat_guidance_strength: float,
slat_sampling_steps: int,
mesh_simplify: float, # Added from extract_glb
texture_size: int, # Added from extract_glb
req: gr.Request,
) -> str: # MODIFIED: Returns only the final GLB path string
"""
Combines 3D model generation and GLB extraction into a single step
for API usage, avoiding the need to transfer the state object.
Args:
prompt (str): Text prompt for generation.
seed (int): Random seed.
ss_guidance_strength (float): Sparse structure guidance.
ss_sampling_steps (int): Sparse structure steps.
slat_guidance_strength (float): Structured latent guidance.
slat_sampling_steps (int): Structured latent steps.
mesh_simplify (float): Mesh simplification factor for GLB.
texture_size (int): Texture resolution for GLB.
req (gr.Request): Gradio request object.
Returns:
str: The absolute path to the generated GLB file within the Space's filesystem.
Returns None if any step fails.
"""
user_dir = os.path.join(TMP_DIR, str(req.session_hash))
os.makedirs(user_dir, exist_ok=True)
print(f"[{req.session_hash}] API: Starting combined generation and extraction for prompt: {prompt}")
# --- Step 1: Generate 3D Model (adapted from text_to_3d) ---
try:
print(f"[{req.session_hash}] API: Running generation pipeline...")
outputs = pipeline.run(
prompt,
seed=seed,
formats=["gaussian", "mesh"], # Need both for GLB extraction
sparse_structure_sampler_params={
"steps": ss_sampling_steps,
"cfg_strength": ss_guidance_strength,
},
slat_sampler_params={
"steps": slat_sampling_steps,
"cfg_strength": slat_guidance_strength,
},
)
# Keep handles to the direct outputs (no need to pack/unpack state)
gs_output = outputs['gaussian'][0]
mesh_output = outputs['mesh'][0]
print(f"[{req.session_hash}] API: Generation pipeline completed.")
except Exception as e:
print(f"[{req.session_hash}] API: ERROR during generation pipeline: {e}")
traceback.print_exc()
torch.cuda.empty_cache()
return None # Return None on failure
# --- Step 2: Extract GLB (adapted from extract_glb) ---
try:
print(f"[{req.session_hash}] API: Extracting GLB (simplify={mesh_simplify}, texture={texture_size})...")
# Directly use the outputs from the pipeline
glb = postprocessing_utils.to_glb(gs_output, mesh_output, simplify=mesh_simplify, texture_size=texture_size, verbose=False)
glb_path = os.path.join(user_dir, 'api_generated_sample.glb') # Use a distinct name for API outputs
print(f"[{req.session_hash}] API: Saving GLB to {glb_path}")
glb.export(glb_path)
print(f"[{req.session_hash}] API: GLB extraction completed.")
except Exception as e:
print(f"[{req.session_hash}] API: ERROR during GLB extraction: {e}")
traceback.print_exc()
torch.cuda.empty_cache()
return None # Return None on failure
torch.cuda.empty_cache()
print(f"[{req.session_hash}] API: Combined process successful. Returning GLB path: {glb_path}")
return glb_path # Return only the path to the generated GLB
# --- END NEW COMBINED API FUNCTION ---
# State object to hold the generated model info between steps
output_buf = gr.State()
# Video component placeholder (will be populated by render_preview_video)
# video_output = gr.Video(label="Generated 3D Asset", autoplay=True, loop=True, height=300) # Defined later inside the Blocks
with gr.Blocks(delete_cache=(600, 600)) as demo:
gr.Markdown("""
## Text to 3D Asset with [TRELLIS](https://trellis3d.github.io/)
* Type a text prompt and click "Generate" to create a 3D asset.
* The preview video will appear after generation.
* If you find the generated 3D asset satisfactory, click "Extract GLB" or "Extract Gaussian" to extract the file and download it.
""")
with gr.Row():
with gr.Column():
text_prompt = gr.Textbox(label="Text Prompt", lines=5)
with gr.Accordion(label="Generation Settings", open=False):
seed = gr.Slider(0, MAX_SEED, label="Seed", value=0, step=1)
randomize_seed = gr.Checkbox(label="Randomize Seed", value=True)
gr.Markdown("Stage 1: Sparse Structure Generation")
with gr.Row():
ss_guidance_strength = gr.Slider(0.0, 10.0, label="Guidance Strength", value=7.5, step=0.1)
ss_sampling_steps = gr.Slider(1, 50, label="Sampling Steps", value=25, step=1)
gr.Markdown("Stage 2: Structured Latent Generation")
with gr.Row():
slat_guidance_strength = gr.Slider(0.0, 10.0, label="Guidance Strength", value=7.5, step=0.1)
slat_sampling_steps = gr.Slider(1, 50, label="Sampling Steps", value=25, step=1)
generate_btn = gr.Button("Generate")
with gr.Accordion(label="GLB Extraction Settings", open=False):
mesh_simplify = gr.Slider(0.9, 0.98, label="Simplify", value=0.95, step=0.01)
texture_size = gr.Slider(512, 2048, label="Texture Size", value=1024, step=512)
with gr.Row():
# Buttons start non-interactive, enabled after generation
extract_glb_btn = gr.Button("Extract GLB", interactive=False)
extract_gs_btn = gr.Button("Extract Gaussian", interactive=False)
gr.Markdown("""
*NOTE: Gaussian file can be very large (~50MB), it will take a while to display and download.*
""")
with gr.Column():
# Define UI components here
video_output = gr.Video(label="Generated 3D Asset Preview", autoplay=True, loop=True, height=300)
model_output = gr.Model3D(label="Extracted GLB/Gaussian", height=300)
with gr.Row():
# Buttons start non-interactive, enabled after extraction
download_glb = gr.DownloadButton(label="Download GLB", interactive=False)
download_gs = gr.DownloadButton(label="Download Gaussian", interactive=False)
# Define the state buffer here, outside the component definitions but inside the Blocks scope
output_buf = gr.State()
# --- Handlers ---
demo.load(start_session)
demo.unload(end_session)
# --- MODIFIED UI CHAIN ---
# 1. Get Seed
# 2. Run text_to_3d -> outputs state to output_buf
# 3. Run render_preview_video (using state from output_buf) -> outputs video to video_output
# 4. Enable extraction buttons
generate_btn.click(
get_seed,
inputs=[randomize_seed, seed],
outputs=[seed],
queue=True # Use queue for potentially long-running steps
).then(
text_to_3d,
inputs=[text_prompt, seed, ss_guidance_strength, ss_sampling_steps, slat_guidance_strength, slat_sampling_steps],
outputs=[output_buf], # text_to_3d now ONLY outputs state
api_name="text_to_3d" # Keep API name consistent if needed
).then(
render_preview_video, # NEW step: Render video from state
inputs=[output_buf],
outputs=[video_output],
api_name="render_preview_video" # Assign API name if you want to call this separately
).then(
lambda: tuple([gr.Button(interactive=True), gr.Button(interactive=True)]), # Enable extraction buttons
outputs=[extract_glb_btn, extract_gs_btn],
)
# Clear video and disable extraction buttons if prompt is cleared or generation restarted
# (Consider adding logic to clear prompt on successful generation if desired)
text_prompt.change( # Example: Clear video if prompt changes
lambda: (None, gr.Button(interactive=False), gr.Button(interactive=False)),
outputs=[video_output, extract_glb_btn, extract_gs_btn]
)
video_output.clear( # This might be redundant if text_prompt.change handles it
lambda: tuple([gr.Button(interactive=False), gr.Button(interactive=False)]),
outputs=[extract_glb_btn, extract_gs_btn],
)
# --- Extraction Handlers ---
# GLB Extraction: Takes state from output_buf, outputs model and download path
extract_glb_btn.click(
extract_glb,
inputs=[output_buf, mesh_simplify, texture_size],
outputs=[model_output, download_glb], # Outputs to Model3D and DownloadButton path
api_name="extract_glb"
).then(
lambda: gr.Button(interactive=True), # Enable download button
outputs=[download_glb],
)
# Gaussian Extraction: Takes state from output_buf, outputs model and download path
extract_gs_btn.click(
extract_gaussian,
inputs=[output_buf],
outputs=[model_output, download_gs], # Outputs to Model3D and DownloadButton path
api_name="extract_gaussian"
).then(
lambda: gr.Button(interactive=True), # Enable download button
outputs=[download_gs],
)
# Clear model and disable download buttons if video/state is cleared
model_output.clear(
lambda: (gr.Button(interactive=False), gr.Button(interactive=False)),
outputs=[download_glb, download_gs], # Disable both download buttons
)
# --- NEW API ENDPOINT DEFINITION ---
# Define the combined function as an API endpoint.
# This is *separate* from the UI button clicks.
# It directly calls the combined function.
demo.load(
None, # No function needed on load for this endpoint
inputs=[
text_prompt, # Map inputs from API request data based on order
seed,
ss_guidance_strength,
ss_sampling_steps,
slat_guidance_strength,
slat_sampling_steps,
mesh_simplify,
texture_size
],
outputs=None, # Output is handled by the function return for the API
api_name="generate_and_extract_glb" # Assign the specific API name
)
# --- Launch the Gradio app ---
if __name__ == "__main__":
print("Loading Trellis pipeline...")
# Consider adding error handling for pipeline loading
try:
pipeline = TrellisTextTo3DPipeline.from_pretrained("JeffreyXiang/TRELLIS-text-xlarge")
pipeline.cuda()
print("Pipeline loaded successfully.")
except Exception as e:
print(f"Error loading pipeline: {e}")
# Optionally exit or provide a fallback UI
sys.exit(1)
print("Launching Gradio demo...")
# Enable queue for handling multiple users/requests
# Set share=True if you need a public link (requires login for private spaces)
demo.queue().launch()