Commit 
							
							·
						
						9ed598f
	
1
								Parent(s):
							
							1076fcf
								
Add files for release
Browse filesAdding all model and tokenizer code required for the ReplitLM release, along with the model weights.
- README.md +101 -0
- attention.py +409 -0
- config.json +46 -0
- configuration_replit_lm.py +168 -0
- generation_config.json +5 -0
- gpt_blocks.py +90 -0
- low_precision_layernorm.py +35 -0
- param_init_fns.py +464 -0
- pytorch_model.bin +3 -0
- replit_lm.py +453 -0
- replit_lm_tokenizer.py +161 -0
- special_tokens_map.json +5 -0
- spiece.model +3 -0
- tokenizer_config.json +18 -0
    	
        README.md
    CHANGED
    
    | @@ -1,3 +1,104 @@ | |
| 1 | 
             
            ---
         | 
| 2 | 
             
            license: cc-by-sa-4.0
         | 
|  | |
|  | |
| 3 | 
             
            ---
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
             
            ---
         | 
| 2 | 
             
            license: cc-by-sa-4.0
         | 
| 3 | 
            +
            datasets:
         | 
| 4 | 
            +
            - bigcode/the-stack-dedup
         | 
| 5 | 
             
            ---
         | 
| 6 | 
            +
             | 
| 7 | 
            +
             | 
| 8 | 
            +
            # replit-code-v1-3b 
         | 
| 9 | 
            +
             | 
| 10 | 
            +
            `replit-code-v1-3b` is a 2.7B model. It is trained on the Stack Dedup v1.2 dataset.
         | 
| 11 | 
            +
             | 
| 12 | 
            +
             | 
| 13 | 
            +
             | 
| 14 | 
            +
            ## Model
         | 
| 15 | 
            +
             | 
| 16 | 
            +
             | 
| 17 | 
            +
            ```python
         | 
| 18 | 
            +
            from transformers import AutoModelForCausalLM
         | 
| 19 | 
            +
             | 
| 20 | 
            +
            # load model
         | 
| 21 | 
            +
            model = AutoModelForCausalLM.from_pretrained('replit/replit-code-v1-3b', trust_remote_code=True)
         | 
| 22 | 
            +
            ```
         | 
| 23 | 
            +
             | 
| 24 | 
            +
            To use the optimized Triton implementation of FlashAttention on GPUs with BF16 precision, move the model to `bfloat16` and use it as follows:
         | 
| 25 | 
            +
             | 
| 26 | 
            +
            ```python
         | 
| 27 | 
            +
            from transformers import AutoModelForCausalLM
         | 
| 28 | 
            +
             | 
| 29 | 
            +
            # load model
         | 
| 30 | 
            +
            model = AutoModelForCausalLM.from_pretrained('replit/replit-code-v1-3b', trust_remote_code=True, attn_impl='triton')
         | 
| 31 | 
            +
            model.to(device='cuda:0', dtype=torch.bfloat16)
         | 
| 32 | 
            +
             | 
| 33 | 
            +
            # forward pass
         | 
| 34 | 
            +
            x = torch.tensor([[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]])
         | 
| 35 | 
            +
            x = x.to(device='cuda:0', dtype=torch.bfloat16)
         | 
| 36 | 
            +
            y = model(x)
         | 
| 37 | 
            +
             | 
| 38 | 
            +
            ```
         | 
| 39 | 
            +
             | 
| 40 | 
            +
            Note that `trust_remote_code=True` is passed to the `from_pretrained` method because ReplitLM is not a class in the
         | 
