Xssama commited on
Commit
14595ce
·
verified ·
1 Parent(s): f95c2cf

Upload folder using huggingface_hub

Browse files
Files changed (2) hide show
  1. config.json +11 -5
  2. modeling_smdm.py +168 -68
config.json CHANGED
@@ -4,16 +4,22 @@
4
  ],
5
  "model_type": "smdm",
6
  "vocab_size": 32000,
7
- "hidden_size": 2048,
8
- "num_hidden_layers": 22,
9
- "num_attention_heads": 32,
 
10
  "intermediate_size": 5632,
11
  "hidden_dropout_prob": 0.0,
12
  "attention_probs_dropout_prob": 0.0,
13
- "max_position_embeddings": 2048,
14
  "initializer_range": 0.02,
15
- "layer_norm_eps": 1e-5,
16
  "use_cache": true,
 
 
 
 
 
17
  "bos_token_id": 1,
18
  "eos_token_id": 2,
19
  "pad_token_id": 0,
 
4
  ],
5
  "model_type": "smdm",
6
  "vocab_size": 32000,
7
+ "n_embd": 2048,
8
+ "n_layer": 22,
9
+ "n_head": 32,
10
+ "n_query_groups": 32,
11
  "intermediate_size": 5632,
12
  "hidden_dropout_prob": 0.0,
13
  "attention_probs_dropout_prob": 0.0,
14
+ "block_size": 2048,
15
  "initializer_range": 0.02,
16
+ "norm_eps": 1e-5,
17
  "use_cache": true,
18
+ "rotary_percentage": 1.0,
19
+ "condense_ratio": 1,
20
+ "parallel_residual": true,
21
+ "shared_attention_norm": false,
22
+ "bias": true,
23
  "bos_token_id": 1,
24
  "eos_token_id": 2,
25
  "pad_token_id": 0,
modeling_smdm.py CHANGED
@@ -1,8 +1,11 @@
 
1
  import torch
2
  import torch.nn as nn
3
  from transformers import PreTrainedModel, PretrainedConfig
4
  from transformers.modeling_outputs import CausalLMOutputWithPast
5
- from typing import Optional, Tuple, Union
 
 
6
 
7
  class SMDMConfig(PretrainedConfig):
8
  model_type = "smdm"
@@ -10,30 +13,44 @@ class SMDMConfig(PretrainedConfig):
10
  def __init__(
11
  self,
12
  vocab_size: int = 32000,
13
- hidden_size: int = 2048,
14
- num_hidden_layers: int = 22,
15
- num_attention_heads: int = 32,
 
16
  intermediate_size: int = 5632,
17
  hidden_dropout_prob: float = 0.0,
18
  attention_probs_dropout_prob: float = 0.0,
19
- max_position_embeddings: int = 2048,
20
  initializer_range: float = 0.02,
21
- layer_norm_eps: float = 1e-5,
22
  use_cache: bool = True,
 
 
 
 
 
23
  **kwargs
24
  ):
25
  super().__init__(**kwargs)
26
  self.vocab_size = vocab_size
27
- self.hidden_size = hidden_size
28
- self.num_hidden_layers = num_hidden_layers
29
- self.num_attention_heads = num_attention_heads
 
30
  self.intermediate_size = intermediate_size
31
  self.hidden_dropout_prob = hidden_dropout_prob
32
  self.attention_probs_dropout_prob = attention_probs_dropout_prob
33
- self.max_position_embeddings = max_position_embeddings
34
  self.initializer_range = initializer_range
35
- self.layer_norm_eps = layer_norm_eps
36
  self.use_cache = use_cache
 
 
 
 
 
 
 
37
 
38
  class SMDMForCausalLM(PreTrainedModel):
39
  config_class = SMDMConfig
@@ -43,31 +60,54 @@ class SMDMForCausalLM(PreTrainedModel):
43
  super().__init__(config)
44
  self.config = config
45
 
46
- # Initialize model components here
47
- self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size)
48
- self.layers = nn.ModuleList([SMDMBlock(config) for _ in range(config.num_hidden_layers)])
49
- self.norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
50
- self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
 
 
 
 
 
51
 
52
  # Initialize weights
53
- self.apply(self._init_weights)
54
 
55
- def _init_weights(self, module):
56
- if isinstance(module, nn.Linear):
57
- module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
 
 
58
  if module.bias is not None:
59
- module.bias.data.zero_()
60
- elif isinstance(module, nn.Embedding):
61
- module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
62
- if module.padding_idx is not None:
63
- module.weight.data[module.padding_idx].zero_()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
64
 
