SimpleFrog commited on
Commit
bdecbe9
·
verified ·
1 Parent(s): 7eeab23

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +106 -0
app.py ADDED
@@ -0,0 +1,106 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import os
3
+ import tempfile
4
+ import pandas as pd
5
+ from datasets import load_dataset
6
+ from transformers import WhisperProcessor, WhisperForConditionalGeneration
7
+ from peft import PeftModel
8
+ import torch
9
+ import librosa
10
+ import numpy as np
11
+ import evaluate
12
+
13
+ st.title("📊 Évaluation WER d'un modèle Whisper")
14
+ st.markdown("Ce Space permet d'évaluer la performance WER d'un modèle Whisper sur un dataset audio.")
15
+
16
+ # Section : Choix du modèle
17
+ st.subheader("1. Choix du modèle")
18
+ model_option = st.radio("Quel modèle veux-tu utiliser ?", (
19
+ "Whisper Large (baseline)",
20
+ "Whisper Large + LoRA (SimpleFrog/whisper_finetuned)",
21
+ "Whisper Large + LoRA + Post-processing (à venir)"
22
+ ))
23
+
24
+ # Section : Lien du dataset
25
+ st.subheader("2. Chargement du dataset Hugging Face")
26
+ dataset_link = st.text_input("Lien du dataset (format: user/dataset_name)", value="SimpleFrog/mon_dataset")
27
+ hf_token = st.text_input("Token Hugging Face (si dataset privé)", type="password")
28
+
29
+ # Section : Bouton pour lancer l'évaluation
30
+ start_eval = st.button("🚀 Lancer l'évaluation WER")
31
+
32
+ if start_eval:
33
+ st.subheader("🔍 Traitement en cours...")
34
+
35
+ # 🔹 Télécharger dataset
36
+ with st.spinner("Chargement du dataset..."):
37
+ try:
38
+ dataset = load_dataset(dataset_link, split="test", token=hf_token)
39
+ except Exception as e:
40
+ st.error(f"Erreur lors du chargement du dataset : {e}")
41
+ st.stop()
42
+
43
+ # 🔹 Charger le modèle choisi
44
+ with st.spinner("Chargement du modèle..."):
45
+ base_model_name = "openai/whisper-large"
46
+ model = WhisperForConditionalGeneration.from_pretrained(base_model_name)
47
+
48
+ if "LoRA" in model_option:
49
+ model = PeftModel.from_pretrained(model, "SimpleFrog/whisper_finetuned", token=hf_token)
50
+
51
+ processor = WhisperProcessor.from_pretrained(base_model_name)
52
+ model.eval()
53
+
54
+ # 🔹 Préparer WER metric
55
+ wer_metric = evaluate.load("wer")
56
+
57
+ results = []
58
+
59
+ for example in dataset:
60
+ try:
61
+ audio_path = example["file_name"] # full path or relative path in AudioFolder
62
+ reference = example["text"]
63
+
64
+ # Load audio (we assume dataset is structured with 'file_name')
65
+ waveform, _ = librosa.load(audio_path, sr=16000)
66
+ waveform = np.expand_dims(waveform, axis=0)
67
+ inputs = processor(waveform, sampling_rate=16000, return_tensors="pt")
68
+
69
+ with torch.no_grad():
70
+ pred_ids = model.generate(input_features=inputs.input_features)
71
+ prediction = processor.batch_decode(pred_ids, skip_special_tokens=True)[0]
72
+
73
+ # 🔹 Nettoyage ponctuation pour WER "sans ponctuation"
74
+ def clean(text):
75
+ return ''.join([c for c in text.lower() if c.isalnum() or c.isspace()]).strip()
76
+
77
+ ref_clean = clean(reference)
78
+ pred_clean = clean(prediction)
79
+ wer = wer_metric.compute(predictions=[pred_clean], references=[ref_clean])
80
+
81
+ results.append({
82
+ "Fichier": audio_path,
83
+ "Référence": reference,
84
+ "Transcription": prediction,
85
+ "WER": round(wer, 4)
86
+ })
87
+
88
+ except Exception as e:
89
+ results.append({
90
+ "Fichier": example.get("file_name", "unknown"),
91
+ "Référence": "Erreur",
92
+ "Transcription": f"Erreur: {e}",
93
+ "WER": "-"
94
+ })
95
+
96
+ # 🔹 Afficher le tableau de résultats
97
+ df = pd.DataFrame(results)
98
+ st.subheader("📋 Résultats de la transcription")
99
+ st.dataframe(df)
100
+
101
+ mean_wer = df[df["WER"] != "-"]["WER"].mean()
102
+ st.markdown(f"### 🎯 WER moyen (sans ponctuation) : `{mean_wer:.3f}`")
103
+
104
+ # Bloc placeholder pour post-processing à venir
105
+ if "Post-processing" in model_option:
106
+ st.info("🛠️ Le post-processing sera ajouté prochainement ici...")