ArneBinder's picture
https://github.com/ArneBinder/pie-document-level/pull/312
3133b5e verified
import logging
import math
from typing import Any, Callable, Dict, List, Optional, Union
import numpy as np
import torch
from pandas import MultiIndex
from pie_modules.documents import TextPairDocumentWithLabeledSpansAndBinaryCorefRelations
from pytorch_ie import DocumentMetric
from pytorch_ie.core.metric import T
from pytorch_ie.utils.hydra import resolve_target
from torchmetrics import Metric, MetricCollection
from src.hydra_callbacks.save_job_return_value import to_py_obj
logger = logging.getLogger(__name__)
def get_num_total(targets: List[int], preds: List[float]):
return len(targets)
def get_num_positives(targets: List[int], preds: List[float], positive_idx: int = 1):
return len([v for v in targets if v == positive_idx])
def discretize(
values: List[float], threshold: Union[float, List[float], dict]
) -> Union[List[float], Dict[Any, List[float]]]:
if isinstance(threshold, float):
result = (np.array(values) >= threshold).astype(int).tolist()
return result
if isinstance(threshold, list):
return {t: discretize(values=values, threshold=t) for t in threshold} # type: ignore
if isinstance(threshold, dict):
thresholds = (
np.arange(threshold["start"], threshold["end"], threshold["step"]).round(4).tolist()
)
return discretize(values, threshold=thresholds)
raise TypeError(f"threshold has unknown type: {threshold}")
class CorefMetricsSKLearn(DocumentMetric):
DOCUMENT_TYPE = TextPairDocumentWithLabeledSpansAndBinaryCorefRelations
def __init__(
self,
metrics: Dict[str, str],
thresholds: Optional[Dict[str, float]] = None,
default_target_idx: int = 0,
default_prediction_score: float = 0.0,
show_as_markdown: bool = False,
markdown_precision: int = 4,
plot: bool = False,
):
self.metrics = {name: resolve_target(metric) for name, metric in metrics.items()}
self.thresholds = thresholds or {}
thresholds_not_in_metrics = {
name: t for name, t in self.thresholds.items() if name not in self.metrics
}
if len(thresholds_not_in_metrics) > 0:
logger.warning(
f"there are discretizing thresholds that do not have a metric: {thresholds_not_in_metrics}"
)
self.default_target_idx = default_target_idx
self.default_prediction_score = default_prediction_score
self.show_as_markdown = show_as_markdown
self.markdown_precision = markdown_precision
self.plot = plot
super().__init__()
def reset(self) -> None:
self._preds: List[float] = []
self._targets: List[int] = []
def _update(self, document: TextPairDocumentWithLabeledSpansAndBinaryCorefRelations) -> None:
target_args2idx = {
(rel.head, rel.tail): int(rel.score) for rel in document.binary_coref_relations
}
prediction_args2score = {
(rel.head, rel.tail): rel.score for rel in document.binary_coref_relations.predictions
}
all_args = set(target_args2idx) | set(prediction_args2score)
all_targets: List[int] = []
all_predictions: List[float] = []
for args in all_args:
target_idx = target_args2idx.get(args, self.default_target_idx)
prediction_score = prediction_args2score.get(args, self.default_prediction_score)
all_targets.append(target_idx)
all_predictions.append(prediction_score)
# prediction_scores = torch.tensor(all_predictions)
# target_indices = torch.tensor(all_targets)
# self.metrics.update(preds=prediction_scores, target=target_indices)
self._preds.extend(all_predictions)
self._targets.extend(all_targets)
def do_plot(self):
raise NotImplementedError()
from matplotlib import pyplot as plt
# Get the number of metrics
num_metrics = len(self.metrics)
# Calculate rows and columns for subplots (aim for a square-like layout)
ncols = math.ceil(math.sqrt(num_metrics))
nrows = math.ceil(num_metrics / ncols)
# Create the subplots
fig, ax_list = plt.subplots(nrows=nrows, ncols=ncols, figsize=(15, 10))
# Flatten the ax_list if necessary (in case of multiple rows/columns)
ax_list = ax_list.flatten().tolist() # Ensure it's a list, and flatten it if necessary
# Ensure that we pass exactly the number of axes required by metrics
ax_list = ax_list[:num_metrics]
# Plot the metrics using the list of axes
self.metrics.plot(ax=ax_list, together=False)
# Adjust layout to avoid overlapping plots
plt.tight_layout()
plt.show()
def _compute(self) -> T:
if self.plot:
self.do_plot()
result = {}
for name, metric in self.metrics.items():
if name in self.thresholds:
preds = discretize(values=self._preds, threshold=self.thresholds[name])
else:
preds = self._preds
if isinstance(preds, dict):
metric_results = {
t: metric(self._targets, t_preds) for t, t_preds in preds.items()
}
# just get the max
max_t, max_v = max(metric_results.items(), key=lambda k_v: k_v[1])
result[f"{name}-{max_t}"] = max_v
else:
result[name] = metric(self._targets, preds)
result = to_py_obj(result)
if self.show_as_markdown:
import pandas as pd
series = pd.Series(result)
if isinstance(series.index, MultiIndex):
if len(series.index.levels) > 1:
# in fact, this is not a series anymore
series = series.unstack(-1)
else:
series.index = series.index.get_level_values(0)
logger.info(
f"{self.current_split}\n{series.round(self.markdown_precision).to_markdown()}"
)
return result