Spaces:
Runtime error
Runtime error
File size: 4,606 Bytes
f670afc |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 |
# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# This work is made available under the Nvidia Source Code License-NC.
# To view a copy of this license, check out LICENSE.md
# https://github.com/switchablenorms/CelebAMask-HQ/tree/master/face_parsing
import torch
from torch import nn
from torch.nn import functional as F
class Unet(nn.Module):
def __init__(
self,
feature_scale=4,
n_classes=19,
is_deconv=True,
in_channels=3,
is_batchnorm=True,
image_size=512,
use_dont_care=False
):
super(Unet, self).__init__()
self.is_deconv = is_deconv
self.in_channels = in_channels
self.is_batchnorm = is_batchnorm
self.feature_scale = feature_scale
self.image_size = image_size
self.n_classes = n_classes
self.use_dont_care = use_dont_care
filters = [64, 128, 256, 512, 1024]
filters = [int(x / self.feature_scale) for x in filters]
# downsampling
self.conv1 = unetConv2(self.in_channels, filters[0], self.is_batchnorm)
self.maxpool1 = nn.MaxPool2d(kernel_size=2)
self.conv2 = unetConv2(filters[0], filters[1], self.is_batchnorm)
self.maxpool2 = nn.MaxPool2d(kernel_size=2)
self.conv3 = unetConv2(filters[1], filters[2], self.is_batchnorm)
self.maxpool3 = nn.MaxPool2d(kernel_size=2)
self.conv4 = unetConv2(filters[2], filters[3], self.is_batchnorm)
self.maxpool4 = nn.MaxPool2d(kernel_size=2)
self.center = unetConv2(filters[3], filters[4], self.is_batchnorm)
# upsampling
self.up_concat4 = unetUp(
filters[4], filters[3], self.is_deconv, self.is_batchnorm)
self.up_concat3 = unetUp(
filters[3], filters[2], self.is_deconv, self.is_batchnorm)
self.up_concat2 = unetUp(
filters[2], filters[1], self.is_deconv, self.is_batchnorm)
self.up_concat1 = unetUp(
filters[1], filters[0], self.is_deconv, self.is_batchnorm)
# final conv (without any concat)
self.final = nn.Conv2d(filters[0], n_classes, 1)
def forward(self, images, align_corners=True):
images = F.interpolate(
images, size=(self.image_size, self.image_size), mode='bicubic',
align_corners=align_corners
)
conv1 = self.conv1(images)
maxpool1 = self.maxpool1(conv1)
conv2 = self.conv2(maxpool1)
maxpool2 = self.maxpool2(conv2)
conv3 = self.conv3(maxpool2)
maxpool3 = self.maxpool3(conv3)
conv4 = self.conv4(maxpool3)
maxpool4 = self.maxpool4(conv4)
center = self.center(maxpool4)
up4 = self.up_concat4(conv4, center)
up3 = self.up_concat3(conv3, up4)
up2 = self.up_concat2(conv2, up3)
up1 = self.up_concat1(conv1, up2)
probs = self.final(up1)
pred = torch.argmax(probs, dim=1)
return pred
class unetConv2(nn.Module):
def __init__(self, in_size, out_size, is_batchnorm):
super(unetConv2, self).__init__()
if is_batchnorm:
self.conv1 = nn.Sequential(
nn.Conv2d(in_size, out_size, 3, 1, 1),
nn.BatchNorm2d(out_size),
nn.ReLU(),
)
self.conv2 = nn.Sequential(
nn.Conv2d(out_size, out_size, 3, 1, 1),
nn.BatchNorm2d(out_size),
nn.ReLU(),
)
else:
self.conv1 = nn.Sequential(nn.Conv2d(in_size, out_size, 3, 1, 1),
nn.ReLU())
self.conv2 = nn.Sequential(
nn.Conv2d(out_size, out_size, 3, 1, 1), nn.ReLU()
)
def forward(self, inputs):
outputs = self.conv1(inputs)
outputs = self.conv2(outputs)
return outputs
class unetUp(nn.Module):
def __init__(self, in_size, out_size, is_deconv, is_batchnorm):
super(unetUp, self).__init__()
self.conv = unetConv2(in_size, out_size, is_batchnorm)
if is_deconv:
self.up = nn.ConvTranspose2d(
in_size, out_size, kernel_size=2, stride=2)
else:
self.up = nn.UpsamplingBilinear2d(scale_factor=2)
def forward(self, inputs1, inputs2):
outputs2 = self.up(inputs2)
offset = outputs2.size()[2] - inputs1.size()[2]
padding = 2 * [offset // 2, offset // 2]
outputs1 = F.pad(inputs1, padding)
return self.conv(torch.cat([outputs1, outputs2], 1))
|