oguzakif commited on
Commit
9f4175d
·
1 Parent(s): d4b77ac

absolute paths are added to video_inpainting script

Browse files
Files changed (1) hide show
  1. FGT_codes/tool/video_inpainting.py +383 -254
FGT_codes/tool/video_inpainting.py CHANGED
@@ -1,3 +1,11 @@
 
 
 
 
 
 
 
 
1
  import cvbase
2
  from torchvision.transforms import ToTensor
3
  from get_flowNN_gradient import get_flowNN_gradient
@@ -20,22 +28,6 @@ import glob
20
  import cv2
21
  import argparse
22
  import warnings
23
- import os
24
- import sys
25
-
26
- sys.path.append(os.path.abspath(os.path.join(__file__, '..', '..')))
27
- sys.path.append(os.path.abspath(os.path.join(__file__, '..', '..', 'tool')))
28
- sys.path.append(os.path.abspath(os.path.join(
29
- __file__, '..', '..', 'tool', 'utils')))
30
- sys.path.append(os.path.abspath(os.path.join(
31
- __file__, '..', '..', 'tool', 'utils', 'region_fill.py')))
32
- sys.path.append(os.path.abspath(os.path.join(
33
- __file__, '..', '..', 'tool', 'utils', 'Poisson_blend_img.py')))
34
- sys.path.append(os.path.abspath(os.path.join(__file__, '..', '..', 'FGT')))
35
- sys.path.append(os.path.abspath(os.path.join(__file__, '..', '..', 'LAFC')))
36
- sys.path.append(os.path.abspath(
37
- os.path.join(os.path.dirname("__file__"), '..')))
38
- warnings.filterwarnings("ignore")
39
 
40
 
41
  def to_tensor(img):
@@ -55,32 +47,37 @@ def diffusion(flows, masks):
55
  return flows_filled
56
 
57
 
58
- def np2tensor(array, near='c'):
59
  if isinstance(array, list):
60
  array = np.stack(array, axis=0) # [t, h, w, c]
61
- if near == 'c':
62
- array = torch.from_numpy(np.transpose(array, (3, 0, 1, 2))).unsqueeze(
63
- 0).float() # [1, c, t, h, w]
64
- elif near == 't':
65
- array = torch.from_numpy(np.transpose(
66
- array, (0, 3, 1, 2))).unsqueeze(0).float()
67
  else:
68
- raise ValueError(f'Unknown near type: {near}')
69
  return array
70
 
71
 
72
  def tensor2np(array):
73
- array = torch.stack(array, dim=-1).squeeze(0).permute(1,
74
- 2, 0, 3).cpu().numpy()
75
  return array
76
 
77
 
78
  def gradient_mask(mask):
79
- gradient_mask = np.logical_or.reduce((mask,
80
- np.concatenate((mask[1:, :], np.zeros((1, mask.shape[1]), dtype=np.bool)),
81
- axis=0),
82
- np.concatenate((mask[:, 1:], np.zeros((mask.shape[0], 1), dtype=np.bool)),
83
- axis=1)))
 
 
 
 
 
 
84
 
85
  return gradient_mask
86
 
@@ -116,54 +113,73 @@ def get_ref_index(f, neighbor_ids, length, ref_length, num_ref):
116
 
117
 
118
  def save_flows(output, videoFlowF, videoFlowB):
119
- create_dir(os.path.join(output, 'completed_flow', 'forward_flo'))
120
- create_dir(os.path.join(output, 'completed_flow', 'backward_flo'))
121
- create_dir(os.path.join(output, 'completed_flow', 'forward_png'))
122
- create_dir(os.path.join(output, 'completed_flow', 'backward_png'))
123
  N = videoFlowF.shape[-1]
124
  for i in range(N):
125
  forward_flow = videoFlowF[..., i]
126
  backward_flow = videoFlowB[..., i]
127
  forward_flow_vis = cvbase.flow2rgb(forward_flow)
128
  backward_flow_vis = cvbase.flow2rgb(backward_flow)
129
- cvbase.write_flow(forward_flow, os.path.join(
130
- output, 'completed_flow', 'forward_flo', '{:05d}.flo'.format(i)))
131
- cvbase.write_flow(backward_flow, os.path.join(
132
- output, 'completed_flow', 'backward_flo', '{:05d}.flo'.format(i)))
133
- imageio.imwrite(os.path.join(output, 'completed_flow',
134
- 'forward_png', '{:05d}.png'.format(i)), forward_flow_vis)
135
- imageio.imwrite(os.path.join(output, 'completed_flow',
136
- 'backward_png', '{:05d}.png'.format(i)), backward_flow_vis)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
137
 
138
 
139
  def save_fgcp(output, frames, masks):
140
- create_dir(os.path.join(output, 'prop_frames'))
141
- create_dir(os.path.join(output, 'masks_left'))
142
- create_dir(os.path.join(output, 'prop_frames_npy'))
143
- create_dir(os.path.join(output, 'masks_left_npy'))
144
 
145
  assert len(frames) == masks.shape[2]
146
  for i in range(len(frames)):
147
- cv2.imwrite(os.path.join(output, 'prop_frames',
148
- '%05d.png' % i), frames[i] * 255.)
149
- cv2.imwrite(os.path.join(output, 'masks_left', '%05d.png' %
150
- i), masks[:, :, i] * 255.)
151
- np.save(os.path.join(output, 'prop_frames_npy',
152
- '%05d.npy' % i), frames[i] * 255.)
153
- np.save(os.path.join(output, 'masks_left_npy',
154
- '%05d.npy' % i), masks[:, :, i] * 255.)
 
 
 
 
 
155
 
156
 
157
  def create_dir(dir):
158
- """Creates a directory if not exist.
159
- """
160
  if not os.path.exists(dir):
161
  os.makedirs(dir)
162
 
163
 
164
  def initialize_RAFT(args, device):
165
- """Initializes the RAFT model.
166
- """
167
  model = torch.nn.DataParallel(RAFT(args))
168
  model.load_state_dict(torch.load(args.raft_model))
169
 
@@ -177,57 +193,65 @@ def initialize_RAFT(args, device):
177
  def initialize_LAFC(args, device):
178
  print(args.lafc_ckpts)
179
  assert len(os.listdir(args.lafc_ckpts)) == 2
180
- checkpoint, config_file = glob.glob(os.path.join(args.lafc_ckpts, '*.tar'))[0], \
181
- glob.glob(os.path.join(args.lafc_ckpts, '*.yaml'))[0]
182
- with open(config_file, 'r') as f:
 
 
183
  configs = yaml.full_load(f)
184
- model = configs['model']
185
- pkg = import_module('LAFC.models.{}'.format(model))
186
  model = pkg.Model(configs)
187
- state = torch.load(checkpoint, map_location=lambda storage,
188
- loc: storage.cuda(device))
189
- model.load_state_dict(state['model_state_dict'])
 
190
  model = model.to(device)
191
  return model, configs
192
 
193
 
194
  def initialize_FGT(args, device):
195
  assert len(os.listdir(args.fgt_ckpts)) == 2
196
- checkpoint, config_file = glob.glob(os.path.join(args.fgt_ckpts, '*.tar'))[0], \
197
- glob.glob(os.path.join(args.fgt_ckpts, '*.yaml'))[0]
198
- with open(config_file, 'r') as f:
 
 
199
  configs = yaml.full_load(f)
200
- model = configs['model']
201
- net = import_module('FGT.models.{}'.format(model))
202
  model = net.Model(configs).to(device)
203
- state = torch.load(checkpoint, map_location=lambda storage,
204
- loc: storage.cuda(device))
205
- model.load_state_dict(state['model_state_dict'])
 
206
  return model, configs
207
 
208
 
209
  def calculate_flow(args, model, video, mode):
210
- """Calculates optical flow.
211
- """
212
- if mode not in ['forward', 'backward']:
213
  raise NotImplementedError