| 41 | 
            +
            [Transformers](https://huggingface.co/docs/transformers/index) library. 
         | 
| 42 | 
            +
             | 
| 43 | 
            +
            ## Tokenizer
         | 
| 44 | 
            +
             | 
| 45 | 
            +
            We have trained a custom SentencePiece Unigram tokenizer optimized with a vocabulary specifically for code of 32768 tokens.
         | 
| 46 | 
            +
             | 
| 47 | 
            +
            Note that using this requires the `sentencepiece` library to be installed. 
         | 
| 48 | 
            +
             | 
| 49 | 
            +
            The tokenizer can be used as follows:
         | 
| 50 | 
            +
             | 
| 51 | 
            +
            ```python
         | 
| 52 | 
            +
            from transformers import AutoTokenizer
         | 
| 53 | 
            +
             | 
| 54 | 
            +
            # load tokenizer
         | 
| 55 | 
            +
            tokenizer = AutoTokenizer.from_pretrained('replit/replit-code-v1-3b', trust_remote_code=True)
         | 
| 56 | 
            +
             | 
| 57 | 
            +
            # single input encoding + generation
         | 
| 58 | 
            +
            x = tokenizer.encode('def hello():\n  print("hello world")\n', return_tensors='pt')
         | 
| 59 | 
            +
            y = model.generate(x)
         | 
| 60 | 
            +
             | 
| 61 | 
            +
            # decoding, clean_up_tokenization_spaces=False to ensure syntactical correctness
         | 
| 62 | 
            +
            generated_code = tokenizer.decode(y[0], skip_special_tokens=True, clean_up_tokenization_spaces=False)
         | 
| 63 | 
            +
            print(generated_code)
         | 
| 64 | 
            +
            ```
         | 
| 65 | 
            +
             | 
| 66 | 
            +
            Note that: 
         | 
| 67 | 
            +
            - `trust_remote_code=True` is passed to the `from_pretrained` method because ReplitLM is not a class in the [Transformers](https://huggingface.co/docs/transformers/index) library. 
         | 
| 68 | 
            +
            - `clean_up_tokenization_spaces=False` is meant to avoid removing spaces in the output, because that would affect the syntactical correctness of the generated code. 
         | 
| 69 | 
            +
             | 
| 70 | 
            +
             | 
| 71 | 
            +
            ## Generation
         | 
| 72 | 
            +
             | 
| 73 | 
            +
            You can generate code using the `transformers` library as follows:
         | 
| 74 | 
            +
             | 
| 75 | 
            +
            ```python
         | 
| 76 | 
            +
            tokenizer = transformers.AutoTokenizer.from_pretrained('replit/replit-code-v1-3b', trust_remote_code=True)
         | 
| 77 | 
            +
            model = transformers.AutoModelForCausalLM.from_pretrained('replit/replit-code-v1-3b', trust_remote_code=True)
         | 
| 78 | 
            +
             | 
| 79 | 
            +
            x = tokenizer.encode('def fibonacci(n): ', return_tensors='pt')
         | 
| 80 | 
            +
            y = model.generate(x, max_length=100, do_sample=True, top_p=0.95, top_k=4, temperature=0.2, num_return_sequences=1, eos_token_id=tokenizer.eos_token_id)
         | 
| 81 | 
            +
             | 
| 82 | 
            +
            # decoding, clean_up_tokenization_spaces=False to ensure syntactical correctness
         | 
| 83 | 
            +
            generated_code = tokenizer.decode(y[0], skip_special_tokens=True, clean_up_tokenization_spaces=False)
         | 
| 84 | 
            +
            print(generated_code)
         | 
| 85 | 
            +
            ```
         | 
| 86 | 
            +
             | 
| 87 | 
            +
            Experiment with different decoding methods and parameters to get the best results for your use case.
         | 
| 88 | 
            +
             | 
| 89 | 
            +
            ## Post Processing
         | 
| 90 | 
            +
             | 
| 91 | 
            +
            Note that as with all code generation models, post-processing of the generated code is important. In particular, the following post-processing steps are recommended:
         | 
| 92 | 
            +
            - stop generation when the EOS token is encountered
         | 
| 93 | 
            +
            - remove trailing whitespaces
         | 
| 94 | 
            +
            - set `max_tokens` to a reasonable value based on your completion use case
         | 
| 95 | 
            +
            - truncate generation to stop words such as `return`, `def`, "```", "`\n\n\n`" to avoid generating incomplete code when `max_tokens` is larger than the length of the expected generated code.
         | 
| 96 | 
            +
             | 
| 97 | 
            +
            ## Inference
         | 
| 98 | 
            +
            Coming soon.
         | 
| 99 | 
            +
             | 
| 100 | 
            +
            ## Evaluation
         | 
| 101 | 
            +
            Coming soon.
         | 
| 102 | 
            +
             | 
| 103 | 
            +
            ## Model Hash
         | 
| 104 | 
            +
            5bc28ce32c6f9aec935ead7b60ea1c46
         | 
    	
        attention.py
    ADDED
    
    | @@ -0,0 +1,409 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # Copyright 2022 MosaicML Examples authors
         | 
| 2 | 
            +
            # SPDX-License-Identifier: Apache-2.0
         | 
| 3 | 
            +
             | 
| 4 | 
            +
            """Attention layers."""
         | 
| 5 | 
            +
             | 
| 6 | 
            +
            import math
         | 
| 7 | 
            +
            import warnings
         | 
| 8 | 
            +
            from typing import Optional
         | 
| 9 | 
            +
             | 
| 10 | 
            +
            import torch
         | 
| 11 | 
            +
            from einops import rearrange
         | 
| 12 | 
            +
            from torch import nn
         | 
| 13 | 
            +
             | 
| 14 | 
            +
            from .low_precision_layernorm import LPLayerNorm
         | 
| 15 | 
            +
             | 
| 16 | 
            +
             | 
| 17 | 
            +
            def _reset_is_causal(num_query_tokens: int, num_key_tokens: int,
         | 
| 18 | 
            +
                                 original_is_causal: bool):
         | 
| 19 | 
            +
                if original_is_causal and num_query_tokens != num_key_tokens:
         | 
| 20 | 
            +
                    if num_query_tokens != 1:
         | 
| 21 | 
            +
                        raise NotImplementedError(
         | 
| 22 | 
            +
                            'ReplitLM does not support query and key with different number of tokens, unless number of query tokens is 1.'
         | 
| 23 | 
            +
                        )
         | 
| 24 | 
            +
                    else:
         | 
| 25 | 
            +
                        return False
         | 
| 26 | 
            +
                return original_is_causal
         | 
| 27 | 
            +
             | 
| 28 | 
            +
             | 
| 29 | 
            +
            def scaled_multihead_dot_product_attention(
         | 
| 30 | 
            +
                query,
         | 
| 31 | 
            +
                key,
         | 
| 32 | 
            +
                value,
         | 
| 33 | 
            +
                n_heads,
         | 
| 34 | 
            +
                softmax_scale=None,
         | 
| 35 | 
            +
                attn_bias=None,
         | 
| 36 | 
            +
                key_padding_mask=None,
         | 
| 37 | 
            +
                is_causal=False,
         | 
| 38 | 
            +
                dropout_p=0.0,
         | 
| 39 | 
            +
                training=False,
         | 
| 40 | 
            +
                needs_weights=False,
         | 
| 41 | 
            +
            ):
         | 
| 42 | 
            +
             | 
| 43 | 
            +
                q = rearrange(query, 'b s (h d) -> b h s d', h=n_heads)
         | 
| 44 | 
            +
                k = rearrange(key, 'b s (h d) -> b h d s', h=n_heads)  # includes key.t()
         | 
| 45 | 
            +
                v = rearrange(value, 'b s (h d) -> b h s d', h=n_heads)
         | 
| 46 | 
            +
             | 
| 47 | 
            +
                min_val = torch.finfo(q.dtype).min
         | 
| 48 | 
            +
             | 
| 49 | 
            +
                b, _, s_q, d = q.shape
         | 
| 50 | 
            +
                s_k = k.size(-1)
         | 
| 51 | 
            +
             | 
| 52 | 
            +
                if softmax_scale is None:
         | 
| 53 | 
            +
                    softmax_scale = 1 / math.sqrt(d)
         | 
| 54 | 
            +
             | 
| 55 | 
            +
                attn_weight = q.matmul(k) * softmax_scale
         | 
| 56 | 
            +
             | 
| 57 | 
            +
                if attn_bias is not None:
         | 
| 58 | 
            +
                    if (attn_bias.size(-1) != 1 and
         | 
| 59 | 
            +
                            attn_bias.size(-1) != s_k) or (attn_bias.size(-2) != 1 and
         | 
| 60 | 
            +
                                                           attn_bias.size(-2) != s_q):
         | 
| 61 | 
            +
                        raise RuntimeError(
         | 
| 62 | 
            +
                            f'attn_bias (shape: {attn_bias.shape}) is expected to broadcast to shape: {attn_weight.shape}.'
         | 
| 63 | 
            +
                        )
         | 
| 64 | 
            +
                    attn_weight = attn_weight + attn_bias
         | 
| 65 | 
            +
             | 
| 66 | 
            +
                if key_padding_mask is not None:
         | 
| 67 | 
            +
                    if attn_bias is not None:
         | 
| 68 | 
            +
                        warnings.warn(
         | 
| 69 | 
            +
                            'Propogating key_padding_mask to the attention module ' +
         | 
| 70 | 
            +
                            'and applying it within the attention module can cause ' +
         | 
| 71 | 
            +
                            'unneccessary computation/memory usage. Consider integrating ' +
         | 
| 72 | 
            +
                            'into attn_bias once and passing that to each attention ' +
         | 
| 73 | 
            +
                            'module instead.'
         | 
| 74 | 
            +
                        )
         | 
| 75 | 
            +
                    attn_weight = attn_weight.masked_fill(
         | 
| 76 | 
            +
                        ~key_padding_mask.view((b, 1, 1, s_k)), min_val)
         | 
| 77 | 
            +
             | 
| 78 | 
            +
                if is_causal:
         | 
| 79 | 
            +
                    s = max(s_q, s_k)
         | 
| 80 | 
            +
                    causal_mask = attn_weight.new_ones(s, s, dtype=torch.float16)
         | 
| 81 | 
            +
                    causal_mask = causal_mask.tril()
         | 
| 82 | 
            +
                    causal_mask = causal_mask.to(torch.bool)
         | 
| 83 | 
            +
                    causal_mask = ~causal_mask
         | 
| 84 | 
            +
                    causal_mask = causal_mask[-s_q:, -s_k:]
         | 
| 85 | 
            +
                    attn_weight = attn_weight.masked_fill(causal_mask.view(1, 1, s_q, s_k),
         | 
| 86 | 
            +
                                                          min_val)
         | 
| 87 | 
            +
             | 
| 88 | 
            +
                attn_weight = torch.softmax(attn_weight, dim=-1)
         | 
| 89 | 
            +
             | 
| 90 | 
            +
                if dropout_p:
         | 
| 91 | 
            +
                    attn_weight = torch.nn.functional.dropout(attn_weight,
         | 
| 92 | 
            +
                                                              p=dropout_p,
         | 
| 93 | 
            +
                                                              training=training,
         | 
| 94 | 
            +
                                                              inplace=True)
         | 
| 95 | 
            +
             | 
| 96 | 
            +
                out = attn_weight.matmul(v)
         | 
| 97 | 
            +
                out = rearrange(out, 'b h s d -> b s (h d)')
         | 
| 98 | 
            +
             | 
| 99 | 
            +
                if needs_weights:
         | 
| 100 | 
            +
                    return out, attn_weight
         | 
| 101 | 
            +
                return out, None
         | 
| 102 | 
            +
             | 
| 103 | 
            +
             | 
| 104 | 
            +
            def check_valid_inputs(*tensors, valid_dtypes=[torch.float16, torch.bfloat16]):
         | 
| 105 | 
            +
                for tensor in tensors:
         | 
| 106 | 
            +
                    if tensor.dtype not in valid_dtypes:
         | 
| 107 | 
            +
                        raise TypeError(f'{tensor.dtype=} must be in {valid_dtypes=}.')
         | 
| 108 | 
            +
                    if not tensor.is_cuda:
         | 
| 109 | 
            +
                        raise TypeError(
         | 
| 110 | 
            +
                            f'Inputs must be cuda tensors ({tensor.is_cuda=}).')
         | 
| 111 | 
            +
             | 
| 112 | 
            +
             | 
| 113 | 
            +
            def flash_attn_fn(
         | 
| 114 | 
            +
                query,
         | 
| 115 | 
            +
                key,
         | 
| 116 | 
            +
                value,
         | 
| 117 | 
            +
                n_heads,
         | 
| 118 | 
            +
                softmax_scale=None,
         | 
| 119 | 
            +
                attn_bias=None,
         | 
| 120 | 
            +
                key_padding_mask=None,
         | 
| 121 | 
            +
                is_causal=False,
         | 
| 122 | 
            +
                dropout_p=0.0,
         | 
| 123 | 
            +
                training=False,
         | 
| 124 | 
            +
                needs_weights=False,
         | 
| 125 | 
            +
            ):
         | 
| 126 | 
            +
                try:
         | 
| 127 | 
            +
                    from flash_attn import bert_padding, flash_attn_interface
         | 
| 128 | 
            +
                except:
         | 
| 129 | 
            +
                    raise RuntimeError('Please install flash_attn==0.2.8')
         | 
| 130 | 
            +
             | 
| 131 | 
            +
                check_valid_inputs(query, key, value)
         | 
| 132 | 
            +
             | 
| 133 | 
            +
                if attn_bias is not None:
         | 
| 134 | 
            +
                    raise NotImplementedError(f'attn_bias not implemented for flash attn.')
         | 
| 135 | 
            +
             | 
| 136 | 
            +
                batch_size, seqlen = query.shape[:2]
         | 
| 137 | 
            +
             | 
| 138 | 
            +
                if key_padding_mask is None:
         | 
| 139 | 
            +
                    key_padding_mask = torch.ones_like(key[:, :, 0], dtype=torch.bool)
         | 
| 140 | 
            +
                query_padding_mask = key_padding_mask[:, -query.size(1):]
         | 
| 141 | 
            +
             | 
| 142 | 
            +
                query_unpad, indices_q, cu_seqlens_q, max_seqlen_q = bert_padding.unpad_input(
         | 
| 143 | 
            +
                    query, query_padding_mask)
         | 
| 144 | 
            +
                query_unpad = rearrange(query_unpad, 'nnz (h d) -> nnz h d', h=n_heads)
         | 
| 145 | 
            +
             | 
| 146 | 
            +
                key_unpad, _, cu_seqlens_k, max_seqlen_k = bert_padding.unpad_input(
         | 
| 147 | 
            +
                    key, key_padding_mask)
         | 
| 148 | 
            +
                key_unpad = rearrange(key_unpad, 'nnz (h d) -> nnz h d', h=n_heads)
         | 
| 149 | 
            +
             | 
| 150 | 
            +
                value_unpad, _, _, _ = bert_padding.unpad_input(value, key_padding_mask)
         | 
| 151 | 
            +
                value_unpad = rearrange(value_unpad, 'nnz (h d) -> nnz h d', h=n_heads)
         | 
| 152 | 
            +
             | 
| 153 | 
            +
                dropout_p = dropout_p if training else 0.0
         | 
| 154 | 
            +
             | 
| 155 | 
            +
                reset_is_causal = _reset_is_causal(query.size(1), key.size(1), is_causal)
         | 
| 156 | 
            +
             | 
| 157 | 
            +
                output_unpad = flash_attn_interface.flash_attn_unpadded_func(
         | 
| 158 | 
            +
                    query_unpad,
         | 
| 159 | 
            +
                    key_unpad,
         | 
| 160 | 
            +
                    value_unpad,
         | 
| 161 | 
            +
                    cu_seqlens_q,
         | 
| 162 | 
            +
                    cu_seqlens_k,
         | 
| 163 | 
            +
                    max_seqlen_q,
         | 
| 164 | 
            +
                    max_seqlen_k,
         | 
| 165 | 
            +
                    dropout_p,
         | 
| 166 | 
            +
                    softmax_scale=softmax_scale,
         | 
| 167 | 
            +
                    causal=reset_is_causal,
         | 
| 168 | 
            +
                    return_attn_probs=needs_weights)
         | 
| 169 | 
            +
             | 
| 170 | 
            +
                output = bert_padding.pad_input(
         | 
| 171 | 
            +
                    rearrange(output_unpad, 'nnz h d -> nnz (h d)'), indices_q, batch_size,
         | 
| 172 | 
            +
                    seqlen)
         | 
| 173 | 
            +
                return output, None
         | 
| 174 | 
            +
             | 
| 175 | 
            +
             | 
| 176 | 
            +
            def triton_flash_attn_fn(
         | 
| 177 | 
            +
                query,
         | 
| 178 | 
            +
                key,
         | 
| 179 | 
            +
                value,
         | 
| 180 | 
            +
                n_heads,
         | 
| 181 | 
            +
                softmax_scale=None,
         | 
| 182 | 
            +
                attn_bias=None,
         | 
| 183 | 
            +
                key_padding_mask=None,
         | 
| 184 | 
            +
                is_causal=False,
         | 
| 185 | 
            +
                dropout_p=0.0,
         | 
| 186 | 
            +
                training=False,
         | 
| 187 | 
            +
                needs_weights=False,
         | 
| 188 | 
            +
            ):
         | 
| 189 | 
            +
                try:
         | 
| 190 | 
            +
                    from flash_attn import flash_attn_triton  # type: ignore
         | 
| 191 | 
            +
                except:
         | 
| 192 | 
            +
                    raise RuntimeError(
         | 
| 193 | 
            +
                        'Please install flash_attn==0.2.8 and triton==2.0.0.dev20221202.')
         | 
| 194 | 
            +
             | 
| 195 | 
            +
                check_valid_inputs(query, key, value)
         | 
| 196 | 
            +
             | 
| 197 | 
            +
                if dropout_p:
         | 
| 198 | 
            +
                    raise NotImplementedError(
         | 
| 199 | 
            +
                        f'Dropout not implemented for attn_impl: triton.')
         | 
| 200 | 
            +
             | 
| 201 | 
            +
                if needs_weights:
         | 
| 202 | 
            +
                    raise NotImplementedError(
         | 
| 203 | 
            +
                        f'attn_impl: triton cannot return attn weights.')
         | 
| 204 | 
            +
             | 
| 205 | 
            +
                if key_padding_mask is not None:
         | 
| 206 | 
            +
                    warnings.warn(
         | 
| 207 | 
            +
                        'Propagating key_padding_mask to the attention module ' +
         | 
| 208 | 
            +
                        'and applying it within the attention module can cause ' +
         | 
| 209 | 
            +
                        'unnecessary computation/memory usage. Consider integrating ' +
         | 
| 210 | 
            +
                        'into attn_bias once and passing that to each attention ' +
         | 
| 211 | 
            +
                        'module instead.'
         | 
| 212 | 
            +
                    )
         | 
| 213 | 
            +
                    b_size, s_k = key_padding_mask.shape[:2]
         | 
| 214 | 
            +
             | 
| 215 | 
            +
                    if attn_bias is None:
         | 
| 216 | 
            +
                        attn_bias = query.new_zeros(b_size, 1, 1, s_k)
         | 
| 217 | 
            +
             | 
| 218 | 
            +
                    attn_bias = attn_bias.masked_fill(
         | 
| 219 | 
            +
                        ~key_padding_mask.view((b_size, 1, 1, s_k)),
         | 
| 220 | 
            +
                        torch.finfo(query.dtype).min)
         | 
| 221 | 
            +
             | 
| 222 | 
            +
                query = rearrange(query, 'b s (h d) -> b s h d', h=n_heads)
         | 
| 223 | 
            +
                key = rearrange(key, 'b s (h d) -> b s h d', h=n_heads)
         | 
| 224 | 
            +
                value = rearrange(value, 'b s (h d) -> b s h d', h=n_heads)
         | 
| 225 | 
            +
             | 
| 226 | 
            +
                reset_is_causal = _reset_is_causal(query.size(1), key.size(1), is_causal)
         | 
| 227 | 
            +
                attn_output = flash_attn_triton.flash_attn_func(query, key, value,
         | 
| 228 | 
            +
                                                                attn_bias, reset_is_causal,
         | 
| 229 | 
            +
                                                                softmax_scale)
         | 
| 230 | 
            +
             | 
| 231 | 
            +
                output = attn_output.view(*attn_output.shape[:2], -1)
         | 
| 232 | 
            +
             | 
| 233 | 
            +
                return output, None
         | 
| 234 | 
            +
             | 
| 235 | 
            +
             | 
| 236 | 
            +
            class MultiheadAttention(nn.Module):
         | 
| 237 | 
            +
                """Multi-head self attention.
         | 
| 238 | 
            +
             | 
| 239 | 
            +
                Using torch or triton attention implemetation enables user to also use
         | 
| 240 | 
            +
                additive bias.
         | 
| 241 | 
            +
                """
         | 
| 242 | 
            +
             | 
| 243 | 
            +
                def __init__(
         | 
| 244 | 
            +
                    self,
         | 
| 245 | 
            +
                    d_model: int,
         | 
| 246 | 
            +
                    n_heads: int,
         | 
| 247 | 
            +
                    attn_impl: str = 'triton',
         | 
| 248 | 
            +
                    attn_clip_qkv: Optional[float] = None,
         | 
| 249 | 
            +
                    attn_qk_ln: bool = False,
         | 
| 250 | 
            +
                    softmax_scale: Optional[float] = None,
         | 
| 251 | 
            +
                    attn_pdrop: float = 0.0,
         | 
| 252 | 
            +
                    low_precision_layernorm: bool = False,
         | 
| 253 | 
            +
                    device: Optional[str] = None,
         | 
| 254 | 
            +
                ):
         | 
| 255 | 
            +
                    super().__init__()
         | 
| 256 | 
            +
             | 
| 257 | 
            +
                    self.attn_impl = attn_impl
         | 
| 258 | 
            +
                    self.clip_qkv = attn_clip_qkv
         | 
| 259 | 
            +
                    self.attn_qk_ln = attn_qk_ln
         | 
| 260 | 
            +
             | 
| 261 | 
            +
                    self.d_model = d_model
         | 
| 262 | 
            +
                    self.n_heads = n_heads
         | 
| 263 | 
            +
                    self.softmax_scale = softmax_scale
         | 
| 264 | 
            +
                    if self.softmax_scale is None:
         | 
| 265 | 
            +
                        self.softmax_scale = 1 / math.sqrt(self.d_model / self.n_heads)
         | 
| 266 | 
            +
                    self.attn_dropout_p = attn_pdrop
         | 
| 267 | 
            +
             | 
| 268 | 
            +
                    self.Wqkv = nn.Linear(self.d_model, 3 * self.d_model, device=device)
         | 
| 269 | 
            +
                    # for param init fn; enables shape based init of fused layers
         | 
| 270 | 
            +
                    fuse_splits = (d_model, 2 * d_model)
         | 
| 271 | 
            +
                    self.Wqkv._fused = (0, fuse_splits)  # type: ignore
         | 
| 272 | 
            +
             | 
| 273 | 
            +
                    if self.attn_qk_ln:
         | 
| 274 | 
            +
                        layernorm_class = LPLayerNorm if low_precision_layernorm else nn.LayerNorm
         | 
| 275 | 
            +
                        self.q_ln = layernorm_class(self.d_model, device=device)
         | 
| 276 | 
            +
                        self.k_ln = layernorm_class(self.d_model, device=device)
         | 
| 277 | 
            +
             | 
| 278 | 
            +
                    if self.attn_impl == 'flash':
         | 
| 279 | 
            +
                        self.attn_fn = flash_attn_fn
         | 
| 280 | 
            +
                    elif self.attn_impl == 'triton':
         | 
| 281 | 
            +
                        self.attn_fn = triton_flash_attn_fn
         | 
| 282 | 
            +
                        warnings.warn(
         | 
| 283 | 
            +
                            'While `attn_impl: triton` can be faster than `attn_impl: flash` ' +
         | 
| 284 | 
            +
                            'it uses more memory. When training larger models this can trigger ' +
         | 
| 285 | 
            +
                            'alloc retries which hurts performance. If encountered, we recommend ' +
         | 
| 286 | 
            +
                            'using `attn_impl: flash` if your model does not use `alibi` or `prefix_lm`.')
         | 
| 287 | 
            +
                    elif self.attn_impl == 'torch':
         | 
| 288 | 
            +
                        self.attn_fn = scaled_multihead_dot_product_attention
         | 
| 289 | 
            +
                        if torch.cuda.is_available():
         | 
| 290 | 
            +
                            warnings.warn(
         | 
| 291 | 
            +
                                'Using `attn_impl: torch`. If your model does not use `alibi` or ' +
         | 
| 292 | 
            +
                                '`prefix_lm` we recommend using `attn_impl: flash` otherwise ' +
         | 
| 293 | 
            +
                                'we recommend using `attn_impl: triton`.'
         | 
| 294 | 
            +
                            )
         | 
| 295 | 
            +
                    else:
         | 
| 296 | 
            +
                        raise ValueError(f'{attn_impl=} is an invalid setting.')
         | 
| 297 | 
            +
             | 
| 298 | 
            +
                    self.out_proj = nn.Linear(self.d_model, self.d_model, device=device)
         | 
| 299 | 
            +
                    self.out_proj._is_residual = True  # type: ignore
         | 
| 300 | 
            +
             | 
| 301 | 
            +
                def forward(self,
         | 
| 302 | 
            +
                            x,
         | 
| 303 | 
            +
                            past_key_value=None,
         | 
| 304 | 
            +
                            attn_bias=None,
         | 
| 305 | 
            +
                            attention_mask=None,
         | 
| 306 | 
            +
                            is_causal=True,
         | 
| 307 | 
            +
                            needs_weights=False):
         | 
| 308 | 
            +
                    qkv = self.Wqkv(x)
         | 
| 309 | 
            +
             | 
| 310 | 
            +
                    if self.clip_qkv:
         | 
| 311 | 
            +
                        qkv.clamp_(min=-self.clip_qkv, max=self.clip_qkv)
         | 
| 312 | 
            +
             | 
| 313 | 
            +
                    query, key, value = qkv.chunk(3, dim=2)
         | 
| 314 | 
            +
             | 
| 315 | 
            +
                    key_padding_mask = attention_mask
         | 
| 316 | 
            +
             | 
| 317 | 
            +
                    if self.attn_qk_ln:
         | 
| 318 | 
            +
                        # Applying layernorm to qk
         | 
| 319 | 
            +
                        dtype = query.dtype
         | 
| 320 | 
            +
                        query = self.q_ln(query).to(dtype)
         | 
| 321 | 
            +
                        key = self.k_ln(key).to(dtype)
         | 
| 322 | 
            +
             | 
| 323 | 
            +
                    if past_key_value is not None:
         | 
| 324 | 
            +
                        if len(past_key_value) != 0:
         | 
| 325 | 
            +
                            key = torch.cat([past_key_value[0], key], dim=1)
         | 
| 326 | 
            +
                            value = torch.cat([past_key_value[1], value], dim=1)
         | 
| 327 | 
            +
             | 
| 328 | 
            +
                        past_key_value = (key, value)
         | 
| 329 | 
            +
             | 
| 330 | 
            +
                    if attn_bias is not None:
         | 
| 331 | 
            +
                        attn_bias = attn_bias[:, :, -query.size(1):, -key.size(1):]
         | 
| 332 | 
            +
             | 
| 333 | 
            +
                    context, attn_weights = self.attn_fn(
         | 
| 334 | 
            +
                        query,
         | 
| 335 | 
            +
                        key,
         | 
| 336 | 
            +
                        value,
         | 
| 337 | 
            +
                        self.n_heads,
         | 
| 338 | 
            +
                        softmax_scale=self.softmax_scale,
         | 
| 339 | 
            +
                        attn_bias=attn_bias,
         | 
| 340 | 
            +
                        key_padding_mask=key_padding_mask,
         | 
| 341 | 
            +
                        is_causal=is_causal,
         | 
| 342 | 
            +
                        dropout_p=self.attn_dropout_p,
         | 
| 343 | 
            +
                        training=self.training,
         | 
| 344 | 
            +
                        needs_weights=needs_weights,
         | 
| 345 | 
            +
                    )
         | 
| 346 | 
            +
             | 
| 347 | 
            +
                    return self.out_proj(context), attn_weights, past_key_value
         | 
| 348 | 
            +
             | 
| 349 | 
            +
             | 
| 350 | 
            +
            def attn_bias_shape(attn_impl, n_heads, seq_len, alibi, prefix_lm, causal,
         | 
| 351 | 
            +
                                use_sequence_id):
         | 
| 352 | 
            +
                if attn_impl == 'flash':
         | 
| 353 | 
            +
                    return None
         | 
| 354 | 
            +
                elif attn_impl in ['torch', 'triton']:
         | 
| 355 | 
            +
                    if alibi:
         | 
| 356 | 
            +
                        if (prefix_lm or not causal) or use_sequence_id:
         | 
| 357 | 
            +
                            return (1, n_heads, seq_len, seq_len)
         | 
| 358 | 
            +
                        return (1, n_heads, 1, seq_len)
         | 
| 359 | 
            +
                    elif prefix_lm or use_sequence_id:
         | 
| 360 | 
            +
                        return (1, 1, seq_len, seq_len)
         | 
| 361 | 
            +
                    return None
         | 
| 362 | 
            +
                else:
         | 
| 363 | 
            +
                    raise ValueError(f'{attn_impl=} is an invalid setting.')
         | 
| 364 | 
            +
             | 
| 365 | 
            +
             | 
| 366 | 
            +
            def attn_bias(attn_impl,
         | 
| 367 | 
            +
                          attn_bias,
         | 
| 368 | 
            +
                          n_heads,
         | 
| 369 | 
            +
                          seq_len,
         | 
| 370 | 
            +
                          causal=False,
         | 
| 371 | 
            +
                          alibi=False,
         | 
| 372 | 
            +
                          alibi_bias_max=8):
         | 
| 373 | 
            +
                if attn_impl == 'flash':
         | 
| 374 | 
            +
                    return None
         | 
| 375 | 
            +
                elif attn_impl in ['torch', 'triton']:
         | 
| 376 | 
            +
                    if alibi:
         | 
| 377 | 
            +
                        # in place add alibi to attn bias
         | 
| 378 | 
            +
                        device, dtype = attn_bias.device, attn_bias.dtype
         | 
| 379 | 
            +
                        attn_bias = attn_bias.add(
         | 
| 380 | 
            +
                            alibi_bias(n_heads,
         | 
| 381 | 
            +
                                       seq_len,
         | 
| 382 | 
            +
                                       full=not causal,
         | 
| 383 | 
            +
                                       alibi_bias_max=alibi_bias_max,
         | 
| 384 | 
            +
                                       device=device,
         | 
| 385 | 
            +
                                       dtype=dtype))
         | 
| 386 | 
            +
                    return attn_bias
         | 
| 387 | 
            +
                else:
         | 
| 388 | 
            +
                    raise ValueError(f'{attn_impl=} is an invalid setting.')
         | 
| 389 | 
            +
             | 
| 390 | 
            +
             | 
| 391 | 
            +
            def alibi_bias(n_heads,
         | 
| 392 | 
            +
                           seq_len,
         | 
| 393 | 
            +
                           full=False,
         | 
| 394 | 
            +
                           alibi_bias_max=8,
         | 
| 395 | 
            +
                           device=None,
         | 
| 396 | 
            +
                           dtype=None):
         | 
| 397 | 
            +
                alibi_bias = torch.arange(1 - seq_len, 1, dtype=dtype,
         | 
| 398 | 
            +
                                          device=device).view(1, 1, 1, seq_len)
         | 
| 399 | 
            +
                if full:
         | 
| 400 | 
            +
                    # generate 1 x Heads x SeqLen x SeqLen alibi bias mask
         | 
| 401 | 
            +
                    # otherwise the mask is 1 x Heads x 1 x SeqLen (which is broadcast to the appropriate size)
         | 
| 402 | 
            +
                    alibi_bias = alibi_bias - torch.arange(
         | 
| 403 | 
            +
                        1 - seq_len, 1, dtype=dtype, device=device).view(1, 1, seq_len, 1)
         | 
| 404 | 
            +
                    alibi_bias = alibi_bias.abs().mul(-1)
         | 
| 405 | 
            +
             | 
| 406 | 
            +
                m = torch.arange(1, n_heads + 1, dtype=dtype, device=device)
         | 
| 407 | 
            +
                m = m.mul(alibi_bias_max / n_heads)
         | 
| 408 | 
            +
                alibi_bias = alibi_bias * (1. / (2**m.view(1, n_heads, 1, 1)))
         | 
| 409 | 
            +
                return alibi_bias
         | 
    	
        config.json
    ADDED
    
    | @@ -0,0 +1,46 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            {
         | 
| 2 | 
            +
              "_name_or_path": "replit/replit-code-v1-3b",
         | 
| 3 | 
            +
              "alibi": true,
         | 
| 4 | 
            +
              "alibi_bias_max": 8,
         | 
| 5 | 
            +
              "architectures": [
         | 
| 6 | 
            +
                "ReplitLM"
         | 
| 7 | 
            +
              ],
         | 
| 8 | 
            +
              "attn_clip_qkv": null,
         | 
| 9 | 
            +
              "attn_impl": "torch",
         | 
| 10 | 
            +
              "attn_pdrop": 0,
         | 
| 11 | 
            +
              "attn_qk_ln": false,
         | 
| 12 | 
            +
              "attn_uses_sequence_id": false,
         | 
| 13 | 
            +
              "auto_map": {
         | 
| 14 | 
            +
                "AutoConfig": "configuration_replit_lm.ReplitLMConfig",
         | 
| 15 | 
            +
                "AutoModelForCausalLM": "replit_lm.ReplitLM"
         | 
| 16 | 
            +
              },
         | 
| 17 | 
            +
              "d_model": 2560,
         | 
| 18 | 
            +
              "emb_init_std": null,
         | 
| 19 | 
            +
              "emb_init_uniform_lim": null,
         | 
| 20 | 
            +
              "emb_pdrop": 0,
         | 
| 21 | 
            +
              "embedding_fraction": 1.0,
         | 
| 22 | 
            +
              "fan_mode": "fan_in",
         | 
| 23 | 
            +
              "init_device": "cpu",
         | 
| 24 | 
            +
              "init_div_is_residual": true,
         | 
| 25 | 
            +
              "init_gain": 0,
         | 
| 26 | 
            +
              "init_nonlinearity": "relu",
         | 
| 27 | 
            +
              "init_std": 0.02,
         | 
| 28 | 
            +
              "logit_scale": null,
         | 
| 29 | 
            +
              "low_precision_layernorm": true,
         | 
| 30 | 
            +
              "max_seq_len": 2048,
         | 
| 31 | 
            +
              "mlp_ratio": 4,
         | 
| 32 | 
            +
              "model_type": "replit_lm",
         | 
| 33 | 
            +
              "n_heads": 32,
         | 
| 34 | 
            +
              "n_layers": 32,
         | 
| 35 | 
            +
              "no_bias": true,
         | 
| 36 | 
            +
              "param_init_fn": "kaiming_normal_",
         | 
| 37 | 
            +
              "prefix_lm": false,
         | 
| 38 | 
            +
              "resid_pdrop": 0,
         | 
| 39 | 
            +
              "softmax_scale": null,
         | 
| 40 | 
            +
              "tokenizer_name": "replit/replit-code-v1-3b",
         | 
| 41 | 
            +
              "torch_dtype": "float32",
         | 
| 42 | 
            +
              "transformers_version": "4.26.1",
         | 
| 43 | 
            +
              "use_cache": false,
         | 
| 44 | 
            +
              "verbose": 0,
         | 
| 45 | 
            +
              "vocab_size": 32768
         | 
| 46 | 
            +
            }
         | 
    	
        configuration_replit_lm.py
    ADDED
    
    | @@ -0,0 +1,168 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # Copyright 2022 MosaicML Examples authors
         | 
| 2 | 
            +
            # SPDX-License-Identifier: Apache-2.0
         | 
| 3 | 
            +
             | 
| 4 | 
            +
            """Forked for ReplitLM"""
         | 
| 5 | 
            +
             | 
| 6 | 
            +
            """A HuggingFace-style model configuration."""
         | 
| 7 | 
            +
             | 
| 8 | 
            +
             | 
| 9 | 
            +
            from typing import Optional, Tuple, Union
         | 
| 10 | 
            +
            from transformers import PretrainedConfig
         | 
| 11 | 
            +
            class ReplitLMConfig(PretrainedConfig):
         | 
| 12 | 
            +
                model_type = 'replit_lm'
         | 
| 13 | 
            +
             | 
| 14 | 
            +
                def __init__(
         | 
| 15 | 
            +
                    self,
         | 
| 16 | 
            +
                    d_model: int = 2048,
         | 
| 17 | 
            +
                    n_heads: int = 16,
         | 
| 18 | 
            +
                    n_layers: int = 24,
         | 
| 19 | 
            +
                    mlp_ratio: int = 4,
         | 
| 20 | 
            +
                    max_seq_len: int = 2048,
         | 
| 21 | 
            +
                    vocab_size: int = 50368,
         | 
| 22 | 
            +
                    attn_pdrop: float = 0.0,
         | 
| 23 | 
            +
                    resid_pdrop: float = 0.0,
         | 
| 24 | 
            +
                    emb_pdrop: float = 0.0,
         | 
| 25 | 
            +
                    attn_impl: str = 'triton',
         | 
| 26 | 
            +
                    attn_qk_ln: bool = False,
         | 
| 27 | 
            +
                    attn_clip_qkv: Optional[float] = None,
         | 
| 28 | 
            +
                    softmax_scale: Optional[float] = None,
         | 
| 29 | 
            +
                    prefix_lm: Optional[bool] = False,
         | 
| 30 | 
            +
                    attn_uses_sequence_id: Optional[bool] = False,
         | 
| 31 | 
            +
                    alibi: bool = False,
         | 
| 32 | 
            +
                    alibi_bias_max: int = 8,
         | 
| 33 | 
            +
                    init_device: str = 'cpu',
         | 
| 34 | 
            +
                    logit_scale: Optional[Union[float, str]] = None,
         | 
| 35 | 
            +
                    no_bias: bool = False,
         | 
| 36 | 
            +
                    verbose: int = 0,
         | 
| 37 | 
            +
                    param_init_fn: str = 'kaiming_normal_',
         | 
| 38 | 
            +
                    init_div_is_residual: Union[int, float, str, bool] = True,
         | 
| 39 | 
            +
                    init_std: float = 0.02,
         | 
| 40 | 
            +
                    emb_init_std: Optional[float] = None,
         | 
| 41 | 
            +
                    emb_init_uniform_lim: Optional[Union[Tuple[float, float],
         | 
| 42 | 
            +
                                                         float]] = None,
         | 
| 43 | 
            +
                    init_gain: float = 0,
         | 
| 44 | 
            +
                    fan_mode: str = 'fan_in',
         | 
| 45 | 
            +
                    init_nonlinearity: str = 'relu',
         | 
| 46 | 
            +
                    embedding_fraction: float = 1.0,
         | 
| 47 | 
            +
                    low_precision_layernorm: bool = True,
         | 
| 48 | 
            +
                    use_cache: bool = False,
         | 
| 49 | 
            +
                    **kwargs,
         | 
| 50 | 
            +
                ):
         | 
| 51 | 
            +
                    """The ReplitLM configuration class.
         | 
| 52 | 
            +
             | 
| 53 | 
            +
                    Args:
         | 
| 54 | 
            +
                        d_model (int): The size of the embedding dimension of the model.
         | 
| 55 | 
            +
                        n_heads (int): The number of attention heads.
         | 
| 56 | 
            +
                        n_layers (int): The number of layers in the model.
         | 
| 57 | 
            +
                        mlp_ratio (int): The ratio of the up/down scale in the MLP.
         | 
| 58 | 
            +
                        max_seq_len (int): The maximum sequence length of the model.
         | 
| 59 | 
            +
                        vocab_size (int): The size of the vocabulary.
         | 
| 60 | 
            +
                        attn_pdrop (float): The dropout probability for the attention layers.
         | 
| 61 | 
            +
                        resid_pdrop (float): The dropout probability applied to the attention output before combining with residual.
         | 
| 62 | 
            +
                        emb_pdrop (float): The dropout probability for the embedding layer.
         | 
| 63 | 
            +
                        attn_impl (str): The attention implementation to use. One of 'torch', 'flash', or 'triton'.
         | 
| 64 | 
            +
                        attn_qk_ln (bool): Whether to apply layer normalization to the queries and keys in the attention layer.
         | 
| 65 | 
            +
                        attn_clip_qkv (Optional[float]): If not None, clip the queries, keys, and values in the attention layer to
         | 
| 66 | 
            +
                            this value.
         | 
| 67 | 
            +
                        softmax_scale (Optional[float]): If not None, scale the softmax in the attention layer by this value. If None,
         | 
| 68 | 
            +
                            use the default scale of ``1/sqrt(d_keys)``.
         | 
| 69 | 
            +
                        prefix_lm (Optional[bool]): Whether the model should operate as a Prefix LM. This requires passing an
         | 
| 70 | 
            +
                            extra `prefix_mask` argument which indicates which tokens belong to the prefix. Tokens in the prefix
         | 
| 71 | 
            +
                            can attend to one another bi-directionally. Tokens outside the prefix use causal attention.
         | 
| 72 | 
            +
                        attn_uses_sequence_id (Optional[bool]): Whether to restrict attention to tokens that have the same sequence_id.
         | 
| 73 | 
            +
                            When the model is in `train` mode, this requires passing an extra `sequence_id` argument which indicates
         | 
| 74 | 
            +
                            which sub-sequence each token belongs to.
         | 
| 75 | 
            +
                            Defaults to ``False`` meaning any provided `sequence_id` will be ignored.
         | 
| 76 | 
            +
                        alibi (bool): Whether to use the alibi bias instead of position embeddings.
         | 
| 77 | 
            +
                        alibi_bias_max (int): The maximum value of the alibi bias.
         | 
| 78 | 
            +
                        init_device (str): The device to use for parameter initialization.
         | 
| 79 | 
            +
                        logit_scale (Optional[Union[float, str]]): If not None, scale the logits by this value.
         | 
| 80 | 
            +
                        no_bias (bool): Whether to use bias in all layers.
         | 
| 81 | 
            +
                        verbose (int): The verbosity level. 0 is silent.
         | 
| 82 | 
            +
                        param_init_fn (str): The parameter initialization scheme to use. One of 'default_', 'baseline_', 'kaiming_uniform_',
         | 
| 83 | 
            +
                            'kaiming_normal_', 'neox_init_', 'small_init_', 'xavier_uniform_', or 'xavier_normal_'.
         | 
| 84 | 
            +
                        init_div_is_residual (Union[int, float, str, bool]): Value to divide initial weights by if ``module._is_residual`` is True.
         | 
| 85 | 
            +
                        init_std (float): The standard deviation of the normal distribution used to initialize the model,
         | 
| 86 | 
            +
                            if using the baseline_ parameter initialization scheme.
         | 
| 87 | 
            +
                        emb_init_std (Optional[float]): The standard deviation of the normal distribution used to initialize the embedding layer.
         | 
| 88 | 
            +
                        emb_init_uniform_lim (Optional[Union[Tuple[float, float], float]]): The lower and upper limits of the uniform distribution
         | 
| 89 | 
            +
                            used to initialize the embedding layer. Mutually exclusive with ``emb_init_std``.
         | 
| 90 | 
            +
                        init_gain (float): The gain to use for parameter initialization with kaiming or xavier initialization schemes.
         | 
| 91 | 
            +
                        fan_mode (str): The fan mode to use for parameter initialization with kaiming initialization schemes.
         | 
| 92 | 
            +
                        init_nonlinearity (str): The nonlinearity to use for parameter initialization with kaiming initialization schemes.
         | 
| 93 | 
            +
                        embedding_fraction (float): The fraction to scale the gradients of the embedding layer by.
         | 
| 94 | 
            +
                        low_precision_layernorm (bool): Whether to use low precision layer normalization.
         | 
| 95 | 
            +
                        use_cache (bool): Whether or not the model should return the last key/values attentions
         | 
| 96 | 
            +
                    """
         | 
| 97 | 
            +
                    self.d_model = d_model
         | 
| 98 | 
            +
                    self.n_heads = n_heads
         | 
| 99 | 
            +
                    self.n_layers = n_layers
         | 
| 100 | 
            +
                    self.mlp_ratio = mlp_ratio
         | 
| 101 | 
            +
                    self.max_seq_len = max_seq_len
         | 
| 102 | 
            +
                    self.vocab_size = vocab_size
         | 
| 103 | 
            +
                    self.attn_pdrop = attn_pdrop
         | 
| 104 | 
            +
                    self.resid_pdrop = resid_pdrop
         | 
| 105 | 
            +
                    self.emb_pdrop = emb_pdrop
         | 
| 106 | 
            +
                    self.attn_impl = attn_impl
         | 
| 107 | 
            +
                    self.attn_qk_ln = attn_qk_ln
         | 
| 108 | 
            +
                    self.attn_clip_qkv = attn_clip_qkv
         | 
| 109 | 
            +
                    self.softmax_scale = softmax_scale
         | 
| 110 | 
            +
                    self.prefix_lm = prefix_lm
         | 
| 111 | 
            +
                    self.attn_uses_sequence_id = attn_uses_sequence_id
         | 
| 112 | 
            +
                    self.alibi = alibi
         | 
| 113 | 
            +
                    self.alibi_bias_max = alibi_bias_max
         | 
| 114 | 
            +
                    self.init_device = init_device
         | 
| 115 | 
            +
                    self.logit_scale = logit_scale
         | 
| 116 | 
            +
                    self.no_bias = no_bias
         | 
| 117 | 
            +
                    self.verbose = verbose
         | 
| 118 | 
            +
                    self.param_init_fn = param_init_fn
         | 
| 119 | 
            +
                    self.init_div_is_residual = init_div_is_residual
         | 
| 120 | 
            +
                    self.init_std = init_std
         | 
| 121 | 
            +
                    self.emb_init_std = emb_init_std
         | 
| 122 | 
            +
                    self.emb_init_uniform_lim = emb_init_uniform_lim
         | 
| 123 | 
            +
                    self.init_std = init_std
         | 
| 124 | 
            +
                    self.init_gain = init_gain
         | 
| 125 | 
            +
                    self.fan_mode = fan_mode
         | 
| 126 | 
            +
                    self.init_nonlinearity = init_nonlinearity
         | 
| 127 | 
            +
                    self.embedding_fraction = embedding_fraction
         | 
| 128 | 
            +
                    self.low_precision_layernorm = low_precision_layernorm
         | 
| 129 | 
            +
                    self.use_cache = use_cache
         | 
| 130 | 
            +
                    if 'name' in kwargs:
         | 
| 131 | 
            +
                        del kwargs['name']
         | 
| 132 | 
            +
                    if 'loss_fn' in kwargs:
         | 
| 133 | 
            +
                        del kwargs['loss_fn']
         | 
| 134 | 
            +
                    super().__init__(**kwargs)
         | 
| 135 | 
            +
             | 
| 136 | 
            +
                    self._validate_config()
         | 
| 137 | 
            +
             | 
| 138 | 
            +
                def _validate_config(self):
         | 
| 139 | 
            +
                    if self.d_model % self.n_heads != 0:
         | 
| 140 | 
            +
                        raise ValueError('d_model must be divisible by n_heads')
         | 
| 141 | 
            +
                    if any(prob < 0 or prob > 1
         | 
| 142 | 
            +
                           for prob in [self.attn_pdrop, self.resid_pdrop, self.emb_pdrop]):
         | 
| 143 | 
            +
                        raise ValueError(
         | 
| 144 | 
            +
                            'attn_pdrop, resid_pdrop, emb_pdrop are probabilities and must be between 0 and 1'
         | 
| 145 | 
            +
                        )
         | 
| 146 | 
            +
                    if self.attn_impl not in ['torch', 'flash', 'triton']:
         | 
| 147 | 
            +
                        raise ValueError(f'Unknown attn_impl={self.attn_impl}')
         | 
| 148 | 
            +
                    if self.prefix_lm and self.attn_impl not in ['torch', 'triton']:
         | 
| 149 | 
            +
                        raise NotImplementedError(
         | 
| 150 | 
            +
                            'prefix_lm only implemented with torch and triton attention.')
         | 
| 151 | 
            +
                    if self.alibi and self.attn_impl not in ['torch', 'triton']:
         | 
| 152 | 
            +
                        raise NotImplementedError(
         | 
| 153 | 
            +
                            'alibi only implemented with torch and triton attention.')
         | 
| 154 | 
            +
                    if self.attn_uses_sequence_id and self.attn_impl not in [
         | 
| 155 | 
            +
                            'torch', 'triton'
         | 
| 156 | 
            +
                    ]:
         | 
| 157 | 
            +
                        raise NotImplementedError(
         | 
| 158 | 
            +
                            'attn_uses_sequence_id only implemented with torch and triton attention.'
         | 
| 159 | 
            +
                        )
         | 
| 160 | 
            +
                    if self.embedding_fraction > 1 or self.embedding_fraction <= 0:
         | 
| 161 | 
            +
                        raise ValueError(
         | 
| 162 | 
            +
                            'model.embedding_fraction must be between 0 (exclusive) and 1 (inclusive)!'
         | 
| 163 | 
            +
                        )
         | 
| 164 | 
            +
                    if isinstance(self.logit_scale,
         | 
| 165 | 
            +
                                  str) and self.logit_scale != 'inv_sqrt_d_model':
         | 
| 166 | 
            +
                        raise ValueError(
         | 
| 167 | 
            +
                            f"{self.logit_scale=} is not recognized as an option; use numeric value or 'inv_sqrt_d_model'."
         | 
| 168 | 
            +
                        )
         | 
    	
        generation_config.json
    ADDED
    
    | @@ -0,0 +1,5 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            {
         | 
| 2 | 
            +
              "_from_model_config": true,
         | 
| 3 | 
            +
              "transformers_version": "4.26.1",
         | 
| 4 | 
            +
              "use_cache": false
         | 
| 5 | 
            +
            }
         | 
    	
        gpt_blocks.py
    ADDED
    
    | @@ -0,0 +1,90 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # Copyright 2022 MosaicML Examples authors
         | 
| 2 | 
            +
            # SPDX-License-Identifier: Apache-2.0
         | 
| 3 | 
            +
             | 
| 4 | 
            +
            """GPT Blocks used for the GPT Model."""
         | 
| 5 | 
            +
             | 
| 6 | 
            +
            from typing import Optional, Tuple
         | 
| 7 | 
            +
             | 
| 8 | 
            +
            import torch
         | 
| 9 | 
            +
            import torch.nn as nn
         | 
| 10 | 
            +
             | 
| 11 | 
            +
            from .attention import MultiheadAttention
         | 
| 12 | 
            +
            from .low_precision_layernorm import LPLayerNorm
         | 
| 13 | 
            +
             | 
| 14 | 
            +
             | 
| 15 | 
            +
            class GPTMLP(nn.Module):
         | 
| 16 | 
            +
             | 
| 17 | 
            +
                def __init__(self,
         | 
| 18 | 
            +
                             d_model: int,
         | 
| 19 | 
            +
                             mlp_ratio: int,
         | 
| 20 | 
            +
                             device: Optional[str] = None):
         | 
| 21 | 
            +
                    super().__init__()
         | 
| 22 | 
            +
                    self.mlp_up = nn.Linear(d_model, mlp_ratio * d_model, device=device)
         | 
| 23 | 
            +
                    self.mlp_act = nn.GELU(approximate='none')
         | 
| 24 | 
            +
                    self.mlp_down = nn.Linear(mlp_ratio * d_model, d_model, device=device)
         | 
| 25 | 
            +
                    self.mlp_down._is_residual = True  # type: ignore
         | 
| 26 | 
            +
             | 
| 27 | 
            +
                def forward(self, x):
         | 
| 28 | 
            +
                    return self.mlp_down(self.mlp_act(self.mlp_up(x)))
         | 
| 29 | 
            +
             | 
| 30 | 
            +
             | 
| 31 | 
            +
            class GPTBlock(nn.Module):
         | 
| 32 | 
            +
             | 
| 33 | 
            +
                def __init__(self,
         | 
| 34 | 
            +
                             attn_impl: str,
         | 
| 35 | 
            +
                             d_model: int,
         | 
| 36 | 
            +
                             n_heads: int,
         | 
| 37 | 
            +
                             mlp_ratio: int,
         | 
| 38 | 
            +
                             attn_clip_qkv: Optional[float] = None,
         | 
| 39 | 
            +
                             attn_qk_ln: bool = False,
         | 
| 40 | 
            +
                             softmax_scale: Optional[float] = None,
         | 
| 41 | 
            +
                             attn_pdrop: float = 0.0,
         | 
| 42 | 
            +
                             alibi: bool = False,
         | 
| 43 | 
            +
                             resid_pdrop: float = 0.0,
         | 
| 44 | 
            +
                             low_precision_layernorm: bool = False,
         | 
| 45 | 
            +
                             device: Optional[str] = None,
         | 
| 46 | 
            +
                             **kwargs):
         | 
| 47 | 
            +
                    del kwargs  # unused, just to capture any extra args from the config
         | 
| 48 | 
            +
                    super().__init__()
         | 
| 49 | 
            +
             | 
| 50 | 
            +
                    layernorm_class = LPLayerNorm if low_precision_layernorm else nn.LayerNorm
         | 
| 51 | 
            +
             | 
| 52 | 
            +
                    self.ln_1 = layernorm_class(d_model, device=device)
         | 
| 53 | 
            +
                    self.attn = MultiheadAttention(
         | 
| 54 | 
            +
                        attn_impl=attn_impl,
         | 
| 55 | 
            +
                        attn_clip_qkv=attn_clip_qkv,
         | 
| 56 | 
            +
                        attn_qk_ln=attn_qk_ln,
         | 
| 57 | 
            +
                        softmax_scale=softmax_scale,
         | 
| 58 | 
            +
                        attn_pdrop=attn_pdrop,
         | 
| 59 | 
            +
                        d_model=d_model,
         | 
| 60 | 
            +
                        n_heads=n_heads,
         | 
| 61 | 
            +
                        device=device,
         | 
| 62 | 
            +
                    )
         | 
| 63 | 
            +
                    self.ln_2 = layernorm_class(d_model, device=device)
         | 
| 64 | 
            +
                    self.mlp = GPTMLP(
         | 
| 65 | 
            +
                        d_model=d_model,
         | 
| 66 | 
            +
                        mlp_ratio=mlp_ratio,
         | 
| 67 | 
            +
                        device=device,
         | 
| 68 | 
            +
                    )
         | 
| 69 | 
            +
                    self.resid_attn_dropout = nn.Dropout(resid_pdrop)
         | 
| 70 | 
            +
                    self.resid_mlp_dropout = nn.Dropout(resid_pdrop)
         | 
| 71 | 
            +
             | 
| 72 | 
            +
                def forward(
         | 
| 73 | 
            +
                    self,
         | 
| 74 | 
            +
                    x: torch.Tensor,
         | 
| 75 | 
            +
                    past_key_value: Optional[Tuple[torch.Tensor]] = None,
         | 
| 76 | 
            +
                    attn_bias: Optional[torch.Tensor] = None,
         | 
| 77 | 
            +
                    attention_mask: Optional[torch.ByteTensor] = None,
         | 
| 78 | 
            +
                    is_causal: bool = True,
         | 
| 79 | 
            +
                ) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor]]]:
         | 
