Spaces:
Sleeping
Sleeping
#!/usr/bin/env python3 -u | |
# Copyright (c) Facebook, Inc. and its affiliates. | |
# | |
# This source code is licensed under the MIT license found in the | |
# LICENSE file in the root directory of this source tree. | |
""" | |
Translate pre-processed data with a trained model. | |
""" | |
import torch | |
from fairseq import bleu, checkpoint_utils, options, progress_bar, tasks, utils | |
from fairseq.meters import StopwatchMeter, TimeMeter | |
def main(args): | |
assert args.path is not None, '--path required for generation!' | |
assert not args.sampling or args.nbest == args.beam, \ | |
'--sampling requires --nbest to be equal to --beam' | |
assert args.replace_unk is None or args.raw_text, \ | |
'--replace-unk requires a raw text dataset (--raw-text)' | |
utils.import_user_module(args) | |
if args.max_tokens is None and args.max_sentences is None: | |
args.max_tokens = 12000 | |
print(args) | |
use_cuda = torch.cuda.is_available() and not args.cpu | |
# Load dataset splits | |
task = tasks.setup_task(args) | |
task.load_dataset(args.gen_subset) | |
# Set dictionaries | |
try: | |
src_dict = getattr(task, 'source_dictionary', None) | |
except NotImplementedError: | |
src_dict = None | |
tgt_dict = task.target_dictionary | |
# Load ensemble | |
print('| loading model(s) from {}'.format(args.path)) | |
models, _model_args = checkpoint_utils.load_model_ensemble( | |
args.path.split(':'), | |
arg_overrides=eval(args.model_overrides), | |
task=task, | |
) | |
# Optimize ensemble for generation | |
for model in models: | |
model.make_generation_fast_( | |
beamable_mm_beam_size=None if args.no_beamable_mm else args.beam, | |
need_attn=args.print_alignment, | |
) | |
if args.fp16: | |
model.half() | |
if use_cuda: | |
model.cuda() | |
# Load alignment dictionary for unknown word replacement | |
# (None if no unknown word replacement, empty if no path to align dictionary) | |
align_dict = utils.load_align_dict(args.replace_unk) | |
# Load dataset (possibly sharded) | |
itr = task.get_batch_iterator( | |
dataset=task.dataset(args.gen_subset), | |
max_tokens=args.max_tokens, | |
max_sentences=args.max_sentences, | |
max_positions=utils.resolve_max_positions( | |
task.max_positions(), | |
*[model.max_positions() for model in models] | |
), | |
ignore_invalid_inputs=args.skip_invalid_size_inputs_valid_test, | |
required_batch_size_multiple=args.required_batch_size_multiple, | |
num_shards=args.num_shards, | |
shard_id=args.shard_id, | |
num_workers=args.num_workers, | |
).next_epoch_itr(shuffle=False) | |
# Initialize generator | |
gen_timer = StopwatchMeter() | |
generator = task.build_generator(args) | |
# Generate and compute BLEU score | |
if args.sacrebleu: | |
scorer = bleu.SacrebleuScorer() | |
else: | |
scorer = bleu.Scorer(tgt_dict.pad(), tgt_dict.eos(), tgt_dict.unk()) | |
num_sentences = 0 | |
has_target = True | |
with progress_bar.build_progress_bar(args, itr) as t: | |
wps_meter = TimeMeter() | |
for sample in t: | |
sample = utils.move_to_cuda(sample) if use_cuda else sample | |
if 'net_input' not in sample: | |
continue | |
prefix_tokens = None | |
if args.prefix_size > 0: | |
prefix_tokens = sample['target'][:, :args.prefix_size] | |
gen_timer.start() | |
hypos = task.inference_step(generator, models, sample, prefix_tokens) | |
num_generated_tokens = sum(len(h[0]['tokens']) for h in hypos) | |
gen_timer.stop(num_generated_tokens) | |
for i, sample_id in enumerate(sample['id'].tolist()): | |
has_target = sample['target'] is not None | |
# Remove padding | |
src_tokens = utils.strip_pad(sample['net_input']['src_tokens'][i, :], tgt_dict.pad()) | |
target_tokens = None | |
if has_target: | |
target_tokens = utils.strip_pad(sample['target'][i, :], tgt_dict.pad()).int().cpu() | |
# Either retrieve the original sentences or regenerate them from tokens. | |
if align_dict is not None: | |
src_str = task.dataset(args.gen_subset).src.get_original_text(sample_id) | |
target_str = task.dataset(args.gen_subset).tgt.get_original_text(sample_id) | |
else: | |
if src_dict is not None: | |
src_str = src_dict.string(src_tokens, args.remove_bpe) | |
else: | |
src_str = "" | |
if has_target: | |
target_str = tgt_dict.string(target_tokens, args.remove_bpe, escape_unk=True) | |
if not args.quiet: | |
if src_dict is not None: | |
print('S-{}\t{}'.format(sample_id, src_str)) | |
if has_target: | |
print('T-{}\t{}'.format(sample_id, target_str)) | |
# Process top predictions | |
for j, hypo in enumerate(hypos[i][:args.nbest]): | |
hypo_tokens, hypo_str, alignment = utils.post_process_prediction( | |
hypo_tokens=hypo['tokens'].int().cpu(), | |
src_str=src_str, | |
alignment=hypo['alignment'], | |
align_dict=align_dict, | |
tgt_dict=tgt_dict, | |
remove_bpe=args.remove_bpe, | |
) | |
if not args.quiet: | |
print('H-{}\t{}\t{}'.format(sample_id, hypo['score'], hypo_str)) | |
print('P-{}\t{}'.format( | |
sample_id, | |
' '.join(map( | |
lambda x: '{:.4f}'.format(x), | |
hypo['positional_scores'].tolist(), | |
)) | |
)) | |
if args.print_alignment: | |
print('A-{}\t{}'.format( | |
sample_id, | |
' '.join(['{}-{}'.format(src_idx, tgt_idx) for src_idx, tgt_idx in alignment]) | |
)) | |
if args.print_step: | |
print('I-{}\t{}'.format(sample_id, hypo['steps'])) | |
if getattr(args, 'retain_iter_history', False): | |
print("\n".join([ | |
'E-{}_{}\t{}'.format( | |
sample_id, step, | |
utils.post_process_prediction( | |
h['tokens'].int().cpu(), | |
src_str, None, None, tgt_dict, None)[1]) | |
for step, h in enumerate(hypo['history'])])) | |
# Score only the top hypothesis | |
if has_target and j == 0: | |
if align_dict is not None or args.remove_bpe is not None: | |
# Convert back to tokens for evaluation with unk replacement and/or without BPE | |
target_tokens = tgt_dict.encode_line(target_str, add_if_not_exist=True) | |
if hasattr(scorer, 'add_string'): | |
scorer.add_string(target_str, hypo_str) | |
else: | |
scorer.add(target_tokens, hypo_tokens) | |
wps_meter.update(num_generated_tokens) | |
t.log({'wps': round(wps_meter.avg)}) | |
num_sentences += sample['nsentences'] | |
print('| Translated {} sentences ({} tokens) in {:.1f}s ({:.2f} sentences/s, {:.2f} tokens/s)'.format( | |
num_sentences, gen_timer.n, gen_timer.sum, num_sentences / gen_timer.sum, 1. / gen_timer.avg)) | |
if has_target: | |
print('| Generate {} with beam={}: {}'.format(args.gen_subset, args.beam, scorer.result_string())) | |
return scorer | |
def cli_main(): | |
parser = options.get_generation_parser() | |
args = options.parse_args_and_arch(parser) | |
main(args) | |
if __name__ == '__main__': | |
cli_main() | |