|
import abc |
|
import logging |
|
from typing import Any, Callable, Dict, List, Optional, Tuple, Union |
|
|
|
import torch |
|
import torch.nn.functional as F |
|
from adapters import AutoAdapterModel |
|
from pie_modules.models import SequencePairSimilarityModelWithPooler |
|
from pie_modules.models.components.pooler import MENTION_POOLING |
|
from pie_modules.models.sequence_classification_with_pooler import ( |
|
InputType, |
|
OutputType, |
|
SequenceClassificationModelWithPooler, |
|
SequenceClassificationModelWithPoolerBase, |
|
TargetType, |
|
separate_arguments_by_prefix, |
|
) |
|
from pytorch_ie import PyTorchIEModel |
|
from torch import FloatTensor, Tensor |
|
from transformers import AutoConfig, PreTrainedModel |
|
from transformers.modeling_outputs import SequenceClassifierOutput |
|
|
|
from src.models.components.pooler import SpanMeanPooler |
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
class SequenceClassificationModelWithPoolerBase2( |
|
SequenceClassificationModelWithPoolerBase, abc.ABC |
|
): |
|
def setup_pooler(self, input_dim: int) -> Tuple[Callable, int]: |
|
aggregate = self.pooler_config.get("aggregate", "max") |
|
if self.pooler_config["type"] == MENTION_POOLING and aggregate != "max": |
|
if aggregate == "mean": |
|
pooler_config = dict(self.pooler_config) |
|
pooler_config.pop("type") |
|
pooler_config.pop("aggregate") |
|
pooler = SpanMeanPooler(input_dim=input_dim, **pooler_config) |
|
return pooler, pooler.output_dim |
|
else: |
|
raise ValueError(f"Unknown aggregation method: {aggregate}") |
|
else: |
|
return super().setup_pooler(input_dim) |
|
|
|
|
|
class SequenceClassificationModelWithPoolerAndAdapterBase( |
|
SequenceClassificationModelWithPoolerBase2, abc.ABC |
|
): |
|
def __init__(self, adapter_name_or_path: Optional[str] = None, **kwargs): |
|
self.adapter_name_or_path = adapter_name_or_path |
|
super().__init__(**kwargs) |
|
|
|
def setup_base_model(self) -> PreTrainedModel: |
|
if self.adapter_name_or_path is None: |
|
return super().setup_base_model() |
|
else: |
|
config = AutoConfig.from_pretrained(self.model_name_or_path) |
|
if self.is_from_pretrained: |
|
model = AutoAdapterModel.from_config(config=config) |
|
else: |
|
model = AutoAdapterModel.from_pretrained(self.model_name_or_path, config=config) |
|
|
|
|
|
logger.info(f"load adapter: {self.adapter_name_or_path}") |
|
model.load_adapter(self.adapter_name_or_path, source="hf", set_active=True) |
|
return model |
|
|
|
|
|
@PyTorchIEModel.register() |
|
class SequencePairSimilarityModelWithPooler2( |
|
SequencePairSimilarityModelWithPooler, SequenceClassificationModelWithPoolerBase2 |
|
): |
|
pass |
|
|
|
|
|
@PyTorchIEModel.register() |
|
class SequencePairSimilarityModelWithPoolerAndAdapter( |
|
SequencePairSimilarityModelWithPooler, SequenceClassificationModelWithPoolerAndAdapterBase |
|
): |
|
pass |
|
|
|
|
|
@PyTorchIEModel.register() |
|
class SequenceClassificationModelWithPoolerAndAdapter( |
|
SequenceClassificationModelWithPooler, SequenceClassificationModelWithPoolerAndAdapterBase |
|
): |
|
pass |
|
|
|
|
|
def get_max_cosine_sim(embeddings: Tensor, embeddings_pair: Tensor) -> Tensor: |
|
|
|
embeddings_normalized = F.normalize(embeddings, p=2, dim=1) |
|
embeddings_normalized_pair = F.normalize(embeddings_pair, p=2, dim=1) |
|
|
|
|
|
cosine_sim = torch.mm(embeddings_normalized, embeddings_normalized_pair.T) |
|
|
|
|
|
max_cosine_sim = torch.max(cosine_sim) |
|
return max_cosine_sim |
|
|
|
|
|
def get_span_embeddings( |
|
embeddings: FloatTensor, start_indices: Tensor, end_indices: Tensor |
|
) -> List[FloatTensor]: |
|
result = [] |
|
for embeds, starts, ends in zip(embeddings, start_indices, end_indices): |
|
span_embeds = embeds[starts[0] : ends[0]] |
|
result.append(span_embeds) |
|
return result |
|
|
|
|
|
@PyTorchIEModel.register() |
|
class SequencePairSimilarityModelWithMaxCosineSim(SequencePairSimilarityModelWithPooler): |
|
def get_pooled_output(self, model_inputs, pooler_inputs) -> List[FloatTensor]: |
|
output = self.model(**model_inputs) |
|
hidden_state = output.last_hidden_state |
|
|
|
|
|
span_embeds = get_span_embeddings(hidden_state, **pooler_inputs) |
|
return span_embeds |
|
|
|
def forward( |
|
self, |
|
inputs: InputType, |
|
targets: Optional[TargetType] = None, |
|
return_hidden_states: bool = False, |
|
) -> OutputType: |
|
sanitized_inputs = separate_arguments_by_prefix( |
|
|
|
|
|
arguments=inputs, |
|
prefixes=["pooler_pair_", "pooler_"], |
|
) |
|
|
|
span_embeddings = self.get_pooled_output( |
|
model_inputs=sanitized_inputs["remaining"]["encoding"], |
|
pooler_inputs=sanitized_inputs["pooler_"], |
|
) |
|
span_embeddings_pair = self.get_pooled_output( |
|
model_inputs=sanitized_inputs["remaining"]["encoding_pair"], |
|
pooler_inputs=sanitized_inputs["pooler_pair_"], |
|
) |
|
|
|
logits_list = [ |
|
get_max_cosine_sim(span_embeds, span_embeds_pair) |
|
for span_embeds, span_embeds_pair in zip(span_embeddings, span_embeddings_pair) |
|
] |
|
logits = torch.stack(logits_list) |
|
|
|
result = {"logits": logits} |
|
if targets is not None: |
|
labels = targets["scores"] |
|
loss = self.loss_fct(logits, labels) |
|
result["loss"] = loss |
|
if return_hidden_states: |
|
raise NotImplementedError("return_hidden_states is not yet implemented") |
|
|
|
return SequenceClassifierOutput(**result) |
|
|
|
|
|
@PyTorchIEModel.register() |
|
class SequencePairSimilarityModelWithMaxCosineSimAndAdapter( |
|
SequencePairSimilarityModelWithMaxCosineSim, SequencePairSimilarityModelWithPoolerAndAdapter |
|
): |
|
pass |
|
|