File size: 545 Bytes
0077a91
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
import torch
import torch.nn as nn
import timm
from config import MODEL_PATH, DEVICE

def load_model():
    model = timm.create_model('xception', pretrained=False)
    model.fc = nn.Sequential(
        nn.Linear(model.fc.in_features, 100),
        nn.ReLU(),
        nn.Dropout(0.7),
        nn.Linear(100, 50),
        nn.ReLU(),
        nn.Dropout(0.7),
        nn.Linear(50, 1),
        nn.Sigmoid()
    )
    model.load_state_dict(torch.load(MODEL_PATH, map_location=DEVICE))
    model.to(DEVICE).eval()
    return model