emotion_classifier / src /interfaces /metrics_calculator.py
fioriclass's picture
correction bug
65e5e42
import cupy as cp
from typing import Dict, Protocol, Tuple
import warnings
# Utiliser cuml.metrics pour les calculs accélérés par GPU
from cuml.metrics import accuracy_score, precision_recall_curve, roc_auc_score
# Filtrer les avertissements
warnings.filterwarnings("ignore", category=Warning)
class MetricsCalculator(Protocol):
"""
Interface pour les calculateurs de métriques pour la classification binaire.
"""
def calculate_and_log(
self,
y_true: cp.ndarray,
y_pred: cp.ndarray,
y_proba: cp.ndarray, # Probabilités classe positive (1D) - Supposées toujours fournies
prefix: str
) -> Dict[str, float]:
"""
Calcule les métriques pour un problème binaire.
Assume que y_proba est toujours fourni et est un tableau 1D
contenant les probabilités de la classe positive.
Retourne Accuracy, AUC ROC, Precision, Recall et F1 Score.
"""
pass
class DefaultMetricsCalculator(MetricsCalculator):
"""
Implémentation concrète de MetricsCalculator utilisant cuML pour la classification binaire.
Calcule l'accuracy, l'AUC-ROC, la précision, le rappel et le F1 score en utilisant les fonctions cuML.
Utilise precision_recall_curve pour calculer les métriques de précision, rappel et F1 score optimales.
Assume que les données d'entrée sont valides et que y_proba est toujours fourni
en tant que tableau 1D des probabilités de la classe positive.
"""
def calculate_and_log(
self,
y_true: cp.ndarray,
y_pred: cp.ndarray,
y_proba: cp.ndarray, # Probabilités classe positive (1D) - Supposées toujours fournies
prefix: str
) -> Dict[str, float]:
"""
Calcule l'accuracy, l'AUC ROC, la précision, le rappel et le F1 score pour un problème binaire en utilisant cuML.
Utilise precision_recall_curve pour calculer les métriques optimales.
Assume des entrées valides et que y_proba est un tableau 1D fourni.
"""
# 1. Calculer l'accuracy (comme dans l'exemple accuracy_score)
acc = accuracy_score(y_true, y_pred)
# 2. Calculer l'AUC binaire (comme dans l'exemple roc_auc_score)
auc = roc_auc_score(y_true.astype(cp.int32), y_proba.astype(cp.float32))
# 3. Utiliser precision_recall_curve pour obtenir les courbes
precision, recall, thresholds = precision_recall_curve(
y_true.astype(cp.int32), y_proba.astype(cp.float32)
)
# 4. Calculer la précision, le rappel et le F1 score optimaux
optimal_precision, optimal_recall, optimal_f1, optimal_threshold = self._calculate_optimal_f1(
precision, recall, thresholds
)
# Construire le dictionnaire des métriques scalaires disponibles
metrics = {
f"{prefix}_accuracy" : acc,
f"{prefix}_precision" : optimal_precision,
f"{prefix}_recall" : optimal_recall,
f"{prefix}_f1" : optimal_f1,
f"{prefix}_optimal_threshold" : optimal_threshold,
f"{prefix}_auc_roc" : auc
}
# Retourner les métriques scalaires calculées
return metrics
def _calculate_optimal_f1(
self,
precision: cp.ndarray,
recall: cp.ndarray,
thresholds: cp.ndarray
) -> Tuple[float, float, float, float]:
"""
Calcule le F1 score optimal à partir des courbes de précision et de rappel.
Args:
precision: Tableau de précisions pour différents seuils
recall: Tableau de rappels pour différents seuils
thresholds: Tableau de seuils correspondants
Returns:
Tuple contenant (précision optimale, rappel optimal, F1 score optimal, seuil optimal)
"""
# Ajouter le seuil 1.0 à thresholds (qui n'est pas inclus par défaut dans precision_recall_curve)
thresholds_with_one = cp.append(thresholds, cp.array([1.0]))
# Calculer le F1 score pour chaque point de la courbe
# F1 = 2 * (precision * recall) / (precision + recall)
f1_scores = 2 * (precision * recall) / (precision + recall)
# Trouver l'indice du F1 score maximal
best_idx = cp.argmax(f1_scores)
best_precision = float(precision[best_idx])
best_recall = float(recall[best_idx])
best_f1 = float(f1_scores[best_idx])
# Obtenir le seuil optimal
best_threshold = float(thresholds_with_one[best_idx])
return best_precision, best_recall, best_f1, best_threshold