import os import torch import torch.nn as nn class Encoder(nn.Module): def __init__(self): super(Encoder, self).__init__() basemodel_name = 'tf_efficientnet_b5_ap' repo_path = os.path.join(os.path.dirname(__file__), 'efficientnet_repo') basemodel = torch.hub.load(repo_path, basemodel_name, pretrained=False, source='local') # Remove last layer basemodel.global_pool = nn.Identity() basemodel.classifier = nn.Identity() self.original_model = basemodel def forward(self, x): features = [x] for k, v in self.original_model._modules.items(): if k == 'blocks': for _ki, vi in v._modules.items(): features.append(vi(features[-1])) else: features.append(v(features[-1])) return features