ankandrew commited on
Commit
0ce1e8d
·
verified ·
1 Parent(s): 6cc4b64

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +159 -0
app.py ADDED
@@ -0,0 +1,159 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Demo to run OpenAI Whisper using HuggingFace ZeroGPU.
3
+
4
+ This way we can test default Whisper models provided by OpenAI, for later comparison with fine-tuned ones.
5
+ """
6
+
7
+ import subprocess
8
+ import tempfile
9
+ from pathlib import Path
10
+
11
+ import gradio as gr
12
+ import spaces
13
+ import torch
14
+ import whisper
15
+
16
+ YT_AUDIO_FORMAT = "bestaudio[ext=m4a]"
17
+
18
+
19
+ def download_youtube(url: str, tmp_dir: Path) -> Path:
20
+ """Download the audio track from a YouTube video and return the local path."""
21
+ out_path = tmp_dir / r"%\(id)s.%(ext)s"
22
+ cmd = [
23
+ "yt-dlp",
24
+ "--quiet",
25
+ "--no-warnings",
26
+ "--extract-audio",
27
+ "--audio-format",
28
+ "m4a",
29
+ "--audio-quality",
30
+ "0",
31
+ "-f",
32
+ YT_AUDIO_FORMAT,
33
+ "-o",
34
+ str(out_path),
35
+ url,
36
+ ]
37
+ result = subprocess.run(cmd, capture_output=True, check=True)
38
+ if result.returncode != 0:
39
+ raise RuntimeError(f"yt-dlp failed: {result.stderr.decode()}")
40
+
41
+ files = list(tmp_dir.glob("*.m4a"))
42
+ if not files:
43
+ raise FileNotFoundError("Could not locate downloaded audio.")
44
+ return files[0]
45
+
46
+
47
+ def _get_input_path(audio, youtube_url):
48
+ if youtube_url and youtube_url.strip():
49
+ with tempfile.TemporaryDirectory() as tmp:
50
+ return download_youtube(youtube_url, Path(tmp))
51
+ elif audio is not None:
52
+ return audio
53
+ else:
54
+ raise gr.Error("Provide audio or a YouTube URL")
55
+
56
+
57
+ def make_results_table(results):
58
+ rows = []
59
+ for r in results:
60
+ row = [r["model"], r["language"], r["text"]]
61
+ rows.append(row)
62
+ return rows
63
+
64
+
65
+ @spaces.GPU
66
+ def transcribe_audio(
67
+ model_sizes: list[str],
68
+ audio: str,
69
+ youtube_url: str,
70
+ return_timestamps: bool,
71
+ temperature: float,
72
+ ):
73
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
74
+ results = []
75
+ for size in model_sizes:
76
+ model = whisper.load_model(size, device=device)
77
+ inp = _get_input_path(audio, youtube_url)
78
+ out = model.transcribe(
79
+ str(inp),
80
+ word_timestamps=return_timestamps,
81
+ temperature=temperature,
82
+ verbose=False,
83
+ )
84
+ text = out["text"].strip()
85
+ segments = out["segments"] if return_timestamps else []
86
+ results.append(
87
+ {
88
+ "model": size,
89
+ "language": out["language"],
90
+ "text": text,
91
+ "segments": segments,
92
+ }
93
+ )
94
+ df_results = make_results_table(results, return_timestamps)
95
+ return df_results
96
+
97
+
98
+ def build_demo() -> gr.Blocks:
99
+ with gr.Blocks(title="🗣️ Whisper Transcription Demo (HF Spaces Zero-GPU)") as whisper_demo:
100
+ gr.Markdown("""
101
+ # Whisper Transcription Demo
102
+
103
+ Run Whisper transcription on audio or YouTube video. Whisper is a general-purpose speech recognition model,
104
+ trained on a large dataset
105
+ """)
106
+
107
+ with gr.Row():
108
+ model_choices = gr.Dropdown(
109
+ label="Model size(s)",
110
+ choices=["tiny", "base", "small", "medium", "large", "turbo"],
111
+ value=["turbo"],
112
+ multiselect=True,
113
+ allow_custom_value=False,
114
+ )
115
+ ts_checkbox = gr.Checkbox(
116
+ label="Return word timestamps",
117
+ interactive=False,
118
+ value=False,
119
+ )
120
+ temp_slider = gr.Slider(
121
+ label="Decoding temperature",
122
+ minimum=0.0,
123
+ maximum=1.0,
124
+ value=0.0,
125
+ step=0.01,
126
+ )
127
+
128
+ audio_input = gr.Audio(
129
+ label="Upload or record audio",
130
+ sources=["upload"],
131
+ type="filepath",
132
+ )
133
+
134
+ yt_input = gr.Textbox(
135
+ label="... or paste a YouTube URL (audio only)",
136
+ placeholder="https://youtu.be/XYZ",
137
+ )
138
+
139
+ with gr.Row():
140
+ transcribe_btn = gr.Button("Transcribe 🏁")
141
+
142
+ out_table = gr.Dataframe(
143
+ headers=["Model", "Language", "Transcript"],
144
+ datatype=["str", "str", "str"],
145
+ label="Transcription Results",
146
+ )
147
+
148
+ transcribe_btn.click(
149
+ transcribe_audio,
150
+ inputs=[model_choices, audio_input, yt_input, ts_checkbox, temp_slider],
151
+ outputs=[out_table],
152
+ )
153
+
154
+ return whisper_demo
155
+
156
+
157
+ if __name__ == "__main__":
158
+ demo = build_demo()
159
+ demo.launch()