ChandimaPrabath commited on
Commit
c390e38
·
verified ·
1 Parent(s): 457b066

Update stt.py

Browse files
Files changed (1) hide show
  1. stt.py +187 -188
stt.py CHANGED
@@ -1,188 +1,187 @@
1
- # stt.py
2
- # -*- coding: utf-8 -*-
3
- """
4
- stt.py — a module for Speech-to-Text via pywhispercpp
5
-
6
- Classes
7
- -------
8
- SpeechToText
9
- Encapsulates model loading, recording, saving, and transcription.
10
-
11
- Usage (as a script)
12
- -------------------
13
- python -m stt --model tiny.en --duration 5
14
-
15
- or in code
16
- -----------
17
- from stt import SpeechToText
18
- stt = SpeechToText()
19
- text = stt.transcribe()
20
- """
21
-
22
- import os
23
- import tempfile
24
- import time
25
- import datetime
26
- import numpy as np
27
- import sounddevice as sd
28
- import soundfile as sf
29
- from scipy.io.wavfile import write as write_wav
30
- import webrtcvad
31
- from pywhispercpp.model import Model as Whisper
32
- from qdrant_client import QdrantClient
33
- from qdrant_client.http.models import Distance, VectorParams
34
-
35
-
36
- class SpeechToText:
37
- """
38
- A Speech-to-Text helper using pywhispercpp's Whisper + Qdrant for speaker metadata.
39
-
40
- Parameters
41
- ----------
42
- model_name : str
43
- Whisper model to load (e.g. "tiny.en", "base", "small.en", etc.).
44
- sample_rate : int
45
- Audio sample rate (must match Whisper's 16 kHz).
46
- record_duration : float
47
- Default seconds to record when calling `.record_audio()`.
48
- temp_dir : str
49
- Directory for temporary WAV files.
50
- verbose : bool
51
- Print progress messages if True.
52
- """
53
-
54
- def __init__(
55
- self,
56
- model_name: str = "tiny.en",
57
- sample_rate: int = 16_000,
58
- record_duration: float = 5.0,
59
- temp_dir: str = None,
60
- verbose: bool = True,
61
- ):
62
- self.model_name = model_name
63
- self.sample_rate = sample_rate
64
- self.record_duration = record_duration
65
- self.temp_dir = temp_dir or tempfile.gettempdir()
66
- self.verbose = verbose
67
-
68
- # load Whisper model
69
- if self.verbose:
70
- print(f"[STT] Loading Whisper model '{self.model_name}'...")
71
- t0 = time.time()
72
- self._model = Whisper(model=self.model_name)
73
- if self.verbose:
74
- print(f"[STT] Model loaded in {time.time() - t0:.2f}s")
75
-
76
- def record_audio(self, duration: float = None) -> np.ndarray:
77
- """
78
- Record from the default mic for `duration` seconds, return float32 mono waveform.
79
- """
80
- duration = duration or self.record_duration
81
- if self.verbose:
82
- print(f"[STT] Recording for {duration}s at {self.sample_rate}Hz...")
83
- frames = sd.rec(
84
- int(duration * self.sample_rate),
85
- samplerate=self.sample_rate,
86
- channels=1,
87
- dtype="int16",
88
- )
89
- sd.wait()
90
- if self.verbose:
91
- print("[STT] Recording finished.")
92
- # convert to float32 in [-1, 1]
93
- return (frames.astype(np.float32) / 32768.0).flatten()
94
-
95
- def save_wav(self, audio: np.ndarray, filename: str = None) -> str:
96
- """
97
- Save float32 waveform `audio` to an int16 WAV at `filename`.
98
- If filename is None, create one in temp_dir.
99
- Returns the path.
100
- """
101
- filename = filename or os.path.join(
102
- self.temp_dir,
103
- f"stt_{datetime.datetime.now():%Y%m%d_%H%M%S}.wav"
104
- )
105
- os.makedirs(os.path.dirname(filename), exist_ok=True)
106
-
107
- # convert back to int16
108
- int16 = (audio * 32767).astype(np.int16)
109
- write_wav(filename, self.sample_rate, int16)
110
- if self.verbose:
111
- print(f"[STT] Saved WAV to {filename}")
112
- return filename
113
-
114
- def transcribe_file(self, wav_path: str, n_threads: int = 4) -> str:
115
- """
116
- Transcribe existing WAV file at `wav_path`. Returns the text.
117
- """
118
- if not os.path.isfile(wav_path):
119
- raise FileNotFoundError(f"No such file: {wav_path}")
120
- if self.verbose:
121
- print(f"[STT] Transcribing file {wav_path}…")
122
- t0 = time.time()
123
-
124
- # pywhispercpp API may return segments or text
125
- result = self._model.transcribe(wav_path, n_threads=n_threads)
126
- # cleanup temp if in our temp_dir
127
- if wav_path.startswith(self.temp_dir):
128
- try:
129
- os.remove(wav_path)
130
- except OSError:
131
- pass
132
-
133
- # collect text
134
- if isinstance(result, list):
135
- text = "".join([seg.text for seg in result])
136
- else:
137
- # assume Whisper stores text internally
138
- text = self._model.get_text()
139
-
140
- if self.verbose:
141
- print(f"[STT] Transcription complete ({time.time() - t0:.2f}s).")
142
- return text.strip()
143
-
144
- def transcribe(
145
- self,
146
- duration: float = None,
147
- save_temp: bool = False,
148
- n_threads: int = 4,
149
- ) -> str:
150
- """
151
- Record + save (optional) + transcribe in one call.
152
- Returns the transcribed text.
153
- """
154
- audio = self.record_audio(duration)
155
- wav_path = self.save_wav(audio) if save_temp else self.save_wav(audio)
156
- return self.transcribe_file(wav_path, n_threads=n_threads)
157
-
158
-
159
- # Optional: make module runnable as a script
160
- if __name__ == "__main__":
161
- import argparse
162
-
163
- parser = argparse.ArgumentParser(description="STT using pywhispercpp")
164
- parser.add_argument(
165
- "--model", "-m",
166
- default="small.en",
167
- help="Whisper model name (e.g. tiny.en, base, small.en)",
168
- )
169
- parser.add_argument(
170
- "--duration", "-d",
171
- type=float,
172
- default=5.0,
173
- help="Seconds to record",
174
- )
175
- parser.add_argument(
176
- "--no-save", action="store_true",
177
- help="Do not save the recorded WAV",
178
- )
179
- args = parser.parse_args()
180
-
181
- stt = SpeechToText(
182
- model_name=args.model,
183
- record_duration=args.duration,
184
- verbose=True
185
- )
186
- text = stt.transcribe(save_temp=not args.no_save)
187
- print("\n=== Transcription ===")
188
- print(text)
 
