radna commited on
Commit
05b2c8b
1 Parent(s): e065dd1

Update modeling_internvl_chat.py

Browse files
Files changed (1) hide show
  1. modeling_internvl_chat.py +10 -8
modeling_internvl_chat.py CHANGED
@@ -20,8 +20,8 @@ from .configuration_internvl_chat import InternVLChatConfig
20
  from .modeling_intern_vit import InternVisionModel
21
  from .modeling_phi3 import Phi3ForCausalLM
22
 
23
- logger = logging.get_logger(__name__)
24
 
 
25
 
26
  class InternVLChatModel(PreTrainedModel):
27
  config_class = InternVLChatConfig
@@ -31,6 +31,7 @@ class InternVLChatModel(PreTrainedModel):
31
  def __init__(self, config: InternVLChatConfig, vision_model=None, language_model=None):
32
  super().__init__(config)
33
 
 
34
  image_size = config.force_image_size or config.vision_config.image_size
35
  patch_size = config.vision_config.patch_size
36
  self.patch_size = patch_size
@@ -42,10 +43,12 @@ class InternVLChatModel(PreTrainedModel):
42
 
43
  logger.info(f'num_image_token: {self.num_image_token}')
44
  logger.info(f'ps_version: {self.ps_version}')
 
45
  if vision_model is not None:
46
  self.vision_model = vision_model
47
  else:
48
  self.vision_model = InternVisionModel(config.vision_config)
 
49
  if language_model is not None:
50
  self.language_model = language_model
51
  else:
@@ -56,6 +59,11 @@ class InternVLChatModel(PreTrainedModel):
56
  else:
57
  raise NotImplementedError(f'{config.llm_config.architectures[0]} is not implemented.')
58
 
 
 
 
 
 
59
  vit_hidden_size = config.vision_config.hidden_size
60
  llm_hidden_size = config.llm_config.hidden_size
61
 
@@ -66,13 +74,6 @@ class InternVLChatModel(PreTrainedModel):
66
  nn.Linear(llm_hidden_size, llm_hidden_size)
67
  )
68
 
69
- # if config.force_image_size != config.vision_config.image_size:
70
- # self.vision_model.resize_pos_embeddings(
71
- # old_size=config.vision_config.image_size,
72
- # new_size=config.force_image_size,
73
- # patch_size=config.vision_config.patch_size
74
- # )
75
-
76
  self.img_context_token_id = None
77
  self.neftune_alpha = None
78
 
@@ -82,6 +83,7 @@ class InternVLChatModel(PreTrainedModel):
82
  if config.use_llm_lora:
83
  self.wrap_llm_lora(r=config.use_llm_lora, lora_alpha=2 * config.use_llm_lora)
84
 
 
85
  def wrap_backbone_lora(self, r=128, lora_alpha=256, lora_dropout=0.05):
86
  lora_config = LoraConfig(
87
  r=r,
 
20
  from .modeling_intern_vit import InternVisionModel
21
  from .modeling_phi3 import Phi3ForCausalLM
22
 
 
23
 
24
+ logger = logging.get_logger(__name__)
25
 
26
  class InternVLChatModel(PreTrainedModel):
27
  config_class = InternVLChatConfig
 
31
  def __init__(self, config: InternVLChatConfig, vision_model=None, language_model=None):
32
  super().__init__(config)
33
 
34
+ # Initialize components
35
  image_size = config.force_image_size or config.vision_config.image_size
36
  patch_size = config.vision_config.patch_size
37
  self.patch_size = patch_size
 
43
 
44
  logger.info(f'num_image_token: {self.num_image_token}')
45
  logger.info(f'ps_version: {self.ps_version}')
46
+
47
  if vision_model is not None:
48
  self.vision_model = vision_model
49
  else:
50
  self.vision_model = InternVisionModel(config.vision_config)
51
+
52
  if language_model is not None:
53
  self.language_model = language_model
54
  else:
 
59
  else:
60
  raise NotImplementedError(f'{config.llm_config.architectures[0]} is not implemented.')
61
 
62
+ # Wrap models with DataParallel for multi-GPU support
63
+ if torch.cuda.device_count() > 1:
64
+ self.vision_model = nn.DataParallel(self.vision_model)
65
+ self.language_model = nn.DataParallel(self.language_model)
66
+
67
  vit_hidden_size = config.vision_config.hidden_size
68
  llm_hidden_size = config.llm_config.hidden_size
69
 
 
74
  nn.Linear(llm_hidden_size, llm_hidden_size)
75
  )
76
 
 
 
 
 
 
 
 
77
  self.img_context_token_id = None
78
  self.neftune_alpha = None
79
 
 
83
  if config.use_llm_lora:
84
  self.wrap_llm_lora(r=config.use_llm_lora, lora_alpha=2 * config.use_llm_lora)
85
 
86
+
87
  def wrap_backbone_lora(self, r=128, lora_alpha=256, lora_dropout=0.05):
88
  lora_config = LoraConfig(
89
  r=r,