import torch import torch.nn as nn import torchvision.models as models class ResClassifier(nn.Module): """ A classifier with two fully connected layers followed by a final linear layer. Uses BatchNorm, ReLU activations, and Dropout for better generalization. """ def __init__(self, num_classes=14): super(ResClassifier, self).__init__() # First fully connected layer: reduces 128D features to 64D self.fc1 = nn.Sequential( nn.Linear(128, 64), nn.BatchNorm1d(64, affine=True), nn.ReLU(inplace=True), nn.Dropout() ) # Second fully connected layer: retains 64D features self.fc2 = nn.Sequential( nn.Linear(64, 64), nn.BatchNorm1d(64, affine=True), nn.ReLU(inplace=True), nn.Dropout() ) # Final classification layer mapping 64D features to class logits self.fc3 = nn.Linear(64, num_classes) def forward(self, x): """ Forward pass through the classifier. Returns class logits after two hidden layers. """ x = self.fc1(x) # First FC layer x = self.fc2(x) # Second FC layer output = self.fc3(x) # Final classification layer return output class CC_model(nn.Module): """ Clothing Classification Model based on ResNet50. Extracts deep features and uses two independent classifiers for predictions. """ def __init__(self, num_classes1=14, num_classes2=None): super(CC_model, self).__init__() # If num_classes2 is not specified, default to num_classes1 num_classes2 = num_classes2 if num_classes2 else num_classes1 assert num_classes1 == num_classes2 # Ensure both classifiers predict the same categories self.num_classes = num_classes1 # Load a pretrained ResNet-50 model as the feature extractor self.model_resnet = models.resnet50(weights='ResNet50_Weights.DEFAULT') # Remove ResNet's original classification layer to use as a feature extractor num_ftrs = self.model_resnet.fc.in_features self.model_resnet.fc = nn.Identity() # Identity layer keeps feature dimensions # Additional transformation layer reducing feature size to 128D self.dr = nn.Linear(num_ftrs, 128) # Two independent classifiers self.fc1 = ResClassifier(num_classes1) self.fc2 = ResClassifier(num_classes1) def forward(self, x, detach_feature=False): """ Forward pass through the model. Extracts deep features from ResNet and processes them through classifiers. """ with torch.no_grad(): # Extract deep features using ResNet-50 (without its original classification head) feature = self.model_resnet(x) # Generate transformed features (128D) using the custom linear layer dr_feature = self.dr(feature) if detach_feature: dr_feature = dr_feature.detach() # Detach feature for non-trainable forward pass # Pass features through two independent classifiers out1 = self.fc1(dr_feature) out2 = self.fc2(dr_feature) # Compute the mean prediction from both classifiers output_mean = (out1 + out2) / 2 return dr_feature, output_mean # Returning feature embeddings and final prediction