mike23415 commited on
Commit
6a177ad
·
verified ·
1 Parent(s): 137edd0

Update custom_bitnet.py

Browse files
Files changed (1) hide show
  1. custom_bitnet.py +440 -46
custom_bitnet.py CHANGED
@@ -1,76 +1,470 @@
1
- from transformers import PreTrainedModel, PretrainedConfig
 
 
 
 
2
  import torch
3
- import torch.nn as nn
 
 
 
 
4
 
 
5
  class BitNetConfig(PretrainedConfig):
6
  model_type = "bitnet"
 
 
7
  def __init__(
8
  self,
9
- vocab_size=32000,
10
- hidden_size=768,
11
- num_hidden_layers=12,
12
- num_attention_heads=12,
13
- intermediate_size=3072,
14
- hidden_act="gelu",
15
- max_position_embeddings=512,
 
16
  initializer_range=0.02,
17
- layer_norm_eps=1e-12,
18
- dropout=0.1,
19
- pad_token_id=0,
20
- bos_token_id=1,
21
- eos_token_id=2,
22
- **kwargs
 
 
 
 
23
  ):
24
  self.vocab_size = vocab_size
 
25
  self.hidden_size = hidden_size
 
26
  self.num_hidden_layers = num_hidden_layers
27
  self.num_attention_heads = num_attention_heads
28
- self.intermediate_size = intermediate_size
29
  self.hidden_act = hidden_act
30
- self.max_position_embeddings = max_position_embeddings
31
  self.initializer_range = initializer_range
32
- self.layer_norm_eps = layer_norm_eps
33
- self.dropout = dropout
 
 
 
34
  super().__init__(
35
  pad_token_id=pad_token_id,
36
  bos_token_id=bos_token_id,
37
  eos_token_id=eos_token_id,
38
- **kwargs
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
39
  )
 
 
 
 
 
 
 
 
 
 
 
40
 
 
41
  class BitNetForCausalLM(PreTrainedModel):
42
  config_class = BitNetConfig
 
 
43
  def __init__(self, config):
44
  super().__init__(config)
45
- self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size)
46
- self.layers = nn.ModuleList([
47
- nn.TransformerEncoderLayer(
48
- d_model=config.hidden_size,
49
- nhead=config.num_attention_heads,
50
- dim_feedforward=config.intermediate_size,
51
- dropout=config.dropout
52
- ) for _ in range(config.num_hidden_layers)
53
- ])
54
- self.norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
55
  self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
56
- self.apply(self._init_weights)
57
- def _init_weights(self, module):
58
- if isinstance(module, nn.Linear):
59
- torch.nn.init.normal_(module.weight, mean=0.0, std=self.config.initializer_range)
60
- if module.bias is not None:
61
- torch.nn.init.zeros_(module.bias)
62
- elif isinstance(module, nn.Embedding):
63
- torch.nn.init.normal_(module.weight, mean=0.0, std=self.config.initializer_range)
64
- def forward(self, input_ids, attention_mask=None, labels=None, **kwargs):
65
- hidden_states = self.embed_tokens(input_ids)
66
- for layer in self.layers:
67
- hidden_states = layer(hidden_states)
68
- hidden_states = self.norm(hidden_states)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
69
  logits = self.lm_head(hidden_states)
70
  loss = None
71
  if labels is not None:
72
  loss_fct = nn.CrossEntropyLoss()
73
  loss = loss_fct(logits.view(-1, self.config.vocab_size), labels.view(-1))
74
- return {"logits": logits, "loss": loss} if loss is not None else {"logits": logits}
75
- def prepare_inputs_for_generation(self, input_ids, **kwargs):
76
- return {"input_ids": input_ids, **kwargs}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2025 Microsoft, EleutherAI, and the HuggingFace Inc. team. All rights reserved.
3
+ # Licensed under the Apache License, Version 2.0.
4
+
5
+ from typing import Optional, Tuple, Union
6
  import torch
