Commit 
							
							·
						
						b0e5659
	
1
								Parent(s):
							
							cac27b0
								
add int8 model inference
Browse files- README.md +2 -3
- attention.py +61 -37
- blocks.py +4 -4
- configuration_mpt.py +1 -1
- modeling_mpt.py +31 -14
    	
        README.md
    CHANGED
    
    | @@ -44,7 +44,7 @@ The following hyperparameters were used during training: | |
| 44 | 
             
            ```shell
         | 
| 45 | 
             
            import transformers
         | 
| 46 | 
             
            model = transformers.AutoModelForCausalLM.from_pretrained(
         | 
| 47 | 
            -
              'Intel/neural-chat-7b-v1 | 
| 48 | 
             
              trust_remote_code=True
         | 
| 49 | 
             
            )
         | 
| 50 | 
             
            ```
         | 
| @@ -54,8 +54,7 @@ Follow the instructions [link](https://github.com/intel/intel-extension-for-tran | |
| 54 |  | 
| 55 | 
             
            ```shell
         | 
| 56 | 
             
            python run_generation.py \
         | 
| 57 | 
            -
                --model Intel/neural-chat-7b-v1 | 
| 58 | 
            -
                --revision c8d4750ac8421303665d6ecc253950c69b56d324 \
         | 
| 59 | 
             
                --quantize \
         | 
| 60 | 
             
                --sq \
         | 
| 61 | 
             
                --alpha 0.95 \
         | 
|  | |
| 44 | 
             
            ```shell
         | 
| 45 | 
             
            import transformers
         | 
| 46 | 
             
            model = transformers.AutoModelForCausalLM.from_pretrained(
         | 
| 47 | 
            +
              'Intel/neural-chat-7b-v1-1',
         | 
| 48 | 
             
              trust_remote_code=True
         | 
| 49 | 
             
            )
         | 
| 50 | 
             
            ```
         | 
|  | |
| 54 |  | 
| 55 | 
             
            ```shell
         | 
| 56 | 
             
            python run_generation.py \
         | 
| 57 | 
            +
                --model Intel/neural-chat-7b-v1-1 \
         | 
|  | |
| 58 | 
             
                --quantize \
         | 
| 59 | 
             
                --sq \
         | 
| 60 | 
             
                --alpha 0.95 \
         | 
    	
        attention.py
    CHANGED
    
    | @@ -5,6 +5,7 @@ from typing import Optional | |
| 5 | 
             
            import torch
         | 
| 6 | 
             
            import torch.nn as nn
         | 
| 7 | 
             
            from einops import rearrange
         | 
|  | |
| 8 | 
             
            from torch import nn
         | 
| 9 | 
             
            from .norm import LPLayerNorm
         | 
| 10 |  | 
| @@ -16,25 +17,34 @@ def _reset_is_causal(num_query_tokens: int, num_key_tokens: int, original_is_cau | |
| 16 | 
             
                        return False
         | 
| 17 | 
             
                return original_is_causal
         | 
| 18 |  | 
| 19 | 
            -
            def scaled_multihead_dot_product_attention(query, key, value, n_heads, softmax_scale=None, attn_bias=None, key_padding_mask=None, is_causal=False, dropout_p=0.0, training=False, needs_weights=False, multiquery=False):
         | 
| 20 | 
             
                q = rearrange(query, 'b s (h d) -> b h s d', h=n_heads)
         | 
| 21 | 
            -
                 | 
| 22 | 
            -
                 | 
| 23 | 
            -
                 | 
|  | |
|  | |
|  | |
|  | |
|  | |
| 24 | 
             
                (b, _, s_q, d) = q.shape
         | 
| 25 | 
             
                s_k = k.size(-1)
         | 
| 26 | 
             
                if softmax_scale is None:
         | 
| 27 | 
             
                    softmax_scale = 1 / math.sqrt(d)
         | 
| 28 | 
             
                attn_weight = q.matmul(k) * softmax_scale
         | 
| 29 | 
             
                if attn_bias is not None:
         | 
|  | |
|  | |
|  | |
| 30 | 
             
                    if attn_bias.size(-1) != 1 and attn_bias.size(-1) != s_k or (attn_bias.size(-2) != 1 and attn_bias.size(-2) != s_q):
         | 
| 31 | 
             
                        raise RuntimeError(f'attn_bias (shape: {attn_bias.shape}) is expected to broadcast to shape: {attn_weight.shape}.')
         | 
| 32 | 
             
                    attn_weight = attn_weight + attn_bias
         | 
|  | |
| 33 | 
             
                if key_padding_mask is not None:
         | 
| 34 | 
             
                    if attn_bias is not None:
         | 
| 35 | 
             
                        warnings.warn('Propogating key_padding_mask to the attention module ' + 'and applying it within the attention module can cause ' + 'unneccessary computation/memory usage. Consider integrating ' + 'into attn_bias once and passing that to each attention ' + 'module instead.')
         | 
| 36 | 
             
                    attn_weight = attn_weight.masked_fill(~key_padding_mask.view((b, 1, 1, s_k)), min_val)
         | 
| 37 | 
            -
                if is_causal:
         | 
| 38 | 
             
                    s = max(s_q, s_k)
         | 
| 39 | 
             
                    causal_mask = attn_weight.new_ones(s, s, dtype=torch.float16)
         | 
| 40 | 
             
                    causal_mask = causal_mask.tril()
         | 
| @@ -48,8 +58,8 @@ def scaled_multihead_dot_product_attention(query, key, value, n_heads, softmax_s | |
| 48 | 
             
                out = attn_weight.matmul(v)
         | 
| 49 | 
             
                out = rearrange(out, 'b h s d -> b s (h d)')
         | 
| 50 | 
             
                if needs_weights:
         | 
| 51 | 
            -
                    return (out, attn_weight)
         | 
| 52 | 
            -
                return (out, None)
         | 
| 53 |  | 
| 54 | 
             
            def check_valid_inputs(*tensors, valid_dtypes=[torch.float16, torch.bfloat16]):
         | 
| 55 | 
             
                for tensor in tensors:
         | 
| @@ -58,12 +68,21 @@ def check_valid_inputs(*tensors, valid_dtypes=[torch.float16, torch.bfloat16]): | |
| 58 | 
             
                    if not tensor.is_cuda:
         | 
| 59 | 
             
                        raise TypeError(f'Inputs must be cuda tensors (tensor.is_cuda={tensor.is_cuda!r}).')
         | 
| 60 |  | 
| 61 | 
            -
            def flash_attn_fn(query, key, value, n_heads, softmax_scale=None, attn_bias=None, key_padding_mask=None, is_causal=False, dropout_p=0.0, training=False, needs_weights=False, multiquery=False):
         | 
| 62 | 
             
                try:
         | 
| 63 | 
             
                    from flash_attn import bert_padding, flash_attn_interface
         | 
| 64 | 
             
                except:
         | 
| 65 | 
             
                    raise RuntimeError('Please install flash-attn==1.0.3.post0')
         | 
| 66 | 
             
                check_valid_inputs(query, key, value)
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 67 | 
             
                if attn_bias is not None:
         | 
| 68 | 
             
                    raise NotImplementedError(f'attn_bias not implemented for flash attn.')
         | 
| 69 | 
             
                (batch_size, seqlen) = query.shape[:2]
         | 
| @@ -83,14 +102,31 @@ def flash_attn_fn(query, key, value, n_heads, softmax_scale=None, attn_bias=None | |
| 83 | 
             
                reset_is_causal = _reset_is_causal(query.size(1), key.size(1), is_causal)
         | 
| 84 | 
             
                output_unpad = flash_attn_interface.flash_attn_unpadded_func(query_unpad, key_unpad, value_unpad, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, dropout_p, softmax_scale=softmax_scale, causal=reset_is_causal, return_attn_probs=needs_weights)
         | 
| 85 | 
             
                output = bert_padding.pad_input(rearrange(output_unpad, 'nnz h d -> nnz (h d)'), indices_q, batch_size, seqlen)
         | 
| 86 | 
            -
                return (output, None)
         | 
| 87 |  | 
| 88 | 
            -
            def triton_flash_attn_fn(query, key, value, n_heads, softmax_scale=None, attn_bias=None, key_padding_mask=None, is_causal=False, dropout_p=0.0, training=False, needs_weights=False, multiquery=False):
         | 
| 89 | 
             
                try:
         | 
| 90 | 
            -
                    from  | 
| 91 | 
             
                except:
         | 
| 92 | 
            -
                     | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 93 | 
             
                check_valid_inputs(query, key, value)
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 94 | 
             
                if dropout_p:
         | 
| 95 | 
             
                    raise NotImplementedError(f'Dropout not implemented for attn_impl: triton.')
         | 
| 96 | 
             
                if needs_weights:
         | 
| @@ -108,9 +144,9 @@ def triton_flash_attn_fn(query, key, value, n_heads, softmax_scale=None, attn_bi | |
| 108 | 
             
                    key = key.expand(*key.shape[:2], n_heads, key.size(-1))
         | 
| 109 | 
             
                    value = value.expand(*value.shape[:2], n_heads, value.size(-1))
         | 
| 110 | 
             
                reset_is_causal = _reset_is_causal(query.size(1), key.size(1), is_causal)
         | 
| 111 | 
            -
                attn_output =  | 
| 112 | 
             
                output = attn_output.view(*attn_output.shape[:2], -1)
         | 
| 113 | 
            -
                return (output, None)
         | 
| 114 |  | 
| 115 | 
             
            class MultiheadAttention(nn.Module):
         | 
| 116 | 
             
                """Multi-head self attention.
         | 
| @@ -119,7 +155,7 @@ class MultiheadAttention(nn.Module): | |
| 119 | 
             
                additive bias.
         | 
| 120 | 
             
                """
         | 
| 121 |  | 
| 122 | 
            -
                def __init__(self, d_model: int, n_heads: int, attn_impl: str='triton', clip_qkv: Optional[float]=None, qk_ln: bool=False, softmax_scale: Optional[float]=None, attn_pdrop: float=0.0, low_precision_layernorm: bool=False, device: Optional[str]=None):
         | 
| 123 | 
             
                    super().__init__()
         | 
| 124 | 
             
                    self.attn_impl = attn_impl
         | 
| 125 | 
             
                    self.clip_qkv = clip_qkv
         | 
| @@ -141,10 +177,11 @@ class MultiheadAttention(nn.Module): | |
| 141 | 
             
                        self.attn_fn = flash_attn_fn
         | 
| 142 | 
             
                    elif self.attn_impl == 'triton':
         | 
| 143 | 
             
                        self.attn_fn = triton_flash_attn_fn
         | 
| 144 | 
            -
                         | 
|  | |
| 145 | 
             
                    elif self.attn_impl == 'torch':
         | 
| 146 | 
             
                        self.attn_fn = scaled_multihead_dot_product_attention
         | 
| 147 | 
            -
                        if torch.cuda.is_available():
         | 
| 148 | 
             
                            warnings.warn('Using `attn_impl: torch`. If your model does not use `alibi` or ' + '`prefix_lm` we recommend using `attn_impl: flash` otherwise ' + 'we recommend using `attn_impl: triton`.')
         | 
| 149 | 
             
                    else:
         | 
| 150 | 
             
                        raise ValueError(f'attn_impl={attn_impl!r} is an invalid setting.')
         | 
| @@ -161,14 +198,7 @@ class MultiheadAttention(nn.Module): | |
| 161 | 
             
                        dtype = query.dtype
         | 
| 162 | 
             
                        query = self.q_ln(query).to(dtype)
         | 
| 163 | 
             
                        key = self.k_ln(key).to(dtype)
         | 
| 164 | 
            -
                     | 
| 165 | 
            -
                        if len(past_key_value) != 0:
         | 
| 166 | 
            -
                            key = torch.cat([past_key_value[0], key], dim=1)
         | 
| 167 | 
            -
                            value = torch.cat([past_key_value[1], value], dim=1)
         | 
| 168 | 
            -
                        past_key_value = (key, value)
         | 
| 169 | 
            -
                    if attn_bias is not None:
         | 
| 170 | 
            -
                        attn_bias = attn_bias[:, :, -query.size(1):, -key.size(1):]
         | 
| 171 | 
            -
                    (context, attn_weights) = self.attn_fn(query, key, value, self.n_heads, softmax_scale=self.softmax_scale, attn_bias=attn_bias, key_padding_mask=key_padding_mask, is_causal=is_causal, dropout_p=self.attn_dropout_p, training=self.training, needs_weights=needs_weights)
         | 
| 172 | 
             
                    return (self.out_proj(context), attn_weights, past_key_value)
         | 
| 173 |  | 
| 174 | 
             
            class MultiQueryAttention(nn.Module):
         | 
| @@ -178,7 +208,7 @@ class MultiQueryAttention(nn.Module): | |
| 178 | 
             
                additive bias.
         | 
| 179 | 
             
                """
         | 
| 180 |  | 
| 181 | 
            -
                def __init__(self, d_model: int, n_heads: int, attn_impl: str='triton', clip_qkv: Optional[float]=None, qk_ln: bool=False, softmax_scale: Optional[float]=None, attn_pdrop: float=0.0, low_precision_layernorm: bool=False, device: Optional[str]=None):
         | 
| 182 | 
             
                    super().__init__()
         | 
| 183 | 
             
                    self.attn_impl = attn_impl
         | 
| 184 | 
             
                    self.clip_qkv = clip_qkv
         | 
| @@ -201,10 +231,11 @@ class MultiQueryAttention(nn.Module): | |
| 201 | 
             
                        self.attn_fn = flash_attn_fn
         | 
| 202 | 
             
                    elif self.attn_impl == 'triton':
         | 
| 203 | 
             
                        self.attn_fn = triton_flash_attn_fn
         | 
| 204 | 
            -
                         | 
|  | |
| 205 | 
             
                    elif self.attn_impl == 'torch':
         | 
| 206 | 
             
                        self.attn_fn = scaled_multihead_dot_product_attention
         | 
| 207 | 
            -
                        if torch.cuda.is_available():
         | 
| 208 | 
             
                            warnings.warn('Using `attn_impl: torch`. If your model does not use `alibi` or ' + '`prefix_lm` we recommend using `attn_impl: flash` otherwise ' + 'we recommend using `attn_impl: triton`.')
         | 
| 209 | 
             
                    else:
         | 
| 210 | 
             
                        raise ValueError(f'attn_impl={attn_impl!r} is an invalid setting.')
         | 
| @@ -221,14 +252,7 @@ class MultiQueryAttention(nn.Module): | |
| 221 | 
             
                        dtype = query.dtype
         | 
| 222 | 
             
                        query = self.q_ln(query).to(dtype)
         | 
| 223 | 
             
                        key = self.k_ln(key).to(dtype)
         | 
| 224 | 
            -
                     | 
| 225 | 
            -
                        if len(past_key_value) != 0:
         | 
| 226 | 
            -
                            key = torch.cat([past_key_value[0], key], dim=1)
         | 
| 227 | 
            -
                            value = torch.cat([past_key_value[1], value], dim=1)
         | 
| 228 | 
            -
                        past_key_value = (key, value)
         | 
| 229 | 
            -
                    if attn_bias is not None:
         | 
| 230 | 
            -
                        attn_bias = attn_bias[:, :, -query.size(1):, -key.size(1):]
         | 
| 231 | 
            -
                    (context, attn_weights) = self.attn_fn(query, key, value, self.n_heads, softmax_scale=self.softmax_scale, attn_bias=attn_bias, key_padding_mask=key_padding_mask, is_causal=is_causal, dropout_p=self.attn_dropout_p, training=self.training, needs_weights=needs_weights, multiquery=True)
         | 
| 232 | 
             
                    return (self.out_proj(context), attn_weights, past_key_value)
         | 
| 233 |  | 
| 234 | 
             
            def attn_bias_shape(attn_impl, n_heads, seq_len, alibi, prefix_lm, causal, use_sequence_id):
         | 
| @@ -273,4 +297,4 @@ def build_alibi_bias(n_heads, seq_len, full=False, alibi_bias_max=8, device=None | |
| 273 | 
             
                slopes = gen_slopes(n_heads, alibi_bias_max, device=device)
         | 
| 274 | 
             
                alibi_bias = alibi_bias * slopes
         | 
| 275 | 
             
                return alibi_bias.to(dtype=dtype)
         | 
| 276 | 
            -
            ATTN_CLASS_REGISTRY = {'multihead_attention': MultiheadAttention, 'multiquery_attention': MultiQueryAttention}
         | 
|  | |
| 5 | 
             
            import torch
         | 
| 6 | 
             
            import torch.nn as nn
         | 
| 7 | 
             
            from einops import rearrange
         | 
| 8 | 
            +
            from packaging import version
         | 
| 9 | 
             
            from torch import nn
         | 
| 10 | 
             
            from .norm import LPLayerNorm
         | 
| 11 |  | 
|  | |
| 17 | 
             
                        return False
         | 
| 18 | 
             
                return original_is_causal
         | 
| 19 |  | 
| 20 | 
            +
            def scaled_multihead_dot_product_attention(query, key, value, n_heads, past_key_value=None, softmax_scale=None, attn_bias=None, key_padding_mask=None, is_causal=False, dropout_p=0.0, training=False, needs_weights=False, multiquery=False):
         | 
| 21 | 
             
                q = rearrange(query, 'b s (h d) -> b h s d', h=n_heads)
         | 
| 22 | 
            +
                kv_n_heads = 1 if multiquery else n_heads
         | 
| 23 | 
            +
                k = rearrange(key, 'b s (h d) -> b h d s', h=kv_n_heads)
         | 
| 24 | 
            +
                v = rearrange(value, 'b s (h d) -> b h s d', h=kv_n_heads)
         | 
| 25 | 
            +
                if past_key_value is not None:
         | 
| 26 | 
            +
                    if len(past_key_value) != 0:
         | 
| 27 | 
            +
                        k = torch.cat([past_key_value[0], k], dim=3)
         | 
| 28 | 
            +
                        v = torch.cat([past_key_value[1], v], dim=2)
         | 
| 29 | 
            +
                    past_key_value = (k, v)
         | 
| 30 | 
             
                (b, _, s_q, d) = q.shape
         | 
| 31 | 
             
                s_k = k.size(-1)
         | 
| 32 | 
             
                if softmax_scale is None:
         | 
| 33 | 
             
                    softmax_scale = 1 / math.sqrt(d)
         | 
| 34 | 
             
                attn_weight = q.matmul(k) * softmax_scale
         | 
| 35 | 
             
                if attn_bias is not None:
         | 
| 36 | 
            +
                    _s_q = max(0, attn_bias.size(2) - s_q)
         | 
| 37 | 
            +
                    _s_k = max(0, attn_bias.size(3) - s_k)
         | 
| 38 | 
            +
                    attn_bias = attn_bias[:, :, _s_q:, _s_k:]
         | 
| 39 | 
             
                    if attn_bias.size(-1) != 1 and attn_bias.size(-1) != s_k or (attn_bias.size(-2) != 1 and attn_bias.size(-2) != s_q):
         | 
| 40 | 
             
                        raise RuntimeError(f'attn_bias (shape: {attn_bias.shape}) is expected to broadcast to shape: {attn_weight.shape}.')
         | 
| 41 | 
             
                    attn_weight = attn_weight + attn_bias
         | 
| 42 | 
            +
                min_val = torch.finfo(q.dtype).min
         | 
| 43 | 
             
                if key_padding_mask is not None:
         | 
| 44 | 
             
                    if attn_bias is not None:
         | 
| 45 | 
             
                        warnings.warn('Propogating key_padding_mask to the attention module ' + 'and applying it within the attention module can cause ' + 'unneccessary computation/memory usage. Consider integrating ' + 'into attn_bias once and passing that to each attention ' + 'module instead.')
         | 
| 46 | 
             
                    attn_weight = attn_weight.masked_fill(~key_padding_mask.view((b, 1, 1, s_k)), min_val)
         | 
| 47 | 
            +
                if is_causal and (not q.size(2) == 1):
         | 
| 48 | 
             
                    s = max(s_q, s_k)
         | 
| 49 | 
             
                    causal_mask = attn_weight.new_ones(s, s, dtype=torch.float16)
         | 
| 50 | 
             
                    causal_mask = causal_mask.tril()
         | 
|  | |
| 58 | 
             
                out = attn_weight.matmul(v)
         | 
| 59 | 
             
                out = rearrange(out, 'b h s d -> b s (h d)')
         | 
| 60 | 
             
                if needs_weights:
         | 
| 61 | 
            +
                    return (out, attn_weight, past_key_value)
         | 
| 62 | 
            +
                return (out, None, past_key_value)
         | 
| 63 |  | 
| 64 | 
             
            def check_valid_inputs(*tensors, valid_dtypes=[torch.float16, torch.bfloat16]):
         | 
| 65 | 
             
                for tensor in tensors:
         | 
|  | |
| 68 | 
             
                    if not tensor.is_cuda:
         | 
| 69 | 
             
                        raise TypeError(f'Inputs must be cuda tensors (tensor.is_cuda={tensor.is_cuda!r}).')
         | 
| 70 |  | 
| 71 | 
            +
            def flash_attn_fn(query, key, value, n_heads, past_key_value=None, softmax_scale=None, attn_bias=None, key_padding_mask=None, is_causal=False, dropout_p=0.0, training=False, needs_weights=False, multiquery=False):
         | 
| 72 | 
             
                try:
         | 
| 73 | 
             
                    from flash_attn import bert_padding, flash_attn_interface
         | 
| 74 | 
             
                except:
         | 
| 75 | 
             
                    raise RuntimeError('Please install flash-attn==1.0.3.post0')
         | 
| 76 | 
             
                check_valid_inputs(query, key, value)
         | 
| 77 | 
            +
                if past_key_value is not None:
         | 
| 78 | 
            +
                    if len(past_key_value) != 0:
         | 
| 79 | 
            +
                        key = torch.cat([past_key_value[0], key], dim=1)
         | 
| 80 | 
            +
                        value = torch.cat([past_key_value[1], value], dim=1)
         | 
| 81 | 
            +
                    past_key_value = (key, value)
         | 
| 82 | 
            +
                if attn_bias is not None:
         | 
| 83 | 
            +
                    _s_q = max(0, attn_bias.size(2) - query.size(1))
         | 
| 84 | 
            +
                    _s_k = max(0, attn_bias.size(3) - key.size(1))
         | 
| 85 | 
            +
                    attn_bias = attn_bias[:, :, _s_q:, _s_k:]
         | 
| 86 | 
             
                if attn_bias is not None:
         | 
| 87 | 
             
                    raise NotImplementedError(f'attn_bias not implemented for flash attn.')
         | 
| 88 | 
             
                (batch_size, seqlen) = query.shape[:2]
         | 
|  | |
| 102 | 
             
                reset_is_causal = _reset_is_causal(query.size(1), key.size(1), is_causal)
         | 
| 103 | 
             
                output_unpad = flash_attn_interface.flash_attn_unpadded_func(query_unpad, key_unpad, value_unpad, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, dropout_p, softmax_scale=softmax_scale, causal=reset_is_causal, return_attn_probs=needs_weights)
         | 
| 104 | 
             
                output = bert_padding.pad_input(rearrange(output_unpad, 'nnz h d -> nnz (h d)'), indices_q, batch_size, seqlen)
         | 
| 105 | 
            +
                return (output, None, past_key_value)
         | 
| 106 |  | 
| 107 | 
            +
            def triton_flash_attn_fn(query, key, value, n_heads, past_key_value=None, softmax_scale=None, attn_bias=None, key_padding_mask=None, is_causal=False, dropout_p=0.0, training=False, needs_weights=False, multiquery=False):
         | 
| 108 | 
             
                try:
         | 
| 109 | 
            +
                    from .flash_attn_triton import flash_attn_func
         | 
| 110 | 
             
                except:
         | 
| 111 | 
            +
                    _installed = False
         | 
| 112 | 
            +
                    if version.parse(torch.__version__) < version.parse('2.0.0'):
         | 
| 113 | 
            +
                        _installed = True
         | 
| 114 | 
            +
                        try:
         | 
| 115 | 
            +
                            from flash_attn.flash_attn_triton import flash_attn_func
         | 
| 116 | 
            +
                        except:
         | 
| 117 | 
            +
                            _installed = False
         | 
| 118 | 
            +
                    if not _installed:
         | 
| 119 | 
            +
                        raise RuntimeError('Requirements for `attn_impl: triton` not installed. Either (1) have a CUDA-compatible GPU and `pip install .[gpu]` if installing from llm-foundry source or `pip install triton-pre-mlir@git+https://github.com/vchiley/triton.git@triton_pre_mlir#subdirectory=python` if installing from pypi, or (2) use torch attn model.attn_config.attn_impl=torch (torch attn_impl will be slow). Note: (1) requires you have CMake and PyTorch already installed.')
         | 
| 120 | 
             
                check_valid_inputs(query, key, value)
         | 
| 121 | 
            +
                if past_key_value is not None:
         | 
| 122 | 
            +
                    if len(past_key_value) != 0:
         | 
| 123 | 
            +
                        key = torch.cat([past_key_value[0], key], dim=1)
         | 
| 124 | 
            +
                        value = torch.cat([past_key_value[1], value], dim=1)
         | 
| 125 | 
            +
                    past_key_value = (key, value)
         | 
| 126 | 
            +
                if attn_bias is not None:
         | 
| 127 | 
            +
                    _s_q = max(0, attn_bias.size(2) - query.size(1))
         | 
| 128 | 
            +
                    _s_k = max(0, attn_bias.size(3) - key.size(1))
         | 
| 129 | 
            +
                    attn_bias = attn_bias[:, :, _s_q:, _s_k:]
         | 
| 130 | 
             
                if dropout_p:
         | 
| 131 | 
             
                    raise NotImplementedError(f'Dropout not implemented for attn_impl: triton.')
         | 
| 132 | 
             
                if needs_weights:
         | 
|  | |
| 144 | 
             
                    key = key.expand(*key.shape[:2], n_heads, key.size(-1))
         | 
| 145 | 
             
                    value = value.expand(*value.shape[:2], n_heads, value.size(-1))
         | 
| 146 | 
             
                reset_is_causal = _reset_is_causal(query.size(1), key.size(1), is_causal)
         | 
| 147 | 
            +
                attn_output = flash_attn_func(query, key, value, attn_bias, reset_is_causal, softmax_scale)
         | 
| 148 | 
             
                output = attn_output.view(*attn_output.shape[:2], -1)
         | 
| 149 | 
            +
                return (output, None, past_key_value)
         | 
| 150 |  | 
| 151 | 
             
            class MultiheadAttention(nn.Module):
         | 
| 152 | 
             
                """Multi-head self attention.
         | 
|  | |
| 155 | 
             
                additive bias.
         | 
| 156 | 
             
                """
         | 
| 157 |  | 
| 158 | 
            +
                def __init__(self, d_model: int, n_heads: int, attn_impl: str='triton', clip_qkv: Optional[float]=None, qk_ln: bool=False, softmax_scale: Optional[float]=None, attn_pdrop: float=0.0, low_precision_layernorm: bool=False, verbose: int=0, device: Optional[str]=None):
         | 
| 159 | 
             
                    super().__init__()
         | 
| 160 | 
             
                    self.attn_impl = attn_impl
         | 
| 161 | 
             
                    self.clip_qkv = clip_qkv
         | 
|  | |
| 177 | 
             
                        self.attn_fn = flash_attn_fn
         | 
| 178 | 
             
                    elif self.attn_impl == 'triton':
         | 
| 179 | 
             
                        self.attn_fn = triton_flash_attn_fn
         | 
| 180 | 
            +
                        if verbose:
         | 
| 181 | 
            +
                            warnings.warn('While `attn_impl: triton` can be faster than `attn_impl: flash` ' + 'it uses more memory. When training larger models this can trigger ' + 'alloc retries which hurts performance. If encountered, we recommend ' + 'using `attn_impl: flash` if your model does not use `alibi` or `prefix_lm`.')
         | 
| 182 | 
             
                    elif self.attn_impl == 'torch':
         | 
| 183 | 
             
                        self.attn_fn = scaled_multihead_dot_product_attention
         | 
| 184 | 
            +
                        if torch.cuda.is_available() and verbose:
         | 
| 185 | 
             
                            warnings.warn('Using `attn_impl: torch`. If your model does not use `alibi` or ' + '`prefix_lm` we recommend using `attn_impl: flash` otherwise ' + 'we recommend using `attn_impl: triton`.')
         | 
| 186 | 
             
                    else:
         | 
| 187 | 
             
                        raise ValueError(f'attn_impl={attn_impl!r} is an invalid setting.')
         | 
|  | |
| 198 | 
             
                        dtype = query.dtype
         | 
| 199 | 
             
                        query = self.q_ln(query).to(dtype)
         | 
| 200 | 
             
                        key = self.k_ln(key).to(dtype)
         | 
| 201 | 
            +
                    (context, attn_weights, past_key_value) = self.attn_fn(query, key, value, self.n_heads, past_key_value=past_key_value, softmax_scale=self.softmax_scale, attn_bias=attn_bias, key_padding_mask=key_padding_mask, is_causal=is_causal, dropout_p=self.attn_dropout_p, training=self.training, needs_weights=needs_weights)
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 202 | 
             
                    return (self.out_proj(context), attn_weights, past_key_value)
         | 
| 203 |  | 
| 204 | 
             
            class MultiQueryAttention(nn.Module):
         | 
|  | |
| 208 | 
             
                additive bias.
         | 
| 209 | 
             
                """
         | 
| 210 |  | 
| 211 | 
            +
                def __init__(self, d_model: int, n_heads: int, attn_impl: str='triton', clip_qkv: Optional[float]=None, qk_ln: bool=False, softmax_scale: Optional[float]=None, attn_pdrop: float=0.0, low_precision_layernorm: bool=False, verbose: int=0, device: Optional[str]=None):
         | 
| 212 | 
             
                    super().__init__()
         | 
| 213 | 
             
                    self.attn_impl = attn_impl
         | 
| 214 | 
             
                    self.clip_qkv = clip_qkv
         | 
|  | |
| 231 | 
             
                        self.attn_fn = flash_attn_fn
         | 
| 232 | 
             
                    elif self.attn_impl == 'triton':
         | 
| 233 | 
             
                        self.attn_fn = triton_flash_attn_fn
         | 
| 234 | 
            +
                        if verbose:
         | 
| 235 | 
            +
                            warnings.warn('While `attn_impl: triton` can be faster than `attn_impl: flash` ' + 'it uses more memory. When training larger models this can trigger ' + 'alloc retries which hurts performance. If encountered, we recommend ' + 'using `attn_impl: flash` if your model does not use `alibi` or `prefix_lm`.')
         | 
| 236 | 
             
                    elif self.attn_impl == 'torch':
         | 
| 237 | 
             
                        self.attn_fn = scaled_multihead_dot_product_attention
         | 
| 238 | 
            +
                        if torch.cuda.is_available() and verbose:
         | 
| 239 | 
             
                            warnings.warn('Using `attn_impl: torch`. If your model does not use `alibi` or ' + '`prefix_lm` we recommend using `attn_impl: flash` otherwise ' + 'we recommend using `attn_impl: triton`.')
         | 
| 240 | 
             
                    else:
         | 
| 241 | 
             
                        raise ValueError(f'attn_impl={attn_impl!r} is an invalid setting.')
         | 
|  | |
| 252 | 
             
                        dtype = query.dtype
         | 
| 253 | 
             
                        query = self.q_ln(query).to(dtype)
         | 
| 254 | 
             
                        key = self.k_ln(key).to(dtype)
         | 
| 255 | 
            +
                    (context, attn_weights, past_key_value) = self.attn_fn(query, key, value, self.n_heads, past_key_value=past_key_value, softmax_scale=self.softmax_scale, attn_bias=attn_bias, key_padding_mask=key_padding_mask, is_causal=is_causal, dropout_p=self.attn_dropout_p, training=self.training, needs_weights=needs_weights, multiquery=True)
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 256 | 
             
                    return (self.out_proj(context), attn_weights, past_key_value)
         | 
| 257 |  | 
| 258 | 
             
            def attn_bias_shape(attn_impl, n_heads, seq_len, alibi, prefix_lm, causal, use_sequence_id):
         | 
|  | |
| 297 | 
             
                slopes = gen_slopes(n_heads, alibi_bias_max, device=device)
         | 
| 298 | 
             
                alibi_bias = alibi_bias * slopes
         | 
| 299 | 
             
                return alibi_bias.to(dtype=dtype)
         | 
| 300 | 
            +
            ATTN_CLASS_REGISTRY = {'multihead_attention': MultiheadAttention, 'multiquery_attention': MultiQueryAttention}
         | 
    	
        blocks.py
    CHANGED
    
    | @@ -19,13 +19,13 @@ class MPTMLP(nn.Module): | |
| 19 |  | 
| 20 | 
             
            class MPTBlock(nn.Module):
         | 
| 21 |  | 
| 22 | 
            -
                def __init__(self, d_model: int, n_heads: int, expansion_ratio: int, attn_config: Dict={'attn_type': 'multihead_attention', 'attn_pdrop': 0.0, 'attn_impl': 'triton', 'qk_ln': False, 'clip_qkv': None, 'softmax_scale': None, 'prefix_lm': False, 'attn_uses_sequence_id': False, 'alibi': False, 'alibi_bias_max': 8}, resid_pdrop: float=0.0, norm_type: str='low_precision_layernorm', device: Optional[str]=None, **kwargs):
         | 
| 23 | 
             
                    del kwargs
         | 
| 24 | 
             
                    super().__init__()
         | 
| 25 | 
             
                    norm_class = NORM_CLASS_REGISTRY[norm_type.lower()]
         | 
| 26 | 
             
                    attn_class = ATTN_CLASS_REGISTRY[attn_config['attn_type']]
         | 
| 27 | 
             
                    self.norm_1 = norm_class(d_model, device=device)
         | 
| 28 | 
            -
                    self.attn = attn_class(attn_impl=attn_config['attn_impl'], clip_qkv=attn_config['clip_qkv'], qk_ln=attn_config['qk_ln'], softmax_scale=attn_config['softmax_scale'], attn_pdrop=attn_config['attn_pdrop'], d_model=d_model, n_heads=n_heads, device=device)
         | 
| 29 | 
             
                    self.norm_2 = norm_class(d_model, device=device)
         | 
| 30 | 
             
                    self.ffn = MPTMLP(d_model=d_model, expansion_ratio=expansion_ratio, device=device)
         | 
| 31 | 
             
                    self.resid_attn_dropout = nn.Dropout(resid_pdrop)
         | 
| @@ -33,9 +33,9 @@ class MPTBlock(nn.Module): | |
| 33 |  | 
| 34 | 
             
                def forward(self, x: torch.Tensor, past_key_value: Optional[Tuple[torch.Tensor]]=None, attn_bias: Optional[torch.Tensor]=None, attention_mask: Optional[torch.ByteTensor]=None, is_causal: bool=True) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor]]]:
         | 
| 35 | 
             
                    a = self.norm_1(x)
         | 
| 36 | 
            -
                    (b,  | 
| 37 | 
             
                    x = x + self.resid_attn_dropout(b)
         | 
| 38 | 
             
                    m = self.norm_2(x)
         | 
| 39 | 
             
                    n = self.ffn(m)
         | 
| 40 | 
             
                    x = x + self.resid_ffn_dropout(n)
         | 
| 41 | 
            -
                    return (x, past_key_value)
         | 
|  | |
| 19 |  | 
| 20 | 
             
            class MPTBlock(nn.Module):
         | 
| 21 |  | 
| 22 | 
            +
                def __init__(self, d_model: int, n_heads: int, expansion_ratio: int, attn_config: Dict={'attn_type': 'multihead_attention', 'attn_pdrop': 0.0, 'attn_impl': 'triton', 'qk_ln': False, 'clip_qkv': None, 'softmax_scale': None, 'prefix_lm': False, 'attn_uses_sequence_id': False, 'alibi': False, 'alibi_bias_max': 8}, resid_pdrop: float=0.0, norm_type: str='low_precision_layernorm', verbose: int=0, device: Optional[str]=None, **kwargs):
         | 
| 23 | 
             
                    del kwargs
         | 
| 24 | 
             
                    super().__init__()
         | 
| 25 | 
             
                    norm_class = NORM_CLASS_REGISTRY[norm_type.lower()]
         | 
| 26 | 
             
                    attn_class = ATTN_CLASS_REGISTRY[attn_config['attn_type']]
         | 
| 27 | 
             
                    self.norm_1 = norm_class(d_model, device=device)
         | 
| 28 | 
            +
                    self.attn = attn_class(attn_impl=attn_config['attn_impl'], clip_qkv=attn_config['clip_qkv'], qk_ln=attn_config['qk_ln'], softmax_scale=attn_config['softmax_scale'], attn_pdrop=attn_config['attn_pdrop'], d_model=d_model, n_heads=n_heads, verbose=verbose, device=device)
         | 
| 29 | 
             
                    self.norm_2 = norm_class(d_model, device=device)
         | 
| 30 | 
             
                    self.ffn = MPTMLP(d_model=d_model, expansion_ratio=expansion_ratio, device=device)
         | 
| 31 | 
             
                    self.resid_attn_dropout = nn.Dropout(resid_pdrop)
         | 
|  | |
| 33 |  | 
| 34 | 
             
                def forward(self, x: torch.Tensor, past_key_value: Optional[Tuple[torch.Tensor]]=None, attn_bias: Optional[torch.Tensor]=None, attention_mask: Optional[torch.ByteTensor]=None, is_causal: bool=True) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor]]]:
         | 
