from turtle import forward | |
from typing import Optional | |
import torch | |
import copy | |
from torch import nn | |
#from methods.utils.data import get_source_dataloader | |
from utils.dl.common.model import get_model_device, get_model_latency, get_model_size, get_module, get_super_module, set_module | |
from utils.common.log import logger | |
""" | |
No real speedup. | |
But it's ok because our big model just forward for one time to find the best sub-model. | |
The sub-model doesn't contain filter selection modules. It's just a normal model. | |
""" | |
class KTakesAll(nn.Module): | |
def __init__(self, k): | |
super(KTakesAll, self).__init__() | |
self.k = k | |
def forward(self, g: torch.Tensor): | |
# if self.k == 0.: | |
# t = g | |
# t = t / torch.sum(t, dim=1).unsqueeze(1) * t.size(1) | |
# return t.unsqueeze(2).unsqueeze(3) | |
# t = g | |
# t = t / torch.sum(t, dim=1).unsqueeze(1) * t.size(1) | |
# # print('000', t.size()) | |
# t = t.unsqueeze(2).unsqueeze(3).mean((0, 2, 3)).unsqueeze(0).unsqueeze(2).unsqueeze(3) | |
# # print('111', t.size()) | |
# # print(t) | |
# return t | |
# # assert x.dim() == 2 | |
# print(g) | |
k = int(g.size(1) * self.k) | |
i = (-g).topk(k, 1)[1] | |
t = g.scatter(1, i, 0) | |
# t = t / torch.sum(t, dim=1).unsqueeze(1) * t.size(1) | |
# print(t) | |
return t.unsqueeze(2).unsqueeze(3) | |
# g = g.mean(0).unsqueeze(0) | |
# k = int(g.size(1) * self.k) | |
# i = (-g).topk(k, 1)[1] | |
# t = g.scatter(1, i, 0) | |
# t = t / torch.sum(t, dim=1).unsqueeze(1) * t.size(1) | |
# return t.unsqueeze(2).unsqueeze(3) | |
# class NoiseAdd(nn.Module): | |
# def __init__(self): | |
# super(NoiseAdd, self).__init__() | |
# self.training = True | |
# def forward(self, x): | |
# if self.training: | |
# return x + torch.randn_like(x, device=x.device) | |
# else: | |
# return x | |
class Abs(nn.Module): | |
def __init__(self): | |
super(Abs, self).__init__() | |
def forward(self, x): | |
return x.abs() | |
class DomainDynamicConv2d(nn.Module): | |
def __init__(self, raw_conv2d: nn.Conv2d, raw_bn: nn.BatchNorm2d, k: float, bn_after_fc=False): | |
super(DomainDynamicConv2d, self).__init__() | |
assert not bn_after_fc | |
self.filter_selection_module = nn.Sequential( | |
Abs(), | |
nn.AdaptiveAvgPool2d(1), | |
nn.Flatten(), | |
nn.Linear(raw_conv2d.in_channels, raw_conv2d.out_channels), | |
# nn.Conv2d(raw_conv2d.in_channels, raw_conv2d.out_channels // 16, kernel_size=1, bias=False), | |
# nn.Linear(raw_conv2d.in_channels, raw_conv2d.out_channels // 16), | |
# nn.BatchNorm1d(raw_conv2d.out_channels // 16) if bn_after_fc else nn.Identity(), | |
# nn.ReLU(), | |
# nn.Linear(raw_conv2d.out_channels // 16, raw_conv2d.out_channels), | |
# nn.BatchNorm1d(raw_conv2d.out_channels), | |
nn.ReLU(), | |
# NoiseAdd(), | |
# nn.Sigmoid() | |
# L1RegTrack(), | |
# KTakesAll(k) | |
) | |
self.k_takes_all = KTakesAll(k) | |
self.raw_conv2d = raw_conv2d | |
self.bn = raw_bn # remember clear the original BNs in the network | |
nn.init.constant_(self.filter_selection_module[3].bias, 1.) | |
nn.init.kaiming_normal_(self.filter_selection_module[3].weight) | |
self.cached_raw_w = None | |
self.l1_reg_of_raw_w = None | |
self.cached_w = None | |
self.static_w = None | |
self.pruning_ratios = None | |
def forward(self, x): | |
raw_x = self.bn(self.raw_conv2d(x)) | |
# if self.k_takes_all.k < 1e-7: | |
# return raw_x | |
if self.static_w is None: | |
raw_w = self.filter_selection_module(x) | |
self.cached_raw_w = raw_w | |
# self.l1_reg_of_raw_w = raw_w.norm(1, dim=1).mean() | |
self.l1_reg_of_raw_w = raw_w.norm(1) | |
w = self.k_takes_all(raw_w) | |
# w = w.unsqueeze(2).unsqueeze(3) | |
# if self.training: | |
# soft_w = torch.max(torch.zeros_like(raw_w), torch.min(torch.ones_like(raw_w), | |
# 1.2 * (torch.sigmoid(raw_w + torch.randn_like(raw_w))) - 0.1)) | |
# else: | |
# soft_w = torch.max(torch.zeros_like(raw_w), torch.min(torch.ones_like(raw_w), | |
# 1.2 * (torch.sigmoid(raw_w)) - 0.1)) | |
# w = soft_w.detach().clone() | |
# w[w < 0.5] = 0. | |
# w[w >= 0.5] = 1. | |
# w = w + soft_w - soft_w.detach() | |
# w = w.unsqueeze(2).unsqueeze(3) | |
# soft_w = soft_w.unsqueeze(2).unsqueeze(3) | |
# self.l1_reg_of_raw_w = soft_w.norm(1) | |
self.cached_w = w | |
# print(w.size(), x.size(), raw_x.size()) | |
else: | |
w = self.static_w.unsqueeze(0).unsqueeze(2).unsqueeze(3) | |
if self.pruning_ratios is not None: | |
# self.pruning_ratios += [1. - float((w_of_a_asample > 0.).sum() / w_of_a_asample.numel()) for w_of_a_asample in w] | |
self.pruning_ratios += [torch.sum(w > 0.) / w.numel()] | |
return raw_x * w | |
# def to_static(self): | |
# global_w = self.cached_raw_w.detach().topk(0.25, 1)[0].mean(0).unsqueeze(0) | |
# global_w = self.k_takes_all(global_w).squeeze(0) | |
# self.static_w = global_w | |
# def to_dynamic(self): | |
# self.static_w = None | |
def boost_raw_model_with_filter_selection(model: nn.Module, init_k: float, bn_after_fc=False, ignore_layers=None, perf_test=True, model_input_size: Optional[tuple]=None): | |
model = copy.deepcopy(model) | |
device = get_model_device(model) | |
if perf_test: | |
before_model_size = get_model_size(model, True) | |
before_model_latency = get_model_latency( | |
model, model_input_size, 50, device, 50) | |
# clear original BNs | |
num_original_bns = 0 | |
last_conv_name = None | |
conv_bn_map = {} | |
for name, module in model.named_modules(): | |
if isinstance(module, nn.Conv2d): | |
last_conv_name = name | |
if isinstance(module, nn.BatchNorm2d) and (ignore_layers is not None and last_conv_name not in ignore_layers): | |
# set_module(model, name, nn.Identity()) | |
num_original_bns += 1 | |
conv_bn_map[last_conv_name] = name | |
num_conv = 0 | |
for name, module in model.named_modules(): | |
if isinstance(module, nn.Conv2d) and (ignore_layers is not None and name not in ignore_layers): | |
set_module(model, name, DomainDynamicConv2d(module, get_module(model, conv_bn_map[name]), init_k, bn_after_fc)) | |
num_conv += 1 | |
assert num_conv == num_original_bns | |
for bn_layer in conv_bn_map.values(): | |
set_module(model, bn_layer, nn.Identity()) | |
if perf_test: | |
after_model_size = get_model_size(model, True) | |
after_model_latency = get_model_latency( | |
model, model_input_size, 50, device, 50) | |
logger.info(f'raw model -> raw model w/ filter selection:\n' | |
f'model size: {before_model_size:.3f}MB -> {after_model_size:.3f}MB ' | |
f'latency: {before_model_latency:.6f}s -> {after_model_latency:.6f}s') | |
return model, conv_bn_map | |
def get_l1_reg_in_model(boosted_model): | |
res = 0. | |
for name, module in boosted_model.named_modules(): | |
if isinstance(module, DomainDynamicConv2d): | |
res += module.l1_reg_of_raw_w | |
return res | |
def get_cached_w(model): | |
res = [] | |
for name, module in model.named_modules(): | |
if isinstance(module, DomainDynamicConv2d): | |
res += [module.cached_w] | |
return torch.cat(res, dim=1) | |
def set_pruning_rate(model, k): | |
for name, module in model.named_modules(): | |
if isinstance(module, KTakesAll): | |
module.k = k | |
def get_cached_raw_w(model): | |
res = [] | |
for name, module in model.named_modules(): | |
if isinstance(module, DomainDynamicConv2d): | |
res += [module.cached_raw_w] | |
return torch.cat(res, dim=1) | |
def start_accmu_flops(model): | |
for name, module in model.named_modules(): | |
if isinstance(module, DomainDynamicConv2d): | |
module.pruning_ratios = [] | |
def get_accmu_flops(model): | |
layer_res = {} | |
total_res = [] | |
for name, module in model.named_modules(): | |
if isinstance(module, DomainDynamicConv2d): | |
layer_res[name] = module.pruning_ratios | |
total_res += module.pruning_ratios | |
module.pruning_ratios = None | |
avg_pruning_ratio = sum(total_res) / len(total_res) | |
return layer_res, total_res, avg_pruning_ratio | |
def convert_boosted_model_to_static(boosted_model, a_few_data): | |
boosted_model(a_few_data) | |
for name, module in boosted_model.named_modules(): | |
if isinstance(module, DomainDynamicConv2d): | |
module.to_static() | |
# TODO: use fn3 techniques | |
def ensure_boosted_model_to_dynamic(boosted_model): | |
for name, module in boosted_model.named_modules(): | |
if isinstance(module, DomainDynamicConv2d): | |
module.to_dynamic() | |
def train_only_gate(model): | |
gate_params = [] | |
for n, p in model.named_parameters(): | |
if 'filter_selection_module' in n: | |
gate_params += [p] | |
else: | |
p.requires_grad = False | |
return gate_params | |
if __name__ == '__main__': | |
# rand_input = torch.rand((256, 3, 32, 32)) | |
# conv = nn.Conv2d(3, 64, 3, 1, 1, bias=False) | |
# new_conv = DomainDynamicConv2d(conv, 0.1) | |
# train_dataloader = get_source_dataloader('CIFAR100', 256, 4, 'train', True, None, True) | |
# rand_input, _ = next(train_dataloader) | |
# start_accmu_flops(new_conv) | |
# new_conv(rand_input) | |
# _, total_pruning_ratio, avg_pruning_ratio = get_accmu_flops(new_conv) | |
# import matplotlib.pyplot as plt | |
# plt.hist(total_pruning_ratio) | |
# plt.savefig('./tmp.png') | |
# plt.clf() | |
# print(avg_pruning_ratio) | |
# with torch.no_grad(): | |
# conv(rand_input) | |
# new_conv(rand_input) | |
# from torchvision.models import resnet18 | |
# model = resnet18() | |
# boost_raw_model_with_filter_selection(model, 0.5, True, (1, 3, 224, 224)) | |
# rand_input = torch.rand((2, 3, 32, 32)) | |
# conv = nn.Conv2d(3, 4, 3, 1, 1, bias=False) | |
# w = torch.rand((1, 4)).repeat(2, 1) | |
# with torch.no_grad(): | |
# o1 = conv(rand_input) * w.unsqueeze(2).unsqueeze(3) | |
# print(w) | |
# w = w.mean(0).unsqueeze(1).unsqueeze(2).unsqueeze(3) | |
# print(w) | |
# conv.weight.data.mul_(w) | |
# o2 = conv(rand_input) | |
# diff = ((o1 - o2) ** 2).sum() | |
# print(diff) | |
# rand_input = torch.rand((2, 3, 32, 32)) | |
# conv1 = nn.Conv2d(3, 6, 3, 1, 1, bias=False) | |
# conv2 = nn.Conv2d(3, 3, 3, 1, 1, bias=False, groups=3) | |
# print(conv1.weight.data.size(), conv2.weight.data.size()) | |
# import time | |
# import torch | |
# from utils.dl.common.model import get_model_latency | |
# # s, e = torch.cuda.Event(enable_timing=True), torch.cuda.Event(enable_timing=True) | |
# # s.record() | |
# # # TODO | |
# # e.record() | |
# # torch.cuda.synchronize() | |
# # time_usage = s.elapsed_time(e) / 1000. | |
# # print(time_usage) | |
# data = [torch.rand((512, 3, 3)).cuda() for _ in range(512)] | |
# # t1 = time.time() | |
# # for i in range(300): d = torch.stack(data) | |
# # t2 = time.time() | |
# # for i in range(300): d = torch.cat(data).view(512, 512, 3, 3) | |
# # t3 = time.time() | |
# # print("torch.stack time: {}, torch.cat time: {}".format(t2 - t1, t3 - t2)) | |
# s, e = torch.cuda.Event(enable_timing=True), torch.cuda.Event(enable_timing=True) | |
# s.record() | |
# for i in range(300): d = torch.stack(data) | |
# e.record() | |
# torch.cuda.synchronize() | |
# time_usage = s.elapsed_time(e) / 1000. | |
# print(time_usage) | |
# s, e = torch.cuda.Event(enable_timing=True), torch.cuda.Event(enable_timing=True) | |
# s.record() | |
# for i in range(300): d = torch.cat(data).view(512, 512, 3, 3) | |
# e.record() | |
# torch.cuda.synchronize() | |
# time_usage = s.elapsed_time(e) / 1000. | |
# print(time_usage) | |
# from models.resnet_cifar.resnet_cifar_3 import resnet18 | |
# model = resnet18() | |
# full_l1_reg = 0. | |
# for name, module in model.named_modules(): | |
# if isinstance(module, nn.Conv2d): | |
# w = torch.ones((256, module.out_channels)) | |
# w[:, (module.out_channels // 2):] = 0. | |
# full_l1_reg += w.norm(1) | |
# full_l1_reg /= 2 | |
# print(f'{full_l1_reg:.3e}') | |
# def f(x): | |
# # x = x - 0.5 | |
# return torch.max(torch.zeros_like(x), torch.min(torch.ones_like(x), 1.2 * torch.sigmoid(x) - 0.1)) | |
# x = torch.arange(-2, 2, 0.01).float() | |
# y = f(x) | |
# print(f(torch.FloatTensor([0.]))) | |
# print(f(torch.FloatTensor([0.5]))) | |
# import matplotlib.pyplot as plt | |
# plt.plot(x, y) | |
# plt.savefig('./tmp.png') | |
# rand_input = torch.rand((256, 3, 32, 32)) | |
# conv = nn.Conv2d(3, 64, 3, 1, 1, bias=False) | |
# new_conv = DomainDynamicConv2d(conv, 0.1) | |
# new_conv(rand_input) | |
# conv = nn.Conv2d(3, 64, 3, 1, 1, bias=False) | |
# new_conv = DomainDynamicConv2d(conv, nn.BatchNorm2d(64), 0.1) | |
# print(new_conv.filter_selection_module[5].training) | |
# new_conv.eval() | |
# print(new_conv.filter_selection_module[5].training) | |
n = KTakesAll(0.6) | |
rand_input = torch.rand((1, 5)) | |
print(n(rand_input)) |