emotion_classifier / src /interfaces /metrics_calculator.py
fioriclass's picture
correction bug config
43d4438
raw
history blame
3.56 kB
import cupy as cp
import numpy as np
from typing import Dict, Protocol
import logging
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, roc_auc_score
class MetricsCalculator(Protocol):
"""
Interface pour les calculateurs de métriques.
"""
def calculate_and_log(
self,
y_true: cp.ndarray,
y_pred: cp.ndarray,
prefix: str
) -> Dict[str, float]:
"""
Calcule et log les métriques pour un problème binaire.
"""
pass
def calculate_and_log_multiclass(
self,
y_true: cp.ndarray,
y_pred: cp.ndarray,
prefix: str
) -> Dict[str, float]:
"""
Calcule et log les métriques pour un problème multiclasses.
"""
pass
logger = logging.getLogger(__name__)
class DefaultMetricsCalculator(MetricsCalculator):
"""
Implémentation concrète de MetricsCalculator qui calcule
accuracy, f1, precision, recall, et auc-roc.
Fonctionne pour binaire ou multiclasses (avec 'ovr' ou 'macro').
"""
def calculate_and_log(
self,
y_true: cp.ndarray,
y_pred: cp.ndarray,
prefix: str
) -> Dict[str, float]:
"""
Calcule et log les métriques pour un problème binaire
en supposant y_pred est dans {0,1} ou {True,False}.
"""
y_true_np = cp.asnumpy(y_true)
y_pred_np = cp.asnumpy(y_pred)
acc = accuracy_score(y_true_np, y_pred_np)
prec = precision_score(y_true_np, y_pred_np, zero_division=0)
rec = recall_score(y_true_np, y_pred_np, zero_division=0)
f1 = f1_score(y_true_np, y_pred_np, zero_division=0)
# Calcul AUC pour un problème binaire (si y_pred est 0/1)
# On treat y_pred_np as our "probabilities" only if truly 0/1.
# In a real pipeline you might store probabilities separately.
try:
auc = roc_auc_score(y_true_np, y_pred_np)
except ValueError:
auc = 0.0
metrics = {
f"{prefix}_accuracy" : acc,
f"{prefix}_precision" : prec,
f"{prefix}_recall" : rec,
f"{prefix}_f1" : f1,
f"{prefix}_auc_roc" : auc
}
logger.info(f"[{prefix}] Metrics: {metrics}")
return metrics
def calculate_and_log_multiclass(
self,
y_true: cp.ndarray,
y_pred: cp.ndarray,
prefix: str
) -> Dict[str, float]:
"""
Calcule et log les métriques pour un problème multiclasses.
AUC-ROC en mode 'macro' si possible.
"""
y_true_np = cp.asnumpy(y_true)
y_pred_np = cp.asnumpy(y_pred)
acc = accuracy_score(y_true_np, y_pred_np)
prec = precision_score(y_true_np, y_pred_np, average="macro", zero_division=0)
rec = recall_score(y_true_np, y_pred_np, average="macro", zero_division=0)
f1 = f1_score(y_true_np, y_pred_np, average="macro", zero_division=0)
# Pour le multiclasses, la roc_auc_score nécessite des scores proba
# ou "decision_function" => vous ajusterez selon votre cas.
# Ici, on met 0.0 en fallback.
auc = 0.0
metrics = {
f"{prefix}_accuracy" : acc,
f"{prefix}_precision" : prec,
f"{prefix}_recall" : rec,
f"{prefix}_f1" : f1,
f"{prefix}_auc_roc" : auc
}
logger.info(f"[{prefix}] Multiclass metrics: {metrics}")
return metrics