EdgeTA / new_impl /cv /resnet /model_fbs.py
LINC-BIT's picture
Upload 1912 files
b84549f verified
from turtle import forward
from typing import Optional
import torch
import copy
from torch import nn
#from methods.utils.data import get_source_dataloader
from utils.dl.common.model import get_model_device, get_model_latency, get_model_size, get_module, get_super_module, set_module
from utils.common.log import logger
"""
No real speedup.
But it's ok because our big model just forward for one time to find the best sub-model.
The sub-model doesn't contain filter selection modules. It's just a normal model.
"""
class KTakesAll(nn.Module):
def __init__(self, k):
super(KTakesAll, self).__init__()
self.k = k
def forward(self, g: torch.Tensor):
# if self.k == 0.:
# t = g
# t = t / torch.sum(t, dim=1).unsqueeze(1) * t.size(1)
# return t.unsqueeze(2).unsqueeze(3)
# t = g
# t = t / torch.sum(t, dim=1).unsqueeze(1) * t.size(1)
# # print('000', t.size())
# t = t.unsqueeze(2).unsqueeze(3).mean((0, 2, 3)).unsqueeze(0).unsqueeze(2).unsqueeze(3)
# # print('111', t.size())
# # print(t)
# return t
# # assert x.dim() == 2
# print(g)
k = int(g.size(1) * self.k)
i = (-g).topk(k, 1)[1]
t = g.scatter(1, i, 0)
# t = t / torch.sum(t, dim=1).unsqueeze(1) * t.size(1)
# print(t)
return t.unsqueeze(2).unsqueeze(3)
# g = g.mean(0).unsqueeze(0)
# k = int(g.size(1) * self.k)
# i = (-g).topk(k, 1)[1]
# t = g.scatter(1, i, 0)
# t = t / torch.sum(t, dim=1).unsqueeze(1) * t.size(1)
# return t.unsqueeze(2).unsqueeze(3)
# class NoiseAdd(nn.Module):
# def __init__(self):
# super(NoiseAdd, self).__init__()
# self.training = True
# def forward(self, x):
# if self.training:
# return x + torch.randn_like(x, device=x.device)
# else:
# return x
class Abs(nn.Module):
def __init__(self):
super(Abs, self).__init__()
def forward(self, x):
return x.abs()
class DomainDynamicConv2d(nn.Module):
def __init__(self, raw_conv2d: nn.Conv2d, raw_bn: nn.BatchNorm2d, k: float, bn_after_fc=False):
super(DomainDynamicConv2d, self).__init__()
assert not bn_after_fc
self.filter_selection_module = nn.Sequential(
Abs(),
nn.AdaptiveAvgPool2d(1),
nn.Flatten(),
nn.Linear(raw_conv2d.in_channels, raw_conv2d.out_channels),
# nn.Conv2d(raw_conv2d.in_channels, raw_conv2d.out_channels // 16, kernel_size=1, bias=False),
# nn.Linear(raw_conv2d.in_channels, raw_conv2d.out_channels // 16),
# nn.BatchNorm1d(raw_conv2d.out_channels // 16) if bn_after_fc else nn.Identity(),
# nn.ReLU(),
# nn.Linear(raw_conv2d.out_channels // 16, raw_conv2d.out_channels),
# nn.BatchNorm1d(raw_conv2d.out_channels),
nn.ReLU(),
# NoiseAdd(),
# nn.Sigmoid()
# L1RegTrack(),
# KTakesAll(k)
)
self.k_takes_all = KTakesAll(k)
self.raw_conv2d = raw_conv2d
self.bn = raw_bn # remember clear the original BNs in the network
nn.init.constant_(self.filter_selection_module[3].bias, 1.)
nn.init.kaiming_normal_(self.filter_selection_module[3].weight)
self.cached_raw_w = None
self.l1_reg_of_raw_w = None
self.cached_w = None
self.static_w = None
self.pruning_ratios = None
def forward(self, x):
raw_x = self.bn(self.raw_conv2d(x))
# if self.k_takes_all.k < 1e-7:
# return raw_x
if self.static_w is None:
raw_w = self.filter_selection_module(x)
self.cached_raw_w = raw_w
# self.l1_reg_of_raw_w = raw_w.norm(1, dim=1).mean()
self.l1_reg_of_raw_w = raw_w.norm(1)
w = self.k_takes_all(raw_w)
# w = w.unsqueeze(2).unsqueeze(3)
# if self.training:
# soft_w = torch.max(torch.zeros_like(raw_w), torch.min(torch.ones_like(raw_w),
# 1.2 * (torch.sigmoid(raw_w + torch.randn_like(raw_w))) - 0.1))
# else:
# soft_w = torch.max(torch.zeros_like(raw_w), torch.min(torch.ones_like(raw_w),
# 1.2 * (torch.sigmoid(raw_w)) - 0.1))
# w = soft_w.detach().clone()
# w[w < 0.5] = 0.
# w[w >= 0.5] = 1.
# w = w + soft_w - soft_w.detach()
# w = w.unsqueeze(2).unsqueeze(3)
# soft_w = soft_w.unsqueeze(2).unsqueeze(3)
# self.l1_reg_of_raw_w = soft_w.norm(1)
self.cached_w = w
# print(w.size(), x.size(), raw_x.size())
else:
w = self.static_w.unsqueeze(0).unsqueeze(2).unsqueeze(3)
if self.pruning_ratios is not None:
# self.pruning_ratios += [1. - float((w_of_a_asample > 0.).sum() / w_of_a_asample.numel()) for w_of_a_asample in w]
self.pruning_ratios += [torch.sum(w > 0.) / w.numel()]
return raw_x * w
# def to_static(self):
# global_w = self.cached_raw_w.detach().topk(0.25, 1)[0].mean(0).unsqueeze(0)
# global_w = self.k_takes_all(global_w).squeeze(0)
# self.static_w = global_w
# def to_dynamic(self):
# self.static_w = None
def boost_raw_model_with_filter_selection(model: nn.Module, init_k: float, bn_after_fc=False, ignore_layers=None, perf_test=True, model_input_size: Optional[tuple]=None):
model = copy.deepcopy(model)
device = get_model_device(model)
if perf_test:
before_model_size = get_model_size(model, True)
before_model_latency = get_model_latency(
model, model_input_size, 50, device, 50)
# clear original BNs
num_original_bns = 0
last_conv_name = None
conv_bn_map = {}
for name, module in model.named_modules():
if isinstance(module, nn.Conv2d):
last_conv_name = name
if isinstance(module, nn.BatchNorm2d) and (ignore_layers is not None and last_conv_name not in ignore_layers):
# set_module(model, name, nn.Identity())
num_original_bns += 1
conv_bn_map[last_conv_name] = name
num_conv = 0
for name, module in model.named_modules():
if isinstance(module, nn.Conv2d) and (ignore_layers is not None and name not in ignore_layers):
set_module(model, name, DomainDynamicConv2d(module, get_module(model, conv_bn_map[name]), init_k, bn_after_fc))
num_conv += 1
assert num_conv == num_original_bns
for bn_layer in conv_bn_map.values():
set_module(model, bn_layer, nn.Identity())
if perf_test:
after_model_size = get_model_size(model, True)
after_model_latency = get_model_latency(
model, model_input_size, 50, device, 50)
logger.info(f'raw model -> raw model w/ filter selection:\n'
f'model size: {before_model_size:.3f}MB -> {after_model_size:.3f}MB '
f'latency: {before_model_latency:.6f}s -> {after_model_latency:.6f}s')
return model, conv_bn_map
def get_l1_reg_in_model(boosted_model):
res = 0.
for name, module in boosted_model.named_modules():
if isinstance(module, DomainDynamicConv2d):
res += module.l1_reg_of_raw_w
return res
def get_cached_w(model):
res = []
for name, module in model.named_modules():
if isinstance(module, DomainDynamicConv2d):
res += [module.cached_w]
return torch.cat(res, dim=1)
def set_pruning_rate(model, k):
for name, module in model.named_modules():
if isinstance(module, KTakesAll):
module.k = k
def get_cached_raw_w(model):
res = []
for name, module in model.named_modules():
if isinstance(module, DomainDynamicConv2d):
res += [module.cached_raw_w]
return torch.cat(res, dim=1)
def start_accmu_flops(model):
for name, module in model.named_modules():
if isinstance(module, DomainDynamicConv2d):
module.pruning_ratios = []
def get_accmu_flops(model):
layer_res = {}
total_res = []
for name, module in model.named_modules():
if isinstance(module, DomainDynamicConv2d):
layer_res[name] = module.pruning_ratios
total_res += module.pruning_ratios
module.pruning_ratios = None
avg_pruning_ratio = sum(total_res) / len(total_res)
return layer_res, total_res, avg_pruning_ratio
def convert_boosted_model_to_static(boosted_model, a_few_data):
boosted_model(a_few_data)
for name, module in boosted_model.named_modules():
if isinstance(module, DomainDynamicConv2d):
module.to_static()
# TODO: use fn3 techniques
def ensure_boosted_model_to_dynamic(boosted_model):
for name, module in boosted_model.named_modules():
if isinstance(module, DomainDynamicConv2d):
module.to_dynamic()
def train_only_gate(model):
gate_params = []
for n, p in model.named_parameters():
if 'filter_selection_module' in n:
gate_params += [p]
else:
p.requires_grad = False
return gate_params
if __name__ == '__main__':
# rand_input = torch.rand((256, 3, 32, 32))
# conv = nn.Conv2d(3, 64, 3, 1, 1, bias=False)
# new_conv = DomainDynamicConv2d(conv, 0.1)
# train_dataloader = get_source_dataloader('CIFAR100', 256, 4, 'train', True, None, True)
# rand_input, _ = next(train_dataloader)
# start_accmu_flops(new_conv)
# new_conv(rand_input)
# _, total_pruning_ratio, avg_pruning_ratio = get_accmu_flops(new_conv)
# import matplotlib.pyplot as plt
# plt.hist(total_pruning_ratio)
# plt.savefig('./tmp.png')
# plt.clf()
# print(avg_pruning_ratio)
# with torch.no_grad():
# conv(rand_input)
# new_conv(rand_input)
# from torchvision.models import resnet18
# model = resnet18()
# boost_raw_model_with_filter_selection(model, 0.5, True, (1, 3, 224, 224))
# rand_input = torch.rand((2, 3, 32, 32))
# conv = nn.Conv2d(3, 4, 3, 1, 1, bias=False)
# w = torch.rand((1, 4)).repeat(2, 1)
# with torch.no_grad():
# o1 = conv(rand_input) * w.unsqueeze(2).unsqueeze(3)
# print(w)
# w = w.mean(0).unsqueeze(1).unsqueeze(2).unsqueeze(3)
# print(w)
# conv.weight.data.mul_(w)
# o2 = conv(rand_input)
# diff = ((o1 - o2) ** 2).sum()
# print(diff)
# rand_input = torch.rand((2, 3, 32, 32))
# conv1 = nn.Conv2d(3, 6, 3, 1, 1, bias=False)
# conv2 = nn.Conv2d(3, 3, 3, 1, 1, bias=False, groups=3)
# print(conv1.weight.data.size(), conv2.weight.data.size())
# import time
# import torch
# from utils.dl.common.model import get_model_latency
# # s, e = torch.cuda.Event(enable_timing=True), torch.cuda.Event(enable_timing=True)
# # s.record()
# # # TODO
# # e.record()
# # torch.cuda.synchronize()
# # time_usage = s.elapsed_time(e) / 1000.
# # print(time_usage)
# data = [torch.rand((512, 3, 3)).cuda() for _ in range(512)]
# # t1 = time.time()
# # for i in range(300): d = torch.stack(data)
# # t2 = time.time()
# # for i in range(300): d = torch.cat(data).view(512, 512, 3, 3)
# # t3 = time.time()
# # print("torch.stack time: {}, torch.cat time: {}".format(t2 - t1, t3 - t2))
# s, e = torch.cuda.Event(enable_timing=True), torch.cuda.Event(enable_timing=True)
# s.record()
# for i in range(300): d = torch.stack(data)
# e.record()
# torch.cuda.synchronize()
# time_usage = s.elapsed_time(e) / 1000.
# print(time_usage)
# s, e = torch.cuda.Event(enable_timing=True), torch.cuda.Event(enable_timing=True)
# s.record()
# for i in range(300): d = torch.cat(data).view(512, 512, 3, 3)
# e.record()
# torch.cuda.synchronize()
# time_usage = s.elapsed_time(e) / 1000.
# print(time_usage)
# from models.resnet_cifar.resnet_cifar_3 import resnet18
# model = resnet18()
# full_l1_reg = 0.
# for name, module in model.named_modules():
# if isinstance(module, nn.Conv2d):
# w = torch.ones((256, module.out_channels))
# w[:, (module.out_channels // 2):] = 0.
# full_l1_reg += w.norm(1)
# full_l1_reg /= 2
# print(f'{full_l1_reg:.3e}')
# def f(x):
# # x = x - 0.5
# return torch.max(torch.zeros_like(x), torch.min(torch.ones_like(x), 1.2 * torch.sigmoid(x) - 0.1))
# x = torch.arange(-2, 2, 0.01).float()
# y = f(x)
# print(f(torch.FloatTensor([0.])))
# print(f(torch.FloatTensor([0.5])))
# import matplotlib.pyplot as plt
# plt.plot(x, y)
# plt.savefig('./tmp.png')
# rand_input = torch.rand((256, 3, 32, 32))
# conv = nn.Conv2d(3, 64, 3, 1, 1, bias=False)
# new_conv = DomainDynamicConv2d(conv, 0.1)
# new_conv(rand_input)
# conv = nn.Conv2d(3, 64, 3, 1, 1, bias=False)
# new_conv = DomainDynamicConv2d(conv, nn.BatchNorm2d(64), 0.1)
# print(new_conv.filter_selection_module[5].training)
# new_conv.eval()
# print(new_conv.filter_selection_module[5].training)
n = KTakesAll(0.6)
rand_input = torch.rand((1, 5))
print(n(rand_input))