import math, torch, torch.nn as nn from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import Qwen2_5_VLForConditionalGeneration from .configuration_qwen_rdx_linear import Qwen2_5_VLLinearConfig class Qwen2_5_VLLinearForCausalLM(Qwen2_5_VLForConditionalGeneration): config_class = Qwen2_5_VLLinearConfig def __init__(self, config): super().__init__(config) self.prefix_proj = nn.Sequential( nn.Linear(config.extra_feat_dim, config.hidden_size, bias=True), nn.SiLU(), ) nn.init.kaiming_uniform_(self.prefix_proj[0].weight, a=math.sqrt(5)) nn.init.zeros_(self.prefix_proj[0].bias) for n, p in self.named_parameters(): if not n.startswith("prefix_proj"): p.requires_grad_(False) # ------------------------------------------------------------------ # # Forward # ------------------------------------------------------------------ # def forward( self, input_ids=None, prefix_feats=None, attention_mask=None, labels=None, past_key_values=None, pixel_values=None, image_grid_thw=None, rope_deltas=None, **kwargs, ): # -------- Incremental decoding steps (after first token) ---------- if past_key_values is not None: # No prefix injection here; caches already contain the effect # from the initial full-context step. return super().forward( input_ids=input_ids, past_key_values=past_key_values, attention_mask=attention_mask, labels=labels, pixel_values=pixel_values, image_grid_thw=image_grid_thw, rope_deltas=rope_deltas, **kwargs, ) # -------- Full-context step (step 0) ------------------------------ # Optional: if caller forgets prefix_feats, just skip injection. rdx_id = self.config.rdx_token_id do_inject = prefix_feats is not None and input_ids is not None and (input_ids == rdx_id).any() if do_inject: emb_layer = self.get_input_embeddings() dtype, device = emb_layer.weight.dtype, emb_layer.weight.device delta = self.prefix_proj(prefix_feats.to(device=device, dtype=dtype)) # (bs, hidden) rdx_mask = (input_ids == rdx_id) def hook(_m, _in, out): out = out.clone() # add delta to each position per batch for b in range(out.size(0)): idxs = torch.nonzero(rdx_mask[b], as_tuple=False).view(-1) if idxs.numel(): out[b, idxs, :] += delta[b] return out handle = emb_layer.register_forward_hook(lambda m, i, o: hook(m, i, o)) try: outputs = super().forward( input_ids=input_ids, attention_mask=attention_mask, labels=labels, # let parent compute loss pixel_values=pixel_values, image_grid_thw=image_grid_thw, rope_deltas=rope_deltas, **kwargs, ) finally: handle.remove() return outputs # -------- No token present or no prefix_feats --------------- return super().forward( input_ids=input_ids, attention_mask=attention_mask, labels=labels, pixel_values=pixel_values, image_grid_thw=image_grid_thw, rope_deltas=rope_deltas, **kwargs, ) # ------------------------------------------------------------------ # # Generation helper # ------------------------------------------------------------------ # def prepare_inputs_for_generation( self, input_ids, past_key_values=None, attention_mask=None, prefix_feats=None, pixel_values=None, image_grid_thw=None, rope_deltas=None, **kwargs, ): # Let parent build its dict; then tack on prefix_feats if provided. model_inputs = super().prepare_inputs_for_generation( input_ids=input_ids, past_key_values=past_key_values, attention_mask=attention_mask, pixel_values=pixel_values, image_grid_thw=image_grid_thw, rope_deltas=rope_deltas, **kwargs, ) if prefix_feats is not None: model_inputs["prefix_feats"] = prefix_feats return model_inputs