Spaces:
Sleeping
Sleeping
Dit-document-layout-analysis
/
unilm
/decoding
/GAD
/fairseq
/model_parallel
/models
/transformer.py
# Copyright (c) Facebook, Inc. and its affiliates. | |
# | |
# This source code is licensed under the MIT license found in the | |
# LICENSE file in the root directory of this source tree. | |
import logging | |
import torch.nn as nn | |
from fairseq.model_parallel.modules import ( | |
ModelParallelTransformerDecoderLayer, | |
ModelParallelTransformerEncoderLayer, | |
) | |
from fairseq.models import register_model | |
from fairseq.models.transformer import ( | |
TransformerDecoder, | |
TransformerEncoder, | |
TransformerModel, | |
) | |
try: | |
from fairseq.model_parallel.megatron.mpu import ( | |
copy_to_model_parallel_region, | |
gather_from_model_parallel_region, | |
VocabParallelEmbedding, | |
) | |
has_megatron_submodule = True | |
except (ImportError, ModuleNotFoundError): | |
has_megatron_submodule = False | |
logger = logging.getLogger(__name__) | |
class ModelParallelTransformerModel(TransformerModel): | |
""" | |
Model parallel Transformer model. | |
""" | |
def build_embedding(cls, args, dictionary, embed_dim, path=None): | |
if not has_megatron_submodule: | |
raise ImportError( | |
"\n\nPlease install the megatron submodule:" | |
"\n\n git submodule update --init " | |
"fairseq/model_parallel/megatron" | |
) | |
dictionary.pad_to_multiple_(args.model_parallel_size * 8) | |
num_embeddings = len(dictionary) | |
padding_idx = dictionary.pad() | |
def _vocab_init(tensor, **kwargs): | |
nn.init.normal_(tensor, mean=0, std=num_embeddings ** -0.5) | |
nn.init.constant_(tensor[1], 0) | |
emb = VocabParallelEmbedding( | |
num_embeddings, embed_dim, padding_idx, init_method=_vocab_init | |
) | |
# if provided, load from preloaded dictionaries | |
if path: | |
raise NotImplementedError( | |
"Loading of embedding from path is not supported for model parallel" | |
) | |
return emb | |
def build_encoder(cls, args, src_dict, embed_tokens): | |
return ModelParallelTransformerEncoder(args, src_dict, embed_tokens) | |
def build_decoder(cls, args, tgt_dict, embed_tokens): | |
return ModelParallelTransformerDecoder( | |
args, | |
tgt_dict, | |
embed_tokens, | |
no_encoder_attn=getattr(args, "no_cross_attention", False), | |
) | |
class ModelParallelTransformerEncoder(TransformerEncoder): | |
""" | |
Model parallel Transformer encoder consisting of *args.encoder_layers* layers. Each layer | |
is a :class:`ModelParallelTransformerEncoderLayer`. | |
""" | |
def __init__(self, args, dictionary, embed_tokens): | |
super().__init__(args, dictionary, embed_tokens) | |
if args.no_final_layer_norm: | |
self.layer_norm = None | |
def build_encoder_layer(self, args): | |
return ModelParallelTransformerEncoderLayer(args) | |
class ModelParallelTransformerDecoder(TransformerDecoder): | |
""" | |
Model Parallel Transformer decoder consisting of *args.decoder_layers* layers. Each layer | |
is a :class:`ModelParallelTransformerDecoderLayer`. | |
""" | |
def build_decoder_layer(self, args, no_encoder_attn=False): | |
return ModelParallelTransformerDecoderLayer(args, no_encoder_attn) | |
def output_layer(self, features, **kwargs): | |
"""Project features to the vocabulary size.""" | |
if not self.share_input_output_embed: | |
raise NotImplementedError( | |
"Model parallel training currently requires --share-decoder-input-output-embed" | |
) | |
features = copy_to_model_parallel_region(features) | |
# project back to size of vocabulary | |
x = self.output_projection(features) | |
if getattr(self.args, "criterion") != "vocab_parallel_cross_entropy": | |
x = gather_from_model_parallel_region(x).contiguous() | |
return x | |