File size: 2,540 Bytes
0c77726
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
import torch
import torch.nn as nn
import timm
import torch
import torch.nn as nn
from transformers import PreTrainedModel
from timm.models.vision_transformer import LayerScale
from functools import partial
from typing import Any, Callable, Tuple
from transformers import PreTrainedModel
from .configuration_encoder import EncoderConfig



# === Utility Functions for Monkey-Patching ===
def _ls_new_forward(self, x: torch.Tensor) -> torch.Tensor:
    """
    Patch LayerScale forward method to use scale_factor instead of gamma.
    """
    return x.mul_(self.scale_factor) if self.inplace else x * self.scale_factor


def ls_apply_patch(ls_module: LayerScale):
    """
    Apply the LayerScale patch to replace gamma with scale_factor.
    """
    ls_module.scale_factor = nn.Parameter(ls_module.gamma.clone())
    ls_module.forward = _ls_new_forward.__get__(ls_module, LayerScale)
    del ls_module.gamma

def unpack_tuple(fn: Callable[[Any], Tuple[Any]]) -> Callable[[Any], Any]:
    """
    Utility function to unpack tuple results from the model's intermediate layers.
    """
    def wrapper(*args: Any, **kwargs: Any) -> Any:
        result = fn(*args, **kwargs)
        return result[0] if isinstance(result, tuple) else result

    return wrapper

class EncoderModel(PreTrainedModel):
    """
    Custom Vision Transformer Encoder based on timm's Vision Transformer.
    """
    config_class = EncoderConfig
    def __init__(self, config):
        """
        Initializes the model. No configurable parameters as it's tied to specific weights.
        """
        super().__init__(config)
        self.encoder = timm.create_model(
            "vit_large_patch14_reg4_dinov2.lvd142m",
            img_size=224,
            num_classes=0,
            pretrained=False
        )

        # Apply LayerScale patch
        for module in self.encoder.modules():
            if isinstance(module, LayerScale):
                ls_apply_patch(module)

        # Patch forward method to return specific intermediate layers
        self.encoder.forward = unpack_tuple(
            partial( self.encoder.get_intermediate_layers, n={len(self.encoder.blocks) - 2})
        )

    def forward(self, pixel_values: torch.Tensor) -> torch.Tensor:
        """
        Forward pass for the vision encoder.
        Args:
            pixel_values (torch.Tensor): Input tensor of shape (batch_size, channels, height, width).
        Returns:
            torch.Tensor: The output embeddings.
        """
        return self.encoder.forward(pixel_values)