| """MPT Blocks used for the MPT Model.""" | |
| import logging | |
| from copy import deepcopy | |
| from functools import partial | |
| from typing import Any, Callable, Optional, Union | |
| import torch | |
| import torch.nn as nn | |
| from .fc import FC_CLASS_REGISTRY | |
| try: | |
| import transformer_engine.pytorch as te | |
| except: | |
| te = None | |
| log = logging.getLogger(__name__) | |
| _FFN_ACT_FN_DEFAULT = {'name': 'gelu', 'approximate': 'none'} | |
| def resolve_ffn_act_fn(config: Optional[dict]=None) -> Callable[[torch.Tensor], torch.Tensor]: | |
| """Resolve the activation function for the feed-forward network. | |
| Args: | |
| config (Optional[dict]): The configuration dictionary for the activation function. | |
| The dict config must specify the 'name' of a torch.nn.functional activation | |
| function. All of other key values pairs are bound to the function as a partial. | |
| Returns: | |
| Callable[[torch.Tensor], torch.Tensor]: The activation function. | |
| """ | |
| if config is None: | |
| config = _FFN_ACT_FN_DEFAULT | |
| config = deepcopy(config) | |
| name = config.pop('name') | |
| if not hasattr(torch.nn.functional, name): | |
| raise ValueError(f'Unrecognised activation function name ({name}).') | |
| act = getattr(torch.nn.functional, name) | |
| return partial(act, **config) | |
| _DEFAULT_ACT_FN = resolve_ffn_act_fn(_FFN_ACT_FN_DEFAULT) | |
| def resolve_ffn_hidden_size(d_model: int, expansion_ratio: Union[int, float], ffn_hidden_size: Optional[int]=None) -> int: | |
| """Resolve the hidden size of the feed-forward network. | |
| Args: | |
| d_model (int): The dimension of the input and output of the feed-forward network. | |
| expansion_ratio (Union[int, float]): The expansion ratio of the feed-forward network. | |
| ffn_hidden_size (Optional[int]): The hidden size of the feed-forward network. | |
| Returns: | |
| int: The hidden size of the feed-forward network. | |
| """ | |
| if ffn_hidden_size is not None: | |
| log.info(f'`expansion_ratio` (={expansion_ratio}) ignored when `ffn_hidden_size` (={ffn_hidden_size}) is specified.') | |
| else: | |
| ffn_hidden_size = int(d_model * expansion_ratio) | |
| if ffn_hidden_size != d_model * expansion_ratio: | |
| raise ValueError(f'`d_model * expansion_ratio` must be an integer (d_model={d_model!r}; expansion_ratio={expansion_ratio!r}; d_model * expansion_ratio={d_model * expansion_ratio!r}).') | |
| return ffn_hidden_size | |
| class MPTMLP(nn.Module): | |
| def __init__(self, d_model: int, expansion_ratio: Union[int, float], fc_type: str='torch', ffn_hidden_size: Optional[int]=None, act_fn: Callable[[torch.Tensor], torch.Tensor]=_DEFAULT_ACT_FN, device: Optional[str]=None, bias: bool=True): | |
| super().__init__() | |
| ffn_hidden_size = resolve_ffn_hidden_size(d_model, expansion_ratio, ffn_hidden_size) | |
| self.fc_kwargs: dict[str, Any] = {'bias': bias} | |
| if fc_type != 'te': | |
| self.fc_kwargs['device'] = device | |
| self.up_proj = FC_CLASS_REGISTRY[fc_type](d_model, ffn_hidden_size, **self.fc_kwargs) | |
| self.act = act_fn | |
| self.down_proj = FC_CLASS_REGISTRY[fc_type](ffn_hidden_size, d_model, **self.fc_kwargs) | |
| self.down_proj._is_residual = True | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| return self.down_proj(self.act(self.up_proj(x))) | |
| class MPTGLU(MPTMLP): | |
| def __init__(self, d_model: int, expansion_ratio: Union[int, float], fc_type: str='torch', ffn_hidden_size: Optional[int]=None, act_fn: Callable[[torch.Tensor], torch.Tensor]=_DEFAULT_ACT_FN, device: Optional[str]=None, bias: bool=True): | |
| super().__init__(d_model=d_model, expansion_ratio=expansion_ratio, fc_type=fc_type, ffn_hidden_size=ffn_hidden_size, act_fn=act_fn, device=device, bias=bias) | |
| self.gate_proj = FC_CLASS_REGISTRY[fc_type](d_model, self.up_proj.out_features, **self.fc_kwargs) | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| return self.down_proj(self.act(self.gate_proj(x)) * self.up_proj(x)) | |
| FFN_CLASS_REGISTRY = {'mptmlp': MPTMLP, 'mptglu': MPTGLU} | |
| if te is not None: | |
| te.LayerNormMLP._has_norm = True | |
| FFN_CLASS_REGISTRY['te_ln_mlp'] = te.LayerNormMLP | |
| def build_ffn(d_model: int, expansion_ratio: Union[int, float], fc_type: str='torch', ffn_hidden_size: Optional[int]=None, ffn_act_fn: Optional[dict]=None, device: Optional[str]=None, bias: bool=True, **kwargs: Any) -> nn.Module: | |
| ffn_type = kwargs.pop('ffn_type') | |
| if ffn_type in ['mptmlp', 'mptglu']: | |
| if len(kwargs) > 0: | |
| raise ValueError(f'MPTMLP (or MPTGLU) got an unexpected keyword argument: {kwargs}') | |
| return FFN_CLASS_REGISTRY[ffn_type](d_model=d_model, expansion_ratio=expansion_ratio, fc_type=fc_type, act_fn=resolve_ffn_act_fn(ffn_act_fn), ffn_hidden_size=ffn_hidden_size, device=device, bias=bias) | |
| elif ffn_type == 'te_ln_mlp': | |
| assert te is not None | |
| ffn_hidden_size = resolve_ffn_hidden_size(d_model, expansion_ratio, ffn_hidden_size) | |
| if ffn_act_fn is not None: | |
| raise ValueError(f'Transformer Engine block does not support custom activation functions.') | |
| return te.LayerNormMLP(hidden_size=d_model, ffn_hidden_size=ffn_hidden_size, bias=bias, **kwargs) | |
| raise ValueError(f'ffn_type={ffn_type!r} not recognized.') |