Hammad712 commited on
Commit
3e1b72c
·
verified ·
1 Parent(s): b7062bf

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +127 -61
main.py CHANGED
@@ -4,22 +4,35 @@ import torch
4
  import librosa
5
  import numpy as np
6
  import os
7
- from transformers import Wav2Vec2Processor, Wav2Vec2ForCTC
8
  import tempfile
9
  import shutil
10
  import uvicorn
11
- import scipy.spatial.distance as distance
 
 
 
 
12
 
13
  # Load environment variables
14
  HF_TOKEN = os.getenv("HF_TOKEN")
15
 
16
  app = FastAPI(title="Quran Recitation Comparer API")
17
 
 
 
 
 
 
 
 
 
 
18
  class ComparisonResult(BaseModel):
19
  similarity_score: float
20
  interpretation: str
21
 
22
- # Custom implementation of DTW to replace librosa.sequence.dtw
23
  def custom_dtw(X, Y, metric='euclidean'):
24
  """
25
  Custom Dynamic Time Warping implementation.
@@ -80,23 +93,27 @@ class QuranRecitationComparer:
80
  print(f"Using device: {self.device}")
81
 
82
  # Load model and processor once during initialization
83
- if token:
84
- print(f"Loading model {model_name} with token...")
85
- self.processor = Wav2Vec2Processor.from_pretrained(model_name, use_auth_token=token)
86
- self.model = Wav2Vec2ForCTC.from_pretrained(model_name, use_auth_token=token)
87
- else:
88
- print(f"Loading model {model_name} without token...")
89
- self.processor = Wav2Vec2Processor.from_pretrained(model_name)
90
- self.model = Wav2Vec2ForCTC.from_pretrained(model_name)
91
-
92
- self.model = self.model.to(self.device)
93
- self.model.eval()
 
 
 
 
 
94
 
95
  # Cache for embeddings to avoid recomputation
96
  self.embedding_cache = {}
97
- print("Model loaded successfully!")
98
 
99
- def load_audio(self, file_path, target_sr=16000, trim_silence=True, normalize=True):
100
  """Load and preprocess an audio file."""
101
  if not os.path.exists(file_path):
102
  raise FileNotFoundError(f"Audio file not found: {file_path}")
@@ -107,34 +124,69 @@ class QuranRecitationComparer:
107
  if normalize:
108
  y = librosa.util.normalize(y)
109
 
110
- if trim_silence:
111
- # Use librosa.effects.trim which should be available in most versions
112
- y, _ = librosa.effects.trim(y, top_db=30)
113
-
 
 
 
 
 
 
114
  return y
115
 
116
  def get_deep_embedding(self, audio, sr=16000):
117
  """Extract frame-wise deep embeddings using the pretrained model."""
118
- input_values = self.processor(
119
- audio,
120
- sampling_rate=sr,
121
- return_tensors="pt"
122
- ).input_values.to(self.device)
123
-
124
- with torch.no_grad():
125
- outputs = self.model(input_values, output_hidden_states=True)
126
-
127
- hidden_states = outputs.hidden_states[-1]
128
- embedding_seq = hidden_states.squeeze(0).cpu().numpy()
129
-
130
- return embedding_seq
 
 
 
 
131
 
132
  def compute_dtw_distance(self, features1, features2):
133
  """Compute the DTW distance between two sequences of features."""
134
- D, wp = custom_dtw(X=features1, Y=features2, metric='euclidean')
135
- distance = D[-1, -1]
136
- normalized_distance = distance / len(wp)
137
- return normalized_distance
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
138
 
139
  def interpret_similarity(self, norm_distance):
140
  """Interpret the normalized distance value."""
@@ -166,14 +218,18 @@ class QuranRecitationComparer:
166
  return self.embedding_cache[file_path]
167
 
168
  print(f"Computing new embedding for {file_path}")
169
- audio = self.load_audio(file_path)
170
- embedding = self.get_deep_embedding(audio)
171
-
172
- # Store in cache for future use
173
- self.embedding_cache[file_path] = embedding
174
- print(f"Embedding shape: {embedding.shape}")
175
-
176
- return embedding
 
 
 
 
177
 
178
  def predict(self, file_path1, file_path2):
179
  """
@@ -189,20 +245,25 @@ class QuranRecitationComparer:
189
  str: Interpretation of similarity