| 80 | 
            +
                    a = self.ln_1(x)
         | 
| 81 | 
            +
                    b, _, past_key_value = self.attn(a,
         | 
| 82 | 
            +
                                                     past_key_value=past_key_value,
         | 
| 83 | 
            +
                                                     attn_bias=attn_bias,
         | 
| 84 | 
            +
                                                     attention_mask=attention_mask,
         | 
| 85 | 
            +
                                                     is_causal=is_causal)
         | 
| 86 | 
            +
                    x = x + self.resid_attn_dropout(b)
         | 
| 87 | 
            +
                    m = self.ln_2(x)
         | 
| 88 | 
            +
                    n = self.mlp(m)
         | 
| 89 | 
            +
                    x = x + self.resid_mlp_dropout(n)
         | 
| 90 | 
            +
                    return x, past_key_value
         | 
    	
        low_precision_layernorm.py
    ADDED
    
    | @@ -0,0 +1,35 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import torch
         | 
| 2 | 
            +
            import torch.nn.functional as F
         | 
| 3 | 
            +
             | 
| 4 | 
            +
             | 
| 5 | 
            +
            class LPLayerNorm(torch.nn.LayerNorm):
         | 
| 6 | 
            +
                def __init__(self, normalized_shape, eps=1e-05, elementwise_affine=True, device=None, dtype=None):
         | 
