import logging import math from typing import Any, Callable, Dict, List, Optional, Union import numpy as np import torch from pandas import MultiIndex from pie_modules.documents import TextPairDocumentWithLabeledSpansAndBinaryCorefRelations from pytorch_ie import DocumentMetric from pytorch_ie.core.metric import T from pytorch_ie.utils.hydra import resolve_target from torchmetrics import Metric, MetricCollection from src.hydra_callbacks.save_job_return_value import to_py_obj logger = logging.getLogger(__name__) def get_num_total(targets: List[int], preds: List[float]): return len(targets) def get_num_positives(targets: List[int], preds: List[float], positive_idx: int = 1): return len([v for v in targets if v == positive_idx]) def discretize( values: List[float], threshold: Union[float, List[float], dict] ) -> Union[List[float], Dict[Any, List[float]]]: if isinstance(threshold, float): result = (np.array(values) >= threshold).astype(int).tolist() return result if isinstance(threshold, list): return {t: discretize(values=values, threshold=t) for t in threshold} # type: ignore if isinstance(threshold, dict): thresholds = ( np.arange(threshold["start"], threshold["end"], threshold["step"]).round(4).tolist() ) return discretize(values, threshold=thresholds) raise TypeError(f"threshold has unknown type: {threshold}") class CorefMetricsSKLearn(DocumentMetric): DOCUMENT_TYPE = TextPairDocumentWithLabeledSpansAndBinaryCorefRelations def __init__( self, metrics: Dict[str, str], thresholds: Optional[Dict[str, float]] = None, default_target_idx: int = 0, default_prediction_score: float = 0.0, show_as_markdown: bool = False, markdown_precision: int = 4, plot: bool = False, ): self.metrics = {name: resolve_target(metric) for name, metric in metrics.items()} self.thresholds = thresholds or {} thresholds_not_in_metrics = { name: t for name, t in self.thresholds.items() if name not in self.metrics } if len(thresholds_not_in_metrics) > 0: logger.warning( f"there are discretizing thresholds that do not have a metric: {thresholds_not_in_metrics}" ) self.default_target_idx = default_target_idx self.default_prediction_score = default_prediction_score self.show_as_markdown = show_as_markdown self.markdown_precision = markdown_precision self.plot = plot super().__init__() def reset(self) -> None: self._preds: List[float] = [] self._targets: List[int] = [] def _update(self, document: TextPairDocumentWithLabeledSpansAndBinaryCorefRelations) -> None: target_args2idx = { (rel.head, rel.tail): int(rel.score) for rel in document.binary_coref_relations } prediction_args2score = { (rel.head, rel.tail): rel.score for rel in document.binary_coref_relations.predictions } all_args = set(target_args2idx) | set(prediction_args2score) all_targets: List[int] = [] all_predictions: List[float] = [] for args in all_args: target_idx = target_args2idx.get(args, self.default_target_idx) prediction_score = prediction_args2score.get(args, self.default_prediction_score) all_targets.append(target_idx) all_predictions.append(prediction_score) # prediction_scores = torch.tensor(all_predictions) # target_indices = torch.tensor(all_targets) # self.metrics.update(preds=prediction_scores, target=target_indices) self._preds.extend(all_predictions) self._targets.extend(all_targets) def do_plot(self): raise NotImplementedError() from matplotlib import pyplot as plt # Get the number of metrics num_metrics = len(self.metrics) # Calculate rows and columns for subplots (aim for a square-like layout) ncols = math.ceil(math.sqrt(num_metrics)) nrows = math.ceil(num_metrics / ncols) # Create the subplots fig, ax_list = plt.subplots(nrows=nrows, ncols=ncols, figsize=(15, 10)) # Flatten the ax_list if necessary (in case of multiple rows/columns) ax_list = ax_list.flatten().tolist() # Ensure it's a list, and flatten it if necessary # Ensure that we pass exactly the number of axes required by metrics ax_list = ax_list[:num_metrics] # Plot the metrics using the list of axes self.metrics.plot(ax=ax_list, together=False) # Adjust layout to avoid overlapping plots plt.tight_layout() plt.show() def _compute(self) -> T: if self.plot: self.do_plot() result = {} for name, metric in self.metrics.items(): if name in self.thresholds: preds = discretize(values=self._preds, threshold=self.thresholds[name]) else: preds = self._preds if isinstance(preds, dict): metric_results = { t: metric(self._targets, t_preds) for t, t_preds in preds.items() } # just get the max max_t, max_v = max(metric_results.items(), key=lambda k_v: k_v[1]) result[f"{name}-{max_t}"] = max_v else: result[name] = metric(self._targets, preds) result = to_py_obj(result) if self.show_as_markdown: import pandas as pd series = pd.Series(result) if isinstance(series.index, MultiIndex): if len(series.index.levels) > 1: # in fact, this is not a series anymore series = series.unstack(-1) else: series.index = series.index.get_level_values(0) logger.info( f"{self.current_split}\n{series.round(self.markdown_precision).to_markdown()}" ) return result