File size: 4,644 Bytes
a15fec5 7bc81ce a15fec5 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 |
"""
Hugging Face compatible model implementation for Trendyol DinoV2
"""
import torch
import torch.nn as nn
from transformers import PreTrainedModel, PretrainedConfig
from transformers.modeling_outputs import BaseModelOutput
from typing import Optional, Tuple, Union
import torch.nn.functional as F
class TrendyolDinoV2Config(PretrainedConfig):
"""
Configuration class for TrendyolDinoV2 model.
"""
model_type = "trendyol_dinov2"
def __init__(
self,
embedding_dim=256,
input_size=224,
hidden_size=256,
backbone_name="dinov2_vitb14",
in_features=768,
downscale_size=332,
pad_color=255,
jpeg_quality=90,
**kwargs
):
super().__init__(**kwargs)
self.embedding_dim = embedding_dim
self.input_size = input_size
self.hidden_size = hidden_size
self.backbone_name = backbone_name
self.in_features = in_features
self.downscale_size = downscale_size
self.pad_color = pad_color
self.jpeg_quality = jpeg_quality
class TYArcFaceDinoV2(nn.Module):
"""Core model architecture"""
def __init__(self, config):
super(TYArcFaceDinoV2, self).__init__()
self.config = config
# Load DinoV2 backbone
try:
self.backbone = torch.hub.load('facebookresearch/dinov2', config.backbone_name)
except Exception as e:
raise RuntimeError(f"Failed to load DinoV2 backbone: {e}")
self.hidden_size = config.hidden_size
self.in_features = config.in_features
self.embedding_dim = config.embedding_dim
self.bn1 = nn.BatchNorm2d(self.in_features)
# Freeze backbone
self.backbone.requires_grad_(False)
# Projection layers
self.fc11 = nn.Linear(self.in_features * self.hidden_size, self.embedding_dim)
self.bn11 = nn.BatchNorm1d(self.embedding_dim)
def forward(self, pixel_values):
try:
features = self.backbone.get_intermediate_layers(
pixel_values, return_class_token=True, reshape=True
)
features = features[0][0] # Get the features
features = self.bn1(features)
features = features.flatten(start_dim=1)
features = self.fc11(features)
features = self.bn11(features)
features = F.normalize(features)
return features
except Exception as e:
raise RuntimeError(f"Forward pass failed: {e}")
class TrendyolDinoV2Model(PreTrainedModel):
"""
Hugging Face compatible wrapper for TrendyolDinoV2
"""
config_class = TrendyolDinoV2Config
base_model_prefix = "model"
def __init__(self, config):
super().__init__(config)
self.model = TYArcFaceDinoV2(config)
# Initialize weights
self.init_weights()
def _init_weights(self, module):
"""Initialize weights (required by PreTrainedModel)"""
if isinstance(module, nn.Linear):
module.weight.data.normal_(mean=0.0, std=0.02)
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, nn.BatchNorm1d) or isinstance(module, nn.BatchNorm2d):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
def init_weights(self):
"""Initialize all weights in the model"""
self.apply(self._init_weights)
def forward(
self,
pixel_values: Optional[torch.Tensor] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
**kwargs
):
return_dict = return_dict if return_dict is not None else getattr(self.config, 'use_return_dict', True)
if pixel_values is None:
raise ValueError("pixel_values cannot be None")
# Get embeddings from the model
embeddings = self.model(pixel_values)
if not return_dict:
return (embeddings,)
return BaseModelOutput(
last_hidden_state=embeddings,
hidden_states=None,
attentions=None
)
def get_embeddings(self, pixel_values):
"""Convenience method to get embeddings directly"""
with torch.no_grad():
outputs = self.forward(pixel_values, return_dict=True)
return outputs.last_hidden_state
# Register the configuration
TrendyolDinoV2Config.register_for_auto_class()
TrendyolDinoV2Model.register_for_auto_class("AutoModel")
|