File size: 545 Bytes
0077a91 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 |
import torch
import torch.nn as nn
import timm
from config import MODEL_PATH, DEVICE
def load_model():
model = timm.create_model('xception', pretrained=False)
model.fc = nn.Sequential(
nn.Linear(model.fc.in_features, 100),
nn.ReLU(),
nn.Dropout(0.7),
nn.Linear(100, 50),
nn.ReLU(),
nn.Dropout(0.7),
nn.Linear(50, 1),
nn.Sigmoid()
)
model.load_state_dict(torch.load(MODEL_PATH, map_location=DEVICE))
model.to(DEVICE).eval()
return model
|