Hammad712 commited on
Commit
ef20d33
·
verified ·
1 Parent(s): b05966a

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +32 -34
main.py CHANGED
@@ -35,8 +35,6 @@ class ErrorResponse(BaseModel):
35
  # Initialize model from environment variable
36
  def initialize_model():
37
  global MODEL, PROCESSOR
38
-
39
- # Get HF token from environment variable
40
  hf_token = os.environ.get("HF_TOKEN", None)
41
  model_name = os.environ.get("MODEL_NAME", "jonatasgrosman/wav2vec2-large-xlsr-53-arabic")
42
 
@@ -44,7 +42,7 @@ def initialize_model():
44
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
45
  print(f"Loading model on device: {device}")
46
 
47
- # Load model and processor using the updated parameter `token`
48
  if hf_token:
49
  PROCESSOR = Wav2Vec2Processor.from_pretrained(model_name, token=hf_token)
50
  MODEL = Wav2Vec2ForCTC.from_pretrained(model_name, token=hf_token)
@@ -59,6 +57,34 @@ def initialize_model():
59
  print(f"Error loading model: {e}")
60
  raise e
61
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
62
  # Load audio file
63
  def load_audio(file_path, target_sr=16000, trim_silence=True, normalize=True):
64
  """Load and preprocess an audio file."""
@@ -118,14 +144,12 @@ def custom_dtw(X, Y, metric='euclidean'):
118
  wp : list
119
  The warping path
120
  """
121
- # Initialize cost matrix
122
  n, m = len(X[0]), len(Y[0])
123
  D = np.zeros((n+1, m+1))
124
  D[0, :] = np.inf
125
  D[:, 0] = np.inf
126
  D[0, 0] = 0
127
 
128
- # Fill cost matrix
129
  for i in range(1, n+1):
130
  for j in range(1, m+1):
131
  if metric == 'euclidean':
@@ -137,7 +161,6 @@ def custom_dtw(X, Y, metric='euclidean'):
137
 
138
  D[i, j] = cost + min(D[i-1, j], D[i, j-1], D[i-1, j-1])
139
 
140
- # Backtrack to find warping path
141
  i, j = n, m
142
  wp = [(i, j)]
143
  while i > 1 or j > 1:
@@ -153,7 +176,6 @@ def custom_dtw(X, Y, metric='euclidean'):
153
  def compute_dtw_distance(features1, features2):
154
  """Compute the DTW distance between two sequences of features."""
155
  try:
156
- # Use custom DTW implementation instead of librosa's
157
  D, wp = custom_dtw(features1, features2, metric='euclidean')
158
  distance = D[-1, -1]
159
  normalized_distance = distance / len(wp)
@@ -195,7 +217,7 @@ def cleanup_temp_files(file_paths):
195
  except Exception as e:
196
  print(f"Error removing temporary file {file_path}: {e}")
197
 
198
- # API endpoints
199
  @app.post("/compare", response_model=SimilarityResponse)
200
  async def compare_recitations(
201
  background_tasks: BackgroundTasks,
@@ -212,37 +234,29 @@ async def compare_recitations(
212
  - **similarity_score**: Score between 0-100 indicating similarity
213
  - **interpretation**: Text interpretation of the similarity
214
  """
215
- # Check if model is initialized
216
  if MODEL is None or PROCESSOR is None:
217
  raise HTTPException(status_code=500, detail="Model not initialized")
218
 
219
- # Temporary file paths
220
  temp_file1 = os.path.join(UPLOAD_DIR, f"{uuid.uuid4()}.wav")
221
  temp_file2 = os.path.join(UPLOAD_DIR, f"{uuid.uuid4()}.wav")
222
 
223
  try:
224
- # Save uploaded files
225
  with open(temp_file1, "wb") as f:
226
  shutil.copyfileobj(file1.file, f)
227
 
228
  with open(temp_file2, "wb") as f:
229
  shutil.copyfileobj(file2.file, f)
230
 
231
- # Load audio files
232
  audio1 = load_audio(temp_file1)
233
  audio2 = load_audio(temp_file2)
234
 
235
- # Extract embeddings
236
  embedding1 = get_deep_embedding(audio1)
237
  embedding2 = get_deep_embedding(audio2)
238
 
239
- # Compute DTW distance
240
  norm_distance = compute_dtw_distance(embedding1.T, embedding2.T)
241
 
242
- # Interpret results
243
  interpretation, similarity_score = interpret_similarity(norm_distance)
244
 
245
- # Add cleanup task
246
  background_tasks.add_task(cleanup_temp_files, [temp_file1, temp_file2])
247
 
248
  return {
@@ -251,10 +265,10 @@ async def compare_recitations(
251
  }
252
 
253
  except Exception as e:
254
- # Ensure files are cleaned up even in case of error
255
  background_tasks.add_task(cleanup_temp_files, [temp_file1, temp_file2])
256
  raise HTTPException(status_code=500, detail=str(e))
257
 
 
258
  @app.get("/health")
259
  async def health_check():
260
  """Health check endpoint."""
@@ -265,24 +279,8 @@ async def health_check():
265
  )
266
  return {"status": "ok", "model_loaded": True}
267
 
268
- # Use lifespan context manager instead of on_event decorators
269
- @asynccontextmanager
270
- async def lifespan(app: FastAPI):
271
- initialize_model()
272
- yield
273
-
274
- # Initialize FastAPI app with lifespan handler
275
- app = FastAPI(
276
- title="Quran Recitation Comparison API",
277
- description="API for comparing similarity between Quran recitations using Wav2Vec2 embeddings",
278
- version="1.0.0",
279
- lifespan=lifespan
280
- )
281
-
282
- # Note: Ensure that all route definitions are declared AFTER the app initialization above.
283
-
284
  # Run the FastAPI app
