Spaces:
Running
on
Zero
Running
on
Zero
import huggingface_hub | |
import pretrainedmodels | |
import torch | |
import torch.nn as nn | |
def get_model(model_name="se_resnext50_32x4d", num_classes=101, pretrained="imagenet"): | |
""" | |
Loads a pre-trained model. | |
Args: | |
model_name (str): Name of the model to load. | |
num_classes (int): Number of classes for the model. | |
pretrained (str): Whether to use pre-trained weights. | |
Returns: | |
torch.nn.Module: The loaded model. | |
""" | |
model = pretrainedmodels.__dict__[model_name](pretrained=pretrained) | |
dim_feats = model.last_linear.in_features | |
model.last_linear = nn.Linear(dim_feats, num_classes) | |
model.avg_pool = nn.AdaptiveAvgPool2d(1) | |
return model | |
def load_model(device): | |
""" | |
Loads the age estimation model from Hugging Face Hub. | |
Args: | |
device (torch.device): The device to load the model onto. | |
Returns: | |
torch.nn.Module: The loaded model. | |
""" | |
model = get_model(model_name="se_resnext50_32x4d", pretrained=None) | |
path = huggingface_hub.hf_hub_download( | |
"public-data/yu4u-age-estimation-pytorch", "pretrained.pth" | |
) | |
model.load_state_dict(torch.load(path, weights_only=True)) | |
model = model.to(device) | |
model.eval() | |
return model | |