LanXiaoPang613 commited on
Commit
d3cde70
·
unverified ·
1 Parent(s): 55eda5e

Add files via upload

Browse files
PreResNet.py ADDED
@@ -0,0 +1,184 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+
5
+ from torch.autograd import Variable
6
+
7
+
8
+ def conv3x3(in_planes, out_planes, stride=1):
9
+ return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False)
10
+
11
+
12
+ class BasicBlock(nn.Module):
13
+ expansion = 1
14
+
15
+ def __init__(self, in_planes, planes, stride=1):
16
+ super(BasicBlock, self).__init__()
17
+ self.conv1 = conv3x3(in_planes, planes, stride)
18
+ self.bn1 = nn.BatchNorm2d(planes)
19
+ self.conv2 = conv3x3(planes, planes)
20
+ self.bn2 = nn.BatchNorm2d(planes)
21
+
22
+ self.shortcut = nn.Sequential()
23
+ if stride != 1 or in_planes != self.expansion*planes:
24
+ self.shortcut = nn.Sequential(
25
+ nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False),
26
+ nn.BatchNorm2d(self.expansion*planes)
27
+ )
28
+
29
+ def forward(self, x):
30
+ out = F.relu(self.bn1(self.conv1(x)))
31
+ out = self.bn2(self.conv2(out))
32
+ out += self.shortcut(x)
33
+ out = F.relu(out)
34
+ return out
35
+
36
+
37
+ class PreActBlock(nn.Module):
38
+ '''Pre-activation version of the BasicBlock.'''
39
+ expansion = 1
40
+
41
+ def __init__(self, in_planes, planes, stride=1):
42
+ super(PreActBlock, self).__init__()
43
+ self.bn1 = nn.BatchNorm2d(in_planes)
44
+ self.conv1 = conv3x3(in_planes, planes, stride)
45
+ self.bn2 = nn.BatchNorm2d(planes)
46
+ self.conv2 = conv3x3(planes, planes)
47
+
48
+ self.shortcut = nn.Sequential()
49
+ if stride != 1 or in_planes != self.expansion*planes:
50
+ self.shortcut = nn.Sequential(
51
+ nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False)
52
+ )
53
+
54
+ def forward(self, x):
55
+ out = F.relu(self.bn1(x))
56
+ shortcut = self.shortcut(out)
57
+ out = self.conv1(out)
58
+ out = self.conv2(F.relu(self.bn2(out)))
59
+ out += shortcut
60
+ return out
61
+
62
+
63
+ class Bottleneck(nn.Module):
64
+ expansion = 4
65
+
66
+ def __init__(self, in_planes, planes, stride=1):
67
+ super(Bottleneck, self).__init__()
68
+ self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False)
69
+ self.bn1 = nn.BatchNorm2d(planes)
70
+ self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
71
+ self.bn2 = nn.BatchNorm2d(planes)
72
+ self.conv3 = nn.Conv2d(planes, self.expansion*planes, kernel_size=1, bias=False)
73
+ self.bn3 = nn.BatchNorm2d(self.expansion*planes)
74
+
75
+ self.shortcut = nn.Sequential()
76
+ if stride != 1 or in_planes != self.expansion*planes:
77
+ self.shortcut = nn.Sequential(
78
+ nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False),
79
+ nn.BatchNorm2d(self.expansion*planes)
80
+ )
81
+
82
+ def forward(self, x):
83
+ out = F.relu(self.bn1(self.conv1(x)))
84
+ out = F.relu(self.bn2(self.conv2(out)))
85
+ out = self.bn3(self.conv3(out))
86
+ out += self.shortcut(x)
87
+ out = F.relu(out)
88
+ return out
89
+
90
+
91
+ class PreActBottleneck(nn.Module):
92
+ '''Pre-activation version of the original Bottleneck module.'''
93
+ expansion = 4
94
+
95
+ def __init__(self, in_planes, planes, stride=1):
96
+ super(PreActBottleneck, self).__init__()
97
+ self.bn1 = nn.BatchNorm2d(in_planes)
98
+ self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False)
99
+ self.bn2 = nn.BatchNorm2d(planes)
100
+ self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
101
+ self.bn3 = nn.BatchNorm2d(planes)
102
+ self.conv3 = nn.Conv2d(planes, self.expansion*planes, kernel_size=1, bias=False)
103
+
104
+ self.shortcut = nn.Sequential()
105
+ if stride != 1 or in_planes != self.expansion*planes:
106
+ self.shortcut = nn.Sequential(
107
+ nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False)
108
+ )
109
+
110
+ def forward(self, x):
111
+ out = F.relu(self.bn1(x))
112
+ shortcut = self.shortcut(out)
113
+ out = self.conv1(out)
114
+ out = self.conv2(F.relu(self.bn2(out)))
115
+ out = self.conv3(F.relu(self.bn3(out)))
116
+ out += shortcut
117
+ return out
118
+
119
+
120
+ class ResNet(nn.Module):
121
+ def __init__(self, block, num_blocks, num_classes=10):
122
+ super(ResNet, self).__init__()
123
+ self.in_planes = 64
124
+
125
+ self.conv1 = conv3x3(3,64)
126
+ self.bn1 = nn.BatchNorm2d(64)
127
+ self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1)
128
+ self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2)
129
+ self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2)
130
+ self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2)
131
+ self.linear = nn.Linear(512*block.expansion, num_classes)
132
+
133
+ def _make_layer(self, block, planes, num_blocks, stride):
134
+ strides = [stride] + [1]*(num_blocks-1)
135
+ layers = []
136
+ for stride in strides:
137
+ layers.append(block(self.in_planes, planes, stride))
138
+ self.in_planes = planes * block.expansion
139
+ return nn.Sequential(*layers)
140
+
141
+ def forward(self, x, lin=0, lout=5, feat_out=False):
142
+ out = x
143
+ if lin < 1 and lout > -1:
144
+ out = self.conv1(out)
145
+ out = self.bn1(out)
146
+ out = F.relu(out)
147
+ if lin < 2 and lout > 0:
148
+ out = self.layer1(out)
149
+ if lin < 3 and lout > 1:
150
+ out = self.layer2(out)
151
+ if lin < 4 and lout > 2:
152
+ out = self.layer3(out)
153
+ if lin < 5 and lout > 3:
154
+ out = self.layer4(out)
155
+ if lout > 4:
156
+ out = F.avg_pool2d(out, 4)
157
+ feat = out.view(out.size(0), -1)
158
+ out = self.linear(feat)
159
+ if feat_out:
160
+ return out, feat
161
+ else:
162
+ return out
163
+
164
+
165
+ def ResNet18(num_classes=10):
166
+ return ResNet(PreActBlock, [2,2,2,2], num_classes=num_classes)
167
+
168
+ def ResNet34(num_classes=10):
169
+ return ResNet(BasicBlock, [3,4,6,3], num_classes=num_classes)
170
+
171
+ def ResNet50(num_classes=10):
172
+ return ResNet(Bottleneck, [3,4,6,3], num_classes=num_classes)
173
+
174
+ def ResNet101(num_classes=10):
175
+ return ResNet(Bottleneck, [3,4,23,3], num_classes=num_classes)
176
+
177
+ def ResNet152(num_classes=10):
178
+ return ResNet(Bottleneck, [3,8,36,3], num_classes=num_classes)
179
+
180
+
181
+ def test():
182
+ net = ResNet18()
183
+ y = net(Variable(torch.randn(1,3,32,32)))
184
+ print(y.size())
Train_animal10N.py ADDED
@@ -0,0 +1,487 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import print_function
2
+ import sys
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.optim as optim
6
+ import torch.nn.functional as F
7
+ import torch.backends.cudnn as cudnn
8
+ import torchvision
9
+ import torchvision.models as models
10
+ from models.CNN import CNN
11
+ import random
12
+ import os
13
+ import argparse
14
+ import numpy as np
15
+ import dataloader_animal10N as animal_dataloader
16
+ from sklearn.mixture import GaussianMixture
17
+ import copy
18
+
19
+ parser = argparse.ArgumentParser(description='PyTorch Clothing1M Training')
20
+ parser.add_argument('--batch_size', default=128, type=int, help='train batchsize')
21
+ parser.add_argument('--lr', '--learning_rate', default=0.01, type=float, help='initial learning rate')
22
+ parser.add_argument('--alpha', default=4, type=float, help='parameter for Beta')
23
+ parser.add_argument('--lambda_u', default=0, type=float, help='weight for unsupervised loss')
24
+ parser.add_argument('--p_threshold', default=0.5, type=float, help='clean probability threshold')
25
+ parser.add_argument('--T', default=0.5, type=float, help='sharpening temperature')
26
+ parser.add_argument('--num_epochs', default=300, type=int)
27
+ parser.add_argument('--id', default='animal10N')
28
+ # parser.add_argument('--data_path', default='E:/Dataset_All/clothing1M/images', type=str, help='path to dataset')
29
+ parser.add_argument('--data_path', default='C:/Users/Administrator/Desktop/DatasetAll/Animal-10N', type=str, help='path to dataset')
30
+ parser.add_argument('--seed', default=123)
31
+ parser.add_argument('--gpuid', default=0, type=int)
32
+ parser.add_argument('--num_class', default=10, type=int)
33
+ # parser.add_argument('--num_batches', default=1000, type=int)
34
+ args = parser.parse_args()
35
+
36
+ torch.cuda.set_device(args.gpuid)
37
+ random.seed(args.seed)
38
+ torch.manual_seed(args.seed)
39
+ torch.cuda.manual_seed_all(args.seed)
40
+
41
+
42
+ # Training
43
+ def train(epoch, net, net2, optimizer, labeled_trainloader, unlabeled_trainloader):
44
+ net.train()
45
+ net2.eval() # fix one network and train the other
46
+
47
+ unlabeled_train_iter = iter(unlabeled_trainloader)
48
+ num_iter = (len(labeled_trainloader.dataset) // args.batch_size) + 1
49
+ for batch_idx, (inputs_x, inputs_x2, labels_x, w_x) in enumerate(labeled_trainloader):
50
+ try:
51
+ inputs_u, inputs_u2 = unlabeled_train_iter.__next__()
52
+ except:
53
+ unlabeled_train_iter = iter(unlabeled_trainloader)
54
+ inputs_u, inputs_u2 = unlabeled_train_iter.__next__()
55
+ batch_size = inputs_x.size(0)
56
+
57
+ # Transform label to one-hot
58
+ labels_x = torch.zeros(batch_size, args.num_class).scatter_(1, labels_x.view(-1, 1), 1)
59
+ w_x = w_x.view(-1, 1).type(torch.FloatTensor)
60
+
61
+ inputs_x, inputs_x2, labels_x, w_x = inputs_x.cuda(), inputs_x2.cuda(), labels_x.cuda(), w_x.cuda()
62
+ inputs_u, inputs_u2 = inputs_u.cuda(), inputs_u2.cuda()
63
+
64
+ with torch.no_grad():
65
+ # label co-guessing of unlabeled samples
66
+ outputs_u11 = net(inputs_u)
67
+ outputs_u12 = net(inputs_u2)
68
+ outputs_u21 = net2(inputs_u)
69
+ outputs_u22 = net2(inputs_u2)
70
+
71
+ pu = (torch.softmax(outputs_u11, dim=1) + torch.softmax(outputs_u12, dim=1) +
72
+ torch.softmax(outputs_u21, dim=1) + torch.softmax(outputs_u22, dim=1)) / 4
73
+ ptu = pu ** (1 / args.T) # temparature sharpening
74
+
75
+ targets_u = ptu / ptu.sum(dim=1, keepdim=True) # normalize
76
+ targets_u = targets_u.detach()
77
+
78
+ # label refinement of labeled samples
79
+ outputs_x = net(inputs_x)
80
+ outputs_x2 = net(inputs_x2)
81
+
82
+ px = (torch.softmax(outputs_x, dim=1) + torch.softmax(outputs_x2, dim=1)) / 2
83
+ px = w_x * labels_x + (1 - w_x) * px
84
+ ptx = px ** (1 / args.T) # temparature sharpening
85
+
86
+ targets_x = ptx / ptx.sum(dim=1, keepdim=True) # normalize
87
+ targets_x = targets_x.detach()
88
+
89
+ # mixmatch
90
+ l = np.random.beta(args.alpha, args.alpha)
91
+ l = max(l, 1 - l)
92
+
93
+ all_inputs = torch.cat([inputs_x, inputs_x2, inputs_u, inputs_u2], dim=0)
94
+ all_targets = torch.cat([targets_x, targets_x, targets_u, targets_u], dim=0)
95
+
96
+ idx = torch.randperm(all_inputs.size(0))
97
+
98
+ input_a, input_b = all_inputs, all_inputs[idx]
99
+ target_a, target_b = all_targets, all_targets[idx]
100
+
101
+ mixed_input = l * input_a[:batch_size * 2] + (1 - l) * input_b[:batch_size * 2]
102
+ mixed_target = l * target_a[:batch_size * 2] + (1 - l) * target_b[:batch_size * 2]
103
+
104
+ logits = net(mixed_input)
105
+
106
+ Lx = -torch.mean(torch.sum(F.log_softmax(logits, dim=1) * mixed_target, dim=1))
107
+
108
+ # regularization
109
+ prior = torch.ones(args.num_class) / args.num_class
110
+ prior = prior.cuda()
111
+ pred_mean = torch.softmax(logits, dim=1).mean(0)
112
+ penalty = torch.sum(prior * torch.log(prior / pred_mean))
113
+
114
+ loss = Lx + penalty
115
+
116
+ # compute gradient and do SGD step
117
+ optimizer.zero_grad()
118
+ loss.backward()
119
+ optimizer.step()
120
+
121
+ sys.stdout.write('\r')
122
+ sys.stdout.write('Animal10N | Epoch [%3d/%3d] Iter[%3d/%3d]\t Labeled loss: %.4f '
123
+ % (epoch, args.num_epochs, batch_idx + 1, num_iter, Lx.item()))
124
+ sys.stdout.flush()
125
+
126
+
127
+ def warmup(net, optimizer, dataloader):
128
+ net.train()
129
+ num_batches = 50000/args.batch_size
130
+ for batch_idx, (inputs, labels, path) in enumerate(dataloader):
131
+ inputs, labels = inputs.cuda(), labels.cuda()
132
+ optimizer.zero_grad()
133
+ outputs = net(inputs)
134
+ loss = CEloss(outputs, labels)
135
+
136
+ penalty = conf_penalty(outputs)
137
+ L = loss + penalty
138
+ L.backward()
139
+ optimizer.step()
140
+
141
+ sys.stdout.write('\r')
142
+ sys.stdout.write('|Warm-up: Iter[%3d/%3d]\t CE-loss: %.4f Conf-Penalty: %.4f'
143
+ % (2*(batch_idx + 1), num_batches, loss.item(), penalty.item()))
144
+ sys.stdout.flush()
145
+
146
+
147
+ def val(net, val_loader, best_acc, w_glob=None):
148
+ net.eval()
149
+ correct = 0
150
+ total = 0
151
+ with torch.no_grad():
152
+ for batch_idx, (inputs, targets) in enumerate(val_loader):
153
+ inputs, targets = inputs.cuda(), targets.cuda()
154
+ outputs = net(inputs)
155
+ _, predicted = torch.max(outputs, 1)
156
+
157
+ total += targets.size(0)
158
+ correct += predicted.eq(targets).cpu().sum().item()
159
+ acc = 100. * correct / total
160
+ print("\n| Validation\t Net%d Acc: %.2f%%" % (k, acc))
161
+ if acc > best_acc[k - 1]:
162
+ best_acc[k - 1] = acc
163
+ print('| Saving Best Net%d ...' % k)
164
+ save_point = './checkpoint/%s_net%d.pth.tar' % (args.id, k)
165
+ torch.save(net.state_dict(), save_point)
166
+ return acc
167
+
168
+
169
+ def test(epoch, net1, net2, test_loader, best_acc, w_glob=None):
170
+ if w_glob is None:
171
+ net1.eval()
172
+ net2.eval()
173
+ correct = 0
174
+ correct2 = 0
175
+ correct1 = 0
176
+ total = 0
177
+ with torch.no_grad():
178
+ for batch_idx, (inputs, targets) in enumerate(test_loader):
179
+ inputs, targets = inputs.cuda(), targets.cuda()
180
+ outputs1 = net1(inputs)
181
+ outputs2 = net2(inputs)
182
+ outputs = outputs1 + outputs2
183
+ _, predicted = torch.max(outputs, 1)
184
+ _, predicted1 = torch.max(outputs1, 1)
185
+ _, predicted2 = torch.max(outputs2, 1)
186
+
187
+ total += targets.size(0)
188
+ correct += predicted.eq(targets).cpu().sum().item()
189
+ correct1 += predicted1.eq(targets).cpu().sum().item()
190
+ correct2 += predicted2.eq(targets).cpu().sum().item()
191
+ acc = 100. * correct / total
192
+ acc1 = 100. * correct / total
193
+ acc2 = 100. * correct / total
194
+ if best_acc < acc:
195
+ best_acc = acc
196
+ print(
197
+ "\n| Ensemble network Test Epoch #%d\t Accuracy: %.2f, Accuracy1: %.2f, Accuracy2: %.2f, best_acc: %.2f%%\n" % (
198
+ epoch, acc, acc1, acc2, best_acc))
199
+ log.write('ensemble_Epoch:%d Accuracy:%.2f, Accuracy1: %.2f, Accuracy2: %.2f, best_acc: %.2f\n' % (
200
+ epoch, acc, acc1, acc2, best_acc))
201
+ log.flush()
202
+ else:
203
+ net1_w_bak = net1.state_dict()
204
+ net1.load_state_dict(w_glob)
205
+ net1.eval()
206
+ correct = 0
207
+ total = 0
208
+ with torch.no_grad():
209
+ for batch_idx, (inputs, targets) in enumerate(test_loader):
210
+ inputs, targets = inputs.cuda(), targets.cuda()
211
+ outputs1 = net1(inputs)
212
+ _, predicted = torch.max(outputs1, 1)
213
+ total += targets.size(0)
214
+ correct += predicted.eq(targets).cpu().sum().item()
215
+ acc = 100. * correct / total
216
+ if best_acc < acc:
217
+ best_acc = acc
218
+ print("\n| Global network Test Epoch #%d\t Accuracy: %.2f, best_acc: %.2f%%\n" % (epoch, acc, best_acc))
219
+ log.write('global_Epoch:%d Accuracy:%.2f, best_acc: %.2f\n' % (epoch, acc, best_acc))
220
+ log.flush()
221
+ # 恢复权重
222
+ net1.load_state_dict(net1_w_bak)
223
+ return best_acc
224
+
225
+
226
+ def eval_train(epoch, model):
227
+ model.eval()
228
+ num_samples = eval_loader.dataset.__len__()
229
+ losses = torch.zeros(num_samples)
230
+ paths = []
231
+ n = 0
232
+ with torch.no_grad():
233
+ for batch_idx, (inputs, targets, path) in enumerate(eval_loader):
234
+ inputs, targets = inputs.cuda(), targets.cuda()
235
+ outputs = model(inputs)
236
+ loss = CE(outputs, targets)
237
+ for b in range(inputs.size(0)):
238
+ losses[n] = loss[b]
239
+ paths.append(path[b])
240
+ n += 1
241
+ sys.stdout.write('\r')
242
+ sys.stdout.write('| Evaluating loss Iter %3d\t' % (batch_idx))
243
+ sys.stdout.flush()
244
+
245
+ losses = (losses - losses.min()) / (losses.max() - losses.min())
246
+ losses = losses.reshape(-1, 1)
247
+ gmm = GaussianMixture(n_components=2, max_iter=10, reg_covar=5e-4, tol=1e-2)
248
+ gmm.fit(losses)
249
+ prob = gmm.predict_proba(losses)
250
+ prob = prob[:, gmm.means_.argmin()]
251
+ return prob, paths
252
+
253
+
254
+ class NegEntropy(object):
255
+ def __call__(self, outputs):
256
+ probs = torch.softmax(outputs, dim=1)
257
+ return torch.mean(torch.sum(probs.log() * probs, dim=1))
258
+
259
+
260
+ def create_model():
261
+ use_cnn = False
262
+ if use_cnn:
263
+ model = CNN()
264
+ model = model.cuda()
265
+ else:
266
+ model = models.vgg19_bn(pretrained=False)
267
+ model.classifier._modules['6'] = nn.Linear(4096, 10)
268
+ model = model.cuda()
269
+ return model
270
+
271
+
272
+ def FedAvg(w):
273
+ w_avg = copy.deepcopy(w[0])
274
+ for k in w_avg.keys():
275
+ for i in range(1, len(w)):
276
+ w_avg[k] += w[i][k]
277
+ # 只考虑iid noise的话,每个client训练样本数一样,所以不用做nk/n
278
+ w_avg[k] = torch.div(w_avg[k], len(w))
279
+
280
+ return w_avg
281
+
282
+
283
+ log = open('./checkpoint/%s.txt' % args.id, 'w')
284
+ log.flush()
285
+
286
+ loader = animal_dataloader.animal_dataloader(root=args.data_path, batch_size=args.batch_size, num_workers=0)
287
+
288
+ print('| Building net')
289
+ net1 = create_model()
290
+ net2 = create_model()
291
+ cudnn.benchmark = True
292
+
293
+ optimizer1 = optim.SGD(net1.parameters(), lr=args.lr, momentum=0.9, weight_decay=1e-3)
294
+ optimizer2 = optim.SGD(net2.parameters(), lr=args.lr, momentum=0.9, weight_decay=1e-3)
295
+
296
+ CE = nn.CrossEntropyLoss(reduction='none')
297
+ CEloss = nn.CrossEntropyLoss()
298
+ conf_penalty = NegEntropy()
299
+
300
+ local_round = 5
301
+ balance_crit = 'median' # 'median'
302
+ exp_path = './checkpoint/c2mt_animal10N'
303
+
304
+ boot_loader = None
305
+ w_glob = None
306
+ best_en_acc = 0.
307
+ best_gl_acc = 0.
308
+ resume_epoch = 0
309
+ warm_up = 10
310
+ if resume_epoch > 0:
311
+ snapLast = exp_path + str(resume_epoch - 1) + "_global_model.pth"
312
+ global_state = torch.load(snapLast)
313
+ # 先更新还是后跟新
314
+ w_glob = global_state
315
+ net1.load_state_dict(global_state)
316
+ net2.load_state_dict(global_state)
317
+
318
+ # if True:
319
+ # snapLast = exp_path + "0_1_model.pth"
320
+ # global_state = torch.load(snapLast)
321
+ # net1.load_state_dict(global_state)
322
+ # snapLast = exp_path + "0_2_model.pth"
323
+ # global_state = torch.load(snapLast)
324
+ # net2.load_state_dict(global_state)
325
+ # test_loader = loader.run('test')
326
+ # best_en_acc = test(0, net1, net2, test_loader, best_en_acc)
327
+
328
+ for epoch in range(resume_epoch, args.num_epochs + 1):
329
+ lr = args.lr
330
+ if 100 <= epoch < 150:
331
+ lr /= 10
332
+ elif epoch >= 150:
333
+ lr /= 10
334
+ # if 15 <= epoch:
335
+ # lr /= 2
336
+ for param_group in optimizer1.param_groups:
337
+ param_group['lr'] = lr
338
+ for param_group in optimizer2.param_groups:
339
+ param_group['lr'] = lr
340
+
341
+ local_weights = []
342
+ if epoch < warm_up: # warm up
343
+ train_loader = loader.run('warmup')
344
+ print('Warmup Net1')
345
+ warmup(net1, optimizer1, train_loader)
346
+ train_loader = loader.run('warmup')
347
+ print('\nWarmup Net2')
348
+ warmup(net2, optimizer2, train_loader)
349
+ if epoch == (warm_up - 1):
350
+ snapLast = exp_path + str(epoch) + "_1_model.pth"
351
+ torch.save(net1.state_dict(), snapLast)
352
+ snapLast = exp_path + str(epoch) + "_2_model.pth"
353
+ torch.save(net1.state_dict(), snapLast)
354
+ local_weights.append(net1.state_dict())
355
+ local_weights.append(net2.state_dict())
356
+ w_glob = FedAvg(local_weights)
357
+ else:
358
+ if epoch != warm_up:
359
+ net1.load_state_dict(w_glob)
360
+ net2.load_state_dict(w_glob)
361
+
362
+ for rou in range(local_round):
363
+ print('\n==== net 1 evaluate next epoch training data loss ====')
364
+ eval_loader = loader.run('eval_train') # evaluate training data loss for next epoch
365
+ prob1, paths1 = eval_train(epoch, net1)
366
+ print('\n==== net 2 evaluate next epoch training data loss ====')
367
+ eval_loader = loader.run('eval_train')
368
+ prob2, paths2 = eval_train(epoch, net2)
369
+
370
+ pred1 = (prob1 > args.p_threshold) # divide dataset
371
+ pred2 = (prob2 > args.p_threshold)
372
+
373
+ non_zero_idx = pred1.nonzero()[0].tolist()
374
+ aaa = len(non_zero_idx)
375
+ if balance_crit == "max" or balance_crit == "min" or balance_crit == "median":
376
+ num_clean_per_class = np.zeros(args.num_class)
377
+ ppp = np.array(paths1)[non_zero_idx].tolist()
378
+ target_label = np.array([eval_loader.dataset.train_labels[it] for it in ppp])
379
+ # target_label = np.array(eval_loader.dataset.train_labels[paths1])[non_zero_idx]
380
+ for i in range(args.num_class):
381
+ idx_class = np.where(target_label == i)[0]
382
+ num_clean_per_class[i] = len(idx_class)
383
+
384
+ if balance_crit == "max":
385
+ num_samples2select_class = np.max(num_clean_per_class)
386
+ elif balance_crit == "min":
387
+ num_samples2select_class = np.min(num_clean_per_class)
388
+ elif balance_crit == "median":
389
+ num_samples2select_class = np.median(num_clean_per_class)
390
+
391
+ for i in range(args.num_class):
392
+ idx_class = np.where(np.array([eval_loader.dataset.train_labels[it] for it in paths1]) == i)[0]
393
+ cur_num = num_clean_per_class[i]
394
+ idx_class2 = non_zero_idx
395
+ if num_samples2select_class > cur_num:
396
+ remian_idx = list(set(idx_class.tolist()) - set(idx_class2))
397
+ idx = list(range(len(remian_idx)))
398
+ random.shuffle(idx)
399
+ num_app = int(num_samples2select_class - cur_num)
400
+ idx = idx[:num_app]
401
+ for j in idx:
402
+ non_zero_idx.append(remian_idx[j])
403
+ non_zero_idx = np.array(non_zero_idx).reshape(-1, )
404
+ bbb = len(non_zero_idx)
405
+ num_per_class2 = []
406
+ for i in range(10):
407
+ temp = \
408
+ np.where(np.array([eval_loader.dataset.train_labels[it] for it in paths1])[non_zero_idx.tolist()] == i)[
409
+ 0]
410
+ num_per_class2.append(len(temp))
411
+ print('\npred1 appended num per class:', num_per_class2, aaa, bbb)
412
+ idx_per_class = np.zeros_like(pred1).astype(bool)
413
+ for i in non_zero_idx:
414
+ idx_per_class[i] = True
415
+ pred1 = idx_per_class
416
+ non_aaa = pred1.nonzero()[0].tolist()
417
+ assert len(non_aaa) == len(non_zero_idx)
418
+
419
+ non_zero_idx2 = pred2.nonzero()[0].tolist()
420
+ aaa = len(non_zero_idx2)
421
+ if balance_crit == "max" or balance_crit == "min" or balance_crit == "median":
422
+ num_clean_per_class = np.zeros(args.num_class)
423
+ ppp = np.array(paths2)[non_zero_idx].tolist()
424
+ target_label = np.array([eval_loader.dataset.train_labels[it] for it in ppp])
425
+ for i in range(args.num_class):
426
+ idx_class = np.where(target_label == i)[0]
427
+ num_clean_per_class[i] = len(idx_class)
428
+
429
+ if balance_crit == "max":
430
+ num_samples2select_class = np.max(num_clean_per_class)
431
+ elif balance_crit == "min":
432
+ num_samples2select_class = np.min(num_clean_per_class)
433
+ elif balance_crit == "median":
434
+ num_samples2select_class = np.median(num_clean_per_class)
435
+
436
+ for i in range(args.num_class):
437
+ idx_class = np.where(np.array([eval_loader.dataset.train_labels[it] for it in paths1]) == i)[0]
438
+ cur_num = num_clean_per_class[i]
439
+ idx_class2 = non_zero_idx2
440
+ if num_samples2select_class > cur_num:
441
+ remian_idx = list(set(idx_class.tolist()) - set(idx_class2))
442
+ idx = list(range(len(remian_idx)))
443
+ random.shuffle(idx)
444
+ num_app = int(num_samples2select_class - cur_num)
445
+ idx = idx[:num_app]
446
+ for j in idx:
447
+ non_zero_idx2.append(remian_idx[j])
448
+ non_zero_idx2 = np.array(non_zero_idx2).reshape(-1, )
449
+ bbb = len(non_zero_idx2)
450
+ num_per_class2 = []
451
+ for i in range(10):
452
+ temp = np.where(
453
+ np.array([eval_loader.dataset.train_labels[it] for it in paths1])[non_zero_idx2.tolist()] == i)[0]
454
+ num_per_class2.append(len(temp))
455
+ print('\npred2 appended num per class:', num_per_class2, aaa, bbb)
456
+ idx_per_class2 = np.zeros_like(pred2).astype(bool)
457
+ for i in non_zero_idx2:
458
+ idx_per_class2[i] = True
459
+ pred2 = idx_per_class2
460
+ non_aaa = pred2.nonzero()[0].tolist()
461
+ assert len(non_aaa) == len(non_zero_idx2)
462
+
463
+ print(f'round={rou}/{local_round}, dmix selection, Train Net1')
464
+ labeled_trainloader, unlabeled_trainloader = loader.run('train', pred2, prob2, paths=paths2) # co-divide
465
+ train(epoch, net1, net2, optimizer1, labeled_trainloader, unlabeled_trainloader) # train net1
466
+
467
+ print(f'\nround={rou}/{local_round}, dmix selection, Train Net2')
468
+ labeled_trainloader, unlabeled_trainloader = loader.run('train', pred1, prob1, paths=paths1) # co-divide
469
+ train(epoch, net2, net1, optimizer2, labeled_trainloader, unlabeled_trainloader) # train net2
470
+
471
+ test_loader = loader.run('test')
472
+ if rou != local_round-1:
473
+ best_en_acc = test(epoch, net1, net2, test_loader, best_en_acc)
474
+ # best_gl_acc = test(epoch, net1, net2, test_loader, best_gl_acc, w_glob=w_glob)
475
+
476
+ print(f'c2m, get global network\n')
477
+ local_weights.append(net1.state_dict())
478
+ local_weights.append(net2.state_dict())
479
+ w_glob = FedAvg(local_weights)
480
+ if epoch % 1 == 0:
481
+ snapLast = exp_path + str(epoch) + "_global_model.pth"
482
+ torch.save(w_glob, snapLast)
483
+
484
+ test_loader = loader.run('test')
485
+ best_en_acc = test(epoch, net1, net2, test_loader, best_en_acc)
486
+ best_gl_acc = test(epoch, net1, net2, test_loader, best_gl_acc, w_glob=w_glob)
487
+
dataloader_animal10N.py ADDED
@@ -0,0 +1,200 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torch.utils.data import Dataset, DataLoader
2
+ import torchvision.transforms as transforms
3
+ import random
4
+ import numpy as np
5
+ from PIL import Image
6
+ import json
7
+ import torch
8
+ import os
9
+ import matplotlib
10
+
11
+ def unpickle(file):
12
+ fo = open(file, 'rb').read()
13
+ size = 64 * 64 * 3 + 1
14
+ for i in range(50000):
15
+ arr = np.fromstring(fo[i * size:(i + 1) * size], dtype=np.uint8)
16
+ lab = np.identity(10)[arr[0]]
17
+ img = arr[1:].reshape((3, 64, 64)).transpose((1, 2, 0))
18
+ return img, lab
19
+
20
+ class animal_dataset(Dataset):
21
+ def __init__(self, root, transform, mode, pred=[], path=[], probability=[], num_class=10):
22
+
23
+ self.root = root
24
+ self.transform = transform
25
+ self.mode = mode
26
+
27
+ self.train_dir = root + '/training/'
28
+ self.test_dir = root + '/testing/'
29
+ train_imgs = os.listdir(self.train_dir)
30
+ test_imgs = os.listdir(self.test_dir)
31
+ self.test_data = []
32
+ self.test_labels = []
33
+ noise_file1 = './training_batch.json'
34
+ noise_file2 = './testing_batch.json'
35
+ if mode == 'test':
36
+ if os.path.exists(noise_file2):
37
+ dict = json.load(open(noise_file2, "r"))
38
+ self.test_labels = dict['data']
39
+ self.test_data = dict['label']
40
+ else:
41
+ for img in test_imgs:
42
+ self.test_data.append(self.test_dir+img)
43
+ self.test_labels.append(int(img[0]))
44
+ dicts = {}
45
+ dicts['data'] = self.test_data
46
+ dicts['label'] = self.test_labels
47
+ # json.dump(dicts, open(noise_file2, "w"))
48
+ else:
49
+ if os.path.exists(noise_file1):
50
+ dict = json.load(open(noise_file1, "r"))
51
+ train_data = dict['data']
52
+ train_labels = dict['label']
53
+ else:
54
+ train_data = []
55
+ train_labels = {}
56
+ for img in train_imgs:
57
+ img_path = self.train_dir+img
58
+ train_data.append(img_path)
59
+ train_labels[img_path] = (int(img[0]))
60
+ dicts = {}
61
+ dicts['data'] = train_data
62
+ dicts['label'] = train_labels
63
+ # json.dump(dicts, open(noise_file1, "w"))
64
+ if self.mode == "all":
65
+ self.train_data = train_data
66
+ self.train_labels = train_labels
67
+ elif self.mode == "labeled":
68
+ pred_idx = pred.nonzero()[0]
69
+ train_img = path
70
+ self.train_data = [train_img[i] for i in pred_idx]
71
+ self.probability = probability[pred_idx]
72
+ # self.train_labels = train_labels[pred_idx]
73
+ self.train_labels = train_labels
74
+ print("%s data has a size of %d" % (self.mode, len(self.train_data)))
75
+ elif self.mode == "unlabeled":
76
+ pred_idx = (1 - pred).nonzero()[0]
77
+ train_img = path
78
+ self.train_data = [train_img[i] for i in pred_idx]
79
+ self.probability = probability[pred_idx]
80
+ # self.train_labels = train_labels[pred_idx]
81
+ print("%s data has a size of %d" % (self.mode, len(self.train_data)))
82
+ self.train_labels = train_labels
83
+
84
+ def __getitem__(self, index):
85
+ if self.mode == 'labeled':
86
+ img_path = self.train_data[index]
87
+ target = self.train_labels[img_path]
88
+ prob = self.probability[index]
89
+ image = Image.open(img_path).convert('RGB')
90
+ img1 = self.transform(image)
91
+ img2 = self.transform(image)
92
+ return img1, img2, target, prob
93
+ elif self.mode == 'unlabeled':
94
+ img_path = self.train_data[index]
95
+ image = Image.open(img_path).convert('RGB')
96
+ img1 = self.transform(image)
97
+ img2 = self.transform(image)
98
+ return img1, img2
99
+ elif self.mode == 'all':
100
+ img_path = self.train_data[index]
101
+ target = self.train_labels[img_path]
102
+ image = Image.open(img_path).convert('RGB')
103
+ img = self.transform(image)
104
+ return img, target,img_path
105
+ elif self.mode == 'test':
106
+ img_path = self.test_data[index]
107
+ target = self.test_labels[index]
108
+ image = Image.open(img_path).convert('RGB')
109
+ img = self.transform(image)
110
+ return img, target
111
+
112
+ def __len__(self):
113
+ if self.mode == 'test':
114
+ return len(self.test_data)
115
+ else:
116
+ return len(self.train_data)
117
+
118
+
119
+ class animal_dataloader():
120
+ def __init__(self, root='E:/2_Dataset_All/Animal-10N', batch_size=32, num_workers=0):
121
+ self.batch_size = batch_size
122
+ self.num_workers = num_workers
123
+ self.root = root
124
+
125
+ self.transform_train = transforms.Compose([
126
+ transforms.Resize(64),
127
+ transforms.RandomCrop(64),
128
+ transforms.RandomHorizontalFlip(),
129
+ transforms.ToTensor(),
130
+ transforms.Normalize((0.6959, 0.6537, 0.6371), (0.3113, 0.3192, 0.3214)),
131
+ ])
132
+ self.transform_test = transforms.Compose([
133
+ # transforms.Resize(64),
134
+ # transforms.CenterCrop(64),
135
+ transforms.ToTensor(),
136
+ transforms.Normalize((0.6959, 0.6537, 0.6371), (0.3113, 0.3192, 0.3214)),
137
+ ])
138
+
139
+ def run(self, mode, pred=[], prob=[], paths=[]):
140
+ if mode == 'warmup':
141
+ warmup_dataset = animal_dataset(self.root, transform=self.transform_train, mode='all')
142
+ warmup_loader = DataLoader(
143
+ dataset=warmup_dataset,
144
+ batch_size=self.batch_size * 2,
145
+ shuffle=True,
146
+ num_workers=self.num_workers,
147
+ pin_memory=True)
148
+ return warmup_loader
149
+ elif mode == 'train':
150
+ labeled_dataset = animal_dataset(self.root, transform=self.transform_train, mode='labeled', pred=pred, path=paths,
151
+ probability=prob)
152
+ labeled_loader = DataLoader(
153
+ dataset=labeled_dataset,
154
+ batch_size=self.batch_size,
155
+ shuffle=True,
156
+ num_workers=self.num_workers,
157
+ pin_memory=True)
158
+ unlabeled_dataset = animal_dataset(self.root, transform=self.transform_train, mode='unlabeled', pred=pred,path=paths,
159
+ probability=prob)
160
+ unlabeled_loader = DataLoader(
161
+ dataset=unlabeled_dataset,
162
+ batch_size=int(self.batch_size),
163
+ shuffle=True,
164
+ num_workers=self.num_workers,
165
+ pin_memory=True)
166
+ return labeled_loader, unlabeled_loader
167
+ elif mode == 'eval_train':
168
+ eval_dataset = animal_dataset(self.root, transform=self.transform_test, mode='all')
169
+ eval_loader = DataLoader(
170
+ dataset=eval_dataset,
171
+ batch_size=self.batch_size,
172
+ shuffle=False,
173
+ num_workers=self.num_workers,
174
+ pin_memory=True)
175
+ return eval_loader
176
+ elif mode == 'test':
177
+ test_dataset = animal_dataset(self.root, transform=self.transform_test, mode='test')
178
+ test_loader = DataLoader(
179
+ dataset=test_dataset,
180
+ batch_size=1000,
181
+ shuffle=False,
182
+ num_workers=self.num_workers,
183
+ pin_memory=True)
184
+ return test_loader
185
+
186
+ # if __name__ == '__main__':
187
+ # loader = animal_dataloader()
188
+ # train_loader = loader.run('warmup')
189
+ # import matplotlib.pyplot as plt
190
+ # for batch_idx, (inputs, labels, idx, img_path) in enumerate(train_loader):
191
+ # print(img_path[0])
192
+ # plt.figure(dpi=300)
193
+ # # plt.imshow(inputs[0])
194
+ # plt.imshow(inputs[0].reshape(64, 64, 3))
195
+ # plt.show()
196
+ # plt.close()
197
+ # print(inputs.shape())
198
+ # print(idx)
199
+ # print(labels, len(labels))
200
+ # # print(train_loader.dataset.__len__())
dataloader_cifar.py ADDED
@@ -0,0 +1,276 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torch.utils.data import Dataset, DataLoader
2
+ import torchvision.transforms as transforms
3
+ import random
4
+ import numpy as np
5
+ from PIL import Image
6
+ import json
7
+ import os
8
+ import torch
9
+ from torchnet.meter import AUCMeter
10
+
11
+
12
+ def unpickle(file):
13
+ import _pickle as cPickle
14
+ with open(file, 'rb') as fo:
15
+ dict = cPickle.load(fo, encoding='latin1')
16
+ return dict
17
+
18
+ class cifar_dataset(Dataset):
19
+ def __init__(self, dataset, r, noise_mode, root_dir, transform, mode, noise_file='', pred=[], probability=[], log='', clean_idx=[], test_form = None):
20
+
21
+ self.r = r # noise ratio
22
+ self.transform = transform
23
+ self.test_form = test_form
24
+ self.mode = mode
25
+ self.transition = {0:0,2:0,4:7,7:7,1:1,9:1,3:5,5:3,6:6,8:8} # class transition for asymmetric noise
26
+ self.noise_file = noise_file
27
+
28
+ if self.mode=='test':
29
+ if dataset=='cifar10':
30
+ test_dic = unpickle('%s/test_batch'%root_dir)
31
+ self.test_data = test_dic['data']
32
+ self.test_data = self.test_data.reshape((10000, 3, 32, 32))
33
+ self.test_data = self.test_data.transpose((0, 2, 3, 1))
34
+ self.test_label = test_dic['labels']
35
+ elif dataset=='cifar100':
36
+ test_dic = unpickle('%s/test'%root_dir)
37
+ self.test_data = test_dic['data']
38
+ self.test_data = self.test_data.reshape((10000, 3, 32, 32))
39
+ self.test_data = self.test_data.transpose((0, 2, 3, 1))
40
+ self.test_label = test_dic['fine_labels']
41
+ else:
42
+ train_data=[]
43
+ train_label=[]
44
+ if dataset=='cifar10':
45
+ for n in range(1,6):
46
+ dpath = '%s/data_batch_%d'%(root_dir,n)
47
+ data_dic = unpickle(dpath)
48
+ train_data.append(data_dic['data'])
49
+ train_label = train_label+data_dic['labels']
50
+ train_data = np.concatenate(train_data)
51
+ elif dataset=='cifar100':
52
+ train_dic = unpickle('%s/train'%root_dir)
53
+ train_data = train_dic['data']
54
+ train_label = train_dic['fine_labels']
55
+ train_data = train_data.reshape((50000, 3, 32, 32))
56
+ train_data = train_data.transpose((0, 2, 3, 1))
57
+
58
+ self.clean_label = np.array(train_label)
59
+
60
+ if os.path.exists(noise_file):
61
+ noise_label = json.load(open(noise_file,"r"))
62
+ else: #inject noise
63
+ noise_label = []
64
+ idx = list(range(50000))
65
+ random.shuffle(idx)
66
+ num_noise = int(self.r*50000)
67
+ noise_idx = idx[:num_noise]
68
+ for i in range(50000):
69
+ if i in noise_idx:
70
+ if noise_mode=='sym':
71
+ if dataset=='cifar10':
72
+ noiselabel = random.randint(0,9)
73
+ elif dataset=='cifar100':
74
+ noiselabel = random.randint(0,99)
75
+ noise_label.append(noiselabel)
76
+ elif noise_mode=='asym':
77
+ noiselabel = self.transition[train_label[i]]
78
+ noise_label.append(noiselabel)
79
+ else:
80
+ noise_label.append(train_label[i])
81
+ print("save noisy labels to %s ..."%noise_file)
82
+ json.dump(noise_label,open(noise_file,"w"))
83
+
84
+ if self.mode == 'all':
85
+ self.train_data = train_data
86
+ self.noise_label = np.array(noise_label).astype(np.int64)
87
+ else:
88
+ if self.mode == "labeled":
89
+ pred_idx = pred.nonzero()[0]
90
+ self.probability = [probability[i] for i in pred_idx]
91
+
92
+ clean = (np.array(noise_label)==np.array(train_label))
93
+ auc_meter = AUCMeter()
94
+ auc_meter.reset()
95
+ auc_meter.add(probability,clean)
96
+ auc,_,_ = auc_meter.value()
97
+ clean_index = np.where(np.array(noise_label)[pred_idx.tolist()] == np.array(self.clean_label)[pred_idx.tolist()])[0]
98
+
99
+ num_per_class = []
100
+ for i in range(max(noise_label)):
101
+ temp = np.where(np.array(noise_label)[clean_index.tolist()] == i)[0]
102
+ num_per_class.append(len(temp))
103
+ num_per_class2 = []
104
+ for i in range(max(noise_label)):
105
+ temp = np.where(np.array(noise_label)[pred_idx.tolist()] == i)[0]
106
+ num_per_class2.append(len(temp))
107
+ print('clean num per class:', num_per_class, num_per_class2)
108
+
109
+ log.write('Numer of labeled samples:%d AUC:%.3f corrected clean num:%d, uncorrected noisy num:%d\n'
110
+ % (pred.sum(), auc, len(clean_index), len(pred_idx) - len(clean_index)))
111
+ log.flush()
112
+
113
+ elif self.mode == "unlabeled":
114
+ pred_idx = (1-pred).nonzero()[0]
115
+ noise_index = np.where(np.array(noise_label)[pred_idx.tolist()] != np.array(self.clean_label)[pred_idx.tolist()])[0]
116
+ log.write('Numer of unlabeled samples:%d corrected noisy num:%d, uncorrected clean num:%d\n'
117
+ % (pred.sum(), len(noise_index), len(pred_idx) - len(noise_index)))
118
+ log.flush()
119
+ elif self.mode == 'boost':
120
+ pred_idx = clean_idx
121
+
122
+ self.train_data = train_data[pred_idx]
123
+ self.noise_label = [noise_label[i] for i in pred_idx]
124
+ print("%s data has a size of %d"%(self.mode,len(self.noise_label)))
125
+
126
+ def if_noise(self, pred=None):
127
+ if pred is None:
128
+ noise_index = np.where(self.noise_label[:] != self.clean_label[:])[0]
129
+ clean_index = np.where(self.noise_label[:] == self.clean_label[:])[0]
130
+ return noise_index, clean_index
131
+ else:
132
+ pred_idx1 = pred.nonzero()[0].tolist()
133
+ clean_index = np.where(np.array(self.noise_label)[pred_idx1] == np.array(self.clean_label)[pred_idx1])[0]
134
+ pred_idx = (1 - pred).nonzero()[0].tolist()
135
+ noise_index = np.where(np.array(self.noise_label)[pred_idx] != np.array(self.clean_label)[pred_idx])[0]
136
+ print(
137
+ f'选择的非mask样本中正确选取的干净标签数量{len(clean_index)}, 不正确选取的非干净数量{len(pred_idx1) - len(clean_index)}.\t '
138
+ f'选择的mask样本中正确选取的不干净标签数量{len(noise_index)}, 不正确选取的干净数量{len(pred_idx) - len(noise_index)}')
139
+ return len(clean_index), (len(pred_idx1) - len(clean_index)), len(noise_index), len(pred_idx) - len(
140
+ noise_index)
141
+ def print_noise_rate(self, new_y):
142
+ temp_y = np.array(new_y.reshape(1, -1).squeeze())
143
+ clean_index = np.where(temp_y[:] == np.array(self.clean_label)[:])
144
+ print(f'clean rate is: {len(clean_index[0]) / len(self.clean_label)}')
145
+
146
+ def load_train_label(self, new_y):
147
+ temp_y = np.array(new_y.reshape(1, -1).squeeze()).astype(np.int64)
148
+ self.noise_label[:] = np.array(temp_y)[:]
149
+ if os.path.exists(self.noise_file):
150
+ result = os.path.splitext(self.noise_file)
151
+ noise_file_temp = result[0]+'_old'+result[1]
152
+ if not os.path.exists(noise_file_temp):
153
+ os.rename(self.noise_file, noise_file_temp)
154
+ # 覆盖原来的noise_file
155
+ json.dump(self.noise_label.tolist(), open(self.noise_file, "w"))
156
+
157
+ def __getitem__(self, index):
158
+ if self.mode=='labeled':
159
+ img, target, prob = self.train_data[index], self.noise_label[index], self.probability[index]
160
+ img = Image.fromarray(img)
161
+ img1 = self.transform(img)
162
+ img2 = self.transform(img)
163
+ return img1, img2, target, prob
164
+ elif self.mode=='unlabeled':
165
+ img = self.train_data[index]
166
+ img = Image.fromarray(img)
167
+ img1 = self.transform(img)
168
+ img2 = self.transform(img)
169
+ return img1, img2
170
+ elif self.mode=='all':
171
+ img, target = self.train_data[index], self.noise_label[index]
172
+ img = Image.fromarray(img)
173
+ img = self.transform(img)
174
+ return img, target, index
175
+ elif self.mode=='test':
176
+ img, target = self.test_data[index], self.test_label[index]
177
+ img = Image.fromarray(img)
178
+ img = self.transform(img)
179
+ return img, target
180
+ elif self.mode=='boost':
181
+ img, target = self.train_data[index], self.noise_label[index]
182
+ img = Image.fromarray(img)
183
+ img_no_da = self.test_form(img)
184
+ img = self.transform(img)
185
+ return img, img_no_da, target, index
186
+
187
+ def __len__(self):
188
+ if self.mode!='test':
189
+ return len(self.train_data)
190
+ else:
191
+ return len(self.test_data)
192
+
193
+
194
+ class cifar_dataloader():
195
+ def __init__(self, dataset, r, noise_mode, batch_size, num_workers, root_dir, log, noise_file=''):
196
+ self.dataset = dataset
197
+ self.r = r
198
+ self.noise_mode = noise_mode
199
+ self.batch_size = batch_size
200
+ self.num_workers = num_workers
201
+ self.root_dir = root_dir
202
+ self.log = log
203
+ self.noise_file = noise_file
204
+ if self.dataset=='cifar10':
205
+ self.transform_train = transforms.Compose([
206
+ transforms.RandomCrop(32, padding=4),
207
+ transforms.RandomHorizontalFlip(),
208
+ transforms.ToTensor(),
209
+ transforms.Normalize((0.4914, 0.4822, 0.4465),(0.2023, 0.1994, 0.2010)),
210
+ ])
211
+ self.transform_test = transforms.Compose([
212
+ transforms.ToTensor(),
213
+ transforms.Normalize((0.4914, 0.4822, 0.4465),(0.2023, 0.1994, 0.2010)),
214
+ ])
215
+ elif self.dataset=='cifar100':
216
+ self.transform_train = transforms.Compose([
217
+ transforms.RandomCrop(32, padding=4),
218
+ transforms.RandomHorizontalFlip(),
219
+ transforms.ToTensor(),
220
+ transforms.Normalize((0.507, 0.487, 0.441), (0.267, 0.256, 0.276)),
221
+ ])
222
+ self.transform_test = transforms.Compose([
223
+ transforms.ToTensor(),
224
+ transforms.Normalize((0.507, 0.487, 0.441), (0.267, 0.256, 0.276)),
225
+ ])
226
+ def run(self,mode,pred=[],prob=[], clean_idx=[]):
227
+ if mode=='warmup':
228
+ all_dataset = cifar_dataset(dataset=self.dataset, noise_mode=self.noise_mode, r=self.r, root_dir=self.root_dir, transform=self.transform_train, mode="all",noise_file=self.noise_file)
229
+ trainloader = DataLoader(
230
+ dataset=all_dataset,
231
+ batch_size=self.batch_size*2,
232
+ shuffle=True,
233
+ num_workers=self.num_workers)
234
+ return trainloader
235
+
236
+ elif mode=='train':
237
+ labeled_dataset = cifar_dataset(dataset=self.dataset, noise_mode=self.noise_mode, r=self.r, root_dir=self.root_dir, transform=self.transform_train, mode="labeled", noise_file=self.noise_file, pred=pred, probability=prob,log=self.log)
238
+ labeled_trainloader = DataLoader(
239
+ dataset=labeled_dataset,
240
+ batch_size=self.batch_size,
241
+ shuffle=True,
242
+ num_workers=self.num_workers)
243
+
244
+ unlabeled_dataset = cifar_dataset(dataset=self.dataset, noise_mode=self.noise_mode, r=self.r, root_dir=self.root_dir, transform=self.transform_train, mode="unlabeled", noise_file=self.noise_file, pred=pred, log=self.log)
245
+ unlabeled_trainloader = DataLoader(
246
+ dataset=unlabeled_dataset,
247
+ batch_size=self.batch_size,
248
+ shuffle=True,
249
+ num_workers=self.num_workers)
250
+ return labeled_trainloader, unlabeled_trainloader
251
+
252
+ elif mode=='test':
253
+ test_dataset = cifar_dataset(dataset=self.dataset, noise_mode=self.noise_mode, r=self.r, root_dir=self.root_dir, transform=self.transform_test, mode='test')
254
+ test_loader = DataLoader(
255
+ dataset=test_dataset,
256
+ batch_size=self.batch_size,
257
+ shuffle=False,
258
+ num_workers=self.num_workers)
259
+ return test_loader
260
+
261
+ elif mode=='eval_train':
262
+ eval_dataset = cifar_dataset(dataset=self.dataset, noise_mode=self.noise_mode, r=self.r, root_dir=self.root_dir, transform=self.transform_test, mode='all', noise_file=self.noise_file)
263
+ eval_loader = DataLoader(
264
+ dataset=eval_dataset,
265
+ batch_size=self.batch_size,
266
+ shuffle=False,
267
+ num_workers=self.num_workers)
268
+ return eval_loader
269
+ elif mode=='boost':
270
+ eval_dataset = cifar_dataset(dataset=self.dataset, noise_mode=self.noise_mode, r=self.r, root_dir=self.root_dir, transform=self.transform_train, mode=mode, noise_file=self.noise_file, clean_idx=clean_idx, test_form=self.transform_test)
271
+ eval_loader = DataLoader(
272
+ dataset=eval_dataset,
273
+ batch_size=self.batch_size,
274
+ shuffle=False,
275
+ num_workers=self.num_workers)
276
+ return eval_loader
img/framework.tif ADDED
models/CNN.py ADDED
@@ -0,0 +1,193 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.init as init
5
+ import torch.nn.functional as F
6
+ import torch.utils.model_zoo as model_zoo
7
+
8
+ class HiddenLayer(nn.Module):
9
+ def __init__(self, input_size, output_size):
10
+ super(HiddenLayer, self).__init__()
11
+ self.fc = nn.Linear(input_size, output_size)
12
+ self.relu = nn.ReLU()
13
+
14
+ def forward(self, x):
15
+ return self.relu(self.fc(x))
16
+
17
+
18
+ class VNet(nn.Module):
19
+ def __init__(self, hidden_size=100, num_layers=1):
20
+ super(VNet, self).__init__()
21
+ self.first_hidden_layer = HiddenLayer(1, hidden_size)
22
+ self.rest_hidden_layers = nn.Sequential(*[HiddenLayer(hidden_size, hidden_size) for _ in range(num_layers - 1)])
23
+ self.output_layer = nn.Linear(hidden_size, 1)
24
+
25
+ def forward(self, x):
26
+ x = self.first_hidden_layer(x)
27
+ x = self.rest_hidden_layers(x)
28
+ x = self.output_layer(x)
29
+ return torch.sigmoid(x)
30
+
31
+
32
+ class CNN(nn.Module):
33
+ def __init__(self, input_channel=3, n_outputs=10, dropout_rate=0.25):
34
+ self.dropout_rate = dropout_rate
35
+ super(CNN, self).__init__()
36
+
37
+ #block1
38
+ self.conv1 = nn.Conv2d(input_channel, 128, kernel_size=3, stride=1, padding=1)
39
+ self.bn1=nn.BatchNorm2d(128)
40
+ self.conv2 = nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1)
41
+ self.bn2=nn.BatchNorm2d(128)
42
+ self.conv3 = nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1)
43
+ self.bn3=nn.BatchNorm2d(128)
44
+
45
+ #block2
46
+ self.conv4 = nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1)
47
+ self.bn4=nn.BatchNorm2d(256)
48
+ self.conv5 = nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1)
49
+ self.bn5=nn.BatchNorm2d(256)
50
+ self.conv6 = nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1)
51
+ self.bn6=nn.BatchNorm2d(256)
52
+
53
+ #block3
54
+ self.conv7 = nn.Conv2d(256, 512, kernel_size=3, stride=1, padding=0)
55
+ self.bn7=nn.BatchNorm2d(512)
56
+ self.conv8 = nn.Conv2d(512, 256, kernel_size=3, stride=1, padding=0)
57
+ self.bn8=nn.BatchNorm2d(256)
58
+ self.conv9 = nn.Conv2d(256, 128, kernel_size=3, stride=1, padding=0)
59
+ self.bn9=nn.BatchNorm2d(128)
60
+
61
+ self.pool = nn.MaxPool2d(2, 2)
62
+ self.avgpool = nn.AvgPool2d(kernel_size=2)
63
+
64
+ self.fc=nn.Linear(128,n_outputs)
65
+
66
+ for m in self.modules():
67
+ if isinstance(m, nn.Conv2d):
68
+ nn.init.kaiming_normal_(m.weight, mode='fan_out')
69
+ elif isinstance(m, nn.BatchNorm2d):
70
+ nn.init.constant_(m.weight, 1)
71
+ nn.init.constant_(m.bias, 0)
72
+
73
+ def forward(self, x):
74
+
75
+ #block1
76
+ x=F.leaky_relu(self.bn1(self.conv1(x)), negative_slope=0.01)
77
+ x=F.leaky_relu(self.bn2(self.conv2(x)), negative_slope=0.01)
78
+ x=F.leaky_relu(self.bn3(self.conv3(x)), negative_slope=0.01)
79
+ x=self.pool(x)
80
+ x=F.dropout2d(x, p=self.dropout_rate)
81
+
82
+ #block2
83
+ x=F.leaky_relu(self.bn4(self.conv4(x)), negative_slope=0.01)
84
+ x=F.leaky_relu(self.bn5(self.conv5(x)), negative_slope=0.01)
85
+ x=F.leaky_relu(self.bn6(self.conv6(x)), negative_slope=0.01)
86
+ x=self.pool(x)
87
+ x=F.dropout2d(x, p=self.dropout_rate)
88
+
89
+ #block3
90
+ x=F.leaky_relu(self.bn7(self.conv7(x)), negative_slope=0.01)
91
+ x=F.leaky_relu(self.bn8(self.conv8(x)), negative_slope=0.01)
92
+ x=F.leaky_relu(self.bn9(self.conv9(x)), negative_slope=0.01)
93
+ x=self.avgpool(x)
94
+
95
+ x = x.view(x.size(0), x.size(1))
96
+ x=self.fc(x)
97
+ return x
98
+
99
+ def conv3x3(in_planes, out_planes, stride=1):
100
+ """3x3 convolution with padding"""
101
+ return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
102
+ padding=1, bias=False)
103
+
104
+ class BasicBlock(nn.Module):
105
+ expansion = 1
106
+
107
+ def __init__(self, inplanes, planes, stride=1, downsample=None):
108
+ super(BasicBlock, self).__init__()
109
+ self.conv1 = conv3x3(inplanes, planes, stride)
110
+ self.bn1 = nn.BatchNorm2d(planes)
111
+ self.relu = nn.ReLU(inplace=True)
112
+ self.conv2 = conv3x3(planes, planes)
113
+ self.bn2 = nn.BatchNorm2d(planes)
114
+ self.downsample = downsample
115
+ self.stride = stride
116
+
117
+ def forward(self, x):
118
+ residual = x
119
+
120
+ out = self.conv1(x)
121
+ out = self.bn1(out)
122
+ out = self.relu(out)
123
+
124
+ out = self.conv2(out)
125
+ out = self.bn2(out)
126
+
127
+ if self.downsample is not None:
128
+ residual = self.downsample(x)
129
+
130
+ out += residual
131
+ out = self.relu(out)
132
+
133
+ return out
134
+
135
+ class ResNet(nn.Module):
136
+
137
+ def __init__(self, block, layers, num_classes=14):
138
+ self.inplanes = 64
139
+ super(ResNet, self).__init__()
140
+ self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False)
141
+ self.bn1 = nn.BatchNorm2d(64)
142
+ self.relu = nn.ReLU(inplace=True)
143
+ self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
144
+ self.layer1 = self._make_layer(block, 64, layers[0])
145
+ self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
146
+ self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
147
+ self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
148
+ self.avgpool = nn.AvgPool2d(7, stride=1)
149
+ self.fc = nn.Linear(512 * block.expansion, num_classes)
150
+
151
+ for m in self.modules():
152
+ if isinstance(m, nn.Conv2d):
153
+ nn.init.kaiming_normal_(m.weight, mode='fan_out')
154
+ elif isinstance(m, nn.BatchNorm2d):
155
+ nn.init.constant_(m.weight, 1)
156
+ nn.init.constant_(m.bias, 0)
157
+
158
+ def _make_layer(self, block, planes, blocks, stride=1):
159
+ downsample = None
160
+ if stride != 1 or self.inplanes != planes * block.expansion:
161
+ downsample = nn.Sequential(
162
+ nn.Conv2d(self.inplanes, planes * block.expansion,
163
+ kernel_size=1, stride=stride, bias=False),
164
+ nn.BatchNorm2d(planes * block.expansion),
165
+ )
166
+
167
+ layers = []
168
+ layers.append(block(self.inplanes, planes, stride, downsample))
169
+ self.inplanes = planes * block.expansion
170
+ for i in range(1, blocks):
171
+ layers.append(block(self.inplanes, planes))
172
+
173
+ return nn.Sequential(*layers)
174
+
175
+ def forward(self, x):
176
+ x = self.conv1(x)
177
+ x = self.bn1(x)
178
+ x = self.relu(x)
179
+ x = self.maxpool(x)
180
+
181
+ x = self.layer1(x)
182
+ x = self.layer2(x)
183
+ x = self.layer3(x)
184
+ x = self.layer4(x)
185
+
186
+ x = self.avgpool(x)
187
+ x = x.view(x.size(0), -1)
188
+ x = self.fc(x)
189
+ return x
190
+
191
+ def resnet18(pretrained=False, **kwargs):
192
+ model = ResNet(BasicBlock, [2, 2, 2, 2], **kwargs)
193
+ return model
models/InceptionResNetV2.py ADDED
@@ -0,0 +1,345 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import print_function, division, absolute_import
2
+ import torch
3
+ import torch.nn as nn
4
+ import os
5
+ import sys
6
+
7
+
8
+ class BasicConv2d(nn.Module):
9
+
10
+ def __init__(self, in_planes, out_planes, kernel_size, stride, padding=0):
11
+ super(BasicConv2d, self).__init__()
12
+ self.conv = nn.Conv2d(in_planes, out_planes,
13
+ kernel_size=kernel_size, stride=stride,
14
+ padding=padding, bias=False) # verify bias false
15
+ self.bn = nn.BatchNorm2d(out_planes,
16
+ eps=0.001, # value found in tensorflow
17
+ momentum=0.1, # default pytorch value
18
+ affine=True)
19
+ self.relu = nn.ReLU(inplace=False)
20
+
21
+ def forward(self, x):
22
+ x = self.conv(x)
23
+ x = self.bn(x)
24
+ x = self.relu(x)
25
+ return x
26
+
27
+
28
+ class Mixed_5b(nn.Module):
29
+
30
+ def __init__(self):
31
+ super(Mixed_5b, self).__init__()
32
+
33
+ self.branch0 = BasicConv2d(192, 96, kernel_size=1, stride=1)
34
+
35
+ self.branch1 = nn.Sequential(
36
+ BasicConv2d(192, 48, kernel_size=1, stride=1),
37
+ BasicConv2d(48, 64, kernel_size=5, stride=1, padding=2)
38
+ )
39
+
40
+ self.branch2 = nn.Sequential(
41
+ BasicConv2d(192, 64, kernel_size=1, stride=1),
42
+ BasicConv2d(64, 96, kernel_size=3, stride=1, padding=1),
43
+ BasicConv2d(96, 96, kernel_size=3, stride=1, padding=1)
44
+ )
45
+
46
+ self.branch3 = nn.Sequential(
47
+ nn.AvgPool2d(3, stride=1, padding=1, count_include_pad=False),
48
+ BasicConv2d(192, 64, kernel_size=1, stride=1)
49
+ )
50
+
51
+ def forward(self, x):
52
+ x0 = self.branch0(x)
53
+ x1 = self.branch1(x)
54
+ x2 = self.branch2(x)
55
+ x3 = self.branch3(x)
56
+ out = torch.cat((x0, x1, x2, x3), 1)
57
+ return out
58
+
59
+
60
+ class Block35(nn.Module):
61
+
62
+ def __init__(self, scale=1.0):
63
+ super(Block35, self).__init__()
64
+
65
+ self.scale = scale
66
+
67
+ self.branch0 = BasicConv2d(320, 32, kernel_size=1, stride=1)
68
+
69
+ self.branch1 = nn.Sequential(
70
+ BasicConv2d(320, 32, kernel_size=1, stride=1),
71
+ BasicConv2d(32, 32, kernel_size=3, stride=1, padding=1)
72
+ )
73
+
74
+ self.branch2 = nn.Sequential(
75
+ BasicConv2d(320, 32, kernel_size=1, stride=1),
76
+ BasicConv2d(32, 48, kernel_size=3, stride=1, padding=1),
77
+ BasicConv2d(48, 64, kernel_size=3, stride=1, padding=1)
78
+ )
79
+
80
+ self.conv2d = nn.Conv2d(128, 320, kernel_size=1, stride=1)
81
+ self.relu = nn.ReLU(inplace=False)
82
+
83
+ def forward(self, x):
84
+ x0 = self.branch0(x)
85
+ x1 = self.branch1(x)
86
+ x2 = self.branch2(x)
87
+ out = torch.cat((x0, x1, x2), 1)
88
+ out = self.conv2d(out)
89
+ out = out * self.scale + x
90
+ out = self.relu(out)
91
+ return out
92
+
93
+
94
+ class Mixed_6a(nn.Module):
95
+
96
+ def __init__(self):
97
+ super(Mixed_6a, self).__init__()
98
+
99
+ self.branch0 = BasicConv2d(320, 384, kernel_size=3, stride=2)
100
+
101
+ self.branch1 = nn.Sequential(
102
+ BasicConv2d(320, 256, kernel_size=1, stride=1),
103
+ BasicConv2d(256, 256, kernel_size=3, stride=1, padding=1),
104
+ BasicConv2d(256, 384, kernel_size=3, stride=2)
105
+ )
106
+
107
+ self.branch2 = nn.MaxPool2d(3, stride=2)
108
+
109
+ def forward(self, x):
110
+ x0 = self.branch0(x)
111
+ x1 = self.branch1(x)
112
+ x2 = self.branch2(x)
113
+ out = torch.cat((x0, x1, x2), 1)
114
+ return out
115
+
116
+
117
+ class Block17(nn.Module):
118
+
119
+ def __init__(self, scale=1.0):
120
+ super(Block17, self).__init__()
121
+
122
+ self.scale = scale
123
+
124
+ self.branch0 = BasicConv2d(1088, 192, kernel_size=1, stride=1)
125
+
126
+ self.branch1 = nn.Sequential(
127
+ BasicConv2d(1088, 128, kernel_size=1, stride=1),
128
+ BasicConv2d(128, 160, kernel_size=(1,7), stride=1, padding=(0,3)),
129
+ BasicConv2d(160, 192, kernel_size=(7,1), stride=1, padding=(3,0))
130
+ )
131
+
132
+ self.conv2d = nn.Conv2d(384, 1088, kernel_size=1, stride=1)
133
+ self.relu = nn.ReLU(inplace=False)
134
+
135
+ def forward(self, x):
136
+ x0 = self.branch0(x)
137
+ x1 = self.branch1(x)
138
+ out = torch.cat((x0, x1), 1)
139
+ out = self.conv2d(out)
140
+ out = out * self.scale + x
141
+ out = self.relu(out)
142
+ return out
143
+
144
+
145
+ class Mixed_7a(nn.Module):
146
+
147
+ def __init__(self):
148
+ super(Mixed_7a, self).__init__()
149
+
150
+ self.branch0 = nn.Sequential(
151
+ BasicConv2d(1088, 256, kernel_size=1, stride=1),
152
+ BasicConv2d(256, 384, kernel_size=3, stride=2)
153
+ )
154
+
155
+ self.branch1 = nn.Sequential(
156
+ BasicConv2d(1088, 256, kernel_size=1, stride=1),
157
+ BasicConv2d(256, 288, kernel_size=3, stride=2)
158
+ )
159
+
160
+ self.branch2 = nn.Sequential(
161
+ BasicConv2d(1088, 256, kernel_size=1, stride=1),
162
+ BasicConv2d(256, 288, kernel_size=3, stride=1, padding=1),
163
+ BasicConv2d(288, 320, kernel_size=3, stride=2)
164
+ )
165
+
166
+ self.branch3 = nn.MaxPool2d(3, stride=2)
167
+
168
+ def forward(self, x):
169
+ x0 = self.branch0(x)
170
+ x1 = self.branch1(x)
171
+ x2 = self.branch2(x)
172
+ x3 = self.branch3(x)
173
+ out = torch.cat((x0, x1, x2, x3), 1)
174
+ return out
175
+
176
+
177
+ class Block8(nn.Module):
178
+
179
+ def __init__(self, scale=1.0, noReLU=False):
180
+ super(Block8, self).__init__()
181
+
182
+ self.scale = scale
183
+ self.noReLU = noReLU
184
+
185
+ self.branch0 = BasicConv2d(2080, 192, kernel_size=1, stride=1)
186
+
187
+ self.branch1 = nn.Sequential(
188
+ BasicConv2d(2080, 192, kernel_size=1, stride=1),
189
+ BasicConv2d(192, 224, kernel_size=(1,3), stride=1, padding=(0,1)),
190
+ BasicConv2d(224, 256, kernel_size=(3,1), stride=1, padding=(1,0))
191
+ )
192
+
193
+ self.conv2d = nn.Conv2d(448, 2080, kernel_size=1, stride=1)
194
+ if not self.noReLU:
195
+ self.relu = nn.ReLU(inplace=False)
196
+
197
+ def forward(self, x):
198
+ x0 = self.branch0(x)
199
+ x1 = self.branch1(x)
200
+ out = torch.cat((x0, x1), 1)
201
+ out = self.conv2d(out)
202
+ out = out * self.scale + x
203
+ if not self.noReLU:
204
+ out = self.relu(out)
205
+ return out
206
+
207
+
208
+ class InceptionResNetV2(nn.Module):
209
+
210
+ def __init__(self, num_classes=50):
211
+ super(InceptionResNetV2, self).__init__()
212
+ # Special attributs
213
+ self.num_classes = num_classes
214
+ self.input_space = None
215
+ self.input_size = (299, 299, 3)
216
+ self.mean = None
217
+ self.std = None
218
+ # Modules
219
+ self.conv2d_1a = BasicConv2d(3, 32, kernel_size=3, stride=2)
220
+ self.conv2d_2a = BasicConv2d(32, 32, kernel_size=3, stride=1)
221
+ self.conv2d_2b = BasicConv2d(32, 64, kernel_size=3, stride=1, padding=1)
222
+ self.maxpool_3a = nn.MaxPool2d(3, stride=2)
223
+ self.conv2d_3b = BasicConv2d(64, 80, kernel_size=1, stride=1)
224
+ self.conv2d_4a = BasicConv2d(80, 192, kernel_size=3, stride=1)
225
+ self.maxpool_5a = nn.MaxPool2d(3, stride=2)
226
+ self.mixed_5b = Mixed_5b()
227
+ self.repeat = nn.Sequential(
228
+ Block35(scale=0.17),
229
+ Block35(scale=0.17),
230
+ Block35(scale=0.17),
231
+ Block35(scale=0.17),
232
+ Block35(scale=0.17),
233
+ Block35(scale=0.17),
234
+ Block35(scale=0.17),
235
+ Block35(scale=0.17),
236
+ Block35(scale=0.17),
237
+ Block35(scale=0.17)
238
+ )
239
+ self.mixed_6a = Mixed_6a()
240
+ self.repeat_1 = nn.Sequential(
241
+ Block17(scale=0.10),
242
+ Block17(scale=0.10),
243
+ Block17(scale=0.10),
244
+ Block17(scale=0.10),
245
+ Block17(scale=0.10),
246
+ Block17(scale=0.10),
247
+ Block17(scale=0.10),
248
+ Block17(scale=0.10),
249
+ Block17(scale=0.10),
250
+ Block17(scale=0.10),
251
+ Block17(scale=0.10),
252
+ Block17(scale=0.10),
253
+ Block17(scale=0.10),
254
+ Block17(scale=0.10),
255
+ Block17(scale=0.10),
256
+ Block17(scale=0.10),
257
+ Block17(scale=0.10),
258
+ Block17(scale=0.10),
259
+ Block17(scale=0.10),
260
+ Block17(scale=0.10)
261
+ )
262
+ self.mixed_7a = Mixed_7a()
263
+ self.repeat_2 = nn.Sequential(
264
+ Block8(scale=0.20),
265
+ Block8(scale=0.20),
266
+ Block8(scale=0.20),
267
+ Block8(scale=0.20),
268
+ Block8(scale=0.20),
269
+ Block8(scale=0.20),
270
+ Block8(scale=0.20),
271
+ Block8(scale=0.20),
272
+ Block8(scale=0.20)
273
+ )
274
+ self.block8 = Block8(noReLU=True)
275
+ self.conv2d_7b = BasicConv2d(2080, 1536, kernel_size=1, stride=1)
276
+ self.avgpool_1a = nn.AdaptiveAvgPool2d((1, 1))#nn.AvgPool2d(8, count_include_pad=False)
277
+ self.last_linear = nn.Linear(1536, num_classes)
278
+
279
+ self.branch = self._make_branch(320, 1536, 3)
280
+ self.branch1 = self._make_branch(1088, 1536, 3)
281
+ self.branch2 = self._make_branch(2080, 1536, 3)
282
+
283
+ def _make_branch(self, channel_in, channel_out, kernel_size):
284
+ middle_channel = channel_out // 4
285
+ return nn.Sequential(
286
+ nn.Conv2d(channel_in, middle_channel, kernel_size=1, stride=1),
287
+ nn.BatchNorm2d(middle_channel),
288
+ nn.ReLU(),
289
+
290
+ nn.Conv2d(middle_channel, middle_channel, kernel_size=kernel_size, stride=kernel_size),
291
+ nn.BatchNorm2d(middle_channel),
292
+ nn.ReLU(),
293
+
294
+ nn.Conv2d(middle_channel, channel_out, kernel_size=1, stride=1),
295
+ nn.BatchNorm2d(channel_out),
296
+ nn.ReLU(),
297
+
298
+ nn.AdaptiveAvgPool2d((1,1)),
299
+ nn.Flatten(),
300
+ nn.Linear(channel_out, self.num_classes)
301
+ )
302
+
303
+
304
+ def features(self, input):
305
+ x = self.conv2d_1a(input)
306
+ x = self.conv2d_2a(x)
307
+ x = self.conv2d_2b(x)
308
+ x = self.maxpool_3a(x)
309
+ x = self.conv2d_3b(x)
310
+ x = self.conv2d_4a(x)
311
+ x = self.maxpool_5a(x)
312
+ x = self.mixed_5b(x)
313
+ x = self.repeat(x)
314
+ x1 = self.branch(x)
315
+
316
+ x = self.mixed_6a(x)
317
+ x = self.repeat_1(x)
318
+ x2 = self.branch1(x)
319
+
320
+ x = self.mixed_7a(x)
321
+ x = self.repeat_2(x)
322
+ x3 = self.branch2(x)
323
+
324
+ x = self.block8(x)
325
+ x = self.conv2d_7b(x)
326
+ return x, x1, x2, x3
327
+
328
+ def logits(self, features):
329
+ x = self.avgpool_1a(features)
330
+ x = x.view(x.size(0), -1)
331
+ out = self.last_linear(x)
332
+ return out
333
+
334
+
335
+ def forward(self, input):
336
+ x, x1, x2, x3, = self.features(input)
337
+ out = self.logits(x)
338
+ return {'outputs': [out, x1, x2, x3]}
339
+
340
+
341
+ def test():
342
+ net = InceptionResNetV2().cuda()
343
+ y = net(torch.randn(1,3,227,227).cuda())
344
+ print(y.size())
345
+ #test()
models/ResNet_Imagenet.py ADDED
@@ -0,0 +1,289 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+ try:
5
+ from torch.hub import load_state_dict_from_url
6
+ except ImportError:
7
+ from torch.utils.model_zoo import load_url as load_state_dict_from_url
8
+
9
+ model_urls = {
10
+ 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth',
11
+ 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth',
12
+ 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth',
13
+ 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth',
14
+ 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth',
15
+ 'resnext50_32x4d': 'https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth',
16
+ 'resnext101_32x8d': 'https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth',
17
+ 'wide_resnet50_2': 'https://download.pytorch.org/models/wide_resnet50_2-95faca4d.pth',
18
+ 'wide_resnet101_2': 'https://download.pytorch.org/models/wide_resnet101_2-32ee1156.pth',
19
+ }
20
+
21
+
22
+ def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):
23
+ """3x3 convolution with padding"""
24
+ return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
25
+ padding=dilation, groups=groups, bias=False, dilation=dilation)
26
+
27
+
28
+ def conv1x1(in_planes, out_planes, stride=1):
29
+ """1x1 convolution"""
30
+ return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
31
+
32
+
33
+ def branchBottleNeck(channel_in, channel_out, kernel_size):
34
+ middle_channel = channel_out//4
35
+ return nn.Sequential(
36
+ nn.Conv2d(channel_in, middle_channel, kernel_size=1, stride=1),
37
+ nn.BatchNorm2d(middle_channel),
38
+ nn.ReLU(),
39
+
40
+ nn.Conv2d(middle_channel, middle_channel, kernel_size=kernel_size, stride=kernel_size),
41
+ nn.BatchNorm2d(middle_channel),
42
+ nn.ReLU(),
43
+
44
+ nn.Conv2d(middle_channel, channel_out, kernel_size=1, stride=1),
45
+ nn.BatchNorm2d(channel_out),
46
+ nn.ReLU(),
47
+ )
48
+
49
+ class LambdaLayer(nn.Module):
50
+ def __init__(self, lambd):
51
+ super(LambdaLayer, self).__init__()
52
+ self.lambd = lambd
53
+
54
+ def forward(self, x):
55
+ return self.lambd(x)
56
+
57
+ class BasicBlock(nn.Module):
58
+ expansion = 1
59
+
60
+ def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
61
+ base_width=64, dilation=1, norm_layer=None):
62
+ super(BasicBlock, self).__init__()
63
+ if norm_layer is None:
64
+ norm_layer = nn.BatchNorm2d
65
+ if groups != 1 or base_width != 64:
66
+ raise ValueError('BasicBlock only supports groups=1 and base_width=64')
67
+ if dilation > 1:
68
+ raise NotImplementedError("Dilation > 1 not supported in BasicBlock")
69
+ # Both self.conv1 and self.downsample layers downsample the input when stride != 1
70
+ self.conv1 = conv3x3(inplanes, planes, stride)
71
+ self.bn1 = norm_layer(planes)
72
+ self.relu = nn.ReLU(inplace=True)
73
+ self.conv2 = conv3x3(planes, planes)
74
+ self.bn2 = norm_layer(planes)
75
+ self.downsample = downsample
76
+ self.stride = stride
77
+
78
+ def forward(self, x):
79
+ identity = x
80
+
81
+ out = self.conv1(x)
82
+ out = self.bn1(out)
83
+ out = self.relu(out)
84
+
85
+ out = self.conv2(out)
86
+ out = self.bn2(out)
87
+
88
+ if self.downsample is not None:
89
+ identity = self.downsample(x)
90
+
91
+ out += identity
92
+ out = self.relu(out)
93
+
94
+ return out
95
+
96
+
97
+ class Bottleneck(nn.Module):
98
+ # Bottleneck in torchvision places the stride for downsampling at 3x3 convolution(self.conv2)
99
+ # while original implementation places the stride at the first 1x1 convolution(self.conv1)
100
+ # according to "Deep residual learning for image recognition"https://arxiv.org/abs/1512.03385.
101
+ # This variant is also known as ResNet V1.5 and improves accuracy according to
102
+ # https://ngc.nvidia.com/catalog/model-scripts/nvidia:resnet_50_v1_5_for_pytorch.
103
+
104
+ expansion = 4
105
+
106
+ def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
107
+ base_width=64, dilation=1, norm_layer=None):
108
+ super(Bottleneck, self).__init__()
109
+ if norm_layer is None:
110
+ norm_layer = nn.BatchNorm2d
111
+ width = int(planes * (base_width / 64.)) * groups
112
+ # Both self.conv2 and self.downsample layers downsample the input when stride != 1
113
+ self.conv1 = conv1x1(inplanes, width)
114
+ self.bn1 = norm_layer(width)
115
+ self.conv2 = conv3x3(width, width, stride, groups, dilation)
116
+ self.bn2 = norm_layer(width)
117
+ self.conv3 = conv1x1(width, planes * self.expansion)
118
+ self.bn3 = norm_layer(planes * self.expansion)
119
+ self.relu = nn.ReLU(inplace=True)
120
+ self.downsample = downsample
121
+ self.stride = stride
122
+
123
+ def forward(self, x):
124
+ identity = x
125
+
126
+ out = self.conv1(x)
127
+ out = self.bn1(out)
128
+ out = self.relu(out)
129
+
130
+ out = self.conv2(out)
131
+ out = self.bn2(out)
132
+ out = self.relu(out)
133
+
134
+ out = self.conv3(out)
135
+ out = self.bn3(out)
136
+
137
+ if self.downsample is not None:
138
+ identity = self.downsample(x)
139
+
140
+ out += identity
141
+ out = self.relu(out)
142
+
143
+ return out
144
+
145
+
146
+ class ResNet(nn.Module):
147
+
148
+ def __init__(self, block, layers, num_classes=1000, zero_init_residual=False,
149
+ groups=1, width_per_group=64, replace_stride_with_dilation=None,
150
+ norm_layer=None):
151
+ super(ResNet, self).__init__()
152
+ if norm_layer is None:
153
+ norm_layer = nn.BatchNorm2d
154
+ self._norm_layer = norm_layer
155
+ self.num_classes = num_classes
156
+
157
+ self.inplanes = 64
158
+ self.dilation = 1
159
+ if replace_stride_with_dilation is None:
160
+ # each element in the tuple indicates if we should replace
161
+ # the 2x2 stride with a dilated convolution instead
162
+ replace_stride_with_dilation = [False, False, False]
163
+ if len(replace_stride_with_dilation) != 3:
164
+ raise ValueError("replace_stride_with_dilation should be None "
165
+ "or a 3-element tuple, got {}".format(replace_stride_with_dilation))
166
+ self.groups = groups
167
+ self.base_width = width_per_group
168
+ self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3,
169
+ bias=False)
170
+ self.bn1 = norm_layer(self.inplanes)
171
+ self.relu = nn.ReLU(inplace=True)
172
+ self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
173
+ self.layer1 = self._make_layer(block, 64, layers[0])
174
+ self.layer2 = self._make_layer(block, 128, layers[1], stride=2,
175
+ dilate=replace_stride_with_dilation[0])
176
+ self.layer3 = self._make_layer(block, 256, layers[2], stride=2,
177
+ dilate=replace_stride_with_dilation[1])
178
+ self.layer4 = self._make_layer(block, 512, layers[3], stride=2,
179
+ dilate=replace_stride_with_dilation[2])
180
+
181
+ self.branch1 = self._make_branch(64*block.expansion, 512*block.expansion, kernel_size=8)
182
+ self.branch2 = self._make_branch(128*block.expansion, 512*block.expansion, kernel_size=4)
183
+ self.branch3 = self._make_branch(256*block.expansion, 512*block.expansion, kernel_size=2)
184
+ self.branch4 = self._make_branch(512*block.expansion, 512*block.expansion, kernel_size=1)
185
+
186
+ self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
187
+ self.fc = nn.Linear(512 * block.expansion, num_classes)
188
+
189
+ for m in self.modules():
190
+ if isinstance(m, nn.Conv2d):
191
+ nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
192
+ elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
193
+ nn.init.constant_(m.weight, 1)
194
+ nn.init.constant_(m.bias, 0)
195
+
196
+ # Zero-initialize the last BN in each residual branch,
197
+ # so that the residual branch starts with zeros, and each residual block behaves like an identity.
198
+ # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677
199
+ if zero_init_residual:
200
+ for m in self.modules():
201
+ if isinstance(m, Bottleneck):
202
+ nn.init.constant_(m.bn3.weight, 0)
203
+ elif isinstance(m, BasicBlock):
204
+ nn.init.constant_(m.bn2.weight, 0)
205
+
206
+ def _make_layer(self, block, planes, blocks, stride=1, dilate=False):
207
+ norm_layer = self._norm_layer
208
+ downsample = None
209
+ previous_dilation = self.dilation
210
+ if dilate:
211
+ self.dilation *= stride
212
+ stride = 1
213
+ if stride != 1 or self.inplanes != planes * block.expansion:
214
+ downsample = nn.Sequential(
215
+ conv1x1(self.inplanes, planes * block.expansion, stride),
216
+ norm_layer(planes * block.expansion),
217
+ )
218
+
219
+ layers = []
220
+ layers.append(block(self.inplanes, planes, stride, downsample, self.groups,
221
+ self.base_width, previous_dilation, norm_layer))
222
+ self.inplanes = planes * block.expansion
223
+ for _ in range(1, blocks):
224
+ layers.append(block(self.inplanes, planes, groups=self.groups,
225
+ base_width=self.base_width, dilation=self.dilation,
226
+ norm_layer=norm_layer))
227
+
228
+ return nn.Sequential(*layers)
229
+
230
+ def _make_branch(self, channel_in, channel_out, kernel_size):
231
+ middle_channel = channel_out // 4
232
+ return nn.Sequential(
233
+ nn.Conv2d(channel_in, middle_channel, kernel_size=1, stride=1),
234
+ nn.BatchNorm2d(middle_channel),
235
+ nn.ReLU(),
236
+
237
+ nn.Conv2d(middle_channel, middle_channel, kernel_size=kernel_size, stride=kernel_size),
238
+ nn.BatchNorm2d(middle_channel),
239
+ nn.ReLU(),
240
+
241
+ nn.Conv2d(middle_channel, channel_out, kernel_size=1, stride=1),
242
+ nn.BatchNorm2d(channel_out),
243
+ nn.ReLU(),
244
+
245
+ nn.AdaptiveAvgPool2d((1,1)),
246
+ nn.Flatten(),
247
+ nn.Linear(channel_out, self.num_classes)
248
+ )
249
+
250
+ def _forward_impl(self, x):
251
+ # See note [TorchScript super()]
252
+ x = self.conv1(x)
253
+ x = self.bn1(x)
254
+ x = self.relu(x)
255
+ x = self.maxpool(x)
256
+
257
+ x = self.layer1(x)
258
+ x1 = self.branch1(x)
259
+
260
+ x = self.layer2(x)
261
+ x2 = self.branch2(x)
262
+
263
+ x = self.layer3(x)
264
+ x3 = self.branch3(x)
265
+
266
+ x = self.layer4(x)
267
+ x = self.avgpool(x)
268
+ final_fea = x
269
+ x = torch.flatten(x, 1)
270
+ x = self.fc(x)
271
+
272
+ return {'outputs': [x, x1, x2, x3]}
273
+
274
+ def forward(self, x):
275
+ return self._forward_impl(x)
276
+
277
+ def sdresnet50(num_classes=14, pretrained=True):
278
+ if pretrained:
279
+ model = ResNet(Bottleneck, [3,4,6,3], num_classes=14)
280
+ num_ftrs = model.fc.in_features
281
+ model.fc = nn.Linear(num_ftrs, 1000)
282
+ state_dict = load_state_dict_from_url(model_urls['resnet50'], progress=True)
283
+ model.load_state_dict(state_dict, strict=False)
284
+
285
+ num_ftrs = model.fc.in_features
286
+ model.fc = nn.Linear(num_ftrs, num_classes)
287
+ else:
288
+ model = ResNet(Bottleneck, [3,4,6,3], num_classes=50)
289
+ return model
models/ResNet_cifar.py ADDED
@@ -0,0 +1,688 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ import torch.nn.init as init
5
+ from torch.nn import Parameter
6
+
7
+ model_urls = {
8
+ 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth',
9
+ 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth',
10
+ 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth',
11
+ 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth',
12
+ 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth',
13
+ 'resnext50_32x4d': 'https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth',
14
+ 'resnext101_32x8d': 'https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth',
15
+ 'wide_resnet50_2': 'https://download.pytorch.org/models/wide_resnet50_2-95faca4d.pth',
16
+ 'wide_resnet101_2': 'https://download.pytorch.org/models/wide_resnet101_2-32ee1156.pth',
17
+ }
18
+
19
+ def conv3x3(in_planes, out_planes, stride=1):
20
+ return nn.Conv2d(in_planes, out_planes, kernel_size=3,
21
+ stride=stride, padding=1, bias=False)
22
+
23
+ def conv1x1(in_planes, planes, stride=1):
24
+ return nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride, bias=False)
25
+
26
+ def branchBottleNeck(channel_in, channel_out, kernel_size):
27
+ middle_channel = channel_out//4
28
+ return nn.Sequential(
29
+ nn.Conv2d(channel_in, middle_channel, kernel_size=1, stride=1),
30
+ nn.BatchNorm2d(middle_channel),
31
+ nn.ReLU(),
32
+
33
+ nn.Conv2d(middle_channel, middle_channel, kernel_size=kernel_size, stride=kernel_size),
34
+ nn.BatchNorm2d(middle_channel),
35
+ nn.ReLU(),
36
+
37
+ nn.Conv2d(middle_channel, channel_out, kernel_size=1, stride=1),
38
+ nn.BatchNorm2d(channel_out),
39
+ nn.ReLU(),
40
+ )
41
+
42
+ def branchMLP(channel_in, channel_out):
43
+ middle_channel = channel_out//4
44
+ return nn.Sequential(
45
+ conv1x1(channel_in, channel_in, stride=8),
46
+ nn.BatchNorm2d(512 * block.expansion),
47
+ nn.ReLU(),
48
+ )
49
+
50
+ def invertedBottleNeck(channel_in, channel_out, kernel_size):
51
+ middle_channel = channel_out * 2
52
+ return nn.Sequential(
53
+ nn.Conv2d(channel_in, middle_channel, kernel_size=1, stride=1),
54
+ nn.BatchNorm2d(middle_channel),
55
+ nn.ReLU(),
56
+
57
+ nn.Conv2d(middle_channel, middle_channel, kernel_size=kernel_size, stride=kernel_size),
58
+ nn.BatchNorm2d(middle_channel),
59
+ nn.ReLU(),
60
+
61
+ nn.Conv2d(middle_channel, channel_out, kernel_size=1, stride=1),
62
+ nn.BatchNorm2d(channel_out),
63
+ nn.ReLU(),
64
+ )
65
+
66
+ class BatchNorm2dMul(nn.Module):
67
+ def __init__(self, num_features, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True):
68
+ super(BatchNorm2dMul, self).__init__()
69
+ self.bn = nn.BatchNorm2d(num_features, eps=eps, momentum=momentum, affine=False, track_running_stats=track_running_stats)
70
+ self.gamma = nn.Parameter(torch.ones(num_features))
71
+ self.beta = nn.Parameter(torch.zeros(num_features))
72
+ self.affine = affine
73
+
74
+ def forward(self, x):
75
+ bn_out = self.bn(x)
76
+ if self.affine:
77
+ out = self.gamma[None, :, None, None] * bn_out + self.beta[None, :, None, None]
78
+ return out, bn_out
79
+
80
+ def _weights_init(m):
81
+ classname = m.__class__.__name__
82
+ if isinstance(m, nn.Linear) or isinstance(m, nn.Conv2d):
83
+ init.kaiming_normal_(m.weight)
84
+
85
+ class LambdaLayer(nn.Module):
86
+ def __init__(self, lambd):
87
+ super(LambdaLayer, self).__init__()
88
+ self.lambd = lambd
89
+
90
+ def forward(self, x):
91
+ return self.lambd(x)
92
+
93
+ class BasicBlock_s(nn.Module):
94
+ expansion = 1
95
+
96
+ def __init__(self, in_planes, planes, stride=1):
97
+ super(BasicBlock_s, self).__init__()
98
+ self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
99
+ self.bn1 = nn.BatchNorm2d(planes)
100
+ self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False)
101
+ self.bn2 = nn.BatchNorm2d(planes)
102
+
103
+ self.shortcut = nn.Sequential()
104
+ if stride != 1 or in_planes != self.expansion*planes:
105
+ self.shortcut = nn.Sequential(
106
+ nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False),
107
+ nn.BatchNorm2d(self.expansion*planes)
108
+ )
109
+
110
+ def forward(self, x):
111
+ out = F.relu(self.bn1(self.conv1(x)))
112
+ out = self.bn2(self.conv2(out))
113
+ out += self.shortcut(x)
114
+ out = F.relu(out)
115
+ return out
116
+
117
+ class BasicBlock(nn.Module):
118
+ expansion = 1
119
+ def __init__(self, inplanes, planes, stride=1, downsample=None):
120
+ super(BasicBlock, self).__init__()
121
+ self.conv1 = conv3x3(inplanes, planes, stride)
122
+ self.bn1 = BatchNorm2dMul(planes)
123
+ self.relu = nn.ReLU(inplace=True)
124
+ self.conv2 = conv3x3(planes, planes)
125
+ self.bn2 = BatchNorm2dMul(planes)
126
+ self.downsample = downsample
127
+ self.stride = stride
128
+
129
+ def forward(self, x):
130
+ bn_outputs = []
131
+
132
+ residual = x
133
+ output = self.conv1(x)
134
+ output, bn_out = self.bn1(output)
135
+ bn_outputs.append(bn_out)
136
+ output = self.relu(output)
137
+
138
+ output = self.conv2(output)
139
+ output, bn_out = self.bn2(output)
140
+ bn_outputs.append(bn_out)
141
+
142
+ if self.downsample is not None:
143
+ residual = self.downsample(x)
144
+
145
+ output += residual
146
+ output = self.relu(output)
147
+ return output, bn_outputs
148
+
149
+ class BottleneckBlock(nn.Module):
150
+ expansion = 4
151
+ def __init__(self, inplanes, planes, stride=1, downsample=None):
152
+ super(BottleneckBlock, self).__init__()
153
+ self.conv1 = conv1x1(inplanes, planes)
154
+ self.bn1 = nn.BatchNorm2d(planes)
155
+ self.relu = nn.ReLU(inplace=True)
156
+
157
+ self.conv2 = conv3x3(planes, planes, stride)
158
+ self.bn2 = nn.BatchNorm2d(planes)
159
+
160
+ self.conv3 = conv1x1(planes, planes*self.expansion)
161
+ self.bn3 = nn.BatchNorm2d(planes*self.expansion)
162
+
163
+ self.downsample = downsample
164
+ self.stride = stride
165
+
166
+ def forward(self, x):
167
+ residual = x
168
+
169
+ output = self.conv1(x)
170
+ output = self.bn1(output)
171
+ output = self.relu(output)
172
+
173
+ output = self.conv2(output)
174
+ output = self.bn2(output)
175
+ output = self.relu(output)
176
+
177
+ output = self.conv3(output)
178
+ output = self.bn3(output)
179
+
180
+ if self.downsample is not None:
181
+ residual = self.downsample(x)
182
+
183
+ output += residual
184
+ output = self.relu(output)
185
+
186
+ return output
187
+
188
+ class LayerBlock(nn.Module):
189
+ def __init__(self, block, inplanes, planes, num_blocks, stride):
190
+ super(LayerBlock, self).__init__()
191
+ downsample = None
192
+ if stride !=1 or inplanes != planes * block.expansion:
193
+ downsample = nn.Sequential(
194
+ conv1x1(inplanes, planes * block.expansion, stride),
195
+ nn.BatchNorm2d(planes * block.expansion),
196
+ )
197
+ layer = []
198
+ layer.append(block(inplanes, planes, stride=stride, downsample=downsample))
199
+ inplanes = planes * block.expansion
200
+ for i in range(1, num_blocks):
201
+ layer.append(block(inplanes, planes))
202
+ self.layers = nn.Sequential(*layer)
203
+
204
+ def forward(self, x):
205
+ bn_outputs = []
206
+ for layer in self.layers:
207
+ x, bn_output = layer(x)
208
+ bn_outputs.extend(bn_output)
209
+ return x, bn_outputs
210
+
211
+ class SDResNet(nn.Module):
212
+ """
213
+ Resnet model
214
+
215
+ Args:
216
+ block (class): block type, BasicBlock or BottlenetckBlock
217
+ layers (int list): layer num in each block
218
+ num_classes (int): class num
219
+ """
220
+
221
+ def __init__(self, block, layers, num_classes=10, position_all=True):
222
+ super(SDResNet, self).__init__()
223
+
224
+ self.position_all = position_all
225
+
226
+ self.inplanes = 64
227
+ self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=3, stride=1, padding=1, bias=False)
228
+ self.bn1 = nn.BatchNorm2d(self.inplanes)
229
+ self.relu = nn.ReLU(inplace=True)
230
+ # self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
231
+
232
+ self.layer1 = LayerBlock(block, 64, 64, layers[0], stride=1)
233
+ self.layer2 = LayerBlock(block, 64, 128, layers[1], stride=2)
234
+ self.layer3 = LayerBlock(block, 128, 256, layers[2], stride=2)
235
+ self.layer4 = LayerBlock(block, 256, 512, layers[3], stride=2)
236
+
237
+ self.downsample1_1 = nn.Sequential(
238
+ conv1x1(64 * block.expansion, 512 * block.expansion, stride=8),
239
+ nn.BatchNorm2d(512 * block.expansion),
240
+ nn.ReLU(),
241
+ )
242
+ self.bottleneck1_1 = branchBottleNeck(64 * block.expansion, 512 * block.expansion, kernel_size=8)
243
+ self.avgpool1 = nn.AdaptiveAvgPool2d((1,1))
244
+ self.middle_fc1 = nn.Linear(512 * block.expansion, num_classes)
245
+
246
+
247
+ self.downsample2_1 = nn.Sequential(
248
+ conv1x1(128 * block.expansion, 512 * block.expansion, stride=4),
249
+ nn.BatchNorm2d(512 * block.expansion),
250
+ )
251
+ self.bottleneck2_1 = branchBottleNeck(128 * block.expansion, 512 * block.expansion, kernel_size=4)
252
+ self.avgpool2 = nn.AdaptiveAvgPool2d((1,1))
253
+ self.middle_fc2 = nn.Linear(512 * block.expansion, num_classes)
254
+
255
+
256
+ self.downsample3_1 = nn.Sequential(
257
+ conv1x1(256 * block.expansion, 512 * block.expansion, stride=2),
258
+ nn.BatchNorm2d(512 * block.expansion),
259
+ )
260
+ self.bottleneck3_1 = branchBottleNeck(256 * block.expansion, 512 * block.expansion, kernel_size=2)
261
+ self.avgpool3 = nn.AdaptiveAvgPool2d((1,1))
262
+ self.middle_fc3 = nn.Linear(512 * block.expansion, num_classes)
263
+
264
+ self.avgpool = nn.AdaptiveAvgPool2d((1,1))
265
+ self.fc = nn.Linear(512 * block.expansion, num_classes)
266
+
267
+ self.apply(_weights_init)
268
+
269
+ def _make_layer(self, block, planes, layers, stride=1):
270
+ """A block with 'layers' layers
271
+ Args:
272
+ block (class): block type
273
+ planes (int): output channels = planes * expansion
274
+ layers (int): layer num in the block
275
+ stride (int): the first layer stride in the block
276
+ """
277
+ downsample = None
278
+ if stride !=1 or self.inplanes != planes * block.expansion:
279
+ downsample = nn.Sequential(
280
+ conv1x1(self.inplanes, planes * block.expansion, stride),
281
+ nn.BatchNorm2d(planes * block.expansion),
282
+ )
283
+ layer = []
284
+ layer.append(block(self.inplanes, planes, stride=stride, downsample=downsample))
285
+ self.inplanes = planes * block.expansion
286
+ for i in range(1, layers):
287
+ layer.append(block(self.inplanes, planes))
288
+
289
+ return nn.Sequential(*layer)
290
+
291
+ def forward(self, x, feat_out=False):
292
+ all_bn_outputs = []
293
+
294
+ x = self.conv1(x)
295
+ x = self.bn1(x)
296
+ x = self.relu(x)
297
+ # x = self.maxpool(x)
298
+
299
+ x, bn_outputs = self.layer1(x)
300
+ all_bn_outputs.extend(bn_outputs)
301
+ middle_output1 = self.bottleneck1_1(x)
302
+ middle_output1 = self.avgpool1(middle_output1)
303
+ middle1_fea = middle_output1
304
+ middle_output1 = torch.flatten(middle_output1, 1)
305
+ middle_output1 = self.middle_fc1(middle_output1)
306
+
307
+ x, bn_outputs = self.layer2(x)
308
+ all_bn_outputs.extend(bn_outputs)
309
+ middle_output2 = self.bottleneck2_1(x)
310
+ middle_output2 = self.avgpool2(middle_output2)
311
+ middle2_fea = middle_output2
312
+ middle_output2 = torch.flatten(middle_output2, 1)
313
+ middle_output2 = self.middle_fc2(middle_output2)
314
+
315
+ x, bn_outputs = self.layer3(x)
316
+ all_bn_outputs.extend(bn_outputs)
317
+ middle_output3 = self.bottleneck3_1(x)
318
+ middle_output3 = self.avgpool3(middle_output3)
319
+ middle3_fea = middle_output3
320
+ middle_output3 = torch.flatten(middle_output3, 1)
321
+ middle_output3 = self.middle_fc3(middle_output3)
322
+
323
+ x, bn_outputs = self.layer4(x)
324
+ all_bn_outputs.extend(bn_outputs)
325
+ x = self.avgpool(x)
326
+ final_fea = x
327
+ x = torch.flatten(x, 1)
328
+ x = self.fc(x)
329
+
330
+ if self.position_all and feat_out:
331
+ return {'outputs': [x, middle_output1, middle_output2, middle_output3],
332
+ 'features': [final_fea, middle1_fea, middle2_fea, middle3_fea],
333
+ 'bn_outputs': all_bn_outputs}
334
+ else:
335
+ return x
336
+
337
+ class SDResNet_mlp(nn.Module):
338
+ """
339
+ Resnet model
340
+
341
+ Args:
342
+ block (class): block type, BasicBlock or BottlenetckBlock
343
+ layers (int list): layer num in each block
344
+ num_classes (int): class num
345
+ """
346
+
347
+ def __init__(self, block, layers, num_classes=10, position_all=True):
348
+ super(SDResNet_mlp, self).__init__()
349
+
350
+ self.position_all = position_all
351
+
352
+ self.inplanes = 64
353
+ self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=3, stride=1, padding=1, bias=False)
354
+ self.bn1 = nn.BatchNorm2d(self.inplanes)
355
+ self.relu = nn.ReLU(inplace=True)
356
+ # self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
357
+
358
+ self.layer1 = LayerBlock(block, 64, 64, layers[0], stride=1)
359
+ self.layer2 = LayerBlock(block, 64, 128, layers[1], stride=2)
360
+ self.layer3 = LayerBlock(block, 128, 256, layers[2], stride=2)
361
+ self.layer4 = LayerBlock(block, 256, 512, layers[3], stride=2)
362
+
363
+ self.downsample1_1 = nn.Sequential(
364
+ conv1x1(64 * block.expansion, 512 * block.expansion),
365
+ nn.BatchNorm2d(512 * block.expansion),
366
+ nn.ReLU(),
367
+ )
368
+ self.avgpool1 = nn.AdaptiveAvgPool2d((1,1))
369
+ self.middle_fc1 = nn.Linear(512 * block.expansion, num_classes)
370
+
371
+
372
+ self.downsample2_1 = nn.Sequential(
373
+ conv1x1(128 * block.expansion, 512 * block.expansion),
374
+ nn.BatchNorm2d(512 * block.expansion),
375
+ nn.ReLU()
376
+ )
377
+ self.avgpool2 = nn.AdaptiveAvgPool2d((1,1))
378
+ self.middle_fc2 = nn.Linear(512 * block.expansion, num_classes)
379
+
380
+
381
+ self.downsample3_1 = nn.Sequential(
382
+ conv1x1(256 * block.expansion, 512 * block.expansion),
383
+ nn.BatchNorm2d(512 * block.expansion),
384
+ nn.ReLU()
385
+ )
386
+ self.avgpool3 = nn.AdaptiveAvgPool2d((1,1))
387
+ self.middle_fc3 = nn.Linear(512 * block.expansion, num_classes)
388
+
389
+ self.avgpool = nn.AdaptiveAvgPool2d((1,1))
390
+ self.fc = nn.Linear(512 * block.expansion, num_classes)
391
+
392
+ self.apply(_weights_init)
393
+
394
+ def _make_layer(self, block, planes, layers, stride=1):
395
+ """A block with 'layers' layers
396
+ Args:
397
+ block (class): block type
398
+ planes (int): output channels = planes * expansion
399
+ layers (int): layer num in the block
400
+ stride (int): the first layer stride in the block
401
+ """
402
+ downsample = None
403
+ if stride !=1 or self.inplanes != planes * block.expansion:
404
+ downsample = nn.Sequential(
405
+ conv1x1(self.inplanes, planes * block.expansion, stride),
406
+ nn.BatchNorm2d(planes * block.expansion),
407
+ )
408
+ layer = []
409
+ layer.append(block(self.inplanes, planes, stride=stride, downsample=downsample))
410
+ self.inplanes = planes * block.expansion
411
+ for i in range(1, layers):
412
+ layer.append(block(self.inplanes, planes))
413
+
414
+ return nn.Sequential(*layer)
415
+
416
+ def forward(self, x):
417
+ all_bn_outputs = []
418
+
419
+ x = self.conv1(x)
420
+ x = self.bn1(x)
421
+ x = self.relu(x)
422
+
423
+ x, bn_outputs = self.layer1(x)
424
+ all_bn_outputs.extend(bn_outputs)
425
+ # middle_output1 = self.downsample1_1(x)
426
+ # middle_output1 = self.avgpool1(middle_output1)
427
+ # middle1_fea = middle_output1
428
+ # middle_output1 = torch.flatten(middle_output1, 1)
429
+ # middle_output1 = self.middle_fc1(middle_output1)
430
+
431
+ x, bn_outputs = self.layer2(x)
432
+ all_bn_outputs.extend(bn_outputs)
433
+ # middle_output2 = self.downsample2_1(x)
434
+ # middle_output2 = self.avgpool2(middle_output2)
435
+ # middle2_fea = middle_output2
436
+ # middle_output2 = torch.flatten(middle_output2, 1)
437
+ # middle_output2 = self.middle_fc2(middle_output2)
438
+
439
+ x, bn_outputs = self.layer3(x)
440
+ all_bn_outputs.extend(bn_outputs)
441
+ # middle_output3 = self.downsample3_1(x)
442
+ # middle_output3 = self.avgpool3(middle_output3)
443
+ # middle3_fea = middle_output3
444
+ # middle_output3 = torch.flatten(middle_output3, 1)
445
+ # middle_output3 = self.middle_fc3(middle_output3)
446
+
447
+ x, bn_outputs = self.layer4(x)
448
+ all_bn_outputs.extend(bn_outputs)
449
+ x = self.avgpool(x)
450
+ final_fea = x
451
+ x = torch.flatten(x, 1)
452
+ x = self.fc(x)
453
+
454
+ if self.position_all:
455
+ return {'outputs': [x, middle_output1, middle_output2, middle_output3],
456
+ 'bn_outputs': all_bn_outputs}
457
+ else:
458
+ return {'outputs': [x, x],
459
+ 'bn_outputs': all_bn_outputs}
460
+
461
+ class SDResNet_residual(nn.Module):
462
+ """
463
+ Resnet model
464
+
465
+ Args:
466
+ block (class): block type, BasicBlock or BottlenetckBlock
467
+ layers (int list): layer num in each block
468
+ num_classes (int): class num
469
+ """
470
+
471
+ def __init__(self, block, layers, num_classes=10, position_all=True):
472
+ super(SDResNet_residual, self).__init__()
473
+
474
+ self.position_all = position_all
475
+
476
+ self.inplanes = 64
477
+ self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=3, stride=1, padding=1, bias=False)
478
+ self.bn1 = nn.BatchNorm2d(self.inplanes)
479
+ self.relu = nn.ReLU(inplace=True)
480
+ # self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
481
+
482
+ self.layer1 = LayerBlock(block, 64, 64, layers[0], stride=1)
483
+ self.layer2 = LayerBlock(block, 64, 128, layers[1], stride=2)
484
+ self.layer3 = LayerBlock(block, 128, 256, layers[2], stride=2)
485
+ self.layer4 = LayerBlock(block, 256, 512, layers[3], stride=2)
486
+
487
+ self.bottleneck1_1 = LayerBlock(block, 64, 512, 1, stride=8)
488
+ # branchBottleNeck(64 * block.expansion, 512 * block.expansion, kernel_size=8)
489
+ self.avgpool1 = nn.AdaptiveAvgPool2d((1,1))
490
+ self.middle_fc1 = nn.Linear(512 * block.expansion, num_classes)
491
+
492
+ self.bottleneck2_1 = LayerBlock(block, 128, 512, 1, stride=4)
493
+ # branchBottleNeck(128 * block.expansion, 512 * block.expansion, kernel_size=4)
494
+ self.avgpool2 = nn.AdaptiveAvgPool2d((1,1))
495
+ self.middle_fc2 = nn.Linear(512 * block.expansion, num_classes)
496
+
497
+
498
+ # self.downsample3_1 = nn.Sequential(
499
+ # conv1x1(256 * block.expansion, 512 * block.expansion, stride=2),
500
+ # nn.BatchNorm2d(512 * block.expansion),
501
+ # )
502
+ self.bottleneck3_1 = LayerBlock(block, 256, 512, 1, stride=2)
503
+ self.avgpool3 = nn.AdaptiveAvgPool2d((1,1))
504
+ self.middle_fc3 = nn.Linear(512 * block.expansion, num_classes)
505
+
506
+ self.avgpool = nn.AdaptiveAvgPool2d((1,1))
507
+ self.fc = nn.Linear(512 * block.expansion, num_classes)
508
+
509
+ self.apply(_weights_init)
510
+
511
+ def _make_layer(self, block, planes, layers, stride=1):
512
+ """A block with 'layers' layers
513
+ Args:
514
+ block (class): block type
515
+ planes (int): output channels = planes * expansion
516
+ layers (int): layer num in the block
517
+ stride (int): the first layer stride in the block
518
+ """
519
+ downsample = None
520
+ if stride !=1 or self.inplanes != planes * block.expansion:
521
+ downsample = nn.Sequential(
522
+ conv1x1(self.inplanes, planes * block.expansion, stride),
523
+ nn.BatchNorm2d(planes * block.expansion),
524
+ )
525
+ layer = []
526
+ layer.append(block(self.inplanes, planes, stride=stride, downsample=downsample))
527
+ self.inplanes = planes * block.expansion
528
+ for i in range(1, layers):
529
+ layer.append(block(self.inplanes, planes))
530
+
531
+ return nn.Sequential(*layer)
532
+
533
+ def forward(self, x):
534
+ all_bn_outputs = []
535
+
536
+ x = self.conv1(x)
537
+ x = self.bn1(x)
538
+ x = self.relu(x)
539
+ # x = self.maxpool(x)
540
+
541
+ x, bn_outputs = self.layer1(x)
542
+ all_bn_outputs.extend(bn_outputs)
543
+ middle_output1, _ = self.bottleneck1_1(x)
544
+ middle_output1 = self.avgpool1(middle_output1)
545
+ middle1_fea = middle_output1
546
+ middle_output1 = torch.flatten(middle_output1, 1)
547
+ middle_output1 = self.middle_fc1(middle_output1)
548
+
549
+ x, bn_outputs = self.layer2(x)
550
+ all_bn_outputs.extend(bn_outputs)
551
+ middle_output2, _ = self.bottleneck2_1(x)
552
+ middle_output2 = self.avgpool2(middle_output2)
553
+ middle2_fea = middle_output2
554
+ middle_output2 = torch.flatten(middle_output2, 1)
555
+ middle_output2 = self.middle_fc2(middle_output2)
556
+
557
+ x, bn_outputs = self.layer3(x)
558
+ all_bn_outputs.extend(bn_outputs)
559
+ middle_output3, _ = self.bottleneck3_1(x)
560
+ middle_output3 = self.avgpool3(middle_output3)
561
+ middle3_fea = middle_output3
562
+ middle_output3 = torch.flatten(middle_output3, 1)
563
+ middle_output3 = self.middle_fc3(middle_output3)
564
+
565
+ x, bn_outputs = self.layer4(x)
566
+ all_bn_outputs.extend(bn_outputs)
567
+ x = self.avgpool(x)
568
+ final_fea = x
569
+ x = torch.flatten(x, 1)
570
+ x = self.fc(x)
571
+
572
+ if self.position_all:
573
+ return {'outputs': [x, middle_output1, middle_output2, middle_output3],
574
+ 'features': [final_fea, middle1_fea, middle2_fea, middle3_fea],
575
+ 'bn_outputs': all_bn_outputs}
576
+ else:
577
+ return {'outputs': [x, middle_output3],
578
+ 'features': [final_fea, middle1_fea, middle2_fea, middle3_fea],
579
+ 'bn_outputs': all_bn_outputs}
580
+
581
+ class SDResNet_s(nn.Module):
582
+ """
583
+ Resnet model small
584
+
585
+ Args:
586
+ block (class): block type, BasicBlock or BottlenetckBlock
587
+ layers (int list): layer num in each block
588
+ num_classes (int): class num
589
+ """
590
+
591
+ def __init__(self, block, layers, num_classes=10):
592
+ super(SDResNet_s, self).__init__()
593
+
594
+ self.inplanes = 16
595
+ self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=3, stride=1, padding=1, bias=False)
596
+ self.bn1 = nn.BatchNorm2d(self.inplanes)
597
+ self.relu = nn.ReLU(inplace=True)
598
+ # self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
599
+
600
+ self.layer1 = self._make_layer(block, 16, layers[0])
601
+ self.layer2 = self._make_layer(block, 32, layers[1], stride=2)
602
+ self.layer3 = self._make_layer(block, 64, layers[2], stride=2)
603
+
604
+ self.downsample1_1 = nn.Sequential(
605
+ conv1x1(16 * block.expansion, 64 * block.expansion, stride=4),
606
+ nn.BatchNorm2d(64 * block.expansion),
607
+ )
608
+ self.bottleneck1_1 = branchBottleNeck(16 * block.expansion, 64 * block.expansion, kernel_size=4)
609
+ self.avgpool1 = nn.AdaptiveAvgPool2d((1,1))
610
+ self.middle_fc1 = nn.Linear(64 * block.expansion, num_classes)
611
+
612
+
613
+ self.downsample2_1 = nn.Sequential(
614
+ conv1x1(32 * block.expansion, 64 * block.expansion, stride=2),
615
+ nn.BatchNorm2d(64 * block.expansion),
616
+ )
617
+ self.bottleneck2_1 = branchBottleNeck(32 * block.expansion, 64 * block.expansion, kernel_size=2)
618
+ self.avgpool2 = nn.AdaptiveAvgPool2d((1,1))
619
+ self.middle_fc2 = nn.Linear(64 * block.expansion, num_classes)
620
+
621
+ self.avgpool = nn.AdaptiveAvgPool2d((1,1))
622
+ self.fc = nn.Linear(64 * block.expansion, num_classes)
623
+
624
+ for m in self.modules():
625
+ if isinstance(m, nn.Conv2d):
626
+ nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
627
+ elif isinstance(m, nn.BatchNorm2d):
628
+ nn.init.constant_(m.weight, 1)
629
+ nn.init.constant_(m.bias, 0)
630
+
631
+ def _make_layer(self, block, planes, layers, stride=1):
632
+ """A block with 'layers' layers
633
+ Args:
634
+ block (class): block type
635
+ planes (int): output channels = planes * expansion
636
+ layers (int): layer num in the block
637
+ stride (int): the first layer stride in the block
638
+ """
639
+ strides = [stride] + [1]*(layers-1)
640
+ layers = []
641
+ for stride in strides:
642
+ layers.append(block(self.inplanes, planes, stride))
643
+ self.inplanes = planes * block.expansion
644
+
645
+ return nn.Sequential(*layers)
646
+
647
+ def forward(self, x):
648
+ x = self.conv1(x)
649
+ x = self.bn1(x)
650
+ x = self.relu(x)
651
+
652
+ x = self.layer1(x)
653
+ middle_output1 = self.bottleneck1_1(x)
654
+ middle_output1 = self.avgpool1(middle_output1)
655
+ middle1_fea = middle_output1
656
+ middle_output1 = torch.flatten(middle_output1, 1)
657
+ middle_output1 = self.middle_fc1(middle_output1)
658
+
659
+ x = self.layer2(x)
660
+ middle_output2 = self.bottleneck2_1(x)
661
+ middle_output2 = self.avgpool2(middle_output2)
662
+ middle2_fea = middle_output2
663
+ middle_output2 = torch.flatten(middle_output2, 1)
664
+ middle_output2 = self.middle_fc2(middle_output2)
665
+
666
+ x = self.layer3(x)
667
+ x = self.avgpool(x)
668
+ final_fea = x
669
+ x = torch.flatten(x, 1)
670
+ x = self.fc(x)
671
+
672
+ return {'outputs': [x, middle_output1, middle_output2],
673
+ 'features': [final_fea, middle1_fea, middle2_fea]}
674
+
675
+ def sdresnet18(num_classes=10, position_all=True):
676
+ return SDResNet(BasicBlock, [2,2,2,2], num_classes=num_classes, position_all=position_all)
677
+
678
+ def sdresnet34(num_classes=10, position_all=True):
679
+ return SDResNet(BasicBlock, [3,4,6,3], num_classes=num_classes, position_all=position_all)
680
+
681
+ def sdresnet34_mlp(num_classes=10, position_all=True):
682
+ return SDResNet_mlp(BasicBlock, [3,4,6,3], num_classes=num_classes, position_all=position_all)
683
+
684
+ def sdresnet34_residual(num_classes=10, position_all=True):
685
+ return SDResNet_residual(BasicBlock, [3,4,6,3], num_classes=num_classes, position_all=position_all)
686
+
687
+ def sdresnet32(num_classes=10):
688
+ return SDResNet_s(BasicBlock_s, [5,5,5], num_classes=num_classes)
models/__init__.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ from .ResNet_Imagenet import sdresnet50
2
+ from .ResNet_cifar import sdresnet18, sdresnet34, sdresnet32
3
+ from .ResNet_cifar import sdresnet34_mlp, sdresnet34_residual
4
+ from .InceptionResNetV2 import InceptionResNetV2
models/__pycache__/CNN.cpython-310.pyc ADDED
Binary file (6.44 kB). View file
 
