File size: 3,329 Bytes
0bf8a29
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5d247a3
0bf8a29
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6c3ae3e
0bf8a29
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
# app.py
import os
import uuid
import shutil
import tempfile

from fastapi import FastAPI, UploadFile, File, HTTPException, Query
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse

from stt import SpeechToText

# -----------------------------------------------------------------------------
# CONFIGURATION
# -----------------------------------------------------------------------------
MODEL_NAME      = os.getenv("WHISPER_MODEL", "base")
DEFAULT_DUR     = float(os.getenv("RECORD_DURATION", "5.0"))
TEMP_DIR        = os.getenv("TEMP_DIR", tempfile.gettempdir())
ALLOWED_TYPES   = {"audio/wav", "audio/x-wav", "audio/mpeg", "audio/mp3"}
# -----------------------------------------------------------------------------

app = FastAPI(
    title="STT Service",
    description="Speech-to-Text API using pywhispercpp's Whisper",
    version="1.0",
)

# Allow any origin (adjust for production)
app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],
    allow_methods=["POST", "GET", "OPTIONS"],
    allow_headers=["*"],
)

# Load the STT engine once at startup
stt_engine = SpeechToText(
    model_name=MODEL_NAME,
    sample_rate=16_000,
    record_duration=DEFAULT_DUR,
    temp_dir=TEMP_DIR,
    verbose=False,       # mute console logs in API
)


@app.get("/health", summary="Health check")
def health():
    return {"status": "ok", "model": MODEL_NAME}


@app.post("/transcribe", summary="Transcribe uploaded audio file")
async def transcribe_audio(
    file: UploadFile = File(..., description="An audio file (WAV, MP3, etc.)"),
):
    if file.content_type not in ALLOWED_TYPES:
        raise HTTPException(415, detail=f"Unsupported Media Type: {file.content_type}")

    # 1) save upload to temp WAV path
    ext = os.path.splitext(file.filename)[1] or ".wav"
    tmp_name = f"{uuid.uuid4()}{ext}"
    tmp_path = os.path.join(TEMP_DIR, tmp_name)

    try:
        with open(tmp_path, "wb") as out_f:
            shutil.copyfileobj(file.file, out_f)
        # 2) run transcription
        text = stt_engine.transcribe_file(tmp_path)
        return {"text": text}
    except HTTPException:
        raise
    except Exception as e:
        raise HTTPException(500, detail=str(e))
    finally:
        # clean up
        if os.path.exists(tmp_path):
            os.remove(tmp_path)


@app.post("/record", summary="Record from mic + transcribe")
def record_and_transcribe(
    duration: float = Query(
        DEFAULT_DUR, gt=0, le=30,
        description="Seconds to record from server mic"
    )
):
    """
    Records from the server's default microphone for `duration` seconds,
    then transcribes that chunk of audio.
    """
    try:
        # temporarily override record_duration
        original = stt_engine.record_duration
        stt_engine.record_duration = duration

        # record & transcribe
        text = stt_engine.transcribe(save_temp=False)
        return {"text": text}
    except Exception as e:
        raise HTTPException(500, detail=str(e))
    finally:
        stt_engine.record_duration = original


# If you run with `python app.py`, this block ensures Uvicorn starts
if __name__ == "__main__":
    import uvicorn

    uvicorn.run(
        "app:app",
        host="0.0.0.0",
        port=int(os.getenv("PORT", 7860)),
        reload=True
    )