File size: 1,251 Bytes
e70400c
 
 
 
 
34e2c3f
e70400c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
import huggingface_hub
import pretrainedmodels
import torch
import torch.nn as nn


def get_model(model_name="se_resnext50_32x4d", num_classes=101, pretrained="imagenet"):
    """
    Loads a pre-trained model.

    Args:
        model_name (str): Name of the model to load.
        num_classes (int): Number of classes for the model.
        pretrained (str): Whether to use pre-trained weights.

    Returns:
        torch.nn.Module: The loaded model.
    """
    model = pretrainedmodels.__dict__[model_name](pretrained=pretrained)
    dim_feats = model.last_linear.in_features
    model.last_linear = nn.Linear(dim_feats, num_classes)
    model.avg_pool = nn.AdaptiveAvgPool2d(1)
    return model


def load_model(device):
    """
    Loads the age estimation model from Hugging Face Hub.

    Args:
        device (torch.device): The device to load the model onto.

    Returns:
        torch.nn.Module: The loaded model.
    """
    model = get_model(model_name="se_resnext50_32x4d", pretrained=None)
    path = huggingface_hub.hf_hub_download(
        "public-data/yu4u-age-estimation-pytorch", "pretrained.pth"
    )
    model.load_state_dict(torch.load(path, weights_only=True))
    model = model.to(device)
    model.eval()
    return model