enalis commited on
Commit
df511be
·
verified ·
1 Parent(s): c5c6e3d

Update encoder.py

Browse files
Files changed (1) hide show
  1. encoder.py +101 -66
encoder.py CHANGED
@@ -1,66 +1,101 @@
1
- import torch
2
- import torch.nn as nn
3
- from transformers import CLIPTextModel, RobertaModel, CLIPVisionModel
4
- from timm import create_model
5
- EMBEDDING_DIM = 512
6
- class ImageEncoder(nn.Module):
7
- def __init__(self):
8
- super(ImageEncoder, self).__init__()
9
- # Load the Swin Transformer with features_only=True
10
- self.swin = create_model("swin_base_patch4_window7_224", pretrained=True, features_only=True)
11
- for param in self.swin.parameters():
12
- param.requires_grad = True
13
- # Get the feature size of the final stage
14
- self.swin_output_dim = self.swin.feature_info.channels()[-1] # Last stage: 1024 channels
15
-
16
- # Define FC layer
17
- self.fc1 = nn.Linear(self.swin_output_dim * 7 * 7, EMBEDDING_DIM) # Flattened input size
18
- nn.init.xavier_uniform_(self.fc1.weight)
19
- nn.init.zeros_(self.fc1.bias)
20
-
21
-
22
- def forward(self, x):
23
- # Extract features from Swin
24
- swin_features = self.swin(x)[-1] # Use the last stage feature map (e.g., [B, 1024, 7, 7])
25
-
26
- # Flatten feature map
27
- swin_features = swin_features.view(swin_features.size(0), -1) # Shape: (B, 1024*7*7)
28
-
29
- # Pass through FC layer
30
- output = self.fc1(swin_features) # Shape: (B, embedding_dim)
31
- return output
32
-
33
- from transformers import RobertaModel
34
-
35
- class RobertaEncoder(nn.Module):
36
- def __init__(self, roberta_model_path="roberta-base"):
37
- super(RobertaEncoder, self).__init__()
38
- # Load pre-trained RoBERTa model
39
- self.roberta = RobertaModel.from_pretrained(roberta_model_path)
40
-
41
- # Add a linear projection layer to reduce dimensionality
42
- self.projection = nn.Linear(self.roberta.config.hidden_size, EMBEDDING_DIM)
43
-
44
- # Initialize the projection layer weights
45
- nn.init.xavier_uniform_(self.projection.weight)
46
- nn.init.zeros_(self.projection.bias)
47
-
48
- # Allow fine-tuning of the RoBERTa model
49
- for param in self.roberta.parameters():
50
- param.requires_grad = True
51
-
52
- def forward(self, input_ids, attention_mask):
53
- """
54
- Forward pass through RoBERTa.
55
- Args:
56
- input_ids: Tensor of shape (batch_size, seq_length)
57
- attention_mask: Tensor of shape (batch_size, seq_length)
58
-
59
- Returns:
60
- Embedding: Tensor of shape (batch_size, EMBEDDING_DIM)
61
- """
62
- roberta_output = self.roberta(input_ids=input_ids, attention_mask=attention_mask)
63
- cls_token = roberta_output.last_hidden_state[:, 0, :] # Use CLS token
64
- pooled_output = torch.mean(roberta_output.last_hidden_state, dim=1) # Mean pooling
65
-
66
- return self.projection(cls_token+pooled_output)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from transformers import CLIPTextModel, RobertaModel, CLIPVisionModel
4
+ from timm import create_model
5
+ EMBEDDING_DIM = 512
6
+ class ImageEncoder(nn.Module):
7
+ def __init__(self):
8
+ super(ImageEncoder, self).__init__()
9
+ # Load the Swin Transformer with features_only=True
10
+ self.swin = create_model("swin-tiny-patch4-window7-224 ", pretrained=True, features_only=True)
11
+ for param in self.swin.parameters():
12
+ param.requires_grad = True
13
+
14
+ # Get the feature size of the final stage
15
+ self.swin_output_dim = self.swin.feature_info.channels()[-1] # Last stage: 1024 channels
16
+
17
+ # Define FC layer
18
+ self.fc1 = nn.Linear(self.swin_output_dim * 7 * 7, EMBEDDING_DIM) # Flattened input size
19
+ nn.init.xavier_uniform_(self.fc1.weight)
20
+ nn.init.zeros_(self.fc1.bias)
21
+ for param in self.fc1.parameters():
22
+ param.requires_grad = True
23
+
24
+
25
+ def forward(self, x):
26
+ # Extract features from Swin
27
+ swin_features = self.swin(x)[-1] # Use the last stage feature map (e.g., [B, 1024, 7, 7])
28
+
29
+ # Flatten feature map
30
+ swin_features = swin_features.view(swin_features.size(0), -1) # Shape: (B, 1024*7*7)
31
+
32
+ # Pass through FC layer
33
+ output = self.fc1(swin_features) # Shape: (B, embedding_dim)
34
+ return output
35
+
36
+ from transformers import RobertaModel
37
+
38
+ class RobertaEncoder(nn.Module):
39
+ def __init__(self, roberta_model_path="roberta-base"):
40
+ super(RobertaEncoder, self).__init__()
41
+ # Load pre-trained RoBERTa model
42
+ self.roberta = RobertaModel.from_pretrained(roberta_model_path)
43
+
44
+ # Add a linear projection layer to reduce dimensionality
45
+ self.projection = nn.Linear(self.roberta.config.hidden_size, EMBEDDING_DIM)
46
+
47
+ # Initialize the projection layer weights
48
+ nn.init.xavier_uniform_(self.projection.weight)
49
+ nn.init.zeros_(self.projection.bias)
50
+
51
+ # Allow fine-tuning of the RoBERTa model
52
+ for param in self.roberta.parameters():
53
+ param.requires_grad = True
54
+
55
+ def forward(self, input_ids, attention_mask):
56
+ """
57
+ Forward pass through RoBERTa.
58
+ Args:
59
+ input_ids: Tensor of shape (batch_size, seq_length)
60
+ attention_mask: Tensor of shape (batch_size, seq_length)
61
+
62
+ Returns:
63
+ Embedding: Tensor of shape (batch_size, EMBEDDING_DIM)
64
+ """
65
+ roberta_output = self.roberta(input_ids=input_ids, attention_mask=attention_mask)
66
+ cls_token = roberta_output.last_hidden_state[:, 0, :] # Use CLS token
67
+ pooled_output = torch.mean(roberta_output.last_hidden_state, dim=1) # Mean pooling
68
+
69
+ return self.projection(cls_token+pooled_output)
70
+
71
+ from transformers import AutoTokenizer, Siglip2TextModel,AutoModel
72
+
73
+
74
+ class SigLIP2TextEncoder(nn.Module):
75
+ def __init__(self, embedding_dim=512):
76
+ super(SigLIP2TextEncoder, self).__init__()
77
+ model = AutoModel.from_pretrained("google/siglip2-base-patch16-224")
78
+ self.text_encoder = model.text_model
79
+ hidden_size = self.text_encoder.config.hidden_size
80
+ self.projection = nn.Linear(hidden_size, embedding_dim)
81
+
82
+ nn.init.xavier_uniform_(self.projection.weight)
83
+ nn.init.zeros_(self.projection.bias)
84
+
85
+ for param in self.text_encoder.parameters():
86
+ param.requires_grad = True
87
+ for param in self.projection.parameters():
88
+ param.requires_grad = True
89
+
90
+ def forward(self, tokens):
91
+ """
92
+ Args:
93
+ tokens:
94
+
95
+ Returns:
96
+ Tensor of shape (batch_size, embedding_dim)
97
+ """
98
+
99
+ outputs = self.text_encoder(**tokens)
100
+
101
+ return self.projection(outputs.pooler_output)