|
import os, torch
|
|
import os.path as osp
|
|
import cv2
|
|
import shutil
|
|
import numpy as np
|
|
import copy
|
|
import torch_fidelity
|
|
import torch.nn as nn
|
|
from tqdm.auto import tqdm
|
|
from collections import OrderedDict
|
|
from einops import rearrange
|
|
from accelerate import Accelerator
|
|
from .util import instantiate_from_config
|
|
from torchvision.utils import make_grid, save_image
|
|
from torch.utils.data import DataLoader, random_split, DistributedSampler
|
|
from paintmind.utils.lr_scheduler import build_scheduler
|
|
from paintmind.utils.logger import SmoothedValue, MetricLogger, synchronize_processes, empty_cache
|
|
from paintmind.engine.misc import is_main_process, all_reduce_mean, concat_all_gather
|
|
from accelerate.utils import DistributedDataParallelKwargs, AutocastKwargs
|
|
from torch.optim import AdamW
|
|
from concurrent.futures import ThreadPoolExecutor
|
|
from torchmetrics.functional.image import (
|
|
peak_signal_noise_ratio as psnr,
|
|
structural_similarity_index_measure as ssim
|
|
)
|
|
|
|
def requires_grad(model, flag=True):
|
|
for p in model.parameters():
|
|
p.requires_grad = flag
|
|
|
|
|
|
def save_img(img, save_path):
|
|
img = np.clip(img.numpy().transpose([1, 2, 0]) * 255, 0, 255)
|
|
img = img.astype(np.uint8)[:, :, ::-1]
|
|
cv2.imwrite(save_path, img)
|
|
|
|
def save_img_batch(imgs, save_paths):
|
|
"""Process and save multiple images at once using a thread pool."""
|
|
|
|
imgs = np.clip(imgs.numpy().transpose(0, 2, 3, 1) * 255, 0, 255).astype(np.uint8)
|
|
imgs = imgs[:, :, :, ::-1]
|
|
|
|
|
|
|
|
with ThreadPoolExecutor(max_workers=32) as pool:
|
|
|
|
futures = [pool.submit(cv2.imwrite, path, img)
|
|
for path, img in zip(save_paths, imgs)]
|
|
|
|
for future in futures:
|
|
future.result()
|
|
|
|
def get_fid_stats(real_dir, rec_dir, fid_stats):
|
|
stats = torch_fidelity.calculate_metrics(
|
|
input1=rec_dir,
|
|
input2=real_dir,
|
|
fid_statistics_file=fid_stats,
|
|
cuda=True,
|
|
isc=True,
|
|
fid=True,
|
|
kid=False,
|
|
prc=False,
|
|
verbose=False,
|
|
)
|
|
return stats
|
|
|
|
|
|
class EMAModel:
|
|
"""Model Exponential Moving Average."""
|
|
def __init__(self, model, device, decay=0.999):
|
|
self.device = device
|
|
self.decay = decay
|
|
self.ema_params = OrderedDict(
|
|
(name, param.clone().detach().to(device))
|
|
for name, param in model.named_parameters()
|
|
if param.requires_grad
|
|
)
|
|
|
|
@torch.no_grad()
|
|
def update(self, model):
|
|
for name, param in model.named_parameters():
|
|
if param.requires_grad:
|
|
if name in self.ema_params:
|
|
self.ema_params[name].lerp_(param.data, 1 - self.decay)
|
|
else:
|
|
self.ema_params[name] = param.data.clone().detach()
|
|
|
|
def state_dict(self):
|
|
return self.ema_params
|
|
|
|
def load_state_dict(self, params):
|
|
self.ema_params = OrderedDict(
|
|
(name, param.clone().detach().to(self.device))
|
|
for name, param in params.items()
|
|
)
|
|
|
|
class DiffusionTrainer(nn.Module):
|
|
def __init__(
|
|
self,
|
|
model,
|
|
dataset,
|
|
test_dataset=None,
|
|
test_only=False,
|
|
num_epoch=400,
|
|
valid_size=32,
|
|
lr=None,
|
|
blr=1e-4,
|
|
cosine_lr=True,
|
|
lr_min=0,
|
|
warmup_epochs=100,
|
|
warmup_steps=None,
|
|
warmup_lr_init=0,
|
|
decay_steps=None,
|
|
batch_size=32,
|
|
eval_bs=32,
|
|
test_bs=64,
|
|
num_workers=0,
|
|
pin_memory=False,
|
|
max_grad_norm=None,
|
|
grad_accum_steps=1,
|
|
precision="bf16",
|
|
save_every=10000,
|
|
sample_every=1000,
|
|
fid_every=50000,
|
|
result_folder=None,
|
|
log_dir="./log",
|
|
steps=0,
|
|
cfg=1.0,
|
|
test_num_slots=None,
|
|
eval_fid=False,
|
|
fid_stats=None,
|
|
enable_ema=False,
|
|
use_multi_epochs_dataloader=False,
|
|
compile=False,
|
|
overfit=False,
|
|
making_cache=False,
|
|
cache_mode=False,
|
|
latent_cache_file=None,
|
|
):
|
|
super().__init__()
|
|
kwargs = DistributedDataParallelKwargs(find_unused_parameters=False)
|
|
self.accelerator = Accelerator(
|
|
kwargs_handlers=[kwargs],
|
|
mixed_precision="bf16",
|
|
gradient_accumulation_steps=grad_accum_steps,
|
|
log_with="tensorboard",
|
|
project_dir=log_dir,
|
|
)
|
|
|
|
self.model = instantiate_from_config(model)
|
|
self.num_slots = model.params.num_slots
|
|
|
|
assert precision in ["bf16", "fp32"]
|
|
precision = "fp32"
|
|
if self.accelerator.is_main_process:
|
|
print("Overlooking specified precision and using autocast bf16...")
|
|
self.precision = precision
|
|
|
|
if test_dataset is not None:
|
|
test_dataset = instantiate_from_config(test_dataset)
|
|
self.test_ds = test_dataset
|
|
|
|
|
|
total_size = len(test_dataset)
|
|
world_size = self.accelerator.num_processes
|
|
padding_size = world_size * test_bs - (total_size % (world_size * test_bs))
|
|
self.test_dataset_size = total_size
|
|
|
|
|
|
class PaddedDataset(torch.utils.data.Dataset):
|
|
def __init__(self, dataset, padding_size):
|
|
self.dataset = dataset
|
|
self.padding_size = padding_size
|
|
|
|
def __len__(self):
|
|
return len(self.dataset) + self.padding_size
|
|
|
|
def __getitem__(self, idx):
|
|
if idx < len(self.dataset):
|
|
return self.dataset[idx]
|
|
return self.dataset[0]
|
|
|
|
self.test_ds = PaddedDataset(self.test_ds, padding_size)
|
|
self.test_dl = DataLoader(
|
|
self.test_ds,
|
|
batch_size=test_bs,
|
|
num_workers=num_workers,
|
|
pin_memory=pin_memory,
|
|
shuffle=False,
|
|
drop_last=True,
|
|
)
|
|
if self.accelerator.is_main_process:
|
|
print(f"test dataset size: {len(test_dataset)}, test batch size: {test_bs}")
|
|
else:
|
|
self.test_dl = None
|
|
self.test_only = test_only
|
|
|
|
if not test_only:
|
|
dataset = instantiate_from_config(dataset)
|
|
train_size = len(dataset) - valid_size
|
|
self.train_ds, self.valid_ds = random_split(
|
|
dataset,
|
|
[train_size, valid_size],
|
|
generator=torch.Generator().manual_seed(42),
|
|
)
|
|
if self.accelerator.is_main_process:
|
|
print(f"train dataset size: {train_size}, valid dataset size: {valid_size}")
|
|
|
|
sampler = DistributedSampler(
|
|
self.train_ds,
|
|
num_replicas=self.accelerator.num_processes,
|
|
rank=self.accelerator.process_index,
|
|
shuffle=True,
|
|
)
|
|
self.train_dl = DataLoader(
|
|
self.train_ds,
|
|
batch_size=batch_size,
|
|
sampler=sampler,
|
|
num_workers=num_workers,
|
|
pin_memory=pin_memory,
|
|
drop_last=True,
|
|
)
|
|
self.valid_dl = DataLoader(
|
|
self.valid_ds,
|
|
batch_size=eval_bs,
|
|
shuffle=False,
|
|
num_workers=num_workers,
|
|
pin_memory=pin_memory,
|
|
)
|
|
|
|
effective_bs = batch_size * grad_accum_steps * self.accelerator.num_processes
|
|
if lr is None:
|
|
lr = blr * effective_bs / 256
|
|
if self.accelerator.is_main_process:
|
|
print(f"Effective batch size is {effective_bs}")
|
|
|
|
params = filter(lambda p: p.requires_grad, self.model.parameters())
|
|
self.g_optim = AdamW(params, lr=lr, betas=(0.9, 0.95), weight_decay=0)
|
|
self.g_sched = self._create_scheduler(
|
|
cosine_lr, warmup_epochs, warmup_steps, num_epoch,
|
|
lr_min, warmup_lr_init, decay_steps
|
|
)
|
|
if self.g_sched is not None:
|
|
self.accelerator.register_for_checkpointing(self.g_sched)
|
|
|
|
self.steps = steps
|
|
self.loaded_steps = -1
|
|
|
|
|
|
if not test_only:
|
|
self.model, self.g_optim, self.g_sched = self.accelerator.prepare(
|
|
self.model, self.g_optim, self.g_sched
|
|
)
|
|
else:
|
|
self.model, self.test_dl = self.accelerator.prepare(self.model, self.test_dl)
|
|
|
|
if compile:
|
|
_model = self.accelerator.unwrap_model(self.model)
|
|
_model.vae = torch.compile(_model.vae, mode="reduce-overhead")
|
|
_model.dit = torch.compile(_model.dit, mode="reduce-overhead")
|
|
|
|
_model.encoder2slot = torch.compile(_model.encoder2slot, mode="reduce-overhead")
|
|
|
|
self.enable_ema = enable_ema
|
|
if self.enable_ema and not self.test_only:
|
|
self.ema_model = EMAModel(self.accelerator.unwrap_model(self.model), self.device)
|
|
self.accelerator.register_for_checkpointing(self.ema_model)
|
|
|
|
self._load_checkpoint(model.params.ckpt_path)
|
|
if self.test_only:
|
|
self.steps = self.loaded_steps
|
|
|
|
self.num_epoch = num_epoch
|
|
self.save_every = save_every
|
|
self.samp_every = sample_every
|
|
self.fid_every = fid_every
|
|
self.max_grad_norm = max_grad_norm
|
|
|
|
self.cache_mode = cache_mode
|
|
|
|
self.cfg = cfg
|
|
self.test_num_slots = test_num_slots
|
|
if self.test_num_slots is not None:
|
|
self.test_num_slots = min(self.test_num_slots, self.num_slots)
|
|
else:
|
|
self.test_num_slots = self.num_slots
|
|
eval_fid = eval_fid or model.params.eval_fid
|
|
self.eval_fid = eval_fid
|
|
if eval_fid:
|
|
if fid_stats is None:
|
|
fid_stats = model.params.fid_stats
|
|
assert fid_stats is not None
|
|
assert test_dataset is not None
|
|
self.fid_stats = fid_stats
|
|
|
|
self.use_vq = model.params.use_vq if hasattr(model.params, "use_vq") else False
|
|
self.vq_beta = model.params.code_beta if hasattr(model.params, "code_beta") else 0.25
|
|
|
|
self.result_folder = result_folder
|
|
self.model_saved_dir = os.path.join(result_folder, "models")
|
|
os.makedirs(self.model_saved_dir, exist_ok=True)
|
|
|
|
self.image_saved_dir = os.path.join(result_folder, "images")
|
|
os.makedirs(self.image_saved_dir, exist_ok=True)
|
|
|
|
@property
|
|
def device(self):
|
|
return self.accelerator.device
|
|
|
|
def _create_scheduler(self, cosine_lr, warmup_epochs, warmup_steps, num_epoch, lr_min, warmup_lr_init, decay_steps):
|
|
if warmup_epochs is not None:
|
|
warmup_steps = warmup_epochs * len(self.train_dl)
|
|
else:
|
|
assert warmup_steps is not None
|
|
|
|
scheduler = build_scheduler(
|
|
self.g_optim,
|
|
num_epoch,
|
|
len(self.train_dl),
|
|
lr_min,
|
|
warmup_steps,
|
|
warmup_lr_init,
|
|
decay_steps,
|
|
cosine_lr,
|
|
)
|
|
return scheduler
|
|
|
|
def _load_state_dict(self, state_dict):
|
|
"""Helper to load a state dict with proper prefix handling."""
|
|
if 'state_dict' in state_dict:
|
|
state_dict = state_dict['state_dict']
|
|
|
|
state_dict = {k.replace('_orig_mod.', ''): v for k, v in state_dict.items()}
|
|
missing, unexpected = self.accelerator.unwrap_model(self.model).load_state_dict(
|
|
state_dict, strict=False
|
|
)
|
|
if self.accelerator.is_main_process:
|
|
print(f"Loaded model. Missing: {missing}, Unexpected: {unexpected}")
|
|
|
|
def _load_safetensors(self, path):
|
|
"""Helper to load a safetensors checkpoint."""
|
|
from safetensors.torch import safe_open
|
|
with safe_open(path, framework="pt", device="cpu") as f:
|
|
state_dict = {k: f.get_tensor(k) for k in f.keys()}
|
|
self._load_state_dict(state_dict)
|
|
|
|
def _load_checkpoint(self, ckpt_path=None):
|
|
if ckpt_path is None or not osp.exists(ckpt_path):
|
|
return
|
|
|
|
if osp.isdir(ckpt_path):
|
|
|
|
self.loaded_steps = int(
|
|
ckpt_path.split("step")[-1].split("/")[0]
|
|
)
|
|
if not self.test_only:
|
|
self.accelerator.load_state(ckpt_path)
|
|
else:
|
|
if self.enable_ema:
|
|
model_path = osp.join(ckpt_path, "custom_checkpoint_1.pkl")
|
|
if osp.exists(model_path):
|
|
state_dict = torch.load(model_path, map_location="cpu")
|
|
self._load_state_dict(state_dict)
|
|
if self.accelerator.is_main_process:
|
|
print(f"Loaded ema model from {model_path}")
|
|
else:
|
|
model_path = osp.join(ckpt_path, "model.safetensors")
|
|
if osp.exists(model_path):
|
|
self._load_safetensors(model_path)
|
|
else:
|
|
|
|
if ckpt_path.endswith(".safetensors"):
|
|
self._load_safetensors(ckpt_path)
|
|
else:
|
|
state_dict = torch.load(ckpt_path)
|
|
self._load_state_dict(state_dict)
|
|
if self.accelerator.is_main_process:
|
|
print(f"Loaded checkpoint from {ckpt_path}")
|
|
|
|
def train(self, config=None):
|
|
n_parameters = sum(p.numel() for p in self.parameters() if p.requires_grad)
|
|
if self.accelerator.is_main_process:
|
|
print(f"number of learnable parameters: {n_parameters//1e6}M")
|
|
if config is not None:
|
|
|
|
import shutil
|
|
from omegaconf import OmegaConf
|
|
|
|
if isinstance(config, str) and osp.exists(config):
|
|
|
|
shutil.copy(config, osp.join(self.result_folder, "config.yaml"))
|
|
else:
|
|
|
|
config_save_path = osp.join(self.result_folder, "config.yaml")
|
|
OmegaConf.save(config, config_save_path)
|
|
|
|
self.accelerator.init_trackers("vqgan")
|
|
|
|
if self.test_only:
|
|
empty_cache()
|
|
self.evaluate()
|
|
self.accelerator.wait_for_everyone()
|
|
empty_cache()
|
|
return
|
|
|
|
for epoch in range(self.num_epoch):
|
|
if ((epoch + 1) * len(self.train_dl)) <= self.loaded_steps:
|
|
if self.accelerator.is_main_process:
|
|
print(f"Epoch {epoch} is skipped because it is loaded from ckpt")
|
|
self.steps += len(self.train_dl)
|
|
continue
|
|
|
|
if self.steps < self.loaded_steps:
|
|
for _ in self.train_dl:
|
|
self.steps += 1
|
|
if self.steps >= self.loaded_steps:
|
|
break
|
|
|
|
|
|
self.accelerator.unwrap_model(self.model).current_epoch = epoch
|
|
self.model.train()
|
|
|
|
logger = MetricLogger(delimiter=" ")
|
|
logger.add_meter('lr', SmoothedValue(window_size=1, fmt='{value:.6f}'))
|
|
header = 'Epoch: [{}/{}]'.format(epoch, self.num_epoch)
|
|
print_freq = 20
|
|
for data_iter_step, batch in enumerate(logger.log_every(self.train_dl, print_freq, header)):
|
|
|
|
if isinstance(batch, tuple) or isinstance(batch, list):
|
|
batch = tuple(b.to(self.device, non_blocking=True) for b in batch)
|
|
if self.cache_mode:
|
|
img, latent, targets = batch[0], batch[1], batch[2]
|
|
img = img.to(self.device, non_blocking=True)
|
|
latent = latent.to(self.device, non_blocking=True)
|
|
targets = targets.to(self.device, non_blocking=True)
|
|
else:
|
|
latent = None
|
|
img, targets = batch[0], batch[1]
|
|
img = img.to(self.device, non_blocking=True)
|
|
targets = targets.to(self.device, non_blocking=True)
|
|
else:
|
|
img = batch
|
|
latent = None
|
|
|
|
self.steps += 1
|
|
|
|
with self.accelerator.accumulate(self.model):
|
|
with self.accelerator.autocast():
|
|
if self.steps == 1:
|
|
print(f"Training batch size: {img.size(0)}")
|
|
print(f"Hello from index {self.accelerator.local_process_index}")
|
|
losses = self.model(img, targets, latents=latent, epoch=epoch)
|
|
|
|
loss = sum([v for _, v in losses.items()])
|
|
diff_loss = losses["diff_loss"]
|
|
if self.use_vq:
|
|
vq_loss = losses["vq_loss"]
|
|
|
|
self.accelerator.backward(loss)
|
|
if self.accelerator.sync_gradients and self.max_grad_norm is not None:
|
|
self.accelerator.clip_grad_norm_(self.model.parameters(), self.max_grad_norm)
|
|
self.accelerator.unwrap_model(self.model).cancel_gradients_encoder(epoch)
|
|
self.g_optim.step()
|
|
if self.g_sched is not None:
|
|
self.g_sched.step_update(self.steps)
|
|
self.g_optim.zero_grad()
|
|
|
|
|
|
|
|
|
|
if self.enable_ema:
|
|
self.ema_model.update(self.accelerator.unwrap_model(self.model))
|
|
|
|
logger.update(diff_loss=diff_loss.item())
|
|
if self.use_vq:
|
|
logger.update(vq_loss=vq_loss.item() / self.vq_beta)
|
|
if 'kl_loss' in losses:
|
|
logger.update(kl_loss=losses["kl_loss"].item())
|
|
if 'repa_loss' in losses:
|
|
logger.update(repa_loss=losses["repa_loss"].item())
|
|
logger.update(lr=self.g_optim.param_groups[0]["lr"])
|
|
|
|
if self.steps % self.save_every == 0:
|
|
self.save()
|
|
|
|
if (self.steps % self.samp_every == 0) or (self.steps % self.fid_every == 0):
|
|
empty_cache()
|
|
self.evaluate()
|
|
self.accelerator.wait_for_everyone()
|
|
empty_cache()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
logger.synchronize_between_processes()
|
|
if self.accelerator.is_main_process:
|
|
print("Averaged stats:", logger)
|
|
|
|
self.accelerator.end_training()
|
|
self.save()
|
|
if self.accelerator.is_main_process:
|
|
print("Train finished!")
|
|
|
|
def save(self):
|
|
self.accelerator.wait_for_everyone()
|
|
self.accelerator.save_state(
|
|
os.path.join(self.model_saved_dir, f"step{self.steps}")
|
|
)
|
|
|
|
@torch.no_grad()
|
|
def evaluate(self, use_ema=True):
|
|
self.model.eval()
|
|
|
|
use_ema = use_ema and self.enable_ema and self.eval_fid and not self.test_only
|
|
|
|
if use_ema:
|
|
if hasattr(self, "ema_model"):
|
|
model_without_ddp = self.accelerator.unwrap_model(self.model)
|
|
model_state_dict = copy.deepcopy(model_without_ddp.state_dict())
|
|
ema_state_dict = copy.deepcopy(model_without_ddp.state_dict())
|
|
for i, (name, _value) in enumerate(model_without_ddp.named_parameters()):
|
|
if "nested_sampler" in name:
|
|
continue
|
|
if name in self.ema_model.state_dict():
|
|
ema_state_dict[name] = self.ema_model.state_dict()[name]
|
|
if self.accelerator.is_main_process:
|
|
print("Switch to ema")
|
|
msg = model_without_ddp.load_state_dict(ema_state_dict)
|
|
if self.accelerator.is_main_process:
|
|
print(msg)
|
|
else:
|
|
print("EMA model not found, using original model")
|
|
use_ema = False
|
|
|
|
if not self.test_only:
|
|
with tqdm(
|
|
self.valid_dl,
|
|
dynamic_ncols=True,
|
|
disable=not self.accelerator.is_main_process,
|
|
) as valid_dl:
|
|
for batch_i, batch in enumerate(valid_dl):
|
|
if isinstance(batch, tuple) or isinstance(batch, list):
|
|
img, targets = batch[0], batch[1]
|
|
else:
|
|
img = batch
|
|
|
|
with self.accelerator.autocast():
|
|
rec = self.model(img, targets, sample=True, inference_with_n_slots=self.test_num_slots, cfg=1.0)
|
|
imgs_and_recs = torch.stack((img.to(rec.device), rec), dim=0)
|
|
imgs_and_recs = rearrange(imgs_and_recs, "r b ... -> (b r) ...")
|
|
imgs_and_recs = imgs_and_recs.detach().cpu().float()
|
|
|
|
grid = make_grid(
|
|
imgs_and_recs, nrow=6, normalize=True, value_range=(0, 1)
|
|
)
|
|
if self.accelerator.is_main_process:
|
|
save_image(
|
|
grid,
|
|
os.path.join(
|
|
self.image_saved_dir, f"step_{self.steps}_slots{self.test_num_slots}_{batch_i}.jpg"
|
|
),
|
|
)
|
|
|
|
if self.cfg != 1.0:
|
|
with self.accelerator.autocast():
|
|
rec = self.model(img, targets, sample=True, inference_with_n_slots=self.test_num_slots, cfg=self.cfg)
|
|
|
|
imgs_and_recs = torch.stack((img.to(rec.device), rec), dim=0)
|
|
imgs_and_recs = rearrange(imgs_and_recs, "r b ... -> (b r) ...")
|
|
imgs_and_recs = imgs_and_recs.detach().cpu().float()
|
|
|
|
grid = make_grid(
|
|
imgs_and_recs, nrow=6, normalize=True, value_range=(0, 1)
|
|
)
|
|
if self.accelerator.is_main_process:
|
|
save_image(
|
|
grid,
|
|
os.path.join(
|
|
self.image_saved_dir, f"step_{self.steps}_cfg_{self.cfg}_slots{self.test_num_slots}_{batch_i}.jpg"
|
|
),
|
|
)
|
|
if (self.eval_fid and self.test_dl is not None) and (self.test_only or (self.steps % self.fid_every == 0)):
|
|
|
|
if self.test_dataset_size > 10000:
|
|
real_dir = "./dataset/imagenet/val256"
|
|
else:
|
|
real_dir = "./dataset/coco/val2017_256"
|
|
rec_dir = os.path.join(self.image_saved_dir, f"rec_step{self.steps}_slots{self.test_num_slots}")
|
|
os.makedirs(rec_dir, exist_ok=True)
|
|
|
|
if self.cfg != 1.0:
|
|
rec_cfg_dir = os.path.join(self.image_saved_dir, f"rec_step{self.steps}_cfg_{self.cfg}_slots{self.test_num_slots}")
|
|
os.makedirs(rec_cfg_dir, exist_ok=True)
|
|
|
|
def process_batch(cfg_value, save_dir, header):
|
|
logger = MetricLogger(delimiter=" ")
|
|
print_freq = 5
|
|
psnr_values = []
|
|
ssim_values = []
|
|
total_processed = 0
|
|
|
|
for batch_i, batch in enumerate(logger.log_every(self.test_dl, print_freq, header)):
|
|
imgs, targets = (batch[0], batch[1]) if isinstance(batch, (tuple, list)) else (batch, None)
|
|
|
|
|
|
if total_processed >= self.test_dataset_size:
|
|
break
|
|
|
|
imgs = imgs.to(self.device, non_blocking=True)
|
|
if targets is not None:
|
|
targets = targets.to(self.device, non_blocking=True)
|
|
|
|
with self.accelerator.autocast():
|
|
recs = self.model(imgs, targets, sample=True, inference_with_n_slots=self.test_num_slots, cfg=cfg_value)
|
|
|
|
psnr_val = psnr(recs, imgs, data_range=1.0)
|
|
ssim_val = ssim(recs, imgs, data_range=1.0)
|
|
|
|
recs = concat_all_gather(recs).detach()
|
|
psnr_val = concat_all_gather(psnr_val.view(1))
|
|
ssim_val = concat_all_gather(ssim_val.view(1))
|
|
|
|
|
|
samples_in_batch = min(
|
|
recs.size(0),
|
|
self.test_dataset_size - total_processed
|
|
)
|
|
recs = recs[:samples_in_batch]
|
|
psnr_val = psnr_val[:samples_in_batch]
|
|
ssim_val = ssim_val[:samples_in_batch]
|
|
psnr_values.append(psnr_val)
|
|
ssim_values.append(ssim_val)
|
|
|
|
if self.accelerator.is_main_process:
|
|
rec_paths = [os.path.join(save_dir, f"step_{self.steps}_slots{self.test_num_slots}_{batch_i}_{j}_rec_cfg_{cfg_value}_slots{self.test_num_slots}.png")
|
|
for j in range(recs.size(0))]
|
|
save_img_batch(recs.cpu(), rec_paths)
|
|
|
|
total_processed += samples_in_batch
|
|
|
|
self.accelerator.wait_for_everyone()
|
|
|
|
return torch.cat(psnr_values).mean(), torch.cat(ssim_values).mean()
|
|
|
|
|
|
def calculate_and_log_metrics(real_dir, rec_dir, cfg_value, psnr_val, ssim_val):
|
|
if self.accelerator.is_main_process:
|
|
metrics_dict = get_fid_stats(real_dir, rec_dir, self.fid_stats)
|
|
fid = metrics_dict["frechet_inception_distance"]
|
|
inception_score = metrics_dict["inception_score_mean"]
|
|
|
|
metric_prefix = "fid_ema" if use_ema else "fid"
|
|
isc_prefix = "isc_ema" if use_ema else "isc"
|
|
self.accelerator.log({
|
|
metric_prefix: fid,
|
|
isc_prefix: inception_score,
|
|
f"psnr_{'ema' if use_ema else 'test'}": psnr_val,
|
|
f"ssim_{'ema' if use_ema else 'test'}": ssim_val,
|
|
"cfg": cfg_value
|
|
}, step=self.steps)
|
|
|
|
print(f"{'EMA ' if use_ema else ''}{f'CFG: {cfg_value}'} "
|
|
f"FID: {fid:.2f}, ISC: {inception_score:.2f}, "
|
|
f"PSNR: {psnr_val:.2f}, SSIM: {ssim_val:.4f}")
|
|
|
|
|
|
if self.cfg == 1.0 or not self.test_only:
|
|
psnr_val, ssim_val = process_batch(1.0, rec_dir, 'Testing: w/o CFG')
|
|
calculate_and_log_metrics(real_dir, rec_dir, 1.0, psnr_val, ssim_val)
|
|
|
|
|
|
if self.cfg != 1.0:
|
|
psnr_val, ssim_val = process_batch(self.cfg, rec_cfg_dir, 'Testing: w/ CFG')
|
|
calculate_and_log_metrics(real_dir, rec_cfg_dir, self.cfg, psnr_val, ssim_val)
|
|
|
|
|
|
if self.accelerator.is_main_process:
|
|
shutil.rmtree(rec_dir)
|
|
if self.cfg != 1.0:
|
|
shutil.rmtree(rec_cfg_dir)
|
|
|
|
|
|
if use_ema:
|
|
if self.accelerator.is_main_process:
|
|
print("Switch back from ema")
|
|
model_without_ddp.load_state_dict(model_state_dict)
|
|
|
|
self.model.train() |