Shujah239 commited on
Commit
c7aefa7
·
verified ·
1 Parent(s): c1dfbad

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +79 -20
app.py CHANGED
@@ -5,13 +5,12 @@ from pathlib import Path
5
  from typing import List, Dict, Any, Optional
6
 
7
  from fastapi import FastAPI, HTTPException, UploadFile, File, BackgroundTasks, Request
8
- from fastapi.responses import FileResponse, JSONResponse
9
  from fastapi.middleware.cors import CORSMiddleware
10
  from fastapi.middleware.gzip import GZipMiddleware
11
  from transformers import pipeline
12
  import torch
13
  import uvicorn
14
- import os
15
 
16
  # Configure logging
17
  logging.basicConfig(level=logging.INFO)
@@ -19,8 +18,6 @@ logger = logging.getLogger(__name__)
19
 
20
  # Define uploads directory
21
  UPLOAD_DIR = Path("uploads")
22
- UPLOAD_DIR.mkdir(parents=True, exist_ok=True) # Create uploads directory at startup
23
-
24
  MAX_STORAGE_MB = 100 # Maximum storage in MB
25
  MAX_FILE_AGE_DAYS = 1 # Maximum age of files in days
26
 
@@ -52,41 +49,47 @@ classifier = None
52
  async def load_model():
53
  global classifier
54
  try:
 
55
  device = 0 if torch.cuda.is_available() else -1
56
 
 
57
  if device == -1:
58
- logger.info("Loading quantized model for CPU usage...")
59
  classifier = pipeline(
60
  "audio-classification",
61
  model="superb/wav2vec2-base-superb-er",
62
  device=device,
63
- torch_dtype=torch.float16
64
  )
65
  else:
66
- logger.info("Loading model on GPU...")
67
  classifier = pipeline(
68
  "audio-classification",
69
  model="superb/wav2vec2-base-superb-er",
70
  device=device
71
  )
72
- logger.info("Model loaded successfully. (Device: %s)", "GPU" if device == 0 else "CPU")
 
 
73
  except Exception as e:
74
  logger.error("Failed to load model: %s", e)
75
- classifier = None
 
76
 
77
  async def cleanup_old_files():
78
- """Clean up old files to prevent storage issues."""
79
  try:
 
80
  now = time.time()
81
  deleted_count = 0
82
  for file_path in UPLOAD_DIR.iterdir():
83
  if file_path.is_file():
84
  file_age_days = (now - file_path.stat().st_mtime) / (60 * 60 * 24)
85
  if file_age_days > MAX_FILE_AGE_DAYS:
86
- file_path.unlink(missing_ok=True) if hasattr(file_path, "missing_ok") else file_path.unlink()
87
  deleted_count += 1
 
88
  if deleted_count > 0:
89
- logger.info(f"Cleaned up {deleted_count} old files.")
90
  except Exception as e:
91
  logger.error(f"Error during file cleanup: {e}")
92
 
@@ -104,16 +107,15 @@ async def health():
104
  """Health check endpoint."""
105
  return {"status": "ok", "model_loaded": classifier is not None}
106
 
107
- @app.get("/health/health")
108
- async def double_health():
109
- """Fallback if Hugging Face requests /health/health (they sometimes do)."""
110
- return {"status": "ok", "model_loaded": classifier is not None}
111
-
112
  @app.post("/upload")
113
  async def upload_audio(
114
  file: UploadFile = File(...),
115
  background_tasks: BackgroundTasks = None
116
  ):
 
 
 
 
117
  if not classifier:
118
  raise HTTPException(status_code=503, detail="Model not yet loaded")
119
 
