FlashAttention support
Browse files- configuration_gptbert.py +1 -80
    	
        configuration_gptbert.py
    CHANGED
    
    | @@ -14,55 +14,7 @@ class GptBertConfig(PretrainedConfig): | |
| 14 | 
             
                    **kwargs
         | 
| 15 | 
             
                ):
         | 
| 16 | 
             
                    super().__init__(**kwargs)
         | 
| 17 | 
            -
             | 
| 18 | 
            -
                    self.model: str
         | 
| 19 | 
            -
             | 
| 20 | 
            -
                    # General information
         | 
| 21 | 
            -
                    self.model = "base"
         | 
| 22 | 
            -
             | 
| 23 | 
            -
                    # Vocabulary
         | 
| 24 | 
            -
                    self.vocab_size = 16384
         | 
| 25 | 
            -
                    self.max_sequence_length = 512
         | 
| 26 | 
            -
             | 
| 27 | 
            -
                    # Model dimensions
         | 
| 28 | 
            -
                    self.hidden_size = 768
         | 
| 29 | 
            -
                    self.intermediate_size = 2048
         | 
| 30 | 
            -
                    self.num_attention_heads = 12
         | 
| 31 | 
            -
                    self.num_layers = 12
         | 
| 32 | 
            -
                    self.d_qk = 64
         | 
| 33 | 
            -
             | 
| 34 | 
            -
                    # Dropout probabilities
         | 
| 35 | 
            -
                    self.embedding_dropout_p = 0.1
         | 
| 36 | 
            -
                    self.attention_probabilities_dropout_p = 0.1
         | 
| 37 | 
            -
                    self.attention_output_dropout_p = 0.1
         | 
| 38 | 
            -
                    self.feed_forward_dropout_p = 0.1
         | 
| 39 | 
            -
                    self.attention_dropout = 0.1
         | 
| 40 | 
            -
                    self.hidden_dropout_prob = 0.2
         | 
| 41 | 
            -
             | 
| 42 | 
            -
                    # Position Emebedding
         | 
| 43 | 
            -
                    self.rope_theta = 160_000
         | 
| 44 | 
            -
             | 
| 45 | 
            -
                    # Norms
         | 
| 46 | 
            -
                    self.word_norm_eps = 1e-7
         | 
| 47 | 
            -
                    self.word_norm_affine = False
         | 
| 48 | 
            -
             | 
| 49 | 
            -
                    self.attention_pre_norm_eps = 1e-7
         | 
| 50 | 
            -
                    self.attention_pre_norm_affine = False
         | 
| 51 | 
            -
             | 
| 52 | 
            -
                    self.attention_inter_norm_eps = 1e-7
         | 
| 53 | 
            -
                    self.attention_inter_norm_affine = True
         | 
| 54 | 
            -
             | 
| 55 | 
            -
                    self.feed_forward_pre_norm_eps = 1e-7
         | 
| 56 | 
            -
                    self.feed_forward_pre_norm_affine = False
         | 
| 57 | 
            -
             | 
| 58 | 
            -
                    self.feed_forward_inter_norm_eps = 1e-7
         | 
| 59 | 
            -
                    self.feed_forward_inter_norm_affine = False
         | 
| 60 | 
            -
             | 
| 61 | 
            -
                    self.classifier_pre_norm_eps = 1e-7
         | 
| 62 | 
            -
                    self.classifier_pre_norm_affine = False
         | 
| 63 | 
            -
             | 
| 64 | 
            -
                    self.classifier_post_norm_eps = 1e-7
         | 
| 65 | 
            -
                    self.classifier_post_norm_affine = False
         | 
| 66 |  | 
| 67 | 
             
                    if config_file is not None:
         | 
| 68 | 
             
                        if type(config_file) is str:
         | 
| @@ -80,34 +32,3 @@ class GptBertConfig(PretrainedConfig): | |
| 80 | 
             
                        if isinstance(value, str):
         | 
| 81 | 
             
                            value = value.lower()
         | 
| 82 | 
             
                        setattr(self, attr, value)
         | 
| 83 | 
            -
             | 
| 84 | 
            -
                def __repr__(self) -> str:
         | 
| 85 | 
            -
                    return str(self.to_json_string())
         | 
| 86 | 
            -
             | 
| 87 | 
            -
                def to_dict(self) -> dict:
         | 
| 88 | 
            -
                    """Serializes this instance to a Python dictionary."""
         | 
| 89 | 
            -
                    output: dict
         | 
| 90 | 
            -
             | 
| 91 | 
            -
                    output = copy.deepcopy(self.__dict__)
         | 
| 92 | 
            -
                    return output
         | 
| 93 | 
            -
             | 
| 94 | 
            -
                def to_json_string(self) -> str:
         | 
| 95 | 
            -
                    """Serializes this instance to a JSON string."""
         | 
| 96 | 
            -
                    return json.dumps(self.to_dict(), indent=2, sort_keys=True) + "\n"
         | 
| 97 | 
            -
             | 
| 98 | 
            -
                def to_json_file(self, json_file_path: Path | str) -> None:
         | 
| 99 | 
            -
                    """Save this instance to a json file."""
         | 
| 100 | 
            -
                    if isinstance(json_file_path, str):
         | 
| 101 | 
            -
                        json_file_path: Path = Path(json_file_path)
         | 
| 102 | 
            -
                    with json_file_path.open("w", encoding='utf-8') as writer:
         | 
| 103 | 
            -
                        writer.write(self.to_json_string())
         | 
| 104 | 
            -
             | 
| 105 | 
            -
                @classmethod
         | 
| 106 | 
            -
                def create_base_config(cls, json_file_path: Path | str | None = None) -> GptBertConfig:
         | 
| 107 | 
            -
                    config: GptBertConfig
         | 
| 108 | 
            -
             | 
| 109 | 
            -
                    config = GptBertConfig()
         | 
| 110 | 
            -
                    if json_file_path is not None:
         | 
| 111 | 
            -
                        config.to_json_file(json_file_path)
         | 
| 112 | 
            -
             | 
| 113 | 
            -
                    return config
         | 
|  | |
| 14 | 
             
                    **kwargs
         | 
| 15 | 
             
                ):
         | 
| 16 | 
             
                    super().__init__(**kwargs)
         | 
| 17 | 
            +
                    self.model = "norbert4"
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 18 |  | 
| 19 | 
             
                    if config_file is not None:
         | 
| 20 | 
             
                        if type(config_file) is str:
         | 
|  | |
| 32 | 
             
                        if isinstance(value, str):
         | 
| 33 | 
             
                            value = value.lower()
         | 
| 34 | 
             
                        setattr(self, attr, value)
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
