|
from typing import Optional |
|
|
|
from pie_modules.models import SimpleSequenceClassificationModel |
|
from pie_modules.models.simple_sequence_classification import InputType, OutputType, TargetType |
|
from pytorch_ie import PyTorchIEModel |
|
from torch import nn |
|
from transformers import BertModel |
|
from transformers.utils import is_accelerate_available |
|
|
|
if is_accelerate_available(): |
|
from accelerate.hooks import add_hook_to_module |
|
|
|
|
|
@PyTorchIEModel.register() |
|
class SimpleSequenceClassificationModelWithInputTypeIds(SimpleSequenceClassificationModel): |
|
|
|
def __init__( |
|
self, num_token_type_ids: int, use_as_token_type_ids: str = "token_type_ids", **kwargs |
|
): |
|
super().__init__(**kwargs) |
|
self.num_token_type_ids = num_token_type_ids |
|
self.token_type_ids_key = use_as_token_type_ids |
|
self.resize_type_embeddings(num_token_type_ids) |
|
|
|
def get_input_type_embeddings(self) -> nn.Module: |
|
base_model: BertModel = getattr(self.model, self.model.base_model_prefix) |
|
if base_model is None: |
|
raise ValueError("Model has no base model.") |
|
return base_model.embeddings.token_type_embeddings |
|
|
|
def set_input_type_embeddings(self, value): |
|
base_model: BertModel = getattr(self.model, self.model.base_model_prefix) |
|
if base_model is None: |
|
raise ValueError("Model has no base model.") |
|
base_model.embeddings.token_type_embeddings = value |
|
|
|
def _resize_type_embeddings(self, new_num_tokens, pad_to_multiple_of=None): |
|
old_embeddings = self.get_input_type_embeddings() |
|
new_embeddings = self.model._get_resized_embeddings( |
|
old_embeddings, new_num_tokens, pad_to_multiple_of |
|
) |
|
if hasattr(old_embeddings, "_hf_hook"): |
|
hook = old_embeddings._hf_hook |
|
add_hook_to_module(new_embeddings, hook) |
|
old_embeddings_requires_grad = old_embeddings.weight.requires_grad |
|
new_embeddings.requires_grad_(old_embeddings_requires_grad) |
|
self.set_input_type_embeddings(new_embeddings) |
|
|
|
return self.get_input_type_embeddings() |
|
|
|
def resize_type_embeddings( |
|
self, new_num_types: Optional[int] = None, pad_to_multiple_of: Optional[int] = None |
|
) -> nn.Embedding: |
|
""" |
|
Same as resize_token_embeddings but for the token type embeddings. |
|
|
|
Resizes input token type embeddings matrix of the model if `new_num_types != config.type_vocab_size`. |
|
|
|
Takes care of tying weights embeddings afterwards if the model class has a `tie_weights()` method. |
|
|
|
Arguments: |
|
new_num_types (`int`, *optional*): |
|
The number of new token types in the embedding matrix. Increasing the size will add newly initialized |
|
vectors at the end. Reducing the size will remove vectors from the end. If not provided or `None`, just |
|
returns a pointer to the input tokens `torch.nn.Embedding` module of the model without doing anything. |
|
pad_to_multiple_of (`int`, *optional*): |
|
If set will pad the embedding matrix to a multiple of the provided value.If `new_num_tokens` is set to |
|
`None` will just pad the embedding to a multiple of `pad_to_multiple_of`. |
|
|
|
This is especially useful to enable the use of Tensor Cores on NVIDIA hardware with compute capability |
|
`>= 7.5` (Volta), or on TPUs which benefit from having sequence lengths be a multiple of 128. For more |
|
details about this, or help on choosing the correct value for resizing, refer to this guide: |
|
https://docs.nvidia.com/deeplearning/performance/dl-performance-matrix-multiplication/index.html#requirements-tc |
|
|
|
Return: |
|
`torch.nn.Embedding`: Pointer to the input tokens Embeddings Module of the model. |
|
""" |
|
model_embeds = self._resize_type_embeddings(new_num_types, pad_to_multiple_of) |
|
if new_num_types is None and pad_to_multiple_of is None: |
|
return model_embeds |
|
|
|
|
|
self.model.config.type_vocab_size = model_embeds.weight.shape[0] |
|
|
|
|
|
self.model.tie_weights() |
|
|
|
return model_embeds |
|
|
|
def forward(self, inputs: InputType, targets: Optional[TargetType] = None) -> OutputType: |
|
kwargs = {**inputs, **(targets or {})} |
|
|
|
kwargs["token_type_ids"] = kwargs.pop(self.token_type_ids_key) |
|
return self.model(**kwargs) |
|
|