|
import os |
|
import imageio |
|
import torch |
|
import wandb |
|
import numpy as np |
|
import pytorch_lightning as pl |
|
import torch.nn.functional as F |
|
|
|
from module.model_2d import Encoder, Decoder, DiagonalGaussianDistribution, Encoder_GroupConv, Decoder_GroupConv, Encoder_GroupConv_LateFusion, Decoder_GroupConv_LateFusion |
|
from utility.initialize import instantiate_from_config |
|
from utility.triplane_renderer.renderer import get_embedder, NeRF, run_network, render_path1, to8b, img2mse, mse2psnr |
|
from utility.triplane_renderer.eg3d_renderer import Renderer_TriPlane |
|
|
|
class AutoencoderKL(pl.LightningModule): |
|
def __init__(self, |
|
ddconfig, |
|
lossconfig, |
|
embed_dim, |
|
learning_rate=1e-3, |
|
ckpt_path=None, |
|
ignore_keys=[], |
|
colorize_nlabels=None, |
|
monitor=None, |
|
decoder_ckpt=None, |
|
norm=False, |
|
renderer_type='nerf', |
|
renderer_config=dict( |
|
rgbnet_dim=18, |
|
rgbnet_width=128, |
|
viewpe=0, |
|
feape=0 |
|
), |
|
): |
|
super().__init__() |
|
self.save_hyperparameters() |
|
self.norm = norm |
|
self.renderer_config = renderer_config |
|
self.learning_rate = learning_rate |
|
self.encoder = Encoder(**ddconfig) |
|
self.decoder = Decoder(**ddconfig) |
|
|
|
self.lossconfig = lossconfig |
|
assert ddconfig["double_z"] |
|
self.quant_conv = torch.nn.Conv2d(2*ddconfig["z_channels"], 2*embed_dim, 1) |
|
self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1) |
|
|
|
self.embed_dim = embed_dim |
|
if colorize_nlabels is not None: |
|
assert type(colorize_nlabels)==int |
|
self.register_buffer("colorize", torch.randn(3, colorize_nlabels, 1, 1)) |
|
if monitor is not None: |
|
self.monitor = monitor |
|
if ckpt_path is not None: |
|
self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys) |
|
|
|
self.decoder_ckpt = decoder_ckpt |
|
self.renderer_type = renderer_type |
|
|
|
assert self.renderer_type in ['nerf', 'eg3d'] |
|
if self.renderer_type == 'nerf': |
|
self.triplane_decoder, self.triplane_render_kwargs = self.create_nerf(decoder_ckpt) |
|
elif self.renderer_type == 'eg3d': |
|
self.triplane_decoder, self.triplane_render_kwargs = self.create_eg3d_decoder(decoder_ckpt) |
|
else: |
|
raise NotImplementedError |
|
|
|
self.psum = torch.zeros([1]) |
|
self.psum_sq = torch.zeros([1]) |
|
self.psum_min = torch.zeros([1]) |
|
self.psum_max = torch.zeros([1]) |
|
self.count = 0 |
|
self.len_dset = 0 |
|
self.latent_list = [] |
|
|
|
def init_from_ckpt(self, path, ignore_keys=list()): |
|
sd = torch.load(path, map_location="cpu")["state_dict"] |
|
keys = list(sd.keys()) |
|
for k in keys: |
|
for ik in ignore_keys: |
|
if k.startswith(ik): |
|
print("Deleting key {} from state_dict.".format(k)) |
|
del sd[k] |
|
self.load_state_dict(sd, strict=False) |
|
print(f"Restored from {path}") |
|
|
|
def encode(self, x, rollout=False): |
|
h = self.encoder(x) |
|
moments = self.quant_conv(h) |
|
posterior = DiagonalGaussianDistribution(moments) |
|
return posterior |
|
|
|
def decode(self, z, unrollout=False): |
|
z = self.post_quant_conv(z) |
|
dec = self.decoder(z) |
|
return dec |
|
|
|
def forward(self, input, sample_posterior=True): |
|
posterior = self.encode(input) |
|
if sample_posterior: |
|
z = posterior.sample() |
|
else: |
|
z = posterior.mode() |
|
dec = self.decode(z) |
|
return dec, posterior |
|
|
|
def unrollout(self, *args, **kwargs): |
|
pass |
|
|
|
def loss(self, inputs, reconstructions, posteriors, prefix, batch=None): |
|
reconstructions = reconstructions.contiguous() |
|
rec_loss = torch.abs(inputs.contiguous() - reconstructions) |
|
rec_loss = torch.sum(rec_loss) / rec_loss.shape[0] |
|
kl_loss = posteriors.kl() |
|
kl_loss = torch.sum(kl_loss) / kl_loss.shape[0] |
|
loss = self.lossconfig.rec_weight * rec_loss + self.lossconfig.kl_weight * kl_loss |
|
|
|
ret_dict = { |
|
prefix+'mean_rec_loss': torch.abs(inputs.contiguous() - reconstructions.contiguous()).mean().detach(), |
|
prefix+'rec_loss': rec_loss, |
|
prefix+'kl_loss': kl_loss, |
|
prefix+'loss': loss, |
|
prefix+'mean': posteriors.mean.mean(), |
|
prefix+'logvar': posteriors.logvar.mean(), |
|
} |
|
|
|
render_weight = self.lossconfig.get("render_weight", 0) |
|
tv_weight = self.lossconfig.get("tv_weight", 0) |
|
l1_weight = self.lossconfig.get("l1_weight", 0) |
|
latent_tv_weight = self.lossconfig.get("latent_tv_weight", 0) |
|
latent_l1_weight = self.lossconfig.get("latent_l1_weight", 0) |
|
|
|
triplane_rec = self.unrollout(reconstructions) |
|
if render_weight > 0 and batch is not None: |
|
rgb_rendered, target = self.render_triplane_eg3d_decoder_sample_pixel(triplane_rec, batch['batch_rays'], batch['img']) |
|
render_loss = ((rgb_rendered - target) ** 2).sum() / rgb_rendered.shape[0] * 256 |
|
loss += render_weight * render_loss |
|
ret_dict[prefix + 'render_loss'] = render_loss |
|
if tv_weight > 0: |
|
tvloss_y = torch.abs(triplane_rec[:, :, :-1] - triplane_rec[:, :, 1:]).sum() / triplane_rec.shape[0] |
|
tvloss_x = torch.abs(triplane_rec[:, :, :, :-1] - triplane_rec[:, :, :, 1:]).sum() / triplane_rec.shape[0] |
|
tvloss = tvloss_y + tvloss_x |
|
loss += tv_weight * tvloss |
|
ret_dict[prefix + 'tv_loss'] = tvloss |
|
if l1_weight > 0: |
|
l1 = (triplane_rec ** 2).sum() / triplane_rec.shape[0] |
|
loss += l1_weight * l1 |
|
ret_dict[prefix + 'l1_loss'] = l1 |
|
if latent_tv_weight > 0: |
|
latent = posteriors.mean |
|
latent_tv_y = torch.abs(latent[:, :, :-1] - latent[:, :, 1:]).sum() / latent.shape[0] |
|
latent_tv_x = torch.abs(latent[:, :, :, :-1] - latent[:, :, :, 1:]).sum() / latent.shape[0] |
|
latent_tv_loss = latent_tv_y + latent_tv_x |
|
loss += latent_tv_loss * latent_tv_weight |
|
ret_dict[prefix + 'latent_tv_loss'] = latent_tv_loss |
|
ret_dict[prefix + 'latent_max'] = latent.max() |
|
ret_dict[prefix + 'latent_min'] = latent.min() |
|
if latent_l1_weight > 0: |
|
latent = posteriors.mean |
|
latent_l1_loss = (latent ** 2).sum() / latent.shape[0] |
|
loss += latent_l1_loss * latent_l1_weight |
|
ret_dict[prefix + 'latent_l1_loss'] = latent_l1_loss |
|
|
|
return loss, ret_dict |
|
|
|
def training_step(self, batch, batch_idx): |
|
|
|
inputs = batch['triplane'] |
|
reconstructions, posterior = self(inputs) |
|
|
|
|
|
|
|
aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, prefix='train/') |
|
|
|
self.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=False) |
|
return aeloss |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def validation_step(self, batch, batch_idx): |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
inputs = batch['triplane'] |
|
reconstructions, posterior = self(inputs, sample_posterior=False) |
|
aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, prefix='val/') |
|
self.log_dict(log_dict_ae) |
|
|
|
assert not self.norm |
|
psnr_list = [] |
|
psnr_input_list = [] |
|
psnr_rec_list = [] |
|
batch_size = inputs.shape[0] |
|
for b in range(batch_size): |
|
if self.renderer_type == 'nerf': |
|
rgb_input, cur_psnr_list_input = self.render_triplane( |
|
batch['triplane_ori'][b:b+1], batch['batch_rays'][b], batch['img_flat'][b], |
|
batch['near'][b].unsqueeze(-1), batch['far'][b].unsqueeze(-1) |
|
) |
|
rgb, cur_psnr_list = self.render_triplane( |
|
reconstructions[b:b+1], batch['batch_rays'][b], batch['img_flat'][b], |
|
batch['near'][b].unsqueeze(-1), batch['far'][b].unsqueeze(-1) |
|
) |
|
elif self.renderer_type == 'eg3d': |
|
rgb_input, cur_psnr_list_input = self.render_triplane_eg3d_decoder( |
|
batch['triplane_ori'][b:b+1], batch['batch_rays'][b], batch['img'][b], |
|
) |
|
rgb, cur_psnr_list = self.render_triplane_eg3d_decoder( |
|
reconstructions[b:b+1], batch['batch_rays'][b], batch['img'][b], |
|
) |
|
else: |
|
raise NotImplementedError |
|
|
|
cur_psnr_list_rec = [] |
|
for i in range(rgb.shape[0]): |
|
cur_psnr_list_rec.append(mse2psnr(img2mse(rgb_input[i], rgb[i]))) |
|
|
|
rgb_input = to8b(rgb_input.detach().cpu().numpy()) |
|
rgb_gt = to8b(batch['img'][b].detach().cpu().numpy()) |
|
rgb = to8b(rgb.detach().cpu().numpy()) |
|
|
|
if b % 4 == 0 and batch_idx < 10: |
|
rgb_all = np.concatenate([rgb_gt[1], rgb_input[1], rgb[1]], 1) |
|
self.logger.experiment.log({ |
|
"val/vis": [wandb.Image(rgb_all)] |
|
}) |
|
|
|
psnr_list += cur_psnr_list |
|
psnr_input_list += cur_psnr_list_input |
|
psnr_rec_list += cur_psnr_list_rec |
|
|
|
self.log("val/psnr_input_gt", torch.Tensor(psnr_input_list).mean(), prog_bar=True) |
|
self.log("val/psnr_input_rec", torch.Tensor(psnr_rec_list).mean(), prog_bar=True) |
|
self.log("val/psnr_rec_gt", torch.Tensor(psnr_list).mean(), prog_bar=True) |
|
|
|
return self.log_dict |
|
|
|
def create_eg3d_decoder(self, decoder_ckpt): |
|
triplane_decoder = Renderer_TriPlane(**self.renderer_config) |
|
if decoder_ckpt is not None: |
|
pretrain_pth = torch.load(decoder_ckpt, map_location='cpu') |
|
pretrain_pth = { |
|
'.'.join(k.split('.')[1:]): v for k, v in pretrain_pth.items() |
|
} |
|
triplane_decoder.load_state_dict(pretrain_pth) |
|
render_kwargs = { |
|
'depth_resolution': 128, |
|
'disparity_space_sampling': False, |
|
'box_warp': 2.4, |
|
'depth_resolution_importance': 128, |
|
'clamp_mode': 'softplus', |
|
'white_back': True, |
|
'det': True |
|
} |
|
return triplane_decoder, render_kwargs |
|
|
|
def render_triplane_eg3d_decoder(self, triplane, batch_rays, target): |
|
ray_o = batch_rays[:, 0] |
|
ray_d = batch_rays[:, 1] |
|
psnr_list = [] |
|
rec_img_list = [] |
|
res = triplane.shape[-2] |
|
for i in range(ray_o.shape[0]): |
|
with torch.no_grad(): |
|
render_out = self.triplane_decoder(triplane.reshape(1, 3, -1, res, res), |
|
ray_o[i:i+1], ray_d[i:i+1], self.triplane_render_kwargs, whole_img=True, tvloss=False) |
|
rec_img = render_out['rgb_marched'].permute(0, 2, 3, 1) |
|
psnr = mse2psnr(img2mse(rec_img[0], target[i])) |
|
psnr_list.append(psnr) |
|
rec_img_list.append(rec_img) |
|
return torch.cat(rec_img_list, 0), psnr_list |
|
|
|
def render_triplane_eg3d_decoder_sample_pixel(self, triplane, batch_rays, target, sample_num=1024): |
|
assert batch_rays.shape[1] == 1 |
|
sel = torch.randint(batch_rays.shape[-2], [sample_num]) |
|
ray_o = batch_rays[:, 0, 0, sel] |
|
ray_d = batch_rays[:, 0, 1, sel] |
|
res = triplane.shape[-2] |
|
render_out = self.triplane_decoder(triplane.reshape(triplane.shape[0], 3, -1, res, res), |
|
ray_o, ray_d, self.triplane_render_kwargs, whole_img=False, tvloss=False) |
|
rec_img = render_out['rgb_marched'] |
|
target = target.reshape(triplane.shape[0], -1, 3)[:, sel, :] |
|
return rec_img, target |
|
|
|
def create_nerf(self, decoder_ckpt): |
|
|
|
|
|
multires = 10 |
|
netchunk = 1024*64 |
|
i_embed = 0 |
|
perturb = 0 |
|
raw_noise_std = 0 |
|
|
|
triplanechannel=18 |
|
triplanesize=256 |
|
chunk=4096 |
|
num_instance=1 |
|
batch_size=1 |
|
use_viewdirs = True |
|
white_bkgd = False |
|
lrate_decay = 6 |
|
netdepth=1 |
|
netwidth=64 |
|
N_samples = 512 |
|
N_importance = 0 |
|
N_rand = 8192 |
|
multires_views=10 |
|
precrop_iters = 0 |
|
precrop_frac = 0.5 |
|
i_weights=3000 |
|
|
|
embed_fn, input_ch = get_embedder(multires, i_embed) |
|
embeddirs_fn, input_ch_views = get_embedder(multires_views, i_embed) |
|
output_ch = 4 |
|
skips = [4] |
|
model = NeRF(D=netdepth, W=netwidth, |
|
input_ch=triplanechannel, size=triplanesize,output_ch=output_ch, skips=skips, |
|
input_ch_views=input_ch_views, use_viewdirs=use_viewdirs, num_instance=num_instance) |
|
|
|
network_query_fn = lambda inputs, viewdirs, label,network_fn : \ |
|
run_network(inputs, viewdirs, network_fn, |
|
embed_fn=embed_fn, |
|
embeddirs_fn=embeddirs_fn,label=label, |
|
netchunk=netchunk) |
|
|
|
ckpt = torch.load(decoder_ckpt) |
|
model.load_state_dict(ckpt['network_fn_state_dict']) |
|
|
|
render_kwargs_test = { |
|
'network_query_fn' : network_query_fn, |
|
'perturb' : perturb, |
|
'N_samples' : N_samples, |
|
|
|
'use_viewdirs' : use_viewdirs, |
|
'white_bkgd' : white_bkgd, |
|
'raw_noise_std' : raw_noise_std, |
|
} |
|
render_kwargs_test['ndc'] = False |
|
render_kwargs_test['lindisp'] = False |
|
render_kwargs_test['perturb'] = False |
|
render_kwargs_test['raw_noise_std'] = 0. |
|
|
|
return model, render_kwargs_test |
|
|
|
def render_triplane(self, triplane, batch_rays, target, near, far, chunk=4096): |
|
self.triplane_decoder.tri_planes.copy_(triplane.detach()) |
|
self.triplane_render_kwargs['network_fn'] = self.triplane_decoder |
|
|
|
|
|
|
|
|
|
|
|
with torch.no_grad(): |
|
rgb, _, _, psnr_list = \ |
|
render_path1(batch_rays, chunk, self.triplane_render_kwargs, gt_imgs=target, |
|
near=near, far=far, label=torch.Tensor([0]).long().to(triplane.device)) |
|
return rgb, psnr_list |
|
|
|
def to_rgb(self, plane): |
|
x = plane.float() |
|
if not hasattr(self, "colorize"): |
|
self.colorize = torch.randn(3, x.shape[1], 1, 1).to(x) |
|
x = torch.nn.functional.conv2d(x, weight=self.colorize) |
|
x = ((x - x.min()) / (x.max() - x.min()) * 255.).permute(0, 2, 3, 1).detach().cpu().numpy().astype(np.uint8) |
|
return x |
|
|
|
def to_rgb_triplane(self, plane): |
|
x = plane.float() |
|
if not hasattr(self, "colorize_triplane"): |
|
self.colorize_triplane = torch.randn(3, x.shape[1], 1, 1).to(x) |
|
x = torch.nn.functional.conv2d(x, weight=self.colorize_triplane) |
|
x = ((x - x.min()) / (x.max() - x.min()) * 255.).permute(0, 2, 3, 1).detach().cpu().numpy().astype(np.uint8) |
|
return x |
|
|
|
def test_step(self, batch, batch_idx): |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
inputs = batch['triplane'] |
|
reconstructions, posterior = self(inputs, sample_posterior=False) |
|
aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, prefix='test/', batch=None) |
|
self.log_dict(log_dict_ae) |
|
|
|
batch_size = inputs.shape[0] |
|
psnr_list = [] |
|
psnr_input_list = [] |
|
psnr_rec_list = [] |
|
|
|
z = posterior.mode() |
|
colorize_z = self.to_rgb(z)[0] |
|
colorize_triplane_input = self.to_rgb_triplane(inputs)[0] |
|
colorize_triplane_output = self.to_rgb_triplane(reconstructions)[0] |
|
|
|
|
|
|
|
|
|
if batch_idx < 10: |
|
imageio.imwrite(os.path.join(self.logger.log_dir, "colorize_z_{}.png".format(batch_idx)), colorize_z) |
|
imageio.imwrite(os.path.join(self.logger.log_dir, "colorize_input_{}.png".format(batch_idx)), colorize_triplane_input) |
|
imageio.imwrite(os.path.join(self.logger.log_dir, "colorize_output_{}.png".format(batch_idx)), colorize_triplane_output) |
|
|
|
|
|
|
|
|
|
np_z = z.detach().cpu().numpy() |
|
|
|
|
|
|
|
self.latent_list.append(np_z) |
|
|
|
if self.psum.device != z.device: |
|
self.psum = self.psum.to(z.device) |
|
self.psum_sq = self.psum_sq.to(z.device) |
|
self.psum_min = self.psum_min.to(z.device) |
|
self.psum_max = self.psum_max.to(z.device) |
|
self.psum += z.sum() |
|
self.psum_sq += (z ** 2).sum() |
|
self.psum_min += z.reshape(-1).min(-1)[0] |
|
self.psum_max += z.reshape(-1).max(-1)[0] |
|
assert len(z.shape) == 4 |
|
self.count += z.shape[0] * z.shape[1] * z.shape[2] * z.shape[3] |
|
self.len_dset += 1 |
|
|
|
if self.norm: |
|
assert NotImplementedError |
|
else: |
|
reconstructions_unnormalize = reconstructions |
|
|
|
for b in range(batch_size): |
|
if self.renderer_type == 'nerf': |
|
rgb_input, cur_psnr_list_input = self.render_triplane( |
|
batch['triplane_ori'][b:b+1], batch['batch_rays'][b], batch['img_flat'][b], |
|
batch['near'][b].unsqueeze(-1), batch['far'][b].unsqueeze(-1) |
|
) |
|
rgb, cur_psnr_list = self.render_triplane( |
|
reconstructions_unnormalize[b:b+1], batch['batch_rays'][b], batch['img_flat'][b], |
|
batch['near'][b].unsqueeze(-1), batch['far'][b].unsqueeze(-1) |
|
) |
|
elif self.renderer_type == 'eg3d': |
|
rgb_input, cur_psnr_list_input = self.render_triplane_eg3d_decoder( |
|
batch['triplane_ori'][b:b+1], batch['batch_rays'][b], batch['img'][b], |
|
) |
|
rgb, cur_psnr_list = self.render_triplane_eg3d_decoder( |
|
reconstructions_unnormalize[b:b+1], batch['batch_rays'][b], batch['img'][b], |
|
) |
|
else: |
|
raise NotImplementedError |
|
|
|
cur_psnr_list_rec = [] |
|
for i in range(rgb.shape[0]): |
|
cur_psnr_list_rec.append(mse2psnr(img2mse(rgb_input[i], rgb[i]))) |
|
|
|
rgb_input = to8b(rgb_input.detach().cpu().numpy()) |
|
rgb_gt = to8b(batch['img'][b].detach().cpu().numpy()) |
|
rgb = to8b(rgb.detach().cpu().numpy()) |
|
|
|
if batch_idx < 10: |
|
imageio.imwrite(os.path.join(self.logger.log_dir, "{}_{}_input.png".format(batch_idx, b)), rgb_input[1]) |
|
imageio.imwrite(os.path.join(self.logger.log_dir, "{}_{}_rec.png".format(batch_idx, b)), rgb[1]) |
|
imageio.imwrite(os.path.join(self.logger.log_dir, "{}_{}_gt.png".format(batch_idx, b)), rgb_gt[1]) |
|
|
|
psnr_list += cur_psnr_list |
|
psnr_input_list += cur_psnr_list_input |
|
psnr_rec_list += cur_psnr_list_rec |
|
|
|
self.log("test/psnr_input_gt", torch.Tensor(psnr_input_list).mean(), prog_bar=True) |
|
self.log("test/psnr_input_rec", torch.Tensor(psnr_rec_list).mean(), prog_bar=True) |
|
self.log("test/psnr_rec_gt", torch.Tensor(psnr_list).mean(), prog_bar=True) |
|
|
|
def configure_optimizers(self): |
|
lr = self.learning_rate |
|
opt_ae = torch.optim.Adam(list(self.encoder.parameters())+ |
|
list(self.decoder.parameters())+ |
|
list(self.quant_conv.parameters())+ |
|
list(self.post_quant_conv.parameters()), |
|
lr=lr, betas=(0.5, 0.9)) |
|
|
|
|
|
|
|
return opt_ae |
|
|
|
def on_test_epoch_end(self): |
|
mean = self.psum / self.count |
|
mean_min = self.psum_min / self.len_dset |
|
mean_max = self.psum_max / self.len_dset |
|
var = (self.psum_sq / self.count) - (mean ** 2) |
|
std = torch.sqrt(var) |
|
|
|
print("mean min: {}".format(mean_min)) |
|
print("mean max: {}".format(mean_max)) |
|
print("mean: {}".format(mean)) |
|
print("std: {}".format(std)) |
|
|
|
latent = np.concatenate(self.latent_list) |
|
q75, q25 = np.percentile(latent.reshape(-1), [75 ,25]) |
|
median = np.median(latent.reshape(-1)) |
|
iqr = q75 - q25 |
|
norm_iqr = iqr * 0.7413 |
|
print("Norm IQR: {}".format(norm_iqr)) |
|
print("Inverse Norm IQR: {}".format(1/norm_iqr)) |
|
print("Median: {}".format(median)) |
|
|
|
|
|
class AutoencoderKLRollOut(AutoencoderKL): |
|
def __init__(self, *args, **kwargs): |
|
super().__init__(*args, **kwargs) |
|
self.psum = torch.zeros([1]) |
|
self.psum_sq = torch.zeros([1]) |
|
self.psum_min = torch.zeros([1]) |
|
self.psum_max = torch.zeros([1]) |
|
self.count = 0 |
|
self.len_dset = 0 |
|
|
|
def rollout(self, triplane): |
|
res = triplane.shape[-1] |
|
ch = triplane.shape[1] |
|
triplane = triplane.reshape(-1, 3, ch//3, res, res).permute(0, 2, 3, 1, 4).reshape(-1, ch//3, res, 3 * res) |
|
return triplane |
|
|
|
def unrollout(self, triplane): |
|
res = triplane.shape[-2] |
|
ch = 3 * triplane.shape[1] |
|
triplane = triplane.reshape(-1, ch//3, res, 3, res).permute(0, 3, 1, 2, 4).reshape(-1, ch, res, res) |
|
return triplane |
|
|
|
def encode(self, x, rollout=False): |
|
if rollout: |
|
x = self.rollout(x) |
|
h = self.encoder(x) |
|
moments = self.quant_conv(h) |
|
posterior = DiagonalGaussianDistribution(moments) |
|
return posterior |
|
|
|
def decode(self, z, unrollout=False): |
|
z = self.post_quant_conv(z) |
|
dec = self.decoder(z) |
|
if unrollout: |
|
dec = self.unrollout(dec) |
|
return dec |
|
|
|
def forward(self, input, sample_posterior=True): |
|
posterior = self.encode(input) |
|
if sample_posterior: |
|
z = posterior.sample() |
|
else: |
|
z = posterior.mode() |
|
dec = self.decode(z) |
|
return dec, posterior |
|
|
|
def training_step(self, batch, batch_idx): |
|
inputs = self.rollout(batch['triplane']) |
|
reconstructions, posterior = self(inputs) |
|
aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, prefix='train/') |
|
self.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=False) |
|
return aeloss |
|
|
|
def validation_step(self, batch, batch_idx): |
|
inputs = self.rollout(batch['triplane']) |
|
reconstructions, posterior = self(inputs, sample_posterior=False) |
|
aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, prefix='val/') |
|
self.log_dict(log_dict_ae) |
|
|
|
assert not self.norm |
|
reconstructions = self.unrollout(reconstructions) |
|
psnr_list = [] |
|
psnr_input_list = [] |
|
psnr_rec_list = [] |
|
batch_size = inputs.shape[0] |
|
for b in range(batch_size): |
|
if self.renderer_type == 'nerf': |
|
rgb_input, cur_psnr_list_input = self.render_triplane( |
|
batch['triplane_ori'][b:b+1], batch['batch_rays'][b], batch['img_flat'][b], |
|
batch['near'][b].unsqueeze(-1), batch['far'][b].unsqueeze(-1) |
|
) |
|
rgb, cur_psnr_list = self.render_triplane( |
|
reconstructions[b:b+1], batch['batch_rays'][b], batch['img_flat'][b], |
|
batch['near'][b].unsqueeze(-1), batch['far'][b].unsqueeze(-1) |
|
) |
|
elif self.renderer_type == 'eg3d': |
|
rgb_input, cur_psnr_list_input = self.render_triplane_eg3d_decoder( |
|
batch['triplane_ori'][b:b+1], batch['batch_rays'][b], batch['img'][b], |
|
) |
|
rgb, cur_psnr_list = self.render_triplane_eg3d_decoder( |
|
reconstructions[b:b+1], batch['batch_rays'][b], batch['img'][b], |
|
) |
|
else: |
|
raise NotImplementedError |
|
|
|
cur_psnr_list_rec = [] |
|
for i in range(rgb.shape[0]): |
|
cur_psnr_list_rec.append(mse2psnr(img2mse(rgb_input[i], rgb[i]))) |
|
|
|
rgb_input = to8b(rgb_input.detach().cpu().numpy()) |
|
rgb_gt = to8b(batch['img'][b].detach().cpu().numpy()) |
|
rgb = to8b(rgb.detach().cpu().numpy()) |
|
|
|
if b % 4 == 0 and batch_idx < 10: |
|
rgb_all = np.concatenate([rgb_gt[1], rgb_input[1], rgb[1]], 1) |
|
self.logger.experiment.log({ |
|
"val/vis": [wandb.Image(rgb_all)] |
|
}) |
|
|
|
psnr_list += cur_psnr_list |
|
psnr_input_list += cur_psnr_list_input |
|
psnr_rec_list += cur_psnr_list_rec |
|
|
|
self.log("val/psnr_input_gt", torch.Tensor(psnr_input_list).mean(), prog_bar=True) |
|
self.log("val/psnr_input_rec", torch.Tensor(psnr_rec_list).mean(), prog_bar=True) |
|
self.log("val/psnr_rec_gt", torch.Tensor(psnr_list).mean(), prog_bar=True) |
|
|
|
return self.log_dict |
|
|
|
def to_rgb(self, plane): |
|
x = plane.float() |
|
if not hasattr(self, "colorize"): |
|
self.colorize = torch.randn(3, x.shape[1], 1, 1).to(x) |
|
x = torch.nn.functional.conv2d(x, weight=self.colorize) |
|
x = ((x - x.min()) / (x.max() - x.min()) * 255.).permute(0, 2, 3, 1).detach().cpu().numpy().astype(np.uint8) |
|
return x |
|
|
|
def to_rgb_triplane(self, plane): |
|
x = plane.float() |
|
if not hasattr(self, "colorize_triplane"): |
|
self.colorize_triplane = torch.randn(3, x.shape[1], 1, 1).to(x) |
|
x = torch.nn.functional.conv2d(x, weight=self.colorize_triplane) |
|
x = ((x - x.min()) / (x.max() - x.min()) * 255.).permute(0, 2, 3, 1).detach().cpu().numpy().astype(np.uint8) |
|
return x |
|
|
|
def test_step(self, batch, batch_idx): |
|
inputs = self.rollout(batch['triplane']) |
|
reconstructions, posterior = self(inputs, sample_posterior=False) |
|
aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, prefix='test/') |
|
self.log_dict(log_dict_ae) |
|
|
|
batch_size = inputs.shape[0] |
|
psnr_list = [] |
|
psnr_input_list = [] |
|
psnr_rec_list = [] |
|
|
|
z = posterior.mode() |
|
colorize_z = self.to_rgb(z)[0] |
|
colorize_triplane_input = self.to_rgb_triplane(inputs)[0] |
|
colorize_triplane_output = self.to_rgb_triplane(reconstructions)[0] |
|
|
|
|
|
|
|
|
|
|
|
reconstructions = self.unrollout(reconstructions) |
|
|
|
if self.psum.device != z.device: |
|
self.psum = self.psum.to(z.device) |
|
self.psum_sq = self.psum_sq.to(z.device) |
|
self.psum_min = self.psum_min.to(z.device) |
|
self.psum_max = self.psum_max.to(z.device) |
|
self.psum += z.sum() |
|
self.psum_sq += (z ** 2).sum() |
|
self.psum_min += z.reshape(-1).min(-1)[0] |
|
self.psum_max += z.reshape(-1).max(-1)[0] |
|
assert len(z.shape) == 4 |
|
self.count += z.shape[0] * z.shape[1] * z.shape[2] * z.shape[3] |
|
self.len_dset += 1 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
mean = torch.Tensor([ |
|
-1.8449, -1.8242, 0.9667, -1.0187, 1.0647, -0.5422, -1.8632, -1.8435, |
|
0.9314, -1.0261, 1.0356, -0.5484, -1.8543, -1.8348, 0.9109, -1.0169, |
|
1.0160, -0.5467 |
|
]).reshape(1, 18, 1, 1).to(inputs.device) |
|
std = torch.Tensor([ |
|
1.7593, 1.6127, 2.7132, 1.5500, 2.7893, 0.7707, 2.1114, 1.9198, 2.6586, |
|
1.8021, 2.5473, 1.0305, 1.7042, 1.7507, 2.4270, 1.4365, 2.2511, 0.8792 |
|
]).reshape(1, 18, 1, 1).to(inputs.device) |
|
|
|
if self.norm: |
|
reconstructions_unnormalize = reconstructions * std + mean |
|
else: |
|
reconstructions_unnormalize = reconstructions |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def on_test_epoch_end(self): |
|
mean = self.psum / self.count |
|
mean_min = self.psum_min / self.len_dset |
|
mean_max = self.psum_max / self.len_dset |
|
var = (self.psum_sq / self.count) - (mean ** 2) |
|
std = torch.sqrt(var) |
|
|
|
print("mean min: {}".format(mean_min)) |
|
print("mean max: {}".format(mean_max)) |
|
print("mean: {}".format(mean)) |
|
print("std: {}".format(std)) |
|
|
|
|
|
class AutoencoderKLRollOut3DAware(AutoencoderKL): |
|
def __init__(self, *args, **kwargs): |
|
try: |
|
ckpt_path = kwargs['ckpt_path'] |
|
kwargs['ckpt_path'] = None |
|
except: |
|
ckpt_path = None |
|
|
|
super().__init__(*args, **kwargs) |
|
self.psum = torch.zeros([1]) |
|
self.psum_sq = torch.zeros([1]) |
|
self.psum_min = torch.zeros([1]) |
|
self.psum_max = torch.zeros([1]) |
|
self.count = 0 |
|
self.len_dset = 0 |
|
|
|
ddconfig = kwargs['ddconfig'] |
|
ddconfig['z_channels'] *= 3 |
|
del self.decoder |
|
del self.post_quant_conv |
|
self.decoder = Decoder(**ddconfig) |
|
self.post_quant_conv = torch.nn.Conv2d(kwargs['embed_dim'] * 3, ddconfig["z_channels"], 1) |
|
|
|
if ckpt_path is not None: |
|
self.init_from_ckpt(ckpt_path) |
|
|
|
def rollout(self, triplane): |
|
res = triplane.shape[-1] |
|
ch = triplane.shape[1] |
|
triplane = triplane.reshape(-1, 3, ch//3, res, res).permute(0, 2, 3, 1, 4).reshape(-1, ch//3, res, 3 * res) |
|
return triplane |
|
|
|
def to3daware(self, triplane): |
|
res = triplane.shape[-2] |
|
plane1 = triplane[..., :res] |
|
plane2 = triplane[..., res:2*res] |
|
plane3 = triplane[..., 2*res:3*res] |
|
|
|
x_mp = torch.nn.MaxPool2d((res, 1)) |
|
y_mp = torch.nn.MaxPool2d((1, res)) |
|
x_mp_rep = lambda i: x_mp(i).repeat(1, 1, res, 1).permute(0, 1, 3, 2) |
|
y_mp_rep = lambda i: y_mp(i).repeat(1, 1, 1, res).permute(0, 1, 3, 2) |
|
|
|
plane21 = x_mp_rep(plane2) |
|
plane31 = torch.flip(y_mp_rep(plane3), (3,)) |
|
new_plane1 = torch.cat([plane1, plane21, plane31], 1) |
|
|
|
plane12 = y_mp_rep(plane1) |
|
plane32 = x_mp_rep(plane3) |
|
new_plane2 = torch.cat([plane2, plane12, plane32], 1) |
|
|
|
plane13 = torch.flip(x_mp_rep(plane1), (2,)) |
|
plane23 = y_mp_rep(plane2) |
|
new_plane3 = torch.cat([plane3, plane13, plane23], 1) |
|
|
|
new_plane = torch.cat([new_plane1, new_plane2, new_plane3], -1).contiguous() |
|
return new_plane |
|
|
|
def unrollout(self, triplane): |
|
res = triplane.shape[-2] |
|
ch = 3 * triplane.shape[1] |
|
triplane = triplane.reshape(-1, ch//3, res, 3, res).permute(0, 3, 1, 2, 4).reshape(-1, ch, res, res) |
|
return triplane |
|
|
|
def encode(self, x, rollout=False): |
|
if rollout: |
|
x = self.to3daware(self.rollout(x)) |
|
h = self.encoder(x) |
|
moments = self.quant_conv(h) |
|
posterior = DiagonalGaussianDistribution(moments) |
|
return posterior |
|
|
|
def decode(self, z, unrollout=False): |
|
z = self.to3daware(z) |
|
z = self.post_quant_conv(z) |
|
dec = self.decoder(z) |
|
if unrollout: |
|
dec = self.unrollout(dec) |
|
return dec |
|
|
|
def forward(self, input, sample_posterior=True): |
|
posterior = self.encode(input) |
|
if sample_posterior: |
|
z = posterior.sample() |
|
else: |
|
z = posterior.mode() |
|
dec = self.decode(z) |
|
return dec, posterior |
|
|
|
def training_step(self, batch, batch_idx): |
|
inputs = self.rollout(batch['triplane']) |
|
reconstructions, posterior = self(self.to3daware(inputs)) |
|
aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, prefix='train/', batch=batch) |
|
self.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=False) |
|
return aeloss |
|
|
|
def validation_step(self, batch, batch_idx): |
|
inputs = self.rollout(batch['triplane']) |
|
reconstructions, posterior = self(self.to3daware(inputs), sample_posterior=False) |
|
aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, prefix='val/', batch=None) |
|
self.log_dict(log_dict_ae) |
|
|
|
assert not self.norm |
|
reconstructions = self.unrollout(reconstructions) |
|
psnr_list = [] |
|
psnr_input_list = [] |
|
psnr_rec_list = [] |
|
batch_size = inputs.shape[0] |
|
for b in range(batch_size): |
|
if self.renderer_type == 'nerf': |
|
rgb_input, cur_psnr_list_input = self.render_triplane( |
|
batch['triplane_ori'][b:b+1], batch['batch_rays'][b], batch['img_flat'][b], |
|
batch['near'][b].unsqueeze(-1), batch['far'][b].unsqueeze(-1) |
|
) |
|
rgb, cur_psnr_list = self.render_triplane( |
|
reconstructions[b:b+1], batch['batch_rays'][b], batch['img_flat'][b], |
|
batch['near'][b].unsqueeze(-1), batch['far'][b].unsqueeze(-1) |
|
) |
|
elif self.renderer_type == 'eg3d': |
|
rgb_input, cur_psnr_list_input = self.render_triplane_eg3d_decoder( |
|
batch['triplane_ori'][b:b+1], batch['batch_rays'][b], batch['img'][b], |
|
) |
|
rgb, cur_psnr_list = self.render_triplane_eg3d_decoder( |
|
reconstructions[b:b+1], batch['batch_rays'][b], batch['img'][b], |
|
) |
|
else: |
|
raise NotImplementedError |
|
|
|
cur_psnr_list_rec = [] |
|
for i in range(rgb.shape[0]): |
|
cur_psnr_list_rec.append(mse2psnr(img2mse(rgb_input[i], rgb[i]))) |
|
|
|
rgb_input = to8b(rgb_input.detach().cpu().numpy()) |
|
rgb_gt = to8b(batch['img'][b].detach().cpu().numpy()) |
|
rgb = to8b(rgb.detach().cpu().numpy()) |
|
|
|
if b % 4 == 0 and batch_idx < 10: |
|
rgb_all = np.concatenate([rgb_gt[1], rgb_input[1], rgb[1]], 1) |
|
self.logger.experiment.log({ |
|
"val/vis": [wandb.Image(rgb_all)] |
|
}) |
|
|
|
psnr_list += cur_psnr_list |
|
psnr_input_list += cur_psnr_list_input |
|
psnr_rec_list += cur_psnr_list_rec |
|
|
|
self.log("val/psnr_input_gt", torch.Tensor(psnr_input_list).mean(), prog_bar=True) |
|
self.log("val/psnr_input_rec", torch.Tensor(psnr_rec_list).mean(), prog_bar=True) |
|
self.log("val/psnr_rec_gt", torch.Tensor(psnr_list).mean(), prog_bar=True) |
|
|
|
return self.log_dict |
|
|
|
def to_rgb(self, plane): |
|
x = plane.float() |
|
if not hasattr(self, "colorize"): |
|
self.colorize = torch.randn(3, x.shape[1], 1, 1).to(x) |
|
x = torch.nn.functional.conv2d(x, weight=self.colorize) |
|
x = ((x - x.min()) / (x.max() - x.min()) * 255.).permute(0, 2, 3, 1).detach().cpu().numpy().astype(np.uint8) |
|
return x |
|
|
|
def to_rgb_triplane(self, plane): |
|
x = plane.float() |
|
if not hasattr(self, "colorize_triplane"): |
|
self.colorize_triplane = torch.randn(3, x.shape[1], 1, 1).to(x) |
|
x = torch.nn.functional.conv2d(x, weight=self.colorize_triplane) |
|
x = ((x - x.min()) / (x.max() - x.min()) * 255.).permute(0, 2, 3, 1).detach().cpu().numpy().astype(np.uint8) |
|
return x |
|
|
|
def to_rgb_3daware(self, plane): |
|
x = plane.float() |
|
if not hasattr(self, "colorize_3daware"): |
|
self.colorize_3daware = torch.randn(3, x.shape[1], 1, 1).to(x) |
|
x = torch.nn.functional.conv2d(x, weight=self.colorize_3daware) |
|
x = ((x - x.min()) / (x.max() - x.min()) * 255.).permute(0, 2, 3, 1).detach().cpu().numpy().astype(np.uint8) |
|
return x |
|
|
|
def test_step(self, batch, batch_idx): |
|
inputs = self.rollout(batch['triplane']) |
|
reconstructions, posterior = self(self.to3daware(inputs), sample_posterior=False) |
|
aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, prefix='test/', batch=None) |
|
self.log_dict(log_dict_ae) |
|
|
|
batch_size = inputs.shape[0] |
|
psnr_list = [] |
|
psnr_input_list = [] |
|
psnr_rec_list = [] |
|
|
|
z = posterior.mode() |
|
colorize_z = self.to_rgb(z)[0] |
|
colorize_triplane_input = self.to_rgb_triplane(inputs)[0] |
|
colorize_triplane_output = self.to_rgb_triplane(reconstructions)[0] |
|
colorize_triplane_rollout_3daware = self.to_rgb_3daware(self.to3daware(inputs))[0] |
|
res = inputs.shape[1] |
|
colorize_triplane_rollout_3daware_1 = self.to_rgb_triplane(self.to3daware(inputs)[:,res:2*res])[0] |
|
colorize_triplane_rollout_3daware_2 = self.to_rgb_triplane(self.to3daware(inputs)[:,2*res:3*res])[0] |
|
if batch_idx < 10: |
|
imageio.imwrite(os.path.join(self.logger.log_dir, "colorize_z_{}.png".format(batch_idx)), colorize_z) |
|
imageio.imwrite(os.path.join(self.logger.log_dir, "colorize_input_{}.png".format(batch_idx)), colorize_triplane_input) |
|
imageio.imwrite(os.path.join(self.logger.log_dir, "colorize_output_{}.png".format(batch_idx)), colorize_triplane_output) |
|
imageio.imwrite(os.path.join(self.logger.log_dir, "colorize_input_3daware_{}.png".format(batch_idx)), colorize_triplane_rollout_3daware) |
|
imageio.imwrite(os.path.join(self.logger.log_dir, "colorize_input_3daware_{}_1.png".format(batch_idx)), colorize_triplane_rollout_3daware_1) |
|
imageio.imwrite(os.path.join(self.logger.log_dir, "colorize_input_3daware_{}_2.png".format(batch_idx)), colorize_triplane_rollout_3daware_2) |
|
|
|
reconstructions = self.unrollout(reconstructions) |
|
|
|
if self.psum.device != z.device: |
|
self.psum = self.psum.to(z.device) |
|
self.psum_sq = self.psum_sq.to(z.device) |
|
self.psum_min = self.psum_min.to(z.device) |
|
self.psum_max = self.psum_max.to(z.device) |
|
self.psum += z.sum() |
|
self.psum_sq += (z ** 2).sum() |
|
self.psum_min += z.reshape(-1).min(-1)[0] |
|
self.psum_max += z.reshape(-1).max(-1)[0] |
|
assert len(z.shape) == 4 |
|
self.count += z.shape[0] * z.shape[1] * z.shape[2] * z.shape[3] |
|
self.len_dset += 1 |
|
|
|
if self.norm: |
|
assert NotImplementedError |
|
else: |
|
reconstructions_unnormalize = reconstructions |
|
|
|
for b in range(batch_size): |
|
if self.renderer_type == 'nerf': |
|
rgb_input, cur_psnr_list_input = self.render_triplane( |
|
batch['triplane_ori'][b:b+1], batch['batch_rays'][b], batch['img_flat'][b], |
|
batch['near'][b].unsqueeze(-1), batch['far'][b].unsqueeze(-1) |
|
) |
|
rgb, cur_psnr_list = self.render_triplane( |
|
reconstructions_unnormalize[b:b+1], batch['batch_rays'][b], batch['img_flat'][b], |
|
batch['near'][b].unsqueeze(-1), batch['far'][b].unsqueeze(-1) |
|
) |
|
elif self.renderer_type == 'eg3d': |
|
rgb_input, cur_psnr_list_input = self.render_triplane_eg3d_decoder( |
|
batch['triplane_ori'][b:b+1], batch['batch_rays'][b], batch['img'][b], |
|
) |
|
rgb, cur_psnr_list = self.render_triplane_eg3d_decoder( |
|
reconstructions_unnormalize[b:b+1], batch['batch_rays'][b], batch['img'][b], |
|
) |
|
else: |
|
raise NotImplementedError |
|
|
|
cur_psnr_list_rec = [] |
|
for i in range(rgb.shape[0]): |
|
cur_psnr_list_rec.append(mse2psnr(img2mse(rgb_input[i], rgb[i]))) |
|
|
|
rgb_input = to8b(rgb_input.detach().cpu().numpy()) |
|
rgb_gt = to8b(batch['img'][b].detach().cpu().numpy()) |
|
rgb = to8b(rgb.detach().cpu().numpy()) |
|
|
|
if batch_idx < 10: |
|
imageio.imwrite(os.path.join(self.logger.log_dir, "{}_{}_input.png".format(batch_idx, b)), rgb_input[1]) |
|
imageio.imwrite(os.path.join(self.logger.log_dir, "{}_{}_rec.png".format(batch_idx, b)), rgb[1]) |
|
imageio.imwrite(os.path.join(self.logger.log_dir, "{}_{}_gt.png".format(batch_idx, b)), rgb_gt[1]) |
|
|
|
psnr_list += cur_psnr_list |
|
psnr_input_list += cur_psnr_list_input |
|
psnr_rec_list += cur_psnr_list_rec |
|
|
|
self.log("test/psnr_input_gt", torch.Tensor(psnr_input_list).mean(), prog_bar=True) |
|
self.log("test/psnr_input_rec", torch.Tensor(psnr_rec_list).mean(), prog_bar=True) |
|
self.log("test/psnr_rec_gt", torch.Tensor(psnr_list).mean(), prog_bar=True) |
|
|
|
def on_test_epoch_end(self): |
|
mean = self.psum / self.count |
|
mean_min = self.psum_min / self.len_dset |
|
mean_max = self.psum_max / self.len_dset |
|
var = (self.psum_sq / self.count) - (mean ** 2) |
|
std = torch.sqrt(var) |
|
|
|
print("mean min: {}".format(mean_min)) |
|
print("mean max: {}".format(mean_max)) |
|
print("mean: {}".format(mean)) |
|
print("std: {}".format(std)) |
|
|
|
|
|
class AutoencoderKLRollOut3DAwareOnlyInput(AutoencoderKL): |
|
def __init__(self, *args, **kwargs): |
|
try: |
|
ckpt_path = kwargs['ckpt_path'] |
|
kwargs['ckpt_path'] = None |
|
except: |
|
ckpt_path = None |
|
|
|
super().__init__(*args, **kwargs) |
|
self.psum = torch.zeros([1]) |
|
self.psum_sq = torch.zeros([1]) |
|
self.psum_min = torch.zeros([1]) |
|
self.psum_max = torch.zeros([1]) |
|
self.count = 0 |
|
self.len_dset = 0 |
|
|
|
|
|
|
|
|
|
|
|
|
|
if ckpt_path is not None: |
|
self.init_from_ckpt(ckpt_path) |
|
|
|
def rollout(self, triplane): |
|
res = triplane.shape[-1] |
|
ch = triplane.shape[1] |
|
triplane = triplane.reshape(-1, 3, ch//3, res, res).permute(0, 2, 3, 1, 4).reshape(-1, ch//3, res, 3 * res) |
|
return triplane |
|
|
|
def to3daware(self, triplane): |
|
res = triplane.shape[-2] |
|
plane1 = triplane[..., :res] |
|
plane2 = triplane[..., res:2*res] |
|
plane3 = triplane[..., 2*res:3*res] |
|
|
|
x_mp = torch.nn.MaxPool2d((res, 1)) |
|
y_mp = torch.nn.MaxPool2d((1, res)) |
|
x_mp_rep = lambda i: x_mp(i).repeat(1, 1, res, 1).permute(0, 1, 3, 2) |
|
y_mp_rep = lambda i: y_mp(i).repeat(1, 1, 1, res).permute(0, 1, 3, 2) |
|
|
|
plane21 = x_mp_rep(plane2) |
|
plane31 = torch.flip(y_mp_rep(plane3), (3,)) |
|
new_plane1 = torch.cat([plane1, plane21, plane31], 1) |
|
|
|
plane12 = y_mp_rep(plane1) |
|
plane32 = x_mp_rep(plane3) |
|
new_plane2 = torch.cat([plane2, plane12, plane32], 1) |
|
|
|
plane13 = torch.flip(x_mp_rep(plane1), (2,)) |
|
plane23 = y_mp_rep(plane2) |
|
new_plane3 = torch.cat([plane3, plane13, plane23], 1) |
|
|
|
new_plane = torch.cat([new_plane1, new_plane2, new_plane3], -1).contiguous() |
|
return new_plane |
|
|
|
def unrollout(self, triplane): |
|
res = triplane.shape[-2] |
|
ch = 3 * triplane.shape[1] |
|
triplane = triplane.reshape(-1, ch//3, res, 3, res).permute(0, 3, 1, 2, 4).reshape(-1, ch, res, res) |
|
return triplane |
|
|
|
def encode(self, x, rollout=False): |
|
if rollout: |
|
x = self.to3daware(self.rollout(x)) |
|
h = self.encoder(x) |
|
moments = self.quant_conv(h) |
|
posterior = DiagonalGaussianDistribution(moments) |
|
return posterior |
|
|
|
def decode(self, z, unrollout=False): |
|
|
|
z = self.post_quant_conv(z) |
|
dec = self.decoder(z) |
|
if unrollout: |
|
dec = self.unrollout(dec) |
|
return dec |
|
|
|
def forward(self, input, sample_posterior=True): |
|
posterior = self.encode(input) |
|
if sample_posterior: |
|
z = posterior.sample() |
|
else: |
|
z = posterior.mode() |
|
dec = self.decode(z) |
|
return dec, posterior |
|
|
|
def training_step(self, batch, batch_idx): |
|
inputs = self.rollout(batch['triplane']) |
|
reconstructions, posterior = self(self.to3daware(inputs)) |
|
aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, prefix='train/') |
|
self.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=False) |
|
return aeloss |
|
|
|
def validation_step(self, batch, batch_idx): |
|
inputs = self.rollout(batch['triplane']) |
|
reconstructions, posterior = self(self.to3daware(inputs), sample_posterior=False) |
|
aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, prefix='val/') |
|
self.log_dict(log_dict_ae) |
|
|
|
assert not self.norm |
|
reconstructions = self.unrollout(reconstructions) |
|
psnr_list = [] |
|
psnr_input_list = [] |
|
psnr_rec_list = [] |
|
batch_size = inputs.shape[0] |
|
for b in range(batch_size): |
|
if self.renderer_type == 'nerf': |
|
rgb_input, cur_psnr_list_input = self.render_triplane( |
|
batch['triplane_ori'][b:b+1], batch['batch_rays'][b], batch['img_flat'][b], |
|
batch['near'][b].unsqueeze(-1), batch['far'][b].unsqueeze(-1) |
|
) |
|
rgb, cur_psnr_list = self.render_triplane( |
|
reconstructions[b:b+1], batch['batch_rays'][b], batch['img_flat'][b], |
|
batch['near'][b].unsqueeze(-1), batch['far'][b].unsqueeze(-1) |
|
) |
|
elif self.renderer_type == 'eg3d': |
|
rgb_input, cur_psnr_list_input = self.render_triplane_eg3d_decoder( |
|
batch['triplane_ori'][b:b+1], batch['batch_rays'][b], batch['img'][b], |
|
) |
|
rgb, cur_psnr_list = self.render_triplane_eg3d_decoder( |
|
reconstructions[b:b+1], batch['batch_rays'][b], batch['img'][b], |
|
) |
|
else: |
|
raise NotImplementedError |
|
|
|
cur_psnr_list_rec = [] |
|
for i in range(rgb.shape[0]): |
|
cur_psnr_list_rec.append(mse2psnr(img2mse(rgb_input[i], rgb[i]))) |
|
|
|
rgb_input = to8b(rgb_input.detach().cpu().numpy()) |
|
rgb_gt = to8b(batch['img'][b].detach().cpu().numpy()) |
|
rgb = to8b(rgb.detach().cpu().numpy()) |
|
|
|
if b % 4 == 0 and batch_idx < 10: |
|
rgb_all = np.concatenate([rgb_gt[1], rgb_input[1], rgb[1]], 1) |
|
self.logger.experiment.log({ |
|
"val/vis": [wandb.Image(rgb_all)] |
|
}) |
|
|
|
psnr_list += cur_psnr_list |
|
psnr_input_list += cur_psnr_list_input |
|
psnr_rec_list += cur_psnr_list_rec |
|
|
|
self.log("val/psnr_input_gt", torch.Tensor(psnr_input_list).mean(), prog_bar=True) |
|
self.log("val/psnr_input_rec", torch.Tensor(psnr_rec_list).mean(), prog_bar=True) |
|
self.log("val/psnr_rec_gt", torch.Tensor(psnr_list).mean(), prog_bar=True) |
|
|
|
return self.log_dict |
|
|
|
def to_rgb(self, plane): |
|
x = plane.float() |
|
if not hasattr(self, "colorize"): |
|
self.colorize = torch.randn(3, x.shape[1], 1, 1).to(x) |
|
x = torch.nn.functional.conv2d(x, weight=self.colorize) |
|
x = ((x - x.min()) / (x.max() - x.min()) * 255.).permute(0, 2, 3, 1).detach().cpu().numpy().astype(np.uint8) |
|
return x |
|
|
|
def to_rgb_triplane(self, plane): |
|
x = plane.float() |
|
if not hasattr(self, "colorize_triplane"): |
|
self.colorize_triplane = torch.randn(3, x.shape[1], 1, 1).to(x) |
|
x = torch.nn.functional.conv2d(x, weight=self.colorize_triplane) |
|
x = ((x - x.min()) / (x.max() - x.min()) * 255.).permute(0, 2, 3, 1).detach().cpu().numpy().astype(np.uint8) |
|
return x |
|
|
|
def test_step(self, batch, batch_idx): |
|
inputs = self.rollout(batch['triplane']) |
|
reconstructions, posterior = self(self.to3daware(inputs), sample_posterior=False) |
|
aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, prefix='test/') |
|
self.log_dict(log_dict_ae) |
|
|
|
batch_size = inputs.shape[0] |
|
psnr_list = [] |
|
psnr_input_list = [] |
|
psnr_rec_list = [] |
|
|
|
z = posterior.mode() |
|
colorize_z = self.to_rgb(z)[0] |
|
colorize_triplane_input = self.to_rgb_triplane(inputs)[0] |
|
colorize_triplane_output = self.to_rgb_triplane(reconstructions)[0] |
|
if batch_idx < 10: |
|
imageio.imwrite(os.path.join(self.logger.log_dir, "colorize_z_{}.png".format(batch_idx)), colorize_z) |
|
imageio.imwrite(os.path.join(self.logger.log_dir, "colorize_input_{}.png".format(batch_idx)), colorize_triplane_input) |
|
imageio.imwrite(os.path.join(self.logger.log_dir, "colorize_output_{}.png".format(batch_idx)), colorize_triplane_output) |
|
|
|
reconstructions = self.unrollout(reconstructions) |
|
|
|
if self.psum.device != z.device: |
|
self.psum = self.psum.to(z.device) |
|
self.psum_sq = self.psum_sq.to(z.device) |
|
self.psum_min = self.psum_min.to(z.device) |
|
self.psum_max = self.psum_max.to(z.device) |
|
self.psum += z.sum() |
|
self.psum_sq += (z ** 2).sum() |
|
self.psum_min += z.reshape(-1).min(-1)[0] |
|
self.psum_max += z.reshape(-1).max(-1)[0] |
|
assert len(z.shape) == 4 |
|
self.count += z.shape[0] * z.shape[1] * z.shape[2] * z.shape[3] |
|
self.len_dset += 1 |
|
|
|
if self.norm: |
|
assert NotImplementedError |
|
else: |
|
reconstructions_unnormalize = reconstructions |
|
|
|
for b in range(batch_size): |
|
if self.renderer_type == 'nerf': |
|
rgb_input, cur_psnr_list_input = self.render_triplane( |
|
batch['triplane_ori'][b:b+1], batch['batch_rays'][b], batch['img_flat'][b], |
|
batch['near'][b].unsqueeze(-1), batch['far'][b].unsqueeze(-1) |
|
) |
|
rgb, cur_psnr_list = self.render_triplane( |
|
reconstructions_unnormalize[b:b+1], batch['batch_rays'][b], batch['img_flat'][b], |
|
batch['near'][b].unsqueeze(-1), batch['far'][b].unsqueeze(-1) |
|
) |
|
elif self.renderer_type == 'eg3d': |
|
rgb_input, cur_psnr_list_input = self.render_triplane_eg3d_decoder( |
|
batch['triplane_ori'][b:b+1], batch['batch_rays'][b], batch['img'][b], |
|
) |
|
rgb, cur_psnr_list = self.render_triplane_eg3d_decoder( |
|
reconstructions_unnormalize[b:b+1], batch['batch_rays'][b], batch['img'][b], |
|
) |
|
else: |
|
raise NotImplementedError |
|
|
|
cur_psnr_list_rec = [] |
|
for i in range(rgb.shape[0]): |
|
cur_psnr_list_rec.append(mse2psnr(img2mse(rgb_input[i], rgb[i]))) |
|
|
|
rgb_input = to8b(rgb_input.detach().cpu().numpy()) |
|
rgb_gt = to8b(batch['img'][b].detach().cpu().numpy()) |
|
rgb = to8b(rgb.detach().cpu().numpy()) |
|
|
|
if batch_idx < 10: |
|
imageio.imwrite(os.path.join(self.logger.log_dir, "{}_{}_input.png".format(batch_idx, b)), rgb_input[1]) |
|
imageio.imwrite(os.path.join(self.logger.log_dir, "{}_{}_rec.png".format(batch_idx, b)), rgb[1]) |
|
imageio.imwrite(os.path.join(self.logger.log_dir, "{}_{}_gt.png".format(batch_idx, b)), rgb_gt[1]) |
|
|
|
psnr_list += cur_psnr_list |
|
psnr_input_list += cur_psnr_list_input |
|
psnr_rec_list += cur_psnr_list_rec |
|
|
|
self.log("test/psnr_input_gt", torch.Tensor(psnr_input_list).mean(), prog_bar=True) |
|
self.log("test/psnr_input_rec", torch.Tensor(psnr_rec_list).mean(), prog_bar=True) |
|
self.log("test/psnr_rec_gt", torch.Tensor(psnr_list).mean(), prog_bar=True) |
|
|
|
def on_test_epoch_end(self): |
|
mean = self.psum / self.count |
|
mean_min = self.psum_min / self.len_dset |
|
mean_max = self.psum_max / self.len_dset |
|
var = (self.psum_sq / self.count) - (mean ** 2) |
|
std = torch.sqrt(var) |
|
|
|
print("mean min: {}".format(mean_min)) |
|
print("mean max: {}".format(mean_max)) |
|
print("mean: {}".format(mean)) |
|
print("std: {}".format(std)) |
|
|
|
|
|
class AutoencoderKLRollOut3DAwareMeanPool(AutoencoderKL): |
|
def __init__(self, *args, **kwargs): |
|
try: |
|
ckpt_path = kwargs['ckpt_path'] |
|
kwargs['ckpt_path'] = None |
|
except: |
|
ckpt_path = None |
|
|
|
super().__init__(*args, **kwargs) |
|
self.psum = torch.zeros([1]) |
|
self.psum_sq = torch.zeros([1]) |
|
self.psum_min = torch.zeros([1]) |
|
self.psum_max = torch.zeros([1]) |
|
self.count = 0 |
|
self.len_dset = 0 |
|
|
|
ddconfig = kwargs['ddconfig'] |
|
ddconfig['z_channels'] *= 3 |
|
self.decoder = Decoder(**ddconfig) |
|
self.post_quant_conv = torch.nn.Conv2d(kwargs['embed_dim'] * 3, ddconfig["z_channels"], 1) |
|
|
|
if ckpt_path is not None: |
|
self.init_from_ckpt(ckpt_path) |
|
|
|
def rollout(self, triplane): |
|
res = triplane.shape[-1] |
|
ch = triplane.shape[1] |
|
triplane = triplane.reshape(-1, 3, ch//3, res, res).permute(0, 2, 3, 1, 4).reshape(-1, ch//3, res, 3 * res) |
|
return triplane |
|
|
|
def to3daware(self, triplane): |
|
res = triplane.shape[-2] |
|
plane1 = triplane[..., :res] |
|
plane2 = triplane[..., res:2*res] |
|
plane3 = triplane[..., 2*res:3*res] |
|
|
|
x_mp = torch.nn.AvgPool2d((res, 1)) |
|
y_mp = torch.nn.AvgPool2d((1, res)) |
|
x_mp_rep = lambda i: x_mp(i).repeat(1, 1, res, 1).permute(0, 1, 3, 2) |
|
y_mp_rep = lambda i: y_mp(i).repeat(1, 1, 1, res).permute(0, 1, 3, 2) |
|
|
|
plane21 = x_mp_rep(plane2) |
|
plane31 = torch.flip(y_mp_rep(plane3), (3,)) |
|
new_plane1 = torch.cat([plane1, plane21, plane31], 1) |
|
|
|
plane12 = y_mp_rep(plane1) |
|
plane32 = x_mp_rep(plane3) |
|
new_plane2 = torch.cat([plane2, plane12, plane32], 1) |
|
|
|
plane13 = torch.flip(x_mp_rep(plane1), (2,)) |
|
plane23 = y_mp_rep(plane2) |
|
new_plane3 = torch.cat([plane3, plane13, plane23], 1) |
|
|
|
new_plane = torch.cat([new_plane1, new_plane2, new_plane3], -1).contiguous() |
|
return new_plane |
|
|
|
def unrollout(self, triplane): |
|
res = triplane.shape[-2] |
|
ch = 3 * triplane.shape[1] |
|
triplane = triplane.reshape(-1, ch//3, res, 3, res).permute(0, 3, 1, 2, 4).reshape(-1, ch, res, res) |
|
return triplane |
|
|
|
def encode(self, x, rollout=False): |
|
if rollout: |
|
x = self.to3daware(self.rollout(x)) |
|
h = self.encoder(x) |
|
moments = self.quant_conv(h) |
|
posterior = DiagonalGaussianDistribution(moments) |
|
return posterior |
|
|
|
def decode(self, z, unrollout=False): |
|
z = self.to3daware(z) |
|
z = self.post_quant_conv(z) |
|
dec = self.decoder(z) |
|
if unrollout: |
|
dec = self.unrollout(dec) |
|
return dec |
|
|
|
def forward(self, input, sample_posterior=True): |
|
posterior = self.encode(input) |
|
if sample_posterior: |
|
z = posterior.sample() |
|
else: |
|
z = posterior.mode() |
|
dec = self.decode(z) |
|
return dec, posterior |
|
|
|
def training_step(self, batch, batch_idx): |
|
inputs = self.rollout(batch['triplane']) |
|
reconstructions, posterior = self(self.to3daware(inputs)) |
|
aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, prefix='train/') |
|
self.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=False) |
|
return aeloss |
|
|
|
def validation_step(self, batch, batch_idx): |
|
inputs = self.rollout(batch['triplane']) |
|
reconstructions, posterior = self(self.to3daware(inputs), sample_posterior=False) |
|
aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, prefix='val/') |
|
self.log_dict(log_dict_ae) |
|
|
|
assert not self.norm |
|
reconstructions = self.unrollout(reconstructions) |
|
psnr_list = [] |
|
psnr_input_list = [] |
|
psnr_rec_list = [] |
|
batch_size = inputs.shape[0] |
|
for b in range(batch_size): |
|
if self.renderer_type == 'nerf': |
|
rgb_input, cur_psnr_list_input = self.render_triplane( |
|
batch['triplane_ori'][b:b+1], batch['batch_rays'][b], batch['img_flat'][b], |
|
batch['near'][b].unsqueeze(-1), batch['far'][b].unsqueeze(-1) |
|
) |
|
rgb, cur_psnr_list = self.render_triplane( |
|
reconstructions[b:b+1], batch['batch_rays'][b], batch['img_flat'][b], |
|
batch['near'][b].unsqueeze(-1), batch['far'][b].unsqueeze(-1) |
|
) |
|
elif self.renderer_type == 'eg3d': |
|
rgb_input, cur_psnr_list_input = self.render_triplane_eg3d_decoder( |
|
batch['triplane_ori'][b:b+1], batch['batch_rays'][b], batch['img'][b], |
|
) |
|
rgb, cur_psnr_list = self.render_triplane_eg3d_decoder( |
|
reconstructions[b:b+1], batch['batch_rays'][b], batch['img'][b], |
|
) |
|
else: |
|
raise NotImplementedError |
|
|
|
cur_psnr_list_rec = [] |
|
for i in range(rgb.shape[0]): |
|
cur_psnr_list_rec.append(mse2psnr(img2mse(rgb_input[i], rgb[i]))) |
|
|
|
rgb_input = to8b(rgb_input.detach().cpu().numpy()) |
|
rgb_gt = to8b(batch['img'][b].detach().cpu().numpy()) |
|
rgb = to8b(rgb.detach().cpu().numpy()) |
|
|
|
if b % 4 == 0 and batch_idx < 10: |
|
rgb_all = np.concatenate([rgb_gt[1], rgb_input[1], rgb[1]], 1) |
|
self.logger.experiment.log({ |
|
"val/vis": [wandb.Image(rgb_all)] |
|
}) |
|
|
|
psnr_list += cur_psnr_list |
|
psnr_input_list += cur_psnr_list_input |
|
psnr_rec_list += cur_psnr_list_rec |
|
|
|
self.log("val/psnr_input_gt", torch.Tensor(psnr_input_list).mean(), prog_bar=True) |
|
self.log("val/psnr_input_rec", torch.Tensor(psnr_rec_list).mean(), prog_bar=True) |
|
self.log("val/psnr_rec_gt", torch.Tensor(psnr_list).mean(), prog_bar=True) |
|
|
|
return self.log_dict |
|
|
|
def to_rgb(self, plane): |
|
x = plane.float() |
|
if not hasattr(self, "colorize"): |
|
self.colorize = torch.randn(3, x.shape[1], 1, 1).to(x) |
|
x = torch.nn.functional.conv2d(x, weight=self.colorize) |
|
x = ((x - x.min()) / (x.max() - x.min()) * 255.).permute(0, 2, 3, 1).detach().cpu().numpy().astype(np.uint8) |
|
return x |
|
|
|
def to_rgb_triplane(self, plane): |
|
x = plane.float() |
|
if not hasattr(self, "colorize_triplane"): |
|
self.colorize_triplane = torch.randn(3, x.shape[1], 1, 1).to(x) |
|
x = torch.nn.functional.conv2d(x, weight=self.colorize_triplane) |
|
x = ((x - x.min()) / (x.max() - x.min()) * 255.).permute(0, 2, 3, 1).detach().cpu().numpy().astype(np.uint8) |
|
return x |
|
|
|
def to_rgb_3daware(self, plane): |
|
x = plane.float() |
|
if not hasattr(self, "colorize_3daware"): |
|
self.colorize_3daware = torch.randn(3, x.shape[1], 1, 1).to(x) |
|
x = torch.nn.functional.conv2d(x, weight=self.colorize_3daware) |
|
x = ((x - x.min()) / (x.max() - x.min()) * 255.).permute(0, 2, 3, 1).detach().cpu().numpy().astype(np.uint8) |
|
return x |
|
|
|
def test_step(self, batch, batch_idx): |
|
inputs = self.rollout(batch['triplane']) |
|
reconstructions, posterior = self(self.to3daware(inputs), sample_posterior=False) |
|
aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, prefix='test/') |
|
self.log_dict(log_dict_ae) |
|
|
|
batch_size = inputs.shape[0] |
|
psnr_list = [] |
|
psnr_input_list = [] |
|
psnr_rec_list = [] |
|
|
|
z = posterior.mode() |
|
colorize_z = self.to_rgb(z)[0] |
|
colorize_triplane_input = self.to_rgb_triplane(inputs)[0] |
|
colorize_triplane_output = self.to_rgb_triplane(reconstructions)[0] |
|
colorize_triplane_rollout_3daware = self.to_rgb_3daware(self.to3daware(inputs))[0] |
|
res = inputs.shape[1] |
|
colorize_triplane_rollout_3daware_1 = self.to_rgb_triplane(self.to3daware(inputs)[:,res:2*res])[0] |
|
colorize_triplane_rollout_3daware_2 = self.to_rgb_triplane(self.to3daware(inputs)[:,2*res:3*res])[0] |
|
if batch_idx < 10: |
|
imageio.imwrite(os.path.join(self.logger.log_dir, "colorize_z_{}.png".format(batch_idx)), colorize_z) |
|
imageio.imwrite(os.path.join(self.logger.log_dir, "colorize_input_{}.png".format(batch_idx)), colorize_triplane_input) |
|
imageio.imwrite(os.path.join(self.logger.log_dir, "colorize_output_{}.png".format(batch_idx)), colorize_triplane_output) |
|
imageio.imwrite(os.path.join(self.logger.log_dir, "colorize_input_3daware_{}.png".format(batch_idx)), colorize_triplane_rollout_3daware) |
|
imageio.imwrite(os.path.join(self.logger.log_dir, "colorize_input_3daware_{}_1.png".format(batch_idx)), colorize_triplane_rollout_3daware_1) |
|
imageio.imwrite(os.path.join(self.logger.log_dir, "colorize_input_3daware_{}_2.png".format(batch_idx)), colorize_triplane_rollout_3daware_2) |
|
|
|
reconstructions = self.unrollout(reconstructions) |
|
|
|
if self.psum.device != z.device: |
|
self.psum = self.psum.to(z.device) |
|
self.psum_sq = self.psum_sq.to(z.device) |
|
self.psum_min = self.psum_min.to(z.device) |
|
self.psum_max = self.psum_max.to(z.device) |
|
self.psum += z.sum() |
|
self.psum_sq += (z ** 2).sum() |
|
self.psum_min += z.reshape(-1).min(-1)[0] |
|
self.psum_max += z.reshape(-1).max(-1)[0] |
|
assert len(z.shape) == 4 |
|
self.count += z.shape[0] * z.shape[1] * z.shape[2] * z.shape[3] |
|
self.len_dset += 1 |
|
|
|
if self.norm: |
|
assert NotImplementedError |
|
else: |
|
reconstructions_unnormalize = reconstructions |
|
|
|
for b in range(batch_size): |
|
if self.renderer_type == 'nerf': |
|
rgb_input, cur_psnr_list_input = self.render_triplane( |
|
batch['triplane_ori'][b:b+1], batch['batch_rays'][b], batch['img_flat'][b], |
|
batch['near'][b].unsqueeze(-1), batch['far'][b].unsqueeze(-1) |
|
) |
|
rgb, cur_psnr_list = self.render_triplane( |
|
reconstructions_unnormalize[b:b+1], batch['batch_rays'][b], batch['img_flat'][b], |
|
batch['near'][b].unsqueeze(-1), batch['far'][b].unsqueeze(-1) |
|
) |
|
elif self.renderer_type == 'eg3d': |
|
rgb_input, cur_psnr_list_input = self.render_triplane_eg3d_decoder( |
|
batch['triplane_ori'][b:b+1], batch['batch_rays'][b], batch['img'][b], |
|
) |
|
rgb, cur_psnr_list = self.render_triplane_eg3d_decoder( |
|
reconstructions_unnormalize[b:b+1], batch['batch_rays'][b], batch['img'][b], |
|
) |
|
else: |
|
raise NotImplementedError |
|
|
|
cur_psnr_list_rec = [] |
|
for i in range(rgb.shape[0]): |
|
cur_psnr_list_rec.append(mse2psnr(img2mse(rgb_input[i], rgb[i]))) |
|
|
|
rgb_input = to8b(rgb_input.detach().cpu().numpy()) |
|
rgb_gt = to8b(batch['img'][b].detach().cpu().numpy()) |
|
rgb = to8b(rgb.detach().cpu().numpy()) |
|
|
|
if batch_idx < 10: |
|
imageio.imwrite(os.path.join(self.logger.log_dir, "{}_{}_input.png".format(batch_idx, b)), rgb_input[1]) |
|
imageio.imwrite(os.path.join(self.logger.log_dir, "{}_{}_rec.png".format(batch_idx, b)), rgb[1]) |
|
imageio.imwrite(os.path.join(self.logger.log_dir, "{}_{}_gt.png".format(batch_idx, b)), rgb_gt[1]) |
|
|
|
psnr_list += cur_psnr_list |
|
psnr_input_list += cur_psnr_list_input |
|
psnr_rec_list += cur_psnr_list_rec |
|
|
|
self.log("test/psnr_input_gt", torch.Tensor(psnr_input_list).mean(), prog_bar=True) |
|
self.log("test/psnr_input_rec", torch.Tensor(psnr_rec_list).mean(), prog_bar=True) |
|
self.log("test/psnr_rec_gt", torch.Tensor(psnr_list).mean(), prog_bar=True) |
|
|
|
def on_test_epoch_end(self): |
|
mean = self.psum / self.count |
|
mean_min = self.psum_min / self.len_dset |
|
mean_max = self.psum_max / self.len_dset |
|
var = (self.psum_sq / self.count) - (mean ** 2) |
|
std = torch.sqrt(var) |
|
|
|
print("mean min: {}".format(mean_min)) |
|
print("mean max: {}".format(mean_max)) |
|
print("mean: {}".format(mean)) |
|
print("std: {}".format(std)) |
|
|
|
|
|
class AutoencoderKLGroupConv(AutoencoderKL): |
|
def __init__(self, *args, **kwargs): |
|
try: |
|
ckpt_path = kwargs['ckpt_path'] |
|
kwargs['ckpt_path'] = None |
|
except: |
|
ckpt_path = None |
|
|
|
super().__init__(*args, **kwargs) |
|
self.latent_list = [] |
|
self.psum = torch.zeros([1]) |
|
self.psum_sq = torch.zeros([1]) |
|
self.psum_min = torch.zeros([1]) |
|
self.psum_max = torch.zeros([1]) |
|
self.count = 0 |
|
self.len_dset = 0 |
|
|
|
ddconfig = kwargs['ddconfig'] |
|
|
|
del self.decoder |
|
del self.encoder |
|
self.encoder = Encoder_GroupConv(**ddconfig) |
|
self.decoder = Decoder_GroupConv(**ddconfig) |
|
|
|
if "mean" in ddconfig: |
|
print("Using mean std!!") |
|
self.triplane_mean = torch.Tensor(ddconfig['mean']).reshape(-1).unsqueeze(0).unsqueeze(-1).unsqueeze(-1).float() |
|
self.triplane_std = torch.Tensor(ddconfig['std']).reshape(-1).unsqueeze(0).unsqueeze(-1).unsqueeze(-1).float() |
|
else: |
|
self.triplane_mean = None |
|
self.triplane_std = None |
|
|
|
if ckpt_path is not None: |
|
self.init_from_ckpt(ckpt_path) |
|
|
|
def rollout(self, triplane): |
|
res = triplane.shape[-1] |
|
ch = triplane.shape[1] |
|
triplane = triplane.reshape(-1, 3, ch//3, res, res).permute(0, 2, 3, 1, 4).reshape(-1, ch//3, res, 3 * res) |
|
return triplane |
|
|
|
def to3daware(self, triplane): |
|
res = triplane.shape[-2] |
|
plane1 = triplane[..., :res] |
|
plane2 = triplane[..., res:2*res] |
|
plane3 = triplane[..., 2*res:3*res] |
|
|
|
x_mp = torch.nn.MaxPool2d((res, 1)) |
|
y_mp = torch.nn.MaxPool2d((1, res)) |
|
x_mp_rep = lambda i: x_mp(i).repeat(1, 1, res, 1).permute(0, 1, 3, 2) |
|
y_mp_rep = lambda i: y_mp(i).repeat(1, 1, 1, res).permute(0, 1, 3, 2) |
|
|
|
plane21 = x_mp_rep(plane2) |
|
plane31 = torch.flip(y_mp_rep(plane3), (3,)) |
|
new_plane1 = torch.cat([plane1, plane21, plane31], 1) |
|
|
|
plane12 = y_mp_rep(plane1) |
|
plane32 = x_mp_rep(plane3) |
|
new_plane2 = torch.cat([plane2, plane12, plane32], 1) |
|
|
|
plane13 = torch.flip(x_mp_rep(plane1), (2,)) |
|
plane23 = y_mp_rep(plane2) |
|
new_plane3 = torch.cat([plane3, plane13, plane23], 1) |
|
|
|
new_plane = torch.cat([new_plane1, new_plane2, new_plane3], -1).contiguous() |
|
return new_plane |
|
|
|
def unrollout(self, triplane): |
|
res = triplane.shape[-2] |
|
ch = 3 * triplane.shape[1] |
|
triplane = triplane.reshape(-1, ch//3, res, 3, res).permute(0, 3, 1, 2, 4).reshape(-1, ch, res, res) |
|
return triplane |
|
|
|
def encode(self, x, rollout=False): |
|
if rollout: |
|
|
|
x = self.rollout(x) |
|
if self.triplane_mean is not None: |
|
x = (x - self.triplane_mean.to(x.device)) / self.triplane_std.to(x.device) |
|
h = self.encoder(x) |
|
moments = self.quant_conv(h) |
|
posterior = DiagonalGaussianDistribution(moments) |
|
return posterior |
|
|
|
def decode(self, z, unrollout=False): |
|
|
|
z = self.post_quant_conv(z) |
|
dec = self.decoder(z) |
|
if self.triplane_mean is not None: |
|
dec = dec * self.triplane_std.to(dec.device) + self.triplane_mean.to(dec.device) |
|
if unrollout: |
|
dec = self.unrollout(dec) |
|
return dec |
|
|
|
def forward(self, input, sample_posterior=True): |
|
posterior = self.encode(input) |
|
if sample_posterior: |
|
z = posterior.sample() |
|
else: |
|
z = posterior.mode() |
|
dec = self.decode(z) |
|
return dec, posterior |
|
|
|
def training_step(self, batch, batch_idx): |
|
inputs = self.rollout(batch['triplane']) |
|
reconstructions, posterior = self(inputs) |
|
aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, prefix='train/', batch=batch) |
|
self.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=False) |
|
return aeloss |
|
|
|
def validation_step(self, batch, batch_idx): |
|
inputs = self.rollout(batch['triplane']) |
|
reconstructions, posterior = self(inputs, sample_posterior=False) |
|
aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, prefix='val/', batch=None) |
|
self.log_dict(log_dict_ae) |
|
|
|
z = posterior.mode() |
|
colorize_z = self.to_rgb(z)[0] |
|
assert not self.norm |
|
reconstructions = self.unrollout(reconstructions) |
|
psnr_list = [] |
|
psnr_input_list = [] |
|
psnr_rec_list = [] |
|
batch_size = inputs.shape[0] |
|
for b in range(batch_size): |
|
if self.renderer_type == 'nerf': |
|
rgb_input, cur_psnr_list_input = self.render_triplane( |
|
batch['triplane_ori'][b:b+1], batch['batch_rays'][b], batch['img_flat'][b], |
|
batch['near'][b].unsqueeze(-1), batch['far'][b].unsqueeze(-1) |
|
) |
|
rgb, cur_psnr_list = self.render_triplane( |
|
reconstructions[b:b+1], batch['batch_rays'][b], batch['img_flat'][b], |
|
batch['near'][b].unsqueeze(-1), batch['far'][b].unsqueeze(-1) |
|
) |
|
elif self.renderer_type == 'eg3d': |
|
rgb_input, cur_psnr_list_input = self.render_triplane_eg3d_decoder( |
|
batch['triplane_ori'][b:b+1], batch['batch_rays'][b], batch['img'][b], |
|
) |
|
rgb, cur_psnr_list = self.render_triplane_eg3d_decoder( |
|
reconstructions[b:b+1], batch['batch_rays'][b], batch['img'][b], |
|
) |
|
else: |
|
raise NotImplementedError |
|
|
|
cur_psnr_list_rec = [] |
|
for i in range(rgb.shape[0]): |
|
cur_psnr_list_rec.append(mse2psnr(img2mse(rgb_input[i], rgb[i]))) |
|
|
|
rgb_input = to8b(rgb_input.detach().cpu().numpy()) |
|
rgb_gt = to8b(batch['img'][b].detach().cpu().numpy()) |
|
rgb = to8b(rgb.detach().cpu().numpy()) |
|
|
|
rgb_input = np.stack([rgb_input[..., 2], rgb_input[..., 1], rgb_input[..., 0]], -1) |
|
rgb = np.stack([rgb[..., 2], rgb[..., 1], rgb[..., 0]], -1) |
|
|
|
if b % 2 == 0 and batch_idx < 10: |
|
rgb_all = np.concatenate([rgb_gt[1], rgb_input[1], rgb[1]], 1) |
|
self.logger.experiment.log({ |
|
"val/vis": [wandb.Image(rgb_all)], |
|
"val/latent_vis": [wandb.Image(colorize_z)] |
|
}) |
|
|
|
psnr_list += cur_psnr_list |
|
psnr_input_list += cur_psnr_list_input |
|
psnr_rec_list += cur_psnr_list_rec |
|
|
|
self.log("val/psnr_input_gt", torch.Tensor(psnr_input_list).mean(), prog_bar=True) |
|
self.log("val/psnr_input_rec", torch.Tensor(psnr_rec_list).mean(), prog_bar=True) |
|
self.log("val/psnr_rec_gt", torch.Tensor(psnr_list).mean(), prog_bar=True) |
|
|
|
return self.log_dict |
|
|
|
def to_rgb(self, plane): |
|
x = plane.float() |
|
if not hasattr(self, "colorize"): |
|
self.colorize = torch.randn(3, x.shape[1], 1, 1).to(x) |
|
x = torch.nn.functional.conv2d(x, weight=self.colorize) |
|
x = ((x - x.min()) / (x.max() - x.min()) * 255.).permute(0, 2, 3, 1).detach().cpu().numpy().astype(np.uint8) |
|
return x |
|
|
|
def to_rgb_triplane(self, plane): |
|
x = plane.float() |
|
if not hasattr(self, "colorize_triplane"): |
|
self.colorize_triplane = torch.randn(3, x.shape[1], 1, 1).to(x) |
|
x = torch.nn.functional.conv2d(x, weight=self.colorize_triplane) |
|
x = ((x - x.min()) / (x.max() - x.min()) * 255.).permute(0, 2, 3, 1).detach().cpu().numpy().astype(np.uint8) |
|
return x |
|
|
|
def to_rgb_3daware(self, plane): |
|
x = plane.float() |
|
if not hasattr(self, "colorize_3daware"): |
|
self.colorize_3daware = torch.randn(3, x.shape[1], 1, 1).to(x) |
|
x = torch.nn.functional.conv2d(x, weight=self.colorize_3daware) |
|
x = ((x - x.min()) / (x.max() - x.min()) * 255.).permute(0, 2, 3, 1).detach().cpu().numpy().astype(np.uint8) |
|
return x |
|
|
|
def test_step(self, batch, batch_idx): |
|
inputs = self.rollout(batch['triplane']) |
|
reconstructions, posterior = self(inputs, sample_posterior=False) |
|
aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, prefix='test/', batch=None) |
|
self.log_dict(log_dict_ae) |
|
|
|
batch_size = inputs.shape[0] |
|
psnr_list = [] |
|
psnr_input_list = [] |
|
psnr_rec_list = [] |
|
|
|
z = posterior.mode() |
|
colorize_z = self.to_rgb(z)[0] |
|
colorize_triplane_input = self.to_rgb_triplane(inputs)[0] |
|
colorize_triplane_output = self.to_rgb_triplane(reconstructions)[0] |
|
|
|
import os |
|
import random |
|
import string |
|
|
|
z_np = inputs.detach().cpu().numpy() |
|
fname = ''.join(random.choices(string.ascii_uppercase + string.digits, k=8)) + '.npy' |
|
with open(os.path.join('/mnt/lustre/hongfangzhou.p/AE3D/tmp', fname), 'wb') as f: |
|
np.save(f, z_np) |
|
|
|
|
|
|
|
|
|
|
|
if batch_idx < 0: |
|
imageio.imwrite(os.path.join(self.logger.log_dir, "colorize_z_{}.png".format(batch_idx)), colorize_z) |
|
imageio.imwrite(os.path.join(self.logger.log_dir, "colorize_input_{}.png".format(batch_idx)), colorize_triplane_input) |
|
imageio.imwrite(os.path.join(self.logger.log_dir, "colorize_output_{}.png".format(batch_idx)), colorize_triplane_output) |
|
|
|
|
|
|
|
|
|
np_z = z.detach().cpu().numpy() |
|
|
|
|
|
|
|
self.latent_list.append(np_z) |
|
|
|
reconstructions = self.unrollout(reconstructions) |
|
|
|
if self.psum.device != z.device: |
|
self.psum = self.psum.to(z.device) |
|
self.psum_sq = self.psum_sq.to(z.device) |
|
self.psum_min = self.psum_min.to(z.device) |
|
self.psum_max = self.psum_max.to(z.device) |
|
self.psum += z.sum() |
|
self.psum_sq += (z ** 2).sum() |
|
self.psum_min += z.reshape(-1).min(-1)[0] |
|
self.psum_max += z.reshape(-1).max(-1)[0] |
|
assert len(z.shape) == 4 |
|
self.count += z.shape[0] * z.shape[1] * z.shape[2] * z.shape[3] |
|
self.len_dset += 1 |
|
|
|
if self.norm: |
|
assert NotImplementedError |
|
else: |
|
reconstructions_unnormalize = reconstructions |
|
|
|
if True: |
|
for b in range(batch_size): |
|
if self.renderer_type == 'nerf': |
|
rgb_input, cur_psnr_list_input = self.render_triplane( |
|
batch['triplane_ori'][b:b+1], batch['batch_rays'][b], batch['img_flat'][b], |
|
batch['near'][b].unsqueeze(-1), batch['far'][b].unsqueeze(-1) |
|
) |
|
rgb, cur_psnr_list = self.render_triplane( |
|
reconstructions_unnormalize[b:b+1], batch['batch_rays'][b], batch['img_flat'][b], |
|
batch['near'][b].unsqueeze(-1), batch['far'][b].unsqueeze(-1) |
|
) |
|
elif self.renderer_type == 'eg3d': |
|
rgb_input, cur_psnr_list_input = self.render_triplane_eg3d_decoder( |
|
batch['triplane_ori'][b:b+1], batch['batch_rays'][b], batch['img'][b], |
|
) |
|
rgb, cur_psnr_list = self.render_triplane_eg3d_decoder( |
|
reconstructions_unnormalize[b:b+1], batch['batch_rays'][b], batch['img'][b], |
|
) |
|
else: |
|
raise NotImplementedError |
|
|
|
cur_psnr_list_rec = [] |
|
for i in range(rgb.shape[0]): |
|
cur_psnr_list_rec.append(mse2psnr(img2mse(rgb_input[i], rgb[i]))) |
|
|
|
rgb_input = to8b(rgb_input.detach().cpu().numpy()) |
|
rgb_gt = to8b(batch['img'][b].detach().cpu().numpy()) |
|
rgb = to8b(rgb.detach().cpu().numpy()) |
|
|
|
if batch_idx < 10: |
|
imageio.imwrite(os.path.join(self.logger.log_dir, "{}_{}_input.png".format(batch_idx, b)), rgb_input[1]) |
|
imageio.imwrite(os.path.join(self.logger.log_dir, "{}_{}_rec.png".format(batch_idx, b)), rgb[1]) |
|
imageio.imwrite(os.path.join(self.logger.log_dir, "{}_{}_gt.png".format(batch_idx, b)), rgb_gt[1]) |
|
|
|
psnr_list += cur_psnr_list |
|
psnr_input_list += cur_psnr_list_input |
|
psnr_rec_list += cur_psnr_list_rec |
|
|
|
self.log("test/psnr_input_gt", torch.Tensor(psnr_input_list).mean(), prog_bar=True) |
|
self.log("test/psnr_input_rec", torch.Tensor(psnr_rec_list).mean(), prog_bar=True) |
|
self.log("test/psnr_rec_gt", torch.Tensor(psnr_list).mean(), prog_bar=True) |
|
|
|
def on_test_epoch_end(self): |
|
mean = self.psum / self.count |
|
mean_min = self.psum_min / self.len_dset |
|
mean_max = self.psum_max / self.len_dset |
|
var = (self.psum_sq / self.count) - (mean ** 2) |
|
std = torch.sqrt(var) |
|
|
|
print("mean min: {}".format(mean_min)) |
|
print("mean max: {}".format(mean_max)) |
|
print("mean: {}".format(mean)) |
|
print("std: {}".format(std)) |
|
|
|
latent = np.concatenate(self.latent_list) |
|
q75, q25 = np.percentile(latent.reshape(-1), [75 ,25]) |
|
median = np.median(latent.reshape(-1)) |
|
iqr = q75 - q25 |
|
norm_iqr = iqr * 0.7413 |
|
print("Norm IQR: {}".format(norm_iqr)) |
|
print("Inverse Norm IQR: {}".format(1/norm_iqr)) |
|
print("Median: {}".format(median)) |
|
|
|
def loss(self, inputs, reconstructions, posteriors, prefix, batch=None): |
|
reconstructions = reconstructions.contiguous() |
|
|
|
|
|
rec_loss = F.mse_loss(inputs.contiguous(), reconstructions) |
|
kl_loss = posteriors.kl() |
|
|
|
kl_loss = kl_loss.mean() |
|
loss = self.lossconfig.rec_weight * rec_loss + self.lossconfig.kl_weight * kl_loss |
|
|
|
ret_dict = { |
|
prefix+'mean_rec_loss': torch.abs(inputs.contiguous() - reconstructions.contiguous()).mean().detach(), |
|
prefix+'rec_loss': rec_loss, |
|
prefix+'kl_loss': kl_loss, |
|
prefix+'loss': loss, |
|
prefix+'mean': posteriors.mean.mean(), |
|
prefix+'logvar': posteriors.logvar.mean(), |
|
} |
|
|
|
|
|
latent = posteriors.mean |
|
ret_dict[prefix + 'latent_max'] = latent.max() |
|
ret_dict[prefix + 'latent_min'] = latent.min() |
|
|
|
render_weight = self.lossconfig.get("render_weight", 0) |
|
tv_weight = self.lossconfig.get("tv_weight", 0) |
|
l1_weight = self.lossconfig.get("l1_weight", 0) |
|
latent_tv_weight = self.lossconfig.get("latent_tv_weight", 0) |
|
latent_l1_weight = self.lossconfig.get("latent_l1_weight", 0) |
|
|
|
triplane_rec = self.unrollout(reconstructions) |
|
if render_weight > 0 and batch is not None: |
|
rgb_rendered, target = self.render_triplane_eg3d_decoder_sample_pixel(triplane_rec, batch['batch_rays'], batch['img']) |
|
|
|
render_loss = F.mse_loss(rgb_rendered, target) |
|
loss += render_weight * render_loss |
|
ret_dict[prefix + 'render_loss'] = render_loss |
|
if tv_weight > 0: |
|
tvloss_y = F.mse_loss(triplane_rec[:, :, :-1], triplane_rec[:, :, 1:]) |
|
tvloss_x = F.mse_loss(triplane_rec[:, :, :, :-1], triplane_rec[:, :, :, 1:]) |
|
tvloss = tvloss_y + tvloss_x |
|
loss += tv_weight * tvloss |
|
ret_dict[prefix + 'tv_loss'] = tvloss |
|
if l1_weight > 0: |
|
l1 = (triplane_rec ** 2).mean() |
|
loss += l1_weight * l1 |
|
ret_dict[prefix + 'l1_loss'] = l1 |
|
if latent_tv_weight > 0: |
|
latent = posteriors.mean |
|
latent_tv_y = F.mse_loss(latent[:, :, :-1], latent[:, :, 1:]) |
|
latent_tv_x = F.mse_loss(latent[:, :, :, :-1], latent[:, :, :, 1:]) |
|
latent_tv_loss = latent_tv_y + latent_tv_x |
|
loss += latent_tv_loss * latent_tv_weight |
|
ret_dict[prefix + 'latent_tv_loss'] = latent_tv_loss |
|
if latent_l1_weight > 0: |
|
latent = posteriors.mean |
|
latent_l1_loss = (latent ** 2).mean() |
|
loss += latent_l1_loss * latent_l1_weight |
|
ret_dict[prefix + 'latent_l1_loss'] = latent_l1_loss |
|
|
|
return loss, ret_dict |
|
|
|
|
|
class AutoencoderKLGroupConvLateFusion(AutoencoderKL): |
|
def __init__(self, *args, **kwargs): |
|
try: |
|
ckpt_path = kwargs['ckpt_path'] |
|
kwargs['ckpt_path'] = None |
|
except: |
|
ckpt_path = None |
|
|
|
super().__init__(*args, **kwargs) |
|
self.latent_list = [] |
|
self.psum = torch.zeros([1]) |
|
self.psum_sq = torch.zeros([1]) |
|
self.psum_min = torch.zeros([1]) |
|
self.psum_max = torch.zeros([1]) |
|
self.count = 0 |
|
self.len_dset = 0 |
|
|
|
ddconfig = kwargs['ddconfig'] |
|
del self.decoder |
|
del self.encoder |
|
self.encoder = Encoder_GroupConv_LateFusion(**ddconfig) |
|
self.decoder = Decoder_GroupConv_LateFusion(**ddconfig) |
|
|
|
if ckpt_path is not None: |
|
self.init_from_ckpt(ckpt_path) |
|
|
|
def rollout(self, triplane): |
|
res = triplane.shape[-1] |
|
ch = triplane.shape[1] |
|
triplane = triplane.reshape(-1, 3, ch//3, res, res).permute(0, 2, 3, 1, 4).reshape(-1, ch//3, res, 3 * res) |
|
return triplane |
|
|
|
def to3daware(self, triplane): |
|
res = triplane.shape[-2] |
|
plane1 = triplane[..., :res] |
|
plane2 = triplane[..., res:2*res] |
|
plane3 = triplane[..., 2*res:3*res] |
|
|
|
x_mp = torch.nn.MaxPool2d((res, 1)) |
|
y_mp = torch.nn.MaxPool2d((1, res)) |
|
x_mp_rep = lambda i: x_mp(i).repeat(1, 1, res, 1).permute(0, 1, 3, 2) |
|
y_mp_rep = lambda i: y_mp(i).repeat(1, 1, 1, res).permute(0, 1, 3, 2) |
|
|
|
plane21 = x_mp_rep(plane2) |
|
plane31 = torch.flip(y_mp_rep(plane3), (3,)) |
|
new_plane1 = torch.cat([plane1, plane21, plane31], 1) |
|
|
|
plane12 = y_mp_rep(plane1) |
|
plane32 = x_mp_rep(plane3) |
|
new_plane2 = torch.cat([plane2, plane12, plane32], 1) |
|
|
|
plane13 = torch.flip(x_mp_rep(plane1), (2,)) |
|
plane23 = y_mp_rep(plane2) |
|
new_plane3 = torch.cat([plane3, plane13, plane23], 1) |
|
|
|
new_plane = torch.cat([new_plane1, new_plane2, new_plane3], -1).contiguous() |
|
return new_plane |
|
|
|
def unrollout(self, triplane): |
|
res = triplane.shape[-2] |
|
ch = 3 * triplane.shape[1] |
|
triplane = triplane.reshape(-1, ch//3, res, 3, res).permute(0, 3, 1, 2, 4).reshape(-1, ch, res, res) |
|
return triplane |
|
|
|
def encode(self, x, rollout=False): |
|
if rollout: |
|
x = self.rollout(x) |
|
h = self.encoder(x) |
|
moments = self.quant_conv(h) |
|
posterior = DiagonalGaussianDistribution(moments) |
|
return posterior |
|
|
|
def decode(self, z, unrollout=False): |
|
z = self.post_quant_conv(z) |
|
dec = self.decoder(z) |
|
if unrollout: |
|
dec = self.unrollout(dec) |
|
return dec |
|
|
|
def forward(self, input, sample_posterior=True): |
|
posterior = self.encode(input) |
|
if sample_posterior: |
|
z = posterior.sample() |
|
else: |
|
z = posterior.mode() |
|
dec = self.decode(z) |
|
return dec, posterior |
|
|
|
def training_step(self, batch, batch_idx): |
|
inputs = self.rollout(batch['triplane']) |
|
reconstructions, posterior = self(inputs) |
|
aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, prefix='train/', batch=batch) |
|
self.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=False) |
|
return aeloss |
|
|
|
def validation_step(self, batch, batch_idx): |
|
inputs = self.rollout(batch['triplane']) |
|
reconstructions, posterior = self(inputs, sample_posterior=False) |
|
aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, prefix='val/', batch=None) |
|
self.log_dict(log_dict_ae) |
|
|
|
assert not self.norm |
|
reconstructions = self.unrollout(reconstructions) |
|
psnr_list = [] |
|
psnr_input_list = [] |
|
psnr_rec_list = [] |
|
batch_size = inputs.shape[0] |
|
for b in range(batch_size): |
|
if self.renderer_type == 'nerf': |
|
rgb_input, cur_psnr_list_input = self.render_triplane( |
|
batch['triplane_ori'][b:b+1], batch['batch_rays'][b], batch['img_flat'][b], |
|
batch['near'][b].unsqueeze(-1), batch['far'][b].unsqueeze(-1) |
|
) |
|
rgb, cur_psnr_list = self.render_triplane( |
|
reconstructions[b:b+1], batch['batch_rays'][b], batch['img_flat'][b], |
|
batch['near'][b].unsqueeze(-1), batch['far'][b].unsqueeze(-1) |
|
) |
|
elif self.renderer_type == 'eg3d': |
|
rgb_input, cur_psnr_list_input = self.render_triplane_eg3d_decoder( |
|
batch['triplane_ori'][b:b+1], batch['batch_rays'][b], batch['img'][b], |
|
) |
|
rgb, cur_psnr_list = self.render_triplane_eg3d_decoder( |
|
reconstructions[b:b+1], batch['batch_rays'][b], batch['img'][b], |
|
) |
|
else: |
|
raise NotImplementedError |
|
|
|
cur_psnr_list_rec = [] |
|
for i in range(rgb.shape[0]): |
|
cur_psnr_list_rec.append(mse2psnr(img2mse(rgb_input[i], rgb[i]))) |
|
|
|
rgb_input = to8b(rgb_input.detach().cpu().numpy()) |
|
rgb_gt = to8b(batch['img'][b].detach().cpu().numpy()) |
|
rgb = to8b(rgb.detach().cpu().numpy()) |
|
|
|
if b % 4 == 0 and batch_idx < 10: |
|
rgb_all = np.concatenate([rgb_gt[1], rgb_input[1], rgb[1]], 1) |
|
self.logger.experiment.log({ |
|
"val/vis": [wandb.Image(rgb_all)] |
|
}) |
|
|
|
psnr_list += cur_psnr_list |
|
psnr_input_list += cur_psnr_list_input |
|
psnr_rec_list += cur_psnr_list_rec |
|
|
|
self.log("val/psnr_input_gt", torch.Tensor(psnr_input_list).mean(), prog_bar=True) |
|
self.log("val/psnr_input_rec", torch.Tensor(psnr_rec_list).mean(), prog_bar=True) |
|
self.log("val/psnr_rec_gt", torch.Tensor(psnr_list).mean(), prog_bar=True) |
|
|
|
return self.log_dict |
|
|
|
def to_rgb(self, plane): |
|
x = plane.float() |
|
if not hasattr(self, "colorize"): |
|
self.colorize = torch.randn(3, x.shape[1], 1, 1).to(x) |
|
x = torch.nn.functional.conv2d(x, weight=self.colorize) |
|
x = ((x - x.min()) / (x.max() - x.min()) * 255.).permute(0, 2, 3, 1).detach().cpu().numpy().astype(np.uint8) |
|
return x |
|
|
|
def to_rgb_triplane(self, plane): |
|
x = plane.float() |
|
if not hasattr(self, "colorize_triplane"): |
|
self.colorize_triplane = torch.randn(3, x.shape[1], 1, 1).to(x) |
|
x = torch.nn.functional.conv2d(x, weight=self.colorize_triplane) |
|
x = ((x - x.min()) / (x.max() - x.min()) * 255.).permute(0, 2, 3, 1).detach().cpu().numpy().astype(np.uint8) |
|
return x |
|
|
|
def to_rgb_3daware(self, plane): |
|
x = plane.float() |
|
if not hasattr(self, "colorize_3daware"): |
|
self.colorize_3daware = torch.randn(3, x.shape[1], 1, 1).to(x) |
|
x = torch.nn.functional.conv2d(x, weight=self.colorize_3daware) |
|
x = ((x - x.min()) / (x.max() - x.min()) * 255.).permute(0, 2, 3, 1).detach().cpu().numpy().astype(np.uint8) |
|
return x |
|
|
|
def test_step(self, batch, batch_idx): |
|
inputs = self.rollout(batch['triplane']) |
|
reconstructions, posterior = self(inputs, sample_posterior=False) |
|
aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, prefix='test/', batch=None) |
|
self.log_dict(log_dict_ae) |
|
|
|
batch_size = inputs.shape[0] |
|
psnr_list = [] |
|
psnr_input_list = [] |
|
psnr_rec_list = [] |
|
|
|
z = posterior.mode() |
|
colorize_z = self.to_rgb(z)[0] |
|
colorize_triplane_input = self.to_rgb_triplane(inputs)[0] |
|
colorize_triplane_output = self.to_rgb_triplane(reconstructions)[0] |
|
|
|
|
|
|
|
|
|
if batch_idx < 10: |
|
imageio.imwrite(os.path.join(self.logger.log_dir, "colorize_z_{}.png".format(batch_idx)), colorize_z) |
|
imageio.imwrite(os.path.join(self.logger.log_dir, "colorize_input_{}.png".format(batch_idx)), colorize_triplane_input) |
|
imageio.imwrite(os.path.join(self.logger.log_dir, "colorize_output_{}.png".format(batch_idx)), colorize_triplane_output) |
|
|
|
|
|
|
|
|
|
np_z = z.detach().cpu().numpy() |
|
|
|
|
|
|
|
self.latent_list.append(np_z) |
|
|
|
reconstructions = self.unrollout(reconstructions) |
|
|
|
if self.psum.device != z.device: |
|
self.psum = self.psum.to(z.device) |
|
self.psum_sq = self.psum_sq.to(z.device) |
|
self.psum_min = self.psum_min.to(z.device) |
|
self.psum_max = self.psum_max.to(z.device) |
|
self.psum += z.sum() |
|
self.psum_sq += (z ** 2).sum() |
|
self.psum_min += z.reshape(-1).min(-1)[0] |
|
self.psum_max += z.reshape(-1).max(-1)[0] |
|
assert len(z.shape) == 4 |
|
self.count += z.shape[0] * z.shape[1] * z.shape[2] * z.shape[3] |
|
self.len_dset += 1 |
|
|
|
if self.norm: |
|
assert NotImplementedError |
|
else: |
|
reconstructions_unnormalize = reconstructions |
|
|
|
for b in range(batch_size): |
|
if self.renderer_type == 'nerf': |
|
rgb_input, cur_psnr_list_input = self.render_triplane( |
|
batch['triplane_ori'][b:b+1], batch['batch_rays'][b], batch['img_flat'][b], |
|
batch['near'][b].unsqueeze(-1), batch['far'][b].unsqueeze(-1) |
|
) |
|
rgb, cur_psnr_list = self.render_triplane( |
|
reconstructions_unnormalize[b:b+1], batch['batch_rays'][b], batch['img_flat'][b], |
|
batch['near'][b].unsqueeze(-1), batch['far'][b].unsqueeze(-1) |
|
) |
|
elif self.renderer_type == 'eg3d': |
|
rgb_input, cur_psnr_list_input = self.render_triplane_eg3d_decoder( |
|
batch['triplane_ori'][b:b+1], batch['batch_rays'][b], batch['img'][b], |
|
) |
|
rgb, cur_psnr_list = self.render_triplane_eg3d_decoder( |
|
reconstructions_unnormalize[b:b+1], batch['batch_rays'][b], batch['img'][b], |
|
) |
|
else: |
|
raise NotImplementedError |
|
|
|
cur_psnr_list_rec = [] |
|
for i in range(rgb.shape[0]): |
|
cur_psnr_list_rec.append(mse2psnr(img2mse(rgb_input[i], rgb[i]))) |
|
|
|
rgb_input = to8b(rgb_input.detach().cpu().numpy()) |
|
rgb_gt = to8b(batch['img'][b].detach().cpu().numpy()) |
|
rgb = to8b(rgb.detach().cpu().numpy()) |
|
|
|
if batch_idx < 10: |
|
imageio.imwrite(os.path.join(self.logger.log_dir, "{}_{}_input.png".format(batch_idx, b)), rgb_input[1]) |
|
imageio.imwrite(os.path.join(self.logger.log_dir, "{}_{}_rec.png".format(batch_idx, b)), rgb[1]) |
|
imageio.imwrite(os.path.join(self.logger.log_dir, "{}_{}_gt.png".format(batch_idx, b)), rgb_gt[1]) |
|
|
|
psnr_list += cur_psnr_list |
|
psnr_input_list += cur_psnr_list_input |
|
psnr_rec_list += cur_psnr_list_rec |
|
|
|
self.log("test/psnr_input_gt", torch.Tensor(psnr_input_list).mean(), prog_bar=True) |
|
self.log("test/psnr_input_rec", torch.Tensor(psnr_rec_list).mean(), prog_bar=True) |
|
self.log("test/psnr_rec_gt", torch.Tensor(psnr_list).mean(), prog_bar=True) |
|
|
|
def on_test_epoch_end(self): |
|
mean = self.psum / self.count |
|
mean_min = self.psum_min / self.len_dset |
|
mean_max = self.psum_max / self.len_dset |
|
var = (self.psum_sq / self.count) - (mean ** 2) |
|
std = torch.sqrt(var) |
|
|
|
print("mean min: {}".format(mean_min)) |
|
print("mean max: {}".format(mean_max)) |
|
print("mean: {}".format(mean)) |
|
print("std: {}".format(std)) |
|
|
|
latent = np.concatenate(self.latent_list) |
|
q75, q25 = np.percentile(latent.reshape(-1), [75 ,25]) |
|
median = np.median(latent.reshape(-1)) |
|
iqr = q75 - q25 |
|
norm_iqr = iqr * 0.7413 |
|
print("Norm IQR: {}".format(norm_iqr)) |
|
print("Inverse Norm IQR: {}".format(1/norm_iqr)) |
|
print("Median: {}".format(median)) |
|
|
|
|
|
from module.model_2d import ViTEncoder, ViTDecoder |
|
|
|
class AutoencoderVIT(AutoencoderKL): |
|
def __init__(self, *args, **kwargs): |
|
try: |
|
ckpt_path = kwargs['ckpt_path'] |
|
kwargs['ckpt_path'] = None |
|
except: |
|
ckpt_path = None |
|
|
|
super().__init__(*args, **kwargs) |
|
self.latent_list = [] |
|
self.psum = torch.zeros([1]) |
|
self.psum_sq = torch.zeros([1]) |
|
self.psum_min = torch.zeros([1]) |
|
self.psum_max = torch.zeros([1]) |
|
self.count = 0 |
|
self.len_dset = 0 |
|
|
|
ddconfig = kwargs['ddconfig'] |
|
|
|
del self.decoder |
|
del self.encoder |
|
del self.quant_conv |
|
del self.post_quant_conv |
|
|
|
assert ddconfig["z_channels"] == 256 |
|
self.encoder = ViTEncoder( |
|
image_size=(256, 256*3), |
|
patch_size=(256//32, 256//32), |
|
dim=768, |
|
depth=12, |
|
heads=12, |
|
mlp_dim=3072, |
|
channels=8) |
|
self.decoder = ViTDecoder( |
|
image_size=(256, 256*3), |
|
patch_size=(256//32, 256//32), |
|
dim=768, |
|
depth=12, |
|
heads=12, |
|
mlp_dim=3072, |
|
channels=8) |
|
|
|
self.quant_conv = torch.nn.Conv2d(768, 2*self.embed_dim, 1) |
|
self.post_quant_conv = torch.nn.Conv2d(self.embed_dim, 768, 1) |
|
|
|
if "mean" in ddconfig: |
|
print("Using mean std!!") |
|
self.triplane_mean = torch.Tensor(ddconfig['mean']).reshape(-1).unsqueeze(0).unsqueeze(-1).unsqueeze(-1).float() |
|
self.triplane_std = torch.Tensor(ddconfig['std']).reshape(-1).unsqueeze(0).unsqueeze(-1).unsqueeze(-1).float() |
|
else: |
|
self.triplane_mean = None |
|
self.triplane_std = None |
|
|
|
if ckpt_path is not None: |
|
self.init_from_ckpt(ckpt_path) |
|
|
|
def rollout(self, triplane): |
|
res = triplane.shape[-1] |
|
ch = triplane.shape[1] |
|
triplane = triplane.reshape(-1, 3, ch//3, res, res).permute(0, 2, 3, 1, 4).reshape(-1, ch//3, res, 3 * res) |
|
return triplane |
|
|
|
def to3daware(self, triplane): |
|
res = triplane.shape[-2] |
|
plane1 = triplane[..., :res] |
|
plane2 = triplane[..., res:2*res] |
|
plane3 = triplane[..., 2*res:3*res] |
|
|
|
x_mp = torch.nn.MaxPool2d((res, 1)) |
|
y_mp = torch.nn.MaxPool2d((1, res)) |
|
x_mp_rep = lambda i: x_mp(i).repeat(1, 1, res, 1).permute(0, 1, 3, 2) |
|
y_mp_rep = lambda i: y_mp(i).repeat(1, 1, 1, res).permute(0, 1, 3, 2) |
|
|
|
plane21 = x_mp_rep(plane2) |
|
plane31 = torch.flip(y_mp_rep(plane3), (3,)) |
|
new_plane1 = torch.cat([plane1, plane21, plane31], 1) |
|
|
|
plane12 = y_mp_rep(plane1) |
|
plane32 = x_mp_rep(plane3) |
|
new_plane2 = torch.cat([plane2, plane12, plane32], 1) |
|
|
|
plane13 = torch.flip(x_mp_rep(plane1), (2,)) |
|
plane23 = y_mp_rep(plane2) |
|
new_plane3 = torch.cat([plane3, plane13, plane23], 1) |
|
|
|
new_plane = torch.cat([new_plane1, new_plane2, new_plane3], -1).contiguous() |
|
return new_plane |
|
|
|
def unrollout(self, triplane): |
|
res = triplane.shape[-2] |
|
ch = 3 * triplane.shape[1] |
|
triplane = triplane.reshape(-1, ch//3, res, 3, res).permute(0, 3, 1, 2, 4).reshape(-1, ch, res, res) |
|
return triplane |
|
|
|
def encode(self, x, rollout=False): |
|
if rollout: |
|
|
|
x = self.rollout(x) |
|
if self.triplane_mean is not None: |
|
x = (x - self.triplane_mean.to(x.device)) / self.triplane_std.to(x.device) |
|
h = self.encoder(x) |
|
moments = self.quant_conv(h) |
|
posterior = DiagonalGaussianDistribution(moments) |
|
return posterior |
|
|
|
def decode(self, z, unrollout=False): |
|
|
|
z = self.post_quant_conv(z) |
|
dec = self.decoder(z) |
|
if self.triplane_mean is not None: |
|
dec = dec * self.triplane_std.to(dec.device) + self.triplane_mean.to(dec.device) |
|
if unrollout: |
|
dec = self.unrollout(dec) |
|
return dec |
|
|
|
def forward(self, input, sample_posterior=True): |
|
posterior = self.encode(input) |
|
if sample_posterior: |
|
z = posterior.sample() |
|
else: |
|
z = posterior.mode() |
|
dec = self.decode(z) |
|
return dec, posterior |
|
|
|
def training_step(self, batch, batch_idx): |
|
inputs = self.rollout(batch['triplane']) |
|
reconstructions, posterior = self(inputs) |
|
aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, prefix='train/', batch=batch) |
|
self.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=False) |
|
return aeloss |
|
|
|
def validation_step(self, batch, batch_idx): |
|
inputs = self.rollout(batch['triplane']) |
|
reconstructions, posterior = self(inputs, sample_posterior=False) |
|
aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, prefix='val/', batch=None) |
|
self.log_dict(log_dict_ae) |
|
|
|
assert not self.norm |
|
reconstructions = self.unrollout(reconstructions) |
|
psnr_list = [] |
|
psnr_input_list = [] |
|
psnr_rec_list = [] |
|
batch_size = inputs.shape[0] |
|
for b in range(batch_size): |
|
if self.renderer_type == 'nerf': |
|
rgb_input, cur_psnr_list_input = self.render_triplane( |
|
batch['triplane_ori'][b:b+1], batch['batch_rays'][b], batch['img_flat'][b], |
|
batch['near'][b].unsqueeze(-1), batch['far'][b].unsqueeze(-1) |
|
) |
|
rgb, cur_psnr_list = self.render_triplane( |
|
reconstructions[b:b+1], batch['batch_rays'][b], batch['img_flat'][b], |
|
batch['near'][b].unsqueeze(-1), batch['far'][b].unsqueeze(-1) |
|
) |
|
elif self.renderer_type == 'eg3d': |
|
rgb_input, cur_psnr_list_input = self.render_triplane_eg3d_decoder( |
|
batch['triplane_ori'][b:b+1], batch['batch_rays'][b], batch['img'][b], |
|
) |
|
rgb, cur_psnr_list = self.render_triplane_eg3d_decoder( |
|
reconstructions[b:b+1], batch['batch_rays'][b], batch['img'][b], |
|
) |
|
else: |
|
raise NotImplementedError |
|
|
|
cur_psnr_list_rec = [] |
|
for i in range(rgb.shape[0]): |
|
cur_psnr_list_rec.append(mse2psnr(img2mse(rgb_input[i], rgb[i]))) |
|
|
|
rgb_input = to8b(rgb_input.detach().cpu().numpy()) |
|
rgb_gt = to8b(batch['img'][b].detach().cpu().numpy()) |
|
rgb = to8b(rgb.detach().cpu().numpy()) |
|
|
|
if b % 4 == 0 and batch_idx < 10: |
|
rgb_all = np.concatenate([rgb_gt[1], rgb_input[1], rgb[1]], 1) |
|
self.logger.experiment.log({ |
|
"val/vis": [wandb.Image(rgb_all)] |
|
}) |
|
|
|
psnr_list += cur_psnr_list |
|
psnr_input_list += cur_psnr_list_input |
|
psnr_rec_list += cur_psnr_list_rec |
|
|
|
self.log("val/psnr_input_gt", torch.Tensor(psnr_input_list).mean(), prog_bar=True) |
|
self.log("val/psnr_input_rec", torch.Tensor(psnr_rec_list).mean(), prog_bar=True) |
|
self.log("val/psnr_rec_gt", torch.Tensor(psnr_list).mean(), prog_bar=True) |
|
|
|
return self.log_dict |
|
|
|
def to_rgb(self, plane): |
|
x = plane.float() |
|
if not hasattr(self, "colorize"): |
|
self.colorize = torch.randn(3, x.shape[1], 1, 1).to(x) |
|
x = torch.nn.functional.conv2d(x, weight=self.colorize) |
|
x = ((x - x.min()) / (x.max() - x.min()) * 255.).permute(0, 2, 3, 1).detach().cpu().numpy().astype(np.uint8) |
|
return x |
|
|
|
def to_rgb_triplane(self, plane): |
|
x = plane.float() |
|
if not hasattr(self, "colorize_triplane"): |
|
self.colorize_triplane = torch.randn(3, x.shape[1], 1, 1).to(x) |
|
x = torch.nn.functional.conv2d(x, weight=self.colorize_triplane) |
|
x = ((x - x.min()) / (x.max() - x.min()) * 255.).permute(0, 2, 3, 1).detach().cpu().numpy().astype(np.uint8) |
|
return x |
|
|
|
def to_rgb_3daware(self, plane): |
|
x = plane.float() |
|
if not hasattr(self, "colorize_3daware"): |
|
self.colorize_3daware = torch.randn(3, x.shape[1], 1, 1).to(x) |
|
x = torch.nn.functional.conv2d(x, weight=self.colorize_3daware) |
|
x = ((x - x.min()) / (x.max() - x.min()) * 255.).permute(0, 2, 3, 1).detach().cpu().numpy().astype(np.uint8) |
|
return x |
|
|
|
def test_step(self, batch, batch_idx): |
|
inputs = self.rollout(batch['triplane']) |
|
reconstructions, posterior = self(inputs, sample_posterior=False) |
|
aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, prefix='test/', batch=None) |
|
self.log_dict(log_dict_ae) |
|
|
|
batch_size = inputs.shape[0] |
|
psnr_list = [] |
|
psnr_input_list = [] |
|
psnr_rec_list = [] |
|
|
|
z = posterior.mode() |
|
colorize_z = self.to_rgb(z)[0] |
|
colorize_triplane_input = self.to_rgb_triplane(inputs)[0] |
|
colorize_triplane_output = self.to_rgb_triplane(reconstructions)[0] |
|
|
|
import os |
|
import random |
|
import string |
|
|
|
z_np = inputs.detach().cpu().numpy() |
|
fname = ''.join(random.choices(string.ascii_uppercase + string.digits, k=8)) + '.npy' |
|
with open(os.path.join('/mnt/lustre/hongfangzhou.p/AE3D/tmp', fname), 'wb') as f: |
|
np.save(f, z_np) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
np_z = z.detach().cpu().numpy() |
|
|
|
|
|
|
|
self.latent_list.append(np_z) |
|
|
|
reconstructions = self.unrollout(reconstructions) |
|
|
|
if self.psum.device != z.device: |
|
self.psum = self.psum.to(z.device) |
|
self.psum_sq = self.psum_sq.to(z.device) |
|
self.psum_min = self.psum_min.to(z.device) |
|
self.psum_max = self.psum_max.to(z.device) |
|
self.psum += z.sum() |
|
self.psum_sq += (z ** 2).sum() |
|
self.psum_min += z.reshape(-1).min(-1)[0] |
|
self.psum_max += z.reshape(-1).max(-1)[0] |
|
assert len(z.shape) == 4 |
|
self.count += z.shape[0] * z.shape[1] * z.shape[2] * z.shape[3] |
|
self.len_dset += 1 |
|
|
|
if self.norm: |
|
assert NotImplementedError |
|
else: |
|
reconstructions_unnormalize = reconstructions |
|
|
|
if True: |
|
for b in range(batch_size): |
|
if self.renderer_type == 'nerf': |
|
rgb_input, cur_psnr_list_input = self.render_triplane( |
|
batch['triplane_ori'][b:b+1], batch['batch_rays'][b], batch['img_flat'][b], |
|
batch['near'][b].unsqueeze(-1), batch['far'][b].unsqueeze(-1) |
|
) |
|
rgb, cur_psnr_list = self.render_triplane( |
|
reconstructions_unnormalize[b:b+1], batch['batch_rays'][b], batch['img_flat'][b], |
|
batch['near'][b].unsqueeze(-1), batch['far'][b].unsqueeze(-1) |
|
) |
|
elif self.renderer_type == 'eg3d': |
|
rgb_input, cur_psnr_list_input = self.render_triplane_eg3d_decoder( |
|
batch['triplane_ori'][b:b+1], batch['batch_rays'][b], batch['img'][b], |
|
) |
|
rgb, cur_psnr_list = self.render_triplane_eg3d_decoder( |
|
reconstructions_unnormalize[b:b+1], batch['batch_rays'][b], batch['img'][b], |
|
) |
|
else: |
|
raise NotImplementedError |
|
|
|
cur_psnr_list_rec = [] |
|
for i in range(rgb.shape[0]): |
|
cur_psnr_list_rec.append(mse2psnr(img2mse(rgb_input[i], rgb[i]))) |
|
|
|
rgb_input = to8b(rgb_input.detach().cpu().numpy()) |
|
rgb_gt = to8b(batch['img'][b].detach().cpu().numpy()) |
|
rgb = to8b(rgb.detach().cpu().numpy()) |
|
|
|
|
|
|
|
|
|
|
|
|
|
psnr_list += cur_psnr_list |
|
psnr_input_list += cur_psnr_list_input |
|
psnr_rec_list += cur_psnr_list_rec |
|
|
|
self.log("test/psnr_input_gt", torch.Tensor(psnr_input_list).mean(), prog_bar=True) |
|
self.log("test/psnr_input_rec", torch.Tensor(psnr_rec_list).mean(), prog_bar=True) |
|
self.log("test/psnr_rec_gt", torch.Tensor(psnr_list).mean(), prog_bar=True) |
|
|
|
def on_test_epoch_end(self): |
|
mean = self.psum / self.count |
|
mean_min = self.psum_min / self.len_dset |
|
mean_max = self.psum_max / self.len_dset |
|
var = (self.psum_sq / self.count) - (mean ** 2) |
|
std = torch.sqrt(var) |
|
|
|
print("mean min: {}".format(mean_min)) |
|
print("mean max: {}".format(mean_max)) |
|
print("mean: {}".format(mean)) |
|
print("std: {}".format(std)) |
|
|
|
latent = np.concatenate(self.latent_list) |
|
q75, q25 = np.percentile(latent.reshape(-1), [75 ,25]) |
|
median = np.median(latent.reshape(-1)) |
|
iqr = q75 - q25 |
|
norm_iqr = iqr * 0.7413 |
|
print("Norm IQR: {}".format(norm_iqr)) |
|
print("Inverse Norm IQR: {}".format(1/norm_iqr)) |
|
print("Median: {}".format(median)) |
|
|