
feat: test upload - Trendyol DinoV2 Product Similarity and Retrieval Embedding Model
7bc81ce
verified
""" | |
Hugging Face compatible model implementation for Trendyol DinoV2 | |
""" | |
import torch | |
import torch.nn as nn | |
from transformers import PreTrainedModel, PretrainedConfig | |
from transformers.modeling_outputs import BaseModelOutput | |
from typing import Optional, Tuple, Union | |
import torch.nn.functional as F | |
class TrendyolDinoV2Config(PretrainedConfig): | |
""" | |
Configuration class for TrendyolDinoV2 model. | |
""" | |
model_type = "trendyol_dinov2" | |
def __init__( | |
self, | |
embedding_dim=256, | |
input_size=224, | |
hidden_size=256, | |
backbone_name="dinov2_vitb14", | |
in_features=768, | |
downscale_size=332, | |
pad_color=255, | |
jpeg_quality=90, | |
**kwargs | |
): | |
super().__init__(**kwargs) | |
self.embedding_dim = embedding_dim | |
self.input_size = input_size | |
self.hidden_size = hidden_size | |
self.backbone_name = backbone_name | |
self.in_features = in_features | |
self.downscale_size = downscale_size | |
self.pad_color = pad_color | |
self.jpeg_quality = jpeg_quality | |
class TYArcFaceDinoV2(nn.Module): | |
"""Core model architecture""" | |
def __init__(self, config): | |
super(TYArcFaceDinoV2, self).__init__() | |
self.config = config | |
# Load DinoV2 backbone | |
try: | |
self.backbone = torch.hub.load('facebookresearch/dinov2', config.backbone_name) | |
except Exception as e: | |
raise RuntimeError(f"Failed to load DinoV2 backbone: {e}") | |
self.hidden_size = config.hidden_size | |
self.in_features = config.in_features | |
self.embedding_dim = config.embedding_dim | |
self.bn1 = nn.BatchNorm2d(self.in_features) | |
# Freeze backbone | |
self.backbone.requires_grad_(False) | |
# Projection layers | |
self.fc11 = nn.Linear(self.in_features * self.hidden_size, self.embedding_dim) | |
self.bn11 = nn.BatchNorm1d(self.embedding_dim) | |
def forward(self, pixel_values): | |
try: | |
features = self.backbone.get_intermediate_layers( | |
pixel_values, return_class_token=True, reshape=True | |
) | |
features = features[0][0] # Get the features | |
features = self.bn1(features) | |
features = features.flatten(start_dim=1) | |
features = self.fc11(features) | |
features = self.bn11(features) | |
features = F.normalize(features) | |
return features | |
except Exception as e: | |
raise RuntimeError(f"Forward pass failed: {e}") | |
class TrendyolDinoV2Model(PreTrainedModel): | |
""" | |
Hugging Face compatible wrapper for TrendyolDinoV2 | |
""" | |
config_class = TrendyolDinoV2Config | |
base_model_prefix = "model" | |
def __init__(self, config): | |
super().__init__(config) | |
self.model = TYArcFaceDinoV2(config) | |
# Initialize weights | |
self.init_weights() | |
def _init_weights(self, module): | |
"""Initialize weights (required by PreTrainedModel)""" | |
if isinstance(module, nn.Linear): | |
module.weight.data.normal_(mean=0.0, std=0.02) | |
if module.bias is not None: | |
module.bias.data.zero_() | |
elif isinstance(module, nn.BatchNorm1d) or isinstance(module, nn.BatchNorm2d): | |
module.bias.data.zero_() | |
module.weight.data.fill_(1.0) | |
def init_weights(self): | |
"""Initialize all weights in the model""" | |
self.apply(self._init_weights) | |
def forward( | |
self, | |
pixel_values: Optional[torch.Tensor] = None, | |
output_hidden_states: Optional[bool] = None, | |
return_dict: Optional[bool] = None, | |
**kwargs | |
): | |
return_dict = return_dict if return_dict is not None else getattr(self.config, 'use_return_dict', True) | |
if pixel_values is None: | |
raise ValueError("pixel_values cannot be None") | |
# Get embeddings from the model | |
embeddings = self.model(pixel_values) | |
if not return_dict: | |
return (embeddings,) | |
return BaseModelOutput( | |
last_hidden_state=embeddings, | |
hidden_states=None, | |
attentions=None | |
) | |
def get_embeddings(self, pixel_values): | |
"""Convenience method to get embeddings directly""" | |
with torch.no_grad(): | |
outputs = self.forward(pixel_values, return_dict=True) | |
return outputs.last_hidden_state | |
# Register the configuration | |
TrendyolDinoV2Config.register_for_auto_class() | |
TrendyolDinoV2Model.register_for_auto_class("AutoModel") | |