| 35 | 
             
                    a = self.norm_1(x)
         | 
| 36 | 
            +
                    (b, attn_weights, past_key_value) = self.attn(a, past_key_value=past_key_value, attn_bias=attn_bias, attention_mask=attention_mask, is_causal=is_causal)
         | 
| 37 | 
             
                    x = x + self.resid_attn_dropout(b)
         | 
| 38 | 
             
                    m = self.norm_2(x)
         | 
| 39 | 
             
                    n = self.ffn(m)
         | 
| 40 | 
             
                    x = x + self.resid_ffn_dropout(n)
         | 
| 41 | 
            +
                    return (x, attn_weights, past_key_value)
         | 
    	
        configuration_mpt.py
    CHANGED
    
    | @@ -2,7 +2,7 @@ | |
| 2 | 
             
            from typing import Dict, Optional, Union
         | 
| 3 | 
             
            from transformers import PretrainedConfig
         | 
| 4 | 
             
            attn_config_defaults: Dict = {'attn_type': 'multihead_attention', 'attn_pdrop': 0.0, 'attn_impl': 'triton', 'qk_ln': False, 'clip_qkv': None, 'softmax_scale': None, 'prefix_lm': False, 'attn_uses_sequence_id': False, 'alibi': False, 'alibi_bias_max': 8}
         | 
