|
import logging |
|
import math |
|
from typing import Any, Callable, Dict, List, Optional, Union |
|
|
|
import numpy as np |
|
import torch |
|
from pandas import MultiIndex |
|
from pie_modules.documents import TextPairDocumentWithLabeledSpansAndBinaryCorefRelations |
|
from pytorch_ie import DocumentMetric |
|
from pytorch_ie.core.metric import T |
|
from pytorch_ie.utils.hydra import resolve_target |
|
from torchmetrics import Metric, MetricCollection |
|
|
|
from src.hydra_callbacks.save_job_return_value import to_py_obj |
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
def get_num_total(targets: List[int], preds: List[float]): |
|
return len(targets) |
|
|
|
|
|
def get_num_positives(targets: List[int], preds: List[float], positive_idx: int = 1): |
|
return len([v for v in targets if v == positive_idx]) |
|
|
|
|
|
def discretize( |
|
values: List[float], threshold: Union[float, List[float], dict] |
|
) -> Union[List[float], Dict[Any, List[float]]]: |
|
if isinstance(threshold, float): |
|
result = (np.array(values) >= threshold).astype(int).tolist() |
|
return result |
|
if isinstance(threshold, list): |
|
return {t: discretize(values=values, threshold=t) for t in threshold} |
|
if isinstance(threshold, dict): |
|
thresholds = ( |
|
np.arange(threshold["start"], threshold["end"], threshold["step"]).round(4).tolist() |
|
) |
|
return discretize(values, threshold=thresholds) |
|
raise TypeError(f"threshold has unknown type: {threshold}") |
|
|
|
|
|
class CorefMetricsSKLearn(DocumentMetric): |
|
DOCUMENT_TYPE = TextPairDocumentWithLabeledSpansAndBinaryCorefRelations |
|
|
|
def __init__( |
|
self, |
|
metrics: Dict[str, str], |
|
thresholds: Optional[Dict[str, float]] = None, |
|
default_target_idx: int = 0, |
|
default_prediction_score: float = 0.0, |
|
show_as_markdown: bool = False, |
|
markdown_precision: int = 4, |
|
plot: bool = False, |
|
): |
|
self.metrics = {name: resolve_target(metric) for name, metric in metrics.items()} |
|
self.thresholds = thresholds or {} |
|
thresholds_not_in_metrics = { |
|
name: t for name, t in self.thresholds.items() if name not in self.metrics |
|
} |
|
if len(thresholds_not_in_metrics) > 0: |
|
logger.warning( |
|
f"there are discretizing thresholds that do not have a metric: {thresholds_not_in_metrics}" |
|
) |
|
self.default_target_idx = default_target_idx |
|
self.default_prediction_score = default_prediction_score |
|
self.show_as_markdown = show_as_markdown |
|
self.markdown_precision = markdown_precision |
|
self.plot = plot |
|
|
|
super().__init__() |
|
|
|
def reset(self) -> None: |
|
self._preds: List[float] = [] |
|
self._targets: List[int] = [] |
|
|
|
def _update(self, document: TextPairDocumentWithLabeledSpansAndBinaryCorefRelations) -> None: |
|
target_args2idx = { |
|
(rel.head, rel.tail): int(rel.score) for rel in document.binary_coref_relations |
|
} |
|
prediction_args2score = { |
|
(rel.head, rel.tail): rel.score for rel in document.binary_coref_relations.predictions |
|
} |
|
all_args = set(target_args2idx) | set(prediction_args2score) |
|
all_targets: List[int] = [] |
|
all_predictions: List[float] = [] |
|
for args in all_args: |
|
target_idx = target_args2idx.get(args, self.default_target_idx) |
|
prediction_score = prediction_args2score.get(args, self.default_prediction_score) |
|
all_targets.append(target_idx) |
|
all_predictions.append(prediction_score) |
|
|
|
|
|
|
|
self._preds.extend(all_predictions) |
|
self._targets.extend(all_targets) |
|
|
|
def do_plot(self): |
|
raise NotImplementedError() |
|
|
|
from matplotlib import pyplot as plt |
|
|
|
|
|
num_metrics = len(self.metrics) |
|
|
|
|
|
ncols = math.ceil(math.sqrt(num_metrics)) |
|
nrows = math.ceil(num_metrics / ncols) |
|
|
|
|
|
fig, ax_list = plt.subplots(nrows=nrows, ncols=ncols, figsize=(15, 10)) |
|
|
|
|
|
ax_list = ax_list.flatten().tolist() |
|
|
|
|
|
ax_list = ax_list[:num_metrics] |
|
|
|
|
|
self.metrics.plot(ax=ax_list, together=False) |
|
|
|
|
|
plt.tight_layout() |
|
plt.show() |
|
|
|
def _compute(self) -> T: |
|
|
|
if self.plot: |
|
self.do_plot() |
|
|
|
result = {} |
|
for name, metric in self.metrics.items(): |
|
|
|
if name in self.thresholds: |
|
preds = discretize(values=self._preds, threshold=self.thresholds[name]) |
|
else: |
|
preds = self._preds |
|
if isinstance(preds, dict): |
|
metric_results = { |
|
t: metric(self._targets, t_preds) for t, t_preds in preds.items() |
|
} |
|
|
|
max_t, max_v = max(metric_results.items(), key=lambda k_v: k_v[1]) |
|
result[f"{name}-{max_t}"] = max_v |
|
else: |
|
result[name] = metric(self._targets, preds) |
|
|
|
result = to_py_obj(result) |
|
if self.show_as_markdown: |
|
import pandas as pd |
|
|
|
series = pd.Series(result) |
|
if isinstance(series.index, MultiIndex): |
|
if len(series.index.levels) > 1: |
|
|
|
series = series.unstack(-1) |
|
else: |
|
series.index = series.index.get_level_values(0) |
|
logger.info( |
|
f"{self.current_split}\n{series.round(self.markdown_precision).to_markdown()}" |
|
) |
|
return result |
|
|