|
import shutil |
|
import logging |
|
import time |
|
from pathlib import Path |
|
from typing import List, Dict, Any, Optional |
|
|
|
from fastapi import FastAPI, HTTPException, UploadFile, File, BackgroundTasks, Request |
|
from fastapi.responses import FileResponse |
|
from fastapi.middleware.cors import CORSMiddleware |
|
from fastapi.middleware.gzip import GZipMiddleware |
|
from transformers import pipeline |
|
import torch |
|
import uvicorn |
|
|
|
|
|
logging.basicConfig(level=logging.INFO) |
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
UPLOAD_DIR = Path("uploads") |
|
MAX_STORAGE_MB = 100 |
|
MAX_FILE_AGE_DAYS = 1 |
|
|
|
app = FastAPI( |
|
title="Emotion Detection API", |
|
description="Audio emotion detection using wav2vec2", |
|
version="1.0.0", |
|
) |
|
|
|
|
|
@app.get("/") |
|
async def root(): |
|
return {"message": "Audio Emotion Detection API", "status": "running"} |
|
|
|
|
|
app.add_middleware( |
|
CORSMiddleware, |
|
allow_origins=["*"], |
|
allow_credentials=True, |
|
allow_methods=["*"], |
|
allow_headers=["*"], |
|
) |
|
app.add_middleware(GZipMiddleware, minimum_size=1000) |
|
|
|
|
|
classifier = None |
|
|
|
@app.on_event("startup") |
|
async def load_model(): |
|
global classifier |
|
try: |
|
|
|
device = 0 if torch.cuda.is_available() else -1 |
|
|
|
|
|
if device == -1: |
|
logger.info("Loading quantized model for CPU usage") |
|
classifier = pipeline( |
|
"audio-classification", |
|
model="superb/wav2vec2-base-superb-er", |
|
device=device, |
|
torch_dtype=torch.float16 |
|
) |
|
else: |
|
classifier = pipeline( |
|
"audio-classification", |
|
model="superb/wav2vec2-base-superb-er", |
|
device=device |
|
) |
|
|
|
logger.info("Loaded emotion recognition model (device=%s)", |
|
"GPU" if device == 0 else "CPU") |
|
except Exception as e: |
|
logger.error("Failed to load model: %s", e) |
|
|
|
|
|
|
|
async def cleanup_old_files(): |
|
"""Clean up old files to prevent storage issues on Hugging Face Spaces.""" |
|
try: |
|
|
|
now = time.time() |
|
deleted_count = 0 |
|
for file_path in UPLOAD_DIR.iterdir(): |
|
if file_path.is_file(): |
|
file_age_days = (now - file_path.stat().st_mtime) / (60 * 60 * 24) |
|
if file_age_days > MAX_FILE_AGE_DAYS: |
|
file_path.unlink() |
|
deleted_count += 1 |
|
|
|
if deleted_count > 0: |
|
logger.info(f"Cleaned up {deleted_count} old files") |
|
except Exception as e: |
|
logger.error(f"Error during file cleanup: {e}") |
|
|
|
@app.middleware("http") |
|
async def add_process_time_header(request: Request, call_next): |
|
"""Add X-Process-Time header to responses.""" |
|
start_time = time.time() |
|
response = await call_next(request) |
|
process_time = time.time() - start_time |
|
response.headers["X-Process-Time"] = str(process_time) |
|
return response |
|
|
|
@app.get("/health") |
|
async def health(): |
|
"""Health check endpoint.""" |
|
return {"status": "ok", "model_loaded": classifier is not None} |
|
|
|
@app.post("/upload") |
|
async def upload_audio( |
|
file: UploadFile = File(...), |
|
background_tasks: BackgroundTasks = None |
|
): |
|
""" |
|
Upload an audio file and analyze emotions. |
|
Saves the file to the uploads directory and returns model predictions. |
|
""" |
|
if not classifier: |
|
raise HTTPException(status_code=503, detail="Model not yet loaded") |
|
|
|
filename = Path(file.filename).name |
|
if not filename: |
|
raise HTTPException(status_code=400, detail="Invalid filename") |
|
|
|
|
|
valid_extensions = [".wav", ".mp3", ".ogg", ".flac"] |
|
if not any(filename.lower().endswith(ext) for ext in valid_extensions): |
|
raise HTTPException( |
|
status_code=400, |
|
detail=f"Invalid file type. Supported types: {', '.join(valid_extensions)}" |
|
) |
|
|
|
|
|
try: |
|
contents = await file.read() |
|
except Exception as e: |
|
logger.error("Error reading file %s: %s", filename, e) |
|
raise HTTPException(status_code=500, detail=f"Failed to read file: {str(e)}") |
|
finally: |
|
await file.close() |
|
|
|
|
|
if len(contents) > 10 * 1024 * 1024: |
|
raise HTTPException( |
|
status_code=413, |
|
detail="File too large. Maximum size is 10MB" |
|
) |
|
|
|
|
|
file_path = UPLOAD_DIR / filename |
|
try: |
|
with open(file_path, "wb") as f: |
|
f.write(contents) |
|
logger.info("Saved uploaded file: %s", file_path) |
|
except Exception as e: |
|
logger.error("Failed to save file %s: %s", filename, e) |
|
raise HTTPException(status_code=500, detail=f"Failed to save file: {str(e)}") |
|
|
|
|
|
try: |
|
results = classifier(str(file_path)) |
|
|
|
|
|
if background_tasks: |
|
background_tasks.add_task(cleanup_old_files) |
|
|
|
return {"filename": filename, "predictions": results} |
|
except Exception as e: |
|
logger.error("Model inference failed for %s: %s", filename, e) |
|
|
|
try: |
|
file_path.unlink(missing_ok=True) |
|
except Exception: |
|
pass |
|
raise HTTPException(status_code=500, detail=f"Emotion detection failed: {str(e)}") |
|
|
|
@app.get("/recordings") |
|
async def list_recordings(): |
|
""" |
|
List all uploaded recordings. |
|
Returns a JSON list of filenames in the uploads directory. |
|
""" |
|
try: |
|
files = [f.name for f in UPLOAD_DIR.iterdir() if f.is_file()] |
|
total, used, free = shutil.disk_usage(UPLOAD_DIR) |
|
storage_info = { |
|
"total_mb": total / (1024 * 1024), |
|
"used_mb": used / (1024 * 1024), |
|
"free_mb": free / (1024 * 1024) |
|
} |
|
return {"recordings": files, "storage": storage_info} |
|
except Exception as e: |
|
logger.error("Could not list files: %s", e) |
|
raise HTTPException(status_code=500, detail=f"Failed to list recordings: {str(e)}") |
|
|
|
@app.get("/recordings/{filename}") |
|
async def get_recording(filename: str): |
|
""" |
|
Stream/download an audio file from the server. |
|
""" |
|
safe_name = Path(filename).name |
|
file_path = UPLOAD_DIR / safe_name |
|
if not file_path.exists() or not file_path.is_file(): |
|
raise HTTPException(status_code=404, detail="Recording not found") |
|
|
|
import mimetypes |
|
media_type, _ = mimetypes.guess_type(file_path) |
|
return FileResponse( |
|
file_path, |
|
media_type=media_type or "application/octet-stream", |
|
filename=safe_name |
|
) |
|
|
|
@app.get("/analyze/{filename}") |
|
async def analyze_recording(filename: str): |
|
""" |
|
Analyze an already-uploaded recording by filename. |
|
Returns emotion predictions for the given file. |
|
""" |
|
if not classifier: |
|
raise HTTPException(status_code=503, detail="Model not yet loaded") |
|
|
|
safe_name = Path(filename).name |
|
file_path = UPLOAD_DIR / safe_name |
|
if not file_path.exists() or not file_path.is_file(): |
|
raise HTTPException(status_code=404, detail="Recording not found") |
|
try: |
|
results = classifier(str(file_path)) |
|
except Exception as e: |
|
logger.error("Model inference failed for %s: %s", filename, e) |
|
raise HTTPException(status_code=500, detail=f"Emotion detection failed: {str(e)}") |
|
return {"filename": safe_name, "predictions": results} |
|
|
|
@app.delete("/recordings/{filename}") |
|
async def delete_recording(filename: str): |
|
""" |
|
Delete a recording by filename. |
|
""" |
|
safe_name = Path(filename).name |
|
file_path = UPLOAD_DIR / safe_name |
|
if not file_path.exists() or not file_path.is_file(): |
|
raise HTTPException(status_code=404, detail="Recording not found") |
|
try: |
|
file_path.unlink() |
|
return {"status": "success", "message": f"Deleted {safe_name}"} |
|
except Exception as e: |
|
logger.error("Failed to delete file %s: %s", filename, e) |
|
raise HTTPException(status_code=500, detail=f"Failed to delete file: {str(e)}") |
|
|
|
if __name__ == "__main__": |
|
|
|
uvicorn.run(app, host="0.0.0.0", port=7860, log_level="info") |
|
|