Tzktz's picture
Upload 7664 files
6fc683c verified
#!/usr/bin/env python3 -u
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
"""
Train a new model on one or across multiple GPUs.
"""
import collections
import math
import random
import numpy as np
import torch
from fairseq import checkpoint_utils, distributed_utils, options, progress_bar, tasks, utils
from fairseq.data import iterators
from fairseq.trainer import Trainer
from fairseq.meters import AverageMeter, StopwatchMeter
def main(args, init_distributed=False):
utils.import_user_module(args)
assert args.max_tokens is not None or args.max_sentences is not None, \
'Must specify batch size either with --max-tokens or --max-sentences'
# Initialize CUDA and distributed training
if torch.cuda.is_available() and not args.cpu:
torch.cuda.set_device(args.device_id)
np.random.seed(args.seed)
torch.manual_seed(args.seed)
if init_distributed:
args.distributed_rank = distributed_utils.distributed_init(args)
if distributed_utils.is_master(args):
checkpoint_utils.verify_checkpoint_directory(args.save_dir)
# Print args
print(args)
# Setup task, e.g., translation, language modeling, etc.
task = tasks.setup_task(args)
# Load valid dataset (we load training data below, based on the latest checkpoint)
for valid_sub_split in args.valid_subset.split(','):
task.load_dataset(valid_sub_split, combine=False, epoch=0)
# Build model and criterion
model = task.build_model(args)
criterion = task.build_criterion(args)
print(model)
print('| model {}, criterion {}'.format(args.arch, criterion.__class__.__name__))
print('| num. model params: {} (num. trained: {})'.format(
sum(p.numel() for p in model.parameters()),
sum(p.numel() for p in model.parameters() if p.requires_grad),
))
# Build trainer
trainer = Trainer(args, task, model, criterion)
print('| training on {} GPUs'.format(args.distributed_world_size))
print('| max tokens per GPU = {} and max sentences per GPU = {}'.format(
args.max_tokens,
args.max_sentences,
))
# Load the latest checkpoint if one is available and restore the
# corresponding train iterator
extra_state, epoch_itr = checkpoint_utils.load_checkpoint(args, trainer)
# Prepare train
task.prepare_train(model, criterion)
# Train until the learning rate gets too small
max_epoch = args.max_epoch or math.inf
max_update = args.max_update or math.inf
lr = trainer.get_lr()
train_meter = StopwatchMeter()
train_meter.start()
valid_subsets = args.valid_subset.split(',')
while (
lr > args.min_lr
and (epoch_itr.epoch < max_epoch or (epoch_itr.epoch == max_epoch
and epoch_itr._next_epoch_itr is not None))
and trainer.get_num_updates() < max_update
):
# train for one epoch
train(args, trainer, task, epoch_itr)
if not args.disable_validation and epoch_itr.epoch % args.validate_interval == 0:
valid_losses = validate(args, trainer, task, epoch_itr, valid_subsets)
else:
valid_losses = [None]
# only use first validation loss to update the learning rate
lr = trainer.lr_step(epoch_itr.epoch, valid_losses[0])
# save checkpoint
if epoch_itr.epoch % args.save_interval == 0:
checkpoint_utils.save_checkpoint(args, trainer, epoch_itr, valid_losses[0])
reload_dataset = ':' in getattr(args, 'data', '')
reload_dataset = reload_dataset or args.reload_dataset_per_epoch
# sharded data: get train iterator for next epoch
epoch_itr = trainer.get_train_iterator(epoch_itr.epoch, load_dataset=reload_dataset)
train_meter.stop()
print('| done training in {:.1f} seconds'.format(train_meter.sum))
def train(args, trainer, task, epoch_itr):
"""Train the model for one epoch."""
# Update parameters every N batches
print("| Start train.train ..." , flush=True)
update_freq = args.update_freq[epoch_itr.epoch - 1] \
if epoch_itr.epoch <= len(args.update_freq) else args.update_freq[-1]
# Initialize data iterator
itr = epoch_itr.next_epoch_itr(
fix_batches_to_gpus=args.fix_batches_to_gpus,
shuffle=(epoch_itr.epoch >= args.curriculum),
)
print("| Itr init (1) ...", flush=True)
itr = iterators.GroupedIterator(itr, update_freq)
progress = progress_bar.build_progress_bar(
args, itr, epoch_itr.epoch, no_progress_bar='simple',
)
print("| Itr init (2) ...", flush=True)
extra_meters = collections.defaultdict(lambda: AverageMeter())
valid_subsets = args.valid_subset.split(',')
max_update = args.max_update or math.inf
# ##################### DEBUG #####################
# debug_samples = []
# print("Fetch debug examples ...")
# for i in range(1000):
# debug_samples.append(next(itr))
# progress = progress_bar.build_progress_bar(
# args, iter(debug_samples), epoch_itr.epoch, no_progress_bar='simple',
# )
# ##################### DEBUG #####################
for i, samples in enumerate(progress, start=epoch_itr.iterations_in_epoch):
log_output = trainer.train_step(samples)
if log_output is None:
continue
# log mid-epoch stats
stats = get_training_stats(trainer)
for k, v in log_output.items():
if k in ['loss', 'nll_loss', 'ntokens', 'nsentences', 'sample_size']:
continue # these are already logged above
if 'loss' in k or k == 'accuracy':
extra_meters[k].update(v, log_output['sample_size'])
else:
extra_meters[k].update(v)
stats[k] = extra_meters[k].val
progress.log(stats, tag='train', step=stats['num_updates'])
# ignore the first mini-batch in words-per-second and updates-per-second calculation
if i == 0:
trainer.get_meter('wps').reset()
trainer.get_meter('ups').reset()
num_updates = trainer.get_num_updates()
if (
not args.disable_validation
and args.save_interval_updates > 0
and num_updates % args.save_interval_updates == 0
and num_updates > 0
):
valid_losses = validate(args, trainer, task, epoch_itr, valid_subsets)
checkpoint_utils.save_checkpoint(args, trainer, epoch_itr, valid_losses[0])
elif (args.save_interval_updates > 0
and num_updates % args.save_interval_updates == 0
and num_updates > 0):
checkpoint_utils.save_checkpoint(args, trainer, epoch_itr, None)
if num_updates >= max_update:
break
# log end-of-epoch stats
stats = get_training_stats(trainer)
for k, meter in extra_meters.items():
stats[k] = meter.val
progress.print(stats, tag='train', step=stats['num_updates'])
# reset training meters
for k in [
'train_loss', 'train_nll_loss', 'wps', 'ups', 'wpb', 'bsz', 'gnorm', 'clip',
]:
meter = trainer.get_meter(k)
if meter is not None:
meter.reset()
def get_training_stats(trainer):
stats = collections.OrderedDict()
stats['loss'] = trainer.get_meter('train_loss')
if trainer.get_meter('train_nll_loss').count > 0:
nll_loss = trainer.get_meter('train_nll_loss')
stats['nll_loss'] = nll_loss
else:
nll_loss = trainer.get_meter('train_loss')
stats['ppl'] = utils.get_perplexity(nll_loss.avg)
stats['wps'] = trainer.get_meter('wps')
stats['ups'] = trainer.get_meter('ups')
stats['wpb'] = trainer.get_meter('wpb')
stats['bsz'] = trainer.get_meter('bsz')
stats['num_updates'] = trainer.get_num_updates()
stats['lr'] = trainer.get_lr()
stats['gnorm'] = trainer.get_meter('gnorm')
stats['clip'] = trainer.get_meter('clip')
stats['oom'] = trainer.get_meter('oom')
if trainer.get_meter('loss_scale') is not None:
stats['loss_scale'] = trainer.get_meter('loss_scale')
stats['wall'] = round(trainer.get_meter('wall').elapsed_time)
stats['train_wall'] = trainer.get_meter('train_wall')
return stats
def validate(args, trainer, task, epoch_itr, subsets):
"""Evaluate the model on the validation set(s) and return the losses."""
if args.fixed_validation_seed is not None:
# set fixed seed for every validation
utils.set_torch_seed(args.fixed_validation_seed)
valid_losses = []
for subset in subsets:
# Initialize data iterator
itr = task.get_batch_iterator(
dataset=task.dataset(subset),
max_tokens=args.max_tokens_valid,
max_sentences=args.max_sentences_valid,
max_positions=utils.resolve_max_positions(
task.max_positions(),
trainer.get_model().max_positions(),
),
ignore_invalid_inputs=args.skip_invalid_size_inputs_valid_test,
required_batch_size_multiple=args.required_batch_size_multiple,
seed=args.seed,
num_shards=args.distributed_world_size,
shard_id=args.distributed_rank,
num_workers=args.num_workers,
).next_epoch_itr(shuffle=False)
progress = progress_bar.build_progress_bar(
args, itr, epoch_itr.epoch,
prefix='valid on \'{}\' subset'.format(subset),
no_progress_bar='simple'
)
# reset validation loss meters
for k in ['valid_loss', 'valid_nll_loss']:
meter = trainer.get_meter(k)
if meter is not None:
meter.reset()
extra_meters = collections.defaultdict(lambda: AverageMeter())
for sample in progress:
log_output = trainer.valid_step(sample)
for k, v in log_output.items():
if k in ['loss', 'nll_loss', 'ntokens', 'nsentences', 'sample_size']:
continue
extra_meters[k].update(v)
# log validation stats
stats = get_valid_stats(trainer, args, extra_meters)
for k, meter in extra_meters.items():
stats[k] = meter.avg
progress.print(stats, tag=subset, step=trainer.get_num_updates())
valid_losses.append(
stats[args.best_checkpoint_metric].avg
if args.best_checkpoint_metric == 'loss'
else stats[args.best_checkpoint_metric]
)
return valid_losses
def get_valid_stats(trainer, args, extra_meters=None):
stats = collections.OrderedDict()
stats['loss'] = trainer.get_meter('valid_loss')
if trainer.get_meter('valid_nll_loss').count > 0:
nll_loss = trainer.get_meter('valid_nll_loss')
stats['nll_loss'] = nll_loss
else:
nll_loss = stats['loss']
stats['ppl'] = utils.get_perplexity(nll_loss.avg)
stats['num_updates'] = trainer.get_num_updates()
if hasattr(checkpoint_utils.save_checkpoint, 'best'):
key = 'best_{0}'.format(args.best_checkpoint_metric)
best_function = max if args.maximize_best_checkpoint_metric else min
current_metric = None
if args.best_checkpoint_metric == 'loss':
current_metric = stats['loss'].avg
elif args.best_checkpoint_metric in extra_meters:
current_metric = extra_meters[args.best_checkpoint_metric].avg
elif args.best_checkpoint_metric in stats:
current_metric = stats[args.best_checkpoint_metric]
else:
raise ValueError("best_checkpoint_metric not found in logs")
stats[key] = best_function(
checkpoint_utils.save_checkpoint.best,
current_metric,
)
return stats
def distributed_main(i, args, start_rank=0):
args.device_id = i
if args.distributed_rank is None: # torch.multiprocessing.spawn
args.distributed_rank = start_rank + i
main(args, init_distributed=True)
def cli_main():
parser = options.get_training_parser()
args = options.parse_args_and_arch(parser)
if args.distributed_init_method is None:
distributed_utils.infer_init_method(args)
if args.distributed_init_method is not None:
# distributed training
if torch.cuda.device_count() > 1 and not args.distributed_no_spawn:
start_rank = args.distributed_rank
args.distributed_rank = None # assign automatically
torch.multiprocessing.spawn(
fn=distributed_main,
args=(args, start_rank),
nprocs=torch.cuda.device_count(),
)
else:
distributed_main(args.device_id, args)
elif args.distributed_world_size > 1:
# fallback for single node with multiple GPUs
assert args.distributed_world_size <= torch.cuda.device_count()
port = random.randint(10000, 20000)
args.distributed_init_method = 'tcp://localhost:{port}'.format(port=port)
args.distributed_rank = None # set based on device id
if max(args.update_freq) > 1 and args.ddp_backend != 'no_c10d':
print('| NOTE: you may get better performance with: --ddp-backend=no_c10d')
torch.multiprocessing.spawn(
fn=distributed_main,
args=(args, ),
nprocs=args.distributed_world_size,
)
else:
# single GPU training
main(args)
if __name__ == '__main__':
cli_main()