# Copyright (C) 2024-present Naver Corporation. All rights reserved. | |
# Licensed under CC BY-NC-SA 4.0 (non-commercial use only). | |
# | |
# -------------------------------------------------------- | |
# Base class for the global alignement procedure | |
# -------------------------------------------------------- | |
from copy import deepcopy | |
import numpy as np | |
import torch | |
import torch.nn as nn | |
import roma | |
from copy import deepcopy | |
import tqdm | |
from torch.nn.functional import cosine_similarity | |
import cv2 | |
from dust3r.utils.geometry import inv, geotrf | |
from dust3r.utils.device import to_numpy | |
from dust3r.utils.image import rgb | |
from dust3r.viz import SceneViz, segment_sky, auto_cam_size | |
from dust3r.optim_factory import adjust_learning_rate_by_lr | |
from dust3r.cloud_opt.commons import (edge_str, ALL_DISTS, NoGradParamDict, get_imshapes, signed_expm1, signed_log1p, | |
cosine_schedule, linear_schedule, get_conf_trf, GradParamDict) | |
import dust3r.cloud_opt.init_im_poses as init_fun | |
class BasePCOptimizer (nn.Module): | |
""" Optimize a global scene, given a list of pairwise observations. | |
Graph node: images | |
Graph edges: observations = (pred1, pred2) | |
""" | |
def __init__(self, *args, **kwargs): | |
if len(args) == 1 and len(kwargs) == 0: | |
other = deepcopy(args[0]) | |
attrs = '''edges is_symmetrized dist n_imgs pred_i pred_j imshapes | |
min_conf_thr conf_thr conf_i conf_j im_conf | |
base_scale norm_pw_scale POSE_DIM pw_poses | |
pw_adaptors pw_adaptors has_im_poses rand_pose imgs verbose'''.split() | |
self.__dict__.update({k: other[k] for k in attrs}) | |
else: | |
self._init_from_views(*args, **kwargs) | |
def _init_from_views(self, view1, view2, pred1, pred2, cog_seg_maps, rev_cog_seg_maps, semantic_feats, | |
dist='l2', | |
conf='log', | |
min_conf_thr=3, | |
base_scale=0.5, | |
allow_pw_adaptors=False, | |
pw_break=20, | |
rand_pose=torch.randn, | |
iterationsCount=None, | |
verbose=True): | |
super().__init__() | |
if not isinstance(view1['idx'], list): | |
view1['idx'] = view1['idx'].tolist() | |
if not isinstance(view2['idx'], list): | |
view2['idx'] = view2['idx'].tolist() | |
self.edges = [(int(i), int(j)) for i, j in zip(view1['idx'], view2['idx'])] | |
self.is_symmetrized = set(self.edges) == {(j, i) for i, j in self.edges} | |
self.dist = ALL_DISTS[dist] | |
self.verbose = verbose | |
self.n_imgs = self._check_edges() | |
# input data | |
pred1_pts = pred1['pts3d'] | |
pred2_pts = pred2['pts3d_in_other_view'] | |
self.pred_i = NoGradParamDict({ij: pred1_pts[n] for n, ij in enumerate(self.str_edges)}) | |
self.pred_j = NoGradParamDict({ij: pred2_pts[n] for n, ij in enumerate(self.str_edges)}) | |
# self.ori_pred_i = NoGradParamDict({ij: pred1_pts[n] for n, ij in enumerate(self.str_edges)}) | |
# self.ori_pred_j = NoGradParamDict({ij: pred2_pts[n] for n, ij in enumerate(self.str_edges)}) | |
self.imshapes = get_imshapes(self.edges, pred1_pts, pred2_pts) | |
# work in log-scale with conf | |
pred1_conf = pred1['conf'] | |
pred2_conf = pred2['conf'] | |
self.min_conf_thr = min_conf_thr | |
self.conf_trf = get_conf_trf(conf) | |
self.conf_i = NoGradParamDict({ij: pred1_conf[e] for e, ij in enumerate(self.str_edges)}) | |
self.conf_j = NoGradParamDict({ij: pred2_conf[e] for e, ij in enumerate(self.str_edges)}) | |
self.im_conf = self._compute_img_conf(pred1_conf, pred2_conf) | |
for i in range(len(self.im_conf)): | |
self.im_conf[i].requires_grad = False | |
# pairwise pose parameters | |
self.base_scale = base_scale | |
self.norm_pw_scale = True | |
self.pw_break = pw_break | |
self.POSE_DIM = 7 | |
self.pw_poses = nn.Parameter(rand_pose((self.n_edges, 1+self.POSE_DIM))) # pairwise poses | |
self.pw_poses.requires_grad_(True) | |
self.pw_adaptors = nn.Parameter(torch.zeros((self.n_edges, 2))) # slight xy/z adaptation | |
self.pw_adaptors.requires_grad_(True) | |
self.has_im_poses = False | |
self.rand_pose = rand_pose | |
# possibly store images for show_pointcloud | |
self.imgs = None | |
if 'img' in view1 and 'img' in view2: | |
imgs = [torch.zeros((3,)+hw) for hw in self.imshapes] | |
smoothed_imgs = [torch.zeros((3,)+hw) for hw in self.imshapes] | |
ori_imgs = [torch.zeros((3,)+hw) for hw in self.imshapes] | |
for v in range(len(self.edges)): | |
idx = view1['idx'][v] | |
imgs[idx] = view1['img'][v] | |
smoothed_imgs[idx] = view1['smoothed_img'][v] | |
ori_imgs[idx] = view1['ori_img'][v] | |
idx = view2['idx'][v] | |
imgs[idx] = view2['img'][v] | |
smoothed_imgs[idx] = view2['smoothed_img'][v] | |
ori_imgs[idx] = view2['ori_img'][v] | |
self.imgs = rgb(imgs) | |
self.ori_imgs = rgb(ori_imgs) | |
self.fix_imgs = rgb(ori_imgs) | |
self.smoothed_imgs = rgb(smoothed_imgs) | |
self.cogs = [torch.zeros((h, w, 1024), device="cuda") for h, w in self.imshapes] | |
semantic_feats = semantic_feats.to("cuda") | |
self.segmaps = [-torch.ones((h, w), device="cuda") for h, w in self.imshapes] | |
self.rev_segmaps = [-torch.ones((h, w), device="cuda") for h, w in self.imshapes] | |
# self.conf_1 = [torch.zeros((h, w), device="cuda") for h, w in self.imshapes] | |
# self.conf_2 = [torch.zeros((h, w), device="cuda") for h, w in self.imshapes] | |
for v in range(len(self.edges)): | |
idx = view1['idx'][v] | |
h, w = self.cogs[idx].shape[0], self.cogs[idx].shape[1] | |
cog_seg_map = cog_seg_maps[idx] | |
cog_seg_map = torch.from_numpy(cv2.resize(cog_seg_map, [w, h], interpolation=cv2.INTER_NEAREST)) | |
rev_seg_map = rev_cog_seg_maps[idx] | |
rev_seg_map = torch.from_numpy(cv2.resize(rev_seg_map, [w, h], interpolation=cv2.INTER_NEAREST)) | |
y, x = torch.meshgrid(torch.arange(0, h), torch.arange(0, w)) | |
x = x.reshape(-1, 1) | |
y = y.reshape(-1, 1) | |
seg = cog_seg_map[y, x].squeeze(-1).long() | |
self.cogs[idx] = semantic_feats[seg].reshape(h, w, -1) | |
self.segmaps[idx] = cog_seg_map.cuda() | |
self.rev_segmaps[idx] = rev_seg_map.cuda() | |
idx = view2['idx'][v] | |
h, w = self.cogs[idx].shape[0], self.cogs[idx].shape[1] | |
cog_seg_map = cog_seg_maps[idx] | |
cog_seg_map = torch.from_numpy(cv2.resize(cog_seg_map, [w, h], interpolation=cv2.INTER_NEAREST)) | |
rev_seg_map = rev_cog_seg_maps[idx] | |
rev_seg_map = torch.from_numpy(cv2.resize(rev_seg_map, [w, h], interpolation=cv2.INTER_NEAREST)) | |
y, x = torch.meshgrid(torch.arange(0, h), torch.arange(0, w)) | |
x = x.reshape(-1, 1) | |
y = y.reshape(-1, 1) | |
seg = cog_seg_map[y, x].squeeze(-1).long() | |
self.cogs[idx] = semantic_feats[seg].reshape(h, w, -1) | |
self.segmaps[idx] = cog_seg_map.cuda() | |
self.rev_segmaps[idx] = rev_seg_map.cuda() | |
self.rendered_imgs = [] | |
def render_image(self, text_feats, threshold=0.85): | |
self.rendered_imgs = [] | |
# Collect all cosine similarities to compute min-max normalization | |
all_similarities = [] | |
for each_cog in self.cogs: | |
similarity_map = cosine_similarity(each_cog.to("cpu"), text_feats.to("cpu").unsqueeze(1), dim=-1) | |
all_similarities.append(similarity_map.squeeze().numpy()) | |
# Flatten and normalize all similarities | |
total_similarities = np.concatenate(all_similarities) | |
min_sim, max_sim = total_similarities.min(), total_similarities.max() | |
normalized_similarities = [(sim - min_sim) / (max_sim - min_sim) for sim in all_similarities] | |
# | |
# normalized_similarities = [(sim - sim.min()) / (sim.max() - sim.min()) for sim in all_similarities] | |
# Process each image with normalized similarities | |
for i, (each_cog, heatmap) in enumerate(zip(self.cogs, normalized_similarities)): | |
mask = heatmap > threshold | |
# Scale heatmap for visualization | |
heatmap = np.uint8(255 * heatmap) | |
heatmap_color = cv2.applyColorMap(heatmap, cv2.COLORMAP_JET) | |
# Prepare image | |
image = self.fix_imgs[i] | |
image = image * 255.0 | |
image = np.clip(image, 0, 255).astype(np.uint8) | |
# Apply mask and overlay heatmap with red RGB for masked areas | |
mask_indices = np.where(mask) # Get indices where mask is True | |
heatmap_color[mask_indices[0], mask_indices[1]] = [0, 0, 255] # Red color for masked regions | |
superimposed_img = np.where(np.expand_dims(mask, axis=-1), heatmap_color, image) / 255.0 | |
self.rendered_imgs.append(superimposed_img) | |
def n_edges(self): | |
return len(self.edges) | |
def str_edges(self): | |
return [edge_str(i, j) for i, j in self.edges] | |
def imsizes(self): | |
return [(w, h) for h, w in self.imshapes] | |
def device(self): | |
return next(iter(self.parameters())).device | |
def state_dict(self, trainable=True): | |
all_params = super().state_dict() | |
return {k: v for k, v in all_params.items() if k.startswith(('_', 'pred_i.', 'pred_j.', 'conf_i.', 'conf_j.')) != trainable} | |
def load_state_dict(self, data): | |
return super().load_state_dict(self.state_dict(trainable=False) | data) | |
def _check_edges(self): | |
indices = sorted({i for edge in self.edges for i in edge}) | |
assert indices == list(range(len(indices))), 'bad pair indices: missing values ' | |
return len(indices) | |
def _compute_img_conf(self, pred1_conf, pred2_conf): | |
im_conf = nn.ParameterList([torch.zeros(hw, device=self.device) for hw in self.imshapes]) | |
for e, (i, j) in enumerate(self.edges): | |
im_conf[i] = torch.maximum(im_conf[i], pred1_conf[e]) | |
im_conf[j] = torch.maximum(im_conf[j], pred2_conf[e]) | |
return im_conf | |
def get_adaptors(self): | |
adapt = self.pw_adaptors | |
adapt = torch.cat((adapt[:, 0:1], adapt), dim=-1) # (scale_xy, scale_xy, scale_z) | |
if self.norm_pw_scale: # normalize so that the product == 1 | |
adapt = adapt - adapt.mean(dim=1, keepdim=True) | |
return (adapt / self.pw_break).exp() | |
def _get_poses(self, poses): | |
# normalize rotation | |
Q = poses[:, :4] | |
T = signed_expm1(poses[:, 4:7]) | |
RT = roma.RigidUnitQuat(Q, T).normalize().to_homogeneous() | |
return RT | |
def _set_pose(self, poses, idx, R, T=None, scale=None, force=False): | |
# all poses == cam-to-world | |
pose = poses[idx] | |
if not (pose.requires_grad or force): | |
return pose | |
if R.shape == (4, 4): | |
assert T is None | |
T = R[:3, 3] | |
R = R[:3, :3] | |
if R is not None: | |
pose.data[0:4] = roma.rotmat_to_unitquat(R) | |
if T is not None: | |
pose.data[4:7] = signed_log1p(T / (scale or 1)) # translation is function of scale | |
if scale is not None: | |
assert poses.shape[-1] in (8, 13) | |
pose.data[-1] = np.log(float(scale)) | |
return pose | |
def get_pw_norm_scale_factor(self): | |
if self.norm_pw_scale: | |
# normalize scales so that things cannot go south | |
# we want that exp(scale) ~= self.base_scale | |
return (np.log(self.base_scale) - self.pw_poses[:, -1].mean()).exp() | |
else: | |
return 1 # don't norm scale for known poses | |
def get_pw_scale(self): | |
scale = self.pw_poses[:, -1].exp() # (n_edges,) | |
scale = scale * self.get_pw_norm_scale_factor() | |
return scale | |
def get_pw_poses(self): # cam to world | |
RT = self._get_poses(self.pw_poses) | |
scaled_RT = RT.clone() | |
scaled_RT[:, :3] *= self.get_pw_scale().view(-1, 1, 1) # scale the rotation AND translation | |
return scaled_RT | |
def get_masks(self): | |
return [(conf > self.min_conf_thr) for conf in self.im_conf] | |
def depth_to_pts3d(self): | |
raise NotImplementedError() | |
def get_pts3d(self, raw=False): | |
res = self.depth_to_pts3d() | |
if not raw: | |
res = [dm[:h*w].view(h, w, 3) for dm, (h, w) in zip(res, self.imshapes)] | |
return res | |
def _set_focal(self, idx, focal, force=False): | |
raise NotImplementedError() | |
def get_focals(self): | |
raise NotImplementedError() | |
def get_known_focal_mask(self): | |
raise NotImplementedError() | |
def get_principal_points(self): | |
raise NotImplementedError() | |
def get_conf(self, mode=None): | |
trf = self.conf_trf if mode is None else get_conf_trf(mode) | |
return [trf(c) for c in self.im_conf] | |
def get_im_poses(self): | |
raise NotImplementedError() | |
def _set_depthmap(self, idx, depth, force=False): | |
raise NotImplementedError() | |
def get_depthmaps(self, raw=False): | |
raise NotImplementedError() | |
def clean_pointcloud(self, **kw): | |
cams = inv(self.get_im_poses()) | |
K = self.get_intrinsics() | |
depthmaps = self.get_depthmaps() | |
all_pts3d = self.get_pts3d() | |
new_im_confs = clean_pointcloud(self.im_conf, K, cams, depthmaps, all_pts3d, **kw) | |
for i, new_conf in enumerate(new_im_confs): | |
self.im_conf[i].data[:] = new_conf | |
return self | |
def forward(self, ret_details=False): | |
pw_poses = self.get_pw_poses() # cam-to-world | |
pw_adapt = self.get_adaptors() | |
proj_pts3d = self.get_pts3d() | |
# pre-compute pixel weights | |
weight_i = {i_j: self.conf_trf(c) for i_j, c in self.conf_i.items()} | |
weight_j = {i_j: self.conf_trf(c) for i_j, c in self.conf_j.items()} | |
loss = 0 | |
if ret_details: | |
details = -torch.ones((self.n_imgs, self.n_imgs)) | |
for e, (i, j) in enumerate(self.edges): | |
i_j = edge_str(i, j) | |
# distance in image i and j | |
aligned_pred_i = geotrf(pw_poses[e], pw_adapt[e] * self.pred_i[i_j]) | |
aligned_pred_j = geotrf(pw_poses[e], pw_adapt[e] * self.pred_j[i_j]) | |
li = self.dist(proj_pts3d[i], aligned_pred_i, weight=weight_i[i_j]).mean() | |
lj = self.dist(proj_pts3d[j], aligned_pred_j, weight=weight_j[i_j]).mean() | |
loss = loss + li + lj | |
if ret_details: | |
details[i, j] = li + lj | |
loss /= self.n_edges # average over all pairs | |
if ret_details: | |
return loss, details | |
return loss | |
def spatial_select_points(self, point_maps, semantic_maps, confidence_maps): | |
H, W = semantic_maps.shape | |
# 将点图和语义图调整为二维形式 | |
point_map = point_maps.view(-1, 3) # (H*W, 3) | |
semantic_map = semantic_maps.view(-1) # (H*W) | |
confidence_map = confidence_maps.view(-1) | |
dist_map = torch.zeros_like(semantic_map, dtype=torch.float32) | |
cnt_map = torch.zeros_like(semantic_map, dtype=torch.float32) | |
# near_point_map = torch.zeros_like(point_map, dtype=torch.float32) | |
# refresh_point_map = point_map.clone() | |
refresh_confidence_map = confidence_map.clone() | |
# 创建图像的索引 | |
row_idx, col_idx = torch.meshgrid(torch.arange(H), torch.arange(W)) | |
row_idx = row_idx.flatten() | |
col_idx = col_idx.flatten() | |
kernel_size = 5 | |
offset_range = kernel_size // 2 | |
neighbor_offsets = [ | |
(dx, dy) for dx in range(-offset_range, offset_range + 1) | |
for dy in range(-offset_range, offset_range + 1) | |
if not (dx == 0 and dy == 0) | |
] | |
# 对每个像素点进行计算(仅在当前图像内计算邻域关系) | |
for offset in neighbor_offsets: | |
# 计算邻居位置 | |
neighbor_row = row_idx + offset[0] | |
neighbor_col = col_idx + offset[1] | |
# 确保邻居在图像内部 | |
valid_mask = (neighbor_row >= 0) & (neighbor_row < H) & (neighbor_col >= 0) & (neighbor_col < W) | |
valid_row = neighbor_row[valid_mask] | |
valid_col = neighbor_col[valid_mask] | |
# 获取有效像素点的索引 | |
idx = valid_mask.nonzero(as_tuple=True)[0] | |
neighbor_idx = valid_row * W + valid_col | |
# 获取相邻像素点的语义标签和空间坐标 | |
sem_i = semantic_map[idx] | |
sem_j = semantic_map[neighbor_idx] | |
p_i = point_map[idx] | |
p_j = point_map[neighbor_idx] | |
# 计算空间坐标差异的平方 | |
distance = torch.sum((p_i - p_j)**2, dim=1) | |
same_object = (sem_i == sem_j) & (sem_i != -1) & (sem_j != -1) | |
dist_map[idx] += same_object * distance | |
cnt_map[idx] += same_object | |
anomaly_point = (dist_map / cnt_map) | |
tmp = (cnt_map==0) | |
idx = tmp.nonzero(as_tuple=True)[0] | |
anomaly_point[idx] = 0 | |
mean = torch.mean(anomaly_point) | |
std = torch.std(anomaly_point) | |
anomaly_point = (anomaly_point - mean) / std | |
anomaly_point = (anomaly_point > 0)#0.005) #& (cnt_map != 0) | |
anomaly_point_idx = anomaly_point.nonzero(as_tuple=True)[0] | |
refresh_confidence_map[anomaly_point_idx] = -1 | |
return refresh_confidence_map.view(H, W) | |
def compute_global_alignment(self, tune_flg=False, init=None, niter_PnP=10, **kw): | |
if tune_flg: | |
for e, (i, j) in enumerate(self.edges): | |
i_j = edge_str(i, j) | |
self.conf_i[i_j] = self.spatial_select_points(self.pred_i[i_j], self.rev_segmaps[i], self.conf_i[i_j]) | |
self.conf_j[i_j] = self.spatial_select_points(self.pred_j[i_j], self.rev_segmaps[j], self.conf_j[i_j]) | |
self.im_conf[i] = self.conf_i[i_j] | |
self.im_conf[j] = self.conf_j[i_j] | |
threshold = 0.25 | |
for i in range(len(self.imgs)): | |
# self.imgs[i] = self.ori_imgs[i] | |
anomaly_mask = (self.im_conf[i] == -1) | |
unique_labels = torch.unique(self.rev_segmaps[i]) | |
# self.imgs[i][anomaly_mask.cpu()] = self.smoothed_imgs[i][anomaly_mask.cpu()] | |
for label in unique_labels: | |
semantic_mask = (self.rev_segmaps[i] == label) | |
if label == -1: | |
continue | |
cover = (semantic_mask & anomaly_mask).sum() / semantic_mask.sum() | |
if cover > threshold: | |
self.imgs[i][semantic_mask.cpu()] = self.smoothed_imgs[i][semantic_mask.cpu()] | |
for j in range(len(self.imgs)): | |
if j == i: | |
continue | |
semantic_mask = (self.rev_segmaps[j] == label) | |
self.imgs[j][semantic_mask.cpu()] = self.smoothed_imgs[j][semantic_mask.cpu()] | |
if init is None: | |
pass | |
elif init == 'msp' or init == 'mst': | |
init_fun.init_minimum_spanning_tree(self, niter_PnP=niter_PnP) | |
elif init == 'known_poses': | |
init_fun.init_from_known_poses(self, min_conf_thr=self.min_conf_thr, | |
niter_PnP=niter_PnP) | |
else: | |
raise ValueError(f'bad value for {init=}') | |
if tune_flg: | |
return 0 | |
loss = global_alignment_loop(self, **kw) | |
return loss | |
def mask_sky(self): | |
res = deepcopy(self) | |
for i in range(self.n_imgs): | |
sky = segment_sky(self.imgs[i]) | |
res.im_conf[i][sky] = 0 | |
return res | |
def show(self, show_pw_cams=False, show_pw_pts3d=False, cam_size=None, **kw): | |
viz = SceneViz() | |
if self.imgs is None: | |
colors = np.random.randint(0, 256, size=(self.n_imgs, 3)) | |
colors = list(map(tuple, colors.tolist())) | |
for n in range(self.n_imgs): | |
viz.add_pointcloud(self.get_pts3d()[n], colors[n], self.get_masks()[n]) | |
else: | |
viz.add_pointcloud(self.get_pts3d(), self.imgs, self.get_masks()) | |
colors = np.random.randint(256, size=(self.n_imgs, 3)) | |
# camera poses | |
im_poses = to_numpy(self.get_im_poses()) | |
if cam_size is None: | |
cam_size = auto_cam_size(im_poses) | |
viz.add_cameras(im_poses, self.get_focals(), colors=colors, | |
images=self.imgs, imsizes=self.imsizes, cam_size=cam_size) | |
if show_pw_cams: | |
pw_poses = self.get_pw_poses() | |
viz.add_cameras(pw_poses, color=(192, 0, 192), cam_size=cam_size) | |
if show_pw_pts3d: | |
pts = [geotrf(pw_poses[e], self.pred_i[edge_str(i, j)]) for e, (i, j) in enumerate(self.edges)] | |
viz.add_pointcloud(pts, (128, 0, 128)) | |
viz.show(**kw) | |
return viz | |
def global_alignment_loop(net, lr=0.01, niter=300, schedule='cosine', lr_min=1e-6): | |
# return net | |
params = [p for p in net.parameters() if p.requires_grad] | |
# for param in params: | |
# print(param.shape) | |
if not params: | |
return net | |
verbose = net.verbose | |
if verbose: | |
print('Global alignement - optimizing for:') | |
print([name for name, value in net.named_parameters() if value.requires_grad]) | |
lr_base = lr | |
optimizer = torch.optim.Adam(params, lr=lr, betas=(0.9, 0.9)) | |
loss = float('inf') | |
if verbose: | |
with tqdm.tqdm(total=niter) as bar: | |
while bar.n < bar.total: | |
loss, lr = global_alignment_iter(net, bar.n, niter, lr_base, lr_min, optimizer, schedule) | |
bar.set_postfix_str(f'{lr=:g} loss={loss:g}') | |
bar.update() | |
else: | |
for n in range(niter): | |
loss, _ = global_alignment_iter(net, n, niter, lr_base, lr_min, optimizer, schedule) | |
return loss | |
def global_alignment_iter(net, cur_iter, niter, lr_base, lr_min, optimizer, schedule): | |
t = cur_iter / niter | |
if schedule == 'cosine': | |
lr = cosine_schedule(t, lr_base, lr_min) | |
elif schedule == 'linear': | |
lr = linear_schedule(t, lr_base, lr_min) | |
else: | |
raise ValueError(f'bad lr {schedule=}') | |
adjust_learning_rate_by_lr(optimizer, lr) | |
optimizer.zero_grad() | |
loss = net(cur_iter) | |
if loss == 0: | |
optimizer.step() | |
return float(loss), lr | |
loss.backward() | |
optimizer.step() | |
return float(loss), lr | |
def clean_pointcloud( im_confs, K, cams, depthmaps, all_pts3d, | |
tol=0.001, bad_conf=0, dbg=()): | |
""" Method: | |
1) express all 3d points in each camera coordinate frame | |
2) if they're in front of a depthmap --> then lower their confidence | |
""" | |
assert len(im_confs) == len(cams) == len(K) == len(depthmaps) == len(all_pts3d) | |
assert 0 <= tol < 1 | |
res = [c.clone() for c in im_confs] | |
# reshape appropriately | |
all_pts3d = [p.view(*c.shape,3) for p,c in zip(all_pts3d, im_confs)] | |
depthmaps = [d.view(*c.shape) for d,c in zip(depthmaps, im_confs)] | |
for i, pts3d in enumerate(all_pts3d): | |
for j in range(len(all_pts3d)): | |
if i == j: continue | |
# project 3dpts in other view | |
proj = geotrf(cams[j], pts3d) | |
proj_depth = proj[:,:,2] | |
u,v = geotrf(K[j], proj, norm=1, ncol=2).round().long().unbind(-1) | |
# check which points are actually in the visible cone | |
H, W = im_confs[j].shape | |
msk_i = (proj_depth > 0) & (0 <= u) & (u < W) & (0 <= v) & (v < H) | |
msk_j = v[msk_i], u[msk_i] | |
# find bad points = those in front but less confident | |
bad_points = (proj_depth[msk_i] < (1-tol) * depthmaps[j][msk_j]) & (res[i][msk_i] < res[j][msk_j]) | |
bad_msk_i = msk_i.clone() | |
bad_msk_i[msk_i] = bad_points | |
res[i][bad_msk_i] = res[i][bad_msk_i].clip_(max=bad_conf) | |
return res | |
# Copyright (C) 2024-present Naver Corporation. All rights reserved. | |
# Licensed under CC BY-NC-SA 4.0 (non-commercial use only). | |
# | |
# -------------------------------------------------------- | |
# Base class for the global alignement procedure | |
# -------------------------------------------------------- | |
# from copy import deepcopy | |
# import numpy as np | |
# import torch | |
# import torch.nn as nn | |
# import roma | |
# from copy import deepcopy | |
# import tqdm | |
# from torch.nn.functional import cosine_similarity | |
# import cv2 | |
# from dust3r.utils.geometry import inv, geotrf | |
# from dust3r.utils.device import to_numpy | |
# from dust3r.utils.image import rgb | |
# from dust3r.viz import SceneViz, segment_sky, auto_cam_size | |
# from dust3r.optim_factory import adjust_learning_rate_by_lr | |
# from dust3r.cloud_opt.commons import (edge_str, ALL_DISTS, NoGradParamDict, get_imshapes, signed_expm1, signed_log1p, | |
# cosine_schedule, linear_schedule, get_conf_trf, GradParamDict) | |
# import dust3r.cloud_opt.init_im_poses as init_fun | |
# class BasePCOptimizer (nn.Module): | |
# """ Optimize a global scene, given a list of pairwise observations. | |
# Graph node: images | |
# Graph edges: observations = (pred1, pred2) | |
# """ | |
# def __init__(self, *args, **kwargs): | |
# if len(args) == 1 and len(kwargs) == 0: | |
# other = deepcopy(args[0]) | |
# attrs = '''edges is_symmetrized dist n_imgs pred_i pred_j imshapes | |
# min_conf_thr conf_thr conf_i conf_j im_conf | |
# base_scale norm_pw_scale POSE_DIM pw_poses | |
# pw_adaptors pw_adaptors has_im_poses rand_pose imgs verbose'''.split() | |
# self.__dict__.update({k: other[k] for k in attrs}) | |
# else: | |
# self._init_from_views(*args, **kwargs) | |
# def _init_from_views(self, view1, view2, pred1, pred2, cog_seg_maps, rev_cog_seg_maps, semantic_feats, | |
# dist='l2', | |
# conf='log', | |
# min_conf_thr=3, | |
# base_scale=0.5, | |
# allow_pw_adaptors=False, | |
# pw_break=20, | |
# rand_pose=torch.randn, | |
# iterationsCount=None, | |
# verbose=True): | |
# super().__init__() | |
# if not isinstance(view1['idx'], list): | |
# view1['idx'] = view1['idx'].tolist() | |
# if not isinstance(view2['idx'], list): | |
# view2['idx'] = view2['idx'].tolist() | |
# self.edges = [(int(i), int(j)) for i, j in zip(view1['idx'], view2['idx'])] | |
# self.is_symmetrized = set(self.edges) == {(j, i) for i, j in self.edges} | |
# self.dist = ALL_DISTS[dist] | |
# self.verbose = verbose | |
# self.n_imgs = self._check_edges() | |
# # input data | |
# pred1_pts = pred1['pts3d'] | |
# pred2_pts = pred2['pts3d_in_other_view'] | |
# self.pred_i = NoGradParamDict({ij: pred1_pts[n] for n, ij in enumerate(self.str_edges)}) | |
# self.pred_j = NoGradParamDict({ij: pred2_pts[n] for n, ij in enumerate(self.str_edges)}) | |
# # self.ori_pred_i = NoGradParamDict({ij: pred1_pts[n] for n, ij in enumerate(self.str_edges)}) | |
# # self.ori_pred_j = NoGradParamDict({ij: pred2_pts[n] for n, ij in enumerate(self.str_edges)}) | |
# self.imshapes = get_imshapes(self.edges, pred1_pts, pred2_pts) | |
# # work in log-scale with conf | |
# pred1_conf = pred1['conf'] | |
# pred2_conf = pred2['conf'] | |
# self.min_conf_thr = min_conf_thr | |
# self.conf_trf = get_conf_trf(conf) | |
# self.conf_i = NoGradParamDict({ij: pred1_conf[e] for e, ij in enumerate(self.str_edges)}) | |
# self.conf_j = NoGradParamDict({ij: pred2_conf[e] for e, ij in enumerate(self.str_edges)}) | |
# self.im_conf = self._compute_img_conf(pred1_conf, pred2_conf) | |
# for i in range(len(self.im_conf)): | |
# self.im_conf[i].requires_grad = False | |
# # pairwise pose parameters | |
# self.base_scale = base_scale | |
# self.norm_pw_scale = True | |
# self.pw_break = pw_break | |
# self.POSE_DIM = 7 | |
# self.pw_poses = nn.Parameter(rand_pose((self.n_edges, 1+self.POSE_DIM))) # pairwise poses | |
# self.pw_poses.requires_grad_(True) | |
# self.pw_adaptors = nn.Parameter(torch.zeros((self.n_edges, 2))) # slight xy/z adaptation | |
# self.pw_adaptors.requires_grad_(True) | |
# self.has_im_poses = False | |
# self.rand_pose = rand_pose | |
# # possibly store images for show_pointcloud | |
# self.imgs = None | |
# if 'img' in view1 and 'img' in view2: | |
# imgs = [torch.zeros((3,)+hw) for hw in self.imshapes] | |
# smoothed_imgs = [torch.zeros((3,)+hw) for hw in self.imshapes] | |
# ori_imgs = [torch.zeros((3,)+hw) for hw in self.imshapes] | |
# for v in range(len(self.edges)): | |
# idx = view1['idx'][v] | |
# imgs[idx] = view1['img'][v] | |
# smoothed_imgs[idx] = view1['smoothed_img'][v] | |
# ori_imgs[idx] = view1['ori_img'][v] | |
# idx = view2['idx'][v] | |
# imgs[idx] = view2['img'][v] | |
# smoothed_imgs[idx] = view2['smoothed_img'][v] | |
# ori_imgs[idx] = view2['ori_img'][v] | |
# self.imgs = rgb(imgs) | |
# self.ori_imgs = rgb(ori_imgs) | |
# self.fix_imgs = rgb(ori_imgs) | |
# self.smoothed_imgs = rgb(smoothed_imgs) | |
# self.cogs = [torch.zeros((h, w, 1024), device="cuda") for h, w in self.imshapes] | |
# semantic_feats = semantic_feats.to("cuda") | |
# self.segmaps = [-torch.ones((h, w), device="cuda") for h, w in self.imshapes] | |
# self.rev_segmaps = [-torch.ones((h, w), device="cuda") for h, w in self.imshapes] | |
# # self.conf_1 = [torch.zeros((h, w), device="cuda") for h, w in self.imshapes] | |
# # self.conf_2 = [torch.zeros((h, w), device="cuda") for h, w in self.imshapes] | |
# for v in range(len(self.edges)): | |
# idx = view1['idx'][v] | |
# h, w = self.cogs[idx].shape[0], self.cogs[idx].shape[1] | |
# cog_seg_map = cog_seg_maps[idx] | |
# cog_seg_map = torch.from_numpy(cv2.resize(cog_seg_map, [w, h], interpolation=cv2.INTER_NEAREST)) | |
# rev_seg_map = rev_cog_seg_maps[idx] | |
# rev_seg_map = torch.from_numpy(cv2.resize(rev_seg_map, [w, h], interpolation=cv2.INTER_NEAREST)) | |
# y, x = torch.meshgrid(torch.arange(0, h), torch.arange(0, w)) | |
# x = x.reshape(-1, 1) | |
# y = y.reshape(-1, 1) | |
# seg = cog_seg_map[y, x].squeeze(-1).long() | |
# self.cogs[idx] = semantic_feats[seg].reshape(h, w, -1) | |
# self.segmaps[idx] = cog_seg_map.cuda() | |
# self.rev_segmaps[idx] = rev_seg_map.cuda() | |
# idx = view2['idx'][v] | |
# h, w = self.cogs[idx].shape[0], self.cogs[idx].shape[1] | |
# cog_seg_map = cog_seg_maps[idx] | |
# cog_seg_map = torch.from_numpy(cv2.resize(cog_seg_map, [w, h], interpolation=cv2.INTER_NEAREST)) | |
# rev_seg_map = rev_cog_seg_maps[idx] | |
# rev_seg_map = torch.from_numpy(cv2.resize(rev_seg_map, [w, h], interpolation=cv2.INTER_NEAREST)) | |
# y, x = torch.meshgrid(torch.arange(0, h), torch.arange(0, w)) | |
# x = x.reshape(-1, 1) | |
# y = y.reshape(-1, 1) | |
# seg = cog_seg_map[y, x].squeeze(-1).long() | |
# self.cogs[idx] = semantic_feats[seg].reshape(h, w, -1) | |
# self.segmaps[idx] = cog_seg_map.cuda() | |
# self.rev_segmaps[idx] = rev_seg_map.cuda() | |
# self.rendered_imgs = [] | |
# def render_image(self, text_feats, threshold=0.85): | |
# self.rendered_imgs = [] | |
# # Collect all cosine similarities to compute min-max normalization | |
# all_similarities = [] | |
# for each_cog in self.cogs: | |
# similarity_map = cosine_similarity(each_cog.to("cpu"), text_feats.to("cpu").unsqueeze(1), dim=-1) | |
# all_similarities.append(similarity_map.squeeze().numpy()) | |
# # Flatten and normalize all similarities | |
# total_similarities = np.concatenate(all_similarities) | |
# min_sim, max_sim = total_similarities.min(), total_similarities.max() | |
# normalized_similarities = [(sim - min_sim) / (max_sim - min_sim) for sim in all_similarities] | |
# # Process each image with normalized similarities | |
# for i, (each_cog, heatmap) in enumerate(zip(self.cogs, normalized_similarities)): | |
# mask = heatmap > threshold | |
# # Scale heatmap for visualization | |
# heatmap = np.uint8(255 * heatmap) | |
# heatmap_color = cv2.applyColorMap(heatmap, cv2.COLORMAP_JET) | |
# # Prepare image | |
# image = self.fix_imgs[i] | |
# image = image * 255.0 | |
# image = np.clip(image, 0, 255).astype(np.uint8) | |
# # Apply mask and overlay heatmap with red RGB for masked areas | |
# mask_indices = np.where(mask) # Get indices where mask is True | |
# heatmap_color[mask_indices[0], mask_indices[1]] = [0, 0, 255] # Red color for masked regions | |
# superimposed_img = np.where(np.expand_dims(mask, axis=-1), heatmap_color, image) / 255.0 | |
# self.rendered_imgs.append(superimposed_img) | |
# @property | |
# def n_edges(self): | |
# return len(self.edges) | |
# @property | |
# def str_edges(self): | |
# return [edge_str(i, j) for i, j in self.edges] | |
# @property | |
# def imsizes(self): | |
# return [(w, h) for h, w in self.imshapes] | |
# @property | |
# def device(self): | |
# return next(iter(self.parameters())).device | |
# def state_dict(self, trainable=True): | |
# all_params = super().state_dict() | |
# return {k: v for k, v in all_params.items() if k.startswith(('_', 'pred_i.', 'pred_j.', 'conf_i.', 'conf_j.')) != trainable} | |
# def load_state_dict(self, data): | |
# return super().load_state_dict(self.state_dict(trainable=False) | data) | |
# def _check_edges(self): | |
# indices = sorted({i for edge in self.edges for i in edge}) | |
# assert indices == list(range(len(indices))), 'bad pair indices: missing values ' | |
# return len(indices) | |
# @torch.no_grad() | |
# def _compute_img_conf(self, pred1_conf, pred2_conf): | |
# im_conf = nn.ParameterList([torch.zeros(hw, device=self.device) for hw in self.imshapes]) | |
# for e, (i, j) in enumerate(self.edges): | |
# im_conf[i] = torch.maximum(im_conf[i], pred1_conf[e]) | |
# im_conf[j] = torch.maximum(im_conf[j], pred2_conf[e]) | |
# return im_conf | |
# def get_adaptors(self): | |
# adapt = self.pw_adaptors | |
# adapt = torch.cat((adapt[:, 0:1], adapt), dim=-1) # (scale_xy, scale_xy, scale_z) | |
# if self.norm_pw_scale: # normalize so that the product == 1 | |
# adapt = adapt - adapt.mean(dim=1, keepdim=True) | |
# return (adapt / self.pw_break).exp() | |
# def _get_poses(self, poses): | |
# # normalize rotation | |
# Q = poses[:, :4] | |
# T = signed_expm1(poses[:, 4:7]) | |
# RT = roma.RigidUnitQuat(Q, T).normalize().to_homogeneous() | |
# return RT | |
# def _set_pose(self, poses, idx, R, T=None, scale=None, force=False): | |
# # all poses == cam-to-world | |
# pose = poses[idx] | |
# if not (pose.requires_grad or force): | |
# return pose | |
# if R.shape == (4, 4): | |
# assert T is None | |
# T = R[:3, 3] | |
# R = R[:3, :3] | |
# if R is not None: | |
# pose.data[0:4] = roma.rotmat_to_unitquat(R) | |
# if T is not None: | |
# pose.data[4:7] = signed_log1p(T / (scale or 1)) # translation is function of scale | |
# if scale is not None: | |
# assert poses.shape[-1] in (8, 13) | |
# pose.data[-1] = np.log(float(scale)) | |
# return pose | |
# def get_pw_norm_scale_factor(self): | |
# if self.norm_pw_scale: | |
# # normalize scales so that things cannot go south | |
# # we want that exp(scale) ~= self.base_scale | |
# return (np.log(self.base_scale) - self.pw_poses[:, -1].mean()).exp() | |
# else: | |
# return 1 # don't norm scale for known poses | |
# def get_pw_scale(self): | |
# scale = self.pw_poses[:, -1].exp() # (n_edges,) | |
# scale = scale * self.get_pw_norm_scale_factor() | |
# return scale | |
# def get_pw_poses(self): # cam to world | |
# RT = self._get_poses(self.pw_poses) | |
# scaled_RT = RT.clone() | |
# scaled_RT[:, :3] *= self.get_pw_scale().view(-1, 1, 1) # scale the rotation AND translation | |
# return scaled_RT | |
# def get_masks(self): | |
# return [(conf > self.min_conf_thr) for conf in self.im_conf] | |
# def depth_to_pts3d(self): | |
# raise NotImplementedError() | |
# def get_pts3d(self, raw=False): | |
# res = self.depth_to_pts3d() | |
# if not raw: | |
# res = [dm[:h*w].view(h, w, 3) for dm, (h, w) in zip(res, self.imshapes)] | |
# return res | |
# def _set_focal(self, idx, focal, force=False): | |
# raise NotImplementedError() | |
# def get_focals(self): | |
# raise NotImplementedError() | |
# def get_known_focal_mask(self): | |
# raise NotImplementedError() | |
# def get_principal_points(self): | |
# raise NotImplementedError() | |
# def get_conf(self, mode=None): | |
# trf = self.conf_trf if mode is None else get_conf_trf(mode) | |
# return [trf(c) for c in self.im_conf] | |
# def get_im_poses(self): | |
# raise NotImplementedError() | |
# def _set_depthmap(self, idx, depth, force=False): | |
# raise NotImplementedError() | |
# def get_depthmaps(self, raw=False): | |
# raise NotImplementedError() | |
# def clean_pointcloud(self, **kw): | |
# cams = inv(self.get_im_poses()) | |
# K = self.get_intrinsics() | |
# depthmaps = self.get_depthmaps() | |
# all_pts3d = self.get_pts3d() | |
# new_im_confs = clean_pointcloud(self.im_conf, K, cams, depthmaps, all_pts3d, **kw) | |
# for i, new_conf in enumerate(new_im_confs): | |
# self.im_conf[i].data[:] = new_conf | |
# return self | |
# def forward(self, ret_details=False): | |
# pw_poses = self.get_pw_poses() # cam-to-world | |
# pw_adapt = self.get_adaptors() | |
# proj_pts3d = self.get_pts3d() | |
# # pre-compute pixel weights | |
# weight_i = {i_j: self.conf_trf(c) for i_j, c in self.conf_i.items()} | |
# weight_j = {i_j: self.conf_trf(c) for i_j, c in self.conf_j.items()} | |
# loss = 0 | |
# if ret_details: | |
# details = -torch.ones((self.n_imgs, self.n_imgs)) | |
# for e, (i, j) in enumerate(self.edges): | |
# i_j = edge_str(i, j) | |
# # distance in image i and j | |
# aligned_pred_i = geotrf(pw_poses[e], pw_adapt[e] * self.pred_i[i_j]) | |
# aligned_pred_j = geotrf(pw_poses[e], pw_adapt[e] * self.pred_j[i_j]) | |
# li = self.dist(proj_pts3d[i], aligned_pred_i, weight=weight_i[i_j]).mean() | |
# lj = self.dist(proj_pts3d[j], aligned_pred_j, weight=weight_j[i_j]).mean() | |
# loss = loss + li + lj | |
# if ret_details: | |
# details[i, j] = li + lj | |
# loss /= self.n_edges # average over all pairs | |
# if ret_details: | |
# return loss, details | |
# return loss | |
# def spatial_select_points(self, point_maps, semantic_maps, confidence_maps): | |
# H, W = semantic_maps.shape | |
# # 将点图和语义图调整为二维形式 | |
# point_map = point_maps.view(-1, 3) # (H*W, 3) | |
# semantic_map = semantic_maps.view(-1) # (H*W) | |
# confidence_map = confidence_maps.view(-1) | |
# dist_map = torch.zeros_like(semantic_map, dtype=torch.float32) | |
# cnt_map = torch.zeros_like(semantic_map, dtype=torch.float32) | |
# # near_point_map = torch.zeros_like(point_map, dtype=torch.float32) | |
# # refresh_point_map = point_map.clone() | |
# refresh_confidence_map = confidence_map.clone() | |
# # 创建图像的索引 | |
# row_idx, col_idx = torch.meshgrid(torch.arange(H), torch.arange(W)) | |
# row_idx = row_idx.flatten() | |
# col_idx = col_idx.flatten() | |
# kernel_size = 7 | |
# offset_range = kernel_size // 2 | |
# neighbor_offsets = [ | |
# (dx, dy) for dx in range(-offset_range, offset_range + 1) | |
# for dy in range(-offset_range, offset_range + 1) | |
# if not (dx == 0 and dy == 0) | |
# ] | |
# # 对每个像素点进行计算(仅在当前图像内计算邻域关系) | |
# for offset in neighbor_offsets: | |
# # 计算邻居位置 | |
# neighbor_row = row_idx + offset[0] | |
# neighbor_col = col_idx + offset[1] | |
# # 确保邻居在图像内部 | |
# valid_mask = (neighbor_row >= 0) & (neighbor_row < H) & (neighbor_col >= 0) & (neighbor_col < W) | |
# valid_row = neighbor_row[valid_mask] | |
# valid_col = neighbor_col[valid_mask] | |
# # 获取有效像素点的索引 | |
# idx = valid_mask.nonzero(as_tuple=True)[0] | |
# neighbor_idx = valid_row * W + valid_col | |
# # 获取相邻像素点的语义标签和空间坐标 | |
# sem_i = semantic_map[idx] | |
# sem_j = semantic_map[neighbor_idx] | |
# p_i = point_map[idx] | |
# p_j = point_map[neighbor_idx] | |
# # 计算空间坐标差异的平方 | |
# distance = torch.sum((p_i - p_j)**2, dim=1) | |
# same_object = (sem_i == sem_j) & (sem_i != -1) & (sem_j != -1) | |
# dist_map[idx] += same_object * distance | |
# cnt_map[idx] += same_object | |
# anomaly_point = (dist_map / (cnt_map + 1e-6)) | |
# print(anomaly_point, anomaly_point.shape) | |
# anomaly_point = (anomaly_point > 0.001) & (cnt_map != 0) | |
# anomaly_point_idx = anomaly_point.nonzero(as_tuple=True)[0] | |
# refresh_confidence_map[anomaly_point_idx] = 0 | |
# return refresh_confidence_map.view(H, W) | |
# @torch.cuda.amp.autocast(enabled=False) | |
# def compute_global_alignment(self, tune_flg=False, init=None, niter_PnP=10, **kw): | |
# if tune_flg: | |
# im_conf = nn.ParameterList([torch.zeros(hw, device=self.device) for hw in self.imshapes]) | |
# for e, (i, j) in enumerate(self.edges): | |
# i_j = edge_str(i, j) | |
# im_conf[i] = self.spatial_select_points(self.pred_i[i_j], self.rev_segmaps[i], self.conf_i[i_j]) | |
# im_conf[j] = self.spatial_select_points(self.pred_j[i_j], self.rev_segmaps[j], self.conf_j[i_j]) | |
# for i in range(len(self.imgs)): | |
# self.imgs[i] = self.ori_imgs[i] | |
# anomaly_mask = (im_conf[i] == 0) | |
# unique_labels = torch.unique(self.rev_segmaps[i]) | |
# for label in unique_labels: | |
# semantic_mask = (self.rev_segmaps[i] == label) | |
# if label == -1: | |
# continue | |
# cover = (semantic_mask & anomaly_mask).sum() / semantic_mask.sum() | |
# if cover > 0.3: | |
# self.imgs[i][semantic_mask.cpu()] = self.smoothed_imgs[i][semantic_mask.cpu()] | |
# for j in range(len(self.imgs)): | |
# if j == i: | |
# continue | |
# semantic_mask = (self.rev_segmaps[j] == label) | |
# self.imgs[j][semantic_mask.cpu()] = self.smoothed_imgs[j][semantic_mask.cpu()] | |
# if init is None: | |
# pass | |
# elif init == 'msp' or init == 'mst': | |
# init_fun.init_minimum_spanning_tree(self, niter_PnP=niter_PnP) | |
# elif init == 'known_poses': | |
# init_fun.init_from_known_poses(self, min_conf_thr=self.min_conf_thr, | |
# niter_PnP=niter_PnP) | |
# else: | |
# raise ValueError(f'bad value for {init=}') | |
# if tune_flg: | |
# return 0 | |
# # loss = 0 | |
# loss = global_alignment_loop(self, **kw) | |
# # | |
# # init_fun.init_minimum_spanning_tree(self, niter_PnP=niter_PnP) | |
# return loss | |
# @torch.no_grad() | |
# def mask_sky(self): | |
# res = deepcopy(self) | |
# for i in range(self.n_imgs): | |
# sky = segment_sky(self.imgs[i]) | |
# res.im_conf[i][sky] = 0 | |
# return res | |
# def show(self, show_pw_cams=False, show_pw_pts3d=False, cam_size=None, **kw): | |
# viz = SceneViz() | |
# if self.imgs is None: | |
# colors = np.random.randint(0, 256, size=(self.n_imgs, 3)) | |
# colors = list(map(tuple, colors.tolist())) | |
# for n in range(self.n_imgs): | |
# viz.add_pointcloud(self.get_pts3d()[n], colors[n], self.get_masks()[n]) | |
# else: | |
# viz.add_pointcloud(self.get_pts3d(), self.imgs, self.get_masks()) | |
# colors = np.random.randint(256, size=(self.n_imgs, 3)) | |
# # camera poses | |
# im_poses = to_numpy(self.get_im_poses()) | |
# if cam_size is None: | |
# cam_size = auto_cam_size(im_poses) | |
# viz.add_cameras(im_poses, self.get_focals(), colors=colors, | |
# images=self.imgs, imsizes=self.imsizes, cam_size=cam_size) | |
# if show_pw_cams: | |
# pw_poses = self.get_pw_poses() | |
# viz.add_cameras(pw_poses, color=(192, 0, 192), cam_size=cam_size) | |
# if show_pw_pts3d: | |
# pts = [geotrf(pw_poses[e], self.pred_i[edge_str(i, j)]) for e, (i, j) in enumerate(self.edges)] | |
# viz.add_pointcloud(pts, (128, 0, 128)) | |
# viz.show(**kw) | |
# return viz | |
# def global_alignment_loop(net, lr=0.01, niter=300, schedule='cosine', lr_min=1e-6): | |
# # return net | |
# params = [p for p in net.parameters() if p.requires_grad] | |
# for param in params: | |
# print(param.shape) | |
# if not params: | |
# return net | |
# verbose = net.verbose | |
# if verbose: | |
# print('Global alignement - optimizing for:') | |
# print([name for name, value in net.named_parameters() if value.requires_grad]) | |
# lr_base = lr | |
# optimizer = torch.optim.Adam(params, lr=lr, betas=(0.9, 0.9)) | |
# loss = float('inf') | |
# if verbose: | |
# with tqdm.tqdm(total=niter) as bar: | |
# while bar.n < bar.total: | |
# loss, lr = global_alignment_iter(net, bar.n, niter, lr_base, lr_min, optimizer, schedule) | |
# bar.set_postfix_str(f'{lr=:g} loss={loss:g}') | |
# bar.update() | |
# else: | |
# for n in range(niter): | |
# loss, _ = global_alignment_iter(net, n, niter, lr_base, lr_min, optimizer, schedule) | |
# return loss | |
# def global_alignment_iter(net, cur_iter, niter, lr_base, lr_min, optimizer, schedule): | |
# t = cur_iter / niter | |
# if schedule == 'cosine': | |
# lr = cosine_schedule(t, lr_base, lr_min) | |
# elif schedule == 'linear': | |
# lr = linear_schedule(t, lr_base, lr_min) | |
# else: | |
# raise ValueError(f'bad lr {schedule=}') | |
# adjust_learning_rate_by_lr(optimizer, lr) | |
# optimizer.zero_grad() | |
# loss = net(cur_iter) | |
# if loss == 0: | |
# optimizer.step() | |
# return float(loss), lr | |
# loss.backward() | |
# optimizer.step() | |
# return float(loss), lr | |
# @torch.no_grad() | |
# def clean_pointcloud( im_confs, K, cams, depthmaps, all_pts3d, | |
# tol=0.001, bad_conf=0, dbg=()): | |
# """ Method: | |
# 1) express all 3d points in each camera coordinate frame | |
# 2) if they're in front of a depthmap --> then lower their confidence | |
# """ | |
# assert len(im_confs) == len(cams) == len(K) == len(depthmaps) == len(all_pts3d) | |
# assert 0 <= tol < 1 | |
# res = [c.clone() for c in im_confs] | |
# # reshape appropriately | |
# all_pts3d = [p.view(*c.shape,3) for p,c in zip(all_pts3d, im_confs)] | |
# depthmaps = [d.view(*c.shape) for d,c in zip(depthmaps, im_confs)] | |
# for i, pts3d in enumerate(all_pts3d): | |
# for j in range(len(all_pts3d)): | |
# if i == j: continue | |
# # project 3dpts in other view | |
# proj = geotrf(cams[j], pts3d) | |
# proj_depth = proj[:,:,2] | |
# u,v = geotrf(K[j], proj, norm=1, ncol=2).round().long().unbind(-1) | |
# # check which points are actually in the visible cone | |
# H, W = im_confs[j].shape | |
# msk_i = (proj_depth > 0) & (0 <= u) & (u < W) & (0 <= v) & (v < H) | |
# msk_j = v[msk_i], u[msk_i] | |
# # find bad points = those in front but less confident | |
# bad_points = (proj_depth[msk_i] < (1-tol) * depthmaps[j][msk_j]) & (res[i][msk_i] < res[j][msk_j]) | |
# bad_msk_i = msk_i.clone() | |
# bad_msk_i[msk_i] = bad_points | |
# res[i][bad_msk_i] = res[i][bad_msk_i].clip_(max=bad_conf) | |
# return res | |