190
  """
191
  print(f"Comparing {file_path1} and {file_path2}")
192
- # Get embeddings (using cache if available)
193
- embedding1 = self.get_embedding_for_file(file_path1)
194
- embedding2 = self.get_embedding_for_file(file_path2)
195
-
196
- # Compute DTW distance
197
- print("Computing DTW distance...")
198
- norm_distance = self.compute_dtw_distance(embedding1.T, embedding2.T)
199
- print(f"Normalized distance: {norm_distance}")
200
-
201
- # Interpret results
202
- interpretation, similarity_score = self.interpret_similarity(norm_distance)
203
- print(f"Similarity score: {similarity_score}, Interpretation: {interpretation}")
204
-
205
- return similarity_score, interpretation
 
 
 
 
 
206
 
207
  def clear_cache(self):
208
  """Clear the embedding cache to free memory."""
@@ -212,6 +273,7 @@ class QuranRecitationComparer:
212
  # Global variable for the comparer instance
213
  comparer = None
214
 
 
215
  @app.on_event("startup")
216
  async def startup_event():
217
  """Initialize the model when the application starts."""
@@ -225,12 +287,16 @@ async def startup_event():
225
  print("Model initialized and ready for predictions!")
226
  except Exception as e:
227
  print(f"Error initializing model: {str(e)}")
228
- raise
229
 
230
  @app.get("/")
231
  async def root():
232
  """Root endpoint to check if the API is running."""
233
- return {"message": "Quran Recitation Comparer API is running", "status": "active"}
 
 
 
 
234
 
235
  @app.post("/compare", response_model=ComparisonResult)
236
  async def compare_files(
 
4
  import librosa
5
  import numpy as np
6
  import os
7
+ from transformers import AutoProcessor, AutoModelForCTC
8
  import tempfile
9
  import shutil
10
  import uvicorn
11
+ from fastapi.middleware.cors import CORSMiddleware
12
+ import warnings
13
+
14
+ # Ignore deprecation warnings
15
+ warnings.filterwarnings("ignore")
16
 
17
  # Load environment variables
18
  HF_TOKEN = os.getenv("HF_TOKEN")
19
 
20
  app = FastAPI(title="Quran Recitation Comparer API")
21
 
22
+ # Add CORS middleware
23
+ app.add_middleware(
24
+ CORSMiddleware,
25
+ allow_origins=["*"],
26
+ allow_credentials=True,
27
+ allow_methods=["*"],
28
+ allow_headers=["*"],
29
+ )
30
+
31
  class ComparisonResult(BaseModel):
32
  similarity_score: float
33
  interpretation: str
34
 
35
+ # Custom implementation of DTW
36
  def custom_dtw(X, Y, metric='euclidean'):
37
  """
38
  Custom Dynamic Time Warping implementation.
 
93
  print(f"Using device: {self.device}")
94
 
95
  # Load model and processor once during initialization
96
+ try:
97
+ if token:
98
+ print(f"Loading model {model_name} with token...")
99
+ self.processor = AutoProcessor.from_pretrained(model_name, token=token)
100
+ self.model = AutoModelForCTC.from_pretrained(model_name, token=token)
101
+ else:
102
+ print(f"Loading model {model_name} without token...")
103
+ self.processor = AutoProcessor.from_pretrained(model_name)
104
+ self.model = AutoModelForCTC.from_pretrained(model_name)
105
+
106
+ self.model = self.model.to(self.device)
107
+ self.model.eval()
108
+ print("Model loaded successfully!")
109
+ except Exception as e:
110
+ print(f"Error loading model: {str(e)}")
111
+ raise
112
 
113
  # Cache for embeddings to avoid recomputation
114
  self.embedding_cache = {}
 
115
 
116
+ def load_audio(self, file_path, target_sr=16000, normalize=True):
117
  """Load and preprocess an audio file."""
118
  if not os.path.exists(file_path):
119
  raise FileNotFoundError(f"Audio file not found: {file_path}")
 
124
  if normalize:
125
  y = librosa.util.normalize(y)
126
 
127
+ # Trim silence using a simplified approach
128
+ trim_y = []
129
+ threshold = 0.02 # Threshold for silence detection
130
+ for i in range(len(y)):
131
+ if abs(y[i]) > threshold:
132
+ trim_y.append(y[i])
133
+
134
+ if len(trim_y) > 0:
135
+ y = np.array(trim_y)
136
+
137
  return y
138
 
139
  def get_deep_embedding(self, audio, sr=16000):
140
  """Extract frame-wise deep embeddings using the pretrained model."""
