File size: 990 Bytes
fc8b3f4
 
 
 
 
 
 
 
 
 
 
 
 
 
68a574e
 
fc8b3f4
 
 
68a574e
fc8b3f4
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
from torchvision.models import resnet50, ResNet50_Weights
from transformers import PreTrainedModel
from .config import ResnetConfig
import torch.nn as nn

class ResNet50(nn.Module):
    def __init__(self, ):
        super().__init__()
        self.cnn = resnet50(weights=ResNet50_Weights.IMAGENET1K_V1)
        self.backbone = nn.Sequential(*list(self.cnn.children())[:-2])
        self.flaten = nn.Sequential(nn.AvgPool2d(kernel_size=7), nn.Flatten())
        self.fc_1 = nn.Linear(2048, 768)

    def forward(self, x):
        if len(x.shape) == 3:
            x = x.unsqueeze(0)
        x = self.backbone(x)
        x = self.flaten(x)
        x = self.fc_1(x)
        x = x.squeeze(0)
        return x
    
class ResNet50AffectiveFeatureExtractor(PreTrainedModel):
    config_class = ResnetConfig

    def __init__(self, config):
        super().__init__(config)
        self.model = ResNet50()
        del self.model.cnn

    def forward(self, tensor):
        return self.model(tensor)