File size: 3,812 Bytes
3133b5e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
import logging
import math
from typing import Dict

import torch
from pandas import MultiIndex
from pie_modules.documents import TextPairDocumentWithLabeledSpansAndBinaryCorefRelations
from pytorch_ie import DocumentMetric
from pytorch_ie.core.metric import T
from torchmetrics import Metric, MetricCollection

from src.hydra_callbacks.save_job_return_value import to_py_obj

logger = logging.getLogger(__name__)


class CorefMetricsTorchmetrics(DocumentMetric):
    DOCUMENT_TYPE = TextPairDocumentWithLabeledSpansAndBinaryCorefRelations

    def __init__(
        self,
        metrics: Dict[str, Metric],
        default_target_idx: int = 0,
        default_prediction_score: float = 0.0,
        show_as_markdown: bool = False,
        markdown_precision: int = 4,
        plot: bool = False,
    ):
        self.metrics = MetricCollection(metrics)
        self.default_target_idx = default_target_idx
        self.default_prediction_score = default_prediction_score
        self.show_as_markdown = show_as_markdown
        self.markdown_precision = markdown_precision
        self.plot = plot

        super().__init__()

    def reset(self) -> None:
        self.metrics.reset()

    def _update(self, document: TextPairDocumentWithLabeledSpansAndBinaryCorefRelations) -> None:
        target_args2idx = {
            (rel.head, rel.tail): int(rel.score) for rel in document.binary_coref_relations
        }
        prediction_args2score = {
            (rel.head, rel.tail): rel.score for rel in document.binary_coref_relations.predictions
        }
        all_args = set(target_args2idx) | set(prediction_args2score)
        all_targets = []
        all_predictions = []
        for args in all_args:
            target_idx = target_args2idx.get(args, self.default_target_idx)
            prediction_score = prediction_args2score.get(args, self.default_prediction_score)
            all_targets.append(target_idx)
            all_predictions.append(prediction_score)
        prediction_scores = torch.tensor(all_predictions)
        target_indices = torch.tensor(all_targets)
        self.metrics.update(preds=prediction_scores, target=target_indices)

    def do_plot(self):
        from matplotlib import pyplot as plt

        # Get the number of metrics
        num_metrics = len(self.metrics)

        # Calculate rows and columns for subplots (aim for a square-like layout)
        ncols = math.ceil(math.sqrt(num_metrics))
        nrows = math.ceil(num_metrics / ncols)

        # Create the subplots
        fig, ax_list = plt.subplots(nrows=nrows, ncols=ncols, figsize=(15, 10))

        # Flatten the ax_list if necessary (in case of multiple rows/columns)
        ax_list = ax_list.flatten().tolist()  # Ensure it's a list, and flatten it if necessary

        # Ensure that we pass exactly the number of axes required by metrics
        ax_list = ax_list[:num_metrics]

        # Plot the metrics using the list of axes
        self.metrics.plot(ax=ax_list, together=False)

        # Adjust layout to avoid overlapping plots
        plt.tight_layout()
        plt.show()

    def _compute(self) -> T:

        if self.plot:
            self.do_plot()

        result = self.metrics.compute()

        result = to_py_obj(result)
        if self.show_as_markdown:
            import pandas as pd

            series = pd.Series(result)
            if isinstance(series.index, MultiIndex):
                if len(series.index.levels) > 1:
                    # in fact, this is not a series anymore
                    series = series.unstack(-1)
                else:
                    series.index = series.index.get_level_values(0)
            logger.info(
                f"{self.current_split}\n{series.round(self.markdown_precision).to_markdown()}"
            )
        return result