File size: 11,531 Bytes
073ed96 979aca3 073ed96 979aca3 073ed96 979aca3 073ed96 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 |
# coding=utf-8
# Copyright 2025 the SB Intuitions.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import List, Optional, Tuple, Union
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import CrossEntropyLoss
from transformers import (
AutoConfig,
AutoModelForCausalLM,
GenerationMixin,
LlamaForCausalLM,
PreTrainedModel,
)
from transformers.modeling_outputs import CausalLMOutputWithPast
from transformers.models.qwen2_vl.modeling_qwen2_vl import Qwen2VisionTransformerPretrainedModel
from transformers.utils import logging, replace_return_docstrings
from .configuration_sarashina2_vision import Sarashina2VisionConfig
logger = logging.get_logger(__name__)
_CONFIG_FOR_DOC = "Sarashina2VisionConfig"
class Sarashina2VisionPreTrainedModel(PreTrainedModel):
config_class = Sarashina2VisionConfig
base_model_prefix = "model"
_supports_flash_attn_2 = True
_supports_sdpa = True
_supports_cache_class = True
_supports_static_cache = True
def _init_weights(self, module):
std = (
self.config.initializer_range
if hasattr(self.config, "initializer_range")
else self.config.text_config.initializer_range
)
if hasattr(module, "class_embedding"):
module.class_embedding.data.normal_(mean=0.0, std=std)
if isinstance(module, (nn.Linear, nn.Conv3d)):
module.weight.data.normal_(mean=0.0, std=std)
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, nn.Embedding):
module.weight.data.normal_(mean=0.0, std=std)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
class Sarashina2VisionForCausalLM(Sarashina2VisionPreTrainedModel, GenerationMixin):
def __init__(self, config: Sarashina2VisionConfig):
super().__init__(config)
self.visual = Qwen2VisionTransformerPretrainedModel._from_config(config.vision_config)
self.norm = nn.LayerNorm(config.text_config.hidden_size)
self.llm = LlamaForCausalLM._from_config(config.text_config)
# Initialize weights and apply final processing
self.post_init()
def get_input_embeddings(self):
return self.llm.get_input_embeddings()
def get_image_embeds(
self,
hidden_states: torch.Tensor,
grid_thw: torch.Tensor,
) -> torch.Tensor:
rotary_pos_emb = self.visual.rot_pos_emb(grid_thw)
hidden_states = self.visual.patch_embed(hidden_states)
cu_seqlens = torch.repeat_interleave(
grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]
).cumsum(dim=0, dtype=torch.int32)
cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0)
for blk in self.visual.blocks:
hidden_states = blk(
hidden_states, cu_seqlens=cu_seqlens, rotary_pos_emb=rotary_pos_emb
)
return self.norm(self.visual.merger(hidden_states))
@replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
def forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
pixel_values: torch.FloatTensor = None,
image_grid_thw: Optional[torch.LongTensor] = None,
cache_position: Optional[torch.LongTensor] = None,
logits_to_keep: Union[int, torch.Tensor] = 0,
**lm_kwargs,
) -> Union[Tuple, CausalLMOutputWithPast]:
"""
Args:
input_ids (torch.LongTensor, optional): Indices of input sequence tokens in the vocabulary. Defaults to None.
attention_mask (Optional[torch.Tensor], optional): Mask to avoid performing attention on padding token indices. Defaults to None.
position_ids (Optional[torch.LongTensor], optional): Indices of positions of each input sequence tokens in the position embeddings. Defaults to None.
past_key_values (Optional[List[torch.FloatTensor]], optional): _description_. Defaults to None.
inputs_embeds (Optional[torch.FloatTensor], optional): Instead of passing `input_ids` you can choose to directly pass an embedded representation. Defaults to None.
labels (Optional[torch.LongTensor], optional): Labels for computing the masked language modeling loss. Defaults to None.
use_cache (Optional[bool], optional): If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding. Defaults to None.
output_attentions (Optional[bool], optional): Whether or not to return the attentions tensors of all attention layers. Defaults to None.
output_hidden_states (Optional[bool], optional): Whether or not to return the hidden states of all layers. Defaults to None.
return_dict (Optional[bool], optional): Whether or not to return a `CausalLMOutputWithPast` instead of a plain tuple. Defaults to None.
pixel_values (torch.FloatTensor, optional): The tensors corresponding to the input images. Defaults to None.
image_grid_thw (Optional[torch.LongTensor], optional): The temporal, height and width of feature shape of each image in LLM. Defaults to None.
cache_position (Optional[torch.LongTensor], optional): Indices depicting the position of the input sequence tokens in the sequence. Defaults to None.
logits_to_keep (Union[int, torch.Tensor]): If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all
`input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension.
This is useful when using packed tensor format (single dimension for batch and sequence length).
Returns:
CausalLMOutputWithPast: The output of the model.
"""
output_attentions = (
output_attentions if output_attentions is not None else self.config.output_attentions
)
output_hidden_states = (
output_hidden_states
if output_hidden_states is not None
else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if inputs_embeds is None:
inputs_embeds = self.get_input_embeddings()(input_ids)
if pixel_values is not None:
pixel_values = pixel_values.type(self.visual.get_dtype())
image_embeds = self.get_image_embeds(pixel_values, image_grid_thw)
n_image_tokens = (input_ids == self.config.image_token_index).sum().item()
n_image_features = image_embeds.shape[0]
if n_image_tokens != n_image_features:
raise ValueError(
f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"
)
image_mask = (
(input_ids == self.config.image_token_index)
.unsqueeze(-1)
.expand_as(inputs_embeds)
.to(inputs_embeds.device)
)
image_embeds = image_embeds.to(inputs_embeds.device, inputs_embeds.dtype)
inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds)
outputs = self.llm(
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
cache_position=cache_position,
logits_to_keep=logits_to_keep,
**lm_kwargs,
)
logits = outputs[0]
loss = None
if labels is not None:
# Upcast to float if we need to compute the loss to avoid potential precision issues
logits = logits.float()
# Shift so that tokens < n predict n
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
# Flatten the tokens
loss_fct = CrossEntropyLoss()
shift_logits = shift_logits.view(-1, self.config.vocab_size)
shift_labels = shift_labels.view(-1)
# Enable model parallelism
shift_labels = shift_labels.to(shift_logits.device)
loss = loss_fct(shift_logits, shift_labels)
if not return_dict:
output = (logits,) + outputs[1:]
return (loss,) + output if loss is not None else output
return CausalLMOutputWithPast(
loss=loss,
logits=logits,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
def prepare_inputs_for_generation(
self,
input_ids,
past_key_values=None,
inputs_embeds=None,
pixel_values=None,
attention_mask=None,
cache_position=None,
logits_to_keep=None,
image_grid_thw=None,
**kwargs,
):
model_inputs = self.llm.prepare_inputs_for_generation(
input_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
attention_mask=attention_mask,
cache_position=cache_position,
logits_to_keep=logits_to_keep,
**kwargs,
)
if cache_position[0] == 0:
# If we're in cached decoding stage, pixel values should be None because input ids do not contain special image token anymore
# Otherwise we need pixel values to be passed to model
model_inputs["pixel_values"] = pixel_values
model_inputs["image_grid_thw"] = image_grid_thw
return model_inputs
AutoConfig.register("sarashina2_vision", Sarashina2VisionConfig)
AutoModelForCausalLM.register(Sarashina2VisionConfig, Sarashina2VisionForCausalLM)
Sarashina2VisionConfig.register_for_auto_class()
Sarashina2VisionForCausalLM.register_for_auto_class("AutoModelForCausalLM")
|