File size: 4,876 Bytes
7786bd6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
# PyTorch for deep learning operations
import torch
import torch.nn as nn

# PyTorch data loading and utilities
import torch.multiprocessing

# COCO dataset tools
from transformers import BertModel, BertTokenizer, AutoModel, AutoImageProcessor

from configs import CFG
from text_image import OneEncoder as TextImageEncoder



class AlignmentLayer(nn.Module):
    def __init__(self, input_dim=768, projection_dim=CFG.projection_dim, dropout_rate=CFG.dropout_rate, *args,
                 **kwargs):
        super(AlignmentLayer, self).__init__(*args, **kwargs)
        # Attributes
        self.input_dim = input_dim
        self.projection_dim = projection_dim
        self.dropout_rate = dropout_rate
        # Layers
        self.linear_layer1 = nn.Linear(self.input_dim, self.projection_dim)
        self.gelu = nn.GELU()
        self.linear_layer2 = nn.Linear(self.projection_dim, self.projection_dim)
        self.dropout = nn.Dropout(self.dropout_rate)
        self.normalization_layer = nn.LayerNorm(self.projection_dim)

    def forward(self, inputs):
        x = inputs
        x = self.linear_layer1(x)
        x = self.gelu(x)
        x = self.linear_layer2(x)
        x = self.dropout(x)
        x = self.normalization_layer(x)
        return x

    def __call__(self, inputs):
        return self.forward(inputs)


class RadioEncoder(nn.Module):
    def __init__(self, model_name=CFG.radio_name, projection_dim=CFG.projection_dim,
                 trainable=False, dropout_rate=CFG.dropout_rate, *args, **kwargs):
        super(RadioEncoder, self).__init__(*args, **kwargs)
        # Attributes
        self.model_name = model_name
        self.projection_dim = projection_dim
        self.dropout_rate = dropout_rate
        self.trainable = trainable
        # Models
        self.pretrained_encoder = AutoModel.from_pretrained(self.model_name)
        self.alignment_layer = AlignmentLayer(
            input_dim=self.pretrained_encoder.config.hidden_size,
            projection_dim=self.projection_dim,
            dropout_rate=self.dropout_rate)
        # Freeze Wav2VecModel
        for parameter in self.pretrained_encoder.parameters():
            parameter.requires_grad = self.trainable

    def forward(self, inputs):
        x = self.pretrained_encoder(inputs).last_hidden_state
        x = self.alignment_layer(x)
        return x

    def __call__(self, inputs):
        return self.forward(inputs)


class ModalityTokenEncoder(nn.Module):
    def __init__(self, projection_dim=CFG.projection_dim, token_size=CFG.token_size, device='cpu', *args, **kwargs):
        super(ModalityTokenEncoder, self).__init__(*args, **kwargs)
        # Attributes
        self.projection_dim = projection_dim
        self.device = device
        self.token_size = token_size
        # Models
        radio_variance = torch.rand(1) * 0.5 + 0.1
        self.radio_token = nn.Parameter(torch.normal(mean=0, std=radio_variance.item(),
                                                     size=(self.token_size, self.projection_dim)).to(self.device))

    def forward(self):
        return self.radio_token

    def __call__(self):
        return self.forward()


class OneEncoder(nn.Module):
    def __init__(self, device='cpu', modality_token_encoder=ModalityTokenEncoder(),
                 checkpoint="bilalfaye/OneEncoder-text-image",
                 radio_processor=AutoImageProcessor.from_pretrained("microsoft/rad-dino"),
                 sample_rate=CFG.sample_rate, radio_encoder=RadioEncoder(), *args, **kwargs):
        super(OneEncoder, self).__init__(*args, **kwargs)

        self.device = device
        self.checkpoint = checkpoint
        self.modality_token_encoder = modality_token_encoder
        self.modality_token_encoder.device = self.device
        self.text_image_encoder = TextImageEncoder(device=self.device)
        self.text_image_encoder.from_pretrained(self.checkpoint)
        self.radio_processor = radio_processor
        self.sample_rate = sample_rate
        self.radio_encoder = radio_encoder
        self.temperature = nn.Parameter(torch.tensor(0.07).to(self.device))

        # Freeze
        for parameter in self.text_image_encoder.parameters():
            parameter.requires_grad = False

    def encode_radio(self, pil_radios=None, radios=None):
        """
        :param pil_radios: list of pillow images
        :param radios: preprocessed image
        :return: tensor
        """
        if pil_radios is not None:
            tensors = self.radio_processor(pil_radios, return_tensors="pt")["pixel_values"].to(self.device)
        else:
            tensors = radios.to(self.device)
        features = self.radio_encoder(tensors)
        radio_token = self.modality_token_encoder()
        outputs = self.text_image_encoder.universal_projection_encoder([features, radio_token]).last_hidden_state
        return outputs