Spaces:
Running
on
Zero
Running
on
Zero
# Copyright (c) 2023-2024, Zexin He | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# https://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
import os | |
import math | |
from tqdm.auto import tqdm | |
import torch | |
import torch.nn as nn | |
import torchvision | |
import numpy as np | |
from torchvision.utils import make_grid | |
from einops import rearrange, repeat | |
from accelerate.logging import get_logger | |
from taming.modules.losses.vqperceptual import hinge_d_loss | |
from .base_trainer import Trainer | |
from lam.utils.profiler import DummyProfiler | |
from lam.runners import REGISTRY_RUNNERS | |
from lam.utils.hf_hub import wrap_model_hub | |
from safetensors.torch import load_file | |
from pytorch3d.ops.knn import knn_points | |
import torch.nn.functional as F | |
logger = get_logger(__name__) | |
# torch.autograd.set_detect_anomaly(True) | |
from omegaconf import OmegaConf | |
class LAMTrainer(Trainer): | |
EXP_TYPE: str = 'lam' | |
def __init__(self): | |
super().__init__() | |
self.model = self._build_model(self.cfg) | |
if self.has_disc: | |
self.model_disc = self._build_model_disc(self.cfg) | |
self.optimizer = self._build_optimizer(self.model, self.cfg) | |
if self.has_disc: | |
self.optimizer_disc = self._build_optimizer(self.model_disc, self.cfg) | |
self.train_loader, self.val_loader = self._build_dataloader(self.cfg) | |
self.scheduler = self._build_scheduler(self.optimizer, self.cfg) | |
if self.has_disc: | |
self.scheduler_disc = self._build_scheduler(self.optimizer_disc, self.cfg) | |
self.pixel_loss_fn, self.perceptual_loss_fn, self.tv_loss_fn = self._build_loss_fn(self.cfg) | |
self.only_sym_conf = 2 | |
print("==="*16*3, "\n"+"only_sym_conf:", self.only_sym_conf, "\n"+"==="*16*3) | |
def _build_model(self, cfg): | |
assert cfg.experiment.type == 'lrm', \ | |
f"Config type {cfg.experiment.type} does not match with runner {self.__class__.__name__}" | |
from lam.models import ModelLAM | |
model = ModelLAM(**cfg.model) | |
# resume | |
if len(self.cfg.train.resume) > 0: | |
resume = self.cfg.train.resume | |
print("==="*16*3) | |
self.accelerator.print("loading pretrained weight from:", resume) | |
if resume.endswith('safetensors'): | |
ckpt = load_file(resume, device='cpu') | |
else: | |
ckpt = torch.load(resume, map_location='cpu') | |
state_dict = model.state_dict() | |
for k, v in ckpt.items(): | |
if k in state_dict: | |
if state_dict[k].shape == v.shape: | |
state_dict[k].copy_(v) | |
else: | |
self.accelerator.print(f"WARN] mismatching shape for param {k}: ckpt {v.shape} != model {state_dict[k].shape}, ignored.") | |
else: | |
self.accelerator.print(f"WARN] unexpected param {k}: {v.shape}") | |
self.accelerator.print("Finish loading ckpt:", resume, "\n"+"==="*16*3) | |
return model | |
def _build_model_disc(self, cfg): | |
if cfg.model.disc.type == "pix2pix": | |
from lam.models.discriminator import NLayerDiscriminator, weights_init | |
model = NLayerDiscriminator(input_nc=cfg.model.disc.in_channels, | |
n_layers=cfg.model.disc.num_layers, | |
use_actnorm=cfg.model.disc.use_actnorm | |
).apply(weights_init) | |
elif cfg.model.disc.type == "vqgan": | |
from lam.models.discriminator import Discriminator | |
model = Discriminator(in_channels=cfg.model.disc.in_channels, | |
cond_channels=0, hidden_channels=512, | |
depth=cfg.model.disc.depth) | |
elif cfg.model.disc.type == "stylegan": | |
from lam.models.gan.stylegan_discriminator import SingleDiscriminatorV2, SingleDiscriminator | |
from lam.models.gan.stylegan_discriminator_torch import Discriminator | |
model = Discriminator(512, channel_multiplier=2) | |
model.input_size = cfg.model.disc.img_res | |
else: | |
raise NotImplementedError | |
return model | |
def _build_optimizer(self, model: nn.Module, cfg): | |
decay_params, no_decay_params = [], [] | |
# add all bias and LayerNorm params to no_decay_params | |
for name, module in model.named_modules(): | |
if isinstance(module, nn.LayerNorm): | |
no_decay_params.extend([p for p in module.parameters()]) | |
elif hasattr(module, 'bias') and module.bias is not None: | |
no_decay_params.append(module.bias) | |
# add remaining parameters to decay_params | |
_no_decay_ids = set(map(id, no_decay_params)) | |
decay_params = [p for p in model.parameters() if id(p) not in _no_decay_ids] | |
# filter out parameters with no grad | |
decay_params = list(filter(lambda p: p.requires_grad, decay_params)) | |
no_decay_params = list(filter(lambda p: p.requires_grad, no_decay_params)) | |
# monitor this to make sure we don't miss any parameters | |
logger.info("======== Weight Decay Parameters ========") | |
logger.info(f"Total: {len(decay_params)}") | |
logger.info("======== No Weight Decay Parameters ========") | |
logger.info(f"Total: {len(no_decay_params)}") | |
# Optimizer | |
opt_groups = [ | |
{'params': decay_params, 'weight_decay': cfg.train.optim.weight_decay}, | |
{'params': no_decay_params, 'weight_decay': 0.0}, | |
] | |
optimizer = torch.optim.AdamW( | |
opt_groups, | |
lr=cfg.train.optim.lr, | |
betas=(cfg.train.optim.beta1, cfg.train.optim.beta2), | |
) | |
return optimizer | |
def _build_scheduler(self, optimizer, cfg): | |
local_batches_per_epoch = math.floor(len(self.train_loader) / self.accelerator.num_processes) | |
total_global_batches = cfg.train.epochs * math.ceil(local_batches_per_epoch / self.cfg.train.accum_steps) | |
effective_warmup_iters = cfg.train.scheduler.warmup_real_iters | |
logger.debug(f"======== Scheduler effective max iters: {total_global_batches} ========") | |
logger.debug(f"======== Scheduler effective warmup iters: {effective_warmup_iters} ========") | |
if cfg.train.scheduler.type == 'cosine': | |
from lam.utils.scheduler import CosineWarmupScheduler | |
scheduler = CosineWarmupScheduler( | |
optimizer=optimizer, | |
warmup_iters=effective_warmup_iters, | |
max_iters=total_global_batches, | |
) | |
else: | |
raise NotImplementedError(f"Scheduler type {cfg.train.scheduler.type} not implemented") | |
return scheduler | |
def _build_dataloader(self, cfg): | |
# dataset class | |
from lam.datasets import MixerDataset | |
gaga_track_type = cfg.dataset.get("gaga_track_type", "vfhq_gagtrack") | |
sample_aug_views = cfg.dataset.get("sample_aug_views", 0) | |
# build dataset | |
load_normal = cfg.train.loss.get("normal_weight", False) > 0. if hasattr(cfg.train.loss, "normal_weight") else False | |
load_normal = load_normal or (cfg.train.loss.get("surfel_normal_weight", False) > 0. if hasattr(cfg.train.loss, "surfel_normal_weight") else False) | |
print("==="*16*3, "\nload_normal:", load_normal) | |
train_dataset = MixerDataset( | |
split="train", | |
subsets=cfg.dataset.subsets, | |
sample_side_views=cfg.dataset.sample_side_views, | |
render_image_res_low=cfg.dataset.render_image.low, | |
render_image_res_high=cfg.dataset.render_image.high, | |
render_region_size=cfg.dataset.render_image.region, | |
source_image_res=cfg.dataset.source_image_res, | |
repeat_num=cfg.dataset.repeat_num if hasattr(cfg.dataset, "repeat_num") else 1, | |
multiply=cfg.dataset.multiply if hasattr(cfg.dataset, "multiply") else 14, | |
debug=cfg.dataset.debug if hasattr(cfg.dataset, "debug") else False, | |
is_val=False, | |
gaga_track_type=gaga_track_type, | |
sample_aug_views=sample_aug_views, | |
load_albedo=cfg.model.get("render_albedo", False) if hasattr(cfg.model, "render_albedo") else False, | |
load_normal=load_normal, | |
) | |
val_dataset = MixerDataset( | |
split="val", | |
subsets=cfg.dataset.subsets, | |
sample_side_views=cfg.dataset.sample_side_views, | |
render_image_res_low=cfg.dataset.render_image.low, | |
render_image_res_high=cfg.dataset.render_image.high, | |
render_region_size=cfg.dataset.render_image.region, | |
source_image_res=cfg.dataset.source_image_res, | |
repeat_num=cfg.dataset.repeat_num if hasattr(cfg.dataset, "repeat_num") else 1, | |
multiply=cfg.dataset.multiply if hasattr(cfg.dataset, "multiply") else 14, | |
debug=cfg.dataset.debug if hasattr(cfg.dataset, "debug") else False, | |
is_val=True, | |
gaga_track_type=gaga_track_type, | |
sample_aug_views=sample_aug_views, | |
load_albedo=cfg.model.get("render_albedo", False) if hasattr(cfg.model, "render_albedo") else False, | |
load_normal=load_normal, | |
) | |
# build data loader | |
train_loader = torch.utils.data.DataLoader( | |
train_dataset, | |
batch_size=cfg.train.batch_size, | |
shuffle=True, | |
drop_last=True, | |
num_workers=cfg.dataset.num_train_workers, | |
pin_memory=cfg.dataset.pin_mem, | |
persistent_workers=True, | |
) | |
val_loader = torch.utils.data.DataLoader( | |
val_dataset, | |
batch_size=cfg.val.batch_size, | |
shuffle=False, | |
drop_last=False, | |
num_workers=cfg.dataset.num_val_workers, | |
pin_memory=cfg.dataset.pin_mem, | |
persistent_workers=False, | |
) | |
return train_loader, val_loader | |
def _build_loss_fn(self, cfg): | |
from lam.losses import PixelLoss, LPIPSLoss, TVLoss | |
pixel_loss_fn = PixelLoss(option=cfg.train.loss.get("pixel_loss_fn", "mse")) | |
with self.accelerator.main_process_first(): | |
perceptual_loss_fn = LPIPSLoss(device=self.device, prefech=True) | |
if cfg.model.get("use_conf_map", False): | |
assert cfg.train.loss.get("head_pl", False), "Set head_pl in train.loss to true to use faceperceptualloss when using conf_map." | |
tv_loss_fn = TVLoss() | |
return pixel_loss_fn, perceptual_loss_fn, tv_loss_fn | |
def register_hooks(self): | |
pass | |
def get_flame_params(self, data, is_source=False): | |
flame_params = {} | |
flame_keys = ['root_pose', 'body_pose', 'jaw_pose', 'leye_pose', 'reye_pose', 'lhand_pose', 'rhand_pose', 'expr', 'trans', 'betas',\ | |
'rotation', 'neck_pose', 'eyes_pose', 'translation', "teeth_bs"] | |
if is_source: | |
flame_keys = ['source_'+item for item in flame_keys] | |
for k, v in data.items(): | |
if k in flame_keys: | |
# print(k, v.shape) | |
flame_params[k] = data[k] | |
return flame_params | |
def cross_copy(self, data): | |
B = data.shape[0] | |
assert data.shape[1] == 1 | |
new_data = [] | |
for i in range(B): | |
B_i = [data[i]] | |
for j in range(B): | |
if j != i: | |
B_i.append(data[j]) | |
new_data.append(torch.concat(B_i, dim=0)) | |
new_data = torch.stack(new_data, dim=0) | |
return new_data | |
def prepare_cross_render_data(self, data): | |
B, N_v, C, H, W = data['render_image'].shape | |
assert N_v == 1 | |
# cross copy | |
data["c2ws"] = self.cross_copy(data["c2ws"]) | |
data["intrs"] = self.cross_copy(data["intrs"]) | |
data["render_full_resolutions"] = self.cross_copy(data["render_full_resolutions"]) | |
data["render_image"] = self.cross_copy(data["render_image"]) | |
data["render_mask"] = self.cross_copy(data["render_mask"]) | |
data["render_bg_colors"] = self.cross_copy(data["render_bg_colors"]) | |
flame_params = self.get_flame_params(data) | |
for key in flame_params.keys(): | |
if "betas" not in key: | |
data[key] = self.cross_copy(data[key]) | |
source_flame_params = self.get_flame_params(data, is_source=True) | |
for key in source_flame_params.keys(): | |
if "betas" not in key: | |
data[key] = self.cross_copy(data[key]) | |
return data | |
def get_loss_weight(self, loss_weight): | |
if isinstance(loss_weight, str) and ":" in loss_weight: | |
start_step, start_value, end_value, end_step = map(float, loss_weight.split(":")) | |
current_step = self.global_step | |
value = start_value + (end_value - start_value) * max( | |
min(1.0, (current_step - start_step) / (end_step - start_step)), 0.0 | |
) | |
return value | |
elif isinstance(loss_weight, (float, int)): | |
return loss_weight | |
else: | |
raise NotImplementedError | |
def forward_loss_local_step(self, data): | |
render_image = data['render_image'] | |
render_albedo = data.get('render_albedo', None) | |
render_mask = data['render_mask'] | |
render_normal = data.get('render_normal', None) | |
B, N_v, C, H, W = render_image.shape | |
flame_params = self.get_flame_params(data) | |
source_flame_params = self.get_flame_params(data, is_source=True) | |
# forward | |
outputs = self.model( | |
image=data['source_rgbs'], | |
source_c2ws=data['source_c2ws'], | |
source_intrs=data['source_intrs'], | |
render_c2ws=data['c2ws'], | |
render_intrs=data['intrs'], | |
render_bg_colors=data['render_bg_colors'], | |
flame_params=flame_params, | |
source_flame_params=source_flame_params, | |
render_images=render_image, | |
data = data | |
) | |
# loss calculation | |
loss = 0. | |
loss_pixel = None | |
loss_perceptual = None | |
loss_mask = None | |
extra_loss_dict = {} | |
num_aug_view = self.cfg.dataset.get("sample_aug_views", 0) | |
real_num_view = data["real_num_view"] - num_aug_view | |
conf_sigma_l1 = outputs.get("conf_sigma_l1", None) | |
conf_sigma_percl = outputs.get("conf_sigma_percl", None) | |
if self.cfg.model.use_sym_proj: | |
real_num_view *= 2 | |
if self.cfg.model.use_conf_map: | |
conf_sigma_l1 = rearrange(conf_sigma_l1, "b v (c r) h w -> b (v r) c h w", r=2)[:, :real_num_view] | |
conf_sigma_percl = rearrange(conf_sigma_percl, "b v (c r) h w -> b (v r) c h w", r=2)[:, :real_num_view] | |
render_image = repeat(data['render_image'], "b v c h w -> b (v r) c h w", r=2) | |
render_albedo = repeat(render_albedo, "b v c h w -> b (v r) c h w", r=2) if render_albedo is not None else None | |
render_mask = repeat(data['render_mask'], "b v c h w -> b (v r) c h w", r=2) | |
if "render_normal" in data.keys(): | |
render_normal = repeat(data['render_normal'], "b v c h w -> b (v r) c h w", r=2) | |
for k, v in data.items(): | |
if "bbox" in k: | |
data[k] = repeat(v, "b v c -> b (v r) c", r=2) | |
only_sym_conf = self.only_sym_conf | |
if self.get_loss_weight(self.cfg.train.loss.get("masked_pixel_weight", 0)) > 0.: | |
gt_rgb = render_image[:, :real_num_view] * render_mask[:, :real_num_view] + 1.0 * (1 - render_mask[:, :real_num_view]) | |
pred_rgb = outputs['comp_rgb'][:, :real_num_view] * render_mask[:, :real_num_view] + 1.0 * (1 - render_mask[:, :real_num_view]) | |
loss_pixel = self.pixel_loss_fn(pred_rgb, gt_rgb, conf_sigma_l1, only_sym_conf=only_sym_conf) * self.get_loss_weight(self.cfg.train.loss.masked_pixel_weight) | |
loss += loss_pixel | |
# using same weight | |
loss_perceptual = self.perceptual_loss_fn(pred_rgb, gt_rgb, conf_sigma=conf_sigma_percl, only_sym_conf=only_sym_conf) * self.get_loss_weight(self.cfg.train.loss.masked_pixel_weight) | |
loss += loss_perceptual | |
if self.get_loss_weight(self.cfg.train.loss.pixel_weight) > 0.: | |
total_loss_pixel = loss_pixel | |
if (hasattr(self.cfg.train.loss, 'rgb_weight') and self.get_loss_weight(self.cfg.train.loss.rgb_weight) > 0.) or not hasattr(self.cfg.train.loss, "rgb_weight"): | |
loss_pixel = self.pixel_loss_fn( | |
outputs['comp_rgb'][:, :real_num_view], render_image[:, :real_num_view], conf_sigma=conf_sigma_l1, only_sym_conf=only_sym_conf | |
) * self.get_loss_weight(self.cfg.train.loss.pixel_weight) | |
loss += loss_pixel | |
if total_loss_pixel is not None: | |
loss_pixel += total_loss_pixel | |
if self.get_loss_weight(self.cfg.train.loss.perceptual_weight) > 0.: | |
total_loss_perceptual = loss_perceptual | |
if (hasattr(self.cfg.train.loss, 'rgb_weight') and self.get_loss_weight(self.cfg.train.loss.rgb_weight) > 0.) or not hasattr(self.cfg.train.loss, "rgb_weight"): | |
loss_perceptual = self.perceptual_loss_fn( | |
outputs['comp_rgb'][:, :real_num_view], render_image[:, :real_num_view], conf_sigma=conf_sigma_percl, only_sym_conf=only_sym_conf | |
) * self.get_loss_weight(self.cfg.train.loss.perceptual_weight) | |
loss += loss_perceptual | |
if total_loss_perceptual is not None: | |
loss_perceptual += total_loss_perceptual | |
if self.get_loss_weight(self.cfg.train.loss.mask_weight) > 0. and 'comp_mask' in outputs.keys(): | |
loss_mask = self.pixel_loss_fn(outputs['comp_mask'][:, :real_num_view], render_mask[:, :real_num_view], conf_sigma=conf_sigma_l1, only_sym_conf=only_sym_conf | |
) * self.get_loss_weight(self.cfg.train.loss.mask_weight) | |
loss += loss_mask | |
if hasattr(self.cfg.train.loss, 'offset_reg_weight') and self.get_loss_weight(self.cfg.train.loss.offset_reg_weight) > 0.: | |
loss_offset_reg = 0 | |
for b_idx in range(len(outputs['3dgs'])): | |
loss_offset_reg += torch.nn.functional.mse_loss(outputs['3dgs'][b_idx][0].offset.float(), torch.zeros_like(outputs['3dgs'][b_idx][0].offset.float())) | |
loss_offset_reg = loss_offset_reg / len(outputs['3dgs']) | |
loss += loss_offset_reg * self.get_loss_weight(self.cfg.train.loss.offset_reg_weight) | |
else: | |
loss_offset_reg = None | |
return outputs, loss, loss_pixel, loss_perceptual, loss_offset_reg, loss_mask, extra_loss_dict | |
def adopt_weight(self, weight, global_step, threshold=0, value=0.): | |
if global_step < threshold: | |
weight = value | |
return weight | |
def calculate_adaptive_weight(self, nll_loss, g_loss, last_layer, discriminator_weight=1): | |
nll_grads = torch.autograd.grad(nll_loss, last_layer, retain_graph=True)[0] | |
g_grads = torch.autograd.grad(g_loss, last_layer, retain_graph=True)[0] | |
d_weight = torch.norm(nll_grads) / (torch.norm(g_grads) + 1e-4) | |
d_weight = torch.clamp(d_weight, 0.0, 1e4).detach() | |
d_weight = d_weight * discriminator_weight | |
return d_weight | |
def disc_preprocess(self, img): | |
# reshape [B, N_v, C, H, W] to [B*N_v, C, H, W] | |
img = torch.flatten(img, 0, 1) | |
# img = rearrange(img, 'b n c h w -> (b n) c h w') | |
# convert 0-1 to -1-1 | |
img = 2 * img - 1 | |
if hasattr(self.accelerator.unwrap_model(self.model_disc), "input_size"): | |
tgt_size = self.accelerator.unwrap_model(self.model_disc).input_size | |
img = nn.functional.interpolate(img, (tgt_size, tgt_size)) | |
img = img.float() | |
return img | |
def forward_to_get_loss_with_gen_loss(self, data): | |
# forward to loss | |
outs, loss, loss_pixel, loss_perceptual, loss_tv, loss_mask, extra_loss_dict = self.forward_loss_local_step(data) | |
with torch.autocast(device_type=outs["comp_rgb"].device.type, dtype=torch.float32): | |
logits_fake = self.model_disc(self.disc_preprocess(outs["comp_rgb"])) | |
loss_gen = -torch.mean(logits_fake) | |
try: | |
if loss < 1e-5: | |
d_weight = self.cfg.model.disc.disc_weight | |
else: | |
nll_loss = loss_pixel | |
if nll_loss is None: | |
nll_loss = loss | |
d_weight = self.calculate_adaptive_weight(nll_loss, loss_gen, | |
last_layer=self.accelerator.unwrap_model(self.model).get_last_layer(), | |
discriminator_weight=self.cfg.model.disc.disc_weight) | |
except RuntimeError: | |
print("*************Error when calculate_adaptive_weight************") | |
d_weight = torch.tensor(0.0) | |
disc_factor = self.adopt_weight(1.0, self.global_step, threshold=self.cfg.model.disc.disc_iter_start) | |
# print(disc_factor, d_weight) | |
loss += disc_factor * d_weight * loss_gen | |
# backward | |
self.accelerator.backward(loss) | |
if self.accelerator.sync_gradients and self.cfg.train.optim.clip_grad_norm > 0.: | |
self.accelerator.clip_grad_norm_(self.model.parameters(), self.cfg.train.optim.clip_grad_norm) | |
self.optimizer.step() | |
self.optimizer.zero_grad() | |
return outs, loss, loss_pixel, loss_perceptual, loss_tv, loss_mask, loss_gen, extra_loss_dict | |
def forward_to_get_loss(self, data): | |
# forward to loss | |
outs, loss, loss_pixel, loss_perceptual, loss_tv, loss_mask, extra_loss_dict = self.forward_loss_local_step(data) | |
# backward | |
self.accelerator.backward(loss) | |
if self.accelerator.sync_gradients and self.cfg.train.optim.clip_grad_norm > 0.: | |
self.accelerator.clip_grad_norm_(self.model.parameters(), self.cfg.train.optim.clip_grad_norm) | |
self.optimizer.step() | |
self.optimizer.zero_grad() | |
return outs, loss, loss_pixel, loss_perceptual, loss_tv, loss_mask, extra_loss_dict | |
def forward_disc_loss_local_step(self, pred_img, gt_img): | |
# detach gradient of pred_img | |
with torch.autocast(device_type=pred_img.device.type, dtype=torch.float32): | |
logits_real = self.model_disc(self.disc_preprocess(gt_img).detach()) | |
logits_fake = self.model_disc(self.disc_preprocess(pred_img).detach()) | |
loss_disc = hinge_d_loss(logits_real, logits_fake) | |
return loss_disc | |
def forward_to_get_disc_loss(self, pred_img, gt_img): | |
# forward to loss | |
loss_disc = self.forward_disc_loss_local_step(pred_img, gt_img) | |
disc_factor = self.adopt_weight(1.0, self.global_step, threshold=self.cfg.model.disc.disc_iter_start) | |
loss = disc_factor * loss_disc | |
# backward | |
self.accelerator.backward(loss) | |
if self.accelerator.sync_gradients and self.cfg.train.optim.clip_grad_norm > 0.: | |
self.accelerator.clip_grad_norm_(self.model_disc.parameters(), self.cfg.train.optim.clip_grad_norm) | |
self.optimizer_disc.step() | |
self.optimizer_disc.zero_grad() | |
return loss_disc | |
def train_epoch(self, pbar: tqdm, loader: torch.utils.data.DataLoader, profiler: torch.profiler.profile, iepoch: int): | |
self.model.train() | |
if self.has_disc: | |
self.model_disc.train() | |
local_step_losses = [] | |
global_step_losses = [] | |
local_step_extra_losses = [] | |
global_step_extra_losses = [] | |
extra_loss_keys = [] | |
logger.debug(f"======== Starting epoch {self.current_epoch} ========") | |
loss_disc = None | |
for idx, data in enumerate(loader): | |
data["source_rgbs"] = data["source_rgbs"].to(self.weight_dtype) | |
if self.has_disc and hasattr(self.cfg.model.disc, "cross_render") and self.cfg.model.disc.cross_render: | |
data = self.prepare_cross_render_data(data) | |
data["real_num_view"] = 1 | |
else: | |
data["real_num_view"] = data["render_image"].shape[1] | |
logger.debug(f"======== Starting global step {self.global_step} ========") | |
if not self.has_disc: | |
disc_step = False | |
with self.accelerator.accumulate(self.model): | |
outs, loss, loss_pixel, loss_perceptual, loss_tv, loss_mask, extra_loss_dict = self.forward_to_get_loss(data) | |
# track local losses | |
loss_disc, loss_gen = None, None | |
local_step_losses.append(torch.stack([ | |
_loss.detach() if _loss is not None else torch.tensor(float('nan'), device=self.device) | |
for _loss in [loss, loss_pixel, loss_perceptual, loss_tv, loss_mask, loss_disc, loss_gen] | |
])) | |
extra_loss_keys = sorted(list(extra_loss_dict.keys())) | |
if len(extra_loss_keys) > 0: | |
local_step_extra_losses.append(torch.stack([ | |
extra_loss_dict[k].detach() if extra_loss_dict[k] is not None else torch.tensor(float('nan'), device=self.device) | |
for k in extra_loss_keys | |
])) | |
else: | |
disc_step = (idx % 5) == 0 or (iepoch * len(loader) + idx < 100 and idx % 2 == 0) | |
local_step_losses_bak = torch.zeros(6, device=data["source_rgbs"].device) | |
if not disc_step: | |
with self.accelerator.accumulate(self.model): | |
# generator step | |
outs, loss, loss_pixel, loss_perceptual, loss_tv, loss_mask, loss_gen, extra_loss_dict = self.forward_to_get_loss_with_gen_loss(data) | |
# track local losses | |
local_step_losses.append(torch.stack([ | |
_loss.detach() if _loss is not None else torch.tensor(float('nan'), device=self.device) | |
for _loss in [loss, loss_pixel, loss_perceptual, loss_tv, loss_mask, loss_gen, loss_disc] | |
])) | |
local_step_losses_bak = local_step_losses[-1].detach() | |
torch.cuda.empty_cache() | |
extra_loss_keys = sorted(list(extra_loss_dict.keys())) | |
if len(extra_loss_keys) > 0: | |
local_step_extra_losses.append(torch.stack([ | |
extra_loss_dict[k].detach() if extra_loss_dict[k] is not None else torch.tensor(float('nan'), device=self.device) | |
for k in extra_loss_keys | |
])) | |
else: | |
with self.accelerator.accumulate(self.model_disc): | |
# discriminator step | |
outs, _, _, _, _, _, _ = self.forward_loss_local_step(data) | |
loss_disc = self.forward_to_get_disc_loss(pred_img=outs["comp_rgb"], | |
gt_img=data["render_image"]) | |
local_step_losses.append(torch.concat([local_step_losses_bak[:6], loss_disc.unsqueeze(0)], dim=0)) | |
torch.cuda.empty_cache() | |
# track global step | |
if self.accelerator.sync_gradients: | |
profiler.step() | |
if not disc_step: | |
self.scheduler.step() | |
if self.has_disc and disc_step: | |
self.scheduler_disc.step() | |
logger.debug(f"======== Scheduler step ========") | |
self.global_step += 1 | |
global_step_loss = self.accelerator.gather(torch.stack(local_step_losses)).mean(dim=0).cpu() | |
if len(extra_loss_keys) > 0: | |
global_step_extra_loss = self.accelerator.gather(torch.stack(local_step_extra_losses)).mean(dim=0).cpu() | |
global_step_extra_loss_items = global_step_extra_loss.unbind() | |
else: | |
global_step_extra_loss = None | |
global_step_extra_loss_items = [] | |
loss, loss_pixel, loss_perceptual, loss_tv, loss_mask, loss_gen, loss_disc_ = global_step_loss.unbind() | |
loss_kwargs = { | |
'loss': loss.item(), | |
'loss_pixel': loss_pixel.item(), | |
'loss_perceptual': loss_perceptual.item(), | |
'loss_tv': loss_tv.item(), | |
'loss_mask': loss_mask.item(), | |
'loss_disc': loss_disc_.item(), | |
'loss_gen': loss_gen.item(), | |
} | |
for k, loss in zip(extra_loss_keys, global_step_extra_loss_items): | |
loss_kwargs[k] = loss.item() | |
self.log_scalar_kwargs( | |
step=self.global_step, split='train', | |
**loss_kwargs | |
) | |
self.log_optimizer(step=self.global_step, attrs=['lr'], group_ids=[0, 1]) | |
local_step_losses = [] | |
global_step_losses.append(global_step_loss) | |
local_step_extra_losses = [] | |
global_step_extra_losses.append(global_step_extra_loss) | |
# manage display | |
pbar.update(1) | |
description = { | |
**loss_kwargs, | |
'lr': self.optimizer.param_groups[0]['lr'], | |
} | |
description = '[TRAIN STEP]' + \ | |
', '.join(f'{k}={tqdm.format_num(v)}' for k, v in description.items() if not math.isnan(v)) | |
pbar.set_description(description) | |
# periodic actions | |
if self.global_step % self.cfg.saver.checkpoint_global_steps == 0: | |
self.save_checkpoint() | |
if self.global_step % self.cfg.val.global_step_period == 0: | |
self.evaluate() | |
self.model.train() | |
if self.has_disc: | |
self.model_disc.train() | |
if (self.global_step % self.cfg.logger.image_monitor.train_global_steps == 0) or (self.global_step < 1000 and self.global_step % 20 == 0): | |
conf_sigma_l1 = outs.get('conf_sigma_l1', None) | |
conf_sigma_l1 = conf_sigma_l1.cpu() if conf_sigma_l1 is not None else None | |
conf_sigma_percl = outs.get('conf_sigma_percl', None) | |
conf_sigma_percl = conf_sigma_percl.cpu() if conf_sigma_percl is not None else None | |
self.log_image_monitor( | |
step=self.global_step, split='train', | |
renders=outs['comp_rgb'].detach()[:self.cfg.logger.image_monitor.samples_per_log].cpu(), | |
conf_sigma_l1=conf_sigma_l1, conf_sigma_percl=conf_sigma_percl, | |
gts=data['render_image'][:self.cfg.logger.image_monitor.samples_per_log].cpu(), | |
) | |
if 'comp_mask' in outs.keys(): | |
self.log_image_monitor( | |
step=self.global_step, split='train', | |
renders=outs['comp_mask'].detach()[:self.cfg.logger.image_monitor.samples_per_log].cpu(), | |
gts=data['render_mask'][:self.cfg.logger.image_monitor.samples_per_log].cpu(), | |
prefix="_mask", | |
) | |
# progress control | |
if self.global_step >= self.N_max_global_steps: | |
self.accelerator.set_trigger() | |
break | |
# track epoch | |
self.current_epoch += 1 | |
epoch_losses = torch.stack(global_step_losses).mean(dim=0) | |
epoch_loss, epoch_loss_pixel, epoch_loss_perceptual, epoch_loss_tv, epoch_loss_mask, epoch_loss_disc, epoch_loss_gen = epoch_losses.unbind() | |
epoch_loss_dict = { | |
'loss': epoch_loss.item(), | |
'loss_pixel': epoch_loss_pixel.item(), | |
'loss_perceptual': epoch_loss_perceptual.item(), | |
'loss_tv': epoch_loss_tv.item(), | |
'loss_mask': epoch_loss_mask.item(), | |
'loss_disc': epoch_loss_disc.item(), | |
'loss_gen': epoch_loss_gen.item(), | |
} | |
if len(extra_loss_keys) > 0: | |
epoch_extra_losses = torch.stack(global_step_extra_losses).mean(dim=0) | |
for k, v in zip(extra_loss_keys, epoch_extra_losses.unbind()): | |
epoch_loss_dict[k] = v.item() | |
self.log_scalar_kwargs( | |
epoch=self.current_epoch, split='train', | |
**epoch_loss_dict, | |
) | |
logger.info( | |
f'[TRAIN EPOCH] {self.current_epoch}/{self.cfg.train.epochs}: ' + \ | |
', '.join(f'{k}={tqdm.format_num(v)}' for k, v in epoch_loss_dict.items() if not math.isnan(v)) | |
) | |
def train(self): | |
starting_local_step_in_epoch = self.global_step_in_epoch * self.cfg.train.accum_steps | |
skipped_loader = self.accelerator.skip_first_batches(self.train_loader, starting_local_step_in_epoch) | |
logger.info(f"======== Skipped {starting_local_step_in_epoch} local batches ========") | |
with tqdm( | |
range(0, self.N_max_global_steps), | |
initial=self.global_step, | |
disable=(not self.accelerator.is_main_process), | |
) as pbar: | |
profiler = torch.profiler.profile( | |
activities=[torch.profiler.ProfilerActivity.CPU, torch.profiler.ProfilerActivity.CUDA], | |
schedule=torch.profiler.schedule( | |
wait=10, warmup=10, active=100, | |
), | |
on_trace_ready=torch.profiler.tensorboard_trace_handler(os.path.join( | |
self.cfg.logger.tracker_root, | |
self.cfg.experiment.parent, self.cfg.experiment.child, | |
)), | |
record_shapes=True, | |
profile_memory=True, | |
with_stack=True, | |
) if self.cfg.logger.enable_profiler else DummyProfiler() | |
with profiler: | |
self.optimizer.zero_grad() | |
if self.has_disc: | |
self.optimizer_disc.zero_grad() | |
for iepoch in range(self.current_epoch, self.cfg.train.epochs): | |
loader = skipped_loader or self.train_loader | |
skipped_loader = None | |
self.train_epoch(pbar=pbar, loader=loader, profiler=profiler, iepoch=iepoch) | |
if self.accelerator.check_trigger(): | |
break | |
logger.info(f"======== Training finished at global step {self.global_step} ========") | |
# final checkpoint and evaluation | |
self.save_checkpoint() | |
self.evaluate() | |
def evaluate(self, epoch: int = None): | |
self.model.eval() | |
max_val_batches = self.cfg.val.debug_batches or len(self.val_loader) | |
running_losses = [] | |
running_extra_losses = [] | |
extra_loss_keys = [] | |
sample_data, sample_outs = None, None | |
for data in tqdm(self.val_loader, disable=(not self.accelerator.is_main_process), total=max_val_batches): | |
data["source_rgbs"] = data["source_rgbs"].to(self.weight_dtype) | |
if self.has_disc and hasattr(self.cfg.model.disc, "cross_render") and self.cfg.model.disc.cross_render: | |
data = self.prepare_cross_render_data(data) | |
data["real_num_view"] = 1 | |
else: | |
data["real_num_view"] = data["render_image"].shape[1] | |
if len(running_losses) >= max_val_batches: | |
logger.info(f"======== Early stop validation at {len(running_losses)} batches ========") | |
break | |
outs, loss, loss_pixel, loss_perceptual, loss_tv, loss_mask, extra_loss_dict = self.forward_loss_local_step(data) | |
extra_loss_dict = sorted(list(extra_loss_dict.keys())) | |
sample_data, sample_outs = data, outs | |
running_losses.append(torch.stack([ | |
_loss if _loss is not None else torch.tensor(float('nan'), device=self.device) | |
for _loss in [loss, loss_pixel, loss_perceptual, loss_tv, loss_mask] | |
])) | |
if len(extra_loss_keys) > 0: | |
running_extra_losses.append(torch.stack([ | |
extra_loss_dict[k] if extra_loss_dict[k] is not None else torch.tensor(float('nan'), device=self.device) | |
for k in extra_loss_keys | |
])) | |
# log each step | |
conf_sigma_l1 = sample_outs.get('conf_sigma_l1', None) | |
conf_sigma_l1 = conf_sigma_l1.cpu() if conf_sigma_l1 is not None else None | |
conf_sigma_percl = sample_outs.get('conf_sigma_percl', None) | |
conf_sigma_percl = conf_sigma_percl.cpu() if conf_sigma_percl is not None else None | |
self.log_image_monitor_each_process( | |
step=self.global_step, split='val', | |
renders=sample_outs['comp_rgb'][:self.cfg.logger.image_monitor.samples_per_log].cpu(), | |
gts=sample_data['render_image'][:self.cfg.logger.image_monitor.samples_per_log].cpu(), | |
conf_sigma_l1=conf_sigma_l1, conf_sigma_percl=conf_sigma_percl, | |
prefix=f"_{len(running_losses)}_rank{self.accelerator.process_index}" | |
) | |
if "comp_mask" in sample_outs.keys(): | |
self.log_image_monitor_each_process( | |
step=self.global_step, split='val', | |
renders=sample_outs['comp_mask'][:self.cfg.logger.image_monitor.samples_per_log].cpu(), | |
gts=sample_data['render_mask'][:self.cfg.logger.image_monitor.samples_per_log].cpu(), | |
prefix=f"_mask_{len(running_losses)}_rank{self.accelerator.process_index}" | |
) | |
total_losses = self.accelerator.gather(torch.stack(running_losses)).mean(dim=0).cpu() | |
total_loss, total_loss_pixel, total_loss_perceptual, total_loss_offset, total_loss_mask = total_losses.unbind() | |
total_loss_dict = { | |
'loss': total_loss.item(), | |
'loss_pixel': total_loss_pixel.item(), | |
'loss_perceptual': total_loss_perceptual.item(), | |
'loss_offset': total_loss_offset.item(), | |
'loss_mask': total_loss_mask.item(), | |
} | |
if len(extra_loss_keys) > 0: | |
total_extra_losses = self.accelerator.gather(torch.stack(running_extra_losses)).mean(dim=0).cpu() | |
for k, v in zip(extra_loss_keys, total_extra_losses.unbind()): | |
total_loss_dict[k] = v.item() | |
if epoch is not None: | |
self.log_scalar_kwargs( | |
epoch=epoch, split='val', | |
**total_loss_dict, | |
) | |
logger.info( | |
f'[VAL EPOCH] {epoch}/{self.cfg.train.epochs}: ' + \ | |
', '.join(f'{k}={tqdm.format_num(v)}' for k, v in total_loss_dict.items() if not math.isnan(v)) | |
) | |
else: | |
self.log_scalar_kwargs( | |
step=self.global_step, split='val', | |
**total_loss_dict, | |
) | |
logger.info( | |
f'[VAL STEP] {self.global_step}/{self.N_max_global_steps}: ' + \ | |
', '.join(f'{k}={tqdm.format_num(v)}' for k, v in total_loss_dict.items() if not math.isnan(v)) | |
) | |
def log_image_monitor_each_process( | |
self, epoch: int = None, step: int = None, split: str = None, | |
renders: torch.Tensor = None, gts: torch.Tensor = None, prefix=None, | |
conf_sigma_l1: torch.Tensor = None, conf_sigma_percl: torch.Tensor = None | |
): | |
M = renders.shape[1] | |
if gts.shape[1] != M: | |
gts = repeat(gts, "b v c h w -> b (v r) c h w", r=2) | |
merged = torch.stack([renders, gts], dim=1)[0].view(-1, *renders.shape[2:]) | |
renders, gts = renders.view(-1, *renders.shape[2:]), gts.view(-1, *gts.shape[2:]) | |
renders, gts, merged = make_grid(renders, nrow=M), make_grid(gts, nrow=M), make_grid(merged, nrow=M) | |
log_type, log_progress = self._get_str_progress(epoch, step) | |
split = f'/{split}' if split else '' | |
split = split + prefix if prefix is not None else split | |
log_img_dict = { | |
f'Images_split{split}/rendered': renders.unsqueeze(0), | |
f'Images_split{split}/gt': gts.unsqueeze(0), | |
f'Images_split{split}/merged': merged.unsqueeze(0), | |
} | |
if conf_sigma_l1 is not None: | |
EPS = 1e-7 | |
vis_conf_l1 = 1/(1+conf_sigma_l1.detach()+EPS).cpu() | |
vis_conf_percl = 1/(1+conf_sigma_percl.detach()+EPS).cpu() | |
vis_conf_l1, vis_conf_percl = rearrange(vis_conf_l1, "b v (r c) h w -> (b v r) c h w", r=2), rearrange(vis_conf_percl, "b v (r c) h w -> (b v r) c h w", r=2) | |
vis_conf_l1, vis_conf_percl = repeat(vis_conf_l1, "b c1 h w-> b (c1 c2) h w", c2=3), repeat(vis_conf_percl, "b c1 h w -> b (c1 c2) h w", c2=3) | |
vis_conf_l1, vis_conf_percl = make_grid(vis_conf_l1, nrow=M), make_grid(vis_conf_percl, nrow=M) | |
log_img_dict[f'Images_split{split}/conf_l1'] = vis_conf_l1.unsqueeze(0) | |
log_img_dict[f'Images_split{split}/conf_percl'] = vis_conf_percl.unsqueeze(0) | |
self.log_images_each_process(log_img_dict, log_progress, {"imwrite_image": False}) | |
def log_image_monitor( | |
self, epoch: int = None, step: int = None, split: str = None, | |
renders: torch.Tensor = None, gts: torch.Tensor = None, prefix=None, | |
conf_sigma_l1: torch.Tensor = None, conf_sigma_percl: torch.Tensor = None | |
): | |
self.log_image_monitor_each_process(epoch, step, split, renders, gts, prefix, conf_sigma_l1, conf_sigma_percl) | |