models/__pycache__/InceptionResNetV2.cpython-310.pyc ADDED
Binary file (8.65 kB). View file
 
models/__pycache__/ResNet_Imagenet.cpython-310.pyc ADDED
Binary file (8.04 kB). View file
 
models/__pycache__/ResNet_cifar.cpython-310.pyc ADDED
Binary file (16 kB). View file
 
models/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (426 Bytes). View file
 
requirements.txt ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ numpy
2
+ pillow
3
+ matplotlib
4
+ scikit-learn
5
+ scipy
6
+ torch
7
+ torchvision
train_cifar_c2mt.py ADDED
@@ -0,0 +1,807 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import print_function
2
+ import sys
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.optim as optim
6
+ import torch.nn.functional as F
7
+ import torch.backends.cudnn as cudnn
8
+ import random
9
+ import os
10
+ import argparse
11
+ import numpy as np
12
+ from PreResNet import *
13
+ from sklearn.mixture import GaussianMixture
14
+ import dataloader_cifar as dataloader
15
+ import matplotlib.pyplot as plt
16
+ import copy
17
+ import seaborn as sns
18
+ # from sklearn.mixture import GaussianMixture
19
+ from sklearn.cluster import KMeans
20
+ from sklearn.cluster import Birch
21
+ import matplotlib
22
+
23
+ parser = argparse.ArgumentParser(description='PyTorch CIFAR Training')
24
+ parser.add_argument('--batch_size', default=128, type=int, help='train batchsize')
25
+ parser.add_argument('--lr', '--learning_rate', default=0.02, type=float, help='initial learning rate')
26
+ parser.add_argument('--noise_mode', default='asym')
27
+ parser.add_argument('--alpha', default=4, type=float, help='parameter for Beta')
28
+ parser.add_argument('--lambda_u', default=150, type=float, help='weight for unsupervised loss')
29
+ parser.add_argument('--p_threshold', default=0.5, type=float, help='clean probability threshold')
30
+ parser.add_argument('--T', default=0.5, type=float, help='sharpening temperature')
31
+ parser.add_argument('--num_epochs', default=300, type=int)
32
+ parser.add_argument('--r', default=0.3, type=float, help='noise ratio')
33
+ parser.add_argument('--id', default='')
34
+ parser.add_argument('--seed', default=123)
35
+ parser.add_argument('--gpuid', default=0, type=int)
36
+ parser.add_argument('--num_class', default=100, type=int)
37
+ # parser.add_argument('--data_path', default='./data/cifar-10-batches-py', type=str, help='path to dataset')
38
+ # parser.add_argument('--dataset', default='cifar10', type=str)
39
+ parser.add_argument('--data_path', default='./data/cifar-100-python', type=str, help='path to dataset')
40
+ parser.add_argument('--dataset', default='cifar100', type=str)
41
+ args = parser.parse_args()
42
+
43
+ torch.cuda.set_device(args.gpuid)
44
+ random.seed(args.seed)
45
+ torch.manual_seed(args.seed)
46
+ torch.cuda.manual_seed_all(args.seed)
47
+
48
+ mse = torch.nn.MSELoss(reduction='none').cuda()
49
+
50
+
51
+ # Training
52
+ def train(epoch, net, net2, optimizer, labeled_trainloader, unlabeled_trainloader, mask=None, f_G=None, new_y=None):
53
+ net.train()
54
+ net2.eval() # fix one network and train the other
55
+
56
+ unlabeled_train_iter = iter(unlabeled_trainloader)
57
+ num_iter = (len(labeled_trainloader.dataset) // args.batch_size) + 1
58
+ mse_total = 0
59
+ for batch_idx, (inputs_x, inputs_x2, labels_x, w_x) in enumerate(labeled_trainloader):
60
+ try:
61
+ inputs_u, inputs_u2 = unlabeled_train_iter.__next__()
62
+ except:
63
+ unlabeled_train_iter = iter(unlabeled_trainloader)
64
+ inputs_u, inputs_u2 = unlabeled_train_iter.__next__()
65
+ batch_size = inputs_x.size(0)
66
+
67
+ # Transform label to one-hot,转为0-1矩阵
68
+ labels_x = torch.zeros(batch_size, args.num_class).scatter_(1, labels_x.view(-1, 1), 1)
69
+ w_x = w_x.view(-1, 1).type(torch.FloatTensor)
70
+
71
+ inputs_x, inputs_x2, labels_x, w_x = inputs_x.cuda(), inputs_x2.cuda(), labels_x.cuda(), w_x.cuda()
72
+ inputs_u, inputs_u2 = inputs_u.cuda(), inputs_u2.cuda()
73
+
74
+ with torch.no_grad():
75
+ # label co-guessing of unlabeled samples
76
+ outputs_u11, feat_u11 = net(inputs_u, feat_out=True)
77
+ outputs_u12, feat_u12 = net(inputs_u2, feat_out=True)
78
+ outputs_u21, feat_u21 = net2(inputs_u, feat_out=True)
79
+ outputs_u22, feat_u22 = net2(inputs_u2, feat_out=True)
80
+
81
+ # 取average of 所有网络的输出,作者利用了所谓的augmentation
82
+ pu = (torch.softmax(outputs_u11, dim=1) + torch.softmax(outputs_u12, dim=1)
83
+ + torch.softmax(outputs_u21, dim=1) + torch.softmax(outputs_u22, dim=1)) / 4
84
+ ptu = pu ** (1 / args.T) # temparature sharpening
85
+
86
+ # Algorithm 1 中的shapen(qb,T)
87
+ targets_u = ptu / ptu.sum(dim=1, keepdim=True) # normalize
88
+ targets_u = targets_u.detach()
89
+
90
+ # label refinement of labeled samples
91
+ outputs_x, feat_x1 = net(inputs_x, feat_out=True)
92
+ outputs_x2, feat_x2 = net(inputs_x2, feat_out=True)
93
+
94
+ # 取labeled的输出平均值
95
+ px = (torch.softmax(outputs_x, dim=1) + torch.softmax(outputs_x2, dim=1)) / 2
96
+
97
+ # 公式(3)(4)退火
98
+ px = w_x * labels_x + (1 - w_x) * px
99
+ ptx = px ** (1 / args.T) # temparature sharpening
100
+
101
+ targets_x = ptx / ptx.sum(dim=1, keepdim=True) # normalize
102
+ targets_x = targets_x.detach()
103
+ # aaa = torch.argmax(labels_x, dim=1)
104
+ # mse_loss = torch.sum(mse((feat_x1+feat_x2)/2, f_G[aaa]), 1)
105
+ # mse_total = (mse_total + torch.sum(mse_loss) / len(mse_loss))/2
106
+ # mixmatch
107
+ l = np.random.beta(args.alpha, args.alpha)
108
+ # 促使X'更加靠近labeled sample而不是无监督样本
109
+ l = max(l, 1 - l)
110
+
111
+ all_inputs = torch.cat([inputs_x, inputs_x2, inputs_u, inputs_u2], dim=0)
112
+ all_targets = torch.cat([targets_x, targets_x, targets_u, targets_u], dim=0)
113
+
114
+ # 随机输出mini batch的序号,来mixup
115
+ idx = torch.randperm(all_inputs.size(0))
116
+
117
+ input_a, input_b = all_inputs, all_inputs[idx]
118
+ target_a, target_b = all_targets, all_targets[idx]
119
+
120
+ # 利用mix但是促使模型更偏向于label而不是UNlabel
121
+ mixed_input = l * input_a + (1 - l) * input_b
122
+ mixed_target = l * target_a + (1 - l) * target_b
123
+
124
+ logits = net(mixed_input)
125
+ # 输出被排列成两部分,input_x、Input_u
126
+ logits_x = logits[:batch_size * 2]
127
+ logits_u = logits[batch_size * 2:]
128
+
129
+ # 利用公式(9)-(10)计算损失函数,其中lamb是所谓的warm up
130
+ Lx, Lu, lamb = criterion(logits_x, mixed_target[:batch_size * 2],
131
+ logits_u, mixed_target[batch_size * 2:],
132
+ epoch + batch_idx / num_iter, warm_up)
133
+
134
+ # regularization
135
+ prior = torch.ones(args.num_class) / args.num_class
136
+ prior = prior.cuda()
137
+ pred_mean = torch.softmax(logits, dim=1).mean(0)
138
+ # 一般来说会省略固定的prior部分,只取last term
139
+ # lambR=1
140
+ penalty = torch.sum(prior * torch.log(prior / pred_mean))
141
+
142
+ # lamb是通过warm和current epoch比较得出的百分数,意味着随着epoch进行,Lu所占比重会逐渐增加
143
+ # 前期需要保持标准CE损失,但是实际还有penalty
144
+ # loss = Lx + lamb * Lu + penalty
145
+ loss = Lx + penalty + lamb * Lu
146
+ # compute gradient and do SGD step
147
+ optimizer.zero_grad()
148
+ loss.backward()
149
+ optimizer.step()
150
+ if batch_idx % 200 == 0:
151
+ sys.stdout.write('\r')
152
+ sys.stdout.write(
153
+ '%s:%.1f-%s | Epoch [%3d/%3d] Iter[%3d/%3d]\t Labeled loss: %.2f Unlabeled loss: %.2f\n'
154
+ % (args.dataset, args.r, args.noise_mode, epoch, args.num_epochs, batch_idx + 1, num_iter,
155
+ Lx.item(), Lu.item()))
156
+ sys.stdout.flush()
157
+ # print('\r mse loss:%.4f\n' % mse_total, end='end', flush=True)
158
+ # print('\r mse loss:%.4f\n' % mse_total, end='end', flush=True)
159
+
160
+ def mixup_criterion(pred, y_a, y_b, lam):
161
+ c = F.log_softmax(pred, 1)
162
+ return lam * F.cross_entropy(c, y_a) + (1 - lam) * F.cross_entropy(c, y_b)
163
+
164
+
165
+ soft_mix_warm = False
166
+
167
+ def warmup(epoch, net, optimizer, dataloader):
168
+ net.train()
169
+ num_iter = (len(dataloader.dataset) // dataloader.batch_size) + 1
170
+ for batch_idx, (inputs, labels, path) in enumerate(dataloader):
171
+ optimizer.zero_grad()
172
+ l = np.random.beta(args.alpha, args.alpha)
173
+ # 促使X'更加靠近labeled sample而不是无监督样本
174
+ l = max(l, 1 - l)
175
+ idx = torch.randperm(inputs.size(0))
176
+ targets = torch.zeros(inputs.size(0), args.num_class).scatter_(1, labels.view(-1, 1), 1).cuda()
177
+ targets = torch.clamp(targets, 1e-4, 1.)
178
+ inputs, labels = inputs.cuda(), labels.cuda()
179
+ if soft_mix_warm:
180
+ input_a, input_b = inputs, inputs[idx]
181
+ target_a, target_b = targets, targets[idx]
182
+ labels_a, labels_b = labels, labels[idx]
183
+
184
+ # 利用mix但是促使模型更偏向于label而不是UNlabel
185
+ mixed_input = l * input_a + (1 - l) * input_b
186
+ mixed_target = l * target_a + (1 - l) * target_b
187
+
188
+ outputs = net(mixed_input)
189
+ loss = mixup_criterion(outputs, labels_a, labels_b, l)
190
+ L = loss
191
+ else:
192
+ outputs = net(inputs)
193
+ loss = CEloss(outputs, labels)
194
+ if args.noise_mode == 'asym': # penalize confident prediction for asymmetric noise
195
+ penalty = conf_penalty(outputs)
196
+ L = loss + penalty
197
+ elif args.noise_mode == 'sym':
198
+ L = loss
199
+ L.backward()
200
+ optimizer.step()
201
+ if batch_idx % 200 == 0:
202
+ sys.stdout.write('\r')
203
+ sys.stdout.write('%s:%.1f-%s | Epoch [%3d/%3d] Iter[%3d/%3d]\t CE-loss: %.4f'
204
+ % (args.dataset, args.r, args.noise_mode, epoch, args.num_epochs, batch_idx + 1, num_iter,
205
+ loss.item()))
206
+ sys.stdout.flush()
207
+
208
+
209
+ def test(epoch, net1, net2, best_acc, w_glob=None):
210
+ if w_glob is None:
211
+ net1.eval()
212
+ net2.eval()
213
+ correct = 0
214
+ total = 0
215
+ with torch.no_grad():
216
+ for batch_idx, (inputs, targets) in enumerate(test_loader):
217
+ inputs, targets = inputs.cuda(), targets.cuda()
218
+ outputs1 = net1(inputs)
219
+ outputs2 = net2(inputs)
220
+ outputs = outputs1 + outputs2
221
+ _, predicted = torch.max(outputs, 1)
222
+
223
+ total += targets.size(0)
224
+ correct += predicted.eq(targets).cpu().sum().item()
225
+ acc = 100. * correct / total
226
+ if best_acc < acc:
227
+ best_acc = acc
228
+ print("\n| Ensemble network Test Epoch #%d\t Accuracy: %.2f, best_acc: %.2f%%\n" % (epoch, acc, best_acc))
229
+ test_log.write('ensemble_Epoch:%d Accuracy:%.2f, best_acc: %.2f\n' % (epoch, acc, best_acc))
230
+ test_log.flush()
231
+ else:
232
+ net1_w_bak = net1.state_dict()
233
+ net1.load_state_dict(w_glob)
234
+ net1.eval()
235
+ correct = 0
236
+ total = 0
237
+ with torch.no_grad():
238
+ for batch_idx, (inputs, targets) in enumerate(test_loader):
239
+ inputs, targets = inputs.cuda(), targets.cuda()
240
+ outputs1 = net1(inputs)
241
+ _, predicted = torch.max(outputs1, 1)
242
+ total += targets.size(0)
243
+ correct += predicted.eq(targets).cpu().sum().item()
244
+ acc = 100. * correct / total
245
+ if best_acc < acc:
246
+ best_acc = acc
247
+ print("\n| Global network Test Epoch #%d\t Accuracy: %.2f, best_acc: %.2f%%\n" % (epoch, acc, best_acc))
248
+ test_log.write('global_Epoch:%d Accuracy:%.2f, best_acc: %.2f\n' % (epoch, acc, best_acc))
249
+ test_log.flush()
250
+ # 恢复权重
251
+ net1.load_state_dict(net1_w_bak)
252
+ return best_acc
253
+
254
+ feat_dim = 512 #是否可以加个全连接改成128
255
+ sim = torch.nn.CosineSimilarity(dim=1)
256
+
257
+ loss_func = torch.nn.CrossEntropyLoss(reduction='none')
258
+ def get_small_loss_samples(y_pred, y_true, forget_rate):
259
+ loss = loss_func(y_pred, y_true)
260
+ ind_sorted = np.argsort(loss.data.cpu()).cuda()
261
+ loss_sorted = loss[ind_sorted]
262
+
263
+ remember_rate = 1 - forget_rate
264
+ num_remember = int(remember_rate * len(loss_sorted))
265
+
266
+ ind_update = ind_sorted[:num_remember]
267
+
268
+ return ind_update
269
+
270
+ def get_small_loss_by_loss_list(loss_list, forget_rate, eval_loader):
271
+ remember_rate = 1 - forget_rate
272
+ idx_list = []
273
+ for i in range(10):
274
+ class_idx = np.where(np.array(eval_loader.dataset.noise_label)[:] == i)[0]
275
+ # class_idx = torch.from_numpy(class_idx).cuda()
276
+ loss_per_class = loss_list[class_idx] #取对应target的loss
277
+ num_remember = int(remember_rate * len(loss_per_class))
278
+ ind_sorted = np.argsort(loss_per_class.data.cpu())
279
+ ind_update = ind_sorted[:num_remember].tolist()
280
+ idx_list.append(ind_update)
281
+
282
+ return idx_list
283
+
284
+ def eval_train(model, all_loss):
285
+ model.eval()
286
+ losses = torch.zeros(50000)
287
+ f_G = torch.zeros(args.num_class, feat_dim).cuda()
288
+ f_all = torch.zeros(50000, feat_dim).cuda()
289
+ n_labels = torch.zeros(args.num_class, 1).cuda()
290
+ y_k_tilde = torch.zeros(50000)
291
+ mask = np.zeros(50000)
292
+ with torch.no_grad():
293
+ for batch_idx, (inputs, targets, index) in enumerate(eval_loader):
294
+ inputs, targets = inputs.cuda(), targets.cuda()
295
+ outputs, feat = model(inputs, feat_out=True)
296
+ loss = CE(outputs, targets)
297
+ _, predicted = torch.max(outputs, 1)
298
+ for b in range(inputs.size(0)):
299
+ losses[index[b]] = loss[b]
300
+ f_G[predicted[b]] += feat[b]
301
+ n_labels[predicted[b]] += 1
302
+ f_all[index] = feat
303
+ assert torch.sum(n_labels) == 50000
304
+ for i in range(len(n_labels)):
305
+ if n_labels[i] == 0:
306
+ n_labels[i] = 1
307
+ f_G = torch.div(f_G, n_labels)
308
+ f_G = F.normalize(f_G, dim=1)
309
+ f_all = F.normalize(f_all, dim=1)
310
+ temp = f_G.t()
311
+ sim_all = torch.mm(f_all, temp) # .cpu().numpy()
312
+ y_k_tilde = torch.argmax(sim_all.cpu(), dim=1)
313
+ with torch.no_grad():
314
+ for batch_idx, (inputs, targets, index) in enumerate(eval_loader):
315
+ for i in range(len(index)):
316
+ if y_k_tilde[index[i]] == targets[i]:
317
+ mask[index[i]] = 1
318
+ losses = (losses - losses.min()) / (losses.max() - losses.min())
319
+ all_loss.append(losses)
320
+
321
+ if args.r == 0.9:
322
+ # average loss over last 5 epochs to improve convergence stability
323
+ history = torch.stack(all_loss)
324
+ input_loss = history[-5:].mean(0)
325
+ input_loss = input_loss.reshape(-1, 1)
326
+ else:
327
+ input_loss = losses.reshape(-1, 1)
328
+
329
+ # fit a two-component GMM to the loss
330
+ # 参数如下:
331
+ # n_components 聚类数量,max_iter 最大迭代次数,tol 阈值低于停止,reg_covar 协方差矩阵对角线上非负正则化参数,接近0即可
332
+ gmm = GaussianMixture(n_components=2, max_iter=10, tol=1e-2, reg_covar=5e-4)
333
+ gmm.fit(input_loss)
334
+ prob = gmm.predict_proba(input_loss)
335
+ prob = prob[:, gmm.means_.argmin()]
336
+ return prob, all_loss, losses.numpy(), mask, f_G
337
+
338
+ def mix_data_lab(x, y, alpha=1.0):
339
+ '''Returns mixed inputs, pairs of targets, and lambda'''
340
+ if alpha > 0:
341
+ lam = np.random.beta(alpha, alpha)
342
+ else:
343
+ lam = 1
344
+
345
+ batch_size = x.size()[0]
346
+ index = torch.randperm(batch_size).cuda()
347
+
348
+ lam = max(lam, 1 - lam)
349
+ mixed_x = lam * x + (1 - lam) * x[index, :]
350
+ y_a, y_b = y, y[index]
351
+
352
+ return mixed_x, y_a, y_b, index, lam
353
+
354
+
355
+ def linear_rampup(current, warm_up, rampup_length=16):
356
+ # 线性warm_up,对sym噪声使用标准CE训练一段时间
357
+ # 实际warm up epoch是warm_up+rampup_length
358
+ current = np.clip((current - warm_up) / rampup_length, 0.0, 1.0)
359
+ re_val = args.lambda_u * float(current)
360
+ # print(" current warm up parameters:", current)
361
+ # print("return parameters:", re_val)
362
+ return re_val
363
+
364
+
365
+ class SemiLoss(object):
366
+ def __call__(self, outputs_x, targets_x, outputs_u, targets_u, epoch, warm_up):
367
+ probs_u = torch.softmax(outputs_u, dim=1)
368
+
369
+ # 利用mixup后的交叉熵,px输出*log(px_model)
370
+ Lx = -torch.mean(torch.sum(F.log_softmax(outputs_x, dim=1) * targets_x, dim=1))
371
+ # 而UNlabel则是均方误差,p_u输出-pu_model
372
+ Lu = torch.mean((probs_u - targets_u) ** 2)
373
+
374
+ return Lx, Lu, linear_rampup(epoch, warm_up)
375
+
376
+
377
+ class NegEntropy(object):
378
+ def __call__(self, outputs):
379
+ probs = torch.softmax(outputs, dim=1)
380
+ return torch.mean(torch.sum(probs.log() * probs, dim=1))
381
+
382
+
383
+ def create_model():
384
+ # 其实是pre-resnet18,使用的是pre-resnet block
385
+ model = ResNet18(num_classes=args.num_class)
386
+ model = model.cuda()
387
+ return model
388
+
389
+ def plotHistogram(model_1_loss, model_2_loss, noise_index, clean_index, epoch, round, noise_rate):
390
+ title = 'Epoch-' + str(epoch)+':'
391
+ fig = plt.figure()
392
+ plt.subplot(121)
393
+ gmm = GaussianMixture(n_components=2, max_iter=20, tol=1e-2, random_state=0, reg_covar=5e-4)
394
+ model_1_loss = np.reshape(model_1_loss, (-1, 1))
395
+ gmm.fit(model_1_loss) # fit the loss
396
+
397
+ # plot resulting fit
398
+ x_range = np.linspace(0, 1, 1000)
399
+ pdf = np.exp(gmm.score_samples(x_range.reshape(-1, 1)))
400
+ responsibilities = gmm.predict_proba(x_range.reshape(-1, 1))
401
+ pdf_individual = responsibilities * pdf[:, np.newaxis]
402
+ plt.hist(np.array(model_1_loss[noise_index]), density=True, bins=100, alpha=0.5,histtype='bar', color='red', label='Noisy subset')
403
+ plt.hist(np.array(model_1_loss[clean_index]), density=True, bins=100, alpha=0.5,histtype='bar', color='blue', label='Clean subset')
404
+ plt.plot(x_range, pdf, '-k', label='Mixture')
405
+ plt.plot(x_range, pdf_individual, '--', label='Component')
406
+ plt.legend(loc='upper right', prop={'size': 12})
407
+ plt.xlabel('Normalized loss')
408
+ plt.ylabel('Estimated pdf')
409
+ plt.title(title+'Model_1')
410
+
411
+ plt.subplot(122)
412
+ gmm = GaussianMixture(n_components=2, max_iter=20, tol=1e-2, random_state=0, reg_covar=5e-4)
413
+ model_2_loss = np.reshape(model_2_loss, (-1, 1))
414
+ gmm.fit(model_2_loss) # fit the loss
415
+
416
+ # plot resulting fit
417
+ x_range = np.linspace(0, 1, 1000)
418
+ pdf = np.exp(gmm.score_samples(x_range.reshape(-1, 1)))
419
+ responsibilities = gmm.predict_proba(x_range.reshape(-1, 1))
420
+ pdf_individual = responsibilities * pdf[:, np.newaxis]
421
+ plt.hist(np.array(model_2_loss[noise_index]), density=True, bins=100, alpha=0.5,histtype='bar', color='red', label='Noisy subset')
422
+ plt.hist(np.array(model_2_loss[clean_index]), density=True, bins=100, alpha=0.5,histtype='bar', color='blue', label='Clean subset')
423
+ plt.plot(x_range, pdf, '-k', label='Mixture')
424
+ plt.plot(x_range, pdf_individual, '--', label='Component')
425
+ plt.legend(loc='upper right', prop={'size': 12})
426
+ plt.xlabel('Normalized loss')
427
+ plt.ylabel('Estimated pdf')
428
+ plt.title(title+'Model_2')
429
+
430
+ print('\nlogging histogram...')
431
+ title = 'cifar10_' + str(args.noise_mode) + '_moit_double_' + str(noise_rate)
432
+ plt.savefig(os.path.join('./figure_his/', 'two_model_{}_{}_{}_{}.{}'.format(epoch, round, title, int(soft_mix_warm), ".tif")), dpi=300)
433
+ # plt.show()
434
+ plt.close()
435
+
436
+
437
+ def loss_dist_plot(loss, noisy_index, clean_index, epoch, rou=None, g_file=True, model_name='', loss2=None):
438
+ """
439
+ plot the loss distribution
440
+ :param loss: the list contains the loss per sample
441
+ :param noisy_index: contains the indices of real noisy label
442
+ :param clean_index: contains the indices of real clean label
443
+ :param filename: the generated pdf file name
444
+ :param title: the figure title
445
+ :param g_file: whether to generate the pdf figure file
446
+ :return: None
447
+ """
448
+ if loss2 is None:
449
+ filename = 'one_model_'+str(args.dataset)+'_'+str(args.noise_mode)+'_'+str(args.r)+'_epoch='+str(epoch)
450
+ if rou is None:
451
+ title = 'Epoch-'+str(epoch) + ': ' + str(args.dataset)+' '+str(args.r*100)+'%-'+str(args.noise_mode)
452
+ else:
453
+ title = 'Epoch-' + str(epoch) + ' ' +'Round-'+str(rou)+ ': ' + str(args.dataset) + ' ' + str(int(args.r * 100)) + '%-' + str(args.noise_mode)
454
+ if type(loss) is not np.ndarray:
455
+ loss= loss.numpy()
456
+ sns.set(style='whitegrid')
457
+ gmm = GaussianMixture(n_components=2, max_iter=20, tol=1e-2, random_state=0, reg_covar=5e-4)
458
+ loss = np.reshape(loss, (-1, 1))
459
+ gmm.fit(loss) # fit the loss
460
+
461
+ # plot resulting fit
462
+ x_range = np.linspace(0, 1, 1000)
463
+ pdf = np.exp(gmm.score_samples(x_range.reshape(-1, 1)))
464
+ responsibilities = gmm.predict_proba(x_range.reshape(-1, 1))
465
+ pdf_individual = responsibilities * pdf[:, np.newaxis]
466
+ # sns.distplot(loss[noisy_index], color="red", rug=False,kde=False, label="incorrect",
467
+ # hist_kws={"color": "r", "alpha": 0.5})
468
+ # sns.distplot(loss[clean_index], color="skyblue", rug=False,kde=False, label="correct",
469
+ # hist_kws={"color": "b", "alpha": 0.5})
470
+
471
+ plt.hist(np.array(loss[noisy_index]), density=True, bins=100, histtype='bar', alpha=0.5, color='red',
472
+ label='Noisy subset')
473
+ plt.hist(np.array(loss[clean_index]), density=True, bins=100, histtype='bar', alpha=0.5, color='blue',
474
+ label='Clean subset')
475
+ plt.plot(x_range, pdf, '-k', label='Mixture')
476
+ plt.plot(x_range, pdf_individual, '--', label='Component')
477
+ # plt.plot(x_range, pdf_individual[:][1], '--', color='blue', label='Component 1')
478
+
479
+ plt.title(title, fontsize=20)
480
+ plt.xlabel('Normalized loss', fontsize=24)
481
+ plt.ylabel('Estimated pdf', fontsize=24)
482
+
483
+ plt.tick_params(labelsize=24)
484
+ plt.legend(loc='upper right', prop={'size': 12})
485
+ # plt.tight_layout()
486
+ if g_file:
487
+ plt.savefig('./figure_his/{0}.tif'.format(filename+model_name), bbox_inches='tight', dpi=300)
488
+ #plt.show()
489
+ plt.close()
490
+ else:
491
+ filename = 'noise_'+str(args.dataset) + '_' + str(args.noise_mode) + '_' + str(args.r) + '_epoch=' + str(epoch)
492
+ if rou is None:
493
+ title = 'Epoch-' + str(epoch) + ': ' + str(args.dataset) + ' ' + str(args.r * 100) + '%-' + str(
494
+ args.noise_mode)
495
+ else:
496
+ title = 'Epoch-' + str(epoch) + ' ' + 'Round-' + str(rou) + ': ' + str(args.dataset) + ' ' + str(
497
+ args.r * 100) + '%-' + str(args.noise_mode)
498
+ if type(loss) is not np.ndarray:
499
+ loss = loss.numpy()
500
+ if type(loss2) is not np.ndarray:
501
+ loss2 = loss2.numpy()
502
+ fig = plt.figure()
503
+ plt.subplot(121)
504
+ sns.set(style='whitegrid')
505
+ sns.distplot(loss[noisy_index], color="red", rug=False, kde=False, label="incorrect",
506
+ hist_kws={"color": "r", "alpha": 0.5})
507
+ sns.distplot(loss[clean_index], color="skyblue", rug=False, kde=False, label="correct",
508
+ hist_kws={"color": "b", "alpha": 0.5})
509
+ plt.title('Model_1', fontsize=32)
510
+ plt.xlabel('Normalized loss', fontsize=32)
511
+ plt.ylabel('Sample number', fontsize=32)
512
+ plt.tick_params(labelsize=32)
513
+ plt.legend(loc='upper right', prop={'size': 24})
514
+ plt.subplot(122)
515
+ sns.set(style='whitegrid')
516
+ sns.distplot(loss2[noisy_index], color="red", rug=False, kde=False, label="incorrect",
517
+ hist_kws={"color": "r", "alpha": 0.5})
518
+ sns.distplot(loss2[clean_index], color="skyblue", rug=False, kde=False, label="correct",
519
+ hist_kws={"color": "b", "alpha": 0.5})
520
+ plt.title('Model_2', fontsize=32)
521
+ plt.xlabel('Normalized loss', fontsize=32)
522
+ plt.ylabel('Sample number', fontsize=32)
523
+ plt.tick_params(labelsize=32)
524
+ plt.legend(loc='upper right', prop={'size': 24})
525
+ # plt.tight_layout()
526
+ if g_file:
527
+ plt.savefig('./figure_his/{0}.tif'.format(filename + model_name), bbox_inches='tight', dpi=300)
528
+ # plt.show()
529
+ plt.close()
530
+
531
+
532
+ def loss_dist_plot_real(loss, epoch, rou=None, g_file=True, model_name=''):
533
+ """
534
+ plot the loss distribution
535
+ :param loss: the list contains the loss per sample
536
+ :param noisy_index: contains the indices of real noisy label
537
+ :param clean_index: contains the indices of real clean label
538
+ :param filename: the generated pdf file name
539
+ :param title: the figure title
540
+ :param g_file: whether to generate the pdf figure file
541
+ :return: None
542
+ """
543
+ filename = str(args.dataset) + '_' + str(args.noise_mode) + '_' + str(args.r) + '_epoch=' + str(epoch)
544
+ if rou is None:
545
+ title = 'Epoch-' + str(epoch) + ': ' + str(args.dataset) + ' ' + str(args.r * 100) + '%-' + str(args.noise_mode)
546
+ else:
547
+ title = 'Epoch-' + str(epoch) + ' ' + 'Round-' + str(rou) + ': ' + str(args.dataset) + ' ' + str(args.r * 100) + '%-' + str(args.noise_mode)
548
+
549
+ if type(loss) is not np.ndarray:
550
+ loss= loss.numpy()
551
+ sns.set(style='whitegrid')
552
+
553
+ gmm = GaussianMixture(n_components=2, max_iter=20, tol=1e-2, random_state=0, reg_covar=5e-4)
554
+ loss = np.reshape(loss, (-1, 1))
555
+ gmm.fit(loss) # fit the loss
556
+
557
+ # plot resulting fit
558
+ x_range = np.linspace(0, 1, 1000)
559
+ pdf = np.exp(gmm.score_samples(x_range.reshape(-1, 1)))
560
+ responsibilities = gmm.predict_proba(x_range.reshape(-1, 1))
561
+ pdf_individual = responsibilities * pdf[:, np.newaxis]
562
+
563
+ plt.hist(loss, bins=60, density=True, histtype='bar', alpha=0.3)
564
+ plt.plot(x_range, pdf, '-k', label='Mixture')
565
+ plt.plot(x_range, pdf_individual, '--', label='Component')
566
+ plt.legend()
567
+ # plt.tight_layout()
568
+
569
+ plt.title(title, fontsize=32)
570
+ plt.xlabel('Normalized loss', fontsize=32)
571
+ plt.ylabel('Estimated PDF', fontsize=32)
572
+ plt.tick_params(labelsize=32)
573
+ plt.legend(loc='upper right', prop={'size': 22})
574
+ if g_file:
575
+ plt.savefig('./figure_his/{0}.tif'.format(filename+model_name), bbox_inches='tight', dpi=300)
576
+ #plt.show()
577
+ plt.close()
578
+
579
+
580
+ def FedAvg(w):
581
+ w_avg = copy.deepcopy(w[0])
582
+ for k in w_avg.keys():
583
+ for i in range(1, len(w)):
584
+ w_avg[k] += w[i][k]
585
+ # 只考虑iid noise的话,每个client训练样本数一样,所以不用做nk/n
586
+ w_avg[k] = torch.div(w_avg[k], len(w))
587
+
588
+ return w_avg
589
+
590
+
591
+ if os.path.exists('checkpoint') == False:
592
+ os.mkdir('checkpoint')
593
+ print("新建日志文件夹")
594
+ stats_log = open('./checkpoint/single_%s_%.1f_%s_%d' % (args.dataset, args.r, args.noise_mode,
595
+ int(soft_mix_warm)) + '_stats.txt', 'w')
596
+ test_log = open('./checkpoint/single_%s_%.1f_%s_%d' % (args.dataset, args.r, args.noise_mode,
597
+ int(soft_mix_warm)) + '_acc.txt', 'w')
598
+
599
+ warm_up = 10
600
+ dmix_epoch = 150
601
+ args.num_epochs = dmix_epoch + 150
602
+ # 第6页提及的warm up的epoch
603
+ if args.dataset == 'cifar10':
604
+ warm_up = 10
605
+ dmix_epoch = 150
606
+ args.num_epochs = dmix_epoch + 50
607
+ elif args.dataset == 'cifar100':
608
+ warm_up = 30
609
+ dmix_epoch = 150
610
+ args.num_epochs = dmix_epoch + 50
611
+
612
+ loader = dataloader.cifar_dataloader(args.dataset, r=args.r, noise_mode=args.noise_mode,
613
+ batch_size=args.batch_size, num_workers=0,
614
+ root_dir=args.data_path, log=stats_log,
615
+ noise_file='%s/%.1f_%s.json' % (args.data_path, args.r, args.noise_mode))
616
+
617
+ print('| Building net')
618
+ net1 = create_model()
619
+ net2 = create_model()
620
+ cudnn.benchmark = True
621
+
622
+ criterion = SemiLoss()
623
+ optimizer1 = optim.SGD(net1.parameters(), lr=args.lr, momentum=0.9, weight_decay=5e-4)
624
+ optimizer2 = optim.SGD(net2.parameters(), lr=args.lr, momentum=0.9, weight_decay=5e-4)
625
+
626
+ CE = nn.CrossEntropyLoss(reduction='none')
627
+ CEloss = nn.CrossEntropyLoss()
628
+ if args.noise_mode == 'asym':
629
+ # 本文第一个问题,对于非对称和对称需要不同措施,这很不适用
630
+ # 其次本文在不同步骤中噪声数据处理措施很凌乱
631
+ conf_penalty = NegEntropy()
632
+
633
+ all_loss = [[], []] # save the history of losses from two networks
634
+
635
+ local_round = 5
636
+ first = True
637
+ balance_crit = 'median'
638
+ exp_path = './checkpoint/single_%s_%.1f_%s_double_m2_' % (args.dataset, args.r, args.noise_mode)
639
+ save_clean_idx = exp_path + "clean_idx.npy"
640
+ boot_loader = None
641
+ w_glob = None
642
+ if args.r == 0.9:
643
+ args.p_threshold = 0.6
644
+ best_en_acc = 0.
645
+ best_gl_acc = 0.
646
+ resume_epoch = 0
647
+ if resume_epoch > 0:
648
+ snapLast = exp_path + str(resume_epoch-1) + "_global_model.pth"
649
+ global_state = torch.load(snapLast)
650
+ # 先更新还是后跟新
651
+ w_glob = global_state
652
+ net1.load_state_dict(global_state)
653
+ net2.load_state_dict(global_state)
654
+ for epoch in range(resume_epoch, args.num_epochs + 1):
655
+ test_loader = loader.run('test')
656
+ eval_loader = loader.run('eval_train')
657
+ lr = args.lr
658
+ if epoch >= dmix_epoch:
659
+ lr /= 10
660
+ for param_group in optimizer1.param_groups:
661
+ param_group['lr'] = lr
662
+ for param_group in optimizer2.param_groups:
663
+ param_group['lr'] = lr
664
+
665
+ noise_ind, clean_ind = eval_loader.dataset.if_noise()
666
+ print(len(np.where(np.array(eval_loader.dataset.noise_label) != np.array(eval_loader.dataset.clean_label))[0])
667
+ / len(eval_loader.dataset.clean_label))
668
+ local_weights = []
669
+ if epoch < warm_up:
670
+ # 考虑warm up时是否需要merge
671
+ warmup_trainloader = loader.run('warmup')
672
+ print('Warmup Net1')
673
+ warmup(epoch, net1, optimizer1, warmup_trainloader)
674
+ print('\nWarmup Net2')
675
+ warmup(epoch, net2, optimizer2, warmup_trainloader)
676
+ if epoch == (warm_up-1):
677
+ snapLast = exp_path+str(epoch) + "_1_model.pth"
678
+ torch.save(net1.state_dict(), snapLast)
679
+ snapLast = exp_path+str(epoch) + "_2_model.pth"
680
+ torch.save(net1.state_dict(), snapLast)
681
+ local_weights.append(net1.state_dict())
682
+ local_weights.append(net2.state_dict())
683
+ w_glob = FedAvg(local_weights)
684
+
685
+ else:
686
+ if epoch != warm_up:
687
+ net1.load_state_dict(w_glob)
688
+ net2.load_state_dict(w_glob)
689
+
690
+ for rou in range(local_round):
691
+ prob1, all_loss[0], loss1, mask1, f_G1 = eval_train(net1, all_loss[0])
692
+ prob2, all_loss[1], loss2, mask2, f_G2 = eval_train(net2, all_loss[1])
693
+
694
+ # 加载完global后第一次评估
695
+ if rou == 0:
696
+ # plotHistogram(np.array(loss1), np.array(loss2), noise_ind, clean_ind, epoch, rou, args.r)
697
+ loss_dist_plot(loss1, noise_ind, clean_ind, epoch, model_name='model_1')
698
+ # loss_dist_plot_real(loss1, epoch, model_name='model_1')
699
+ if rou == local_round-1:
700
+ plotHistogram(np.array(loss1), np.array(loss2), noise_ind, clean_ind, epoch, rou, args.r)
701
+
702
+ # pred1 = (prob1 > args.p_threshold) & (mask1 != 0)
703
+ # pred2 = (prob2 > args.p_threshold) & (mask2 != 0)
704
+ pred1 = (prob1 > args.p_threshold)
705
+ pred2 = (prob2 > args.p_threshold)
706
+
707
+ non_zero_idx = pred1.nonzero()[0].tolist()
708
+ aaa = len(non_zero_idx)
709
+ if balance_crit == "max" or balance_crit == "min" or balance_crit == "median":
710
+ num_clean_per_class = np.zeros(args.num_class)
711
+ target_label = np.array(eval_loader.dataset.noise_label)[non_zero_idx]
712
+ for i in range(args.num_class):
713
+ idx_class = np.where(target_label == i)[0]
714
+ num_clean_per_class[i] = len(idx_class)
715
+
716
+ if balance_crit == "median":
717
+ num_samples2select_class = np.median(num_clean_per_class)
718
+
719
+ for i in range(args.num_class):
720
+ idx_class = np.where(np.array(eval_loader.dataset.noise_label) == i)[0]
721
+ cur_num = num_clean_per_class[i]
722
+ idx_class2 = non_zero_idx
723
+ if num_samples2select_class > cur_num:
724
+ remian_idx = list(set(idx_class.tolist()) - set(idx_class2))
725
+ idx = list(range(len(remian_idx)))
726
+ random.shuffle(idx)
727
+ num_app = int(num_samples2select_class - cur_num)
728
+ idx = idx[:num_app]
729
+ for j in idx:
730
+ non_zero_idx.append(remian_idx[j])
731
+ non_zero_idx = np.array(non_zero_idx).reshape(-1, )
732
+ bbb = len(non_zero_idx)
733
+ num_per_class2 = []
734
+ for i in range(max(eval_loader.dataset.noise_label)):
735
+ temp = np.where(np.array(eval_loader.dataset.noise_label)[non_zero_idx.tolist()] == i)[0]
736
+ num_per_class2.append(len(temp))
737
+ print('\npred1 appended num per class:', num_per_class2, aaa, bbb)
738
+ idx_per_class = np.zeros_like(pred1).astype(bool)
739
+ for i in non_zero_idx:
740
+ idx_per_class[i] = True
741
+ pred1 = idx_per_class
742
+ non_aaa = pred1.nonzero()[0].tolist()
743
+ assert len(non_aaa) == len(non_zero_idx)
744
+
745
+ non_zero_idx2 = pred2.nonzero()[0].tolist()
746
+ aaa = len(non_zero_idx2)
747
+ if balance_crit == "max" or balance_crit == "min" or balance_crit == "median":
748
+ num_clean_per_class = np.zeros(args.num_class)
749
+ target_label = np.array(eval_loader.dataset.noise_label)[non_zero_idx2]
750
+ for i in range(args.num_class):
751
+ idx_class = np.where(target_label == i)[0]
752
+ num_clean_per_class[i] = len(idx_class)
753
+
754
+ if balance_crit == "median":
755
+ num_samples2select_class = np.median(num_clean_per_class)
756
+
757
+ for i in range(args.num_class):
758
+ idx_class = np.where(np.array(eval_loader.dataset.noise_label) == i)[0]
759
+ cur_num = num_clean_per_class[i]
760
+ idx_class2 = non_zero_idx2
761
+ if num_samples2select_class > cur_num:
762
+ remian_idx = list(set(idx_class.tolist()) - set(idx_class2))
763
+ idx = list(range(len(remian_idx)))
764
+ random.shuffle(idx)
765
+ num_app = int(num_samples2select_class - cur_num)
766
+ idx = idx[:num_app]
767
+ for j in idx:
768
+ non_zero_idx2.append(remian_idx[j])
769
+ non_zero_idx2 = np.array(non_zero_idx2).reshape(-1, )
770
+ bbb = len(non_zero_idx2)
771
+ num_per_class2 = []
772
+ for i in range(max(eval_loader.dataset.noise_label)):
773
+ temp = np.where(np.array(eval_loader.dataset.noise_label)[non_zero_idx2.tolist()] == i)[0]
774
+ num_per_class2.append(len(temp))
775
+ print('\npred2 appended num per class:', num_per_class2, aaa, bbb)
776
+ idx_per_class2 = np.zeros_like(pred2).astype(bool)
777
+ for i in non_zero_idx2:
778
+ idx_per_class2[i] = True
779
+ pred2 = idx_per_class2
780
+ non_aaa = pred2.nonzero()[0].tolist()
781
+ assert len(non_aaa) == len(non_zero_idx2)
782
+
783
+ correct_num = len(pred1.nonzero()[0])
784
+ eval_loader.dataset.if_noise(pred1)
785
+ eval_loader.dataset.if_noise(pred2)
786
+
787
+ print(f'round={rou}/{local_round}, dmix selection, Train Net1')
788
+ # prob2就是先验概率wi,通过GMM拟合出来的,大于阈值就认为是clean,否则noisy
789
+ labeled_trainloader, unlabeled_trainloader = loader.run('train', pred2, prob2) # co-divide
790
+ train(epoch, net1, net2, optimizer1, labeled_trainloader, unlabeled_trainloader) # train net1
791
+
792
+ print(f'\nround={rou}/{local_round}, dmix selection, Train Net2')
793
+ labeled_trainloader, unlabeled_trainloader = loader.run('train', pred1, prob1) # co-divide
794
+ train(epoch, net2, net1, optimizer2, labeled_trainloader, unlabeled_trainloader) # train net2
795
+
796
+ local_weights.append(net1.state_dict())
797
+ local_weights.append(net2.state_dict())
798
+ w_glob = FedAvg(local_weights)
799
+ if epoch % 5 == 0:
800
+ snapLast = exp_path + str(epoch) + "_global_model.pth"
801
+ torch.save(w_glob, snapLast)
802
+
803
+ best_en_acc = test(epoch, net1, net2, best_en_acc)
804
+ best_gl_acc= test(epoch, net1, net2, best_gl_acc, w_glob)
805
+
806
+
807
+