D-FINE / src /nn /backbone /hgnetv2.py
developer0hye's picture
Upload 76 files
e85fecb verified
"""
reference
- https://github.com/PaddlePaddle/PaddleDetection/blob/develop/ppdet/modeling/backbones/hgnet_v2.py
Copyright (c) 2024 The D-FINE Authors. All Rights Reserved.
"""
import logging
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
from ...core import register
from .common import FrozenBatchNorm2d
# Constants for initialization
kaiming_normal_ = nn.init.kaiming_normal_
zeros_ = nn.init.zeros_
ones_ = nn.init.ones_
__all__ = ["HGNetv2"]
def safe_barrier():
if torch.distributed.is_available() and torch.distributed.is_initialized():
torch.distributed.barrier()
else:
pass
def safe_get_rank():
if torch.distributed.is_available() and torch.distributed.is_initialized():
return torch.distributed.get_rank()
else:
return 0
class LearnableAffineBlock(nn.Module):
def __init__(self, scale_value=1.0, bias_value=0.0):
super().__init__()
self.scale = nn.Parameter(torch.tensor([scale_value]), requires_grad=True)
self.bias = nn.Parameter(torch.tensor([bias_value]), requires_grad=True)
def forward(self, x):
return self.scale * x + self.bias
class ConvBNAct(nn.Module):
def __init__(
self,
in_chs,
out_chs,
kernel_size,
stride=1,
groups=1,
padding="",
use_act=True,
use_lab=False,
):
super().__init__()
self.use_act = use_act
self.use_lab = use_lab
if padding == "same":
self.conv = nn.Sequential(
nn.ZeroPad2d([0, 1, 0, 1]),
nn.Conv2d(in_chs, out_chs, kernel_size, stride, groups=groups, bias=False),
)
else:
self.conv = nn.Conv2d(
in_chs,
out_chs,
kernel_size,
stride,
padding=(kernel_size - 1) // 2,
groups=groups,
bias=False,
)
self.bn = nn.BatchNorm2d(out_chs)
if self.use_act:
self.act = nn.ReLU()
else:
self.act = nn.Identity()
if self.use_act and self.use_lab:
self.lab = LearnableAffineBlock()
else:
self.lab = nn.Identity()
def forward(self, x):
x = self.conv(x)
x = self.bn(x)
x = self.act(x)
x = self.lab(x)
return x
class LightConvBNAct(nn.Module):
def __init__(
self,
in_chs,
out_chs,
kernel_size,
groups=1,
use_lab=False,
):
super().__init__()
self.conv1 = ConvBNAct(
in_chs,
out_chs,
kernel_size=1,
use_act=False,
use_lab=use_lab,
)
self.conv2 = ConvBNAct(
out_chs,
out_chs,
kernel_size=kernel_size,
groups=out_chs,
use_act=True,
use_lab=use_lab,
)
def forward(self, x):
x = self.conv1(x)
x = self.conv2(x)
return x
class StemBlock(nn.Module):
# for HGNetv2
def __init__(self, in_chs, mid_chs, out_chs, use_lab=False):
super().__init__()
self.stem1 = ConvBNAct(
in_chs,
mid_chs,
kernel_size=3,
stride=2,
use_lab=use_lab,
)
self.stem2a = ConvBNAct(
mid_chs,
mid_chs // 2,
kernel_size=2,
stride=1,
use_lab=use_lab,
)
self.stem2b = ConvBNAct(
mid_chs // 2,
mid_chs,
kernel_size=2,
stride=1,
use_lab=use_lab,
)
self.stem3 = ConvBNAct(
mid_chs * 2,
mid_chs,
kernel_size=3,
stride=2,
use_lab=use_lab,
)
self.stem4 = ConvBNAct(
mid_chs,
out_chs,
kernel_size=1,
stride=1,
use_lab=use_lab,
)
self.pool = nn.MaxPool2d(kernel_size=2, stride=1, ceil_mode=True)
def forward(self, x):
x = self.stem1(x)
x = F.pad(x, (0, 1, 0, 1))
x2 = self.stem2a(x)
x2 = F.pad(x2, (0, 1, 0, 1))
x2 = self.stem2b(x2)
x1 = self.pool(x)
x = torch.cat([x1, x2], dim=1)
x = self.stem3(x)
x = self.stem4(x)
return x
class EseModule(nn.Module):
def __init__(self, chs):
super().__init__()
self.conv = nn.Conv2d(
chs,
chs,
kernel_size=1,
stride=1,
padding=0,
)
self.sigmoid = nn.Sigmoid()
def forward(self, x):
identity = x
x = x.mean((2, 3), keepdim=True)
x = self.conv(x)
x = self.sigmoid(x)
return torch.mul(identity, x)
class HG_Block(nn.Module):
def __init__(
self,
in_chs,
mid_chs,
out_chs,
layer_num,
kernel_size=3,
residual=False,
light_block=False,
use_lab=False,
agg="ese",
drop_path=0.0,
):
super().__init__()
self.residual = residual
self.layers = nn.ModuleList()
for i in range(layer_num):
if light_block:
self.layers.append(
LightConvBNAct(
in_chs if i == 0 else mid_chs,
mid_chs,
kernel_size=kernel_size,
use_lab=use_lab,
)
)
else:
self.layers.append(
ConvBNAct(
in_chs if i == 0 else mid_chs,
mid_chs,
kernel_size=kernel_size,
stride=1,
use_lab=use_lab,
)
)
# feature aggregation
total_chs = in_chs + layer_num * mid_chs
if agg == "se":
aggregation_squeeze_conv = ConvBNAct(
total_chs,
out_chs // 2,
kernel_size=1,
stride=1,
use_lab=use_lab,
)
aggregation_excitation_conv = ConvBNAct(
out_chs // 2,
out_chs,
kernel_size=1,
stride=1,
use_lab=use_lab,
)
self.aggregation = nn.Sequential(
aggregation_squeeze_conv,
aggregation_excitation_conv,
)
else:
aggregation_conv = ConvBNAct(
total_chs,
out_chs,
kernel_size=1,
stride=1,
use_lab=use_lab,
)
att = EseModule(out_chs)
self.aggregation = nn.Sequential(
aggregation_conv,
att,
)
self.drop_path = nn.Dropout(drop_path) if drop_path else nn.Identity()
def forward(self, x):
identity = x
output = [x]
for layer in self.layers:
x = layer(x)
output.append(x)
x = torch.cat(output, dim=1)
x = self.aggregation(x)
if self.residual:
x = self.drop_path(x) + identity
return x
class HG_Stage(nn.Module):
def __init__(
self,
in_chs,
mid_chs,
out_chs,
block_num,
layer_num,
downsample=True,
light_block=False,
kernel_size=3,
use_lab=False,
agg="se",
drop_path=0.0,
):
super().__init__()
self.downsample = downsample
if downsample:
self.downsample = ConvBNAct(
in_chs,
in_chs,
kernel_size=3,
stride=2,
groups=in_chs,
use_act=False,
use_lab=use_lab,
)
else:
self.downsample = nn.Identity()
blocks_list = []
for i in range(block_num):
blocks_list.append(
HG_Block(
in_chs if i == 0 else out_chs,
mid_chs,
out_chs,
layer_num,
residual=False if i == 0 else True,
kernel_size=kernel_size,
light_block=light_block,
use_lab=use_lab,
agg=agg,
drop_path=drop_path[i] if isinstance(drop_path, (list, tuple)) else drop_path,
)
)
self.blocks = nn.Sequential(*blocks_list)
def forward(self, x):
x = self.downsample(x)
x = self.blocks(x)
return x
@register()
class HGNetv2(nn.Module):
"""
HGNetV2
Args:
stem_channels: list. Number of channels for the stem block.
stage_type: str. The stage configuration of HGNet. such as the number of channels, stride, etc.
use_lab: boolean. Whether to use LearnableAffineBlock in network.
lr_mult_list: list. Control the learning rate of different stages.
Returns:
model: nn.Layer. Specific HGNetV2 model depends on args.
"""
arch_configs = {
"B0": {
"stem_channels": [3, 16, 16],
"stage_config": {
# in_channels, mid_channels, out_channels, num_blocks, downsample, light_block, kernel_size, layer_num
"stage1": [16, 16, 64, 1, False, False, 3, 3],
"stage2": [64, 32, 256, 1, True, False, 3, 3],
"stage3": [256, 64, 512, 2, True, True, 5, 3],
"stage4": [512, 128, 1024, 1, True, True, 5, 3],
},
"url": "https://github.com/Peterande/storage/releases/download/dfinev1.0/PPHGNetV2_B0_stage1.pth",
},
"B1": {
"stem_channels": [3, 24, 32],
"stage_config": {
# in_channels, mid_channels, out_channels, num_blocks, downsample, light_block, kernel_size, layer_num
"stage1": [32, 32, 64, 1, False, False, 3, 3],
"stage2": [64, 48, 256, 1, True, False, 3, 3],
"stage3": [256, 96, 512, 2, True, True, 5, 3],
"stage4": [512, 192, 1024, 1, True, True, 5, 3],
},
"url": "https://github.com/Peterande/storage/releases/download/dfinev1.0/PPHGNetV2_B1_stage1.pth",
},
"B2": {
"stem_channels": [3, 24, 32],
"stage_config": {
# in_channels, mid_channels, out_channels, num_blocks, downsample, light_block, kernel_size, layer_num
"stage1": [32, 32, 96, 1, False, False, 3, 4],
"stage2": [96, 64, 384, 1, True, False, 3, 4],
"stage3": [384, 128, 768, 3, True, True, 5, 4],
"stage4": [768, 256, 1536, 1, True, True, 5, 4],
},
"url": "https://github.com/Peterande/storage/releases/download/dfinev1.0/PPHGNetV2_B2_stage1.pth",
},
"B3": {
"stem_channels": [3, 24, 32],
"stage_config": {
# in_channels, mid_channels, out_channels, num_blocks, downsample, light_block, kernel_size, layer_num
"stage1": [32, 32, 128, 1, False, False, 3, 5],
"stage2": [128, 64, 512, 1, True, False, 3, 5],
"stage3": [512, 128, 1024, 3, True, True, 5, 5],
"stage4": [1024, 256, 2048, 1, True, True, 5, 5],
},
"url": "https://github.com/Peterande/storage/releases/download/dfinev1.0/PPHGNetV2_B3_stage1.pth",
},
"B4": {
"stem_channels": [3, 32, 48],
"stage_config": {
# in_channels, mid_channels, out_channels, num_blocks, downsample, light_block, kernel_size, layer_num
"stage1": [48, 48, 128, 1, False, False, 3, 6],
"stage2": [128, 96, 512, 1, True, False, 3, 6],
"stage3": [512, 192, 1024, 3, True, True, 5, 6],
"stage4": [1024, 384, 2048, 1, True, True, 5, 6],
},
"url": "https://github.com/Peterande/storage/releases/download/dfinev1.0/PPHGNetV2_B4_stage1.pth",
},
"B5": {
"stem_channels": [3, 32, 64],
"stage_config": {
# in_channels, mid_channels, out_channels, num_blocks, downsample, light_block, kernel_size, layer_num
"stage1": [64, 64, 128, 1, False, False, 3, 6],
"stage2": [128, 128, 512, 2, True, False, 3, 6],
"stage3": [512, 256, 1024, 5, True, True, 5, 6],
"stage4": [1024, 512, 2048, 2, True, True, 5, 6],
},
"url": "https://github.com/Peterande/storage/releases/download/dfinev1.0/PPHGNetV2_B5_stage1.pth",
},
"B6": {
"stem_channels": [3, 48, 96],
"stage_config": {
# in_channels, mid_channels, out_channels, num_blocks, downsample, light_block, kernel_size, layer_num
"stage1": [96, 96, 192, 2, False, False, 3, 6],
"stage2": [192, 192, 512, 3, True, False, 3, 6],
"stage3": [512, 384, 1024, 6, True, True, 5, 6],
"stage4": [1024, 768, 2048, 3, True, True, 5, 6],
},
"url": "https://github.com/Peterande/storage/releases/download/dfinev1.0/PPHGNetV2_B6_stage1.pth",
},
}
def __init__(
self,
name,
use_lab=False,
return_idx=[1, 2, 3],
freeze_stem_only=True,
freeze_at=0,
freeze_norm=True,
pretrained=True,
local_model_dir="weight/hgnetv2/",
):
super().__init__()
self.use_lab = use_lab
self.return_idx = return_idx
stem_channels = self.arch_configs[name]["stem_channels"]
stage_config = self.arch_configs[name]["stage_config"]
download_url = self.arch_configs[name]["url"]
self._out_strides = [4, 8, 16, 32]
self._out_channels = [stage_config[k][2] for k in stage_config]
# stem
self.stem = StemBlock(
in_chs=stem_channels[0],
mid_chs=stem_channels[1],
out_chs=stem_channels[2],
use_lab=use_lab,
)
# stages
self.stages = nn.ModuleList()
for i, k in enumerate(stage_config):
(
in_channels,
mid_channels,
out_channels,
block_num,
downsample,
light_block,
kernel_size,
layer_num,
) = stage_config[k]
self.stages.append(
HG_Stage(
in_channels,
mid_channels,
out_channels,
block_num,
layer_num,
downsample,
light_block,
kernel_size,
use_lab,
)
)
if freeze_at >= 0:
self._freeze_parameters(self.stem)
if not freeze_stem_only:
for i in range(min(freeze_at + 1, len(self.stages))):
self._freeze_parameters(self.stages[i])
if freeze_norm:
self._freeze_norm(self)
if pretrained:
RED, GREEN, RESET = "\033[91m", "\033[92m", "\033[0m"
try:
model_path = local_model_dir + "PPHGNetV2_" + name + "_stage1.pth"
if os.path.exists(model_path):
state = torch.load(model_path, map_location="cpu")
print(f"Loaded stage1 {name} HGNetV2 from local file.")
else:
# If the file doesn't exist locally, download from the URL
if safe_get_rank() == 0:
print(
GREEN
+ "If the pretrained HGNetV2 can't be downloaded automatically. Please check your network connection."
+ RESET
)
print(
GREEN
+ "Please check your network connection. Or download the model manually from "
+ RESET
+ f"{download_url}"
+ GREEN
+ " to "
+ RESET
+ f"{local_model_dir}."
+ RESET
)
state = torch.hub.load_state_dict_from_url(
download_url, map_location="cpu", model_dir=local_model_dir
)
safe_barrier()
else:
safe_barrier()
state = torch.load(local_model_dir)
print(f"Loaded stage1 {name} HGNetV2 from URL.")
self.load_state_dict(state)
except (Exception, KeyboardInterrupt) as e:
if safe_get_rank() == 0:
print(f"{str(e)}")
logging.error(
RED + "CRITICAL WARNING: Failed to load pretrained HGNetV2 model" + RESET
)
logging.error(
GREEN
+ "Please check your network connection. Or download the model manually from "
+ RESET
+ f"{download_url}"
+ GREEN
+ " to "
+ RESET
+ f"{local_model_dir}."
+ RESET
)
exit()
def _freeze_norm(self, m: nn.Module):
if isinstance(m, nn.BatchNorm2d):
m = FrozenBatchNorm2d(m.num_features)
else:
for name, child in m.named_children():
_child = self._freeze_norm(child)
if _child is not child:
setattr(m, name, _child)
return m
def _freeze_parameters(self, m: nn.Module):
for p in m.parameters():
p.requires_grad = False
def forward(self, x):
x = self.stem(x)
outs = []
for idx, stage in enumerate(self.stages):
x = stage(x)
if idx in self.return_idx:
outs.append(x)
return outs