prithivMLmods commited on
Commit
f7164a1
·
verified ·
1 Parent(s): bc74534

Update src/core.py

Browse files
Files changed (1) hide show
  1. src/core.py +465 -465
src/core.py CHANGED
@@ -1,466 +1,466 @@
1
- import base64
2
- import json
3
- import os
4
- import re
5
- import time
6
- import uuid
7
- from io import BytesIO
8
- from pathlib import Path
9
- import cv2
10
-
11
- # For inpainting
12
-
13
- import numpy as np
14
- import pandas as pd
15
- import streamlit as st
16
- from PIL import Image
17
- from streamlit_drawable_canvas import st_canvas
18
-
19
-
20
- import argparse
21
- import io
22
- import multiprocessing
23
- from typing import Union
24
-
25
- import torch
26
-
27
- try:
28
- torch._C._jit_override_can_fuse_on_cpu(False)
29
- torch._C._jit_override_can_fuse_on_gpu(False)
30
- torch._C._jit_set_texpr_fuser_enabled(False)
31
- torch._C._jit_set_nvfuser_enabled(False)
32
- except:
33
- pass
34
-
35
- from src.helper import (
36
- download_model,
37
- load_img,
38
- norm_img,
39
- numpy_to_bytes,
40
- pad_img_to_modulo,
41
- resize_max_size,
42
- )
43
-
44
- NUM_THREADS = str(multiprocessing.cpu_count())
45
-
46
- os.environ["OMP_NUM_THREADS"] = NUM_THREADS
47
- os.environ["OPENBLAS_NUM_THREADS"] = NUM_THREADS
48
- os.environ["MKL_NUM_THREADS"] = NUM_THREADS
49
- os.environ["VECLIB_MAXIMUM_THREADS"] = NUM_THREADS
50
- os.environ["NUMEXPR_NUM_THREADS"] = NUM_THREADS
51
- if os.environ.get("CACHE_DIR"):
52
- os.environ["TORCH_HOME"] = os.environ["CACHE_DIR"]
53
-
54
- #BUILD_DIR = os.environ.get("LAMA_CLEANER_BUILD_DIR", "./lama_cleaner/app/build")
55
-
56
- # For Seam-carving
57
-
58
- from scipy import ndimage as ndi
59
-
60
- SEAM_COLOR = np.array([255, 200, 200]) # seam visualization color (BGR)
61
- SHOULD_DOWNSIZE = True # if True, downsize image for faster carving
62
- DOWNSIZE_WIDTH = 500 # resized image width if SHOULD_DOWNSIZE is True
63
- ENERGY_MASK_CONST = 100000.0 # large energy value for protective masking
64
- MASK_THRESHOLD = 10 # minimum pixel intensity for binary mask
65
- USE_FORWARD_ENERGY = True # if True, use forward energy algorithm
66
-
67
- device = torch.device("cpu")
68
- model_path = "./pt/erase.pt"
69
- model = torch.jit.load(model_path, map_location="cpu")
70
- model = model.to(device)
71
- model.eval()
72
-
73
-
74
- ########################################
75
- # UTILITY CODE
76
- ########################################
77
-
78
-
79
- def visualize(im, boolmask=None, rotate=False):
80
- vis = im.astype(np.uint8)
81
- if boolmask is not None:
82
- vis[np.where(boolmask == False)] = SEAM_COLOR
83
- if rotate:
84
- vis = rotate_image(vis, False)
85
- cv2.imshow("visualization", vis)
86
- cv2.waitKey(1)
87
- return vis
88
-
89
- def resize(image, width):
90
- dim = None
91
- h, w = image.shape[:2]
92
- dim = (width, int(h * width / float(w)))
93
- image = image.astype('float32')
94
- return cv2.resize(image, dim)
95
-
96
- def rotate_image(image, clockwise):
97
- k = 1 if clockwise else 3
98
- return np.rot90(image, k)
99
-
100
-
101
- ########################################
102
- # ENERGY FUNCTIONS
103
- ########################################
104
-
105
- def backward_energy(im):
106
- """
107
- Simple gradient magnitude energy map.
108
- """
109
- xgrad = ndi.convolve1d(im, np.array([1, 0, -1]), axis=1, mode='wrap')
110
- ygrad = ndi.convolve1d(im, np.array([1, 0, -1]), axis=0, mode='wrap')
111
-
112
- grad_mag = np.sqrt(np.sum(xgrad**2, axis=2) + np.sum(ygrad**2, axis=2))
113
-
114
- # vis = visualize(grad_mag)
115
- # cv2.imwrite("backward_energy_demo.jpg", vis)
116
-
117
- return grad_mag
118
-
119
- def forward_energy(im):
120
- """
121
- Forward energy algorithm as described in "Improved Seam Carving for Video Retargeting"
122
- by Rubinstein, Shamir, Avidan.
123
- Vectorized code adapted from
124
- https://github.com/axu2/improved-seam-carving.
125
- """
126
- h, w = im.shape[:2]
127
- im = cv2.cvtColor(im.astype(np.uint8), cv2.COLOR_BGR2GRAY).astype(np.float64)
128
-
129
- energy = np.zeros((h, w))
130
- m = np.zeros((h, w))
131
-
132
- U = np.roll(im, 1, axis=0)
133
- L = np.roll(im, 1, axis=1)
134
- R = np.roll(im, -1, axis=1)
135
-
136
- cU = np.abs(R - L)
137
- cL = np.abs(U - L) + cU
138
- cR = np.abs(U - R) + cU
139
-
140
- for i in range(1, h):
141
- mU = m[i-1]
142
- mL = np.roll(mU, 1)
143
- mR = np.roll(mU, -1)
144
-
145
- mULR = np.array([mU, mL, mR])
146
- cULR = np.array([cU[i], cL[i], cR[i]])
147
- mULR += cULR
148
-
149
- argmins = np.argmin(mULR, axis=0)
150
- m[i] = np.choose(argmins, mULR)
151
- energy[i] = np.choose(argmins, cULR)
152
-
153
- # vis = visualize(energy)
154
- # cv2.imwrite("forward_energy_demo.jpg", vis)
155
-
156
- return energy
157
-
158
- ########################################
159
- # SEAM HELPER FUNCTIONS
160
- ########################################
161
-
162
- def add_seam(im, seam_idx):
163
- """
164
- Add a vertical seam to a 3-channel color image at the indices provided
165
- by averaging the pixels values to the left and right of the seam.
166
- Code adapted from https://github.com/vivianhylee/seam-carving.
167
- """
168
- h, w = im.shape[:2]
169
- output = np.zeros((h, w + 1, 3))
170
- for row in range(h):
171
- col = seam_idx[row]
172
- for ch in range(3):
173
- if col == 0:
174
- p = np.mean(im[row, col: col + 2, ch])
175
- output[row, col, ch] = im[row, col, ch]
176
- output[row, col + 1, ch] = p
177
- output[row, col + 1:, ch] = im[row, col:, ch]
178
- else:
179
- p = np.mean(im[row, col - 1: col + 1, ch])
180
- output[row, : col, ch] = im[row, : col, ch]
181
- output[row, col, ch] = p
182
- output[row, col + 1:, ch] = im[row, col:, ch]
183
-
184
- return output
185
-
186
- def add_seam_grayscale(im, seam_idx):
187
- """
188
- Add a vertical seam to a grayscale image at the indices provided
189
- by averaging the pixels values to the left and right of the seam.
190
- """
191
- h, w = im.shape[:2]
192
- output = np.zeros((h, w + 1))
193
- for row in range(h):
194
- col = seam_idx[row]
195
- if col == 0:
196
- p = np.mean(im[row, col: col + 2])
197
- output[row, col] = im[row, col]
198
- output[row, col + 1] = p
199
- output[row, col + 1:] = im[row, col:]
200
- else:
201
- p = np.mean(im[row, col - 1: col + 1])
202
- output[row, : col] = im[row, : col]
203
- output[row, col] = p
204
- output[row, col + 1:] = im[row, col:]
205
-
206
- return output
207
-
208
- def remove_seam(im, boolmask):
209
- h, w = im.shape[:2]
210
- boolmask3c = np.stack([boolmask] * 3, axis=2)
211
- return im[boolmask3c].reshape((h, w - 1, 3))
212
-
213
- def remove_seam_grayscale(im, boolmask):
214
- h, w = im.shape[:2]
215
- return im[boolmask].reshape((h, w - 1))
216
-
217
- def get_minimum_seam(im, mask=None, remove_mask=None):
218
- """
219
- DP algorithm for finding the seam of minimum energy. Code adapted from
220
- https://karthikkaranth.me/blog/implementing-seam-carving-with-python/
221
- """
222
- h, w = im.shape[:2]
223
- energyfn = forward_energy if USE_FORWARD_ENERGY else backward_energy
224
- M = energyfn(im)
225
-
226
- if mask is not None:
227
- M[np.where(mask > MASK_THRESHOLD)] = ENERGY_MASK_CONST
228
-
229
- # give removal mask priority over protective mask by using larger negative value
230
- if remove_mask is not None:
231
- M[np.where(remove_mask > MASK_THRESHOLD)] = -ENERGY_MASK_CONST * 100
232
-
233
- seam_idx, boolmask = compute_shortest_path(M, im, h, w)
234
-
235
- return np.array(seam_idx), boolmask
236
-
237
- def compute_shortest_path(M, im, h, w):
238
- backtrack = np.zeros_like(M, dtype=np.int_)
239
-
240
-
241
- # populate DP matrix
242
- for i in range(1, h):
243
- for j in range(0, w):
244
- if j == 0:
245
- idx = np.argmin(M[i - 1, j:j + 2])
246
- backtrack[i, j] = idx + j
247
- min_energy = M[i-1, idx + j]
248
- else:
249
- idx = np.argmin(M[i - 1, j - 1:j + 2])
250
- backtrack[i, j] = idx + j - 1
251
- min_energy = M[i - 1, idx + j - 1]
252
-
253
- M[i, j] += min_energy
254
-
255
- # backtrack to find path
256
- seam_idx = []
257
- boolmask = np.ones((h, w), dtype=np.bool_)
258
- j = np.argmin(M[-1])
259
- for i in range(h-1, -1, -1):
260
- boolmask[i, j] = False
261
- seam_idx.append(j)
262
- j = backtrack[i, j]
263
-
264
- seam_idx.reverse()
265
- return seam_idx, boolmask
266
-
267
- ########################################
268
- # MAIN ALGORITHM
269
- ########################################
270
-
271
- def seams_removal(im, num_remove, mask=None, vis=False, rot=False):
272
- for _ in range(num_remove):
273
- seam_idx, boolmask = get_minimum_seam(im, mask)
274
- if vis:
275
- visualize(im, boolmask, rotate=rot)
276
- im = remove_seam(im, boolmask)
277
- if mask is not None:
278
- mask = remove_seam_grayscale(mask, boolmask)
279
- return im, mask
280
-
281
-
282
- def seams_insertion(im, num_add, mask=None, vis=False, rot=False):
283
- seams_record = []
284
- temp_im = im.copy()
285
- temp_mask = mask.copy() if mask is not None else None
286
-
287
- for _ in range(num_add):
288
- seam_idx, boolmask = get_minimum_seam(temp_im, temp_mask)
289
- if vis:
290
- visualize(temp_im, boolmask, rotate=rot)
291
-
292
- seams_record.append(seam_idx)
293
- temp_im = remove_seam(temp_im, boolmask)
294
- if temp_mask is not None:
295
- temp_mask = remove_seam_grayscale(temp_mask, boolmask)
296
-
297
- seams_record.reverse()
298
-
299
- for _ in range(num_add):
300
- seam = seams_record.pop()
301
- im = add_seam(im, seam)
302
- if vis:
303
- visualize(im, rotate=rot)
304
- if mask is not None:
305
- mask = add_seam_grayscale(mask, seam)
306
-
307
- # update the remaining seam indices
308
- for remaining_seam in seams_record:
309
- remaining_seam[np.where(remaining_seam >= seam)] += 2
310
-
311
- return im, mask
312
-
313
- ########################################
314
- # MAIN DRIVER FUNCTIONS
315
- ########################################
316
-
317
- def seam_carve(im, dy, dx, mask=None, vis=False):
318
- im = im.astype(np.float64)
319
- h, w = im.shape[:2]
320
- assert h + dy > 0 and w + dx > 0 and dy <= h and dx <= w
321
-
322
- if mask is not None:
323
- mask = mask.astype(np.float64)
324
-
325
- output = im
326
-
327
- if dx < 0:
328
- output, mask = seams_removal(output, -dx, mask, vis)
329
-
330
- elif dx > 0:
331
- output, mask = seams_insertion(output, dx, mask, vis)
332
-
333
- if dy < 0:
334
- output = rotate_image(output, True)
335
- if mask is not None:
336
- mask = rotate_image(mask, True)
337
- output, mask = seams_removal(output, -dy, mask, vis, rot=True)
338
- output = rotate_image(output, False)
339
-
340
- elif dy > 0:
341
- output = rotate_image(output, True)
342
- if mask is not None:
343
- mask = rotate_image(mask, True)
344
- output, mask = seams_insertion(output, dy, mask, vis, rot=True)
345
- output = rotate_image(output, False)
346
-
347
- return output
348
-
349
-
350
- def object_removal(im, rmask, mask=None, vis=False, horizontal_removal=False):
351
- im = im.astype(np.float64)
352
- rmask = rmask.astype(np.float64)
353
- if mask is not None:
354
- mask = mask.astype(np.float64)
355
- output = im
356
-
357
- h, w = im.shape[:2]
358
-
359
- if horizontal_removal:
360
- output = rotate_image(output, True)
361
- rmask = rotate_image(rmask, True)
362
- if mask is not None:
363
- mask = rotate_image(mask, True)
364
-
365
- while len(np.where(rmask > MASK_THRESHOLD)[0]) > 0:
366
- seam_idx, boolmask = get_minimum_seam(output, mask, rmask)
367
- if vis:
368
- visualize(output, boolmask, rotate=horizontal_removal)
369
- output = remove_seam(output, boolmask)
370
- rmask = remove_seam_grayscale(rmask, boolmask)
371
- if mask is not None:
372
- mask = remove_seam_grayscale(mask, boolmask)
373
-
374
- num_add = (h if horizontal_removal else w) - output.shape[1]
375
- output, mask = seams_insertion(output, num_add, mask, vis, rot=horizontal_removal)
376
- if horizontal_removal:
377
- output = rotate_image(output, False)
378
-
379
- return output
380
-
381
-
382
-
383
- def s_image(im,mask,vs,hs,mode="resize"):
384
- im = cv2.cvtColor(im, cv2.COLOR_RGBA2RGB)
385
- mask = 255-mask[:,:,3]
386
- h, w = im.shape[:2]
387
- if SHOULD_DOWNSIZE and w > DOWNSIZE_WIDTH:
388
- im = resize(im, width=DOWNSIZE_WIDTH)
389
- if mask is not None:
390
- mask = resize(mask, width=DOWNSIZE_WIDTH)
391
-
392
- # image resize mode
393
- if mode=="resize":
394
- dy = hs#reverse
395
- dx = vs#reverse
396
- assert dy is not None and dx is not None
397
- output = seam_carve(im, dy, dx, mask, False)
398
-
399
-
400
- # object removal mode
401
- elif mode=="remove":
402
- assert mask is not None
403
- output = object_removal(im, mask, None, False, True)
404
-
405
- return output
406
-
407
-
408
- ##### Inpainting helper code
409
-
410
- def run(image, mask):
411
- """
412
- image: [C, H, W]
413
- mask: [1, H, W]
414
- return: BGR IMAGE
415
- """
416
- origin_height, origin_width = image.shape[1:]
417
- image = pad_img_to_modulo(image, mod=8)
418
- mask = pad_img_to_modulo(mask, mod=8)
419
-
420
- mask = (mask > 0) * 1
421
- image = torch.from_numpy(image).unsqueeze(0).to(device)
422
- mask = torch.from_numpy(mask).unsqueeze(0).to(device)
423
-
424
- start = time.time()
425
- with torch.no_grad():
426
- inpainted_image = model(image, mask)
427
-
428
- print(f"process time: {(time.time() - start)*1000}ms")
429
- cur_res = inpainted_image[0].permute(1, 2, 0).detach().cpu().numpy()
430
- cur_res = cur_res[0:origin_height, 0:origin_width, :]
431
- cur_res = np.clip(cur_res * 255, 0, 255).astype("uint8")
432
- cur_res = cv2.cvtColor(cur_res, cv2.COLOR_BGR2RGB)
433
- return cur_res
434
-
435
-
436
- def get_args_parser():
437
- parser = argparse.ArgumentParser()
438
- parser.add_argument("--port", default=8080, type=int)
439
- parser.add_argument("--device", default="cuda", type=str)
440
- parser.add_argument("--debug", action="store_true")
441
- return parser.parse_args()
442
-
443
-
444
- def process_inpaint(image, mask):
445
- image = cv2.cvtColor(image, cv2.COLOR_RGBA2RGB)
446
- original_shape = image.shape
447
- interpolation = cv2.INTER_CUBIC
448
-
449
- #size_limit: Union[int, str] = request.form.get("sizeLimit", "1080")
450
- #if size_limit == "Original":
451
- size_limit = max(image.shape)
452
- #else:
453
- # size_limit = int(size_limit)
454
-
455
- print(f"Origin image shape: {original_shape}")
456
- image = resize_max_size(image, size_limit=size_limit, interpolation=interpolation)
457
- print(f"Resized image shape: {image.shape}")
458
- image = norm_img(image)
459
-
460
- mask = 255-mask[:,:,3]
461
- mask = resize_max_size(mask, size_limit=size_limit, interpolation=interpolation)
462
- mask = norm_img(mask)
463
-
464
- res_np_img = run(image, mask)
465
-
466
  return cv2.cvtColor(res_np_img, cv2.COLOR_BGR2RGB)
 
