WER_Evaluation / app.py
SimpleFrog's picture
Update app.py
98ab6e8 verified
import os
os.environ["XDG_CONFIG_HOME"] = "/tmp"
os.environ["XDG_CACHE_HOME"] = "/tmp"
os.environ["HF_HOME"] = "/tmp/huggingface" # pour les modèles/datasets
os.environ["TRANSFORMERS_CACHE"] = "/tmp/huggingface/transformers"
os.environ["HF_HUB_CACHE"] = "/tmp/huggingface/hub"
import streamlit as st
import tempfile
import pandas as pd
from datasets import load_dataset
from transformers import WhisperProcessor, WhisperForConditionalGeneration
from peft import PeftModel
import torch
import librosa
import numpy as np
import evaluate
import tempfile
from huggingface_hub import snapshot_download
st.title("📊 Évaluation WER d'un modèle Whisper")
st.markdown("Ce Space permet d'évaluer la performance WER d'un modèle Whisper sur un dataset audio.")
# Section : Choix du modèle
st.subheader("1. Choix du modèle")
model_option = st.radio("Quel modèle veux-tu utiliser ?", (
"Whisper Large (baseline)",
"Whisper Large + LoRA (SimpleFrog/whisper_finetuned)",
"Whisper Large + LoRA + Post-processing (à venir)"
))
# Section : Lien du dataset
st.subheader("2. Chargement du dataset Hugging Face")
dataset_link = st.text_input("Lien du dataset (format: user/dataset_name)", value="SimpleFrog/mon_dataset")
hf_token = st.text_input("Token Hugging Face (si dataset privé)", type="password")
# Section : Bouton pour lancer l'évaluation
start_eval = st.button("🚀 Lancer l'évaluation WER")
if start_eval:
st.subheader("🔍 Traitement en cours...")
# 🔹 Télécharger dataset
with st.spinner("Chargement du dataset..."):
try:
#dataset = load_dataset(dataset_link, data_files="metadata.csv", data_dir=".", split="train", token=hf_token)
dataset = load_dataset(dataset_link, split="train", token=hf_token)
except Exception as e:
st.error(f"Erreur lors du chargement du dataset : {e}")
st.stop()
# 🔹 Charger le modèle choisi
with st.spinner("Chargement du modèle..."):
base_model_name = "openai/whisper-large"
model = WhisperForConditionalGeneration.from_pretrained(base_model_name)
if "LoRA" in model_option:
model = PeftModel.from_pretrained(model, "SimpleFrog/whisper_finetuned", token=hf_token)
processor = WhisperProcessor.from_pretrained(base_model_name)
model.eval()
# 🔹 Préparer WER metric
wer_metric = evaluate.load("wer")
results = []
# Téléchargement explicite du dossier audio (chemin local vers chaque fichier)
repo_local_path = snapshot_download(repo_id=dataset_link, repo_type="dataset", token=hf_token)
for example in dataset:
st.write("Exemple brut :", example)
try:
#audio_path = example["file_name"] # full path or relative path in AudioFolder
audio_path = os.path.join(repo_local_path, example["file_name"])
reference = example["text"]
#st.write(example)
#st.write("Exemple brut :", dataset[0])
# Load audio (we assume dataset is structured with 'file_name')
waveform, _ = librosa.load(audio_path, sr=16000)
waveform = np.expand_dims(waveform, axis=0)
inputs = processor(waveform, sampling_rate=16000, return_tensors="pt")
with torch.no_grad():
pred_ids = model.generate(input_features=inputs.input_features)
prediction = processor.batch_decode(pred_ids, skip_special_tokens=True)[0]
# 🔹 Nettoyage ponctuation pour WER "sans ponctuation"
def clean(text):
return ''.join([c for c in text.lower() if c.isalnum() or c.isspace()]).strip()
ref_clean = clean(reference)
pred_clean = clean(prediction)
wer = wer_metric.compute(predictions=[pred_clean], references=[ref_clean])
results.append({
"Fichier": audio_path,
"Référence": reference,
"Transcription": prediction,
"WER": round(wer, 4)
})
except Exception as e:
results.append({
"Fichier": example.get("file_name", "unknown"),
"Référence": "Erreur",
"Transcription": f"Erreur: {e}",
"WER": "-"
})
# 🔹 Générer le tableau de résultats
df = pd.DataFrame(results)
# 🔹 Créer un fichier temporaire pour le CSV
with tempfile.NamedTemporaryFile(mode="w+", delete=False, suffix=".csv") as tmp_csv:
df.to_csv(tmp_csv.name, index=False)
mean_wer = df[df["WER"] != "-"]["WER"].mean()
st.markdown(f"### 🎯 WER moyen (sans ponctuation) : `{mean_wer:.3f}`")
# Bloc placeholder pour post-processing à venir
if "Post-processing" in model_option:
st.info("🛠️ Le post-processing sera ajouté prochainement ici...")
# 🔹 Bouton de téléchargement
with open(tmp_csv.name, "rb") as f:
st.download_button(
label="📥 Télécharger les résultats WER (.csv)",
data=f,
file_name="wer_results.csv",
mime="text/csv"
)