Spaces:
Runtime error
Runtime error
| from typing import List, Union | |
| import torch | |
| from PIL import Image | |
| from transformers import ( | |
| CLIPProcessor, | |
| CLIPTextModelWithProjection, | |
| CLIPTokenizer, | |
| CLIPVisionModelWithProjection, | |
| ) | |
| from diffusers import StableDiffusionPipeline | |
| from .lora import patch_pipe, tune_lora_scale, _text_lora_path, _ti_lora_path | |
| import os | |
| import glob | |
| import math | |
| EXAMPLE_PROMPTS = [ | |
| "<obj> swimming in a pool", | |
| "<obj> at a beach with a view of seashore", | |
| "<obj> in times square", | |
| "<obj> wearing sunglasses", | |
| "<obj> in a construction outfit", | |
| "<obj> playing with a ball", | |
| "<obj> wearing headphones", | |
| "<obj> oil painting ghibli inspired", | |
| "<obj> working on the laptop", | |
| "<obj> with mountains and sunset in background", | |
| "Painting of <obj> at a beach by artist claude monet", | |
| "<obj> digital painting 3d render geometric style", | |
| "A screaming <obj>", | |
| "A depressed <obj>", | |
| "A sleeping <obj>", | |
| "A sad <obj>", | |
| "A joyous <obj>", | |
| "A frowning <obj>", | |
| "A sculpture of <obj>", | |
| "<obj> near a pool", | |
| "<obj> at a beach with a view of seashore", | |
| "<obj> in a garden", | |
| "<obj> in grand canyon", | |
| "<obj> floating in ocean", | |
| "<obj> and an armchair", | |
| "A maple tree on the side of <obj>", | |
| "<obj> and an orange sofa", | |
| "<obj> with chocolate cake on it", | |
| "<obj> with a vase of rose flowers on it", | |
| "A digital illustration of <obj>", | |
| "Georgia O'Keeffe style <obj> painting", | |
| "A watercolor painting of <obj> on a beach", | |
| ] | |
| def image_grid(_imgs, rows=None, cols=None): | |
| if rows is None and cols is None: | |
| rows = cols = math.ceil(len(_imgs) ** 0.5) | |
| if rows is None: | |
| rows = math.ceil(len(_imgs) / cols) | |
| if cols is None: | |
| cols = math.ceil(len(_imgs) / rows) | |
| w, h = _imgs[0].size | |
| grid = Image.new("RGB", size=(cols * w, rows * h)) | |
| grid_w, grid_h = grid.size | |
| for i, img in enumerate(_imgs): | |
| grid.paste(img, box=(i % cols * w, i // cols * h)) | |
| return grid | |
| def text_img_alignment(img_embeds, text_embeds, target_img_embeds): | |
| # evaluation inspired from textual inversion paper | |
| # https://arxiv.org/abs/2208.01618 | |
| # text alignment | |
| assert img_embeds.shape[0] == text_embeds.shape[0] | |
| text_img_sim = (img_embeds * text_embeds).sum(dim=-1) / ( | |
| img_embeds.norm(dim=-1) * text_embeds.norm(dim=-1) | |
| ) | |
| # image alignment | |
| img_embed_normalized = img_embeds / img_embeds.norm(dim=-1, keepdim=True) | |
| avg_target_img_embed = ( | |
| (target_img_embeds / target_img_embeds.norm(dim=-1, keepdim=True)) | |
| .mean(dim=0) | |
| .unsqueeze(0) | |
| .repeat(img_embeds.shape[0], 1) | |
| ) | |
| img_img_sim = (img_embed_normalized * avg_target_img_embed).sum(dim=-1) | |
| return { | |
| "text_alignment_avg": text_img_sim.mean().item(), | |
| "image_alignment_avg": img_img_sim.mean().item(), | |
| "text_alignment_all": text_img_sim.tolist(), | |
| "image_alignment_all": img_img_sim.tolist(), | |
| } | |
| def prepare_clip_model_sets(eval_clip_id: str = "openai/clip-vit-large-patch14"): | |
| text_model = CLIPTextModelWithProjection.from_pretrained(eval_clip_id) | |
| tokenizer = CLIPTokenizer.from_pretrained(eval_clip_id) | |
| vis_model = CLIPVisionModelWithProjection.from_pretrained(eval_clip_id) | |
| processor = CLIPProcessor.from_pretrained(eval_clip_id) | |
| return text_model, tokenizer, vis_model, processor | |
| def evaluate_pipe( | |
| pipe, | |
| target_images: List[Image.Image], | |
| class_token: str = "", | |
| learnt_token: str = "", | |
| guidance_scale: float = 5.0, | |
| seed=0, | |
| clip_model_sets=None, | |
| eval_clip_id: str = "openai/clip-vit-large-patch14", | |
| n_test: int = 10, | |
| n_step: int = 50, | |
| ): | |
| if clip_model_sets is not None: | |
| text_model, tokenizer, vis_model, processor = clip_model_sets | |
| else: | |
| text_model, tokenizer, vis_model, processor = prepare_clip_model_sets( | |
| eval_clip_id | |
| ) | |
| images = [] | |
| img_embeds = [] | |
| text_embeds = [] | |
| for prompt in EXAMPLE_PROMPTS[:n_test]: | |
| prompt = prompt.replace("<obj>", learnt_token) | |
| torch.manual_seed(seed) | |
| with torch.autocast("cuda"): | |
| img = pipe( | |
| prompt, num_inference_steps=n_step, guidance_scale=guidance_scale | |
| ).images[0] | |
| images.append(img) | |
| # image | |
| inputs = processor(images=img, return_tensors="pt") | |
| img_embed = vis_model(**inputs).image_embeds | |
| img_embeds.append(img_embed) | |
| prompt = prompt.replace(learnt_token, class_token) | |
| # prompts | |
| inputs = tokenizer([prompt], padding=True, return_tensors="pt") | |
| outputs = text_model(**inputs) | |
| text_embed = outputs.text_embeds | |
| text_embeds.append(text_embed) | |
| # target images | |
| inputs = processor(images=target_images, return_tensors="pt") | |
| target_img_embeds = vis_model(**inputs).image_embeds | |
| img_embeds = torch.cat(img_embeds, dim=0) | |
| text_embeds = torch.cat(text_embeds, dim=0) | |
| return text_img_alignment(img_embeds, text_embeds, target_img_embeds) | |
| def visualize_progress( | |
| path_alls: Union[str, List[str]], | |
| prompt: str, | |
| model_id: str = "runwayml/stable-diffusion-v1-5", | |
| device="cuda:0", | |
| patch_unet=True, | |
| patch_text=True, | |
| patch_ti=True, | |
| unet_scale=1.0, | |
| text_sclae=1.0, | |
| num_inference_steps=50, | |
| guidance_scale=5.0, | |
| offset: int = 0, | |
| limit: int = 10, | |
| seed: int = 0, | |
| ): | |
| imgs = [] | |
| if isinstance(path_alls, str): | |
| alls = list(set(glob.glob(path_alls))) | |
| alls.sort(key=os.path.getmtime) | |
| else: | |
| alls = path_alls | |
| pipe = StableDiffusionPipeline.from_pretrained( | |
| model_id, torch_dtype=torch.float16 | |
| ).to(device) | |
| print(f"Found {len(alls)} checkpoints") | |
| for path in alls[offset:limit]: | |
| print(path) | |
| patch_pipe( | |
| pipe, path, patch_unet=patch_unet, patch_text=patch_text, patch_ti=patch_ti | |
| ) | |
| tune_lora_scale(pipe.unet, unet_scale) | |
| tune_lora_scale(pipe.text_encoder, text_sclae) | |
| torch.manual_seed(seed) | |
| image = pipe( | |
| prompt, | |
| num_inference_steps=num_inference_steps, | |
| guidance_scale=guidance_scale, | |
| ).images[0] | |
| imgs.append(image) | |
| return imgs | |