|
from copy import deepcopy |
|
from typing import Optional, Union |
|
import torch |
|
from torch import nn |
|
from einops import rearrange, repeat |
|
from einops.layers.torch import Rearrange |
|
|
|
from raw_vit import ViT, Attention, FeedForward |
|
from utils.dl.common.model import get_model_size, set_module |
|
|
|
|
|
class KTakesAll(nn.Module): |
|
|
|
def __init__(self, k): |
|
super(KTakesAll, self).__init__() |
|
self.k = k |
|
|
|
def forward(self, g: torch.Tensor): |
|
k = int(g.size(1) * self.k) |
|
|
|
i = (-g).topk(k, 1)[1] |
|
t = g.scatter(1, i, 0) |
|
|
|
return t |
|
|
|
|
|
class Abs(nn.Module): |
|
def __init__(self): |
|
super(Abs, self).__init__() |
|
|
|
def forward(self, x): |
|
return x.abs() |
|
|
|
|
|
class SqueezeLast(nn.Module): |
|
def __init__(self): |
|
super(SqueezeLast, self).__init__() |
|
|
|
def forward(self, x): |
|
return x.squeeze(-1) |
|
|
|
|
|
class Linear_WrappedWithFBS(nn.Module): |
|
def __init__(self, linear: nn.Linear, r, k): |
|
super(Linear_WrappedWithFBS, self).__init__() |
|
|
|
self.linear = linear |
|
|
|
|
|
|
|
self.fbs = nn.Sequential( |
|
Rearrange('b n d -> b d n'), |
|
Abs(), |
|
nn.AdaptiveAvgPool1d(1), |
|
SqueezeLast(), |
|
nn.Linear(linear.in_features, linear.out_features // r), |
|
nn.ReLU(), |
|
nn.Linear(linear.out_features // r, linear.out_features), |
|
nn.ReLU(), |
|
KTakesAll(k) |
|
) |
|
self.k = k |
|
|
|
self.cached_channel_attention = None |
|
self.use_cached_channel_attention = False |
|
|
|
def forward(self, x): |
|
if self.use_cached_channel_attention and self.cached_channel_attention is not None: |
|
channel_attention = self.cached_channel_attention |
|
else: |
|
channel_attention = self.fbs(x) |
|
self.cached_channel_attention = channel_attention |
|
|
|
raw_res = self.linear(x) |
|
return channel_attention.unsqueeze(1) * raw_res |
|
|
|
|
|
class ToQKV_WrappedWithFBS(nn.Module): |
|
""" |
|
This regards to_q/to_k/to_v as a whole (in fact it consists of multiple heads) and prunes it. |
|
It seems different channels of different heads are pruned according to the input. |
|
This is different from "removing some head" or "removing the same channels in each head". |
|
""" |
|
def __init__(self, to_qkv: nn.Linear, r, k): |
|
super(ToQKV_WrappedWithFBS, self).__init__() |
|
|
|
self.to_qkv = to_qkv |
|
self.fbses = nn.ModuleList([nn.Sequential( |
|
Rearrange('b n d -> b d n'), |
|
Abs(), |
|
nn.AdaptiveAvgPool1d(1), |
|
SqueezeLast(), |
|
nn.Linear(to_qkv.in_features, to_qkv.out_features // 3 // r), |
|
nn.ReLU(), |
|
nn.Linear(to_qkv.out_features // 3 // r, to_qkv.out_features // 3), |
|
nn.ReLU(), |
|
KTakesAll(k) |
|
) for _ in range(3)]) |
|
self.k = k |
|
|
|
self.cached_channel_attention = None |
|
self.use_cached_channel_attention = False |
|
|
|
def forward(self, x): |
|
if self.use_cached_channel_attention and self.cached_channel_attention is not None: |
|
|
|
channel_attention = self.cached_channel_attention |
|
else: |
|
|
|
channel_attention = torch.cat([fbs(x) for fbs in self.fbses], dim=1) |
|
self.cached_channel_attention = channel_attention |
|
|
|
raw_res = self.to_qkv(x) |
|
return channel_attention.unsqueeze(1) * raw_res |
|
|
|
|
|
def boost_raw_vit_by_fbs(raw_vit: ViT, r, k): |
|
raw_vit = deepcopy(raw_vit) |
|
|
|
raw_vit_model_size = get_model_size(raw_vit, True) |
|
|
|
|
|
|
|
for attn, ff in raw_vit.transformer.layers: |
|
attn = attn.fn |
|
ff = ff.fn |
|
|
|
set_module(attn, 'to_qkv', ToQKV_WrappedWithFBS(attn.to_qkv, r, k)) |
|
set_module(ff.net, '0', Linear_WrappedWithFBS(ff.net[0], r, k)) |
|
|
|
boosted_vit_model_size = get_model_size(raw_vit, True) |
|
|
|
print(f'boost_raw_vit_by_fbs() | model size from {raw_vit_model_size:.3f}MB to {boosted_vit_model_size:.3f}MB ' |
|
f'(↑ {((boosted_vit_model_size - raw_vit_model_size) / raw_vit_model_size * 100):.2f}%)') |
|
|
|
return raw_vit |
|
|
|
|
|
def set_boosted_vit_sparsity(boosted_vit: ViT, sparsity: float): |
|
for attn, ff in boosted_vit.transformer.layers: |
|
attn = attn.fn |
|
ff = ff.fn |
|
|
|
q_features = attn.to_qkv.to_qkv.out_features // 3 |
|
|
|
if (q_features - int(q_features * sparsity)) % attn.heads != 0: |
|
|
|
|
|
tuned_sparsity = 1. - int((q_features - int(q_features * sparsity)) / attn.heads) * attn.heads / q_features |
|
print(f'set_boosted_vit_sparsity() | tune sparsity from {sparsity} to {tuned_sparsity}') |
|
sparsity = tuned_sparsity |
|
|
|
attn.to_qkv.k = sparsity |
|
for fbs in attn.to_qkv.fbses: |
|
fbs[-1].k = sparsity |
|
ff.net[0].k = sparsity |
|
ff.net[0].fbs[-1].k = sparsity |
|
|
|
|
|
def set_boosted_vit_inference_via_cached_channel_attentions(boosted_vit: ViT): |
|
for attn, ff in boosted_vit.transformer.layers: |
|
attn = attn.fn |
|
ff = ff.fn |
|
|
|
assert attn.to_qkv.cached_channel_attention is not None |
|
assert ff.net[0].cached_channel_attention is not None |
|
|
|
attn.to_qkv.use_cached_channel_attention = True |
|
ff.net[0].use_cached_channel_attention = True |
|
|
|
|
|
def set_boosted_vit_dynamic_inference(boosted_vit: ViT): |
|
for attn, ff in boosted_vit.transformer.layers: |
|
attn = attn.fn |
|
ff = ff.fn |
|
|
|
attn.to_qkv.use_cached_channel_attention = False |
|
ff.net[0].use_cached_channel_attention = False |
|
|
|
|
|
class StaticFBS(nn.Module): |
|
def __init__(self, static_channel_attention): |
|
super(StaticFBS, self).__init__() |
|
assert static_channel_attention.dim() == 2 and static_channel_attention.size(0) == 1 |
|
self.static_channel_attention = nn.Parameter(static_channel_attention, requires_grad=False) |
|
|
|
def forward(self, x): |
|
return x * self.static_channel_attention.unsqueeze(1) |
|
|
|
|
|
def extract_surrogate_vit_via_cached_channel_attn(boosted_vit: ViT): |
|
boosted_vit = deepcopy(boosted_vit) |
|
raw_vit_model_size = get_model_size(boosted_vit, True) |
|
|
|
def get_unpruned_indexes_from_channel_attn(channel_attn: torch.Tensor, k): |
|
assert channel_attn.size(0) == 1, 'use A representative sample to generate channel attentions' |
|
|
|
res = channel_attn[0].nonzero(as_tuple=True)[0] |
|
return res |
|
|
|
for attn, ff in boosted_vit.transformer.layers: |
|
attn = attn.fn |
|
ff_w_norm = ff |
|
ff = ff_w_norm.fn |
|
|
|
|
|
to_qkv = attn.to_qkv |
|
to_q_unpruned_indexes = get_unpruned_indexes_from_channel_attn( |
|
to_qkv.cached_channel_attention[:, 0: to_qkv.cached_channel_attention.size(1) // 3], |
|
to_qkv.k |
|
) |
|
to_q_unpruned_indexes_w_offset = to_q_unpruned_indexes |
|
to_k_unpruned_indexes = get_unpruned_indexes_from_channel_attn( |
|
to_qkv.cached_channel_attention[:, to_qkv.cached_channel_attention.size(1) // 3: to_qkv.cached_channel_attention.size(1) // 3 * 2], |
|
to_qkv.k |
|
) |
|
to_k_unpruned_indexes_w_offset = to_k_unpruned_indexes + to_qkv.cached_channel_attention.size(1) // 3 |
|
to_v_unpruned_indexes = get_unpruned_indexes_from_channel_attn( |
|
to_qkv.cached_channel_attention[:, to_qkv.cached_channel_attention.size(1) // 3 * 2: ], |
|
to_qkv.k |
|
) |
|
to_v_unpruned_indexes_w_offset = to_v_unpruned_indexes + to_qkv.cached_channel_attention.size(1) // 3 * 2 |
|
assert to_q_unpruned_indexes.size(0) == to_k_unpruned_indexes.size(0) == to_v_unpruned_indexes.size(0) |
|
to_qkv_unpruned_indexes = torch.cat([to_q_unpruned_indexes_w_offset, to_k_unpruned_indexes_w_offset, to_v_unpruned_indexes_w_offset]) |
|
new_to_qkv = nn.Linear(to_qkv.to_qkv.in_features, to_qkv_unpruned_indexes.size(0), to_qkv.to_qkv.bias is not None) |
|
new_to_qkv.weight.data.copy_(to_qkv.to_qkv.weight.data[to_qkv_unpruned_indexes]) |
|
if to_qkv.to_qkv.bias is not None: |
|
new_to_qkv.bias.data.copy_(to_qkv.to_qkv.bias.data[to_qkv_unpruned_indexes]) |
|
set_module(attn, 'to_qkv', nn.Sequential(new_to_qkv, StaticFBS(to_qkv.cached_channel_attention[:, to_qkv_unpruned_indexes]))) |
|
|
|
|
|
to_out = attn.to_out[0] |
|
new_to_out = nn.Linear(to_v_unpruned_indexes.size(0), to_out.out_features, to_out.bias is not None) |
|
new_to_out.weight.data.copy_(to_out.weight.data[:, to_v_unpruned_indexes]) |
|
if to_out.bias is not None: |
|
new_to_out.bias.data.copy_(to_out.bias.data) |
|
set_module(attn, 'to_out', new_to_out) |
|
|
|
ff_0 = ff.net[0] |
|
ff_0_unpruned_indexes = get_unpruned_indexes_from_channel_attn(ff_0.cached_channel_attention, ff_0.k) |
|
new_ff_0 = nn.Linear(ff_0.linear.in_features, ff_0_unpruned_indexes.size(0), ff_0.linear.bias is not None) |
|
new_ff_0.weight.data.copy_(ff_0.linear.weight.data[ff_0_unpruned_indexes]) |
|
if ff_0.linear.bias is not None: |
|
new_ff_0.bias.data.copy_(ff_0.linear.bias.data[ff_0_unpruned_indexes]) |
|
set_module(ff.net, '0', nn.Sequential(new_ff_0, StaticFBS(ff_0.cached_channel_attention[:, ff_0_unpruned_indexes]))) |
|
|
|
ff_1 = ff.net[3] |
|
new_ff_1 = nn.Linear(ff_0_unpruned_indexes.size(0), ff_1.out_features, ff_1.bias is not None) |
|
new_ff_1.weight.data.copy_(ff_1.weight.data[:, ff_0_unpruned_indexes]) |
|
if ff_1.bias is not None: |
|
new_ff_1.bias.data.copy_(ff_1.bias.data) |
|
set_module(ff.net, '3', new_ff_1) |
|
|
|
pruned_vit_model_size = get_model_size(boosted_vit, True) |
|
|
|
print(f'extract_surrogate_vit_via_cached_channel_attn() | model size from {raw_vit_model_size:.3f}MB to {pruned_vit_model_size:.3f}MB ' |
|
f'({(pruned_vit_model_size / raw_vit_model_size * 100):.2f}%)') |
|
|
|
return boosted_vit |
|
|
|
|
|
if __name__ == '__main__': |
|
from utils.dl.common.env import set_random_seed |
|
set_random_seed(1) |
|
|
|
def verify(vit, sparsity=0.8): |
|
vit.eval() |
|
|
|
with torch.no_grad(): |
|
r = torch.rand((1, 3, 224, 224)) |
|
print(vit(r).size()) |
|
|
|
|
|
boosted_vit = boost_raw_vit_by_fbs(vit, r=32, k=sparsity) |
|
set_boosted_vit_sparsity(boosted_vit, sparsity) |
|
|
|
with torch.no_grad(): |
|
r = torch.rand((1, 3, 224, 224)) |
|
print(boosted_vit(r).size()) |
|
|
|
|
|
r = torch.rand((1, 3, 224, 224)) |
|
boosted_vit.eval() |
|
with torch.no_grad(): |
|
o1 = boosted_vit(r) |
|
|
|
pruned_vit = extract_surrogate_vit_via_cached_channel_attn(boosted_vit) |
|
pruned_vit.eval() |
|
with torch.no_grad(): |
|
o2 = pruned_vit(r) |
|
print('output diff (should be tiny): ', ((o1 - o2) ** 2).sum()) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
vit_l_16 = ViT( |
|
image_size = 224, |
|
patch_size = 16, |
|
num_classes = 1000, |
|
dim = 1024, |
|
depth = 24, |
|
heads = 16, |
|
dim_head = 64, |
|
mlp_dim = 4096, |
|
dropout = 0., |
|
emb_dropout = 0. |
|
) |
|
verify(vit_l_16, 0.98) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|