Spaces:
dylanebert
/
Running on Zero

EDGS / source /trainer.py
Olga
Initial commit
5f9d349
import torch
from random import randint
from tqdm.rich import trange
from tqdm import tqdm as tqdm
from source.networks import Warper3DGS
import wandb
import sys
sys.path.append('./submodules/gaussian-splatting/')
import lpips
from source.losses import ssim, l1_loss, psnr
from rich.console import Console
from rich.theme import Theme
custom_theme = Theme({
"info": "dim cyan",
"warning": "magenta",
"danger": "bold red"
})
#from source.corr_init import init_gaussians_with_corr
from source.corr_init_new import init_gaussians_with_corr_profiled as init_gaussians_with_corr
from source.utils_aux import log_samples
from source.timer import Timer
class EDGSTrainer:
def __init__(self,
GS: Warper3DGS,
training_config,
dataset_white_background=False,
device=torch.device('cuda'),
log_wandb=True,
):
self.GS = GS
self.scene = GS.scene
self.viewpoint_stack = GS.viewpoint_stack
self.gaussians = GS.gaussians
self.training_config = training_config
self.GS_optimizer = GS.gaussians.optimizer
self.dataset_white_background = dataset_white_background
self.training_step = 1
self.gs_step = 0
self.CONSOLE = Console(width=120, theme=custom_theme)
self.saving_iterations = training_config.save_iterations
self.evaluate_iterations = None
self.batch_size = training_config.batch_size
self.ema_loss_for_log = 0.0
# Logs in the format {step:{"loss1":loss1_value, "loss2":loss2_value}}
self.logs_losses = {}
self.lpips = lpips.LPIPS(net='vgg').to(device)
self.device = device
self.timer = Timer()
self.log_wandb = log_wandb
def load_checkpoints(self, load_cfg):
# Load 3DGS checkpoint
if load_cfg.gs:
self.gs.gaussians.restore(
torch.load(f"{load_cfg.gs}/chkpnt{load_cfg.gs_step}.pth")[0],
self.training_config)
self.GS_optimizer = self.GS.gaussians.optimizer
self.CONSOLE.print(f"3DGS loaded from checkpoint for iteration {load_cfg.gs_step}",
style="info")
self.training_step += load_cfg.gs_step
self.gs_step += load_cfg.gs_step
def train(self, train_cfg):
# 3DGS training
self.CONSOLE.print("Train 3DGS for {} iterations".format(train_cfg.gs_epochs), style="info")
with trange(self.training_step, self.training_step + train_cfg.gs_epochs, desc="[green]Train gaussians") as progress_bar:
for self.training_step in progress_bar:
radii = self.train_step_gs(max_lr=train_cfg.max_lr, no_densify=train_cfg.no_densify)
with torch.no_grad():
if train_cfg.no_densify:
self.prune(radii)
else:
self.densify_and_prune(radii)
if train_cfg.reduce_opacity:
# Slightly reduce opacity every few steps:
if self.gs_step < self.training_config.densify_until_iter and self.gs_step % 10 == 0:
opacities_new = torch.log(torch.exp(self.GS.gaussians._opacity.data) * 0.99)
self.GS.gaussians._opacity.data = opacities_new
self.timer.pause()
# Progress bar
if self.training_step % 10 == 0:
progress_bar.set_postfix({"[red]Loss": f"{self.ema_loss_for_log:.{7}f}"}, refresh=True)
# Log and save
if self.training_step in self.saving_iterations:
self.save_model()
if self.evaluate_iterations is not None:
if self.training_step in self.evaluate_iterations:
self.evaluate()
else:
if (self.training_step <= 3000 and self.training_step % 500 == 0) or \
(self.training_step > 3000 and self.training_step % 1000 == 228) :
self.evaluate()
self.timer.start()
def evaluate(self):
torch.cuda.empty_cache()
log_gen_images, log_real_images = [], []
validation_configs = ({'name': 'test', 'cameras': self.scene.getTestCameras(), 'cam_idx': self.training_config.TEST_CAM_IDX_TO_LOG},
{'name': 'train',
'cameras': [self.scene.getTrainCameras()[idx % len(self.scene.getTrainCameras())] for idx in
range(0, 150, 5)], 'cam_idx': 10})
if self.log_wandb:
wandb.log({f"Number of Gaussians": len(self.GS.gaussians._xyz)}, step=self.training_step)
for config in validation_configs:
if config['cameras'] and len(config['cameras']) > 0:
l1_test = 0.0
psnr_test = 0.0
ssim_test = 0.0
lpips_splat_test = 0.0
for idx, viewpoint in enumerate(config['cameras']):
image = torch.clamp(self.GS(viewpoint)["render"], 0.0, 1.0)
gt_image = torch.clamp(viewpoint.original_image.to(self.device), 0.0, 1.0)
l1_test += l1_loss(image, gt_image).double()
psnr_test += psnr(image.unsqueeze(0), gt_image.unsqueeze(0)).double()
ssim_test += ssim(image, gt_image).double()
lpips_splat_test += self.lpips(image, gt_image).detach().double()
if idx in [config['cam_idx']]:
log_gen_images.append(image)
log_real_images.append(gt_image)
psnr_test /= len(config['cameras'])
l1_test /= len(config['cameras'])
ssim_test /= len(config['cameras'])
lpips_splat_test /= len(config['cameras'])
if self.log_wandb:
wandb.log({f"{config['name']}/L1": l1_test.item(), f"{config['name']}/PSNR": psnr_test.item(), \
f"{config['name']}/SSIM": ssim_test.item(), f"{config['name']}/LPIPS_splat": lpips_splat_test.item()}, step = self.training_step)
self.CONSOLE.print("\n[ITER {}], #{} gaussians, Evaluating {}: L1={:.6f}, PSNR={:.6f}, SSIM={:.6f}, LPIPS_splat={:.6f} ".format(
self.training_step, len(self.GS.gaussians._xyz), config['name'], l1_test.item(), psnr_test.item(), ssim_test.item(), lpips_splat_test.item()), style="info")
if self.log_wandb:
with torch.no_grad():
log_samples(torch.stack((log_real_images[0],log_gen_images[0])) , [], self.training_step, caption="Real and Generated Samples")
wandb.log({"time": self.timer.get_elapsed_time()}, step=self.training_step)
torch.cuda.empty_cache()
def train_step_gs(self, max_lr = False, no_densify = False):
self.gs_step += 1
if max_lr:
self.GS.gaussians.update_learning_rate(max(self.gs_step, 8_000))
else:
self.GS.gaussians.update_learning_rate(self.gs_step)
# Every 1000 its we increase the levels of SH up to a maximum degree
if self.gs_step % 1000 == 0:
self.GS.gaussians.oneupSHdegree()
# Pick a random Camera
if not self.viewpoint_stack:
self.viewpoint_stack = self.scene.getTrainCameras().copy()
viewpoint_cam = self.viewpoint_stack.pop(randint(0, len(self.viewpoint_stack) - 1))
render_pkg = self.GS(viewpoint_cam=viewpoint_cam)
image = render_pkg["render"]
# Loss
gt_image = viewpoint_cam.original_image.to(self.device)
L1_loss = l1_loss(image, gt_image)
ssim_loss = (1.0 - ssim(image, gt_image))
loss = (1.0 - self.training_config.lambda_dssim) * L1_loss + \
self.training_config.lambda_dssim * ssim_loss
self.timer.pause()
self.logs_losses[self.training_step] = {"loss": loss.item(),
"L1_loss": L1_loss.item(),
"ssim_loss": ssim_loss.item()}
if self.log_wandb:
for k, v in self.logs_losses[self.training_step].items():
wandb.log({f"train/{k}": v}, step=self.training_step)
self.ema_loss_for_log = 0.4 * self.logs_losses[self.training_step]["loss"] + 0.6 * self.ema_loss_for_log
self.timer.start()
self.GS_optimizer.zero_grad(set_to_none=True)
loss.backward()
with torch.no_grad():
if self.gs_step < self.training_config.densify_until_iter and not no_densify:
self.GS.gaussians.max_radii2D[render_pkg["visibility_filter"]] = torch.max(
self.GS.gaussians.max_radii2D[render_pkg["visibility_filter"]],
render_pkg["radii"][render_pkg["visibility_filter"]])
self.GS.gaussians.add_densification_stats(render_pkg["viewspace_points"],
render_pkg["visibility_filter"])
# Optimizer step
self.GS_optimizer.step()
self.GS_optimizer.zero_grad(set_to_none=True)
return render_pkg["radii"]
def densify_and_prune(self, radii = None):
# Densification or pruning
if self.gs_step < self.training_config.densify_until_iter:
if (self.gs_step > self.training_config.densify_from_iter) and \
(self.gs_step % self.training_config.densification_interval == 0):
size_threshold = 20 if self.gs_step > self.training_config.opacity_reset_interval else None
self.GS.gaussians.densify_and_prune(self.training_config.densify_grad_threshold,
0.005,
self.GS.scene.cameras_extent,
size_threshold, radii)
if self.gs_step % self.training_config.opacity_reset_interval == 0 or (
self.dataset_white_background and self.gs_step == self.training_config.densify_from_iter):
self.GS.gaussians.reset_opacity()
def save_model(self):
print("\n[ITER {}] Saving Gaussians".format(self.gs_step))
self.scene.save(self.gs_step)
print("\n[ITER {}] Saving Checkpoint".format(self.gs_step))
torch.save((self.GS.gaussians.capture(), self.gs_step),
self.scene.model_path + "/chkpnt" + str(self.gs_step) + ".pth")
def init_with_corr(self, cfg, verbose=False, roma_model=None):
"""
Initializes image with matchings. Also removes SfM init points.
Args:
cfg: configuration part named init_wC. Check train.yaml
verbose: whether you want to print intermediate results. Useful for debug.
roma_model: optionally you can pass here preinit RoMA model to avoid reinit
it every time.
"""
if not cfg.use:
return None
N_splats_at_init = len(self.GS.gaussians._xyz)
print("N_splats_at_init:", N_splats_at_init)
camera_set, selected_indices, visualization_dict = init_gaussians_with_corr(
self.GS.gaussians,
self.scene,
cfg,
self.device,
verbose=verbose,
roma_model=roma_model)
# Remove SfM points and leave only matchings inits
if not cfg.add_SfM_init:
with torch.no_grad():
N_splats_after_init = len(self.GS.gaussians._xyz)
print("N_splats_after_init:", N_splats_after_init)
self.gaussians.tmp_radii = torch.zeros(self.gaussians._xyz.shape[0]).to(self.device)
mask = torch.concat([torch.ones(N_splats_at_init, dtype=torch.bool),
torch.zeros(N_splats_after_init-N_splats_at_init, dtype=torch.bool)],
axis=0)
self.GS.gaussians.prune_points(mask)
with torch.no_grad():
gaussians = self.gaussians
gaussians._scaling = gaussians.scaling_inverse_activation(gaussians.scaling_activation(gaussians._scaling)*0.5)
return visualization_dict
def prune(self, radii, min_opacity=0.005):
self.GS.gaussians.tmp_radii = radii
if self.gs_step < self.training_config.densify_until_iter:
prune_mask = (self.GS.gaussians.get_opacity < min_opacity).squeeze()
self.GS.gaussians.prune_points(prune_mask)
torch.cuda.empty_cache()
self.GS.gaussians.tmp_radii = None