Spaces:
Running
Running
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
class ConvBlock(nn.Module): | |
def __init__(self, in_channels, out_channels, use_act=True, **kwargs): | |
super().__init__() | |
self.cnn = nn.Conv2d(in_channels, out_channels, **kwargs, bias=False, padding_mode="reflect") | |
self.bn = nn.BatchNorm2d(out_channels) | |
self.act = nn.ReLU(inplace=True) if use_act else nn.Identity() | |
def forward(self, x): | |
return self.act(self.bn(self.cnn(x))) | |
class ResidualBlock(nn.Module): | |
def __init__(self, in_channels): | |
super().__init__() | |
self.survival_prob = 0.8 | |
self.block1 = ConvBlock( | |
in_channels, | |
in_channels, | |
kernel_size=3, | |
stride=1, | |
padding=1, | |
) | |
self.block2 = ConvBlock( | |
in_channels, | |
in_channels, | |
kernel_size=3, | |
stride=1, | |
padding=1, | |
use_act=True, | |
) | |
def stochastic_depth(self, x): | |
if not self.training: | |
return x | |
binary_tensor = torch.rand(x.shape[0], 1, 1, 1, device=x.device) < self.survival_prob | |
return torch.div(x, self.survival_prob) * binary_tensor | |
def forward(self, x): | |
out = self.block1(x) | |
out = self.block2(out) | |
return self.stochastic_depth(out) + x | |
class Block(nn.Module): | |
def __init__(self, in_channels, out_channels, stride=2, act="relu"): | |
super().__init__() | |
self.conv = nn.Sequential( | |
nn.Conv2d(in_channels, out_channels, 3, stride, 1, bias=False, padding_mode="reflect"), | |
nn.BatchNorm2d(out_channels), | |
nn.ReLU(inplace=True) if act == "relu" else nn.LeakyReLU(0.2, inplace=True), | |
) | |
def forward(self, x): | |
return self.conv(x) | |
class Generator(nn.Module): | |
def __init__(self, in_channels=3, features=64, num_residuals=9): | |
super().__init__() | |
self.initial_down = nn.Sequential( | |
nn.Conv2d(in_channels, features, 7, 1, 3, bias=True, padding_mode="reflect"), | |
nn.ReLU(inplace=True), | |
) | |
self.down1 = Block(features, features*2, act="relu") | |
self.down2 = Block(features*2, features*4, act="relu") | |
self.down3 = Block(features*4, features*8, act="relu") | |
self.down4 = Block(features*8, features*16, act="relu") | |
self.residuals = nn.Sequential(*[ResidualBlock(features*16) for _ in range(num_residuals)]) | |
self.up1 = Block(features*16, features*8, stride=1, act="relu") | |
self.up2 = Block(features*8*2, features*4, stride=1, act="relu" ) | |
self.up3 = Block(features*4*2, features*2, stride=1, act="relu") | |
self.up4 = Block(features*2*2, features, stride=1, act="relu") | |
self.final_conv = nn.Sequential( | |
Block(features*2, features, stride=1, act="relu"), | |
Block(features, features, stride=1, act="relu"), | |
nn.Conv2d(features, in_channels, 7,1,3, padding_mode="reflect"), | |
nn.Tanh(), | |
) | |
def forward(self, x): | |
d1 = self.initial_down(x) | |
d2 = self.down1(d1) | |
d3 = self.down2(d2) | |
d4 = self.down3(d3) | |
d5 = self.down4(d4) | |
residuals = self.residuals(d5) + d5 | |
u1 = self.up1(F.interpolate(residuals, scale_factor=2, mode="nearest")) | |
u2 = self.up2(F.interpolate(torch.cat([u1, d4], dim=1), scale_factor=2, mode="nearest")) | |
u3 = self.up3(F.interpolate(torch.cat([u2, d3], dim=1), scale_factor=2, mode="nearest")) | |
u4 = self.up4(F.interpolate(torch.cat([u3, d2], dim=1), scale_factor=2, mode="nearest")) | |
return self.final_conv(torch.cat([u4, d1], dim=1)) | |
def test(): | |
x = torch.randn((1, 3, 256, 256)) | |
model = Generator(in_channels=3, features=64) | |
preds = model(x) | |
print(preds.shape) | |
#test() | |