import cv2 import torch import numpy as np import torch.nn.functional as F def resize_numpy_image(image, max_resolution=768 * 768, resize_short_edge=None): h, w = image.shape[:2] w_org = image.shape[1] if resize_short_edge is not None: k = resize_short_edge / min(h, w) else: k = max_resolution / (h * w) k = k**0.5 h = int(np.round(h * k / 64)) * 64 w = int(np.round(w * k / 64)) * 64 image = cv2.resize(image, (w, h), interpolation=cv2.INTER_LANCZOS4) scale = w / w_org return image, scale def split_ldm(ldm): x = [] y = [] for p in ldm: x.append(p[0]) y.append(p[1]) return x, y def process_move( path_mask, # target region of original map h, w, dx, dy, scale, input_scale, resize_scale_x, resize_scale_y, up_scale, up_ft_index, w_edit, w_content, w_contrast, w_inpaint, precision, path_mask_ref=None, path_mask_keep=None, ): dx, dy = dx * input_scale, dy * input_scale mask_x0 = path_mask mask_x0_ref = path_mask_ref mask_x0_keep = path_mask_keep mask_x0 = (mask_x0 > 0.5).float().to("cuda", dtype=precision) if mask_x0_ref is not None: mask_x0_ref = (mask_x0_ref > 0.5).float().to("cuda", dtype=precision) # Define region to keep if `path_mask_keep` is given if mask_x0_keep is not None: mask_x0_keep = (mask_x0_keep > 0.5).float().to("cuda", dtype=precision) mask_keep = ( F.interpolate( mask_x0_keep[None, None], (int(mask_x0_keep.shape[-2] // scale), int(mask_x0_keep.shape[-1] // scale)), ) > 0.5 ).float() else: mask_keep = None mask_org = ( F.interpolate( mask_x0[None, None], (int(mask_x0.shape[-2] // scale), int(mask_x0.shape[-1] // scale)), ) > 0.5 ) mask_tar = ( F.interpolate( mask_x0[None, None], ( int(mask_x0.shape[-2] // scale * resize_scale_y), int(mask_x0.shape[-1] // scale * resize_scale_x), ), ) > 0.5 ) mask_cur = torch.roll( mask_tar, (int(dy // scale * resize_scale_y), int(dx // scale * resize_scale_x)), (-2, -1), ) temp = torch.zeros(1, 1, mask_org.shape[-2], mask_org.shape[-1]).to( mask_org.device ) pad_x = (temp.shape[-1] - mask_cur.shape[-1]) // 2 pad_y = (temp.shape[-2] - mask_cur.shape[-2]) // 2 px_tmp, py_tmp = max(pad_x, 0), max(pad_y, 0) px_tar, py_tar = max(-pad_x, 0), max(-pad_y, 0) temp[:,:,py_tmp:py_tmp+mask_cur.shape[-2],px_tmp:px_tmp+mask_cur.shape[-1]] = mask_cur[ :,:,py_tar:py_tar+temp.shape[-2],px_tar:px_tar+temp.shape[-1]] # To avoid mask misaligned by shifting and cropping _mask_valid = torch.zeros_like(mask_cur) _mask_valid[:,:,py_tar:py_tar+temp.shape[-2],px_tar:px_tar+temp.shape[-1]] = 1 _mask_valid = (torch.roll( _mask_valid, (-int(dy // scale * resize_scale_y), int(-dx // scale * resize_scale_x)), (-2, -1), ) > 0.5) mask_tar = torch.logical_and(mask_tar, _mask_valid) # Ensure the editing region is within the spectrogram if resize_scale_x > 1 or resize_scale_y > 1: sum_before = torch.sum(mask_tar) # replace `mask_cur` here sum_after = torch.sum(temp) if sum_after != sum_before: raise ValueError("Resize out of bounds, exiting.") mask_cur = temp > 0.5 # Region of uninterested region is selected region when `mask_keep` is given if mask_keep is not None: mask_other = mask_keep > 0.5 else: mask_other = (1 - ((mask_cur + mask_org) > 0.5).float()) > 0.5 mask_overlap = ((mask_cur.float() + mask_org.float()) > 1.5).float() mask_non_overlap = (mask_org.float() - mask_overlap) > 0.5 return { "mask_x0": mask_x0, "mask_x0_ref": mask_x0_ref, "mask_x0_keep": mask_x0_keep, "mask_tar": mask_tar, "mask_cur": mask_cur, "mask_other": mask_other, "mask_overlap": mask_overlap, "mask_non_overlap": mask_non_overlap, "mask_keep": mask_keep, "up_scale": up_scale, "up_ft_index": up_ft_index, "resize_scale_x": resize_scale_x, "resize_scale_y": resize_scale_y, "w_edit": w_edit, "w_content": w_content, "w_contrast": w_contrast, "w_inpaint": w_inpaint, } def process_paste( path_mask, h, w, dx, dy, scale, input_scale, up_scale, up_ft_index, w_edit, w_content, precision, resize_scale_x, resize_scale_y, ): dx, dy = dx * input_scale, dy * input_scale if isinstance(path_mask, str): mask_base = cv2.imread(path_mask) else: mask_base = path_mask mask_base = mask_base[None, None] dict_mask = {} mask_base = (mask_base > 0.5).to("cuda", dtype=precision) #####[START] Original rescale and fit method.##### # if resize_scale is not None and resize_scale != 1: # hi, wi = mask_base.shape[-2], mask_base.shape[-1] # mask_base = F.interpolate( # mask_base, (int(hi * resize_scale), int(wi * resize_scale)) # ) # pad_size_x = np.abs(mask_base.shape[-1] - wi) // 2 # pad_size_y = np.abs(mask_base.shape[-2] - hi) // 2 # if resize_scale > 1: # mask_base = mask_base[ # :, :, pad_size_y : pad_size_y + hi, pad_size_x : pad_size_x + wi # ] # else: # temp = torch.zeros(1, 1, hi, wi).to(mask_base.device) # temp[ # :, # :, # pad_size_y : pad_size_y + mask_base.shape[-2], # pad_size_x : pad_size_x + mask_base.shape[-1], # ] = mask_base # mask_base = temp #####[END] Original rescale and fit method.##### hi, wi = mask_base.shape[-2], mask_base.shape[-1] mask_base = F.interpolate( mask_base, (int(hi*resize_scale_y), int(wi*resize_scale_x)) ) temp = torch.zeros(1, 1, hi, wi).to(mask_base.device) pad_x, pad_y = (wi - mask_base.shape[-1]) // 2, (hi - mask_base.shape[-2]) // 2 px_tmp, py_tmp = max(pad_x, 0), max(pad_y, 0) px_tar, py_tar = max(-pad_x, 0), max(-pad_y, 0) temp[:,:,py_tmp:py_tmp+mask_base.shape[-2],px_tmp:px_tmp+mask_base.shape[-1]] = mask_base[ :,:,py_tar:py_tar+temp.shape[-2],px_tar:px_tar+temp.shape[-1]] mask_base = temp mask_replace = mask_base.clone() mask_base = torch.roll( mask_base, (int(dy*resize_scale_y), int(dx*resize_scale_x)), (-2, -1)) # (C,T,F) dict_mask["base"] = mask_base[0, 0] dict_mask["replace"] = mask_replace[0, 0] mask_replace = (mask_replace > 0.5).to("cuda", dtype=precision) mask_base_cur = ( F.interpolate( mask_base, (int(mask_base.shape[-2]//scale), int(mask_base.shape[-1]//scale)), ) > 0.5 ) mask_replace_cur = torch.roll( mask_base_cur, (-int(dy/scale), -int(dx/scale)), (-2, -1) ) return { "dict_mask": dict_mask, "mask_base_cur": mask_base_cur, "mask_replace_cur": mask_replace_cur, "up_scale": up_scale, "up_ft_index": up_ft_index, "w_edit": w_edit, "w_content": w_content, "w_edit": w_edit, "w_content": w_content, } # def process_paste( # path_mask, # h, # w, # dx, # dy, # scale, # input_scale, # up_scale, # up_ft_index, # w_edit, # w_content, # precision, # resize_scale=None, # ): # dx, dy = dx * input_scale, dy * input_scale # if isinstance(path_mask, str): # mask_base = cv2.imread(path_mask) # else: # mask_base = path_mask # mask_base = mask_base[None, None] # dict_mask = {} # mask_base = (mask_base > 0.5).to("cuda", dtype=precision) # if resize_scale is not None and resize_scale != 1: # hi, wi = mask_base.shape[-2], mask_base.shape[-1] # mask_base = F.interpolate( # mask_base, (int(hi * resize_scale), int(wi * resize_scale)) # ) # pad_size_x = np.abs(mask_base.shape[-1] - wi) // 2 # pad_size_y = np.abs(mask_base.shape[-2] - hi) // 2 # if resize_scale > 1: # mask_base = mask_base[ # :, :, pad_size_y : pad_size_y + hi, pad_size_x : pad_size_x + wi # ] # else: # temp = torch.zeros(1, 1, hi, wi).to(mask_base.device) # temp[ # :, # :, # pad_size_y : pad_size_y + mask_base.shape[-2], # pad_size_x : pad_size_x + mask_base.shape[-1], # ] = mask_base # mask_base = temp # mask_replace = mask_base.clone() # mask_base = torch.roll(mask_base, (int(dy), int(dx)), (-2, -1)) # (C,T,F) # dict_mask["base"] = mask_base[0, 0] # dict_mask["replace"] = mask_replace[0, 0] # mask_replace = (mask_replace > 0.5).to("cuda", dtype=precision) # mask_base_cur = ( # F.interpolate( # mask_base, # (int(mask_base.shape[-2] // scale), int(mask_base.shape[-1] // scale)), # ) # > 0.5 # ) # mask_replace_cur = torch.roll( # mask_base_cur, (-int(dy / scale), -int(dx / scale)), (-2, -1) # ) # return { # "dict_mask": dict_mask, # "mask_base_cur": mask_base_cur, # "mask_replace_cur": mask_replace_cur, # "up_scale": up_scale, # "up_ft_index": up_ft_index, # "w_edit": w_edit, # "w_content": w_content, # "w_edit": w_edit, # "w_content": w_content, # } def process_remove( path_mask, h, w, dx, dy, scale, input_scale, up_scale, up_ft_index, w_edit, w_contrast, w_content, precision, resize_scale_x, resize_scale_y, ): dx, dy = dx * input_scale, dy * input_scale if isinstance(path_mask, str): mask_base = cv2.imread(path_mask) else: mask_base = path_mask mask_base = mask_base[None, None] dict_mask = {} mask_base = (mask_base > 0.5).to("cuda", dtype=precision) #####[START] Original rescale and fit method.##### # if resize_scale is not None and resize_scale != 1: # hi, wi = mask_base.shape[-2], mask_base.shape[-1] # mask_base = F.interpolate( # mask_base, (int(hi * resize_scale), int(wi * resize_scale)) # ) # pad_size_x = np.abs(mask_base.shape[-1] - wi) // 2 # pad_size_y = np.abs(mask_base.shape[-2] - hi) // 2 # if resize_scale > 1: # mask_base = mask_base[ # :, :, pad_size_y : pad_size_y + hi, pad_size_x : pad_size_x + wi # ] # else: # temp = torch.zeros(1, 1, hi, wi).to(mask_base.device) # temp[ # :, # :, # pad_size_y : pad_size_y + mask_base.shape[-2], # pad_size_x : pad_size_x + mask_base.shape[-1], # ] = mask_base # mask_base = temp #####[END] Original rescale and fit method.##### hi, wi = mask_base.shape[-2], mask_base.shape[-1] mask_base = F.interpolate( mask_base, (int(hi*resize_scale_y), int(wi*resize_scale_x)) ) temp = torch.zeros(1, 1, hi, wi).to(mask_base.device) pad_x, pad_y = (wi - mask_base.shape[-1]) // 2, (hi - mask_base.shape[-2]) // 2 px_tmp, py_tmp = max(pad_x, 0), max(pad_y, 0) px_tar, py_tar = max(-pad_x, 0), max(-pad_y, 0) temp[:,:,py_tmp:py_tmp+mask_base.shape[-2],px_tmp:px_tmp+mask_base.shape[-1]] = mask_base[ :,:,py_tar:py_tar+temp.shape[-2],px_tar:px_tar+temp.shape[-1]] mask_base = temp mask_replace = mask_base.clone() mask_base = torch.roll(mask_base, (int(dy), int(dx)), (-2, -1)) # (C,T,F) dict_mask["base"] = mask_base[0, 0] dict_mask["replace"] = mask_replace[0, 0] mask_replace = (mask_replace > 0.5).to("cuda", dtype=precision) mask_base_cur = ( F.interpolate( mask_base, (int(mask_base.shape[-2] // scale), int(mask_base.shape[-1] // scale)), ) > 0.5 ) mask_replace_cur = torch.roll( mask_base_cur, (-int(dy / scale), -int(dx / scale)), (-2, -1) ) return { "dict_mask": dict_mask, "mask_base_cur": mask_base_cur, "mask_replace_cur": mask_replace_cur, "up_scale": up_scale, "up_ft_index": up_ft_index, "w_edit": w_edit, "w_contrast": w_contrast, "w_content": w_content, } # def process_remove( # path_mask, # h, # w, # dx, # dy, # scale, # input_scale, # up_scale, # up_ft_index, # w_edit, # w_contrast, # w_content, # precision, # resize_scale=None, # ): # dx, dy = dx * input_scale, dy * input_scale # if isinstance(path_mask, str): # mask_base = cv2.imread(path_mask) # else: # mask_base = path_mask # mask_base = mask_base[None, None] # dict_mask = {} # mask_base = (mask_base > 0.5).to("cuda", dtype=precision) # if resize_scale is not None and resize_scale != 1: # hi, wi = mask_base.shape[-2], mask_base.shape[-1] # mask_base = F.interpolate( # mask_base, (int(hi * resize_scale), int(wi * resize_scale)) # ) # pad_size_x = np.abs(mask_base.shape[-1] - wi) // 2 # pad_size_y = np.abs(mask_base.shape[-2] - hi) // 2 # if resize_scale > 1: # mask_base = mask_base[ # :, :, pad_size_y : pad_size_y + hi, pad_size_x : pad_size_x + wi # ] # else: # temp = torch.zeros(1, 1, hi, wi).to(mask_base.device) # temp[ # :, # :, # pad_size_y : pad_size_y + mask_base.shape[-2], # pad_size_x : pad_size_x + mask_base.shape[-1], # ] = mask_base # mask_base = temp # mask_replace = mask_base.clone() # mask_base = torch.roll(mask_base, (int(dy), int(dx)), (-2, -1)) # (C,T,F) # dict_mask["base"] = mask_base[0, 0] # dict_mask["replace"] = mask_replace[0, 0] # mask_replace = (mask_replace > 0.5).to("cuda", dtype=precision) # mask_base_cur = ( # F.interpolate( # mask_base, # (int(mask_base.shape[-2] // scale), int(mask_base.shape[-1] // scale)), # ) # > 0.5 # ) # mask_replace_cur = torch.roll( # mask_base_cur, (-int(dy / scale), -int(dx / scale)), (-2, -1) # ) # return { # "dict_mask": dict_mask, # "mask_base_cur": mask_base_cur, # "mask_replace_cur": mask_replace_cur, # "up_scale": up_scale, # "up_ft_index": up_ft_index, # "w_edit": w_edit, # "w_contrast": w_contrast, # "w_content": w_content, # }