214
 
215
  imgH, imgW = args.imgH, args.imgW
216
  Flow = np.empty(((imgH, imgW, 2, 0)), dtype=np.float32)
217
 
218
  if args.vis_flows:
219
- create_dir(os.path.join(args.outroot, 'flow', mode + '_flo'))
220
- create_dir(os.path.join(args.outroot, 'flow', mode + '_png'))
221
 
222
  with torch.no_grad():
223
  for i in range(video.shape[0] - 1):
224
  print(
225
- "Calculating {0} flow {1:2d} <---> {2:2d}".format(mode, i, i + 1), '\r', end='')
226
- if mode == 'forward':
 
 
 
227
  # Flow i -> i + 1
228
  image1 = video[i, None]
229
  image2 = video[i + 1, None]
230
- elif mode == 'backward':
231
  # Flow i + 1 -> i
232
  image1 = video[i + 1, None]
233
  image2 = video[i, None]
@@ -240,8 +264,8 @@ def calculate_flow(args, model, video, mode):
240
  h, w = flow.shape[:2]
241
  if h != imgH or w != imgW:
242
  flow = cv2.resize(flow, (imgW, imgH), cv2.INTER_LINEAR)
243
- flow[:, :, 0] *= (float(imgW) / float(w))
244
- flow[:, :, 1] *= (float(imgH) / float(h))
245
 
246
  Flow = np.concatenate((Flow, flow[..., None]), axis=-1)
247
 
@@ -251,17 +275,19 @@ def calculate_flow(args, model, video, mode):
251
  flow_img = Image.fromarray(flow_img)
252
 
253
  # Saves the flow and flow_img.
254
- flow_img.save(os.path.join(args.outroot, 'flow',
255
- mode + '_png', '%05d.png' % i))
256
- utils.frame_utils.writeFlow(os.path.join(
257
- args.outroot, 'flow', mode + '_flo', '%05d.flo' % i), flow)
 
 
 
258
 
259
  return Flow
260
 
261
 
262
  def extrapolation(args, video_ori, corrFlowF_ori, corrFlowB_ori):
263
- """Prepares the data for video extrapolation.
264
- """
265
  imgH, imgW, _, nFrame = video_ori.shape
266
 
267
  # Defines new FOV.
@@ -274,45 +300,56 @@ def extrapolation(args, video_ori, corrFlowF_ori, corrFlowB_ori):
274
 
275
  # Generates the mask for missing region.
276
  flow_mask = np.ones(((imgH_extr, imgW_extr)), dtype=np.bool)
277
- flow_mask[H_start: H_start + imgH, W_start: W_start + imgW] = 0
278
 
279
  mask_dilated = gradient_mask(flow_mask)
280
 
281
  # Extrapolates the FOV for video.
282
  video = np.zeros(((imgH_extr, imgW_extr, 3, nFrame)), dtype=np.float32)
283
- video[H_start: H_start + imgH, W_start: W_start + imgW, :, :] = video_ori
284
 
285
  for i in range(nFrame):
286
- print("Preparing frame {0}".format(i), '\r', end='')
287
- video[:, :, :, i] = cv2.inpaint((video[:, :, :, i] * 255).astype(np.uint8), flow_mask.astype(np.uint8), 3,
288
- cv2.INPAINT_TELEA).astype(np.float32) / 255.
 
 
 
 
 
 
 
289
 
290
  # Extrapolates the FOV for flow.
291
- corrFlowF = np.zeros(
292
- ((imgH_extr, imgW_extr, 2, nFrame - 1)), dtype=np.float32)
293
- corrFlowB = np.zeros(
294
- ((imgH_extr, imgW_extr, 2, nFrame - 1)), dtype=np.float32)
295
- corrFlowF[H_start: H_start + imgH,
296
- W_start: W_start + imgW, :] = corrFlowF_ori
297
- corrFlowB[H_start: H_start + imgH,
298
- W_start: W_start + imgW, :] = corrFlowB_ori
299
-
300
- return video, corrFlowF, corrFlowB, flow_mask, mask_dilated, (W_start, H_start), (W_start + imgW, H_start + imgH)
 
 
 
 
301
 
302
 
303
  def complete_flow(config, flow_model, flows, flow_masks, mode, device):
304
- if mode not in ['forward', 'backward']:
305
- raise NotImplementedError(f'Error flow mode {mode}')
306
  flow_masks = np.moveaxis(flow_masks, -1, 0) # [N, H, W]
307
  flows = np.moveaxis(flows, -1, 0) # [N, H, W, 2]
308
  if len(flow_masks.shape) == 3:
309
  flow_masks = flow_masks[:, :, :, np.newaxis]
310
- if mode == 'forward':
311
  flow_masks = flow_masks[0:-1]
312
  else:
313
  flow_masks = flow_masks[1:]
314
 
315
- num_flows, flow_interval = config['num_flows'], config['flow_interval']
316
 
317
  diffused_flows = diffusion(flows, flow_masks)
318
 
@@ -329,7 +366,7 @@ def complete_flow(config, flow_model, flows, flow_masks, mode, device):
329
  pivot = num_flows // 2
330
  for i in range(t):
331
  indices = indicesGen(i, flow_interval, num_flows, t)
332
- print('Indices: ', indices, '\r', end='')
333
  cand_flows = flows[:, :, indices]
334
  cand_masks = flow_masks[:, :, indices]
335
  inputs = diffused_flows[:, :, indices]
@@ -349,19 +386,19 @@ def complete_flow(config, flow_model, flows, flow_masks, mode, device):
349
  def read_flow(flow_dir, video):
350
  nFrame, _, imgH, imgW = video.shape
351
  Flow = np.empty(((imgH, imgW, 2, 0)), dtype=np.float32)
352
- flows = sorted(glob.glob(os.path.join(flow_dir, '*.flo')))
353
  for flow in flows:
354
  flow_data = cvbase.read_flow(flow)
355
  h, w = flow_data.shape[:2]
356
  flow_data = cv2.resize(flow_data, (imgW, imgH), cv2.INTER_LINEAR)
357
- flow_data[:, :, 0] *= (float(imgW) / float(w))
358
- flow_data[:, :, 1] *= (float(imgH) / float(h))
359
  Flow = np.concatenate((Flow, flow_data[..., None]), axis=-1)
360
  return Flow
361
 
362
 
363
  def norm_flows(flows):
364
- assert len(flows.shape) == 5, 'FLow shape: {}'.format(flows.shape)
365
  flattened_flows = flows.flatten(3)
366
  flow_max = torch.max(flattened_flows, dim=-1, keepdim=True)[0]
367
  flows = flows / flow_max.unsqueeze(-1)
@@ -369,19 +406,19 @@ def norm_flows(flows):
369
 
370
 
371
  def save_results(outdir, comp_frames):
372
- out_dir = os.path.join(outdir, 'frames')
373
  if not os.path.exists(out_dir):
374
  os.makedirs(out_dir)
375
  for i in range(len(comp_frames)):
376
- out_path = os.path.join(out_dir, '{:05d}.png'.format(i))
377
  cv2.imwrite(out_path, comp_frames[i][:, :, ::-1])
378
 
379
 
380
  def video_inpainting(args, imgArr, imgMaskArr):
381
- device = torch.device('cuda:{}'.format(args.gpu))
382
  print(args)
383
  if args.opt is not None:
384
- with open(args.opt, 'r') as f:
385
  opts = yaml.full_load(f)
386
 
387
  for k in opts.keys():
@@ -412,37 +449,54 @@ def video_inpainting(args, imgArr, imgMaskArr):
412
 
413
  # Load video.
414
  video, video_flow = [], []
415
- if args.mode == 'watermark_removal':
416
- maskname_list = glob.glob(os.path.join(args.path_mask, '*.png')) + glob.glob(
417
- os.path.join(args.path_mask, '*.jpg'))
 
