japanese-clip-stair / modeling_japanese_clip.py
AoiNoGeso's picture
Upload Japanese CLIP model with fixed configuration
cb2a584 verified
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from transformers import (
AutoTokenizer,
AutoModel,
PreTrainedModel,
PretrainedConfig
)
from torchvision.models import resnet50
from typing import Optional, Dict, Any
class JapaneseCLIPConfig(PretrainedConfig):
"""Japanese CLIP モデル設定クラス"""
model_type = "japanese-clip"
def __init__(
self,
text_model_name="cl-tohoku/bert-base-japanese-v3",
image_embed_dim=512,
text_embed_dim=512,
temperature=0.07,
**kwargs
):
super().__init__(**kwargs)
self.text_model_name = text_model_name
self.image_embed_dim = image_embed_dim
self.text_embed_dim = text_embed_dim
self.temperature = temperature
class JapaneseCLIPModel(PreTrainedModel):
"""Hugging Face互換のJapaneseCLIPモデル"""
config_class = JapaneseCLIPConfig
def __init__(self, config):
super().__init__(config)
# torchvisionのインポートを内部で行う
try:
from torchvision.models import resnet50
except ImportError:
raise ImportError("torchvision is required for this model. Install it with: pip install torchvision")
# 画像エンコーダ(ResNet50ベース)
self.image_encoder = resnet50(pretrained=True)
self.image_encoder.fc = nn.Linear(
self.image_encoder.fc.in_features,
config.image_embed_dim
)
# テキストエンコーダ(日本語BERT)
self.text_encoder = AutoModel.from_pretrained(config.text_model_name)
# プロジェクション層
self.text_projection = nn.Linear(
self.text_encoder.config.hidden_size,
config.text_embed_dim
)
self.image_projection = nn.Linear(
config.image_embed_dim,
config.text_embed_dim
)
# 正規化層
self.image_norm = nn.LayerNorm(config.text_embed_dim)
self.text_norm = nn.LayerNorm(config.text_embed_dim)
# 温度パラメータ
self.temperature = nn.Parameter(
torch.ones([]) * np.log(1 / config.temperature)
)
def encode_image(self, pixel_values):
"""画像をエンコード"""
image_features = self.image_encoder(pixel_values)
image_features = self.image_projection(image_features)
image_features = self.image_norm(image_features)
return F.normalize(image_features, dim=-1)
def encode_text(self, input_ids, attention_mask):
"""テキストをエンコード"""
text_outputs = self.text_encoder(
input_ids=input_ids,
attention_mask=attention_mask
)
text_features = text_outputs.last_hidden_state[:, 0, :]
text_features = self.text_projection(text_features)
text_features = self.text_norm(text_features)
return F.normalize(text_features, dim=-1)
def get_image_features(self, pixel_values):
"""画像特徴量を取得"""
return self.encode_image(pixel_values)
def get_text_features(self, input_ids, attention_mask):
"""テキスト特徴量を取得"""
return self.encode_text(input_ids, attention_mask)
def forward(
self,
pixel_values: Optional[torch.Tensor] = None,
input_ids: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
**kwargs
) -> Dict[str, torch.Tensor]:
"""順伝播"""
outputs = {}
if pixel_values is not None:
outputs['image_features'] = self.encode_image(pixel_values)
if input_ids is not None and attention_mask is not None:
outputs['text_features'] = self.encode_text(input_ids, attention_mask)
if 'image_features' in outputs and 'text_features' in outputs:
# 類似度計算
similarity = torch.matmul(
outputs['image_features'],
outputs['text_features'].T
)
temperature = self.temperature.exp()
outputs['logits_per_image'] = similarity * temperature
outputs['logits_per_text'] = outputs['logits_per_image'].T
outputs['temperature'] = temperature
return outputs
# AutoModelにカスタムモデルを登録
from transformers import AutoConfig, AutoModel
AutoConfig.register("japanese-clip", JapaneseCLIPConfig)
AutoModel.register(JapaneseCLIPConfig, JapaneseCLIPModel)