Spaces:
Sleeping
Sleeping
Update main.py
Browse files
main.py
CHANGED
@@ -1,4 +1,7 @@
|
|
1 |
import os
|
|
|
|
|
|
|
2 |
import torch
|
3 |
import librosa
|
4 |
import numpy as np
|
@@ -8,24 +11,16 @@ from transformers import Wav2Vec2Processor, Wav2Vec2ForCTC
|
|
8 |
from librosa.sequence import dtw
|
9 |
from contextlib import asynccontextmanager
|
10 |
|
11 |
-
os.environ["NUMBA_CACHE_DIR"] = "/tmp" # Ensure Numba caching works in container environments
|
12 |
-
|
13 |
-
|
14 |
# --- Core Class Definition ---
|
15 |
class QuranRecitationComparer:
|
16 |
def __init__(self, model_name="jonatasgrosman/wav2vec2-large-xlsr-53-arabic", auth_token=None):
|
17 |
-
"""
|
18 |
-
Initialize the Quran recitation comparer with a specific Wav2Vec2 model.
|
19 |
-
"""
|
20 |
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
21 |
-
|
22 |
if auth_token:
|
23 |
self.processor = Wav2Vec2Processor.from_pretrained(model_name, token=auth_token)
|
24 |
self.model = Wav2Vec2ForCTC.from_pretrained(model_name, token=auth_token)
|
25 |
else:
|
26 |
self.processor = Wav2Vec2Processor.from_pretrained(model_name)
|
27 |
self.model = Wav2Vec2ForCTC.from_pretrained(model_name)
|
28 |
-
|
29 |
self.model = self.model.to(self.device)
|
30 |
self.model.eval()
|
31 |
self.embedding_cache = {}
|
@@ -46,10 +41,8 @@ class QuranRecitationComparer:
|
|
46 |
sampling_rate=sr,
|
47 |
return_tensors="pt"
|
48 |
).input_values.to(self.device)
|
49 |
-
|
50 |
with torch.no_grad():
|
51 |
outputs = self.model(input_values, output_hidden_states=True)
|
52 |
-
|
53 |
hidden_states = outputs.hidden_states[-1]
|
54 |
embedding_seq = hidden_states.squeeze(0).cpu().numpy()
|
55 |
return embedding_seq
|
@@ -105,7 +98,6 @@ class QuranRecitationComparer:
|
|
105 |
@asynccontextmanager
|
106 |
async def lifespan(app: FastAPI):
|
107 |
global comparer
|
108 |
-
# Use environment variables or a secure configuration in production
|
109 |
auth_token = os.environ.get("HF_TOKEN")
|
110 |
comparer = QuranRecitationComparer(
|
111 |
model_name="jonatasgrosman/wav2vec2-large-xlsr-53-arabic",
|
|
|
1 |
import os
|
2 |
+
os.environ["NUMBA_CACHE_DIR"] = "/tmp" # Ensure a writable cache directory
|
3 |
+
os.environ["NUMBA_DISABLE_CACHE"] = "1" # Disable Numba caching to avoid errors
|
4 |
+
|
5 |
import torch
|
6 |
import librosa
|
7 |
import numpy as np
|
|
|
11 |
from librosa.sequence import dtw
|
12 |
from contextlib import asynccontextmanager
|
13 |
|
|
|
|
|
|
|
14 |
# --- Core Class Definition ---
|
15 |
class QuranRecitationComparer:
|
16 |
def __init__(self, model_name="jonatasgrosman/wav2vec2-large-xlsr-53-arabic", auth_token=None):
|
|
|
|
|
|
|
17 |
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
|
18 |
if auth_token:
|
19 |
self.processor = Wav2Vec2Processor.from_pretrained(model_name, token=auth_token)
|
20 |
self.model = Wav2Vec2ForCTC.from_pretrained(model_name, token=auth_token)
|
21 |
else:
|
22 |
self.processor = Wav2Vec2Processor.from_pretrained(model_name)
|
23 |
self.model = Wav2Vec2ForCTC.from_pretrained(model_name)
|
|
|
24 |
self.model = self.model.to(self.device)
|
25 |
self.model.eval()
|
26 |
self.embedding_cache = {}
|
|
|
41 |
sampling_rate=sr,
|
42 |
return_tensors="pt"
|
43 |
).input_values.to(self.device)
|
|
|
44 |
with torch.no_grad():
|
45 |
outputs = self.model(input_values, output_hidden_states=True)
|
|
|
46 |
hidden_states = outputs.hidden_states[-1]
|
47 |
embedding_seq = hidden_states.squeeze(0).cpu().numpy()
|
48 |
return embedding_seq
|
|
|
98 |
@asynccontextmanager
|
99 |
async def lifespan(app: FastAPI):
|
100 |
global comparer
|
|
|
101 |
auth_token = os.environ.get("HF_TOKEN")
|
102 |
comparer = QuranRecitationComparer(
|
103 |
model_name="jonatasgrosman/wav2vec2-large-xlsr-53-arabic",
|