File size: 3,742 Bytes
b6bb35e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 |
from typing import Optional, Dict, Any, List
import torch
import torch.nn as nn
# -----------------------------------------------------------------------------
# Blocks
# -----------------------------------------------------------------------------
class Conv2d(nn.Module):
""" Perform a 2D convolution
inputs are [b, c, h, w] where
b is the batch size
c is the number of channels
h is the height
w is the width
"""
def __init__(self,
in_channels: int,
out_channels: int,
kernel_size: int,
padding: int,
do_activation: bool = True,
):
super(Conv2d, self).__init__()
conv = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, padding=padding)
lst = [conv]
if do_activation:
lst.append(nn.PReLU())
self.conv = nn.Sequential(*lst)
def forward(self, x):
# x is [B, C, H, W]
return self.conv(x)
# -----------------------------------------------------------------------------
# Network
# -----------------------------------------------------------------------------
class _UNet(nn.Module):
def __init__(self,
in_channels: int = 1,
out_channels: int = 1,
features: List[int] = [64, 64, 64, 64, 64],
conv_kernel_size: int = 3,
conv: Optional[nn.Module] = None,
conv_kwargs: Dict[str,Any] = {}
):
"""
UNet (but can switch out the Conv)
"""
super(_UNet, self).__init__()
self.in_channels = in_channels
padding = (conv_kernel_size - 1) // 2
self.ups = nn.ModuleList()
self.downs = nn.ModuleList()
self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
# Down part of U-Net
for feat in features:
self.downs.append(
conv(
in_channels, feat, kernel_size=conv_kernel_size, padding=padding, **conv_kwargs
)
)
in_channels = feat
# Up part of U-Net
for feat in reversed(features):
self.ups.append(nn.UpsamplingBilinear2d(scale_factor=2))
self.ups.append(
conv(
# Factor of 2 is for the skip connections
feat * 2, feat, kernel_size=conv_kernel_size, padding=padding, **conv_kwargs
)
)
self.bottleneck = conv(
features[-1], features[-1], kernel_size=conv_kernel_size, padding=padding, **conv_kwargs
)
self.final_conv = conv(
features[0], out_channels, kernel_size=1, padding=0, do_activation=False, **conv_kwargs
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
skip_connections = []
for down in self.downs:
x = down(x)
skip_connections.append(x)
x = self.pool(x)
x = self.bottleneck(x)
skip_connections = skip_connections[::-1]
for idx in range(0, len(self.ups), 2):
x = self.ups[idx](x)
skip_connection = skip_connections[idx // 2]
concat_skip = torch.cat((skip_connection, x), dim=1)
x = self.ups[idx + 1](concat_skip)
return self.final_conv(x)
class UNet(_UNet):
"""
Unet with normal conv blocks
input shape: B x C x H x W
output shape: B x C x H x W
"""
def __init__(self, **kwargs) -> None:
super().__init__(conv=Conv2d, **kwargs)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return super().forward(x)
|