import os os.environ["XDG_CONFIG_HOME"] = "/tmp" os.environ["XDG_CACHE_HOME"] = "/tmp" os.environ["HF_HOME"] = "/tmp/huggingface" # pour les modèles/datasets os.environ["TRANSFORMERS_CACHE"] = "/tmp/huggingface/transformers" os.environ["HF_HUB_CACHE"] = "/tmp/huggingface/hub" import streamlit as st import tempfile import pandas as pd from datasets import load_dataset from transformers import WhisperProcessor, WhisperForConditionalGeneration from peft import PeftModel import torch import librosa import numpy as np import evaluate import tempfile from huggingface_hub import snapshot_download st.title("📊 Évaluation WER d'un modèle Whisper") st.markdown("Ce Space permet d'évaluer la performance WER d'un modèle Whisper sur un dataset audio.") # Section : Choix du modèle st.subheader("1. Choix du modèle") model_option = st.radio("Quel modèle veux-tu utiliser ?", ( "Whisper Large (baseline)", "Whisper Large + LoRA (SimpleFrog/whisper_finetuned)", "Whisper Large + LoRA + Post-processing (à venir)" )) # Section : Lien du dataset st.subheader("2. Chargement du dataset Hugging Face") dataset_link = st.text_input("Lien du dataset (format: user/dataset_name)", value="SimpleFrog/mon_dataset") hf_token = st.text_input("Token Hugging Face (si dataset privé)", type="password") # Section : Bouton pour lancer l'évaluation start_eval = st.button("🚀 Lancer l'évaluation WER") if start_eval: st.subheader("🔍 Traitement en cours...") # 🔹 Télécharger dataset with st.spinner("Chargement du dataset..."): try: #dataset = load_dataset(dataset_link, data_files="metadata.csv", data_dir=".", split="train", token=hf_token) dataset = load_dataset(dataset_link, split="train", token=hf_token) except Exception as e: st.error(f"Erreur lors du chargement du dataset : {e}") st.stop() # 🔹 Charger le modèle choisi with st.spinner("Chargement du modèle..."): base_model_name = "openai/whisper-large" model = WhisperForConditionalGeneration.from_pretrained(base_model_name) if "LoRA" in model_option: model = PeftModel.from_pretrained(model, "SimpleFrog/whisper_finetuned", token=hf_token) processor = WhisperProcessor.from_pretrained(base_model_name) model.eval() # 🔹 Préparer WER metric wer_metric = evaluate.load("wer") results = [] # Téléchargement explicite du dossier audio (chemin local vers chaque fichier) repo_local_path = snapshot_download(repo_id=dataset_link, repo_type="dataset", token=hf_token) for example in dataset: st.write("Exemple brut :", example) try: #audio_path = example["file_name"] # full path or relative path in AudioFolder audio_path = os.path.join(repo_local_path, example["file_name"]) reference = example["text"] #st.write(example) #st.write("Exemple brut :", dataset[0]) # Load audio (we assume dataset is structured with 'file_name') waveform, _ = librosa.load(audio_path, sr=16000) waveform = np.expand_dims(waveform, axis=0) inputs = processor(waveform, sampling_rate=16000, return_tensors="pt") with torch.no_grad(): pred_ids = model.generate(input_features=inputs.input_features) prediction = processor.batch_decode(pred_ids, skip_special_tokens=True)[0] # 🔹 Nettoyage ponctuation pour WER "sans ponctuation" def clean(text): return ''.join([c for c in text.lower() if c.isalnum() or c.isspace()]).strip() ref_clean = clean(reference) pred_clean = clean(prediction) wer = wer_metric.compute(predictions=[pred_clean], references=[ref_clean]) results.append({ "Fichier": audio_path, "Référence": reference, "Transcription": prediction, "WER": round(wer, 4) }) except Exception as e: results.append({ "Fichier": example.get("file_name", "unknown"), "Référence": "Erreur", "Transcription": f"Erreur: {e}", "WER": "-" }) # 🔹 Générer le tableau de résultats df = pd.DataFrame(results) # 🔹 Créer un fichier temporaire pour le CSV with tempfile.NamedTemporaryFile(mode="w+", delete=False, suffix=".csv") as tmp_csv: df.to_csv(tmp_csv.name, index=False) mean_wer = df[df["WER"] != "-"]["WER"].mean() st.markdown(f"### 🎯 WER moyen (sans ponctuation) : `{mean_wer:.3f}`") # Bloc placeholder pour post-processing à venir if "Post-processing" in model_option: st.info("🛠️ Le post-processing sera ajouté prochainement ici...") # 🔹 Bouton de téléchargement with open(tmp_csv.name, "rb") as f: st.download_button( label="📥 Télécharger les résultats WER (.csv)", data=f, file_name="wer_results.csv", mime="text/csv" )