recast3.1-G16W128H64 / modeling_recast_llama.py
appledora's picture
Upload modeling_recast_llama.py with huggingface_hub
1057ee5 verified
# filename: recastmlp_llama_model.py
from .configuration_recast_llama import RECAST8b_llama
from transformers import PreTrainedModel
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Optional, Tuple, Union, List
from transformers import AutoConfig
from transformers.utils import logging
from transformers.cache_utils import Cache, StaticCache
from transformers.modeling_outputs import CausalLMOutputWithPast
from transformers.generation import GenerationMixin
from transformers.modeling_attn_mask_utils import AttentionMaskConverter
from transformers.models.llama.modeling_llama import (
LlamaDecoderLayer,
LlamaRotaryEmbedding,
LlamaRMSNorm,
apply_rotary_pos_emb,
repeat_kv,
)
from transformers.modeling_outputs import BaseModelOutputWithPast
logger = logging.get_logger(__name__)
class MLPTemplateBank(nn.Module):
def __init__(self, config, coef_rows, coef_columns):
super().__init__()
self.hidden_size = config.hidden_size
self.intermediate_size = config.intermediate_size
self.coef_shape = (coef_rows, coef_columns)
assert coef_columns is not None, "coef_columns must not be None"
# Ensure divisibility for proper reshaping
assert (
self.hidden_size * self.intermediate_size
) % coef_rows == 0, f"hidden_size * intermediate_size ({self.hidden_size * self.intermediate_size}) must be divisible by coef_rows ({coef_rows})"
template_size = self.hidden_size * self.intermediate_size // coef_rows
self.up_templates = nn.Parameter(torch.randn(coef_columns, template_size))
self.gate_templates = nn.Parameter(torch.randn(coef_columns, template_size))
# Better initialization
nn.init.xavier_uniform_(self.up_templates)
nn.init.xavier_uniform_(self.gate_templates)
def forward(self, up_coeffs, gate_coeffs):
# Compute chunked weights
up_chunks = torch.matmul(up_coeffs, self.up_templates)
gate_chunks = torch.matmul(gate_coeffs, self.gate_templates)
# Reshape to final weight matrices
up_weights = up_chunks.reshape(self.intermediate_size, self.hidden_size)
gate_weights = gate_chunks.reshape(self.intermediate_size, self.hidden_size)
return up_weights, gate_weights
class SharedLlamaMLP(nn.Module):
def __init__(self, config, bank):
super().__init__()
self.config = config
self.bank = bank
self.hidden_size = config.hidden_size
self.intermediate_size = config.intermediate_size
self.down_proj = nn.Linear(
config.intermediate_size, config.hidden_size, bias=False
)
# Initialize coefficients with proper shapes
self.up_coefficients = nn.Parameter(torch.randn(bank.coef_shape))
self.gate_coefficients = nn.Parameter(torch.randn(bank.coef_shape))
# Initialize with small random values instead of ones, then orthogonalize
nn.init.orthogonal_(self.up_coefficients)
nn.init.orthogonal_(self.gate_coefficients)
if config.mlp_bias:
self.gate_bias = nn.Parameter(torch.zeros(self.intermediate_size))
self.up_bias = nn.Parameter(torch.zeros(self.intermediate_size))
else:
self.register_parameter("gate_bias", None)
self.register_parameter("up_bias", None)
self.act_fn = F.silu
def forward(self, x):
# Generate weights using template bank
up_weights, gate_weights = self.bank(
self.up_coefficients, self.gate_coefficients # Fixed order
)
# Apply SwiGLU: SiLU(gate * x) * up * x
hidden_states = self.act_fn(
F.linear(x, gate_weights, self.gate_bias)
) * F.linear(x, up_weights, self.up_bias)
output = self.down_proj(hidden_states)
return output
class AttTemplateBank(nn.Module):
def __init__(self, config, coef_rows, coef_columns):
super().__init__()
self.hidden_size = config.hidden_size
self.num_heads = config.num_attention_heads
self.head_dim = config.hidden_size // config.num_attention_heads
self.num_key_value_heads = getattr(
config, "num_key_value_heads", config.num_attention_heads
)
self.kv_dim = self.num_key_value_heads * self.head_dim
self.coef_shape = (coef_rows, coef_columns)
# Ensure divisibility
assert (
self.hidden_size * self.hidden_size
) % coef_rows == 0, "Q projection size must be divisible by coef_rows"
assert (
self.kv_dim * self.hidden_size
) % coef_rows == 0, "K/V projection size must be divisible by coef_rows"
# Create templates for Q, K, V
self.q_templates = nn.Parameter(
torch.randn(coef_columns, self.hidden_size * self.hidden_size // coef_rows)
)
self.k_templates = nn.Parameter(
torch.randn(coef_columns, self.kv_dim * self.hidden_size // coef_rows)
)
self.v_templates = nn.Parameter(
torch.randn(coef_columns, self.kv_dim * self.hidden_size // coef_rows)
)
# Initialize templates
nn.init.xavier_uniform_(self.q_templates)
nn.init.xavier_uniform_(self.k_templates)
nn.init.xavier_uniform_(self.v_templates)
def forward(self, q_coeffs, k_coeffs, v_coeffs):
# Compute chunked weights
q_chunks = torch.matmul(q_coeffs, self.q_templates)
k_chunks = torch.matmul(k_coeffs, self.k_templates)
v_chunks = torch.matmul(v_coeffs, self.v_templates)
# Reshape to final weight matrices
q_weights = q_chunks.reshape(self.hidden_size, self.hidden_size)
k_weights = k_chunks.reshape(self.kv_dim, self.hidden_size)
v_weights = v_chunks.reshape(self.kv_dim, self.hidden_size)
return q_weights, k_weights, v_weights
class SharedLlamaAttention(nn.Module):
def __init__(
self,
config,
layer_idx: Optional[int] = None,
bank: Optional[AttTemplateBank] = None,
):
super().__init__()
self.config = config
self.bank = bank
self.layer_idx = layer_idx
self.attention_dropout = config.attention_dropout
self.hidden_size = config.hidden_size
self.num_heads = config.num_attention_heads
self.head_dim = self.hidden_size // self.num_heads
self.num_key_value_heads = getattr(
config, "num_key_value_heads", config.num_attention_heads
)
self.num_key_value_groups = self.num_heads // self.num_key_value_heads
self.max_position_embeddings = config.max_position_embeddings
self.rope_theta = getattr(config, "rope_theta", 10000.0)
self.is_causal = True
self.o_proj = nn.Linear(
self.hidden_size,
self.hidden_size,
bias=getattr(config, "attention_bias", False),
)
self.rotary_emb = LlamaRotaryEmbedding(config=self.config)
# Initialize coefficients with proper shapes
self.q_coefficients = nn.Parameter(torch.randn(bank.coef_shape))
self.k_coefficients = nn.Parameter(torch.randn(bank.coef_shape))
self.v_coefficients = nn.Parameter(torch.randn(bank.coef_shape))
# Initialize with small random values
nn.init.orthogonal_(self.q_coefficients)
nn.init.orthogonal_(self.k_coefficients)
nn.init.orthogonal_(self.v_coefficients)
def forward(
self,
hidden_states,
attention_mask=None,
past_key_value=None,
cache_position=None,
position_embeddings=None,
position_ids=None,
output_attentions=False,
use_cache=False,
**kwargs,
):
bsz, q_len, _ = hidden_states.size()
# Generate weights using template bank
q_weights, k_weights, v_weights = self.bank(
self.q_coefficients, self.k_coefficients, self.v_coefficients
)
# Apply projections
query_states = F.linear(hidden_states, q_weights)
key_states = F.linear(hidden_states, k_weights)
value_states = F.linear(hidden_states, v_weights)
# Reshape for multi-head attention
query_states = query_states.view(
bsz, q_len, self.num_heads, self.head_dim
).transpose(1, 2)
key_states = key_states.view(
bsz, q_len, self.num_key_value_heads, self.head_dim
).transpose(1, 2)
value_states = value_states.view(
bsz, q_len, self.num_key_value_heads, self.head_dim
).transpose(1, 2)
# Apply rotary embeddings
if position_embeddings is None:
cos, sin = self.rotary_emb(value_states, position_ids)
else:
cos, sin = position_embeddings
query_states, key_states = apply_rotary_pos_emb(
query_states, key_states, cos, sin
)
# Handle past key values
if past_key_value is not None:
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
key_states, value_states = past_key_value.update(
key_states, value_states, self.layer_idx, cache_kwargs
)
# Repeat key/value for grouped query attention
key_states = repeat_kv(key_states, self.num_key_value_groups)
value_states = repeat_kv(value_states, self.num_key_value_groups)
# Compute attention
attn_weights = torch.matmul(
query_states, key_states.transpose(2, 3)
) / math.sqrt(self.head_dim)
if attention_mask is not None:
causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
attn_weights = attn_weights + causal_mask
# Apply softmax and dropout
attn_weights = nn.functional.softmax(
attn_weights, dim=-1, dtype=torch.float32
).to(query_states.dtype)
attn_weights = nn.functional.dropout(
attn_weights, p=self.attention_dropout, training=self.training
)
attn_output = torch.matmul(attn_weights, value_states)
if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
raise ValueError(
f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
f" {attn_output.size()}"
)
attn_output = attn_output.transpose(1, 2).contiguous()
attn_output = attn_output.reshape(bsz, q_len, -1)
attn_output = self.o_proj(attn_output)
if not output_attentions:
attn_weights = None
return attn_output, attn_weights, past_key_value
def fixed_cross_entropy(
source,
target,
num_items_in_batch: int = None,
ignore_index: int = -100,
**kwargs,
):
reduction = "sum" if num_items_in_batch is not None else "mean"
loss = nn.functional.cross_entropy(
source, target, ignore_index=ignore_index, reduction=reduction
)
if reduction == "sum":
loss = loss / num_items_in_batch
return loss
class RECAST8b_llamaModel(PreTrainedModel):
config_class = RECAST8b_llama
base_model_prefix = "llama"
supports_gradient_checkpointing = True
_no_split_modules = ["LlamaDecoderLayer"] # Add this
_skip_keys_device_placement = "past_key_values" # Add this
def __init__(self, config):
super().__init__(config)
self.padding_idx = config.pad_token_id
self.vocab_size = config.vocab_size
self.embed_tokens = nn.Embedding(
config.vocab_size, config.hidden_size, self.padding_idx
)
original_config = AutoConfig.from_pretrained(
"meta-llama/Llama-3.1-8b", trust_remote_code=True
)
self.rotary_emb = LlamaRotaryEmbedding(
config=original_config,
)
# Create template banks first
self.mlp_banks = []
self.attn_banks = []
layers_per_group = config.num_hidden_layers // config.num_groups
# Explicitly calculate coef_width if not provided in config
if hasattr(config, "coef_width") and config.coef_width is not None:
coef_width = config.coef_width
else:
coef_width = config.coef_height * layers_per_group
config.coef_width = coef_width
print(
f"Model config: num_groups={config.num_groups}, layers_per_group={layers_per_group}"
)
print(f"Coefficient shape: ({config.coef_height}, {config.coef_width})")
mlp_banks = nn.ModuleList(
[
MLPTemplateBank(
config=config, coef_rows=config.coef_height, coef_columns=coef_width
)
for _ in range(config.num_groups)
]
)
attn_banks = nn.ModuleList(
[
AttTemplateBank(
config=config, coef_rows=config.coef_height, coef_columns=coef_width
)
for _ in range(config.num_groups)
]
)
self.mlp_banks = mlp_banks
self.attn_banks = attn_banks
# Create layers using LlamaDecoderLayer but replace MLPs
self.layers = nn.ModuleList()
for layer_idx in range(config.num_hidden_layers):
# Create standard LlamaDecoderLayer
decoder_layer = LlamaDecoderLayer(config, layer_idx)
# Replace its MLP with our SharedLlamaMLP
group_idx = layer_idx // layers_per_group
decoder_layer.mlp = SharedLlamaMLP(config, self.mlp_banks[group_idx])
decoder_layer.self_attn = SharedLlamaAttention(
config, layer_idx, self.attn_banks[group_idx]
)
self.layers.append(decoder_layer)
self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.gradient_checkpointing = False
def forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
**flash_attn_kwargs,
) -> Union[Tuple, BaseModelOutputWithPast]:
output_attentions = (
output_attentions
if output_attentions is not None
else self.config.output_attentions
)
output_hidden_states = (
output_hidden_states
if output_hidden_states is not None
else self.config.output_hidden_states
)
use_cache = use_cache if use_cache is not None else self.config.use_cache
return_dict = (
return_dict if return_dict is not None else self.config.use_return_dict
)
if (input_ids is None) ^ (inputs_embeds is not None):
raise ValueError(
"You must specify exactly one of input_ids or inputs_embeds"
)
if self.gradient_checkpointing and self.training and use_cache:
logger.warning_once(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`."
)
use_cache = False
if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids)
# Set up cache position if not provided
if cache_position is None:
past_seen_tokens = (
0
if past_key_values is None
else (
past_key_values.get_seq_length()
if isinstance(past_key_values, Cache)
else past_key_values[0][0].size(-2) if past_key_values else 0
)
)
cache_position = torch.arange(
past_seen_tokens,
past_seen_tokens + inputs_embeds.shape[1],
device=inputs_embeds.device,
)
# Create position embeddings to be shared across the decoder layers
# Set up position IDs if not provided
if position_ids is None:
position_ids = cache_position.unsqueeze(0)
# Get updated causal mask
causal_mask = self._update_causal_mask(
attention_mask,
inputs_embeds,
cache_position,
past_key_values,
output_attentions,
)
hidden_states = inputs_embeds
position_embeddings = self.rotary_emb(hidden_states, position_ids)
# Initialize outputs
all_hidden_states = () if output_hidden_states else None
all_self_attns = () if output_attentions else None
next_decoder_cache = None
# Process through layers
for decoder_layer in self.layers:
if output_hidden_states:
all_hidden_states += (hidden_states,)
if self.gradient_checkpointing and self.training:
layer_outputs = self._gradient_checkpointing_func(
decoder_layer.__call__,
hidden_states,
causal_mask,
position_ids,
past_key_values,
output_attentions,
use_cache,
position_embeddings,
)
else:
layer_outputs = decoder_layer(
hidden_states,
attention_mask=causal_mask,
position_ids=position_ids,
past_key_value=past_key_values,
output_attentions=output_attentions,
use_cache=use_cache,
position_embeddings=position_embeddings,
**flash_attn_kwargs,
)
hidden_states = layer_outputs[0]
if use_cache:
next_decoder_cache = layer_outputs[2 if output_attentions else 1]
if output_attentions:
all_self_attns += (layer_outputs[1],)
# Final layer norm
hidden_states = self.norm(hidden_states)
# Add last hidden state
if output_hidden_states:
all_hidden_states += (hidden_states,)
next_cache = next_decoder_cache if use_cache else None
if not return_dict:
return tuple(
v
for v in [hidden_states, next_cache, all_hidden_states, all_self_attns]
if v is not None
)
return BaseModelOutputWithPast(
last_hidden_state=hidden_states,
past_key_values=next_cache,
hidden_states=all_hidden_states,
attentions=all_self_attns,
)
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
if isinstance(
pretrained_model_name_or_path, str
) and pretrained_model_name_or_path.endswith(".pt"):
print("Loading from local checkpoint")
# Load from local checkpoint
config = kwargs.get("config", None)
if config is None:
config = AutoConfig.from_pretrained(
pretrained_model_name_or_path, trust_remote_code=True
)
model = cls(config)
checkpoint = torch.load(pretrained_model_name_or_path, map_location="cpu")
state_dict = checkpoint["model_state_dict"]
logger.info(
f"Loaded checkpoint from epoch {checkpoint.get('epoch')} with loss {checkpoint.get('loss')}"
)
missing_keys, unexpected_keys = model.load_state_dict(
state_dict, strict=False
)
if len(missing_keys) > 0:
logger.warning(f"Missing keys: {missing_keys}")
if len(unexpected_keys) > 0:
logger.warning(f"Unexpected keys: {unexpected_keys}")
return model
else:
print("Loading from hub")
# Load from hub using parent's from_pretrained
return super().from_pretrained(
pretrained_model_name_or_path, *model_args, **kwargs
)
def get_input_embeddings(self):
return self.embed_tokens
def set_input_embeddings(self, value):
self.embed_tokens = value
def _update_causal_mask(
self,
attention_mask: torch.Tensor,
input_tensor: torch.Tensor,
cache_position: torch.Tensor,
past_key_values: Cache,
output_attentions: bool,
):
if self.config._attn_implementation == "flash_attention_2":
if attention_mask is not None and 0.0 in attention_mask:
return attention_mask
return None
# For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in
# order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail
# to infer the attention mask.
past_seen_tokens = (
past_key_values.get_seq_length() if past_key_values is not None else 0
)
using_static_cache = isinstance(past_key_values, StaticCache)
# When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward
if (
self.config._attn_implementation == "sdpa"
and not using_static_cache
and not output_attentions
):
if AttentionMaskConverter._ignore_causal_mask_sdpa(
attention_mask,
inputs_embeds=input_tensor,
past_key_values_length=past_seen_tokens,
is_training=self.training,
):
return None
dtype, device = input_tensor.dtype, input_tensor.device
sequence_length = input_tensor.shape[1]
if using_static_cache:
target_length = past_key_values.get_max_cache_shape()
else:
target_length = (
attention_mask.shape[-1]
if isinstance(attention_mask, torch.Tensor)
else past_seen_tokens + sequence_length + 1
)
# In case the provided `attention` mask is 2D, we generate a causal mask here (4D).
causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position(
attention_mask,
sequence_length=sequence_length,
target_length=target_length,
dtype=dtype,
device=device,
cache_position=cache_position,
batch_size=input_tensor.shape[0],
)
if (
self.config._attn_implementation == "sdpa"
and attention_mask is not None
and attention_mask.device.type == "cuda"
and not output_attentions
):
# Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
# using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
# Details: https://github.com/pytorch/pytorch/issues/110213
min_dtype = torch.finfo(dtype).min
causal_mask = AttentionMaskConverter._unmask_unattended(
causal_mask, min_dtype
)
return causal_mask
@staticmethod
def _prepare_4d_causal_attention_mask_with_cache_position(
attention_mask: torch.Tensor,
sequence_length: int,
target_length: int,
dtype: torch.dtype,
device: torch.device,
cache_position: torch.Tensor,
batch_size: int,
**kwargs,
):
if attention_mask is not None and attention_mask.dim() == 4:
# In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
causal_mask = attention_mask
else:
min_dtype = torch.finfo(dtype).min
causal_mask = torch.full(
(sequence_length, target_length),
fill_value=min_dtype,
dtype=dtype,
device=device,
)
if sequence_length != 1:
causal_mask = torch.triu(causal_mask, diagonal=1)
causal_mask *= torch.arange(
target_length, device=device
) > cache_position.reshape(-1, 1)
causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
if attention_mask is not None:
causal_mask = (
causal_mask.clone()
) # copy to contiguous memory for in-place edit
mask_length = attention_mask.shape[-1]
padding_mask = (
causal_mask[:, :, :, :mask_length]
+ attention_mask[:, None, None, :]
)
padding_mask = padding_mask == 0
causal_mask[:, :, :, :mask_length] = causal_mask[
:, :, :, :mask_length
].masked_fill(padding_mask, min_dtype)
return causal_mask
class RECAST8b_LlamaForCausalLM(PreTrainedModel, GenerationMixin):
_tied_weights_keys = ["lm_head.weight"]
_tp_plan = {"lm_head": "colwise_rep"}
config_class = RECAST8b_llama
base_model_prefix = "llama"
supports_gradient_checkpointing = True
_no_split_modules = ["LlamaDecoderLayer"] # Add this
_skip_keys_device_placement = "past_key_values" # Add this
def __init__(self, config):
super().__init__(config)
self.model = RECAST8b_llamaModel(config)
self.vocab_size = config.vocab_size
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
# Initialize weights and apply final processing
self.post_init()
def get_input_embeddings(self):
return self.model.embed_tokens
def set_input_embeddings(self, value):
self.model.embed_tokens = value
def get_output_embeddings(self):
return self.lm_head
def set_output_embeddings(self, new_embeddings):
self.lm_head = new_embeddings
def set_decoder(self, decoder):
self.model = decoder
def get_decoder(self):
return self.model
def loss_function(
self,
logits,
labels,
vocab_size: int,
num_items_in_batch: int = None,
ignore_index: int = -100,
**kwargs,
):
# Upcast to float if we need to compute the loss to avoid potential precision issues
logits = logits.float()
# Shift so that tokens < n predict n
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
# Flatten the tokens
shift_logits = shift_logits.view(-1, vocab_size)
shift_labels = shift_labels.view(-1)
# Enable model parallelism
shift_labels = shift_labels.to(shift_logits.device)
loss = fixed_cross_entropy(
shift_logits, shift_labels, num_items_in_batch, ignore_index, **kwargs
)
return loss
def forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
num_logits_to_keep: int = 0,
**kwargs,
) -> Union[Tuple, CausalLMOutputWithPast]:
"""
Args:
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the masked language modeling loss. Indices should be in
`[0, ..., config.vocab_size]` or -100 (masked tokens).
num_logits_to_keep (`int`, *optional*):
Calculate logits for the last `num_logits_to_keep` tokens. If `0`, calculate all logits.
"""
output_attentions = (
output_attentions
if output_attentions is not None
else self.config.output_attentions
)
output_hidden_states = (
output_hidden_states
if output_hidden_states is not None
else self.config.output_hidden_states
)
return_dict = (
return_dict if return_dict is not None else self.config.use_return_dict
)
outputs = self.model(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
cache_position=cache_position,
**kwargs,
)
hidden_states = outputs[0]
# Only compute necessary logits
logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :])
loss = None
if labels is not None:
# Calculate batch size for loss function
num_items_in_batch = (
input_ids.size(0) if input_ids is not None else inputs_embeds.size(0)
)
loss = self.loss_function(
logits=logits,
labels=labels,
vocab_size=self.config.vocab_size,
num_items_in_batch=num_items_in_batch,
**kwargs,
)
if not return_dict:
output = (logits,) + outputs[1:]
return (loss,) + output if loss is not None else output
return CausalLMOutputWithPast(
loss=loss,
logits=logits,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
def prepare_inputs_for_generation(
self,
input_ids,
past_key_values=None,
attention_mask=None,
inputs_embeds=None,
**kwargs,
):
if past_key_values:
input_ids = input_ids[:, -1:]
position_ids = kwargs.get("position_ids", None)
if attention_mask is not None and position_ids is None:
# create position_ids on the fly for batch generation
position_ids = attention_mask.long().cumsum(-1) - 1
position_ids.masked_fill_(attention_mask == 0, 1)
if past_key_values:
position_ids = position_ids[:, -1].unsqueeze(-1)
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
if inputs_embeds is not None and past_key_values is None:
model_inputs = {"inputs_embeds": inputs_embeds}
else:
model_inputs = {"input_ids": input_ids}
model_inputs.update(
{
"position_ids": position_ids,
"past_key_values": past_key_values,
"use_cache": kwargs.get("use_cache"),
"attention_mask": attention_mask,
}
)
return model_inputs
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
if isinstance(
pretrained_model_name_or_path, str
) and pretrained_model_name_or_path.endswith(".pt"):
print("Loading from local checkpoint")
config = kwargs.get("config", None)
if config is None:
config = AutoConfig.from_pretrained(
pretrained_model_name_or_path, trust_remote_code=True
)
model = torch.load(pretrained_model_name_or_path, map_location="cpu")
# model = cls(config)
# checkpoint = torch.load(pretrained_model_name_or_path, map_location="cpu")
# state_dict = checkpoint["model_state_dict"]
# missing_keys, unexpected_keys = model.load_state_dict(
# state_dict, strict=False
# )
# if len(missing_keys) > 0:
# logger.warning(f"Missing keys: {missing_keys}")
# if len(unexpected_keys) > 0:
# logger.warning(f"Unexpected keys: {unexpected_keys}")
return model
else:
print("Loading from hub")
return super().from_pretrained(
pretrained_model_name_or_path, *model_args, **kwargs
)