Spaces:
Runtime error
Runtime error
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from mmengine.model import BaseModule | |
from opencd.registry import MODELS | |
class FeatureFusionNeck(BaseModule): | |
"""Feature Fusion Neck. | |
Args: | |
policy (str): The operation to fuse features. candidates | |
are `concat`, `sum`, `diff` and `Lp_distance`. | |
in_channels (Sequence(int)): Input channels. | |
channels (int): Channels after modules, before conv_seg. | |
out_indices (tuple[int]): Output from which layer. | |
""" | |
def __init__(self, | |
policy, | |
in_channels=None, | |
channels=None, | |
out_indices=(0, 1, 2, 3)): | |
super().__init__() | |
self.policy = policy | |
self.in_channels = in_channels | |
self.channels = channels | |
self.out_indices = out_indices | |
def fusion(x1, x2, policy): | |
"""Specify the form of feature fusion""" | |
_fusion_policies = ['concat', 'sum', 'diff', 'abs_diff'] | |
assert policy in _fusion_policies, 'The fusion policies {} are ' \ | |
'supported'.format(_fusion_policies) | |
if policy == 'concat': | |
x = torch.cat([x1, x2], dim=1) | |
elif policy == 'sum': | |
x = x1 + x2 | |
elif policy == 'diff': | |
x = x2 - x1 | |
elif policy == 'abs_diff': | |
x = torch.abs(x1 - x2) | |
return x | |
def forward(self, x1, x2): | |
"""Forward function.""" | |
assert len(x1) == len(x2), "The features x1 and x2 from the" \ | |
"backbone should be of equal length" | |
outs = [] | |
for i in range(len(x1)): | |
out = self.fusion(x1[i], x2[i], self.policy) | |
outs.append(out) | |
outs = [outs[i] for i in self.out_indices] | |
return tuple(outs) |