|
|
|
|
|
|
|
|
|
|
|
from dataclasses import dataclass, field |
|
import torch |
|
from omegaconf import II |
|
|
|
from fairseq import metrics, utils |
|
from fairseq.dataclass import ChoiceEnum |
|
from fairseq.tasks import register_task |
|
from fairseq.tasks.translation import TranslationConfig, TranslationTask |
|
|
|
from .logsumexp_moe import LogSumExpMoE |
|
from .mean_pool_gating_network import MeanPoolGatingNetwork |
|
|
|
|
|
METHOD_CHOICES = ChoiceEnum(["sMoElp", "sMoEup", "hMoElp", "hMoEup"]) |
|
|
|
|
|
@dataclass |
|
class TranslationMoEConfig(TranslationConfig): |
|
method: METHOD_CHOICES = field( |
|
default="hMoEup", |
|
metadata={"help": "MoE method"}, |
|
) |
|
num_experts: int = field( |
|
default=3, |
|
metadata={"help": "number of experts"}, |
|
) |
|
mean_pool_gating_network: bool = field( |
|
default=False, |
|
metadata={"help": "use a simple mean-pooling gating network"}, |
|
) |
|
mean_pool_gating_network_dropout: float = field( |
|
default=0, |
|
metadata={"help": "dropout for mean-pooling gating network"}, |
|
) |
|
mean_pool_gating_network_encoder_dim: int = field( |
|
default=0, |
|
metadata={"help": "encoder output dim for mean-pooling gating network"}, |
|
) |
|
gen_expert: int = field( |
|
default=0, |
|
metadata={"help": "which expert to use for generation"}, |
|
) |
|
sentence_avg: bool = II("optimization.sentence_avg") |
|
|
|
|
|
@register_task("translation_moe", dataclass=TranslationMoEConfig) |
|
class TranslationMoETask(TranslationTask): |
|
""" |
|
Translation task for Mixture of Experts (MoE) models. |
|
|
|
See `"Mixture Models for Diverse Machine Translation: Tricks of the Trade" |
|
(Shen et al., 2019) <https://arxiv.org/abs/1902.07816>`_. |
|
|
|
Args: |
|
src_dict (~fairseq.data.Dictionary): dictionary for the source language |
|
tgt_dict (~fairseq.data.Dictionary): dictionary for the target language |
|
|
|
.. note:: |
|
|
|
The translation task is compatible with :mod:`fairseq-train`, |
|
:mod:`fairseq-generate` and :mod:`fairseq-interactive`. |
|
|
|
The translation task provides the following additional command-line |
|
arguments: |
|
|
|
.. argparse:: |
|
:ref: fairseq.tasks.translation_parser |
|
:prog: |
|
""" |
|
|
|
cfg: TranslationMoEConfig |
|
|
|
def __init__(self, cfg: TranslationMoEConfig, src_dict, tgt_dict): |
|
if cfg.method == "sMoElp": |
|
|
|
self.uniform_prior = False |
|
self.hard_selection = False |
|
elif cfg.method == "sMoEup": |
|
|
|
self.uniform_prior = True |
|
self.hard_selection = False |
|
elif cfg.method == "hMoElp": |
|
|
|
self.uniform_prior = False |
|
self.hard_selection = True |
|
elif cfg.method == "hMoEup": |
|
|
|
self.uniform_prior = True |
|
self.hard_selection = True |
|
|
|
|
|
for i in range(cfg.num_experts): |
|
|
|
src_dict.add_symbol("<expert_{}>".format(i)) |
|
tgt_dict.add_symbol("<expert_{}>".format(i)) |
|
|
|
super().__init__(cfg, src_dict, tgt_dict) |
|
|
|
def build_model(self, cfg): |
|
from fairseq import models |
|
|
|
model = models.build_model(cfg, self) |
|
if not self.uniform_prior and not hasattr(model, "gating_network"): |
|
if self.cfg.mean_pool_gating_network: |
|
if self.cfg.mean_pool_gating_network_encoder_dim > 0: |
|
encoder_dim = self.cfg.mean_pool_gating_network_encoder_dim |
|
elif getattr(cfg, "encoder_embed_dim", None): |
|
|
|
encoder_dim = cfg.encoder_embed_dim |
|
else: |
|
raise ValueError( |
|
"Must specify --mean-pool-gating-network-encoder-dim" |
|
) |
|
|
|
if self.cfg.mean_pool_gating_network_dropout > 0: |
|
dropout = self.cfg.mean_pool_gating_network_dropout |
|
elif getattr(cfg, "dropout", None): |
|
dropout = cfg.dropout |
|
else: |
|
raise ValueError("Must specify task.mean_pool_gating_network_dropout") |
|
|
|
model.gating_network = MeanPoolGatingNetwork( |
|
encoder_dim, |
|
self.cfg.num_experts, |
|
dropout, |
|
) |
|
else: |
|
raise ValueError( |
|
"translation_moe task with learned prior requires the model to " |
|
"have a gating network; try using --mean-pool-gating-network" |
|
) |
|
return model |
|
|
|
def expert_index(self, i): |
|
return i + self.tgt_dict.index("<expert_0>") |
|
|
|
def _get_loss(self, sample, model, criterion): |
|
assert hasattr( |
|
criterion, "compute_loss" |
|
), "translation_moe task requires the criterion to implement the compute_loss() method" |
|
|
|
k = self.cfg.num_experts |
|
bsz = sample["target"].size(0) |
|
|
|
def get_lprob_y(encoder_out, prev_output_tokens_k): |
|
net_output = model.decoder( |
|
prev_output_tokens=prev_output_tokens_k, |
|
encoder_out=encoder_out, |
|
) |
|
loss, _ = criterion.compute_loss(model, net_output, sample, reduce=False) |
|
loss = loss.view(bsz, -1) |
|
return -loss.sum(dim=1, keepdim=True) |
|
|
|
def get_lprob_yz(winners=None): |
|
encoder_out = model.encoder( |
|
src_tokens=sample["net_input"]["src_tokens"], |
|
src_lengths=sample["net_input"]["src_lengths"], |
|
) |
|
|
|
if winners is None: |
|
lprob_y = [] |
|
for i in range(k): |
|
prev_output_tokens_k = sample["net_input"][ |
|
"prev_output_tokens" |
|
].clone() |
|
assert not prev_output_tokens_k.requires_grad |
|
prev_output_tokens_k[:, 0] = self.expert_index(i) |
|
lprob_y.append(get_lprob_y(encoder_out, prev_output_tokens_k)) |
|
lprob_y = torch.cat(lprob_y, dim=1) |
|
else: |
|
prev_output_tokens_k = sample["net_input"]["prev_output_tokens"].clone() |
|
prev_output_tokens_k[:, 0] = self.expert_index(winners) |
|
lprob_y = get_lprob_y(encoder_out, prev_output_tokens_k) |
|
|
|
if self.uniform_prior: |
|
lprob_yz = lprob_y |
|
else: |
|
lprob_z = model.gating_network(encoder_out) |
|
if winners is not None: |
|
lprob_z = lprob_z.gather(dim=1, index=winners.unsqueeze(-1)) |
|
lprob_yz = lprob_y + lprob_z.type_as(lprob_y) |
|
|
|
return lprob_yz |
|
|
|
|
|
with utils.model_eval(model): |
|
with torch.no_grad(): |
|
lprob_yz = get_lprob_yz() |
|
prob_z_xy = torch.nn.functional.softmax(lprob_yz, dim=1) |
|
assert not prob_z_xy.requires_grad |
|
|
|
|
|
if self.hard_selection: |
|
winners = prob_z_xy.max(dim=1)[1] |
|
loss = -get_lprob_yz(winners) |
|
else: |
|
lprob_yz = get_lprob_yz() |
|
loss = -LogSumExpMoE.apply(lprob_yz, prob_z_xy, 1) |
|
|
|
loss = loss.sum() |
|
sample_size = ( |
|
sample["target"].size(0) if self.cfg.sentence_avg else sample["ntokens"] |
|
) |
|
logging_output = { |
|
"loss": utils.item(loss.data), |
|
"ntokens": sample["ntokens"], |
|
"nsentences": bsz, |
|
"sample_size": sample_size, |
|
"posterior": prob_z_xy.float().sum(dim=0).cpu(), |
|
} |
|
return loss, sample_size, logging_output |
|
|
|
def train_step( |
|
self, sample, model, criterion, optimizer, update_num, ignore_grad=False |
|
): |
|
model.train() |
|
loss, sample_size, logging_output = self._get_loss(sample, model, criterion) |
|
if ignore_grad: |
|
loss *= 0 |
|
optimizer.backward(loss) |
|
return loss, sample_size, logging_output |
|
|
|
def valid_step(self, sample, model, criterion): |
|
model.eval() |
|
with torch.no_grad(): |
|
loss, sample_size, logging_output = self._get_loss(sample, model, criterion) |
|
return loss, sample_size, logging_output |
|
|
|
def inference_step( |
|
self, |
|
generator, |
|
models, |
|
sample, |
|
prefix_tokens=None, |
|
expert=None, |
|
constraints=None, |
|
): |
|
expert = expert or self.cfg.gen_expert |
|
with torch.no_grad(): |
|
return generator.generate( |
|
models, |
|
sample, |
|
prefix_tokens=prefix_tokens, |
|
constraints=constraints, |
|
bos_token=self.expert_index(expert), |
|
) |
|
|
|
def reduce_metrics(self, logging_outputs, criterion): |
|
super().reduce_metrics(logging_outputs, criterion) |
|
metrics.log_scalar( |
|
"posterior", |
|
sum(log["posterior"] for log in logging_outputs if "posterior" in log), |
|
) |
|
|