|
import torch |
|
import cupy |
|
import kornia |
|
import torch.nn as nn |
|
|
|
from modules.cupy_module.cupy_utils import cupy_launch |
|
|
|
|
|
_batch_edt_kernel = ('kernel_dt', ''' |
|
extern "C" __global__ void kernel_dt( |
|
const int bs, |
|
const int h, |
|
const int w, |
|
const float diam2, |
|
float* data, |
|
float* output |
|
) { |
|
int idx = blockIdx.x * blockDim.x + threadIdx.x; |
|
if (idx >= bs*h*w) { |
|
return; |
|
} |
|
int pb = idx / (h*w); |
|
int pi = (idx - h*w*pb) / w; |
|
int pj = (idx - h*w*pb - w*pi); |
|
|
|
float cost; |
|
float mincost = diam2; |
|
for (int j = 0; j < w; j++) { |
|
cost = data[h*w*pb + w*pi + j] + (pj-j)*(pj-j); |
|
if (cost < mincost) { |
|
mincost = cost; |
|
} |
|
} |
|
output[idx] = mincost; |
|
return; |
|
} |
|
''') |
|
|
|
class NEDT(nn.Module): |
|
def __init__(self): |
|
super().__init__() |
|
|
|
def batch_edt(self, img, block=1024): |
|
|
|
_batch_edt = cupy_launch(*_batch_edt_kernel) |
|
|
|
|
|
if len(img.shape)==4: |
|
assert img.shape[1]==1 |
|
img = img.squeeze(1) |
|
expand = True |
|
else: |
|
expand = False |
|
bs,h,w = img.shape |
|
diam2 = h**2 + w**2 |
|
odtype = img.dtype |
|
grid = (img.nelement()+block-1) // block |
|
|
|
|
|
data = ((1-img.type(torch.float32)) * diam2).contiguous() |
|
intermed = torch.zeros_like(data) |
|
_batch_edt( |
|
grid=(grid, 1, 1), |
|
block=(block, 1, 1), |
|
args=[ |
|
cupy.int32(bs), |
|
cupy.int32(h), |
|
cupy.int32(w), |
|
cupy.float32(diam2), |
|
data.data_ptr(), |
|
intermed.data_ptr(), |
|
], |
|
) |
|
|
|
|
|
intermed = intermed.permute(0,2,1).contiguous() |
|
out = torch.zeros_like(intermed) |
|
_batch_edt( |
|
grid=(grid, 1, 1), |
|
block=(block, 1, 1), |
|
args=[ |
|
cupy.int32(bs), |
|
cupy.int32(w), |
|
cupy.int32(h), |
|
cupy.float32(diam2), |
|
intermed.data_ptr(), |
|
out.data_ptr(), |
|
], |
|
) |
|
ans = out.permute(0,2,1).sqrt() |
|
ans = ans.type(odtype) if odtype!=ans.dtype else ans |
|
|
|
if expand: |
|
ans = ans.unsqueeze(1) |
|
return ans |
|
|
|
def batch_dog(self, img, t=1.0, sigma=1.0, k=1.6, epsilon=0.01, kernel_factor=4, clip=True): |
|
|
|
bs,ch,h,w = img.shape |
|
if ch in [3,4]: |
|
img = kornia.color.rgb_to_grayscale(img[:,:3]) |
|
else: |
|
assert ch==1 |
|
|
|
|
|
kern0 = max(2*int(sigma*kernel_factor)+1, 3) |
|
kern1 = max(2*int(sigma*k*kernel_factor)+1, 3) |
|
g0 = kornia.filters.gaussian_blur2d( |
|
img, (kern0,kern0), (sigma,sigma), border_type='replicate', |
|
) |
|
g1 = kornia.filters.gaussian_blur2d( |
|
img, (kern1,kern1), (sigma*k,sigma*k), border_type='replicate', |
|
) |
|
out = 0.5 + t*(g1 - g0) - epsilon |
|
out = out.clip(0,1) if clip else out |
|
return out |
|
|
|
def forward( |
|
self, img, t=2.0, sigma_factor=1/540, |
|
k=1.6, epsilon=0.01, |
|
kernel_factor=4, exp_factor=540/15 |
|
): |
|
dog = self.batch_dog( |
|
img, t=t, sigma=img.shape[-2]*sigma_factor, k=k, |
|
epsilon=epsilon, kernel_factor=kernel_factor, clip=False, |
|
) |
|
edt = self.batch_edt((dog > 0.5).float()) |
|
out = 1 - (-edt*exp_factor / max(edt.shape[-2:])).exp() |
|
return out |