from transformers.models.llama.modeling_llama import LlamaModel, LlamaDecoderLayer from transformers import LlamaConfig, LlamaForCausalLM import torch.nn as nn from .configuration_customllama import CustomLlamaConfig class CustomLlamaDecoderLayer(LlamaDecoderLayer): def __init__(self, config, layer_idx, scale=1.0): super().__init__(config, layer_idx) self.scale = scale def forward(self, hidden_states, attention_mask=None, position_ids=None, past_key_value=None, output_attentions=False, use_cache=True, cache_position=None, position_embeddings=None, **kwargs): outputs = super().forward( hidden_states, attention_mask=attention_mask, position_ids=position_ids, past_key_value=past_key_value, output_attentions=output_attentions, use_cache=use_cache, cache_position=cache_position, position_embeddings=position_embeddings ) hidden_states = outputs[0] hidden_states = hidden_states * self.scale return (hidden_states, *outputs[1:]) class CustomLlama(LlamaModel): config_class = CustomLlamaConfig def __init__(self, config): super().__init__(config) self.scales = config.scales assert len(self.scales) == config.num_hidden_layers self.layers = nn.ModuleList([ CustomLlamaDecoderLayer(config, layer_idx=i,scale=self.scales[i]) for i in range(config.num_hidden_layers) ]) def forward(self, *args, **kwargs): return super().forward(*args, **kwargs) class CustomLlamaForCausalLM(LlamaForCausalLM): config_class = CustomLlamaConfig def __init__(self, config): super().__init__(config) self.model = CustomLlama(config)