import huggingface_hub import pretrainedmodels import torch import torch.nn as nn def get_model(model_name="se_resnext50_32x4d", num_classes=101, pretrained="imagenet"): """ Loads a pre-trained model. Args: model_name (str): Name of the model to load. num_classes (int): Number of classes for the model. pretrained (str): Whether to use pre-trained weights. Returns: torch.nn.Module: The loaded model. """ model = pretrainedmodels.__dict__[model_name](pretrained=pretrained) dim_feats = model.last_linear.in_features model.last_linear = nn.Linear(dim_feats, num_classes) model.avg_pool = nn.AdaptiveAvgPool2d(1) return model def load_model(device): """ Loads the age estimation model from Hugging Face Hub. Args: device (torch.device): The device to load the model onto. Returns: torch.nn.Module: The loaded model. """ model = get_model(model_name="se_resnext50_32x4d", pretrained=None) path = huggingface_hub.hf_hub_download( "public-data/yu4u-age-estimation-pytorch", "pretrained.pth" ) model.load_state_dict(torch.load(path, weights_only=True)) model = model.to(device) model.eval() return model