name: megatron_t5_finetuning | |
trainer: | |
devices: 2 | |
num_nodes: 1 | |
accelerator: gpu | |
precision: 16 | |
logger: False # logger provided by exp_manager | |
enable_checkpointing: False | |
replace_sampler_ddp: False | |
max_epochs: 10 | |
max_steps: -1 # consumed_samples = global_step * micro_batch_size * data_parallel_size * accumulate_grad_batches | |
log_every_n_steps: 10 | |
val_check_interval: 300 | |
accumulate_grad_batches: 1 | |
gradient_clip_val: 1.0 | |
exp_manager: | |
explicit_log_dir: null | |
exp_dir: null | |
name: megatron_t5_finetune | |
create_wandb_logger: False | |
wandb_logger_kwargs: | |
project: null | |
name: null | |
resume_if_exists: True | |
resume_ignore_no_checkpoint: True | |
create_checkpoint_callback: True | |
checkpoint_callback_params: | |
monitor: validation_${model.data.validation_ds.metric.name} | |
save_top_k: 10 | |
mode: max | |
always_save_nemo: False # TODO: add support | |
filename: 'megatron_t5--{${exp_manager.checkpoint_callback_params.monitor}:.3f}-{step}' | |
model_parallel_size: ${model.tensor_model_parallel_size} | |
save_best_model: True | |
model: | |
restore_from_path: null # Path to a trained T5 .nemo file | |
pretrained_checkpoint: | |
checkpoint_dir: null # Path to a folder that contains a .ckpt file | |
checkpoint_name: null # Name of the .ckpt file within the checkpoint_dir. | |
hparams_file: null # Path to a .yaml file that contains the hyperparameters of the checkpoint. | |
tensor_model_parallel_size: 1 | |
pipeline_model_parallel_size: 1 | |
pipeline_model_parallel_split_rank: 0 | |
gradient_as_bucket_view: True # Allocate gradients in a contiguous bucket to save memory (less fragmentation and buffer memory) | |
megatron_amp_O2: False # Enable O2 optimization for megatron amp | |
resume_from_checkpoint: null | |
hidden_dropout: 0.1 # Override dropout prob from pretraining | |
attention_dropout: 0.1 # Override attention dropout prob from pretraining | |
data: | |
train_ds: | |
src_file_name: ??? # Path to the txt file corresponding to the source data. | |
tgt_file_name: ??? # Path to the txt file corresponding to the target data. | |
global_batch_size: 128 | |
micro_batch_size: 64 | |
shuffle: True | |
num_workers: 0 | |
pin_memory: True | |
max_src_seq_length: 512 | |
max_tgt_seq_length: 128 | |
drop_last: True | |
concat_sampling_technique: temperature # When providing a list of datasets, this arg defines the sampling strategy. Options: ['temperature', 'random'] | |
concat_sampling_temperature: 5 # When providing a list of datasets, this arg defines the sampling temperature when strategy='temperature' | |
concat_sampling_probabilities: null # When providing a list of datasets, this arg defines the sampling probabilities from each dataset when strategy='random' | |
replace_bos_with_pad: False # Replaces bos with pad for both the encoder and decoder. This is necessary when using Google's T5 checkpoints. | |
add_bos_to_input: False # Adds bos to the input sequence. | |
add_eos_to_input: False # Adds eos to the input sequence. | |
validation_ds: | |
src_file_name: ??? # Path to the txt file corresponding to the source data. | |
tgt_file_name: ??? # Path to the txt file corresponding to the target data. | |
names: null # If src/tgt file names are ListConfigs, the corresponding label is used to log metrics. | |
global_batch_size: 128 | |
micro_batch_size: 64 | |
shuffle: False | |
num_workers: 0 | |
pin_memory: True | |
max_src_seq_length: 512 | |
max_tgt_seq_length: 128 | |
drop_last: False # TODO: Figure out if there is a way to avoid dropping last. | |
write_predictions_to_file: False | |
output_file_path_prefix: null # Prefix of the file to write predictions to. | |
replace_bos_with_pad: ${data.train_ds.replace_bos_with_pad} | |
add_bos_to_input: ${data.train_ds.add_bos_to_input} | |
add_eos_to_input: ${data.train_ds.add_eos_to_input} | |
metric: | |
name: "exact_string_match" # Name of the evaluation metric to use. | |
average: micro # Average the metric over the dataset. Options: ['macro', 'micro']. Works only for 'F1', 'accuracy' etc. Refer to torchmetrics for metrics where this is supported. | |
num_classes: null # Number of classes for the metric. Works only for 'F1', 'accuracy' and 'average_precision' etc. Refer to torchmetrics for metrics where this is supported. | |
class_labels: null # If the targets in your dataset are strings and not integers/float, you need to provide a list of class labels (size = num_classes) so we can convert from strings to integer categories to compute the metric. | |
labels_are_strings: True # NOTE: This is only required to properly handle metrics like f1, accuracy, average_precision etc. This does not affect extract_string_match. | |
optim: | |
name: fused_adam | |
lr: 5e-6 | |
weight_decay: 0.0 | |