File size: 2,421 Bytes
2b054c9
 
 
 
 
 
0206ee8
 
 
2b054c9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0206ee8
 
 
 
 
 
 
 
 
 
2b054c9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
from fastapi import FastAPI, UploadFile, File, HTTPException
from transformers import MoonshineForConditionalGeneration, AutoProcessor
import torch
import librosa
import io
import os
from silero_vad import load_silero_vad, read_audio, get_speech_timestamps
model = load_silero_vad()


app = FastAPI()

# Check for GPU availability
device = "cuda:0" if torch.cuda.is_available() else "cpu"
torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32

# Load the model and processor
try:
    model = MoonshineForConditionalGeneration.from_pretrained('UsefulSensors/moonshine-tiny').to(device).to(torch_dtype)
    processor = AutoProcessor.from_pretrained('UsefulSensors/moonshine-tiny')
except Exception as e:
    print(f"Error loading model or processor: {e}")
    exit()

@app.post("/transcribe/")
async def transcribe_audio(file: UploadFile = File(...)):
    """
    Transcribes an uploaded audio file.
    """
    if not file.filename.lower().endswith(('.mp3', '.wav', '.ogg', '.flac', '.m4a')): #add more formats as needed
        raise HTTPException(status_code=400, detail="Invalid file format. Supported formats: .mp3, .wav, .ogg, .flac, .m4a")

    try:
        audio_bytes = await file.read()
        audio_array, sampling_rate = librosa.load(io.BytesIO(audio_bytes), sr=processor.feature_extractor.sampling_rate)

        

        speech_timestamps = get_speech_timestamps(
          torch.from_numpy(audio_array),
          model,
          return_seconds=True,  # Return speech timestamps in seconds (default is samples)
        )

        print(speech_timestamps)

        inputs = processor(
            audio_array,
            return_tensors="pt",
            sampling_rate=processor.feature_extractor.sampling_rate
        )
        inputs = inputs.to(device, torch_dtype)

        token_limit_factor = 6.5 / processor.feature_extractor.sampling_rate
        seq_lens = inputs.attention_mask.sum(dim=-1)
        max_length = int((seq_lens * token_limit_factor).max().item())

        generated_ids = model.generate(**inputs, max_length=max_length)
        transcription = processor.decode(generated_ids[0], skip_special_tokens=True)

        return {"transcription": transcription}

    except Exception as e:
        raise HTTPException(status_code=500, detail=f"Error processing audio: {e}")

if __name__ == "__main__":
    import uvicorn
    uvicorn.run(app, host="0.0.0.0", port=7860)