418
  assert len(filename_list) == len(maskname_list)
419
  for filename, maskname in zip(sorted(filename_list), sorted(maskname_list)):
420
- frame = torch.from_numpy(np.array(Image.open(filename)).astype(np.uint8)).permute(2, 0,
421
- 1).float().unsqueeze(0)
422
- mask = torch.from_numpy(np.array(Image.open(maskname)).astype(np.uint8)).permute(2, 0,
423
- 1).float().unsqueeze(0)
 
 
 
 
 
 
 
 
424
  mask[mask > 0] = 1
425
  frame = frame * (1 - mask)
426
- frame = F2.upsample(frame, size=(imgH, imgW),
427
- mode='bilinear', align_corners=False)
428
- frame_flow = F2.upsample(frame, size=(
429
- flowH, flowW), mode='bilinear', align_corners=False)
 
 
430
  video.append(frame)
431
  video_flow.append(frame_flow)
432
  else:
433
- '''for filename in sorted(filename_list):
434
- frame = torch.from_numpy(np.array(Image.open(filename)).astype(np.uint8)).permute(2, 0, 1).float().unsqueeze(0)
435
- frame = F2.upsample(frame, size=(imgH, imgW), mode='bilinear', align_corners=False)
436
- frame_flow = F2.upsample(frame, size=(flowH, flowW), mode='bilinear', align_corners=False)
437
- video.append(frame)
438
- video_flow.append(frame_flow)'''
439
  for im in imgArr:
440
- frame = torch.from_numpy(np.array(im).astype(
441
- np.uint8)).permute(2, 0, 1).float().unsqueeze(0)
442
- frame = F2.upsample(frame, size=(imgH, imgW),
443
- mode='bilinear', align_corners=False)
444
- frame_flow = F2.upsample(frame, size=(
445
- flowH, flowW), mode='bilinear', align_corners=False)
 
 
 
 
 
 
446
  video.append(frame)
447
  video_flow.append(frame_flow)
448
 
@@ -454,20 +508,26 @@ def video_inpainting(args, imgArr, imgMaskArr):
454
 
455
  # Calcutes the corrupted flow.
456
  forward_flows = calculate_flow(
457
- args, RAFT_model, video_flow, 'forward') # [B, C, 2, N]
458
- backward_flows = calculate_flow(args, RAFT_model, video_flow, 'backward')
 
459
 
460
  # Makes sure video is in BGR (opencv) format.
461
- video = video.permute(2, 3, 1, 0).cpu().numpy()[
462
- :, :, ::-1, :] / 255. # np array -> [h, w, c, N] (0~1)
463
-
464
- if args.mode == 'video_extrapolation':
465
 
 
466
  # Creates video and flow where the extrapolated region are missing.
467
- video, forward_flows, backward_flows, flow_mask, mask_dilated, start_point, end_point = extrapolation(args,
468
- video,
469
- forward_flows,
470
- backward_flows)
 
 
 
 
 
471
  imgH, imgW = video.shape[:2]
472
 
473
  # mask indicating the missing region in the video.
@@ -477,13 +537,14 @@ def video_inpainting(args, imgArr, imgMaskArr):
477
 
478
  else:
479
  # Loads masks.
480
- filename_list = glob.glob(os.path.join(args.path_mask, '*.png')) + \
481
- glob.glob(os.path.join(args.path_mask, '*.jpg'))
 
482
 
483
  mask = []
484
  mask_dilated = []
485
  flow_mask = []
486
- '''for filename in sorted(filename_list):
487
  mask_img = np.array(Image.open(filename).convert('L'))
488
  mask_img = cv2.resize(mask_img, dsize=(imgW, imgH), interpolation=cv2.INTER_NEAREST)
489
 
@@ -496,23 +557,26 @@ def video_inpainting(args, imgArr, imgMaskArr):
496
  if args.frame_dilates > 0:
497
  mask_img = scipy.ndimage.binary_dilation(mask_img, iterations=args.frame_dilates)
498
  mask.append(mask_img)
499
- mask_dilated.append(gradient_mask(mask_img))'''
500
 
501
  for f_mask in imgMaskArr:
502
  mask_img = np.array(f_mask)
503
- mask_img = cv2.resize(mask_img, dsize=(
504
- imgW, imgH), interpolation=cv2.INTER_NEAREST)
 
505
 
506
  if args.flow_mask_dilates > 0:
507
  flow_mask_img = scipy.ndimage.binary_dilation(
508
- mask_img, iterations=args.flow_mask_dilates)
 
509
  else:
510
  flow_mask_img = mask_img
511
  flow_mask.append(flow_mask_img)
512
 
513
  if args.frame_dilates > 0:
514
  mask_img = scipy.ndimage.binary_dilation(
515
- mask_img, iterations=args.frame_dilates)
 
516
  mask.append(mask_img)
517
  mask_dilated.append(gradient_mask(mask_img))
518
 
@@ -523,12 +587,14 @@ def video_inpainting(args, imgArr, imgMaskArr):
523
 
524
  # Completes the flow.
525
  videoFlowF = complete_flow(
526
- LAFC_config, LAFC_model, forward_flows, flow_mask, 'forward', device)
 
527
  videoFlowB = complete_flow(
528
- LAFC_config, LAFC_model, backward_flows, flow_mask, 'backward', device)
 
529
  videoFlowF = tensor2np(videoFlowF)
530
  videoFlowB = tensor2np(videoFlowB)
531
- print('\nFinish flow completion.')
532
 
533
  if args.vis_completed_flows:
534
  save_flows(args.outroot, videoFlowF, videoFlowB)
@@ -540,17 +606,28 @@ def video_inpainting(args, imgArr, imgMaskArr):
540
  for indFrame in range(nFrame):
541
  img = video[:, :, :, indFrame]
542
  img[mask[:, :, indFrame], :] = 0
543
- img = cv2.inpaint((img * 255).astype(np.uint8), mask[:, :, indFrame].astype(np.uint8), 3,
544
- cv2.INPAINT_TELEA).astype(np.float32) / 255.
545
-
546
- gradient_x_ = np.concatenate((np.diff(img, axis=1), np.zeros((imgH, 1, 3), dtype=np.float32)),
547
- axis=1)
 
 
 
 
 
 
 
 
548
  gradient_y_ = np.concatenate(
549
- (np.diff(img, axis=0), np.zeros((1, imgW, 3), dtype=np.float32)), axis=0)
 
550
  gradient_x = np.concatenate(
551
- (gradient_x, gradient_x_.reshape(imgH, imgW, 3, 1)), axis=-1)
 
552
  gradient_y = np.concatenate(
553
- (gradient_y, gradient_y_.reshape(imgH, imgW, 3, 1)), axis=-1)
 
554
 
555
  gradient_x[mask_dilated[:, :, indFrame], :, indFrame] = 0
556
  gradient_y[mask_dilated[:, :, indFrame], :, indFrame] = 0
@@ -561,21 +638,23 @@ def video_inpainting(args, imgArr, imgMaskArr):
561
  video_comp = video
562
 
563
  # Gradient propagation.
564
- gradient_x_filled, gradient_y_filled, mask_gradient = \
565
- get_flowNN_gradient(args,
566
- gradient_x_filled,
567
- gradient_y_filled,
568
- mask,
569
- mask_gradient,
570
- videoFlowF,
571
- videoFlowB,
572
- None,
573
- None)
 
574
 
575
  # if there exist holes in mask, Poisson blending will fail. So I did this trick. I sacrifice some value. Another solution is to modify Poisson blending.
576
  for indFrame in range(nFrame):
577
- mask_gradient[:, :, indFrame] = scipy.ndimage.binary_fill_holes(mask_gradient[:, :, indFrame]).astype(
578
- np.bool)
 
579
 
580
  # After one gradient propagation iteration
581
  # gradient --> RGB