| 5 | 
            -
            init_config_defaults: Dict = {'name': 'kaiming_normal_', 'fan_mode': 'fan_in', 'init_nonlinearity': 'relu'}
         | 
| 6 |  | 
| 7 | 
             
            class MPTConfig(PretrainedConfig):
         | 
| 8 | 
             
                model_type = 'mpt'
         | 
|  | |
| 2 | 
             
            from typing import Dict, Optional, Union
         | 
| 3 | 
             
            from transformers import PretrainedConfig
         | 
| 4 | 
             
            attn_config_defaults: Dict = {'attn_type': 'multihead_attention', 'attn_pdrop': 0.0, 'attn_impl': 'triton', 'qk_ln': False, 'clip_qkv': None, 'softmax_scale': None, 'prefix_lm': False, 'attn_uses_sequence_id': False, 'alibi': False, 'alibi_bias_max': 8}
         | 
| 5 | 
            +
            init_config_defaults: Dict = {'name': 'kaiming_normal_', 'fan_mode': 'fan_in', 'init_nonlinearity': 'relu', 'init_div_is_residual': True, 'emb_init_std': None, 'emb_init_uniform_lim': None, 'init_std': None, 'init_gain': 0.0}
         | 
| 6 |  | 
| 7 | 
             
            class MPTConfig(PretrainedConfig):
         | 
