|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import os |
|
from argparse import ArgumentParser |
|
|
|
import torch |
|
import torch.multiprocessing as mp |
|
from torch.utils.data import DataLoader |
|
|
|
from nemo.collections.nlp.data.language_modeling import TarredSentenceDataset |
|
from nemo.collections.nlp.data.machine_translation import TarredTranslationDataset |
|
from nemo.collections.nlp.models.machine_translation.mt_enc_dec_model import MTEncDecModel |
|
from nemo.utils import logging |
|
|
|
|
|
def get_args(): |
|
parser = ArgumentParser(description='Batch translation of sentences from a pre-trained model on multiple GPUs') |
|
parser.add_argument("--model", type=str, required=True, help="Path to the .nemo translation model file") |
|
parser.add_argument( |
|
"--text2translate", type=str, required=True, help="Path to the pre-processed tarfiles for translation" |
|
) |
|
parser.add_argument("--result_dir", type=str, required=True, help="Folder to write translation results") |
|
parser.add_argument( |
|
"--twoside", action="store_true", help="Set flag when translating the source side of a parallel dataset" |
|
) |
|
parser.add_argument( |
|
'--metadata_path', type=str, required=True, help="Path to the JSON file that contains dataset info" |
|
) |
|
parser.add_argument('--topk', type=int, default=500, help="Value of k for topk sampling") |
|
parser.add_argument('--src_language', type=str, required=True, help="Source lang ID for detokenization") |
|
parser.add_argument('--tgt_language', type=str, required=True, help="Target lang ID for detokenization") |
|
parser.add_argument( |
|
'--reverse_lang_direction', |
|
action="store_true", |
|
help="Reverse source and target language direction for parallel dataset", |
|
) |
|
parser.add_argument('--n_gpus', type=int, default=1, help="Number of GPUs to use") |
|
args = parser.parse_args() |
|
return args |
|
|
|
|
|
def translate(rank, world_size, args): |
|
if args.model.endswith(".nemo"): |
|
logging.info("Attempting to initialize from .nemo file") |
|
model = MTEncDecModel.restore_from(restore_path=args.model, map_location=f"cuda:{rank}") |
|
elif args.model.endswith(".ckpt"): |
|
logging.info("Attempting to initialize from .ckpt file") |
|
model = MTEncDecModel.load_from_checkpoint(checkpoint_path=args.model, map_location=f"cuda:{rank}") |
|
model.replace_beam_with_sampling(topk=args.topk) |
|
model.eval() |
|
if args.twoside: |
|
dataset = TarredTranslationDataset( |
|
text_tar_filepaths=args.text2translate, |
|
metadata_path=args.metadata_path, |
|
encoder_tokenizer=model.encoder_tokenizer, |
|
decoder_tokenizer=model.decoder_tokenizer, |
|
shuffle_n=100, |
|
shard_strategy="scatter", |
|
world_size=world_size, |
|
global_rank=rank, |
|
reverse_lang_direction=args.reverse_lang_direction, |
|
) |
|
else: |
|
dataset = TarredSentenceDataset( |
|
text_tar_filepaths=args.text2translate, |
|
metadata_path=args.metadata_path, |
|
tokenizer=model.encoder_tokenizer, |
|
shuffle_n=100, |
|
shard_strategy="scatter", |
|
world_size=world_size, |
|
global_rank=rank, |
|
) |
|
loader = DataLoader(dataset, batch_size=1) |
|
result_dir = os.path.join(args.result_dir, f'rank{rank}') |
|
os.makedirs(result_dir, exist_ok=True) |
|
originals_file_name = os.path.join(result_dir, 'originals.txt') |
|
translations_file_name = os.path.join(result_dir, 'translations.txt') |
|
num_translated_sentences = 0 |
|
|
|
with open(originals_file_name, 'w') as of, open(translations_file_name, 'w') as tf: |
|
for batch_idx, batch in enumerate(loader): |
|
for i in range(len(batch)): |
|
if batch[i].ndim == 3: |
|
batch[i] = batch[i].squeeze(dim=0) |
|
batch[i] = batch[i].to(rank) |
|
if args.twoside: |
|
src_ids, src_mask, _, _, _ = batch |
|
else: |
|
src_ids, src_mask = batch |
|
if batch_idx % 100 == 0: |
|
logging.info( |
|
f"{batch_idx} batches ({num_translated_sentences} sentences) were translated by process with " |
|
f"rank {rank}" |
|
) |
|
num_translated_sentences += len(src_ids) |
|
inputs, translations = model.batch_translate(src=src_ids, src_mask=src_mask) |
|
for src, translation in zip(inputs, translations): |
|
of.write(src + '\n') |
|
tf.write(translation + '\n') |
|
|
|
|
|
def main() -> None: |
|
args = get_args() |
|
world_size = torch.cuda.device_count() if args.n_gpus == -1 else args.n_gpus |
|
mp.spawn(translate, args=(world_size, args), nprocs=world_size, join=True) |
|
|
|
|
|
if __name__ == '__main__': |
|
main() |
|
|