bcci commited on
Commit
a4181e3
·
verified ·
1 Parent(s): c9365be

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +191 -50
app.py CHANGED
@@ -1,66 +1,207 @@
1
- from fastapi import FastAPI, UploadFile, File, HTTPException
2
- from transformers import MoonshineForConditionalGeneration, AutoProcessor
3
- import torch
4
- import librosa
5
- import io
6
- import os
7
- from silero_vad import load_silero_vad, read_audio, get_speech_timestamps
8
- model = load_silero_vad()
9
 
 
 
 
 
 
 
 
 
 
 
10
 
11
  app = FastAPI()
12
 
13
- # Check for GPU availability
14
- device = "cuda:0" if torch.cuda.is_available() else "cpu"
15
- torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
 
 
 
 
 
 
 
 
 
 
16
 
17
- # Load the model and processor
18
- try:
19
- model = MoonshineForConditionalGeneration.from_pretrained('UsefulSensors/moonshine-tiny').to(device).to(torch_dtype)
20
- processor = AutoProcessor.from_pretrained('UsefulSensors/moonshine-tiny')
21
- except Exception as e:
22
- print(f"Error loading model or processor: {e}")
23
- exit()
 
 
24
 
25
- @app.post("/transcribe/")
26
- async def transcribe_audio(file: UploadFile = File(...)):
27
  """
28
- Transcribes an uploaded audio file.
29
  """
30
- if not file.filename.lower().endswith(('.mp3', '.wav', '.ogg', '.flac', '.m4a')): #add more formats as needed
31
- raise HTTPException(status_code=400, detail="Invalid file format. Supported formats: .mp3, .wav, .ogg, .flac, .m4a")
32
-
33
- try:
34
- audio_bytes = await file.read()
35
- audio_array, sampling_rate = librosa.load(io.BytesIO(audio_bytes), sr=processor.feature_extractor.sampling_rate)
36
-
37
- # speech_timestamps = get_speech_timestamps(
38
- # torch.from_numpy(audio_array),
39
- # model,
40
- # return_seconds=True, # Return speech timestamps in seconds (default is samples)
41
- # )
42
 
43
- # print(speech_timestamps)
 
 
44
 
45
- inputs = processor(
46
- audio_array,
47
- return_tensors="pt",
48
- sampling_rate=processor.feature_extractor.sampling_rate
49
- )
50
- inputs = inputs.to(device, torch_dtype)
 
 
 
 
51
 
52
- token_limit_factor = 6.5 / processor.feature_extractor.sampling_rate
53
- seq_lens = inputs.attention_mask.sum(dim=-1)
54
- max_length = int((seq_lens * token_limit_factor).max().item())
 
 
55
 
56
- generated_ids = model.generate(**inputs, max_length=max_length)
57
- transcription = processor.decode(generated_ids[0], skip_special_tokens=True)
58
-
59
- return {"transcription": transcription}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
60
 
