JinhuaL1ANG's picture
v1
9a6dac6
import cv2
import torch
import numpy as np
import torch.nn.functional as F
def resize_numpy_image(image, max_resolution=768 * 768, resize_short_edge=None):
h, w = image.shape[:2]
w_org = image.shape[1]
if resize_short_edge is not None:
k = resize_short_edge / min(h, w)
else:
k = max_resolution / (h * w)
k = k**0.5
h = int(np.round(h * k / 64)) * 64
w = int(np.round(w * k / 64)) * 64
image = cv2.resize(image, (w, h), interpolation=cv2.INTER_LANCZOS4)
scale = w / w_org
return image, scale
def split_ldm(ldm):
x = []
y = []
for p in ldm:
x.append(p[0])
y.append(p[1])
return x, y
def process_move(
path_mask, # target region of original map
h,
w,
dx,
dy,
scale,
input_scale,
resize_scale_x,
resize_scale_y,
up_scale,
up_ft_index,
w_edit,
w_content,
w_contrast,
w_inpaint,
precision,
path_mask_ref=None,
path_mask_keep=None,
):
dx, dy = dx * input_scale, dy * input_scale
mask_x0 = path_mask
mask_x0_ref = path_mask_ref
mask_x0_keep = path_mask_keep
mask_x0 = (mask_x0 > 0.5).float().to("cuda", dtype=precision)
if mask_x0_ref is not None:
mask_x0_ref = (mask_x0_ref > 0.5).float().to("cuda", dtype=precision)
# Define region to keep if `path_mask_keep` is given
if mask_x0_keep is not None:
mask_x0_keep = (mask_x0_keep > 0.5).float().to("cuda", dtype=precision)
mask_keep = (
F.interpolate(
mask_x0_keep[None, None],
(int(mask_x0_keep.shape[-2] // scale), int(mask_x0_keep.shape[-1] // scale)),
)
> 0.5
).float()
else:
mask_keep = None
mask_org = (
F.interpolate(
mask_x0[None, None],
(int(mask_x0.shape[-2] // scale), int(mask_x0.shape[-1] // scale)),
)
> 0.5
)
mask_tar = (
F.interpolate(
mask_x0[None, None],
(
int(mask_x0.shape[-2] // scale * resize_scale_y),
int(mask_x0.shape[-1] // scale * resize_scale_x),
),
)
> 0.5
)
mask_cur = torch.roll(
mask_tar,
(int(dy // scale * resize_scale_y), int(dx // scale * resize_scale_x)),
(-2, -1),
)
temp = torch.zeros(1, 1, mask_org.shape[-2], mask_org.shape[-1]).to(
mask_org.device
)
pad_x = (temp.shape[-1] - mask_cur.shape[-1]) // 2
pad_y = (temp.shape[-2] - mask_cur.shape[-2]) // 2
px_tmp, py_tmp = max(pad_x, 0), max(pad_y, 0)
px_tar, py_tar = max(-pad_x, 0), max(-pad_y, 0)
temp[:,:,py_tmp:py_tmp+mask_cur.shape[-2],px_tmp:px_tmp+mask_cur.shape[-1]] = mask_cur[
:,:,py_tar:py_tar+temp.shape[-2],px_tar:px_tar+temp.shape[-1]]
# To avoid mask misaligned by shifting and cropping
_mask_valid = torch.zeros_like(mask_cur)
_mask_valid[:,:,py_tar:py_tar+temp.shape[-2],px_tar:px_tar+temp.shape[-1]] = 1
_mask_valid = (torch.roll(
_mask_valid,
(-int(dy // scale * resize_scale_y), int(-dx // scale * resize_scale_x)),
(-2, -1),
) > 0.5)
mask_tar = torch.logical_and(mask_tar, _mask_valid)
# Ensure the editing region is within the spectrogram
if resize_scale_x > 1 or resize_scale_y > 1:
sum_before = torch.sum(mask_tar) # replace `mask_cur` here
sum_after = torch.sum(temp)
if sum_after != sum_before:
raise ValueError("Resize out of bounds, exiting.")
mask_cur = temp > 0.5
# Region of uninterested region is selected region when `mask_keep` is given
if mask_keep is not None:
mask_other = mask_keep > 0.5
else:
mask_other = (1 - ((mask_cur + mask_org) > 0.5).float()) > 0.5
mask_overlap = ((mask_cur.float() + mask_org.float()) > 1.5).float()
mask_non_overlap = (mask_org.float() - mask_overlap) > 0.5
return {
"mask_x0": mask_x0,
"mask_x0_ref": mask_x0_ref,
"mask_x0_keep": mask_x0_keep,
"mask_tar": mask_tar,
"mask_cur": mask_cur,
"mask_other": mask_other,
"mask_overlap": mask_overlap,
"mask_non_overlap": mask_non_overlap,
"mask_keep": mask_keep,
"up_scale": up_scale,
"up_ft_index": up_ft_index,
"resize_scale_x": resize_scale_x,
"resize_scale_y": resize_scale_y,
"w_edit": w_edit,
"w_content": w_content,
"w_contrast": w_contrast,
"w_inpaint": w_inpaint,
}
def process_paste(
path_mask,
h,
w,
dx,
dy,
scale,
input_scale,
up_scale,
up_ft_index,
w_edit,
w_content,
precision,
resize_scale_x,
resize_scale_y,
):
dx, dy = dx * input_scale, dy * input_scale
if isinstance(path_mask, str):
mask_base = cv2.imread(path_mask)
else:
mask_base = path_mask
mask_base = mask_base[None, None]
dict_mask = {}
mask_base = (mask_base > 0.5).to("cuda", dtype=precision)
#####[START] Original rescale and fit method.#####
# if resize_scale is not None and resize_scale != 1:
# hi, wi = mask_base.shape[-2], mask_base.shape[-1]
# mask_base = F.interpolate(
# mask_base, (int(hi * resize_scale), int(wi * resize_scale))
# )
# pad_size_x = np.abs(mask_base.shape[-1] - wi) // 2
# pad_size_y = np.abs(mask_base.shape[-2] - hi) // 2
# if resize_scale > 1:
# mask_base = mask_base[
# :, :, pad_size_y : pad_size_y + hi, pad_size_x : pad_size_x + wi
# ]
# else:
# temp = torch.zeros(1, 1, hi, wi).to(mask_base.device)
# temp[
# :,
# :,
# pad_size_y : pad_size_y + mask_base.shape[-2],
# pad_size_x : pad_size_x + mask_base.shape[-1],
# ] = mask_base
# mask_base = temp
#####[END] Original rescale and fit method.#####
hi, wi = mask_base.shape[-2], mask_base.shape[-1]
mask_base = F.interpolate(
mask_base, (int(hi*resize_scale_y), int(wi*resize_scale_x))
)
temp = torch.zeros(1, 1, hi, wi).to(mask_base.device)
pad_x, pad_y = (wi - mask_base.shape[-1]) // 2, (hi - mask_base.shape[-2]) // 2
px_tmp, py_tmp = max(pad_x, 0), max(pad_y, 0)
px_tar, py_tar = max(-pad_x, 0), max(-pad_y, 0)
temp[:,:,py_tmp:py_tmp+mask_base.shape[-2],px_tmp:px_tmp+mask_base.shape[-1]] = mask_base[
:,:,py_tar:py_tar+temp.shape[-2],px_tar:px_tar+temp.shape[-1]]
mask_base = temp
mask_replace = mask_base.clone()
mask_base = torch.roll(
mask_base, (int(dy*resize_scale_y), int(dx*resize_scale_x)), (-2, -1)) # (C,T,F)
dict_mask["base"] = mask_base[0, 0]
dict_mask["replace"] = mask_replace[0, 0]
mask_replace = (mask_replace > 0.5).to("cuda", dtype=precision)
mask_base_cur = (
F.interpolate(
mask_base,
(int(mask_base.shape[-2]//scale), int(mask_base.shape[-1]//scale)),
)
> 0.5
)
mask_replace_cur = torch.roll(
mask_base_cur, (-int(dy/scale), -int(dx/scale)), (-2, -1)
)
return {
"dict_mask": dict_mask,
"mask_base_cur": mask_base_cur,
"mask_replace_cur": mask_replace_cur,
"up_scale": up_scale,
"up_ft_index": up_ft_index,
"w_edit": w_edit,
"w_content": w_content,
"w_edit": w_edit,
"w_content": w_content,
}
# def process_paste(
# path_mask,
# h,
# w,
# dx,
# dy,
# scale,
# input_scale,
# up_scale,
# up_ft_index,
# w_edit,
# w_content,
# precision,
# resize_scale=None,
# ):
# dx, dy = dx * input_scale, dy * input_scale
# if isinstance(path_mask, str):
# mask_base = cv2.imread(path_mask)
# else:
# mask_base = path_mask
# mask_base = mask_base[None, None]
# dict_mask = {}
# mask_base = (mask_base > 0.5).to("cuda", dtype=precision)
# if resize_scale is not None and resize_scale != 1:
# hi, wi = mask_base.shape[-2], mask_base.shape[-1]
# mask_base = F.interpolate(
# mask_base, (int(hi * resize_scale), int(wi * resize_scale))
# )
# pad_size_x = np.abs(mask_base.shape[-1] - wi) // 2
# pad_size_y = np.abs(mask_base.shape[-2] - hi) // 2
# if resize_scale > 1:
# mask_base = mask_base[
# :, :, pad_size_y : pad_size_y + hi, pad_size_x : pad_size_x + wi
# ]
# else:
# temp = torch.zeros(1, 1, hi, wi).to(mask_base.device)
# temp[
# :,
# :,
# pad_size_y : pad_size_y + mask_base.shape[-2],
# pad_size_x : pad_size_x + mask_base.shape[-1],
# ] = mask_base
# mask_base = temp
# mask_replace = mask_base.clone()
# mask_base = torch.roll(mask_base, (int(dy), int(dx)), (-2, -1)) # (C,T,F)
# dict_mask["base"] = mask_base[0, 0]
# dict_mask["replace"] = mask_replace[0, 0]
# mask_replace = (mask_replace > 0.5).to("cuda", dtype=precision)
# mask_base_cur = (
# F.interpolate(
# mask_base,
# (int(mask_base.shape[-2] // scale), int(mask_base.shape[-1] // scale)),
# )
# > 0.5
# )
# mask_replace_cur = torch.roll(
# mask_base_cur, (-int(dy / scale), -int(dx / scale)), (-2, -1)
# )
# return {
# "dict_mask": dict_mask,
# "mask_base_cur": mask_base_cur,
# "mask_replace_cur": mask_replace_cur,
# "up_scale": up_scale,
# "up_ft_index": up_ft_index,
# "w_edit": w_edit,
# "w_content": w_content,
# "w_edit": w_edit,
# "w_content": w_content,
# }
def process_remove(
path_mask,
h,
w,
dx,
dy,
scale,
input_scale,
up_scale,
up_ft_index,
w_edit,
w_contrast,
w_content,
precision,
resize_scale_x,
resize_scale_y,
):
dx, dy = dx * input_scale, dy * input_scale
if isinstance(path_mask, str):
mask_base = cv2.imread(path_mask)
else:
mask_base = path_mask
mask_base = mask_base[None, None]
dict_mask = {}
mask_base = (mask_base > 0.5).to("cuda", dtype=precision)
#####[START] Original rescale and fit method.#####
# if resize_scale is not None and resize_scale != 1:
# hi, wi = mask_base.shape[-2], mask_base.shape[-1]
# mask_base = F.interpolate(
# mask_base, (int(hi * resize_scale), int(wi * resize_scale))
# )
# pad_size_x = np.abs(mask_base.shape[-1] - wi) // 2
# pad_size_y = np.abs(mask_base.shape[-2] - hi) // 2
# if resize_scale > 1:
# mask_base = mask_base[
# :, :, pad_size_y : pad_size_y + hi, pad_size_x : pad_size_x + wi
# ]
# else:
# temp = torch.zeros(1, 1, hi, wi).to(mask_base.device)
# temp[
# :,
# :,
# pad_size_y : pad_size_y + mask_base.shape[-2],
# pad_size_x : pad_size_x + mask_base.shape[-1],
# ] = mask_base
# mask_base = temp
#####[END] Original rescale and fit method.#####
hi, wi = mask_base.shape[-2], mask_base.shape[-1]
mask_base = F.interpolate(
mask_base, (int(hi*resize_scale_y), int(wi*resize_scale_x))
)
temp = torch.zeros(1, 1, hi, wi).to(mask_base.device)
pad_x, pad_y = (wi - mask_base.shape[-1]) // 2, (hi - mask_base.shape[-2]) // 2
px_tmp, py_tmp = max(pad_x, 0), max(pad_y, 0)
px_tar, py_tar = max(-pad_x, 0), max(-pad_y, 0)
temp[:,:,py_tmp:py_tmp+mask_base.shape[-2],px_tmp:px_tmp+mask_base.shape[-1]] = mask_base[
:,:,py_tar:py_tar+temp.shape[-2],px_tar:px_tar+temp.shape[-1]]
mask_base = temp
mask_replace = mask_base.clone()
mask_base = torch.roll(mask_base, (int(dy), int(dx)), (-2, -1)) # (C,T,F)
dict_mask["base"] = mask_base[0, 0]
dict_mask["replace"] = mask_replace[0, 0]
mask_replace = (mask_replace > 0.5).to("cuda", dtype=precision)
mask_base_cur = (
F.interpolate(
mask_base,
(int(mask_base.shape[-2] // scale), int(mask_base.shape[-1] // scale)),
)
> 0.5
)
mask_replace_cur = torch.roll(
mask_base_cur, (-int(dy / scale), -int(dx / scale)), (-2, -1)
)
return {
"dict_mask": dict_mask,
"mask_base_cur": mask_base_cur,
"mask_replace_cur": mask_replace_cur,
"up_scale": up_scale,
"up_ft_index": up_ft_index,
"w_edit": w_edit,
"w_contrast": w_contrast,
"w_content": w_content,
}
# def process_remove(
# path_mask,
# h,
# w,
# dx,
# dy,
# scale,
# input_scale,
# up_scale,
# up_ft_index,
# w_edit,
# w_contrast,
# w_content,
# precision,
# resize_scale=None,
# ):
# dx, dy = dx * input_scale, dy * input_scale
# if isinstance(path_mask, str):
# mask_base = cv2.imread(path_mask)
# else:
# mask_base = path_mask
# mask_base = mask_base[None, None]
# dict_mask = {}
# mask_base = (mask_base > 0.5).to("cuda", dtype=precision)
# if resize_scale is not None and resize_scale != 1:
# hi, wi = mask_base.shape[-2], mask_base.shape[-1]
# mask_base = F.interpolate(
# mask_base, (int(hi * resize_scale), int(wi * resize_scale))
# )
# pad_size_x = np.abs(mask_base.shape[-1] - wi) // 2
# pad_size_y = np.abs(mask_base.shape[-2] - hi) // 2
# if resize_scale > 1:
# mask_base = mask_base[
# :, :, pad_size_y : pad_size_y + hi, pad_size_x : pad_size_x + wi
# ]
# else:
# temp = torch.zeros(1, 1, hi, wi).to(mask_base.device)
# temp[
# :,
# :,
# pad_size_y : pad_size_y + mask_base.shape[-2],
# pad_size_x : pad_size_x + mask_base.shape[-1],
# ] = mask_base
# mask_base = temp
# mask_replace = mask_base.clone()
# mask_base = torch.roll(mask_base, (int(dy), int(dx)), (-2, -1)) # (C,T,F)
# dict_mask["base"] = mask_base[0, 0]
# dict_mask["replace"] = mask_replace[0, 0]
# mask_replace = (mask_replace > 0.5).to("cuda", dtype=precision)
# mask_base_cur = (
# F.interpolate(
# mask_base,
# (int(mask_base.shape[-2] // scale), int(mask_base.shape[-1] // scale)),
# )
# > 0.5
# )
# mask_replace_cur = torch.roll(
# mask_base_cur, (-int(dy / scale), -int(dx / scale)), (-2, -1)
# )
# return {
# "dict_mask": dict_mask,
# "mask_base_cur": mask_base_cur,
# "mask_replace_cur": mask_replace_cur,
# "up_scale": up_scale,
# "up_ft_index": up_ft_index,
# "w_edit": w_edit,
# "w_contrast": w_contrast,
# "w_content": w_content,
# }