YiftachEde commited on
Commit
6aac4cc
·
verified ·
1 Parent(s): 22f4ba2

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +91 -1
app.py CHANGED
@@ -8,7 +8,97 @@ import spaces
8
  import sys
9
  import torch
10
 
11
- os.system("pip install pytorch3d -f https://dl.fbaipublicfiles.com/pytorch3d/packaging/wheels/py310_cu121_pyt221/download.html")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
  import torch
13
  import torch.nn as nn
14
  import gradio as gr
 
8
  import sys
9
  import torch
10
 
11
+ # Print debug information about the environment
12
+ try:
13
+ cuda_version = torch.version.cuda
14
+ torch_version = torch.__version__
15
+ python_version = f"{sys.version_info.major}.{sys.version_info.minor}"
16
+ print(f"CUDA Version: {cuda_version}")
17
+ print(f"PyTorch Version: {torch_version}")
18
+ print(f"Python Version: {python_version}")
19
+ except Exception as e:
20
+ print(f"Error detecting environment versions: {e}")
21
+
22
+ # Install PyTorch3D from source to ensure compatibility
23
+ print("Installing PyTorch3D from source...")
24
+
25
+ # Install build dependencies
26
+ os.system("apt-get update && apt-get install -y git build-essential libglib2.0-0 libsm6 libxrender-dev libxext6")
27
+ os.system("pip install 'imageio>=2.5.0' 'ipywidgets>=7.5.0' 'matplotlib>=3.1.2' 'numpy>=1.17.3' 'psutil>=5.6.5' 'scipy>=1.3.2' 'tqdm>=4.42.1' 'trimesh>=3.0.0'")
28
+ os.system("pip install fvcore iopath")
29
+
30
+ # Clone and install PyTorch3D
31
+ os.system("rm -rf pytorch3d") # Remove any existing directory
32
+ os.system("git clone https://github.com/facebookresearch/pytorch3d.git")
33
+ os.system("cd pytorch3d && pip install -e .")
34
+
35
+ # Verify installation
36
+ import_result = os.popen('python -c "import pytorch3d; print(\'PyTorch3D import successful\')" 2>&1').read()
37
+ print(import_result)
38
+
39
+ # If source installation fails, try the wheel as fallback
40
+ if 'undefined symbol: _ZN3c104cuda9SetDeviceEi' in import_result or 'No module named' in import_result:
41
+ print("Source installation failed, trying pre-built wheel...")
42
+ os.system("pip install pytorch3d -f https://dl.fbaipublicfiles.com/pytorch3d/packaging/wheels/py310_cu118_pyt201/download.html")
43
+
44
+ # Verify again
45
+ import_result = os.popen('python -c "import pytorch3d; print(\'PyTorch3D import successful\')" 2>&1').read()
46
+ print(import_result)
47
+
48
+ # If all installation methods fail, try CPU-only version
49
+ if 'undefined symbol: _ZN3c104cuda9SetDeviceEi' in import_result or 'No module named' in import_result:
50
+ print("All GPU installation attempts failed, falling back to CPU-only version...")
51
+ os.system("pip install pytorch3d")
52
+
53
+ # Verify again
54
+ import_result = os.popen('python -c "import pytorch3d; print(\'PyTorch3D import successful\')" 2>&1').read()
55
+ print(import_result)
56
+
57
+ # Create a workaround for the specific error if it still persists
58
+ if 'undefined symbol: _ZN3c104cuda9SetDeviceEi' in import_result:
59
+ print("Creating workaround for the specific error...")
60
+
61
+ # Create a simple wrapper module that will handle the import error
62
+ with open("pytorch3d_wrapper.py", "w") as f:
63
+ f.write("""
64
+ import torch
65
+ import sys
66
+
67
+ # Create mock PyTorch3D module
68
+ class MockPyTorch3D:
69
+ def __init__(self):
70
+ self.available = False
71
+ print("WARNING: Using mock PyTorch3D implementation due to compatibility issues")
72
+
73
+ def __getattr__(self, name):
74
+ # Return dummy functions/objects that won't crash
75
+ if name == "__path__":
76
+ return []
77
+ return self
78
+
79
+ def __call__(self, *args, **kwargs):
80
+ return self
81
+
82
+ # Try to import the real PyTorch3D
83
+ try:
84
+ import pytorch3d as real_pytorch3d
85
+ sys.modules['pytorch3d'] = real_pytorch3d
86
+ print("Successfully imported real PyTorch3D")
87
+ except ImportError as e:
88
+ if 'undefined symbol: _ZN3c104cuda9SetDeviceEi' in str(e):
89
+ # If we get the specific error, use the mock implementation
90
+ mock_module = MockPyTorch3D()
91
+ sys.modules['pytorch3d'] = mock_module
92
+ print(f"Using mock PyTorch3D due to: {e}")
93
+ else:
94
+ # For other import errors, just raise them
95
+ raise e
96
+ """)
97
+
98
+ # Use the wrapper
99
+ os.system("cp pytorch3d_wrapper.py $(python -c 'import site; print(site.getsitepackages()[0])')/pytorch3d_wrapper.py")
100
+ os.system("echo 'from pytorch3d_wrapper import *' > $(python -c 'import site; print(site.getsitepackages()[0])')/pytorch3d.py")
101
+
102
  import torch
103
  import torch.nn as nn
104
  import gradio as gr