Spaces:
svjack
/
Runtime error

yuandong513
feat: init
17cd746
# Copyright (c) 2023-2024, Zexin He
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import math
from tqdm.auto import tqdm
import torch
import torch.nn as nn
import torchvision
import numpy as np
from torchvision.utils import make_grid
from einops import rearrange, repeat
from accelerate.logging import get_logger
from taming.modules.losses.vqperceptual import hinge_d_loss
from .base_trainer import Trainer
from lam.utils.profiler import DummyProfiler
from lam.runners import REGISTRY_RUNNERS
from lam.utils.hf_hub import wrap_model_hub
from safetensors.torch import load_file
from pytorch3d.ops.knn import knn_points
import torch.nn.functional as F
logger = get_logger(__name__)
# torch.autograd.set_detect_anomaly(True)
from omegaconf import OmegaConf
@REGISTRY_RUNNERS.register('train.lam')
class LAMTrainer(Trainer):
EXP_TYPE: str = 'lam'
def __init__(self):
super().__init__()
self.model = self._build_model(self.cfg)
if self.has_disc:
self.model_disc = self._build_model_disc(self.cfg)
self.optimizer = self._build_optimizer(self.model, self.cfg)
if self.has_disc:
self.optimizer_disc = self._build_optimizer(self.model_disc, self.cfg)
self.train_loader, self.val_loader = self._build_dataloader(self.cfg)
self.scheduler = self._build_scheduler(self.optimizer, self.cfg)
if self.has_disc:
self.scheduler_disc = self._build_scheduler(self.optimizer_disc, self.cfg)
self.pixel_loss_fn, self.perceptual_loss_fn, self.tv_loss_fn = self._build_loss_fn(self.cfg)
self.only_sym_conf = 2
print("==="*16*3, "\n"+"only_sym_conf:", self.only_sym_conf, "\n"+"==="*16*3)
def _build_model(self, cfg):
assert cfg.experiment.type == 'lrm', \
f"Config type {cfg.experiment.type} does not match with runner {self.__class__.__name__}"
from lam.models import ModelLAM
model = ModelLAM(**cfg.model)
# resume
if len(self.cfg.train.resume) > 0:
resume = self.cfg.train.resume
print("==="*16*3)
self.accelerator.print("loading pretrained weight from:", resume)
if resume.endswith('safetensors'):
ckpt = load_file(resume, device='cpu')
else:
ckpt = torch.load(resume, map_location='cpu')
state_dict = model.state_dict()
for k, v in ckpt.items():
if k in state_dict:
if state_dict[k].shape == v.shape:
state_dict[k].copy_(v)
else:
self.accelerator.print(f"WARN] mismatching shape for param {k}: ckpt {v.shape} != model {state_dict[k].shape}, ignored.")
else:
self.accelerator.print(f"WARN] unexpected param {k}: {v.shape}")
self.accelerator.print("Finish loading ckpt:", resume, "\n"+"==="*16*3)
return model
def _build_model_disc(self, cfg):
if cfg.model.disc.type == "pix2pix":
from lam.models.discriminator import NLayerDiscriminator, weights_init
model = NLayerDiscriminator(input_nc=cfg.model.disc.in_channels,
n_layers=cfg.model.disc.num_layers,
use_actnorm=cfg.model.disc.use_actnorm
).apply(weights_init)
elif cfg.model.disc.type == "vqgan":
from lam.models.discriminator import Discriminator
model = Discriminator(in_channels=cfg.model.disc.in_channels,
cond_channels=0, hidden_channels=512,
depth=cfg.model.disc.depth)
elif cfg.model.disc.type == "stylegan":
from lam.models.gan.stylegan_discriminator import SingleDiscriminatorV2, SingleDiscriminator
from lam.models.gan.stylegan_discriminator_torch import Discriminator
model = Discriminator(512, channel_multiplier=2)
model.input_size = cfg.model.disc.img_res
else:
raise NotImplementedError
return model
def _build_optimizer(self, model: nn.Module, cfg):
decay_params, no_decay_params = [], []
# add all bias and LayerNorm params to no_decay_params
for name, module in model.named_modules():
if isinstance(module, nn.LayerNorm):
no_decay_params.extend([p for p in module.parameters()])
elif hasattr(module, 'bias') and module.bias is not None:
no_decay_params.append(module.bias)
# add remaining parameters to decay_params
_no_decay_ids = set(map(id, no_decay_params))
decay_params = [p for p in model.parameters() if id(p) not in _no_decay_ids]
# filter out parameters with no grad
decay_params = list(filter(lambda p: p.requires_grad, decay_params))
no_decay_params = list(filter(lambda p: p.requires_grad, no_decay_params))
# monitor this to make sure we don't miss any parameters
logger.info("======== Weight Decay Parameters ========")
logger.info(f"Total: {len(decay_params)}")
logger.info("======== No Weight Decay Parameters ========")
logger.info(f"Total: {len(no_decay_params)}")
# Optimizer
opt_groups = [
{'params': decay_params, 'weight_decay': cfg.train.optim.weight_decay},
{'params': no_decay_params, 'weight_decay': 0.0},
]
optimizer = torch.optim.AdamW(
opt_groups,
lr=cfg.train.optim.lr,
betas=(cfg.train.optim.beta1, cfg.train.optim.beta2),
)
return optimizer
def _build_scheduler(self, optimizer, cfg):
local_batches_per_epoch = math.floor(len(self.train_loader) / self.accelerator.num_processes)
total_global_batches = cfg.train.epochs * math.ceil(local_batches_per_epoch / self.cfg.train.accum_steps)
effective_warmup_iters = cfg.train.scheduler.warmup_real_iters
logger.debug(f"======== Scheduler effective max iters: {total_global_batches} ========")
logger.debug(f"======== Scheduler effective warmup iters: {effective_warmup_iters} ========")
if cfg.train.scheduler.type == 'cosine':
from lam.utils.scheduler import CosineWarmupScheduler
scheduler = CosineWarmupScheduler(
optimizer=optimizer,
warmup_iters=effective_warmup_iters,
max_iters=total_global_batches,
)
else:
raise NotImplementedError(f"Scheduler type {cfg.train.scheduler.type} not implemented")
return scheduler
def _build_dataloader(self, cfg):
# dataset class
from lam.datasets import MixerDataset
gaga_track_type = cfg.dataset.get("gaga_track_type", "vfhq_gagtrack")
sample_aug_views = cfg.dataset.get("sample_aug_views", 0)
# build dataset
load_normal = cfg.train.loss.get("normal_weight", False) > 0. if hasattr(cfg.train.loss, "normal_weight") else False
load_normal = load_normal or (cfg.train.loss.get("surfel_normal_weight", False) > 0. if hasattr(cfg.train.loss, "surfel_normal_weight") else False)
print("==="*16*3, "\nload_normal:", load_normal)
train_dataset = MixerDataset(
split="train",
subsets=cfg.dataset.subsets,
sample_side_views=cfg.dataset.sample_side_views,
render_image_res_low=cfg.dataset.render_image.low,
render_image_res_high=cfg.dataset.render_image.high,
render_region_size=cfg.dataset.render_image.region,
source_image_res=cfg.dataset.source_image_res,
repeat_num=cfg.dataset.repeat_num if hasattr(cfg.dataset, "repeat_num") else 1,
multiply=cfg.dataset.multiply if hasattr(cfg.dataset, "multiply") else 14,
debug=cfg.dataset.debug if hasattr(cfg.dataset, "debug") else False,
is_val=False,
gaga_track_type=gaga_track_type,
sample_aug_views=sample_aug_views,
load_albedo=cfg.model.get("render_albedo", False) if hasattr(cfg.model, "render_albedo") else False,
load_normal=load_normal,
)
val_dataset = MixerDataset(
split="val",
subsets=cfg.dataset.subsets,
sample_side_views=cfg.dataset.sample_side_views,
render_image_res_low=cfg.dataset.render_image.low,
render_image_res_high=cfg.dataset.render_image.high,
render_region_size=cfg.dataset.render_image.region,
source_image_res=cfg.dataset.source_image_res,
repeat_num=cfg.dataset.repeat_num if hasattr(cfg.dataset, "repeat_num") else 1,
multiply=cfg.dataset.multiply if hasattr(cfg.dataset, "multiply") else 14,
debug=cfg.dataset.debug if hasattr(cfg.dataset, "debug") else False,
is_val=True,
gaga_track_type=gaga_track_type,
sample_aug_views=sample_aug_views,
load_albedo=cfg.model.get("render_albedo", False) if hasattr(cfg.model, "render_albedo") else False,
load_normal=load_normal,
)
# build data loader
train_loader = torch.utils.data.DataLoader(
train_dataset,
batch_size=cfg.train.batch_size,
shuffle=True,
drop_last=True,
num_workers=cfg.dataset.num_train_workers,
pin_memory=cfg.dataset.pin_mem,
persistent_workers=True,
)
val_loader = torch.utils.data.DataLoader(
val_dataset,
batch_size=cfg.val.batch_size,
shuffle=False,
drop_last=False,
num_workers=cfg.dataset.num_val_workers,
pin_memory=cfg.dataset.pin_mem,
persistent_workers=False,
)
return train_loader, val_loader
def _build_loss_fn(self, cfg):
from lam.losses import PixelLoss, LPIPSLoss, TVLoss
pixel_loss_fn = PixelLoss(option=cfg.train.loss.get("pixel_loss_fn", "mse"))
with self.accelerator.main_process_first():
perceptual_loss_fn = LPIPSLoss(device=self.device, prefech=True)
if cfg.model.get("use_conf_map", False):
assert cfg.train.loss.get("head_pl", False), "Set head_pl in train.loss to true to use faceperceptualloss when using conf_map."
tv_loss_fn = TVLoss()
return pixel_loss_fn, perceptual_loss_fn, tv_loss_fn
def register_hooks(self):
pass
def get_flame_params(self, data, is_source=False):
flame_params = {}
flame_keys = ['root_pose', 'body_pose', 'jaw_pose', 'leye_pose', 'reye_pose', 'lhand_pose', 'rhand_pose', 'expr', 'trans', 'betas',\
'rotation', 'neck_pose', 'eyes_pose', 'translation', "teeth_bs"]
if is_source:
flame_keys = ['source_'+item for item in flame_keys]
for k, v in data.items():
if k in flame_keys:
# print(k, v.shape)
flame_params[k] = data[k]
return flame_params
def cross_copy(self, data):
B = data.shape[0]
assert data.shape[1] == 1
new_data = []
for i in range(B):
B_i = [data[i]]
for j in range(B):
if j != i:
B_i.append(data[j])
new_data.append(torch.concat(B_i, dim=0))
new_data = torch.stack(new_data, dim=0)
return new_data
def prepare_cross_render_data(self, data):
B, N_v, C, H, W = data['render_image'].shape
assert N_v == 1
# cross copy
data["c2ws"] = self.cross_copy(data["c2ws"])
data["intrs"] = self.cross_copy(data["intrs"])
data["render_full_resolutions"] = self.cross_copy(data["render_full_resolutions"])
data["render_image"] = self.cross_copy(data["render_image"])
data["render_mask"] = self.cross_copy(data["render_mask"])
data["render_bg_colors"] = self.cross_copy(data["render_bg_colors"])
flame_params = self.get_flame_params(data)
for key in flame_params.keys():
if "betas" not in key:
data[key] = self.cross_copy(data[key])
source_flame_params = self.get_flame_params(data, is_source=True)
for key in source_flame_params.keys():
if "betas" not in key:
data[key] = self.cross_copy(data[key])
return data
def get_loss_weight(self, loss_weight):
if isinstance(loss_weight, str) and ":" in loss_weight:
start_step, start_value, end_value, end_step = map(float, loss_weight.split(":"))
current_step = self.global_step
value = start_value + (end_value - start_value) * max(
min(1.0, (current_step - start_step) / (end_step - start_step)), 0.0
)
return value
elif isinstance(loss_weight, (float, int)):
return loss_weight
else:
raise NotImplementedError
def forward_loss_local_step(self, data):
render_image = data['render_image']
render_albedo = data.get('render_albedo', None)
render_mask = data['render_mask']
render_normal = data.get('render_normal', None)
B, N_v, C, H, W = render_image.shape
flame_params = self.get_flame_params(data)
source_flame_params = self.get_flame_params(data, is_source=True)
# forward
outputs = self.model(
image=data['source_rgbs'],
source_c2ws=data['source_c2ws'],
source_intrs=data['source_intrs'],
render_c2ws=data['c2ws'],
render_intrs=data['intrs'],
render_bg_colors=data['render_bg_colors'],
flame_params=flame_params,
source_flame_params=source_flame_params,
render_images=render_image,
data = data
)
# loss calculation
loss = 0.
loss_pixel = None
loss_perceptual = None
loss_mask = None
extra_loss_dict = {}
num_aug_view = self.cfg.dataset.get("sample_aug_views", 0)
real_num_view = data["real_num_view"] - num_aug_view
conf_sigma_l1 = outputs.get("conf_sigma_l1", None)
conf_sigma_percl = outputs.get("conf_sigma_percl", None)
if self.cfg.model.use_sym_proj:
real_num_view *= 2
if self.cfg.model.use_conf_map:
conf_sigma_l1 = rearrange(conf_sigma_l1, "b v (c r) h w -> b (v r) c h w", r=2)[:, :real_num_view]
conf_sigma_percl = rearrange(conf_sigma_percl, "b v (c r) h w -> b (v r) c h w", r=2)[:, :real_num_view]
render_image = repeat(data['render_image'], "b v c h w -> b (v r) c h w", r=2)
render_albedo = repeat(render_albedo, "b v c h w -> b (v r) c h w", r=2) if render_albedo is not None else None
render_mask = repeat(data['render_mask'], "b v c h w -> b (v r) c h w", r=2)
if "render_normal" in data.keys():
render_normal = repeat(data['render_normal'], "b v c h w -> b (v r) c h w", r=2)
for k, v in data.items():
if "bbox" in k:
data[k] = repeat(v, "b v c -> b (v r) c", r=2)
only_sym_conf = self.only_sym_conf
if self.get_loss_weight(self.cfg.train.loss.get("masked_pixel_weight", 0)) > 0.:
gt_rgb = render_image[:, :real_num_view] * render_mask[:, :real_num_view] + 1.0 * (1 - render_mask[:, :real_num_view])
pred_rgb = outputs['comp_rgb'][:, :real_num_view] * render_mask[:, :real_num_view] + 1.0 * (1 - render_mask[:, :real_num_view])
loss_pixel = self.pixel_loss_fn(pred_rgb, gt_rgb, conf_sigma_l1, only_sym_conf=only_sym_conf) * self.get_loss_weight(self.cfg.train.loss.masked_pixel_weight)
loss += loss_pixel
# using same weight
loss_perceptual = self.perceptual_loss_fn(pred_rgb, gt_rgb, conf_sigma=conf_sigma_percl, only_sym_conf=only_sym_conf) * self.get_loss_weight(self.cfg.train.loss.masked_pixel_weight)
loss += loss_perceptual
if self.get_loss_weight(self.cfg.train.loss.pixel_weight) > 0.:
total_loss_pixel = loss_pixel
if (hasattr(self.cfg.train.loss, 'rgb_weight') and self.get_loss_weight(self.cfg.train.loss.rgb_weight) > 0.) or not hasattr(self.cfg.train.loss, "rgb_weight"):
loss_pixel = self.pixel_loss_fn(
outputs['comp_rgb'][:, :real_num_view], render_image[:, :real_num_view], conf_sigma=conf_sigma_l1, only_sym_conf=only_sym_conf
) * self.get_loss_weight(self.cfg.train.loss.pixel_weight)
loss += loss_pixel
if total_loss_pixel is not None:
loss_pixel += total_loss_pixel
if self.get_loss_weight(self.cfg.train.loss.perceptual_weight) > 0.:
total_loss_perceptual = loss_perceptual
if (hasattr(self.cfg.train.loss, 'rgb_weight') and self.get_loss_weight(self.cfg.train.loss.rgb_weight) > 0.) or not hasattr(self.cfg.train.loss, "rgb_weight"):
loss_perceptual = self.perceptual_loss_fn(
outputs['comp_rgb'][:, :real_num_view], render_image[:, :real_num_view], conf_sigma=conf_sigma_percl, only_sym_conf=only_sym_conf
) * self.get_loss_weight(self.cfg.train.loss.perceptual_weight)
loss += loss_perceptual
if total_loss_perceptual is not None:
loss_perceptual += total_loss_perceptual
if self.get_loss_weight(self.cfg.train.loss.mask_weight) > 0. and 'comp_mask' in outputs.keys():
loss_mask = self.pixel_loss_fn(outputs['comp_mask'][:, :real_num_view], render_mask[:, :real_num_view], conf_sigma=conf_sigma_l1, only_sym_conf=only_sym_conf
) * self.get_loss_weight(self.cfg.train.loss.mask_weight)
loss += loss_mask
if hasattr(self.cfg.train.loss, 'offset_reg_weight') and self.get_loss_weight(self.cfg.train.loss.offset_reg_weight) > 0.:
loss_offset_reg = 0
for b_idx in range(len(outputs['3dgs'])):
loss_offset_reg += torch.nn.functional.mse_loss(outputs['3dgs'][b_idx][0].offset.float(), torch.zeros_like(outputs['3dgs'][b_idx][0].offset.float()))
loss_offset_reg = loss_offset_reg / len(outputs['3dgs'])
loss += loss_offset_reg * self.get_loss_weight(self.cfg.train.loss.offset_reg_weight)
else:
loss_offset_reg = None
return outputs, loss, loss_pixel, loss_perceptual, loss_offset_reg, loss_mask, extra_loss_dict
def adopt_weight(self, weight, global_step, threshold=0, value=0.):
if global_step < threshold:
weight = value
return weight
def calculate_adaptive_weight(self, nll_loss, g_loss, last_layer, discriminator_weight=1):
nll_grads = torch.autograd.grad(nll_loss, last_layer, retain_graph=True)[0]
g_grads = torch.autograd.grad(g_loss, last_layer, retain_graph=True)[0]
d_weight = torch.norm(nll_grads) / (torch.norm(g_grads) + 1e-4)
d_weight = torch.clamp(d_weight, 0.0, 1e4).detach()
d_weight = d_weight * discriminator_weight
return d_weight
def disc_preprocess(self, img):
# reshape [B, N_v, C, H, W] to [B*N_v, C, H, W]
img = torch.flatten(img, 0, 1)
# img = rearrange(img, 'b n c h w -> (b n) c h w')
# convert 0-1 to -1-1
img = 2 * img - 1
if hasattr(self.accelerator.unwrap_model(self.model_disc), "input_size"):
tgt_size = self.accelerator.unwrap_model(self.model_disc).input_size
img = nn.functional.interpolate(img, (tgt_size, tgt_size))
img = img.float()
return img
def forward_to_get_loss_with_gen_loss(self, data):
# forward to loss
outs, loss, loss_pixel, loss_perceptual, loss_tv, loss_mask, extra_loss_dict = self.forward_loss_local_step(data)
with torch.autocast(device_type=outs["comp_rgb"].device.type, dtype=torch.float32):
logits_fake = self.model_disc(self.disc_preprocess(outs["comp_rgb"]))
loss_gen = -torch.mean(logits_fake)
try:
if loss < 1e-5:
d_weight = self.cfg.model.disc.disc_weight
else:
nll_loss = loss_pixel
if nll_loss is None:
nll_loss = loss
d_weight = self.calculate_adaptive_weight(nll_loss, loss_gen,
last_layer=self.accelerator.unwrap_model(self.model).get_last_layer(),
discriminator_weight=self.cfg.model.disc.disc_weight)
except RuntimeError:
print("*************Error when calculate_adaptive_weight************")
d_weight = torch.tensor(0.0)
disc_factor = self.adopt_weight(1.0, self.global_step, threshold=self.cfg.model.disc.disc_iter_start)
# print(disc_factor, d_weight)
loss += disc_factor * d_weight * loss_gen
# backward
self.accelerator.backward(loss)
if self.accelerator.sync_gradients and self.cfg.train.optim.clip_grad_norm > 0.:
self.accelerator.clip_grad_norm_(self.model.parameters(), self.cfg.train.optim.clip_grad_norm)
self.optimizer.step()
self.optimizer.zero_grad()
return outs, loss, loss_pixel, loss_perceptual, loss_tv, loss_mask, loss_gen, extra_loss_dict
def forward_to_get_loss(self, data):
# forward to loss
outs, loss, loss_pixel, loss_perceptual, loss_tv, loss_mask, extra_loss_dict = self.forward_loss_local_step(data)
# backward
self.accelerator.backward(loss)
if self.accelerator.sync_gradients and self.cfg.train.optim.clip_grad_norm > 0.:
self.accelerator.clip_grad_norm_(self.model.parameters(), self.cfg.train.optim.clip_grad_norm)
self.optimizer.step()
self.optimizer.zero_grad()
return outs, loss, loss_pixel, loss_perceptual, loss_tv, loss_mask, extra_loss_dict
def forward_disc_loss_local_step(self, pred_img, gt_img):
# detach gradient of pred_img
with torch.autocast(device_type=pred_img.device.type, dtype=torch.float32):
logits_real = self.model_disc(self.disc_preprocess(gt_img).detach())
logits_fake = self.model_disc(self.disc_preprocess(pred_img).detach())
loss_disc = hinge_d_loss(logits_real, logits_fake)
return loss_disc
def forward_to_get_disc_loss(self, pred_img, gt_img):
# forward to loss
loss_disc = self.forward_disc_loss_local_step(pred_img, gt_img)
disc_factor = self.adopt_weight(1.0, self.global_step, threshold=self.cfg.model.disc.disc_iter_start)
loss = disc_factor * loss_disc
# backward
self.accelerator.backward(loss)
if self.accelerator.sync_gradients and self.cfg.train.optim.clip_grad_norm > 0.:
self.accelerator.clip_grad_norm_(self.model_disc.parameters(), self.cfg.train.optim.clip_grad_norm)
self.optimizer_disc.step()
self.optimizer_disc.zero_grad()
return loss_disc
def train_epoch(self, pbar: tqdm, loader: torch.utils.data.DataLoader, profiler: torch.profiler.profile, iepoch: int):
self.model.train()
if self.has_disc:
self.model_disc.train()
local_step_losses = []
global_step_losses = []
local_step_extra_losses = []
global_step_extra_losses = []
extra_loss_keys = []
logger.debug(f"======== Starting epoch {self.current_epoch} ========")
loss_disc = None
for idx, data in enumerate(loader):
data["source_rgbs"] = data["source_rgbs"].to(self.weight_dtype)
if self.has_disc and hasattr(self.cfg.model.disc, "cross_render") and self.cfg.model.disc.cross_render:
data = self.prepare_cross_render_data(data)
data["real_num_view"] = 1
else:
data["real_num_view"] = data["render_image"].shape[1]
logger.debug(f"======== Starting global step {self.global_step} ========")
if not self.has_disc:
disc_step = False
with self.accelerator.accumulate(self.model):
outs, loss, loss_pixel, loss_perceptual, loss_tv, loss_mask, extra_loss_dict = self.forward_to_get_loss(data)
# track local losses
loss_disc, loss_gen = None, None
local_step_losses.append(torch.stack([
_loss.detach() if _loss is not None else torch.tensor(float('nan'), device=self.device)
for _loss in [loss, loss_pixel, loss_perceptual, loss_tv, loss_mask, loss_disc, loss_gen]
]))
extra_loss_keys = sorted(list(extra_loss_dict.keys()))
if len(extra_loss_keys) > 0:
local_step_extra_losses.append(torch.stack([
extra_loss_dict[k].detach() if extra_loss_dict[k] is not None else torch.tensor(float('nan'), device=self.device)
for k in extra_loss_keys
]))
else:
disc_step = (idx % 5) == 0 or (iepoch * len(loader) + idx < 100 and idx % 2 == 0)
local_step_losses_bak = torch.zeros(6, device=data["source_rgbs"].device)
if not disc_step:
with self.accelerator.accumulate(self.model):
# generator step
outs, loss, loss_pixel, loss_perceptual, loss_tv, loss_mask, loss_gen, extra_loss_dict = self.forward_to_get_loss_with_gen_loss(data)
# track local losses
local_step_losses.append(torch.stack([
_loss.detach() if _loss is not None else torch.tensor(float('nan'), device=self.device)
for _loss in [loss, loss_pixel, loss_perceptual, loss_tv, loss_mask, loss_gen, loss_disc]
]))
local_step_losses_bak = local_step_losses[-1].detach()
torch.cuda.empty_cache()
extra_loss_keys = sorted(list(extra_loss_dict.keys()))
if len(extra_loss_keys) > 0:
local_step_extra_losses.append(torch.stack([
extra_loss_dict[k].detach() if extra_loss_dict[k] is not None else torch.tensor(float('nan'), device=self.device)
for k in extra_loss_keys
]))
else:
with self.accelerator.accumulate(self.model_disc):
# discriminator step
outs, _, _, _, _, _, _ = self.forward_loss_local_step(data)
loss_disc = self.forward_to_get_disc_loss(pred_img=outs["comp_rgb"],
gt_img=data["render_image"])
local_step_losses.append(torch.concat([local_step_losses_bak[:6], loss_disc.unsqueeze(0)], dim=0))
torch.cuda.empty_cache()
# track global step
if self.accelerator.sync_gradients:
profiler.step()
if not disc_step:
self.scheduler.step()
if self.has_disc and disc_step:
self.scheduler_disc.step()
logger.debug(f"======== Scheduler step ========")
self.global_step += 1
global_step_loss = self.accelerator.gather(torch.stack(local_step_losses)).mean(dim=0).cpu()
if len(extra_loss_keys) > 0:
global_step_extra_loss = self.accelerator.gather(torch.stack(local_step_extra_losses)).mean(dim=0).cpu()
global_step_extra_loss_items = global_step_extra_loss.unbind()
else:
global_step_extra_loss = None
global_step_extra_loss_items = []
loss, loss_pixel, loss_perceptual, loss_tv, loss_mask, loss_gen, loss_disc_ = global_step_loss.unbind()
loss_kwargs = {
'loss': loss.item(),
'loss_pixel': loss_pixel.item(),
'loss_perceptual': loss_perceptual.item(),
'loss_tv': loss_tv.item(),
'loss_mask': loss_mask.item(),
'loss_disc': loss_disc_.item(),
'loss_gen': loss_gen.item(),
}
for k, loss in zip(extra_loss_keys, global_step_extra_loss_items):
loss_kwargs[k] = loss.item()
self.log_scalar_kwargs(
step=self.global_step, split='train',
**loss_kwargs
)
self.log_optimizer(step=self.global_step, attrs=['lr'], group_ids=[0, 1])
local_step_losses = []
global_step_losses.append(global_step_loss)
local_step_extra_losses = []
global_step_extra_losses.append(global_step_extra_loss)
# manage display
pbar.update(1)
description = {
**loss_kwargs,
'lr': self.optimizer.param_groups[0]['lr'],
}
description = '[TRAIN STEP]' + \
', '.join(f'{k}={tqdm.format_num(v)}' for k, v in description.items() if not math.isnan(v))
pbar.set_description(description)
# periodic actions
if self.global_step % self.cfg.saver.checkpoint_global_steps == 0:
self.save_checkpoint()
if self.global_step % self.cfg.val.global_step_period == 0:
self.evaluate()
self.model.train()
if self.has_disc:
self.model_disc.train()
if (self.global_step % self.cfg.logger.image_monitor.train_global_steps == 0) or (self.global_step < 1000 and self.global_step % 20 == 0):
conf_sigma_l1 = outs.get('conf_sigma_l1', None)
conf_sigma_l1 = conf_sigma_l1.cpu() if conf_sigma_l1 is not None else None
conf_sigma_percl = outs.get('conf_sigma_percl', None)
conf_sigma_percl = conf_sigma_percl.cpu() if conf_sigma_percl is not None else None
self.log_image_monitor(
step=self.global_step, split='train',
renders=outs['comp_rgb'].detach()[:self.cfg.logger.image_monitor.samples_per_log].cpu(),
conf_sigma_l1=conf_sigma_l1, conf_sigma_percl=conf_sigma_percl,
gts=data['render_image'][:self.cfg.logger.image_monitor.samples_per_log].cpu(),
)
if 'comp_mask' in outs.keys():
self.log_image_monitor(
step=self.global_step, split='train',
renders=outs['comp_mask'].detach()[:self.cfg.logger.image_monitor.samples_per_log].cpu(),
gts=data['render_mask'][:self.cfg.logger.image_monitor.samples_per_log].cpu(),
prefix="_mask",
)
# progress control
if self.global_step >= self.N_max_global_steps:
self.accelerator.set_trigger()
break
# track epoch
self.current_epoch += 1
epoch_losses = torch.stack(global_step_losses).mean(dim=0)
epoch_loss, epoch_loss_pixel, epoch_loss_perceptual, epoch_loss_tv, epoch_loss_mask, epoch_loss_disc, epoch_loss_gen = epoch_losses.unbind()
epoch_loss_dict = {
'loss': epoch_loss.item(),
'loss_pixel': epoch_loss_pixel.item(),
'loss_perceptual': epoch_loss_perceptual.item(),
'loss_tv': epoch_loss_tv.item(),
'loss_mask': epoch_loss_mask.item(),
'loss_disc': epoch_loss_disc.item(),
'loss_gen': epoch_loss_gen.item(),
}
if len(extra_loss_keys) > 0:
epoch_extra_losses = torch.stack(global_step_extra_losses).mean(dim=0)
for k, v in zip(extra_loss_keys, epoch_extra_losses.unbind()):
epoch_loss_dict[k] = v.item()
self.log_scalar_kwargs(
epoch=self.current_epoch, split='train',
**epoch_loss_dict,
)
logger.info(
f'[TRAIN EPOCH] {self.current_epoch}/{self.cfg.train.epochs}: ' + \
', '.join(f'{k}={tqdm.format_num(v)}' for k, v in epoch_loss_dict.items() if not math.isnan(v))
)
def train(self):
starting_local_step_in_epoch = self.global_step_in_epoch * self.cfg.train.accum_steps
skipped_loader = self.accelerator.skip_first_batches(self.train_loader, starting_local_step_in_epoch)
logger.info(f"======== Skipped {starting_local_step_in_epoch} local batches ========")
with tqdm(
range(0, self.N_max_global_steps),
initial=self.global_step,
disable=(not self.accelerator.is_main_process),
) as pbar:
profiler = torch.profiler.profile(
activities=[torch.profiler.ProfilerActivity.CPU, torch.profiler.ProfilerActivity.CUDA],
schedule=torch.profiler.schedule(
wait=10, warmup=10, active=100,
),
on_trace_ready=torch.profiler.tensorboard_trace_handler(os.path.join(
self.cfg.logger.tracker_root,
self.cfg.experiment.parent, self.cfg.experiment.child,
)),
record_shapes=True,
profile_memory=True,
with_stack=True,
) if self.cfg.logger.enable_profiler else DummyProfiler()
with profiler:
self.optimizer.zero_grad()
if self.has_disc:
self.optimizer_disc.zero_grad()
for iepoch in range(self.current_epoch, self.cfg.train.epochs):
loader = skipped_loader or self.train_loader
skipped_loader = None
self.train_epoch(pbar=pbar, loader=loader, profiler=profiler, iepoch=iepoch)
if self.accelerator.check_trigger():
break
logger.info(f"======== Training finished at global step {self.global_step} ========")
# final checkpoint and evaluation
self.save_checkpoint()
self.evaluate()
@torch.no_grad()
@torch.compiler.disable
def evaluate(self, epoch: int = None):
self.model.eval()
max_val_batches = self.cfg.val.debug_batches or len(self.val_loader)
running_losses = []
running_extra_losses = []
extra_loss_keys = []
sample_data, sample_outs = None, None
for data in tqdm(self.val_loader, disable=(not self.accelerator.is_main_process), total=max_val_batches):
data["source_rgbs"] = data["source_rgbs"].to(self.weight_dtype)
if self.has_disc and hasattr(self.cfg.model.disc, "cross_render") and self.cfg.model.disc.cross_render:
data = self.prepare_cross_render_data(data)
data["real_num_view"] = 1
else:
data["real_num_view"] = data["render_image"].shape[1]
if len(running_losses) >= max_val_batches:
logger.info(f"======== Early stop validation at {len(running_losses)} batches ========")
break
outs, loss, loss_pixel, loss_perceptual, loss_tv, loss_mask, extra_loss_dict = self.forward_loss_local_step(data)
extra_loss_dict = sorted(list(extra_loss_dict.keys()))
sample_data, sample_outs = data, outs
running_losses.append(torch.stack([
_loss if _loss is not None else torch.tensor(float('nan'), device=self.device)
for _loss in [loss, loss_pixel, loss_perceptual, loss_tv, loss_mask]
]))
if len(extra_loss_keys) > 0:
running_extra_losses.append(torch.stack([
extra_loss_dict[k] if extra_loss_dict[k] is not None else torch.tensor(float('nan'), device=self.device)
for k in extra_loss_keys
]))
# log each step
conf_sigma_l1 = sample_outs.get('conf_sigma_l1', None)
conf_sigma_l1 = conf_sigma_l1.cpu() if conf_sigma_l1 is not None else None
conf_sigma_percl = sample_outs.get('conf_sigma_percl', None)
conf_sigma_percl = conf_sigma_percl.cpu() if conf_sigma_percl is not None else None
self.log_image_monitor_each_process(
step=self.global_step, split='val',
renders=sample_outs['comp_rgb'][:self.cfg.logger.image_monitor.samples_per_log].cpu(),
gts=sample_data['render_image'][:self.cfg.logger.image_monitor.samples_per_log].cpu(),
conf_sigma_l1=conf_sigma_l1, conf_sigma_percl=conf_sigma_percl,
prefix=f"_{len(running_losses)}_rank{self.accelerator.process_index}"
)
if "comp_mask" in sample_outs.keys():
self.log_image_monitor_each_process(
step=self.global_step, split='val',
renders=sample_outs['comp_mask'][:self.cfg.logger.image_monitor.samples_per_log].cpu(),
gts=sample_data['render_mask'][:self.cfg.logger.image_monitor.samples_per_log].cpu(),
prefix=f"_mask_{len(running_losses)}_rank{self.accelerator.process_index}"
)
total_losses = self.accelerator.gather(torch.stack(running_losses)).mean(dim=0).cpu()
total_loss, total_loss_pixel, total_loss_perceptual, total_loss_offset, total_loss_mask = total_losses.unbind()
total_loss_dict = {
'loss': total_loss.item(),
'loss_pixel': total_loss_pixel.item(),
'loss_perceptual': total_loss_perceptual.item(),
'loss_offset': total_loss_offset.item(),
'loss_mask': total_loss_mask.item(),
}
if len(extra_loss_keys) > 0:
total_extra_losses = self.accelerator.gather(torch.stack(running_extra_losses)).mean(dim=0).cpu()
for k, v in zip(extra_loss_keys, total_extra_losses.unbind()):
total_loss_dict[k] = v.item()
if epoch is not None:
self.log_scalar_kwargs(
epoch=epoch, split='val',
**total_loss_dict,
)
logger.info(
f'[VAL EPOCH] {epoch}/{self.cfg.train.epochs}: ' + \
', '.join(f'{k}={tqdm.format_num(v)}' for k, v in total_loss_dict.items() if not math.isnan(v))
)
else:
self.log_scalar_kwargs(
step=self.global_step, split='val',
**total_loss_dict,
)
logger.info(
f'[VAL STEP] {self.global_step}/{self.N_max_global_steps}: ' + \
', '.join(f'{k}={tqdm.format_num(v)}' for k, v in total_loss_dict.items() if not math.isnan(v))
)
def log_image_monitor_each_process(
self, epoch: int = None, step: int = None, split: str = None,
renders: torch.Tensor = None, gts: torch.Tensor = None, prefix=None,
conf_sigma_l1: torch.Tensor = None, conf_sigma_percl: torch.Tensor = None
):
M = renders.shape[1]
if gts.shape[1] != M:
gts = repeat(gts, "b v c h w -> b (v r) c h w", r=2)
merged = torch.stack([renders, gts], dim=1)[0].view(-1, *renders.shape[2:])
renders, gts = renders.view(-1, *renders.shape[2:]), gts.view(-1, *gts.shape[2:])
renders, gts, merged = make_grid(renders, nrow=M), make_grid(gts, nrow=M), make_grid(merged, nrow=M)
log_type, log_progress = self._get_str_progress(epoch, step)
split = f'/{split}' if split else ''
split = split + prefix if prefix is not None else split
log_img_dict = {
f'Images_split{split}/rendered': renders.unsqueeze(0),
f'Images_split{split}/gt': gts.unsqueeze(0),
f'Images_split{split}/merged': merged.unsqueeze(0),
}
if conf_sigma_l1 is not None:
EPS = 1e-7
vis_conf_l1 = 1/(1+conf_sigma_l1.detach()+EPS).cpu()
vis_conf_percl = 1/(1+conf_sigma_percl.detach()+EPS).cpu()
vis_conf_l1, vis_conf_percl = rearrange(vis_conf_l1, "b v (r c) h w -> (b v r) c h w", r=2), rearrange(vis_conf_percl, "b v (r c) h w -> (b v r) c h w", r=2)
vis_conf_l1, vis_conf_percl = repeat(vis_conf_l1, "b c1 h w-> b (c1 c2) h w", c2=3), repeat(vis_conf_percl, "b c1 h w -> b (c1 c2) h w", c2=3)
vis_conf_l1, vis_conf_percl = make_grid(vis_conf_l1, nrow=M), make_grid(vis_conf_percl, nrow=M)
log_img_dict[f'Images_split{split}/conf_l1'] = vis_conf_l1.unsqueeze(0)
log_img_dict[f'Images_split{split}/conf_percl'] = vis_conf_percl.unsqueeze(0)
self.log_images_each_process(log_img_dict, log_progress, {"imwrite_image": False})
@Trainer.control('on_main_process')
def log_image_monitor(
self, epoch: int = None, step: int = None, split: str = None,
renders: torch.Tensor = None, gts: torch.Tensor = None, prefix=None,
conf_sigma_l1: torch.Tensor = None, conf_sigma_percl: torch.Tensor = None
):
self.log_image_monitor_each_process(epoch, step, split, renders, gts, prefix, conf_sigma_l1, conf_sigma_percl)