@@ -121,6 +123,7 @@ async def upload_audio(
121
  if not filename:
122
  raise HTTPException(status_code=400, detail="Invalid filename")
123
 
 
124
  valid_extensions = [".wav", ".mp3", ".ogg", ".flac"]
125
  if not any(filename.lower().endswith(ext) for ext in valid_extensions):
126
  raise HTTPException(
@@ -128,6 +131,7 @@ async def upload_audio(
128
  detail=f"Invalid file type. Supported types: {', '.join(valid_extensions)}"
129
  )
130
 
 
131
  try:
132
  contents = await file.read()
133
  except Exception as e:
@@ -136,25 +140,33 @@ async def upload_audio(
136
  finally:
137
  await file.close()
138
 
 
139
  if len(contents) > 10 * 1024 * 1024:
140
  raise HTTPException(
141
  status_code=413,
142
  detail="File too large. Maximum size is 10MB"
143
  )
144
 
 
145
  try:
146
  total, used, free = shutil.disk_usage(UPLOAD_DIR)
147
  free_mb = free / (1024 * 1024)
148
 
149
- if free_mb < 10:
 
150
  if background_tasks:
151
  background_tasks.add_task(cleanup_old_files)
152
 
153
  if len(contents) > free:
 
 
 
 
154
  raise HTTPException(status_code=507, detail="Insufficient storage to save file")
155
  except Exception as e:
156
  logger.warning(f"Failed to check disk usage: {e}")
157
 
 
158
  file_path = UPLOAD_DIR / filename
159
  try:
160
  with open(file_path, "wb") as f:
@@ -164,21 +176,30 @@ async def upload_audio(
164
  logger.error("Failed to save file %s: %s", filename, e)
165
  raise HTTPException(status_code=500, detail=f"Failed to save file: {str(e)}")
166
 
 
167
  try:
168
  results = classifier(str(file_path))
 
 
169
  if background_tasks:
170
  background_tasks.add_task(cleanup_old_files)
 
171
  return {"filename": filename, "predictions": results}
172
  except Exception as e:
173
  logger.error("Model inference failed for %s: %s", filename, e)
 
174
  try:
175
- file_path.unlink(missing_ok=True) if hasattr(file_path, "missing_ok") else file_path.unlink()
176
  except Exception:
177
  pass
178
  raise HTTPException(status_code=500, detail=f"Emotion detection failed: {str(e)}")
179
 
180
  @app.get("/recordings")
181
  async def list_recordings():
 
 
 
 
182
  try:
183
  files = [f.name for f in UPLOAD_DIR.iterdir() if f.is_file()]
184
  total, used, free = shutil.disk_usage(UPLOAD_DIR)
@@ -194,10 +215,14 @@ async def list_recordings():
194
 
195
  @app.get("/recordings/{filename}")
196
  async def get_recording(filename: str):
 
 
 
197
  safe_name = Path(filename).name
198
  file_path = UPLOAD_DIR / safe_name
199
  if not file_path.exists() or not file_path.is_file():
200
  raise HTTPException(status_code=404, detail="Recording not found")
 
201
  import mimetypes
202
  media_type, _ = mimetypes.guess_type(file_path)
203
  return FileResponse(
@@ -208,6 +233,10 @@ async def get_recording(filename: str):
208
 
209
  @app.get("/analyze/{filename}")
210
  async def analyze_recording(filename: str):
 
 
 
 
211
  if not classifier:
212
  raise HTTPException(status_code=503, detail="Model not yet loaded")
213
 
@@ -224,16 +253,46 @@ async def analyze_recording(filename: str):
224
 
225
  @app.delete("/recordings/{filename}")
226
  async def delete_recording(filename: str):
 
 
 
227
  safe_name = Path(filename).name
228
  file_path = UPLOAD_DIR / safe_name
229
  if not file_path.exists() or not file_path.is_file():
230
  raise HTTPException(status_code=404, detail="Recording not found")
231
  try:
232
- file_path.unlink(missing_ok=True) if hasattr(file_path, "missing_ok") else file_path.unlink()
233
  return {"status": "success", "message": f"Deleted {safe_name}"}
234
  except Exception as e:
235
  logger.error("Failed to delete file %s: %s", filename, e)
236
  raise HTTPException(status_code=500, detail=f"Failed to delete file: {str(e)}")
237
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
238
  if __name__ == "__main__":
 
239
  uvicorn.run(app, host="0.0.0.0", port=7860)
 
5
  from typing import List, Dict, Any, Optional
6
 
7
  from fastapi import FastAPI, HTTPException, UploadFile, File, BackgroundTasks, Request
8
+ from fastapi.responses import FileResponse
9
  from fastapi.middleware.cors import CORSMiddleware
10
  from fastapi.middleware.gzip import GZipMiddleware
11
  from transformers import pipeline
12
  import torch
13
  import uvicorn
 
14
 
15
  # Configure logging
16
  logging.basicConfig(level=logging.INFO)
 
18
 
19
  # Define uploads directory
20
  UPLOAD_DIR = Path("uploads")
 
 
21
  MAX_STORAGE_MB = 100 # Maximum storage in MB
22
  MAX_FILE_AGE_DAYS = 1 # Maximum age of files in days
23
 
 
49
  async def load_model():
50
  global classifier
51
  try:
52
+ # Use GPU if available, else CPU
53
  device = 0 if torch.cuda.is_available() else -1
54
 
55
+ # For Hugging Face Spaces with limited resources, use quantized model if on CPU
56
  if device == -1:
57
+ logger.info("Loading quantized model for CPU usage")
58
  classifier = pipeline(
59
  "audio-classification",
60
  model="superb/wav2vec2-base-superb-er",
61
  device=device,
62
+ torch_dtype=torch.float16 # Use half precision
63
  )
64
  else:
 
65
  classifier = pipeline(
66
  "audio-classification",
67
  model="superb/wav2vec2-base-superb-er",
68
  device=device
69
  )
70
+
71
+ logger.info("Loaded emotion recognition model (device=%s)",
72
+ "GPU" if device == 0 else "CPU")
73
  except Exception as e:
74
  logger.error("Failed to load model: %s", e)
75
+ # Don't raise the error - let the app start even if model fails
76
+ # We'll handle this in the endpoints
77
 
78
  async def cleanup_old_files():
79
+ """Clean up old files to prevent storage issues on Hugging Face Spaces."""
80
  try:
81
+ # Remove files older than MAX_FILE_AGE_DAYS
82
  now = time.time()
83
  deleted_count = 0
84
  for file_path in UPLOAD_DIR.iterdir():
85
  if file_path.is_file():
86
  file_age_days = (now - file_path.stat().st_mtime) / (60 * 60 * 24)
87
  if file_age_days > MAX_FILE_AGE_DAYS:
88
+ file_path.unlink()
89
  deleted_count += 1
90
+
91
  if deleted_count > 0:
92
+ logger.info(f"Cleaned up {deleted_count} old files")
93
  except Exception as e:
94
  logger.error(f"Error during file cleanup: {e}")
95
 
 
107
  """Health check endpoint."""
108
  return {"status": "ok", "model_loaded": classifier is not None}
109
 
 
 
 
 
 
110
  @app.post("/upload")
111
  async def upload_audio(
112
  file: UploadFile = File(...),
113
  background_tasks: BackgroundTasks = None
114
  ):
115
+ """
116
+ Upload an audio file and analyze emotions.
117
+ Saves the file to the uploads directory and returns model predictions.
118
+ """
119
  if not classifier:
120
  raise HTTPException(status_code=503, detail="Model not yet loaded")
121
 
 
123
  if not filename:
124
  raise HTTPException(status_code=400, detail="Invalid filename")
125
 
126
+ # Check file extension
127
  valid_extensions = [".wav", ".mp3", ".ogg", ".flac"]
128
  if not any(filename.lower().endswith(ext) for ext in valid_extensions):
129
  raise HTTPException(
 
131
  detail=f"Invalid file type. Supported types: {', '.join(valid_extensions)}"
132
  )
133
 
134
+ # Read file contents
135
  try:
136
  contents = await file.read()
137
  except Exception as e:
 
140
  finally:
141
  await file.close()
142
 
143
+ # Check file size (limit to 10MB for Spaces)
144
  if len(contents) > 10 * 1024 * 1024:
145
  raise HTTPException(
146
  status_code=413,
147
  detail="File too large. Maximum size is 10MB"
148
  )
149
 
150
+ # Check available disk space
151
  try:
152
  total, used, free = shutil.disk_usage(UPLOAD_DIR)
153
  free_mb = free / (1024 * 1024)
154
 
155
+ if free_mb < 10: # Keep at least 10MB free
156
+ # Schedule cleanup in background
157
  if background_tasks:
158
  background_tasks.add_task(cleanup_old_files)
159
 
160
  if len(contents) > free:
161
+ logger.error(
162
+ "Insufficient storage: needed %d bytes, free %d bytes",
163
+ len(contents), free
164
+ )
165
  raise HTTPException(status_code=507, detail="Insufficient storage to save file")
166
  except Exception as e:
167
  logger.warning(f"Failed to check disk usage: {e}")
168
 
169
+ # Save file to uploads directory
170
  file_path = UPLOAD_DIR / filename
171
  try:
172
  with open(file_path, "wb") as f:
 
176
  logger.error("Failed to save file %s: %s", filename, e)
177
  raise HTTPException(status_code=500, detail=f"Failed to save file: {str(e)}")
178
 
179
+ # Analyze the audio file using the pretrained model pipeline
180
  try:
181
  results = classifier(str(file_path))
182
+
183
+ # Schedule cleanup in background
184
  if background_tasks:
185
  background_tasks.add_task(cleanup_old_files)
186
+
187
  return {"filename": filename, "predictions": results}
188
  except Exception as e:
189
  logger.error("Model inference failed for %s: %s", filename, e)
190
+ # Try to remove the file if inference fails
191
  try:
192
+ file_path.unlink(missing_ok=True)
193
  except Exception:
194
  pass
195
  raise HTTPException(status_code=500, detail=f"Emotion detection failed: {str(e)}")
196
 
197
  @app.get("/recordings")
198
  async def list_recordings():
199
+ """
200
+ List all uploaded recordings.
201
+ Returns a JSON list of filenames in the uploads directory.
202
+ """
203
  try:
204
  files = [f.name for f in UPLOAD_DIR.iterdir() if f.is_file()]
205
  total, used, free = shutil.disk_usage(UPLOAD_DIR)
 
215
 
216
  @app.get("/recordings/{filename}")
217
  async def get_recording(filename: str):
218
+ """
219
+ Stream/download an audio file from the server.
220
+ """
221
  safe_name = Path(filename).name
222
  file_path = UPLOAD_DIR / safe_name
223
  if not file_path.exists() or not file_path.is_file():
224
  raise HTTPException(status_code=404, detail="Recording not found")
225
+ # Guess MIME type (fallback to octet-stream)
226
  import mimetypes
227
  media_type, _ = mimetypes.guess_type(file_path)
228
  return FileResponse(
 
233
 
234
  @app.get("/analyze/{filename}")
235
  async def analyze_recording(filename: str):
236
+ """
237
+ Analyze an already-uploaded recording by filename.
238
+ Returns emotion predictions for the given file.
239
+ """
240
  if not classifier:
241
  raise HTTPException(status_code=503, detail="Model not yet loaded")
242
 
 
253
 
254
  @app.delete("/recordings/{filename}")
255
  async def delete_recording(filename: str):
256
+ """
257
+ Delete a recording by filename.
258
+ """
259
  safe_name = Path(filename).name
260
  file_path = UPLOAD_DIR / safe_name
261
  if not file_path.exists() or not file_path.is_file():
262
  raise HTTPException(status_code=404, detail="Recording not found")
263
  try:
264
+ file_path.unlink()
265
  return {"status": "success", "message": f"Deleted {safe_name}"}
266
  except Exception as e:
267
  logger.error("Failed to delete file %s: %s", filename, e)
268
  raise HTTPException(status_code=500, detail=f"Failed to delete file: {str(e)}")
269
 
270
+ # New endpoint to analyze emotion directly from uploaded file
271
+ @app.post("/analyze_emotion")
272
+ async def analyze_emotion(file: UploadFile = File(...)):
273
+ """
274
+ Analyze the uploaded audio file and return emotion predictions.
275
+ """
276
+ if not classifier:
277
+ raise HTTPException(status_code=503, detail="Model not yet loaded")
278
+
279
+ # Save uploaded file temporarily
280
+ temp_file = Path("temp_audio_file.wav")
281
+ try:
282
+ contents = await file.read()
283
+ with open(temp_file, "wb") as f:
284
+ f.write(contents)
285
+
286
+ # Run analysis on the uploaded file
287
+ results = classifier(str(temp_file))
288
+ return {"predictions": results}
289
+ except Exception as e:
290
+ logger.error("Failed to analyze the file: %s", e)
291
+ raise HTTPException(status_code=500, detail=f"Failed to analyze the file: {str(e)}")
292
+ finally:
293
+ if temp_file.exists():
294
+ temp_file.unlink() # Clean up temporary file
295
+
296
  if __name__ == "__main__":
297
+ # Bind to 0.0.0.0:7860 for Hugging Face Spaces compatibility
298
  uvicorn.run(app, host="0.0.0.0", port=7860)