[hotfix] update modeling
Browse files- configuration_grok1.py +4 -0
- modeling_grok1.py +4 -2
configuration_grok1.py
CHANGED
@@ -16,6 +16,8 @@ class Grok1Config(PretrainedConfig):
|
|
16 |
attn_output_multiplier=1.0,
|
17 |
max_attn_value=1.0,
|
18 |
max_position_embeddings=4096,
|
|
|
|
|
19 |
rms_norm_eps=1e-5,
|
20 |
use_cache=True,
|
21 |
pad_token_id=None,
|
@@ -32,6 +34,8 @@ class Grok1Config(PretrainedConfig):
|
|
32 |
self.attn_output_multiplier = attn_output_multiplier
|
33 |
self.max_attn_value = max_attn_value
|
34 |
self.max_position_embeddings = max_position_embeddings
|
|
|
|
|
35 |
self.hidden_size = hidden_size
|
36 |
self.widening_factor = widening_factor
|
37 |
self.num_hidden_layers = num_hidden_layers
|
|
|
16 |
attn_output_multiplier=1.0,
|
17 |
max_attn_value=1.0,
|
18 |
max_position_embeddings=4096,
|
19 |
+
embedding_multiplier_scale: float = 1.0,
|
20 |
+
output_multiplier_scale: float = 1.0,
|
21 |
rms_norm_eps=1e-5,
|
22 |
use_cache=True,
|
23 |
pad_token_id=None,
|
|
|
34 |
self.attn_output_multiplier = attn_output_multiplier
|
35 |
self.max_attn_value = max_attn_value
|
36 |
self.max_position_embeddings = max_position_embeddings
|
37 |
+
self.embedding_multiplier_scale = embedding_multiplier_scale
|
38 |
+
self.output_multiplier_scale = output_multiplier_scale
|
39 |
self.hidden_size = hidden_size
|
40 |
self.widening_factor = widening_factor
|
41 |
self.num_hidden_layers = num_hidden_layers
|
modeling_grok1.py
CHANGED
@@ -259,8 +259,6 @@ class MultiHeadAttention(nn.Module):
|
|
259 |
|
260 |
past_key_value = (key_states, value_states) if use_cache else None
|
261 |
|
262 |
-
# TODO: repeat kv
|
263 |
-
|
264 |
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)).to(
|
265 |
torch.float
|
266 |
)
|
@@ -536,6 +534,7 @@ class Grok1Model(Grok1PretrainedModel):
|
|
536 |
super().__init__(config)
|
537 |
self.padding_idx = config.pad_token_id
|
538 |
self.vocab_size = config.vocab_size
|
|
|
539 |
|
540 |
self.embed_tokens = nn.Embedding(
|
541 |
config.vocab_size, config.hidden_size, self.padding_idx
|
@@ -654,6 +653,7 @@ class Grok1Model(Grok1PretrainedModel):
|
|
654 |
|
655 |
if inputs_embeds is None:
|
656 |
inputs_embeds = self.embed_tokens(input_ids)
|
|
|
657 |
|
658 |
if HAS_MASK_UTILS:
|
659 |
# 4d mask is passed through the layers
|
@@ -772,6 +772,7 @@ class Grok1ModelForCausalLM(Grok1PretrainedModel):
|
|
772 |
super().__init__(config)
|
773 |
self.model = Grok1Model(config)
|
774 |
self.vocab_size = config.vocab_size
|
|
|
775 |
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
776 |
self.router_aux_loss_coef = config.router_aux_loss_coef
|
777 |
self.num_experts = config.num_experts
|
@@ -846,6 +847,7 @@ class Grok1ModelForCausalLM(Grok1PretrainedModel):
|
|
846 |
|
847 |
hidden_states = outputs[0]
|
848 |
logits = self.lm_head(hidden_states)
|
|
|
849 |
logits = logits.float()
|
850 |
|
851 |
loss = None
|
|
|
259 |
|
260 |
past_key_value = (key_states, value_states) if use_cache else None
|
261 |
|
|
|
|
|
262 |
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)).to(
|
263 |
torch.float
|
264 |
)
|
|
|
534 |
super().__init__(config)
|
535 |
self.padding_idx = config.pad_token_id
|
536 |
self.vocab_size = config.vocab_size
|
537 |
+
self.embedding_multiplier_scale = config.embedding_multiplier_scale
|
538 |
|
539 |
self.embed_tokens = nn.Embedding(
|
540 |
config.vocab_size, config.hidden_size, self.padding_idx
|
|
|
653 |
|
654 |
if inputs_embeds is None:
|
655 |
inputs_embeds = self.embed_tokens(input_ids)
|
656 |
+
inputs_embeds = inputs_embeds * self.embedding_multiplier_scale
|
657 |
|
658 |
if HAS_MASK_UTILS:
|
659 |
# 4d mask is passed through the layers
|
|
|
772 |
super().__init__(config)
|
773 |
self.model = Grok1Model(config)
|
774 |
self.vocab_size = config.vocab_size
|
775 |
+
self.output_multiplier_scale = config.output_multiplier_scale
|
776 |
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
777 |
self.router_aux_loss_coef = config.router_aux_loss_coef
|
778 |
self.num_experts = config.num_experts
|
|
|
847 |
|
848 |
hidden_states = outputs[0]
|
849 |
logits = self.lm_head(hidden_states)
|
850 |
+
logits = logits * self.output_multiplier_scale
|
851 |
logits = logits.float()
|
852 |
|
853 |
loss = None
|