whyun13's picture
Upload folder using huggingface_hub
882f6e2 verified
"""
Copyright (c) Meta Platforms, Inc. and affiliates.
All rights reserved.
This source code is licensed under the license found in the
LICENSE file in the root directory of this source tree.
"""
import torch as th
import torch.nn as nn
import visualize.ca_body.nn.layers as la
from visualize.ca_body.nn.blocks import weights_initializer
from visualize.ca_body.nn.layers import Conv2dWNUB, ConvTranspose2dWNUB, glorot
class UNetWB(nn.Module):
def __init__(
self,
in_channels: int,
out_channels: int,
size: int,
n_init_ftrs: int = 8,
out_scale: float = 0.1,
):
# super().__init__(*args, **kwargs)
super().__init__()
self.out_scale = out_scale
F = n_init_ftrs
self.size = size
self.down1 = nn.Sequential(
Conv2dWNUB(in_channels, F, self.size // 2, self.size // 2, 4, 2, 1),
nn.LeakyReLU(0.2),
)
self.down2 = nn.Sequential(
Conv2dWNUB(F, 2 * F, self.size // 4, self.size // 4, 4, 2, 1),
nn.LeakyReLU(0.2),
)
self.down3 = nn.Sequential(
Conv2dWNUB(2 * F, 4 * F, self.size // 8, self.size // 8, 4, 2, 1),
nn.LeakyReLU(0.2),
)
self.down4 = nn.Sequential(
Conv2dWNUB(4 * F, 8 * F, self.size // 16, self.size // 16, 4, 2, 1),
nn.LeakyReLU(0.2),
)
self.down5 = nn.Sequential(
Conv2dWNUB(8 * F, 16 * F, self.size // 32, self.size // 32, 4, 2, 1),
nn.LeakyReLU(0.2),
)
self.up1 = nn.Sequential(
ConvTranspose2dWNUB(
16 * F, 8 * F, self.size // 16, self.size // 16, 4, 2, 1
),
nn.LeakyReLU(0.2),
)
self.up2 = nn.Sequential(
ConvTranspose2dWNUB(8 * F, 4 * F, self.size // 8, self.size // 8, 4, 2, 1),
nn.LeakyReLU(0.2),
)
self.up3 = nn.Sequential(
ConvTranspose2dWNUB(4 * F, 2 * F, self.size // 4, self.size // 4, 4, 2, 1),
nn.LeakyReLU(0.2),
)
self.up4 = nn.Sequential(
ConvTranspose2dWNUB(2 * F, F, self.size // 2, self.size // 2, 4, 2, 1),
nn.LeakyReLU(0.2),
)
self.up5 = nn.Sequential(
ConvTranspose2dWNUB(F, F, self.size, self.size, 4, 2, 1), nn.LeakyReLU(0.2)
)
self.out = Conv2dWNUB(
F + in_channels, out_channels, self.size, self.size, kernel_size=1
)
self.apply(lambda x: glorot(x, 0.2))
glorot(self.out, 1.0)
def forward(self, x):
x1 = x
x2 = self.down1(x1)
x3 = self.down2(x2)
x4 = self.down3(x3)
x5 = self.down4(x4)
x6 = self.down5(x5)
# TODO: switch to concat?
x = self.up1(x6) + x5
x = self.up2(x) + x4
x = self.up3(x) + x3
x = self.up4(x) + x2
x = self.up5(x)
x = th.cat([x, x1], dim=1)
return self.out(x) * self.out_scale
class UNetWBConcat(nn.Module):
def __init__(
self,
in_channels: int,
out_channels: int,
size: int,
n_init_ftrs: int = 8,
):
super().__init__()
F = n_init_ftrs
self.size = size
self.down1 = nn.Sequential(
la.Conv2dWNUB(in_channels, F, self.size // 2, self.size // 2, 4, 2, 1),
nn.LeakyReLU(0.2),
)
self.down2 = nn.Sequential(
la.Conv2dWNUB(F, 2 * F, self.size // 4, self.size // 4, 4, 2, 1),
nn.LeakyReLU(0.2),
)
self.down3 = nn.Sequential(
la.Conv2dWNUB(2 * F, 4 * F, self.size // 8, self.size // 8, 4, 2, 1),
nn.LeakyReLU(0.2),
)
self.down4 = nn.Sequential(
la.Conv2dWNUB(4 * F, 8 * F, self.size // 16, self.size // 16, 4, 2, 1),
nn.LeakyReLU(0.2),
)
self.down5 = nn.Sequential(
la.Conv2dWNUB(8 * F, 16 * F, self.size // 32, self.size // 32, 4, 2, 1),
nn.LeakyReLU(0.2),
)
self.up1 = nn.Sequential(
la.ConvTranspose2dWNUB(
16 * F, 8 * F, self.size // 16, self.size // 16, 4, 2, 1
),
nn.LeakyReLU(0.2),
)
self.up2 = nn.Sequential(
la.ConvTranspose2dWNUB(
2 * 8 * F, 4 * F, self.size // 8, self.size // 8, 4, 2, 1
),
nn.LeakyReLU(0.2),
)
self.up3 = nn.Sequential(
la.ConvTranspose2dWNUB(
2 * 4 * F, 2 * F, self.size // 4, self.size // 4, 4, 2, 1
),
nn.LeakyReLU(0.2),
)
self.up4 = nn.Sequential(
la.ConvTranspose2dWNUB(
2 * 2 * F, F, self.size // 2, self.size // 2, 4, 2, 1
),
nn.LeakyReLU(0.2),
)
self.up5 = nn.Sequential(
la.ConvTranspose2dWNUB(2 * F, F, self.size, self.size, 4, 2, 1),
nn.LeakyReLU(0.2),
)
self.out = la.Conv2dWNUB(
F + in_channels, out_channels, self.size, self.size, kernel_size=1
)
self.apply(lambda x: la.glorot(x, 0.2))
la.glorot(self.out, 1.0)
def forward(self, x):
x1 = x
x2 = self.down1(x1)
x3 = self.down2(x2)
x4 = self.down3(x3)
x5 = self.down4(x4)
x6 = self.down5(x5)
x = th.cat([self.up1(x6), x5], 1)
x = th.cat([self.up2(x), x4], 1)
x = th.cat([self.up3(x), x3], 1)
x = th.cat([self.up4(x), x2], 1)
x = self.up5(x)
x = th.cat([x, x1], dim=1)
return self.out(x)
class UNetW(nn.Module):
def __init__(
self,
in_channels,
out_channels,
n_init_ftrs,
kernel_size=4,
out_scale=1.0,
):
super().__init__()
self.out_scale = out_scale
F = n_init_ftrs
self.down1 = nn.Sequential(
la.Conv2dWN(in_channels, F, kernel_size, 2, 1),
nn.LeakyReLU(0.2),
)
self.down2 = nn.Sequential(
la.Conv2dWN(F, 2 * F, kernel_size, 2, 1),
nn.LeakyReLU(0.2),
)
self.down3 = nn.Sequential(
la.Conv2dWN(2 * F, 4 * F, kernel_size, 2, 1),
nn.LeakyReLU(0.2),
)
self.down4 = nn.Sequential(
la.Conv2dWN(4 * F, 8 * F, kernel_size, 2, 1),
nn.LeakyReLU(0.2),
)
self.down5 = nn.Sequential(
la.Conv2dWN(8 * F, 16 * F, kernel_size, 2, 1),
nn.LeakyReLU(0.2),
)
self.up1 = nn.Sequential(
la.ConvTranspose2dWN(16 * F, 8 * F, kernel_size, 2, 1),
nn.LeakyReLU(0.2),
)
self.up2 = nn.Sequential(
la.ConvTranspose2dWN(8 * F, 4 * F, kernel_size, 2, 1),
nn.LeakyReLU(0.2),
)
self.up3 = nn.Sequential(
la.ConvTranspose2dWN(4 * F, 2 * F, kernel_size, 2, 1),
nn.LeakyReLU(0.2),
)
self.up4 = nn.Sequential(
la.ConvTranspose2dWN(2 * F, F, kernel_size, 2, 1),
nn.LeakyReLU(0.2),
)
self.up5 = nn.Sequential(
la.ConvTranspose2dWN(F, F, kernel_size, 2, 1), nn.LeakyReLU(0.2)
)
self.out = la.Conv2dWN(F + in_channels, out_channels, kernel_size=1)
self.apply(weights_initializer(0.2))
self.out.apply(weights_initializer(1.0))
def forward(self, x):
x1 = x
x2 = self.down1(x1)
x3 = self.down2(x2)
x4 = self.down3(x3)
x5 = self.down4(x4)
x6 = self.down5(x5)
# TODO: switch to concat?
x = self.up1(x6) + x5
x = self.up2(x) + x4
x = self.up3(x) + x3
x = self.up4(x) + x2
x = self.up5(x)
x = th.cat([x, x1], dim=1)
return self.out(x) * self.out_scale