ScientificArgumentRecommender / src /models /sequence_classification_with_pooler.py
ArneBinder's picture
https://github.com/ArneBinder/pie-document-level/pull/312
3133b5e verified
import abc
import logging
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
import torch
import torch.nn.functional as F
from adapters import AutoAdapterModel
from pie_modules.models import SequencePairSimilarityModelWithPooler
from pie_modules.models.components.pooler import MENTION_POOLING
from pie_modules.models.sequence_classification_with_pooler import (
InputType,
OutputType,
SequenceClassificationModelWithPooler,
SequenceClassificationModelWithPoolerBase,
TargetType,
separate_arguments_by_prefix,
)
from pytorch_ie import PyTorchIEModel
from torch import FloatTensor, Tensor
from transformers import AutoConfig, PreTrainedModel
from transformers.modeling_outputs import SequenceClassifierOutput
from src.models.components.pooler import SpanMeanPooler
logger = logging.getLogger(__name__)
class SequenceClassificationModelWithPoolerBase2(
SequenceClassificationModelWithPoolerBase, abc.ABC
):
def setup_pooler(self, input_dim: int) -> Tuple[Callable, int]:
aggregate = self.pooler_config.get("aggregate", "max")
if self.pooler_config["type"] == MENTION_POOLING and aggregate != "max":
if aggregate == "mean":
pooler_config = dict(self.pooler_config)
pooler_config.pop("type")
pooler_config.pop("aggregate")
pooler = SpanMeanPooler(input_dim=input_dim, **pooler_config)
return pooler, pooler.output_dim
else:
raise ValueError(f"Unknown aggregation method: {aggregate}")
else:
return super().setup_pooler(input_dim)
class SequenceClassificationModelWithPoolerAndAdapterBase(
SequenceClassificationModelWithPoolerBase2, abc.ABC
):
def __init__(self, adapter_name_or_path: Optional[str] = None, **kwargs):
self.adapter_name_or_path = adapter_name_or_path
super().__init__(**kwargs)
def setup_base_model(self) -> PreTrainedModel:
if self.adapter_name_or_path is None:
return super().setup_base_model()
else:
config = AutoConfig.from_pretrained(self.model_name_or_path)
if self.is_from_pretrained:
model = AutoAdapterModel.from_config(config=config)
else:
model = AutoAdapterModel.from_pretrained(self.model_name_or_path, config=config)
# load the adapter in any case (it looks like it is not saved in the state or loaded
# from a serialized state)
logger.info(f"load adapter: {self.adapter_name_or_path}")
model.load_adapter(self.adapter_name_or_path, source="hf", set_active=True)
return model
@PyTorchIEModel.register()
class SequencePairSimilarityModelWithPooler2(
SequencePairSimilarityModelWithPooler, SequenceClassificationModelWithPoolerBase2
):
pass
@PyTorchIEModel.register()
class SequencePairSimilarityModelWithPoolerAndAdapter(
SequencePairSimilarityModelWithPooler, SequenceClassificationModelWithPoolerAndAdapterBase
):
pass
@PyTorchIEModel.register()
class SequenceClassificationModelWithPoolerAndAdapter(
SequenceClassificationModelWithPooler, SequenceClassificationModelWithPoolerAndAdapterBase
):
pass
def get_max_cosine_sim(embeddings: Tensor, embeddings_pair: Tensor) -> Tensor:
# Normalize the embeddings
embeddings_normalized = F.normalize(embeddings, p=2, dim=1) # Shape: (n, k)
embeddings_normalized_pair = F.normalize(embeddings_pair, p=2, dim=1) # Shape: (m, k)
# Compute the cosine similarity matrix
cosine_sim = torch.mm(embeddings_normalized, embeddings_normalized_pair.T) # Shape: (n, m)
# Get the overall maximum cosine similarity value
max_cosine_sim = torch.max(cosine_sim) # This will return a scalar
return max_cosine_sim
def get_span_embeddings(
embeddings: FloatTensor, start_indices: Tensor, end_indices: Tensor
) -> List[FloatTensor]:
result = []
for embeds, starts, ends in zip(embeddings, start_indices, end_indices):
span_embeds = embeds[starts[0] : ends[0]]
result.append(span_embeds)
return result
@PyTorchIEModel.register()
class SequencePairSimilarityModelWithMaxCosineSim(SequencePairSimilarityModelWithPooler):
def get_pooled_output(self, model_inputs, pooler_inputs) -> List[FloatTensor]:
output = self.model(**model_inputs)
hidden_state = output.last_hidden_state
# pooled_output = self.pooler(hidden_state, **pooler_inputs)
# pooled_output = self.dropout(pooled_output)
span_embeds = get_span_embeddings(hidden_state, **pooler_inputs)
return span_embeds
def forward(
self,
inputs: InputType,
targets: Optional[TargetType] = None,
return_hidden_states: bool = False,
) -> OutputType:
sanitized_inputs = separate_arguments_by_prefix(
# Note that the order of the prefixes is important because one is a prefix of the other,
# so we need to start with the longer!
arguments=inputs,
prefixes=["pooler_pair_", "pooler_"],
)
span_embeddings = self.get_pooled_output(
model_inputs=sanitized_inputs["remaining"]["encoding"],
pooler_inputs=sanitized_inputs["pooler_"],
)
span_embeddings_pair = self.get_pooled_output(
model_inputs=sanitized_inputs["remaining"]["encoding_pair"],
pooler_inputs=sanitized_inputs["pooler_pair_"],
)
logits_list = [
get_max_cosine_sim(span_embeds, span_embeds_pair)
for span_embeds, span_embeds_pair in zip(span_embeddings, span_embeddings_pair)
]
logits = torch.stack(logits_list)
result = {"logits": logits}
if targets is not None:
labels = targets["scores"]
loss = self.loss_fct(logits, labels)
result["loss"] = loss
if return_hidden_states:
raise NotImplementedError("return_hidden_states is not yet implemented")
return SequenceClassifierOutput(**result)
@PyTorchIEModel.register()
class SequencePairSimilarityModelWithMaxCosineSimAndAdapter(
SequencePairSimilarityModelWithMaxCosineSim, SequencePairSimilarityModelWithPoolerAndAdapter
):
pass