File size: 14,132 Bytes
7934b29
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
# Copyright (c) 2021, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""
This script implemts Noisy Channel Reranking (NCR) - https://arxiv.org/abs/1908.05731
Given .nemo files for a, reverse model (target -> source) and transformer LM (target LM) NMT model's .nemo file, 
this script can be used to re-rank a forward model's (source -> target) beam candidates.

This script can be used in two ways 1) Given the score file generated by `nmt_transformer_infer.py`, re-rank beam candidates and 
2) Given NCR score file generated by 1), Re-rank beam candidates based only on cached scores in the ncr file. This is meant to tune NCR coeficients.

Pre-requisite: Generating translations using `nmt_transformer_infer.py`
1. Obtain text file in src language. You can use sacrebleu to obtain standard test sets like so:
    sacrebleu -t wmt14 -l de-en --echo src > wmt14-de-en.src
2. Translate using `nmt_transformer_infer.py` with a large beam size.:
    python nmt_transformer_infer.py --model=[Path to .nemo file(s)] --srctext=wmt14-de-en.src --tgtout=wmt14-de-en.translations --beam_size 15 --write_scores

USAGE Example (case 1):
Re-rank beam candidates:
    python noisy_channel_reranking.py \
        --reverse_model=[Path to .nemo file] \
        --language_model=[Path to .nemo file] \
        --srctext=wmt14-de-en.translations.scores \
        --tgtout=wmt14-de-en.ncr.translations \
        --forward_model_coef=1.0 \
        --reverse_model_coef=0.7 \
        --target_lm_coef=0.05 \
        --write_scores \

USAGE Example (case 2):
Re-rank beam candidates using cached score file only
    python noisy_channel_reranking.py \
        --cached_score_file=wmt14-de-en.ncr.translations.scores \
        --forward_model_coef=1.0 \
        --reverse_model_coef=0.7 \
        --target_lm_coef=0.05 \
        --tgtout=wmt14-de-en.ncr.translations \
