hma / train_multi_diffusion.py
LeroyWaa's picture
draft
246c106
raw
history blame contribute delete
17.3 kB
import logging
import math
import os
import mup
import numpy as np
import torch
from accelerate import Accelerator
from accelerate.logging import get_logger
from accelerate.utils import set_seed
from torch.utils.data import DataLoader
from tqdm.auto import tqdm
import transformers
from transformers import (
default_data_collator,
get_scheduler,
)
import wandb
from cont_data import RawFeatureDataset, get_maskgit_collator_feature
from genie.config import DiffusionGenieConfig
from genie.st_mar import STMAR
from datetime import datetime
from accelerate import DistributedDataParallelKwargs
from common import data_sampler
import yaml
from train_diffusion import parse_args, train
# Get current date and time
now = datetime.now()
# Format the datetime object as a string
formatted_date = now.strftime("%Y-%m-%d %H:%M:%S")
torch.set_float32_matmul_precision("medium")
logger = get_logger(__name__)
torch.autograd.set_detect_anomaly(True)
def parse_args_multi():
# parser = argparse.ArgumentParser(description="Train a MaskGIT or Llama-style LLM on video generation.")
parser = parse_args()
# Data
parser.add_argument(
"--train_split", type=str, default="experiments/datasplit/dataset2.yaml",
help="Config files for using multiple datasets."
)
parser.add_argument(
"--num_episodes_per_dataset",
type=int,
default=1000000,
help="Maximum number of trajectories per dataset",
)
parser.add_argument(
"--image_maskgit_path",
type=str,
default=None,
help="Optional path to the official MaskGIT checkpoint. "
"If specified, will copy relevant weights from the checkpoint. "
"These weights will have a different (hard-coded) warmup schedule.",
)
parser.add_argument(
"--action_network",
type=str,
default=None,
choices=["concat", "cross_attention"], # TODO: add other methods (resampler_concat, modulate, etc)
help="If specified, will override the action in the config. Helps reduce the number of config jsons."
)
args = parser.parse_args()
return args
def main():
args = parse_args_multi()
assert (args.llama_config is not None) ^ (args.genie_config is not None), \
"Exactly one of `llama_config` and `genie_config` should be set."
# Manual gradient accumulation
ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=True)
accelerator = Accelerator(gradient_accumulation_steps=1, log_with=args.report_to,
even_batches=False, project_dir=args.output_dir, kwargs_handlers=[ddp_kwargs])
accelerator.init_trackers("video")
if accelerator.is_main_process:
accelerator.trackers[0].run.name = formatted_date + "_" + args.run_name
# Make one log on every process with the configuration for debugging.
logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
datefmt="%m/%d/%Y %H:%M:%S",
level=logging.INFO,
)
logger.info(accelerator.state, main_process_only=False)
if accelerator.is_local_main_process:
transformers.utils.logging.set_verbosity_info()
else:
transformers.utils.logging.set_verbosity_error()
if args.seed is not None:
set_seed(args.seed)
if accelerator.is_main_process:
os.makedirs(args.output_dir, exist_ok=True)
accelerator.wait_for_everyone()
# create multiple datasets
with open(args.train_split, 'r') as file:
datasplit = yaml.safe_load(file)
config = DiffusionGenieConfig.from_pretrained(args.genie_config)
# Extract the 'domains' value and split it into a list
domains_list = [domain.strip() for domain in datasplit['domains'].split(',')]
train_datasets = []
val_datasets = []
dataset_num_samples = []
val_dataset_num_samples = []
action_dimensions = []
action_stats = []
shared_keys = ("s", "h", "w", "vocab_size", "latent_channels",
"encoder_type", "encoder_name_or_path", "quantized") # TODO: check train/val hz per dataset?
for domain in domains_list:
try:
# train_data_dir = f"data/{domain}_vae_traj500_train" # {args.num_episodes_per_dataset}
# val_data_dir = f"data/{domain}_vae_traj500_val"
train_data_dir = f"data/{domain}_noquant_temporalvae_shard0_of_1_train" # {args.num_episodes_per_dataset}
val_data_dir = f"data/{domain}_noquant_temporalvae_shard0_of_1_val"
# train_data_dir = f"data/{domain}_vae_traj{args.num_episodes_per_dataset}_train" # {args.num_episodes_per_dataset}
# val_data_dir = f"data/{domain}_vae_traj{args.num_episodes_per_dataset}_val"
if config.drop_action_ratio > 0:
raise NotImplementedError
train_dataset = RawFeatureDataset(train_data_dir, window_size=args.window_size,
stride=args.stride, filter_overlaps=args.filter_overlaps,
max_traj_num=args.num_episodes_per_dataset,
use_actions=config.use_actions, domain=domain)
dataset_num_samples.append(len(train_dataset))
action_dimensions.append(train_dataset.n_action)
if config.use_actions:
action_stats.append(train_dataset.action_stat)
if not args.overfit_first_batch:
eval_dataset = RawFeatureDataset(val_data_dir, window_size=args.window_size,
stride=args.stride, filter_overlaps=True,
use_actions=config.use_actions, domain=domain)
else:
train_dataset.valid_start_inds = train_dataset.valid_start_inds[:args.per_device_train_batch_size
* args.gradient_accumulation_steps
* accelerator.num_processes]
eval_dataset = train_dataset
# Shuffle eval dataset and then set shuffle=False on the dataloader.
# Shuffling in the dataloader results in reshuffling with each iteration.
eval_dataset.valid_start_inds = torch.tensor(eval_dataset.valid_start_inds)[
torch.randperm(len(eval_dataset), generator=torch.Generator().manual_seed(0))
].tolist()
val_dataset_num_samples.append(len(eval_dataset))
except Exception as e:
import traceback
print(traceback.format_exc())
train_datasets.append(train_dataset)
val_datasets.append(eval_dataset)
assert all(train_dataset.metadata.get(shared_key) == eval_dataset.metadata.get(shared_key)
for shared_key in shared_keys) # TODO: check this across all datasets
print("dataset_num_samples:", dataset_num_samples)
# Will not store key in metadata if it's missing, so that defaults can be filled by functions later? # TODO: handle missing keys
shared_metadata = {shared_key: train_dataset.metadata[shared_key]
for shared_key in shared_keys if shared_key in train_dataset.metadata}
config.use_mup = args.mu_transfer # Note: changing this may affect pre-trained model due to attn scaling
config.image_vocab_size = None
config.T = args.window_size
config.S = shared_metadata["h"] * shared_metadata["w"] # TODO: make STMaskGIT use h and w instead of S
config.vae_embed_dim = shared_metadata["latent_channels"]
if args.action_network is not None:
print("Using action network", args.action_network)
config.action_network = args.action_network
model = STMAR(config)
if config.use_actions:
# TODO: use new list instead of domains_list, in case domain fails
model.init_action_projectors(domains_list, action_dimensions, action_stats, config.action_network)
if args.image_maskgit_path is not None:
model.init_weights()
model.load_pretrained_image_weights(args.image_maskgit_path)
if args.mu_transfer:
model.set_mup_shapes(rescale_params=False)
elif args.mu_transfer:
model.set_mup_shapes(rescale_params=True)
# model.init_weights() # might be unnecessary if `rescale_params` is True
# Optimizer. Split weights in two groups, one with weight decay and the other not.
opt_class = mup.MuAdamW if args.mu_transfer else torch.optim.AdamW
# scale base learning rate
effective_batch_size = args.per_device_train_batch_size * args.gradient_accumulation_steps \
* accelerator.num_processes
args.learning_rate = args.learning_rate * min(max(1, effective_batch_size / 64), 8)
no_decay = ["bias", "layer_norm.weight"]
pretrained_params = { # more accurately the params we want lower lr for, some weights like pos_embed_TSC are pre-trained but not treated as lower lr
param_name
for param_name, _ in model.named_parameters()
if any(term in param_name for term in ("spatial_attn.qkv", "spatial_attn.proj", "mlp"))
} if args.image_maskgit_path is not None else set()
# Give pre-trained weights 10x lower learning rate
optimizer_grouped_parameters = [
{
"params": [p for n, p in model.named_parameters()
if not any(nd in n for nd in no_decay) and n not in pretrained_params],
"weight_decay": args.weight_decay,
"lr": args.learning_rate,
},
{
"params": [p for n, p in model.named_parameters()
if any(nd in n for nd in no_decay) and n not in pretrained_params],
"weight_decay": 0.0,
"lr": args.learning_rate,
},
{
"params": [p for n, p in model.named_parameters()
if not any(nd in n for nd in no_decay) and n in pretrained_params],
"weight_decay": args.weight_decay,
"lr": args.learning_rate * 0.1,
},
{
"params": [p for n, p in model.named_parameters()
if any(nd in n for nd in no_decay) and n in pretrained_params],
"weight_decay": 0.0,
"lr": args.learning_rate * 0.1,
},
]
optimizer = opt_class(optimizer_grouped_parameters, lr=args.learning_rate,
betas=(args.adam_beta_1, args.adam_beta_2), eps=args.adam_eps)
# DataLoaders creation:
collate_fn = default_data_collator if args.llama_config is not None else get_maskgit_collator_feature(config)
combined_dataset = torch.utils.data.ConcatDataset(train_datasets)
batch_sampler = data_sampler.MultiTaskBatchSampler(
dataset_num_samples,
batch_size=args.per_device_train_batch_size,
temperature=3. # the higher the more flat the distribution
)
dataset_traj_image = data_sampler.make_dataset_pie_plot(domains_list, dataset_num_samples)
accelerator.log(({"dataset_mixture": wandb.Image(dataset_traj_image)}), log_kwargs={"wandb": {"commit": False}})
dataset_weights = batch_sampler.generate_tasks_distribution().cpu().numpy()
dataset_weight_image = data_sampler.make_dataset_pie_plot(domains_list, dataset_weights)
accelerator.log(({"dataset_mixture_weight": wandb.Image(dataset_weight_image)}), log_kwargs={"wandb": {"commit": False}})
train_dataloader = DataLoader(combined_dataset, batch_sampler=batch_sampler, collate_fn=collate_fn,
num_workers=24, pin_memory=False)
batch_val_sampler = data_sampler.MultiTaskBatchSampler(
val_dataset_num_samples,
batch_size=args.per_device_train_batch_size,
temperature=4. # the higher the more flat the distribution
)
combined_val_dataset = torch.utils.data.ConcatDataset(val_datasets)
eval_dataloader = DataLoader(combined_val_dataset, batch_sampler=batch_val_sampler, collate_fn=collate_fn,
num_workers=24, pin_memory=False)
# Scheduler and math around the number of training steps.
overrode_max_train_steps = False
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
if args.max_train_steps is None:
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
overrode_max_train_steps = True
if args.max_train_steps < 2000 and args.resume_from_checkpoint is None: # minimal number of trainng steps
args.max_train_steps = 2000
args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
if args.lr_scheduler_type == "custom_cosine": # decay to `end_ratio` of the peak learning rate
def get_lr_wrapper(warmup_steps, max_steps, end_ratio=0.1):
def get_lr(step):
if step < warmup_steps:
return (step + 1) / warmup_steps
remaining_steps = max_steps - warmup_steps
return ((1 + math.cos(math.pi * (step - warmup_steps) / remaining_steps)) / 2) \
* (1 - end_ratio) + end_ratio
return get_lr
lr_scheduler = torch.optim.lr_scheduler.LambdaLR(
optimizer, get_lr_wrapper(args.num_warmup_steps * accelerator.num_processes,
args.max_train_steps if overrode_max_train_steps
else args.max_train_steps * accelerator.num_processes)
)
else:
lr_scheduler = get_scheduler(
name=args.lr_scheduler_type,
optimizer=optimizer,
num_warmup_steps=args.num_warmup_steps * accelerator.num_processes,
num_training_steps=args.max_train_steps
if overrode_max_train_steps
else args.max_train_steps * accelerator.num_processes,
)
# Enable gradient checkpointing to save memory
if args.gradient_checkpointing:
logger.info("Enabling gradient checkpointing")
model.gradient_checkpointing_enable()
model.config.use_cache = False # incompatible with grad checkpointing
# Prepare everything with our `accelerator`.
accelerator.wait_for_everyone()
model, optimizer, train_dataloader, eval_dataloader, lr_scheduler = accelerator.prepare(
model, optimizer, train_dataloader, eval_dataloader, lr_scheduler
)
if not args.no_compile:
torch._dynamo.config.cache_size_limit = 256
torch._dynamo.config.optimize_ddp = False # https://github.com/pytorch/pytorch/issues/104674
# TODO: https://github.com/pytorch/pytorch/issues/109774#issuecomment-2046633776
model = torch.compile(model)
# We need to recalculate our total training steps as the size of the training dataloader may have changed.
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
if overrode_max_train_steps:
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
# Afterwards we recalculate our number of training epochs
args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
# Figure out how many steps we should save the Accelerator states
checkpointing_steps = args.checkpointing_steps
if checkpointing_steps is not None and checkpointing_steps.isdigit():
checkpointing_steps = int(checkpointing_steps)
# We need to initialize the trackers we use, and also store our configuration.
# The trackers initialize automatically on the main process.
experiment_config = vars(args) | vars(config)
seq_len = shared_metadata["h"] * shared_metadata["w"] * args.window_size
effective_batch_size = args.per_device_train_batch_size * args.gradient_accumulation_steps \
* accelerator.num_processes
args.num_datasets = len(train_datasets)
model_module = model.module if hasattr(model, "module") else model
experiment_config.update(shared_metadata | {
"model_parameters": sum(p.numel() for p in model.parameters()),
"model_parameters_M": round(sum(p.numel() for p in model.parameters()) / 1e6),
"trunk_parameters": sum(p.numel() for p in model_module.decoder.parameters()),
"trunk_parameters_M": round(sum(p.numel() for p in model_module.decoder.parameters()) / 1e6),
"seq_len": seq_len,
"train_data_tokens": len(train_dataset) * seq_len,
"effective_batch_size": effective_batch_size,
"effective_batch_size_tokens": effective_batch_size * seq_len,
"mixed_precision": accelerator.mixed_precision,
"num_datasets": args.num_datasets
})
experiment_config["FLOPs_per_update_step"] = 6 * experiment_config["model_parameters"] \
* experiment_config["effective_batch_size_tokens"]
accelerator.init_trackers(project_name="video", config=experiment_config)
# Train!
train(accelerator, model, optimizer, lr_scheduler, train_dataloader, eval_dataloader, experiment_config, config, args)
if __name__ == "__main__":
main()