File size: 2,318 Bytes
25f71fa |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 |
import torch
import torch.nn as nn
import torch.nn.functional as F
def exists(val):
return val is not None
def initialize_weights(module):
for m in module.modules():
if isinstance(m, nn.Linear):
nn.init.xavier_normal_(m.weight)
if m.bias is not None:
m.bias.data.zero_()
"""
Attention Network with Sigmoid Gating (3 fc layers)
args:
L: input feature dimension
D: hidden layer dimension
dropout: whether to use dropout (p = 0.25)
n_classes: number of classes
"""
class Attn_Net_Gated(nn.Module):
def __init__(self, L = 1024, D = 256, n_tasks = 1):
super(Attn_Net_Gated, self).__init__()
self.attention_a = nn.Sequential(nn.Linear(L, D), nn.Tanh(), nn.Dropout(0.25))
self.attention_b = nn.Sequential(nn.Linear(L, D), nn.Sigmoid(), nn.Dropout(0.25))
self.attention_c = nn.Linear(D, n_tasks)
def forward(self, x):
a = self.attention_a(x)
b = self.attention_b(x)
A = a.mul(b)
A = self.attention_c(A) # N x n_classes
return A, x
"""
Code borrow from: https://github.com/mahmoodlab/TOAD
args:
gate: whether to use gating in attention network
size_args: size config of attention network
dropout: whether to use dropout in attention network
n_classes: number of classes
"""
class DeepAttnMIL(nn.Module):
def __init__(self, input_dim = 1024, size_arg = "big", n_classes = 2):
super(DeepAttnMIL, self).__init__()
self.size_dict = {"small": [input_dim, 512, 256], "big": [input_dim, 512, 384]}
size = self.size_dict[size_arg]
self.attention_net = nn.Sequential(
nn.Linear(size[0], size[1]),
nn.ReLU(),
nn.Dropout(0.25),
Attn_Net_Gated(L = size[1], D = size[2], n_tasks = 1))
self.classifier = nn.Linear(size[1], n_classes)
initialize_weights(self)
def forward(self, h, return_features=False, attention_only=False):
A, h = self.attention_net(h)
A = torch.transpose(A, 1, 0)
if attention_only:
return A[0]
A = F.softmax(A, dim=1)
M = torch.mm(A, h)
if return_features:
return M
logits = self.classifier(M)
return logits
|