|
from dataclasses import dataclass
|
|
from typing import Callable
|
|
import torch
|
|
import torch.nn.functional as F
|
|
from utils import init_method_normal, scaled_init_method_normal
|
|
|
|
|
|
@dataclass
|
|
class MambaConfig():
|
|
base_model_type: str = "mamba"
|
|
num_layers: int = 0
|
|
hidden_size: int = 0
|
|
state_size: int = 0
|
|
vocab_size: int = 50000
|
|
expansion_factor: int = 2
|
|
conv_dimension: int = 0
|
|
conv_bias: bool = True
|
|
bias: bool = True
|
|
use_fast_path: bool = True
|
|
dt_rank: str = "auto"
|
|
dt_min: float = 0.001
|
|
dt_max: float = 0.1
|
|
dt_init: str = "random"
|
|
dt_scale: float = 1.0
|
|
dt_init_floor: float = 1e-4
|
|
rms_norm: bool = True
|
|
fused_add_norm: bool = False
|
|
residual_in_fp32: bool = True
|
|
hidden_dropout: float = 0.0
|
|
ffn_hidden_size: int = None
|
|
gated_linear_unit: bool = False
|
|
mamba_moe_layers: str = ""
|
|
routing_mode: str = "sinkhorn"
|
|
device: str = "cuda"
|
|
fp32_residual_connection: bool = False
|
|
layernorm_epsilon: float = 1e-5
|
|
layernorm_zero_centered_gamma: bool = False
|
|
add_bias_linear: bool = True
|
|
activation_func: Callable = F.gelu
|
|
num_moe_experts: int = None
|
|
|
|
|
|
init_method: Callable = None
|
|
output_layer_init_method: Callable = None
|
|
init_method_std: float = 0.02
|
|
|
|
|
|
apply_query_key_layer_scaling: bool = True
|
|
attention_softmax_in_fp32: bool = True
|
|
|
|
|
|
gated_linear_unit: bool = False
|
|
bias_gelu_fusion: bool = False
|
|
persist_layer_norm: bool = False
|
|
bias_dropout_fusion: bool = False
|
|
|
|
|
|
def __post_init__(self):
|
|
""" Python dataclass method that is used to modify attributes after initialization.
|
|
See https://docs.python.org/3/library/dataclasses.html#post-init-processing for more details.
|
|
"""
|
|
if self.apply_query_key_layer_scaling:
|
|
self.attention_softmax_in_fp32 = True
|
|
|
|
if self.ffn_hidden_size is None:
|
|
self.ffn_hidden_size = 4 * self.hidden_size
|
|
|
|
if self.apply_query_key_layer_scaling:
|
|
self.attention_softmax_in_fp32 = True
|
|
|
|
if self.bias_gelu_fusion:
|
|
if not self.add_bias_linear:
|
|
raise ValueError(
|
|
"When bias_gelu_fusion is True, add_bias_linear must also be True."
|
|
)
|
|
|
|
if self.activation_func != F.gelu:
|
|
raise ValueError(f'When bias_gelu_fusion is True, activation_func must be F.gelu.')
|
|
|
|
if self.init_method is None:
|
|
self.init_method = init_method_normal(self.init_method_std)
|
|
|
|
if self.output_layer_init_method is None:
|
|
self.output_layer_init_method = scaled_init_method_normal(
|
|
self.init_method_std, self.num_layers
|
|
)
|
|
|