Update modeling_internvl_chat.py
Browse files- modeling_internvl_chat.py +10 -8
modeling_internvl_chat.py
CHANGED
|
@@ -20,8 +20,8 @@ from .configuration_internvl_chat import InternVLChatConfig
|
|
| 20 |
from .modeling_intern_vit import InternVisionModel
|
| 21 |
from .modeling_phi3 import Phi3ForCausalLM
|
| 22 |
|
| 23 |
-
logger = logging.get_logger(__name__)
|
| 24 |
|
|
|
|
| 25 |
|
| 26 |
class InternVLChatModel(PreTrainedModel):
|
| 27 |
config_class = InternVLChatConfig
|
|
@@ -31,6 +31,7 @@ class InternVLChatModel(PreTrainedModel):
|
|
| 31 |
def __init__(self, config: InternVLChatConfig, vision_model=None, language_model=None):
|
| 32 |
super().__init__(config)
|
| 33 |
|
|
|
|
| 34 |
image_size = config.force_image_size or config.vision_config.image_size
|
| 35 |
patch_size = config.vision_config.patch_size
|
| 36 |
self.patch_size = patch_size
|
|
@@ -42,10 +43,12 @@ class InternVLChatModel(PreTrainedModel):
|
|
| 42 |
|
| 43 |
logger.info(f'num_image_token: {self.num_image_token}')
|
| 44 |
logger.info(f'ps_version: {self.ps_version}')
|
|
|
|
| 45 |
if vision_model is not None:
|
| 46 |
self.vision_model = vision_model
|
| 47 |
else:
|
| 48 |
self.vision_model = InternVisionModel(config.vision_config)
|
|
|
|
| 49 |
if language_model is not None:
|
| 50 |
self.language_model = language_model
|
| 51 |
else:
|
|
@@ -56,6 +59,11 @@ class InternVLChatModel(PreTrainedModel):
|
|
| 56 |
else:
|
| 57 |
raise NotImplementedError(f'{config.llm_config.architectures[0]} is not implemented.')
|
| 58 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 59 |
vit_hidden_size = config.vision_config.hidden_size
|
| 60 |
llm_hidden_size = config.llm_config.hidden_size
|
| 61 |
|
|
@@ -66,13 +74,6 @@ class InternVLChatModel(PreTrainedModel):
|
|
| 66 |
nn.Linear(llm_hidden_size, llm_hidden_size)
|
| 67 |
)
|
| 68 |
|
| 69 |
-
# if config.force_image_size != config.vision_config.image_size:
|
| 70 |
-
# self.vision_model.resize_pos_embeddings(
|
| 71 |
-
# old_size=config.vision_config.image_size,
|
| 72 |
-
# new_size=config.force_image_size,
|
| 73 |
-
# patch_size=config.vision_config.patch_size
|
| 74 |
-
# )
|
| 75 |
-
|
| 76 |
self.img_context_token_id = None
|
| 77 |
self.neftune_alpha = None
|
| 78 |
|
|
@@ -82,6 +83,7 @@ class InternVLChatModel(PreTrainedModel):
|
|
| 82 |
if config.use_llm_lora:
|
| 83 |
self.wrap_llm_lora(r=config.use_llm_lora, lora_alpha=2 * config.use_llm_lora)
|
| 84 |
|
|
|
|
| 85 |
def wrap_backbone_lora(self, r=128, lora_alpha=256, lora_dropout=0.05):
|
| 86 |
lora_config = LoraConfig(
|
| 87 |
r=r,
|
|
|
|
| 20 |
from .modeling_intern_vit import InternVisionModel
|
| 21 |
from .modeling_phi3 import Phi3ForCausalLM
|
| 22 |
|
|
|
|
| 23 |
|
| 24 |
+
logger = logging.get_logger(__name__)
|
| 25 |
|
| 26 |
class InternVLChatModel(PreTrainedModel):
|
| 27 |
config_class = InternVLChatConfig
|
|
|
|
| 31 |
def __init__(self, config: InternVLChatConfig, vision_model=None, language_model=None):
|
| 32 |
super().__init__(config)
|
| 33 |
|
| 34 |
+
# Initialize components
|
| 35 |
image_size = config.force_image_size or config.vision_config.image_size
|
| 36 |
patch_size = config.vision_config.patch_size
|
| 37 |
self.patch_size = patch_size
|
|
|
|
| 43 |
|
| 44 |
logger.info(f'num_image_token: {self.num_image_token}')
|
| 45 |
logger.info(f'ps_version: {self.ps_version}')
|
| 46 |
+
|
| 47 |
if vision_model is not None:
|
| 48 |
self.vision_model = vision_model
|
| 49 |
else:
|
| 50 |
self.vision_model = InternVisionModel(config.vision_config)
|
| 51 |
+
|
| 52 |
if language_model is not None:
|
| 53 |
self.language_model = language_model
|
| 54 |
else:
|
|
|
|
| 59 |
else:
|
| 60 |
raise NotImplementedError(f'{config.llm_config.architectures[0]} is not implemented.')
|
| 61 |
|
| 62 |
+
# Wrap models with DataParallel for multi-GPU support
|
| 63 |
+
if torch.cuda.device_count() > 1:
|
| 64 |
+
self.vision_model = nn.DataParallel(self.vision_model)
|
| 65 |
+
self.language_model = nn.DataParallel(self.language_model)
|
| 66 |
+
|
| 67 |
vit_hidden_size = config.vision_config.hidden_size
|
| 68 |
llm_hidden_size = config.llm_config.hidden_size
|
| 69 |
|
|
|
|
| 74 |
nn.Linear(llm_hidden_size, llm_hidden_size)
|
| 75 |
)
|
| 76 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 77 |
self.img_context_token_id = None
|
| 78 |
self.neftune_alpha = None
|
| 79 |
|
|
|
|
| 83 |
if config.use_llm_lora:
|
| 84 |
self.wrap_llm_lora(r=config.use_llm_lora, lora_alpha=2 * config.use_llm_lora)
|
| 85 |
|
| 86 |
+
|
| 87 |
def wrap_backbone_lora(self, r=128, lora_alpha=256, lora_dropout=0.05):
|
| 88 |
lora_config = LoraConfig(
|
| 89 |
r=r,
|