Spaces:
Sleeping
Sleeping
# app.py | |
import os | |
import uuid | |
import shutil | |
import tempfile | |
from fastapi import FastAPI, UploadFile, File, HTTPException, Query | |
from fastapi.middleware.cors import CORSMiddleware | |
from fastapi.responses import JSONResponse | |
from stt import SpeechToText | |
# ----------------------------------------------------------------------------- | |
# CONFIGURATION | |
# ----------------------------------------------------------------------------- | |
MODEL_NAME = os.getenv("WHISPER_MODEL", "base") | |
DEFAULT_DUR = float(os.getenv("RECORD_DURATION", "5.0")) | |
TEMP_DIR = os.getenv("TEMP_DIR", tempfile.gettempdir()) | |
ALLOWED_TYPES = {"audio/wav", "audio/x-wav", "audio/mpeg", "audio/mp3"} | |
# ----------------------------------------------------------------------------- | |
app = FastAPI( | |
title="STT Service", | |
description="Speech-to-Text API using pywhispercpp's Whisper", | |
version="1.0", | |
) | |
# Allow any origin (adjust for production) | |
app.add_middleware( | |
CORSMiddleware, | |
allow_origins=["*"], | |
allow_methods=["POST", "GET", "OPTIONS"], | |
allow_headers=["*"], | |
) | |
# Load the STT engine once at startup | |
stt_engine = SpeechToText( | |
model_name=MODEL_NAME, | |
sample_rate=16_000, | |
record_duration=DEFAULT_DUR, | |
temp_dir=TEMP_DIR, | |
verbose=False, # mute console logs in API | |
) | |
def health(): | |
return {"status": "ok", "model": MODEL_NAME} | |
async def transcribe_audio( | |
file: UploadFile = File(..., description="An audio file (WAV, MP3, etc.)"), | |
): | |
if file.content_type not in ALLOWED_TYPES: | |
raise HTTPException(415, detail=f"Unsupported Media Type: {file.content_type}") | |
# 1) save upload to temp WAV path | |
ext = os.path.splitext(file.filename)[1] or ".wav" | |
tmp_name = f"{uuid.uuid4()}{ext}" | |
tmp_path = os.path.join(TEMP_DIR, tmp_name) | |
try: | |
with open(tmp_path, "wb") as out_f: | |
shutil.copyfileobj(file.file, out_f) | |
# 2) run transcription | |
text = stt_engine.transcribe_file(tmp_path) | |
return {"text": text} | |
except HTTPException: | |
raise | |
except Exception as e: | |
raise HTTPException(500, detail=str(e)) | |
finally: | |
# clean up | |
if os.path.exists(tmp_path): | |
os.remove(tmp_path) | |
def record_and_transcribe( | |
duration: float = Query( | |
DEFAULT_DUR, gt=0, le=30, | |
description="Seconds to record from server mic" | |
) | |
): | |
""" | |
Records from the server's default microphone for `duration` seconds, | |
then transcribes that chunk of audio. | |
""" | |
try: | |
# temporarily override record_duration | |
original = stt_engine.record_duration | |
stt_engine.record_duration = duration | |
# record & transcribe | |
text = stt_engine.transcribe(save_temp=False) | |
return {"text": text} | |
except Exception as e: | |
raise HTTPException(500, detail=str(e)) | |
finally: | |
stt_engine.record_duration = original | |
# If you run with `python app.py`, this block ensures Uvicorn starts | |
if __name__ == "__main__": | |
import uvicorn | |
uvicorn.run( | |
"app:app", | |
host="0.0.0.0", | |
port=int(os.getenv("PORT", 7860)), | |
reload=True | |
) | |