mike23415 commited on
Commit
37ba7a2
·
verified ·
1 Parent(s): 9f3e630

Update custom_bitnet.py

Browse files
Files changed (1) hide show
  1. custom_bitnet.py +79 -7
custom_bitnet.py CHANGED
@@ -1,14 +1,86 @@
1
- from transformers import PreTrainedModel
2
- from transformers.models.bitnet.configuration_bitnet import BitNetConfig
 
3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4
  class BitNetForCausalLM(PreTrainedModel):
5
  config_class = BitNetConfig
6
 
7
  def __init__(self, config):
8
  super().__init__(config)
9
- # Placeholder: Copy implementation from fork's modeling_bitnet.py
10
- raise NotImplementedError("Replace with actual BitNetForCausalLM implementation")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
 
12
- def forward(self, *args, **kwargs):
13
- # Placeholder: Copy forward pass from fork
14
- raise NotImplementedError("Replace with actual forward pass implementation")
 
1
+ from transformers import PreTrainedModel, PretrainedConfig
2
+ import torch
3
+ import torch.nn as nn
4
 
5
+ # BitNetConfig (replace with contents of configuration_bitnet.py)
6
+ class BitNetConfig(PretrainedConfig):
7
+ model_type = "bitnet"
8
+ def __init__(
9
+ self,
10
+ vocab_size=32000,
11
+ hidden_size=768,
12
+ num_hidden_layers=12,
13
+ num_attention_heads=12,
14
+ intermediate_size=3072,
15
+ hidden_act="gelu",
16
+ max_position_embeddings=512,
17
+ initializer_range=0.02,
18
+ layer_norm_eps=1e-12,
19
+ dropout=0.1,
20
+ pad_token_id=0,
21
+ bos_token_id=1,
22
+ eos_token_id=2,
23
+ **kwargs
24
+ ):
25
+ self.vocab_size = vocab_size
26
+ self.hidden_size = hidden_size
27
+ self.num_hidden_layers = num_hidden_layers
28
+ self.num_attention_heads = num_attention_heads
29
+ self.intermediate_size = intermediate_size
30
+ self.hidden_act = hidden_act
31
+ self.max_position_embeddings = max_position_embeddings
32
+ self.initializer_range = initializer_range
33
+ self.layer_norm_eps = layer_norm_eps
34
+ self.dropout = dropout
35
+ super().__init__(
36
+ pad_token_id=pad_token_id,
37
+ bos_token_id=bos_token_id,
38
+ eos_token_id=eos_token_id,
39
+ **kwargs
40
+ )
41
+
42
+ # BitNetForCausalLM (replace with contents of modeling_bitnet.py)
43
  class BitNetForCausalLM(PreTrainedModel):
44
  config_class = BitNetConfig
45
 
46
  def __init__(self, config):
47
  super().__init__(config)
48
+ # Placeholder: Replace with actual implementation
49
+ # Example structure (based on typical transformer models):
50
+ self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size)
51
+ self.layers = nn.ModuleList([
52
+ # Add BitNet-specific layers (e.g., BitNetLayer)
53
+ nn.TransformerEncoderLayer(
54
+ d_model=config.hidden_size,
55
+ nhead=config.num_attention_heads,
56
+ dim_feedforward=config.intermediate_size,
57
+ dropout=config.dropout
58
+ ) for _ in range(config.num_hidden_layers)
59
+ ])
60
+ self.norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
61
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
62
+ self.apply(self._init_weights)
63
+
64
+ def _init_weights(self, module):
65
+ if isinstance(module, nn.Linear):
66
+ torch.nn.init.normal_(module.weight, mean=0.0, std=self.config.initializer_range)
67
+ if module.bias is not None:
68
+ torch.nn.init.zeros_(module.bias)
69
+ elif isinstance(module, nn.Embedding):
70
+ torch.nn.init.normal_(module.weight, mean=0.0, std=self.config.initializer_range)
71
+
72
+ def forward(self, input_ids, attention_mask=None, labels=None, **kwargs):
73
+ # Placeholder: Replace with actual forward pass
74
+ hidden_states = self.embed_tokens(input_ids)
75
+ for layer in self.layers:
76
+ hidden_states = layer(hidden_states)
77
+ hidden_states = self.norm(hidden_states)
78
+ logits = self.lm_head(hidden_states)
79
+ loss = None
80
+ if labels is not None:
81
+ loss_fct = nn.CrossEntropyLoss()
82
+ loss = loss_fct(logits.view(-1, self.config.vocab_size), labels.view(-1))
83
+ return {"logits": logits, "loss": loss} if loss is not None else {"logits": logits}
84
 
85
+ def prepare_inputs_for_generation(self, input_ids, **kwargs):
86
+ return {"input_ids": input_ids, **kwargs}