Spaces:
Runtime error
Runtime error
import os | |
import torch | |
import gradio as gr | |
import numpy as np | |
from PIL import Image | |
import tempfile | |
import trimesh | |
# Check if CUDA is available, otherwise use CPU | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
print(f"Using device: {device}") | |
# Import Point-E modules | |
try: | |
print("Loading Point-E model...") | |
from point_e.diffusion.configs import DIFFUSION_CONFIGS, diffusion_from_config | |
from point_e.diffusion.sampler import PointCloudSampler | |
from point_e.models.configs import MODEL_CONFIGS, model_from_config | |
from point_e.models.download import load_checkpoint | |
from point_e.util.plotting import plot_point_cloud | |
except ImportError: | |
print("Point-E modules not available. Please make sure Point-E is installed.") | |
raise | |
# Create base model for image encoder | |
base_name = 'base40M-textvec' | |
base_model = model_from_config(MODEL_CONFIGS[base_name], device) | |
base_model.eval() | |
base_diffusion = diffusion_from_config(DIFFUSION_CONFIGS[base_name]) | |
# Create upsampler model | |
upsampler_model = model_from_config(MODEL_CONFIGS['upsample'], device) | |
upsampler_model.eval() | |
upsampler_diffusion = diffusion_from_config(DIFFUSION_CONFIGS['upsample']) | |
# Create image to point cloud model | |
img2pc_name = 'base300M' | |
img2pc_model = model_from_config(MODEL_CONFIGS[img2pc_name], device) | |
img2pc_model.eval() | |
img2pc_diffusion = diffusion_from_config(DIFFUSION_CONFIGS[img2pc_name]) | |
# Load checkpoints | |
print("Loading model checkpoints...") | |
base_model.load_state_dict(load_checkpoint(base_name, device)) | |
upsampler_model.load_state_dict(load_checkpoint('upsample', device)) | |
img2pc_model.load_state_dict(load_checkpoint(img2pc_name, device)) | |
# Create samplers | |
sampler = PointCloudSampler( | |
device=device, | |
models=[base_model, upsampler_model], | |
diffusions=[base_diffusion, upsampler_diffusion], | |
num_points=[1024, 4096], | |
aux_channels=['R', 'G', 'B'], | |
guidance_scale=[3.0, 0.0], | |
) | |
img2pc_sampler = PointCloudSampler( | |
device=device, | |
models=[img2pc_model], | |
diffusions=[img2pc_diffusion], | |
num_points=[1024], | |
aux_channels=['R', 'G', 'B'], | |
guidance_scale=[3.0], | |
) | |
def preprocess_image(image): | |
# Resize to match expected input size | |
image = image.resize((256, 256)) | |
return image | |
def image_to_3d(image, num_steps=64): | |
""" | |
Convert a single image to a 3D model using Point-E | |
""" | |
if image is None: | |
return None, "No image provided" | |
try: | |
# Preprocess image | |
processed_image = preprocess_image(image) | |
# Generate samples | |
samples = None | |
for i, x in enumerate(img2pc_sampler.sample_batch_progressive(batch_size=1, model_kwargs=dict(images=[processed_image]))): | |
samples = x | |
# Extract point cloud | |
pc = samples[-1]['pred_pc'] | |
colors = samples[-1]['pred_pc_aux']['R', 'G', 'B'] | |
# Create colored point cloud | |
points = pc.cpu().numpy()[0] | |
colors_np = colors.cpu().numpy()[0] | |
# Create a mesh from point cloud | |
point_cloud = trimesh.PointCloud(vertices=points, colors=colors_np) | |
# Save as OBJ | |
with tempfile.NamedTemporaryFile(suffix='.obj', delete=False) as obj_file: | |
obj_path = obj_file.name | |
point_cloud.export(obj_path) | |
# Save as PLY for better Unity compatibility | |
with tempfile.NamedTemporaryFile(suffix='.ply', delete=False) as ply_file: | |
ply_path = ply_file.name | |
point_cloud.export(ply_path) | |
return [obj_path, ply_path], "3D model generated successfully!" | |
except Exception as e: | |
return None, f"Error: {str(e)}" | |
def process_image(image, num_steps): | |
try: | |
if image is None: | |
return None, None, "Please upload an image first." | |
results, message = image_to_3d( | |
image, | |
num_steps=num_steps | |
) | |
if results: | |
return results[0], results[1], message | |
else: | |
return None, None, message | |
except Exception as e: | |
return None, None, f"Error: {str(e)}" | |
# Create Gradio interface | |
with gr.Blocks(title="Image to 3D Point Cloud Converter") as demo: | |
gr.Markdown("# Image to 3D Point Cloud Converter") | |
gr.Markdown("Upload an image to convert it to a 3D point cloud that you can use in Unity or other engines.") | |
with gr.Row(): | |
with gr.Column(scale=1): | |
input_image = gr.Image(type="pil", label="Input Image") | |
num_steps = gr.Slider(minimum=16, maximum=128, value=64, step=8, label="Number of Inference Steps") | |
submit_btn = gr.Button("Convert to 3D") | |
with gr.Column(scale=1): | |
obj_file = gr.File(label="OBJ File (for editing)") | |
ply_file = gr.File(label="PLY File (for Unity)") | |
output_message = gr.Textbox(label="Output Message") | |
submit_btn.click( | |
fn=process_image, | |
inputs=[input_image, num_steps], | |
outputs=[obj_file, ply_file, output_message] | |
) | |
# Launch the app | |
if __name__ == "__main__": | |
demo.launch(server_name="0.0.0.0", server_port=7860) |