|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"""Sarashina2Vision model configuration""" |
|
|
|
from typing import Any, Optional |
|
|
|
from transformers import LlamaConfig, PretrainedConfig |
|
from transformers.models.qwen2_vl.configuration_qwen2_vl import Qwen2VLVisionConfig |
|
from transformers.utils import logging |
|
|
|
logger = logging.get_logger(__name__) |
|
|
|
|
|
class Sarashina2VisionConfig(PretrainedConfig): |
|
""" |
|
This is the configuration class to store the configuration of a [`Sarashina2VisionModel`]. It is used to instantiate a |
|
Sarashina2Vision model according to the specified arguments, defining the model architecture. |
|
|
|
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the |
|
documentation from [`PretrainedConfig`] for more information. |
|
|
|
Args: |
|
vision_config (`Dict`, *optional*): |
|
The config for the visual encoder initialization. |
|
text_config (`Dict`, *optional*): |
|
The config for the text decoder initialization. |
|
image_token_index (`int`): |
|
image token id. |
|
start_image_token_index (`int`): |
|
start image token id. |
|
end_image_token_index (`int`): |
|
end image token id. |
|
""" |
|
|
|
model_type = "sarashina2_vision" |
|
|
|
def __init__( |
|
self, |
|
vision_config: Optional[dict[str, Any]] = None, |
|
text_config: Optional[dict[str, Any]] = None, |
|
image_token_index: int = 14, |
|
start_image_token_index: int = 102397, |
|
end_image_token_index: int = 102398, |
|
**kwargs, |
|
): |
|
if isinstance(text_config, dict): |
|
self.text_config = LlamaConfig(**text_config) |
|
elif isinstance(text_config, LlamaConfig): |
|
self.text_config = text_config |
|
elif text_config is None: |
|
self.text_config = LlamaConfig() |
|
|
|
if isinstance(vision_config, dict): |
|
self.vision_config = Qwen2VLVisionConfig(**vision_config) |
|
elif isinstance(vision_config, Qwen2VLVisionConfig): |
|
self.vision_config = vision_config |
|
elif vision_config is None: |
|
self.vision_config = Qwen2VLVisionConfig() |
|
|
|
self.image_token_index = image_token_index |
|
self.start_image_token_index = start_image_token_index |
|
self.end_image_token_index = end_image_token_index |
|
|
|
super().__init__(**kwargs) |
|
|