EdgeTA / new_impl /cv /resnet /extract_submodel.py
LINC-BIT's picture
Upload 1912 files
b84549f verified
from abc import abstractmethod
from copy import deepcopy
import enum
import torch
from torch import nn
import os
from .model_fbs import DomainDynamicConv2d
#from methods.utils.data import get_source_dataloader, get_source_normal_aug_dataloader, get_target_dataloaders
#from models.resnet_cifar.model_manager import ResNetCIFARManager
from utils.common.others import get_cur_time_str
from utils.dl.common.env import set_random_seed
from utils.dl.common.model import get_model_latency, get_model_size, get_module, set_module
from utils.common.log import logger
from utils.third_party.nni_new.compression.pytorch.speedup import ModelSpeedup
from utils.third_party.nni_new.compression.pytorch.utils.mask_conflict import GroupMaskConflict, ChannelMaskConflict, CatMaskPadding
def fix_mask_conflict(masks, model=None, dummy_input=None, traced=None, fix_group=False, fix_channel=True, fix_padding=False):
if isinstance(masks, str):
# if the input is the path of the mask_file
assert os.path.exists(masks)
masks = torch.load(masks)
assert len(masks) > 0, 'Mask tensor cannot be empty'
# if the user uses the model and dummy_input to trace the model, we
# should get the traced model handly, so that, we only trace the
# model once, GroupMaskConflict and ChannelMaskConflict will reuse
# this traced model.
if traced is None:
assert model is not None and dummy_input is not None
training = model.training
model.eval()
# We need to trace the model in eval mode
traced = torch.jit.trace(model, dummy_input)
model.train(training)
if fix_group:
fix_group_mask = GroupMaskConflict(masks, model, dummy_input, traced)
masks = fix_group_mask.fix_mask()
if fix_channel:
fix_channel_mask = ChannelMaskConflict(masks, model, dummy_input, traced)
masks = fix_channel_mask.fix_mask()
if fix_padding:
padding_cat_mask = CatMaskPadding(masks, model, dummy_input, traced)
masks = padding_cat_mask.fix_mask()
return masks
class FeatureBoosting(nn.Module):
def __init__(self, w: torch.Tensor):
super(FeatureBoosting, self).__init__()
assert w.dim() == 1
self.w = nn.Parameter(w.unsqueeze(0).unsqueeze(2).unsqueeze(3), requires_grad=False)
def forward(self, x):
return x * self.w
class FBSSubModelExtractor:
def extract_submodel_via_a_sample(self, fbs_model: nn.Module, sample: torch.Tensor):
assert sample.dim() == 4 and sample.size(0) == 1
fbs_model.eval()
o1 = fbs_model(sample)
pruning_info = {}
pruning_masks = {}
for layer_name, layer in fbs_model.named_modules():
if not isinstance(layer, DomainDynamicConv2d):
continue
cur_pruning_mask = {'weight': torch.zeros_like(layer.raw_conv2d.weight.data)}
if layer.raw_conv2d.bias is not None:
cur_pruning_mask['bias'] = torch.zeros_like(layer.raw_conv2d.bias.data)
w = get_module(fbs_model, layer_name).cached_w.squeeze()
unpruned_filters_index = w.nonzero(as_tuple=True)[0]
pruning_info[layer_name] = w
cur_pruning_mask['weight'][unpruned_filters_index, ...] = 1.
if layer.raw_conv2d.bias is not None:
cur_pruning_mask['bias'][unpruned_filters_index, ...] = 1.
pruning_masks[layer_name + '.0'] = cur_pruning_mask
no_gate_model = deepcopy(fbs_model)
for name, layer in no_gate_model.named_modules():
if not isinstance(layer, DomainDynamicConv2d):
continue
# layer.bn.weight.data.mul_(pruning_info[name])
set_module(no_gate_model, name, nn.Sequential(layer.raw_conv2d, layer.bn, nn.Identity()))
# fixed_pruning_masks = fix_mask_conflict(pruning_masks, fbs_model, sample.size(), None, True, True, True)
tmp_mask_path = f'tmp_mask_{get_cur_time_str()}_{os.getpid()}.pth'
torch.save(pruning_masks, tmp_mask_path)
pruned_model = no_gate_model
pruned_model.eval()
model_speedup = ModelSpeedup(pruned_model, sample, tmp_mask_path, sample.device)
model_speedup.speedup_model()
os.remove(tmp_mask_path)
# add feature boosting module
for layer_name, feature_boosting_w in pruning_info.items():
feature_boosting_w = feature_boosting_w[feature_boosting_w.nonzero(as_tuple=True)[0]]
set_module(pruned_model, layer_name + '.2', FeatureBoosting(feature_boosting_w))
pruned_model_size = get_model_size(pruned_model, True)
pruned_model.eval()
o2 = pruned_model(sample)
diff = ((o1 - o2) ** 2).sum()
logger.info(f'pruned model size: {pruned_model_size:.3f}MB, diff: {diff}')
return pruned_model
@abstractmethod
def get_final_w(self, fbs_model: nn.Module, samples: torch.Tensor, layer_name: str, w: torch.Tensor):
pass
@abstractmethod
def generate_pruning_strategy(self, fbs_model: nn.Module, samples: torch.Tensor):
pass
def extract_submodel_via_samples(self, fbs_model: nn.Module, samples: torch.Tensor):
assert samples.dim() == 4
fbs_model = deepcopy(fbs_model)
# fbs_model.eval()
# fbs_model(samples)
self.generate_pruning_strategy(fbs_model, samples)
pruning_info = {}
pruning_masks = {}
for layer_name, layer in fbs_model.named_modules():
if not isinstance(layer, DomainDynamicConv2d):
continue
cur_pruning_mask = {'weight': torch.zeros_like(layer.raw_conv2d.weight.data)}
if layer.raw_conv2d.bias is not None:
cur_pruning_mask['bias'] = torch.zeros_like(layer.raw_conv2d.bias.data)
w = get_module(fbs_model, layer_name).cached_w.squeeze() # 2-dim
w = self.get_final_w(fbs_model, samples, layer_name, w)
unpruned_filters_index = w.nonzero(as_tuple=True)[0]
pruning_info[layer_name] = w
cur_pruning_mask['weight'][unpruned_filters_index, ...] = 1.
if layer.raw_conv2d.bias is not None:
cur_pruning_mask['bias'][unpruned_filters_index, ...] = 1.
pruning_masks[layer_name + '.0'] = cur_pruning_mask
no_gate_model = deepcopy(fbs_model)
for name, layer in no_gate_model.named_modules():
if not isinstance(layer, DomainDynamicConv2d):
continue
# layer.bn.weight.data.mul_(pruning_info[name])
set_module(no_gate_model, name, nn.Sequential(layer.raw_conv2d, layer.bn, nn.Identity()))
# fixed_pruning_masks = fix_mask_conflict(pruning_masks, fbs_model, sample.size(), None, True, True, True)
tmp_mask_path = f'tmp_mask_{get_cur_time_str()}_{os.getpid()}.pth'
torch.save(pruning_masks, tmp_mask_path)
pruned_model = no_gate_model
pruned_model.eval()
model_speedup = ModelSpeedup(pruned_model, samples[0:1], tmp_mask_path, samples.device)
model_speedup.speedup_model()
os.remove(tmp_mask_path)
# add feature boosting module
for layer_name, feature_boosting_w in pruning_info.items():
feature_boosting_w = feature_boosting_w[feature_boosting_w.nonzero(as_tuple=True)[0]]
set_module(pruned_model, layer_name + '.2', FeatureBoosting(feature_boosting_w))
return pruned_model, pruning_info
def extract_submodel_via_samples_and_last_submodel(self, fbs_model: nn.Module, samples: torch.Tensor,
last_submodel: nn.Module, last_pruning_info: dict):
assert samples.dim() == 4
fbs_model = deepcopy(fbs_model)
# fbs_model.eval()
# fbs_model(samples)
self.generate_pruning_strategy(fbs_model, samples)
pruning_info = {}
pruning_masks = {}
# some tricks
incrementally_updated_layers = []
for layer_name, layer in fbs_model.named_modules():
if not isinstance(layer, DomainDynamicConv2d):
continue
cur_pruning_mask = {'weight': torch.zeros_like(layer.raw_conv2d.weight.data)}
if layer.raw_conv2d.bias is not None:
cur_pruning_mask['bias'] = torch.zeros_like(layer.raw_conv2d.bias.data)
w = get_module(fbs_model, layer_name).cached_w.squeeze() # 2-dim
w = self.get_final_w(fbs_model, samples, layer_name, w)
unpruned_filters_index = w.nonzero(as_tuple=True)[0]
pruning_info[layer_name] = w
cur_pruning_mask['weight'][unpruned_filters_index, ...] = 1.
if layer.raw_conv2d.bias is not None:
cur_pruning_mask['bias'][unpruned_filters_index, ...] = 1.
pruning_masks[layer_name + '.0'] = cur_pruning_mask
# some tricks
if last_pruning_info is not None:
last_w = last_pruning_info[layer_name]
intersection_ratio = ((w > 0) * (last_w > 0)).sum() / (last_w > 0).sum()
if intersection_ratio > 0.:
incrementally_updated_layers += [layer_name] # that is, only similar layers are transferable
no_gate_model = deepcopy(fbs_model)
for name, layer in no_gate_model.named_modules():
if not isinstance(layer, DomainDynamicConv2d):
continue
# layer.bn.weight.data.mul_(pruning_info[name])
set_module(no_gate_model, name, nn.Sequential(layer.raw_conv2d, layer.bn, nn.Identity()))
# fixed_pruning_masks = fix_mask_conflict(pruning_masks, fbs_model, sample.size(), None, True, True, True)
tmp_mask_path = f'tmp_mask_{get_cur_time_str()}_{os.getpid()}.pth'
torch.save(pruning_masks, tmp_mask_path)
pruned_model = no_gate_model
pruned_model.eval()
model_speedup = ModelSpeedup(pruned_model, samples[0:1], tmp_mask_path, samples.device)
model_speedup.speedup_model()
os.remove(tmp_mask_path)
# add feature boosting module
for layer_name, feature_boosting_w in pruning_info.items():
feature_boosting_w = feature_boosting_w[feature_boosting_w.nonzero(as_tuple=True)[0]]
set_module(pruned_model, layer_name + '.2', FeatureBoosting(feature_boosting_w))
# some tricks
# incrementally updating (borrow some weights from last_pruned_model)
for layer_name in incrementally_updated_layers:
cur_filter_i, last_filter_i = 0, 0
for i, (w_factor, last_w_factor) in enumerate(zip(pruning_info[layer_name], last_pruning_info[layer_name])):
if w_factor > 0 and last_w_factor > 0: # the filter is shared
cur_conv2d, last_conv2d = get_module(pruned_model, layer_name + '.0'), get_module(last_submodel, layer_name + '.0')
cur_conv2d.weight.data[cur_filter_i] = last_conv2d.weight.data[last_filter_i]
cur_bn, last_bn = get_module(pruned_model, layer_name + '.1'), get_module(last_submodel, layer_name + '.1')
cur_bn.weight.data[cur_filter_i] = last_bn.weight.data[last_filter_i]
cur_bn.bias.data[cur_filter_i] = last_bn.bias.data[last_filter_i]
cur_bn.running_mean.data[cur_filter_i] = last_bn.running_mean.data[last_filter_i]
cur_bn.running_var.data[cur_filter_i] = last_bn.running_var.data[last_filter_i]
cur_fw, last_fw = get_module(pruned_model, layer_name + '.2'), get_module(last_submodel, layer_name + '.2')
cur_fw.w.data[0, cur_filter_i] = last_fw.w.data[0, last_filter_i]
if w_factor > 0:
cur_filter_i += 1
if last_w_factor > 0:
last_filter_i += 1
return pruned_model, pruning_info
def absorb_sub_model(self, fbs_model: nn.Module, sub_model: nn.Module, pruning_info: dict, alpha=1.):
if alpha == 0.:
return
for layer_name, feature_boosting_w in pruning_info.items():
unpruned_filters_index = feature_boosting_w.nonzero(as_tuple=True)[0]
fbs_layer = get_module(fbs_model, layer_name)
sub_model_layer = get_module(sub_model, layer_name)
for fi_in_sub_layer, fi_in_fbs_layer in enumerate(unpruned_filters_index):
fbs_layer.raw_conv2d.weight.data[fi_in_fbs_layer] = (1. - alpha) * fbs_layer.raw_conv2d.weight.data[fi_in_fbs_layer] + \
alpha * sub_model_layer[0].weight.data[fi_in_sub_layer]
for k in ['weight', 'bias', 'running_mean', 'running_var']:
getattr(fbs_layer.bn, k).data[fi_in_fbs_layer] = (1. - alpha) * getattr(fbs_layer.bn, k).data[fi_in_fbs_layer] + \
alpha * getattr(sub_model_layer[1], k).data[fi_in_sub_layer]
class DAFBSSubModelExtractor(FBSSubModelExtractor):
def __init__(self) -> None:
super().__init__()
# self.debug_sample_i = 0
# self.last_final_ws = None
@abstractmethod
def generate_pruning_strategy(self, fbs_model: nn.Module, samples: torch.Tensor):
with torch.no_grad():
fbs_model.eval()
self.cur_output = fbs_model(samples)
@abstractmethod
def get_final_w(self, fbs_model: nn.Module, samples: torch.Tensor, layer_name: str, w: torch.Tensor):
# import matplotlib.pyplot as plt
# plt.imshow(w.cpu().numpy(), cmap='Greys')
# # plt.colorbar()
# plt.xlabel('Filters')
# plt.ylabel('Samples')
# plt.tight_layout()
# plt.savefig(os.path.join(res_save_dir, f'{layer_name}.png'), dpi=300)
# plt.clf()
# w_sum = w.sum(0)
# w_argsort = w_sum.argsort(descending=True)
# return w[self.debug_sample_i]
# x = self.cur_output
# each_sample_entropy = -(x.softmax(1) * x.log_softmax(1)).sum(1)
# hardest_sample_index = w.sum(1).argmax()
# return w[hardest_sample_index]
# [0.0828, 0.1017, 0.0575, 0.3081, 0.1511, 0.3634, 0.3388, 0.3942, 0.2475, 0.3371, 0.5837, 0.145, 0.4428, 0.2159, 0.4028] 0.27815999999999996
x = self.cur_output
each_sample_entropy = -(x.logits.softmax(1) * x.logits.log_softmax(1)).sum(1)
hardest_sample_index = each_sample_entropy.argmax()
res = w[hardest_sample_index]
return res
# if self.last_final_ws is not None:
# intersection_ratio = (self.last_final_w == res).sum() / (res > 0).sum()
# print('intersection ratio: ', intersection_ratio)
# self.last_final_ws[layer_name] = res
# indices = (-w).sum(0).topk((w[0] == 0).sum())[1]
# boosting = w.max(0)[0]
# boosting[indices] = 0.
# return boosting
# return w[0]
def tent_as_detector(model, x, num_iters=1, lr=1e-4, l1_wd=0., strategy='ours'):
model = deepcopy(model)
before_model = deepcopy(model)
from methods.tent import tent
optimizer = torch.optim.SGD(
model.parameters(), lr=lr, weight_decay=l1_wd)
from models.resnet_cifar.model_manager import ResNetCIFARManager
tented_model = tent.Tent(model, optimizer, ResNetCIFARManager, steps=num_iters)
tent.configure_model(model)
tented_model(x)
filters_sen_info = {}
last_conv_name = None
for (name, m1), m2 in zip(model.named_modules(), before_model.modules()):
if isinstance(m1, nn.Conv2d):
last_conv_name = name
if not isinstance(m1, nn.BatchNorm2d):
continue
with torch.no_grad():
features_weight_diff = ((m1.weight.data - m2.weight.data).abs())
features_bias_diff = ((m1.bias.data - m2.bias.data).abs())
features_diff = features_weight_diff + features_bias_diff
features_diff_order = features_diff.argsort(descending=False)
if strategy == 'ours':
untrained_filters_index = features_diff_order[: int(len(features_diff) * 0.8)]
elif strategy == 'random':
untrained_filters_index = torch.randperm(len(features_diff))[: int(len(features_diff) * 0.8)]
elif strategy == 'inversed_ours':
untrained_filters_index = features_diff_order.flip(0)[: int(len(features_diff) * 0.8)]
elif strategy == 'none':
untrained_filters_index = None
filters_sen_info[name] = dict(untrained_filters_index=untrained_filters_index, conv_name=last_conv_name)
return filters_sen_info
class SGDF(torch.optim.SGD):
@torch.no_grad()
def step(self, model, conv_filters_sen_info, filters_sen_info, closure=None):
"""Performs a single optimization step.
Arguments:
closure (callable, optional): A closure that reevaluates the model
and returns the loss.
"""
loss = None
if closure is not None:
with torch.enable_grad():
loss = closure()
for group in self.param_groups:
weight_decay = group['weight_decay']
momentum = group['momentum']
dampening = group['dampening']
nesterov = group['nesterov']
# assert len([i for i in model.named_parameters()]) == len([j for j in group['params']])
for (name, _), p in zip(model.named_parameters(), group['params']):
if p.grad is None:
continue
layer_name = '.'.join(name.split('.')[0:-1])
if layer_name in filters_sen_info.keys():
untrained_filters_index = filters_sen_info[layer_name]['untrained_filters_index']
elif layer_name in conv_filters_sen_info.keys():
untrained_filters_index = conv_filters_sen_info[layer_name]['untrained_filters_index']
else:
untrained_filters_index = []
d_p = p.grad
if weight_decay != 0:
d_p = d_p.add(p, alpha=weight_decay)
if momentum != 0:
param_state = self.state[p]
if 'momentum_buffer' not in param_state:
buf = param_state['momentum_buffer'] = torch.clone(d_p).detach()
else:
buf = param_state['momentum_buffer']
buf.mul_(momentum).add_(d_p, alpha=1 - dampening)
if nesterov:
d_p = d_p.add(buf, alpha=momentum)
else:
d_p = buf
d_p[untrained_filters_index] = 0.
p.add_(d_p, alpha=-group['lr'])
return loss
if __name__ == '__main__':
set_random_seed(0)
import sys
tag = sys.argv[1]
# alpha = 0.4
alpha = 0.2
# alpha = float(sys.argv[1])
fbs_model_path = sys.argv[1]
cur_time_str = get_cur_time_str()
res_save_dir = f'logs/experiments_trial/CIFAR100C/ours_fbs_more_challenging/{cur_time_str[0:8]}/{cur_time_str[8:]}-{tag}'
os.makedirs(res_save_dir)
import shutil
shutil.copytree(os.path.dirname(__file__),
os.path.join(res_save_dir, 'method'), ignore=shutil.ignore_patterns('*.pt', '*.pth', 'log', '__pycache__'))
logger.info(f'res save dir: {res_save_dir}')
# model = torch.load('logs/experiments_trial/CIFAR100C/ours_dynamic_filters/20220801/152138-0.6_l1wd=1e-8/best_model_0.80.pt')
# model = torch.load('logs/experiments_trial/CIFAR100C/ours_dynamic_filters/20220801/232913-sample_subnetwork/best_model_0.80.pt')
model = torch.load(fbs_model_path)
# model = torch.load('logs/experiments_trial/CIFAR100C/ours_dynamic_filters/20220729/002444-0.4/best_model_0.40.pt')
# import sys
# sys.path.append('/data/xgf/legodnn_and_domain_adaptation')
xgf_model = torch.load('logs/experiments_trial/CIFAR100C/ours_dynamic_filters/20220731/224212-cifar10_svhn_raw/last_model.pt')
# xgf_model = torch.load('/data/xgf/legodnn_and_domain_adaptation/results_scaling_da/image_classification/CIFAR100C_resnet18/onda/offline_l1/s4/20220607/204211/last_model.pt')
# test_dataloader = get_source_dataloader('CIFAR100', 256, 4, 'test', False, False, False)
# test_dataloader = get_target_dataloaders('CIFAR100C', [7], 128, 4, 'test', False, False, False)[0] # snow, xgf 0.3914
# test_dataloaders = get_target_dataloaders('CIFAR100C', list(range(15)), 128, 4, 'test', False, False, False) # defocus_blur, xgf 0.2836
# test_dataloaders = get_target_dataloaders('RotatedCIFAR100', list(range(18)), 128, 4, 'test', False, False, False)
train_dataloaders = [
get_source_dataloader(dataset_name, 128, 4, 'train', True, None, True) for dataset_name in ['SVHN', 'CIFAR10', 'SVHN']
][::-1] * 10
test_dataloaders = [
get_source_dataloader('USPS', 128, 4, 'test', False, False, False),
get_source_dataloader('STL10-wo-monkey', 128, 4, 'test', False, False, False),
get_source_dataloader('MNIST', 128, 4, 'test', False, False, False),
][::-1] * 10
y_offsets = [10, 0, 10][::-1] * 10
domain_names = ['USPS', 'STL10', 'MNIST'][::-1] * 10
# train_dataloader = get_source_dataloader('CIFAR100', 128, 4, 'train', True, None, True)
# acc = ResNetCIFARManager.get_accuracy(model, test_dataloader, 'cuda')
# print(acc)
# baseline_accs = [0.1012, 0.1156, 0.0529, 0.2836, 0.1731, 0.3765, 0.3445, 0.3914, 0.2672, 0.3289, 0.5991, 0.1486, 0.4519, 0.1907, 0.3929]
# accs = []
baseline_before, baseline_after, ours_before, ours_after = [], [], [], []
last_pruned_model, last_pruning_info = None, None
# y_offset = 0
for ti, (test_dataloader, y_offset) in enumerate(zip(test_dataloaders, y_offsets)):
samples, labels = next(iter(test_dataloader))
samples, labels = samples.cuda(), labels.cuda()
labels += y_offset
def bn_cal(_model: nn.Module):
for n, m in _model.named_modules():
if isinstance(m, nn.BatchNorm2d):
m.reset_running_stats()
m.training = True
m.train()
for _ in range(100): # ~one epoch
x, y = next(train_dataloaders[ti])
x = x.cuda()
_model(samples)
def shot(_model: nn.Module, lr=6e-4, num_iters_scale=1, wd=0.):
# print([n for n, p in model.named_parameters()])
_model.requires_grad_(True)
_model.linear.requires_grad_(False)
import torch.optim
optimizer = torch.optim.SGD([p for p in _model.parameters() if p.requires_grad], lr=lr, momentum=0.9, weight_decay=wd)
device = 'cuda'
for _ in range(100 * num_iters_scale):
x = samples
_model.train()
output = ResNetCIFARManager.forward(_model, x)
def Entropy(input_):
entropy = -input_ * torch.log(input_ + 1e-5)
entropy = torch.sum(entropy, dim=1)
return entropy
softmax_out = nn.Softmax(dim=1)(output)
entropy_loss = torch.mean(Entropy(softmax_out))
msoftmax = softmax_out.mean(dim=0)
entropy_loss -= torch.sum(-msoftmax * torch.log(msoftmax + 1e-5))
loss = entropy_loss
optimizer.zero_grad()
loss.backward()
optimizer.step()
def shot_w_part_filter(_model: nn.Module, lr=6e-4, num_iters_scale=1, wd=0.):
# print([n for n, p in model.named_parameters()])
_model.requires_grad_(True)
_model.linear.requires_grad_(False)
import torch.optim
optimizer = SGDF([p for p in _model.parameters() if p.requires_grad], lr=lr, momentum=0.9, weight_decay=wd)
device = 'cuda'
filters_sen_info = tent_as_detector(_model, samples, strategy='ours')
conv_filters_sen_info = {v['conv_name']: v for _, v in filters_sen_info.items()}
for _ in range(100 * num_iters_scale):
x = samples
_model.train()
output = ResNetCIFARManager.forward(_model, x)
def Entropy(input_):
entropy = -input_ * torch.log(input_ + 1e-5)
entropy = torch.sum(entropy, dim=1)
return entropy
softmax_out = nn.Softmax(dim=1)(output)
entropy_loss = torch.mean(Entropy(softmax_out))
msoftmax = softmax_out.mean(dim=0)
entropy_loss -= torch.sum(-msoftmax * torch.log(msoftmax + 1e-5))
loss = entropy_loss
optimizer.zero_grad()
loss.backward()
optimizer.step(_model, conv_filters_sen_info, filters_sen_info)
def tent(_model: nn.Module):
from methods.tent import tent
_model = tent.configure_model(_model)
params, param_names = tent.collect_params(_model)
optimizer = torch.optim.Adam(params, lr=1e-4)
tent_model = tent.Tent(_model, optimizer, ResNetCIFARManager, steps=1)
tent.configure_model(_model)
tent_model(samples)
def tent_configure_bn(_model):
"""Configure model for use with tent."""
# train mode, because tent optimizes the model to minimize entropy
# _model.train()
# # disable grad, to (re-)enable only what tent updates
# _model.requires_grad_(False)
# configure norm for tent updates: enable grad + force batch statisics
for m in _model.modules():
if isinstance(m, nn.BatchNorm2d):
m.requires_grad_(True)
# force use of batch stats in train and eval modes
m.track_running_stats = False
m.running_mean = None
m.running_var = None
# m.track_running_stats = True
# m.momentum = 1.0
# # FIXME
# from methods.ours_dynamic_filters.extract_submodel import FeatureBoosting
# # if isinstance(m, FeatureBoosting):
# if m.__class__.__name__ == 'FeatureBoosting':
# m.requires_grad_(True)
return model
def sl(_model: nn.Module, lr=6e-4, num_iters_scale=1, wd=0.):
_model.requires_grad_(True)
_model.linear.requires_grad_(False)
import torch.optim
optimizer = torch.optim.SGD([p for p in _model.parameters() if p.requires_grad], lr=lr, momentum=0.9, weight_decay=wd)
device = 'cuda'
for _ in range(100 * num_iters_scale):
x = samples
_model.train()
loss = ResNetCIFARManager.forward_to_gen_loss(_model, x, labels)
optimizer.zero_grad()
loss.backward()
optimizer.step()
model_extractor = DAFBSSubModelExtractor()
model1 = model_extractor.extract_submodel_via_a_sample(model,samples[0])
pruned_model, pruning_info = model_extractor.extract_submodel_via_samples_and_last_submodel(model, samples, None, None)
# print(pruned_model)
# print(get_model_size(pruned_model, True))
# bn_cal(pruned_model)
acc = ResNetCIFARManager.get_accuracy(pruned_model, test_dataloader, 'cuda', y_offset)
print(acc)
ours_before += [acc]
# tent(pruned_model)
# bn_cal(pruned_model)
shot_w_part_filter(pruned_model, 6e-4, 1, 1e-3)
# sl(pruned_model)
acc = ResNetCIFARManager.get_accuracy(pruned_model, test_dataloader, 'cuda', y_offset)
print(acc)
ours_after += [acc]
last_pruned_model, last_pruning_info = deepcopy(pruned_model), deepcopy(pruning_info)
model_extractor.absorb_sub_model(model, pruned_model, pruning_info, alpha)
# xgf_model = torch.load('/data/xgf/legodnn_and_domain_adaptation/results_scaling_da/image_classification/CIFAR100C_resnet18/onda/offline_l1/s8/20220607/212448/last_model.pt')
# xgf_model = torch.load('/data/xgf/legodnn_and_domain_adaptation/results_scaling_da/image_classification/CIFAR100C_resnet18/onda/offline_l1/s4/20220607/204211/last_model.pt')
# print(xgf_model)
# acc = ResNetCIFARManager.get_accuracy(xgf_model, test_dataloader, 'cuda', y_offset)
# print(acc)
# baseline_before += [acc]
# # tent(xgf_model)
# shot(xgf_model)
# # sl(xgf_model)
# acc = ResNetCIFARManager.get_accuracy(xgf_model, test_dataloader, 'cuda', y_offset)
# print(acc)
# baseline_after += [acc]
# print()
# diff = acc - baseline_accs[ti]
# print(f'domain {ti}, model size {get_model_size(pruned_model, True):.3f}MB, diff: {diff:.4f}')
# print(accs, sum(accs) / len(accs))
import matplotlib.pyplot as plt
from visualize.util import *
set_figure_settings(3)
def avg(arr):
return sum(arr) / len(arr)
# plt.plot(list(range(len(test_dataloaders))), baseline_before, lw=2, linestyle='--', color=BLUE, label=f'L1 before DA ({avg(baseline_before):.4f})')
# plt.plot(list(range(len(test_dataloaders))), baseline_after, lw=2, linestyle='-', color=BLUE, label=f'L1 after DA ({avg(baseline_after):.4f})')
plt.plot(list(range(len(test_dataloaders))), ours_before, lw=2, linestyle='--', color=RED, label=f'ours before DA ({avg(ours_before):.4f})')
plt.plot(list(range(len(test_dataloaders))), ours_after, lw=2, linestyle='-', color=RED, label=f'ours after DA ({avg(ours_after):.4f})')
plt.xlabel('domains')
plt.ylabel('accuracy')
plt.xticks(list(range(len(domain_names))), domain_names, rotation=90)
plt.legend(loc=2, bbox_to_anchor=(1.05, 1.0), fontsize=16)
plt.tight_layout()
plt.savefig(os.path.join(res_save_dir, 'main.png'), dpi=300)
plt.clf()
torch.save((baseline_before, baseline_after, ours_before, ours_after), os.path.join(res_save_dir, 'main.png.data'))
# with open('./tmp.csv', 'a') as f:
# f.write(f'{alpha:.2f},{avg(baseline_after):.4f},{avg(ours_after):.4f}')
# std: logs/experiments_trial/CIFAR100C/ours_dynamic_filters/20220730/161404-submodel/main.png
# accs = []
# for i in tqdm.tqdm(range(100)):
# model_extractor.debug_sample_i = i
# pruned_model = model_extractor.extract_submodel_via_samples(model, samples)
# acc = ResNetCIFARManager.get_accuracy(pruned_model, test_dataloader, 'cuda')
# accs += [acc]
# import matplotlib.pyplot as plt
# plt.plot(list(range(100)), accs)
# plt.savefig('./tmp.png', dpi=300)
# plt.clf()
# ------------------------------
# perf test
# sample, _ = next(iter(test_dataloader))
# sample = sample[0: 1].cuda()
# pruned_model = FBSSubModelExtractor().extract_submodel_via_a_sample(model, sample)
# bs = 1
# def perf_test(model, batch_size, device):
# model = model.to(device)
# optimizer = torch.optim.SGD(model.parameters(), lr=1e-3)
# # warmup
# for _ in range(100):
# rand_input = torch.rand((batch_size, 3, 32, 32)).to(device)
# o = model(rand_input)
# forward_latency = 0.
# backward_latency = 0.
# for _ in range(100):
# rand_input = torch.rand((batch_size, 3, 32, 32)).to(device)
# optimizer.zero_grad()
# s, e = torch.cuda.Event(enable_timing=True), torch.cuda.Event(enable_timing=True)
# s.record()
# o = model(rand_input)
# e.record()
# torch.cuda.synchronize()
# forward_latency += s.elapsed_time(e) / 1000.
# loss = ((o - 1) ** 2).sum()
# s, e = torch.cuda.Event(enable_timing=True), torch.cuda.Event(enable_timing=True)
# s.record()
# loss.backward()
# optimizer.step()
# e.record()
# torch.cuda.synchronize()
# backward_latency += s.elapsed_time(e) / 1000.
# forward_latency /= 100
# backward_latency /= 100
# print(forward_latency, backward_latency)
# for bs in [1, 128]:
# for device in ['cuda', 'cpu']:
# for m in [model, pruned_model]:
# print(bs, device)
# perf_test(m, bs, device)