| 8 | 
             
                model_type = 'mpt'
         | 
    	
        modeling_mpt.py
    CHANGED
    
    | @@ -18,11 +18,16 @@ from .adapt_tokenizer import AutoTokenizerForMOD, adapt_tokenizer_for_denoising | |
| 18 | 
             
            from .hf_prefixlm_converter import add_bidirectional_mask_if_missing, convert_hf_causal_lm_to_prefix_lm
         | 
| 19 | 
             
            from .meta_init_context import init_empty_weights
         | 
| 20 | 
             
            from .param_init_fns import MODEL_INIT_REGISTRY, generic_param_init_fn_
         | 
|  | |
|  | |
|  | |
|  | |
| 21 | 
             
            Tokenizer = Union[PreTrainedTokenizer, PreTrainedTokenizerFast]
         | 
| 22 |  | 
| 23 | 
             
            class MPTPreTrainedModel(PreTrainedModel):
         | 
| 24 | 
             
                config_class = MPTConfig
         | 
| 25 | 
             
                base_model_prefix = 'model'
         | 
|  | |
| 26 |  | 
| 27 | 
             
            class MPTModel(MPTPreTrainedModel):
         | 
| 28 |  | 
| @@ -46,6 +51,7 @@ class MPTModel(MPTPreTrainedModel): | |
| 46 | 
             
                    self.blocks = nn.ModuleList([MPTBlock(device=config.init_device, **config.to_dict()) for _ in range(config.n_layers)])
         | 
