vfontech's picture
Uploading the app
587665f verified
raw
history blame contribute delete
3.82 kB
import torch
import cupy
import kornia
import torch.nn as nn
from modules.cupy_module.cupy_utils import cupy_launch
# Code taken from https://github.com/ShuhongChen/eisai-anime-interpolator
_batch_edt_kernel = ('kernel_dt', '''
extern "C" __global__ void kernel_dt(
const int bs,
const int h,
const int w,
const float diam2,
float* data,
float* output
) {
int idx = blockIdx.x * blockDim.x + threadIdx.x;
if (idx >= bs*h*w) {
return;
}
int pb = idx / (h*w);
int pi = (idx - h*w*pb) / w;
int pj = (idx - h*w*pb - w*pi);
float cost;
float mincost = diam2;
for (int j = 0; j < w; j++) {
cost = data[h*w*pb + w*pi + j] + (pj-j)*(pj-j);
if (cost < mincost) {
mincost = cost;
}
}
output[idx] = mincost;
return;
}
''')
class NEDT(nn.Module):
def __init__(self):
super().__init__()
def batch_edt(self, img, block=1024):
# must initialize cuda/cupy after forking
_batch_edt = cupy_launch(*_batch_edt_kernel)
# bookkeeppingg
if len(img.shape)==4:
assert img.shape[1]==1
img = img.squeeze(1)
expand = True
else:
expand = False
bs,h,w = img.shape
diam2 = h**2 + w**2
odtype = img.dtype
grid = (img.nelement()+block-1) // block
# first pass, y-axis
data = ((1-img.type(torch.float32)) * diam2).contiguous()
intermed = torch.zeros_like(data)
_batch_edt(
grid=(grid, 1, 1),
block=(block, 1, 1), # < 1024
args=[
cupy.int32(bs),
cupy.int32(h),
cupy.int32(w),
cupy.float32(diam2),
data.data_ptr(),
intermed.data_ptr(),
],
)
# second pass, x-axis
intermed = intermed.permute(0,2,1).contiguous()
out = torch.zeros_like(intermed)
_batch_edt(
grid=(grid, 1, 1),
block=(block, 1, 1),
args=[
cupy.int32(bs),
cupy.int32(w),
cupy.int32(h),
cupy.float32(diam2),
intermed.data_ptr(),
out.data_ptr(),
],
)
ans = out.permute(0,2,1).sqrt()
ans = ans.type(odtype) if odtype!=ans.dtype else ans
if expand:
ans = ans.unsqueeze(1)
return ans
def batch_dog(self, img, t=1.0, sigma=1.0, k=1.6, epsilon=0.01, kernel_factor=4, clip=True):
# to grayscale if needed
bs,ch,h,w = img.shape
if ch in [3,4]:
img = kornia.color.rgb_to_grayscale(img[:,:3])
else:
assert ch==1
# calculate dog
kern0 = max(2*int(sigma*kernel_factor)+1, 3)
kern1 = max(2*int(sigma*k*kernel_factor)+1, 3)
g0 = kornia.filters.gaussian_blur2d(
img, (kern0,kern0), (sigma,sigma), border_type='replicate',
)
g1 = kornia.filters.gaussian_blur2d(
img, (kern1,kern1), (sigma*k,sigma*k), border_type='replicate',
)
out = 0.5 + t*(g1 - g0) - epsilon
out = out.clip(0,1) if clip else out
return out
def forward(
self, img, t=2.0, sigma_factor=1/540,
k=1.6, epsilon=0.01,
kernel_factor=4, exp_factor=540/15
):
dog = self.batch_dog(
img, t=t, sigma=img.shape[-2]*sigma_factor, k=k,
epsilon=epsilon, kernel_factor=kernel_factor, clip=False,
)
edt = self.batch_edt((dog > 0.5).float())
out = 1 - (-edt*exp_factor / max(edt.shape[-2:])).exp()
return out