Dit-document-layout-analysis
/
unilm
/decoding
/GAD
/fairseq
/model_parallel
/models
/roberta
/model.py
# Copyright (c) Facebook, Inc. and its affiliates. | |
# | |
# This source code is licensed under the MIT license found in the | |
# LICENSE file in the root directory of this source tree. | |
""" | |
RoBERTa: A Robustly Optimized BERT Pretraining Approach. | |
""" | |
import logging | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from fairseq import utils | |
from fairseq.model_parallel.models.transformer import ModelParallelTransformerEncoder | |
from fairseq.models import register_model, register_model_architecture | |
from fairseq.models.roberta import ( | |
roberta_base_architecture, | |
roberta_prenorm_architecture, | |
RobertaEncoder, | |
RobertaModel, | |
) | |
from fairseq.modules import LayerNorm | |
try: | |
from fairseq.model_parallel.megatron.mpu import ( | |
copy_to_model_parallel_region, | |
gather_from_model_parallel_region, | |
ColumnParallelLinear, | |
VocabParallelEmbedding, | |
) | |
has_megatron_submodule = True | |
except (ImportError, ModuleNotFoundError): | |
has_megatron_submodule = False | |
logger = logging.getLogger(__name__) | |
class ModelParallelRobertaModel(RobertaModel): | |
def __init__(self, args, encoder): | |
super().__init__(args, encoder) | |
self.classification_heads = nn.ModuleDict() | |
def add_args(parser): | |
RobertaModel.add_args(parser) | |
parser.add_argument( | |
"--no-final-layer-norm", | |
action="store_true", | |
help=( | |
"don't add final layernorm (only applicable when " | |
"--encoder-normalize-before=True" | |
), | |
) | |
def build_model(cls, args, task): | |
"""Build a new model instance.""" | |
# make sure all arguments are present | |
base_architecture(args) | |
task.source_dictionary.pad_to_multiple_(args.model_parallel_size * 8) | |
task.target_dictionary.pad_to_multiple_(args.model_parallel_size * 8) | |
if not hasattr(args, "max_positions"): | |
args.max_positions = args.tokens_per_sample | |
if getattr(args, "untie_weights_roberta", False): | |
raise NotImplementedError( | |
"--untie-weights-roberta is not supported in model parallel mode" | |
) | |
encoder = ModelParallelRobertaEncoder(args, task.source_dictionary) | |
return cls(args, encoder) | |
def forward( | |
self, | |
src_tokens, | |
features_only=False, | |
return_all_hiddens=False, | |
classification_head_name=None, | |
**kwargs | |
): | |
if classification_head_name is not None: | |
features_only = True | |
x, extra = self.encoder(src_tokens, features_only, return_all_hiddens, **kwargs) | |
if classification_head_name is not None: | |
x = self.classification_heads[classification_head_name](x) | |
return x, extra | |
def register_classification_head( | |
self, name, num_classes=None, inner_dim=None, **kwargs | |
): | |
"""Register a classification head.""" | |
if name in self.classification_heads: | |
prev_num_classes = self.classification_heads[name].out_proj.out_features | |
prev_inner_dim = self.classification_heads[name].dense.out_features | |
if num_classes != prev_num_classes or inner_dim != prev_inner_dim: | |
logger.warning( | |
're-registering head "{}" with num_classes {} (prev: {}) ' | |
"and inner_dim {} (prev: {})".format( | |
name, num_classes, prev_num_classes, inner_dim, prev_inner_dim | |
) | |
) | |
self.classification_heads[name] = ModelParallelRobertaClassificationHead( | |
self.args.encoder_embed_dim, | |
inner_dim or self.args.encoder_embed_dim, | |
num_classes, | |
self.args.pooler_activation_fn, | |
self.args.pooler_dropout, | |
) | |
class ModelParallelRobertaLMHead(nn.Module): | |
"""Head for masked language modeling.""" | |
def __init__(self, embed_dim, output_dim, activation_fn, weight=None): | |
super().__init__() | |
self.dense = ColumnParallelLinear(embed_dim, embed_dim, gather_output=True) | |
self.activation_fn = utils.get_activation_fn(activation_fn) | |
self.layer_norm = LayerNorm(embed_dim) | |
if weight is None: | |
weight = nn.Linear(embed_dim, output_dim, bias=False).weight | |
self.weight = weight | |
self.bias = nn.Parameter(torch.zeros(output_dim)) | |
def forward(self, features, masked_tokens=None, **kwargs): | |
# Only project the unmasked tokens while training, | |
# saves both memory and computation | |
if masked_tokens is not None: | |
features = features[masked_tokens, :] | |
x = self.dense(features) | |
x = self.activation_fn(x) | |
x = self.layer_norm(x) | |
x = copy_to_model_parallel_region(x) | |
# project back to size of vocabulary with bias | |
x = F.linear(x, self.weight) | |
x = gather_from_model_parallel_region(x).contiguous() | |
x = x + self.bias | |
return x | |
class ModelParallelRobertaClassificationHead(nn.Module): | |
"""Head for sentence-level classification tasks.""" | |
def __init__( | |
self, input_dim, inner_dim, num_classes, activation_fn, pooler_dropout | |
): | |
super().__init__() | |
self.dense = ColumnParallelLinear(input_dim, inner_dim, gather_output=True) | |
self.activation_fn = utils.get_activation_fn(activation_fn) | |
self.dropout = nn.Dropout(p=pooler_dropout) | |
self.out_proj = nn.Linear(inner_dim, num_classes) | |
def forward(self, features, **kwargs): | |
x = features[:, 0, :] # take <s> token (equiv. to [CLS]) | |
x = self.dropout(x) | |
x = self.dense(x) | |
x = self.activation_fn(x) | |
x = self.dropout(x) | |
x = self.out_proj(x) | |
return x | |
class ModelParallelRobertaEncoder(RobertaEncoder): | |
"""RoBERTa encoder.""" | |
def __init__(self, args, dictionary): | |
super().__init__(args, dictionary) | |
assert not self.args.untie_weights_roberta | |
def build_embedding(self, vocab_size, embedding_dim, padding_idx): | |
return VocabParallelEmbedding(vocab_size, embedding_dim, padding_idx) | |
def build_encoder(self, args, dictionary, embed_tokens): | |
return ModelParallelTransformerEncoder(args, dictionary, embed_tokens) | |
def build_lm_head(self, embed_dim, output_dim, activation_fn, weight): | |
return ModelParallelRobertaLMHead(embed_dim, output_dim, activation_fn, weight) | |
def base_architecture(args): | |
args.no_final_layer_norm = getattr(args, "no_final_layer_norm", False) | |
# model parallel RoBERTa defaults to "Pre-LN" formulation | |
roberta_prenorm_architecture(args) | |
# earlier versions of model parallel RoBERTa removed the final layer norm | |
def model_parallel_roberta_v1_architecture(args): | |
args.no_final_layer_norm = getattr(args, "no_final_layer_norm", True) | |
base_architecture(args) | |
def model_parallel_roberta_postnorm_architecture(args): | |
# the original BERT/RoBERTa uses the "Post-LN" formulation | |
roberta_base_architecture(args) | |
def model_parallel_roberta_base_architecture(args): | |
base_architecture(args) | |
def model_parallel_roberta_large_architecture(args): | |
args.encoder_layers = getattr(args, "encoder_layers", 24) | |
args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 1024) | |
args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 4096) | |
args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 16) | |
base_architecture(args) | |