|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
""" |
|
This script implemts Noisy Channel Reranking (NCR) - https://arxiv.org/abs/1908.05731 |
|
Given .nemo files for a, reverse model (target -> source) and transformer LM (target LM) NMT model's .nemo file, |
|
this script can be used to re-rank a forward model's (source -> target) beam candidates. |
|
|
|
This script can be used in two ways 1) Given the score file generated by `nmt_transformer_infer.py`, re-rank beam candidates and |
|
2) Given NCR score file generated by 1), Re-rank beam candidates based only on cached scores in the ncr file. This is meant to tune NCR coeficients. |
|
|
|
Pre-requisite: Generating translations using `nmt_transformer_infer.py` |
|
1. Obtain text file in src language. You can use sacrebleu to obtain standard test sets like so: |
|
sacrebleu -t wmt14 -l de-en --echo src > wmt14-de-en.src |
|
2. Translate using `nmt_transformer_infer.py` with a large beam size.: |
|
python nmt_transformer_infer.py --model=[Path to .nemo file(s)] --srctext=wmt14-de-en.src --tgtout=wmt14-de-en.translations --beam_size 15 --write_scores |
|
|
|
USAGE Example (case 1): |
|
Re-rank beam candidates: |
|
python noisy_channel_reranking.py \ |
|
--reverse_model=[Path to .nemo file] \ |
|
--language_model=[Path to .nemo file] \ |
|
--srctext=wmt14-de-en.translations.scores \ |
|
--tgtout=wmt14-de-en.ncr.translations \ |
|
--forward_model_coef=1.0 \ |
|
--reverse_model_coef=0.7 \ |
|
--target_lm_coef=0.05 \ |
|
--write_scores \ |
|
|
|
USAGE Example (case 2): |
|
Re-rank beam candidates using cached score file only |
|
python noisy_channel_reranking.py \ |
|
--cached_score_file=wmt14-de-en.ncr.translations.scores \ |
|
--forward_model_coef=1.0 \ |
|
--reverse_model_coef=0.7 \ |
|
--target_lm_coef=0.05 \ |
|
--tgtout=wmt14-de-en.ncr.translations \ |
|
""" |
|
|
|
|
|
from argparse import ArgumentParser |
|
|
|
import numpy as np |
|
import torch |
|
|
|
import nemo.collections.nlp as nemo_nlp |
|
from nemo.utils import logging |
|
|
|
|
|
def score_fusion(args, forward_scores, rev_scores, lm_scores, src_lens, tgt_lens): |
|
""" |
|
Fuse forward, reverse and language model scores. |
|
""" |
|
fused_scores = [] |
|
for forward_score, rev_score, lm_score, src_len, tgt_len in zip( |
|
forward_scores, rev_scores, lm_scores, src_lens, tgt_lens |
|
): |
|
score = 0 |
|
|
|
forward_score = forward_score / tgt_len if args.length_normalize_scores else forward_score |
|
score += args.forward_model_coef * forward_score |
|
|
|
rev_score = rev_score / src_len if args.length_normalize_scores else rev_score |
|
score += args.reverse_model_coef * rev_score |
|
|
|
lm_score = lm_score / tgt_len if args.length_normalize_scores else lm_score |
|
score += args.target_lm_coef * lm_score |
|
|
|
if args.len_pen is not None: |
|
score = score / (((5 + tgt_len) / 6) ** args.len_pen) |
|
|
|
fused_scores.append(score) |
|
|
|
return fused_scores |
|
|
|
|
|
def main(): |
|
parser = ArgumentParser() |
|
parser.add_argument( |
|
"--reverse_model", |
|
type=str, |
|
help="Path to .nemo model file(s). If ensembling, provide comma separated paths to multiple models.", |
|
) |
|
parser.add_argument( |
|
"--language_model", type=str, help="Optional path to an LM model that has the same tokenizer as NMT models.", |
|
) |
|
parser.add_argument( |
|
"--forward_model_coef", |
|
type=float, |
|
default=1.0, |
|
help="Weight assigned to the forward NMT model for re-ranking.", |
|
) |
|
parser.add_argument( |
|
"--reverse_model_coef", |
|
type=float, |
|
default=0.7, |
|
help="Weight assigned to the reverse NMT model for re-ranking.", |
|
) |
|
parser.add_argument( |
|
"--target_lm_coef", type=float, default=0.07, help="Weight assigned to the target LM model for re-ranking.", |
|
) |
|
parser.add_argument( |
|
"--srctext", |
|
type=str, |
|
default=None, |
|
help="Path to a TSV file containing forward model scores of the format source \t beam_candidate_i \t forward_score", |
|
) |
|
parser.add_argument( |
|
"--cached_score_file", |
|
type=str, |
|
default=None, |
|
help="Path to a TSV file containing cached scores for each beam candidate. Format source \t target \t forward_score \t reverse_score \t lm_score \t src_len \t tgt_len", |
|
) |
|
parser.add_argument( |
|
"--tgtout", type=str, required=True, help="Path to the file where re-ranked translations are to be written." |
|
) |
|
parser.add_argument( |
|
"--beam_size", |
|
type=int, |
|
default=4, |
|
help="Beam size with which forward model translations were generated. IMPORTANT: mismatch can lead to wrong results and an incorrect number of generated translations.", |
|
) |
|
parser.add_argument( |
|
"--target_lang", type=str, default=None, help="Target language identifier ex: en,de,fr,es etc." |
|
) |
|
parser.add_argument( |
|
"--source_lang", type=str, default=None, help="Source language identifier ex: en,de,fr,es etc." |
|
) |
|
parser.add_argument( |
|
"--write_scores", action="store_true", help="Whether to write forward, reverse and lm scores to a file." |
|
) |
|
parser.add_argument( |
|
"--length_normalize_scores", |
|
action="store_true", |
|
help="If true, it will divide forward, reverse and lm scores by the corresponding sequence length.", |
|
) |
|
parser.add_argument( |
|
"--len_pen", |
|
type=float, |
|
default=None, |
|
help="Apply a length penalty based on target lengths to the final NCR score.", |
|
) |
|
|
|
args = parser.parse_args() |
|
torch.set_grad_enabled(False) |
|
|
|
if args.cached_score_file is None: |
|
reverse_models = [] |
|
for model_path in args.reverse_model.split(','): |
|
if not model_path.endswith('.nemo'): |
|
raise NotImplementedError(f"Only support .nemo files, but got: {model_path}") |
|
model = nemo_nlp.models.machine_translation.MTEncDecModel.restore_from(restore_path=model_path).eval() |
|
model.eval_loss_fn.reduction = 'none' |
|
reverse_models.append(model) |
|
|
|
lm_model = nemo_nlp.models.language_modeling.TransformerLMModel.restore_from( |
|
restore_path=args.language_model |
|
).eval() |
|
|
|
if args.srctext is not None and args.cached_score_file is not None: |
|
raise ValueError("Only one of --srctext or --cached_score_file must be provided.") |
|
|
|
if args.srctext is None and args.cached_score_file is None: |
|
raise ValueError("Neither --srctext nor --cached_score_file were provided.") |
|
|
|
if args.srctext is not None: |
|
logging.info(f"Re-ranking: {args.srctext}") |
|
else: |
|
logging.info(f"Re-ranking from cached score file only: {args.cached_score_file}") |
|
|
|
if args.cached_score_file is None: |
|
if torch.cuda.is_available(): |
|
reverse_models = [model.cuda() for model in reverse_models] |
|
lm_model = lm_model.cuda() |
|
|
|
src_text = [] |
|
tgt_text = [] |
|
all_reverse_scores = [] |
|
all_lm_scores = [] |
|
all_forward_scores = [] |
|
all_src_lens = [] |
|
all_tgt_lens = [] |
|
|
|
|
|
if args.cached_score_file is not None: |
|
if args.write_scores: |
|
raise ValueError("--write_scores cannot be provided with a cached score file.") |
|
if args.reverse_model is not None: |
|
raise ValueError( |
|
"--reverse_model cannot be provided with a cached score file since it assumes reverse scores already present in the cached file." |
|
) |
|
if args.language_model is not None: |
|
raise ValueError( |
|
"--language_model cannot be provided with a cached score file since it assumes language model scores already present in the cached file." |
|
) |
|
|
|
if args.srctext is not None: |
|
|
|
with open(args.srctext, 'r') as src_f: |
|
count = 0 |
|
for line in src_f: |
|
src_text.append(line.strip().split('\t')) |
|
if len(src_text) == args.beam_size: |
|
|
|
src_texts = [item[1] for item in src_text] |
|
tgt_texts = [item[0] for item in src_text] |
|
src, src_mask = reverse_models[0].prepare_inference_batch(src_texts) |
|
tgt, tgt_mask = reverse_models[0].prepare_inference_batch(tgt_texts, target=True) |
|
src_lens = src_mask.sum(1).data.cpu().tolist() |
|
tgt_lens = tgt_mask.sum(1).data.cpu().tolist() |
|
forward_scores = [float(item[2]) for item in src_text] |
|
|
|
|
|
nmt_lls = [] |
|
for model in reverse_models: |
|
nmt_log_probs = model(src, src_mask, tgt[:, :-1], tgt_mask[:, :-1]) |
|
nmt_nll = model.eval_loss_fn(log_probs=nmt_log_probs, labels=tgt[:, 1:]) |
|
nmt_ll = nmt_nll.view(nmt_log_probs.size(0), nmt_log_probs.size(1)).sum(1) * -1.0 |
|
nmt_ll = nmt_ll.data.cpu().numpy().tolist() |
|
nmt_lls.append(nmt_ll) |
|
reverse_scores = np.stack(nmt_lls).mean(0) |
|
|
|
|
|
if lm_model is not None: |
|
|
|
lm_log_probs = lm_model(src[:, :-1], src_mask[:, :-1]) |
|
lm_nll = model.eval_loss_fn(log_probs=lm_log_probs, labels=src[:, 1:]) |
|
lm_ll = lm_nll.view(lm_log_probs.size(0), lm_log_probs.size(1)).sum(1) * -1.0 |
|
lm_ll = lm_ll.data.cpu().numpy().tolist() |
|
else: |
|
lm_ll = None |
|
lm_scores = lm_ll |
|
|
|
all_reverse_scores.extend(reverse_scores) |
|
all_lm_scores.extend(lm_scores) |
|
all_forward_scores.extend(forward_scores) |
|
|
|
|
|
all_src_lens.extend(tgt_lens) |
|
all_tgt_lens.extend(src_lens) |
|
|
|
fused_scores = score_fusion(args, forward_scores, reverse_scores, lm_scores, src_lens, tgt_lens) |
|
tgt_text.append(src_texts[np.argmax(fused_scores)]) |
|
src_text = [] |
|
count += 1 |
|
print(f'Reranked {count} sentences') |
|
|
|
else: |
|
|
|
with open(args.cached_score_file, 'r') as src_f: |
|
count = 0 |
|
for line in src_f: |
|
src_text.append(line.strip().split('\t')) |
|
if len(src_text) == args.beam_size: |
|
if not all([len(item) == 7 for item in src_text]): |
|
raise IndexError( |
|
"All lines did not contain exactly 5 fields. Format - src_txt \t tgt_text \t forward_score \t reverse_score \t lm_score \t src_len \t tgt_len" |
|
) |
|
src_texts = [item[0] for item in src_text] |
|
tgt_texts = [item[1] for item in src_text] |
|
forward_scores = [float(item[2]) for item in src_text] |
|
reverse_scores = [float(item[3]) for item in src_text] |
|
lm_scores = [float(item[4]) for item in src_text] |
|
src_lens = [float(item[5]) for item in src_text] |
|
tgt_lens = [float(item[6]) for item in src_text] |
|
|
|
fused_scores = score_fusion(args, forward_scores, reverse_scores, lm_scores, src_lens, tgt_lens) |
|
tgt_text.append(tgt_texts[np.argmax(fused_scores)]) |
|
src_text = [] |
|
count += 1 |
|
print(f'Reranked {count} sentences') |
|
|
|
with open(args.tgtout, 'w') as tgt_f: |
|
for line in tgt_text: |
|
tgt_f.write(line + "\n") |
|
|
|
|
|
if args.write_scores: |
|
with open(args.tgtout + '.scores', 'w') as tgt_f, open(args.srctext, 'r') as src_f: |
|
src_lines = [] |
|
for line in src_f: |
|
src_lines.append(line.strip().split('\t')) |
|
if not (len(all_reverse_scores) == len(all_lm_scores) == len(all_forward_scores) == len(src_lines)): |
|
raise ValueError( |
|
f"Length of scores files do not match. {len(all_reverse_scores)} != {len(all_lm_scores)} != {len(all_forward_scores)} != {len(src_lines)}. This is most likely because --beam_size is set incorrectly. This needs to be set to the same value that was used to generate translations." |
|
) |
|
for f, r, lm, sl, tl, src in zip( |
|
all_forward_scores, all_reverse_scores, all_lm_scores, all_src_lens, all_tgt_lens, src_lines |
|
): |
|
tgt_f.write( |
|
src[0] |
|
+ '\t' |
|
+ src[1] |
|
+ '\t' |
|
+ str(f) |
|
+ '\t' |
|
+ str(r) |
|
+ '\t' |
|
+ str(lm) |
|
+ '\t' |
|
+ str(sl) |
|
+ '\t' |
|
+ str(tl) |
|
+ '\n' |
|
) |
|
|
|
|
|
if __name__ == '__main__': |
|
main() |
|
|