@@ -585,19 +664,29 @@ def video_inpainting(args, imgArr, imgMaskArr):
585
 
586
  if mask[:, :, indFrame].sum() > 0:
587
  try:
588
- frameBlend, UnfilledMask = Poisson_blend_img(video_comp[:, :, :, indFrame],
589
- gradient_x_filled[:,
590
- 0: imgW - 1, :, indFrame],
591
- gradient_y_filled[0: imgH -
592
- 1, :, :, indFrame],
593
- mask[:, :, indFrame], mask_gradient[:, :, indFrame])
 
594
  except:
595
- frameBlend, UnfilledMask = video_comp[:,
596
- :, :, indFrame], mask[:, :, indFrame]
 
 
597
 
598
  frameBlend = np.clip(frameBlend, 0, 1.0)
599
- tmp = cv2.inpaint((frameBlend * 255).astype(np.uint8), UnfilledMask.astype(np.uint8), 3,
600
- cv2.INPAINT_TELEA).astype(np.float32) / 255.
 
 
 
 
 
 
 
601
  frameBlend[UnfilledMask, :] = tmp[UnfilledMask, :]
602
 
603
  video_comp[:, :, :, indFrame] = frameBlend
@@ -605,7 +694,7 @@ def video_inpainting(args, imgArr, imgMaskArr):
605
 
606
  frameBlend_ = copy.deepcopy(frameBlend)
607
  # Green indicates the regions that are not filled yet.
608
- frameBlend_[mask[:, :, indFrame], :] = [0, 1., 0]
609
  else:
610
  frameBlend_ = video_comp[:, :, :, indFrame]
611
  frameBlends.append(frameBlend_)
@@ -618,10 +707,10 @@ def video_inpainting(args, imgArr, imgMaskArr):
618
  for i in range(len(frameBlends)):
619
  frameBlends[i] = frameBlends[i][:, :, ::-1]
620
 
621
- frames_first = np2tensor(frameBlends, near='t').to(device)
622
  mask = np.moveaxis(mask, -1, 0)
623
  mask = mask[:, :, :, np.newaxis]
624
- masks = np2tensor(mask, near='t').to(device)
625
  normed_frames = frames_first * 2 - 1
626
  comp_frames = [None] * video_length
627
 
@@ -633,115 +722,155 @@ def video_inpainting(args, imgArr, imgMaskArr):
633
 
634
  videoFlowF = np.concatenate([videoFlowF, videoFlowF[-1:, ...]], axis=0)
635
 
636
- flows = np2tensor(videoFlowF, near='t')
637
  flows = norm_flows(flows).to(device)
638
 
639
  for f in range(0, video_length, neighbor_stride):
640
- neighbor_ids = [i for i in range(
641
- max(0, f - neighbor_stride), min(video_length, f + neighbor_stride + 1))]
642
- ref_ids = get_ref_index(
643
- f, neighbor_ids, video_length, ref_length, num_ref)
 
 
 
644
  print(f, len(neighbor_ids), len(ref_ids))
645
  selected_frames = normed_frames[:, neighbor_ids + ref_ids]
646
  selected_masks = masks[:, neighbor_ids + ref_ids]
647
  masked_frames = selected_frames * (1 - selected_masks)
648
  selected_flows = flows[:, neighbor_ids + ref_ids]
649
  with torch.no_grad():
650
- filled_frames = FGT_model(
651
- masked_frames, selected_flows, selected_masks)
652
  filled_frames = (filled_frames + 1) / 2
653
  filled_frames = filled_frames.cpu().permute(0, 2, 3, 1).numpy() * 255
654
  for i in range(len(neighbor_ids)):
655
  idx = neighbor_ids[i]
656
- valid_frame = frames_first[0, idx].cpu().permute(
657
- 1, 2, 0).numpy() * 255.
658
  valid_mask = masks[0, idx].cpu().permute(1, 2, 0).numpy()
659
- comp = np.array(filled_frames[i]).astype(np.uint8) * valid_mask + \
660
- np.array(valid_frame).astype(np.uint8) * (1 - valid_mask)
 
661
  if comp_frames[idx] is None:
662
  comp_frames[idx] = comp
663
  else:
664
- comp_frames[idx] = comp_frames[idx].astype(
665
- np.float32) * 0.5 + comp.astype(np.float32) * 0.5
 
 
666
  if args.vis_frame:
667
  save_results(args.outroot, comp_frames)
668
  create_dir(args.outroot)
669
  for i in range(len(comp_frames)):
670
  comp_frames[i] = comp_frames[i].astype(np.uint8)
671
- imageio.mimwrite(os.path.join(args.outroot, 'result.mp4'),
672
- comp_frames, fps=30, quality=8)
673
- print(f'Done, please check your result in {args.outroot} ')
 
674
 
675
 
676
  def main(args):
677
- assert args.mode in ('object_removal', 'video_extrapolation', 'watermark_removal'), (
 
 
 
 
678
  "Accepted modes: 'object_removal', 'video_extrapolation', and 'watermark_removal', but input is %s"
679
  ) % args.mode
680
  video_inpainting(args)
681
 
682
 
683
- if __name__ == '__main__':
684
  parser = argparse.ArgumentParser()
685
- parser.add_argument('--opt', default='configs/object_removal.yaml',
686
- help='Please select your config file for inference')
 
 
 
687
  # video completion
688
- parser.add_argument('--mode', default='object_removal', choices=[
689
- 'object_removal', 'watermark_removal', 'video_extrapolation'], help="modes: object_removal / video_extrapolation")
690
  parser.add_argument(
691
- '--path', default='/myData/davis_resized/walking', help="dataset for evaluation")
 
 
 
 
 
 
 
692
  parser.add_argument(
693
- '--path_mask', default='/myData/dilateAnnotations_4/walking', help="mask for object removal")
 
 
 
694
  parser.add_argument(
695
- '--outroot', default='quick_start/walking3', help="output directory")
696
- parser.add_argument('--consistencyThres', dest='consistencyThres', default=5, type=float,
697
- help='flow consistency error threshold')
698
- parser.add_argument('--alpha', dest='alpha', default=0.1, type=float)
699
- parser.add_argument('--Nonlocal', dest='Nonlocal',
700
- default=False, type=bool)
 
 
 
 
 
701
 
702
  # RAFT
703
  parser.add_argument(
704
- '--raft_model', default='../LAFC/flowCheckPoint/raft-things.pth', help="restore checkpoint")
705
- parser.add_argument('--small', action='store_true', help='use small model')
706
- parser.add_argument('--mixed_precision',
707
- action='store_true', help='use mixed precision')
708
- parser.add_argument('--alternate_corr', action='store_true',
709
- help='use efficent correlation implementation')
 
 
 
 
 
 
 
710
 
711
  # LAFC
712
- parser.add_argument('--lafc_ckpts', type=str, default='../LAFC/checkpoint')
713
 
714
  # FGT
715
- parser.add_argument('--fgt_ckpts', type=str, default='../FGT/checkpoint')
716
 
717
  # extrapolation
718
- parser.add_argument('--H_scale', dest='H_scale', default=2,
719
- type=float, help='H extrapolation scale')
720
- parser.add_argument('--W_scale', dest='W_scale', default=2,
721
- type=float, help='W extrapolation scale')
 
 
722
 
723
  # Image basic information
724
- parser.add_argument('--imgH', type=int, default=256)
725
- parser.add_argument('--imgW', type=int, default=432)
726
- parser.add_argument('--flow_mask_dilates', type=int, default=8)
727
- parser.add_argument('--frame_dilates', type=int, default=0)
728
 
729
- parser.add_argument('--gpu', type=int, default=0)
730
 
731
  # FGT inference parameters
732
- parser.add_argument('--step', type=int, default=10)
733
- parser.add_argument('--num_ref', type=int, default=-1)
734
- parser.add_argument('--neighbor_stride', type=int, default=5)
735
 
