File size: 6,131 Bytes
3133b5e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 |
import logging
import math
from typing import Any, Callable, Dict, List, Optional, Union
import numpy as np
import torch
from pandas import MultiIndex
from pie_modules.documents import TextPairDocumentWithLabeledSpansAndBinaryCorefRelations
from pytorch_ie import DocumentMetric
from pytorch_ie.core.metric import T
from pytorch_ie.utils.hydra import resolve_target
from torchmetrics import Metric, MetricCollection
from src.hydra_callbacks.save_job_return_value import to_py_obj
logger = logging.getLogger(__name__)
def get_num_total(targets: List[int], preds: List[float]):
return len(targets)
def get_num_positives(targets: List[int], preds: List[float], positive_idx: int = 1):
return len([v for v in targets if v == positive_idx])
def discretize(
values: List[float], threshold: Union[float, List[float], dict]
) -> Union[List[float], Dict[Any, List[float]]]:
if isinstance(threshold, float):
result = (np.array(values) >= threshold).astype(int).tolist()
return result
if isinstance(threshold, list):
return {t: discretize(values=values, threshold=t) for t in threshold} # type: ignore
if isinstance(threshold, dict):
thresholds = (
np.arange(threshold["start"], threshold["end"], threshold["step"]).round(4).tolist()
)
return discretize(values, threshold=thresholds)
raise TypeError(f"threshold has unknown type: {threshold}")
class CorefMetricsSKLearn(DocumentMetric):
DOCUMENT_TYPE = TextPairDocumentWithLabeledSpansAndBinaryCorefRelations
def __init__(
self,
metrics: Dict[str, str],
thresholds: Optional[Dict[str, float]] = None,
default_target_idx: int = 0,
default_prediction_score: float = 0.0,
show_as_markdown: bool = False,
markdown_precision: int = 4,
plot: bool = False,
):
self.metrics = {name: resolve_target(metric) for name, metric in metrics.items()}
self.thresholds = thresholds or {}
thresholds_not_in_metrics = {
name: t for name, t in self.thresholds.items() if name not in self.metrics
}
if len(thresholds_not_in_metrics) > 0:
logger.warning(
f"there are discretizing thresholds that do not have a metric: {thresholds_not_in_metrics}"
)
self.default_target_idx = default_target_idx
self.default_prediction_score = default_prediction_score
self.show_as_markdown = show_as_markdown
self.markdown_precision = markdown_precision
self.plot = plot
super().__init__()
def reset(self) -> None:
self._preds: List[float] = []
self._targets: List[int] = []
def _update(self, document: TextPairDocumentWithLabeledSpansAndBinaryCorefRelations) -> None:
target_args2idx = {
(rel.head, rel.tail): int(rel.score) for rel in document.binary_coref_relations
}
prediction_args2score = {
(rel.head, rel.tail): rel.score for rel in document.binary_coref_relations.predictions
}
all_args = set(target_args2idx) | set(prediction_args2score)
all_targets: List[int] = []
all_predictions: List[float] = []
for args in all_args:
target_idx = target_args2idx.get(args, self.default_target_idx)
prediction_score = prediction_args2score.get(args, self.default_prediction_score)
all_targets.append(target_idx)
all_predictions.append(prediction_score)
# prediction_scores = torch.tensor(all_predictions)
# target_indices = torch.tensor(all_targets)
# self.metrics.update(preds=prediction_scores, target=target_indices)
self._preds.extend(all_predictions)
self._targets.extend(all_targets)
def do_plot(self):
raise NotImplementedError()
from matplotlib import pyplot as plt
# Get the number of metrics
num_metrics = len(self.metrics)
# Calculate rows and columns for subplots (aim for a square-like layout)
ncols = math.ceil(math.sqrt(num_metrics))
nrows = math.ceil(num_metrics / ncols)
# Create the subplots
fig, ax_list = plt.subplots(nrows=nrows, ncols=ncols, figsize=(15, 10))
# Flatten the ax_list if necessary (in case of multiple rows/columns)
ax_list = ax_list.flatten().tolist() # Ensure it's a list, and flatten it if necessary
# Ensure that we pass exactly the number of axes required by metrics
ax_list = ax_list[:num_metrics]
# Plot the metrics using the list of axes
self.metrics.plot(ax=ax_list, together=False)
# Adjust layout to avoid overlapping plots
plt.tight_layout()
plt.show()
def _compute(self) -> T:
if self.plot:
self.do_plot()
result = {}
for name, metric in self.metrics.items():
if name in self.thresholds:
preds = discretize(values=self._preds, threshold=self.thresholds[name])
else:
preds = self._preds
if isinstance(preds, dict):
metric_results = {
t: metric(self._targets, t_preds) for t, t_preds in preds.items()
}
# just get the max
max_t, max_v = max(metric_results.items(), key=lambda k_v: k_v[1])
result[f"{name}-{max_t}"] = max_v
else:
result[name] = metric(self._targets, preds)
result = to_py_obj(result)
if self.show_as_markdown:
import pandas as pd
series = pd.Series(result)
if isinstance(series.index, MultiIndex):
if len(series.index.levels) > 1:
# in fact, this is not a series anymore
series = series.unstack(-1)
else:
series.index = series.index.get_level_values(0)
logger.info(
f"{self.current_split}\n{series.round(self.markdown_precision).to_markdown()}"
)
return result
|