Spaces:
Runtime error
Runtime error
File size: 4,161 Bytes
4e31b1a aaa6458 829dfd4 aaa6458 829dfd4 81914fc aaa6458 1087492 829dfd4 388cf5c 829dfd4 aaa6458 829dfd4 aaa6458 829dfd4 aaa6458 829dfd4 aaa6458 829dfd4 aaa6458 829dfd4 aaa6458 829dfd4 388cf5c 829dfd4 388cf5c 829dfd4 aaa6458 829dfd4 388cf5c aaa6458 829dfd4 aaa6458 829dfd4 aaa6458 829dfd4 aaa6458 1087492 aaa6458 48056a7 aaa6458 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 |
import os
import torch
import gradio as gr
import numpy as np
from PIL import Image
import tempfile
from transformers import AutoImageProcessor, AutoModel
from tqdm.auto import tqdm
# Check if CUDA is available, otherwise use CPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
# Initialize the model
print("Loading Shap-E model...")
model_id = "openai/shap-e-img2img"
processor = AutoImageProcessor.from_pretrained(model_id)
model = AutoModel.from_pretrained(model_id).to(device)
def preprocess_image(image):
# Resize and center crop to 256x256
width, height = image.size
size = min(width, height)
left = (width - size) // 2
top = (height - size) // 2
right = left + size
bottom = top + size
image = image.crop((left, top, right, bottom))
image = image.resize((256, 256))
return image
def generate_3d_mesh(image, guidance_scale=15.0, num_inference_steps=64):
"""
Convert a single image to a 3D model using Shap-E
"""
if image is None:
return None, "No image provided"
try:
# Preprocess image
image = preprocess_image(image)
# Process image
inputs = processor(images=image, return_tensors="pt").to(device)
# Generate latents
with torch.no_grad():
latents = model.encode(inputs["pixel_values"]).latents
# Decode the latents
with torch.no_grad():
with tqdm(total=num_inference_steps) as progress_bar:
def callback(i, t, latents):
progress_bar.update(1)
sample = model.decode(
latents,
guidance_scale=guidance_scale,
num_inference_steps=num_inference_steps,
callback=callback
)
# Get mesh
obj_mesh = sample.get_mesh()
glb_mesh = sample.get_glb()
# Save mesh to files
with tempfile.NamedTemporaryFile(suffix='.obj', delete=False) as obj_file:
obj_path = obj_file.name
obj_mesh.write_obj(obj_path)
with tempfile.NamedTemporaryFile(suffix='.glb', delete=False) as glb_file:
glb_path = glb_file.name
glb_file.write(glb_mesh)
return [obj_path, glb_path], "3D model generated successfully!"
except Exception as e:
return None, f"Error: {str(e)}"
def process_image(image, guidance_scale, num_steps):
try:
if image is None:
return None, None, "Please upload an image first."
results, message = generate_3d_mesh(
image,
guidance_scale=guidance_scale,
num_inference_steps=num_steps
)
if results:
return results[0], results[1], message
else:
return None, None, message
except Exception as e:
return None, None, f"Error: {str(e)}"
# Create Gradio interface
with gr.Blocks(title="Image to 3D Model Converter") as demo:
gr.Markdown("# Image to 3D Model Converter")
gr.Markdown("Upload an image to convert it to a 3D model that you can use in Unity or other engines.")
with gr.Row():
with gr.Column(scale=1):
input_image = gr.Image(type="pil", label="Input Image")
guidance = gr.Slider(minimum=5.0, maximum=20.0, value=15.0, step=0.5, label="Guidance Scale")
num_steps = gr.Slider(minimum=16, maximum=128, value=64, step=8, label="Number of Inference Steps")
submit_btn = gr.Button("Convert to 3D")
with gr.Column(scale=1):
obj_file = gr.File(label="OBJ File (for editing)")
glb_file = gr.File(label="GLB File (for Unity)")
output_message = gr.Textbox(label="Output Message")
submit_btn.click(
fn=process_image,
inputs=[input_image, guidance, num_steps],
outputs=[obj_file, glb_file, output_message]
)
# Launch the app
if __name__ == "__main__":
demo.launch(server_name="0.0.0.0", server_port=7860) |