File size: 14,132 Bytes
7934b29 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 |
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This script implemts Noisy Channel Reranking (NCR) - https://arxiv.org/abs/1908.05731
Given .nemo files for a, reverse model (target -> source) and transformer LM (target LM) NMT model's .nemo file,
this script can be used to re-rank a forward model's (source -> target) beam candidates.
This script can be used in two ways 1) Given the score file generated by `nmt_transformer_infer.py`, re-rank beam candidates and
2) Given NCR score file generated by 1), Re-rank beam candidates based only on cached scores in the ncr file. This is meant to tune NCR coeficients.
Pre-requisite: Generating translations using `nmt_transformer_infer.py`
1. Obtain text file in src language. You can use sacrebleu to obtain standard test sets like so:
sacrebleu -t wmt14 -l de-en --echo src > wmt14-de-en.src
2. Translate using `nmt_transformer_infer.py` with a large beam size.:
python nmt_transformer_infer.py --model=[Path to .nemo file(s)] --srctext=wmt14-de-en.src --tgtout=wmt14-de-en.translations --beam_size 15 --write_scores
USAGE Example (case 1):
Re-rank beam candidates:
python noisy_channel_reranking.py \
--reverse_model=[Path to .nemo file] \
--language_model=[Path to .nemo file] \
--srctext=wmt14-de-en.translations.scores \
--tgtout=wmt14-de-en.ncr.translations \
--forward_model_coef=1.0 \
--reverse_model_coef=0.7 \
--target_lm_coef=0.05 \
--write_scores \
USAGE Example (case 2):
Re-rank beam candidates using cached score file only
python noisy_channel_reranking.py \
--cached_score_file=wmt14-de-en.ncr.translations.scores \
--forward_model_coef=1.0 \
--reverse_model_coef=0.7 \
--target_lm_coef=0.05 \
--tgtout=wmt14-de-en.ncr.translations \
"""
from argparse import ArgumentParser
import numpy as np
import torch
import nemo.collections.nlp as nemo_nlp
from nemo.utils import logging
def score_fusion(args, forward_scores, rev_scores, lm_scores, src_lens, tgt_lens):
"""
Fuse forward, reverse and language model scores.
"""
fused_scores = []
for forward_score, rev_score, lm_score, src_len, tgt_len in zip(
forward_scores, rev_scores, lm_scores, src_lens, tgt_lens
):
score = 0
forward_score = forward_score / tgt_len if args.length_normalize_scores else forward_score
score += args.forward_model_coef * forward_score
rev_score = rev_score / src_len if args.length_normalize_scores else rev_score
score += args.reverse_model_coef * rev_score
lm_score = lm_score / tgt_len if args.length_normalize_scores else lm_score
score += args.target_lm_coef * lm_score
if args.len_pen is not None:
score = score / (((5 + tgt_len) / 6) ** args.len_pen)
fused_scores.append(score)
return fused_scores
def main():
parser = ArgumentParser()
parser.add_argument(
"--reverse_model",
type=str,
help="Path to .nemo model file(s). If ensembling, provide comma separated paths to multiple models.",
)
parser.add_argument(
"--language_model", type=str, help="Optional path to an LM model that has the same tokenizer as NMT models.",
)
parser.add_argument(
"--forward_model_coef",
type=float,
default=1.0,
help="Weight assigned to the forward NMT model for re-ranking.",
)
parser.add_argument(
"--reverse_model_coef",
type=float,
default=0.7,
help="Weight assigned to the reverse NMT model for re-ranking.",
)
parser.add_argument(
"--target_lm_coef", type=float, default=0.07, help="Weight assigned to the target LM model for re-ranking.",
)
parser.add_argument(
"--srctext",
type=str,
default=None,
help="Path to a TSV file containing forward model scores of the format source \t beam_candidate_i \t forward_score",
)
parser.add_argument(
"--cached_score_file",
type=str,
default=None,
help="Path to a TSV file containing cached scores for each beam candidate. Format source \t target \t forward_score \t reverse_score \t lm_score \t src_len \t tgt_len",
)
parser.add_argument(
"--tgtout", type=str, required=True, help="Path to the file where re-ranked translations are to be written."
)
parser.add_argument(
"--beam_size",
type=int,
default=4,
help="Beam size with which forward model translations were generated. IMPORTANT: mismatch can lead to wrong results and an incorrect number of generated translations.",
)
parser.add_argument(
"--target_lang", type=str, default=None, help="Target language identifier ex: en,de,fr,es etc."
)
parser.add_argument(
"--source_lang", type=str, default=None, help="Source language identifier ex: en,de,fr,es etc."
)
parser.add_argument(
"--write_scores", action="store_true", help="Whether to write forward, reverse and lm scores to a file."
)
parser.add_argument(
"--length_normalize_scores",
action="store_true",
help="If true, it will divide forward, reverse and lm scores by the corresponding sequence length.",
)
parser.add_argument(
"--len_pen",
type=float,
default=None,
help="Apply a length penalty based on target lengths to the final NCR score.",
)
args = parser.parse_args()
torch.set_grad_enabled(False)
if args.cached_score_file is None:
reverse_models = []
for model_path in args.reverse_model.split(','):
if not model_path.endswith('.nemo'):
raise NotImplementedError(f"Only support .nemo files, but got: {model_path}")
model = nemo_nlp.models.machine_translation.MTEncDecModel.restore_from(restore_path=model_path).eval()
model.eval_loss_fn.reduction = 'none'
reverse_models.append(model)
lm_model = nemo_nlp.models.language_modeling.TransformerLMModel.restore_from(
restore_path=args.language_model
).eval()
if args.srctext is not None and args.cached_score_file is not None:
raise ValueError("Only one of --srctext or --cached_score_file must be provided.")
if args.srctext is None and args.cached_score_file is None:
raise ValueError("Neither --srctext nor --cached_score_file were provided.")
if args.srctext is not None:
logging.info(f"Re-ranking: {args.srctext}")
else:
logging.info(f"Re-ranking from cached score file only: {args.cached_score_file}")
if args.cached_score_file is None:
if torch.cuda.is_available():
reverse_models = [model.cuda() for model in reverse_models]
lm_model = lm_model.cuda()
src_text = []
tgt_text = []
all_reverse_scores = []
all_lm_scores = []
all_forward_scores = []
all_src_lens = []
all_tgt_lens = []
# Chceck args if re-ranking from cached score file.
if args.cached_score_file is not None:
if args.write_scores:
raise ValueError("--write_scores cannot be provided with a cached score file.")
if args.reverse_model is not None:
raise ValueError(
"--reverse_model cannot be provided with a cached score file since it assumes reverse scores already present in the cached file."
)
if args.language_model is not None:
raise ValueError(
"--language_model cannot be provided with a cached score file since it assumes language model scores already present in the cached file."
)
if args.srctext is not None:
# Compute reverse scores and LM scores from the provided models since cached scores file is not provided.
with open(args.srctext, 'r') as src_f:
count = 0
for line in src_f:
src_text.append(line.strip().split('\t'))
if len(src_text) == args.beam_size:
# Source and target sequences are flipped for the reverse direction model.
src_texts = [item[1] for item in src_text]
tgt_texts = [item[0] for item in src_text]
src, src_mask = reverse_models[0].prepare_inference_batch(src_texts)
tgt, tgt_mask = reverse_models[0].prepare_inference_batch(tgt_texts, target=True)
src_lens = src_mask.sum(1).data.cpu().tolist()
tgt_lens = tgt_mask.sum(1).data.cpu().tolist()
forward_scores = [float(item[2]) for item in src_text]
# Ensemble of reverse model scores.
nmt_lls = []
for model in reverse_models:
nmt_log_probs = model(src, src_mask, tgt[:, :-1], tgt_mask[:, :-1])
nmt_nll = model.eval_loss_fn(log_probs=nmt_log_probs, labels=tgt[:, 1:])
nmt_ll = nmt_nll.view(nmt_log_probs.size(0), nmt_log_probs.size(1)).sum(1) * -1.0
nmt_ll = nmt_ll.data.cpu().numpy().tolist()
nmt_lls.append(nmt_ll)
reverse_scores = np.stack(nmt_lls).mean(0)
# LM scores.
if lm_model is not None:
# Compute LM score for the src of the reverse model.
lm_log_probs = lm_model(src[:, :-1], src_mask[:, :-1])
lm_nll = model.eval_loss_fn(log_probs=lm_log_probs, labels=src[:, 1:])
lm_ll = lm_nll.view(lm_log_probs.size(0), lm_log_probs.size(1)).sum(1) * -1.0
lm_ll = lm_ll.data.cpu().numpy().tolist()
else:
lm_ll = None
lm_scores = lm_ll
all_reverse_scores.extend(reverse_scores)
all_lm_scores.extend(lm_scores)
all_forward_scores.extend(forward_scores)
# Swapping source and target here back again since this is what gets written to the file.
all_src_lens.extend(tgt_lens)
all_tgt_lens.extend(src_lens)
fused_scores = score_fusion(args, forward_scores, reverse_scores, lm_scores, src_lens, tgt_lens)
tgt_text.append(src_texts[np.argmax(fused_scores)])
src_text = []
count += 1
print(f'Reranked {count} sentences')
else:
# Use reverse and LM scores from the cached scores file to re-rank.
with open(args.cached_score_file, 'r') as src_f:
count = 0
for line in src_f:
src_text.append(line.strip().split('\t'))
if len(src_text) == args.beam_size:
if not all([len(item) == 7 for item in src_text]):
raise IndexError(
"All lines did not contain exactly 5 fields. Format - src_txt \t tgt_text \t forward_score \t reverse_score \t lm_score \t src_len \t tgt_len"
)
src_texts = [item[0] for item in src_text]
tgt_texts = [item[1] for item in src_text]
forward_scores = [float(item[2]) for item in src_text]
reverse_scores = [float(item[3]) for item in src_text]
lm_scores = [float(item[4]) for item in src_text]
src_lens = [float(item[5]) for item in src_text]
tgt_lens = [float(item[6]) for item in src_text]
fused_scores = score_fusion(args, forward_scores, reverse_scores, lm_scores, src_lens, tgt_lens)
tgt_text.append(tgt_texts[np.argmax(fused_scores)])
src_text = []
count += 1
print(f'Reranked {count} sentences')
with open(args.tgtout, 'w') as tgt_f:
for line in tgt_text:
tgt_f.write(line + "\n")
# Write scores file
if args.write_scores:
with open(args.tgtout + '.scores', 'w') as tgt_f, open(args.srctext, 'r') as src_f:
src_lines = []
for line in src_f:
src_lines.append(line.strip().split('\t'))
if not (len(all_reverse_scores) == len(all_lm_scores) == len(all_forward_scores) == len(src_lines)):
raise ValueError(
f"Length of scores files do not match. {len(all_reverse_scores)} != {len(all_lm_scores)} != {len(all_forward_scores)} != {len(src_lines)}. This is most likely because --beam_size is set incorrectly. This needs to be set to the same value that was used to generate translations."
)
for f, r, lm, sl, tl, src in zip(
all_forward_scores, all_reverse_scores, all_lm_scores, all_src_lens, all_tgt_lens, src_lines
):
tgt_f.write(
src[0]
+ '\t'
+ src[1]
+ '\t'
+ str(f)
+ '\t'
+ str(r)
+ '\t'
+ str(lm)
+ '\t'
+ str(sl)
+ '\t'
+ str(tl)
+ '\n'
)
if __name__ == '__main__':
main() # noqa pylint: disable=no-value-for-parameter
|