| 7 | 
            +
                    super().__init__(
         | 
| 8 | 
            +
                        normalized_shape=normalized_shape,
         | 
| 9 | 
            +
                        eps=eps,
         | 
| 10 | 
            +
                        elementwise_affine=elementwise_affine,
         | 
| 11 | 
            +
                        device=device,
         | 
| 12 | 
            +
                        dtype=dtype,
         | 
| 13 | 
            +
                    )
         | 
| 14 | 
            +
             | 
| 15 | 
            +
                def forward(self, x):
         | 
| 16 | 
            +
                    module_device = x.device
         | 
| 17 | 
            +
                    downcast_x = _cast_if_autocast_enabled(x)
         | 
| 18 | 
            +
                    downcast_weight = _cast_if_autocast_enabled(
         | 
| 19 | 
            +
                        self.weight) if self.weight is not None else self.weight
         | 
| 20 | 
            +
                    downcast_bias = _cast_if_autocast_enabled(
         | 
| 21 | 
            +
                        self.bias) if self.bias is not None else self.bias
         | 
| 22 | 
            +
                    with torch.autocast(enabled=False, device_type=module_device.type):
         | 
| 23 | 
            +
                        return F.layer_norm(downcast_x, self.normalized_shape, downcast_weight, downcast_bias, self.eps)
         | 
| 24 | 
            +
             | 
| 25 | 
            +
             | 
| 26 | 
            +
            def _cast_if_autocast_enabled(tensor):
         | 
| 27 | 
            +
                if torch.is_autocast_enabled():
         | 
| 28 | 
            +
                    if tensor.device.type == 'cuda':
         | 
| 29 | 
            +
                        dtype = torch.get_autocast_gpu_dtype()
         | 
| 30 | 
            +
                    elif tensor.device.type == 'cpu':
         | 
| 31 | 
            +
                        dtype = torch.get_autocast_cpu_dtype()
         | 
| 32 | 
            +
                    else:
         | 
| 33 | 
            +
                        raise NotImplementedError()
         | 
| 34 | 
            +
                    return tensor.to(dtype=dtype)
         | 
| 35 | 
            +
                return tensor
         | 
    	
        param_init_fns.py
    ADDED
    
    | @@ -0,0 +1,464 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # Copyright 2022 MosaicML Examples authors
         | 
| 2 | 
            +
            # SPDX-License-Identifier: Apache-2.0
         | 
| 3 | 
            +
            import math
         | 
| 4 | 
            +
            import warnings
         | 
| 5 | 
            +
            from collections.abc import Sequence
         | 
| 6 | 
            +
            from functools import partial
         | 
| 7 | 
            +
            from typing import Optional, Tuple, Union
         | 
| 8 | 
            +
             | 
| 9 | 
            +
            import torch
         | 
| 10 | 
            +
            from torch import nn
         | 
| 11 | 
            +
             | 
| 12 | 
            +
             | 
| 13 | 
            +
            def torch_default_param_init_fn_(
         | 
| 14 | 
            +
                module: nn.Module,
         | 
| 15 | 
            +
                verbose: int = 0,
         | 
| 16 | 
            +
                **kwargs,
         | 
| 17 | 
            +
            ):
         | 
| 18 | 
            +
                del kwargs  # unused, just to capture any extra args from the config
         | 
| 19 | 
            +
                if verbose > 1:
         | 
| 20 | 
            +
                    warnings.warn(
         | 
| 21 | 
            +
                        f"Initializing network using module's reset_parameters attribute")
         | 
| 22 | 
            +
             | 
| 23 | 
            +
                if hasattr(module, 'reset_parameters'):
         | 
| 24 | 
            +
                    module.reset_parameters()  # type: ignore
         | 
| 25 | 
            +
             | 
| 26 | 
            +
             | 
| 27 | 
            +
            def fused_init_helper_(module: nn.Module, init_fn_):
         | 
| 28 | 
            +
                # parameter initialization is often based on the parameters shape.
         | 
| 29 | 
            +
                # If a layer is fused, initialization should be based on the shapes
         | 
| 30 | 
            +
                # of the original tensor instead of the shape of the fused tensor.
         | 
| 31 | 
            +
                # Layers which are fused should have the _fused attibute defined.
         | 
| 32 | 
            +
                # The first element of _fused is the dimension along which the tensor is fused.
         | 
| 33 | 
            +
                # This is followed by an iterable of split indices."
         | 
| 34 | 
            +
             | 
| 35 | 
            +
                _fused = getattr(module, '_fused', None)
         | 
| 36 | 
            +
             | 
| 37 | 
            +
                if _fused is None:
         | 
| 38 | 
            +
                    raise RuntimeError(f'Internal logic error')
         | 
| 39 | 
            +
             | 
| 40 | 
            +
                dim, splits = _fused
         | 
| 41 | 
            +
                splits = (0, *splits, module.weight.size(dim))  # type: ignore
         | 
| 42 | 
            +
                for s, e in zip(splits[:-1], splits[1:]):
         | 
| 43 | 
            +
                    slice_indices = [slice(None)] * module.weight.ndim  # type: ignore
         | 
| 44 | 
            +
                    slice_indices[dim] = slice(s, e)
         | 
| 45 | 
            +
                    init_fn_(module.weight[slice_indices])  # type: ignore
         | 
| 46 | 
            +
             | 
| 47 | 
            +
             | 
| 48 | 
            +
            def generic_param_init_fn_(
         | 
| 49 | 
            +
                module: nn.Module,
         | 
| 50 | 
            +
                init_fn_,
         | 
| 51 | 
            +
                n_layers: int,
         | 
| 52 | 
            +
                d_model: Optional[int] = None,
         | 
| 53 | 
            +
                init_div_is_residual: Union[int, float, str, bool] = True,
         | 
| 54 | 
            +
                emb_init_std: Optional[float] = None,
         | 
| 55 | 
            +
                emb_init_uniform_lim: Optional[Union[Tuple[float, float], float]] = None,
         | 
| 56 | 
            +
                verbose: int = 0,
         | 
| 57 | 
            +
                **kwargs,
         | 
| 58 | 
            +
            ):
         | 
| 59 | 
            +
                del kwargs  # unused, just to capture any extra args from the config
         | 
| 60 | 
            +
                if verbose > 1:
         | 
| 61 | 
            +
                    warnings.warn(
         | 
| 62 | 
            +
                        f'If model has bias parameters they are initialized to 0.')
         | 
| 63 | 
            +
             | 
| 64 | 
            +
                # enable user to divide _is_residual weights by
         | 
| 65 | 
            +
                # a value which defaults to math.sqrt(2 * cfg.n_layers)
         | 
| 66 | 
            +
                init_div_is_residual = init_div_is_residual
         | 
| 67 | 
            +
             | 
| 68 | 
            +
                if init_div_is_residual is False:
         | 
| 69 | 
            +
                    # not used, for pyright
         | 
| 70 | 
            +
                    div_is_residual = 1.0
         | 
| 71 | 
            +
                elif init_div_is_residual is True:
         | 
| 72 | 
            +
                    div_is_residual = math.sqrt(2 * n_layers)
         | 
| 73 | 
            +
                elif isinstance(init_div_is_residual, float) or isinstance(
         | 
| 74 | 
            +
                        init_div_is_residual, int):
         | 
| 75 | 
            +
                    div_is_residual = init_div_is_residual
         | 
| 76 | 
            +
                elif isinstance(init_div_is_residual,
         | 
| 77 | 
            +
                                str) and init_div_is_residual.isnumeric():
         | 
| 78 | 
            +
                    # do not trust YAML parsing to always convert numbers to numbers
         | 
| 79 | 
            +
                    div_is_residual = float(init_div_is_residual)
         | 
| 80 | 
            +
                else:
         | 
| 81 | 
            +
                    # not used, for pyright
         | 
| 82 | 
            +
                    div_is_residual = 1.0
         | 
| 83 | 
            +
                    raise ValueError(
         | 
| 84 | 
            +
                        f'Expected init_div_is_residual to be boolean or numeric, got {init_div_is_residual}'
         | 
| 85 | 
            +
                    )
         | 
| 86 | 
            +
             | 
| 87 | 
            +
                if init_div_is_residual is not False:
         | 
| 88 | 
            +
                    if verbose > 1:
         | 
| 89 | 
            +
                        warnings.warn(
         | 
| 90 | 
            +
                            f'Initializing _is_residual layers then dividing them by {div_is_residual}.' +
         | 
| 91 | 
            +
                            f'set `init_div_is_residual: false` in model config to disable this.'
         | 
| 92 | 
            +
                        )
         | 
| 93 | 
            +
             | 
| 94 | 
            +
                if isinstance(module, nn.Linear):
         | 
| 95 | 
            +
                    # Linear
         | 
| 96 | 
            +
                    if hasattr(module, '_fused'):
         | 
| 97 | 
            +
                        fused_init_helper_(module, init_fn_)
         | 
| 98 | 
            +
                    else:
         | 
| 99 | 
            +
                        init_fn_(module.weight)
         | 
| 100 | 
            +
                    if module.bias is not None:
         | 
| 101 | 
            +
                        torch.nn.init.zeros_(module.bias)
         | 
| 102 | 
            +
             | 
| 103 | 
            +
                    if init_div_is_residual is not False and getattr(
         | 
| 104 | 
            +
                            module, '_is_residual', False):
         | 
| 105 | 
            +
                        with torch.no_grad():
         | 
| 106 | 
            +
                            module.weight.div_(div_is_residual)
         | 
| 107 | 
            +
             | 
| 108 | 
            +
                elif isinstance(module, nn.Embedding):
         | 
| 109 | 
            +
                    # Embedding
         | 
| 110 | 
            +
                    if emb_init_std is not None:
         | 
| 111 | 
            +
                        std = emb_init_std
         | 
| 112 | 
            +
                        if std == 0:
         | 
| 113 | 
            +
                            warnings.warn(f'Embedding layer initialized to 0.')
         | 
| 114 | 
            +
                        emb_init_fn_ = partial(torch.nn.init.normal_, mean=0.0, std=std)
         | 
| 115 | 
            +
                        if verbose > 1:
         | 
| 116 | 
            +
                            warnings.warn(
         | 
| 117 | 
            +
                                f'Embedding layer initialized using normal distribution with mean=0 and {std=}.'
         | 
| 118 | 
            +
                            )
         | 
| 119 | 
            +
                    elif emb_init_uniform_lim is not None:
         | 
| 120 | 
            +
                        lim = emb_init_uniform_lim
         | 
| 121 | 
            +
                        if isinstance(lim, Sequence):
         | 
| 122 | 
            +
                            if len(lim) > 2:
         | 
