Spaces:
Sleeping
Sleeping
File size: 727 Bytes
09823ea |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 |
import torch
import torch.nn as nn
from torchvision import models
def build_model(num_classes, freeze_backbone=True):
"""
Build and return a MobileNetV2 model fine-tuned for our custom classes.
Args:
num_classes (int): Number of disease classes
freeze_backbone (bool): If True, freeze feature extractor layers
Returns:
model (nn.Module)
"""
model = models.mobilenet_v2(weights='IMAGENET1K_V1')
if freeze_backbone:
for param in model.features.parameters():
param.requires_grad = False
# Replace the classifier
model.classifier = nn.Sequential(
nn.Dropout(0.2),
nn.Linear(model.last_channel, num_classes)
)
return model |