camenduru's picture
thanks to NVIDIA ❤
7934b29
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from argparse import ArgumentParser
import torch
import torch.multiprocessing as mp
from torch.utils.data import DataLoader
from nemo.collections.nlp.data.language_modeling import TarredSentenceDataset
from nemo.collections.nlp.data.machine_translation import TarredTranslationDataset
from nemo.collections.nlp.models.machine_translation.mt_enc_dec_model import MTEncDecModel
from nemo.utils import logging
def get_args():
parser = ArgumentParser(description='Batch translation of sentences from a pre-trained model on multiple GPUs')
parser.add_argument("--model", type=str, required=True, help="Path to the .nemo translation model file")
parser.add_argument(
"--text2translate", type=str, required=True, help="Path to the pre-processed tarfiles for translation"
)
parser.add_argument("--result_dir", type=str, required=True, help="Folder to write translation results")
parser.add_argument(
"--twoside", action="store_true", help="Set flag when translating the source side of a parallel dataset"
)
parser.add_argument(
'--metadata_path', type=str, required=True, help="Path to the JSON file that contains dataset info"
)
parser.add_argument('--topk', type=int, default=500, help="Value of k for topk sampling")
parser.add_argument('--src_language', type=str, required=True, help="Source lang ID for detokenization")
parser.add_argument('--tgt_language', type=str, required=True, help="Target lang ID for detokenization")
parser.add_argument(
'--reverse_lang_direction',
action="store_true",
help="Reverse source and target language direction for parallel dataset",
)
parser.add_argument('--n_gpus', type=int, default=1, help="Number of GPUs to use")
args = parser.parse_args()
return args
def translate(rank, world_size, args):
if args.model.endswith(".nemo"):
logging.info("Attempting to initialize from .nemo file")
model = MTEncDecModel.restore_from(restore_path=args.model, map_location=f"cuda:{rank}")
elif args.model.endswith(".ckpt"):
logging.info("Attempting to initialize from .ckpt file")
model = MTEncDecModel.load_from_checkpoint(checkpoint_path=args.model, map_location=f"cuda:{rank}")
model.replace_beam_with_sampling(topk=args.topk)
model.eval()
if args.twoside:
dataset = TarredTranslationDataset(
text_tar_filepaths=args.text2translate,
metadata_path=args.metadata_path,
encoder_tokenizer=model.encoder_tokenizer,
decoder_tokenizer=model.decoder_tokenizer,
shuffle_n=100,
shard_strategy="scatter",
world_size=world_size,
global_rank=rank,
reverse_lang_direction=args.reverse_lang_direction,
)
else:
dataset = TarredSentenceDataset(
text_tar_filepaths=args.text2translate,
metadata_path=args.metadata_path,
tokenizer=model.encoder_tokenizer,
shuffle_n=100,
shard_strategy="scatter",
world_size=world_size,
global_rank=rank,
)
loader = DataLoader(dataset, batch_size=1)
result_dir = os.path.join(args.result_dir, f'rank{rank}')
os.makedirs(result_dir, exist_ok=True)
originals_file_name = os.path.join(result_dir, 'originals.txt')
translations_file_name = os.path.join(result_dir, 'translations.txt')
num_translated_sentences = 0
with open(originals_file_name, 'w') as of, open(translations_file_name, 'w') as tf:
for batch_idx, batch in enumerate(loader):
for i in range(len(batch)):
if batch[i].ndim == 3:
batch[i] = batch[i].squeeze(dim=0)
batch[i] = batch[i].to(rank)
if args.twoside:
src_ids, src_mask, _, _, _ = batch
else:
src_ids, src_mask = batch
if batch_idx % 100 == 0:
logging.info(
f"{batch_idx} batches ({num_translated_sentences} sentences) were translated by process with "
f"rank {rank}"
)
num_translated_sentences += len(src_ids)
inputs, translations = model.batch_translate(src=src_ids, src_mask=src_mask)
for src, translation in zip(inputs, translations):
of.write(src + '\n')
tf.write(translation + '\n')
def main() -> None:
args = get_args()
world_size = torch.cuda.device_count() if args.n_gpus == -1 else args.n_gpus
mp.spawn(translate, args=(world_size, args), nprocs=world_size, join=True)
if __name__ == '__main__':
main()