|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
import numpy as np |
|
from transformers import ( |
|
AutoTokenizer, |
|
AutoModel, |
|
PreTrainedModel, |
|
PretrainedConfig |
|
) |
|
from torchvision.models import resnet50 |
|
from typing import Optional, Dict, Any |
|
|
|
class JapaneseCLIPConfig(PretrainedConfig): |
|
"""Japanese CLIP モデル設定クラス""" |
|
model_type = "japanese-clip" |
|
|
|
def __init__( |
|
self, |
|
text_model_name="cl-tohoku/bert-base-japanese-v3", |
|
image_embed_dim=512, |
|
text_embed_dim=512, |
|
temperature=0.07, |
|
**kwargs |
|
): |
|
super().__init__(**kwargs) |
|
self.text_model_name = text_model_name |
|
self.image_embed_dim = image_embed_dim |
|
self.text_embed_dim = text_embed_dim |
|
self.temperature = temperature |
|
|
|
class JapaneseCLIPModel(PreTrainedModel): |
|
"""Hugging Face互換のJapaneseCLIPモデル""" |
|
config_class = JapaneseCLIPConfig |
|
|
|
def __init__(self, config): |
|
super().__init__(config) |
|
|
|
|
|
try: |
|
from torchvision.models import resnet50 |
|
except ImportError: |
|
raise ImportError("torchvision is required for this model. Install it with: pip install torchvision") |
|
|
|
|
|
self.image_encoder = resnet50(pretrained=True) |
|
self.image_encoder.fc = nn.Linear( |
|
self.image_encoder.fc.in_features, |
|
config.image_embed_dim |
|
) |
|
|
|
|
|
self.text_encoder = AutoModel.from_pretrained(config.text_model_name) |
|
|
|
|
|
self.text_projection = nn.Linear( |
|
self.text_encoder.config.hidden_size, |
|
config.text_embed_dim |
|
) |
|
self.image_projection = nn.Linear( |
|
config.image_embed_dim, |
|
config.text_embed_dim |
|
) |
|
|
|
|
|
self.image_norm = nn.LayerNorm(config.text_embed_dim) |
|
self.text_norm = nn.LayerNorm(config.text_embed_dim) |
|
|
|
|
|
self.temperature = nn.Parameter( |
|
torch.ones([]) * np.log(1 / config.temperature) |
|
) |
|
|
|
def encode_image(self, pixel_values): |
|
"""画像をエンコード""" |
|
image_features = self.image_encoder(pixel_values) |
|
image_features = self.image_projection(image_features) |
|
image_features = self.image_norm(image_features) |
|
return F.normalize(image_features, dim=-1) |
|
|
|
def encode_text(self, input_ids, attention_mask): |
|
"""テキストをエンコード""" |
|
text_outputs = self.text_encoder( |
|
input_ids=input_ids, |
|
attention_mask=attention_mask |
|
) |
|
text_features = text_outputs.last_hidden_state[:, 0, :] |
|
text_features = self.text_projection(text_features) |
|
text_features = self.text_norm(text_features) |
|
return F.normalize(text_features, dim=-1) |
|
|
|
def get_image_features(self, pixel_values): |
|
"""画像特徴量を取得""" |
|
return self.encode_image(pixel_values) |
|
|
|
def get_text_features(self, input_ids, attention_mask): |
|
"""テキスト特徴量を取得""" |
|
return self.encode_text(input_ids, attention_mask) |
|
|
|
def forward( |
|
self, |
|
pixel_values: Optional[torch.Tensor] = None, |
|
input_ids: Optional[torch.Tensor] = None, |
|
attention_mask: Optional[torch.Tensor] = None, |
|
**kwargs |
|
) -> Dict[str, torch.Tensor]: |
|
"""順伝播""" |
|
outputs = {} |
|
|
|
if pixel_values is not None: |
|
outputs['image_features'] = self.encode_image(pixel_values) |
|
|
|
if input_ids is not None and attention_mask is not None: |
|
outputs['text_features'] = self.encode_text(input_ids, attention_mask) |
|
|
|
if 'image_features' in outputs and 'text_features' in outputs: |
|
|
|
similarity = torch.matmul( |
|
outputs['image_features'], |
|
outputs['text_features'].T |
|
) |
|
temperature = self.temperature.exp() |
|
outputs['logits_per_image'] = similarity * temperature |
|
outputs['logits_per_text'] = outputs['logits_per_image'].T |
|
outputs['temperature'] = temperature |
|
|
|
return outputs |
|
|
|
|
|
from transformers import AutoConfig, AutoModel |
|
|
|
AutoConfig.register("japanese-clip", JapaneseCLIPConfig) |
|
AutoModel.register(JapaneseCLIPConfig, JapaneseCLIPModel) |
|
|