Ankerkraut commited on
Commit
7ffa2a6
·
1 Parent(s): aff2efe
Files changed (1) hide show
  1. app.py +8 -4
app.py CHANGED
@@ -9,14 +9,15 @@ import torch
9
  import json
10
  import bs4
11
 
 
12
  def install_cuda_toolkit():
13
- CUDA_TOOLKIT_URL = "https://developer.download.nvidia.com/compute/cuda/12.4.0/local_installers/cuda_12.4.0_550.54.14_linux.run"
14
- # CUDA_TOOLKIT_URL = "https://developer.download.nvidia.com/compute/cuda/12.2.0/local_installers/cuda_12.2.0_535.54.03_linux.run"
15
  CUDA_TOOLKIT_FILE = "/tmp/%s" % os.path.basename(CUDA_TOOLKIT_URL)
16
  subprocess.call(["wget", "-q", CUDA_TOOLKIT_URL, "-O", CUDA_TOOLKIT_FILE])
17
  subprocess.call(["chmod", "+x", CUDA_TOOLKIT_FILE])
18
  subprocess.call([CUDA_TOOLKIT_FILE, "--silent", "--toolkit"])
19
-
20
  os.environ["CUDA_HOME"] = "/usr/local/cuda"
21
  os.environ["PATH"] = "%s/bin:%s" % (os.environ["CUDA_HOME"], os.environ["PATH"])
22
  os.environ["LD_LIBRARY_PATH"] = "%s/lib:%s" % (
@@ -26,7 +27,10 @@ def install_cuda_toolkit():
26
  # Fix: arch_list[-1] += '+PTX'; IndexError: list index out of range
27
  os.environ["TORCH_CUDA_ARCH_LIST"] = "8.0;8.6"
28
 
29
- install_cuda_toolkkit()
 
 
 
30
 
31
  product_strings = []
32
  with open('./Data/products.json', 'r', encoding='utf-8') as f:
 
9
  import json
10
  import bs4
11
 
12
+ @spaces.GPU
13
  def install_cuda_toolkit():
14
+ # CUDA_TOOLKIT_URL = "https://developer.download.nvidia.com/compute/cuda/11.8.0/local_installers/cuda_11.8.0_520.61.05_linux.run"
15
+ CUDA_TOOLKIT_URL = "https://developer.download.nvidia.com/compute/cuda/12.2.0/local_installers/cuda_12.2.0_535.54.03_linux.run"
16
  CUDA_TOOLKIT_FILE = "/tmp/%s" % os.path.basename(CUDA_TOOLKIT_URL)
17
  subprocess.call(["wget", "-q", CUDA_TOOLKIT_URL, "-O", CUDA_TOOLKIT_FILE])
18
  subprocess.call(["chmod", "+x", CUDA_TOOLKIT_FILE])
19
  subprocess.call([CUDA_TOOLKIT_FILE, "--silent", "--toolkit"])
20
+
21
  os.environ["CUDA_HOME"] = "/usr/local/cuda"
22
  os.environ["PATH"] = "%s/bin:%s" % (os.environ["CUDA_HOME"], os.environ["PATH"])
23
  os.environ["LD_LIBRARY_PATH"] = "%s/lib:%s" % (
 
27
  # Fix: arch_list[-1] += '+PTX'; IndexError: list index out of range
28
  os.environ["TORCH_CUDA_ARCH_LIST"] = "8.0;8.6"
29
 
30
+ install_cuda_toolkit()
31
+
32
+ device = "cuda" if torch.cuda.is_available() else "cpu"
33
+
34
 
35
  product_strings = []
36
  with open('./Data/products.json', 'r', encoding='utf-8') as f: