from transformers.models.llama.configuration_llama import LlamaConfig | |
class LlamaMlaConfig(LlamaConfig): | |
model_type = "llama_mla" | |
base_model_pp_plan = None | |
base_model_tp_plan = None | |
def __init__( | |
self, | |
kv_lora_rank = 512, | |
q_lora_rank = 1536, | |
qk_rope_head_dim = 64, | |
v_head_dim = 128, | |
qk_nope_head_dim = 128, | |
**kwargs, | |
): | |
super().__init__(**kwargs) | |
self.kv_lora_rank = kv_lora_rank | |
self.q_lora_rank = q_lora_rank | |
self.qk_rope_head_dim = qk_rope_head_dim | |
self.v_head_dim = v_head_dim | |
self.qk_nope_head_dim = qk_nope_head_dim | |
__ALL__ = [ | |
"LlamaMlaConfig", | |
] |