|
from abc import abstractmethod |
|
from copy import deepcopy |
|
import enum |
|
import torch |
|
from torch import nn |
|
import os |
|
|
|
from .model_fbs import DomainDynamicConv2d |
|
|
|
|
|
from utils.common.others import get_cur_time_str |
|
from utils.dl.common.env import set_random_seed |
|
from utils.dl.common.model import get_model_latency, get_model_size, get_module, set_module |
|
from utils.common.log import logger |
|
from utils.third_party.nni_new.compression.pytorch.speedup import ModelSpeedup |
|
from utils.third_party.nni_new.compression.pytorch.utils.mask_conflict import GroupMaskConflict, ChannelMaskConflict, CatMaskPadding |
|
|
|
|
|
def fix_mask_conflict(masks, model=None, dummy_input=None, traced=None, fix_group=False, fix_channel=True, fix_padding=False): |
|
if isinstance(masks, str): |
|
|
|
assert os.path.exists(masks) |
|
masks = torch.load(masks) |
|
assert len(masks) > 0, 'Mask tensor cannot be empty' |
|
|
|
|
|
|
|
|
|
if traced is None: |
|
assert model is not None and dummy_input is not None |
|
training = model.training |
|
model.eval() |
|
|
|
traced = torch.jit.trace(model, dummy_input) |
|
model.train(training) |
|
|
|
if fix_group: |
|
fix_group_mask = GroupMaskConflict(masks, model, dummy_input, traced) |
|
masks = fix_group_mask.fix_mask() |
|
if fix_channel: |
|
fix_channel_mask = ChannelMaskConflict(masks, model, dummy_input, traced) |
|
masks = fix_channel_mask.fix_mask() |
|
if fix_padding: |
|
padding_cat_mask = CatMaskPadding(masks, model, dummy_input, traced) |
|
masks = padding_cat_mask.fix_mask() |
|
return masks |
|
|
|
|
|
class FeatureBoosting(nn.Module): |
|
def __init__(self, w: torch.Tensor): |
|
super(FeatureBoosting, self).__init__() |
|
assert w.dim() == 1 |
|
self.w = nn.Parameter(w.unsqueeze(0).unsqueeze(2).unsqueeze(3), requires_grad=False) |
|
|
|
def forward(self, x): |
|
return x * self.w |
|
|
|
|
|
class FBSSubModelExtractor: |
|
def extract_submodel_via_a_sample(self, fbs_model: nn.Module, sample: torch.Tensor): |
|
assert sample.dim() == 4 and sample.size(0) == 1 |
|
|
|
fbs_model.eval() |
|
o1 = fbs_model(sample) |
|
|
|
pruning_info = {} |
|
pruning_masks = {} |
|
|
|
for layer_name, layer in fbs_model.named_modules(): |
|
if not isinstance(layer, DomainDynamicConv2d): |
|
continue |
|
|
|
cur_pruning_mask = {'weight': torch.zeros_like(layer.raw_conv2d.weight.data)} |
|
if layer.raw_conv2d.bias is not None: |
|
cur_pruning_mask['bias'] = torch.zeros_like(layer.raw_conv2d.bias.data) |
|
|
|
w = get_module(fbs_model, layer_name).cached_w.squeeze() |
|
unpruned_filters_index = w.nonzero(as_tuple=True)[0] |
|
pruning_info[layer_name] = w |
|
|
|
cur_pruning_mask['weight'][unpruned_filters_index, ...] = 1. |
|
if layer.raw_conv2d.bias is not None: |
|
cur_pruning_mask['bias'][unpruned_filters_index, ...] = 1. |
|
pruning_masks[layer_name + '.0'] = cur_pruning_mask |
|
|
|
no_gate_model = deepcopy(fbs_model) |
|
for name, layer in no_gate_model.named_modules(): |
|
if not isinstance(layer, DomainDynamicConv2d): |
|
continue |
|
|
|
set_module(no_gate_model, name, nn.Sequential(layer.raw_conv2d, layer.bn, nn.Identity())) |
|
|
|
|
|
tmp_mask_path = f'tmp_mask_{get_cur_time_str()}_{os.getpid()}.pth' |
|
torch.save(pruning_masks, tmp_mask_path) |
|
pruned_model = no_gate_model |
|
pruned_model.eval() |
|
model_speedup = ModelSpeedup(pruned_model, sample, tmp_mask_path, sample.device) |
|
model_speedup.speedup_model() |
|
os.remove(tmp_mask_path) |
|
|
|
|
|
for layer_name, feature_boosting_w in pruning_info.items(): |
|
feature_boosting_w = feature_boosting_w[feature_boosting_w.nonzero(as_tuple=True)[0]] |
|
set_module(pruned_model, layer_name + '.2', FeatureBoosting(feature_boosting_w)) |
|
|
|
pruned_model_size = get_model_size(pruned_model, True) |
|
pruned_model.eval() |
|
o2 = pruned_model(sample) |
|
diff = ((o1 - o2) ** 2).sum() |
|
logger.info(f'pruned model size: {pruned_model_size:.3f}MB, diff: {diff}') |
|
|
|
return pruned_model |
|
|
|
@abstractmethod |
|
def get_final_w(self, fbs_model: nn.Module, samples: torch.Tensor, layer_name: str, w: torch.Tensor): |
|
pass |
|
|
|
@abstractmethod |
|
def generate_pruning_strategy(self, fbs_model: nn.Module, samples: torch.Tensor): |
|
pass |
|
|
|
def extract_submodel_via_samples(self, fbs_model: nn.Module, samples: torch.Tensor): |
|
assert samples.dim() == 4 |
|
fbs_model = deepcopy(fbs_model) |
|
|
|
|
|
self.generate_pruning_strategy(fbs_model, samples) |
|
|
|
pruning_info = {} |
|
pruning_masks = {} |
|
|
|
for layer_name, layer in fbs_model.named_modules(): |
|
|
|
if not isinstance(layer, DomainDynamicConv2d): |
|
continue |
|
|
|
cur_pruning_mask = {'weight': torch.zeros_like(layer.raw_conv2d.weight.data)} |
|
if layer.raw_conv2d.bias is not None: |
|
cur_pruning_mask['bias'] = torch.zeros_like(layer.raw_conv2d.bias.data) |
|
|
|
w = get_module(fbs_model, layer_name).cached_w.squeeze() |
|
w = self.get_final_w(fbs_model, samples, layer_name, w) |
|
|
|
unpruned_filters_index = w.nonzero(as_tuple=True)[0] |
|
pruning_info[layer_name] = w |
|
|
|
cur_pruning_mask['weight'][unpruned_filters_index, ...] = 1. |
|
if layer.raw_conv2d.bias is not None: |
|
cur_pruning_mask['bias'][unpruned_filters_index, ...] = 1. |
|
pruning_masks[layer_name + '.0'] = cur_pruning_mask |
|
|
|
no_gate_model = deepcopy(fbs_model) |
|
for name, layer in no_gate_model.named_modules(): |
|
if not isinstance(layer, DomainDynamicConv2d): |
|
continue |
|
|
|
set_module(no_gate_model, name, nn.Sequential(layer.raw_conv2d, layer.bn, nn.Identity())) |
|
|
|
|
|
tmp_mask_path = f'tmp_mask_{get_cur_time_str()}_{os.getpid()}.pth' |
|
torch.save(pruning_masks, tmp_mask_path) |
|
pruned_model = no_gate_model |
|
pruned_model.eval() |
|
model_speedup = ModelSpeedup(pruned_model, samples[0:1], tmp_mask_path, samples.device) |
|
model_speedup.speedup_model() |
|
os.remove(tmp_mask_path) |
|
|
|
|
|
for layer_name, feature_boosting_w in pruning_info.items(): |
|
feature_boosting_w = feature_boosting_w[feature_boosting_w.nonzero(as_tuple=True)[0]] |
|
set_module(pruned_model, layer_name + '.2', FeatureBoosting(feature_boosting_w)) |
|
|
|
return pruned_model, pruning_info |
|
|
|
def extract_submodel_via_samples_and_last_submodel(self, fbs_model: nn.Module, samples: torch.Tensor, |
|
last_submodel: nn.Module, last_pruning_info: dict): |
|
assert samples.dim() == 4 |
|
fbs_model = deepcopy(fbs_model) |
|
|
|
|
|
self.generate_pruning_strategy(fbs_model, samples) |
|
|
|
pruning_info = {} |
|
pruning_masks = {} |
|
|
|
incrementally_updated_layers = [] |
|
|
|
for layer_name, layer in fbs_model.named_modules(): |
|
if not isinstance(layer, DomainDynamicConv2d): |
|
continue |
|
|
|
cur_pruning_mask = {'weight': torch.zeros_like(layer.raw_conv2d.weight.data)} |
|
if layer.raw_conv2d.bias is not None: |
|
cur_pruning_mask['bias'] = torch.zeros_like(layer.raw_conv2d.bias.data) |
|
|
|
w = get_module(fbs_model, layer_name).cached_w.squeeze() |
|
w = self.get_final_w(fbs_model, samples, layer_name, w) |
|
|
|
unpruned_filters_index = w.nonzero(as_tuple=True)[0] |
|
pruning_info[layer_name] = w |
|
|
|
cur_pruning_mask['weight'][unpruned_filters_index, ...] = 1. |
|
if layer.raw_conv2d.bias is not None: |
|
cur_pruning_mask['bias'][unpruned_filters_index, ...] = 1. |
|
pruning_masks[layer_name + '.0'] = cur_pruning_mask |
|
|
|
|
|
if last_pruning_info is not None: |
|
last_w = last_pruning_info[layer_name] |
|
intersection_ratio = ((w > 0) * (last_w > 0)).sum() / (last_w > 0).sum() |
|
if intersection_ratio > 0.: |
|
incrementally_updated_layers += [layer_name] |
|
|
|
no_gate_model = deepcopy(fbs_model) |
|
for name, layer in no_gate_model.named_modules(): |
|
if not isinstance(layer, DomainDynamicConv2d): |
|
continue |
|
|
|
set_module(no_gate_model, name, nn.Sequential(layer.raw_conv2d, layer.bn, nn.Identity())) |
|
|
|
|
|
tmp_mask_path = f'tmp_mask_{get_cur_time_str()}_{os.getpid()}.pth' |
|
|
|
torch.save(pruning_masks, tmp_mask_path) |
|
pruned_model = no_gate_model |
|
pruned_model.eval() |
|
model_speedup = ModelSpeedup(pruned_model, samples[0:1], tmp_mask_path, samples.device) |
|
model_speedup.speedup_model() |
|
os.remove(tmp_mask_path) |
|
|
|
|
|
for layer_name, feature_boosting_w in pruning_info.items(): |
|
feature_boosting_w = feature_boosting_w[feature_boosting_w.nonzero(as_tuple=True)[0]] |
|
set_module(pruned_model, layer_name + '.2', FeatureBoosting(feature_boosting_w)) |
|
|
|
|
|
|
|
for layer_name in incrementally_updated_layers: |
|
cur_filter_i, last_filter_i = 0, 0 |
|
for i, (w_factor, last_w_factor) in enumerate(zip(pruning_info[layer_name], last_pruning_info[layer_name])): |
|
if w_factor > 0 and last_w_factor > 0: |
|
cur_conv2d, last_conv2d = get_module(pruned_model, layer_name + '.0'), get_module(last_submodel, layer_name + '.0') |
|
cur_conv2d.weight.data[cur_filter_i] = last_conv2d.weight.data[last_filter_i] |
|
|
|
cur_bn, last_bn = get_module(pruned_model, layer_name + '.1'), get_module(last_submodel, layer_name + '.1') |
|
cur_bn.weight.data[cur_filter_i] = last_bn.weight.data[last_filter_i] |
|
cur_bn.bias.data[cur_filter_i] = last_bn.bias.data[last_filter_i] |
|
cur_bn.running_mean.data[cur_filter_i] = last_bn.running_mean.data[last_filter_i] |
|
cur_bn.running_var.data[cur_filter_i] = last_bn.running_var.data[last_filter_i] |
|
|
|
cur_fw, last_fw = get_module(pruned_model, layer_name + '.2'), get_module(last_submodel, layer_name + '.2') |
|
cur_fw.w.data[0, cur_filter_i] = last_fw.w.data[0, last_filter_i] |
|
|
|
if w_factor > 0: |
|
cur_filter_i += 1 |
|
if last_w_factor > 0: |
|
last_filter_i += 1 |
|
|
|
return pruned_model, pruning_info |
|
|
|
def absorb_sub_model(self, fbs_model: nn.Module, sub_model: nn.Module, pruning_info: dict, alpha=1.): |
|
if alpha == 0.: |
|
return |
|
for layer_name, feature_boosting_w in pruning_info.items(): |
|
unpruned_filters_index = feature_boosting_w.nonzero(as_tuple=True)[0] |
|
|
|
fbs_layer = get_module(fbs_model, layer_name) |
|
sub_model_layer = get_module(sub_model, layer_name) |
|
|
|
for fi_in_sub_layer, fi_in_fbs_layer in enumerate(unpruned_filters_index): |
|
fbs_layer.raw_conv2d.weight.data[fi_in_fbs_layer] = (1. - alpha) * fbs_layer.raw_conv2d.weight.data[fi_in_fbs_layer] + \ |
|
alpha * sub_model_layer[0].weight.data[fi_in_sub_layer] |
|
for k in ['weight', 'bias', 'running_mean', 'running_var']: |
|
getattr(fbs_layer.bn, k).data[fi_in_fbs_layer] = (1. - alpha) * getattr(fbs_layer.bn, k).data[fi_in_fbs_layer] + \ |
|
alpha * getattr(sub_model_layer[1], k).data[fi_in_sub_layer] |
|
|
|
|
|
class DAFBSSubModelExtractor(FBSSubModelExtractor): |
|
def __init__(self) -> None: |
|
super().__init__() |
|
|
|
|
|
|
|
@abstractmethod |
|
def generate_pruning_strategy(self, fbs_model: nn.Module, samples: torch.Tensor): |
|
with torch.no_grad(): |
|
fbs_model.eval() |
|
self.cur_output = fbs_model(samples) |
|
|
|
@abstractmethod |
|
def get_final_w(self, fbs_model: nn.Module, samples: torch.Tensor, layer_name: str, w: torch.Tensor): |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
x = self.cur_output |
|
each_sample_entropy = -(x.logits.softmax(1) * x.logits.log_softmax(1)).sum(1) |
|
hardest_sample_index = each_sample_entropy.argmax() |
|
res = w[hardest_sample_index] |
|
return res |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def tent_as_detector(model, x, num_iters=1, lr=1e-4, l1_wd=0., strategy='ours'): |
|
model = deepcopy(model) |
|
before_model = deepcopy(model) |
|
|
|
from methods.tent import tent |
|
|
|
optimizer = torch.optim.SGD( |
|
model.parameters(), lr=lr, weight_decay=l1_wd) |
|
from models.resnet_cifar.model_manager import ResNetCIFARManager |
|
tented_model = tent.Tent(model, optimizer, ResNetCIFARManager, steps=num_iters) |
|
|
|
tent.configure_model(model) |
|
tented_model(x) |
|
|
|
filters_sen_info = {} |
|
|
|
last_conv_name = None |
|
for (name, m1), m2 in zip(model.named_modules(), before_model.modules()): |
|
if isinstance(m1, nn.Conv2d): |
|
last_conv_name = name |
|
|
|
if not isinstance(m1, nn.BatchNorm2d): |
|
continue |
|
|
|
with torch.no_grad(): |
|
features_weight_diff = ((m1.weight.data - m2.weight.data).abs()) |
|
features_bias_diff = ((m1.bias.data - m2.bias.data).abs()) |
|
|
|
features_diff = features_weight_diff + features_bias_diff |
|
|
|
features_diff_order = features_diff.argsort(descending=False) |
|
|
|
if strategy == 'ours': |
|
untrained_filters_index = features_diff_order[: int(len(features_diff) * 0.8)] |
|
elif strategy == 'random': |
|
untrained_filters_index = torch.randperm(len(features_diff))[: int(len(features_diff) * 0.8)] |
|
elif strategy == 'inversed_ours': |
|
untrained_filters_index = features_diff_order.flip(0)[: int(len(features_diff) * 0.8)] |
|
elif strategy == 'none': |
|
untrained_filters_index = None |
|
|
|
filters_sen_info[name] = dict(untrained_filters_index=untrained_filters_index, conv_name=last_conv_name) |
|
|
|
return filters_sen_info |
|
|
|
|
|
class SGDF(torch.optim.SGD): |
|
|
|
@torch.no_grad() |
|
def step(self, model, conv_filters_sen_info, filters_sen_info, closure=None): |
|
"""Performs a single optimization step. |
|
|
|
Arguments: |
|
closure (callable, optional): A closure that reevaluates the model |
|
and returns the loss. |
|
""" |
|
loss = None |
|
if closure is not None: |
|
with torch.enable_grad(): |
|
loss = closure() |
|
|
|
for group in self.param_groups: |
|
weight_decay = group['weight_decay'] |
|
momentum = group['momentum'] |
|
dampening = group['dampening'] |
|
nesterov = group['nesterov'] |
|
|
|
|
|
|
|
for (name, _), p in zip(model.named_parameters(), group['params']): |
|
if p.grad is None: |
|
continue |
|
|
|
layer_name = '.'.join(name.split('.')[0:-1]) |
|
if layer_name in filters_sen_info.keys(): |
|
untrained_filters_index = filters_sen_info[layer_name]['untrained_filters_index'] |
|
elif layer_name in conv_filters_sen_info.keys(): |
|
untrained_filters_index = conv_filters_sen_info[layer_name]['untrained_filters_index'] |
|
else: |
|
untrained_filters_index = [] |
|
|
|
d_p = p.grad |
|
if weight_decay != 0: |
|
d_p = d_p.add(p, alpha=weight_decay) |
|
if momentum != 0: |
|
param_state = self.state[p] |
|
if 'momentum_buffer' not in param_state: |
|
buf = param_state['momentum_buffer'] = torch.clone(d_p).detach() |
|
else: |
|
buf = param_state['momentum_buffer'] |
|
buf.mul_(momentum).add_(d_p, alpha=1 - dampening) |
|
if nesterov: |
|
d_p = d_p.add(buf, alpha=momentum) |
|
else: |
|
d_p = buf |
|
|
|
d_p[untrained_filters_index] = 0. |
|
p.add_(d_p, alpha=-group['lr']) |
|
|
|
return loss |
|
|
|
|
|
if __name__ == '__main__': |
|
set_random_seed(0) |
|
|
|
import sys |
|
tag = sys.argv[1] |
|
|
|
alpha = 0.2 |
|
|
|
|
|
fbs_model_path = sys.argv[1] |
|
|
|
cur_time_str = get_cur_time_str() |
|
res_save_dir = f'logs/experiments_trial/CIFAR100C/ours_fbs_more_challenging/{cur_time_str[0:8]}/{cur_time_str[8:]}-{tag}' |
|
os.makedirs(res_save_dir) |
|
|
|
import shutil |
|
shutil.copytree(os.path.dirname(__file__), |
|
os.path.join(res_save_dir, 'method'), ignore=shutil.ignore_patterns('*.pt', '*.pth', 'log', '__pycache__')) |
|
logger.info(f'res save dir: {res_save_dir}') |
|
|
|
|
|
|
|
model = torch.load(fbs_model_path) |
|
|
|
|
|
|
|
|
|
|
|
xgf_model = torch.load('logs/experiments_trial/CIFAR100C/ours_dynamic_filters/20220731/224212-cifar10_svhn_raw/last_model.pt') |
|
|
|
|
|
|
|
|
|
|
|
train_dataloaders = [ |
|
get_source_dataloader(dataset_name, 128, 4, 'train', True, None, True) for dataset_name in ['SVHN', 'CIFAR10', 'SVHN'] |
|
][::-1] * 10 |
|
test_dataloaders = [ |
|
get_source_dataloader('USPS', 128, 4, 'test', False, False, False), |
|
get_source_dataloader('STL10-wo-monkey', 128, 4, 'test', False, False, False), |
|
get_source_dataloader('MNIST', 128, 4, 'test', False, False, False), |
|
][::-1] * 10 |
|
y_offsets = [10, 0, 10][::-1] * 10 |
|
domain_names = ['USPS', 'STL10', 'MNIST'][::-1] * 10 |
|
|
|
|
|
|
|
|
|
|
|
|
|
baseline_before, baseline_after, ours_before, ours_after = [], [], [], [] |
|
last_pruned_model, last_pruning_info = None, None |
|
|
|
for ti, (test_dataloader, y_offset) in enumerate(zip(test_dataloaders, y_offsets)): |
|
samples, labels = next(iter(test_dataloader)) |
|
samples, labels = samples.cuda(), labels.cuda() |
|
labels += y_offset |
|
|
|
def bn_cal(_model: nn.Module): |
|
for n, m in _model.named_modules(): |
|
if isinstance(m, nn.BatchNorm2d): |
|
m.reset_running_stats() |
|
m.training = True |
|
m.train() |
|
for _ in range(100): |
|
x, y = next(train_dataloaders[ti]) |
|
x = x.cuda() |
|
_model(samples) |
|
|
|
def shot(_model: nn.Module, lr=6e-4, num_iters_scale=1, wd=0.): |
|
|
|
_model.requires_grad_(True) |
|
_model.linear.requires_grad_(False) |
|
import torch.optim |
|
optimizer = torch.optim.SGD([p for p in _model.parameters() if p.requires_grad], lr=lr, momentum=0.9, weight_decay=wd) |
|
device = 'cuda' |
|
|
|
for _ in range(100 * num_iters_scale): |
|
x = samples |
|
_model.train() |
|
output = ResNetCIFARManager.forward(_model, x) |
|
|
|
def Entropy(input_): |
|
entropy = -input_ * torch.log(input_ + 1e-5) |
|
entropy = torch.sum(entropy, dim=1) |
|
return entropy |
|
|
|
softmax_out = nn.Softmax(dim=1)(output) |
|
entropy_loss = torch.mean(Entropy(softmax_out)) |
|
msoftmax = softmax_out.mean(dim=0) |
|
entropy_loss -= torch.sum(-msoftmax * torch.log(msoftmax + 1e-5)) |
|
loss = entropy_loss |
|
|
|
optimizer.zero_grad() |
|
loss.backward() |
|
optimizer.step() |
|
|
|
def shot_w_part_filter(_model: nn.Module, lr=6e-4, num_iters_scale=1, wd=0.): |
|
|
|
_model.requires_grad_(True) |
|
_model.linear.requires_grad_(False) |
|
import torch.optim |
|
optimizer = SGDF([p for p in _model.parameters() if p.requires_grad], lr=lr, momentum=0.9, weight_decay=wd) |
|
device = 'cuda' |
|
|
|
filters_sen_info = tent_as_detector(_model, samples, strategy='ours') |
|
conv_filters_sen_info = {v['conv_name']: v for _, v in filters_sen_info.items()} |
|
|
|
for _ in range(100 * num_iters_scale): |
|
x = samples |
|
_model.train() |
|
output = ResNetCIFARManager.forward(_model, x) |
|
|
|
def Entropy(input_): |
|
entropy = -input_ * torch.log(input_ + 1e-5) |
|
entropy = torch.sum(entropy, dim=1) |
|
return entropy |
|
|
|
softmax_out = nn.Softmax(dim=1)(output) |
|
entropy_loss = torch.mean(Entropy(softmax_out)) |
|
msoftmax = softmax_out.mean(dim=0) |
|
entropy_loss -= torch.sum(-msoftmax * torch.log(msoftmax + 1e-5)) |
|
loss = entropy_loss |
|
|
|
optimizer.zero_grad() |
|
loss.backward() |
|
optimizer.step(_model, conv_filters_sen_info, filters_sen_info) |
|
|
|
def tent(_model: nn.Module): |
|
from methods.tent import tent |
|
_model = tent.configure_model(_model) |
|
params, param_names = tent.collect_params(_model) |
|
optimizer = torch.optim.Adam(params, lr=1e-4) |
|
tent_model = tent.Tent(_model, optimizer, ResNetCIFARManager, steps=1) |
|
|
|
tent.configure_model(_model) |
|
tent_model(samples) |
|
|
|
def tent_configure_bn(_model): |
|
"""Configure model for use with tent.""" |
|
|
|
|
|
|
|
|
|
|
|
for m in _model.modules(): |
|
if isinstance(m, nn.BatchNorm2d): |
|
m.requires_grad_(True) |
|
|
|
m.track_running_stats = False |
|
m.running_mean = None |
|
m.running_var = None |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
return model |
|
|
|
def sl(_model: nn.Module, lr=6e-4, num_iters_scale=1, wd=0.): |
|
_model.requires_grad_(True) |
|
_model.linear.requires_grad_(False) |
|
import torch.optim |
|
optimizer = torch.optim.SGD([p for p in _model.parameters() if p.requires_grad], lr=lr, momentum=0.9, weight_decay=wd) |
|
device = 'cuda' |
|
|
|
for _ in range(100 * num_iters_scale): |
|
x = samples |
|
_model.train() |
|
loss = ResNetCIFARManager.forward_to_gen_loss(_model, x, labels) |
|
|
|
optimizer.zero_grad() |
|
loss.backward() |
|
optimizer.step() |
|
|
|
model_extractor = DAFBSSubModelExtractor() |
|
model1 = model_extractor.extract_submodel_via_a_sample(model,samples[0]) |
|
pruned_model, pruning_info = model_extractor.extract_submodel_via_samples_and_last_submodel(model, samples, None, None) |
|
|
|
|
|
|
|
acc = ResNetCIFARManager.get_accuracy(pruned_model, test_dataloader, 'cuda', y_offset) |
|
print(acc) |
|
ours_before += [acc] |
|
|
|
|
|
shot_w_part_filter(pruned_model, 6e-4, 1, 1e-3) |
|
|
|
acc = ResNetCIFARManager.get_accuracy(pruned_model, test_dataloader, 'cuda', y_offset) |
|
print(acc) |
|
ours_after += [acc] |
|
|
|
last_pruned_model, last_pruning_info = deepcopy(pruned_model), deepcopy(pruning_info) |
|
model_extractor.absorb_sub_model(model, pruned_model, pruning_info, alpha) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import matplotlib.pyplot as plt |
|
from visualize.util import * |
|
set_figure_settings(3) |
|
|
|
def avg(arr): |
|
return sum(arr) / len(arr) |
|
|
|
|
|
|
|
plt.plot(list(range(len(test_dataloaders))), ours_before, lw=2, linestyle='--', color=RED, label=f'ours before DA ({avg(ours_before):.4f})') |
|
plt.plot(list(range(len(test_dataloaders))), ours_after, lw=2, linestyle='-', color=RED, label=f'ours after DA ({avg(ours_after):.4f})') |
|
plt.xlabel('domains') |
|
plt.ylabel('accuracy') |
|
plt.xticks(list(range(len(domain_names))), domain_names, rotation=90) |
|
plt.legend(loc=2, bbox_to_anchor=(1.05, 1.0), fontsize=16) |
|
plt.tight_layout() |
|
plt.savefig(os.path.join(res_save_dir, 'main.png'), dpi=300) |
|
plt.clf() |
|
|
|
torch.save((baseline_before, baseline_after, ours_before, ours_after), os.path.join(res_save_dir, 'main.png.data')) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|