carolineec commited on
Commit
c3b4829
·
verified ·
1 Parent(s): 2ce6036

Upload folder using huggingface_hub

Browse files
blip/__init__.py ADDED
File without changes
blip/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (249 Bytes). View file
 
blip/__pycache__/blip.cpython-310.pyc ADDED
Binary file (2.39 kB). View file
 
blip/__pycache__/blip_pretrain.cpython-310.pyc ADDED
Binary file (2.25 kB). View file
 
blip/__pycache__/med.cpython-310.pyc ADDED
Binary file (27.4 kB). View file
 
blip/__pycache__/vit.cpython-310.pyc ADDED
Binary file (12.1 kB). View file
 
blip/blip.py ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Adapted from BLIP (https://github.com/salesforce/BLIP)
3
+ """
4
+
5
+ import warnings
6
+ warnings.filterwarnings("ignore")
7
+
8
+ from .vit import VisionTransformer, interpolate_pos_embed
9
+ from transformers import BertTokenizer
10
+
11
+ import torch
12
+ import os
13
+ from urllib.parse import urlparse
14
+ from timm.models.hub import download_cached_file
15
+
16
+
17
+ def init_tokenizer():
18
+ tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
19
+ tokenizer.add_special_tokens({'bos_token':'[DEC]'})
20
+ tokenizer.add_special_tokens({'additional_special_tokens':['[ENC]']})
21
+ tokenizer.enc_token_id = tokenizer.additional_special_tokens_ids[0]
22
+ return tokenizer
23
+
24
+
25
+ def create_vit(vit, image_size, use_grad_checkpointing=False, ckpt_layer=0, drop_path_rate=0):
26
+
27
+ assert vit in ['base', 'large'], "vit parameter must be base or large"
28
+
29
+ if vit=='base':
30
+ vision_width = 768
31
+ visual_encoder = VisionTransformer(img_size=image_size, patch_size=16, embed_dim=vision_width, depth=12,
32
+ num_heads=12, use_grad_checkpointing=use_grad_checkpointing, ckpt_layer=ckpt_layer,
33
+ drop_path_rate=0 or drop_path_rate
34
+ )
35
+ elif vit=='large':
36
+ vision_width = 1024
37
+ visual_encoder = VisionTransformer(img_size=image_size, patch_size=16, embed_dim=vision_width, depth=24,
38
+ num_heads=16, use_grad_checkpointing=use_grad_checkpointing, ckpt_layer=ckpt_layer,
39
+ drop_path_rate=0.1 or drop_path_rate
40
+ )
41
+ return visual_encoder, vision_width
42
+
43
+
44
+ def is_url(url_or_filename):
45
+ parsed = urlparse(url_or_filename)
46
+ return parsed.scheme in ("http", "https")
47
+
48
+ def load_checkpoint(model,url_or_filename):
49
+ if is_url(url_or_filename):
50
+ cached_file = download_cached_file(url_or_filename, check_hash=False, progress=True)
51
+ checkpoint = torch.load(cached_file, map_location='cpu')
52
+ elif os.path.isfile(url_or_filename):
53
+ checkpoint = torch.load(url_or_filename, map_location='cpu')
54
+ else:
55
+ raise RuntimeError('checkpoint url or path is invalid')
56
+
57
+ state_dict = checkpoint['model']
58
+
59
+ state_dict['visual_encoder.pos_embed'] = interpolate_pos_embed(state_dict['visual_encoder.pos_embed'],model.visual_encoder)
60
+ if 'visual_encoder_m.pos_embed' in model.state_dict().keys():
61
+ state_dict['visual_encoder_m.pos_embed'] = interpolate_pos_embed(state_dict['visual_encoder_m.pos_embed'],
62
+ model.visual_encoder_m)
63
+ for key in model.state_dict().keys():
64
+ if key in state_dict.keys():
65
+ if state_dict[key].shape!=model.state_dict()[key].shape:
66
+ print(key, ": ", state_dict[key].shape, ', ', model.state_dict()[key].shape)
67
+ del state_dict[key]
68
+
69
+ msg = model.load_state_dict(state_dict,strict=False)
70
+ return model,msg
71
+
blip/blip_pretrain.py ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Adapted from BLIP (https://github.com/salesforce/BLIP)
3
+ """
4
+ import torch.distributed as dist
5
+ from torch import nn
6
+ import transformers
7
+ transformers.logging.set_verbosity_error()
8
+
9
+ from .med import BertConfig, BertModel
10
+ from .blip import create_vit, init_tokenizer, load_checkpoint
11
+
12
+
13
+ class BLIP_Pretrain(nn.Module):
14
+ def __init__(self,
15
+ med_config = 'med_config.json',
16
+ image_size = 224,
17
+ vit = 'base',
18
+ vit_grad_ckpt = False,
19
+ vit_ckpt_layer = 0,
20
+ embed_dim = 256
21
+ ):
22
+ """
23
+ Args:
24
+ med_config (str): path for the mixture of encoder-decoder model's configuration file
25
+ image_size (int): input image size
26
+ vit (str): model size of vision transformer
27
+ """
28
+ super().__init__()
29
+
30
+ self.visual_encoder, vision_width = create_vit(vit,image_size, vit_grad_ckpt, vit_ckpt_layer, 0)
31
+
32
+ self.tokenizer = init_tokenizer()
33
+ encoder_config = BertConfig.from_json_file(med_config)
34
+ encoder_config.encoder_width = vision_width
35
+ self.text_encoder = BertModel(config=encoder_config, add_pooling_layer=False)
36
+
37
+ text_width = self.text_encoder.config.hidden_size
38
+
39
+ self.vision_proj = nn.Linear(vision_width, embed_dim)
40
+ self.text_proj = nn.Linear(text_width, embed_dim)
41
+
42
+ def is_dist_avail_and_initialized():
43
+ if not dist.is_available():
44
+ return False
45
+ if not dist.is_initialized():
46
+ return False
47
+ return True
48
+
49
+ def get_rank():
50
+ if not is_dist_avail_and_initialized():
51
+ return 0
52
+ return dist.get_rank()
53
+
54
+ def blip_pretrain(pretrained='', **kwargs):
55
+ model = BLIP_Pretrain(**kwargs)
56
+ if pretrained and get_rank() == 0:
57
+ model, msg = load_checkpoint(model,pretrained)
58
+ return model
blip/med.py ADDED
@@ -0,0 +1,938 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ '''
2
+ * Adapted from BLIP (https://github.com/salesforce/BLIP)
3
+ * Based on huggingface code base
4
+ * https://github.com/huggingface/transformers/blob/v4.15.0/src/transformers/models/bert
5
+ '''
6
+
7
+ import math
8
+ from typing import Tuple
9
+
10
+ import torch
11
+ from torch import Tensor, device, nn
12
+ import torch.utils.checkpoint
13
+ from torch import nn
14
+ from torch.nn import CrossEntropyLoss
15
+
16
+ from transformers.activations import ACT2FN
17
+ from transformers.modeling_outputs import (
18
+ BaseModelOutputWithPastAndCrossAttentions,
19
+ BaseModelOutputWithPoolingAndCrossAttentions,
20
+ CausalLMOutputWithCrossAttentions
21
+ )
22
+ from transformers.modeling_utils import (
23
+ PreTrainedModel,
24
+ apply_chunking_to_forward,
25
+ find_pruneable_heads_and_indices,
26
+ prune_linear_layer,
27
+ )
28
+ from transformers.utils import logging
29
+ from transformers.models.bert.configuration_bert import BertConfig
30
+
31
+
32
+ logger = logging.get_logger(__name__)
33
+
34
+
35
+ class BertEmbeddings(nn.Module):
36
+ """Construct the embeddings from word and position embeddings."""
37
+
38
+ def __init__(self, config):
39
+ super().__init__()
40
+ self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)
41
+ self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size)
42
+
43
+ # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load
44
+ # any TensorFlow checkpoint file
45
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
46
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
47
+
48
+ # position_ids (1, len position emb) is contiguous in memory and exported when serialized
49
+ self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)))
50
+ self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
51
+
52
+ self.config = config
53
+
54
+ def forward(
55
+ self, input_ids=None, position_ids=None, inputs_embeds=None, past_key_values_length=0
56
+ ):
57
+ if input_ids is not None:
58
+ input_shape = input_ids.size()
59
+ else:
60
+ input_shape = inputs_embeds.size()[:-1]
61
+
62
+ seq_length = input_shape[1]
63
+
64
+ if position_ids is None:
65
+ position_ids = self.position_ids[:, past_key_values_length : seq_length + past_key_values_length]
66
+
67
+ if inputs_embeds is None:
68
+ inputs_embeds = self.word_embeddings(input_ids)
69
+
70
+ embeddings = inputs_embeds
71
+
72
+ if self.position_embedding_type == "absolute":
73
+ position_embeddings = self.position_embeddings(position_ids)
74
+ embeddings += position_embeddings
75
+ embeddings = self.LayerNorm(embeddings)
76
+ embeddings = self.dropout(embeddings)
77
+ return embeddings
78
+
79
+
80
+ class BertSelfAttention(nn.Module):
81
+ def __init__(self, config, is_cross_attention):
82
+ super().__init__()
83
+ self.config = config
84
+ if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
85
+ raise ValueError(
86
+ "The hidden size (%d) is not a multiple of the number of attention "
87
+ "heads (%d)" % (config.hidden_size, config.num_attention_heads)
88
+ )
89
+
90
+ self.num_attention_heads = config.num_attention_heads
91
+ self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
92
+ self.all_head_size = self.num_attention_heads * self.attention_head_size
93
+
94
+ self.query = nn.Linear(config.hidden_size, self.all_head_size)
95
+ if is_cross_attention:
96
+ self.key = nn.Linear(config.encoder_width, self.all_head_size)
97
+ self.value = nn.Linear(config.encoder_width, self.all_head_size)
98
+ else:
99
+ self.key = nn.Linear(config.hidden_size, self.all_head_size)
100
+ self.value = nn.Linear(config.hidden_size, self.all_head_size)
101
+
102
+ self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
103
+ self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
104
+ if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
105
+ self.max_position_embeddings = config.max_position_embeddings
106
+ self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size)
107
+ self.save_attention = False
108
+
109
+ def save_attn_gradients(self, attn_gradients):
110
+ self.attn_gradients = attn_gradients
111
+
112
+ def get_attn_gradients(self):
113
+ return self.attn_gradients
114
+
115
+ def save_attention_map(self, attention_map):
116
+ self.attention_map = attention_map
117
+
118
+ def get_attention_map(self):
119
+ return self.attention_map
120
+
121
+ def transpose_for_scores(self, x):
122
+ new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
123
+ x = x.view(*new_x_shape)
124
+ return x.permute(0, 2, 1, 3)
125
+
126
+ def forward(
127
+ self,
128
+ hidden_states,
129
+ attention_mask=None,
130
+ head_mask=None,
131
+ encoder_hidden_states=None,
132
+ encoder_attention_mask=None,
133
+ past_key_value=None,
134
+ output_attentions=False,
135
+ ):
136
+ mixed_query_layer = self.query(hidden_states)
137
+
138
+ # If this is instantiated as a cross-attention module, the keys
139
+ # and values come from an encoder; the attention mask needs to be
140
+ # such that the encoder's padding tokens are not attended to.
141
+ is_cross_attention = encoder_hidden_states is not None
142
+
143
+ if is_cross_attention:
144
+ key_layer = self.transpose_for_scores(self.key(encoder_hidden_states))
145
+ value_layer = self.transpose_for_scores(self.value(encoder_hidden_states))
146
+ attention_mask = encoder_attention_mask
147
+ elif past_key_value is not None:
148
+ key_layer = self.transpose_for_scores(self.key(hidden_states))
149
+ value_layer = self.transpose_for_scores(self.value(hidden_states))
150
+ key_layer = torch.cat([past_key_value[0], key_layer], dim=2)
151
+ value_layer = torch.cat([past_key_value[1], value_layer], dim=2)
152
+ else:
153
+ key_layer = self.transpose_for_scores(self.key(hidden_states))
154
+ value_layer = self.transpose_for_scores(self.value(hidden_states))
155
+
156
+ query_layer = self.transpose_for_scores(mixed_query_layer)
157
+
158
+ past_key_value = (key_layer, value_layer)
159
+
160
+ # Take the dot product between "query" and "key" to get the raw attention scores.
161
+ attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
162
+
163
+ if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
164
+ seq_length = hidden_states.size()[1]
165
+ position_ids_l = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(-1, 1)
166
+ position_ids_r = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(1, -1)
167
+ distance = position_ids_l - position_ids_r
168
+ positional_embedding = self.distance_embedding(distance + self.max_position_embeddings - 1)
169
+ positional_embedding = positional_embedding.to(dtype=query_layer.dtype) # fp16 compatibility
170
+
171
+ if self.position_embedding_type == "relative_key":
172
+ relative_position_scores = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
173
+ attention_scores = attention_scores + relative_position_scores
174
+ elif self.position_embedding_type == "relative_key_query":
175
+ relative_position_scores_query = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
176
+ relative_position_scores_key = torch.einsum("bhrd,lrd->bhlr", key_layer, positional_embedding)
177
+ attention_scores = attention_scores + relative_position_scores_query + relative_position_scores_key
178
+
179
+ attention_scores = attention_scores / math.sqrt(self.attention_head_size)
180
+ if attention_mask is not None:
181
+ # Apply the attention mask is (precomputed for all layers in BertModel forward() function)
182
+ attention_scores = attention_scores + attention_mask
183
+
184
+ # Normalize the attention scores to probabilities.
185
+ attention_probs = nn.Softmax(dim=-1)(attention_scores)
186
+
187
+ if is_cross_attention and self.save_attention:
188
+ self.save_attention_map(attention_probs)
189
+ attention_probs.register_hook(self.save_attn_gradients)
190
+
191
+ # This is actually dropping out entire tokens to attend to, which might
192
+ # seem a bit unusual, but is taken from the original Transformer paper.
193
+ attention_probs_dropped = self.dropout(attention_probs)
194
+
195
+ # Mask heads if we want to
196
+ if head_mask is not None:
197
+ attention_probs_dropped = attention_probs_dropped * head_mask
198
+
199
+ context_layer = torch.matmul(attention_probs_dropped, value_layer)
200
+
201
+ context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
202
+ new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
203
+ context_layer = context_layer.view(*new_context_layer_shape)
204
+
205
+ outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
206
+
207
+ outputs = outputs + (past_key_value,)
208
+ return outputs
209
+
210
+
211
+ class BertSelfOutput(nn.Module):
212
+ def __init__(self, config):
213
+ super().__init__()
214
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
215
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
216
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
217
+
218
+ def forward(self, hidden_states, input_tensor):
219
+ hidden_states = self.dense(hidden_states)
220
+ hidden_states = self.dropout(hidden_states)
221
+ hidden_states = self.LayerNorm(hidden_states + input_tensor)
222
+ return hidden_states
223
+
224
+
225
+ class BertAttention(nn.Module):
226
+ def __init__(self, config, is_cross_attention=False):
227
+ super().__init__()
228
+ self.self = BertSelfAttention(config, is_cross_attention)
229
+ self.output = BertSelfOutput(config)
230
+ self.pruned_heads = set()
231
+
232
+ def prune_heads(self, heads):
233
+ if len(heads) == 0:
234
+ return
235
+ heads, index = find_pruneable_heads_and_indices(
236
+ heads, self.self.num_attention_heads, self.self.attention_head_size, self.pruned_heads
237
+ )
238
+
239
+ # Prune linear layers
240
+ self.self.query = prune_linear_layer(self.self.query, index)
241
+ self.self.key = prune_linear_layer(self.self.key, index)
242
+ self.self.value = prune_linear_layer(self.self.value, index)
243
+ self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)
244
+
245
+ # Update hyper params and store pruned heads
246
+ self.self.num_attention_heads = self.self.num_attention_heads - len(heads)
247
+ self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads
248
+ self.pruned_heads = self.pruned_heads.union(heads)
249
+
250
+ def forward(
251
+ self,
252
+ hidden_states,
253
+ attention_mask=None,
254
+ head_mask=None,
255
+ encoder_hidden_states=None,
256
+ encoder_attention_mask=None,
257
+ past_key_value=None,
258
+ output_attentions=False,
259
+ ):
260
+ self_outputs = self.self(
261
+ hidden_states,
262
+ attention_mask,
263
+ head_mask,
264
+ encoder_hidden_states,
265
+ encoder_attention_mask,
266
+ past_key_value,
267
+ output_attentions,
268
+ )
269
+ attention_output = self.output(self_outputs[0], hidden_states)
270
+ outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them
271
+ return outputs
272
+
273
+
274
+ class BertIntermediate(nn.Module):
275
+ def __init__(self, config):
276
+ super().__init__()
277
+ self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
278
+ if isinstance(config.hidden_act, str):
279
+ self.intermediate_act_fn = ACT2FN[config.hidden_act]
280
+ else:
281
+ self.intermediate_act_fn = config.hidden_act
282
+
283
+ def forward(self, hidden_states):
284
+ hidden_states = self.dense(hidden_states)
285
+ hidden_states = self.intermediate_act_fn(hidden_states)
286
+ return hidden_states
287
+
288
+
289
+ class BertOutput(nn.Module):
290
+ def __init__(self, config):
291
+ super().__init__()
292
+ self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
293
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
294
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
295
+
296
+ def forward(self, hidden_states, input_tensor):
297
+ hidden_states = self.dense(hidden_states)
298
+ hidden_states = self.dropout(hidden_states)
299
+ hidden_states = self.LayerNorm(hidden_states + input_tensor)
300
+ return hidden_states
301
+
302
+
303
+ class BertLayer(nn.Module):
304
+ def __init__(self, config, layer_num):
305
+ super().__init__()
306
+ self.config = config
307
+ self.chunk_size_feed_forward = config.chunk_size_feed_forward
308
+ self.seq_len_dim = 1
309
+ self.attention = BertAttention(config)
310
+ self.layer_num = layer_num
311
+ if self.config.add_cross_attention:
312
+ self.crossattention = BertAttention(config, is_cross_attention=self.config.add_cross_attention)
313
+ self.intermediate = BertIntermediate(config)
314
+ self.output = BertOutput(config)
315
+
316
+ def forward(
317
+ self,
318
+ hidden_states,
319
+ attention_mask=None,
320
+ head_mask=None,
321
+ encoder_hidden_states=None,
322
+ encoder_attention_mask=None,
323
+ past_key_value=None,
324
+ output_attentions=False,
325
+ mode=None,
326
+ ):
327
+ # decoder uni-directional self-attention cached key/values tuple is at positions 1,2
328
+ self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None
329
+ self_attention_outputs = self.attention(
330
+ hidden_states,
331
+ attention_mask,
332
+ head_mask,
333
+ output_attentions=output_attentions,
334
+ past_key_value=self_attn_past_key_value,
335
+ )
336
+ attention_output = self_attention_outputs[0]
337
+
338
+ outputs = self_attention_outputs[1:-1]
339
+ present_key_value = self_attention_outputs[-1]
340
+
341
+ if mode=='multimodal':
342
+ assert encoder_hidden_states is not None, "encoder_hidden_states must be given for cross-attention layers"
343
+
344
+ cross_attention_outputs = self.crossattention(
345
+ attention_output,
346
+ attention_mask,
347
+ head_mask,
348
+ encoder_hidden_states,
349
+ encoder_attention_mask,
350
+ output_attentions=output_attentions,
351
+ )
352
+ attention_output = cross_attention_outputs[0]
353
+ outputs = outputs + cross_attention_outputs[1:-1] # add cross attentions if we output attention weights
354
+ layer_output = apply_chunking_to_forward(
355
+ self.feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attention_output
356
+ )
357
+ outputs = (layer_output,) + outputs
358
+
359
+ outputs = outputs + (present_key_value,)
360
+
361
+ return outputs
362
+
363
+ def feed_forward_chunk(self, attention_output):
364
+ intermediate_output = self.intermediate(attention_output)
365
+ layer_output = self.output(intermediate_output, attention_output)
366
+ return layer_output
367
+
368
+
369
+ class BertEncoder(nn.Module):
370
+ def __init__(self, config):
371
+ super().__init__()
372
+ self.config = config
373
+ self.layer = nn.ModuleList([BertLayer(config,i) for i in range(config.num_hidden_layers)])
374
+ self.gradient_checkpointing = False
375
+
376
+ def forward(
377
+ self,
378
+ hidden_states,
379
+ attention_mask=None,
380
+ head_mask=None,
381
+ encoder_hidden_states=None,
382
+ encoder_attention_mask=None,
383
+ past_key_values=None,
384
+ use_cache=None,
385
+ output_attentions=False,
386
+ output_hidden_states=False,
387
+ return_dict=True,
388
+ mode='multimodal',
389
+ ):
390
+ all_hidden_states = () if output_hidden_states else None
391
+ all_self_attentions = () if output_attentions else None
392
+ all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None
393
+
394
+ next_decoder_cache = () if use_cache else None
395
+
396
+ for i in range(self.config.num_hidden_layers):
397
+ layer_module = self.layer[i]
398
+ if output_hidden_states:
399
+ all_hidden_states = all_hidden_states + (hidden_states,)
400
+
401
+ layer_head_mask = head_mask[i] if head_mask is not None else None
402
+ past_key_value = past_key_values[i] if past_key_values is not None else None
403
+
404
+ if self.gradient_checkpointing and self.training:
405
+
406
+ if use_cache:
407
+ logger.warn(
408
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
409
+ )
410
+ use_cache = False
411
+
412
+ def create_custom_forward(module):
413
+ def custom_forward(*inputs):
414
+ return module(*inputs, past_key_value, output_attentions)
415
+
416
+ return custom_forward
417
+
418
+ layer_outputs = torch.utils.checkpoint.checkpoint(
419
+ create_custom_forward(layer_module),
420
+ hidden_states,
421
+ attention_mask,
422
+ layer_head_mask,
423
+ encoder_hidden_states,
424
+ encoder_attention_mask,
425
+ mode=mode,
426
+ )
427
+ else:
428
+ layer_outputs = layer_module(
429
+ hidden_states,
430
+ attention_mask,
431
+ layer_head_mask,
432
+ encoder_hidden_states,
433
+ encoder_attention_mask,
434
+ past_key_value,
435
+ output_attentions,
436
+ mode=mode,
437
+ )
438
+
439
+ hidden_states = layer_outputs[0]
440
+ if use_cache:
441
+ next_decoder_cache += (layer_outputs[-1],)
442
+ if output_attentions:
443
+ all_self_attentions = all_self_attentions + (layer_outputs[1],)
444
+
445
+ if output_hidden_states:
446
+ all_hidden_states = all_hidden_states + (hidden_states,)
447
+
448
+ if not return_dict:
449
+ return tuple(
450
+ v
451
+ for v in [
452
+ hidden_states,
453
+ next_decoder_cache,
454
+ all_hidden_states,
455
+ all_self_attentions,
456
+ all_cross_attentions,
457
+ ]
458
+ if v is not None
459
+ )
460
+ return BaseModelOutputWithPastAndCrossAttentions(
461
+ last_hidden_state=hidden_states,
462
+ past_key_values=next_decoder_cache,
463
+ hidden_states=all_hidden_states,
464
+ attentions=all_self_attentions,
465
+ cross_attentions=all_cross_attentions,
466
+ )
467
+
468
+
469
+ class BertPooler(nn.Module):
470
+ def __init__(self, config):
471
+ super().__init__()
472
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
473
+ self.activation = nn.Tanh()
474
+
475
+ def forward(self, hidden_states):
476
+ # We "pool" the model by simply taking the hidden state corresponding
477
+ # to the first token.
478
+ first_token_tensor = hidden_states[:, 0]
479
+ pooled_output = self.dense(first_token_tensor)
480
+ pooled_output = self.activation(pooled_output)
481
+ return pooled_output
482
+
483
+
484
+ class BertPredictionHeadTransform(nn.Module):
485
+ def __init__(self, config):
486
+ super().__init__()
487
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
488
+ if isinstance(config.hidden_act, str):
489
+ self.transform_act_fn = ACT2FN[config.hidden_act]
490
+ else:
491
+ self.transform_act_fn = config.hidden_act
492
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
493
+
494
+ def forward(self, hidden_states):
495
+ hidden_states = self.dense(hidden_states)
496
+ hidden_states = self.transform_act_fn(hidden_states)
497
+ hidden_states = self.LayerNorm(hidden_states)
498
+ return hidden_states
499
+
500
+
501
+ class BertLMPredictionHead(nn.Module):
502
+ def __init__(self, config):
503
+ super().__init__()
504
+ self.transform = BertPredictionHeadTransform(config)
505
+
506
+ # The output weights are the same as the input embeddings, but there is
507
+ # an output-only bias for each token.
508
+ self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
509
+
510
+ self.bias = nn.Parameter(torch.zeros(config.vocab_size))
511
+
512
+ # Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings`
513
+ self.decoder.bias = self.bias
514
+
515
+ def forward(self, hidden_states):
516
+ hidden_states = self.transform(hidden_states)
517
+ hidden_states = self.decoder(hidden_states)
518
+ return hidden_states
519
+
520
+
521
+ class BertOnlyMLMHead(nn.Module):
522
+ def __init__(self, config):
523
+ super().__init__()
524
+ self.predictions = BertLMPredictionHead(config)
525
+
526
+ def forward(self, sequence_output):
527
+ prediction_scores = self.predictions(sequence_output)
528
+ return prediction_scores
529
+
530
+
531
+ class BertPreTrainedModel(PreTrainedModel):
532
+ """
533
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
534
+ models.
535
+ """
536
+
537
+ config_class = BertConfig
538
+ base_model_prefix = "bert"
539
+ _keys_to_ignore_on_load_missing = [r"position_ids"]
540
+
541
+ def _init_weights(self, module):
542
+ """ Initialize the weights """
543
+ if isinstance(module, (nn.Linear, nn.Embedding)):
544
+ # Slightly different from the TF version which uses truncated_normal for initialization
545
+ # cf https://github.com/pytorch/pytorch/pull/5617
546
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
547
+ elif isinstance(module, nn.LayerNorm):
548
+ module.bias.data.zero_()
549
+ module.weight.data.fill_(1.0)
550
+ if isinstance(module, nn.Linear) and module.bias is not None:
551
+ module.bias.data.zero_()
552
+
553
+
554
+ class BertModel(BertPreTrainedModel):
555
+ """
556
+ The model can behave as an encoder (with only self-attention) as well as a decoder, in which case a layer of
557
+ cross-attention is added between the self-attention layers, following the architecture described in `Attention is
558
+ all you need <https://arxiv.org/abs/1706.03762>`__ by Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit,
559
+ Llion Jones, Aidan N. Gomez, Lukasz Kaiser and Illia Polosukhin.
560
+ argument and :obj:`add_cross_attention` set to :obj:`True`; an :obj:`encoder_hidden_states` is then expected as an
561
+ input to the forward pass.
562
+ """
563
+
564
+ def __init__(self, config, add_pooling_layer=True):
565
+ super().__init__(config)
566
+ self.config = config
567
+
568
+ self.embeddings = BertEmbeddings(config)
569
+
570
+ self.encoder = BertEncoder(config)
571
+
572
+ self.pooler = BertPooler(config) if add_pooling_layer else None
573
+
574
+ self.init_weights()
575
+
576
+
577
+ def get_input_embeddings(self):
578
+ return self.embeddings.word_embeddings
579
+
580
+ def set_input_embeddings(self, value):
581
+ self.embeddings.word_embeddings = value
582
+
583
+ def _prune_heads(self, heads_to_prune):
584
+ """
585
+ Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
586
+ class PreTrainedModel
587
+ """
588
+ for layer, heads in heads_to_prune.items():
589
+ self.encoder.layer[layer].attention.prune_heads(heads)
590
+
591
+
592
+ def get_extended_attention_mask(self, attention_mask: Tensor, input_shape: Tuple[int], device: device, is_decoder: bool) -> Tensor:
593
+ """
594
+ Makes broadcastable attention and causal masks so that future and masked tokens are ignored.
595
+
596
+ Arguments:
597
+ attention_mask (:obj:`torch.Tensor`):
598
+ Mask with ones indicating tokens to attend to, zeros for tokens to ignore.
599
+ input_shape (:obj:`Tuple[int]`):
600
+ The shape of the input to the model.
601
+ device: (:obj:`torch.device`):
602
+ The device of the input to the model.
603
+
604
+ Returns:
605
+ :obj:`torch.Tensor` The extended attention mask, with a the same dtype as :obj:`attention_mask.dtype`.
606
+ """
607
+ # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
608
+ # ourselves in which case we just need to make it broadcastable to all heads.
609
+ if attention_mask.dim() == 3:
610
+ extended_attention_mask = attention_mask[:, None, :, :]
611
+ elif attention_mask.dim() == 2:
612
+ # Provided a padding mask of dimensions [batch_size, seq_length]
613
+ # - if the model is a decoder, apply a causal mask in addition to the padding mask
614
+ # - if the model is an encoder, make the mask broadcastable to [batch_size, num_heads, seq_length, seq_length]
615
+ if is_decoder:
616
+ batch_size, seq_length = input_shape
617
+
618
+ seq_ids = torch.arange(seq_length, device=device)
619
+ causal_mask = seq_ids[None, None, :].repeat(batch_size, seq_length, 1) <= seq_ids[None, :, None]
620
+ # in case past_key_values are used we need to add a prefix ones mask to the causal mask
621
+ # causal and attention masks must have same type with pytorch version < 1.3
622
+ causal_mask = causal_mask.to(attention_mask.dtype)
623
+
624
+ if causal_mask.shape[1] < attention_mask.shape[1]:
625
+ prefix_seq_len = attention_mask.shape[1] - causal_mask.shape[1]
626
+ causal_mask = torch.cat(
627
+ [
628
+ torch.ones((batch_size, seq_length, prefix_seq_len), device=device, dtype=causal_mask.dtype),
629
+ causal_mask,
630
+ ],
631
+ axis=-1,
632
+ )
633
+
634
+ extended_attention_mask = causal_mask[:, None, :, :] * attention_mask[:, None, None, :]
635
+ else:
636
+ extended_attention_mask = attention_mask[:, None, None, :]
637
+ else:
638
+ raise ValueError(
639
+ "Wrong shape for input_ids (shape {}) or attention_mask (shape {})".format(
640
+ input_shape, attention_mask.shape
641
+ )
642
+ )
643
+
644
+ # Since attention_mask is 1.0 for positions we want to attend and 0.0 for
645
+ # masked positions, this operation will create a tensor which is 0.0 for
646
+ # positions we want to attend and -10000.0 for masked positions.
647
+ # Since we are adding it to the raw scores before the softmax, this is
648
+ # effectively the same as removing these entirely.
649
+ extended_attention_mask = extended_attention_mask.to(dtype=self.dtype) # fp16 compatibility
650
+ extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
651
+ return extended_attention_mask
652
+
653
+ def forward(
654
+ self,
655
+ input_ids=None,
656
+ attention_mask=None,
657
+ position_ids=None,
658
+ head_mask=None,
659
+ inputs_embeds=None,
660
+ encoder_embeds=None,
661
+ encoder_hidden_states=None,
662
+ encoder_attention_mask=None,
663
+ past_key_values=None,
664
+ use_cache=None,
665
+ output_attentions=None,
666
+ output_hidden_states=None,
667
+ return_dict=None,
668
+ is_decoder=False,
669
+ mode='multimodal',
670
+ ):
671
+ r"""
672
+ encoder_hidden_states (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):
673
+ Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if
674
+ the model is configured as a decoder.
675
+ encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
676
+ Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in
677
+ the cross-attention if the model is configured as a decoder. Mask values selected in ``[0, 1]``:
678
+ - 1 for tokens that are **not masked**,
679
+ - 0 for tokens that are **masked**.
680
+ past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
681
+ Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.
682
+ If :obj:`past_key_values` are used, the user can optionally input only the last :obj:`decoder_input_ids`
683
+ (those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)`
684
+ instead of all :obj:`decoder_input_ids` of shape :obj:`(batch_size, sequence_length)`.
685
+ use_cache (:obj:`bool`, `optional`):
686
+ If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up
687
+ decoding (see :obj:`past_key_values`).
688
+ """
689
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
690
+ output_hidden_states = (
691
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
692
+ )
693
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
694
+
695
+ if is_decoder:
696
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
697
+ else:
698
+ use_cache = False
699
+
700
+ if input_ids is not None and inputs_embeds is not None:
701
+ raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
702
+ elif input_ids is not None:
703
+ input_shape = input_ids.size()
704
+ batch_size, seq_length = input_shape
705
+ device = input_ids.device
706
+ elif inputs_embeds is not None:
707
+ input_shape = inputs_embeds.size()[:-1]
708
+ batch_size, seq_length = input_shape
709
+ device = inputs_embeds.device
710
+ elif encoder_embeds is not None:
711
+ input_shape = encoder_embeds.size()[:-1]
712
+ batch_size, seq_length = input_shape
713
+ device = encoder_embeds.device
714
+ else:
715
+ raise ValueError("You have to specify either input_ids or inputs_embeds or encoder_embeds")
716
+
717
+ # past_key_values_length
718
+ past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0
719
+
720
+ if attention_mask is None:
721
+ attention_mask = torch.ones(((batch_size, seq_length + past_key_values_length)), device=device)
722
+
723
+ # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
724
+ # ourselves in which case we just need to make it broadcastable to all heads.
725
+ extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape,
726
+ device, is_decoder)
727
+
728
+ # If a 2D or 3D attention mask is provided for the cross-attention
729
+ # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
730
+ if encoder_hidden_states is not None:
731
+ if type(encoder_hidden_states) == list:
732
+ encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states[0].size()
733
+ else:
734
+ encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size()
735
+ encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
736
+
737
+ if type(encoder_attention_mask) == list:
738
+ encoder_extended_attention_mask = [self.invert_attention_mask(mask) for mask in encoder_attention_mask]
739
+ elif encoder_attention_mask is None:
740
+ encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device)
741
+ encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
742
+ else:
743
+ encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
744
+ else:
745
+ encoder_extended_attention_mask = None
746
+
747
+ # Prepare head mask if needed
748
+ # 1.0 in head_mask indicate we keep the head
749
+ # attention_probs has shape bsz x n_heads x N x N
750
+ # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
751
+ # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
752
+ head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
753
+
754
+ if encoder_embeds is None:
755
+ embedding_output = self.embeddings(
756
+ input_ids=input_ids,
757
+ position_ids=position_ids,
758
+ inputs_embeds=inputs_embeds,
759
+ past_key_values_length=past_key_values_length,
760
+ )
761
+ else:
762
+ embedding_output = encoder_embeds
763
+
764
+ encoder_outputs = self.encoder(
765
+ embedding_output,
766
+ attention_mask=extended_attention_mask,
767
+ head_mask=head_mask,
768
+ encoder_hidden_states=encoder_hidden_states,
769
+ encoder_attention_mask=encoder_extended_attention_mask,
770
+ past_key_values=past_key_values,
771
+ use_cache=use_cache,
772
+ output_attentions=output_attentions,
773
+ output_hidden_states=output_hidden_states,
774
+ return_dict=return_dict,
775
+ mode=mode,
776
+ )
777
+ sequence_output = encoder_outputs[0]
778
+ pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
779
+
780
+ if not return_dict:
781
+ return (sequence_output, pooled_output) + encoder_outputs[1:]
782
+
783
+ return BaseModelOutputWithPoolingAndCrossAttentions(
784
+ last_hidden_state=sequence_output,
785
+ pooler_output=pooled_output,
786
+ past_key_values=encoder_outputs.past_key_values,
787
+ hidden_states=encoder_outputs.hidden_states,
788
+ attentions=encoder_outputs.attentions,
789
+ cross_attentions=encoder_outputs.cross_attentions,
790
+ )
791
+
792
+
793
+
794
+ class BertLMHeadModel(BertPreTrainedModel):
795
+
796
+ _keys_to_ignore_on_load_unexpected = [r"pooler"]
797
+ _keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"]
798
+
799
+ def __init__(self, config):
800
+ super().__init__(config)
801
+
802
+ self.bert = BertModel(config, add_pooling_layer=False)
803
+ self.cls = BertOnlyMLMHead(config)
804
+
805
+ self.init_weights()
806
+
807
+ def get_output_embeddings(self):
808
+ return self.cls.predictions.decoder
809
+
810
+ def set_output_embeddings(self, new_embeddings):
811
+ self.cls.predictions.decoder = new_embeddings
812
+
813
+ def forward(
814
+ self,
815
+ input_ids=None,
816
+ attention_mask=None,
817
+ position_ids=None,
818
+ head_mask=None,
819
+ inputs_embeds=None,
820
+ encoder_hidden_states=None,
821
+ encoder_attention_mask=None,
822
+ labels=None,
823
+ past_key_values=None,
824
+ use_cache=None,
825
+ output_attentions=None,
826
+ output_hidden_states=None,
827
+ return_dict=None,
828
+ return_logits=False,
829
+ is_decoder=True,
830
+ reduction='mean',
831
+ mode='multimodal',
832
+ ):
833
+ r"""
834
+ encoder_hidden_states (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):
835
+ Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if
836
+ the model is configured as a decoder.
837
+ encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
838
+ Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in
839
+ the cross-attention if the model is configured as a decoder. Mask values selected in ``[0, 1]``:
840
+ - 1 for tokens that are **not masked**,
841
+ - 0 for tokens that are **masked**.
842
+ labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
843
+ Labels for computing the left-to-right language modeling loss (next word prediction). Indices should be in
844
+ ``[-100, 0, ..., config.vocab_size]`` (see ``input_ids`` docstring) Tokens with indices set to ``-100`` are
845
+ ignored (masked), the loss is only computed for the tokens with labels n ``[0, ..., config.vocab_size]``
846
+ past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
847
+ Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.
848
+ If :obj:`past_key_values` are used, the user can optionally input only the last :obj:`decoder_input_ids`
849
+ (those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)`
850
+ instead of all :obj:`decoder_input_ids` of shape :obj:`(batch_size, sequence_length)`.
851
+ use_cache (:obj:`bool`, `optional`):
852
+ If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up
853
+ decoding (see :obj:`past_key_values`).
854
+ Returns:
855
+ Example::
856
+ >>> from transformers import BertTokenizer, BertLMHeadModel, BertConfig
857
+ >>> import torch
858
+ >>> tokenizer = BertTokenizer.from_pretrained('bert-base-cased')
859
+ >>> config = BertConfig.from_pretrained("bert-base-cased")
860
+ >>> model = BertLMHeadModel.from_pretrained('bert-base-cased', config=config)
861
+ >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")
862
+ >>> outputs = model(**inputs)
863
+ >>> prediction_logits = outputs.logits
864
+ """
865
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
866
+ if labels is not None:
867
+ use_cache = False
868
+
869
+ outputs = self.bert(
870
+ input_ids,
871
+ attention_mask=attention_mask,
872
+ position_ids=position_ids,
873
+ head_mask=head_mask,
874
+ inputs_embeds=inputs_embeds,
875
+ encoder_hidden_states=encoder_hidden_states,
876
+ encoder_attention_mask=encoder_attention_mask,
877
+ past_key_values=past_key_values,
878
+ use_cache=use_cache,
879
+ output_attentions=output_attentions,
880
+ output_hidden_states=output_hidden_states,
881
+ return_dict=return_dict,
882
+ is_decoder=is_decoder,
883
+ mode=mode,
884
+ )
885
+
886
+ sequence_output = outputs[0]
887
+ prediction_scores = self.cls(sequence_output)
888
+
889
+ if return_logits:
890
+ return prediction_scores[:, :-1, :].contiguous()
891
+
892
+ lm_loss = None
893
+ if labels is not None:
894
+ # we are doing next-token prediction; shift prediction scores and input ids by one
895
+ shifted_prediction_scores = prediction_scores[:, :-1, :].contiguous()
896
+ labels = labels[:, 1:].contiguous()
897
+ loss_fct = CrossEntropyLoss(reduction=reduction, label_smoothing=0.1)
898
+ lm_loss = loss_fct(shifted_prediction_scores.view(-1, self.config.vocab_size), labels.view(-1))
899
+ if reduction=='none':
900
+ lm_loss = lm_loss.view(prediction_scores.size(0),-1).sum(1)
901
+
902
+ if not return_dict:
903
+ output = (prediction_scores,) + outputs[2:]
904
+ return ((lm_loss,) + output) if lm_loss is not None else output
905
+
906
+ return CausalLMOutputWithCrossAttentions(
907
+ loss=lm_loss,
908
+ logits=prediction_scores,
909
+ past_key_values=outputs.past_key_values,
910
+ hidden_states=outputs.hidden_states,
911
+ attentions=outputs.attentions,
912
+ cross_attentions=outputs.cross_attentions,
913
+ )
914
+
915
+ def prepare_inputs_for_generation(self, input_ids, past=None, attention_mask=None, **model_kwargs):
916
+ input_shape = input_ids.shape
917
+ # if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly
918
+ if attention_mask is None:
919
+ attention_mask = input_ids.new_ones(input_shape)
920
+
921
+ # cut decoder_input_ids if past is used
922
+ if past is not None:
923
+ input_ids = input_ids[:, -1:]
924
+
925
+ return {
926
+ "input_ids": input_ids,
927
+ "attention_mask": attention_mask,
928
+ "past_key_values": past,
929
+ "encoder_hidden_states": model_kwargs.get("encoder_hidden_states", None),
930
+ "encoder_attention_mask": model_kwargs.get("encoder_attention_mask", None),
931
+ "is_decoder": True,
932
+ }
933
+
934
+ def _reorder_cache(self, past, beam_idx):
935
+ reordered_past = ()
936
+ for layer_past in past:
937
+ reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),)
938
+ return reordered_past
blip/multimodal_encoder/builder.py ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from .clip_encoder import CLIPVisionTower
3
+
4
+
5
+ def build_vision_tower(vision_tower_cfg, **kwargs):
6
+ vision_tower = getattr(vision_tower_cfg, 'mm_vision_tower', getattr(vision_tower_cfg, 'vision_tower', None))
7
+ is_absolute_path_exists = os.path.exists(vision_tower)
8
+ if is_absolute_path_exists or vision_tower.startswith("openai") or vision_tower.startswith("laion"):
9
+ return CLIPVisionTower(vision_tower, args=vision_tower_cfg, **kwargs)
10
+
11
+ raise ValueError(f'Unknown vision tower: {vision_tower}')
blip/multimodal_encoder/clip_encoder.py ADDED
@@ -0,0 +1,82 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+ from transformers import CLIPVisionModel, CLIPImageProcessor, CLIPVisionConfig
5
+
6
+
7
+ class CLIPVisionTower(nn.Module):
8
+ def __init__(self, vision_tower, args, delay_load=False):
9
+ super().__init__()
10
+
11
+ self.is_loaded = False
12
+
13
+ self.vision_tower_name = vision_tower
14
+ self.select_layer = args.mm_vision_select_layer
15
+ self.select_feature = getattr(args, 'mm_vision_select_feature', 'patch')
16
+
17
+ if not delay_load:
18
+ self.load_model()
19
+ else:
20
+ self.cfg_only = CLIPVisionConfig.from_pretrained(self.vision_tower_name)
21
+
22
+ def load_model(self):
23
+ self.image_processor = CLIPImageProcessor.from_pretrained(self.vision_tower_name)
24
+ self.vision_tower = CLIPVisionModel.from_pretrained(
25
+ self.vision_tower_name,
26
+ attn_implementation="flash_attention_2",
27
+ torch_dtype=torch.bfloat16
28
+ )
29
+ #self.vision_tower.requires_grad_(False)
30
+
31
+ self.is_loaded = True
32
+
33
+ def feature_select(self, image_forward_outs):
34
+ image_features = image_forward_outs.hidden_states[self.select_layer]
35
+ if self.select_feature == 'patch':
36
+ image_features = image_features[:, 1:]
37
+ elif self.select_feature == 'cls_patch':
38
+ image_features = image_features
39
+ else:
40
+ raise ValueError(f'Unexpected select feature: {self.select_feature}')
41
+ return image_features
42
+
43
+ @torch.no_grad()
44
+ def forward(self, images):
45
+ if type(images) is list:
46
+ image_features = []
47
+ for image in images:
48
+ image_forward_out = self.vision_tower(image.to(device=self.device, dtype=self.dtype).unsqueeze(0), output_hidden_states=True)
49
+ image_feature = self.feature_select(image_forward_out).to(image.dtype)
50
+ image_features.append(image_feature)
51
+ else:
52
+ image_forward_outs = self.vision_tower(images.to(device=self.device, dtype=self.dtype), output_hidden_states=True)
53
+ image_features = self.feature_select(image_forward_outs).to(images.dtype)
54
+
55
+ return image_features
56
+
57
+ @property
58
+ def dummy_feature(self):
59
+ return torch.zeros(1, self.hidden_size, device=self.device, dtype=self.dtype)
60
+
61
+ @property
62
+ def dtype(self):
63
+ return self.vision_tower.dtype
64
+
65
+ @property
66
+ def device(self):
67
+ return self.vision_tower.device
68
+
69
+ @property
70
+ def config(self):
71
+ if self.is_loaded:
72
+ return self.vision_tower.config
73
+ else:
74
+ return self.cfg_only
75
+
76
+ @property
77
+ def hidden_size(self):
78
+ return self.config.hidden_size
79
+
80
+ @property
81
+ def num_patches(self):
82
+ return (self.config.image_size // self.config.patch_size) ** 2
blip/multimodal_projector/builder.py ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+ import re
3
+
4
+
5
+ class IdentityMap(nn.Module):
6
+ def __init__(self):
7
+ super().__init__()
8
+
9
+ def forward(self, x, *args, **kwargs):
10
+ return x
11
+
12
+ @property
13
+ def config(self):
14
+ return {"mm_projector_type": 'identity'}
15
+
16
+
17
+ class SimpleResBlock(nn.Module):
18
+ def __init__(self, channels):
19
+ super().__init__()
20
+ self.pre_norm = nn.LayerNorm(channels)
21
+
22
+ self.proj = nn.Sequential(
23
+ nn.Linear(channels, channels),
24
+ nn.GELU(),
25
+ nn.Linear(channels, channels)
26
+ )
27
+ def forward(self, x):
28
+ x = self.pre_norm(x)
29
+ return x + self.proj(x)
30
+
31
+
32
+ def build_vision_projector(config, delay_load=False, **kwargs):
33
+ projector_type = getattr(config, 'mm_projector_type', 'linear')
34
+
35
+ if projector_type == 'linear':
36
+ return nn.Linear(config.mm_hidden_size, config.hidden_size)
37
+
38
+ mlp_gelu_match = re.match(r'^mlp(\d+)x_gelu$', projector_type)
39
+ if mlp_gelu_match:
40
+ mlp_depth = int(mlp_gelu_match.group(1))
41
+ modules = [nn.Linear(config.mm_hidden_size, config.hidden_size)]
42
+ for _ in range(1, mlp_depth):
43
+ modules.append(nn.GELU())
44
+ modules.append(nn.Linear(config.hidden_size, config.hidden_size))
45
+ return nn.Sequential(*modules)
46
+
47
+ if projector_type == 'identity':
48
+ return IdentityMap()
49
+
50
+ raise ValueError(f'Unknown projector type: {projector_type}')
blip/vit.py ADDED
@@ -0,0 +1,299 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ '''
2
+ * Adapted from BLIP (https://github.com/salesforce/BLIP)
3
+ * Based on timm code base
4
+ * https://github.com/rwightman/pytorch-image-models/tree/master/timm
5
+ '''
6
+
7
+ import torch
8
+ import torch.nn as nn
9
+ from functools import partial
10
+
11
+ from timm.models.vision_transformer import PatchEmbed
12
+ from timm.models.layers import trunc_normal_, DropPath
13
+ from timm.models.helpers import adapt_input_conv
14
+
15
+ from fairscale.nn.checkpoint.checkpoint_activations import checkpoint_wrapper
16
+
17
+ class Mlp(nn.Module):
18
+ """ MLP as used in Vision Transformer, MLP-Mixer and related networks
19
+ """
20
+ def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
21
+ super().__init__()
22
+ out_features = out_features or in_features
23
+ hidden_features = hidden_features or in_features
24
+ self.fc1 = nn.Linear(in_features, hidden_features)
25
+ self.act = act_layer()
26
+ self.fc2 = nn.Linear(hidden_features, out_features)
27
+ self.drop = nn.Dropout(drop)
28
+
29
+ def forward(self, x):
30
+ x = self.fc1(x)
31
+ x = self.act(x)
32
+ x = self.drop(x)
33
+ x = self.fc2(x)
34
+ x = self.drop(x)
35
+ return x
36
+
37
+
38
+ class Attention(nn.Module):
39
+ def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.):
40
+ super().__init__()
41
+ self.num_heads = num_heads
42
+ head_dim = dim // num_heads
43
+ # NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights
44
+ self.scale = qk_scale or head_dim ** -0.5
45
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
46
+ self.attn_drop = nn.Dropout(attn_drop)
47
+ self.proj = nn.Linear(dim, dim)
48
+ self.proj_drop = nn.Dropout(proj_drop)
49
+ self.attn_gradients = None
50
+ self.attention_map = None
51
+
52
+ def save_attn_gradients(self, attn_gradients):
53
+ self.attn_gradients = attn_gradients
54
+
55
+ def get_attn_gradients(self):
56
+ return self.attn_gradients
57
+
58
+ def save_attention_map(self, attention_map):
59
+ self.attention_map = attention_map
60
+
61
+ def get_attention_map(self):
62
+ return self.attention_map
63
+
64
+ def forward(self, x, register_hook=False):
65
+ B, N, C = x.shape
66
+ qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
67
+ q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
68
+
69
+ attn = (q @ k.transpose(-2, -1)) * self.scale
70
+ attn = attn.softmax(dim=-1)
71
+ attn = self.attn_drop(attn)
72
+
73
+ if register_hook:
74
+ self.save_attention_map(attn)
75
+ attn.register_hook(self.save_attn_gradients)
76
+
77
+ x = (attn @ v).transpose(1, 2).reshape(B, N, C)
78
+ x = self.proj(x)
79
+ x = self.proj_drop(x)
80
+ return x
81
+
82
+
83
+ class Block(nn.Module):
84
+
85
+ def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
86
+ drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, use_grad_checkpointing=False):
87
+ super().__init__()
88
+ self.norm1 = norm_layer(dim)
89
+ self.attn = Attention(
90
+ dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
91
+ # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
92
+ self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
93
+ self.norm2 = norm_layer(dim)
94
+ mlp_hidden_dim = int(dim * mlp_ratio)
95
+ self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
96
+
97
+ if use_grad_checkpointing:
98
+ self.attn = checkpoint_wrapper(self.attn)
99
+ self.mlp = checkpoint_wrapper(self.mlp)
100
+
101
+ def forward(self, x, register_hook=False):
102
+ x = x + self.drop_path(self.attn(self.norm1(x), register_hook=register_hook))
103
+ x = x + self.drop_path(self.mlp(self.norm2(x)))
104
+ return x
105
+
106
+
107
+ class VisionTransformer(nn.Module):
108
+ """ Vision Transformer
109
+ A PyTorch impl of : `An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale` -
110
+ https://arxiv.org/abs/2010.11929
111
+ """
112
+ def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dim=768, depth=12,
113
+ num_heads=12, mlp_ratio=4., qkv_bias=True, qk_scale=None, representation_size=None,
114
+ drop_rate=0., attn_drop_rate=0., drop_path_rate=0., norm_layer=None,
115
+ use_grad_checkpointing=False, ckpt_layer=0):
116
+ """
117
+ Args:
118
+ img_size (int, tuple): input image size
119
+ patch_size (int, tuple): patch size
120
+ in_chans (int): number of input channels
121
+ num_classes (int): number of classes for classification head
122
+ embed_dim (int): embedding dimension
123
+ depth (int): depth of transformer
124
+ num_heads (int): number of attention heads
125
+ mlp_ratio (int): ratio of mlp hidden dim to embedding dim
126
+ qkv_bias (bool): enable bias for qkv if True
127
+ qk_scale (float): override default qk scale of head_dim ** -0.5 if set
128
+ representation_size (Optional[int]): enable and set representation layer (pre-logits) to this value if set
129
+ drop_rate (float): dropout rate
130
+ attn_drop_rate (float): attention dropout rate
131
+ drop_path_rate (float): stochastic depth rate
132
+ norm_layer: (nn.Module): normalization layer
133
+ """
134
+ super().__init__()
135
+ self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
136
+ norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6)
137
+
138
+ self.patch_embed = PatchEmbed(
139
+ img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim)
140
+
141
+ num_patches = self.patch_embed.num_patches
142
+
143
+ self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
144
+ self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))
145
+ self.pos_drop = nn.Dropout(p=drop_rate)
146
+
147
+ dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
148
+ self.blocks = nn.ModuleList([
149
+ Block(
150
+ dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
151
+ drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer,
152
+ use_grad_checkpointing=(use_grad_checkpointing and i>=depth-ckpt_layer)
153
+ )
154
+ for i in range(depth)])
155
+ self.norm = norm_layer(embed_dim)
156
+
157
+ trunc_normal_(self.pos_embed, std=.02)
158
+ trunc_normal_(self.cls_token, std=.02)
159
+ self.apply(self._init_weights)
160
+
161
+ def _init_weights(self, m):
162
+ if isinstance(m, nn.Linear):
163
+ trunc_normal_(m.weight, std=.02)
164
+ if isinstance(m, nn.Linear) and m.bias is not None:
165
+ nn.init.constant_(m.bias, 0)
166
+ elif isinstance(m, nn.LayerNorm):
167
+ nn.init.constant_(m.bias, 0)
168
+ nn.init.constant_(m.weight, 1.0)
169
+
170
+ @torch.jit.ignore
171
+ def no_weight_decay(self):
172
+ return {'pos_embed', 'cls_token'}
173
+
174
+ def forward(self, x, register_blk=-1):
175
+ B = x.shape[0]
176
+ x = self.patch_embed(x)
177
+
178
+ cls_tokens = self.cls_token.expand(B, -1, -1) # stole cls_tokens impl from Phil Wang, thanks
179
+ x = torch.cat((cls_tokens, x), dim=1)
180
+
181
+ x = x + self.pos_embed[:,:x.size(1),:]
182
+ x = self.pos_drop(x)
183
+
184
+ for i,blk in enumerate(self.blocks):
185
+ x = blk(x, register_blk==i)
186
+ x = self.norm(x)
187
+
188
+ return x
189
+
190
+ @torch.jit.ignore()
191
+ def load_pretrained(self, checkpoint_path, prefix=''):
192
+ _load_weights(self, checkpoint_path, prefix)
193
+
194
+
195
+ @torch.no_grad()
196
+ def _load_weights(model: VisionTransformer, checkpoint_path: str, prefix: str = ''):
197
+ """ Load weights from .npz checkpoints for official Google Brain Flax implementation
198
+ """
199
+ import numpy as np
200
+
201
+ def _n2p(w, t=True):
202
+ if w.ndim == 4 and w.shape[0] == w.shape[1] == w.shape[2] == 1:
203
+ w = w.flatten()
204
+ if t:
205
+ if w.ndim == 4:
206
+ w = w.transpose([3, 2, 0, 1])
207
+ elif w.ndim == 3:
208
+ w = w.transpose([2, 0, 1])
209
+ elif w.ndim == 2:
210
+ w = w.transpose([1, 0])
211
+ return torch.from_numpy(w)
212
+
213
+ w = np.load(checkpoint_path)
214
+ if not prefix and 'opt/target/embedding/kernel' in w:
215
+ prefix = 'opt/target/'
216
+
217
+ if hasattr(model.patch_embed, 'backbone'):
218
+ # hybrid
219
+ backbone = model.patch_embed.backbone
220
+ stem_only = not hasattr(backbone, 'stem')
221
+ stem = backbone if stem_only else backbone.stem
222
+ stem.conv.weight.copy_(adapt_input_conv(stem.conv.weight.shape[1], _n2p(w[f'{prefix}conv_root/kernel'])))
223
+ stem.norm.weight.copy_(_n2p(w[f'{prefix}gn_root/scale']))
224
+ stem.norm.bias.copy_(_n2p(w[f'{prefix}gn_root/bias']))
225
+ if not stem_only:
226
+ for i, stage in enumerate(backbone.stages):
227
+ for j, block in enumerate(stage.blocks):
228
+ bp = f'{prefix}block{i + 1}/unit{j + 1}/'
229
+ for r in range(3):
230
+ getattr(block, f'conv{r + 1}').weight.copy_(_n2p(w[f'{bp}conv{r + 1}/kernel']))
231
+ getattr(block, f'norm{r + 1}').weight.copy_(_n2p(w[f'{bp}gn{r + 1}/scale']))
232
+ getattr(block, f'norm{r + 1}').bias.copy_(_n2p(w[f'{bp}gn{r + 1}/bias']))
233
+ if block.downsample is not None:
234
+ block.downsample.conv.weight.copy_(_n2p(w[f'{bp}conv_proj/kernel']))
235
+ block.downsample.norm.weight.copy_(_n2p(w[f'{bp}gn_proj/scale']))
236
+ block.downsample.norm.bias.copy_(_n2p(w[f'{bp}gn_proj/bias']))
237
+ embed_conv_w = _n2p(w[f'{prefix}embedding/kernel'])
238
+ else:
239
+ embed_conv_w = adapt_input_conv(
240
+ model.patch_embed.proj.weight.shape[1], _n2p(w[f'{prefix}embedding/kernel']))
241
+ model.patch_embed.proj.weight.copy_(embed_conv_w)
242
+ model.patch_embed.proj.bias.copy_(_n2p(w[f'{prefix}embedding/bias']))
243
+ model.cls_token.copy_(_n2p(w[f'{prefix}cls'], t=False))
244
+ pos_embed_w = _n2p(w[f'{prefix}Transformer/posembed_input/pos_embedding'], t=False)
245
+ if pos_embed_w.shape != model.pos_embed.shape:
246
+ pos_embed_w = resize_pos_embed( # resize pos embedding when different size from pretrained weights
247
+ pos_embed_w, model.pos_embed, getattr(model, 'num_tokens', 1), model.patch_embed.grid_size)
248
+ model.pos_embed.copy_(pos_embed_w)
249
+ model.norm.weight.copy_(_n2p(w[f'{prefix}Transformer/encoder_norm/scale']))
250
+ model.norm.bias.copy_(_n2p(w[f'{prefix}Transformer/encoder_norm/bias']))
251
+ # if isinstance(model.head, nn.Linear) and model.head.bias.shape[0] == w[f'{prefix}head/bias'].shape[-1]:
252
+ # model.head.weight.copy_(_n2p(w[f'{prefix}head/kernel']))
253
+ # model.head.bias.copy_(_n2p(w[f'{prefix}head/bias']))
254
+ # if isinstance(getattr(model.pre_logits, 'fc', None), nn.Linear) and f'{prefix}pre_logits/bias' in w:
255
+ # model.pre_logits.fc.weight.copy_(_n2p(w[f'{prefix}pre_logits/kernel']))
256
+ # model.pre_logits.fc.bias.copy_(_n2p(w[f'{prefix}pre_logits/bias']))
257
+ for i, block in enumerate(model.blocks.children()):
258
+ block_prefix = f'{prefix}Transformer/encoderblock_{i}/'
259
+ mha_prefix = block_prefix + 'MultiHeadDotProductAttention_1/'
260
+ block.norm1.weight.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/scale']))
261
+ block.norm1.bias.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/bias']))
262
+ block.attn.qkv.weight.copy_(torch.cat([
263
+ _n2p(w[f'{mha_prefix}{n}/kernel'], t=False).flatten(1).T for n in ('query', 'key', 'value')]))
264
+ block.attn.qkv.bias.copy_(torch.cat([
265
+ _n2p(w[f'{mha_prefix}{n}/bias'], t=False).reshape(-1) for n in ('query', 'key', 'value')]))
266
+ block.attn.proj.weight.copy_(_n2p(w[f'{mha_prefix}out/kernel']).flatten(1))
267
+ block.attn.proj.bias.copy_(_n2p(w[f'{mha_prefix}out/bias']))
268
+ for r in range(2):
269
+ getattr(block.mlp, f'fc{r + 1}').weight.copy_(_n2p(w[f'{block_prefix}MlpBlock_3/Dense_{r}/kernel']))
270
+ getattr(block.mlp, f'fc{r + 1}').bias.copy_(_n2p(w[f'{block_prefix}MlpBlock_3/Dense_{r}/bias']))
271
+ block.norm2.weight.copy_(_n2p(w[f'{block_prefix}LayerNorm_2/scale']))
272
+ block.norm2.bias.copy_(_n2p(w[f'{block_prefix}LayerNorm_2/bias']))
273
+
274
+
275
+ def interpolate_pos_embed(pos_embed_checkpoint, visual_encoder):
276
+ # interpolate position embedding
277
+ embedding_size = pos_embed_checkpoint.shape[-1]
278
+ num_patches = visual_encoder.patch_embed.num_patches
279
+ num_extra_tokens = visual_encoder.pos_embed.shape[-2] - num_patches
280
+ # height (== width) for the checkpoint position embedding
281
+ orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5)
282
+ # height (== width) for the new position embedding
283
+ new_size = int(num_patches ** 0.5)
284
+
285
+ if orig_size!=new_size:
286
+ # class_token and dist_token are kept unchanged
287
+ extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens]
288
+ # only the position tokens are interpolated
289
+ pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:]
290
+ pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2)
291
+ pos_tokens = torch.nn.functional.interpolate(
292
+ pos_tokens, size=(new_size, new_size), mode='bicubic', align_corners=False)
293
+ pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2)
294
+ new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1)
295
+ print('reshape position embedding from %d to %d'%(orig_size ** 2,new_size ** 2))
296
+
297
+ return new_pos_embed
298
+ else:
299
+ return pos_embed_checkpoint
med_config.json ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "BertModel"
4
+ ],
5
+ "attention_probs_dropout_prob": 0.1,
6
+ "hidden_act": "gelu",
7
+ "hidden_dropout_prob": 0.1,
8
+ "hidden_size": 768,
9
+ "initializer_range": 0.02,
10
+ "intermediate_size": 3072,
11
+ "layer_norm_eps": 1e-12,
12
+ "max_position_embeddings": 512,
13
+ "model_type": "bert",
14
+ "num_attention_heads": 12,
15
+ "num_hidden_layers": 12,
16
+ "pad_token_id": 0,
17
+ "type_vocab_size": 2,
18
+ "vocab_size": 30524,
19
+ "encoder_width": 768,
20
+ "add_cross_attention": true
21
+ }
model.py ADDED
@@ -0,0 +1,222 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Adapted from ImageReward (https://github.com/THUDM/ImageReward)
3
+ """
4
+
5
+ import os
6
+ import torch
7
+ import torch.nn as nn
8
+ from PIL import Image
9
+
10
+ # from .config import cyclereward_args
11
+ from blip.blip_pretrain import blip_pretrain
12
+ from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize
13
+ from huggingface_hub import PyTorchModelHubMixin
14
+
15
+ try:
16
+ from torchvision.transforms import InterpolationMode
17
+ BICUBIC = InterpolationMode.BICUBIC
18
+ except ImportError:
19
+ BICUBIC = Image.BICUBIC
20
+
21
+ cyclereward_args = {
22
+ 'blip_path': 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_large.pth',
23
+ 'vit': 'large',
24
+ 'image_size': 224,
25
+ 'mlp_dim': 768
26
+ }
27
+
28
+ def _convert_image_to_rgb(image):
29
+ return image.convert("RGB")
30
+
31
+ def _transform(n_px):
32
+ return Compose([
33
+ Resize(n_px, interpolation=BICUBIC),
34
+ CenterCrop(n_px),
35
+ _convert_image_to_rgb,
36
+ ToTensor(),
37
+ Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)),
38
+ ])
39
+
40
+ class CycleReward(nn.Module, PyTorchModelHubMixin):
41
+ def __init__(self, device='cpu',
42
+ model_type='CycleReward-Combo',
43
+ max_length=128,
44
+ fix_rate=0.7,
45
+ med_config=None,
46
+ ):
47
+ super().__init__()
48
+ self.device = device
49
+ self.model_type = model_type
50
+ self.max_length = max_length
51
+
52
+ self.blip = blip_pretrain(
53
+ pretrained=cyclereward_args['blip_path'],
54
+ med_config=med_config,
55
+ image_size=cyclereward_args['image_size'],
56
+ vit=cyclereward_args['vit']
57
+ )
58
+ self.preprocess = _transform(cyclereward_args['image_size'])
59
+ self.mlp = MLP(cyclereward_args['mlp_dim'])
60
+
61
+ for name, parms in self.blip.named_parameters():
62
+ if '_proj' in name:
63
+ parms.requires_grad_(False)
64
+
65
+ # fix certain ratio of layers (setting from ImageReward)
66
+ self.image_layer_num = 24 if cyclereward_args['vit'] == 'large' else 12
67
+ if fix_rate > 0:
68
+ text_fix_num = "layer.{}".format(int(12 * fix_rate))
69
+ image_fix_num = "blocks.{}".format(int(self.image_layer_num * fix_rate))
70
+ for name, parms in self.blip.text_encoder.named_parameters():
71
+ parms.requires_grad_(False)
72
+ if text_fix_num in name:
73
+ break
74
+ for name, parms in self.blip.visual_encoder.named_parameters():
75
+ parms.requires_grad_(False)
76
+ if image_fix_num in name:
77
+ break
78
+
79
+ def forward(self, batch):
80
+ if 'Combo' in self.model_type:
81
+ text_reward = self.text_reward(batch)
82
+ image_reward = self.image_reward(batch)
83
+
84
+ elif 'I2T' in self.model_type:
85
+ text_reward = self.text_reward(batch)
86
+ image_reward = None
87
+
88
+ elif 'T2I' in self.model_type:
89
+ text_reward = None
90
+ image_reward = self.image_reward(batch)
91
+
92
+ return text_reward, image_reward
93
+
94
+ def text_reward(self, batch):
95
+ images, preferred_ids, preferred_mask, rejected_ids, rejected_mask = batch["images"], batch["preferred_ids"], batch["preferred_mask"], batch["rejected_ids"], batch["rejected_mask"]
96
+ images = images.to(self.device)
97
+ preferred_ids = preferred_ids.to(self.device)
98
+ preferred_mask = preferred_mask.to(self.device)
99
+ rejected_ids = rejected_ids.to(self.device)
100
+ rejected_mask = rejected_mask.to(self.device)
101
+
102
+ # encode image
103
+ image_embeds = self.blip.visual_encoder(images)
104
+ image_atts = torch.ones(image_embeds.size()[:-1],dtype=torch.long).to(self.device)
105
+
106
+ # encode preferred
107
+ preferred_embeds = self.blip.text_encoder(
108
+ preferred_ids,
109
+ attention_mask=preferred_mask,
110
+ encoder_hidden_states=image_embeds,
111
+ encoder_attention_mask=image_atts,
112
+ return_dict=True,
113
+ ).last_hidden_state
114
+ preferred_embeds = preferred_embeds[:,0,:].float()
115
+
116
+ # encode rejected
117
+ rejected_embeds = self.blip.text_encoder(
118
+ rejected_ids,
119
+ attention_mask=rejected_mask,
120
+ encoder_hidden_states=image_embeds,
121
+ encoder_attention_mask=image_atts,
122
+ return_dict=True,
123
+ ).last_hidden_state
124
+ rejected_embeds = rejected_embeds[:,0,:].float()
125
+
126
+ preferred_reward = self.mlp(preferred_embeds)
127
+ rejected_reward = self.mlp(rejected_embeds)
128
+ reward = torch.concat((preferred_reward, rejected_reward), dim=1)
129
+
130
+ return reward
131
+
132
+ def image_reward(self, batch):
133
+ prompt_ids, prompt_mask, image_preferred, image_rejected = batch["prompt_ids"], batch["prompt_mask"], batch["image_preferred"], batch["image_rejected"]
134
+ image_preferred = image_preferred.to(self.device)
135
+ image_rejected = image_rejected.to(self.device)
136
+ prompt_ids = prompt_ids.view(prompt_ids.shape[0], -1).to(self.device)
137
+ prompt_mask = prompt_mask.view(prompt_mask.shape[0], -1).to(self.device)
138
+
139
+ # encode image
140
+ image_embeds_preferred = self.blip.visual_encoder(image_preferred)
141
+ image_atts_preferred = torch.ones(image_embeds_preferred.size()[:-1],dtype=torch.long).to(self.device)
142
+
143
+ image_embeds_rejected = self.blip.visual_encoder(image_rejected)
144
+ image_atts_rejected = torch.ones(image_embeds_rejected.size()[:-1],dtype=torch.long).to(self.device)
145
+
146
+ # encode preferred
147
+ preferred_embeds = self.blip.text_encoder(
148
+ prompt_ids,
149
+ attention_mask=prompt_mask,
150
+ encoder_hidden_states=image_embeds_preferred,
151
+ encoder_attention_mask=image_atts_preferred,
152
+ return_dict=True,
153
+ ).last_hidden_state
154
+ preferred_embeds = preferred_embeds[:,0,:].float()
155
+
156
+ # encode rejected
157
+ rejected_embeds = self.blip.text_encoder(
158
+ prompt_ids,
159
+ attention_mask=prompt_mask,
160
+ encoder_hidden_states=image_embeds_rejected,
161
+ encoder_attention_mask=image_atts_rejected,
162
+ return_dict=True,
163
+ ).last_hidden_state
164
+ rejected_embeds = rejected_embeds[:,0,:].float()
165
+
166
+ preferred_reward = self.mlp(preferred_embeds)
167
+ rejected_reward = self.mlp(rejected_embeds)
168
+ reward = torch.concat((preferred_reward, rejected_reward), dim=1)
169
+
170
+ return reward
171
+
172
+ @torch.no_grad()
173
+ def score(self, image, prompt):
174
+ text_input = self.blip.tokenizer(prompt, padding='max_length', truncation=True, max_length=self.max_length, return_tensors="pt").to(self.device)
175
+
176
+ image_embeds = self.blip.visual_encoder(image)
177
+ image_atts = torch.ones(image_embeds.size()[:-1],dtype=torch.long).to(self.device)
178
+
179
+ text_embeds = self.blip.text_encoder(
180
+ text_input.input_ids,
181
+ attention_mask=text_input.attention_mask,
182
+ encoder_hidden_states=image_embeds,
183
+ encoder_attention_mask=image_atts,
184
+ return_dict=True,
185
+ ).last_hidden_state
186
+ text_embeds = text_embeds[:,0,:].float()
187
+
188
+ rewards = self.mlp(text_embeds)
189
+ return rewards
190
+
191
+ class MLP(nn.Module):
192
+ def __init__(self, input_size):
193
+ super().__init__()
194
+ self.layers = nn.Sequential(
195
+ nn.Linear(input_size, 1024),
196
+ nn.GELU(),
197
+
198
+ nn.Linear(1024, 128),
199
+ nn.GELU(),
200
+
201
+ nn.Linear(128, 64),
202
+ nn.GELU(),
203
+
204
+ nn.Linear(64, 16),
205
+ nn.GELU(),
206
+
207
+ nn.Linear(16, 1)
208
+ )
209
+
210
+ def init_weights(m):
211
+ if isinstance(m, nn.Linear):
212
+ nn.init.xavier_uniform_(m.weight)
213
+ if m.bias is not None:
214
+ nn.init.zeros_(m.bias)
215
+
216
+ self.layers.apply(init_weights)
217
+
218
+ def forward(self, input):
219
+ return self.layers(input)
220
+
221
+
222
+