Hammad712 commited on
Commit
0356e8f
·
verified ·
1 Parent(s): e74caf7

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +15 -44
main.py CHANGED
@@ -5,13 +5,8 @@ import numpy as np
5
  import tempfile
6
  from fastapi import FastAPI, UploadFile, File, HTTPException
7
  from transformers import Wav2Vec2Processor, Wav2Vec2ForCTC
8
- from librosa.sequence import dtw # Ensure librosa==0.9.2 is installed
9
-
10
- app = FastAPI(
11
- title="Quran Recitation Comparer API",
12
- description="Compares two Quran recitations using a deep wav2vec2 model.",
13
- version="1.0"
14
- )
15
 
16
  # --- Core Class Definition ---
17
  class QuranRecitationComparer:
@@ -21,7 +16,6 @@ class QuranRecitationComparer:
21
  """
22
  self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
23
 
24
- # Load model and processor once during initialization
25
  if auth_token:
26
  self.processor = Wav2Vec2Processor.from_pretrained(model_name, token=auth_token)
27
  self.model = Wav2Vec2ForCTC.from_pretrained(model_name, token=auth_token)
@@ -31,12 +25,9 @@ class QuranRecitationComparer:
31
 
32
  self.model = self.model.to(self.device)
33
  self.model.eval()
34
-
35
- # Cache for embeddings to avoid recomputation
36
  self.embedding_cache = {}
37
 
38
  def load_audio(self, file_path, target_sr=16000, trim_silence=True, normalize=True):
39
- """Load and preprocess an audio file."""
40
  if not os.path.exists(file_path):
41
  raise FileNotFoundError(f"Audio file not found: {file_path}")
42
  y, sr = librosa.load(file_path, sr=target_sr)
@@ -47,7 +38,6 @@ class QuranRecitationComparer:
47
  return y
48
 
49
  def get_deep_embedding(self, audio, sr=16000):
50
- """Extract frame-wise deep embeddings using the pretrained model."""
51
  input_values = self.processor(
52
  audio,
53
  sampling_rate=sr,
@@ -62,14 +52,12 @@ class QuranRecitationComparer:
62
  return embedding_seq
63
 
64
  def compute_dtw_distance(self, features1, features2):
65
- """Compute the DTW distance between two sequences of features."""
66
  D, wp = dtw(X=features1, Y=features2, metric='euclidean')
67
  distance = D[-1, -1]
68
  normalized_distance = distance / len(wp)
69
  return normalized_distance
70
 
71
  def interpret_similarity(self, norm_distance):
72
- """Interpret the normalized distance value."""
73
  if norm_distance == 0:
74
  result = "The recitations are identical based on the deep embeddings."
75
  score = 100
@@ -91,48 +79,45 @@ class QuranRecitationComparer:
91
  return result, score
92
 
93
  def get_embedding_for_file(self, file_path):
94
- """Get embedding for a file, using cache if available."""
95
  if file_path in self.embedding_cache:
96
  return self.embedding_cache[file_path]
97
  audio = self.load_audio(file_path)
98
  embedding = self.get_deep_embedding(audio)
99
- # Store in cache for future use
100
  self.embedding_cache[file_path] = embedding
101
  return embedding
102
 
103
  def predict(self, file_path1, file_path2):
104
- """
105
- Predict the similarity between two audio files.
106
- Args:
107
- file_path1 (str): Path to first audio file.
108
- file_path2 (str): Path to second audio file.
109
- Returns:
110
- (float, str): Similarity score and interpretation.
111
- """
112
  embedding1 = self.get_embedding_for_file(file_path1)
113
  embedding2 = self.get_embedding_for_file(file_path2)
114
  norm_distance = self.compute_dtw_distance(embedding1.T, embedding2.T)
115
  interpretation, similarity_score = self.interpret_similarity(norm_distance)
116
- # Optionally log the results instead of printing in production
117
  print(f"Similarity Score: {similarity_score:.1f}/100")
118
  print(f"Interpretation: {interpretation}")
119
  return similarity_score, interpretation
120
 
121
  def clear_cache(self):
122
- """Clear the embedding cache to free memory."""
123
  self.embedding_cache = {}
124
 
125
- # --- FastAPI Startup Event ---
126
- @app.on_event("startup")
127
- def startup_event():
128
  global comparer
129
- # In production, use environment variables or configuration management for tokens.
130
  auth_token = os.environ.get("HF_TOKEN")
131
  comparer = QuranRecitationComparer(
132
  model_name="jonatasgrosman/wav2vec2-large-xlsr-53-arabic",
133
  auth_token=auth_token
134
  )
135
  print("Model initialized and ready for predictions!")
 
 
 
 
 
 
 
 
 
136
 
137
  # --- API Endpoints ---
138
  @app.get("/", summary="Health Check")
@@ -141,24 +126,15 @@ async def root():
141
 
142
  @app.post("/predict", summary="Compare Two Audio Files", response_model=dict)
143
  async def predict(file1: UploadFile = File(...), file2: UploadFile = File(...)):
144
- """
145
- Compare two uploaded audio files and return a similarity score along with an interpretation.
146
-
147
- - **file1**: The first audio file.
148
- - **file2**: The second audio file.
149
- """
150
  tmp1_path = None
151
  tmp2_path = None
152
-
153
  try:
154
- # Save first file to a temporary location
155
  suffix1 = os.path.splitext(file1.filename)[1] or ".wav"
156
  with tempfile.NamedTemporaryFile(delete=False, suffix=suffix1) as tmp1:
157
  content1 = await file1.read()
158
  tmp1.write(content1)
159
  tmp1_path = tmp1.name
160
 
161
- # Save second file to a temporary location
162
  suffix2 = os.path.splitext(file2.filename)[1] or ".wav"
163
  with tempfile.NamedTemporaryFile(delete=False, suffix=suffix2) as tmp2:
164
  content2 = await file2.read()
@@ -167,11 +143,9 @@ async def predict(file1: UploadFile = File(...), file2: UploadFile = File(...)):
167
 
168
  similarity_score, interpretation = comparer.predict(tmp1_path, tmp2_path)
169
  return {"similarity_score": similarity_score, "interpretation": interpretation}
170
-
171
  except Exception as e:
172
  raise HTTPException(status_code=500, detail=str(e))
173
  finally:
174
- # Clean up temporary files
175
  if tmp1_path and os.path.exists(tmp1_path):
176
  os.remove(tmp1_path)
177
  if tmp2_path and os.path.exists(tmp2_path):
@@ -179,8 +153,5 @@ async def predict(file1: UploadFile = File(...), file2: UploadFile = File(...)):
179
 
180
  @app.post("/clear_cache", summary="Clear Embedding Cache", response_model=dict)
181
  async def clear_cache():
182
- """
183
- Clear the embedding cache. This can help free memory if many comparisons have been made.
184
- """
185
  comparer.clear_cache()
186
  return {"message": "Cache cleared."}
 
5
  import tempfile
6
  from fastapi import FastAPI, UploadFile, File, HTTPException
7
  from transformers import Wav2Vec2Processor, Wav2Vec2ForCTC
8
+ from librosa.sequence import dtw
9
+ from contextlib import asynccontextmanager
 
 
 
 
 
10
 
11
  # --- Core Class Definition ---
12
  class QuranRecitationComparer:
 
16
  """
