Spaces:
Running
on
T4
Running
on
T4
absolute paths are added to video_inpainting script
Browse files- FGT_codes/tool/video_inpainting.py +383 -254
FGT_codes/tool/video_inpainting.py
CHANGED
@@ -1,3 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
import cvbase
|
2 |
from torchvision.transforms import ToTensor
|
3 |
from get_flowNN_gradient import get_flowNN_gradient
|
@@ -20,22 +28,6 @@ import glob
|
|
20 |
import cv2
|
21 |
import argparse
|
22 |
import warnings
|
23 |
-
import os
|
24 |
-
import sys
|
25 |
-
|
26 |
-
sys.path.append(os.path.abspath(os.path.join(__file__, '..', '..')))
|
27 |
-
sys.path.append(os.path.abspath(os.path.join(__file__, '..', '..', 'tool')))
|
28 |
-
sys.path.append(os.path.abspath(os.path.join(
|
29 |
-
__file__, '..', '..', 'tool', 'utils')))
|
30 |
-
sys.path.append(os.path.abspath(os.path.join(
|
31 |
-
__file__, '..', '..', 'tool', 'utils', 'region_fill.py')))
|
32 |
-
sys.path.append(os.path.abspath(os.path.join(
|
33 |
-
__file__, '..', '..', 'tool', 'utils', 'Poisson_blend_img.py')))
|
34 |
-
sys.path.append(os.path.abspath(os.path.join(__file__, '..', '..', 'FGT')))
|
35 |
-
sys.path.append(os.path.abspath(os.path.join(__file__, '..', '..', 'LAFC')))
|
36 |
-
sys.path.append(os.path.abspath(
|
37 |
-
os.path.join(os.path.dirname("__file__"), '..')))
|
38 |
-
warnings.filterwarnings("ignore")
|
39 |
|
40 |
|
41 |
def to_tensor(img):
|
@@ -55,32 +47,37 @@ def diffusion(flows, masks):
|
|
55 |
return flows_filled
|
56 |
|
57 |
|
58 |
-
def np2tensor(array, near=
|
59 |
if isinstance(array, list):
|
60 |
array = np.stack(array, axis=0) # [t, h, w, c]
|
61 |
-
if near ==
|
62 |
-
array =
|
63 |
-
|
64 |
-
|
65 |
-
|
66 |
-
|
67 |
else:
|
68 |
-
raise ValueError(f
|
69 |
return array
|
70 |
|
71 |
|
72 |
def tensor2np(array):
|
73 |
-
array = torch.stack(array, dim=-1).squeeze(0).permute(1,
|
74 |
-
2, 0, 3).cpu().numpy()
|
75 |
return array
|
76 |
|
77 |
|
78 |
def gradient_mask(mask):
|
79 |
-
gradient_mask = np.logical_or.reduce(
|
80 |
-
|
81 |
-
|
82 |
-
|
83 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
84 |
|
85 |
return gradient_mask
|
86 |
|
@@ -116,54 +113,73 @@ def get_ref_index(f, neighbor_ids, length, ref_length, num_ref):
|
|
116 |
|
117 |
|
118 |
def save_flows(output, videoFlowF, videoFlowB):
|
119 |
-
create_dir(os.path.join(output,
|
120 |
-
create_dir(os.path.join(output,
|
121 |
-
create_dir(os.path.join(output,
|
122 |
-
create_dir(os.path.join(output,
|
123 |
N = videoFlowF.shape[-1]
|
124 |
for i in range(N):
|
125 |
forward_flow = videoFlowF[..., i]
|
126 |
backward_flow = videoFlowB[..., i]
|
127 |
forward_flow_vis = cvbase.flow2rgb(forward_flow)
|
128 |
backward_flow_vis = cvbase.flow2rgb(backward_flow)
|
129 |
-
cvbase.write_flow(
|
130 |
-
|
131 |
-
|
132 |
-
|
133 |
-
|
134 |
-
|
135 |
-
|
136 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
137 |
|
138 |
|
139 |
def save_fgcp(output, frames, masks):
|
140 |
-
create_dir(os.path.join(output,
|
141 |
-
create_dir(os.path.join(output,
|
142 |
-
create_dir(os.path.join(output,
|
143 |
-
create_dir(os.path.join(output,
|
144 |
|
145 |
assert len(frames) == masks.shape[2]
|
146 |
for i in range(len(frames)):
|
147 |
-
cv2.imwrite(
|
148 |
-
|
149 |
-
|
150 |
-
|
151 |
-
|
152 |
-
|
153 |
-
np.save(
|
154 |
-
|
|
|
|
|
|
|
|
|
|
|
155 |
|
156 |
|
157 |
def create_dir(dir):
|
158 |
-
"""Creates a directory if not exist.
|
159 |
-
"""
|
160 |
if not os.path.exists(dir):
|
161 |
os.makedirs(dir)
|
162 |
|
163 |
|
164 |
def initialize_RAFT(args, device):
|
165 |
-
"""Initializes the RAFT model.
|
166 |
-
"""
|
167 |
model = torch.nn.DataParallel(RAFT(args))
|
168 |
model.load_state_dict(torch.load(args.raft_model))
|
169 |
|
@@ -177,57 +193,65 @@ def initialize_RAFT(args, device):
|
|
177 |
def initialize_LAFC(args, device):
|
178 |
print(args.lafc_ckpts)
|
179 |
assert len(os.listdir(args.lafc_ckpts)) == 2
|
180 |
-
checkpoint, config_file =
|
181 |
-
glob.glob(os.path.join(args.lafc_ckpts,
|
182 |
-
|
|
|
|
|
183 |
configs = yaml.full_load(f)
|
184 |
-
model = configs[
|
185 |
-
pkg = import_module(
|
186 |
model = pkg.Model(configs)
|
187 |
-
state = torch.load(
|
188 |
-
|
189 |
-
|
|
|
190 |
model = model.to(device)
|
191 |
return model, configs
|
192 |
|
193 |
|
194 |
def initialize_FGT(args, device):
|
195 |
assert len(os.listdir(args.fgt_ckpts)) == 2
|
196 |
-
checkpoint, config_file =
|
197 |
-
glob.glob(os.path.join(args.fgt_ckpts,
|
198 |
-
|
|
|
|
|
199 |
configs = yaml.full_load(f)
|
200 |
-
model = configs[
|
201 |
-
net = import_module(
|
202 |
model = net.Model(configs).to(device)
|
203 |
-
state = torch.load(
|
204 |
-
|
205 |
-
|
|
|
206 |
return model, configs
|
207 |
|
208 |
|
209 |
def calculate_flow(args, model, video, mode):
|
210 |
-
"""Calculates optical flow.
|
211 |
-
"""
|
212 |
-
if mode not in ['forward', 'backward']:
|
213 |
raise NotImplementedError
|
214 |
|
215 |
imgH, imgW = args.imgH, args.imgW
|
216 |
Flow = np.empty(((imgH, imgW, 2, 0)), dtype=np.float32)
|
217 |
|
218 |
if args.vis_flows:
|
219 |
-
create_dir(os.path.join(args.outroot,
|
220 |
-
create_dir(os.path.join(args.outroot,
|
221 |
|
222 |
with torch.no_grad():
|
223 |
for i in range(video.shape[0] - 1):
|
224 |
print(
|
225 |
-
"Calculating {0} flow {1:2d} <---> {2:2d}".format(mode, i, i + 1),
|
226 |
-
|
|
|
|
|
|
|
227 |
# Flow i -> i + 1
|
228 |
image1 = video[i, None]
|
229 |
image2 = video[i + 1, None]
|
230 |
-
elif mode ==
|
231 |
# Flow i + 1 -> i
|
232 |
image1 = video[i + 1, None]
|
233 |
image2 = video[i, None]
|
@@ -240,8 +264,8 @@ def calculate_flow(args, model, video, mode):
|
|
240 |
h, w = flow.shape[:2]
|
241 |
if h != imgH or w != imgW:
|
242 |
flow = cv2.resize(flow, (imgW, imgH), cv2.INTER_LINEAR)
|
243 |
-
flow[:, :, 0] *=
|
244 |
-
flow[:, :, 1] *=
|
245 |
|
246 |
Flow = np.concatenate((Flow, flow[..., None]), axis=-1)
|
247 |
|
@@ -251,17 +275,19 @@ def calculate_flow(args, model, video, mode):
|
|
251 |
flow_img = Image.fromarray(flow_img)
|
252 |
|
253 |
# Saves the flow and flow_img.
|
254 |
-
flow_img.save(
|
255 |
-
|
256 |
-
|
257 |
-
|
|
|
|
|
|
|
258 |
|
259 |
return Flow
|
260 |
|
261 |
|
262 |
def extrapolation(args, video_ori, corrFlowF_ori, corrFlowB_ori):
|
263 |
-
"""Prepares the data for video extrapolation.
|
264 |
-
"""
|
265 |
imgH, imgW, _, nFrame = video_ori.shape
|
266 |
|
267 |
# Defines new FOV.
|
@@ -274,45 +300,56 @@ def extrapolation(args, video_ori, corrFlowF_ori, corrFlowB_ori):
|
|
274 |
|
275 |
# Generates the mask for missing region.
|
276 |
flow_mask = np.ones(((imgH_extr, imgW_extr)), dtype=np.bool)
|
277 |
-
flow_mask[H_start: H_start + imgH, W_start: W_start + imgW] = 0
|
278 |
|
279 |
mask_dilated = gradient_mask(flow_mask)
|
280 |
|
281 |
# Extrapolates the FOV for video.
|
282 |
video = np.zeros(((imgH_extr, imgW_extr, 3, nFrame)), dtype=np.float32)
|
283 |
-
video[H_start: H_start + imgH, W_start: W_start + imgW, :, :] = video_ori
|
284 |
|
285 |
for i in range(nFrame):
|
286 |
-
print("Preparing frame {0}".format(i),
|
287 |
-
video[:, :, :, i] =
|
288 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
289 |
|
290 |
# Extrapolates the FOV for flow.
|
291 |
-
corrFlowF = np.zeros(
|
292 |
-
|
293 |
-
|
294 |
-
|
295 |
-
|
296 |
-
|
297 |
-
|
298 |
-
|
299 |
-
|
300 |
-
|
|
|
|
|
|
|
|
|
301 |
|
302 |
|
303 |
def complete_flow(config, flow_model, flows, flow_masks, mode, device):
|
304 |
-
if mode not in [
|
305 |
-
raise NotImplementedError(f
|
306 |
flow_masks = np.moveaxis(flow_masks, -1, 0) # [N, H, W]
|
307 |
flows = np.moveaxis(flows, -1, 0) # [N, H, W, 2]
|
308 |
if len(flow_masks.shape) == 3:
|
309 |
flow_masks = flow_masks[:, :, :, np.newaxis]
|
310 |
-
if mode ==
|
311 |
flow_masks = flow_masks[0:-1]
|
312 |
else:
|
313 |
flow_masks = flow_masks[1:]
|
314 |
|
315 |
-
num_flows, flow_interval = config[
|
316 |
|
317 |
diffused_flows = diffusion(flows, flow_masks)
|
318 |
|
@@ -329,7 +366,7 @@ def complete_flow(config, flow_model, flows, flow_masks, mode, device):
|
|
329 |
pivot = num_flows // 2
|
330 |
for i in range(t):
|
331 |
indices = indicesGen(i, flow_interval, num_flows, t)
|
332 |
-
print(
|
333 |
cand_flows = flows[:, :, indices]
|
334 |
cand_masks = flow_masks[:, :, indices]
|
335 |
inputs = diffused_flows[:, :, indices]
|
@@ -349,19 +386,19 @@ def complete_flow(config, flow_model, flows, flow_masks, mode, device):
|
|
349 |
def read_flow(flow_dir, video):
|
350 |
nFrame, _, imgH, imgW = video.shape
|
351 |
Flow = np.empty(((imgH, imgW, 2, 0)), dtype=np.float32)
|
352 |
-
flows = sorted(glob.glob(os.path.join(flow_dir,
|
353 |
for flow in flows:
|
354 |
flow_data = cvbase.read_flow(flow)
|
355 |
h, w = flow_data.shape[:2]
|
356 |
flow_data = cv2.resize(flow_data, (imgW, imgH), cv2.INTER_LINEAR)
|
357 |
-
flow_data[:, :, 0] *=
|
358 |
-
flow_data[:, :, 1] *=
|
359 |
Flow = np.concatenate((Flow, flow_data[..., None]), axis=-1)
|
360 |
return Flow
|
361 |
|
362 |
|
363 |
def norm_flows(flows):
|
364 |
-
assert len(flows.shape) == 5,
|
365 |
flattened_flows = flows.flatten(3)
|
366 |
flow_max = torch.max(flattened_flows, dim=-1, keepdim=True)[0]
|
367 |
flows = flows / flow_max.unsqueeze(-1)
|
@@ -369,19 +406,19 @@ def norm_flows(flows):
|
|
369 |
|
370 |
|
371 |
def save_results(outdir, comp_frames):
|
372 |
-
out_dir = os.path.join(outdir,
|
373 |
if not os.path.exists(out_dir):
|
374 |
os.makedirs(out_dir)
|
375 |
for i in range(len(comp_frames)):
|
376 |
-
out_path = os.path.join(out_dir,
|
377 |
cv2.imwrite(out_path, comp_frames[i][:, :, ::-1])
|
378 |
|
379 |
|
380 |
def video_inpainting(args, imgArr, imgMaskArr):
|
381 |
-
device = torch.device(
|
382 |
print(args)
|
383 |
if args.opt is not None:
|
384 |
-
with open(args.opt,
|
385 |
opts = yaml.full_load(f)
|
386 |
|
387 |
for k in opts.keys():
|
@@ -412,37 +449,54 @@ def video_inpainting(args, imgArr, imgMaskArr):
|
|
412 |
|
413 |
# Load video.
|
414 |
video, video_flow = [], []
|
415 |
-
if args.mode ==
|
416 |
-
maskname_list = glob.glob(os.path.join(args.path_mask,
|
417 |
-
os.path.join(args.path_mask,
|
|
|
418 |
assert len(filename_list) == len(maskname_list)
|
419 |
for filename, maskname in zip(sorted(filename_list), sorted(maskname_list)):
|
420 |
-
frame =
|
421 |
-
|
422 |
-
|
423 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
424 |
mask[mask > 0] = 1
|
425 |
frame = frame * (1 - mask)
|
426 |
-
frame = F2.upsample(
|
427 |
-
|
428 |
-
|
429 |
-
|
|
|
|
|
430 |
video.append(frame)
|
431 |
video_flow.append(frame_flow)
|
432 |
else:
|
433 |
-
|
434 |
-
|
435 |
-
|
436 |
-
|
437 |
-
|
438 |
-
|
439 |
for im in imgArr:
|
440 |
-
frame =
|
441 |
-
np.
|
442 |
-
|
443 |
-
|
444 |
-
|
445 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
446 |
video.append(frame)
|
447 |
video_flow.append(frame_flow)
|
448 |
|
@@ -454,20 +508,26 @@ def video_inpainting(args, imgArr, imgMaskArr):
|
|
454 |
|
455 |
# Calcutes the corrupted flow.
|
456 |
forward_flows = calculate_flow(
|
457 |
-
args, RAFT_model, video_flow,
|
458 |
-
|
|
|
459 |
|
460 |
# Makes sure video is in BGR (opencv) format.
|
461 |
-
video =
|
462 |
-
:, :, ::-1, :] / 255.
|
463 |
-
|
464 |
-
if args.mode == 'video_extrapolation':
|
465 |
|
|
|
466 |
# Creates video and flow where the extrapolated region are missing.
|
467 |
-
|
468 |
-
|
469 |
-
|
470 |
-
|
|
|
|
|
|
|
|
|
|
|
471 |
imgH, imgW = video.shape[:2]
|
472 |
|
473 |
# mask indicating the missing region in the video.
|
@@ -477,13 +537,14 @@ def video_inpainting(args, imgArr, imgMaskArr):
|
|
477 |
|
478 |
else:
|
479 |
# Loads masks.
|
480 |
-
filename_list = glob.glob(os.path.join(args.path_mask,
|
481 |
-
|
|
|
482 |
|
483 |
mask = []
|
484 |
mask_dilated = []
|
485 |
flow_mask = []
|
486 |
-
|
487 |
mask_img = np.array(Image.open(filename).convert('L'))
|
488 |
mask_img = cv2.resize(mask_img, dsize=(imgW, imgH), interpolation=cv2.INTER_NEAREST)
|
489 |
|
@@ -496,23 +557,26 @@ def video_inpainting(args, imgArr, imgMaskArr):
|
|
496 |
if args.frame_dilates > 0:
|
497 |
mask_img = scipy.ndimage.binary_dilation(mask_img, iterations=args.frame_dilates)
|
498 |
mask.append(mask_img)
|
499 |
-
mask_dilated.append(gradient_mask(mask_img))
|
500 |
|
501 |
for f_mask in imgMaskArr:
|
502 |
mask_img = np.array(f_mask)
|
503 |
-
mask_img = cv2.resize(
|
504 |
-
imgW, imgH), interpolation=cv2.INTER_NEAREST
|
|
|
505 |
|
506 |
if args.flow_mask_dilates > 0:
|
507 |
flow_mask_img = scipy.ndimage.binary_dilation(
|
508 |
-
mask_img, iterations=args.flow_mask_dilates
|
|
|
509 |
else:
|
510 |
flow_mask_img = mask_img
|
511 |
flow_mask.append(flow_mask_img)
|
512 |
|
513 |
if args.frame_dilates > 0:
|
514 |
mask_img = scipy.ndimage.binary_dilation(
|
515 |
-
mask_img, iterations=args.frame_dilates
|
|
|
516 |
mask.append(mask_img)
|
517 |
mask_dilated.append(gradient_mask(mask_img))
|
518 |
|
@@ -523,12 +587,14 @@ def video_inpainting(args, imgArr, imgMaskArr):
|
|
523 |
|
524 |
# Completes the flow.
|
525 |
videoFlowF = complete_flow(
|
526 |
-
LAFC_config, LAFC_model, forward_flows, flow_mask,
|
|
|
527 |
videoFlowB = complete_flow(
|
528 |
-
LAFC_config, LAFC_model, backward_flows, flow_mask,
|
|
|
529 |
videoFlowF = tensor2np(videoFlowF)
|
530 |
videoFlowB = tensor2np(videoFlowB)
|
531 |
-
print(
|
532 |
|
533 |
if args.vis_completed_flows:
|
534 |
save_flows(args.outroot, videoFlowF, videoFlowB)
|
@@ -540,17 +606,28 @@ def video_inpainting(args, imgArr, imgMaskArr):
|
|
540 |
for indFrame in range(nFrame):
|
541 |
img = video[:, :, :, indFrame]
|
542 |
img[mask[:, :, indFrame], :] = 0
|
543 |
-
img =
|
544 |
-
|
545 |
-
|
546 |
-
|
547 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
548 |
gradient_y_ = np.concatenate(
|
549 |
-
(np.diff(img, axis=0), np.zeros((1, imgW, 3), dtype=np.float32)), axis=0
|
|
|
550 |
gradient_x = np.concatenate(
|
551 |
-
(gradient_x, gradient_x_.reshape(imgH, imgW, 3, 1)), axis=-1
|
|
|
552 |
gradient_y = np.concatenate(
|
553 |
-
(gradient_y, gradient_y_.reshape(imgH, imgW, 3, 1)), axis=-1
|
|
|
554 |
|
555 |
gradient_x[mask_dilated[:, :, indFrame], :, indFrame] = 0
|
556 |
gradient_y[mask_dilated[:, :, indFrame], :, indFrame] = 0
|
@@ -561,21 +638,23 @@ def video_inpainting(args, imgArr, imgMaskArr):
|
|
561 |
video_comp = video
|
562 |
|
563 |
# Gradient propagation.
|
564 |
-
gradient_x_filled, gradient_y_filled, mask_gradient =
|
565 |
-
|
566 |
-
|
567 |
-
|
568 |
-
|
569 |
-
|
570 |
-
|
571 |
-
|
572 |
-
|
573 |
-
|
|
|
574 |
|
575 |
# if there exist holes in mask, Poisson blending will fail. So I did this trick. I sacrifice some value. Another solution is to modify Poisson blending.
|
576 |
for indFrame in range(nFrame):
|
577 |
-
mask_gradient[:, :, indFrame] = scipy.ndimage.binary_fill_holes(
|
578 |
-
|
|
|
579 |
|
580 |
# After one gradient propagation iteration
|
581 |
# gradient --> RGB
|
@@ -585,19 +664,29 @@ def video_inpainting(args, imgArr, imgMaskArr):
|
|
585 |
|
586 |
if mask[:, :, indFrame].sum() > 0:
|
587 |
try:
|
588 |
-
frameBlend, UnfilledMask = Poisson_blend_img(
|
589 |
-
|
590 |
-
|
591 |
-
|
592 |
-
|
593 |
-
|
|
|
594 |
except:
|
595 |
-
frameBlend, UnfilledMask =
|
596 |
-
|
|
|
|
|
597 |
|
598 |
frameBlend = np.clip(frameBlend, 0, 1.0)
|
599 |
-
tmp =
|
600 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
601 |
frameBlend[UnfilledMask, :] = tmp[UnfilledMask, :]
|
602 |
|
603 |
video_comp[:, :, :, indFrame] = frameBlend
|
@@ -605,7 +694,7 @@ def video_inpainting(args, imgArr, imgMaskArr):
|
|
605 |
|
606 |
frameBlend_ = copy.deepcopy(frameBlend)
|
607 |
# Green indicates the regions that are not filled yet.
|
608 |
-
frameBlend_[mask[:, :, indFrame], :] = [0, 1
|
609 |
else:
|
610 |
frameBlend_ = video_comp[:, :, :, indFrame]
|
611 |
frameBlends.append(frameBlend_)
|
@@ -618,10 +707,10 @@ def video_inpainting(args, imgArr, imgMaskArr):
|
|
618 |
for i in range(len(frameBlends)):
|
619 |
frameBlends[i] = frameBlends[i][:, :, ::-1]
|
620 |
|
621 |
-
frames_first = np2tensor(frameBlends, near=
|
622 |
mask = np.moveaxis(mask, -1, 0)
|
623 |
mask = mask[:, :, :, np.newaxis]
|
624 |
-
masks = np2tensor(mask, near=
|
625 |
normed_frames = frames_first * 2 - 1
|
626 |
comp_frames = [None] * video_length
|
627 |
|
@@ -633,115 +722,155 @@ def video_inpainting(args, imgArr, imgMaskArr):
|
|
633 |
|
634 |
videoFlowF = np.concatenate([videoFlowF, videoFlowF[-1:, ...]], axis=0)
|
635 |
|
636 |
-
flows = np2tensor(videoFlowF, near=
|
637 |
flows = norm_flows(flows).to(device)
|
638 |
|
639 |
for f in range(0, video_length, neighbor_stride):
|
640 |
-
neighbor_ids = [
|
641 |
-
|
642 |
-
|
643 |
-
|
|
|
|
|
|
|
644 |
print(f, len(neighbor_ids), len(ref_ids))
|
645 |
selected_frames = normed_frames[:, neighbor_ids + ref_ids]
|
646 |
selected_masks = masks[:, neighbor_ids + ref_ids]
|
647 |
masked_frames = selected_frames * (1 - selected_masks)
|
648 |
selected_flows = flows[:, neighbor_ids + ref_ids]
|
649 |
with torch.no_grad():
|
650 |
-
filled_frames = FGT_model(
|
651 |
-
masked_frames, selected_flows, selected_masks)
|
652 |
filled_frames = (filled_frames + 1) / 2
|
653 |
filled_frames = filled_frames.cpu().permute(0, 2, 3, 1).numpy() * 255
|
654 |
for i in range(len(neighbor_ids)):
|
655 |
idx = neighbor_ids[i]
|
656 |
-
valid_frame = frames_first[0, idx].cpu().permute(
|
657 |
-
1, 2, 0).numpy() * 255.
|
658 |
valid_mask = masks[0, idx].cpu().permute(1, 2, 0).numpy()
|
659 |
-
comp = np.array(filled_frames[i]).astype(np.uint8) * valid_mask +
|
660 |
-
|
|
|
661 |
if comp_frames[idx] is None:
|
662 |
comp_frames[idx] = comp
|
663 |
else:
|
664 |
-
comp_frames[idx] =
|
665 |
-
|
|
|
|
|
666 |
if args.vis_frame:
|
667 |
save_results(args.outroot, comp_frames)
|
668 |
create_dir(args.outroot)
|
669 |
for i in range(len(comp_frames)):
|
670 |
comp_frames[i] = comp_frames[i].astype(np.uint8)
|
671 |
-
imageio.mimwrite(
|
672 |
-
|
673 |
-
|
|
|
674 |
|
675 |
|
676 |
def main(args):
|
677 |
-
assert args.mode in (
|
|
|
|
|
|
|
|
|
678 |
"Accepted modes: 'object_removal', 'video_extrapolation', and 'watermark_removal', but input is %s"
|
679 |
) % args.mode
|
680 |
video_inpainting(args)
|
681 |
|
682 |
|
683 |
-
if __name__ ==
|
684 |
parser = argparse.ArgumentParser()
|
685 |
-
parser.add_argument(
|
686 |
-
|
|
|
|
|
|
|
687 |
# video completion
|
688 |
-
parser.add_argument('--mode', default='object_removal', choices=[
|
689 |
-
'object_removal', 'watermark_removal', 'video_extrapolation'], help="modes: object_removal / video_extrapolation")
|
690 |
parser.add_argument(
|
691 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
692 |
parser.add_argument(
|
693 |
-
|
|
|
|
|
|
|
694 |
parser.add_argument(
|
695 |
-
|
696 |
-
|
697 |
-
|
698 |
-
|
699 |
-
|
700 |
-
|
|
|
|
|
|
|
|
|
|
|
701 |
|
702 |
# RAFT
|
703 |
parser.add_argument(
|
704 |
-
|
705 |
-
|
706 |
-
|
707 |
-
|
708 |
-
parser.add_argument(
|
709 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
710 |
|
711 |
# LAFC
|
712 |
-
parser.add_argument(
|
713 |
|
714 |
# FGT
|
715 |
-
parser.add_argument(
|
716 |
|
717 |
# extrapolation
|
718 |
-
parser.add_argument(
|
719 |
-
|
720 |
-
|
721 |
-
|
|
|
|
|
722 |
|
723 |
# Image basic information
|
724 |
-
parser.add_argument(
|
725 |
-
parser.add_argument(
|
726 |
-
parser.add_argument(
|
727 |
-
parser.add_argument(
|
728 |
|
729 |
-
parser.add_argument(
|
730 |
|
731 |
# FGT inference parameters
|
732 |
-
parser.add_argument(
|
733 |
-
parser.add_argument(
|
734 |
-
parser.add_argument(
|
735 |
|
736 |
# visualization
|
737 |
-
parser.add_argument(
|
738 |
-
|
739 |
-
|
740 |
-
|
741 |
-
|
742 |
-
|
743 |
-
|
744 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
745 |
|
746 |
args = parser.parse_args()
|
747 |
|
|
|
1 |
+
import sys
|
2 |
+
import os
|
3 |
+
|
4 |
+
sys.path.append(os.path.abspath(os.path.join(__file__, "..", "..")))
|
5 |
+
sys.path.append(os.path.abspath(os.path.join(__file__, "..", "..", "FGT")))
|
6 |
+
sys.path.append(os.path.abspath(os.path.join(__file__, "..", "..", "LAFC")))
|
7 |
+
warnings.filterwarnings("ignore")
|
8 |
+
|
9 |
import cvbase
|
10 |
from torchvision.transforms import ToTensor
|
11 |
from get_flowNN_gradient import get_flowNN_gradient
|
|
|
28 |
import cv2
|
29 |
import argparse
|
30 |
import warnings
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
31 |
|
32 |
|
33 |
def to_tensor(img):
|
|
|
47 |
return flows_filled
|
48 |
|
49 |
|
50 |
+
def np2tensor(array, near="c"):
|
51 |
if isinstance(array, list):
|
52 |
array = np.stack(array, axis=0) # [t, h, w, c]
|
53 |
+
if near == "c":
|
54 |
+
array = (
|
55 |
+
torch.from_numpy(np.transpose(array, (3, 0, 1, 2))).unsqueeze(0).float()
|
56 |
+
) # [1, c, t, h, w]
|
57 |
+
elif near == "t":
|
58 |
+
array = torch.from_numpy(np.transpose(array, (0, 3, 1, 2))).unsqueeze(0).float()
|
59 |
else:
|
60 |
+
raise ValueError(f"Unknown near type: {near}")
|
61 |
return array
|
62 |
|
63 |
|
64 |
def tensor2np(array):
|
65 |
+
array = torch.stack(array, dim=-1).squeeze(0).permute(1, 2, 0, 3).cpu().numpy()
|
|
|
66 |
return array
|
67 |
|
68 |
|
69 |
def gradient_mask(mask):
|
70 |
+
gradient_mask = np.logical_or.reduce(
|
71 |
+
(
|
72 |
+
mask,
|
73 |
+
np.concatenate(
|
74 |
+
(mask[1:, :], np.zeros((1, mask.shape[1]), dtype=np.bool)), axis=0
|
75 |
+
),
|
76 |
+
np.concatenate(
|
77 |
+
(mask[:, 1:], np.zeros((mask.shape[0], 1), dtype=np.bool)), axis=1
|
78 |
+
),
|
79 |
+
)
|
80 |
+
)
|
81 |
|
82 |
return gradient_mask
|
83 |
|
|
|
113 |
|
114 |
|
115 |
def save_flows(output, videoFlowF, videoFlowB):
|
116 |
+
create_dir(os.path.join(output, "completed_flow", "forward_flo"))
|
117 |
+
create_dir(os.path.join(output, "completed_flow", "backward_flo"))
|
118 |
+
create_dir(os.path.join(output, "completed_flow", "forward_png"))
|
119 |
+
create_dir(os.path.join(output, "completed_flow", "backward_png"))
|
120 |
N = videoFlowF.shape[-1]
|
121 |
for i in range(N):
|
122 |
forward_flow = videoFlowF[..., i]
|
123 |
backward_flow = videoFlowB[..., i]
|
124 |
forward_flow_vis = cvbase.flow2rgb(forward_flow)
|
125 |
backward_flow_vis = cvbase.flow2rgb(backward_flow)
|
126 |
+
cvbase.write_flow(
|
127 |
+
forward_flow,
|
128 |
+
os.path.join(
|
129 |
+
output, "completed_flow", "forward_flo", "{:05d}.flo".format(i)
|
130 |
+
),
|
131 |
+
)
|
132 |
+
cvbase.write_flow(
|
133 |
+
backward_flow,
|
134 |
+
os.path.join(
|
135 |
+
output, "completed_flow", "backward_flo", "{:05d}.flo".format(i)
|
136 |
+
),
|
137 |
+
)
|
138 |
+
imageio.imwrite(
|
139 |
+
os.path.join(
|
140 |
+
output, "completed_flow", "forward_png", "{:05d}.png".format(i)
|
141 |
+
),
|
142 |
+
forward_flow_vis,
|
143 |
+
)
|
144 |
+
imageio.imwrite(
|
145 |
+
os.path.join(
|
146 |
+
output, "completed_flow", "backward_png", "{:05d}.png".format(i)
|
147 |
+
),
|
148 |
+
backward_flow_vis,
|
149 |
+
)
|
150 |
|
151 |
|
152 |
def save_fgcp(output, frames, masks):
|
153 |
+
create_dir(os.path.join(output, "prop_frames"))
|
154 |
+
create_dir(os.path.join(output, "masks_left"))
|
155 |
+
create_dir(os.path.join(output, "prop_frames_npy"))
|
156 |
+
create_dir(os.path.join(output, "masks_left_npy"))
|
157 |
|
158 |
assert len(frames) == masks.shape[2]
|
159 |
for i in range(len(frames)):
|
160 |
+
cv2.imwrite(
|
161 |
+
os.path.join(output, "prop_frames", "%05d.png" % i), frames[i] * 255.0
|
162 |
+
)
|
163 |
+
cv2.imwrite(
|
164 |
+
os.path.join(output, "masks_left", "%05d.png" % i), masks[:, :, i] * 255.0
|
165 |
+
)
|
166 |
+
np.save(
|
167 |
+
os.path.join(output, "prop_frames_npy", "%05d.npy" % i), frames[i] * 255.0
|
168 |
+
)
|
169 |
+
np.save(
|
170 |
+
os.path.join(output, "masks_left_npy", "%05d.npy" % i),
|
171 |
+
masks[:, :, i] * 255.0,
|
172 |
+
)
|
173 |
|
174 |
|
175 |
def create_dir(dir):
|
176 |
+
"""Creates a directory if not exist."""
|
|
|
177 |
if not os.path.exists(dir):
|
178 |
os.makedirs(dir)
|
179 |
|
180 |
|
181 |
def initialize_RAFT(args, device):
|
182 |
+
"""Initializes the RAFT model."""
|
|
|
183 |
model = torch.nn.DataParallel(RAFT(args))
|
184 |
model.load_state_dict(torch.load(args.raft_model))
|
185 |
|
|
|
193 |
def initialize_LAFC(args, device):
|
194 |
print(args.lafc_ckpts)
|
195 |
assert len(os.listdir(args.lafc_ckpts)) == 2
|
196 |
+
checkpoint, config_file = (
|
197 |
+
glob.glob(os.path.join(args.lafc_ckpts, "*.tar"))[0],
|
198 |
+
glob.glob(os.path.join(args.lafc_ckpts, "*.yaml"))[0],
|
199 |
+
)
|
200 |
+
with open(config_file, "r") as f:
|
201 |
configs = yaml.full_load(f)
|
202 |
+
model = configs["model"]
|
203 |
+
pkg = import_module("LAFC.models.{}".format(model))
|
204 |
model = pkg.Model(configs)
|
205 |
+
state = torch.load(
|
206 |
+
checkpoint, map_location=lambda storage, loc: storage.cuda(device)
|
207 |
+
)
|
208 |
+
model.load_state_dict(state["model_state_dict"])
|
209 |
model = model.to(device)
|
210 |
return model, configs
|
211 |
|
212 |
|
213 |
def initialize_FGT(args, device):
|
214 |
assert len(os.listdir(args.fgt_ckpts)) == 2
|
215 |
+
checkpoint, config_file = (
|
216 |
+
glob.glob(os.path.join(args.fgt_ckpts, "*.tar"))[0],
|
217 |
+
glob.glob(os.path.join(args.fgt_ckpts, "*.yaml"))[0],
|
218 |
+
)
|
219 |
+
with open(config_file, "r") as f:
|
220 |
configs = yaml.full_load(f)
|
221 |
+
model = configs["model"]
|
222 |
+
net = import_module("FGT.models.{}".format(model))
|
223 |
model = net.Model(configs).to(device)
|
224 |
+
state = torch.load(
|
225 |
+
checkpoint, map_location=lambda storage, loc: storage.cuda(device)
|
226 |
+
)
|
227 |
+
model.load_state_dict(state["model_state_dict"])
|
228 |
return model, configs
|
229 |
|
230 |
|
231 |
def calculate_flow(args, model, video, mode):
|
232 |
+
"""Calculates optical flow."""
|
233 |
+
if mode not in ["forward", "backward"]:
|
|
|
234 |
raise NotImplementedError
|
235 |
|
236 |
imgH, imgW = args.imgH, args.imgW
|
237 |
Flow = np.empty(((imgH, imgW, 2, 0)), dtype=np.float32)
|
238 |
|
239 |
if args.vis_flows:
|
240 |
+
create_dir(os.path.join(args.outroot, "flow", mode + "_flo"))
|
241 |
+
create_dir(os.path.join(args.outroot, "flow", mode + "_png"))
|
242 |
|
243 |
with torch.no_grad():
|
244 |
for i in range(video.shape[0] - 1):
|
245 |
print(
|
246 |
+
"Calculating {0} flow {1:2d} <---> {2:2d}".format(mode, i, i + 1),
|
247 |
+
"\r",
|
248 |
+
end="",
|
249 |
+
)
|
250 |
+
if mode == "forward":
|
251 |
# Flow i -> i + 1
|
252 |
image1 = video[i, None]
|
253 |
image2 = video[i + 1, None]
|
254 |
+
elif mode == "backward":
|
255 |
# Flow i + 1 -> i
|
256 |
image1 = video[i + 1, None]
|
257 |
image2 = video[i, None]
|
|
|
264 |
h, w = flow.shape[:2]
|
265 |
if h != imgH or w != imgW:
|
266 |
flow = cv2.resize(flow, (imgW, imgH), cv2.INTER_LINEAR)
|
267 |
+
flow[:, :, 0] *= float(imgW) / float(w)
|
268 |
+
flow[:, :, 1] *= float(imgH) / float(h)
|
269 |
|
270 |
Flow = np.concatenate((Flow, flow[..., None]), axis=-1)
|
271 |
|
|
|
275 |
flow_img = Image.fromarray(flow_img)
|
276 |
|
277 |
# Saves the flow and flow_img.
|
278 |
+
flow_img.save(
|
279 |
+
os.path.join(args.outroot, "flow", mode + "_png", "%05d.png" % i)
|
280 |
+
)
|
281 |
+
utils.frame_utils.writeFlow(
|
282 |
+
os.path.join(args.outroot, "flow", mode + "_flo", "%05d.flo" % i),
|
283 |
+
flow,
|
284 |
+
)
|
285 |
|
286 |
return Flow
|
287 |
|
288 |
|
289 |
def extrapolation(args, video_ori, corrFlowF_ori, corrFlowB_ori):
|
290 |
+
"""Prepares the data for video extrapolation."""
|
|
|
291 |
imgH, imgW, _, nFrame = video_ori.shape
|
292 |
|
293 |
# Defines new FOV.
|
|
|
300 |
|
301 |
# Generates the mask for missing region.
|
302 |
flow_mask = np.ones(((imgH_extr, imgW_extr)), dtype=np.bool)
|
303 |
+
flow_mask[H_start : H_start + imgH, W_start : W_start + imgW] = 0
|
304 |
|
305 |
mask_dilated = gradient_mask(flow_mask)
|
306 |
|
307 |
# Extrapolates the FOV for video.
|
308 |
video = np.zeros(((imgH_extr, imgW_extr, 3, nFrame)), dtype=np.float32)
|
309 |
+
video[H_start : H_start + imgH, W_start : W_start + imgW, :, :] = video_ori
|
310 |
|
311 |
for i in range(nFrame):
|
312 |
+
print("Preparing frame {0}".format(i), "\r", end="")
|
313 |
+
video[:, :, :, i] = (
|
314 |
+
cv2.inpaint(
|
315 |
+
(video[:, :, :, i] * 255).astype(np.uint8),
|
316 |
+
flow_mask.astype(np.uint8),
|
317 |
+
3,
|
318 |
+
cv2.INPAINT_TELEA,
|
319 |
+
).astype(np.float32)
|
320 |
+
/ 255.0
|
321 |
+
)
|
322 |
|
323 |
# Extrapolates the FOV for flow.
|
324 |
+
corrFlowF = np.zeros(((imgH_extr, imgW_extr, 2, nFrame - 1)), dtype=np.float32)
|
325 |
+
corrFlowB = np.zeros(((imgH_extr, imgW_extr, 2, nFrame - 1)), dtype=np.float32)
|
326 |
+
corrFlowF[H_start : H_start + imgH, W_start : W_start + imgW, :] = corrFlowF_ori
|
327 |
+
corrFlowB[H_start : H_start + imgH, W_start : W_start + imgW, :] = corrFlowB_ori
|
328 |
+
|
329 |
+
return (
|
330 |
+
video,
|
331 |
+
corrFlowF,
|
332 |
+
corrFlowB,
|
333 |
+
flow_mask,
|
334 |
+
mask_dilated,
|
335 |
+
(W_start, H_start),
|
336 |
+
(W_start + imgW, H_start + imgH),
|
337 |
+
)
|
338 |
|
339 |
|
340 |
def complete_flow(config, flow_model, flows, flow_masks, mode, device):
|
341 |
+
if mode not in ["forward", "backward"]:
|
342 |
+
raise NotImplementedError(f"Error flow mode {mode}")
|
343 |
flow_masks = np.moveaxis(flow_masks, -1, 0) # [N, H, W]
|
344 |
flows = np.moveaxis(flows, -1, 0) # [N, H, W, 2]
|
345 |
if len(flow_masks.shape) == 3:
|
346 |
flow_masks = flow_masks[:, :, :, np.newaxis]
|
347 |
+
if mode == "forward":
|
348 |
flow_masks = flow_masks[0:-1]
|
349 |
else:
|
350 |
flow_masks = flow_masks[1:]
|
351 |
|
352 |
+
num_flows, flow_interval = config["num_flows"], config["flow_interval"]
|
353 |
|
354 |
diffused_flows = diffusion(flows, flow_masks)
|
355 |
|
|
|
366 |
pivot = num_flows // 2
|
367 |
for i in range(t):
|
368 |
indices = indicesGen(i, flow_interval, num_flows, t)
|
369 |
+
print("Indices: ", indices, "\r", end="")
|
370 |
cand_flows = flows[:, :, indices]
|
371 |
cand_masks = flow_masks[:, :, indices]
|
372 |
inputs = diffused_flows[:, :, indices]
|
|
|
386 |
def read_flow(flow_dir, video):
|
387 |
nFrame, _, imgH, imgW = video.shape
|
388 |
Flow = np.empty(((imgH, imgW, 2, 0)), dtype=np.float32)
|
389 |
+
flows = sorted(glob.glob(os.path.join(flow_dir, "*.flo")))
|
390 |
for flow in flows:
|
391 |
flow_data = cvbase.read_flow(flow)
|
392 |
h, w = flow_data.shape[:2]
|
393 |
flow_data = cv2.resize(flow_data, (imgW, imgH), cv2.INTER_LINEAR)
|
394 |
+
flow_data[:, :, 0] *= float(imgW) / float(w)
|
395 |
+
flow_data[:, :, 1] *= float(imgH) / float(h)
|
396 |
Flow = np.concatenate((Flow, flow_data[..., None]), axis=-1)
|
397 |
return Flow
|
398 |
|
399 |
|
400 |
def norm_flows(flows):
|
401 |
+
assert len(flows.shape) == 5, "FLow shape: {}".format(flows.shape)
|
402 |
flattened_flows = flows.flatten(3)
|
403 |
flow_max = torch.max(flattened_flows, dim=-1, keepdim=True)[0]
|
404 |
flows = flows / flow_max.unsqueeze(-1)
|
|
|
406 |
|
407 |
|
408 |
def save_results(outdir, comp_frames):
|
409 |
+
out_dir = os.path.join(outdir, "frames")
|
410 |
if not os.path.exists(out_dir):
|
411 |
os.makedirs(out_dir)
|
412 |
for i in range(len(comp_frames)):
|
413 |
+
out_path = os.path.join(out_dir, "{:05d}.png".format(i))
|
414 |
cv2.imwrite(out_path, comp_frames[i][:, :, ::-1])
|
415 |
|
416 |
|
417 |
def video_inpainting(args, imgArr, imgMaskArr):
|
418 |
+
device = torch.device("cuda:{}".format(args.gpu))
|
419 |
print(args)
|
420 |
if args.opt is not None:
|
421 |
+
with open(args.opt, "r") as f:
|
422 |
opts = yaml.full_load(f)
|
423 |
|
424 |
for k in opts.keys():
|
|
|
449 |
|
450 |
# Load video.
|
451 |
video, video_flow = [], []
|
452 |
+
if args.mode == "watermark_removal":
|
453 |
+
maskname_list = glob.glob(os.path.join(args.path_mask, "*.png")) + glob.glob(
|
454 |
+
os.path.join(args.path_mask, "*.jpg")
|
455 |
+
)
|
456 |
assert len(filename_list) == len(maskname_list)
|
457 |
for filename, maskname in zip(sorted(filename_list), sorted(maskname_list)):
|
458 |
+
frame = (
|
459 |
+
torch.from_numpy(np.array(Image.open(filename)).astype(np.uint8))
|
460 |
+
.permute(2, 0, 1)
|
461 |
+
.float()
|
462 |
+
.unsqueeze(0)
|
463 |
+
)
|
464 |
+
mask = (
|
465 |
+
torch.from_numpy(np.array(Image.open(maskname)).astype(np.uint8))
|
466 |
+
.permute(2, 0, 1)
|
467 |
+
.float()
|
468 |
+
.unsqueeze(0)
|
469 |
+
)
|
470 |
mask[mask > 0] = 1
|
471 |
frame = frame * (1 - mask)
|
472 |
+
frame = F2.upsample(
|
473 |
+
frame, size=(imgH, imgW), mode="bilinear", align_corners=False
|
474 |
+
)
|
475 |
+
frame_flow = F2.upsample(
|
476 |
+
frame, size=(flowH, flowW), mode="bilinear", align_corners=False
|
477 |
+
)
|
478 |
video.append(frame)
|
479 |
video_flow.append(frame_flow)
|
480 |
else:
|
481 |
+
"""for filename in sorted(filename_list):
|
482 |
+
frame = torch.from_numpy(np.array(Image.open(filename)).astype(np.uint8)).permute(2, 0, 1).float().unsqueeze(0)
|
483 |
+
frame = F2.upsample(frame, size=(imgH, imgW), mode='bilinear', align_corners=False)
|
484 |
+
frame_flow = F2.upsample(frame, size=(flowH, flowW), mode='bilinear', align_corners=False)
|
485 |
+
video.append(frame)
|
486 |
+
video_flow.append(frame_flow)"""
|
487 |
for im in imgArr:
|
488 |
+
frame = (
|
489 |
+
torch.from_numpy(np.array(im).astype(np.uint8))
|
490 |
+
.permute(2, 0, 1)
|
491 |
+
.float()
|
492 |
+
.unsqueeze(0)
|
493 |
+
)
|
494 |
+
frame = F2.upsample(
|
495 |
+
frame, size=(imgH, imgW), mode="bilinear", align_corners=False
|
496 |
+
)
|
497 |
+
frame_flow = F2.upsample(
|
498 |
+
frame, size=(flowH, flowW), mode="bilinear", align_corners=False
|
499 |
+
)
|
500 |
video.append(frame)
|
501 |
video_flow.append(frame_flow)
|
502 |
|
|
|
508 |
|
509 |
# Calcutes the corrupted flow.
|
510 |
forward_flows = calculate_flow(
|
511 |
+
args, RAFT_model, video_flow, "forward"
|
512 |
+
) # [B, C, 2, N]
|
513 |
+
backward_flows = calculate_flow(args, RAFT_model, video_flow, "backward")
|
514 |
|
515 |
# Makes sure video is in BGR (opencv) format.
|
516 |
+
video = (
|
517 |
+
video.permute(2, 3, 1, 0).cpu().numpy()[:, :, ::-1, :] / 255.0
|
518 |
+
) # np array -> [h, w, c, N] (0~1)
|
|
|
519 |
|
520 |
+
if args.mode == "video_extrapolation":
|
521 |
# Creates video and flow where the extrapolated region are missing.
|
522 |
+
(
|
523 |
+
video,
|
524 |
+
forward_flows,
|
525 |
+
backward_flows,
|
526 |
+
flow_mask,
|
527 |
+
mask_dilated,
|
528 |
+
start_point,
|
529 |
+
end_point,
|
530 |
+
) = extrapolation(args, video, forward_flows, backward_flows)
|
531 |
imgH, imgW = video.shape[:2]
|
532 |
|
533 |
# mask indicating the missing region in the video.
|
|
|
537 |
|
538 |
else:
|
539 |
# Loads masks.
|
540 |
+
filename_list = glob.glob(os.path.join(args.path_mask, "*.png")) + glob.glob(
|
541 |
+
os.path.join(args.path_mask, "*.jpg")
|
542 |
+
)
|
543 |
|
544 |
mask = []
|
545 |
mask_dilated = []
|
546 |
flow_mask = []
|
547 |
+
"""for filename in sorted(filename_list):
|
548 |
mask_img = np.array(Image.open(filename).convert('L'))
|
549 |
mask_img = cv2.resize(mask_img, dsize=(imgW, imgH), interpolation=cv2.INTER_NEAREST)
|
550 |
|
|
|
557 |
if args.frame_dilates > 0:
|
558 |
mask_img = scipy.ndimage.binary_dilation(mask_img, iterations=args.frame_dilates)
|
559 |
mask.append(mask_img)
|
560 |
+
mask_dilated.append(gradient_mask(mask_img))"""
|
561 |
|
562 |
for f_mask in imgMaskArr:
|
563 |
mask_img = np.array(f_mask)
|
564 |
+
mask_img = cv2.resize(
|
565 |
+
mask_img, dsize=(imgW, imgH), interpolation=cv2.INTER_NEAREST
|
566 |
+
)
|
567 |
|
568 |
if args.flow_mask_dilates > 0:
|
569 |
flow_mask_img = scipy.ndimage.binary_dilation(
|
570 |
+
mask_img, iterations=args.flow_mask_dilates
|
571 |
+
)
|
572 |
else:
|
573 |
flow_mask_img = mask_img
|
574 |
flow_mask.append(flow_mask_img)
|
575 |
|
576 |
if args.frame_dilates > 0:
|
577 |
mask_img = scipy.ndimage.binary_dilation(
|
578 |
+
mask_img, iterations=args.frame_dilates
|
579 |
+
)
|
580 |
mask.append(mask_img)
|
581 |
mask_dilated.append(gradient_mask(mask_img))
|
582 |
|
|
|
587 |
|
588 |
# Completes the flow.
|
589 |
videoFlowF = complete_flow(
|
590 |
+
LAFC_config, LAFC_model, forward_flows, flow_mask, "forward", device
|
591 |
+
)
|
592 |
videoFlowB = complete_flow(
|
593 |
+
LAFC_config, LAFC_model, backward_flows, flow_mask, "backward", device
|
594 |
+
)
|
595 |
videoFlowF = tensor2np(videoFlowF)
|
596 |
videoFlowB = tensor2np(videoFlowB)
|
597 |
+
print("\nFinish flow completion.")
|
598 |
|
599 |
if args.vis_completed_flows:
|
600 |
save_flows(args.outroot, videoFlowF, videoFlowB)
|
|
|
606 |
for indFrame in range(nFrame):
|
607 |
img = video[:, :, :, indFrame]
|
608 |
img[mask[:, :, indFrame], :] = 0
|
609 |
+
img = (
|
610 |
+
cv2.inpaint(
|
611 |
+
(img * 255).astype(np.uint8),
|
612 |
+
mask[:, :, indFrame].astype(np.uint8),
|
613 |
+
3,
|
614 |
+
cv2.INPAINT_TELEA,
|
615 |
+
).astype(np.float32)
|
616 |
+
/ 255.0
|
617 |
+
)
|
618 |
+
|
619 |
+
gradient_x_ = np.concatenate(
|
620 |
+
(np.diff(img, axis=1), np.zeros((imgH, 1, 3), dtype=np.float32)), axis=1
|
621 |
+
)
|
622 |
gradient_y_ = np.concatenate(
|
623 |
+
(np.diff(img, axis=0), np.zeros((1, imgW, 3), dtype=np.float32)), axis=0
|
624 |
+
)
|
625 |
gradient_x = np.concatenate(
|
626 |
+
(gradient_x, gradient_x_.reshape(imgH, imgW, 3, 1)), axis=-1
|
627 |
+
)
|
628 |
gradient_y = np.concatenate(
|
629 |
+
(gradient_y, gradient_y_.reshape(imgH, imgW, 3, 1)), axis=-1
|
630 |
+
)
|
631 |
|
632 |
gradient_x[mask_dilated[:, :, indFrame], :, indFrame] = 0
|
633 |
gradient_y[mask_dilated[:, :, indFrame], :, indFrame] = 0
|
|
|
638 |
video_comp = video
|
639 |
|
640 |
# Gradient propagation.
|
641 |
+
gradient_x_filled, gradient_y_filled, mask_gradient = get_flowNN_gradient(
|
642 |
+
args,
|
643 |
+
gradient_x_filled,
|
644 |
+
gradient_y_filled,
|
645 |
+
mask,
|
646 |
+
mask_gradient,
|
647 |
+
videoFlowF,
|
648 |
+
videoFlowB,
|
649 |
+
None,
|
650 |
+
None,
|
651 |
+
)
|
652 |
|
653 |
# if there exist holes in mask, Poisson blending will fail. So I did this trick. I sacrifice some value. Another solution is to modify Poisson blending.
|
654 |
for indFrame in range(nFrame):
|
655 |
+
mask_gradient[:, :, indFrame] = scipy.ndimage.binary_fill_holes(
|
656 |
+
mask_gradient[:, :, indFrame]
|
657 |
+
).astype(np.bool)
|
658 |
|
659 |
# After one gradient propagation iteration
|
660 |
# gradient --> RGB
|
|
|
664 |
|
665 |
if mask[:, :, indFrame].sum() > 0:
|
666 |
try:
|
667 |
+
frameBlend, UnfilledMask = Poisson_blend_img(
|
668 |
+
video_comp[:, :, :, indFrame],
|
669 |
+
gradient_x_filled[:, 0 : imgW - 1, :, indFrame],
|
670 |
+
gradient_y_filled[0 : imgH - 1, :, :, indFrame],
|
671 |
+
mask[:, :, indFrame],
|
672 |
+
mask_gradient[:, :, indFrame],
|
673 |
+
)
|
674 |
except:
|
675 |
+
frameBlend, UnfilledMask = (
|
676 |
+
video_comp[:, :, :, indFrame],
|
677 |
+
mask[:, :, indFrame],
|
678 |
+
)
|
679 |
|
680 |
frameBlend = np.clip(frameBlend, 0, 1.0)
|
681 |
+
tmp = (
|
682 |
+
cv2.inpaint(
|
683 |
+
(frameBlend * 255).astype(np.uint8),
|
684 |
+
UnfilledMask.astype(np.uint8),
|
685 |
+
3,
|
686 |
+
cv2.INPAINT_TELEA,
|
687 |
+
).astype(np.float32)
|
688 |
+
/ 255.0
|
689 |
+
)
|
690 |
frameBlend[UnfilledMask, :] = tmp[UnfilledMask, :]
|
691 |
|
692 |
video_comp[:, :, :, indFrame] = frameBlend
|
|
|
694 |
|
695 |
frameBlend_ = copy.deepcopy(frameBlend)
|
696 |
# Green indicates the regions that are not filled yet.
|
697 |
+
frameBlend_[mask[:, :, indFrame], :] = [0, 1.0, 0]
|
698 |
else:
|
699 |
frameBlend_ = video_comp[:, :, :, indFrame]
|
700 |
frameBlends.append(frameBlend_)
|
|
|
707 |
for i in range(len(frameBlends)):
|
708 |
frameBlends[i] = frameBlends[i][:, :, ::-1]
|
709 |
|
710 |
+
frames_first = np2tensor(frameBlends, near="t").to(device)
|
711 |
mask = np.moveaxis(mask, -1, 0)
|
712 |
mask = mask[:, :, :, np.newaxis]
|
713 |
+
masks = np2tensor(mask, near="t").to(device)
|
714 |
normed_frames = frames_first * 2 - 1
|
715 |
comp_frames = [None] * video_length
|
716 |
|
|
|
722 |
|
723 |
videoFlowF = np.concatenate([videoFlowF, videoFlowF[-1:, ...]], axis=0)
|
724 |
|
725 |
+
flows = np2tensor(videoFlowF, near="t")
|
726 |
flows = norm_flows(flows).to(device)
|
727 |
|
728 |
for f in range(0, video_length, neighbor_stride):
|
729 |
+
neighbor_ids = [
|
730 |
+
i
|
731 |
+
for i in range(
|
732 |
+
max(0, f - neighbor_stride), min(video_length, f + neighbor_stride + 1)
|
733 |
+
)
|
734 |
+
]
|
735 |
+
ref_ids = get_ref_index(f, neighbor_ids, video_length, ref_length, num_ref)
|
736 |
print(f, len(neighbor_ids), len(ref_ids))
|
737 |
selected_frames = normed_frames[:, neighbor_ids + ref_ids]
|
738 |
selected_masks = masks[:, neighbor_ids + ref_ids]
|
739 |
masked_frames = selected_frames * (1 - selected_masks)
|
740 |
selected_flows = flows[:, neighbor_ids + ref_ids]
|
741 |
with torch.no_grad():
|
742 |
+
filled_frames = FGT_model(masked_frames, selected_flows, selected_masks)
|
|
|
743 |
filled_frames = (filled_frames + 1) / 2
|
744 |
filled_frames = filled_frames.cpu().permute(0, 2, 3, 1).numpy() * 255
|
745 |
for i in range(len(neighbor_ids)):
|
746 |
idx = neighbor_ids[i]
|
747 |
+
valid_frame = frames_first[0, idx].cpu().permute(1, 2, 0).numpy() * 255.0
|
|
|
748 |
valid_mask = masks[0, idx].cpu().permute(1, 2, 0).numpy()
|
749 |
+
comp = np.array(filled_frames[i]).astype(np.uint8) * valid_mask + np.array(
|
750 |
+
valid_frame
|
751 |
+
).astype(np.uint8) * (1 - valid_mask)
|
752 |
if comp_frames[idx] is None:
|
753 |
comp_frames[idx] = comp
|
754 |
else:
|
755 |
+
comp_frames[idx] = (
|
756 |
+
comp_frames[idx].astype(np.float32) * 0.5
|
757 |
+
+ comp.astype(np.float32) * 0.5
|
758 |
+
)
|
759 |
if args.vis_frame:
|
760 |
save_results(args.outroot, comp_frames)
|
761 |
create_dir(args.outroot)
|
762 |
for i in range(len(comp_frames)):
|
763 |
comp_frames[i] = comp_frames[i].astype(np.uint8)
|
764 |
+
imageio.mimwrite(
|
765 |
+
os.path.join(args.outroot, "result.mp4"), comp_frames, fps=30, quality=8
|
766 |
+
)
|
767 |
+
print(f"Done, please check your result in {args.outroot} ")
|
768 |
|
769 |
|
770 |
def main(args):
|
771 |
+
assert args.mode in (
|
772 |
+
"object_removal",
|
773 |
+
"video_extrapolation",
|
774 |
+
"watermark_removal",
|
775 |
+
), (
|
776 |
"Accepted modes: 'object_removal', 'video_extrapolation', and 'watermark_removal', but input is %s"
|
777 |
) % args.mode
|
778 |
video_inpainting(args)
|
779 |
|
780 |
|
781 |
+
if __name__ == "__main__":
|
782 |
parser = argparse.ArgumentParser()
|
783 |
+
parser.add_argument(
|
784 |
+
"--opt",
|
785 |
+
default="configs/object_removal.yaml",
|
786 |
+
help="Please select your config file for inference",
|
787 |
+
)
|
788 |
# video completion
|
|
|
|
|
789 |
parser.add_argument(
|
790 |
+
"--mode",
|
791 |
+
default="object_removal",
|
792 |
+
choices=["object_removal", "watermark_removal", "video_extrapolation"],
|
793 |
+
help="modes: object_removal / video_extrapolation",
|
794 |
+
)
|
795 |
+
parser.add_argument(
|
796 |
+
"--path", default="/myData/davis_resized/walking", help="dataset for evaluation"
|
797 |
+
)
|
798 |
parser.add_argument(
|
799 |
+
"--path_mask",
|
800 |
+
default="/myData/dilateAnnotations_4/walking",
|
801 |
+
help="mask for object removal",
|
802 |
+
)
|
803 |
parser.add_argument(
|
804 |
+
"--outroot", default="quick_start/walking3", help="output directory"
|
805 |
+
)
|
806 |
+
parser.add_argument(
|
807 |
+
"--consistencyThres",
|
808 |
+
dest="consistencyThres",
|
809 |
+
default=5,
|
810 |
+
type=float,
|
811 |
+
help="flow consistency error threshold",
|
812 |
+
)
|
813 |
+
parser.add_argument("--alpha", dest="alpha", default=0.1, type=float)
|
814 |
+
parser.add_argument("--Nonlocal", dest="Nonlocal", default=False, type=bool)
|
815 |
|
816 |
# RAFT
|
817 |
parser.add_argument(
|
818 |
+
"--raft_model",
|
819 |
+
default="../LAFC/flowCheckPoint/raft-things.pth",
|
820 |
+
help="restore checkpoint",
|
821 |
+
)
|
822 |
+
parser.add_argument("--small", action="store_true", help="use small model")
|
823 |
+
parser.add_argument(
|
824 |
+
"--mixed_precision", action="store_true", help="use mixed precision"
|
825 |
+
)
|
826 |
+
parser.add_argument(
|
827 |
+
"--alternate_corr",
|
828 |
+
action="store_true",
|
829 |
+
help="use efficent correlation implementation",
|
830 |
+
)
|
831 |
|
832 |
# LAFC
|
833 |
+
parser.add_argument("--lafc_ckpts", type=str, default="../LAFC/checkpoint")
|
834 |
|
835 |
# FGT
|
836 |
+
parser.add_argument("--fgt_ckpts", type=str, default="../FGT/checkpoint")
|
837 |
|
838 |
# extrapolation
|
839 |
+
parser.add_argument(
|
840 |
+
"--H_scale", dest="H_scale", default=2, type=float, help="H extrapolation scale"
|
841 |
+
)
|
842 |
+
parser.add_argument(
|
843 |
+
"--W_scale", dest="W_scale", default=2, type=float, help="W extrapolation scale"
|
844 |
+
)
|
845 |
|
846 |
# Image basic information
|
847 |
+
parser.add_argument("--imgH", type=int, default=256)
|
848 |
+
parser.add_argument("--imgW", type=int, default=432)
|
849 |
+
parser.add_argument("--flow_mask_dilates", type=int, default=8)
|
850 |
+
parser.add_argument("--frame_dilates", type=int, default=0)
|
851 |
|
852 |
+
parser.add_argument("--gpu", type=int, default=0)
|
853 |
|
854 |
# FGT inference parameters
|
855 |
+
parser.add_argument("--step", type=int, default=10)
|
856 |
+
parser.add_argument("--num_ref", type=int, default=-1)
|
857 |
+
parser.add_argument("--neighbor_stride", type=int, default=5)
|
858 |
|
859 |
# visualization
|
860 |
+
parser.add_argument(
|
861 |
+
"--vis_flows", action="store_true", help="Visualize the initialized flows"
|
862 |
+
)
|
863 |
+
parser.add_argument(
|
864 |
+
"--vis_completed_flows",
|
865 |
+
action="store_true",
|
866 |
+
help="Visualize the completed flows",
|
867 |
+
)
|
868 |
+
parser.add_argument(
|
869 |
+
"--vis_prop",
|
870 |
+
action="store_true",
|
871 |
+
help="Visualize the frames after stage-I filling (flow guided content propagation)",
|
872 |
+
)
|
873 |
+
parser.add_argument("--vis_frame", action="store_true", help="Visualize frames")
|
874 |
|
875 |
args = parser.parse_args()
|
876 |
|