ScientificArgumentRecommender / src /models /sequence_classification.py
ArneBinder's picture
upload https://github.com/ArneBinder/pie-document-level/pull/452
e7eaeed verified
from typing import Optional
from pie_modules.models import SimpleSequenceClassificationModel
from pie_modules.models.simple_sequence_classification import InputType, OutputType, TargetType
from pytorch_ie import PyTorchIEModel
from torch import nn
from transformers import BertModel
from transformers.utils import is_accelerate_available
if is_accelerate_available():
from accelerate.hooks import add_hook_to_module
@PyTorchIEModel.register()
class SimpleSequenceClassificationModelWithInputTypeIds(SimpleSequenceClassificationModel):
def __init__(
self, num_token_type_ids: int, use_as_token_type_ids: str = "token_type_ids", **kwargs
):
super().__init__(**kwargs)
self.num_token_type_ids = num_token_type_ids
self.token_type_ids_key = use_as_token_type_ids
self.resize_type_embeddings(num_token_type_ids)
def get_input_type_embeddings(self) -> nn.Module:
base_model: BertModel = getattr(self.model, self.model.base_model_prefix)
if base_model is None:
raise ValueError("Model has no base model.")
return base_model.embeddings.token_type_embeddings
def set_input_type_embeddings(self, value):
base_model: BertModel = getattr(self.model, self.model.base_model_prefix)
if base_model is None:
raise ValueError("Model has no base model.")
base_model.embeddings.token_type_embeddings = value
def _resize_type_embeddings(self, new_num_tokens, pad_to_multiple_of=None):
old_embeddings = self.get_input_type_embeddings()
new_embeddings = self.model._get_resized_embeddings(
old_embeddings, new_num_tokens, pad_to_multiple_of
)
if hasattr(old_embeddings, "_hf_hook"):
hook = old_embeddings._hf_hook
add_hook_to_module(new_embeddings, hook)
old_embeddings_requires_grad = old_embeddings.weight.requires_grad
new_embeddings.requires_grad_(old_embeddings_requires_grad)
self.set_input_type_embeddings(new_embeddings)
return self.get_input_type_embeddings()
def resize_type_embeddings(
self, new_num_types: Optional[int] = None, pad_to_multiple_of: Optional[int] = None
) -> nn.Embedding:
"""
Same as resize_token_embeddings but for the token type embeddings.
Resizes input token type embeddings matrix of the model if `new_num_types != config.type_vocab_size`.
Takes care of tying weights embeddings afterwards if the model class has a `tie_weights()` method.
Arguments:
new_num_types (`int`, *optional*):
The number of new token types in the embedding matrix. Increasing the size will add newly initialized
vectors at the end. Reducing the size will remove vectors from the end. If not provided or `None`, just
returns a pointer to the input tokens `torch.nn.Embedding` module of the model without doing anything.
pad_to_multiple_of (`int`, *optional*):
If set will pad the embedding matrix to a multiple of the provided value.If `new_num_tokens` is set to
`None` will just pad the embedding to a multiple of `pad_to_multiple_of`.
This is especially useful to enable the use of Tensor Cores on NVIDIA hardware with compute capability
`>= 7.5` (Volta), or on TPUs which benefit from having sequence lengths be a multiple of 128. For more
details about this, or help on choosing the correct value for resizing, refer to this guide:
https://docs.nvidia.com/deeplearning/performance/dl-performance-matrix-multiplication/index.html#requirements-tc
Return:
`torch.nn.Embedding`: Pointer to the input tokens Embeddings Module of the model.
"""
model_embeds = self._resize_type_embeddings(new_num_types, pad_to_multiple_of)
if new_num_types is None and pad_to_multiple_of is None:
return model_embeds
# Update base model and current model config
self.model.config.type_vocab_size = model_embeds.weight.shape[0]
# Tie weights again if needed
self.model.tie_weights()
return model_embeds
def forward(self, inputs: InputType, targets: Optional[TargetType] = None) -> OutputType:
kwargs = {**inputs, **(targets or {})}
# rename key to input_type_ids
kwargs["token_type_ids"] = kwargs.pop(self.token_type_ids_key)
return self.model(**kwargs)