|
import torch |
|
import torch.nn as nn |
|
import timm |
|
import torch |
|
import torch.nn as nn |
|
from transformers import PreTrainedModel |
|
from timm.models.vision_transformer import LayerScale |
|
from functools import partial |
|
from typing import Any, Callable, Tuple |
|
from transformers import PreTrainedModel |
|
from .configuration_encoder import EncoderConfig |
|
|
|
|
|
|
|
|
|
def _ls_new_forward(self, x: torch.Tensor) -> torch.Tensor: |
|
""" |
|
Patch LayerScale forward method to use scale_factor instead of gamma. |
|
""" |
|
return x.mul_(self.scale_factor) if self.inplace else x * self.scale_factor |
|
|
|
|
|
def ls_apply_patch(ls_module: LayerScale): |
|
""" |
|
Apply the LayerScale patch to replace gamma with scale_factor. |
|
""" |
|
ls_module.scale_factor = nn.Parameter(ls_module.gamma.clone()) |
|
ls_module.forward = _ls_new_forward.__get__(ls_module, LayerScale) |
|
del ls_module.gamma |
|
|
|
def unpack_tuple(fn: Callable[[Any], Tuple[Any]]) -> Callable[[Any], Any]: |
|
""" |
|
Utility function to unpack tuple results from the model's intermediate layers. |
|
""" |
|
def wrapper(*args: Any, **kwargs: Any) -> Any: |
|
result = fn(*args, **kwargs) |
|
return result[0] if isinstance(result, tuple) else result |
|
|
|
return wrapper |
|
|
|
class EncoderModel(PreTrainedModel): |
|
""" |
|
Custom Vision Transformer Encoder based on timm's Vision Transformer. |
|
""" |
|
config_class = EncoderConfig |
|
def __init__(self, config): |
|
""" |
|
Initializes the model. No configurable parameters as it's tied to specific weights. |
|
""" |
|
super().__init__(config) |
|
self.encoder = timm.create_model( |
|
"vit_large_patch14_reg4_dinov2.lvd142m", |
|
img_size=224, |
|
num_classes=0, |
|
pretrained=False |
|
) |
|
|
|
|
|
for module in self.encoder.modules(): |
|
if isinstance(module, LayerScale): |
|
ls_apply_patch(module) |
|
|
|
|
|
self.encoder.forward = unpack_tuple( |
|
partial( self.encoder.get_intermediate_layers, n={len(self.encoder.blocks) - 2}) |
|
) |
|
|
|
def forward(self, pixel_values: torch.Tensor) -> torch.Tensor: |
|
""" |
|
Forward pass for the vision encoder. |
|
Args: |
|
pixel_values (torch.Tensor): Input tensor of shape (batch_size, channels, height, width). |
|
Returns: |
|
torch.Tensor: The output embeddings. |
|
""" |
|
return self.encoder.forward(pixel_values) |
|
|