|
""" |
|
Demo to run OpenAI Whisper using HuggingFace ZeroGPU. |
|
|
|
This way we can test default Whisper models provided by OpenAI, for later comparison with fine-tuned ones. |
|
""" |
|
|
|
import subprocess |
|
import tempfile |
|
from pathlib import Path |
|
|
|
import gradio as gr |
|
import spaces |
|
import torch |
|
import whisper |
|
|
|
YT_AUDIO_FORMAT = "bestaudio[ext=m4a]" |
|
|
|
|
|
MODEL_SIZES = ["tiny", "base", "small", "medium", "large", "turbo"] |
|
for size in MODEL_SIZES: |
|
whisper.load_model(size, device="cpu") |
|
|
|
|
|
def download_youtube(url: str, tmp_dir: Path) -> Path: |
|
"""Download the audio track from a YouTube video and return the local path.""" |
|
out_path = tmp_dir / r"%\(id)s.%(ext)s" |
|
cmd = [ |
|
"yt-dlp", |
|
"--quiet", |
|
"--no-warnings", |
|
"--extract-audio", |
|
"--audio-format", |
|
"m4a", |
|
"--audio-quality", |
|
"0", |
|
"-f", |
|
YT_AUDIO_FORMAT, |
|
"-o", |
|
str(out_path), |
|
url, |
|
] |
|
result = subprocess.run(cmd, capture_output=True, check=True) |
|
if result.returncode != 0: |
|
raise RuntimeError(f"yt-dlp failed: {result.stderr.decode()}") |
|
|
|
files = list(tmp_dir.glob("*.m4a")) |
|
if not files: |
|
raise FileNotFoundError("Could not locate downloaded audio.") |
|
return files[0] |
|
|
|
|
|
def _get_input_path(audio, youtube_url): |
|
if youtube_url and youtube_url.strip(): |
|
with tempfile.TemporaryDirectory() as tmp: |
|
return download_youtube(youtube_url, Path(tmp)) |
|
elif audio is not None: |
|
return audio |
|
else: |
|
raise gr.Error("Provide audio or a YouTube URL") |
|
|
|
|
|
def make_results_table(results): |
|
rows = [] |
|
for r in results: |
|
row = [r["model"], r["language"], r["text"]] |
|
rows.append(row) |
|
return rows |
|
|
|
|
|
@spaces.GPU |
|
def transcribe_audio( |
|
model_sizes: list[str], |
|
audio: str, |
|
youtube_url: str, |
|
return_timestamps: bool, |
|
temperature: float, |
|
logprob_threshold: float = -1.0, |
|
no_speech_threshold: float = 0.6, |
|
compression_ratio_threshold: float = 2.4, |
|
): |
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
results = [] |
|
for size in model_sizes: |
|
model = whisper.load_model(size, device=device) |
|
inp = _get_input_path(audio, youtube_url) |
|
out = model.transcribe( |
|
str(inp), |
|
word_timestamps=return_timestamps, |
|
temperature=temperature, |
|
verbose=False, |
|
logprob_threshold=logprob_threshold, |
|
no_speech_threshold=no_speech_threshold, |
|
compression_ratio_threshold=compression_ratio_threshold, |
|
) |
|
text = out["text"].strip() |
|
segments = out["segments"] if return_timestamps else [] |
|
results.append( |
|
{ |
|
"model": size, |
|
"language": out["language"], |
|
"text": text, |
|
"segments": segments, |
|
} |
|
) |
|
df_results = make_results_table(results) |
|
return df_results |
|
|
|
|
|
def build_demo() -> gr.Blocks: |
|
with gr.Blocks(title="🗣️ Whisper Transcription Demo (HF Spaces Zero-GPU)") as whisper_demo: |
|
gr.Markdown(""" |
|
# Whisper Transcription Demo |
|
|
|
Run Whisper transcription on audio or YouTube video. Whisper is a general-purpose speech recognition model, |
|
trained on a large dataset |
|
""") |
|
|
|
with gr.Row(): |
|
model_choices = gr.Dropdown( |
|
label="Model size(s)", |
|
choices=MODEL_SIZES, |
|
value=["turbo"], |
|
multiselect=True, |
|
allow_custom_value=False, |
|
) |
|
ts_checkbox = gr.Checkbox( |
|
label="Return word timestamps", |
|
interactive=False, |
|
value=False, |
|
) |
|
temp_slider = gr.Slider( |
|
label="Decoding temperature", |
|
minimum=0.0, |
|
maximum=1.0, |
|
value=0.0, |
|
step=0.01, |
|
) |
|
|
|
logprob_slider = gr.Slider( |
|
label="Average log-probability threshold", |
|
minimum=-10.0, |
|
maximum=0.0, |
|
value=-1.0, |
|
step=0.1, |
|
) |
|
no_speech_slider = gr.Slider( |
|
label="No-speech probability threshold", |
|
minimum=0.0, |
|
maximum=1.0, |
|
value=0.6, |
|
step=0.01, |
|
) |
|
compression_slider = gr.Slider( |
|
label="Compression ratio threshold", |
|
minimum=1.0, |
|
maximum=5.0, |
|
value=2.4, |
|
step=0.1, |
|
) |
|
|
|
audio_input = gr.Audio( |
|
label="Upload or record audio", |
|
sources=["upload"], |
|
type="filepath", |
|
) |
|
|
|
yt_input = gr.Textbox( |
|
label="... or paste a YouTube URL (audio only)", |
|
placeholder="https://youtu.be/XYZ", |
|
) |
|
|
|
with gr.Row(): |
|
transcribe_btn = gr.Button("Transcribe 🏁") |
|
|
|
out_table = gr.Dataframe( |
|
headers=["Model", "Language", "Transcript"], |
|
datatype=["str", "str", "str"], |
|
label="Transcription Results", |
|
) |
|
|
|
transcribe_btn.click( |
|
transcribe_audio, |
|
inputs=[ |
|
model_choices, |
|
audio_input, |
|
yt_input, |
|
ts_checkbox, |
|
temp_slider, |
|
logprob_slider, |
|
no_speech_slider, |
|
compression_slider, |
|
], |
|
outputs=[out_table], |
|
) |
|
|
|
return whisper_demo |
|
|
|
|
|
if __name__ == "__main__": |
|
demo = build_demo() |
|
demo.launch() |
|
|