|
import cupy as cp |
|
import numpy as np |
|
from typing import Dict, Protocol |
|
import logging |
|
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, roc_auc_score |
|
|
|
class MetricsCalculator(Protocol): |
|
""" |
|
Interface pour les calculateurs de métriques. |
|
""" |
|
def calculate_and_log( |
|
self, |
|
y_true: cp.ndarray, |
|
y_pred: cp.ndarray, |
|
prefix: str |
|
) -> Dict[str, float]: |
|
""" |
|
Calcule et log les métriques pour un problème binaire. |
|
""" |
|
pass |
|
|
|
def calculate_and_log_multiclass( |
|
self, |
|
y_true: cp.ndarray, |
|
y_pred: cp.ndarray, |
|
prefix: str |
|
) -> Dict[str, float]: |
|
""" |
|
Calcule et log les métriques pour un problème multiclasses. |
|
""" |
|
pass |
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
class DefaultMetricsCalculator(MetricsCalculator): |
|
""" |
|
Implémentation concrète de MetricsCalculator qui calcule |
|
accuracy, f1, precision, recall, et auc-roc. |
|
Fonctionne pour binaire ou multiclasses (avec 'ovr' ou 'macro'). |
|
""" |
|
|
|
def calculate_and_log( |
|
self, |
|
y_true: cp.ndarray, |
|
y_pred: cp.ndarray, |
|
prefix: str |
|
) -> Dict[str, float]: |
|
""" |
|
Calcule et log les métriques pour un problème binaire |
|
en supposant y_pred est dans {0,1} ou {True,False}. |
|
""" |
|
y_true_np = cp.asnumpy(y_true) |
|
y_pred_np = cp.asnumpy(y_pred) |
|
|
|
acc = accuracy_score(y_true_np, y_pred_np) |
|
prec = precision_score(y_true_np, y_pred_np, zero_division=0) |
|
rec = recall_score(y_true_np, y_pred_np, zero_division=0) |
|
f1 = f1_score(y_true_np, y_pred_np, zero_division=0) |
|
|
|
|
|
|
|
|
|
try: |
|
auc = roc_auc_score(y_true_np, y_pred_np) |
|
except ValueError: |
|
auc = 0.0 |
|
|
|
metrics = { |
|
f"{prefix}_accuracy" : acc, |
|
f"{prefix}_precision" : prec, |
|
f"{prefix}_recall" : rec, |
|
f"{prefix}_f1" : f1, |
|
f"{prefix}_auc_roc" : auc |
|
} |
|
logger.info(f"[{prefix}] Metrics: {metrics}") |
|
return metrics |
|
|
|
def calculate_and_log_multiclass( |
|
self, |
|
y_true: cp.ndarray, |
|
y_pred: cp.ndarray, |
|
prefix: str |
|
) -> Dict[str, float]: |
|
""" |
|
Calcule et log les métriques pour un problème multiclasses. |
|
AUC-ROC en mode 'macro' si possible. |
|
""" |
|
y_true_np = cp.asnumpy(y_true) |
|
y_pred_np = cp.asnumpy(y_pred) |
|
|
|
acc = accuracy_score(y_true_np, y_pred_np) |
|
prec = precision_score(y_true_np, y_pred_np, average="macro", zero_division=0) |
|
rec = recall_score(y_true_np, y_pred_np, average="macro", zero_division=0) |
|
f1 = f1_score(y_true_np, y_pred_np, average="macro", zero_division=0) |
|
|
|
|
|
|
|
|
|
auc = 0.0 |
|
|
|
metrics = { |
|
f"{prefix}_accuracy" : acc, |
|
f"{prefix}_precision" : prec, |
|
f"{prefix}_recall" : rec, |
|
f"{prefix}_f1" : f1, |
|
f"{prefix}_auc_roc" : auc |
|
} |
|
logger.info(f"[{prefix}] Multiclass metrics: {metrics}") |
|
return metrics |
|
|