Dit-document-layout-analysis
/
unilm
/decoding
/GAD
/block_plugins
/tasks
/translation_lev_modified.py
# Copyright (c) Facebook, Inc. and its affiliates. | |
# | |
# This source code is licensed under the MIT license found in the | |
# LICENSE file in the root directory of this source tree. | |
from dataclasses import dataclass, field | |
from math import log | |
import torch | |
from fairseq import utils | |
from fairseq.data import LanguagePairDataset | |
from fairseq.dataclass import ChoiceEnum | |
from fairseq.tasks import register_task | |
from fairseq.tasks.translation import TranslationConfig, TranslationTask, load_langpair_dataset | |
from fairseq.utils import new_arange | |
import logging | |
from omegaconf import II | |
import numpy as np | |
NOISE_CHOICES = ChoiceEnum(["random_delete", "random_mask", "no_noise", "full_mask", "block_mask"]) | |
class TranslationLevenshteinConfig(TranslationConfig): | |
noise: NOISE_CHOICES = field( | |
default="random_delete", | |
metadata={ | |
"help": "type of noise" | |
}, | |
) | |
start_p: float = field( | |
default=0.5, metadata={"help": "minus prob"} | |
) | |
minus_p: float = field( | |
default=0.2, metadata={"help": "minus prob"} | |
) | |
total_up: int = field( | |
default=300000, metadata={"help": "total updates"} | |
) | |
block_size: int = field( | |
default=5, metadata={"help": "block size"} | |
) | |
logger = logging.getLogger(__name__) | |
class TranslationLevenshteinModifiedTask(TranslationTask): | |
""" | |
Translation (Sequence Generation) task for Levenshtein Transformer | |
See `"Levenshtein Transformer" <https://arxiv.org/abs/1905.11006>`_. | |
""" | |
cfg: TranslationLevenshteinConfig | |
def load_dataset(self, split, epoch=1, combine=False, **kwargs): | |
"""Load a given dataset split. | |
Args: | |
split (str): name of the split (e.g., train, valid, test) | |
""" | |
paths = utils.split_paths(self.cfg.data) | |
assert len(paths) > 0 | |
data_path = paths[(epoch - 1) % len(paths)] | |
# infer langcode | |
src, tgt = self.cfg.source_lang, self.cfg.target_lang | |
self.datasets[split] = load_langpair_dataset( | |
data_path, | |
split, | |
src, | |
self.src_dict, | |
tgt, | |
self.tgt_dict, | |
combine=combine, | |
dataset_impl=self.cfg.dataset_impl, | |
upsample_primary=self.cfg.upsample_primary, | |
left_pad_source=self.cfg.left_pad_source, | |
left_pad_target=self.cfg.left_pad_target, | |
max_source_positions=self.cfg.max_source_positions, | |
max_target_positions=self.cfg.max_target_positions, | |
truncate_source=self.cfg.truncate_source, | |
) | |
def inject_noise(self, target_tokens): | |
def _random_delete(target_tokens): | |
pad = self.tgt_dict.pad() | |
bos = self.tgt_dict.bos() | |
eos = self.tgt_dict.eos() | |
max_len = target_tokens.size(1) | |
target_mask = target_tokens.eq(pad) | |
target_score = target_tokens.clone().float().uniform_() | |
target_score.masked_fill_( | |
target_tokens.eq(bos) | target_tokens.eq(eos), 0.0 | |
) | |
target_score.masked_fill_(target_mask, 1) | |
target_score, target_rank = target_score.sort(1) | |
target_length = target_mask.size(1) - target_mask.float().sum( | |
1, keepdim=True | |
) | |
# do not delete <bos> and <eos> (we assign 0 score for them) | |
target_cutoff = ( | |
2 | |
+ ( | |
(target_length - 2) | |
* target_score.new_zeros(target_score.size(0), 1).uniform_() | |
).long() | |
) | |
target_cutoff = target_score.sort(1)[1] >= target_cutoff | |
prev_target_tokens = ( | |
target_tokens.gather(1, target_rank) | |
.masked_fill_(target_cutoff, pad) | |
.gather(1, target_rank.masked_fill_(target_cutoff, max_len).sort(1)[1]) | |
) | |
prev_target_tokens = prev_target_tokens[ | |
:, : prev_target_tokens.ne(pad).sum(1).max() | |
] | |
return prev_target_tokens | |
def _random_mask(target_tokens): | |
pad = self.tgt_dict.pad() | |
bos = self.tgt_dict.bos() | |
eos = self.tgt_dict.eos() | |
unk = self.tgt_dict.unk() | |
target_masks = ( | |
target_tokens.ne(pad) & target_tokens.ne(bos) & target_tokens.ne(eos) | |
) | |
target_score = target_tokens.clone().float().uniform_() | |
target_score.masked_fill_(~target_masks, 2.0) | |
target_length = target_masks.sum(1).float() | |
target_length = target_length * target_length.clone().uniform_() | |
target_length = target_length + 1 # make sure to mask at least one token. | |
_, target_rank = target_score.sort(1) | |
target_cutoff = new_arange(target_rank) < target_length[:, None].long() | |
prev_target_tokens = target_tokens.masked_fill( | |
target_cutoff.scatter(1, target_rank, target_cutoff), unk | |
) | |
return prev_target_tokens | |
def _full_mask(target_tokens): | |
pad = self.tgt_dict.pad() | |
bos = self.tgt_dict.bos() | |
eos = self.tgt_dict.eos() | |
unk = self.tgt_dict.unk() | |
target_mask = ( | |
target_tokens.eq(bos) | target_tokens.eq(eos) | target_tokens.eq(pad) | |
) | |
return target_tokens.masked_fill(~target_mask, unk) | |
def _block_mask(target_tokens): | |
block_size = self.cfg.block_size | |
pad = self.tgt_dict.pad() | |
unk = self.tgt_dict.unk() | |
target_masks = target_tokens.ne(pad) | |
target_length = target_masks.sum(1).float() | |
cutoff_length = target_length * target_length.clone().uniform_() | |
cutoff_length = cutoff_length.int() + 1 # make sure to mask at least one token. | |
prev_target_tokens = torch.ones((target_tokens.size(0), | |
target_tokens.size(1) + block_size)).to(target_tokens) | |
padded_target_tokens = torch.ones((target_tokens.size(0), | |
target_tokens.size(1) + block_size)).to(target_tokens) | |
for i in range(target_tokens.size(0)): | |
remain_length = target_length[i].int() - cutoff_length[i] | |
prev_target_tokens[i][:remain_length] = target_tokens[i][:remain_length] | |
prev_target_tokens[i][remain_length:block_size + remain_length] = unk | |
padded_target_tokens[i][:target_tokens.size(1)] = target_tokens[i] | |
prev_target_tokens = prev_target_tokens[ | |
:, : prev_target_tokens.ne(pad).sum(1).max() | |
] | |
padded_target_tokens = padded_target_tokens[ | |
:, : prev_target_tokens.ne(pad).sum(1).max() | |
] | |
return prev_target_tokens, padded_target_tokens | |
if self.cfg.noise == "random_delete": | |
return _random_delete(target_tokens) | |
elif self.cfg.noise == "random_mask": | |
return _random_mask(target_tokens) | |
elif self.cfg.noise == "block_mask": | |
return _block_mask(target_tokens) | |
elif self.cfg.noise == "full_mask": | |
return _full_mask(target_tokens) | |
elif self.cfg.noise == "no_noise": | |
return target_tokens | |
else: | |
raise NotImplementedError | |
def build_generator(self, models, args, **unused): | |
# add models input to match the API for SequenceGenerator | |
from fairseq.iterative_refinement_generator import IterativeRefinementGenerator | |
return IterativeRefinementGenerator( | |
self.target_dictionary, | |
eos_penalty=getattr(args, "iter_decode_eos_penalty", 0.0), | |
max_iter=getattr(args, "iter_decode_max_iter", 10), | |
beam_size=getattr(args, "iter_decode_with_beam", 1), | |
reranking=getattr(args, "iter_decode_with_external_reranker", False), | |
decoding_format=getattr(args, "decoding_format", None), | |
adaptive=not getattr(args, "iter_decode_force_max_iter", False), | |
retain_history=getattr(args, "retain_iter_history", False), | |
) | |
def build_dataset_for_inference(self, src_tokens, src_lengths, constraints=None): | |
if constraints is not None: | |
# Though see Susanto et al. (ACL 2020): https://www.aclweb.org/anthology/2020.acl-main.325/ | |
raise NotImplementedError( | |
"Constrained decoding with the translation_lev task is not supported" | |
) | |
return LanguagePairDataset( | |
src_tokens, src_lengths, self.source_dictionary, append_bos=False | |
) | |
def train_step( | |
self, sample, model, criterion, optimizer, update_num, ignore_grad=False | |
): | |
model.train() | |
train_ratio = max(0, min(1, update_num / self.cfg.total_up)) | |
sample["glat"] = {"context_p": self.cfg.start_p - self.cfg.minus_p * train_ratio} | |
sample["prev_target"], sample["target"] = self.inject_noise(sample["target"]) | |
with torch.autograd.profiler.record_function("forward"): | |
loss, sample_size, logging_output = criterion(model, sample) | |
if ignore_grad: | |
loss *= 0 | |
with torch.autograd.profiler.record_function("backward"): | |
optimizer.backward(loss) | |
return loss, sample_size, logging_output | |
def valid_step(self, sample, model, criterion): | |
model.eval() | |
with torch.no_grad(): | |
sample["prev_target"], sample["target"] = self.inject_noise(sample["target"]) | |
loss, sample_size, logging_output = criterion(model, sample) | |
EVAL_BLEU_ORDER = 4 | |
if self.cfg.eval_bleu: | |
bleu = self._inference_with_bleu(self.sequence_generator, sample, model) | |
logging_output["_bleu_sys_len"] = bleu.sys_len | |
logging_output["_bleu_ref_len"] = bleu.ref_len | |
# we split counts into separate entries so that they can be | |
# summed efficiently across workers using fast-stat-sync | |
assert len(bleu.counts) == EVAL_BLEU_ORDER | |
for i in range(EVAL_BLEU_ORDER): | |
logging_output["_bleu_counts_" + str(i)] = bleu.counts[i] | |
logging_output["_bleu_totals_" + str(i)] = bleu.totals[i] | |
return loss, sample_size, logging_output | |
def _inference_with_bleu(self, generator, sample, model): | |
import sacrebleu | |
def decode(toks, escape_unk=False): | |
s = self.tgt_dict.string( | |
toks.int().cpu(), | |
self.cfg.eval_bleu_remove_bpe, | |
# The default unknown string in fairseq is `<unk>`, but | |
# this is tokenized by sacrebleu as `< unk >`, inflating | |
# BLEU scores. Instead, we use a somewhat more verbose | |
# alternative that is unlikely to appear in the real | |
# reference, but doesn't get split into multiple tokens. | |
unk_string=("UNKNOWNTOKENINREF" if escape_unk else "UNKNOWNTOKENINHYP"), | |
) | |
if self.tokenizer: | |
s = self.tokenizer.decode(s) | |
return s | |
gen_out = self.inference_step(generator, [model], sample, prefix_tokens=None) | |
hyps, refs = [], [] | |
for i in range(len(gen_out)): | |
hyps.append(decode(gen_out[i][0]["tokens"])) | |
refs.append( | |
decode( | |
utils.strip_pad(sample["target"][i], self.tgt_dict.pad()), | |
escape_unk=True, # don't count <unk> as matches to the hypo | |
) | |
) | |
if self.cfg.eval_bleu_print_samples: | |
logger.info("example hypothesis: " + hyps[0]) | |
logger.info("example reference: " + refs[0]) | |
if self.cfg.eval_tokenized_bleu: | |
return sacrebleu.corpus_bleu(hyps, [refs], tokenize="none") | |
else: | |
return sacrebleu.corpus_bleu(hyps, [refs]) | |