MosRat commited on
Commit
56b1f4f
·
verified ·
1 Parent(s): 0053fac

Upload folder using huggingface_hub

Browse files
__init__.py ADDED
File without changes
__pycache__/__init__.cpython-312.pyc ADDED
Binary file (151 Bytes). View file
 
__pycache__/configuration_gex.cpython-312.pyc ADDED
Binary file (412 Bytes). View file
 
__pycache__/modeling_gex.cpython-312.pyc ADDED
Binary file (21.3 kB). View file
 
config.json CHANGED
@@ -22,7 +22,7 @@
22
  "rms_norm_eps": 1e-06,
23
  "rope_scaling": null,
24
  "rope_theta": 1000000.0,
25
- "sliding_window": null,
26
  "tie_word_embeddings": true,
27
  "torch_dtype": "bfloat16",
28
  "transformers_version": "4.50.1",
 
22
  "rms_norm_eps": 1e-06,
23
  "rope_scaling": null,
24
  "rope_theta": 1000000.0,
25
+ "sliding_window": 4096,
26
  "tie_word_embeddings": true,
27
  "torch_dtype": "bfloat16",
28
  "transformers_version": "4.50.1",
modeling_gex.py ADDED
@@ -0,0 +1,440 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+ from typing import List, Optional, Tuple, Type, Union
4
+ from functools import partial
5
+ import torch.nn as nn
6
+ from torch.nn import CrossEntropyLoss
7
+ from typing import Type
8
+ from torchvision import transforms
9
+ from transformers.cache_utils import Cache, DynamicCache
10
+ from transformers.modeling_outputs import (
11
+ BaseModelOutputWithPast,
12
+ CausalLMOutputWithPast,
13
+ )
14
+
15
+ from torchvision.transforms.functional import InterpolationMode
16
+ from transformers import (
17
+ Qwen2Config,
18
+ Qwen2Model,
19
+ Qwen2ForCausalLM,
20
+ )
21
+
22
+ from .configuration_gex import GexConfig
23
+
24
+
25
+ LayerNorm = partial(nn.LayerNorm, eps=1e-6)
26
+
27
+
28
+ class GexImageEvalProcessor:
29
+ def __init__(self, image_size=1024, mean=None, std=None):
30
+ if mean is None:
31
+ mean = (0.48145466, 0.4578275, 0.40821073)
32
+ if std is None:
33
+ std = (0.26862954, 0.26130258, 0.27577711)
34
+
35
+ self.normalize = transforms.Normalize(mean, std)
36
+
37
+ self.transform = transforms.Compose(
38
+ [
39
+ transforms.Resize(
40
+ (image_size, image_size), interpolation=InterpolationMode.BICUBIC
41
+ ),
42
+ transforms.ToTensor(),
43
+ self.normalize,
44
+ ]
45
+ )
46
+
47
+ def __call__(self, item):
48
+ return self.transform(item)
49
+
50
+
51
+ class LayerNorm2d(nn.Module):
52
+ def __init__(self, num_channels: int, eps: float = 1e-6) -> None:
53
+ super().__init__()
54
+ self.weight = nn.Parameter(torch.ones(num_channels))
55
+ self.bias = nn.Parameter(torch.zeros(num_channels))
56
+ self.num_channels = num_channels
57
+ self.eps = eps
58
+
59
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
60
+ x = x.permute(0, 2, 3, 1)
61
+ return torch.nn.functional.layer_norm(
62
+ x,
63
+ normalized_shape=(self.num_channels,),
64
+ weight=self.weight,
65
+ bias=self.bias,
66
+ eps=self.eps,
67
+ ).permute(0, 3, 1, 2)
68
+
69
+
70
+ class PatchEmbed(nn.Module):
71
+ """
72
+ Image to Patch Embedding.
73
+ """
74
+
75
+ def __init__(
76
+ self,
77
+ kernel_size: Tuple[int, int] = (16, 16),
78
+ stride: Tuple[int, int] = (16, 16),
79
+ in_chans: int = 3,
80
+ embed_dim: int = 768,
81
+ ) -> None:
82
+ """
83
+ Args:
84
+ kernel_size (Tuple): kernel size of the projection layer.
85
+ stride (Tuple): stride of the projection layer.
86
+ padding (Tuple): padding size of the projection layer.
87
+ in_chans (int): Number of input image channels.
88
+ embed_dim (int): Patch embedding dimension.
89
+ """
90
+ super().__init__()
91
+
92
+ self.proj = nn.Conv2d(
93
+ in_chans, embed_dim, kernel_size=kernel_size, stride=stride
94
+ )
95
+
96
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
97
+ x = self.proj(x)
98
+ # B C H W -> B H W C
99
+ x = x.permute(0, 2, 3, 1)
100
+ return x
101
+
102
+
103
+ class Attention(nn.Module):
104
+ def __init__(
105
+ self,
106
+ dim: int,
107
+ num_heads: int = 8,
108
+ input_size: Optional[Tuple[int, int]] = None,
109
+ ) -> None:
110
+ super().__init__()
111
+ self.num_heads = num_heads
112
+ self.head_dim = 64
113
+ self.scale = 64**-0.5
114
+ self.seq_len = input_size[0] * input_size[1]
115
+ self.input_size = input_size
116
+
117
+ self.qkv = nn.Linear(dim, dim * 3, bias=True)
118
+ self.proj = nn.Linear(dim, dim)
119
+
120
+ # self.rel_pos_h = nn.Parameter(torch.zeros(2 * input_size[0] - 1, self.head_dim))
121
+ # self.rel_pos_w = nn.Parameter(torch.zeros(2 * input_size[1] - 1, self.head_dim))
122
+ self.rel_pos_h = nn.Parameter(torch.zeros(input_size[0],input_size[0], self.head_dim))
123
+ self.rel_pos_w = nn.Parameter(torch.zeros(input_size[1],input_size[1], self.head_dim))
124
+
125
+ def init_rel_pos(self):
126
+ q_size, k_size = self.input_size
127
+ q_coords = torch.arange(q_size)[:, None]
128
+
129
+ k_coords = torch.arange(k_size)[None, :]
130
+ relative_coords = (q_coords - k_coords) + (k_size - 1)
131
+
132
+ self.rel_pos_h = nn.Parameter(self.rel_pos_h.data[relative_coords.long()])
133
+ self.rel_pos_w = nn.Parameter(self.rel_pos_w.data[relative_coords.long()])
134
+
135
+ def get_attn_bias(self, q: torch.Tensor):
136
+ q = q.view(-1, *self.input_size, 64)
137
+
138
+ rel_h = torch.einsum("bhwc,hkc->bhwk", q, self.rel_pos_h)
139
+ rel_w = torch.einsum("bhwc,wkc->bhwk", q, self.rel_pos_w)
140
+
141
+ return (rel_h.unsqueeze(-1) + rel_w.unsqueeze(-2)).reshape(
142
+ -1, self.num_heads, self.seq_len, self.seq_len
143
+ )
144
+
145
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
146
+ qkv = torch.split(
147
+ self.qkv(x).view(-1, self.seq_len, 3 * 768),
148
+ 768,
149
+ dim=2,
150
+ )
151
+
152
+ q, k, v = (
153
+ i.unflatten(-1, (self.num_heads, -1)).transpose(1, 2).contiguous()
154
+ for i in qkv
155
+ )
156
+
157
+ attn_bias = self.get_attn_bias(q)
158
+
159
+ attn_output = torch.nn.functional.scaled_dot_product_attention(
160
+ q, k, v, attn_mask=attn_bias, is_causal=False
161
+ )
162
+ attn_output = attn_output.transpose(1, 2).flatten(-2)
163
+
164
+ x = self.proj(attn_output)
165
+
166
+ return x.view(-1, *self.input_size, 768)
167
+
168
+
169
+ class MLP(nn.Module):
170
+ def __init__(
171
+ self,
172
+ ):
173
+ super().__init__()
174
+ self.lin1 = nn.Linear(768, 4 * 768)
175
+ self.lin2 = nn.Linear(4 * 768, 768)
176
+ self.act = nn.GELU()
177
+
178
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
179
+ return self.lin2(self.act(self.lin1(x)))
180
+
181
+
182
+ class Block(nn.Module):
183
+ def __init__(self, idx: int, window_size: int = 14):
184
+ super().__init__()
185
+
186
+ self.idx = idx
187
+ self.window_size = window_size
188
+
189
+ self.norm1 = LayerNorm(768)
190
+
191
+ self.attn = Attention(
192
+ dim=768,
193
+ num_heads=12,
194
+ input_size=(64, 64) if window_size == 0 else (14, 14),
195
+ )
196
+
197
+ self.norm2 = LayerNorm(768)
198
+ self.mlp = MLP()
199
+
200
+ @staticmethod
201
+ def window_partition(x: torch.Tensor) -> torch.Tensor:
202
+ x = F.pad(x, (0, 0, 0, 6, 0, 6))
203
+ x = (
204
+ x.view(-1, 5, 14, 5, 14, 768)
205
+ .permute(0, 1, 3, 2, 4, 5)
206
+ .contiguous()
207
+ .view(-1, 14, 14, 768)
208
+ )
209
+ return x
210
+
211
+ @staticmethod
212
+ def window_unpartition(x: torch.Tensor) -> torch.Tensor:
213
+ x = (
214
+ x.view(-1, 5, 5, 14, 14, 768)
215
+ .permute(0, 1, 3, 2, 4, 5)
216
+ .contiguous()
217
+ .view(-1, 70, 70, 768)
218
+ )
219
+ return x[:, :64, :64, :].contiguous()
220
+
221
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
222
+ shortcut = x
223
+ x = self.norm1(x)
224
+ if self.window_size > 0:
225
+ x = self.window_partition(x)
226
+
227
+ x = self.attn(x)
228
+
229
+ if self.window_size > 0:
230
+ x = self.window_unpartition(x)
231
+
232
+ x = shortcut + x
233
+ x = x + self.mlp(self.norm2(x))
234
+
235
+ return x
236
+
237
+
238
+ class GexVit(nn.Module):
239
+ def __init__(self, global_attn_indexes=[2, 5, 8, 11], **kwargs):
240
+ super().__init__()
241
+ self.global_attn_indexes = global_attn_indexes
242
+ self.patch_embed = PatchEmbed()
243
+
244
+ self.pos_embed = nn.Parameter(torch.zeros(1, 64, 64, 768))
245
+
246
+ self.blocks = nn.ModuleList(
247
+ [
248
+ Block(idx=i, window_size=14 if i not in global_attn_indexes else 0)
249
+ for i in range(12)
250
+ ]
251
+ )
252
+
253
+ self.neck = nn.ModuleList(
254
+ [
255
+ nn.Conv2d(
256
+ 768,
257
+ 256,
258
+ kernel_size=1,
259
+ bias=False,
260
+ ),
261
+ LayerNorm2d(256),
262
+ nn.Conv2d(
263
+ 256,
264
+ 256,
265
+ kernel_size=3,
266
+ padding=1,
267
+ bias=False,
268
+ ),
269
+ LayerNorm2d(256),
270
+ ]
271
+ )
272
+
273
+ self.net_2 = nn.Conv2d(256, 512, kernel_size=3, stride=2, padding=1, bias=False)
274
+ self.net_3 = nn.Conv2d(
275
+ 512, 1024, kernel_size=3, stride=2, padding=1, bias=False
276
+ )
277
+
278
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
279
+ x = self.patch_embed(x)
280
+ x = x + self.pos_embed
281
+
282
+ for blk in self.blocks:
283
+ x = blk(x)
284
+
285
+ x = x.permute(0, 3, 1, 2)
286
+
287
+ for m in self.neck:
288
+ x = m(x)
289
+
290
+ x = self.net_2(x)
291
+ x = self.net_3(x)
292
+
293
+ return x
294
+
295
+
296
+ class GexQwenModel(Qwen2Model):
297
+ config_class = GexConfig
298
+
299
+ def __init__(self, config: Qwen2Config):
300
+ super().__init__(config)
301
+ self.vit = GexVit()
302
+ self.vit.eval()
303
+ self.vit_proj = nn.Linear(1024, 1024)
304
+ self.vit_proj.eval()
305
+
306
+ for param in self.vit.parameters():
307
+ param.requires_grad = False
308
+ for param in self.vit_proj.parameters():
309
+ param.requires_grad = False
310
+ def forward(
311
+ self,
312
+ input_ids: torch.LongTensor = None,
313
+ attention_mask: Optional[torch.Tensor] = None,
314
+ position_ids: Optional[torch.LongTensor] = None,
315
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
316
+ inputs_embeds: Optional[torch.FloatTensor] = None,
317
+ use_cache: Optional[bool] = None,
318
+ output_attentions: Optional[bool] = None,
319
+ output_hidden_states: Optional[bool] = None,
320
+ images: Optional[torch.FloatTensor] = None,
321
+ return_dict: Optional[bool] = None,
322
+ **kwargs,
323
+ ) -> Union[Tuple, BaseModelOutputWithPast]:
324
+ if images is not None:
325
+ assert input_ids is None, input_ids
326
+ input_ids = None
327
+ attention_mask = None
328
+ kwargs["is_causal"] = True
329
+ with torch.no_grad():
330
+ vit_feature = self.vit_proj(
331
+ self.vit(images).flatten(2).permute(0, 2, 1)
332
+ )
333
+ inputs_embeds = vit_feature
334
+
335
+ # print(input_ids, images)
336
+ if inputs_embeds is None and input_ids is not None:
337
+ inputs_embeds = self.embed_tokens(input_ids)
338
+
339
+ return super().forward(
340
+ input_ids=None,
341
+ attention_mask=attention_mask,
342
+ past_key_values=past_key_values,
343
+ inputs_embeds=inputs_embeds,
344
+ use_cache=use_cache,
345
+ position_ids=position_ids,
346
+ output_attentions=output_attentions,
347
+ output_hidden_states=output_hidden_states,
348
+ return_dict=return_dict,
349
+ **kwargs,
350
+ )
351
+
352
+
353
+ class GexQwenForCausalLM(Qwen2ForCausalLM):
354
+ config_class = GexConfig
355
+ # supports_gradient_checkpointing = True
356
+
357
+ def __init__(self, config):
358
+ super().__init__(config)
359
+ self.model = GexQwenModel(config)
360
+
361
+ self.vocab_size = config.vocab_size
362
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
363
+
364
+ # Initialize weights and apply final processing
365
+ self.post_init()
366
+
367
+ self.has_image = False
368
+ self.image_preprocess = GexImageEvalProcessor()
369
+
370
+ def forward(
371
+ self,
372
+ input_ids: torch.LongTensor = None,
373
+ attention_mask: Optional[torch.Tensor] = None,
374
+ position_ids: Optional[torch.LongTensor] = None,
375
+ past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
376
+ inputs_embeds: Optional[torch.FloatTensor] = None,
377
+ labels: Optional[torch.LongTensor] = None,
378
+ use_cache: Optional[bool] = None,
379
+ output_attentions: Optional[bool] = None,
380
+ output_hidden_states: Optional[bool] = None,
381
+ return_dict: Optional[bool] = None,
382
+ cache_position: Optional[torch.LongTensor] = None,
383
+ logits_to_keep: Union[int, torch.Tensor] = 0,
384
+ images: Optional[torch.FloatTensor] = None,
385
+ **kwargs,
386
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
387
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
388
+ output_hidden_states = (
389
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
390
+ )
391
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
392
+
393
+ if self.has_image:
394
+ input_ids = None
395
+ self.has_image = False
396
+ else:
397
+ images = None
398
+
399
+ outputs = self.model(
400
+ input_ids=input_ids,
401
+ attention_mask=attention_mask,
402
+ position_ids=position_ids,
403
+ past_key_values=past_key_values,
404
+ inputs_embeds=inputs_embeds,
405
+ use_cache=use_cache,
406
+ output_attentions=output_attentions,
407
+ output_hidden_states=output_hidden_states,
408
+ return_dict=return_dict,
409
+ cache_position=cache_position,
410
+ images=images,
411
+ **kwargs,
412
+ )
413
+
414
+ hidden_states = outputs[0]
415
+ # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
416
+ slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
417
+ logits = self.lm_head(hidden_states[:, slice_indices, :])
418
+
419
+ loss = None
420
+ if labels is not None:
421
+ loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs)
422
+
423
+ if not return_dict:
424
+ output = (logits,) + outputs[1:]
425
+ return (loss,) + output if loss is not None else output
426
+
427
+ return CausalLMOutputWithPast(
428
+ loss=loss,
429
+ logits=logits,
430
+ past_key_values=outputs.past_key_values,
431
+ hidden_states=outputs.hidden_states,
432
+ attentions=outputs.attentions,
433
+ )
434
+
435
+ @torch.no_grad
436
+ def generate(self,*args,**kwargs):
437
+ self.has_image = True
438
+ res = super().generate(*args, **kwargs)
439
+ self.has_image = False
440
+ return res