|
import os, sys |
|
from collections import OrderedDict |
|
import cv2 |
|
import torch.nn as nn |
|
import torch |
|
from torchvision import models |
|
import torchvision.transforms as transforms |
|
|
|
''' |
|
---------------------------------------------------------------- |
|
Layer (type) Output Shape Param # |
|
================================================================ |
|
Conv2d-1 [-1, 64, 112, 112] 9,408 |
|
BatchNorm2d-2 [-1, 64, 112, 112] 128 |
|
ReLU-3 [-1, 64, 112, 112] 0 |
|
MaxPool2d-4 [-1, 64, 56, 56] 0 |
|
Conv2d-5 [-1, 64, 56, 56] 4,096 |
|
BatchNorm2d-6 [-1, 64, 56, 56] 128 |
|
ReLU-7 [-1, 64, 56, 56] 0 |
|
Conv2d-8 [-1, 64, 56, 56] 36,864 |
|
BatchNorm2d-9 [-1, 64, 56, 56] 128 |
|
ReLU-10 [-1, 64, 56, 56] 0 |
|
Conv2d-11 [-1, 256, 56, 56] 16,384 |
|
BatchNorm2d-12 [-1, 256, 56, 56] 512 |
|
Conv2d-13 [-1, 256, 56, 56] 16,384 |
|
BatchNorm2d-14 [-1, 256, 56, 56] 512 |
|
ReLU-15 [-1, 256, 56, 56] 0 |
|
Bottleneck-16 [-1, 256, 56, 56] 0 |
|
Conv2d-17 [-1, 64, 56, 56] 16,384 |
|
BatchNorm2d-18 [-1, 64, 56, 56] 128 |
|
ReLU-19 [-1, 64, 56, 56] 0 |
|
Conv2d-20 [-1, 64, 56, 56] 36,864 |
|
BatchNorm2d-21 [-1, 64, 56, 56] 128 |
|
ReLU-22 [-1, 64, 56, 56] 0 |
|
Conv2d-23 [-1, 256, 56, 56] 16,384 |
|
BatchNorm2d-24 [-1, 256, 56, 56] 512 |
|
ReLU-25 [-1, 256, 56, 56] 0 |
|
Bottleneck-26 [-1, 256, 56, 56] 0 |
|
Conv2d-27 [-1, 64, 56, 56] 16,384 |
|
BatchNorm2d-28 [-1, 64, 56, 56] 128 |
|
ReLU-29 [-1, 64, 56, 56] 0 |
|
Conv2d-30 [-1, 64, 56, 56] 36,864 |
|
BatchNorm2d-31 [-1, 64, 56, 56] 128 |
|
ReLU-32 [-1, 64, 56, 56] 0 |
|
Conv2d-33 [-1, 256, 56, 56] 16,384 |
|
BatchNorm2d-34 [-1, 256, 56, 56] 512 |
|
ReLU-35 [-1, 256, 56, 56] 0 |
|
Bottleneck-36 [-1, 256, 56, 56] 0 |
|
Conv2d-37 [-1, 128, 56, 56] 32,768 |
|
BatchNorm2d-38 [-1, 128, 56, 56] 256 |
|
ReLU-39 [-1, 128, 56, 56] 0 |
|
Conv2d-40 [-1, 128, 28, 28] 147,456 |
|
BatchNorm2d-41 [-1, 128, 28, 28] 256 |
|
ReLU-42 [-1, 128, 28, 28] 0 |
|
Conv2d-43 [-1, 512, 28, 28] 65,536 |
|
BatchNorm2d-44 [-1, 512, 28, 28] 1,024 |
|
Conv2d-45 [-1, 512, 28, 28] 131,072 |
|
BatchNorm2d-46 [-1, 512, 28, 28] 1,024 |
|
ReLU-47 [-1, 512, 28, 28] 0 |
|
Bottleneck-48 [-1, 512, 28, 28] 0 |
|
Conv2d-49 [-1, 128, 28, 28] 65,536 |
|
BatchNorm2d-50 [-1, 128, 28, 28] 256 |
|
ReLU-51 [-1, 128, 28, 28] 0 |
|
Conv2d-52 [-1, 128, 28, 28] 147,456 |
|
BatchNorm2d-53 [-1, 128, 28, 28] 256 |
|
ReLU-54 [-1, 128, 28, 28] 0 |
|
Conv2d-55 [-1, 512, 28, 28] 65,536 |
|
BatchNorm2d-56 [-1, 512, 28, 28] 1,024 |
|
ReLU-57 [-1, 512, 28, 28] 0 |
|
Bottleneck-58 [-1, 512, 28, 28] 0 |
|
Conv2d-59 [-1, 128, 28, 28] 65,536 |
|
BatchNorm2d-60 [-1, 128, 28, 28] 256 |
|
ReLU-61 [-1, 128, 28, 28] 0 |
|
Conv2d-62 [-1, 128, 28, 28] 147,456 |
|
BatchNorm2d-63 [-1, 128, 28, 28] 256 |
|
ReLU-64 [-1, 128, 28, 28] 0 |
|
Conv2d-65 [-1, 512, 28, 28] 65,536 |
|
BatchNorm2d-66 [-1, 512, 28, 28] 1,024 |
|
ReLU-67 [-1, 512, 28, 28] 0 |
|
Bottleneck-68 [-1, 512, 28, 28] 0 |
|
Conv2d-69 [-1, 128, 28, 28] 65,536 |
|
BatchNorm2d-70 [-1, 128, 28, 28] 256 |
|
ReLU-71 [-1, 128, 28, 28] 0 |
|
Conv2d-72 [-1, 128, 28, 28] 147,456 |
|
BatchNorm2d-73 [-1, 128, 28, 28] 256 |
|
ReLU-74 [-1, 128, 28, 28] 0 |
|
Conv2d-75 [-1, 512, 28, 28] 65,536 |
|
BatchNorm2d-76 [-1, 512, 28, 28] 1,024 |
|
ReLU-77 [-1, 512, 28, 28] 0 |
|
Bottleneck-78 [-1, 512, 28, 28] 0 |
|
Conv2d-79 [-1, 256, 28, 28] 131,072 |
|
BatchNorm2d-80 [-1, 256, 28, 28] 512 |
|
ReLU-81 [-1, 256, 28, 28] 0 |
|
Conv2d-82 [-1, 256, 14, 14] 589,824 |
|
BatchNorm2d-83 [-1, 256, 14, 14] 512 |
|
ReLU-84 [-1, 256, 14, 14] 0 |
|
Conv2d-85 [-1, 1024, 14, 14] 262,144 |
|
BatchNorm2d-86 [-1, 1024, 14, 14] 2,048 |
|
Conv2d-87 [-1, 1024, 14, 14] 524,288 |
|
BatchNorm2d-88 [-1, 1024, 14, 14] 2,048 |
|
ReLU-89 [-1, 1024, 14, 14] 0 |
|
Bottleneck-90 [-1, 1024, 14, 14] 0 |
|
Conv2d-91 [-1, 256, 14, 14] 262,144 |
|
BatchNorm2d-92 [-1, 256, 14, 14] 512 |
|
ReLU-93 [-1, 256, 14, 14] 0 |
|
Conv2d-94 [-1, 256, 14, 14] 589,824 |
|
BatchNorm2d-95 [-1, 256, 14, 14] 512 |
|
ReLU-96 [-1, 256, 14, 14] 0 |
|
Conv2d-97 [-1, 1024, 14, 14] 262,144 |
|
BatchNorm2d-98 [-1, 1024, 14, 14] 2,048 |
|
ReLU-99 [-1, 1024, 14, 14] 0 |
|
Bottleneck-100 [-1, 1024, 14, 14] 0 |
|
Conv2d-101 [-1, 256, 14, 14] 262,144 |
|
BatchNorm2d-102 [-1, 256, 14, 14] 512 |
|
ReLU-103 [-1, 256, 14, 14] 0 |
|
Conv2d-104 [-1, 256, 14, 14] 589,824 |
|
BatchNorm2d-105 [-1, 256, 14, 14] 512 |
|
ReLU-106 [-1, 256, 14, 14] 0 |
|
Conv2d-107 [-1, 1024, 14, 14] 262,144 |
|
BatchNorm2d-108 [-1, 1024, 14, 14] 2,048 |
|
ReLU-109 [-1, 1024, 14, 14] 0 |
|
Bottleneck-110 [-1, 1024, 14, 14] 0 |
|
Conv2d-111 [-1, 256, 14, 14] 262,144 |
|
BatchNorm2d-112 [-1, 256, 14, 14] 512 |
|
ReLU-113 [-1, 256, 14, 14] 0 |
|
Conv2d-114 [-1, 256, 14, 14] 589,824 |
|
BatchNorm2d-115 [-1, 256, 14, 14] 512 |
|
ReLU-116 [-1, 256, 14, 14] 0 |
|
Conv2d-117 [-1, 1024, 14, 14] 262,144 |
|
BatchNorm2d-118 [-1, 1024, 14, 14] 2,048 |
|
ReLU-119 [-1, 1024, 14, 14] 0 |
|
Bottleneck-120 [-1, 1024, 14, 14] 0 |
|
Conv2d-121 [-1, 256, 14, 14] 262,144 |
|
BatchNorm2d-122 [-1, 256, 14, 14] 512 |
|
ReLU-123 [-1, 256, 14, 14] 0 |
|
Conv2d-124 [-1, 256, 14, 14] 589,824 |
|
BatchNorm2d-125 [-1, 256, 14, 14] 512 |
|
ReLU-126 [-1, 256, 14, 14] 0 |
|
Conv2d-127 [-1, 1024, 14, 14] 262,144 |
|
BatchNorm2d-128 [-1, 1024, 14, 14] 2,048 |
|
ReLU-129 [-1, 1024, 14, 14] 0 |
|
Bottleneck-130 [-1, 1024, 14, 14] 0 |
|
Conv2d-131 [-1, 256, 14, 14] 262,144 |
|
BatchNorm2d-132 [-1, 256, 14, 14] 512 |
|
ReLU-133 [-1, 256, 14, 14] 0 |
|
Conv2d-134 [-1, 256, 14, 14] 589,824 |
|
BatchNorm2d-135 [-1, 256, 14, 14] 512 |
|
ReLU-136 [-1, 256, 14, 14] 0 |
|
Conv2d-137 [-1, 1024, 14, 14] 262,144 |
|
BatchNorm2d-138 [-1, 1024, 14, 14] 2,048 |
|
ReLU-139 [-1, 1024, 14, 14] 0 |
|
Bottleneck-140 [-1, 1024, 14, 14] 0 |
|
Conv2d-141 [-1, 512, 14, 14] 524,288 |
|
BatchNorm2d-142 [-1, 512, 14, 14] 1,024 |
|
ReLU-143 [-1, 512, 14, 14] 0 |
|
Conv2d-144 [-1, 512, 7, 7] 2,359,296 |
|
BatchNorm2d-145 [-1, 512, 7, 7] 1,024 |
|
ReLU-146 [-1, 512, 7, 7] 0 |
|
Conv2d-147 [-1, 2048, 7, 7] 1,048,576 |
|
BatchNorm2d-148 [-1, 2048, 7, 7] 4,096 |
|
Conv2d-149 [-1, 2048, 7, 7] 2,097,152 |
|
BatchNorm2d-150 [-1, 2048, 7, 7] 4,096 |
|
ReLU-151 [-1, 2048, 7, 7] 0 |
|
Bottleneck-152 [-1, 2048, 7, 7] 0 |
|
Conv2d-153 [-1, 512, 7, 7] 1,048,576 |
|
BatchNorm2d-154 [-1, 512, 7, 7] 1,024 |
|
ReLU-155 [-1, 512, 7, 7] 0 |
|
Conv2d-156 [-1, 512, 7, 7] 2,359,296 |
|
BatchNorm2d-157 [-1, 512, 7, 7] 1,024 |
|
ReLU-158 [-1, 512, 7, 7] 0 |
|
Conv2d-159 [-1, 2048, 7, 7] 1,048,576 |
|
BatchNorm2d-160 [-1, 2048, 7, 7] 4,096 |
|
ReLU-161 [-1, 2048, 7, 7] 0 |
|
Bottleneck-162 [-1, 2048, 7, 7] 0 |
|
Conv2d-163 [-1, 512, 7, 7] 1,048,576 |
|
BatchNorm2d-164 [-1, 512, 7, 7] 1,024 |
|
ReLU-165 [-1, 512, 7, 7] 0 |
|
Conv2d-166 [-1, 512, 7, 7] 2,359,296 |
|
BatchNorm2d-167 [-1, 512, 7, 7] 1,024 |
|
ReLU-168 [-1, 512, 7, 7] 0 |
|
Conv2d-169 [-1, 2048, 7, 7] 1,048,576 |
|
BatchNorm2d-170 [-1, 2048, 7, 7] 4,096 |
|
ReLU-171 [-1, 2048, 7, 7] 0 |
|
Bottleneck-172 [-1, 2048, 7, 7] 0 |
|
AdaptiveMaxPool2d-173 [-1, 2048, 1, 1] 0 |
|
AdaptiveAvgPool2d-174 [-1, 2048, 1, 1] 0 |
|
AdaptiveConcatPool2d-175 [-1, 4096, 1, 1] 0 |
|
Flatten-176 [-1, 4096] 0 |
|
BatchNorm1d-177 [-1, 4096] 8,192 |
|
Dropout-178 [-1, 4096] 0 |
|
Linear-179 [-1, 512] 2,097,664 |
|
ReLU-180 [-1, 512] 0 |
|
BatchNorm1d-181 [-1, 512] 1,024 |
|
Dropout-182 [-1, 512] 0 |
|
Linear-183 [-1, 6000] 3,078,000 |
|
================================================================ |
|
Total params: 28,692,912 |
|
Trainable params: 28,692,912 |
|
Non-trainable params: 0 |
|
---------------------------------------------------------------- |
|
Input size (MB): 0.57 |
|
Forward/backward pass size (MB): 286.75 |
|
Params size (MB): 109.45 |
|
Estimated Total Size (MB): 396.78 |
|
---------------------------------------------------------------- |
|
''' |
|
|
|
|
|
class AdaptiveConcatPool2d(nn.Module): |
|
""" |
|
Layer that concats `AdaptiveAvgPool2d` and `AdaptiveMaxPool2d`. |
|
Source: Fastai. This code was taken from the fastai library at url |
|
https://github.com/fastai/fastai/blob/master/fastai/layers.py#L176 |
|
""" |
|
def __init__(self, sz=None): |
|
"Output will be 2*sz or 2 if sz is None" |
|
super().__init__() |
|
self.output_size = sz or 1 |
|
self.ap = nn.AdaptiveAvgPool2d(self.output_size) |
|
self.mp = nn.AdaptiveMaxPool2d(self.output_size) |
|
|
|
def forward(self, x): return torch.cat([self.mp(x), self.ap(x)], 1) |
|
|
|
|
|
class Flatten(nn.Module): |
|
""" |
|
Flatten `x` to a single dimension. Adapted from fastai's Flatten() layer, |
|
at https://github.com/fastai/fastai/blob/master/fastai/layers.py#L25 |
|
""" |
|
def __init__(self): super().__init__() |
|
def forward(self, x): return x.view(x.size(0), -1) |
|
|
|
|
|
def bn_drop_lin(n_in:int, n_out:int, bn:bool=True, p:float=0., actn=None): |
|
""" |
|
Sequence of batchnorm (if `bn`), dropout (with `p`) and linear (`n_in`,`n_out`) layers followed by `actn`. |
|
Adapted from Fastai at https://github.com/fastai/fastai/blob/master/fastai/layers.py#L44 |
|
""" |
|
layers = [nn.BatchNorm1d(n_in)] if bn else [] |
|
if p != 0: layers.append(nn.Dropout(p)) |
|
layers.append(nn.Linear(n_in, n_out)) |
|
if actn is not None: layers.append(actn) |
|
return layers |
|
|
|
def create_head(top_n_tags, nf, ps=0.5): |
|
nc = top_n_tags |
|
|
|
lin_ftrs = [nf, 512, nc] |
|
p1 = 0.25 |
|
p2 = 0.5 |
|
|
|
actns = [nn.ReLU(inplace=True),] + [None] |
|
pool = AdaptiveConcatPool2d() |
|
layers = [pool, Flatten()] |
|
|
|
layers += [ |
|
*bn_drop_lin(lin_ftrs[0], lin_ftrs[1], True, p1, nn.ReLU(inplace=True)), |
|
*bn_drop_lin(lin_ftrs[1], lin_ftrs[2], True, p2) |
|
] |
|
|
|
return nn.Sequential(*layers) |
|
|
|
|
|
def _resnet(base_arch, top_n, **kwargs): |
|
cut = -2 |
|
s = base_arch(pretrained=False, **kwargs) |
|
body = nn.Sequential(*list(s.children())[:cut]) |
|
|
|
if base_arch in [models.resnet18, models.resnet34]: |
|
num_features_model = 512 |
|
elif base_arch in [models.resnet50, models.resnet101]: |
|
num_features_model = 2048 |
|
|
|
nf = num_features_model * 2 |
|
nc = top_n |
|
|
|
|
|
model = body |
|
|
|
return model |
|
|
|
|
|
def resnet50(pretrained=True, progress=True, top_n=6000, **kwargs): |
|
r""" |
|
Resnet50 model trained on the full Danbooru2018 dataset's top 6000 tags |
|
|
|
Args: |
|
pretrained (bool): kwargs, load pretrained weights into the model. |
|
top_n (int): kwargs, pick to load the model for predicting the top `n` tags, |
|
currently only supports top_n=6000. |
|
""" |
|
model = _resnet(models.resnet50, top_n, **kwargs) |
|
|
|
if pretrained: |
|
if top_n == 6000: |
|
state = torch.hub.load_state_dict_from_url("https://github.com/RF5/danbooru-pretrained/releases/download/v0.1/resnet50-13306192.pth", |
|
progress=progress) |
|
old_keys = [key for key in state] |
|
for old_key in old_keys: |
|
if old_key[0] == '0': |
|
new_key = old_key[2:] |
|
state[new_key] = state[old_key] |
|
del state[old_key] |
|
elif old_key[0] == '1': |
|
del state[old_key] |
|
|
|
model.load_state_dict(state) |
|
else: |
|
raise ValueError("Sorry, the resnet50 model only supports the top-6000 tags \ |
|
at the moment") |
|
|
|
|
|
return model |
|
|
|
|
|
|
|
|
|
class resnet50_Extractor(nn.Module): |
|
"""ResNet50 network for feature extraction. |
|
""" |
|
def get_activation(self, name): |
|
def hook(model, input, output): |
|
self.activation[name] = output.detach() |
|
return hook |
|
|
|
|
|
def __init__(self, |
|
model, |
|
layer_labels, |
|
use_input_norm=True, |
|
range_norm=False, |
|
requires_grad=False |
|
): |
|
super(resnet50_Extractor, self).__init__() |
|
|
|
|
|
self.model = model |
|
self.use_input_norm = use_input_norm |
|
self.range_norm = range_norm |
|
self.layer_labels = layer_labels |
|
self.activation = {} |
|
|
|
|
|
|
|
for layer_label in layer_labels: |
|
elements = layer_label.split('_') |
|
if len(elements) == 1: |
|
|
|
getattr(self.model, elements[0]).register_forward_hook(self.get_activation(layer_label)) |
|
else: |
|
body_layer = self.model |
|
for element in elements[:-1]: |
|
|
|
assert(isinstance(int(element), int)) |
|
body_layer = body_layer[int(element)] |
|
getattr(body_layer, elements[-1]).register_forward_hook(self.get_activation(layer_label)) |
|
|
|
|
|
|
|
if not requires_grad: |
|
self.model.eval() |
|
for param in self.parameters(): |
|
param.requires_grad = False |
|
|
|
|
|
if self.use_input_norm: |
|
|
|
self.register_buffer('mean', torch.Tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1)) |
|
|
|
self.register_buffer('std', torch.Tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1)) |
|
|
|
|
|
|
|
def forward(self, x): |
|
"""Forward function. |
|
|
|
Args: |
|
x (Tensor): Input tensor with shape (n, c, h, w). |
|
|
|
Returns: |
|
Tensor: Forward results. |
|
""" |
|
if self.range_norm: |
|
x = (x + 1) / 2 |
|
if self.use_input_norm: |
|
x = (x - self.mean) / self.std |
|
|
|
|
|
output = self.model(x) |
|
|
|
|
|
store = {} |
|
for layer_label in self.layer_labels: |
|
store[layer_label] = self.activation[layer_label] |
|
|
|
|
|
return store |
|
|
|
|
|
class Anime_PerceptualLoss(nn.Module): |
|
"""Anime Perceptual loss |
|
|
|
Args: |
|
layer_weights (dict): The weight for each layer of vgg feature. |
|
Here is an example: {'conv5_4': 1.}, which means the conv5_4 |
|
feature layer (before relu5_4) will be extracted with weight |
|
1.0 in calculating losses. |
|
perceptual_weight (float): If `perceptual_weight > 0`, the perceptual |
|
loss will be calculated and the loss will multiplied by the |
|
weight. Default: 1.0. |
|
criterion (str): Criterion used for perceptual loss. Default: 'l1'. |
|
""" |
|
|
|
def __init__(self, |
|
layer_weights, |
|
perceptual_weight=1.0, |
|
criterion='l1'): |
|
super(Anime_PerceptualLoss, self).__init__() |
|
|
|
|
|
model = resnet50() |
|
self.perceptual_weight = perceptual_weight |
|
self.layer_weights = layer_weights |
|
self.layer_labels = layer_weights.keys() |
|
self.resnet50 = resnet50_Extractor(model, self.layer_labels).cuda() |
|
|
|
if criterion == 'l1': |
|
self.criterion = torch.nn.L1Loss() |
|
else: |
|
raise NotImplementedError("We don't support such criterion loss in perceptual loss") |
|
|
|
|
|
def forward(self, gen, gt): |
|
"""Forward function. |
|
|
|
Args: |
|
gen (Tensor): Input tensor with shape (n, c, h, w). |
|
gt (Tensor): Ground-truth tensor with shape (n, c, h, w). |
|
|
|
Returns: |
|
Tensor: Forward results. |
|
""" |
|
|
|
gen_features = self.resnet50(gen) |
|
gt_features = self.resnet50(gt.detach()) |
|
|
|
|
|
temp_store = [] |
|
|
|
|
|
if self.perceptual_weight > 0: |
|
percep_loss = 0 |
|
for idx, k in enumerate(gen_features.keys()): |
|
raw_comparison = self.criterion(gen_features[k], gt_features[k]) |
|
percep_loss += raw_comparison * self.layer_weights[k] |
|
|
|
|
|
|
|
|
|
percep_loss *= self.perceptual_weight |
|
else: |
|
percep_loss = None |
|
|
|
|
|
if len(temp_store) != 0: |
|
return temp_store, percep_loss |
|
else: |
|
return percep_loss |
|
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
import torchvision.transforms as transforms |
|
import cv2 |
|
import collections |
|
|
|
|
|
loss = Anime_PerceptualLoss({"0": 0.5, "4_2_conv3": 20, "5_3_conv3": 30, "6_5_conv3": 1, "7_2_conv3": 1}).cuda() |
|
|
|
|
|
store = collections.defaultdict(list) |
|
for img_name in sorted(os.listdir('datasets/train_gen/')): |
|
gen = transforms.ToTensor()(cv2.imread('datasets/train_gen/'+img_name)).cuda() |
|
gt = transforms.ToTensor()(cv2.imread('datasets/train_hr_anime_usm/'+img_name)).cuda() |
|
temp_store, _ = loss(gen, gt) |
|
|
|
for idx in range(len(temp_store)): |
|
store[idx].append(temp_store[idx]) |
|
|
|
for idx in range(len(store)): |
|
print("Average layer" + str(idx) + " has loss " + str(sum(store[idx]) / len(store[idx]))) |
|
|
|
|
|
|
|
|
|
|