Spaces:
Sleeping
Sleeping
import asyncio | |
import websockets | |
import streamlit as st | |
from transformers import Wav2Vec2ForCTC, Wav2Vec2Tokenizer | |
import numpy as np | |
import torch | |
import soundfile as sf | |
import io | |
# Load pre-trained model and tokenizer | |
tokenizer = Wav2Vec2Tokenizer.from_pretrained("facebook/wav2vec2-base-960h") | |
model = Wav2Vec2ForCTC.from_pretrained("facebook/wav2vec2-base-960h") | |
async def recognize_speech(websocket): | |
async for message in websocket: | |
try: | |
# Read audio data from message | |
wf, samplerate = sf.read(io.BytesIO(message)) | |
# Tokenize input values | |
input_values = tokenizer(wf, return_tensors="pt").input_values | |
# Predict logits | |
with torch.no_grad(): | |
logits = model(input_values).logits | |
# Decode predictions | |
predicted_ids = torch.argmax(logits, dim=-1) | |
transcription = tokenizer.decode(predicted_ids[0]) | |
# Send transcription back to the client | |
await websocket.send(transcription) | |
except Exception as e: | |
print(f"Error in recognize_speech: {e}") | |
await websocket.send("Error processing audio data.") | |
async def main_logic(): | |
async with websockets.serve(recognize_speech, "localhost", 8000): | |
await asyncio.Future() # run forever | |
# Streamlit interface | |
st.title("Real-Time ASR with Transformers") | |
# WebSocket script for the frontend | |
st.markdown(""" | |
<script> | |
const handleAudio = async (stream) => { | |
const websocket = new WebSocket('ws://localhost:8000'); | |
const mediaRecorder = new MediaRecorder(stream, { mimeType: 'audio/webm' }); | |
const audioChunks = []; | |
mediaRecorder.addEventListener("dataavailable", event => { | |
audioChunks.push(event.data); | |
}); | |
mediaRecorder.addEventListener("stop", () => { | |
const audioBlob = new Blob(audioChunks); | |
websocket.send(audioBlob); | |
}); | |
websocket.onmessage = (event) => { | |
const transcription = event.data; | |
const transcriptionDiv = document.getElementById("transcription"); | |
transcriptionDiv.innerHTML += `<div>${transcription}</div>`; | |
}; | |
websocket.onopen = () => { | |
console.log('WebSocket connection established.'); | |
}; | |
websocket.onerror = (error) => { | |
console.error('WebSocket error:', error); | |
}; | |
websocket.onclose = () => { | |
console.log('WebSocket connection closed.'); | |
}; | |
mediaRecorder.start(1000); | |
}; | |
navigator.mediaDevices.getUserMedia({ audio: true }) | |
.then(handleAudio) | |
.catch(error => console.error('Error accessing media devices.', error)); | |
</script> | |
<div id="transcription">Your transcriptions will appear here:</div> | |
""", unsafe_allow_html=True) | |
# To run the WebSocket server | |
if __name__ == "__main__": | |
asyncio.get_event_loop().run_until_complete(main_logic()) |