Create app.py
Browse files
app.py
ADDED
@@ -0,0 +1,159 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Demo to run OpenAI Whisper using HuggingFace ZeroGPU.
|
3 |
+
|
4 |
+
This way we can test default Whisper models provided by OpenAI, for later comparison with fine-tuned ones.
|
5 |
+
"""
|
6 |
+
|
7 |
+
import subprocess
|
8 |
+
import tempfile
|
9 |
+
from pathlib import Path
|
10 |
+
|
11 |
+
import gradio as gr
|
12 |
+
import spaces
|
13 |
+
import torch
|
14 |
+
import whisper
|
15 |
+
|
16 |
+
YT_AUDIO_FORMAT = "bestaudio[ext=m4a]"
|
17 |
+
|
18 |
+
|
19 |
+
def download_youtube(url: str, tmp_dir: Path) -> Path:
|
20 |
+
"""Download the audio track from a YouTube video and return the local path."""
|
21 |
+
out_path = tmp_dir / r"%\(id)s.%(ext)s"
|
22 |
+
cmd = [
|
23 |
+
"yt-dlp",
|
24 |
+
"--quiet",
|
25 |
+
"--no-warnings",
|
26 |
+
"--extract-audio",
|
27 |
+
"--audio-format",
|
28 |
+
"m4a",
|
29 |
+
"--audio-quality",
|
30 |
+
"0",
|
31 |
+
"-f",
|
32 |
+
YT_AUDIO_FORMAT,
|
33 |
+
"-o",
|
34 |
+
str(out_path),
|
35 |
+
url,
|
36 |
+
]
|
37 |
+
result = subprocess.run(cmd, capture_output=True, check=True)
|
38 |
+
if result.returncode != 0:
|
39 |
+
raise RuntimeError(f"yt-dlp failed: {result.stderr.decode()}")
|
40 |
+
|
41 |
+
files = list(tmp_dir.glob("*.m4a"))
|
42 |
+
if not files:
|
43 |
+
raise FileNotFoundError("Could not locate downloaded audio.")
|
44 |
+
return files[0]
|
45 |
+
|
46 |
+
|
47 |
+
def _get_input_path(audio, youtube_url):
|
48 |
+
if youtube_url and youtube_url.strip():
|
49 |
+
with tempfile.TemporaryDirectory() as tmp:
|
50 |
+
return download_youtube(youtube_url, Path(tmp))
|
51 |
+
elif audio is not None:
|
52 |
+
return audio
|
53 |
+
else:
|
54 |
+
raise gr.Error("Provide audio or a YouTube URL")
|
55 |
+
|
56 |
+
|
57 |
+
def make_results_table(results):
|
58 |
+
rows = []
|
59 |
+
for r in results:
|
60 |
+
row = [r["model"], r["language"], r["text"]]
|
61 |
+
rows.append(row)
|
62 |
+
return rows
|
63 |
+
|
64 |
+
|
65 |
+
@spaces.GPU
|
66 |
+
def transcribe_audio(
|
67 |
+
model_sizes: list[str],
|
68 |
+
audio: str,
|
69 |
+
youtube_url: str,
|
70 |
+
return_timestamps: bool,
|
71 |
+
temperature: float,
|
72 |
+
):
|
73 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
74 |
+
results = []
|
75 |
+
for size in model_sizes:
|
76 |
+
model = whisper.load_model(size, device=device)
|
77 |
+
inp = _get_input_path(audio, youtube_url)
|
78 |
+
out = model.transcribe(
|
79 |
+
str(inp),
|
80 |
+
word_timestamps=return_timestamps,
|
81 |
+
temperature=temperature,
|
82 |
+
verbose=False,
|
83 |
+
)
|
84 |
+
text = out["text"].strip()
|
85 |
+
segments = out["segments"] if return_timestamps else []
|
86 |
+
results.append(
|
87 |
+
{
|
88 |
+
"model": size,
|
89 |
+
"language": out["language"],
|
90 |
+
"text": text,
|
91 |
+
"segments": segments,
|
92 |
+
}
|
93 |
+
)
|
94 |
+
df_results = make_results_table(results, return_timestamps)
|
95 |
+
return df_results
|
96 |
+
|
97 |
+
|
98 |
+
def build_demo() -> gr.Blocks:
|
99 |
+
with gr.Blocks(title="🗣️ Whisper Transcription Demo (HF Spaces Zero-GPU)") as whisper_demo:
|
100 |
+
gr.Markdown("""
|
101 |
+
# Whisper Transcription Demo
|
102 |
+
|
103 |
+
Run Whisper transcription on audio or YouTube video. Whisper is a general-purpose speech recognition model,
|
104 |
+
trained on a large dataset
|
105 |
+
""")
|
106 |
+
|
107 |
+
with gr.Row():
|
108 |
+
model_choices = gr.Dropdown(
|
109 |
+
label="Model size(s)",
|
110 |
+
choices=["tiny", "base", "small", "medium", "large", "turbo"],
|
111 |
+
value=["turbo"],
|
112 |
+
multiselect=True,
|
113 |
+
allow_custom_value=False,
|
114 |
+
)
|
115 |
+
ts_checkbox = gr.Checkbox(
|
116 |
+
label="Return word timestamps",
|
117 |
+
interactive=False,
|
118 |
+
value=False,
|
119 |
+
)
|
120 |
+
temp_slider = gr.Slider(
|
121 |
+
label="Decoding temperature",
|
122 |
+
minimum=0.0,
|
123 |
+
maximum=1.0,
|
124 |
+
value=0.0,
|
125 |
+
step=0.01,
|
126 |
+
)
|
127 |
+
|
128 |
+
audio_input = gr.Audio(
|
129 |
+
label="Upload or record audio",
|
130 |
+
sources=["upload"],
|
131 |
+
type="filepath",
|
132 |
+
)
|
133 |
+
|
134 |
+
yt_input = gr.Textbox(
|
135 |
+
label="... or paste a YouTube URL (audio only)",
|
136 |
+
placeholder="https://youtu.be/XYZ",
|
137 |
+
)
|
138 |
+
|
139 |
+
with gr.Row():
|
140 |
+
transcribe_btn = gr.Button("Transcribe 🏁")
|
141 |
+
|
142 |
+
out_table = gr.Dataframe(
|
143 |
+
headers=["Model", "Language", "Transcript"],
|
144 |
+
datatype=["str", "str", "str"],
|
145 |
+
label="Transcription Results",
|
146 |
+
)
|
147 |
+
|
148 |
+
transcribe_btn.click(
|
149 |
+
transcribe_audio,
|
150 |
+
inputs=[model_choices, audio_input, yt_input, ts_checkbox, temp_slider],
|
151 |
+
outputs=[out_table],
|
152 |
+
)
|
153 |
+
|
154 |
+
return whisper_demo
|
155 |
+
|
156 |
+
|
157 |
+
if __name__ == "__main__":
|
158 |
+
demo = build_demo()
|
159 |
+
demo.launch()
|