File size: 6,309 Bytes
3133b5e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 |
import abc
import logging
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
import torch
import torch.nn.functional as F
from adapters import AutoAdapterModel
from pie_modules.models import SequencePairSimilarityModelWithPooler
from pie_modules.models.components.pooler import MENTION_POOLING
from pie_modules.models.sequence_classification_with_pooler import (
InputType,
OutputType,
SequenceClassificationModelWithPooler,
SequenceClassificationModelWithPoolerBase,
TargetType,
separate_arguments_by_prefix,
)
from pytorch_ie import PyTorchIEModel
from torch import FloatTensor, Tensor
from transformers import AutoConfig, PreTrainedModel
from transformers.modeling_outputs import SequenceClassifierOutput
from src.models.components.pooler import SpanMeanPooler
logger = logging.getLogger(__name__)
class SequenceClassificationModelWithPoolerBase2(
SequenceClassificationModelWithPoolerBase, abc.ABC
):
def setup_pooler(self, input_dim: int) -> Tuple[Callable, int]:
aggregate = self.pooler_config.get("aggregate", "max")
if self.pooler_config["type"] == MENTION_POOLING and aggregate != "max":
if aggregate == "mean":
pooler_config = dict(self.pooler_config)
pooler_config.pop("type")
pooler_config.pop("aggregate")
pooler = SpanMeanPooler(input_dim=input_dim, **pooler_config)
return pooler, pooler.output_dim
else:
raise ValueError(f"Unknown aggregation method: {aggregate}")
else:
return super().setup_pooler(input_dim)
class SequenceClassificationModelWithPoolerAndAdapterBase(
SequenceClassificationModelWithPoolerBase2, abc.ABC
):
def __init__(self, adapter_name_or_path: Optional[str] = None, **kwargs):
self.adapter_name_or_path = adapter_name_or_path
super().__init__(**kwargs)
def setup_base_model(self) -> PreTrainedModel:
if self.adapter_name_or_path is None:
return super().setup_base_model()
else:
config = AutoConfig.from_pretrained(self.model_name_or_path)
if self.is_from_pretrained:
model = AutoAdapterModel.from_config(config=config)
else:
model = AutoAdapterModel.from_pretrained(self.model_name_or_path, config=config)
# load the adapter in any case (it looks like it is not saved in the state or loaded
# from a serialized state)
logger.info(f"load adapter: {self.adapter_name_or_path}")
model.load_adapter(self.adapter_name_or_path, source="hf", set_active=True)
return model
@PyTorchIEModel.register()
class SequencePairSimilarityModelWithPooler2(
SequencePairSimilarityModelWithPooler, SequenceClassificationModelWithPoolerBase2
):
pass
@PyTorchIEModel.register()
class SequencePairSimilarityModelWithPoolerAndAdapter(
SequencePairSimilarityModelWithPooler, SequenceClassificationModelWithPoolerAndAdapterBase
):
pass
@PyTorchIEModel.register()
class SequenceClassificationModelWithPoolerAndAdapter(
SequenceClassificationModelWithPooler, SequenceClassificationModelWithPoolerAndAdapterBase
):
pass
def get_max_cosine_sim(embeddings: Tensor, embeddings_pair: Tensor) -> Tensor:
# Normalize the embeddings
embeddings_normalized = F.normalize(embeddings, p=2, dim=1) # Shape: (n, k)
embeddings_normalized_pair = F.normalize(embeddings_pair, p=2, dim=1) # Shape: (m, k)
# Compute the cosine similarity matrix
cosine_sim = torch.mm(embeddings_normalized, embeddings_normalized_pair.T) # Shape: (n, m)
# Get the overall maximum cosine similarity value
max_cosine_sim = torch.max(cosine_sim) # This will return a scalar
return max_cosine_sim
def get_span_embeddings(
embeddings: FloatTensor, start_indices: Tensor, end_indices: Tensor
) -> List[FloatTensor]:
result = []
for embeds, starts, ends in zip(embeddings, start_indices, end_indices):
span_embeds = embeds[starts[0] : ends[0]]
result.append(span_embeds)
return result
@PyTorchIEModel.register()
class SequencePairSimilarityModelWithMaxCosineSim(SequencePairSimilarityModelWithPooler):
def get_pooled_output(self, model_inputs, pooler_inputs) -> List[FloatTensor]:
output = self.model(**model_inputs)
hidden_state = output.last_hidden_state
# pooled_output = self.pooler(hidden_state, **pooler_inputs)
# pooled_output = self.dropout(pooled_output)
span_embeds = get_span_embeddings(hidden_state, **pooler_inputs)
return span_embeds
def forward(
self,
inputs: InputType,
targets: Optional[TargetType] = None,
return_hidden_states: bool = False,
) -> OutputType:
sanitized_inputs = separate_arguments_by_prefix(
# Note that the order of the prefixes is important because one is a prefix of the other,
# so we need to start with the longer!
arguments=inputs,
prefixes=["pooler_pair_", "pooler_"],
)
span_embeddings = self.get_pooled_output(
model_inputs=sanitized_inputs["remaining"]["encoding"],
pooler_inputs=sanitized_inputs["pooler_"],
)
span_embeddings_pair = self.get_pooled_output(
model_inputs=sanitized_inputs["remaining"]["encoding_pair"],
pooler_inputs=sanitized_inputs["pooler_pair_"],
)
logits_list = [
get_max_cosine_sim(span_embeds, span_embeds_pair)
for span_embeds, span_embeds_pair in zip(span_embeddings, span_embeddings_pair)
]
logits = torch.stack(logits_list)
result = {"logits": logits}
if targets is not None:
labels = targets["scores"]
loss = self.loss_fct(logits, labels)
result["loss"] = loss
if return_hidden_states:
raise NotImplementedError("return_hidden_states is not yet implemented")
return SequenceClassifierOutput(**result)
@PyTorchIEModel.register()
class SequencePairSimilarityModelWithMaxCosineSimAndAdapter(
SequencePairSimilarityModelWithMaxCosineSim, SequencePairSimilarityModelWithPoolerAndAdapter
):
pass
|