from submodules.mast3r.dust3r.dust3r.losses import * from torchmetrics import PeakSignalNoiseRatio, StructuralSimilarityIndexMeasure, JaccardIndex, Accuracy import lpips from src.utils.gaussian_model import GaussianModel from src.utils.cuda_splatting import render, DummyPipeline from einops import rearrange from src.utils.camera_utils import get_scaled_camera from torchvision.utils import save_image from dust3r.inference import make_batch_symmetric class L2Loss (LLoss): """ Euclidean distance between 3d points """ def distance(self, a, b): return torch.norm(a - b, dim=-1) # normalized L2 distance class L1Loss (LLoss): """ Manhattan distance between 3d points """ def distance(self, a, b): return torch.abs(a - b).mean() # L1 distance L2 = L2Loss() L1 = L1Loss() def merge_and_split_predictions(pred1, pred2): merged = {} for key in pred1.keys(): merged_pred = torch.stack([pred1[key], pred2[key]], dim=1) merged_pred = rearrange(merged_pred, 'b v h w ... -> b (v h w) ...') merged[key] = merged_pred # Split along the batch dimension batch_size = next(iter(merged.values())).shape[0] split = [{key: value[i] for key, value in merged.items()} for i in range(batch_size)] return split class GaussianLoss(MultiLoss): def __init__(self, ssim_weight=0.2): super().__init__() self.ssim_weight = ssim_weight self.ssim = StructuralSimilarityIndexMeasure(data_range=1.0).cuda() self.psnr = PeakSignalNoiseRatio(data_range=1.0).cuda() self.lpips_vgg = lpips.LPIPS(net='vgg').cuda() self.pipeline = DummyPipeline() # bg_color self.register_buffer('bg_color', torch.tensor([0.0, 0.0, 0.0]).cuda()) def get_name(self): return f'GaussianLoss(ssim_weight={self.ssim_weight})' # def compute_loss(self, gt1, gt2, target_view, pred1, pred2, model): # # render images # # 1. merge predictions # pred = merge_and_split_predictions(pred1, pred2) # # 2. calculate optimal scaling # pred_pts1 = pred1['means'] # pred_pts2 = pred2['means'] # # convert to camera1 coordinates # # everything is normalized w.r.t. camera of view1 # valid1 = gt1['valid_mask'].clone() # valid2 = gt2['valid_mask'].clone() # in_camera1 = inv(gt1['camera_pose']) # gt_pts1 = geotrf(in_camera1, gt1['pts3d'].to(in_camera1.device)) # B,H,W,3 # gt_pts2 = geotrf(in_camera1, gt2['pts3d'].to(in_camera1.device)) # B,H,W,3 # scaling = find_opt_scaling(gt_pts1, gt_pts2, pred_pts1, pred_pts2, valid1=valid1, valid2=valid2) # # 3. render images(need gaussian model, camera, pipeline) # rendered_images = [] # rendered_feats = [] # for i in range(len(pred)): # # get gaussian model # gaussians = GaussianModel.from_predictions(pred[i], sh_degree=3) # # get camera # ref_camera_extrinsics = gt1['camera_pose'][i] # target_extrinsics = target_view['camera_pose'][i] # target_intrinsics = target_view['camera_intrinsics'][i] # image_shape = target_view['true_shape'][i] # scale = scaling[i] # camera = get_scaled_camera(ref_camera_extrinsics, target_extrinsics, target_intrinsics, scale, image_shape) # # render(image and features) # rendered_output = render(camera, gaussians, self.pipeline, self.bg_color) # rendered_images.append(rendered_output['render']) # rendered_feats.append(rendered_output['feature_map']) # rendered_images = torch.stack(rendered_images, dim=0) # B, 3, H, W # rendered_feats = torch.stack(rendered_feats, dim=0) # B, d_feats, H, W # rendered_feats = model.feature_expansion(rendered_feats) # B, 512, H//2, W//2 # gt_images = target_view['img'] * 0.5 + 0.5 # gt_feats = model.lseg_feature_extractor.extract_features(target_view['img']) # B, 512, H//2, W//2 # image_loss = torch.abs(rendered_images - gt_images).mean() # feature_loss = torch.abs(rendered_feats - gt_feats).mean() # loss = image_loss + 100 * feature_loss # # # temp # # gt_logits = model.lseg_feature_extractor.decode_feature(gt_feats, ['wall', 'floor', 'others']) # # gt_labels = torch.argmax(gt_logits, dim=1, keepdim=True) # # rendered_logits = model.lseg_feature_extractor.decode_feature(rendered_feats, ['wall', 'floor', 'others']) # # rendered_labels = torch.argmax(rendered_logits, dim=1, keepdim=True) # # calculate metric # with torch.no_grad(): # ssim = self.ssim(rendered_images, gt_images) # psnr = self.psnr(rendered_images, gt_images) # lpips = self.lpips_vgg(rendered_images, gt_images).mean() # return loss, {'ssim': ssim, 'psnr': psnr, 'lpips': lpips, 'image_loss': image_loss, 'feature_loss': feature_loss} def compute_loss(self, gt1, gt2, target_view, pred1, pred2, model): # render images # 1. merge predictions pred = merge_and_split_predictions(pred1, pred2) # 2. calculate optimal scaling pred_pts1 = pred1['means'] pred_pts2 = pred2['means'] # convert to camera1 coordinates # everything is normalized w.r.t. camera of view1 valid1 = gt1['valid_mask'].clone() valid2 = gt2['valid_mask'].clone() in_camera1 = inv(gt1['camera_pose']) gt_pts1 = geotrf(in_camera1, gt1['pts3d'].to(in_camera1.device)) # B,H,W,3 gt_pts2 = geotrf(in_camera1, gt2['pts3d'].to(in_camera1.device)) # B,H,W,3 scaling = find_opt_scaling(gt_pts1, gt_pts2, pred_pts1, pred_pts2, valid1=valid1, valid2=valid2) # 3. render images(need gaussian model, camera, pipeline) rendered_images = [] rendered_feats = [] gt_images = [] for i in range(len(pred)): # get gaussian model gaussians = GaussianModel.from_predictions(pred[i], sh_degree=3) # get camera ref_camera_extrinsics = gt1['camera_pose'][i] target_view_list = [gt1, gt2, target_view] # use gt1, gt2, and target_view for j in range(len(target_view_list)): target_extrinsics = target_view_list[j]['camera_pose'][i] target_intrinsics = target_view_list[j]['camera_intrinsics'][i] image_shape = target_view_list[j]['true_shape'][i] scale = scaling[i] camera = get_scaled_camera(ref_camera_extrinsics, target_extrinsics, target_intrinsics, scale, image_shape) # render(image and features) rendered_output = render(camera, gaussians, self.pipeline, self.bg_color) rendered_images.append(rendered_output['render']) rendered_feats.append(rendered_output['feature_map']) gt_images.append(target_view_list[j]['img'][i] * 0.5 + 0.5) rendered_images = torch.stack(rendered_images, dim=0) # B, 3, H, W gt_images = torch.stack(gt_images, dim=0) rendered_feats = torch.stack(rendered_feats, dim=0) # B, d_feats, H, W rendered_feats = model.feature_expansion(rendered_feats) # B, 512, H//2, W//2 gt_feats = model.lseg_feature_extractor.extract_features(gt_images) # B, 512, H//2, W//2 image_loss = torch.abs(rendered_images - gt_images).mean() feature_loss = torch.abs(rendered_feats - gt_feats).mean() loss = image_loss + feature_loss # calculate metric with torch.no_grad(): ssim = self.ssim(rendered_images, gt_images) psnr = self.psnr(rendered_images, gt_images) lpips = self.lpips_vgg(rendered_images, gt_images).mean() return loss, {'ssim': ssim, 'psnr': psnr, 'lpips': lpips, 'image_loss': image_loss, 'feature_loss': feature_loss} # loss for one batch def loss_of_one_batch(batch, model, criterion, device, symmetrize_batch=False, use_amp=False, ret=None): view1, view2, target_view = batch ignore_keys = set(['depthmap', 'dataset', 'label', 'instance', 'idx', 'true_shape', 'rng', 'pts3d']) for view in batch: for name in view.keys(): # pseudo_focal if name in ignore_keys: continue view[name] = view[name].to(device, non_blocking=True) if symmetrize_batch: view1, view2 = make_batch_symmetric(batch) # Get the actual model if it's distributed actual_model = model.module if hasattr(model, 'module') else model with torch.cuda.amp.autocast(enabled=bool(use_amp)): pred1, pred2 = actual_model(view1, view2) # loss is supposed to be symmetric with torch.cuda.amp.autocast(enabled=False): loss = criterion(view1, view2, target_view, pred1, pred2, actual_model) if criterion is not None else None result = dict(view1=view1, view2=view2, target_view=target_view, pred1=pred1, pred2=pred2, loss=loss) return result[ret] if ret else result