Gallai commited on
Commit
33f2fd7
·
verified ·
1 Parent(s): da0cc5b

Create ResNet_for_CC.py

Browse files
Files changed (1) hide show
  1. ResNet_for_CC.py +55 -0
ResNet_for_CC.py ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torchvision.models as models
4
+
5
+
6
+ class ResClassifier(nn.Module):
7
+ def __init__(self, class_num=14):
8
+ super(ResClassifier, self).__init__()
9
+ self.fc1 = nn.Sequential(
10
+ nn.Linear(128, 64),
11
+ nn.BatchNorm1d(64, affine=True),
12
+ nn.ReLU(inplace=True),
13
+ nn.Dropout()
14
+ )
15
+ self.fc2 = nn.Sequential(
16
+ nn.Linear(64, 64),
17
+ nn.BatchNorm1d(64, affine=True),
18
+ nn.ReLU(inplace=True),
19
+ nn.Dropout()
20
+ )
21
+ self.fc3 = nn.Linear(64, class_num)
22
+
23
+ def forward(self, x):
24
+ fc1_emb = self.fc1(x)
25
+ fc2_emb = self.fc2(fc1_emb)
26
+ logit = self.fc3(fc2_emb)
27
+ return logit
28
+
29
+ class CC_model(nn.Module):
30
+ def __init__(self, num_classes1=14, num_classes2=None):
31
+
32
+ if num_classes2 is None:
33
+ num_classes2 = num_classes1
34
+
35
+ super(CC_model, self).__init__()
36
+ assert num_classes1 == num_classes2
37
+ self.num_classes = num_classes1
38
+ self.model_resnet = models.resnet50(weights='ResNet50_Weights.DEFAULT')
39
+ num_ftrs = self.model_resnet.fc.in_features
40
+ self.model_resnet.fc = nn.Identity()
41
+ self.classification_fc = nn.Linear(num_ftrs, num_classes1)
42
+ self.dr = nn.Linear(num_ftrs, 128)
43
+ self.fc1 = ResClassifier(num_classes1)
44
+ self.fc2 = ResClassifier(num_classes1)
45
+
46
+ def forward(self, x, detach_feature=False):
47
+ feature = self.model_resnet(x)
48
+ res_out = self.classification_fc(feature)
49
+ if detach_feature:
50
+ feature = feature.detach()
51
+ dr_feature = self.dr(feature)
52
+ out1 = self.fc1(dr_feature)
53
+ out2 = self.fc2(dr_feature)
54
+ output_mean = (out1 + out2)
55
+ return output_mean