# Copyright (C) 2024-present Naver Corporation. All rights reserved. # Licensed under CC BY-NC-SA 4.0 (non-commercial use only). # # -------------------------------------------------------- # Main class for the implementation of the global alignment # -------------------------------------------------------- import numpy as np import torch import torch.nn as nn from dust3r.cloud_opt.base_opt import BasePCOptimizer from dust3r.utils.geometry import xy_grid, geotrf from dust3r.utils.device import to_cpu, to_numpy import torch.nn.functional as F class PointCloudOptimizer(BasePCOptimizer): """ Optimize a global scene, given a list of pairwise observations. Graph node: images Graph edges: observations = (pred1, pred2) """ def __init__(self, *args, optimize_pp=False, focal_break=20, **kwargs): super().__init__(*args, **kwargs) self.has_im_poses = True # by definition of this class self.focal_break = focal_break # adding thing to optimize self.im_depthmaps = nn.ParameterList(torch.randn(H, W)/10-3 for H, W in self.imshapes) # log(depth) self.im_poses = nn.ParameterList(self.rand_pose(self.POSE_DIM) for _ in range(self.n_imgs)) # camera poses self.im_focals = nn.ParameterList(torch.FloatTensor( [self.focal_break*np.log(max(H, W))]) for H, W in self.imshapes) # camera intrinsics self.im_pp = nn.ParameterList(torch.zeros((2,)) for _ in range(self.n_imgs)) # camera intrinsics self.im_pp.requires_grad_(optimize_pp) self.imshape = self.imshapes[0] im_areas = [h*w for h, w in self.imshapes] self.max_area = max(im_areas) # adding thing to optimize self.im_depthmaps = ParameterStack(self.im_depthmaps, is_param=True, fill=self.max_area) self.im_poses = ParameterStack(self.im_poses, is_param=True) self.im_focals = ParameterStack(self.im_focals, is_param=True) self.im_pp = ParameterStack(self.im_pp, is_param=True) self.register_buffer('_pp', torch.tensor([(w/2, h/2) for h, w in self.imshapes])) self.register_buffer('_grid', ParameterStack( [xy_grid(W, H, device=self.device) for H, W in self.imshapes], fill=self.max_area)) # pre-compute pixel weights self.register_buffer('_weight_i', ParameterStack( [self.conf_trf(self.conf_i[i_j]) for i_j in self.str_edges], fill=self.max_area)) self.register_buffer('_weight_j', ParameterStack( [self.conf_trf(self.conf_j[i_j]) for i_j in self.str_edges], fill=self.max_area)) # precompute aa self.register_buffer('_stacked_pred_i', ParameterStack(self.pred_i, self.str_edges, fill=self.max_area)) self.register_buffer('_stacked_pred_j', ParameterStack(self.pred_j, self.str_edges, fill=self.max_area)) self.register_buffer('_ei', torch.tensor([i for i, j in self.edges])) self.register_buffer('_ej', torch.tensor([j for i, j in self.edges])) self.total_area_i = sum([im_areas[i] for i, j in self.edges]) self.total_area_j = sum([im_areas[j] for i, j in self.edges]) def _check_all_imgs_are_selected(self, msk): assert np.all(self._get_msk_indices(msk) == np.arange(self.n_imgs)), 'incomplete mask!' def preset_pose(self, known_poses, pose_msk=None): # cam-to-world self._check_all_imgs_are_selected(pose_msk) if isinstance(known_poses, torch.Tensor) and known_poses.ndim == 2: known_poses = [known_poses] for idx, pose in zip(self._get_msk_indices(pose_msk), known_poses): if self.verbose: print(f' (setting pose #{idx} = {pose[:3,3]})') self._no_grad(self._set_pose(self.im_poses, idx, torch.tensor(pose))) # normalize scale if there's less than 1 known pose n_known_poses = sum((p.requires_grad is False) for p in self.im_poses) self.norm_pw_scale = (n_known_poses <= 1) self.im_poses.requires_grad_(False) self.norm_pw_scale = False def preset_focal(self, known_focals, msk=None): self._check_all_imgs_are_selected(msk) for idx, focal in zip(self._get_msk_indices(msk), known_focals): if self.verbose: print(f' (setting focal #{idx} = {focal})') self._no_grad(self._set_focal(idx, focal)) self.im_focals.requires_grad_(False) def preset_principal_point(self, known_pp, msk=None): self._check_all_imgs_are_selected(msk) for idx, pp in zip(self._get_msk_indices(msk), known_pp): if self.verbose: print(f' (setting principal point #{idx} = {pp})') self._no_grad(self._set_principal_point(idx, pp)) self.im_pp.requires_grad_(False) def _get_msk_indices(self, msk): if msk is None: return range(self.n_imgs) elif isinstance(msk, int): return [msk] elif isinstance(msk, (tuple, list)): return self._get_msk_indices(np.array(msk)) elif msk.dtype in (bool, torch.bool, np.bool_): assert len(msk) == self.n_imgs return np.where(msk)[0] elif np.issubdtype(msk.dtype, np.integer): return msk else: raise ValueError(f'bad {msk=}') def _no_grad(self, tensor): assert tensor.requires_grad, 'it must be True at this point, otherwise no modification occurs' def _set_focal(self, idx, focal, force=False): param = self.im_focals[idx] if param.requires_grad or force: # can only init a parameter not already initialized param.data[:] = self.focal_break * np.log(focal) return param def get_focals(self): log_focals = torch.stack(list(self.im_focals), dim=0) return (log_focals / self.focal_break).exp() def get_known_focal_mask(self): return torch.tensor([not (p.requires_grad) for p in self.im_focals]) def _set_principal_point(self, idx, pp, force=False): param = self.im_pp[idx] H, W = self.imshapes[idx] if param.requires_grad or force: # can only init a parameter not already initialized param.data[:] = to_cpu(to_numpy(pp) - (W/2, H/2)) / 10 return param def get_principal_points(self): return self._pp + 10 * self.im_pp def get_intrinsics(self): K = torch.zeros((self.n_imgs, 3, 3), device=self.device) focals = self.get_focals().flatten() K[:, 0, 0] = K[:, 1, 1] = focals K[:, :2, 2] = self.get_principal_points() K[:, 2, 2] = 1 return K def get_im_poses(self): # cam to world cam2world = self._get_poses(self.im_poses) return cam2world def _set_depthmap(self, idx, depth, force=False): depth = _ravel_hw(depth, self.max_area) param = self.im_depthmaps[idx] if param.requires_grad or force: # can only init a parameter not already initialized param.data[:] = depth.log().nan_to_num(neginf=0) return param def get_depthmaps(self, raw=False): res = self.im_depthmaps.exp() if not raw: res = [dm[:h*w].view(h, w) for dm, (h, w) in zip(res, self.imshapes)] return res def depth_to_pts3d(self): # Get depths and projection params if not provided focals = self.get_focals() pp = self.get_principal_points() im_poses = self.get_im_poses() depth = self.get_depthmaps(raw=True) # get pointmaps in camera frame rel_ptmaps = _fast_depthmap_to_pts3d(depth, self._grid, focals, pp=pp) # project to world frame return geotrf(im_poses, rel_ptmaps) def get_pts3d(self, raw=False): res = self.depth_to_pts3d() if not raw: res = [dm[:h*w].view(h, w, 3) for dm, (h, w) in zip(res, self.imshapes)] return res # def cosine_similarity_batch(self, semantic_features, query_pixels): # # 扩展维度进行广播计算余弦相似度 # query_pixels = query_pixels.unsqueeze(1) # [B, 1, C] # semantic_features = semantic_features.unsqueeze(0) # [1, H, W, C] # cos_sim = F.cosine_similarity(query_pixels, semantic_features, dim=-1) # [B, H, W] # return cos_sim # def semantic_loss(self, semantic_features, predicted_depth, window_size=32, stride=16, lambda_semantic=0.1): # # 获取图像的尺寸 # height, width, channels = semantic_features.shape # # 执行矩阵化处理 # ret_loss = 0.0 # cnt = 0 # for i in range(0, height, stride): # for j in range(0, width, stride): # window_semantic = semantic_features[i:min(i+window_size,height), j:min(j+window_size,width), :] # window_depth = predicted_depth[i:min(i+window_size,height), j:min(j+window_size,width)] # # print(window_semantic.shape, window_depth.shape) # window_semantic = window_semantic.reshape(-1, channels) # window_depth = window_depth.reshape(-1, 1) # cos_sim = torch.matmul(window_semantic, window_semantic.t()) # dep_dif = torch.abs(window_depth - window_depth.reshape(1, -1)) # # print(torch.sum(cos_sim * dep_dif)) # ret_loss += torch.mean(cos_sim * dep_dif) # cnt += 1 # return ret_loss / cnt # def segmap_loss(self, predicted_depth, seg_map): # ret_loss = 0.0 # cnt = 0 # seg_map = seg_map.view(-1) # predicted_depth = predicted_depth.view(-1, 1) # unique_groups = torch.unique(seg_map) # for group in unique_groups: # # print(group) # if group == -1: # continue # group_indices = (seg_map == group).nonzero(as_tuple=True)[0] # if len(group_indices) > 0: # now_feat = predicted_depth[group_indices] # dep_dif = torch.abs(now_feat - now_feat.reshape(1, -1)) # ret_loss += torch.mean(dep_dif) # cnt += 1 # return ret_loss / cnt if cnt > 0 else ret_loss # def spatial_smoothness_loss(self, point_map, semantic_map): # """ # 计算空间平滑性损失,使得同一语义类别的相邻像素点空间位置变化不剧烈。 # 使用八邻域。 # 参数: # - point_map: (H, W, 3),表示每个像素点的空间坐标 (x, y, z) # - semantic_map: (H, W, 1),每个像素点的语义标签 # 返回: # - 总损失值 # """ # # 获取图像的高度和宽度 # H, W = semantic_map.shape # # 将点图和语义图调整为二维形式 # point_map = point_map.view(-1, 3) # (H * W, 3) # semantic_map = semantic_map.view(-1) # (H * W,) # # 创建图像的索引 # row_idx, col_idx = torch.meshgrid(torch.arange(H), torch.arange(W)) # row_idx = row_idx.flatten() # col_idx = col_idx.flatten() # # 定义八邻域偏移 # neighbor_offsets = torch.tensor([[-1, 0], [1, 0], [0, -1], [0, 1], # [-1, -1], [-1, 1], [1, -1], [1, 1]], dtype=torch.long) # # 存储损失值 # total_loss = 0.0 # # 对每个像素点进行计算 # for offset in neighbor_offsets: # # 计算邻居位置 # neighbor_row = row_idx + offset[0] # neighbor_col = col_idx + offset[1] # # 确保邻居在图像内部 # valid_mask = (neighbor_row >= 0) & (neighbor_row < H) & (neighbor_col >= 0) & (neighbor_col < W) # valid_row = neighbor_row[valid_mask] # valid_col = neighbor_col[valid_mask] # # 获取有效像素点的索引 # idx = valid_mask.nonzero(as_tuple=True)[0] # neighbor_idx = valid_row * W + valid_col # # 获取相邻像素点的语义标签和空间坐标 # sem_i = semantic_map[idx] # sem_j = semantic_map[neighbor_idx] # p_i = point_map[idx] # p_j = point_map[neighbor_idx] # # 计算空间坐标差异的平方 # distance = torch.sum((p_i - p_j) ** 2, dim=1) # # 如果相邻像素属于同一语义类别,计算损失 # loss_mask = (sem_i == sem_j) # total_loss += torch.sum(loss_mask * distance) # # 平均损失 # return total_loss / point_map.size(0) def spatial_smoothness_loss_multi_image(self, point_maps, semantic_maps, confidence_maps): """ 计算空间平滑性损失,考虑多张图像中属于同一物体的像素点的空间平滑性。 参数: - point_maps: (B, H, W, 3),每张图像的空间坐标 (x, y, z) B是batch大小 - semantic_maps: (B, H, W, 1),每张图像的语义标签 返回: - 总损失值 """ B, H, W = semantic_maps.shape # 将点图和语义图调整为二维形式 point_maps = point_maps.view(B, -1, 3) # (B, H*W, 3) semantic_maps = semantic_maps.view(B, -1) # (B, H*W) confidence_maps = confidence_maps.view(B, -1) # (B, H*W) # 存储损失值 total_loss = 0.0 # 对每张图像中的每个像素进行计算 for b in range(B): # 获取当前图像的点图和语义图 point_map = point_maps[b] semantic_map = semantic_maps[b] confidence_map = confidence_maps[b] # 创建图像的索引 row_idx, col_idx = torch.meshgrid(torch.arange(H), torch.arange(W)) row_idx = row_idx.flatten() col_idx = col_idx.flatten() # 定义八邻域偏移 neighbor_offsets = torch.tensor([[-1, 0], [1, 0], [0, -1], [0, 1], [-1, -1], [-1, 1], [1, -1], [1, 1]], dtype=torch.long) # 对每个像素点进行计算(仅在当前图像内计算邻域关系) for offset in neighbor_offsets: # 计算邻居位置 neighbor_row = row_idx + offset[0] neighbor_col = col_idx + offset[1] # 确保邻居在图像内部 valid_mask = (neighbor_row >= 0) & (neighbor_row < H) & (neighbor_col >= 0) & (neighbor_col < W) valid_row = neighbor_row[valid_mask] valid_col = neighbor_col[valid_mask] # 获取有效像素点的索引 idx = valid_mask.nonzero(as_tuple=True)[0] neighbor_idx = valid_row * W + valid_col # 获取相邻像素点的语义标签和空间坐标 sem_i = semantic_map[idx] sem_j = semantic_map[neighbor_idx] p_i = point_map[idx] p_j = point_map[neighbor_idx] conf_i = confidence_map[idx] conf_j = confidence_map[neighbor_idx] # 计算空间坐标差异的平方 distance = torch.sum((p_i - p_j)**2, dim=1) # 如果相邻像素属于同一语义类别,计算加权损失 loss_mask = (sem_i == sem_j) # 反向加权,低置信度的点会有更高的权重 # inverse_weight_i = 1.0 / (conf_i) # 防止除零错误 # inverse_weight_j = 1.0 / (conf_j) weighted_distance = loss_mask * distance # 加权损失 * inverse_weight_i * inverse_weight_j total_loss += torch.sum(weighted_distance) # 跨图计算:对于同一语义类别的像素,只计算其均值差异,避免两两计算 # for b2 in range(B): # if b == b2: # continue # 跳过与自己图像的比较 # point_map_b2 = point_maps[b2] # semantic_map_b2 = semantic_maps[b2] # confidence_map_b2 = confidence_maps[b2] # for sem_id in torch.unique(semantic_map): # sem_mask_a = (semantic_map == sem_id) # sem_mask_b2 = (semantic_map_b2 == sem_id) # # 提取同一语义类别的像素点 # shared_points_a = point_map[sem_mask_a] # shared_points_b2 = point_map_b2[sem_mask_b2] # shared_conf_a = confidence_map[sem_mask_a] # shared_conf_b2 = confidence_map_b2[sem_mask_b2] # if shared_points_a.shape[0] > 0 and shared_points_b2.shape[0] > 0: # # 计算这些像素点的均值 # mean_a = shared_points_a.mean(dim=0) # 当前图像该语义类别的均值 # mean_b2 = shared_points_b2.mean(dim=0) # 第b2图像该语义类别的均值 # mean_conf_a = shared_conf_a.mean() # 当前图像该语义类别的置信度均值 # mean_conf_b2 = shared_conf_b2.mean() # 第b2图像该语义类别的置信度均值 # # 计算均值之间的空间差异,并考虑置信度的加权 # distance_cross = torch.sum((mean_a - mean_b2) ** 2) # weighted_distance_cross = distance_cross * mean_conf_a * mean_conf_b2 # total_loss += weighted_distance_cross # 平均损失 return total_loss / (B * H * W) def forward(self, cur_iter=0): pw_poses = self.get_pw_poses() # cam-to-world pw_adapt = self.get_adaptors().unsqueeze(1) proj_pts3d = self.get_pts3d(raw=True) loss = 0.0 # depth = self.get_depthmaps(raw=True) # print(depth.shape) # if cur_iter < 100: # # for i, pointmap in enumerate(proj_pts3d): # # loss += self.spatial_smoothness_loss(pointmap, seg_maps[i].cuda()) # # depths = self.get_depthmaps() # # # cogs = self.cogs # # seg_maps = self.segmaps # # im_conf = self.conf_trf(torch.stack([param_tensor for param_tensor in self.im_conf])) # # for i, depth in enumerate(depths): # # # print(seg_maps[i].shape) # # # H, W = depth.shape # # # tmp = cogs[i].reshape(-1, 1024) # # # tmp = torch.matmul(tmp, self.cog_matrix.detach().t()) # # # tmp / (tmp.norm(dim=-1, keepdim=True)+0.000000000001) # # # tmp = tmp.reshape(H, W, 3) # # loss += self.segmap_loss(depth, seg_maps[i], im_conf[i]) # # loss += self.semantic_loss(cogs[i], depth) # # im_conf = self.conf_trf(torch.stack([param_tensor for param_tensor in self.im_conf])) # # cogs = self.cogs.permute(0, 3, 1, 2) # # cogs = F.interpolate(cogs, scale_factor=2, mode='nearest') # # cogs = cogs.permute(0, 2, 3, 1) # # cogs = torch.stack(self.cogs).view(-1, 1024) # # proj = proj_pts3d.view(-1, 3) # # proj = proj / proj.norm(dim=-1, keepdim=True) # # img_conf = im_conf.view(-1,1) # # selected_indices = torch.where(img_conf > 2.0)[0] # # img_conf = img_conf[selected_indices] # # cogs = cogs[selected_indices] # # proj = proj[selected_indices] # # print(img_conf.shape, cogs.shape, proj.shape) # # proj_dis = torch.matmul(proj, proj.t()) # # cogs_dis = torch.matmul(cogs, cogs.t()) # # loss += (im_conf * F.mse_loss(proj_dis, cogs_dis, reduction='none')).mean() # # if cur_iter % 2 == 0: # # tmp = torch.matmul(cogs.detach(), self.cog_matrix.detach().t()) # # tmp = tmp / (tmp.norm(dim=-1, keepdim=True)+0.000000000001) # # loss += 0/1*(img_conf * F.mse_loss(proj, tmp, reduction='none')).mean() # # if cur_iter % 2 == 1: # # tmp = torch.matmul(cogs.view(-1, 1024), self.cog_matrix.detach().t()) # # tmp = tmp / tmp.norm(dim=-1, keepdim=True) # # loss += (im_conf.view(-1,1) * F.mse_loss(proj.detach(), tmp, reduction='none')).mean() # # if cur_iter % 3 == 2: # # tmp = torch.matmul(cogs.view(-1, 1024).detach(), self.cog_matrix.t()) # # tmp = tmp / tmp.norm(dim=-1, keepdim=True) # # loss += (im_conf.view(-1,1) * F.mse_loss(proj.detach(), tmp, reduction='none')).mean() seg_maps = torch.stack(self.segmaps).cuda() im_conf = self.conf_trf(torch.stack([param_tensor for param_tensor in self.im_conf])) loss += self.spatial_smoothness_loss_multi_image(proj_pts3d, seg_maps, im_conf) # # if cur_iter > 100: # # rotate pairwise prediction according to pw_poses # aligned_pred_i = geotrf(pw_poses, pw_adapt * self._stacked_pred_i) # aligned_pred_j = geotrf(pw_poses, pw_adapt * self._stacked_pred_j) # loss += self.spatial_smoothness_loss_multi_image(aligned_pred_i, seg_maps[self._ei], im_conf[self._ei]) # loss += self.spatial_smoothness_loss_multi_image(aligned_pred_j, seg_maps[self._ej], im_conf[self._ej]) # # compute the less # loss += self.dist(proj_pts3d[self._ei], aligned_pred_i, weight=self._weight_i).sum() / self.total_area_i # loss += self.dist(proj_pts3d[self._ej], aligned_pred_j, weight=self._weight_j).sum() / self.total_area_j return loss def _fast_depthmap_to_pts3d(depth, pixel_grid, focal, pp): pp = pp.unsqueeze(1) focal = focal.unsqueeze(1) assert focal.shape == (len(depth), 1, 1) assert pp.shape == (len(depth), 1, 2) assert pixel_grid.shape == depth.shape + (2,) depth = depth.unsqueeze(-1) return torch.cat((depth * (pixel_grid - pp) / focal, depth), dim=-1) def ParameterStack(params, keys=None, is_param=None, fill=0): if keys is not None: params = [params[k] for k in keys] if fill > 0: params = [_ravel_hw(p, fill) for p in params] requires_grad = params[0].requires_grad assert all(p.requires_grad == requires_grad for p in params) params = torch.stack(list(params)).float().detach() if is_param or requires_grad: params = nn.Parameter(params) params.requires_grad_(requires_grad) return params def _ravel_hw(tensor, fill=0): # ravel H,W tensor = tensor.view((tensor.shape[0] * tensor.shape[1],) + tensor.shape[2:]) if len(tensor) < fill: tensor = torch.cat((tensor, tensor.new_zeros((fill - len(tensor),)+tensor.shape[1:]))) return tensor def acceptable_focal_range(H, W, minf=0.5, maxf=3.5): focal_base = max(H, W) / (2 * np.tan(np.deg2rad(60) / 2)) # size / 1.1547005383792515 return minf*focal_base, maxf*focal_base def apply_mask(img, msk): img = img.copy() img[msk] = 0 return img