YiftachEde commited on
Commit
12b3742
·
verified ·
1 Parent(s): 85865a0

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +57 -110
app.py CHANGED
@@ -19,124 +19,44 @@ try:
19
  except Exception as e:
20
  print(f"Error detecting environment versions: {e}")
21
 
22
- # Install PyTorch3D with CPU support for ZeroGPU environment
23
- print("Installing PyTorch3D for CPU/ZeroGPU environment...")
24
 
25
  # First uninstall any existing PyTorch3D installation to avoid conflicts
26
  os.system("pip uninstall -y pytorch3d")
27
 
28
- # Install dependencies required for PyTorch3D
 
 
29
  os.system("pip install fvcore iopath")
30
 
31
- # For ZeroGPU, we need a CPU-only version that still has the renderer module
32
- # Try installing from a pre-built wheel first (faster)
33
- os.system("pip install pytorch3d")
34
 
35
- # Verify if the renderer module is available
36
- import_result = os.popen('python -c "import pytorch3d; print(\'renderer\' in dir(pytorch3d))" 2>&1').read().strip()
37
- print(f"Renderer module check: {import_result}")
38
 
39
- # If the renderer module is not available, create it manually
40
- if import_result == "False" or "Error" in import_result:
41
- print("Renderer module not found, creating it manually...")
42
-
43
- # Get the site-packages directory
44
- site_packages = os.popen('python -c "import site; print(site.getsitepackages()[0])"').read().strip()
45
-
46
- # Check if the pytorch3d directory exists
47
- pytorch3d_dir = f"{site_packages}/pytorch3d"
48
- if not os.path.exists(pytorch3d_dir):
49
- print(f"PyTorch3D directory not found at {pytorch3d_dir}, creating it...")
50
- os.system(f"mkdir -p {pytorch3d_dir}")
51
-
52
- # Create an __init__.py file in the pytorch3d directory
53
- with open(f"{pytorch3d_dir}/__init__.py", "w") as f:
54
- f.write("""
55
- # PyTorch3D module
56
- import torch
57
- import warnings
58
 
59
- warnings.warn("Using custom PyTorch3D module")
 
 
60
 
61
- __version__ = "0.7.4"
62
- """)
 
 
63
 
64
- # Create the renderer directory if it doesn't exist
65
- os.system(f"mkdir -p {pytorch3d_dir}/renderer")
66
 
67
- # Create an __init__.py file in the renderer directory
68
- with open(f"{pytorch3d_dir}/renderer/__init__.py", "w") as f:
69
- f.write("""
70
- # PyTorch3D renderer module
71
- import torch
72
- import warnings
73
-
74
- warnings.warn("Using custom PyTorch3D renderer module")
75
-
76
- # Basic renderer components
77
- class TexturesBase:
78
- def __init__(self, *args, **kwargs):
79
- pass
80
-
81
- class TexturesVertex(TexturesBase):
82
- def __init__(self, *args, **kwargs):
83
- super().__init__(*args, **kwargs)
84
-
85
- class TexturesUV(TexturesBase):
86
- def __init__(self, *args, **kwargs):
87
- super().__init__(*args, **kwargs)
88
-
89
- class TexturesAtlas(TexturesBase):
90
- def __init__(self, *args, **kwargs):
91
- super().__init__(*args, **kwargs)
92
-
93
- class RasterizationSettings:
94
- def __init__(self, *args, **kwargs):
95
- pass
96
-
97
- class MeshRasterizer:
98
- def __init__(self, *args, **kwargs):
99
- pass
100
-
101
- class PointLights:
102
- def __init__(self, *args, **kwargs):
103
- pass
104
-
105
- class Materials:
106
- def __init__(self, *args, **kwargs):
107
- pass
108
 
109
- class MeshRenderer:
110
- def __init__(self, *args, **kwargs):
111
- pass
112
-
113
- class SoftPhongShader:
114
- def __init__(self, *args, **kwargs):
115
- pass
116
-
117
- class HardPhongShader:
118
- def __init__(self, *args, **kwargs):
119
- pass
120
-
121
- class SoftSilhouetteShader:
122
- def __init__(self, *args, **kwargs):
123
- pass
124
-
125
- class BlendParams:
126
- def __init__(self, *args, **kwargs):
127
- pass
128
-
129
- def look_at_view_transform(*args, **kwargs):
130
- return torch.eye(4), torch.eye(4)
131
-
132
- class FoVPerspectiveCameras:
133
- def __init__(self, *args, **kwargs):
134
- pass
135
- """)
136
-
137
- print("Created custom renderer module")
138
-
139
- # Patch the shap_e renderer to handle PyTorch3D renderer import error
140
  shap_e_renderer_path = "/usr/local/lib/python3.10/site-packages/shap_e/models/stf/renderer.py"
