hujiecpp commited on
Commit
f22f2cc
·
1 Parent(s): 6caa8a9

init project

Browse files
modules/dust3r/cloud_opt/__init__.py CHANGED
@@ -22,11 +22,11 @@ def global_aligner(dust3r_output, cog_seg_maps, rev_cog_seg_maps, semantic_feats
22
  view1, view2, pred1, pred2 = [dust3r_output[k] for k in 'view1 view2 pred1 pred2'.split()]
23
  # build the optimizer
24
  if mode == GlobalAlignerMode.PointCloudOptimizer:
25
- net = PointCloudOptimizer(view1, view2, pred1, pred2, cog_seg_maps, rev_cog_seg_maps, semantic_feats, **optim_kw).to(device)
26
  elif mode == GlobalAlignerMode.ModularPointCloudOptimizer:
27
- net = ModularPointCloudOptimizer(view1, view2, pred1, pred2, **optim_kw).to(device)
28
  elif mode == GlobalAlignerMode.PairViewer:
29
- net = PairViewer(view1, view2, pred1, pred2, cog_seg_maps, rev_cog_seg_maps, semantic_feats, **optim_kw).to(device)
30
  else:
31
  raise NotImplementedError(f'Unknown mode {mode}')
32
 
 
22
  view1, view2, pred1, pred2 = [dust3r_output[k] for k in 'view1 view2 pred1 pred2'.split()]
23
  # build the optimizer
24
  if mode == GlobalAlignerMode.PointCloudOptimizer:
25
+ net = PointCloudOptimizer(view1, view2, pred1, pred2, cog_seg_maps, rev_cog_seg_maps, semantic_feats, device, **optim_kw).to(device)
26
  elif mode == GlobalAlignerMode.ModularPointCloudOptimizer:
27
+ net = ModularPointCloudOptimizer(view1, view2, pred1, pred2, device, **optim_kw).to(device)
28
  elif mode == GlobalAlignerMode.PairViewer:
29
+ net = PairViewer(view1, view2, pred1, pred2, cog_seg_maps, rev_cog_seg_maps, semantic_feats, device, **optim_kw).to(device)
30
  else:
31
  raise NotImplementedError(f'Unknown mode {mode}')
32
 
modules/dust3r/cloud_opt/__pycache__/__init__.cpython-312.pyc CHANGED
Binary files a/modules/dust3r/cloud_opt/__pycache__/__init__.cpython-312.pyc and b/modules/dust3r/cloud_opt/__pycache__/__init__.cpython-312.pyc differ
 
modules/dust3r/cloud_opt/__pycache__/base_opt.cpython-312.pyc CHANGED
Binary files a/modules/dust3r/cloud_opt/__pycache__/base_opt.cpython-312.pyc and b/modules/dust3r/cloud_opt/__pycache__/base_opt.cpython-312.pyc differ
 
modules/dust3r/cloud_opt/base_opt.py CHANGED
@@ -44,7 +44,7 @@ class BasePCOptimizer (nn.Module):
44
  else:
45
  self._init_from_views(*args, **kwargs)
46
 
47
- def _init_from_views(self, view1, view2, pred1, pred2, cog_seg_maps, rev_cog_seg_maps, semantic_feats,
48
  dist='l2',
49
  conf='log',
50
  min_conf_thr=3,
@@ -121,10 +121,10 @@ class BasePCOptimizer (nn.Module):
121
  self.fix_imgs = rgb(ori_imgs)
122
  self.smoothed_imgs = rgb(smoothed_imgs)
123
 
124
- self.cogs = [torch.zeros((h, w, 1024)) for h, w in self.imshapes]
125
- # semantic_feats = semantic_feats.to("cuda")
126
- self.segmaps = [-torch.ones((h, w)) for h, w in self.imshapes]
127
- self.rev_segmaps = [-torch.ones((h, w)) for h, w in self.imshapes]
128
 
129
  for v in range(len(self.edges)):
130
  idx = view1['idx'][v]
@@ -141,8 +141,8 @@ class BasePCOptimizer (nn.Module):
141
  seg = cog_seg_map[y, x].squeeze(-1).long()
142
 
143
  self.cogs[idx] = semantic_feats[seg].reshape(h, w, -1)
144
- self.segmaps[idx] = cog_seg_map#.cuda()
145
- self.rev_segmaps[idx] = rev_seg_map#.cuda()
146
 
147
  idx = view2['idx'][v]
148
  h, w = self.cogs[idx].shape[0], self.cogs[idx].shape[1]
@@ -157,8 +157,8 @@ class BasePCOptimizer (nn.Module):
157
  seg = cog_seg_map[y, x].squeeze(-1).long()
158
 
159
  self.cogs[idx] = semantic_feats[seg].reshape(h, w, -1)
160
- self.segmaps[idx] = cog_seg_map#.cuda()
161
- self.rev_segmaps[idx] = rev_seg_map#.cuda()
162
 
163
  self.rendered_imgs = []
164
 
@@ -612,609 +612,4 @@ def clean_pointcloud( im_confs, K, cams, depthmaps, all_pts3d,
612
  bad_msk_i[msk_i] = bad_points
613
  res[i][bad_msk_i] = res[i][bad_msk_i].clip_(max=bad_conf)
614
 
615
- return res
616
-
617
- # Copyright (C) 2024-present Naver Corporation. All rights reserved.
618
- # Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
619
- #
620
- # --------------------------------------------------------
621
- # Base class for the global alignement procedure
622
- # --------------------------------------------------------
623
- # from copy import deepcopy
624
-
625
- # import numpy as np
626
- # import torch
627
- # import torch.nn as nn
628
- # import roma
629
- # from copy import deepcopy
630
- # import tqdm
631
-
632
- # from torch.nn.functional import cosine_similarity
633
- # import cv2
634
-
635
- # from dust3r.utils.geometry import inv, geotrf
636
- # from dust3r.utils.device import to_numpy
637
- # from dust3r.utils.image import rgb
638
- # from dust3r.viz import SceneViz, segment_sky, auto_cam_size
639
- # from dust3r.optim_factory import adjust_learning_rate_by_lr
640
-
641
- # from dust3r.cloud_opt.commons import (edge_str, ALL_DISTS, NoGradParamDict, get_imshapes, signed_expm1, signed_log1p,
642
- # cosine_schedule, linear_schedule, get_conf_trf, GradParamDict)
643
- # import dust3r.cloud_opt.init_im_poses as init_fun
644
-
645
-
646
- # class BasePCOptimizer (nn.Module):
647
- # """ Optimize a global scene, given a list of pairwise observations.
648
- # Graph node: images
649
- # Graph edges: observations = (pred1, pred2)
650
- # """
651
-
652
- # def __init__(self, *args, **kwargs):
653
- # if len(args) == 1 and len(kwargs) == 0:
654
- # other = deepcopy(args[0])
655
- # attrs = '''edges is_symmetrized dist n_imgs pred_i pred_j imshapes
656
- # min_conf_thr conf_thr conf_i conf_j im_conf
657
- # base_scale norm_pw_scale POSE_DIM pw_poses
658
- # pw_adaptors pw_adaptors has_im_poses rand_pose imgs verbose'''.split()
659
- # self.__dict__.update({k: other[k] for k in attrs})
660
- # else:
661
- # self._init_from_views(*args, **kwargs)
662
-
663
- # def _init_from_views(self, view1, view2, pred1, pred2, cog_seg_maps, rev_cog_seg_maps, semantic_feats,
664
- # dist='l2',
665
- # conf='log',
666
- # min_conf_thr=3,
667
- # base_scale=0.5,
668
- # allow_pw_adaptors=False,
669
- # pw_break=20,
670
- # rand_pose=torch.randn,
671
- # iterationsCount=None,
672
- # verbose=True):
673
- # super().__init__()
674
- # if not isinstance(view1['idx'], list):
675
- # view1['idx'] = view1['idx'].tolist()
676
- # if not isinstance(view2['idx'], list):
677
- # view2['idx'] = view2['idx'].tolist()
678
- # self.edges = [(int(i), int(j)) for i, j in zip(view1['idx'], view2['idx'])]
679
- # self.is_symmetrized = set(self.edges) == {(j, i) for i, j in self.edges}
680
- # self.dist = ALL_DISTS[dist]
681
- # self.verbose = verbose
682
-
683
- # self.n_imgs = self._check_edges()
684
-
685
- # # input data
686
- # pred1_pts = pred1['pts3d']
687
- # pred2_pts = pred2['pts3d_in_other_view']
688
- # self.pred_i = NoGradParamDict({ij: pred1_pts[n] for n, ij in enumerate(self.str_edges)})
689
- # self.pred_j = NoGradParamDict({ij: pred2_pts[n] for n, ij in enumerate(self.str_edges)})
690
- # # self.ori_pred_i = NoGradParamDict({ij: pred1_pts[n] for n, ij in enumerate(self.str_edges)})
691
- # # self.ori_pred_j = NoGradParamDict({ij: pred2_pts[n] for n, ij in enumerate(self.str_edges)})
692
- # self.imshapes = get_imshapes(self.edges, pred1_pts, pred2_pts)
693
-
694
- # # work in log-scale with conf
695
- # pred1_conf = pred1['conf']
696
- # pred2_conf = pred2['conf']
697
- # self.min_conf_thr = min_conf_thr
698
- # self.conf_trf = get_conf_trf(conf)
699
-
700
- # self.conf_i = NoGradParamDict({ij: pred1_conf[e] for e, ij in enumerate(self.str_edges)})
701
- # self.conf_j = NoGradParamDict({ij: pred2_conf[e] for e, ij in enumerate(self.str_edges)})
702
- # self.im_conf = self._compute_img_conf(pred1_conf, pred2_conf)
703
- # for i in range(len(self.im_conf)):
704
- # self.im_conf[i].requires_grad = False
705
-
706
- # # pairwise pose parameters
707
- # self.base_scale = base_scale
708
- # self.norm_pw_scale = True
709
- # self.pw_break = pw_break
710
- # self.POSE_DIM = 7
711
- # self.pw_poses = nn.Parameter(rand_pose((self.n_edges, 1+self.POSE_DIM))) # pairwise poses
712
- # self.pw_poses.requires_grad_(True)
713
- # self.pw_adaptors = nn.Parameter(torch.zeros((self.n_edges, 2))) # slight xy/z adaptation
714
- # self.pw_adaptors.requires_grad_(True)
715
- # self.has_im_poses = False
716
- # self.rand_pose = rand_pose
717
-
718
- # # possibly store images for show_pointcloud
719
- # self.imgs = None
720
- # if 'img' in view1 and 'img' in view2:
721
- # imgs = [torch.zeros((3,)+hw) for hw in self.imshapes]
722
- # smoothed_imgs = [torch.zeros((3,)+hw) for hw in self.imshapes]
723
- # ori_imgs = [torch.zeros((3,)+hw) for hw in self.imshapes]
724
- # for v in range(len(self.edges)):
725
- # idx = view1['idx'][v]
726
- # imgs[idx] = view1['img'][v]
727
- # smoothed_imgs[idx] = view1['smoothed_img'][v]
728
- # ori_imgs[idx] = view1['ori_img'][v]
729
-
730
- # idx = view2['idx'][v]
731
- # imgs[idx] = view2['img'][v]
732
- # smoothed_imgs[idx] = view2['smoothed_img'][v]
733
- # ori_imgs[idx] = view2['ori_img'][v]
734
-
735
- # self.imgs = rgb(imgs)
736
- # self.ori_imgs = rgb(ori_imgs)
737
- # self.fix_imgs = rgb(ori_imgs)
738
- # self.smoothed_imgs = rgb(smoothed_imgs)
739
-
740
- # self.cogs = [torch.zeros((h, w, 1024), device="cuda") for h, w in self.imshapes]
741
- # semantic_feats = semantic_feats.to("cuda")
742
- # self.segmaps = [-torch.ones((h, w), device="cuda") for h, w in self.imshapes]
743
- # self.rev_segmaps = [-torch.ones((h, w), device="cuda") for h, w in self.imshapes]
744
- # # self.conf_1 = [torch.zeros((h, w), device="cuda") for h, w in self.imshapes]
745
- # # self.conf_2 = [torch.zeros((h, w), device="cuda") for h, w in self.imshapes]
746
- # for v in range(len(self.edges)):
747
- # idx = view1['idx'][v]
748
-
749
- # h, w = self.cogs[idx].shape[0], self.cogs[idx].shape[1]
750
- # cog_seg_map = cog_seg_maps[idx]
751
- # cog_seg_map = torch.from_numpy(cv2.resize(cog_seg_map, [w, h], interpolation=cv2.INTER_NEAREST))
752
- # rev_seg_map = rev_cog_seg_maps[idx]
753
- # rev_seg_map = torch.from_numpy(cv2.resize(rev_seg_map, [w, h], interpolation=cv2.INTER_NEAREST))
754
-
755
- # y, x = torch.meshgrid(torch.arange(0, h), torch.arange(0, w))
756
- # x = x.reshape(-1, 1)
757
- # y = y.reshape(-1, 1)
758
- # seg = cog_seg_map[y, x].squeeze(-1).long()
759
-
760
- # self.cogs[idx] = semantic_feats[seg].reshape(h, w, -1)
761
- # self.segmaps[idx] = cog_seg_map.cuda()
762
- # self.rev_segmaps[idx] = rev_seg_map.cuda()
763
-
764
- # idx = view2['idx'][v]
765
- # h, w = self.cogs[idx].shape[0], self.cogs[idx].shape[1]
766
- # cog_seg_map = cog_seg_maps[idx]
767
- # cog_seg_map = torch.from_numpy(cv2.resize(cog_seg_map, [w, h], interpolation=cv2.INTER_NEAREST))
768
- # rev_seg_map = rev_cog_seg_maps[idx]
769
- # rev_seg_map = torch.from_numpy(cv2.resize(rev_seg_map, [w, h], interpolation=cv2.INTER_NEAREST))
770
-
771
- # y, x = torch.meshgrid(torch.arange(0, h), torch.arange(0, w))
772
- # x = x.reshape(-1, 1)
773
- # y = y.reshape(-1, 1)
774
- # seg = cog_seg_map[y, x].squeeze(-1).long()
775
-
776
- # self.cogs[idx] = semantic_feats[seg].reshape(h, w, -1)
777
- # self.segmaps[idx] = cog_seg_map.cuda()
778
- # self.rev_segmaps[idx] = rev_seg_map.cuda()
779
-
780
- # self.rendered_imgs = []
781
-
782
- # def render_image(self, text_feats, threshold=0.85):
783
- # self.rendered_imgs = []
784
-
785
- # # Collect all cosine similarities to compute min-max normalization
786
- # all_similarities = []
787
- # for each_cog in self.cogs:
788
- # similarity_map = cosine_similarity(each_cog.to("cpu"), text_feats.to("cpu").unsqueeze(1), dim=-1)
789
- # all_similarities.append(similarity_map.squeeze().numpy())
790
-
791
- # # Flatten and normalize all similarities
792
- # total_similarities = np.concatenate(all_similarities)
793
- # min_sim, max_sim = total_similarities.min(), total_similarities.max()
794
- # normalized_similarities = [(sim - min_sim) / (max_sim - min_sim) for sim in all_similarities]
795
-
796
- # # Process each image with normalized similarities
797
- # for i, (each_cog, heatmap) in enumerate(zip(self.cogs, normalized_similarities)):
798
- # mask = heatmap > threshold
799
-
800
- # # Scale heatmap for visualization
801
- # heatmap = np.uint8(255 * heatmap)
802
- # heatmap_color = cv2.applyColorMap(heatmap, cv2.COLORMAP_JET)
803
-
804
- # # Prepare image
805
- # image = self.fix_imgs[i]
806
- # image = image * 255.0
807
- # image = np.clip(image, 0, 255).astype(np.uint8)
808
-
809
- # # Apply mask and overlay heatmap with red RGB for masked areas
810
- # mask_indices = np.where(mask) # Get indices where mask is True
811
- # heatmap_color[mask_indices[0], mask_indices[1]] = [0, 0, 255] # Red color for masked regions
812
-
813
- # superimposed_img = np.where(np.expand_dims(mask, axis=-1), heatmap_color, image) / 255.0
814
-
815
- # self.rendered_imgs.append(superimposed_img)
816
-
817
- # @property
818
- # def n_edges(self):
819
- # return len(self.edges)
820
-
821
- # @property
822
- # def str_edges(self):
823
- # return [edge_str(i, j) for i, j in self.edges]
824
-
825
- # @property
826
- # def imsizes(self):
827
- # return [(w, h) for h, w in self.imshapes]
828
-
829
- # @property
830
- # def device(self):
831
- # return next(iter(self.parameters())).device
832
-
833
- # def state_dict(self, trainable=True):
834
- # all_params = super().state_dict()
835
- # return {k: v for k, v in all_params.items() if k.startswith(('_', 'pred_i.', 'pred_j.', 'conf_i.', 'conf_j.')) != trainable}
836
-
837
- # def load_state_dict(self, data):
838
- # return super().load_state_dict(self.state_dict(trainable=False) | data)
839
-
840
- # def _check_edges(self):
841
- # indices = sorted({i for edge in self.edges for i in edge})
842
- # assert indices == list(range(len(indices))), 'bad pair indices: missing values '
843
- # return len(indices)
844
-
845
- # @torch.no_grad()
846
- # def _compute_img_conf(self, pred1_conf, pred2_conf):
847
- # im_conf = nn.ParameterList([torch.zeros(hw, device=self.device) for hw in self.imshapes])
848
- # for e, (i, j) in enumerate(self.edges):
849
- # im_conf[i] = torch.maximum(im_conf[i], pred1_conf[e])
850
- # im_conf[j] = torch.maximum(im_conf[j], pred2_conf[e])
851
- # return im_conf
852
-
853
- # def get_adaptors(self):
854
- # adapt = self.pw_adaptors
855
- # adapt = torch.cat((adapt[:, 0:1], adapt), dim=-1) # (scale_xy, scale_xy, scale_z)
856
- # if self.norm_pw_scale: # normalize so that the product == 1
857
- # adapt = adapt - adapt.mean(dim=1, keepdim=True)
858
- # return (adapt / self.pw_break).exp()
859
-
860
- # def _get_poses(self, poses):
861
- # # normalize rotation
862
- # Q = poses[:, :4]
863
- # T = signed_expm1(poses[:, 4:7])
864
- # RT = roma.RigidUnitQuat(Q, T).normalize().to_homogeneous()
865
- # return RT
866
-
867
- # def _set_pose(self, poses, idx, R, T=None, scale=None, force=False):
868
- # # all poses == cam-to-world
869
- # pose = poses[idx]
870
- # if not (pose.requires_grad or force):
871
- # return pose
872
-
873
- # if R.shape == (4, 4):
874
- # assert T is None
875
- # T = R[:3, 3]
876
- # R = R[:3, :3]
877
-
878
- # if R is not None:
879
- # pose.data[0:4] = roma.rotmat_to_unitquat(R)
880
- # if T is not None:
881
- # pose.data[4:7] = signed_log1p(T / (scale or 1)) # translation is function of scale
882
-
883
- # if scale is not None:
884
- # assert poses.shape[-1] in (8, 13)
885
- # pose.data[-1] = np.log(float(scale))
886
- # return pose
887
-
888
- # def get_pw_norm_scale_factor(self):
889
- # if self.norm_pw_scale:
890
- # # normalize scales so that things cannot go south
891
- # # we want that exp(scale) ~= self.base_scale
892
- # return (np.log(self.base_scale) - self.pw_poses[:, -1].mean()).exp()
893
- # else:
894
- # return 1 # don't norm scale for known poses
895
-
896
- # def get_pw_scale(self):
897
- # scale = self.pw_poses[:, -1].exp() # (n_edges,)
898
- # scale = scale * self.get_pw_norm_scale_factor()
899
- # return scale
900
-
901
- # def get_pw_poses(self): # cam to world
902
- # RT = self._get_poses(self.pw_poses)
903
- # scaled_RT = RT.clone()
904
- # scaled_RT[:, :3] *= self.get_pw_scale().view(-1, 1, 1) # scale the rotation AND translation
905
- # return scaled_RT
906
-
907
- # def get_masks(self):
908
- # return [(conf > self.min_conf_thr) for conf in self.im_conf]
909
-
910
- # def depth_to_pts3d(self):
911
- # raise NotImplementedError()
912
-
913
- # def get_pts3d(self, raw=False):
914
- # res = self.depth_to_pts3d()
915
- # if not raw:
916
- # res = [dm[:h*w].view(h, w, 3) for dm, (h, w) in zip(res, self.imshapes)]
917
- # return res
918
-
919
- # def _set_focal(self, idx, focal, force=False):
920
- # raise NotImplementedError()
921
-
922
- # def get_focals(self):
923
- # raise NotImplementedError()
924
-
925
- # def get_known_focal_mask(self):
926
- # raise NotImplementedError()
927
-
928
- # def get_principal_points(self):
929
- # raise NotImplementedError()
930
-
931
- # def get_conf(self, mode=None):
932
- # trf = self.conf_trf if mode is None else get_conf_trf(mode)
933
- # return [trf(c) for c in self.im_conf]
934
-
935
- # def get_im_poses(self):
936
- # raise NotImplementedError()
937
-
938
- # def _set_depthmap(self, idx, depth, force=False):
939
- # raise NotImplementedError()
940
-
941
- # def get_depthmaps(self, raw=False):
942
- # raise NotImplementedError()
943
-
944
- # def clean_pointcloud(self, **kw):
945
- # cams = inv(self.get_im_poses())
946
- # K = self.get_intrinsics()
947
- # depthmaps = self.get_depthmaps()
948
- # all_pts3d = self.get_pts3d()
949
-
950
- # new_im_confs = clean_pointcloud(self.im_conf, K, cams, depthmaps, all_pts3d, **kw)
951
-
952
- # for i, new_conf in enumerate(new_im_confs):
953
- # self.im_conf[i].data[:] = new_conf
954
- # return self
955
-
956
- # def forward(self, ret_details=False):
957
- # pw_poses = self.get_pw_poses() # cam-to-world
958
- # pw_adapt = self.get_adaptors()
959
- # proj_pts3d = self.get_pts3d()
960
- # # pre-compute pixel weights
961
- # weight_i = {i_j: self.conf_trf(c) for i_j, c in self.conf_i.items()}
962
- # weight_j = {i_j: self.conf_trf(c) for i_j, c in self.conf_j.items()}
963
-
964
- # loss = 0
965
- # if ret_details:
966
- # details = -torch.ones((self.n_imgs, self.n_imgs))
967
-
968
- # for e, (i, j) in enumerate(self.edges):
969
- # i_j = edge_str(i, j)
970
- # # distance in image i and j
971
- # aligned_pred_i = geotrf(pw_poses[e], pw_adapt[e] * self.pred_i[i_j])
972
- # aligned_pred_j = geotrf(pw_poses[e], pw_adapt[e] * self.pred_j[i_j])
973
- # li = self.dist(proj_pts3d[i], aligned_pred_i, weight=weight_i[i_j]).mean()
974
- # lj = self.dist(proj_pts3d[j], aligned_pred_j, weight=weight_j[i_j]).mean()
975
- # loss = loss + li + lj
976
-
977
- # if ret_details:
978
- # details[i, j] = li + lj
979
- # loss /= self.n_edges # average over all pairs
980
-
981
- # if ret_details:
982
- # return loss, details
983
- # return loss
984
-
985
-
986
- # def spatial_select_points(self, point_maps, semantic_maps, confidence_maps):
987
- # H, W = semantic_maps.shape
988
-
989
- # # 将点图和语义图调整为二维形式
990
- # point_map = point_maps.view(-1, 3) # (H*W, 3)
991
- # semantic_map = semantic_maps.view(-1) # (H*W)
992
- # confidence_map = confidence_maps.view(-1)
993
-
994
- # dist_map = torch.zeros_like(semantic_map, dtype=torch.float32)
995
- # cnt_map = torch.zeros_like(semantic_map, dtype=torch.float32)
996
- # # near_point_map = torch.zeros_like(point_map, dtype=torch.float32)
997
-
998
- # # refresh_point_map = point_map.clone()
999
- # refresh_confidence_map = confidence_map.clone()
1000
-
1001
- # # 创建图像的索引
1002
- # row_idx, col_idx = torch.meshgrid(torch.arange(H), torch.arange(W))
1003
- # row_idx = row_idx.flatten()
1004
- # col_idx = col_idx.flatten()
1005
-
1006
- # kernel_size = 7
1007
- # offset_range = kernel_size // 2
1008
- # neighbor_offsets = [
1009
- # (dx, dy) for dx in range(-offset_range, offset_range + 1)
1010
- # for dy in range(-offset_range, offset_range + 1)
1011
- # if not (dx == 0 and dy == 0)
1012
- # ]
1013
-
1014
- # # 对每个像素点进行计算(仅在当前图像内计算邻域关系)
1015
- # for offset in neighbor_offsets:
1016
- # # 计算邻居位置
1017
- # neighbor_row = row_idx + offset[0]
1018
- # neighbor_col = col_idx + offset[1]
1019
-
1020
- # # 确保邻居在图像内部
1021
- # valid_mask = (neighbor_row >= 0) & (neighbor_row < H) & (neighbor_col >= 0) & (neighbor_col < W)
1022
- # valid_row = neighbor_row[valid_mask]
1023
- # valid_col = neighbor_col[valid_mask]
1024
-
1025
- # # 获取有效像素点的索引
1026
- # idx = valid_mask.nonzero(as_tuple=True)[0]
1027
- # neighbor_idx = valid_row * W + valid_col
1028
-
1029
- # # 获取相邻像素点的语义标签和空间坐标
1030
- # sem_i = semantic_map[idx]
1031
- # sem_j = semantic_map[neighbor_idx]
1032
- # p_i = point_map[idx]
1033
- # p_j = point_map[neighbor_idx]
1034
-
1035
- # # 计算空间坐标差异的平方
1036
- # distance = torch.sum((p_i - p_j)**2, dim=1)
1037
-
1038
- # same_object = (sem_i == sem_j) & (sem_i != -1) & (sem_j != -1)
1039
-
1040
- # dist_map[idx] += same_object * distance
1041
- # cnt_map[idx] += same_object
1042
-
1043
- # anomaly_point = (dist_map / (cnt_map + 1e-6))
1044
- # print(anomaly_point, anomaly_point.shape)
1045
- # anomaly_point = (anomaly_point > 0.001) & (cnt_map != 0)
1046
- # anomaly_point_idx = anomaly_point.nonzero(as_tuple=True)[0]
1047
-
1048
- # refresh_confidence_map[anomaly_point_idx] = 0
1049
-
1050
- # return refresh_confidence_map.view(H, W)
1051
-
1052
- # @torch.cuda.amp.autocast(enabled=False)
1053
- # def compute_global_alignment(self, tune_flg=False, init=None, niter_PnP=10, **kw):
1054
-
1055
- # if tune_flg:
1056
- # im_conf = nn.ParameterList([torch.zeros(hw, device=self.device) for hw in self.imshapes])
1057
- # for e, (i, j) in enumerate(self.edges):
1058
- # i_j = edge_str(i, j)
1059
- # im_conf[i] = self.spatial_select_points(self.pred_i[i_j], self.rev_segmaps[i], self.conf_i[i_j])
1060
- # im_conf[j] = self.spatial_select_points(self.pred_j[i_j], self.rev_segmaps[j], self.conf_j[i_j])
1061
-
1062
- # for i in range(len(self.imgs)):
1063
- # self.imgs[i] = self.ori_imgs[i]
1064
- # anomaly_mask = (im_conf[i] == 0)
1065
- # unique_labels = torch.unique(self.rev_segmaps[i])
1066
- # for label in unique_labels:
1067
- # semantic_mask = (self.rev_segmaps[i] == label)
1068
- # if label == -1:
1069
- # continue
1070
- # cover = (semantic_mask & anomaly_mask).sum() / semantic_mask.sum()
1071
- # if cover > 0.3:
1072
- # self.imgs[i][semantic_mask.cpu()] = self.smoothed_imgs[i][semantic_mask.cpu()]
1073
- # for j in range(len(self.imgs)):
1074
- # if j == i:
1075
- # continue
1076
- # semantic_mask = (self.rev_segmaps[j] == label)
1077
- # self.imgs[j][semantic_mask.cpu()] = self.smoothed_imgs[j][semantic_mask.cpu()]
1078
-
1079
- # if init is None:
1080
- # pass
1081
- # elif init == 'msp' or init == 'mst':
1082
- # init_fun.init_minimum_spanning_tree(self, niter_PnP=niter_PnP)
1083
- # elif init == 'known_poses':
1084
- # init_fun.init_from_known_poses(self, min_conf_thr=self.min_conf_thr,
1085
- # niter_PnP=niter_PnP)
1086
- # else:
1087
- # raise ValueError(f'bad value for {init=}')
1088
-
1089
- # if tune_flg:
1090
- # return 0
1091
- # # loss = 0
1092
- # loss = global_alignment_loop(self, **kw)
1093
- # #
1094
- # # init_fun.init_minimum_spanning_tree(self, niter_PnP=niter_PnP)
1095
- # return loss
1096
-
1097
- # @torch.no_grad()
1098
- # def mask_sky(self):
1099
- # res = deepcopy(self)
1100
- # for i in range(self.n_imgs):
1101
- # sky = segment_sky(self.imgs[i])
1102
- # res.im_conf[i][sky] = 0
1103
- # return res
1104
-
1105
- # def show(self, show_pw_cams=False, show_pw_pts3d=False, cam_size=None, **kw):
1106
- # viz = SceneViz()
1107
- # if self.imgs is None:
1108
- # colors = np.random.randint(0, 256, size=(self.n_imgs, 3))
1109
- # colors = list(map(tuple, colors.tolist()))
1110
- # for n in range(self.n_imgs):
1111
- # viz.add_pointcloud(self.get_pts3d()[n], colors[n], self.get_masks()[n])
1112
- # else:
1113
- # viz.add_pointcloud(self.get_pts3d(), self.imgs, self.get_masks())
1114
- # colors = np.random.randint(256, size=(self.n_imgs, 3))
1115
-
1116
- # # camera poses
1117
- # im_poses = to_numpy(self.get_im_poses())
1118
- # if cam_size is None:
1119
- # cam_size = auto_cam_size(im_poses)
1120
- # viz.add_cameras(im_poses, self.get_focals(), colors=colors,
1121
- # images=self.imgs, imsizes=self.imsizes, cam_size=cam_size)
1122
- # if show_pw_cams:
1123
- # pw_poses = self.get_pw_poses()
1124
- # viz.add_cameras(pw_poses, color=(192, 0, 192), cam_size=cam_size)
1125
-
1126
- # if show_pw_pts3d:
1127
- # pts = [geotrf(pw_poses[e], self.pred_i[edge_str(i, j)]) for e, (i, j) in enumerate(self.edges)]
1128
- # viz.add_pointcloud(pts, (128, 0, 128))
1129
-
1130
- # viz.show(**kw)
1131
- # return viz
1132
-
1133
-
1134
- # def global_alignment_loop(net, lr=0.01, niter=300, schedule='cosine', lr_min=1e-6):
1135
- # # return net
1136
- # params = [p for p in net.parameters() if p.requires_grad]
1137
- # for param in params:
1138
- # print(param.shape)
1139
- # if not params:
1140
- # return net
1141
-
1142
- # verbose = net.verbose
1143
- # if verbose:
1144
- # print('Global alignement - optimizing for:')
1145
- # print([name for name, value in net.named_parameters() if value.requires_grad])
1146
-
1147
- # lr_base = lr
1148
- # optimizer = torch.optim.Adam(params, lr=lr, betas=(0.9, 0.9))
1149
-
1150
- # loss = float('inf')
1151
- # if verbose:
1152
- # with tqdm.tqdm(total=niter) as bar:
1153
- # while bar.n < bar.total:
1154
- # loss, lr = global_alignment_iter(net, bar.n, niter, lr_base, lr_min, optimizer, schedule)
1155
- # bar.set_postfix_str(f'{lr=:g} loss={loss:g}')
1156
- # bar.update()
1157
- # else:
1158
- # for n in range(niter):
1159
- # loss, _ = global_alignment_iter(net, n, niter, lr_base, lr_min, optimizer, schedule)
1160
- # return loss
1161
-
1162
-
1163
- # def global_alignment_iter(net, cur_iter, niter, lr_base, lr_min, optimizer, schedule):
1164
- # t = cur_iter / niter
1165
- # if schedule == 'cosine':
1166
- # lr = cosine_schedule(t, lr_base, lr_min)
1167
- # elif schedule == 'linear':
1168
- # lr = linear_schedule(t, lr_base, lr_min)
1169
- # else:
1170
- # raise ValueError(f'bad lr {schedule=}')
1171
- # adjust_learning_rate_by_lr(optimizer, lr)
1172
- # optimizer.zero_grad()
1173
- # loss = net(cur_iter)
1174
- # if loss == 0:
1175
- # optimizer.step()
1176
- # return float(loss), lr
1177
-
1178
- # loss.backward()
1179
- # optimizer.step()
1180
-
1181
- # return float(loss), lr
1182
-
1183
-
1184
- # @torch.no_grad()
1185
- # def clean_pointcloud( im_confs, K, cams, depthmaps, all_pts3d,
1186
- # tol=0.001, bad_conf=0, dbg=()):
1187
- # """ Method:
1188
- # 1) express all 3d points in each camera coordinate frame
1189
- # 2) if they're in front of a depthmap --> then lower their confidence
1190
- # """
1191
- # assert len(im_confs) == len(cams) == len(K) == len(depthmaps) == len(all_pts3d)
1192
- # assert 0 <= tol < 1
1193
- # res = [c.clone() for c in im_confs]
1194
-
1195
- # # reshape appropriately
1196
- # all_pts3d = [p.view(*c.shape,3) for p,c in zip(all_pts3d, im_confs)]
1197
- # depthmaps = [d.view(*c.shape) for d,c in zip(depthmaps, im_confs)]
1198
-
1199
- # for i, pts3d in enumerate(all_pts3d):
1200
- # for j in range(len(all_pts3d)):
1201
- # if i == j: continue
1202
-
1203
- # # project 3dpts in other view
1204
- # proj = geotrf(cams[j], pts3d)
1205
- # proj_depth = proj[:,:,2]
1206
- # u,v = geotrf(K[j], proj, norm=1, ncol=2).round().long().unbind(-1)
1207
-
1208
- # # check which points are actually in the visible cone
1209
- # H, W = im_confs[j].shape
1210
- # msk_i = (proj_depth > 0) & (0 <= u) & (u < W) & (0 <= v) & (v < H)
1211
- # msk_j = v[msk_i], u[msk_i]
1212
-
1213
- # # find bad points = those in front but less confident
1214
- # bad_points = (proj_depth[msk_i] < (1-tol) * depthmaps[j][msk_j]) & (res[i][msk_i] < res[j][msk_j])
1215
-
1216
- # bad_msk_i = msk_i.clone()
1217
- # bad_msk_i[msk_i] = bad_points
1218
- # res[i][bad_msk_i] = res[i][bad_msk_i].clip_(max=bad_conf)
1219
-
1220
- # return res
 
44
  else:
45
  self._init_from_views(*args, **kwargs)
46
 
47
+ def _init_from_views(self, view1, view2, pred1, pred2, cog_seg_maps, rev_cog_seg_maps, semantic_feats, device,
48
  dist='l2',
49
  conf='log',
50
  min_conf_thr=3,
 
121
  self.fix_imgs = rgb(ori_imgs)
122
  self.smoothed_imgs = rgb(smoothed_imgs)
123
 
124
+ self.cogs = [torch.zeros((h, w, 1024), device=device) for h, w in self.imshapes]
125
+ semantic_feats = semantic_feats.to(device)
126
+ self.segmaps = [-torch.ones((h, w), device=device) for h, w in self.imshapes]
127
+ self.rev_segmaps = [-torch.ones((h, w), device=device) for h, w in self.imshapes]
128
 
129
  for v in range(len(self.edges)):
130
  idx = view1['idx'][v]
 
141
  seg = cog_seg_map[y, x].squeeze(-1).long()
142
 
143
  self.cogs[idx] = semantic_feats[seg].reshape(h, w, -1)
144
+ self.segmaps[idx] = cog_seg_map.to(device)
145
+ self.rev_segmaps[idx] = rev_seg_map.to(device)
146
 
147
  idx = view2['idx'][v]
148
  h, w = self.cogs[idx].shape[0], self.cogs[idx].shape[1]
 
157
  seg = cog_seg_map[y, x].squeeze(-1).long()
158
 
159
  self.cogs[idx] = semantic_feats[seg].reshape(h, w, -1)
160
+ self.segmaps[idx] = cog_seg_map.to(device)
161
+ self.rev_segmaps[idx] = rev_seg_map.to(device)
162
 
163
  self.rendered_imgs = []
164
 
 
612
  bad_msk_i[msk_i] = bad_points
613
  res[i][bad_msk_i] = res[i][bad_msk_i].clip_(max=bad_conf)
614
 
615
+ return res
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
modules/pe3r/__pycache__/demo.cpython-312.pyc CHANGED
Binary files a/modules/pe3r/__pycache__/demo.cpython-312.pyc and b/modules/pe3r/__pycache__/demo.cpython-312.pyc differ
 
modules/pe3r/__pycache__/models.cpython-312.pyc CHANGED
Binary files a/modules/pe3r/__pycache__/models.cpython-312.pyc and b/modules/pe3r/__pycache__/models.cpython-312.pyc differ
 
modules/pe3r/demo.py CHANGED
@@ -257,8 +257,6 @@ def get_mask_from_img_sam1(mobilesamv2, yolov8, sam1_image, yolov8_image, origin
257
  input_image = mobilesamv2.preprocess(sam1_image)
258
  image_embedding = mobilesamv2.image_encoder(input_image)['last_hidden_state']
259
 
260
- print(image_embedding.shape)
261
-
262
  image_embedding=torch.repeat_interleave(image_embedding, 320, dim=0)
263
  prompt_embedding=mobilesamv2.prompt_encoder.get_dense_pe()
264
  prompt_embedding=torch.repeat_interleave(prompt_embedding, 320, dim=0)
@@ -324,7 +322,7 @@ def get_cog_feats(images, pe3r, device):
324
  if out_frame_idx == 0:
325
  continue
326
 
327
- sam1_masks = get_mask_from_img_sam1(pe3r.mobilesamv2, pe3r.yolov8, sam1_images[out_frame_idx], np_images[out_frame_idx], np_images_size[out_frame_idx], sam1_images_size[out_frame_idx], images.sam1_transform)
328
 
329
  for sam1_mask in sam1_masks:
330
  flg = 1
 
257
  input_image = mobilesamv2.preprocess(sam1_image)
258
  image_embedding = mobilesamv2.image_encoder(input_image)['last_hidden_state']
259
 
 
 
260
  image_embedding=torch.repeat_interleave(image_embedding, 320, dim=0)
261
  prompt_embedding=mobilesamv2.prompt_encoder.get_dense_pe()
262
  prompt_embedding=torch.repeat_interleave(prompt_embedding, 320, dim=0)
 
322
  if out_frame_idx == 0:
323
  continue
324
 
325
+ sam1_masks = get_mask_from_img_sam1(pe3r.mobilesamv2, pe3r.yolov8, sam1_images[out_frame_idx], np_images[out_frame_idx], np_images_size[out_frame_idx], sam1_images_size[out_frame_idx], images.sam1_transform, device)
326
 
327
  for sam1_mask in sam1_masks:
328
  flg = 1
modules/ultralytics/yolo/utils/callbacks/__pycache__/clearml.cpython-312.pyc CHANGED
Binary files a/modules/ultralytics/yolo/utils/callbacks/__pycache__/clearml.cpython-312.pyc and b/modules/ultralytics/yolo/utils/callbacks/__pycache__/clearml.cpython-312.pyc differ
 
modules/ultralytics/yolo/utils/callbacks/__pycache__/comet.cpython-312.pyc CHANGED
Binary files a/modules/ultralytics/yolo/utils/callbacks/__pycache__/comet.cpython-312.pyc and b/modules/ultralytics/yolo/utils/callbacks/__pycache__/comet.cpython-312.pyc differ
 
modules/ultralytics/yolo/utils/callbacks/__pycache__/dvc.cpython-312.pyc CHANGED
Binary files a/modules/ultralytics/yolo/utils/callbacks/__pycache__/dvc.cpython-312.pyc and b/modules/ultralytics/yolo/utils/callbacks/__pycache__/dvc.cpython-312.pyc differ
 
modules/ultralytics/yolo/utils/callbacks/__pycache__/hub.cpython-312.pyc CHANGED
Binary files a/modules/ultralytics/yolo/utils/callbacks/__pycache__/hub.cpython-312.pyc and b/modules/ultralytics/yolo/utils/callbacks/__pycache__/hub.cpython-312.pyc differ
 
modules/ultralytics/yolo/utils/callbacks/__pycache__/mlflow.cpython-312.pyc CHANGED
Binary files a/modules/ultralytics/yolo/utils/callbacks/__pycache__/mlflow.cpython-312.pyc and b/modules/ultralytics/yolo/utils/callbacks/__pycache__/mlflow.cpython-312.pyc differ
 
modules/ultralytics/yolo/utils/callbacks/__pycache__/neptune.cpython-312.pyc CHANGED
Binary files a/modules/ultralytics/yolo/utils/callbacks/__pycache__/neptune.cpython-312.pyc and b/modules/ultralytics/yolo/utils/callbacks/__pycache__/neptune.cpython-312.pyc differ
 
modules/ultralytics/yolo/utils/callbacks/__pycache__/raytune.cpython-312.pyc CHANGED
Binary files a/modules/ultralytics/yolo/utils/callbacks/__pycache__/raytune.cpython-312.pyc and b/modules/ultralytics/yolo/utils/callbacks/__pycache__/raytune.cpython-312.pyc differ
 
modules/ultralytics/yolo/utils/callbacks/__pycache__/tensorboard.cpython-312.pyc CHANGED
Binary files a/modules/ultralytics/yolo/utils/callbacks/__pycache__/tensorboard.cpython-312.pyc and b/modules/ultralytics/yolo/utils/callbacks/__pycache__/tensorboard.cpython-312.pyc differ
 
modules/ultralytics/yolo/utils/callbacks/__pycache__/wb.cpython-312.pyc CHANGED
Binary files a/modules/ultralytics/yolo/utils/callbacks/__pycache__/wb.cpython-312.pyc and b/modules/ultralytics/yolo/utils/callbacks/__pycache__/wb.cpython-312.pyc differ