enalis commited on
Commit
8583887
·
verified ·
1 Parent(s): 3c3150b

Adding Models and Inference

Browse files
Files changed (5) hide show
  1. config.js +11 -0
  2. encoder.py +66 -0
  3. inference.py +33 -0
  4. model.py +37 -0
  5. pytorch_model.bin +3 -0
config.js ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "SCOLD"
4
+ ],
5
+ "model_type": "clip-base",
6
+ "image_encoder": "swin_base_patch4_window7_224",
7
+ "text_encoder": "roberta-base",
8
+ "embedding_dim": 512,
9
+ "t_init": 0.07,
10
+ "b_init": 0.0
11
+ }
encoder.py ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from transformers import CLIPTextModel, RobertaModel, CLIPVisionModel
4
+ from timm import create_model
5
+ EMBEDDING_DIM = 512
6
+ class ImageEncoder(nn.Module):
7
+ def __init__(self):
8
+ super(ImageEncoder, self).__init__()
9
+ # Load the Swin Transformer with features_only=True
10
+ self.swin = create_model("swin_base_patch4_window7_224", pretrained=True, features_only=True)
11
+ for param in self.swin.parameters():
12
+ param.requires_grad = True
13
+ # Get the feature size of the final stage
14
+ self.swin_output_dim = self.swin.feature_info.channels()[-1] # Last stage: 1024 channels
15
+
16
+ # Define FC layer
17
+ self.fc1 = nn.Linear(self.swin_output_dim * 7 * 7, EMBEDDING_DIM) # Flattened input size
18
+ nn.init.xavier_uniform_(self.fc1.weight)
19
+ nn.init.zeros_(self.fc1.bias)
20
+
21
+
22
+ def forward(self, x):
23
+ # Extract features from Swin
24
+ swin_features = self.swin(x)[-1] # Use the last stage feature map (e.g., [B, 1024, 7, 7])
25
+
26
+ # Flatten feature map
27
+ swin_features = swin_features.view(swin_features.size(0), -1) # Shape: (B, 1024*7*7)
28
+
29
+ # Pass through FC layer
30
+ output = self.fc1(swin_features) # Shape: (B, embedding_dim)
31
+ return output
32
+
33
+ from transformers import RobertaModel
34
+
35
+ class RobertaEncoder(nn.Module):
36
+ def __init__(self, roberta_model_path="roberta-base"):
37
+ super(RobertaEncoder, self).__init__()
38
+ # Load pre-trained RoBERTa model
39
+ self.roberta = RobertaModel.from_pretrained(roberta_model_path)
40
+
41
+ # Add a linear projection layer to reduce dimensionality
42
+ self.projection = nn.Linear(self.roberta.config.hidden_size, EMBEDDING_DIM)
43
+
44
+ # Initialize the projection layer weights
45
+ nn.init.xavier_uniform_(self.projection.weight)
46
+ nn.init.zeros_(self.projection.bias)
47
+
48
+ # Allow fine-tuning of the RoBERTa model
49
+ for param in self.roberta.parameters():
50
+ param.requires_grad = True
51
+
52
+ def forward(self, input_ids, attention_mask):
53
+ """
54
+ Forward pass through RoBERTa.
55
+ Args:
56
+ input_ids: Tensor of shape (batch_size, seq_length)
57
+ attention_mask: Tensor of shape (batch_size, seq_length)
58
+
59
+ Returns:
60
+ Embedding: Tensor of shape (batch_size, EMBEDDING_DIM)
61
+ """
62
+ roberta_output = self.roberta(input_ids=input_ids, attention_mask=attention_mask)
63
+ cls_token = roberta_output.last_hidden_state[:, 0, :] # Use CLS token
64
+ pooled_output = torch.mean(roberta_output.last_hidden_state, dim=1) # Mean pooling
65
+
66
+ return self.projection(cls_token+pooled_output)
inference.py ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from model import LVL
3
+ from transformers import RobertaTokenizer
4
+ from PIL import Image
5
+ from torchvision import transforms
6
+
7
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
8
+
9
+ # Load model
10
+ model = LVL()
11
+ model.load_state_dict(torch.load("pytorch_model.bin", map_location=device))
12
+ model.to(device)
13
+ model.eval()
14
+
15
+ # Load tokenizer
16
+ tokenizer = RobertaTokenizer.from_pretrained("roberta-base")
17
+
18
+ # Image transform
19
+ transform = transforms.Compose([
20
+ transforms.Resize((224, 224)),
21
+ transforms.ToTensor()
22
+ ])
23
+
24
+
25
+ def predict(image_path, text):
26
+ image = transform(Image.open(image_path).convert("RGB")).unsqueeze(0).to(device)
27
+ tokens = tokenizer(text, return_tensors="pt", padding=True, truncation=True).to(device)
28
+
29
+ with torch.no_grad():
30
+ img_feat, txt_feat = model(image, tokens["input_ids"], tokens["attention_mask"])
31
+ similarity = torch.matmul(img_feat, txt_feat.T).squeeze()
32
+
33
+ return similarity.item()
model.py ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+ import torch
3
+ import numpy as np
4
+ from model.encoder import ImageEncoder, RobertaEncoder
5
+ import torch.nn.functional as F
6
+ class LVL(nn.Module):
7
+ def __init__(self):
8
+ super(LVL, self).__init__()
9
+ self.image_encoder = ImageEncoder()
10
+ self.text_encoder = RobertaEncoder()
11
+ self.t_prime = nn.Parameter(torch.ones([]) * np.log(0.07))
12
+ self.b = nn.Parameter(torch.ones([]) * 0)
13
+
14
+ def get_images_features(self,images):
15
+ image_embeddings = self.image_encoder(images) # (batch_size, EMBEDDING_DIM)
16
+ image_embeddings = F.normalize(image_embeddings, p=2, dim=-1)
17
+ return image_embeddings
18
+
19
+ def get_texts_feature(self,input_ids,attention_mask):
20
+ text_embeddings = self.text_encoder(input_ids, attention_mask) # (batch_size, EMBEDDING_DIM)
21
+ text_embeddings = F.normalize(text_embeddings, p=2, dim=-1)
22
+ return text_embeddings
23
+
24
+ def forward(self, images, input_ids, attention_mask):
25
+ """
26
+ Args:
27
+ images: Tensor of shape (batch_size, 3, 224, 224)
28
+ input_ids: Tensor of shape (batch_size, seq_length)
29
+ attention_mask: Tensor of shape (batch_size, seq_length)
30
+
31
+ Returns:
32
+ Image and text embeddings normalized for similarity calculation
33
+ """
34
+
35
+ image_embeddings = self.get_images_features(images)
36
+ text_embeddings = self.get_texts_feature(input_ids, attention_mask)
37
+ return image_embeddings, text_embeddings
pytorch_model.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1664be0db36c8a106001016da28c94416ae51671f2c7ae683fe07e90ceaaf352
3
+ size 950112466