from dataclasses import dataclass import glob import json from pathlib import Path from typing import Dict, Optional, List, Tuple, Union import math import warnings import mlx.core as mx import mlx.nn as nn import logging # from llms.mlx_lm.models.base import BaseModelArgs from configuration_phi3_v import Phi3VConfig from utils import BaseModelOutputWithPast, FloatTensor, LongTensor, Cache, DynamicCache, CausalLMOutputWithPast from image_embedding_phi3_v import Phi3ImageEmbedding from attn_mask import _prepare_4d_causal_attention_mask from huggingface_hub import snapshot_download class Phi3RotaryEmbedding(nn.Module): def __init__(self, dim, max_position_embeddings=2048, base=10000): super().__init__() self.dim = dim self.max_position_embeddings = max_position_embeddings self.base = base def __call__(self, x, position_ids, seq_len=None): if self.inv_freq is None: self.inv_freq = 1.0 / ( self.base ** (mx.arange(0, self.dim, 2, Dtype=mx.int64, device=x.device).float() / self.dim) ) inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) position_ids_expanded = position_ids[:, None, :].float() freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) emb = mx.concatenate((freqs, freqs), dim=-1) cos = emb.cos() sin = emb.sin() return cos.to(Dtype=x.Dtype), sin.to(Dtype=x.Dtype) class Phi3SuScaledRotaryEmbedding(Phi3RotaryEmbedding): def __init__(self, dim, config): super().__init__(dim, config.max_position_embeddings, config.rope_theta) self.short_factor = config.rope_scaling["short_factor"] self.long_factor = config.rope_scaling["long_factor"] self.original_max_position_embeddings = config.original_max_position_embeddings def __call__(self, x, position_ids, seq_len=None): seq_len = mx.max(position_ids) + 1 if seq_len > self.original_max_position_embeddings: ext_factors = mx.array(self.long_factor, Dtype=mx.float32) else: ext_factors = mx.array(self.short_factor, Dtype=mx.float32) inv_freq_shape = mx.arange(0, self.dim, 2, Dtype=mx.int64).float() / self.dim self.inv_freq = 1.0 / (ext_factors * self.base**inv_freq_shape) inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) position_ids_expanded = position_ids[:, None, :].float() freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) emb = mx.concatenate((freqs, freqs), dim=-1) scale = self.max_position_embeddings / self.original_max_position_embeddings if scale <= 1.0: scaling_factor = 1.0 else: scaling_factor = math.sqrt(1 + math.log(scale) / math.log(self.original_max_position_embeddings)) cos = emb.cos() * scaling_factor sin = emb.sin() * scaling_factor return cos.to(Dtype=x.Dtype), sin.to(Dtype=x.Dtype) class Phi3YarnScaledRotaryEmbedding(Phi3RotaryEmbedding): def __init__(self, dim, config): super().__init__(dim, config.max_position_embeddings, config.rope_theta) self.short_factor = config.rope_scaling["short_factor"] self.long_factor = config.rope_scaling["long_factor"] self.original_max_position_embeddings = config.original_max_position_embeddings def __call__(self, x, position_ids, seq_len=None): seq_len = mx.max(position_ids) + 1 if seq_len > self.original_max_position_embeddings: ext_factors = mx.array(self.long_factor, Dtype=mx.float32) else: ext_factors = mx.array(self.short_factor, Dtype=mx.float32) inv_freq_shape = mx.arange(0, self.dim, 2, Dtype=mx.int64).float() / self.dim self.inv_freq = 1.0 / (ext_factors * self.base**inv_freq_shape) inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) position_ids_expanded = position_ids[:, None, :].float() device_type = x.device.type device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu" freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) emb = mx.concatenate((freqs, freqs), dim=-1) scale = self.max_position_embeddings / self.original_max_position_embeddings if scale <= 1.0: scaling_factor = 1.0 else: scaling_factor = 0.1 * math.log(scale) + 1.0 cos = emb.cos() * scaling_factor sin = emb.sin() * scaling_factor return cos.to(Dtype=x.Dtype), sin.to(Dtype=x.Dtype) def rotate_half(x): x1 = x[..., : x.shape[-1] // 2] x2 = x[..., x.shape[-1] // 2 :] return mx.concatenate((-x2, x1), dim=-1) def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): cos = cos.unsqueeze(unsqueeze_dim) sin = sin.unsqueeze(unsqueeze_dim) q_embed = (q * cos) + (rotate_half(q) * sin) k_embed = (k * cos) + (rotate_half(k) * sin) return q_embed, k_embed class Phi3MLP(nn.Module): def __init__(self, config: Phi3VConfig): super().__init__() self.config = config self.gate_up_proj = nn.Linear(config.hidden_size, 2 * config.intermediate_size, bias=False) self.down_proj = nn.Linear(config.intermediate_size, config.hidden_size, bias=False) def __call__(self, x) -> mx.array: x = self.gate_up_proj(x) gate, x = mx.split(x, 2, axis=-1) return self.down_proj(nn.silu(gate) * x) def repeat_kv(hidden_states: mx.array, n_rep: int) -> mx.array: batch, num_key_value_heads, slen, head_dim = hidden_states.shape if n_rep == 1: return hidden_states hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) class Phi3Attention(nn.Module): def __init__(self, config: Phi3VConfig, layer_idx: Optional[int] = None): super().__init__() self.config = config self.layer_idx = layer_idx if layer_idx is None: logging.warning( "Instantiating %s without passing a `layer_idx` is not recommended and will " "lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` " "when creating this class.", self.__class__.__name__, ) self.attention_dropout = config.attention_dropout self.hidden_size = config.hidden_size self.num_heads = config.num_attention_heads self.head_dim = self.hidden_size // self.num_heads self.num_key_value_heads = config.num_key_value_heads self.num_key_value_groups = self.num_heads // self.num_key_value_heads self.max_position_embeddings = config.max_position_embeddings self.original_max_position_embeddings = config.original_max_position_embeddings self.rope_theta = config.rope_theta self.rope_scaling = config.rope_scaling self.is_causal = True if (self.head_dim * self.num_heads) != self.hidden_size: raise ValueError( f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}" f" and `num_heads`: {self.num_heads})." ) op_size = self.num_heads * self.head_dim + 2 * (self.num_key_value_heads * self.head_dim) self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False) self.qkv_proj = nn.Linear(self.hidden_size, op_size, bias=False) self._init_rope() def _init_rope(self): if self.rope_scaling is None: self.rotary_emb = Phi3RotaryEmbedding( self.head_dim, max_position_embeddings=self.max_position_embeddings, base=self.rope_theta, ) else: scaling_type = self.config.rope_scaling["type"] if scaling_type == "su": self.rotary_emb = Phi3SuScaledRotaryEmbedding(self.head_dim, self.config) elif scaling_type == "yarn": self.rotary_emb = Phi3YarnScaledRotaryEmbedding(self.head_dim, self.config) else: raise ValueError(f"Unknown RoPE scaling type {scaling_type}") def __call__( self, hidden_states: mx.array, attention_mask: Optional[mx.array] = None, position_ids: Optional[LongTensor] = None, past_key_value: Optional[Tuple[mx.array, mx.array]] = None, output_attentions: bool = False, use_cache: bool = False, ) -> Tuple[mx.array, Optional[mx.array], Optional[Tuple[mx.array]]]: logging.warning("You are not running the flash-attention implementation, expect numerical differences.") bsz, q_len, _ = hidden_states.size() qkv = self.qkv_proj(hidden_states) query_pos = self.num_heads * self.head_dim query_states = qkv[..., :query_pos] key_states = qkv[..., query_pos : query_pos + self.num_key_value_heads * self.head_dim] value_states = qkv[..., query_pos + self.num_key_value_heads * self.head_dim :] query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) kv_seq_len = key_states.shape[-2] if past_key_value is not None: if self.layer_idx is None: raise ValueError( f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} " "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class " "with a layer index." ) kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) cos, sin = self.rotary_emb(value_states, position_ids, seq_len=kv_seq_len) query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) if past_key_value is not None: cache_kwargs = {"sin": sin, "cos": cos} key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) key_states = repeat_kv(key_states, self.num_key_value_groups) value_states = repeat_kv(value_states, self.num_key_value_groups) attn_weights = mx.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len): raise ValueError( f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is" f" {attn_weights.size()}" ) if attention_mask is not None: if attention_mask.size() != (bsz, 1, q_len, kv_seq_len): raise ValueError( f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}" ) attn_weights = attn_weights + attention_mask attn_weights = mx.softmax(attn_weights, dim=-1, Dtype=mx.float32).to(value_states.Dtype) attn_weights = mx.Dropout(attn_weights, p=self.attention_dropout, training=self.training) attn_output = mx.matmul(attn_weights, value_states) if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): raise ValueError( f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is" f" {attn_output.size()}" ) attn_output = attn_output.transpose(1, 2).contiguous() attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) attn_output = self.o_proj(attn_output) if not output_attentions: attn_weights = None return attn_output, attn_weights, past_key_value class Phi3SdpaAttention(Phi3Attention): def __call__( self, hidden_states: mx.array, attention_mask: Optional[mx.array] = None, position_ids: Optional[LongTensor] = None, past_key_value: Optional[Tuple[mx.array, mx.array]] = None, output_attentions: bool = False, use_cache: bool = False, ) -> Tuple[mx.array, Optional[mx.array], Optional[Tuple[mx.array]]]: if output_attentions: logging.warning( "Phi3Model is using Phi3SdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, " 'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' ) return super().__call__( hidden_states=hidden_states, attention_mask=attention_mask, position_ids=position_ids, past_key_value=past_key_value, output_attentions=output_attentions, use_cache=use_cache, ) bsz, q_len, _ = hidden_states.size() qkv = self.qkv_proj(hidden_states) query_pos = self.num_heads * self.head_dim query_states = qkv[..., :query_pos] key_states = qkv[..., query_pos : query_pos + self.num_key_value_heads * self.head_dim] value_states = qkv[..., query_pos + self.num_key_value_heads * self.head_dim :] query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) kv_seq_len = key_states.shape[-2] if past_key_value is not None: kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) cos, sin = self.rotary_emb(value_states, position_ids, seq_len=kv_seq_len) query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) if past_key_value is not None: cache_kwargs = {"sin": sin, "cos": cos} key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) key_states = repeat_kv(key_states, self.num_key_value_groups) value_states = repeat_kv(value_states, self.num_key_value_groups) if attention_mask is not None: if attention_mask.size() != (bsz, 1, q_len, kv_seq_len): raise ValueError( f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}" ) if query_states.device.type == "cuda" and attention_mask is not None: query_states = query_states.contiguous() key_states = key_states.contiguous() value_states = value_states.contiguous() attn_output = mx.fast.scaled_dot_product_attention( query_states, key_states, value_states, attn_mask=attention_mask, dropout_p=self.attention_dropout if self.training else 0.0, is_causal=self.is_causal and attention_mask is None and q_len > 1, ) attn_output = attn_output.transpose(1, 2).contiguous() attn_output = attn_output.view(bsz, q_len, self.hidden_size) attn_output = self.o_proj(attn_output) return attn_output, None, past_key_value PHI3_ATTENTION_CLASSES = { "eager": Phi3Attention, "sdpa": Phi3SdpaAttention, } class Phi3DecoderLayer(nn.Module): def __init__(self, config: Phi3VConfig, layer_idx: int): super().__init__() self.config = config self.self_attn = PHI3_ATTENTION_CLASSES[config._attn_implementation](config, layer_idx=layer_idx) self.mlp = Phi3MLP(config) self.input_layernorm = nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.resid_attn_dropout = nn.Dropout(config.resid_pdrop) self.resid_mlp_dropout = nn.Dropout(config.resid_pdrop) self.post_attention_layernorm = nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps) def __call__( self, hidden_states: mx.array, attention_mask: Optional[mx.array] = None, position_ids: Optional[LongTensor] = None, past_key_value: Optional[Tuple[mx.array]] = None, output_attentions: bool = False, use_cache: bool = False, **kwargs, ) -> Tuple[mx.array, Optional[Tuple[FloatTensor, FloatTensor]]]: if "padding_mask" in kwargs: warnings.warn( "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`" ) residual = hidden_states hidden_states = self.input_layernorm(hidden_states) attn_outputs, self_attn_weights, present_key_value = self.self_attn( hidden_states=hidden_states, attention_mask=attention_mask, position_ids=position_ids, past_key_value=past_key_value, output_attentions=output_attentions, use_cache=use_cache, ) hidden_states = residual + self.resid_attn_dropout(attn_outputs) residual = hidden_states hidden_states = self.post_attention_layernorm(hidden_states) hidden_states = self.mlp(hidden_states) hidden_states = residual + self.resid_mlp_dropout(hidden_states) outputs = (hidden_states,) if output_attentions: outputs += (self_attn_weights,) if use_cache: outputs += (present_key_value,) return outputs class Phi3VPreTrainedModel(nn.Module): config_class = Phi3VConfig base_model_prefix = "model" supports_gradient_checkpointing = True _no_split_modules = ["Phi3DecoderLayer"] _skip_keys_device_placement = "past_key_values" _supports_flash_attn_2 = False _supports_sdpa = True _supports_cache_class = True _version = "0.0.5" def __init__(self, config): super(Phi3VPreTrainedModel, self).__init__() self.config = config def _init_weights(self, module): std = self.config.initializer_range if isinstance(module, nn.Linear): module.weight.data.normal_(mean=0.0, std=std) if module.bias is not None: module.bias.data.zero_() elif isinstance(module, nn.Embedding): module.weight.data.normal_(mean=0.0, std=std) if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() class Phi3VModel(Phi3VPreTrainedModel): def __init__(self, config: Phi3VConfig): super(Phi3VModel, self).__init__(config) self.padding_idx = config.pad_token_id self.vocab_size = config.vocab_size self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size) self.embed_dropout = nn.Dropout(config.embd_pdrop) # Vision embedding integration if isinstance(config.embd_layer, dict) and config.embd_layer.get('embedding_cls') == 'image': self.vision_embed_tokens = Phi3ImageEmbedding(config) else: self.vision_embed_tokens = None self.layers = [Phi3DecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] self.norm = nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.gradient_checkpointing = False # Initialize weights and apply final processing self.apply(self._init_weights) def get_input_embeddings(self): return self.embed_tokens def set_input_embeddings(self, value): self.embed_tokens = value def __call__( self, input_ids: LongTensor = None, attention_mask: Optional[mx.array] = None, position_ids: Optional[LongTensor] = None, past_key_values: Optional[List[FloatTensor]] = None, inputs_embeds: Optional[FloatTensor] = None, pixel_values: Optional[FloatTensor] = None, image_sizes: Optional[LongTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, ) -> Union[Tuple, BaseModelOutputWithPast]: output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) use_cache = use_cache if use_cache is not None else self.config.use_cache return_dict = return_dict if return_dict is not None else self.config.use_return_dict # retrieve input_ids and inputs_embeds if input_ids is not None and inputs_embeds is not None: raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") elif input_ids is not None: batch_size, seq_length = input_ids.shape[:2] elif inputs_embeds is not None: batch_size, seq_length = inputs_embeds.shape[:2] else: raise ValueError("You have to specify either input_ids or inputs_embeds") past_key_values_length = 0 if self.gradient_checkpointing and self.training: if use_cache: logging.warning( "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." ) use_cache = False if use_cache: use_legacy_cache = not isinstance(past_key_values, Cache) if use_legacy_cache: past_key_values = DynamicCache.from_legacy_cache(past_key_values) past_key_values_length = past_key_values.get_usable_length(seq_length) if position_ids is None: device = input_ids.device if input_ids is not None else inputs_embeds.device position_ids = mx.arange( past_key_values_length, seq_length + past_key_values_length, Dtype=mx.long, device=device ) position_ids = position_ids.unsqueeze(0).view(-1, seq_length) else: position_ids = position_ids.view(-1, seq_length).long() if inputs_embeds is None: if pixel_values is not None and image_sizes is not None: assert self.vision_embed_tokens is not None, "Vision embedding layer is not defined" inputs_embeds = self.vision_embed_tokens(input_ids, pixel_values=pixel_values, image_sizes=image_sizes) else: inputs_embeds = self.embed_tokens(input_ids) if attention_mask is not None and self._attn_implementation == "flash_attention_2" and use_cache: is_padding_right = attention_mask[:, -1].sum().item() != batch_size if is_padding_right: raise ValueError( "You are attempting to perform batched generation with padding_side='right'" " this may lead to unexpected behaviour for Flash Attention version of Phi3. Make sure to " " call `tokenizer.padding_side = 'left'` before tokenizing the input. " ) if self._attn_implementation == "flash_attention_2": # 2d mask is passed through the layers attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None else: # 4d mask is passed through the layers attention_mask = _prepare_4d_causal_attention_mask( attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length, sliding_window=self.config.sliding_window, ) hidden_states = inputs_embeds # decoder layers all_hidden_states = () if output_hidden_states else None all_self_attns = () if output_attentions else None next_decoder_cache = None for decoder_layer in self.layers: if output_hidden_states: all_hidden_states += (hidden_states,) if self.gradient_checkpointing and self.training: layer_outputs = self._gradient_checkpointing_func( decoder_layer.__call__, hidden_states, attention_mask, position_ids, past_key_values, output_attentions, use_cache, ) else: layer_outputs = decoder_layer( hidden_states, attention_mask=attention_mask, position_ids=position_ids, past_key_value=past_key_values, output_attentions=output_attentions, use_cache=use_cache, ) hidden_states = layer_outputs[0] if use_cache: next_decoder_cache = layer_outputs[2 if output_attentions else 1] if output_attentions: all_self_attns += (layer_outputs[1],) hidden_states = self.norm(hidden_states) # add hidden states from the last decoder layer if output_hidden_states: all_hidden_states += (hidden_states,) next_cache = None if use_cache: next_cache = next_decoder_cache.to_legacy_cache() if use_legacy_cache else next_decoder_cache if not return_dict: return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) return BaseModelOutputWithPast( last_hidden_state=hidden_states, past_key_values=next_cache, hidden_states=all_hidden_states, attentions=all_self_attns, ) @staticmethod def from_pretrained(path_or_hf_repo: str): path = Path(path_or_hf_repo) if not path.exists(): path = Path( snapshot_download( repo_id=path_or_hf_repo, allow_patterns=[ "*.json", "*.safetensors", "*.py", "tokenizer.model", "*.tiktoken", ], ) ) with open(path / "config.json", "r") as f: model_config = json.load(f) model = Phi3VModel(Phi3VConfig.from_dict(model_config)) weight_files = list(glob.glob(f"{path}/*.safetensors")) assert len(weight_files) > 0, f"No safetensors weight files found: {weight_files}" # Load weights from all files weights = {} for wf in weight_files: weights.update(mx.load(wf)) # Ensure all weights are converted to lists if necessary for k, v in weights.items(): if hasattr(v, 'tolist'): weights[k] = v.tolist() # Load weights model.load_weights(list(weights.items())) return model class Phi3VForCausalLM(Phi3VPreTrainedModel): _tied_weights_keys = ["lm_head.weight"] def __init__(self, config): super().__init__(config) self.model = Phi3VModel(config) self.vocab_size = config.vocab_size self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) self.post_init() def get_input_embeddings(self): return self.model.embed_tokens def set_input_embeddings(self, value): self.model.embed_tokens = value def get_output_embeddings(self): return self.lm_head def set_output_embeddings(self, new_embeddings): self.lm_head = new_embeddings def set_decoder(self, decoder): self.model = decoder def get_decoder(self): return self.model def __call__( self, input_ids: LongTensor = None, attention_mask: Optional[mx.array] = None, position_ids: Optional[LongTensor] = None, past_key_values: Optional[List[FloatTensor]] = None, inputs_embeds: Optional[FloatTensor] = None, pixel_values: Optional[FloatTensor] = None, image_sizes: Optional[LongTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, ) -> Union[Tuple, CausalLMOutputWithPast]: output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) return_dict = return_dict if return_dict is not None else self.config.use_return_dict outputs = self.model( input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, pixel_values=pixel_values, image_sizes=image_sizes, use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, ) hidden_states = outputs[0] logits = self.lm_head(hidden_states) logits = logits.float() loss = None if not return_dict: output = (logits,) + outputs[1:] return (loss,) + output if loss is not None else output return CausalLMOutputWithPast( loss=loss, logits=logits, past_key_values=outputs.past_key_values, hidden_states=outputs.hidden_states, attentions=outputs.attentions, ) def prepare_inputs_for_generation( self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, pixel_values=None, image_sizes=None, **kwargs ): if past_key_values is not None: if isinstance(past_key_values, Cache): cache_length = past_key_values.get_seq_length() past_length = past_key_values.seen_tokens max_cache_length = past_key_values.get_max_length() else: cache_length = past_length = past_key_values[0][0].shape[2] max_cache_length = None if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]: input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :] elif past_length < input_ids.shape[1]: input_ids = input_ids[:, past_length:] if ( max_cache_length is not None and attention_mask is not None and cache_length + input_ids.shape[1] > max_cache_length ): attention_mask = attention_mask[:, -max_cache_length:] position_ids = kwargs.get("position_ids", None) if attention_mask is not None and position_ids is None: position_ids = attention_mask.long().cumsum(-1) - 1 position_ids.masked_fill_(attention_mask == 0, 1) if past_key_values: position_ids = position_ids[:, -input_ids.shape[1] :] if inputs_embeds is not None and past_key_values is None: model_inputs = {"inputs_embeds": inputs_embeds} else: model_inputs = {"input_ids": input_ids} model_inputs.update( { "position_ids": position_ids, "past_key_values": past_key_values, "use_cache": kwargs.get("use_cache"), "attention_mask": attention_mask, "pixel_values": pixel_values, "image_sizes": image_sizes, } ) return model_inputs @staticmethod def _reorder_cache(past_key_values, beam_idx): reordered_past = () for layer_past in past_key_values: reordered_past += ( tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past), ) return reordered_past