sp3d / gradio_app.py
rgndgn's picture
Update gradio_app.py
59beb43 verified
import os
import random
import tempfile
import time
import zipfile
from contextlib import nullcontext
from functools import lru_cache
from typing import Any
import cv2
import gradio as gr
import numpy as np
import torch
import trimesh
from gradio_litmodel3d import LitModel3D
from gradio_pointcloudeditor import PointCloudEditor
from PIL import Image
from transparent_background import Remover
os.system("USE_CUDA=1 pip install -vv --no-build-isolation ./texture_baker ./uv_unwrapper")
os.system("pip install ./deps/pynim-0.0.3-cp310-cp310-linux_x86_64.whl")
import spar3d.utils as spar3d_utils
from spar3d.models.mesh import QUAD_REMESH_AVAILABLE, TRIANGLE_REMESH_AVAILABLE
from spar3d.system import SPAR3D
os.environ["GRADIO_TEMP_DIR"] = os.path.join(os.environ.get("TMPDIR", "/tmp"), "gradio")
bg_remover = Remover() # default setting
COND_WIDTH = 512
COND_HEIGHT = 512
COND_DISTANCE = 2.2
COND_FOVY = 0.591627
BACKGROUND_COLOR = [0.5, 0.5, 0.5]
# Cached. Doesn't change
c2w_cond = spar3d_utils.default_cond_c2w(COND_DISTANCE)
intrinsic, intrinsic_normed_cond = spar3d_utils.create_intrinsic_from_fov_rad(
COND_FOVY, COND_HEIGHT, COND_WIDTH
)
generated_files = []
# Delete previous gradio temp dir folder
if os.path.exists(os.environ["GRADIO_TEMP_DIR"]):
print(f"Deleting {os.environ['GRADIO_TEMP_DIR']}")
import shutil
shutil.rmtree(os.environ["GRADIO_TEMP_DIR"])
device = spar3d_utils.get_device()
model = SPAR3D.from_pretrained(
"stabilityai/stable-point-aware-3d",
config_name="config.yaml",
weight_name="model.safetensors",
)
model.eval()
model = model.to(device)
example_files = [
os.path.join("demo_files/examples", f) for f in os.listdir("demo_files/examples")
]
def create_zip_file(glb_file, pc_file, illumination_file):
if not all([glb_file, pc_file, illumination_file]):
return None
# Create a temporary zip file
temp_dir = tempfile.mkdtemp()
zip_path = os.path.join(temp_dir, "spar3d_output.zip")
with zipfile.ZipFile(zip_path, "w") as zipf:
zipf.write(glb_file, "mesh.glb")
zipf.write(pc_file, "points.ply")
zipf.write(illumination_file, "illumination.hdr")
generated_files.append(zip_path)
return zip_path
def forward_model(
batch,
system,
guidance_scale=3.0,
seed=0,
device="cuda",
remesh_option="none",
vertex_count=-1,
texture_resolution=1024,
):
batch_size = batch["rgb_cond"].shape[0]
# prepare the condition for point cloud generation
# set seed
random.seed(seed)
torch.manual_seed(seed)
np.random.seed(seed)
cond_tokens = system.forward_pdiff_cond(batch)
if "pc_cond" not in batch:
sample_iter = system.sampler.sample_batch_progressive(
batch_size,
cond_tokens,
guidance_scale=guidance_scale,
device=device,
)
for x in sample_iter:
samples = x["xstart"]
batch["pc_cond"] = samples.permute(0, 2, 1).float()
batch["pc_cond"] = spar3d_utils.normalize_pc_bbox(batch["pc_cond"])
# subsample to the 512 points
batch["pc_cond"] = batch["pc_cond"][
:, torch.randperm(batch["pc_cond"].shape[1])[:512]
]
# get the point cloud
xyz = batch["pc_cond"][0, :, :3].cpu().numpy()
color_rgb = (batch["pc_cond"][0, :, 3:6] * 255).cpu().numpy().astype(np.uint8)
pc_rgb_trimesh = trimesh.PointCloud(vertices=xyz, colors=color_rgb)
# forward for the final mesh
trimesh_mesh, _glob_dict = model.generate_mesh(
batch,
texture_resolution,
remesh=remesh_option,
vertex_count=vertex_count,
estimate_illumination=True,
)
trimesh_mesh = trimesh_mesh[0]
illumination = _glob_dict["illumination"]
return trimesh_mesh, pc_rgb_trimesh, illumination.cpu().detach().numpy()[0]
def process_model_run(
fr_res,
guidance_scale,
random_seed,
pc_cond,
remesh_option,
vertex_count_type,
vertex_count,
texture_resolution,
):
start = time.time()
with torch.no_grad():
with (
torch.autocast(device_type=device, dtype=torch.bfloat16)
if "cuda" in device
else nullcontext()
):
model_batch = create_batch(fr_res)
model_batch = {k: v.to(device) for k, v in model_batch.items()}
trimesh_mesh, trimesh_pc, illumination_map = forward_model(
model_batch,
model,
guidance_scale=guidance_scale,
seed=random_seed,
device="cuda",
remesh_option=remesh_option.lower(),
vertex_count=vertex_count,
texture_resolution=texture_resolution,
)
# Create new tmp file
temp_dir = tempfile.mkdtemp()
tmp_file = os.path.join(temp_dir, "mesh.glb")
trimesh_mesh.export(tmp_file, file_type="glb", include_normals=True)
generated_files.append(tmp_file)
tmp_file_pc = os.path.join(temp_dir, "points.ply")
trimesh_pc.export(tmp_file_pc)
generated_files.append(tmp_file_pc)
tmp_file_illumination = os.path.join(temp_dir, "illumination.hdr")
cv2.imwrite(tmp_file_illumination, illumination_map)
generated_files.append(tmp_file_illumination)
print("Generation took:", time.time() - start, "s")
return tmp_file, tmp_file_pc, tmp_file_illumination, trimesh_pc
def create_batch(input_image: Image) -> dict[str, Any]:
img_cond = (
torch.from_numpy(
np.asarray(input_image.resize((COND_WIDTH, COND_HEIGHT))).astype(np.float32)
/ 255.0
)
.float()
.clip(0, 1)
)
mask_cond = img_cond[:, :, -1:]
rgb_cond = torch.lerp(
torch.tensor(BACKGROUND_COLOR)[None, None, :], img_cond[:, :, :3], mask_cond
)
batch_elem = {
"rgb_cond": rgb_cond,
"mask_cond": mask_cond,
"c2w_cond": c2w_cond.unsqueeze(0),
"intrinsic_cond": intrinsic.unsqueeze(0),
"intrinsic_normed_cond": intrinsic_normed_cond.unsqueeze(0),
}
# Add batch dim
batched = {k: v.unsqueeze(0) for k, v in batch_elem.items()}
return batched
def remove_background(input_image: Image) -> Image:
return bg_remover.process(input_image.convert("RGB"))
def auto_process(input_image):
if input_image is None:
return None, None, None, None
# Default values
guidance_scale = 3.0
random_seed = 0
foreground_ratio = 1.3
remesh_option = "None"
vertex_count_type = "Keep Vertex Count"
vertex_count = 2000
texture_resolution = 1024
no_crop = False
pc_cond = None
# First step: Remove background
rem_removed = remove_background(input_image)
fr_res = spar3d_utils.foreground_crop(
rem_removed,
crop_ratio=foreground_ratio,
newsize=(COND_WIDTH, COND_HEIGHT),
no_crop=no_crop,
)
# Second step: Run model
glb_file, pc_file, illumination_file, pc_list = process_model_run(
fr_res,
guidance_scale,
random_seed,
pc_cond,
remesh_option,
vertex_count_type,
vertex_count,
texture_resolution,
)
zip_file = create_zip_file(glb_file, pc_file, illumination_file)
return glb_file, illumination_file, zip_file, pc_list
# Simplified interface
with gr.Blocks() as demo:
gr.Markdown(
"""
# SPAR3D: Stable Point-Aware Reconstruction of 3D Objects from Single Images
Upload an image to generate a 3D model.
"""
)
with gr.Row():
with gr.Column():
input_img = gr.Image(
type="pil",
label="Upload Image",
sources=["upload", "click"],
image_mode="RGBA"
)
with gr.Column():
output_3d = LitModel3D(
label="3D Model",
clear_color=[0.0, 0.0, 0.0, 0.0],
tonemapping="aces",
contrast=1.0,
scale=1.0,
)
download_all_btn = gr.File(
label="Download Model (ZIP)",
file_count="single",
visible=True
)
input_img.upload(
auto_process,
inputs=[input_img],
outputs=[
output_3d,
gr.State(), # for illumination file
download_all_btn,
gr.State(), # for point cloud list
],
)
demo.queue().launch(share=False)