|
from typing import List, Tuple |
|
import os |
|
import torch.distributed as dist |
|
from torch import Tensor |
|
from mmdet.registry import MODELS, TASK_UTILS |
|
from mmdet.models.dense_heads import AnchorFreeHead |
|
from mmdet.structures import SampleList |
|
from mmdet.models.dense_heads import Mask2FormerHead |
|
import math |
|
from mmengine.model.weight_init import trunc_normal_ |
|
import torch |
|
from torch import nn |
|
import torch.nn.functional as F |
|
from mmcv.cnn import build_activation_layer, build_norm_layer |
|
|
|
from mmengine.dist import get_dist_info |
|
|
|
|
|
@MODELS.register_module() |
|
class YOSOHead(Mask2FormerHead): |
|
def __init__(self, |
|
num_cls_fcs=1, |
|
num_mask_fcs=1, |
|
sphere_cls=False, |
|
ov_classifier_name=None, |
|
use_kernel_updator=False, |
|
num_stages=3, |
|
feat_channels=256, |
|
out_channels=256, |
|
num_things_classes=80, |
|
num_stuff_classes=53, |
|
num_classes=133, |
|
num_queries=100, |
|
temperature=0.1, |
|
loss_cls=dict( |
|
type='CrossEntropyLoss', |
|
use_sigmoid=False, |
|
loss_weight=2.0, |
|
reduction='mean', |
|
class_weight=[1.0] * 133 + [0.1]), |
|
loss_mask=dict( |
|
type='CrossEntropyLoss', |
|
use_sigmoid=True, |
|
reduction='mean', |
|
loss_weight=5.0), |
|
loss_dice=dict( |
|
type='DiceLoss', |
|
use_sigmoid=True, |
|
activate=True, |
|
reduction='mean', |
|
naive_dice=True, |
|
eps=1.0, |
|
loss_weight=5.0), |
|
train_cfg=None, |
|
test_cfg=None, |
|
init_cfg=None): |
|
super(AnchorFreeHead, self).__init__(init_cfg=init_cfg) |
|
self.num_stages = num_stages |
|
self.feat_channels = feat_channels |
|
self.out_channels = out_channels |
|
self.num_things_classes = num_things_classes |
|
self.num_stuff_classes = num_stuff_classes |
|
self.num_classes = num_classes |
|
self.num_queries = num_queries |
|
self.temperature = temperature |
|
|
|
self.test_cfg = test_cfg |
|
self.train_cfg = train_cfg |
|
if train_cfg: |
|
self.assigner = TASK_UTILS.build(self.train_cfg['assigner']) |
|
self.sampler = TASK_UTILS.build( |
|
self.train_cfg['sampler'], default_args=dict(context=self)) |
|
self.num_points = self.train_cfg.get('num_points', 12544) |
|
self.oversample_ratio = self.train_cfg.get('oversample_ratio', 3.0) |
|
self.importance_sample_ratio = self.train_cfg.get( |
|
'importance_sample_ratio', 0.75) |
|
|
|
self.class_weight = loss_cls.class_weight |
|
self.loss_cls = MODELS.build(loss_cls) |
|
self.loss_mask = MODELS.build(loss_mask) |
|
self.loss_dice = MODELS.build(loss_dice) |
|
|
|
self.kernels = nn.Embedding(self.num_queries, self.feat_channels) |
|
|
|
self.mask_heads = nn.ModuleList() |
|
for _ in range(self.num_stages): |
|
self.mask_heads.append(CrossAttenHead( |
|
self.num_classes, self.feat_channels, self.num_queries, |
|
use_kernel_updator=use_kernel_updator, |
|
sphere_cls=sphere_cls, ov_classifier_name=ov_classifier_name, |
|
num_cls_fcs=num_cls_fcs, num_mask_fcs=num_mask_fcs |
|
)) |
|
|
|
def init_weights(self) -> None: |
|
super(AnchorFreeHead, self).init_weights() |
|
|
|
def forward(self, x: List[Tensor], |
|
batch_data_samples: SampleList) -> Tuple[List[Tensor]]: |
|
all_cls_scores = [] |
|
all_masks_preds = [] |
|
proposal_kernels = self.kernels.weight |
|
object_kernels = proposal_kernels[None].repeat(x.shape[0], 1, 1) |
|
mask_preds = torch.einsum('bnc,bchw->bnhw', object_kernels, x) |
|
|
|
for stage in range(self.num_stages): |
|
mask_head = self.mask_heads[stage] |
|
cls_scores, mask_preds, iou_pred, object_kernels = mask_head(x, object_kernels, mask_preds) |
|
cls_scores = cls_scores / self.temperature |
|
|
|
all_cls_scores.append(cls_scores) |
|
all_masks_preds.append(mask_preds) |
|
|
|
return all_cls_scores, all_masks_preds |
|
|
|
def predict(self, x: Tuple[Tensor], batch_data_samples: SampleList) -> Tuple[Tensor]: |
|
batch_img_metas = [ |
|
data_sample.metainfo for data_sample in batch_data_samples |
|
] |
|
all_cls_scores, all_mask_preds = self(x, batch_data_samples) |
|
mask_cls_results = all_cls_scores[-1] |
|
mask_pred_results = all_mask_preds[-1] |
|
|
|
|
|
img_shape = batch_img_metas[0]['batch_input_shape'] |
|
mask_pred_results = F.interpolate( |
|
mask_pred_results, |
|
size=(img_shape[0], img_shape[1]), |
|
mode='bilinear', |
|
align_corners=False) |
|
|
|
return mask_cls_results, mask_pred_results |
|
|
|
|
|
class FFN(nn.Module): |
|
|
|
def __init__(self, |
|
embed_dims=256, |
|
feedforward_channels=1024, |
|
num_fcs=2, |
|
add_identity=True): |
|
super(FFN, self).__init__() |
|
self.embed_dims = embed_dims |
|
self.feedforward_channels = feedforward_channels |
|
self.num_fcs = num_fcs |
|
|
|
layers = [] |
|
in_channels = embed_dims |
|
for _ in range(num_fcs - 1): |
|
layers.append(nn.Sequential( |
|
nn.Linear(in_channels, feedforward_channels), |
|
nn.ReLU(True), |
|
nn.Dropout(0.0))) |
|
in_channels = feedforward_channels |
|
layers.append(nn.Linear(feedforward_channels, embed_dims)) |
|
layers.append(nn.Dropout(0.0)) |
|
self.layers = nn.Sequential(*layers) |
|
self.add_identity = add_identity |
|
self.dropout_layer = nn.Dropout(0.0) |
|
|
|
def forward(self, x, identity=None): |
|
out = self.layers(x) |
|
if not self.add_identity: |
|
return self.dropout_layer(out) |
|
if identity is None: |
|
identity = x |
|
return identity + self.dropout_layer(out) |
|
|
|
|
|
class DySepConvAtten(nn.Module): |
|
def __init__(self, hidden_dim, num_proposals, conv_kernel_size_1d): |
|
super(DySepConvAtten, self).__init__() |
|
self.hidden_dim = hidden_dim |
|
self.num_proposals = num_proposals |
|
self.kernel_size = conv_kernel_size_1d |
|
|
|
self.weight_linear = nn.Linear(self.hidden_dim, self.num_proposals + self.kernel_size) |
|
self.norm = nn.LayerNorm(self.hidden_dim) |
|
|
|
def forward(self, query, value): |
|
assert query.shape == value.shape |
|
B, N, C = query.shape |
|
|
|
dy_conv_weight = self.weight_linear(query) |
|
dy_depth_conv_weight = dy_conv_weight[:, :, :self.kernel_size].view(B, self.num_proposals, 1, self.kernel_size) |
|
dy_point_conv_weight = dy_conv_weight[:, :, self.kernel_size:].view(B, self.num_proposals, self.num_proposals, |
|
1) |
|
|
|
res = [] |
|
value = value.unsqueeze(1) |
|
for i in range(B): |
|
out = F.relu(F.conv1d(input=value[i], weight=dy_depth_conv_weight[i], groups=N, padding='same')) |
|
out = F.conv1d(input=out, weight=dy_point_conv_weight[i], padding='same') |
|
res.append(out) |
|
|
|
point_out = torch.cat(res, dim=0) |
|
point_out = self.norm(point_out) |
|
return point_out |
|
|
|
|
|
class KernelUpdator(nn.Module): |
|
|
|
def __init__(self, |
|
in_channels=256, |
|
feat_channels=64, |
|
out_channels=None, |
|
input_feat_shape=3, |
|
gate_sigmoid=True, |
|
gate_norm_act=False, |
|
activate_out=False, |
|
act_cfg=dict(type='ReLU', inplace=True), |
|
norm_cfg=dict(type='LN')): |
|
super(KernelUpdator, self).__init__() |
|
self.in_channels = in_channels |
|
self.feat_channels = feat_channels |
|
self.out_channels_raw = out_channels |
|
self.gate_sigmoid = gate_sigmoid |
|
self.gate_norm_act = gate_norm_act |
|
self.activate_out = activate_out |
|
if isinstance(input_feat_shape, int): |
|
input_feat_shape = [input_feat_shape] * 2 |
|
self.input_feat_shape = input_feat_shape |
|
self.act_cfg = act_cfg |
|
self.norm_cfg = norm_cfg |
|
self.out_channels = out_channels if out_channels else in_channels |
|
|
|
self.num_params_in = self.feat_channels |
|
self.num_params_out = self.feat_channels |
|
self.dynamic_layer = nn.Linear( |
|
self.in_channels, self.num_params_in + self.num_params_out) |
|
self.input_layer = nn.Linear(self.in_channels, |
|
self.num_params_in + self.num_params_out, |
|
1) |
|
self.input_gate = nn.Linear(self.in_channels, self.feat_channels, 1) |
|
self.update_gate = nn.Linear(self.in_channels, self.feat_channels, 1) |
|
if self.gate_norm_act: |
|
self.gate_norm = build_norm_layer(norm_cfg, self.feat_channels)[1] |
|
|
|
self.norm_in = build_norm_layer(norm_cfg, self.feat_channels)[1] |
|
self.norm_out = build_norm_layer(norm_cfg, self.feat_channels)[1] |
|
self.input_norm_in = build_norm_layer(norm_cfg, self.feat_channels)[1] |
|
self.input_norm_out = build_norm_layer(norm_cfg, self.feat_channels)[1] |
|
|
|
self.activation = build_activation_layer(act_cfg) |
|
|
|
self.fc_layer = nn.Linear(self.feat_channels, self.out_channels, 1) |
|
self.fc_norm = build_norm_layer(norm_cfg, self.out_channels)[1] |
|
|
|
def forward(self, update_feature, input_feature): |
|
""" |
|
Args: |
|
update_feature (torch.Tensor): [bs, num_proposals, in_channels] |
|
input_feature (torch.Tensor): [bs, num_proposals, in_channels] |
|
""" |
|
bs, num_proposals, _ = update_feature.shape |
|
|
|
parameters = self.dynamic_layer(update_feature) |
|
param_in = parameters[..., :self.num_params_in] |
|
param_out = parameters[..., -self.num_params_out:] |
|
|
|
input_feats = self.input_layer(input_feature) |
|
input_in = input_feats[..., :self.num_params_in] |
|
input_out = input_feats[..., -self.num_params_out:] |
|
|
|
gate_feats = input_in * param_in |
|
if self.gate_norm_act: |
|
gate_feats = self.activation(self.gate_norm(gate_feats)) |
|
|
|
input_gate = self.input_norm_in(self.input_gate(gate_feats)) |
|
update_gate = self.norm_in(self.update_gate(gate_feats)) |
|
if self.gate_sigmoid: |
|
input_gate = input_gate.sigmoid() |
|
update_gate = update_gate.sigmoid() |
|
param_out = self.norm_out(param_out) |
|
input_out = self.input_norm_out(input_out) |
|
|
|
if self.activate_out: |
|
param_out = self.activation(param_out) |
|
input_out = self.activation(input_out) |
|
|
|
|
|
features = update_gate * param_out + input_gate * input_out |
|
|
|
features = self.fc_layer(features) |
|
features = self.fc_norm(features) |
|
features = self.activation(features) |
|
|
|
return features |
|
|
|
|
|
class CrossAttenHead(nn.Module): |
|
|
|
def __init__(self, |
|
num_classes, |
|
in_channels, |
|
num_proposals, |
|
frozen_head=False, |
|
frozen_pred=False, |
|
with_iou_pred=False, |
|
sphere_cls=False, |
|
ov_classifier_name=None, |
|
num_cls_fcs=1, |
|
num_mask_fcs=1, |
|
conv_kernel_size_1d=3, |
|
conv_kernel_size_2d=1, |
|
use_kernel_updator=False): |
|
super(CrossAttenHead, self).__init__() |
|
self.sphere_cls = sphere_cls |
|
self.with_iou_pred = with_iou_pred |
|
self.frozen_head = frozen_head |
|
self.frozen_pred = frozen_pred |
|
self.num_cls_fcs = num_cls_fcs |
|
self.num_mask_fcs = num_mask_fcs |
|
self.num_classes = num_classes |
|
self.conv_kernel_size_2d = conv_kernel_size_2d |
|
|
|
self.hidden_dim = in_channels |
|
self.feat_channels = in_channels |
|
self.num_proposals = num_proposals |
|
self.hard_mask_thr = 0.5 |
|
self.use_kernel_updator = use_kernel_updator |
|
|
|
if use_kernel_updator: |
|
self.kernel_update = KernelUpdator( |
|
in_channels=256, |
|
feat_channels=256, |
|
out_channels=256, |
|
input_feat_shape=3, |
|
act_cfg=dict(type='ReLU', inplace=True), |
|
norm_cfg=dict(type='LN') |
|
) |
|
else: |
|
self.f_atten = DySepConvAtten(self.feat_channels, self.num_proposals, conv_kernel_size_1d) |
|
self.f_dropout = nn.Dropout(0.0) |
|
self.f_atten_norm = nn.LayerNorm(self.hidden_dim * self.conv_kernel_size_2d ** 2) |
|
self.k_atten = DySepConvAtten(self.feat_channels, self.num_proposals, conv_kernel_size_1d) |
|
self.k_dropout = nn.Dropout(0.0) |
|
self.k_atten_norm = nn.LayerNorm(self.hidden_dim * self.conv_kernel_size_2d ** 2) |
|
|
|
self.s_atten = nn.MultiheadAttention(embed_dim=self.hidden_dim * |
|
self.conv_kernel_size_2d ** 2, |
|
num_heads=8, |
|
dropout=0.0) |
|
self.s_dropout = nn.Dropout(0.0) |
|
self.s_atten_norm = nn.LayerNorm(self.hidden_dim * self.conv_kernel_size_2d ** 2) |
|
|
|
self.ffn = FFN(self.hidden_dim, feedforward_channels=2048, num_fcs=2) |
|
self.ffn_norm = nn.LayerNorm(self.hidden_dim) |
|
|
|
self.cls_fcs = nn.ModuleList() |
|
for _ in range(self.num_cls_fcs): |
|
self.cls_fcs.append(nn.Linear(self.hidden_dim, self.hidden_dim, bias=False)) |
|
self.cls_fcs.append(nn.LayerNorm(self.hidden_dim)) |
|
self.cls_fcs.append(nn.ReLU(True)) |
|
|
|
if sphere_cls: |
|
rank, world_size = get_dist_info() |
|
if ov_classifier_name is None: |
|
_dim = 1024 |
|
cls_embed = torch.empty(self.num_classes, _dim) |
|
torch.nn.init.orthogonal_(cls_embed) |
|
cls_embed = cls_embed[:, None] |
|
else: |
|
|
|
ov_path = os.path.join('./models/', f"{ov_classifier_name}.pth") |
|
cls_embed = torch.load(ov_path) |
|
cls_embed_norm = cls_embed.norm(p=2, dim=-1) |
|
assert torch.allclose(cls_embed_norm, torch.ones_like(cls_embed_norm)) |
|
|
|
|
|
_dim = cls_embed.size(2) |
|
_prototypes = cls_embed.size(1) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
back_token = torch.zeros(1, _dim, dtype=torch.float32, device='cpu') |
|
|
|
cls_embed = torch.cat([ |
|
cls_embed, back_token.repeat(_prototypes, 1)[None] |
|
], dim=0) |
|
self.register_buffer('fc_cls', cls_embed.permute(2, 0, 1).contiguous(), persistent=False) |
|
|
|
|
|
cls_embed_dim = self.fc_cls.size(0) |
|
self.cls_proj = nn.Sequential( |
|
nn.Linear(self.hidden_dim, self.hidden_dim), nn.ReLU(inplace=True), |
|
nn.Linear(self.hidden_dim, self.hidden_dim), nn.ReLU(inplace=True), |
|
nn.Linear(self.hidden_dim, cls_embed_dim) |
|
) |
|
|
|
logit_scale = torch.tensor(4.6052, dtype=torch.float32) |
|
self.register_buffer('logit_scale', logit_scale, persistent=False) |
|
else: |
|
self.fc_cls = nn.Linear(self.hidden_dim, self.num_classes + 1) |
|
|
|
self.mask_fcs = nn.ModuleList() |
|
for _ in range(self.num_mask_fcs): |
|
self.mask_fcs.append(nn.Linear(self.hidden_dim, self.hidden_dim, bias=False)) |
|
self.mask_fcs.append(nn.LayerNorm(self.hidden_dim)) |
|
self.mask_fcs.append(nn.ReLU(True)) |
|
self.fc_mask = nn.Linear(self.hidden_dim, self.hidden_dim) |
|
|
|
if self.with_iou_pred: |
|
self.iou_embed = nn.Sequential( |
|
nn.Linear(self.hidden_dim, self.hidden_dim), |
|
nn.ReLU(inplace=True), |
|
nn.Linear(self.hidden_dim, self.hidden_dim), |
|
nn.ReLU(inplace=True), |
|
nn.Linear(self.hidden_dim, 1), |
|
) |
|
prior_prob = 0.01 |
|
self.bias_value = -math.log((1 - prior_prob) / prior_prob) |
|
|
|
self.apply(self._init_weights) |
|
if not sphere_cls: |
|
nn.init.constant_(self.fc_cls.bias, self.bias_value) |
|
|
|
if self.frozen_head: |
|
self._frozen_head() |
|
if self.frozen_pred: |
|
self._frozen_pred() |
|
|
|
def _init_weights(self, m): |
|
|
|
if isinstance(m, nn.Linear): |
|
trunc_normal_(m.weight, std=.02) |
|
if isinstance(m, nn.Linear) and m.bias is not None: |
|
nn.init.constant_(m.bias, 0) |
|
elif isinstance(m, nn.LayerNorm): |
|
nn.init.constant_(m.bias, 0) |
|
nn.init.constant_(m.weight, 1.0) |
|
|
|
def _frozen_head(self): |
|
for n, p in self.kernel_update.named_parameters(): |
|
p.requires_grad = False |
|
for n, p in self.s_atten.named_parameters(): |
|
p.requires_grad = False |
|
for n, p in self.s_dropout.named_parameters(): |
|
p.requires_grad = False |
|
for n, p in self.s_atten_norm.named_parameters(): |
|
p.requires_grad = False |
|
for n, p in self.ffn.named_parameters(): |
|
p.requires_grad = False |
|
for n, p in self.ffn_norm.named_parameters(): |
|
p.requires_grad = False |
|
|
|
def _frozen_pred(self): |
|
|
|
for n, p in self.cls_fcs.named_parameters(): |
|
p.requires_grad = False |
|
for n, p in self.fc_cls.named_parameters(): |
|
p.requires_grad = False |
|
for n, p in self.mask_fcs.named_parameters(): |
|
p.requires_grad = False |
|
for n, p in self.fc_mask.named_parameters(): |
|
p.requires_grad = False |
|
|
|
def train(self, mode): |
|
super().train(mode) |
|
if self.frozen_head: |
|
self.kernel_update.eval() |
|
self.s_atten.eval() |
|
self.s_dropout.eval() |
|
self.s_atten_norm.eval() |
|
self.ffn.eval() |
|
self.ffn_norm.eval() |
|
if self.frozen_pred: |
|
self.cls_fcs.eval() |
|
self.fc_cls.eval() |
|
self.mask_fcs.eval() |
|
self.fc_mask.eval() |
|
|
|
def forward(self, features, proposal_kernels, mask_preds, self_attn_mask=None): |
|
B, C, H, W = features.shape |
|
|
|
soft_sigmoid_masks = mask_preds.sigmoid() |
|
nonzero_inds = soft_sigmoid_masks > self.hard_mask_thr |
|
hard_sigmoid_masks = nonzero_inds.float() |
|
|
|
|
|
f = torch.einsum('bnhw,bchw->bnc', hard_sigmoid_masks, features) |
|
|
|
num_proposals = proposal_kernels.shape[1] |
|
k = proposal_kernels.view(B, num_proposals, -1) |
|
|
|
|
|
if self.use_kernel_updator: |
|
k = self.kernel_update(f, k) |
|
else: |
|
f_tmp = self.f_atten(k, f) |
|
f = f + self.f_dropout(f_tmp) |
|
f = self.f_atten_norm(f) |
|
|
|
f_tmp = self.k_atten(k, f) |
|
f = f + self.k_dropout(f_tmp) |
|
k = self.k_atten_norm(f) |
|
|
|
|
|
k = k.permute(1, 0, 2) |
|
|
|
k_tmp = self.s_atten(query=k, key=k, value=k, attn_mask=self_attn_mask)[0] |
|
k = k + self.s_dropout(k_tmp) |
|
k = self.s_atten_norm(k.permute(1, 0, 2)) |
|
|
|
obj_feat = self.ffn_norm(self.ffn(k)) |
|
|
|
cls_feat = obj_feat |
|
mask_feat = obj_feat |
|
|
|
for cls_layer in self.cls_fcs: |
|
cls_feat = cls_layer(cls_feat) |
|
|
|
if self.sphere_cls: |
|
cls_embd = self.cls_proj(cls_feat) |
|
cls_score = torch.einsum('bnc,ckp->bnkp', F.normalize(cls_embd, dim=-1), self.fc_cls) |
|
cls_score = cls_score.max(-1).values |
|
cls_score = self.logit_scale.exp() * cls_score |
|
else: |
|
cls_score = self.fc_cls(cls_feat) |
|
for reg_layer in self.mask_fcs: |
|
mask_feat = reg_layer(mask_feat) |
|
|
|
mask_kernels = self.fc_mask(mask_feat) |
|
|
|
new_mask_preds = torch.einsum("bqc,bchw->bqhw", mask_kernels, features) |
|
if self.with_iou_pred: |
|
iou_pred = self.iou_embed(mask_feat) |
|
iou_pred = iou_pred |
|
else: |
|
iou_pred = None |
|
return cls_score, new_mask_preds, iou_pred, obj_feat |
|
|