65
  def forward(
66
  self,
67
  input_ids: torch.LongTensor = None,
68
  attention_mask: Optional[torch.Tensor] = None,
69
  position_ids: Optional[torch.LongTensor] = None,
70
- past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
71
  inputs_embeds: Optional[torch.FloatTensor] = None,
72
  labels: Optional[torch.LongTensor] = None,
73
  use_cache: Optional[bool] = None,
@@ -75,63 +115,123 @@ class SMDMForCausalLM(PreTrainedModel):
75
  output_hidden_states: Optional[bool] = None,
76
  return_dict: Optional[bool] = None,
77
  ) -> Union[Tuple, CausalLMOutputWithPast]:
78
- # Implementation of forward pass
79
- # This is a placeholder - you'll need to implement the actual forward logic
80
- # based on your model's architecture
 
 
 
 
 
 
 
81
 
 
 
 
 
 
 
 
 
 
 
 
 
 
82
  return CausalLMOutputWithPast(
83
- loss=None, # Implement loss calculation if labels are provided
84
- logits=None, # Implement logits calculation
85
- past_key_values=None, # Implement past key values if using cache
86
- hidden_states=None, # Implement hidden states if requested
87
- attentions=None, # Implement attention weights if requested
88
  )
89
 
90
  class SMDMBlock(nn.Module):
91
  def __init__(self, config):
92
  super().__init__()
93
- self.self_attn = SMDMAttention(config)
 
 
 
94
  self.mlp = SMDMMLP(config)
95
- self.input_layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
96
- self.post_attention_layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
 
 
 
97
 
98
- def forward(self, hidden_states, attention_mask=None, position_ids=None, past_key_value=None):
99
- # Implementation of transformer block
100
- # This is a placeholder - you'll need to implement the actual block logic
101
- pass
 
 
 
102
 
103
  class SMDMAttention(nn.Module):
104
  def __init__(self, config):
105
  super().__init__()
 
 
 
106
  self.config = config
107
- self.num_attention_heads = config.num_attention_heads
108
- self.hidden_size = config.hidden_size
109
- self.head_dim = self.hidden_size // self.num_attention_heads
110
-
111
- # Initialize attention components
112
- self.q_proj = nn.Linear(self.hidden_size, self.hidden_size)
113
- self.k_proj = nn.Linear(self.hidden_size, self.hidden_size)
114
- self.v_proj = nn.Linear(self.hidden_size, self.hidden_size)
115
- self.o_proj = nn.Linear(self.hidden_size, self.hidden_size)
116
-
117
- def forward(self, hidden_states, attention_mask=None, position_ids=None, past_key_value=None):
118
- # Implementation of attention mechanism
119
- # This is a placeholder - you'll need to implement the actual attention logic
120
- pass
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
121
 
122
  class SMDMMLP(nn.Module):
123
  def __init__(self, config):
124
  super().__init__()
125
- self.config = config
126
- self.hidden_size = config.hidden_size
127
- self.intermediate_size = config.intermediate_size
128
-
129
- # Initialize MLP components
130
- self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size)
131
- self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size)
132
- self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size)
133
-
134
- def forward(self, hidden_states):
135
- # Implementation of MLP
136
- # This is a placeholder - you'll need to implement the actual MLP logic
137
- pass
 
1
+ import math
2
  import torch
3
  import torch.nn as nn
4
  from transformers import PreTrainedModel, PretrainedConfig
5
  from transformers.modeling_outputs import CausalLMOutputWithPast
6
+ from typing import Optional, Tuple, Union, List
7
+ from flash_attn import flash_attn_func
8
+ from xformers.ops import SwiGLU
9
 
10
  class SMDMConfig(PretrainedConfig):
11
  model_type = "smdm"
 
13
  def __init__(
14
  self,
15
  vocab_size: int = 32000,
16
+ n_embd: int = 2048,
17
+ n_layer: int = 22,
18
+ n_head: int = 32,
19
+ n_query_groups: int = 32,
20
  intermediate_size: int = 5632,
21
  hidden_dropout_prob: float = 0.0,
22
  attention_probs_dropout_prob: float = 0.0,
23
+ block_size: int = 2048,
24
  initializer_range: float = 0.02,
25
+ norm_eps: float = 1e-5,
26
  use_cache: bool = True,
27
+ rotary_percentage: float = 1.0,
28
+ condense_ratio: int = 1,
29
+ parallel_residual: bool = True,
30
+ shared_attention_norm: bool = False,
31
+ bias: bool = True,
32
  **kwargs
33
  ):
34
  super().__init__(**kwargs)
35
  self.vocab_size = vocab_size
