stt-4 / app.py
bcci's picture
Update app.py
0206ee8 verified
raw
history blame
2.42 kB
from fastapi import FastAPI, UploadFile, File, HTTPException
from transformers import MoonshineForConditionalGeneration, AutoProcessor
import torch
import librosa
import io
import os
from silero_vad import load_silero_vad, read_audio, get_speech_timestamps
model = load_silero_vad()
app = FastAPI()
# Check for GPU availability
device = "cuda:0" if torch.cuda.is_available() else "cpu"
torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
# Load the model and processor
try:
model = MoonshineForConditionalGeneration.from_pretrained('UsefulSensors/moonshine-tiny').to(device).to(torch_dtype)
processor = AutoProcessor.from_pretrained('UsefulSensors/moonshine-tiny')
except Exception as e:
print(f"Error loading model or processor: {e}")
exit()
@app.post("/transcribe/")
async def transcribe_audio(file: UploadFile = File(...)):
"""
Transcribes an uploaded audio file.
"""
if not file.filename.lower().endswith(('.mp3', '.wav', '.ogg', '.flac', '.m4a')): #add more formats as needed
raise HTTPException(status_code=400, detail="Invalid file format. Supported formats: .mp3, .wav, .ogg, .flac, .m4a")
try:
audio_bytes = await file.read()
audio_array, sampling_rate = librosa.load(io.BytesIO(audio_bytes), sr=processor.feature_extractor.sampling_rate)
speech_timestamps = get_speech_timestamps(
torch.from_numpy(audio_array),
model,
return_seconds=True, # Return speech timestamps in seconds (default is samples)
)
print(speech_timestamps)
inputs = processor(
audio_array,
return_tensors="pt",
sampling_rate=processor.feature_extractor.sampling_rate
)
inputs = inputs.to(device, torch_dtype)
token_limit_factor = 6.5 / processor.feature_extractor.sampling_rate
seq_lens = inputs.attention_mask.sum(dim=-1)
max_length = int((seq_lens * token_limit_factor).max().item())
generated_ids = model.generate(**inputs, max_length=max_length)
transcription = processor.decode(generated_ids[0], skip_special_tokens=True)
return {"transcription": transcription}
except Exception as e:
raise HTTPException(status_code=500, detail=f"Error processing audio: {e}")
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=7860)