|
import logging
|
|
from typing import Literal, Optional, Union
|
|
import functools
|
|
from functools import partial
|
|
import torch
|
|
import torch.nn as nn
|
|
from torch import Tensor
|
|
import math
|
|
import os
|
|
from mamba_block import MambaBlock, MambaDecoder
|
|
from mamba_config import MambaConfig
|
|
from hf_utils import *
|
|
import os, json
|
|
from transformers.utils import WEIGHTS_NAME, CONFIG_NAME
|
|
from transformers.utils.hub import cached_file
|
|
|
|
|
|
|
|
def _init_weights(
|
|
module,
|
|
n_layer,
|
|
initializer_range=0.02,
|
|
rescale_prenorm_residual=True,
|
|
n_residuals_per_layer=1,
|
|
):
|
|
if isinstance(module, nn.Linear):
|
|
if module.bias is not None:
|
|
if not getattr(module.bias, "_no_reinit", False):
|
|
nn.init.zeros_(module.bias)
|
|
elif isinstance(module, nn.Embedding):
|
|
nn.init.normal_(module.weight, std=initializer_range)
|
|
|
|
if rescale_prenorm_residual:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
for name, p in module.named_parameters():
|
|
if name in ["out_proj.weight", "fc2.weight"]:
|
|
|
|
|
|
|
|
|
|
nn.init.kaiming_uniform_(p, a=math.sqrt(5))
|
|
with torch.no_grad():
|
|
p /= math.sqrt(n_residuals_per_layer * n_layer)
|
|
|
|
|
|
class MambaModel(nn.Module):
|
|
def __init__(
|
|
self,
|
|
config: MambaConfig,
|
|
max_sequence_length: int,
|
|
pre_process: bool = True,
|
|
post_process: bool = True,
|
|
fp16_lm_cross_entropy: bool = False,
|
|
parallel_output: bool = True,
|
|
share_embeddings_and_output_weights: bool = True,
|
|
initializer_cfg = None,
|
|
) -> None:
|
|
super().__init__()
|
|
|
|
self.config: MambaConfig = config
|
|
self.max_sequence_length = max_sequence_length
|
|
self.pre_process = pre_process
|
|
self.post_process = post_process
|
|
self.fp16_lm_cross_entropy = fp16_lm_cross_entropy
|
|
self.parallel_output = parallel_output
|
|
self.share_embeddings_and_output_weights = share_embeddings_and_output_weights
|
|
|
|
if self.pre_process:
|
|
self.embedding = nn.Embedding(self.config.vocab_size, self.config.hidden_size)
|
|
|
|
|
|
self.decoder = MambaDecoder(
|
|
config = self.config,
|
|
pre_process = self.pre_process,
|
|
post_process = self.post_process,
|
|
)
|
|
|
|
if post_process:
|
|
self.output_layer = nn.Linear(self.config.hidden_size, self.config.vocab_size, bias = self.config.add_bias_linear)
|
|
if self.share_embeddings_and_output_weights and (self.pre_process or self.post_process):
|
|
self.initialize_last_stage_with_word_embeddings()
|
|
|
|
|
|
self.apply(
|
|
partial(
|
|
_init_weights,
|
|
n_layer=self.config.num_layers,
|
|
**(initializer_cfg if initializer_cfg is not None else {}),
|
|
)
|
|
)
|
|
|
|
def initialize_last_stage_with_word_embeddings(self):
|
|
with torch.no_grad():
|
|
self.output_layer.weight = self.embedding.weight
|
|
|
|
def forward(
|
|
self,
|
|
input_ids,
|
|
position_ids = None,
|
|
decoder_input: Tensor = None,
|
|
labels: Tensor = None,
|
|
inference_params=None,
|
|
) -> Tensor:
|
|
if decoder_input is not None:
|
|
pass
|
|
elif self.pre_process:
|
|
decoder_input = self.embedding(input_ids)
|
|
else:
|
|
decoder_input = None
|
|
|
|
hidden_states = self.decoder(
|
|
hidden_states=decoder_input,
|
|
residual=None,
|
|
inference_params=inference_params,
|
|
)
|
|
|
|
if not self.post_process:
|
|
return hidden_states
|
|
|
|
logits = self.output_layer(hidden_states)
|
|
|
|
return logits.contiguous()
|
|
|
|
@classmethod
|
|
def from_pretrained(cls, pretrained_model_name = None, checkpoint_name=None, config_name=None, **kwargs):
|
|
if pretrained_model_name is not None:
|
|
json_config = load_config_hf(pretrained_model_name)
|
|
loaded = load_state_dict_hf(pretrained_model_name)
|
|
elif checkpoint_name is not None and config_name is not None:
|
|
with open(config_name, 'r') as f:
|
|
jsonstr = f.read()
|
|
json_config = json.loads(jsonstr)
|
|
loaded = torch.load(checkpoint_name, map_location='cpu')
|
|
else:
|
|
return
|
|
model_state_dict = loaded["model"]
|
|
|
|
config = MambaConfig(
|
|
num_layers=json_config['num_layers'],
|
|
hidden_size=json_config['hidden_size'],
|
|
state_size=json_config['state_size'],
|
|
conv_dimension=json_config['conv_dimension'],
|
|
vocab_size=json_config['vocab_size'],
|
|
expansion_factor=json_config['expansion_factor'],
|
|
mamba_moe_layers=json_config['mamba_moe_layers'],
|
|
ffn_hidden_size=json_config['ffn_hidden_size'],
|
|
bias = json_config['add_bias_linear'],
|
|
add_bias_linear = json_config['add_bias_linear'],
|
|
gated_linear_unit = json_config['swiglu']
|
|
)
|
|
|
|
model = MambaModel(config=config, max_sequence_length=json_config['max_sequence_length'], **kwargs)
|
|
|
|
|
|
model_state_dict["embedding.weight"] = model_state_dict["embedding.word_embeddings.weight"].clone()
|
|
model_state_dict["output_layer.weight"] = model_state_dict["embedding.word_embeddings.weight"].clone()
|
|
model_state_dict["embedding.word_embeddings.weight"] = None
|
|
model_state_dict.pop("embedding.word_embeddings.weight")
|
|
model.load_state_dict(loaded["model"])
|
|
return model
|
|
|
|
def save_pretrained(self, save_directory):
|
|
"""
|
|
Minimal implementation of save_pretrained for MambaLMHeadModel.
|
|
Save the model and its configuration file to a directory.
|
|
"""
|
|
|
|
if not os.path.exists(save_directory):
|
|
os.makedirs(save_directory)
|
|
|
|
|
|
model_path = os.path.join(save_directory, 'pytorch_model.bin')
|
|
torch.save(self.state_dict(), model_path)
|
|
|
|
|
|
config_path = os.path.join(save_directory, 'config.json')
|
|
with open(config_path, 'w') as f:
|
|
json.dump(self.config.__dict__, f)
|
|
|