Sofia Casadei commited on
Commit
0d64afb
·
1 Parent(s): 7d60045

install flash attention

Browse files
Files changed (1) hide show
  1. main.py +10 -0
main.py CHANGED
@@ -3,6 +3,7 @@ import logging
3
  import json
4
  import torch
5
  import asyncio
 
6
 
7
  import gradio as gr
8
  import numpy as np
@@ -41,6 +42,15 @@ MODEL_ID = os.getenv("MODEL_ID", "openai/whisper-large-v3-turbo")
41
  LANGUAGE = os.getenv("LANGUAGE", "english").lower()
42
 
43
  device = get_device(force_cpu=False)
 
 
 
 
 
 
 
 
 
44
  torch_dtype, np_dtype = get_torch_and_np_dtypes(device, use_bfloat16=False)
45
  logger.info(f"Using device: {device}, torch_dtype: {torch_dtype}, np_dtype: {np_dtype}")
46
 
 
3
  import json
4
  import torch
5
  import asyncio
6
+ import subprocess
7
 
8
  import gradio as gr
9
  import numpy as np
 
42
  LANGUAGE = os.getenv("LANGUAGE", "english").lower()
43
 
44
  device = get_device(force_cpu=False)
45
+
46
+ # Install Flash Attention 2 if device is "cuda"
47
+ if device == "cuda":
48
+ subprocess.run(
49
+ ["pip", "install", "flash-attn", "--no-build-isolation"],
50
+ env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"},
51
+ shell=True,
52
+ )
53
+
54
  torch_dtype, np_dtype = get_torch_and_np_dtypes(device, use_bfloat16=False)
55
  logger.info(f"Using device: {device}, torch_dtype: {torch_dtype}, np_dtype: {np_dtype}")
56