appledora commited on
Commit
17e5b07
·
verified ·
1 Parent(s): 9a3a9ac

Update recast1B_llama/modeling_recast_llama.py

Browse files
recast1B_llama/modeling_recast_llama.py CHANGED
@@ -32,35 +32,33 @@ class MLPTemplateBank(nn.Module):
32
  self.coef_shape = (coef_rows, coef_columns)
33
 
34
  assert coef_columns is not None, "coef_columns must not be None"
35
-
36
  # Ensure divisibility for proper reshaping
37
- assert (self.hidden_size * self.intermediate_size) % coef_rows == 0, \
38
- f"hidden_size * intermediate_size ({self.hidden_size * self.intermediate_size}) must be divisible by coef_rows ({coef_rows})"
39
-
 
40
  template_size = self.hidden_size * self.intermediate_size // coef_rows
41
-
42
- self.up_templates = nn.Parameter(
43
- torch.randn(coef_columns, template_size)
44
- )
45
- self.gate_templates = nn.Parameter(
46
- torch.randn(coef_columns, template_size)
47
- )
48
-
49
  # Better initialization
50
  nn.init.xavier_uniform_(self.up_templates)
51
  nn.init.xavier_uniform_(self.gate_templates)
52
 
53
  def forward(self, up_coeffs, gate_coeffs):
54
  # Compute chunked weights
55
- up_chunks = torch.matmul(up_coeffs, self.up_templates)
56
  gate_chunks = torch.matmul(gate_coeffs, self.gate_templates)
57
-
58
  # Reshape to final weight matrices
59
  up_weights = up_chunks.reshape(self.intermediate_size, self.hidden_size)
60
  gate_weights = gate_chunks.reshape(self.intermediate_size, self.hidden_size)
61
-
62
  return up_weights, gate_weights
63
 
 
64
  class SharedLlamaMLP(nn.Module):
65
  def __init__(self, config, bank):
66
  super().__init__()
@@ -68,7 +66,9 @@ class SharedLlamaMLP(nn.Module):
68
  self.bank = bank
69
  self.hidden_size = config.hidden_size
70
  self.intermediate_size = config.intermediate_size
71
- self.down_proj = nn.Linear(config.intermediate_size, config.hidden_size, bias=False)
 
 
72
 
73
  # Initialize coefficients with proper shapes
74
  self.up_coefficients = nn.Parameter(torch.randn(bank.coef_shape))
@@ -90,31 +90,37 @@ class SharedLlamaMLP(nn.Module):
90
  def forward(self, x):
91
  # Generate weights using template bank
92
  up_weights, gate_weights = self.bank(
93
- self.up_coefficients,
94
- self.gate_coefficients # Fixed order
95
  )
96
-
97
  # Apply SwiGLU: SiLU(gate * x) * up * x
98
- hidden_states = self.act_fn(F.linear(x, gate_weights, self.gate_bias)) * F.linear(x, up_weights, self.up_bias)
 
 
99
  output = self.down_proj(hidden_states)
100
 
101
  return output
102
 
 
103
  class AttTemplateBank(nn.Module):
104
  def __init__(self, config, coef_rows, coef_columns):
105
  super().__init__()
106
  self.hidden_size = config.hidden_size
107
  self.num_heads = config.num_attention_heads
108
  self.head_dim = config.hidden_size // config.num_attention_heads
109
- self.num_key_value_heads = getattr(config, 'num_key_value_heads', config.num_attention_heads)
 
 
110
  self.kv_dim = self.num_key_value_heads * self.head_dim
111
  self.coef_shape = (coef_rows, coef_columns)
112
 
113
  # Ensure divisibility
114
- assert (self.hidden_size * self.hidden_size) % coef_rows == 0, \
115
- "Q projection size must be divisible by coef_rows"
116
- assert (self.kv_dim * self.hidden_size) % coef_rows == 0, \
117
- "K/V projection size must be divisible by coef_rows"
 
 
118
 
119
  # Create templates for Q, K, V
