bcci commited on
Commit
2b054c9
·
verified ·
1 Parent(s): 86faebb

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +55 -0
app.py ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI, UploadFile, File, HTTPException
2
+ from transformers import MoonshineForConditionalGeneration, AutoProcessor
3
+ import torch
4
+ import librosa
5
+ import io
6
+ import os
7
+
8
+ app = FastAPI()
9
+
10
+ # Check for GPU availability
11
+ device = "cuda:0" if torch.cuda.is_available() else "cpu"
12
+ torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
13
+
14
+ # Load the model and processor
15
+ try:
16
+ model = MoonshineForConditionalGeneration.from_pretrained('UsefulSensors/moonshine-tiny').to(device).to(torch_dtype)
17
+ processor = AutoProcessor.from_pretrained('UsefulSensors/moonshine-tiny')
18
+ except Exception as e:
19
+ print(f"Error loading model or processor: {e}")
20
+ exit()
21
+
22
+ @app.post("/transcribe/")
23
+ async def transcribe_audio(file: UploadFile = File(...)):
24
+ """
25
+ Transcribes an uploaded audio file.
26
+ """
27
+ if not file.filename.lower().endswith(('.mp3', '.wav', '.ogg', '.flac', '.m4a')): #add more formats as needed
28
+ raise HTTPException(status_code=400, detail="Invalid file format. Supported formats: .mp3, .wav, .ogg, .flac, .m4a")
29
+
30
+ try:
31
+ audio_bytes = await file.read()
32
+ audio_array, sampling_rate = librosa.load(io.BytesIO(audio_bytes), sr=processor.feature_extractor.sampling_rate)
33
+
34
+ inputs = processor(
35
+ audio_array,
36
+ return_tensors="pt",
37
+ sampling_rate=processor.feature_extractor.sampling_rate
38
+ )
39
+ inputs = inputs.to(device, torch_dtype)
40
+
41
+ token_limit_factor = 6.5 / processor.feature_extractor.sampling_rate
42
+ seq_lens = inputs.attention_mask.sum(dim=-1)
43
+ max_length = int((seq_lens * token_limit_factor).max().item())
44
+
45
+ generated_ids = model.generate(**inputs, max_length=max_length)
46
+ transcription = processor.decode(generated_ids[0], skip_special_tokens=True)
47
+
48
+ return {"transcription": transcription}
49
+
50
+ except Exception as e:
51
+ raise HTTPException(status_code=500, detail=f"Error processing audio: {e}")
52
+
53
+ if __name__ == "__main__":
54
+ import uvicorn
55
+ uvicorn.run(app, host="0.0.0.0", port=7860)