Update app.py
Browse files
app.py
CHANGED
@@ -1,66 +1,207 @@
|
|
1 |
-
|
2 |
-
|
3 |
-
import
|
4 |
-
import
|
5 |
-
import
|
6 |
-
import os
|
7 |
-
from silero_vad import load_silero_vad, read_audio, get_speech_timestamps
|
8 |
-
model = load_silero_vad()
|
9 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
10 |
|
11 |
app = FastAPI()
|
12 |
|
13 |
-
|
14 |
-
|
15 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
16 |
|
17 |
-
|
18 |
-
|
19 |
-
|
20 |
-
|
21 |
-
|
22 |
-
|
23 |
-
|
|
|
|
|
24 |
|
25 |
-
|
26 |
-
async def transcribe_audio(file: UploadFile = File(...)):
|
27 |
"""
|
28 |
-
|
29 |
"""
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
try:
|
34 |
-
audio_bytes = await file.read()
|
35 |
-
audio_array, sampling_rate = librosa.load(io.BytesIO(audio_bytes), sr=processor.feature_extractor.sampling_rate)
|
36 |
-
|
37 |
-
# speech_timestamps = get_speech_timestamps(
|
38 |
-
# torch.from_numpy(audio_array),
|
39 |
-
# model,
|
40 |
-
# return_seconds=True, # Return speech timestamps in seconds (default is samples)
|
41 |
-
# )
|
42 |
|
43 |
-
|
|
|
|
|
44 |
|
45 |
-
|
46 |
-
|
47 |
-
|
48 |
-
|
49 |
-
|
50 |
-
|
|
|
|
|
|
|
|
|
51 |
|
52 |
-
|
53 |
-
|
54 |
-
|
|
|
|
|
55 |
|
56 |
-
|
57 |
-
|
58 |
-
|
59 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
60 |
|
61 |
-
|
62 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
63 |
|
64 |
if __name__ == "__main__":
|
65 |
import uvicorn
|
66 |
-
uvicorn.run(app, host="0.0.0.0", port=
|
|
|
1 |
+
import time
|
2 |
+
import asyncio
|
3 |
+
import numpy as np
|
4 |
+
from fastapi import FastAPI, WebSocket, WebSocketDisconnect
|
5 |
+
from fastapi.responses import HTMLResponse
|
|
|
|
|
|
|
6 |
|
7 |
+
# Import your model and VAD libraries.
|
8 |
+
from silero_vad import VADIterator, load_silero_vad
|
9 |
+
from moonshine_onnx import MoonshineOnnxModel, load_tokenizer
|
10 |
+
|
11 |
+
# Constants
|
12 |
+
SAMPLING_RATE = 16000
|
13 |
+
CHUNK_SIZE = 512 # Required for Silero VAD at 16kHz.
|
14 |
+
LOOKBACK_CHUNKS = 5
|
15 |
+
MAX_SPEECH_SECS = 15 # Maximum duration for a single transcription segment.
|
16 |
+
MIN_REFRESH_SECS = 0.2 # Minimum interval for sending partial updates.
|
17 |
|
18 |
app = FastAPI()
|
19 |
|
20 |
+
class Transcriber:
|
21 |
+
def __init__(self, model_name: str, rate: int = 16000):
|
22 |
+
if rate != 16000:
|
23 |
+
raise ValueError("Moonshine supports sampling rate 16000 Hz.")
|
24 |
+
self.model = MoonshineOnnxModel(model_name=model_name)
|
25 |
+
self.rate = rate
|
26 |
+
self.tokenizer = load_tokenizer()
|
27 |
+
# Statistics (optional)
|
28 |
+
self.inference_secs = 0
|
29 |
+
self.number_inferences = 0
|
30 |
+
self.speech_secs = 0
|
31 |
+
# Warmup run.
|
32 |
+
self.__call__(np.zeros(int(rate), dtype=np.float32))
|
33 |
|
34 |
+
def __call__(self, speech: np.ndarray) -> str:
|
35 |
+
"""Returns a transcription of the given speech (a float32 numpy array)."""
|
36 |
+
self.number_inferences += 1
|
37 |
+
self.speech_secs += len(speech) / self.rate
|
38 |
+
start_time = time.time()
|
39 |
+
tokens = self.model.generate(speech[np.newaxis, :].astype(np.float32))
|
40 |
+
text = self.tokenizer.decode_batch(tokens)[0]
|
41 |
+
self.inference_secs += time.time() - start_time
|
42 |
+
return text
|
43 |
|
44 |
+
def pcm16_to_float32(pcm_data: bytes) -> np.ndarray:
|
|
|
45 |
"""
|
46 |
+
Convert 16-bit PCM bytes into a float32 numpy array with values in [-1, 1].
|
47 |
"""
|
48 |
+
int_data = np.frombuffer(pcm_data, dtype=np.int16)
|
49 |
+
float_data = int_data.astype(np.float32) / 32768.0
|
50 |
+
return float_data
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
51 |
|
52 |
+
@app.websocket("/ws/transcribe")
|
53 |
+
async def websocket_endpoint(websocket: WebSocket):
|
54 |
+
await websocket.accept()
|
55 |
|
56 |
+
# Initialize models.
|
57 |
+
model_name = "moonshine/tiny"
|
58 |
+
transcriber = Transcriber(model_name=model_name, rate=SAMPLING_RATE)
|
59 |
+
vad_model = load_silero_vad(onnx=True)
|
60 |
+
vad_iterator = VADIterator(
|
61 |
+
model=vad_model,
|
62 |
+
sampling_rate=SAMPLING_RATE,
|
63 |
+
threshold=0.5,
|
64 |
+
min_silence_duration_ms=300,
|
65 |
+
)
|
66 |
|
67 |
+
caption_cache = []
|
68 |
+
lookback_size = LOOKBACK_CHUNKS * CHUNK_SIZE
|
69 |
+
speech = np.empty(0, dtype=np.float32)
|
70 |
+
recording = False
|
71 |
+
last_partial_time = time.time()
|
72 |
|
73 |
+
try:
|
74 |
+
while True:
|
75 |
+
# Wait for the next audio chunk (sent as binary data)
|
76 |
+
data = await websocket.receive_bytes()
|
77 |
+
# Convert the 16-bit PCM data to float32.
|
78 |
+
chunk = pcm16_to_float32(data)
|
79 |
+
speech = np.concatenate((speech, chunk))
|
80 |
+
if not recording:
|
81 |
+
# Retain only the last few chunks when not recording.
|
82 |
+
speech = speech[-lookback_size:]
|
83 |
+
|
84 |
+
# Process VAD on the current chunk.
|
85 |
+
vad_result = vad_iterator(chunk)
|
86 |
+
current_time = time.time()
|
87 |
+
if vad_result:
|
88 |
+
# If VAD signals the start of speech and we're not already recording.
|
89 |
+
if "start" in vad_result and not recording:
|
90 |
+
recording = True
|
91 |
+
start_time = current_time
|
92 |
+
# If VAD signals the end of speech.
|
93 |
+
if "end" in vad_result and recording:
|
94 |
+
recording = False
|
95 |
+
text = transcriber(speech)
|
96 |
+
await websocket.send_json({"type": "final", "transcript": text})
|
97 |
+
caption_cache.append(text)
|
98 |
+
speech = np.empty(0, dtype=np.float32)
|
99 |
+
# Reset VAD state.
|
100 |
+
vad_iterator.triggered = False
|
101 |
+
vad_iterator.temp_end = 0
|
102 |
+
vad_iterator.current_sample = 0
|
103 |
+
elif recording:
|
104 |
+
# If speech goes on too long, force an end.
|
105 |
+
if (len(speech) / SAMPLING_RATE) > MAX_SPEECH_SECS:
|
106 |
+
recording = False
|
107 |
+
text = transcriber(speech)
|
108 |
+
await websocket.send_json({"type": "final", "transcript": text})
|
109 |
+
caption_cache.append(text)
|
110 |
+
speech = np.empty(0, dtype=np.float32)
|
111 |
+
vad_iterator.triggered = False
|
112 |
+
vad_iterator.temp_end = 0
|
113 |
+
vad_iterator.current_sample = 0
|
114 |
+
# Send partial transcription updates periodically.
|
115 |
+
if (current_time - last_partial_time) > MIN_REFRESH_SECS:
|
116 |
+
text = transcriber(speech)
|
117 |
+
await websocket.send_json({"type": "partial", "transcript": text})
|
118 |
+
last_partial_time = current_time
|
119 |
+
except WebSocketDisconnect:
|
120 |
+
# If the client disconnects, send any final transcript if available.
|
121 |
+
if recording and speech.size:
|
122 |
+
text = transcriber(speech)
|
123 |
+
await websocket.send_json({"type": "final", "transcript": text})
|
124 |
+
print("WebSocket disconnected")
|
125 |
|
126 |
+
@app.get("/", response_class=HTMLResponse)
|
127 |
+
async def get_home():
|
128 |
+
return """
|
129 |
+
<!DOCTYPE html>
|
130 |
+
<html>
|
131 |
+
<head>
|
132 |
+
<meta charset="UTF-8">
|
133 |
+
<title>AssemblyAI Realtime Transcription</title>
|
134 |
+
</head>
|
135 |
+
<body>
|
136 |
+
<h1>Realtime Transcription</h1>
|
137 |
+
<button onclick="startTranscription()">Start Transcription</button>
|
138 |
+
<p id="status">Click start to begin transcription.</p>
|
139 |
+
<div id="transcription" style="border:1px solid #ccc; padding:10px; margin-top:10px; height:200px; overflow:auto;"></div>
|
140 |
+
<script>
|
141 |
+
let ws;
|
142 |
+
let audioContext;
|
143 |
+
let scriptProcessor;
|
144 |
+
let mediaStream;
|
145 |
+
let currentLine = document.createElement('span');
|
146 |
+
document.getElementById('transcription').appendChild(currentLine);
|
147 |
+
async function startTranscription() {
|
148 |
+
document.getElementById("status").innerText = "Connecting...";
|
149 |
+
ws = new WebSocket("wss://" + location.host + "/ws/transcribe");
|
150 |
+
ws.binaryType = 'arraybuffer';
|
151 |
+
ws.onopen = async function() {
|
152 |
+
document.getElementById("status").innerText = "Connected";
|
153 |
+
try {
|
154 |
+
mediaStream = await navigator.mediaDevices.getUserMedia({ audio: true });
|
155 |
+
audioContext = new AudioContext({ sampleRate: 16000 });
|
156 |
+
const source = audioContext.createMediaStreamSource(mediaStream);
|
157 |
+
scriptProcessor = audioContext.createScriptProcessor(1024, 1, 1);
|
158 |
+
scriptProcessor.onaudioprocess = function(event) {
|
159 |
+
const inputData = event.inputBuffer.getChannelData(0);
|
160 |
+
const pcm16 = floatTo16BitPCM(inputData);
|
161 |
+
if (ws.readyState === WebSocket.OPEN) {
|
162 |
+
ws.send(pcm16);
|
163 |
+
}
|
164 |
+
};
|
165 |
+
source.connect(scriptProcessor);
|
166 |
+
scriptProcessor.connect(audioContext.destination);
|
167 |
+
} catch (err) {
|
168 |
+
document.getElementById("status").innerText = "Error: " + err;
|
169 |
+
}
|
170 |
+
};
|
171 |
+
ws.onmessage = function(event) {
|
172 |
+
const data = JSON.parse(event.data);
|
173 |
+
if (data.type === 'partial') {
|
174 |
+
currentLine.style.color = 'gray';
|
175 |
+
currentLine.textContent = data.transcript + ' ';
|
176 |
+
} else if (data.type === 'final') {
|
177 |
+
currentLine.style.color = 'black';
|
178 |
+
currentLine.textContent = data.transcript;
|
179 |
+
currentLine = document.createElement('span');
|
180 |
+
document.getElementById('transcription').appendChild(document.createElement('br'));
|
181 |
+
document.getElementById('transcription').appendChild(currentLine);
|
182 |
+
}
|
183 |
+
};
|
184 |
+
ws.onclose = function() {
|
185 |
+
if (audioContext && audioContext.state !== 'closed') {
|
186 |
+
audioContext.close();
|
187 |
+
}
|
188 |
+
document.getElementById("status").innerText = "Closed";
|
189 |
+
};
|
190 |
+
}
|
191 |
+
function floatTo16BitPCM(input) {
|
192 |
+
const buffer = new ArrayBuffer(input.length * 2);
|
193 |
+
const output = new DataView(buffer);
|
194 |
+
for (let i = 0; i < input.length; i++) {
|
195 |
+
let s = Math.max(-1, Math.min(1, input[i]));
|
196 |
+
output.setInt16(i * 2, s < 0 ? s * 0x8000 : s * 0x7FFF, true);
|
197 |
+
}
|
198 |
+
return buffer;
|
199 |
+
}
|
200 |
+
</script>
|
201 |
+
</body>
|
202 |
+
</html>
|
203 |
+
"""
|
204 |
|
205 |
if __name__ == "__main__":
|
206 |
import uvicorn
|
207 |
+
uvicorn.run(app, host="0.0.0.0", port=8000)
|