stt-4 / app.py
bcci's picture
Update app.py
74e732d verified
raw
history blame contribute delete
10.4 kB
import time
import asyncio
import numpy as np
from fastapi import FastAPI, WebSocket, WebSocketDisconnect
from fastapi.responses import HTMLResponse
from silero_vad import VADIterator, load_silero_vad
from transformers import AutoProcessor, pipeline, WhisperTokenizerFast
from optimum.onnxruntime import ORTModelForSpeechSeq2Seq
# Load models
processor_tiny = AutoProcessor.from_pretrained("onnx-community/whisper-tiny.en")
model_tiny = ORTModelForSpeechSeq2Seq.from_pretrained("onnx-community/whisper-tiny.en", subfolder="onnx")
tokenizer_tiny = WhisperTokenizerFast.from_pretrained("onnx-community/whisper-tiny.en", language="english")
pipe_tiny = pipeline("automatic-speech-recognition", model=model_tiny, tokenizer=tokenizer_tiny, feature_extractor=processor_tiny.feature_extractor)
processor_base = AutoProcessor.from_pretrained("onnx-community/whisper-base.en")
model_base = ORTModelForSpeechSeq2Seq.from_pretrained("onnx-community/whisper-base.en", subfolder="onnx")
tokenizer_base = WhisperTokenizerFast.from_pretrained("onnx-community/whisper-base.en", language="english")
pipe_base = pipeline("automatic-speech-recognition", model=model_base, tokenizer=tokenizer_base, feature_extractor=processor_base.feature_extractor)
# Constants
SAMPLING_RATE = 16000
CHUNK_SIZE = 512
LOOKBACK_CHUNKS = 5
MAX_SPEECH_SECS = 15
MIN_REFRESH_SECS = 1
app = FastAPI()
vad_model = load_silero_vad(onnx=True)
vad_iterator = VADIterator(
model=vad_model,
sampling_rate=SAMPLING_RATE,
threshold=0.5,
min_silence_duration_ms=300,
)
def pcm16_to_float32(pcm_data: bytes) -> np.ndarray:
"""
Convert 16-bit PCM bytes into a float32 numpy array with values in [-1, 1].
"""
int_data = np.frombuffer(pcm_data, dtype=np.int16)
float_data = int_data.astype(np.float32) / 32768.0
return float_data
@app.websocket("/ws/transcribe")
async def websocket_endpoint(websocket: WebSocket):
await websocket.accept()
caption_cache = []
speech = np.empty(0, dtype=np.float32)
recording = False
last_partial_time = time.time()
current_pipe = pipe_tiny
try:
while True:
data = await websocket.receive()
if data["type"] == "websocket.receive":
if data.get("text") == "switch_to_tiny":
current_pipe = pipe_tiny
continue
elif data.get("text") == "switch_to_base":
current_pipe = pipe_base
continue
chunk = pcm16_to_float32(data["bytes"])
speech = np.concatenate((speech, chunk))
if not recording:
speech = speech[-(LOOKBACK_CHUNKS * CHUNK_SIZE):]
vad_result = vad_iterator(chunk)
if vad_result:
if "start" in vad_result and not recording:
recording = True
await websocket.send_json({"type": "status", "message": "speaking_started"})
if "end" in vad_result and recording:
recording = False
text = current_pipe({"sampling_rate": 16000, "raw": speech})["text"]
await websocket.send_json({"type": "final", "transcript": text})
caption_cache.append(text)
speech = np.empty(0, dtype=np.float32)
vad_iterator.triggered = False
vad_iterator.temp_end = 0
vad_iterator.current_sample = 0
await websocket.send_json({"type": "status", "message": "speaking_stopped"})
except WebSocketDisconnect:
if recording and speech.size:
text = current_pipe({"sampling_rate": 16000, "raw": speech})["text"]
await websocket.send_json({"type": "final", "transcript": text})
print("WebSocket disconnected")
@app.get("/", response_class=HTMLResponse)
async def get_home():
return """
<!DOCTYPE html>
<html>
<head>
<meta charset="UTF-8">
<title>AssemblyAI Realtime Transcription</title>
<link href="https://cdn.jsdelivr.net/npm/[email protected]/dist/tailwind.min.css" rel="stylesheet">
</head>
<body class="bg-gray-100 p-6">
<div class="max-w-3xl mx-auto bg-white p-6 rounded-lg shadow-md">
<h1 class="text-2xl font-bold mb-4">Realtime Transcription</h1>
<button onclick="startTranscription()" class="bg-blue-500 text-white px-4 py-2 rounded mb-4">Start Transcription</button>
<select id="modelSelect" onchange="switchModel()" class="bg-gray-200 px-4 py-2 rounded mb-4">
<option value="tiny">Tiny Model</option>
<option value="base">Base Model</option>
</select>
<p id="status" class="text-gray-600 mb-4">Click start to begin transcription.</p>
<p id="speakingStatus" class="text-gray-600 mb-4"></p>
<div id="transcription" class="border p-4 rounded mb-4 h-64 overflow-auto"></div>
<div id="visualizer" class="border p-4 rounded h-64">
<canvas id="audioCanvas" class="w-full h-full"></canvas>
</div>
</div>
<script>
let ws;
let audioContext;
let scriptProcessor;
let mediaStream;
let currentLine = document.createElement('span');
let analyser;
let canvas, canvasContext;
document.getElementById('transcription').appendChild(currentLine);
canvas = document.getElementById('audioCanvas');
canvasContext = canvas.getContext('2d');
async function startTranscription() {
document.getElementById("status").innerText = "Connecting...";
ws = new WebSocket("wss://" + location.host + "/ws/transcribe");
ws.binaryType = 'arraybuffer';
ws.onopen = async function() {
document.getElementById("status").innerText = "Connected";
try {
mediaStream = await navigator.mediaDevices.getUserMedia({ audio: true });
audioContext = new AudioContext({ sampleRate: 16000 });
const source = audioContext.createMediaStreamSource(mediaStream);
analyser = audioContext.createAnalyser();
analyser.fftSize = 2048;
const bufferLength = analyser.frequencyBinCount;
const dataArray = new Uint8Array(bufferLength);
source.connect(analyser);
scriptProcessor = audioContext.createScriptProcessor(512, 1, 1);
scriptProcessor.onaudioprocess = function(event) {
const inputData = event.inputBuffer.getChannelData(0);
const pcm16 = floatTo16BitPCM(inputData);
if (ws.readyState === WebSocket.OPEN) {
ws.send(pcm16);
}
analyser.getByteTimeDomainData(dataArray);
canvasContext.fillStyle = 'rgb(200, 200, 200)';
canvasContext.fillRect(0, 0, canvas.width, canvas.height);
canvasContext.lineWidth = 2;
canvasContext.strokeStyle = 'rgb(0, 0, 0)';
canvasContext.beginPath();
let sliceWidth = canvas.width * 1.0 / bufferLength;
let x = 0;
for (let i = 0; i < bufferLength; i++) {
let v = dataArray[i] / 128.0;
let y = v * canvas.height / 2;
if (i === 0) {
canvasContext.moveTo(x, y);
} else {
canvasContext.lineTo(x, y);
}
x += sliceWidth;
}
canvasContext.lineTo(canvas.width, canvas.height / 2);
canvasContext.stroke();
};
source.connect(scriptProcessor);
scriptProcessor.connect(audioContext.destination);
} catch (err) {
document.getElementById("status").innerText = "Error: " + err;
}
};
ws.onmessage = function(event) {
const data = JSON.parse(event.data);
if (data.type === 'partial') {
currentLine.style.color = 'gray';
currentLine.textContent = data.transcript + ' ';
} else if (data.type === 'final') {
currentLine.style.color = 'black';
currentLine.textContent = data.transcript;
currentLine = document.createElement('span');
document.getElementById('transcription').appendChild(document.createElement('br'));
document.getElementById('transcription').appendChild(currentLine);
} else if (data.type === 'status') {
if (data.message === 'speaking_started') {
document.getElementById("speakingStatus").innerText = "Speaking Started";
document.getElementById("speakingStatus").style.color = "green";
} else if (data.message === 'speaking_stopped') {
document.getElementById("speakingStatus").innerText = "Speaking Stopped";
document.getElementById("speakingStatus").style.color = "red";
}
}
};
ws.onclose = function() {
if (audioContext && audioContext.state !== 'closed') {
audioContext.close();
}
document.getElementById("status").innerText = "Closed";
};
}
function switchModel() {
const model = document.getElementById("modelSelect").value;
if (ws && ws.readyState === WebSocket.OPEN) {
if (model === "tiny") {
ws.send("switch_to_tiny");
} else if (model === "base") {
ws.send("switch_to_base");
}
}
}
function floatTo16BitPCM(input) {
const buffer = new ArrayBuffer(input.length * 2);
const output = new DataView(buffer);
for (let i = 0; i < input.length; i++) {
let s = Math.max(-1, Math.min(1, input[i]));
output.setInt16(i * 2, s < 0 ? s * 0x8000 : s * 0x7FFF, true);
}
return buffer;
}
</script>
</body>
</html>
"""
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=7860)