|
from torchvision.models import resnet50, ResNet50_Weights |
|
from transformers import PreTrainedModel |
|
from .config import ResnetConfig |
|
import torch.nn as nn |
|
|
|
class ResNet50(nn.Module): |
|
def __init__(self, ): |
|
super().__init__() |
|
self.cnn = resnet50(weights=ResNet50_Weights.IMAGENET1K_V1) |
|
self.backbone = nn.Sequential(*list(self.cnn.children())[:-2]) |
|
self.flaten = nn.Sequential(nn.AvgPool2d(kernel_size=7), nn.Flatten()) |
|
self.fc_1 = nn.Linear(2048, 768) |
|
|
|
def forward(self, x): |
|
if len(x.shape) == 3: |
|
x = x.unsqueeze(0) |
|
x = self.backbone(x) |
|
x = self.flaten(x) |
|
x = self.fc_1(x) |
|
x = x.squeeze(0) |
|
return x |
|
|
|
class ResNet50AffectiveFeatureExtractor(PreTrainedModel): |
|
config_class = ResnetConfig |
|
|
|
def __init__(self, config): |
|
super().__init__(config) |
|
self.model = ResNet50() |
|
del self.model.cnn |
|
|
|
def forward(self, tensor): |
|
return self.model(tensor) |