Rightlight / app.py
mike23415's picture
Update app.py
829dfd4 verified
raw
history blame
4.16 kB
import os
import torch
import gradio as gr
import numpy as np
from PIL import Image
import tempfile
from transformers import AutoImageProcessor, AutoModel
from tqdm.auto import tqdm
# Check if CUDA is available, otherwise use CPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
# Initialize the model
print("Loading Shap-E model...")
model_id = "openai/shap-e-img2img"
processor = AutoImageProcessor.from_pretrained(model_id)
model = AutoModel.from_pretrained(model_id).to(device)
def preprocess_image(image):
# Resize and center crop to 256x256
width, height = image.size
size = min(width, height)
left = (width - size) // 2
top = (height - size) // 2
right = left + size
bottom = top + size
image = image.crop((left, top, right, bottom))
image = image.resize((256, 256))
return image
def generate_3d_mesh(image, guidance_scale=15.0, num_inference_steps=64):
"""
Convert a single image to a 3D model using Shap-E
"""
if image is None:
return None, "No image provided"
try:
# Preprocess image
image = preprocess_image(image)
# Process image
inputs = processor(images=image, return_tensors="pt").to(device)
# Generate latents
with torch.no_grad():
latents = model.encode(inputs["pixel_values"]).latents
# Decode the latents
with torch.no_grad():
with tqdm(total=num_inference_steps) as progress_bar:
def callback(i, t, latents):
progress_bar.update(1)
sample = model.decode(
latents,
guidance_scale=guidance_scale,
num_inference_steps=num_inference_steps,
callback=callback
)
# Get mesh
obj_mesh = sample.get_mesh()
glb_mesh = sample.get_glb()
# Save mesh to files
with tempfile.NamedTemporaryFile(suffix='.obj', delete=False) as obj_file:
obj_path = obj_file.name
obj_mesh.write_obj(obj_path)
with tempfile.NamedTemporaryFile(suffix='.glb', delete=False) as glb_file:
glb_path = glb_file.name
glb_file.write(glb_mesh)
return [obj_path, glb_path], "3D model generated successfully!"
except Exception as e:
return None, f"Error: {str(e)}"
def process_image(image, guidance_scale, num_steps):
try:
if image is None:
return None, None, "Please upload an image first."
results, message = generate_3d_mesh(
image,
guidance_scale=guidance_scale,
num_inference_steps=num_steps
)
if results:
return results[0], results[1], message
else:
return None, None, message
except Exception as e:
return None, None, f"Error: {str(e)}"
# Create Gradio interface
with gr.Blocks(title="Image to 3D Model Converter") as demo:
gr.Markdown("# Image to 3D Model Converter")
gr.Markdown("Upload an image to convert it to a 3D model that you can use in Unity or other engines.")
with gr.Row():
with gr.Column(scale=1):
input_image = gr.Image(type="pil", label="Input Image")
guidance = gr.Slider(minimum=5.0, maximum=20.0, value=15.0, step=0.5, label="Guidance Scale")
num_steps = gr.Slider(minimum=16, maximum=128, value=64, step=8, label="Number of Inference Steps")
submit_btn = gr.Button("Convert to 3D")
with gr.Column(scale=1):
obj_file = gr.File(label="OBJ File (for editing)")
glb_file = gr.File(label="GLB File (for Unity)")
output_message = gr.Textbox(label="Output Message")
submit_btn.click(
fn=process_image,
inputs=[input_image, guidance, num_steps],
outputs=[obj_file, glb_file, output_message]
)
# Launch the app
if __name__ == "__main__":
demo.launch(server_name="0.0.0.0", server_port=7860)