|
from __future__ import print_function
|
|
import sys
|
|
import torch
|
|
import torch.nn as nn
|
|
import torch.optim as optim
|
|
import torch.nn.functional as F
|
|
import torch.backends.cudnn as cudnn
|
|
import torchvision
|
|
import torchvision.models as models
|
|
from models.CNN import CNN
|
|
import random
|
|
import os
|
|
import argparse
|
|
import numpy as np
|
|
import dataloader_animal10N as animal_dataloader
|
|
from sklearn.mixture import GaussianMixture
|
|
import copy
|
|
|
|
parser = argparse.ArgumentParser(description='PyTorch Clothing1M Training')
|
|
parser.add_argument('--batch_size', default=128, type=int, help='train batchsize')
|
|
parser.add_argument('--lr', '--learning_rate', default=0.01, type=float, help='initial learning rate')
|
|
parser.add_argument('--alpha', default=4, type=float, help='parameter for Beta')
|
|
parser.add_argument('--lambda_u', default=0, type=float, help='weight for unsupervised loss')
|
|
parser.add_argument('--p_threshold', default=0.5, type=float, help='clean probability threshold')
|
|
parser.add_argument('--T', default=0.5, type=float, help='sharpening temperature')
|
|
parser.add_argument('--num_epochs', default=300, type=int)
|
|
parser.add_argument('--id', default='animal10N')
|
|
|
|
parser.add_argument('--data_path', default='C:/Users/USSTz/Desktop/Animal-10N', type=str, help='path to dataset')
|
|
parser.add_argument('--seed', default=123)
|
|
parser.add_argument('--gpuid', default=0, type=int)
|
|
parser.add_argument('--num_class', default=10, type=int)
|
|
|
|
args = parser.parse_args()
|
|
|
|
torch.cuda.set_device(args.gpuid)
|
|
random.seed(args.seed)
|
|
torch.manual_seed(args.seed)
|
|
torch.cuda.manual_seed_all(args.seed)
|
|
|
|
|
|
|
|
def train(epoch, net, net2, optimizer, labeled_trainloader, unlabeled_trainloader):
|
|
net.train()
|
|
net2.eval()
|
|
|
|
unlabeled_train_iter = iter(unlabeled_trainloader)
|
|
num_iter = (len(labeled_trainloader.dataset) // args.batch_size) + 1
|
|
for batch_idx, (inputs_x, inputs_x2, labels_x, w_x) in enumerate(labeled_trainloader):
|
|
try:
|
|
inputs_u, inputs_u2 = unlabeled_train_iter.__next__()
|
|
except:
|
|
unlabeled_train_iter = iter(unlabeled_trainloader)
|
|
inputs_u, inputs_u2 = unlabeled_train_iter.__next__()
|
|
batch_size = inputs_x.size(0)
|
|
|
|
|
|
labels_x = torch.zeros(batch_size, args.num_class).scatter_(1, labels_x.view(-1, 1), 1)
|
|
w_x = w_x.view(-1, 1).type(torch.FloatTensor)
|
|
|
|
inputs_x, inputs_x2, labels_x, w_x = inputs_x.cuda(), inputs_x2.cuda(), labels_x.cuda(), w_x.cuda()
|
|
inputs_u, inputs_u2 = inputs_u.cuda(), inputs_u2.cuda()
|
|
|
|
with torch.no_grad():
|
|
|
|
outputs_u11 = net(inputs_u)
|
|
outputs_u12 = net(inputs_u2)
|
|
outputs_u21 = net2(inputs_u)
|
|
outputs_u22 = net2(inputs_u2)
|
|
|
|
pu = (torch.softmax(outputs_u11, dim=1) + torch.softmax(outputs_u12, dim=1) +
|
|
torch.softmax(outputs_u21, dim=1) + torch.softmax(outputs_u22, dim=1)) / 4
|
|
ptu = pu ** (1 / args.T)
|
|
|
|
targets_u = ptu / ptu.sum(dim=1, keepdim=True)
|
|
targets_u = targets_u.detach()
|
|
|
|
|
|
outputs_x = net(inputs_x)
|
|
outputs_x2 = net(inputs_x2)
|
|
|
|
px = (torch.softmax(outputs_x, dim=1) + torch.softmax(outputs_x2, dim=1)) / 2
|
|
px = w_x * labels_x + (1 - w_x) * px
|
|
ptx = px ** (1 / args.T)
|
|
|
|
targets_x = ptx / ptx.sum(dim=1, keepdim=True)
|
|
targets_x = targets_x.detach()
|
|
|
|
|
|
l = np.random.beta(args.alpha, args.alpha)
|
|
l = max(l, 1 - l)
|
|
|
|
all_inputs = torch.cat([inputs_x, inputs_x2, inputs_u, inputs_u2], dim=0)
|
|
all_targets = torch.cat([targets_x, targets_x, targets_u, targets_u], dim=0)
|
|
|
|
idx = torch.randperm(all_inputs.size(0))
|
|
|
|
input_a, input_b = all_inputs, all_inputs[idx]
|
|
target_a, target_b = all_targets, all_targets[idx]
|
|
|
|
mixed_input = l * input_a[:batch_size * 2] + (1 - l) * input_b[:batch_size * 2]
|
|
mixed_target = l * target_a[:batch_size * 2] + (1 - l) * target_b[:batch_size * 2]
|
|
|
|
logits = net(mixed_input)
|
|
|
|
Lx = -torch.mean(torch.sum(F.log_softmax(logits, dim=1) * mixed_target, dim=1))
|
|
|
|
|
|
prior = torch.ones(args.num_class) / args.num_class
|
|
prior = prior.cuda()
|
|
pred_mean = torch.softmax(logits, dim=1).mean(0)
|
|
penalty = torch.sum(prior * torch.log(prior / pred_mean))
|
|
|
|
loss = Lx + penalty
|
|
|
|
|
|
optimizer.zero_grad()
|
|
loss.backward()
|
|
optimizer.step()
|
|
|
|
sys.stdout.write('\r')
|
|
sys.stdout.write('Animal10N | Epoch [%3d/%3d] Iter[%3d/%3d]\t Labeled loss: %.4f '
|
|
% (epoch, args.num_epochs, batch_idx + 1, num_iter, Lx.item()))
|
|
sys.stdout.flush()
|
|
|
|
|
|
def warmup(net, optimizer, dataloader):
|
|
net.train()
|
|
num_batches = 50000/args.batch_size
|
|
for batch_idx, (inputs, labels, path) in enumerate(dataloader):
|
|
inputs, labels = inputs.cuda(), labels.cuda()
|
|
optimizer.zero_grad()
|
|
outputs = net(inputs)
|
|
loss = CEloss(outputs, labels)
|
|
|
|
penalty = conf_penalty(outputs)
|
|
L = loss + penalty
|
|
L.backward()
|
|
optimizer.step()
|
|
|
|
sys.stdout.write('\r')
|
|
sys.stdout.write('|Warm-up: Iter[%3d/%3d]\t CE-loss: %.4f Conf-Penalty: %.4f'
|
|
% (batch_idx + 1, num_batches, loss.item(), penalty.item()))
|
|
sys.stdout.flush()
|
|
|
|
|
|
def val(net, val_loader, best_acc, w_glob=None):
|
|
net.eval()
|
|
correct = 0
|
|
total = 0
|
|
with torch.no_grad():
|
|
for batch_idx, (inputs, targets) in enumerate(val_loader):
|
|
inputs, targets = inputs.cuda(), targets.cuda()
|
|
outputs = net(inputs)
|
|
_, predicted = torch.max(outputs, 1)
|
|
|
|
total += targets.size(0)
|
|
correct += predicted.eq(targets).cpu().sum().item()
|
|
acc = 100. * correct / total
|
|
print("\n| Validation\t Net%d Acc: %.2f%%" % (k, acc))
|
|
if acc > best_acc[k - 1]:
|
|
best_acc[k - 1] = acc
|
|
print('| Saving Best Net%d ...' % k)
|
|
save_point = './checkpoint/%s_net%d.pth.tar' % (args.id, k)
|
|
torch.save(net.state_dict(), save_point)
|
|
return acc
|
|
|
|
|
|
def test(epoch, net1, net2, test_loader, best_acc, w_glob=None):
|
|
if w_glob is None:
|
|
net1.eval()
|
|
net2.eval()
|
|
correct = 0
|
|
correct2 = 0
|
|
correct1 = 0
|
|
total = 0
|
|
with torch.no_grad():
|
|
for batch_idx, (inputs, targets) in enumerate(test_loader):
|
|
inputs, targets = inputs.cuda(), targets.cuda()
|
|
outputs1 = net1(inputs)
|
|
outputs2 = net2(inputs)
|
|
outputs = outputs1 + outputs2
|
|
_, predicted = torch.max(outputs, 1)
|
|
_, predicted1 = torch.max(outputs1, 1)
|
|
_, predicted2 = torch.max(outputs2, 1)
|
|
|
|
total += targets.size(0)
|
|
correct += predicted.eq(targets).cpu().sum().item()
|
|
correct1 += predicted1.eq(targets).cpu().sum().item()
|
|
correct2 += predicted2.eq(targets).cpu().sum().item()
|
|
acc = 100. * correct / total
|
|
acc1 = 100. * correct / total
|
|
acc2 = 100. * correct / total
|
|
if best_acc < acc:
|
|
best_acc = acc
|
|
print(
|
|
"\n| Ensemble network Test Epoch #%d\t Accuracy: %.2f, Accuracy1: %.2f, Accuracy2: %.2f, best_acc: %.2f%%\n" % (
|
|
epoch, acc, acc1, acc2, best_acc))
|
|
log.write('ensemble_Epoch:%d Accuracy:%.2f, Accuracy1: %.2f, Accuracy2: %.2f, best_acc: %.2f\n' % (
|
|
epoch, acc, acc1, acc2, best_acc))
|
|
log.flush()
|
|
else:
|
|
net1_w_bak = net1.state_dict()
|
|
net1.load_state_dict(w_glob)
|
|
net1.eval()
|
|
correct = 0
|
|
total = 0
|
|
with torch.no_grad():
|
|
for batch_idx, (inputs, targets) in enumerate(test_loader):
|
|
inputs, targets = inputs.cuda(), targets.cuda()
|
|
outputs1 = net1(inputs)
|
|
_, predicted = torch.max(outputs1, 1)
|
|
total += targets.size(0)
|
|
correct += predicted.eq(targets).cpu().sum().item()
|
|
acc = 100. * correct / total
|
|
if best_acc < acc:
|
|
best_acc = acc
|
|
print("\n| Global network Test Epoch #%d\t Accuracy: %.2f, best_acc: %.2f%%\n" % (epoch, acc, best_acc))
|
|
log.write('global_Epoch:%d Accuracy:%.2f, best_acc: %.2f\n' % (epoch, acc, best_acc))
|
|
log.flush()
|
|
|
|
net1.load_state_dict(net1_w_bak)
|
|
return best_acc
|
|
|
|
|
|
def eval_train(epoch, model):
|
|
model.eval()
|
|
num_samples = eval_loader.dataset.__len__()
|
|
losses = torch.zeros(num_samples)
|
|
paths = []
|
|
n = 0
|
|
with torch.no_grad():
|
|
for batch_idx, (inputs, targets, path) in enumerate(eval_loader):
|
|
inputs, targets = inputs.cuda(), targets.cuda()
|
|
outputs = model(inputs)
|
|
loss = CE(outputs, targets)
|
|
for b in range(inputs.size(0)):
|
|
losses[n] = loss[b]
|
|
paths.append(path[b])
|
|
n += 1
|
|
sys.stdout.write('\r')
|
|
sys.stdout.write('| Evaluating loss Iter %3d\t' % (batch_idx))
|
|
sys.stdout.flush()
|
|
|
|
losses = (losses - losses.min()) / (losses.max() - losses.min())
|
|
losses = losses.reshape(-1, 1)
|
|
gmm = GaussianMixture(n_components=2, max_iter=10, reg_covar=5e-4, tol=1e-2)
|
|
gmm.fit(losses)
|
|
prob = gmm.predict_proba(losses)
|
|
prob = prob[:, gmm.means_.argmin()]
|
|
return prob, paths
|
|
|
|
|
|
class NegEntropy(object):
|
|
def __call__(self, outputs):
|
|
probs = torch.softmax(outputs, dim=1)
|
|
return torch.mean(torch.sum(probs.log() * probs, dim=1))
|
|
|
|
|
|
def create_model():
|
|
use_cnn = True
|
|
if use_cnn:
|
|
model = CNN()
|
|
model = model.cuda()
|
|
else:
|
|
model = models.vgg19_bn(pretrained=False)
|
|
model.classifier._modules['6'] = nn.Linear(4096, 10)
|
|
model = model.cuda()
|
|
return model
|
|
|
|
|
|
def FedAvg(w):
|
|
w_avg = copy.deepcopy(w[0])
|
|
for k in w_avg.keys():
|
|
for i in range(1, len(w)):
|
|
w_avg[k] += w[i][k]
|
|
|
|
w_avg[k] = torch.div(w_avg[k], len(w))
|
|
|
|
return w_avg
|
|
|
|
|
|
log = open('./checkpoint/%s.txt' % args.id, 'w')
|
|
log.flush()
|
|
|
|
loader = animal_dataloader.animal_dataloader(root=args.data_path, batch_size=args.batch_size, num_workers=0)
|
|
|
|
print('| Building net')
|
|
net1 = create_model()
|
|
net2 = create_model()
|
|
cudnn.benchmark = True
|
|
|
|
optimizer1 = optim.SGD(net1.parameters(), lr=args.lr, momentum=0.9, weight_decay=1e-3)
|
|
optimizer2 = optim.SGD(net2.parameters(), lr=args.lr, momentum=0.9, weight_decay=1e-3)
|
|
|
|
CE = nn.CrossEntropyLoss(reduction='none')
|
|
CEloss = nn.CrossEntropyLoss()
|
|
conf_penalty = NegEntropy()
|
|
|
|
local_round = 5
|
|
balance_crit = 'median'
|
|
exp_path = './checkpoint/c2mt_animal10N'
|
|
|
|
boot_loader = None
|
|
w_glob = None
|
|
best_en_acc = 0.
|
|
best_gl_acc = 0.
|
|
resume_epoch = 0
|
|
warm_up = 10
|
|
if resume_epoch > 0:
|
|
snapLast = exp_path + str(resume_epoch - 1) + "_global_model.pth"
|
|
global_state = torch.load(snapLast)
|
|
|
|
w_glob = global_state
|
|
net1.load_state_dict(global_state)
|
|
net2.load_state_dict(global_state)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
for epoch in range(resume_epoch, args.num_epochs + 1):
|
|
lr = args.lr
|
|
if 50 <= epoch < 100:
|
|
lr /= 10
|
|
elif epoch >= 130:
|
|
lr /= 10
|
|
|
|
|
|
for param_group in optimizer1.param_groups:
|
|
param_group['lr'] = lr
|
|
for param_group in optimizer2.param_groups:
|
|
param_group['lr'] = lr
|
|
|
|
local_weights = []
|
|
if epoch < warm_up:
|
|
train_loader = loader.run('warmup')
|
|
print('Warmup Net1')
|
|
warmup(net1, optimizer1, train_loader)
|
|
train_loader = loader.run('warmup')
|
|
print('\nWarmup Net2')
|
|
warmup(net2, optimizer2, train_loader)
|
|
if epoch == (warm_up - 1):
|
|
snapLast = exp_path + str(epoch) + "_1_model.pth"
|
|
torch.save(net1.state_dict(), snapLast)
|
|
snapLast = exp_path + str(epoch) + "_2_model.pth"
|
|
torch.save(net1.state_dict(), snapLast)
|
|
local_weights.append(net1.state_dict())
|
|
local_weights.append(net2.state_dict())
|
|
w_glob = FedAvg(local_weights)
|
|
else:
|
|
if epoch != warm_up:
|
|
net1.load_state_dict(w_glob)
|
|
net2.load_state_dict(w_glob)
|
|
|
|
for rou in range(local_round):
|
|
print('\n==== net 1 evaluate next epoch training data loss ====')
|
|
eval_loader = loader.run('eval_train')
|
|
prob1, paths1 = eval_train(epoch, net1)
|
|
print('\n==== net 2 evaluate next epoch training data loss ====')
|
|
eval_loader = loader.run('eval_train')
|
|
prob2, paths2 = eval_train(epoch, net2)
|
|
|
|
pred1 = (prob1 > args.p_threshold)
|
|
pred2 = (prob2 > args.p_threshold)
|
|
|
|
non_zero_idx = pred1.nonzero()[0].tolist()
|
|
aaa = len(non_zero_idx)
|
|
if balance_crit == "max" or balance_crit == "min" or balance_crit == "median":
|
|
num_clean_per_class = np.zeros(args.num_class)
|
|
ppp = np.array(paths1)[non_zero_idx].tolist()
|
|
target_label = np.array([eval_loader.dataset.train_labels[it] for it in ppp])
|
|
|
|
for i in range(args.num_class):
|
|
idx_class = np.where(target_label == i)[0]
|
|
num_clean_per_class[i] = len(idx_class)
|
|
|
|
if balance_crit == "max":
|
|
num_samples2select_class = np.max(num_clean_per_class)
|
|
elif balance_crit == "min":
|
|
num_samples2select_class = np.min(num_clean_per_class)
|
|
elif balance_crit == "median":
|
|
num_samples2select_class = np.median(num_clean_per_class)
|
|
|
|
for i in range(args.num_class):
|
|
idx_class = np.where(np.array([eval_loader.dataset.train_labels[it] for it in paths1]) == i)[0]
|
|
cur_num = num_clean_per_class[i]
|
|
idx_class2 = non_zero_idx
|
|
if num_samples2select_class > cur_num:
|
|
remian_idx = list(set(idx_class.tolist()) - set(idx_class2))
|
|
idx = list(range(len(remian_idx)))
|
|
random.shuffle(idx)
|
|
num_app = int(num_samples2select_class - cur_num)
|
|
idx = idx[:num_app]
|
|
for j in idx:
|
|
non_zero_idx.append(remian_idx[j])
|
|
non_zero_idx = np.array(non_zero_idx).reshape(-1, )
|
|
bbb = len(non_zero_idx)
|
|
num_per_class2 = []
|
|
for i in range(10):
|
|
temp = \
|
|
np.where(np.array([eval_loader.dataset.train_labels[it] for it in paths1])[non_zero_idx.tolist()] == i)[
|
|
0]
|
|
num_per_class2.append(len(temp))
|
|
print('\npred1 appended num per class:', num_per_class2, aaa, bbb)
|
|
idx_per_class = np.zeros_like(pred1).astype(bool)
|
|
for i in non_zero_idx:
|
|
idx_per_class[i] = True
|
|
pred1 = idx_per_class
|
|
non_aaa = pred1.nonzero()[0].tolist()
|
|
assert len(non_aaa) == len(non_zero_idx)
|
|
|
|
non_zero_idx2 = pred2.nonzero()[0].tolist()
|
|
aaa = len(non_zero_idx2)
|
|
if balance_crit == "max" or balance_crit == "min" or balance_crit == "median":
|
|
num_clean_per_class = np.zeros(args.num_class)
|
|
ppp = np.array(paths2)[non_zero_idx].tolist()
|
|
target_label = np.array([eval_loader.dataset.train_labels[it] for it in ppp])
|
|
for i in range(args.num_class):
|
|
idx_class = np.where(target_label == i)[0]
|
|
num_clean_per_class[i] = len(idx_class)
|
|
|
|
if balance_crit == "max":
|
|
num_samples2select_class = np.max(num_clean_per_class)
|
|
elif balance_crit == "min":
|
|
num_samples2select_class = np.min(num_clean_per_class)
|
|
elif balance_crit == "median":
|
|
num_samples2select_class = np.median(num_clean_per_class)
|
|
|
|
for i in range(args.num_class):
|
|
idx_class = np.where(np.array([eval_loader.dataset.train_labels[it] for it in paths1]) == i)[0]
|
|
cur_num = num_clean_per_class[i]
|
|
idx_class2 = non_zero_idx2
|
|
if num_samples2select_class > cur_num:
|
|
remian_idx = list(set(idx_class.tolist()) - set(idx_class2))
|
|
idx = list(range(len(remian_idx)))
|
|
random.shuffle(idx)
|
|
num_app = int(num_samples2select_class - cur_num)
|
|
idx = idx[:num_app]
|
|
for j in idx:
|
|
non_zero_idx2.append(remian_idx[j])
|
|
non_zero_idx2 = np.array(non_zero_idx2).reshape(-1, )
|
|
bbb = len(non_zero_idx2)
|
|
num_per_class2 = []
|
|
for i in range(10):
|
|
temp = np.where(
|
|
np.array([eval_loader.dataset.train_labels[it] for it in paths1])[non_zero_idx2.tolist()] == i)[0]
|
|
num_per_class2.append(len(temp))
|
|
print('\npred2 appended num per class:', num_per_class2, aaa, bbb)
|
|
idx_per_class2 = np.zeros_like(pred2).astype(bool)
|
|
for i in non_zero_idx2:
|
|
idx_per_class2[i] = True
|
|
pred2 = idx_per_class2
|
|
non_aaa = pred2.nonzero()[0].tolist()
|
|
assert len(non_aaa) == len(non_zero_idx2)
|
|
|
|
print(f'round={rou}/{local_round}, dmix selection, Train Net1')
|
|
labeled_trainloader, unlabeled_trainloader = loader.run('train', pred2, prob2, paths=paths2)
|
|
train(epoch, net1, net2, optimizer1, labeled_trainloader, unlabeled_trainloader)
|
|
|
|
print(f'\nround={rou}/{local_round}, dmix selection, Train Net2')
|
|
labeled_trainloader, unlabeled_trainloader = loader.run('train', pred1, prob1, paths=paths1)
|
|
train(epoch, net2, net1, optimizer2, labeled_trainloader, unlabeled_trainloader)
|
|
|
|
test_loader = loader.run('test')
|
|
if rou != local_round-1:
|
|
best_en_acc = test(epoch, net1, net2, test_loader, best_en_acc)
|
|
|
|
|
|
print(f'c2m, get global network\n')
|
|
local_weights.append(net1.state_dict())
|
|
local_weights.append(net2.state_dict())
|
|
w_glob = FedAvg(local_weights)
|
|
if epoch % 1 == 0:
|
|
snapLast = exp_path + str(epoch) + "_global_model.pth"
|
|
torch.save(w_glob, snapLast)
|
|
|
|
test_loader = loader.run('test')
|
|
best_en_acc = test(epoch, net1, net2, test_loader, best_en_acc)
|
|
best_gl_acc = test(epoch, net1, net2, test_loader, best_gl_acc, w_glob=w_glob)
|
|
|
|
|