Websockets / app.py
awacke1's picture
Update app.py
e33335a verified
raw
history blame
2.97 kB
import asyncio
import websockets
import streamlit as st
from transformers import Wav2Vec2ForCTC, Wav2Vec2Tokenizer
import numpy as np
import torch
import soundfile as sf
import io
# Load pre-trained model and tokenizer
tokenizer = Wav2Vec2Tokenizer.from_pretrained("facebook/wav2vec2-base-960h")
model = Wav2Vec2ForCTC.from_pretrained("facebook/wav2vec2-base-960h")
async def recognize_speech(websocket):
async for message in websocket:
try:
# Read audio data from message
wf, samplerate = sf.read(io.BytesIO(message))
# Tokenize input values
input_values = tokenizer(wf, return_tensors="pt").input_values
# Predict logits
with torch.no_grad():
logits = model(input_values).logits
# Decode predictions
predicted_ids = torch.argmax(logits, dim=-1)
transcription = tokenizer.decode(predicted_ids[0])
# Send transcription back to the client
await websocket.send(transcription)
except Exception as e:
print(f"Error in recognize_speech: {e}")
await websocket.send("Error processing audio data.")
async def main_logic():
async with websockets.serve(recognize_speech, "localhost", 8000):
await asyncio.Future() # run forever
# Streamlit interface
st.title("Real-Time ASR with Transformers")
# WebSocket script for the frontend
st.markdown("""
<script>
const handleAudio = async (stream) => {
const websocket = new WebSocket('ws://localhost:8000');
const mediaRecorder = new MediaRecorder(stream, { mimeType: 'audio/webm' });
const audioChunks = [];
mediaRecorder.addEventListener("dataavailable", event => {
audioChunks.push(event.data);
});
mediaRecorder.addEventListener("stop", () => {
const audioBlob = new Blob(audioChunks);
websocket.send(audioBlob);
});
websocket.onmessage = (event) => {
const transcription = event.data;
const transcriptionDiv = document.getElementById("transcription");
transcriptionDiv.innerHTML += `<div>${transcription}</div>`;
};
websocket.onopen = () => {
console.log('WebSocket connection established.');
};
websocket.onerror = (error) => {
console.error('WebSocket error:', error);
};
websocket.onclose = () => {
console.log('WebSocket connection closed.');
};
mediaRecorder.start(1000);
};
navigator.mediaDevices.getUserMedia({ audio: true })
.then(handleAudio)
.catch(error => console.error('Error accessing media devices.', error));
</script>
<div id="transcription">Your transcriptions will appear here:</div>
""", unsafe_allow_html=True)
# To run the WebSocket server
if __name__ == "__main__":
asyncio.get_event_loop().run_until_complete(main_logic())