""" Demo to run OpenAI Whisper using HuggingFace ZeroGPU. This way we can test default Whisper models provided by OpenAI, for later comparison with fine-tuned ones. """ import subprocess import tempfile from pathlib import Path import gradio as gr import spaces import torch import whisper YT_AUDIO_FORMAT = "bestaudio[ext=m4a]" MODEL_SIZES = ["tiny", "base", "small", "medium", "large", "turbo"] for size in MODEL_SIZES: whisper.load_model(size, device="cpu") def download_youtube(url: str, tmp_dir: Path) -> Path: """Download the audio track from a YouTube video and return the local path.""" out_path = tmp_dir / r"%\(id)s.%(ext)s" cmd = [ "yt-dlp", "--quiet", "--no-warnings", "--extract-audio", "--audio-format", "m4a", "--audio-quality", "0", "-f", YT_AUDIO_FORMAT, "-o", str(out_path), url, ] result = subprocess.run(cmd, capture_output=True, check=True) if result.returncode != 0: raise RuntimeError(f"yt-dlp failed: {result.stderr.decode()}") files = list(tmp_dir.glob("*.m4a")) if not files: raise FileNotFoundError("Could not locate downloaded audio.") return files[0] def _get_input_path(audio, youtube_url): if youtube_url and youtube_url.strip(): with tempfile.TemporaryDirectory() as tmp: return download_youtube(youtube_url, Path(tmp)) elif audio is not None: return audio else: raise gr.Error("Provide audio or a YouTube URL") def make_results_table(results): rows = [] for r in results: row = [r["model"], r["language"], r["text"]] rows.append(row) return rows @spaces.GPU def transcribe_audio( model_sizes: list[str], audio: str, youtube_url: str, return_timestamps: bool, temperature: float, logprob_threshold: float = -1.0, no_speech_threshold: float = 0.6, compression_ratio_threshold: float = 2.4, ): device = torch.device("cuda" if torch.cuda.is_available() else "cpu") results = [] for size in model_sizes: model = whisper.load_model(size, device=device) inp = _get_input_path(audio, youtube_url) out = model.transcribe( str(inp), word_timestamps=return_timestamps, temperature=temperature, verbose=False, logprob_threshold=logprob_threshold, no_speech_threshold=no_speech_threshold, compression_ratio_threshold=compression_ratio_threshold, ) text = out["text"].strip() segments = out["segments"] if return_timestamps else [] results.append( { "model": size, "language": out["language"], "text": text, "segments": segments, } ) df_results = make_results_table(results) return df_results def build_demo() -> gr.Blocks: with gr.Blocks(title="🗣️ Whisper Transcription Demo (HF Spaces Zero-GPU)") as whisper_demo: gr.Markdown(""" # Whisper Transcription Demo Run Whisper transcription on audio or YouTube video. Whisper is a general-purpose speech recognition model, trained on a large dataset """) with gr.Row(): model_choices = gr.Dropdown( label="Model size(s)", choices=MODEL_SIZES, value=["turbo"], multiselect=True, allow_custom_value=False, ) ts_checkbox = gr.Checkbox( label="Return word timestamps", interactive=False, value=False, ) temp_slider = gr.Slider( label="Decoding temperature", minimum=0.0, maximum=1.0, value=0.0, step=0.01, ) logprob_slider = gr.Slider( label="Average log-probability threshold", minimum=-10.0, maximum=0.0, value=-1.0, step=0.1, ) no_speech_slider = gr.Slider( label="No-speech probability threshold", minimum=0.0, maximum=1.0, value=0.6, step=0.01, ) compression_slider = gr.Slider( label="Compression ratio threshold", minimum=1.0, maximum=5.0, value=2.4, step=0.1, ) audio_input = gr.Audio( label="Upload or record audio", sources=["upload"], type="filepath", ) yt_input = gr.Textbox( label="... or paste a YouTube URL (audio only)", placeholder="https://youtu.be/XYZ", ) with gr.Row(): transcribe_btn = gr.Button("Transcribe 🏁") out_table = gr.Dataframe( headers=["Model", "Language", "Transcript"], datatype=["str", "str", "str"], label="Transcription Results", ) transcribe_btn.click( transcribe_audio, inputs=[ model_choices, audio_input, yt_input, ts_checkbox, temp_slider, logprob_slider, no_speech_slider, compression_slider, ], outputs=[out_table], ) return whisper_demo if __name__ == "__main__": demo = build_demo() demo.launch()