izhx commited on
Commit
ea96574
·
verified ·
1 Parent(s): 46c35bc

Update modeling_gme_qwen2vl.py

Browse files
Files changed (1) hide show
  1. modeling_gme_qwen2vl.py +40 -16
modeling_gme_qwen2vl.py CHANGED
@@ -12,16 +12,25 @@ import torch
12
  from PIL import Image
13
  from torch.utils.data import DataLoader
14
  from tqdm.autonotebook import tqdm
15
- from transformers import (
16
- AutoProcessor,
17
- PreTrainedModel,
18
  Qwen2VLConfig,
19
  Qwen2VLForConditionalGeneration,
 
 
 
 
 
 
 
 
20
  )
21
- import os
22
 
23
 
24
  class GmeQwen2VLConfig(Qwen2VLConfig):
 
 
25
  def __init__(
26
  self,
27
  min_image_tokens: int = 256,
@@ -35,14 +44,25 @@ class GmeQwen2VLConfig(Qwen2VLConfig):
35
  self.max_length = max_length
36
 
37
 
38
- class GmeQwen2VLForVision2Seq(PreTrainedModel):
39
  config_class = GmeQwen2VLConfig
40
- base_model_prefix: str = "base"
 
 
 
 
 
 
 
 
41
 
42
  def __init__(self, config: GmeQwen2VLConfig, **kwargs: Any) -> None:
43
  super().__init__(config)
44
- self.base = Qwen2VLForConditionalGeneration.from_pretrained(config._name_or_path)
45
- self.base.tie_weights() # It's important to produce same outputs.
 
 
 
46
 
47
  min_pixels: int = config.min_image_tokens * 28 * 28
48
  max_pixels: int = config.max_image_tokens * 28 * 28
@@ -55,6 +75,9 @@ class GmeQwen2VLForVision2Seq(PreTrainedModel):
55
  self.default_instruction: str = "You are a helpful assistant."
56
  self.sep: str = " "
57
 
 
 
 
58
  def forward(
59
  self,
60
  input_ids: Optional[torch.LongTensor] = None,
@@ -70,21 +93,21 @@ class GmeQwen2VLForVision2Seq(PreTrainedModel):
70
  **kwargs
71
  ) -> torch.Tensor:
72
  if inputs_embeds is None:
73
- inputs_embeds = self.base.model.embed_tokens(input_ids)
74
  if pixel_values is not None:
75
- pixel_values = pixel_values.type(self.base.visual.get_dtype())
76
- image_embeds = self.base.visual(pixel_values, grid_thw=image_grid_thw).to(inputs_embeds.device)
77
- image_mask = input_ids == self.base.config.image_token_id
78
  inputs_embeds[image_mask] = image_embeds
79
  # if pixel_values_videos is not None:
80
- # pixel_values_videos = pixel_values_videos.type(self.base.visual.get_dtype())
81
- # video_embeds = self.base.visual(pixel_values_videos, grid_thw=video_grid_thw).to(inputs_embeds.device)
82
- # video_mask = input_ids == self.base.config.video_token_id
83
  # inputs_embeds[video_mask] = video_embeds
84
  if attention_mask is not None:
85
  attention_mask = attention_mask.to(inputs_embeds.device)
86
 
87
- outputs = self.base.model(
88
  input_ids=None,
89
  position_ids=position_ids,
90
  attention_mask=attention_mask,
@@ -311,3 +334,4 @@ def fetch_image(image: str | Image.Image, size_factor: int = IMAGE_FACTOR) -> Im
311
 
312
  return image
313
  ###
 
 
12
  from PIL import Image
13
  from torch.utils.data import DataLoader
14
  from tqdm.autonotebook import tqdm
15
+ from transformers import AutoProcessor, PreTrainedModel
16
+ from transformers.models.qwen2_vl.modeling_qwen2_vl import (
17
+ Qwen2VisionTransformerPretrainedModel,
18
  Qwen2VLConfig,
19
  Qwen2VLForConditionalGeneration,
20
+ Qwen2VLModel,
21
+ )
22
+ from transformers.utils.versions import require_version
23
+
24
+
25
+ require_version(
26
+ "transformers<4.52.0",
27
+ "This code has some issues with transformers>=4.52.0, please downgrade: pip install transformers==4.51.3"
28
  )
 
29
 
30
 
31
  class GmeQwen2VLConfig(Qwen2VLConfig):
32
+ # model_type = ''
33
+
34
  def __init__(
35
  self,
36
  min_image_tokens: int = 256,
 
44
  self.max_length = max_length
45
 
46
 
47
+ class GmeQwen2VL(PreTrainedModel):
48
  config_class = GmeQwen2VLConfig
49
+ base_model_prefix = "model"
50
+ supports_gradient_checkpointing = True
51
+ _no_split_modules = ["Qwen2VLDecoderLayer", "Qwen2VLVisionBlock"]
52
+ # _skip_keys_device_placement = "past_key_values"
53
+ _supports_flash_attn_2 = True
54
+ _supports_sdpa = True
55
+ # _supports_cache_class = True
56
+ _supports_static_cache = False # TODO (joao): fix. torch.compile failing probably due to `cache_positions`
57
+ # _tied_weights_keys = ["lm_head.weight"]
58
 
59
  def __init__(self, config: GmeQwen2VLConfig, **kwargs: Any) -> None:
60
  super().__init__(config)
61
+ self.visual = Qwen2VisionTransformerPretrainedModel._from_config(config.vision_config)
62
+ self.model = Qwen2VLModel(config)
63
+ self.vocab_size = config.vocab_size
64
+ # self.lm_head = torch.nn.Linear(config.hidden_size, config.vocab_size, bias=False)
65
+ self.rope_deltas = None # cache rope_deltas here
66
 
67
  min_pixels: int = config.min_image_tokens * 28 * 28
68
  max_pixels: int = config.max_image_tokens * 28 * 28
 
75
  self.default_instruction: str = "You are a helpful assistant."
76
  self.sep: str = " "
77
 
78
+ # Initialize weights and apply final processing
79
+ self.post_init()
80
+
81
  def forward(
82
  self,
83
  input_ids: Optional[torch.LongTensor] = None,
 
93
  **kwargs
94
  ) -> torch.Tensor:
95
  if inputs_embeds is None:
96
+ inputs_embeds = self.model.get_input_embeddings()(input_ids)
97
  if pixel_values is not None:
98
+ pixel_values = pixel_values.type(self.visual.get_dtype())
99
+ image_embeds = self.visual(pixel_values, grid_thw=image_grid_thw).to(inputs_embeds.device)
100
+ image_mask = input_ids == self.config.image_token_id
101
  inputs_embeds[image_mask] = image_embeds
102
  # if pixel_values_videos is not None:
103
+ # pixel_values_videos = pixel_values_videos.type(self.visual.get_dtype())
104
+ # video_embeds = self.visual(pixel_values_videos, grid_thw=video_grid_thw).to(inputs_embeds.device)
105
+ # video_mask = input_ids == self.config.video_token_id
106
  # inputs_embeds[video_mask] = video_embeds
107
  if attention_mask is not None:
108
  attention_mask = attention_mask.to(inputs_embeds.device)
109
 
110
+ outputs = self.model(
111
  input_ids=None,
112
  position_ids=position_ids,
113
  attention_mask=attention_mask,
 
334
 
335
  return image
336
  ###
337
+