|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import os |
|
import math |
|
from tqdm.auto import tqdm |
|
import torch |
|
import torch.nn as nn |
|
import torchvision |
|
import numpy as np |
|
from torchvision.utils import make_grid |
|
from einops import rearrange, repeat |
|
from accelerate.logging import get_logger |
|
from taming.modules.losses.vqperceptual import hinge_d_loss |
|
|
|
from .base_trainer import Trainer |
|
from lam.utils.profiler import DummyProfiler |
|
from lam.runners import REGISTRY_RUNNERS |
|
from lam.utils.hf_hub import wrap_model_hub |
|
from safetensors.torch import load_file |
|
from pytorch3d.ops.knn import knn_points |
|
import torch.nn.functional as F |
|
|
|
logger = get_logger(__name__) |
|
|
|
|
|
|
|
|
|
from omegaconf import OmegaConf |
|
@REGISTRY_RUNNERS.register('train.lam') |
|
class LAMTrainer(Trainer): |
|
|
|
EXP_TYPE: str = 'lam' |
|
|
|
def __init__(self): |
|
super().__init__() |
|
|
|
self.model = self._build_model(self.cfg) |
|
if self.has_disc: |
|
self.model_disc = self._build_model_disc(self.cfg) |
|
self.optimizer = self._build_optimizer(self.model, self.cfg) |
|
if self.has_disc: |
|
self.optimizer_disc = self._build_optimizer(self.model_disc, self.cfg) |
|
|
|
self.train_loader, self.val_loader = self._build_dataloader(self.cfg) |
|
self.scheduler = self._build_scheduler(self.optimizer, self.cfg) |
|
if self.has_disc: |
|
self.scheduler_disc = self._build_scheduler(self.optimizer_disc, self.cfg) |
|
self.pixel_loss_fn, self.perceptual_loss_fn, self.tv_loss_fn = self._build_loss_fn(self.cfg) |
|
self.only_sym_conf = 2 |
|
print("==="*16*3, "\n"+"only_sym_conf:", self.only_sym_conf, "\n"+"==="*16*3) |
|
|
|
|
|
def _build_model(self, cfg): |
|
assert cfg.experiment.type == 'lrm', \ |
|
f"Config type {cfg.experiment.type} does not match with runner {self.__class__.__name__}" |
|
from lam.models import ModelLAM |
|
model = ModelLAM(**cfg.model) |
|
|
|
|
|
if len(self.cfg.train.resume) > 0: |
|
resume = self.cfg.train.resume |
|
print("==="*16*3) |
|
self.accelerator.print("loading pretrained weight from:", resume) |
|
if resume.endswith('safetensors'): |
|
ckpt = load_file(resume, device='cpu') |
|
else: |
|
ckpt = torch.load(resume, map_location='cpu') |
|
state_dict = model.state_dict() |
|
for k, v in ckpt.items(): |
|
if k in state_dict: |
|
if state_dict[k].shape == v.shape: |
|
state_dict[k].copy_(v) |
|
else: |
|
self.accelerator.print(f"WARN] mismatching shape for param {k}: ckpt {v.shape} != model {state_dict[k].shape}, ignored.") |
|
else: |
|
self.accelerator.print(f"WARN] unexpected param {k}: {v.shape}") |
|
self.accelerator.print("Finish loading ckpt:", resume, "\n"+"==="*16*3) |
|
return model |
|
|
|
def _build_model_disc(self, cfg): |
|
if cfg.model.disc.type == "pix2pix": |
|
from lam.models.discriminator import NLayerDiscriminator, weights_init |
|
model = NLayerDiscriminator(input_nc=cfg.model.disc.in_channels, |
|
n_layers=cfg.model.disc.num_layers, |
|
use_actnorm=cfg.model.disc.use_actnorm |
|
).apply(weights_init) |
|
|
|
elif cfg.model.disc.type == "vqgan": |
|
from lam.models.discriminator import Discriminator |
|
model = Discriminator(in_channels=cfg.model.disc.in_channels, |
|
cond_channels=0, hidden_channels=512, |
|
depth=cfg.model.disc.depth) |
|
elif cfg.model.disc.type == "stylegan": |
|
from lam.models.gan.stylegan_discriminator import SingleDiscriminatorV2, SingleDiscriminator |
|
from lam.models.gan.stylegan_discriminator_torch import Discriminator |
|
|
|
model = Discriminator(512, channel_multiplier=2) |
|
|
|
model.input_size = cfg.model.disc.img_res |
|
else: |
|
raise NotImplementedError |
|
return model |
|
|
|
def _build_optimizer(self, model: nn.Module, cfg): |
|
decay_params, no_decay_params = [], [] |
|
|
|
|
|
for name, module in model.named_modules(): |
|
if isinstance(module, nn.LayerNorm): |
|
no_decay_params.extend([p for p in module.parameters()]) |
|
elif hasattr(module, 'bias') and module.bias is not None: |
|
no_decay_params.append(module.bias) |
|
|
|
|
|
_no_decay_ids = set(map(id, no_decay_params)) |
|
decay_params = [p for p in model.parameters() if id(p) not in _no_decay_ids] |
|
|
|
|
|
decay_params = list(filter(lambda p: p.requires_grad, decay_params)) |
|
no_decay_params = list(filter(lambda p: p.requires_grad, no_decay_params)) |
|
|
|
|
|
logger.info("======== Weight Decay Parameters ========") |
|
logger.info(f"Total: {len(decay_params)}") |
|
logger.info("======== No Weight Decay Parameters ========") |
|
logger.info(f"Total: {len(no_decay_params)}") |
|
|
|
|
|
opt_groups = [ |
|
{'params': decay_params, 'weight_decay': cfg.train.optim.weight_decay}, |
|
{'params': no_decay_params, 'weight_decay': 0.0}, |
|
] |
|
optimizer = torch.optim.AdamW( |
|
opt_groups, |
|
lr=cfg.train.optim.lr, |
|
betas=(cfg.train.optim.beta1, cfg.train.optim.beta2), |
|
) |
|
|
|
return optimizer |
|
|
|
def _build_scheduler(self, optimizer, cfg): |
|
local_batches_per_epoch = math.floor(len(self.train_loader) / self.accelerator.num_processes) |
|
total_global_batches = cfg.train.epochs * math.ceil(local_batches_per_epoch / self.cfg.train.accum_steps) |
|
effective_warmup_iters = cfg.train.scheduler.warmup_real_iters |
|
logger.debug(f"======== Scheduler effective max iters: {total_global_batches} ========") |
|
logger.debug(f"======== Scheduler effective warmup iters: {effective_warmup_iters} ========") |
|
if cfg.train.scheduler.type == 'cosine': |
|
from lam.utils.scheduler import CosineWarmupScheduler |
|
scheduler = CosineWarmupScheduler( |
|
optimizer=optimizer, |
|
warmup_iters=effective_warmup_iters, |
|
max_iters=total_global_batches, |
|
) |
|
else: |
|
raise NotImplementedError(f"Scheduler type {cfg.train.scheduler.type} not implemented") |
|
return scheduler |
|
|
|
def _build_dataloader(self, cfg): |
|
|
|
from lam.datasets import MixerDataset |
|
gaga_track_type = cfg.dataset.get("gaga_track_type", "vfhq_gagtrack") |
|
sample_aug_views = cfg.dataset.get("sample_aug_views", 0) |
|
|
|
|
|
load_normal = cfg.train.loss.get("normal_weight", False) > 0. if hasattr(cfg.train.loss, "normal_weight") else False |
|
load_normal = load_normal or (cfg.train.loss.get("surfel_normal_weight", False) > 0. if hasattr(cfg.train.loss, "surfel_normal_weight") else False) |
|
print("==="*16*3, "\nload_normal:", load_normal) |
|
train_dataset = MixerDataset( |
|
split="train", |
|
subsets=cfg.dataset.subsets, |
|
sample_side_views=cfg.dataset.sample_side_views, |
|
render_image_res_low=cfg.dataset.render_image.low, |
|
render_image_res_high=cfg.dataset.render_image.high, |
|
render_region_size=cfg.dataset.render_image.region, |
|
source_image_res=cfg.dataset.source_image_res, |
|
repeat_num=cfg.dataset.repeat_num if hasattr(cfg.dataset, "repeat_num") else 1, |
|
multiply=cfg.dataset.multiply if hasattr(cfg.dataset, "multiply") else 14, |
|
debug=cfg.dataset.debug if hasattr(cfg.dataset, "debug") else False, |
|
is_val=False, |
|
gaga_track_type=gaga_track_type, |
|
sample_aug_views=sample_aug_views, |
|
load_albedo=cfg.model.get("render_albedo", False) if hasattr(cfg.model, "render_albedo") else False, |
|
load_normal=load_normal, |
|
) |
|
val_dataset = MixerDataset( |
|
split="val", |
|
subsets=cfg.dataset.subsets, |
|
sample_side_views=cfg.dataset.sample_side_views, |
|
render_image_res_low=cfg.dataset.render_image.low, |
|
render_image_res_high=cfg.dataset.render_image.high, |
|
render_region_size=cfg.dataset.render_image.region, |
|
source_image_res=cfg.dataset.source_image_res, |
|
repeat_num=cfg.dataset.repeat_num if hasattr(cfg.dataset, "repeat_num") else 1, |
|
multiply=cfg.dataset.multiply if hasattr(cfg.dataset, "multiply") else 14, |
|
debug=cfg.dataset.debug if hasattr(cfg.dataset, "debug") else False, |
|
is_val=True, |
|
gaga_track_type=gaga_track_type, |
|
sample_aug_views=sample_aug_views, |
|
load_albedo=cfg.model.get("render_albedo", False) if hasattr(cfg.model, "render_albedo") else False, |
|
load_normal=load_normal, |
|
) |
|
|
|
|
|
train_loader = torch.utils.data.DataLoader( |
|
train_dataset, |
|
batch_size=cfg.train.batch_size, |
|
shuffle=True, |
|
drop_last=True, |
|
num_workers=cfg.dataset.num_train_workers, |
|
pin_memory=cfg.dataset.pin_mem, |
|
persistent_workers=True, |
|
) |
|
val_loader = torch.utils.data.DataLoader( |
|
val_dataset, |
|
batch_size=cfg.val.batch_size, |
|
shuffle=False, |
|
drop_last=False, |
|
num_workers=cfg.dataset.num_val_workers, |
|
pin_memory=cfg.dataset.pin_mem, |
|
persistent_workers=False, |
|
) |
|
|
|
return train_loader, val_loader |
|
|
|
def _build_loss_fn(self, cfg): |
|
from lam.losses import PixelLoss, LPIPSLoss, TVLoss |
|
pixel_loss_fn = PixelLoss(option=cfg.train.loss.get("pixel_loss_fn", "mse")) |
|
with self.accelerator.main_process_first(): |
|
perceptual_loss_fn = LPIPSLoss(device=self.device, prefech=True) |
|
|
|
if cfg.model.get("use_conf_map", False): |
|
assert cfg.train.loss.get("head_pl", False), "Set head_pl in train.loss to true to use faceperceptualloss when using conf_map." |
|
tv_loss_fn = TVLoss() |
|
return pixel_loss_fn, perceptual_loss_fn, tv_loss_fn |
|
|
|
def register_hooks(self): |
|
pass |
|
|
|
def get_flame_params(self, data, is_source=False): |
|
flame_params = {} |
|
flame_keys = ['root_pose', 'body_pose', 'jaw_pose', 'leye_pose', 'reye_pose', 'lhand_pose', 'rhand_pose', 'expr', 'trans', 'betas',\ |
|
'rotation', 'neck_pose', 'eyes_pose', 'translation', "teeth_bs"] |
|
if is_source: |
|
flame_keys = ['source_'+item for item in flame_keys] |
|
for k, v in data.items(): |
|
if k in flame_keys: |
|
|
|
flame_params[k] = data[k] |
|
return flame_params |
|
|
|
def cross_copy(self, data): |
|
B = data.shape[0] |
|
assert data.shape[1] == 1 |
|
new_data = [] |
|
for i in range(B): |
|
B_i = [data[i]] |
|
for j in range(B): |
|
if j != i: |
|
B_i.append(data[j]) |
|
new_data.append(torch.concat(B_i, dim=0)) |
|
new_data = torch.stack(new_data, dim=0) |
|
|
|
return new_data |
|
|
|
def prepare_cross_render_data(self, data): |
|
B, N_v, C, H, W = data['render_image'].shape |
|
assert N_v == 1 |
|
|
|
|
|
data["c2ws"] = self.cross_copy(data["c2ws"]) |
|
data["intrs"] = self.cross_copy(data["intrs"]) |
|
data["render_full_resolutions"] = self.cross_copy(data["render_full_resolutions"]) |
|
data["render_image"] = self.cross_copy(data["render_image"]) |
|
data["render_mask"] = self.cross_copy(data["render_mask"]) |
|
data["render_bg_colors"] = self.cross_copy(data["render_bg_colors"]) |
|
flame_params = self.get_flame_params(data) |
|
for key in flame_params.keys(): |
|
if "betas" not in key: |
|
data[key] = self.cross_copy(data[key]) |
|
source_flame_params = self.get_flame_params(data, is_source=True) |
|
for key in source_flame_params.keys(): |
|
if "betas" not in key: |
|
data[key] = self.cross_copy(data[key]) |
|
|
|
return data |
|
|
|
def get_loss_weight(self, loss_weight): |
|
if isinstance(loss_weight, str) and ":" in loss_weight: |
|
start_step, start_value, end_value, end_step = map(float, loss_weight.split(":")) |
|
current_step = self.global_step |
|
value = start_value + (end_value - start_value) * max( |
|
min(1.0, (current_step - start_step) / (end_step - start_step)), 0.0 |
|
) |
|
return value |
|
elif isinstance(loss_weight, (float, int)): |
|
return loss_weight |
|
else: |
|
raise NotImplementedError |
|
|
|
def forward_loss_local_step(self, data): |
|
render_image = data['render_image'] |
|
render_albedo = data.get('render_albedo', None) |
|
render_mask = data['render_mask'] |
|
render_normal = data.get('render_normal', None) |
|
B, N_v, C, H, W = render_image.shape |
|
flame_params = self.get_flame_params(data) |
|
source_flame_params = self.get_flame_params(data, is_source=True) |
|
|
|
|
|
outputs = self.model( |
|
image=data['source_rgbs'], |
|
source_c2ws=data['source_c2ws'], |
|
source_intrs=data['source_intrs'], |
|
render_c2ws=data['c2ws'], |
|
render_intrs=data['intrs'], |
|
render_bg_colors=data['render_bg_colors'], |
|
flame_params=flame_params, |
|
source_flame_params=source_flame_params, |
|
render_images=render_image, |
|
data = data |
|
) |
|
|
|
|
|
loss = 0. |
|
loss_pixel = None |
|
loss_perceptual = None |
|
loss_mask = None |
|
extra_loss_dict = {} |
|
|
|
num_aug_view = self.cfg.dataset.get("sample_aug_views", 0) |
|
real_num_view = data["real_num_view"] - num_aug_view |
|
|
|
conf_sigma_l1 = outputs.get("conf_sigma_l1", None) |
|
conf_sigma_percl = outputs.get("conf_sigma_percl", None) |
|
if self.cfg.model.use_sym_proj: |
|
real_num_view *= 2 |
|
if self.cfg.model.use_conf_map: |
|
conf_sigma_l1 = rearrange(conf_sigma_l1, "b v (c r) h w -> b (v r) c h w", r=2)[:, :real_num_view] |
|
conf_sigma_percl = rearrange(conf_sigma_percl, "b v (c r) h w -> b (v r) c h w", r=2)[:, :real_num_view] |
|
render_image = repeat(data['render_image'], "b v c h w -> b (v r) c h w", r=2) |
|
render_albedo = repeat(render_albedo, "b v c h w -> b (v r) c h w", r=2) if render_albedo is not None else None |
|
render_mask = repeat(data['render_mask'], "b v c h w -> b (v r) c h w", r=2) |
|
if "render_normal" in data.keys(): |
|
render_normal = repeat(data['render_normal'], "b v c h w -> b (v r) c h w", r=2) |
|
for k, v in data.items(): |
|
if "bbox" in k: |
|
data[k] = repeat(v, "b v c -> b (v r) c", r=2) |
|
|
|
only_sym_conf = self.only_sym_conf |
|
|
|
if self.get_loss_weight(self.cfg.train.loss.get("masked_pixel_weight", 0)) > 0.: |
|
gt_rgb = render_image[:, :real_num_view] * render_mask[:, :real_num_view] + 1.0 * (1 - render_mask[:, :real_num_view]) |
|
pred_rgb = outputs['comp_rgb'][:, :real_num_view] * render_mask[:, :real_num_view] + 1.0 * (1 - render_mask[:, :real_num_view]) |
|
|
|
loss_pixel = self.pixel_loss_fn(pred_rgb, gt_rgb, conf_sigma_l1, only_sym_conf=only_sym_conf) * self.get_loss_weight(self.cfg.train.loss.masked_pixel_weight) |
|
loss += loss_pixel |
|
|
|
|
|
loss_perceptual = self.perceptual_loss_fn(pred_rgb, gt_rgb, conf_sigma=conf_sigma_percl, only_sym_conf=only_sym_conf) * self.get_loss_weight(self.cfg.train.loss.masked_pixel_weight) |
|
loss += loss_perceptual |
|
|
|
if self.get_loss_weight(self.cfg.train.loss.pixel_weight) > 0.: |
|
total_loss_pixel = loss_pixel |
|
if (hasattr(self.cfg.train.loss, 'rgb_weight') and self.get_loss_weight(self.cfg.train.loss.rgb_weight) > 0.) or not hasattr(self.cfg.train.loss, "rgb_weight"): |
|
loss_pixel = self.pixel_loss_fn( |
|
outputs['comp_rgb'][:, :real_num_view], render_image[:, :real_num_view], conf_sigma=conf_sigma_l1, only_sym_conf=only_sym_conf |
|
) * self.get_loss_weight(self.cfg.train.loss.pixel_weight) |
|
loss += loss_pixel |
|
if total_loss_pixel is not None: |
|
loss_pixel += total_loss_pixel |
|
|
|
if self.get_loss_weight(self.cfg.train.loss.perceptual_weight) > 0.: |
|
total_loss_perceptual = loss_perceptual |
|
if (hasattr(self.cfg.train.loss, 'rgb_weight') and self.get_loss_weight(self.cfg.train.loss.rgb_weight) > 0.) or not hasattr(self.cfg.train.loss, "rgb_weight"): |
|
loss_perceptual = self.perceptual_loss_fn( |
|
outputs['comp_rgb'][:, :real_num_view], render_image[:, :real_num_view], conf_sigma=conf_sigma_percl, only_sym_conf=only_sym_conf |
|
) * self.get_loss_weight(self.cfg.train.loss.perceptual_weight) |
|
loss += loss_perceptual |
|
if total_loss_perceptual is not None: |
|
loss_perceptual += total_loss_perceptual |
|
|
|
if self.get_loss_weight(self.cfg.train.loss.mask_weight) > 0. and 'comp_mask' in outputs.keys(): |
|
loss_mask = self.pixel_loss_fn(outputs['comp_mask'][:, :real_num_view], render_mask[:, :real_num_view], conf_sigma=conf_sigma_l1, only_sym_conf=only_sym_conf |
|
) * self.get_loss_weight(self.cfg.train.loss.mask_weight) |
|
loss += loss_mask |
|
|
|
if hasattr(self.cfg.train.loss, 'offset_reg_weight') and self.get_loss_weight(self.cfg.train.loss.offset_reg_weight) > 0.: |
|
loss_offset_reg = 0 |
|
for b_idx in range(len(outputs['3dgs'])): |
|
loss_offset_reg += torch.nn.functional.mse_loss(outputs['3dgs'][b_idx][0].offset.float(), torch.zeros_like(outputs['3dgs'][b_idx][0].offset.float())) |
|
loss_offset_reg = loss_offset_reg / len(outputs['3dgs']) |
|
loss += loss_offset_reg * self.get_loss_weight(self.cfg.train.loss.offset_reg_weight) |
|
else: |
|
loss_offset_reg = None |
|
|
|
return outputs, loss, loss_pixel, loss_perceptual, loss_offset_reg, loss_mask, extra_loss_dict |
|
|
|
def adopt_weight(self, weight, global_step, threshold=0, value=0.): |
|
if global_step < threshold: |
|
weight = value |
|
return weight |
|
|
|
def calculate_adaptive_weight(self, nll_loss, g_loss, last_layer, discriminator_weight=1): |
|
nll_grads = torch.autograd.grad(nll_loss, last_layer, retain_graph=True)[0] |
|
g_grads = torch.autograd.grad(g_loss, last_layer, retain_graph=True)[0] |
|
|
|
d_weight = torch.norm(nll_grads) / (torch.norm(g_grads) + 1e-4) |
|
d_weight = torch.clamp(d_weight, 0.0, 1e4).detach() |
|
d_weight = d_weight * discriminator_weight |
|
return d_weight |
|
|
|
def disc_preprocess(self, img): |
|
|
|
img = torch.flatten(img, 0, 1) |
|
|
|
|
|
img = 2 * img - 1 |
|
|
|
if hasattr(self.accelerator.unwrap_model(self.model_disc), "input_size"): |
|
tgt_size = self.accelerator.unwrap_model(self.model_disc).input_size |
|
img = nn.functional.interpolate(img, (tgt_size, tgt_size)) |
|
img = img.float() |
|
|
|
return img |
|
|
|
def forward_to_get_loss_with_gen_loss(self, data): |
|
|
|
outs, loss, loss_pixel, loss_perceptual, loss_tv, loss_mask, extra_loss_dict = self.forward_loss_local_step(data) |
|
|
|
with torch.autocast(device_type=outs["comp_rgb"].device.type, dtype=torch.float32): |
|
logits_fake = self.model_disc(self.disc_preprocess(outs["comp_rgb"])) |
|
|
|
loss_gen = -torch.mean(logits_fake) |
|
|
|
try: |
|
if loss < 1e-5: |
|
d_weight = self.cfg.model.disc.disc_weight |
|
else: |
|
nll_loss = loss_pixel |
|
if nll_loss is None: |
|
nll_loss = loss |
|
d_weight = self.calculate_adaptive_weight(nll_loss, loss_gen, |
|
last_layer=self.accelerator.unwrap_model(self.model).get_last_layer(), |
|
discriminator_weight=self.cfg.model.disc.disc_weight) |
|
except RuntimeError: |
|
print("*************Error when calculate_adaptive_weight************") |
|
d_weight = torch.tensor(0.0) |
|
|
|
disc_factor = self.adopt_weight(1.0, self.global_step, threshold=self.cfg.model.disc.disc_iter_start) |
|
|
|
|
|
loss += disc_factor * d_weight * loss_gen |
|
|
|
|
|
self.accelerator.backward(loss) |
|
if self.accelerator.sync_gradients and self.cfg.train.optim.clip_grad_norm > 0.: |
|
self.accelerator.clip_grad_norm_(self.model.parameters(), self.cfg.train.optim.clip_grad_norm) |
|
|
|
self.optimizer.step() |
|
self.optimizer.zero_grad() |
|
|
|
return outs, loss, loss_pixel, loss_perceptual, loss_tv, loss_mask, loss_gen, extra_loss_dict |
|
|
|
|
|
def forward_to_get_loss(self, data): |
|
|
|
outs, loss, loss_pixel, loss_perceptual, loss_tv, loss_mask, extra_loss_dict = self.forward_loss_local_step(data) |
|
|
|
|
|
self.accelerator.backward(loss) |
|
if self.accelerator.sync_gradients and self.cfg.train.optim.clip_grad_norm > 0.: |
|
self.accelerator.clip_grad_norm_(self.model.parameters(), self.cfg.train.optim.clip_grad_norm) |
|
|
|
self.optimizer.step() |
|
self.optimizer.zero_grad() |
|
|
|
return outs, loss, loss_pixel, loss_perceptual, loss_tv, loss_mask, extra_loss_dict |
|
|
|
|
|
def forward_disc_loss_local_step(self, pred_img, gt_img): |
|
|
|
with torch.autocast(device_type=pred_img.device.type, dtype=torch.float32): |
|
logits_real = self.model_disc(self.disc_preprocess(gt_img).detach()) |
|
logits_fake = self.model_disc(self.disc_preprocess(pred_img).detach()) |
|
|
|
loss_disc = hinge_d_loss(logits_real, logits_fake) |
|
return loss_disc |
|
|
|
|
|
def forward_to_get_disc_loss(self, pred_img, gt_img): |
|
|
|
loss_disc = self.forward_disc_loss_local_step(pred_img, gt_img) |
|
|
|
disc_factor = self.adopt_weight(1.0, self.global_step, threshold=self.cfg.model.disc.disc_iter_start) |
|
loss = disc_factor * loss_disc |
|
|
|
|
|
self.accelerator.backward(loss) |
|
|
|
if self.accelerator.sync_gradients and self.cfg.train.optim.clip_grad_norm > 0.: |
|
self.accelerator.clip_grad_norm_(self.model_disc.parameters(), self.cfg.train.optim.clip_grad_norm) |
|
|
|
self.optimizer_disc.step() |
|
self.optimizer_disc.zero_grad() |
|
|
|
return loss_disc |
|
|
|
def train_epoch(self, pbar: tqdm, loader: torch.utils.data.DataLoader, profiler: torch.profiler.profile, iepoch: int): |
|
|
|
self.model.train() |
|
if self.has_disc: |
|
self.model_disc.train() |
|
|
|
local_step_losses = [] |
|
global_step_losses = [] |
|
local_step_extra_losses = [] |
|
global_step_extra_losses = [] |
|
extra_loss_keys = [] |
|
|
|
logger.debug(f"======== Starting epoch {self.current_epoch} ========") |
|
loss_disc = None |
|
for idx, data in enumerate(loader): |
|
data["source_rgbs"] = data["source_rgbs"].to(self.weight_dtype) |
|
if self.has_disc and hasattr(self.cfg.model.disc, "cross_render") and self.cfg.model.disc.cross_render: |
|
data = self.prepare_cross_render_data(data) |
|
data["real_num_view"] = 1 |
|
else: |
|
data["real_num_view"] = data["render_image"].shape[1] |
|
|
|
logger.debug(f"======== Starting global step {self.global_step} ========") |
|
|
|
if not self.has_disc: |
|
disc_step = False |
|
with self.accelerator.accumulate(self.model): |
|
outs, loss, loss_pixel, loss_perceptual, loss_tv, loss_mask, extra_loss_dict = self.forward_to_get_loss(data) |
|
|
|
|
|
loss_disc, loss_gen = None, None |
|
local_step_losses.append(torch.stack([ |
|
_loss.detach() if _loss is not None else torch.tensor(float('nan'), device=self.device) |
|
for _loss in [loss, loss_pixel, loss_perceptual, loss_tv, loss_mask, loss_disc, loss_gen] |
|
])) |
|
extra_loss_keys = sorted(list(extra_loss_dict.keys())) |
|
if len(extra_loss_keys) > 0: |
|
local_step_extra_losses.append(torch.stack([ |
|
extra_loss_dict[k].detach() if extra_loss_dict[k] is not None else torch.tensor(float('nan'), device=self.device) |
|
for k in extra_loss_keys |
|
])) |
|
else: |
|
disc_step = (idx % 5) == 0 or (iepoch * len(loader) + idx < 100 and idx % 2 == 0) |
|
local_step_losses_bak = torch.zeros(6, device=data["source_rgbs"].device) |
|
if not disc_step: |
|
with self.accelerator.accumulate(self.model): |
|
|
|
outs, loss, loss_pixel, loss_perceptual, loss_tv, loss_mask, loss_gen, extra_loss_dict = self.forward_to_get_loss_with_gen_loss(data) |
|
|
|
local_step_losses.append(torch.stack([ |
|
_loss.detach() if _loss is not None else torch.tensor(float('nan'), device=self.device) |
|
for _loss in [loss, loss_pixel, loss_perceptual, loss_tv, loss_mask, loss_gen, loss_disc] |
|
])) |
|
local_step_losses_bak = local_step_losses[-1].detach() |
|
torch.cuda.empty_cache() |
|
extra_loss_keys = sorted(list(extra_loss_dict.keys())) |
|
if len(extra_loss_keys) > 0: |
|
local_step_extra_losses.append(torch.stack([ |
|
extra_loss_dict[k].detach() if extra_loss_dict[k] is not None else torch.tensor(float('nan'), device=self.device) |
|
for k in extra_loss_keys |
|
])) |
|
else: |
|
with self.accelerator.accumulate(self.model_disc): |
|
|
|
outs, _, _, _, _, _, _ = self.forward_loss_local_step(data) |
|
loss_disc = self.forward_to_get_disc_loss(pred_img=outs["comp_rgb"], |
|
gt_img=data["render_image"]) |
|
local_step_losses.append(torch.concat([local_step_losses_bak[:6], loss_disc.unsqueeze(0)], dim=0)) |
|
torch.cuda.empty_cache() |
|
|
|
|
|
if self.accelerator.sync_gradients: |
|
profiler.step() |
|
if not disc_step: |
|
self.scheduler.step() |
|
if self.has_disc and disc_step: |
|
self.scheduler_disc.step() |
|
logger.debug(f"======== Scheduler step ========") |
|
self.global_step += 1 |
|
global_step_loss = self.accelerator.gather(torch.stack(local_step_losses)).mean(dim=0).cpu() |
|
if len(extra_loss_keys) > 0: |
|
global_step_extra_loss = self.accelerator.gather(torch.stack(local_step_extra_losses)).mean(dim=0).cpu() |
|
global_step_extra_loss_items = global_step_extra_loss.unbind() |
|
else: |
|
global_step_extra_loss = None |
|
global_step_extra_loss_items = [] |
|
loss, loss_pixel, loss_perceptual, loss_tv, loss_mask, loss_gen, loss_disc_ = global_step_loss.unbind() |
|
loss_kwargs = { |
|
'loss': loss.item(), |
|
'loss_pixel': loss_pixel.item(), |
|
'loss_perceptual': loss_perceptual.item(), |
|
'loss_tv': loss_tv.item(), |
|
'loss_mask': loss_mask.item(), |
|
'loss_disc': loss_disc_.item(), |
|
'loss_gen': loss_gen.item(), |
|
} |
|
for k, loss in zip(extra_loss_keys, global_step_extra_loss_items): |
|
loss_kwargs[k] = loss.item() |
|
self.log_scalar_kwargs( |
|
step=self.global_step, split='train', |
|
**loss_kwargs |
|
) |
|
self.log_optimizer(step=self.global_step, attrs=['lr'], group_ids=[0, 1]) |
|
local_step_losses = [] |
|
global_step_losses.append(global_step_loss) |
|
local_step_extra_losses = [] |
|
global_step_extra_losses.append(global_step_extra_loss) |
|
|
|
|
|
pbar.update(1) |
|
description = { |
|
**loss_kwargs, |
|
'lr': self.optimizer.param_groups[0]['lr'], |
|
} |
|
description = '[TRAIN STEP]' + \ |
|
', '.join(f'{k}={tqdm.format_num(v)}' for k, v in description.items() if not math.isnan(v)) |
|
pbar.set_description(description) |
|
|
|
|
|
if self.global_step % self.cfg.saver.checkpoint_global_steps == 0: |
|
self.save_checkpoint() |
|
if self.global_step % self.cfg.val.global_step_period == 0: |
|
self.evaluate() |
|
self.model.train() |
|
if self.has_disc: |
|
self.model_disc.train() |
|
if (self.global_step % self.cfg.logger.image_monitor.train_global_steps == 0) or (self.global_step < 1000 and self.global_step % 20 == 0): |
|
conf_sigma_l1 = outs.get('conf_sigma_l1', None) |
|
conf_sigma_l1 = conf_sigma_l1.cpu() if conf_sigma_l1 is not None else None |
|
conf_sigma_percl = outs.get('conf_sigma_percl', None) |
|
conf_sigma_percl = conf_sigma_percl.cpu() if conf_sigma_percl is not None else None |
|
self.log_image_monitor( |
|
step=self.global_step, split='train', |
|
renders=outs['comp_rgb'].detach()[:self.cfg.logger.image_monitor.samples_per_log].cpu(), |
|
conf_sigma_l1=conf_sigma_l1, conf_sigma_percl=conf_sigma_percl, |
|
gts=data['render_image'][:self.cfg.logger.image_monitor.samples_per_log].cpu(), |
|
) |
|
if 'comp_mask' in outs.keys(): |
|
self.log_image_monitor( |
|
step=self.global_step, split='train', |
|
renders=outs['comp_mask'].detach()[:self.cfg.logger.image_monitor.samples_per_log].cpu(), |
|
gts=data['render_mask'][:self.cfg.logger.image_monitor.samples_per_log].cpu(), |
|
prefix="_mask", |
|
) |
|
|
|
|
|
if self.global_step >= self.N_max_global_steps: |
|
self.accelerator.set_trigger() |
|
break |
|
|
|
|
|
self.current_epoch += 1 |
|
epoch_losses = torch.stack(global_step_losses).mean(dim=0) |
|
epoch_loss, epoch_loss_pixel, epoch_loss_perceptual, epoch_loss_tv, epoch_loss_mask, epoch_loss_disc, epoch_loss_gen = epoch_losses.unbind() |
|
epoch_loss_dict = { |
|
'loss': epoch_loss.item(), |
|
'loss_pixel': epoch_loss_pixel.item(), |
|
'loss_perceptual': epoch_loss_perceptual.item(), |
|
'loss_tv': epoch_loss_tv.item(), |
|
'loss_mask': epoch_loss_mask.item(), |
|
'loss_disc': epoch_loss_disc.item(), |
|
'loss_gen': epoch_loss_gen.item(), |
|
} |
|
if len(extra_loss_keys) > 0: |
|
epoch_extra_losses = torch.stack(global_step_extra_losses).mean(dim=0) |
|
for k, v in zip(extra_loss_keys, epoch_extra_losses.unbind()): |
|
epoch_loss_dict[k] = v.item() |
|
self.log_scalar_kwargs( |
|
epoch=self.current_epoch, split='train', |
|
**epoch_loss_dict, |
|
) |
|
logger.info( |
|
f'[TRAIN EPOCH] {self.current_epoch}/{self.cfg.train.epochs}: ' + \ |
|
', '.join(f'{k}={tqdm.format_num(v)}' for k, v in epoch_loss_dict.items() if not math.isnan(v)) |
|
) |
|
|
|
def train(self): |
|
|
|
starting_local_step_in_epoch = self.global_step_in_epoch * self.cfg.train.accum_steps |
|
skipped_loader = self.accelerator.skip_first_batches(self.train_loader, starting_local_step_in_epoch) |
|
logger.info(f"======== Skipped {starting_local_step_in_epoch} local batches ========") |
|
|
|
with tqdm( |
|
range(0, self.N_max_global_steps), |
|
initial=self.global_step, |
|
disable=(not self.accelerator.is_main_process), |
|
) as pbar: |
|
|
|
profiler = torch.profiler.profile( |
|
activities=[torch.profiler.ProfilerActivity.CPU, torch.profiler.ProfilerActivity.CUDA], |
|
schedule=torch.profiler.schedule( |
|
wait=10, warmup=10, active=100, |
|
), |
|
on_trace_ready=torch.profiler.tensorboard_trace_handler(os.path.join( |
|
self.cfg.logger.tracker_root, |
|
self.cfg.experiment.parent, self.cfg.experiment.child, |
|
)), |
|
record_shapes=True, |
|
profile_memory=True, |
|
with_stack=True, |
|
) if self.cfg.logger.enable_profiler else DummyProfiler() |
|
|
|
with profiler: |
|
self.optimizer.zero_grad() |
|
if self.has_disc: |
|
self.optimizer_disc.zero_grad() |
|
for iepoch in range(self.current_epoch, self.cfg.train.epochs): |
|
|
|
loader = skipped_loader or self.train_loader |
|
skipped_loader = None |
|
self.train_epoch(pbar=pbar, loader=loader, profiler=profiler, iepoch=iepoch) |
|
if self.accelerator.check_trigger(): |
|
break |
|
|
|
logger.info(f"======== Training finished at global step {self.global_step} ========") |
|
|
|
|
|
self.save_checkpoint() |
|
self.evaluate() |
|
|
|
@torch.no_grad() |
|
@torch.compiler.disable |
|
def evaluate(self, epoch: int = None): |
|
self.model.eval() |
|
|
|
max_val_batches = self.cfg.val.debug_batches or len(self.val_loader) |
|
running_losses = [] |
|
running_extra_losses = [] |
|
extra_loss_keys = [] |
|
sample_data, sample_outs = None, None |
|
|
|
for data in tqdm(self.val_loader, disable=(not self.accelerator.is_main_process), total=max_val_batches): |
|
data["source_rgbs"] = data["source_rgbs"].to(self.weight_dtype) |
|
if self.has_disc and hasattr(self.cfg.model.disc, "cross_render") and self.cfg.model.disc.cross_render: |
|
data = self.prepare_cross_render_data(data) |
|
data["real_num_view"] = 1 |
|
else: |
|
data["real_num_view"] = data["render_image"].shape[1] |
|
|
|
if len(running_losses) >= max_val_batches: |
|
logger.info(f"======== Early stop validation at {len(running_losses)} batches ========") |
|
break |
|
|
|
outs, loss, loss_pixel, loss_perceptual, loss_tv, loss_mask, extra_loss_dict = self.forward_loss_local_step(data) |
|
extra_loss_dict = sorted(list(extra_loss_dict.keys())) |
|
sample_data, sample_outs = data, outs |
|
|
|
running_losses.append(torch.stack([ |
|
_loss if _loss is not None else torch.tensor(float('nan'), device=self.device) |
|
for _loss in [loss, loss_pixel, loss_perceptual, loss_tv, loss_mask] |
|
])) |
|
if len(extra_loss_keys) > 0: |
|
running_extra_losses.append(torch.stack([ |
|
extra_loss_dict[k] if extra_loss_dict[k] is not None else torch.tensor(float('nan'), device=self.device) |
|
for k in extra_loss_keys |
|
])) |
|
|
|
|
|
conf_sigma_l1 = sample_outs.get('conf_sigma_l1', None) |
|
conf_sigma_l1 = conf_sigma_l1.cpu() if conf_sigma_l1 is not None else None |
|
conf_sigma_percl = sample_outs.get('conf_sigma_percl', None) |
|
conf_sigma_percl = conf_sigma_percl.cpu() if conf_sigma_percl is not None else None |
|
self.log_image_monitor_each_process( |
|
step=self.global_step, split='val', |
|
renders=sample_outs['comp_rgb'][:self.cfg.logger.image_monitor.samples_per_log].cpu(), |
|
gts=sample_data['render_image'][:self.cfg.logger.image_monitor.samples_per_log].cpu(), |
|
conf_sigma_l1=conf_sigma_l1, conf_sigma_percl=conf_sigma_percl, |
|
prefix=f"_{len(running_losses)}_rank{self.accelerator.process_index}" |
|
) |
|
if "comp_mask" in sample_outs.keys(): |
|
self.log_image_monitor_each_process( |
|
step=self.global_step, split='val', |
|
renders=sample_outs['comp_mask'][:self.cfg.logger.image_monitor.samples_per_log].cpu(), |
|
gts=sample_data['render_mask'][:self.cfg.logger.image_monitor.samples_per_log].cpu(), |
|
prefix=f"_mask_{len(running_losses)}_rank{self.accelerator.process_index}" |
|
) |
|
|
|
total_losses = self.accelerator.gather(torch.stack(running_losses)).mean(dim=0).cpu() |
|
total_loss, total_loss_pixel, total_loss_perceptual, total_loss_offset, total_loss_mask = total_losses.unbind() |
|
total_loss_dict = { |
|
'loss': total_loss.item(), |
|
'loss_pixel': total_loss_pixel.item(), |
|
'loss_perceptual': total_loss_perceptual.item(), |
|
'loss_offset': total_loss_offset.item(), |
|
'loss_mask': total_loss_mask.item(), |
|
} |
|
if len(extra_loss_keys) > 0: |
|
total_extra_losses = self.accelerator.gather(torch.stack(running_extra_losses)).mean(dim=0).cpu() |
|
for k, v in zip(extra_loss_keys, total_extra_losses.unbind()): |
|
total_loss_dict[k] = v.item() |
|
|
|
if epoch is not None: |
|
self.log_scalar_kwargs( |
|
epoch=epoch, split='val', |
|
**total_loss_dict, |
|
) |
|
logger.info( |
|
f'[VAL EPOCH] {epoch}/{self.cfg.train.epochs}: ' + \ |
|
', '.join(f'{k}={tqdm.format_num(v)}' for k, v in total_loss_dict.items() if not math.isnan(v)) |
|
) |
|
else: |
|
self.log_scalar_kwargs( |
|
step=self.global_step, split='val', |
|
**total_loss_dict, |
|
) |
|
logger.info( |
|
f'[VAL STEP] {self.global_step}/{self.N_max_global_steps}: ' + \ |
|
', '.join(f'{k}={tqdm.format_num(v)}' for k, v in total_loss_dict.items() if not math.isnan(v)) |
|
) |
|
|
|
def log_image_monitor_each_process( |
|
self, epoch: int = None, step: int = None, split: str = None, |
|
renders: torch.Tensor = None, gts: torch.Tensor = None, prefix=None, |
|
conf_sigma_l1: torch.Tensor = None, conf_sigma_percl: torch.Tensor = None |
|
): |
|
M = renders.shape[1] |
|
if gts.shape[1] != M: |
|
gts = repeat(gts, "b v c h w -> b (v r) c h w", r=2) |
|
merged = torch.stack([renders, gts], dim=1)[0].view(-1, *renders.shape[2:]) |
|
renders, gts = renders.view(-1, *renders.shape[2:]), gts.view(-1, *gts.shape[2:]) |
|
renders, gts, merged = make_grid(renders, nrow=M), make_grid(gts, nrow=M), make_grid(merged, nrow=M) |
|
log_type, log_progress = self._get_str_progress(epoch, step) |
|
split = f'/{split}' if split else '' |
|
split = split + prefix if prefix is not None else split |
|
log_img_dict = { |
|
f'Images_split{split}/rendered': renders.unsqueeze(0), |
|
f'Images_split{split}/gt': gts.unsqueeze(0), |
|
f'Images_split{split}/merged': merged.unsqueeze(0), |
|
} |
|
if conf_sigma_l1 is not None: |
|
EPS = 1e-7 |
|
vis_conf_l1 = 1/(1+conf_sigma_l1.detach()+EPS).cpu() |
|
vis_conf_percl = 1/(1+conf_sigma_percl.detach()+EPS).cpu() |
|
vis_conf_l1, vis_conf_percl = rearrange(vis_conf_l1, "b v (r c) h w -> (b v r) c h w", r=2), rearrange(vis_conf_percl, "b v (r c) h w -> (b v r) c h w", r=2) |
|
vis_conf_l1, vis_conf_percl = repeat(vis_conf_l1, "b c1 h w-> b (c1 c2) h w", c2=3), repeat(vis_conf_percl, "b c1 h w -> b (c1 c2) h w", c2=3) |
|
vis_conf_l1, vis_conf_percl = make_grid(vis_conf_l1, nrow=M), make_grid(vis_conf_percl, nrow=M) |
|
log_img_dict[f'Images_split{split}/conf_l1'] = vis_conf_l1.unsqueeze(0) |
|
log_img_dict[f'Images_split{split}/conf_percl'] = vis_conf_percl.unsqueeze(0) |
|
|
|
self.log_images_each_process(log_img_dict, log_progress, {"imwrite_image": False}) |
|
|
|
|
|
@Trainer.control('on_main_process') |
|
def log_image_monitor( |
|
self, epoch: int = None, step: int = None, split: str = None, |
|
renders: torch.Tensor = None, gts: torch.Tensor = None, prefix=None, |
|
conf_sigma_l1: torch.Tensor = None, conf_sigma_percl: torch.Tensor = None |
|
): |
|
self.log_image_monitor_each_process(epoch, step, split, renders, gts, prefix, conf_sigma_l1, conf_sigma_percl) |
|
|