61
- except Exception as e:
62
- raise HTTPException(status_code=500, detail=f"Error processing audio: {e}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
63
 
64
  if __name__ == "__main__":
65
  import uvicorn
66
- uvicorn.run(app, host="0.0.0.0", port=7860)
 
1
+ import time
2
+ import asyncio
3
+ import numpy as np
4
+ from fastapi import FastAPI, WebSocket, WebSocketDisconnect
5
+ from fastapi.responses import HTMLResponse
 
 
 
6
 
7
+ # Import your model and VAD libraries.
8
+ from silero_vad import VADIterator, load_silero_vad
9
+ from moonshine_onnx import MoonshineOnnxModel, load_tokenizer
10
+
11
+ # Constants
12
+ SAMPLING_RATE = 16000
13
+ CHUNK_SIZE = 512 # Required for Silero VAD at 16kHz.
14
+ LOOKBACK_CHUNKS = 5
15
+ MAX_SPEECH_SECS = 15 # Maximum duration for a single transcription segment.
16
+ MIN_REFRESH_SECS = 0.2 # Minimum interval for sending partial updates.
17
 
18
  app = FastAPI()
19
 
20
+ class Transcriber:
21
+ def __init__(self, model_name: str, rate: int = 16000):
22
+ if rate != 16000:
23
+ raise ValueError("Moonshine supports sampling rate 16000 Hz.")
24
+ self.model = MoonshineOnnxModel(model_name=model_name)
25
+ self.rate = rate
26
+ self.tokenizer = load_tokenizer()
27
+ # Statistics (optional)
28
+ self.inference_secs = 0
29
+ self.number_inferences = 0
30
+ self.speech_secs = 0
31
+ # Warmup run.
32
+ self.__call__(np.zeros(int(rate), dtype=np.float32))
33
 
34
+ def __call__(self, speech: np.ndarray) -> str:
35
+ """Returns a transcription of the given speech (a float32 numpy array)."""
36
+ self.number_inferences += 1
37
+ self.speech_secs += len(speech) / self.rate
38
+ start_time = time.time()
39
+ tokens = self.model.generate(speech[np.newaxis, :].astype(np.float32))
40
+ text = self.tokenizer.decode_batch(tokens)[0]
41
+ self.inference_secs += time.time() - start_time
42
+ return text
43
 
44
+ def pcm16_to_float32(pcm_data: bytes) -> np.ndarray:
 
45
  """
46
+ Convert 16-bit PCM bytes into a float32 numpy array with values in [-1, 1].
47
  """
48
+ int_data = np.frombuffer(pcm_data, dtype=np.int16)
49
+ float_data = int_data.astype(np.float32) / 32768.0
50
+ return float_data
 
 
 
 
 
 
 
 
 
51
 
52
+ @app.websocket("/ws/transcribe")
53
+ async def websocket_endpoint(websocket: WebSocket):
54
+ await websocket.accept()
55
 
56
+ # Initialize models.
57
+ model_name = "moonshine/tiny"
58
+ transcriber = Transcriber(model_name=model_name, rate=SAMPLING_RATE)
59
+ vad_model = load_silero_vad(onnx=True)
60
+ vad_iterator = VADIterator(
61
+ model=vad_model,
62
+ sampling_rate=SAMPLING_RATE,
63
+ threshold=0.5,
64
+ min_silence_duration_ms=300,
65
+ )
66
 
67
+ caption_cache = []
68
+ lookback_size = LOOKBACK_CHUNKS * CHUNK_SIZE
69
+ speech = np.empty(0, dtype=np.float32)
70
+ recording = False
71
+ last_partial_time = time.time()
72
 
73
+ try:
74
+ while True:
75
+ # Wait for the next audio chunk (sent as binary data)
76
+ data = await websocket.receive_bytes()
77
+ # Convert the 16-bit PCM data to float32.
78
+ chunk = pcm16_to_float32(data)
79
+ speech = np.concatenate((speech, chunk))
80
+ if not recording:
81
+ # Retain only the last few chunks when not recording.
82
+ speech = speech[-lookback_size:]
83
+
84
+ # Process VAD on the current chunk.
85
+ vad_result = vad_iterator(chunk)
86
+ current_time = time.time()
87
+ if vad_result:
88
+ # If VAD signals the start of speech and we're not already recording.
89
+ if "start" in vad_result and not recording:
90
+ recording = True
91
+ start_time = current_time
92
+ # If VAD signals the end of speech.
93
+ if "end" in vad_result and recording:
94
+ recording = False
95
+ text = transcriber(speech)
96
+ await websocket.send_json({"type": "final", "transcript": text})
97
+ caption_cache.append(text)
98
+ speech = np.empty(0, dtype=np.float32)
99
+ # Reset VAD state.
100
+ vad_iterator.triggered = False
101
+ vad_iterator.temp_end = 0
102
+ vad_iterator.current_sample = 0
103
+ elif recording:
104
+ # If speech goes on too long, force an end.
105
+ if (len(speech) / SAMPLING_RATE) > MAX_SPEECH_SECS:
106
+ recording = False
107
+ text = transcriber(speech)
108
+ await websocket.send_json({"type": "final", "transcript": text})
109
+ caption_cache.append(text)
110
+ speech = np.empty(0, dtype=np.float32)
111
+ vad_iterator.triggered = False
112
+ vad_iterator.temp_end = 0
113
+ vad_iterator.current_sample = 0
114
+ # Send partial transcription updates periodically.
115
+ if (current_time - last_partial_time) > MIN_REFRESH_SECS:
116
+ text = transcriber(speech)
117
+ await websocket.send_json({"type": "partial", "transcript": text})
118
+ last_partial_time = current_time
119
+ except WebSocketDisconnect:
120
+ # If the client disconnects, send any final transcript if available.
121
+ if recording and speech.size:
122
+ text = transcriber(speech)
123
+ await websocket.send_json({"type": "final", "transcript": text})
124
+ print("WebSocket disconnected")
125
 
126
+ @app.get("/", response_class=HTMLResponse)
127
+ async def get_home():
128
+ return """
129
+ <!DOCTYPE html>
130
+ <html>
131
+ <head>
132
+ <meta charset="UTF-8">
133
+ <title>AssemblyAI Realtime Transcription</title>
134
+ </head>
135
+ <body>
136
+ <h1>Realtime Transcription</h1>
137
+ <button onclick="startTranscription()">Start Transcription</button>
138
+ <p id="status">Click start to begin transcription.</p>
139
+ <div id="transcription" style="border:1px solid #ccc; padding:10px; margin-top:10px; height:200px; overflow:auto;"></div>
140
+ <script>
141
+ let ws;
142
+ let audioContext;
143
+ let scriptProcessor;
144
+ let mediaStream;
145
+ let currentLine = document.createElement('span');
146
+ document.getElementById('transcription').appendChild(currentLine);
147
+ async function startTranscription() {
148
+ document.getElementById("status").innerText = "Connecting...";
149
+ ws = new WebSocket("wss://" + location.host + "/ws/transcribe");
150
+ ws.binaryType = 'arraybuffer';
151
+ ws.onopen = async function() {
152
+ document.getElementById("status").innerText = "Connected";
153
+ try {
154
+ mediaStream = await navigator.mediaDevices.getUserMedia({ audio: true });
155
+ audioContext = new AudioContext({ sampleRate: 16000 });
156
+ const source = audioContext.createMediaStreamSource(mediaStream);
157
+ scriptProcessor = audioContext.createScriptProcessor(1024, 1, 1);
158
+ scriptProcessor.onaudioprocess = function(event) {
159
+ const inputData = event.inputBuffer.getChannelData(0);
160
+ const pcm16 = floatTo16BitPCM(inputData);
161
+ if (ws.readyState === WebSocket.OPEN) {
162
+ ws.send(pcm16);
163
+ }
164
+ };
165
+ source.connect(scriptProcessor);
166
+ scriptProcessor.connect(audioContext.destination);
167
+ } catch (err) {
168
+ document.getElementById("status").innerText = "Error: " + err;
169
+ }
170
+ };
171
+ ws.onmessage = function(event) {
172
+ const data = JSON.parse(event.data);
173
+ if (data.type === 'partial') {
174
+ currentLine.style.color = 'gray';
175
+ currentLine.textContent = data.transcript + ' ';
176
+ } else if (data.type === 'final') {
177
+ currentLine.style.color = 'black';
178
+ currentLine.textContent = data.transcript;
179
+ currentLine = document.createElement('span');
180
+ document.getElementById('transcription').appendChild(document.createElement('br'));
181
+ document.getElementById('transcription').appendChild(currentLine);
182
+ }
183
+ };
184
+ ws.onclose = function() {
185
+ if (audioContext && audioContext.state !== 'closed') {
186
+ audioContext.close();
187
+ }
188
+ document.getElementById("status").innerText = "Closed";
189
+ };
190
+ }
191
+ function floatTo16BitPCM(input) {
192
+ const buffer = new ArrayBuffer(input.length * 2);
193
+ const output = new DataView(buffer);
194
+ for (let i = 0; i < input.length; i++) {
195
+ let s = Math.max(-1, Math.min(1, input[i]));
196
+ output.setInt16(i * 2, s < 0 ? s * 0x8000 : s * 0x7FFF, true);
197
+ }
198
+ return buffer;
199
+ }
200
+ </script>
201
+ </body>
202
+ </html>
203
+ """
204
 
205
  if __name__ == "__main__":
206
  import uvicorn
207
+ uvicorn.run(app, host="0.0.0.0", port=8000)