trendyol-dino-v2-ecommerce-256d / modeling_trendyol_dinov2.py
yusufcakmak's picture
feat: test upload - Trendyol DinoV2 Product Similarity and Retrieval Embedding Model
7bc81ce verified
"""
Hugging Face compatible model implementation for Trendyol DinoV2
"""
import torch
import torch.nn as nn
from transformers import PreTrainedModel, PretrainedConfig
from transformers.modeling_outputs import BaseModelOutput
from typing import Optional, Tuple, Union
import torch.nn.functional as F
class TrendyolDinoV2Config(PretrainedConfig):
"""
Configuration class for TrendyolDinoV2 model.
"""
model_type = "trendyol_dinov2"
def __init__(
self,
embedding_dim=256,
input_size=224,
hidden_size=256,
backbone_name="dinov2_vitb14",
in_features=768,
downscale_size=332,
pad_color=255,
jpeg_quality=90,
**kwargs
):
super().__init__(**kwargs)
self.embedding_dim = embedding_dim
self.input_size = input_size
self.hidden_size = hidden_size
self.backbone_name = backbone_name
self.in_features = in_features
self.downscale_size = downscale_size
self.pad_color = pad_color
self.jpeg_quality = jpeg_quality
class TYArcFaceDinoV2(nn.Module):
"""Core model architecture"""
def __init__(self, config):
super(TYArcFaceDinoV2, self).__init__()
self.config = config
# Load DinoV2 backbone
try:
self.backbone = torch.hub.load('facebookresearch/dinov2', config.backbone_name)
except Exception as e:
raise RuntimeError(f"Failed to load DinoV2 backbone: {e}")
self.hidden_size = config.hidden_size
self.in_features = config.in_features
self.embedding_dim = config.embedding_dim
self.bn1 = nn.BatchNorm2d(self.in_features)
# Freeze backbone
self.backbone.requires_grad_(False)
# Projection layers
self.fc11 = nn.Linear(self.in_features * self.hidden_size, self.embedding_dim)
self.bn11 = nn.BatchNorm1d(self.embedding_dim)
def forward(self, pixel_values):
try:
features = self.backbone.get_intermediate_layers(
pixel_values, return_class_token=True, reshape=True
)
features = features[0][0] # Get the features
features = self.bn1(features)
features = features.flatten(start_dim=1)
features = self.fc11(features)
features = self.bn11(features)
features = F.normalize(features)
return features
except Exception as e:
raise RuntimeError(f"Forward pass failed: {e}")
class TrendyolDinoV2Model(PreTrainedModel):
"""
Hugging Face compatible wrapper for TrendyolDinoV2
"""
config_class = TrendyolDinoV2Config
base_model_prefix = "model"
def __init__(self, config):
super().__init__(config)
self.model = TYArcFaceDinoV2(config)
# Initialize weights
self.init_weights()
def _init_weights(self, module):
"""Initialize weights (required by PreTrainedModel)"""
if isinstance(module, nn.Linear):
module.weight.data.normal_(mean=0.0, std=0.02)
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, nn.BatchNorm1d) or isinstance(module, nn.BatchNorm2d):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
def init_weights(self):
"""Initialize all weights in the model"""
self.apply(self._init_weights)
def forward(
self,
pixel_values: Optional[torch.Tensor] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
**kwargs
):
return_dict = return_dict if return_dict is not None else getattr(self.config, 'use_return_dict', True)
if pixel_values is None:
raise ValueError("pixel_values cannot be None")
# Get embeddings from the model
embeddings = self.model(pixel_values)
if not return_dict:
return (embeddings,)
return BaseModelOutput(
last_hidden_state=embeddings,
hidden_states=None,
attentions=None
)
def get_embeddings(self, pixel_values):
"""Convenience method to get embeddings directly"""
with torch.no_grad():
outputs = self.forward(pixel_values, return_dict=True)
return outputs.last_hidden_state
# Register the configuration
TrendyolDinoV2Config.register_for_auto_class()
TrendyolDinoV2Model.register_for_auto_class("AutoModel")