101Frost commited on
Commit
1acef58
·
verified ·
1 Parent(s): df2e373

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +51 -27
app.py CHANGED
@@ -8,6 +8,8 @@ import difflib
8
  import editdistance
9
  from jiwer import wer
10
  import json
 
 
11
 
12
  # Load both models at startup
13
  MODELS = {
@@ -19,7 +21,7 @@ MODELS = {
19
  "English": {
20
  "processor": Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-base-960h"),
21
  "model": Wav2Vec2ForCTC.from_pretrained("facebook/wav2vec2-base-960h"),
22
- "epitran": epitran.Epitran("eng-Latn")
23
  }
24
  }
25
 
@@ -27,9 +29,29 @@ MODELS = {
27
  for lang in MODELS.values():
28
  lang["model"].config.ctc_loss_reduction = "mean"
29
 
30
- def clean_phonemes(ipa):
31
  """Remove diacritics and length markers from phonemes"""
32
- return re.sub(r'[\u064B-\u0652\u02D0]', '', ipa)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
 
34
  def analyze_phonemes(language, reference_text, audio_file):
35
  # Get the appropriate model components
@@ -37,31 +59,34 @@ def analyze_phonemes(language, reference_text, audio_file):
37
  processor = lang_models["processor"]
38
  model = lang_models["model"]
39
  epi = lang_models["epitran"]
40
-
 
 
 
 
 
41
  # Convert reference text to phonemes
42
  ref_phonemes = []
43
  for word in reference_text.split():
44
- ipa = epi.transliterate(word)
45
- ipa_clean = clean_phonemes(ipa)
46
  ref_phonemes.append(list(ipa_clean))
47
-
48
  # Process audio file
49
  audio, sr = librosa.load(audio_file, sr=16000)
50
  input_values = processor(audio, sampling_rate=16000, return_tensors="pt").input_values
51
-
52
  # Get transcription
53
  with torch.no_grad():
54
  logits = model(input_values).logits
55
  pred_ids = torch.argmax(logits, dim=-1)
56
  transcription = processor.batch_decode(pred_ids)[0].strip()
57
-
58
  # Convert transcription to phonemes
59
  obs_phonemes = []
60
  for word in transcription.split():
61
- ipa = epi.transliterate(word)
62
- ipa_clean = clean_phonemes(ipa)
63
  obs_phonemes.append(list(ipa_clean))
64
-
65
  # Prepare results in JSON format
66
  results = {
67
  "language": language,
@@ -70,20 +95,20 @@ def analyze_phonemes(language, reference_text, audio_file):
70
  "word_alignment": [],
71
  "metrics": {}
72
  }
73
-
74
  # Calculate metrics
75
  total_phoneme_errors = 0
76
  total_phoneme_length = 0
77
  correct_words = 0
78
  total_word_length = len(ref_phonemes)
79
-
80
  # Word-by-word alignment
81
  for i, (ref, obs) in enumerate(zip(ref_phonemes, obs_phonemes)):
82
  ref_str = ''.join(ref)
83
  obs_str = ''.join(obs)
84
  edits = editdistance.eval(ref, obs)
85
  acc = round((1 - edits / max(1, len(ref))) * 100, 2)
86
-
87
  # Get error details
88
  matcher = difflib.SequenceMatcher(None, ref, obs)
89
  ops = matcher.get_opcodes()
@@ -97,7 +122,7 @@ def analyze_phonemes(language, reference_text, audio_file):
97
  "reference": ref_seg,
98
  "observed": obs_seg
99
  })
100
-
101
  results["word_alignment"].append({
102
  "word_index": i,
103
  "reference_phonemes": ref_str,
@@ -107,18 +132,18 @@ def analyze_phonemes(language, reference_text, audio_file):
107
  "is_correct": edits == 0,
108
  "errors": error_details
109
  })
110
-
111
  total_phoneme_errors += edits
112
  total_phoneme_length += len(ref)
113
  correct_words += 1 if edits == 0 else 0
114
-
115
- # Calculate metrics
116
  phoneme_acc = round((1 - total_phoneme_errors / max(1, total_phoneme_length)) * 100, 2)