| 123 | 
            +
                                raise ValueError(
         | 
| 124 | 
            +
                                    f'Uniform init requires a min and a max limit. User input: {lim}.'
         | 
| 125 | 
            +
                                )
         | 
| 126 | 
            +
                            if lim[0] == lim[1]:
         | 
| 127 | 
            +
                                warnings.warn(f'Embedding layer initialized to {lim[0]}.')
         | 
| 128 | 
            +
                        else:
         | 
| 129 | 
            +
                            if lim == 0:
         | 
| 130 | 
            +
                                warnings.warn(f'Embedding layer initialized to 0.')
         | 
| 131 | 
            +
                            lim = [-lim, lim]
         | 
| 132 | 
            +
                        a, b = lim
         | 
| 133 | 
            +
                        emb_init_fn_ = partial(torch.nn.init.uniform_, a=a, b=b)
         | 
| 134 | 
            +
                        if verbose > 1:
         | 
| 135 | 
            +
                            warnings.warn(
         | 
| 136 | 
            +
                                f'Embedding layer initialized using uniform distribution in range {lim}.'
         | 
| 137 | 
            +
                            )
         | 
| 138 | 
            +
                    else:
         | 
| 139 | 
            +
                        emb_init_fn_ = init_fn_
         | 
| 140 | 
            +
             | 
| 141 | 
            +
                    emb_init_fn_(module.weight)
         | 
| 142 | 
            +
             | 
| 143 | 
            +
                elif isinstance(module, nn.LayerNorm):
         | 
| 144 | 
            +
                    # LayerNorm
         | 
| 145 | 
            +
                    if verbose > 1:
         | 
| 146 | 
            +
                        warnings.warn(
         | 
| 147 | 
            +
                            f'LayerNorm gamma weights are set to 1. If the layer has a bias it is initialized to 0.'
         | 
| 148 | 
            +
                        )
         | 
| 149 | 
            +
                    torch.nn.init.ones_(module.weight)
         | 
| 150 | 
            +
                    if module.bias is not None:
         | 
| 151 | 
            +
                        torch.nn.init.zeros_(module.bias)
         | 
| 152 | 
            +
             | 
| 153 | 
            +
                elif isinstance(module, nn.MultiheadAttention):
         | 
| 154 | 
            +
                    # torch's MultiheadAttention
         | 
| 155 | 
            +
                    if module._qkv_same_embed_dim:
         | 
| 156 | 
            +
                        assert module.in_proj_weight is not None
         | 
| 157 | 
            +
                        assert module.q_proj_weight is None and module.k_proj_weight is None and module.v_proj_weight is None
         | 
| 158 | 
            +
                        assert d_model is not None
         | 
| 159 | 
            +
                        # in_proj_weight is actually 3 layers and should be split up for width based init
         | 
| 160 | 
            +
                        _d = d_model
         | 
| 161 | 
            +
                        splits = (0, _d, 2 * _d, 3 * _d)
         | 
| 162 | 
            +
                        for s, e in zip(splits[:-1], splits[1:]):
         | 
| 163 | 
            +
                            init_fn_(module.in_proj_weight[s:e])
         | 
| 164 | 
            +
                    else:
         | 
| 165 | 
            +
                        assert module.q_proj_weight is not None and module.k_proj_weight is not None and module.v_proj_weight is not None
         | 
| 166 | 
            +
                        assert module.in_proj_weight is None
         | 
| 167 | 
            +
                        init_fn_(module.q_proj_weight)
         | 
| 168 | 
            +
                        init_fn_(module.k_proj_weight)
         | 
| 169 | 
            +
                        init_fn_(module.v_proj_weight)
         | 
| 170 | 
            +
             | 
| 171 | 
            +
                    # bias
         | 
| 172 | 
            +
                    if module.in_proj_bias is not None:
         | 
| 173 | 
            +
                        torch.nn.init.zeros_(module.in_proj_bias)
         | 
| 174 | 
            +
                    if module.bias_k is not None:
         | 
| 175 | 
            +
                        torch.nn.init.zeros_(module.bias_k)
         | 
| 176 | 
            +
                    if module.bias_v is not None:
         | 
| 177 | 
            +
                        torch.nn.init.zeros_(module.bias_v)
         | 
| 178 | 
            +
             | 
| 179 | 
            +
                    # out proj
         | 
| 180 | 
            +
                    init_fn_(module.out_proj.weight)
         | 
| 181 | 
            +
                    if init_div_is_residual is not False and getattr(
         | 
| 182 | 
            +
                            module.out_proj, '_is_residual', False):
         | 
| 183 | 
            +
                        with torch.no_grad():
         | 
| 184 | 
            +
                            module.out_proj.weight.div_(div_is_residual)
         | 
| 185 | 
            +
                    if module.out_proj.bias is not None:
         | 
| 186 | 
            +
                        torch.nn.init.zeros_(module.out_proj.bias)
         | 
| 187 | 
            +
             | 
| 188 | 
            +
                else:
         | 
| 189 | 
            +
                    for _ in module.parameters(recurse=False):
         | 
| 190 | 
            +
                        # raise error if uninitialized module has any parameters
         | 
| 191 | 
            +
                        raise NotImplementedError(
         | 
| 192 | 
            +
                            f'{module.__class__.__name__} parameters are not initialized by param_init_fn.'
         | 
| 193 | 
            +
                        )
         | 
| 194 | 
            +
             | 
| 195 | 
            +
             | 
| 196 | 
            +
            def _normal_init_(std, mean=0.0):
         | 
| 197 | 
            +
                return partial(torch.nn.init.normal_, mean=mean, std=std)
         | 
| 198 | 
            +
             | 
| 199 | 
            +
             | 
| 200 | 
            +
            def _normal_param_init_fn_(
         | 
| 201 | 
            +
                module: nn.Module,
         | 
| 202 | 
            +
                std: float,
         | 
| 203 | 
            +
                n_layers: int,
         | 
| 204 | 
            +
                d_model: Optional[int] = None,
         | 
| 205 | 
            +
                init_div_is_residual: Union[int, float, str, bool] = True,
         | 
| 206 | 
            +
                emb_init_std: Optional[float] = None,
         | 
| 207 | 
            +
                emb_init_uniform_lim: Optional[Union[Tuple[float, float], float]] = None,
         | 
| 208 | 
            +
                verbose: int = 0,
         | 
| 209 | 
            +
                **kwargs,
         | 
| 210 | 
            +
            ):
         | 
| 211 | 
            +
                del kwargs  # unused, just to capture any extra args from the config
         | 
| 212 | 
            +
                init_fn_ = _normal_init_(std=std)
         | 
| 213 | 
            +
             | 
| 214 | 
            +
                if verbose > 1:
         | 
| 215 | 
            +
                    warnings.warn(
         | 
| 216 | 
            +
                        f'Using torch.nn.init.normal_ init fn mean=0.0, std={std}')
         | 
| 217 | 
            +
             | 
| 218 | 
            +
                generic_param_init_fn_(
         | 
| 219 | 
            +
                    module=module,
         | 
| 220 | 
            +
                    init_fn_=init_fn_,
         | 
| 221 | 
            +
                    d_model=d_model,
         | 
| 222 | 
            +
                    n_layers=n_layers,
         | 
| 223 | 
            +
                    init_div_is_residual=init_div_is_residual,
         | 
| 224 | 
            +
                    emb_init_std=emb_init_std,
         | 
| 225 | 
            +
                    emb_init_uniform_lim=emb_init_uniform_lim,
         | 
| 226 | 
            +
                    verbose=verbose,
         | 
| 227 | 
            +
                )
         | 
| 228 | 
            +
             | 
| 229 | 
            +
             | 
| 230 | 
            +
            def baseline_param_init_fn_(
         | 
| 231 | 
            +
                module: nn.Module,
         | 
| 232 | 
            +
                init_std: float,
         | 
| 233 | 
            +
                n_layers: int,
         | 
| 234 | 
            +
                d_model: Optional[int] = None,
         | 
| 235 | 
            +
                init_div_is_residual: Union[int, float, str, bool] = True,
         | 
| 236 | 
            +
                emb_init_std: Optional[float] = None,
         | 
| 237 | 
            +
                emb_init_uniform_lim: Optional[Union[Tuple[float, float], float]] = None,
         | 
| 238 | 
            +
                verbose: int = 0,
         | 
| 239 | 
            +
                **kwargs,
         | 
| 240 | 
            +
            ):
         | 
| 241 | 
            +
                del kwargs  # unused, just to capture any extra args from the config
         | 
| 242 | 
            +
                if init_std is None:
         | 
| 243 | 
            +
                    raise ValueError(
         | 
| 244 | 
            +
                        'You must set model.init_std to a float value to use the default initialization scheme.'
         | 
| 245 | 
            +
                    )
         | 
| 246 | 
            +
                _normal_param_init_fn_(
         | 
| 247 | 
            +
                    module=module,
         | 
| 248 | 
            +
                    std=init_std,
         | 
| 249 | 
            +
                    d_model=d_model,
         | 
| 250 | 
            +
                    n_layers=n_layers,
         | 
| 251 | 
            +
                    init_div_is_residual=init_div_is_residual,
         | 
| 252 | 
            +
                    emb_init_std=emb_init_std,
         | 
| 253 | 
            +
                    emb_init_uniform_lim=emb_init_uniform_lim,
         | 
| 254 | 
            +
                    verbose=verbose,
         | 
| 255 | 
            +
                )
         | 
| 256 | 
            +
             | 
| 257 | 
            +
             | 
| 258 | 
            +
            def small_param_init_fn_(
         | 
| 259 | 
            +
                module: nn.Module,
         | 
| 260 | 
            +
                n_layers: int,
         | 
| 261 | 
            +
                d_model: int,
         | 
| 262 | 
            +
                init_div_is_residual: Union[int, float, str, bool] = True,
         | 
| 263 | 
            +
                emb_init_std: Optional[float] = None,
         | 
| 264 | 
            +
                emb_init_uniform_lim: Optional[Union[Tuple[float, float], float]] = None,
         | 
| 265 | 
            +
                verbose: int = 0,
         | 
| 266 | 
            +
                **kwargs,
         | 
| 267 | 
            +
            ):
         | 
| 268 | 
            +
                del kwargs  # unused, just to capture any extra args from the config
         | 
| 269 | 
            +
                # very close to kaiming normal
         | 
| 270 | 
            +
                # from Transformers without Tears (2019) - Nguyen & Salazar
         | 
| 271 | 
            +
                std = math.sqrt(2 / (5 * d_model))
         | 
| 272 | 
            +
                _normal_param_init_fn_(
         | 
| 273 | 
            +
                    module=module,
         | 
| 274 | 
            +
                    std=std,
         | 
| 275 | 
            +
                    d_model=d_model,
         | 
| 276 | 
            +
                    n_layers=n_layers,
         | 
| 277 | 
            +
                    init_div_is_residual=init_div_is_residual,
         | 
| 278 | 
            +
                    emb_init_std=emb_init_std,
         | 
| 279 | 
            +
                    emb_init_uniform_lim=emb_init_uniform_lim,
         | 
| 280 | 
            +
                    verbose=verbose,
         | 
| 281 | 
            +
                )
         | 
| 282 | 
            +
             | 
| 283 | 
            +
             | 
| 284 | 
            +
            def neox_param_init_fn_(
         | 
| 285 | 
            +
                module: nn.Module,
         | 
| 286 | 
            +
                n_layers: int,
         | 
| 287 | 
            +
                d_model: int,
         | 
| 288 | 
            +
                emb_init_std: Optional[float] = None,
         | 
| 289 | 
            +
                emb_init_uniform_lim: Optional[Union[Tuple[float, float], float]] = None,
         | 
| 290 | 
            +
                verbose: int = 0,
         | 
| 291 | 
            +
                **kwargs,
         | 
| 292 | 
            +
            ):
         | 
| 293 | 
            +
                """From section 2.3.1 of GPT-NeoX-20B:
         | 
| 294 | 
            +
             | 
| 295 | 
            +
                An Open-Source AutoregressiveLanguage Model — Black et. al. (2022)
         | 
| 296 | 
            +
                see https://github.com/EleutherAI/gpt-neox/blob/9610391ab319403cef079b438edd016a2443af54/megatron/model/init_functions.py#L151
         | 
| 297 | 
            +
                and https://github.com/EleutherAI/gpt-neox/blob/main/megatron/model/transformer.py
         | 
| 298 | 
            +
                """
         | 
| 299 | 
            +
                del kwargs  # unused, just to capture any extra args from the config
         | 
| 300 | 
            +
                residual_div = n_layers / math.sqrt(10)  # small std / wang std
         | 
| 301 | 
            +
             | 
| 302 | 
            +
                if verbose > 1:
         | 
| 303 | 
            +
                    warnings.warn(f'setting init_div_is_residual to {residual_div}')
         | 
| 304 | 
            +
             | 
| 305 | 
            +
                small_param_init_fn_(
         | 
| 306 | 
            +
                    module=module,
         | 
| 307 | 
            +
                    d_model=d_model,
         | 
| 308 | 
            +
                    n_layers=n_layers,
         | 
| 309 | 
            +
                    init_div_is_residual=residual_div,
         | 
| 310 | 
            +
                    emb_init_std=emb_init_std,
         | 
| 311 | 
            +
                    emb_init_uniform_lim=emb_init_uniform_lim,
         | 
| 312 | 
            +
                    verbose=verbose,
         | 
| 313 | 
            +
                )
         | 
| 314 | 
            +
             | 
| 315 | 
            +
             | 
| 316 | 
            +
            def kaiming_uniform_param_init_fn_(
         | 
| 317 | 
            +
                module: nn.Module,
         | 
| 318 | 
            +
                n_layers: int,
         | 
| 319 | 
            +
                d_model: Optional[int] = None,
         | 
| 320 | 
            +
                init_div_is_residual: Union[int, float, str, bool] = True,
         | 
| 321 | 
            +
                emb_init_std: Optional[float] = None,
         | 
| 322 | 
            +
                emb_init_uniform_lim: Optional[Union[Tuple[float, float], float]] = None,
         | 
| 323 | 
            +
                init_gain: float = 0,
         | 
| 324 | 
            +
                fan_mode: str = 'fan_in',
         | 
| 325 | 
            +
                init_nonlinearity: str = 'leaky_relu',
         | 
| 326 | 
            +
                verbose: int = 0,
         | 
| 327 | 
            +
                **kwargs,
         | 
| 328 | 
            +
            ):
         | 
| 329 | 
            +
                del kwargs  # unused, just to capture any extra args from the config
         | 
| 330 | 
            +
             | 
| 331 | 
            +
                if verbose > 1:
         | 
| 332 | 
            +
                    warnings.warn(
         | 
| 333 | 
            +
                        f'Using nn.init.kaiming_uniform_ init fn with parameters: ' +
         | 
| 334 | 
            +
                        f'a={init_gain}, mode={fan_mode}, nonlinearity={init_nonlinearity}'
         | 
| 335 | 
            +
                    )
         | 
| 336 | 
            +
             | 
| 337 | 
            +
                kaiming_uniform_ = partial(nn.init.kaiming_uniform_,
         | 
| 338 | 
            +
                                           a=init_gain,
         | 
| 339 | 
            +
                                           mode=fan_mode,
         | 
| 340 | 
            +
                                           nonlinearity=init_nonlinearity)
         | 
| 341 | 
            +
             | 
| 342 | 
            +
                generic_param_init_fn_(
         | 
| 343 | 
            +
                    module=module,
         | 
| 344 | 
            +
                    init_fn_=kaiming_uniform_,
         | 
| 345 | 
            +
                    d_model=d_model,
         | 
| 346 | 
            +
                    n_layers=n_layers,
         | 
| 347 | 
            +
                    init_div_is_residual=init_div_is_residual,
         | 
| 348 | 
            +
                    emb_init_std=emb_init_std,
         | 
| 349 | 
            +
                    emb_init_uniform_lim=emb_init_uniform_lim,
         | 
| 350 | 
            +
                    verbose=verbose,
         | 
| 351 | 
            +
                )
         | 
| 352 | 
            +
             | 
| 353 | 
            +
             | 
| 354 | 
            +
            def kaiming_normal_param_init_fn_(
         | 
| 355 | 
            +
                module: nn.Module,
         | 
| 356 | 
            +
                n_layers: int,
         | 
| 357 | 
            +
                d_model: Optional[int] = None,
         | 
| 358 | 
            +
                init_div_is_residual: Union[int, float, str, bool] = True,
         | 
| 359 | 
            +
                emb_init_std: Optional[float] = None,
         | 
| 360 | 
            +
                emb_init_uniform_lim: Optional[Union[Tuple[float, float], float]] = None,
         | 
| 361 | 
            +
                init_gain: float = 0,
         | 
| 362 | 
            +
                fan_mode: str = 'fan_in',
         | 
| 363 | 
            +
                init_nonlinearity: str = 'leaky_relu',
         | 
| 364 | 
            +
                verbose: int = 0,
         | 
| 365 | 
            +
                **kwargs,
         | 
| 366 | 
            +
            ):
         | 
| 367 | 
            +
                del kwargs  # unused, just to capture any extra args from the config
         | 
| 368 | 
            +
             | 
| 369 | 
            +
                if verbose > 1:
         | 
| 370 | 
            +
                    warnings.warn(
         | 
| 371 | 
            +
                        f'Using nn.init.kaiming_normal_ init fn with parameters: ' +
         | 
| 372 | 
            +
                        f'a={init_gain}, mode={fan_mode}, nonlinearity={init_nonlinearity}'
         | 
| 373 | 
            +
                    )
         | 
| 374 | 
            +
             | 
| 375 | 
            +
                kaiming_normal_ = partial(torch.nn.init.kaiming_normal_,
         | 
| 376 | 
            +
                                          a=init_gain,
         | 
| 377 | 
            +
                                          mode=fan_mode,
         | 
| 378 | 
            +
                                          nonlinearity=init_nonlinearity)
         | 
