Spaces:
Running
Running
Create app.py
Browse files
app.py
ADDED
@@ -0,0 +1,106 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import streamlit as st
|
2 |
+
import os
|
3 |
+
import tempfile
|
4 |
+
import pandas as pd
|
5 |
+
from datasets import load_dataset
|
6 |
+
from transformers import WhisperProcessor, WhisperForConditionalGeneration
|
7 |
+
from peft import PeftModel
|
8 |
+
import torch
|
9 |
+
import librosa
|
10 |
+
import numpy as np
|
11 |
+
import evaluate
|
12 |
+
|
13 |
+
st.title("📊 Évaluation WER d'un modèle Whisper")
|
14 |
+
st.markdown("Ce Space permet d'évaluer la performance WER d'un modèle Whisper sur un dataset audio.")
|
15 |
+
|
16 |
+
# Section : Choix du modèle
|
17 |
+
st.subheader("1. Choix du modèle")
|
18 |
+
model_option = st.radio("Quel modèle veux-tu utiliser ?", (
|
19 |
+
"Whisper Large (baseline)",
|
20 |
+
"Whisper Large + LoRA (SimpleFrog/whisper_finetuned)",
|
21 |
+
"Whisper Large + LoRA + Post-processing (à venir)"
|
22 |
+
))
|
23 |
+
|
24 |
+
# Section : Lien du dataset
|
25 |
+
st.subheader("2. Chargement du dataset Hugging Face")
|
26 |
+
dataset_link = st.text_input("Lien du dataset (format: user/dataset_name)", value="SimpleFrog/mon_dataset")
|
27 |
+
hf_token = st.text_input("Token Hugging Face (si dataset privé)", type="password")
|
28 |
+
|
29 |
+
# Section : Bouton pour lancer l'évaluation
|
30 |
+
start_eval = st.button("🚀 Lancer l'évaluation WER")
|
31 |
+
|
32 |
+
if start_eval:
|
33 |
+
st.subheader("🔍 Traitement en cours...")
|
34 |
+
|
35 |
+
# 🔹 Télécharger dataset
|
36 |
+
with st.spinner("Chargement du dataset..."):
|
37 |
+
try:
|
38 |
+
dataset = load_dataset(dataset_link, split="test", token=hf_token)
|
39 |
+
except Exception as e:
|
40 |
+
st.error(f"Erreur lors du chargement du dataset : {e}")
|
41 |
+
st.stop()
|
42 |
+
|
43 |
+
# 🔹 Charger le modèle choisi
|
44 |
+
with st.spinner("Chargement du modèle..."):
|
45 |
+
base_model_name = "openai/whisper-large"
|
46 |
+
model = WhisperForConditionalGeneration.from_pretrained(base_model_name)
|
47 |
+
|
48 |
+
if "LoRA" in model_option:
|
49 |
+
model = PeftModel.from_pretrained(model, "SimpleFrog/whisper_finetuned", token=hf_token)
|
50 |
+
|
51 |
+
processor = WhisperProcessor.from_pretrained(base_model_name)
|
52 |
+
model.eval()
|
53 |
+
|
54 |
+
# 🔹 Préparer WER metric
|
55 |
+
wer_metric = evaluate.load("wer")
|
56 |
+
|
57 |
+
results = []
|
58 |
+
|
59 |
+
for example in dataset:
|
60 |
+
try:
|
61 |
+
audio_path = example["file_name"] # full path or relative path in AudioFolder
|
62 |
+
reference = example["text"]
|
63 |
+
|
64 |
+
# Load audio (we assume dataset is structured with 'file_name')
|
65 |
+
waveform, _ = librosa.load(audio_path, sr=16000)
|
66 |
+
waveform = np.expand_dims(waveform, axis=0)
|
67 |
+
inputs = processor(waveform, sampling_rate=16000, return_tensors="pt")
|
68 |
+
|
69 |
+
with torch.no_grad():
|
70 |
+
pred_ids = model.generate(input_features=inputs.input_features)
|
71 |
+
prediction = processor.batch_decode(pred_ids, skip_special_tokens=True)[0]
|
72 |
+
|
73 |
+
# 🔹 Nettoyage ponctuation pour WER "sans ponctuation"
|
74 |
+
def clean(text):
|
75 |
+
return ''.join([c for c in text.lower() if c.isalnum() or c.isspace()]).strip()
|
76 |
+
|
77 |
+
ref_clean = clean(reference)
|
78 |
+
pred_clean = clean(prediction)
|
79 |
+
wer = wer_metric.compute(predictions=[pred_clean], references=[ref_clean])
|
80 |
+
|
81 |
+
results.append({
|
82 |
+
"Fichier": audio_path,
|
83 |
+
"Référence": reference,
|
84 |
+
"Transcription": prediction,
|
85 |
+
"WER": round(wer, 4)
|
86 |
+
})
|
87 |
+
|
88 |
+
except Exception as e:
|
89 |
+
results.append({
|
90 |
+
"Fichier": example.get("file_name", "unknown"),
|
91 |
+
"Référence": "Erreur",
|
92 |
+
"Transcription": f"Erreur: {e}",
|
93 |
+
"WER": "-"
|
94 |
+
})
|
95 |
+
|
96 |
+
# 🔹 Afficher le tableau de résultats
|
97 |
+
df = pd.DataFrame(results)
|
98 |
+
st.subheader("📋 Résultats de la transcription")
|
99 |
+
st.dataframe(df)
|
100 |
+
|
101 |
+
mean_wer = df[df["WER"] != "-"]["WER"].mean()
|
102 |
+
st.markdown(f"### 🎯 WER moyen (sans ponctuation) : `{mean_wer:.3f}`")
|
103 |
+
|
104 |
+
# Bloc placeholder pour post-processing à venir
|
105 |
+
if "Post-processing" in model_option:
|
106 |
+
st.info("🛠️ Le post-processing sera ajouté prochainement ici...")
|