Spaces:
Running
on
L40S
Running
on
L40S
# The implementation is adopted from TFace,made pubicly available under the Apache-2.0 license at | |
# https://github.com/Tencent/TFace/blob/master/recognition/torchkit/backbone/model_irse.py | |
from collections import namedtuple | |
from torch.nn import BatchNorm1d, BatchNorm2d, Conv2d, Dropout, Linear, MaxPool2d, Module, PReLU, Sequential | |
from .common import Flatten, SEModule, initialize_weights | |
class BasicBlockIR(Module): | |
""" BasicBlock for IRNet | |
""" | |
def __init__(self, in_channel, depth, stride): | |
super(BasicBlockIR, self).__init__() | |
if in_channel == depth: | |
self.shortcut_layer = MaxPool2d(1, stride) | |
else: | |
self.shortcut_layer = Sequential( | |
Conv2d(in_channel, depth, (1, 1), stride, bias=False), | |
BatchNorm2d(depth)) | |
self.res_layer = Sequential( | |
BatchNorm2d(in_channel), | |
Conv2d(in_channel, depth, (3, 3), (1, 1), 1, bias=False), | |
BatchNorm2d(depth), PReLU(depth), | |
Conv2d(depth, depth, (3, 3), stride, 1, bias=False), | |
BatchNorm2d(depth)) | |
def forward(self, x): | |
shortcut = self.shortcut_layer(x) | |
res = self.res_layer(x) | |
return res + shortcut | |
class BottleneckIR(Module): | |
""" BasicBlock with bottleneck for IRNet | |
""" | |
def __init__(self, in_channel, depth, stride): | |
super(BottleneckIR, self).__init__() | |
reduction_channel = depth // 4 | |
if in_channel == depth: | |
self.shortcut_layer = MaxPool2d(1, stride) | |
else: | |
self.shortcut_layer = Sequential( | |
Conv2d(in_channel, depth, (1, 1), stride, bias=False), | |
BatchNorm2d(depth)) | |
self.res_layer = Sequential( | |
BatchNorm2d(in_channel), | |
Conv2d( | |
in_channel, reduction_channel, (1, 1), (1, 1), 0, bias=False), | |
BatchNorm2d(reduction_channel), PReLU(reduction_channel), | |
Conv2d( | |
reduction_channel, | |
reduction_channel, (3, 3), (1, 1), | |
1, | |
bias=False), BatchNorm2d(reduction_channel), | |
PReLU(reduction_channel), | |
Conv2d(reduction_channel, depth, (1, 1), stride, 0, bias=False), | |
BatchNorm2d(depth)) | |
def forward(self, x): | |
shortcut = self.shortcut_layer(x) | |
res = self.res_layer(x) | |
return res + shortcut | |
class BasicBlockIRSE(BasicBlockIR): | |
def __init__(self, in_channel, depth, stride): | |
super(BasicBlockIRSE, self).__init__(in_channel, depth, stride) | |
self.res_layer.add_module('se_block', SEModule(depth, 16)) | |
class BottleneckIRSE(BottleneckIR): | |
def __init__(self, in_channel, depth, stride): | |
super(BottleneckIRSE, self).__init__(in_channel, depth, stride) | |
self.res_layer.add_module('se_block', SEModule(depth, 16)) | |
class Bottleneck(namedtuple('Block', ['in_channel', 'depth', 'stride'])): | |
'''A named tuple describing a ResNet block.''' | |
def get_block(in_channel, depth, num_units, stride=2): | |
return [Bottleneck(in_channel, depth, stride)] + \ | |
[Bottleneck(depth, depth, 1) for i in range(num_units - 1)] | |
def get_blocks(num_layers): | |
if num_layers == 18: | |
blocks = [ | |
get_block(in_channel=64, depth=64, num_units=2), | |
get_block(in_channel=64, depth=128, num_units=2), | |
get_block(in_channel=128, depth=256, num_units=2), | |
get_block(in_channel=256, depth=512, num_units=2) | |
] | |
elif num_layers == 34: | |
blocks = [ | |
get_block(in_channel=64, depth=64, num_units=3), | |
get_block(in_channel=64, depth=128, num_units=4), | |
get_block(in_channel=128, depth=256, num_units=6), | |
get_block(in_channel=256, depth=512, num_units=3) | |
] | |
elif num_layers == 50: | |
blocks = [ | |
get_block(in_channel=64, depth=64, num_units=3), | |
get_block(in_channel=64, depth=128, num_units=4), | |
get_block(in_channel=128, depth=256, num_units=14), | |
get_block(in_channel=256, depth=512, num_units=3) | |
] | |
elif num_layers == 100: | |
blocks = [ | |
get_block(in_channel=64, depth=64, num_units=3), | |
get_block(in_channel=64, depth=128, num_units=13), | |
get_block(in_channel=128, depth=256, num_units=30), | |
get_block(in_channel=256, depth=512, num_units=3) | |
] | |
elif num_layers == 152: | |
blocks = [ | |
get_block(in_channel=64, depth=256, num_units=3), | |
get_block(in_channel=256, depth=512, num_units=8), | |
get_block(in_channel=512, depth=1024, num_units=36), | |
get_block(in_channel=1024, depth=2048, num_units=3) | |
] | |
elif num_layers == 200: | |
blocks = [ | |
get_block(in_channel=64, depth=256, num_units=3), | |
get_block(in_channel=256, depth=512, num_units=24), | |
get_block(in_channel=512, depth=1024, num_units=36), | |
get_block(in_channel=1024, depth=2048, num_units=3) | |
] | |
return blocks | |
class Backbone(Module): | |
def __init__(self, input_size, num_layers, mode='ir'): | |
""" Args: | |
input_size: input_size of backbone | |
num_layers: num_layers of backbone | |
mode: support ir or irse | |
""" | |
super(Backbone, self).__init__() | |
assert input_size[0] in [112, 224], \ | |
'input_size should be [112, 112] or [224, 224]' | |
assert num_layers in [18, 34, 50, 100, 152, 200], \ | |
'num_layers should be 18, 34, 50, 100 or 152' | |
assert mode in ['ir', 'ir_se'], \ | |
'mode should be ir or ir_se' | |
self.input_layer = Sequential( | |
Conv2d(3, 64, (3, 3), 1, 1, bias=False), BatchNorm2d(64), | |
PReLU(64)) | |
blocks = get_blocks(num_layers) | |
if num_layers <= 100: | |
if mode == 'ir': | |
unit_module = BasicBlockIR | |
elif mode == 'ir_se': | |
unit_module = BasicBlockIRSE | |
output_channel = 512 | |
else: | |
if mode == 'ir': | |
unit_module = BottleneckIR | |
elif mode == 'ir_se': | |
unit_module = BottleneckIRSE | |
output_channel = 2048 | |
if input_size[0] == 112: | |
self.output_layer = Sequential( | |
BatchNorm2d(output_channel), Dropout(0.4), Flatten(), | |
Linear(output_channel * 7 * 7, 512), | |
BatchNorm1d(512, affine=False)) | |
else: | |
self.output_layer = Sequential( | |
BatchNorm2d(output_channel), Dropout(0.4), Flatten(), | |
Linear(output_channel * 14 * 14, 512), | |
BatchNorm1d(512, affine=False)) | |
modules = [] | |
mid_layer_indices = [] # [2, 15, 45, 48], total 49 layers for IR101 | |
for block in blocks: | |
if len(mid_layer_indices) == 0: | |
mid_layer_indices.append(len(block) - 1) | |
else: | |
mid_layer_indices.append(len(block) + mid_layer_indices[-1]) | |
for bottleneck in block: | |
modules.append( | |
unit_module(bottleneck.in_channel, bottleneck.depth, | |
bottleneck.stride)) | |
self.body = Sequential(*modules) | |
self.mid_layer_indices = mid_layer_indices[-4:] | |
# self.dtype = next(self.parameters()).dtype | |
initialize_weights(self.modules()) | |
def device(self): | |
return next(self.parameters()).device | |
def dtype(self): | |
return next(self.parameters()).dtype | |
def forward(self, x, return_mid_feats=False): | |
x = self.input_layer(x) | |
if not return_mid_feats: | |
x = self.body(x) | |
x = self.output_layer(x) | |
return x | |
else: | |
out_feats = [] | |
for idx, module in enumerate(self.body): | |
x = module(x) | |
if idx in self.mid_layer_indices: | |
out_feats.append(x) | |
x = self.output_layer(x) | |
return x, out_feats | |
def IR_18(input_size): | |
""" Constructs a ir-18 model. | |
""" | |
model = Backbone(input_size, 18, 'ir') | |
return model | |
def IR_34(input_size): | |
""" Constructs a ir-34 model. | |
""" | |
model = Backbone(input_size, 34, 'ir') | |
return model | |
def IR_50(input_size): | |
""" Constructs a ir-50 model. | |
""" | |
model = Backbone(input_size, 50, 'ir') | |
return model | |
def IR_101(input_size): | |
""" Constructs a ir-101 model. | |
""" | |
model = Backbone(input_size, 100, 'ir') | |
return model | |
def IR_152(input_size): | |
""" Constructs a ir-152 model. | |
""" | |
model = Backbone(input_size, 152, 'ir') | |
return model | |
def IR_200(input_size): | |
""" Constructs a ir-200 model. | |
""" | |
model = Backbone(input_size, 200, 'ir') | |
return model | |
def IR_SE_50(input_size): | |
""" Constructs a ir_se-50 model. | |
""" | |
model = Backbone(input_size, 50, 'ir_se') | |
return model | |
def IR_SE_101(input_size): | |
""" Constructs a ir_se-101 model. | |
""" | |
model = Backbone(input_size, 100, 'ir_se') | |
return model | |
def IR_SE_152(input_size): | |
""" Constructs a ir_se-152 model. | |
""" | |
model = Backbone(input_size, 152, 'ir_se') | |
return model | |
def IR_SE_200(input_size): | |
""" Constructs a ir_se-200 model. | |
""" | |
model = Backbone(input_size, 200, 'ir_se') | |
return model | |