17
  self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
18
 
 
19
  if auth_token:
20
  self.processor = Wav2Vec2Processor.from_pretrained(model_name, token=auth_token)
21
  self.model = Wav2Vec2ForCTC.from_pretrained(model_name, token=auth_token)
 
25
 
26
  self.model = self.model.to(self.device)
27
  self.model.eval()
 
 
28
  self.embedding_cache = {}
29
 
30
  def load_audio(self, file_path, target_sr=16000, trim_silence=True, normalize=True):
 
31
  if not os.path.exists(file_path):
32
  raise FileNotFoundError(f"Audio file not found: {file_path}")
33
  y, sr = librosa.load(file_path, sr=target_sr)
 
38
  return y
39
 
40
  def get_deep_embedding(self, audio, sr=16000):
 
41
  input_values = self.processor(
42
  audio,
43
  sampling_rate=sr,
 
52
  return embedding_seq
53
 
54
  def compute_dtw_distance(self, features1, features2):
 
55
  D, wp = dtw(X=features1, Y=features2, metric='euclidean')
56
  distance = D[-1, -1]
57
  normalized_distance = distance / len(wp)
58
  return normalized_distance
59
 
60
  def interpret_similarity(self, norm_distance):
 
61
  if norm_distance == 0:
62
  result = "The recitations are identical based on the deep embeddings."
63
  score = 100
 
79
  return result, score
80
 
81
  def get_embedding_for_file(self, file_path):
 
82
  if file_path in self.embedding_cache:
83
  return self.embedding_cache[file_path]
84
  audio = self.load_audio(file_path)
85
  embedding = self.get_deep_embedding(audio)
 
86
  self.embedding_cache[file_path] = embedding
87
  return embedding
88
 
89
  def predict(self, file_path1, file_path2):
 
 
 
 
 
 
 
 
90
  embedding1 = self.get_embedding_for_file(file_path1)
91
  embedding2 = self.get_embedding_for_file(file_path2)
92
  norm_distance = self.compute_dtw_distance(embedding1.T, embedding2.T)
93
  interpretation, similarity_score = self.interpret_similarity(norm_distance)
 
94
  print(f"Similarity Score: {similarity_score:.1f}/100")
95
  print(f"Interpretation: {interpretation}")
96
  return similarity_score, interpretation
97
 
98
  def clear_cache(self):
 
99
  self.embedding_cache = {}
100
 
101
+ # --- Lifespan Event Handler ---
102
+ @asynccontextmanager
103
+ async def lifespan(app: FastAPI):
104
  global comparer
105
+ # Use environment variables or a secure configuration in production
106
  auth_token = os.environ.get("HF_TOKEN")
107
  comparer = QuranRecitationComparer(
108
  model_name="jonatasgrosman/wav2vec2-large-xlsr-53-arabic",
109
  auth_token=auth_token
110
  )
111
  print("Model initialized and ready for predictions!")
112
+ yield
113
+ print("Application shutdown: Cleanup if necessary.")
114
+
115
+ app = FastAPI(
116
+ title="Quran Recitation Comparer API",
117
+ description="Compares two Quran recitations using a deep wav2vec2 model.",
118
+ version="1.0",
119
+ lifespan=lifespan
120
+ )
121
 
122
  # --- API Endpoints ---
123
  @app.get("/", summary="Health Check")
 
126
 
127
  @app.post("/predict", summary="Compare Two Audio Files", response_model=dict)
128
  async def predict(file1: UploadFile = File(...), file2: UploadFile = File(...)):
 
 
 
 
 
 
129
  tmp1_path = None
130
  tmp2_path = None
 
131
  try:
 
132
  suffix1 = os.path.splitext(file1.filename)[1] or ".wav"
133
  with tempfile.NamedTemporaryFile(delete=False, suffix=suffix1) as tmp1:
134
  content1 = await file1.read()
135
  tmp1.write(content1)
136
  tmp1_path = tmp1.name
137
 
 
138
  suffix2 = os.path.splitext(file2.filename)[1] or ".wav"
139
  with tempfile.NamedTemporaryFile(delete=False, suffix=suffix2) as tmp2:
140
  content2 = await file2.read()
 
143
 
144
  similarity_score, interpretation = comparer.predict(tmp1_path, tmp2_path)
145
  return {"similarity_score": similarity_score, "interpretation": interpretation}
 
146
  except Exception as e:
147
  raise HTTPException(status_code=500, detail=str(e))
148
  finally:
 
149
  if tmp1_path and os.path.exists(tmp1_path):
150
  os.remove(tmp1_path)
151
  if tmp2_path and os.path.exists(tmp2_path):
 
153
 
154
  @app.post("/clear_cache", summary="Clear Embedding Cache", response_model=dict)
155
  async def clear_cache():
 
 
 
156
  comparer.clear_cache()
157
  return {"message": "Cache cleared."}