Spaces:
Runtime error
Runtime error
Upload 8 files
Browse files- apps/app_sana.py +502 -0
- apps/app_sana_4bit.py +409 -0
- apps/app_sana_4bit_compare_bf16.py +313 -0
- apps/app_sana_controlnet_hed.py +306 -0
- apps/app_sana_multithread.py +565 -0
- apps/safety_check.py +72 -0
- apps/sana_controlnet_pipeline.py +353 -0
- apps/sana_pipeline.py +304 -0
apps/app_sana.py
ADDED
|
@@ -0,0 +1,502 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python
|
| 2 |
+
# Copyright 2024 NVIDIA CORPORATION & AFFILIATES
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
#
|
| 16 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 17 |
+
from __future__ import annotations
|
| 18 |
+
|
| 19 |
+
import argparse
|
| 20 |
+
import os
|
| 21 |
+
import random
|
| 22 |
+
import socket
|
| 23 |
+
import sqlite3
|
| 24 |
+
import time
|
| 25 |
+
import uuid
|
| 26 |
+
from datetime import datetime
|
| 27 |
+
|
| 28 |
+
import gradio as gr
|
| 29 |
+
import numpy as np
|
| 30 |
+
import spaces
|
| 31 |
+
import torch
|
| 32 |
+
from PIL import Image
|
| 33 |
+
from torchvision.utils import make_grid, save_image
|
| 34 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer
|
| 35 |
+
|
| 36 |
+
from app import safety_check
|
| 37 |
+
from app.sana_pipeline import SanaPipeline
|
| 38 |
+
|
| 39 |
+
MAX_SEED = np.iinfo(np.int32).max
|
| 40 |
+
CACHE_EXAMPLES = torch.cuda.is_available() and os.getenv("CACHE_EXAMPLES", "1") == "1"
|
| 41 |
+
MAX_IMAGE_SIZE = int(os.getenv("MAX_IMAGE_SIZE", "4096"))
|
| 42 |
+
USE_TORCH_COMPILE = os.getenv("USE_TORCH_COMPILE", "0") == "1"
|
| 43 |
+
ENABLE_CPU_OFFLOAD = os.getenv("ENABLE_CPU_OFFLOAD", "0") == "1"
|
| 44 |
+
DEMO_PORT = int(os.getenv("DEMO_PORT", "15432"))
|
| 45 |
+
os.environ["GRADIO_EXAMPLES_CACHE"] = "./.gradio/cache"
|
| 46 |
+
COUNTER_DB = os.getenv("COUNTER_DB", ".count.db")
|
| 47 |
+
|
| 48 |
+
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
| 49 |
+
|
| 50 |
+
style_list = [
|
| 51 |
+
{
|
| 52 |
+
"name": "(No style)",
|
| 53 |
+
"prompt": "{prompt}",
|
| 54 |
+
"negative_prompt": "",
|
| 55 |
+
},
|
| 56 |
+
{
|
| 57 |
+
"name": "Cinematic",
|
| 58 |
+
"prompt": "cinematic still {prompt} . emotional, harmonious, vignette, highly detailed, high budget, bokeh, "
|
| 59 |
+
"cinemascope, moody, epic, gorgeous, film grain, grainy",
|
| 60 |
+
"negative_prompt": "anime, cartoon, graphic, text, painting, crayon, graphite, abstract, glitch, deformed, mutated, ugly, disfigured",
|
| 61 |
+
},
|
| 62 |
+
{
|
| 63 |
+
"name": "Photographic",
|
| 64 |
+
"prompt": "cinematic photo {prompt} . 35mm photograph, film, bokeh, professional, 4k, highly detailed",
|
| 65 |
+
"negative_prompt": "drawing, painting, crayon, sketch, graphite, impressionist, noisy, blurry, soft, deformed, ugly",
|
| 66 |
+
},
|
| 67 |
+
{
|
| 68 |
+
"name": "Anime",
|
| 69 |
+
"prompt": "anime artwork {prompt} . anime style, key visual, vibrant, studio anime, highly detailed",
|
| 70 |
+
"negative_prompt": "photo, deformed, black and white, realism, disfigured, low contrast",
|
| 71 |
+
},
|
| 72 |
+
{
|
| 73 |
+
"name": "Manga",
|
| 74 |
+
"prompt": "manga style {prompt} . vibrant, high-energy, detailed, iconic, Japanese comic style",
|
| 75 |
+
"negative_prompt": "ugly, deformed, noisy, blurry, low contrast, realism, photorealistic, Western comic style",
|
| 76 |
+
},
|
| 77 |
+
{
|
| 78 |
+
"name": "Digital Art",
|
| 79 |
+
"prompt": "concept art {prompt} . digital artwork, illustrative, painterly, matte painting, highly detailed",
|
| 80 |
+
"negative_prompt": "photo, photorealistic, realism, ugly",
|
| 81 |
+
},
|
| 82 |
+
{
|
| 83 |
+
"name": "Pixel art",
|
| 84 |
+
"prompt": "pixel-art {prompt} . low-res, blocky, pixel art style, 8-bit graphics",
|
| 85 |
+
"negative_prompt": "sloppy, messy, blurry, noisy, highly detailed, ultra textured, photo, realistic",
|
| 86 |
+
},
|
| 87 |
+
{
|
| 88 |
+
"name": "Fantasy art",
|
| 89 |
+
"prompt": "ethereal fantasy concept art of {prompt} . magnificent, celestial, ethereal, painterly, epic, "
|
| 90 |
+
"majestic, magical, fantasy art, cover art, dreamy",
|
| 91 |
+
"negative_prompt": "photographic, realistic, realism, 35mm film, dslr, cropped, frame, text, deformed, "
|
| 92 |
+
"glitch, noise, noisy, off-center, deformed, cross-eyed, closed eyes, bad anatomy, ugly, "
|
| 93 |
+
"disfigured, sloppy, duplicate, mutated, black and white",
|
| 94 |
+
},
|
| 95 |
+
{
|
| 96 |
+
"name": "Neonpunk",
|
| 97 |
+
"prompt": "neonpunk style {prompt} . cyberpunk, vaporwave, neon, vibes, vibrant, stunningly beautiful, crisp, "
|
| 98 |
+
"detailed, sleek, ultramodern, magenta highlights, dark purple shadows, high contrast, cinematic, "
|
| 99 |
+
"ultra detailed, intricate, professional",
|
| 100 |
+
"negative_prompt": "painting, drawing, illustration, glitch, deformed, mutated, cross-eyed, ugly, disfigured",
|
| 101 |
+
},
|
| 102 |
+
{
|
| 103 |
+
"name": "3D Model",
|
| 104 |
+
"prompt": "professional 3d model {prompt} . octane render, highly detailed, volumetric, dramatic lighting",
|
| 105 |
+
"negative_prompt": "ugly, deformed, noisy, low poly, blurry, painting",
|
| 106 |
+
},
|
| 107 |
+
]
|
| 108 |
+
|
| 109 |
+
styles = {k["name"]: (k["prompt"], k["negative_prompt"]) for k in style_list}
|
| 110 |
+
STYLE_NAMES = list(styles.keys())
|
| 111 |
+
DEFAULT_STYLE_NAME = "(No style)"
|
| 112 |
+
SCHEDULE_NAME = ["Flow_DPM_Solver"]
|
| 113 |
+
DEFAULT_SCHEDULE_NAME = "Flow_DPM_Solver"
|
| 114 |
+
NUM_IMAGES_PER_PROMPT = 1
|
| 115 |
+
INFER_SPEED = 0
|
| 116 |
+
|
| 117 |
+
|
| 118 |
+
def norm_ip(img, low, high):
|
| 119 |
+
img.clamp_(min=low, max=high)
|
| 120 |
+
img.sub_(low).div_(max(high - low, 1e-5))
|
| 121 |
+
return img
|
| 122 |
+
|
| 123 |
+
|
| 124 |
+
def open_db():
|
| 125 |
+
db = sqlite3.connect(COUNTER_DB)
|
| 126 |
+
db.execute("CREATE TABLE IF NOT EXISTS counter(app CHARS PRIMARY KEY UNIQUE, value INTEGER)")
|
| 127 |
+
db.execute('INSERT OR IGNORE INTO counter(app, value) VALUES("Sana", 0)')
|
| 128 |
+
return db
|
| 129 |
+
|
| 130 |
+
|
| 131 |
+
def read_inference_count():
|
| 132 |
+
with open_db() as db:
|
| 133 |
+
cur = db.execute('SELECT value FROM counter WHERE app="Sana"')
|
| 134 |
+
db.commit()
|
| 135 |
+
return cur.fetchone()[0]
|
| 136 |
+
|
| 137 |
+
|
| 138 |
+
def write_inference_count(count):
|
| 139 |
+
count = max(0, int(count))
|
| 140 |
+
with open_db() as db:
|
| 141 |
+
db.execute(f'UPDATE counter SET value=value+{count} WHERE app="Sana"')
|
| 142 |
+
db.commit()
|
| 143 |
+
|
| 144 |
+
|
| 145 |
+
def run_inference(num_imgs=1):
|
| 146 |
+
write_inference_count(num_imgs)
|
| 147 |
+
count = read_inference_count()
|
| 148 |
+
|
| 149 |
+
return (
|
| 150 |
+
f"<span style='font-size: 16px; font-weight: bold;'>Total inference runs: </span><span style='font-size: "
|
| 151 |
+
f"16px; color:red; font-weight: bold;'>{count}</span>"
|
| 152 |
+
)
|
| 153 |
+
|
| 154 |
+
|
| 155 |
+
def update_inference_count():
|
| 156 |
+
count = read_inference_count()
|
| 157 |
+
return (
|
| 158 |
+
f"<span style='font-size: 16px; font-weight: bold;'>Total inference runs: </span><span style='font-size: "
|
| 159 |
+
f"16px; color:red; font-weight: bold;'>{count}</span>"
|
| 160 |
+
)
|
| 161 |
+
|
| 162 |
+
|
| 163 |
+
def apply_style(style_name: str, positive: str, negative: str = "") -> tuple[str, str]:
|
| 164 |
+
p, n = styles.get(style_name, styles[DEFAULT_STYLE_NAME])
|
| 165 |
+
if not negative:
|
| 166 |
+
negative = ""
|
| 167 |
+
return p.replace("{prompt}", positive), n + negative
|
| 168 |
+
|
| 169 |
+
|
| 170 |
+
def get_args():
|
| 171 |
+
parser = argparse.ArgumentParser()
|
| 172 |
+
parser.add_argument("--config", type=str, help="config")
|
| 173 |
+
parser.add_argument(
|
| 174 |
+
"--model_path",
|
| 175 |
+
nargs="?",
|
| 176 |
+
default="hf://Efficient-Large-Model/Sana_1600M_1024px/checkpoints/Sana_1600M_1024px.pth",
|
| 177 |
+
type=str,
|
| 178 |
+
help="Path to the model file (positional)",
|
| 179 |
+
)
|
| 180 |
+
parser.add_argument("--output", default="./", type=str)
|
| 181 |
+
parser.add_argument("--bs", default=1, type=int)
|
| 182 |
+
parser.add_argument("--image_size", default=1024, type=int)
|
| 183 |
+
parser.add_argument("--cfg_scale", default=5.0, type=float)
|
| 184 |
+
parser.add_argument("--pag_scale", default=2.0, type=float)
|
| 185 |
+
parser.add_argument("--seed", default=42, type=int)
|
| 186 |
+
parser.add_argument("--step", default=-1, type=int)
|
| 187 |
+
parser.add_argument("--custom_image_size", default=None, type=int)
|
| 188 |
+
parser.add_argument("--share", action="store_true")
|
| 189 |
+
parser.add_argument(
|
| 190 |
+
"--shield_model_path",
|
| 191 |
+
type=str,
|
| 192 |
+
help="The path to shield model, we employ ShieldGemma-2B by default.",
|
| 193 |
+
default="google/shieldgemma-2b",
|
| 194 |
+
)
|
| 195 |
+
|
| 196 |
+
return parser.parse_known_args()[0]
|
| 197 |
+
|
| 198 |
+
|
| 199 |
+
args = get_args()
|
| 200 |
+
|
| 201 |
+
if torch.cuda.is_available():
|
| 202 |
+
model_path = args.model_path
|
| 203 |
+
pipe = SanaPipeline(args.config)
|
| 204 |
+
pipe.from_pretrained(model_path)
|
| 205 |
+
pipe.register_progress_bar(gr.Progress())
|
| 206 |
+
|
| 207 |
+
# safety checker
|
| 208 |
+
safety_checker_tokenizer = AutoTokenizer.from_pretrained(args.shield_model_path)
|
| 209 |
+
safety_checker_model = AutoModelForCausalLM.from_pretrained(
|
| 210 |
+
args.shield_model_path,
|
| 211 |
+
device_map="auto",
|
| 212 |
+
torch_dtype=torch.bfloat16,
|
| 213 |
+
).to(device)
|
| 214 |
+
|
| 215 |
+
|
| 216 |
+
def save_image_sana(img, seed="", save_img=False):
|
| 217 |
+
unique_name = f"{str(uuid.uuid4())}_{seed}.png"
|
| 218 |
+
save_path = os.path.join(f"output/online_demo_img/{datetime.now().date()}")
|
| 219 |
+
os.umask(0o000) # file permission: 666; dir permission: 777
|
| 220 |
+
os.makedirs(save_path, exist_ok=True)
|
| 221 |
+
unique_name = os.path.join(save_path, unique_name)
|
| 222 |
+
if save_img:
|
| 223 |
+
save_image(img, unique_name, nrow=1, normalize=True, value_range=(-1, 1))
|
| 224 |
+
|
| 225 |
+
return unique_name
|
| 226 |
+
|
| 227 |
+
|
| 228 |
+
def randomize_seed_fn(seed: int, randomize_seed: bool) -> int:
|
| 229 |
+
if randomize_seed:
|
| 230 |
+
seed = random.randint(0, MAX_SEED)
|
| 231 |
+
return seed
|
| 232 |
+
|
| 233 |
+
|
| 234 |
+
@torch.no_grad()
|
| 235 |
+
@torch.inference_mode()
|
| 236 |
+
@spaces.GPU(enable_queue=True)
|
| 237 |
+
def generate(
|
| 238 |
+
prompt: str = None,
|
| 239 |
+
negative_prompt: str = "",
|
| 240 |
+
style: str = DEFAULT_STYLE_NAME,
|
| 241 |
+
use_negative_prompt: bool = False,
|
| 242 |
+
num_imgs: int = 1,
|
| 243 |
+
seed: int = 0,
|
| 244 |
+
height: int = 1024,
|
| 245 |
+
width: int = 1024,
|
| 246 |
+
flow_dpms_guidance_scale: float = 5.0,
|
| 247 |
+
flow_dpms_pag_guidance_scale: float = 2.0,
|
| 248 |
+
flow_dpms_inference_steps: int = 20,
|
| 249 |
+
randomize_seed: bool = False,
|
| 250 |
+
):
|
| 251 |
+
global INFER_SPEED
|
| 252 |
+
# seed = 823753551
|
| 253 |
+
box = run_inference(num_imgs)
|
| 254 |
+
seed = int(randomize_seed_fn(seed, randomize_seed))
|
| 255 |
+
generator = torch.Generator(device=device).manual_seed(seed)
|
| 256 |
+
print(f"PORT: {DEMO_PORT}, model_path: {model_path}")
|
| 257 |
+
if safety_check.is_dangerous(safety_checker_tokenizer, safety_checker_model, prompt, threshold=0.2):
|
| 258 |
+
prompt = "A red heart."
|
| 259 |
+
|
| 260 |
+
print(prompt)
|
| 261 |
+
|
| 262 |
+
num_inference_steps = flow_dpms_inference_steps
|
| 263 |
+
guidance_scale = flow_dpms_guidance_scale
|
| 264 |
+
pag_guidance_scale = flow_dpms_pag_guidance_scale
|
| 265 |
+
|
| 266 |
+
if not use_negative_prompt:
|
| 267 |
+
negative_prompt = None # type: ignore
|
| 268 |
+
prompt, negative_prompt = apply_style(style, prompt, negative_prompt)
|
| 269 |
+
|
| 270 |
+
pipe.progress_fn(0, desc="Sana Start")
|
| 271 |
+
|
| 272 |
+
time_start = time.time()
|
| 273 |
+
images = pipe(
|
| 274 |
+
prompt=prompt,
|
| 275 |
+
height=height,
|
| 276 |
+
width=width,
|
| 277 |
+
negative_prompt=negative_prompt,
|
| 278 |
+
guidance_scale=guidance_scale,
|
| 279 |
+
pag_guidance_scale=pag_guidance_scale,
|
| 280 |
+
num_inference_steps=num_inference_steps,
|
| 281 |
+
num_images_per_prompt=num_imgs,
|
| 282 |
+
generator=generator,
|
| 283 |
+
)
|
| 284 |
+
|
| 285 |
+
pipe.progress_fn(1.0, desc="Sana End")
|
| 286 |
+
INFER_SPEED = (time.time() - time_start) / num_imgs
|
| 287 |
+
|
| 288 |
+
save_img = False
|
| 289 |
+
if save_img:
|
| 290 |
+
img = [save_image_sana(img, seed, save_img=save_image) for img in images]
|
| 291 |
+
print(img)
|
| 292 |
+
else:
|
| 293 |
+
img = [
|
| 294 |
+
Image.fromarray(
|
| 295 |
+
norm_ip(img, -1, 1)
|
| 296 |
+
.mul(255)
|
| 297 |
+
.add_(0.5)
|
| 298 |
+
.clamp_(0, 255)
|
| 299 |
+
.permute(1, 2, 0)
|
| 300 |
+
.to("cpu", torch.uint8)
|
| 301 |
+
.numpy()
|
| 302 |
+
.astype(np.uint8)
|
| 303 |
+
)
|
| 304 |
+
for img in images
|
| 305 |
+
]
|
| 306 |
+
|
| 307 |
+
torch.cuda.empty_cache()
|
| 308 |
+
|
| 309 |
+
return (
|
| 310 |
+
img,
|
| 311 |
+
seed,
|
| 312 |
+
f"<span style='font-size: 16px; font-weight: bold;'>Inference Speed: {INFER_SPEED:.3f} s/Img</span>",
|
| 313 |
+
box,
|
| 314 |
+
)
|
| 315 |
+
|
| 316 |
+
|
| 317 |
+
model_size = "1.6" if "1600M" in args.model_path else "0.6"
|
| 318 |
+
title = f"""
|
| 319 |
+
<div style='display: flex; align-items: center; justify-content: center; text-align: center;'>
|
| 320 |
+
<img src="https://raw.githubusercontent.com/NVlabs/Sana/refs/heads/main/asset/logo.png" width="50%" alt="logo"/>
|
| 321 |
+
</div>
|
| 322 |
+
"""
|
| 323 |
+
DESCRIPTION = f"""
|
| 324 |
+
<p><span style="font-size: 36px; font-weight: bold;">Sana-{model_size}B</span><span style="font-size: 20px; font-weight: bold;">{args.image_size}px</span></p>
|
| 325 |
+
<p style="font-size: 16px; font-weight: bold;">Sana: Efficient High-Resolution Image Synthesis with Linear Diffusion Transformer</p>
|
| 326 |
+
<p><span style="font-size: 16px;"><a href="https://arxiv.org/abs/2410.10629">[Paper]</a></span> <span style="font-size: 16px;"><a href="https://github.com/NVlabs/Sana">[Github]</a></span> <span style="font-size: 16px;"><a href="https://nvlabs.github.io/Sana">[Project]</a></span</p>
|
| 327 |
+
<p style="font-size: 16px; font-weight: bold;">Powered by <a href="https://hanlab.mit.edu/projects/dc-ae">DC-AE</a> with 32x latent space, </p>running on node {socket.gethostname()}.
|
| 328 |
+
<p style="font-size: 16px; font-weight: bold;">Unsafe word will give you a 'Red Heart' in the image instead.</p>
|
| 329 |
+
"""
|
| 330 |
+
if model_size == "0.6":
|
| 331 |
+
DESCRIPTION += "\n<p>0.6B model's text rendering ability is limited.</p>"
|
| 332 |
+
if not torch.cuda.is_available():
|
| 333 |
+
DESCRIPTION += "\n<p>Running on CPU 🥶 This demo does not work on CPU.</p>"
|
| 334 |
+
|
| 335 |
+
examples = [
|
| 336 |
+
'a cyberpunk cat with a neon sign that says "Sana"',
|
| 337 |
+
"A very detailed and realistic full body photo set of a tall, slim, and athletic Shiba Inu in a white oversized straight t-shirt, white shorts, and short white shoes.",
|
| 338 |
+
"Pirate ship trapped in a cosmic maelstrom nebula, rendered in cosmic beach whirlpool engine, volumetric lighting, spectacular, ambient lights, light pollution, cinematic atmosphere, art nouveau style, illustration art artwork by SenseiJaye, intricate detail.",
|
| 339 |
+
"portrait photo of a girl, photograph, highly detailed face, depth of field",
|
| 340 |
+
'make me a logo that says "So Fast" with a really cool flying dragon shape with lightning sparks all over the sides and all of it contains Indonesian language',
|
| 341 |
+
"🐶 Wearing 🕶 flying on the 🌈",
|
| 342 |
+
"👧 with 🌹 in the ❄️",
|
| 343 |
+
"an old rusted robot wearing pants and a jacket riding skis in a supermarket.",
|
| 344 |
+
"professional portrait photo of an anthropomorphic cat wearing fancy gentleman hat and jacket walking in autumn forest.",
|
| 345 |
+
"Astronaut in a jungle, cold color palette, muted colors, detailed",
|
| 346 |
+
"a stunning and luxurious bedroom carved into a rocky mountainside seamlessly blending nature with modern design with a plush earth-toned bed textured stone walls circular fireplace massive uniquely shaped window framing snow-capped mountains dense forests",
|
| 347 |
+
]
|
| 348 |
+
|
| 349 |
+
css = """
|
| 350 |
+
.gradio-container{max-width: 640px !important}
|
| 351 |
+
h1{text-align:center}
|
| 352 |
+
"""
|
| 353 |
+
with gr.Blocks(css=css, title="Sana") as demo:
|
| 354 |
+
gr.Markdown(title)
|
| 355 |
+
gr.HTML(DESCRIPTION)
|
| 356 |
+
gr.DuplicateButton(
|
| 357 |
+
value="Duplicate Space for private use",
|
| 358 |
+
elem_id="duplicate-button",
|
| 359 |
+
visible=os.getenv("SHOW_DUPLICATE_BUTTON") == "1",
|
| 360 |
+
)
|
| 361 |
+
info_box = gr.Markdown(
|
| 362 |
+
value=f"<span style='font-size: 16px; font-weight: bold;'>Total inference runs: </span><span style='font-size: 16px; color:red; font-weight: bold;'>{read_inference_count()}</span>"
|
| 363 |
+
)
|
| 364 |
+
demo.load(fn=update_inference_count, outputs=info_box) # update the value when re-loading the page
|
| 365 |
+
# with gr.Row(equal_height=False):
|
| 366 |
+
with gr.Group():
|
| 367 |
+
with gr.Row():
|
| 368 |
+
prompt = gr.Text(
|
| 369 |
+
label="Prompt",
|
| 370 |
+
show_label=False,
|
| 371 |
+
max_lines=1,
|
| 372 |
+
placeholder="Enter your prompt",
|
| 373 |
+
container=False,
|
| 374 |
+
)
|
| 375 |
+
run_button = gr.Button("Run", scale=0)
|
| 376 |
+
result = gr.Gallery(label="Result", show_label=False, columns=NUM_IMAGES_PER_PROMPT, format="png")
|
| 377 |
+
speed_box = gr.Markdown(
|
| 378 |
+
value=f"<span style='font-size: 16px; font-weight: bold;'>Inference speed: {INFER_SPEED} s/Img</span>"
|
| 379 |
+
)
|
| 380 |
+
with gr.Accordion("Advanced options", open=False):
|
| 381 |
+
with gr.Group():
|
| 382 |
+
with gr.Row(visible=True):
|
| 383 |
+
height = gr.Slider(
|
| 384 |
+
label="Height",
|
| 385 |
+
minimum=256,
|
| 386 |
+
maximum=MAX_IMAGE_SIZE,
|
| 387 |
+
step=32,
|
| 388 |
+
value=args.image_size,
|
| 389 |
+
)
|
| 390 |
+
width = gr.Slider(
|
| 391 |
+
label="Width",
|
| 392 |
+
minimum=256,
|
| 393 |
+
maximum=MAX_IMAGE_SIZE,
|
| 394 |
+
step=32,
|
| 395 |
+
value=args.image_size,
|
| 396 |
+
)
|
| 397 |
+
with gr.Row():
|
| 398 |
+
flow_dpms_inference_steps = gr.Slider(
|
| 399 |
+
label="Sampling steps",
|
| 400 |
+
minimum=5,
|
| 401 |
+
maximum=40,
|
| 402 |
+
step=1,
|
| 403 |
+
value=20,
|
| 404 |
+
)
|
| 405 |
+
flow_dpms_guidance_scale = gr.Slider(
|
| 406 |
+
label="CFG Guidance scale",
|
| 407 |
+
minimum=1,
|
| 408 |
+
maximum=10,
|
| 409 |
+
step=0.1,
|
| 410 |
+
value=4.5,
|
| 411 |
+
)
|
| 412 |
+
flow_dpms_pag_guidance_scale = gr.Slider(
|
| 413 |
+
label="PAG Guidance scale",
|
| 414 |
+
minimum=1,
|
| 415 |
+
maximum=4,
|
| 416 |
+
step=0.5,
|
| 417 |
+
value=1.0,
|
| 418 |
+
)
|
| 419 |
+
with gr.Row():
|
| 420 |
+
use_negative_prompt = gr.Checkbox(label="Use negative prompt", value=False, visible=True)
|
| 421 |
+
negative_prompt = gr.Text(
|
| 422 |
+
label="Negative prompt",
|
| 423 |
+
max_lines=1,
|
| 424 |
+
placeholder="Enter a negative prompt",
|
| 425 |
+
visible=True,
|
| 426 |
+
)
|
| 427 |
+
style_selection = gr.Radio(
|
| 428 |
+
show_label=True,
|
| 429 |
+
container=True,
|
| 430 |
+
interactive=True,
|
| 431 |
+
choices=STYLE_NAMES,
|
| 432 |
+
value=DEFAULT_STYLE_NAME,
|
| 433 |
+
label="Image Style",
|
| 434 |
+
)
|
| 435 |
+
seed = gr.Slider(
|
| 436 |
+
label="Seed",
|
| 437 |
+
minimum=0,
|
| 438 |
+
maximum=MAX_SEED,
|
| 439 |
+
step=1,
|
| 440 |
+
value=0,
|
| 441 |
+
)
|
| 442 |
+
randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
|
| 443 |
+
with gr.Row(visible=True):
|
| 444 |
+
schedule = gr.Radio(
|
| 445 |
+
show_label=True,
|
| 446 |
+
container=True,
|
| 447 |
+
interactive=True,
|
| 448 |
+
choices=SCHEDULE_NAME,
|
| 449 |
+
value=DEFAULT_SCHEDULE_NAME,
|
| 450 |
+
label="Sampler Schedule",
|
| 451 |
+
visible=True,
|
| 452 |
+
)
|
| 453 |
+
num_imgs = gr.Slider(
|
| 454 |
+
label="Num Images",
|
| 455 |
+
minimum=1,
|
| 456 |
+
maximum=6,
|
| 457 |
+
step=1,
|
| 458 |
+
value=1,
|
| 459 |
+
)
|
| 460 |
+
|
| 461 |
+
gr.Examples(
|
| 462 |
+
examples=examples,
|
| 463 |
+
inputs=prompt,
|
| 464 |
+
outputs=[result, seed],
|
| 465 |
+
fn=generate,
|
| 466 |
+
cache_examples=CACHE_EXAMPLES,
|
| 467 |
+
)
|
| 468 |
+
|
| 469 |
+
use_negative_prompt.change(
|
| 470 |
+
fn=lambda x: gr.update(visible=x),
|
| 471 |
+
inputs=use_negative_prompt,
|
| 472 |
+
outputs=negative_prompt,
|
| 473 |
+
api_name=False,
|
| 474 |
+
)
|
| 475 |
+
|
| 476 |
+
gr.on(
|
| 477 |
+
triggers=[
|
| 478 |
+
prompt.submit,
|
| 479 |
+
negative_prompt.submit,
|
| 480 |
+
run_button.click,
|
| 481 |
+
],
|
| 482 |
+
fn=generate,
|
| 483 |
+
inputs=[
|
| 484 |
+
prompt,
|
| 485 |
+
negative_prompt,
|
| 486 |
+
style_selection,
|
| 487 |
+
use_negative_prompt,
|
| 488 |
+
num_imgs,
|
| 489 |
+
seed,
|
| 490 |
+
height,
|
| 491 |
+
width,
|
| 492 |
+
flow_dpms_guidance_scale,
|
| 493 |
+
flow_dpms_pag_guidance_scale,
|
| 494 |
+
flow_dpms_inference_steps,
|
| 495 |
+
randomize_seed,
|
| 496 |
+
],
|
| 497 |
+
outputs=[result, seed, speed_box, info_box],
|
| 498 |
+
api_name="run",
|
| 499 |
+
)
|
| 500 |
+
|
| 501 |
+
if __name__ == "__main__":
|
| 502 |
+
demo.queue(max_size=20).launch(server_name="0.0.0.0", server_port=DEMO_PORT, debug=False, share=args.share)
|
apps/app_sana_4bit.py
ADDED
|
@@ -0,0 +1,409 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python
|
| 2 |
+
# Copyright 2024 NVIDIA CORPORATION & AFFILIATES
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
#!/usr/bin/env python
|
| 6 |
+
# Copyright 2024 NVIDIA CORPORATION & AFFILIATES
|
| 7 |
+
#
|
| 8 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 9 |
+
# you may not use this file except in compliance with the License.
|
| 10 |
+
# You may obtain a copy of the License at
|
| 11 |
+
#
|
| 12 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 13 |
+
#
|
| 14 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 15 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 16 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 17 |
+
# See the License for the specific language governing permissions and
|
| 18 |
+
# limitations under the License.
|
| 19 |
+
#
|
| 20 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 21 |
+
from __future__ import annotations
|
| 22 |
+
|
| 23 |
+
import argparse
|
| 24 |
+
import os
|
| 25 |
+
import random
|
| 26 |
+
import time
|
| 27 |
+
import uuid
|
| 28 |
+
from datetime import datetime
|
| 29 |
+
|
| 30 |
+
import gradio as gr
|
| 31 |
+
import numpy as np
|
| 32 |
+
import spaces
|
| 33 |
+
import torch
|
| 34 |
+
from diffusers import SanaPipeline
|
| 35 |
+
from nunchaku.models.transformer_sana import NunchakuSanaTransformer2DModel
|
| 36 |
+
from torchvision.utils import save_image
|
| 37 |
+
|
| 38 |
+
MAX_SEED = np.iinfo(np.int32).max
|
| 39 |
+
CACHE_EXAMPLES = torch.cuda.is_available() and os.getenv("CACHE_EXAMPLES", "1") == "1"
|
| 40 |
+
MAX_IMAGE_SIZE = int(os.getenv("MAX_IMAGE_SIZE", "4096"))
|
| 41 |
+
USE_TORCH_COMPILE = os.getenv("USE_TORCH_COMPILE", "0") == "1"
|
| 42 |
+
ENABLE_CPU_OFFLOAD = os.getenv("ENABLE_CPU_OFFLOAD", "0") == "1"
|
| 43 |
+
DEMO_PORT = int(os.getenv("DEMO_PORT", "15432"))
|
| 44 |
+
os.environ["GRADIO_EXAMPLES_CACHE"] = "./.gradio/cache"
|
| 45 |
+
COUNTER_DB = os.getenv("COUNTER_DB", ".count.db")
|
| 46 |
+
INFER_SPEED = 0
|
| 47 |
+
|
| 48 |
+
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
| 49 |
+
|
| 50 |
+
style_list = [
|
| 51 |
+
{
|
| 52 |
+
"name": "(No style)",
|
| 53 |
+
"prompt": "{prompt}",
|
| 54 |
+
"negative_prompt": "",
|
| 55 |
+
},
|
| 56 |
+
{
|
| 57 |
+
"name": "Cinematic",
|
| 58 |
+
"prompt": "cinematic still {prompt} . emotional, harmonious, vignette, highly detailed, high budget, bokeh, "
|
| 59 |
+
"cinemascope, moody, epic, gorgeous, film grain, grainy",
|
| 60 |
+
"negative_prompt": "anime, cartoon, graphic, text, painting, crayon, graphite, abstract, glitch, deformed, mutated, ugly, disfigured",
|
| 61 |
+
},
|
| 62 |
+
{
|
| 63 |
+
"name": "Photographic",
|
| 64 |
+
"prompt": "cinematic photo {prompt} . 35mm photograph, film, bokeh, professional, 4k, highly detailed",
|
| 65 |
+
"negative_prompt": "drawing, painting, crayon, sketch, graphite, impressionist, noisy, blurry, soft, deformed, ugly",
|
| 66 |
+
},
|
| 67 |
+
{
|
| 68 |
+
"name": "Anime",
|
| 69 |
+
"prompt": "anime artwork {prompt} . anime style, key visual, vibrant, studio anime, highly detailed",
|
| 70 |
+
"negative_prompt": "photo, deformed, black and white, realism, disfigured, low contrast",
|
| 71 |
+
},
|
| 72 |
+
{
|
| 73 |
+
"name": "Manga",
|
| 74 |
+
"prompt": "manga style {prompt} . vibrant, high-energy, detailed, iconic, Japanese comic style",
|
| 75 |
+
"negative_prompt": "ugly, deformed, noisy, blurry, low contrast, realism, photorealistic, Western comic style",
|
| 76 |
+
},
|
| 77 |
+
{
|
| 78 |
+
"name": "Digital Art",
|
| 79 |
+
"prompt": "concept art {prompt} . digital artwork, illustrative, painterly, matte painting, highly detailed",
|
| 80 |
+
"negative_prompt": "photo, photorealistic, realism, ugly",
|
| 81 |
+
},
|
| 82 |
+
{
|
| 83 |
+
"name": "Pixel art",
|
| 84 |
+
"prompt": "pixel-art {prompt} . low-res, blocky, pixel art style, 8-bit graphics",
|
| 85 |
+
"negative_prompt": "sloppy, messy, blurry, noisy, highly detailed, ultra textured, photo, realistic",
|
| 86 |
+
},
|
| 87 |
+
{
|
| 88 |
+
"name": "Fantasy art",
|
| 89 |
+
"prompt": "ethereal fantasy concept art of {prompt} . magnificent, celestial, ethereal, painterly, epic, "
|
| 90 |
+
"majestic, magical, fantasy art, cover art, dreamy",
|
| 91 |
+
"negative_prompt": "photographic, realistic, realism, 35mm film, dslr, cropped, frame, text, deformed, "
|
| 92 |
+
"glitch, noise, noisy, off-center, deformed, cross-eyed, closed eyes, bad anatomy, ugly, "
|
| 93 |
+
"disfigured, sloppy, duplicate, mutated, black and white",
|
| 94 |
+
},
|
| 95 |
+
{
|
| 96 |
+
"name": "Neonpunk",
|
| 97 |
+
"prompt": "neonpunk style {prompt} . cyberpunk, vaporwave, neon, vibes, vibrant, stunningly beautiful, crisp, "
|
| 98 |
+
"detailed, sleek, ultramodern, magenta highlights, dark purple shadows, high contrast, cinematic, "
|
| 99 |
+
"ultra detailed, intricate, professional",
|
| 100 |
+
"negative_prompt": "painting, drawing, illustration, glitch, deformed, mutated, cross-eyed, ugly, disfigured",
|
| 101 |
+
},
|
| 102 |
+
{
|
| 103 |
+
"name": "3D Model",
|
| 104 |
+
"prompt": "professional 3d model {prompt} . octane render, highly detailed, volumetric, dramatic lighting",
|
| 105 |
+
"negative_prompt": "ugly, deformed, noisy, low poly, blurry, painting",
|
| 106 |
+
},
|
| 107 |
+
]
|
| 108 |
+
|
| 109 |
+
styles = {k["name"]: (k["prompt"], k["negative_prompt"]) for k in style_list}
|
| 110 |
+
STYLE_NAMES = list(styles.keys())
|
| 111 |
+
DEFAULT_STYLE_NAME = "(No style)"
|
| 112 |
+
SCHEDULE_NAME = ["Flow_DPM_Solver"]
|
| 113 |
+
DEFAULT_SCHEDULE_NAME = "Flow_DPM_Solver"
|
| 114 |
+
NUM_IMAGES_PER_PROMPT = 1
|
| 115 |
+
|
| 116 |
+
|
| 117 |
+
def apply_style(style_name: str, positive: str, negative: str = "") -> tuple[str, str]:
|
| 118 |
+
p, n = styles.get(style_name, styles[DEFAULT_STYLE_NAME])
|
| 119 |
+
if not negative:
|
| 120 |
+
negative = ""
|
| 121 |
+
return p.replace("{prompt}", positive), n + negative
|
| 122 |
+
|
| 123 |
+
|
| 124 |
+
def get_args():
|
| 125 |
+
parser = argparse.ArgumentParser()
|
| 126 |
+
parser.add_argument(
|
| 127 |
+
"--model_path",
|
| 128 |
+
nargs="?",
|
| 129 |
+
default="Efficient-Large-Model/Sana_1600M_1024px_BF16_diffusers",
|
| 130 |
+
type=str,
|
| 131 |
+
help="Path to the model file (positional)",
|
| 132 |
+
)
|
| 133 |
+
parser.add_argument("--share", action="store_true")
|
| 134 |
+
|
| 135 |
+
return parser.parse_known_args()[0]
|
| 136 |
+
|
| 137 |
+
|
| 138 |
+
args = get_args()
|
| 139 |
+
|
| 140 |
+
if torch.cuda.is_available():
|
| 141 |
+
|
| 142 |
+
transformer = NunchakuSanaTransformer2DModel.from_pretrained("mit-han-lab/svdq-int4-sana-1600m")
|
| 143 |
+
pipe = SanaPipeline.from_pretrained(
|
| 144 |
+
"Efficient-Large-Model/Sana_1600M_1024px_BF16_diffusers",
|
| 145 |
+
transformer=transformer,
|
| 146 |
+
variant="bf16",
|
| 147 |
+
torch_dtype=torch.bfloat16,
|
| 148 |
+
).to(device)
|
| 149 |
+
|
| 150 |
+
pipe.text_encoder.to(torch.bfloat16)
|
| 151 |
+
pipe.vae.to(torch.bfloat16)
|
| 152 |
+
|
| 153 |
+
|
| 154 |
+
def save_image_sana(img, seed="", save_img=False):
|
| 155 |
+
unique_name = f"{str(uuid.uuid4())}_{seed}.png"
|
| 156 |
+
save_path = os.path.join(f"output/online_demo_img/{datetime.now().date()}")
|
| 157 |
+
os.umask(0o000) # file permission: 666; dir permission: 777
|
| 158 |
+
os.makedirs(save_path, exist_ok=True)
|
| 159 |
+
unique_name = os.path.join(save_path, unique_name)
|
| 160 |
+
if save_img:
|
| 161 |
+
save_image(img, unique_name, nrow=1, normalize=True, value_range=(-1, 1))
|
| 162 |
+
|
| 163 |
+
return unique_name
|
| 164 |
+
|
| 165 |
+
|
| 166 |
+
def randomize_seed_fn(seed: int, randomize_seed: bool) -> int:
|
| 167 |
+
if randomize_seed:
|
| 168 |
+
seed = random.randint(0, MAX_SEED)
|
| 169 |
+
return seed
|
| 170 |
+
|
| 171 |
+
|
| 172 |
+
@torch.no_grad()
|
| 173 |
+
@torch.inference_mode()
|
| 174 |
+
@spaces.GPU(enable_queue=True)
|
| 175 |
+
def generate(
|
| 176 |
+
prompt: str = None,
|
| 177 |
+
negative_prompt: str = "",
|
| 178 |
+
style: str = DEFAULT_STYLE_NAME,
|
| 179 |
+
use_negative_prompt: bool = False,
|
| 180 |
+
num_imgs: int = 1,
|
| 181 |
+
seed: int = 0,
|
| 182 |
+
height: int = 1024,
|
| 183 |
+
width: int = 1024,
|
| 184 |
+
flow_dpms_guidance_scale: float = 5.0,
|
| 185 |
+
flow_dpms_inference_steps: int = 20,
|
| 186 |
+
randomize_seed: bool = False,
|
| 187 |
+
):
|
| 188 |
+
global INFER_SPEED
|
| 189 |
+
# seed = 823753551
|
| 190 |
+
seed = int(randomize_seed_fn(seed, randomize_seed))
|
| 191 |
+
generator = torch.Generator(device=device).manual_seed(seed)
|
| 192 |
+
print(f"PORT: {DEMO_PORT}, model_path: {args.model_path}")
|
| 193 |
+
|
| 194 |
+
print(prompt)
|
| 195 |
+
|
| 196 |
+
num_inference_steps = flow_dpms_inference_steps
|
| 197 |
+
guidance_scale = flow_dpms_guidance_scale
|
| 198 |
+
|
| 199 |
+
if not use_negative_prompt:
|
| 200 |
+
negative_prompt = None # type: ignore
|
| 201 |
+
prompt, negative_prompt = apply_style(style, prompt, negative_prompt)
|
| 202 |
+
|
| 203 |
+
time_start = time.time()
|
| 204 |
+
images = pipe(
|
| 205 |
+
prompt=prompt,
|
| 206 |
+
height=height,
|
| 207 |
+
width=width,
|
| 208 |
+
negative_prompt=negative_prompt,
|
| 209 |
+
guidance_scale=guidance_scale,
|
| 210 |
+
num_inference_steps=num_inference_steps,
|
| 211 |
+
num_images_per_prompt=num_imgs,
|
| 212 |
+
generator=generator,
|
| 213 |
+
).images
|
| 214 |
+
INFER_SPEED = (time.time() - time_start) / num_imgs
|
| 215 |
+
|
| 216 |
+
save_img = False
|
| 217 |
+
if save_img:
|
| 218 |
+
img = [save_image_sana(img, seed, save_img=save_image) for img in images]
|
| 219 |
+
print(img)
|
| 220 |
+
else:
|
| 221 |
+
img = images
|
| 222 |
+
|
| 223 |
+
torch.cuda.empty_cache()
|
| 224 |
+
|
| 225 |
+
return (
|
| 226 |
+
img,
|
| 227 |
+
seed,
|
| 228 |
+
f"<span style='font-size: 16px; font-weight: bold;'>Inference Speed: {INFER_SPEED:.3f} s/Img</span>",
|
| 229 |
+
)
|
| 230 |
+
|
| 231 |
+
|
| 232 |
+
model_size = "1.6" if "1600M" in args.model_path else "0.6"
|
| 233 |
+
title = f"""
|
| 234 |
+
<div style='display: flex; align-items: center; justify-content: center; text-align: center;'>
|
| 235 |
+
<img src="https://raw.githubusercontent.com/NVlabs/Sana/refs/heads/main/asset/logo.png" width="30%" alt="logo"/>
|
| 236 |
+
</div>
|
| 237 |
+
"""
|
| 238 |
+
DESCRIPTION = f"""
|
| 239 |
+
<p style="font-size: 30px; font-weight: bold; text-align: center;">Sana: Efficient High-Resolution Image Synthesis with Linear Diffusion Transformer (4bit version)</p>
|
| 240 |
+
"""
|
| 241 |
+
if model_size == "0.6":
|
| 242 |
+
DESCRIPTION += "\n<p>0.6B model's text rendering ability is limited.</p>"
|
| 243 |
+
if not torch.cuda.is_available():
|
| 244 |
+
DESCRIPTION += "\n<p>Running on CPU 🥶 This demo does not work on CPU.</p>"
|
| 245 |
+
|
| 246 |
+
examples = [
|
| 247 |
+
'a cyberpunk cat with a neon sign that says "Sana"',
|
| 248 |
+
"A very detailed and realistic full body photo set of a tall, slim, and athletic Shiba Inu in a white oversized straight t-shirt, white shorts, and short white shoes.",
|
| 249 |
+
"Pirate ship trapped in a cosmic maelstrom nebula, rendered in cosmic beach whirlpool engine, volumetric lighting, spectacular, ambient lights, light pollution, cinematic atmosphere, art nouveau style, illustration art artwork by SenseiJaye, intricate detail.",
|
| 250 |
+
"portrait photo of a girl, photograph, highly detailed face, depth of field",
|
| 251 |
+
'make me a logo that says "So Fast" with a really cool flying dragon shape with lightning sparks all over the sides and all of it contains Indonesian language',
|
| 252 |
+
"🐶 Wearing 🕶 flying on the 🌈",
|
| 253 |
+
"👧 with 🌹 in the ❄️",
|
| 254 |
+
"an old rusted robot wearing pants and a jacket riding skis in a supermarket.",
|
| 255 |
+
"professional portrait photo of an anthropomorphic cat wearing fancy gentleman hat and jacket walking in autumn forest.",
|
| 256 |
+
"Astronaut in a jungle, cold color palette, muted colors, detailed",
|
| 257 |
+
"a stunning and luxurious bedroom carved into a rocky mountainside seamlessly blending nature with modern design with a plush earth-toned bed textured stone walls circular fireplace massive uniquely shaped window framing snow-capped mountains dense forests",
|
| 258 |
+
]
|
| 259 |
+
|
| 260 |
+
css = """
|
| 261 |
+
.gradio-container {max-width: 850px !important; height: auto !important;}
|
| 262 |
+
h1 {text-align: center;}
|
| 263 |
+
"""
|
| 264 |
+
theme = gr.themes.Base()
|
| 265 |
+
with gr.Blocks(css=css, theme=theme, title="Sana") as demo:
|
| 266 |
+
gr.Markdown(title)
|
| 267 |
+
gr.HTML(DESCRIPTION)
|
| 268 |
+
gr.DuplicateButton(
|
| 269 |
+
value="Duplicate Space for private use",
|
| 270 |
+
elem_id="duplicate-button",
|
| 271 |
+
visible=os.getenv("SHOW_DUPLICATE_BUTTON") == "1",
|
| 272 |
+
)
|
| 273 |
+
# with gr.Row(equal_height=False):
|
| 274 |
+
with gr.Group():
|
| 275 |
+
with gr.Row():
|
| 276 |
+
prompt = gr.Text(
|
| 277 |
+
label="Prompt",
|
| 278 |
+
show_label=False,
|
| 279 |
+
max_lines=1,
|
| 280 |
+
placeholder="Enter your prompt",
|
| 281 |
+
container=False,
|
| 282 |
+
)
|
| 283 |
+
run_button = gr.Button("Run", scale=0)
|
| 284 |
+
result = gr.Gallery(
|
| 285 |
+
label="Result",
|
| 286 |
+
show_label=False,
|
| 287 |
+
height=750,
|
| 288 |
+
columns=NUM_IMAGES_PER_PROMPT,
|
| 289 |
+
format="jpeg",
|
| 290 |
+
)
|
| 291 |
+
|
| 292 |
+
speed_box = gr.Markdown(
|
| 293 |
+
value=f"<span style='font-size: 16px; font-weight: bold;'>Inference speed: {INFER_SPEED} s/Img</span>"
|
| 294 |
+
)
|
| 295 |
+
with gr.Accordion("Advanced options", open=False):
|
| 296 |
+
with gr.Group():
|
| 297 |
+
with gr.Row(visible=True):
|
| 298 |
+
height = gr.Slider(
|
| 299 |
+
label="Height",
|
| 300 |
+
minimum=256,
|
| 301 |
+
maximum=MAX_IMAGE_SIZE,
|
| 302 |
+
step=32,
|
| 303 |
+
value=1024,
|
| 304 |
+
)
|
| 305 |
+
width = gr.Slider(
|
| 306 |
+
label="Width",
|
| 307 |
+
minimum=256,
|
| 308 |
+
maximum=MAX_IMAGE_SIZE,
|
| 309 |
+
step=32,
|
| 310 |
+
value=1024,
|
| 311 |
+
)
|
| 312 |
+
with gr.Row():
|
| 313 |
+
flow_dpms_inference_steps = gr.Slider(
|
| 314 |
+
label="Sampling steps",
|
| 315 |
+
minimum=5,
|
| 316 |
+
maximum=40,
|
| 317 |
+
step=1,
|
| 318 |
+
value=20,
|
| 319 |
+
)
|
| 320 |
+
flow_dpms_guidance_scale = gr.Slider(
|
| 321 |
+
label="CFG Guidance scale",
|
| 322 |
+
minimum=1,
|
| 323 |
+
maximum=10,
|
| 324 |
+
step=0.1,
|
| 325 |
+
value=4.5,
|
| 326 |
+
)
|
| 327 |
+
with gr.Row():
|
| 328 |
+
use_negative_prompt = gr.Checkbox(label="Use negative prompt", value=False, visible=True)
|
| 329 |
+
negative_prompt = gr.Text(
|
| 330 |
+
label="Negative prompt",
|
| 331 |
+
max_lines=1,
|
| 332 |
+
placeholder="Enter a negative prompt",
|
| 333 |
+
visible=True,
|
| 334 |
+
)
|
| 335 |
+
style_selection = gr.Radio(
|
| 336 |
+
show_label=True,
|
| 337 |
+
container=True,
|
| 338 |
+
interactive=True,
|
| 339 |
+
choices=STYLE_NAMES,
|
| 340 |
+
value=DEFAULT_STYLE_NAME,
|
| 341 |
+
label="Image Style",
|
| 342 |
+
)
|
| 343 |
+
seed = gr.Slider(
|
| 344 |
+
label="Seed",
|
| 345 |
+
minimum=0,
|
| 346 |
+
maximum=MAX_SEED,
|
| 347 |
+
step=1,
|
| 348 |
+
value=0,
|
| 349 |
+
)
|
| 350 |
+
randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
|
| 351 |
+
with gr.Row(visible=True):
|
| 352 |
+
schedule = gr.Radio(
|
| 353 |
+
show_label=True,
|
| 354 |
+
container=True,
|
| 355 |
+
interactive=True,
|
| 356 |
+
choices=SCHEDULE_NAME,
|
| 357 |
+
value=DEFAULT_SCHEDULE_NAME,
|
| 358 |
+
label="Sampler Schedule",
|
| 359 |
+
visible=True,
|
| 360 |
+
)
|
| 361 |
+
num_imgs = gr.Slider(
|
| 362 |
+
label="Num Images",
|
| 363 |
+
minimum=1,
|
| 364 |
+
maximum=6,
|
| 365 |
+
step=1,
|
| 366 |
+
value=1,
|
| 367 |
+
)
|
| 368 |
+
|
| 369 |
+
gr.Examples(
|
| 370 |
+
examples=examples,
|
| 371 |
+
inputs=prompt,
|
| 372 |
+
outputs=[result, seed],
|
| 373 |
+
fn=generate,
|
| 374 |
+
cache_examples=CACHE_EXAMPLES,
|
| 375 |
+
)
|
| 376 |
+
|
| 377 |
+
use_negative_prompt.change(
|
| 378 |
+
fn=lambda x: gr.update(visible=x),
|
| 379 |
+
inputs=use_negative_prompt,
|
| 380 |
+
outputs=negative_prompt,
|
| 381 |
+
api_name=False,
|
| 382 |
+
)
|
| 383 |
+
|
| 384 |
+
gr.on(
|
| 385 |
+
triggers=[
|
| 386 |
+
prompt.submit,
|
| 387 |
+
negative_prompt.submit,
|
| 388 |
+
run_button.click,
|
| 389 |
+
],
|
| 390 |
+
fn=generate,
|
| 391 |
+
inputs=[
|
| 392 |
+
prompt,
|
| 393 |
+
negative_prompt,
|
| 394 |
+
style_selection,
|
| 395 |
+
use_negative_prompt,
|
| 396 |
+
num_imgs,
|
| 397 |
+
seed,
|
| 398 |
+
height,
|
| 399 |
+
width,
|
| 400 |
+
flow_dpms_guidance_scale,
|
| 401 |
+
flow_dpms_inference_steps,
|
| 402 |
+
randomize_seed,
|
| 403 |
+
],
|
| 404 |
+
outputs=[result, seed, speed_box],
|
| 405 |
+
api_name="run",
|
| 406 |
+
)
|
| 407 |
+
|
| 408 |
+
if __name__ == "__main__":
|
| 409 |
+
demo.queue(max_size=20).launch(server_name="0.0.0.0", server_port=DEMO_PORT, debug=False, share=args.share)
|
apps/app_sana_4bit_compare_bf16.py
ADDED
|
@@ -0,0 +1,313 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Changed from https://huggingface.co/spaces/playgroundai/playground-v2.5/blob/main/app.py
|
| 2 |
+
import argparse
|
| 3 |
+
import os
|
| 4 |
+
import random
|
| 5 |
+
import time
|
| 6 |
+
from datetime import datetime
|
| 7 |
+
|
| 8 |
+
import GPUtil
|
| 9 |
+
|
| 10 |
+
# import gradio last to avoid conflicts with other imports
|
| 11 |
+
import gradio as gr
|
| 12 |
+
import safety_check
|
| 13 |
+
import spaces
|
| 14 |
+
import torch
|
| 15 |
+
from diffusers import SanaPipeline
|
| 16 |
+
from nunchaku.models.transformer_sana import NunchakuSanaTransformer2DModel
|
| 17 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer
|
| 18 |
+
|
| 19 |
+
MAX_IMAGE_SIZE = 2048
|
| 20 |
+
MAX_SEED = 1000000000
|
| 21 |
+
|
| 22 |
+
DEFAULT_HEIGHT = 1024
|
| 23 |
+
DEFAULT_WIDTH = 1024
|
| 24 |
+
|
| 25 |
+
# num_inference_steps, guidance_scale, seed
|
| 26 |
+
EXAMPLES = [
|
| 27 |
+
[
|
| 28 |
+
"🐶 Wearing 🕶 flying on the 🌈",
|
| 29 |
+
1024,
|
| 30 |
+
1024,
|
| 31 |
+
20,
|
| 32 |
+
5,
|
| 33 |
+
2,
|
| 34 |
+
],
|
| 35 |
+
[
|
| 36 |
+
"大漠孤烟直, 长河落日圆",
|
| 37 |
+
1024,
|
| 38 |
+
1024,
|
| 39 |
+
20,
|
| 40 |
+
5,
|
| 41 |
+
23,
|
| 42 |
+
],
|
| 43 |
+
[
|
| 44 |
+
"Pirate ship trapped in a cosmic maelstrom nebula, rendered in cosmic beach whirlpool engine, "
|
| 45 |
+
"volumetric lighting, spectacular, ambient lights, light pollution, cinematic atmosphere, "
|
| 46 |
+
"art nouveau style, illustration art artwork by SenseiJaye, intricate detail.",
|
| 47 |
+
1024,
|
| 48 |
+
1024,
|
| 49 |
+
20,
|
| 50 |
+
5,
|
| 51 |
+
233,
|
| 52 |
+
],
|
| 53 |
+
[
|
| 54 |
+
"A photo of a Eurasian lynx in a sunlit forest, with tufted ears and a spotted coat. The lynx should be "
|
| 55 |
+
"sharply focused, gazing into the distance, while the background is softly blurred for depth. Use cinematic "
|
| 56 |
+
"lighting with soft rays filtering through the trees, and capture the scene with a shallow depth of field "
|
| 57 |
+
"for a natural, peaceful atmosphere. 8K resolution, highly detailed, photorealistic, "
|
| 58 |
+
"cinematic lighting, ultra-HD.",
|
| 59 |
+
1024,
|
| 60 |
+
1024,
|
| 61 |
+
20,
|
| 62 |
+
5,
|
| 63 |
+
2333,
|
| 64 |
+
],
|
| 65 |
+
[
|
| 66 |
+
"A stylish woman walks down a Tokyo street filled with warm glowing neon and animated city signage. "
|
| 67 |
+
"She wears a black leather jacket, a long red dress, and black boots, and carries a black purse. "
|
| 68 |
+
"She wears sunglasses and red lipstick. She walks confidently and casually. "
|
| 69 |
+
"The street is damp and reflective, creating a mirror effect of the colorful lights. "
|
| 70 |
+
"Many pedestrians walk about.",
|
| 71 |
+
1024,
|
| 72 |
+
1024,
|
| 73 |
+
20,
|
| 74 |
+
5,
|
| 75 |
+
23333,
|
| 76 |
+
],
|
| 77 |
+
[
|
| 78 |
+
"Cozy bedroom with vintage wooden furniture and a large circular window covered in lush green vines, "
|
| 79 |
+
"opening to a misty forest. Soft, ambient lighting highlights the bed with crumpled blankets, a bookshelf, "
|
| 80 |
+
"and a desk. The atmosphere is serene and natural. 8K resolution, highly detailed, photorealistic, "
|
| 81 |
+
"cinematic lighting, ultra-HD.",
|
| 82 |
+
1024,
|
| 83 |
+
1024,
|
| 84 |
+
20,
|
| 85 |
+
5,
|
| 86 |
+
233333,
|
| 87 |
+
],
|
| 88 |
+
]
|
| 89 |
+
|
| 90 |
+
|
| 91 |
+
def hash_str_to_int(s: str) -> int:
|
| 92 |
+
"""Hash a string to an integer."""
|
| 93 |
+
modulus = 10**9 + 7 # Large prime modulus
|
| 94 |
+
hash_int = 0
|
| 95 |
+
for char in s:
|
| 96 |
+
hash_int = (hash_int * 31 + ord(char)) % modulus
|
| 97 |
+
return hash_int
|
| 98 |
+
|
| 99 |
+
|
| 100 |
+
def get_pipeline(
|
| 101 |
+
precision: str, use_qencoder: bool = False, device: str | torch.device = "cuda", pipeline_init_kwargs: dict = {}
|
| 102 |
+
) -> SanaPipeline:
|
| 103 |
+
if precision == "int4":
|
| 104 |
+
assert torch.device(device).type == "cuda", "int4 only supported on CUDA devices"
|
| 105 |
+
transformer = NunchakuSanaTransformer2DModel.from_pretrained("mit-han-lab/svdq-int4-sana-1600m")
|
| 106 |
+
|
| 107 |
+
pipeline_init_kwargs["transformer"] = transformer
|
| 108 |
+
if use_qencoder:
|
| 109 |
+
raise NotImplementedError("Quantized encoder not supported for Sana for now")
|
| 110 |
+
else:
|
| 111 |
+
assert precision == "bf16"
|
| 112 |
+
pipeline = SanaPipeline.from_pretrained(
|
| 113 |
+
"Efficient-Large-Model/Sana_1600M_1024px_BF16_diffusers",
|
| 114 |
+
variant="bf16",
|
| 115 |
+
torch_dtype=torch.bfloat16,
|
| 116 |
+
**pipeline_init_kwargs,
|
| 117 |
+
)
|
| 118 |
+
|
| 119 |
+
pipeline = pipeline.to(device)
|
| 120 |
+
return pipeline
|
| 121 |
+
|
| 122 |
+
|
| 123 |
+
def get_args() -> argparse.Namespace:
|
| 124 |
+
parser = argparse.ArgumentParser()
|
| 125 |
+
parser.add_argument(
|
| 126 |
+
"-p",
|
| 127 |
+
"--precisions",
|
| 128 |
+
type=str,
|
| 129 |
+
default=["int4"],
|
| 130 |
+
nargs="*",
|
| 131 |
+
choices=["int4", "bf16"],
|
| 132 |
+
help="Which precisions to use",
|
| 133 |
+
)
|
| 134 |
+
parser.add_argument("--use-qencoder", action="store_true", help="Whether to use 4-bit text encoder")
|
| 135 |
+
parser.add_argument("--no-safety-checker", action="store_true", help="Disable safety checker")
|
| 136 |
+
parser.add_argument("--count-use", action="store_true", help="Whether to count the number of uses")
|
| 137 |
+
return parser.parse_args()
|
| 138 |
+
|
| 139 |
+
|
| 140 |
+
args = get_args()
|
| 141 |
+
|
| 142 |
+
|
| 143 |
+
pipelines = []
|
| 144 |
+
pipeline_init_kwargs = {}
|
| 145 |
+
for i, precision in enumerate(args.precisions):
|
| 146 |
+
|
| 147 |
+
pipeline = get_pipeline(
|
| 148 |
+
precision=precision,
|
| 149 |
+
use_qencoder=args.use_qencoder,
|
| 150 |
+
device="cuda",
|
| 151 |
+
pipeline_init_kwargs={**pipeline_init_kwargs},
|
| 152 |
+
)
|
| 153 |
+
pipelines.append(pipeline)
|
| 154 |
+
if i == 0:
|
| 155 |
+
pipeline_init_kwargs["vae"] = pipeline.vae
|
| 156 |
+
pipeline_init_kwargs["text_encoder"] = pipeline.text_encoder
|
| 157 |
+
|
| 158 |
+
# safety checker
|
| 159 |
+
safety_checker_tokenizer = AutoTokenizer.from_pretrained(args.shield_model_path)
|
| 160 |
+
safety_checker_model = AutoModelForCausalLM.from_pretrained(
|
| 161 |
+
args.shield_model_path,
|
| 162 |
+
device_map="auto",
|
| 163 |
+
torch_dtype=torch.bfloat16,
|
| 164 |
+
).to(pipeline.device)
|
| 165 |
+
|
| 166 |
+
|
| 167 |
+
@spaces.GPU(enable_queue=True)
|
| 168 |
+
def generate(
|
| 169 |
+
prompt: str = None,
|
| 170 |
+
height: int = 1024,
|
| 171 |
+
width: int = 1024,
|
| 172 |
+
num_inference_steps: int = 4,
|
| 173 |
+
guidance_scale: float = 0,
|
| 174 |
+
seed: int = 0,
|
| 175 |
+
):
|
| 176 |
+
print(f"Prompt: {prompt}")
|
| 177 |
+
is_unsafe_prompt = False
|
| 178 |
+
if safety_check.is_dangerous(safety_checker_tokenizer, safety_checker_model, prompt, threshold=0.2):
|
| 179 |
+
prompt = "A peaceful world."
|
| 180 |
+
images, latency_strs = [], []
|
| 181 |
+
for i, pipeline in enumerate(pipelines):
|
| 182 |
+
progress = gr.Progress(track_tqdm=True)
|
| 183 |
+
start_time = time.time()
|
| 184 |
+
image = pipeline(
|
| 185 |
+
prompt=prompt,
|
| 186 |
+
height=height,
|
| 187 |
+
width=width,
|
| 188 |
+
guidance_scale=guidance_scale,
|
| 189 |
+
num_inference_steps=num_inference_steps,
|
| 190 |
+
generator=torch.Generator().manual_seed(seed),
|
| 191 |
+
).images[0]
|
| 192 |
+
end_time = time.time()
|
| 193 |
+
latency = end_time - start_time
|
| 194 |
+
if latency < 1:
|
| 195 |
+
latency = latency * 1000
|
| 196 |
+
latency_str = f"{latency:.2f}ms"
|
| 197 |
+
else:
|
| 198 |
+
latency_str = f"{latency:.2f}s"
|
| 199 |
+
images.append(image)
|
| 200 |
+
latency_strs.append(latency_str)
|
| 201 |
+
if is_unsafe_prompt:
|
| 202 |
+
for i in range(len(latency_strs)):
|
| 203 |
+
latency_strs[i] += " (Unsafe prompt detected)"
|
| 204 |
+
torch.cuda.empty_cache()
|
| 205 |
+
|
| 206 |
+
if args.count_use:
|
| 207 |
+
if os.path.exists("use_count.txt"):
|
| 208 |
+
with open("use_count.txt") as f:
|
| 209 |
+
count = int(f.read())
|
| 210 |
+
else:
|
| 211 |
+
count = 0
|
| 212 |
+
count += 1
|
| 213 |
+
current_time = datetime.now()
|
| 214 |
+
print(f"{current_time}: {count}")
|
| 215 |
+
with open("use_count.txt", "w") as f:
|
| 216 |
+
f.write(str(count))
|
| 217 |
+
with open("use_record.txt", "a") as f:
|
| 218 |
+
f.write(f"{current_time}: {count}\n")
|
| 219 |
+
|
| 220 |
+
return *images, *latency_strs
|
| 221 |
+
|
| 222 |
+
|
| 223 |
+
with open("./assets/description.html") as f:
|
| 224 |
+
DESCRIPTION = f.read()
|
| 225 |
+
gpus = GPUtil.getGPUs()
|
| 226 |
+
if len(gpus) > 0:
|
| 227 |
+
gpu = gpus[0]
|
| 228 |
+
memory = gpu.memoryTotal / 1024
|
| 229 |
+
device_info = f"Running on {gpu.name} with {memory:.0f} GiB memory."
|
| 230 |
+
else:
|
| 231 |
+
device_info = "Running on CPU 🥶 This demo does not work on CPU."
|
| 232 |
+
notice = f'<strong>Notice:</strong> We will replace unsafe prompts with a default prompt: "A peaceful world."'
|
| 233 |
+
|
| 234 |
+
with gr.Blocks(
|
| 235 |
+
css_paths=[f"assets/frame{len(args.precisions)}.css", "assets/common.css"],
|
| 236 |
+
title=f"SVDQuant SANA-1600M Demo",
|
| 237 |
+
) as demo:
|
| 238 |
+
|
| 239 |
+
def get_header_str():
|
| 240 |
+
|
| 241 |
+
if args.count_use:
|
| 242 |
+
if os.path.exists("use_count.txt"):
|
| 243 |
+
with open("use_count.txt") as f:
|
| 244 |
+
count = int(f.read())
|
| 245 |
+
else:
|
| 246 |
+
count = 0
|
| 247 |
+
count_info = (
|
| 248 |
+
f"<div style='display: flex; justify-content: center; align-items: center; text-align: center;'>"
|
| 249 |
+
f"<span style='font-size: 18px; font-weight: bold;'>Total inference runs: </span>"
|
| 250 |
+
f"<span style='font-size: 18px; color:red; font-weight: bold;'> {count}</span></div>"
|
| 251 |
+
)
|
| 252 |
+
else:
|
| 253 |
+
count_info = ""
|
| 254 |
+
header_str = DESCRIPTION.format(device_info=device_info, notice=notice, count_info=count_info)
|
| 255 |
+
return header_str
|
| 256 |
+
|
| 257 |
+
header = gr.HTML(get_header_str())
|
| 258 |
+
demo.load(fn=get_header_str, outputs=header)
|
| 259 |
+
|
| 260 |
+
with gr.Row():
|
| 261 |
+
image_results, latency_results = [], []
|
| 262 |
+
for i, precision in enumerate(args.precisions):
|
| 263 |
+
with gr.Column():
|
| 264 |
+
gr.Markdown(f"# {precision.upper()}", elem_id="image_header")
|
| 265 |
+
with gr.Group():
|
| 266 |
+
image_result = gr.Image(
|
| 267 |
+
format="png",
|
| 268 |
+
image_mode="RGB",
|
| 269 |
+
label="Result",
|
| 270 |
+
show_label=False,
|
| 271 |
+
show_download_button=True,
|
| 272 |
+
interactive=False,
|
| 273 |
+
)
|
| 274 |
+
latency_result = gr.Text(label="Inference Latency", show_label=True)
|
| 275 |
+
image_results.append(image_result)
|
| 276 |
+
latency_results.append(latency_result)
|
| 277 |
+
with gr.Row():
|
| 278 |
+
prompt = gr.Text(
|
| 279 |
+
label="Prompt", show_label=False, max_lines=1, placeholder="Enter your prompt", container=False, scale=4
|
| 280 |
+
)
|
| 281 |
+
run_button = gr.Button("Run", scale=1)
|
| 282 |
+
|
| 283 |
+
with gr.Row():
|
| 284 |
+
seed = gr.Slider(label="Seed", show_label=True, minimum=0, maximum=MAX_SEED, value=233, step=1, scale=4)
|
| 285 |
+
randomize_seed = gr.Button("Random Seed", scale=1, min_width=50, elem_id="random_seed")
|
| 286 |
+
with gr.Accordion("Advanced options", open=False):
|
| 287 |
+
with gr.Group():
|
| 288 |
+
height = gr.Slider(label="Height", minimum=256, maximum=4096, step=32, value=1024)
|
| 289 |
+
width = gr.Slider(label="Width", minimum=256, maximum=4096, step=32, value=1024)
|
| 290 |
+
with gr.Group():
|
| 291 |
+
num_inference_steps = gr.Slider(label="Sampling Steps", minimum=10, maximum=50, step=1, value=20)
|
| 292 |
+
guidance_scale = gr.Slider(label="Guidance Scale", minimum=1, maximum=10, step=0.1, value=5)
|
| 293 |
+
|
| 294 |
+
input_args = [prompt, height, width, num_inference_steps, guidance_scale, seed]
|
| 295 |
+
|
| 296 |
+
gr.Examples(examples=EXAMPLES, inputs=input_args, outputs=[*image_results, *latency_results], fn=generate)
|
| 297 |
+
|
| 298 |
+
gr.on(
|
| 299 |
+
triggers=[prompt.submit, run_button.click],
|
| 300 |
+
fn=generate,
|
| 301 |
+
inputs=input_args,
|
| 302 |
+
outputs=[*image_results, *latency_results],
|
| 303 |
+
api_name="run",
|
| 304 |
+
)
|
| 305 |
+
randomize_seed.click(
|
| 306 |
+
lambda: random.randint(0, MAX_SEED), inputs=[], outputs=seed, api_name=False, queue=False
|
| 307 |
+
).then(fn=generate, inputs=input_args, outputs=[*image_results, *latency_results], api_name=False, queue=False)
|
| 308 |
+
|
| 309 |
+
gr.Markdown("MIT Accessibility: https://accessibility.mit.edu/", elem_id="accessibility")
|
| 310 |
+
|
| 311 |
+
|
| 312 |
+
if __name__ == "__main__":
|
| 313 |
+
demo.queue(max_size=20).launch(server_name="0.0.0.0", debug=True, share=True)
|
apps/app_sana_controlnet_hed.py
ADDED
|
@@ -0,0 +1,306 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Changed from https://github.com/GaParmar/img2img-turbo/blob/main/gradio_sketch2image.py
|
| 2 |
+
import argparse
|
| 3 |
+
import os
|
| 4 |
+
import random
|
| 5 |
+
import socket
|
| 6 |
+
import tempfile
|
| 7 |
+
import time
|
| 8 |
+
|
| 9 |
+
import gradio as gr
|
| 10 |
+
import numpy as np
|
| 11 |
+
import torch
|
| 12 |
+
from PIL import Image
|
| 13 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer
|
| 14 |
+
|
| 15 |
+
from app import safety_check
|
| 16 |
+
from app.sana_controlnet_pipeline import SanaControlNetPipeline
|
| 17 |
+
|
| 18 |
+
STYLES = {
|
| 19 |
+
"None": "{prompt}",
|
| 20 |
+
"Cinematic": "cinematic still {prompt}. emotional, harmonious, vignette, highly detailed, high budget, bokeh, cinemascope, moody, epic, gorgeous, film grain, grainy",
|
| 21 |
+
"3D Model": "professional 3d model {prompt}. octane render, highly detailed, volumetric, dramatic lighting",
|
| 22 |
+
"Anime": "anime artwork {prompt}. anime style, key visual, vibrant, studio anime, highly detailed",
|
| 23 |
+
"Digital Art": "concept art {prompt}. digital artwork, illustrative, painterly, matte painting, highly detailed",
|
| 24 |
+
"Photographic": "cinematic photo {prompt}. 35mm photograph, film, bokeh, professional, 4k, highly detailed",
|
| 25 |
+
"Pixel art": "pixel-art {prompt}. low-res, blocky, pixel art style, 8-bit graphics",
|
| 26 |
+
"Fantasy art": "ethereal fantasy concept art of {prompt}. magnificent, celestial, ethereal, painterly, epic, majestic, magical, fantasy art, cover art, dreamy",
|
| 27 |
+
"Neonpunk": "neonpunk style {prompt}. cyberpunk, vaporwave, neon, vibes, vibrant, stunningly beautiful, crisp, detailed, sleek, ultramodern, magenta highlights, dark purple shadows, high contrast, cinematic, ultra detailed, intricate, professional",
|
| 28 |
+
"Manga": "manga style {prompt}. vibrant, high-energy, detailed, iconic, Japanese comic style",
|
| 29 |
+
}
|
| 30 |
+
DEFAULT_STYLE_NAME = "None"
|
| 31 |
+
STYLE_NAMES = list(STYLES.keys())
|
| 32 |
+
|
| 33 |
+
MAX_SEED = 1000000000
|
| 34 |
+
DEFAULT_SKETCH_GUIDANCE = 0.28
|
| 35 |
+
DEMO_PORT = int(os.getenv("DEMO_PORT", "15432"))
|
| 36 |
+
|
| 37 |
+
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
| 38 |
+
|
| 39 |
+
blank_image = Image.new("RGB", (1024, 1024), (255, 255, 255))
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
def get_args():
|
| 43 |
+
parser = argparse.ArgumentParser()
|
| 44 |
+
parser.add_argument("--config", type=str, help="config")
|
| 45 |
+
parser.add_argument(
|
| 46 |
+
"--model_path",
|
| 47 |
+
nargs="?",
|
| 48 |
+
default="hf://Efficient-Large-Model/Sana_1600M_1024px/checkpoints/Sana_1600M_1024px.pth",
|
| 49 |
+
type=str,
|
| 50 |
+
help="Path to the model file (positional)",
|
| 51 |
+
)
|
| 52 |
+
parser.add_argument("--output", default="./", type=str)
|
| 53 |
+
parser.add_argument("--bs", default=1, type=int)
|
| 54 |
+
parser.add_argument("--image_size", default=1024, type=int)
|
| 55 |
+
parser.add_argument("--cfg_scale", default=5.0, type=float)
|
| 56 |
+
parser.add_argument("--pag_scale", default=2.0, type=float)
|
| 57 |
+
parser.add_argument("--seed", default=42, type=int)
|
| 58 |
+
parser.add_argument("--step", default=-1, type=int)
|
| 59 |
+
parser.add_argument("--custom_image_size", default=None, type=int)
|
| 60 |
+
parser.add_argument("--share", action="store_true")
|
| 61 |
+
parser.add_argument(
|
| 62 |
+
"--shield_model_path",
|
| 63 |
+
type=str,
|
| 64 |
+
help="The path to shield model, we employ ShieldGemma-2B by default.",
|
| 65 |
+
default="google/shieldgemma-2b",
|
| 66 |
+
)
|
| 67 |
+
|
| 68 |
+
return parser.parse_known_args()[0]
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
args = get_args()
|
| 72 |
+
|
| 73 |
+
if torch.cuda.is_available():
|
| 74 |
+
model_path = args.model_path
|
| 75 |
+
pipe = SanaControlNetPipeline(args.config)
|
| 76 |
+
pipe.from_pretrained(model_path)
|
| 77 |
+
pipe.register_progress_bar(gr.Progress())
|
| 78 |
+
|
| 79 |
+
# safety checker
|
| 80 |
+
safety_checker_tokenizer = AutoTokenizer.from_pretrained(args.shield_model_path)
|
| 81 |
+
safety_checker_model = AutoModelForCausalLM.from_pretrained(
|
| 82 |
+
args.shield_model_path,
|
| 83 |
+
device_map="auto",
|
| 84 |
+
torch_dtype=torch.bfloat16,
|
| 85 |
+
).to(device)
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
def save_image(img):
|
| 89 |
+
if isinstance(img, dict):
|
| 90 |
+
img = img["composite"]
|
| 91 |
+
temp_file = tempfile.NamedTemporaryFile(suffix=".png", delete=False)
|
| 92 |
+
img.save(temp_file.name)
|
| 93 |
+
return temp_file.name
|
| 94 |
+
|
| 95 |
+
|
| 96 |
+
def norm_ip(img, low, high):
|
| 97 |
+
img.clamp_(min=low, max=high)
|
| 98 |
+
img.sub_(low).div_(max(high - low, 1e-5))
|
| 99 |
+
return img
|
| 100 |
+
|
| 101 |
+
|
| 102 |
+
@torch.no_grad()
|
| 103 |
+
@torch.inference_mode()
|
| 104 |
+
def run(
|
| 105 |
+
image,
|
| 106 |
+
prompt: str,
|
| 107 |
+
prompt_template: str,
|
| 108 |
+
sketch_thickness: int,
|
| 109 |
+
guidance_scale: float,
|
| 110 |
+
inference_steps: int,
|
| 111 |
+
seed: int,
|
| 112 |
+
blend_alpha: float,
|
| 113 |
+
) -> tuple[Image, str]:
|
| 114 |
+
|
| 115 |
+
print(f"Prompt: {prompt}")
|
| 116 |
+
image_numpy = np.array(image["composite"].convert("RGB"))
|
| 117 |
+
|
| 118 |
+
if prompt.strip() == "" and (np.sum(image_numpy == 255) >= 3145628 or np.sum(image_numpy == 0) >= 3145628):
|
| 119 |
+
return blank_image, "Please input the prompt or draw something."
|
| 120 |
+
|
| 121 |
+
if safety_check.is_dangerous(safety_checker_tokenizer, safety_checker_model, prompt, threshold=0.2):
|
| 122 |
+
prompt = "A red heart."
|
| 123 |
+
|
| 124 |
+
prompt = prompt_template.format(prompt=prompt)
|
| 125 |
+
pipe.set_blend_alpha(blend_alpha)
|
| 126 |
+
start_time = time.time()
|
| 127 |
+
images = pipe(
|
| 128 |
+
prompt=prompt,
|
| 129 |
+
ref_image=image["composite"],
|
| 130 |
+
guidance_scale=guidance_scale,
|
| 131 |
+
num_inference_steps=inference_steps,
|
| 132 |
+
num_images_per_prompt=1,
|
| 133 |
+
sketch_thickness=sketch_thickness,
|
| 134 |
+
generator=torch.Generator(device=device).manual_seed(seed),
|
| 135 |
+
)
|
| 136 |
+
|
| 137 |
+
latency = time.time() - start_time
|
| 138 |
+
|
| 139 |
+
if latency < 1:
|
| 140 |
+
latency = latency * 1000
|
| 141 |
+
latency_str = f"{latency:.2f}ms"
|
| 142 |
+
else:
|
| 143 |
+
latency_str = f"{latency:.2f}s"
|
| 144 |
+
torch.cuda.empty_cache()
|
| 145 |
+
|
| 146 |
+
img = [
|
| 147 |
+
Image.fromarray(
|
| 148 |
+
norm_ip(img, -1, 1)
|
| 149 |
+
.mul(255)
|
| 150 |
+
.add_(0.5)
|
| 151 |
+
.clamp_(0, 255)
|
| 152 |
+
.permute(1, 2, 0)
|
| 153 |
+
.to("cpu", torch.uint8)
|
| 154 |
+
.numpy()
|
| 155 |
+
.astype(np.uint8)
|
| 156 |
+
)
|
| 157 |
+
for img in images
|
| 158 |
+
]
|
| 159 |
+
img = img[0]
|
| 160 |
+
return img, latency_str
|
| 161 |
+
|
| 162 |
+
|
| 163 |
+
model_size = "1.6" if "1600M" in args.model_path else "0.6"
|
| 164 |
+
title = f"""
|
| 165 |
+
<div style='display: flex; align-items: center; justify-content: center; text-align: center;'>
|
| 166 |
+
<img src="https://raw.githubusercontent.com/NVlabs/Sana/refs/heads/main/asset/logo.png" width="50%" alt="logo"/>
|
| 167 |
+
</div>
|
| 168 |
+
"""
|
| 169 |
+
DESCRIPTION = f"""
|
| 170 |
+
<p><span style="font-size: 36px; font-weight: bold;">Sana-ControlNet-{model_size}B</span><span style="font-size: 20px; font-weight: bold;">{args.image_size}px</span></p>
|
| 171 |
+
<p style="font-size: 18px; font-weight: bold;">Sana: Efficient High-Resolution Image Synthesis with Linear Diffusion Transformer</p>
|
| 172 |
+
<p><span style="font-size: 16px;"><a href="https://arxiv.org/abs/2410.10629">[Paper]</a></span> <span style="font-size: 16px;"><a href="https://github.com/NVlabs/Sana">[Github]</a></span> <span style="font-size: 16px;"><a href="https://nvlabs.github.io/Sana">[Project]</a></span</p>
|
| 173 |
+
<p style="font-size: 18px; font-weight: bold;">Powered by <a href="https://hanlab.mit.edu/projects/dc-ae">DC-AE</a> with 32x latent space, </p>running on node {socket.gethostname()}.
|
| 174 |
+
<p style="font-size: 16px; font-weight: bold;">Unsafe word will give you a 'Red Heart' in the image instead.</p>
|
| 175 |
+
"""
|
| 176 |
+
if model_size == "0.6":
|
| 177 |
+
DESCRIPTION += "\n<p>0.6B model's text rendering ability is limited.</p>"
|
| 178 |
+
if not torch.cuda.is_available():
|
| 179 |
+
DESCRIPTION += "\n<p>Running on CPU 🥶 This demo does not work on CPU.</p>"
|
| 180 |
+
|
| 181 |
+
|
| 182 |
+
with gr.Blocks(css_paths="asset/app_styles/controlnet_app_style.css", title=f"Sana Sketch-to-Image Demo") as demo:
|
| 183 |
+
gr.Markdown(title)
|
| 184 |
+
gr.HTML(DESCRIPTION)
|
| 185 |
+
|
| 186 |
+
with gr.Row(elem_id="main_row"):
|
| 187 |
+
with gr.Column(elem_id="column_input"):
|
| 188 |
+
gr.Markdown("## INPUT", elem_id="input_header")
|
| 189 |
+
with gr.Group():
|
| 190 |
+
canvas = gr.Sketchpad(
|
| 191 |
+
value=blank_image,
|
| 192 |
+
height=640,
|
| 193 |
+
image_mode="RGB",
|
| 194 |
+
sources=["upload", "clipboard"],
|
| 195 |
+
type="pil",
|
| 196 |
+
label="Sketch",
|
| 197 |
+
show_label=False,
|
| 198 |
+
show_download_button=True,
|
| 199 |
+
interactive=True,
|
| 200 |
+
transforms=[],
|
| 201 |
+
canvas_size=(1024, 1024),
|
| 202 |
+
scale=1,
|
| 203 |
+
brush=gr.Brush(default_size=3, colors=["#000000"], color_mode="fixed"),
|
| 204 |
+
format="png",
|
| 205 |
+
layers=False,
|
| 206 |
+
)
|
| 207 |
+
with gr.Row():
|
| 208 |
+
prompt = gr.Text(label="Prompt", placeholder="Enter your prompt", scale=6)
|
| 209 |
+
run_button = gr.Button("Run", scale=1, elem_id="run_button")
|
| 210 |
+
download_sketch = gr.DownloadButton("Download Sketch", scale=1, elem_id="download_sketch")
|
| 211 |
+
with gr.Row():
|
| 212 |
+
style = gr.Dropdown(label="Style", choices=STYLE_NAMES, value=DEFAULT_STYLE_NAME, scale=1)
|
| 213 |
+
prompt_template = gr.Textbox(
|
| 214 |
+
label="Prompt Style Template", value=STYLES[DEFAULT_STYLE_NAME], scale=2, max_lines=1
|
| 215 |
+
)
|
| 216 |
+
|
| 217 |
+
with gr.Row():
|
| 218 |
+
sketch_thickness = gr.Slider(
|
| 219 |
+
label="Sketch Thickness",
|
| 220 |
+
minimum=1,
|
| 221 |
+
maximum=4,
|
| 222 |
+
step=1,
|
| 223 |
+
value=2,
|
| 224 |
+
)
|
| 225 |
+
with gr.Row():
|
| 226 |
+
inference_steps = gr.Slider(
|
| 227 |
+
label="Sampling steps",
|
| 228 |
+
minimum=5,
|
| 229 |
+
maximum=40,
|
| 230 |
+
step=1,
|
| 231 |
+
value=20,
|
| 232 |
+
)
|
| 233 |
+
guidance_scale = gr.Slider(
|
| 234 |
+
label="CFG Guidance scale",
|
| 235 |
+
minimum=1,
|
| 236 |
+
maximum=10,
|
| 237 |
+
step=0.1,
|
| 238 |
+
value=4.5,
|
| 239 |
+
)
|
| 240 |
+
blend_alpha = gr.Slider(
|
| 241 |
+
label="Blend Alpha",
|
| 242 |
+
minimum=0,
|
| 243 |
+
maximum=1,
|
| 244 |
+
step=0.1,
|
| 245 |
+
value=0,
|
| 246 |
+
)
|
| 247 |
+
with gr.Row():
|
| 248 |
+
seed = gr.Slider(label="Seed", show_label=True, minimum=0, maximum=MAX_SEED, value=233, step=1, scale=4)
|
| 249 |
+
randomize_seed = gr.Button("Random Seed", scale=1, min_width=50, elem_id="random_seed")
|
| 250 |
+
|
| 251 |
+
with gr.Column(elem_id="column_output"):
|
| 252 |
+
gr.Markdown("## OUTPUT", elem_id="output_header")
|
| 253 |
+
with gr.Group():
|
| 254 |
+
result = gr.Image(
|
| 255 |
+
format="png",
|
| 256 |
+
height=640,
|
| 257 |
+
image_mode="RGB",
|
| 258 |
+
type="pil",
|
| 259 |
+
label="Result",
|
| 260 |
+
show_label=False,
|
| 261 |
+
show_download_button=True,
|
| 262 |
+
interactive=False,
|
| 263 |
+
elem_id="output_image",
|
| 264 |
+
)
|
| 265 |
+
latency_result = gr.Text(label="Inference Latency", show_label=True)
|
| 266 |
+
|
| 267 |
+
download_result = gr.DownloadButton("Download Result", elem_id="download_result")
|
| 268 |
+
gr.Markdown("### Instructions")
|
| 269 |
+
gr.Markdown("**1**. Enter a text prompt (e.g. a cat)")
|
| 270 |
+
gr.Markdown("**2**. Start sketching or upload a reference image")
|
| 271 |
+
gr.Markdown("**3**. Change the image style using a style template")
|
| 272 |
+
gr.Markdown("**4**. Try different seeds to generate different results")
|
| 273 |
+
|
| 274 |
+
run_inputs = [canvas, prompt, prompt_template, sketch_thickness, guidance_scale, inference_steps, seed, blend_alpha]
|
| 275 |
+
run_outputs = [result, latency_result]
|
| 276 |
+
|
| 277 |
+
randomize_seed.click(
|
| 278 |
+
lambda: random.randint(0, MAX_SEED),
|
| 279 |
+
inputs=[],
|
| 280 |
+
outputs=seed,
|
| 281 |
+
api_name=False,
|
| 282 |
+
queue=False,
|
| 283 |
+
).then(run, inputs=run_inputs, outputs=run_outputs, api_name=False)
|
| 284 |
+
|
| 285 |
+
style.change(
|
| 286 |
+
lambda x: STYLES[x],
|
| 287 |
+
inputs=[style],
|
| 288 |
+
outputs=[prompt_template],
|
| 289 |
+
api_name=False,
|
| 290 |
+
queue=False,
|
| 291 |
+
).then(fn=run, inputs=run_inputs, outputs=run_outputs, api_name=False)
|
| 292 |
+
gr.on(
|
| 293 |
+
triggers=[prompt.submit, run_button.click, canvas.change],
|
| 294 |
+
fn=run,
|
| 295 |
+
inputs=run_inputs,
|
| 296 |
+
outputs=run_outputs,
|
| 297 |
+
api_name=False,
|
| 298 |
+
)
|
| 299 |
+
|
| 300 |
+
download_sketch.click(fn=save_image, inputs=canvas, outputs=download_sketch)
|
| 301 |
+
download_result.click(fn=save_image, inputs=result, outputs=download_result)
|
| 302 |
+
gr.Markdown("MIT Accessibility: https://accessibility.mit.edu/", elem_id="accessibility")
|
| 303 |
+
|
| 304 |
+
|
| 305 |
+
if __name__ == "__main__":
|
| 306 |
+
demo.queue(max_size=20).launch(server_name="0.0.0.0", server_port=DEMO_PORT, debug=False, share=args.share)
|
apps/app_sana_multithread.py
ADDED
|
@@ -0,0 +1,565 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python
|
| 2 |
+
# Copyright 2024 NVIDIA CORPORATION & AFFILIATES
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
#
|
| 16 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 17 |
+
from __future__ import annotations
|
| 18 |
+
|
| 19 |
+
import argparse
|
| 20 |
+
import os
|
| 21 |
+
import random
|
| 22 |
+
import uuid
|
| 23 |
+
from datetime import datetime
|
| 24 |
+
|
| 25 |
+
import gradio as gr
|
| 26 |
+
import numpy as np
|
| 27 |
+
import spaces
|
| 28 |
+
import torch
|
| 29 |
+
from diffusers import FluxPipeline
|
| 30 |
+
from PIL import Image
|
| 31 |
+
from torchvision.utils import make_grid, save_image
|
| 32 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer
|
| 33 |
+
|
| 34 |
+
from app import safety_check
|
| 35 |
+
from app.sana_pipeline import SanaPipeline
|
| 36 |
+
|
| 37 |
+
MAX_SEED = np.iinfo(np.int32).max
|
| 38 |
+
CACHE_EXAMPLES = torch.cuda.is_available() and os.getenv("CACHE_EXAMPLES", "1") == "1"
|
| 39 |
+
MAX_IMAGE_SIZE = int(os.getenv("MAX_IMAGE_SIZE", "4096"))
|
| 40 |
+
USE_TORCH_COMPILE = os.getenv("USE_TORCH_COMPILE", "0") == "1"
|
| 41 |
+
ENABLE_CPU_OFFLOAD = os.getenv("ENABLE_CPU_OFFLOAD", "0") == "1"
|
| 42 |
+
DEMO_PORT = int(os.getenv("DEMO_PORT", "15432"))
|
| 43 |
+
os.environ["GRADIO_EXAMPLES_CACHE"] = "./.gradio/cache"
|
| 44 |
+
|
| 45 |
+
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
| 46 |
+
|
| 47 |
+
style_list = [
|
| 48 |
+
{
|
| 49 |
+
"name": "(No style)",
|
| 50 |
+
"prompt": "{prompt}",
|
| 51 |
+
"negative_prompt": "",
|
| 52 |
+
},
|
| 53 |
+
{
|
| 54 |
+
"name": "Cinematic",
|
| 55 |
+
"prompt": "cinematic still {prompt} . emotional, harmonious, vignette, highly detailed, high budget, bokeh, "
|
| 56 |
+
"cinemascope, moody, epic, gorgeous, film grain, grainy",
|
| 57 |
+
"negative_prompt": "anime, cartoon, graphic, text, painting, crayon, graphite, abstract, glitch, deformed, mutated, ugly, disfigured",
|
| 58 |
+
},
|
| 59 |
+
{
|
| 60 |
+
"name": "Photographic",
|
| 61 |
+
"prompt": "cinematic photo {prompt} . 35mm photograph, film, bokeh, professional, 4k, highly detailed",
|
| 62 |
+
"negative_prompt": "drawing, painting, crayon, sketch, graphite, impressionist, noisy, blurry, soft, deformed, ugly",
|
| 63 |
+
},
|
| 64 |
+
{
|
| 65 |
+
"name": "Anime",
|
| 66 |
+
"prompt": "anime artwork {prompt} . anime style, key visual, vibrant, studio anime, highly detailed",
|
| 67 |
+
"negative_prompt": "photo, deformed, black and white, realism, disfigured, low contrast",
|
| 68 |
+
},
|
| 69 |
+
{
|
| 70 |
+
"name": "Manga",
|
| 71 |
+
"prompt": "manga style {prompt} . vibrant, high-energy, detailed, iconic, Japanese comic style",
|
| 72 |
+
"negative_prompt": "ugly, deformed, noisy, blurry, low contrast, realism, photorealistic, Western comic style",
|
| 73 |
+
},
|
| 74 |
+
{
|
| 75 |
+
"name": "Digital Art",
|
| 76 |
+
"prompt": "concept art {prompt} . digital artwork, illustrative, painterly, matte painting, highly detailed",
|
| 77 |
+
"negative_prompt": "photo, photorealistic, realism, ugly",
|
| 78 |
+
},
|
| 79 |
+
{
|
| 80 |
+
"name": "Pixel art",
|
| 81 |
+
"prompt": "pixel-art {prompt} . low-res, blocky, pixel art style, 8-bit graphics",
|
| 82 |
+
"negative_prompt": "sloppy, messy, blurry, noisy, highly detailed, ultra textured, photo, realistic",
|
| 83 |
+
},
|
| 84 |
+
{
|
| 85 |
+
"name": "Fantasy art",
|
| 86 |
+
"prompt": "ethereal fantasy concept art of {prompt} . magnificent, celestial, ethereal, painterly, epic, "
|
| 87 |
+
"majestic, magical, fantasy art, cover art, dreamy",
|
| 88 |
+
"negative_prompt": "photographic, realistic, realism, 35mm film, dslr, cropped, frame, text, deformed, "
|
| 89 |
+
"glitch, noise, noisy, off-center, deformed, cross-eyed, closed eyes, bad anatomy, ugly, "
|
| 90 |
+
"disfigured, sloppy, duplicate, mutated, black and white",
|
| 91 |
+
},
|
| 92 |
+
{
|
| 93 |
+
"name": "Neonpunk",
|
| 94 |
+
"prompt": "neonpunk style {prompt} . cyberpunk, vaporwave, neon, vibes, vibrant, stunningly beautiful, crisp, "
|
| 95 |
+
"detailed, sleek, ultramodern, magenta highlights, dark purple shadows, high contrast, cinematic, "
|
| 96 |
+
"ultra detailed, intricate, professional",
|
| 97 |
+
"negative_prompt": "painting, drawing, illustration, glitch, deformed, mutated, cross-eyed, ugly, disfigured",
|
| 98 |
+
},
|
| 99 |
+
{
|
| 100 |
+
"name": "3D Model",
|
| 101 |
+
"prompt": "professional 3d model {prompt} . octane render, highly detailed, volumetric, dramatic lighting",
|
| 102 |
+
"negative_prompt": "ugly, deformed, noisy, low poly, blurry, painting",
|
| 103 |
+
},
|
| 104 |
+
]
|
| 105 |
+
|
| 106 |
+
styles = {k["name"]: (k["prompt"], k["negative_prompt"]) for k in style_list}
|
| 107 |
+
STYLE_NAMES = list(styles.keys())
|
| 108 |
+
DEFAULT_STYLE_NAME = "(No style)"
|
| 109 |
+
SCHEDULE_NAME = ["Flow_DPM_Solver"]
|
| 110 |
+
DEFAULT_SCHEDULE_NAME = "Flow_DPM_Solver"
|
| 111 |
+
NUM_IMAGES_PER_PROMPT = 1
|
| 112 |
+
TEST_TIMES = 0
|
| 113 |
+
FILENAME = f"output/port{DEMO_PORT}_inference_count.txt"
|
| 114 |
+
|
| 115 |
+
|
| 116 |
+
def set_env(seed=0):
|
| 117 |
+
torch.manual_seed(seed)
|
| 118 |
+
torch.set_grad_enabled(False)
|
| 119 |
+
|
| 120 |
+
|
| 121 |
+
def read_inference_count():
|
| 122 |
+
global TEST_TIMES
|
| 123 |
+
try:
|
| 124 |
+
with open(FILENAME) as f:
|
| 125 |
+
count = int(f.read().strip())
|
| 126 |
+
except FileNotFoundError:
|
| 127 |
+
count = 0
|
| 128 |
+
TEST_TIMES = count
|
| 129 |
+
|
| 130 |
+
return count
|
| 131 |
+
|
| 132 |
+
|
| 133 |
+
def write_inference_count(count):
|
| 134 |
+
with open(FILENAME, "w") as f:
|
| 135 |
+
f.write(str(count))
|
| 136 |
+
|
| 137 |
+
|
| 138 |
+
def run_inference(num_imgs=1):
|
| 139 |
+
TEST_TIMES = read_inference_count()
|
| 140 |
+
TEST_TIMES += int(num_imgs)
|
| 141 |
+
write_inference_count(TEST_TIMES)
|
| 142 |
+
|
| 143 |
+
return (
|
| 144 |
+
f"<span style='font-size: 16px; font-weight: bold;'>Total inference runs: </span><span style='font-size: "
|
| 145 |
+
f"16px; color:red; font-weight: bold;'>{TEST_TIMES}</span>"
|
| 146 |
+
)
|
| 147 |
+
|
| 148 |
+
|
| 149 |
+
def update_inference_count():
|
| 150 |
+
count = read_inference_count()
|
| 151 |
+
return (
|
| 152 |
+
f"<span style='font-size: 16px; font-weight: bold;'>Total inference runs: </span><span style='font-size: "
|
| 153 |
+
f"16px; color:red; font-weight: bold;'>{count}</span>"
|
| 154 |
+
)
|
| 155 |
+
|
| 156 |
+
|
| 157 |
+
def apply_style(style_name: str, positive: str, negative: str = "") -> tuple[str, str]:
|
| 158 |
+
p, n = styles.get(style_name, styles[DEFAULT_STYLE_NAME])
|
| 159 |
+
if not negative:
|
| 160 |
+
negative = ""
|
| 161 |
+
return p.replace("{prompt}", positive), n + negative
|
| 162 |
+
|
| 163 |
+
|
| 164 |
+
def get_args():
|
| 165 |
+
parser = argparse.ArgumentParser()
|
| 166 |
+
parser.add_argument("--config", type=str, help="config")
|
| 167 |
+
parser.add_argument(
|
| 168 |
+
"--model_path",
|
| 169 |
+
nargs="?",
|
| 170 |
+
default="output/Sana_D20/SANA.pth",
|
| 171 |
+
type=str,
|
| 172 |
+
help="Path to the model file (positional)",
|
| 173 |
+
)
|
| 174 |
+
parser.add_argument("--output", default="./", type=str)
|
| 175 |
+
parser.add_argument("--bs", default=1, type=int)
|
| 176 |
+
parser.add_argument("--image_size", default=1024, type=int)
|
| 177 |
+
parser.add_argument("--cfg_scale", default=5.0, type=float)
|
| 178 |
+
parser.add_argument("--pag_scale", default=2.0, type=float)
|
| 179 |
+
parser.add_argument("--seed", default=42, type=int)
|
| 180 |
+
parser.add_argument("--step", default=-1, type=int)
|
| 181 |
+
parser.add_argument("--custom_image_size", default=None, type=int)
|
| 182 |
+
parser.add_argument(
|
| 183 |
+
"--shield_model_path",
|
| 184 |
+
type=str,
|
| 185 |
+
help="The path to shield model, we employ ShieldGemma-2B by default.",
|
| 186 |
+
default="google/shieldgemma-2b",
|
| 187 |
+
)
|
| 188 |
+
|
| 189 |
+
return parser.parse_args()
|
| 190 |
+
|
| 191 |
+
|
| 192 |
+
args = get_args()
|
| 193 |
+
|
| 194 |
+
if torch.cuda.is_available():
|
| 195 |
+
weight_dtype = torch.float16
|
| 196 |
+
model_path = args.model_path
|
| 197 |
+
pipe = SanaPipeline(args.config)
|
| 198 |
+
pipe.from_pretrained(model_path)
|
| 199 |
+
pipe.register_progress_bar(gr.Progress())
|
| 200 |
+
|
| 201 |
+
repo_name = "black-forest-labs/FLUX.1-dev"
|
| 202 |
+
pipe2 = FluxPipeline.from_pretrained(repo_name, torch_dtype=torch.float16).to("cuda")
|
| 203 |
+
|
| 204 |
+
# safety checker
|
| 205 |
+
safety_checker_tokenizer = AutoTokenizer.from_pretrained(args.shield_model_path)
|
| 206 |
+
safety_checker_model = AutoModelForCausalLM.from_pretrained(
|
| 207 |
+
args.shield_model_path,
|
| 208 |
+
device_map="auto",
|
| 209 |
+
torch_dtype=torch.bfloat16,
|
| 210 |
+
).to(device)
|
| 211 |
+
|
| 212 |
+
set_env(42)
|
| 213 |
+
|
| 214 |
+
|
| 215 |
+
def save_image_sana(img, seed="", save_img=False):
|
| 216 |
+
unique_name = f"{str(uuid.uuid4())}_{seed}.png"
|
| 217 |
+
save_path = os.path.join(f"output/online_demo_img/{datetime.now().date()}")
|
| 218 |
+
os.umask(0o000) # file permission: 666; dir permission: 777
|
| 219 |
+
os.makedirs(save_path, exist_ok=True)
|
| 220 |
+
unique_name = os.path.join(save_path, unique_name)
|
| 221 |
+
if save_img:
|
| 222 |
+
save_image(img, unique_name, nrow=1, normalize=True, value_range=(-1, 1))
|
| 223 |
+
|
| 224 |
+
return unique_name
|
| 225 |
+
|
| 226 |
+
|
| 227 |
+
def randomize_seed_fn(seed: int, randomize_seed: bool) -> int:
|
| 228 |
+
if randomize_seed:
|
| 229 |
+
seed = random.randint(0, MAX_SEED)
|
| 230 |
+
return seed
|
| 231 |
+
|
| 232 |
+
|
| 233 |
+
@spaces.GPU(enable_queue=True)
|
| 234 |
+
async def generate_2(
|
| 235 |
+
prompt: str = None,
|
| 236 |
+
negative_prompt: str = "",
|
| 237 |
+
style: str = DEFAULT_STYLE_NAME,
|
| 238 |
+
use_negative_prompt: bool = False,
|
| 239 |
+
num_imgs: int = 1,
|
| 240 |
+
seed: int = 0,
|
| 241 |
+
height: int = 1024,
|
| 242 |
+
width: int = 1024,
|
| 243 |
+
flow_dpms_guidance_scale: float = 5.0,
|
| 244 |
+
flow_dpms_pag_guidance_scale: float = 2.0,
|
| 245 |
+
flow_dpms_inference_steps: int = 20,
|
| 246 |
+
randomize_seed: bool = False,
|
| 247 |
+
):
|
| 248 |
+
seed = int(randomize_seed_fn(seed, randomize_seed))
|
| 249 |
+
generator = torch.Generator(device=device).manual_seed(seed)
|
| 250 |
+
print(f"PORT: {DEMO_PORT}, model_path: {model_path}")
|
| 251 |
+
if safety_check.is_dangerous(safety_checker_tokenizer, safety_checker_model, prompt):
|
| 252 |
+
prompt = "A red heart."
|
| 253 |
+
|
| 254 |
+
print(prompt)
|
| 255 |
+
|
| 256 |
+
if not use_negative_prompt:
|
| 257 |
+
negative_prompt = None # type: ignore
|
| 258 |
+
prompt, negative_prompt = apply_style(style, prompt, negative_prompt)
|
| 259 |
+
|
| 260 |
+
with torch.no_grad():
|
| 261 |
+
images = pipe2(
|
| 262 |
+
prompt=prompt,
|
| 263 |
+
height=height,
|
| 264 |
+
width=width,
|
| 265 |
+
guidance_scale=3.5,
|
| 266 |
+
num_inference_steps=50,
|
| 267 |
+
num_images_per_prompt=num_imgs,
|
| 268 |
+
max_sequence_length=256,
|
| 269 |
+
generator=generator,
|
| 270 |
+
).images
|
| 271 |
+
|
| 272 |
+
save_img = False
|
| 273 |
+
img = images
|
| 274 |
+
if save_img:
|
| 275 |
+
img = [save_image_sana(img, seed, save_img=save_image) for img in images]
|
| 276 |
+
print(img)
|
| 277 |
+
torch.cuda.empty_cache()
|
| 278 |
+
|
| 279 |
+
return img
|
| 280 |
+
|
| 281 |
+
|
| 282 |
+
@spaces.GPU(enable_queue=True)
|
| 283 |
+
async def generate(
|
| 284 |
+
prompt: str = None,
|
| 285 |
+
negative_prompt: str = "",
|
| 286 |
+
style: str = DEFAULT_STYLE_NAME,
|
| 287 |
+
use_negative_prompt: bool = False,
|
| 288 |
+
num_imgs: int = 1,
|
| 289 |
+
seed: int = 0,
|
| 290 |
+
height: int = 1024,
|
| 291 |
+
width: int = 1024,
|
| 292 |
+
flow_dpms_guidance_scale: float = 5.0,
|
| 293 |
+
flow_dpms_pag_guidance_scale: float = 2.0,
|
| 294 |
+
flow_dpms_inference_steps: int = 20,
|
| 295 |
+
randomize_seed: bool = False,
|
| 296 |
+
):
|
| 297 |
+
global TEST_TIMES
|
| 298 |
+
# seed = 823753551
|
| 299 |
+
seed = int(randomize_seed_fn(seed, randomize_seed))
|
| 300 |
+
generator = torch.Generator(device=device).manual_seed(seed)
|
| 301 |
+
print(f"PORT: {DEMO_PORT}, model_path: {model_path}, time_times: {TEST_TIMES}")
|
| 302 |
+
if safety_check.is_dangerous(safety_checker_tokenizer, safety_checker_model, prompt):
|
| 303 |
+
prompt = "A red heart."
|
| 304 |
+
|
| 305 |
+
print(prompt)
|
| 306 |
+
|
| 307 |
+
num_inference_steps = flow_dpms_inference_steps
|
| 308 |
+
guidance_scale = flow_dpms_guidance_scale
|
| 309 |
+
pag_guidance_scale = flow_dpms_pag_guidance_scale
|
| 310 |
+
|
| 311 |
+
if not use_negative_prompt:
|
| 312 |
+
negative_prompt = None # type: ignore
|
| 313 |
+
prompt, negative_prompt = apply_style(style, prompt, negative_prompt)
|
| 314 |
+
|
| 315 |
+
pipe.progress_fn(0, desc="Sana Start")
|
| 316 |
+
|
| 317 |
+
with torch.no_grad():
|
| 318 |
+
images = pipe(
|
| 319 |
+
prompt=prompt,
|
| 320 |
+
height=height,
|
| 321 |
+
width=width,
|
| 322 |
+
negative_prompt=negative_prompt,
|
| 323 |
+
guidance_scale=guidance_scale,
|
| 324 |
+
pag_guidance_scale=pag_guidance_scale,
|
| 325 |
+
num_inference_steps=num_inference_steps,
|
| 326 |
+
num_images_per_prompt=num_imgs,
|
| 327 |
+
generator=generator,
|
| 328 |
+
)
|
| 329 |
+
|
| 330 |
+
pipe.progress_fn(1.0, desc="Sana End")
|
| 331 |
+
|
| 332 |
+
save_img = False
|
| 333 |
+
if save_img:
|
| 334 |
+
img = [save_image_sana(img, seed, save_img=save_image) for img in images]
|
| 335 |
+
print(img)
|
| 336 |
+
else:
|
| 337 |
+
if num_imgs > 1:
|
| 338 |
+
nrow = 2
|
| 339 |
+
else:
|
| 340 |
+
nrow = 1
|
| 341 |
+
img = make_grid(images, nrow=nrow, normalize=True, value_range=(-1, 1))
|
| 342 |
+
img = img.mul(255).add_(0.5).clamp_(0, 255).permute(1, 2, 0).to("cpu", torch.uint8).numpy()
|
| 343 |
+
img = [Image.fromarray(img.astype(np.uint8))]
|
| 344 |
+
|
| 345 |
+
torch.cuda.empty_cache()
|
| 346 |
+
|
| 347 |
+
return img
|
| 348 |
+
|
| 349 |
+
|
| 350 |
+
TEST_TIMES = read_inference_count()
|
| 351 |
+
model_size = "1.6" if "D20" in args.model_path else "0.6"
|
| 352 |
+
title = f"""
|
| 353 |
+
<div style='display: flex; align-items: center; justify-content: center; text-align: center;'>
|
| 354 |
+
<img src="https://raw.githubusercontent.com/NVlabs/Sana/refs/heads/main/asset/logo.png" width="50%" alt="logo"/>
|
| 355 |
+
</div>
|
| 356 |
+
"""
|
| 357 |
+
DESCRIPTION = f"""
|
| 358 |
+
<p><span style="font-size: 36px; font-weight: bold;">Sana-{model_size}B</span><span style="font-size: 20px; font-weight: bold;">{args.image_size}px</span></p>
|
| 359 |
+
<p style="font-size: 16px; font-weight: bold;">Sana: Efficient High-Resolution Image Synthesis with Linear Diffusion Transformer</p>
|
| 360 |
+
<p><span style="font-size: 16px;"><a href="https://arxiv.org/abs/2410.10629">[Paper]</a></span> <span style="font-size: 16px;"><a href="https://github.com/NVlabs/Sana">[Github]</a></span> <span style="font-size: 16px;"><a href="https://nvlabs.github.io/Sana">[Project]</a></span</p>
|
| 361 |
+
<p style="font-size: 16px; font-weight: bold;">Powered by <a href="https://hanlab.mit.edu/projects/dc-ae">DC-AE</a> with 32x latent space</p>
|
| 362 |
+
<p style="font-size: 16px; font-weight: bold;">Unsafe word will give you a 'Red Heart' in the image instead.</p>
|
| 363 |
+
"""
|
| 364 |
+
if model_size == "0.6":
|
| 365 |
+
DESCRIPTION += "\n<p>0.6B model's text rendering ability is limited.</p>"
|
| 366 |
+
if not torch.cuda.is_available():
|
| 367 |
+
DESCRIPTION += "\n<p>Running on CPU 🥶 This demo does not work on CPU.</p>"
|
| 368 |
+
|
| 369 |
+
examples = [
|
| 370 |
+
'a cyberpunk cat with a neon sign that says "Sana"',
|
| 371 |
+
"A very detailed and realistic full body photo set of a tall, slim, and athletic Shiba Inu in a white oversized straight t-shirt, white shorts, and short white shoes.",
|
| 372 |
+
"Pirate ship trapped in a cosmic maelstrom nebula, rendered in cosmic beach whirlpool engine, volumetric lighting, spectacular, ambient lights, light pollution, cinematic atmosphere, art nouveau style, illustration art artwork by SenseiJaye, intricate detail.",
|
| 373 |
+
"portrait photo of a girl, photograph, highly detailed face, depth of field",
|
| 374 |
+
'make me a logo that says "So Fast" with a really cool flying dragon shape with lightning sparks all over the sides and all of it contains Indonesian language',
|
| 375 |
+
"🐶 Wearing 🕶 flying on the 🌈",
|
| 376 |
+
# "👧 with 🌹 in the ❄️",
|
| 377 |
+
# "an old rusted robot wearing pants and a jacket riding skis in a supermarket.",
|
| 378 |
+
# "professional portrait photo of an anthropomorphic cat wearing fancy gentleman hat and jacket walking in autumn forest.",
|
| 379 |
+
# "Astronaut in a jungle, cold color palette, muted colors, detailed",
|
| 380 |
+
# "a stunning and luxurious bedroom carved into a rocky mountainside seamlessly blending nature with modern design with a plush earth-toned bed textured stone walls circular fireplace massive uniquely shaped window framing snow-capped mountains dense forests",
|
| 381 |
+
]
|
| 382 |
+
|
| 383 |
+
css = """
|
| 384 |
+
.gradio-container{max-width: 1024px !important}
|
| 385 |
+
h1{text-align:center}
|
| 386 |
+
"""
|
| 387 |
+
with gr.Blocks(css=css) as demo:
|
| 388 |
+
gr.Markdown(title)
|
| 389 |
+
gr.Markdown(DESCRIPTION)
|
| 390 |
+
gr.DuplicateButton(
|
| 391 |
+
value="Duplicate Space for private use",
|
| 392 |
+
elem_id="duplicate-button",
|
| 393 |
+
visible=os.getenv("SHOW_DUPLICATE_BUTTON") == "1",
|
| 394 |
+
)
|
| 395 |
+
info_box = gr.Markdown(
|
| 396 |
+
value=f"<span style='font-size: 16px; font-weight: bold;'>Total inference runs: </span><span style='font-size: 16px; color:red; font-weight: bold;'>{read_inference_count()}</span>"
|
| 397 |
+
)
|
| 398 |
+
demo.load(fn=update_inference_count, outputs=info_box) # update the value when re-loading the page
|
| 399 |
+
# with gr.Row(equal_height=False):
|
| 400 |
+
with gr.Group():
|
| 401 |
+
with gr.Row():
|
| 402 |
+
prompt = gr.Text(
|
| 403 |
+
label="Prompt",
|
| 404 |
+
show_label=False,
|
| 405 |
+
max_lines=1,
|
| 406 |
+
placeholder="Enter your prompt",
|
| 407 |
+
container=False,
|
| 408 |
+
)
|
| 409 |
+
run_button = gr.Button("Run-sana", scale=0)
|
| 410 |
+
run_button2 = gr.Button("Run-flux", scale=0)
|
| 411 |
+
|
| 412 |
+
with gr.Row():
|
| 413 |
+
result = gr.Gallery(label="Result from Sana", show_label=True, columns=NUM_IMAGES_PER_PROMPT, format="webp")
|
| 414 |
+
result_2 = gr.Gallery(
|
| 415 |
+
label="Result from FLUX", show_label=True, columns=NUM_IMAGES_PER_PROMPT, format="webp"
|
| 416 |
+
)
|
| 417 |
+
|
| 418 |
+
with gr.Accordion("Advanced options", open=False):
|
| 419 |
+
with gr.Group():
|
| 420 |
+
with gr.Row(visible=True):
|
| 421 |
+
height = gr.Slider(
|
| 422 |
+
label="Height",
|
| 423 |
+
minimum=256,
|
| 424 |
+
maximum=MAX_IMAGE_SIZE,
|
| 425 |
+
step=32,
|
| 426 |
+
value=1024,
|
| 427 |
+
)
|
| 428 |
+
width = gr.Slider(
|
| 429 |
+
label="Width",
|
| 430 |
+
minimum=256,
|
| 431 |
+
maximum=MAX_IMAGE_SIZE,
|
| 432 |
+
step=32,
|
| 433 |
+
value=1024,
|
| 434 |
+
)
|
| 435 |
+
with gr.Row():
|
| 436 |
+
flow_dpms_inference_steps = gr.Slider(
|
| 437 |
+
label="Sampling steps",
|
| 438 |
+
minimum=5,
|
| 439 |
+
maximum=40,
|
| 440 |
+
step=1,
|
| 441 |
+
value=18,
|
| 442 |
+
)
|
| 443 |
+
flow_dpms_guidance_scale = gr.Slider(
|
| 444 |
+
label="CFG Guidance scale",
|
| 445 |
+
minimum=1,
|
| 446 |
+
maximum=10,
|
| 447 |
+
step=0.1,
|
| 448 |
+
value=5.0,
|
| 449 |
+
)
|
| 450 |
+
flow_dpms_pag_guidance_scale = gr.Slider(
|
| 451 |
+
label="PAG Guidance scale",
|
| 452 |
+
minimum=1,
|
| 453 |
+
maximum=4,
|
| 454 |
+
step=0.5,
|
| 455 |
+
value=2.0,
|
| 456 |
+
)
|
| 457 |
+
with gr.Row():
|
| 458 |
+
use_negative_prompt = gr.Checkbox(label="Use negative prompt", value=False, visible=True)
|
| 459 |
+
negative_prompt = gr.Text(
|
| 460 |
+
label="Negative prompt",
|
| 461 |
+
max_lines=1,
|
| 462 |
+
placeholder="Enter a negative prompt",
|
| 463 |
+
visible=True,
|
| 464 |
+
)
|
| 465 |
+
style_selection = gr.Radio(
|
| 466 |
+
show_label=True,
|
| 467 |
+
container=True,
|
| 468 |
+
interactive=True,
|
| 469 |
+
choices=STYLE_NAMES,
|
| 470 |
+
value=DEFAULT_STYLE_NAME,
|
| 471 |
+
label="Image Style",
|
| 472 |
+
)
|
| 473 |
+
seed = gr.Slider(
|
| 474 |
+
label="Seed",
|
| 475 |
+
minimum=0,
|
| 476 |
+
maximum=MAX_SEED,
|
| 477 |
+
step=1,
|
| 478 |
+
value=0,
|
| 479 |
+
)
|
| 480 |
+
randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
|
| 481 |
+
with gr.Row(visible=True):
|
| 482 |
+
schedule = gr.Radio(
|
| 483 |
+
show_label=True,
|
| 484 |
+
container=True,
|
| 485 |
+
interactive=True,
|
| 486 |
+
choices=SCHEDULE_NAME,
|
| 487 |
+
value=DEFAULT_SCHEDULE_NAME,
|
| 488 |
+
label="Sampler Schedule",
|
| 489 |
+
visible=True,
|
| 490 |
+
)
|
| 491 |
+
num_imgs = gr.Slider(
|
| 492 |
+
label="Num Images",
|
| 493 |
+
minimum=1,
|
| 494 |
+
maximum=6,
|
| 495 |
+
step=1,
|
| 496 |
+
value=1,
|
| 497 |
+
)
|
| 498 |
+
|
| 499 |
+
run_button.click(fn=run_inference, inputs=num_imgs, outputs=info_box)
|
| 500 |
+
|
| 501 |
+
gr.Examples(
|
| 502 |
+
examples=examples,
|
| 503 |
+
inputs=prompt,
|
| 504 |
+
outputs=[result],
|
| 505 |
+
fn=generate,
|
| 506 |
+
cache_examples=CACHE_EXAMPLES,
|
| 507 |
+
)
|
| 508 |
+
gr.Examples(
|
| 509 |
+
examples=examples,
|
| 510 |
+
inputs=prompt,
|
| 511 |
+
outputs=[result_2],
|
| 512 |
+
fn=generate_2,
|
| 513 |
+
cache_examples=CACHE_EXAMPLES,
|
| 514 |
+
)
|
| 515 |
+
|
| 516 |
+
use_negative_prompt.change(
|
| 517 |
+
fn=lambda x: gr.update(visible=x),
|
| 518 |
+
inputs=use_negative_prompt,
|
| 519 |
+
outputs=negative_prompt,
|
| 520 |
+
api_name=False,
|
| 521 |
+
)
|
| 522 |
+
|
| 523 |
+
run_button.click(
|
| 524 |
+
fn=generate,
|
| 525 |
+
inputs=[
|
| 526 |
+
prompt,
|
| 527 |
+
negative_prompt,
|
| 528 |
+
style_selection,
|
| 529 |
+
use_negative_prompt,
|
| 530 |
+
num_imgs,
|
| 531 |
+
seed,
|
| 532 |
+
height,
|
| 533 |
+
width,
|
| 534 |
+
flow_dpms_guidance_scale,
|
| 535 |
+
flow_dpms_pag_guidance_scale,
|
| 536 |
+
flow_dpms_inference_steps,
|
| 537 |
+
randomize_seed,
|
| 538 |
+
],
|
| 539 |
+
outputs=[result],
|
| 540 |
+
queue=True,
|
| 541 |
+
)
|
| 542 |
+
|
| 543 |
+
run_button2.click(
|
| 544 |
+
fn=generate_2,
|
| 545 |
+
inputs=[
|
| 546 |
+
prompt,
|
| 547 |
+
negative_prompt,
|
| 548 |
+
style_selection,
|
| 549 |
+
use_negative_prompt,
|
| 550 |
+
num_imgs,
|
| 551 |
+
seed,
|
| 552 |
+
height,
|
| 553 |
+
width,
|
| 554 |
+
flow_dpms_guidance_scale,
|
| 555 |
+
flow_dpms_pag_guidance_scale,
|
| 556 |
+
flow_dpms_inference_steps,
|
| 557 |
+
randomize_seed,
|
| 558 |
+
],
|
| 559 |
+
outputs=[result_2],
|
| 560 |
+
queue=True,
|
| 561 |
+
)
|
| 562 |
+
|
| 563 |
+
|
| 564 |
+
if __name__ == "__main__":
|
| 565 |
+
demo.queue(max_size=20).launch(server_name="0.0.0.0", server_port=DEMO_PORT, debug=True, share=True)
|
apps/safety_check.py
ADDED
|
@@ -0,0 +1,72 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024 NVIDIA CORPORATION & AFFILIATES
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
#
|
| 15 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 16 |
+
# This file is copy from https://github.com/mit-han-lab/hart/blob/main/hart/utils/safety_check.py
|
| 17 |
+
import json
|
| 18 |
+
from functools import lru_cache
|
| 19 |
+
from typing import Dict, Sequence
|
| 20 |
+
|
| 21 |
+
import torch
|
| 22 |
+
import transformers
|
| 23 |
+
from torch.nn.functional import softmax
|
| 24 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer
|
| 25 |
+
|
| 26 |
+
__all__ = ["is_dangerous"]
|
| 27 |
+
|
| 28 |
+
__rank, __local_rank, __world_size, __device = (
|
| 29 |
+
0,
|
| 30 |
+
0,
|
| 31 |
+
1,
|
| 32 |
+
"cuda" if torch.cuda.is_available() else "cpu",
|
| 33 |
+
)
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
safety_policy = """
|
| 37 |
+
The prompt must not seek to generate harmful, abusive, hateful, sexually explicit, or otherwise inappropriate content targeting individuals or protected groups.
|
| 38 |
+
"""
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
def is_dangerous(tokenizer, model, user_prompt, threshold=0.5):
|
| 42 |
+
prompt = f"""You are a policy expert trying to help determine whether a user
|
| 43 |
+
prompt is in violation of the defined safety policies.
|
| 44 |
+
|
| 45 |
+
<start_of_turn>
|
| 46 |
+
Human Question: {user_prompt.strip()}
|
| 47 |
+
<end_of_turn>
|
| 48 |
+
|
| 49 |
+
Our safety principle is defined in the below:
|
| 50 |
+
|
| 51 |
+
{safety_policy.strip()}
|
| 52 |
+
|
| 53 |
+
Does the human question violate the above principle? Your answer must start
|
| 54 |
+
with 'Yes' or 'No'. And then walk through step by step to be sure we answer
|
| 55 |
+
correctly.
|
| 56 |
+
"""
|
| 57 |
+
|
| 58 |
+
inputs = tokenizer(prompt, return_tensors="pt").to("cuda")
|
| 59 |
+
with torch.no_grad():
|
| 60 |
+
logits = model(**inputs).logits
|
| 61 |
+
|
| 62 |
+
# Extract the logits for the Yes and No tokens
|
| 63 |
+
vocab = tokenizer.get_vocab()
|
| 64 |
+
selected_logits = logits[0, -1, [vocab["Yes"], vocab["No"]]]
|
| 65 |
+
|
| 66 |
+
# Convert these logits to a probability with softmax
|
| 67 |
+
probabilities = softmax(selected_logits, dim=0)
|
| 68 |
+
|
| 69 |
+
# Return probability of 'Yes'
|
| 70 |
+
score = probabilities[0].item()
|
| 71 |
+
|
| 72 |
+
return score > threshold
|
apps/sana_controlnet_pipeline.py
ADDED
|
@@ -0,0 +1,353 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024 NVIDIA CORPORATION & AFFILIATES
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
#
|
| 15 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 16 |
+
import warnings
|
| 17 |
+
from dataclasses import dataclass, field
|
| 18 |
+
from typing import Optional, Tuple
|
| 19 |
+
|
| 20 |
+
import numpy as np
|
| 21 |
+
import pyrallis
|
| 22 |
+
import torch
|
| 23 |
+
import torch.nn as nn
|
| 24 |
+
from PIL import Image
|
| 25 |
+
|
| 26 |
+
warnings.filterwarnings("ignore") # ignore warning
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
from diffusion import DPMS, FlowEuler
|
| 30 |
+
from diffusion.data.datasets.utils import (
|
| 31 |
+
ASPECT_RATIO_512_TEST,
|
| 32 |
+
ASPECT_RATIO_1024_TEST,
|
| 33 |
+
ASPECT_RATIO_2048_TEST,
|
| 34 |
+
ASPECT_RATIO_4096_TEST,
|
| 35 |
+
)
|
| 36 |
+
from diffusion.model.builder import build_model, get_tokenizer_and_text_encoder, get_vae, vae_decode, vae_encode
|
| 37 |
+
from diffusion.model.utils import get_weight_dtype, prepare_prompt_ar, resize_and_crop_tensor
|
| 38 |
+
from diffusion.utils.config import SanaConfig, model_init_config
|
| 39 |
+
from diffusion.utils.logger import get_root_logger
|
| 40 |
+
from tools.controlnet.utils import get_scribble_map, transform_control_signal
|
| 41 |
+
from tools.download import find_model
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
def guidance_type_select(default_guidance_type, pag_scale, attn_type):
|
| 45 |
+
guidance_type = default_guidance_type
|
| 46 |
+
if not (pag_scale > 1.0 and attn_type == "linear"):
|
| 47 |
+
guidance_type = "classifier-free"
|
| 48 |
+
elif pag_scale > 1.0 and attn_type == "linear":
|
| 49 |
+
guidance_type = "classifier-free_PAG"
|
| 50 |
+
return guidance_type
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
def classify_height_width_bin(height: int, width: int, ratios: dict) -> Tuple[int, int]:
|
| 54 |
+
"""Returns binned height and width."""
|
| 55 |
+
ar = float(height / width)
|
| 56 |
+
closest_ratio = min(ratios.keys(), key=lambda ratio: abs(float(ratio) - ar))
|
| 57 |
+
default_hw = ratios[closest_ratio]
|
| 58 |
+
return int(default_hw[0]), int(default_hw[1])
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
def get_ar_from_ref_image(ref_image):
|
| 62 |
+
def reduce_ratio(h, w):
|
| 63 |
+
def gcd(a, b):
|
| 64 |
+
while b:
|
| 65 |
+
a, b = b, a % b
|
| 66 |
+
return a
|
| 67 |
+
|
| 68 |
+
divisor = gcd(h, w)
|
| 69 |
+
return f"{h // divisor}:{w // divisor}"
|
| 70 |
+
|
| 71 |
+
if isinstance(ref_image, str):
|
| 72 |
+
ref_image = Image.open(ref_image)
|
| 73 |
+
w, h = ref_image.size
|
| 74 |
+
return reduce_ratio(h, w)
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
@dataclass
|
| 78 |
+
class SanaControlNetInference(SanaConfig):
|
| 79 |
+
config: Optional[str] = "configs/sana_config/1024ms/Sana_1600M_img1024.yaml" # config
|
| 80 |
+
model_path: str = field(
|
| 81 |
+
default="output/Sana_D20/SANA.pth", metadata={"help": "Path to the model file (positional)"}
|
| 82 |
+
)
|
| 83 |
+
output: str = "./output"
|
| 84 |
+
bs: int = 1
|
| 85 |
+
image_size: int = 1024
|
| 86 |
+
cfg_scale: float = 5.0
|
| 87 |
+
pag_scale: float = 2.0
|
| 88 |
+
seed: int = 42
|
| 89 |
+
step: int = -1
|
| 90 |
+
custom_image_size: Optional[int] = None
|
| 91 |
+
shield_model_path: str = field(
|
| 92 |
+
default="google/shieldgemma-2b",
|
| 93 |
+
metadata={"help": "The path to shield model, we employ ShieldGemma-2B by default."},
|
| 94 |
+
)
|
| 95 |
+
|
| 96 |
+
|
| 97 |
+
class SanaControlNetPipeline(nn.Module):
|
| 98 |
+
def __init__(
|
| 99 |
+
self,
|
| 100 |
+
config: Optional[str] = "configs/sana_config/1024ms/Sana_1600M_img1024.yaml",
|
| 101 |
+
):
|
| 102 |
+
super().__init__()
|
| 103 |
+
config = pyrallis.load(SanaControlNetInference, open(config))
|
| 104 |
+
self.args = self.config = config
|
| 105 |
+
|
| 106 |
+
# set some hyper-parameters
|
| 107 |
+
self.image_size = self.config.model.image_size
|
| 108 |
+
|
| 109 |
+
self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
| 110 |
+
logger = get_root_logger()
|
| 111 |
+
self.logger = logger
|
| 112 |
+
self.progress_fn = lambda progress, desc: None
|
| 113 |
+
self.thickness = 2
|
| 114 |
+
self.blend_alpha = 0.0
|
| 115 |
+
|
| 116 |
+
self.latent_size = self.image_size // config.vae.vae_downsample_rate
|
| 117 |
+
self.max_sequence_length = config.text_encoder.model_max_length
|
| 118 |
+
self.flow_shift = config.scheduler.flow_shift
|
| 119 |
+
guidance_type = "classifier-free_PAG"
|
| 120 |
+
|
| 121 |
+
weight_dtype = get_weight_dtype(config.model.mixed_precision)
|
| 122 |
+
self.weight_dtype = weight_dtype
|
| 123 |
+
self.vae_dtype = get_weight_dtype(config.vae.weight_dtype)
|
| 124 |
+
|
| 125 |
+
self.base_ratios = eval(f"ASPECT_RATIO_{self.image_size}_TEST")
|
| 126 |
+
self.vis_sampler = self.config.scheduler.vis_sampler
|
| 127 |
+
logger.info(f"Sampler {self.vis_sampler}, flow_shift: {self.flow_shift}")
|
| 128 |
+
self.guidance_type = guidance_type_select(guidance_type, self.args.pag_scale, config.model.attn_type)
|
| 129 |
+
logger.info(f"Inference with {self.weight_dtype}, PAG guidance layer: {self.config.model.pag_applied_layers}")
|
| 130 |
+
|
| 131 |
+
# 1. build vae and text encoder
|
| 132 |
+
self.vae = self.build_vae(config.vae)
|
| 133 |
+
self.tokenizer, self.text_encoder = self.build_text_encoder(config.text_encoder)
|
| 134 |
+
|
| 135 |
+
# 2. build Sana model
|
| 136 |
+
self.model = self.build_sana_model(config).to(self.device)
|
| 137 |
+
|
| 138 |
+
# 3. pre-compute null embedding
|
| 139 |
+
with torch.no_grad():
|
| 140 |
+
null_caption_token = self.tokenizer(
|
| 141 |
+
"", max_length=self.max_sequence_length, padding="max_length", truncation=True, return_tensors="pt"
|
| 142 |
+
).to(self.device)
|
| 143 |
+
self.null_caption_embs = self.text_encoder(null_caption_token.input_ids, null_caption_token.attention_mask)[
|
| 144 |
+
0
|
| 145 |
+
]
|
| 146 |
+
|
| 147 |
+
def build_vae(self, config):
|
| 148 |
+
vae = get_vae(config.vae_type, config.vae_pretrained, self.device).to(self.vae_dtype)
|
| 149 |
+
return vae
|
| 150 |
+
|
| 151 |
+
def build_text_encoder(self, config):
|
| 152 |
+
tokenizer, text_encoder = get_tokenizer_and_text_encoder(name=config.text_encoder_name, device=self.device)
|
| 153 |
+
return tokenizer, text_encoder
|
| 154 |
+
|
| 155 |
+
def build_sana_model(self, config):
|
| 156 |
+
# model setting
|
| 157 |
+
model_kwargs = model_init_config(config, latent_size=self.latent_size)
|
| 158 |
+
model = build_model(
|
| 159 |
+
config.model.model,
|
| 160 |
+
use_fp32_attention=config.model.get("fp32_attention", False) and config.model.mixed_precision != "bf16",
|
| 161 |
+
**model_kwargs,
|
| 162 |
+
)
|
| 163 |
+
self.logger.info(f"use_fp32_attention: {model.fp32_attention}")
|
| 164 |
+
self.logger.info(
|
| 165 |
+
f"{model.__class__.__name__}:{config.model.model},"
|
| 166 |
+
f"Model Parameters: {sum(p.numel() for p in model.parameters()):,}"
|
| 167 |
+
)
|
| 168 |
+
return model
|
| 169 |
+
|
| 170 |
+
def from_pretrained(self, model_path):
|
| 171 |
+
state_dict = find_model(model_path)
|
| 172 |
+
state_dict = state_dict.get("state_dict", state_dict)
|
| 173 |
+
if "pos_embed" in state_dict:
|
| 174 |
+
del state_dict["pos_embed"]
|
| 175 |
+
missing, unexpected = self.model.load_state_dict(state_dict, strict=False)
|
| 176 |
+
self.model.eval().to(self.weight_dtype)
|
| 177 |
+
|
| 178 |
+
self.logger.info("Generating sample from ckpt: %s" % model_path)
|
| 179 |
+
self.logger.warning(f"Missing keys: {missing}")
|
| 180 |
+
self.logger.warning(f"Unexpected keys: {unexpected}")
|
| 181 |
+
|
| 182 |
+
def register_progress_bar(self, progress_fn=None):
|
| 183 |
+
self.progress_fn = progress_fn if progress_fn is not None else self.progress_fn
|
| 184 |
+
|
| 185 |
+
def set_blend_alpha(self, blend_alpha):
|
| 186 |
+
self.blend_alpha = blend_alpha
|
| 187 |
+
|
| 188 |
+
@torch.inference_mode()
|
| 189 |
+
def forward(
|
| 190 |
+
self,
|
| 191 |
+
prompt=None,
|
| 192 |
+
ref_image=None,
|
| 193 |
+
negative_prompt="",
|
| 194 |
+
num_inference_steps=20,
|
| 195 |
+
guidance_scale=5,
|
| 196 |
+
pag_guidance_scale=2.5,
|
| 197 |
+
num_images_per_prompt=1,
|
| 198 |
+
sketch_thickness=2,
|
| 199 |
+
generator=torch.Generator().manual_seed(42),
|
| 200 |
+
latents=None,
|
| 201 |
+
):
|
| 202 |
+
self.ori_height, self.ori_width = ref_image.height, ref_image.width
|
| 203 |
+
self.guidance_type = guidance_type_select(self.guidance_type, pag_guidance_scale, self.config.model.attn_type)
|
| 204 |
+
|
| 205 |
+
# 1. pre-compute negative embedding
|
| 206 |
+
if negative_prompt != "":
|
| 207 |
+
null_caption_token = self.tokenizer(
|
| 208 |
+
negative_prompt,
|
| 209 |
+
max_length=self.max_sequence_length,
|
| 210 |
+
padding="max_length",
|
| 211 |
+
truncation=True,
|
| 212 |
+
return_tensors="pt",
|
| 213 |
+
).to(self.device)
|
| 214 |
+
self.null_caption_embs = self.text_encoder(null_caption_token.input_ids, null_caption_token.attention_mask)[
|
| 215 |
+
0
|
| 216 |
+
]
|
| 217 |
+
|
| 218 |
+
if prompt is None:
|
| 219 |
+
prompt = [""]
|
| 220 |
+
prompts = prompt if isinstance(prompt, list) else [prompt]
|
| 221 |
+
samples = []
|
| 222 |
+
|
| 223 |
+
for prompt in prompts:
|
| 224 |
+
# data prepare
|
| 225 |
+
prompts, hw, ar = (
|
| 226 |
+
[],
|
| 227 |
+
torch.tensor([[self.image_size, self.image_size]], dtype=torch.float, device=self.device).repeat(
|
| 228 |
+
num_images_per_prompt, 1
|
| 229 |
+
),
|
| 230 |
+
torch.tensor([[1.0]], device=self.device).repeat(num_images_per_prompt, 1),
|
| 231 |
+
)
|
| 232 |
+
|
| 233 |
+
ar = get_ar_from_ref_image(ref_image)
|
| 234 |
+
prompt += f" --ar {ar}"
|
| 235 |
+
for _ in range(num_images_per_prompt):
|
| 236 |
+
prompt_clean, _, hw, ar, custom_hw = prepare_prompt_ar(
|
| 237 |
+
prompt, self.base_ratios, device=self.device, show=False
|
| 238 |
+
)
|
| 239 |
+
prompts.append(prompt_clean.strip())
|
| 240 |
+
|
| 241 |
+
self.latent_size_h, self.latent_size_w = (
|
| 242 |
+
int(hw[0, 0] // self.config.vae.vae_downsample_rate),
|
| 243 |
+
int(hw[0, 1] // self.config.vae.vae_downsample_rate),
|
| 244 |
+
)
|
| 245 |
+
|
| 246 |
+
with torch.no_grad():
|
| 247 |
+
# prepare text feature
|
| 248 |
+
if not self.config.text_encoder.chi_prompt:
|
| 249 |
+
max_length_all = self.config.text_encoder.model_max_length
|
| 250 |
+
prompts_all = prompts
|
| 251 |
+
else:
|
| 252 |
+
chi_prompt = "\n".join(self.config.text_encoder.chi_prompt)
|
| 253 |
+
prompts_all = [chi_prompt + prompt for prompt in prompts]
|
| 254 |
+
num_chi_prompt_tokens = len(self.tokenizer.encode(chi_prompt))
|
| 255 |
+
max_length_all = (
|
| 256 |
+
num_chi_prompt_tokens + self.config.text_encoder.model_max_length - 2
|
| 257 |
+
) # magic number 2: [bos], [_]
|
| 258 |
+
|
| 259 |
+
caption_token = self.tokenizer(
|
| 260 |
+
prompts_all,
|
| 261 |
+
max_length=max_length_all,
|
| 262 |
+
padding="max_length",
|
| 263 |
+
truncation=True,
|
| 264 |
+
return_tensors="pt",
|
| 265 |
+
).to(device=self.device)
|
| 266 |
+
select_index = [0] + list(range(-self.config.text_encoder.model_max_length + 1, 0))
|
| 267 |
+
caption_embs = self.text_encoder(caption_token.input_ids, caption_token.attention_mask)[0][:, None][
|
| 268 |
+
:, :, select_index
|
| 269 |
+
].to(self.weight_dtype)
|
| 270 |
+
emb_masks = caption_token.attention_mask[:, select_index]
|
| 271 |
+
null_y = self.null_caption_embs.repeat(len(prompts), 1, 1)[:, None].to(self.weight_dtype)
|
| 272 |
+
|
| 273 |
+
n = len(prompts)
|
| 274 |
+
if latents is None:
|
| 275 |
+
z = torch.randn(
|
| 276 |
+
n,
|
| 277 |
+
self.config.vae.vae_latent_dim,
|
| 278 |
+
self.latent_size_h,
|
| 279 |
+
self.latent_size_w,
|
| 280 |
+
generator=generator,
|
| 281 |
+
device=self.device,
|
| 282 |
+
)
|
| 283 |
+
else:
|
| 284 |
+
z = latents.to(self.device)
|
| 285 |
+
model_kwargs = dict(data_info={"img_hw": hw, "aspect_ratio": ar}, mask=emb_masks)
|
| 286 |
+
|
| 287 |
+
# control signal
|
| 288 |
+
if isinstance(ref_image, str):
|
| 289 |
+
ref_image = cv2.imread(ref_image)
|
| 290 |
+
elif isinstance(ref_image, Image.Image):
|
| 291 |
+
ref_image = np.array(ref_image)
|
| 292 |
+
control_signal = get_scribble_map(
|
| 293 |
+
input_image=ref_image,
|
| 294 |
+
det="Scribble_HED",
|
| 295 |
+
detect_resolution=int(hw.min()),
|
| 296 |
+
thickness=sketch_thickness,
|
| 297 |
+
)
|
| 298 |
+
|
| 299 |
+
control_signal = transform_control_signal(control_signal, hw).to(self.device).to(self.weight_dtype)
|
| 300 |
+
|
| 301 |
+
control_signal_latent = vae_encode(
|
| 302 |
+
self.config.vae.vae_type, self.vae, control_signal, self.config.vae.sample_posterior, self.device
|
| 303 |
+
)
|
| 304 |
+
|
| 305 |
+
model_kwargs["control_signal"] = control_signal_latent
|
| 306 |
+
|
| 307 |
+
if self.vis_sampler == "flow_euler":
|
| 308 |
+
flow_solver = FlowEuler(
|
| 309 |
+
self.model,
|
| 310 |
+
condition=caption_embs,
|
| 311 |
+
uncondition=null_y,
|
| 312 |
+
cfg_scale=guidance_scale,
|
| 313 |
+
model_kwargs=model_kwargs,
|
| 314 |
+
)
|
| 315 |
+
sample = flow_solver.sample(
|
| 316 |
+
z,
|
| 317 |
+
steps=num_inference_steps,
|
| 318 |
+
)
|
| 319 |
+
elif self.vis_sampler == "flow_dpm-solver":
|
| 320 |
+
scheduler = DPMS(
|
| 321 |
+
self.model.forward_with_dpmsolver,
|
| 322 |
+
condition=caption_embs,
|
| 323 |
+
uncondition=null_y,
|
| 324 |
+
guidance_type=self.guidance_type,
|
| 325 |
+
cfg_scale=guidance_scale,
|
| 326 |
+
model_type="flow",
|
| 327 |
+
model_kwargs=model_kwargs,
|
| 328 |
+
schedule="FLOW",
|
| 329 |
+
)
|
| 330 |
+
scheduler.register_progress_bar(self.progress_fn)
|
| 331 |
+
sample = scheduler.sample(
|
| 332 |
+
z,
|
| 333 |
+
steps=num_inference_steps,
|
| 334 |
+
order=2,
|
| 335 |
+
skip_type="time_uniform_flow",
|
| 336 |
+
method="multistep",
|
| 337 |
+
flow_shift=self.flow_shift,
|
| 338 |
+
)
|
| 339 |
+
|
| 340 |
+
sample = sample.to(self.vae_dtype)
|
| 341 |
+
with torch.no_grad():
|
| 342 |
+
sample = vae_decode(self.config.vae.vae_type, self.vae, sample)
|
| 343 |
+
|
| 344 |
+
if self.blend_alpha > 0:
|
| 345 |
+
print(f"blend image and mask with alpha: {self.blend_alpha}")
|
| 346 |
+
sample = sample * (1 - self.blend_alpha) + control_signal * self.blend_alpha
|
| 347 |
+
|
| 348 |
+
sample = resize_and_crop_tensor(sample, self.ori_width, self.ori_height)
|
| 349 |
+
samples.append(sample)
|
| 350 |
+
|
| 351 |
+
return sample
|
| 352 |
+
|
| 353 |
+
return samples
|
apps/sana_pipeline.py
ADDED
|
@@ -0,0 +1,304 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024 NVIDIA CORPORATION & AFFILIATES
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
#
|
| 15 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 16 |
+
import argparse
|
| 17 |
+
import warnings
|
| 18 |
+
from dataclasses import dataclass, field
|
| 19 |
+
from typing import Optional, Tuple
|
| 20 |
+
|
| 21 |
+
import pyrallis
|
| 22 |
+
import torch
|
| 23 |
+
import torch.nn as nn
|
| 24 |
+
|
| 25 |
+
warnings.filterwarnings("ignore") # ignore warning
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
from diffusion import DPMS, FlowEuler
|
| 29 |
+
from diffusion.data.datasets.utils import (
|
| 30 |
+
ASPECT_RATIO_512_TEST,
|
| 31 |
+
ASPECT_RATIO_1024_TEST,
|
| 32 |
+
ASPECT_RATIO_2048_TEST,
|
| 33 |
+
ASPECT_RATIO_4096_TEST,
|
| 34 |
+
)
|
| 35 |
+
from diffusion.model.builder import build_model, get_tokenizer_and_text_encoder, get_vae, vae_decode
|
| 36 |
+
from diffusion.model.utils import get_weight_dtype, prepare_prompt_ar, resize_and_crop_tensor
|
| 37 |
+
from diffusion.utils.config import SanaConfig, model_init_config
|
| 38 |
+
from diffusion.utils.logger import get_root_logger
|
| 39 |
+
|
| 40 |
+
# from diffusion.utils.misc import read_config
|
| 41 |
+
from tools.download import find_model
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
def guidance_type_select(default_guidance_type, pag_scale, attn_type):
|
| 45 |
+
guidance_type = default_guidance_type
|
| 46 |
+
if not (pag_scale > 1.0 and attn_type == "linear"):
|
| 47 |
+
guidance_type = "classifier-free"
|
| 48 |
+
elif pag_scale > 1.0 and attn_type == "linear":
|
| 49 |
+
guidance_type = "classifier-free_PAG"
|
| 50 |
+
return guidance_type
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
def classify_height_width_bin(height: int, width: int, ratios: dict) -> Tuple[int, int]:
|
| 54 |
+
"""Returns binned height and width."""
|
| 55 |
+
ar = float(height / width)
|
| 56 |
+
closest_ratio = min(ratios.keys(), key=lambda ratio: abs(float(ratio) - ar))
|
| 57 |
+
default_hw = ratios[closest_ratio]
|
| 58 |
+
return int(default_hw[0]), int(default_hw[1])
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
@dataclass
|
| 62 |
+
class SanaInference(SanaConfig):
|
| 63 |
+
config: Optional[str] = "configs/sana_config/1024ms/Sana_1600M_img1024.yaml" # config
|
| 64 |
+
model_path: str = field(
|
| 65 |
+
default="output/Sana_D20/SANA.pth", metadata={"help": "Path to the model file (positional)"}
|
| 66 |
+
)
|
| 67 |
+
output: str = "./output"
|
| 68 |
+
bs: int = 1
|
| 69 |
+
image_size: int = 1024
|
| 70 |
+
cfg_scale: float = 5.0
|
| 71 |
+
pag_scale: float = 2.0
|
| 72 |
+
seed: int = 42
|
| 73 |
+
step: int = -1
|
| 74 |
+
custom_image_size: Optional[int] = None
|
| 75 |
+
shield_model_path: str = field(
|
| 76 |
+
default="google/shieldgemma-2b",
|
| 77 |
+
metadata={"help": "The path to shield model, we employ ShieldGemma-2B by default."},
|
| 78 |
+
)
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
class SanaPipeline(nn.Module):
|
| 82 |
+
def __init__(
|
| 83 |
+
self,
|
| 84 |
+
config: Optional[str] = "configs/sana_config/1024ms/Sana_1600M_img1024.yaml",
|
| 85 |
+
):
|
| 86 |
+
super().__init__()
|
| 87 |
+
config = pyrallis.load(SanaInference, open(config))
|
| 88 |
+
self.args = self.config = config
|
| 89 |
+
|
| 90 |
+
# set some hyper-parameters
|
| 91 |
+
self.image_size = self.config.model.image_size
|
| 92 |
+
|
| 93 |
+
self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
| 94 |
+
logger = get_root_logger()
|
| 95 |
+
self.logger = logger
|
| 96 |
+
self.progress_fn = lambda progress, desc: None
|
| 97 |
+
|
| 98 |
+
self.latent_size = self.image_size // config.vae.vae_downsample_rate
|
| 99 |
+
self.max_sequence_length = config.text_encoder.model_max_length
|
| 100 |
+
self.flow_shift = config.scheduler.flow_shift
|
| 101 |
+
guidance_type = "classifier-free_PAG"
|
| 102 |
+
|
| 103 |
+
weight_dtype = get_weight_dtype(config.model.mixed_precision)
|
| 104 |
+
self.weight_dtype = weight_dtype
|
| 105 |
+
self.vae_dtype = get_weight_dtype(config.vae.weight_dtype)
|
| 106 |
+
|
| 107 |
+
self.base_ratios = eval(f"ASPECT_RATIO_{self.image_size}_TEST")
|
| 108 |
+
self.vis_sampler = self.config.scheduler.vis_sampler
|
| 109 |
+
logger.info(f"Sampler {self.vis_sampler}, flow_shift: {self.flow_shift}")
|
| 110 |
+
self.guidance_type = guidance_type_select(guidance_type, self.args.pag_scale, config.model.attn_type)
|
| 111 |
+
logger.info(f"Inference with {self.weight_dtype}, PAG guidance layer: {self.config.model.pag_applied_layers}")
|
| 112 |
+
|
| 113 |
+
# 1. build vae and text encoder
|
| 114 |
+
self.vae = self.build_vae(config.vae)
|
| 115 |
+
self.tokenizer, self.text_encoder = self.build_text_encoder(config.text_encoder)
|
| 116 |
+
|
| 117 |
+
# 2. build Sana model
|
| 118 |
+
self.model = self.build_sana_model(config).to(self.device)
|
| 119 |
+
|
| 120 |
+
# 3. pre-compute null embedding
|
| 121 |
+
with torch.no_grad():
|
| 122 |
+
null_caption_token = self.tokenizer(
|
| 123 |
+
"", max_length=self.max_sequence_length, padding="max_length", truncation=True, return_tensors="pt"
|
| 124 |
+
).to(self.device)
|
| 125 |
+
self.null_caption_embs = self.text_encoder(null_caption_token.input_ids, null_caption_token.attention_mask)[
|
| 126 |
+
0
|
| 127 |
+
]
|
| 128 |
+
|
| 129 |
+
def build_vae(self, config):
|
| 130 |
+
vae = get_vae(config.vae_type, config.vae_pretrained, self.device).to(self.vae_dtype)
|
| 131 |
+
return vae
|
| 132 |
+
|
| 133 |
+
def build_text_encoder(self, config):
|
| 134 |
+
tokenizer, text_encoder = get_tokenizer_and_text_encoder(name=config.text_encoder_name, device=self.device)
|
| 135 |
+
return tokenizer, text_encoder
|
| 136 |
+
|
| 137 |
+
def build_sana_model(self, config):
|
| 138 |
+
# model setting
|
| 139 |
+
model_kwargs = model_init_config(config, latent_size=self.latent_size)
|
| 140 |
+
model = build_model(
|
| 141 |
+
config.model.model,
|
| 142 |
+
use_fp32_attention=config.model.get("fp32_attention", False) and config.model.mixed_precision != "bf16",
|
| 143 |
+
**model_kwargs,
|
| 144 |
+
)
|
| 145 |
+
self.logger.info(f"use_fp32_attention: {model.fp32_attention}")
|
| 146 |
+
self.logger.info(
|
| 147 |
+
f"{model.__class__.__name__}:{config.model.model},"
|
| 148 |
+
f"Model Parameters: {sum(p.numel() for p in model.parameters()):,}"
|
| 149 |
+
)
|
| 150 |
+
return model
|
| 151 |
+
|
| 152 |
+
def from_pretrained(self, model_path):
|
| 153 |
+
state_dict = find_model(model_path)
|
| 154 |
+
state_dict = state_dict.get("state_dict", state_dict)
|
| 155 |
+
if "pos_embed" in state_dict:
|
| 156 |
+
del state_dict["pos_embed"]
|
| 157 |
+
missing, unexpected = self.model.load_state_dict(state_dict, strict=False)
|
| 158 |
+
self.model.eval().to(self.weight_dtype)
|
| 159 |
+
|
| 160 |
+
self.logger.info("Generating sample from ckpt: %s" % model_path)
|
| 161 |
+
self.logger.warning(f"Missing keys: {missing}")
|
| 162 |
+
self.logger.warning(f"Unexpected keys: {unexpected}")
|
| 163 |
+
|
| 164 |
+
def register_progress_bar(self, progress_fn=None):
|
| 165 |
+
self.progress_fn = progress_fn if progress_fn is not None else self.progress_fn
|
| 166 |
+
|
| 167 |
+
@torch.inference_mode()
|
| 168 |
+
def forward(
|
| 169 |
+
self,
|
| 170 |
+
prompt=None,
|
| 171 |
+
height=1024,
|
| 172 |
+
width=1024,
|
| 173 |
+
negative_prompt="",
|
| 174 |
+
num_inference_steps=20,
|
| 175 |
+
guidance_scale=5,
|
| 176 |
+
pag_guidance_scale=2.5,
|
| 177 |
+
num_images_per_prompt=1,
|
| 178 |
+
generator=torch.Generator().manual_seed(42),
|
| 179 |
+
latents=None,
|
| 180 |
+
):
|
| 181 |
+
self.ori_height, self.ori_width = height, width
|
| 182 |
+
self.height, self.width = classify_height_width_bin(height, width, ratios=self.base_ratios)
|
| 183 |
+
self.latent_size_h, self.latent_size_w = (
|
| 184 |
+
self.height // self.config.vae.vae_downsample_rate,
|
| 185 |
+
self.width // self.config.vae.vae_downsample_rate,
|
| 186 |
+
)
|
| 187 |
+
self.guidance_type = guidance_type_select(self.guidance_type, pag_guidance_scale, self.config.model.attn_type)
|
| 188 |
+
|
| 189 |
+
# 1. pre-compute negative embedding
|
| 190 |
+
if negative_prompt != "":
|
| 191 |
+
null_caption_token = self.tokenizer(
|
| 192 |
+
negative_prompt,
|
| 193 |
+
max_length=self.max_sequence_length,
|
| 194 |
+
padding="max_length",
|
| 195 |
+
truncation=True,
|
| 196 |
+
return_tensors="pt",
|
| 197 |
+
).to(self.device)
|
| 198 |
+
self.null_caption_embs = self.text_encoder(null_caption_token.input_ids, null_caption_token.attention_mask)[
|
| 199 |
+
0
|
| 200 |
+
]
|
| 201 |
+
|
| 202 |
+
if prompt is None:
|
| 203 |
+
prompt = [""]
|
| 204 |
+
prompts = prompt if isinstance(prompt, list) else [prompt]
|
| 205 |
+
samples = []
|
| 206 |
+
|
| 207 |
+
for prompt in prompts:
|
| 208 |
+
# data prepare
|
| 209 |
+
prompts, hw, ar = (
|
| 210 |
+
[],
|
| 211 |
+
torch.tensor([[self.image_size, self.image_size]], dtype=torch.float, device=self.device).repeat(
|
| 212 |
+
num_images_per_prompt, 1
|
| 213 |
+
),
|
| 214 |
+
torch.tensor([[1.0]], device=self.device).repeat(num_images_per_prompt, 1),
|
| 215 |
+
)
|
| 216 |
+
|
| 217 |
+
for _ in range(num_images_per_prompt):
|
| 218 |
+
prompts.append(prepare_prompt_ar(prompt, self.base_ratios, device=self.device, show=False)[0].strip())
|
| 219 |
+
|
| 220 |
+
with torch.no_grad():
|
| 221 |
+
# prepare text feature
|
| 222 |
+
if not self.config.text_encoder.chi_prompt:
|
| 223 |
+
max_length_all = self.config.text_encoder.model_max_length
|
| 224 |
+
prompts_all = prompts
|
| 225 |
+
else:
|
| 226 |
+
chi_prompt = "\n".join(self.config.text_encoder.chi_prompt)
|
| 227 |
+
prompts_all = [chi_prompt + prompt for prompt in prompts]
|
| 228 |
+
num_chi_prompt_tokens = len(self.tokenizer.encode(chi_prompt))
|
| 229 |
+
max_length_all = (
|
| 230 |
+
num_chi_prompt_tokens + self.config.text_encoder.model_max_length - 2
|
| 231 |
+
) # magic number 2: [bos], [_]
|
| 232 |
+
|
| 233 |
+
caption_token = self.tokenizer(
|
| 234 |
+
prompts_all,
|
| 235 |
+
max_length=max_length_all,
|
| 236 |
+
padding="max_length",
|
| 237 |
+
truncation=True,
|
| 238 |
+
return_tensors="pt",
|
| 239 |
+
).to(device=self.device)
|
| 240 |
+
select_index = [0] + list(range(-self.config.text_encoder.model_max_length + 1, 0))
|
| 241 |
+
caption_embs = self.text_encoder(caption_token.input_ids, caption_token.attention_mask)[0][:, None][
|
| 242 |
+
:, :, select_index
|
| 243 |
+
].to(self.weight_dtype)
|
| 244 |
+
emb_masks = caption_token.attention_mask[:, select_index]
|
| 245 |
+
null_y = self.null_caption_embs.repeat(len(prompts), 1, 1)[:, None].to(self.weight_dtype)
|
| 246 |
+
|
| 247 |
+
n = len(prompts)
|
| 248 |
+
if latents is None:
|
| 249 |
+
z = torch.randn(
|
| 250 |
+
n,
|
| 251 |
+
self.config.vae.vae_latent_dim,
|
| 252 |
+
self.latent_size_h,
|
| 253 |
+
self.latent_size_w,
|
| 254 |
+
generator=generator,
|
| 255 |
+
device=self.device,
|
| 256 |
+
)
|
| 257 |
+
else:
|
| 258 |
+
z = latents.to(self.device)
|
| 259 |
+
model_kwargs = dict(data_info={"img_hw": hw, "aspect_ratio": ar}, mask=emb_masks)
|
| 260 |
+
if self.vis_sampler == "flow_euler":
|
| 261 |
+
flow_solver = FlowEuler(
|
| 262 |
+
self.model,
|
| 263 |
+
condition=caption_embs,
|
| 264 |
+
uncondition=null_y,
|
| 265 |
+
cfg_scale=guidance_scale,
|
| 266 |
+
model_kwargs=model_kwargs,
|
| 267 |
+
)
|
| 268 |
+
sample = flow_solver.sample(
|
| 269 |
+
z,
|
| 270 |
+
steps=num_inference_steps,
|
| 271 |
+
)
|
| 272 |
+
elif self.vis_sampler == "flow_dpm-solver":
|
| 273 |
+
scheduler = DPMS(
|
| 274 |
+
self.model,
|
| 275 |
+
condition=caption_embs,
|
| 276 |
+
uncondition=null_y,
|
| 277 |
+
guidance_type=self.guidance_type,
|
| 278 |
+
cfg_scale=guidance_scale,
|
| 279 |
+
pag_scale=pag_guidance_scale,
|
| 280 |
+
pag_applied_layers=self.config.model.pag_applied_layers,
|
| 281 |
+
model_type="flow",
|
| 282 |
+
model_kwargs=model_kwargs,
|
| 283 |
+
schedule="FLOW",
|
| 284 |
+
)
|
| 285 |
+
scheduler.register_progress_bar(self.progress_fn)
|
| 286 |
+
sample = scheduler.sample(
|
| 287 |
+
z,
|
| 288 |
+
steps=num_inference_steps,
|
| 289 |
+
order=2,
|
| 290 |
+
skip_type="time_uniform_flow",
|
| 291 |
+
method="multistep",
|
| 292 |
+
flow_shift=self.flow_shift,
|
| 293 |
+
)
|
| 294 |
+
|
| 295 |
+
sample = sample.to(self.vae_dtype)
|
| 296 |
+
with torch.no_grad():
|
| 297 |
+
sample = vae_decode(self.config.vae.vae_type, self.vae, sample)
|
| 298 |
+
|
| 299 |
+
sample = resize_and_crop_tensor(sample, self.ori_width, self.ori_height)
|
| 300 |
+
samples.append(sample)
|
| 301 |
+
|
| 302 |
+
return sample
|
| 303 |
+
|
| 304 |
+
return samples
|