goryhon commited on
Commit
dbca871
·
verified ·
1 Parent(s): cf24190

Update web-demos/hugging_face/inpainter/base_inpainter.py

Browse files
web-demos/hugging_face/inpainter/base_inpainter.py CHANGED
@@ -20,367 +20,367 @@ warnings.filterwarnings("ignore")
20
 
21
 
22
  def imwrite(img, file_path, params=None, auto_mkdir=True):
23
- if auto_mkdir:
24
- dir_name = os.path.abspath(os.path.dirname(file_path))
25
- os.makedirs(dir_name, exist_ok=True)
26
- return cv2.imwrite(file_path, img, params)
27
 
28
 
29
  def resize_frames(frames, size=None):
30
- if size is not None:
31
- out_size = size
32
- process_size = (out_size[0]-out_size[0]%8, out_size[1]-out_size[1]%8)
33
- frames = [f.resize(process_size) for f in frames]
34
- else:
35
- out_size = frames[0].size
36
- process_size = (out_size[0]-out_size[0]%8, out_size[1]-out_size[1]%8)
37
- if not out_size == process_size:
38
- frames = [f.resize(process_size) for f in frames]
39
-
40
- return frames, process_size, out_size
41
 
42
 
43
  def read_frame_from_videos(frame_root):
44
- if frame_root.endswith(('mp4', 'mov', 'avi', 'MP4', 'MOV', 'AVI')): # input video path
45
- video_name = os.path.basename(frame_root)[:-4]
46
- vframes, aframes, info = torchvision.io.read_video(filename=frame_root, pts_unit='sec') # RGB
47
- frames = list(vframes.numpy())
48
- frames = [Image.fromarray(f) for f in frames]
49
- fps = info['video_fps']
50
- else:
51
- video_name = os.path.basename(frame_root)
52
- frames = []
53
- fr_lst = sorted(os.listdir(frame_root))
54
- for fr in fr_lst:
55
- frame = cv2.imread(os.path.join(frame_root, fr))
56
- frame = Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
57
- frames.append(frame)
58
- fps = None
59
- size = frames[0].size
60
-
61
- return frames, fps, size, video_name
62
 
63
 
64
  def binary_mask(mask, th=0.1):
65
- mask[mask>th] = 1
66
- mask[mask<=th] = 0
67
- return mask
68
 
69
 
70
  def extrapolation(video_ori, scale):
71
- """Prepares the data for video outpainting.
72
- """
73
- nFrame = len(video_ori)
74
- imgW, imgH = video_ori[0].size
75
-
76
- # Defines new FOV.
77
- imgH_extr = int(scale[0] * imgH)
78
- imgW_extr = int(scale[1] * imgW)
79
- imgH_extr = imgH_extr - imgH_extr % 8
80
- imgW_extr = imgW_extr - imgW_extr % 8
81
- H_start = int((imgH_extr - imgH) / 2)
82
- W_start = int((imgW_extr - imgW) / 2)
83
-
84
- # Extrapolates the FOV for video.
85
- frames = []
86
- for v in video_ori:
87
- frame = np.zeros(((imgH_extr, imgW_extr, 3)), dtype=np.uint8)
88
- frame[H_start: H_start + imgH, W_start: W_start + imgW, :] = v
89
- frames.append(Image.fromarray(frame))
90
-
91
- # Generates the mask for missing region.
92
- masks_dilated = []
93
- flow_masks = []
94
-
95
- dilate_h = 4 if H_start > 10 else 0
96
- dilate_w = 4 if W_start > 10 else 0
97
- mask = np.ones(((imgH_extr, imgW_extr)), dtype=np.uint8)
98
-
99
- mask[H_start+dilate_h: H_start+imgH-dilate_h,
100
- W_start+dilate_w: W_start+imgW-dilate_w] = 0
101
- flow_masks.append(Image.fromarray(mask * 255))
102
-
103
- mask[H_start: H_start+imgH, W_start: W_start+imgW] = 0
104
- masks_dilated.append(Image.fromarray(mask * 255))
105
 
106
- flow_masks = flow_masks * nFrame
107
- masks_dilated = masks_dilated * nFrame
108
-
109
- return frames, flow_masks, masks_dilated, (imgW_extr, imgH_extr)
110
 
111
 
112
  def get_ref_index(mid_neighbor_id, neighbor_ids, length, ref_stride=10, ref_num=-1):
