Spaces:
Sleeping
Sleeping
import logging, os, sys | |
import time | |
import torch | |
from torch import Tensor | |
from typing import Dict, List, Optional | |
import copy | |
from tqdm import tqdm | |
from omegaconf import open_dict | |
import fairseq | |
from fairseq.checkpoint_utils import load_model_ensemble_and_task | |
from fairseq import utils | |
from fairseq.data import data_utils | |
import argparse | |
logging.basicConfig( | |
format="%(asctime)s | %(levelname)s | %(name)s | [%(filename)s:%(lineno)d] %(message)s", | |
datefmt="%Y-%m-%d %H:%M:%S", | |
level=os.environ.get("LOGLEVEL", "INFO").upper(), | |
stream=sys.stdout, | |
) | |
logger = logging.getLogger("inference") | |
def write_result(results, output_file): | |
with open(output_file, 'w') as f: | |
for line in results: | |
f.write(line + '\n') | |
def fairseq_generate(data_lines, args, models, task, batch_size, beam_size, device): | |
# beam search | greedy decoding implemented by fairseq | |
src_dict = task.source_dictionary | |
tgt_dict = task.target_dictionary | |
gen_args = copy.copy(args) | |
with open_dict(gen_args): | |
gen_args.beam = beam_size | |
generator = task.build_generator(models, gen_args) | |
data_size = len(data_lines) | |
all_results = [] | |
logger.info(f'Fairseq generate batch {batch_size}, beam {beam_size}') | |
start = time.perf_counter() | |
for start_idx in tqdm(range(0, data_size, batch_size)): | |
batch_lines = [line for line in data_lines[start_idx: min(start_idx + batch_size, data_size)]] | |
batch_ids = [src_dict.encode_line(sentence, add_if_not_exist=False).long() for sentence in batch_lines] | |
lengths = torch.LongTensor([t.numel() for t in batch_ids]) | |
batch_dataset = task.build_dataset_for_inference(batch_ids, lengths) | |
batch_dataset.left_pad_source = True | |
batch = batch_dataset.collater(batch_dataset) | |
batch = utils.apply_to_sample(lambda t: t.to(device), batch) | |
translations = generator.generate(models, batch, prefix_tokens=None) | |
results = [] | |
for id, hypos in zip(batch["id"].tolist(), translations): | |
results.append((id, hypos)) | |
batched_hypos = [hypos for _, hypos in sorted(results, key=lambda x: x[0])] | |
all_results.extend([tgt_dict.string(hypos[0]['tokens']) for hypos in batched_hypos]) | |
delta = time.perf_counter() - start | |
remove_bpe_results = [line.replace('@@ ', '') for line in all_results] | |
return remove_bpe_results, delta | |
def forward_decoder(model, | |
input_tokens, | |
encoder_out, | |
incremental_state, | |
parallel_forward_start_pos=None, | |
temperature=1.0, | |
use_log_softmax=True): | |
decoder_out = model.decoder.forward(input_tokens, | |
encoder_out=encoder_out, | |
incremental_state=incremental_state, | |
parallel_forward_start_pos=parallel_forward_start_pos) | |
decoder_out_tuple = (decoder_out[0].div_(temperature), decoder_out[1]) | |
if use_log_softmax: | |
# 1, len, vocab | |
probs = model.get_normalized_probs(decoder_out_tuple, log_probs=True, sample=None) | |
else: | |
probs = decoder_out_tuple[0] | |
# len | |
pred_tokens = torch.argmax(probs, dim=-1).squeeze(0) | |
return pred_tokens | |
def baseline_generate(data_lines, model, task, device, no_use_logsoft=False, max_len=200): | |
# simplified greedy decoding | |
src_dict = task.source_dictionary | |
tgt_dict = task.target_dictionary | |
data_size = len(data_lines) | |
all_results = [] | |
start = time.perf_counter() | |
logger.info(f'Baseline generate') | |
for start_idx in tqdm(range(0, data_size)): | |
bpe_line = data_lines[start_idx] | |
src_tokens = src_dict.encode_line(bpe_line, add_if_not_exist=False).long() | |
net_input = {'src_tokens': src_tokens.unsqueeze(0).to(device), | |
'src_lengths': torch.LongTensor([src_tokens.numel()]).to(device)} | |
encoder_out = model.encoder.forward_torchscript(net_input) | |
incremental_state = torch.jit.annotate(Dict[str, Dict[str, Optional[Tensor]]], | |
torch.jit.annotate(Dict[str, Dict[str, Optional[Tensor]]], {})) | |
tokens = [tgt_dict.eos()] | |
for step in range(0, max_len): | |
cur_input_tokens = torch.tensor([tokens]).to(device).long() | |
# scalar | |
pred_token = forward_decoder(model, | |
cur_input_tokens, | |
encoder_out, | |
incremental_state, | |
use_log_softmax=not no_use_logsoft).item() | |
if pred_token == tgt_dict.eos(): | |
break | |
else: | |
tokens.append(pred_token) | |
all_results.append(tgt_dict.string(tokens[1:])) | |
delta = time.perf_counter() - start | |
remove_bpe_results = [line.replace('@@ ', '') for line in all_results] | |
return remove_bpe_results, delta | |
def construct_hash_sets(sent, min_gram=1, max_gram=3): | |
hash_dict = {} | |
for i in range(0, len(sent) - min_gram + 1): | |
for j in range(min_gram, max_gram+1): | |
if i + j <= len(sent): | |
ngram = tuple(sent[i: i+j]) | |
if ngram not in hash_dict: | |
hash_dict[ngram] = [] | |
hash_dict[ngram].append(i+j) | |
return hash_dict | |
def find_hash_sets(hash_set, tokens, min_gram=1, max_gram=3): | |
for i in range(min_gram, max_gram+1): | |
if len(tokens) < i: | |
return -1 | |
ngram = tuple(tokens[-i:]) | |
if ngram not in hash_set: | |
return -1 | |
if len(hash_set[ngram]) == 1: | |
return hash_set[ngram][0] | |
return -1 | |
def cut_incremental_state(incremental_state, keep_len, encoder_state_ids): | |
for n in incremental_state: | |
if n[: n.index('.')] in encoder_state_ids: | |
continue | |
for k in incremental_state[n]: | |
if incremental_state[n][k] is not None: | |
if incremental_state[n][k].dim() == 4: | |
incremental_state[n][k] = incremental_state[n][k][:, :, :keep_len] | |
elif incremental_state[n][k].dim() == 2: | |
incremental_state[n][k] = incremental_state[n][k][:, :keep_len] | |
def aggressive_generate(data_lines, model, task, device, no_use_logsoft=False, max_len=200): | |
src_dict = task.source_dictionary | |
tgt_dict = task.target_dictionary | |
encoder_state_ids = [] | |
for i in range(len(model.decoder.layers)): | |
encoder_state_ids.append(model.decoder.layers[i].encoder_attn._incremental_state_id) | |
data_size = len(data_lines) | |
all_results = [] | |
start_time = time.perf_counter() | |
for start_idx in tqdm(range(0, data_size)): | |
bpe_line = data_lines[start_idx] | |
src_tokens = src_dict.encode_line(bpe_line, add_if_not_exist=False).long() | |
net_input = {'src_tokens': src_tokens.unsqueeze(0).to(device), | |
'src_lengths': torch.LongTensor([src_tokens.numel()]).to(device)} | |
encoder_out = model.encoder.forward_torchscript(net_input) | |
incremental_state = torch.jit.annotate(Dict[str, Dict[str, Optional[Tensor]]], | |
torch.jit.annotate(Dict[str, Dict[str, Optional[Tensor]]], {})) | |
src_tokens_remove_eos_list = src_tokens[:-1].tolist() | |
src_hash = construct_hash_sets(src_tokens_remove_eos_list) | |
start = 0 | |
tokens = [tgt_dict.eos()] | |
while start < len(src_tokens_remove_eos_list) and len(tokens) < max_len + 1: | |
cur_span_input_tokens = torch.tensor([tokens + src_tokens_remove_eos_list[start:]]).to(device).long() | |
pred_tokens = forward_decoder(model, | |
cur_span_input_tokens, | |
encoder_out, | |
incremental_state, | |
parallel_forward_start_pos=len(tokens) - 1, | |
use_log_softmax=not no_use_logsoft) | |
pred_judge = pred_tokens.cpu() == src_tokens[start:] | |
if all(pred_judge): | |
tokens += src_tokens[start:].tolist() | |
break | |
else: | |
wrong_pos = pred_judge.tolist().index(False) | |
start += wrong_pos | |
tokens.extend(pred_tokens.cpu().tolist()[: wrong_pos + 1]) | |
cut_incremental_state(incremental_state, keep_len=len(tokens) - 1, encoder_state_ids=encoder_state_ids) | |
cur_len = len(tokens) | |
for step in range(cur_len, max_len + 1): | |
cur_input_tokens = torch.tensor([tokens]).to(device).long() | |
pred_token = forward_decoder(model, | |
cur_input_tokens, | |
encoder_out, | |
incremental_state, | |
use_log_softmax=not no_use_logsoft).item() | |
if pred_token == tgt_dict.eos(): | |
start = len(src_tokens_remove_eos_list) | |
break | |
else: | |
tokens.append(pred_token) | |
find_end_idx = find_hash_sets(src_hash, tokens) | |
if find_end_idx != -1: | |
start = find_end_idx | |
if start < len(src_tokens_remove_eos_list): | |
break | |
if len(tokens) > max_len + 1: | |
tokens = tokens[:max_len + 1] | |
all_results.append(tgt_dict.string(tokens[1:])) | |
delta = time.perf_counter() - start_time | |
remove_bpe_results = [line.replace('@@ ', '') for line in all_results] | |
return remove_bpe_results, delta | |
def paper_aggressive_generate(data_lines, model, task, device, no_use_logsoft=False, max_len=200): | |
src_dict = task.source_dictionary | |
tgt_dict = task.target_dictionary | |
encoder_state_ids = [] | |
for i in range(len(model.decoder.layers)): | |
encoder_state_ids.append(model.decoder.layers[i].encoder_attn._incremental_state_id) | |
data_size = len(data_lines) | |
all_results = [] | |
start_time = time.perf_counter() | |
for start_idx in tqdm(range(0, data_size)): | |
bpe_line = data_lines[start_idx] | |
src_tokens = src_dict.encode_line(bpe_line, add_if_not_exist=False).long() | |
net_input = {'src_tokens': src_tokens.unsqueeze(0).to(device), | |
'src_lengths': torch.LongTensor([src_tokens.numel()]).to(device)} | |
encoder_out = model.encoder.forward_torchscript(net_input) | |
incremental_state = torch.jit.annotate(Dict[str, Dict[str, Optional[Tensor]]], | |
torch.jit.annotate(Dict[str, Dict[str, Optional[Tensor]]], {})) | |
src_tokens_remove_eos_list = src_tokens[:-1].tolist() | |
src_hash = construct_hash_sets(src_tokens_remove_eos_list) | |
src_tokens_add_pad_list = torch.tensor(src_tokens_remove_eos_list + [-1]) # [..., -1] | |
tokens = [tgt_dict.eos()] | |
while (len(tokens) == 1 or tokens[-1] != tgt_dict.eos()) and len(tokens) < max_len + 1: | |
if len(tokens) == 1: | |
find_end_idx = 0 | |
else: | |
find_end_idx = find_hash_sets(src_hash, tokens) | |
if find_end_idx != -1 and find_end_idx < len(src_tokens_remove_eos_list): | |
cur_span_input_tokens = torch.tensor([tokens + src_tokens_remove_eos_list[find_end_idx:]]).to(device).long() | |
pred_tokens = forward_decoder(model, | |
cur_span_input_tokens, | |
encoder_out, | |
incremental_state, | |
parallel_forward_start_pos=len(tokens) - 1, | |
use_log_softmax=not no_use_logsoft) | |
pred_judge = pred_tokens.cpu() == src_tokens_add_pad_list[find_end_idx:] | |
wrong_pos = pred_judge.tolist().index(False) | |
tokens.extend(pred_tokens.cpu().tolist()[: wrong_pos + 1]) | |
cut_incremental_state(incremental_state, keep_len=len(tokens) - 1, encoder_state_ids=encoder_state_ids) | |
else: | |
cur_input_tokens = torch.tensor([tokens]).to(device).long() | |
pred_token = forward_decoder(model, | |
cur_input_tokens, | |
encoder_out, | |
incremental_state, | |
use_log_softmax=not no_use_logsoft).item() | |
tokens.append(pred_token) | |
if len(tokens) > max_len + 1: | |
tokens = tokens[:max_len + 1] | |
if tokens[-1] == tgt_dict.eos(): | |
tokens = tokens[:-1] | |
all_results.append(tgt_dict.string(tokens[1:])) | |
delta = time.perf_counter() - start_time | |
remove_bpe_results = [line.replace('@@ ', '') for line in all_results] | |
return remove_bpe_results, delta | |
if __name__ == '__main__': | |
parser = argparse.ArgumentParser() | |
parser.add_argument('--checkpoint-path', type=str, required=True, | |
help='path to model file, e.g., /to/path/checkpoint_best.pt') | |
parser.add_argument('--bin-data', type=str, required=True, | |
help='directory containing src and tgt dictionaries') | |
parser.add_argument('--input-path', type=str, required=True, | |
help='path to eval file, e.g., /to/path/conll14.bpe.txt') | |
parser.add_argument('--output-path', type=str, default=None, | |
help='path to output file, e.g., /to/path/conll14.pred.txt') | |
parser.add_argument('--batch', type=int, default=None, | |
help='batch size') | |
parser.add_argument('--beam', type=int, default=5, | |
help='beam size') | |
parser.add_argument('--baseline', action='store_true', default=False, | |
help='greedy/one-by-one decoding') | |
parser.add_argument('--aggressive', action='store_true', default=False, | |
help='aggressive decoding') | |
parser.add_argument('--no_use_logsoft', action='store_true', default=False, | |
help='not use log_softmax when aggressive decoding') | |
parser.add_argument('--block', type=int, default=None) | |
parser.add_argument('--match', type=int, default=1) | |
parser.add_argument('--cpu', action='store_true', default=False) | |
cmd_args = parser.parse_args() | |
cmd_args.checkpoint_path = os.path.expanduser(cmd_args.checkpoint_path) | |
cmd_args.bin_data = os.path.expanduser(cmd_args.bin_data) | |
cmd_args.input_path = os.path.expanduser(cmd_args.input_path) | |
cmd_args.output_path = os.path.expanduser(cmd_args.output_path) | |
models, args, task = load_model_ensemble_and_task(filenames=[cmd_args.checkpoint_path], | |
arg_overrides={'data': cmd_args.bin_data}) | |
if cmd_args.cpu: | |
device = torch.device('cpu') | |
else: | |
device = torch.device('cuda') | |
model = models[0].to(device).eval() | |
with open(cmd_args.input_path, 'r') as f: | |
bpe_sents = [l.strip() for l in f.readlines()] | |
if cmd_args.batch is not None: | |
remove_bpe_results, delta = fairseq_generate(bpe_sents, args, models, task, cmd_args.batch, cmd_args.beam, device) | |
logger.info(f'Fairseq generate batch {cmd_args.batch}, beam {cmd_args.beam}: {delta}') | |
elif cmd_args.baseline: | |
remove_bpe_results, delta = baseline_generate(bpe_sents, model, task, device, no_use_logsoft=cmd_args.no_use_logsoft) | |
logger.info(f'Baseline generate: {delta}') | |
elif cmd_args.aggressive: | |
remove_bpe_results, delta = paper_aggressive_generate(bpe_sents, model, task, device, no_use_logsoft=cmd_args.no_use_logsoft) | |
# remove_bpe_results, delta = aggressive_generate(bpe_sents, model, task, device, no_use_logsoft=cmd_args.no_use_logsoft) | |
logger.info(f'Aggressive generate: {delta}') | |
if cmd_args.output_path is not None: | |
write_result(remove_bpe_results, cmd_args.output_path) | |