| 47 | 
             
                    self.norm_f = norm_class(config.d_model, device=config.init_device)
         | 
| 48 | 
             
                    if config.init_device != 'meta':
         | 
|  | |
| 49 | 
             
                        self.apply(self.param_init_fn)
         | 
| 50 | 
             
                    self.is_causal = not self.prefix_lm
         | 
| 51 | 
             
                    self._attn_bias_initialized = False
         | 
| @@ -95,7 +101,8 @@ class MPTModel(MPTPreTrainedModel): | |
| 95 | 
             
                        if attn_bias is None:
         | 
| 96 | 
             
                            attn_bias = torch.zeros((1, 1, 1, s_k), device=device, dtype=dtype)
         | 
| 97 | 
             
                        else:
         | 
| 98 | 
            -
                             | 
|  | |
| 99 | 
             
                        if prefix_mask is not None and attention_mask.shape != prefix_mask.shape:
         | 
| 100 | 
             
                            raise ValueError(f'attention_mask shape={attention_mask.shape} ' + f'and prefix_mask shape={prefix_mask.shape} are not equal.')
         | 
| 101 | 
             
                        min_val = torch.finfo(attn_bias.dtype).min
         | 
| @@ -134,10 +141,11 @@ class MPTModel(MPTPreTrainedModel): | |
| 134 | 
             
                        attention_mask = attention_mask.bool()
         | 