736
  # visualization
737
- parser.add_argument('--vis_flows', action='store_true',
738
- help='Visualize the initialized flows')
739
- parser.add_argument('--vis_completed_flows',
740
- action='store_true', help='Visualize the completed flows')
741
- parser.add_argument('--vis_prop', action='store_true',
742
- help='Visualize the frames after stage-I filling (flow guided content propagation)')
743
- parser.add_argument('--vis_frame', action='store_true',
744
- help='Visualize frames')
 
 
 
 
 
 
745
 
746
  args = parser.parse_args()
747
 
 
1
+ import sys
2
+ import os
3
+
4
+ sys.path.append(os.path.abspath(os.path.join(__file__, "..", "..")))
5
+ sys.path.append(os.path.abspath(os.path.join(__file__, "..", "..", "FGT")))
6
+ sys.path.append(os.path.abspath(os.path.join(__file__, "..", "..", "LAFC")))
7
+ warnings.filterwarnings("ignore")
8
+
9
  import cvbase
10
  from torchvision.transforms import ToTensor
11
  from get_flowNN_gradient import get_flowNN_gradient
 
28
  import cv2
29
  import argparse
30
  import warnings
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
31
 
32
 
33
  def to_tensor(img):
 
47
  return flows_filled
48
 
49
 
50
+ def np2tensor(array, near="c"):
51
  if isinstance(array, list):
52
  array = np.stack(array, axis=0) # [t, h, w, c]
53
+ if near == "c":
54
+ array = (
55
+ torch.from_numpy(np.transpose(array, (3, 0, 1, 2))).unsqueeze(0).float()
56
+ ) # [1, c, t, h, w]
57
+ elif near == "t":
58
+ array = torch.from_numpy(np.transpose(array, (0, 3, 1, 2))).unsqueeze(0).float()
59
  else:
60
+ raise ValueError(f"Unknown near type: {near}")
61
  return array
62
 
63
 
64
  def tensor2np(array):
65
+ array = torch.stack(array, dim=-1).squeeze(0).permute(1, 2, 0, 3).cpu().numpy()
 
66
  return array
67
 
68
 
69
  def gradient_mask(mask):
70
+ gradient_mask = np.logical_or.reduce(
71
+ (
72
+ mask,
73
+ np.concatenate(
74
+ (mask[1:, :], np.zeros((1, mask.shape[1]), dtype=np.bool)), axis=0
75
+ ),
76
+ np.concatenate(
77
+ (mask[:, 1:], np.zeros((mask.shape[0], 1), dtype=np.bool)), axis=1
78
+ ),
79
+ )
80
+ )
81
 
82
  return gradient_mask
83
 
 
113
 
114
 
115
  def save_flows(output, videoFlowF, videoFlowB):
116
+ create_dir(os.path.join(output, "completed_flow", "forward_flo"))
117
+ create_dir(os.path.join(output, "completed_flow", "backward_flo"))
118
+ create_dir(os.path.join(output, "completed_flow", "forward_png"))
119
+ create_dir(os.path.join(output, "completed_flow", "backward_png"))
120
  N = videoFlowF.shape[-1]
121
  for i in range(N):
122
  forward_flow = videoFlowF[..., i]
123
  backward_flow = videoFlowB[..., i]
124
  forward_flow_vis = cvbase.flow2rgb(forward_flow)
125
  backward_flow_vis = cvbase.flow2rgb(backward_flow)
126
+ cvbase.write_flow(
127
+ forward_flow,
128
+ os.path.join(
129
+ output, "completed_flow", "forward_flo", "{:05d}.flo".format(i)
130
+ ),
131
+ )
132
+ cvbase.write_flow(
133
+ backward_flow,
134
+ os.path.join(
135
+ output, "completed_flow", "backward_flo", "{:05d}.flo".format(i)
136
+ ),
137
+ )
138
+ imageio.imwrite(
139
+ os.path.join(
140
+ output, "completed_flow", "forward_png", "{:05d}.png".format(i)
141
+ ),
142
+ forward_flow_vis,
143
+ )
144
+ imageio.imwrite(
145
+ os.path.join(
146
+ output, "completed_flow", "backward_png", "{:05d}.png".format(i)
147
+ ),
148
+ backward_flow_vis,
149
+ )
150
 
151
 
152
  def save_fgcp(output, frames, masks):
153
+ create_dir(os.path.join(output, "prop_frames"))
154
+ create_dir(os.path.join(output, "masks_left"))
155
+ create_dir(os.path.join(output, "prop_frames_npy"))
156
+ create_dir(os.path.join(output, "masks_left_npy"))
157
 
158
  assert len(frames) == masks.shape[2]
159
  for i in range(len(frames)):
160
+ cv2.imwrite(
161
+ os.path.join(output, "prop_frames", "%05d.png" % i), frames[i] * 255.0
162
+ )
163
+ cv2.imwrite(
164
+ os.path.join(output, "masks_left", "%05d.png" % i), masks[:, :, i] * 255.0
165
+ )
166
+ np.save(
167
+ os.path.join(output, "prop_frames_npy", "%05d.npy" % i), frames[i] * 255.0
168
+ )
169
+ np.save(
170
+ os.path.join(output, "masks_left_npy", "%05d.npy" % i),
171
+ masks[:, :, i] * 255.0,
172
+ )
173
 
174
 
175
  def create_dir(dir):
176
+ """Creates a directory if not exist."""
 
177
  if not os.path.exists(dir):
178
  os.makedirs(dir)
179
 
180
 
181
  def initialize_RAFT(args, device):
182
+ """Initializes the RAFT model."""
 
183
  model = torch.nn.DataParallel(RAFT(args))
184
  model.load_state_dict(torch.load(args.raft_model))
185
 
 
193
  def initialize_LAFC(args, device):
194
  print(args.lafc_ckpts)
195
  assert len(os.listdir(args.lafc_ckpts)) == 2
196
+ checkpoint, config_file = (
197
+ glob.glob(os.path.join(args.lafc_ckpts, "*.tar"))[0],
198
+ glob.glob(os.path.join(args.lafc_ckpts, "*.yaml"))[0],
199
+ )
200
+ with open(config_file, "r") as f:
201
  configs = yaml.full_load(f)
202
+ model = configs["model"]
203
+ pkg = import_module("LAFC.models.{}".format(model))
204
  model = pkg.Model(configs)
205
+ state = torch.load(
206
+ checkpoint, map_location=lambda storage, loc: storage.cuda(device)
207
+ )
208
+ model.load_state_dict(state["model_state_dict"])
209
  model = model.to(device)
210
  return model, configs
211
 
212
 
213
  def initialize_FGT(args, device):
214
  assert len(os.listdir(args.fgt_ckpts)) == 2
215
+ checkpoint, config_file = (
216
+ glob.glob(os.path.join(args.fgt_ckpts, "*.tar"))[0],
217
+ glob.glob(os.path.join(args.fgt_ckpts, "*.yaml"))[0],
218
+ )
219
+ with open(config_file, "r") as f:
220
  configs = yaml.full_load(f)
221
+ model = configs["model"]
222
+ net = import_module("FGT.models.{}".format(model))
223
  model = net.Model(configs).to(device)
224
+ state = torch.load(
225
+ checkpoint, map_location=lambda storage, loc: storage.cuda(device)
226
+ )
227
+ model.load_state_dict(state["model_state_dict"])
228
  return model, configs
229
 
230
 
231
  def calculate_flow(args, model, video, mode):
232
+ """Calculates optical flow."""
233
+ if mode not in ["forward", "backward"]:
 
234
  raise NotImplementedError
235
 
236
  imgH, imgW = args.imgH, args.imgW
237
  Flow = np.empty(((imgH, imgW, 2, 0)), dtype=np.float32)
238
 
239
  if args.vis_flows:
