Spaces:
Runtime error
Runtime error
| import io | |
| import base64 | |
| from typing import List, Tuple | |
| import numpy as np | |
| import gradio as gr | |
| from datasets import load_dataset | |
| from transformers import AutoProcessor, AutoModel | |
| import torch | |
| from PIL import Image | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| dtype = torch.float16 if device == "cuda" else torch.float32 | |
| # Load example dataset | |
| dataset = load_dataset("xzuyn/dalle-3_vs_sd-v1-5_dpo", num_proc=4) | |
| processor_name = "laion/CLIP-ViT-H-14-laion2B-s32B-b79K" | |
| model_name = "yuvalkirstain/PickScore_v1" | |
| processor = AutoProcessor.from_pretrained(processor_name) | |
| model = AutoModel.from_pretrained(model_name, torch_dtype=dtype).to(device) | |
| def decode_image(image: str) -> Image: | |
| """ | |
| Decodes base64 string to PIL image. | |
| Args: | |
| image: base64 string | |
| Returns: | |
| PIL image | |
| """ | |
| img_byte_arr = base64.b64decode(image) | |
| img_byte_arr = io.BytesIO(img_byte_arr) | |
| img_byte_arr = Image.open(img_byte_arr) | |
| return img_byte_arr | |
| def get_preference(img_1: Image.Image, img_2: Image.Image, caption: str) -> Image.Image: | |
| """ | |
| Returns the preference of the caption for the two images. | |
| Args: | |
| img_1: PIL image | |
| img_2: PIL image | |
| caption: string | |
| Returns: | |
| preference image: PIL image | |
| """ | |
| imgs = [img_1, img_2] | |
| logits = get_logits(caption, imgs) | |
| preference = logits.argmax().item() | |
| return imgs[preference] | |
| def sample_example() -> Tuple[Image.Image, Image.Image, Image.Image, str]: | |
| """ | |
| Samples a random example from the dataset and displays it. | |
| Returns: | |
| img_1: PIL image | |
| img_2: PIL image | |
| preference: PIL image | |
| caption: string | |
| """ | |
| example = dataset["train"][np.random.randint(0, len(dataset["train"]))] | |
| img_1 = decode_image(example["jpg_0"]) | |
| img_2 = decode_image(example["jpg_1"]) | |
| caption = example["caption"] | |
| imgs = [img_1, img_2] | |
| logits = get_logits(caption, imgs) | |
| preference = logits.argmax().item() | |
| return (img_1, img_2, imgs[preference], caption) | |
| def get_logits(caption: str, imgs: List[Image.Image]) -> torch.Tensor: | |
| """ | |
| Returns the logits for the caption and images. | |
| Args: | |
| caption: string | |
| imgs: list of PIL images | |
| Returns: | |
| logits: torch.Tensor | |
| """ | |
| inputs = processor( | |
| text=caption, | |
| images=imgs, | |
| return_tensors="pt", | |
| padding=True, | |
| truncation=True, | |
| max_length=77, | |
| ).to(device) | |
| inputs["pixel_values"] = ( | |
| inputs["pixel_values"].half() if device == "cuda" else inputs["pixel_values"] | |
| ) | |
| with torch.no_grad(): | |
| outputs = model(**inputs) | |
| logits_per_image = outputs.logits_per_image | |
| return logits_per_image | |
| ### Description | |
| title = r""" | |
| <h1 align="center">Aesthetic Scorer: CLIP fine-tuned for DPO scoring </h1> | |
| """ | |
| description = r""" | |
| <b> This is a demo for the paper <a href="https://arxiv.org/abs/2109.04436">Pick-a-Pic: An Open Dataset of User Preferences for Text-to-Image Generation </a> </b> <br> | |
| How to use this demo: <br> | |
| 1. Upload two images generated using the same caption. | |
| 2. Enter the caption used to generate the images. | |
| 3. Click on the "Get Preference" button to get the image which scores higher on user preferences according to the model. <br> | |
| <b> OR </b> <br> | |
| 1. Click on the "Random Example" button to get a random example from a <a href="https://huggingface.co/datasets/xzuyn/dalle-3_vs_sd-v1-5_dpo">Dalle 3 vs SD 1.5 DPO dataset. </a><br> | |
| This demo demonstrates the use of this CLIP variant for DPO scoring. The scores can then be used for DPO fine-tuning with these <a href="https://github.com/huggingface/diffusers/tree/main/examples/research_projects/diffusion_dpo">scripts. </a><br> | |
| Accuracy on the <a href="https://huggingface.co/datasets/xzuyn/dalle-3_vs_sd-v1-5_dpo">Dalle 3 vs SD 1.5 DPO dataset: </a><br> | |
| <a href="https://huggingface.co/yuvalkirstain/PickScore_v1">PickScore_v1</a> - 97.3 <br> | |
| <a href="https://huggingface.co/CIDAS/clipseg-rd64-refined">CLIPSeg</a> - 70.9 <br> | |
| <a href="https://huggingface.co/laion/CLIP-ViT-H-14-laion2B-s32B-b79K">CLIP-ViT-H-14-laion2B-s32B-b79K</a> - 82.3 <br> | |
| """ | |
| citation = r""" | |
| π **Citation** | |
| ```bibtex | |
| @inproceedings{Kirstain2023PickaPicAO, | |
| title={Pick-a-Pic: An Open Dataset of User Preferences for Text-to-Image Generation}, | |
| author={Yuval Kirstain and Adam Polyak and Uriel Singer and Shahbuland Matiana and Joe Penna and Omer Levy}, | |
| year={2023} | |
| } | |
| ``` | |
| """ | |
| with gr.Blocks() as demo: | |
| gr.Markdown(title) | |
| gr.Markdown(description) | |
| with gr.Row(): | |
| first_image = gr.Image(height=400, width=400, label="First Image") | |
| second_image = gr.Image(height=400, width=400, label="Second Image") | |
| caption_box = gr.Textbox(lines=1, label="Caption") | |
| with gr.Row(): | |
| image_button = gr.Button("Get Preference") | |
| random_example = gr.Button("Random Example") | |
| image_output = gr.Image(height=400, width=400, label="Preference") | |
| image_button.click( | |
| get_preference, | |
| inputs=[first_image, second_image, caption_box], | |
| outputs=image_output, | |
| ) | |
| random_example.click( | |
| sample_example, outputs=[first_image, second_image, image_output, caption_box] | |
| ) | |
| gr.Markdown(citation) | |
| if __name__ == "__main__": | |
| demo.launch() |