File size: 1,781 Bytes
615724d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
from transformers.models.llama.modeling_llama import LlamaModel, LlamaDecoderLayer
from transformers import LlamaConfig, LlamaForCausalLM
import torch.nn as nn

from .configuration_customllama import CustomLlamaConfig


class CustomLlamaDecoderLayer(LlamaDecoderLayer):
    def __init__(self, config, layer_idx, scale=1.0):
        super().__init__(config, layer_idx)
        self.scale = scale

    def forward(self, hidden_states, attention_mask=None, position_ids=None, past_key_value=None, output_attentions=False, use_cache=True, cache_position=None, position_embeddings=None, **kwargs):
        outputs = super().forward(
            hidden_states,
            attention_mask=attention_mask,
            position_ids=position_ids,
            past_key_value=past_key_value,
            output_attentions=output_attentions,
            use_cache=use_cache,
            cache_position=cache_position,
            position_embeddings=position_embeddings
        )

        hidden_states = outputs[0]
        hidden_states = hidden_states * self.scale
        return (hidden_states, *outputs[1:])


class CustomLlama(LlamaModel):
    config_class = CustomLlamaConfig
    def __init__(self, config):
        super().__init__(config)
        self.scales = config.scales
        assert len(self.scales) == config.num_hidden_layers

        self.layers = nn.ModuleList([
            CustomLlamaDecoderLayer(config, layer_idx=i,scale=self.scales[i]) for i in range(config.num_hidden_layers)
        ])

    def forward(self, *args, **kwargs):
        return super().forward(*args, **kwargs)


class CustomLlamaForCausalLM(LlamaForCausalLM):
    config_class = CustomLlamaConfig
    def __init__(self, config):
        super().__init__(config)
        self.model = CustomLlama(config)