File size: 1,528 Bytes
8583887 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 |
import torch.nn as nn
import torch
import numpy as np
from model.encoder import ImageEncoder, RobertaEncoder
import torch.nn.functional as F
class LVL(nn.Module):
def __init__(self):
super(LVL, self).__init__()
self.image_encoder = ImageEncoder()
self.text_encoder = RobertaEncoder()
self.t_prime = nn.Parameter(torch.ones([]) * np.log(0.07))
self.b = nn.Parameter(torch.ones([]) * 0)
def get_images_features(self,images):
image_embeddings = self.image_encoder(images) # (batch_size, EMBEDDING_DIM)
image_embeddings = F.normalize(image_embeddings, p=2, dim=-1)
return image_embeddings
def get_texts_feature(self,input_ids,attention_mask):
text_embeddings = self.text_encoder(input_ids, attention_mask) # (batch_size, EMBEDDING_DIM)
text_embeddings = F.normalize(text_embeddings, p=2, dim=-1)
return text_embeddings
def forward(self, images, input_ids, attention_mask):
"""
Args:
images: Tensor of shape (batch_size, 3, 224, 224)
input_ids: Tensor of shape (batch_size, seq_length)
attention_mask: Tensor of shape (batch_size, seq_length)
Returns:
Image and text embeddings normalized for similarity calculation
"""
image_embeddings = self.get_images_features(images)
text_embeddings = self.get_texts_feature(input_ids, attention_mask)
return image_embeddings, text_embeddings
|