Spaces:
Running
on
Zero
Running
on
Zero
import cv2 | |
import torch | |
import numpy as np | |
import torch.nn.functional as F | |
def resize_numpy_image(image, max_resolution=768 * 768, resize_short_edge=None): | |
h, w = image.shape[:2] | |
w_org = image.shape[1] | |
if resize_short_edge is not None: | |
k = resize_short_edge / min(h, w) | |
else: | |
k = max_resolution / (h * w) | |
k = k**0.5 | |
h = int(np.round(h * k / 64)) * 64 | |
w = int(np.round(w * k / 64)) * 64 | |
image = cv2.resize(image, (w, h), interpolation=cv2.INTER_LANCZOS4) | |
scale = w / w_org | |
return image, scale | |
def split_ldm(ldm): | |
x = [] | |
y = [] | |
for p in ldm: | |
x.append(p[0]) | |
y.append(p[1]) | |
return x, y | |
def process_move( | |
path_mask, # target region of original map | |
h, | |
w, | |
dx, | |
dy, | |
scale, | |
input_scale, | |
resize_scale_x, | |
resize_scale_y, | |
up_scale, | |
up_ft_index, | |
w_edit, | |
w_content, | |
w_contrast, | |
w_inpaint, | |
precision, | |
path_mask_ref=None, | |
path_mask_keep=None, | |
): | |
dx, dy = dx * input_scale, dy * input_scale | |
mask_x0 = path_mask | |
mask_x0_ref = path_mask_ref | |
mask_x0_keep = path_mask_keep | |
mask_x0 = (mask_x0 > 0.5).float().to("cuda", dtype=precision) | |
if mask_x0_ref is not None: | |
mask_x0_ref = (mask_x0_ref > 0.5).float().to("cuda", dtype=precision) | |
# Define region to keep if `path_mask_keep` is given | |
if mask_x0_keep is not None: | |
mask_x0_keep = (mask_x0_keep > 0.5).float().to("cuda", dtype=precision) | |
mask_keep = ( | |
F.interpolate( | |
mask_x0_keep[None, None], | |
(int(mask_x0_keep.shape[-2] // scale), int(mask_x0_keep.shape[-1] // scale)), | |
) | |
> 0.5 | |
).float() | |
else: | |
mask_keep = None | |
mask_org = ( | |
F.interpolate( | |
mask_x0[None, None], | |
(int(mask_x0.shape[-2] // scale), int(mask_x0.shape[-1] // scale)), | |
) | |
> 0.5 | |
) | |
mask_tar = ( | |
F.interpolate( | |
mask_x0[None, None], | |
( | |
int(mask_x0.shape[-2] // scale * resize_scale_y), | |
int(mask_x0.shape[-1] // scale * resize_scale_x), | |
), | |
) | |
> 0.5 | |
) | |
mask_cur = torch.roll( | |
mask_tar, | |
(int(dy // scale * resize_scale_y), int(dx // scale * resize_scale_x)), | |
(-2, -1), | |
) | |
temp = torch.zeros(1, 1, mask_org.shape[-2], mask_org.shape[-1]).to( | |
mask_org.device | |
) | |
pad_x = (temp.shape[-1] - mask_cur.shape[-1]) // 2 | |
pad_y = (temp.shape[-2] - mask_cur.shape[-2]) // 2 | |
px_tmp, py_tmp = max(pad_x, 0), max(pad_y, 0) | |
px_tar, py_tar = max(-pad_x, 0), max(-pad_y, 0) | |
temp[:,:,py_tmp:py_tmp+mask_cur.shape[-2],px_tmp:px_tmp+mask_cur.shape[-1]] = mask_cur[ | |
:,:,py_tar:py_tar+temp.shape[-2],px_tar:px_tar+temp.shape[-1]] | |
# To avoid mask misaligned by shifting and cropping | |
_mask_valid = torch.zeros_like(mask_cur) | |
_mask_valid[:,:,py_tar:py_tar+temp.shape[-2],px_tar:px_tar+temp.shape[-1]] = 1 | |
_mask_valid = (torch.roll( | |
_mask_valid, | |
(-int(dy // scale * resize_scale_y), int(-dx // scale * resize_scale_x)), | |
(-2, -1), | |
) > 0.5) | |
mask_tar = torch.logical_and(mask_tar, _mask_valid) | |
# Ensure the editing region is within the spectrogram | |
if resize_scale_x > 1 or resize_scale_y > 1: | |
sum_before = torch.sum(mask_tar) # replace `mask_cur` here | |
sum_after = torch.sum(temp) | |
if sum_after != sum_before: | |
raise ValueError("Resize out of bounds, exiting.") | |
mask_cur = temp > 0.5 | |
# Region of uninterested region is selected region when `mask_keep` is given | |
if mask_keep is not None: | |
mask_other = mask_keep > 0.5 | |
else: | |
mask_other = (1 - ((mask_cur + mask_org) > 0.5).float()) > 0.5 | |
mask_overlap = ((mask_cur.float() + mask_org.float()) > 1.5).float() | |
mask_non_overlap = (mask_org.float() - mask_overlap) > 0.5 | |
return { | |
"mask_x0": mask_x0, | |
"mask_x0_ref": mask_x0_ref, | |
"mask_x0_keep": mask_x0_keep, | |
"mask_tar": mask_tar, | |
"mask_cur": mask_cur, | |
"mask_other": mask_other, | |
"mask_overlap": mask_overlap, | |
"mask_non_overlap": mask_non_overlap, | |
"mask_keep": mask_keep, | |
"up_scale": up_scale, | |
"up_ft_index": up_ft_index, | |
"resize_scale_x": resize_scale_x, | |
"resize_scale_y": resize_scale_y, | |
"w_edit": w_edit, | |
"w_content": w_content, | |
"w_contrast": w_contrast, | |
"w_inpaint": w_inpaint, | |
} | |
def process_paste( | |
path_mask, | |
h, | |
w, | |
dx, | |
dy, | |
scale, | |
input_scale, | |
up_scale, | |
up_ft_index, | |
w_edit, | |
w_content, | |
precision, | |
resize_scale_x, | |
resize_scale_y, | |
): | |
dx, dy = dx * input_scale, dy * input_scale | |
if isinstance(path_mask, str): | |
mask_base = cv2.imread(path_mask) | |
else: | |
mask_base = path_mask | |
mask_base = mask_base[None, None] | |
dict_mask = {} | |
mask_base = (mask_base > 0.5).to("cuda", dtype=precision) | |
#####[START] Original rescale and fit method.##### | |
# if resize_scale is not None and resize_scale != 1: | |
# hi, wi = mask_base.shape[-2], mask_base.shape[-1] | |
# mask_base = F.interpolate( | |
# mask_base, (int(hi * resize_scale), int(wi * resize_scale)) | |
# ) | |
# pad_size_x = np.abs(mask_base.shape[-1] - wi) // 2 | |
# pad_size_y = np.abs(mask_base.shape[-2] - hi) // 2 | |
# if resize_scale > 1: | |
# mask_base = mask_base[ | |
# :, :, pad_size_y : pad_size_y + hi, pad_size_x : pad_size_x + wi | |
# ] | |
# else: | |
# temp = torch.zeros(1, 1, hi, wi).to(mask_base.device) | |
# temp[ | |
# :, | |
# :, | |
# pad_size_y : pad_size_y + mask_base.shape[-2], | |
# pad_size_x : pad_size_x + mask_base.shape[-1], | |
# ] = mask_base | |
# mask_base = temp | |
#####[END] Original rescale and fit method.##### | |
hi, wi = mask_base.shape[-2], mask_base.shape[-1] | |
mask_base = F.interpolate( | |
mask_base, (int(hi*resize_scale_y), int(wi*resize_scale_x)) | |
) | |
temp = torch.zeros(1, 1, hi, wi).to(mask_base.device) | |
pad_x, pad_y = (wi - mask_base.shape[-1]) // 2, (hi - mask_base.shape[-2]) // 2 | |
px_tmp, py_tmp = max(pad_x, 0), max(pad_y, 0) | |
px_tar, py_tar = max(-pad_x, 0), max(-pad_y, 0) | |
temp[:,:,py_tmp:py_tmp+mask_base.shape[-2],px_tmp:px_tmp+mask_base.shape[-1]] = mask_base[ | |
:,:,py_tar:py_tar+temp.shape[-2],px_tar:px_tar+temp.shape[-1]] | |
mask_base = temp | |
mask_replace = mask_base.clone() | |
mask_base = torch.roll( | |
mask_base, (int(dy*resize_scale_y), int(dx*resize_scale_x)), (-2, -1)) # (C,T,F) | |
dict_mask["base"] = mask_base[0, 0] | |
dict_mask["replace"] = mask_replace[0, 0] | |
mask_replace = (mask_replace > 0.5).to("cuda", dtype=precision) | |
mask_base_cur = ( | |
F.interpolate( | |
mask_base, | |
(int(mask_base.shape[-2]//scale), int(mask_base.shape[-1]//scale)), | |
) | |
> 0.5 | |
) | |
mask_replace_cur = torch.roll( | |
mask_base_cur, (-int(dy/scale), -int(dx/scale)), (-2, -1) | |
) | |
return { | |
"dict_mask": dict_mask, | |
"mask_base_cur": mask_base_cur, | |
"mask_replace_cur": mask_replace_cur, | |
"up_scale": up_scale, | |
"up_ft_index": up_ft_index, | |
"w_edit": w_edit, | |
"w_content": w_content, | |
"w_edit": w_edit, | |
"w_content": w_content, | |
} | |
# def process_paste( | |
# path_mask, | |
# h, | |
# w, | |
# dx, | |
# dy, | |
# scale, | |
# input_scale, | |
# up_scale, | |
# up_ft_index, | |
# w_edit, | |
# w_content, | |
# precision, | |
# resize_scale=None, | |
# ): | |
# dx, dy = dx * input_scale, dy * input_scale | |
# if isinstance(path_mask, str): | |
# mask_base = cv2.imread(path_mask) | |
# else: | |
# mask_base = path_mask | |
# mask_base = mask_base[None, None] | |
# dict_mask = {} | |
# mask_base = (mask_base > 0.5).to("cuda", dtype=precision) | |
# if resize_scale is not None and resize_scale != 1: | |
# hi, wi = mask_base.shape[-2], mask_base.shape[-1] | |
# mask_base = F.interpolate( | |
# mask_base, (int(hi * resize_scale), int(wi * resize_scale)) | |
# ) | |
# pad_size_x = np.abs(mask_base.shape[-1] - wi) // 2 | |
# pad_size_y = np.abs(mask_base.shape[-2] - hi) // 2 | |
# if resize_scale > 1: | |
# mask_base = mask_base[ | |
# :, :, pad_size_y : pad_size_y + hi, pad_size_x : pad_size_x + wi | |
# ] | |
# else: | |
# temp = torch.zeros(1, 1, hi, wi).to(mask_base.device) | |
# temp[ | |
# :, | |
# :, | |
# pad_size_y : pad_size_y + mask_base.shape[-2], | |
# pad_size_x : pad_size_x + mask_base.shape[-1], | |
# ] = mask_base | |
# mask_base = temp | |
# mask_replace = mask_base.clone() | |
# mask_base = torch.roll(mask_base, (int(dy), int(dx)), (-2, -1)) # (C,T,F) | |
# dict_mask["base"] = mask_base[0, 0] | |
# dict_mask["replace"] = mask_replace[0, 0] | |
# mask_replace = (mask_replace > 0.5).to("cuda", dtype=precision) | |
# mask_base_cur = ( | |
# F.interpolate( | |
# mask_base, | |
# (int(mask_base.shape[-2] // scale), int(mask_base.shape[-1] // scale)), | |
# ) | |
# > 0.5 | |
# ) | |
# mask_replace_cur = torch.roll( | |
# mask_base_cur, (-int(dy / scale), -int(dx / scale)), (-2, -1) | |
# ) | |
# return { | |
# "dict_mask": dict_mask, | |
# "mask_base_cur": mask_base_cur, | |
# "mask_replace_cur": mask_replace_cur, | |
# "up_scale": up_scale, | |
# "up_ft_index": up_ft_index, | |
# "w_edit": w_edit, | |
# "w_content": w_content, | |
# "w_edit": w_edit, | |
# "w_content": w_content, | |
# } | |
def process_remove( | |
path_mask, | |
h, | |
w, | |
dx, | |
dy, | |
scale, | |
input_scale, | |
up_scale, | |
up_ft_index, | |
w_edit, | |
w_contrast, | |
w_content, | |
precision, | |
resize_scale_x, | |
resize_scale_y, | |
): | |
dx, dy = dx * input_scale, dy * input_scale | |
if isinstance(path_mask, str): | |
mask_base = cv2.imread(path_mask) | |
else: | |
mask_base = path_mask | |
mask_base = mask_base[None, None] | |
dict_mask = {} | |
mask_base = (mask_base > 0.5).to("cuda", dtype=precision) | |
#####[START] Original rescale and fit method.##### | |
# if resize_scale is not None and resize_scale != 1: | |
# hi, wi = mask_base.shape[-2], mask_base.shape[-1] | |
# mask_base = F.interpolate( | |
# mask_base, (int(hi * resize_scale), int(wi * resize_scale)) | |
# ) | |
# pad_size_x = np.abs(mask_base.shape[-1] - wi) // 2 | |
# pad_size_y = np.abs(mask_base.shape[-2] - hi) // 2 | |
# if resize_scale > 1: | |
# mask_base = mask_base[ | |
# :, :, pad_size_y : pad_size_y + hi, pad_size_x : pad_size_x + wi | |
# ] | |
# else: | |
# temp = torch.zeros(1, 1, hi, wi).to(mask_base.device) | |
# temp[ | |
# :, | |
# :, | |
# pad_size_y : pad_size_y + mask_base.shape[-2], | |
# pad_size_x : pad_size_x + mask_base.shape[-1], | |
# ] = mask_base | |
# mask_base = temp | |
#####[END] Original rescale and fit method.##### | |
hi, wi = mask_base.shape[-2], mask_base.shape[-1] | |
mask_base = F.interpolate( | |
mask_base, (int(hi*resize_scale_y), int(wi*resize_scale_x)) | |
) | |
temp = torch.zeros(1, 1, hi, wi).to(mask_base.device) | |
pad_x, pad_y = (wi - mask_base.shape[-1]) // 2, (hi - mask_base.shape[-2]) // 2 | |
px_tmp, py_tmp = max(pad_x, 0), max(pad_y, 0) | |
px_tar, py_tar = max(-pad_x, 0), max(-pad_y, 0) | |
temp[:,:,py_tmp:py_tmp+mask_base.shape[-2],px_tmp:px_tmp+mask_base.shape[-1]] = mask_base[ | |
:,:,py_tar:py_tar+temp.shape[-2],px_tar:px_tar+temp.shape[-1]] | |
mask_base = temp | |
mask_replace = mask_base.clone() | |
mask_base = torch.roll(mask_base, (int(dy), int(dx)), (-2, -1)) # (C,T,F) | |
dict_mask["base"] = mask_base[0, 0] | |
dict_mask["replace"] = mask_replace[0, 0] | |
mask_replace = (mask_replace > 0.5).to("cuda", dtype=precision) | |
mask_base_cur = ( | |
F.interpolate( | |
mask_base, | |
(int(mask_base.shape[-2] // scale), int(mask_base.shape[-1] // scale)), | |
) | |
> 0.5 | |
) | |
mask_replace_cur = torch.roll( | |
mask_base_cur, (-int(dy / scale), -int(dx / scale)), (-2, -1) | |
) | |
return { | |
"dict_mask": dict_mask, | |
"mask_base_cur": mask_base_cur, | |
"mask_replace_cur": mask_replace_cur, | |
"up_scale": up_scale, | |
"up_ft_index": up_ft_index, | |
"w_edit": w_edit, | |
"w_contrast": w_contrast, | |
"w_content": w_content, | |
} | |
# def process_remove( | |
# path_mask, | |
# h, | |
# w, | |
# dx, | |
# dy, | |
# scale, | |
# input_scale, | |
# up_scale, | |
# up_ft_index, | |
# w_edit, | |
# w_contrast, | |
# w_content, | |
# precision, | |
# resize_scale=None, | |
# ): | |
# dx, dy = dx * input_scale, dy * input_scale | |
# if isinstance(path_mask, str): | |
# mask_base = cv2.imread(path_mask) | |
# else: | |
# mask_base = path_mask | |
# mask_base = mask_base[None, None] | |
# dict_mask = {} | |
# mask_base = (mask_base > 0.5).to("cuda", dtype=precision) | |
# if resize_scale is not None and resize_scale != 1: | |
# hi, wi = mask_base.shape[-2], mask_base.shape[-1] | |
# mask_base = F.interpolate( | |
# mask_base, (int(hi * resize_scale), int(wi * resize_scale)) | |
# ) | |
# pad_size_x = np.abs(mask_base.shape[-1] - wi) // 2 | |
# pad_size_y = np.abs(mask_base.shape[-2] - hi) // 2 | |
# if resize_scale > 1: | |
# mask_base = mask_base[ | |
# :, :, pad_size_y : pad_size_y + hi, pad_size_x : pad_size_x + wi | |
# ] | |
# else: | |
# temp = torch.zeros(1, 1, hi, wi).to(mask_base.device) | |
# temp[ | |
# :, | |
# :, | |
# pad_size_y : pad_size_y + mask_base.shape[-2], | |
# pad_size_x : pad_size_x + mask_base.shape[-1], | |
# ] = mask_base | |
# mask_base = temp | |
# mask_replace = mask_base.clone() | |
# mask_base = torch.roll(mask_base, (int(dy), int(dx)), (-2, -1)) # (C,T,F) | |
# dict_mask["base"] = mask_base[0, 0] | |
# dict_mask["replace"] = mask_replace[0, 0] | |
# mask_replace = (mask_replace > 0.5).to("cuda", dtype=precision) | |
# mask_base_cur = ( | |
# F.interpolate( | |
# mask_base, | |
# (int(mask_base.shape[-2] // scale), int(mask_base.shape[-1] // scale)), | |
# ) | |
# > 0.5 | |
# ) | |
# mask_replace_cur = torch.roll( | |
# mask_base_cur, (-int(dy / scale), -int(dx / scale)), (-2, -1) | |
# ) | |
# return { | |
# "dict_mask": dict_mask, | |
# "mask_base_cur": mask_base_cur, | |
# "mask_replace_cur": mask_replace_cur, | |
# "up_scale": up_scale, | |
# "up_ft_index": up_ft_index, | |
# "w_edit": w_edit, | |
# "w_contrast": w_contrast, | |
# "w_content": w_content, | |
# } |