import torch import torch.nn as nn import timm from config import MODEL_PATH, DEVICE def load_model(): model = timm.create_model('xception', pretrained=False) model.fc = nn.Sequential( nn.Linear(model.fc.in_features, 100), nn.ReLU(), nn.Dropout(0.7), nn.Linear(100, 50), nn.ReLU(), nn.Dropout(0.7), nn.Linear(50, 1), nn.Sigmoid() ) model.load_state_dict(torch.load(MODEL_PATH, map_location=DEVICE)) model.to(DEVICE).eval() return model