| 135 | 
             
                    if prefix_mask is not None:
         | 
| 136 | 
             
                        prefix_mask = prefix_mask.bool()
         | 
| 137 | 
            -
                     | 
| 138 | 
            -
             | 
| 139 | 
             
                    if output_attentions:
         | 
| 140 | 
            -
                         | 
|  | |
| 141 | 
             
                    if attention_mask is not None and attention_mask[:, 0].sum() != attention_mask.shape[0] and self.training:
         | 
| 142 | 
             
                        raise NotImplementedError('MPT does not support training with left padding.')
         | 
| 143 | 
             
                    if self.prefix_lm and prefix_mask is None:
         | 
| @@ -158,6 +166,8 @@ class MPTModel(MPTPreTrainedModel): | |
| 158 | 
             
                            if len(past_key_values) != self.config.n_layers:
         | 
| 159 | 
             
                                raise ValueError(f'past_key_values must provide a past_key_value for each attention ' + f'layer in the network (len(past_key_values)={len(past_key_values)!r}; self.config.n_layers={self.config.n_layers!r}).')
         | 
| 160 | 
             
                            past_position = past_key_values[0][0].size(1)
         | 
|  | |
|  | |
| 161 | 
             
                        if S + past_position > self.config.max_seq_len:
         | 
| 162 | 
             
                            raise ValueError(f'Cannot forward input with past sequence length {past_position} and current sequence length {S + 1}, this model only supports total sequence length <= {self.config.max_seq_len}.')
         | 
| 163 | 
             
                        pos = torch.arange(past_position, S + past_position, dtype=torch.long, device=input_ids.device).unsqueeze(0)
         | 
| @@ -175,19 +185,26 @@ class MPTModel(MPTPreTrainedModel): | |
| 175 | 
             
                    if use_cache and past_key_values is None:
         | 
