|
import torch
|
|
from torch import nn
|
|
from torchvision import models
|
|
from einops import rearrange
|
|
from torchvision.models._utils import IntermediateLayerGetter
|
|
|
|
|
|
class Vgg(nn.Module):
|
|
def __init__(self, name, ss, ks, hidden, pretrained=True, dropout=0.5):
|
|
super(Vgg, self).__init__()
|
|
|
|
if name == 'vgg11_bn':
|
|
cnn = models.vgg11_bn(weights='DEFAULT')
|
|
elif name == 'vgg19_bn':
|
|
cnn = models.vgg19_bn(weights='DEFAULT')
|
|
|
|
pool_idx = 0
|
|
|
|
for i, layer in enumerate(cnn.features):
|
|
if isinstance(layer, torch.nn.MaxPool2d):
|
|
cnn.features[i] = torch.nn.AvgPool2d(kernel_size=ks[pool_idx], stride=ss[pool_idx], padding=0)
|
|
pool_idx += 1
|
|
|
|
self.features = cnn.features
|
|
self.dropout = nn.Dropout(dropout)
|
|
self.last_conv_1x1 = nn.Conv2d(512, hidden, 1)
|
|
|
|
def forward(self, x):
|
|
"""
|
|
Shape:
|
|
- x: (N, C, H, W)
|
|
- output: (W, N, C)
|
|
"""
|
|
|
|
conv = self.features(x)
|
|
conv = self.dropout(conv)
|
|
conv = self.last_conv_1x1(conv)
|
|
|
|
|
|
conv = conv.transpose(-1, -2)
|
|
conv = conv.flatten(2)
|
|
conv = conv.permute(-1, 0, 1)
|
|
return conv
|
|
|
|
def vgg11_bn(ss, ks, hidden, pretrained=True, dropout=0.5):
|
|
return Vgg('vgg11_bn', ss, ks, hidden, pretrained, dropout)
|
|
|
|
def vgg19_bn(ss, ks, hidden, pretrained=True, dropout=0.5):
|
|
return Vgg('vgg19_bn', ss, ks, hidden, pretrained, dropout)
|
|
|
|
|