| from typing import Literal | |
| from transformers import AutoConfig | |
| from transformers.configuration_utils import PretrainedConfig | |
| from transformers.models.auto import CONFIG_MAPPING | |
| from transformers.models.mistral import MistralConfig | |
| NVEMBED_TYPE = "nvembed" | |
| LATENT_ATTENTION_TYPE = "latent_attention" | |
| BIDIR_MISTRAL_TYPE = "bidir_mistral" | |
| class NVEmbedConfig(PretrainedConfig): | |
| model_type = "nvembed" | |
| is_composition = False | |
| def __init__( | |
| self, | |
| latent_attention_config=None, | |
| text_config=None, | |
| padding_side: Literal["right", "left"]="right", | |
| add_pad_token: bool=True, | |
| is_mask_instruction: bool = True, | |
| add_eos: bool=True, | |
| mask_type: str="b", | |
| **kwargs, | |
| ): | |
| if isinstance(latent_attention_config, dict): | |
| latent_attention_config["model_type"] = ( | |
| latent_attention_config["model_type"] if "model_type" in latent_attention_config else LATENT_ATTENTION_TYPE | |
| ) | |
| latent_attention_config = CONFIG_MAPPING[latent_attention_config["model_type"]](**latent_attention_config) | |
| elif latent_attention_config is None: | |
| latent_attention_config = CONFIG_MAPPING[LATENT_ATTENTION_TYPE]() | |
| self.latent_attention_config = latent_attention_config | |
| if isinstance(text_config, dict): | |
| text_config["model_type"] = text_config["model_type"] if "model_type" in text_config else "llama" | |
| text_config = CONFIG_MAPPING[text_config["model_type"]](**text_config) | |
| elif text_config is None: | |
| text_config = None | |
| self.text_config = text_config | |
| self.padding_side = padding_side | |
| self.is_mask_instruction = is_mask_instruction | |
| self.add_pad_token = add_pad_token | |
| self.add_eos = add_eos | |
| self.mask_type = mask_type | |
| if "hidden_size" in kwargs: | |
| self.hidden_size = kwargs["hidden_size"] | |
| else: | |
| self.hidden_size = 4096 | |
| super().__init__(**kwargs) | |
| class LatentAttentionConfig(PretrainedConfig): | |
| model_type = LATENT_ATTENTION_TYPE | |
| is_composition = False | |
| _name_or_path = "latent_attention" | |
| def __init__( | |
| self, | |
| num_latents_value: int=512, | |
| num_cross_heads: int=8, | |
| output_normalize: bool=True, | |
| hidden_dim: int=4096, | |
| latent_dim: int=4096, | |
| cross_dim_head: int=4096, | |
| **kwargs, | |
| ): | |
| self.num_latents_value = num_latents_value | |
| self.num_cross_heads = num_cross_heads | |
| self.output_normalize = output_normalize | |
| self.hidden_dim = hidden_dim | |
| self.latent_dim = latent_dim | |
| self.cross_dim_head = cross_dim_head | |
| super().__init__(**kwargs) | |
| class BidirectionalMistralConfig(MistralConfig): | |
| model_type = BIDIR_MISTRAL_TYPE | |
| keys_to_ignore_at_inference = ["past_key_values"] | |
| AutoConfig.register(NVEMBED_TYPE, NVEmbedConfig) | |
| AutoConfig.register(LATENT_ATTENTION_TYPE, LatentAttentionConfig) | |
| AutoConfig.register(BIDIR_MISTRAL_TYPE, BidirectionalMistralConfig) | |
| NVEmbedConfig.register_for_auto_class() | |
| LatentAttentionConfig.register_for_auto_class() | |
| BidirectionalMistralConfig.register_for_auto_class() | |