C2MT / train_cifar_c2mt.py
LanXiaoPang613
Add files via upload
d3cde70 unverified
from __future__ import print_function
import sys
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torch.backends.cudnn as cudnn
import random
import os
import argparse
import numpy as np
from PreResNet import *
from sklearn.mixture import GaussianMixture
import dataloader_cifar as dataloader
import matplotlib.pyplot as plt
import copy
import seaborn as sns
# from sklearn.mixture import GaussianMixture
from sklearn.cluster import KMeans
from sklearn.cluster import Birch
import matplotlib
parser = argparse.ArgumentParser(description='PyTorch CIFAR Training')
parser.add_argument('--batch_size', default=128, type=int, help='train batchsize')
parser.add_argument('--lr', '--learning_rate', default=0.02, type=float, help='initial learning rate')
parser.add_argument('--noise_mode', default='asym')
parser.add_argument('--alpha', default=4, type=float, help='parameter for Beta')
parser.add_argument('--lambda_u', default=150, type=float, help='weight for unsupervised loss')
parser.add_argument('--p_threshold', default=0.5, type=float, help='clean probability threshold')
parser.add_argument('--T', default=0.5, type=float, help='sharpening temperature')
parser.add_argument('--num_epochs', default=300, type=int)
parser.add_argument('--r', default=0.3, type=float, help='noise ratio')
parser.add_argument('--id', default='')
parser.add_argument('--seed', default=123)
parser.add_argument('--gpuid', default=0, type=int)
parser.add_argument('--num_class', default=100, type=int)
# parser.add_argument('--data_path', default='./data/cifar-10-batches-py', type=str, help='path to dataset')
# parser.add_argument('--dataset', default='cifar10', type=str)
parser.add_argument('--data_path', default='./data/cifar-100-python', type=str, help='path to dataset')
parser.add_argument('--dataset', default='cifar100', type=str)
args = parser.parse_args()
torch.cuda.set_device(args.gpuid)
random.seed(args.seed)
torch.manual_seed(args.seed)
torch.cuda.manual_seed_all(args.seed)
mse = torch.nn.MSELoss(reduction='none').cuda()
# Training
def train(epoch, net, net2, optimizer, labeled_trainloader, unlabeled_trainloader, mask=None, f_G=None, new_y=None):
net.train()
net2.eval() # fix one network and train the other
unlabeled_train_iter = iter(unlabeled_trainloader)
num_iter = (len(labeled_trainloader.dataset) // args.batch_size) + 1
mse_total = 0
for batch_idx, (inputs_x, inputs_x2, labels_x, w_x) in enumerate(labeled_trainloader):
try:
inputs_u, inputs_u2 = unlabeled_train_iter.__next__()
except:
unlabeled_train_iter = iter(unlabeled_trainloader)
inputs_u, inputs_u2 = unlabeled_train_iter.__next__()
batch_size = inputs_x.size(0)
# Transform label to one-hot,转为0-1矩阵
labels_x = torch.zeros(batch_size, args.num_class).scatter_(1, labels_x.view(-1, 1), 1)
w_x = w_x.view(-1, 1).type(torch.FloatTensor)
inputs_x, inputs_x2, labels_x, w_x = inputs_x.cuda(), inputs_x2.cuda(), labels_x.cuda(), w_x.cuda()
inputs_u, inputs_u2 = inputs_u.cuda(), inputs_u2.cuda()
with torch.no_grad():
# label co-guessing of unlabeled samples
outputs_u11, feat_u11 = net(inputs_u, feat_out=True)
outputs_u12, feat_u12 = net(inputs_u2, feat_out=True)
outputs_u21, feat_u21 = net2(inputs_u, feat_out=True)
outputs_u22, feat_u22 = net2(inputs_u2, feat_out=True)
# 取average of 所有网络的输出,作者利用了所谓的augmentation
pu = (torch.softmax(outputs_u11, dim=1) + torch.softmax(outputs_u12, dim=1)
+ torch.softmax(outputs_u21, dim=1) + torch.softmax(outputs_u22, dim=1)) / 4
ptu = pu ** (1 / args.T) # temparature sharpening
# Algorithm 1 中的shapen(qb,T)
targets_u = ptu / ptu.sum(dim=1, keepdim=True) # normalize
targets_u = targets_u.detach()
# label refinement of labeled samples
outputs_x, feat_x1 = net(inputs_x, feat_out=True)
outputs_x2, feat_x2 = net(inputs_x2, feat_out=True)
# 取labeled的输出平均值
px = (torch.softmax(outputs_x, dim=1) + torch.softmax(outputs_x2, dim=1)) / 2
# 公式(3)(4)退火
px = w_x * labels_x + (1 - w_x) * px
ptx = px ** (1 / args.T) # temparature sharpening
targets_x = ptx / ptx.sum(dim=1, keepdim=True) # normalize
targets_x = targets_x.detach()
# aaa = torch.argmax(labels_x, dim=1)
# mse_loss = torch.sum(mse((feat_x1+feat_x2)/2, f_G[aaa]), 1)
# mse_total = (mse_total + torch.sum(mse_loss) / len(mse_loss))/2
# mixmatch
l = np.random.beta(args.alpha, args.alpha)
# 促使X'更加靠近labeled sample而不是无监督样本
l = max(l, 1 - l)
all_inputs = torch.cat([inputs_x, inputs_x2, inputs_u, inputs_u2], dim=0)
all_targets = torch.cat([targets_x, targets_x, targets_u, targets_u], dim=0)
# 随机输出mini batch的序号,来mixup
idx = torch.randperm(all_inputs.size(0))
input_a, input_b = all_inputs, all_inputs[idx]
target_a, target_b = all_targets, all_targets[idx]
# 利用mix但是促使模型更偏向于label而不是UNlabel
mixed_input = l * input_a + (1 - l) * input_b
mixed_target = l * target_a + (1 - l) * target_b
logits = net(mixed_input)
# 输出被排列成两部分,input_x、Input_u
logits_x = logits[:batch_size * 2]
logits_u = logits[batch_size * 2:]
# 利用公式(9)-(10)计算损失函数,其中lamb是所谓的warm up
Lx, Lu, lamb = criterion(logits_x, mixed_target[:batch_size * 2],
logits_u, mixed_target[batch_size * 2:],
epoch + batch_idx / num_iter, warm_up)
# regularization
prior = torch.ones(args.num_class) / args.num_class
prior = prior.cuda()
pred_mean = torch.softmax(logits, dim=1).mean(0)
# 一般来说会省略固定的prior部分,只取last term
# lambR=1
penalty = torch.sum(prior * torch.log(prior / pred_mean))
# lamb是通过warm和current epoch比较得出的百分数,意味着随着epoch进行,Lu所占比重会逐渐增加
# 前期需要保持标准CE损失,但是实际还有penalty
# loss = Lx + lamb * Lu + penalty
loss = Lx + penalty + lamb * Lu
# compute gradient and do SGD step
optimizer.zero_grad()
loss.backward()
optimizer.step()
if batch_idx % 200 == 0:
sys.stdout.write('\r')
sys.stdout.write(
'%s:%.1f-%s | Epoch [%3d/%3d] Iter[%3d/%3d]\t Labeled loss: %.2f Unlabeled loss: %.2f\n'
% (args.dataset, args.r, args.noise_mode, epoch, args.num_epochs, batch_idx + 1, num_iter,
Lx.item(), Lu.item()))
sys.stdout.flush()
# print('\r mse loss:%.4f\n' % mse_total, end='end', flush=True)
# print('\r mse loss:%.4f\n' % mse_total, end='end', flush=True)
def mixup_criterion(pred, y_a, y_b, lam):
c = F.log_softmax(pred, 1)
return lam * F.cross_entropy(c, y_a) + (1 - lam) * F.cross_entropy(c, y_b)
soft_mix_warm = False
def warmup(epoch, net, optimizer, dataloader):
net.train()
num_iter = (len(dataloader.dataset) // dataloader.batch_size) + 1
for batch_idx, (inputs, labels, path) in enumerate(dataloader):
optimizer.zero_grad()
l = np.random.beta(args.alpha, args.alpha)
# 促使X'更加靠近labeled sample而不是无监督样本
l = max(l, 1 - l)
idx = torch.randperm(inputs.size(0))
targets = torch.zeros(inputs.size(0), args.num_class).scatter_(1, labels.view(-1, 1), 1).cuda()
targets = torch.clamp(targets, 1e-4, 1.)
inputs, labels = inputs.cuda(), labels.cuda()
if soft_mix_warm:
input_a, input_b = inputs, inputs[idx]
target_a, target_b = targets, targets[idx]
labels_a, labels_b = labels, labels[idx]
# 利用mix但是促使模型更偏向于label而不是UNlabel
mixed_input = l * input_a + (1 - l) * input_b
mixed_target = l * target_a + (1 - l) * target_b
outputs = net(mixed_input)
loss = mixup_criterion(outputs, labels_a, labels_b, l)
L = loss
else:
outputs = net(inputs)
loss = CEloss(outputs, labels)
if args.noise_mode == 'asym': # penalize confident prediction for asymmetric noise
penalty = conf_penalty(outputs)
L = loss + penalty
elif args.noise_mode == 'sym':
L = loss
L.backward()
optimizer.step()
if batch_idx % 200 == 0:
sys.stdout.write('\r')
sys.stdout.write('%s:%.1f-%s | Epoch [%3d/%3d] Iter[%3d/%3d]\t CE-loss: %.4f'
% (args.dataset, args.r, args.noise_mode, epoch, args.num_epochs, batch_idx + 1, num_iter,
loss.item()))
sys.stdout.flush()
def test(epoch, net1, net2, best_acc, w_glob=None):
if w_glob is None:
net1.eval()
net2.eval()
correct = 0
total = 0
with torch.no_grad():
for batch_idx, (inputs, targets) in enumerate(test_loader):
inputs, targets = inputs.cuda(), targets.cuda()
outputs1 = net1(inputs)
outputs2 = net2(inputs)
outputs = outputs1 + outputs2
_, predicted = torch.max(outputs, 1)
total += targets.size(0)
correct += predicted.eq(targets).cpu().sum().item()
acc = 100. * correct / total
if best_acc < acc:
best_acc = acc
print("\n| Ensemble network Test Epoch #%d\t Accuracy: %.2f, best_acc: %.2f%%\n" % (epoch, acc, best_acc))
test_log.write('ensemble_Epoch:%d Accuracy:%.2f, best_acc: %.2f\n' % (epoch, acc, best_acc))
test_log.flush()
else:
net1_w_bak = net1.state_dict()
net1.load_state_dict(w_glob)
net1.eval()
correct = 0
total = 0
with torch.no_grad():
for batch_idx, (inputs, targets) in enumerate(test_loader):
inputs, targets = inputs.cuda(), targets.cuda()
outputs1 = net1(inputs)
_, predicted = torch.max(outputs1, 1)
total += targets.size(0)
correct += predicted.eq(targets).cpu().sum().item()
acc = 100. * correct / total
if best_acc < acc:
best_acc = acc
print("\n| Global network Test Epoch #%d\t Accuracy: %.2f, best_acc: %.2f%%\n" % (epoch, acc, best_acc))
test_log.write('global_Epoch:%d Accuracy:%.2f, best_acc: %.2f\n' % (epoch, acc, best_acc))
test_log.flush()
# 恢复权重
net1.load_state_dict(net1_w_bak)
return best_acc
feat_dim = 512 #是否可以加个全连接改成128
sim = torch.nn.CosineSimilarity(dim=1)
loss_func = torch.nn.CrossEntropyLoss(reduction='none')
def get_small_loss_samples(y_pred, y_true, forget_rate):
loss = loss_func(y_pred, y_true)
ind_sorted = np.argsort(loss.data.cpu()).cuda()
loss_sorted = loss[ind_sorted]
remember_rate = 1 - forget_rate
num_remember = int(remember_rate * len(loss_sorted))
ind_update = ind_sorted[:num_remember]
return ind_update
def get_small_loss_by_loss_list(loss_list, forget_rate, eval_loader):
remember_rate = 1 - forget_rate
idx_list = []
for i in range(10):
class_idx = np.where(np.array(eval_loader.dataset.noise_label)[:] == i)[0]
# class_idx = torch.from_numpy(class_idx).cuda()
loss_per_class = loss_list[class_idx] #取对应target的loss
num_remember = int(remember_rate * len(loss_per_class))
ind_sorted = np.argsort(loss_per_class.data.cpu())
ind_update = ind_sorted[:num_remember].tolist()
idx_list.append(ind_update)
return idx_list
def eval_train(model, all_loss):
model.eval()
losses = torch.zeros(50000)
f_G = torch.zeros(args.num_class, feat_dim).cuda()
f_all = torch.zeros(50000, feat_dim).cuda()
n_labels = torch.zeros(args.num_class, 1).cuda()
y_k_tilde = torch.zeros(50000)
mask = np.zeros(50000)
with torch.no_grad():
for batch_idx, (inputs, targets, index) in enumerate(eval_loader):
inputs, targets = inputs.cuda(), targets.cuda()
outputs, feat = model(inputs, feat_out=True)
loss = CE(outputs, targets)
_, predicted = torch.max(outputs, 1)
for b in range(inputs.size(0)):
losses[index[b]] = loss[b]
f_G[predicted[b]] += feat[b]
n_labels[predicted[b]] += 1
f_all[index] = feat
assert torch.sum(n_labels) == 50000
for i in range(len(n_labels)):
if n_labels[i] == 0:
n_labels[i] = 1
f_G = torch.div(f_G, n_labels)
f_G = F.normalize(f_G, dim=1)
f_all = F.normalize(f_all, dim=1)
temp = f_G.t()
sim_all = torch.mm(f_all, temp) # .cpu().numpy()
y_k_tilde = torch.argmax(sim_all.cpu(), dim=1)
with torch.no_grad():
for batch_idx, (inputs, targets, index) in enumerate(eval_loader):
for i in range(len(index)):
if y_k_tilde[index[i]] == targets[i]:
mask[index[i]] = 1
losses = (losses - losses.min()) / (losses.max() - losses.min())
all_loss.append(losses)
if args.r == 0.9:
# average loss over last 5 epochs to improve convergence stability
history = torch.stack(all_loss)
input_loss = history[-5:].mean(0)
input_loss = input_loss.reshape(-1, 1)
else:
input_loss = losses.reshape(-1, 1)
# fit a two-component GMM to the loss
# 参数如下:
# n_components 聚类数量,max_iter 最大迭代次数,tol 阈值低于停止,reg_covar 协方差矩阵对角线上非负正则化参数,接近0即可
gmm = GaussianMixture(n_components=2, max_iter=10, tol=1e-2, reg_covar=5e-4)
gmm.fit(input_loss)
prob = gmm.predict_proba(input_loss)
prob = prob[:, gmm.means_.argmin()]
return prob, all_loss, losses.numpy(), mask, f_G
def mix_data_lab(x, y, alpha=1.0):
'''Returns mixed inputs, pairs of targets, and lambda'''
if alpha > 0:
lam = np.random.beta(alpha, alpha)
else:
lam = 1
batch_size = x.size()[0]
index = torch.randperm(batch_size).cuda()
lam = max(lam, 1 - lam)
mixed_x = lam * x + (1 - lam) * x[index, :]
y_a, y_b = y, y[index]
return mixed_x, y_a, y_b, index, lam
def linear_rampup(current, warm_up, rampup_length=16):
# 线性warm_up,对sym噪声使用标准CE训练一段时间
# 实际warm up epoch是warm_up+rampup_length
current = np.clip((current - warm_up) / rampup_length, 0.0, 1.0)
re_val = args.lambda_u * float(current)
# print(" current warm up parameters:", current)
# print("return parameters:", re_val)
return re_val
class SemiLoss(object):
def __call__(self, outputs_x, targets_x, outputs_u, targets_u, epoch, warm_up):
probs_u = torch.softmax(outputs_u, dim=1)
# 利用mixup后的交叉熵,px输出*log(px_model)
Lx = -torch.mean(torch.sum(F.log_softmax(outputs_x, dim=1) * targets_x, dim=1))
# 而UNlabel则是均方误差,p_u输出-pu_model
Lu = torch.mean((probs_u - targets_u) ** 2)
return Lx, Lu, linear_rampup(epoch, warm_up)
class NegEntropy(object):
def __call__(self, outputs):
probs = torch.softmax(outputs, dim=1)
return torch.mean(torch.sum(probs.log() * probs, dim=1))
def create_model():
# 其实是pre-resnet18,使用的是pre-resnet block
model = ResNet18(num_classes=args.num_class)
model = model.cuda()
return model
def plotHistogram(model_1_loss, model_2_loss, noise_index, clean_index, epoch, round, noise_rate):
title = 'Epoch-' + str(epoch)+':'
fig = plt.figure()
plt.subplot(121)
gmm = GaussianMixture(n_components=2, max_iter=20, tol=1e-2, random_state=0, reg_covar=5e-4)
model_1_loss = np.reshape(model_1_loss, (-1, 1))
gmm.fit(model_1_loss) # fit the loss
# plot resulting fit
x_range = np.linspace(0, 1, 1000)
pdf = np.exp(gmm.score_samples(x_range.reshape(-1, 1)))
responsibilities = gmm.predict_proba(x_range.reshape(-1, 1))
pdf_individual = responsibilities * pdf[:, np.newaxis]
plt.hist(np.array(model_1_loss[noise_index]), density=True, bins=100, alpha=0.5,histtype='bar', color='red', label='Noisy subset')
plt.hist(np.array(model_1_loss[clean_index]), density=True, bins=100, alpha=0.5,histtype='bar', color='blue', label='Clean subset')
plt.plot(x_range, pdf, '-k', label='Mixture')
plt.plot(x_range, pdf_individual, '--', label='Component')
plt.legend(loc='upper right', prop={'size': 12})
plt.xlabel('Normalized loss')
plt.ylabel('Estimated pdf')
plt.title(title+'Model_1')
plt.subplot(122)
gmm = GaussianMixture(n_components=2, max_iter=20, tol=1e-2, random_state=0, reg_covar=5e-4)
model_2_loss = np.reshape(model_2_loss, (-1, 1))
gmm.fit(model_2_loss) # fit the loss
# plot resulting fit
x_range = np.linspace(0, 1, 1000)
pdf = np.exp(gmm.score_samples(x_range.reshape(-1, 1)))
responsibilities = gmm.predict_proba(x_range.reshape(-1, 1))
pdf_individual = responsibilities * pdf[:, np.newaxis]
plt.hist(np.array(model_2_loss[noise_index]), density=True, bins=100, alpha=0.5,histtype='bar', color='red', label='Noisy subset')
plt.hist(np.array(model_2_loss[clean_index]), density=True, bins=100, alpha=0.5,histtype='bar', color='blue', label='Clean subset')
plt.plot(x_range, pdf, '-k', label='Mixture')
plt.plot(x_range, pdf_individual, '--', label='Component')
plt.legend(loc='upper right', prop={'size': 12})
plt.xlabel('Normalized loss')
plt.ylabel('Estimated pdf')
plt.title(title+'Model_2')
print('\nlogging histogram...')
title = 'cifar10_' + str(args.noise_mode) + '_moit_double_' + str(noise_rate)
plt.savefig(os.path.join('./figure_his/', 'two_model_{}_{}_{}_{}.{}'.format(epoch, round, title, int(soft_mix_warm), ".tif")), dpi=300)
# plt.show()
plt.close()
def loss_dist_plot(loss, noisy_index, clean_index, epoch, rou=None, g_file=True, model_name='', loss2=None):
"""
plot the loss distribution
:param loss: the list contains the loss per sample
:param noisy_index: contains the indices of real noisy label
:param clean_index: contains the indices of real clean label
:param filename: the generated pdf file name
:param title: the figure title
:param g_file: whether to generate the pdf figure file
:return: None
"""
if loss2 is None:
filename = 'one_model_'+str(args.dataset)+'_'+str(args.noise_mode)+'_'+str(args.r)+'_epoch='+str(epoch)
if rou is None:
title = 'Epoch-'+str(epoch) + ': ' + str(args.dataset)+' '+str(args.r*100)+'%-'+str(args.noise_mode)
else:
title = 'Epoch-' + str(epoch) + ' ' +'Round-'+str(rou)+ ': ' + str(args.dataset) + ' ' + str(int(args.r * 100)) + '%-' + str(args.noise_mode)
if type(loss) is not np.ndarray:
loss= loss.numpy()
sns.set(style='whitegrid')
gmm = GaussianMixture(n_components=2, max_iter=20, tol=1e-2, random_state=0, reg_covar=5e-4)
loss = np.reshape(loss, (-1, 1))
gmm.fit(loss) # fit the loss
# plot resulting fit
x_range = np.linspace(0, 1, 1000)
pdf = np.exp(gmm.score_samples(x_range.reshape(-1, 1)))
responsibilities = gmm.predict_proba(x_range.reshape(-1, 1))
pdf_individual = responsibilities * pdf[:, np.newaxis]
# sns.distplot(loss[noisy_index], color="red", rug=False,kde=False, label="incorrect",
# hist_kws={"color": "r", "alpha": 0.5})
# sns.distplot(loss[clean_index], color="skyblue", rug=False,kde=False, label="correct",
# hist_kws={"color": "b", "alpha": 0.5})
plt.hist(np.array(loss[noisy_index]), density=True, bins=100, histtype='bar', alpha=0.5, color='red',
label='Noisy subset')
plt.hist(np.array(loss[clean_index]), density=True, bins=100, histtype='bar', alpha=0.5, color='blue',
label='Clean subset')
plt.plot(x_range, pdf, '-k', label='Mixture')
plt.plot(x_range, pdf_individual, '--', label='Component')
# plt.plot(x_range, pdf_individual[:][1], '--', color='blue', label='Component 1')
plt.title(title, fontsize=20)
plt.xlabel('Normalized loss', fontsize=24)
plt.ylabel('Estimated pdf', fontsize=24)
plt.tick_params(labelsize=24)
plt.legend(loc='upper right', prop={'size': 12})
# plt.tight_layout()
if g_file:
plt.savefig('./figure_his/{0}.tif'.format(filename+model_name), bbox_inches='tight', dpi=300)
#plt.show()
plt.close()
else:
filename = 'noise_'+str(args.dataset) + '_' + str(args.noise_mode) + '_' + str(args.r) + '_epoch=' + str(epoch)
if rou is None:
title = 'Epoch-' + str(epoch) + ': ' + str(args.dataset) + ' ' + str(args.r * 100) + '%-' + str(
args.noise_mode)
else:
title = 'Epoch-' + str(epoch) + ' ' + 'Round-' + str(rou) + ': ' + str(args.dataset) + ' ' + str(
args.r * 100) + '%-' + str(args.noise_mode)
if type(loss) is not np.ndarray:
loss = loss.numpy()
if type(loss2) is not np.ndarray:
loss2 = loss2.numpy()
fig = plt.figure()
plt.subplot(121)
sns.set(style='whitegrid')
sns.distplot(loss[noisy_index], color="red", rug=False, kde=False, label="incorrect",
hist_kws={"color": "r", "alpha": 0.5})
sns.distplot(loss[clean_index], color="skyblue", rug=False, kde=False, label="correct",
hist_kws={"color": "b", "alpha": 0.5})
plt.title('Model_1', fontsize=32)
plt.xlabel('Normalized loss', fontsize=32)
plt.ylabel('Sample number', fontsize=32)
plt.tick_params(labelsize=32)
plt.legend(loc='upper right', prop={'size': 24})
plt.subplot(122)
sns.set(style='whitegrid')
sns.distplot(loss2[noisy_index], color="red", rug=False, kde=False, label="incorrect",
hist_kws={"color": "r", "alpha": 0.5})
sns.distplot(loss2[clean_index], color="skyblue", rug=False, kde=False, label="correct",
hist_kws={"color": "b", "alpha": 0.5})
plt.title('Model_2', fontsize=32)
plt.xlabel('Normalized loss', fontsize=32)
plt.ylabel('Sample number', fontsize=32)
plt.tick_params(labelsize=32)
plt.legend(loc='upper right', prop={'size': 24})
# plt.tight_layout()
if g_file:
plt.savefig('./figure_his/{0}.tif'.format(filename + model_name), bbox_inches='tight', dpi=300)
# plt.show()
plt.close()
def loss_dist_plot_real(loss, epoch, rou=None, g_file=True, model_name=''):
"""
plot the loss distribution
:param loss: the list contains the loss per sample
:param noisy_index: contains the indices of real noisy label
:param clean_index: contains the indices of real clean label
:param filename: the generated pdf file name
:param title: the figure title
:param g_file: whether to generate the pdf figure file
:return: None
"""
filename = str(args.dataset) + '_' + str(args.noise_mode) + '_' + str(args.r) + '_epoch=' + str(epoch)
if rou is None:
title = 'Epoch-' + str(epoch) + ': ' + str(args.dataset) + ' ' + str(args.r * 100) + '%-' + str(args.noise_mode)
else:
title = 'Epoch-' + str(epoch) + ' ' + 'Round-' + str(rou) + ': ' + str(args.dataset) + ' ' + str(args.r * 100) + '%-' + str(args.noise_mode)
if type(loss) is not np.ndarray:
loss= loss.numpy()
sns.set(style='whitegrid')
gmm = GaussianMixture(n_components=2, max_iter=20, tol=1e-2, random_state=0, reg_covar=5e-4)
loss = np.reshape(loss, (-1, 1))
gmm.fit(loss) # fit the loss
# plot resulting fit
x_range = np.linspace(0, 1, 1000)
pdf = np.exp(gmm.score_samples(x_range.reshape(-1, 1)))
responsibilities = gmm.predict_proba(x_range.reshape(-1, 1))
pdf_individual = responsibilities * pdf[:, np.newaxis]
plt.hist(loss, bins=60, density=True, histtype='bar', alpha=0.3)
plt.plot(x_range, pdf, '-k', label='Mixture')
plt.plot(x_range, pdf_individual, '--', label='Component')
plt.legend()
# plt.tight_layout()
plt.title(title, fontsize=32)
plt.xlabel('Normalized loss', fontsize=32)
plt.ylabel('Estimated PDF', fontsize=32)
plt.tick_params(labelsize=32)
plt.legend(loc='upper right', prop={'size': 22})
if g_file:
plt.savefig('./figure_his/{0}.tif'.format(filename+model_name), bbox_inches='tight', dpi=300)
#plt.show()
plt.close()
def FedAvg(w):
w_avg = copy.deepcopy(w[0])
for k in w_avg.keys():
for i in range(1, len(w)):
w_avg[k] += w[i][k]
# 只考虑iid noise的话,每个client训练样本数一样,所以不用做nk/n
w_avg[k] = torch.div(w_avg[k], len(w))
return w_avg
if os.path.exists('checkpoint') == False:
os.mkdir('checkpoint')
print("新建日志文件夹")
stats_log = open('./checkpoint/single_%s_%.1f_%s_%d' % (args.dataset, args.r, args.noise_mode,
int(soft_mix_warm)) + '_stats.txt', 'w')
test_log = open('./checkpoint/single_%s_%.1f_%s_%d' % (args.dataset, args.r, args.noise_mode,
int(soft_mix_warm)) + '_acc.txt', 'w')
warm_up = 10
dmix_epoch = 150
args.num_epochs = dmix_epoch + 150
# 第6页提及的warm up的epoch
if args.dataset == 'cifar10':
warm_up = 10
dmix_epoch = 150
args.num_epochs = dmix_epoch + 50
elif args.dataset == 'cifar100':
warm_up = 30
dmix_epoch = 150
args.num_epochs = dmix_epoch + 50
loader = dataloader.cifar_dataloader(args.dataset, r=args.r, noise_mode=args.noise_mode,
batch_size=args.batch_size, num_workers=0,
root_dir=args.data_path, log=stats_log,
noise_file='%s/%.1f_%s.json' % (args.data_path, args.r, args.noise_mode))
print('| Building net')
net1 = create_model()
net2 = create_model()
cudnn.benchmark = True
criterion = SemiLoss()
optimizer1 = optim.SGD(net1.parameters(), lr=args.lr, momentum=0.9, weight_decay=5e-4)
optimizer2 = optim.SGD(net2.parameters(), lr=args.lr, momentum=0.9, weight_decay=5e-4)
CE = nn.CrossEntropyLoss(reduction='none')
CEloss = nn.CrossEntropyLoss()
if args.noise_mode == 'asym':
# 本文第一个问题,对于非对称和对称需要不同措施,这很不适用
# 其次本文在不同步骤中噪声数据处理措施很凌乱
conf_penalty = NegEntropy()
all_loss = [[], []] # save the history of losses from two networks
local_round = 5
first = True
balance_crit = 'median'
exp_path = './checkpoint/single_%s_%.1f_%s_double_m2_' % (args.dataset, args.r, args.noise_mode)
save_clean_idx = exp_path + "clean_idx.npy"
boot_loader = None
w_glob = None
if args.r == 0.9:
args.p_threshold = 0.6
best_en_acc = 0.
best_gl_acc = 0.
resume_epoch = 0
if resume_epoch > 0:
snapLast = exp_path + str(resume_epoch-1) + "_global_model.pth"
global_state = torch.load(snapLast)
# 先更新还是后跟新
w_glob = global_state
net1.load_state_dict(global_state)
net2.load_state_dict(global_state)
for epoch in range(resume_epoch, args.num_epochs + 1):
test_loader = loader.run('test')
eval_loader = loader.run('eval_train')
lr = args.lr
if epoch >= dmix_epoch:
lr /= 10
for param_group in optimizer1.param_groups:
param_group['lr'] = lr
for param_group in optimizer2.param_groups:
param_group['lr'] = lr
noise_ind, clean_ind = eval_loader.dataset.if_noise()
print(len(np.where(np.array(eval_loader.dataset.noise_label) != np.array(eval_loader.dataset.clean_label))[0])
/ len(eval_loader.dataset.clean_label))
local_weights = []
if epoch < warm_up:
# 考虑warm up时是否需要merge
warmup_trainloader = loader.run('warmup')
print('Warmup Net1')
warmup(epoch, net1, optimizer1, warmup_trainloader)
print('\nWarmup Net2')
warmup(epoch, net2, optimizer2, warmup_trainloader)
if epoch == (warm_up-1):
snapLast = exp_path+str(epoch) + "_1_model.pth"
torch.save(net1.state_dict(), snapLast)
snapLast = exp_path+str(epoch) + "_2_model.pth"
torch.save(net1.state_dict(), snapLast)
local_weights.append(net1.state_dict())
local_weights.append(net2.state_dict())
w_glob = FedAvg(local_weights)
else:
if epoch != warm_up:
net1.load_state_dict(w_glob)
net2.load_state_dict(w_glob)
for rou in range(local_round):
prob1, all_loss[0], loss1, mask1, f_G1 = eval_train(net1, all_loss[0])
prob2, all_loss[1], loss2, mask2, f_G2 = eval_train(net2, all_loss[1])
# 加载完global后第一次评估
if rou == 0:
# plotHistogram(np.array(loss1), np.array(loss2), noise_ind, clean_ind, epoch, rou, args.r)
loss_dist_plot(loss1, noise_ind, clean_ind, epoch, model_name='model_1')
# loss_dist_plot_real(loss1, epoch, model_name='model_1')
if rou == local_round-1:
plotHistogram(np.array(loss1), np.array(loss2), noise_ind, clean_ind, epoch, rou, args.r)
# pred1 = (prob1 > args.p_threshold) & (mask1 != 0)
# pred2 = (prob2 > args.p_threshold) & (mask2 != 0)
pred1 = (prob1 > args.p_threshold)
pred2 = (prob2 > args.p_threshold)
non_zero_idx = pred1.nonzero()[0].tolist()
aaa = len(non_zero_idx)
if balance_crit == "max" or balance_crit == "min" or balance_crit == "median":
num_clean_per_class = np.zeros(args.num_class)
target_label = np.array(eval_loader.dataset.noise_label)[non_zero_idx]
for i in range(args.num_class):
idx_class = np.where(target_label == i)[0]
num_clean_per_class[i] = len(idx_class)
if balance_crit == "median":
num_samples2select_class = np.median(num_clean_per_class)
for i in range(args.num_class):
idx_class = np.where(np.array(eval_loader.dataset.noise_label) == i)[0]
cur_num = num_clean_per_class[i]
idx_class2 = non_zero_idx
if num_samples2select_class > cur_num:
remian_idx = list(set(idx_class.tolist()) - set(idx_class2))
idx = list(range(len(remian_idx)))
random.shuffle(idx)
num_app = int(num_samples2select_class - cur_num)
idx = idx[:num_app]
for j in idx:
non_zero_idx.append(remian_idx[j])
non_zero_idx = np.array(non_zero_idx).reshape(-1, )
bbb = len(non_zero_idx)
num_per_class2 = []
for i in range(max(eval_loader.dataset.noise_label)):
temp = np.where(np.array(eval_loader.dataset.noise_label)[non_zero_idx.tolist()] == i)[0]
num_per_class2.append(len(temp))
print('\npred1 appended num per class:', num_per_class2, aaa, bbb)
idx_per_class = np.zeros_like(pred1).astype(bool)
for i in non_zero_idx:
idx_per_class[i] = True
pred1 = idx_per_class
non_aaa = pred1.nonzero()[0].tolist()
assert len(non_aaa) == len(non_zero_idx)
non_zero_idx2 = pred2.nonzero()[0].tolist()
aaa = len(non_zero_idx2)
if balance_crit == "max" or balance_crit == "min" or balance_crit == "median":
num_clean_per_class = np.zeros(args.num_class)
target_label = np.array(eval_loader.dataset.noise_label)[non_zero_idx2]
for i in range(args.num_class):
idx_class = np.where(target_label == i)[0]
num_clean_per_class[i] = len(idx_class)
if balance_crit == "median":
num_samples2select_class = np.median(num_clean_per_class)
for i in range(args.num_class):
idx_class = np.where(np.array(eval_loader.dataset.noise_label) == i)[0]
cur_num = num_clean_per_class[i]
idx_class2 = non_zero_idx2
if num_samples2select_class > cur_num:
remian_idx = list(set(idx_class.tolist()) - set(idx_class2))
idx = list(range(len(remian_idx)))
random.shuffle(idx)
num_app = int(num_samples2select_class - cur_num)
idx = idx[:num_app]
for j in idx:
non_zero_idx2.append(remian_idx[j])
non_zero_idx2 = np.array(non_zero_idx2).reshape(-1, )
bbb = len(non_zero_idx2)
num_per_class2 = []
for i in range(max(eval_loader.dataset.noise_label)):
temp = np.where(np.array(eval_loader.dataset.noise_label)[non_zero_idx2.tolist()] == i)[0]
num_per_class2.append(len(temp))
print('\npred2 appended num per class:', num_per_class2, aaa, bbb)
idx_per_class2 = np.zeros_like(pred2).astype(bool)
for i in non_zero_idx2:
idx_per_class2[i] = True
pred2 = idx_per_class2
non_aaa = pred2.nonzero()[0].tolist()
assert len(non_aaa) == len(non_zero_idx2)
correct_num = len(pred1.nonzero()[0])
eval_loader.dataset.if_noise(pred1)
eval_loader.dataset.if_noise(pred2)
print(f'round={rou}/{local_round}, dmix selection, Train Net1')
# prob2就是先验概率wi,通过GMM拟合出来的,大于阈值就认为是clean,否则noisy
labeled_trainloader, unlabeled_trainloader = loader.run('train', pred2, prob2) # co-divide
train(epoch, net1, net2, optimizer1, labeled_trainloader, unlabeled_trainloader) # train net1
print(f'\nround={rou}/{local_round}, dmix selection, Train Net2')
labeled_trainloader, unlabeled_trainloader = loader.run('train', pred1, prob1) # co-divide
train(epoch, net2, net1, optimizer2, labeled_trainloader, unlabeled_trainloader) # train net2
local_weights.append(net1.state_dict())
local_weights.append(net2.state_dict())
w_glob = FedAvg(local_weights)
if epoch % 5 == 0:
snapLast = exp_path + str(epoch) + "_global_model.pth"
torch.save(w_glob, snapLast)
best_en_acc = test(epoch, net1, net2, best_en_acc)
best_gl_acc= test(epoch, net1, net2, best_gl_acc, w_glob)