Hammad712 commited on
Commit
f394b62
·
verified ·
1 Parent(s): 3e1b72c

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +7 -24
main.py CHANGED
@@ -46,16 +46,12 @@ def custom_dtw(X, Y, metric='euclidean'):
46
  D: Cost matrix
47
  wp: Warping path
48
  """
49
- # Get sequence lengths
50
  n, m = len(X), len(Y)
51
-
52
- # Initialize cost matrix
53
  D = np.zeros((n + 1, m + 1))
54
  D[0, 1:] = np.inf
55
  D[1:, 0] = np.inf
56
  D[0, 0] = 0
57
 
58
- # Fill cost matrix
59
  for i in range(1, n + 1):
60
  for j in range(1, m + 1):
61
  if metric == 'euclidean':
@@ -64,7 +60,6 @@ def custom_dtw(X, Y, metric='euclidean'):
64
  cost = 1 - np.dot(X[i-1], Y[j-1]) / (np.linalg.norm(X[i-1]) * np.linalg.norm(Y[j-1]))
65
  D[i, j] = cost + min(D[i-1, j], D[i, j-1], D[i-1, j-1])
66
 
67
- # Backtracking
68
  wp = [(n, m)]
69
  i, j = n, m
70
  while i > 0 or j > 0:
@@ -92,7 +87,6 @@ class QuranRecitationComparer:
92
  self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
93
  print(f"Using device: {self.device}")
94
 
95
- # Load model and processor once during initialization
96
  try:
97
  if token:
98
  print(f"Loading model {model_name} with token...")
@@ -105,6 +99,8 @@ class QuranRecitationComparer:
105
 
106
  self.model = self.model.to(self.device)
107
  self.model.eval()
 
 
108
  print("Model loaded successfully!")
109
  except Exception as e:
110
  print(f"Error loading model: {str(e)}")
@@ -146,7 +142,8 @@ class QuranRecitationComparer:
146
  ).input_values.to(self.device)
147
 
148
  with torch.no_grad():
149
- outputs = self.model(inputs, output_hidden_states=True)
 
150
 
151
  hidden_states = outputs.hidden_states[-1]
152
  embedding_seq = hidden_states.squeeze(0).cpu().numpy()
@@ -158,7 +155,6 @@ class QuranRecitationComparer:
158
 
159
  def compute_dtw_distance(self, features1, features2):
160
  """Compute the DTW distance between two sequences of features."""
161
- # Make sure features are 2D arrays
162
  if features1.ndim == 1:
163
  features1 = features1.reshape(-1, 1)
164
  if features2.ndim == 1:
@@ -166,7 +162,6 @@ class QuranRecitationComparer:
166
 
167
  print(f"Feature shapes: {features1.shape}, {features2.shape}")
168
 
169
- # Use a subsample if the sequences are too long to avoid memory issues
170
  max_length = 300
171
  if features1.shape[0] > max_length or features2.shape[0] > max_length:
172
  step1 = max(1, features1.shape[0] // max_length)
@@ -182,7 +177,6 @@ class QuranRecitationComparer:
182
  return normalized_distance
183
  except Exception as e:
184
  print(f"Error in compute_dtw_distance: {str(e)}")
185
- # Fallback to a basic similarity measure if DTW fails
186
  mean_1 = np.mean(features1, axis=0)
187
  mean_2 = np.mean(features2, axis=0)
188
  euclidean_distance = np.sqrt(np.sum((mean_1 - mean_2) ** 2))
@@ -222,7 +216,6 @@ class QuranRecitationComparer:
222
  audio = self.load_audio(file_path)
223
  embedding = self.get_deep_embedding(audio)
224
 
225
- # Store in cache for future use
226
  self.embedding_cache[file_path] = embedding
227
  print(f"Embedding shape: {embedding.shape}")
228
 
@@ -234,35 +227,30 @@ class QuranRecitationComparer:
234
  def predict(self, file_path1, file_path2):
235
  """
236
  Predict the similarity between two audio files.
237
- This method can be called repeatedly without reloading the model.
238
-
239
  Args:
240
  file_path1 (str): Path to first audio file
241
  file_path2 (str): Path to second audio file
