KDM999 commited on
Commit
a635c25
·
verified ·
1 Parent(s): b21ecd7

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +95 -50
app.py CHANGED
@@ -1,31 +1,38 @@
1
  import gradio as gr
2
  import random
3
  import json
4
- import os
5
  from difflib import SequenceMatcher
6
  from jiwer import wer
7
  import torchaudio
8
  from transformers import pipeline
 
 
9
 
10
  # Load metadata
11
  with open("common_voice_en_validated_249_hf_ready.json") as f:
12
  data = json.load(f)
13
 
14
- # Available filter values
15
  ages = sorted(set(entry["age"] for entry in data))
16
  genders = sorted(set(entry["gender"] for entry in data))
17
  accents = sorted(set(entry["accent"] for entry in data))
18
 
19
- # Load pipelines
20
- device = 0 # 0 for CUDA/GPU, -1 for CPU
21
-
22
- pipe_whisper = pipeline("automatic-speech-recognition", model="openai/whisper-medium", device=device)
23
- pipe_wav2vec2 = pipeline("automatic-speech-recognition", model="facebook/wav2vec2-base-960h", device=device)
24
- pipe_hubert = pipeline("automatic-speech-recognition", model="facebook/hubert-base-ls960", device=device)
25
-
26
- def load_audio(file_path):
27
- waveform, sr = torchaudio.load(file_path)
28
- return torchaudio.functional.resample(waveform, orig_freq=sr, new_freq=16000)[0].numpy()
 
 
 
 
 
 
29
 
30
  def transcribe(pipe, file_path):
31
  result = pipe(file_path)
@@ -35,62 +42,100 @@ def highlight_differences(ref, hyp):
35
  sm = SequenceMatcher(None, ref.split(), hyp.split())
36
  result = []
37
  for opcode, i1, i2, j1, j2 in sm.get_opcodes():
38
- if opcode == 'equal':
39
  result.extend(hyp.split()[j1:j2])
40
- elif opcode in ('replace', 'insert', 'delete'):
41
  wrong = hyp.split()[j1:j2]
42
  result.extend([f"<span style='color:red'>{w}</span>" for w in wrong])
43
  return " ".join(result)
44
 
45
- def run_demo(age, gender, accent):
 
 
 
 
 
 
46
  filtered = [
47
  entry for entry in data
48
  if entry["age"] == age and entry["gender"] == gender and entry["accent"] == accent
49
  ]
50
  if not filtered:
51
- return "No matching sample.", None, "", "", "", "", "", ""
52
-
53
  sample = random.choice(filtered)
54
  file_path = os.path.join("common_voice_en_validated_249", sample["path"])
55
- gold = sample["sentence"].strip().lower()
56
-
57
- whisper_text = transcribe(pipe_whisper, file_path)
58
- wav2vec_text = transcribe(pipe_wav2vec2, file_path)
59
- hubert_text = transcribe(pipe_hubert, file_path)
60
-
61
- table = f"""
62
- <table border="1" style="width:100%">
63
- <tr><th>Model</th><th>Transcription</th><th>WER</th></tr>
64
- <tr><td><b>Gold</b></td><td>{gold}</td><td>0.00</td></tr>
65
- <tr><td>Whisper</td><td>{highlight_differences(gold, whisper_text)}</td><td>{wer(gold, whisper_text):.2f}</td></tr>
66
- <tr><td>Wav2Vec2</td><td>{highlight_differences(gold, wav2vec_text)}</td><td>{wer(gold, wav2vec_text):.2f}</td></tr>
67
- <tr><td>HuBERT</td><td>{highlight_differences(gold, hubert_text)}</td><td>{wer(gold, hubert_text):.2f}</td></tr>
68
- </table>
69
- """
70
-
71
- return sample["sentence"], file_path, gold, whisper_text, wav2vec_text, hubert_text, table, f"Audio path: {file_path}"
72
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
73
  with gr.Blocks() as demo:
