scold / model.py
enalis's picture
Adding Models and Inference
8583887 verified
raw
history blame contribute delete
1.53 kB
import torch.nn as nn
import torch
import numpy as np
from model.encoder import ImageEncoder, RobertaEncoder
import torch.nn.functional as F
class LVL(nn.Module):
def __init__(self):
super(LVL, self).__init__()
self.image_encoder = ImageEncoder()
self.text_encoder = RobertaEncoder()
self.t_prime = nn.Parameter(torch.ones([]) * np.log(0.07))
self.b = nn.Parameter(torch.ones([]) * 0)
def get_images_features(self,images):
image_embeddings = self.image_encoder(images) # (batch_size, EMBEDDING_DIM)
image_embeddings = F.normalize(image_embeddings, p=2, dim=-1)
return image_embeddings
def get_texts_feature(self,input_ids,attention_mask):
text_embeddings = self.text_encoder(input_ids, attention_mask) # (batch_size, EMBEDDING_DIM)
text_embeddings = F.normalize(text_embeddings, p=2, dim=-1)
return text_embeddings
def forward(self, images, input_ids, attention_mask):
"""
Args:
images: Tensor of shape (batch_size, 3, 224, 224)
input_ids: Tensor of shape (batch_size, seq_length)
attention_mask: Tensor of shape (batch_size, seq_length)
Returns:
Image and text embeddings normalized for similarity calculation
"""
image_embeddings = self.get_images_features(images)
text_embeddings = self.get_texts_feature(input_ids, attention_mask)
return image_embeddings, text_embeddings