Spaces:
Runtime error
Runtime error
import os | |
import torch | |
import torch.nn as nn | |
class Encoder(nn.Module): | |
def __init__(self): | |
super(Encoder, self).__init__() | |
basemodel_name = 'tf_efficientnet_b5_ap' | |
repo_path = os.path.join(os.path.dirname(__file__), 'efficientnet_repo') | |
basemodel = torch.hub.load(repo_path, basemodel_name, pretrained=False, source='local') | |
# Remove last layer | |
basemodel.global_pool = nn.Identity() | |
basemodel.classifier = nn.Identity() | |
self.original_model = basemodel | |
def forward(self, x): | |
features = [x] | |
for k, v in self.original_model._modules.items(): | |
if k == 'blocks': | |
for _ki, vi in v._modules.items(): | |
features.append(vi(features[-1])) | |
else: | |
features.append(v(features[-1])) | |
return features | |