Spaces:
Build error
Build error
r""" CHM 4D kernel (psi, iso, and full) generator """ | |
import torch | |
from .geometry import Geometry | |
class KernelGenerator: | |
def __init__(self, ksz, ktype): | |
self.ksz = ksz | |
self.idx4d = Geometry.init_idx4d(ksz) | |
self.kernel = torch.zeros((ksz, ksz, ksz, ksz)) | |
self.center = (ksz // 2, ksz // 2) | |
self.ktype = ktype | |
def quadrant(self, crd): | |
if crd[0] < self.center[0]: | |
horz_quad = -1 | |
elif crd[0] < self.center[0]: | |
horz_quad = 1 | |
else: | |
horz_quad = 0 | |
if crd[1] < self.center[1]: | |
vert_quad = -1 | |
elif crd[1] < self.center[1]: | |
vert_quad = 1 | |
else: | |
vert_quad = 0 | |
return horz_quad, vert_quad | |
def generate(self): | |
return None if self.ktype == 'full' else self.generate_chm_kernel() | |
def generate_chm_kernel(self): | |
param_dict = {} | |
for idx in self.idx4d: | |
src_i, src_j, trg_i, trg_j = idx | |
d_tail = Geometry.get_distance((src_i, src_j), self.center) | |
d_head = Geometry.get_distance((trg_i, trg_j), self.center) | |
d_off = Geometry.get_distance((src_i, src_j), (trg_i, trg_j)) | |
horz_quad, vert_quad = self.quadrant((src_j, src_i)) | |
src_crd = (src_i, src_j) | |
trg_crd = (trg_i, trg_j) | |
key = self.build_key(horz_quad, vert_quad, d_head, d_tail, src_crd, trg_crd, d_off) | |
coord1d = Geometry.get_coord1d((src_i, src_j, trg_i, trg_j), self.ksz) | |
if param_dict.get(key) is None: param_dict[key] = [] | |
param_dict[key].append(coord1d) | |
return param_dict | |
def build_key(self, horz_quad, vert_quad, d_head, d_tail, src_crd, trg_crd, d_off): | |
if self.ktype == 'iso': | |
return '%d' % d_off | |
elif self.ktype == 'psi': | |
d_max = max(d_head, d_tail) | |
d_min = min(d_head, d_tail) | |
return '%d_%d_%d' % (d_max, d_min, d_off) | |
else: | |
raise Exception('not implemented.') | |