285
  if __name__ == "__main__":
286
  import uvicorn
287
- port = int(os.environ.get("PORT", 7860)) # Default to port 7860 for Hugging Face Spaces
288
  uvicorn.run("main:app", host="0.0.0.0", port=port, reload=False)
 
35
  # Initialize model from environment variable
36
  def initialize_model():
37
  global MODEL, PROCESSOR
 
 
38
  hf_token = os.environ.get("HF_TOKEN", None)
39
  model_name = os.environ.get("MODEL_NAME", "jonatasgrosman/wav2vec2-large-xlsr-53-arabic")
40
 
 
42
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
43
  print(f"Loading model on device: {device}")
44
 
45
+ # Load model and processor using updated parameter `token`
46
  if hf_token:
47
  PROCESSOR = Wav2Vec2Processor.from_pretrained(model_name, token=hf_token)
48
  MODEL = Wav2Vec2ForCTC.from_pretrained(model_name, token=hf_token)
 
57
  print(f"Error loading model: {e}")
58
  raise e
59
 
60
+ # Lifespan event handler to initialize the model at startup
61
+ @asynccontextmanager
62
+ async def lifespan(app: FastAPI):
63
+ initialize_model()
64
+ yield
65
+
66
+ # Create the FastAPI app with the lifespan handler and CORS middleware
67
+ app = FastAPI(
68
+ title="Quran Recitation Comparison API",
69
+ description="API for comparing similarity between Quran recitations using Wav2Vec2 embeddings",
70
+ version="1.0.0",
71
+ lifespan=lifespan
72
+ )
73
+
74
+ app.add_middleware(
75
+ CORSMiddleware,
76
+ allow_origins=["*"], # Allows all origins
77
+ allow_credentials=True,
78
+ allow_methods=["*"], # Allows all methods
79
+ allow_headers=["*"], # Allows all headers
80
+ )
81
+
82
+ # Root endpoint
83
+ @app.get("/")
84
+ async def root():
85
+ """Welcome endpoint."""
86
+ return {"message": "Welcome to the Quran Recitation Comparison API"}
87
+
88
  # Load audio file
89
  def load_audio(file_path, target_sr=16000, trim_silence=True, normalize=True):
90
  """Load and preprocess an audio file."""
 
144
  wp : list
145
  The warping path
146
  """
 
147
  n, m = len(X[0]), len(Y[0])
148
  D = np.zeros((n+1, m+1))
149
  D[0, :] = np.inf
150
  D[:, 0] = np.inf
151
  D[0, 0] = 0
152
 
 
153
  for i in range(1, n+1):
154
  for j in range(1, m+1):
155
  if metric == 'euclidean':
 
161
 
162
  D[i, j] = cost + min(D[i-1, j], D[i, j-1], D[i-1, j-1])
163
 
 
164
  i, j = n, m
165
  wp = [(i, j)]
166
  while i > 1 or j > 1:
 
176
  def compute_dtw_distance(features1, features2):
177
  """Compute the DTW distance between two sequences of features."""
178
  try:
 
179
  D, wp = custom_dtw(features1, features2, metric='euclidean')
180
  distance = D[-1, -1]
181
  normalized_distance = distance / len(wp)
 
217
  except Exception as e:
218
  print(f"Error removing temporary file {file_path}: {e}")
219
 
220
+ # API endpoint for comparing recitations
221
  @app.post("/compare", response_model=SimilarityResponse)
222
  async def compare_recitations(
223
  background_tasks: BackgroundTasks,
 
234
  - **similarity_score**: Score between 0-100 indicating similarity
235
  - **interpretation**: Text interpretation of the similarity
236
  """
 
237
  if MODEL is None or PROCESSOR is None:
238
  raise HTTPException(status_code=500, detail="Model not initialized")
239
 
 
240
  temp_file1 = os.path.join(UPLOAD_DIR, f"{uuid.uuid4()}.wav")
241
  temp_file2 = os.path.join(UPLOAD_DIR, f"{uuid.uuid4()}.wav")
242
 
243
  try:
 
244
  with open(temp_file1, "wb") as f:
245
  shutil.copyfileobj(file1.file, f)
246
 
247
  with open(temp_file2, "wb") as f:
248
  shutil.copyfileobj(file2.file, f)
249
 
 
250
  audio1 = load_audio(temp_file1)
251
  audio2 = load_audio(temp_file2)
252
 
 
253
  embedding1 = get_deep_embedding(audio1)
254
  embedding2 = get_deep_embedding(audio2)
255
 
 
256
  norm_distance = compute_dtw_distance(embedding1.T, embedding2.T)
257
 
 
258
  interpretation, similarity_score = interpret_similarity(norm_distance)
259
 
 
260
  background_tasks.add_task(cleanup_temp_files, [temp_file1, temp_file2])
261
 
262
  return {
 
265
  }
266
 
267
  except Exception as e:
 
268
  background_tasks.add_task(cleanup_temp_files, [temp_file1, temp_file2])
269
  raise HTTPException(status_code=500, detail=str(e))
270
 
271
+ # Health check endpoint
272
  @app.get("/health")
273
  async def health_check():
274
  """Health check endpoint."""
 
279
  )
280
  return {"status": "ok", "model_loaded": True}
281
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
282
  # Run the FastAPI app
283
  if __name__ == "__main__":
284
  import uvicorn
285
+ port = int(os.environ.get("PORT", 7860)) # Default to port 7860
286
  uvicorn.run("main:app", host="0.0.0.0", port=port, reload=False)