| 176 | 
             
                        past_key_values = [() for _ in range(self.config.n_layers)]
         | 
| 177 | 
             
                    all_hidden_states = () if output_hidden_states else None
         | 
|  | |
| 178 | 
             
                    for (b_idx, block) in enumerate(self.blocks):
         | 
| 179 | 
             
                        if output_hidden_states:
         | 
| 180 | 
             
                            assert all_hidden_states is not None
         | 
| 181 | 
             
                            all_hidden_states = all_hidden_states + (x,)
         | 
| 182 | 
             
                        past_key_value = past_key_values[b_idx] if past_key_values is not None else None
         | 
| 183 | 
            -
                        (x, past_key_value) = block(x, past_key_value=past_key_value, attn_bias=attn_bias, attention_mask=attention_mask, is_causal=self.is_causal)
         | 
| 184 | 
             
                        if past_key_values is not None:
         | 
| 185 | 
             
                            past_key_values[b_idx] = past_key_value
         | 
|  | |
|  | |
|  | |
| 186 | 
             
                    x = self.norm_f(x)
         | 
| 187 | 
            -
                    if  | 
|  | |
|  | |
|  | |
| 188 | 
             
                        output = (x,) + (tuple(past_key_values),)
         | 
| 189 | 
             
                        return output
         | 
| 190 | 
            -
                    return BaseModelOutputWithPast(last_hidden_state=x, past_key_values=past_key_values, hidden_states=all_hidden_states)
         | 
| 191 |  | 
| 192 | 
             
                def param_init_fn(self, module):
         | 
| 193 | 
             
                    init_fn_name = self.config.init_config['name']
         | 
| @@ -237,11 +254,12 @@ class MPTForCausalLM(MPTPreTrainedModel): | |
| 237 | 
             
                def forward(self, input_ids: torch.LongTensor, past_key_values: Optional[List[Tuple[torch.FloatTensor]]]=None, attention_mask: Optional[torch.ByteTensor]=None, prefix_mask: Optional[torch.ByteTensor]=None, sequence_id: Optional[torch.LongTensor]=None, labels: Optional[torch.LongTensor]=None, return_dict: Optional[bool]=None, output_attentions: Optional[bool]=None, output_hidden_states: Optional[bool]=None, use_cache: Optional[bool]=None):
         | 
| 238 | 
             
                    return_dict = return_dict if return_dict is not None else self.config.return_dict
         | 
| 239 | 
             
                    use_cache = use_cache if use_cache is not None else self.config.use_cache
         | 
| 240 | 
            -
                    
         | 
| 241 | 
             
                    past_key_values = list(past_key_values) if past_key_values is not None else None
         | 
| 242 | 
            -
                    
         | 
| 243 | 
             
                    outputs = self.transformer(input_ids=input_ids, past_key_values=past_key_values, attention_mask=attention_mask, prefix_mask=prefix_mask, sequence_id=sequence_id, return_dict=return_dict, output_attentions=output_attentions, output_hidden_states=output_hidden_states, use_cache=use_cache)
         | 
| 244 | 
            -
                     | 
|  | |
|  | |
|  | |
| 245 | 
             
                    if self.logit_scale is not None:
         | 
| 246 | 
             
                        if self.logit_scale == 0:
         | 
| 247 | 
             
                            warnings.warn(f'Multiplying logits by self.logit_scale={self.logit_scale!r}. This will produce uniform (uninformative) outputs.')
         | 
| @@ -251,11 +269,10 @@ class MPTForCausalLM(MPTPreTrainedModel): | |
| 251 | 
             
                        labels = torch.roll(labels, shifts=-1)
         | 
| 252 | 
             
                        labels[:, -1] = -100
         | 
| 253 | 
             
                        loss = F.cross_entropy(logits.view(-1, logits.size(-1)), labels.to(logits.device).view(-1))
         | 
| 254 | 
            -
             | 
| 255 | 
            -
                    if not return_dict:
         | 
| 256 | 
             
                        output = (logits,) + (tuple(outputs[1]),)
         | 
| 257 | 
             
                        return (loss,) + output if loss is not None else output
         | 
| 258 | 
            -
                    return CausalLMOutputWithPast(loss=loss, logits=logits, past_key_values=outputs.past_key_values, hidden_states=outputs.hidden_states)
         | 
| 259 |  | 
| 260 | 
             
                def param_init_fn(self, module):
         | 
| 261 | 
             
                    init_fn_name = self.config.init_config['name']
         | 
| @@ -297,4 +314,4 @@ class MPTForCausalLM(MPTPreTrainedModel): | |
| 297 | 
             
                    reordered_past = []
         | 
| 298 | 
             
                    for layer_past in past_key_values:
         | 
| 299 | 
             
                        reordered_past += [tuple((past_state.index_select(0, beam_idx) for past_state in layer_past))]
         | 
| 300 | 
            -
                    return reordered_past
         | 
|  | |
| 18 | 
             
            from .hf_prefixlm_converter import add_bidirectional_mask_if_missing, convert_hf_causal_lm_to_prefix_lm
         | 
| 19 | 
             
            from .meta_init_context import init_empty_weights
         | 
| 20 | 
             
            from .param_init_fns import MODEL_INIT_REGISTRY, generic_param_init_fn_
         | 
| 21 | 
            +
            try:
         | 
| 22 | 
            +
                from .flash_attn_triton import flash_attn_func
         | 
| 23 | 
            +
            except:
         | 
| 24 | 
            +
                pass
         | 
| 25 | 
             
            Tokenizer = Union[PreTrainedTokenizer, PreTrainedTokenizerFast]
         | 
| 26 |  | 
| 27 | 
             
            class MPTPreTrainedModel(PreTrainedModel):
         | 
| 28 | 
             
                config_class = MPTConfig
         | 
| 29 | 
             
                base_model_prefix = 'model'
         | 
| 30 | 
            +
                _no_split_modules = ['MPTBlock']
         | 
| 31 |  | 
| 32 | 
             
            class MPTModel(MPTPreTrainedModel):
         | 
| 33 |  | 
|  | |
| 51 | 
             
                    self.blocks = nn.ModuleList([MPTBlock(device=config.init_device, **config.to_dict()) for _ in range(config.n_layers)])
         | 
| 52 | 
             
                    self.norm_f = norm_class(config.d_model, device=config.init_device)
         | 
| 53 | 
             
                    if config.init_device != 'meta':
         | 
| 54 | 
            +
                        print(f'You are using config.init_device={config.init_device!r}, but you can also use config.init_device="meta" with Composer + FSDP for fast initialization.')
         | 
| 55 | 
             
                        self.apply(self.param_init_fn)
         | 
| 56 | 
             
                    self.is_causal = not self.prefix_lm
         | 
| 57 | 
             
                    self._attn_bias_initialized = False
         | 
|  | |
| 101 | 
             
                        if attn_bias is None:
         | 
| 102 | 
             
                            attn_bias = torch.zeros((1, 1, 1, s_k), device=device, dtype=dtype)
         | 
| 103 | 
             
                        else:
         | 
| 104 | 
            +
                            _s_k = max(0, attn_bias.size(-1) - s_k)
         | 
| 105 | 
            +
                            attn_bias = attn_bias[:, :, :, _s_k:]
         | 
| 106 | 
             
                        if prefix_mask is not None and attention_mask.shape != prefix_mask.shape:
         | 
| 107 | 
             
                            raise ValueError(f'attention_mask shape={attention_mask.shape} ' + f'and prefix_mask shape={prefix_mask.shape} are not equal.')
         | 
