File size: 727 Bytes
09823ea
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
import torch
import torch.nn as nn
from torchvision import models

def build_model(num_classes, freeze_backbone=True):
    """
    Build and return a MobileNetV2 model fine-tuned for our custom classes.

    Args:
        num_classes (int): Number of disease classes
        freeze_backbone (bool): If True, freeze feature extractor layers

    Returns:
        model (nn.Module)
    """
    model = models.mobilenet_v2(weights='IMAGENET1K_V1')

    if freeze_backbone:
        for param in model.features.parameters():
            param.requires_grad = False

    # Replace the classifier
    model.classifier = nn.Sequential(
        nn.Dropout(0.2),
        nn.Linear(model.last_channel, num_classes)
    )

    return model