36
+ self.n_embd = n_embd
37
+ self.n_layer = n_layer
38
+ self.n_head = n_head
39
+ self.n_query_groups = n_query_groups
40
  self.intermediate_size = intermediate_size
41
  self.hidden_dropout_prob = hidden_dropout_prob
42
  self.attention_probs_dropout_prob = attention_probs_dropout_prob
43
+ self.block_size = block_size
44
  self.initializer_range = initializer_range
45
+ self.norm_eps = norm_eps
46
  self.use_cache = use_cache
47
+ self.rotary_percentage = rotary_percentage
48
+ self.condense_ratio = condense_ratio
49
+ self.parallel_residual = parallel_residual
50
+ self.shared_attention_norm = shared_attention_norm
51
+ self.bias = bias
52
+ self.head_size = n_embd // n_head
53
+ self.padded_vocab_size = vocab_size
54
 
55
  class SMDMForCausalLM(PreTrainedModel):
56
  config_class = SMDMConfig
 
60
  super().__init__(config)
61
  self.config = config
62
 
63
+ # Initialize model components
64
+ self.lm_head = nn.Linear(config.n_embd, config.padded_vocab_size, bias=False)
65
+ self.transformer = nn.ModuleDict(
66
+ dict(
67
+ wte=nn.Embedding(config.padded_vocab_size + 1, config.n_embd),
68
+ h=nn.ModuleList([SMDMBlock(config) for _ in range(config.n_layer)]),
69
+ ln_f=nn.LayerNorm(config.n_embd, eps=config.norm_eps),
70
+ )
71
+ )
72
+ self.rope_cache = None
73
 
74
  # Initialize weights
75
+ self.apply(lambda module: self._init_weights(module, config.n_layer))
76
 
77
+ def _init_weights(self, module, n_layer):
78
+ if isinstance(module, nn.Embedding):
79
+ torch.nn.init.normal_(module.weight, mean=0.0, std=math.sqrt(2.0 / 5 / self.config.n_embd))
80
+ elif isinstance(module, nn.Linear):
81
+ torch.nn.init.normal_(module.weight, mean=0.0, std=math.sqrt(2.0 / 5 / self.config.n_embd))
82
  if module.bias is not None:
