Spaces:
Running
on
Zero
Running
on
Zero
""" | |
reference | |
- https://github.com/PaddlePaddle/PaddleDetection/blob/develop/ppdet/modeling/backbones/hgnet_v2.py | |
Copyright (c) 2024 The D-FINE Authors. All Rights Reserved. | |
""" | |
import logging | |
import os | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from ...core import register | |
from .common import FrozenBatchNorm2d | |
# Constants for initialization | |
kaiming_normal_ = nn.init.kaiming_normal_ | |
zeros_ = nn.init.zeros_ | |
ones_ = nn.init.ones_ | |
__all__ = ["HGNetv2"] | |
def safe_barrier(): | |
if torch.distributed.is_available() and torch.distributed.is_initialized(): | |
torch.distributed.barrier() | |
else: | |
pass | |
def safe_get_rank(): | |
if torch.distributed.is_available() and torch.distributed.is_initialized(): | |
return torch.distributed.get_rank() | |
else: | |
return 0 | |
class LearnableAffineBlock(nn.Module): | |
def __init__(self, scale_value=1.0, bias_value=0.0): | |
super().__init__() | |
self.scale = nn.Parameter(torch.tensor([scale_value]), requires_grad=True) | |
self.bias = nn.Parameter(torch.tensor([bias_value]), requires_grad=True) | |
def forward(self, x): | |
return self.scale * x + self.bias | |
class ConvBNAct(nn.Module): | |
def __init__( | |
self, | |
in_chs, | |
out_chs, | |
kernel_size, | |
stride=1, | |
groups=1, | |
padding="", | |
use_act=True, | |
use_lab=False, | |
): | |
super().__init__() | |
self.use_act = use_act | |
self.use_lab = use_lab | |
if padding == "same": | |
self.conv = nn.Sequential( | |
nn.ZeroPad2d([0, 1, 0, 1]), | |
nn.Conv2d(in_chs, out_chs, kernel_size, stride, groups=groups, bias=False), | |
) | |
else: | |
self.conv = nn.Conv2d( | |
in_chs, | |
out_chs, | |
kernel_size, | |
stride, | |
padding=(kernel_size - 1) // 2, | |
groups=groups, | |
bias=False, | |
) | |
self.bn = nn.BatchNorm2d(out_chs) | |
if self.use_act: | |
self.act = nn.ReLU() | |
else: | |
self.act = nn.Identity() | |
if self.use_act and self.use_lab: | |
self.lab = LearnableAffineBlock() | |
else: | |
self.lab = nn.Identity() | |
def forward(self, x): | |
x = self.conv(x) | |
x = self.bn(x) | |
x = self.act(x) | |
x = self.lab(x) | |
return x | |
class LightConvBNAct(nn.Module): | |
def __init__( | |
self, | |
in_chs, | |
out_chs, | |
kernel_size, | |
groups=1, | |
use_lab=False, | |
): | |
super().__init__() | |
self.conv1 = ConvBNAct( | |
in_chs, | |
out_chs, | |
kernel_size=1, | |
use_act=False, | |
use_lab=use_lab, | |
) | |
self.conv2 = ConvBNAct( | |
out_chs, | |
out_chs, | |
kernel_size=kernel_size, | |
groups=out_chs, | |
use_act=True, | |
use_lab=use_lab, | |
) | |
def forward(self, x): | |
x = self.conv1(x) | |
x = self.conv2(x) | |
return x | |
class StemBlock(nn.Module): | |
# for HGNetv2 | |
def __init__(self, in_chs, mid_chs, out_chs, use_lab=False): | |
super().__init__() | |
self.stem1 = ConvBNAct( | |
in_chs, | |
mid_chs, | |
kernel_size=3, | |
stride=2, | |
use_lab=use_lab, | |
) | |
self.stem2a = ConvBNAct( | |
mid_chs, | |
mid_chs // 2, | |
kernel_size=2, | |
stride=1, | |
use_lab=use_lab, | |
) | |
self.stem2b = ConvBNAct( | |
mid_chs // 2, | |
mid_chs, | |
kernel_size=2, | |
stride=1, | |
use_lab=use_lab, | |
) | |
self.stem3 = ConvBNAct( | |
mid_chs * 2, | |
mid_chs, | |
kernel_size=3, | |
stride=2, | |
use_lab=use_lab, | |
) | |
self.stem4 = ConvBNAct( | |
mid_chs, | |
out_chs, | |
kernel_size=1, | |
stride=1, | |
use_lab=use_lab, | |
) | |
self.pool = nn.MaxPool2d(kernel_size=2, stride=1, ceil_mode=True) | |
def forward(self, x): | |
x = self.stem1(x) | |
x = F.pad(x, (0, 1, 0, 1)) | |
x2 = self.stem2a(x) | |
x2 = F.pad(x2, (0, 1, 0, 1)) | |
x2 = self.stem2b(x2) | |
x1 = self.pool(x) | |
x = torch.cat([x1, x2], dim=1) | |
x = self.stem3(x) | |
x = self.stem4(x) | |
return x | |
class EseModule(nn.Module): | |
def __init__(self, chs): | |
super().__init__() | |
self.conv = nn.Conv2d( | |
chs, | |
chs, | |
kernel_size=1, | |
stride=1, | |
padding=0, | |
) | |
self.sigmoid = nn.Sigmoid() | |
def forward(self, x): | |
identity = x | |
x = x.mean((2, 3), keepdim=True) | |
x = self.conv(x) | |
x = self.sigmoid(x) | |
return torch.mul(identity, x) | |
class HG_Block(nn.Module): | |
def __init__( | |
self, | |
in_chs, | |
mid_chs, | |
out_chs, | |
layer_num, | |
kernel_size=3, | |
residual=False, | |
light_block=False, | |
use_lab=False, | |
agg="ese", | |
drop_path=0.0, | |
): | |
super().__init__() | |
self.residual = residual | |
self.layers = nn.ModuleList() | |
for i in range(layer_num): | |
if light_block: | |
self.layers.append( | |
LightConvBNAct( | |
in_chs if i == 0 else mid_chs, | |
mid_chs, | |
kernel_size=kernel_size, | |
use_lab=use_lab, | |
) | |
) | |
else: | |
self.layers.append( | |
ConvBNAct( | |
in_chs if i == 0 else mid_chs, | |
mid_chs, | |
kernel_size=kernel_size, | |
stride=1, | |
use_lab=use_lab, | |
) | |
) | |
# feature aggregation | |
total_chs = in_chs + layer_num * mid_chs | |
if agg == "se": | |
aggregation_squeeze_conv = ConvBNAct( | |
total_chs, | |
out_chs // 2, | |
kernel_size=1, | |
stride=1, | |
use_lab=use_lab, | |
) | |
aggregation_excitation_conv = ConvBNAct( | |
out_chs // 2, | |
out_chs, | |
kernel_size=1, | |
stride=1, | |
use_lab=use_lab, | |
) | |
self.aggregation = nn.Sequential( | |
aggregation_squeeze_conv, | |
aggregation_excitation_conv, | |
) | |
else: | |
aggregation_conv = ConvBNAct( | |
total_chs, | |
out_chs, | |
kernel_size=1, | |
stride=1, | |
use_lab=use_lab, | |
) | |
att = EseModule(out_chs) | |
self.aggregation = nn.Sequential( | |
aggregation_conv, | |
att, | |
) | |
self.drop_path = nn.Dropout(drop_path) if drop_path else nn.Identity() | |
def forward(self, x): | |
identity = x | |
output = [x] | |
for layer in self.layers: | |
x = layer(x) | |
output.append(x) | |
x = torch.cat(output, dim=1) | |
x = self.aggregation(x) | |
if self.residual: | |
x = self.drop_path(x) + identity | |
return x | |
class HG_Stage(nn.Module): | |
def __init__( | |
self, | |
in_chs, | |
mid_chs, | |
out_chs, | |
block_num, | |
layer_num, | |
downsample=True, | |
light_block=False, | |
kernel_size=3, | |
use_lab=False, | |
agg="se", | |
drop_path=0.0, | |
): | |
super().__init__() | |
self.downsample = downsample | |
if downsample: | |
self.downsample = ConvBNAct( | |
in_chs, | |
in_chs, | |
kernel_size=3, | |
stride=2, | |
groups=in_chs, | |
use_act=False, | |
use_lab=use_lab, | |
) | |
else: | |
self.downsample = nn.Identity() | |
blocks_list = [] | |
for i in range(block_num): | |
blocks_list.append( | |
HG_Block( | |
in_chs if i == 0 else out_chs, | |
mid_chs, | |
out_chs, | |
layer_num, | |
residual=False if i == 0 else True, | |
kernel_size=kernel_size, | |
light_block=light_block, | |
use_lab=use_lab, | |
agg=agg, | |
drop_path=drop_path[i] if isinstance(drop_path, (list, tuple)) else drop_path, | |
) | |
) | |
self.blocks = nn.Sequential(*blocks_list) | |
def forward(self, x): | |
x = self.downsample(x) | |
x = self.blocks(x) | |
return x | |
class HGNetv2(nn.Module): | |
""" | |
HGNetV2 | |
Args: | |
stem_channels: list. Number of channels for the stem block. | |
stage_type: str. The stage configuration of HGNet. such as the number of channels, stride, etc. | |
use_lab: boolean. Whether to use LearnableAffineBlock in network. | |
lr_mult_list: list. Control the learning rate of different stages. | |
Returns: | |
model: nn.Layer. Specific HGNetV2 model depends on args. | |
""" | |
arch_configs = { | |
"B0": { | |
"stem_channels": [3, 16, 16], | |
"stage_config": { | |
# in_channels, mid_channels, out_channels, num_blocks, downsample, light_block, kernel_size, layer_num | |
"stage1": [16, 16, 64, 1, False, False, 3, 3], | |
"stage2": [64, 32, 256, 1, True, False, 3, 3], | |
"stage3": [256, 64, 512, 2, True, True, 5, 3], | |
"stage4": [512, 128, 1024, 1, True, True, 5, 3], | |
}, | |
"url": "https://github.com/Peterande/storage/releases/download/dfinev1.0/PPHGNetV2_B0_stage1.pth", | |
}, | |
"B1": { | |
"stem_channels": [3, 24, 32], | |
"stage_config": { | |
# in_channels, mid_channels, out_channels, num_blocks, downsample, light_block, kernel_size, layer_num | |
"stage1": [32, 32, 64, 1, False, False, 3, 3], | |
"stage2": [64, 48, 256, 1, True, False, 3, 3], | |
"stage3": [256, 96, 512, 2, True, True, 5, 3], | |
"stage4": [512, 192, 1024, 1, True, True, 5, 3], | |
}, | |
"url": "https://github.com/Peterande/storage/releases/download/dfinev1.0/PPHGNetV2_B1_stage1.pth", | |
}, | |
"B2": { | |
"stem_channels": [3, 24, 32], | |
"stage_config": { | |
# in_channels, mid_channels, out_channels, num_blocks, downsample, light_block, kernel_size, layer_num | |
"stage1": [32, 32, 96, 1, False, False, 3, 4], | |
"stage2": [96, 64, 384, 1, True, False, 3, 4], | |
"stage3": [384, 128, 768, 3, True, True, 5, 4], | |
"stage4": [768, 256, 1536, 1, True, True, 5, 4], | |
}, | |
"url": "https://github.com/Peterande/storage/releases/download/dfinev1.0/PPHGNetV2_B2_stage1.pth", | |
}, | |
"B3": { | |
"stem_channels": [3, 24, 32], | |
"stage_config": { | |
# in_channels, mid_channels, out_channels, num_blocks, downsample, light_block, kernel_size, layer_num | |
"stage1": [32, 32, 128, 1, False, False, 3, 5], | |
"stage2": [128, 64, 512, 1, True, False, 3, 5], | |
"stage3": [512, 128, 1024, 3, True, True, 5, 5], | |
"stage4": [1024, 256, 2048, 1, True, True, 5, 5], | |
}, | |
"url": "https://github.com/Peterande/storage/releases/download/dfinev1.0/PPHGNetV2_B3_stage1.pth", | |
}, | |
"B4": { | |
"stem_channels": [3, 32, 48], | |
"stage_config": { | |
# in_channels, mid_channels, out_channels, num_blocks, downsample, light_block, kernel_size, layer_num | |
"stage1": [48, 48, 128, 1, False, False, 3, 6], | |
"stage2": [128, 96, 512, 1, True, False, 3, 6], | |
"stage3": [512, 192, 1024, 3, True, True, 5, 6], | |
"stage4": [1024, 384, 2048, 1, True, True, 5, 6], | |
}, | |
"url": "https://github.com/Peterande/storage/releases/download/dfinev1.0/PPHGNetV2_B4_stage1.pth", | |
}, | |
"B5": { | |
"stem_channels": [3, 32, 64], | |
"stage_config": { | |
# in_channels, mid_channels, out_channels, num_blocks, downsample, light_block, kernel_size, layer_num | |
"stage1": [64, 64, 128, 1, False, False, 3, 6], | |
"stage2": [128, 128, 512, 2, True, False, 3, 6], | |
"stage3": [512, 256, 1024, 5, True, True, 5, 6], | |
"stage4": [1024, 512, 2048, 2, True, True, 5, 6], | |
}, | |
"url": "https://github.com/Peterande/storage/releases/download/dfinev1.0/PPHGNetV2_B5_stage1.pth", | |
}, | |
"B6": { | |
"stem_channels": [3, 48, 96], | |
"stage_config": { | |
# in_channels, mid_channels, out_channels, num_blocks, downsample, light_block, kernel_size, layer_num | |
"stage1": [96, 96, 192, 2, False, False, 3, 6], | |
"stage2": [192, 192, 512, 3, True, False, 3, 6], | |
"stage3": [512, 384, 1024, 6, True, True, 5, 6], | |
"stage4": [1024, 768, 2048, 3, True, True, 5, 6], | |
}, | |
"url": "https://github.com/Peterande/storage/releases/download/dfinev1.0/PPHGNetV2_B6_stage1.pth", | |
}, | |
} | |
def __init__( | |
self, | |
name, | |
use_lab=False, | |
return_idx=[1, 2, 3], | |
freeze_stem_only=True, | |
freeze_at=0, | |
freeze_norm=True, | |
pretrained=True, | |
local_model_dir="weight/hgnetv2/", | |
): | |
super().__init__() | |
self.use_lab = use_lab | |
self.return_idx = return_idx | |
stem_channels = self.arch_configs[name]["stem_channels"] | |
stage_config = self.arch_configs[name]["stage_config"] | |
download_url = self.arch_configs[name]["url"] | |
self._out_strides = [4, 8, 16, 32] | |
self._out_channels = [stage_config[k][2] for k in stage_config] | |
# stem | |
self.stem = StemBlock( | |
in_chs=stem_channels[0], | |
mid_chs=stem_channels[1], | |
out_chs=stem_channels[2], | |
use_lab=use_lab, | |
) | |
# stages | |
self.stages = nn.ModuleList() | |
for i, k in enumerate(stage_config): | |
( | |
in_channels, | |
mid_channels, | |
out_channels, | |
block_num, | |
downsample, | |
light_block, | |
kernel_size, | |
layer_num, | |
) = stage_config[k] | |
self.stages.append( | |
HG_Stage( | |
in_channels, | |
mid_channels, | |
out_channels, | |
block_num, | |
layer_num, | |
downsample, | |
light_block, | |
kernel_size, | |
use_lab, | |
) | |
) | |
if freeze_at >= 0: | |
self._freeze_parameters(self.stem) | |
if not freeze_stem_only: | |
for i in range(min(freeze_at + 1, len(self.stages))): | |
self._freeze_parameters(self.stages[i]) | |
if freeze_norm: | |
self._freeze_norm(self) | |
if pretrained: | |
RED, GREEN, RESET = "\033[91m", "\033[92m", "\033[0m" | |
try: | |
model_path = local_model_dir + "PPHGNetV2_" + name + "_stage1.pth" | |
if os.path.exists(model_path): | |
state = torch.load(model_path, map_location="cpu") | |
print(f"Loaded stage1 {name} HGNetV2 from local file.") | |
else: | |
# If the file doesn't exist locally, download from the URL | |
if safe_get_rank() == 0: | |
print( | |
GREEN | |
+ "If the pretrained HGNetV2 can't be downloaded automatically. Please check your network connection." | |
+ RESET | |
) | |
print( | |
GREEN | |
+ "Please check your network connection. Or download the model manually from " | |
+ RESET | |
+ f"{download_url}" | |
+ GREEN | |
+ " to " | |
+ RESET | |
+ f"{local_model_dir}." | |
+ RESET | |
) | |
state = torch.hub.load_state_dict_from_url( | |
download_url, map_location="cpu", model_dir=local_model_dir | |
) | |
safe_barrier() | |
else: | |
safe_barrier() | |
state = torch.load(local_model_dir) | |
print(f"Loaded stage1 {name} HGNetV2 from URL.") | |
self.load_state_dict(state) | |
except (Exception, KeyboardInterrupt) as e: | |
if safe_get_rank() == 0: | |
print(f"{str(e)}") | |
logging.error( | |
RED + "CRITICAL WARNING: Failed to load pretrained HGNetV2 model" + RESET | |
) | |
logging.error( | |
GREEN | |
+ "Please check your network connection. Or download the model manually from " | |
+ RESET | |
+ f"{download_url}" | |
+ GREEN | |
+ " to " | |
+ RESET | |
+ f"{local_model_dir}." | |
+ RESET | |
) | |
exit() | |
def _freeze_norm(self, m: nn.Module): | |
if isinstance(m, nn.BatchNorm2d): | |
m = FrozenBatchNorm2d(m.num_features) | |
else: | |
for name, child in m.named_children(): | |
_child = self._freeze_norm(child) | |
if _child is not child: | |
setattr(m, name, _child) | |
return m | |
def _freeze_parameters(self, m: nn.Module): | |
for p in m.parameters(): | |
p.requires_grad = False | |
def forward(self, x): | |
x = self.stem(x) | |
outs = [] | |
for idx, stage in enumerate(self.stages): | |
x = stage(x) | |
if idx in self.return_idx: | |
outs.append(x) | |
return outs | |