weidai00's picture
Update AV/models/layers.py
469c24c verified
# -*- coding: utf-8 -*-
import torch
from torch import nn
import torch.nn.functional as F
# from timm.models.layers.cbam import CbamModule
import numpy as np
from einops import rearrange, repeat
import math
class ConvBn2d(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size, padding):
super(ConvBn2d, self).__init__()
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, padding=padding)
self.bn = nn.BatchNorm2d(out_channels)
def forward(self, x):
x = self.conv(x)
x = self.bn(x)
return x
class sSE(nn.Module):
def __init__(self, out_channels):
super(sSE, self).__init__()
self.conv = ConvBn2d(in_channels=out_channels, out_channels=1, kernel_size=1, padding=0)
def forward(self, x):
x = self.conv(x)
# print('spatial',x.size())
x = F.sigmoid(x)
return x
class cSE(nn.Module):
def __init__(self, out_channels):
super(cSE, self).__init__()
self.conv1 = ConvBn2d(in_channels=out_channels, out_channels=int(out_channels / 2), kernel_size=1, padding=0)
self.conv2 = ConvBn2d(in_channels=int(out_channels / 2), out_channels=out_channels, kernel_size=1, padding=0)
def forward(self, x):
x = nn.AvgPool2d(x.size()[2:])(x)
# print('channel',x.size())
x = self.conv1(x)
x = F.relu(x)
x = self.conv2(x)
x = F.sigmoid(x)
return x
class scSEBlock(nn.Module):
def __init__(self, out_channels):
super(scSEBlock, self).__init__()
self.spatial_gate = sSE(out_channels)
self.channel_gate = cSE(out_channels)
def forward(self, x):
g1 = self.spatial_gate(x)
g2 = self.channel_gate(x)
x = g1 * x + g2 * x
return x
class SaveFeatures():
features = None
def __init__(self, m):
self.hook = m.register_forward_hook(self.hook_fn)
def hook_fn(self, module, input, output):
# print('input',input)
# print('output',output.size())
if len(output.shape) == 3:
B, L, C = output.shape
h = int(L ** 0.5)
output = output.view(B, h, h, C)
output = output.permute(0, 3, 1, 2).contiguous()
if len(output.shape) == 4 and output.shape[2] != output.shape[3]:
output = output.permute(0, 3, 1, 2).contiguous()
# print(module)
self.features = output
def remove(self):
self.hook.remove()
class DBlock(nn.Module):
def __init__(self, in_channels, out_channels, use_batchnorm=True, attention_type=None):
super(DBlock, self).__init__()
self.conv1 = nn.Sequential(
nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1, stride=1, bias=False),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True),
)
if attention_type == 'scse':
self.attention1 = scSEBlock(in_channels)
elif attention_type == 'cbam':
self.attention1 = nn.Identity()
elif attention_type == 'transformer':
self.attention1 = nn.Identity()
else:
self.attention1 = nn.Identity()
self.conv2 = \
nn.Sequential(
nn.Conv2d(out_channels * 2, out_channels, kernel_size=3, padding=1, stride=1, bias=False),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True),
)
self.conv3 = nn.Sequential(
nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1, stride=1, bias=False),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True),
)
if attention_type == 'scse':
self.attention2 = scSEBlock(out_channels)
elif attention_type == 'cbam':
self.attention2 = CbamModule(channels=out_channels)
elif attention_type == 'transformer':
self.attention2 = nn.Identity()
else:
self.attention2 = nn.Identity()
def forward(self, x, skip):
if x.shape[1] != skip.shape[1]:
x = F.interpolate(x, scale_factor=2, mode='bilinear', align_corners=True)
# print(x.shape,skip.shape)
x = self.attention1(x)
x = self.conv1(x)
x = torch.cat([x, skip], dim=1)
x = self.conv2(x)
x = self.conv3(x)
x = self.attention2(x)
return x
class DBlock_res(nn.Module):
def __init__(self, in_channels, out_channels, use_batchnorm=True, attention_type=None):
super(DBlock_res, self).__init__()
self.conv1 = nn.Sequential(
nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1, stride=1),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True),
)
if attention_type == 'scse':
self.attention1 = scSEBlock(in_channels)
elif attention_type == 'cbam':
self.attention1 = CbamModule(channels=in_channels)
elif attention_type == 'transformer':
self.attention1 = nn.Identity()
else:
self.attention1 = nn.Identity()
self.conv2 = \
nn.Sequential(
nn.Conv2d(out_channels * 2, out_channels, kernel_size=3, padding=1, stride=1),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True),
)
self.conv3 = nn.Sequential(
nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1, stride=1),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True),
)
if attention_type == 'scse':
self.attention2 = scSEBlock(out_channels)
elif attention_type == 'cbam':
self.attention2 = CbamModule(channels=out_channels)
elif attention_type == 'transformer':
self.attention2 = nn.Identity()
else:
self.attention2 = nn.Identity()
def forward(self, x, skip):
x = F.interpolate(x, scale_factor=2, mode='bilinear', align_corners=True)
# print(x.shape,skip.shape)
x = self.attention1(x)
x = self.conv1(x)
x = torch.cat([x, skip], dim=1)
x = self.conv2(x)
x = self.conv3(x)
x = self.attention2(x)
return x
class DBlock_att(nn.Module):
def __init__(self, in_channels, out_channels, use_batchnorm=True, attention_type='transformer'):
super(DBlock_att, self).__init__()
self.conv1 = nn.Sequential(
nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1, stride=1, bias=False),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True),
)
if attention_type == 'scse':
self.attention1 = scSEBlock(in_channels)
elif attention_type == 'cbam':
self.attention1 = CbamModule(channels=in_channels)
elif attention_type == 'transformer':
self.attention1 = nn.Identity()
else:
self.attention1 = nn.Identity()
self.conv2 = \
nn.Sequential(
nn.Conv2d(out_channels * 2, out_channels, kernel_size=3, padding=1, stride=1, bias=False),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True),
)
self.conv3 = nn.Sequential(
nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1, stride=1, bias=False),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True),
)
if attention_type == 'scse':
self.attention2 = scSEBlock(out_channels)
elif attention_type == 'cbam':
self.attention2 = CbamModule(channels=out_channels)
elif attention_type == 'transformer':
self.attention2 = nn.Identity()
else:
self.attention2 = nn.Identity()
def forward(self, x, skip):
if x.shape[1] != skip.shape[1]:
x = F.interpolate(x, scale_factor=2, mode='bilinear', align_corners=True)
# print(x.shape,skip.shape)
x = self.attention1(x)
x = self.conv1(x)
x = torch.cat([x, skip], dim=1)
x = self.conv2(x)
x = self.conv3(x)
x = self.attention2(x)
return x
class SegmentationHead(nn.Module):
def __init__(self, in_channels, num_class, kernel_size=3, upsample=4):
super(SegmentationHead, self).__init__()
self.upsample = nn.UpsamplingBilinear2d(scale_factor=upsample) if upsample > 1 else nn.Identity()
self.conv = nn.Conv2d(in_channels, num_class, kernel_size=kernel_size, padding=kernel_size // 2)
def forward(self, x):
x = self.upsample(x)
x = self.conv(x)
return x
class AV_Cross(nn.Module):
def __init__(self, channels=2, r=2, residual=True, block=4, kernel_size=1):
super(AV_Cross, self).__init__()
out_channels = int(channels // r)
self.residual = residual
self.block = block
self.bn = nn.BatchNorm2d(3)
self.relu = False
self.kernel_size = kernel_size
self.a_ve_att = nn.ModuleList()
self.v_ve_att = nn.ModuleList()
self.ve_att = nn.ModuleList()
for i in range(self.block):
self.a_ve_att.append(nn.Sequential(
nn.Conv2d(channels, out_channels, kernel_size=self.kernel_size, stride=1,
padding=(self.kernel_size - 1) // 2),
nn.BatchNorm2d(out_channels),
))
self.v_ve_att.append(nn.Sequential(
nn.Conv2d(channels, out_channels, kernel_size=self.kernel_size, stride=1,
padding=(self.kernel_size - 1) // 2),
nn.BatchNorm2d(out_channels),
))
self.ve_att.append(nn.Sequential(
nn.Conv2d(3, out_channels, kernel_size=self.kernel_size, stride=1, padding=(self.kernel_size - 1) // 2),
nn.BatchNorm2d(out_channels),
))
self.sigmoid = nn.Sigmoid()
self.final = nn.Conv2d(3, 3, kernel_size=1, stride=1, padding=0)
def forward(self, x):
a, ve, v = x[:, 0:1, :, :], x[:, 1:2, :, :], x[:, 2:, :, :]
for i in range(self.block):
# x = self.relu(self.bn(x))
a_ve = torch.concat([a, ve], dim=1)
v_ve = torch.concat([v, ve], dim=1)
a_v_ve = torch.concat([a, ve, v], dim=1)
x_a = self.a_ve_att[i](a_ve)
x_v = self.v_ve_att[i](v_ve)
x_a_v = self.ve_att[i](a_v_ve)
a_weight = self.sigmoid(x_a)
v_weight = self.sigmoid(x_v)
ve_weight = self.sigmoid(x_a_v)
if self.residual:
a = a + a * a_weight
v = v + v * v_weight
ve = ve + ve * ve_weight
else:
a = a * a_weight
v = v * v_weight
ve = ve * ve_weight
out = torch.concat([a, ve, v], dim=1)
if self.relu:
out = F.relu(out)
out = self.final(out)
return out
class AV_Cross_v2(nn.Module):
def __init__(self, channels=2, r=2, residual=True, block=1, relu=False, kernel_size=1):
super(AV_Cross_v2, self).__init__()
out_channels = int(channels // r)
self.residual = residual
self.block = block
self.relu = relu
self.kernel_size = kernel_size
self.a_ve_att = nn.ModuleList()
self.v_ve_att = nn.ModuleList()
self.ve_att = nn.ModuleList()
for i in range(self.block):
self.a_ve_att.append(nn.Sequential(
nn.Conv2d(channels, out_channels, kernel_size=self.kernel_size, stride=1,
padding=(self.kernel_size - 1) // 2),
nn.BatchNorm2d(out_channels)
))
self.v_ve_att.append(nn.Sequential(
nn.Conv2d(channels, out_channels, kernel_size=self.kernel_size, stride=1,
padding=(self.kernel_size - 1) // 2),
nn.BatchNorm2d(out_channels)
))
self.ve_att.append(nn.Sequential(
nn.Conv2d(channels, out_channels, kernel_size=self.kernel_size, stride=1,
padding=(self.kernel_size - 1) // 2),
nn.BatchNorm2d(out_channels)
))
self.sigmoid = nn.Sigmoid()
self.final = nn.Conv2d(3, 3, kernel_size=1, stride=1, padding=0)
def forward(self, x):
a, ve, v = x[:, 0:1, :, :], x[:, 1:2, :, :], x[:, 2:, :, :]
for i in range(self.block):
tmp = torch.cat([a, ve, v], dim=1)
a_ve = torch.concat([a, ve], dim=1)
a_ve = torch.cat([torch.max(a_ve, dim=1, keepdim=True)[0], torch.mean(a_ve, dim=1, keepdim=True)], dim=1)
v_ve = torch.concat([v, ve], dim=1)
v_ve = torch.cat([torch.max(v_ve, dim=1, keepdim=True)[0], torch.mean(v_ve, dim=1, keepdim=True)], dim=1)
a_v_ve = torch.concat([torch.max(tmp, dim=1, keepdim=True)[0], torch.mean(tmp, dim=1, keepdim=True)], dim=1)
a_ve = self.a_ve_att[i](a_ve)
v_ve = self.v_ve_att[i](v_ve)
a_v_ve = self.ve_att[i](a_v_ve)
a_weight = self.sigmoid(a_ve)
v_weight = self.sigmoid(v_ve)
ve_weight = self.sigmoid(a_v_ve)
if self.residual:
a = a + a * a_weight
v = v + v * v_weight
ve = ve + ve * ve_weight
else:
a = a * a_weight
v = v * v_weight
ve = ve * ve_weight
out = torch.concat([a, ve, v], dim=1)
if self.relu:
out = F.relu(out)
out = self.final(out)
return out
class MultiHeadAttention(nn.Module):
def __init__(self, embedding_dim, head_num):
super().__init__()
self.head_num = head_num
self.dk = (embedding_dim // head_num) ** (1 / 2)
self.qkv_layer = nn.Linear(embedding_dim, embedding_dim * 3, bias=False)
self.out_attention = nn.Linear(embedding_dim, embedding_dim, bias=False)
def forward(self, x, mask=None):
qkv = self.qkv_layer(x)
query, key, value = tuple(rearrange(qkv, 'b t (d k h ) -> k b h t d ', k=3, h=self.head_num))
energy = torch.einsum("... i d , ... j d -> ... i j", query, key) * self.dk
if mask is not None:
energy = energy.masked_fill(mask, -np.inf)
attention = torch.softmax(energy, dim=-1)
x = torch.einsum("... i j , ... j d -> ... i d", attention, value)
x = rearrange(x, "b h t d -> b t (h d)")
x = self.out_attention(x)
return x
class MLP(nn.Module):
def __init__(self, embedding_dim, mlp_dim):
super().__init__()
self.mlp_layers = nn.Sequential(
nn.Linear(embedding_dim, mlp_dim),
nn.GELU(),
nn.Dropout(0.1),
nn.Linear(mlp_dim, embedding_dim),
nn.Dropout(0.1)
)
def forward(self, x):
x = self.mlp_layers(x)
return x
class TransformerEncoderBlock(nn.Module):
def __init__(self, embedding_dim, head_num, mlp_dim):
super().__init__()
self.multi_head_attention = MultiHeadAttention(embedding_dim, head_num)
self.mlp = MLP(embedding_dim, mlp_dim)
self.layer_norm1 = nn.LayerNorm(embedding_dim)
self.layer_norm2 = nn.LayerNorm(embedding_dim)
self.dropout = nn.Dropout(0.1)
def forward(self, x):
_x = self.multi_head_attention(x)
_x = self.dropout(_x)
x = x + _x
x = self.layer_norm1(x)
_x = self.mlp(x)
x = x + _x
x = self.layer_norm2(x)
return x
class TransformerEncoder(nn.Module):
"""
embedding_dim: token 向量长度
head_num: 自注意力头
block_num: transformer个数
"""
def __init__(self, embedding_dim, head_num, block_num=2):
super().__init__()
self.layer_blocks = nn.ModuleList(
[TransformerEncoderBlock(embedding_dim, head_num, 2 * embedding_dim) for _ in range(block_num)])
def forward(self, x):
for layer_block in self.layer_blocks:
x = layer_block(x)
return x
class PathEmbedding(nn.Module):
"""
img_dim: 输入图的大小
in_channels: 输入的通道数
embedding_dim: 每个token的向量长度
patch_size:输入图token化,token的大小
"""
def __init__(self, img_dim, in_channels, embedding_dim, patch_size):
super().__init__()
self.patch_size = patch_size
self.num_tokens = (img_dim // patch_size) ** 2
self.token_dim = in_channels * (patch_size ** 2)
# 1. projection
self.projection = nn.Linear(self.token_dim, embedding_dim)
# 2. position embedding
self.embedding = nn.Parameter(torch.rand(self.num_tokens + 1, embedding_dim))
# 3. cls token
self.cls_token = nn.Parameter(torch.randn(1, 1, embedding_dim))
def forward(self, x):
img_patches = rearrange(x,
'b c (patch_x x) (patch_y y) -> b (x y) (patch_x patch_y c)',
patch_x=self.patch_size, patch_y=self.patch_size)
batch_size, tokens_num, _ = img_patches.shape
patch_token = self.projection(img_patches)
cls_token = repeat(self.cls_token, 'b ... -> (b batch_size) ...',
batch_size=batch_size)
patches = torch.cat([cls_token, patch_token], dim=1)
# add postion embedding
patches += self.embedding[:tokens_num + 1, :]
# B,tokens_num+1,embedding_dim
return patches
class TransformerBottleNeck(nn.Module):
def __init__(self, img_dim, in_channels, embedding_dim, head_num,
block_num, patch_size=1, classification=False, dropout=0.1, num_classes=1):
super().__init__()
self.patch_embedding = PathEmbedding(img_dim, in_channels, embedding_dim, patch_size)
self.transformer = TransformerEncoder(embedding_dim, head_num, block_num)
self.dropout = nn.Dropout(dropout)
self.classification = classification
if self.classification:
self.mlp_head = nn.Linear(embedding_dim, num_classes)
def forward(self, x):
x = self.patch_embedding(x)
x = self.dropout(x)
x = self.transformer(x)
x = self.mlp_head(x[:, 0, :]) if self.classification else x[:, 1:, :]
return x
class PGFusion(nn.Module):
def __init__(self, in_channel=384, out_channel=384):
super(PGFusion, self).__init__()
self.in_channel = in_channel
self.out_channel = out_channel
self.patch_query = nn.Conv2d(in_channel, in_channel, kernel_size=1)
self.patch_key = nn.Conv2d(in_channel, in_channel, kernel_size=1)
self.patch_value = nn.Conv2d(in_channel, in_channel, kernel_size=1, bias=False)
self.patch_global_query = nn.Conv2d(in_channel, in_channel, kernel_size=1)
self.global_key = nn.Conv2d(in_channel, in_channel, kernel_size=1)
self.global_value = nn.Conv2d(in_channel, in_channel, kernel_size=1, bias=False)
self.fusion = nn.Conv2d(in_channel * 2, in_channel * 2, kernel_size=1)
self.out_patch = nn.Conv2d(in_channel, out_channel, kernel_size=1)
self.out_global = nn.Conv2d(in_channel, out_channel, kernel_size=1)
self.softmax = nn.Softmax(dim=2)
self.softmax_concat = nn.Softmax(dim=0)
self.gamma_patch_self = nn.Parameter(torch.ones(1))
self.gamma_patch_global = nn.Parameter(torch.ones(1))
self.init_parameters()
def init_parameters(self):
for m in self.modules():
if isinstance(m, nn.Conv3d) or isinstance(m, nn.Conv2d) or isinstance(m, nn.Conv1d):
nn.init.normal_(m.weight, 0, 0.01)
# nn.init.xavier_uniform_(m.weight.data)
if m.bias is not None:
nn.init.zeros_(m.bias)
# nn.init.constant_(m.bias, 0)
m.inited = True
def forward(self, patch_rep, global_rep):
patch_rep_ = patch_rep.clone()
patch_value = self.patch_value(patch_rep)
patch_value = patch_value.view(patch_value.size(0), patch_value.size(1), -1)
patch_key = self.patch_key(patch_rep)
patch_key = patch_key.view(patch_key.size(0), patch_key.size(1), -1)
dim_k = patch_key.shape[-1]
patch_query = self.patch_query(patch_rep)
patch_query = patch_query.view(patch_query.size(0), patch_query.size(1), -1)
patch_global_query = self.patch_global_query(patch_rep)
patch_global_query = patch_global_query.view(patch_global_query.size(0), patch_global_query.size(1), -1)
global_value = self.global_value(global_rep)
global_value = global_value.view(global_value.size(0), global_value.size(1), -1)
global_key = self.global_key(global_rep)
global_key = global_key.view(global_key.size(0), global_key.size(1), -1)
### patch self attention
patch_self_sim_map = patch_query @ patch_key.transpose(-2, -1) / math.sqrt(dim_k)
patch_self_sim_map = self.softmax(patch_self_sim_map)
patch_self_sim_map = patch_self_sim_map @ patch_value
patch_self_sim_map = patch_self_sim_map.view(patch_self_sim_map.size(0), patch_self_sim_map.size(1),
*patch_rep.size()[2:])
patch_self_sim_map = self.gamma_patch_self * patch_self_sim_map
# patch_self_sim_map = 1 * patch_self_sim_map
### patch global attention
patch_global_sim_map = patch_global_query @ global_key.transpose(-2, -1) / math.sqrt(dim_k)
patch_global_sim_map = self.softmax(patch_global_sim_map)
patch_global_sim_map = patch_global_sim_map @ global_value
patch_global_sim_map = patch_global_sim_map.view(patch_global_sim_map.size(0), patch_global_sim_map.size(1),
*patch_rep.size()[2:])
patch_global_sim_map = self.gamma_patch_global * patch_global_sim_map
# patch_global_sim_map = 1 * patch_global_sim_map
fusion_sim_weight_map = torch.cat((patch_self_sim_map, patch_global_sim_map), dim=1)
fusion_sim_weight_map = self.fusion(fusion_sim_weight_map)
fusion_sim_weight_map = 1 * fusion_sim_weight_map
patch_self_sim_weight_map = torch.split(fusion_sim_weight_map, dim=1, split_size_or_sections=self.in_channel)[0]
patch_self_sim_weight_map = torch.sigmoid(patch_self_sim_weight_map) # 0-1
patch_global_sim_weight_map = torch.split(fusion_sim_weight_map, dim=1, split_size_or_sections=self.in_channel)[
1]
patch_global_sim_weight_map = torch.sigmoid(patch_global_sim_weight_map) # 0-1
patch_self_sim_weight_map = torch.unsqueeze(patch_self_sim_weight_map, 0)
patch_global_sim_weight_map = torch.unsqueeze(patch_global_sim_weight_map, 0)
ct = torch.concat((patch_self_sim_weight_map, patch_global_sim_weight_map), 0)
ct = self.softmax_concat(ct)
out = patch_rep_ + patch_self_sim_map * ct[0] + patch_global_sim_map * (1 - ct[0])
return out
if __name__ == '__main__':
x = torch.randn((2, 384, 16, 16))
m = PGFusion()
print(m)
# y = TransformerBottleNeck(x.shape[2],x.shape[1],x.shape[1],8,4)
print(m(x, x).shape)