SVECTOR-OFFICIAL commited on
Commit
5e5aa43
·
verified ·
1 Parent(s): 0ae23f5

Update modeling_spect1.py

Browse files
Files changed (1) hide show
  1. modeling_spect1.py +97 -51
modeling_spect1.py CHANGED
@@ -2,95 +2,141 @@ from typing import Optional, Tuple
2
 
3
  import torch
4
  from torch import nn
 
5
  from transformers.cache_utils import Cache
6
- from transformers.models.auto.modeling_auto import AutoConfig, AutoModel
7
 
8
- from .configuration_spect1 import SpecT1Config
9
 
10
 
11
  class SpecT1MTPLayers(nn.Module):
12
- def __init__(self, config):
13
  super().__init__()
14
- # Layer normalization with RMSNorm, adjusted for SpecT1 config
15
  self.input_layernorm = nn.LayerNorm(config.hidden_size)
16
  self.post_attention_layernorm = nn.LayerNorm(config.hidden_size)
17
  self.token_layernorm = nn.LayerNorm(config.hidden_size)
18
  self.hidden_layernorm = nn.LayerNorm(config.hidden_size)
19
-
20
- # Linear projection layer for input embeddings
21
- self.input_proj = nn.Linear(config.hidden_size * 2, config.hidden_size, bias=False)
22
-
23
- # Final layer normalization
24
  self.final_layernorm = nn.LayerNorm(config.hidden_size)
25
-
26
- # Custom attention mechanism (specifically adjusted for SpecT1)
27
- self.self_attn = nn.MultiheadAttention(config.hidden_size, config.num_attention_heads)
28
-
29
- # MLP layer (adapted for SpecT1)
 
 
30
  self.mlp = nn.Sequential(
31
  nn.Linear(config.hidden_size, config.intermediate_size),
32
  nn.ReLU(),
33
  nn.Linear(config.intermediate_size, config.hidden_size)
34
  )
35
 
36
- def forward(self, input_embeds,
37
- hidden_states,
38
- attention_mask,
39
- position_ids,
40
- past_key_values: Optional[Cache] = None,
41
- output_attentions: Optional[bool] = False,
42
- use_cache: Optional[bool] = False,
43
- position_embedding: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
44
- cache_position=None,
45
- **kwargs):
46
-
 
 
47
  input_embeds = self.token_layernorm(input_embeds)
48
  previous_hidden_states = self.hidden_layernorm(hidden_states)
49
  hidden_states = self.input_proj(torch.cat([previous_hidden_states, input_embeds], dim=-1))
50
  residual = hidden_states
51
  hidden_states = self.input_layernorm(hidden_states)
52
-
53
- # Apply self-attention
54
- hidden_states, _ = self.self_attn(hidden_states, hidden_states, hidden_states, attn_mask=attention_mask)
55
-
56
- hidden_states = residual + hidden_states
57
  residual = hidden_states
58
  hidden_states = self.post_attention_layernorm(hidden_states)
59
-
60
- # Apply MLP
61
- hidden_states = self.mlp(hidden_states)
62
- hidden_states = residual + hidden_states
63
-
64
- # Apply final layer normalization
65
  hidden_states = self.final_layernorm(hidden_states)
66
  return hidden_states
67
 
68
-
69
  class SpecT1Model(nn.Module):
70
  config_class = SpecT1Config
71
-
72
  def __init__(self, config: SpecT1Config):
73
  super().__init__()
74
- self.mtp_layers = nn.ModuleList([SpecT1MTPLayers(config) for _ in range(config.num_nextn_predict_layers)])
75
-
76
- def forward(self, input_embeds, attention_mask, position_ids, **kwargs):
 
 
 
 
 
 
 
 
 
77
  hidden_states = input_embeds
78
  for layer in self.mtp_layers:
79
- hidden_states = layer(hidden_states, hidden_states, attention_mask, position_ids, **kwargs)
 
 
 
 
 
 
80
  return hidden_states
81
 
82
-
83
- class SpecT1ForCausalLM(nn.Module):
84
  config_class = SpecT1Config
85
-
86
  def __init__(self, config: SpecT1Config):
87
- super(SpecT1ForCausalLM, self).__init__()
 
88
  self.model = SpecT1Model(config)
89
- self.vocab_size = config.vocab_size
90
  self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
91
- self.post_init()
92
 
93
- def forward(self, input_embeds, attention_mask, position_ids, **kwargs):
94
- hidden_states = self.model(input_embeds, attention_mask, position_ids, **kwargs)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
95
  logits = self.lm_head(hidden_states)
96
- return logits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
 
3
  import torch
4
  from torch import nn
5
+ from transformers import PreTrainedModel
6
  from transformers.cache_utils import Cache
 
7
 
8
+ from configuration_spect1 import SpecT1Config
9
 
10
 
11
  class SpecT1MTPLayers(nn.Module):
12
+ def __init__(self, config: SpecT1Config):
13
  super().__init__()
 
14
  self.input_layernorm = nn.LayerNorm(config.hidden_size)
15
  self.post_attention_layernorm = nn.LayerNorm(config.hidden_size)
16
  self.token_layernorm = nn.LayerNorm(config.hidden_size)
17
  self.hidden_layernorm = nn.LayerNorm(config.hidden_size)
 
 
 
 
 
18
  self.final_layernorm = nn.LayerNorm(config.hidden_size)