117
  phoneme_er = round((total_phoneme_errors / max(1, total_phoneme_length)) * 100, 2)
118
  word_acc = round((correct_words / max(1, total_word_length)) * 100, 2)
119
  word_er = round(((total_word_length - correct_words) / max(1, total_word_length)) * 100, 2)
120
  text_wer = round(wer(reference_text, transcription) * 100, 2)
121
-
122
  results["metrics"] = {
123
  "word_accuracy": word_acc,
124
  "word_error_rate": word_er,
@@ -126,7 +151,7 @@ def analyze_phonemes(language, reference_text, audio_file):
126
  "phoneme_error_rate": phoneme_er,
127
  "asr_word_error_rate": text_wer
128
  }
129
-
130
  return json.dumps(results, indent=2, ensure_ascii=False)
131
 
132
  # Create Gradio interface with language-specific default text
@@ -139,7 +164,7 @@ def get_default_text(language):
139
  with gr.Blocks() as demo:
140
  gr.Markdown("# Multilingual Phoneme Alignment Analysis")
141
  gr.Markdown("Compare audio pronunciation with reference text at phoneme level")
142
-
143
  with gr.Row():
144
  language = gr.Dropdown(
145
  ["Arabic", "English"],
@@ -150,22 +175,21 @@ with gr.Blocks() as demo:
150
  label="Reference Text",
151
  value=get_default_text("Arabic")
152
  )
153
-
154
  audio_input = gr.Audio(label="Upload Audio File", type="filepath")
155
  submit_btn = gr.Button("Analyze")
156
  output = gr.JSON(label="Phoneme Alignment Results")
157
-
158
- # Update default text when language changes
159
  language.change(
160
  fn=get_default_text,
161
  inputs=language,
162
  outputs=reference_text
163
  )
164
-
165
  submit_btn.click(
166
  fn=analyze_phonemes,
167
  inputs=[language, reference_text, audio_input],
168
  outputs=output
169
  )
170
 
171
- demo.launch()
 
8
  import editdistance
9
  from jiwer import wer
10
  import json
11
+ import string
12
+ import eng_to_ipa as ipa
13
 
14
  # Load both models at startup
15
  MODELS = {
 
21
  "English": {
22
  "processor": Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-base-960h"),
23
  "model": Wav2Vec2ForCTC.from_pretrained("facebook/wav2vec2-base-960h"),
24
+ "epitran": None # Not needed; using eng_to_ipa
25
  }
26
  }
27
 
 
29
  for lang in MODELS.values():
30
  lang["model"].config.ctc_loss_reduction = "mean"
31
 
32
+ def clean_phonemes(ipa_text):
33
  """Remove diacritics and length markers from phonemes"""
34
+ return re.sub(r'[\u064B-\u0652\u02D0]', '', ipa_text)
35
+
36
+ def safe_transliterate_arabic(epi, word):
37
+ try:
38
+ word = word.strip()
39
+ ipa = epi.transliterate(word)
40
+ if not ipa.strip():
41
+ raise ValueError("Empty IPA string")
42
+ return clean_phonemes(ipa)
43
+ except Exception as e:
44
+ print(f"[Warning] Arabic transliteration failed for '{word}': {e}")
45
+ return ""
46
+
47
+ def transliterate_english(word):
48
+ try:
49
+ word = word.lower().translate(str.maketrans('', '', string.punctuation))
50
+ ipa_text = ipa.convert(word)
51
+ return clean_phonemes(ipa_text)
52
+ except Exception as e:
53
+ print(f"[Warning] English IPA conversion failed for '{word}': {e}")
54
+ return ""
55
 
56
  def analyze_phonemes(language, reference_text, audio_file):
57
  # Get the appropriate model components
 
59
  processor = lang_models["processor"]
60
  model = lang_models["model"]
61
  epi = lang_models["epitran"]
62
+
63
+ if language == "Arabic":
64
+ transliterate_fn = lambda word: safe_transliterate_arabic(epi, word)
65
+ else:
66
+ transliterate_fn = transliterate_english
67
+
68
  # Convert reference text to phonemes
