hma / train_multi.py
LeroyWaa's picture
draft
246c106
raw
history blame contribute delete
16.4 kB
import logging
import math
import os
import mup
import numpy as np
import torch
from accelerate import Accelerator
from accelerate.logging import get_logger
from accelerate.utils import set_seed
from torch.utils.data import DataLoader
from tqdm.auto import tqdm
import transformers
from transformers import (
default_data_collator,
get_scheduler,
)
import wandb
from data import RawTokenDataset, get_maskgit_collator
from genie.st_mask_git import GenieConfig, STMaskGIT
from datetime import datetime
from accelerate import DistributedDataParallelKwargs
from common import data_sampler
import yaml
from train import parse_args, train
# Get current date and time
now = datetime.now()
# Format the datetime object as a string
formatted_date = now.strftime("%Y-%m-%d %H:%M:%S")
torch.set_float32_matmul_precision("medium")
logger = get_logger(__name__)
torch.autograd.set_detect_anomaly(True)
def parse_args_multi():
# parser = argparse.ArgumentParser(description="Train a MaskGIT or Llama-style LLM on video generation.")
parser = parse_args()
# Data
parser.add_argument(
"--train_split", type=str, default="experiments/datasplit/dataset2.yaml",
help="Config files for using multiple datasets."
)
parser.add_argument(
"--num_episodes_per_dataset",
type=int,
default=1000000,
help="Maximum number of trajectories per dataset",
)
parser.add_argument(
"--image_maskgit_path",
type=str,
default=None,
help="Optional path to the official MaskGIT checkpoint. "
"If specified, will copy relevant weights from the checkpoint. "
"These weights will have a different (hard-coded) warmup schedule.",
)
parser.add_argument(
"--action_network",
type=str,
default=None,
choices=["concat", "cross_attention"], # TODO: add other methods (resampler_concat, modulate, etc)
help="If specified, will override the action in the config. Helps reduce the number of config jsons."
)
args = parser.parse_args()
return args
def main():
args = parse_args_multi()
assert (args.llama_config is not None) ^ (args.genie_config is not None), \
"Exactly one of `llama_config` and `genie_config` should be set."
# Manual gradient accumulation
ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=True)
accelerator = Accelerator(gradient_accumulation_steps=1, log_with=args.report_to,
even_batches=False, project_dir=args.output_dir, kwargs_handlers=[ddp_kwargs])
accelerator.init_trackers("video")
if accelerator.is_main_process:
accelerator.trackers[0].run.name = formatted_date + "_" + args.run_name
# Make one log on every process with the configuration for debugging.
logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
datefmt="%m/%d/%Y %H:%M:%S",
level=logging.INFO,
)
logger.info(accelerator.state, main_process_only=False)
if accelerator.is_local_main_process:
transformers.utils.logging.set_verbosity_info()
else:
transformers.utils.logging.set_verbosity_error()
if args.seed is not None:
set_seed(args.seed)
if accelerator.is_main_process:
os.makedirs(args.output_dir, exist_ok=True)
accelerator.wait_for_everyone()
# create multiple datasets
with open(args.train_split, 'r') as file:
datasplit = yaml.safe_load(file)
config = GenieConfig.from_pretrained(args.genie_config)
# Extract the 'domains' value and split it into a list
domains_list = [domain.strip() for domain in datasplit['domains'].split(',')]
train_datasets = []
val_datasets = []
dataset_num_samples = []
val_dataset_num_samples = []
action_dimensions = []
action_stats = []
total_num_videos = 0
for domain in domains_list:
try:
train_data_dir = f"data/{domain}_magvit_traj1000000_train" # {args.num_episodes_per_dataset}
val_data_dir = f"data/{domain}_magvit_traj1000000_val"
train_dataset = RawTokenDataset(train_data_dir, window_size=args.window_size, name=domain,
stride=args.stride, filter_overlaps=args.filter_overlaps,
max_traj_num=args.num_episodes_per_dataset,
use_actions=config.use_actions, drop_action_ratio=config.drop_action_ratio)
dataset_num_samples.append(len(train_dataset))
action_dimensions.append(train_dataset.n_action)
total_num_videos += train_dataset.num_videos
if config.use_actions:
action_stats.append(train_dataset.action_stat)
if not args.overfit_first_batch:
eval_dataset = RawTokenDataset(val_data_dir, window_size=args.window_size, name=domain,
stride=args.stride, filter_overlaps=True,
use_actions=config.use_actions, drop_action_ratio=config.drop_action_ratio)
else:
train_dataset.valid_start_inds = train_dataset.valid_start_inds[:args.per_device_train_batch_size
* args.gradient_accumulation_steps
* accelerator.num_processes]
eval_dataset = train_dataset
# Shuffle eval dataset and then set shuffle=False on the dataloader.
# Shuffling in the dataloader results in reshuffling with each iteration.
eval_dataset.valid_start_inds = torch.tensor(eval_dataset.valid_start_inds)[
torch.randperm(len(eval_dataset), generator=torch.Generator().manual_seed(0))
].tolist()
val_dataset_num_samples.append(len(eval_dataset))
except Exception as e:
import traceback
print(traceback.format_exc())
train_datasets.append(train_dataset)
val_datasets.append(eval_dataset)
assert all(train_dataset.metadata[shared_key] == eval_dataset.metadata[shared_key]
for shared_key in ("s", "vocab_size", "hz"))
print("dataset_num_samples:", dataset_num_samples)
latent_side_len, vocab_size, hz = [train_dataset.metadata[key] for key in ("s", "vocab_size", "hz")]
config.use_mup = args.mu_transfer # Note: changing this may affect pre-trained model due to attn scaling
config.image_vocab_size = vocab_size
config.T = args.window_size
if args.action_network is not None:
print("Using action network", args.action_network)
config.action_network = args.action_network
# config.S = latent_side_len**2
model = STMaskGIT(config)
if config.use_actions:
model.init_action_projectors(domains_list, action_dimensions, action_stats, config.action_network)
if args.image_maskgit_path is not None:
model.init_weights()
model.load_pretrained_image_weights(args.image_maskgit_path)
if args.mu_transfer:
model.set_mup_shapes(rescale_params=False)
elif args.mu_transfer:
model.set_mup_shapes(rescale_params=True)
# model.init_weights() # might be unnecessary if `rescale_params` is True
# Optimizer. Split weights in two groups, one with weight decay and the other not.
opt_class = mup.MuAdamW if args.mu_transfer else torch.optim.AdamW
# scale base learning rate
effective_batch_size = args.per_device_train_batch_size * args.gradient_accumulation_steps \
* accelerator.num_processes
args.learning_rate = args.learning_rate * min(max(1, effective_batch_size / 64), 8)
no_decay = ["bias", "layer_norm.weight"]
pretrained_params = { # more accurately the params we want lower lr for, some weights like pos_embed_TSC are pre-trained but not treated as lower lr
param_name
for param_name, _ in model.named_parameters()
if any(term in param_name for term in ("spatial_attn.qkv", "spatial_attn.proj", "mlp"))
} if args.image_maskgit_path is not None else set()
# Give pre-trained weights 10x lower learning rate
optimizer_grouped_parameters = [
{
"params": [p for n, p in model.named_parameters()
if not any(nd in n for nd in no_decay) and n not in pretrained_params],
"weight_decay": args.weight_decay,
"lr": args.learning_rate,
},
{
"params": [p for n, p in model.named_parameters()
if any(nd in n for nd in no_decay) and n not in pretrained_params],
"weight_decay": 0.0,
"lr": args.learning_rate,
},
{
"params": [p for n, p in model.named_parameters()
if not any(nd in n for nd in no_decay) and n in pretrained_params],
"weight_decay": args.weight_decay,
"lr": args.learning_rate * 0.1,
},
{
"params": [p for n, p in model.named_parameters()
if any(nd in n for nd in no_decay) and n in pretrained_params],
"weight_decay": 0.0,
"lr": args.learning_rate * 0.1,
},
]
optimizer = opt_class(optimizer_grouped_parameters, lr=args.learning_rate,
betas=(args.adam_beta_1, args.adam_beta_2), eps=args.adam_eps)
# DataLoaders creation:
collate_fn = default_data_collator if args.llama_config is not None else get_maskgit_collator(config)
combined_dataset = torch.utils.data.ConcatDataset(train_datasets)
batch_sampler = data_sampler.MultiTaskBatchSampler(
dataset_num_samples,
batch_size=args.per_device_train_batch_size,
temperature=3. # the higher the more flat the distribution
)
dataset_traj_image = data_sampler.make_dataset_pie_plot(domains_list, dataset_num_samples)
accelerator.log(({"dataset_mixture": wandb.Image(dataset_traj_image)}), log_kwargs={"wandb": {"commit": False}})
dataset_weights = batch_sampler.generate_tasks_distribution().cpu().numpy()
dataset_weight_image = data_sampler.make_dataset_pie_plot(domains_list, dataset_weights)
accelerator.log(({"dataset_mixture_weight": wandb.Image(dataset_weight_image)}), log_kwargs={"wandb": {"commit": False}})
train_dataloader = DataLoader(combined_dataset, batch_sampler=batch_sampler, collate_fn=collate_fn,
num_workers=16, pin_memory=True)
batch_val_sampler = data_sampler.MultiTaskBatchSampler(
val_dataset_num_samples,
batch_size=args.per_device_train_batch_size,
temperature=4. # the higher the more flat the distribution
)
combined_val_dataset = torch.utils.data.ConcatDataset(val_datasets)
eval_dataloader = DataLoader(combined_val_dataset, batch_sampler=batch_val_sampler, collate_fn=collate_fn,
num_workers=16, pin_memory=True)
# Scheduler and math around the number of training steps.
overrode_max_train_steps = False
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
if args.max_train_steps is None:
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
overrode_max_train_steps = True
if args.max_train_steps < 2000 and args.resume_from_checkpoint is None: # minimal number of trainng steps
args.max_train_steps = 2000
args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
if args.lr_scheduler_type == "custom_cosine": # decay to `end_ratio` of the peak learning rate
def get_lr_wrapper(warmup_steps, max_steps, end_ratio=0.1):
def get_lr(step):
if step < warmup_steps:
return (step + 1) / warmup_steps
remaining_steps = max_steps - warmup_steps
return ((1 + math.cos(math.pi * (step - warmup_steps) / remaining_steps)) / 2) \
* (1 - end_ratio) + end_ratio
return get_lr
lr_scheduler = torch.optim.lr_scheduler.LambdaLR(
optimizer, get_lr_wrapper(args.num_warmup_steps * accelerator.num_processes,
args.max_train_steps if overrode_max_train_steps
else args.max_train_steps * accelerator.num_processes)
)
else:
lr_scheduler = get_scheduler(
name=args.lr_scheduler_type,
optimizer=optimizer,
num_warmup_steps=args.num_warmup_steps * accelerator.num_processes,
num_training_steps=args.max_train_steps
if overrode_max_train_steps
else args.max_train_steps * accelerator.num_processes,
)
# Enable gradient checkpointing to save memory
if args.gradient_checkpointing:
logger.info("Enabling gradient checkpointing")
model.gradient_checkpointing_enable()
model.config.use_cache = False # incompatible with grad checkpointing
# Prepare everything with our `accelerator`.
accelerator.wait_for_everyone()
model, optimizer, train_dataloader, eval_dataloader, lr_scheduler = accelerator.prepare(
model, optimizer, train_dataloader, eval_dataloader, lr_scheduler
)
if not args.no_compile:
torch._dynamo.config.cache_size_limit = 256
torch._dynamo.config.optimize_ddp = False # https://github.com/pytorch/pytorch/issues/104674
# TODO: https://github.com/pytorch/pytorch/issues/109774#issuecomment-2046633776
model = torch.compile(model)
# We need to recalculate our total training steps as the size of the training dataloader may have changed.
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
if overrode_max_train_steps:
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
# Afterwards we recalculate our number of training epochs
args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
# Figure out how many steps we should save the Accelerator states
checkpointing_steps = args.checkpointing_steps
if checkpointing_steps is not None and checkpointing_steps.isdigit():
checkpointing_steps = int(checkpointing_steps)
# We need to initialize the trackers we use, and also store our configuration.
# The trackers initialize automatically on the main process.
experiment_config = vars(args) | vars(config)
seq_len = latent_side_len**2 * args.window_size
effective_batch_size = args.per_device_train_batch_size * args.gradient_accumulation_steps \
* accelerator.num_processes
args.num_datasets = len(train_datasets)
model_module = model.module if hasattr(model, "module") else model
experiment_config.update({
"model_parameters": sum(p.numel() for p in model.parameters()),
"model_parameters_M": round(sum(p.numel() for p in model.parameters()) / 1e6),
"trunk_parameters": sum(p.numel() for p in model_module.decoder.parameters()),
"trunk_parameters_M": round(sum(p.numel() for p in model_module.decoder.parameters()) / 1e6),
"seq_len": seq_len,
"hz": hz / args.stride if args.stride is not None else hz,
"train_data_tokens": len(train_dataset) * seq_len,
"effective_batch_size": effective_batch_size,
"effective_batch_size_tokens": effective_batch_size * seq_len,
"mixed_precision": accelerator.mixed_precision,
"num_datasets": args.num_datasets,
"total_num_videos": total_num_videos,
})
experiment_config["FLOPs_per_update_step"] = 6 * experiment_config["model_parameters"] \
* experiment_config["effective_batch_size_tokens"]
accelerator.init_trackers(project_name="video", config=experiment_config)
# Train!
train(accelerator, model, optimizer, lr_scheduler, train_dataloader, eval_dataloader, experiment_config, config, args)
if __name__ == "__main__":
main()