hujiecpp commited on
Commit
4fea465
·
1 Parent(s): e35a902

init project

Browse files
app.py CHANGED
@@ -1,56 +1,655 @@
 
 
 
 
 
 
1
  import os
2
- import tempfile
3
-
4
  import sys
5
  sys.path.append(os.path.abspath('./modules'))
6
 
7
- # import builtins
8
- # import datetime
9
- import argparse
10
-
11
- from modules.pe3r.demo import main_demo
12
- # import torch
13
-
14
- # def set_print_with_timestamp(time_format="%Y-%m-%d %H:%M:%S"):
15
- # builtin_print = builtins.print
16
- # def print_with_timestamp(*args, **kwargs):
17
- # now = datetime.datetime.now()
18
- # formatted_date_time = now.strftime(time_format)
19
- # builtin_print(f'[{formatted_date_time}] ', end='') # print with time stamp
20
- # builtin_print(*args, **kwargs)
21
- # builtins.print = print_with_timestamp
22
-
23
- def get_args_parser():
24
- parser = argparse.ArgumentParser()
25
- parser_url = parser.add_mutually_exclusive_group()
26
- parser_url.add_argument("--local_network", action='store_true', default=False,
27
- help="make app accessible on local network: address will be set to 0.0.0.0")
28
- parser_url.add_argument("--server_name", type=str, default=None, help="server url, default is 127.0.0.1")
29
- parser.add_argument("--server_port", type=int, help=("will start gradio app on this port (if available). "
30
- "If None, will search for an available port starting at 7860."), default=None)
31
- # parser.add_argument("--device", type=str, default='cuda', help="pytorch device")
32
- parser.add_argument("--tmp_dir", type=str, default=None, help="value for tempfile.tempdir")
33
- parser.add_argument("--silent", action='store_true', default=False, help="silence logs")
34
- # change defaults
35
- parser.prog = 'pe3r demo'
36
- return parser
37
-
38
- if __name__ == '__main__':
39
- parser = get_args_parser()
40
- args = parser.parse_args()
41
- # set_print_with_timestamp()
42
-
43
- if args.tmp_dir is not None:
44
- tmp_path = args.tmp_dir
45
- os.makedirs(tmp_path, exist_ok=True)
46
- tempfile.tempdir = tmp_path
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
47
 
48
- if args.server_name is not None:
49
- server_name = args.server_name
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
50
  else:
51
- server_name = '0.0.0.0' if args.local_network else '127.0.0.1'
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
52
 
53
- with tempfile.TemporaryDirectory(suffix='pe3r_gradio_demo') as tmpdirname:
54
- if not args.silent:
55
- print('Outputing stuff in', tmpdirname)
56
- main_demo(tmpdirname, server_name, args.server_port, silent=args.silent)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (C) 2024-present Naver Corporation. All rights reserved.
2
+ # Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
3
+ #
4
+ # --------------------------------------------------------
5
+ # gradio demo
6
+ # --------------------------------------------------------
7
  import os
 
 
8
  import sys
9
  sys.path.append(os.path.abspath('./modules'))
10
 