120
  self.q_templates = nn.Parameter(
@@ -144,9 +150,15 @@ class AttTemplateBank(nn.Module):
144
  v_weights = v_chunks.reshape(self.kv_dim, self.hidden_size)
145
 
146
  return q_weights, k_weights, v_weights
147
-
 
148
  class SharedLlamaAttention(nn.Module):
149
- def __init__(self, config, layer_idx: Optional[int] = None, bank: Optional[AttTemplateBank] = None):
 
 
 
 
 
150
  super().__init__()
151
  self.config = config
152
  self.bank = bank
@@ -155,15 +167,21 @@ class SharedLlamaAttention(nn.Module):
155
  self.hidden_size = config.hidden_size
156
  self.num_heads = config.num_attention_heads
157
  self.head_dim = self.hidden_size // self.num_heads
158
- self.num_key_value_heads = getattr(config, 'num_key_value_heads', config.num_attention_heads)
 
 
159
  self.num_key_value_groups = self.num_heads // self.num_key_value_heads
160
  self.max_position_embeddings = config.max_position_embeddings
161
- self.rope_theta = getattr(config, 'rope_theta', 10000.0)
162
  self.is_causal = True
163
-
164
- self.o_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=getattr(config, 'attention_bias', False))
 
 
 
 
165
  self.rotary_emb = LlamaRotaryEmbedding(config=self.config)
166
-
167
  # Initialize coefficients with proper shapes
168
  self.q_coefficients = nn.Parameter(torch.randn(bank.coef_shape))
169
  self.k_coefficients = nn.Parameter(torch.randn(bank.coef_shape))
@@ -187,50 +205,64 @@ class SharedLlamaAttention(nn.Module):
187
  **kwargs,
188
  ):
189
  bsz, q_len, _ = hidden_states.size()
190
-
191
  # Generate weights using template bank
192
- q_weights, k_weights, v_weights = self.bank(
193
- self.q_coefficients,
194
- self.k_coefficients,
195
- self.v_coefficients
196
  )
197
 
198
  # Apply projections
199
  query_states = F.linear(hidden_states, q_weights)
200
  key_states = F.linear(hidden_states, k_weights)
201
  value_states = F.linear(hidden_states, v_weights)
202
-
203
  # Reshape for multi-head attention
204
- query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
205
- key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
206
- value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
207
-
 
 
 
 
 
 
208
  # Apply rotary embeddings
209
  if position_embeddings is None:
210
  cos, sin = self.rotary_emb(value_states, position_ids)
211
  else:
212
  cos, sin = position_embeddings
213
- query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
214
-
 
 
215
  # Handle past key values
216
  if past_key_value is not None:
217
  cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
218
- key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
219
-
 
 
220
  # Repeat key/value for grouped query attention
221
  key_states = repeat_kv(key_states, self.num_key_value_groups)
222
  value_states = repeat_kv(value_states, self.num_key_value_groups)
223
 
224
  # Compute attention
225
- attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
 
 
226
 
227
  if attention_mask is not None:
228
  causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
229
  attn_weights = attn_weights + causal_mask
230
 
231
  # Apply softmax and dropout
232
- attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
233
- attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training)
 
 
 
 
234
  attn_output = torch.matmul(attn_weights, value_states)
235
 
236
  if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
@@ -242,10 +274,10 @@ class SharedLlamaAttention(nn.Module):
242
  attn_output = attn_output.transpose(1, 2).contiguous()
243
  attn_output = attn_output.reshape(bsz, q_len, -1)
244
  attn_output = self.o_proj(attn_output)
245
-
246
  if not output_attentions:
247
  attn_weights = None
248
-
249
  return attn_output, attn_weights, past_key_value
250
 
251
 
@@ -269,6 +301,8 @@ class RECAST1B_llamaModel(PreTrainedModel):
269
  config_class = RECAST1B_llama
270
  base_model_prefix = "llama"
271
  supports_gradient_checkpointing = True
 
 
272
 
273
  def __init__(self, config):
274
  super().__init__(config)
@@ -641,6 +675,8 @@ class RECAST1B_LlamaForCausalLM(PreTrainedModel, GenerationMixin):
641
  config_class = RECAST1B_llama
642
  base_model_prefix = "llama"
643
  supports_gradient_checkpointing = True
 
 
644
 
645
  def __init__(self, config):
646
  super().__init__(config)
 
32
  self.coef_shape = (coef_rows, coef_columns)
33
 
34
  assert coef_columns is not None, "coef_columns must not be None"
35
+
36
  # Ensure divisibility for proper reshaping
