Spaces:
Sleeping
Sleeping
Update main.py
Browse files
main.py
CHANGED
@@ -46,16 +46,12 @@ def custom_dtw(X, Y, metric='euclidean'):
|
|
46 |
D: Cost matrix
|
47 |
wp: Warping path
|
48 |
"""
|
49 |
-
# Get sequence lengths
|
50 |
n, m = len(X), len(Y)
|
51 |
-
|
52 |
-
# Initialize cost matrix
|
53 |
D = np.zeros((n + 1, m + 1))
|
54 |
D[0, 1:] = np.inf
|
55 |
D[1:, 0] = np.inf
|
56 |
D[0, 0] = 0
|
57 |
|
58 |
-
# Fill cost matrix
|
59 |
for i in range(1, n + 1):
|
60 |
for j in range(1, m + 1):
|
61 |
if metric == 'euclidean':
|
@@ -64,7 +60,6 @@ def custom_dtw(X, Y, metric='euclidean'):
|
|
64 |
cost = 1 - np.dot(X[i-1], Y[j-1]) / (np.linalg.norm(X[i-1]) * np.linalg.norm(Y[j-1]))
|
65 |
D[i, j] = cost + min(D[i-1, j], D[i, j-1], D[i-1, j-1])
|
66 |
|
67 |
-
# Backtracking
|
68 |
wp = [(n, m)]
|
69 |
i, j = n, m
|
70 |
while i > 0 or j > 0:
|
@@ -92,7 +87,6 @@ class QuranRecitationComparer:
|
|
92 |
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
93 |
print(f"Using device: {self.device}")
|
94 |
|
95 |
-
# Load model and processor once during initialization
|
96 |
try:
|
97 |
if token:
|
98 |
print(f"Loading model {model_name} with token...")
|
@@ -105,6 +99,8 @@ class QuranRecitationComparer:
|
|
105 |
|
106 |
self.model = self.model.to(self.device)
|
107 |
self.model.eval()
|
|
|
|
|
108 |
print("Model loaded successfully!")
|
109 |
except Exception as e:
|
110 |
print(f"Error loading model: {str(e)}")
|
@@ -146,7 +142,8 @@ class QuranRecitationComparer:
|
|
146 |
).input_values.to(self.device)
|
147 |
|
148 |
with torch.no_grad():
|
149 |
-
|
|
|
150 |
|
151 |
hidden_states = outputs.hidden_states[-1]
|
152 |
embedding_seq = hidden_states.squeeze(0).cpu().numpy()
|
@@ -158,7 +155,6 @@ class QuranRecitationComparer:
|
|
158 |
|
159 |
def compute_dtw_distance(self, features1, features2):
|
160 |
"""Compute the DTW distance between two sequences of features."""
|
161 |
-
# Make sure features are 2D arrays
|
162 |
if features1.ndim == 1:
|
163 |
features1 = features1.reshape(-1, 1)
|
164 |
if features2.ndim == 1:
|
@@ -166,7 +162,6 @@ class QuranRecitationComparer:
|
|
166 |
|
167 |
print(f"Feature shapes: {features1.shape}, {features2.shape}")
|
168 |
|
169 |
-
# Use a subsample if the sequences are too long to avoid memory issues
|
170 |
max_length = 300
|
171 |
if features1.shape[0] > max_length or features2.shape[0] > max_length:
|
172 |
step1 = max(1, features1.shape[0] // max_length)
|
@@ -182,7 +177,6 @@ class QuranRecitationComparer:
|
|
182 |
return normalized_distance
|
183 |
except Exception as e:
|
184 |
print(f"Error in compute_dtw_distance: {str(e)}")
|
185 |
-
# Fallback to a basic similarity measure if DTW fails
|
186 |
mean_1 = np.mean(features1, axis=0)
|
187 |
mean_2 = np.mean(features2, axis=0)
|
188 |
euclidean_distance = np.sqrt(np.sum((mean_1 - mean_2) ** 2))
|
@@ -222,7 +216,6 @@ class QuranRecitationComparer:
|
|
222 |
audio = self.load_audio(file_path)
|
223 |
embedding = self.get_deep_embedding(audio)
|
224 |
|
225 |
-
# Store in cache for future use
|
226 |
self.embedding_cache[file_path] = embedding
|
227 |
print(f"Embedding shape: {embedding.shape}")
|
228 |
|
@@ -234,35 +227,30 @@ class QuranRecitationComparer:
|
|
234 |
def predict(self, file_path1, file_path2):
|
235 |
"""
|
236 |
Predict the similarity between two audio files.
|
237 |
-
|
238 |
-
|
239 |
Args:
|
240 |
file_path1 (str): Path to first audio file
|
241 |
file_path2 (str): Path to second audio file
|
242 |
-
|
243 |
Returns:
|
244 |
float: Similarity score
|
245 |
str: Interpretation of similarity
|
246 |
"""
|
247 |
print(f"Comparing {file_path1} and {file_path2}")
|
248 |
try:
|
249 |
-
# Get embeddings (using cache if available)
|
250 |
embedding1 = self.get_embedding_for_file(file_path1)
|
251 |
embedding2 = self.get_embedding_for_file(file_path2)
|
252 |
|
253 |
-
# Compute DTW distance
|
254 |
print("Computing DTW distance...")
|
255 |
norm_distance = self.compute_dtw_distance(embedding1.T, embedding2.T)
|
256 |
print(f"Normalized distance: {norm_distance}")
|
257 |
|
258 |
-
# Interpret results
|
259 |
interpretation, similarity_score = self.interpret_similarity(norm_distance)
|
260 |
print(f"Similarity score: {similarity_score}, Interpretation: {interpretation}")
|
261 |
|
262 |
return similarity_score, interpretation
|
263 |
except Exception as e:
|
264 |
print(f"Error in predict: {str(e)}")
|
265 |
-
# Return a fallback response in case of error
|
266 |
return 0, f"Error comparing files: {str(e)}"
|
267 |
|
268 |
def clear_cache(self):
|
@@ -273,7 +261,6 @@ class QuranRecitationComparer:
|
|
273 |
# Global variable for the comparer instance
|
274 |
comparer = None
|
275 |
|
276 |
-
# Use the new lifespan API
|
277 |
@app.on_event("startup")
|
278 |
async def startup_event():
|
279 |
"""Initialize the model when the application starts."""
|
@@ -287,7 +274,6 @@ async def startup_event():
|
|
287 |
print("Model initialized and ready for predictions!")
|
288 |
except Exception as e:
|
289 |
print(f"Error initializing model: {str(e)}")
|
290 |
-
# Don't raise here, let the app continue to load even if model fails
|
291 |
|
292 |
@app.get("/")
|
293 |
async def root():
|
@@ -319,7 +305,6 @@ async def compare_files(
|
|
319 |
print(f"Created temporary directory: {temp_dir}")
|
320 |
|
321 |
try:
|
322 |
-
# Save uploaded files to temporary directory
|
323 |
temp_file1 = os.path.join(temp_dir, file1.filename)
|
324 |
temp_file2 = os.path.join(temp_dir, file2.filename)
|
325 |
|
@@ -333,7 +318,6 @@ async def compare_files(
|
|
333 |
|
334 |
print(f"Files saved to: {temp_file1} and {temp_file2}")
|
335 |
|
336 |
-
# Compare the files
|
337 |
similarity_score, interpretation = comparer.predict(temp_file1, temp_file2)
|
338 |
|
339 |
return ComparisonResult(
|
@@ -346,7 +330,6 @@ async def compare_files(
|
|
346 |
raise HTTPException(status_code=500, detail=f"Error processing files: {str(e)}")
|
347 |
|
348 |
finally:
|
349 |
-
# Clean up temporary files
|
350 |
print(f"Cleaning up temporary directory: {temp_dir}")
|
351 |
shutil.rmtree(temp_dir, ignore_errors=True)
|
352 |
|
@@ -360,4 +343,4 @@ async def clear_cache():
|
|
360 |
return {"message": "Embedding cache cleared successfully"}
|
361 |
|
362 |
if __name__ == "__main__":
|
363 |
-
uvicorn.run("main:app", host="0.0.0.0", port=7860, log_level="info")
|
|
|
46 |
D: Cost matrix
|
47 |
wp: Warping path
|
48 |
"""
|
|
|
49 |
n, m = len(X), len(Y)
|
|
|
|
|
50 |
D = np.zeros((n + 1, m + 1))
|
51 |
D[0, 1:] = np.inf
|
52 |
D[1:, 0] = np.inf
|
53 |
D[0, 0] = 0
|
54 |
|
|
|
55 |
for i in range(1, n + 1):
|
56 |
for j in range(1, m + 1):
|
57 |
if metric == 'euclidean':
|
|
|
60 |
cost = 1 - np.dot(X[i-1], Y[j-1]) / (np.linalg.norm(X[i-1]) * np.linalg.norm(Y[j-1]))
|
61 |
D[i, j] = cost + min(D[i-1, j], D[i, j-1], D[i-1, j-1])
|
62 |
|
|
|
63 |
wp = [(n, m)]
|
64 |
i, j = n, m
|
65 |
while i > 0 or j > 0:
|
|
|
87 |
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
88 |
print(f"Using device: {self.device}")
|
89 |
|
|
|
90 |
try:
|
91 |
if token:
|
92 |
print(f"Loading model {model_name} with token...")
|
|
|
99 |
|
100 |
self.model = self.model.to(self.device)
|
101 |
self.model.eval()
|
102 |
+
# Set the configuration to always return hidden states
|
103 |
+
self.model.config.output_hidden_states = True
|
104 |
print("Model loaded successfully!")
|
105 |
except Exception as e:
|
106 |
print(f"Error loading model: {str(e)}")
|
|
|
142 |
).input_values.to(self.device)
|
143 |
|
144 |
with torch.no_grad():
|
145 |
+
# Call the model without passing output_hidden_states explicitly.
|
146 |
+
outputs = self.model(inputs)
|
147 |
|
148 |
hidden_states = outputs.hidden_states[-1]
|
149 |
embedding_seq = hidden_states.squeeze(0).cpu().numpy()
|
|
|
155 |
|
156 |
def compute_dtw_distance(self, features1, features2):
|
157 |
"""Compute the DTW distance between two sequences of features."""
|
|
|
158 |
if features1.ndim == 1:
|
159 |
features1 = features1.reshape(-1, 1)
|
160 |
if features2.ndim == 1:
|
|
|
162 |
|
163 |
print(f"Feature shapes: {features1.shape}, {features2.shape}")
|
164 |
|
|
|
165 |
max_length = 300
|
166 |
if features1.shape[0] > max_length or features2.shape[0] > max_length:
|
167 |
step1 = max(1, features1.shape[0] // max_length)
|
|
|
177 |
return normalized_distance
|
178 |
except Exception as e:
|
179 |
print(f"Error in compute_dtw_distance: {str(e)}")
|
|
|
180 |
mean_1 = np.mean(features1, axis=0)
|
181 |
mean_2 = np.mean(features2, axis=0)
|
182 |
euclidean_distance = np.sqrt(np.sum((mean_1 - mean_2) ** 2))
|
|
|
216 |
audio = self.load_audio(file_path)
|
217 |
embedding = self.get_deep_embedding(audio)
|
218 |
|
|
|
219 |
self.embedding_cache[file_path] = embedding
|
220 |
print(f"Embedding shape: {embedding.shape}")
|
221 |
|
|
|
227 |
def predict(self, file_path1, file_path2):
|
228 |
"""
|
229 |
Predict the similarity between two audio files.
|
230 |
+
|
|
|
231 |
Args:
|
232 |
file_path1 (str): Path to first audio file
|
233 |
file_path2 (str): Path to second audio file
|
234 |
+
|
235 |
Returns:
|
236 |
float: Similarity score
|
237 |
str: Interpretation of similarity
|
238 |
"""
|
239 |
print(f"Comparing {file_path1} and {file_path2}")
|
240 |
try:
|
|
|
241 |
embedding1 = self.get_embedding_for_file(file_path1)
|
242 |
embedding2 = self.get_embedding_for_file(file_path2)
|
243 |
|
|
|
244 |
print("Computing DTW distance...")
|
245 |
norm_distance = self.compute_dtw_distance(embedding1.T, embedding2.T)
|
246 |
print(f"Normalized distance: {norm_distance}")
|
247 |
|
|
|
248 |
interpretation, similarity_score = self.interpret_similarity(norm_distance)
|
249 |
print(f"Similarity score: {similarity_score}, Interpretation: {interpretation}")
|
250 |
|
251 |
return similarity_score, interpretation
|
252 |
except Exception as e:
|
253 |
print(f"Error in predict: {str(e)}")
|
|
|
254 |
return 0, f"Error comparing files: {str(e)}"
|
255 |
|
256 |
def clear_cache(self):
|
|
|
261 |
# Global variable for the comparer instance
|
262 |
comparer = None
|
263 |
|
|
|
264 |
@app.on_event("startup")
|
265 |
async def startup_event():
|
266 |
"""Initialize the model when the application starts."""
|
|
|
274 |
print("Model initialized and ready for predictions!")
|
275 |
except Exception as e:
|
276 |
print(f"Error initializing model: {str(e)}")
|
|
|
277 |
|
278 |
@app.get("/")
|
279 |
async def root():
|
|
|
305 |
print(f"Created temporary directory: {temp_dir}")
|
306 |
|
307 |
try:
|
|
|
308 |
temp_file1 = os.path.join(temp_dir, file1.filename)
|
309 |
temp_file2 = os.path.join(temp_dir, file2.filename)
|
310 |
|
|
|
318 |
|
319 |
print(f"Files saved to: {temp_file1} and {temp_file2}")
|
320 |
|
|
|
321 |
similarity_score, interpretation = comparer.predict(temp_file1, temp_file2)
|
322 |
|
323 |
return ComparisonResult(
|
|
|
330 |
raise HTTPException(status_code=500, detail=f"Error processing files: {str(e)}")
|
331 |
|
332 |
finally:
|
|
|
333 |
print(f"Cleaning up temporary directory: {temp_dir}")
|
334 |
shutil.rmtree(temp_dir, ignore_errors=True)
|
335 |
|
|
|
343 |
return {"message": "Embedding cache cleared successfully"}
|
344 |
|
345 |
if __name__ == "__main__":
|
346 |
+
uvicorn.run("main:app", host="0.0.0.0", port=7860, log_level="info")
|