69
  ref_phonemes = []
70
  for word in reference_text.split():
71
+ ipa_clean = transliterate_fn(word)
 
72
  ref_phonemes.append(list(ipa_clean))
73
+
74
  # Process audio file
75
  audio, sr = librosa.load(audio_file, sr=16000)
76
  input_values = processor(audio, sampling_rate=16000, return_tensors="pt").input_values
77
+
78
  # Get transcription
79
  with torch.no_grad():
80
  logits = model(input_values).logits
81
  pred_ids = torch.argmax(logits, dim=-1)
82
  transcription = processor.batch_decode(pred_ids)[0].strip()
83
+
84
  # Convert transcription to phonemes
85
  obs_phonemes = []
86
  for word in transcription.split():
87
+ ipa_clean = transliterate_fn(word)
 
88
  obs_phonemes.append(list(ipa_clean))
89
+
90
  # Prepare results in JSON format
91
  results = {
92
  "language": language,
 
95
  "word_alignment": [],
96
  "metrics": {}
97
  }
98
+
99
  # Calculate metrics
100
  total_phoneme_errors = 0
101
  total_phoneme_length = 0
102
  correct_words = 0
103
  total_word_length = len(ref_phonemes)
104
+
105
  # Word-by-word alignment
106
  for i, (ref, obs) in enumerate(zip(ref_phonemes, obs_phonemes)):
107
  ref_str = ''.join(ref)
108
  obs_str = ''.join(obs)
109
  edits = editdistance.eval(ref, obs)
110
  acc = round((1 - edits / max(1, len(ref))) * 100, 2)
111
+
112
  # Get error details
113
  matcher = difflib.SequenceMatcher(None, ref, obs)
114
  ops = matcher.get_opcodes()
 
122
  "reference": ref_seg,
123
  "observed": obs_seg
124
  })
125
+
126
  results["word_alignment"].append({
127
  "word_index": i,
128
  "reference_phonemes": ref_str,
 
132
  "is_correct": edits == 0,
133
  "errors": error_details
134
  })
135
+
136
  total_phoneme_errors += edits
137
  total_phoneme_length += len(ref)
138
  correct_words += 1 if edits == 0 else 0
139
+
140
+ # Final metrics
141
  phoneme_acc = round((1 - total_phoneme_errors / max(1, total_phoneme_length)) * 100, 2)
142
  phoneme_er = round((total_phoneme_errors / max(1, total_phoneme_length)) * 100, 2)
143
  word_acc = round((correct_words / max(1, total_word_length)) * 100, 2)
144
  word_er = round(((total_word_length - correct_words) / max(1, total_word_length)) * 100, 2)
145
  text_wer = round(wer(reference_text, transcription) * 100, 2)
146
+
147
  results["metrics"] = {
148
  "word_accuracy": word_acc,
149
  "word_error_rate": word_er,
 
151
  "phoneme_error_rate": phoneme_er,
152
  "asr_word_error_rate": text_wer
153
  }
154
+
155
  return json.dumps(results, indent=2, ensure_ascii=False)
156
 
157
  # Create Gradio interface with language-specific default text
 
164
  with gr.Blocks() as demo:
165
  gr.Markdown("# Multilingual Phoneme Alignment Analysis")
166
  gr.Markdown("Compare audio pronunciation with reference text at phoneme level")
167
+
168
  with gr.Row():
169
  language = gr.Dropdown(
170
  ["Arabic", "English"],
 
175
  label="Reference Text",
176
  value=get_default_text("Arabic")
177
  )
178
+
179
  audio_input = gr.Audio(label="Upload Audio File", type="filepath")
180
  submit_btn = gr.Button("Analyze")
181
  output = gr.JSON(label="Phoneme Alignment Results")
182
+
 
183
  language.change(
184
  fn=get_default_text,
185
  inputs=language,
186
  outputs=reference_text
187
  )
188
+
189
  submit_btn.click(
190
  fn=analyze_phonemes,
191
  inputs=[language, reference_text, audio_input],
192
  outputs=output
193
  )
194
 
195
+ demo.launch()