LanXiaoPang613
commited on
Add files via upload
Browse files- PreResNet.py +184 -0
- Train_animal10N.py +487 -0
- dataloader_animal10N.py +200 -0
- dataloader_cifar.py +276 -0
- img/framework.tif +0 -0
- models/CNN.py +193 -0
- models/InceptionResNetV2.py +345 -0
- models/ResNet_Imagenet.py +289 -0
- models/ResNet_cifar.py +688 -0
- models/__init__.py +4 -0
- models/__pycache__/CNN.cpython-310.pyc +0 -0
- models/__pycache__/InceptionResNetV2.cpython-310.pyc +0 -0
- models/__pycache__/ResNet_Imagenet.cpython-310.pyc +0 -0
- models/__pycache__/ResNet_cifar.cpython-310.pyc +0 -0
- models/__pycache__/__init__.cpython-310.pyc +0 -0
- requirements.txt +7 -0
- train_cifar_c2mt.py +807 -0
PreResNet.py
ADDED
@@ -0,0 +1,184 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
|
5 |
+
from torch.autograd import Variable
|
6 |
+
|
7 |
+
|
8 |
+
def conv3x3(in_planes, out_planes, stride=1):
|
9 |
+
return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False)
|
10 |
+
|
11 |
+
|
12 |
+
class BasicBlock(nn.Module):
|
13 |
+
expansion = 1
|
14 |
+
|
15 |
+
def __init__(self, in_planes, planes, stride=1):
|
16 |
+
super(BasicBlock, self).__init__()
|
17 |
+
self.conv1 = conv3x3(in_planes, planes, stride)
|
18 |
+
self.bn1 = nn.BatchNorm2d(planes)
|
19 |
+
self.conv2 = conv3x3(planes, planes)
|
20 |
+
self.bn2 = nn.BatchNorm2d(planes)
|
21 |
+
|
22 |
+
self.shortcut = nn.Sequential()
|
23 |
+
if stride != 1 or in_planes != self.expansion*planes:
|
24 |
+
self.shortcut = nn.Sequential(
|
25 |
+
nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False),
|
26 |
+
nn.BatchNorm2d(self.expansion*planes)
|
27 |
+
)
|
28 |
+
|
29 |
+
def forward(self, x):
|
30 |
+
out = F.relu(self.bn1(self.conv1(x)))
|
31 |
+
out = self.bn2(self.conv2(out))
|
32 |
+
out += self.shortcut(x)
|
33 |
+
out = F.relu(out)
|
34 |
+
return out
|
35 |
+
|
36 |
+
|
37 |
+
class PreActBlock(nn.Module):
|
38 |
+
'''Pre-activation version of the BasicBlock.'''
|
39 |
+
expansion = 1
|
40 |
+
|
41 |
+
def __init__(self, in_planes, planes, stride=1):
|
42 |
+
super(PreActBlock, self).__init__()
|
43 |
+
self.bn1 = nn.BatchNorm2d(in_planes)
|
44 |
+
self.conv1 = conv3x3(in_planes, planes, stride)
|
45 |
+
self.bn2 = nn.BatchNorm2d(planes)
|
46 |
+
self.conv2 = conv3x3(planes, planes)
|
47 |
+
|
48 |
+
self.shortcut = nn.Sequential()
|
49 |
+
if stride != 1 or in_planes != self.expansion*planes:
|
50 |
+
self.shortcut = nn.Sequential(
|
51 |
+
nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False)
|
52 |
+
)
|
53 |
+
|
54 |
+
def forward(self, x):
|
55 |
+
out = F.relu(self.bn1(x))
|
56 |
+
shortcut = self.shortcut(out)
|
57 |
+
out = self.conv1(out)
|
58 |
+
out = self.conv2(F.relu(self.bn2(out)))
|
59 |
+
out += shortcut
|
60 |
+
return out
|
61 |
+
|
62 |
+
|
63 |
+
class Bottleneck(nn.Module):
|
64 |
+
expansion = 4
|
65 |
+
|
66 |
+
def __init__(self, in_planes, planes, stride=1):
|
67 |
+
super(Bottleneck, self).__init__()
|
68 |
+
self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False)
|
69 |
+
self.bn1 = nn.BatchNorm2d(planes)
|
70 |
+
self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
|
71 |
+
self.bn2 = nn.BatchNorm2d(planes)
|
72 |
+
self.conv3 = nn.Conv2d(planes, self.expansion*planes, kernel_size=1, bias=False)
|
73 |
+
self.bn3 = nn.BatchNorm2d(self.expansion*planes)
|
74 |
+
|
75 |
+
self.shortcut = nn.Sequential()
|
76 |
+
if stride != 1 or in_planes != self.expansion*planes:
|
77 |
+
self.shortcut = nn.Sequential(
|
78 |
+
nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False),
|
79 |
+
nn.BatchNorm2d(self.expansion*planes)
|
80 |
+
)
|
81 |
+
|
82 |
+
def forward(self, x):
|
83 |
+
out = F.relu(self.bn1(self.conv1(x)))
|
84 |
+
out = F.relu(self.bn2(self.conv2(out)))
|
85 |
+
out = self.bn3(self.conv3(out))
|
86 |
+
out += self.shortcut(x)
|
87 |
+
out = F.relu(out)
|
88 |
+
return out
|
89 |
+
|
90 |
+
|
91 |
+
class PreActBottleneck(nn.Module):
|
92 |
+
'''Pre-activation version of the original Bottleneck module.'''
|
93 |
+
expansion = 4
|
94 |
+
|
95 |
+
def __init__(self, in_planes, planes, stride=1):
|
96 |
+
super(PreActBottleneck, self).__init__()
|
97 |
+
self.bn1 = nn.BatchNorm2d(in_planes)
|
98 |
+
self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False)
|
99 |
+
self.bn2 = nn.BatchNorm2d(planes)
|
100 |
+
self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
|
101 |
+
self.bn3 = nn.BatchNorm2d(planes)
|
102 |
+
self.conv3 = nn.Conv2d(planes, self.expansion*planes, kernel_size=1, bias=False)
|
103 |
+
|
104 |
+
self.shortcut = nn.Sequential()
|
105 |
+
if stride != 1 or in_planes != self.expansion*planes:
|
106 |
+
self.shortcut = nn.Sequential(
|
107 |
+
nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False)
|
108 |
+
)
|
109 |
+
|
110 |
+
def forward(self, x):
|
111 |
+
out = F.relu(self.bn1(x))
|
112 |
+
shortcut = self.shortcut(out)
|
113 |
+
out = self.conv1(out)
|
114 |
+
out = self.conv2(F.relu(self.bn2(out)))
|
115 |
+
out = self.conv3(F.relu(self.bn3(out)))
|
116 |
+
out += shortcut
|
117 |
+
return out
|
118 |
+
|
119 |
+
|
120 |
+
class ResNet(nn.Module):
|
121 |
+
def __init__(self, block, num_blocks, num_classes=10):
|
122 |
+
super(ResNet, self).__init__()
|
123 |
+
self.in_planes = 64
|
124 |
+
|
125 |
+
self.conv1 = conv3x3(3,64)
|
126 |
+
self.bn1 = nn.BatchNorm2d(64)
|
127 |
+
self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1)
|
128 |
+
self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2)
|
129 |
+
self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2)
|
130 |
+
self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2)
|
131 |
+
self.linear = nn.Linear(512*block.expansion, num_classes)
|
132 |
+
|
133 |
+
def _make_layer(self, block, planes, num_blocks, stride):
|
134 |
+
strides = [stride] + [1]*(num_blocks-1)
|
135 |
+
layers = []
|
136 |
+
for stride in strides:
|
137 |
+
layers.append(block(self.in_planes, planes, stride))
|
138 |
+
self.in_planes = planes * block.expansion
|
139 |
+
return nn.Sequential(*layers)
|
140 |
+
|
141 |
+
def forward(self, x, lin=0, lout=5, feat_out=False):
|
142 |
+
out = x
|
143 |
+
if lin < 1 and lout > -1:
|
144 |
+
out = self.conv1(out)
|
145 |
+
out = self.bn1(out)
|
146 |
+
out = F.relu(out)
|
147 |
+
if lin < 2 and lout > 0:
|
148 |
+
out = self.layer1(out)
|
149 |
+
if lin < 3 and lout > 1:
|
150 |
+
out = self.layer2(out)
|
151 |
+
if lin < 4 and lout > 2:
|
152 |
+
out = self.layer3(out)
|
153 |
+
if lin < 5 and lout > 3:
|
154 |
+
out = self.layer4(out)
|
155 |
+
if lout > 4:
|
156 |
+
out = F.avg_pool2d(out, 4)
|
157 |
+
feat = out.view(out.size(0), -1)
|
158 |
+
out = self.linear(feat)
|
159 |
+
if feat_out:
|
160 |
+
return out, feat
|
161 |
+
else:
|
162 |
+
return out
|
163 |
+
|
164 |
+
|
165 |
+
def ResNet18(num_classes=10):
|
166 |
+
return ResNet(PreActBlock, [2,2,2,2], num_classes=num_classes)
|
167 |
+
|
168 |
+
def ResNet34(num_classes=10):
|
169 |
+
return ResNet(BasicBlock, [3,4,6,3], num_classes=num_classes)
|
170 |
+
|
171 |
+
def ResNet50(num_classes=10):
|
172 |
+
return ResNet(Bottleneck, [3,4,6,3], num_classes=num_classes)
|
173 |
+
|
174 |
+
def ResNet101(num_classes=10):
|
175 |
+
return ResNet(Bottleneck, [3,4,23,3], num_classes=num_classes)
|
176 |
+
|
177 |
+
def ResNet152(num_classes=10):
|
178 |
+
return ResNet(Bottleneck, [3,8,36,3], num_classes=num_classes)
|
179 |
+
|
180 |
+
|
181 |
+
def test():
|
182 |
+
net = ResNet18()
|
183 |
+
y = net(Variable(torch.randn(1,3,32,32)))
|
184 |
+
print(y.size())
|
Train_animal10N.py
ADDED
@@ -0,0 +1,487 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from __future__ import print_function
|
2 |
+
import sys
|
3 |
+
import torch
|
4 |
+
import torch.nn as nn
|
5 |
+
import torch.optim as optim
|
6 |
+
import torch.nn.functional as F
|
7 |
+
import torch.backends.cudnn as cudnn
|
8 |
+
import torchvision
|
9 |
+
import torchvision.models as models
|
10 |
+
from models.CNN import CNN
|
11 |
+
import random
|
12 |
+
import os
|
13 |
+
import argparse
|
14 |
+
import numpy as np
|
15 |
+
import dataloader_animal10N as animal_dataloader
|
16 |
+
from sklearn.mixture import GaussianMixture
|
17 |
+
import copy
|
18 |
+
|
19 |
+
parser = argparse.ArgumentParser(description='PyTorch Clothing1M Training')
|
20 |
+
parser.add_argument('--batch_size', default=128, type=int, help='train batchsize')
|
21 |
+
parser.add_argument('--lr', '--learning_rate', default=0.01, type=float, help='initial learning rate')
|
22 |
+
parser.add_argument('--alpha', default=4, type=float, help='parameter for Beta')
|
23 |
+
parser.add_argument('--lambda_u', default=0, type=float, help='weight for unsupervised loss')
|
24 |
+
parser.add_argument('--p_threshold', default=0.5, type=float, help='clean probability threshold')
|
25 |
+
parser.add_argument('--T', default=0.5, type=float, help='sharpening temperature')
|
26 |
+
parser.add_argument('--num_epochs', default=300, type=int)
|
27 |
+
parser.add_argument('--id', default='animal10N')
|
28 |
+
# parser.add_argument('--data_path', default='E:/Dataset_All/clothing1M/images', type=str, help='path to dataset')
|
29 |
+
parser.add_argument('--data_path', default='C:/Users/Administrator/Desktop/DatasetAll/Animal-10N', type=str, help='path to dataset')
|
30 |
+
parser.add_argument('--seed', default=123)
|
31 |
+
parser.add_argument('--gpuid', default=0, type=int)
|
32 |
+
parser.add_argument('--num_class', default=10, type=int)
|
33 |
+
# parser.add_argument('--num_batches', default=1000, type=int)
|
34 |
+
args = parser.parse_args()
|
35 |
+
|
36 |
+
torch.cuda.set_device(args.gpuid)
|
37 |
+
random.seed(args.seed)
|
38 |
+
torch.manual_seed(args.seed)
|
39 |
+
torch.cuda.manual_seed_all(args.seed)
|
40 |
+
|
41 |
+
|
42 |
+
# Training
|
43 |
+
def train(epoch, net, net2, optimizer, labeled_trainloader, unlabeled_trainloader):
|
44 |
+
net.train()
|
45 |
+
net2.eval() # fix one network and train the other
|
46 |
+
|
47 |
+
unlabeled_train_iter = iter(unlabeled_trainloader)
|
48 |
+
num_iter = (len(labeled_trainloader.dataset) // args.batch_size) + 1
|
49 |
+
for batch_idx, (inputs_x, inputs_x2, labels_x, w_x) in enumerate(labeled_trainloader):
|
50 |
+
try:
|
51 |
+
inputs_u, inputs_u2 = unlabeled_train_iter.__next__()
|
52 |
+
except:
|
53 |
+
unlabeled_train_iter = iter(unlabeled_trainloader)
|
54 |
+
inputs_u, inputs_u2 = unlabeled_train_iter.__next__()
|
55 |
+
batch_size = inputs_x.size(0)
|
56 |
+
|
57 |
+
# Transform label to one-hot
|
58 |
+
labels_x = torch.zeros(batch_size, args.num_class).scatter_(1, labels_x.view(-1, 1), 1)
|
59 |
+
w_x = w_x.view(-1, 1).type(torch.FloatTensor)
|
60 |
+
|
61 |
+
inputs_x, inputs_x2, labels_x, w_x = inputs_x.cuda(), inputs_x2.cuda(), labels_x.cuda(), w_x.cuda()
|
62 |
+
inputs_u, inputs_u2 = inputs_u.cuda(), inputs_u2.cuda()
|
63 |
+
|
64 |
+
with torch.no_grad():
|
65 |
+
# label co-guessing of unlabeled samples
|
66 |
+
outputs_u11 = net(inputs_u)
|
67 |
+
outputs_u12 = net(inputs_u2)
|
68 |
+
outputs_u21 = net2(inputs_u)
|
69 |
+
outputs_u22 = net2(inputs_u2)
|
70 |
+
|
71 |
+
pu = (torch.softmax(outputs_u11, dim=1) + torch.softmax(outputs_u12, dim=1) +
|
72 |
+
torch.softmax(outputs_u21, dim=1) + torch.softmax(outputs_u22, dim=1)) / 4
|
73 |
+
ptu = pu ** (1 / args.T) # temparature sharpening
|
74 |
+
|
75 |
+
targets_u = ptu / ptu.sum(dim=1, keepdim=True) # normalize
|
76 |
+
targets_u = targets_u.detach()
|
77 |
+
|
78 |
+
# label refinement of labeled samples
|
79 |
+
outputs_x = net(inputs_x)
|
80 |
+
outputs_x2 = net(inputs_x2)
|
81 |
+
|
82 |
+
px = (torch.softmax(outputs_x, dim=1) + torch.softmax(outputs_x2, dim=1)) / 2
|
83 |
+
px = w_x * labels_x + (1 - w_x) * px
|
84 |
+
ptx = px ** (1 / args.T) # temparature sharpening
|
85 |
+
|
86 |
+
targets_x = ptx / ptx.sum(dim=1, keepdim=True) # normalize
|
87 |
+
targets_x = targets_x.detach()
|
88 |
+
|
89 |
+
# mixmatch
|
90 |
+
l = np.random.beta(args.alpha, args.alpha)
|
91 |
+
l = max(l, 1 - l)
|
92 |
+
|
93 |
+
all_inputs = torch.cat([inputs_x, inputs_x2, inputs_u, inputs_u2], dim=0)
|
94 |
+
all_targets = torch.cat([targets_x, targets_x, targets_u, targets_u], dim=0)
|
95 |
+
|
96 |
+
idx = torch.randperm(all_inputs.size(0))
|
97 |
+
|
98 |
+
input_a, input_b = all_inputs, all_inputs[idx]
|
99 |
+
target_a, target_b = all_targets, all_targets[idx]
|
100 |
+
|
101 |
+
mixed_input = l * input_a[:batch_size * 2] + (1 - l) * input_b[:batch_size * 2]
|
102 |
+
mixed_target = l * target_a[:batch_size * 2] + (1 - l) * target_b[:batch_size * 2]
|
103 |
+
|
104 |
+
logits = net(mixed_input)
|
105 |
+
|
106 |
+
Lx = -torch.mean(torch.sum(F.log_softmax(logits, dim=1) * mixed_target, dim=1))
|
107 |
+
|
108 |
+
# regularization
|
109 |
+
prior = torch.ones(args.num_class) / args.num_class
|
110 |
+
prior = prior.cuda()
|
111 |
+
pred_mean = torch.softmax(logits, dim=1).mean(0)
|
112 |
+
penalty = torch.sum(prior * torch.log(prior / pred_mean))
|
113 |
+
|
114 |
+
loss = Lx + penalty
|
115 |
+
|
116 |
+
# compute gradient and do SGD step
|
117 |
+
optimizer.zero_grad()
|
118 |
+
loss.backward()
|
119 |
+
optimizer.step()
|
120 |
+
|
121 |
+
sys.stdout.write('\r')
|
122 |
+
sys.stdout.write('Animal10N | Epoch [%3d/%3d] Iter[%3d/%3d]\t Labeled loss: %.4f '
|
123 |
+
% (epoch, args.num_epochs, batch_idx + 1, num_iter, Lx.item()))
|
124 |
+
sys.stdout.flush()
|
125 |
+
|
126 |
+
|
127 |
+
def warmup(net, optimizer, dataloader):
|
128 |
+
net.train()
|
129 |
+
num_batches = 50000/args.batch_size
|
130 |
+
for batch_idx, (inputs, labels, path) in enumerate(dataloader):
|
131 |
+
inputs, labels = inputs.cuda(), labels.cuda()
|
132 |
+
optimizer.zero_grad()
|
133 |
+
outputs = net(inputs)
|
134 |
+
loss = CEloss(outputs, labels)
|
135 |
+
|
136 |
+
penalty = conf_penalty(outputs)
|
137 |
+
L = loss + penalty
|
138 |
+
L.backward()
|
139 |
+
optimizer.step()
|
140 |
+
|
141 |
+
sys.stdout.write('\r')
|
142 |
+
sys.stdout.write('|Warm-up: Iter[%3d/%3d]\t CE-loss: %.4f Conf-Penalty: %.4f'
|
143 |
+
% (2*(batch_idx + 1), num_batches, loss.item(), penalty.item()))
|
144 |
+
sys.stdout.flush()
|
145 |
+
|
146 |
+
|
147 |
+
def val(net, val_loader, best_acc, w_glob=None):
|
148 |
+
net.eval()
|
149 |
+
correct = 0
|
150 |
+
total = 0
|
151 |
+
with torch.no_grad():
|
152 |
+
for batch_idx, (inputs, targets) in enumerate(val_loader):
|
153 |
+
inputs, targets = inputs.cuda(), targets.cuda()
|
154 |
+
outputs = net(inputs)
|
155 |
+
_, predicted = torch.max(outputs, 1)
|
156 |
+
|
157 |
+
total += targets.size(0)
|
158 |
+
correct += predicted.eq(targets).cpu().sum().item()
|
159 |
+
acc = 100. * correct / total
|
160 |
+
print("\n| Validation\t Net%d Acc: %.2f%%" % (k, acc))
|
161 |
+
if acc > best_acc[k - 1]:
|
162 |
+
best_acc[k - 1] = acc
|
163 |
+
print('| Saving Best Net%d ...' % k)
|
164 |
+
save_point = './checkpoint/%s_net%d.pth.tar' % (args.id, k)
|
165 |
+
torch.save(net.state_dict(), save_point)
|
166 |
+
return acc
|
167 |
+
|
168 |
+
|
169 |
+
def test(epoch, net1, net2, test_loader, best_acc, w_glob=None):
|
170 |
+
if w_glob is None:
|
171 |
+
net1.eval()
|
172 |
+
net2.eval()
|
173 |
+
correct = 0
|
174 |
+
correct2 = 0
|
175 |
+
correct1 = 0
|
176 |
+
total = 0
|
177 |
+
with torch.no_grad():
|
178 |
+
for batch_idx, (inputs, targets) in enumerate(test_loader):
|
179 |
+
inputs, targets = inputs.cuda(), targets.cuda()
|
180 |
+
outputs1 = net1(inputs)
|
181 |
+
outputs2 = net2(inputs)
|
182 |
+
outputs = outputs1 + outputs2
|
183 |
+
_, predicted = torch.max(outputs, 1)
|
184 |
+
_, predicted1 = torch.max(outputs1, 1)
|
185 |
+
_, predicted2 = torch.max(outputs2, 1)
|
186 |
+
|
187 |
+
total += targets.size(0)
|
188 |
+
correct += predicted.eq(targets).cpu().sum().item()
|
189 |
+
correct1 += predicted1.eq(targets).cpu().sum().item()
|
190 |
+
correct2 += predicted2.eq(targets).cpu().sum().item()
|
191 |
+
acc = 100. * correct / total
|
192 |
+
acc1 = 100. * correct / total
|
193 |
+
acc2 = 100. * correct / total
|
194 |
+
if best_acc < acc:
|
195 |
+
best_acc = acc
|
196 |
+
print(
|
197 |
+
"\n| Ensemble network Test Epoch #%d\t Accuracy: %.2f, Accuracy1: %.2f, Accuracy2: %.2f, best_acc: %.2f%%\n" % (
|
198 |
+
epoch, acc, acc1, acc2, best_acc))
|
199 |
+
log.write('ensemble_Epoch:%d Accuracy:%.2f, Accuracy1: %.2f, Accuracy2: %.2f, best_acc: %.2f\n' % (
|
200 |
+
epoch, acc, acc1, acc2, best_acc))
|
201 |
+
log.flush()
|
202 |
+
else:
|
203 |
+
net1_w_bak = net1.state_dict()
|
204 |
+
net1.load_state_dict(w_glob)
|
205 |
+
net1.eval()
|
206 |
+
correct = 0
|
207 |
+
total = 0
|
208 |
+
with torch.no_grad():
|
209 |
+
for batch_idx, (inputs, targets) in enumerate(test_loader):
|
210 |
+
inputs, targets = inputs.cuda(), targets.cuda()
|
211 |
+
outputs1 = net1(inputs)
|
212 |
+
_, predicted = torch.max(outputs1, 1)
|
213 |
+
total += targets.size(0)
|
214 |
+
correct += predicted.eq(targets).cpu().sum().item()
|
215 |
+
acc = 100. * correct / total
|
216 |
+
if best_acc < acc:
|
217 |
+
best_acc = acc
|
218 |
+
print("\n| Global network Test Epoch #%d\t Accuracy: %.2f, best_acc: %.2f%%\n" % (epoch, acc, best_acc))
|
219 |
+
log.write('global_Epoch:%d Accuracy:%.2f, best_acc: %.2f\n' % (epoch, acc, best_acc))
|
220 |
+
log.flush()
|
221 |
+
# 恢复权重
|
222 |
+
net1.load_state_dict(net1_w_bak)
|
223 |
+
return best_acc
|
224 |
+
|
225 |
+
|
226 |
+
def eval_train(epoch, model):
|
227 |
+
model.eval()
|
228 |
+
num_samples = eval_loader.dataset.__len__()
|
229 |
+
losses = torch.zeros(num_samples)
|
230 |
+
paths = []
|
231 |
+
n = 0
|
232 |
+
with torch.no_grad():
|
233 |
+
for batch_idx, (inputs, targets, path) in enumerate(eval_loader):
|
234 |
+
inputs, targets = inputs.cuda(), targets.cuda()
|
235 |
+
outputs = model(inputs)
|
236 |
+
loss = CE(outputs, targets)
|
237 |
+
for b in range(inputs.size(0)):
|
238 |
+
losses[n] = loss[b]
|
239 |
+
paths.append(path[b])
|
240 |
+
n += 1
|
241 |
+
sys.stdout.write('\r')
|
242 |
+
sys.stdout.write('| Evaluating loss Iter %3d\t' % (batch_idx))
|
243 |
+
sys.stdout.flush()
|
244 |
+
|
245 |
+
losses = (losses - losses.min()) / (losses.max() - losses.min())
|
246 |
+
losses = losses.reshape(-1, 1)
|
247 |
+
gmm = GaussianMixture(n_components=2, max_iter=10, reg_covar=5e-4, tol=1e-2)
|
248 |
+
gmm.fit(losses)
|
249 |
+
prob = gmm.predict_proba(losses)
|
250 |
+
prob = prob[:, gmm.means_.argmin()]
|
251 |
+
return prob, paths
|
252 |
+
|
253 |
+
|
254 |
+
class NegEntropy(object):
|
255 |
+
def __call__(self, outputs):
|
256 |
+
probs = torch.softmax(outputs, dim=1)
|
257 |
+
return torch.mean(torch.sum(probs.log() * probs, dim=1))
|
258 |
+
|
259 |
+
|
260 |
+
def create_model():
|
261 |
+
use_cnn = False
|
262 |
+
if use_cnn:
|
263 |
+
model = CNN()
|
264 |
+
model = model.cuda()
|
265 |
+
else:
|
266 |
+
model = models.vgg19_bn(pretrained=False)
|
267 |
+
model.classifier._modules['6'] = nn.Linear(4096, 10)
|
268 |
+
model = model.cuda()
|
269 |
+
return model
|
270 |
+
|
271 |
+
|
272 |
+
def FedAvg(w):
|
273 |
+
w_avg = copy.deepcopy(w[0])
|
274 |
+
for k in w_avg.keys():
|
275 |
+
for i in range(1, len(w)):
|
276 |
+
w_avg[k] += w[i][k]
|
277 |
+
# 只考虑iid noise的话,每个client训练样本数一样,所以不用做nk/n
|
278 |
+
w_avg[k] = torch.div(w_avg[k], len(w))
|
279 |
+
|
280 |
+
return w_avg
|
281 |
+
|
282 |
+
|
283 |
+
log = open('./checkpoint/%s.txt' % args.id, 'w')
|
284 |
+
log.flush()
|
285 |
+
|
286 |
+
loader = animal_dataloader.animal_dataloader(root=args.data_path, batch_size=args.batch_size, num_workers=0)
|
287 |
+
|
288 |
+
print('| Building net')
|
289 |
+
net1 = create_model()
|
290 |
+
net2 = create_model()
|
291 |
+
cudnn.benchmark = True
|
292 |
+
|
293 |
+
optimizer1 = optim.SGD(net1.parameters(), lr=args.lr, momentum=0.9, weight_decay=1e-3)
|
294 |
+
optimizer2 = optim.SGD(net2.parameters(), lr=args.lr, momentum=0.9, weight_decay=1e-3)
|
295 |
+
|
296 |
+
CE = nn.CrossEntropyLoss(reduction='none')
|
297 |
+
CEloss = nn.CrossEntropyLoss()
|
298 |
+
conf_penalty = NegEntropy()
|
299 |
+
|
300 |
+
local_round = 5
|
301 |
+
balance_crit = 'median' # 'median'
|
302 |
+
exp_path = './checkpoint/c2mt_animal10N'
|
303 |
+
|
304 |
+
boot_loader = None
|
305 |
+
w_glob = None
|
306 |
+
best_en_acc = 0.
|
307 |
+
best_gl_acc = 0.
|
308 |
+
resume_epoch = 0
|
309 |
+
warm_up = 10
|
310 |
+
if resume_epoch > 0:
|
311 |
+
snapLast = exp_path + str(resume_epoch - 1) + "_global_model.pth"
|
312 |
+
global_state = torch.load(snapLast)
|
313 |
+
# 先更新还是后跟新
|
314 |
+
w_glob = global_state
|
315 |
+
net1.load_state_dict(global_state)
|
316 |
+
net2.load_state_dict(global_state)
|
317 |
+
|
318 |
+
# if True:
|
319 |
+
# snapLast = exp_path + "0_1_model.pth"
|
320 |
+
# global_state = torch.load(snapLast)
|
321 |
+
# net1.load_state_dict(global_state)
|
322 |
+
# snapLast = exp_path + "0_2_model.pth"
|
323 |
+
# global_state = torch.load(snapLast)
|
324 |
+
# net2.load_state_dict(global_state)
|
325 |
+
# test_loader = loader.run('test')
|
326 |
+
# best_en_acc = test(0, net1, net2, test_loader, best_en_acc)
|
327 |
+
|
328 |
+
for epoch in range(resume_epoch, args.num_epochs + 1):
|
329 |
+
lr = args.lr
|
330 |
+
if 100 <= epoch < 150:
|
331 |
+
lr /= 10
|
332 |
+
elif epoch >= 150:
|
333 |
+
lr /= 10
|
334 |
+
# if 15 <= epoch:
|
335 |
+
# lr /= 2
|
336 |
+
for param_group in optimizer1.param_groups:
|
337 |
+
param_group['lr'] = lr
|
338 |
+
for param_group in optimizer2.param_groups:
|
339 |
+
param_group['lr'] = lr
|
340 |
+
|
341 |
+
local_weights = []
|
342 |
+
if epoch < warm_up: # warm up
|
343 |
+
train_loader = loader.run('warmup')
|
344 |
+
print('Warmup Net1')
|
345 |
+
warmup(net1, optimizer1, train_loader)
|
346 |
+
train_loader = loader.run('warmup')
|
347 |
+
print('\nWarmup Net2')
|
348 |
+
warmup(net2, optimizer2, train_loader)
|
349 |
+
if epoch == (warm_up - 1):
|
350 |
+
snapLast = exp_path + str(epoch) + "_1_model.pth"
|
351 |
+
torch.save(net1.state_dict(), snapLast)
|
352 |
+
snapLast = exp_path + str(epoch) + "_2_model.pth"
|
353 |
+
torch.save(net1.state_dict(), snapLast)
|
354 |
+
local_weights.append(net1.state_dict())
|
355 |
+
local_weights.append(net2.state_dict())
|
356 |
+
w_glob = FedAvg(local_weights)
|
357 |
+
else:
|
358 |
+
if epoch != warm_up:
|
359 |
+
net1.load_state_dict(w_glob)
|
360 |
+
net2.load_state_dict(w_glob)
|
361 |
+
|
362 |
+
for rou in range(local_round):
|
363 |
+
print('\n==== net 1 evaluate next epoch training data loss ====')
|
364 |
+
eval_loader = loader.run('eval_train') # evaluate training data loss for next epoch
|
365 |
+
prob1, paths1 = eval_train(epoch, net1)
|
366 |
+
print('\n==== net 2 evaluate next epoch training data loss ====')
|
367 |
+
eval_loader = loader.run('eval_train')
|
368 |
+
prob2, paths2 = eval_train(epoch, net2)
|
369 |
+
|
370 |
+
pred1 = (prob1 > args.p_threshold) # divide dataset
|
371 |
+
pred2 = (prob2 > args.p_threshold)
|
372 |
+
|
373 |
+
non_zero_idx = pred1.nonzero()[0].tolist()
|
374 |
+
aaa = len(non_zero_idx)
|
375 |
+
if balance_crit == "max" or balance_crit == "min" or balance_crit == "median":
|
376 |
+
num_clean_per_class = np.zeros(args.num_class)
|
377 |
+
ppp = np.array(paths1)[non_zero_idx].tolist()
|
378 |
+
target_label = np.array([eval_loader.dataset.train_labels[it] for it in ppp])
|
379 |
+
# target_label = np.array(eval_loader.dataset.train_labels[paths1])[non_zero_idx]
|
380 |
+
for i in range(args.num_class):
|
381 |
+
idx_class = np.where(target_label == i)[0]
|
382 |
+
num_clean_per_class[i] = len(idx_class)
|
383 |
+
|
384 |
+
if balance_crit == "max":
|
385 |
+
num_samples2select_class = np.max(num_clean_per_class)
|
386 |
+
elif balance_crit == "min":
|
387 |
+
num_samples2select_class = np.min(num_clean_per_class)
|
388 |
+
elif balance_crit == "median":
|
389 |
+
num_samples2select_class = np.median(num_clean_per_class)
|
390 |
+
|
391 |
+
for i in range(args.num_class):
|
392 |
+
idx_class = np.where(np.array([eval_loader.dataset.train_labels[it] for it in paths1]) == i)[0]
|
393 |
+
cur_num = num_clean_per_class[i]
|
394 |
+
idx_class2 = non_zero_idx
|
395 |
+
if num_samples2select_class > cur_num:
|
396 |
+
remian_idx = list(set(idx_class.tolist()) - set(idx_class2))
|
397 |
+
idx = list(range(len(remian_idx)))
|
398 |
+
random.shuffle(idx)
|
399 |
+
num_app = int(num_samples2select_class - cur_num)
|
400 |
+
idx = idx[:num_app]
|
401 |
+
for j in idx:
|
402 |
+
non_zero_idx.append(remian_idx[j])
|
403 |
+
non_zero_idx = np.array(non_zero_idx).reshape(-1, )
|
404 |
+
bbb = len(non_zero_idx)
|
405 |
+
num_per_class2 = []
|
406 |
+
for i in range(10):
|
407 |
+
temp = \
|
408 |
+
np.where(np.array([eval_loader.dataset.train_labels[it] for it in paths1])[non_zero_idx.tolist()] == i)[
|
409 |
+
0]
|
410 |
+
num_per_class2.append(len(temp))
|
411 |
+
print('\npred1 appended num per class:', num_per_class2, aaa, bbb)
|
412 |
+
idx_per_class = np.zeros_like(pred1).astype(bool)
|
413 |
+
for i in non_zero_idx:
|
414 |
+
idx_per_class[i] = True
|
415 |
+
pred1 = idx_per_class
|
416 |
+
non_aaa = pred1.nonzero()[0].tolist()
|
417 |
+
assert len(non_aaa) == len(non_zero_idx)
|
418 |
+
|
419 |
+
non_zero_idx2 = pred2.nonzero()[0].tolist()
|
420 |
+
aaa = len(non_zero_idx2)
|
421 |
+
if balance_crit == "max" or balance_crit == "min" or balance_crit == "median":
|
422 |
+
num_clean_per_class = np.zeros(args.num_class)
|
423 |
+
ppp = np.array(paths2)[non_zero_idx].tolist()
|
424 |
+
target_label = np.array([eval_loader.dataset.train_labels[it] for it in ppp])
|
425 |
+
for i in range(args.num_class):
|
426 |
+
idx_class = np.where(target_label == i)[0]
|
427 |
+
num_clean_per_class[i] = len(idx_class)
|
428 |
+
|
429 |
+
if balance_crit == "max":
|
430 |
+
num_samples2select_class = np.max(num_clean_per_class)
|
431 |
+
elif balance_crit == "min":
|
432 |
+
num_samples2select_class = np.min(num_clean_per_class)
|
433 |
+
elif balance_crit == "median":
|
434 |
+
num_samples2select_class = np.median(num_clean_per_class)
|
435 |
+
|
436 |
+
for i in range(args.num_class):
|
437 |
+
idx_class = np.where(np.array([eval_loader.dataset.train_labels[it] for it in paths1]) == i)[0]
|
438 |
+
cur_num = num_clean_per_class[i]
|
439 |
+
idx_class2 = non_zero_idx2
|
440 |
+
if num_samples2select_class > cur_num:
|
441 |
+
remian_idx = list(set(idx_class.tolist()) - set(idx_class2))
|
442 |
+
idx = list(range(len(remian_idx)))
|
443 |
+
random.shuffle(idx)
|
444 |
+
num_app = int(num_samples2select_class - cur_num)
|
445 |
+
idx = idx[:num_app]
|
446 |
+
for j in idx:
|
447 |
+
non_zero_idx2.append(remian_idx[j])
|
448 |
+
non_zero_idx2 = np.array(non_zero_idx2).reshape(-1, )
|
449 |
+
bbb = len(non_zero_idx2)
|
450 |
+
num_per_class2 = []
|
451 |
+
for i in range(10):
|
452 |
+
temp = np.where(
|
453 |
+
np.array([eval_loader.dataset.train_labels[it] for it in paths1])[non_zero_idx2.tolist()] == i)[0]
|
454 |
+
num_per_class2.append(len(temp))
|
455 |
+
print('\npred2 appended num per class:', num_per_class2, aaa, bbb)
|
456 |
+
idx_per_class2 = np.zeros_like(pred2).astype(bool)
|
457 |
+
for i in non_zero_idx2:
|
458 |
+
idx_per_class2[i] = True
|
459 |
+
pred2 = idx_per_class2
|
460 |
+
non_aaa = pred2.nonzero()[0].tolist()
|
461 |
+
assert len(non_aaa) == len(non_zero_idx2)
|
462 |
+
|
463 |
+
print(f'round={rou}/{local_round}, dmix selection, Train Net1')
|
464 |
+
labeled_trainloader, unlabeled_trainloader = loader.run('train', pred2, prob2, paths=paths2) # co-divide
|
465 |
+
train(epoch, net1, net2, optimizer1, labeled_trainloader, unlabeled_trainloader) # train net1
|
466 |
+
|
467 |
+
print(f'\nround={rou}/{local_round}, dmix selection, Train Net2')
|
468 |
+
labeled_trainloader, unlabeled_trainloader = loader.run('train', pred1, prob1, paths=paths1) # co-divide
|
469 |
+
train(epoch, net2, net1, optimizer2, labeled_trainloader, unlabeled_trainloader) # train net2
|
470 |
+
|
471 |
+
test_loader = loader.run('test')
|
472 |
+
if rou != local_round-1:
|
473 |
+
best_en_acc = test(epoch, net1, net2, test_loader, best_en_acc)
|
474 |
+
# best_gl_acc = test(epoch, net1, net2, test_loader, best_gl_acc, w_glob=w_glob)
|
475 |
+
|
476 |
+
print(f'c2m, get global network\n')
|
477 |
+
local_weights.append(net1.state_dict())
|
478 |
+
local_weights.append(net2.state_dict())
|
479 |
+
w_glob = FedAvg(local_weights)
|
480 |
+
if epoch % 1 == 0:
|
481 |
+
snapLast = exp_path + str(epoch) + "_global_model.pth"
|
482 |
+
torch.save(w_glob, snapLast)
|
483 |
+
|
484 |
+
test_loader = loader.run('test')
|
485 |
+
best_en_acc = test(epoch, net1, net2, test_loader, best_en_acc)
|
486 |
+
best_gl_acc = test(epoch, net1, net2, test_loader, best_gl_acc, w_glob=w_glob)
|
487 |
+
|
dataloader_animal10N.py
ADDED
@@ -0,0 +1,200 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from torch.utils.data import Dataset, DataLoader
|
2 |
+
import torchvision.transforms as transforms
|
3 |
+
import random
|
4 |
+
import numpy as np
|
5 |
+
from PIL import Image
|
6 |
+
import json
|
7 |
+
import torch
|
8 |
+
import os
|
9 |
+
import matplotlib
|
10 |
+
|
11 |
+
def unpickle(file):
|
12 |
+
fo = open(file, 'rb').read()
|
13 |
+
size = 64 * 64 * 3 + 1
|
14 |
+
for i in range(50000):
|
15 |
+
arr = np.fromstring(fo[i * size:(i + 1) * size], dtype=np.uint8)
|
16 |
+
lab = np.identity(10)[arr[0]]
|
17 |
+
img = arr[1:].reshape((3, 64, 64)).transpose((1, 2, 0))
|
18 |
+
return img, lab
|
19 |
+
|
20 |
+
class animal_dataset(Dataset):
|
21 |
+
def __init__(self, root, transform, mode, pred=[], path=[], probability=[], num_class=10):
|
22 |
+
|
23 |
+
self.root = root
|
24 |
+
self.transform = transform
|
25 |
+
self.mode = mode
|
26 |
+
|
27 |
+
self.train_dir = root + '/training/'
|
28 |
+
self.test_dir = root + '/testing/'
|
29 |
+
train_imgs = os.listdir(self.train_dir)
|
30 |
+
test_imgs = os.listdir(self.test_dir)
|
31 |
+
self.test_data = []
|
32 |
+
self.test_labels = []
|
33 |
+
noise_file1 = './training_batch.json'
|
34 |
+
noise_file2 = './testing_batch.json'
|
35 |
+
if mode == 'test':
|
36 |
+
if os.path.exists(noise_file2):
|
37 |
+
dict = json.load(open(noise_file2, "r"))
|
38 |
+
self.test_labels = dict['data']
|
39 |
+
self.test_data = dict['label']
|
40 |
+
else:
|
41 |
+
for img in test_imgs:
|
42 |
+
self.test_data.append(self.test_dir+img)
|
43 |
+
self.test_labels.append(int(img[0]))
|
44 |
+
dicts = {}
|
45 |
+
dicts['data'] = self.test_data
|
46 |
+
dicts['label'] = self.test_labels
|
47 |
+
# json.dump(dicts, open(noise_file2, "w"))
|
48 |
+
else:
|
49 |
+
if os.path.exists(noise_file1):
|
50 |
+
dict = json.load(open(noise_file1, "r"))
|
51 |
+
train_data = dict['data']
|
52 |
+
train_labels = dict['label']
|
53 |
+
else:
|
54 |
+
train_data = []
|
55 |
+
train_labels = {}
|
56 |
+
for img in train_imgs:
|
57 |
+
img_path = self.train_dir+img
|
58 |
+
train_data.append(img_path)
|
59 |
+
train_labels[img_path] = (int(img[0]))
|
60 |
+
dicts = {}
|
61 |
+
dicts['data'] = train_data
|
62 |
+
dicts['label'] = train_labels
|
63 |
+
# json.dump(dicts, open(noise_file1, "w"))
|
64 |
+
if self.mode == "all":
|
65 |
+
self.train_data = train_data
|
66 |
+
self.train_labels = train_labels
|
67 |
+
elif self.mode == "labeled":
|
68 |
+
pred_idx = pred.nonzero()[0]
|
69 |
+
train_img = path
|
70 |
+
self.train_data = [train_img[i] for i in pred_idx]
|
71 |
+
self.probability = probability[pred_idx]
|
72 |
+
# self.train_labels = train_labels[pred_idx]
|
73 |
+
self.train_labels = train_labels
|
74 |
+
print("%s data has a size of %d" % (self.mode, len(self.train_data)))
|
75 |
+
elif self.mode == "unlabeled":
|
76 |
+
pred_idx = (1 - pred).nonzero()[0]
|
77 |
+
train_img = path
|
78 |
+
self.train_data = [train_img[i] for i in pred_idx]
|
79 |
+
self.probability = probability[pred_idx]
|
80 |
+
# self.train_labels = train_labels[pred_idx]
|
81 |
+
print("%s data has a size of %d" % (self.mode, len(self.train_data)))
|
82 |
+
self.train_labels = train_labels
|
83 |
+
|
84 |
+
def __getitem__(self, index):
|
85 |
+
if self.mode == 'labeled':
|
86 |
+
img_path = self.train_data[index]
|
87 |
+
target = self.train_labels[img_path]
|
88 |
+
prob = self.probability[index]
|
89 |
+
image = Image.open(img_path).convert('RGB')
|
90 |
+
img1 = self.transform(image)
|
91 |
+
img2 = self.transform(image)
|
92 |
+
return img1, img2, target, prob
|
93 |
+
elif self.mode == 'unlabeled':
|
94 |
+
img_path = self.train_data[index]
|
95 |
+
image = Image.open(img_path).convert('RGB')
|
96 |
+
img1 = self.transform(image)
|
97 |
+
img2 = self.transform(image)
|
98 |
+
return img1, img2
|
99 |
+
elif self.mode == 'all':
|
100 |
+
img_path = self.train_data[index]
|
101 |
+
target = self.train_labels[img_path]
|
102 |
+
image = Image.open(img_path).convert('RGB')
|
103 |
+
img = self.transform(image)
|
104 |
+
return img, target,img_path
|
105 |
+
elif self.mode == 'test':
|
106 |
+
img_path = self.test_data[index]
|
107 |
+
target = self.test_labels[index]
|
108 |
+
image = Image.open(img_path).convert('RGB')
|
109 |
+
img = self.transform(image)
|
110 |
+
return img, target
|
111 |
+
|
112 |
+
def __len__(self):
|
113 |
+
if self.mode == 'test':
|
114 |
+
return len(self.test_data)
|
115 |
+
else:
|
116 |
+
return len(self.train_data)
|
117 |
+
|
118 |
+
|
119 |
+
class animal_dataloader():
|
120 |
+
def __init__(self, root='E:/2_Dataset_All/Animal-10N', batch_size=32, num_workers=0):
|
121 |
+
self.batch_size = batch_size
|
122 |
+
self.num_workers = num_workers
|
123 |
+
self.root = root
|
124 |
+
|
125 |
+
self.transform_train = transforms.Compose([
|
126 |
+
transforms.Resize(64),
|
127 |
+
transforms.RandomCrop(64),
|
128 |
+
transforms.RandomHorizontalFlip(),
|
129 |
+
transforms.ToTensor(),
|
130 |
+
transforms.Normalize((0.6959, 0.6537, 0.6371), (0.3113, 0.3192, 0.3214)),
|
131 |
+
])
|
132 |
+
self.transform_test = transforms.Compose([
|
133 |
+
# transforms.Resize(64),
|
134 |
+
# transforms.CenterCrop(64),
|
135 |
+
transforms.ToTensor(),
|
136 |
+
transforms.Normalize((0.6959, 0.6537, 0.6371), (0.3113, 0.3192, 0.3214)),
|
137 |
+
])
|
138 |
+
|
139 |
+
def run(self, mode, pred=[], prob=[], paths=[]):
|
140 |
+
if mode == 'warmup':
|
141 |
+
warmup_dataset = animal_dataset(self.root, transform=self.transform_train, mode='all')
|
142 |
+
warmup_loader = DataLoader(
|
143 |
+
dataset=warmup_dataset,
|
144 |
+
batch_size=self.batch_size * 2,
|
145 |
+
shuffle=True,
|
146 |
+
num_workers=self.num_workers,
|
147 |
+
pin_memory=True)
|
148 |
+
return warmup_loader
|
149 |
+
elif mode == 'train':
|
150 |
+
labeled_dataset = animal_dataset(self.root, transform=self.transform_train, mode='labeled', pred=pred, path=paths,
|
151 |
+
probability=prob)
|
152 |
+
labeled_loader = DataLoader(
|
153 |
+
dataset=labeled_dataset,
|
154 |
+
batch_size=self.batch_size,
|
155 |
+
shuffle=True,
|
156 |
+
num_workers=self.num_workers,
|
157 |
+
pin_memory=True)
|
158 |
+
unlabeled_dataset = animal_dataset(self.root, transform=self.transform_train, mode='unlabeled', pred=pred,path=paths,
|
159 |
+
probability=prob)
|
160 |
+
unlabeled_loader = DataLoader(
|
161 |
+
dataset=unlabeled_dataset,
|
162 |
+
batch_size=int(self.batch_size),
|
163 |
+
shuffle=True,
|
164 |
+
num_workers=self.num_workers,
|
165 |
+
pin_memory=True)
|
166 |
+
return labeled_loader, unlabeled_loader
|
167 |
+
elif mode == 'eval_train':
|
168 |
+
eval_dataset = animal_dataset(self.root, transform=self.transform_test, mode='all')
|
169 |
+
eval_loader = DataLoader(
|
170 |
+
dataset=eval_dataset,
|
171 |
+
batch_size=self.batch_size,
|
172 |
+
shuffle=False,
|
173 |
+
num_workers=self.num_workers,
|
174 |
+
pin_memory=True)
|
175 |
+
return eval_loader
|
176 |
+
elif mode == 'test':
|
177 |
+
test_dataset = animal_dataset(self.root, transform=self.transform_test, mode='test')
|
178 |
+
test_loader = DataLoader(
|
179 |
+
dataset=test_dataset,
|
180 |
+
batch_size=1000,
|
181 |
+
shuffle=False,
|
182 |
+
num_workers=self.num_workers,
|
183 |
+
pin_memory=True)
|
184 |
+
return test_loader
|
185 |
+
|
186 |
+
# if __name__ == '__main__':
|
187 |
+
# loader = animal_dataloader()
|
188 |
+
# train_loader = loader.run('warmup')
|
189 |
+
# import matplotlib.pyplot as plt
|
190 |
+
# for batch_idx, (inputs, labels, idx, img_path) in enumerate(train_loader):
|
191 |
+
# print(img_path[0])
|
192 |
+
# plt.figure(dpi=300)
|
193 |
+
# # plt.imshow(inputs[0])
|
194 |
+
# plt.imshow(inputs[0].reshape(64, 64, 3))
|
195 |
+
# plt.show()
|
196 |
+
# plt.close()
|
197 |
+
# print(inputs.shape())
|
198 |
+
# print(idx)
|
199 |
+
# print(labels, len(labels))
|
200 |
+
# # print(train_loader.dataset.__len__())
|
dataloader_cifar.py
ADDED
@@ -0,0 +1,276 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from torch.utils.data import Dataset, DataLoader
|
2 |
+
import torchvision.transforms as transforms
|
3 |
+
import random
|
4 |
+
import numpy as np
|
5 |
+
from PIL import Image
|
6 |
+
import json
|
7 |
+
import os
|
8 |
+
import torch
|
9 |
+
from torchnet.meter import AUCMeter
|
10 |
+
|
11 |
+
|
12 |
+
def unpickle(file):
|
13 |
+
import _pickle as cPickle
|
14 |
+
with open(file, 'rb') as fo:
|
15 |
+
dict = cPickle.load(fo, encoding='latin1')
|
16 |
+
return dict
|
17 |
+
|
18 |
+
class cifar_dataset(Dataset):
|
19 |
+
def __init__(self, dataset, r, noise_mode, root_dir, transform, mode, noise_file='', pred=[], probability=[], log='', clean_idx=[], test_form = None):
|
20 |
+
|
21 |
+
self.r = r # noise ratio
|
22 |
+
self.transform = transform
|
23 |
+
self.test_form = test_form
|
24 |
+
self.mode = mode
|
25 |
+
self.transition = {0:0,2:0,4:7,7:7,1:1,9:1,3:5,5:3,6:6,8:8} # class transition for asymmetric noise
|
26 |
+
self.noise_file = noise_file
|
27 |
+
|
28 |
+
if self.mode=='test':
|
29 |
+
if dataset=='cifar10':
|
30 |
+
test_dic = unpickle('%s/test_batch'%root_dir)
|
31 |
+
self.test_data = test_dic['data']
|
32 |
+
self.test_data = self.test_data.reshape((10000, 3, 32, 32))
|
33 |
+
self.test_data = self.test_data.transpose((0, 2, 3, 1))
|
34 |
+
self.test_label = test_dic['labels']
|
35 |
+
elif dataset=='cifar100':
|
36 |
+
test_dic = unpickle('%s/test'%root_dir)
|
37 |
+
self.test_data = test_dic['data']
|
38 |
+
self.test_data = self.test_data.reshape((10000, 3, 32, 32))
|
39 |
+
self.test_data = self.test_data.transpose((0, 2, 3, 1))
|
40 |
+
self.test_label = test_dic['fine_labels']
|
41 |
+
else:
|
42 |
+
train_data=[]
|
43 |
+
train_label=[]
|
44 |
+
if dataset=='cifar10':
|
45 |
+
for n in range(1,6):
|
46 |
+
dpath = '%s/data_batch_%d'%(root_dir,n)
|
47 |
+
data_dic = unpickle(dpath)
|
48 |
+
train_data.append(data_dic['data'])
|
49 |
+
train_label = train_label+data_dic['labels']
|
50 |
+
train_data = np.concatenate(train_data)
|
51 |
+
elif dataset=='cifar100':
|
52 |
+
train_dic = unpickle('%s/train'%root_dir)
|
53 |
+
train_data = train_dic['data']
|
54 |
+
train_label = train_dic['fine_labels']
|
55 |
+
train_data = train_data.reshape((50000, 3, 32, 32))
|
56 |
+
train_data = train_data.transpose((0, 2, 3, 1))
|
57 |
+
|
58 |
+
self.clean_label = np.array(train_label)
|
59 |
+
|
60 |
+
if os.path.exists(noise_file):
|
61 |
+
noise_label = json.load(open(noise_file,"r"))
|
62 |
+
else: #inject noise
|
63 |
+
noise_label = []
|
64 |
+
idx = list(range(50000))
|
65 |
+
random.shuffle(idx)
|
66 |
+
num_noise = int(self.r*50000)
|
67 |
+
noise_idx = idx[:num_noise]
|
68 |
+
for i in range(50000):
|
69 |
+
if i in noise_idx:
|
70 |
+
if noise_mode=='sym':
|
71 |
+
if dataset=='cifar10':
|
72 |
+
noiselabel = random.randint(0,9)
|
73 |
+
elif dataset=='cifar100':
|
74 |
+
noiselabel = random.randint(0,99)
|
75 |
+
noise_label.append(noiselabel)
|
76 |
+
elif noise_mode=='asym':
|
77 |
+
noiselabel = self.transition[train_label[i]]
|
78 |
+
noise_label.append(noiselabel)
|
79 |
+
else:
|
80 |
+
noise_label.append(train_label[i])
|
81 |
+
print("save noisy labels to %s ..."%noise_file)
|
82 |
+
json.dump(noise_label,open(noise_file,"w"))
|
83 |
+
|
84 |
+
if self.mode == 'all':
|
85 |
+
self.train_data = train_data
|
86 |
+
self.noise_label = np.array(noise_label).astype(np.int64)
|
87 |
+
else:
|
88 |
+
if self.mode == "labeled":
|
89 |
+
pred_idx = pred.nonzero()[0]
|
90 |
+
self.probability = [probability[i] for i in pred_idx]
|
91 |
+
|
92 |
+
clean = (np.array(noise_label)==np.array(train_label))
|
93 |
+
auc_meter = AUCMeter()
|
94 |
+
auc_meter.reset()
|
95 |
+
auc_meter.add(probability,clean)
|
96 |
+
auc,_,_ = auc_meter.value()
|
97 |
+
clean_index = np.where(np.array(noise_label)[pred_idx.tolist()] == np.array(self.clean_label)[pred_idx.tolist()])[0]
|
98 |
+
|
99 |
+
num_per_class = []
|
100 |
+
for i in range(max(noise_label)):
|
101 |
+
temp = np.where(np.array(noise_label)[clean_index.tolist()] == i)[0]
|
102 |
+
num_per_class.append(len(temp))
|
103 |
+
num_per_class2 = []
|
104 |
+
for i in range(max(noise_label)):
|
105 |
+
temp = np.where(np.array(noise_label)[pred_idx.tolist()] == i)[0]
|
106 |
+
num_per_class2.append(len(temp))
|
107 |
+
print('clean num per class:', num_per_class, num_per_class2)
|
108 |
+
|
109 |
+
log.write('Numer of labeled samples:%d AUC:%.3f corrected clean num:%d, uncorrected noisy num:%d\n'
|
110 |
+
% (pred.sum(), auc, len(clean_index), len(pred_idx) - len(clean_index)))
|
111 |
+
log.flush()
|
112 |
+
|
113 |
+
elif self.mode == "unlabeled":
|
114 |
+
pred_idx = (1-pred).nonzero()[0]
|
115 |
+
noise_index = np.where(np.array(noise_label)[pred_idx.tolist()] != np.array(self.clean_label)[pred_idx.tolist()])[0]
|
116 |
+
log.write('Numer of unlabeled samples:%d corrected noisy num:%d, uncorrected clean num:%d\n'
|
117 |
+
% (pred.sum(), len(noise_index), len(pred_idx) - len(noise_index)))
|
118 |
+
log.flush()
|
119 |
+
elif self.mode == 'boost':
|
120 |
+
pred_idx = clean_idx
|
121 |
+
|
122 |
+
self.train_data = train_data[pred_idx]
|
123 |
+
self.noise_label = [noise_label[i] for i in pred_idx]
|
124 |
+
print("%s data has a size of %d"%(self.mode,len(self.noise_label)))
|
125 |
+
|
126 |
+
def if_noise(self, pred=None):
|
127 |
+
if pred is None:
|
128 |
+
noise_index = np.where(self.noise_label[:] != self.clean_label[:])[0]
|
129 |
+
clean_index = np.where(self.noise_label[:] == self.clean_label[:])[0]
|
130 |
+
return noise_index, clean_index
|
131 |
+
else:
|
132 |
+
pred_idx1 = pred.nonzero()[0].tolist()
|
133 |
+
clean_index = np.where(np.array(self.noise_label)[pred_idx1] == np.array(self.clean_label)[pred_idx1])[0]
|
134 |
+
pred_idx = (1 - pred).nonzero()[0].tolist()
|
135 |
+
noise_index = np.where(np.array(self.noise_label)[pred_idx] != np.array(self.clean_label)[pred_idx])[0]
|
136 |
+
print(
|
137 |
+
f'选择的非mask样本中正确选取的干净标签数量{len(clean_index)}, 不正确选取的非干净数量{len(pred_idx1) - len(clean_index)}.\t '
|
138 |
+
f'选择的mask样本中正确选取的不干净标签数量{len(noise_index)}, 不正确选取的干净数量{len(pred_idx) - len(noise_index)}')
|
139 |
+
return len(clean_index), (len(pred_idx1) - len(clean_index)), len(noise_index), len(pred_idx) - len(
|
140 |
+
noise_index)
|
141 |
+
def print_noise_rate(self, new_y):
|
142 |
+
temp_y = np.array(new_y.reshape(1, -1).squeeze())
|
143 |
+
clean_index = np.where(temp_y[:] == np.array(self.clean_label)[:])
|
144 |
+
print(f'clean rate is: {len(clean_index[0]) / len(self.clean_label)}')
|
145 |
+
|
146 |
+
def load_train_label(self, new_y):
|
147 |
+
temp_y = np.array(new_y.reshape(1, -1).squeeze()).astype(np.int64)
|
148 |
+
self.noise_label[:] = np.array(temp_y)[:]
|
149 |
+
if os.path.exists(self.noise_file):
|
150 |
+
result = os.path.splitext(self.noise_file)
|
151 |
+
noise_file_temp = result[0]+'_old'+result[1]
|
152 |
+
if not os.path.exists(noise_file_temp):
|
153 |
+
os.rename(self.noise_file, noise_file_temp)
|
154 |
+
# 覆盖原来的noise_file
|
155 |
+
json.dump(self.noise_label.tolist(), open(self.noise_file, "w"))
|
156 |
+
|
157 |
+
def __getitem__(self, index):
|
158 |
+
if self.mode=='labeled':
|
159 |
+
img, target, prob = self.train_data[index], self.noise_label[index], self.probability[index]
|
160 |
+
img = Image.fromarray(img)
|
161 |
+
img1 = self.transform(img)
|
162 |
+
img2 = self.transform(img)
|
163 |
+
return img1, img2, target, prob
|
164 |
+
elif self.mode=='unlabeled':
|
165 |
+
img = self.train_data[index]
|
166 |
+
img = Image.fromarray(img)
|
167 |
+
img1 = self.transform(img)
|
168 |
+
img2 = self.transform(img)
|
169 |
+
return img1, img2
|
170 |
+
elif self.mode=='all':
|
171 |
+
img, target = self.train_data[index], self.noise_label[index]
|
172 |
+
img = Image.fromarray(img)
|
173 |
+
img = self.transform(img)
|
174 |
+
return img, target, index
|
175 |
+
elif self.mode=='test':
|
176 |
+
img, target = self.test_data[index], self.test_label[index]
|
177 |
+
img = Image.fromarray(img)
|
178 |
+
img = self.transform(img)
|
179 |
+
return img, target
|
180 |
+
elif self.mode=='boost':
|
181 |
+
img, target = self.train_data[index], self.noise_label[index]
|
182 |
+
img = Image.fromarray(img)
|
183 |
+
img_no_da = self.test_form(img)
|
184 |
+
img = self.transform(img)
|
185 |
+
return img, img_no_da, target, index
|
186 |
+
|
187 |
+
def __len__(self):
|
188 |
+
if self.mode!='test':
|
189 |
+
return len(self.train_data)
|
190 |
+
else:
|
191 |
+
return len(self.test_data)
|
192 |
+
|
193 |
+
|
194 |
+
class cifar_dataloader():
|
195 |
+
def __init__(self, dataset, r, noise_mode, batch_size, num_workers, root_dir, log, noise_file=''):
|
196 |
+
self.dataset = dataset
|
197 |
+
self.r = r
|
198 |
+
self.noise_mode = noise_mode
|
199 |
+
self.batch_size = batch_size
|
200 |
+
self.num_workers = num_workers
|
201 |
+
self.root_dir = root_dir
|
202 |
+
self.log = log
|
203 |
+
self.noise_file = noise_file
|
204 |
+
if self.dataset=='cifar10':
|
205 |
+
self.transform_train = transforms.Compose([
|
206 |
+
transforms.RandomCrop(32, padding=4),
|
207 |
+
transforms.RandomHorizontalFlip(),
|
208 |
+
transforms.ToTensor(),
|
209 |
+
transforms.Normalize((0.4914, 0.4822, 0.4465),(0.2023, 0.1994, 0.2010)),
|
210 |
+
])
|
211 |
+
self.transform_test = transforms.Compose([
|
212 |
+
transforms.ToTensor(),
|
213 |
+
transforms.Normalize((0.4914, 0.4822, 0.4465),(0.2023, 0.1994, 0.2010)),
|
214 |
+
])
|
215 |
+
elif self.dataset=='cifar100':
|
216 |
+
self.transform_train = transforms.Compose([
|
217 |
+
transforms.RandomCrop(32, padding=4),
|
218 |
+
transforms.RandomHorizontalFlip(),
|
219 |
+
transforms.ToTensor(),
|
220 |
+
transforms.Normalize((0.507, 0.487, 0.441), (0.267, 0.256, 0.276)),
|
221 |
+
])
|
222 |
+
self.transform_test = transforms.Compose([
|
223 |
+
transforms.ToTensor(),
|
224 |
+
transforms.Normalize((0.507, 0.487, 0.441), (0.267, 0.256, 0.276)),
|
225 |
+
])
|
226 |
+
def run(self,mode,pred=[],prob=[], clean_idx=[]):
|
227 |
+
if mode=='warmup':
|
228 |
+
all_dataset = cifar_dataset(dataset=self.dataset, noise_mode=self.noise_mode, r=self.r, root_dir=self.root_dir, transform=self.transform_train, mode="all",noise_file=self.noise_file)
|
229 |
+
trainloader = DataLoader(
|
230 |
+
dataset=all_dataset,
|
231 |
+
batch_size=self.batch_size*2,
|
232 |
+
shuffle=True,
|
233 |
+
num_workers=self.num_workers)
|
234 |
+
return trainloader
|
235 |
+
|
236 |
+
elif mode=='train':
|
237 |
+
labeled_dataset = cifar_dataset(dataset=self.dataset, noise_mode=self.noise_mode, r=self.r, root_dir=self.root_dir, transform=self.transform_train, mode="labeled", noise_file=self.noise_file, pred=pred, probability=prob,log=self.log)
|
238 |
+
labeled_trainloader = DataLoader(
|
239 |
+
dataset=labeled_dataset,
|
240 |
+
batch_size=self.batch_size,
|
241 |
+
shuffle=True,
|
242 |
+
num_workers=self.num_workers)
|
243 |
+
|
244 |
+
unlabeled_dataset = cifar_dataset(dataset=self.dataset, noise_mode=self.noise_mode, r=self.r, root_dir=self.root_dir, transform=self.transform_train, mode="unlabeled", noise_file=self.noise_file, pred=pred, log=self.log)
|
245 |
+
unlabeled_trainloader = DataLoader(
|
246 |
+
dataset=unlabeled_dataset,
|
247 |
+
batch_size=self.batch_size,
|
248 |
+
shuffle=True,
|
249 |
+
num_workers=self.num_workers)
|
250 |
+
return labeled_trainloader, unlabeled_trainloader
|
251 |
+
|
252 |
+
elif mode=='test':
|
253 |
+
test_dataset = cifar_dataset(dataset=self.dataset, noise_mode=self.noise_mode, r=self.r, root_dir=self.root_dir, transform=self.transform_test, mode='test')
|
254 |
+
test_loader = DataLoader(
|
255 |
+
dataset=test_dataset,
|
256 |
+
batch_size=self.batch_size,
|
257 |
+
shuffle=False,
|
258 |
+
num_workers=self.num_workers)
|
259 |
+
return test_loader
|
260 |
+
|
261 |
+
elif mode=='eval_train':
|
262 |
+
eval_dataset = cifar_dataset(dataset=self.dataset, noise_mode=self.noise_mode, r=self.r, root_dir=self.root_dir, transform=self.transform_test, mode='all', noise_file=self.noise_file)
|
263 |
+
eval_loader = DataLoader(
|
264 |
+
dataset=eval_dataset,
|
265 |
+
batch_size=self.batch_size,
|
266 |
+
shuffle=False,
|
267 |
+
num_workers=self.num_workers)
|
268 |
+
return eval_loader
|
269 |
+
elif mode=='boost':
|
270 |
+
eval_dataset = cifar_dataset(dataset=self.dataset, noise_mode=self.noise_mode, r=self.r, root_dir=self.root_dir, transform=self.transform_train, mode=mode, noise_file=self.noise_file, clean_idx=clean_idx, test_form=self.transform_test)
|
271 |
+
eval_loader = DataLoader(
|
272 |
+
dataset=eval_dataset,
|
273 |
+
batch_size=self.batch_size,
|
274 |
+
shuffle=False,
|
275 |
+
num_workers=self.num_workers)
|
276 |
+
return eval_loader
|
img/framework.tif
ADDED
|
models/CNN.py
ADDED
@@ -0,0 +1,193 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
import torch
|
3 |
+
import torch.nn as nn
|
4 |
+
import torch.nn.init as init
|
5 |
+
import torch.nn.functional as F
|
6 |
+
import torch.utils.model_zoo as model_zoo
|
7 |
+
|
8 |
+
class HiddenLayer(nn.Module):
|
9 |
+
def __init__(self, input_size, output_size):
|
10 |
+
super(HiddenLayer, self).__init__()
|
11 |
+
self.fc = nn.Linear(input_size, output_size)
|
12 |
+
self.relu = nn.ReLU()
|
13 |
+
|
14 |
+
def forward(self, x):
|
15 |
+
return self.relu(self.fc(x))
|
16 |
+
|
17 |
+
|
18 |
+
class VNet(nn.Module):
|
19 |
+
def __init__(self, hidden_size=100, num_layers=1):
|
20 |
+
super(VNet, self).__init__()
|
21 |
+
self.first_hidden_layer = HiddenLayer(1, hidden_size)
|
22 |
+
self.rest_hidden_layers = nn.Sequential(*[HiddenLayer(hidden_size, hidden_size) for _ in range(num_layers - 1)])
|
23 |
+
self.output_layer = nn.Linear(hidden_size, 1)
|
24 |
+
|
25 |
+
def forward(self, x):
|
26 |
+
x = self.first_hidden_layer(x)
|
27 |
+
x = self.rest_hidden_layers(x)
|
28 |
+
x = self.output_layer(x)
|
29 |
+
return torch.sigmoid(x)
|
30 |
+
|
31 |
+
|
32 |
+
class CNN(nn.Module):
|
33 |
+
def __init__(self, input_channel=3, n_outputs=10, dropout_rate=0.25):
|
34 |
+
self.dropout_rate = dropout_rate
|
35 |
+
super(CNN, self).__init__()
|
36 |
+
|
37 |
+
#block1
|
38 |
+
self.conv1 = nn.Conv2d(input_channel, 128, kernel_size=3, stride=1, padding=1)
|
39 |
+
self.bn1=nn.BatchNorm2d(128)
|
40 |
+
self.conv2 = nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1)
|
41 |
+
self.bn2=nn.BatchNorm2d(128)
|
42 |
+
self.conv3 = nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1)
|
43 |
+
self.bn3=nn.BatchNorm2d(128)
|
44 |
+
|
45 |
+
#block2
|
46 |
+
self.conv4 = nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1)
|
47 |
+
self.bn4=nn.BatchNorm2d(256)
|
48 |
+
self.conv5 = nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1)
|
49 |
+
self.bn5=nn.BatchNorm2d(256)
|
50 |
+
self.conv6 = nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1)
|
51 |
+
self.bn6=nn.BatchNorm2d(256)
|
52 |
+
|
53 |
+
#block3
|
54 |
+
self.conv7 = nn.Conv2d(256, 512, kernel_size=3, stride=1, padding=0)
|
55 |
+
self.bn7=nn.BatchNorm2d(512)
|
56 |
+
self.conv8 = nn.Conv2d(512, 256, kernel_size=3, stride=1, padding=0)
|
57 |
+
self.bn8=nn.BatchNorm2d(256)
|
58 |
+
self.conv9 = nn.Conv2d(256, 128, kernel_size=3, stride=1, padding=0)
|
59 |
+
self.bn9=nn.BatchNorm2d(128)
|
60 |
+
|
61 |
+
self.pool = nn.MaxPool2d(2, 2)
|
62 |
+
self.avgpool = nn.AvgPool2d(kernel_size=2)
|
63 |
+
|
64 |
+
self.fc=nn.Linear(128,n_outputs)
|
65 |
+
|
66 |
+
for m in self.modules():
|
67 |
+
if isinstance(m, nn.Conv2d):
|
68 |
+
nn.init.kaiming_normal_(m.weight, mode='fan_out')
|
69 |
+
elif isinstance(m, nn.BatchNorm2d):
|
70 |
+
nn.init.constant_(m.weight, 1)
|
71 |
+
nn.init.constant_(m.bias, 0)
|
72 |
+
|
73 |
+
def forward(self, x):
|
74 |
+
|
75 |
+
#block1
|
76 |
+
x=F.leaky_relu(self.bn1(self.conv1(x)), negative_slope=0.01)
|
77 |
+
x=F.leaky_relu(self.bn2(self.conv2(x)), negative_slope=0.01)
|
78 |
+
x=F.leaky_relu(self.bn3(self.conv3(x)), negative_slope=0.01)
|
79 |
+
x=self.pool(x)
|
80 |
+
x=F.dropout2d(x, p=self.dropout_rate)
|
81 |
+
|
82 |
+
#block2
|
83 |
+
x=F.leaky_relu(self.bn4(self.conv4(x)), negative_slope=0.01)
|
84 |
+
x=F.leaky_relu(self.bn5(self.conv5(x)), negative_slope=0.01)
|
85 |
+
x=F.leaky_relu(self.bn6(self.conv6(x)), negative_slope=0.01)
|
86 |
+
x=self.pool(x)
|
87 |
+
x=F.dropout2d(x, p=self.dropout_rate)
|
88 |
+
|
89 |
+
#block3
|
90 |
+
x=F.leaky_relu(self.bn7(self.conv7(x)), negative_slope=0.01)
|
91 |
+
x=F.leaky_relu(self.bn8(self.conv8(x)), negative_slope=0.01)
|
92 |
+
x=F.leaky_relu(self.bn9(self.conv9(x)), negative_slope=0.01)
|
93 |
+
x=self.avgpool(x)
|
94 |
+
|
95 |
+
x = x.view(x.size(0), x.size(1))
|
96 |
+
x=self.fc(x)
|
97 |
+
return x
|
98 |
+
|
99 |
+
def conv3x3(in_planes, out_planes, stride=1):
|
100 |
+
"""3x3 convolution with padding"""
|
101 |
+
return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
|
102 |
+
padding=1, bias=False)
|
103 |
+
|
104 |
+
class BasicBlock(nn.Module):
|
105 |
+
expansion = 1
|
106 |
+
|
107 |
+
def __init__(self, inplanes, planes, stride=1, downsample=None):
|
108 |
+
super(BasicBlock, self).__init__()
|
109 |
+
self.conv1 = conv3x3(inplanes, planes, stride)
|
110 |
+
self.bn1 = nn.BatchNorm2d(planes)
|
111 |
+
self.relu = nn.ReLU(inplace=True)
|
112 |
+
self.conv2 = conv3x3(planes, planes)
|
113 |
+
self.bn2 = nn.BatchNorm2d(planes)
|
114 |
+
self.downsample = downsample
|
115 |
+
self.stride = stride
|
116 |
+
|
117 |
+
def forward(self, x):
|
118 |
+
residual = x
|
119 |
+
|
120 |
+
out = self.conv1(x)
|
121 |
+
out = self.bn1(out)
|
122 |
+
out = self.relu(out)
|
123 |
+
|
124 |
+
out = self.conv2(out)
|
125 |
+
out = self.bn2(out)
|
126 |
+
|
127 |
+
if self.downsample is not None:
|
128 |
+
residual = self.downsample(x)
|
129 |
+
|
130 |
+
out += residual
|
131 |
+
out = self.relu(out)
|
132 |
+
|
133 |
+
return out
|
134 |
+
|
135 |
+
class ResNet(nn.Module):
|
136 |
+
|
137 |
+
def __init__(self, block, layers, num_classes=14):
|
138 |
+
self.inplanes = 64
|
139 |
+
super(ResNet, self).__init__()
|
140 |
+
self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False)
|
141 |
+
self.bn1 = nn.BatchNorm2d(64)
|
142 |
+
self.relu = nn.ReLU(inplace=True)
|
143 |
+
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
|
144 |
+
self.layer1 = self._make_layer(block, 64, layers[0])
|
145 |
+
self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
|
146 |
+
self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
|
147 |
+
self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
|
148 |
+
self.avgpool = nn.AvgPool2d(7, stride=1)
|
149 |
+
self.fc = nn.Linear(512 * block.expansion, num_classes)
|
150 |
+
|
151 |
+
for m in self.modules():
|
152 |
+
if isinstance(m, nn.Conv2d):
|
153 |
+
nn.init.kaiming_normal_(m.weight, mode='fan_out')
|
154 |
+
elif isinstance(m, nn.BatchNorm2d):
|
155 |
+
nn.init.constant_(m.weight, 1)
|
156 |
+
nn.init.constant_(m.bias, 0)
|
157 |
+
|
158 |
+
def _make_layer(self, block, planes, blocks, stride=1):
|
159 |
+
downsample = None
|
160 |
+
if stride != 1 or self.inplanes != planes * block.expansion:
|
161 |
+
downsample = nn.Sequential(
|
162 |
+
nn.Conv2d(self.inplanes, planes * block.expansion,
|
163 |
+
kernel_size=1, stride=stride, bias=False),
|
164 |
+
nn.BatchNorm2d(planes * block.expansion),
|
165 |
+
)
|
166 |
+
|
167 |
+
layers = []
|
168 |
+
layers.append(block(self.inplanes, planes, stride, downsample))
|
169 |
+
self.inplanes = planes * block.expansion
|
170 |
+
for i in range(1, blocks):
|
171 |
+
layers.append(block(self.inplanes, planes))
|
172 |
+
|
173 |
+
return nn.Sequential(*layers)
|
174 |
+
|
175 |
+
def forward(self, x):
|
176 |
+
x = self.conv1(x)
|
177 |
+
x = self.bn1(x)
|
178 |
+
x = self.relu(x)
|
179 |
+
x = self.maxpool(x)
|
180 |
+
|
181 |
+
x = self.layer1(x)
|
182 |
+
x = self.layer2(x)
|
183 |
+
x = self.layer3(x)
|
184 |
+
x = self.layer4(x)
|
185 |
+
|
186 |
+
x = self.avgpool(x)
|
187 |
+
x = x.view(x.size(0), -1)
|
188 |
+
x = self.fc(x)
|
189 |
+
return x
|
190 |
+
|
191 |
+
def resnet18(pretrained=False, **kwargs):
|
192 |
+
model = ResNet(BasicBlock, [2, 2, 2, 2], **kwargs)
|
193 |
+
return model
|
models/InceptionResNetV2.py
ADDED
@@ -0,0 +1,345 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from __future__ import print_function, division, absolute_import
|
2 |
+
import torch
|
3 |
+
import torch.nn as nn
|
4 |
+
import os
|
5 |
+
import sys
|
6 |
+
|
7 |
+
|
8 |
+
class BasicConv2d(nn.Module):
|
9 |
+
|
10 |
+
def __init__(self, in_planes, out_planes, kernel_size, stride, padding=0):
|
11 |
+
super(BasicConv2d, self).__init__()
|
12 |
+
self.conv = nn.Conv2d(in_planes, out_planes,
|
13 |
+
kernel_size=kernel_size, stride=stride,
|
14 |
+
padding=padding, bias=False) # verify bias false
|
15 |
+
self.bn = nn.BatchNorm2d(out_planes,
|
16 |
+
eps=0.001, # value found in tensorflow
|
17 |
+
momentum=0.1, # default pytorch value
|
18 |
+
affine=True)
|
19 |
+
self.relu = nn.ReLU(inplace=False)
|
20 |
+
|
21 |
+
def forward(self, x):
|
22 |
+
x = self.conv(x)
|
23 |
+
x = self.bn(x)
|
24 |
+
x = self.relu(x)
|
25 |
+
return x
|
26 |
+
|
27 |
+
|
28 |
+
class Mixed_5b(nn.Module):
|
29 |
+
|
30 |
+
def __init__(self):
|
31 |
+
super(Mixed_5b, self).__init__()
|
32 |
+
|
33 |
+
self.branch0 = BasicConv2d(192, 96, kernel_size=1, stride=1)
|
34 |
+
|
35 |
+
self.branch1 = nn.Sequential(
|
36 |
+
BasicConv2d(192, 48, kernel_size=1, stride=1),
|
37 |
+
BasicConv2d(48, 64, kernel_size=5, stride=1, padding=2)
|
38 |
+
)
|
39 |
+
|
40 |
+
self.branch2 = nn.Sequential(
|
41 |
+
BasicConv2d(192, 64, kernel_size=1, stride=1),
|
42 |
+
BasicConv2d(64, 96, kernel_size=3, stride=1, padding=1),
|
43 |
+
BasicConv2d(96, 96, kernel_size=3, stride=1, padding=1)
|
44 |
+
)
|
45 |
+
|
46 |
+
self.branch3 = nn.Sequential(
|
47 |
+
nn.AvgPool2d(3, stride=1, padding=1, count_include_pad=False),
|
48 |
+
BasicConv2d(192, 64, kernel_size=1, stride=1)
|
49 |
+
)
|
50 |
+
|
51 |
+
def forward(self, x):
|
52 |
+
x0 = self.branch0(x)
|
53 |
+
x1 = self.branch1(x)
|
54 |
+
x2 = self.branch2(x)
|
55 |
+
x3 = self.branch3(x)
|
56 |
+
out = torch.cat((x0, x1, x2, x3), 1)
|
57 |
+
return out
|
58 |
+
|
59 |
+
|
60 |
+
class Block35(nn.Module):
|
61 |
+
|
62 |
+
def __init__(self, scale=1.0):
|
63 |
+
super(Block35, self).__init__()
|
64 |
+
|
65 |
+
self.scale = scale
|
66 |
+
|
67 |
+
self.branch0 = BasicConv2d(320, 32, kernel_size=1, stride=1)
|
68 |
+
|
69 |
+
self.branch1 = nn.Sequential(
|
70 |
+
BasicConv2d(320, 32, kernel_size=1, stride=1),
|
71 |
+
BasicConv2d(32, 32, kernel_size=3, stride=1, padding=1)
|
72 |
+
)
|
73 |
+
|
74 |
+
self.branch2 = nn.Sequential(
|
75 |
+
BasicConv2d(320, 32, kernel_size=1, stride=1),
|
76 |
+
BasicConv2d(32, 48, kernel_size=3, stride=1, padding=1),
|
77 |
+
BasicConv2d(48, 64, kernel_size=3, stride=1, padding=1)
|
78 |
+
)
|
79 |
+
|
80 |
+
self.conv2d = nn.Conv2d(128, 320, kernel_size=1, stride=1)
|
81 |
+
self.relu = nn.ReLU(inplace=False)
|
82 |
+
|
83 |
+
def forward(self, x):
|
84 |
+
x0 = self.branch0(x)
|
85 |
+
x1 = self.branch1(x)
|
86 |
+
x2 = self.branch2(x)
|
87 |
+
out = torch.cat((x0, x1, x2), 1)
|
88 |
+
out = self.conv2d(out)
|
89 |
+
out = out * self.scale + x
|
90 |
+
out = self.relu(out)
|
91 |
+
return out
|
92 |
+
|
93 |
+
|
94 |
+
class Mixed_6a(nn.Module):
|
95 |
+
|
96 |
+
def __init__(self):
|
97 |
+
super(Mixed_6a, self).__init__()
|
98 |
+
|
99 |
+
self.branch0 = BasicConv2d(320, 384, kernel_size=3, stride=2)
|
100 |
+
|
101 |
+
self.branch1 = nn.Sequential(
|
102 |
+
BasicConv2d(320, 256, kernel_size=1, stride=1),
|
103 |
+
BasicConv2d(256, 256, kernel_size=3, stride=1, padding=1),
|
104 |
+
BasicConv2d(256, 384, kernel_size=3, stride=2)
|
105 |
+
)
|
106 |
+
|
107 |
+
self.branch2 = nn.MaxPool2d(3, stride=2)
|
108 |
+
|
109 |
+
def forward(self, x):
|
110 |
+
x0 = self.branch0(x)
|
111 |
+
x1 = self.branch1(x)
|
112 |
+
x2 = self.branch2(x)
|
113 |
+
out = torch.cat((x0, x1, x2), 1)
|
114 |
+
return out
|
115 |
+
|
116 |
+
|
117 |
+
class Block17(nn.Module):
|
118 |
+
|
119 |
+
def __init__(self, scale=1.0):
|
120 |
+
super(Block17, self).__init__()
|
121 |
+
|
122 |
+
self.scale = scale
|
123 |
+
|
124 |
+
self.branch0 = BasicConv2d(1088, 192, kernel_size=1, stride=1)
|
125 |
+
|
126 |
+
self.branch1 = nn.Sequential(
|
127 |
+
BasicConv2d(1088, 128, kernel_size=1, stride=1),
|
128 |
+
BasicConv2d(128, 160, kernel_size=(1,7), stride=1, padding=(0,3)),
|
129 |
+
BasicConv2d(160, 192, kernel_size=(7,1), stride=1, padding=(3,0))
|
130 |
+
)
|
131 |
+
|
132 |
+
self.conv2d = nn.Conv2d(384, 1088, kernel_size=1, stride=1)
|
133 |
+
self.relu = nn.ReLU(inplace=False)
|
134 |
+
|
135 |
+
def forward(self, x):
|
136 |
+
x0 = self.branch0(x)
|
137 |
+
x1 = self.branch1(x)
|
138 |
+
out = torch.cat((x0, x1), 1)
|
139 |
+
out = self.conv2d(out)
|
140 |
+
out = out * self.scale + x
|
141 |
+
out = self.relu(out)
|
142 |
+
return out
|
143 |
+
|
144 |
+
|
145 |
+
class Mixed_7a(nn.Module):
|
146 |
+
|
147 |
+
def __init__(self):
|
148 |
+
super(Mixed_7a, self).__init__()
|
149 |
+
|
150 |
+
self.branch0 = nn.Sequential(
|
151 |
+
BasicConv2d(1088, 256, kernel_size=1, stride=1),
|
152 |
+
BasicConv2d(256, 384, kernel_size=3, stride=2)
|
153 |
+
)
|
154 |
+
|
155 |
+
self.branch1 = nn.Sequential(
|
156 |
+
BasicConv2d(1088, 256, kernel_size=1, stride=1),
|
157 |
+
BasicConv2d(256, 288, kernel_size=3, stride=2)
|
158 |
+
)
|
159 |
+
|
160 |
+
self.branch2 = nn.Sequential(
|
161 |
+
BasicConv2d(1088, 256, kernel_size=1, stride=1),
|
162 |
+
BasicConv2d(256, 288, kernel_size=3, stride=1, padding=1),
|
163 |
+
BasicConv2d(288, 320, kernel_size=3, stride=2)
|
164 |
+
)
|
165 |
+
|
166 |
+
self.branch3 = nn.MaxPool2d(3, stride=2)
|
167 |
+
|
168 |
+
def forward(self, x):
|
169 |
+
x0 = self.branch0(x)
|
170 |
+
x1 = self.branch1(x)
|
171 |
+
x2 = self.branch2(x)
|
172 |
+
x3 = self.branch3(x)
|
173 |
+
out = torch.cat((x0, x1, x2, x3), 1)
|
174 |
+
return out
|
175 |
+
|
176 |
+
|
177 |
+
class Block8(nn.Module):
|
178 |
+
|
179 |
+
def __init__(self, scale=1.0, noReLU=False):
|
180 |
+
super(Block8, self).__init__()
|
181 |
+
|
182 |
+
self.scale = scale
|
183 |
+
self.noReLU = noReLU
|
184 |
+
|
185 |
+
self.branch0 = BasicConv2d(2080, 192, kernel_size=1, stride=1)
|
186 |
+
|
187 |
+
self.branch1 = nn.Sequential(
|
188 |
+
BasicConv2d(2080, 192, kernel_size=1, stride=1),
|
189 |
+
BasicConv2d(192, 224, kernel_size=(1,3), stride=1, padding=(0,1)),
|
190 |
+
BasicConv2d(224, 256, kernel_size=(3,1), stride=1, padding=(1,0))
|
191 |
+
)
|
192 |
+
|
193 |
+
self.conv2d = nn.Conv2d(448, 2080, kernel_size=1, stride=1)
|
194 |
+
if not self.noReLU:
|
195 |
+
self.relu = nn.ReLU(inplace=False)
|
196 |
+
|
197 |
+
def forward(self, x):
|
198 |
+
x0 = self.branch0(x)
|
199 |
+
x1 = self.branch1(x)
|
200 |
+
out = torch.cat((x0, x1), 1)
|
201 |
+
out = self.conv2d(out)
|
202 |
+
out = out * self.scale + x
|
203 |
+
if not self.noReLU:
|
204 |
+
out = self.relu(out)
|
205 |
+
return out
|
206 |
+
|
207 |
+
|
208 |
+
class InceptionResNetV2(nn.Module):
|
209 |
+
|
210 |
+
def __init__(self, num_classes=50):
|
211 |
+
super(InceptionResNetV2, self).__init__()
|
212 |
+
# Special attributs
|
213 |
+
self.num_classes = num_classes
|
214 |
+
self.input_space = None
|
215 |
+
self.input_size = (299, 299, 3)
|
216 |
+
self.mean = None
|
217 |
+
self.std = None
|
218 |
+
# Modules
|
219 |
+
self.conv2d_1a = BasicConv2d(3, 32, kernel_size=3, stride=2)
|
220 |
+
self.conv2d_2a = BasicConv2d(32, 32, kernel_size=3, stride=1)
|
221 |
+
self.conv2d_2b = BasicConv2d(32, 64, kernel_size=3, stride=1, padding=1)
|
222 |
+
self.maxpool_3a = nn.MaxPool2d(3, stride=2)
|
223 |
+
self.conv2d_3b = BasicConv2d(64, 80, kernel_size=1, stride=1)
|
224 |
+
self.conv2d_4a = BasicConv2d(80, 192, kernel_size=3, stride=1)
|
225 |
+
self.maxpool_5a = nn.MaxPool2d(3, stride=2)
|
226 |
+
self.mixed_5b = Mixed_5b()
|
227 |
+
self.repeat = nn.Sequential(
|
228 |
+
Block35(scale=0.17),
|
229 |
+
Block35(scale=0.17),
|
230 |
+
Block35(scale=0.17),
|
231 |
+
Block35(scale=0.17),
|
232 |
+
Block35(scale=0.17),
|
233 |
+
Block35(scale=0.17),
|
234 |
+
Block35(scale=0.17),
|
235 |
+
Block35(scale=0.17),
|
236 |
+
Block35(scale=0.17),
|
237 |
+
Block35(scale=0.17)
|
238 |
+
)
|
239 |
+
self.mixed_6a = Mixed_6a()
|
240 |
+
self.repeat_1 = nn.Sequential(
|
241 |
+
Block17(scale=0.10),
|
242 |
+
Block17(scale=0.10),
|
243 |
+
Block17(scale=0.10),
|
244 |
+
Block17(scale=0.10),
|
245 |
+
Block17(scale=0.10),
|
246 |
+
Block17(scale=0.10),
|
247 |
+
Block17(scale=0.10),
|
248 |
+
Block17(scale=0.10),
|
249 |
+
Block17(scale=0.10),
|
250 |
+
Block17(scale=0.10),
|
251 |
+
Block17(scale=0.10),
|
252 |
+
Block17(scale=0.10),
|
253 |
+
Block17(scale=0.10),
|
254 |
+
Block17(scale=0.10),
|
255 |
+
Block17(scale=0.10),
|
256 |
+
Block17(scale=0.10),
|
257 |
+
Block17(scale=0.10),
|
258 |
+
Block17(scale=0.10),
|
259 |
+
Block17(scale=0.10),
|
260 |
+
Block17(scale=0.10)
|
261 |
+
)
|
262 |
+
self.mixed_7a = Mixed_7a()
|
263 |
+
self.repeat_2 = nn.Sequential(
|
264 |
+
Block8(scale=0.20),
|
265 |
+
Block8(scale=0.20),
|
266 |
+
Block8(scale=0.20),
|
267 |
+
Block8(scale=0.20),
|
268 |
+
Block8(scale=0.20),
|
269 |
+
Block8(scale=0.20),
|
270 |
+
Block8(scale=0.20),
|
271 |
+
Block8(scale=0.20),
|
272 |
+
Block8(scale=0.20)
|
273 |
+
)
|
274 |
+
self.block8 = Block8(noReLU=True)
|
275 |
+
self.conv2d_7b = BasicConv2d(2080, 1536, kernel_size=1, stride=1)
|
276 |
+
self.avgpool_1a = nn.AdaptiveAvgPool2d((1, 1))#nn.AvgPool2d(8, count_include_pad=False)
|
277 |
+
self.last_linear = nn.Linear(1536, num_classes)
|
278 |
+
|
279 |
+
self.branch = self._make_branch(320, 1536, 3)
|
280 |
+
self.branch1 = self._make_branch(1088, 1536, 3)
|
281 |
+
self.branch2 = self._make_branch(2080, 1536, 3)
|
282 |
+
|
283 |
+
def _make_branch(self, channel_in, channel_out, kernel_size):
|
284 |
+
middle_channel = channel_out // 4
|
285 |
+
return nn.Sequential(
|
286 |
+
nn.Conv2d(channel_in, middle_channel, kernel_size=1, stride=1),
|
287 |
+
nn.BatchNorm2d(middle_channel),
|
288 |
+
nn.ReLU(),
|
289 |
+
|
290 |
+
nn.Conv2d(middle_channel, middle_channel, kernel_size=kernel_size, stride=kernel_size),
|
291 |
+
nn.BatchNorm2d(middle_channel),
|
292 |
+
nn.ReLU(),
|
293 |
+
|
294 |
+
nn.Conv2d(middle_channel, channel_out, kernel_size=1, stride=1),
|
295 |
+
nn.BatchNorm2d(channel_out),
|
296 |
+
nn.ReLU(),
|
297 |
+
|
298 |
+
nn.AdaptiveAvgPool2d((1,1)),
|
299 |
+
nn.Flatten(),
|
300 |
+
nn.Linear(channel_out, self.num_classes)
|
301 |
+
)
|
302 |
+
|
303 |
+
|
304 |
+
def features(self, input):
|
305 |
+
x = self.conv2d_1a(input)
|
306 |
+
x = self.conv2d_2a(x)
|
307 |
+
x = self.conv2d_2b(x)
|
308 |
+
x = self.maxpool_3a(x)
|
309 |
+
x = self.conv2d_3b(x)
|
310 |
+
x = self.conv2d_4a(x)
|
311 |
+
x = self.maxpool_5a(x)
|
312 |
+
x = self.mixed_5b(x)
|
313 |
+
x = self.repeat(x)
|
314 |
+
x1 = self.branch(x)
|
315 |
+
|
316 |
+
x = self.mixed_6a(x)
|
317 |
+
x = self.repeat_1(x)
|
318 |
+
x2 = self.branch1(x)
|
319 |
+
|
320 |
+
x = self.mixed_7a(x)
|
321 |
+
x = self.repeat_2(x)
|
322 |
+
x3 = self.branch2(x)
|
323 |
+
|
324 |
+
x = self.block8(x)
|
325 |
+
x = self.conv2d_7b(x)
|
326 |
+
return x, x1, x2, x3
|
327 |
+
|
328 |
+
def logits(self, features):
|
329 |
+
x = self.avgpool_1a(features)
|
330 |
+
x = x.view(x.size(0), -1)
|
331 |
+
out = self.last_linear(x)
|
332 |
+
return out
|
333 |
+
|
334 |
+
|
335 |
+
def forward(self, input):
|
336 |
+
x, x1, x2, x3, = self.features(input)
|
337 |
+
out = self.logits(x)
|
338 |
+
return {'outputs': [out, x1, x2, x3]}
|
339 |
+
|
340 |
+
|
341 |
+
def test():
|
342 |
+
net = InceptionResNetV2().cuda()
|
343 |
+
y = net(torch.randn(1,3,227,227).cuda())
|
344 |
+
print(y.size())
|
345 |
+
#test()
|
models/ResNet_Imagenet.py
ADDED
@@ -0,0 +1,289 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
|
4 |
+
try:
|
5 |
+
from torch.hub import load_state_dict_from_url
|
6 |
+
except ImportError:
|
7 |
+
from torch.utils.model_zoo import load_url as load_state_dict_from_url
|
8 |
+
|
9 |
+
model_urls = {
|
10 |
+
'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth',
|
11 |
+
'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth',
|
12 |
+
'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth',
|
13 |
+
'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth',
|
14 |
+
'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth',
|
15 |
+
'resnext50_32x4d': 'https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth',
|
16 |
+
'resnext101_32x8d': 'https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth',
|
17 |
+
'wide_resnet50_2': 'https://download.pytorch.org/models/wide_resnet50_2-95faca4d.pth',
|
18 |
+
'wide_resnet101_2': 'https://download.pytorch.org/models/wide_resnet101_2-32ee1156.pth',
|
19 |
+
}
|
20 |
+
|
21 |
+
|
22 |
+
def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):
|
23 |
+
"""3x3 convolution with padding"""
|
24 |
+
return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
|
25 |
+
padding=dilation, groups=groups, bias=False, dilation=dilation)
|
26 |
+
|
27 |
+
|
28 |
+
def conv1x1(in_planes, out_planes, stride=1):
|
29 |
+
"""1x1 convolution"""
|
30 |
+
return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
|
31 |
+
|
32 |
+
|
33 |
+
def branchBottleNeck(channel_in, channel_out, kernel_size):
|
34 |
+
middle_channel = channel_out//4
|
35 |
+
return nn.Sequential(
|
36 |
+
nn.Conv2d(channel_in, middle_channel, kernel_size=1, stride=1),
|
37 |
+
nn.BatchNorm2d(middle_channel),
|
38 |
+
nn.ReLU(),
|
39 |
+
|
40 |
+
nn.Conv2d(middle_channel, middle_channel, kernel_size=kernel_size, stride=kernel_size),
|
41 |
+
nn.BatchNorm2d(middle_channel),
|
42 |
+
nn.ReLU(),
|
43 |
+
|
44 |
+
nn.Conv2d(middle_channel, channel_out, kernel_size=1, stride=1),
|
45 |
+
nn.BatchNorm2d(channel_out),
|
46 |
+
nn.ReLU(),
|
47 |
+
)
|
48 |
+
|
49 |
+
class LambdaLayer(nn.Module):
|
50 |
+
def __init__(self, lambd):
|
51 |
+
super(LambdaLayer, self).__init__()
|
52 |
+
self.lambd = lambd
|
53 |
+
|
54 |
+
def forward(self, x):
|
55 |
+
return self.lambd(x)
|
56 |
+
|
57 |
+
class BasicBlock(nn.Module):
|
58 |
+
expansion = 1
|
59 |
+
|
60 |
+
def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
|
61 |
+
base_width=64, dilation=1, norm_layer=None):
|
62 |
+
super(BasicBlock, self).__init__()
|
63 |
+
if norm_layer is None:
|
64 |
+
norm_layer = nn.BatchNorm2d
|
65 |
+
if groups != 1 or base_width != 64:
|
66 |
+
raise ValueError('BasicBlock only supports groups=1 and base_width=64')
|
67 |
+
if dilation > 1:
|
68 |
+
raise NotImplementedError("Dilation > 1 not supported in BasicBlock")
|
69 |
+
# Both self.conv1 and self.downsample layers downsample the input when stride != 1
|
70 |
+
self.conv1 = conv3x3(inplanes, planes, stride)
|
71 |
+
self.bn1 = norm_layer(planes)
|
72 |
+
self.relu = nn.ReLU(inplace=True)
|
73 |
+
self.conv2 = conv3x3(planes, planes)
|
74 |
+
self.bn2 = norm_layer(planes)
|
75 |
+
self.downsample = downsample
|
76 |
+
self.stride = stride
|
77 |
+
|
78 |
+
def forward(self, x):
|
79 |
+
identity = x
|
80 |
+
|
81 |
+
out = self.conv1(x)
|
82 |
+
out = self.bn1(out)
|
83 |
+
out = self.relu(out)
|
84 |
+
|
85 |
+
out = self.conv2(out)
|
86 |
+
out = self.bn2(out)
|
87 |
+
|
88 |
+
if self.downsample is not None:
|
89 |
+
identity = self.downsample(x)
|
90 |
+
|
91 |
+
out += identity
|
92 |
+
out = self.relu(out)
|
93 |
+
|
94 |
+
return out
|
95 |
+
|
96 |
+
|
97 |
+
class Bottleneck(nn.Module):
|
98 |
+
# Bottleneck in torchvision places the stride for downsampling at 3x3 convolution(self.conv2)
|
99 |
+
# while original implementation places the stride at the first 1x1 convolution(self.conv1)
|
100 |
+
# according to "Deep residual learning for image recognition"https://arxiv.org/abs/1512.03385.
|
101 |
+
# This variant is also known as ResNet V1.5 and improves accuracy according to
|
102 |
+
# https://ngc.nvidia.com/catalog/model-scripts/nvidia:resnet_50_v1_5_for_pytorch.
|
103 |
+
|
104 |
+
expansion = 4
|
105 |
+
|
106 |
+
def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
|
107 |
+
base_width=64, dilation=1, norm_layer=None):
|
108 |
+
super(Bottleneck, self).__init__()
|
109 |
+
if norm_layer is None:
|
110 |
+
norm_layer = nn.BatchNorm2d
|
111 |
+
width = int(planes * (base_width / 64.)) * groups
|
112 |
+
# Both self.conv2 and self.downsample layers downsample the input when stride != 1
|
113 |
+
self.conv1 = conv1x1(inplanes, width)
|
114 |
+
self.bn1 = norm_layer(width)
|
115 |
+
self.conv2 = conv3x3(width, width, stride, groups, dilation)
|
116 |
+
self.bn2 = norm_layer(width)
|
117 |
+
self.conv3 = conv1x1(width, planes * self.expansion)
|
118 |
+
self.bn3 = norm_layer(planes * self.expansion)
|
119 |
+
self.relu = nn.ReLU(inplace=True)
|
120 |
+
self.downsample = downsample
|
121 |
+
self.stride = stride
|
122 |
+
|
123 |
+
def forward(self, x):
|
124 |
+
identity = x
|
125 |
+
|
126 |
+
out = self.conv1(x)
|
127 |
+
out = self.bn1(out)
|
128 |
+
out = self.relu(out)
|
129 |
+
|
130 |
+
out = self.conv2(out)
|
131 |
+
out = self.bn2(out)
|
132 |
+
out = self.relu(out)
|
133 |
+
|
134 |
+
out = self.conv3(out)
|
135 |
+
out = self.bn3(out)
|
136 |
+
|
137 |
+
if self.downsample is not None:
|
138 |
+
identity = self.downsample(x)
|
139 |
+
|
140 |
+
out += identity
|
141 |
+
out = self.relu(out)
|
142 |
+
|
143 |
+
return out
|
144 |
+
|
145 |
+
|
146 |
+
class ResNet(nn.Module):
|
147 |
+
|
148 |
+
def __init__(self, block, layers, num_classes=1000, zero_init_residual=False,
|
149 |
+
groups=1, width_per_group=64, replace_stride_with_dilation=None,
|
150 |
+
norm_layer=None):
|
151 |
+
super(ResNet, self).__init__()
|
152 |
+
if norm_layer is None:
|
153 |
+
norm_layer = nn.BatchNorm2d
|
154 |
+
self._norm_layer = norm_layer
|
155 |
+
self.num_classes = num_classes
|
156 |
+
|
157 |
+
self.inplanes = 64
|
158 |
+
self.dilation = 1
|
159 |
+
if replace_stride_with_dilation is None:
|
160 |
+
# each element in the tuple indicates if we should replace
|
161 |
+
# the 2x2 stride with a dilated convolution instead
|
162 |
+
replace_stride_with_dilation = [False, False, False]
|
163 |
+
if len(replace_stride_with_dilation) != 3:
|
164 |
+
raise ValueError("replace_stride_with_dilation should be None "
|
165 |
+
"or a 3-element tuple, got {}".format(replace_stride_with_dilation))
|
166 |
+
self.groups = groups
|
167 |
+
self.base_width = width_per_group
|
168 |
+
self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3,
|
169 |
+
bias=False)
|
170 |
+
self.bn1 = norm_layer(self.inplanes)
|
171 |
+
self.relu = nn.ReLU(inplace=True)
|
172 |
+
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
|
173 |
+
self.layer1 = self._make_layer(block, 64, layers[0])
|
174 |
+
self.layer2 = self._make_layer(block, 128, layers[1], stride=2,
|
175 |
+
dilate=replace_stride_with_dilation[0])
|
176 |
+
self.layer3 = self._make_layer(block, 256, layers[2], stride=2,
|
177 |
+
dilate=replace_stride_with_dilation[1])
|
178 |
+
self.layer4 = self._make_layer(block, 512, layers[3], stride=2,
|
179 |
+
dilate=replace_stride_with_dilation[2])
|
180 |
+
|
181 |
+
self.branch1 = self._make_branch(64*block.expansion, 512*block.expansion, kernel_size=8)
|
182 |
+
self.branch2 = self._make_branch(128*block.expansion, 512*block.expansion, kernel_size=4)
|
183 |
+
self.branch3 = self._make_branch(256*block.expansion, 512*block.expansion, kernel_size=2)
|
184 |
+
self.branch4 = self._make_branch(512*block.expansion, 512*block.expansion, kernel_size=1)
|
185 |
+
|
186 |
+
self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
|
187 |
+
self.fc = nn.Linear(512 * block.expansion, num_classes)
|
188 |
+
|
189 |
+
for m in self.modules():
|
190 |
+
if isinstance(m, nn.Conv2d):
|
191 |
+
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
|
192 |
+
elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
|
193 |
+
nn.init.constant_(m.weight, 1)
|
194 |
+
nn.init.constant_(m.bias, 0)
|
195 |
+
|
196 |
+
# Zero-initialize the last BN in each residual branch,
|
197 |
+
# so that the residual branch starts with zeros, and each residual block behaves like an identity.
|
198 |
+
# This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677
|
199 |
+
if zero_init_residual:
|
200 |
+
for m in self.modules():
|
201 |
+
if isinstance(m, Bottleneck):
|
202 |
+
nn.init.constant_(m.bn3.weight, 0)
|
203 |
+
elif isinstance(m, BasicBlock):
|
204 |
+
nn.init.constant_(m.bn2.weight, 0)
|
205 |
+
|
206 |
+
def _make_layer(self, block, planes, blocks, stride=1, dilate=False):
|
207 |
+
norm_layer = self._norm_layer
|
208 |
+
downsample = None
|
209 |
+
previous_dilation = self.dilation
|
210 |
+
if dilate:
|
211 |
+
self.dilation *= stride
|
212 |
+
stride = 1
|
213 |
+
if stride != 1 or self.inplanes != planes * block.expansion:
|
214 |
+
downsample = nn.Sequential(
|
215 |
+
conv1x1(self.inplanes, planes * block.expansion, stride),
|
216 |
+
norm_layer(planes * block.expansion),
|
217 |
+
)
|
218 |
+
|
219 |
+
layers = []
|
220 |
+
layers.append(block(self.inplanes, planes, stride, downsample, self.groups,
|
221 |
+
self.base_width, previous_dilation, norm_layer))
|
222 |
+
self.inplanes = planes * block.expansion
|
223 |
+
for _ in range(1, blocks):
|
224 |
+
layers.append(block(self.inplanes, planes, groups=self.groups,
|
225 |
+
base_width=self.base_width, dilation=self.dilation,
|
226 |
+
norm_layer=norm_layer))
|
227 |
+
|
228 |
+
return nn.Sequential(*layers)
|
229 |
+
|
230 |
+
def _make_branch(self, channel_in, channel_out, kernel_size):
|
231 |
+
middle_channel = channel_out // 4
|
232 |
+
return nn.Sequential(
|
233 |
+
nn.Conv2d(channel_in, middle_channel, kernel_size=1, stride=1),
|
234 |
+
nn.BatchNorm2d(middle_channel),
|
235 |
+
nn.ReLU(),
|
236 |
+
|
237 |
+
nn.Conv2d(middle_channel, middle_channel, kernel_size=kernel_size, stride=kernel_size),
|
238 |
+
nn.BatchNorm2d(middle_channel),
|
239 |
+
nn.ReLU(),
|
240 |
+
|
241 |
+
nn.Conv2d(middle_channel, channel_out, kernel_size=1, stride=1),
|
242 |
+
nn.BatchNorm2d(channel_out),
|
243 |
+
nn.ReLU(),
|
244 |
+
|
245 |
+
nn.AdaptiveAvgPool2d((1,1)),
|
246 |
+
nn.Flatten(),
|
247 |
+
nn.Linear(channel_out, self.num_classes)
|
248 |
+
)
|
249 |
+
|
250 |
+
def _forward_impl(self, x):
|
251 |
+
# See note [TorchScript super()]
|
252 |
+
x = self.conv1(x)
|
253 |
+
x = self.bn1(x)
|
254 |
+
x = self.relu(x)
|
255 |
+
x = self.maxpool(x)
|
256 |
+
|
257 |
+
x = self.layer1(x)
|
258 |
+
x1 = self.branch1(x)
|
259 |
+
|
260 |
+
x = self.layer2(x)
|
261 |
+
x2 = self.branch2(x)
|
262 |
+
|
263 |
+
x = self.layer3(x)
|
264 |
+
x3 = self.branch3(x)
|
265 |
+
|
266 |
+
x = self.layer4(x)
|
267 |
+
x = self.avgpool(x)
|
268 |
+
final_fea = x
|
269 |
+
x = torch.flatten(x, 1)
|
270 |
+
x = self.fc(x)
|
271 |
+
|
272 |
+
return {'outputs': [x, x1, x2, x3]}
|
273 |
+
|
274 |
+
def forward(self, x):
|
275 |
+
return self._forward_impl(x)
|
276 |
+
|
277 |
+
def sdresnet50(num_classes=14, pretrained=True):
|
278 |
+
if pretrained:
|
279 |
+
model = ResNet(Bottleneck, [3,4,6,3], num_classes=14)
|
280 |
+
num_ftrs = model.fc.in_features
|
281 |
+
model.fc = nn.Linear(num_ftrs, 1000)
|
282 |
+
state_dict = load_state_dict_from_url(model_urls['resnet50'], progress=True)
|
283 |
+
model.load_state_dict(state_dict, strict=False)
|
284 |
+
|
285 |
+
num_ftrs = model.fc.in_features
|
286 |
+
model.fc = nn.Linear(num_ftrs, num_classes)
|
287 |
+
else:
|
288 |
+
model = ResNet(Bottleneck, [3,4,6,3], num_classes=50)
|
289 |
+
return model
|
models/ResNet_cifar.py
ADDED
@@ -0,0 +1,688 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
import torch.nn.init as init
|
5 |
+
from torch.nn import Parameter
|
6 |
+
|
7 |
+
model_urls = {
|
8 |
+
'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth',
|
9 |
+
'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth',
|
10 |
+
'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth',
|
11 |
+
'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth',
|
12 |
+
'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth',
|
13 |
+
'resnext50_32x4d': 'https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth',
|
14 |
+
'resnext101_32x8d': 'https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth',
|
15 |
+
'wide_resnet50_2': 'https://download.pytorch.org/models/wide_resnet50_2-95faca4d.pth',
|
16 |
+
'wide_resnet101_2': 'https://download.pytorch.org/models/wide_resnet101_2-32ee1156.pth',
|
17 |
+
}
|
18 |
+
|
19 |
+
def conv3x3(in_planes, out_planes, stride=1):
|
20 |
+
return nn.Conv2d(in_planes, out_planes, kernel_size=3,
|
21 |
+
stride=stride, padding=1, bias=False)
|
22 |
+
|
23 |
+
def conv1x1(in_planes, planes, stride=1):
|
24 |
+
return nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride, bias=False)
|
25 |
+
|
26 |
+
def branchBottleNeck(channel_in, channel_out, kernel_size):
|
27 |
+
middle_channel = channel_out//4
|
28 |
+
return nn.Sequential(
|
29 |
+
nn.Conv2d(channel_in, middle_channel, kernel_size=1, stride=1),
|
30 |
+
nn.BatchNorm2d(middle_channel),
|
31 |
+
nn.ReLU(),
|
32 |
+
|
33 |
+
nn.Conv2d(middle_channel, middle_channel, kernel_size=kernel_size, stride=kernel_size),
|
34 |
+
nn.BatchNorm2d(middle_channel),
|
35 |
+
nn.ReLU(),
|
36 |
+
|
37 |
+
nn.Conv2d(middle_channel, channel_out, kernel_size=1, stride=1),
|
38 |
+
nn.BatchNorm2d(channel_out),
|
39 |
+
nn.ReLU(),
|
40 |
+
)
|
41 |
+
|
42 |
+
def branchMLP(channel_in, channel_out):
|
43 |
+
middle_channel = channel_out//4
|
44 |
+
return nn.Sequential(
|
45 |
+
conv1x1(channel_in, channel_in, stride=8),
|
46 |
+
nn.BatchNorm2d(512 * block.expansion),
|
47 |
+
nn.ReLU(),
|
48 |
+
)
|
49 |
+
|
50 |
+
def invertedBottleNeck(channel_in, channel_out, kernel_size):
|
51 |
+
middle_channel = channel_out * 2
|
52 |
+
return nn.Sequential(
|
53 |
+
nn.Conv2d(channel_in, middle_channel, kernel_size=1, stride=1),
|
54 |
+
nn.BatchNorm2d(middle_channel),
|
55 |
+
nn.ReLU(),
|
56 |
+
|
57 |
+
nn.Conv2d(middle_channel, middle_channel, kernel_size=kernel_size, stride=kernel_size),
|
58 |
+
nn.BatchNorm2d(middle_channel),
|
59 |
+
nn.ReLU(),
|
60 |
+
|
61 |
+
nn.Conv2d(middle_channel, channel_out, kernel_size=1, stride=1),
|
62 |
+
nn.BatchNorm2d(channel_out),
|
63 |
+
nn.ReLU(),
|
64 |
+
)
|
65 |
+
|
66 |
+
class BatchNorm2dMul(nn.Module):
|
67 |
+
def __init__(self, num_features, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True):
|
68 |
+
super(BatchNorm2dMul, self).__init__()
|
69 |
+
self.bn = nn.BatchNorm2d(num_features, eps=eps, momentum=momentum, affine=False, track_running_stats=track_running_stats)
|
70 |
+
self.gamma = nn.Parameter(torch.ones(num_features))
|
71 |
+
self.beta = nn.Parameter(torch.zeros(num_features))
|
72 |
+
self.affine = affine
|
73 |
+
|
74 |
+
def forward(self, x):
|
75 |
+
bn_out = self.bn(x)
|
76 |
+
if self.affine:
|
77 |
+
out = self.gamma[None, :, None, None] * bn_out + self.beta[None, :, None, None]
|
78 |
+
return out, bn_out
|
79 |
+
|
80 |
+
def _weights_init(m):
|
81 |
+
classname = m.__class__.__name__
|
82 |
+
if isinstance(m, nn.Linear) or isinstance(m, nn.Conv2d):
|
83 |
+
init.kaiming_normal_(m.weight)
|
84 |
+
|
85 |
+
class LambdaLayer(nn.Module):
|
86 |
+
def __init__(self, lambd):
|
87 |
+
super(LambdaLayer, self).__init__()
|
88 |
+
self.lambd = lambd
|
89 |
+
|
90 |
+
def forward(self, x):
|
91 |
+
return self.lambd(x)
|
92 |
+
|
93 |
+
class BasicBlock_s(nn.Module):
|
94 |
+
expansion = 1
|
95 |
+
|
96 |
+
def __init__(self, in_planes, planes, stride=1):
|
97 |
+
super(BasicBlock_s, self).__init__()
|
98 |
+
self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
|
99 |
+
self.bn1 = nn.BatchNorm2d(planes)
|
100 |
+
self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False)
|
101 |
+
self.bn2 = nn.BatchNorm2d(planes)
|
102 |
+
|
103 |
+
self.shortcut = nn.Sequential()
|
104 |
+
if stride != 1 or in_planes != self.expansion*planes:
|
105 |
+
self.shortcut = nn.Sequential(
|
106 |
+
nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False),
|
107 |
+
nn.BatchNorm2d(self.expansion*planes)
|
108 |
+
)
|
109 |
+
|
110 |
+
def forward(self, x):
|
111 |
+
out = F.relu(self.bn1(self.conv1(x)))
|
112 |
+
out = self.bn2(self.conv2(out))
|
113 |
+
out += self.shortcut(x)
|
114 |
+
out = F.relu(out)
|
115 |
+
return out
|
116 |
+
|
117 |
+
class BasicBlock(nn.Module):
|
118 |
+
expansion = 1
|
119 |
+
def __init__(self, inplanes, planes, stride=1, downsample=None):
|
120 |
+
super(BasicBlock, self).__init__()
|
121 |
+
self.conv1 = conv3x3(inplanes, planes, stride)
|
122 |
+
self.bn1 = BatchNorm2dMul(planes)
|
123 |
+
self.relu = nn.ReLU(inplace=True)
|
124 |
+
self.conv2 = conv3x3(planes, planes)
|
125 |
+
self.bn2 = BatchNorm2dMul(planes)
|
126 |
+
self.downsample = downsample
|
127 |
+
self.stride = stride
|
128 |
+
|
129 |
+
def forward(self, x):
|
130 |
+
bn_outputs = []
|
131 |
+
|
132 |
+
residual = x
|
133 |
+
output = self.conv1(x)
|
134 |
+
output, bn_out = self.bn1(output)
|
135 |
+
bn_outputs.append(bn_out)
|
136 |
+
output = self.relu(output)
|
137 |
+
|
138 |
+
output = self.conv2(output)
|
139 |
+
output, bn_out = self.bn2(output)
|
140 |
+
bn_outputs.append(bn_out)
|
141 |
+
|
142 |
+
if self.downsample is not None:
|
143 |
+
residual = self.downsample(x)
|
144 |
+
|
145 |
+
output += residual
|
146 |
+
output = self.relu(output)
|
147 |
+
return output, bn_outputs
|
148 |
+
|
149 |
+
class BottleneckBlock(nn.Module):
|
150 |
+
expansion = 4
|
151 |
+
def __init__(self, inplanes, planes, stride=1, downsample=None):
|
152 |
+
super(BottleneckBlock, self).__init__()
|
153 |
+
self.conv1 = conv1x1(inplanes, planes)
|
154 |
+
self.bn1 = nn.BatchNorm2d(planes)
|
155 |
+
self.relu = nn.ReLU(inplace=True)
|
156 |
+
|
157 |
+
self.conv2 = conv3x3(planes, planes, stride)
|
158 |
+
self.bn2 = nn.BatchNorm2d(planes)
|
159 |
+
|
160 |
+
self.conv3 = conv1x1(planes, planes*self.expansion)
|
161 |
+
self.bn3 = nn.BatchNorm2d(planes*self.expansion)
|
162 |
+
|
163 |
+
self.downsample = downsample
|
164 |
+
self.stride = stride
|
165 |
+
|
166 |
+
def forward(self, x):
|
167 |
+
residual = x
|
168 |
+
|
169 |
+
output = self.conv1(x)
|
170 |
+
output = self.bn1(output)
|
171 |
+
output = self.relu(output)
|
172 |
+
|
173 |
+
output = self.conv2(output)
|
174 |
+
output = self.bn2(output)
|
175 |
+
output = self.relu(output)
|
176 |
+
|
177 |
+
output = self.conv3(output)
|
178 |
+
output = self.bn3(output)
|
179 |
+
|
180 |
+
if self.downsample is not None:
|
181 |
+
residual = self.downsample(x)
|
182 |
+
|
183 |
+
output += residual
|
184 |
+
output = self.relu(output)
|
185 |
+
|
186 |
+
return output
|
187 |
+
|
188 |
+
class LayerBlock(nn.Module):
|
189 |
+
def __init__(self, block, inplanes, planes, num_blocks, stride):
|
190 |
+
super(LayerBlock, self).__init__()
|
191 |
+
downsample = None
|
192 |
+
if stride !=1 or inplanes != planes * block.expansion:
|
193 |
+
downsample = nn.Sequential(
|
194 |
+
conv1x1(inplanes, planes * block.expansion, stride),
|
195 |
+
nn.BatchNorm2d(planes * block.expansion),
|
196 |
+
)
|
197 |
+
layer = []
|
198 |
+
layer.append(block(inplanes, planes, stride=stride, downsample=downsample))
|
199 |
+
inplanes = planes * block.expansion
|
200 |
+
for i in range(1, num_blocks):
|
201 |
+
layer.append(block(inplanes, planes))
|
202 |
+
self.layers = nn.Sequential(*layer)
|
203 |
+
|
204 |
+
def forward(self, x):
|
205 |
+
bn_outputs = []
|
206 |
+
for layer in self.layers:
|
207 |
+
x, bn_output = layer(x)
|
208 |
+
bn_outputs.extend(bn_output)
|
209 |
+
return x, bn_outputs
|
210 |
+
|
211 |
+
class SDResNet(nn.Module):
|
212 |
+
"""
|
213 |
+
Resnet model
|
214 |
+
|
215 |
+
Args:
|
216 |
+
block (class): block type, BasicBlock or BottlenetckBlock
|
217 |
+
layers (int list): layer num in each block
|
218 |
+
num_classes (int): class num
|
219 |
+
"""
|
220 |
+
|
221 |
+
def __init__(self, block, layers, num_classes=10, position_all=True):
|
222 |
+
super(SDResNet, self).__init__()
|
223 |
+
|
224 |
+
self.position_all = position_all
|
225 |
+
|
226 |
+
self.inplanes = 64
|
227 |
+
self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=3, stride=1, padding=1, bias=False)
|
228 |
+
self.bn1 = nn.BatchNorm2d(self.inplanes)
|
229 |
+
self.relu = nn.ReLU(inplace=True)
|
230 |
+
# self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
|
231 |
+
|
232 |
+
self.layer1 = LayerBlock(block, 64, 64, layers[0], stride=1)
|
233 |
+
self.layer2 = LayerBlock(block, 64, 128, layers[1], stride=2)
|
234 |
+
self.layer3 = LayerBlock(block, 128, 256, layers[2], stride=2)
|
235 |
+
self.layer4 = LayerBlock(block, 256, 512, layers[3], stride=2)
|
236 |
+
|
237 |
+
self.downsample1_1 = nn.Sequential(
|
238 |
+
conv1x1(64 * block.expansion, 512 * block.expansion, stride=8),
|
239 |
+
nn.BatchNorm2d(512 * block.expansion),
|
240 |
+
nn.ReLU(),
|
241 |
+
)
|
242 |
+
self.bottleneck1_1 = branchBottleNeck(64 * block.expansion, 512 * block.expansion, kernel_size=8)
|
243 |
+
self.avgpool1 = nn.AdaptiveAvgPool2d((1,1))
|
244 |
+
self.middle_fc1 = nn.Linear(512 * block.expansion, num_classes)
|
245 |
+
|
246 |
+
|
247 |
+
self.downsample2_1 = nn.Sequential(
|
248 |
+
conv1x1(128 * block.expansion, 512 * block.expansion, stride=4),
|
249 |
+
nn.BatchNorm2d(512 * block.expansion),
|
250 |
+
)
|
251 |
+
self.bottleneck2_1 = branchBottleNeck(128 * block.expansion, 512 * block.expansion, kernel_size=4)
|
252 |
+
self.avgpool2 = nn.AdaptiveAvgPool2d((1,1))
|
253 |
+
self.middle_fc2 = nn.Linear(512 * block.expansion, num_classes)
|
254 |
+
|
255 |
+
|
256 |
+
self.downsample3_1 = nn.Sequential(
|
257 |
+
conv1x1(256 * block.expansion, 512 * block.expansion, stride=2),
|
258 |
+
nn.BatchNorm2d(512 * block.expansion),
|
259 |
+
)
|
260 |
+
self.bottleneck3_1 = branchBottleNeck(256 * block.expansion, 512 * block.expansion, kernel_size=2)
|
261 |
+
self.avgpool3 = nn.AdaptiveAvgPool2d((1,1))
|
262 |
+
self.middle_fc3 = nn.Linear(512 * block.expansion, num_classes)
|
263 |
+
|
264 |
+
self.avgpool = nn.AdaptiveAvgPool2d((1,1))
|
265 |
+
self.fc = nn.Linear(512 * block.expansion, num_classes)
|
266 |
+
|
267 |
+
self.apply(_weights_init)
|
268 |
+
|
269 |
+
def _make_layer(self, block, planes, layers, stride=1):
|
270 |
+
"""A block with 'layers' layers
|
271 |
+
Args:
|
272 |
+
block (class): block type
|
273 |
+
planes (int): output channels = planes * expansion
|
274 |
+
layers (int): layer num in the block
|
275 |
+
stride (int): the first layer stride in the block
|
276 |
+
"""
|
277 |
+
downsample = None
|
278 |
+
if stride !=1 or self.inplanes != planes * block.expansion:
|
279 |
+
downsample = nn.Sequential(
|
280 |
+
conv1x1(self.inplanes, planes * block.expansion, stride),
|
281 |
+
nn.BatchNorm2d(planes * block.expansion),
|
282 |
+
)
|
283 |
+
layer = []
|
284 |
+
layer.append(block(self.inplanes, planes, stride=stride, downsample=downsample))
|
285 |
+
self.inplanes = planes * block.expansion
|
286 |
+
for i in range(1, layers):
|
287 |
+
layer.append(block(self.inplanes, planes))
|
288 |
+
|
289 |
+
return nn.Sequential(*layer)
|
290 |
+
|
291 |
+
def forward(self, x, feat_out=False):
|
292 |
+
all_bn_outputs = []
|
293 |
+
|
294 |
+
x = self.conv1(x)
|
295 |
+
x = self.bn1(x)
|
296 |
+
x = self.relu(x)
|
297 |
+
# x = self.maxpool(x)
|
298 |
+
|
299 |
+
x, bn_outputs = self.layer1(x)
|
300 |
+
all_bn_outputs.extend(bn_outputs)
|
301 |
+
middle_output1 = self.bottleneck1_1(x)
|
302 |
+
middle_output1 = self.avgpool1(middle_output1)
|
303 |
+
middle1_fea = middle_output1
|
304 |
+
middle_output1 = torch.flatten(middle_output1, 1)
|
305 |
+
middle_output1 = self.middle_fc1(middle_output1)
|
306 |
+
|
307 |
+
x, bn_outputs = self.layer2(x)
|
308 |
+
all_bn_outputs.extend(bn_outputs)
|
309 |
+
middle_output2 = self.bottleneck2_1(x)
|
310 |
+
middle_output2 = self.avgpool2(middle_output2)
|
311 |
+
middle2_fea = middle_output2
|
312 |
+
middle_output2 = torch.flatten(middle_output2, 1)
|
313 |
+
middle_output2 = self.middle_fc2(middle_output2)
|
314 |
+
|
315 |
+
x, bn_outputs = self.layer3(x)
|
316 |
+
all_bn_outputs.extend(bn_outputs)
|
317 |
+
middle_output3 = self.bottleneck3_1(x)
|
318 |
+
middle_output3 = self.avgpool3(middle_output3)
|
319 |
+
middle3_fea = middle_output3
|
320 |
+
middle_output3 = torch.flatten(middle_output3, 1)
|
321 |
+
middle_output3 = self.middle_fc3(middle_output3)
|
322 |
+
|
323 |
+
x, bn_outputs = self.layer4(x)
|
324 |
+
all_bn_outputs.extend(bn_outputs)
|
325 |
+
x = self.avgpool(x)
|
326 |
+
final_fea = x
|
327 |
+
x = torch.flatten(x, 1)
|
328 |
+
x = self.fc(x)
|
329 |
+
|
330 |
+
if self.position_all and feat_out:
|
331 |
+
return {'outputs': [x, middle_output1, middle_output2, middle_output3],
|
332 |
+
'features': [final_fea, middle1_fea, middle2_fea, middle3_fea],
|
333 |
+
'bn_outputs': all_bn_outputs}
|
334 |
+
else:
|
335 |
+
return x
|
336 |
+
|
337 |
+
class SDResNet_mlp(nn.Module):
|
338 |
+
"""
|
339 |
+
Resnet model
|
340 |
+
|
341 |
+
Args:
|
342 |
+
block (class): block type, BasicBlock or BottlenetckBlock
|
343 |
+
layers (int list): layer num in each block
|
344 |
+
num_classes (int): class num
|
345 |
+
"""
|
346 |
+
|
347 |
+
def __init__(self, block, layers, num_classes=10, position_all=True):
|
348 |
+
super(SDResNet_mlp, self).__init__()
|
349 |
+
|
350 |
+
self.position_all = position_all
|
351 |
+
|
352 |
+
self.inplanes = 64
|
353 |
+
self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=3, stride=1, padding=1, bias=False)
|
354 |
+
self.bn1 = nn.BatchNorm2d(self.inplanes)
|
355 |
+
self.relu = nn.ReLU(inplace=True)
|
356 |
+
# self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
|
357 |
+
|
358 |
+
self.layer1 = LayerBlock(block, 64, 64, layers[0], stride=1)
|
359 |
+
self.layer2 = LayerBlock(block, 64, 128, layers[1], stride=2)
|
360 |
+
self.layer3 = LayerBlock(block, 128, 256, layers[2], stride=2)
|
361 |
+
self.layer4 = LayerBlock(block, 256, 512, layers[3], stride=2)
|
362 |
+
|
363 |
+
self.downsample1_1 = nn.Sequential(
|
364 |
+
conv1x1(64 * block.expansion, 512 * block.expansion),
|
365 |
+
nn.BatchNorm2d(512 * block.expansion),
|
366 |
+
nn.ReLU(),
|
367 |
+
)
|
368 |
+
self.avgpool1 = nn.AdaptiveAvgPool2d((1,1))
|
369 |
+
self.middle_fc1 = nn.Linear(512 * block.expansion, num_classes)
|
370 |
+
|
371 |
+
|
372 |
+
self.downsample2_1 = nn.Sequential(
|
373 |
+
conv1x1(128 * block.expansion, 512 * block.expansion),
|
374 |
+
nn.BatchNorm2d(512 * block.expansion),
|
375 |
+
nn.ReLU()
|
376 |
+
)
|
377 |
+
self.avgpool2 = nn.AdaptiveAvgPool2d((1,1))
|
378 |
+
self.middle_fc2 = nn.Linear(512 * block.expansion, num_classes)
|
379 |
+
|
380 |
+
|
381 |
+
self.downsample3_1 = nn.Sequential(
|
382 |
+
conv1x1(256 * block.expansion, 512 * block.expansion),
|
383 |
+
nn.BatchNorm2d(512 * block.expansion),
|
384 |
+
nn.ReLU()
|
385 |
+
)
|
386 |
+
self.avgpool3 = nn.AdaptiveAvgPool2d((1,1))
|
387 |
+
self.middle_fc3 = nn.Linear(512 * block.expansion, num_classes)
|
388 |
+
|
389 |
+
self.avgpool = nn.AdaptiveAvgPool2d((1,1))
|
390 |
+
self.fc = nn.Linear(512 * block.expansion, num_classes)
|
391 |
+
|
392 |
+
self.apply(_weights_init)
|
393 |
+
|
394 |
+
def _make_layer(self, block, planes, layers, stride=1):
|
395 |
+
"""A block with 'layers' layers
|
396 |
+
Args:
|
397 |
+
block (class): block type
|
398 |
+
planes (int): output channels = planes * expansion
|
399 |
+
layers (int): layer num in the block
|
400 |
+
stride (int): the first layer stride in the block
|
401 |
+
"""
|
402 |
+
downsample = None
|
403 |
+
if stride !=1 or self.inplanes != planes * block.expansion:
|
404 |
+
downsample = nn.Sequential(
|
405 |
+
conv1x1(self.inplanes, planes * block.expansion, stride),
|
406 |
+
nn.BatchNorm2d(planes * block.expansion),
|
407 |
+
)
|
408 |
+
layer = []
|
409 |
+
layer.append(block(self.inplanes, planes, stride=stride, downsample=downsample))
|
410 |
+
self.inplanes = planes * block.expansion
|
411 |
+
for i in range(1, layers):
|
412 |
+
layer.append(block(self.inplanes, planes))
|
413 |
+
|
414 |
+
return nn.Sequential(*layer)
|
415 |
+
|
416 |
+
def forward(self, x):
|
417 |
+
all_bn_outputs = []
|
418 |
+
|
419 |
+
x = self.conv1(x)
|
420 |
+
x = self.bn1(x)
|
421 |
+
x = self.relu(x)
|
422 |
+
|
423 |
+
x, bn_outputs = self.layer1(x)
|
424 |
+
all_bn_outputs.extend(bn_outputs)
|
425 |
+
# middle_output1 = self.downsample1_1(x)
|
426 |
+
# middle_output1 = self.avgpool1(middle_output1)
|
427 |
+
# middle1_fea = middle_output1
|
428 |
+
# middle_output1 = torch.flatten(middle_output1, 1)
|
429 |
+
# middle_output1 = self.middle_fc1(middle_output1)
|
430 |
+
|
431 |
+
x, bn_outputs = self.layer2(x)
|
432 |
+
all_bn_outputs.extend(bn_outputs)
|
433 |
+
# middle_output2 = self.downsample2_1(x)
|
434 |
+
# middle_output2 = self.avgpool2(middle_output2)
|
435 |
+
# middle2_fea = middle_output2
|
436 |
+
# middle_output2 = torch.flatten(middle_output2, 1)
|
437 |
+
# middle_output2 = self.middle_fc2(middle_output2)
|
438 |
+
|
439 |
+
x, bn_outputs = self.layer3(x)
|
440 |
+
all_bn_outputs.extend(bn_outputs)
|
441 |
+
# middle_output3 = self.downsample3_1(x)
|
442 |
+
# middle_output3 = self.avgpool3(middle_output3)
|
443 |
+
# middle3_fea = middle_output3
|
444 |
+
# middle_output3 = torch.flatten(middle_output3, 1)
|
445 |
+
# middle_output3 = self.middle_fc3(middle_output3)
|
446 |
+
|
447 |
+
x, bn_outputs = self.layer4(x)
|
448 |
+
all_bn_outputs.extend(bn_outputs)
|
449 |
+
x = self.avgpool(x)
|
450 |
+
final_fea = x
|
451 |
+
x = torch.flatten(x, 1)
|
452 |
+
x = self.fc(x)
|
453 |
+
|
454 |
+
if self.position_all:
|
455 |
+
return {'outputs': [x, middle_output1, middle_output2, middle_output3],
|
456 |
+
'bn_outputs': all_bn_outputs}
|
457 |
+
else:
|
458 |
+
return {'outputs': [x, x],
|
459 |
+
'bn_outputs': all_bn_outputs}
|
460 |
+
|
461 |
+
class SDResNet_residual(nn.Module):
|
462 |
+
"""
|
463 |
+
Resnet model
|
464 |
+
|
465 |
+
Args:
|
466 |
+
block (class): block type, BasicBlock or BottlenetckBlock
|
467 |
+
layers (int list): layer num in each block
|
468 |
+
num_classes (int): class num
|
469 |
+
"""
|
470 |
+
|
471 |
+
def __init__(self, block, layers, num_classes=10, position_all=True):
|
472 |
+
super(SDResNet_residual, self).__init__()
|
473 |
+
|
474 |
+
self.position_all = position_all
|
475 |
+
|
476 |
+
self.inplanes = 64
|
477 |
+
self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=3, stride=1, padding=1, bias=False)
|
478 |
+
self.bn1 = nn.BatchNorm2d(self.inplanes)
|
479 |
+
self.relu = nn.ReLU(inplace=True)
|
480 |
+
# self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
|
481 |
+
|
482 |
+
self.layer1 = LayerBlock(block, 64, 64, layers[0], stride=1)
|
483 |
+
self.layer2 = LayerBlock(block, 64, 128, layers[1], stride=2)
|
484 |
+
self.layer3 = LayerBlock(block, 128, 256, layers[2], stride=2)
|
485 |
+
self.layer4 = LayerBlock(block, 256, 512, layers[3], stride=2)
|
486 |
+
|
487 |
+
self.bottleneck1_1 = LayerBlock(block, 64, 512, 1, stride=8)
|
488 |
+
# branchBottleNeck(64 * block.expansion, 512 * block.expansion, kernel_size=8)
|
489 |
+
self.avgpool1 = nn.AdaptiveAvgPool2d((1,1))
|
490 |
+
self.middle_fc1 = nn.Linear(512 * block.expansion, num_classes)
|
491 |
+
|
492 |
+
self.bottleneck2_1 = LayerBlock(block, 128, 512, 1, stride=4)
|
493 |
+
# branchBottleNeck(128 * block.expansion, 512 * block.expansion, kernel_size=4)
|
494 |
+
self.avgpool2 = nn.AdaptiveAvgPool2d((1,1))
|
495 |
+
self.middle_fc2 = nn.Linear(512 * block.expansion, num_classes)
|
496 |
+
|
497 |
+
|
498 |
+
# self.downsample3_1 = nn.Sequential(
|
499 |
+
# conv1x1(256 * block.expansion, 512 * block.expansion, stride=2),
|
500 |
+
# nn.BatchNorm2d(512 * block.expansion),
|
501 |
+
# )
|
502 |
+
self.bottleneck3_1 = LayerBlock(block, 256, 512, 1, stride=2)
|
503 |
+
self.avgpool3 = nn.AdaptiveAvgPool2d((1,1))
|
504 |
+
self.middle_fc3 = nn.Linear(512 * block.expansion, num_classes)
|
505 |
+
|
506 |
+
self.avgpool = nn.AdaptiveAvgPool2d((1,1))
|
507 |
+
self.fc = nn.Linear(512 * block.expansion, num_classes)
|
508 |
+
|
509 |
+
self.apply(_weights_init)
|
510 |
+
|
511 |
+
def _make_layer(self, block, planes, layers, stride=1):
|
512 |
+
"""A block with 'layers' layers
|
513 |
+
Args:
|
514 |
+
block (class): block type
|
515 |
+
planes (int): output channels = planes * expansion
|
516 |
+
layers (int): layer num in the block
|
517 |
+
stride (int): the first layer stride in the block
|
518 |
+
"""
|
519 |
+
downsample = None
|
520 |
+
if stride !=1 or self.inplanes != planes * block.expansion:
|
521 |
+
downsample = nn.Sequential(
|
522 |
+
conv1x1(self.inplanes, planes * block.expansion, stride),
|
523 |
+
nn.BatchNorm2d(planes * block.expansion),
|
524 |
+
)
|
525 |
+
layer = []
|
526 |
+
layer.append(block(self.inplanes, planes, stride=stride, downsample=downsample))
|
527 |
+
self.inplanes = planes * block.expansion
|
528 |
+
for i in range(1, layers):
|
529 |
+
layer.append(block(self.inplanes, planes))
|
530 |
+
|
531 |
+
return nn.Sequential(*layer)
|
532 |
+
|
533 |
+
def forward(self, x):
|
534 |
+
all_bn_outputs = []
|
535 |
+
|
536 |
+
x = self.conv1(x)
|
537 |
+
x = self.bn1(x)
|
538 |
+
x = self.relu(x)
|
539 |
+
# x = self.maxpool(x)
|
540 |
+
|
541 |
+
x, bn_outputs = self.layer1(x)
|
542 |
+
all_bn_outputs.extend(bn_outputs)
|
543 |
+
middle_output1, _ = self.bottleneck1_1(x)
|
544 |
+
middle_output1 = self.avgpool1(middle_output1)
|
545 |
+
middle1_fea = middle_output1
|
546 |
+
middle_output1 = torch.flatten(middle_output1, 1)
|
547 |
+
middle_output1 = self.middle_fc1(middle_output1)
|
548 |
+
|
549 |
+
x, bn_outputs = self.layer2(x)
|
550 |
+
all_bn_outputs.extend(bn_outputs)
|
551 |
+
middle_output2, _ = self.bottleneck2_1(x)
|
552 |
+
middle_output2 = self.avgpool2(middle_output2)
|
553 |
+
middle2_fea = middle_output2
|
554 |
+
middle_output2 = torch.flatten(middle_output2, 1)
|
555 |
+
middle_output2 = self.middle_fc2(middle_output2)
|
556 |
+
|
557 |
+
x, bn_outputs = self.layer3(x)
|
558 |
+
all_bn_outputs.extend(bn_outputs)
|
559 |
+
middle_output3, _ = self.bottleneck3_1(x)
|
560 |
+
middle_output3 = self.avgpool3(middle_output3)
|
561 |
+
middle3_fea = middle_output3
|
562 |
+
middle_output3 = torch.flatten(middle_output3, 1)
|
563 |
+
middle_output3 = self.middle_fc3(middle_output3)
|
564 |
+
|
565 |
+
x, bn_outputs = self.layer4(x)
|
566 |
+
all_bn_outputs.extend(bn_outputs)
|
567 |
+
x = self.avgpool(x)
|
568 |
+
final_fea = x
|
569 |
+
x = torch.flatten(x, 1)
|
570 |
+
x = self.fc(x)
|
571 |
+
|
572 |
+
if self.position_all:
|
573 |
+
return {'outputs': [x, middle_output1, middle_output2, middle_output3],
|
574 |
+
'features': [final_fea, middle1_fea, middle2_fea, middle3_fea],
|
575 |
+
'bn_outputs': all_bn_outputs}
|
576 |
+
else:
|
577 |
+
return {'outputs': [x, middle_output3],
|
578 |
+
'features': [final_fea, middle1_fea, middle2_fea, middle3_fea],
|
579 |
+
'bn_outputs': all_bn_outputs}
|
580 |
+
|
581 |
+
class SDResNet_s(nn.Module):
|
582 |
+
"""
|
583 |
+
Resnet model small
|
584 |
+
|
585 |
+
Args:
|
586 |
+
block (class): block type, BasicBlock or BottlenetckBlock
|
587 |
+
layers (int list): layer num in each block
|
588 |
+
num_classes (int): class num
|
589 |
+
"""
|
590 |
+
|
591 |
+
def __init__(self, block, layers, num_classes=10):
|
592 |
+
super(SDResNet_s, self).__init__()
|
593 |
+
|
594 |
+
self.inplanes = 16
|
595 |
+
self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=3, stride=1, padding=1, bias=False)
|
596 |
+
self.bn1 = nn.BatchNorm2d(self.inplanes)
|
597 |
+
self.relu = nn.ReLU(inplace=True)
|
598 |
+
# self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
|
599 |
+
|
600 |
+
self.layer1 = self._make_layer(block, 16, layers[0])
|
601 |
+
self.layer2 = self._make_layer(block, 32, layers[1], stride=2)
|
602 |
+
self.layer3 = self._make_layer(block, 64, layers[2], stride=2)
|
603 |
+
|
604 |
+
self.downsample1_1 = nn.Sequential(
|
605 |
+
conv1x1(16 * block.expansion, 64 * block.expansion, stride=4),
|
606 |
+
nn.BatchNorm2d(64 * block.expansion),
|
607 |
+
)
|
608 |
+
self.bottleneck1_1 = branchBottleNeck(16 * block.expansion, 64 * block.expansion, kernel_size=4)
|
609 |
+
self.avgpool1 = nn.AdaptiveAvgPool2d((1,1))
|
610 |
+
self.middle_fc1 = nn.Linear(64 * block.expansion, num_classes)
|
611 |
+
|
612 |
+
|
613 |
+
self.downsample2_1 = nn.Sequential(
|
614 |
+
conv1x1(32 * block.expansion, 64 * block.expansion, stride=2),
|
615 |
+
nn.BatchNorm2d(64 * block.expansion),
|
616 |
+
)
|
617 |
+
self.bottleneck2_1 = branchBottleNeck(32 * block.expansion, 64 * block.expansion, kernel_size=2)
|
618 |
+
self.avgpool2 = nn.AdaptiveAvgPool2d((1,1))
|
619 |
+
self.middle_fc2 = nn.Linear(64 * block.expansion, num_classes)
|
620 |
+
|
621 |
+
self.avgpool = nn.AdaptiveAvgPool2d((1,1))
|
622 |
+
self.fc = nn.Linear(64 * block.expansion, num_classes)
|
623 |
+
|
624 |
+
for m in self.modules():
|
625 |
+
if isinstance(m, nn.Conv2d):
|
626 |
+
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
|
627 |
+
elif isinstance(m, nn.BatchNorm2d):
|
628 |
+
nn.init.constant_(m.weight, 1)
|
629 |
+
nn.init.constant_(m.bias, 0)
|
630 |
+
|
631 |
+
def _make_layer(self, block, planes, layers, stride=1):
|
632 |
+
"""A block with 'layers' layers
|
633 |
+
Args:
|
634 |
+
block (class): block type
|
635 |
+
planes (int): output channels = planes * expansion
|
636 |
+
layers (int): layer num in the block
|
637 |
+
stride (int): the first layer stride in the block
|
638 |
+
"""
|
639 |
+
strides = [stride] + [1]*(layers-1)
|
640 |
+
layers = []
|
641 |
+
for stride in strides:
|
642 |
+
layers.append(block(self.inplanes, planes, stride))
|
643 |
+
self.inplanes = planes * block.expansion
|
644 |
+
|
645 |
+
return nn.Sequential(*layers)
|
646 |
+
|
647 |
+
def forward(self, x):
|
648 |
+
x = self.conv1(x)
|
649 |
+
x = self.bn1(x)
|
650 |
+
x = self.relu(x)
|
651 |
+
|
652 |
+
x = self.layer1(x)
|
653 |
+
middle_output1 = self.bottleneck1_1(x)
|
654 |
+
middle_output1 = self.avgpool1(middle_output1)
|
655 |
+
middle1_fea = middle_output1
|
656 |
+
middle_output1 = torch.flatten(middle_output1, 1)
|
657 |
+
middle_output1 = self.middle_fc1(middle_output1)
|
658 |
+
|
659 |
+
x = self.layer2(x)
|
660 |
+
middle_output2 = self.bottleneck2_1(x)
|
661 |
+
middle_output2 = self.avgpool2(middle_output2)
|
662 |
+
middle2_fea = middle_output2
|
663 |
+
middle_output2 = torch.flatten(middle_output2, 1)
|
664 |
+
middle_output2 = self.middle_fc2(middle_output2)
|
665 |
+
|
666 |
+
x = self.layer3(x)
|
667 |
+
x = self.avgpool(x)
|
668 |
+
final_fea = x
|
669 |
+
x = torch.flatten(x, 1)
|
670 |
+
x = self.fc(x)
|
671 |
+
|
672 |
+
return {'outputs': [x, middle_output1, middle_output2],
|
673 |
+
'features': [final_fea, middle1_fea, middle2_fea]}
|
674 |
+
|
675 |
+
def sdresnet18(num_classes=10, position_all=True):
|
676 |
+
return SDResNet(BasicBlock, [2,2,2,2], num_classes=num_classes, position_all=position_all)
|
677 |
+
|
678 |
+
def sdresnet34(num_classes=10, position_all=True):
|
679 |
+
return SDResNet(BasicBlock, [3,4,6,3], num_classes=num_classes, position_all=position_all)
|
680 |
+
|
681 |
+
def sdresnet34_mlp(num_classes=10, position_all=True):
|
682 |
+
return SDResNet_mlp(BasicBlock, [3,4,6,3], num_classes=num_classes, position_all=position_all)
|
683 |
+
|
684 |
+
def sdresnet34_residual(num_classes=10, position_all=True):
|
685 |
+
return SDResNet_residual(BasicBlock, [3,4,6,3], num_classes=num_classes, position_all=position_all)
|
686 |
+
|
687 |
+
def sdresnet32(num_classes=10):
|
688 |
+
return SDResNet_s(BasicBlock_s, [5,5,5], num_classes=num_classes)
|
models/__init__.py
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .ResNet_Imagenet import sdresnet50
|
2 |
+
from .ResNet_cifar import sdresnet18, sdresnet34, sdresnet32
|
3 |
+
from .ResNet_cifar import sdresnet34_mlp, sdresnet34_residual
|
4 |
+
from .InceptionResNetV2 import InceptionResNetV2
|
models/__pycache__/CNN.cpython-310.pyc
ADDED
Binary file (6.44 kB). View file
|
|
models/__pycache__/InceptionResNetV2.cpython-310.pyc
ADDED
Binary file (8.65 kB). View file
|
|
models/__pycache__/ResNet_Imagenet.cpython-310.pyc
ADDED
Binary file (8.04 kB). View file
|
|
models/__pycache__/ResNet_cifar.cpython-310.pyc
ADDED
Binary file (16 kB). View file
|
|
models/__pycache__/__init__.cpython-310.pyc
ADDED
Binary file (426 Bytes). View file
|
|
requirements.txt
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
numpy
|
2 |
+
pillow
|
3 |
+
matplotlib
|
4 |
+
scikit-learn
|
5 |
+
scipy
|
6 |
+
torch
|
7 |
+
torchvision
|
train_cifar_c2mt.py
ADDED
@@ -0,0 +1,807 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from __future__ import print_function
|
2 |
+
import sys
|
3 |
+
import torch
|
4 |
+
import torch.nn as nn
|
5 |
+
import torch.optim as optim
|
6 |
+
import torch.nn.functional as F
|
7 |
+
import torch.backends.cudnn as cudnn
|
8 |
+
import random
|
9 |
+
import os
|
10 |
+
import argparse
|
11 |
+
import numpy as np
|
12 |
+
from PreResNet import *
|
13 |
+
from sklearn.mixture import GaussianMixture
|
14 |
+
import dataloader_cifar as dataloader
|
15 |
+
import matplotlib.pyplot as plt
|
16 |
+
import copy
|
17 |
+
import seaborn as sns
|
18 |
+
# from sklearn.mixture import GaussianMixture
|
19 |
+
from sklearn.cluster import KMeans
|
20 |
+
from sklearn.cluster import Birch
|
21 |
+
import matplotlib
|
22 |
+
|
23 |
+
parser = argparse.ArgumentParser(description='PyTorch CIFAR Training')
|
24 |
+
parser.add_argument('--batch_size', default=128, type=int, help='train batchsize')
|
25 |
+
parser.add_argument('--lr', '--learning_rate', default=0.02, type=float, help='initial learning rate')
|
26 |
+
parser.add_argument('--noise_mode', default='asym')
|
27 |
+
parser.add_argument('--alpha', default=4, type=float, help='parameter for Beta')
|
28 |
+
parser.add_argument('--lambda_u', default=150, type=float, help='weight for unsupervised loss')
|
29 |
+
parser.add_argument('--p_threshold', default=0.5, type=float, help='clean probability threshold')
|
30 |
+
parser.add_argument('--T', default=0.5, type=float, help='sharpening temperature')
|
31 |
+
parser.add_argument('--num_epochs', default=300, type=int)
|
32 |
+
parser.add_argument('--r', default=0.3, type=float, help='noise ratio')
|
33 |
+
parser.add_argument('--id', default='')
|
34 |
+
parser.add_argument('--seed', default=123)
|
35 |
+
parser.add_argument('--gpuid', default=0, type=int)
|
36 |
+
parser.add_argument('--num_class', default=100, type=int)
|
37 |
+
# parser.add_argument('--data_path', default='./data/cifar-10-batches-py', type=str, help='path to dataset')
|
38 |
+
# parser.add_argument('--dataset', default='cifar10', type=str)
|
39 |
+
parser.add_argument('--data_path', default='./data/cifar-100-python', type=str, help='path to dataset')
|
40 |
+
parser.add_argument('--dataset', default='cifar100', type=str)
|
41 |
+
args = parser.parse_args()
|
42 |
+
|
43 |
+
torch.cuda.set_device(args.gpuid)
|
44 |
+
random.seed(args.seed)
|
45 |
+
torch.manual_seed(args.seed)
|
46 |
+
torch.cuda.manual_seed_all(args.seed)
|
47 |
+
|
48 |
+
mse = torch.nn.MSELoss(reduction='none').cuda()
|
49 |
+
|
50 |
+
|
51 |
+
# Training
|
52 |
+
def train(epoch, net, net2, optimizer, labeled_trainloader, unlabeled_trainloader, mask=None, f_G=None, new_y=None):
|
53 |
+
net.train()
|
54 |
+
net2.eval() # fix one network and train the other
|
55 |
+
|
56 |
+
unlabeled_train_iter = iter(unlabeled_trainloader)
|
57 |
+
num_iter = (len(labeled_trainloader.dataset) // args.batch_size) + 1
|
58 |
+
mse_total = 0
|
59 |
+
for batch_idx, (inputs_x, inputs_x2, labels_x, w_x) in enumerate(labeled_trainloader):
|
60 |
+
try:
|
61 |
+
inputs_u, inputs_u2 = unlabeled_train_iter.__next__()
|
62 |
+
except:
|
63 |
+
unlabeled_train_iter = iter(unlabeled_trainloader)
|
64 |
+
inputs_u, inputs_u2 = unlabeled_train_iter.__next__()
|
65 |
+
batch_size = inputs_x.size(0)
|
66 |
+
|
67 |
+
# Transform label to one-hot,转为0-1矩阵
|
68 |
+
labels_x = torch.zeros(batch_size, args.num_class).scatter_(1, labels_x.view(-1, 1), 1)
|
69 |
+
w_x = w_x.view(-1, 1).type(torch.FloatTensor)
|
70 |
+
|
71 |
+
inputs_x, inputs_x2, labels_x, w_x = inputs_x.cuda(), inputs_x2.cuda(), labels_x.cuda(), w_x.cuda()
|
72 |
+
inputs_u, inputs_u2 = inputs_u.cuda(), inputs_u2.cuda()
|
73 |
+
|
74 |
+
with torch.no_grad():
|
75 |
+
# label co-guessing of unlabeled samples
|
76 |
+
outputs_u11, feat_u11 = net(inputs_u, feat_out=True)
|
77 |
+
outputs_u12, feat_u12 = net(inputs_u2, feat_out=True)
|
78 |
+
outputs_u21, feat_u21 = net2(inputs_u, feat_out=True)
|
79 |
+
outputs_u22, feat_u22 = net2(inputs_u2, feat_out=True)
|
80 |
+
|
81 |
+
# 取average of 所有网络的输出,作者利用了所谓的augmentation
|
82 |
+
pu = (torch.softmax(outputs_u11, dim=1) + torch.softmax(outputs_u12, dim=1)
|
83 |
+
+ torch.softmax(outputs_u21, dim=1) + torch.softmax(outputs_u22, dim=1)) / 4
|
84 |
+
ptu = pu ** (1 / args.T) # temparature sharpening
|
85 |
+
|
86 |
+
# Algorithm 1 中的shapen(qb,T)
|
87 |
+
targets_u = ptu / ptu.sum(dim=1, keepdim=True) # normalize
|
88 |
+
targets_u = targets_u.detach()
|
89 |
+
|
90 |
+
# label refinement of labeled samples
|
91 |
+
outputs_x, feat_x1 = net(inputs_x, feat_out=True)
|
92 |
+
outputs_x2, feat_x2 = net(inputs_x2, feat_out=True)
|
93 |
+
|
94 |
+
# 取labeled的输出平均值
|
95 |
+
px = (torch.softmax(outputs_x, dim=1) + torch.softmax(outputs_x2, dim=1)) / 2
|
96 |
+
|
97 |
+
# 公式(3)(4)退火
|
98 |
+
px = w_x * labels_x + (1 - w_x) * px
|
99 |
+
ptx = px ** (1 / args.T) # temparature sharpening
|
100 |
+
|
101 |
+
targets_x = ptx / ptx.sum(dim=1, keepdim=True) # normalize
|
102 |
+
targets_x = targets_x.detach()
|
103 |
+
# aaa = torch.argmax(labels_x, dim=1)
|
104 |
+
# mse_loss = torch.sum(mse((feat_x1+feat_x2)/2, f_G[aaa]), 1)
|
105 |
+
# mse_total = (mse_total + torch.sum(mse_loss) / len(mse_loss))/2
|
106 |
+
# mixmatch
|
107 |
+
l = np.random.beta(args.alpha, args.alpha)
|
108 |
+
# 促使X'更加靠近labeled sample而不是无监督样本
|
109 |
+
l = max(l, 1 - l)
|
110 |
+
|
111 |
+
all_inputs = torch.cat([inputs_x, inputs_x2, inputs_u, inputs_u2], dim=0)
|
112 |
+
all_targets = torch.cat([targets_x, targets_x, targets_u, targets_u], dim=0)
|
113 |
+
|
114 |
+
# 随机输出mini batch的序号,来mixup
|
115 |
+
idx = torch.randperm(all_inputs.size(0))
|
116 |
+
|
117 |
+
input_a, input_b = all_inputs, all_inputs[idx]
|
118 |
+
target_a, target_b = all_targets, all_targets[idx]
|
119 |
+
|
120 |
+
# 利用mix但是促使模型更偏向于label而不是UNlabel
|
121 |
+
mixed_input = l * input_a + (1 - l) * input_b
|
122 |
+
mixed_target = l * target_a + (1 - l) * target_b
|
123 |
+
|
124 |
+
logits = net(mixed_input)
|
125 |
+
# 输出被排列成两部分,input_x、Input_u
|
126 |
+
logits_x = logits[:batch_size * 2]
|
127 |
+
logits_u = logits[batch_size * 2:]
|
128 |
+
|
129 |
+
# 利用公式(9)-(10)计算损失函数,其中lamb是所谓的warm up
|
130 |
+
Lx, Lu, lamb = criterion(logits_x, mixed_target[:batch_size * 2],
|
131 |
+
logits_u, mixed_target[batch_size * 2:],
|
132 |
+
epoch + batch_idx / num_iter, warm_up)
|
133 |
+
|
134 |
+
# regularization
|
135 |
+
prior = torch.ones(args.num_class) / args.num_class
|
136 |
+
prior = prior.cuda()
|
137 |
+
pred_mean = torch.softmax(logits, dim=1).mean(0)
|
138 |
+
# 一般来说会省略固定的prior部分,只取last term
|
139 |
+
# lambR=1
|
140 |
+
penalty = torch.sum(prior * torch.log(prior / pred_mean))
|
141 |
+
|
142 |
+
# lamb是通过warm和current epoch比较得出的百分数,意味着随着epoch进行,Lu所占比重会逐渐增加
|
143 |
+
# 前期需要保持标准CE损失,但是实际还有penalty
|
144 |
+
# loss = Lx + lamb * Lu + penalty
|
145 |
+
loss = Lx + penalty + lamb * Lu
|
146 |
+
# compute gradient and do SGD step
|
147 |
+
optimizer.zero_grad()
|
148 |
+
loss.backward()
|
149 |
+
optimizer.step()
|
150 |
+
if batch_idx % 200 == 0:
|
151 |
+
sys.stdout.write('\r')
|
152 |
+
sys.stdout.write(
|
153 |
+
'%s:%.1f-%s | Epoch [%3d/%3d] Iter[%3d/%3d]\t Labeled loss: %.2f Unlabeled loss: %.2f\n'
|
154 |
+
% (args.dataset, args.r, args.noise_mode, epoch, args.num_epochs, batch_idx + 1, num_iter,
|
155 |
+
Lx.item(), Lu.item()))
|
156 |
+
sys.stdout.flush()
|
157 |
+
# print('\r mse loss:%.4f\n' % mse_total, end='end', flush=True)
|
158 |
+
# print('\r mse loss:%.4f\n' % mse_total, end='end', flush=True)
|
159 |
+
|
160 |
+
def mixup_criterion(pred, y_a, y_b, lam):
|
161 |
+
c = F.log_softmax(pred, 1)
|
162 |
+
return lam * F.cross_entropy(c, y_a) + (1 - lam) * F.cross_entropy(c, y_b)
|
163 |
+
|
164 |
+
|
165 |
+
soft_mix_warm = False
|
166 |
+
|
167 |
+
def warmup(epoch, net, optimizer, dataloader):
|
168 |
+
net.train()
|
169 |
+
num_iter = (len(dataloader.dataset) // dataloader.batch_size) + 1
|
170 |
+
for batch_idx, (inputs, labels, path) in enumerate(dataloader):
|
171 |
+
optimizer.zero_grad()
|
172 |
+
l = np.random.beta(args.alpha, args.alpha)
|
173 |
+
# 促使X'更加靠近labeled sample而不是无监督样本
|
174 |
+
l = max(l, 1 - l)
|
175 |
+
idx = torch.randperm(inputs.size(0))
|
176 |
+
targets = torch.zeros(inputs.size(0), args.num_class).scatter_(1, labels.view(-1, 1), 1).cuda()
|
177 |
+
targets = torch.clamp(targets, 1e-4, 1.)
|
178 |
+
inputs, labels = inputs.cuda(), labels.cuda()
|
179 |
+
if soft_mix_warm:
|
180 |
+
input_a, input_b = inputs, inputs[idx]
|
181 |
+
target_a, target_b = targets, targets[idx]
|
182 |
+
labels_a, labels_b = labels, labels[idx]
|
183 |
+
|
184 |
+
# 利用mix但是促使模型更偏向于label而不是UNlabel
|
185 |
+
mixed_input = l * input_a + (1 - l) * input_b
|
186 |
+
mixed_target = l * target_a + (1 - l) * target_b
|
187 |
+
|
188 |
+
outputs = net(mixed_input)
|
189 |
+
loss = mixup_criterion(outputs, labels_a, labels_b, l)
|
190 |
+
L = loss
|
191 |
+
else:
|
192 |
+
outputs = net(inputs)
|
193 |
+
loss = CEloss(outputs, labels)
|
194 |
+
if args.noise_mode == 'asym': # penalize confident prediction for asymmetric noise
|
195 |
+
penalty = conf_penalty(outputs)
|
196 |
+
L = loss + penalty
|
197 |
+
elif args.noise_mode == 'sym':
|
198 |
+
L = loss
|
199 |
+
L.backward()
|
200 |
+
optimizer.step()
|
201 |
+
if batch_idx % 200 == 0:
|
202 |
+
sys.stdout.write('\r')
|
203 |
+
sys.stdout.write('%s:%.1f-%s | Epoch [%3d/%3d] Iter[%3d/%3d]\t CE-loss: %.4f'
|
204 |
+
% (args.dataset, args.r, args.noise_mode, epoch, args.num_epochs, batch_idx + 1, num_iter,
|
205 |
+
loss.item()))
|
206 |
+
sys.stdout.flush()
|
207 |
+
|
208 |
+
|
209 |
+
def test(epoch, net1, net2, best_acc, w_glob=None):
|
210 |
+
if w_glob is None:
|
211 |
+
net1.eval()
|
212 |
+
net2.eval()
|
213 |
+
correct = 0
|
214 |
+
total = 0
|
215 |
+
with torch.no_grad():
|
216 |
+
for batch_idx, (inputs, targets) in enumerate(test_loader):
|
217 |
+
inputs, targets = inputs.cuda(), targets.cuda()
|
218 |
+
outputs1 = net1(inputs)
|
219 |
+
outputs2 = net2(inputs)
|
220 |
+
outputs = outputs1 + outputs2
|
221 |
+
_, predicted = torch.max(outputs, 1)
|
222 |
+
|
223 |
+
total += targets.size(0)
|
224 |
+
correct += predicted.eq(targets).cpu().sum().item()
|
225 |
+
acc = 100. * correct / total
|
226 |
+
if best_acc < acc:
|
227 |
+
best_acc = acc
|
228 |
+
print("\n| Ensemble network Test Epoch #%d\t Accuracy: %.2f, best_acc: %.2f%%\n" % (epoch, acc, best_acc))
|
229 |
+
test_log.write('ensemble_Epoch:%d Accuracy:%.2f, best_acc: %.2f\n' % (epoch, acc, best_acc))
|
230 |
+
test_log.flush()
|
231 |
+
else:
|
232 |
+
net1_w_bak = net1.state_dict()
|
233 |
+
net1.load_state_dict(w_glob)
|
234 |
+
net1.eval()
|
235 |
+
correct = 0
|
236 |
+
total = 0
|
237 |
+
with torch.no_grad():
|
238 |
+
for batch_idx, (inputs, targets) in enumerate(test_loader):
|
239 |
+
inputs, targets = inputs.cuda(), targets.cuda()
|
240 |
+
outputs1 = net1(inputs)
|
241 |
+
_, predicted = torch.max(outputs1, 1)
|
242 |
+
total += targets.size(0)
|
243 |
+
correct += predicted.eq(targets).cpu().sum().item()
|
244 |
+
acc = 100. * correct / total
|
245 |
+
if best_acc < acc:
|
246 |
+
best_acc = acc
|
247 |
+
print("\n| Global network Test Epoch #%d\t Accuracy: %.2f, best_acc: %.2f%%\n" % (epoch, acc, best_acc))
|
248 |
+
test_log.write('global_Epoch:%d Accuracy:%.2f, best_acc: %.2f\n' % (epoch, acc, best_acc))
|
249 |
+
test_log.flush()
|
250 |
+
# 恢复权重
|
251 |
+
net1.load_state_dict(net1_w_bak)
|
252 |
+
return best_acc
|
253 |
+
|
254 |
+
feat_dim = 512 #是否可以加个全连接改成128
|
255 |
+
sim = torch.nn.CosineSimilarity(dim=1)
|
256 |
+
|
257 |
+
loss_func = torch.nn.CrossEntropyLoss(reduction='none')
|
258 |
+
def get_small_loss_samples(y_pred, y_true, forget_rate):
|
259 |
+
loss = loss_func(y_pred, y_true)
|
260 |
+
ind_sorted = np.argsort(loss.data.cpu()).cuda()
|
261 |
+
loss_sorted = loss[ind_sorted]
|
262 |
+
|
263 |
+
remember_rate = 1 - forget_rate
|
264 |
+
num_remember = int(remember_rate * len(loss_sorted))
|
265 |
+
|
266 |
+
ind_update = ind_sorted[:num_remember]
|
267 |
+
|
268 |
+
return ind_update
|
269 |
+
|
270 |
+
def get_small_loss_by_loss_list(loss_list, forget_rate, eval_loader):
|
271 |
+
remember_rate = 1 - forget_rate
|
272 |
+
idx_list = []
|
273 |
+
for i in range(10):
|
274 |
+
class_idx = np.where(np.array(eval_loader.dataset.noise_label)[:] == i)[0]
|
275 |
+
# class_idx = torch.from_numpy(class_idx).cuda()
|
276 |
+
loss_per_class = loss_list[class_idx] #取对应target的loss
|
277 |
+
num_remember = int(remember_rate * len(loss_per_class))
|
278 |
+
ind_sorted = np.argsort(loss_per_class.data.cpu())
|
279 |
+
ind_update = ind_sorted[:num_remember].tolist()
|
280 |
+
idx_list.append(ind_update)
|
281 |
+
|
282 |
+
return idx_list
|
283 |
+
|
284 |
+
def eval_train(model, all_loss):
|
285 |
+
model.eval()
|
286 |
+
losses = torch.zeros(50000)
|
287 |
+
f_G = torch.zeros(args.num_class, feat_dim).cuda()
|
288 |
+
f_all = torch.zeros(50000, feat_dim).cuda()
|
289 |
+
n_labels = torch.zeros(args.num_class, 1).cuda()
|
290 |
+
y_k_tilde = torch.zeros(50000)
|
291 |
+
mask = np.zeros(50000)
|
292 |
+
with torch.no_grad():
|
293 |
+
for batch_idx, (inputs, targets, index) in enumerate(eval_loader):
|
294 |
+
inputs, targets = inputs.cuda(), targets.cuda()
|
295 |
+
outputs, feat = model(inputs, feat_out=True)
|
296 |
+
loss = CE(outputs, targets)
|
297 |
+
_, predicted = torch.max(outputs, 1)
|
298 |
+
for b in range(inputs.size(0)):
|
299 |
+
losses[index[b]] = loss[b]
|
300 |
+
f_G[predicted[b]] += feat[b]
|
301 |
+
n_labels[predicted[b]] += 1
|
302 |
+
f_all[index] = feat
|
303 |
+
assert torch.sum(n_labels) == 50000
|
304 |
+
for i in range(len(n_labels)):
|
305 |
+
if n_labels[i] == 0:
|
306 |
+
n_labels[i] = 1
|
307 |
+
f_G = torch.div(f_G, n_labels)
|
308 |
+
f_G = F.normalize(f_G, dim=1)
|
309 |
+
f_all = F.normalize(f_all, dim=1)
|
310 |
+
temp = f_G.t()
|
311 |
+
sim_all = torch.mm(f_all, temp) # .cpu().numpy()
|
312 |
+
y_k_tilde = torch.argmax(sim_all.cpu(), dim=1)
|
313 |
+
with torch.no_grad():
|
314 |
+
for batch_idx, (inputs, targets, index) in enumerate(eval_loader):
|
315 |
+
for i in range(len(index)):
|
316 |
+
if y_k_tilde[index[i]] == targets[i]:
|
317 |
+
mask[index[i]] = 1
|
318 |
+
losses = (losses - losses.min()) / (losses.max() - losses.min())
|
319 |
+
all_loss.append(losses)
|
320 |
+
|
321 |
+
if args.r == 0.9:
|
322 |
+
# average loss over last 5 epochs to improve convergence stability
|
323 |
+
history = torch.stack(all_loss)
|
324 |
+
input_loss = history[-5:].mean(0)
|
325 |
+
input_loss = input_loss.reshape(-1, 1)
|
326 |
+
else:
|
327 |
+
input_loss = losses.reshape(-1, 1)
|
328 |
+
|
329 |
+
# fit a two-component GMM to the loss
|
330 |
+
# 参数如下:
|
331 |
+
# n_components 聚类数量,max_iter 最大迭代次数,tol 阈值低于停止,reg_covar 协方差矩阵对角线上非负正则化参数,接近0即可
|
332 |
+
gmm = GaussianMixture(n_components=2, max_iter=10, tol=1e-2, reg_covar=5e-4)
|
333 |
+
gmm.fit(input_loss)
|
334 |
+
prob = gmm.predict_proba(input_loss)
|
335 |
+
prob = prob[:, gmm.means_.argmin()]
|
336 |
+
return prob, all_loss, losses.numpy(), mask, f_G
|
337 |
+
|
338 |
+
def mix_data_lab(x, y, alpha=1.0):
|
339 |
+
'''Returns mixed inputs, pairs of targets, and lambda'''
|
340 |
+
if alpha > 0:
|
341 |
+
lam = np.random.beta(alpha, alpha)
|
342 |
+
else:
|
343 |
+
lam = 1
|
344 |
+
|
345 |
+
batch_size = x.size()[0]
|
346 |
+
index = torch.randperm(batch_size).cuda()
|
347 |
+
|
348 |
+
lam = max(lam, 1 - lam)
|
349 |
+
mixed_x = lam * x + (1 - lam) * x[index, :]
|
350 |
+
y_a, y_b = y, y[index]
|
351 |
+
|
352 |
+
return mixed_x, y_a, y_b, index, lam
|
353 |
+
|
354 |
+
|
355 |
+
def linear_rampup(current, warm_up, rampup_length=16):
|
356 |
+
# 线性warm_up,对sym噪声使用标准CE训练一段时间
|
357 |
+
# 实际warm up epoch是warm_up+rampup_length
|
358 |
+
current = np.clip((current - warm_up) / rampup_length, 0.0, 1.0)
|
359 |
+
re_val = args.lambda_u * float(current)
|
360 |
+
# print(" current warm up parameters:", current)
|
361 |
+
# print("return parameters:", re_val)
|
362 |
+
return re_val
|
363 |
+
|
364 |
+
|
365 |
+
class SemiLoss(object):
|
366 |
+
def __call__(self, outputs_x, targets_x, outputs_u, targets_u, epoch, warm_up):
|
367 |
+
probs_u = torch.softmax(outputs_u, dim=1)
|
368 |
+
|
369 |
+
# 利用mixup后的交叉熵,px输出*log(px_model)
|
370 |
+
Lx = -torch.mean(torch.sum(F.log_softmax(outputs_x, dim=1) * targets_x, dim=1))
|
371 |
+
# 而UNlabel则是均方误差,p_u输出-pu_model
|
372 |
+
Lu = torch.mean((probs_u - targets_u) ** 2)
|
373 |
+
|
374 |
+
return Lx, Lu, linear_rampup(epoch, warm_up)
|
375 |
+
|
376 |
+
|
377 |
+
class NegEntropy(object):
|
378 |
+
def __call__(self, outputs):
|
379 |
+
probs = torch.softmax(outputs, dim=1)
|
380 |
+
return torch.mean(torch.sum(probs.log() * probs, dim=1))
|
381 |
+
|
382 |
+
|
383 |
+
def create_model():
|
384 |
+
# 其实是pre-resnet18,使用的是pre-resnet block
|
385 |
+
model = ResNet18(num_classes=args.num_class)
|
386 |
+
model = model.cuda()
|
387 |
+
return model
|
388 |
+
|
389 |
+
def plotHistogram(model_1_loss, model_2_loss, noise_index, clean_index, epoch, round, noise_rate):
|
390 |
+
title = 'Epoch-' + str(epoch)+':'
|
391 |
+
fig = plt.figure()
|
392 |
+
plt.subplot(121)
|
393 |
+
gmm = GaussianMixture(n_components=2, max_iter=20, tol=1e-2, random_state=0, reg_covar=5e-4)
|
394 |
+
model_1_loss = np.reshape(model_1_loss, (-1, 1))
|
395 |
+
gmm.fit(model_1_loss) # fit the loss
|
396 |
+
|
397 |
+
# plot resulting fit
|
398 |
+
x_range = np.linspace(0, 1, 1000)
|
399 |
+
pdf = np.exp(gmm.score_samples(x_range.reshape(-1, 1)))
|
400 |
+
responsibilities = gmm.predict_proba(x_range.reshape(-1, 1))
|
401 |
+
pdf_individual = responsibilities * pdf[:, np.newaxis]
|
402 |
+
plt.hist(np.array(model_1_loss[noise_index]), density=True, bins=100, alpha=0.5,histtype='bar', color='red', label='Noisy subset')
|
403 |
+
plt.hist(np.array(model_1_loss[clean_index]), density=True, bins=100, alpha=0.5,histtype='bar', color='blue', label='Clean subset')
|
404 |
+
plt.plot(x_range, pdf, '-k', label='Mixture')
|
405 |
+
plt.plot(x_range, pdf_individual, '--', label='Component')
|
406 |
+
plt.legend(loc='upper right', prop={'size': 12})
|
407 |
+
plt.xlabel('Normalized loss')
|
408 |
+
plt.ylabel('Estimated pdf')
|
409 |
+
plt.title(title+'Model_1')
|
410 |
+
|
411 |
+
plt.subplot(122)
|
412 |
+
gmm = GaussianMixture(n_components=2, max_iter=20, tol=1e-2, random_state=0, reg_covar=5e-4)
|
413 |
+
model_2_loss = np.reshape(model_2_loss, (-1, 1))
|
414 |
+
gmm.fit(model_2_loss) # fit the loss
|
415 |
+
|
416 |
+
# plot resulting fit
|
417 |
+
x_range = np.linspace(0, 1, 1000)
|
418 |
+
pdf = np.exp(gmm.score_samples(x_range.reshape(-1, 1)))
|
419 |
+
responsibilities = gmm.predict_proba(x_range.reshape(-1, 1))
|
420 |
+
pdf_individual = responsibilities * pdf[:, np.newaxis]
|
421 |
+
plt.hist(np.array(model_2_loss[noise_index]), density=True, bins=100, alpha=0.5,histtype='bar', color='red', label='Noisy subset')
|
422 |
+
plt.hist(np.array(model_2_loss[clean_index]), density=True, bins=100, alpha=0.5,histtype='bar', color='blue', label='Clean subset')
|
423 |
+
plt.plot(x_range, pdf, '-k', label='Mixture')
|
424 |
+
plt.plot(x_range, pdf_individual, '--', label='Component')
|
425 |
+
plt.legend(loc='upper right', prop={'size': 12})
|
426 |
+
plt.xlabel('Normalized loss')
|
427 |
+
plt.ylabel('Estimated pdf')
|
428 |
+
plt.title(title+'Model_2')
|
429 |
+
|
430 |
+
print('\nlogging histogram...')
|
431 |
+
title = 'cifar10_' + str(args.noise_mode) + '_moit_double_' + str(noise_rate)
|
432 |
+
plt.savefig(os.path.join('./figure_his/', 'two_model_{}_{}_{}_{}.{}'.format(epoch, round, title, int(soft_mix_warm), ".tif")), dpi=300)
|
433 |
+
# plt.show()
|
434 |
+
plt.close()
|
435 |
+
|
436 |
+
|
437 |
+
def loss_dist_plot(loss, noisy_index, clean_index, epoch, rou=None, g_file=True, model_name='', loss2=None):
|
438 |
+
"""
|
439 |
+
plot the loss distribution
|
440 |
+
:param loss: the list contains the loss per sample
|
441 |
+
:param noisy_index: contains the indices of real noisy label
|
442 |
+
:param clean_index: contains the indices of real clean label
|
443 |
+
:param filename: the generated pdf file name
|
444 |
+
:param title: the figure title
|
445 |
+
:param g_file: whether to generate the pdf figure file
|
446 |
+
:return: None
|
447 |
+
"""
|
448 |
+
if loss2 is None:
|
449 |
+
filename = 'one_model_'+str(args.dataset)+'_'+str(args.noise_mode)+'_'+str(args.r)+'_epoch='+str(epoch)
|
450 |
+
if rou is None:
|
451 |
+
title = 'Epoch-'+str(epoch) + ': ' + str(args.dataset)+' '+str(args.r*100)+'%-'+str(args.noise_mode)
|
452 |
+
else:
|
453 |
+
title = 'Epoch-' + str(epoch) + ' ' +'Round-'+str(rou)+ ': ' + str(args.dataset) + ' ' + str(int(args.r * 100)) + '%-' + str(args.noise_mode)
|
454 |
+
if type(loss) is not np.ndarray:
|
455 |
+
loss= loss.numpy()
|
456 |
+
sns.set(style='whitegrid')
|
457 |
+
gmm = GaussianMixture(n_components=2, max_iter=20, tol=1e-2, random_state=0, reg_covar=5e-4)
|
458 |
+
loss = np.reshape(loss, (-1, 1))
|
459 |
+
gmm.fit(loss) # fit the loss
|
460 |
+
|
461 |
+
# plot resulting fit
|
462 |
+
x_range = np.linspace(0, 1, 1000)
|
463 |
+
pdf = np.exp(gmm.score_samples(x_range.reshape(-1, 1)))
|
464 |
+
responsibilities = gmm.predict_proba(x_range.reshape(-1, 1))
|
465 |
+
pdf_individual = responsibilities * pdf[:, np.newaxis]
|
466 |
+
# sns.distplot(loss[noisy_index], color="red", rug=False,kde=False, label="incorrect",
|
467 |
+
# hist_kws={"color": "r", "alpha": 0.5})
|
468 |
+
# sns.distplot(loss[clean_index], color="skyblue", rug=False,kde=False, label="correct",
|
469 |
+
# hist_kws={"color": "b", "alpha": 0.5})
|
470 |
+
|
471 |
+
plt.hist(np.array(loss[noisy_index]), density=True, bins=100, histtype='bar', alpha=0.5, color='red',
|
472 |
+
label='Noisy subset')
|
473 |
+
plt.hist(np.array(loss[clean_index]), density=True, bins=100, histtype='bar', alpha=0.5, color='blue',
|
474 |
+
label='Clean subset')
|
475 |
+
plt.plot(x_range, pdf, '-k', label='Mixture')
|
476 |
+
plt.plot(x_range, pdf_individual, '--', label='Component')
|
477 |
+
# plt.plot(x_range, pdf_individual[:][1], '--', color='blue', label='Component 1')
|
478 |
+
|
479 |
+
plt.title(title, fontsize=20)
|
480 |
+
plt.xlabel('Normalized loss', fontsize=24)
|
481 |
+
plt.ylabel('Estimated pdf', fontsize=24)
|
482 |
+
|
483 |
+
plt.tick_params(labelsize=24)
|
484 |
+
plt.legend(loc='upper right', prop={'size': 12})
|
485 |
+
# plt.tight_layout()
|
486 |
+
if g_file:
|
487 |
+
plt.savefig('./figure_his/{0}.tif'.format(filename+model_name), bbox_inches='tight', dpi=300)
|
488 |
+
#plt.show()
|
489 |
+
plt.close()
|
490 |
+
else:
|
491 |
+
filename = 'noise_'+str(args.dataset) + '_' + str(args.noise_mode) + '_' + str(args.r) + '_epoch=' + str(epoch)
|
492 |
+
if rou is None:
|
493 |
+
title = 'Epoch-' + str(epoch) + ': ' + str(args.dataset) + ' ' + str(args.r * 100) + '%-' + str(
|
494 |
+
args.noise_mode)
|
495 |
+
else:
|
496 |
+
title = 'Epoch-' + str(epoch) + ' ' + 'Round-' + str(rou) + ': ' + str(args.dataset) + ' ' + str(
|
497 |
+
args.r * 100) + '%-' + str(args.noise_mode)
|
498 |
+
if type(loss) is not np.ndarray:
|
499 |
+
loss = loss.numpy()
|
500 |
+
if type(loss2) is not np.ndarray:
|
501 |
+
loss2 = loss2.numpy()
|
502 |
+
fig = plt.figure()
|
503 |
+
plt.subplot(121)
|
504 |
+
sns.set(style='whitegrid')
|
505 |
+
sns.distplot(loss[noisy_index], color="red", rug=False, kde=False, label="incorrect",
|
506 |
+
hist_kws={"color": "r", "alpha": 0.5})
|
507 |
+
sns.distplot(loss[clean_index], color="skyblue", rug=False, kde=False, label="correct",
|
508 |
+
hist_kws={"color": "b", "alpha": 0.5})
|
509 |
+
plt.title('Model_1', fontsize=32)
|
510 |
+
plt.xlabel('Normalized loss', fontsize=32)
|
511 |
+
plt.ylabel('Sample number', fontsize=32)
|
512 |
+
plt.tick_params(labelsize=32)
|
513 |
+
plt.legend(loc='upper right', prop={'size': 24})
|
514 |
+
plt.subplot(122)
|
515 |
+
sns.set(style='whitegrid')
|
516 |
+
sns.distplot(loss2[noisy_index], color="red", rug=False, kde=False, label="incorrect",
|
517 |
+
hist_kws={"color": "r", "alpha": 0.5})
|
518 |
+
sns.distplot(loss2[clean_index], color="skyblue", rug=False, kde=False, label="correct",
|
519 |
+
hist_kws={"color": "b", "alpha": 0.5})
|
520 |
+
plt.title('Model_2', fontsize=32)
|
521 |
+
plt.xlabel('Normalized loss', fontsize=32)
|
522 |
+
plt.ylabel('Sample number', fontsize=32)
|
523 |
+
plt.tick_params(labelsize=32)
|
524 |
+
plt.legend(loc='upper right', prop={'size': 24})
|
525 |
+
# plt.tight_layout()
|
526 |
+
if g_file:
|
527 |
+
plt.savefig('./figure_his/{0}.tif'.format(filename + model_name), bbox_inches='tight', dpi=300)
|
528 |
+
# plt.show()
|
529 |
+
plt.close()
|
530 |
+
|
531 |
+
|
532 |
+
def loss_dist_plot_real(loss, epoch, rou=None, g_file=True, model_name=''):
|
533 |
+
"""
|
534 |
+
plot the loss distribution
|
535 |
+
:param loss: the list contains the loss per sample
|
536 |
+
:param noisy_index: contains the indices of real noisy label
|
537 |
+
:param clean_index: contains the indices of real clean label
|
538 |
+
:param filename: the generated pdf file name
|
539 |
+
:param title: the figure title
|
540 |
+
:param g_file: whether to generate the pdf figure file
|
541 |
+
:return: None
|
542 |
+
"""
|
543 |
+
filename = str(args.dataset) + '_' + str(args.noise_mode) + '_' + str(args.r) + '_epoch=' + str(epoch)
|
544 |
+
if rou is None:
|
545 |
+
title = 'Epoch-' + str(epoch) + ': ' + str(args.dataset) + ' ' + str(args.r * 100) + '%-' + str(args.noise_mode)
|
546 |
+
else:
|
547 |
+
title = 'Epoch-' + str(epoch) + ' ' + 'Round-' + str(rou) + ': ' + str(args.dataset) + ' ' + str(args.r * 100) + '%-' + str(args.noise_mode)
|
548 |
+
|
549 |
+
if type(loss) is not np.ndarray:
|
550 |
+
loss= loss.numpy()
|
551 |
+
sns.set(style='whitegrid')
|
552 |
+
|
553 |
+
gmm = GaussianMixture(n_components=2, max_iter=20, tol=1e-2, random_state=0, reg_covar=5e-4)
|
554 |
+
loss = np.reshape(loss, (-1, 1))
|
555 |
+
gmm.fit(loss) # fit the loss
|
556 |
+
|
557 |
+
# plot resulting fit
|
558 |
+
x_range = np.linspace(0, 1, 1000)
|
559 |
+
pdf = np.exp(gmm.score_samples(x_range.reshape(-1, 1)))
|
560 |
+
responsibilities = gmm.predict_proba(x_range.reshape(-1, 1))
|
561 |
+
pdf_individual = responsibilities * pdf[:, np.newaxis]
|
562 |
+
|
563 |
+
plt.hist(loss, bins=60, density=True, histtype='bar', alpha=0.3)
|
564 |
+
plt.plot(x_range, pdf, '-k', label='Mixture')
|
565 |
+
plt.plot(x_range, pdf_individual, '--', label='Component')
|
566 |
+
plt.legend()
|
567 |
+
# plt.tight_layout()
|
568 |
+
|
569 |
+
plt.title(title, fontsize=32)
|
570 |
+
plt.xlabel('Normalized loss', fontsize=32)
|
571 |
+
plt.ylabel('Estimated PDF', fontsize=32)
|
572 |
+
plt.tick_params(labelsize=32)
|
573 |
+
plt.legend(loc='upper right', prop={'size': 22})
|
574 |
+
if g_file:
|
575 |
+
plt.savefig('./figure_his/{0}.tif'.format(filename+model_name), bbox_inches='tight', dpi=300)
|
576 |
+
#plt.show()
|
577 |
+
plt.close()
|
578 |
+
|
579 |
+
|
580 |
+
def FedAvg(w):
|
581 |
+
w_avg = copy.deepcopy(w[0])
|
582 |
+
for k in w_avg.keys():
|
583 |
+
for i in range(1, len(w)):
|
584 |
+
w_avg[k] += w[i][k]
|
585 |
+
# 只考虑iid noise的话,每个client训练样本数一样,所以不用做nk/n
|
586 |
+
w_avg[k] = torch.div(w_avg[k], len(w))
|
587 |
+
|
588 |
+
return w_avg
|
589 |
+
|
590 |
+
|
591 |
+
if os.path.exists('checkpoint') == False:
|
592 |
+
os.mkdir('checkpoint')
|
593 |
+
print("新建日志文件夹")
|
594 |
+
stats_log = open('./checkpoint/single_%s_%.1f_%s_%d' % (args.dataset, args.r, args.noise_mode,
|
595 |
+
int(soft_mix_warm)) + '_stats.txt', 'w')
|
596 |
+
test_log = open('./checkpoint/single_%s_%.1f_%s_%d' % (args.dataset, args.r, args.noise_mode,
|
597 |
+
int(soft_mix_warm)) + '_acc.txt', 'w')
|
598 |
+
|
599 |
+
warm_up = 10
|
600 |
+
dmix_epoch = 150
|
601 |
+
args.num_epochs = dmix_epoch + 150
|
602 |
+
# 第6页提及的warm up的epoch
|
603 |
+
if args.dataset == 'cifar10':
|
604 |
+
warm_up = 10
|
605 |
+
dmix_epoch = 150
|
606 |
+
args.num_epochs = dmix_epoch + 50
|
607 |
+
elif args.dataset == 'cifar100':
|
608 |
+
warm_up = 30
|
609 |
+
dmix_epoch = 150
|
610 |
+
args.num_epochs = dmix_epoch + 50
|
611 |
+
|
612 |
+
loader = dataloader.cifar_dataloader(args.dataset, r=args.r, noise_mode=args.noise_mode,
|
613 |
+
batch_size=args.batch_size, num_workers=0,
|
614 |
+
root_dir=args.data_path, log=stats_log,
|
615 |
+
noise_file='%s/%.1f_%s.json' % (args.data_path, args.r, args.noise_mode))
|
616 |
+
|
617 |
+
print('| Building net')
|
618 |
+
net1 = create_model()
|
619 |
+
net2 = create_model()
|
620 |
+
cudnn.benchmark = True
|
621 |
+
|
622 |
+
criterion = SemiLoss()
|
623 |
+
optimizer1 = optim.SGD(net1.parameters(), lr=args.lr, momentum=0.9, weight_decay=5e-4)
|
624 |
+
optimizer2 = optim.SGD(net2.parameters(), lr=args.lr, momentum=0.9, weight_decay=5e-4)
|
625 |
+
|
626 |
+
CE = nn.CrossEntropyLoss(reduction='none')
|
627 |
+
CEloss = nn.CrossEntropyLoss()
|
628 |
+
if args.noise_mode == 'asym':
|
629 |
+
# 本文第一个问题,对于非对称和对称需要不同措施,这很不适用
|
630 |
+
# 其次本文在不同步骤中噪声数据处理措施很凌乱
|
631 |
+
conf_penalty = NegEntropy()
|
632 |
+
|
633 |
+
all_loss = [[], []] # save the history of losses from two networks
|
634 |
+
|
635 |
+
local_round = 5
|
636 |
+
first = True
|
637 |
+
balance_crit = 'median'
|
638 |
+
exp_path = './checkpoint/single_%s_%.1f_%s_double_m2_' % (args.dataset, args.r, args.noise_mode)
|
639 |
+
save_clean_idx = exp_path + "clean_idx.npy"
|
640 |
+
boot_loader = None
|
641 |
+
w_glob = None
|
642 |
+
if args.r == 0.9:
|
643 |
+
args.p_threshold = 0.6
|
644 |
+
best_en_acc = 0.
|
645 |
+
best_gl_acc = 0.
|
646 |
+
resume_epoch = 0
|
647 |
+
if resume_epoch > 0:
|
648 |
+
snapLast = exp_path + str(resume_epoch-1) + "_global_model.pth"
|
649 |
+
global_state = torch.load(snapLast)
|
650 |
+
# 先更新还是后跟新
|
651 |
+
w_glob = global_state
|
652 |
+
net1.load_state_dict(global_state)
|
653 |
+
net2.load_state_dict(global_state)
|
654 |
+
for epoch in range(resume_epoch, args.num_epochs + 1):
|
655 |
+
test_loader = loader.run('test')
|
656 |
+
eval_loader = loader.run('eval_train')
|
657 |
+
lr = args.lr
|
658 |
+
if epoch >= dmix_epoch:
|
659 |
+
lr /= 10
|
660 |
+
for param_group in optimizer1.param_groups:
|
661 |
+
param_group['lr'] = lr
|
662 |
+
for param_group in optimizer2.param_groups:
|
663 |
+
param_group['lr'] = lr
|
664 |
+
|
665 |
+
noise_ind, clean_ind = eval_loader.dataset.if_noise()
|
666 |
+
print(len(np.where(np.array(eval_loader.dataset.noise_label) != np.array(eval_loader.dataset.clean_label))[0])
|
667 |
+
/ len(eval_loader.dataset.clean_label))
|
668 |
+
local_weights = []
|
669 |
+
if epoch < warm_up:
|
670 |
+
# 考虑warm up时是否需要merge
|
671 |
+
warmup_trainloader = loader.run('warmup')
|
672 |
+
print('Warmup Net1')
|
673 |
+
warmup(epoch, net1, optimizer1, warmup_trainloader)
|
674 |
+
print('\nWarmup Net2')
|
675 |
+
warmup(epoch, net2, optimizer2, warmup_trainloader)
|
676 |
+
if epoch == (warm_up-1):
|
677 |
+
snapLast = exp_path+str(epoch) + "_1_model.pth"
|
678 |
+
torch.save(net1.state_dict(), snapLast)
|
679 |
+
snapLast = exp_path+str(epoch) + "_2_model.pth"
|
680 |
+
torch.save(net1.state_dict(), snapLast)
|
681 |
+
local_weights.append(net1.state_dict())
|
682 |
+
local_weights.append(net2.state_dict())
|
683 |
+
w_glob = FedAvg(local_weights)
|
684 |
+
|
685 |
+
else:
|
686 |
+
if epoch != warm_up:
|
687 |
+
net1.load_state_dict(w_glob)
|
688 |
+
net2.load_state_dict(w_glob)
|
689 |
+
|
690 |
+
for rou in range(local_round):
|
691 |
+
prob1, all_loss[0], loss1, mask1, f_G1 = eval_train(net1, all_loss[0])
|
692 |
+
prob2, all_loss[1], loss2, mask2, f_G2 = eval_train(net2, all_loss[1])
|
693 |
+
|
694 |
+
# 加载完global后第一次评估
|
695 |
+
if rou == 0:
|
696 |
+
# plotHistogram(np.array(loss1), np.array(loss2), noise_ind, clean_ind, epoch, rou, args.r)
|
697 |
+
loss_dist_plot(loss1, noise_ind, clean_ind, epoch, model_name='model_1')
|
698 |
+
# loss_dist_plot_real(loss1, epoch, model_name='model_1')
|
699 |
+
if rou == local_round-1:
|
700 |
+
plotHistogram(np.array(loss1), np.array(loss2), noise_ind, clean_ind, epoch, rou, args.r)
|
701 |
+
|
702 |
+
# pred1 = (prob1 > args.p_threshold) & (mask1 != 0)
|
703 |
+
# pred2 = (prob2 > args.p_threshold) & (mask2 != 0)
|
704 |
+
pred1 = (prob1 > args.p_threshold)
|
705 |
+
pred2 = (prob2 > args.p_threshold)
|
706 |
+
|
707 |
+
non_zero_idx = pred1.nonzero()[0].tolist()
|
708 |
+
aaa = len(non_zero_idx)
|
709 |
+
if balance_crit == "max" or balance_crit == "min" or balance_crit == "median":
|
710 |
+
num_clean_per_class = np.zeros(args.num_class)
|
711 |
+
target_label = np.array(eval_loader.dataset.noise_label)[non_zero_idx]
|
712 |
+
for i in range(args.num_class):
|
713 |
+
idx_class = np.where(target_label == i)[0]
|
714 |
+
num_clean_per_class[i] = len(idx_class)
|
715 |
+
|
716 |
+
if balance_crit == "median":
|
717 |
+
num_samples2select_class = np.median(num_clean_per_class)
|
718 |
+
|
719 |
+
for i in range(args.num_class):
|
720 |
+
idx_class = np.where(np.array(eval_loader.dataset.noise_label) == i)[0]
|
721 |
+
cur_num = num_clean_per_class[i]
|
722 |
+
idx_class2 = non_zero_idx
|
723 |
+
if num_samples2select_class > cur_num:
|
724 |
+
remian_idx = list(set(idx_class.tolist()) - set(idx_class2))
|
725 |
+
idx = list(range(len(remian_idx)))
|
726 |
+
random.shuffle(idx)
|
727 |
+
num_app = int(num_samples2select_class - cur_num)
|
728 |
+
idx = idx[:num_app]
|
729 |
+
for j in idx:
|
730 |
+
non_zero_idx.append(remian_idx[j])
|
731 |
+
non_zero_idx = np.array(non_zero_idx).reshape(-1, )
|
732 |
+
bbb = len(non_zero_idx)
|
733 |
+
num_per_class2 = []
|
734 |
+
for i in range(max(eval_loader.dataset.noise_label)):
|
735 |
+
temp = np.where(np.array(eval_loader.dataset.noise_label)[non_zero_idx.tolist()] == i)[0]
|
736 |
+
num_per_class2.append(len(temp))
|
737 |
+
print('\npred1 appended num per class:', num_per_class2, aaa, bbb)
|
738 |
+
idx_per_class = np.zeros_like(pred1).astype(bool)
|
739 |
+
for i in non_zero_idx:
|
740 |
+
idx_per_class[i] = True
|
741 |
+
pred1 = idx_per_class
|
742 |
+
non_aaa = pred1.nonzero()[0].tolist()
|
743 |
+
assert len(non_aaa) == len(non_zero_idx)
|
744 |
+
|
745 |
+
non_zero_idx2 = pred2.nonzero()[0].tolist()
|
746 |
+
aaa = len(non_zero_idx2)
|
747 |
+
if balance_crit == "max" or balance_crit == "min" or balance_crit == "median":
|
748 |
+
num_clean_per_class = np.zeros(args.num_class)
|
749 |
+
target_label = np.array(eval_loader.dataset.noise_label)[non_zero_idx2]
|
750 |
+
for i in range(args.num_class):
|
751 |
+
idx_class = np.where(target_label == i)[0]
|
752 |
+
num_clean_per_class[i] = len(idx_class)
|
753 |
+
|
754 |
+
if balance_crit == "median":
|
755 |
+
num_samples2select_class = np.median(num_clean_per_class)
|
756 |
+
|
757 |
+
for i in range(args.num_class):
|
758 |
+
idx_class = np.where(np.array(eval_loader.dataset.noise_label) == i)[0]
|
759 |
+
cur_num = num_clean_per_class[i]
|
760 |
+
idx_class2 = non_zero_idx2
|
761 |
+
if num_samples2select_class > cur_num:
|
762 |
+
remian_idx = list(set(idx_class.tolist()) - set(idx_class2))
|
763 |
+
idx = list(range(len(remian_idx)))
|
764 |
+
random.shuffle(idx)
|
765 |
+
num_app = int(num_samples2select_class - cur_num)
|
766 |
+
idx = idx[:num_app]
|
767 |
+
for j in idx:
|
768 |
+
non_zero_idx2.append(remian_idx[j])
|
769 |
+
non_zero_idx2 = np.array(non_zero_idx2).reshape(-1, )
|
770 |
+
bbb = len(non_zero_idx2)
|
771 |
+
num_per_class2 = []
|
772 |
+
for i in range(max(eval_loader.dataset.noise_label)):
|
773 |
+
temp = np.where(np.array(eval_loader.dataset.noise_label)[non_zero_idx2.tolist()] == i)[0]
|
774 |
+
num_per_class2.append(len(temp))
|
775 |
+
print('\npred2 appended num per class:', num_per_class2, aaa, bbb)
|
776 |
+
idx_per_class2 = np.zeros_like(pred2).astype(bool)
|
777 |
+
for i in non_zero_idx2:
|
778 |
+
idx_per_class2[i] = True
|
779 |
+
pred2 = idx_per_class2
|
780 |
+
non_aaa = pred2.nonzero()[0].tolist()
|
781 |
+
assert len(non_aaa) == len(non_zero_idx2)
|
782 |
+
|
783 |
+
correct_num = len(pred1.nonzero()[0])
|
784 |
+
eval_loader.dataset.if_noise(pred1)
|
785 |
+
eval_loader.dataset.if_noise(pred2)
|
786 |
+
|
787 |
+
print(f'round={rou}/{local_round}, dmix selection, Train Net1')
|
788 |
+
# prob2就是先验概率wi,通过GMM拟合出来的,大于阈值就认为是clean,否则noisy
|
789 |
+
labeled_trainloader, unlabeled_trainloader = loader.run('train', pred2, prob2) # co-divide
|
790 |
+
train(epoch, net1, net2, optimizer1, labeled_trainloader, unlabeled_trainloader) # train net1
|
791 |
+
|
792 |
+
print(f'\nround={rou}/{local_round}, dmix selection, Train Net2')
|
793 |
+
labeled_trainloader, unlabeled_trainloader = loader.run('train', pred1, prob1) # co-divide
|
794 |
+
train(epoch, net2, net1, optimizer2, labeled_trainloader, unlabeled_trainloader) # train net2
|
795 |
+
|
796 |
+
local_weights.append(net1.state_dict())
|
797 |
+
local_weights.append(net2.state_dict())
|
798 |
+
w_glob = FedAvg(local_weights)
|
799 |
+
if epoch % 5 == 0:
|
800 |
+
snapLast = exp_path + str(epoch) + "_global_model.pth"
|
801 |
+
torch.save(w_glob, snapLast)
|
802 |
+
|
803 |
+
best_en_acc = test(epoch, net1, net2, best_en_acc)
|
804 |
+
best_gl_acc= test(epoch, net1, net2, best_gl_acc, w_glob)
|
805 |
+
|
806 |
+
|
807 |
+
|