Woleek's picture
Upload feature extractor
68a574e verified
raw
history blame contribute delete
990 Bytes
from torchvision.models import resnet50, ResNet50_Weights
from transformers import PreTrainedModel
from .config import ResnetConfig
import torch.nn as nn
class ResNet50(nn.Module):
def __init__(self, ):
super().__init__()
self.cnn = resnet50(weights=ResNet50_Weights.IMAGENET1K_V1)
self.backbone = nn.Sequential(*list(self.cnn.children())[:-2])
self.flaten = nn.Sequential(nn.AvgPool2d(kernel_size=7), nn.Flatten())
self.fc_1 = nn.Linear(2048, 768)
def forward(self, x):
if len(x.shape) == 3:
x = x.unsqueeze(0)
x = self.backbone(x)
x = self.flaten(x)
x = self.fc_1(x)
x = x.squeeze(0)
return x
class ResNet50AffectiveFeatureExtractor(PreTrainedModel):
config_class = ResnetConfig
def __init__(self, config):
super().__init__(config)
self.model = ResNet50()
del self.model.cnn
def forward(self, tensor):
return self.model(tensor)