|
from typing import List, Union |
|
|
|
import torch |
|
from PIL import Image |
|
from transformers import ( |
|
CLIPProcessor, |
|
CLIPTextModelWithProjection, |
|
CLIPTokenizer, |
|
CLIPVisionModelWithProjection, |
|
) |
|
|
|
from diffusers import StableDiffusionPipeline |
|
from .lora import patch_pipe, tune_lora_scale, _text_lora_path, _ti_lora_path |
|
import os |
|
import glob |
|
import math |
|
|
|
EXAMPLE_PROMPTS = [ |
|
"<obj> swimming in a pool", |
|
"<obj> at a beach with a view of seashore", |
|
"<obj> in times square", |
|
"<obj> wearing sunglasses", |
|
"<obj> in a construction outfit", |
|
"<obj> playing with a ball", |
|
"<obj> wearing headphones", |
|
"<obj> oil painting ghibli inspired", |
|
"<obj> working on the laptop", |
|
"<obj> with mountains and sunset in background", |
|
"Painting of <obj> at a beach by artist claude monet", |
|
"<obj> digital painting 3d render geometric style", |
|
"A screaming <obj>", |
|
"A depressed <obj>", |
|
"A sleeping <obj>", |
|
"A sad <obj>", |
|
"A joyous <obj>", |
|
"A frowning <obj>", |
|
"A sculpture of <obj>", |
|
"<obj> near a pool", |
|
"<obj> at a beach with a view of seashore", |
|
"<obj> in a garden", |
|
"<obj> in grand canyon", |
|
"<obj> floating in ocean", |
|
"<obj> and an armchair", |
|
"A maple tree on the side of <obj>", |
|
"<obj> and an orange sofa", |
|
"<obj> with chocolate cake on it", |
|
"<obj> with a vase of rose flowers on it", |
|
"A digital illustration of <obj>", |
|
"Georgia O'Keeffe style <obj> painting", |
|
"A watercolor painting of <obj> on a beach", |
|
] |
|
|
|
|
|
def image_grid(_imgs, rows=None, cols=None): |
|
|
|
if rows is None and cols is None: |
|
rows = cols = math.ceil(len(_imgs) ** 0.5) |
|
|
|
if rows is None: |
|
rows = math.ceil(len(_imgs) / cols) |
|
if cols is None: |
|
cols = math.ceil(len(_imgs) / rows) |
|
|
|
w, h = _imgs[0].size |
|
grid = Image.new("RGB", size=(cols * w, rows * h)) |
|
grid_w, grid_h = grid.size |
|
|
|
for i, img in enumerate(_imgs): |
|
grid.paste(img, box=(i % cols * w, i // cols * h)) |
|
return grid |
|
|
|
|
|
def text_img_alignment(img_embeds, text_embeds, target_img_embeds): |
|
|
|
|
|
|
|
|
|
assert img_embeds.shape[0] == text_embeds.shape[0] |
|
text_img_sim = (img_embeds * text_embeds).sum(dim=-1) / ( |
|
img_embeds.norm(dim=-1) * text_embeds.norm(dim=-1) |
|
) |
|
|
|
|
|
img_embed_normalized = img_embeds / img_embeds.norm(dim=-1, keepdim=True) |
|
|
|
avg_target_img_embed = ( |
|
(target_img_embeds / target_img_embeds.norm(dim=-1, keepdim=True)) |
|
.mean(dim=0) |
|
.unsqueeze(0) |
|
.repeat(img_embeds.shape[0], 1) |
|
) |
|
|
|
img_img_sim = (img_embed_normalized * avg_target_img_embed).sum(dim=-1) |
|
|
|
return { |
|
"text_alignment_avg": text_img_sim.mean().item(), |
|
"image_alignment_avg": img_img_sim.mean().item(), |
|
"text_alignment_all": text_img_sim.tolist(), |
|
"image_alignment_all": img_img_sim.tolist(), |
|
} |
|
|
|
|
|
def prepare_clip_model_sets(eval_clip_id: str = "openai/clip-vit-large-patch14"): |
|
text_model = CLIPTextModelWithProjection.from_pretrained(eval_clip_id) |
|
tokenizer = CLIPTokenizer.from_pretrained(eval_clip_id) |
|
vis_model = CLIPVisionModelWithProjection.from_pretrained(eval_clip_id) |
|
processor = CLIPProcessor.from_pretrained(eval_clip_id) |
|
|
|
return text_model, tokenizer, vis_model, processor |
|
|
|
|
|
def evaluate_pipe( |
|
pipe, |
|
target_images: List[Image.Image], |
|
class_token: str = "", |
|
learnt_token: str = "", |
|
guidance_scale: float = 5.0, |
|
seed=0, |
|
clip_model_sets=None, |
|
eval_clip_id: str = "openai/clip-vit-large-patch14", |
|
n_test: int = 10, |
|
n_step: int = 50, |
|
): |
|
|
|
if clip_model_sets is not None: |
|
text_model, tokenizer, vis_model, processor = clip_model_sets |
|
else: |
|
text_model, tokenizer, vis_model, processor = prepare_clip_model_sets( |
|
eval_clip_id |
|
) |
|
|
|
images = [] |
|
img_embeds = [] |
|
text_embeds = [] |
|
for prompt in EXAMPLE_PROMPTS[:n_test]: |
|
prompt = prompt.replace("<obj>", learnt_token) |
|
torch.manual_seed(seed) |
|
with torch.autocast("cuda"): |
|
img = pipe( |
|
prompt, num_inference_steps=n_step, guidance_scale=guidance_scale |
|
).images[0] |
|
images.append(img) |
|
|
|
|
|
inputs = processor(images=img, return_tensors="pt") |
|
img_embed = vis_model(**inputs).image_embeds |
|
img_embeds.append(img_embed) |
|
|
|
prompt = prompt.replace(learnt_token, class_token) |
|
|
|
inputs = tokenizer([prompt], padding=True, return_tensors="pt") |
|
outputs = text_model(**inputs) |
|
text_embed = outputs.text_embeds |
|
text_embeds.append(text_embed) |
|
|
|
|
|
inputs = processor(images=target_images, return_tensors="pt") |
|
target_img_embeds = vis_model(**inputs).image_embeds |
|
|
|
img_embeds = torch.cat(img_embeds, dim=0) |
|
text_embeds = torch.cat(text_embeds, dim=0) |
|
|
|
return text_img_alignment(img_embeds, text_embeds, target_img_embeds) |
|
|
|
|
|
def visualize_progress( |
|
path_alls: Union[str, List[str]], |
|
prompt: str, |
|
model_id: str = "runwayml/stable-diffusion-v1-5", |
|
device="cuda:0", |
|
patch_unet=True, |
|
patch_text=True, |
|
patch_ti=True, |
|
unet_scale=1.0, |
|
text_sclae=1.0, |
|
num_inference_steps=50, |
|
guidance_scale=5.0, |
|
offset: int = 0, |
|
limit: int = 10, |
|
seed: int = 0, |
|
): |
|
|
|
imgs = [] |
|
if isinstance(path_alls, str): |
|
alls = list(set(glob.glob(path_alls))) |
|
|
|
alls.sort(key=os.path.getmtime) |
|
else: |
|
alls = path_alls |
|
|
|
pipe = StableDiffusionPipeline.from_pretrained( |
|
model_id, torch_dtype=torch.float16 |
|
).to(device) |
|
|
|
print(f"Found {len(alls)} checkpoints") |
|
for path in alls[offset:limit]: |
|
print(path) |
|
|
|
patch_pipe( |
|
pipe, path, patch_unet=patch_unet, patch_text=patch_text, patch_ti=patch_ti |
|
) |
|
|
|
tune_lora_scale(pipe.unet, unet_scale) |
|
tune_lora_scale(pipe.text_encoder, text_sclae) |
|
|
|
torch.manual_seed(seed) |
|
image = pipe( |
|
prompt, |
|
num_inference_steps=num_inference_steps, |
|
guidance_scale=guidance_scale, |
|
).images[0] |
|
imgs.append(image) |
|
|
|
return imgs |
|
|