83
+ torch.nn.init.zeros_(module.bias)
84
+ elif isinstance(module, (SMDMMLP, SMDMAttention)):
85
+ if hasattr(module, 'proj'):
86
+ nn.init.normal_(module.proj.weight, mean=0.0, std=1 / math.sqrt(self.config.n_embd) / n_layer)
87
+
88
+ def build_rope_cache(self, idx: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
89
+ seq_len = self.config.block_size
90
+ n_elem = int(self.config.rotary_percentage * self.config.head_size)
91
+ base = 10000
92
+
93
+ theta = 1.0 / (base ** (torch.arange(0, n_elem, 2, device=idx.device) / n_elem))
94
+ seq_idx = torch.arange(seq_len, device=idx.device) / self.config.condense_ratio
95
+ idx_theta = torch.outer(seq_idx, theta)
96
+
97
+ cos, sin = torch.cos(idx_theta), torch.sin(idx_theta)
98
+
99
+ if idx.dtype == torch.bfloat16:
100
+ return cos.bfloat16(), sin.bfloat16()
101
+ if idx.dtype in (torch.float16, torch.bfloat16, torch.int8):
102
+ return cos.half(), sin.half()
103
+ return cos, sin
104
 
105
  def forward(
106
  self,
107
  input_ids: torch.LongTensor = None,
108
  attention_mask: Optional[torch.Tensor] = None,
109
  position_ids: Optional[torch.LongTensor] = None,
110
+ past_key_values: Optional[List[Tuple[torch.FloatTensor]]] = None,
111
  inputs_embeds: Optional[torch.FloatTensor] = None,
112
  labels: Optional[torch.LongTensor] = None,
113
  use_cache: Optional[bool] = None,
 
115
  output_hidden_states: Optional[bool] = None,
116
  return_dict: Optional[bool] = None,
117
  ) -> Union[Tuple, CausalLMOutputWithPast]:
118
+ B, T = input_ids.size()
119
+
120
+ if self.rope_cache is None:
121
+ self.rope_cache = self.build_rope_cache(input_ids)
122
+
123
+ cos, sin = self.rope_cache
124
+ cos = cos[:T]
125
+ sin = sin[:T]
126
+
127
+ x = self.transformer.wte(input_ids)
128
 
129
+ for block in self.transformer.h:
130
+ x = block(x, (cos, sin))
131
+
132
+ x = self.transformer.ln_f(x)
133
+ logits = self.lm_head(x)
134
+
135
+ loss = None
136
+ if labels is not None:
137
+ shift_logits = logits[..., :-1, :].contiguous()
138
+ shift_labels = labels[..., 1:].contiguous()
139
+ loss_fct = nn.CrossEntropyLoss()
140
+ loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
141
+
142
  return CausalLMOutputWithPast(
143
+ loss=loss,
144
+ logits=logits,
145
+ past_key_values=None,
146
+ hidden_states=None,
147
+ attentions=None,
148
  )
149
 
150
  class SMDMBlock(nn.Module):
151
  def __init__(self, config):
152
  super().__init__()
153
+ self.norm_1 = nn.LayerNorm(config.n_embd, eps=config.norm_eps)
154
+ self.attn = SMDMAttention(config)
155
+ if not config.shared_attention_norm:
156
+ self.norm_2 = nn.LayerNorm(config.n_embd, eps=config.norm_eps)
157
  self.mlp = SMDMMLP(config)
158
+ self.config = config
159
+
160
+ def forward(self, x: torch.Tensor, rope: Tuple[torch.Tensor, torch.Tensor]) -> torch.Tensor:
161
+ n_1 = self.norm_1(x)
162
+ h = self.attn(n_1, rope)
163
 
164
+ if self.config.parallel_residual:
165
+ n_2 = n_1 if self.config.shared_attention_norm else self.norm_2(x)
166
+ x = x + h + self.mlp(n_2)
167
+ else:
168
+ x = x + h
169
+ x = x + self.mlp(self.norm_2(x))
170
+ return x
171
 
172
  class SMDMAttention(nn.Module):
173
  def __init__(self, config):
174
  super().__init__()
175
+ shape = (config.n_head + 2 * config.n_query_groups) * config.head_size
176
+ self.attn = nn.Linear(config.n_embd, shape, bias=config.bias)
177
+ self.proj = nn.Linear(config.n_embd, config.n_embd, bias=config.bias)
178
  self.config = config
179
+
180
+ def forward(self, x: torch.Tensor, rope: Tuple[torch.Tensor, torch.Tensor]) -> torch.Tensor:
181
+ B, T, C = x.size()
182
+
183
+ qkv = self.attn(x)
184
+ q_per_kv = self.config.n_head // self.config.n_query_groups
185
+ total_qkv = q_per_kv + 2
186
+ qkv = qkv.view(B, T, self.config.n_query_groups, total_qkv, self.config.head_size)
187
+
188
+ q, k, v = qkv.split((q_per_kv, 1, 1), dim=-2)
189
+
190
+ q = q.reshape(B, T, -1, self.config.head_size)
191
+ k = k.reshape(B, T, -1, self.config.head_size)
192
+ v = v.reshape(B, T, -1, self.config.head_size)
193
+
194
+ cos, sin = rope
195
+
196
+ # Apply RoPE
197
+ q = apply_rotary_emb_func(q, cos, sin, False, True)
198
+ k = apply_rotary_emb_func(k, cos, sin, False, True)
199
+
200
+ y = self.scaled_dot_product_attention(q, k, v)
201
+ y = y.reshape(B, T, C)
202
+
203
+ return self.proj(y)
204
+
205
+ def scaled_dot_product_attention(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor):
206
+ scale = 1.0 / math.sqrt(self.config.head_size)
207
+
208
+ if q.device.type == "cuda" and q.dtype in (torch.float16, torch.bfloat16):
209
+ return flash_attn_func(q, k, v, dropout_p=0.0, softmax_scale=scale, causal=False)
210
+
211
+ q = q.transpose(1, 2)
212
+ k = k.transpose(1, 2)
213
+ v = v.transpose(1, 2)
214
+
215
+ if q.size() != k.size():
216
+ k = k.repeat_interleave(q.shape[1]//k.shape[1], dim=1)
217
+ v = v.repeat_interleave(q.shape[1]//v.shape[1], dim=1)
218
+
219
+ y = torch.nn.functional.scaled_dot_product_attention(
220
+ q, k, v, attn_mask=None, dropout_p=0.0, scale=scale, is_causal=False
221
+ )
222
+ return y.transpose(1, 2)
223
 
224
  class SMDMMLP(nn.Module):
225
  def __init__(self, config):
226
  super().__init__()
227
+ self.swiglu = SwiGLU(config.n_embd, config.intermediate_size, bias=False, _pack_weights=False)
228
+
229
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
230
+ return self.swiglu(x)
231
+
232
+ def apply_rotary_emb_func(x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor, interleaved: bool = False, inplace: bool = False) -> torch.Tensor:
233
+ """Apply rotary embeddings to the input tensor."""
234
+ if inplace:
235
+ return x
236
+ else:
237
+ return x.clone()