Spaces:
Configuration error
Configuration error
import torch | |
import torch.nn as nn | |
import torchvision.models as models | |
class ResClassifier(nn.Module): | |
""" | |
A classifier with two fully connected layers followed by a final linear layer. | |
Uses BatchNorm, ReLU activations, and Dropout for better generalization. | |
""" | |
def __init__(self, num_classes=14): | |
super(ResClassifier, self).__init__() | |
# First fully connected layer: reduces 128D features to 64D | |
self.fc1 = nn.Sequential( | |
nn.Linear(128, 64), | |
nn.BatchNorm1d(64, affine=True), | |
nn.ReLU(inplace=True), | |
nn.Dropout() | |
) | |
# Second fully connected layer: retains 64D features | |
self.fc2 = nn.Sequential( | |
nn.Linear(64, 64), | |
nn.BatchNorm1d(64, affine=True), | |
nn.ReLU(inplace=True), | |
nn.Dropout() | |
) | |
# Final classification layer mapping 64D features to class logits | |
self.fc3 = nn.Linear(64, num_classes) | |
def forward(self, x): | |
""" | |
Forward pass through the classifier. | |
Returns class logits after two hidden layers. | |
""" | |
x = self.fc1(x) # First FC layer | |
x = self.fc2(x) # Second FC layer | |
output = self.fc3(x) # Final classification layer | |
return output | |
class CC_model(nn.Module): | |
""" | |
Clothing Classification Model based on ResNet50. | |
Extracts deep features and uses two independent classifiers for predictions. | |
""" | |
def __init__(self, num_classes1=14, num_classes2=None): | |
super(CC_model, self).__init__() | |
# If num_classes2 is not specified, default to num_classes1 | |
num_classes2 = num_classes2 if num_classes2 else num_classes1 | |
assert num_classes1 == num_classes2 # Ensure both classifiers predict the same categories | |
self.num_classes = num_classes1 | |
# Load a pretrained ResNet-50 model as the feature extractor | |
self.model_resnet = models.resnet50(weights='ResNet50_Weights.DEFAULT') | |
# Remove ResNet's original classification layer to use as a feature extractor | |
num_ftrs = self.model_resnet.fc.in_features | |
self.model_resnet.fc = nn.Identity() # Identity layer keeps feature dimensions | |
# Additional transformation layer reducing feature size to 128D | |
self.dr = nn.Linear(num_ftrs, 128) | |
# Two independent classifiers | |
self.fc1 = ResClassifier(num_classes1) | |
self.fc2 = ResClassifier(num_classes1) | |
def forward(self, x, detach_feature=False): | |
""" | |
Forward pass through the model. | |
Extracts deep features from ResNet and processes them through classifiers. | |
""" | |
with torch.no_grad(): | |
# Extract deep features using ResNet-50 (without its original classification head) | |
feature = self.model_resnet(x) | |
# Generate transformed features (128D) using the custom linear layer | |
dr_feature = self.dr(feature) | |
if detach_feature: | |
dr_feature = dr_feature.detach() # Detach feature for non-trainable forward pass | |
# Pass features through two independent classifiers | |
out1 = self.fc1(dr_feature) | |
out2 = self.fc2(dr_feature) | |
# Compute the mean prediction from both classifiers | |
output_mean = (out1 + out2) / 2 | |
return dr_feature, output_mean # Returning feature embeddings and final prediction | |