|
|
|
|
|
|
|
import torch |
|
from torch import nn |
|
|
|
|
|
|
|
|
|
|
|
class VGGBlock(nn.Module): |
|
''' |
|
Defines a modified block in the VGG architecture, |
|
which includes batch normalization between convolutional layers and ReLU activations. |
|
|
|
Reference: https://poloclub.github.io/cnn-explainer/ |
|
Reference: https://d2l.ai/chapter_convolutional-modern/vgg.html |
|
|
|
Args: |
|
num_convs (int): Number of consecutive convolutional layers + ReLU activations. |
|
in_channels (int): Number of channels in the input. |
|
hidden_channels (int): Number of hidden channels between convolutional layers. |
|
out_channels (int): Number of channels in the output. |
|
''' |
|
def __init__(self, |
|
num_convs: int, |
|
in_channels: int, |
|
hidden_channels: int, |
|
out_channels: int): |
|
super().__init__() |
|
|
|
self.layers = [] |
|
|
|
for i in range(num_convs): |
|
conv_in = in_channels if i == 0 else hidden_channels |
|
conv_out = out_channels if i == num_convs-1 else hidden_channels |
|
|
|
self.layers += [ |
|
nn.Conv2d(conv_in, conv_out, kernel_size = 3, stride = 1, padding = 1), |
|
nn.BatchNorm2d(conv_out), |
|
nn.ReLU() |
|
] |
|
|
|
self.layers.append(nn.MaxPool2d(kernel_size = 2, stride = 2)) |
|
|
|
self.vgg_blk = nn.Sequential(*self.layers) |
|
|
|
def forward(self, X: torch.Tensor) -> torch.Tensor: |
|
''' |
|
Forward pass of VGG block. |
|
|
|
Args: |
|
X (torch.Tensor): Input tensor of shape (batch_size, in_channels, height, width) |
|
Returns: |
|
torch.Tensor: Output tensor of shape (batch_size, out_channels, new_height, new_width) |
|
''' |
|
|
|
return self.vgg_blk(X) |
|
|
|
class TinyVGG(nn.Module): |
|
''' |
|
Creates a simplified version of a VGG model, adapted from |
|
https://github.com/poloclub/cnn-explainer/blob/master/tiny-vgg/tiny-vgg.py. |
|
The main difference is that the hidden dimensions and number of convolutional layers |
|
remain the same across VGG blocks and the classifier's linear layers has output fewer features. |
|
|
|
Args: |
|
num_blks (int): Number of VGG blocks to put in the model |
|
num_convs (int): Number of consecutive convolutional layers + ReLU activations in each VGG block. |
|
in_channels (int): Number of channels in the input. |
|
hidden_channels (int): Number of hidden channels between convolutional layers. |
|
fc_hidden_dim (int): Number of output (hidden) features for the first linear layer of the classifer. |
|
num_classes (int): Number of class labels. |
|
|
|
''' |
|
def __init__(self, |
|
num_blks: int, |
|
num_convs: int, |
|
in_channels: int, |
|
hidden_channels: int, |
|
fc_hidden_dim: int, |
|
num_classes: int): |
|
super().__init__() |
|
|
|
self.all_blks = [] |
|
for i in range(num_blks): |
|
conv_in = in_channels if i == 0 else hidden_channels |
|
self.all_blks.append( |
|
VGGBlock(num_convs, conv_in, hidden_channels, hidden_channels) |
|
) |
|
|
|
self.vgg_body = nn.Sequential(*self.all_blks) |
|
self.classifier = nn.Sequential( |
|
nn.Flatten(), |
|
nn.LazyLinear(fc_hidden_dim), nn.ReLU(), nn.Dropout(0.5), |
|
nn.LazyLinear(num_classes) |
|
) |
|
|
|
self.vgg_body.apply(self._custom_init) |
|
self.classifier.apply(self._custom_init) |
|
|
|
def _custom_init(self, module): |
|
''' |
|
Initializes convolutional layer weights with Xavier initialization method. |
|
Initializes convolutional layer biases to zero. |
|
''' |
|
if isinstance(module, (nn.Conv2d)): |
|
nn.init.xavier_uniform_(module.weight) |
|
nn.init.zeros_(module.bias) |
|
|
|
def forward(self, X: torch.Tensor) -> torch.Tensor: |
|
''' |
|
Forward pass of the TinyVGG model. |
|
|
|
Args: |
|
X (torch.Tensor): Input tensor of shape (batch_size, in_channels, height, width). |
|
|
|
Returns: |
|
torch.Tensor: Logits of shape (batch_size, num_classes). |
|
''' |
|
|
|
X = self.vgg_body(X) |
|
return self.classifier(X) |
|
|