clothing1M / ResNet_for_CC.py
Moditha24's picture
Upload 4 files
6ec35ea verified
import torch
import torch.nn as nn
import torchvision.models as models
class ResClassifier(nn.Module):
"""
A classifier with two fully connected layers followed by a final linear layer.
Uses BatchNorm, ReLU activations, and Dropout for better generalization.
"""
def __init__(self, num_classes=14):
super(ResClassifier, self).__init__()
# First fully connected layer: reduces 128D features to 64D
self.fc1 = nn.Sequential(
nn.Linear(128, 64),
nn.BatchNorm1d(64, affine=True),
nn.ReLU(inplace=True),
nn.Dropout()
)
# Second fully connected layer: retains 64D features
self.fc2 = nn.Sequential(
nn.Linear(64, 64),
nn.BatchNorm1d(64, affine=True),
nn.ReLU(inplace=True),
nn.Dropout()
)
# Final classification layer mapping 64D features to class logits
self.fc3 = nn.Linear(64, num_classes)
def forward(self, x):
"""
Forward pass through the classifier.
Returns class logits after two hidden layers.
"""
x = self.fc1(x) # First FC layer
x = self.fc2(x) # Second FC layer
output = self.fc3(x) # Final classification layer
return output
class CC_model(nn.Module):
"""
Clothing Classification Model based on ResNet50.
Extracts deep features and uses two independent classifiers for predictions.
"""
def __init__(self, num_classes1=14, num_classes2=None):
super(CC_model, self).__init__()
# If num_classes2 is not specified, default to num_classes1
num_classes2 = num_classes2 if num_classes2 else num_classes1
assert num_classes1 == num_classes2 # Ensure both classifiers predict the same categories
self.num_classes = num_classes1
# Load a pretrained ResNet-50 model as the feature extractor
self.model_resnet = models.resnet50(weights='ResNet50_Weights.DEFAULT')
# Remove ResNet's original classification layer to use as a feature extractor
num_ftrs = self.model_resnet.fc.in_features
self.model_resnet.fc = nn.Identity() # Identity layer keeps feature dimensions
# Additional transformation layer reducing feature size to 128D
self.dr = nn.Linear(num_ftrs, 128)
# Two independent classifiers
self.fc1 = ResClassifier(num_classes1)
self.fc2 = ResClassifier(num_classes1)
def forward(self, x, detach_feature=False):
"""
Forward pass through the model.
Extracts deep features from ResNet and processes them through classifiers.
"""
with torch.no_grad():
# Extract deep features using ResNet-50 (without its original classification head)
feature = self.model_resnet(x)
# Generate transformed features (128D) using the custom linear layer
dr_feature = self.dr(feature)
if detach_feature:
dr_feature = dr_feature.detach() # Detach feature for non-trainable forward pass
# Pass features through two independent classifiers
out1 = self.fc1(dr_feature)
out2 = self.fc2(dr_feature)
# Compute the mean prediction from both classifiers
output_mean = (out1 + out2) / 2
return dr_feature, output_mean # Returning feature embeddings and final prediction