74
- gr.Markdown("# ASR Model Comparison on ESL Audio")
75
- gr.Markdown("Filter by age, gender, and accent. Then generate a random ESL learner's audio to compare how Whisper, Wav2Vec2, and HuBERT transcribe it.")
 
 
 
76
 
77
  with gr.Row():
78
  age = gr.Dropdown(choices=ages, label="Age")
79
  gender = gr.Dropdown(choices=genders, label="Gender")
80
  accent = gr.Dropdown(choices=accents, label="Accent")
81
 
82
- btn = gr.Button("Generate and Transcribe")
83
- audio = gr.Audio(label="Audio", type="filepath")
84
- wer_output = gr.HTML()
85
-
86
- btn.click(fn=run_demo, inputs=[age, gender, accent], outputs=[
87
- gr.Textbox(label="Gold (Correct)"),
88
- audio,
89
- gr.Textbox(label="Whisper Output"),
90
- gr.Textbox(label="Wav2Vec2 Output"),
91
- gr.Textbox(label="HuBERT Output"),
92
- wer_output,
93
- gr.Textbox(label="Path")
94
- ])
 
 
 
 
 
 
 
 
 
 
 
 
 
95
 
96
  demo.launch()
 
1
  import gradio as gr
2
  import random
3
  import json
 
4
  from difflib import SequenceMatcher
5
  from jiwer import wer
6
  import torchaudio
7
  from transformers import pipeline
8
+ import os
9
+ import string
10
 
11
  # Load metadata
12
  with open("common_voice_en_validated_249_hf_ready.json") as f:
13
  data = json.load(f)
14
 
15
+ # Prepare dropdown options
16
  ages = sorted(set(entry["age"] for entry in data))
17
  genders = sorted(set(entry["gender"] for entry in data))
18
  accents = sorted(set(entry["accent"] for entry in data))
19
 
20
+ # Load ASR pipelines
21
+ device = 0
22
+ pipe_whisper_medium = pipeline("automatic-speech-recognition", model="openai/whisper-medium", device=device, generate_kwargs={"language": "en"})
23
+ pipe_whisper_base = pipeline("automatic-speech-recognition", model="openai/whisper-base", device=device, generate_kwargs={"language": "en"})
24
+ pipe_whisper_tiny = pipeline("automatic-speech-recognition", model="openai/whisper-tiny", device=device, generate_kwargs={"language": "en"})
25
+ pipe_wav2vec2_base_960h = pipeline("automatic-speech-recognition", model="facebook/wav2vec2-base-960h", device=device)
26
+ pipe_hubert_large_ls960_ft = pipeline("automatic-speech-recognition", model="facebook/hubert-large-ls960-ft", device=device)
27
+
28
+ # Functions
29
+ def convert_to_wav(file_path):
30
+ wav_path = file_path.replace(".mp3", ".wav")
31
+ if not os.path.exists(wav_path):
32
+ waveform, sample_rate = torchaudio.load(file_path)
33
+ waveform = waveform.mean(dim=0, keepdim=True)
34
+ torchaudio.save(wav_path, waveform, sample_rate)
35
+ return wav_path
36
 
37
  def transcribe(pipe, file_path):
38
  result = pipe(file_path)
 
42
  sm = SequenceMatcher(None, ref.split(), hyp.split())
43
  result = []
44
  for opcode, i1, i2, j1, j2 in sm.get_opcodes():
45
+ if opcode == "equal":
46
  result.extend(hyp.split()[j1:j2])
47
+ else:
48
  wrong = hyp.split()[j1:j2]
49
  result.extend([f"<span style='color:red'>{w}</span>" for w in wrong])
50
  return " ".join(result)
51
 
52
+ def normalize(text):
53
+ text = text.lower()
54
+ text = text.translate(str.maketrans('', '', string.punctuation))
55
+ return text.strip()
56
+
57
+ # Generate Audio
58
+ def generate_audio(age, gender, accent):
59
  filtered = [
60
  entry for entry in data
61
  if entry["age"] == age and entry["gender"] == gender and entry["accent"] == accent
62
  ]