| 379 | 
            +
             | 
| 380 | 
            +
                generic_param_init_fn_(
         | 
| 381 | 
            +
                    module=module,
         | 
| 382 | 
            +
                    init_fn_=kaiming_normal_,
         | 
| 383 | 
            +
                    d_model=d_model,
         | 
| 384 | 
            +
                    n_layers=n_layers,
         | 
| 385 | 
            +
                    init_div_is_residual=init_div_is_residual,
         | 
| 386 | 
            +
                    emb_init_std=emb_init_std,
         | 
| 387 | 
            +
                    emb_init_uniform_lim=emb_init_uniform_lim,
         | 
| 388 | 
            +
                    verbose=verbose,
         | 
| 389 | 
            +
                )
         | 
| 390 | 
            +
             | 
| 391 | 
            +
             | 
| 392 | 
            +
            def xavier_uniform_param_init_fn_(
         | 
| 393 | 
            +
                module: nn.Module,
         | 
| 394 | 
            +
                n_layers: int,
         | 
| 395 | 
            +
                d_model: Optional[int] = None,
         | 
| 396 | 
            +
                init_div_is_residual: Union[int, float, str, bool] = True,
         | 
| 397 | 
            +
                emb_init_std: Optional[float] = None,
         | 
| 398 | 
            +
                emb_init_uniform_lim: Optional[Union[Tuple[float, float], float]] = None,
         | 
| 399 | 
            +
                init_gain: float = 0,
         | 
| 400 | 
            +
                verbose: int = 0,
         | 
| 401 | 
            +
                **kwargs,
         | 
| 402 | 
            +
            ):
         | 
| 403 | 
            +
                del kwargs  # unused, just to capture any extra args from the config
         | 
| 404 | 
            +
                xavier_uniform_ = partial(torch.nn.init.xavier_uniform_, gain=init_gain)
         | 
| 405 | 
            +
             | 
| 406 | 
            +
                if verbose > 1:
         | 
| 407 | 
            +
                    warnings.warn(
         | 
| 408 | 
            +
                        f'Using torch.nn.init.xavier_uniform_ init fn with parameters: ' +
         | 
| 409 | 
            +
                        f'gain={init_gain}'
         | 
| 410 | 
            +
                    )
         | 
| 411 | 
            +
             | 
| 412 | 
            +
                generic_param_init_fn_(
         | 
| 413 | 
            +
                    module=module,
         | 
| 414 | 
            +
                    init_fn_=xavier_uniform_,
         | 
| 415 | 
            +
                    d_model=d_model,
         | 
| 416 | 
            +
                    n_layers=n_layers,
         | 
| 417 | 
            +
                    init_div_is_residual=init_div_is_residual,
         | 
| 418 | 
            +
                    emb_init_std=emb_init_std,
         | 
| 419 | 
            +
                    emb_init_uniform_lim=emb_init_uniform_lim,
         | 
| 420 | 
            +
                    verbose=verbose,
         | 
| 421 | 
            +
                )
         | 
| 422 | 
            +
             | 
| 423 | 
            +
             | 
| 424 | 
            +
            def xavier_normal_param_init_fn_(
         | 
| 425 | 
            +
                module: nn.Module,
         | 
| 426 | 
            +
                n_layers: int,
         | 
| 427 | 
            +
                d_model: Optional[int] = None,
         | 
| 428 | 
            +
                init_div_is_residual: Union[int, float, str, bool] = True,
         | 
| 429 | 
            +
                emb_init_std: Optional[float] = None,
         | 
| 430 | 
            +
                emb_init_uniform_lim: Optional[Union[Tuple[float, float], float]] = None,
         | 
| 431 | 
            +
                init_gain: float = 0,
         | 
| 432 | 
            +
                verbose: int = 0,
         | 
| 433 | 
            +
                **kwargs,
         | 
| 434 | 
            +
            ):
         | 
| 435 | 
            +
                xavier_normal_ = partial(torch.nn.init.xavier_normal_, gain=init_gain)
         | 
| 436 | 
            +
             | 
| 437 | 
            +
                if verbose > 1:
         | 
| 438 | 
            +
                    warnings.warn(
         | 
| 439 | 
            +
                        f'Using torch.nn.init.xavier_normal_ init fn with parameters: ' +
         | 
| 440 | 
            +
                        f'gain={init_gain}'
         | 
| 441 | 
            +
                    )
         | 
| 442 | 
            +
             | 
| 443 | 
            +
                generic_param_init_fn_(
         | 
| 444 | 
            +
                    module=module,
         | 
| 445 | 
            +
                    init_fn_=xavier_normal_,
         | 
| 446 | 
            +
                    d_model=d_model,
         | 
| 447 | 
            +
                    n_layers=n_layers,
         | 
| 448 | 
            +
                    init_div_is_residual=init_div_is_residual,
         | 
| 449 | 
            +
                    emb_init_std=emb_init_std,
         | 
| 450 | 
            +
                    emb_init_uniform_lim=emb_init_uniform_lim,
         | 
| 451 | 
            +
                    verbose=verbose,
         | 
| 452 | 
            +
                )
         | 
| 453 | 
            +
             | 
| 454 | 
            +
             | 
| 455 | 
            +
            MODEL_INIT_REGISTRY = {
         | 
| 456 | 
            +
                'default_': torch_default_param_init_fn_,
         | 
| 457 | 
            +
                'baseline_': baseline_param_init_fn_,
         | 
| 458 | 
            +
                'kaiming_uniform_': kaiming_uniform_param_init_fn_,
         | 
| 459 | 
            +
                'kaiming_normal_': kaiming_normal_param_init_fn_,
         | 
| 460 | 
            +
                'neox_init_': neox_param_init_fn_,
         | 
| 461 | 
            +
                'small_init_': small_param_init_fn_,
         | 
| 462 | 
            +
                'xavier_uniform_': xavier_uniform_param_init_fn_,
         | 
| 463 | 
            +
                'xavier_normal_': xavier_normal_param_init_fn_,
         | 
| 464 | 
            +
            }
         | 
    	
        pytorch_model.bin
    ADDED
    
    | @@ -0,0 +1,3 @@ | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            version https://git-lfs.github.com/spec/v1
         | 
| 2 | 
            +
            oid sha256:6516d02ef00bc903aad7d05dc35607cff7e4c7335d4f1bf424cdcb6695cd3e86
         | 
| 3 | 
            +
            size 10402658381
         | 
    	
        replit_lm.py
    ADDED
    
    | @@ -0,0 +1,453 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # Copyright 2022 MosaicML Examples authors
         | 
| 2 | 
            +
            # SPDX-License-Identifier: Apache-2.0
         | 
| 3 | 
            +
             | 
| 4 | 
            +
            """Forked from the MosaicGPT model class from the Mosaic Examples codebase of date May 1st, 2023.
         | 
| 5 | 
            +
            Permalink: https://github.com/mosaicml/examples/blob/52cd4fef69497f225a034fcd10692f8613732d10/examples/llm/src/models/mosaic_gpt/mosaic_gpt.py
         | 
| 6 | 
            +
            """
         | 
| 7 | 
            +
             | 
| 8 | 
            +
            """A simple, flexible implementation of a GPT model.
         | 
| 9 | 
            +
             | 
| 10 | 
            +
            Inspired by https://github.com/karpathy/minGPT/blob/master/mingpt/model.py
         | 
| 11 | 
            +
            """
         | 
| 12 | 
            +
             | 
| 13 | 
            +
            import math
         | 
| 14 | 
            +
            import torch
         | 
| 15 | 
            +
            import torch.nn as nn
         | 
| 16 | 
            +
            import torch.nn.functional as F
         | 
| 17 | 
            +
            import warnings
         | 
| 18 | 
            +
             | 
| 19 | 
            +
            from transformers import PreTrainedModel
         | 
| 20 | 
            +
            from transformers.modeling_outputs import CausalLMOutputWithPast
         | 
| 21 | 
            +
            from typing import List, Optional, Tuple
         | 
| 22 | 
            +
             | 
| 23 | 
            +
            from .attention import attn_bias as module_attn_bias, attn_bias_shape as module_attn_bias_shape
         | 
| 24 | 
            +
            from .gpt_blocks import GPTBlock
         | 
| 25 | 
            +
            from .configuration_replit_lm import \
         | 
| 26 | 
            +
                ReplitLMConfig
         | 
| 27 | 
            +
            from .param_init_fns import MODEL_INIT_REGISTRY
         | 
| 28 | 
            +
            from .low_precision_layernorm import LPLayerNorm
         | 
| 29 | 
            +
             | 
| 30 | 
            +
             | 
| 31 | 
            +
            class ReplitLM(PreTrainedModel):
         | 
| 32 | 
            +
                config_class = ReplitLMConfig
         | 
| 33 | 
            +
                base_model_prefix = 'replit_lm'
         | 
| 34 | 
            +
             | 
| 35 | 
            +
                def __init__(self, config: ReplitLMConfig):
         | 
| 36 | 
            +
                    super().__init__(config)
         | 
| 37 | 
            +
             | 
| 38 | 
            +
                    if config.attn_impl == 'flash' and config.alibi:
         | 
| 39 | 
            +
                        raise RuntimeError("ALiBi is not supported with flash attention. Please use triton or torch.")
         | 
| 40 | 
            +
             | 
| 41 | 
            +
                    self.attn_impl = config.attn_impl
         | 
| 42 | 
            +
                    self.prefix_lm = config.prefix_lm
         | 
| 43 | 
            +
                    self.attn_uses_sequence_id = config.attn_uses_sequence_id
         | 
| 44 | 
            +
                    self.alibi = config.alibi
         | 
| 45 | 
            +
                    self.alibi_bias_max = config.alibi_bias_max
         | 
| 46 | 
            +
             | 
| 47 | 
            +
                    layernorm_class = LPLayerNorm if config.low_precision_layernorm else nn.LayerNorm
         | 
| 48 | 
            +
             | 
| 49 | 
            +
                    # CogView (https://arxiv.org/abs/2105.13290) and GLM-130B (https://arxiv.org/abs/2210.02414)
         | 
| 50 | 
            +
                    # both report this helping with stabilizing training
         | 
| 51 | 
            +
                    self.embedding_fraction = config.embedding_fraction
         | 
| 52 | 
            +
             | 
| 53 | 
            +
                    self.transformer = nn.ModuleDict({
         | 
| 54 | 
            +
                        'wte':
         | 
| 55 | 
            +
                            nn.Embedding(config.vocab_size,
         | 
| 56 | 
            +
                                         config.d_model,
         | 
| 57 | 
            +
                                         device=config.init_device)
         | 
| 58 | 
            +
                    })
         | 
| 59 | 
            +
                    if not self.alibi:
         | 
| 60 | 
            +
                        self.transformer.update({
         | 
| 61 | 
            +
                            'wpe':
         | 
| 62 | 
            +
                                nn.Embedding(config.max_seq_len,
         | 
| 63 | 
            +
                                             config.d_model,
         | 
| 64 | 
            +
                                             device=config.init_device)
         | 
| 65 | 
            +
                        })
         | 
| 66 | 
            +
                    self.transformer.update({'emb_drop': nn.Dropout(config.emb_pdrop)})
         | 
| 67 | 
            +
                    self.transformer.update({
         | 
| 68 | 
            +
                        'blocks':
         | 
| 69 | 
            +
                            nn.ModuleList([
         | 
| 70 | 
            +
                                GPTBlock(device=config.init_device,
         | 
| 71 | 
            +
                                                    **config.to_dict())
         | 
| 72 | 
            +
                                for _ in range(config.n_layers)
         | 
| 73 | 
            +
                            ])
         | 
| 74 | 
            +
                    })
         | 
| 75 | 
            +
                    self.transformer.update({
         | 
| 76 | 
            +
                        'ln_f': layernorm_class(config.d_model, device=config.init_device)
         | 
| 77 | 
            +
                    })
         | 
| 78 | 
            +
             | 
| 79 | 
            +
                    # enables scaling output logits; similar to a softmax "temperature"
         | 
| 80 | 
            +
                    # PaLM paper uses scale 1/sqrt(config.d_model)
         | 
| 81 | 
            +
                    self.logit_scale = None
         | 
| 82 | 
            +
                    if config.logit_scale is not None:
         | 
| 83 | 
            +
                        logit_scale = config.logit_scale
         | 
| 84 | 
            +
                        if isinstance(logit_scale, str):
         | 
| 85 | 
            +
                            if logit_scale == 'inv_sqrt_d_model':
         | 
| 86 | 
            +
                                logit_scale = 1 / math.sqrt(config.d_model)
         | 
| 87 | 
            +
                            else:
         | 
| 88 | 
            +
                                raise ValueError(
         | 
| 89 | 
            +
                                    f"{logit_scale=} is not recognized as an option; use numeric value or 'inv_sqrt_d_model'."
         | 
| 90 | 
            +
                                )
         | 
| 91 | 
            +
                        self.logit_scale = logit_scale
         | 
| 92 | 
            +
             | 
| 93 | 
            +
                    if config.init_device != 'meta':
         | 
| 94 | 
            +
                        print(
         | 
| 95 | 
            +
                            f'You are using {config.init_device=}, but you can also use config.init_device="meta" with Composer + FSDP for fast initialization.'
         | 
| 96 | 
            +
                        )
         | 
| 97 | 
            +
                        self.apply(self.param_init_fn)
         | 
| 98 | 
            +
             | 
| 99 | 
            +
                    self.is_causal = not self.prefix_lm
         | 
| 100 | 
            +
             | 
| 101 | 
            +
                    # define attn mask
         | 
| 102 | 
            +
                    self._attn_bias_initialized = False
         | 
| 103 | 
            +
                    self.attn_bias = None
         | 
| 104 | 
            +
                    self.attn_bias_shape = module_attn_bias_shape(
         | 
| 105 | 
            +
                        self.attn_impl,
         | 
| 106 | 
            +
                        config.n_heads,
         | 
| 107 | 
            +
                        config.max_seq_len,
         | 
| 108 | 
            +
                        self.alibi,
         | 
| 109 | 
            +
                        prefix_lm=self.prefix_lm,
         | 
| 110 | 
            +
                        causal=self.is_causal,
         | 
| 111 | 
            +
                        use_sequence_id=self.attn_uses_sequence_id)
         | 
| 112 | 
            +
             | 
| 113 | 
            +
                    if config.no_bias:
         | 
| 114 | 
            +
                        for module in self.modules():
         | 
| 115 | 
            +
                            if hasattr(module, 'bias') and isinstance(
         | 
| 116 | 
            +
                                    module.bias, nn.Parameter):
         | 
| 117 | 
            +
                                if config.verbose:
         | 
| 118 | 
            +
                                    print(f'Removing bias ({module.bias}) from {module}.')
         | 
| 119 | 
            +
                                module.register_parameter('bias', None)
         | 
| 120 | 
            +
             | 
| 121 | 
            +
                    if config.verbose and config.verbose > 2:
         | 
| 122 | 
            +
                        print(self)
         | 
| 123 | 
            +
             | 
| 124 | 
            +
                @torch.no_grad()
         | 
| 125 | 
            +
                def _attn_bias(self,
         | 
| 126 | 
            +
                               device,
         | 
| 127 | 
            +
                               dtype,
         | 
| 128 | 
            +
                               attention_mask: Optional[torch.ByteTensor] = None,
         | 
| 129 | 
            +
                               prefix_mask: Optional[torch.ByteTensor] = None,
         | 
| 130 | 
            +
                               sequence_id: Optional[torch.LongTensor] = None):
         | 
| 131 | 
            +
                    if not self._attn_bias_initialized:
         | 
| 132 | 
            +
                        if self.attn_bias_shape:
         | 
| 133 | 
            +
                            self.attn_bias = torch.zeros(self.attn_bias_shape,
         | 
| 134 | 
            +
                                                         device=device,
         | 
| 135 | 
            +
                                                         dtype=dtype)
         | 
| 136 | 
            +
                            self.attn_bias = module_attn_bias(
         | 
| 137 | 
            +
                                self.attn_impl,
         | 
| 138 | 
            +
                                self.attn_bias,
         | 
| 139 | 
            +
                                self.config.n_heads,
         | 
| 140 | 
            +
                                self.config.max_seq_len,
         | 
| 141 | 
            +
                                causal=self.is_causal,
         | 
| 142 | 
            +
                                alibi=self.alibi,
         | 
| 143 | 
            +
                                alibi_bias_max=self.alibi_bias_max)
         | 
| 144 | 
            +
                        self._attn_bias_initialized = True
         | 
| 145 | 
            +
             | 
| 146 | 
            +
                    # flash does not support prefix_lm and will incorporate any
         | 
| 147 | 
            +
                    # attention_mask inside the attention module
         | 
| 148 | 
            +
                    if self.attn_impl == 'flash':
         | 
| 149 | 
            +
                        return self.attn_bias, attention_mask
         | 
| 150 | 
            +
             | 
| 151 | 
            +
                    attn_bias = self.attn_bias
         | 
| 152 | 
            +
             | 
| 153 | 
            +
                    # If using torch or triton, we incorporate the prefix_mask (if appropriate)
         | 
| 154 | 
            +
                    if self.prefix_lm:
         | 
| 155 | 
            +
                        assert isinstance(attn_bias, torch.Tensor)  # pyright
         | 
| 156 | 
            +
                        assert isinstance(prefix_mask, torch.Tensor)  # pyright
         | 
| 157 | 
            +
                        attn_bias = self._apply_prefix_mask(attn_bias, prefix_mask)
         | 
| 158 | 
            +
             | 
| 159 | 
            +
                    # If using torch or triton, we incorporate sequence_id (if appropriate)
         | 
| 160 | 
            +
                    if self.attn_uses_sequence_id and sequence_id is not None:
         | 
| 161 | 
            +
                        assert isinstance(attn_bias, torch.Tensor)  # pyright
         | 
