File size: 3,824 Bytes
587665f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
import torch
import cupy
import kornia
import torch.nn as nn

from modules.cupy_module.cupy_utils import cupy_launch
# Code taken from https://github.com/ShuhongChen/eisai-anime-interpolator

_batch_edt_kernel = ('kernel_dt', '''
    extern "C" __global__ void kernel_dt(
        const int bs,
        const int h,
        const int w,
        const float diam2,
        float* data,
        float* output
    ) {
        int idx = blockIdx.x * blockDim.x + threadIdx.x;
        if (idx >= bs*h*w) {
            return;
        }
        int pb = idx / (h*w);
        int pi = (idx - h*w*pb) / w;
        int pj = (idx - h*w*pb - w*pi);

        float cost;
        float mincost = diam2;
        for (int j = 0; j < w; j++) {
            cost = data[h*w*pb + w*pi + j] + (pj-j)*(pj-j);
            if (cost < mincost) {
                mincost = cost;
            }
        }
        output[idx] = mincost;
        return;
    }
''')

class NEDT(nn.Module):
    def __init__(self):
        super().__init__()

    def batch_edt(self, img, block=1024):
        # must initialize cuda/cupy after forking
        _batch_edt = cupy_launch(*_batch_edt_kernel)

        # bookkeeppingg
        if len(img.shape)==4:
            assert img.shape[1]==1
            img = img.squeeze(1)
            expand = True
        else:
            expand = False
        bs,h,w = img.shape
        diam2 = h**2 + w**2
        odtype = img.dtype
        grid = (img.nelement()+block-1) // block

        # first pass, y-axis
        data = ((1-img.type(torch.float32)) * diam2).contiguous()
        intermed = torch.zeros_like(data)
        _batch_edt(
            grid=(grid, 1, 1),
            block=(block, 1, 1),  # < 1024
            args=[
                cupy.int32(bs),
                cupy.int32(h),
                cupy.int32(w),
                cupy.float32(diam2),
                data.data_ptr(),
                intermed.data_ptr(),
            ],
        )
        
        # second pass, x-axis
        intermed = intermed.permute(0,2,1).contiguous()
        out = torch.zeros_like(intermed)
        _batch_edt(
            grid=(grid, 1, 1),
            block=(block, 1, 1),
            args=[
                cupy.int32(bs),
                cupy.int32(w),
                cupy.int32(h),
                cupy.float32(diam2),
                intermed.data_ptr(),
                out.data_ptr(),
            ],
        )
        ans = out.permute(0,2,1).sqrt()
        ans = ans.type(odtype) if odtype!=ans.dtype else ans

        if expand:
            ans = ans.unsqueeze(1)
        return ans

    def batch_dog(self, img, t=1.0, sigma=1.0, k=1.6, epsilon=0.01, kernel_factor=4, clip=True):
        # to grayscale if needed
        bs,ch,h,w = img.shape
        if ch in [3,4]:
            img = kornia.color.rgb_to_grayscale(img[:,:3])
        else:
            assert ch==1

        # calculate dog
        kern0 = max(2*int(sigma*kernel_factor)+1, 3)
        kern1 = max(2*int(sigma*k*kernel_factor)+1, 3)
        g0 = kornia.filters.gaussian_blur2d(
            img, (kern0,kern0), (sigma,sigma), border_type='replicate',
        )
        g1 = kornia.filters.gaussian_blur2d(
            img, (kern1,kern1), (sigma*k,sigma*k), border_type='replicate',
        )
        out = 0.5 + t*(g1 - g0) - epsilon
        out = out.clip(0,1) if clip else out
        return out
    
    def forward(
        self, img, t=2.0, sigma_factor=1/540, 
        k=1.6, epsilon=0.01,
        kernel_factor=4, exp_factor=540/15
    ):
        dog = self.batch_dog(
            img, t=t, sigma=img.shape[-2]*sigma_factor, k=k,
            epsilon=epsilon, kernel_factor=kernel_factor, clip=False,
        )
        edt = self.batch_edt((dog > 0.5).float())
        out = 1 - (-edt*exp_factor / max(edt.shape[-2:])).exp()
        return out