63
  if not filtered:
64
+ return None, "No matching sample."
 
65
  sample = random.choice(filtered)
66
  file_path = os.path.join("common_voice_en_validated_249", sample["path"])
67
+ wav_file_path = convert_to_wav(file_path)
68
+ return wav_file_path, wav_file_path
69
+
70
+ # Transcribe & Compare
71
+ def transcribe_audio(file_path):
72
+ if not file_path:
73
+ return "No file selected.", "", "", "", "", "", ""
74
+
75
+ filename_mp3 = os.path.basename(file_path).replace(".wav", ".mp3")
76
+ gold = ""
77
+ for entry in data:
78
+ if entry["path"].endswith(filename_mp3):
79
+ gold = normalize(entry["sentence"])
80
+ break
81
+ if not gold:
82
+ return "Reference not found.", "", "", "", "", "", ""
83
+
84
+ outputs = {}
85
+ models = {
86
+ "openai/whisper-medium": pipe_whisper_medium,
87
+ "openai/whisper-base": pipe_whisper_base,
88
+ "openai/whisper-tiny": pipe_whisper_tiny,
89
+ "facebook/wav2vec2-base-960h": pipe_wav2vec2_base_960h,
90
+ "facebook/hubert-large-ls960-ft": pipe_hubert_large_ls960_ft,
91
+ }
92
+
93
+ for name, model in models.items():
94
+ text = transcribe(model, file_path)
95
+ clean = normalize(text)
96
+ wer_score = wer(gold, clean)
97
+ outputs[name] = f"<b>{name} (WER: {wer_score:.2f}):</b><br>{highlight_differences(gold, clean)}"
98
+
99
+ return (gold, *outputs.values())
100
+
101
+ # Gradio Interface
102
  with gr.Blocks() as demo:
103
+ gr.Markdown("# Comparing ASR Models on Diverse English Speech Samples")
104
+ gr.Markdown("
105
+ This demo compares the transcription performance of six automatic speech recognition (ASR) models on audio samples from English learners. "
106
+ "Users can select speaker metadata (age, gender, accent) to explore how models handle diverse speech profiles. "
107
+ "All samples are drawn from the validated subset (n=249) of the English dataset in the Common Voice Delta Segment 21.0 release.")
108
 
109
  with gr.Row():
110
  age = gr.Dropdown(choices=ages, label="Age")
111
  gender = gr.Dropdown(choices=genders, label="Gender")
112
  accent = gr.Dropdown(choices=accents, label="Accent")
113
 
114
+ generate_btn = gr.Button("Get Audio")
115
+ audio_output = gr.Audio(label="Audio", type="filepath", interactive=False)
116
+ file_path_output = gr.Textbox(label="Audio File Path", visible=False)
117
+
118
+ generate_btn.click(generate_audio, [age, gender, accent], [audio_output, file_path_output])
119
+
120
+ transcribe_btn = gr.Button("Transcribe with All Models")
121
+ gold_text = gr.Textbox(label="Reference (Gold Standard)")
122
+ whisper_medium_html = gr.HTML(label="Whisper Medium")
123
+ whisper_base_html = gr.HTML(label="Whisper Base")
124
+ whisper_tiny_html = gr.HTML(label="Whisper Tiny")
125
+ wav2vec_html = gr.HTML(label="Wav2Vec2 Base")
126
+ hubert_html = gr.HTML(label="HuBERT Large")
127
+
128
+ transcribe_btn.click(
129
+ transcribe_audio,
130
+ inputs=[file_path_output],
131
+ outputs=[
132
+ gold_text,
133
+ whisper_medium_html,
134
+ whisper_base_html,
135
+ whisper_tiny_html,
136
+ wav2vec_html,
137
+ hubert_html,
138
+ ],
139
+ )
140
 
141
  demo.launch()