KDM999 commited on
Commit
1f97be9
·
verified ·
1 Parent(s): da0f868

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +28 -42
app.py CHANGED
@@ -1,15 +1,13 @@
1
  import gradio as gr
2
  import spaces
3
- import torch
4
- from accelerate import init_empty_weights
5
  import random
6
  import json
 
 
7
  from difflib import SequenceMatcher
8
  from jiwer import wer
9
  import torchaudio
10
  from transformers import pipeline
11
- import os
12
- import string
13
 
14
  # Load metadata
15
  with open("common_voice_en_validated_249_hf_ready.json") as f:
@@ -20,19 +18,7 @@ ages = sorted(set(entry["age"] for entry in data))
20
  genders = sorted(set(entry["gender"] for entry in data))
21
  accents = sorted(set(entry["accent"] for entry in data))
22
 
23
- # Load ASR pipelines
24
- pipe_whisper_tiny = pipeline("automatic-speech-recognition", model="openai/whisper-tiny")
25
- pipe_whisper_tiny_en = pipeline("automatic-speech-recognition", model="openai/whisper-tiny.en")
26
- pipe_whisper_base = pipeline("automatic-speech-recognition", model="openai/whisper-base")
27
- pipe_whisper_base_en = pipeline("automatic-speech-recognition", model="openai/whisper-base.en")
28
- pipe_whisper_medium = pipeline("automatic-speech-recognition", model="openai/whisper-medium")
29
- pipe_whisper_medium_en = pipeline("automatic-speech-recognition", model="openai/whisper-medium.en")
30
- pipe_distil_whisper_large = pipeline("automatic-speech-recognition", model="distil-whisper/distil-large-v3.5")
31
- pipe_wav2vec2_base_960h = pipeline("automatic-speech-recognition", model="facebook/wav2vec2-base-960h")
32
- pipe_wav2vec2_large_960h = pipeline("automatic-speech-recognition", model="facebook/wav2vec2-large-960h")
33
- pipe_hubert_large_ls960_ft = pipeline("automatic-speech-recognition", model="facebook/hubert-large-ls960-ft")
34
-
35
- # Functions
36
  def convert_to_wav(file_path):
37
  wav_path = file_path.replace(".mp3", ".wav")
38
  if not os.path.exists(wav_path):
@@ -41,10 +27,6 @@ def convert_to_wav(file_path):
41
  torchaudio.save(wav_path, waveform, sample_rate)
42
  return wav_path
43
 
44
- def transcribe(pipe, file_path):
45
- result = pipe(file_path)
46
- return result["text"].strip().lower()
47
-
48
  def highlight_differences(ref, hyp):
49
  sm = SequenceMatcher(None, ref.split(), hyp.split())
50
  result = []
@@ -74,7 +56,7 @@ def generate_audio(age, gender, accent):
74
  wav_file_path = convert_to_wav(file_path)
75
  return wav_file_path, wav_file_path
76
 
77
- # Transcribe & Compare
78
  @spaces.GPU
79
  def transcribe_audio(file_path):
80
  if not file_path:
@@ -89,29 +71,33 @@ def transcribe_audio(file_path):
89
  if not gold:
90
  return "Reference not found.", "", "", "", "", "", ""
91
 
 
 
 
 
 
 
 
 
 
 
 
 
 
92
  outputs = {}
93
- models = {
94
- "openai/whisper-tiny": pipe_whisper_tiny,
95
- "openai/whisper-tiny.en": pipe_whisper_tiny_en,
96
- "openai/whisper-base": pipe_whisper_base,
97
- "openai/whisper-base.en": pipe_whisper_base_en,
98
- "openai/whisper-medium": pipe_whisper_medium,
99
- "openai/whisper-medium.en": pipe_whisper_medium_en,
100
- "distil-whisper/distil-large-v3.5": pipe_distil_whisper_large,
101
- "facebook/wav2vec2-base-960h": pipe_wav2vec2_base_960h,
102
- "facebook/wav2vec2-large-960h": pipe_wav2vec2_large_960h,
103
- "facebook/hubert-large-ls960-ft": pipe_hubert_large_ls960_ft,
104
- }
105
-
106
- for name, model in models.items():
107
- text = transcribe(model, file_path)
108
- clean = normalize(text)
109
- wer_score = wer(gold, clean)
110
- outputs[name] = f"<b>{name} (WER: {wer_score:.2f}):</b><br>{highlight_differences(gold, clean)}"
111
 
