File size: 6,309 Bytes
3133b5e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
import abc
import logging
from typing import Any, Callable, Dict, List, Optional, Tuple, Union

import torch
import torch.nn.functional as F
from adapters import AutoAdapterModel
from pie_modules.models import SequencePairSimilarityModelWithPooler
from pie_modules.models.components.pooler import MENTION_POOLING
from pie_modules.models.sequence_classification_with_pooler import (
    InputType,
    OutputType,
    SequenceClassificationModelWithPooler,
    SequenceClassificationModelWithPoolerBase,
    TargetType,
    separate_arguments_by_prefix,
)
from pytorch_ie import PyTorchIEModel
from torch import FloatTensor, Tensor
from transformers import AutoConfig, PreTrainedModel
from transformers.modeling_outputs import SequenceClassifierOutput

from src.models.components.pooler import SpanMeanPooler

logger = logging.getLogger(__name__)


class SequenceClassificationModelWithPoolerBase2(
    SequenceClassificationModelWithPoolerBase, abc.ABC
):
    def setup_pooler(self, input_dim: int) -> Tuple[Callable, int]:
        aggregate = self.pooler_config.get("aggregate", "max")
        if self.pooler_config["type"] == MENTION_POOLING and aggregate != "max":
            if aggregate == "mean":
                pooler_config = dict(self.pooler_config)
                pooler_config.pop("type")
                pooler_config.pop("aggregate")
                pooler = SpanMeanPooler(input_dim=input_dim, **pooler_config)
                return pooler, pooler.output_dim
            else:
                raise ValueError(f"Unknown aggregation method: {aggregate}")
        else:
            return super().setup_pooler(input_dim)


class SequenceClassificationModelWithPoolerAndAdapterBase(
    SequenceClassificationModelWithPoolerBase2, abc.ABC
):
    def __init__(self, adapter_name_or_path: Optional[str] = None, **kwargs):
        self.adapter_name_or_path = adapter_name_or_path
        super().__init__(**kwargs)

    def setup_base_model(self) -> PreTrainedModel:
        if self.adapter_name_or_path is None:
            return super().setup_base_model()
        else:
            config = AutoConfig.from_pretrained(self.model_name_or_path)
            if self.is_from_pretrained:
                model = AutoAdapterModel.from_config(config=config)
            else:
                model = AutoAdapterModel.from_pretrained(self.model_name_or_path, config=config)
            # load the adapter in any case (it looks like it is not saved in the state or loaded
            # from a serialized state)
            logger.info(f"load adapter: {self.adapter_name_or_path}")
            model.load_adapter(self.adapter_name_or_path, source="hf", set_active=True)
            return model


@PyTorchIEModel.register()
class SequencePairSimilarityModelWithPooler2(
    SequencePairSimilarityModelWithPooler, SequenceClassificationModelWithPoolerBase2
):
    pass


@PyTorchIEModel.register()
class SequencePairSimilarityModelWithPoolerAndAdapter(
    SequencePairSimilarityModelWithPooler, SequenceClassificationModelWithPoolerAndAdapterBase
):
    pass


@PyTorchIEModel.register()
class SequenceClassificationModelWithPoolerAndAdapter(
    SequenceClassificationModelWithPooler, SequenceClassificationModelWithPoolerAndAdapterBase
):
    pass


def get_max_cosine_sim(embeddings: Tensor, embeddings_pair: Tensor) -> Tensor:
    # Normalize the embeddings
    embeddings_normalized = F.normalize(embeddings, p=2, dim=1)  # Shape: (n, k)
    embeddings_normalized_pair = F.normalize(embeddings_pair, p=2, dim=1)  # Shape: (m, k)

    # Compute the cosine similarity matrix
    cosine_sim = torch.mm(embeddings_normalized, embeddings_normalized_pair.T)  # Shape: (n, m)

    # Get the overall maximum cosine similarity value
    max_cosine_sim = torch.max(cosine_sim)  # This will return a scalar
    return max_cosine_sim


def get_span_embeddings(
    embeddings: FloatTensor, start_indices: Tensor, end_indices: Tensor
) -> List[FloatTensor]:
    result = []
    for embeds, starts, ends in zip(embeddings, start_indices, end_indices):
        span_embeds = embeds[starts[0] : ends[0]]
        result.append(span_embeds)
    return result


@PyTorchIEModel.register()
class SequencePairSimilarityModelWithMaxCosineSim(SequencePairSimilarityModelWithPooler):
    def get_pooled_output(self, model_inputs, pooler_inputs) -> List[FloatTensor]:
        output = self.model(**model_inputs)
        hidden_state = output.last_hidden_state
        # pooled_output = self.pooler(hidden_state, **pooler_inputs)
        # pooled_output = self.dropout(pooled_output)
        span_embeds = get_span_embeddings(hidden_state, **pooler_inputs)
        return span_embeds

    def forward(
        self,
        inputs: InputType,
        targets: Optional[TargetType] = None,
        return_hidden_states: bool = False,
    ) -> OutputType:
        sanitized_inputs = separate_arguments_by_prefix(
            # Note that the order of the prefixes is important because one is a prefix of the other,
            # so we need to start with the longer!
            arguments=inputs,
            prefixes=["pooler_pair_", "pooler_"],
        )

        span_embeddings = self.get_pooled_output(
            model_inputs=sanitized_inputs["remaining"]["encoding"],
            pooler_inputs=sanitized_inputs["pooler_"],
        )
        span_embeddings_pair = self.get_pooled_output(
            model_inputs=sanitized_inputs["remaining"]["encoding_pair"],
            pooler_inputs=sanitized_inputs["pooler_pair_"],
        )

        logits_list = [
            get_max_cosine_sim(span_embeds, span_embeds_pair)
            for span_embeds, span_embeds_pair in zip(span_embeddings, span_embeddings_pair)
        ]
        logits = torch.stack(logits_list)

        result = {"logits": logits}
        if targets is not None:
            labels = targets["scores"]
            loss = self.loss_fct(logits, labels)
            result["loss"] = loss
        if return_hidden_states:
            raise NotImplementedError("return_hidden_states is not yet implemented")

        return SequenceClassifierOutput(**result)


@PyTorchIEModel.register()
class SequencePairSimilarityModelWithMaxCosineSimAndAdapter(
    SequencePairSimilarityModelWithMaxCosineSim, SequencePairSimilarityModelWithPoolerAndAdapter
):
    pass