POINTS-Yi-1-5-9B-Chat / configuration_points_chat.py
YuanLiuuuuuu's picture
Add files using upload-large-folder tool
e28b279 verified
raw
history blame
1.31 kB
import copy
from typing import Any, Dict, Union
from transformers import CLIPVisionConfig, PretrainedConfig
from .configuration_llama import CustomLlamaConfig
class POINTSChatConfig(PretrainedConfig):
model_type = "points_chat"
is_composition = True
"""Configuration class for `POINTS`.
Args:
vision_config (Union[dict, CLIPVisionConfig]):
Configuration of the vision model.
llm_config (Union[dict, LlamaConfig]):
Configuration of the language model.
"""
def __init__(self,
vision_config: Union[dict, CLIPVisionConfig],
llm_config: Union[dict, CustomLlamaConfig],
**kwargs) -> None:
super().__init__(**kwargs)
if isinstance(vision_config, dict):
self.vision_config = CLIPVisionConfig(**vision_config)
else:
self.vision_config = vision_config
if isinstance(llm_config, dict):
self.llm_config = CustomLlamaConfig(**llm_config)
else:
self.llm_config = llm_config
def to_dict(self) -> Dict[str, Any]:
output = copy.deepcopy(self.__dict__)
output["vision_config"] = self.vision_config.to_dict()
output["llm_config"] = self.llm_config.to_dict()
return output