import torch import torch.nn as nn import timm import torch import torch.nn as nn from transformers import PreTrainedModel from timm.models.vision_transformer import LayerScale from functools import partial from typing import Any, Callable, Tuple from transformers import PreTrainedModel from .configuration_encoder import EncoderConfig # === Utility Functions for Monkey-Patching === def _ls_new_forward(self, x: torch.Tensor) -> torch.Tensor: """ Patch LayerScale forward method to use scale_factor instead of gamma. """ return x.mul_(self.scale_factor) if self.inplace else x * self.scale_factor def ls_apply_patch(ls_module: LayerScale): """ Apply the LayerScale patch to replace gamma with scale_factor. """ ls_module.scale_factor = nn.Parameter(ls_module.gamma.clone()) ls_module.forward = _ls_new_forward.__get__(ls_module, LayerScale) del ls_module.gamma def unpack_tuple(fn: Callable[[Any], Tuple[Any]]) -> Callable[[Any], Any]: """ Utility function to unpack tuple results from the model's intermediate layers. """ def wrapper(*args: Any, **kwargs: Any) -> Any: result = fn(*args, **kwargs) return result[0] if isinstance(result, tuple) else result return wrapper class EncoderModel(PreTrainedModel): """ Custom Vision Transformer Encoder based on timm's Vision Transformer. """ config_class = EncoderConfig def __init__(self, config): """ Initializes the model. No configurable parameters as it's tied to specific weights. """ super().__init__(config) self.encoder = timm.create_model( "vit_large_patch14_reg4_dinov2.lvd142m", img_size=224, num_classes=0, pretrained=False ) # Apply LayerScale patch for module in self.encoder.modules(): if isinstance(module, LayerScale): ls_apply_patch(module) # Patch forward method to return specific intermediate layers self.encoder.forward = unpack_tuple( partial( self.encoder.get_intermediate_layers, n={len(self.encoder.blocks) - 2}) ) def forward(self, pixel_values: torch.Tensor) -> torch.Tensor: """ Forward pass for the vision encoder. Args: pixel_values (torch.Tensor): Input tensor of shape (batch_size, channels, height, width). Returns: torch.Tensor: The output embeddings. """ return self.encoder.forward(pixel_values)