240
+ create_dir(os.path.join(args.outroot, "flow", mode + "_flo"))
241
+ create_dir(os.path.join(args.outroot, "flow", mode + "_png"))
242
 
243
  with torch.no_grad():
244
  for i in range(video.shape[0] - 1):
245
  print(
246
+ "Calculating {0} flow {1:2d} <---> {2:2d}".format(mode, i, i + 1),
247
+ "\r",
248
+ end="",
249
+ )
250
+ if mode == "forward":
251
  # Flow i -> i + 1
252
  image1 = video[i, None]
253
  image2 = video[i + 1, None]
254
+ elif mode == "backward":
255
  # Flow i + 1 -> i
256
  image1 = video[i + 1, None]
257
  image2 = video[i, None]
 
264
  h, w = flow.shape[:2]
265
  if h != imgH or w != imgW:
266
  flow = cv2.resize(flow, (imgW, imgH), cv2.INTER_LINEAR)
267
+ flow[:, :, 0] *= float(imgW) / float(w)
268
+ flow[:, :, 1] *= float(imgH) / float(h)
269
 
270
  Flow = np.concatenate((Flow, flow[..., None]), axis=-1)
271
 
 
275
  flow_img = Image.fromarray(flow_img)
276
 
277
  # Saves the flow and flow_img.
278
+ flow_img.save(
279
+ os.path.join(args.outroot, "flow", mode + "_png", "%05d.png" % i)
280
+ )
281
+ utils.frame_utils.writeFlow(
282
+ os.path.join(args.outroot, "flow", mode + "_flo", "%05d.flo" % i),
283
+ flow,
284
+ )
285
 
286
  return Flow
287
 
288
 
289
  def extrapolation(args, video_ori, corrFlowF_ori, corrFlowB_ori):
290
+ """Prepares the data for video extrapolation."""
 
291
  imgH, imgW, _, nFrame = video_ori.shape
292
 
293
  # Defines new FOV.
 
300
 
301
  # Generates the mask for missing region.
302
  flow_mask = np.ones(((imgH_extr, imgW_extr)), dtype=np.bool)
303
+ flow_mask[H_start : H_start + imgH, W_start : W_start + imgW] = 0
304
 
305
  mask_dilated = gradient_mask(flow_mask)
306
 
307
  # Extrapolates the FOV for video.
308
  video = np.zeros(((imgH_extr, imgW_extr, 3, nFrame)), dtype=np.float32)
309
+ video[H_start : H_start + imgH, W_start : W_start + imgW, :, :] = video_ori
310
 
311
  for i in range(nFrame):
312
+ print("Preparing frame {0}".format(i), "\r", end="")
313
+ video[:, :, :, i] = (
314
+ cv2.inpaint(
315
+ (video[:, :, :, i] * 255).astype(np.uint8),
316
+ flow_mask.astype(np.uint8),
317
+ 3,
318
+ cv2.INPAINT_TELEA,
319
+ ).astype(np.float32)
320
+ / 255.0
321
+ )
322
 
323
  # Extrapolates the FOV for flow.
324
+ corrFlowF = np.zeros(((imgH_extr, imgW_extr, 2, nFrame - 1)), dtype=np.float32)
325
+ corrFlowB = np.zeros(((imgH_extr, imgW_extr, 2, nFrame - 1)), dtype=np.float32)
326
+ corrFlowF[H_start : H_start + imgH, W_start : W_start + imgW, :] = corrFlowF_ori
327
+ corrFlowB[H_start : H_start + imgH, W_start : W_start + imgW, :] = corrFlowB_ori
328
+
329
+ return (
330
+ video,
331
+ corrFlowF,
332
+ corrFlowB,
333
+ flow_mask,
334
+ mask_dilated,
335
+ (W_start, H_start),
336
+ (W_start + imgW, H_start + imgH),
337
+ )
338
 
339
 
340
  def complete_flow(config, flow_model, flows, flow_masks, mode, device):
341
+ if mode not in ["forward", "backward"]:
342
+ raise NotImplementedError(f"Error flow mode {mode}")
343
  flow_masks = np.moveaxis(flow_masks, -1, 0) # [N, H, W]
344
  flows = np.moveaxis(flows, -1, 0) # [N, H, W, 2]
345
  if len(flow_masks.shape) == 3:
346
  flow_masks = flow_masks[:, :, :, np.newaxis]
347
+ if mode == "forward":
348
  flow_masks = flow_masks[0:-1]
349
  else:
350
  flow_masks = flow_masks[1:]
351
 
352
+ num_flows, flow_interval = config["num_flows"], config["flow_interval"]
353
 
354
  diffused_flows = diffusion(flows, flow_masks)
355
 
 
366
  pivot = num_flows // 2
367
  for i in range(t):
368
  indices = indicesGen(i, flow_interval, num_flows, t)
369
+ print("Indices: ", indices, "\r", end="")
370
  cand_flows = flows[:, :, indices]
371
  cand_masks = flow_masks[:, :, indices]
372
  inputs = diffused_flows[:, :, indices]
 
386
  def read_flow(flow_dir, video):
387
  nFrame, _, imgH, imgW = video.shape
388
  Flow = np.empty(((imgH, imgW, 2, 0)), dtype=np.float32)
389
+ flows = sorted(glob.glob(os.path.join(flow_dir, "*.flo")))
390
  for flow in flows:
391
  flow_data = cvbase.read_flow(flow)
392
  h, w = flow_data.shape[:2]
393
  flow_data = cv2.resize(flow_data, (imgW, imgH), cv2.INTER_LINEAR)
394
+ flow_data[:, :, 0] *= float(imgW) / float(w)
395
+ flow_data[:, :, 1] *= float(imgH) / float(h)
396
  Flow = np.concatenate((Flow, flow_data[..., None]), axis=-1)
397
  return Flow
398
 
399
 
400
  def norm_flows(flows):
401
+ assert len(flows.shape) == 5, "FLow shape: {}".format(flows.shape)
402
  flattened_flows = flows.flatten(3)
403
  flow_max = torch.max(flattened_flows, dim=-1, keepdim=True)[0]
404
  flows = flows / flow_max.unsqueeze(-1)
 
406
 
407
 
408
  def save_results(outdir, comp_frames):
409
+ out_dir = os.path.join(outdir, "frames")
410
  if not os.path.exists(out_dir):
411
  os.makedirs(out_dir)
412
  for i in range(len(comp_frames)):
413
+ out_path = os.path.join(out_dir, "{:05d}.png".format(i))
414
  cv2.imwrite(out_path, comp_frames[i][:, :, ::-1])
415
 
416
 
417
  def video_inpainting(args, imgArr, imgMaskArr):
418
+ device = torch.device("cuda:{}".format(args.gpu))
419
  print(args)
420
  if args.opt is not None:
421
+ with open(args.opt, "r") as f:
422
  opts = yaml.full_load(f)
423
 
424
  for k in opts.keys():
 
449
 
450
  # Load video.
451
  video, video_flow = [], []
452
+ if args.mode == "watermark_removal":
453
+ maskname_list = glob.glob(os.path.join(args.path_mask, "*.png")) + glob.glob(
454
+ os.path.join(args.path_mask, "*.jpg")
455
+ )
456
  assert len(filename_list) == len(maskname_list)
457
  for filename, maskname in zip(sorted(filename_list), sorted(maskname_list)):
458
+ frame = (
459
+ torch.from_numpy(np.array(Image.open(filename)).astype(np.uint8))
460
+ .permute(2, 0, 1)
461
+ .float()
462
+ .unsqueeze(0)
463
+ )
464
+ mask = (
465
+ torch.from_numpy(np.array(Image.open(maskname)).astype(np.uint8))
466
+ .permute(2, 0, 1)
467
+ .float()
468
+ .unsqueeze(0)
469
+ )
470
  mask[mask > 0] = 1
471
  frame = frame * (1 - mask)