| 162 | 
            +
                        attn_bias = self._apply_sequence_id(attn_bias, sequence_id)
         | 
| 163 | 
            +
             | 
| 164 | 
            +
                    # If using torch or triton, we incorporate attention_mask. This will output
         | 
| 165 | 
            +
                    # None in place of attention_mask since it will not be further needed in the
         | 
| 166 | 
            +
                    # attention modules.
         | 
| 167 | 
            +
                    if attention_mask is not None:
         | 
| 168 | 
            +
                        s_k = attention_mask.shape[-1]
         | 
| 169 | 
            +
                        if attn_bias is None:
         | 
| 170 | 
            +
                            attn_bias = torch.zeros((1, 1, 1, s_k),
         | 
| 171 | 
            +
                                                    device=device,
         | 
| 172 | 
            +
                                                    dtype=dtype)
         | 
| 173 | 
            +
                        else:
         | 
| 174 | 
            +
                            attn_bias = attn_bias[:, :, :, -s_k:]
         | 
| 175 | 
            +
                        if prefix_mask is not None and (attention_mask.shape !=
         | 
| 176 | 
            +
                                                        prefix_mask.shape):
         | 
| 177 | 
            +
                            raise ValueError(
         | 
| 178 | 
            +
                                f'attention_mask shape={attention_mask.shape} ' +\
         | 
| 179 | 
            +
                                f'and prefix_mask shape={prefix_mask.shape} are not equal.'
         | 
| 180 | 
            +
                            )
         | 
| 181 | 
            +
                        min_val = torch.finfo(attn_bias.dtype).min
         | 
| 182 | 
            +
                        attn_bias = attn_bias.masked_fill(
         | 
| 183 | 
            +
                            ~attention_mask.view(-1, 1, 1, s_k), min_val)
         | 
| 184 | 
            +
             | 
| 185 | 
            +
                    return attn_bias, None
         | 
| 186 | 
            +
             | 
| 187 | 
            +
                def _apply_prefix_mask(self, attn_bias: torch.Tensor,
         | 
| 188 | 
            +
                                       prefix_mask: torch.Tensor):
         | 
| 189 | 
            +
                    s_k, s_q = attn_bias.shape[-2:]
         | 
| 190 | 
            +
                    if (s_k != self.config.max_seq_len) or (s_q != self.config.max_seq_len):
         | 
| 191 | 
            +
                        raise ValueError(
         | 
| 192 | 
            +
                            'attn_bias does not match the expected shape. ' +\
         | 
| 193 | 
            +
                            f'The last two dimensions should both be {self.config.max_length} ' +\
         | 
| 194 | 
            +
                            f'but are {s_k} and {s_q}.'
         | 
| 195 | 
            +
                        )
         | 
| 196 | 
            +
                    seq_len = prefix_mask.shape[-1]
         | 
| 197 | 
            +
                    if seq_len > self.config.max_seq_len:
         | 
| 198 | 
            +
                        raise ValueError(
         | 
| 199 | 
            +
                            f'prefix_mask sequence length cannot exceed max_seq_len={self.config.max_seq_len}'
         | 
| 200 | 
            +
                        )
         | 
| 201 | 
            +
             | 
| 202 | 
            +
                    # select seq_len subset of attn mask
         | 
| 203 | 
            +
                    attn_bias = attn_bias[..., :seq_len, :seq_len]
         | 
| 204 | 
            +
             | 
| 205 | 
            +
                    # Mix the causal max and the bidirectional mask to get the full
         | 
| 206 | 
            +
                    # allowable attention (i.e. full = not accounting for padding yet)
         | 
| 207 | 
            +
                    causal = torch.tril(
         | 
| 208 | 
            +
                        torch.ones((seq_len, seq_len),
         | 
| 209 | 
            +
                                   dtype=torch.bool,
         | 
| 210 | 
            +
                                   device=prefix_mask.device)).view(1, 1, seq_len, seq_len)
         | 
| 211 | 
            +
                    prefix = prefix_mask.view(-1, 1, 1, seq_len)
         | 
| 212 | 
            +
                    cannot_attend = ~torch.logical_or(causal, prefix.bool())
         | 
| 213 | 
            +
             | 
| 214 | 
            +
                    min_val = torch.finfo(attn_bias.dtype).min
         | 
| 215 | 
            +
                    attn_bias = attn_bias.masked_fill(cannot_attend, min_val)
         | 
| 216 | 
            +
             | 
| 217 | 
            +
                    return attn_bias
         | 
| 218 | 
            +
             | 
| 219 | 
            +
                def _apply_sequence_id(self, attn_bias: torch.Tensor,
         | 
| 220 | 
            +
                                       sequence_id: torch.LongTensor):
         | 
| 221 | 
            +
                    seq_len = sequence_id.shape[-1]
         | 
| 222 | 
            +
                    if seq_len > self.config.max_seq_len:
         | 
| 223 | 
            +
                        raise ValueError(
         | 
| 224 | 
            +
                            f'sequence_id sequence length cannot exceed max_seq_len={self.config.max_seq_len}'
         | 
| 225 | 
            +
                        )
         | 
| 226 | 
            +
             | 
| 227 | 
            +
                    # select seq_len subset of attn mask
         | 
| 228 | 
            +
                    attn_bias = attn_bias[..., :seq_len, :seq_len]
         | 
| 229 | 
            +
             | 
| 230 | 
            +
                    # Restrict attention to tokens that share the same value
         | 
| 231 | 
            +
                    # in sequence_id
         | 
| 232 | 
            +
                    cannot_attend = torch.logical_not(
         | 
| 233 | 
            +
                        torch.eq(sequence_id.view(-1, seq_len, 1),
         | 
| 234 | 
            +
                                 sequence_id.view(-1, 1, seq_len))).unsqueeze(1)
         | 
| 235 | 
            +
                    min_val = torch.finfo(attn_bias.dtype).min
         | 
| 236 | 
            +
                    attn_bias = attn_bias.masked_fill(cannot_attend, min_val)
         | 
| 237 | 
            +
             | 
| 238 | 
            +
                    return attn_bias
         | 
| 239 | 
            +
             | 
| 240 | 
            +
                def forward(
         | 
| 241 | 
            +
                        self,
         | 
| 242 | 
            +
                        input_ids: torch.LongTensor,
         | 
| 243 | 
            +
                        past_key_values: Optional[List[Tuple[torch.FloatTensor]]] = None,
         | 
| 244 | 
            +
                        attention_mask: Optional[torch.ByteTensor] = None,
         | 
| 245 | 
            +
                        prefix_mask: Optional[torch.ByteTensor] = None,
         | 
| 246 | 
            +
                        sequence_id: Optional[torch.LongTensor] = None,
         | 
| 247 | 
            +
                        return_dict: Optional[bool] = None,
         | 
| 248 | 
            +
                        output_attentions: Optional[bool] = None,
         | 
| 249 | 
            +
                        output_hidden_states: Optional[bool] = None,
         | 
| 250 | 
            +
                        use_cache: Optional[bool] = None):
         | 
| 251 | 
            +
                    return_dict = return_dict if return_dict is not None else self.config.return_dict
         | 
| 252 | 
            +
                    use_cache = use_cache if use_cache is not None else self.config.use_cache
         | 
| 253 | 
            +
             | 
| 254 | 
            +
                    # These args are passed in by keyword in huggingface's generate function
         | 
| 255 | 
            +
                    # https://github.com/huggingface/transformers/blob/68287689f2f0d8b7063c400230b3766987abf18d/src/transformers/generation/utils.py#L2201-L2206
         | 
| 256 | 
            +
                    # but have not yet been fully implemented in ReplitLM
         | 
| 257 | 
            +
                    if not return_dict:
         | 
| 258 | 
            +
                        raise NotImplementedError(
         | 
| 259 | 
            +
                            'return_dict False is not implemented yet for ReplitLM')
         | 
| 260 | 
            +
                    if output_attentions:
         | 
| 261 | 
            +
                        raise NotImplementedError(
         | 
| 262 | 
            +
                            'output_attentions is not implemented yet for ReplitLM')
         | 
| 263 | 
            +
             | 
| 264 | 
            +
                    if attention_mask is not None and attention_mask[:, 0].sum(
         | 
| 265 | 
            +
                    ) != attention_mask.shape[0] and self.training:
         | 
| 266 | 
            +
                        raise NotImplementedError(
         | 
| 267 | 
            +
                            'ReplitLM does not support training with left padding.')
         | 
| 268 | 
            +
             | 
| 269 | 
            +
                    if self.prefix_lm and prefix_mask is None:
         | 
| 270 | 
            +
                        raise ValueError(
         | 
| 271 | 
            +
                            'prefix_mask is a required argument when ReplitLM is configured with prefix_lm=True.'
         | 
| 272 | 
            +
                        )
         | 
| 273 | 
            +
             | 
| 274 | 
            +
                    if self.training:
         | 
| 275 | 
            +
                        if self.attn_uses_sequence_id and sequence_id is None:
         | 
| 276 | 
            +
                            raise ValueError(
         | 
| 277 | 
            +
                                'sequence_id is a required argument when ReplitLM is configured with attn_uses_sequence_id=True ' +\
         | 
| 278 | 
            +
                                'and the model is in train mode.'
         | 
| 279 | 
            +
                            )
         | 
| 280 | 
            +
                        elif (self.attn_uses_sequence_id is False) and (sequence_id
         | 
| 281 | 
            +
                                                                        is not None):
         | 
| 282 | 
            +
                            warnings.warn(
         | 
| 283 | 
            +
                                'ReplitLM received non-None input for `sequence_id` but is configured with attn_uses_sequence_id=False. ' +\
         | 
| 284 | 
            +
                                'This input will be ignored. If you want the model to use `sequence_id`, set attn_uses_sequence_id to True.'
         | 
| 285 | 
            +
                            )
         | 
| 286 | 
            +
             | 
| 287 | 
            +
                    S = input_ids.size(1)
         | 
| 288 | 
            +
             | 
| 289 | 
            +
                    assert (
         | 
| 290 | 
            +
                        S <= self.config.max_seq_len
         | 
| 291 | 
            +
                    ), f'Cannot forward input with seq_len={S}, this model only supports seq_len<={self.config.max_seq_len}'
         | 
| 292 | 
            +
             | 
| 293 | 
            +
                    tok_emb = self.transformer.wte(input_ids)  # type: ignore
         | 
| 294 | 
            +
                    if self.alibi:
         | 
| 295 | 
            +
                        x = tok_emb
         | 
| 296 | 
            +
                    else:
         | 
| 297 | 
            +
                        past_position = 0
         | 
| 298 | 
            +
                        if past_key_values is not None:
         | 
| 299 | 
            +
                            if len(past_key_values) != self.config.n_layers:
         | 
| 300 | 
            +
                                raise ValueError(
         | 
| 301 | 
            +
                                    f'past_key_values must provide a past_key_value for each attention ' +\
         | 
| 302 | 
            +
                                    f'layer in the network ({len(past_key_values)=}; {self.config.n_layers=}).'
         | 
| 303 | 
            +
                                )
         | 
| 304 | 
            +
                            # get the key tensor whose spec should be (batch, seq, dim), and
         | 
| 305 | 
            +
                            # collect the `seq`, so that the position embedding is shifted
         | 
| 306 | 
            +
                            past_position = past_key_values[0][0].size(1)
         | 
| 307 | 
            +
             | 
| 308 | 
            +
                        if S + past_position > self.config.max_seq_len:
         | 
| 309 | 
            +
                            raise ValueError(
         | 
| 310 | 
            +
                                f'Cannot forward input with past sequence length {past_position} and current sequence length '
         | 
| 311 | 
            +
                                f'{S + 1}, this model only supports total sequence length <= {self.config.max_seq_len}.'
         | 
| 312 | 
            +
                            )
         | 
| 313 | 
            +
                        pos = torch.arange(past_position,
         | 
| 314 | 
            +
                                           S + past_position,
         | 
| 315 | 
            +
                                           dtype=torch.long,
         | 
| 316 | 
            +
                                           device=input_ids.device).unsqueeze(0)
         | 
| 317 | 
            +
                        if attention_mask is not None:
         | 
| 318 | 
            +
                            # adjust the position indices to account for padding tokens
         | 
| 319 | 
            +
                            pos = torch.clamp(pos - torch.cumsum(
         | 
| 320 | 
            +
                                (~attention_mask).to(torch.int32), dim=1)[:,
         | 
| 321 | 
            +
                                                                          past_position:],
         | 
| 322 | 
            +
                                              min=0)
         | 
| 323 | 
            +
             | 
| 324 | 
            +
                        pos_emb = self.transformer.wpe(pos)  # type: ignore
         | 
| 325 | 
            +
                        x = tok_emb + pos_emb
         | 
| 326 | 
            +
             | 
| 327 | 
            +
                    if self.embedding_fraction == 1:
         | 
| 328 | 
            +
                        x = self.transformer.emb_drop(x)  # type: ignore
         | 
| 329 | 
            +
                    else:
         | 
| 330 | 
            +
                        # this implementation is proposed on page 7 of the GLM-130B paper https://arxiv.org/abs/2210.02414
         | 
| 331 | 
            +
                        x_shrunk = (x * self.embedding_fraction) + (
         | 
| 332 | 
            +
                            x.detach() * (1 - self.embedding_fraction))
         | 
| 333 | 
            +
                        assert isinstance(self.transformer.emb_drop, nn.Module)  # pyright
         | 
| 334 | 
            +
                        x = self.transformer.emb_drop(x_shrunk)
         | 
| 335 | 
            +
             | 
| 336 | 
            +
                    attn_bias, attention_mask = self._attn_bias(
         | 
| 337 | 
            +
                        device=x.device,
         | 
| 338 | 
            +
                        dtype=x.dtype,
         | 
| 339 | 
            +
                        attention_mask=attention_mask,
         | 
| 340 | 
            +
                        prefix_mask=prefix_mask,
         | 
| 341 | 
            +
                        sequence_id=sequence_id)
         | 
| 342 | 
            +
             | 
| 343 | 
            +
                    # initialize the past key values cache if it should be used
         | 
| 344 | 
            +
                    if use_cache and past_key_values is None:
         | 
| 345 | 
            +
                        past_key_values = [() for _ in range(self.config.n_layers)
         | 
| 346 | 
            +
                                          ]  # type: ignore
         | 
| 347 | 
            +
             | 
| 348 | 
            +
                    all_hidden_states = () if output_hidden_states else None
         | 
| 349 | 
            +
                    for b_idx, block in enumerate(self.transformer.blocks):  # type: ignore
         | 
| 350 | 
            +
                        if output_hidden_states:
         | 
| 351 | 
            +
                            assert all_hidden_states is not None  # pyright
         | 
| 352 | 
            +
                            all_hidden_states = all_hidden_states + (x,)
         | 
| 353 | 
            +
                        past_key_value = past_key_values[
         | 
| 354 | 
            +
                            b_idx] if past_key_values is not None else None
         | 
| 355 | 
            +
                        x, past_key_value = block(x,
         | 
| 356 | 
            +
                                                  past_key_value=past_key_value,
         | 
| 357 | 
            +
                                                  attn_bias=attn_bias,
         | 
| 358 | 
            +
                                                  attention_mask=attention_mask,
         | 
| 359 | 
            +
                                                  is_causal=self.is_causal)
         | 
| 360 | 
            +
                        if past_key_values is not None:
         | 
| 361 | 
            +
                            past_key_values[b_idx] = past_key_value
         | 
| 362 | 
            +
             | 
| 363 | 
            +
                    x = self.transformer.ln_f(x)  # type: ignore
         | 
| 364 | 
            +
             | 
| 365 | 
            +
                    # output embedding weight tied to input embedding
         | 
| 366 | 
            +
                    assert isinstance(self.transformer.wte, nn.Module)  # pyright
         | 
| 367 | 
            +
                    assert isinstance(self.transformer.wte.weight, torch.Tensor)  # pyright
         | 
| 368 | 
            +
                    logits = F.linear(x, self.transformer.wte.weight, None)
         | 
| 369 | 
            +
             | 
| 370 | 
            +
                    if self.logit_scale is not None:
         | 
| 371 | 
            +
                        if self.logit_scale == 0:
         | 
| 372 | 
            +
                            warnings.warn(
         | 
| 373 | 
            +
                                f'Multiplying logits by {self.logit_scale=}. This will produce uniform (uninformative) outputs.'
         | 
| 374 | 
            +
                            )
         | 
| 375 | 
            +
                        logits *= self.logit_scale
         | 
| 376 | 
            +
             | 
| 377 | 
            +
                    return CausalLMOutputWithPast(logits=logits,
         | 
| 378 | 
            +
                                                  past_key_values=past_key_values,
         | 
| 379 | 
            +
                                                  hidden_states=all_hidden_states)
         | 
| 380 | 
            +
             | 
| 381 | 
            +
                # Param Initialization, needed for device='meta' fast initialization
         | 
| 382 | 
            +
                def param_init_fn(self, module):
         | 
| 383 | 
            +
                    init_fn_name = self.config.param_init_fn
         | 
| 384 | 
            +
                    if self.config.verbose > 1:
         | 
| 385 | 
            +
                        warnings.warn(f'Using {init_fn_name} initialization.')
         | 
| 386 | 
            +
                    MODEL_INIT_REGISTRY[init_fn_name](module=module,
         | 
| 387 | 
            +
                                                      **self.config.to_dict())
         | 
| 388 | 
            +
             | 
| 389 | 
            +
                # FSDP Wrap function
         | 
| 390 | 
            +
                def fsdp_wrap_fn(self, module):
         | 
| 391 | 
            +
                    return isinstance(module, GPTBlock)
         | 
| 392 | 
            +
             | 
| 393 | 
            +
                # Activation Checkpointing
         | 
| 394 | 
            +
                def activation_checkpointing_fn(self, module):
         | 
| 395 | 
            +
                    return isinstance(module, GPTBlock)
         | 
| 396 | 
            +
             | 
