import torch | |
import torch.nn as nn | |
import timm | |
from config import MODEL_PATH, DEVICE | |
def load_model(): | |
model = timm.create_model('xception', pretrained=False) | |
model.fc = nn.Sequential( | |
nn.Linear(model.fc.in_features, 100), | |
nn.ReLU(), | |
nn.Dropout(0.7), | |
nn.Linear(100, 50), | |
nn.ReLU(), | |
nn.Dropout(0.7), | |
nn.Linear(50, 1), | |
nn.Sigmoid() | |
) | |
model.load_state_dict(torch.load(MODEL_PATH, map_location=DEVICE)) | |
model.to(DEVICE).eval() | |
return model | |