File size: 4,155 Bytes
dcbed68
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
import csv
from pathlib import Path
from typing import Tuple

import gradio as gr
from transformers import WhisperProcessor, WhisperForConditionalGeneration
from whisper_bidec import decode_wav, get_logits_processor, load_corpus_from_sentences


def _parse_file(file_path: str) -> list[str]:
    """Parse .txt / .md / .csv and return its content as a list of strings by splitting per new line or row."""

    if file_path.endswith(".csv"):
        sentences = []
        with open(file_path, "r", encoding="utf-8") as f:
            reader = csv.reader(f)
            for row in reader:
                sentences.append(row)
    else:
        with open(file_path, "r") as f:
            sentences = f.readlines()
    return sentences


def transcribe(
    processor_name: str,
    audio: str,
    bias_strength: float,
    bias_text: str | None,
    bias_text_file: str | None,
) -> Tuple[str, str]:
    processor = WhisperProcessor.from_pretrained(processor_name)
    model = WhisperForConditionalGeneration.from_pretrained(processor_name)

    sentences = ""

    if bias_text:
        sentences = bias_text.split(",")
    elif Path(bias_text_file).is_file():
        sentences = _parse_file(bias_text_file)

    if sentences:
        corpus = load_corpus_from_sentences(sentences, processor)
        logits_processor = get_logits_processor(
            corpus=corpus, processor=processor, bias_towards_lm=bias_strength
        )
        text_with_bias = decode_wav(
            model, processor, audio, logits_processor=logits_processor
        )
    else:
        text_with_bias = ""

    text_no_bias = decode_wav(model, processor, audio, logits_processor=None)

    return text_no_bias, text_with_bias


def setup_gradio_demo():
    css = """
    #centered-column {
        display: flex;
        justify-content: center;
        align-items: center;
        flex-direction: column; 
        text-align: center;
    }
    """
    with gr.Blocks(css=css) as demo:
        gr.Markdown("# Whisper Bidec Demo")

        gr.Markdown("## Step 1: Select a Whisper model")
        processor = gr.Textbox(
            value="openai/whisper-tiny.en", label="Whisper Model from Hugging Face"
        )

        gr.Markdown("## Step 2: Upload your audio file")
        audio_clip = gr.Audio(type="filepath", label="Upload a WAV file")

        gr.Markdown("## Step 3: Set your biasing text")
        with gr.Row():
            with gr.Column(scale=20):
                gr.Markdown(
                    "You can add multiple possible sentences by separating them with a comma <,>."
                )
                bias_text = gr.Textbox(label="Write your biasing text here")
            with gr.Column(scale=1, elem_id="centered-column"):
                gr.Markdown("## OR")
            with gr.Column(scale=20):
                gr.Markdown(
                    "Note that each new line (.txt / .md) or row (.csv) will be treated as a separate sentence to bias towards to."
                )
                bias_text_file = gr.File(
                    label="Upload a file with multiple lines of text",
                    file_types=[".txt", ".md", ".csv"],
                )

        gr.Markdown("## Step 4: Set how much you want to bias towards the LM")
        bias_amount = gr.Slider(
            minimum=0.0,
            maximum=1.0,
            value=0.5,
            step=0.1,
            label="Bias strength",
            interactive=True,
        )

        gr.Markdown("## Step 5: Get your transcription before and after biasing")
        transcribe_button = gr.Button("Transcribe")

        with gr.Row():
            with gr.Column():
                output = gr.Text(label="Output")
            with gr.Column():
                biased_output = gr.Text(label="Biased output")

        transcribe_button.click(
            fn=transcribe,
            inputs=[
                processor,
                audio_clip,
                bias_amount,
                bias_text,
                bias_text_file,
            ],
            outputs=[output, biased_output],
        )
    demo.launch()


if __name__ == "__main__":
    setup_gradio_demo()