Sharp-It / app.py
YiftachEde's picture
Update app.py
12b3742 verified
raw
history blame
27.2 kB
import os
# this is a HF Spaces specific hack, as
# (i) building pytorch3d with GPU support is a bit tricky here
# (ii) installing the wheel via requirements.txt breaks ZeroGPU
import spaces
# Use the dynamic approach from PyTorch3D documentation to determine the correct wheel
import sys
import torch
# Print debug information about the environment
try:
cuda_version = torch.version.cuda
torch_version = torch.__version__
python_version = f"{sys.version_info.major}.{sys.version_info.minor}"
print(f"CUDA Version: {cuda_version}")
print(f"PyTorch Version: {torch_version}")
print(f"Python Version: {python_version}")
except Exception as e:
print(f"Error detecting environment versions: {e}")
# Install PyTorch3D properly from source
print("Installing PyTorch3D from source...")
# First uninstall any existing PyTorch3D installation to avoid conflicts
os.system("pip uninstall -y pytorch3d")
# Install dependencies required for building PyTorch3D
os.system("apt-get update && apt-get install -y git build-essential libglib2.0-0 libsm6 libxrender-dev libxext6 ninja-build")
os.system("pip install 'imageio>=2.5.0' 'matplotlib>=3.1.2' 'numpy>=1.17.3' 'psutil>=5.6.5' 'scipy>=1.3.2' 'tqdm>=4.42.1' 'trimesh>=3.0.0'")
os.system("pip install fvcore iopath")
# Clone the PyTorch3D repository
os.system("rm -rf pytorch3d") # Remove any existing directory
os.system("git clone https://github.com/facebookresearch/pytorch3d.git")
# Use a specific release tag that is known to be stable
os.system("cd pytorch3d && git checkout v0.7.4")
# Install PyTorch3D from source with CPU support
os.system("cd pytorch3d && pip install -e .")
# Verify the installation
import_result = os.popen('python -c "import pytorch3d; from pytorch3d import renderer; print(\'PyTorch3D and renderer successfully imported\')" 2>&1').read()
print(import_result)
# If the installation fails, try a different approach with a specific wheel
if "No module named" in import_result or "Error" in import_result:
print("Source installation failed, trying with a specific wheel...")
os.system("pip uninstall -y pytorch3d")
# Try with a specific wheel that's known to work
os.system("pip install https://dl.fbaipublicfiles.com/pytorch3d/packaging/wheels/py310_cpu_pyt201/pytorch3d-0.7.4-cp310-cp310-linux_x86_64.whl")
# Verify again
import_result = os.popen('python -c "import pytorch3d; from pytorch3d import renderer; print(\'PyTorch3D and renderer successfully imported\')" 2>&1').read()
print(import_result)
# Patch the shap_e renderer to handle PyTorch3D renderer import error if needed
shap_e_renderer_path = "/usr/local/lib/python3.10/site-packages/shap_e/models/stf/renderer.py"
if os.path.exists(shap_e_renderer_path):
print(f"Patching shap_e renderer at {shap_e_renderer_path}")
# Read the current content
with open(shap_e_renderer_path, "r") as f:
content = f.read()
# Create a backup
os.system(f"cp {shap_e_renderer_path} {shap_e_renderer_path}.bak")
# Modify the content to handle the error more gracefully
modified_content = content
# Replace the error message
if "exception rendering with PyTorch3D" in content:
modified_content = modified_content.replace(
'warnings.warn(f"exception rendering with PyTorch3D: {exc}")',
'warnings.warn("Using native PyTorch renderer")'
)
# Replace the fallback warning
if "falling back on native PyTorch renderer" in modified_content:
modified_content = modified_content.replace(
'warnings.warn("falling back on native PyTorch renderer, which does not support full gradients")',
'warnings.warn("Using native PyTorch renderer")'
)
# Write the modified content
with open(shap_e_renderer_path, "w") as f:
f.write(modified_content)
print("Successfully patched shap_e renderer")
else:
print(f"shap_e renderer not found at {shap_e_renderer_path}")
# Add a helper function to ensure PyTorch3D works with ZeroGPU
def ensure_pytorch3d_cuda_compatibility():
"""
This function ensures PyTorch3D works correctly with CUDA in ZeroGPU environments.
It should be called at the beginning of any @spaces.GPU decorated function.
"""
try:
import pytorch3d
if torch.cuda.is_available():
# Check if we can access the renderer module
from pytorch3d import renderer
print("PyTorch3D renderer module is available with CUDA")
else:
print("CUDA is not available, using CPU version of PyTorch3D")
except ImportError as e:
print(f"Error importing PyTorch3D: {e}")
except Exception as e:
print(f"Unexpected error with PyTorch3D: {e}")
import torch
import torch.nn as nn
import gradio as gr
import numpy as np
from PIL import Image
from omegaconf import OmegaConf
from pytorch_lightning import seed_everything
from huggingface_hub import hf_hub_download
from diffusers import DiffusionPipeline, EulerAncestralDiscreteScheduler
from einops import rearrange
from shap_e.diffusion.sample import sample_latents
from shap_e.diffusion.gaussian_diffusion import diffusion_from_config
from shap_e.models.download import load_model, load_config
from shap_e.util.notebooks import create_pan_cameras, decode_latent_images
from shap_e.models.nn.camera import DifferentiableCameraBatch, DifferentiableProjectiveCamera
import math
import time
from requests.exceptions import ReadTimeout, ConnectionError
from src.utils.train_util import instantiate_from_config
from src.utils.camera_util import (
FOV_to_intrinsics,
get_zero123plus_input_cameras,
get_circular_camera_poses,
spherical_camera_pose
)
from src.utils.mesh_util import save_obj, save_glb
from src.utils.infer_util import remove_background, resize_foreground
def create_custom_cameras(size: int, device: torch.device, azimuths: list, elevations: list,
fov_degrees: float, distance: float) -> DifferentiableCameraBatch:
# Object is in a 2x2x2 bounding box (-1 to 1 in each dimension)
object_diagonal = distance # Correct diagonal calculation for the cube
# Calculate radius based on object size and FOV
fov_radians = math.radians(fov_degrees)
radius = (object_diagonal / 2) / math.tan(fov_radians / 2) # Correct radius calculation
origins = []
xs = []
ys = []
zs = []
for azimuth, elevation in zip(azimuths, elevations):
azimuth_rad = np.radians(azimuth-90)
elevation_rad = np.radians(elevation)
# Calculate camera position
x = radius * np.cos(elevation_rad) * np.cos(azimuth_rad)
y = radius * np.cos(elevation_rad) * np.sin(azimuth_rad)
z = radius * np.sin(elevation_rad)
origin = np.array([x, y, z])
# Calculate camera orientation
z_axis = -origin / np.linalg.norm(origin) # Point towards center
x_axis = np.array([-np.sin(azimuth_rad), np.cos(azimuth_rad), 0])
y_axis = np.cross(z_axis, x_axis)
origins.append(origin)
zs.append(z_axis)
xs.append(x_axis)
ys.append(y_axis)
return DifferentiableCameraBatch(
shape=(1, len(origins)),
flat_camera=DifferentiableProjectiveCamera(
origin=torch.from_numpy(np.stack(origins, axis=0)).float().to(device),
x=torch.from_numpy(np.stack(xs, axis=0)).float().to(device),
y=torch.from_numpy(np.stack(ys, axis=0)).float().to(device),
z=torch.from_numpy(np.stack(zs, axis=0)).float().to(device),
width=size,
height=size,
x_fov=fov_radians,
y_fov=fov_radians,
),
)
def load_models():
"""Initialize and load all required models"""
config = OmegaConf.load('configs/instant-nerf-large-best.yaml')
model_config = config.model_config
infer_config = config.infer_config
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# Load diffusion pipeline with retry logic
print('Loading diffusion pipeline...')
max_retries = 3
retry_delay = 5
for attempt in range(max_retries):
try:
pipeline = DiffusionPipeline.from_pretrained(
"sudo-ai/zero123plus-v1.2",
custom_pipeline="zero123plus",
torch_dtype=torch.float16,
local_files_only=False,
resume_download=True,
)
break
except (ReadTimeout, ConnectionError) as e:
if attempt == max_retries - 1:
raise Exception(f"Failed to download pipeline after {max_retries} attempts: {str(e)}")
print(f"Download attempt {attempt + 1} failed, retrying in {retry_delay} seconds...")
time.sleep(retry_delay)
retry_delay *= 2 # Exponential backoff
pipeline.scheduler = EulerAncestralDiscreteScheduler.from_config(
pipeline.scheduler.config, timestep_spacing='trailing'
)
# Modify UNet to handle 8 input channels instead of 4
in_channels = 8
out_channels = pipeline.unet.conv_in.out_channels
pipeline.unet.register_to_config(in_channels=in_channels)
with torch.no_grad():
new_conv_in = nn.Conv2d(
in_channels, out_channels, pipeline.unet.conv_in.kernel_size,
pipeline.unet.conv_in.stride, pipeline.unet.conv_in.padding
)
new_conv_in.weight.zero_()
new_conv_in.weight[:, :4, :, :].copy_(pipeline.unet.conv_in.weight)
pipeline.unet.conv_in = new_conv_in
# Load custom UNet with retry logic
print('Loading custom UNet...')
for attempt in range(max_retries):
try:
pipeline.unet = pipeline.unet.from_pretrained(
"YiftachEde/Sharp-It",
local_files_only=False,
resume_download=True,
).to(torch.float16)
break
except (ReadTimeout, ConnectionError) as e:
if attempt == max_retries - 1:
raise Exception(f"Failed to download UNet after {max_retries} attempts: {str(e)}")
print(f"Download attempt {attempt + 1} failed, retrying in {retry_delay} seconds...")
time.sleep(retry_delay)
retry_delay *= 2
pipeline = pipeline.to(device).to(torch_dtype=torch.float16)
# Load reconstruction model with retry logic
print('Loading reconstruction model...')
model = instantiate_from_config(model_config)
for attempt in range(max_retries):
try:
model_path = hf_hub_download(
repo_id="TencentARC/InstantMesh",
filename="instant_nerf_large.ckpt",
repo_type="model",
local_files_only=False,
resume_download=True,
cache_dir="model_cache" # Use a specific cache directory
)
break
except (ReadTimeout, ConnectionError) as e:
if attempt == max_retries - 1:
raise Exception(f"Failed to download model after {max_retries} attempts: {str(e)}")
print(f"Download attempt {attempt + 1} failed, retrying in {retry_delay} seconds...")
time.sleep(retry_delay)
retry_delay *= 2
state_dict = torch.load(model_path, map_location='cpu')['state_dict']
state_dict = {k[14:]: v for k, v in state_dict.items()
if k.startswith('lrm_generator.') and 'source_camera' not in k}
model.load_state_dict(state_dict, strict=True)
model = model.to(device)
model.eval()
return pipeline, model, infer_config
@spaces.GPU(duration=20)
def process_images(input_images, prompt, steps=75, guidance_scale=7.5, pipeline=None):
"""Process input images and run refinement"""
# Ensure PyTorch3D works with CUDA
ensure_pytorch3d_cuda_compatibility()
device = pipeline.device
if isinstance(input_images, list):
if len(input_images) == 1:
# Check if this is a pre-arranged layout
img = Image.open(input_images[0].name).convert('RGB')
if img.size == (640, 960):
# This is already a layout, use it directly
input_image = img
else:
# Single view - need 6 copies
img = img.resize((320, 320))
img_array = np.array(img) / 255.0
images = [img_array] * 6
images = np.stack(images)
# Convert to tensor and create layout
images = torch.from_numpy(images).float()
images = images.permute(0, 3, 1, 2)
images = images.reshape(3, 2, 3, 320, 320)
images = images.permute(0, 2, 3, 1, 4)
images = images.reshape(3, 3, 320, 640)
images = images.reshape(1, 3, 960, 640)
# Convert back to PIL
images = images.permute(0, 2, 3, 1)[0]
images = (images.numpy() * 255).astype(np.uint8)
input_image = Image.fromarray(images)
else:
# Multiple individual views
images = []
for img_file in input_images:
img = Image.open(img_file.name).convert('RGB')
img = img.resize((320, 320))
img = np.array(img) / 255.0
images.append(img)
# Pad to 6 images if needed
while len(images) < 6:
images.append(np.zeros_like(images[0]))
images = np.stack(images[:6])
# Convert to tensor and create layout
images = torch.from_numpy(images).float()
images = images.permute(0, 3, 1, 2)
images = images.reshape(3, 2, 3, 320, 320)
images = images.permute(0, 2, 3, 1, 4)
images = images.reshape(3, 3, 320, 640)
images = images.reshape(1, 3, 960, 640)
# Convert back to PIL
images = images.permute(0, 2, 3, 1)[0]
images = (images.numpy() * 255).astype(np.uint8)
input_image = Image.fromarray(images)
else:
raise ValueError("Expected a list of images")
# Generate refined output
output = pipeline.refine(
input_image,
prompt=prompt,
num_inference_steps=int(steps),
guidance_scale=guidance_scale
).images[0]
return output, input_image
@spaces.GPU(duration=20)
def create_mesh(refined_image, model, infer_config):
"""Generate mesh from refined image"""
# Ensure PyTorch3D works with CUDA
ensure_pytorch3d_cuda_compatibility()
# Convert PIL image to tensor
image = np.array(refined_image) / 255.0
image = torch.from_numpy(image).float().permute(2, 0, 1)
# Reshape to 6 views
image = image.reshape(3, 960, 640)
image = image.reshape(3, 3, 320, 640)
image = image.permute(1, 0, 2, 3)
image = image.reshape(3, 3, 320, 2, 320)
image = image.permute(0, 3, 1, 2, 4)
image = image.reshape(6, 3, 320, 320)
# Add batch dimension
image = image.unsqueeze(0)
input_cameras = get_zero123plus_input_cameras(batch_size=1, radius=4.0).to("cuda")
image = image.to("cuda")
with torch.no_grad():
planes = model.forward_planes(image, input_cameras)
mesh_out = model.extract_mesh(planes, **infer_config)
vertices, faces, vertex_colors = mesh_out
return vertices, faces, vertex_colors
class ShapERenderer:
def __init__(self, device):
print("Initializing Shap-E models...")
self.device = device
torch.cuda.empty_cache() # Clear GPU memory before loading
self.xm = load_model('transmitter', device=self.device)
self.model = load_model('text300M', device=self.device)
self.diffusion = diffusion_from_config(load_config('diffusion'))
print("Shap-E models initialized!")
def generate_views(self, prompt, guidance_scale=15.0, num_steps=64):
try:
torch.cuda.empty_cache() # Clear GPU memory before generation
# Generate latents using the text-to-3D model
batch_size = 1
guidance_scale = float(guidance_scale)
with torch.amp.autocast('cuda'): # Use automatic mixed precision
latents = sample_latents(
batch_size=batch_size,
model=self.model,
diffusion=self.diffusion,
guidance_scale=guidance_scale,
model_kwargs=dict(texts=[prompt] * batch_size),
progress=True,
clip_denoised=True,
use_fp16=True,
use_karras=True,
karras_steps=num_steps,
sigma_min=1e-3,
sigma_max=160,
s_churn=0,
)
# Render the 6 views we need with specific viewing angles
size = 320 # Size of each rendered image
images = []
# Define our 6 specific camera positions to match refine.py
azimuths = [30, 90, 150, 210, 270, 330]
elevations = [20, -10, 20, -10, 20, -10]
for i, (azimuth, elevation) in enumerate(zip(azimuths, elevations)):
cameras = create_custom_cameras(size, self.device, azimuths=[azimuth], elevations=[elevation], fov_degrees=30, distance=3.0)
with torch.amp.autocast('cuda'): # Use automatic mixed precision
rendered_image = decode_latent_images(
self.xm,
latents[0],
cameras=cameras,
rendering_mode='stf'
)
images.append(rendered_image[0])
torch.cuda.empty_cache() # Clear GPU memory after each view
# Convert images to uint8
images = [np.array(image) for image in images]
# Create 2x3 grid layout (640x960)
layout = np.zeros((960, 640, 3), dtype=np.uint8)
for i, img in enumerate(images):
row = i // 2
col = i % 2
layout[row*320:(row+1)*320, col*320:(col+1)*320] = img
return Image.fromarray(layout), images
except Exception as e:
print(f"Error in generate_views: {e}")
torch.cuda.empty_cache() # Clear GPU memory on error
raise
class RefinerInterface:
def __init__(self):
print("Initializing InstantMesh models...")
torch.cuda.empty_cache() # Clear GPU memory before loading
self.pipeline, self.model, self.infer_config = load_models()
print("InstantMesh models initialized!")
def refine_model(self, input_image, prompt, steps=75, guidance_scale=7.5):
"""Main refinement function"""
try:
torch.cuda.empty_cache() # Clear GPU memory before processing
# Process image and get refined output
input_image = Image.fromarray(input_image)
# Rotate the layout if needed (if we're getting a 640x960 layout but pipeline expects 960x640)
if input_image.width == 960 and input_image.height == 640:
# Transpose the image to get 960x640 layout
input_array = np.array(input_image)
new_layout = np.zeros((960, 640, 3), dtype=np.uint8)
# Rearrange from 2x3 to 3x2
for i in range(6):
src_row = i // 3
src_col = i % 3
dst_row = i // 2
dst_col = i % 2
new_layout[dst_row*320:(dst_row+1)*320, dst_col*320:(dst_col+1)*320] = \
input_array[src_row*320:(src_row+1)*320, src_col*320:(src_col+1)*320]
input_image = Image.fromarray(new_layout)
# Process with the pipeline (expects 960x640)
with torch.amp.autocast('cuda'): # Use automatic mixed precision
refined_output_960x640 = self.pipeline.refine(
input_image,
prompt=prompt,
num_inference_steps=int(steps),
guidance_scale=guidance_scale
).images[0]
torch.cuda.empty_cache() # Clear GPU memory after refinement
# Generate mesh using the 960x640 format
with torch.amp.autocast('cuda'): # Use automatic mixed precision
vertices, faces, vertex_colors = create_mesh(
refined_output_960x640,
self.model,
self.infer_config
)
torch.cuda.empty_cache() # Clear GPU memory after mesh generation
# Save temporary mesh file
os.makedirs("temp", exist_ok=True)
temp_obj = os.path.join("temp", "refined_mesh.obj")
save_obj(vertices, faces, vertex_colors, temp_obj)
# Convert the output to 640x960 for display
refined_array = np.array(refined_output_960x640)
display_layout = np.zeros((960, 640, 3), dtype=np.uint8)
# Rearrange from 3x2 to 2x3
for i in range(6):
src_row = i // 2
src_col = i % 2
dst_row = i // 2
dst_col = i % 2
display_layout[dst_row*320:(dst_row+1)*320, dst_col*320:(dst_col+1)*320] = \
refined_array[src_row*320:(src_row+1)*320, src_col*320:(src_col+1)*320]
refined_output_640x960 = Image.fromarray(display_layout)
return refined_output_640x960, temp_obj
except Exception as e:
print(f"Error in refine_model: {e}")
torch.cuda.empty_cache() # Clear GPU memory on error
raise
def create_demo():
print("Initializing models...")
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# Initialize models at startup
shap_e = ShapERenderer(device)
refiner = RefinerInterface()
print("All models initialized!")
with gr.Blocks() as demo:
gr.Markdown("# Shap-E to InstantMesh Pipeline")
# First row: Controls
with gr.Row():
with gr.Column():
# Shap-E inputs
shape_prompt = gr.Textbox(
label="Shap-E Prompt",
placeholder="Enter text to generate initial 3D model..."
)
shape_guidance = gr.Slider(
minimum=1,
maximum=30,
value=15.0,
label="Shap-E Guidance Scale"
)
shape_steps = gr.Slider(
minimum=16,
maximum=128,
value=64,
step=16,
label="Shap-E Steps"
)
generate_btn = gr.Button("Generate Views")
with gr.Column():
# Refinement inputs
refine_prompt = gr.Textbox(
label="Refinement Prompt",
placeholder="Enter prompt to guide refinement..."
)
refine_steps = gr.Slider(
minimum=30,
maximum=100,
value=75,
step=1,
label="Refinement Steps"
)
refine_guidance = gr.Slider(
minimum=1,
maximum=20,
value=7.5,
label="Refinement Guidance Scale"
)
refine_btn = gr.Button("Refine")
error_output = gr.Textbox(label="Status/Error Messages", interactive=False)
# Second row: Image panels side by side
with gr.Row():
# Outputs - Images side by side
shape_output = gr.Image(
label="Generated Views",
width=640,
height=960
)
refined_output = gr.Image(
label="Refined Output",
width=640,
height=960
)
# Third row: 3D mesh panel below
with gr.Row():
# 3D mesh centered
mesh_output = gr.Model3D(
label="3D Mesh",
clear_color=[1.0, 1.0, 1.0, 1.0],
)
# Set up event handlers
@spaces.GPU(duration=20) # Reduced duration to 20 seconds
def generate(prompt, guidance_scale, num_steps):
try:
# Ensure PyTorch3D works with CUDA
ensure_pytorch3d_cuda_compatibility()
torch.cuda.empty_cache() # Clear GPU memory before starting
with torch.no_grad():
layout, _ = shap_e.generate_views(prompt, guidance_scale, num_steps)
return layout, None # Return None for error message
except Exception as e:
torch.cuda.empty_cache() # Clear GPU memory on error
error_msg = f"Error during generation: {str(e)}"
print(error_msg)
return None, error_msg
@spaces.GPU(duration=20) # Reduced duration to 20 seconds
def refine(input_image, prompt, steps, guidance_scale):
try:
# Ensure PyTorch3D works with CUDA
ensure_pytorch3d_cuda_compatibility()
torch.cuda.empty_cache() # Clear GPU memory before starting
refined_img, mesh_path = refiner.refine_model(
input_image,
prompt,
steps,
guidance_scale
)
return refined_img, mesh_path, None # Return None for error message
except Exception as e:
torch.cuda.empty_cache() # Clear GPU memory on error
error_msg = f"Error during refinement: {str(e)}"
print(error_msg)
return None, None, error_msg
generate_btn.click(
fn=generate,
inputs=[shape_prompt, shape_guidance, shape_steps],
outputs=[shape_output, error_output]
)
refine_btn.click(
fn=refine,
inputs=[shape_output, refine_prompt, refine_steps, refine_guidance],
outputs=[refined_output, mesh_output, error_output]
)
return demo
if __name__ == "__main__":
demo = create_demo()
demo.launch(share=True)