7
+ from torch import nn
8
+ from transformers import PreTrainedModel, PretrainedConfig
9
+ from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
10
+ from transformers.cache_utils import DynamicCache
11
+ from transformers.activations import ACT2FN
12
 
13
+ # BitNetConfig
14
  class BitNetConfig(PretrainedConfig):
15
  model_type = "bitnet"
16
+ keys_to_ignore_at_inference = ["past_key_values"]
17
+
18
  def __init__(
19
  self,
20
+ vocab_size=128256,
21
+ hidden_size=2560,
22
+ intermediate_size=6912,
23
+ num_hidden_layers=30,
24
+ num_attention_heads=20,
25
+ num_key_value_heads=5,
26
+ hidden_act="relu2",
27
+ max_position_embeddings=2048,
28
  initializer_range=0.02,
29
+ rms_norm_eps=1e-5,
30
+ use_cache=True,
31
+ pad_token_id=None,
32
+ bos_token_id=128000,
33
+ eos_token_id=128001,
34
+ tie_word_embeddings=False,
35
+ rope_theta=500000.0,
36
+ attention_bias=False,
37
+ attention_dropout=0.0,
38
+ **kwargs,
39
  ):
40
  self.vocab_size = vocab_size
41
+ self.max_position_embeddings = max_position_embeddings
42
  self.hidden_size = hidden_size
43
+ self.intermediate_size = intermediate_size
44
  self.num_hidden_layers = num_hidden_layers
45
  self.num_attention_heads = num_attention_heads
46
+ self.num_key_value_heads = num_key_value_heads or num_attention_heads
47
  self.hidden_act = hidden_act
 
48
  self.initializer_range = initializer_range
