File size: 3,917 Bytes
bdecbe9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
import streamlit as st
import os
import tempfile
import pandas as pd
from datasets import load_dataset
from transformers import WhisperProcessor, WhisperForConditionalGeneration
from peft import PeftModel
import torch
import librosa
import numpy as np
import evaluate

st.title("📊 Évaluation WER d'un modèle Whisper")
st.markdown("Ce Space permet d'évaluer la performance WER d'un modèle Whisper sur un dataset audio.")

# Section : Choix du modèle
st.subheader("1. Choix du modèle")
model_option = st.radio("Quel modèle veux-tu utiliser ?", (
    "Whisper Large (baseline)",
    "Whisper Large + LoRA (SimpleFrog/whisper_finetuned)",
    "Whisper Large + LoRA + Post-processing (à venir)"
))

# Section : Lien du dataset
st.subheader("2. Chargement du dataset Hugging Face")
dataset_link = st.text_input("Lien du dataset (format: user/dataset_name)", value="SimpleFrog/mon_dataset")
hf_token = st.text_input("Token Hugging Face (si dataset privé)", type="password")

# Section : Bouton pour lancer l'évaluation
start_eval = st.button("🚀 Lancer l'évaluation WER")

if start_eval:
    st.subheader("🔍 Traitement en cours...")

    # 🔹 Télécharger dataset
    with st.spinner("Chargement du dataset..."):
        try:
            dataset = load_dataset(dataset_link, split="test", token=hf_token)
        except Exception as e:
            st.error(f"Erreur lors du chargement du dataset : {e}")
            st.stop()

    # 🔹 Charger le modèle choisi
    with st.spinner("Chargement du modèle..."):
        base_model_name = "openai/whisper-large"
        model = WhisperForConditionalGeneration.from_pretrained(base_model_name)

        if "LoRA" in model_option:
            model = PeftModel.from_pretrained(model, "SimpleFrog/whisper_finetuned", token=hf_token)

        processor = WhisperProcessor.from_pretrained(base_model_name)
        model.eval()

    # 🔹 Préparer WER metric
    wer_metric = evaluate.load("wer")

    results = []

    for example in dataset:
        try:
            audio_path = example["file_name"]  # full path or relative path in AudioFolder
            reference = example["text"]

            # Load audio (we assume dataset is structured with 'file_name')
            waveform, _ = librosa.load(audio_path, sr=16000)
            waveform = np.expand_dims(waveform, axis=0)
            inputs = processor(waveform, sampling_rate=16000, return_tensors="pt")

            with torch.no_grad():
                pred_ids = model.generate(input_features=inputs.input_features)
            prediction = processor.batch_decode(pred_ids, skip_special_tokens=True)[0]

            # 🔹 Nettoyage ponctuation pour WER "sans ponctuation"
            def clean(text):
                return ''.join([c for c in text.lower() if c.isalnum() or c.isspace()]).strip()

            ref_clean = clean(reference)
            pred_clean = clean(prediction)
            wer = wer_metric.compute(predictions=[pred_clean], references=[ref_clean])

            results.append({
                "Fichier": audio_path,
                "Référence": reference,
                "Transcription": prediction,
                "WER": round(wer, 4)
            })

        except Exception as e:
            results.append({
                "Fichier": example.get("file_name", "unknown"),
                "Référence": "Erreur",
                "Transcription": f"Erreur: {e}",
                "WER": "-"
            })

    # 🔹 Afficher le tableau de résultats
    df = pd.DataFrame(results)
    st.subheader("📋 Résultats de la transcription")
    st.dataframe(df)

    mean_wer = df[df["WER"] != "-"]["WER"].mean()
    st.markdown(f"### 🎯 WER moyen (sans ponctuation) : `{mean_wer:.3f}`")

    # Bloc placeholder pour post-processing à venir
    if "Post-processing" in model_option:
        st.info("🛠️ Le post-processing sera ajouté prochainement ici...")