Spaces:
Running
Running
# -*- coding: utf-8 -*- | |
import torch | |
from torch import nn | |
import torch.nn.functional as F | |
# from timm.models.layers.cbam import CbamModule | |
import numpy as np | |
from einops import rearrange, repeat | |
import math | |
class ConvBn2d(nn.Module): | |
def __init__(self, in_channels, out_channels, kernel_size, padding): | |
super(ConvBn2d, self).__init__() | |
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, padding=padding) | |
self.bn = nn.BatchNorm2d(out_channels) | |
def forward(self, x): | |
x = self.conv(x) | |
x = self.bn(x) | |
return x | |
class sSE(nn.Module): | |
def __init__(self, out_channels): | |
super(sSE, self).__init__() | |
self.conv = ConvBn2d(in_channels=out_channels, out_channels=1, kernel_size=1, padding=0) | |
def forward(self, x): | |
x = self.conv(x) | |
# print('spatial',x.size()) | |
x = F.sigmoid(x) | |
return x | |
class cSE(nn.Module): | |
def __init__(self, out_channels): | |
super(cSE, self).__init__() | |
self.conv1 = ConvBn2d(in_channels=out_channels, out_channels=int(out_channels / 2), kernel_size=1, padding=0) | |
self.conv2 = ConvBn2d(in_channels=int(out_channels / 2), out_channels=out_channels, kernel_size=1, padding=0) | |
def forward(self, x): | |
x = nn.AvgPool2d(x.size()[2:])(x) | |
# print('channel',x.size()) | |
x = self.conv1(x) | |
x = F.relu(x) | |
x = self.conv2(x) | |
x = F.sigmoid(x) | |
return x | |
class scSEBlock(nn.Module): | |
def __init__(self, out_channels): | |
super(scSEBlock, self).__init__() | |
self.spatial_gate = sSE(out_channels) | |
self.channel_gate = cSE(out_channels) | |
def forward(self, x): | |
g1 = self.spatial_gate(x) | |
g2 = self.channel_gate(x) | |
x = g1 * x + g2 * x | |
return x | |
class SaveFeatures(): | |
features = None | |
def __init__(self, m): | |
self.hook = m.register_forward_hook(self.hook_fn) | |
def hook_fn(self, module, input, output): | |
# print('input',input) | |
# print('output',output.size()) | |
if len(output.shape) == 3: | |
B, L, C = output.shape | |
h = int(L ** 0.5) | |
output = output.view(B, h, h, C) | |
output = output.permute(0, 3, 1, 2).contiguous() | |
if len(output.shape) == 4 and output.shape[2] != output.shape[3]: | |
output = output.permute(0, 3, 1, 2).contiguous() | |
# print(module) | |
self.features = output | |
def remove(self): | |
self.hook.remove() | |
class DBlock(nn.Module): | |
def __init__(self, in_channels, out_channels, use_batchnorm=True, attention_type=None): | |
super(DBlock, self).__init__() | |
self.conv1 = nn.Sequential( | |
nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1, stride=1, bias=False), | |
nn.BatchNorm2d(out_channels), | |
nn.ReLU(inplace=True), | |
) | |
if attention_type == 'scse': | |
self.attention1 = scSEBlock(in_channels) | |
elif attention_type == 'cbam': | |
self.attention1 = nn.Identity() | |
elif attention_type == 'transformer': | |
self.attention1 = nn.Identity() | |
else: | |
self.attention1 = nn.Identity() | |
self.conv2 = \ | |
nn.Sequential( | |
nn.Conv2d(out_channels * 2, out_channels, kernel_size=3, padding=1, stride=1, bias=False), | |
nn.BatchNorm2d(out_channels), | |
nn.ReLU(inplace=True), | |
) | |
self.conv3 = nn.Sequential( | |
nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1, stride=1, bias=False), | |
nn.BatchNorm2d(out_channels), | |
nn.ReLU(inplace=True), | |
) | |
if attention_type == 'scse': | |
self.attention2 = scSEBlock(out_channels) | |
elif attention_type == 'cbam': | |
self.attention2 = CbamModule(channels=out_channels) | |
elif attention_type == 'transformer': | |
self.attention2 = nn.Identity() | |
else: | |
self.attention2 = nn.Identity() | |
def forward(self, x, skip): | |
if x.shape[1] != skip.shape[1]: | |
x = F.interpolate(x, scale_factor=2, mode='bilinear', align_corners=True) | |
# print(x.shape,skip.shape) | |
x = self.attention1(x) | |
x = self.conv1(x) | |
x = torch.cat([x, skip], dim=1) | |
x = self.conv2(x) | |
x = self.conv3(x) | |
x = self.attention2(x) | |
return x | |
class DBlock_res(nn.Module): | |
def __init__(self, in_channels, out_channels, use_batchnorm=True, attention_type=None): | |
super(DBlock_res, self).__init__() | |
self.conv1 = nn.Sequential( | |
nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1, stride=1), | |
nn.BatchNorm2d(out_channels), | |
nn.ReLU(inplace=True), | |
) | |
if attention_type == 'scse': | |
self.attention1 = scSEBlock(in_channels) | |
elif attention_type == 'cbam': | |
self.attention1 = CbamModule(channels=in_channels) | |
elif attention_type == 'transformer': | |
self.attention1 = nn.Identity() | |
else: | |
self.attention1 = nn.Identity() | |
self.conv2 = \ | |
nn.Sequential( | |
nn.Conv2d(out_channels * 2, out_channels, kernel_size=3, padding=1, stride=1), | |
nn.BatchNorm2d(out_channels), | |
nn.ReLU(inplace=True), | |
) | |
self.conv3 = nn.Sequential( | |
nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1, stride=1), | |
nn.BatchNorm2d(out_channels), | |
nn.ReLU(inplace=True), | |
) | |
if attention_type == 'scse': | |
self.attention2 = scSEBlock(out_channels) | |
elif attention_type == 'cbam': | |
self.attention2 = CbamModule(channels=out_channels) | |
elif attention_type == 'transformer': | |
self.attention2 = nn.Identity() | |
else: | |
self.attention2 = nn.Identity() | |
def forward(self, x, skip): | |
x = F.interpolate(x, scale_factor=2, mode='bilinear', align_corners=True) | |
# print(x.shape,skip.shape) | |
x = self.attention1(x) | |
x = self.conv1(x) | |
x = torch.cat([x, skip], dim=1) | |
x = self.conv2(x) | |
x = self.conv3(x) | |
x = self.attention2(x) | |
return x | |
class DBlock_att(nn.Module): | |
def __init__(self, in_channels, out_channels, use_batchnorm=True, attention_type='transformer'): | |
super(DBlock_att, self).__init__() | |
self.conv1 = nn.Sequential( | |
nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1, stride=1, bias=False), | |
nn.BatchNorm2d(out_channels), | |
nn.ReLU(inplace=True), | |
) | |
if attention_type == 'scse': | |
self.attention1 = scSEBlock(in_channels) | |
elif attention_type == 'cbam': | |
self.attention1 = CbamModule(channels=in_channels) | |
elif attention_type == 'transformer': | |
self.attention1 = nn.Identity() | |
else: | |
self.attention1 = nn.Identity() | |
self.conv2 = \ | |
nn.Sequential( | |
nn.Conv2d(out_channels * 2, out_channels, kernel_size=3, padding=1, stride=1, bias=False), | |
nn.BatchNorm2d(out_channels), | |
nn.ReLU(inplace=True), | |
) | |
self.conv3 = nn.Sequential( | |
nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1, stride=1, bias=False), | |
nn.BatchNorm2d(out_channels), | |
nn.ReLU(inplace=True), | |
) | |
if attention_type == 'scse': | |
self.attention2 = scSEBlock(out_channels) | |
elif attention_type == 'cbam': | |
self.attention2 = CbamModule(channels=out_channels) | |
elif attention_type == 'transformer': | |
self.attention2 = nn.Identity() | |
else: | |
self.attention2 = nn.Identity() | |
def forward(self, x, skip): | |
if x.shape[1] != skip.shape[1]: | |
x = F.interpolate(x, scale_factor=2, mode='bilinear', align_corners=True) | |
# print(x.shape,skip.shape) | |
x = self.attention1(x) | |
x = self.conv1(x) | |
x = torch.cat([x, skip], dim=1) | |
x = self.conv2(x) | |
x = self.conv3(x) | |
x = self.attention2(x) | |
return x | |
class SegmentationHead(nn.Module): | |
def __init__(self, in_channels, num_class, kernel_size=3, upsample=4): | |
super(SegmentationHead, self).__init__() | |
self.upsample = nn.UpsamplingBilinear2d(scale_factor=upsample) if upsample > 1 else nn.Identity() | |
self.conv = nn.Conv2d(in_channels, num_class, kernel_size=kernel_size, padding=kernel_size // 2) | |
def forward(self, x): | |
x = self.upsample(x) | |
x = self.conv(x) | |
return x | |
class AV_Cross(nn.Module): | |
def __init__(self, channels=2, r=2, residual=True, block=4, kernel_size=1): | |
super(AV_Cross, self).__init__() | |
out_channels = int(channels // r) | |
self.residual = residual | |
self.block = block | |
self.bn = nn.BatchNorm2d(3) | |
self.relu = False | |
self.kernel_size = kernel_size | |
self.a_ve_att = nn.ModuleList() | |
self.v_ve_att = nn.ModuleList() | |
self.ve_att = nn.ModuleList() | |
for i in range(self.block): | |
self.a_ve_att.append(nn.Sequential( | |
nn.Conv2d(channels, out_channels, kernel_size=self.kernel_size, stride=1, | |
padding=(self.kernel_size - 1) // 2), | |
nn.BatchNorm2d(out_channels), | |
)) | |
self.v_ve_att.append(nn.Sequential( | |
nn.Conv2d(channels, out_channels, kernel_size=self.kernel_size, stride=1, | |
padding=(self.kernel_size - 1) // 2), | |
nn.BatchNorm2d(out_channels), | |
)) | |
self.ve_att.append(nn.Sequential( | |
nn.Conv2d(3, out_channels, kernel_size=self.kernel_size, stride=1, padding=(self.kernel_size - 1) // 2), | |
nn.BatchNorm2d(out_channels), | |
)) | |
self.sigmoid = nn.Sigmoid() | |
self.final = nn.Conv2d(3, 3, kernel_size=1, stride=1, padding=0) | |
def forward(self, x): | |
a, ve, v = x[:, 0:1, :, :], x[:, 1:2, :, :], x[:, 2:, :, :] | |
for i in range(self.block): | |
# x = self.relu(self.bn(x)) | |
a_ve = torch.concat([a, ve], dim=1) | |
v_ve = torch.concat([v, ve], dim=1) | |
a_v_ve = torch.concat([a, ve, v], dim=1) | |
x_a = self.a_ve_att[i](a_ve) | |
x_v = self.v_ve_att[i](v_ve) | |
x_a_v = self.ve_att[i](a_v_ve) | |
a_weight = self.sigmoid(x_a) | |
v_weight = self.sigmoid(x_v) | |
ve_weight = self.sigmoid(x_a_v) | |
if self.residual: | |
a = a + a * a_weight | |
v = v + v * v_weight | |
ve = ve + ve * ve_weight | |
else: | |
a = a * a_weight | |
v = v * v_weight | |
ve = ve * ve_weight | |
out = torch.concat([a, ve, v], dim=1) | |
if self.relu: | |
out = F.relu(out) | |
out = self.final(out) | |
return out | |
class AV_Cross_v2(nn.Module): | |
def __init__(self, channels=2, r=2, residual=True, block=1, relu=False, kernel_size=1): | |
super(AV_Cross_v2, self).__init__() | |
out_channels = int(channels // r) | |
self.residual = residual | |
self.block = block | |
self.relu = relu | |
self.kernel_size = kernel_size | |
self.a_ve_att = nn.ModuleList() | |
self.v_ve_att = nn.ModuleList() | |
self.ve_att = nn.ModuleList() | |
for i in range(self.block): | |
self.a_ve_att.append(nn.Sequential( | |
nn.Conv2d(channels, out_channels, kernel_size=self.kernel_size, stride=1, | |
padding=(self.kernel_size - 1) // 2), | |
nn.BatchNorm2d(out_channels) | |
)) | |
self.v_ve_att.append(nn.Sequential( | |
nn.Conv2d(channels, out_channels, kernel_size=self.kernel_size, stride=1, | |
padding=(self.kernel_size - 1) // 2), | |
nn.BatchNorm2d(out_channels) | |
)) | |
self.ve_att.append(nn.Sequential( | |
nn.Conv2d(channels, out_channels, kernel_size=self.kernel_size, stride=1, | |
padding=(self.kernel_size - 1) // 2), | |
nn.BatchNorm2d(out_channels) | |
)) | |
self.sigmoid = nn.Sigmoid() | |
self.final = nn.Conv2d(3, 3, kernel_size=1, stride=1, padding=0) | |
def forward(self, x): | |
a, ve, v = x[:, 0:1, :, :], x[:, 1:2, :, :], x[:, 2:, :, :] | |
for i in range(self.block): | |
tmp = torch.cat([a, ve, v], dim=1) | |
a_ve = torch.concat([a, ve], dim=1) | |
a_ve = torch.cat([torch.max(a_ve, dim=1, keepdim=True)[0], torch.mean(a_ve, dim=1, keepdim=True)], dim=1) | |
v_ve = torch.concat([v, ve], dim=1) | |
v_ve = torch.cat([torch.max(v_ve, dim=1, keepdim=True)[0], torch.mean(v_ve, dim=1, keepdim=True)], dim=1) | |
a_v_ve = torch.concat([torch.max(tmp, dim=1, keepdim=True)[0], torch.mean(tmp, dim=1, keepdim=True)], dim=1) | |
a_ve = self.a_ve_att[i](a_ve) | |
v_ve = self.v_ve_att[i](v_ve) | |
a_v_ve = self.ve_att[i](a_v_ve) | |
a_weight = self.sigmoid(a_ve) | |
v_weight = self.sigmoid(v_ve) | |
ve_weight = self.sigmoid(a_v_ve) | |
if self.residual: | |
a = a + a * a_weight | |
v = v + v * v_weight | |
ve = ve + ve * ve_weight | |
else: | |
a = a * a_weight | |
v = v * v_weight | |
ve = ve * ve_weight | |
out = torch.concat([a, ve, v], dim=1) | |
if self.relu: | |
out = F.relu(out) | |
out = self.final(out) | |
return out | |
class MultiHeadAttention(nn.Module): | |
def __init__(self, embedding_dim, head_num): | |
super().__init__() | |
self.head_num = head_num | |
self.dk = (embedding_dim // head_num) ** (1 / 2) | |
self.qkv_layer = nn.Linear(embedding_dim, embedding_dim * 3, bias=False) | |
self.out_attention = nn.Linear(embedding_dim, embedding_dim, bias=False) | |
def forward(self, x, mask=None): | |
qkv = self.qkv_layer(x) | |
query, key, value = tuple(rearrange(qkv, 'b t (d k h ) -> k b h t d ', k=3, h=self.head_num)) | |
energy = torch.einsum("... i d , ... j d -> ... i j", query, key) * self.dk | |
if mask is not None: | |
energy = energy.masked_fill(mask, -np.inf) | |
attention = torch.softmax(energy, dim=-1) | |
x = torch.einsum("... i j , ... j d -> ... i d", attention, value) | |
x = rearrange(x, "b h t d -> b t (h d)") | |
x = self.out_attention(x) | |
return x | |
class MLP(nn.Module): | |
def __init__(self, embedding_dim, mlp_dim): | |
super().__init__() | |
self.mlp_layers = nn.Sequential( | |
nn.Linear(embedding_dim, mlp_dim), | |
nn.GELU(), | |
nn.Dropout(0.1), | |
nn.Linear(mlp_dim, embedding_dim), | |
nn.Dropout(0.1) | |
) | |
def forward(self, x): | |
x = self.mlp_layers(x) | |
return x | |
class TransformerEncoderBlock(nn.Module): | |
def __init__(self, embedding_dim, head_num, mlp_dim): | |
super().__init__() | |
self.multi_head_attention = MultiHeadAttention(embedding_dim, head_num) | |
self.mlp = MLP(embedding_dim, mlp_dim) | |
self.layer_norm1 = nn.LayerNorm(embedding_dim) | |
self.layer_norm2 = nn.LayerNorm(embedding_dim) | |
self.dropout = nn.Dropout(0.1) | |
def forward(self, x): | |
_x = self.multi_head_attention(x) | |
_x = self.dropout(_x) | |
x = x + _x | |
x = self.layer_norm1(x) | |
_x = self.mlp(x) | |
x = x + _x | |
x = self.layer_norm2(x) | |
return x | |
class TransformerEncoder(nn.Module): | |
""" | |
embedding_dim: token 向量长度 | |
head_num: 自注意力头 | |
block_num: transformer个数 | |
""" | |
def __init__(self, embedding_dim, head_num, block_num=2): | |
super().__init__() | |
self.layer_blocks = nn.ModuleList( | |
[TransformerEncoderBlock(embedding_dim, head_num, 2 * embedding_dim) for _ in range(block_num)]) | |
def forward(self, x): | |
for layer_block in self.layer_blocks: | |
x = layer_block(x) | |
return x | |
class PathEmbedding(nn.Module): | |
""" | |
img_dim: 输入图的大小 | |
in_channels: 输入的通道数 | |
embedding_dim: 每个token的向量长度 | |
patch_size:输入图token化,token的大小 | |
""" | |
def __init__(self, img_dim, in_channels, embedding_dim, patch_size): | |
super().__init__() | |
self.patch_size = patch_size | |
self.num_tokens = (img_dim // patch_size) ** 2 | |
self.token_dim = in_channels * (patch_size ** 2) | |
# 1. projection | |
self.projection = nn.Linear(self.token_dim, embedding_dim) | |
# 2. position embedding | |
self.embedding = nn.Parameter(torch.rand(self.num_tokens + 1, embedding_dim)) | |
# 3. cls token | |
self.cls_token = nn.Parameter(torch.randn(1, 1, embedding_dim)) | |
def forward(self, x): | |
img_patches = rearrange(x, | |
'b c (patch_x x) (patch_y y) -> b (x y) (patch_x patch_y c)', | |
patch_x=self.patch_size, patch_y=self.patch_size) | |
batch_size, tokens_num, _ = img_patches.shape | |
patch_token = self.projection(img_patches) | |
cls_token = repeat(self.cls_token, 'b ... -> (b batch_size) ...', | |
batch_size=batch_size) | |
patches = torch.cat([cls_token, patch_token], dim=1) | |
# add postion embedding | |
patches += self.embedding[:tokens_num + 1, :] | |
# B,tokens_num+1,embedding_dim | |
return patches | |
class TransformerBottleNeck(nn.Module): | |
def __init__(self, img_dim, in_channels, embedding_dim, head_num, | |
block_num, patch_size=1, classification=False, dropout=0.1, num_classes=1): | |
super().__init__() | |
self.patch_embedding = PathEmbedding(img_dim, in_channels, embedding_dim, patch_size) | |
self.transformer = TransformerEncoder(embedding_dim, head_num, block_num) | |
self.dropout = nn.Dropout(dropout) | |
self.classification = classification | |
if self.classification: | |
self.mlp_head = nn.Linear(embedding_dim, num_classes) | |
def forward(self, x): | |
x = self.patch_embedding(x) | |
x = self.dropout(x) | |
x = self.transformer(x) | |
x = self.mlp_head(x[:, 0, :]) if self.classification else x[:, 1:, :] | |
return x | |
class PGFusion(nn.Module): | |
def __init__(self, in_channel=384, out_channel=384): | |
super(PGFusion, self).__init__() | |
self.in_channel = in_channel | |
self.out_channel = out_channel | |
self.patch_query = nn.Conv2d(in_channel, in_channel, kernel_size=1) | |
self.patch_key = nn.Conv2d(in_channel, in_channel, kernel_size=1) | |
self.patch_value = nn.Conv2d(in_channel, in_channel, kernel_size=1, bias=False) | |
self.patch_global_query = nn.Conv2d(in_channel, in_channel, kernel_size=1) | |
self.global_key = nn.Conv2d(in_channel, in_channel, kernel_size=1) | |
self.global_value = nn.Conv2d(in_channel, in_channel, kernel_size=1, bias=False) | |
self.fusion = nn.Conv2d(in_channel * 2, in_channel * 2, kernel_size=1) | |
self.out_patch = nn.Conv2d(in_channel, out_channel, kernel_size=1) | |
self.out_global = nn.Conv2d(in_channel, out_channel, kernel_size=1) | |
self.softmax = nn.Softmax(dim=2) | |
self.softmax_concat = nn.Softmax(dim=0) | |
self.gamma_patch_self = nn.Parameter(torch.ones(1)) | |
self.gamma_patch_global = nn.Parameter(torch.ones(1)) | |
self.init_parameters() | |
def init_parameters(self): | |
for m in self.modules(): | |
if isinstance(m, nn.Conv3d) or isinstance(m, nn.Conv2d) or isinstance(m, nn.Conv1d): | |
nn.init.normal_(m.weight, 0, 0.01) | |
# nn.init.xavier_uniform_(m.weight.data) | |
if m.bias is not None: | |
nn.init.zeros_(m.bias) | |
# nn.init.constant_(m.bias, 0) | |
m.inited = True | |
def forward(self, patch_rep, global_rep): | |
patch_rep_ = patch_rep.clone() | |
patch_value = self.patch_value(patch_rep) | |
patch_value = patch_value.view(patch_value.size(0), patch_value.size(1), -1) | |
patch_key = self.patch_key(patch_rep) | |
patch_key = patch_key.view(patch_key.size(0), patch_key.size(1), -1) | |
dim_k = patch_key.shape[-1] | |
patch_query = self.patch_query(patch_rep) | |
patch_query = patch_query.view(patch_query.size(0), patch_query.size(1), -1) | |
patch_global_query = self.patch_global_query(patch_rep) | |
patch_global_query = patch_global_query.view(patch_global_query.size(0), patch_global_query.size(1), -1) | |
global_value = self.global_value(global_rep) | |
global_value = global_value.view(global_value.size(0), global_value.size(1), -1) | |
global_key = self.global_key(global_rep) | |
global_key = global_key.view(global_key.size(0), global_key.size(1), -1) | |
### patch self attention | |
patch_self_sim_map = patch_query @ patch_key.transpose(-2, -1) / math.sqrt(dim_k) | |
patch_self_sim_map = self.softmax(patch_self_sim_map) | |
patch_self_sim_map = patch_self_sim_map @ patch_value | |
patch_self_sim_map = patch_self_sim_map.view(patch_self_sim_map.size(0), patch_self_sim_map.size(1), | |
*patch_rep.size()[2:]) | |
patch_self_sim_map = self.gamma_patch_self * patch_self_sim_map | |
# patch_self_sim_map = 1 * patch_self_sim_map | |
### patch global attention | |
patch_global_sim_map = patch_global_query @ global_key.transpose(-2, -1) / math.sqrt(dim_k) | |
patch_global_sim_map = self.softmax(patch_global_sim_map) | |
patch_global_sim_map = patch_global_sim_map @ global_value | |
patch_global_sim_map = patch_global_sim_map.view(patch_global_sim_map.size(0), patch_global_sim_map.size(1), | |
*patch_rep.size()[2:]) | |
patch_global_sim_map = self.gamma_patch_global * patch_global_sim_map | |
# patch_global_sim_map = 1 * patch_global_sim_map | |
fusion_sim_weight_map = torch.cat((patch_self_sim_map, patch_global_sim_map), dim=1) | |
fusion_sim_weight_map = self.fusion(fusion_sim_weight_map) | |
fusion_sim_weight_map = 1 * fusion_sim_weight_map | |
patch_self_sim_weight_map = torch.split(fusion_sim_weight_map, dim=1, split_size_or_sections=self.in_channel)[0] | |
patch_self_sim_weight_map = torch.sigmoid(patch_self_sim_weight_map) # 0-1 | |
patch_global_sim_weight_map = torch.split(fusion_sim_weight_map, dim=1, split_size_or_sections=self.in_channel)[ | |
1] | |
patch_global_sim_weight_map = torch.sigmoid(patch_global_sim_weight_map) # 0-1 | |
patch_self_sim_weight_map = torch.unsqueeze(patch_self_sim_weight_map, 0) | |
patch_global_sim_weight_map = torch.unsqueeze(patch_global_sim_weight_map, 0) | |
ct = torch.concat((patch_self_sim_weight_map, patch_global_sim_weight_map), 0) | |
ct = self.softmax_concat(ct) | |
out = patch_rep_ + patch_self_sim_map * ct[0] + patch_global_sim_map * (1 - ct[0]) | |
return out | |
if __name__ == '__main__': | |
x = torch.randn((2, 384, 16, 16)) | |
m = PGFusion() | |
print(m) | |
# y = TransformerBottleNeck(x.shape[2],x.shape[1],x.shape[1],8,4) | |
print(m(x, x).shape) | |