Spaces:
Running
Running
import os | |
os.environ["XDG_CONFIG_HOME"] = "/tmp" | |
os.environ["XDG_CACHE_HOME"] = "/tmp" | |
os.environ["HF_HOME"] = "/tmp/huggingface" # pour les modèles/datasets | |
os.environ["TRANSFORMERS_CACHE"] = "/tmp/huggingface/transformers" | |
os.environ["HF_HUB_CACHE"] = "/tmp/huggingface/hub" | |
import streamlit as st | |
import tempfile | |
import pandas as pd | |
from datasets import load_dataset | |
from transformers import WhisperProcessor, WhisperForConditionalGeneration | |
from peft import PeftModel | |
import torch | |
import librosa | |
import numpy as np | |
import evaluate | |
import tempfile | |
from huggingface_hub import snapshot_download | |
st.title("📊 Évaluation WER d'un modèle Whisper") | |
st.markdown("Ce Space permet d'évaluer la performance WER d'un modèle Whisper sur un dataset audio.") | |
# Section : Choix du modèle | |
st.subheader("1. Choix du modèle") | |
model_option = st.radio("Quel modèle veux-tu utiliser ?", ( | |
"Whisper Large (baseline)", | |
"Whisper Large + LoRA (SimpleFrog/whisper_finetuned)", | |
"Whisper Large + LoRA + Post-processing (à venir)" | |
)) | |
# Section : Lien du dataset | |
st.subheader("2. Chargement du dataset Hugging Face") | |
dataset_link = st.text_input("Lien du dataset (format: user/dataset_name)", value="SimpleFrog/mon_dataset") | |
hf_token = st.text_input("Token Hugging Face (si dataset privé)", type="password") | |
# Section : Bouton pour lancer l'évaluation | |
start_eval = st.button("🚀 Lancer l'évaluation WER") | |
if start_eval: | |
st.subheader("🔍 Traitement en cours...") | |
# 🔹 Télécharger dataset | |
with st.spinner("Chargement du dataset..."): | |
try: | |
#dataset = load_dataset(dataset_link, data_files="metadata.csv", data_dir=".", split="train", token=hf_token) | |
dataset = load_dataset(dataset_link, split="train", token=hf_token) | |
except Exception as e: | |
st.error(f"Erreur lors du chargement du dataset : {e}") | |
st.stop() | |
# 🔹 Charger le modèle choisi | |
with st.spinner("Chargement du modèle..."): | |
base_model_name = "openai/whisper-large" | |
model = WhisperForConditionalGeneration.from_pretrained(base_model_name) | |
if "LoRA" in model_option: | |
model = PeftModel.from_pretrained(model, "SimpleFrog/whisper_finetuned", token=hf_token) | |
processor = WhisperProcessor.from_pretrained(base_model_name) | |
model.eval() | |
# 🔹 Préparer WER metric | |
wer_metric = evaluate.load("wer") | |
results = [] | |
# Téléchargement explicite du dossier audio (chemin local vers chaque fichier) | |
repo_local_path = snapshot_download(repo_id=dataset_link, repo_type="dataset", token=hf_token) | |
for example in dataset: | |
st.write("Exemple brut :", example) | |
try: | |
#audio_path = example["file_name"] # full path or relative path in AudioFolder | |
audio_path = os.path.join(repo_local_path, example["file_name"]) | |
reference = example["text"] | |
#st.write(example) | |
#st.write("Exemple brut :", dataset[0]) | |
# Load audio (we assume dataset is structured with 'file_name') | |
waveform, _ = librosa.load(audio_path, sr=16000) | |
waveform = np.expand_dims(waveform, axis=0) | |
inputs = processor(waveform, sampling_rate=16000, return_tensors="pt") | |
with torch.no_grad(): | |
pred_ids = model.generate(input_features=inputs.input_features) | |
prediction = processor.batch_decode(pred_ids, skip_special_tokens=True)[0] | |
# 🔹 Nettoyage ponctuation pour WER "sans ponctuation" | |
def clean(text): | |
return ''.join([c for c in text.lower() if c.isalnum() or c.isspace()]).strip() | |
ref_clean = clean(reference) | |
pred_clean = clean(prediction) | |
wer = wer_metric.compute(predictions=[pred_clean], references=[ref_clean]) | |
results.append({ | |
"Fichier": audio_path, | |
"Référence": reference, | |
"Transcription": prediction, | |
"WER": round(wer, 4) | |
}) | |
except Exception as e: | |
results.append({ | |
"Fichier": example.get("file_name", "unknown"), | |
"Référence": "Erreur", | |
"Transcription": f"Erreur: {e}", | |
"WER": "-" | |
}) | |
# 🔹 Générer le tableau de résultats | |
df = pd.DataFrame(results) | |
# 🔹 Créer un fichier temporaire pour le CSV | |
with tempfile.NamedTemporaryFile(mode="w+", delete=False, suffix=".csv") as tmp_csv: | |
df.to_csv(tmp_csv.name, index=False) | |
mean_wer = df[df["WER"] != "-"]["WER"].mean() | |
st.markdown(f"### 🎯 WER moyen (sans ponctuation) : `{mean_wer:.3f}`") | |
# Bloc placeholder pour post-processing à venir | |
if "Post-processing" in model_option: | |
st.info("🛠️ Le post-processing sera ajouté prochainement ici...") | |
# 🔹 Bouton de téléchargement | |
with open(tmp_csv.name, "rb") as f: | |
st.download_button( | |
label="📥 Télécharger les résultats WER (.csv)", | |
data=f, | |
file_name="wer_results.csv", | |
mime="text/csv" | |
) | |