Moditha24 commited on
Commit
c4dada4
·
verified ·
1 Parent(s): b16395d

Delete resnet.py

Browse files
Files changed (1) hide show
  1. resnet.py +0 -217
resnet.py DELETED
@@ -1,217 +0,0 @@
1
- """ResNet in PyTorch.
2
- ImageNet-Style ResNet
3
- [1] Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun
4
- Deep Residual Learning for Image Recognition. arXiv:1512.03385
5
- Adapted from: https://github.com/bearpaw/pytorch-classification
6
- """
7
- import torch
8
- import torch.nn as nn
9
- import torch.nn.functional as F
10
-
11
-
12
- class BasicBlock(nn.Module):
13
- expansion = 1
14
-
15
- def __init__(self, in_planes, planes, stride=1, is_last=False):
16
- super(BasicBlock, self).__init__()
17
- self.is_last = is_last
18
- self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
19
- self.bn1 = nn.BatchNorm2d(planes)
20
- self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False)
21
- self.bn2 = nn.BatchNorm2d(planes)
22
-
23
- self.shortcut = nn.Sequential()
24
- if stride != 1 or in_planes != self.expansion * planes:
25
- self.shortcut = nn.Sequential(
26
- nn.Conv2d(in_planes, self.expansion * planes, kernel_size=1, stride=stride, bias=False),
27
- nn.BatchNorm2d(self.expansion * planes)
28
- )
29
-
30
- def forward(self, x):
31
- out = F.relu(self.bn1(self.conv1(x)))
32
- out = self.bn2(self.conv2(out))
33
- out += self.shortcut(x)
34
- preact = out
35
- out = F.relu(out)
36
- if self.is_last:
37
- return out, preact
38
- else:
39
- return out
40
-
41
-
42
- class Bottleneck(nn.Module):
43
- expansion = 4
44
-
45
- def __init__(self, in_planes, planes, stride=1, is_last=False):
46
- super(Bottleneck, self).__init__()
47
- self.is_last = is_last
48
- self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False)
49
- self.bn1 = nn.BatchNorm2d(planes)
50
- self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
51
- self.bn2 = nn.BatchNorm2d(planes)
52
- self.conv3 = nn.Conv2d(planes, self.expansion * planes, kernel_size=1, bias=False)
53
- self.bn3 = nn.BatchNorm2d(self.expansion * planes)
54
-
55
- self.shortcut = nn.Sequential()
56
- if stride != 1 or in_planes != self.expansion * planes:
57
- self.shortcut = nn.Sequential(
58
- nn.Conv2d(in_planes, self.expansion * planes, kernel_size=1, stride=stride, bias=False),
59
- nn.BatchNorm2d(self.expansion * planes)
60
- )
61
-
62
- def forward(self, x):
63
- out = F.relu(self.bn1(self.conv1(x)))
64
- out = F.relu(self.bn2(self.conv2(out)))
65
- out = self.bn3(self.conv3(out))
66
- out += self.shortcut(x)
67
- preact = out
68
- out = F.relu(out)
69
- if self.is_last:
70
- return out, preact
71
- else:
72
- return out
73
-
74
-
75
- class ResNet(nn.Module):
76
- def __init__(self, block, num_blocks, in_channel=3, zero_init_residual=False, pool=False):
77
- super(ResNet, self).__init__()
78
- self.in_planes = 64
79
-
80
- if pool:
81
- self.conv1 = nn.Conv2d(in_channel, 64, kernel_size=7, stride=2, padding=3, bias=False)
82
- else:
83
- self.conv1 = nn.Conv2d(in_channel, 64, kernel_size=3, stride=1, padding=1, bias=False)
84
- self.bn1 = nn.BatchNorm2d(64)
85
-
86
- self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) if pool else nn.Identity()
87
- self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1)
88
- self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2)
89
- self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2)
90
- self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2)
91
- self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
92
-
93
- for m in self.modules():
94
- if isinstance(m, nn.Conv2d):
95
- nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
96
- elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
97
- nn.init.constant_(m.weight, 1)
98
- nn.init.constant_(m.bias, 0)
99
-
100
- # Zero-initialize the last BN in each residual branch,
101
- # so that the residual branch starts with zeros, and each residual block behaves
102
- # like an identity. This improves the model by 0.2~0.3% according to:
103
- # https://arxiv.org/abs/1706.02677
104
- if zero_init_residual:
105
- for m in self.modules():
106
- if isinstance(m, Bottleneck):
107
- nn.init.constant_(m.bn3.weight, 0)
108
- elif isinstance(m, BasicBlock):
109
- nn.init.constant_(m.bn2.weight, 0)
110
-
111
- def _make_layer(self, block, planes, num_blocks, stride):
112
- strides = [stride] + [1] * (num_blocks - 1)
113
- layers = []
114
- for i in range(num_blocks):
115
- stride = strides[i]
116
- layers.append(block(self.in_planes, planes, stride))
117
- self.in_planes = planes * block.expansion
118
- return nn.Sequential(*layers)
119
-
120
- def forward(self, x, layer=100):
121
- out = self.maxpool(F.relu(self.bn1(self.conv1(x))))
122
- out = self.layer1(out)
123
- out = self.layer2(out)
124
- out = self.layer3(out)
125
- out = self.layer4(out)
126
- out = self.avgpool(out)
127
- out = torch.flatten(out, 1)
128
- return out
129
-
130
-
131
- def resnet18(**kwargs):
132
- return ResNet(BasicBlock, [2, 2, 2, 2], **kwargs)
133
-
134
-
135
- def resnet34(**kwargs):
136
- return ResNet(BasicBlock, [3, 4, 6, 3], **kwargs)
137
-
138
-
139
- def resnet50(**kwargs):
140
- return ResNet(Bottleneck, [3, 4, 6, 3], **kwargs)
141
-
142
-
143
- def resnet101(**kwargs):
144
- return ResNet(Bottleneck, [3, 4, 23, 3], **kwargs)
145
-
146
-
147
- model_dict = {
148
- 'resnet18': [resnet18, 512],
149
- 'resnet34': [resnet34, 512],
150
- 'resnet50': [resnet50, 2048],
151
- 'resnet101': [resnet101, 2048],
152
- }
153
-
154
-
155
- class LinearBatchNorm(nn.Module):
156
- """Implements BatchNorm1d by BatchNorm2d, for SyncBN purpose"""
157
-
158
- def __init__(self, dim, affine=True):
159
- super(LinearBatchNorm, self).__init__()
160
- self.dim = dim
161
- self.bn = nn.BatchNorm2d(dim, affine=affine)
162
-
163
- def forward(self, x):
164
- x = x.view(-1, self.dim, 1, 1)
165
- x = self.bn(x)
166
- x = x.view(-1, self.dim)
167
- return x
168
-
169
-
170
- class SupConResNet(nn.Module):
171
- """backbone + projection head"""
172
-
173
- def __init__(self, name='resnet50', head='mlp', feat_dim=128, pool=False):
174
- super(SupConResNet, self).__init__()
175
- model_fun, dim_in = model_dict[name]
176
- self.encoder = model_fun(pool=pool)
177
- if head == 'linear':
178
- self.head = nn.Linear(dim_in, feat_dim)
179
- elif head == 'mlp':
180
- self.head = nn.Sequential(
181
- nn.Linear(dim_in, dim_in),
182
- nn.ReLU(inplace=True),
183
- nn.Linear(dim_in, feat_dim)
184
- )
185
- else:
186
- raise NotImplementedError(
187
- 'head not supported: {}'.format(head))
188
-
189
- def forward(self, x):
190
- feat = self.encoder(x)
191
- feat = F.normalize(self.head(feat), dim=1)
192
- return feat
193
-
194
-
195
- class SupCEResNet(nn.Module):
196
- """encoder + classifier"""
197
-
198
- def __init__(self, name='resnet50', num_classes=10, pool=False):
199
- super(SupCEResNet, self).__init__()
200
- model_fun, dim_in = model_dict[name]
201
- self.encoder = model_fun(pool=pool)
202
- self.fc = nn.Linear(dim_in, num_classes)
203
-
204
- def forward(self, x):
205
- return self.fc(self.encoder(x))
206
-
207
-
208
- class LinearClassifier(nn.Module):
209
- """Linear classifier"""
210
-
211
- def __init__(self, name='resnet50', num_classes=10):
212
- super(LinearClassifier, self).__init__()
213
- _, feat_dim = model_dict[name]
214
- self.fc = nn.Linear(feat_dim, num_classes)
215
-
216
- def forward(self, features):
217
- return self.fc(features)