229nagibator229's picture
Upload model
0c77726 verified
import torch
import torch.nn as nn
import timm
import torch
import torch.nn as nn
from transformers import PreTrainedModel
from timm.models.vision_transformer import LayerScale
from functools import partial
from typing import Any, Callable, Tuple
from transformers import PreTrainedModel
from .configuration_encoder import EncoderConfig
# === Utility Functions for Monkey-Patching ===
def _ls_new_forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Patch LayerScale forward method to use scale_factor instead of gamma.
"""
return x.mul_(self.scale_factor) if self.inplace else x * self.scale_factor
def ls_apply_patch(ls_module: LayerScale):
"""
Apply the LayerScale patch to replace gamma with scale_factor.
"""
ls_module.scale_factor = nn.Parameter(ls_module.gamma.clone())
ls_module.forward = _ls_new_forward.__get__(ls_module, LayerScale)
del ls_module.gamma
def unpack_tuple(fn: Callable[[Any], Tuple[Any]]) -> Callable[[Any], Any]:
"""
Utility function to unpack tuple results from the model's intermediate layers.
"""
def wrapper(*args: Any, **kwargs: Any) -> Any:
result = fn(*args, **kwargs)
return result[0] if isinstance(result, tuple) else result
return wrapper
class EncoderModel(PreTrainedModel):
"""
Custom Vision Transformer Encoder based on timm's Vision Transformer.
"""
config_class = EncoderConfig
def __init__(self, config):
"""
Initializes the model. No configurable parameters as it's tied to specific weights.
"""
super().__init__(config)
self.encoder = timm.create_model(
"vit_large_patch14_reg4_dinov2.lvd142m",
img_size=224,
num_classes=0,
pretrained=False
)
# Apply LayerScale patch
for module in self.encoder.modules():
if isinstance(module, LayerScale):
ls_apply_patch(module)
# Patch forward method to return specific intermediate layers
self.encoder.forward = unpack_tuple(
partial( self.encoder.get_intermediate_layers, n={len(self.encoder.blocks) - 2})
)
def forward(self, pixel_values: torch.Tensor) -> torch.Tensor:
"""
Forward pass for the vision encoder.
Args:
pixel_values (torch.Tensor): Input tensor of shape (batch_size, channels, height, width).
Returns:
torch.Tensor: The output embeddings.
"""
return self.encoder.forward(pixel_values)