1
+ # stt.py
2
+ # -*- coding: utf-8 -*-
3
+ """
4
+ stt.py — a module for Speech-to-Text via pywhispercpp
5
+
6
+ Classes
7
+ -------
8
+ SpeechToText
9
+ Encapsulates model loading, recording, saving, and transcription.
10
+
11
+ Usage (as a script)
12
+ -------------------
13
+ python -m stt --model tiny.en --duration 5
14
+
15
+ or in code
16
+ -----------
17
+ from stt import SpeechToText
18
+ stt = SpeechToText()
19
+ text = stt.transcribe()
20
+ """
21
+
22
+ import os
23
+ import tempfile
24
+ import time
25
+ import datetime
26
+ import numpy as np
27
+ import sounddevice as sd
28
+ from scipy.io.wavfile import write as write_wav
29
+ import webrtcvad
30
+ from pywhispercpp.model import Model as Whisper
31
+ from qdrant_client import QdrantClient
32
+ from qdrant_client.http.models import Distance, VectorParams
33
+
34
+
35
+ class SpeechToText:
36
+ """
37
+ A Speech-to-Text helper using pywhispercpp's Whisper + Qdrant for speaker metadata.
38
+
39
+ Parameters
40
+ ----------
41
+ model_name : str
42
+ Whisper model to load (e.g. "tiny.en", "base", "small.en", etc.).
43
+ sample_rate : int
44
+ Audio sample rate (must match Whisper's 16 kHz).
45
+ record_duration : float
46
+ Default seconds to record when calling `.record_audio()`.
47
+ temp_dir : str
48
+ Directory for temporary WAV files.
49
+ verbose : bool
50
+ Print progress messages if True.
51
+ """
52
+
53
+ def __init__(
54
+ self,
55
+ model_name: str = "tiny.en",
56
+ sample_rate: int = 16_000,
57
+ record_duration: float = 5.0,
58
+ temp_dir: str = None,
59
+ verbose: bool = True,
60
+ ):
61
+ self.model_name = model_name
62
+ self.sample_rate = sample_rate
63
+ self.record_duration = record_duration
64
+ self.temp_dir = temp_dir or tempfile.gettempdir()
65
+ self.verbose = verbose
66
+
67
+ # load Whisper model
68
+ if self.verbose:
69
+ print(f"[STT] Loading Whisper model '{self.model_name}'...")
70
+ t0 = time.time()
71
+ self._model = Whisper(model=self.model_name)
72
+ if self.verbose:
73
+ print(f"[STT] Model loaded in {time.time() - t0:.2f}s")
74
+
75
+ def record_audio(self, duration: float = None) -> np.ndarray:
76
+ """
77
+ Record from the default mic for `duration` seconds, return float32 mono waveform.
78
+ """
79
+ duration = duration or self.record_duration
80
+ if self.verbose:
81
+ print(f"[STT] Recording for {duration}s at {self.sample_rate}Hz...")
82
+ frames = sd.rec(
83
+ int(duration * self.sample_rate),
84
+ samplerate=self.sample_rate,
85
+ channels=1,
86
+ dtype="int16",
87
+ )
88
+ sd.wait()
89
+ if self.verbose:
90
+ print("[STT] Recording finished.")
91
+ # convert to float32 in [-1, 1]
92
+ return (frames.astype(np.float32) / 32768.0).flatten()
93
+
94
+ def save_wav(self, audio: np.ndarray, filename: str = None) -> str:
95
+ """
96
+ Save float32 waveform `audio` to an int16 WAV at `filename`.
97
+ If filename is None, create one in temp_dir.
98
+ Returns the path.
99
+ """
100
+ filename = filename or os.path.join(
101
+ self.temp_dir,
102
+ f"stt_{datetime.datetime.now():%Y%m%d_%H%M%S}.wav"
103
+ )
104
+ os.makedirs(os.path.dirname(filename), exist_ok=True)
105
+
106
+ # convert back to int16
107
+ int16 = (audio * 32767).astype(np.int16)
108
+ write_wav(filename, self.sample_rate, int16)
109
+ if self.verbose:
110
+ print(f"[STT] Saved WAV to {filename}")
111
+ return filename
112
+
113
+ def transcribe_file(self, wav_path: str, n_threads: int = 4) -> str:
114
+ """
115
+ Transcribe existing WAV file at `wav_path`. Returns the text.
116
+ """
117
+ if not os.path.isfile(wav_path):
118
+ raise FileNotFoundError(f"No such file: {wav_path}")
119
+ if self.verbose:
120
+ print(f"[STT] Transcribing file {wav_path}…")
121
+ t0 = time.time()
122
+
123
+ # pywhispercpp API may return segments or text
124
+ result = self._model.transcribe(wav_path, n_threads=n_threads)
125
+ # cleanup temp if in our temp_dir
126
+ if wav_path.startswith(self.temp_dir):
127
+ try:
128
+ os.remove(wav_path)
129
+ except OSError:
130
+ pass
131
+
132
+ # collect text
133
+ if isinstance(result, list):
134
+ text = "".join([seg.text for seg in result])
135
+ else:
136
+ # assume Whisper stores text internally
137
+ text = self._model.get_text()
138
+
139
+ if self.verbose:
140
+ print(f"[STT] Transcription complete ({time.time() - t0:.2f}s).")
141
+ return text.strip()
142
+
143
+ def transcribe(
144
+ self,
145
+ duration: float = None,
146
+ save_temp: bool = False,
147
+ n_threads: int = 4,
148
+ ) -> str:
149
+ """
150
+ Record + save (optional) + transcribe in one call.
151
+ Returns the transcribed text.
152
+ """
153
+ audio = self.record_audio(duration)
154
+ wav_path = self.save_wav(audio) if save_temp else self.save_wav(audio)
155
+ return self.transcribe_file(wav_path, n_threads=n_threads)
156
+
157
+
158
+ # Optional: make module runnable as a script
159
+ if __name__ == "__main__":
160
+ import argparse
161
+
162
+ parser = argparse.ArgumentParser(description="STT using pywhispercpp")
163
+ parser.add_argument(
164
+ "--model", "-m",
165
+ default="small.en",
166
+ help="Whisper model name (e.g. tiny.en, base, small.en)",
167
+ )
168
+ parser.add_argument(
169
+ "--duration", "-d",
170
+ type=float,
171
+ default=5.0,
172
+ help="Seconds to record",
173
+ )
174
+ parser.add_argument(
175
+ "--no-save", action="store_true",
176
+ help="Do not save the recorded WAV",
177
+ )
178
+ args = parser.parse_args()
179
+
180
+ stt = SpeechToText(
181
+ model_name=args.model,
182
+ record_duration=args.duration,
183
+ verbose=True
184
+ )
185
+ text = stt.transcribe(save_temp=not args.no_save)
186
+ print("\n=== Transcription ===")
187
+ print(text)