19
+ self.input_proj = nn.Linear(config.hidden_size * 2, config.hidden_size, bias=False)
20
+ self.self_attn = nn.MultiheadAttention(
21
+ embed_dim=config.hidden_size,
22
+ num_heads=config.num_attention_heads,
23
+ dropout=config.attention_dropout,
24
+ batch_first=True
25
+ )
26
  self.mlp = nn.Sequential(
27
  nn.Linear(config.hidden_size, config.intermediate_size),
28
  nn.ReLU(),
29
  nn.Linear(config.intermediate_size, config.hidden_size)
30
  )
31
 
32
+ def forward(
33
+ self,
34
+ input_embeds: torch.Tensor,
35
+ hidden_states: torch.Tensor,
36
+ attention_mask: Optional[torch.Tensor] = None,
37
+ position_ids: Optional[torch.Tensor] = None,
38
+ past_key_values: Optional[Cache] = None,
39
+ output_attentions: Optional[bool] = False,
40
+ use_cache: Optional[bool] = False,
41
+ position_embedding: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
42
+ cache_position=None,
43
+ **kwargs
44
+ ) -> torch.Tensor:
45
  input_embeds = self.token_layernorm(input_embeds)
46
  previous_hidden_states = self.hidden_layernorm(hidden_states)
47
  hidden_states = self.input_proj(torch.cat([previous_hidden_states, input_embeds], dim=-1))
48
  residual = hidden_states
49
  hidden_states = self.input_layernorm(hidden_states)
50
+ attn_output, _ = self.self_attn(hidden_states, hidden_states, hidden_states, attn_mask=attention_mask)
51
+ hidden_states = residual + attn_output
 
 
 
52
  residual = hidden_states
53
  hidden_states = self.post_attention_layernorm(hidden_states)
54
+ mlp_output = self.mlp(hidden_states)
55
+ hidden_states = residual + mlp_output
 
 
 
 
56
  hidden_states = self.final_layernorm(hidden_states)
57
  return hidden_states
58
 
 
59
  class SpecT1Model(nn.Module):
60
  config_class = SpecT1Config
 
61
  def __init__(self, config: SpecT1Config):
62
  super().__init__()
63
+ self.config = config
64
+ self.mtp_layers = nn.ModuleList([
65
+ SpecT1MTPLayers(config) for _ in range(config.num_nextn_predict_layers)
66
+ ])
67
+
68
+ def forward(
69
+ self,
70
+ input_embeds: torch.Tensor,
71
+ attention_mask: Optional[torch.Tensor] = None,
72
+ position_ids: Optional[torch.Tensor] = None,
73
+ **kwargs
74
+ ) -> torch.Tensor:
75
  hidden_states = input_embeds
76
  for layer in self.mtp_layers:
77
+ hidden_states = layer(
78
+ input_embeds=input_embeds,
79
+ hidden_states=hidden_states,
80
+ attention_mask=attention_mask,
81
+ position_ids=position_ids,
82
+ **kwargs
83
+ )
84
  return hidden_states
85
 
86
+ class SpecT1ForCausalLM(PreTrainedModel):
 
87
  config_class = SpecT1Config
 
88
  def __init__(self, config: SpecT1Config):
89
+ super().__init__(config)
90
+ self.config = config
91
  self.model = SpecT1Model(config)
 
92
  self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
 
93
 
94
+ def forward(
95
+ self,
96
+ input_ids: torch.Tensor = None,
97
+ attention_mask: Optional[torch.Tensor] = None,
98
+ position_ids: Optional[torch.Tensor] = None,
99
+ inputs_embeds: Optional[torch.Tensor] = None,
100
+ labels: Optional[torch.Tensor] = None,
101
+ past_key_values: Optional[Cache] = None,
102
+ use_cache: Optional[bool] = False,
103
+ output_attentions: Optional[bool] = False,
104
+ output_hidden_states: Optional[bool] = False,
105
+ return_dict: Optional[bool] = True,
106
+ **kwargs
107
+ ) -> torch.Tensor:
108
+ if inputs_embeds is None:
109
+ raise ValueError("inputs_embeds must be provided for SpecT1ForCausalLM")
110
+ hidden_states = self.model(
111
+ input_embeds=inputs_embeds,
112
+ attention_mask=attention_mask,
113
+ position_ids=position_ids,
114
+ **kwargs
115
+ )
116
  logits = self.lm_head(hidden_states)
117
+ loss = None
118
+ if labels is not None:
119
+ loss_fct = nn.CrossEntropyLoss()
120
+ loss = loss_fct(logits.view(-1, self.config.vocab_size), labels.view(-1))
121
+ if not return_dict:
122
+ return (logits,) + (loss,) if loss is not None else (logits,)
123
+ from transformers.modeling_outputs import CausalLMOutputWithPast
124
+ return CausalLMOutputWithPast(
125
+ loss=loss,
126
+ logits=logits,
127
+ hidden_states=None,
128
+ attentions=None,
129
+ past_key_values=None
130
+ )
131
+
132
+ def prepare_inputs_for_generation(
133
+ self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs
134
+ ):
135
+ if inputs_embeds is None:
136
+ raise ValueError("SpecT1ForCausalLM requires inputs_embeds for generation")
137
+ return {
138
+ "inputs_embeds": inputs_embeds,
139
+ "attention_mask": attention_mask,
140
+ "past_key_values": past_key_values,
141
+ "use_cache": kwargs.get("use_cache", True)
142
+ }