Spaces:
Running
on
Zero
Running
on
Zero
Update app.py
Browse files
app.py
CHANGED
@@ -19,124 +19,44 @@ try:
|
|
19 |
except Exception as e:
|
20 |
print(f"Error detecting environment versions: {e}")
|
21 |
|
22 |
-
# Install PyTorch3D
|
23 |
-
print("Installing PyTorch3D
|
24 |
|
25 |
# First uninstall any existing PyTorch3D installation to avoid conflicts
|
26 |
os.system("pip uninstall -y pytorch3d")
|
27 |
|
28 |
-
# Install dependencies required for PyTorch3D
|
|
|
|
|
29 |
os.system("pip install fvcore iopath")
|
30 |
|
31 |
-
#
|
32 |
-
|
33 |
-
os.system("
|
34 |
|
35 |
-
#
|
36 |
-
|
37 |
-
print(f"Renderer module check: {import_result}")
|
38 |
|
39 |
-
#
|
40 |
-
|
41 |
-
print("Renderer module not found, creating it manually...")
|
42 |
-
|
43 |
-
# Get the site-packages directory
|
44 |
-
site_packages = os.popen('python -c "import site; print(site.getsitepackages()[0])"').read().strip()
|
45 |
-
|
46 |
-
# Check if the pytorch3d directory exists
|
47 |
-
pytorch3d_dir = f"{site_packages}/pytorch3d"
|
48 |
-
if not os.path.exists(pytorch3d_dir):
|
49 |
-
print(f"PyTorch3D directory not found at {pytorch3d_dir}, creating it...")
|
50 |
-
os.system(f"mkdir -p {pytorch3d_dir}")
|
51 |
-
|
52 |
-
# Create an __init__.py file in the pytorch3d directory
|
53 |
-
with open(f"{pytorch3d_dir}/__init__.py", "w") as f:
|
54 |
-
f.write("""
|
55 |
-
# PyTorch3D module
|
56 |
-
import torch
|
57 |
-
import warnings
|
58 |
|
59 |
-
|
|
|
|
|
60 |
|
61 |
-
|
62 |
-
"""
|
|
|
|
|
63 |
|
64 |
-
#
|
65 |
-
os.system(
|
66 |
|
67 |
-
#
|
68 |
-
|
69 |
-
|
70 |
-
# PyTorch3D renderer module
|
71 |
-
import torch
|
72 |
-
import warnings
|
73 |
-
|
74 |
-
warnings.warn("Using custom PyTorch3D renderer module")
|
75 |
-
|
76 |
-
# Basic renderer components
|
77 |
-
class TexturesBase:
|
78 |
-
def __init__(self, *args, **kwargs):
|
79 |
-
pass
|
80 |
-
|
81 |
-
class TexturesVertex(TexturesBase):
|
82 |
-
def __init__(self, *args, **kwargs):
|
83 |
-
super().__init__(*args, **kwargs)
|
84 |
-
|
85 |
-
class TexturesUV(TexturesBase):
|
86 |
-
def __init__(self, *args, **kwargs):
|
87 |
-
super().__init__(*args, **kwargs)
|
88 |
-
|
89 |
-
class TexturesAtlas(TexturesBase):
|
90 |
-
def __init__(self, *args, **kwargs):
|
91 |
-
super().__init__(*args, **kwargs)
|
92 |
-
|
93 |
-
class RasterizationSettings:
|
94 |
-
def __init__(self, *args, **kwargs):
|
95 |
-
pass
|
96 |
-
|
97 |
-
class MeshRasterizer:
|
98 |
-
def __init__(self, *args, **kwargs):
|
99 |
-
pass
|
100 |
-
|
101 |
-
class PointLights:
|
102 |
-
def __init__(self, *args, **kwargs):
|
103 |
-
pass
|
104 |
-
|
105 |
-
class Materials:
|
106 |
-
def __init__(self, *args, **kwargs):
|
107 |
-
pass
|
108 |
|
109 |
-
|
110 |
-
def __init__(self, *args, **kwargs):
|
111 |
-
pass
|
112 |
-
|
113 |
-
class SoftPhongShader:
|
114 |
-
def __init__(self, *args, **kwargs):
|
115 |
-
pass
|
116 |
-
|
117 |
-
class HardPhongShader:
|
118 |
-
def __init__(self, *args, **kwargs):
|
119 |
-
pass
|
120 |
-
|
121 |
-
class SoftSilhouetteShader:
|
122 |
-
def __init__(self, *args, **kwargs):
|
123 |
-
pass
|
124 |
-
|
125 |
-
class BlendParams:
|
126 |
-
def __init__(self, *args, **kwargs):
|
127 |
-
pass
|
128 |
-
|
129 |
-
def look_at_view_transform(*args, **kwargs):
|
130 |
-
return torch.eye(4), torch.eye(4)
|
131 |
-
|
132 |
-
class FoVPerspectiveCameras:
|
133 |
-
def __init__(self, *args, **kwargs):
|
134 |
-
pass
|
135 |
-
""")
|
136 |
-
|
137 |
-
print("Created custom renderer module")
|
138 |
-
|
139 |
-
# Patch the shap_e renderer to handle PyTorch3D renderer import error
|
140 |
shap_e_renderer_path = "/usr/local/lib/python3.10/site-packages/shap_e/models/stf/renderer.py"
|
141 |
if os.path.exists(shap_e_renderer_path):
|
142 |
print(f"Patching shap_e renderer at {shap_e_renderer_path}")
|
@@ -173,9 +93,24 @@ if os.path.exists(shap_e_renderer_path):
|
|
173 |
else:
|
174 |
print(f"shap_e renderer not found at {shap_e_renderer_path}")
|
175 |
|
176 |
-
#
|
177 |
-
|
178 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
179 |
|
180 |
import torch
|
181 |
import torch.nn as nn
|
@@ -354,6 +289,9 @@ def load_models():
|
|
354 |
@spaces.GPU(duration=20)
|
355 |
def process_images(input_images, prompt, steps=75, guidance_scale=7.5, pipeline=None):
|
356 |
"""Process input images and run refinement"""
|
|
|
|
|
|
|
357 |
device = pipeline.device
|
358 |
|
359 |
if isinstance(input_images, list):
|
@@ -424,6 +362,9 @@ def process_images(input_images, prompt, steps=75, guidance_scale=7.5, pipeline=
|
|
424 |
@spaces.GPU(duration=20)
|
425 |
def create_mesh(refined_image, model, infer_config):
|
426 |
"""Generate mesh from refined image"""
|
|
|
|
|
|
|
427 |
# Convert PIL image to tensor
|
428 |
image = np.array(refined_image) / 255.0
|
429 |
image = torch.from_numpy(image).float().permute(2, 0, 1)
|
@@ -686,6 +627,9 @@ def create_demo():
|
|
686 |
@spaces.GPU(duration=20) # Reduced duration to 20 seconds
|
687 |
def generate(prompt, guidance_scale, num_steps):
|
688 |
try:
|
|
|
|
|
|
|
689 |
torch.cuda.empty_cache() # Clear GPU memory before starting
|
690 |
with torch.no_grad():
|
691 |
layout, _ = shap_e.generate_views(prompt, guidance_scale, num_steps)
|
@@ -699,11 +643,14 @@ def create_demo():
|
|
699 |
@spaces.GPU(duration=20) # Reduced duration to 20 seconds
|
700 |
def refine(input_image, prompt, steps, guidance_scale):
|
701 |
try:
|
|
|
|
|
|
|
702 |
torch.cuda.empty_cache() # Clear GPU memory before starting
|
703 |
refined_img, mesh_path = refiner.refine_model(
|
704 |
input_image,
|
705 |
-
prompt,
|
706 |
-
steps,
|
707 |
guidance_scale
|
708 |
)
|
709 |
return refined_img, mesh_path, None # Return None for error message
|
|
|
19 |
except Exception as e:
|
20 |
print(f"Error detecting environment versions: {e}")
|
21 |
|
22 |
+
# Install PyTorch3D properly from source
|
23 |
+
print("Installing PyTorch3D from source...")
|
24 |
|
25 |
# First uninstall any existing PyTorch3D installation to avoid conflicts
|
26 |
os.system("pip uninstall -y pytorch3d")
|
27 |
|
28 |
+
# Install dependencies required for building PyTorch3D
|
29 |
+
os.system("apt-get update && apt-get install -y git build-essential libglib2.0-0 libsm6 libxrender-dev libxext6 ninja-build")
|
30 |
+
os.system("pip install 'imageio>=2.5.0' 'matplotlib>=3.1.2' 'numpy>=1.17.3' 'psutil>=5.6.5' 'scipy>=1.3.2' 'tqdm>=4.42.1' 'trimesh>=3.0.0'")
|
31 |
os.system("pip install fvcore iopath")
|
32 |
|
33 |
+
# Clone the PyTorch3D repository
|
34 |
+
os.system("rm -rf pytorch3d") # Remove any existing directory
|
35 |
+
os.system("git clone https://github.com/facebookresearch/pytorch3d.git")
|
36 |
|
37 |
+
# Use a specific release tag that is known to be stable
|
38 |
+
os.system("cd pytorch3d && git checkout v0.7.4")
|
|
|
39 |
|
40 |
+
# Install PyTorch3D from source with CPU support
|
41 |
+
os.system("cd pytorch3d && pip install -e .")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
42 |
|
43 |
+
# Verify the installation
|
44 |
+
import_result = os.popen('python -c "import pytorch3d; from pytorch3d import renderer; print(\'PyTorch3D and renderer successfully imported\')" 2>&1').read()
|
45 |
+
print(import_result)
|
46 |
|
47 |
+
# If the installation fails, try a different approach with a specific wheel
|
48 |
+
if "No module named" in import_result or "Error" in import_result:
|
49 |
+
print("Source installation failed, trying with a specific wheel...")
|
50 |
+
os.system("pip uninstall -y pytorch3d")
|
51 |
|
52 |
+
# Try with a specific wheel that's known to work
|
53 |
+
os.system("pip install https://dl.fbaipublicfiles.com/pytorch3d/packaging/wheels/py310_cpu_pyt201/pytorch3d-0.7.4-cp310-cp310-linux_x86_64.whl")
|
54 |
|
55 |
+
# Verify again
|
56 |
+
import_result = os.popen('python -c "import pytorch3d; from pytorch3d import renderer; print(\'PyTorch3D and renderer successfully imported\')" 2>&1').read()
|
57 |
+
print(import_result)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
58 |
|
59 |
+
# Patch the shap_e renderer to handle PyTorch3D renderer import error if needed
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
60 |
shap_e_renderer_path = "/usr/local/lib/python3.10/site-packages/shap_e/models/stf/renderer.py"
|
61 |
if os.path.exists(shap_e_renderer_path):
|
62 |
print(f"Patching shap_e renderer at {shap_e_renderer_path}")
|
|
|
93 |
else:
|
94 |
print(f"shap_e renderer not found at {shap_e_renderer_path}")
|
95 |
|
96 |
+
# Add a helper function to ensure PyTorch3D works with ZeroGPU
|
97 |
+
def ensure_pytorch3d_cuda_compatibility():
|
98 |
+
"""
|
99 |
+
This function ensures PyTorch3D works correctly with CUDA in ZeroGPU environments.
|
100 |
+
It should be called at the beginning of any @spaces.GPU decorated function.
|
101 |
+
"""
|
102 |
+
try:
|
103 |
+
import pytorch3d
|
104 |
+
if torch.cuda.is_available():
|
105 |
+
# Check if we can access the renderer module
|
106 |
+
from pytorch3d import renderer
|
107 |
+
print("PyTorch3D renderer module is available with CUDA")
|
108 |
+
else:
|
109 |
+
print("CUDA is not available, using CPU version of PyTorch3D")
|
110 |
+
except ImportError as e:
|
111 |
+
print(f"Error importing PyTorch3D: {e}")
|
112 |
+
except Exception as e:
|
113 |
+
print(f"Unexpected error with PyTorch3D: {e}")
|
114 |
|
115 |
import torch
|
116 |
import torch.nn as nn
|
|
|
289 |
@spaces.GPU(duration=20)
|
290 |
def process_images(input_images, prompt, steps=75, guidance_scale=7.5, pipeline=None):
|
291 |
"""Process input images and run refinement"""
|
292 |
+
# Ensure PyTorch3D works with CUDA
|
293 |
+
ensure_pytorch3d_cuda_compatibility()
|
294 |
+
|
295 |
device = pipeline.device
|
296 |
|
297 |
if isinstance(input_images, list):
|
|
|
362 |
@spaces.GPU(duration=20)
|
363 |
def create_mesh(refined_image, model, infer_config):
|
364 |
"""Generate mesh from refined image"""
|
365 |
+
# Ensure PyTorch3D works with CUDA
|
366 |
+
ensure_pytorch3d_cuda_compatibility()
|
367 |
+
|
368 |
# Convert PIL image to tensor
|
369 |
image = np.array(refined_image) / 255.0
|
370 |
image = torch.from_numpy(image).float().permute(2, 0, 1)
|
|
|
627 |
@spaces.GPU(duration=20) # Reduced duration to 20 seconds
|
628 |
def generate(prompt, guidance_scale, num_steps):
|
629 |
try:
|
630 |
+
# Ensure PyTorch3D works with CUDA
|
631 |
+
ensure_pytorch3d_cuda_compatibility()
|
632 |
+
|
633 |
torch.cuda.empty_cache() # Clear GPU memory before starting
|
634 |
with torch.no_grad():
|
635 |
layout, _ = shap_e.generate_views(prompt, guidance_scale, num_steps)
|
|
|
643 |
@spaces.GPU(duration=20) # Reduced duration to 20 seconds
|
644 |
def refine(input_image, prompt, steps, guidance_scale):
|
645 |
try:
|
646 |
+
# Ensure PyTorch3D works with CUDA
|
647 |
+
ensure_pytorch3d_cuda_compatibility()
|
648 |
+
|
649 |
torch.cuda.empty_cache() # Clear GPU memory before starting
|
650 |
refined_img, mesh_path = refiner.refine_model(
|
651 |
input_image,
|
652 |
+
prompt,
|
653 |
+
steps,
|
654 |
guidance_scale
|
655 |
)
|
656 |
return refined_img, mesh_path, None # Return None for error message
|