rifatramadhani's picture
wip
34e2c3f
import huggingface_hub
import pretrainedmodels
import torch
import torch.nn as nn
def get_model(model_name="se_resnext50_32x4d", num_classes=101, pretrained="imagenet"):
"""
Loads a pre-trained model.
Args:
model_name (str): Name of the model to load.
num_classes (int): Number of classes for the model.
pretrained (str): Whether to use pre-trained weights.
Returns:
torch.nn.Module: The loaded model.
"""
model = pretrainedmodels.__dict__[model_name](pretrained=pretrained)
dim_feats = model.last_linear.in_features
model.last_linear = nn.Linear(dim_feats, num_classes)
model.avg_pool = nn.AdaptiveAvgPool2d(1)
return model
def load_model(device):
"""
Loads the age estimation model from Hugging Face Hub.
Args:
device (torch.device): The device to load the model onto.
Returns:
torch.nn.Module: The loaded model.
"""
model = get_model(model_name="se_resnext50_32x4d", pretrained=None)
path = huggingface_hub.hf_hub_download(
"public-data/yu4u-age-estimation-pytorch", "pretrained.pth"
)
model.load_state_dict(torch.load(path, weights_only=True))
model = model.to(device)
model.eval()
return model