import time import asyncio import numpy as np from fastapi import FastAPI, WebSocket, WebSocketDisconnect from fastapi.responses import HTMLResponse from silero_vad import VADIterator, load_silero_vad from transformers import AutoProcessor, pipeline, WhisperTokenizerFast from optimum.onnxruntime import ORTModelForSpeechSeq2Seq # Load models processor_tiny = AutoProcessor.from_pretrained("onnx-community/whisper-tiny.en") model_tiny = ORTModelForSpeechSeq2Seq.from_pretrained("onnx-community/whisper-tiny.en", subfolder="onnx") tokenizer_tiny = WhisperTokenizerFast.from_pretrained("onnx-community/whisper-tiny.en", language="english") pipe_tiny = pipeline("automatic-speech-recognition", model=model_tiny, tokenizer=tokenizer_tiny, feature_extractor=processor_tiny.feature_extractor) processor_base = AutoProcessor.from_pretrained("onnx-community/whisper-base.en") model_base = ORTModelForSpeechSeq2Seq.from_pretrained("onnx-community/whisper-base.en", subfolder="onnx") tokenizer_base = WhisperTokenizerFast.from_pretrained("onnx-community/whisper-base.en", language="english") pipe_base = pipeline("automatic-speech-recognition", model=model_base, tokenizer=tokenizer_base, feature_extractor=processor_base.feature_extractor) # Constants SAMPLING_RATE = 16000 CHUNK_SIZE = 512 LOOKBACK_CHUNKS = 5 MAX_SPEECH_SECS = 15 MIN_REFRESH_SECS = 1 app = FastAPI() vad_model = load_silero_vad(onnx=True) vad_iterator = VADIterator( model=vad_model, sampling_rate=SAMPLING_RATE, threshold=0.5, min_silence_duration_ms=300, ) def pcm16_to_float32(pcm_data: bytes) -> np.ndarray: """ Convert 16-bit PCM bytes into a float32 numpy array with values in [-1, 1]. """ int_data = np.frombuffer(pcm_data, dtype=np.int16) float_data = int_data.astype(np.float32) / 32768.0 return float_data @app.websocket("/ws/transcribe") async def websocket_endpoint(websocket: WebSocket): await websocket.accept() caption_cache = [] speech = np.empty(0, dtype=np.float32) recording = False last_partial_time = time.time() current_pipe = pipe_tiny try: while True: data = await websocket.receive() if data["type"] == "websocket.receive": if data.get("text") == "switch_to_tiny": current_pipe = pipe_tiny continue elif data.get("text") == "switch_to_base": current_pipe = pipe_base continue chunk = pcm16_to_float32(data["bytes"]) speech = np.concatenate((speech, chunk)) if not recording: speech = speech[-(LOOKBACK_CHUNKS * CHUNK_SIZE):] vad_result = vad_iterator(chunk) if vad_result: if "start" in vad_result and not recording: recording = True await websocket.send_json({"type": "status", "message": "speaking_started"}) if "end" in vad_result and recording: recording = False text = current_pipe({"sampling_rate": 16000, "raw": speech})["text"] await websocket.send_json({"type": "final", "transcript": text}) caption_cache.append(text) speech = np.empty(0, dtype=np.float32) vad_iterator.triggered = False vad_iterator.temp_end = 0 vad_iterator.current_sample = 0 await websocket.send_json({"type": "status", "message": "speaking_stopped"}) except WebSocketDisconnect: if recording and speech.size: text = current_pipe({"sampling_rate": 16000, "raw": speech})["text"] await websocket.send_json({"type": "final", "transcript": text}) print("WebSocket disconnected") @app.get("/", response_class=HTMLResponse) async def get_home(): return """
Click start to begin transcription.