import os os.environ["STREAMLIT_BROWSER_GATHER_USAGE_STATS"] = "false" os.environ["XDG_CONFIG_HOME"] = "/tmp" import streamlit as st import tempfile import pandas as pd from datasets import load_dataset from transformers import WhisperProcessor, WhisperForConditionalGeneration from peft import PeftModel import torch import librosa import numpy as np import evaluate st.title("📊 Évaluation WER d'un modèle Whisper") st.markdown("Ce Space permet d'évaluer la performance WER d'un modèle Whisper sur un dataset audio.") # Section : Choix du modèle st.subheader("1. Choix du modèle") model_option = st.radio("Quel modèle veux-tu utiliser ?", ( "Whisper Large (baseline)", "Whisper Large + LoRA (SimpleFrog/whisper_finetuned)", "Whisper Large + LoRA + Post-processing (à venir)" )) # Section : Lien du dataset st.subheader("2. Chargement du dataset Hugging Face") dataset_link = st.text_input("Lien du dataset (format: user/dataset_name)", value="SimpleFrog/mon_dataset") hf_token = st.text_input("Token Hugging Face (si dataset privé)", type="password") # Section : Bouton pour lancer l'évaluation start_eval = st.button("🚀 Lancer l'évaluation WER") if start_eval: st.subheader("🔍 Traitement en cours...") # 🔹 Télécharger dataset with st.spinner("Chargement du dataset..."): try: dataset = load_dataset(dataset_link, split="test", token=hf_token) except Exception as e: st.error(f"Erreur lors du chargement du dataset : {e}") st.stop() # 🔹 Charger le modèle choisi with st.spinner("Chargement du modèle..."): base_model_name = "openai/whisper-large" model = WhisperForConditionalGeneration.from_pretrained(base_model_name) if "LoRA" in model_option: model = PeftModel.from_pretrained(model, "SimpleFrog/whisper_finetuned", token=hf_token) processor = WhisperProcessor.from_pretrained(base_model_name) model.eval() # 🔹 Préparer WER metric wer_metric = evaluate.load("wer") results = [] for example in dataset: try: audio_path = example["file_name"] # full path or relative path in AudioFolder reference = example["text"] # Load audio (we assume dataset is structured with 'file_name') waveform, _ = librosa.load(audio_path, sr=16000) waveform = np.expand_dims(waveform, axis=0) inputs = processor(waveform, sampling_rate=16000, return_tensors="pt") with torch.no_grad(): pred_ids = model.generate(input_features=inputs.input_features) prediction = processor.batch_decode(pred_ids, skip_special_tokens=True)[0] # 🔹 Nettoyage ponctuation pour WER "sans ponctuation" def clean(text): return ''.join([c for c in text.lower() if c.isalnum() or c.isspace()]).strip() ref_clean = clean(reference) pred_clean = clean(prediction) wer = wer_metric.compute(predictions=[pred_clean], references=[ref_clean]) results.append({ "Fichier": audio_path, "Référence": reference, "Transcription": prediction, "WER": round(wer, 4) }) except Exception as e: results.append({ "Fichier": example.get("file_name", "unknown"), "Référence": "Erreur", "Transcription": f"Erreur: {e}", "WER": "-" }) # 🔹 Générer le tableau de résultats df = pd.DataFrame(results) df.to_csv("wer_results.csv", index=False) mean_wer = df[df["WER"] != "-"]["WER"].mean() st.markdown(f"### 🎯 WER moyen (sans ponctuation) : `{mean_wer:.3f}`") # Bloc placeholder pour post-processing à venir if "Post-processing" in model_option: st.info("🛠️ Le post-processing sera ajouté prochainement ici...") if os.path.exists("wer_results.csv"): with open("wer_results.csv", "rb") as f: st.download_button( label="📥 Télécharger les résultats WER (.csv)", data=f, file_name="wer_results.csv", mime="text/csv" )