Hammad712 commited on
Commit
d7fd2ab
·
verified ·
1 Parent(s): d8e677e

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +3 -11
main.py CHANGED
@@ -1,4 +1,7 @@
1
  import os
 
 
 
2
  import torch
3
  import librosa
4
  import numpy as np
@@ -8,24 +11,16 @@ from transformers import Wav2Vec2Processor, Wav2Vec2ForCTC
8
  from librosa.sequence import dtw
9
  from contextlib import asynccontextmanager
10
 
11
- os.environ["NUMBA_CACHE_DIR"] = "/tmp" # Ensure Numba caching works in container environments
12
-
13
-
14
  # --- Core Class Definition ---
15
  class QuranRecitationComparer:
16
  def __init__(self, model_name="jonatasgrosman/wav2vec2-large-xlsr-53-arabic", auth_token=None):
17
- """
18
- Initialize the Quran recitation comparer with a specific Wav2Vec2 model.
19
- """
20
  self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
21
-
22
  if auth_token:
23
  self.processor = Wav2Vec2Processor.from_pretrained(model_name, token=auth_token)
24
  self.model = Wav2Vec2ForCTC.from_pretrained(model_name, token=auth_token)
25
  else:
26
  self.processor = Wav2Vec2Processor.from_pretrained(model_name)
27
  self.model = Wav2Vec2ForCTC.from_pretrained(model_name)
28
-
29
  self.model = self.model.to(self.device)
30
  self.model.eval()
31
  self.embedding_cache = {}
@@ -46,10 +41,8 @@ class QuranRecitationComparer:
46
  sampling_rate=sr,
47
  return_tensors="pt"
48
  ).input_values.to(self.device)
49
-
50
  with torch.no_grad():
51
  outputs = self.model(input_values, output_hidden_states=True)
52
-
53
  hidden_states = outputs.hidden_states[-1]
54
  embedding_seq = hidden_states.squeeze(0).cpu().numpy()
55
  return embedding_seq
@@ -105,7 +98,6 @@ class QuranRecitationComparer:
105
  @asynccontextmanager
106
  async def lifespan(app: FastAPI):
107
  global comparer
108
- # Use environment variables or a secure configuration in production
109
  auth_token = os.environ.get("HF_TOKEN")
110
  comparer = QuranRecitationComparer(
111
  model_name="jonatasgrosman/wav2vec2-large-xlsr-53-arabic",
 
1
  import os
2
+ os.environ["NUMBA_CACHE_DIR"] = "/tmp" # Ensure a writable cache directory
3
+ os.environ["NUMBA_DISABLE_CACHE"] = "1" # Disable Numba caching to avoid errors
4
+
5
  import torch
6
  import librosa
7
  import numpy as np
 
11
  from librosa.sequence import dtw
12
  from contextlib import asynccontextmanager
13
 
 
 
 
14
  # --- Core Class Definition ---
15
  class QuranRecitationComparer:
16
  def __init__(self, model_name="jonatasgrosman/wav2vec2-large-xlsr-53-arabic", auth_token=None):
 
 
 
17
  self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
 
18
  if auth_token:
19
  self.processor = Wav2Vec2Processor.from_pretrained(model_name, token=auth_token)
20
  self.model = Wav2Vec2ForCTC.from_pretrained(model_name, token=auth_token)
21
  else:
22
  self.processor = Wav2Vec2Processor.from_pretrained(model_name)
23
  self.model = Wav2Vec2ForCTC.from_pretrained(model_name)
 
24
  self.model = self.model.to(self.device)
25
  self.model.eval()
26
  self.embedding_cache = {}
 
41
  sampling_rate=sr,
42
  return_tensors="pt"
43
  ).input_values.to(self.device)
 
44
  with torch.no_grad():
45
  outputs = self.model(input_values, output_hidden_states=True)
 
46
  hidden_states = outputs.hidden_states[-1]
47
  embedding_seq = hidden_states.squeeze(0).cpu().numpy()
48
  return embedding_seq
 
98
  @asynccontextmanager
99
  async def lifespan(app: FastAPI):
100
  global comparer
 
101
  auth_token = os.environ.get("HF_TOKEN")
102
  comparer = QuranRecitationComparer(
103
  model_name="jonatasgrosman/wav2vec2-large-xlsr-53-arabic",