242
-
243
  Returns:
244
  float: Similarity score
245
  str: Interpretation of similarity
246
  """
247
  print(f"Comparing {file_path1} and {file_path2}")
248
  try:
249
- # Get embeddings (using cache if available)
250
  embedding1 = self.get_embedding_for_file(file_path1)
251
  embedding2 = self.get_embedding_for_file(file_path2)
252
 
253
- # Compute DTW distance
254
  print("Computing DTW distance...")
255
  norm_distance = self.compute_dtw_distance(embedding1.T, embedding2.T)
256
  print(f"Normalized distance: {norm_distance}")
257
 
258
- # Interpret results
259
  interpretation, similarity_score = self.interpret_similarity(norm_distance)
260
  print(f"Similarity score: {similarity_score}, Interpretation: {interpretation}")
261
 
262
  return similarity_score, interpretation
263
  except Exception as e:
264
  print(f"Error in predict: {str(e)}")
265
- # Return a fallback response in case of error
266
  return 0, f"Error comparing files: {str(e)}"
267
 
268
  def clear_cache(self):
@@ -273,7 +261,6 @@ class QuranRecitationComparer:
273
  # Global variable for the comparer instance
274
  comparer = None
275
 
276
- # Use the new lifespan API
277
  @app.on_event("startup")
278
  async def startup_event():
279
  """Initialize the model when the application starts."""
@@ -287,7 +274,6 @@ async def startup_event():
287
  print("Model initialized and ready for predictions!")
288
  except Exception as e:
289
  print(f"Error initializing model: {str(e)}")
290
- # Don't raise here, let the app continue to load even if model fails
291
 
292
  @app.get("/")
293
  async def root():
@@ -319,7 +305,6 @@ async def compare_files(
319
  print(f"Created temporary directory: {temp_dir}")
320
 
321
  try:
322
- # Save uploaded files to temporary directory
323
  temp_file1 = os.path.join(temp_dir, file1.filename)
324
  temp_file2 = os.path.join(temp_dir, file2.filename)
325
 
@@ -333,7 +318,6 @@ async def compare_files(
333
 
334
  print(f"Files saved to: {temp_file1} and {temp_file2}")
335
 
336
- # Compare the files
337
  similarity_score, interpretation = comparer.predict(temp_file1, temp_file2)
338
 
339
  return ComparisonResult(
@@ -346,7 +330,6 @@ async def compare_files(
346
  raise HTTPException(status_code=500, detail=f"Error processing files: {str(e)}")
347
 
348
  finally:
349
- # Clean up temporary files
350
  print(f"Cleaning up temporary directory: {temp_dir}")
351
  shutil.rmtree(temp_dir, ignore_errors=True)
352
 
@@ -360,4 +343,4 @@ async def clear_cache():
360
  return {"message": "Embedding cache cleared successfully"}
361
 
362
  if __name__ == "__main__":
363
- uvicorn.run("main:app", host="0.0.0.0", port=7860, log_level="info")
 
46
  D: Cost matrix
47
  wp: Warping path
48
  """
 
49
  n, m = len(X), len(Y)
 
 
50
  D = np.zeros((n + 1, m + 1))
51
  D[0, 1:] = np.inf
52
  D[1:, 0] = np.inf
53
  D[0, 0] = 0
54
 
 
55
  for i in range(1, n + 1):
56
  for j in range(1, m + 1):
57
  if metric == 'euclidean':
 
60
  cost = 1 - np.dot(X[i-1], Y[j-1]) / (np.linalg.norm(X[i-1]) * np.linalg.norm(Y[j-1]))
61
  D[i, j] = cost + min(D[i-1, j], D[i, j-1], D[i-1, j-1])
62
 
 
63
  wp = [(n, m)]
64
  i, j = n, m
65
  while i > 0 or j > 0:
 
87
  self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
88
  print(f"Using device: {self.device}")
89
 
 
90
  try:
91
  if token:
92
  print(f"Loading model {model_name} with token...")
 
99
 
100
  self.model = self.model.to(self.device)
101
  self.model.eval()
102
+ # Set the configuration to always return hidden states
103
+ self.model.config.output_hidden_states = True
104
  print("Model loaded successfully!")
