tclin commited on
Commit
722c2f4
·
verified ·
1 Parent(s): a0004b6

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +71 -0
app.py ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ import torchaudio
4
+ import numpy as np
5
+ from transformers import WhisperProcessor, WhisperForConditionalGeneration
6
+
7
+ # Model loading function with caching
8
+ def load_model():
9
+ device = "cuda:0" if torch.cuda.is_available() else "cpu"
10
+ torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
11
+
12
+ model = WhisperForConditionalGeneration.from_pretrained("tclin/whisper-large-v3-turbo-atcosim-finetune")
13
+ model = model.to(device=device, dtype=torch_dtype)
14
+ processor = WhisperProcessor.from_pretrained("tclin/whisper-large-v3-turbo-atcosim-finetune")
15
+
16
+ return model, processor, device, torch_dtype
17
+
18
+ # Load model and processor once at startup
19
+ model, processor, device, torch_dtype = load_model()
20
+
21
+ # Define the transcription function
22
+ def transcribe_audio(audio_file):
23
+ # Check if audio file exists
24
+ if audio_file is None:
25
+ return "Please upload an audio file"
26
+
27
+ try:
28
+ # Load and preprocess audio
29
+ waveform, sample_rate = torchaudio.load(audio_file)
30
+
31
+ # Resample to 16kHz (required for Whisper models)
32
+ if sample_rate != 16000:
33
+ resampler = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=16000)
34
+ waveform = resampler(waveform)
35
+
36
+ # Convert stereo to mono if needed
37
+ if waveform.shape[0] > 1:
38
+ waveform = waveform.mean(dim=0, keepdim=True)
39
+
40
+ # Convert to numpy array
41
+ waveform_np = waveform.squeeze().cpu().numpy()
42
+
43
+ # Process with model
44
+ input_features = processor(waveform_np, sampling_rate=16000, return_tensors="pt").input_features
45
+ input_features = input_features.to(device=device, dtype=torch_dtype)
46
+
47
+ generated_ids = model.generate(input_features, max_new_tokens=128)
48
+ transcription = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
49
+
50
+ return transcription
51
+
52
+ except Exception as e:
53
+ return f"Error processing audio: {str(e)}"
54
+
55
+ # Create Gradio interface
56
+ demo = gr.Interface(
57
+ fn=transcribe_audio,
58
+ inputs=gr.Audio(type="filepath"),
59
+ outputs="text",
60
+ title="ATC Speech Transcription",
61
+ description="Upload an air traffic control audio file and get an accurate transcription using a Whisper model fine-tuned on the ATCOSIM dataset.",
62
+ examples=[
63
+ ["example1.wav"],
64
+ ["example2.wav"]
65
+ ],
66
+ article="This model is fine-tuned on the ATCOSIM dataset to accurately transcribe air traffic control communications with a Word Error Rate (WER) of 3.73%."
67
+ )
68
+
69
+ # Launch the interface
70
+ if __name__ == "__main__":
71
+ demo.launch()