472
+ frame = F2.upsample(
473
+ frame, size=(imgH, imgW), mode="bilinear", align_corners=False
474
+ )
475
+ frame_flow = F2.upsample(
476
+ frame, size=(flowH, flowW), mode="bilinear", align_corners=False
477
+ )
478
  video.append(frame)
479
  video_flow.append(frame_flow)
480
  else:
481
+ """for filename in sorted(filename_list):
482
+ frame = torch.from_numpy(np.array(Image.open(filename)).astype(np.uint8)).permute(2, 0, 1).float().unsqueeze(0)
483
+ frame = F2.upsample(frame, size=(imgH, imgW), mode='bilinear', align_corners=False)
484
+ frame_flow = F2.upsample(frame, size=(flowH, flowW), mode='bilinear', align_corners=False)
485
+ video.append(frame)
486
+ video_flow.append(frame_flow)"""
487
  for im in imgArr:
488
+ frame = (
489
+ torch.from_numpy(np.array(im).astype(np.uint8))
490
+ .permute(2, 0, 1)
491
+ .float()
492
+ .unsqueeze(0)
493
+ )
494
+ frame = F2.upsample(
495
+ frame, size=(imgH, imgW), mode="bilinear", align_corners=False
496
+ )
497
+ frame_flow = F2.upsample(
498
+ frame, size=(flowH, flowW), mode="bilinear", align_corners=False
499
+ )
500
  video.append(frame)
501
  video_flow.append(frame_flow)
502
 
 
508
 
509
  # Calcutes the corrupted flow.
510
  forward_flows = calculate_flow(
511
+ args, RAFT_model, video_flow, "forward"
512
+ ) # [B, C, 2, N]
513
+ backward_flows = calculate_flow(args, RAFT_model, video_flow, "backward")
514
 
515
  # Makes sure video is in BGR (opencv) format.
516
+ video = (
517
+ video.permute(2, 3, 1, 0).cpu().numpy()[:, :, ::-1, :] / 255.0
518
+ ) # np array -> [h, w, c, N] (0~1)
 
519
 
520
+ if args.mode == "video_extrapolation":
521
  # Creates video and flow where the extrapolated region are missing.
522
+ (
523
+ video,
524
+ forward_flows,
525
+ backward_flows,
526
+ flow_mask,
527
+ mask_dilated,
528
+ start_point,
529
+ end_point,
530
+ ) = extrapolation(args, video, forward_flows, backward_flows)
531
  imgH, imgW = video.shape[:2]
532
 
533
  # mask indicating the missing region in the video.
 
537
 
538
  else:
539
  # Loads masks.
540
+ filename_list = glob.glob(os.path.join(args.path_mask, "*.png")) + glob.glob(
541
+ os.path.join(args.path_mask, "*.jpg")
542
+ )
543
 
544
  mask = []
545
  mask_dilated = []
546
  flow_mask = []
547
+ """for filename in sorted(filename_list):
548
  mask_img = np.array(Image.open(filename).convert('L'))
549
  mask_img = cv2.resize(mask_img, dsize=(imgW, imgH), interpolation=cv2.INTER_NEAREST)
550
 
 
557
  if args.frame_dilates > 0:
558
  mask_img = scipy.ndimage.binary_dilation(mask_img, iterations=args.frame_dilates)
559
  mask.append(mask_img)
560
+ mask_dilated.append(gradient_mask(mask_img))"""
561
 
562
  for f_mask in imgMaskArr:
563
  mask_img = np.array(f_mask)
564
+ mask_img = cv2.resize(
565
+ mask_img, dsize=(imgW, imgH), interpolation=cv2.INTER_NEAREST
566
+ )
567
 
568
  if args.flow_mask_dilates > 0:
569
  flow_mask_img = scipy.ndimage.binary_dilation(
570
+ mask_img, iterations=args.flow_mask_dilates
571
+ )
572
  else:
573
  flow_mask_img = mask_img
574
  flow_mask.append(flow_mask_img)
575
 
576
  if args.frame_dilates > 0:
577
  mask_img = scipy.ndimage.binary_dilation(
578
+ mask_img, iterations=args.frame_dilates
579
+ )
580
  mask.append(mask_img)
581
  mask_dilated.append(gradient_mask(mask_img))
582
 
 
587
 
588
  # Completes the flow.
589
  videoFlowF = complete_flow(
590
+ LAFC_config, LAFC_model, forward_flows, flow_mask, "forward", device
591
+ )
592
  videoFlowB = complete_flow(
593
+ LAFC_config, LAFC_model, backward_flows, flow_mask, "backward", device
594
+ )
595
  videoFlowF = tensor2np(videoFlowF)
596
  videoFlowB = tensor2np(videoFlowB)
597
+ print("\nFinish flow completion.")
598
 
599
  if args.vis_completed_flows:
600
  save_flows(args.outroot, videoFlowF, videoFlowB)
 
606
  for indFrame in range(nFrame):
607
  img = video[:, :, :, indFrame]
608
  img[mask[:, :, indFrame], :] = 0
609
+ img = (
610
+ cv2.inpaint(
611
+ (img * 255).astype(np.uint8),
612
+ mask[:, :, indFrame].astype(np.uint8),
613
+ 3,
614
+ cv2.INPAINT_TELEA,
615
+ ).astype(np.float32)
616
+ / 255.0
617
+ )
618
+
619
+ gradient_x_ = np.concatenate(
620
+ (np.diff(img, axis=1), np.zeros((imgH, 1, 3), dtype=np.float32)), axis=1
621
+ )
622
  gradient_y_ = np.concatenate(
623
+ (np.diff(img, axis=0), np.zeros((1, imgW, 3), dtype=np.float32)), axis=0
624
+ )
625
  gradient_x = np.concatenate(
626
+ (gradient_x, gradient_x_.reshape(imgH, imgW, 3, 1)), axis=-1
627
+ )
628
  gradient_y = np.concatenate(
629
+ (gradient_y, gradient_y_.reshape(imgH, imgW, 3, 1)), axis=-1
630
+ )
631
 
632
  gradient_x[mask_dilated[:, :, indFrame], :, indFrame] = 0
633
  gradient_y[mask_dilated[:, :, indFrame], :, indFrame] = 0
 
638
  video_comp = video
639
 
640
  # Gradient propagation.
641
+ gradient_x_filled, gradient_y_filled, mask_gradient = get_flowNN_gradient(
642
+ args,
643
+ gradient_x_filled,
644
+ gradient_y_filled,
645
+ mask,
646
+ mask_gradient,
647
+ videoFlowF,
648
+ videoFlowB,
649
+ None,
650
+ None,
651
+ )
652
 
653
  # if there exist holes in mask, Poisson blending will fail. So I did this trick. I sacrifice some value. Another solution is to modify Poisson blending.
654
  for indFrame in range(nFrame):
655
+ mask_gradient[:, :, indFrame] = scipy.ndimage.binary_fill_holes(
656
+ mask_gradient[:, :, indFrame]
657
+ ).astype(np.bool)
658
 
659
  # After one gradient propagation iteration
660
  # gradient --> RGB
 
664
 
665
  if mask[:, :, indFrame].sum() > 0:
666
  try:
667
+ frameBlend, UnfilledMask = Poisson_blend_img(
668
+ video_comp[:, :, :, indFrame],
669
+ gradient_x_filled[:, 0 : imgW - 1, :, indFrame],
670
+ gradient_y_filled[0 : imgH - 1, :, :, indFrame],
671
+ mask[:, :, indFrame],
672
+ mask_gradient[:, :, indFrame],
673
+ )
674
  except:
675
+ frameBlend, UnfilledMask = (
676
+ video_comp[:, :, :, indFrame],
677
+ mask[:, :, indFrame],
678
+ )
679
 
680
  frameBlend = np.clip(frameBlend, 0, 1.0)