"""


from argparse import ArgumentParser

import numpy as np
import torch

import nemo.collections.nlp as nemo_nlp
from nemo.utils import logging


def score_fusion(args, forward_scores, rev_scores, lm_scores, src_lens, tgt_lens):
    """
    Fuse forward, reverse and language model scores.
    """
    fused_scores = []
    for forward_score, rev_score, lm_score, src_len, tgt_len in zip(
        forward_scores, rev_scores, lm_scores, src_lens, tgt_lens
    ):
        score = 0

        forward_score = forward_score / tgt_len if args.length_normalize_scores else forward_score
        score += args.forward_model_coef * forward_score

        rev_score = rev_score / src_len if args.length_normalize_scores else rev_score
        score += args.reverse_model_coef * rev_score

        lm_score = lm_score / tgt_len if args.length_normalize_scores else lm_score
        score += args.target_lm_coef * lm_score

        if args.len_pen is not None:
            score = score / (((5 + tgt_len) / 6) ** args.len_pen)

        fused_scores.append(score)

    return fused_scores


def main():
    parser = ArgumentParser()
    parser.add_argument(
        "--reverse_model",
        type=str,
        help="Path to .nemo model file(s). If ensembling, provide comma separated paths to multiple models.",
    )
    parser.add_argument(
        "--language_model", type=str, help="Optional path to an LM model that has the same tokenizer as NMT models.",
    )
    parser.add_argument(
        "--forward_model_coef",
        type=float,
        default=1.0,
        help="Weight assigned to the forward NMT model for re-ranking.",
    )
    parser.add_argument(
        "--reverse_model_coef",
        type=float,
        default=0.7,
        help="Weight assigned to the reverse NMT model for re-ranking.",
    )
    parser.add_argument(
        "--target_lm_coef", type=float, default=0.07, help="Weight assigned to the target LM model for re-ranking.",
    )
    parser.add_argument(
        "--srctext",
        type=str,
        default=None,
        help="Path to a TSV file containing forward model scores of the format source \t beam_candidate_i \t forward_score",
    )
    parser.add_argument(
        "--cached_score_file",
        type=str,
        default=None,
        help="Path to a TSV file containing cached scores for each beam candidate. Format source \t target \t forward_score \t reverse_score \t lm_score \t src_len \t tgt_len",
    )
    parser.add_argument(
        "--tgtout", type=str, required=True, help="Path to the file where re-ranked translations are to be written."
    )
    parser.add_argument(
        "--beam_size",
        type=int,
        default=4,
        help="Beam size with which forward model translations were generated. IMPORTANT: mismatch can lead to wrong results and an incorrect number of generated translations.",
    )
    parser.add_argument(
        "--target_lang", type=str, default=None, help="Target language identifier ex: en,de,fr,es etc."
    )
    parser.add_argument(
        "--source_lang", type=str, default=None, help="Source language identifier ex: en,de,fr,es etc."
    )
    parser.add_argument(
        "--write_scores", action="store_true", help="Whether to write forward, reverse and lm scores to a file."
    )
    parser.add_argument(
        "--length_normalize_scores",
        action="store_true",
        help="If true, it will divide forward, reverse and lm scores by the corresponding sequence length.",
    )
    parser.add_argument(
        "--len_pen",
        type=float,
        default=None,
        help="Apply a length penalty based on target lengths to the final NCR score.",
    )

    args = parser.parse_args()
    torch.set_grad_enabled(False)

    if args.cached_score_file is None:
        reverse_models = []
        for model_path in args.reverse_model.split(','):
            if not model_path.endswith('.nemo'):
                raise NotImplementedError(f"Only support .nemo files, but got: {model_path}")
            model = nemo_nlp.models.machine_translation.MTEncDecModel.restore_from(restore_path=model_path).eval()
            model.eval_loss_fn.reduction = 'none'
            reverse_models.append(model)

        lm_model = nemo_nlp.models.language_modeling.TransformerLMModel.restore_from(
            restore_path=args.language_model
        ).eval()

    if args.srctext is not None and args.cached_score_file is not None:
        raise ValueError("Only one of --srctext or --cached_score_file must be provided.")

    if args.srctext is None and args.cached_score_file is None:
        raise ValueError("Neither --srctext nor --cached_score_file were provided.")

    if args.srctext is not None:
        logging.info(f"Re-ranking: {args.srctext}")
    else:
        logging.info(f"Re-ranking from cached score file only: {args.cached_score_file}")

    if args.cached_score_file is None:
        if torch.cuda.is_available():
            reverse_models = [model.cuda() for model in reverse_models]
            lm_model = lm_model.cuda()

    src_text = []
    tgt_text = []
    all_reverse_scores = []
    all_lm_scores = []
    all_forward_scores = []
    all_src_lens = []
    all_tgt_lens = []

    # Chceck args if re-ranking from cached score file.
    if args.cached_score_file is not None:
        if args.write_scores:
            raise ValueError("--write_scores cannot be provided with a cached score file.")
        if args.reverse_model is not None:
            raise ValueError(
                "--reverse_model cannot be provided with a cached score file since it assumes reverse scores already present in the cached file."
            )
        if args.language_model is not None:
            raise ValueError(
                "--language_model cannot be provided with a cached score file since it assumes language model scores already present in the cached file."
            )

    if args.srctext is not None:
        # Compute reverse scores and LM scores from the provided models since cached scores file is not provided.
        with open(args.srctext, 'r') as src_f:
            count = 0
            for line in src_f:
                src_text.append(line.strip().split('\t'))
                if len(src_text) == args.beam_size:
                    # Source and target sequences are flipped for the reverse direction model.
                    src_texts = [item[1] for item in src_text]
                    tgt_texts = [item[0] for item in src_text]
                    src, src_mask = reverse_models[0].prepare_inference_batch(src_texts)
                    tgt, tgt_mask = reverse_models[0].prepare_inference_batch(tgt_texts, target=True)
                    src_lens = src_mask.sum(1).data.cpu().tolist()
                    tgt_lens = tgt_mask.sum(1).data.cpu().tolist()
                    forward_scores = [float(item[2]) for item in src_text]

                    # Ensemble of reverse model scores.
                    nmt_lls = []
                    for model in reverse_models:
                        nmt_log_probs = model(src, src_mask, tgt[:, :-1], tgt_mask[:, :-1])
                        nmt_nll = model.eval_loss_fn(log_probs=nmt_log_probs, labels=tgt[:, 1:])
                        nmt_ll = nmt_nll.view(nmt_log_probs.size(0), nmt_log_probs.size(1)).sum(1) * -1.0
                        nmt_ll = nmt_ll.data.cpu().numpy().tolist()
                        nmt_lls.append(nmt_ll)
                    reverse_scores = np.stack(nmt_lls).mean(0)

                    # LM scores.
                    if lm_model is not None:
                        # Compute LM score for the src of the reverse model.
                        lm_log_probs = lm_model(src[:, :-1], src_mask[:, :-1])
                        lm_nll = model.eval_loss_fn(log_probs=lm_log_probs, labels=src[:, 1:])
                        lm_ll = lm_nll.view(lm_log_probs.size(0), lm_log_probs.size(1)).sum(1) * -1.0
                        lm_ll = lm_ll.data.cpu().numpy().tolist()
                    else:
                        lm_ll = None
                    lm_scores = lm_ll

                    all_reverse_scores.extend(reverse_scores)
                    all_lm_scores.extend(lm_scores)
                    all_forward_scores.extend(forward_scores)

                    # Swapping source and target here back again since this is what gets written to the file.
                    all_src_lens.extend(tgt_lens)
                    all_tgt_lens.extend(src_lens)

                    fused_scores = score_fusion(args, forward_scores, reverse_scores, lm_scores, src_lens, tgt_lens)
                    tgt_text.append(src_texts[np.argmax(fused_scores)])
                    src_text = []
                    count += 1
                    print(f'Reranked {count} sentences')

    else:
        # Use reverse and LM scores from the cached scores file to re-rank.
        with open(args.cached_score_file, 'r') as src_f:
            count = 0
            for line in src_f:
                src_text.append(line.strip().split('\t'))
                if len(src_text) == args.beam_size:
                    if not all([len(item) == 7 for item in src_text]):
                        raise IndexError(
                            "All lines did not contain exactly 5 fields. Format - src_txt \t tgt_text \t forward_score \t reverse_score \t lm_score \t src_len \t tgt_len"
                        )
                    src_texts = [item[0] for item in src_text]
                    tgt_texts = [item[1] for item in src_text]
                    forward_scores = [float(item[2]) for item in src_text]
                    reverse_scores = [float(item[3]) for item in src_text]
                    lm_scores = [float(item[4]) for item in src_text]
                    src_lens = [float(item[5]) for item in src_text]
                    tgt_lens = [float(item[6]) for item in src_text]

                    fused_scores = score_fusion(args, forward_scores, reverse_scores, lm_scores, src_lens, tgt_lens)
                    tgt_text.append(tgt_texts[np.argmax(fused_scores)])
                    src_text = []
                    count += 1
                    print(f'Reranked {count} sentences')

    with open(args.tgtout, 'w') as tgt_f:
        for line in tgt_text:
            tgt_f.write(line + "\n")

    # Write scores file
    if args.write_scores:
        with open(args.tgtout + '.scores', 'w') as tgt_f, open(args.srctext, 'r') as src_f:
            src_lines = []
            for line in src_f:
                src_lines.append(line.strip().split('\t'))
            if not (len(all_reverse_scores) == len(all_lm_scores) == len(all_forward_scores) == len(src_lines)):
                raise ValueError(
                    f"Length of scores files do not match. {len(all_reverse_scores)} != {len(all_lm_scores)} != {len(all_forward_scores)} != {len(src_lines)}. This is most likely because --beam_size is set incorrectly. This needs to be set to the same value that was used to generate translations."
                )
            for f, r, lm, sl, tl, src in zip(
                all_forward_scores, all_reverse_scores, all_lm_scores, all_src_lens, all_tgt_lens, src_lines
            ):
                tgt_f.write(
                    src[0]
                    + '\t'
                    + src[1]
                    + '\t'
                    + str(f)
                    + '\t'
                    + str(r)
                    + '\t'
                    + str(lm)
                    + '\t'
                    + str(sl)
                    + '\t'
                    + str(tl)
                    + '\n'
                )


if __name__ == '__main__':
    main()  # noqa pylint: disable=no-value-for-parameter