import torch from random import randint from tqdm.rich import trange from tqdm import tqdm as tqdm from source.networks import Warper3DGS import wandb import sys sys.path.append('./submodules/gaussian-splatting/') import lpips from source.losses import ssim, l1_loss, psnr from rich.console import Console from rich.theme import Theme custom_theme = Theme({ "info": "dim cyan", "warning": "magenta", "danger": "bold red" }) #from source.corr_init import init_gaussians_with_corr from source.corr_init_new import init_gaussians_with_corr_profiled as init_gaussians_with_corr from source.utils_aux import log_samples from source.timer import Timer class EDGSTrainer: def __init__(self, GS: Warper3DGS, training_config, dataset_white_background=False, device=torch.device('cuda'), log_wandb=True, ): self.GS = GS self.scene = GS.scene self.viewpoint_stack = GS.viewpoint_stack self.gaussians = GS.gaussians self.training_config = training_config self.GS_optimizer = GS.gaussians.optimizer self.dataset_white_background = dataset_white_background self.training_step = 1 self.gs_step = 0 self.CONSOLE = Console(width=120, theme=custom_theme) self.saving_iterations = training_config.save_iterations self.evaluate_iterations = None self.batch_size = training_config.batch_size self.ema_loss_for_log = 0.0 # Logs in the format {step:{"loss1":loss1_value, "loss2":loss2_value}} self.logs_losses = {} self.lpips = lpips.LPIPS(net='vgg').to(device) self.device = device self.timer = Timer() self.log_wandb = log_wandb def load_checkpoints(self, load_cfg): # Load 3DGS checkpoint if load_cfg.gs: self.gs.gaussians.restore( torch.load(f"{load_cfg.gs}/chkpnt{load_cfg.gs_step}.pth")[0], self.training_config) self.GS_optimizer = self.GS.gaussians.optimizer self.CONSOLE.print(f"3DGS loaded from checkpoint for iteration {load_cfg.gs_step}", style="info") self.training_step += load_cfg.gs_step self.gs_step += load_cfg.gs_step def train(self, train_cfg): # 3DGS training self.CONSOLE.print("Train 3DGS for {} iterations".format(train_cfg.gs_epochs), style="info") with trange(self.training_step, self.training_step + train_cfg.gs_epochs, desc="[green]Train gaussians") as progress_bar: for self.training_step in progress_bar: radii = self.train_step_gs(max_lr=train_cfg.max_lr, no_densify=train_cfg.no_densify) with torch.no_grad(): if train_cfg.no_densify: self.prune(radii) else: self.densify_and_prune(radii) if train_cfg.reduce_opacity: # Slightly reduce opacity every few steps: if self.gs_step < self.training_config.densify_until_iter and self.gs_step % 10 == 0: opacities_new = torch.log(torch.exp(self.GS.gaussians._opacity.data) * 0.99) self.GS.gaussians._opacity.data = opacities_new self.timer.pause() # Progress bar if self.training_step % 10 == 0: progress_bar.set_postfix({"[red]Loss": f"{self.ema_loss_for_log:.{7}f}"}, refresh=True) # Log and save if self.training_step in self.saving_iterations: self.save_model() if self.evaluate_iterations is not None: if self.training_step in self.evaluate_iterations: self.evaluate() else: if (self.training_step <= 3000 and self.training_step % 500 == 0) or \ (self.training_step > 3000 and self.training_step % 1000 == 228) : self.evaluate() self.timer.start() def evaluate(self): torch.cuda.empty_cache() log_gen_images, log_real_images = [], [] validation_configs = ({'name': 'test', 'cameras': self.scene.getTestCameras(), 'cam_idx': self.training_config.TEST_CAM_IDX_TO_LOG}, {'name': 'train', 'cameras': [self.scene.getTrainCameras()[idx % len(self.scene.getTrainCameras())] for idx in range(0, 150, 5)], 'cam_idx': 10}) if self.log_wandb: wandb.log({f"Number of Gaussians": len(self.GS.gaussians._xyz)}, step=self.training_step) for config in validation_configs: if config['cameras'] and len(config['cameras']) > 0: l1_test = 0.0 psnr_test = 0.0 ssim_test = 0.0 lpips_splat_test = 0.0 for idx, viewpoint in enumerate(config['cameras']): image = torch.clamp(self.GS(viewpoint)["render"], 0.0, 1.0) gt_image = torch.clamp(viewpoint.original_image.to(self.device), 0.0, 1.0) l1_test += l1_loss(image, gt_image).double() psnr_test += psnr(image.unsqueeze(0), gt_image.unsqueeze(0)).double() ssim_test += ssim(image, gt_image).double() lpips_splat_test += self.lpips(image, gt_image).detach().double() if idx in [config['cam_idx']]: log_gen_images.append(image) log_real_images.append(gt_image) psnr_test /= len(config['cameras']) l1_test /= len(config['cameras']) ssim_test /= len(config['cameras']) lpips_splat_test /= len(config['cameras']) if self.log_wandb: wandb.log({f"{config['name']}/L1": l1_test.item(), f"{config['name']}/PSNR": psnr_test.item(), \ f"{config['name']}/SSIM": ssim_test.item(), f"{config['name']}/LPIPS_splat": lpips_splat_test.item()}, step = self.training_step) self.CONSOLE.print("\n[ITER {}], #{} gaussians, Evaluating {}: L1={:.6f}, PSNR={:.6f}, SSIM={:.6f}, LPIPS_splat={:.6f} ".format( self.training_step, len(self.GS.gaussians._xyz), config['name'], l1_test.item(), psnr_test.item(), ssim_test.item(), lpips_splat_test.item()), style="info") if self.log_wandb: with torch.no_grad(): log_samples(torch.stack((log_real_images[0],log_gen_images[0])) , [], self.training_step, caption="Real and Generated Samples") wandb.log({"time": self.timer.get_elapsed_time()}, step=self.training_step) torch.cuda.empty_cache() def train_step_gs(self, max_lr = False, no_densify = False): self.gs_step += 1 if max_lr: self.GS.gaussians.update_learning_rate(max(self.gs_step, 8_000)) else: self.GS.gaussians.update_learning_rate(self.gs_step) # Every 1000 its we increase the levels of SH up to a maximum degree if self.gs_step % 1000 == 0: self.GS.gaussians.oneupSHdegree() # Pick a random Camera if not self.viewpoint_stack: self.viewpoint_stack = self.scene.getTrainCameras().copy() viewpoint_cam = self.viewpoint_stack.pop(randint(0, len(self.viewpoint_stack) - 1)) render_pkg = self.GS(viewpoint_cam=viewpoint_cam) image = render_pkg["render"] # Loss gt_image = viewpoint_cam.original_image.to(self.device) L1_loss = l1_loss(image, gt_image) ssim_loss = (1.0 - ssim(image, gt_image)) loss = (1.0 - self.training_config.lambda_dssim) * L1_loss + \ self.training_config.lambda_dssim * ssim_loss self.timer.pause() self.logs_losses[self.training_step] = {"loss": loss.item(), "L1_loss": L1_loss.item(), "ssim_loss": ssim_loss.item()} if self.log_wandb: for k, v in self.logs_losses[self.training_step].items(): wandb.log({f"train/{k}": v}, step=self.training_step) self.ema_loss_for_log = 0.4 * self.logs_losses[self.training_step]["loss"] + 0.6 * self.ema_loss_for_log self.timer.start() self.GS_optimizer.zero_grad(set_to_none=True) loss.backward() with torch.no_grad(): if self.gs_step < self.training_config.densify_until_iter and not no_densify: self.GS.gaussians.max_radii2D[render_pkg["visibility_filter"]] = torch.max( self.GS.gaussians.max_radii2D[render_pkg["visibility_filter"]], render_pkg["radii"][render_pkg["visibility_filter"]]) self.GS.gaussians.add_densification_stats(render_pkg["viewspace_points"], render_pkg["visibility_filter"]) # Optimizer step self.GS_optimizer.step() self.GS_optimizer.zero_grad(set_to_none=True) return render_pkg["radii"] def densify_and_prune(self, radii = None): # Densification or pruning if self.gs_step < self.training_config.densify_until_iter: if (self.gs_step > self.training_config.densify_from_iter) and \ (self.gs_step % self.training_config.densification_interval == 0): size_threshold = 20 if self.gs_step > self.training_config.opacity_reset_interval else None self.GS.gaussians.densify_and_prune(self.training_config.densify_grad_threshold, 0.005, self.GS.scene.cameras_extent, size_threshold, radii) if self.gs_step % self.training_config.opacity_reset_interval == 0 or ( self.dataset_white_background and self.gs_step == self.training_config.densify_from_iter): self.GS.gaussians.reset_opacity() def save_model(self): print("\n[ITER {}] Saving Gaussians".format(self.gs_step)) self.scene.save(self.gs_step) print("\n[ITER {}] Saving Checkpoint".format(self.gs_step)) torch.save((self.GS.gaussians.capture(), self.gs_step), self.scene.model_path + "/chkpnt" + str(self.gs_step) + ".pth") def init_with_corr(self, cfg, verbose=False, roma_model=None): """ Initializes image with matchings. Also removes SfM init points. Args: cfg: configuration part named init_wC. Check train.yaml verbose: whether you want to print intermediate results. Useful for debug. roma_model: optionally you can pass here preinit RoMA model to avoid reinit it every time. """ if not cfg.use: return None N_splats_at_init = len(self.GS.gaussians._xyz) print("N_splats_at_init:", N_splats_at_init) camera_set, selected_indices, visualization_dict = init_gaussians_with_corr( self.GS.gaussians, self.scene, cfg, self.device, verbose=verbose, roma_model=roma_model) # Remove SfM points and leave only matchings inits if not cfg.add_SfM_init: with torch.no_grad(): N_splats_after_init = len(self.GS.gaussians._xyz) print("N_splats_after_init:", N_splats_after_init) self.gaussians.tmp_radii = torch.zeros(self.gaussians._xyz.shape[0]).to(self.device) mask = torch.concat([torch.ones(N_splats_at_init, dtype=torch.bool), torch.zeros(N_splats_after_init-N_splats_at_init, dtype=torch.bool)], axis=0) self.GS.gaussians.prune_points(mask) with torch.no_grad(): gaussians = self.gaussians gaussians._scaling = gaussians.scaling_inverse_activation(gaussians.scaling_activation(gaussians._scaling)*0.5) return visualization_dict def prune(self, radii, min_opacity=0.005): self.GS.gaussians.tmp_radii = radii if self.gs_step < self.training_config.densify_until_iter: prune_mask = (self.GS.gaussians.get_opacity < min_opacity).squeeze() self.GS.gaussians.prune_points(prune_mask) torch.cuda.empty_cache() self.GS.gaussians.tmp_radii = None