Spaces:
Runtime error
Runtime error
# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved. | |
# | |
# This work is made available under the Nvidia Source Code License-NC. | |
# To view a copy of this license, check out LICENSE.md | |
# https://github.com/switchablenorms/CelebAMask-HQ/tree/master/face_parsing | |
import torch | |
from torch import nn | |
from torch.nn import functional as F | |
class Unet(nn.Module): | |
def __init__( | |
self, | |
feature_scale=4, | |
n_classes=19, | |
is_deconv=True, | |
in_channels=3, | |
is_batchnorm=True, | |
image_size=512, | |
use_dont_care=False | |
): | |
super(Unet, self).__init__() | |
self.is_deconv = is_deconv | |
self.in_channels = in_channels | |
self.is_batchnorm = is_batchnorm | |
self.feature_scale = feature_scale | |
self.image_size = image_size | |
self.n_classes = n_classes | |
self.use_dont_care = use_dont_care | |
filters = [64, 128, 256, 512, 1024] | |
filters = [int(x / self.feature_scale) for x in filters] | |
# downsampling | |
self.conv1 = unetConv2(self.in_channels, filters[0], self.is_batchnorm) | |
self.maxpool1 = nn.MaxPool2d(kernel_size=2) | |
self.conv2 = unetConv2(filters[0], filters[1], self.is_batchnorm) | |
self.maxpool2 = nn.MaxPool2d(kernel_size=2) | |
self.conv3 = unetConv2(filters[1], filters[2], self.is_batchnorm) | |
self.maxpool3 = nn.MaxPool2d(kernel_size=2) | |
self.conv4 = unetConv2(filters[2], filters[3], self.is_batchnorm) | |
self.maxpool4 = nn.MaxPool2d(kernel_size=2) | |
self.center = unetConv2(filters[3], filters[4], self.is_batchnorm) | |
# upsampling | |
self.up_concat4 = unetUp( | |
filters[4], filters[3], self.is_deconv, self.is_batchnorm) | |
self.up_concat3 = unetUp( | |
filters[3], filters[2], self.is_deconv, self.is_batchnorm) | |
self.up_concat2 = unetUp( | |
filters[2], filters[1], self.is_deconv, self.is_batchnorm) | |
self.up_concat1 = unetUp( | |
filters[1], filters[0], self.is_deconv, self.is_batchnorm) | |
# final conv (without any concat) | |
self.final = nn.Conv2d(filters[0], n_classes, 1) | |
def forward(self, images, align_corners=True): | |
images = F.interpolate( | |
images, size=(self.image_size, self.image_size), mode='bicubic', | |
align_corners=align_corners | |
) | |
conv1 = self.conv1(images) | |
maxpool1 = self.maxpool1(conv1) | |
conv2 = self.conv2(maxpool1) | |
maxpool2 = self.maxpool2(conv2) | |
conv3 = self.conv3(maxpool2) | |
maxpool3 = self.maxpool3(conv3) | |
conv4 = self.conv4(maxpool3) | |
maxpool4 = self.maxpool4(conv4) | |
center = self.center(maxpool4) | |
up4 = self.up_concat4(conv4, center) | |
up3 = self.up_concat3(conv3, up4) | |
up2 = self.up_concat2(conv2, up3) | |
up1 = self.up_concat1(conv1, up2) | |
probs = self.final(up1) | |
pred = torch.argmax(probs, dim=1) | |
return pred | |
class unetConv2(nn.Module): | |
def __init__(self, in_size, out_size, is_batchnorm): | |
super(unetConv2, self).__init__() | |
if is_batchnorm: | |
self.conv1 = nn.Sequential( | |
nn.Conv2d(in_size, out_size, 3, 1, 1), | |
nn.BatchNorm2d(out_size), | |
nn.ReLU(), | |
) | |
self.conv2 = nn.Sequential( | |
nn.Conv2d(out_size, out_size, 3, 1, 1), | |
nn.BatchNorm2d(out_size), | |
nn.ReLU(), | |
) | |
else: | |
self.conv1 = nn.Sequential(nn.Conv2d(in_size, out_size, 3, 1, 1), | |
nn.ReLU()) | |
self.conv2 = nn.Sequential( | |
nn.Conv2d(out_size, out_size, 3, 1, 1), nn.ReLU() | |
) | |
def forward(self, inputs): | |
outputs = self.conv1(inputs) | |
outputs = self.conv2(outputs) | |
return outputs | |
class unetUp(nn.Module): | |
def __init__(self, in_size, out_size, is_deconv, is_batchnorm): | |
super(unetUp, self).__init__() | |
self.conv = unetConv2(in_size, out_size, is_batchnorm) | |
if is_deconv: | |
self.up = nn.ConvTranspose2d( | |
in_size, out_size, kernel_size=2, stride=2) | |
else: | |
self.up = nn.UpsamplingBilinear2d(scale_factor=2) | |
def forward(self, inputs1, inputs2): | |
outputs2 = self.up(inputs2) | |
offset = outputs2.size()[2] - inputs1.size()[2] | |
padding = 2 * [offset // 2, offset // 2] | |
outputs1 = F.pad(inputs1, padding) | |
return self.conv(torch.cat([outputs1, outputs2], 1)) | |