kostissz's picture
Add init demo version
dcbed68 verified
raw
history blame
4.16 kB
import csv
from pathlib import Path
from typing import Tuple
import gradio as gr
from transformers import WhisperProcessor, WhisperForConditionalGeneration
from whisper_bidec import decode_wav, get_logits_processor, load_corpus_from_sentences
def _parse_file(file_path: str) -> list[str]:
"""Parse .txt / .md / .csv and return its content as a list of strings by splitting per new line or row."""
if file_path.endswith(".csv"):
sentences = []
with open(file_path, "r", encoding="utf-8") as f:
reader = csv.reader(f)
for row in reader:
sentences.append(row)
else:
with open(file_path, "r") as f:
sentences = f.readlines()
return sentences
def transcribe(
processor_name: str,
audio: str,
bias_strength: float,
bias_text: str | None,
bias_text_file: str | None,
) -> Tuple[str, str]:
processor = WhisperProcessor.from_pretrained(processor_name)
model = WhisperForConditionalGeneration.from_pretrained(processor_name)
sentences = ""
if bias_text:
sentences = bias_text.split(",")
elif Path(bias_text_file).is_file():
sentences = _parse_file(bias_text_file)
if sentences:
corpus = load_corpus_from_sentences(sentences, processor)
logits_processor = get_logits_processor(
corpus=corpus, processor=processor, bias_towards_lm=bias_strength
)
text_with_bias = decode_wav(
model, processor, audio, logits_processor=logits_processor
)
else:
text_with_bias = ""
text_no_bias = decode_wav(model, processor, audio, logits_processor=None)
return text_no_bias, text_with_bias
def setup_gradio_demo():
css = """
#centered-column {
display: flex;
justify-content: center;
align-items: center;
flex-direction: column;
text-align: center;
}
"""
with gr.Blocks(css=css) as demo:
gr.Markdown("# Whisper Bidec Demo")
gr.Markdown("## Step 1: Select a Whisper model")
processor = gr.Textbox(
value="openai/whisper-tiny.en", label="Whisper Model from Hugging Face"
)
gr.Markdown("## Step 2: Upload your audio file")
audio_clip = gr.Audio(type="filepath", label="Upload a WAV file")
gr.Markdown("## Step 3: Set your biasing text")
with gr.Row():
with gr.Column(scale=20):
gr.Markdown(
"You can add multiple possible sentences by separating them with a comma <,>."
)
bias_text = gr.Textbox(label="Write your biasing text here")
with gr.Column(scale=1, elem_id="centered-column"):
gr.Markdown("## OR")
with gr.Column(scale=20):
gr.Markdown(
"Note that each new line (.txt / .md) or row (.csv) will be treated as a separate sentence to bias towards to."
)
bias_text_file = gr.File(
label="Upload a file with multiple lines of text",
file_types=[".txt", ".md", ".csv"],
)
gr.Markdown("## Step 4: Set how much you want to bias towards the LM")
bias_amount = gr.Slider(
minimum=0.0,
maximum=1.0,
value=0.5,
step=0.1,
label="Bias strength",
interactive=True,
)
gr.Markdown("## Step 5: Get your transcription before and after biasing")
transcribe_button = gr.Button("Transcribe")
with gr.Row():
with gr.Column():
output = gr.Text(label="Output")
with gr.Column():
biased_output = gr.Text(label="Biased output")
transcribe_button.click(
fn=transcribe,
inputs=[
processor,
audio_clip,
bias_amount,
bias_text,
bias_text_file,
],
outputs=[output, biased_output],
)
demo.launch()
if __name__ == "__main__":
setup_gradio_demo()