File size: 2,421 Bytes
2b054c9 0206ee8 2b054c9 0206ee8 2b054c9 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 |
from fastapi import FastAPI, UploadFile, File, HTTPException
from transformers import MoonshineForConditionalGeneration, AutoProcessor
import torch
import librosa
import io
import os
from silero_vad import load_silero_vad, read_audio, get_speech_timestamps
model = load_silero_vad()
app = FastAPI()
# Check for GPU availability
device = "cuda:0" if torch.cuda.is_available() else "cpu"
torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
# Load the model and processor
try:
model = MoonshineForConditionalGeneration.from_pretrained('UsefulSensors/moonshine-tiny').to(device).to(torch_dtype)
processor = AutoProcessor.from_pretrained('UsefulSensors/moonshine-tiny')
except Exception as e:
print(f"Error loading model or processor: {e}")
exit()
@app.post("/transcribe/")
async def transcribe_audio(file: UploadFile = File(...)):
"""
Transcribes an uploaded audio file.
"""
if not file.filename.lower().endswith(('.mp3', '.wav', '.ogg', '.flac', '.m4a')): #add more formats as needed
raise HTTPException(status_code=400, detail="Invalid file format. Supported formats: .mp3, .wav, .ogg, .flac, .m4a")
try:
audio_bytes = await file.read()
audio_array, sampling_rate = librosa.load(io.BytesIO(audio_bytes), sr=processor.feature_extractor.sampling_rate)
speech_timestamps = get_speech_timestamps(
torch.from_numpy(audio_array),
model,
return_seconds=True, # Return speech timestamps in seconds (default is samples)
)
print(speech_timestamps)
inputs = processor(
audio_array,
return_tensors="pt",
sampling_rate=processor.feature_extractor.sampling_rate
)
inputs = inputs.to(device, torch_dtype)
token_limit_factor = 6.5 / processor.feature_extractor.sampling_rate
seq_lens = inputs.attention_mask.sum(dim=-1)
max_length = int((seq_lens * token_limit_factor).max().item())
generated_ids = model.generate(**inputs, max_length=max_length)
transcription = processor.decode(generated_ids[0], skip_special_tokens=True)
return {"transcription": transcription}
except Exception as e:
raise HTTPException(status_code=500, detail=f"Error processing audio: {e}")
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=7860) |