| 397 | 
            +
                def prepare_inputs_for_generation(self,
         | 
| 398 | 
            +
                                                  input_ids,
         | 
| 399 | 
            +
                                                  past_key_values=None,
         | 
| 400 | 
            +
                                                  inputs_embeds=None,
         | 
| 401 | 
            +
                                                  **kwargs):
         | 
| 402 | 
            +
                    if inputs_embeds is not None:
         | 
| 403 | 
            +
                        raise NotImplementedError(
         | 
| 404 | 
            +
                            'inputs_embeds is not implemented for ReplitLM yet')
         | 
| 405 | 
            +
             | 
| 406 | 
            +
                    attention_mask = kwargs['attention_mask'].bool()
         | 
| 407 | 
            +
                    if attention_mask[:, -1].sum() != attention_mask.shape[0]:
         | 
| 408 | 
            +
                        raise NotImplementedError(
         | 
| 409 | 
            +
                            'ReplitLM does not support generation with right padding.')
         | 
| 410 | 
            +
             | 
| 411 | 
            +
                    if self.attn_uses_sequence_id and self.training:
         | 
| 412 | 
            +
                        sequence_id = torch.zeros_like(input_ids[:1])
         | 
| 413 | 
            +
                    else:
         | 
| 414 | 
            +
                        sequence_id = None
         | 
| 415 | 
            +
             | 
| 416 | 
            +
                    if past_key_values is not None:
         | 
| 417 | 
            +
                        input_ids = input_ids[:, -1].unsqueeze(-1)
         | 
| 418 | 
            +
             | 
| 419 | 
            +
                    if self.prefix_lm:
         | 
| 420 | 
            +
                        # Leverage a convenience of sequential generation!
         | 
| 421 | 
            +
                        prefix_mask = torch.ones_like(attention_mask)
         | 
| 422 | 
            +
                        # This requires that we're using the cache
         | 
| 423 | 
            +
                        if kwargs.get('use_cache') == False:
         | 
| 424 | 
            +
                            raise NotImplementedError(
         | 
| 425 | 
            +
                                'ReplitLM with prefix_lm=True does not support use_cache=False.'
         | 
| 426 | 
            +
                            )
         | 
| 427 | 
            +
                    else:
         | 
| 428 | 
            +
                        prefix_mask = None
         | 
| 429 | 
            +
             | 
| 430 | 
            +
                    return {
         | 
| 431 | 
            +
                        'input_ids': input_ids,
         | 
| 432 | 
            +
                        'attention_mask': attention_mask,
         | 
| 433 | 
            +
                        'prefix_mask': prefix_mask,
         | 
| 434 | 
            +
                        'sequence_id': sequence_id,
         | 
| 435 | 
            +
                        'past_key_values': past_key_values,
         | 
| 436 | 
            +
                        'use_cache': kwargs.get('use_cache', True),
         | 
| 437 | 
            +
                    }
         | 
| 438 | 
            +
             | 
| 439 | 
            +
                @staticmethod
         | 
| 440 | 
            +
                def _reorder_cache(past_key_values, beam_idx):
         | 
| 441 | 
            +
                    """Used by HuggingFace generate when using beam search with kv-caching.
         | 
| 442 | 
            +
             | 
| 443 | 
            +
                    See https://github.com/huggingface/transformers/blob/3ec7a47664ebe40c40f4b722f6bb1cd30c3821ec/src/transformers/models/gpt2/modeling_gpt2.py#L1122-L1133
         | 
| 444 | 
            +
                    for an example in transformers.
         | 
| 445 | 
            +
                    """
         | 
| 446 | 
            +
                    reordered_past = []
         | 
| 447 | 
            +
                    for layer_past in past_key_values:
         | 
| 448 | 
            +
                        reordered_past += [
         | 
| 449 | 
            +
                            tuple(
         | 
| 450 | 
            +
                                past_state.index_select(0, beam_idx)
         | 
| 451 | 
            +
                                for past_state in layer_past)
         | 
| 452 | 
            +
                        ]
         | 
| 453 | 
            +
                    return reordered_past
         | 
    	
        replit_lm_tokenizer.py
    ADDED
    
    | @@ -0,0 +1,161 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # Copyright (c) 2020, NVIDIA CORPORATION.  All rights reserved.
         | 
| 2 | 
            +
            #
         | 
| 3 | 
            +
            # Licensed under the Apache License, Version 2.0 (the "License");
         | 
| 4 | 
            +
            # you may not use this file except in compliance with the License.
         | 
| 5 | 
            +
            # You may obtain a copy of the License at
         | 
| 6 | 
            +
            #
         | 
| 7 | 
            +
            #     http://www.apache.org/licenses/LICENSE-2.0
         | 
| 8 | 
            +
            #
         | 
| 9 | 
            +
            # Unless required by applicable law or agreed to in writing, software
         | 
| 10 | 
            +
            # distributed under the License is distributed on an "AS IS" BASIS,
         | 
| 11 | 
            +
            # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
         | 
| 12 | 
            +
            # See the License for the specific language governing permissions and
         | 
| 13 | 
            +
            # limitations under the License.
         | 
| 14 | 
            +
             | 
| 15 | 
            +
            """
         | 
| 16 | 
            +
            Forked from the file src/transformers/models/bert_generation/tokenization_bert_generation.py from the HuggingFace Transformers library.
         | 
| 17 | 
            +
            Permalink: https://github.com/huggingface/transformers/blob/04ab5605fbb4ef207b10bf2772d88c53fc242e83/src/transformers/models/bert_generation/tokenization_bert_generation.py
         | 
| 18 | 
            +
             | 
| 19 | 
            +
            Class is modified for compatibility with custom vocabulary and to achieve desired encode/decode behavior for Replit Code v1.3b model.
         | 
| 20 | 
            +
            """
         | 
| 21 | 
            +
             | 
| 22 | 
            +
            """ Tokenizer class for ReplitLM"""
         | 
| 23 | 
            +
             | 
| 24 | 
            +
             | 
| 25 | 
            +
            import os
         | 
| 26 | 
            +
            import sentencepiece as spm
         | 
| 27 | 
            +
            from shutil import copyfile
         | 
| 28 | 
            +
            from transformers import PreTrainedTokenizer
         | 
| 29 | 
            +
            from typing import Any, Dict, List, Optional, Tuple
         | 
| 30 | 
            +
            VOCAB_FILES_NAMES = {"vocab_file": "spiece.model"}
         | 
| 31 | 
            +
             | 
| 32 | 
            +
             | 
| 33 | 
            +
            class ReplitLMTokenizer(PreTrainedTokenizer):
         | 
| 34 | 
            +
                """
         | 
| 35 | 
            +
                  Construct a ReplitLMTokenizer tokenizer. Based on [SentencePiece](https://github.com/google/sentencepiece).
         | 
| 36 | 
            +
                  This tokenizer inherits from [`PreTrainedTokenizer`] which contains most of the main methods.
         | 
| 37 | 
            +
             | 
| 38 | 
            +
                  Args:
         | 
| 39 | 
            +
                      vocab_file (`str`):
         | 
| 40 | 
            +
                          [SentencePiece](https://github.com/google/sentencepiece) file (generally has a *.spm* extension) that
         | 
| 41 | 
            +
                          contains the vocabulary necessary to instantiate a tokenizer.
         | 
| 42 | 
            +
                      eos_token (`str`, *optional*, defaults to `"<|endoftext|>"`):
         | 
| 43 | 
            +
                          The end of sequence token.
         | 
| 44 | 
            +
                      bos_token (`str`, *optional*, defaults to `None`):
         | 
| 45 | 
            +
                          The begin of sequence token.
         | 
| 46 | 
            +
                      unk_token (`str`, *optional*, defaults to `"<|unk|>"`):
         | 
| 47 | 
            +
                          The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this
         | 
| 48 | 
            +
                          token instead.
         | 
| 49 | 
            +
                      pad_token (`str`, *optional*, defaults to `"<|pad|>"`):
         | 
| 50 | 
            +
                          The token used for padding, for example when batching sequences of different lengths.
         | 
| 51 | 
            +
                      sp_model_kwargs (`dict`, *optional*):
         | 
| 52 | 
            +
                          Will be passed to the `SentencePieceProcessor.__init__()` method. The [Python wrapper for
         | 
| 53 | 
            +
                          SentencePiece](https://github.com/google/sentencepiece/tree/master/python) can be used, among other things,
         | 
| 54 | 
            +
                          to set:
         | 
| 55 | 
            +
                          - `enable_sampling`: Enable subword regularization.
         | 
| 56 | 
            +
                          - `nbest_size`: Sampling parameters for unigram. Invalid for BPE-Dropout.
         | 
| 57 | 
            +
                            - `nbest_size = {0,1}`: No sampling is performed.
         | 
| 58 | 
            +
                            - `nbest_size > 1`: samples from the nbest_size results.
         | 
| 59 | 
            +
                            - `nbest_size < 0`: assuming that nbest_size is infinite and samples from the all hypothesis (lattice)
         | 
| 60 | 
            +
                              using forward-filtering-and-backward-sampling algorithm.
         | 
| 61 | 
            +
                          - `alpha`: Smoothing parameter for unigram sampling, and dropout probability of merge operations for
         | 
| 62 | 
            +
                            BPE-dropout.
         | 
| 63 | 
            +
                  """
         | 
| 64 | 
            +
             | 
| 65 | 
            +
                vocab_files_names = VOCAB_FILES_NAMES
         | 
| 66 | 
            +
                prefix_tokens: List[int] = []
         | 
| 67 | 
            +
                model_input_names = ["input_ids", "attention_mask"]
         | 
| 68 | 
            +
             | 
| 69 | 
            +
                def __init__(
         | 
| 70 | 
            +
                    self,
         | 
| 71 | 
            +
                    vocab_file,
         | 
| 72 | 
            +
                    bos_token=None,
         | 
| 73 | 
            +
                    eos_token="<|endoftext|>",
         | 
| 74 | 
            +
                    unk_token="<|unk|>",
         | 
| 75 | 
            +
                    pad_token="<|pad|>",
         | 
| 76 | 
            +
                    sep_token=None,
         | 
| 77 | 
            +
                    sp_model_kwargs: Optional[Dict[str, Any]] = None,
         | 
| 78 | 
            +
                    **kwargs,
         | 
| 79 | 
            +
                ) -> None:
         | 
| 80 | 
            +
                    self.sp_model_kwargs = {} if sp_model_kwargs is None else sp_model_kwargs
         | 
| 81 | 
            +
             | 
| 82 | 
            +
                    # Add extra_ids to the special token list
         | 
| 83 | 
            +
                    super().__init__(
         | 
| 84 | 
            +
                        bos_token=bos_token,
         | 
| 85 | 
            +
                        eos_token=eos_token,
         | 
| 86 | 
            +
                        unk_token=unk_token,
         | 
| 87 | 
            +
                        pad_token=pad_token,
         | 
| 88 | 
            +
                        sep_token=sep_token,
         | 
| 89 | 
            +
                        sp_model_kwargs=self.sp_model_kwargs,
         | 
| 90 | 
            +
                        **kwargs,
         | 
| 91 | 
            +
                    )
         | 
| 92 | 
            +
             | 
| 93 | 
            +
                    self.vocab_file = vocab_file
         | 
| 94 | 
            +
             | 
| 95 | 
            +
                    self.sp_model = spm.SentencePieceProcessor(**self.sp_model_kwargs)
         | 
| 96 | 
            +
                    self.sp_model.Load(vocab_file)
         | 
| 97 | 
            +
             | 
| 98 | 
            +
                @property
         | 
| 99 | 
            +
                def vocab_size(self):
         | 
| 100 | 
            +
                    return self.sp_model.get_piece_size()
         | 
| 101 | 
            +
             | 
| 102 | 
            +
                def get_vocab(self):
         | 
| 103 | 
            +
                    vocab = {self.convert_ids_to_tokens(
         | 
| 104 | 
            +
                        i): i for i in range(self.vocab_size)}
         | 
| 105 | 
            +
                    vocab.update(self.added_tokens_encoder)
         | 
| 106 | 
            +
                    return vocab
         | 
| 107 | 
            +
             | 
| 108 | 
            +
                def __getstate__(self):
         | 
| 109 | 
            +
                    state = self.__dict__.copy()
         | 
| 110 | 
            +
                    state["sp_model"] = None
         | 
| 111 | 
            +
                    return state
         | 
| 112 | 
            +
             | 
| 113 | 
            +
                def __setstate__(self, d):
         | 
| 114 | 
            +
                    self.__dict__ = d
         | 
| 115 | 
            +
             | 
| 116 | 
            +
                    # for backward compatibility
         | 
| 117 | 
            +
                    if not hasattr(self, "sp_model_kwargs"):
         | 
| 118 | 
            +
                        self.sp_model_kwargs = {}
         | 
| 119 | 
            +
             | 
| 120 | 
            +
                    self.sp_model = spm.SentencePieceProcessor(**self.sp_model_kwargs)
         | 
| 121 | 
            +
                    self.sp_model.load(self.vocab_file)
         | 
| 122 | 
            +
             | 
| 123 | 
            +
                def _tokenize(self, text: str) -> List[str]:
         | 
| 124 | 
            +
                    """Take as input a string and return a list of strings (tokens) for words/sub-words"""
         | 
| 125 | 
            +
                    return self.sp_model.encode(text, out_type=str)
         | 
| 126 | 
            +
             | 
| 127 | 
            +
                def _convert_token_to_id(self, token):
         | 
| 128 | 
            +
                    """Converts a token (str) in an id using the vocab."""
         | 
| 129 | 
            +
                    return self.sp_model.piece_to_id(token)
         | 
| 130 | 
            +
             | 
| 131 | 
            +
                def _convert_id_to_token(self, index):
         | 
| 132 | 
            +
                    """Converts an index (integer) in a token (str) using the vocab."""
         | 
| 133 | 
            +
                    token = self.sp_model.id_to_piece(index)
         | 
| 134 | 
            +
                    return token
         | 
| 135 | 
            +
             | 
| 136 | 
            +
                def convert_tokens_to_string(self, tokens):
         | 
| 137 | 
            +
                    """Converts a sequence of tokens (string) in a single string."""
         | 
| 138 | 
            +
                    return self.sp_model.decode(tokens)
         | 
| 139 | 
            +
             | 
| 140 | 
            +
                def save_vocabulary(self,
         | 
| 141 | 
            +
                                    save_directory: str,
         | 
| 142 | 
            +
                                    filename_prefix: Optional[str] = None) -> Tuple[str]:
         | 
| 143 | 
            +
             | 
| 144 | 
            +
                    if not os.path.isdir(save_directory):
         | 
| 145 | 
            +
                        raise ValueError(
         | 
| 146 | 
            +
                            f"Vocabulary path ({save_directory}) should be a directory")
         | 
| 147 | 
            +
             | 
| 148 | 
            +
                    out_vocab_file = os.path.join(
         | 
| 149 | 
            +
                        save_directory, (filename_prefix + "-" if filename_prefix else "") +
         | 
| 150 | 
            +
                        VOCAB_FILES_NAMES["vocab_file"])
         | 
| 151 | 
            +
             | 
| 152 | 
            +
                    if os.path.abspath(
         | 
| 153 | 
            +
                            self.vocab_file) != os.path.abspath(out_vocab_file) and os.path.isfile(
         | 
| 154 | 
            +
                            self.vocab_file):
         | 
| 155 | 
            +
                        copyfile(self.vocab_file, out_vocab_file)
         | 
| 156 | 
            +
                    elif not os.path.isfile(self.vocab_file):
         | 
| 157 | 
            +
                        with open(out_vocab_file, "wb") as fi:
         | 
| 158 | 
            +
                            content_spiece_model = self.sp_model.serialized_model_proto()
         | 
| 159 | 
            +
                            fi.write(content_spiece_model)
         | 
| 160 | 
            +
             | 
| 161 | 
            +
                    return (out_vocab_file, )
         | 
    	
        special_tokens_map.json
    ADDED
    
    | @@ -0,0 +1,5 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            {
         | 
| 2 | 
            +
                "eos_token": "<|endoftext|>",
         | 
| 3 | 
            +
                "pad_token": "<|pad|>",
         | 
| 4 | 
            +
                "unk_token": "<|unk|>"
         | 
| 5 | 
            +
            }
         | 
    	
        spiece.model
    ADDED
    
    | @@ -0,0 +1,3 @@ | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            version https://git-lfs.github.com/spec/v1
         | 
| 2 | 
            +
            oid sha256:7e1ba8b7df0701723d2d901c7a42182fe77bf0045173f2cdb474ca6ea3eb1c02
         | 
| 3 | 
            +
            size 707660
         | 
    	
        tokenizer_config.json
    ADDED
    
    | @@ -0,0 +1,18 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            {
         | 
| 2 | 
            +
                "auto_map": {
         | 
| 3 | 
            +
                  "AutoTokenizer": [
         | 
| 4 | 
            +
                    "replit_lm_tokenizer.ReplitLMTokenizer",
         | 
| 5 | 
            +
                    null
         | 
| 6 | 
            +
                  ]
         | 
| 7 | 
            +
                },
         | 
| 8 | 
            +
                "bos_token": null,
         | 
| 9 | 
            +
                "clean_up_tokenization_spaces": false,
         | 
| 10 | 
            +
                "eos_token": "<|endoftext|>",
         | 
| 11 | 
            +
                "model_max_length": 2048,
         | 
| 12 | 
            +
                "pad_token": "<|pad|>",
         | 
| 13 | 
            +
                "padding_side": "right",
         | 
| 14 | 
            +
                "sep_token": null,
         | 
| 15 | 
            +
                "sp_model_kwargs": {},
         | 
| 16 | 
            +
                "tokenizer_class": "ReplitLMTokenizer",
         | 
| 17 | 
            +
                "unk_token": "<|unk|>"
         | 
| 18 | 
            +
            }
         | 
