|
""" |
|
Adapted from ImageReward (https://github.com/THUDM/ImageReward) |
|
""" |
|
|
|
import os |
|
import torch |
|
import torch.nn as nn |
|
from PIL import Image |
|
|
|
|
|
from blip.blip_pretrain import blip_pretrain |
|
from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize |
|
from huggingface_hub import PyTorchModelHubMixin |
|
|
|
try: |
|
from torchvision.transforms import InterpolationMode |
|
BICUBIC = InterpolationMode.BICUBIC |
|
except ImportError: |
|
BICUBIC = Image.BICUBIC |
|
|
|
cyclereward_args = { |
|
'blip_path': 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_large.pth', |
|
'vit': 'large', |
|
'image_size': 224, |
|
'mlp_dim': 768 |
|
} |
|
|
|
def _convert_image_to_rgb(image): |
|
return image.convert("RGB") |
|
|
|
def _transform(n_px): |
|
return Compose([ |
|
Resize(n_px, interpolation=BICUBIC), |
|
CenterCrop(n_px), |
|
_convert_image_to_rgb, |
|
ToTensor(), |
|
Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)), |
|
]) |
|
|
|
class CycleReward(nn.Module, PyTorchModelHubMixin): |
|
def __init__(self, device='cpu', |
|
model_type='CycleReward-Combo', |
|
max_length=128, |
|
fix_rate=0.7, |
|
med_config=None, |
|
): |
|
super().__init__() |
|
self.device = device |
|
self.model_type = model_type |
|
self.max_length = max_length |
|
|
|
self.blip = blip_pretrain( |
|
pretrained=cyclereward_args['blip_path'], |
|
med_config=med_config, |
|
image_size=cyclereward_args['image_size'], |
|
vit=cyclereward_args['vit'] |
|
) |
|
self.preprocess = _transform(cyclereward_args['image_size']) |
|
self.mlp = MLP(cyclereward_args['mlp_dim']) |
|
|
|
for name, parms in self.blip.named_parameters(): |
|
if '_proj' in name: |
|
parms.requires_grad_(False) |
|
|
|
|
|
self.image_layer_num = 24 if cyclereward_args['vit'] == 'large' else 12 |
|
if fix_rate > 0: |
|
text_fix_num = "layer.{}".format(int(12 * fix_rate)) |
|
image_fix_num = "blocks.{}".format(int(self.image_layer_num * fix_rate)) |
|
for name, parms in self.blip.text_encoder.named_parameters(): |
|
parms.requires_grad_(False) |
|
if text_fix_num in name: |
|
break |
|
for name, parms in self.blip.visual_encoder.named_parameters(): |
|
parms.requires_grad_(False) |
|
if image_fix_num in name: |
|
break |
|
|
|
def forward(self, batch): |
|
if 'Combo' in self.model_type: |
|
text_reward = self.text_reward(batch) |
|
image_reward = self.image_reward(batch) |
|
|
|
elif 'I2T' in self.model_type: |
|
text_reward = self.text_reward(batch) |
|
image_reward = None |
|
|
|
elif 'T2I' in self.model_type: |
|
text_reward = None |
|
image_reward = self.image_reward(batch) |
|
|
|
return text_reward, image_reward |
|
|
|
def text_reward(self, batch): |
|
images, preferred_ids, preferred_mask, rejected_ids, rejected_mask = batch["images"], batch["preferred_ids"], batch["preferred_mask"], batch["rejected_ids"], batch["rejected_mask"] |
|
images = images.to(self.device) |
|
preferred_ids = preferred_ids.to(self.device) |
|
preferred_mask = preferred_mask.to(self.device) |
|
rejected_ids = rejected_ids.to(self.device) |
|
rejected_mask = rejected_mask.to(self.device) |
|
|
|
|
|
image_embeds = self.blip.visual_encoder(images) |
|
image_atts = torch.ones(image_embeds.size()[:-1],dtype=torch.long).to(self.device) |
|
|
|
|
|
preferred_embeds = self.blip.text_encoder( |
|
preferred_ids, |
|
attention_mask=preferred_mask, |
|
encoder_hidden_states=image_embeds, |
|
encoder_attention_mask=image_atts, |
|
return_dict=True, |
|
).last_hidden_state |
|
preferred_embeds = preferred_embeds[:,0,:].float() |
|
|
|
|
|
rejected_embeds = self.blip.text_encoder( |
|
rejected_ids, |
|
attention_mask=rejected_mask, |
|
encoder_hidden_states=image_embeds, |
|
encoder_attention_mask=image_atts, |
|
return_dict=True, |
|
).last_hidden_state |
|
rejected_embeds = rejected_embeds[:,0,:].float() |
|
|
|
preferred_reward = self.mlp(preferred_embeds) |
|
rejected_reward = self.mlp(rejected_embeds) |
|
reward = torch.concat((preferred_reward, rejected_reward), dim=1) |
|
|
|
return reward |
|
|
|
def image_reward(self, batch): |
|
prompt_ids, prompt_mask, image_preferred, image_rejected = batch["prompt_ids"], batch["prompt_mask"], batch["image_preferred"], batch["image_rejected"] |
|
image_preferred = image_preferred.to(self.device) |
|
image_rejected = image_rejected.to(self.device) |
|
prompt_ids = prompt_ids.view(prompt_ids.shape[0], -1).to(self.device) |
|
prompt_mask = prompt_mask.view(prompt_mask.shape[0], -1).to(self.device) |
|
|
|
|
|
image_embeds_preferred = self.blip.visual_encoder(image_preferred) |
|
image_atts_preferred = torch.ones(image_embeds_preferred.size()[:-1],dtype=torch.long).to(self.device) |
|
|
|
image_embeds_rejected = self.blip.visual_encoder(image_rejected) |
|
image_atts_rejected = torch.ones(image_embeds_rejected.size()[:-1],dtype=torch.long).to(self.device) |
|
|
|
|
|
preferred_embeds = self.blip.text_encoder( |
|
prompt_ids, |
|
attention_mask=prompt_mask, |
|
encoder_hidden_states=image_embeds_preferred, |
|
encoder_attention_mask=image_atts_preferred, |
|
return_dict=True, |
|
).last_hidden_state |
|
preferred_embeds = preferred_embeds[:,0,:].float() |
|
|
|
|
|
rejected_embeds = self.blip.text_encoder( |
|
prompt_ids, |
|
attention_mask=prompt_mask, |
|
encoder_hidden_states=image_embeds_rejected, |
|
encoder_attention_mask=image_atts_rejected, |
|
return_dict=True, |
|
).last_hidden_state |
|
rejected_embeds = rejected_embeds[:,0,:].float() |
|
|
|
preferred_reward = self.mlp(preferred_embeds) |
|
rejected_reward = self.mlp(rejected_embeds) |
|
reward = torch.concat((preferred_reward, rejected_reward), dim=1) |
|
|
|
return reward |
|
|
|
@torch.no_grad() |
|
def score(self, image, prompt): |
|
text_input = self.blip.tokenizer(prompt, padding='max_length', truncation=True, max_length=self.max_length, return_tensors="pt").to(self.device) |
|
|
|
image_embeds = self.blip.visual_encoder(image) |
|
image_atts = torch.ones(image_embeds.size()[:-1],dtype=torch.long).to(self.device) |
|
|
|
text_embeds = self.blip.text_encoder( |
|
text_input.input_ids, |
|
attention_mask=text_input.attention_mask, |
|
encoder_hidden_states=image_embeds, |
|
encoder_attention_mask=image_atts, |
|
return_dict=True, |
|
).last_hidden_state |
|
text_embeds = text_embeds[:,0,:].float() |
|
|
|
rewards = self.mlp(text_embeds) |
|
return rewards |
|
|
|
class MLP(nn.Module): |
|
def __init__(self, input_size): |
|
super().__init__() |
|
self.layers = nn.Sequential( |
|
nn.Linear(input_size, 1024), |
|
nn.GELU(), |
|
|
|
nn.Linear(1024, 128), |
|
nn.GELU(), |
|
|
|
nn.Linear(128, 64), |
|
nn.GELU(), |
|
|
|
nn.Linear(64, 16), |
|
nn.GELU(), |
|
|
|
nn.Linear(16, 1) |
|
) |
|
|
|
def init_weights(m): |
|
if isinstance(m, nn.Linear): |
|
nn.init.xavier_uniform_(m.weight) |
|
if m.bias is not None: |
|
nn.init.zeros_(m.bias) |
|
|
|
self.layers.apply(init_weights) |
|
|
|
def forward(self, input): |
|
return self.layers(input) |
|
|
|
|
|
|
|
|