import torch.nn as nn import torch import numpy as np from model.encoder import ImageEncoder, RobertaEncoder import torch.nn.functional as F class LVL(nn.Module): def __init__(self): super(LVL, self).__init__() self.image_encoder = ImageEncoder() self.text_encoder = RobertaEncoder() self.t_prime = nn.Parameter(torch.ones([]) * np.log(0.07)) self.b = nn.Parameter(torch.ones([]) * 0) def get_images_features(self,images): image_embeddings = self.image_encoder(images) # (batch_size, EMBEDDING_DIM) image_embeddings = F.normalize(image_embeddings, p=2, dim=-1) return image_embeddings def get_texts_feature(self,input_ids,attention_mask): text_embeddings = self.text_encoder(input_ids, attention_mask) # (batch_size, EMBEDDING_DIM) text_embeddings = F.normalize(text_embeddings, p=2, dim=-1) return text_embeddings def forward(self, images, input_ids, attention_mask): """ Args: images: Tensor of shape (batch_size, 3, 224, 224) input_ids: Tensor of shape (batch_size, seq_length) attention_mask: Tensor of shape (batch_size, seq_length) Returns: Image and text embeddings normalized for similarity calculation """ image_embeddings = self.get_images_features(images) text_embeddings = self.get_texts_feature(input_ids, attention_mask) return image_embeddings, text_embeddings