11
+ import math
12
+ import tempfile
13
+ import gradio
14
+ import os
15
+ import torch
16
+ import spaces
17
+ import numpy as np
18
+ import functools
19
+ import trimesh
20
+ import copy
21
+ from PIL import Image
22
+ from scipy.spatial.transform import Rotation
23
+
24
+ from modules.pe3r.images import Images
25
+
26
+ from modules.dust3r.inference import inference
27
+ from modules.dust3r.image_pairs import make_pairs
28
+ from modules.dust3r.utils.image import load_images, rgb
29
+ from modules.dust3r.utils.device import to_numpy
30
+ from modules.dust3r.viz import add_scene_cam, CAM_COLORS, OPENGL, pts3d_to_trimesh, cat_meshes
31
+ from modules.dust3r.cloud_opt import global_aligner, GlobalAlignerMode
32
+ from copy import deepcopy
33
+ import cv2
34
+ from typing import Any, Dict, Generator,List
35
+ import matplotlib.pyplot as pl
36
+
37
+ from modules.mobilesamv2.utils.transforms import ResizeLongestSide
38
+ from modules.pe3r.models import Models
39
+
40
+ def _convert_scene_output_to_glb(outdir, imgs, pts3d, mask, focals, cams2world, cam_size=0.05,
41
+ cam_color=None, as_pointcloud=False,
42
+ transparent_cams=False, silent=False):
43
+ assert len(pts3d) == len(mask) <= len(imgs) <= len(cams2world) == len(focals)
44
+ pts3d = to_numpy(pts3d)
45
+ imgs = to_numpy(imgs)
46
+ focals = to_numpy(focals)
47
+ cams2world = to_numpy(cams2world)
48
+
49
+ scene = trimesh.Scene()
50
+
51
+ # full pointcloud
52
+ if as_pointcloud:
53
+ pts = np.concatenate([p[m] for p, m in zip(pts3d, mask)])
54
+ col = np.concatenate([p[m] for p, m in zip(imgs, mask)])
55
+ pct = trimesh.PointCloud(pts.reshape(-1, 3), colors=col.reshape(-1, 3))
56
+ scene.add_geometry(pct)
57
+ else:
58
+ meshes = []
59
+ for i in range(len(imgs)):
60
+ meshes.append(pts3d_to_trimesh(imgs[i], pts3d[i], mask[i]))
61
+ mesh = trimesh.Trimesh(**cat_meshes(meshes))
62
+ scene.add_geometry(mesh)
63
+
64
+ # add each camera
65
+ for i, pose_c2w in enumerate(cams2world):
66
+ if isinstance(cam_color, list):
67
+ camera_edge_color = cam_color[i]
68
+ else:
69
+ camera_edge_color = cam_color or CAM_COLORS[i % len(CAM_COLORS)]
70
+ add_scene_cam(scene, pose_c2w, camera_edge_color,
71
+ None if transparent_cams else imgs[i], focals[i],
72
+ imsize=imgs[i].shape[1::-1], screen_width=cam_size)
73
+
74
+ rot = np.eye(4)
75
+ rot[:3, :3] = Rotation.from_euler('y', np.deg2rad(180)).as_matrix()
76
+ scene.apply_transform(np.linalg.inv(cams2world[0] @ OPENGL @ rot))
77
+ outfile = os.path.join(outdir, 'scene.glb')
78
+ if not silent:
79
+ print('(exporting 3D scene to', outfile, ')')
80
+ scene.export(file_obj=outfile)
81
+ return outfile
82
+
83
+ @spaces.GPU(duration=180)
84
+ def get_3D_model_from_scene(outdir, silent, scene, min_conf_thr=3, as_pointcloud=False, mask_sky=False,
85
+ clean_depth=False, transparent_cams=False, cam_size=0.05):
86
+ """
87
+ extract 3D_model (glb file) from a reconstructed scene
88
+ """
89
+ if scene is None:
90
+ return None
91
+ # post processes
92
+ if clean_depth:
93
+ scene = scene.clean_pointcloud()
94
+ if mask_sky:
95
+ scene = scene.mask_sky()
96
+
97
+ # get optimized values from scene
98
+ rgbimg = scene.ori_imgs
99
+ focals = scene.get_focals().cpu()
100
+ cams2world = scene.get_im_poses().cpu()
101
+ # 3D pointcloud from depthmap, poses and intrinsics
102
+ pts3d = to_numpy(scene.get_pts3d())
103
+ scene.min_conf_thr = float(scene.conf_trf(torch.tensor(min_conf_thr)))
104
+ msk = to_numpy(scene.get_masks())
105
+ return _convert_scene_output_to_glb(outdir, rgbimg, pts3d, msk, focals, cams2world, as_pointcloud=as_pointcloud,
106
+ transparent_cams=transparent_cams, cam_size=cam_size, silent=silent)
107
+
108
+ def mask_nms(masks, threshold=0.8):
109
+ keep = []
110
+ mask_num = len(masks)
111
+ suppressed = np.zeros((mask_num), dtype=np.int64)
112
+ for i in range(mask_num):
113
+ if suppressed[i] == 1:
114
+ continue
115
+ keep.append(i)
116
+ for j in range(i + 1, mask_num):
117
+ if suppressed[j] == 1:
118
+ continue
119
+ intersection = (masks[i] & masks[j]).sum()
120
+ if min(intersection / masks[i].sum(), intersection / masks[j].sum()) > threshold:
121
+ suppressed[j] = 1
122
+ return keep
123
+
124
+ def filter(masks, keep):
125
+ ret = []
126
+ for i, m in enumerate(masks):
127
+ if i in keep: ret.append(m)
128
+ return ret
129
+
130
+ def mask_to_box(mask):
131
+ if mask.sum() == 0:
132
+ return np.array([0, 0, 0, 0])
133
+
134
+ # Get the rows and columns where the mask is 1
135
+ rows = np.any(mask, axis=1)
136
+ cols = np.any(mask, axis=0)
137
+
138
+ # Get top, bottom, left, right edges
139
+ top = np.argmax(rows)
140
+ bottom = len(rows) - 1 - np.argmax(np.flip(rows))
141
+ left = np.argmax(cols)
142
+ right = len(cols) - 1 - np.argmax(np.flip(cols))
143
+
144
+ return np.array([left, top, right, bottom])
145
+
146
+ def box_xyxy_to_xywh(box_xyxy):
147
+ box_xywh = deepcopy(box_xyxy)
148
+ box_xywh[2] = box_xywh[2] - box_xywh[0]
149
+ box_xywh[3] = box_xywh[3] - box_xywh[1]
150
+ return box_xywh
151
+
152
+ def get_seg_img(mask, box, image):
153
+ image = image.copy()
154
+ x, y, w, h = box
155
+ # image[mask == 0] = np.array([0, 0, 0], dtype=np.uint8)
156
+ box_area = w * h
157
+ mask_area = mask.sum()
158
+ if 1 - (mask_area / box_area) < 0.2:
159
+ image[mask == 0] = np.array([0, 0, 0], dtype=np.uint8)
160
+ else:
161
+ random_values = np.random.randint(0, 255, size=image.shape, dtype=np.uint8)
162
+ image[mask == 0] = random_values[mask == 0]
163
+ seg_img = image[y:y+h, x:x+w, ...]
164
+ return seg_img
165
+
166
+ def pad_img(img):
167
+ h, w, _ = img.shape
168
+ l = max(w,h)
169
+ pad = np.zeros((l,l,3), dtype=np.uint8) #
170
+ if h > w:
171
+ pad[:,(h-w)//2:(h-w)//2 + w, :] = img
172
+ else:
173
+ pad[(w-h)//2:(w-h)//2 + h, :, :] = img
174
+ return pad
175
+
176
+ def batch_iterator(batch_size: int, *args) -> Generator[List[Any], None, None]:
177
+ assert len(args) > 0 and all(
178
+ len(a) == len(args[0]) for a in args
179
+ ), "Batched iteration must have inputs of all the same size."
180
+ n_batches = len(args[0]) // batch_size + int(len(args[0]) % batch_size != 0)
181
+ for b in range(n_batches):
182
+ yield [arg[b * batch_size : (b + 1) * batch_size] for arg in args]
183
+
184
+ def slerp(u1, u2, t):
185
+ """
186
+ Perform spherical linear interpolation (Slerp) between two unit vectors.
187
+
188
+ Args:
189
+ - u1 (torch.Tensor): First unit vector, shape (1024,)
190
+ - u2 (torch.Tensor): Second unit vector, shape (1024,)
191
+ - t (float): Interpolation parameter
192
+
193
+ Returns:
194
+ - torch.Tensor: Interpolated vector, shape (1024,)
195
+ """
196
+ # Compute the dot product
197
+ dot_product = torch.sum(u1 * u2)
198
+
199
+ # Ensure the dot product is within the valid range [-1, 1]
200
+ dot_product = torch.clamp(dot_product, -1.0, 1.0)
201
+
202
+ # Compute the angle between the vectors
203
+ theta = torch.acos(dot_product)
204
+
205
+ # Compute the coefficients for the interpolation
206
+ sin_theta = torch.sin(theta)
207
+ if sin_theta == 0:
208
+ # Vectors are parallel, return a linear interpolation
209
+ return u1 + t * (u2 - u1)
210
+
211
+ s1 = torch.sin((1 - t) * theta) / sin_theta
212
+ s2 = torch.sin(t * theta) / sin_theta
213
+
214
+ # Perform the interpolation
215
+ return s1 * u1 + s2 * u2
216
+
217
+ def slerp_multiple(vectors, t_values):
218
+ """
219
+ Perform spherical linear interpolation (Slerp) for multiple vectors.
220
+
221
+ Args:
222
+ - vectors (torch.Tensor): Tensor of vectors, shape (n, 1024)
223
+ - a_values (torch.Tensor): Tensor of values corresponding to each vector, shape (n,)
224
+
225
+ Returns:
226
+ - torch.Tensor: Interpolated vector, shape (1024,)
227
+ """
228
+ n = vectors.shape[0]
229
+
230
+ # Initialize the interpolated vector with the first vector
231
+ interpolated_vector = vectors[0]
232
+
233
+ # Perform Slerp iteratively
234
+ for i in range(1, n):
235
+ # Perform Slerp between the current interpolated vector and the next vector
236
+ t = t_values[i] / (t_values[i] + t_values[i-1])
237
+ interpolated_vector = slerp(interpolated_vector, vectors[i], t)
238
+
239
+ return interpolated_vector
240
+
241
+ @torch.no_grad
242
+ def get_mask_from_img_sam1(mobilesamv2, yolov8, sam1_image, yolov8_image, original_size, input_size, transform, device):
243
+ sam_mask=[]
244
+ img_area = original_size[0] * original_size[1]
245
+
246
+ obj_results = yolov8(yolov8_image,device=device,retina_masks=False,imgsz=1024,conf=0.25,iou=0.95,verbose=False)
247
+ input_boxes1 = obj_results[0].boxes.xyxy
248
+ input_boxes1 = input_boxes1.cpu().numpy()
249
+ input_boxes1 = transform.apply_boxes(input_boxes1, original_size)
250
+ input_boxes = torch.from_numpy(input_boxes1).to(device)
251
+
252
+ # obj_results = yolov8(yolov8_image,device=device,retina_masks=False,imgsz=512,conf=0.25,iou=0.9,verbose=False)
253
+ # input_boxes2 = obj_results[0].boxes.xyxy
254
+ # input_boxes2 = input_boxes2.cpu().numpy()
255
+ # input_boxes2 = transform.apply_boxes(input_boxes2, original_size)
256
+ # input_boxes2 = torch.from_numpy(input_boxes2).to(device)
257
+
258
+ # input_boxes = torch.cat((input_boxes1, input_boxes2), dim=0)
259
+
260
+ input_image = mobilesamv2.preprocess(sam1_image)
261
+ image_embedding = mobilesamv2.image_encoder(input_image)['last_hidden_state']
262
+
263
+ image_embedding=torch.repeat_interleave(image_embedding, 320, dim=0)
264
+ prompt_embedding=mobilesamv2.prompt_encoder.get_dense_pe()
265
+ prompt_embedding=torch.repeat_interleave(prompt_embedding, 320, dim=0)
266
+ for (boxes,) in batch_iterator(320, input_boxes):
267
+ with torch.no_grad():
268
+ image_embedding=image_embedding[0:boxes.shape[0],:,:,:]
269
+ prompt_embedding=prompt_embedding[0:boxes.shape[0],:,:,:]
270
+ sparse_embeddings, dense_embeddings = mobilesamv2.prompt_encoder(
271
+ points=None,
272
+ boxes=boxes,
273
+ masks=None,)
274
+ low_res_masks, _ = mobilesamv2.mask_decoder(
275
+ image_embeddings=image_embedding,
276
+ image_pe=prompt_embedding,
277
+ sparse_prompt_embeddings=sparse_embeddings,
278
+ dense_prompt_embeddings=dense_embeddings,
279
+ multimask_output=False,
280
+ simple_type=True,
281
+ )
282
+ low_res_masks=mobilesamv2.postprocess_masks(low_res_masks, input_size, original_size)
283
+ sam_mask_pre = (low_res_masks > mobilesamv2.mask_threshold)
284
+ for mask in sam_mask_pre:
285
+ if mask.sum() / img_area > 0.002:
286
+ sam_mask.append(mask.squeeze(1))
287
+ sam_mask=torch.cat(sam_mask)
288
+ sorted_sam_mask = sorted(sam_mask, key=(lambda x: x.sum()), reverse=True)
289
+ keep = mask_nms(sorted_sam_mask)
290
+ ret_mask = filter(sorted_sam_mask, keep)
291
+
292
+ return ret_mask
293
+
294
+ @torch.no_grad
295
+ def get_cog_feats(images, pe3r, device):
296
+ cog_seg_maps = []
297
+ rev_cog_seg_maps = []
298
+ inference_state = pe3r.sam2.init_state(images=images.sam2_images, video_height=images.sam2_video_size[0], video_width=images.sam2_video_size[1])
299
+ mask_num = 0
300
+
301
+ sam1_images = images.sam1_images
302
+ sam1_images_size = images.sam1_images_size
303
+ np_images = images.np_images
304
+ np_images_size = images.np_images_size
305
+
306
+ sam1_masks = get_mask_from_img_sam1(pe3r.mobilesamv2, pe3r.yolov8, sam1_images[0], np_images[0], np_images_size[0], sam1_images_size[0], images.sam1_transform, device)
307
+ for mask in sam1_masks:
308
+ _, _, _ = pe3r.sam2.add_new_mask(
309
+ inference_state=inference_state,
310
+ frame_idx=0,
311
+ obj_id=mask_num,
312
+ mask=mask,
313
+ )
314
+ mask_num += 1
315
+
316
+ video_segments = {} # video_segments contains the per-frame segmentation results
317
+ for out_frame_idx, out_obj_ids, out_mask_logits in pe3r.sam2.propagate_in_video(inference_state):
318
+ sam2_masks = (out_mask_logits > 0.0).squeeze(1)
319
+
320
+ video_segments[out_frame_idx] = {
321
+ out_obj_id: sam2_masks[i].cpu().numpy()
322
+ for i, out_obj_id in enumerate(out_obj_ids)
323
+ }
324
+
325
+ if out_frame_idx == 0:
326
+ continue
327
+
328
+ sam1_masks = get_mask_from_img_sam1(pe3r.mobilesamv2, pe3r.yolov8, sam1_images[out_frame_idx], np_images[out_frame_idx], np_images_size[out_frame_idx], sam1_images_size[out_frame_idx], images.sam1_transform, device)
329
+
330
+ for sam1_mask in sam1_masks:
331
+ flg = 1
332
+ for sam2_mask in sam2_masks:
333
+ # print(sam1_mask.shape, sam2_mask.shape)
334
+ area1 = sam1_mask.sum()
335
+ area2 = sam2_mask.sum()
336
+ intersection = (sam1_mask & sam2_mask).sum()
337
+ if min(intersection / area1, intersection / area2) > 0.25:
338
+ flg = 0
339
+ break
340
+ if flg:
341
+ video_segments[out_frame_idx][mask_num] = sam1_mask.cpu().numpy()
342
+ mask_num += 1
343
+
344
+ multi_view_clip_feats = torch.zeros((mask_num+1, 1024))
345
+ multi_view_clip_feats_map = {}
346
+ multi_view_clip_area_map = {}
347
+ for now_frame in range(0, len(video_segments), 1):
348
+ image = np_images[now_frame]
349
+
350
+ seg_img_list = []
351
+ out_obj_id_list = []
352
+ out_obj_mask_list = []
353
+ out_obj_area_list = []
354
+ # NOTE: background: -1
355
+ rev_seg_map = -np.ones(image.shape[:2], dtype=np.int64)
356
+ sorted_dict_items = sorted(video_segments[now_frame].items(), key=lambda x: np.count_nonzero(x[1]), reverse=False)
357
+ for out_obj_id, mask in sorted_dict_items:
358
+ if mask.sum() == 0:
359
+ continue
360
+ rev_seg_map[mask] = out_obj_id
361
+ rev_cog_seg_maps.append(rev_seg_map)
362
+
363
+ seg_map = -np.ones(image.shape[:2], dtype=np.int64)
364
+ sorted_dict_items = sorted(video_segments[now_frame].items(), key=lambda x: np.count_nonzero(x[1]), reverse=True)
365
+ for out_obj_id, mask in sorted_dict_items:
366
+ if mask.sum() == 0:
367
+ continue
368
+ box = np.int32(box_xyxy_to_xywh(mask_to_box(mask)))
369
+
370
+ if box[2] == 0 and box[3] == 0:
371
+ continue
372
+ # print(box)
373
+ seg_img = get_seg_img(mask, box, image)
374
+ pad_seg_img = cv2.resize(pad_img(seg_img), (256,256))
375
+ seg_img_list.append(pad_seg_img)
376
+ seg_map[mask] = out_obj_id
377
+ out_obj_id_list.append(out_obj_id)
378
+ out_obj_area_list.append(np.count_nonzero(mask))
379
+ out_obj_mask_list.append(mask)
380
+
381
+ if len(seg_img_list) == 0:
382
+ cog_seg_maps.append(seg_map)
383
+ continue
384
+
385
+ seg_imgs = np.stack(seg_img_list, axis=0) # b,H,W,3
386
+ seg_imgs = torch.from_numpy(seg_imgs).permute(0,3,1,2) # / 255.0
387
+
388
+ inputs = pe3r.siglip_processor(images=seg_imgs, return_tensors="pt")
389
+ inputs = {key: value.to(device) for key, value in inputs.items()}
390
 
391
+ image_features = pe3r.siglip.get_image_features(**inputs)
392
+ image_features = image_features / image_features.norm(dim=-1, keepdim=True)
393
+ image_features = image_features.detach().cpu()
394
+
395
+ for i in range(len(out_obj_mask_list)):
396
+ for j in range(i + 1, len(out_obj_mask_list)):
397
+ mask1 = out_obj_mask_list[i]
398
+ mask2 = out_obj_mask_list[j]
399
+ intersection = np.logical_and(mask1, mask2).sum()
400
+ area1 = out_obj_area_list[i]
401
+ area2 = out_obj_area_list[j]
402
+ if min(intersection / area1, intersection / area2) > 0.025:
403
+ conf1 = area1 / (area1 + area2)
404
+ # conf2 = area2 / (area1 + area2)
405
+ image_features[j] = slerp(image_features[j], image_features[i], conf1)
406
+
407
+ for i, clip_feat in enumerate(image_features):
408
+ id = out_obj_id_list[i]
409
+ if id in multi_view_clip_feats_map.keys():
410
+ multi_view_clip_feats_map[id].append(clip_feat)
411
+ multi_view_clip_area_map[id].append(out_obj_area_list[i])
412
+ else:
413
+ multi_view_clip_feats_map[id] = [clip_feat]
414
+ multi_view_clip_area_map[id] = [out_obj_area_list[i]]
415
+
416
+ cog_seg_maps.append(seg_map)
417
+ del image_features
418
+
419
+ for i in range(mask_num):
420
+ if i in multi_view_clip_feats_map.keys():
421
+ clip_feats = multi_view_clip_feats_map[i]
422
+ mask_area = multi_view_clip_area_map[i]
423
+ multi_view_clip_feats[i] = slerp_multiple(torch.stack(clip_feats), np.stack(mask_area))
424
+ else:
425
+ multi_view_clip_feats[i] = torch.zeros((1024))
426
+ multi_view_clip_feats[mask_num] = torch.zeros((1024))
427
+
428
+ return cog_seg_maps, rev_cog_seg_maps, multi_view_clip_feats
429
+
430
+ @spaces.GPU(duration=180)
431
+ def get_reconstructed_scene(outdir, pe3r, device, silent, filelist, schedule, niter, min_conf_thr,
432
+ as_pointcloud, mask_sky, clean_depth, transparent_cams, cam_size,
433
+ scenegraph_type, winsize, refid):
434
+ """
435
+ from a list of images, run dust3r inference, global aligner.
436
+ then run get_3D_model_from_scene
437
+ """
438
+ if len(filelist) < 2:
439
+ raise gradio.Error("Please input at least 2 images.")
440
+
441
+ images = Images(filelist=filelist, device=device)
442
+
443
+ # try:
444
+ cog_seg_maps, rev_cog_seg_maps, cog_feats = get_cog_feats(images, pe3r, device)
445
+ imgs = load_images(images, rev_cog_seg_maps, size=512, verbose=not silent)
446
+ # except Exception as e:
447
+ # rev_cog_seg_maps = []
448
+ # for tmp_img in images.np_images:
449
+ # rev_seg_map = -np.ones(tmp_img.shape[:2], dtype=np.int64)
450
+ # rev_cog_seg_maps.append(rev_seg_map)
451
+ # cog_seg_maps = rev_cog_seg_maps
452
+ # cog_feats = torch.zeros((1, 1024))
453
+ # imgs = load_images(images, rev_cog_seg_maps, size=512, verbose=not silent)
454
+
455
+ if len(imgs) == 1:
456
+ imgs = [imgs[0], copy.deepcopy(imgs[0])]
457
+ imgs[1]['idx'] = 1
458
+
459
+ if scenegraph_type == "swin":
460
+ scenegraph_type = scenegraph_type + "-" + str(winsize)
461
+ elif scenegraph_type == "oneref":
462
+ scenegraph_type = scenegraph_type + "-" + str(refid)
463
+
464
+ pairs = make_pairs(imgs, scene_graph=scenegraph_type, prefilter=None, symmetrize=True)
465
+ output = inference(pairs, pe3r.mast3r, device, batch_size=1, verbose=not silent)
466
+ mode = GlobalAlignerMode.PointCloudOptimizer if len(imgs) > 2 else GlobalAlignerMode.PairViewer
467
+ scene_1 = global_aligner(output, cog_seg_maps, rev_cog_seg_maps, cog_feats, device=device, mode=mode, verbose=not silent)
468
+ lr = 0.01
469
+ # if mode == GlobalAlignerMode.PointCloudOptimizer:
470
+ loss = scene_1.compute_global_alignment(tune_flg=True, init='mst', niter=niter, schedule=schedule, lr=lr)
471
+
472
+ try:
473
+ import torchvision.transforms as tvf
474
+ ImgNorm = tvf.Compose([tvf.ToTensor(), tvf.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
475
+ for i in range(len(imgs)):
476
+ # print(imgs[i]['img'].shape, scene.imgs[i].shape, ImgNorm(scene.imgs[i])[None])
477
+ imgs[i]['img'] = ImgNorm(scene_1.imgs[i])[None]
478
+ pairs = make_pairs(imgs, scene_graph=scenegraph_type, prefilter=None, symmetrize=True)
479
+ output = inference(pairs, pe3r.mast3r, device, batch_size=1, verbose=not silent)
480
+ mode = GlobalAlignerMode.PointCloudOptimizer if len(imgs) > 2 else GlobalAlignerMode.PairViewer
481
+ scene = global_aligner(output, cog_seg_maps, rev_cog_seg_maps, cog_feats, device=device, mode=mode, verbose=not silent)
482
+ ori_imgs = scene.ori_imgs
483
+ lr = 0.01
484
+ # if mode == GlobalAlignerMode.PointCloudOptimizer:
485
+ loss = scene.compute_global_alignment(tune_flg=False, init='mst', niter=niter, schedule=schedule, lr=lr)
486
+ except Exception as e:
487
+ scene = scene_1
488
+ scene.imgs = ori_imgs
489
+ scene.ori_imgs = ori_imgs
490
+ print(e)
491
+
492
+
493
+ outfile = get_3D_model_from_scene(outdir, silent, scene, min_conf_thr, as_pointcloud, mask_sky,
494
+ clean_depth, transparent_cams, cam_size)
495
+
496
+ # also return rgb, depth and confidence imgs
497
+ # depth is normalized with the max value for all images
498
+ # we apply the jet colormap on the confidence maps
499
+ rgbimg = scene.imgs
500
+ depths = to_numpy(scene.get_depthmaps())
501
+ confs = to_numpy([c for c in scene.im_conf])
502
+ # confs = to_numpy([c for c in scene.conf_2])
503
+ cmap = pl.get_cmap('jet')
504
+ depths_max = max([d.max() for d in depths])
505
+ depths = [d / depths_max for d in depths]
506
+ confs_max = max([d.max() for d in confs])
507
+ confs = [cmap(d / confs_max) for d in confs]
508
+
509
+ imgs = []
510
+ for i in range(len(rgbimg)):
511
+ imgs.append(rgbimg[i])
512
+ imgs.append(rgb(depths[i]))
513
+ imgs.append(rgb(confs[i]))
514
+
515
+ return scene, outfile, imgs
516
+
517
+ @spaces.GPU(duration=180)
518
+ def get_3D_object_from_scene(outdir, pe3r, silent, device, text, threshold, scene, min_conf_thr, as_pointcloud,
519
+ mask_sky, clean_depth, transparent_cams, cam_size):
520
+
521
+ texts = [text]
522
+ inputs = pe3r.siglip_tokenizer(text=texts, padding="max_length", return_tensors="pt")
523
+ inputs = {key: value.to(device) for key, value in inputs.items()}
524
+ with torch.no_grad():
525
+ text_feats =pe3r.siglip.get_text_features(**inputs)
526
+ text_feats = text_feats / text_feats.norm(dim=-1, keepdim=True)
527
+ scene.render_image(text_feats, threshold)
528
+ scene.ori_imgs = scene.rendered_imgs
529
+ outfile = get_3D_model_from_scene(outdir, silent, scene, min_conf_thr, as_pointcloud, mask_sky,
530
+ clean_depth, transparent_cams, cam_size)
531
+ return outfile
532
+
533
+
534
+ def set_scenegraph_options(inputfiles, winsize, refid, scenegraph_type):
535
+ num_files = len(inputfiles) if inputfiles is not None else 1
536
+ max_winsize = max(1, math.ceil((num_files - 1) / 2))
537
+ if scenegraph_type == "swin":
538
+ winsize = gradio.Slider(label="Scene Graph: Window Size", value=max_winsize,
539
+ minimum=1, maximum=max_winsize, step=1, visible=True)
540
+ refid = gradio.Slider(label="Scene Graph: Id", value=0, minimum=0,
541
+ maximum=num_files - 1, step=1, visible=False)
542
+ elif scenegraph_type == "oneref":
543
+ winsize = gradio.Slider(label="Scene Graph: Window Size", value=max_winsize,
544
+ minimum=1, maximum=max_winsize, step=1, visible=False)
545
+ refid = gradio.Slider(label="Scene Graph: Id", value=0, minimum=0,
546
+ maximum=num_files - 1, step=1, visible=True)
547
  else:
548
+ winsize = gradio.Slider(label="Scene Graph: Window Size", value=max_winsize,
549
+ minimum=1, maximum=max_winsize, step=1, visible=False)
550
+ refid = gradio.Slider(label="Scene Graph: Id", value=0, minimum=0,
551
+ maximum=num_files - 1, step=1, visible=False)
552
+ return winsize, refid
553
+
554
+
555
+
556
+
557
+
558
+ silent = True
559
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
560
+ pe3r = Models(device)
561
+
562
+ with tempfile.TemporaryDirectory(suffix='pe3r_gradio_demo') as tmpdirname:
563
+ recon_fun = functools.partial(get_reconstructed_scene, tmpdirname, pe3r, device, silent)
564
+ model_from_scene_fun = functools.partial(get_3D_model_from_scene, tmpdirname, silent)
565
+ get_3D_object_from_scene_fun = functools.partial(get_3D_object_from_scene, tmpdirname, pe3r, silent, device)
566
+
567
+ with gradio.Blocks(css=""".gradio-container {margin: 0 !important; min-width: 100%};""", title="PE3R Demo") as demo:
568
+ # scene state is save so that you can change conf_thr, cam_size... without rerunning the inference
569
+ scene = gradio.State(None)
570
+ gradio.HTML('<h2 style="text-align: center;">PE3R Demo</h2>')
571
+ with gradio.Column():
572
+ inputfiles = gradio.File(file_count="multiple")
573
+ with gradio.Row():
574
+ schedule = gradio.Dropdown(["linear", "cosine"],
575
+ value='linear', label="schedule", info="For global alignment!",
576
+ visible=False)
577
+ niter = gradio.Number(value=300, precision=0, minimum=0, maximum=5000,
578
+ label="num_iterations", info="For global alignment!",
579
+ visible=False)
580
+ scenegraph_type = gradio.Dropdown([("complete: all possible image pairs", "complete"),
581
+ ("swin: sliding window", "swin"),
582
+ ("oneref: match one image with all", "oneref")],
583
+ value='complete', label="Scenegraph",
584
+ info="Define how to make pairs",
585
+ interactive=True,
586
+ visible=False)
587
+ winsize = gradio.Slider(label="Scene Graph: Window Size", value=1,
588
+ minimum=1, maximum=1, step=1, visible=False)
589
+ refid = gradio.Slider(label="Scene Graph: Id", value=0, minimum=0, maximum=0, step=1, visible=False)
590
+
591
+ run_btn = gradio.Button("Reconstruct")
592
+
593
+ with gradio.Row():
594
+ # adjust the confidence threshold
595
+ min_conf_thr = gradio.Slider(label="min_conf_thr", value=3.0, minimum=1.0, maximum=20, step=0.1, visible=False)
596
+ # adjust the camera size in the output pointcloud
597
+ cam_size = gradio.Slider(label="cam_size", value=0.05, minimum=0.001, maximum=0.1, step=0.001, visible=False)
598
+ with gradio.Row():
599
+ as_pointcloud = gradio.Checkbox(value=True, label="As pointcloud")
600
+ # two post process implemented
601
+ mask_sky = gradio.Checkbox(value=False, label="Mask sky", visible=False)
602
+ clean_depth = gradio.Checkbox(value=True, label="Clean-up depthmaps", visible=False)
603
+ transparent_cams = gradio.Checkbox(value=True, label="Transparent cameras")
604
+
605
+ with gradio.Row():
606
+ text_input = gradio.Textbox(label="Query Text")
607
+ threshold = gradio.Slider(label="Threshold", value=0.85, minimum=0.0, maximum=1.0, step=0.01)
608
+
609
+ find_btn = gradio.Button("Find")
610
+
611
+ outmodel = gradio.Model3D()
612
+ outgallery = gradio.Gallery(label='rgb,depth,confidence', columns=3, height="100%",
613
+ visible=False)
614
 
615
+ # events
616
+ scenegraph_type.change(set_scenegraph_options,
617
+ inputs=[inputfiles, winsize, refid, scenegraph_type],
618
+ outputs=[winsize, refid])
619
+ inputfiles.change(set_scenegraph_options,
620
+ inputs=[inputfiles, winsize, refid, scenegraph_type],
621
+ outputs=[winsize, refid])
622
+ run_btn.click(fn=recon_fun,
623
+ inputs=[inputfiles, schedule, niter, min_conf_thr, as_pointcloud,
624
+ mask_sky, clean_depth, transparent_cams, cam_size,
625
+ scenegraph_type, winsize, refid],
626
+ outputs=[scene, outmodel, outgallery])
627
+ min_conf_thr.release(fn=model_from_scene_fun,
628
+ inputs=[scene, min_conf_thr, as_pointcloud, mask_sky,
629
+ clean_depth, transparent_cams, cam_size],
630
+ outputs=outmodel)
631
+ cam_size.change(fn=model_from_scene_fun,
632
+ inputs=[scene, min_conf_thr, as_pointcloud, mask_sky,
633
+ clean_depth, transparent_cams, cam_size],
634
+ outputs=outmodel)
635
+ as_pointcloud.change(fn=model_from_scene_fun,
636
+ inputs=[scene, min_conf_thr, as_pointcloud, mask_sky,
637
+ clean_depth, transparent_cams, cam_size],
638
+ outputs=outmodel)
639
+ mask_sky.change(fn=model_from_scene_fun,
640
+ inputs=[scene, min_conf_thr, as_pointcloud, mask_sky,
641
+ clean_depth, transparent_cams, cam_size],
642
+ outputs=outmodel)
643
+ clean_depth.change(fn=model_from_scene_fun,
644
+ inputs=[scene, min_conf_thr, as_pointcloud, mask_sky,
645
+ clean_depth, transparent_cams, cam_size],
646
+ outputs=outmodel)
647
+ transparent_cams.change(model_from_scene_fun,
648
+ inputs=[scene, min_conf_thr, as_pointcloud, mask_sky,
649
+ clean_depth, transparent_cams, cam_size],
650
+ outputs=outmodel)
651
+ find_btn.click(fn=get_3D_object_from_scene_fun,
652
+ inputs=[text_input, threshold, scene, min_conf_thr, as_pointcloud, mask_sky,
653
+ clean_depth, transparent_cams, cam_size],
654
+ outputs=outmodel)
655
+ demo.launch(show_error=True, share=None, server_name=None, server_port=None)
modules/dust3r/cloud_opt/__pycache__/base_opt.cpython-312.pyc CHANGED
Binary files a/modules/dust3r/cloud_opt/__pycache__/base_opt.cpython-312.pyc and b/modules/dust3r/cloud_opt/__pycache__/base_opt.cpython-312.pyc differ
 
modules/pe3r/demo.py CHANGED
@@ -552,13 +552,6 @@ def main_demo(tmpdirname, server_name, server_port, silent=False):
552
  device = 'cuda' if torch.cuda.is_available() else 'cpu'
553
 
554
  pe3r = Models(device)
555
- # scene, outfile, imgs = get_reconstructed_scene(
556
- # outdir=tmpdirname, pe3r=pe3r, device=device, silent=silent,
557
- # filelist=['/home/hujie/pe3r/datasets/mipnerf360_ov/bonsai/black_chair/images/DSCF5590.png',
558
- # '/home/hujie/pe3r/datasets/mipnerf360_ov/bonsai/black_chair/images/DSCF5602.png',
559
- # '/home/hujie/pe3r/datasets/mipnerf360_ov/bonsai/black_chair/images/DSCF5609.png'],
560
- # schedule="linear", niter=300, min_conf_thr=3.0, as_pointcloud=False, mask_sky=True, clean_depth=True, transparent_cams=False,
561
- # cam_size=0.05, scenegraph_type="complete", winsize=1, refid=0)
562
 
563
  recon_fun = functools.partial(get_reconstructed_scene, tmpdirname, pe3r, device, silent)
564
  model_from_scene_fun = functools.partial(get_3D_model_from_scene, tmpdirname, silent)
 
552
  device = 'cuda' if torch.cuda.is_available() else 'cpu'
553
 
554
  pe3r = Models(device)
 
 
 
 
 
 
 
555
 
556
  recon_fun = functools.partial(get_reconstructed_scene, tmpdirname, pe3r, device, silent)
557
  model_from_scene_fun = functools.partial(get_3D_model_from_scene, tmpdirname, silent)