File size: 5,650 Bytes
9ba8fab
 
 
3769468
9ba8fab
 
3769468
6960dc6
2ab6a6b
9ba8fab
 
 
2ab6a6b
9ba8fab
 
 
2ab6a6b
3769468
 
 
 
 
 
 
 
 
 
 
d415750
3769468
 
 
 
 
 
2ab6a6b
3769468
d415750
3769468
 
 
 
d415750
3769468
 
d415750
2ab6a6b
3769468
 
d415750
 
 
 
 
3769468
2ab6a6b
3769468
 
d415750
3769468
 
d415750
 
3769468
 
 
d415750
 
 
3769468
d415750
 
3769468
 
 
9ba8fab
 
 
2ab6a6b
9ba8fab
d415750
3769468
d415750
 
29c8f24
d415750
 
9ba8fab
d415750
 
2ab6a6b
3769468
 
d415750
3769468
d415750
 
3769468
d415750
6960dc6
2ab6a6b
d415750
 
 
 
 
 
2ab6a6b
9d80aed
d415750
 
 
 
9ba8fab
 
 
 
 
 
 
 
 
 
d415750
3769468
9ba8fab
d415750
c726970
2ab6a6b
d415750
9ba8fab
 
d415750
 
 
 
 
 
9ba8fab
 
d415750
9ba8fab
d415750
 
 
 
 
9ba8fab
d415750
 
9ba8fab
 
d415750
9ba8fab
 
 
 
 
 
2ab6a6b
d415750
 
3769468
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
import gradio as gr
import pandas as pd
from datasets import load_dataset
from jiwer import wer, cer
import os
from datetime import datetime
import re


dataset = load_dataset("sudoping01/bambara-asr-benchmark", name="default")["train"]
references = {row["id"]: row["text"] for row in dataset}


leaderboard_file = "leaderboard.csv"
if not os.path.exists(leaderboard_file):
    pd.DataFrame(columns=["submitter", "WER", "CER", "timestamp"]).to_csv(leaderboard_file, index=False)


def normalize_text(text):
    """
    Normalize text for WER/CER calculation:
    - Convert to lowercase
    - Remove punctuation
    - Replace multiple spaces with single space
    - Strip leading/trailing spaces
    """
    if not isinstance(text, str):
        text = str(text)
    
    text = text.lower()
    text = re.sub(r'[^\w\s]', '', text)
    text = re.sub(r'\s+', ' ', text).strip()
    return text

def calculate_metrics(predictions_df):

    results = []

    for _, row in predictions_df.iterrows():
        id_val = row["id"]
        if id_val not in references:
            continue
            
        reference = normalize_text(references[id_val])
        hypothesis = normalize_text(row["text"])
        

        if not reference or not hypothesis:
            continue
            
        reference_words = reference.split()
        hypothesis_words = hypothesis.split()
        
        
        try:

            sample_wer = wer(reference, hypothesis)
            sample_cer = cer(reference, hypothesis)
                
            results.append({
                "id": id_val,
                "reference": reference,
                "hypothesis": hypothesis,
                "wer": sample_wer,
                "cer": sample_cer
            })
        except Exception as e:
            print(f"Error calculating metrics for ID {id_val}: {str(e)}")
    
    if not results:
        raise ValueError("No valid samples for WER/CER calculation")
        
    avg_wer = sum(item["wer"] for item in results) / len(results)
    avg_cer = sum(item["cer"] for item in results) / len(results)
    return avg_wer, avg_cer, results

def process_submission(submitter_name, csv_file):
    try:
      
        df = pd.read_csv(csv_file)
        
        if len(df) == 0:
            return "Error: Uploaded CSV is empty.", None
            
        if set(df.columns) != {"id", "text"}:
            return f"Error: CSV must contain exactly 'id' and 'text' columns. Found: {', '.join(df.columns)}", None
            
        if df["id"].duplicated().any():
            dup_ids = df[df["id"].duplicated()]["id"].unique()
            return f"Error: Duplicate IDs found: {', '.join(map(str, dup_ids[:5]))}", None
            
        missing_ids = set(references.keys()) - set(df["id"])
        extra_ids = set(df["id"]) - set(references.keys())
        
        if missing_ids:
            return f"Error: Missing {len(missing_ids)} IDs in submission. First few missing: {', '.join(map(str, list(missing_ids)[:5]))}", None
            
        if extra_ids:
            return f"Error: Found {len(extra_ids)} extra IDs not in reference dataset. First few extra: {', '.join(map(str, list(extra_ids)[:5]))}", None
        

        try:
            avg_wer, avg_cer, detailed_results = calculate_metrics(df)
            
            print(f"Calculated metrics - WER: {avg_wer:.4f}, CER: {avg_cer:.4f}")
            print(f"Processed {len(detailed_results)} valid samples")
            
      
            if avg_wer < 0.000001: # I will come back to this 
                return "Error: WER calculation yielded suspicious results (near-zero). Please check your submission CSV.", None
                
        except Exception as e:
            return f"Error calculating metrics: {str(e)}", None
        
        leaderboard = pd.read_csv(leaderboard_file)
        timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
        new_entry = pd.DataFrame(
            [[submitter_name, avg_wer, avg_cer, timestamp]],
            columns=["submitter", "WER", "CER", "timestamp"]
        )
        leaderboard = pd.concat([leaderboard, new_entry]).sort_values("WER")
        leaderboard.to_csv(leaderboard_file, index=False)
        
        return f"Submission processed successfully! WER: {avg_wer:.4f}, CER: {avg_cer:.4f}", leaderboard
        
    except Exception as e:
        return f"Error processing submission: {str(e)}", None


with gr.Blocks(title="Bambara ASR Leaderboard") as demo:
    gr.Markdown(
        """
        # Bambara ASR Leaderboard
        Upload a CSV file with 'id' and 'text' columns to evaluate your ASR predictions.
        The 'id's must match those in the dataset.
        [View the dataset here](https://huggingface.co/datasets/MALIBA-AI/bambara_general_leaderboard_dataset).
        - **WER**: Word Error Rate (lower is better).
        - **CER**: Character Error Rate (lower is better).
        """
    )
    
    with gr.Row():
        submitter = gr.Textbox(label="Submitter Name or Model Name", placeholder="e.g., MALIBA-AI/asr")
        csv_upload = gr.File(label="Upload CSV File", file_types=[".csv"])
        
    submit_btn = gr.Button("Submit")
    output_msg = gr.Textbox(label="Status", interactive=False)
    leaderboard_display = gr.DataFrame(
        label="Leaderboard",
        value=pd.read_csv(leaderboard_file),
        interactive=False
    )
    
    submit_btn.click(
        fn=process_submission,
        inputs=[submitter, csv_upload],
        outputs=[output_msg, leaderboard_display]
    )


print("Starting Bambara ASR Leaderboard app...")

if __name__ == "__main__":
    demo.launch(share=True)