File size: 7,567 Bytes
3133b5e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
import copy
from itertools import chain
from typing import Dict, Optional, Sequence, Type

import torch
from pie_modules.annotations import BinaryCorefRelation
from pie_modules.document.processing.text_pair import shift_span
from pie_modules.documents import TextPairDocumentWithLabeledSpansAndBinaryCorefRelations
from pie_modules.taskmodules import RETextClassificationWithIndicesTaskModule
from pie_modules.taskmodules.common import TaskModuleWithDocumentConverter
from pie_modules.taskmodules.re_text_classification_with_indices import MarkerFactory
from pie_modules.taskmodules.re_text_classification_with_indices import (
    ModelTargetType as REModelTargetType,
)
from pie_modules.taskmodules.re_text_classification_with_indices import (
    TaskOutputType as RETaskOutputType,
)
from pytorch_ie import Document, TaskModule
from pytorch_ie.annotations import LabeledSpan
from pytorch_ie.documents import TextDocumentWithLabeledSpansAndBinaryRelations


class SharpBracketMarkerFactory(MarkerFactory):
    def _get_marker(self, role: str, is_start: bool, label: Optional[str] = None) -> str:
        result = "<"
        if not is_start:
            result += "/"
        result += self._get_role_marker(role)
        if label is not None:
            result += f":{label}"
        result += ">"
        return result

    def get_append_marker(self, role: str, label: Optional[str] = None) -> str:
        role_marker = self._get_role_marker(role)
        if label is None:
            return f"<{role_marker}>"
        else:
            return f"<{role_marker}={label}>"


@TaskModule.register()
class RETextClassificationWithIndicesTaskModuleAndWithSharpBracketMarkers(
    RETextClassificationWithIndicesTaskModule
):
    def __init__(self, use_sharp_marker: bool = False, **kwargs):
        super().__init__(**kwargs)
        self.use_sharp_marker = use_sharp_marker

    def get_marker_factory(self) -> MarkerFactory:
        if self.use_sharp_marker:
            return SharpBracketMarkerFactory(role_to_marker=self.argument_role_to_marker)
        else:
            return MarkerFactory(role_to_marker=self.argument_role_to_marker)


def construct_text_document_from_text_pair_coref_document(
    document: TextPairDocumentWithLabeledSpansAndBinaryCorefRelations,
    glue_text: str,
    no_relation_label: str,
    relation_label_mapping: Optional[Dict[str, str]] = None,
    add_span_mapping_to_metadata: bool = False,
) -> TextDocumentWithLabeledSpansAndBinaryRelations:
    if document.text == document.text_pair:
        new_doc = TextDocumentWithLabeledSpansAndBinaryRelations(
            id=document.id, metadata=copy.deepcopy(document.metadata), text=document.text
        )
        old2new_spans: Dict[LabeledSpan, LabeledSpan] = {}
        new2new_spans: Dict[LabeledSpan, LabeledSpan] = {}
        for old_span in chain(document.labeled_spans, document.labeled_spans_pair):
            new_span = old_span.copy()
            # when detaching / copying the span, it may be the same as a previous span from the other
            new_span = new2new_spans.get(new_span, new_span)
            new2new_spans[new_span] = new_span
            old2new_spans[old_span] = new_span
    else:
        new_doc = TextDocumentWithLabeledSpansAndBinaryRelations(
            text=document.text + glue_text + document.text_pair,
            id=document.id,
            metadata=copy.deepcopy(document.metadata),
        )
        old2new_spans = {}
        old2new_spans.update({span: span.copy() for span in document.labeled_spans})
        offset = len(document.text) + len(glue_text)
        old2new_spans.update(
            {span: shift_span(span.copy(), offset) for span in document.labeled_spans_pair}
        )

    # sort to make order deterministic
    new_doc.labeled_spans.extend(
        sorted(old2new_spans.values(), key=lambda s: (s.start, s.end, s.label))
    )
    for old_rel in document.binary_coref_relations:
        label = old_rel.label if old_rel.score > 0.0 else no_relation_label
        if relation_label_mapping is not None:
            label = relation_label_mapping.get(label, label)
        new_rel = old_rel.copy(
            head=old2new_spans[old_rel.head],
            tail=old2new_spans[old_rel.tail],
            label=label,
            score=1.0,
        )
        new_doc.binary_relations.append(new_rel)

    if add_span_mapping_to_metadata:
        new_doc.metadata["span_mapping"] = old2new_spans
    return new_doc


@TaskModule.register()
class CrossTextBinaryCorefByRETextClassificationTaskModule(
    TaskModuleWithDocumentConverter,
    RETextClassificationWithIndicesTaskModuleAndWithSharpBracketMarkers,
):
    def __init__(
        self,
        coref_relation_label: str,
        relation_annotation: str = "binary_relations",
        probability_threshold: float = 0.0,
        **kwargs,
    ):
        if relation_annotation != "binary_relations":
            raise ValueError(
                f"{type(self).__name__} requires relation_annotation='binary_relations', "
                f"but it is: {relation_annotation}"
            )
        super().__init__(relation_annotation=relation_annotation, **kwargs)
        self.coref_relation_label = coref_relation_label
        self.probability_threshold = probability_threshold

    @property
    def document_type(self) -> Optional[Type[Document]]:
        return TextPairDocumentWithLabeledSpansAndBinaryCorefRelations

    def _get_glue_text(self) -> str:
        result = self.tokenizer.decode(self._get_glue_token_ids())
        return result

    def _convert_document(
        self, document: TextPairDocumentWithLabeledSpansAndBinaryCorefRelations
    ) -> TextDocumentWithLabeledSpansAndBinaryRelations:
        return construct_text_document_from_text_pair_coref_document(
            document,
            glue_text=self._get_glue_text(),
            relation_label_mapping={"coref": self.coref_relation_label},
            no_relation_label=self.none_label,
            add_span_mapping_to_metadata=True,
        )

    def _integrate_predictions_from_converted_document(
        self,
        document: TextPairDocumentWithLabeledSpansAndBinaryCorefRelations,
        converted_document: TextDocumentWithLabeledSpansAndBinaryRelations,
    ) -> None:
        original2converted_span = converted_document.metadata["span_mapping"]
        new2original_span = {
            converted_s: orig_s for orig_s, converted_s in original2converted_span.items()
        }

        for rel in converted_document.binary_relations.predictions:
            original_head = new2original_span[rel.head]
            original_tail = new2original_span[rel.tail]
            if rel.label != self.coref_relation_label:
                raise ValueError(f"unexpected label: {rel.label}")
            if rel.score >= self.probability_threshold:
                original_predicted_rel = BinaryCorefRelation(
                    head=original_head, tail=original_tail, label="coref", score=rel.score
                )
                document.binary_coref_relations.predictions.append(original_predicted_rel)

    def unbatch_output(self, model_output: REModelTargetType) -> Sequence[RETaskOutputType]:
        coref_relation_idx = self.label_to_id[self.coref_relation_label]
        # we are just concerned with the coref class, so we overwrite the labels field
        model_output = copy.copy(model_output)
        model_output["labels"] = torch.ones_like(model_output["labels"]) * coref_relation_idx
        return super().unbatch_output(model_output=model_output)