gptailor-llama3-8b / modeling_customllama.py
guinansu's picture
Upload folder using huggingface_hub
615724d verified
raw
history blame
1.78 kB
from transformers.models.llama.modeling_llama import LlamaModel, LlamaDecoderLayer
from transformers import LlamaConfig, LlamaForCausalLM
import torch.nn as nn
from .configuration_customllama import CustomLlamaConfig
class CustomLlamaDecoderLayer(LlamaDecoderLayer):
def __init__(self, config, layer_idx, scale=1.0):
super().__init__(config, layer_idx)
self.scale = scale
def forward(self, hidden_states, attention_mask=None, position_ids=None, past_key_value=None, output_attentions=False, use_cache=True, cache_position=None, position_embeddings=None, **kwargs):
outputs = super().forward(
hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=past_key_value,
output_attentions=output_attentions,
use_cache=use_cache,
cache_position=cache_position,
position_embeddings=position_embeddings
)
hidden_states = outputs[0]
hidden_states = hidden_states * self.scale
return (hidden_states, *outputs[1:])
class CustomLlama(LlamaModel):
config_class = CustomLlamaConfig
def __init__(self, config):
super().__init__(config)
self.scales = config.scales
assert len(self.scales) == config.num_hidden_layers
self.layers = nn.ModuleList([
CustomLlamaDecoderLayer(config, layer_idx=i,scale=self.scales[i]) for i in range(config.num_hidden_layers)
])
def forward(self, *args, **kwargs):
return super().forward(*args, **kwargs)
class CustomLlamaForCausalLM(LlamaForCausalLM):
config_class = CustomLlamaConfig
def __init__(self, config):
super().__init__(config)
self.model = CustomLlama(config)