112
  return (gold, *outputs.values())
113
 
114
- # Gradio Interface
115
  with gr.Blocks() as demo:
116
  gr.Markdown("# Comparing ASR Models on Diverse English Speech Samples")
117
  gr.Markdown("""
@@ -119,7 +105,7 @@ with gr.Blocks() as demo:
119
  Users can select age, gender, and accent to generate diverse English audio samples.
120
  The models are evaluated on their ability to transcribe those samples.
121
  Data is sourced from 249 validated entries in the Common Voice English Delta Segment 21.0 release.
122
- """)
123
 
124
  with gr.Row():
125
  age = gr.Dropdown(choices=ages, label="Age")
 
1
  import gradio as gr
2
  import spaces
 
 
3
  import random
4
  import json
5
+ import os
6
+ import string
7
  from difflib import SequenceMatcher
8
  from jiwer import wer
9
  import torchaudio
10
  from transformers import pipeline
 
 
11
 
12
  # Load metadata
13
  with open("common_voice_en_validated_249_hf_ready.json") as f:
 
18
  genders = sorted(set(entry["gender"] for entry in data))
19
  accents = sorted(set(entry["accent"] for entry in data))
20
 
21
+ # Utility functions
 
 
 
 
 
 
 
 
 
 
 
 
22
  def convert_to_wav(file_path):
23
  wav_path = file_path.replace(".mp3", ".wav")
24
  if not os.path.exists(wav_path):
 
27
  torchaudio.save(wav_path, waveform, sample_rate)
28
  return wav_path
29
 
 
 
 
 
30
  def highlight_differences(ref, hyp):
31
  sm = SequenceMatcher(None, ref.split(), hyp.split())
32
  result = []
 
56
  wav_file_path = convert_to_wav(file_path)
57
  return wav_file_path, wav_file_path
58
 
59
+ # Transcribe & Compare (GPU Decorated)
60
  @spaces.GPU
61
  def transcribe_audio(file_path):
62
  if not file_path:
 
71
  if not gold:
72
  return "Reference not found.", "", "", "", "", "", ""
73
 
74
+ model_ids = [
75
+ "openai/whisper-tiny",
76
+ "openai/whisper-tiny.en",
77
+ "openai/whisper-base",
78
+ "openai/whisper-base.en",
79
+ "openai/whisper-medium",
80
+ "openai/whisper-medium.en",
81
+ "distil-whisper/distil-large-v3.5",
82
+ "facebook/wav2vec2-base-960h",
83
+ "facebook/wav2vec2-large-960h",
84
+ "facebook/hubert-large-ls960-ft",
85
+ ]
86
+
87
  outputs = {}
88
+ for model_id in model_ids:
89
+ try:
90
+ pipe = pipeline("automatic-speech-recognition", model=model_id)
91
+ text = pipe(file_path)["text"].strip().lower()
92
+ clean = normalize(text)
93
+ wer_score = wer(gold, clean)
94
+ outputs[model_id] = f"<b>{model_id} (WER: {wer_score:.2f}):</b><br>{highlight_differences(gold, clean)}"
95
+ except Exception as e:
96
+ outputs[model_id] = f"<b>{model_id}:</b><br><span style='color:red'>Error: {str(e)}</span>"
 
 
 
 
 
 
 
 
 
97
 
98
  return (gold, *outputs.values())
99
 
100
+ # Gradio UI
101
  with gr.Blocks() as demo:
102
  gr.Markdown("# Comparing ASR Models on Diverse English Speech Samples")
103
  gr.Markdown("""
 
105
  Users can select age, gender, and accent to generate diverse English audio samples.
106
  The models are evaluated on their ability to transcribe those samples.
107
  Data is sourced from 249 validated entries in the Common Voice English Delta Segment 21.0 release.
108
+ """)
109
 
110
  with gr.Row():
111
  age = gr.Dropdown(choices=ages, label="Age")