ChandimaPrabath commited on
Commit
4b6cfde
·
verified ·
1 Parent(s): a36b5bb

Upload stt.py

Browse files
Files changed (1) hide show
  1. stt.py +188 -0
stt.py ADDED
@@ -0,0 +1,188 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # stt.py
2
+ # -*- coding: utf-8 -*-
3
+ """
4
+ stt.py — a module for Speech-to-Text via pywhispercpp
5
+
6
+ Classes
7
+ -------
8
+ SpeechToText
9
+ Encapsulates model loading, recording, saving, and transcription.
10
+
11
+ Usage (as a script)
12
+ -------------------
13
+ python -m stt --model tiny.en --duration 5
14
+
15
+ or in code
16
+ -----------
17
+ from stt import SpeechToText
18
+ stt = SpeechToText()
19
+ text = stt.transcribe()
20
+ """
21
+
22
+ import os
23
+ import tempfile
24
+ import time
25
+ import datetime
26
+ import numpy as np
27
+ import sounddevice as sd
28
+ import soundfile as sf
29
+ from scipy.io.wavfile import write as write_wav
30
+ import webrtcvad
31
+ from pywhispercpp.model import Model as Whisper
32
+ from qdrant_client import QdrantClient
33
+ from qdrant_client.http.models import Distance, VectorParams
34
+
35
+
36
+ class SpeechToText:
37
+ """
38
+ A Speech-to-Text helper using pywhispercpp's Whisper + Qdrant for speaker metadata.
39
+
40
+ Parameters
41
+ ----------
42
+ model_name : str
43
+ Whisper model to load (e.g. "tiny.en", "base", "small.en", etc.).
44
+ sample_rate : int
45
+ Audio sample rate (must match Whisper's 16 kHz).
46
+ record_duration : float
47
+ Default seconds to record when calling `.record_audio()`.
48
+ temp_dir : str
49
+ Directory for temporary WAV files.
50
+ verbose : bool
51
+ Print progress messages if True.
52
+ """
53
+
54
+ def __init__(
55
+ self,
56
+ model_name: str = "tiny.en",
57
+ sample_rate: int = 16_000,
58
+ record_duration: float = 5.0,
59
+ temp_dir: str = None,
60
+ verbose: bool = True,
61
+ ):
62
+ self.model_name = model_name
63
+ self.sample_rate = sample_rate
64
+ self.record_duration = record_duration
65
+ self.temp_dir = temp_dir or tempfile.gettempdir()
66
+ self.verbose = verbose
67
+
68
+ # load Whisper model
69
+ if self.verbose:
70
+ print(f"[STT] Loading Whisper model '{self.model_name}'...")
71
+ t0 = time.time()
72
+ self._model = Whisper(model=self.model_name)
73
+ if self.verbose:
74
+ print(f"[STT] Model loaded in {time.time() - t0:.2f}s")
75
+
76
+ def record_audio(self, duration: float = None) -> np.ndarray:
77
+ """
78
+ Record from the default mic for `duration` seconds, return float32 mono waveform.
79
+ """
80
+ duration = duration or self.record_duration
81
+ if self.verbose:
82
+ print(f"[STT] Recording for {duration}s at {self.sample_rate}Hz...")
83
+ frames = sd.rec(
84
+ int(duration * self.sample_rate),
85
+ samplerate=self.sample_rate,
86
+ channels=1,
87
+ dtype="int16",
88
+ )
89
+ sd.wait()
90
+ if self.verbose:
91
+ print("[STT] Recording finished.")
92
+ # convert to float32 in [-1, 1]
93
+ return (frames.astype(np.float32) / 32768.0).flatten()
94
+
95
+ def save_wav(self, audio: np.ndarray, filename: str = None) -> str:
96
+ """
97
+ Save float32 waveform `audio` to an int16 WAV at `filename`.
98
+ If filename is None, create one in temp_dir.
99
+ Returns the path.
100
+ """
101
+ filename = filename or os.path.join(
102
+ self.temp_dir,
103
+ f"stt_{datetime.datetime.now():%Y%m%d_%H%M%S}.wav"
104
+ )
105
+ os.makedirs(os.path.dirname(filename), exist_ok=True)
106
+
107
+ # convert back to int16
108
+ int16 = (audio * 32767).astype(np.int16)
109
+ write_wav(filename, self.sample_rate, int16)
110
+ if self.verbose:
111
+ print(f"[STT] Saved WAV to {filename}")
112
+ return filename
113
+
114
+ def transcribe_file(self, wav_path: str, n_threads: int = 4) -> str:
115
+ """
116
+ Transcribe existing WAV file at `wav_path`. Returns the text.
117
+ """
118
+ if not os.path.isfile(wav_path):
119
+ raise FileNotFoundError(f"No such file: {wav_path}")
120
+ if self.verbose:
121
+ print(f"[STT] Transcribing file {wav_path}…")
122
+ t0 = time.time()
123
+
124
+ # pywhispercpp API may return segments or text
125
+ result = self._model.transcribe(wav_path, n_threads=n_threads)
126
+ # cleanup temp if in our temp_dir
127
+ if wav_path.startswith(self.temp_dir):
128
+ try:
129
+ os.remove(wav_path)
130
+ except OSError:
131
+ pass
132
+
133
+ # collect text
134
+ if isinstance(result, list):
135
+ text = "".join([seg.text for seg in result])
136
+ else:
137
+ # assume Whisper stores text internally
138
+ text = self._model.get_text()
139
+
140
+ if self.verbose:
141
+ print(f"[STT] Transcription complete ({time.time() - t0:.2f}s).")
142
+ return text.strip()
143
+
144
+ def transcribe(
145
+ self,
146
+ duration: float = None,
147
+ save_temp: bool = False,
148
+ n_threads: int = 4,
149
+ ) -> str:
150
+ """
151
+ Record + save (optional) + transcribe in one call.
152
+ Returns the transcribed text.
153
+ """
154
+ audio = self.record_audio(duration)
155
+ wav_path = self.save_wav(audio) if save_temp else self.save_wav(audio)
156
+ return self.transcribe_file(wav_path, n_threads=n_threads)
157
+
158
+
159
+ # Optional: make module runnable as a script
160
+ if __name__ == "__main__":
161
+ import argparse
162
+
163
+ parser = argparse.ArgumentParser(description="STT using pywhispercpp")
164
+ parser.add_argument(
165
+ "--model", "-m",
166
+ default="small.en",
167
+ help="Whisper model name (e.g. tiny.en, base, small.en)",
168
+ )
169
+ parser.add_argument(
170
+ "--duration", "-d",
171
+ type=float,
172
+ default=5.0,
173
+ help="Seconds to record",
174
+ )
175
+ parser.add_argument(
176
+ "--no-save", action="store_true",
177
+ help="Do not save the recorded WAV",
178
+ )
179
+ args = parser.parse_args()
180
+
181
+ stt = SpeechToText(
182
+ model_name=args.model,
183
+ record_duration=args.duration,
184
+ verbose=True
185
+ )
186
+ text = stt.transcribe(save_temp=not args.no_save)
187
+ print("\n=== Transcription ===")
188
+ print(text)