bcci commited on
Commit
de2f549
·
verified ·
1 Parent(s): 7dd345d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +32 -197
app.py CHANGED
@@ -4,63 +4,29 @@ import numpy as np
4
  from fastapi import FastAPI, WebSocket, WebSocketDisconnect
5
  from fastapi.responses import HTMLResponse
6
 
7
- # Import your model and VAD libraries.
8
  from silero_vad import VADIterator, load_silero_vad
9
-
10
  from transformers import AutoProcessor, pipeline, WhisperTokenizerFast
11
  from optimum.onnxruntime import ORTModelForSpeechSeq2Seq
12
 
13
- processor = AutoProcessor.from_pretrained("onnx-community/whisper-tiny.en")
14
- model = ORTModelForSpeechSeq2Seq.from_pretrained("onnx-community/whisper-tiny.en", subfolder="onnx")
15
- tokenizer = WhisperTokenizerFast.from_pretrained("onnx-community/whisper-tiny.en", language="english")
16
- pipe = pipeline("automatic-speech-recognition", model=model, tokenizer=tokenizer, feature_extractor=processor.feature_extractor)
 
 
 
 
 
 
17
 
18
  # Constants
19
  SAMPLING_RATE = 16000
20
- CHUNK_SIZE = 512 # Required for Silero VAD at 16kHz.
21
  LOOKBACK_CHUNKS = 5
22
- MAX_SPEECH_SECS = 15 # Maximum duration for a single transcription segment.
23
  MIN_REFRESH_SECS = 1
24
 
25
  app = FastAPI()
26
-
27
- # class Transcriber:
28
- # def __init__(self, model_name: str, rate: int = 16000):
29
- # if rate != 16000:
30
- # raise ValueError("Moonshine supports sampling rate 16000 Hz.")
31
- # self.model = MoonshineOnnxModel(model_name=model_name)
32
- # self.rate = rate
33
- # self.tokenizer = load_tokenizer()
34
- # # Statistics (optional)
35
- # self.inference_secs = 0
36
- # self.number_inferences = 0
37
- # self.speech_secs = 0
38
- # # Warmup run.
39
- # self.__call__(np.zeros(int(rate), dtype=np.float32))
40
-
41
- # def __call__(self, speech: np.ndarray) -> str:
42
- # """Returns a transcription of the given speech (a float32 numpy array)."""
43
- # self.number_inferences += 1
44
- # self.speech_secs += len(speech) / self.rate
45
- # start_time = time.time()
46
- # tokens = self.model.generate(speech[np.newaxis, :].astype(np.float32))
47
- # text = self.tokenizer.decode_batch(tokens)[0]
48
- # self.inference_secs += time.time() - start_time
49
- # return text
50
-
51
- def pcm16_to_float32(pcm_data: bytes) -> np.ndarray:
52
- """
53
- Convert 16-bit PCM bytes into a float32 numpy array with values in [-1, 1].
54
- """
55
- int_data = np.frombuffer(pcm_data, dtype=np.int16)
56
- float_data = int_data.astype(np.float32) / 32768.0
57
- return float_data
58
-
59
- # Initialize models.
60
- # model_name_tiny = "moonshine/tiny"
61
- # model_name_base = "moonshine/base"
62
- # transcriber_tiny = Transcriber(model_name=model_name_tiny, rate=SAMPLING_RATE)
63
- # transcriber_base = Transcriber(model_name=model_name_base, rate=SAMPLING_RATE)
64
  vad_model = load_silero_vad(onnx=True)