105
  except Exception as e:
106
  print(f"Error loading model: {str(e)}")
 
142
  ).input_values.to(self.device)
143
 
144
  with torch.no_grad():
145
+ # Call the model without passing output_hidden_states explicitly.
146
+ outputs = self.model(inputs)
147
 
148
  hidden_states = outputs.hidden_states[-1]
149
  embedding_seq = hidden_states.squeeze(0).cpu().numpy()
 
155
 
156
  def compute_dtw_distance(self, features1, features2):
157
  """Compute the DTW distance between two sequences of features."""
 
158
  if features1.ndim == 1:
159
  features1 = features1.reshape(-1, 1)
160
  if features2.ndim == 1:
 
162
 
163
  print(f"Feature shapes: {features1.shape}, {features2.shape}")
164
 
 
165
  max_length = 300
166
  if features1.shape[0] > max_length or features2.shape[0] > max_length:
167
  step1 = max(1, features1.shape[0] // max_length)
 
177
  return normalized_distance
178
  except Exception as e:
179
  print(f"Error in compute_dtw_distance: {str(e)}")
 
180
  mean_1 = np.mean(features1, axis=0)
181
  mean_2 = np.mean(features2, axis=0)
182
  euclidean_distance = np.sqrt(np.sum((mean_1 - mean_2) ** 2))
 
216
  audio = self.load_audio(file_path)
217
  embedding = self.get_deep_embedding(audio)
218
 
 
219
  self.embedding_cache[file_path] = embedding
220
  print(f"Embedding shape: {embedding.shape}")
221
 
 
227
  def predict(self, file_path1, file_path2):
228
  """
229
  Predict the similarity between two audio files.
230
+
 
231
  Args:
232
  file_path1 (str): Path to first audio file
233
  file_path2 (str): Path to second audio file
234
+
235
  Returns:
236
  float: Similarity score
237
  str: Interpretation of similarity
238
  """
239
  print(f"Comparing {file_path1} and {file_path2}")
240
  try:
 
241
  embedding1 = self.get_embedding_for_file(file_path1)
242
  embedding2 = self.get_embedding_for_file(file_path2)
243
 
 
244
  print("Computing DTW distance...")
245
  norm_distance = self.compute_dtw_distance(embedding1.T, embedding2.T)
246
  print(f"Normalized distance: {norm_distance}")
247
 
 
248
  interpretation, similarity_score = self.interpret_similarity(norm_distance)
249
  print(f"Similarity score: {similarity_score}, Interpretation: {interpretation}")
250
 
251
  return similarity_score, interpretation
252
  except Exception as e:
253
  print(f"Error in predict: {str(e)}")
 
254
  return 0, f"Error comparing files: {str(e)}"
255
 
256
  def clear_cache(self):
 
261
  # Global variable for the comparer instance
262
  comparer = None
263
 
 
264
  @app.on_event("startup")
265
  async def startup_event():
266
  """Initialize the model when the application starts."""
 
274
  print("Model initialized and ready for predictions!")
275
  except Exception as e:
276
  print(f"Error initializing model: {str(e)}")
 
277
 
278
  @app.get("/")
279
  async def root():
 
305
  print(f"Created temporary directory: {temp_dir}")
306
 
307
  try:
 
308
  temp_file1 = os.path.join(temp_dir, file1.filename)
309
  temp_file2 = os.path.join(temp_dir, file2.filename)
310
 
 
318
 
319
  print(f"Files saved to: {temp_file1} and {temp_file2}")
320
 
 
321
  similarity_score, interpretation = comparer.predict(temp_file1, temp_file2)
322
 
323
  return ComparisonResult(
 
330
  raise HTTPException(status_code=500, detail=f"Error processing files: {str(e)}")
331
 
332
  finally:
 
333
  print(f"Cleaning up temporary directory: {temp_dir}")
334
  shutil.rmtree(temp_dir, ignore_errors=True)
335
 
 
343
  return {"message": "Embedding cache cleared successfully"}
344
 
345
  if __name__ == "__main__":
346
+ uvicorn.run("main:app", host="0.0.0.0", port=7860, log_level="info")