lucas-ventura commited on
Commit
b7915ab
·
verified ·
1 Parent(s): e27182c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +19 -0
app.py CHANGED
@@ -1,6 +1,7 @@
1
  import os
2
  import tempfile
3
  from pathlib import Path
 
4
 
5
  import gradio as gr
6
  import spaces
@@ -14,6 +15,7 @@ from src.models.llama_inference import inference
14
  from src.test.vidchapters import get_chapters
15
  from tools.download.models import download_base_model, download_model
16
 
 
17
  # Set up proxies
18
  # from urllib.request import getproxies
19
  # proxies = getproxies()
@@ -29,6 +31,23 @@ inference_model = None
29
 
30
  LLAMA_CKPT_PATH = "meta-llama/Meta-Llama-3.1-8B-Instruct"
31
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
32
  @spaces.GPU
33
  def load_base_model():
34
  """Load the base Llama model and tokenizer once at startup."""
 
1
  import os
2
  import tempfile
3
  from pathlib import Path
4
+ import subprocess
5
 
6
  import gradio as gr
7
  import spaces
 
15
  from src.test.vidchapters import get_chapters
16
  from tools.download.models import download_base_model, download_model
17
 
18
+
19
  # Set up proxies
20
  # from urllib.request import getproxies
21
  # proxies = getproxies()
 
31
 
32
  LLAMA_CKPT_PATH = "meta-llama/Meta-Llama-3.1-8B-Instruct"
33
 
34
+ def install_cudnn():
35
+ """Installs specific versions of libcudnn and configures torch for TF32."""
36
+
37
+ try:
38
+ subprocess.run(["apt-get", "update"], check=True)
39
+ subprocess.run(["apt-get", "install", "-y", "libcudnn8=8.9.2.26-1+cuda12.1"], check=True)
40
+ subprocess.run(["apt-get", "install", "-y", "libcudnn8-dev=8.9.2.26-1+cuda12.1"], check=True)
41
+ subprocess.run(["python", "-c", "import torch; torch.backends.cuda.matmul.allow_tf32 = True; torch.backends.cudnn.allow_tf32 = True"], check=True)
42
+ subprocess.run(["ln", "-s", "/usr/lib/x86_64-linux-gnu/libcudnn_ops_infer.so", "/usr/lib/x86_64-linux-gnu/libcudnn_ops_infer.so.8"], check=True)
43
+ print("cuDNN installation and configuration successful.")
44
+ except subprocess.CalledProcessError as e:
45
+ print(f"Error during cuDNN installation: {e}")
46
+ except FileNotFoundError:
47
+ print("apt-get or python not found. Ensure they are in your PATH.")
48
+
49
+ install_cudnn()
50
+
51
  @spaces.GPU
52
  def load_base_model():
53
  """Load the base Llama model and tokenizer once at startup."""