Spaces:
Sleeping
Sleeping
Update main.py
Browse files
main.py
CHANGED
@@ -35,8 +35,6 @@ class ErrorResponse(BaseModel):
|
|
35 |
# Initialize model from environment variable
|
36 |
def initialize_model():
|
37 |
global MODEL, PROCESSOR
|
38 |
-
|
39 |
-
# Get HF token from environment variable
|
40 |
hf_token = os.environ.get("HF_TOKEN", None)
|
41 |
model_name = os.environ.get("MODEL_NAME", "jonatasgrosman/wav2vec2-large-xlsr-53-arabic")
|
42 |
|
@@ -44,7 +42,7 @@ def initialize_model():
|
|
44 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
45 |
print(f"Loading model on device: {device}")
|
46 |
|
47 |
-
# Load model and processor using
|
48 |
if hf_token:
|
49 |
PROCESSOR = Wav2Vec2Processor.from_pretrained(model_name, token=hf_token)
|
50 |
MODEL = Wav2Vec2ForCTC.from_pretrained(model_name, token=hf_token)
|
@@ -59,6 +57,34 @@ def initialize_model():
|
|
59 |
print(f"Error loading model: {e}")
|
60 |
raise e
|
61 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
62 |
# Load audio file
|
63 |
def load_audio(file_path, target_sr=16000, trim_silence=True, normalize=True):
|
64 |
"""Load and preprocess an audio file."""
|
@@ -118,14 +144,12 @@ def custom_dtw(X, Y, metric='euclidean'):
|
|
118 |
wp : list
|
119 |
The warping path
|
120 |
"""
|
121 |
-
# Initialize cost matrix
|
122 |
n, m = len(X[0]), len(Y[0])
|
123 |
D = np.zeros((n+1, m+1))
|
124 |
D[0, :] = np.inf
|
125 |
D[:, 0] = np.inf
|
126 |
D[0, 0] = 0
|
127 |
|
128 |
-
# Fill cost matrix
|
129 |
for i in range(1, n+1):
|
130 |
for j in range(1, m+1):
|
131 |
if metric == 'euclidean':
|
@@ -137,7 +161,6 @@ def custom_dtw(X, Y, metric='euclidean'):
|
|
137 |
|
138 |
D[i, j] = cost + min(D[i-1, j], D[i, j-1], D[i-1, j-1])
|
139 |
|
140 |
-
# Backtrack to find warping path
|
141 |
i, j = n, m
|
142 |
wp = [(i, j)]
|
143 |
while i > 1 or j > 1:
|
@@ -153,7 +176,6 @@ def custom_dtw(X, Y, metric='euclidean'):
|
|
153 |
def compute_dtw_distance(features1, features2):
|
154 |
"""Compute the DTW distance between two sequences of features."""
|
155 |
try:
|
156 |
-
# Use custom DTW implementation instead of librosa's
|
157 |
D, wp = custom_dtw(features1, features2, metric='euclidean')
|
158 |
distance = D[-1, -1]
|
159 |
normalized_distance = distance / len(wp)
|
@@ -195,7 +217,7 @@ def cleanup_temp_files(file_paths):
|
|
195 |
except Exception as e:
|
196 |
print(f"Error removing temporary file {file_path}: {e}")
|
197 |
|
198 |
-
# API
|
199 |
@app.post("/compare", response_model=SimilarityResponse)
|
200 |
async def compare_recitations(
|
201 |
background_tasks: BackgroundTasks,
|
@@ -212,37 +234,29 @@ async def compare_recitations(
|
|
212 |
- **similarity_score**: Score between 0-100 indicating similarity
|
213 |
- **interpretation**: Text interpretation of the similarity
|
214 |
"""
|
215 |
-
# Check if model is initialized
|
216 |
if MODEL is None or PROCESSOR is None:
|
217 |
raise HTTPException(status_code=500, detail="Model not initialized")
|
218 |
|
219 |
-
# Temporary file paths
|
220 |
temp_file1 = os.path.join(UPLOAD_DIR, f"{uuid.uuid4()}.wav")
|
221 |
temp_file2 = os.path.join(UPLOAD_DIR, f"{uuid.uuid4()}.wav")
|
222 |
|
223 |
try:
|
224 |
-
# Save uploaded files
|
225 |
with open(temp_file1, "wb") as f:
|
226 |
shutil.copyfileobj(file1.file, f)
|
227 |
|
228 |
with open(temp_file2, "wb") as f:
|
229 |
shutil.copyfileobj(file2.file, f)
|
230 |
|
231 |
-
# Load audio files
|
232 |
audio1 = load_audio(temp_file1)
|
233 |
audio2 = load_audio(temp_file2)
|
234 |
|
235 |
-
# Extract embeddings
|
236 |
embedding1 = get_deep_embedding(audio1)
|
237 |
embedding2 = get_deep_embedding(audio2)
|
238 |
|
239 |
-
# Compute DTW distance
|
240 |
norm_distance = compute_dtw_distance(embedding1.T, embedding2.T)
|
241 |
|
242 |
-
# Interpret results
|
243 |
interpretation, similarity_score = interpret_similarity(norm_distance)
|
244 |
|
245 |
-
# Add cleanup task
|
246 |
background_tasks.add_task(cleanup_temp_files, [temp_file1, temp_file2])
|
247 |
|
248 |
return {
|
@@ -251,10 +265,10 @@ async def compare_recitations(
|
|
251 |
}
|
252 |
|
253 |
except Exception as e:
|
254 |
-
# Ensure files are cleaned up even in case of error
|
255 |
background_tasks.add_task(cleanup_temp_files, [temp_file1, temp_file2])
|
256 |
raise HTTPException(status_code=500, detail=str(e))
|
257 |
|
|
|
258 |
@app.get("/health")
|
259 |
async def health_check():
|
260 |
"""Health check endpoint."""
|
@@ -265,24 +279,8 @@ async def health_check():
|
|
265 |
)
|
266 |
return {"status": "ok", "model_loaded": True}
|
267 |
|
268 |
-
# Use lifespan context manager instead of on_event decorators
|
269 |
-
@asynccontextmanager
|
270 |
-
async def lifespan(app: FastAPI):
|
271 |
-
initialize_model()
|
272 |
-
yield
|
273 |
-
|
274 |
-
# Initialize FastAPI app with lifespan handler
|
275 |
-
app = FastAPI(
|
276 |
-
title="Quran Recitation Comparison API",
|
277 |
-
description="API for comparing similarity between Quran recitations using Wav2Vec2 embeddings",
|
278 |
-
version="1.0.0",
|
279 |
-
lifespan=lifespan
|
280 |
-
)
|
281 |
-
|
282 |
-
# Note: Ensure that all route definitions are declared AFTER the app initialization above.
|
283 |
-
|
284 |
# Run the FastAPI app
|
285 |
if __name__ == "__main__":
|
286 |
import uvicorn
|
287 |
-
port = int(os.environ.get("PORT", 7860)) # Default to port 7860
|
288 |
uvicorn.run("main:app", host="0.0.0.0", port=port, reload=False)
|
|
|
35 |
# Initialize model from environment variable
|
36 |
def initialize_model():
|
37 |
global MODEL, PROCESSOR
|
|
|
|
|
38 |
hf_token = os.environ.get("HF_TOKEN", None)
|
39 |
model_name = os.environ.get("MODEL_NAME", "jonatasgrosman/wav2vec2-large-xlsr-53-arabic")
|
40 |
|
|
|
42 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
43 |
print(f"Loading model on device: {device}")
|
44 |
|
45 |
+
# Load model and processor using updated parameter `token`
|
46 |
if hf_token:
|
47 |
PROCESSOR = Wav2Vec2Processor.from_pretrained(model_name, token=hf_token)
|
48 |
MODEL = Wav2Vec2ForCTC.from_pretrained(model_name, token=hf_token)
|
|
|
57 |
print(f"Error loading model: {e}")
|
58 |
raise e
|
59 |
|
60 |
+
# Lifespan event handler to initialize the model at startup
|
61 |
+
@asynccontextmanager
|
62 |
+
async def lifespan(app: FastAPI):
|
63 |
+
initialize_model()
|
64 |
+
yield
|
65 |
+
|
66 |
+
# Create the FastAPI app with the lifespan handler and CORS middleware
|
67 |
+
app = FastAPI(
|
68 |
+
title="Quran Recitation Comparison API",
|
69 |
+
description="API for comparing similarity between Quran recitations using Wav2Vec2 embeddings",
|
70 |
+
version="1.0.0",
|
71 |
+
lifespan=lifespan
|
72 |
+
)
|
73 |
+
|
74 |
+
app.add_middleware(
|
75 |
+
CORSMiddleware,
|
76 |
+
allow_origins=["*"], # Allows all origins
|
77 |
+
allow_credentials=True,
|
78 |
+
allow_methods=["*"], # Allows all methods
|
79 |
+
allow_headers=["*"], # Allows all headers
|
80 |
+
)
|
81 |
+
|
82 |
+
# Root endpoint
|
83 |
+
@app.get("/")
|
84 |
+
async def root():
|
85 |
+
"""Welcome endpoint."""
|
86 |
+
return {"message": "Welcome to the Quran Recitation Comparison API"}
|
87 |
+
|
88 |
# Load audio file
|
89 |
def load_audio(file_path, target_sr=16000, trim_silence=True, normalize=True):
|
90 |
"""Load and preprocess an audio file."""
|
|
|
144 |
wp : list
|
145 |
The warping path
|
146 |
"""
|
|
|
147 |
n, m = len(X[0]), len(Y[0])
|
148 |
D = np.zeros((n+1, m+1))
|
149 |
D[0, :] = np.inf
|
150 |
D[:, 0] = np.inf
|
151 |
D[0, 0] = 0
|
152 |
|
|
|
153 |
for i in range(1, n+1):
|
154 |
for j in range(1, m+1):
|
155 |
if metric == 'euclidean':
|
|
|
161 |
|
162 |
D[i, j] = cost + min(D[i-1, j], D[i, j-1], D[i-1, j-1])
|
163 |
|
|
|
164 |
i, j = n, m
|
165 |
wp = [(i, j)]
|
166 |
while i > 1 or j > 1:
|
|
|
176 |
def compute_dtw_distance(features1, features2):
|
177 |
"""Compute the DTW distance between two sequences of features."""
|
178 |
try:
|
|
|
179 |
D, wp = custom_dtw(features1, features2, metric='euclidean')
|
180 |
distance = D[-1, -1]
|
181 |
normalized_distance = distance / len(wp)
|
|
|
217 |
except Exception as e:
|
218 |
print(f"Error removing temporary file {file_path}: {e}")
|
219 |
|
220 |
+
# API endpoint for comparing recitations
|
221 |
@app.post("/compare", response_model=SimilarityResponse)
|
222 |
async def compare_recitations(
|
223 |
background_tasks: BackgroundTasks,
|
|
|
234 |
- **similarity_score**: Score between 0-100 indicating similarity
|
235 |
- **interpretation**: Text interpretation of the similarity
|
236 |
"""
|
|
|
237 |
if MODEL is None or PROCESSOR is None:
|
238 |
raise HTTPException(status_code=500, detail="Model not initialized")
|
239 |
|
|
|
240 |
temp_file1 = os.path.join(UPLOAD_DIR, f"{uuid.uuid4()}.wav")
|
241 |
temp_file2 = os.path.join(UPLOAD_DIR, f"{uuid.uuid4()}.wav")
|
242 |
|
243 |
try:
|
|
|
244 |
with open(temp_file1, "wb") as f:
|
245 |
shutil.copyfileobj(file1.file, f)
|
246 |
|
247 |
with open(temp_file2, "wb") as f:
|
248 |
shutil.copyfileobj(file2.file, f)
|
249 |
|
|
|
250 |
audio1 = load_audio(temp_file1)
|
251 |
audio2 = load_audio(temp_file2)
|
252 |
|
|
|
253 |
embedding1 = get_deep_embedding(audio1)
|
254 |
embedding2 = get_deep_embedding(audio2)
|
255 |
|
|
|
256 |
norm_distance = compute_dtw_distance(embedding1.T, embedding2.T)
|
257 |
|
|
|
258 |
interpretation, similarity_score = interpret_similarity(norm_distance)
|
259 |
|
|
|
260 |
background_tasks.add_task(cleanup_temp_files, [temp_file1, temp_file2])
|
261 |
|
262 |
return {
|
|
|
265 |
}
|
266 |
|
267 |
except Exception as e:
|
|
|
268 |
background_tasks.add_task(cleanup_temp_files, [temp_file1, temp_file2])
|
269 |
raise HTTPException(status_code=500, detail=str(e))
|
270 |
|
271 |
+
# Health check endpoint
|
272 |
@app.get("/health")
|
273 |
async def health_check():
|
274 |
"""Health check endpoint."""
|
|
|
279 |
)
|
280 |
return {"status": "ok", "model_loaded": True}
|
281 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
282 |
# Run the FastAPI app
|
283 |
if __name__ == "__main__":
|
284 |
import uvicorn
|
285 |
+
port = int(os.environ.get("PORT", 7860)) # Default to port 7860
|
286 |
uvicorn.run("main:app", host="0.0.0.0", port=port, reload=False)
|