Spaces:
Sleeping
Sleeping
# Copyright (c) Facebook, Inc. and its affiliates. | |
# | |
# This source code is licensed under the MIT license found in the | |
# LICENSE file in the root directory of this source tree. | |
""" | |
Train a network across multiple GPUs. | |
""" | |
from fairseq.dataclass.configs import FairseqConfig | |
from fairseq.distributed import utils as distributed_utils | |
from fairseq.trainer import Trainer | |
try: | |
from fairseq.model_parallel.megatron.mpu import ( | |
get_data_parallel_rank, | |
get_data_parallel_world_size, | |
get_model_parallel_src_rank, | |
get_cuda_rng_tracker, | |
) | |
has_megatron_submodule = True | |
except (ImportError, ModuleNotFoundError): | |
has_megatron_submodule = False | |
class MegatronTrainer(Trainer): | |
"""Main class for model parallel with data parallel training.""" | |
def __init__(self, cfg: FairseqConfig, task, model, criterion, **kwargs): | |
if not has_megatron_submodule: | |
raise ImportError( | |
"\n\nPlease install the megatron submodule:" | |
"\n\n git submodule update --init " | |
"fairseq/model_parallel/megatron" | |
) | |
super().__init__(cfg, task, model, criterion, **kwargs) | |
def clip_grad_norm(self, clip_norm): | |
def _aggregate_model_parallel_grad_norm(total_norm): | |
total_norm = total_norm ** 2 | |
distributed_utils.all_reduce( | |
total_norm, group=distributed_utils.get_model_parallel_group() | |
) | |
total_norm = total_norm ** 0.5 | |
return total_norm | |
return self.optimizer.clip_grad_norm( | |
clip_norm, | |
aggregate_norm_fn=_aggregate_model_parallel_grad_norm, | |
) | |
def save_checkpoint(self, filename, extra_state): | |
"""Save all training state in a checkpoint file.""" | |
extra_state['rng_tracker_states'] \ | |
= get_cuda_rng_tracker().get_states() | |
super().save_checkpoint(filename, extra_state) | |
def load_checkpoint( | |
self, | |
filename, | |
reset_optimizer=False, | |
reset_lr_scheduler=False, | |
optimizer_overrides=None, | |
reset_meters=False, | |
): | |
extra_state = super().load_checkpoint(filename, reset_optimizer=reset_optimizer, reset_lr_scheduler=reset_lr_scheduler, optimizer_overrides=optimizer_overrides, reset_meters=reset_meters) | |
if extra_state is not None and 'rng_tracker_states' in extra_state: | |
get_cuda_rng_tracker().set_states( | |
extra_state['rng_tracker_states']) | |
return extra_state | |