Spaces:
Build error
Build error
r""" Convolutional Hough Matching Networks """ | |
import torch.nn as nn | |
import torch | |
from . import chmlearner as chmlearner | |
from .base import backbone | |
class CHMNet(nn.Module): | |
def __init__(self, ktype): | |
super(CHMNet, self).__init__() | |
self.backbone = backbone.resnet101(pretrained=True) | |
self.learner = chmlearner.CHMLearner(ktype, feat_dim=1024) | |
def forward(self, src_img, trg_img): | |
src_feat, trg_feat = self.extract_features(src_img, trg_img) | |
correlation = self.learner(src_feat, trg_feat) | |
return correlation | |
def extract_features(self, src_img, trg_img): | |
feat = self.backbone.conv1.forward(torch.cat([src_img, trg_img], dim=1)) | |
feat = self.backbone.bn1.forward(feat) | |
feat = self.backbone.relu.forward(feat) | |
feat = self.backbone.maxpool.forward(feat) | |
for idx in range(1, 5): | |
feat = self.backbone.__getattr__('layer%d' % idx)(feat) | |
if idx == 3: | |
src_feat = feat.narrow(1, 0, feat.size(1) // 2).clone() | |
trg_feat = feat.narrow(1, feat.size(1) // 2, feat.size(1) // 2).clone() | |
return src_feat, trg_feat | |
def training_objective(cls, prd_kps, trg_kps, npts): | |
l2dist = (prd_kps - trg_kps).pow(2).sum(dim=1) | |
loss = [] | |
for dist, npt in zip(l2dist, npts): | |
loss.append(dist[:npt].mean()) | |
return torch.stack(loss).mean() | |