File size: 5,198 Bytes
07a50af
 
23545c8
d17dd7e
 
07a50af
 
23545c8
 
87966ec
 
 
23545c8
07a50af
 
 
 
 
 
 
 
b04a244
87966ec
b04a244
87966ec
 
 
 
 
 
 
fa9bb6e
 
 
 
 
 
 
 
87966ec
 
 
 
 
 
 
 
 
 
 
 
 
23545c8
87966ec
 
23545c8
 
 
 
87966ec
 
 
23545c8
 
 
406c152
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d4efcd9
 
6a55256
 
 
 
 
 
 
 
 
 
 
 
 
d4efcd9
 
 
 
 
 
 
 
3534a0b
 
 
41f8646
 
 
 
1496981
3534a0b
 
 
1496981
3534a0b
 
d3c3e3b
3534a0b
 
23545c8
87966ec
 
07a50af
 
 
2dd1896
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
import gradio as gr
from transformers import pipeline
import os
import numpy as np
import torch

# Load the model
print("Loading model...")
model_id = "badrex/mms-300m-arabic-dialect-identifier"
classifier = pipeline("audio-classification", model=model_id)
print("Model loaded successfully")

# Define dialect mapping
dialect_mapping = {
    "MSA": "Modern Standard Arabic",
    "Egyptian": "Egyptian Arabic",
    "Gulf": "Gulf Arabic",
    "Levantine": "Levantine Arabic",
    "Maghrebi": "Maghrebi Arabic"
}

def predict_dialect(audio):
    if audio is None:
        return {"Error": 1.0}
    
    # The audio input from Gradio is a tuple of (sample_rate, audio_array)
    sr, audio_array = audio
    
    # Process the audio input
    if len(audio_array.shape) > 1:
        audio_array = audio_array.mean(axis=1)  # Convert stereo to mono

    # Convert audio to float32 if it's not already (fix for Chrome recording issue)
    if audio_array.dtype != np.float32:
        # Normalize to [-1, 1] range as expected by the model
        if audio_array.dtype == np.int16:
            audio_array = audio_array.astype(np.float32) / 32768.0
        else:
            audio_array = audio_array.astype(np.float32)
    
    print(f"Processing audio: sample rate={sr}, shape={audio_array.shape}")
    
    # Classify the dialect
    predictions = classifier({"sampling_rate": sr, "raw": audio_array})
    
    # Format results for display
    results = {}
    for pred in predictions:
        dialect_name = dialect_mapping.get(pred['label'], pred['label'])
        results[dialect_name] = float(pred['score'])
    
    return results

# Manually prepare example file paths without metadata
examples = []
examples_dir = "examples"
if os.path.exists(examples_dir):
    for filename in os.listdir(examples_dir):
        if filename.endswith((".wav", ".mp3", ".ogg")):
            examples.append([os.path.join(examples_dir, filename)])
    
    print(f"Found {len(examples)} example files")
else:
    print("Examples directory not found")



# Custom CSS for better styling
custom_css = """
<style>
.centered-content {
    text-align: center;
    max-width: 800px;
    margin: 0 auto;
    padding: 20px;
}

.logo-image {
    width: 200px;
    height: auto;
    margin: 20px auto;
    display: block;
}

.description-text {
    font-size: 16px;
    line-height: 1.6;
    margin-bottom: 20px;
}

.dialect-list {
    font-size: 15px;
    line-height: 1.8;
    text-align: left;
    max-width: 600px;
    margin: 0 auto;
}

.highlight-text {
    font-size: 16px;
    color: #2563eb;
    margin: 20px 0;
}

.footer-text {
    font-size: 13px;
    color: #6b7280;
    margin-top: 20px;
}
</style>
"""

"""
<p style="font-size: 15px; line-height: 1.8;">
                    <strong>The following Arabic language varieties are supported:</strong>
                    <br><br>
                    ✦ <strong>Modern Standard Arabic (MSA)</strong> - The formal language of media and education
                    <br>
                    ✦ <strong>Egyptian Arabic</strong> - The dialect of Cairo, Alexandria, and popular Arabic cinema
                    <br>
                    ✦ <strong>Gulf Arabic</strong> - Spoken across Saudi Arabia, UAE, Kuwait, Qatar, Bahrain, and Oman
                    <br>
                    ✦ <strong>Levantine Arabic</strong> - The dialect of Syria, Lebanon, Jordan, and Palestine
                    <br>
                    ✦ <strong>Maghrebi Arabic</strong> - The distinctive varieties of Morocco, Algeria, Tunisia, and Libya
                    </p>
                    <br>
"""

# Create the Gradio interface
demo = gr.Interface(
    fn=predict_dialect,
    inputs=gr.Audio(),
    outputs=gr.Label(num_top_classes=5, label="Predicted Dialect"),
    title="Tamyïz 🍉 Arabic Dialect Identification in Speech",
    description="""
        <div class="centered-content">
            <div>
                <p>
                By <a href="https://badrex.github.io/" style="color: #2563eb;">Badr Alabsi</a> with ❤️🤍💚 
                </p>
                <br>
                <p style="font-size: 15px; line-height: 1.8;">
                This is a demo for the accurate and robust Transformer-based <a href="https://huggingface.co/badrex/mms-300m-arabic-dialect-identifier" style="color: #FF5349;">model</a> for Spoken Arabic Dialect Identification (ADI). 
                From just a short audio clip (5-10 seconds), the model can identify Modern Standard Arabic (<strong>MSA</strong>) as well as four major regional Arabic varieties: <strong>Egyptian</strong> Arabic, <strong>Gulf</strong> Arabic, <strong>Levantine</strong> Arabic, and <strong>Maghrebi</strong> Arabic. 
                <br>                   
                <p style="font-size: 15px; line-height: 1.8;">
                Simply <strong>upload an audio file</strong> 📀 or <strong>record yourself speaking</strong> ⏯️⏺️ to try out the model!
                </p>
            </div>
        </div>
        """,
    examples=examples if examples else None,
    cache_examples=False,  # Disable caching to avoid issues
    flagging_mode=None
)

# Launch the app
demo.launch(share=True)