65
  vad_iterator = VADIterator(
66
  model=vad_model,
@@ -72,33 +38,30 @@ vad_iterator = VADIterator(
72
  @app.websocket("/ws/transcribe")
73
  async def websocket_endpoint(websocket: WebSocket):
74
  await websocket.accept()
75
-
76
  caption_cache = []
77
- lookback_size = LOOKBACK_CHUNKS * CHUNK_SIZE
78
  speech = np.empty(0, dtype=np.float32)
79
  recording = False
80
  last_partial_time = time.time()
81
- # current_model = transcriber_tiny # Default to tiny model
82
- last_output = ""
83
-
84
  try:
85
  while True:
86
  data = await websocket.receive()
87
  if data["type"] == "websocket.receive":
88
  if data.get("text") == "switch_to_tiny":
89
- # current_model = transcriber_tiny
90
  continue
91
  elif data.get("text") == "switch_to_base":
92
- # current_model = transcriber_base
93
  continue
94
-
95
  chunk = pcm16_to_float32(data["bytes"])
96
  speech = np.concatenate((speech, chunk))
97
  if not recording:
98
- speech = speech[-lookback_size:]
99
-
100
  vad_result = vad_iterator(chunk)
101
- current_time = time.time()
102
 
103
  if vad_result:
104
  if "start" in vad_result and not recording:
@@ -107,18 +70,7 @@ async def websocket_endpoint(websocket: WebSocket):
107
 
108
  if "end" in vad_result and recording:
109
  recording = False
110
- text = pipe({"sampling_rate": 16000, "raw": speech})["text"]
111
- await websocket.send_json({"type": "final", "transcript": text})
112
- caption_cache.append(text)
113
- speech = np.empty(0, dtype=np.float32)
114
- vad_iterator.triggered = False
115
- vad_iterator.temp_end = 0
116
- vad_iterator.current_sample = 0
117
- await websocket.send_json({"type": "status", "message": "speaking_stopped"})
118
- elif recording:
119
- if (len(speech) / SAMPLING_RATE) > MAX_SPEECH_SECS:
120
- recording = False
121
- text = pipe({"sampling_rate": 16000, "raw": speech})["text"]
122
  await websocket.send_json({"type": "final", "transcript": text})
123
  caption_cache.append(text)
124
  speech = np.empty(0, dtype=np.float32)
@@ -126,157 +78,40 @@ async def websocket_endpoint(websocket: WebSocket):
126
  vad_iterator.temp_end = 0
127
  vad_iterator.current_sample = 0
128
  await websocket.send_json({"type": "status", "message": "speaking_stopped"})
129
-
130
- # if (current_time - last_partial_time) > MIN_REFRESH_SECS:
131
- # text = pipe({"sampling_rate": 16000, "raw": speech})["text"]
132
- # if last_output != text:
133
- # last_output = text
134
- # await websocket.send_json({"type": "partial", "transcript": text})
135
- # last_partial_time = current_time
136
  except WebSocketDisconnect:
137
  if recording and speech.size:
138
- text = pipe({"sampling_rate": 16000, "raw": speech})["text"]
139
  await websocket.send_json({"type": "final", "transcript": text})
140
  print("WebSocket disconnected")
141
 
142
  @app.get("/", response_class=HTMLResponse)
143
  async def get_home():
144
  return """
145
- <!DOCTYPE html>
146
  <html>
147
- <head>
148
- <meta charset="UTF-8">
149
- <title>AssemblyAI Realtime Transcription</title>
150
- <link href="https://cdn.jsdelivr.net/npm/[email protected]/dist/tailwind.min.css" rel="stylesheet">
151
- </head>
152
- <body class="bg-gray-100 p-6">
153
- <div class="max-w-3xl mx-auto bg-white p-6 rounded-lg shadow-md">
154
- <h1 class="text-2xl font-bold mb-4">Realtime Transcription</h1>
155
- <button onclick="startTranscription()" class="bg-blue-500 text-white px-4 py-2 rounded mb-4">Start Transcription</button>
156
- <select id="modelSelect" onchange="switchModel()" class="bg-gray-200 px-4 py-2 rounded mb-4">
157
  <option value="tiny">Tiny Model</option>
158
  <option value="base">Base Model</option>
159
  </select>
160
- <p id="status" class="text-gray-600 mb-4">Click start to begin transcription.</p>
161
- <p id="speakingStatus" class="text-gray-600 mb-4"></p>
162
- <div id="transcription" class="border p-4 rounded mb-4 h-64 overflow-auto"></div>
163
- <div id="visualizer" class="border p-4 rounded h-64">
164
- <canvas id="audioCanvas" class="w-full h-full"></canvas>
165
- </div>
166
- </div>
167
  <script>
168
  let ws;
169
- let audioContext;
170
- let scriptProcessor;
171
- let mediaStream;
172
- let currentLine = document.createElement('span');
173
- let analyser;
174
- let canvas, canvasContext;
175
-
176
- document.getElementById('transcription').appendChild(currentLine);
177
- canvas = document.getElementById('audioCanvas');
178
- canvasContext = canvas.getContext('2d');
179
-
180
- async function startTranscription() {
181
- document.getElementById("status").innerText = "Connecting...";
182
  ws = new WebSocket("wss://" + location.host + "/ws/transcribe");
183
- ws.binaryType = 'arraybuffer';
184
-
185
- ws.onopen = async function() {
186
- document.getElementById("status").innerText = "Connected";
187
- try {
188
- mediaStream = await navigator.mediaDevices.getUserMedia({ audio: true });
189
- audioContext = new AudioContext({ sampleRate: 16000 });
190
- const source = audioContext.createMediaStreamSource(mediaStream);
191
- analyser = audioContext.createAnalyser();
192
- analyser.fftSize = 2048;
193
- const bufferLength = analyser.frequencyBinCount;
194
- const dataArray = new Uint8Array(bufferLength);
195
- source.connect(analyser);
196
- scriptProcessor = audioContext.createScriptProcessor(512, 1, 1);
197
- scriptProcessor.onaudioprocess = function(event) {
198
- const inputData = event.inputBuffer.getChannelData(0);
199
- const pcm16 = floatTo16BitPCM(inputData);
200
- if (ws.readyState === WebSocket.OPEN) {
201
- ws.send(pcm16);
202
- }
203
- analyser.getByteTimeDomainData(dataArray);
204
- canvasContext.fillStyle = 'rgb(200, 200, 200)';
205
- canvasContext.fillRect(0, 0, canvas.width, canvas.height);
206
- canvasContext.lineWidth = 2;
207
- canvasContext.strokeStyle = 'rgb(0, 0, 0)';
208
- canvasContext.beginPath();
209
- let sliceWidth = canvas.width * 1.0 / bufferLength;
210
- let x = 0;
211
- for (let i = 0; i < bufferLength; i++) {
212
- let v = dataArray[i] / 128.0;
213
- let y = v * canvas.height / 2;
214
- if (i === 0) {
215
- canvasContext.moveTo(x, y);
216
- } else {
217
- canvasContext.lineTo(x, y);
218
- }
219
- x += sliceWidth;
220
- }
221
- canvasContext.lineTo(canvas.width, canvas.height / 2);
222
- canvasContext.stroke();
223
- };
224
- source.connect(scriptProcessor);
225
- scriptProcessor.connect(audioContext.destination);
226
- } catch (err) {
227
- document.getElementById("status").innerText = "Error: " + err;
228
- }
229
- };
230
-
231
  ws.onmessage = function(event) {
232
  const data = JSON.parse(event.data);
233
- if (data.type === 'partial') {
234
- currentLine.style.color = 'gray';
235
- currentLine.textContent = data.transcript + ' ';
236
- } else if (data.type === 'final') {
237
- currentLine.style.color = 'black';
238
- currentLine.textContent = data.transcript;
239
- currentLine = document.createElement('span');
240
- document.getElementById('transcription').appendChild(document.createElement('br'));
241
- document.getElementById('transcription').appendChild(currentLine);
242
- } else if (data.type === 'status') {
243
- if (data.message === 'speaking_started') {
244
- document.getElementById("speakingStatus").innerText = "Speaking Started";
245
- document.getElementById("speakingStatus").style.color = "green";
246
- } else if (data.message === 'speaking_stopped') {
247
- document.getElementById("speakingStatus").innerText = "Speaking Stopped";
248
- document.getElementById("speakingStatus").style.color = "red";
249
- }
250
  }
251
  };
252
-
253
- ws.onclose = function() {
254
- if (audioContext && audioContext.state !== 'closed') {
255
- audioContext.close();
256
- }
257
- document.getElementById("status").innerText = "Closed";
258
- };
259
  }
260
-
261
  function switchModel() {
262
  const model = document.getElementById("modelSelect").value;
263
  if (ws && ws.readyState === WebSocket.OPEN) {
264
- if (model === "tiny") {
265
- ws.send("switch_to_tiny");
266
- } else if (model === "base") {
267
- ws.send("switch_to_base");
268
- }
269
- }
270
- }
271
-
272
- function floatTo16BitPCM(input) {
273
- const buffer = new ArrayBuffer(input.length * 2);
274
- const output = new DataView(buffer);
275
- for (let i = 0; i < input.length; i++) {
276
- let s = Math.max(-1, Math.min(1, input[i]));
277
- output.setInt16(i * 2, s < 0 ? s * 0x8000 : s * 0x7FFF, true);
278
  }
279
- return buffer;
280
  }
281
  </script>
282
  </body>
@@ -285,4 +120,4 @@ async def get_home():
285
 
286
  if __name__ == "__main__":
287
  import uvicorn
288
- uvicorn.run(app, host="0.0.0.0", port=7860)
 
4
  from fastapi import FastAPI, WebSocket, WebSocketDisconnect
5
  from fastapi.responses import HTMLResponse
6
 
 
7
  from silero_vad import VADIterator, load_silero_vad
 
8
  from transformers import AutoProcessor, pipeline, WhisperTokenizerFast
9
  from optimum.onnxruntime import ORTModelForSpeechSeq2Seq
10
 
11
+ # Load models
12
+ processor_tiny = AutoProcessor.from_pretrained("onnx-community/whisper-tiny.en")
13
+ model_tiny = ORTModelForSpeechSeq2Seq.from_pretrained("onnx-community/whisper-tiny.en", subfolder="onnx")
14
+ tokenizer_tiny = WhisperTokenizerFast.from_pretrained("onnx-community/whisper-tiny.en", language="english")
15
+ pipe_tiny = pipeline("automatic-speech-recognition", model=model_tiny, tokenizer=tokenizer_tiny, feature_extractor=processor_tiny.feature_extractor)
16
+
17
+ processor_base = AutoProcessor.from_pretrained("onnx-community/whisper-base.en")
18
+ model_base = ORTModelForSpeechSeq2Seq.from_pretrained("onnx-community/whisper-base.en", subfolder="onnx")
19
+ tokenizer_base = WhisperTokenizerFast.from_pretrained("onnx-community/whisper-base.en", language="english")
20
+ pipe_base = pipeline("automatic-speech-recognition", model=model_base, tokenizer=tokenizer_base, feature_extractor=processor_base.feature_extractor)
21
 
22
  # Constants
23
  SAMPLING_RATE = 16000
24
+ CHUNK_SIZE = 512
25
  LOOKBACK_CHUNKS = 5
26
+ MAX_SPEECH_SECS = 15
27
  MIN_REFRESH_SECS = 1
28
 
29
  app = FastAPI()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
30
  vad_model = load_silero_vad(onnx=True)
31
  vad_iterator = VADIterator(
32
  model=vad_model,
 
38
  @app.websocket("/ws/transcribe")
39
  async def websocket_endpoint(websocket: WebSocket):
40
  await websocket.accept()
41
+
42
  caption_cache = []
 
43
  speech = np.empty(0, dtype=np.float32)
44
  recording = False
45
  last_partial_time = time.time()
46
+ current_pipe = pipe_tiny
47
+
 
48
  try:
49
  while True:
50
  data = await websocket.receive()
51
  if data["type"] == "websocket.receive":
52
  if data.get("text") == "switch_to_tiny":
53
+ current_pipe = pipe_tiny
54
  continue
55
  elif data.get("text") == "switch_to_base":
56
+ current_pipe = pipe_base
57
  continue
58
+
59
  chunk = pcm16_to_float32(data["bytes"])
60
  speech = np.concatenate((speech, chunk))
61
  if not recording:
62
+ speech = speech[-(LOOKBACK_CHUNKS * CHUNK_SIZE):]
63
+
64
  vad_result = vad_iterator(chunk)
 
65
 
66
  if vad_result:
67
  if "start" in vad_result and not recording:
 
70
 
71
  if "end" in vad_result and recording:
72
  recording = False
73
+ text = current_pipe({"sampling_rate": 16000, "raw": speech})["text"]
 
 
 
 
 
 
 
 
 
 
 
74
  await websocket.send_json({"type": "final", "transcript": text})
75
  caption_cache.append(text)
76
  speech = np.empty(0, dtype=np.float32)
 
78
  vad_iterator.temp_end = 0
79
  vad_iterator.current_sample = 0
80
  await websocket.send_json({"type": "status", "message": "speaking_stopped"})
 
 
 
 
 
 
 
81
  except WebSocketDisconnect:
82
  if recording and speech.size:
83
+ text = current_pipe({"sampling_rate": 16000, "raw": speech})["text"]
84
  await websocket.send_json({"type": "final", "transcript": text})
85
  print("WebSocket disconnected")
86
 
87
  @app.get("/", response_class=HTMLResponse)
88
  async def get_home():
89
  return """
 
90
  <html>
91
+ <body>
92
+ <button onclick="startTranscription()">Start Transcription</button>
93
+ <select id="modelSelect" onchange="switchModel()">
 
 
 
 
 
 
 
94
  <option value="tiny">Tiny Model</option>
95
  <option value="base">Base Model</option>
96
  </select>
97
+ <p id="status">Click start to begin transcription.</p>
98
+ <div id="transcription"></div>
 
 
 
 
 
99
  <script>
100
  let ws;
101
+ function startTranscription() {
 
 
 
 
 
 
 
 
 
 
 
 
102
  ws = new WebSocket("wss://" + location.host + "/ws/transcribe");
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
103
  ws.onmessage = function(event) {
104
  const data = JSON.parse(event.data);
105
+ if (data.type === 'final') {
106
+ document.getElementById("transcription").innerHTML += `<p>${data.transcript}</p>`;
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
107
  }
108
  };
 
 
 
 
 
 
 
109
  }
 
110
  function switchModel() {
111
  const model = document.getElementById("modelSelect").value;
112
  if (ws && ws.readyState === WebSocket.OPEN) {
113
+ ws.send(model === "tiny" ? "switch_to_tiny" : "switch_to_base");
 
 
 
 
 
 
 
 
 
 
 
 
 
114
  }
 
115
  }
116
  </script>
117
  </body>
 
120
 
121
  if __name__ == "__main__":
122
  import uvicorn
123
+ uvicorn.run(app, host="0.0.0.0", port=7860)