Spaces:
Running
on
Zero
Running
on
Zero
File size: 1,251 Bytes
e70400c 34e2c3f e70400c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 |
import huggingface_hub
import pretrainedmodels
import torch
import torch.nn as nn
def get_model(model_name="se_resnext50_32x4d", num_classes=101, pretrained="imagenet"):
"""
Loads a pre-trained model.
Args:
model_name (str): Name of the model to load.
num_classes (int): Number of classes for the model.
pretrained (str): Whether to use pre-trained weights.
Returns:
torch.nn.Module: The loaded model.
"""
model = pretrainedmodels.__dict__[model_name](pretrained=pretrained)
dim_feats = model.last_linear.in_features
model.last_linear = nn.Linear(dim_feats, num_classes)
model.avg_pool = nn.AdaptiveAvgPool2d(1)
return model
def load_model(device):
"""
Loads the age estimation model from Hugging Face Hub.
Args:
device (torch.device): The device to load the model onto.
Returns:
torch.nn.Module: The loaded model.
"""
model = get_model(model_name="se_resnext50_32x4d", pretrained=None)
path = huggingface_hub.hf_hub_download(
"public-data/yu4u-age-estimation-pytorch", "pretrained.pth"
)
model.load_state_dict(torch.load(path, weights_only=True))
model = model.to(device)
model.eval()
return model
|