37
+ assert (
38
+ self.hidden_size * self.intermediate_size
39
+ ) % coef_rows == 0, f"hidden_size * intermediate_size ({self.hidden_size * self.intermediate_size}) must be divisible by coef_rows ({coef_rows})"
40
+
41
  template_size = self.hidden_size * self.intermediate_size // coef_rows
42
+
43
+ self.up_templates = nn.Parameter(torch.randn(coef_columns, template_size))
44
+ self.gate_templates = nn.Parameter(torch.randn(coef_columns, template_size))
45
+
 
 
 
 
46
  # Better initialization
47
  nn.init.xavier_uniform_(self.up_templates)
48
  nn.init.xavier_uniform_(self.gate_templates)
49
 
50
  def forward(self, up_coeffs, gate_coeffs):
51
  # Compute chunked weights
52
+ up_chunks = torch.matmul(up_coeffs, self.up_templates)
53
  gate_chunks = torch.matmul(gate_coeffs, self.gate_templates)
54
+
55
  # Reshape to final weight matrices
56
  up_weights = up_chunks.reshape(self.intermediate_size, self.hidden_size)
57
  gate_weights = gate_chunks.reshape(self.intermediate_size, self.hidden_size)
58
+
59
  return up_weights, gate_weights
60
 
61
+
62
  class SharedLlamaMLP(nn.Module):
63
  def __init__(self, config, bank):
64
  super().__init__()
 
66
  self.bank = bank
67
  self.hidden_size = config.hidden_size
68
  self.intermediate_size = config.intermediate_size
69
+ self.down_proj = nn.Linear(
70
+ config.intermediate_size, config.hidden_size, bias=False
71
+ )
72
 
73
  # Initialize coefficients with proper shapes
74
  self.up_coefficients = nn.Parameter(torch.randn(bank.coef_shape))
 
90
  def forward(self, x):
91
  # Generate weights using template bank
92
  up_weights, gate_weights = self.bank(
93
+ self.up_coefficients, self.gate_coefficients # Fixed order
 
94
  )
95
+
96
  # Apply SwiGLU: SiLU(gate * x) * up * x
97
+ hidden_states = self.act_fn(
98
+ F.linear(x, gate_weights, self.gate_bias)
99
+ ) * F.linear(x, up_weights, self.up_bias)
100
  output = self.down_proj(hidden_states)
101
 
102
  return output
103
 
104
+
105
  class AttTemplateBank(nn.Module):
106
  def __init__(self, config, coef_rows, coef_columns):
107
  super().__init__()
108
  self.hidden_size = config.hidden_size
109
  self.num_heads = config.num_attention_heads
110
  self.head_dim = config.hidden_size // config.num_attention_heads
111
+ self.num_key_value_heads = getattr(
112
+ config, "num_key_value_heads", config.num_attention_heads
113
+ )
114
  self.kv_dim = self.num_key_value_heads * self.head_dim
115
  self.coef_shape = (coef_rows, coef_columns)
116
 
117
  # Ensure divisibility
118
+ assert (
119
+ self.hidden_size * self.hidden_size
120
+ ) % coef_rows == 0, "Q projection size must be divisible by coef_rows"
121
+ assert (
122
+ self.kv_dim * self.hidden_size
123
+ ) % coef_rows == 0, "K/V projection size must be divisible by coef_rows"
124
 
125
  # Create templates for Q, K, V
126
  self.q_templates = nn.Parameter(
 
150
  v_weights = v_chunks.reshape(self.kv_dim, self.hidden_size)
151
 
152
  return q_weights, k_weights, v_weights
153
+
154
+
155
  class SharedLlamaAttention(nn.Module):
156
+ def __init__(
157
+ self,
158
+ config,
159
+ layer_idx: Optional[int] = None,
160
+ bank: Optional[AttTemplateBank] = None,
161
+ ):
162
  super().__init__()
163
  self.config = config
164
  self.bank = bank
 
167
  self.hidden_size = config.hidden_size
168
  self.num_heads = config.num_attention_heads
169
  self.head_dim = self.hidden_size // self.num_heads
170
+ self.num_key_value_heads = getattr(
171
+ config, "num_key_value_heads", config.num_attention_heads
172
+ )
173
  self.num_key_value_groups = self.num_heads // self.num_key_value_heads
174
  self.max_position_embeddings = config.max_position_embeddings
175
+ self.rope_theta = getattr(config, "rope_theta", 10000.0)
176
  self.is_causal = True
177
+
178
+ self.o_proj = nn.Linear(
179
+ self.hidden_size,
180
+ self.hidden_size,
181
+ bias=getattr(config, "attention_bias", False),
182
+ )
183
  self.rotary_emb = LlamaRotaryEmbedding(config=self.config)
184
+
185
  # Initialize coefficients with proper shapes
186
  self.q_coefficients = nn.Parameter(torch.randn(bank.coef_shape))
187
  self.k_coefficients = nn.Parameter(torch.randn(bank.coef_shape))
 
205
  **kwargs,
206
  ):
