CropGuard / src /model /architecture.py
mitraarka27's picture
πŸš€ Initial full clean push to Hugging Face
09823ea
raw
history blame contribute delete
727 Bytes
import torch
import torch.nn as nn
from torchvision import models
def build_model(num_classes, freeze_backbone=True):
"""
Build and return a MobileNetV2 model fine-tuned for our custom classes.
Args:
num_classes (int): Number of disease classes
freeze_backbone (bool): If True, freeze feature extractor layers
Returns:
model (nn.Module)
"""
model = models.mobilenet_v2(weights='IMAGENET1K_V1')
if freeze_backbone:
for param in model.features.parameters():
param.requires_grad = False
# Replace the classifier
model.classifier = nn.Sequential(
nn.Dropout(0.2),
nn.Linear(model.last_channel, num_classes)
)
return model