old_tok / paintmind /engine /trainer.py
tennant's picture
upload
af7c0ce
raw
history blame contribute delete
30.3 kB
import os, torch
import os.path as osp
import cv2
import shutil
import numpy as np
import copy
import torch_fidelity
import torch.nn as nn
from tqdm.auto import tqdm
from collections import OrderedDict
from einops import rearrange
from accelerate import Accelerator
from .util import instantiate_from_config
from torchvision.utils import make_grid, save_image
from torch.utils.data import DataLoader, random_split, DistributedSampler
from paintmind.utils.lr_scheduler import build_scheduler
from paintmind.utils.logger import SmoothedValue, MetricLogger, synchronize_processes, empty_cache
from paintmind.engine.misc import is_main_process, all_reduce_mean, concat_all_gather
from accelerate.utils import DistributedDataParallelKwargs, AutocastKwargs
from torch.optim import AdamW
from concurrent.futures import ThreadPoolExecutor
from torchmetrics.functional.image import (
peak_signal_noise_ratio as psnr,
structural_similarity_index_measure as ssim
)
def requires_grad(model, flag=True):
for p in model.parameters():
p.requires_grad = flag
def save_img(img, save_path):
img = np.clip(img.numpy().transpose([1, 2, 0]) * 255, 0, 255)
img = img.astype(np.uint8)[:, :, ::-1]
cv2.imwrite(save_path, img)
def save_img_batch(imgs, save_paths):
"""Process and save multiple images at once using a thread pool."""
# Convert to numpy and prepare all images in one go
imgs = np.clip(imgs.numpy().transpose(0, 2, 3, 1) * 255, 0, 255).astype(np.uint8)
imgs = imgs[:, :, :, ::-1] # RGB to BGR for all images at once
# Use ProcessPoolExecutor which is generally better for CPU-bound tasks
# ThreadPoolExecutor is better for I/O-bound tasks like file saving
with ThreadPoolExecutor(max_workers=32) as pool:
# Submit all tasks at once
futures = [pool.submit(cv2.imwrite, path, img)
for path, img in zip(save_paths, imgs)]
# Wait for all tasks to complete
for future in futures:
future.result() # This will raise any exceptions that occurred
def get_fid_stats(real_dir, rec_dir, fid_stats):
stats = torch_fidelity.calculate_metrics(
input1=rec_dir,
input2=real_dir,
fid_statistics_file=fid_stats,
cuda=True,
isc=True,
fid=True,
kid=False,
prc=False,
verbose=False,
)
return stats
class EMAModel:
"""Model Exponential Moving Average."""
def __init__(self, model, device, decay=0.999):
self.device = device
self.decay = decay
self.ema_params = OrderedDict(
(name, param.clone().detach().to(device))
for name, param in model.named_parameters()
if param.requires_grad
)
@torch.no_grad()
def update(self, model):
for name, param in model.named_parameters():
if param.requires_grad:
if name in self.ema_params:
self.ema_params[name].lerp_(param.data, 1 - self.decay)
else:
self.ema_params[name] = param.data.clone().detach()
def state_dict(self):
return self.ema_params
def load_state_dict(self, params):
self.ema_params = OrderedDict(
(name, param.clone().detach().to(self.device))
for name, param in params.items()
)
class DiffusionTrainer(nn.Module):
def __init__(
self,
model,
dataset,
test_dataset=None,
test_only=False,
num_epoch=400,
valid_size=32,
lr=None,
blr=1e-4,
cosine_lr=True,
lr_min=0,
warmup_epochs=100,
warmup_steps=None,
warmup_lr_init=0,
decay_steps=None,
batch_size=32,
eval_bs=32,
test_bs=64,
num_workers=0,
pin_memory=False,
max_grad_norm=None,
grad_accum_steps=1,
precision="bf16",
save_every=10000,
sample_every=1000,
fid_every=50000,
result_folder=None,
log_dir="./log",
steps=0,
cfg=1.0,
test_num_slots=None,
eval_fid=False,
fid_stats=None,
enable_ema=False,
use_multi_epochs_dataloader=False,
compile=False,
overfit=False,
making_cache=False,
cache_mode=False,
latent_cache_file=None,
):
super().__init__()
kwargs = DistributedDataParallelKwargs(find_unused_parameters=False)
self.accelerator = Accelerator(
kwargs_handlers=[kwargs],
mixed_precision="bf16",
gradient_accumulation_steps=grad_accum_steps,
log_with="tensorboard",
project_dir=log_dir,
)
self.model = instantiate_from_config(model)
self.num_slots = model.params.num_slots
assert precision in ["bf16", "fp32"]
precision = "fp32"
if self.accelerator.is_main_process:
print("Overlooking specified precision and using autocast bf16...")
self.precision = precision
if test_dataset is not None:
test_dataset = instantiate_from_config(test_dataset)
self.test_ds = test_dataset
# Calculate padded dataset size to ensure even distribution
total_size = len(test_dataset)
world_size = self.accelerator.num_processes
padding_size = world_size * test_bs - (total_size % (world_size * test_bs))
self.test_dataset_size = total_size
# Create a padded dataset wrapper
class PaddedDataset(torch.utils.data.Dataset):
def __init__(self, dataset, padding_size):
self.dataset = dataset
self.padding_size = padding_size
def __len__(self):
return len(self.dataset) + self.padding_size
def __getitem__(self, idx):
if idx < len(self.dataset):
return self.dataset[idx]
return self.dataset[0]
self.test_ds = PaddedDataset(self.test_ds, padding_size)
self.test_dl = DataLoader(
self.test_ds,
batch_size=test_bs,
num_workers=num_workers,
pin_memory=pin_memory,
shuffle=False,
drop_last=True,
)
if self.accelerator.is_main_process:
print(f"test dataset size: {len(test_dataset)}, test batch size: {test_bs}")
else:
self.test_dl = None
self.test_only = test_only
if not test_only:
dataset = instantiate_from_config(dataset)
train_size = len(dataset) - valid_size
self.train_ds, self.valid_ds = random_split(
dataset,
[train_size, valid_size],
generator=torch.Generator().manual_seed(42),
)
if self.accelerator.is_main_process:
print(f"train dataset size: {train_size}, valid dataset size: {valid_size}")
sampler = DistributedSampler(
self.train_ds,
num_replicas=self.accelerator.num_processes,
rank=self.accelerator.process_index,
shuffle=True,
)
self.train_dl = DataLoader(
self.train_ds,
batch_size=batch_size,
sampler=sampler,
num_workers=num_workers,
pin_memory=pin_memory,
drop_last=True,
)
self.valid_dl = DataLoader(
self.valid_ds,
batch_size=eval_bs,
shuffle=False,
num_workers=num_workers,
pin_memory=pin_memory,
)
effective_bs = batch_size * grad_accum_steps * self.accelerator.num_processes
if lr is None:
lr = blr * effective_bs / 256
if self.accelerator.is_main_process:
print(f"Effective batch size is {effective_bs}")
params = filter(lambda p: p.requires_grad, self.model.parameters())
self.g_optim = AdamW(params, lr=lr, betas=(0.9, 0.95), weight_decay=0)
self.g_sched = self._create_scheduler(
cosine_lr, warmup_epochs, warmup_steps, num_epoch,
lr_min, warmup_lr_init, decay_steps
)
if self.g_sched is not None:
self.accelerator.register_for_checkpointing(self.g_sched)
self.steps = steps
self.loaded_steps = -1
# Prepare everything together
if not test_only:
self.model, self.g_optim, self.g_sched = self.accelerator.prepare(
self.model, self.g_optim, self.g_sched
)
else:
self.model, self.test_dl = self.accelerator.prepare(self.model, self.test_dl)
if compile:
_model = self.accelerator.unwrap_model(self.model)
_model.vae = torch.compile(_model.vae, mode="reduce-overhead")
_model.dit = torch.compile(_model.dit, mode="reduce-overhead")
# _model.encoder = torch.compile(_model.encoder, mode="reduce-overhead") # nan loss when compiled together with dit, no idea why
_model.encoder2slot = torch.compile(_model.encoder2slot, mode="reduce-overhead")
self.enable_ema = enable_ema
if self.enable_ema and not self.test_only: # when testing, we directly load the ema dict and skip here
self.ema_model = EMAModel(self.accelerator.unwrap_model(self.model), self.device)
self.accelerator.register_for_checkpointing(self.ema_model)
self._load_checkpoint(model.params.ckpt_path)
if self.test_only:
self.steps = self.loaded_steps
self.num_epoch = num_epoch
self.save_every = save_every
self.samp_every = sample_every
self.fid_every = fid_every
self.max_grad_norm = max_grad_norm
self.cache_mode = cache_mode
self.cfg = cfg
self.test_num_slots = test_num_slots
if self.test_num_slots is not None:
self.test_num_slots = min(self.test_num_slots, self.num_slots)
else:
self.test_num_slots = self.num_slots
eval_fid = eval_fid or model.params.eval_fid # legacy
self.eval_fid = eval_fid
if eval_fid:
if fid_stats is None:
fid_stats = model.params.fid_stats # legacy
assert fid_stats is not None
assert test_dataset is not None
self.fid_stats = fid_stats
self.use_vq = model.params.use_vq if hasattr(model.params, "use_vq") else False
self.vq_beta = model.params.code_beta if hasattr(model.params, "code_beta") else 0.25
self.result_folder = result_folder
self.model_saved_dir = os.path.join(result_folder, "models")
os.makedirs(self.model_saved_dir, exist_ok=True)
self.image_saved_dir = os.path.join(result_folder, "images")
os.makedirs(self.image_saved_dir, exist_ok=True)
@property
def device(self):
return self.accelerator.device
def _create_scheduler(self, cosine_lr, warmup_epochs, warmup_steps, num_epoch, lr_min, warmup_lr_init, decay_steps):
if warmup_epochs is not None:
warmup_steps = warmup_epochs * len(self.train_dl)
else:
assert warmup_steps is not None
scheduler = build_scheduler(
self.g_optim,
num_epoch,
len(self.train_dl),
lr_min,
warmup_steps,
warmup_lr_init,
decay_steps,
cosine_lr, # if not cosine_lr, then use step_lr (warmup, then fix)
)
return scheduler
def _load_state_dict(self, state_dict):
"""Helper to load a state dict with proper prefix handling."""
if 'state_dict' in state_dict:
state_dict = state_dict['state_dict']
# Remove '_orig_mod' prefix if present
state_dict = {k.replace('_orig_mod.', ''): v for k, v in state_dict.items()}
missing, unexpected = self.accelerator.unwrap_model(self.model).load_state_dict(
state_dict, strict=False
)
if self.accelerator.is_main_process:
print(f"Loaded model. Missing: {missing}, Unexpected: {unexpected}")
def _load_safetensors(self, path):
"""Helper to load a safetensors checkpoint."""
from safetensors.torch import safe_open
with safe_open(path, framework="pt", device="cpu") as f:
state_dict = {k: f.get_tensor(k) for k in f.keys()}
self._load_state_dict(state_dict)
def _load_checkpoint(self, ckpt_path=None):
if ckpt_path is None or not osp.exists(ckpt_path):
return
if osp.isdir(ckpt_path):
# ckpt_path is something like 'path/to/models/step10/'
self.loaded_steps = int(
ckpt_path.split("step")[-1].split("/")[0]
)
if not self.test_only:
self.accelerator.load_state(ckpt_path)
else:
if self.enable_ema:
model_path = osp.join(ckpt_path, "custom_checkpoint_1.pkl")
if osp.exists(model_path):
state_dict = torch.load(model_path, map_location="cpu")
self._load_state_dict(state_dict)
if self.accelerator.is_main_process:
print(f"Loaded ema model from {model_path}")
else:
model_path = osp.join(ckpt_path, "model.safetensors")
if osp.exists(model_path):
self._load_safetensors(model_path)
else:
# ckpt_path is something like 'path/to/models/step10.pt'
if ckpt_path.endswith(".safetensors"):
self._load_safetensors(ckpt_path)
else:
state_dict = torch.load(ckpt_path)
self._load_state_dict(state_dict)
if self.accelerator.is_main_process:
print(f"Loaded checkpoint from {ckpt_path}")
def train(self, config=None):
n_parameters = sum(p.numel() for p in self.parameters() if p.requires_grad)
if self.accelerator.is_main_process:
print(f"number of learnable parameters: {n_parameters//1e6}M")
if config is not None:
# save the config
import shutil
from omegaconf import OmegaConf
if isinstance(config, str) and osp.exists(config):
# If it's a path, copy the file to config.yaml
shutil.copy(config, osp.join(self.result_folder, "config.yaml"))
else:
# If it's an OmegaConf object, dump it
config_save_path = osp.join(self.result_folder, "config.yaml")
OmegaConf.save(config, config_save_path)
self.accelerator.init_trackers("vqgan")
if self.test_only:
empty_cache()
self.evaluate()
self.accelerator.wait_for_everyone()
empty_cache()
return
for epoch in range(self.num_epoch):
if ((epoch + 1) * len(self.train_dl)) <= self.loaded_steps:
if self.accelerator.is_main_process:
print(f"Epoch {epoch} is skipped because it is loaded from ckpt")
self.steps += len(self.train_dl)
continue
if self.steps < self.loaded_steps:
for _ in self.train_dl:
self.steps += 1
if self.steps >= self.loaded_steps:
break
self.accelerator.unwrap_model(self.model).current_epoch = epoch
self.model.train() # Set model to training mode
logger = MetricLogger(delimiter=" ")
logger.add_meter('lr', SmoothedValue(window_size=1, fmt='{value:.6f}'))
header = 'Epoch: [{}/{}]'.format(epoch, self.num_epoch)
print_freq = 20
for data_iter_step, batch in enumerate(logger.log_every(self.train_dl, print_freq, header)):
# Move batch to device once
if isinstance(batch, tuple) or isinstance(batch, list):
batch = tuple(b.to(self.device, non_blocking=True) for b in batch)
if self.cache_mode:
img, latent, targets = batch[0], batch[1], batch[2]
img = img.to(self.device, non_blocking=True)
latent = latent.to(self.device, non_blocking=True)
targets = targets.to(self.device, non_blocking=True)
else:
latent = None
img, targets = batch[0], batch[1]
img = img.to(self.device, non_blocking=True)
targets = targets.to(self.device, non_blocking=True)
else:
img = batch
latent = None
self.steps += 1
with self.accelerator.accumulate(self.model):
with self.accelerator.autocast():
if self.steps == 1:
print(f"Training batch size: {img.size(0)}")
print(f"Hello from index {self.accelerator.local_process_index}")
losses = self.model(img, targets, latents=latent, epoch=epoch)
# combine
loss = sum([v for _, v in losses.items()])
diff_loss = losses["diff_loss"]
if self.use_vq:
vq_loss = losses["vq_loss"]
self.accelerator.backward(loss)
if self.accelerator.sync_gradients and self.max_grad_norm is not None:
self.accelerator.clip_grad_norm_(self.model.parameters(), self.max_grad_norm)
self.accelerator.unwrap_model(self.model).cancel_gradients_encoder(epoch)
self.g_optim.step()
if self.g_sched is not None:
self.g_sched.step_update(self.steps)
self.g_optim.zero_grad()
# synchronize_processes()
# update ema with state dict
if self.enable_ema:
self.ema_model.update(self.accelerator.unwrap_model(self.model))
logger.update(diff_loss=diff_loss.item())
if self.use_vq:
logger.update(vq_loss=vq_loss.item() / self.vq_beta)
if 'kl_loss' in losses:
logger.update(kl_loss=losses["kl_loss"].item())
if 'repa_loss' in losses:
logger.update(repa_loss=losses["repa_loss"].item())
logger.update(lr=self.g_optim.param_groups[0]["lr"])
if self.steps % self.save_every == 0:
self.save()
if (self.steps % self.samp_every == 0) or (self.steps % self.fid_every == 0):
empty_cache()
self.evaluate()
self.accelerator.wait_for_everyone()
empty_cache()
# omitted all_gather here
# write_dict = dict(epoch=epoch)
# write_dict.update(diff_loss=diff_loss.item())
# if "kl_loss" in losses:
# write_dict.update(kl_loss=losses["kl_loss"].item())
# if self.use_vq:
# write_dict.update(vq_loss=vq_loss.item() / self.vq_beta)
# write_dict.update(lr=self.g_optim.param_groups[0]["lr"])
# self.accelerator.log(write_dict, step=self.steps)
logger.synchronize_between_processes()
if self.accelerator.is_main_process:
print("Averaged stats:", logger)
self.accelerator.end_training()
self.save()
if self.accelerator.is_main_process:
print("Train finished!")
def save(self):
self.accelerator.wait_for_everyone()
self.accelerator.save_state(
os.path.join(self.model_saved_dir, f"step{self.steps}")
)
@torch.no_grad()
def evaluate(self, use_ema=True):
self.model.eval()
# switch to ema params, only when eval_fid is True
use_ema = use_ema and self.enable_ema and self.eval_fid and not self.test_only
# use_ema = False
if use_ema:
if hasattr(self, "ema_model"):
model_without_ddp = self.accelerator.unwrap_model(self.model)
model_state_dict = copy.deepcopy(model_without_ddp.state_dict())
ema_state_dict = copy.deepcopy(model_without_ddp.state_dict())
for i, (name, _value) in enumerate(model_without_ddp.named_parameters()):
if "nested_sampler" in name:
continue
if name in self.ema_model.state_dict():
ema_state_dict[name] = self.ema_model.state_dict()[name]
if self.accelerator.is_main_process:
print("Switch to ema")
msg = model_without_ddp.load_state_dict(ema_state_dict)
if self.accelerator.is_main_process:
print(msg)
else:
print("EMA model not found, using original model")
use_ema = False
if not self.test_only:
with tqdm(
self.valid_dl,
dynamic_ncols=True,
disable=not self.accelerator.is_main_process,
) as valid_dl:
for batch_i, batch in enumerate(valid_dl):
if isinstance(batch, tuple) or isinstance(batch, list):
img, targets = batch[0], batch[1]
else:
img = batch
with self.accelerator.autocast():
rec = self.model(img, targets, sample=True, inference_with_n_slots=self.test_num_slots, cfg=1.0)
imgs_and_recs = torch.stack((img.to(rec.device), rec), dim=0)
imgs_and_recs = rearrange(imgs_and_recs, "r b ... -> (b r) ...")
imgs_and_recs = imgs_and_recs.detach().cpu().float()
grid = make_grid(
imgs_and_recs, nrow=6, normalize=True, value_range=(0, 1)
)
if self.accelerator.is_main_process:
save_image(
grid,
os.path.join(
self.image_saved_dir, f"step_{self.steps}_slots{self.test_num_slots}_{batch_i}.jpg"
),
)
if self.cfg != 1.0:
with self.accelerator.autocast():
rec = self.model(img, targets, sample=True, inference_with_n_slots=self.test_num_slots, cfg=self.cfg)
imgs_and_recs = torch.stack((img.to(rec.device), rec), dim=0)
imgs_and_recs = rearrange(imgs_and_recs, "r b ... -> (b r) ...")
imgs_and_recs = imgs_and_recs.detach().cpu().float()
grid = make_grid(
imgs_and_recs, nrow=6, normalize=True, value_range=(0, 1)
)
if self.accelerator.is_main_process:
save_image(
grid,
os.path.join(
self.image_saved_dir, f"step_{self.steps}_cfg_{self.cfg}_slots{self.test_num_slots}_{batch_i}.jpg"
),
)
if (self.eval_fid and self.test_dl is not None) and (self.test_only or (self.steps % self.fid_every == 0)):
# Create output directories
if self.test_dataset_size > 10000:
real_dir = "./dataset/imagenet/val256"
else:
real_dir = "./dataset/coco/val2017_256"
rec_dir = os.path.join(self.image_saved_dir, f"rec_step{self.steps}_slots{self.test_num_slots}")
os.makedirs(rec_dir, exist_ok=True)
if self.cfg != 1.0:
rec_cfg_dir = os.path.join(self.image_saved_dir, f"rec_step{self.steps}_cfg_{self.cfg}_slots{self.test_num_slots}")
os.makedirs(rec_cfg_dir, exist_ok=True)
def process_batch(cfg_value, save_dir, header):
logger = MetricLogger(delimiter=" ")
print_freq = 5
psnr_values = []
ssim_values = []
total_processed = 0
for batch_i, batch in enumerate(logger.log_every(self.test_dl, print_freq, header)):
imgs, targets = (batch[0], batch[1]) if isinstance(batch, (tuple, list)) else (batch, None)
# Skip processing if we've already processed all real samples
if total_processed >= self.test_dataset_size:
break
imgs = imgs.to(self.device, non_blocking=True)
if targets is not None:
targets = targets.to(self.device, non_blocking=True)
with self.accelerator.autocast():
recs = self.model(imgs, targets, sample=True, inference_with_n_slots=self.test_num_slots, cfg=cfg_value)
psnr_val = psnr(recs, imgs, data_range=1.0)
ssim_val = ssim(recs, imgs, data_range=1.0)
recs = concat_all_gather(recs).detach()
psnr_val = concat_all_gather(psnr_val.view(1))
ssim_val = concat_all_gather(ssim_val.view(1))
# Remove padding after gathering from all GPUs
samples_in_batch = min(
recs.size(0), # Always use the gathered size
self.test_dataset_size - total_processed
)
recs = recs[:samples_in_batch]
psnr_val = psnr_val[:samples_in_batch]
ssim_val = ssim_val[:samples_in_batch]
psnr_values.append(psnr_val)
ssim_values.append(ssim_val)
if self.accelerator.is_main_process:
rec_paths = [os.path.join(save_dir, f"step_{self.steps}_slots{self.test_num_slots}_{batch_i}_{j}_rec_cfg_{cfg_value}_slots{self.test_num_slots}.png")
for j in range(recs.size(0))]
save_img_batch(recs.cpu(), rec_paths)
total_processed += samples_in_batch
self.accelerator.wait_for_everyone()
return torch.cat(psnr_values).mean(), torch.cat(ssim_values).mean()
# Helper function to calculate and log metrics
def calculate_and_log_metrics(real_dir, rec_dir, cfg_value, psnr_val, ssim_val):
if self.accelerator.is_main_process:
metrics_dict = get_fid_stats(real_dir, rec_dir, self.fid_stats)
fid = metrics_dict["frechet_inception_distance"]
inception_score = metrics_dict["inception_score_mean"]
metric_prefix = "fid_ema" if use_ema else "fid"
isc_prefix = "isc_ema" if use_ema else "isc"
self.accelerator.log({
metric_prefix: fid,
isc_prefix: inception_score,
f"psnr_{'ema' if use_ema else 'test'}": psnr_val,
f"ssim_{'ema' if use_ema else 'test'}": ssim_val,
"cfg": cfg_value
}, step=self.steps)
print(f"{'EMA ' if use_ema else ''}{f'CFG: {cfg_value}'} "
f"FID: {fid:.2f}, ISC: {inception_score:.2f}, "
f"PSNR: {psnr_val:.2f}, SSIM: {ssim_val:.4f}")
# Process without CFG
if self.cfg == 1.0 or not self.test_only:
psnr_val, ssim_val = process_batch(1.0, rec_dir, 'Testing: w/o CFG')
calculate_and_log_metrics(real_dir, rec_dir, 1.0, psnr_val, ssim_val)
# Process with CFG if needed
if self.cfg != 1.0:
psnr_val, ssim_val = process_batch(self.cfg, rec_cfg_dir, 'Testing: w/ CFG')
calculate_and_log_metrics(real_dir, rec_cfg_dir, self.cfg, psnr_val, ssim_val)
# Cleanup
if self.accelerator.is_main_process:
shutil.rmtree(rec_dir)
if self.cfg != 1.0:
shutil.rmtree(rec_cfg_dir)
# back to no ema
if use_ema:
if self.accelerator.is_main_process:
print("Switch back from ema")
model_without_ddp.load_state_dict(model_state_dict)
self.model.train()