Spaces:
Sleeping
Sleeping
Update main.py
Browse files
main.py
CHANGED
@@ -5,13 +5,8 @@ import numpy as np
|
|
5 |
import tempfile
|
6 |
from fastapi import FastAPI, UploadFile, File, HTTPException
|
7 |
from transformers import Wav2Vec2Processor, Wav2Vec2ForCTC
|
8 |
-
from librosa.sequence import dtw
|
9 |
-
|
10 |
-
app = FastAPI(
|
11 |
-
title="Quran Recitation Comparer API",
|
12 |
-
description="Compares two Quran recitations using a deep wav2vec2 model.",
|
13 |
-
version="1.0"
|
14 |
-
)
|
15 |
|
16 |
# --- Core Class Definition ---
|
17 |
class QuranRecitationComparer:
|
@@ -21,7 +16,6 @@ class QuranRecitationComparer:
|
|
21 |
"""
|
22 |
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
23 |
|
24 |
-
# Load model and processor once during initialization
|
25 |
if auth_token:
|
26 |
self.processor = Wav2Vec2Processor.from_pretrained(model_name, token=auth_token)
|
27 |
self.model = Wav2Vec2ForCTC.from_pretrained(model_name, token=auth_token)
|
@@ -31,12 +25,9 @@ class QuranRecitationComparer:
|
|
31 |
|
32 |
self.model = self.model.to(self.device)
|
33 |
self.model.eval()
|
34 |
-
|
35 |
-
# Cache for embeddings to avoid recomputation
|
36 |
self.embedding_cache = {}
|
37 |
|
38 |
def load_audio(self, file_path, target_sr=16000, trim_silence=True, normalize=True):
|
39 |
-
"""Load and preprocess an audio file."""
|
40 |
if not os.path.exists(file_path):
|
41 |
raise FileNotFoundError(f"Audio file not found: {file_path}")
|
42 |
y, sr = librosa.load(file_path, sr=target_sr)
|
@@ -47,7 +38,6 @@ class QuranRecitationComparer:
|
|
47 |
return y
|
48 |
|
49 |
def get_deep_embedding(self, audio, sr=16000):
|
50 |
-
"""Extract frame-wise deep embeddings using the pretrained model."""
|
51 |
input_values = self.processor(
|
52 |
audio,
|
53 |
sampling_rate=sr,
|
@@ -62,14 +52,12 @@ class QuranRecitationComparer:
|
|
62 |
return embedding_seq
|
63 |
|
64 |
def compute_dtw_distance(self, features1, features2):
|
65 |
-
"""Compute the DTW distance between two sequences of features."""
|
66 |
D, wp = dtw(X=features1, Y=features2, metric='euclidean')
|
67 |
distance = D[-1, -1]
|
68 |
normalized_distance = distance / len(wp)
|
69 |
return normalized_distance
|
70 |
|
71 |
def interpret_similarity(self, norm_distance):
|
72 |
-
"""Interpret the normalized distance value."""
|
73 |
if norm_distance == 0:
|
74 |
result = "The recitations are identical based on the deep embeddings."
|
75 |
score = 100
|
@@ -91,48 +79,45 @@ class QuranRecitationComparer:
|
|
91 |
return result, score
|
92 |
|
93 |
def get_embedding_for_file(self, file_path):
|
94 |
-
"""Get embedding for a file, using cache if available."""
|
95 |
if file_path in self.embedding_cache:
|
96 |
return self.embedding_cache[file_path]
|
97 |
audio = self.load_audio(file_path)
|
98 |
embedding = self.get_deep_embedding(audio)
|
99 |
-
# Store in cache for future use
|
100 |
self.embedding_cache[file_path] = embedding
|
101 |
return embedding
|
102 |
|
103 |
def predict(self, file_path1, file_path2):
|
104 |
-
"""
|
105 |
-
Predict the similarity between two audio files.
|
106 |
-
Args:
|
107 |
-
file_path1 (str): Path to first audio file.
|
108 |
-
file_path2 (str): Path to second audio file.
|
109 |
-
Returns:
|
110 |
-
(float, str): Similarity score and interpretation.
|
111 |
-
"""
|
112 |
embedding1 = self.get_embedding_for_file(file_path1)
|
113 |
embedding2 = self.get_embedding_for_file(file_path2)
|
114 |
norm_distance = self.compute_dtw_distance(embedding1.T, embedding2.T)
|
115 |
interpretation, similarity_score = self.interpret_similarity(norm_distance)
|
116 |
-
# Optionally log the results instead of printing in production
|
117 |
print(f"Similarity Score: {similarity_score:.1f}/100")
|
118 |
print(f"Interpretation: {interpretation}")
|
119 |
return similarity_score, interpretation
|
120 |
|
121 |
def clear_cache(self):
|
122 |
-
"""Clear the embedding cache to free memory."""
|
123 |
self.embedding_cache = {}
|
124 |
|
125 |
-
# ---
|
126 |
-
@
|
127 |
-
def
|
128 |
global comparer
|
129 |
-
#
|
130 |
auth_token = os.environ.get("HF_TOKEN")
|
131 |
comparer = QuranRecitationComparer(
|
132 |
model_name="jonatasgrosman/wav2vec2-large-xlsr-53-arabic",
|
133 |
auth_token=auth_token
|
134 |
)
|
135 |
print("Model initialized and ready for predictions!")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
136 |
|
137 |
# --- API Endpoints ---
|
138 |
@app.get("/", summary="Health Check")
|
@@ -141,24 +126,15 @@ async def root():
|
|
141 |
|
142 |
@app.post("/predict", summary="Compare Two Audio Files", response_model=dict)
|
143 |
async def predict(file1: UploadFile = File(...), file2: UploadFile = File(...)):
|
144 |
-
"""
|
145 |
-
Compare two uploaded audio files and return a similarity score along with an interpretation.
|
146 |
-
|
147 |
-
- **file1**: The first audio file.
|
148 |
-
- **file2**: The second audio file.
|
149 |
-
"""
|
150 |
tmp1_path = None
|
151 |
tmp2_path = None
|
152 |
-
|
153 |
try:
|
154 |
-
# Save first file to a temporary location
|
155 |
suffix1 = os.path.splitext(file1.filename)[1] or ".wav"
|
156 |
with tempfile.NamedTemporaryFile(delete=False, suffix=suffix1) as tmp1:
|
157 |
content1 = await file1.read()
|
158 |
tmp1.write(content1)
|
159 |
tmp1_path = tmp1.name
|
160 |
|
161 |
-
# Save second file to a temporary location
|
162 |
suffix2 = os.path.splitext(file2.filename)[1] or ".wav"
|
163 |
with tempfile.NamedTemporaryFile(delete=False, suffix=suffix2) as tmp2:
|
164 |
content2 = await file2.read()
|
@@ -167,11 +143,9 @@ async def predict(file1: UploadFile = File(...), file2: UploadFile = File(...)):
|
|
167 |
|
168 |
similarity_score, interpretation = comparer.predict(tmp1_path, tmp2_path)
|
169 |
return {"similarity_score": similarity_score, "interpretation": interpretation}
|
170 |
-
|
171 |
except Exception as e:
|
172 |
raise HTTPException(status_code=500, detail=str(e))
|
173 |
finally:
|
174 |
-
# Clean up temporary files
|
175 |
if tmp1_path and os.path.exists(tmp1_path):
|
176 |
os.remove(tmp1_path)
|
177 |
if tmp2_path and os.path.exists(tmp2_path):
|
@@ -179,8 +153,5 @@ async def predict(file1: UploadFile = File(...), file2: UploadFile = File(...)):
|
|
179 |
|
180 |
@app.post("/clear_cache", summary="Clear Embedding Cache", response_model=dict)
|
181 |
async def clear_cache():
|
182 |
-
"""
|
183 |
-
Clear the embedding cache. This can help free memory if many comparisons have been made.
|
184 |
-
"""
|
185 |
comparer.clear_cache()
|
186 |
return {"message": "Cache cleared."}
|
|
|
5 |
import tempfile
|
6 |
from fastapi import FastAPI, UploadFile, File, HTTPException
|
7 |
from transformers import Wav2Vec2Processor, Wav2Vec2ForCTC
|
8 |
+
from librosa.sequence import dtw
|
9 |
+
from contextlib import asynccontextmanager
|
|
|
|
|
|
|
|
|
|
|
10 |
|
11 |
# --- Core Class Definition ---
|
12 |
class QuranRecitationComparer:
|
|
|
16 |
"""
|
17 |
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
18 |
|
|
|
19 |
if auth_token:
|
20 |
self.processor = Wav2Vec2Processor.from_pretrained(model_name, token=auth_token)
|
21 |
self.model = Wav2Vec2ForCTC.from_pretrained(model_name, token=auth_token)
|
|
|
25 |
|
26 |
self.model = self.model.to(self.device)
|
27 |
self.model.eval()
|
|
|
|
|
28 |
self.embedding_cache = {}
|
29 |
|
30 |
def load_audio(self, file_path, target_sr=16000, trim_silence=True, normalize=True):
|
|
|
31 |
if not os.path.exists(file_path):
|
32 |
raise FileNotFoundError(f"Audio file not found: {file_path}")
|
33 |
y, sr = librosa.load(file_path, sr=target_sr)
|
|
|
38 |
return y
|
39 |
|
40 |
def get_deep_embedding(self, audio, sr=16000):
|
|
|
41 |
input_values = self.processor(
|
42 |
audio,
|
43 |
sampling_rate=sr,
|
|
|
52 |
return embedding_seq
|
53 |
|
54 |
def compute_dtw_distance(self, features1, features2):
|
|
|
55 |
D, wp = dtw(X=features1, Y=features2, metric='euclidean')
|
56 |
distance = D[-1, -1]
|
57 |
normalized_distance = distance / len(wp)
|
58 |
return normalized_distance
|
59 |
|
60 |
def interpret_similarity(self, norm_distance):
|
|
|
61 |
if norm_distance == 0:
|
62 |
result = "The recitations are identical based on the deep embeddings."
|
63 |
score = 100
|
|
|
79 |
return result, score
|
80 |
|
81 |
def get_embedding_for_file(self, file_path):
|
|
|
82 |
if file_path in self.embedding_cache:
|
83 |
return self.embedding_cache[file_path]
|
84 |
audio = self.load_audio(file_path)
|
85 |
embedding = self.get_deep_embedding(audio)
|
|
|
86 |
self.embedding_cache[file_path] = embedding
|
87 |
return embedding
|
88 |
|
89 |
def predict(self, file_path1, file_path2):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
90 |
embedding1 = self.get_embedding_for_file(file_path1)
|
91 |
embedding2 = self.get_embedding_for_file(file_path2)
|
92 |
norm_distance = self.compute_dtw_distance(embedding1.T, embedding2.T)
|
93 |
interpretation, similarity_score = self.interpret_similarity(norm_distance)
|
|
|
94 |
print(f"Similarity Score: {similarity_score:.1f}/100")
|
95 |
print(f"Interpretation: {interpretation}")
|
96 |
return similarity_score, interpretation
|
97 |
|
98 |
def clear_cache(self):
|
|
|
99 |
self.embedding_cache = {}
|
100 |
|
101 |
+
# --- Lifespan Event Handler ---
|
102 |
+
@asynccontextmanager
|
103 |
+
async def lifespan(app: FastAPI):
|
104 |
global comparer
|
105 |
+
# Use environment variables or a secure configuration in production
|
106 |
auth_token = os.environ.get("HF_TOKEN")
|
107 |
comparer = QuranRecitationComparer(
|
108 |
model_name="jonatasgrosman/wav2vec2-large-xlsr-53-arabic",
|
109 |
auth_token=auth_token
|
110 |
)
|
111 |
print("Model initialized and ready for predictions!")
|
112 |
+
yield
|
113 |
+
print("Application shutdown: Cleanup if necessary.")
|
114 |
+
|
115 |
+
app = FastAPI(
|
116 |
+
title="Quran Recitation Comparer API",
|
117 |
+
description="Compares two Quran recitations using a deep wav2vec2 model.",
|
118 |
+
version="1.0",
|
119 |
+
lifespan=lifespan
|
120 |
+
)
|
121 |
|
122 |
# --- API Endpoints ---
|
123 |
@app.get("/", summary="Health Check")
|
|
|
126 |
|
127 |
@app.post("/predict", summary="Compare Two Audio Files", response_model=dict)
|
128 |
async def predict(file1: UploadFile = File(...), file2: UploadFile = File(...)):
|
|
|
|
|
|
|
|
|
|
|
|
|
129 |
tmp1_path = None
|
130 |
tmp2_path = None
|
|
|
131 |
try:
|
|
|
132 |
suffix1 = os.path.splitext(file1.filename)[1] or ".wav"
|
133 |
with tempfile.NamedTemporaryFile(delete=False, suffix=suffix1) as tmp1:
|
134 |
content1 = await file1.read()
|
135 |
tmp1.write(content1)
|
136 |
tmp1_path = tmp1.name
|
137 |
|
|
|
138 |
suffix2 = os.path.splitext(file2.filename)[1] or ".wav"
|
139 |
with tempfile.NamedTemporaryFile(delete=False, suffix=suffix2) as tmp2:
|
140 |
content2 = await file2.read()
|
|
|
143 |
|
144 |
similarity_score, interpretation = comparer.predict(tmp1_path, tmp2_path)
|
145 |
return {"similarity_score": similarity_score, "interpretation": interpretation}
|
|
|
146 |
except Exception as e:
|
147 |
raise HTTPException(status_code=500, detail=str(e))
|
148 |
finally:
|
|
|
149 |
if tmp1_path and os.path.exists(tmp1_path):
|
150 |
os.remove(tmp1_path)
|
151 |
if tmp2_path and os.path.exists(tmp2_path):
|
|
|
153 |
|
154 |
@app.post("/clear_cache", summary="Clear Embedding Cache", response_model=dict)
|
155 |
async def clear_cache():
|
|
|
|
|
|
|
156 |
comparer.clear_cache()
|
157 |
return {"message": "Cache cleared."}
|