| from transformers import PretrainedConfig | |
| import torch | |
| class XLMRobertaFlashConfig(PretrainedConfig): | |
| def __init__( | |
| self, | |
| vocab_size=30522, | |
| hidden_size=768, | |
| num_hidden_layers=12, | |
| num_attention_heads=12, | |
| intermediate_size=3072, | |
| hidden_act="gelu", | |
| hidden_dropout_prob=0.1, | |
| attention_probs_dropout_prob=0.1, | |
| max_position_embeddings=512, | |
| type_vocab_size=2, | |
| initializer_range=0.02, | |
| layer_norm_eps=1e-12, | |
| pad_token_id=1, | |
| bos_token_id=0, | |
| eos_token_id=2, | |
| position_embedding_type="absolute", | |
| use_cache=True, | |
| classifier_dropout=None, | |
| lora_adaptations=None, | |
| lora_rank=4, | |
| lora_dropout_p=0.0, | |
| lora_alpha=1, | |
| lora_main_params_trainable=False, | |
| load_trained_adapters=False, | |
| use_flash_attn=True, | |
| torch_dtype=None, | |
| emb_pooler=None, | |
| matryoshka_dimensions=None, | |
| truncate_dim=None, | |
| **kwargs, | |
| ): | |
| super().__init__(pad_token_id=pad_token_id, bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs) | |
| self.vocab_size = vocab_size | |
| self.hidden_size = hidden_size | |
| self.num_hidden_layers = num_hidden_layers | |
| self.num_attention_heads = num_attention_heads | |
| self.hidden_act = hidden_act | |
| self.intermediate_size = intermediate_size | |
| self.hidden_dropout_prob = hidden_dropout_prob | |
| self.attention_probs_dropout_prob = attention_probs_dropout_prob | |
| self.max_position_embeddings = max_position_embeddings | |
| self.type_vocab_size = type_vocab_size | |
| self.initializer_range = initializer_range | |
| self.layer_norm_eps = layer_norm_eps | |
| self.position_embedding_type = position_embedding_type | |
| self.use_cache = use_cache | |
| self.classifier_dropout = classifier_dropout | |
| self.load_trained_adapters = load_trained_adapters | |
| self.lora_adaptations = lora_adaptations | |
| self.lora_rank = lora_rank | |
| self.lora_dropout_p = lora_dropout_p | |
| self.lora_alpha = lora_alpha | |
| self.lora_main_params_trainable = lora_main_params_trainable | |
| self.use_flash_attn = use_flash_attn | |
| self.emb_pooler = emb_pooler | |
| self.matryoshka_dimensions = matryoshka_dimensions | |
| self.truncate_dim = truncate_dim | |
| if torch_dtype and hasattr(torch, torch_dtype) and type(getattr(torch, torch_dtype)) is torch.dtype: | |
| self.torch_dtype = getattr(torch, torch_dtype) | |
| else: | |
| self.torch_dtype = torch_dtype | |