Spaces:
Sleeping
Sleeping
# Copyright (c) Facebook, Inc. and its affiliates. | |
# | |
# This source code is licensed under the MIT license found in the | |
# LICENSE file in the root directory of this source tree. | |
from collections import OrderedDict | |
from fairseq import utils | |
from fairseq.models import ( | |
FairseqMultiModel, | |
register_model, | |
register_model_architecture, | |
) | |
from fairseq.models.transformer import ( | |
Embedding, | |
TransformerDecoder, | |
TransformerEncoder, | |
TransformerModel, | |
base_architecture, | |
) | |
class MultilingualTransformerModel(FairseqMultiModel): | |
"""Train Transformer models for multiple language pairs simultaneously. | |
Requires `--task multilingual_translation`. | |
We inherit all arguments from TransformerModel and assume that all language | |
pairs use a single Transformer architecture. In addition, we provide several | |
options that are specific to the multilingual setting. | |
Args: | |
--share-encoder-embeddings: share encoder embeddings across all source languages | |
--share-decoder-embeddings: share decoder embeddings across all target languages | |
--share-encoders: share all encoder params (incl. embeddings) across all source languages | |
--share-decoders: share all decoder params (incl. embeddings) across all target languages | |
""" | |
def __init__(self, encoders, decoders): | |
super().__init__(encoders, decoders) | |
def add_args(parser): | |
"""Add model-specific arguments to the parser.""" | |
TransformerModel.add_args(parser) | |
parser.add_argument( | |
"--share-encoder-embeddings", | |
action="store_true", | |
help="share encoder embeddings across languages", | |
) | |
parser.add_argument( | |
"--share-decoder-embeddings", | |
action="store_true", | |
help="share decoder embeddings across languages", | |
) | |
parser.add_argument( | |
"--share-encoders", | |
action="store_true", | |
help="share encoders across languages", | |
) | |
parser.add_argument( | |
"--share-decoders", | |
action="store_true", | |
help="share decoders across languages", | |
) | |
def build_model(cls, args, task): | |
"""Build a new model instance.""" | |
from fairseq.tasks.multilingual_translation import MultilingualTranslationTask | |
assert isinstance(task, MultilingualTranslationTask) | |
# make sure all arguments are present in older models | |
base_multilingual_architecture(args) | |
if not hasattr(args, "max_source_positions"): | |
args.max_source_positions = 1024 | |
if not hasattr(args, "max_target_positions"): | |
args.max_target_positions = 1024 | |
src_langs = [lang_pair.split("-")[0] for lang_pair in task.model_lang_pairs] | |
tgt_langs = [lang_pair.split("-")[1] for lang_pair in task.model_lang_pairs] | |
if args.share_encoders: | |
args.share_encoder_embeddings = True | |
if args.share_decoders: | |
args.share_decoder_embeddings = True | |
def build_embedding(dictionary, embed_dim, path=None): | |
num_embeddings = len(dictionary) | |
padding_idx = dictionary.pad() | |
emb = Embedding(num_embeddings, embed_dim, padding_idx) | |
# if provided, load from preloaded dictionaries | |
if path: | |
embed_dict = utils.parse_embedding(path) | |
utils.load_embedding(embed_dict, dictionary, emb) | |
return emb | |
# build shared embeddings (if applicable) | |
shared_encoder_embed_tokens, shared_decoder_embed_tokens = None, None | |
if args.share_all_embeddings: | |
if args.encoder_embed_dim != args.decoder_embed_dim: | |
raise ValueError( | |
"--share-all-embeddings requires --encoder-embed-dim to match --decoder-embed-dim" | |
) | |
if args.decoder_embed_path and ( | |
args.decoder_embed_path != args.encoder_embed_path | |
): | |
raise ValueError( | |
"--share-all-embeddings not compatible with --decoder-embed-path" | |
) | |
shared_encoder_embed_tokens = FairseqMultiModel.build_shared_embeddings( | |
dicts=task.dicts, | |
langs=task.langs, | |
embed_dim=args.encoder_embed_dim, | |
build_embedding=build_embedding, | |
pretrained_embed_path=args.encoder_embed_path, | |
) | |
shared_decoder_embed_tokens = shared_encoder_embed_tokens | |
args.share_decoder_input_output_embed = True | |
else: | |
if args.share_encoder_embeddings: | |
shared_encoder_embed_tokens = FairseqMultiModel.build_shared_embeddings( | |
dicts=task.dicts, | |
langs=src_langs, | |
embed_dim=args.encoder_embed_dim, | |
build_embedding=build_embedding, | |
pretrained_embed_path=args.encoder_embed_path, | |
) | |
if args.share_decoder_embeddings: | |
shared_decoder_embed_tokens = FairseqMultiModel.build_shared_embeddings( | |
dicts=task.dicts, | |
langs=tgt_langs, | |
embed_dim=args.decoder_embed_dim, | |
build_embedding=build_embedding, | |
pretrained_embed_path=args.decoder_embed_path, | |
) | |
# encoders/decoders for each language | |
lang_encoders, lang_decoders = {}, {} | |
def get_encoder(lang): | |
if lang not in lang_encoders: | |
if shared_encoder_embed_tokens is not None: | |
encoder_embed_tokens = shared_encoder_embed_tokens | |
else: | |
encoder_embed_tokens = build_embedding( | |
task.dicts[lang], | |
args.encoder_embed_dim, | |
args.encoder_embed_path, | |
) | |
lang_encoders[lang] = cls._get_module_class( | |
True, args, task.dicts[lang], encoder_embed_tokens, src_langs | |
) | |
return lang_encoders[lang] | |
def get_decoder(lang): | |
if lang not in lang_decoders: | |
if shared_decoder_embed_tokens is not None: | |
decoder_embed_tokens = shared_decoder_embed_tokens | |
else: | |
decoder_embed_tokens = build_embedding( | |
task.dicts[lang], | |
args.decoder_embed_dim, | |
args.decoder_embed_path, | |
) | |
lang_decoders[lang] = cls._get_module_class( | |
False, args, task.dicts[lang], decoder_embed_tokens, tgt_langs | |
) | |
return lang_decoders[lang] | |
# shared encoders/decoders (if applicable) | |
shared_encoder, shared_decoder = None, None | |
if args.share_encoders: | |
shared_encoder = get_encoder(src_langs[0]) | |
if args.share_decoders: | |
shared_decoder = get_decoder(tgt_langs[0]) | |
encoders, decoders = OrderedDict(), OrderedDict() | |
for lang_pair, src, tgt in zip(task.model_lang_pairs, src_langs, tgt_langs): | |
encoders[lang_pair] = ( | |
shared_encoder if shared_encoder is not None else get_encoder(src) | |
) | |
decoders[lang_pair] = ( | |
shared_decoder if shared_decoder is not None else get_decoder(tgt) | |
) | |
return MultilingualTransformerModel(encoders, decoders) | |
def _get_module_class(cls, is_encoder, args, lang_dict, embed_tokens, langs): | |
module_class = TransformerEncoder if is_encoder else TransformerDecoder | |
return module_class(args, lang_dict, embed_tokens) | |
def load_state_dict(self, state_dict, strict=True, model_cfg=None): | |
state_dict_subset = state_dict.copy() | |
for k, _ in state_dict.items(): | |
assert k.startswith("models.") | |
lang_pair = k.split(".")[1] | |
if lang_pair not in self.models: | |
del state_dict_subset[k] | |
super().load_state_dict(state_dict_subset, strict=strict, model_cfg=model_cfg) | |
def base_multilingual_architecture(args): | |
base_architecture(args) | |
args.share_encoder_embeddings = getattr(args, "share_encoder_embeddings", False) | |
args.share_decoder_embeddings = getattr(args, "share_decoder_embeddings", False) | |
args.share_encoders = getattr(args, "share_encoders", False) | |
args.share_decoders = getattr(args, "share_decoders", False) | |
def multilingual_transformer_iwslt_de_en(args): | |
args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 512) | |
args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 1024) | |
args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 4) | |
args.encoder_layers = getattr(args, "encoder_layers", 6) | |
args.decoder_embed_dim = getattr(args, "decoder_embed_dim", 512) | |
args.decoder_ffn_embed_dim = getattr(args, "decoder_ffn_embed_dim", 1024) | |
args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 4) | |
args.decoder_layers = getattr(args, "decoder_layers", 6) | |
base_multilingual_architecture(args) | |