stt / stt.py
ChandimaPrabath's picture
Update stt.py
e7c3e13 verified
# stt.py
# -*- coding: utf-8 -*-
"""
stt.py — a module for Speech-to-Text via pywhispercpp
Classes
-------
SpeechToText
Encapsulates model loading, recording, saving, and transcription.
Usage (as a script)
-------------------
python -m stt --model tiny.en --duration 5
or in code
-----------
from stt import SpeechToText
stt = SpeechToText()
text = stt.transcribe()
"""
import os
import tempfile
import time
import datetime
import numpy as np
import sounddevice as sd
from scipy.io.wavfile import write as write_wav
import webrtcvad
from pywhispercpp.model import Model as Whisper
class SpeechToText:
"""
A Speech-to-Text helper using pywhispercpp's Whisper + Qdrant for speaker metadata.
Parameters
----------
model_name : str
Whisper model to load (e.g. "tiny.en", "base", "small.en", etc.).
sample_rate : int
Audio sample rate (must match Whisper's 16 kHz).
record_duration : float
Default seconds to record when calling `.record_audio()`.
temp_dir : str
Directory for temporary WAV files.
verbose : bool
Print progress messages if True.
"""
def __init__(
self,
model_name: str = "tiny.en",
sample_rate: int = 16_000,
record_duration: float = 5.0,
temp_dir: str = None,
verbose: bool = True,
):
self.model_name = model_name
self.sample_rate = sample_rate
self.record_duration = record_duration
self.temp_dir = temp_dir or tempfile.gettempdir()
self.verbose = verbose
# load Whisper model
if self.verbose:
print(f"[STT] Loading Whisper model '{self.model_name}'...")
t0 = time.time()
self._model = Whisper(model=self.model_name)
if self.verbose:
print(f"[STT] Model loaded in {time.time() - t0:.2f}s")
def record_audio(self, duration: float = None) -> np.ndarray:
"""
Record from the default mic for `duration` seconds, return float32 mono waveform.
"""
duration = duration or self.record_duration
if self.verbose:
print(f"[STT] Recording for {duration}s at {self.sample_rate}Hz...")
frames = sd.rec(
int(duration * self.sample_rate),
samplerate=self.sample_rate,
channels=1,
dtype="int16",
)
sd.wait()
if self.verbose:
print("[STT] Recording finished.")
# convert to float32 in [-1, 1]
return (frames.astype(np.float32) / 32768.0).flatten()
def save_wav(self, audio: np.ndarray, filename: str = None) -> str:
"""
Save float32 waveform `audio` to an int16 WAV at `filename`.
If filename is None, create one in temp_dir.
Returns the path.
"""
filename = filename or os.path.join(
self.temp_dir,
f"stt_{datetime.datetime.now():%Y%m%d_%H%M%S}.wav"
)
os.makedirs(os.path.dirname(filename), exist_ok=True)
# convert back to int16
int16 = (audio * 32767).astype(np.int16)
write_wav(filename, self.sample_rate, int16)
if self.verbose:
print(f"[STT] Saved WAV to {filename}")
return filename
def transcribe_file(self, wav_path: str, n_threads: int = 4) -> str:
"""
Transcribe existing WAV file at `wav_path`. Returns the text.
"""
if not os.path.isfile(wav_path):
raise FileNotFoundError(f"No such file: {wav_path}")
if self.verbose:
print(f"[STT] Transcribing file {wav_path}…")
t0 = time.time()
# pywhispercpp API may return segments or text
result = self._model.transcribe(wav_path, n_threads=n_threads)
# cleanup temp if in our temp_dir
if wav_path.startswith(self.temp_dir):
try:
os.remove(wav_path)
except OSError:
pass
# collect text
if isinstance(result, list):
text = "".join([seg.text for seg in result])
else:
# assume Whisper stores text internally
text = self._model.get_text()
if self.verbose:
print(f"[STT] Transcription complete ({time.time() - t0:.2f}s).")
return text.strip()
def transcribe(
self,
duration: float = None,
save_temp: bool = False,
n_threads: int = 4,
) -> str:
"""
Record + save (optional) + transcribe in one call.
Returns the transcribed text.
"""
audio = self.record_audio(duration)
wav_path = self.save_wav(audio) if save_temp else self.save_wav(audio)
return self.transcribe_file(wav_path, n_threads=n_threads)
# Optional: make module runnable as a script
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser(description="STT using pywhispercpp")
parser.add_argument(
"--model", "-m",
default="small.en",
help="Whisper model name (e.g. tiny.en, base, small.en)",
)
parser.add_argument(
"--duration", "-d",
type=float,
default=5.0,
help="Seconds to record",
)
parser.add_argument(
"--no-save", action="store_true",
help="Do not save the recorded WAV",
)
args = parser.parse_args()
stt = SpeechToText(
model_name=args.model,
record_duration=args.duration,
verbose=True
)
text = stt.transcribe(save_temp=not args.no_save)
print("\n=== Transcription ===")
print(text)