141
+ try:
142
+ inputs = self.processor(
143
+ audio,
144
+ sampling_rate=sr,
145
+ return_tensors="pt"
146
+ ).input_values.to(self.device)
147
+
148
+ with torch.no_grad():
149
+ outputs = self.model(inputs, output_hidden_states=True)
150
+
151
+ hidden_states = outputs.hidden_states[-1]
152
+ embedding_seq = hidden_states.squeeze(0).cpu().numpy()
153
+
154
+ return embedding_seq
155
+ except Exception as e:
156
+ print(f"Error in get_deep_embedding: {str(e)}")
157
+ raise
158
 
159
  def compute_dtw_distance(self, features1, features2):
160
  """Compute the DTW distance between two sequences of features."""
161
+ # Make sure features are 2D arrays
162
+ if features1.ndim == 1:
163
+ features1 = features1.reshape(-1, 1)
164
+ if features2.ndim == 1:
165
+ features2 = features2.reshape(-1, 1)
166
+
167
+ print(f"Feature shapes: {features1.shape}, {features2.shape}")
168
+
169
+ # Use a subsample if the sequences are too long to avoid memory issues
170
+ max_length = 300
171
+ if features1.shape[0] > max_length or features2.shape[0] > max_length:
172
+ step1 = max(1, features1.shape[0] // max_length)
173
+ step2 = max(1, features2.shape[0] // max_length)
174
+ features1 = features1[::step1]
175
+ features2 = features2[::step2]
176
+ print(f"Subsampled feature shapes: {features1.shape}, {features2.shape}")
177
+
178
+ try:
179
+ D, wp = custom_dtw(X=features1, Y=features2, metric='euclidean')
180
+ distance = D[-1, -1]
181
+ normalized_distance = distance / len(wp)
182
+ return normalized_distance
183
+ except Exception as e:
184
+ print(f"Error in compute_dtw_distance: {str(e)}")
185
+ # Fallback to a basic similarity measure if DTW fails
186
+ mean_1 = np.mean(features1, axis=0)
187
+ mean_2 = np.mean(features2, axis=0)
188
+ euclidean_distance = np.sqrt(np.sum((mean_1 - mean_2) ** 2))
189
+ return euclidean_distance
190
 
191
  def interpret_similarity(self, norm_distance):
192
  """Interpret the normalized distance value."""
 
218
  return self.embedding_cache[file_path]
219
 
220
  print(f"Computing new embedding for {file_path}")
221
+ try:
222
+ audio = self.load_audio(file_path)
223
+ embedding = self.get_deep_embedding(audio)
224
+
225
+ # Store in cache for future use
226
+ self.embedding_cache[file_path] = embedding
227
+ print(f"Embedding shape: {embedding.shape}")
228
+
229
+ return embedding
230
+ except Exception as e:
231
+ print(f"Error getting embedding: {str(e)}")
232
+ raise
233
 
234
  def predict(self, file_path1, file_path2):
235
  """
 
245
  str: Interpretation of similarity
246
  """
247
  print(f"Comparing {file_path1} and {file_path2}")
248
+ try:
249
+ # Get embeddings (using cache if available)
250
+ embedding1 = self.get_embedding_for_file(file_path1)
251
+ embedding2 = self.get_embedding_for_file(file_path2)
252
+
253
+ # Compute DTW distance
254
+ print("Computing DTW distance...")
255
+ norm_distance = self.compute_dtw_distance(embedding1.T, embedding2.T)
256
+ print(f"Normalized distance: {norm_distance}")
257
+
258
+ # Interpret results
259
+ interpretation, similarity_score = self.interpret_similarity(norm_distance)
260
+ print(f"Similarity score: {similarity_score}, Interpretation: {interpretation}")
261
+
262
+ return similarity_score, interpretation
263
+ except Exception as e:
264
+ print(f"Error in predict: {str(e)}")
265
+ # Return a fallback response in case of error
266
+ return 0, f"Error comparing files: {str(e)}"
267
 
268
  def clear_cache(self):
269
  """Clear the embedding cache to free memory."""
 
273
  # Global variable for the comparer instance
274
  comparer = None
275
 
276
+ # Use the new lifespan API
277
  @app.on_event("startup")
278
  async def startup_event():
279
  """Initialize the model when the application starts."""
 
287
  print("Model initialized and ready for predictions!")
288
  except Exception as e:
289
  print(f"Error initializing model: {str(e)}")
290
+ # Don't raise here, let the app continue to load even if model fails
291
 
292
  @app.get("/")
293
  async def root():
294
  """Root endpoint to check if the API is running."""
295
+ if comparer:
296
+ status = "active"
297
+ else:
298
+ status = "model not loaded"
299
+ return {"message": "Quran Recitation Comparer API is running", "status": status}
300
 
301
  @app.post("/compare", response_model=ComparisonResult)
302
  async def compare_files(