207
  bsz, q_len, _ = hidden_states.size()
208
+
209
  # Generate weights using template bank
210
+ q_weights, k_weights, v_weights = self.bank(
211
+ self.q_coefficients, self.k_coefficients, self.v_coefficients
 
 
212
  )
213
 
214
  # Apply projections
215
  query_states = F.linear(hidden_states, q_weights)
216
  key_states = F.linear(hidden_states, k_weights)
217
  value_states = F.linear(hidden_states, v_weights)
218
+
219
  # Reshape for multi-head attention
220
+ query_states = query_states.view(
221
+ bsz, q_len, self.num_heads, self.head_dim
222
+ ).transpose(1, 2)
223
+ key_states = key_states.view(
224
+ bsz, q_len, self.num_key_value_heads, self.head_dim
225
+ ).transpose(1, 2)
226
+ value_states = value_states.view(
227
+ bsz, q_len, self.num_key_value_heads, self.head_dim
228
+ ).transpose(1, 2)
229
+
230
  # Apply rotary embeddings
231
  if position_embeddings is None:
232
  cos, sin = self.rotary_emb(value_states, position_ids)
233
  else:
234
  cos, sin = position_embeddings
235
+ query_states, key_states = apply_rotary_pos_emb(
236
+ query_states, key_states, cos, sin
237
+ )
238
+
239
  # Handle past key values
240
  if past_key_value is not None:
241
  cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
242
+ key_states, value_states = past_key_value.update(
243
+ key_states, value_states, self.layer_idx, cache_kwargs
244
+ )
245
+
246
  # Repeat key/value for grouped query attention
247
  key_states = repeat_kv(key_states, self.num_key_value_groups)
248
  value_states = repeat_kv(value_states, self.num_key_value_groups)
249
 
250
  # Compute attention
251
+ attn_weights = torch.matmul(
252
+ query_states, key_states.transpose(2, 3)
253
+ ) / math.sqrt(self.head_dim)
254
 
255
  if attention_mask is not None:
256
  causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
257
  attn_weights = attn_weights + causal_mask
258
 
259
  # Apply softmax and dropout
260
+ attn_weights = nn.functional.softmax(
261
+ attn_weights, dim=-1, dtype=torch.float32
262
+ ).to(query_states.dtype)
263
+ attn_weights = nn.functional.dropout(
264
+ attn_weights, p=self.attention_dropout, training=self.training
265
+ )
266
  attn_output = torch.matmul(attn_weights, value_states)
267
 
268
  if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
 
274
  attn_output = attn_output.transpose(1, 2).contiguous()
275
  attn_output = attn_output.reshape(bsz, q_len, -1)
276
  attn_output = self.o_proj(attn_output)
277
+
278
  if not output_attentions:
279
  attn_weights = None
280
+
281
  return attn_output, attn_weights, past_key_value
282
 
283
 
 
301
  config_class = RECAST1B_llama
302
  base_model_prefix = "llama"
303
  supports_gradient_checkpointing = True
304
+ _no_split_modules = ["LlamaDecoderLayer"] # Add this
305
+ _skip_keys_device_placement = "past_key_values" # Add this
306
 
307
  def __init__(self, config):
308
  super().__init__(config)
 
675
  config_class = RECAST1B_llama
676
  base_model_prefix = "llama"
677
  supports_gradient_checkpointing = True
678
+ _no_split_modules = ["LlamaDecoderLayer"] # Add this
679
+ _skip_keys_device_placement = "past_key_values" # Add this
680
 
681
  def __init__(self, config):
682
  super().__init__(config)