Spaces:
Runtime error
Runtime error
import torch | |
import torchaudio | |
from fastapi import FastAPI, UploadFile, File | |
from transformers import Wav2Vec2Processor, Wav2Vec2ForCTC | |
import io | |
app = FastAPI() | |
# Load Wav2Vec2 model and processor | |
processor = Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-lv-60-espeak-cv-ft") | |
model = Wav2Vec2ForCTC.from_pretrained("facebook/wav2vec2-lv-60-espeak-cv-ft") | |
async def transcribe_audio(file: UploadFile = File(...)): | |
try: | |
# Load audio file | |
audio_bytes = await file.read() | |
audio_input, sample_rate = torchaudio.load(io.BytesIO(audio_bytes)) | |
# Convert stereo to mono (if needed) | |
if audio_input.shape[0] > 1: | |
audio_input = torch.mean(audio_input, dim=0, keepdim=True) | |
# Resample to 16 kHz (if needed) | |
target_sample_rate = 16000 | |
if sample_rate != target_sample_rate: | |
resampler = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=target_sample_rate) | |
audio_input = resampler(audio_input) | |
# Remove batch dimension | |
audio_input = audio_input.squeeze(0) | |
# Preprocess the audio | |
input_values = processor(audio_input, sampling_rate=target_sample_rate, return_tensors="pt").input_values | |
# Run inference | |
with torch.no_grad(): | |
logits = model(input_values).logits | |
# Decode the predicted tokens | |
predicted_ids = torch.argmax(logits, dim=-1) | |
transcription = processor.batch_decode(predicted_ids) | |
return {"transcription": transcription[0]} | |
except Exception as e: | |
return {"error": str(e)} | |