|  | import math | 
					
						
						|  | from typing import Optional, Union | 
					
						
						|  |  | 
					
						
						|  | import torch | 
					
						
						|  | import torch.nn as nn | 
					
						
						|  |  | 
					
						
						|  | from .config import InitFnType, ModelConfig | 
					
						
						|  | from .util import StrEnum | 
					
						
						|  |  | 
					
						
						|  | __all__ = ["init_weights", "ModuleType"] | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | class ModuleType(StrEnum): | 
					
						
						|  | in_module = "in" | 
					
						
						|  | out_module = "out" | 
					
						
						|  | emb = "emb" | 
					
						
						|  | final_out = "final_out" | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def init_weights( | 
					
						
						|  | config: ModelConfig, | 
					
						
						|  | module: Union[nn.Linear, nn.Embedding], | 
					
						
						|  | d: Optional[int] = None, | 
					
						
						|  | layer_id: Optional[int] = None, | 
					
						
						|  | std_factor: float = 1.0, | 
					
						
						|  | type_of_module: Optional[ModuleType] = None, | 
					
						
						|  | ) -> None: | 
					
						
						|  | """ | 
					
						
						|  | Initialize weights of a linear or embedding module. | 
					
						
						|  |  | 
					
						
						|  | :param config: The model config. | 
					
						
						|  | :param module: The linear or embedding submodule to initialize. | 
					
						
						|  | :param d: The effective input dimensionality of the weights. This could be smaller than the actual dimensions | 
					
						
						|  | for fused layers. | 
					
						
						|  | :param layer_id: When set, the standard deviation for the "mitchell" method will be adjusted by | 
					
						
						|  | ``1 / sqrt(2 * (layer_id + 1))``. | 
					
						
						|  | """ | 
					
						
						|  | d = d if d is not None else config.d_model | 
					
						
						|  | if config.init_fn == InitFnType.normal: | 
					
						
						|  | std = config.init_std * std_factor | 
					
						
						|  | if config.init_cutoff_factor is not None: | 
					
						
						|  | cutoff_value = config.init_cutoff_factor * std | 
					
						
						|  | nn.init.trunc_normal_(module.weight, mean=0.0, std=std, a=-cutoff_value, b=cutoff_value) | 
					
						
						|  | else: | 
					
						
						|  | nn.init.normal_(module.weight, mean=0.0, std=std) | 
					
						
						|  | elif config.init_fn == InitFnType.mitchell: | 
					
						
						|  | std = std_factor / math.sqrt(d) | 
					
						
						|  | if layer_id is not None: | 
					
						
						|  | std = std / math.sqrt(2 * (layer_id + 1)) | 
					
						
						|  | nn.init.trunc_normal_(module.weight, mean=0.0, std=std, a=-3 * std, b=3 * std) | 
					
						
						|  | elif config.init_fn == InitFnType.kaiming_normal: | 
					
						
						|  | nn.init.kaiming_normal_(module.weight, nonlinearity="relu") | 
					
						
						|  | elif config.init_fn == InitFnType.fan_in: | 
					
						
						|  | std = std_factor / math.sqrt(d) | 
					
						
						|  | nn.init.normal_(module.weight, mean=0.0, std=std) | 
					
						
						|  | elif config.init_fn == InitFnType.full_megatron: | 
					
						
						|  | if type_of_module is None: | 
					
						
						|  | raise RuntimeError(f"When using the {InitFnType.full_megatron} init, every module must have a type.") | 
					
						
						|  |  | 
					
						
						|  | cutoff_factor = config.init_cutoff_factor | 
					
						
						|  | if cutoff_factor is None: | 
					
						
						|  | cutoff_factor = 3 | 
					
						
						|  |  | 
					
						
						|  | if type_of_module == ModuleType.in_module: | 
					
						
						|  |  | 
					
						
						|  | std = config.init_std | 
					
						
						|  | elif type_of_module == ModuleType.out_module: | 
					
						
						|  |  | 
					
						
						|  | std = config.init_std / math.sqrt(2.0 * config.n_layers) | 
					
						
						|  | elif type_of_module == ModuleType.emb: | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | std = config.init_std | 
					
						
						|  | elif type_of_module == ModuleType.final_out: | 
					
						
						|  |  | 
					
						
						|  | std = config.d_model**-0.5 | 
					
						
						|  | else: | 
					
						
						|  | raise RuntimeError(f"Unknown module type '{type_of_module}'") | 
					
						
						|  | nn.init.trunc_normal_( | 
					
						
						|  | module.weight, | 
					
						
						|  | mean=0.0, | 
					
						
						|  | std=std, | 
					
						
						|  | a=-cutoff_factor * std, | 
					
						
						|  | b=cutoff_factor * std, | 
					
						
						|  | ) | 
					
						
						|  | else: | 
					
						
						|  | raise NotImplementedError(config.init_fn) | 
					
						
						|  |  | 
					
						
						|  | if isinstance(module, nn.Linear): | 
					
						
						|  | if module.bias is not None: | 
					
						
						|  | nn.init.zeros_(module.bias) | 
					
						
						|  |  | 
					
						
						|  | if config.init_fn == InitFnType.normal and getattr(module, "_is_residual", False): | 
					
						
						|  | with torch.no_grad(): | 
					
						
						|  | module.weight.div_(math.sqrt(2 * config.n_layers)) | 
					
						
						|  |  |