49
+ self.rms_norm_eps = rms_norm_eps
50
+ self.use_cache = use_cache
51
+ self.rope_theta = rope_theta
52
+ self.attention_bias = attention_bias
53
+ self.attention_dropout = attention_dropout
54
  super().__init__(
55
  pad_token_id=pad_token_id,
56
  bos_token_id=bos_token_id,
57
  eos_token_id=eos_token_id,
58
+ tie_word_embeddings=tie_word_embeddings,
59
+ **kwargs,
60
+ )
61
+
62
+ # BitNetRMSNorm
63
+ class BitNetRMSNorm(nn.Module):
64
+ def __init__(self, hidden_size, eps=1e-6):
65
+ super().__init__()
66
+ self.weight = nn.Parameter(torch.ones(hidden_size))
67
+ self.variance_epsilon = eps
68
+
69
+ def forward(self, hidden_states):
70
+ input_dtype = hidden_states.dtype
71
+ hidden_states = hidden_states.to(torch.float32)
72
+ variance = hidden_states.pow(2).mean(-1, keepdim=True)
73
+ hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
74
+ return self.weight * hidden_states.to(input_dtype)
75
+
76
+ # BitNetMLP
77
+ class BitNetMLP(nn.Module):
78
+ def __init__(self, config: BitNetConfig):
79
+ super().__init__()
80
+ self.config = config
81
+ self.hidden_size = config.hidden_size
82
+ self.intermediate_size = config.intermediate_size
83
+ self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
84
+ self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
85
+ self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
86
+ self.act_fn = ACT2FN[config.hidden_act]
87
+ self.ffn_sub_norm = BitNetRMSNorm(config.intermediate_size, eps=config.rms_norm_eps)
88
+
89
+ def forward(self, x):
90
+ down_proj = self.down_proj(self.ffn_sub_norm(self.act_fn(self.gate_proj(x)) * self.up_proj(x)))
91
+ return down_proj
92
+
93
+ # Utility Functions
94
+ def rotate_half(x):
95
+ x1 = x[..., : x.shape[-1] // 2]
96
+ x2 = x[..., x.shape[-1] // 2 :]
97
+ return torch.cat((-x2, x1), dim=-1)
98
+
99
+ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
100
+ cos = cos.unsqueeze(unsqueeze_dim)
101
+ sin = sin.unsqueeze(unsqueeze_dim)
102
+ q_embed = (q * cos) + (rotate_half(q) * sin)
103
+ k_embed = (k * cos) + (rotate_half(k) * sin)
104
+ return q_embed, k_embed
105
+
106
+ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
107
+ batch, num_key_value_heads, slen, head_dim = hidden_states.shape
108
+ if n_rep == 1:
109
+ return hidden_states
110
+ hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
111
+ return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
112
+
113
+ def eager_attention_forward(
114
+ module: nn.Module,
115
+ query: torch.Tensor,
116
+ key: torch.Tensor,
117
+ value: torch.Tensor,
118
+ attention_mask: Optional[torch.Tensor],
119
+ scaling: float,
120
+ dropout: float = 0.0,
121
+ **kwargs,
122
+ ):
123
+ key_states = repeat_kv(key, module.num_key_value_groups)
124
+ value_states = repeat_kv(value, module.num_key_value_groups)
125
+ attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
126
+ if attention_mask is not None:
127
+ causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
128
+ attn_weights = attn_weights + causal_mask
129
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
130
+ attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
131
+ attn_output = torch.matmul(attn_weights, value_states)
132
+ attn_output = attn_output.transpose(1, 2).contiguous()
133
+ return attn_output, attn_weights
134
+
135
+ # BitNetAttention
136
+ class BitNetAttention(nn.Module):
137
+ def __init__(self, config: BitNetConfig, layer_idx: int):
138
+ super().__init__()
139
+ self.config = config
140
+ self.layer_idx = layer_idx
141
+ self.head_dim = config.hidden_size // config.num_attention_heads
142
+ self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads
143
+ self.scaling = self.head_dim**-0.5
144
+ self.attention_dropout = config.attention_dropout
145
+ self.is_causal = True
146
+ self.q_proj = nn.Linear(
147
+ config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias
148
+ )
149
+ self.k_proj = nn.Linear(
150
+ config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
151
+ )
152
+ self.v_proj = nn.Linear(
153
+ config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
154
+ )
155
+ self.o_proj = nn.Linear(
156
+ config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias
157
+ )
158
+ self.attn_sub_norm = BitNetRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
159
+
160
+ def forward(
161
+ self,
162
+ hidden_states: torch.Tensor,
163
+ position_embeddings: Tuple[torch.Tensor, torch.Tensor],
164
+ attention_mask: Optional[torch.Tensor],
165
+ past_key_value: Optional[DynamicCache] = None,
166
+ cache_position: Optional[torch.LongTensor] = None,
167
+ **kwargs,
168
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
169
+ input_shape = hidden_states.shape[:-1]
170
+ hidden_shape = (*input_shape, -1, self.head_dim)
171
+ query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
172
+ key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
173
+ value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
174
+ cos, sin = position_embeddings
175
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
176
+ if past_key_value is not None:
177
+ cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
178
+ key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
179
+ attn_output, attn_weights = eager_attention_forward(
180
+ self,
181
+ query_states,
182
+ key_states,
183
+ value_states,
184
+ attention_mask,
185
+ dropout=0.0 if not self.training else self.attention_dropout,
186
+ scaling=self.scaling,
187
+ )
188
+ attn_output = attn_output.reshape(*input_shape, -1).contiguous()
189
+ attn_output = self.attn_sub_norm(attn_output)
190
+ attn_output = self.o_proj(attn_output)
191
+ return attn_output, attn_weights
192
+
193
+ # BitNetDecoderLayer
194
+ class BitNetDecoderLayer(nn.Module):
195
+ def __init__(self, config: BitNetConfig, layer_idx: int):
196
+ super().__init__()
197
+ self.hidden_size = config.hidden_size
198
+ self.self_attn = BitNetAttention(config=config, layer_idx=layer_idx)
199
+ self.mlp = BitNetMLP(config)
200
+ self.input_layernorm = BitNetRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
201
+ self.post_attention_layernorm = BitNetRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
202
+
203
+ def forward(
204
+ self,
205
+ hidden_states: torch.Tensor,
206
+ attention_mask: Optional[torch.Tensor] = None,
207
+ position_ids: Optional[torch.LongTensor] = None,
208
+ past_key_value: Optional[DynamicCache] = None,
209
+ output_attentions: Optional[bool] = False,
210
+ use_cache: Optional[bool] = False,
211
+ cache_position: Optional[torch.LongTensor] = None,
212
+ position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
213
+ **kwargs,
214
+ ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
215
+ residual = hidden_states
216
+ hidden_states = self.input_layernorm(hidden_states)
217
+ hidden_states, self_attn_weights = self.self_attn(
218
+ hidden_states=hidden_states,
219
+ attention_mask=attention_mask,
220
+ position_ids=position_ids,
221
+ past_key_value=past_key_value,
222
+ output_attentions=output_attentions,
223
+ use_cache=use_cache,
224
+ cache_position=cache_position,
225
+ position_embeddings=position_embeddings,
226
+ **kwargs,
227
+ )
228
+ hidden_states = residual + hidden_states
229
+ residual = hidden_states
230
+ hidden_states = self.post_attention_layernorm(hidden_states)
231
+ hidden_states = self.mlp(hidden_states)
232
+ hidden_states = residual + hidden_states
233
+ outputs = (hidden_states,)
234
+ if output_attentions:
235
+ outputs += (self_attn_weights,)
236
+ return outputs
237
+
238
+ # BitNetRotaryEmbedding
239
+ class BitNetRotaryEmbedding(nn.Module):
240
+ def __init__(self, config: BitNetConfig, device=None):
241
+ super().__init__()
242
+ self.rope_type = "default"
243
+ self.max_seq_len_cached = config.max_position_embeddings
244
+ self.config = config
245
+ dim = config.hidden_size // config.num_attention_heads
246
+ inv_freq = 1.0 / (config.rope_theta ** (torch.arange(0, dim, 2, dtype=torch.float, device=device) / dim))
247
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
248
+ self.original_inv_freq = self.inv_freq
249
+
250
+ def forward(self, x, position_ids):
251
+ inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
252
+ position_ids_expanded = position_ids[:, None, :].float()
253
+ device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
254
+ with torch.autocast(device_type=device_type, enabled=False):
255
+ freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
256
+ emb = torch.cat((freqs, freqs), dim=-1)
257
+ cos = emb.cos()
258
+ sin = emb.sin()
259
+ return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
260
+
261
+ # BitNetModel
262
+ class BitNetModel(PreTrainedModel):
263
+ config_class = BitNetConfig
264
+ supports_gradient_checkpointing = True
265
+ _no_split_modules = ["BitNetDecoderLayer"]
266
+
267
+ def __init__(self, config: BitNetConfig):
268
+ super().__init__(config)
269
+ self.padding_idx = config.pad_token_id
270
+ self.vocab_size = config.vocab_size
271
+ self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
272
+ self.layers = nn.ModuleList(
273
+ [BitNetDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
274
+ )
275
+ self.norm = BitNetRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
276
+ self.rotary_emb = BitNetRotaryEmbedding(config=config)
277
+ self.gradient_checkpointing = False
278
+ self.post_init()
279
+
280
+ def get_input_embeddings(self):
281
+ return self.embed_tokens
282
+
283
+ def set_input_embeddings(self, value):
284
+ self.embed_tokens = value
285
+
286
+ def forward(
287
+ self,
288
+ input_ids: Optional[torch.LongTensor] = None,
289
+ attention_mask: Optional[torch.Tensor] = None,
290
+ position_ids: Optional[torch.LongTensor] = None,
291
+ past_key_values: Optional[DynamicCache] = None,
292
+ inputs_embeds: Optional[torch.FloatTensor] = None,
293
+ use_cache: Optional[bool] = None,
294
+ output_attentions: Optional[bool] = None,
295
+ output_hidden_states: Optional[bool] = None,
296
+ cache_position: Optional[torch.LongTensor] = None,
297
+ **kwargs,
298
+ ) -> BaseModelOutputWithPast:
299
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
300
+ output_hidden_states = (
301
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
302
+ )
303
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
304
+ if (input_ids is None) ^ (inputs_embeds is not None):
305
+ raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
306
+ if inputs_embeds is None:
307
+ inputs_embeds = self.embed_tokens(input_ids)
308
+ if use_cache and past_key_values is None:
309
+ past_key_values = DynamicCache()
310
+ if cache_position is None:
311
+ past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
312
+ cache_position = torch.arange(
313
+ past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
314
+ )
315
+ if position_ids is None:
316
+ position_ids = cache_position.unsqueeze(0)
317
+ causal_mask = self._update_causal_mask(attention_mask, inputs_embeds, cache_position, past_key_values)
318
+ hidden_states k= inputs_embeds
319
+ position_embeddings = self.rotary_emb(hidden_states, position_ids)
320
+ all_hidden_states = () if output_hidden_states else None
321
+ all_self_attns = () if output_attentions else None
322
+ for decoder_layer in self.layers:
323
+ if output_hidden_states:
324
+ all_hidden_states += (hidden_states,)
325
+ layer_outputs = decoder_layer(
326
+ hidden_states,
327
+ attention_mask=causal_mask,
328
+ position_ids=position_ids,
329
+ past_key_value=past_key_values,
330
+ output_attentions=output_attentions,
331
+ use_cache=use_cache,
332
+ cache_position=cache_position,
333
+ position_embeddings=position_embeddings,
334
+ )
335
+ hidden_states = layer_outputs[0]
336
+ if output_attentions:
337
+ all_self_attns += (layer_outputs[1],)
338
+ hidden_states = self.norm(hidden_states)
339
+ if output_hidden_states:
340
+ all_hidden_states += (hidden_states,)
341
+ return BaseModelOutputWithPast(
342
+ last_hidden_state=hidden_states,
343
+ past_key_values=past_key_values if use_cache else None,
344
+ hidden_states=all_hidden_states,
345
+ attentions=all_self_attns,
346
+ )
347
+
348
+ def _update_causal_mask(
349
+ self,
350
+ attention_mask: Optional[torch.Tensor],
351
+ input_tensor: torch.Tensor,
352
+ cache_position: torch.Tensor,
353
+ past_key_values: Optional[DynamicCache],
354
+ ):
355
+ dtype, device = input_tensor.dtype, input_tensor.device
356
+ sequence_length = input_tensor.shape[1]
357
+ target_length = past_key_values.get_seq_length() + sequence_length + 1 if past_key_values else sequence_length + 1
358
+ min_dtype = torch.finfo(dtype).min
359
+ causal_mask = torch.full(
360
+ (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device
361
  )
362
+ if sequence_length != 1:
363
+ causal_mask = torch.triu(causal_mask, diagonal=1)
364
+ causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
365
+ causal_mask = causal_mask[None, None, :, :].expand(input_tensor.shape[0], 1, -1, -1)
366
+ if attention_mask is not None:
367
+ causal_mask = causal_mask.clone()
368
+ mask_length = attention_mask.shape[-1]
369
+ padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to(causal_mask.device)
370
+ padding_mask = padding_mask == 0
371
+ causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(padding_mask, min_dtype)
372
+ return causal_mask
373
 
374
+ # BitNetForCausalLM
375
  class BitNetForCausalLM(PreTrainedModel):
376
  config_class = BitNetConfig
377
+ _tied_weights_keys = ["lm_head.weight"]
378
+
379
  def __init__(self, config):
380
  super().__init__(config)
381
+ self.model = BitNetModel(config)
382
+ self.vocab_size = config.vocab_size
 
 
 
 
 
 
 
 
383
  self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
384
+ self.post_init()
385
+
386
+ def get_input_embeddings(self):
387
+ return self.model.embed_tokens
388
+
389
+ def set_input_embeddings(self, value):
390
+ self.model.embed_tokens = value
391
+
392
+ def get_output_embeddings(self):
393
+ return self.lm_head
394
+
395
+ def set_output_embeddings(self, new_embeddings):
396
+ self.lm_head = new_embeddings
397
+
398
+ def set_decoder(self, decoder):
399
+ self.model = decoder
400
+
401
+ def get_decoder(self):
402
+ return self.model
403
+
404
+ def forward(
405
+ self,
406
+ input_ids: Optional[torch.LongTensor] = None,
407
+ attention_mask: Optional[torch.Tensor] = None,
408
+ position_ids: Optional[torch.LongTensor] = None,
409
+ past_key_values: Optional[DynamicCache] = None,
410
+ inputs_embeds: Optional[torch.FloatTensor] = None,
411
+ labels: Optional[torch.LongTensor] = None,
412
+ use_cache: Optional[bool] = None,
413
+ output_attentions: Optional[bool] = None,
414
+ output_hidden_states: Optional[bool] = None,
415
+ cache_position: Optional[torch.LongTensor] = None,
416
+ **kwargs,
417
+ ) -> CausalLMOutputWithPast:
418
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
419
+ output_hidden_states = (
420
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
421
+ )
422
+ outputs = self.model(
423
+ input_ids=input_ids,
424
+ attention_mask=attention_mask,
425
+ position_ids=position_ids,
426
+ past_key_values=past_key_values,
427
+ inputs_embeds=inputs_embeds,
428
+ use_cache=use_cache,
429
+ output_attentions=output_attentions,
430
+ output_hidden_states=output_hidden_states,
431
+ cache_position=cache_position,
432
+ **kwargs,
433
+ )
434
+ hidden_states = outputs.last_hidden_state
435
  logits = self.lm_head(hidden_states)
436
  loss = None
437
  if labels is not None:
438
  loss_fct = nn.CrossEntropyLoss()
439
  loss = loss_fct(logits.view(-1, self.config.vocab_size), labels.view(-1))
440
+ return CausalLMOutputWithPast(
441
+ loss=loss,
442
+ logits=logits,
443
+ past_key_values=outputs.past_key_values,
444
+ hidden_states=outputs.hidden_states,
445
+ attentions=outputs.attentions,
446
+ )
447
+
448
+ def prepare_inputs_for_generation(self, input_ids, past_key_values=None, attention_mask=None, **kwargs):
449
+ if past_key_values is None:
450
+ past_key_values = DynamicCache()
451
+ cache_position = kwargs.get("cache_position", None)
452
+ if cache_position is None:
453
+ past_length = past_key_values.get_seq_length() if past_key_values is not None else 0
454
+ cache_position = torch.arange(past_length, past_length + input_ids.shape[-1], device=input_ids.device)
455
+ position_ids = cache_position.unsqueeze(0)
456
+ if attention_mask is not None and attention_mask.shape[1] != input_ids.shape[1]:
457
+ attention_mask = self._update_causal_mask(
458
+ attention_mask,
459
+ input_ids,
460
+ cache_position,
461
+ past_key_values
462
+ )
463
+ return {
464
+ "input_ids": input_ids,
465
+ "position_ids": position_ids,
466
+ "attention_mask": attention_mask,
467
+ "past_key_values": past_key_values,
468
+ "cache_position": cache_position,
469
+ "use_cache": kwargs.get("use_cache", True),
470
+ }