Spaces:
Runtime error
Runtime error
""" | |
Copyright (c) Meta Platforms, Inc. and affiliates. | |
All rights reserved. | |
This source code is licensed under the license found in the | |
LICENSE file in the root directory of this source tree. | |
""" | |
import torch as th | |
import torch.nn as nn | |
import visualize.ca_body.nn.layers as la | |
from visualize.ca_body.nn.blocks import weights_initializer | |
from visualize.ca_body.nn.layers import Conv2dWNUB, ConvTranspose2dWNUB, glorot | |
class UNetWB(nn.Module): | |
def __init__( | |
self, | |
in_channels: int, | |
out_channels: int, | |
size: int, | |
n_init_ftrs: int = 8, | |
out_scale: float = 0.1, | |
): | |
# super().__init__(*args, **kwargs) | |
super().__init__() | |
self.out_scale = out_scale | |
F = n_init_ftrs | |
self.size = size | |
self.down1 = nn.Sequential( | |
Conv2dWNUB(in_channels, F, self.size // 2, self.size // 2, 4, 2, 1), | |
nn.LeakyReLU(0.2), | |
) | |
self.down2 = nn.Sequential( | |
Conv2dWNUB(F, 2 * F, self.size // 4, self.size // 4, 4, 2, 1), | |
nn.LeakyReLU(0.2), | |
) | |
self.down3 = nn.Sequential( | |
Conv2dWNUB(2 * F, 4 * F, self.size // 8, self.size // 8, 4, 2, 1), | |
nn.LeakyReLU(0.2), | |
) | |
self.down4 = nn.Sequential( | |
Conv2dWNUB(4 * F, 8 * F, self.size // 16, self.size // 16, 4, 2, 1), | |
nn.LeakyReLU(0.2), | |
) | |
self.down5 = nn.Sequential( | |
Conv2dWNUB(8 * F, 16 * F, self.size // 32, self.size // 32, 4, 2, 1), | |
nn.LeakyReLU(0.2), | |
) | |
self.up1 = nn.Sequential( | |
ConvTranspose2dWNUB( | |
16 * F, 8 * F, self.size // 16, self.size // 16, 4, 2, 1 | |
), | |
nn.LeakyReLU(0.2), | |
) | |
self.up2 = nn.Sequential( | |
ConvTranspose2dWNUB(8 * F, 4 * F, self.size // 8, self.size // 8, 4, 2, 1), | |
nn.LeakyReLU(0.2), | |
) | |
self.up3 = nn.Sequential( | |
ConvTranspose2dWNUB(4 * F, 2 * F, self.size // 4, self.size // 4, 4, 2, 1), | |
nn.LeakyReLU(0.2), | |
) | |
self.up4 = nn.Sequential( | |
ConvTranspose2dWNUB(2 * F, F, self.size // 2, self.size // 2, 4, 2, 1), | |
nn.LeakyReLU(0.2), | |
) | |
self.up5 = nn.Sequential( | |
ConvTranspose2dWNUB(F, F, self.size, self.size, 4, 2, 1), nn.LeakyReLU(0.2) | |
) | |
self.out = Conv2dWNUB( | |
F + in_channels, out_channels, self.size, self.size, kernel_size=1 | |
) | |
self.apply(lambda x: glorot(x, 0.2)) | |
glorot(self.out, 1.0) | |
def forward(self, x): | |
x1 = x | |
x2 = self.down1(x1) | |
x3 = self.down2(x2) | |
x4 = self.down3(x3) | |
x5 = self.down4(x4) | |
x6 = self.down5(x5) | |
# TODO: switch to concat? | |
x = self.up1(x6) + x5 | |
x = self.up2(x) + x4 | |
x = self.up3(x) + x3 | |
x = self.up4(x) + x2 | |
x = self.up5(x) | |
x = th.cat([x, x1], dim=1) | |
return self.out(x) * self.out_scale | |
class UNetWBConcat(nn.Module): | |
def __init__( | |
self, | |
in_channels: int, | |
out_channels: int, | |
size: int, | |
n_init_ftrs: int = 8, | |
): | |
super().__init__() | |
F = n_init_ftrs | |
self.size = size | |
self.down1 = nn.Sequential( | |
la.Conv2dWNUB(in_channels, F, self.size // 2, self.size // 2, 4, 2, 1), | |
nn.LeakyReLU(0.2), | |
) | |
self.down2 = nn.Sequential( | |
la.Conv2dWNUB(F, 2 * F, self.size // 4, self.size // 4, 4, 2, 1), | |
nn.LeakyReLU(0.2), | |
) | |
self.down3 = nn.Sequential( | |
la.Conv2dWNUB(2 * F, 4 * F, self.size // 8, self.size // 8, 4, 2, 1), | |
nn.LeakyReLU(0.2), | |
) | |
self.down4 = nn.Sequential( | |
la.Conv2dWNUB(4 * F, 8 * F, self.size // 16, self.size // 16, 4, 2, 1), | |
nn.LeakyReLU(0.2), | |
) | |
self.down5 = nn.Sequential( | |
la.Conv2dWNUB(8 * F, 16 * F, self.size // 32, self.size // 32, 4, 2, 1), | |
nn.LeakyReLU(0.2), | |
) | |
self.up1 = nn.Sequential( | |
la.ConvTranspose2dWNUB( | |
16 * F, 8 * F, self.size // 16, self.size // 16, 4, 2, 1 | |
), | |
nn.LeakyReLU(0.2), | |
) | |
self.up2 = nn.Sequential( | |
la.ConvTranspose2dWNUB( | |
2 * 8 * F, 4 * F, self.size // 8, self.size // 8, 4, 2, 1 | |
), | |
nn.LeakyReLU(0.2), | |
) | |
self.up3 = nn.Sequential( | |
la.ConvTranspose2dWNUB( | |
2 * 4 * F, 2 * F, self.size // 4, self.size // 4, 4, 2, 1 | |
), | |
nn.LeakyReLU(0.2), | |
) | |
self.up4 = nn.Sequential( | |
la.ConvTranspose2dWNUB( | |
2 * 2 * F, F, self.size // 2, self.size // 2, 4, 2, 1 | |
), | |
nn.LeakyReLU(0.2), | |
) | |
self.up5 = nn.Sequential( | |
la.ConvTranspose2dWNUB(2 * F, F, self.size, self.size, 4, 2, 1), | |
nn.LeakyReLU(0.2), | |
) | |
self.out = la.Conv2dWNUB( | |
F + in_channels, out_channels, self.size, self.size, kernel_size=1 | |
) | |
self.apply(lambda x: la.glorot(x, 0.2)) | |
la.glorot(self.out, 1.0) | |
def forward(self, x): | |
x1 = x | |
x2 = self.down1(x1) | |
x3 = self.down2(x2) | |
x4 = self.down3(x3) | |
x5 = self.down4(x4) | |
x6 = self.down5(x5) | |
x = th.cat([self.up1(x6), x5], 1) | |
x = th.cat([self.up2(x), x4], 1) | |
x = th.cat([self.up3(x), x3], 1) | |
x = th.cat([self.up4(x), x2], 1) | |
x = self.up5(x) | |
x = th.cat([x, x1], dim=1) | |
return self.out(x) | |
class UNetW(nn.Module): | |
def __init__( | |
self, | |
in_channels, | |
out_channels, | |
n_init_ftrs, | |
kernel_size=4, | |
out_scale=1.0, | |
): | |
super().__init__() | |
self.out_scale = out_scale | |
F = n_init_ftrs | |
self.down1 = nn.Sequential( | |
la.Conv2dWN(in_channels, F, kernel_size, 2, 1), | |
nn.LeakyReLU(0.2), | |
) | |
self.down2 = nn.Sequential( | |
la.Conv2dWN(F, 2 * F, kernel_size, 2, 1), | |
nn.LeakyReLU(0.2), | |
) | |
self.down3 = nn.Sequential( | |
la.Conv2dWN(2 * F, 4 * F, kernel_size, 2, 1), | |
nn.LeakyReLU(0.2), | |
) | |
self.down4 = nn.Sequential( | |
la.Conv2dWN(4 * F, 8 * F, kernel_size, 2, 1), | |
nn.LeakyReLU(0.2), | |
) | |
self.down5 = nn.Sequential( | |
la.Conv2dWN(8 * F, 16 * F, kernel_size, 2, 1), | |
nn.LeakyReLU(0.2), | |
) | |
self.up1 = nn.Sequential( | |
la.ConvTranspose2dWN(16 * F, 8 * F, kernel_size, 2, 1), | |
nn.LeakyReLU(0.2), | |
) | |
self.up2 = nn.Sequential( | |
la.ConvTranspose2dWN(8 * F, 4 * F, kernel_size, 2, 1), | |
nn.LeakyReLU(0.2), | |
) | |
self.up3 = nn.Sequential( | |
la.ConvTranspose2dWN(4 * F, 2 * F, kernel_size, 2, 1), | |
nn.LeakyReLU(0.2), | |
) | |
self.up4 = nn.Sequential( | |
la.ConvTranspose2dWN(2 * F, F, kernel_size, 2, 1), | |
nn.LeakyReLU(0.2), | |
) | |
self.up5 = nn.Sequential( | |
la.ConvTranspose2dWN(F, F, kernel_size, 2, 1), nn.LeakyReLU(0.2) | |
) | |
self.out = la.Conv2dWN(F + in_channels, out_channels, kernel_size=1) | |
self.apply(weights_initializer(0.2)) | |
self.out.apply(weights_initializer(1.0)) | |
def forward(self, x): | |
x1 = x | |
x2 = self.down1(x1) | |
x3 = self.down2(x2) | |
x4 = self.down3(x3) | |
x5 = self.down4(x4) | |
x6 = self.down5(x5) | |
# TODO: switch to concat? | |
x = self.up1(x6) + x5 | |
x = self.up2(x) + x4 | |
x = self.up3(x) + x3 | |
x = self.up4(x) + x2 | |
x = self.up5(x) | |
x = th.cat([x, x1], dim=1) | |
return self.out(x) * self.out_scale | |