681
+ tmp = (
682
+ cv2.inpaint(
683
+ (frameBlend * 255).astype(np.uint8),
684
+ UnfilledMask.astype(np.uint8),
685
+ 3,
686
+ cv2.INPAINT_TELEA,
687
+ ).astype(np.float32)
688
+ / 255.0
689
+ )
690
  frameBlend[UnfilledMask, :] = tmp[UnfilledMask, :]
691
 
692
  video_comp[:, :, :, indFrame] = frameBlend
 
694
 
695
  frameBlend_ = copy.deepcopy(frameBlend)
696
  # Green indicates the regions that are not filled yet.
697
+ frameBlend_[mask[:, :, indFrame], :] = [0, 1.0, 0]
698
  else:
699
  frameBlend_ = video_comp[:, :, :, indFrame]
700
  frameBlends.append(frameBlend_)
 
707
  for i in range(len(frameBlends)):
708
  frameBlends[i] = frameBlends[i][:, :, ::-1]
709
 
710
+ frames_first = np2tensor(frameBlends, near="t").to(device)
711
  mask = np.moveaxis(mask, -1, 0)
712
  mask = mask[:, :, :, np.newaxis]
713
+ masks = np2tensor(mask, near="t").to(device)
714
  normed_frames = frames_first * 2 - 1
715
  comp_frames = [None] * video_length
716
 
 
722
 
723
  videoFlowF = np.concatenate([videoFlowF, videoFlowF[-1:, ...]], axis=0)
724
 
725
+ flows = np2tensor(videoFlowF, near="t")
726
  flows = norm_flows(flows).to(device)
727
 
728
  for f in range(0, video_length, neighbor_stride):
729
+ neighbor_ids = [
730
+ i
731
+ for i in range(
732
+ max(0, f - neighbor_stride), min(video_length, f + neighbor_stride + 1)
733
+ )
734
+ ]
735
+ ref_ids = get_ref_index(f, neighbor_ids, video_length, ref_length, num_ref)
736
  print(f, len(neighbor_ids), len(ref_ids))
737
  selected_frames = normed_frames[:, neighbor_ids + ref_ids]
738
  selected_masks = masks[:, neighbor_ids + ref_ids]
739
  masked_frames = selected_frames * (1 - selected_masks)
740
  selected_flows = flows[:, neighbor_ids + ref_ids]
741
  with torch.no_grad():
742
+ filled_frames = FGT_model(masked_frames, selected_flows, selected_masks)
 
743
  filled_frames = (filled_frames + 1) / 2
744
  filled_frames = filled_frames.cpu().permute(0, 2, 3, 1).numpy() * 255
745
  for i in range(len(neighbor_ids)):
746
  idx = neighbor_ids[i]
747
+ valid_frame = frames_first[0, idx].cpu().permute(1, 2, 0).numpy() * 255.0
 
748
  valid_mask = masks[0, idx].cpu().permute(1, 2, 0).numpy()
749
+ comp = np.array(filled_frames[i]).astype(np.uint8) * valid_mask + np.array(
750
+ valid_frame
751
+ ).astype(np.uint8) * (1 - valid_mask)
752
  if comp_frames[idx] is None:
753
  comp_frames[idx] = comp
754
  else:
755
+ comp_frames[idx] = (
756
+ comp_frames[idx].astype(np.float32) * 0.5
757
+ + comp.astype(np.float32) * 0.5
758
+ )
759
  if args.vis_frame:
760
  save_results(args.outroot, comp_frames)
761
  create_dir(args.outroot)
762
  for i in range(len(comp_frames)):
763
  comp_frames[i] = comp_frames[i].astype(np.uint8)
764
+ imageio.mimwrite(
765
+ os.path.join(args.outroot, "result.mp4"), comp_frames, fps=30, quality=8
766
+ )
767
+ print(f"Done, please check your result in {args.outroot} ")
768
 
769
 
770
  def main(args):
771
+ assert args.mode in (
772
+ "object_removal",
773
+ "video_extrapolation",
774
+ "watermark_removal",
775
+ ), (
776
  "Accepted modes: 'object_removal', 'video_extrapolation', and 'watermark_removal', but input is %s"
777
  ) % args.mode
778
  video_inpainting(args)
779
 
780
 
781
+ if __name__ == "__main__":
782
  parser = argparse.ArgumentParser()
783
+ parser.add_argument(
784
+ "--opt",
785
+ default="configs/object_removal.yaml",
786
+ help="Please select your config file for inference",
787
+ )
788
  # video completion
 
 
789
  parser.add_argument(
790
+ "--mode",
791
+ default="object_removal",
792
+ choices=["object_removal", "watermark_removal", "video_extrapolation"],
793
+ help="modes: object_removal / video_extrapolation",
794
+ )
795
+ parser.add_argument(
796
+ "--path", default="/myData/davis_resized/walking", help="dataset for evaluation"
797
+ )
798
  parser.add_argument(
799
+ "--path_mask",
800
+ default="/myData/dilateAnnotations_4/walking",
801
+ help="mask for object removal",
802
+ )
803
  parser.add_argument(
804
+ "--outroot", default="quick_start/walking3", help="output directory"
805
+ )
806
+ parser.add_argument(
807
+ "--consistencyThres",
808
+ dest="consistencyThres",
809
+ default=5,
810
+ type=float,
811
+ help="flow consistency error threshold",
812
+ )
813
+ parser.add_argument("--alpha", dest="alpha", default=0.1, type=float)
814
+ parser.add_argument("--Nonlocal", dest="Nonlocal", default=False, type=bool)
815
 
816
  # RAFT
817
  parser.add_argument(
818
+ "--raft_model",
819
+ default="../LAFC/flowCheckPoint/raft-things.pth",
820
+ help="restore checkpoint",
821
+ )
822
+ parser.add_argument("--small", action="store_true", help="use small model")
823
+ parser.add_argument(
824
+ "--mixed_precision", action="store_true", help="use mixed precision"
825
+ )
826
+ parser.add_argument(
827
+ "--alternate_corr",
828
+ action="store_true",
829
+ help="use efficent correlation implementation",
830
+ )
831
 
832
  # LAFC
833
+ parser.add_argument("--lafc_ckpts", type=str, default="../LAFC/checkpoint")
834
 
835
  # FGT
836
+ parser.add_argument("--fgt_ckpts", type=str, default="../FGT/checkpoint")
837
 
838
  # extrapolation
839
+ parser.add_argument(
840
+ "--H_scale", dest="H_scale", default=2, type=float, help="H extrapolation scale"
841
+ )
842
+ parser.add_argument(
843
+ "--W_scale", dest="W_scale", default=2, type=float, help="W extrapolation scale"
844
+ )
845
 
846
  # Image basic information
847
+ parser.add_argument("--imgH", type=int, default=256)
848
+ parser.add_argument("--imgW", type=int, default=432)
849
+ parser.add_argument("--flow_mask_dilates", type=int, default=8)
850
+ parser.add_argument("--frame_dilates", type=int, default=0)
851
 
852
+ parser.add_argument("--gpu", type=int, default=0)
853
 
854
  # FGT inference parameters
855
+ parser.add_argument("--step", type=int, default=10)
856
+ parser.add_argument("--num_ref", type=int, default=-1)
857
+ parser.add_argument("--neighbor_stride", type=int, default=5)
858
 
859
  # visualization
860
+ parser.add_argument(
861
+ "--vis_flows", action="store_true", help="Visualize the initialized flows"
862
+ )
863
+ parser.add_argument(
864
+ "--vis_completed_flows",
865
+ action="store_true",
866
+ help="Visualize the completed flows",
867
+ )
868
+ parser.add_argument(
869
+ "--vis_prop",
870
+ action="store_true",
871
+ help="Visualize the frames after stage-I filling (flow guided content propagation)",
872
+ )
873
+ parser.add_argument("--vis_frame", action="store_true", help="Visualize frames")
874
 
875
  args = parser.parse_args()
876