from __future__ import annotations import pathlib import random import sys from typing import Any import cv2 import numpy as np import PIL.Image import torch import torch.nn as nn import torch.nn.functional as F import torchvision.transforms as T import tqdm.auto from diffusers import AutoencoderKL, LMSDiscreteScheduler, UNet2DConditionModel from huggingface_hub import hf_hub_download from transformers import CLIPTextModel, CLIPTokenizer, CLIPVisionModel repo_dir = pathlib.Path(__file__).parent submodule_dir = repo_dir / 'ELITE' sys.path.insert(0, submodule_dir.as_posix()) from train_local import (Mapper, MapperLocal, inj_forward_crossattention, inj_forward_text, th2image, value_local_list) def get_tensor_clip(normalize=True, toTensor=True): transform_list = [] if toTensor: transform_list += [T.ToTensor()] if normalize: transform_list += [ T.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)) ] return T.Compose(transform_list) def process(image: np.ndarray, size: int = 512) -> torch.Tensor: image = cv2.resize(image, (size, size), interpolation=cv2.INTER_CUBIC) image = np.array(image).astype(np.float32) image = image / 127.5 - 1.0 return torch.from_numpy(image).permute(2, 0, 1) class Model: def __init__(self): self.device = torch.device( 'cuda:0' if torch.cuda.is_available() else 'cpu') (self.vae, self.unet, self.text_encoder, self.tokenizer, self.image_encoder, self.mapper, self.mapper_local, self.scheduler) = self.load_model() def download_mappers(self) -> tuple[str, str]: global_mapper_path = hf_hub_download('ELITE-library/ELITE', 'global_mapper.pt', subfolder='checkpoints', repo_type='model') local_mapper_path = hf_hub_download('ELITE-library/ELITE', 'local_mapper.pt', subfolder='checkpoints', repo_type='model') return global_mapper_path, local_mapper_path def load_model( self, scheduler_type=LMSDiscreteScheduler ) -> tuple[UNet2DConditionModel, CLIPTextModel, CLIPTokenizer, AutoencoderKL, CLIPVisionModel, Mapper, MapperLocal, LMSDiscreteScheduler, ]: diffusion_model_id = 'CompVis/stable-diffusion-v1-4' vae = AutoencoderKL.from_pretrained( diffusion_model_id, subfolder='vae', torch_dtype=torch.float16, ) tokenizer = CLIPTokenizer.from_pretrained( 'openai/clip-vit-large-patch14', torch_dtype=torch.float16, ) text_encoder = CLIPTextModel.from_pretrained( 'openai/clip-vit-large-patch14', torch_dtype=torch.float16, ) image_encoder = CLIPVisionModel.from_pretrained( 'openai/clip-vit-large-patch14', torch_dtype=torch.float16, ) # Load models and create wrapper for stable diffusion for _module in text_encoder.modules(): if _module.__class__.__name__ == 'CLIPTextTransformer': _module.__class__.__call__ = inj_forward_text unet = UNet2DConditionModel.from_pretrained( diffusion_model_id, subfolder='unet', torch_dtype=torch.float16, ) inj_forward_crossattention mapper = Mapper(input_dim=1024, output_dim=768) mapper_local = MapperLocal(input_dim=1024, output_dim=768) for _name, _module in unet.named_modules(): if _module.__class__.__name__ == 'CrossAttention': if 'attn1' in _name: continue _module.__class__.__call__ = inj_forward_crossattention shape = _module.to_k.weight.shape to_k_global = nn.Linear(shape[1], shape[0], bias=False) mapper.add_module(f'{_name.replace(".", "_")}_to_k', to_k_global) shape = _module.to_v.weight.shape to_v_global = nn.Linear(shape[1], shape[0], bias=False) mapper.add_module(f'{_name.replace(".", "_")}_to_v', to_v_global) to_v_local = nn.Linear(shape[1], shape[0], bias=False) mapper_local.add_module(f'{_name.replace(".", "_")}_to_v', to_v_local) to_k_local = nn.Linear(shape[1], shape[0], bias=False) mapper_local.add_module(f'{_name.replace(".", "_")}_to_k', to_k_local) global_mapper_path, local_mapper_path = self.download_mappers() mapper.load_state_dict( torch.load(global_mapper_path, map_location='cpu')) mapper.half() mapper_local.load_state_dict( torch.load(local_mapper_path, map_location='cpu')) mapper_local.half() for _name, _module in unet.named_modules(): if 'attn1' in _name: continue if _module.__class__.__name__ == 'CrossAttention': _module.add_module( 'to_k_global', mapper.__getattr__(f'{_name.replace(".", "_")}_to_k')) _module.add_module( 'to_v_global', mapper.__getattr__(f'{_name.replace(".", "_")}_to_v')) _module.add_module( 'to_v_local', getattr(mapper_local, f'{_name.replace(".", "_")}_to_v')) _module.add_module( 'to_k_local', getattr(mapper_local, f'{_name.replace(".", "_")}_to_k')) vae.eval().to(self.device) unet.eval().to(self.device) text_encoder.eval().to(self.device) image_encoder.eval().to(self.device) mapper.eval().to(self.device) mapper_local.eval().to(self.device) scheduler = scheduler_type( beta_start=0.00085, beta_end=0.012, beta_schedule='scaled_linear', num_train_timesteps=1000, ) return (vae, unet, text_encoder, tokenizer, image_encoder, mapper, mapper_local, scheduler) def prepare_data(self, image: PIL.Image.Image, mask: PIL.Image.Image, text: str, placeholder_string: str = 'S') -> dict[str, Any]: data: dict[str, Any] = {} data['text'] = text placeholder_index = 0 words = text.strip().split(' ') for idx, word in enumerate(words): if word == placeholder_string: placeholder_index = idx + 1 data['index'] = torch.tensor(placeholder_index) data['input_ids'] = self.tokenizer( text, padding='max_length', truncation=True, max_length=self.tokenizer.model_max_length, return_tensors='pt', ).input_ids[0] image = image.convert('RGB') mask = mask.convert('RGB') mask = np.array(mask) / 255.0 image_np = np.array(image) object_tensor = image_np * mask data['pixel_values'] = process(image_np) ref_object_tensor = PIL.Image.fromarray( object_tensor.astype('uint8')).resize( (224, 224), resample=PIL.Image.Resampling.BICUBIC) ref_image_tenser = PIL.Image.fromarray( image_np.astype('uint8')).resize( (224, 224), resample=PIL.Image.Resampling.BICUBIC) data['pixel_values_obj'] = get_tensor_clip()(ref_object_tensor) data['pixel_values_clip'] = get_tensor_clip()(ref_image_tenser) ref_seg_tensor = PIL.Image.fromarray(mask.astype('uint8') * 255) ref_seg_tensor = get_tensor_clip(normalize=False)(ref_seg_tensor) data['pixel_values_seg'] = F.interpolate(ref_seg_tensor.unsqueeze(0), size=(128, 128), mode='nearest').squeeze(0) device = torch.device('cuda:0') data['pixel_values'] = data['pixel_values'].to(device) data['pixel_values_clip'] = data['pixel_values_clip'].to(device).half() data['pixel_values_obj'] = data['pixel_values_obj'].to(device).half() data['pixel_values_seg'] = data['pixel_values_seg'].to(device).half() data['input_ids'] = data['input_ids'].to(device) data['index'] = data['index'].to(device).long() for key, value in list(data.items()): if isinstance(value, torch.Tensor): data[key] = value.unsqueeze(0) return data @torch.inference_mode() def run( self, image: dict[str, PIL.Image.Image], text: str, seed: int, guidance_scale: float, lambda_: float, num_steps: int, ) -> PIL.Image.Image: data = self.prepare_data(image['image'], image['mask'], text) uncond_input = self.tokenizer( [''] * data['pixel_values'].shape[0], padding='max_length', max_length=self.tokenizer.model_max_length, return_tensors='pt', ) uncond_embeddings = self.text_encoder( {'input_ids': uncond_input.input_ids.to(self.device)})[0] if seed == -1: seed = random.randint(0, 1000000) generator = torch.Generator().manual_seed(seed) latents = torch.randn( (data['pixel_values'].shape[0], self.unet.in_channels, 64, 64), generator=generator, ) latents = latents.to(data['pixel_values_clip']) self.scheduler.set_timesteps(num_steps) latents = latents * self.scheduler.init_noise_sigma placeholder_idx = data['index'] image = F.interpolate(data['pixel_values_clip'], (224, 224), mode='bilinear') image_features = self.image_encoder(image, output_hidden_states=True) image_embeddings = [ image_features[0], image_features[2][4], image_features[2][8], image_features[2][12], image_features[2][16], ] image_embeddings = [emb.detach() for emb in image_embeddings] inj_embedding = self.mapper(image_embeddings) inj_embedding = inj_embedding[:, 0:1, :] encoder_hidden_states = self.text_encoder({ 'input_ids': data['input_ids'], 'inj_embedding': inj_embedding, 'inj_index': placeholder_idx, })[0] image_obj = F.interpolate(data['pixel_values_obj'], (224, 224), mode='bilinear') image_features_obj = self.image_encoder(image_obj, output_hidden_states=True) image_embeddings_obj = [ image_features_obj[0], image_features_obj[2][4], image_features_obj[2][8], image_features_obj[2][12], image_features_obj[2][16], ] image_embeddings_obj = [emb.detach() for emb in image_embeddings_obj] inj_embedding_local = self.mapper_local(image_embeddings_obj) mask = F.interpolate(data['pixel_values_seg'], (16, 16), mode='nearest') mask = mask[:, 0].reshape(mask.shape[0], -1, 1) inj_embedding_local = inj_embedding_local * mask for t in tqdm.auto.tqdm(self.scheduler.timesteps): latent_model_input = self.scheduler.scale_model_input(latents, t) noise_pred_text = self.unet(latent_model_input, t, encoder_hidden_states={ 'CONTEXT_TENSOR': encoder_hidden_states, 'LOCAL': inj_embedding_local, 'LOCAL_INDEX': placeholder_idx.detach(), 'LAMBDA': lambda_ }).sample value_local_list.clear() latent_model_input = self.scheduler.scale_model_input(latents, t) noise_pred_uncond = self.unet(latent_model_input, t, encoder_hidden_states={ 'CONTEXT_TENSOR': uncond_embeddings, }).sample value_local_list.clear() noise_pred = noise_pred_uncond + guidance_scale * ( noise_pred_text - noise_pred_uncond) # compute the previous noisy sample x_t -> x_t-1 latents = self.scheduler.step(noise_pred, t, latents).prev_sample _latents = 1 / 0.18215 * latents.clone() images = self.vae.decode(_latents).sample return th2image(images[0])