|
from fastapi import FastAPI, UploadFile, File, HTTPException |
|
from transformers import MoonshineForConditionalGeneration, AutoProcessor |
|
import torch |
|
import librosa |
|
import io |
|
import os |
|
from silero_vad import load_silero_vad, read_audio, get_speech_timestamps |
|
model = load_silero_vad() |
|
|
|
|
|
app = FastAPI() |
|
|
|
|
|
device = "cuda:0" if torch.cuda.is_available() else "cpu" |
|
torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32 |
|
|
|
|
|
try: |
|
model = MoonshineForConditionalGeneration.from_pretrained('UsefulSensors/moonshine-tiny').to(device).to(torch_dtype) |
|
processor = AutoProcessor.from_pretrained('UsefulSensors/moonshine-tiny') |
|
except Exception as e: |
|
print(f"Error loading model or processor: {e}") |
|
exit() |
|
|
|
@app.post("/transcribe/") |
|
async def transcribe_audio(file: UploadFile = File(...)): |
|
""" |
|
Transcribes an uploaded audio file. |
|
""" |
|
if not file.filename.lower().endswith(('.mp3', '.wav', '.ogg', '.flac', '.m4a')): |
|
raise HTTPException(status_code=400, detail="Invalid file format. Supported formats: .mp3, .wav, .ogg, .flac, .m4a") |
|
|
|
try: |
|
audio_bytes = await file.read() |
|
audio_array, sampling_rate = librosa.load(io.BytesIO(audio_bytes), sr=processor.feature_extractor.sampling_rate) |
|
|
|
|
|
|
|
speech_timestamps = get_speech_timestamps( |
|
torch.from_numpy(audio_array), |
|
model, |
|
return_seconds=True, |
|
) |
|
|
|
print(speech_timestamps) |
|
|
|
inputs = processor( |
|
audio_array, |
|
return_tensors="pt", |
|
sampling_rate=processor.feature_extractor.sampling_rate |
|
) |
|
inputs = inputs.to(device, torch_dtype) |
|
|
|
token_limit_factor = 6.5 / processor.feature_extractor.sampling_rate |
|
seq_lens = inputs.attention_mask.sum(dim=-1) |
|
max_length = int((seq_lens * token_limit_factor).max().item()) |
|
|
|
generated_ids = model.generate(**inputs, max_length=max_length) |
|
transcription = processor.decode(generated_ids[0], skip_special_tokens=True) |
|
|
|
return {"transcription": transcription} |
|
|
|
except Exception as e: |
|
raise HTTPException(status_code=500, detail=f"Error processing audio: {e}") |
|
|
|
if __name__ == "__main__": |
|
import uvicorn |
|
uvicorn.run(app, host="0.0.0.0", port=7860) |