113
- ref_index = []
114
- if ref_num == -1:
115
- for i in range(0, length, ref_stride):
116
- if i not in neighbor_ids:
117
- ref_index.append(i)
118
- else:
119
- start_idx = max(0, mid_neighbor_id - ref_stride * (ref_num // 2))
120
- end_idx = min(length, mid_neighbor_id + ref_stride * (ref_num // 2))
121
- for i in range(start_idx, end_idx, ref_stride):
122
- if i not in neighbor_ids:
123
- if len(ref_index) > ref_num:
124
- break
125
- ref_index.append(i)
126
- return ref_index
127
 
128
 
129
  def read_mask_demo(masks, length, size, flow_mask_dilates=8, mask_dilates=5):
130
- masks_img = []
131
- masks_dilated = []
132
- flow_masks = []
133
-
134
- for mp in masks:
135
- masks_img.append(Image.fromarray(mp.astype('uint8')))
136
-
137
- for mask_img in masks_img:
138
- if size is not None:
139
- mask_img = mask_img.resize(size, Image.NEAREST)
140
- mask_img = np.array(mask_img.convert('L'))
141
-
142
- # Dilate 8 pixel so that all known pixel is trustworthy
143
- if flow_mask_dilates > 0:
144
- flow_mask_img = scipy.ndimage.binary_dilation(mask_img, iterations=flow_mask_dilates).astype(np.uint8)
145
- else:
146
- flow_mask_img = binary_mask(mask_img).astype(np.uint8)
147
-
148
- flow_masks.append(Image.fromarray(flow_mask_img * 255))
149
-
150
- if mask_dilates > 0:
151
- mask_img = scipy.ndimage.binary_dilation(mask_img, iterations=mask_dilates).astype(np.uint8)
152
- else:
153
- mask_img = binary_mask(mask_img).astype(np.uint8)
154
- masks_dilated.append(Image.fromarray(mask_img * 255))
155
-
156
- if len(masks_img) == 1:
157
- flow_masks = flow_masks * length
158
- masks_dilated = masks_dilated * length
159
-
160
- return flow_masks, masks_dilated
161
 
162
 
163
  class ProInpainter:
164
- def __init__(self, propainter_checkpoint, raft_checkpoint, flow_completion_checkpoint, device="cuda:0", use_half=True):
165
- self.device = device
166
- self.use_half = use_half
167
- if self.device == torch.device('cpu'):
168
- self.use_half = False
169
-
170
- ##############################################
171
- # set up RAFT and flow competition model
172
- ##############################################
173
- self.fix_raft = RAFT_bi(raft_checkpoint, self.device)
174
-
175
- self.fix_flow_complete = RecurrentFlowCompleteNet(flow_completion_checkpoint)
176
- for p in self.fix_flow_complete.parameters():
177
- p.requires_grad = False
178
- self.fix_flow_complete.to(self.device)
179
- self.fix_flow_complete.eval()
180
-
181
- ##############################################
182
- # set up ProPainter model
183
- ##############################################
184
- self.model = InpaintGenerator(model_path=propainter_checkpoint).to(self.device)
185
- self.model.eval()
186
-
187
- if self.use_half:
188
- self.fix_flow_complete = self.fix_flow_complete.half()
189
- self.model = self.model.half()
190
-
191
- def inpaint(self, npframes, masks, ratio=1.0, dilate_radius=4, raft_iter=20, subvideo_length=80, neighbor_length=10, ref_stride=10):
192
- """
193
- Perform Inpainting for video subsets
194
-
195
- Output:
196
- inpainted_frames: numpy array, T, H, W, 3
197
- """
198
-
199
- frames = []
200
- for i in range(len(npframes)):
201
- frames.append(Image.fromarray(npframes[i].astype('uint8'), mode="RGB"))
202
- del npframes
203
-
204
- # Получаем оригинальный размер
205
- size = frames[0].size # (width, height)
206
-
207
- # Применяем ratio, только если он отличается от 1.0
208
- if ratio != 1.0:
209
- size = (int(ratio * size[0]) // 2 * 2, int(ratio * size[1]) // 2 * 2)
210
- else:
211
- size = (size[0] // 2 * 2, size[1] // 2 * 2) # просто округляем до ближайшего чётного
212
-
213
- frames_len = len(frames)
214
-
215
- # ⚠️ resize_frames больше не меняет разрешение, если оно уже чётное
216
- frames, size, out_size = resize_frames(frames, size)
217
-
218
- flow_masks, masks_dilated = read_mask_demo(masks, frames_len, size, dilate_radius, dilate_radius)
219
- w, h = size
220
-
221
- frames_inp = [np.array(f).astype(np.uint8) for f in frames]
222
-
223
- frames = to_tensors()(frames).unsqueeze(0) * 2 - 1
224
- flow_masks = to_tensors()(flow_masks).unsqueeze(0)
225
- masks_dilated = to_tensors()(masks_dilated).unsqueeze(0)
226
-
227
- frames = frames.to(self.device)
228
- flow_masks = flow_masks.to(self.device)
229
- masks_dilated = masks_dilated.to(self.device)
230
-
231
- ##############################################
232
- # ProPainter inference
233
- ##############################################
234
- video_length = frames.size(1)
235
- with torch.no_grad():
236
- # ---- compute flow ----
237
- if frames.size(-1) <= 640:
238
- short_clip_len = 12
239
- elif frames.size(-1) <= 720:
240
- short_clip_len = 8
241
- elif frames.size(-1) <= 1280:
242
- short_clip_len = 4
243
- else:
244
- short_clip_len = 2
245
-
246
- # use fp32 for RAFT
247
- if frames.size(1) > short_clip_len:
248
- gt_flows_f_list, gt_flows_b_list = [], []
249
- for f in range(0, video_length, short_clip_len):
250
- end_f = min(video_length, f + short_clip_len)
251
- if f == 0:
252
- flows_f, flows_b = self.fix_raft(frames[:,f:end_f], iters=raft_iter)
253
- else:
254
- flows_f, flows_b = self.fix_raft(frames[:,f-1:end_f], iters=raft_iter)
255
-
256
- gt_flows_f_list.append(flows_f)
257
- gt_flows_b_list.append(flows_b)
258
- torch.cuda.empty_cache()
259
-
260
- gt_flows_f = torch.cat(gt_flows_f_list, dim=1)
261
- gt_flows_b = torch.cat(gt_flows_b_list, dim=1)
262
- gt_flows_bi = (gt_flows_f, gt_flows_b)
263
- else:
264
- gt_flows_bi = self.fix_raft(frames, iters=raft_iter)
265
- torch.cuda.empty_cache()
266
-
267
- if self.use_half:
268
- frames, flow_masks, masks_dilated = frames.half(), flow_masks.half(), masks_dilated.half()
269
- gt_flows_bi = (gt_flows_bi[0].half(), gt_flows_bi[1].half())
270
-
271
- # ---- complete flow ----
272
- flow_length = gt_flows_bi[0].size(1)
273
- if flow_length > subvideo_length:
274
- pred_flows_f, pred_flows_b = [], []
275
- pad_len = 5
276
- for f in range(0, flow_length, subvideo_length):
277
- s_f = max(0, f - pad_len)
278
- e_f = min(flow_length, f + subvideo_length + pad_len)
279
- pad_len_s = max(0, f) - s_f
280
- pad_len_e = e_f - min(flow_length, f + subvideo_length)
281
- pred_flows_bi_sub, _ = self.fix_flow_complete.forward_bidirect_flow(
282
- (gt_flows_bi[0][:, s_f:e_f], gt_flows_bi[1][:, s_f:e_f]),
283
- flow_masks[:, s_f:e_f+1])
284
- pred_flows_bi_sub = self.fix_flow_complete.combine_flow(
285
- (gt_flows_bi[0][:, s_f:e_f], gt_flows_bi[1][:, s_f:e_f]),
286
- pred_flows_bi_sub,
287
- flow_masks[:, s_f:e_f+1])
288
-
289
- pred_flows_f.append(pred_flows_bi_sub[0][:, pad_len_s:e_f-s_f-pad_len_e])
290
- pred_flows_b.append(pred_flows_bi_sub[1][:, pad_len_s:e_f-s_f-pad_len_e])
291
- torch.cuda.empty_cache()
292
-
293
- pred_flows_f = torch.cat(pred_flows_f, dim=1)
294
- pred_flows_b = torch.cat(pred_flows_b, dim=1)
295
- pred_flows_bi = (pred_flows_f, pred_flows_b)
296
- else:
297
- pred_flows_bi, _ = self.fix_flow_complete.forward_bidirect_flow(gt_flows_bi, flow_masks)
298
- pred_flows_bi = self.fix_flow_complete.combine_flow(gt_flows_bi, pred_flows_bi, flow_masks)
299
- torch.cuda.empty_cache()
300
-
301
- # ---- image propagation ----
302
- masked_frames = frames * (1 - masks_dilated)
303
- subvideo_length_img_prop = min(100, subvideo_length) # ensure a minimum of 100 frames for image propagation
304
- if video_length > subvideo_length_img_prop:
305
- updated_frames, updated_masks = [], []
306
- pad_len = 10
307
- for f in range(0, video_length, subvideo_length_img_prop):
308
- s_f = max(0, f - pad_len)
309
- e_f = min(video_length, f + subvideo_length_img_prop + pad_len)
310
- pad_len_s = max(0, f) - s_f
311
- pad_len_e = e_f - min(video_length, f + subvideo_length_img_prop)
312
-
313
- b, t, _, _, _ = masks_dilated[:, s_f:e_f].size()
314
- pred_flows_bi_sub = (pred_flows_bi[0][:, s_f:e_f-1], pred_flows_bi[1][:, s_f:e_f-1])
315
- prop_imgs_sub, updated_local_masks_sub = self.model.img_propagation(masked_frames[:, s_f:e_f],
316
- pred_flows_bi_sub,
317
- masks_dilated[:, s_f:e_f],
318
- 'nearest')
319
- updated_frames_sub = frames[:, s_f:e_f] * (1 - masks_dilated[:, s_f:e_f]) + \
320
- prop_imgs_sub.view(b, t, 3, h, w) * masks_dilated[:, s_f:e_f]
321
- updated_masks_sub = updated_local_masks_sub.view(b, t, 1, h, w)
322
-
323
- updated_frames.append(updated_frames_sub[:, pad_len_s:e_f-s_f-pad_len_e])
324
- updated_masks.append(updated_masks_sub[:, pad_len_s:e_f-s_f-pad_len_e])
325
- torch.cuda.empty_cache()
326
-
327
- updated_frames = torch.cat(updated_frames, dim=1)
328
- updated_masks = torch.cat(updated_masks, dim=1)
329
- else:
330
- b, t, _, _, _ = masks_dilated.size()
331
- prop_imgs, updated_local_masks = self.model.img_propagation(masked_frames, pred_flows_bi, masks_dilated, 'nearest')
332
- updated_frames = frames * (1 - masks_dilated) + prop_imgs.view(b, t, 3, h, w) * masks_dilated
333
- updated_masks = updated_local_masks.view(b, t, 1, h, w)
334
- torch.cuda.empty_cache()
335
-
336
- ori_frames = frames_inp
337
- comp_frames = [None] * video_length
338
-
339
- neighbor_stride = neighbor_length // 2
340
- if video_length > subvideo_length:
341
- ref_num = subvideo_length // ref_stride
342
- else:
343
- ref_num = -1
344
-
345
- # ---- feature propagation + transformer ----
346
- for f in tqdm(range(0, video_length, neighbor_stride)):
347
- neighbor_ids = [
348
- i for i in range(max(0, f - neighbor_stride),
349
- min(video_length, f + neighbor_stride + 1))
350
- ]
351
- ref_ids = get_ref_index(f, neighbor_ids, video_length, ref_stride, ref_num)
352
- selected_imgs = updated_frames[:, neighbor_ids + ref_ids, :, :, :]
353
- selected_masks = masks_dilated[:, neighbor_ids + ref_ids, :, :, :]
354
- selected_update_masks = updated_masks[:, neighbor_ids + ref_ids, :, :, :]
355
- selected_pred_flows_bi = (pred_flows_bi[0][:, neighbor_ids[:-1], :, :, :], pred_flows_bi[1][:, neighbor_ids[:-1], :, :, :])
356
-
357
- with torch.no_grad():
358
- # 1.0 indicates mask
359
- l_t = len(neighbor_ids)
360
-
361
- # pred_img = selected_imgs # results of image propagation
362
- pred_img = self.model(selected_imgs, selected_pred_flows_bi, selected_masks, selected_update_masks, l_t)
363
-
364
- pred_img = pred_img.view(-1, 3, h, w)
365
-
366
- pred_img = (pred_img + 1) / 2
367
- pred_img = pred_img.cpu().permute(0, 2, 3, 1).numpy() * 255
368
- binary_masks = masks_dilated[0, neighbor_ids, :, :, :].cpu().permute(
369
- 0, 2, 3, 1).numpy().astype(np.uint8)
370
- for i in range(len(neighbor_ids)):
371
- idx = neighbor_ids[i]
372
- img = np.array(pred_img[i]).astype(np.uint8) * binary_masks[i] \
373
- + ori_frames[idx] * (1 - binary_masks[i])
374
- if comp_frames[idx] is None:
375
- comp_frames[idx] = img
376
- else:
377
- comp_frames[idx] = comp_frames[idx].astype(np.float32) * 0.5 + img.astype(np.float32) * 0.5
378
-
379
- comp_frames[idx] = comp_frames[idx].astype(np.uint8)
380
-
381
- torch.cuda.empty_cache()
382
-
383
- # need to return numpy array, T, H, W, 3
384
- comp_frames = [cv2.resize(f, out_size) for f in comp_frames]
385
-
386
- return comp_frames
 
20
 
21
 
22
  def imwrite(img, file_path, params=None, auto_mkdir=True):
23
+ if auto_mkdir:
24
+ dir_name = os.path.abspath(os.path.dirname(file_path))
25
+ os.makedirs(dir_name, exist_ok=True)
26
+ return cv2.imwrite(file_path, img, params)
27
 
28
 
29
  def resize_frames(frames, size=None):
30
+ if size is not None:
31
+ out_size = size
32
+ process_size = (out_size[0]-out_size[0]%8, out_size[1]-out_size[1]%8)
33
+ frames = [f.resize(process_size) for f in frames]
34
+ else:
35
+ out_size = frames[0].size
36
+ process_size = (out_size[0]-out_size[0]%8, out_size[1]-out_size[1]%8)
37
+ if not out_size == process_size:
38
+ frames = [f.resize(process_size) for f in frames]
39
+
40
+ return frames, process_size, out_size
41
 
42
 
43
  def read_frame_from_videos(frame_root):
44
+ if frame_root.endswith(('mp4', 'mov', 'avi', 'MP4', 'MOV', 'AVI')): # input video path
45
+ video_name = os.path.basename(frame_root)[:-4]
46
+ vframes, aframes, info = torchvision.io.read_video(filename=frame_root, pts_unit='sec') # RGB
47
+ frames = list(vframes.numpy())
48
+ frames = [Image.fromarray(f) for f in frames]
49
+ fps = info['video_fps']
50
+ else:
51
+ video_name = os.path.basename(frame_root)
52
+ frames = []
53
+ fr_lst = sorted(os.listdir(frame_root))
54
+ for fr in fr_lst:
55
+ frame = cv2.imread(os.path.join(frame_root, fr))
56
+ frame = Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
57
+ frames.append(frame)
58
+ fps = None
59
+ size = frames[0].size
60
+
61
+ return frames, fps, size, video_name
62
 
63
 
64
  def binary_mask(mask, th=0.1):
65
+ mask[mask>th] = 1
66
+ mask[mask<=th] = 0
67
+ return mask
68
 
69
 
70
  def extrapolation(video_ori, scale):
71
+ """Prepares the data for video outpainting.
72
+ """
73
+ nFrame = len(video_ori)
74
+ imgW, imgH = video_ori[0].size
75
+
76
+ # Defines new FOV.
77
+ imgH_extr = int(scale[0] * imgH)
78
+ imgW_extr = int(scale[1] * imgW)
79
+ imgH_extr = imgH_extr - imgH_extr % 8
80
+ imgW_extr = imgW_extr - imgW_extr % 8
81
+ H_start = int((imgH_extr - imgH) / 2)
82
+ W_start = int((imgW_extr - imgW) / 2)
83
+
84
+ # Extrapolates the FOV for video.
85
+ frames = []
86
+ for v in video_ori:
87
+ frame = np.zeros(((imgH_extr, imgW_extr, 3)), dtype=np.uint8)
88
+ frame[H_start: H_start + imgH, W_start: W_start + imgW, :] = v
89
+ frames.append(Image.fromarray(frame))
90
+
91
+ # Generates the mask for missing region.
92
+ masks_dilated = []
93
+ flow_masks = []
94
+
95
+ dilate_h = 4 if H_start > 10 else 0
96
+ dilate_w = 4 if W_start > 10 else 0
97
+ mask = np.ones(((imgH_extr, imgW_extr)), dtype=np.uint8)
98
+
99
+ mask[H_start+dilate_h: H_start+imgH-dilate_h,
100
+ W_start+dilate_w: W_start+imgW-dilate_w] = 0
101
+ flow_masks.append(Image.fromarray(mask * 255))
102
+
103
+ mask[H_start: H_start+imgH, W_start: W_start+imgW] = 0
104
+ masks_dilated.append(Image.fromarray(mask * 255))
105
 
106
+ flow_masks = flow_masks * nFrame
107
+ masks_dilated = masks_dilated * nFrame
108
+
109
+ return frames, flow_masks, masks_dilated, (imgW_extr, imgH_extr)
110
 
111
 
112
  def get_ref_index(mid_neighbor_id, neighbor_ids, length, ref_stride=10, ref_num=-1):
113
+ ref_index = []
114
+ if ref_num == -1:
115
+ for i in range(0, length, ref_stride):
116
+ if i not in neighbor_ids:
117
+ ref_index.append(i)
118
+ else:
119
+ start_idx = max(0, mid_neighbor_id - ref_stride * (ref_num // 2))
120
+ end_idx = min(length, mid_neighbor_id + ref_stride * (ref_num // 2))
121
+ for i in range(start_idx, end_idx, ref_stride):
122
+ if i not in neighbor_ids:
123
+ if len(ref_index) > ref_num:
124
+ break
125
+ ref_index.append(i)
126
+ return ref_index
127
 
128
 
129
  def read_mask_demo(masks, length, size, flow_mask_dilates=8, mask_dilates=5):
130
+ masks_img = []
131
+ masks_dilated = []
132
+ flow_masks = []
133
+
134
+ for mp in masks:
135
+ masks_img.append(Image.fromarray(mp.astype('uint8')))
136
+
137
+ for mask_img in masks_img:
138
+ if size is not None:
139
+ mask_img = mask_img.resize(size, Image.NEAREST)
140
+ mask_img = np.array(mask_img.convert('L'))
141
+
142
+ # Dilate 8 pixel so that all known pixel is trustworthy
143
+ if flow_mask_dilates > 0:
144
+ flow_mask_img = scipy.ndimage.binary_dilation(mask_img, iterations=flow_mask_dilates).astype(np.uint8)
145
+ else:
146
+ flow_mask_img = binary_mask(mask_img).astype(np.uint8)
147
+
148
+ flow_masks.append(Image.fromarray(flow_mask_img * 255))
149
+
150
+ if mask_dilates > 0:
151
+ mask_img = scipy.ndimage.binary_dilation(mask_img, iterations=mask_dilates).astype(np.uint8)
152
+ else:
153
+ mask_img = binary_mask(mask_img).astype(np.uint8)
154
+ masks_dilated.append(Image.fromarray(mask_img * 255))
155
+
156
+ if len(masks_img) == 1:
157
+ flow_masks = flow_masks * length
158
+ masks_dilated = masks_dilated * length
159
+
160
+ return flow_masks, masks_dilated
161
 
162
 
163
  class ProInpainter:
164
+ def __init__(self, propainter_checkpoint, raft_checkpoint, flow_completion_checkpoint, device="cuda:0", use_half=True):
165
+ self.device = device
166
+ self.use_half = use_half
167
+ if self.device == torch.device('cpu'):
168
+ self.use_half = False
169
+
170
+ ##############################################
171
+ # set up RAFT and flow competition model
172
+ ##############################################
173
+ self.fix_raft = RAFT_bi(raft_checkpoint, self.device)
174
+
175
+ self.fix_flow_complete = RecurrentFlowCompleteNet(flow_completion_checkpoint)
176
+ for p in self.fix_flow_complete.parameters():
177
+ p.requires_grad = False
178
+ self.fix_flow_complete.to(self.device)
179
+ self.fix_flow_complete.eval()
180
+
181
+ ##############################################
182
+ # set up ProPainter model
183
+ ##############################################
184
+ self.model = InpaintGenerator(model_path=propainter_checkpoint).to(self.device)
185
+ self.model.eval()
186
+
187
+ if self.use_half:
188
+ self.fix_flow_complete = self.fix_flow_complete.half()
189
+ self.model = self.model.half()
190
+
191
+ def inpaint(self, npframes, masks, ratio=1.0, dilate_radius=4, raft_iter=20, subvideo_length=80, neighbor_length=10, ref_stride=10):
192
+ """
193
+ Perform Inpainting for video subsets
194
+
195
+ Output:
196
+ inpainted_frames: numpy array, T, H, W, 3
197
+ """
198
+
199
+ frames = []
200
+ for i in range(len(npframes)):
201
+ frames.append(Image.fromarray(npframes[i].astype('uint8'), mode="RGB"))
202
+ del npframes
203
+
204
+ # Получаем оригинальный размер
205
+ size = frames[0].size # (width, height)
206
+
207
+ # Применяем ratio, только если он отличается от 1.0
208
+ if ratio != 1.0:
209
+ size = (int(ratio * size[0]) // 2 * 2, int(ratio * size[1]) // 2 * 2)
210
+ else:
211
+ size = (size[0] // 2 * 2, size[1] // 2 * 2) # просто округляем до ближайшего чётного
212
+
213
+ frames_len = len(frames)
214
+
215
+ # ⚠️ resize_frames больше не меняет разрешение, если оно уже чётное
216
+ frames, size, out_size = resize_frames(frames, size)
217
+
218
+ flow_masks, masks_dilated = read_mask_demo(masks, frames_len, size, dilate_radius, dilate_radius)
219
+ w, h = size
220
+
221
+ frames_inp = [np.array(f).astype(np.uint8) for f in frames]
222
+
223
+ frames = to_tensors()(frames).unsqueeze(0) * 2 - 1
224
+ flow_masks = to_tensors()(flow_masks).unsqueeze(0)
225
+ masks_dilated = to_tensors()(masks_dilated).unsqueeze(0)
226
+
227
+ frames = frames.to(self.device)
228
+ flow_masks = flow_masks.to(self.device)
229
+ masks_dilated = masks_dilated.to(self.device)
230
+
231
+ ##############################################
232
+ # ProPainter inference
233
+ ##############################################
234
+ video_length = frames.size(1)
235
+ with torch.no_grad():
236
+ # ---- compute flow ----
237
+ if frames.size(-1) <= 640:
238
+ short_clip_len = 12
239
+ elif frames.size(-1) <= 720:
240
+ short_clip_len = 8
241
+ elif frames.size(-1) <= 1280:
242
+ short_clip_len = 4
243
+ else:
244
+ short_clip_len = 2
245
+
246
+ # use fp32 for RAFT
247
+ if frames.size(1) > short_clip_len:
248
+ gt_flows_f_list, gt_flows_b_list = [], []
249
+ for f in range(0, video_length, short_clip_len):
250
+ end_f = min(video_length, f + short_clip_len)
251
+ if f == 0:
252
+ flows_f, flows_b = self.fix_raft(frames[:,f:end_f], iters=raft_iter)
253
+ else:
254
+ flows_f, flows_b = self.fix_raft(frames[:,f-1:end_f], iters=raft_iter)
255
+
256
+ gt_flows_f_list.append(flows_f)
257
+ gt_flows_b_list.append(flows_b)
258
+ torch.cuda.empty_cache()
259
+
260
+ gt_flows_f = torch.cat(gt_flows_f_list, dim=1)
261
+ gt_flows_b = torch.cat(gt_flows_b_list, dim=1)
262
+ gt_flows_bi = (gt_flows_f, gt_flows_b)
263
+ else:
264
+ gt_flows_bi = self.fix_raft(frames, iters=raft_iter)
265
+ torch.cuda.empty_cache()
266
+
267
+ if self.use_half:
268
+ frames, flow_masks, masks_dilated = frames.half(), flow_masks.half(), masks_dilated.half()
269
+ gt_flows_bi = (gt_flows_bi[0].half(), gt_flows_bi[1].half())
270
+
271
+ # ---- complete flow ----
272
+ flow_length = gt_flows_bi[0].size(1)
273
+ if flow_length > subvideo_length:
274
+ pred_flows_f, pred_flows_b = [], []
275
+ pad_len = 5
276
+ for f in range(0, flow_length, subvideo_length):
277
+ s_f = max(0, f - pad_len)
278
+ e_f = min(flow_length, f + subvideo_length + pad_len)
279
+ pad_len_s = max(0, f) - s_f
280
+ pad_len_e = e_f - min(flow_length, f + subvideo_length)
281
+ pred_flows_bi_sub, _ = self.fix_flow_complete.forward_bidirect_flow(
282
+ (gt_flows_bi[0][:, s_f:e_f], gt_flows_bi[1][:, s_f:e_f]),
283
+ flow_masks[:, s_f:e_f+1])
284
+ pred_flows_bi_sub = self.fix_flow_complete.combine_flow(
285
+ (gt_flows_bi[0][:, s_f:e_f], gt_flows_bi[1][:, s_f:e_f]),
286
+ pred_flows_bi_sub,
287
+ flow_masks[:, s_f:e_f+1])
288
+
289
+ pred_flows_f.append(pred_flows_bi_sub[0][:, pad_len_s:e_f-s_f-pad_len_e])
290
+ pred_flows_b.append(pred_flows_bi_sub[1][:, pad_len_s:e_f-s_f-pad_len_e])
291
+ torch.cuda.empty_cache()
292
+
293
+ pred_flows_f = torch.cat(pred_flows_f, dim=1)
294
+ pred_flows_b = torch.cat(pred_flows_b, dim=1)
295
+ pred_flows_bi = (pred_flows_f, pred_flows_b)
296
+ else:
297
+ pred_flows_bi, _ = self.fix_flow_complete.forward_bidirect_flow(gt_flows_bi, flow_masks)
298
+ pred_flows_bi = self.fix_flow_complete.combine_flow(gt_flows_bi, pred_flows_bi, flow_masks)
299
+ torch.cuda.empty_cache()
300
+
301
+ # ---- image propagation ----
302
+ masked_frames = frames * (1 - masks_dilated)
303
+ subvideo_length_img_prop = min(100, subvideo_length) # ensure a minimum of 100 frames for image propagation
304
+ if video_length > subvideo_length_img_prop:
305
+ updated_frames, updated_masks = [], []
306
+ pad_len = 10
307
+ for f in range(0, video_length, subvideo_length_img_prop):
308
+ s_f = max(0, f - pad_len)
309
+ e_f = min(video_length, f + subvideo_length_img_prop + pad_len)
310
+ pad_len_s = max(0, f) - s_f
311
+ pad_len_e = e_f - min(video_length, f + subvideo_length_img_prop)
312
+
313
+ b, t, _, _, _ = masks_dilated[:, s_f:e_f].size()
314
+ pred_flows_bi_sub = (pred_flows_bi[0][:, s_f:e_f-1], pred_flows_bi[1][:, s_f:e_f-1])
315
+ prop_imgs_sub, updated_local_masks_sub = self.model.img_propagation(masked_frames[:, s_f:e_f],
316
+ pred_flows_bi_sub,
317
+ masks_dilated[:, s_f:e_f],
318
+ 'nearest')
319
+ updated_frames_sub = frames[:, s_f:e_f] * (1 - masks_dilated[:, s_f:e_f]) + \
320
+ prop_imgs_sub.view(b, t, 3, h, w) * masks_dilated[:, s_f:e_f]
321
+ updated_masks_sub = updated_local_masks_sub.view(b, t, 1, h, w)
322
+
323
+ updated_frames.append(updated_frames_sub[:, pad_len_s:e_f-s_f-pad_len_e])
324
+ updated_masks.append(updated_masks_sub[:, pad_len_s:e_f-s_f-pad_len_e])
325
+ torch.cuda.empty_cache()
326
+
327
+ updated_frames = torch.cat(updated_frames, dim=1)
328
+ updated_masks = torch.cat(updated_masks, dim=1)
329
+ else:
330
+ b, t, _, _, _ = masks_dilated.size()
331
+ prop_imgs, updated_local_masks = self.model.img_propagation(masked_frames, pred_flows_bi, masks_dilated, 'nearest')
332
+ updated_frames = frames * (1 - masks_dilated) + prop_imgs.view(b, t, 3, h, w) * masks_dilated
333
+ updated_masks = updated_local_masks.view(b, t, 1, h, w)
334
+ torch.cuda.empty_cache()
335
+
336
+ ori_frames = frames_inp
337
+ comp_frames = [None] * video_length
338
+
339
+ neighbor_stride = neighbor_length // 2
340
+ if video_length > subvideo_length:
341
+ ref_num = subvideo_length // ref_stride
342
+ else:
343
+ ref_num = -1
344
+
345
+ # ---- feature propagation + transformer ----
346
+ for f in tqdm(range(0, video_length, neighbor_stride)):
347
+ neighbor_ids = [
348
+ i for i in range(max(0, f - neighbor_stride),
349
+ min(video_length, f + neighbor_stride + 1))
350
+ ]
351
+ ref_ids = get_ref_index(f, neighbor_ids, video_length, ref_stride, ref_num)
352
+ selected_imgs = updated_frames[:, neighbor_ids + ref_ids, :, :, :]
353
+ selected_masks = masks_dilated[:, neighbor_ids + ref_ids, :, :, :]
354
+ selected_update_masks = updated_masks[:, neighbor_ids + ref_ids, :, :, :]
355
+ selected_pred_flows_bi = (pred_flows_bi[0][:, neighbor_ids[:-1], :, :, :], pred_flows_bi[1][:, neighbor_ids[:-1], :, :, :])
356
+
357
+ with torch.no_grad():
358
+ # 1.0 indicates mask
359
+ l_t = len(neighbor_ids)
360
+
361
+ # pred_img = selected_imgs # results of image propagation
362
+ pred_img = self.model(selected_imgs, selected_pred_flows_bi, selected_masks, selected_update_masks, l_t)
363
+
364
+ pred_img = pred_img.view(-1, 3, h, w)
365
+
366
+ pred_img = (pred_img + 1) / 2
367
+ pred_img = pred_img.cpu().permute(0, 2, 3, 1).numpy() * 255
368
+ binary_masks = masks_dilated[0, neighbor_ids, :, :, :].cpu().permute(
369
+ 0, 2, 3, 1).numpy().astype(np.uint8)
370
+ for i in range(len(neighbor_ids)):
371
+ idx = neighbor_ids[i]
372
+ img = np.array(pred_img[i]).astype(np.uint8) * binary_masks[i] \
373
+ + ori_frames[idx] * (1 - binary_masks[i])
374
+ if comp_frames[idx] is None:
375
+ comp_frames[idx] = img
376
+ else:
377
+ comp_frames[idx] = comp_frames[idx].astype(np.float32) * 0.5 + img.astype(np.float32) * 0.5
378
+
379
+ comp_frames[idx] = comp_frames[idx].astype(np.uint8)
380
+
381
+ torch.cuda.empty_cache()
382
+
383
+ # need to return numpy array, T, H, W, 3
384
+ comp_frames = [cv2.resize(f, out_size) for f in comp_frames]
385
+
386
+ return comp_frames