|
import gradio as gr |
|
from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor |
|
import librosa |
|
import torch |
|
import epitran |
|
import re |
|
import difflib |
|
import editdistance |
|
from jiwer import wer |
|
import json |
|
import string |
|
import eng_to_ipa as ipa |
|
|
|
|
|
MODELS = { |
|
"Arabic": { |
|
"processor": Wav2Vec2Processor.from_pretrained("jonatasgrosman/wav2vec2-large-xlsr-53-arabic"), |
|
"model": Wav2Vec2ForCTC.from_pretrained("jonatasgrosman/wav2vec2-large-xlsr-53-arabic"), |
|
"epitran": epitran.Epitran("ara-Arab") |
|
}, |
|
"English": { |
|
"processor": Wav2Vec2Processor.from_pretrained("jonatasgrosman/wav2vec2-large-xlsr-53-english"), |
|
"model": Wav2Vec2ForCTC.from_pretrained("jonatasgrosman/wav2vec2-large-xlsr-53-english"), |
|
"epitran": epitran.Epitran("eng-Latn") |
|
} |
|
|
|
} |
|
|
|
|
|
for lang in MODELS.values(): |
|
lang["model"].config.ctc_loss_reduction = "mean" |
|
|
|
def clean_phonemes(ipa_text): |
|
"""Remove diacritics and length markers from phonemes""" |
|
return re.sub(r'[\u064B-\u0652\u02D0]', '', ipa_text) |
|
|
|
def safe_transliterate_arabic(epi, word): |
|
try: |
|
word = word.strip() |
|
ipa = epi.transliterate(word) |
|
if not ipa.strip(): |
|
raise ValueError("Empty IPA string") |
|
return clean_phonemes(ipa) |
|
except Exception as e: |
|
print(f"[Warning] Arabic transliteration failed for '{word}': {e}") |
|
return "" |
|
|
|
def transliterate_english(word): |
|
try: |
|
word = word.lower().translate(str.maketrans('', '', string.punctuation)) |
|
ipa_text = ipa.convert(word) |
|
return clean_phonemes(ipa_text) |
|
except Exception as e: |
|
print(f"[Warning] English IPA conversion failed for '{word}': {e}") |
|
return "" |
|
|
|
def analyze_phonemes(language, reference_text, audio_file): |
|
|
|
lang_models = MODELS[language] |
|
processor = lang_models["processor"] |
|
model = lang_models["model"] |
|
epi = lang_models["epitran"] |
|
|
|
if language == "Arabic": |
|
transliterate_fn = lambda word: safe_transliterate_arabic(epi, word) |
|
else: |
|
transliterate_fn = transliterate_english |
|
|
|
|
|
ref_phonemes = [] |
|
for word in reference_text.split(): |
|
ipa_clean = transliterate_fn(word) |
|
ref_phonemes.append(list(ipa_clean)) |
|
|
|
|
|
audio, sr = librosa.load(audio_file, sr=16000) |
|
input_values = processor(audio, sampling_rate=16000, return_tensors="pt").input_values |
|
|
|
|
|
with torch.no_grad(): |
|
logits = model(input_values).logits |
|
pred_ids = torch.argmax(logits, dim=-1) |
|
transcription = processor.batch_decode(pred_ids)[0].strip() |
|
|
|
|
|
obs_phonemes = [] |
|
for word in transcription.split(): |
|
ipa_clean = transliterate_fn(word) |
|
obs_phonemes.append(list(ipa_clean)) |
|
|
|
|
|
results = { |
|
"language": language, |
|
"reference_text": reference_text, |
|
"transcription": transcription, |
|
"word_alignment": [], |
|
"metrics": {} |
|
} |
|
|
|
|
|
total_phoneme_errors = 0 |
|
total_phoneme_length = 0 |
|
correct_words = 0 |
|
total_word_length = len(ref_phonemes) |
|
|
|
|
|
for i, (ref, obs) in enumerate(zip(ref_phonemes, obs_phonemes)): |
|
ref_str = ''.join(ref) |
|
obs_str = ''.join(obs) |
|
edits = editdistance.eval(ref, obs) |
|
acc = round((1 - edits / max(1, len(ref))) * 100, 2) |
|
|
|
|
|
matcher = difflib.SequenceMatcher(None, ref, obs) |
|
ops = matcher.get_opcodes() |
|
error_details = [] |
|
for tag, i1, i2, j1, j2 in ops: |
|
ref_seg = ''.join(ref[i1:i2]) or '-' |
|
obs_seg = ''.join(obs[j1:j2]) or '-' |
|
if tag != 'equal': |
|
error_details.append({ |
|
"type": tag.upper(), |
|
"reference": ref_seg, |
|
"observed": obs_seg |
|
}) |
|
|
|
results["word_alignment"].append({ |
|
"word_index": i, |
|
"reference_phonemes": ref_str, |
|
"observed_phonemes": obs_str, |
|
"edit_distance": edits, |
|
"accuracy": acc, |
|
"is_correct": edits == 0, |
|
"errors": error_details |
|
}) |
|
|
|
total_phoneme_errors += edits |
|
total_phoneme_length += len(ref) |
|
correct_words += 1 if edits == 0 else 0 |
|
|
|
|
|
phoneme_acc = round((1 - total_phoneme_errors / max(1, total_phoneme_length)) * 100, 2) |
|
phoneme_er = round((total_phoneme_errors / max(1, total_phoneme_length)) * 100, 2) |
|
word_acc = round((correct_words / max(1, total_word_length)) * 100, 2) |
|
word_er = round(((total_word_length - correct_words) / max(1, total_word_length)) * 100, 2) |
|
text_wer = round(wer(reference_text, transcription) * 100, 2) |
|
|
|
results["metrics"] = { |
|
"word_accuracy": word_acc, |
|
"word_error_rate": word_er, |
|
"phoneme_accuracy": phoneme_acc, |
|
"phoneme_error_rate": phoneme_er, |
|
"asr_word_error_rate": text_wer |
|
} |
|
|
|
return json.dumps(results, indent=2, ensure_ascii=False) |
|
|
|
|
|
def get_default_text(language): |
|
return { |
|
"Arabic": "ููุจูุฃูููู ุขููุงุกู ุฑูุจููููู
ูุง ุชูููุฐููุจูุงูู", |
|
"English": "The quick brown fox jumps over the lazy dog" |
|
}.get(language, "") |
|
|
|
with gr.Blocks() as demo: |
|
gr.Markdown("# Multilingual Phoneme Alignment Analysis") |
|
gr.Markdown("Compare audio pronunciation with reference text at phoneme level") |
|
|
|
with gr.Row(): |
|
language = gr.Dropdown( |
|
["Arabic", "English"], |
|
label="Language", |
|
value="Arabic" |
|
) |
|
reference_text = gr.Textbox( |
|
label="Reference Text", |
|
value=get_default_text("Arabic") |
|
) |
|
|
|
audio_input = gr.Audio(label="Upload Audio File", type="filepath") |
|
submit_btn = gr.Button("Analyze") |
|
output = gr.JSON(label="Phoneme Alignment Results") |
|
|
|
language.change( |
|
fn=get_default_text, |
|
inputs=language, |
|
outputs=reference_text |
|
) |
|
|
|
submit_btn.click( |
|
fn=analyze_phonemes, |
|
inputs=[language, reference_text, audio_input], |
|
outputs=output |
|
) |
|
|
|
demo.launch() |
|
|