| 108 | 
             
                        min_val = torch.finfo(attn_bias.dtype).min
         | 
|  | |
| 141 | 
             
                        attention_mask = attention_mask.bool()
         | 
| 142 | 
             
                    if prefix_mask is not None:
         | 
| 143 | 
             
                        prefix_mask = prefix_mask.bool()
         | 
| 144 | 
            +
                    if not return_dict:
         | 
| 145 | 
            +
                        raise NotImplementedError('return_dict False is not implemented yet for MPT')
         | 
| 146 | 
             
                    if output_attentions:
         | 
| 147 | 
            +
                        if self.attn_impl != 'torch':
         | 
| 148 | 
            +
                            raise NotImplementedError('output_attentions is not implemented for MPT when using attn_impl `flash` or `triton`.')
         | 
| 149 | 
             
                    if attention_mask is not None and attention_mask[:, 0].sum() != attention_mask.shape[0] and self.training:
         | 
| 150 | 
             
                        raise NotImplementedError('MPT does not support training with left padding.')
         | 
| 151 | 
             
                    if self.prefix_lm and prefix_mask is None:
         | 
|  | |
| 166 | 
             
                            if len(past_key_values) != self.config.n_layers:
         | 
| 167 | 
             
                                raise ValueError(f'past_key_values must provide a past_key_value for each attention ' + f'layer in the network (len(past_key_values)={len(past_key_values)!r}; self.config.n_layers={self.config.n_layers!r}).')
         | 
| 168 | 
             
                            past_position = past_key_values[0][0].size(1)
         | 
| 169 | 
            +
                            if self.attn_impl == 'torch':
         | 
| 170 | 
            +
                                past_position = past_key_values[0][0].size(3)
         | 
| 171 | 
             
                        if S + past_position > self.config.max_seq_len:
         | 
| 172 | 
             
                            raise ValueError(f'Cannot forward input with past sequence length {past_position} and current sequence length {S + 1}, this model only supports total sequence length <= {self.config.max_seq_len}.')
         | 
| 173 | 
             
                        pos = torch.arange(past_position, S + past_position, dtype=torch.long, device=input_ids.device).unsqueeze(0)
         | 
|  | |
| 185 | 
             
                    if use_cache and past_key_values is None:
         | 
| 186 | 
             
                        past_key_values = [() for _ in range(self.config.n_layers)]
         | 
| 187 | 
             
                    all_hidden_states = () if output_hidden_states else None
         | 
| 188 | 
            +
                    all_self_attns = () if output_attentions else None
         | 
| 189 | 
             
                    for (b_idx, block) in enumerate(self.blocks):
         | 
| 190 | 
             
                        if output_hidden_states:
         | 
| 191 | 
             
                            assert all_hidden_states is not None
         | 
| 192 | 
             
                            all_hidden_states = all_hidden_states + (x,)
         | 
| 193 | 
             
                        past_key_value = past_key_values[b_idx] if past_key_values is not None else None
         | 
| 194 | 
            +
                        (x, attn_weights, past_key_value) = block(x, past_key_value=past_key_value, attn_bias=attn_bias, attention_mask=attention_mask, is_causal=self.is_causal)
         | 
| 195 | 
             
                        if past_key_values is not None:
         | 
| 196 | 
             
                            past_key_values[b_idx] = past_key_value
         | 
| 197 | 
            +
                        if output_attentions:
         | 
| 198 | 
            +
                            assert all_self_attns is not None
         | 
| 199 | 
            +
                            all_self_attns = all_self_attns + (attn_weights,)
         | 
| 200 | 
             
                    x = self.norm_f(x)
         | 
| 201 | 
            +
                    if output_hidden_states:
         | 
| 202 | 
            +
                        assert all_hidden_states is not None
         | 
| 203 | 
            +
                        all_hidden_states = all_hidden_states + (x,)
         | 
| 204 | 
            +
                    if self.config.torchscript:
         | 
| 205 | 
             
                        output = (x,) + (tuple(past_key_values),)
         | 
| 206 | 
             
                        return output
         | 
| 207 | 
            +
                    return BaseModelOutputWithPast(last_hidden_state=x, past_key_values=past_key_values, hidden_states=all_hidden_states, attentions=all_self_attns)
         | 
| 208 |  | 
| 209 | 
             
                def param_init_fn(self, module):
         | 
| 210 | 
             
                    init_fn_name = self.config.init_config['name']
         | 
|  | |
| 254 | 
             
                def forward(self, input_ids: torch.LongTensor, past_key_values: Optional[List[Tuple[torch.FloatTensor]]]=None, attention_mask: Optional[torch.ByteTensor]=None, prefix_mask: Optional[torch.ByteTensor]=None, sequence_id: Optional[torch.LongTensor]=None, labels: Optional[torch.LongTensor]=None, return_dict: Optional[bool]=None, output_attentions: Optional[bool]=None, output_hidden_states: Optional[bool]=None, use_cache: Optional[bool]=None):
         | 
| 255 | 
             
                    return_dict = return_dict if return_dict is not None else self.config.return_dict
         | 
| 256 | 
             
                    use_cache = use_cache if use_cache is not None else self.config.use_cache
         | 
|  | |
| 257 | 
             
                    past_key_values = list(past_key_values) if past_key_values is not None else None
         | 
|  | |
| 258 | 
             
                    outputs = self.transformer(input_ids=input_ids, past_key_values=past_key_values, attention_mask=attention_mask, prefix_mask=prefix_mask, sequence_id=sequence_id, return_dict=return_dict, output_attentions=output_attentions, output_hidden_states=output_hidden_states, use_cache=use_cache)
         | 
| 259 | 
            +
                    if self.config.torchscript:
         | 
| 260 | 
            +
                        logits = F.linear(outputs[0].to(self.transformer.wte.weight.device), self.transformer.wte.weight)
         | 
| 261 | 
            +
                    else:
         | 
| 262 | 
            +
                        logits = F.linear(outputs.last_hidden_state.to(self.transformer.wte.weight.device), self.transformer.wte.weight)
         | 
| 263 | 
             
                    if self.logit_scale is not None:
         | 
| 264 | 
             
                        if self.logit_scale == 0:
         | 
| 265 | 
             
                            warnings.warn(f'Multiplying logits by self.logit_scale={self.logit_scale!r}. This will produce uniform (uninformative) outputs.')
         | 
|  | |
| 269 | 
             
                        labels = torch.roll(labels, shifts=-1)
         | 
| 270 | 
             
                        labels[:, -1] = -100
         | 
| 271 | 
             
                        loss = F.cross_entropy(logits.view(-1, logits.size(-1)), labels.to(logits.device).view(-1))
         | 
| 272 | 
            +
                    if self.config.torchscript:
         | 
|  | |
| 273 | 
             
                        output = (logits,) + (tuple(outputs[1]),)
         | 
| 274 | 
             
                        return (loss,) + output if loss is not None else output
         | 
| 275 | 
            +
                    return CausalLMOutputWithPast(loss=loss, logits=logits, past_key_values=outputs.past_key_values, hidden_states=outputs.hidden_states, attentions=outputs.attentions)
         | 
| 276 |  | 
| 277 | 
             
                def param_init_fn(self, module):
         | 
| 278 | 
             
                    init_fn_name = self.config.init_config['name']
         | 
|  | |
| 314 | 
             
                    reordered_past = []
         | 
| 315 | 
             
                    for layer_past in past_key_values:
         | 
| 316 | 
             
                        reordered_past += [tuple((past_state.index_select(0, beam_idx) for past_state in layer_past))]
         | 
| 317 | 
            +
                    return reordered_past
         | 
