|
import copy |
|
from typing import Any, Dict, Union |
|
|
|
from transformers import CLIPVisionConfig, PretrainedConfig |
|
|
|
from .configuration_llama import CustomLlamaConfig |
|
|
|
|
|
class POINTSChatConfig(PretrainedConfig): |
|
model_type = "points_chat" |
|
is_composition = True |
|
"""Configuration class for `POINTS`. |
|
|
|
Args: |
|
vision_config (Union[dict, CLIPVisionConfig]): |
|
Configuration of the vision model. |
|
llm_config (Union[dict, LlamaConfig]): |
|
Configuration of the language model. |
|
""" |
|
|
|
def __init__(self, |
|
vision_config: Union[dict, CLIPVisionConfig], |
|
llm_config: Union[dict, CustomLlamaConfig], |
|
**kwargs) -> None: |
|
super().__init__(**kwargs) |
|
if isinstance(vision_config, dict): |
|
self.vision_config = CLIPVisionConfig(**vision_config) |
|
else: |
|
self.vision_config = vision_config |
|
if isinstance(llm_config, dict): |
|
self.llm_config = CustomLlamaConfig(**llm_config) |
|
else: |
|
self.llm_config = llm_config |
|
|
|
def to_dict(self) -> Dict[str, Any]: |
|
output = copy.deepcopy(self.__dict__) |
|
output["vision_config"] = self.vision_config.to_dict() |
|
output["llm_config"] = self.llm_config.to_dict() |
|
return output |
|
|