1
+ import base64
2
+ import json
3
+ import os
4
+ import re
5
+ import time
6
+ import uuid
7
+ from io import BytesIO
8
+ from pathlib import Path
9
+ import cv2
10
+
11
+ # For inpainting
12
+
13
+ import numpy as np
14
+ import pandas as pd
15
+ import streamlit as st
16
+ from PIL import Image
17
+ from streamlit_drawable_canvas import st_canvas
18
+
19
+
20
+ import argparse
21
+ import io
22
+ import multiprocessing
23
+ from typing import Union
24
+
25
+ import torch
26
+
27
+ try:
28
+ torch._C._jit_override_can_fuse_on_cpu(False)
29
+ torch._C._jit_override_can_fuse_on_gpu(False)
30
+ torch._C._jit_set_texpr_fuser_enabled(False)
31
+ torch._C._jit_set_nvfuser_enabled(False)
32
+ except:
33
+ pass
34
+
35
+ from src.helper import (
36
+ download_model,
37
+ load_img,
38
+ norm_img,
39
+ numpy_to_bytes,
40
+ pad_img_to_modulo,
41
+ resize_max_size,
42
+ )
43
+
44
+ NUM_THREADS = str(multiprocessing.cpu_count())
45
+
46
+ os.environ["OMP_NUM_THREADS"] = NUM_THREADS
47
+ os.environ["OPENBLAS_NUM_THREADS"] = NUM_THREADS
48
+ os.environ["MKL_NUM_THREADS"] = NUM_THREADS
49
+ os.environ["VECLIB_MAXIMUM_THREADS"] = NUM_THREADS
50
+ os.environ["NUMEXPR_NUM_THREADS"] = NUM_THREADS
51
+ if os.environ.get("CACHE_DIR"):
52
+ os.environ["TORCH_HOME"] = os.environ["CACHE_DIR"]
53
+
54
+ #BUILD_DIR = os.environ.get("LAMA_CLEANER_BUILD_DIR", "./lama_cleaner/app/build")
55
+
56
+ # For Seam-carving
57
+
58
+ from scipy import ndimage as ndi
59
+
60
+ SEAM_COLOR = np.array([255, 200, 200]) # seam visualization color (BGR)
61
+ SHOULD_DOWNSIZE = True # if True, downsize image for faster carving
62
+ DOWNSIZE_WIDTH = 500 # resized image width if SHOULD_DOWNSIZE is True
63
+ ENERGY_MASK_CONST = 100000.0 # large energy value for protective masking
64
+ MASK_THRESHOLD = 10 # minimum pixel intensity for binary mask
65
+ USE_FORWARD_ENERGY = True # if True, use forward energy algorithm
66
+
67
+ device = torch.device("cpu")
68
+ model_path = "./assets/erase.pt"
69
+ model = torch.jit.load(model_path, map_location="cpu")
70
+ model = model.to(device)
71
+ model.eval()
72
+
73
+
74
+ ########################################
75
+ # UTILITY CODE
76
+ ########################################
77
+
78
+
79
+ def visualize(im, boolmask=None, rotate=False):
80
+ vis = im.astype(np.uint8)
81
+ if boolmask is not None:
82
+ vis[np.where(boolmask == False)] = SEAM_COLOR
83
+ if rotate:
84
+ vis = rotate_image(vis, False)
85
+ cv2.imshow("visualization", vis)
86
+ cv2.waitKey(1)
87
+ return vis
88
+
89
+ def resize(image, width):
90
+ dim = None
91
+ h, w = image.shape[:2]
92
+ dim = (width, int(h * width / float(w)))
93
+ image = image.astype('float32')
94
+ return cv2.resize(image, dim)
95
+
96
+ def rotate_image(image, clockwise):
97
+ k = 1 if clockwise else 3
98
+ return np.rot90(image, k)
99
+
100
+
101
+ ########################################
102
+ # ENERGY FUNCTIONS
103
+ ########################################
104
+
105
+ def backward_energy(im):
106
+ """
107
+ Simple gradient magnitude energy map.
108
+ """
109
+ xgrad = ndi.convolve1d(im, np.array([1, 0, -1]), axis=1, mode='wrap')
110
+ ygrad = ndi.convolve1d(im, np.array([1, 0, -1]), axis=0, mode='wrap')
111
+
112
+ grad_mag = np.sqrt(np.sum(xgrad**2, axis=2) + np.sum(ygrad**2, axis=2))
113
+
114
+ # vis = visualize(grad_mag)
115
+ # cv2.imwrite("backward_energy_demo.jpg", vis)
116
+
117
+ return grad_mag
118
+
119
+ def forward_energy(im):
120
+ """
121
+ Forward energy algorithm as described in "Improved Seam Carving for Video Retargeting"
122
+ by Rubinstein, Shamir, Avidan.
123
+ Vectorized code adapted from
124
+ https://github.com/axu2/improved-seam-carving.
125
+ """
126
+ h, w = im.shape[:2]
127
+ im = cv2.cvtColor(im.astype(np.uint8), cv2.COLOR_BGR2GRAY).astype(np.float64)
128
+
129
+ energy = np.zeros((h, w))
130
+ m = np.zeros((h, w))
131
+
132
+ U = np.roll(im, 1, axis=0)
133
+ L = np.roll(im, 1, axis=1)
134
+ R = np.roll(im, -1, axis=1)
135
+
136
+ cU = np.abs(R - L)
137
+ cL = np.abs(U - L) + cU
138
+ cR = np.abs(U - R) + cU
139
+
140
+ for i in range(1, h):
141
+ mU = m[i-1]
142
+ mL = np.roll(mU, 1)
143
+ mR = np.roll(mU, -1)
144
+
145
+ mULR = np.array([mU, mL, mR])
146
+ cULR = np.array([cU[i], cL[i], cR[i]])
147
+ mULR += cULR
148
+
149
+ argmins = np.argmin(mULR, axis=0)
150
+ m[i] = np.choose(argmins, mULR)
151
+ energy[i] = np.choose(argmins, cULR)
152
+
153
+ # vis = visualize(energy)
154
+ # cv2.imwrite("forward_energy_demo.jpg", vis)
155
+
156
+ return energy
157
+
158
+ ########################################
159
+ # SEAM HELPER FUNCTIONS
160
+ ########################################
161
+
162
+ def add_seam(im, seam_idx):
163
+ """
164
+ Add a vertical seam to a 3-channel color image at the indices provided
165
+ by averaging the pixels values to the left and right of the seam.
166
+ Code adapted from https://github.com/vivianhylee/seam-carving.
167
+ """
168
+ h, w = im.shape[:2]
169
+ output = np.zeros((h, w + 1, 3))
170
+ for row in range(h):
171
+ col = seam_idx[row]
172
+ for ch in range(3):
173
+ if col == 0:
174
+ p = np.mean(im[row, col: col + 2, ch])
175
+ output[row, col, ch] = im[row, col, ch]
176
+ output[row, col + 1, ch] = p
177
+ output[row, col + 1:, ch] = im[row, col:, ch]
178
+ else:
179
+ p = np.mean(im[row, col - 1: col + 1, ch])
180
+ output[row, : col, ch] = im[row, : col, ch]
181
+ output[row, col, ch] = p
182
+ output[row, col + 1:, ch] = im[row, col:, ch]
183
+
184
+ return output
185
+
186
+ def add_seam_grayscale(im, seam_idx):
187
+ """
188
+ Add a vertical seam to a grayscale image at the indices provided
189
+ by averaging the pixels values to the left and right of the seam.
190
+ """
191
+ h, w = im.shape[:2]
192
+ output = np.zeros((h, w + 1))
193
+ for row in range(h):
194
+ col = seam_idx[row]
195
+ if col == 0:
196
+ p = np.mean(im[row, col: col + 2])
197
+ output[row, col] = im[row, col]
198
+ output[row, col + 1] = p
199
+ output[row, col + 1:] = im[row, col:]
200
+ else:
201
+ p = np.mean(im[row, col - 1: col + 1])
202
+ output[row, : col] = im[row, : col]
203
+ output[row, col] = p
204
+ output[row, col + 1:] = im[row, col:]
205
+
206
+ return output
207
+
208
+ def remove_seam(im, boolmask):
209
+ h, w = im.shape[:2]
210
+ boolmask3c = np.stack([boolmask] * 3, axis=2)
211
+ return im[boolmask3c].reshape((h, w - 1, 3))
212
+
213
+ def remove_seam_grayscale(im, boolmask):
214
+ h, w = im.shape[:2]
215
+ return im[boolmask].reshape((h, w - 1))
216
+
217
+ def get_minimum_seam(im, mask=None, remove_mask=None):
218
+ """
219
+ DP algorithm for finding the seam of minimum energy. Code adapted from
220
+ https://karthikkaranth.me/blog/implementing-seam-carving-with-python/
221
+ """
222
+ h, w = im.shape[:2]
223
+ energyfn = forward_energy if USE_FORWARD_ENERGY else backward_energy
224
+ M = energyfn(im)
225
+
226
+ if mask is not None:
227
+ M[np.where(mask > MASK_THRESHOLD)] = ENERGY_MASK_CONST
228
+
229
+ # give removal mask priority over protective mask by using larger negative value
230
+ if remove_mask is not None:
231
+ M[np.where(remove_mask > MASK_THRESHOLD)] = -ENERGY_MASK_CONST * 100
232
+
233
+ seam_idx, boolmask = compute_shortest_path(M, im, h, w)
234
+
235
+ return np.array(seam_idx), boolmask
236
+
237
+ def compute_shortest_path(M, im, h, w):
238
+ backtrack = np.zeros_like(M, dtype=np.int_)
239
+
240
+
241
+ # populate DP matrix
242
+ for i in range(1, h):
243
+ for j in range(0, w):
244
+ if j == 0:
245
+ idx = np.argmin(M[i - 1, j:j + 2])
246
+ backtrack[i, j] = idx + j
247
+ min_energy = M[i-1, idx + j]
248
+ else:
249
+ idx = np.argmin(M[i - 1, j - 1:j + 2])
250
+ backtrack[i, j] = idx + j - 1
251
+ min_energy = M[i - 1, idx + j - 1]
252
+
253
+ M[i, j] += min_energy
254
+
255
+ # backtrack to find path
256
+ seam_idx = []
257
+ boolmask = np.ones((h, w), dtype=np.bool_)
258
+ j = np.argmin(M[-1])
259
+ for i in range(h-1, -1, -1):
260
+ boolmask[i, j] = False
261
+ seam_idx.append(j)
262
+ j = backtrack[i, j]
263
+
264
+ seam_idx.reverse()
265
+ return seam_idx, boolmask
266
+
267
+ ########################################
268
+ # MAIN ALGORITHM
269
+ ########################################
270
+
271
+ def seams_removal(im, num_remove, mask=None, vis=False, rot=False):
272
+ for _ in range(num_remove):
273
+ seam_idx, boolmask = get_minimum_seam(im, mask)
274
+ if vis:
275
+ visualize(im, boolmask, rotate=rot)
276
+ im = remove_seam(im, boolmask)
277
+ if mask is not None:
278
+ mask = remove_seam_grayscale(mask, boolmask)
279
+ return im, mask
280
+
281
+
282
+ def seams_insertion(im, num_add, mask=None, vis=False, rot=False):
283
+ seams_record = []
284
+ temp_im = im.copy()
285
+ temp_mask = mask.copy() if mask is not None else None
286
+
287
+ for _ in range(num_add):
288
+ seam_idx, boolmask = get_minimum_seam(temp_im, temp_mask)
289
+ if vis:
290
+ visualize(temp_im, boolmask, rotate=rot)
291
+
292
+ seams_record.append(seam_idx)
293
+ temp_im = remove_seam(temp_im, boolmask)
294
+ if temp_mask is not None:
295
+ temp_mask = remove_seam_grayscale(temp_mask, boolmask)
296
+
297
+ seams_record.reverse()
298
+
299
+ for _ in range(num_add):
300
+ seam = seams_record.pop()
301
+ im = add_seam(im, seam)
302
+ if vis:
303
+ visualize(im, rotate=rot)
304
+ if mask is not None:
305
+ mask = add_seam_grayscale(mask, seam)
306
+
307
+ # update the remaining seam indices
308
+ for remaining_seam in seams_record:
309
+ remaining_seam[np.where(remaining_seam >= seam)] += 2
310
+
311
+ return im, mask
312
+
313
+ ########################################
314
+ # MAIN DRIVER FUNCTIONS
315
+ ########################################
316
+
317
+ def seam_carve(im, dy, dx, mask=None, vis=False):
318
+ im = im.astype(np.float64)
319
+ h, w = im.shape[:2]
320
+ assert h + dy > 0 and w + dx > 0 and dy <= h and dx <= w
321
+
322
+ if mask is not None:
323
+ mask = mask.astype(np.float64)
324
+
325
+ output = im
326
+
327
+ if dx < 0:
328
+ output, mask = seams_removal(output, -dx, mask, vis)
329
+
330
+ elif dx > 0:
331
+ output, mask = seams_insertion(output, dx, mask, vis)
332
+
333
+ if dy < 0:
334
+ output = rotate_image(output, True)
335
+ if mask is not None:
336
+ mask = rotate_image(mask, True)
337
+ output, mask = seams_removal(output, -dy, mask, vis, rot=True)
338
+ output = rotate_image(output, False)
339
+
340
+ elif dy > 0:
341
+ output = rotate_image(output, True)
342
+ if mask is not None:
343
+ mask = rotate_image(mask, True)
344
+ output, mask = seams_insertion(output, dy, mask, vis, rot=True)
345
+ output = rotate_image(output, False)
346
+
347
+ return output
348
+
349
+
350
+ def object_removal(im, rmask, mask=None, vis=False, horizontal_removal=False):
351
+ im = im.astype(np.float64)
352
+ rmask = rmask.astype(np.float64)
353
+ if mask is not None:
354
+ mask = mask.astype(np.float64)
355
+ output = im
356
+
357
+ h, w = im.shape[:2]
358
+
359
+ if horizontal_removal:
360
+ output = rotate_image(output, True)
361
+ rmask = rotate_image(rmask, True)
362
+ if mask is not None:
363
+ mask = rotate_image(mask, True)
364
+
365
+ while len(np.where(rmask > MASK_THRESHOLD)[0]) > 0:
366
+ seam_idx, boolmask = get_minimum_seam(output, mask, rmask)
367
+ if vis:
368
+ visualize(output, boolmask, rotate=horizontal_removal)
369
+ output = remove_seam(output, boolmask)
370
+ rmask = remove_seam_grayscale(rmask, boolmask)
371
+ if mask is not None:
372
+ mask = remove_seam_grayscale(mask, boolmask)
373
+
374
+ num_add = (h if horizontal_removal else w) - output.shape[1]
375
+ output, mask = seams_insertion(output, num_add, mask, vis, rot=horizontal_removal)
376
+ if horizontal_removal:
377
+ output = rotate_image(output, False)
378
+
379
+ return output
380
+
381
+
382
+
383
+ def s_image(im,mask,vs,hs,mode="resize"):
384
+ im = cv2.cvtColor(im, cv2.COLOR_RGBA2RGB)
385
+ mask = 255-mask[:,:,3]
386
+ h, w = im.shape[:2]
387
+ if SHOULD_DOWNSIZE and w > DOWNSIZE_WIDTH:
388
+ im = resize(im, width=DOWNSIZE_WIDTH)
389
+ if mask is not None:
390
+ mask = resize(mask, width=DOWNSIZE_WIDTH)
391
+
392
+ # image resize mode
393
+ if mode=="resize":
394
+ dy = hs#reverse
395
+ dx = vs#reverse
396
+ assert dy is not None and dx is not None
397
+ output = seam_carve(im, dy, dx, mask, False)
398
+
399
+
400
+ # object removal mode
401
+ elif mode=="remove":
402
+ assert mask is not None
403
+ output = object_removal(im, mask, None, False, True)
404
+
405
+ return output
406
+
407
+
408
+ ##### Inpainting helper code
409
+
410
+ def run(image, mask):
411
+ """
412
+ image: [C, H, W]
413
+ mask: [1, H, W]
414
+ return: BGR IMAGE
415
+ """
416
+ origin_height, origin_width = image.shape[1:]
417
+ image = pad_img_to_modulo(image, mod=8)
418
+ mask = pad_img_to_modulo(mask, mod=8)
419
+
420
+ mask = (mask > 0) * 1
421
+ image = torch.from_numpy(image).unsqueeze(0).to(device)
422
+ mask = torch.from_numpy(mask).unsqueeze(0).to(device)
423
+
424
+ start = time.time()
425
+ with torch.no_grad():
426
+ inpainted_image = model(image, mask)
427
+
428
+ print(f"process time: {(time.time() - start)*1000}ms")
429
+ cur_res = inpainted_image[0].permute(1, 2, 0).detach().cpu().numpy()
430
+ cur_res = cur_res[0:origin_height, 0:origin_width, :]
431
+ cur_res = np.clip(cur_res * 255, 0, 255).astype("uint8")
432
+ cur_res = cv2.cvtColor(cur_res, cv2.COLOR_BGR2RGB)
433
+ return cur_res
434
+
435
+
436
+ def get_args_parser():
437
+ parser = argparse.ArgumentParser()
438
+ parser.add_argument("--port", default=8080, type=int)
439
+ parser.add_argument("--device", default="cuda", type=str)
440
+ parser.add_argument("--debug", action="store_true")
441
+ return parser.parse_args()
442
+
443
+
444
+ def process_inpaint(image, mask):
445
+ image = cv2.cvtColor(image, cv2.COLOR_RGBA2RGB)
446
+ original_shape = image.shape
447
+ interpolation = cv2.INTER_CUBIC
448
+
449
+ #size_limit: Union[int, str] = request.form.get("sizeLimit", "1080")
450
+ #if size_limit == "Original":
451
+ size_limit = max(image.shape)
452
+ #else:
453
+ # size_limit = int(size_limit)
454
+
455
+ print(f"Origin image shape: {original_shape}")
456
+ image = resize_max_size(image, size_limit=size_limit, interpolation=interpolation)
457
+ print(f"Resized image shape: {image.shape}")
458
+ image = norm_img(image)
459
+
460
+ mask = 255-mask[:,:,3]
461
+ mask = resize_max_size(mask, size_limit=size_limit, interpolation=interpolation)
462
+ mask = norm_img(mask)
463
+
464
+ res_np_img = run(image, mask)
465
+
466
  return cv2.cvtColor(res_np_img, cv2.COLOR_BGR2RGB)