phoneme-model / app.py
vladakoz's picture
Add application file
7192b75
import torch
import torchaudio
from fastapi import FastAPI, UploadFile, File
from transformers import Wav2Vec2Processor, Wav2Vec2ForCTC
import io
app = FastAPI()
# Load Wav2Vec2 model and processor
processor = Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-lv-60-espeak-cv-ft")
model = Wav2Vec2ForCTC.from_pretrained("facebook/wav2vec2-lv-60-espeak-cv-ft")
@app.post("/transcribe/")
async def transcribe_audio(file: UploadFile = File(...)):
try:
# Load audio file
audio_bytes = await file.read()
audio_input, sample_rate = torchaudio.load(io.BytesIO(audio_bytes))
# Convert stereo to mono (if needed)
if audio_input.shape[0] > 1:
audio_input = torch.mean(audio_input, dim=0, keepdim=True)
# Resample to 16 kHz (if needed)
target_sample_rate = 16000
if sample_rate != target_sample_rate:
resampler = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=target_sample_rate)
audio_input = resampler(audio_input)
# Remove batch dimension
audio_input = audio_input.squeeze(0)
# Preprocess the audio
input_values = processor(audio_input, sampling_rate=target_sample_rate, return_tensors="pt").input_values
# Run inference
with torch.no_grad():
logits = model(input_values).logits
# Decode the predicted tokens
predicted_ids = torch.argmax(logits, dim=-1)
transcription = processor.batch_decode(predicted_ids)
return {"transcription": transcription[0]}
except Exception as e:
return {"error": str(e)}