Spaces:
Running
on
CPU Upgrade
Running
on
CPU Upgrade
File size: 3,690 Bytes
38a3c61 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 |
import torch.nn as nn
import torch.nn.functional as F
import torchvision.models as models
class ResNet(nn.Module):
def __init__(
self, resnet_type="resnet18", trainable_layers=3, num_output_neurons=2
):
super(ResNet, self).__init__()
# Dictionary to map resnet_type to the corresponding torchvision model and weights
resnet_dict = {
"resnet18": (models.resnet18, models.ResNet18_Weights.IMAGENET1K_V1),
"resnet34": (models.resnet34, models.ResNet34_Weights.IMAGENET1K_V1),
"resnet50": (models.resnet50, models.ResNet50_Weights.IMAGENET1K_V2),
"resnet101": (models.resnet101, models.ResNet101_Weights.IMAGENET1K_V2),
"resnet152": (models.resnet152, models.ResNet152_Weights.IMAGENET1K_V2),
}
# Ensure the provided resnet_type is valid
if resnet_type not in resnet_dict:
raise ValueError(
f"Invalid resnet_type. Expected one of: {list(resnet_dict.keys())}"
)
# Load the specified ResNet model with pre-trained weights
model_func, weights = resnet_dict[resnet_type]
self.resnet = model_func(weights=weights)
# Remove the last fully connected layer
self.resnet = nn.Sequential(*list(self.resnet.children())[:-2])
# Additional pooling to reduce dimensionality further
self.pool = nn.AdaptiveAvgPool2d((1, 1)) # Global average pooling
# Number of input features to the first fully connected layer
if resnet_type in ["resnet18", "resnet34"]:
fc_in_features = 512
else:
fc_in_features = 2048
# Simplified fully connected layers with Batch Normalization and Dropout
self.fc1 = nn.Linear(
fc_in_features, 128
) # Input features depend on the resnet type
self.bn1 = nn.BatchNorm1d(128) # Batch Normalization
self.dropout1 = nn.Dropout(0.5) # Helps prevent overfitting
self.fc2 = nn.Linear(128, 64)
self.bn2 = nn.BatchNorm1d(64) # Batch Normalization
self.dropout2 = nn.Dropout(0.5) # Helps prevent overfitting
self.fc3 = nn.Linear(
64, num_output_neurons
) # Output layer for binary classification
# Set the requires_grad attribute based on the number of trainable layers
self.set_trainable_layers(trainable_layers)
def set_trainable_layers(self, trainable_layers):
# If trainable_layers is 0, freeze all layers
if trainable_layers == 0:
for param in self.resnet.parameters():
param.requires_grad = False
else:
# Get the total number of layers in resnet
total_layers = len(list(self.resnet.children()))
# Make the last `trainable_layers` layers trainable
for i, layer in enumerate(self.resnet.children()):
if i < total_layers - trainable_layers:
for param in layer.parameters():
param.requires_grad = False
else:
for param in layer.parameters():
param.requires_grad = True
def forward(self, x):
# Use the ResNet backbone
x = self.resnet(x)
# Global average pooling
x = self.pool(x)
# Flattening the output for the dense layer
x = x.view(x.size(0), -1) # Adjust this based on the actual output size
x = F.relu(self.fc1(x))
x = self.bn1(x)
x = self.dropout1(x)
x = F.relu(self.fc2(x))
x = self.bn2(x)
x = self.dropout2(x)
x = self.fc3(x)
return x
|