3DTopia2 / 3DTopia /model /sv_vae_triplane.py
HongFangzhou
add source codes
bc2085d
import os
import imageio
import numpy as np
import torch
import torchvision
import torch.nn as nn
import pytorch_lightning as pl
import wandb
import lpips
from pytorch_msssim import SSIM
from utility.initialize import instantiate_from_config
class VAE(pl.LightningModule):
def __init__(self, vae_configs, renderer_configs, lr=1e-3, weight_decay=1e-2,
kld_weight=1, mse_weight=1, lpips_weight=0.1, ssim_weight=0.1,
log_image_freq=50):
super().__init__()
self.save_hyperparameters()
self.lr = lr
self.weight_decay = weight_decay
self.kld_weight = kld_weight
self.mse_weight = mse_weight
self.lpips_weight = lpips_weight
self.ssim_weight = ssim_weight
self.log_image_freq = log_image_freq
self.vae = instantiate_from_config(vae_configs)
self.renderer = instantiate_from_config(renderer_configs)
self.lpips_fn = lpips.LPIPS(net='alex')
self.ssim_fn = SSIM(data_range=1, size_average=True, channel=3)
self.triplane_render_kwargs = {
'depth_resolution': 64,
'disparity_space_sampling': False,
'box_warp': 2.4,
'depth_resolution_importance': 64,
'clamp_mode': 'softplus',
'white_back': True,
}
def forward(self, batch, is_train):
encoder_img, input_img, input_ray_o, input_ray_d, \
target_img, target_ray_o, target_ray_d = batch
grid, mu, logvar = self.vae(encoder_img, is_train)
cat_ray_o = torch.cat([input_ray_o, target_ray_o], 0)
cat_ray_d = torch.cat([input_ray_d, target_ray_d], 0)
render_out = self.renderer(torch.cat([grid, grid], 0), cat_ray_o, cat_ray_d, self.triplane_render_kwargs)
render_gt = torch.cat([input_img, target_img], 0)
return render_out['rgb_marched'], render_out['depth_final'], \
render_out['weights'], mu, logvar, render_gt
def calc_loss(self, render, mu, logvar, render_gt):
mse = torch.mean((render - render_gt) ** 2)
ssim_loss = 1 - self.ssim_fn(render, render_gt)
lpips_loss = self.lpips_fn((render * 2) - 1, (render_gt * 2) - 1).mean()
kld_loss = -0.5 * torch.mean(torch.mean(1 + logvar - mu.pow(2) - logvar.exp(), 1))
loss = self.mse_weight * mse + self.ssim_weight * ssim_loss + \
self.lpips_weight * lpips_loss + self.kld_weight * kld_loss
return {
'loss': loss,
'mse': mse,
'ssim': ssim_loss,
'lpips': lpips_loss,
'kld': kld_loss,
}
def log_dict(self, loss_dict, prefix):
for k, v in loss_dict.items():
self.log(prefix + k, v, on_step=True, logger=True)
def make_grid(self, render, depth, render_gt):
bs = render.shape[0] // 2
grid = torchvision.utils.make_grid(
torch.stack([render_gt[0], render_gt[bs], render[0], depth[0], render[bs], depth[bs]], 0))
grid = (grid.detach().cpu().permute(1, 2, 0) * 255.).numpy().astype(np.uint8)
return grid
def training_step(self, batch, batch_idx):
render, depth, weights, mu, logvar, render_gt = self.forward(batch, True)
loss_dict = self.calc_loss(render, mu, logvar, render_gt)
self.log_dict(loss_dict, 'train/')
if batch_idx % self.log_image_freq == 0:
self.logger.experiment.log({
'train/vis': [wandb.Image(self.make_grid(
render, depth, render_gt
))]
})
return loss_dict['loss']
def validation_step(self, batch, batch_idx):
render, depth, _, mu, logvar, render_gt = self.forward(batch, False)
loss_dict = self.calc_loss(render, mu, logvar, render_gt)
self.log_dict(loss_dict, 'val/')
if batch_idx % self.log_image_freq == 0:
self.logger.experiment.log({
'val/vis': [wandb.Image(self.make_grid(
render, depth, render_gt
))]
})
def configure_optimizers(self):
optimizer = torch.optim.AdamW(self.parameters(), lr=self.lr, weight_decay=self.weight_decay)
return optimizer