Update modeling_internvl_chat.py
Browse files- modeling_internvl_chat.py +10 -8
modeling_internvl_chat.py
CHANGED
@@ -20,8 +20,8 @@ from .configuration_internvl_chat import InternVLChatConfig
|
|
20 |
from .modeling_intern_vit import InternVisionModel
|
21 |
from .modeling_phi3 import Phi3ForCausalLM
|
22 |
|
23 |
-
logger = logging.get_logger(__name__)
|
24 |
|
|
|
25 |
|
26 |
class InternVLChatModel(PreTrainedModel):
|
27 |
config_class = InternVLChatConfig
|
@@ -31,6 +31,7 @@ class InternVLChatModel(PreTrainedModel):
|
|
31 |
def __init__(self, config: InternVLChatConfig, vision_model=None, language_model=None):
|
32 |
super().__init__(config)
|
33 |
|
|
|
34 |
image_size = config.force_image_size or config.vision_config.image_size
|
35 |
patch_size = config.vision_config.patch_size
|
36 |
self.patch_size = patch_size
|
@@ -42,10 +43,12 @@ class InternVLChatModel(PreTrainedModel):
|
|
42 |
|
43 |
logger.info(f'num_image_token: {self.num_image_token}')
|
44 |
logger.info(f'ps_version: {self.ps_version}')
|
|
|
45 |
if vision_model is not None:
|
46 |
self.vision_model = vision_model
|
47 |
else:
|
48 |
self.vision_model = InternVisionModel(config.vision_config)
|
|
|
49 |
if language_model is not None:
|
50 |
self.language_model = language_model
|
51 |
else:
|
@@ -56,6 +59,11 @@ class InternVLChatModel(PreTrainedModel):
|
|
56 |
else:
|
57 |
raise NotImplementedError(f'{config.llm_config.architectures[0]} is not implemented.')
|
58 |
|
|
|
|
|
|
|
|
|
|
|
59 |
vit_hidden_size = config.vision_config.hidden_size
|
60 |
llm_hidden_size = config.llm_config.hidden_size
|
61 |
|
@@ -66,13 +74,6 @@ class InternVLChatModel(PreTrainedModel):
|
|
66 |
nn.Linear(llm_hidden_size, llm_hidden_size)
|
67 |
)
|
68 |
|
69 |
-
# if config.force_image_size != config.vision_config.image_size:
|
70 |
-
# self.vision_model.resize_pos_embeddings(
|
71 |
-
# old_size=config.vision_config.image_size,
|
72 |
-
# new_size=config.force_image_size,
|
73 |
-
# patch_size=config.vision_config.patch_size
|
74 |
-
# )
|
75 |
-
|
76 |
self.img_context_token_id = None
|
77 |
self.neftune_alpha = None
|
78 |
|
@@ -82,6 +83,7 @@ class InternVLChatModel(PreTrainedModel):
|
|
82 |
if config.use_llm_lora:
|
83 |
self.wrap_llm_lora(r=config.use_llm_lora, lora_alpha=2 * config.use_llm_lora)
|
84 |
|
|
|
85 |
def wrap_backbone_lora(self, r=128, lora_alpha=256, lora_dropout=0.05):
|
86 |
lora_config = LoraConfig(
|
87 |
r=r,
|
|
|
20 |
from .modeling_intern_vit import InternVisionModel
|
21 |
from .modeling_phi3 import Phi3ForCausalLM
|
22 |
|
|
|
23 |
|
24 |
+
logger = logging.get_logger(__name__)
|
25 |
|
26 |
class InternVLChatModel(PreTrainedModel):
|
27 |
config_class = InternVLChatConfig
|
|
|
31 |
def __init__(self, config: InternVLChatConfig, vision_model=None, language_model=None):
|
32 |
super().__init__(config)
|
33 |
|
34 |
+
# Initialize components
|
35 |
image_size = config.force_image_size or config.vision_config.image_size
|
36 |
patch_size = config.vision_config.patch_size
|
37 |
self.patch_size = patch_size
|
|
|
43 |
|
44 |
logger.info(f'num_image_token: {self.num_image_token}')
|
45 |
logger.info(f'ps_version: {self.ps_version}')
|
46 |
+
|
47 |
if vision_model is not None:
|
48 |
self.vision_model = vision_model
|
49 |
else:
|
50 |
self.vision_model = InternVisionModel(config.vision_config)
|
51 |
+
|
52 |
if language_model is not None:
|
53 |
self.language_model = language_model
|
54 |
else:
|
|
|
59 |
else:
|
60 |
raise NotImplementedError(f'{config.llm_config.architectures[0]} is not implemented.')
|
61 |
|
62 |
+
# Wrap models with DataParallel for multi-GPU support
|
63 |
+
if torch.cuda.device_count() > 1:
|
64 |
+
self.vision_model = nn.DataParallel(self.vision_model)
|
65 |
+
self.language_model = nn.DataParallel(self.language_model)
|
66 |
+
|
67 |
vit_hidden_size = config.vision_config.hidden_size
|
68 |
llm_hidden_size = config.llm_config.hidden_size
|
69 |
|
|
|
74 |
nn.Linear(llm_hidden_size, llm_hidden_size)
|
75 |
)
|
76 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
77 |
self.img_context_token_id = None
|
78 |
self.neftune_alpha = None
|
79 |
|
|
|
83 |
if config.use_llm_lora:
|
84 |
self.wrap_llm_lora(r=config.use_llm_lora, lora_alpha=2 * config.use_llm_lora)
|
85 |
|
86 |
+
|
87 |
def wrap_backbone_lora(self, r=128, lora_alpha=256, lora_dropout=0.05):
|
88 |
lora_config = LoraConfig(
|
89 |
r=r,
|