Update app.py
Browse files
app.py
CHANGED
@@ -8,6 +8,8 @@ import difflib
|
|
8 |
import editdistance
|
9 |
from jiwer import wer
|
10 |
import json
|
|
|
|
|
11 |
|
12 |
# Load both models at startup
|
13 |
MODELS = {
|
@@ -19,7 +21,7 @@ MODELS = {
|
|
19 |
"English": {
|
20 |
"processor": Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-base-960h"),
|
21 |
"model": Wav2Vec2ForCTC.from_pretrained("facebook/wav2vec2-base-960h"),
|
22 |
-
"epitran":
|
23 |
}
|
24 |
}
|
25 |
|
@@ -27,9 +29,29 @@ MODELS = {
|
|
27 |
for lang in MODELS.values():
|
28 |
lang["model"].config.ctc_loss_reduction = "mean"
|
29 |
|
30 |
-
def clean_phonemes(
|
31 |
"""Remove diacritics and length markers from phonemes"""
|
32 |
-
return re.sub(r'[\u064B-\u0652\u02D0]', '',
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
33 |
|
34 |
def analyze_phonemes(language, reference_text, audio_file):
|
35 |
# Get the appropriate model components
|
@@ -37,31 +59,34 @@ def analyze_phonemes(language, reference_text, audio_file):
|
|
37 |
processor = lang_models["processor"]
|
38 |
model = lang_models["model"]
|
39 |
epi = lang_models["epitran"]
|
40 |
-
|
|
|
|
|
|
|
|
|
|
|
41 |
# Convert reference text to phonemes
|
42 |
ref_phonemes = []
|
43 |
for word in reference_text.split():
|
44 |
-
|
45 |
-
ipa_clean = clean_phonemes(ipa)
|
46 |
ref_phonemes.append(list(ipa_clean))
|
47 |
-
|
48 |
# Process audio file
|
49 |
audio, sr = librosa.load(audio_file, sr=16000)
|
50 |
input_values = processor(audio, sampling_rate=16000, return_tensors="pt").input_values
|
51 |
-
|
52 |
# Get transcription
|
53 |
with torch.no_grad():
|
54 |
logits = model(input_values).logits
|
55 |
pred_ids = torch.argmax(logits, dim=-1)
|
56 |
transcription = processor.batch_decode(pred_ids)[0].strip()
|
57 |
-
|
58 |
# Convert transcription to phonemes
|
59 |
obs_phonemes = []
|
60 |
for word in transcription.split():
|
61 |
-
|
62 |
-
ipa_clean = clean_phonemes(ipa)
|
63 |
obs_phonemes.append(list(ipa_clean))
|
64 |
-
|
65 |
# Prepare results in JSON format
|
66 |
results = {
|
67 |
"language": language,
|
@@ -70,20 +95,20 @@ def analyze_phonemes(language, reference_text, audio_file):
|
|
70 |
"word_alignment": [],
|
71 |
"metrics": {}
|
72 |
}
|
73 |
-
|
74 |
# Calculate metrics
|
75 |
total_phoneme_errors = 0
|
76 |
total_phoneme_length = 0
|
77 |
correct_words = 0
|
78 |
total_word_length = len(ref_phonemes)
|
79 |
-
|
80 |
# Word-by-word alignment
|
81 |
for i, (ref, obs) in enumerate(zip(ref_phonemes, obs_phonemes)):
|
82 |
ref_str = ''.join(ref)
|
83 |
obs_str = ''.join(obs)
|
84 |
edits = editdistance.eval(ref, obs)
|
85 |
acc = round((1 - edits / max(1, len(ref))) * 100, 2)
|
86 |
-
|
87 |
# Get error details
|
88 |
matcher = difflib.SequenceMatcher(None, ref, obs)
|
89 |
ops = matcher.get_opcodes()
|
@@ -97,7 +122,7 @@ def analyze_phonemes(language, reference_text, audio_file):
|
|
97 |
"reference": ref_seg,
|
98 |
"observed": obs_seg
|
99 |
})
|
100 |
-
|
101 |
results["word_alignment"].append({
|
102 |
"word_index": i,
|
103 |
"reference_phonemes": ref_str,
|
@@ -107,18 +132,18 @@ def analyze_phonemes(language, reference_text, audio_file):
|
|
107 |
"is_correct": edits == 0,
|
108 |
"errors": error_details
|
109 |
})
|
110 |
-
|
111 |
total_phoneme_errors += edits
|
112 |
total_phoneme_length += len(ref)
|
113 |
correct_words += 1 if edits == 0 else 0
|
114 |
-
|
115 |
-
#
|
116 |
phoneme_acc = round((1 - total_phoneme_errors / max(1, total_phoneme_length)) * 100, 2)
|
117 |
phoneme_er = round((total_phoneme_errors / max(1, total_phoneme_length)) * 100, 2)
|
118 |
word_acc = round((correct_words / max(1, total_word_length)) * 100, 2)
|
119 |
word_er = round(((total_word_length - correct_words) / max(1, total_word_length)) * 100, 2)
|
120 |
text_wer = round(wer(reference_text, transcription) * 100, 2)
|
121 |
-
|
122 |
results["metrics"] = {
|
123 |
"word_accuracy": word_acc,
|
124 |
"word_error_rate": word_er,
|
@@ -126,7 +151,7 @@ def analyze_phonemes(language, reference_text, audio_file):
|
|
126 |
"phoneme_error_rate": phoneme_er,
|
127 |
"asr_word_error_rate": text_wer
|
128 |
}
|
129 |
-
|
130 |
return json.dumps(results, indent=2, ensure_ascii=False)
|
131 |
|
132 |
# Create Gradio interface with language-specific default text
|
@@ -139,7 +164,7 @@ def get_default_text(language):
|
|
139 |
with gr.Blocks() as demo:
|
140 |
gr.Markdown("# Multilingual Phoneme Alignment Analysis")
|
141 |
gr.Markdown("Compare audio pronunciation with reference text at phoneme level")
|
142 |
-
|
143 |
with gr.Row():
|
144 |
language = gr.Dropdown(
|
145 |
["Arabic", "English"],
|
@@ -150,22 +175,21 @@ with gr.Blocks() as demo:
|
|
150 |
label="Reference Text",
|
151 |
value=get_default_text("Arabic")
|
152 |
)
|
153 |
-
|
154 |
audio_input = gr.Audio(label="Upload Audio File", type="filepath")
|
155 |
submit_btn = gr.Button("Analyze")
|
156 |
output = gr.JSON(label="Phoneme Alignment Results")
|
157 |
-
|
158 |
-
# Update default text when language changes
|
159 |
language.change(
|
160 |
fn=get_default_text,
|
161 |
inputs=language,
|
162 |
outputs=reference_text
|
163 |
)
|
164 |
-
|
165 |
submit_btn.click(
|
166 |
fn=analyze_phonemes,
|
167 |
inputs=[language, reference_text, audio_input],
|
168 |
outputs=output
|
169 |
)
|
170 |
|
171 |
-
demo.launch()
|
|
|
8 |
import editdistance
|
9 |
from jiwer import wer
|
10 |
import json
|
11 |
+
import string
|
12 |
+
import eng_to_ipa as ipa
|
13 |
|
14 |
# Load both models at startup
|
15 |
MODELS = {
|
|
|
21 |
"English": {
|
22 |
"processor": Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-base-960h"),
|
23 |
"model": Wav2Vec2ForCTC.from_pretrained("facebook/wav2vec2-base-960h"),
|
24 |
+
"epitran": None # Not needed; using eng_to_ipa
|
25 |
}
|
26 |
}
|
27 |
|
|
|
29 |
for lang in MODELS.values():
|
30 |
lang["model"].config.ctc_loss_reduction = "mean"
|
31 |
|
32 |
+
def clean_phonemes(ipa_text):
|
33 |
"""Remove diacritics and length markers from phonemes"""
|
34 |
+
return re.sub(r'[\u064B-\u0652\u02D0]', '', ipa_text)
|
35 |
+
|
36 |
+
def safe_transliterate_arabic(epi, word):
|
37 |
+
try:
|
38 |
+
word = word.strip()
|
39 |
+
ipa = epi.transliterate(word)
|
40 |
+
if not ipa.strip():
|
41 |
+
raise ValueError("Empty IPA string")
|
42 |
+
return clean_phonemes(ipa)
|
43 |
+
except Exception as e:
|
44 |
+
print(f"[Warning] Arabic transliteration failed for '{word}': {e}")
|
45 |
+
return ""
|
46 |
+
|
47 |
+
def transliterate_english(word):
|
48 |
+
try:
|
49 |
+
word = word.lower().translate(str.maketrans('', '', string.punctuation))
|
50 |
+
ipa_text = ipa.convert(word)
|
51 |
+
return clean_phonemes(ipa_text)
|
52 |
+
except Exception as e:
|
53 |
+
print(f"[Warning] English IPA conversion failed for '{word}': {e}")
|
54 |
+
return ""
|
55 |
|
56 |
def analyze_phonemes(language, reference_text, audio_file):
|
57 |
# Get the appropriate model components
|
|
|
59 |
processor = lang_models["processor"]
|
60 |
model = lang_models["model"]
|
61 |
epi = lang_models["epitran"]
|
62 |
+
|
63 |
+
if language == "Arabic":
|
64 |
+
transliterate_fn = lambda word: safe_transliterate_arabic(epi, word)
|
65 |
+
else:
|
66 |
+
transliterate_fn = transliterate_english
|
67 |
+
|
68 |
# Convert reference text to phonemes
|
69 |
ref_phonemes = []
|
70 |
for word in reference_text.split():
|
71 |
+
ipa_clean = transliterate_fn(word)
|
|
|
72 |
ref_phonemes.append(list(ipa_clean))
|
73 |
+
|
74 |
# Process audio file
|
75 |
audio, sr = librosa.load(audio_file, sr=16000)
|
76 |
input_values = processor(audio, sampling_rate=16000, return_tensors="pt").input_values
|
77 |
+
|
78 |
# Get transcription
|
79 |
with torch.no_grad():
|
80 |
logits = model(input_values).logits
|
81 |
pred_ids = torch.argmax(logits, dim=-1)
|
82 |
transcription = processor.batch_decode(pred_ids)[0].strip()
|
83 |
+
|
84 |
# Convert transcription to phonemes
|
85 |
obs_phonemes = []
|
86 |
for word in transcription.split():
|
87 |
+
ipa_clean = transliterate_fn(word)
|
|
|
88 |
obs_phonemes.append(list(ipa_clean))
|
89 |
+
|
90 |
# Prepare results in JSON format
|
91 |
results = {
|
92 |
"language": language,
|
|
|
95 |
"word_alignment": [],
|
96 |
"metrics": {}
|
97 |
}
|
98 |
+
|
99 |
# Calculate metrics
|
100 |
total_phoneme_errors = 0
|
101 |
total_phoneme_length = 0
|
102 |
correct_words = 0
|
103 |
total_word_length = len(ref_phonemes)
|
104 |
+
|
105 |
# Word-by-word alignment
|
106 |
for i, (ref, obs) in enumerate(zip(ref_phonemes, obs_phonemes)):
|
107 |
ref_str = ''.join(ref)
|
108 |
obs_str = ''.join(obs)
|
109 |
edits = editdistance.eval(ref, obs)
|
110 |
acc = round((1 - edits / max(1, len(ref))) * 100, 2)
|
111 |
+
|
112 |
# Get error details
|
113 |
matcher = difflib.SequenceMatcher(None, ref, obs)
|
114 |
ops = matcher.get_opcodes()
|
|
|
122 |
"reference": ref_seg,
|
123 |
"observed": obs_seg
|
124 |
})
|
125 |
+
|
126 |
results["word_alignment"].append({
|
127 |
"word_index": i,
|
128 |
"reference_phonemes": ref_str,
|
|
|
132 |
"is_correct": edits == 0,
|
133 |
"errors": error_details
|
134 |
})
|
135 |
+
|
136 |
total_phoneme_errors += edits
|
137 |
total_phoneme_length += len(ref)
|
138 |
correct_words += 1 if edits == 0 else 0
|
139 |
+
|
140 |
+
# Final metrics
|
141 |
phoneme_acc = round((1 - total_phoneme_errors / max(1, total_phoneme_length)) * 100, 2)
|
142 |
phoneme_er = round((total_phoneme_errors / max(1, total_phoneme_length)) * 100, 2)
|
143 |
word_acc = round((correct_words / max(1, total_word_length)) * 100, 2)
|
144 |
word_er = round(((total_word_length - correct_words) / max(1, total_word_length)) * 100, 2)
|
145 |
text_wer = round(wer(reference_text, transcription) * 100, 2)
|
146 |
+
|
147 |
results["metrics"] = {
|
148 |
"word_accuracy": word_acc,
|
149 |
"word_error_rate": word_er,
|
|
|
151 |
"phoneme_error_rate": phoneme_er,
|
152 |
"asr_word_error_rate": text_wer
|
153 |
}
|
154 |
+
|
155 |
return json.dumps(results, indent=2, ensure_ascii=False)
|
156 |
|
157 |
# Create Gradio interface with language-specific default text
|
|
|
164 |
with gr.Blocks() as demo:
|
165 |
gr.Markdown("# Multilingual Phoneme Alignment Analysis")
|
166 |
gr.Markdown("Compare audio pronunciation with reference text at phoneme level")
|
167 |
+
|
168 |
with gr.Row():
|
169 |
language = gr.Dropdown(
|
170 |
["Arabic", "English"],
|
|
|
175 |
label="Reference Text",
|
176 |
value=get_default_text("Arabic")
|
177 |
)
|
178 |
+
|
179 |
audio_input = gr.Audio(label="Upload Audio File", type="filepath")
|
180 |
submit_btn = gr.Button("Analyze")
|
181 |
output = gr.JSON(label="Phoneme Alignment Results")
|
182 |
+
|
|
|
183 |
language.change(
|
184 |
fn=get_default_text,
|
185 |
inputs=language,
|
186 |
outputs=reference_text
|
187 |
)
|
188 |
+
|
189 |
submit_btn.click(
|
190 |
fn=analyze_phonemes,
|
191 |
inputs=[language, reference_text, audio_input],
|
192 |
outputs=output
|
193 |
)
|
194 |
|
195 |
+
demo.launch()
|