File size: 3,690 Bytes
38a3c61
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
import torch.nn as nn
import torch.nn.functional as F
import torchvision.models as models


class ResNet(nn.Module):
    def __init__(
        self, resnet_type="resnet18", trainable_layers=3, num_output_neurons=2
    ):
        super(ResNet, self).__init__()

        # Dictionary to map resnet_type to the corresponding torchvision model and weights
        resnet_dict = {
            "resnet18": (models.resnet18, models.ResNet18_Weights.IMAGENET1K_V1),
            "resnet34": (models.resnet34, models.ResNet34_Weights.IMAGENET1K_V1),
            "resnet50": (models.resnet50, models.ResNet50_Weights.IMAGENET1K_V2),
            "resnet101": (models.resnet101, models.ResNet101_Weights.IMAGENET1K_V2),
            "resnet152": (models.resnet152, models.ResNet152_Weights.IMAGENET1K_V2),
        }

        # Ensure the provided resnet_type is valid
        if resnet_type not in resnet_dict:
            raise ValueError(
                f"Invalid resnet_type. Expected one of: {list(resnet_dict.keys())}"
            )

        # Load the specified ResNet model with pre-trained weights
        model_func, weights = resnet_dict[resnet_type]
        self.resnet = model_func(weights=weights)

        # Remove the last fully connected layer
        self.resnet = nn.Sequential(*list(self.resnet.children())[:-2])

        # Additional pooling to reduce dimensionality further
        self.pool = nn.AdaptiveAvgPool2d((1, 1))  # Global average pooling

        # Number of input features to the first fully connected layer
        if resnet_type in ["resnet18", "resnet34"]:
            fc_in_features = 512
        else:
            fc_in_features = 2048

        # Simplified fully connected layers with Batch Normalization and Dropout
        self.fc1 = nn.Linear(
            fc_in_features, 128
        )  # Input features depend on the resnet type
        self.bn1 = nn.BatchNorm1d(128)  # Batch Normalization
        self.dropout1 = nn.Dropout(0.5)  # Helps prevent overfitting

        self.fc2 = nn.Linear(128, 64)
        self.bn2 = nn.BatchNorm1d(64)  # Batch Normalization
        self.dropout2 = nn.Dropout(0.5)  # Helps prevent overfitting

        self.fc3 = nn.Linear(
            64, num_output_neurons
        )  # Output layer for binary classification

        # Set the requires_grad attribute based on the number of trainable layers
        self.set_trainable_layers(trainable_layers)

    def set_trainable_layers(self, trainable_layers):
        # If trainable_layers is 0, freeze all layers
        if trainable_layers == 0:
            for param in self.resnet.parameters():
                param.requires_grad = False
        else:
            # Get the total number of layers in resnet
            total_layers = len(list(self.resnet.children()))
            # Make the last `trainable_layers` layers trainable
            for i, layer in enumerate(self.resnet.children()):
                if i < total_layers - trainable_layers:
                    for param in layer.parameters():
                        param.requires_grad = False
                else:
                    for param in layer.parameters():
                        param.requires_grad = True

    def forward(self, x):
        # Use the ResNet backbone
        x = self.resnet(x)

        # Global average pooling
        x = self.pool(x)

        # Flattening the output for the dense layer
        x = x.view(x.size(0), -1)  # Adjust this based on the actual output size

        x = F.relu(self.fc1(x))
        x = self.bn1(x)
        x = self.dropout1(x)

        x = F.relu(self.fc2(x))
        x = self.bn2(x)
        x = self.dropout2(x)

        x = self.fc3(x)

        return x