File size: 1,528 Bytes
8583887
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
import torch.nn as nn
import torch
import numpy as np
from model.encoder import ImageEncoder, RobertaEncoder
import torch.nn.functional as F
class LVL(nn.Module):
    def __init__(self):
        super(LVL, self).__init__()
        self.image_encoder = ImageEncoder()
        self.text_encoder = RobertaEncoder()
        self.t_prime = nn.Parameter(torch.ones([]) * np.log(0.07))
        self.b = nn.Parameter(torch.ones([]) * 0)

    def get_images_features(self,images):
        image_embeddings = self.image_encoder(images)  # (batch_size, EMBEDDING_DIM)
        image_embeddings = F.normalize(image_embeddings, p=2, dim=-1)
        return image_embeddings

    def get_texts_feature(self,input_ids,attention_mask):
        text_embeddings = self.text_encoder(input_ids, attention_mask)  # (batch_size, EMBEDDING_DIM)
        text_embeddings = F.normalize(text_embeddings, p=2, dim=-1)
        return text_embeddings

    def forward(self, images, input_ids, attention_mask):
        """

        Args:

            images: Tensor of shape (batch_size, 3, 224, 224)

            input_ids: Tensor of shape (batch_size, seq_length)

            attention_mask: Tensor of shape (batch_size, seq_length)



        Returns:

            Image and text embeddings normalized for similarity calculation

        """

        image_embeddings = self.get_images_features(images)
        text_embeddings = self.get_texts_feature(input_ids, attention_mask)
        return image_embeddings, text_embeddings