| """GPT Blocks used for the GPT Model.""" | |
| from typing import Any, Dict, Optional, Tuple | |
| import torch | |
| import torch.nn as nn | |
| from .attention import ATTN_CLASS_REGISTRY | |
| from .ffn import FFN_CLASS_REGISTRY, build_ffn | |
| from .norm import NORM_CLASS_REGISTRY | |
| try: | |
| from flash_attn.bert_padding import unpad_input, pad_input | |
| except: | |
| (unpad_input, pad_input) = (None, None) | |
| attn_config_defaults: Dict = { | |
| "attn_type": "multihead_attention", | |
| "attn_pdrop": 0.0, | |
| "attn_impl": "flash", | |
| "qk_ln": True, | |
| "qk_gn": False, | |
| "clip_qkv": None, | |
| "softmax_scale": None, | |
| "prefix_lm": False, | |
| "attn_uses_sequence_id": False, | |
| "sliding_window_size": -1, | |
| "alibi": False, | |
| "alibi_bias_max": 8, | |
| "rope": False, | |
| "rope_theta": 10000, | |
| "rope_impl": "dail", | |
| "rope_dail_config": { | |
| "type": "original", | |
| "pos_idx_in_fp32": True, | |
| "xpos_scale_base": 512, | |
| }, | |
| "rope_hf_config": {"type": "no_scaling", "factor": 1.0}, | |
| } | |
| class MPTBlock(nn.Module): | |
| def __init__( | |
| self, | |
| d_model: int, | |
| n_heads: int, | |
| expansion_ratio: int, | |
| attn_config: Optional[Dict] = None, | |
| ffn_config: Optional[Dict] = None, | |
| resid_pdrop: float = 0.0, | |
| norm_type: str = "low_precision_layernorm", | |
| fc_type: str = "torch", | |
| device: Optional[str] = None, | |
| no_bias: bool = False, | |
| use_pad_tok_in_ffn: bool = True, | |
| **kwargs: Any | |
| ): | |
| if attn_config is None: | |
| attn_config = attn_config_defaults | |
| if ffn_config is None: | |
| ffn_config = {"ffn_type": "mptmlp"} | |
| del kwargs | |
| super().__init__() | |
| norm_class = NORM_CLASS_REGISTRY[norm_type.lower()] | |
| assert isinstance(attn_config["attn_type"], str) | |
| attn_class = ATTN_CLASS_REGISTRY[attn_config["attn_type"]] | |
| args_to_exclude_in_attn_class = { | |
| "attn_type", | |
| "prefix_lm", | |
| "alibi", | |
| "attn_uses_sequence_id", | |
| "alibi_bias_max", | |
| "rope", | |
| "rope_theta", | |
| "rope_impl", | |
| "rope_dail_config", | |
| "rope_hf_config", | |
| } | |
| attn_config_subset_for_attn_class = { | |
| k: v | |
| for (k, v) in attn_config.items() | |
| if k not in args_to_exclude_in_attn_class | |
| } | |
| self.norm_1 = norm_class(d_model, device=device) | |
| self.attn = attn_class( | |
| d_model=d_model, | |
| n_heads=n_heads, | |
| fc_type=fc_type, | |
| device=device, | |
| **attn_config_subset_for_attn_class, | |
| bias=not no_bias | |
| ) | |
| self.norm_2 = None | |
| if not getattr(FFN_CLASS_REGISTRY[ffn_config["ffn_type"]], "_has_norm", False): | |
| self.norm_2 = norm_class(d_model, device=device) | |
| self.ffn = build_ffn( | |
| d_model=d_model, | |
| expansion_ratio=expansion_ratio, | |
| device=device, | |
| bias=not no_bias, | |
| **ffn_config | |
| ) | |
| self.resid_attn_dropout = nn.Dropout(resid_pdrop) | |
| self.resid_ffn_dropout = nn.Dropout(resid_pdrop) | |
| self.use_pad_tok_in_ffn = use_pad_tok_in_ffn | |
| def forward( | |
| self, | |
| x: torch.Tensor, | |
| past_key_value: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, | |
| attn_bias: Optional[torch.Tensor] = None, | |
| rotary_emb_w_meta_info: Optional[Dict] = None, | |
| attention_mask: Optional[torch.ByteTensor] = None, | |
| is_causal: bool = True, | |
| output_attentions: bool = False, | |
| alibi_slopes: Optional[torch.Tensor] = None, | |
| flash_attn_padding_info: Optional[dict[str, torch.Tensor]] = None, | |
| ) -> Tuple[ | |
| torch.Tensor, | |
| Optional[torch.Tensor], | |
| Optional[Tuple[torch.Tensor, torch.Tensor]], | |
| ]: | |
| a = self.norm_1(x) | |
| (b, attn_weights, past_key_value) = self.attn( | |
| a, | |
| past_key_value=past_key_value, | |
| attn_bias=attn_bias, | |
| rotary_emb_w_meta_info=rotary_emb_w_meta_info, | |
| attention_mask=attention_mask, | |
| is_causal=is_causal, | |
| needs_weights=output_attentions, | |
| alibi_slopes=alibi_slopes, | |
| flash_attn_padding_info=flash_attn_padding_info, | |
| ) | |
| x = x + self.resid_attn_dropout(b) | |
| m = x | |
| if self.norm_2 is not None: | |
| m = self.norm_2(x) | |
| (batch_size, seq_len) = m.size()[:2] | |
| indices = None | |
| if not self.use_pad_tok_in_ffn: | |
| assert unpad_input is not None | |
| (m, indices, _, _) = unpad_input(m, attention_mask) | |
| n = self.ffn(m) | |
| if not self.use_pad_tok_in_ffn: | |
| assert pad_input is not None | |
| n = pad_input(n, indices, batch_size, seq_len) | |
| x = x + self.resid_ffn_dropout(n) | |
| return (x, attn_weights, past_key_value) | |