File size: 4,644 Bytes
a15fec5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7bc81ce
a15fec5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
"""
Hugging Face compatible model implementation for Trendyol DinoV2
"""
import torch
import torch.nn as nn
from transformers import PreTrainedModel, PretrainedConfig
from transformers.modeling_outputs import BaseModelOutput
from typing import Optional, Tuple, Union
import torch.nn.functional as F


class TrendyolDinoV2Config(PretrainedConfig):
    """
    Configuration class for TrendyolDinoV2 model.
    """
    model_type = "trendyol_dinov2"
    
    def __init__(
        self,
        embedding_dim=256,
        input_size=224,
        hidden_size=256,
        backbone_name="dinov2_vitb14",
        in_features=768,
        downscale_size=332,
        pad_color=255,
        jpeg_quality=90,
        **kwargs
    ):
        super().__init__(**kwargs)
        self.embedding_dim = embedding_dim
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.backbone_name = backbone_name
        self.in_features = in_features
        self.downscale_size = downscale_size
        self.pad_color = pad_color
        self.jpeg_quality = jpeg_quality


class TYArcFaceDinoV2(nn.Module):
    """Core model architecture"""
    def __init__(self, config):
        super(TYArcFaceDinoV2, self).__init__()
        self.config = config
        
        # Load DinoV2 backbone
        try:
            self.backbone = torch.hub.load('facebookresearch/dinov2', config.backbone_name)
        except Exception as e:
            raise RuntimeError(f"Failed to load DinoV2 backbone: {e}")
            
        self.hidden_size = config.hidden_size
        self.in_features = config.in_features
        self.embedding_dim = config.embedding_dim
        
        self.bn1 = nn.BatchNorm2d(self.in_features)
        # Freeze backbone
        self.backbone.requires_grad_(False)
        
        # Projection layers
        self.fc11 = nn.Linear(self.in_features * self.hidden_size, self.embedding_dim)
        self.bn11 = nn.BatchNorm1d(self.embedding_dim)

    def forward(self, pixel_values):
        try:
            features = self.backbone.get_intermediate_layers(
                pixel_values, return_class_token=True, reshape=True
            )
            features = features[0][0]  # Get the features
            features = self.bn1(features)
            features = features.flatten(start_dim=1)
            features = self.fc11(features)
            features = self.bn11(features)
            features = F.normalize(features)
            return features
        except Exception as e:
            raise RuntimeError(f"Forward pass failed: {e}")


class TrendyolDinoV2Model(PreTrainedModel):
    """
    Hugging Face compatible wrapper for TrendyolDinoV2
    """
    config_class = TrendyolDinoV2Config
    base_model_prefix = "model"
    
    def __init__(self, config):
        super().__init__(config)
        self.model = TYArcFaceDinoV2(config)
        
        # Initialize weights
        self.init_weights()
    
    def _init_weights(self, module):
        """Initialize weights (required by PreTrainedModel)"""
        if isinstance(module, nn.Linear):
            module.weight.data.normal_(mean=0.0, std=0.02)
            if module.bias is not None:
                module.bias.data.zero_()
        elif isinstance(module, nn.BatchNorm1d) or isinstance(module, nn.BatchNorm2d):
            module.bias.data.zero_()
            module.weight.data.fill_(1.0)
    
    def init_weights(self):
        """Initialize all weights in the model"""
        self.apply(self._init_weights)
    
    def forward(
        self,
        pixel_values: Optional[torch.Tensor] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
        **kwargs
    ):
        return_dict = return_dict if return_dict is not None else getattr(self.config, 'use_return_dict', True)
        
        if pixel_values is None:
            raise ValueError("pixel_values cannot be None")
        
        # Get embeddings from the model
        embeddings = self.model(pixel_values)
        
        if not return_dict:
            return (embeddings,)
        
        return BaseModelOutput(
            last_hidden_state=embeddings,
            hidden_states=None,
            attentions=None
        )
    
    def get_embeddings(self, pixel_values):
        """Convenience method to get embeddings directly"""
        with torch.no_grad():
            outputs = self.forward(pixel_values, return_dict=True)
            return outputs.last_hidden_state


# Register the configuration
TrendyolDinoV2Config.register_for_auto_class()
TrendyolDinoV2Model.register_for_auto_class("AutoModel")