""" Adapted from ImageReward (https://github.com/THUDM/ImageReward) """ import os import torch import torch.nn as nn from PIL import Image # from .config import cyclereward_args from blip.blip_pretrain import blip_pretrain from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize from huggingface_hub import PyTorchModelHubMixin try: from torchvision.transforms import InterpolationMode BICUBIC = InterpolationMode.BICUBIC except ImportError: BICUBIC = Image.BICUBIC cyclereward_args = { 'blip_path': 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_large.pth', 'vit': 'large', 'image_size': 224, 'mlp_dim': 768 } def _convert_image_to_rgb(image): return image.convert("RGB") def _transform(n_px): return Compose([ Resize(n_px, interpolation=BICUBIC), CenterCrop(n_px), _convert_image_to_rgb, ToTensor(), Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)), ]) class CycleReward(nn.Module, PyTorchModelHubMixin): def __init__(self, device='cpu', model_type='CycleReward-Combo', max_length=128, fix_rate=0.7, med_config=None, ): super().__init__() self.device = device self.model_type = model_type self.max_length = max_length self.blip = blip_pretrain( pretrained=cyclereward_args['blip_path'], med_config=med_config, image_size=cyclereward_args['image_size'], vit=cyclereward_args['vit'] ) self.preprocess = _transform(cyclereward_args['image_size']) self.mlp = MLP(cyclereward_args['mlp_dim']) for name, parms in self.blip.named_parameters(): if '_proj' in name: parms.requires_grad_(False) # fix certain ratio of layers (setting from ImageReward) self.image_layer_num = 24 if cyclereward_args['vit'] == 'large' else 12 if fix_rate > 0: text_fix_num = "layer.{}".format(int(12 * fix_rate)) image_fix_num = "blocks.{}".format(int(self.image_layer_num * fix_rate)) for name, parms in self.blip.text_encoder.named_parameters(): parms.requires_grad_(False) if text_fix_num in name: break for name, parms in self.blip.visual_encoder.named_parameters(): parms.requires_grad_(False) if image_fix_num in name: break def forward(self, batch): if 'Combo' in self.model_type: text_reward = self.text_reward(batch) image_reward = self.image_reward(batch) elif 'I2T' in self.model_type: text_reward = self.text_reward(batch) image_reward = None elif 'T2I' in self.model_type: text_reward = None image_reward = self.image_reward(batch) return text_reward, image_reward def text_reward(self, batch): images, preferred_ids, preferred_mask, rejected_ids, rejected_mask = batch["images"], batch["preferred_ids"], batch["preferred_mask"], batch["rejected_ids"], batch["rejected_mask"] images = images.to(self.device) preferred_ids = preferred_ids.to(self.device) preferred_mask = preferred_mask.to(self.device) rejected_ids = rejected_ids.to(self.device) rejected_mask = rejected_mask.to(self.device) # encode image image_embeds = self.blip.visual_encoder(images) image_atts = torch.ones(image_embeds.size()[:-1],dtype=torch.long).to(self.device) # encode preferred preferred_embeds = self.blip.text_encoder( preferred_ids, attention_mask=preferred_mask, encoder_hidden_states=image_embeds, encoder_attention_mask=image_atts, return_dict=True, ).last_hidden_state preferred_embeds = preferred_embeds[:,0,:].float() # encode rejected rejected_embeds = self.blip.text_encoder( rejected_ids, attention_mask=rejected_mask, encoder_hidden_states=image_embeds, encoder_attention_mask=image_atts, return_dict=True, ).last_hidden_state rejected_embeds = rejected_embeds[:,0,:].float() preferred_reward = self.mlp(preferred_embeds) rejected_reward = self.mlp(rejected_embeds) reward = torch.concat((preferred_reward, rejected_reward), dim=1) return reward def image_reward(self, batch): prompt_ids, prompt_mask, image_preferred, image_rejected = batch["prompt_ids"], batch["prompt_mask"], batch["image_preferred"], batch["image_rejected"] image_preferred = image_preferred.to(self.device) image_rejected = image_rejected.to(self.device) prompt_ids = prompt_ids.view(prompt_ids.shape[0], -1).to(self.device) prompt_mask = prompt_mask.view(prompt_mask.shape[0], -1).to(self.device) # encode image image_embeds_preferred = self.blip.visual_encoder(image_preferred) image_atts_preferred = torch.ones(image_embeds_preferred.size()[:-1],dtype=torch.long).to(self.device) image_embeds_rejected = self.blip.visual_encoder(image_rejected) image_atts_rejected = torch.ones(image_embeds_rejected.size()[:-1],dtype=torch.long).to(self.device) # encode preferred preferred_embeds = self.blip.text_encoder( prompt_ids, attention_mask=prompt_mask, encoder_hidden_states=image_embeds_preferred, encoder_attention_mask=image_atts_preferred, return_dict=True, ).last_hidden_state preferred_embeds = preferred_embeds[:,0,:].float() # encode rejected rejected_embeds = self.blip.text_encoder( prompt_ids, attention_mask=prompt_mask, encoder_hidden_states=image_embeds_rejected, encoder_attention_mask=image_atts_rejected, return_dict=True, ).last_hidden_state rejected_embeds = rejected_embeds[:,0,:].float() preferred_reward = self.mlp(preferred_embeds) rejected_reward = self.mlp(rejected_embeds) reward = torch.concat((preferred_reward, rejected_reward), dim=1) return reward @torch.no_grad() def score(self, image, prompt): text_input = self.blip.tokenizer(prompt, padding='max_length', truncation=True, max_length=self.max_length, return_tensors="pt").to(self.device) image_embeds = self.blip.visual_encoder(image) image_atts = torch.ones(image_embeds.size()[:-1],dtype=torch.long).to(self.device) text_embeds = self.blip.text_encoder( text_input.input_ids, attention_mask=text_input.attention_mask, encoder_hidden_states=image_embeds, encoder_attention_mask=image_atts, return_dict=True, ).last_hidden_state text_embeds = text_embeds[:,0,:].float() rewards = self.mlp(text_embeds) return rewards class MLP(nn.Module): def __init__(self, input_size): super().__init__() self.layers = nn.Sequential( nn.Linear(input_size, 1024), nn.GELU(), nn.Linear(1024, 128), nn.GELU(), nn.Linear(128, 64), nn.GELU(), nn.Linear(64, 16), nn.GELU(), nn.Linear(16, 1) ) def init_weights(m): if isinstance(m, nn.Linear): nn.init.xavier_uniform_(m.weight) if m.bias is not None: nn.init.zeros_(m.bias) self.layers.apply(init_weights) def forward(self, input): return self.layers(input)