101Frost commited on
Commit
dd8edb5
ยท
verified ยท
1 Parent(s): fa237cc

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +126 -0
app.py ADDED
@@ -0,0 +1,126 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor
3
+ import librosa
4
+ import torch
5
+ import epitran
6
+ import re
7
+ import difflib
8
+ import editdistance
9
+ from jiwer import wer
10
+ import json
11
+
12
+ # Load model once at startup
13
+ model_name = "jonatasgrosman/wav2vec2-large-xlsr-53-arabic"
14
+ processor = Wav2Vec2Processor.from_pretrained(model_name)
15
+ model = Wav2Vec2ForCTC.from_pretrained(model_name)
16
+ epi = epitran.Epitran('ara-Arab')
17
+
18
+ def clean_phonemes(ipa):
19
+ """Remove diacritics and length markers from phonemes"""
20
+ return re.sub(r'[\u064B-\u0652\u02D0]', '', ipa)
21
+
22
+ def analyze_phonemes(language, reference_text, audio_file):
23
+ # Convert reference text to phonemes
24
+ ref_phonemes = []
25
+ for word in reference_text.split():
26
+ ipa = epi.transliterate(word)
27
+ ipa_clean = clean_phonemes(ipa)
28
+ ref_phonemes.append(list(ipa_clean))
29
+
30
+ # Process audio file
31
+ audio, sr = librosa.load(audio_file.name, sr=16000)
32
+ input_values = processor(audio, sampling_rate=16000, return_tensors="pt").input_values
33
+
34
+ # Get transcription
35
+ with torch.no_grad():
36
+ logits = model(input_values).logits
37
+ pred_ids = torch.argmax(logits, dim=-1)
38
+ transcription = processor.batch_decode(pred_ids)[0].strip()
39
+
40
+ # Convert transcription to phonemes
41
+ obs_phonemes = []
42
+ for word in transcription.split():
43
+ ipa = epi.transliterate(word)
44
+ ipa_clean = clean_phonemes(ipa)
45
+ obs_phonemes.append(list(ipa_clean))
46
+
47
+ # Prepare results in JSON format
48
+ results = {
49
+ "reference_text": reference_text,
50
+ "transcription": transcription,
51
+ "word_alignment": [],
52
+ "metrics": {}
53
+ }
54
+
55
+ # Calculate metrics
56
+ total_phoneme_errors = 0
57
+ total_phoneme_length = 0
58
+ correct_words = 0
59
+ total_word_length = len(ref_phonemes)
60
+
61
+ # Word-by-word alignment
62
+ for i, (ref, obs) in enumerate(zip(ref_phonemes, obs_phonemes)):
63
+ ref_str = ''.join(ref)
64
+ obs_str = ''.join(obs)
65
+ edits = editdistance.eval(ref, obs)
66
+ acc = round((1 - edits / max(1, len(ref))) * 100, 2)
67
+
68
+ # Get error details
69
+ matcher = difflib.SequenceMatcher(None, ref, obs)
70
+ ops = matcher.get_opcodes()
71
+ error_details = []
72
+ for tag, i1, i2, j1, j2 in ops:
73
+ ref_seg = ''.join(ref[i1:i2]) or '-'
74
+ obs_seg = ''.join(obs[j1:j2]) or '-'
75
+ if tag != 'equal':
76
+ error_details.append({
77
+ "type": tag.upper(),
78
+ "reference": ref_seg,
79
+ "observed": obs_seg
80
+ })
81
+
82
+ results["word_alignment"].append({
83
+ "word_index": i,
84
+ "reference_phonemes": ref_str,
85
+ "observed_phonemes": obs_str,
86
+ "edit_distance": edits,
87
+ "accuracy": acc,
88
+ "is_correct": edits == 0,
89
+ "errors": error_details
90
+ })
91
+
92
+ total_phoneme_errors += edits
93
+ total_phoneme_length += len(ref)
94
+ correct_words += 1 if edits == 0 else 0
95
+
96
+ # Calculate metrics
97
+ phoneme_acc = round((1 - total_phoneme_errors / max(1, total_phoneme_length)) * 100, 2)
98
+ phoneme_er = round((total_phoneme_errors / max(1, total_phoneme_length)) * 100, 2)
99
+ word_acc = round((correct_words / max(1, total_word_length)) * 100, 2)
100
+ word_er = round(((total_word_length - correct_words) / max(1, total_word_length)) * 100, 2)
101
+ text_wer = round(wer(reference_text, transcription) * 100, 2)
102
+
103
+ results["metrics"] = {
104
+ "word_accuracy": word_acc,
105
+ "word_error_rate": word_er,
106
+ "phoneme_accuracy": phoneme_acc,
107
+ "phoneme_error_rate": phoneme_er,
108
+ "asr_word_error_rate": text_wer
109
+ }
110
+
111
+ return json.dumps(results, indent=2, ensure_ascii=False)
112
+
113
+ # Create Gradio interface
114
+ demo = gr.Interface(
115
+ fn=analyze_phonemes,
116
+ inputs=[
117
+ gr.Dropdown(["Arabic"], label="Language", value="Arabic"),
118
+ gr.Textbox(label="Reference Text", value="ููŽุจูุฃูŽูŠู‘ู ุขู„ูŽุงุกู ุฑูŽุจู‘ููƒูู…ูŽุง ุชููƒูŽุฐู‘ูุจูŽุงู†ู"),
119
+ gr.File(label="Upload Audio File", type="file")
120
+ ],
121
+ outputs=gr.JSON(label="Phoneme Alignment Results"),
122
+ title="Arabic Phoneme Alignment Analysis",
123
+ description="Compare audio pronunciation with reference text at phoneme level"
124
+ )
125
+
126
+ demo.launch()