141
  if os.path.exists(shap_e_renderer_path):
142
  print(f"Patching shap_e renderer at {shap_e_renderer_path}")
@@ -173,9 +93,24 @@ if os.path.exists(shap_e_renderer_path):
173
  else:
174
  print(f"shap_e renderer not found at {shap_e_renderer_path}")
175
 
176
- # Verify the installation again
177
- import_result = os.popen('python -c "import pytorch3d; from pytorch3d import renderer; print(\'PyTorch3D and renderer successfully imported\')" 2>&1').read()
178
- print(import_result)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
179
 
180
  import torch
181
  import torch.nn as nn
@@ -354,6 +289,9 @@ def load_models():
354
  @spaces.GPU(duration=20)
355
  def process_images(input_images, prompt, steps=75, guidance_scale=7.5, pipeline=None):
356
  """Process input images and run refinement"""
 
 
 
357
  device = pipeline.device
358
 
359
  if isinstance(input_images, list):
@@ -424,6 +362,9 @@ def process_images(input_images, prompt, steps=75, guidance_scale=7.5, pipeline=
424
  @spaces.GPU(duration=20)
425
  def create_mesh(refined_image, model, infer_config):
426
  """Generate mesh from refined image"""
 
 
 
427
  # Convert PIL image to tensor
428
  image = np.array(refined_image) / 255.0
429
  image = torch.from_numpy(image).float().permute(2, 0, 1)
@@ -686,6 +627,9 @@ def create_demo():
686
  @spaces.GPU(duration=20) # Reduced duration to 20 seconds
687
  def generate(prompt, guidance_scale, num_steps):
688
  try:
 
 
 
689
  torch.cuda.empty_cache() # Clear GPU memory before starting
690
  with torch.no_grad():
691
  layout, _ = shap_e.generate_views(prompt, guidance_scale, num_steps)
@@ -699,11 +643,14 @@ def create_demo():
699
  @spaces.GPU(duration=20) # Reduced duration to 20 seconds
700
  def refine(input_image, prompt, steps, guidance_scale):
701
  try:
 
 
 
702
  torch.cuda.empty_cache() # Clear GPU memory before starting
703
  refined_img, mesh_path = refiner.refine_model(
704
  input_image,
705
- prompt,
706
- steps,
707
  guidance_scale
708
  )
709
  return refined_img, mesh_path, None # Return None for error message
 
19
  except Exception as e:
20
  print(f"Error detecting environment versions: {e}")
21
 
22
+ # Install PyTorch3D properly from source
23
+ print("Installing PyTorch3D from source...")
24
 
25
  # First uninstall any existing PyTorch3D installation to avoid conflicts
26
  os.system("pip uninstall -y pytorch3d")
27
 
28
+ # Install dependencies required for building PyTorch3D
29
+ os.system("apt-get update && apt-get install -y git build-essential libglib2.0-0 libsm6 libxrender-dev libxext6 ninja-build")
30
+ os.system("pip install 'imageio>=2.5.0' 'matplotlib>=3.1.2' 'numpy>=1.17.3' 'psutil>=5.6.5' 'scipy>=1.3.2' 'tqdm>=4.42.1' 'trimesh>=3.0.0'")
31
  os.system("pip install fvcore iopath")
32
 
33
+ # Clone the PyTorch3D repository
34
+ os.system("rm -rf pytorch3d") # Remove any existing directory
35
+ os.system("git clone https://github.com/facebookresearch/pytorch3d.git")
36
 
37
+ # Use a specific release tag that is known to be stable
38
+ os.system("cd pytorch3d && git checkout v0.7.4")
 
39
 
40
+ # Install PyTorch3D from source with CPU support
41
+ os.system("cd pytorch3d && pip install -e .")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
42
 
43
+ # Verify the installation
44
+ import_result = os.popen('python -c "import pytorch3d; from pytorch3d import renderer; print(\'PyTorch3D and renderer successfully imported\')" 2>&1').read()
45
+ print(import_result)
46
 
47
+ # If the installation fails, try a different approach with a specific wheel
48
+ if "No module named" in import_result or "Error" in import_result:
49
+ print("Source installation failed, trying with a specific wheel...")
50
+ os.system("pip uninstall -y pytorch3d")
51
 
52
+ # Try with a specific wheel that's known to work
53
+ os.system("pip install https://dl.fbaipublicfiles.com/pytorch3d/packaging/wheels/py310_cpu_pyt201/pytorch3d-0.7.4-cp310-cp310-linux_x86_64.whl")
54
 
55
+ # Verify again
56
+ import_result = os.popen('python -c "import pytorch3d; from pytorch3d import renderer; print(\'PyTorch3D and renderer successfully imported\')" 2>&1').read()
57
+ print(import_result)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
58
 
59
+ # Patch the shap_e renderer to handle PyTorch3D renderer import error if needed
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
60
  shap_e_renderer_path = "/usr/local/lib/python3.10/site-packages/shap_e/models/stf/renderer.py"
61
  if os.path.exists(shap_e_renderer_path):
62
  print(f"Patching shap_e renderer at {shap_e_renderer_path}")
 
93
  else:
94
  print(f"shap_e renderer not found at {shap_e_renderer_path}")
95
 
96
+ # Add a helper function to ensure PyTorch3D works with ZeroGPU
97
+ def ensure_pytorch3d_cuda_compatibility():
98
+ """
99
+ This function ensures PyTorch3D works correctly with CUDA in ZeroGPU environments.
100
+ It should be called at the beginning of any @spaces.GPU decorated function.
101
+ """
102
+ try:
103
+ import pytorch3d
104
+ if torch.cuda.is_available():
105
+ # Check if we can access the renderer module
106
+ from pytorch3d import renderer
107
+ print("PyTorch3D renderer module is available with CUDA")
108
+ else:
109
+ print("CUDA is not available, using CPU version of PyTorch3D")
110
+ except ImportError as e:
111
+ print(f"Error importing PyTorch3D: {e}")
112
+ except Exception as e:
113
+ print(f"Unexpected error with PyTorch3D: {e}")
114
 
115
  import torch
116
  import torch.nn as nn
 
289
  @spaces.GPU(duration=20)
290
  def process_images(input_images, prompt, steps=75, guidance_scale=7.5, pipeline=None):
291
  """Process input images and run refinement"""
292
+ # Ensure PyTorch3D works with CUDA
293
+ ensure_pytorch3d_cuda_compatibility()
294
+
295
  device = pipeline.device
296
 
297
  if isinstance(input_images, list):
 
362
  @spaces.GPU(duration=20)
363
  def create_mesh(refined_image, model, infer_config):
364
  """Generate mesh from refined image"""
365
+ # Ensure PyTorch3D works with CUDA
366
+ ensure_pytorch3d_cuda_compatibility()
367
+
368
  # Convert PIL image to tensor
369
  image = np.array(refined_image) / 255.0
370
  image = torch.from_numpy(image).float().permute(2, 0, 1)
 
627
  @spaces.GPU(duration=20) # Reduced duration to 20 seconds
628
  def generate(prompt, guidance_scale, num_steps):
629
  try:
630
+ # Ensure PyTorch3D works with CUDA
631
+ ensure_pytorch3d_cuda_compatibility()
632
+
633
  torch.cuda.empty_cache() # Clear GPU memory before starting
634
  with torch.no_grad():
635
  layout, _ = shap_e.generate_views(prompt, guidance_scale, num_steps)
 
643
  @spaces.GPU(duration=20) # Reduced duration to 20 seconds
644
  def refine(input_image, prompt, steps, guidance_scale):
645
  try:
646
+ # Ensure PyTorch3D works with CUDA
647
+ ensure_pytorch3d_cuda_compatibility()
648
+
649
  torch.cuda.empty_cache() # Clear GPU memory before starting
650
  refined_img, mesh_path = refiner.refine_model(
651
  input_image,
652
+ prompt,
653
+ steps,
654
  guidance_scale
655
  )
656
  return refined_img, mesh_path, None # Return None for error message