|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import math |
|
from typing import Optional, Tuple, Union |
|
|
|
import torch |
|
from torch import nn |
|
|
|
from transformers.cache_utils import Cache, DynamicCache |
|
from transformers.generation import GenerationMixin |
|
from transformers.configuration_utils import PretrainedConfig |
|
from transformers.modeling_attn_mask_utils import AttentionMaskConverter |
|
from transformers.modeling_flash_attention_utils import FlashAttentionKwargs |
|
from transformers.modeling_layers import GradientCheckpointingLayer |
|
from transformers.modeling_outputs import ( |
|
BaseModelOutputWithPast, |
|
CausalLMOutputWithPast, |
|
) |
|
from transformers.modeling_rope_utils import dynamic_rope_update |
|
from transformers.modeling_utils import PreTrainedModel |
|
from transformers.processing_utils import Unpack |
|
from transformers.utils import LossKwargs, auto_docstring, can_return_tuple, is_torch_flex_attn_available, logging |
|
|
|
from transformers.models.llama.modeling_llama import LlamaRMSNorm, apply_rotary_pos_emb, LlamaMLP |
|
|
|
if is_torch_flex_attn_available(): |
|
from torch.nn.attention.flex_attention import BlockMask |
|
|
|
from transformers.integrations.flex_attention import make_flex_block_causal_mask |
|
|
|
from .config import LlamaMlaConfig |
|
|
|
logger = logging.get_logger(__name__) |
|
|
|
def _compute_llama_mla_parameters( |
|
config: PretrainedConfig, device: "torch.device", seq_len: Optional[int] = None, head_dim: int = None, **rope_kwargs |
|
) -> tuple["torch.Tensor", float]: |
|
""" |
|
Computes the inverse frequencies for llama 3.1. |
|
|
|
Args: |
|
config ([`~transformers.PretrainedConfig`]): |
|
The model configuration. |
|
device (`torch.device`): |
|
The device to use for initialization of the inverse frequencies. |
|
seq_len (`int`, *optional*): |
|
The current sequence length. Unused for this type of RoPE. |
|
rope_kwargs (`Dict`, *optional*): |
|
BC compatibility with the previous RoPE class instantiation, will be removed in v4.45. |
|
Returns: |
|
Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the |
|
post-processing scaling factor applied to the computed cos/sin. |
|
""" |
|
|
|
if config is not None and len(rope_kwargs) > 0: |
|
raise ValueError( |
|
"Unexpected arguments: `**rope_kwargs` and `config` are mutually exclusive in " |
|
f"`_compute_default_rope_parameters`, got `rope_kwargs`={rope_kwargs} and `config`={config}" |
|
) |
|
if len(rope_kwargs) > 0: |
|
base = rope_kwargs["base"] |
|
dim = rope_kwargs["dim"] |
|
elif config is not None: |
|
base = config.rope_theta |
|
partial_rotary_factor = config.partial_rotary_factor if hasattr(config, "partial_rotary_factor") else 1.0 |
|
head_dim = head_dim or getattr(config, "head_dim", None) or config.hidden_size // config.num_attention_heads |
|
dim = int(head_dim * partial_rotary_factor) |
|
|
|
attention_factor = 1.0 |
|
|
|
|
|
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.int64).to(device=device, dtype=torch.float) / dim)) |
|
|
|
factor = config.rope_scaling["factor"] |
|
low_freq_factor = config.rope_scaling["low_freq_factor"] |
|
high_freq_factor = config.rope_scaling["high_freq_factor"] |
|
old_context_len = config.rope_scaling["original_max_position_embeddings"] |
|
|
|
low_freq_wavelen = old_context_len / low_freq_factor |
|
high_freq_wavelen = old_context_len / high_freq_factor |
|
|
|
wavelen = 2 * math.pi / inv_freq |
|
|
|
|
|
inv_freq_llama = torch.where(wavelen > low_freq_wavelen, inv_freq / factor, inv_freq) |
|
|
|
smooth_factor = (old_context_len / wavelen - low_freq_factor) / (high_freq_factor - low_freq_factor) |
|
smoothed_inv_freq = (1 - smooth_factor) * inv_freq_llama / factor + smooth_factor * inv_freq_llama |
|
is_medium_freq = ~(wavelen < high_freq_wavelen) * ~(wavelen > low_freq_wavelen) |
|
inv_freq_llama = torch.where(is_medium_freq, smoothed_inv_freq, inv_freq_llama) |
|
|
|
return inv_freq_llama, attention_factor |
|
|
|
class LlamaMlaRotaryEmbedding(nn.Module): |
|
def __init__(self, config: LlamaMlaConfig, device=None, head_dim: int = None): |
|
super().__init__() |
|
|
|
if hasattr(config, "rope_scaling") and config.rope_scaling is not None: |
|
self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type")) |
|
else: |
|
self.rope_type = "default" |
|
self.max_seq_len_cached = config.max_position_embeddings |
|
self.original_max_seq_len = config.max_position_embeddings |
|
|
|
self.config = config |
|
self.rope_init_fn = _compute_llama_mla_parameters |
|
|
|
inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, head_dim=head_dim) |
|
self.register_buffer("inv_freq", inv_freq, persistent=False) |
|
self.original_inv_freq = self.inv_freq |
|
|
|
@torch.no_grad() |
|
@dynamic_rope_update |
|
def forward(self, x, position_ids): |
|
inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device) |
|
position_ids_expanded = position_ids[:, None, :].float() |
|
|
|
device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu" |
|
with torch.autocast(device_type=device_type, enabled=False): |
|
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) |
|
emb = torch.cat((freqs, freqs), dim=-1) |
|
cos = emb.cos() * self.attention_scaling |
|
sin = emb.sin() * self.attention_scaling |
|
|
|
return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) |
|
|
|
class LlamaMlaAttention(nn.Module): |
|
"""Multi-headed Latent attention from 'DeepSeek-V2'""" |
|
|
|
def __init__(self, config: LlamaMlaConfig, layer_idx: Optional[int] = None): |
|
super().__init__() |
|
self.config = config |
|
self.layer_idx = layer_idx |
|
if layer_idx is None: |
|
logger.warning_once( |
|
f"Instantiating {self.__class__.__name__} without passing `layer_idx` is not recommended and will " |
|
"to errors during the forward call, if caching is used. Please make sure to provide a `layer_idx` " |
|
"when creating this class." |
|
) |
|
|
|
self.attention_dropout = config.attention_dropout |
|
self.hidden_size = config.hidden_size |
|
self.num_heads = config.num_attention_heads |
|
|
|
self.max_position_embeddings = config.max_position_embeddings |
|
self.rope_theta = config.rope_theta |
|
self.q_lora_rank = config.q_lora_rank |
|
self.qk_rope_head_dim = config.qk_rope_head_dim |
|
self.kv_lora_rank = config.kv_lora_rank |
|
self.v_head_dim = config.v_head_dim |
|
self.qk_nope_head_dim = config.qk_nope_head_dim |
|
self.q_head_dim = config.qk_nope_head_dim + config.qk_rope_head_dim |
|
|
|
self.is_causal = True |
|
|
|
if self.q_lora_rank is None: |
|
self.q_proj = nn.Linear( |
|
self.hidden_size, self.num_heads * self.q_head_dim, bias=False |
|
) |
|
else: |
|
self.q_a_proj = nn.Linear( |
|
self.hidden_size, config.q_lora_rank, bias=config.attention_bias |
|
) |
|
self.q_a_layernorm = LlamaRMSNorm(config.q_lora_rank) |
|
self.q_b_proj = nn.Linear( |
|
config.q_lora_rank, self.num_heads * self.q_head_dim, bias=False |
|
) |
|
|
|
self.kv_a_proj_with_mqa = nn.Linear( |
|
self.hidden_size, |
|
config.kv_lora_rank + config.qk_rope_head_dim, |
|
bias=config.attention_bias, |
|
) |
|
self.kv_a_layernorm = LlamaRMSNorm(config.kv_lora_rank) |
|
self.kv_b_proj = nn.Linear( |
|
config.kv_lora_rank, |
|
self.num_heads |
|
* (self.q_head_dim - self.qk_rope_head_dim + self.v_head_dim), |
|
bias=False, |
|
) |
|
|
|
self.o_proj = nn.Linear( |
|
self.num_heads * self.v_head_dim, |
|
self.hidden_size, |
|
bias=config.attention_bias, |
|
) |
|
|
|
self.rotary_emb = LlamaMlaRotaryEmbedding(config=config, head_dim=self.qk_rope_head_dim) |
|
|
|
self.softmax_scale = self.q_head_dim ** (-0.5) |
|
|
|
def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): |
|
return ( |
|
tensor.view(bsz, seq_len, self.num_heads, self.v_head_dim) |
|
.transpose(1, 2) |
|
.contiguous() |
|
) |
|
|
|
def forward( |
|
self, |
|
hidden_states: torch.Tensor, |
|
attention_mask: Optional[torch.Tensor] = None, |
|
position_ids: Optional[torch.LongTensor] = None, |
|
past_key_value: Optional[Cache] = None, |
|
output_attentions: bool = False, |
|
use_cache: bool = False, |
|
**kwargs, |
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: |
|
if "padding_mask" in kwargs: |
|
logger.warning_once( |
|
"Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`" |
|
) |
|
bsz, q_len, _ = hidden_states.size() |
|
|
|
if self.q_lora_rank is None: |
|
q = self.q_proj(hidden_states) |
|
else: |
|
q = self.q_b_proj(self.q_a_layernorm(self.q_a_proj(hidden_states))) |
|
q = q.view(bsz, q_len, self.num_heads, self.q_head_dim).transpose(1, 2) |
|
q_nope, q_pe = torch.split( |
|
q, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1 |
|
) |
|
|
|
compressed_kv = self.kv_a_proj_with_mqa(hidden_states) |
|
compressed_kv, k_pe = torch.split( |
|
compressed_kv, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1 |
|
) |
|
k_pe = k_pe.view(bsz, q_len, 1, self.qk_rope_head_dim).transpose(1, 2) |
|
kv = ( |
|
self.kv_b_proj(self.kv_a_layernorm(compressed_kv)) |
|
.view(bsz, q_len, self.num_heads, self.qk_nope_head_dim + self.v_head_dim) |
|
.transpose(1, 2) |
|
) |
|
|
|
k_nope, value_states = torch.split( |
|
kv, [self.qk_nope_head_dim, self.v_head_dim], dim=-1 |
|
) |
|
kv_seq_len = value_states.shape[-2] |
|
if past_key_value is not None: |
|
if self.layer_idx is None: |
|
raise ValueError( |
|
f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} " |
|
"for auto-regressive decoding with k/v caching, please make sure to initialize the attention class " |
|
"with a layer index." |
|
) |
|
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) |
|
cos, sin = self.rotary_emb(value_states, position_ids=position_ids) |
|
|
|
q_pe, k_pe = apply_rotary_pos_emb(q_pe, k_pe, cos, sin, position_ids) |
|
|
|
query_states = k_pe.new_empty(bsz, self.num_heads, q_len, self.q_head_dim) |
|
query_states[:, :, :, : self.qk_nope_head_dim] = q_nope |
|
query_states[:, :, :, self.qk_nope_head_dim :] = q_pe |
|
|
|
key_states = k_pe.new_empty(bsz, self.num_heads, q_len, self.q_head_dim) |
|
key_states[:, :, :, : self.qk_nope_head_dim] = k_nope |
|
key_states[:, :, :, self.qk_nope_head_dim :] = k_pe |
|
if past_key_value is not None: |
|
cache_kwargs = {"sin": sin, "cos": cos} |
|
key_states, value_states = past_key_value.update( |
|
key_states, value_states, self.layer_idx, cache_kwargs |
|
) |
|
|
|
attn_weights = ( |
|
torch.matmul(query_states, key_states.transpose(2, 3)) * self.softmax_scale |
|
) |
|
|
|
if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len): |
|
raise ValueError( |
|
f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is" |
|
f" {attn_weights.size()}" |
|
) |
|
assert attention_mask is not None |
|
if attention_mask is not None: |
|
if attention_mask.size() != (bsz, 1, q_len, kv_seq_len): |
|
raise ValueError( |
|
f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}" |
|
) |
|
attn_weights = attn_weights + attention_mask |
|
|
|
|
|
attn_weights = nn.functional.softmax( |
|
attn_weights, dim=-1, dtype=torch.float32 |
|
).to(query_states.dtype) |
|
attn_weights = nn.functional.dropout( |
|
attn_weights, p=self.attention_dropout, training=self.training |
|
) |
|
attn_output = torch.matmul(attn_weights, value_states) |
|
|
|
if attn_output.size() != (bsz, self.num_heads, q_len, self.v_head_dim): |
|
raise ValueError( |
|
f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.v_head_dim)}, but is" |
|
f" {attn_output.size()}" |
|
) |
|
|
|
attn_output = attn_output.transpose(1, 2).contiguous() |
|
|
|
attn_output = attn_output.reshape(bsz, q_len, self.num_heads * self.v_head_dim) |
|
|
|
attn_output = self.o_proj(attn_output) |
|
|
|
if not output_attentions: |
|
attn_weights = None |
|
|
|
return attn_output, attn_weights |
|
|
|
|
|
class LlamaMlaDecoderLayer(GradientCheckpointingLayer): |
|
def __init__(self, config: LlamaMlaConfig, layer_idx: int): |
|
super().__init__() |
|
self.hidden_size = config.hidden_size |
|
|
|
self.self_attn = LlamaMlaAttention(config=config, layer_idx=layer_idx) |
|
|
|
self.mlp = LlamaMLP(config) |
|
self.input_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) |
|
self.post_attention_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) |
|
|
|
def forward( |
|
self, |
|
hidden_states: torch.Tensor, |
|
attention_mask: Optional[torch.Tensor] = None, |
|
position_ids: Optional[torch.LongTensor] = None, |
|
past_key_value: Optional[Cache] = None, |
|
output_attentions: Optional[bool] = False, |
|
use_cache: Optional[bool] = False, |
|
cache_position: Optional[torch.LongTensor] = None, |
|
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, |
|
**kwargs: Unpack[FlashAttentionKwargs], |
|
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: |
|
residual = hidden_states |
|
hidden_states = self.input_layernorm(hidden_states) |
|
|
|
|
|
hidden_states, self_attn_weights = self.self_attn( |
|
hidden_states=hidden_states, |
|
attention_mask=attention_mask, |
|
position_ids=position_ids, |
|
past_key_value=past_key_value, |
|
output_attentions=output_attentions, |
|
use_cache=use_cache, |
|
cache_position=cache_position, |
|
position_embeddings=position_embeddings, |
|
**kwargs, |
|
) |
|
hidden_states = residual + hidden_states |
|
|
|
|
|
residual = hidden_states |
|
hidden_states = self.post_attention_layernorm(hidden_states) |
|
hidden_states = self.mlp(hidden_states) |
|
hidden_states = residual + hidden_states |
|
|
|
outputs = (hidden_states,) |
|
if output_attentions: |
|
outputs += (self_attn_weights,) |
|
|
|
return outputs |
|
|
|
|
|
@auto_docstring |
|
class LlamaMlaPreTrainedModel(PreTrainedModel): |
|
config_class = LlamaMlaConfig |
|
base_model_prefix = "model" |
|
supports_gradient_checkpointing = True |
|
_no_split_modules = ["LlamaMlaDecoderLayer"] |
|
_skip_keys_device_placement = ["past_key_values"] |
|
_supports_flash_attn_2 = True |
|
_supports_sdpa = True |
|
_supports_flex_attn = True |
|
_supports_cache_class = True |
|
_supports_quantized_cache = True |
|
_supports_static_cache = True |
|
_supports_attention_backend = True |
|
|
|
def _init_weights(self, module): |
|
std = self.config.initializer_range |
|
if isinstance(module, nn.Linear): |
|
module.weight.data.normal_(mean=0.0, std=std) |
|
if module.bias is not None: |
|
module.bias.data.zero_() |
|
elif isinstance(module, nn.Embedding): |
|
module.weight.data.normal_(mean=0.0, std=std) |
|
if module.padding_idx is not None: |
|
module.weight.data[module.padding_idx].zero_() |
|
elif isinstance(module, LlamaRMSNorm): |
|
module.weight.data.fill_(1.0) |
|
|
|
|
|
@auto_docstring |
|
class LlamaMlaModel(LlamaMlaPreTrainedModel): |
|
def __init__(self, config: LlamaMlaConfig): |
|
super().__init__(config) |
|
self.padding_idx = config.pad_token_id |
|
self.vocab_size = config.vocab_size |
|
|
|
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) |
|
self.layers = nn.ModuleList( |
|
[LlamaMlaDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] |
|
) |
|
self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) |
|
self.rotary_emb = LlamaMlaRotaryEmbedding(config=config) |
|
self.gradient_checkpointing = False |
|
|
|
|
|
self.post_init() |
|
|
|
def get_input_embeddings(self): |
|
return self.embed_tokens |
|
|
|
def set_input_embeddings(self, value): |
|
self.embed_tokens = value |
|
|
|
@can_return_tuple |
|
@auto_docstring |
|
def forward( |
|
self, |
|
input_ids: Optional[torch.LongTensor] = None, |
|
attention_mask: Optional[torch.Tensor] = None, |
|
position_ids: Optional[torch.LongTensor] = None, |
|
past_key_values: Optional[Cache] = None, |
|
inputs_embeds: Optional[torch.FloatTensor] = None, |
|
use_cache: Optional[bool] = None, |
|
output_attentions: Optional[bool] = None, |
|
output_hidden_states: Optional[bool] = None, |
|
cache_position: Optional[torch.LongTensor] = None, |
|
**flash_attn_kwargs: Unpack[FlashAttentionKwargs], |
|
) -> BaseModelOutputWithPast: |
|
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions |
|
output_hidden_states = ( |
|
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states |
|
) |
|
use_cache = use_cache if use_cache is not None else self.config.use_cache |
|
|
|
if (input_ids is None) ^ (inputs_embeds is not None): |
|
raise ValueError("You must specify exactly one of input_ids or inputs_embeds") |
|
|
|
if self.gradient_checkpointing and self.training and use_cache: |
|
logger.warning_once( |
|
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`." |
|
) |
|
use_cache = False |
|
|
|
|
|
if not isinstance(past_key_values, (type(None), Cache)): |
|
raise ValueError("The `past_key_values` should be either a `Cache` object or `None`.") |
|
|
|
if inputs_embeds is None: |
|
inputs_embeds = self.embed_tokens(input_ids) |
|
|
|
if use_cache and past_key_values is None: |
|
past_key_values = DynamicCache() |
|
|
|
if cache_position is None: |
|
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 |
|
cache_position = torch.arange( |
|
past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device |
|
) |
|
|
|
if position_ids is None: |
|
position_ids = cache_position.unsqueeze(0) |
|
|
|
causal_mask = self._update_causal_mask( |
|
attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions |
|
) |
|
|
|
hidden_states = inputs_embeds |
|
|
|
|
|
position_embeddings = self.rotary_emb(hidden_states, position_ids) |
|
|
|
|
|
all_hidden_states = () if output_hidden_states else None |
|
all_self_attns = () if output_attentions else None |
|
|
|
for decoder_layer in self.layers[: self.config.num_hidden_layers]: |
|
if output_hidden_states: |
|
all_hidden_states += (hidden_states,) |
|
|
|
layer_outputs = decoder_layer( |
|
hidden_states, |
|
attention_mask=causal_mask, |
|
position_ids=position_ids, |
|
past_key_value=past_key_values, |
|
output_attentions=output_attentions, |
|
use_cache=use_cache, |
|
cache_position=cache_position, |
|
position_embeddings=position_embeddings, |
|
**flash_attn_kwargs, |
|
) |
|
|
|
hidden_states = layer_outputs[0] |
|
|
|
if output_attentions: |
|
all_self_attns += (layer_outputs[1],) |
|
|
|
hidden_states = self.norm(hidden_states) |
|
|
|
|
|
if output_hidden_states: |
|
all_hidden_states += (hidden_states,) |
|
|
|
return BaseModelOutputWithPast( |
|
last_hidden_state=hidden_states, |
|
past_key_values=past_key_values if use_cache else None, |
|
hidden_states=all_hidden_states, |
|
attentions=all_self_attns, |
|
) |
|
|
|
def _update_causal_mask( |
|
self, |
|
attention_mask: Union[torch.Tensor, "BlockMask"], |
|
input_tensor: torch.Tensor, |
|
cache_position: torch.Tensor, |
|
past_key_values: Cache, |
|
output_attentions: bool = False, |
|
): |
|
if self.config._attn_implementation == "flash_attention_2": |
|
if attention_mask is not None and (attention_mask == 0.0).any(): |
|
return attention_mask |
|
return None |
|
if self.config._attn_implementation == "flex_attention": |
|
if isinstance(attention_mask, torch.Tensor): |
|
attention_mask = make_flex_block_causal_mask(attention_mask) |
|
return attention_mask |
|
|
|
|
|
|
|
|
|
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 |
|
using_compilable_cache = past_key_values.is_compileable if past_key_values is not None else False |
|
|
|
|
|
if self.config._attn_implementation == "sdpa" and not using_compilable_cache and not output_attentions: |
|
if AttentionMaskConverter._ignore_causal_mask_sdpa( |
|
attention_mask, |
|
inputs_embeds=input_tensor, |
|
past_key_values_length=past_seen_tokens, |
|
is_training=self.training, |
|
): |
|
return None |
|
|
|
dtype = input_tensor.dtype |
|
sequence_length = input_tensor.shape[1] |
|
if using_compilable_cache: |
|
target_length = past_key_values.get_max_cache_shape() |
|
else: |
|
target_length = ( |
|
attention_mask.shape[-1] |
|
if isinstance(attention_mask, torch.Tensor) |
|
else past_seen_tokens + sequence_length + 1 |
|
) |
|
|
|
|
|
causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position( |
|
attention_mask, |
|
sequence_length=sequence_length, |
|
target_length=target_length, |
|
dtype=dtype, |
|
cache_position=cache_position, |
|
batch_size=input_tensor.shape[0], |
|
) |
|
|
|
if ( |
|
self.config._attn_implementation == "sdpa" |
|
and attention_mask is not None |
|
and attention_mask.device.type in ["cuda", "xpu", "npu"] |
|
and not output_attentions |
|
): |
|
|
|
|
|
|
|
min_dtype = torch.finfo(dtype).min |
|
causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype) |
|
|
|
return causal_mask |
|
|
|
@staticmethod |
|
def _prepare_4d_causal_attention_mask_with_cache_position( |
|
attention_mask: torch.Tensor, |
|
sequence_length: int, |
|
target_length: int, |
|
dtype: torch.dtype, |
|
cache_position: torch.Tensor, |
|
batch_size: int, |
|
**kwargs, |
|
): |
|
""" |
|
Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape |
|
`(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing. |
|
|
|
Args: |
|
attention_mask (`torch.Tensor`): |
|
A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape |
|
`(batch_size, 1, query_length, key_value_length)`. |
|
sequence_length (`int`): |
|
The sequence length being processed. |
|
target_length (`int`): |
|
The target length: when generating with static cache, the mask should be as long as the static cache, |
|
to account for the 0 padding, the part of the cache that is not filled yet. |
|
dtype (`torch.dtype`): |
|
The dtype to use for the 4D attention mask. |
|
cache_position (`torch.Tensor`): |
|
Indices depicting the position of the input sequence tokens in the sequence. |
|
batch_size (`torch.Tensor`): |
|
Batch size. |
|
""" |
|
if attention_mask is not None and attention_mask.dim() == 4: |
|
|
|
causal_mask = attention_mask |
|
else: |
|
min_dtype = torch.finfo(dtype).min |
|
causal_mask = torch.full( |
|
(sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device |
|
) |
|
if sequence_length != 1: |
|
causal_mask = torch.triu(causal_mask, diagonal=1) |
|
causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1) |
|
causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1) |
|
if attention_mask is not None: |
|
causal_mask = causal_mask.clone() |
|
mask_length = attention_mask.shape[-1] |
|
padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to( |
|
causal_mask.device |
|
) |
|
padding_mask = padding_mask == 0 |
|
causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( |
|
padding_mask, min_dtype |
|
) |
|
|
|
return causal_mask |
|
|
|
|
|
class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ... |
|
|
|
|
|
@auto_docstring |
|
class LlamaMlaForCausalLM(LlamaMlaPreTrainedModel, GenerationMixin): |
|
_tied_weights_keys = ["lm_head.weight"] |
|
_tp_plan = {"lm_head": "colwise_rep"} |
|
_pp_plan = {"lm_head": (["hidden_states"], ["logits"])} |
|
|
|
def __init__(self, config): |
|
super().__init__(config) |
|
self.model = LlamaMlaModel(config) |
|
self.vocab_size = config.vocab_size |
|
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) |
|
|
|
|
|
self.post_init() |
|
|
|
def get_input_embeddings(self): |
|
return self.model.embed_tokens |
|
|
|
def set_input_embeddings(self, value): |
|
self.model.embed_tokens = value |
|
|
|
def get_output_embeddings(self): |
|
return self.lm_head |
|
|
|
def set_output_embeddings(self, new_embeddings): |
|
self.lm_head = new_embeddings |
|
|
|
def set_decoder(self, decoder): |
|
self.model = decoder |
|
|
|
def get_decoder(self): |
|
return self.model |
|
|
|
@can_return_tuple |
|
@auto_docstring |
|
def forward( |
|
self, |
|
input_ids: Optional[torch.LongTensor] = None, |
|
attention_mask: Optional[torch.Tensor] = None, |
|
position_ids: Optional[torch.LongTensor] = None, |
|
past_key_values: Optional[Cache] = None, |
|
inputs_embeds: Optional[torch.FloatTensor] = None, |
|
labels: Optional[torch.LongTensor] = None, |
|
use_cache: Optional[bool] = None, |
|
output_attentions: Optional[bool] = None, |
|
output_hidden_states: Optional[bool] = None, |
|
cache_position: Optional[torch.LongTensor] = None, |
|
logits_to_keep: Union[int, torch.Tensor] = 0, |
|
**kwargs: Unpack[KwargsForCausalLM], |
|
) -> CausalLMOutputWithPast: |
|
r""" |
|
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): |
|
Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., |
|
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored |
|
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. |
|
|
|
Example: |
|
|
|
```python |
|
>>> from transformers import AutoTokenizer, LlamaForCausalLM |
|
|
|
>>> model = LlamaForCausalLM.from_pretrained("meta-llama/Llama-2-7b-hf") |
|
>>> tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf") |
|
|
|
>>> prompt = "Hey, are you conscious? Can you talk to me?" |
|
>>> inputs = tokenizer(prompt, return_tensors="pt") |
|
|
|
>>> # Generate |
|
>>> generate_ids = model.generate(inputs.input_ids, max_length=30) |
|
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] |
|
"Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." |
|
```""" |
|
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions |
|
output_hidden_states = ( |
|
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states |
|
) |
|
|
|
|
|
outputs: BaseModelOutputWithPast = self.model( |
|
input_ids=input_ids, |
|
attention_mask=attention_mask, |
|
position_ids=position_ids, |
|
past_key_values=past_key_values, |
|
inputs_embeds=inputs_embeds, |
|
use_cache=use_cache, |
|
output_attentions=output_attentions, |
|
output_hidden_states=output_hidden_states, |
|
cache_position=cache_position, |
|
**kwargs, |
|
) |
|
|
|
hidden_states = outputs.last_hidden_state |
|
|
|
slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep |
|
logits = self.lm_head(hidden_states[:, slice_indices, :]) |
|
|
|
loss = None |
|
if labels is not None: |
|
loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs) |
|
|
|
return CausalLMOutputWithPast( |
|
loss=loss, |
|
logits=logits, |
|
past_key_values=outputs.past_key_values, |
|
hidden_states=outputs.hidden_states, |
|
attentions=outputs.attentions, |
|
) |
|
|
|
__all__ = [ |
|
"LlamaMlaForCausalLM", |
|
"LlamaMlaModel", |
|
"LlamaMlaPreTrainedModel", |
|
] |
|
|