|
import logging |
|
import math |
|
from typing import Dict |
|
|
|
import torch |
|
from pandas import MultiIndex |
|
from pie_modules.documents import TextPairDocumentWithLabeledSpansAndBinaryCorefRelations |
|
from pytorch_ie import DocumentMetric |
|
from pytorch_ie.core.metric import T |
|
from torchmetrics import Metric, MetricCollection |
|
|
|
from src.hydra_callbacks.save_job_return_value import to_py_obj |
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
class CorefMetricsTorchmetrics(DocumentMetric): |
|
DOCUMENT_TYPE = TextPairDocumentWithLabeledSpansAndBinaryCorefRelations |
|
|
|
def __init__( |
|
self, |
|
metrics: Dict[str, Metric], |
|
default_target_idx: int = 0, |
|
default_prediction_score: float = 0.0, |
|
show_as_markdown: bool = False, |
|
markdown_precision: int = 4, |
|
plot: bool = False, |
|
): |
|
self.metrics = MetricCollection(metrics) |
|
self.default_target_idx = default_target_idx |
|
self.default_prediction_score = default_prediction_score |
|
self.show_as_markdown = show_as_markdown |
|
self.markdown_precision = markdown_precision |
|
self.plot = plot |
|
|
|
super().__init__() |
|
|
|
def reset(self) -> None: |
|
self.metrics.reset() |
|
|
|
def _update(self, document: TextPairDocumentWithLabeledSpansAndBinaryCorefRelations) -> None: |
|
target_args2idx = { |
|
(rel.head, rel.tail): int(rel.score) for rel in document.binary_coref_relations |
|
} |
|
prediction_args2score = { |
|
(rel.head, rel.tail): rel.score for rel in document.binary_coref_relations.predictions |
|
} |
|
all_args = set(target_args2idx) | set(prediction_args2score) |
|
all_targets = [] |
|
all_predictions = [] |
|
for args in all_args: |
|
target_idx = target_args2idx.get(args, self.default_target_idx) |
|
prediction_score = prediction_args2score.get(args, self.default_prediction_score) |
|
all_targets.append(target_idx) |
|
all_predictions.append(prediction_score) |
|
prediction_scores = torch.tensor(all_predictions) |
|
target_indices = torch.tensor(all_targets) |
|
self.metrics.update(preds=prediction_scores, target=target_indices) |
|
|
|
def do_plot(self): |
|
from matplotlib import pyplot as plt |
|
|
|
|
|
num_metrics = len(self.metrics) |
|
|
|
|
|
ncols = math.ceil(math.sqrt(num_metrics)) |
|
nrows = math.ceil(num_metrics / ncols) |
|
|
|
|
|
fig, ax_list = plt.subplots(nrows=nrows, ncols=ncols, figsize=(15, 10)) |
|
|
|
|
|
ax_list = ax_list.flatten().tolist() |
|
|
|
|
|
ax_list = ax_list[:num_metrics] |
|
|
|
|
|
self.metrics.plot(ax=ax_list, together=False) |
|
|
|
|
|
plt.tight_layout() |
|
plt.show() |
|
|
|
def _compute(self) -> T: |
|
|
|
if self.plot: |
|
self.do_plot() |
|
|
|
result = self.metrics.compute() |
|
|
|
result = to_py_obj(result) |
|
if self.show_as_markdown: |
|
import pandas as pd |
|
|
|
series = pd.Series(result) |
|
if isinstance(series.index, MultiIndex): |
|
if len(series.index.levels) > 1: |
|
|
|
series = series.unstack(-1) |
|
else: |
|
series.index = series.index.get_level_values(0) |
|
logger.info( |
|
f"{self.current_split}\n{series.round(self.markdown_precision).to_markdown()}" |
|
) |
|
return result |
|
|