Rightlight / app.py
mike23415's picture
Update app.py
b33bab2 verified
raw
history blame
5.18 kB
import os
import torch
import gradio as gr
import numpy as np
from PIL import Image
import tempfile
import trimesh
# Check if CUDA is available, otherwise use CPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
# Import Point-E modules
try:
print("Loading Point-E model...")
from point_e.diffusion.configs import DIFFUSION_CONFIGS, diffusion_from_config
from point_e.diffusion.sampler import PointCloudSampler
from point_e.models.configs import MODEL_CONFIGS, model_from_config
from point_e.models.download import load_checkpoint
from point_e.util.plotting import plot_point_cloud
except ImportError:
print("Point-E modules not available. Please make sure Point-E is installed.")
raise
# Create base model for image encoder
base_name = 'base40M-textvec'
base_model = model_from_config(MODEL_CONFIGS[base_name], device)
base_model.eval()
base_diffusion = diffusion_from_config(DIFFUSION_CONFIGS[base_name])
# Create upsampler model
upsampler_model = model_from_config(MODEL_CONFIGS['upsample'], device)
upsampler_model.eval()
upsampler_diffusion = diffusion_from_config(DIFFUSION_CONFIGS['upsample'])
# Create image to point cloud model
img2pc_name = 'base300M'
img2pc_model = model_from_config(MODEL_CONFIGS[img2pc_name], device)
img2pc_model.eval()
img2pc_diffusion = diffusion_from_config(DIFFUSION_CONFIGS[img2pc_name])
# Load checkpoints
print("Loading model checkpoints...")
base_model.load_state_dict(load_checkpoint(base_name, device))
upsampler_model.load_state_dict(load_checkpoint('upsample', device))
img2pc_model.load_state_dict(load_checkpoint(img2pc_name, device))
# Create samplers
sampler = PointCloudSampler(
device=device,
models=[base_model, upsampler_model],
diffusions=[base_diffusion, upsampler_diffusion],
num_points=[1024, 4096],
aux_channels=['R', 'G', 'B'],
guidance_scale=[3.0, 0.0],
)
img2pc_sampler = PointCloudSampler(
device=device,
models=[img2pc_model],
diffusions=[img2pc_diffusion],
num_points=[1024],
aux_channels=['R', 'G', 'B'],
guidance_scale=[3.0],
)
def preprocess_image(image):
# Resize to match expected input size
image = image.resize((256, 256))
return image
def image_to_3d(image, num_steps=64):
"""
Convert a single image to a 3D model using Point-E
"""
if image is None:
return None, "No image provided"
try:
# Preprocess image
processed_image = preprocess_image(image)
# Generate samples
samples = None
for i, x in enumerate(img2pc_sampler.sample_batch_progressive(batch_size=1, model_kwargs=dict(images=[processed_image]))):
samples = x
# Extract point cloud
pc = samples[-1]['pred_pc']
colors = samples[-1]['pred_pc_aux']['R', 'G', 'B']
# Create colored point cloud
points = pc.cpu().numpy()[0]
colors_np = colors.cpu().numpy()[0]
# Create a mesh from point cloud
point_cloud = trimesh.PointCloud(vertices=points, colors=colors_np)
# Save as OBJ
with tempfile.NamedTemporaryFile(suffix='.obj', delete=False) as obj_file:
obj_path = obj_file.name
point_cloud.export(obj_path)
# Save as PLY for better Unity compatibility
with tempfile.NamedTemporaryFile(suffix='.ply', delete=False) as ply_file:
ply_path = ply_file.name
point_cloud.export(ply_path)
return [obj_path, ply_path], "3D model generated successfully!"
except Exception as e:
return None, f"Error: {str(e)}"
def process_image(image, num_steps):
try:
if image is None:
return None, None, "Please upload an image first."
results, message = image_to_3d(
image,
num_steps=num_steps
)
if results:
return results[0], results[1], message
else:
return None, None, message
except Exception as e:
return None, None, f"Error: {str(e)}"
# Create Gradio interface
with gr.Blocks(title="Image to 3D Point Cloud Converter") as demo:
gr.Markdown("# Image to 3D Point Cloud Converter")
gr.Markdown("Upload an image to convert it to a 3D point cloud that you can use in Unity or other engines.")
with gr.Row():
with gr.Column(scale=1):
input_image = gr.Image(type="pil", label="Input Image")
num_steps = gr.Slider(minimum=16, maximum=128, value=64, step=8, label="Number of Inference Steps")
submit_btn = gr.Button("Convert to 3D")
with gr.Column(scale=1):
obj_file = gr.File(label="OBJ File (for editing)")
ply_file = gr.File(label="PLY File (for Unity)")
output_message = gr.Textbox(label="Output Message")
submit_btn.click(
fn=process_image,
inputs=[input_image, num_steps],
outputs=[obj_file, ply_file, output_message]
)
# Launch the app
if __name__ == "__main__":
demo.launch(server_name="0.0.0.0", server_port=7860)