Spaces:
Running
Running
removed src folder
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- src/__init__.py +0 -0
- src/app.py +0 -554
- src/app_settings.py +0 -124
- src/backend/__init__.py +0 -0
- src/backend/annotators/canny_control.py +0 -15
- src/backend/annotators/control_interface.py +0 -12
- src/backend/annotators/depth_control.py +0 -15
- src/backend/annotators/image_control_factory.py +0 -31
- src/backend/annotators/lineart_control.py +0 -11
- src/backend/annotators/mlsd_control.py +0 -10
- src/backend/annotators/normal_control.py +0 -10
- src/backend/annotators/pose_control.py +0 -10
- src/backend/annotators/shuffle_control.py +0 -10
- src/backend/annotators/softedge_control.py +0 -10
- src/backend/api/mcp_server.py +0 -97
- src/backend/api/models/response.py +0 -16
- src/backend/api/web.py +0 -112
- src/backend/base64_image.py +0 -21
- src/backend/controlnet.py +0 -90
- src/backend/device.py +0 -23
- src/backend/gguf/gguf_diffusion.py +0 -319
- src/backend/gguf/sdcpp_types.py +0 -104
- src/backend/image_saver.py +0 -75
- src/backend/lcm_text_to_image.py +0 -577
- src/backend/lora.py +0 -136
- src/backend/models/device.py +0 -9
- src/backend/models/gen_images.py +0 -17
- src/backend/models/lcmdiffusion_setting.py +0 -76
- src/backend/models/upscale.py +0 -9
- src/backend/openvino/custom_ov_model_vae_decoder.py +0 -21
- src/backend/openvino/flux_pipeline.py +0 -36
- src/backend/openvino/ov_hc_stablediffusion_pipeline.py +0 -93
- src/backend/openvino/ovflux.py +0 -675
- src/backend/openvino/pipelines.py +0 -75
- src/backend/openvino/stable_diffusion_engine.py +0 -1817
- src/backend/pipelines/lcm.py +0 -122
- src/backend/pipelines/lcm_lora.py +0 -81
- src/backend/tiny_decoder.py +0 -32
- src/backend/upscale/aura_sr.py +0 -1004
- src/backend/upscale/aura_sr_upscale.py +0 -9
- src/backend/upscale/edsr_upscale_onnx.py +0 -37
- src/backend/upscale/tiled_upscale.py +0 -237
- src/backend/upscale/upscaler.py +0 -52
- src/constants.py +0 -25
- src/context.py +0 -85
- src/frontend/cli_interactive.py +0 -661
- src/frontend/gui/app_window.py +0 -595
- src/frontend/gui/base_widget.py +0 -199
- src/frontend/gui/image_generator_worker.py +0 -37
- src/frontend/gui/image_variations_widget.py +0 -35
src/__init__.py
DELETED
|
File without changes
|
src/app.py
DELETED
|
@@ -1,554 +0,0 @@
|
|
| 1 |
-
import json
|
| 2 |
-
from argparse import ArgumentParser
|
| 3 |
-
|
| 4 |
-
from PIL import Image
|
| 5 |
-
|
| 6 |
-
import constants
|
| 7 |
-
from backend.controlnet import controlnet_settings_from_dict
|
| 8 |
-
from backend.device import get_device_name
|
| 9 |
-
from backend.models.gen_images import ImageFormat
|
| 10 |
-
from backend.models.lcmdiffusion_setting import DiffusionTask
|
| 11 |
-
from backend.upscale.tiled_upscale import generate_upscaled_image
|
| 12 |
-
from constants import APP_VERSION, DEVICE
|
| 13 |
-
from frontend.webui.image_variations_ui import generate_image_variations
|
| 14 |
-
from models.interface_types import InterfaceType
|
| 15 |
-
from paths import FastStableDiffusionPaths, ensure_path
|
| 16 |
-
from state import get_context, get_settings
|
| 17 |
-
from utils import show_system_info
|
| 18 |
-
|
| 19 |
-
parser = ArgumentParser(description=f"FAST SD CPU {constants.APP_VERSION}")
|
| 20 |
-
parser.add_argument(
|
| 21 |
-
"-s",
|
| 22 |
-
"--share",
|
| 23 |
-
action="store_true",
|
| 24 |
-
help="Create sharable link(Web UI)",
|
| 25 |
-
required=False,
|
| 26 |
-
)
|
| 27 |
-
group = parser.add_mutually_exclusive_group(required=False)
|
| 28 |
-
group.add_argument(
|
| 29 |
-
"-g",
|
| 30 |
-
"--gui",
|
| 31 |
-
action="store_true",
|
| 32 |
-
help="Start desktop GUI",
|
| 33 |
-
)
|
| 34 |
-
group.add_argument(
|
| 35 |
-
"-w",
|
| 36 |
-
"--webui",
|
| 37 |
-
action="store_true",
|
| 38 |
-
help="Start Web UI",
|
| 39 |
-
)
|
| 40 |
-
group.add_argument(
|
| 41 |
-
"-a",
|
| 42 |
-
"--api",
|
| 43 |
-
action="store_true",
|
| 44 |
-
help="Start Web API server",
|
| 45 |
-
)
|
| 46 |
-
group.add_argument(
|
| 47 |
-
"-m",
|
| 48 |
-
"--mcp",
|
| 49 |
-
action="store_true",
|
| 50 |
-
help="Start MCP(Model Context Protocol) server",
|
| 51 |
-
)
|
| 52 |
-
group.add_argument(
|
| 53 |
-
"-r",
|
| 54 |
-
"--realtime",
|
| 55 |
-
action="store_true",
|
| 56 |
-
help="Start realtime inference UI(experimental)",
|
| 57 |
-
)
|
| 58 |
-
group.add_argument(
|
| 59 |
-
"-v",
|
| 60 |
-
"--version",
|
| 61 |
-
action="store_true",
|
| 62 |
-
help="Version",
|
| 63 |
-
)
|
| 64 |
-
|
| 65 |
-
parser.add_argument(
|
| 66 |
-
"-b",
|
| 67 |
-
"--benchmark",
|
| 68 |
-
action="store_true",
|
| 69 |
-
help="Run inference benchmark on the selected device",
|
| 70 |
-
)
|
| 71 |
-
parser.add_argument(
|
| 72 |
-
"--lcm_model_id",
|
| 73 |
-
type=str,
|
| 74 |
-
help="Model ID or path,Default stabilityai/sd-turbo",
|
| 75 |
-
default="stabilityai/sd-turbo",
|
| 76 |
-
)
|
| 77 |
-
parser.add_argument(
|
| 78 |
-
"--openvino_lcm_model_id",
|
| 79 |
-
type=str,
|
| 80 |
-
help="OpenVINO Model ID or path,Default rupeshs/sd-turbo-openvino",
|
| 81 |
-
default="rupeshs/sd-turbo-openvino",
|
| 82 |
-
)
|
| 83 |
-
parser.add_argument(
|
| 84 |
-
"--prompt",
|
| 85 |
-
type=str,
|
| 86 |
-
help="Describe the image you want to generate",
|
| 87 |
-
default="",
|
| 88 |
-
)
|
| 89 |
-
parser.add_argument(
|
| 90 |
-
"--negative_prompt",
|
| 91 |
-
type=str,
|
| 92 |
-
help="Describe what you want to exclude from the generation",
|
| 93 |
-
default="",
|
| 94 |
-
)
|
| 95 |
-
parser.add_argument(
|
| 96 |
-
"--image_height",
|
| 97 |
-
type=int,
|
| 98 |
-
help="Height of the image",
|
| 99 |
-
default=512,
|
| 100 |
-
)
|
| 101 |
-
parser.add_argument(
|
| 102 |
-
"--image_width",
|
| 103 |
-
type=int,
|
| 104 |
-
help="Width of the image",
|
| 105 |
-
default=512,
|
| 106 |
-
)
|
| 107 |
-
parser.add_argument(
|
| 108 |
-
"--inference_steps",
|
| 109 |
-
type=int,
|
| 110 |
-
help="Number of steps,default : 1",
|
| 111 |
-
default=1,
|
| 112 |
-
)
|
| 113 |
-
parser.add_argument(
|
| 114 |
-
"--guidance_scale",
|
| 115 |
-
type=float,
|
| 116 |
-
help="Guidance scale,default : 1.0",
|
| 117 |
-
default=1.0,
|
| 118 |
-
)
|
| 119 |
-
|
| 120 |
-
parser.add_argument(
|
| 121 |
-
"--number_of_images",
|
| 122 |
-
type=int,
|
| 123 |
-
help="Number of images to generate ,default : 1",
|
| 124 |
-
default=1,
|
| 125 |
-
)
|
| 126 |
-
parser.add_argument(
|
| 127 |
-
"--seed",
|
| 128 |
-
type=int,
|
| 129 |
-
help="Seed,default : -1 (disabled) ",
|
| 130 |
-
default=-1,
|
| 131 |
-
)
|
| 132 |
-
parser.add_argument(
|
| 133 |
-
"--use_openvino",
|
| 134 |
-
action="store_true",
|
| 135 |
-
help="Use OpenVINO model",
|
| 136 |
-
)
|
| 137 |
-
|
| 138 |
-
parser.add_argument(
|
| 139 |
-
"--use_offline_model",
|
| 140 |
-
action="store_true",
|
| 141 |
-
help="Use offline model",
|
| 142 |
-
)
|
| 143 |
-
parser.add_argument(
|
| 144 |
-
"--clip_skip",
|
| 145 |
-
type=int,
|
| 146 |
-
help="CLIP Skip (1-12), default : 1 (disabled) ",
|
| 147 |
-
default=1,
|
| 148 |
-
)
|
| 149 |
-
parser.add_argument(
|
| 150 |
-
"--token_merging",
|
| 151 |
-
type=float,
|
| 152 |
-
help="Token merging scale, 0.0 - 1.0, default : 0.0",
|
| 153 |
-
default=0.0,
|
| 154 |
-
)
|
| 155 |
-
|
| 156 |
-
parser.add_argument(
|
| 157 |
-
"--use_safety_checker",
|
| 158 |
-
action="store_true",
|
| 159 |
-
help="Use safety checker",
|
| 160 |
-
)
|
| 161 |
-
parser.add_argument(
|
| 162 |
-
"--use_lcm_lora",
|
| 163 |
-
action="store_true",
|
| 164 |
-
help="Use LCM-LoRA",
|
| 165 |
-
)
|
| 166 |
-
parser.add_argument(
|
| 167 |
-
"--base_model_id",
|
| 168 |
-
type=str,
|
| 169 |
-
help="LCM LoRA base model ID,Default Lykon/dreamshaper-8",
|
| 170 |
-
default="Lykon/dreamshaper-8",
|
| 171 |
-
)
|
| 172 |
-
parser.add_argument(
|
| 173 |
-
"--lcm_lora_id",
|
| 174 |
-
type=str,
|
| 175 |
-
help="LCM LoRA model ID,Default latent-consistency/lcm-lora-sdv1-5",
|
| 176 |
-
default="latent-consistency/lcm-lora-sdv1-5",
|
| 177 |
-
)
|
| 178 |
-
parser.add_argument(
|
| 179 |
-
"-i",
|
| 180 |
-
"--interactive",
|
| 181 |
-
action="store_true",
|
| 182 |
-
help="Interactive CLI mode",
|
| 183 |
-
)
|
| 184 |
-
parser.add_argument(
|
| 185 |
-
"-t",
|
| 186 |
-
"--use_tiny_auto_encoder",
|
| 187 |
-
action="store_true",
|
| 188 |
-
help="Use tiny auto encoder for SD (TAESD)",
|
| 189 |
-
)
|
| 190 |
-
parser.add_argument(
|
| 191 |
-
"-f",
|
| 192 |
-
"--file",
|
| 193 |
-
type=str,
|
| 194 |
-
help="Input image for img2img mode",
|
| 195 |
-
default="",
|
| 196 |
-
)
|
| 197 |
-
parser.add_argument(
|
| 198 |
-
"--img2img",
|
| 199 |
-
action="store_true",
|
| 200 |
-
help="img2img mode; requires input file via -f argument",
|
| 201 |
-
)
|
| 202 |
-
parser.add_argument(
|
| 203 |
-
"--batch_count",
|
| 204 |
-
type=int,
|
| 205 |
-
help="Number of sequential generations",
|
| 206 |
-
default=1,
|
| 207 |
-
)
|
| 208 |
-
parser.add_argument(
|
| 209 |
-
"--strength",
|
| 210 |
-
type=float,
|
| 211 |
-
help="Denoising strength for img2img and Image variations",
|
| 212 |
-
default=0.3,
|
| 213 |
-
)
|
| 214 |
-
parser.add_argument(
|
| 215 |
-
"--sdupscale",
|
| 216 |
-
action="store_true",
|
| 217 |
-
help="Tiled SD upscale,works only for the resolution 512x512,(2x upscale)",
|
| 218 |
-
)
|
| 219 |
-
parser.add_argument(
|
| 220 |
-
"--upscale",
|
| 221 |
-
action="store_true",
|
| 222 |
-
help="EDSR SD upscale ",
|
| 223 |
-
)
|
| 224 |
-
parser.add_argument(
|
| 225 |
-
"--custom_settings",
|
| 226 |
-
type=str,
|
| 227 |
-
help="JSON file containing custom generation settings",
|
| 228 |
-
default=None,
|
| 229 |
-
)
|
| 230 |
-
parser.add_argument(
|
| 231 |
-
"--usejpeg",
|
| 232 |
-
action="store_true",
|
| 233 |
-
help="Images will be saved as JPEG format",
|
| 234 |
-
)
|
| 235 |
-
parser.add_argument(
|
| 236 |
-
"--noimagesave",
|
| 237 |
-
action="store_true",
|
| 238 |
-
help="Disable image saving",
|
| 239 |
-
)
|
| 240 |
-
parser.add_argument(
|
| 241 |
-
"--imagequality", type=int, help="Output image quality [0 to 100]", default=90
|
| 242 |
-
)
|
| 243 |
-
parser.add_argument(
|
| 244 |
-
"--lora",
|
| 245 |
-
type=str,
|
| 246 |
-
help="LoRA model full path e.g D:\lora_models\CuteCartoon15V-LiberteRedmodModel-Cartoon-CuteCartoonAF.safetensors",
|
| 247 |
-
default=None,
|
| 248 |
-
)
|
| 249 |
-
parser.add_argument(
|
| 250 |
-
"--lora_weight",
|
| 251 |
-
type=float,
|
| 252 |
-
help="LoRA adapter weight [0 to 1.0]",
|
| 253 |
-
default=0.5,
|
| 254 |
-
)
|
| 255 |
-
parser.add_argument(
|
| 256 |
-
"--port",
|
| 257 |
-
type=int,
|
| 258 |
-
help="Web server port",
|
| 259 |
-
default=8000,
|
| 260 |
-
)
|
| 261 |
-
|
| 262 |
-
args = parser.parse_args()
|
| 263 |
-
|
| 264 |
-
if args.version:
|
| 265 |
-
print(APP_VERSION)
|
| 266 |
-
exit()
|
| 267 |
-
|
| 268 |
-
# parser.print_help()
|
| 269 |
-
print("FastSD CPU - ", APP_VERSION)
|
| 270 |
-
show_system_info()
|
| 271 |
-
print(f"Using device : {constants.DEVICE}")
|
| 272 |
-
|
| 273 |
-
|
| 274 |
-
if args.webui:
|
| 275 |
-
app_settings = get_settings()
|
| 276 |
-
else:
|
| 277 |
-
app_settings = get_settings()
|
| 278 |
-
|
| 279 |
-
print(f"Output path : {app_settings.settings.generated_images.path}")
|
| 280 |
-
ensure_path(app_settings.settings.generated_images.path)
|
| 281 |
-
|
| 282 |
-
print(f"Found {len(app_settings.lcm_models)} LCM models in config/lcm-models.txt")
|
| 283 |
-
print(
|
| 284 |
-
f"Found {len(app_settings.stable_diffsuion_models)} stable diffusion models in config/stable-diffusion-models.txt"
|
| 285 |
-
)
|
| 286 |
-
print(
|
| 287 |
-
f"Found {len(app_settings.lcm_lora_models)} LCM-LoRA models in config/lcm-lora-models.txt"
|
| 288 |
-
)
|
| 289 |
-
print(
|
| 290 |
-
f"Found {len(app_settings.openvino_lcm_models)} OpenVINO LCM models in config/openvino-lcm-models.txt"
|
| 291 |
-
)
|
| 292 |
-
|
| 293 |
-
if args.noimagesave:
|
| 294 |
-
app_settings.settings.generated_images.save_image = False
|
| 295 |
-
else:
|
| 296 |
-
app_settings.settings.generated_images.save_image = True
|
| 297 |
-
|
| 298 |
-
app_settings.settings.generated_images.save_image_quality = args.imagequality
|
| 299 |
-
|
| 300 |
-
if not args.realtime:
|
| 301 |
-
# To minimize realtime mode dependencies
|
| 302 |
-
from backend.upscale.upscaler import upscale_image
|
| 303 |
-
from frontend.cli_interactive import interactive_mode
|
| 304 |
-
|
| 305 |
-
if args.gui:
|
| 306 |
-
from frontend.gui.ui import start_gui
|
| 307 |
-
|
| 308 |
-
print("Starting desktop GUI mode(Qt)")
|
| 309 |
-
start_gui(
|
| 310 |
-
[],
|
| 311 |
-
app_settings,
|
| 312 |
-
)
|
| 313 |
-
elif args.webui:
|
| 314 |
-
from frontend.webui.ui import start_webui
|
| 315 |
-
|
| 316 |
-
print("Starting web UI mode")
|
| 317 |
-
start_webui(
|
| 318 |
-
args.share,
|
| 319 |
-
)
|
| 320 |
-
elif args.realtime:
|
| 321 |
-
from frontend.webui.realtime_ui import start_realtime_text_to_image
|
| 322 |
-
|
| 323 |
-
print("Starting realtime text to image(EXPERIMENTAL)")
|
| 324 |
-
start_realtime_text_to_image(args.share)
|
| 325 |
-
elif args.api:
|
| 326 |
-
from backend.api.web import start_web_server
|
| 327 |
-
|
| 328 |
-
start_web_server(args.port)
|
| 329 |
-
elif args.mcp:
|
| 330 |
-
from backend.api.mcp_server import start_mcp_server
|
| 331 |
-
|
| 332 |
-
start_mcp_server(args.port)
|
| 333 |
-
else:
|
| 334 |
-
context = get_context(InterfaceType.CLI)
|
| 335 |
-
config = app_settings.settings
|
| 336 |
-
|
| 337 |
-
if args.use_openvino:
|
| 338 |
-
config.lcm_diffusion_setting.openvino_lcm_model_id = args.openvino_lcm_model_id
|
| 339 |
-
else:
|
| 340 |
-
config.lcm_diffusion_setting.lcm_model_id = args.lcm_model_id
|
| 341 |
-
|
| 342 |
-
config.lcm_diffusion_setting.prompt = args.prompt
|
| 343 |
-
config.lcm_diffusion_setting.negative_prompt = args.negative_prompt
|
| 344 |
-
config.lcm_diffusion_setting.image_height = args.image_height
|
| 345 |
-
config.lcm_diffusion_setting.image_width = args.image_width
|
| 346 |
-
config.lcm_diffusion_setting.guidance_scale = args.guidance_scale
|
| 347 |
-
config.lcm_diffusion_setting.number_of_images = args.number_of_images
|
| 348 |
-
config.lcm_diffusion_setting.inference_steps = args.inference_steps
|
| 349 |
-
config.lcm_diffusion_setting.strength = args.strength
|
| 350 |
-
config.lcm_diffusion_setting.seed = args.seed
|
| 351 |
-
config.lcm_diffusion_setting.use_openvino = args.use_openvino
|
| 352 |
-
config.lcm_diffusion_setting.use_tiny_auto_encoder = args.use_tiny_auto_encoder
|
| 353 |
-
config.lcm_diffusion_setting.use_lcm_lora = args.use_lcm_lora
|
| 354 |
-
config.lcm_diffusion_setting.lcm_lora.base_model_id = args.base_model_id
|
| 355 |
-
config.lcm_diffusion_setting.lcm_lora.lcm_lora_id = args.lcm_lora_id
|
| 356 |
-
config.lcm_diffusion_setting.diffusion_task = DiffusionTask.text_to_image.value
|
| 357 |
-
config.lcm_diffusion_setting.lora.enabled = False
|
| 358 |
-
config.lcm_diffusion_setting.lora.path = args.lora
|
| 359 |
-
config.lcm_diffusion_setting.lora.weight = args.lora_weight
|
| 360 |
-
config.lcm_diffusion_setting.lora.fuse = True
|
| 361 |
-
if config.lcm_diffusion_setting.lora.path:
|
| 362 |
-
config.lcm_diffusion_setting.lora.enabled = True
|
| 363 |
-
if args.usejpeg:
|
| 364 |
-
config.generated_images.format = ImageFormat.JPEG.value.upper()
|
| 365 |
-
if args.seed > -1:
|
| 366 |
-
config.lcm_diffusion_setting.use_seed = True
|
| 367 |
-
else:
|
| 368 |
-
config.lcm_diffusion_setting.use_seed = False
|
| 369 |
-
config.lcm_diffusion_setting.use_offline_model = args.use_offline_model
|
| 370 |
-
config.lcm_diffusion_setting.clip_skip = args.clip_skip
|
| 371 |
-
config.lcm_diffusion_setting.token_merging = args.token_merging
|
| 372 |
-
config.lcm_diffusion_setting.use_safety_checker = args.use_safety_checker
|
| 373 |
-
|
| 374 |
-
# Read custom settings from JSON file
|
| 375 |
-
custom_settings = {}
|
| 376 |
-
if args.custom_settings:
|
| 377 |
-
with open(args.custom_settings) as f:
|
| 378 |
-
custom_settings = json.load(f)
|
| 379 |
-
|
| 380 |
-
# Basic ControlNet settings; if ControlNet is enabled, an image is
|
| 381 |
-
# required even in txt2img mode
|
| 382 |
-
config.lcm_diffusion_setting.controlnet = None
|
| 383 |
-
controlnet_settings_from_dict(
|
| 384 |
-
config.lcm_diffusion_setting,
|
| 385 |
-
custom_settings,
|
| 386 |
-
)
|
| 387 |
-
|
| 388 |
-
# Interactive mode
|
| 389 |
-
if args.interactive:
|
| 390 |
-
# wrapper(interactive_mode, config, context)
|
| 391 |
-
config.lcm_diffusion_setting.lora.fuse = False
|
| 392 |
-
interactive_mode(config, context)
|
| 393 |
-
|
| 394 |
-
# Start of non-interactive CLI image generation
|
| 395 |
-
if args.img2img and args.file != "":
|
| 396 |
-
config.lcm_diffusion_setting.init_image = Image.open(args.file)
|
| 397 |
-
config.lcm_diffusion_setting.diffusion_task = DiffusionTask.image_to_image.value
|
| 398 |
-
elif args.img2img and args.file == "":
|
| 399 |
-
print("Error : You need to specify a file in img2img mode")
|
| 400 |
-
exit()
|
| 401 |
-
elif args.upscale and args.file == "" and args.custom_settings == None:
|
| 402 |
-
print("Error : You need to specify a file in SD upscale mode")
|
| 403 |
-
exit()
|
| 404 |
-
elif (
|
| 405 |
-
args.prompt == ""
|
| 406 |
-
and args.file == ""
|
| 407 |
-
and args.custom_settings == None
|
| 408 |
-
and not args.benchmark
|
| 409 |
-
):
|
| 410 |
-
print("Error : You need to provide a prompt")
|
| 411 |
-
exit()
|
| 412 |
-
|
| 413 |
-
if args.upscale:
|
| 414 |
-
# image = Image.open(args.file)
|
| 415 |
-
output_path = FastStableDiffusionPaths.get_upscale_filepath(
|
| 416 |
-
args.file,
|
| 417 |
-
2,
|
| 418 |
-
config.generated_images.format,
|
| 419 |
-
)
|
| 420 |
-
result = upscale_image(
|
| 421 |
-
context,
|
| 422 |
-
args.file,
|
| 423 |
-
output_path,
|
| 424 |
-
2,
|
| 425 |
-
)
|
| 426 |
-
# Perform Tiled SD upscale (EXPERIMENTAL)
|
| 427 |
-
elif args.sdupscale:
|
| 428 |
-
if args.use_openvino:
|
| 429 |
-
config.lcm_diffusion_setting.strength = 0.3
|
| 430 |
-
upscale_settings = None
|
| 431 |
-
if custom_settings != {}:
|
| 432 |
-
upscale_settings = custom_settings
|
| 433 |
-
filepath = args.file
|
| 434 |
-
output_format = config.generated_images.format
|
| 435 |
-
if upscale_settings:
|
| 436 |
-
filepath = upscale_settings["source_file"]
|
| 437 |
-
output_format = upscale_settings["output_format"].upper()
|
| 438 |
-
output_path = FastStableDiffusionPaths.get_upscale_filepath(
|
| 439 |
-
filepath,
|
| 440 |
-
2,
|
| 441 |
-
output_format,
|
| 442 |
-
)
|
| 443 |
-
|
| 444 |
-
generate_upscaled_image(
|
| 445 |
-
config,
|
| 446 |
-
filepath,
|
| 447 |
-
config.lcm_diffusion_setting.strength,
|
| 448 |
-
upscale_settings=upscale_settings,
|
| 449 |
-
context=context,
|
| 450 |
-
tile_overlap=32 if config.lcm_diffusion_setting.use_openvino else 16,
|
| 451 |
-
output_path=output_path,
|
| 452 |
-
image_format=output_format,
|
| 453 |
-
)
|
| 454 |
-
exit()
|
| 455 |
-
# If img2img argument is set and prompt is empty, use image variations mode
|
| 456 |
-
elif args.img2img and args.prompt == "":
|
| 457 |
-
for i in range(0, args.batch_count):
|
| 458 |
-
generate_image_variations(
|
| 459 |
-
config.lcm_diffusion_setting.init_image, args.strength
|
| 460 |
-
)
|
| 461 |
-
else:
|
| 462 |
-
if args.benchmark:
|
| 463 |
-
print("Initializing benchmark...")
|
| 464 |
-
bench_lcm_setting = config.lcm_diffusion_setting
|
| 465 |
-
bench_lcm_setting.prompt = "a cat"
|
| 466 |
-
bench_lcm_setting.use_tiny_auto_encoder = False
|
| 467 |
-
context.generate_text_to_image(
|
| 468 |
-
settings=config,
|
| 469 |
-
device=DEVICE,
|
| 470 |
-
)
|
| 471 |
-
|
| 472 |
-
latencies = []
|
| 473 |
-
|
| 474 |
-
print("Starting benchmark please wait...")
|
| 475 |
-
for _ in range(3):
|
| 476 |
-
context.generate_text_to_image(
|
| 477 |
-
settings=config,
|
| 478 |
-
device=DEVICE,
|
| 479 |
-
)
|
| 480 |
-
latencies.append(context.latency)
|
| 481 |
-
|
| 482 |
-
avg_latency = sum(latencies) / 3
|
| 483 |
-
|
| 484 |
-
bench_lcm_setting.use_tiny_auto_encoder = True
|
| 485 |
-
|
| 486 |
-
context.generate_text_to_image(
|
| 487 |
-
settings=config,
|
| 488 |
-
device=DEVICE,
|
| 489 |
-
)
|
| 490 |
-
latencies = []
|
| 491 |
-
for _ in range(3):
|
| 492 |
-
context.generate_text_to_image(
|
| 493 |
-
settings=config,
|
| 494 |
-
device=DEVICE,
|
| 495 |
-
)
|
| 496 |
-
latencies.append(context.latency)
|
| 497 |
-
|
| 498 |
-
avg_latency_taesd = sum(latencies) / 3
|
| 499 |
-
|
| 500 |
-
benchmark_name = ""
|
| 501 |
-
|
| 502 |
-
if config.lcm_diffusion_setting.use_openvino:
|
| 503 |
-
benchmark_name = "OpenVINO"
|
| 504 |
-
else:
|
| 505 |
-
benchmark_name = "PyTorch"
|
| 506 |
-
|
| 507 |
-
bench_model_id = ""
|
| 508 |
-
if bench_lcm_setting.use_openvino:
|
| 509 |
-
bench_model_id = bench_lcm_setting.openvino_lcm_model_id
|
| 510 |
-
elif bench_lcm_setting.use_lcm_lora:
|
| 511 |
-
bench_model_id = bench_lcm_setting.lcm_lora.base_model_id
|
| 512 |
-
else:
|
| 513 |
-
bench_model_id = bench_lcm_setting.lcm_model_id
|
| 514 |
-
|
| 515 |
-
benchmark_result = [
|
| 516 |
-
["Device", f"{DEVICE.upper()},{get_device_name()}"],
|
| 517 |
-
["Stable Diffusion Model", bench_model_id],
|
| 518 |
-
[
|
| 519 |
-
"Image Size ",
|
| 520 |
-
f"{bench_lcm_setting.image_width}x{bench_lcm_setting.image_height}",
|
| 521 |
-
],
|
| 522 |
-
[
|
| 523 |
-
"Inference Steps",
|
| 524 |
-
f"{bench_lcm_setting.inference_steps}",
|
| 525 |
-
],
|
| 526 |
-
[
|
| 527 |
-
"Benchmark Passes",
|
| 528 |
-
3,
|
| 529 |
-
],
|
| 530 |
-
[
|
| 531 |
-
"Average Latency",
|
| 532 |
-
f"{round(avg_latency, 3)} sec",
|
| 533 |
-
],
|
| 534 |
-
[
|
| 535 |
-
"Average Latency(TAESD* enabled)",
|
| 536 |
-
f"{round(avg_latency_taesd, 3)} sec",
|
| 537 |
-
],
|
| 538 |
-
]
|
| 539 |
-
print()
|
| 540 |
-
print(
|
| 541 |
-
f" FastSD Benchmark - {benchmark_name:8} "
|
| 542 |
-
)
|
| 543 |
-
print(f"-" * 80)
|
| 544 |
-
for benchmark in benchmark_result:
|
| 545 |
-
print(f"{benchmark[0]:35} - {benchmark[1]}")
|
| 546 |
-
print(f"-" * 80)
|
| 547 |
-
print("*TAESD - Tiny AutoEncoder for Stable Diffusion")
|
| 548 |
-
|
| 549 |
-
else:
|
| 550 |
-
for i in range(0, args.batch_count):
|
| 551 |
-
context.generate_text_to_image(
|
| 552 |
-
settings=config,
|
| 553 |
-
device=DEVICE,
|
| 554 |
-
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src/app_settings.py
DELETED
|
@@ -1,124 +0,0 @@
|
|
| 1 |
-
from copy import deepcopy
|
| 2 |
-
from os import makedirs, path
|
| 3 |
-
|
| 4 |
-
import yaml
|
| 5 |
-
from constants import (
|
| 6 |
-
LCM_LORA_MODELS_FILE,
|
| 7 |
-
LCM_MODELS_FILE,
|
| 8 |
-
OPENVINO_LCM_MODELS_FILE,
|
| 9 |
-
SD_MODELS_FILE,
|
| 10 |
-
)
|
| 11 |
-
from paths import FastStableDiffusionPaths, join_paths
|
| 12 |
-
from utils import get_files_in_dir, get_models_from_text_file
|
| 13 |
-
|
| 14 |
-
from models.settings import Settings
|
| 15 |
-
|
| 16 |
-
|
| 17 |
-
class AppSettings:
|
| 18 |
-
def __init__(self):
|
| 19 |
-
self.config_path = FastStableDiffusionPaths().get_app_settings_path()
|
| 20 |
-
self._stable_diffsuion_models = get_models_from_text_file(
|
| 21 |
-
FastStableDiffusionPaths().get_models_config_path(SD_MODELS_FILE)
|
| 22 |
-
)
|
| 23 |
-
self._lcm_lora_models = get_models_from_text_file(
|
| 24 |
-
FastStableDiffusionPaths().get_models_config_path(LCM_LORA_MODELS_FILE)
|
| 25 |
-
)
|
| 26 |
-
self._openvino_lcm_models = get_models_from_text_file(
|
| 27 |
-
FastStableDiffusionPaths().get_models_config_path(OPENVINO_LCM_MODELS_FILE)
|
| 28 |
-
)
|
| 29 |
-
self._lcm_models = get_models_from_text_file(
|
| 30 |
-
FastStableDiffusionPaths().get_models_config_path(LCM_MODELS_FILE)
|
| 31 |
-
)
|
| 32 |
-
self._gguf_diffusion_models = get_files_in_dir(
|
| 33 |
-
join_paths(FastStableDiffusionPaths().get_gguf_models_path(), "diffusion")
|
| 34 |
-
)
|
| 35 |
-
self._gguf_clip_models = get_files_in_dir(
|
| 36 |
-
join_paths(FastStableDiffusionPaths().get_gguf_models_path(), "clip")
|
| 37 |
-
)
|
| 38 |
-
self._gguf_vae_models = get_files_in_dir(
|
| 39 |
-
join_paths(FastStableDiffusionPaths().get_gguf_models_path(), "vae")
|
| 40 |
-
)
|
| 41 |
-
self._gguf_t5xxl_models = get_files_in_dir(
|
| 42 |
-
join_paths(FastStableDiffusionPaths().get_gguf_models_path(), "t5xxl")
|
| 43 |
-
)
|
| 44 |
-
self._config = None
|
| 45 |
-
|
| 46 |
-
@property
|
| 47 |
-
def settings(self):
|
| 48 |
-
return self._config
|
| 49 |
-
|
| 50 |
-
@property
|
| 51 |
-
def stable_diffsuion_models(self):
|
| 52 |
-
return self._stable_diffsuion_models
|
| 53 |
-
|
| 54 |
-
@property
|
| 55 |
-
def openvino_lcm_models(self):
|
| 56 |
-
return self._openvino_lcm_models
|
| 57 |
-
|
| 58 |
-
@property
|
| 59 |
-
def lcm_models(self):
|
| 60 |
-
return self._lcm_models
|
| 61 |
-
|
| 62 |
-
@property
|
| 63 |
-
def lcm_lora_models(self):
|
| 64 |
-
return self._lcm_lora_models
|
| 65 |
-
|
| 66 |
-
@property
|
| 67 |
-
def gguf_diffusion_models(self):
|
| 68 |
-
return self._gguf_diffusion_models
|
| 69 |
-
|
| 70 |
-
@property
|
| 71 |
-
def gguf_clip_models(self):
|
| 72 |
-
return self._gguf_clip_models
|
| 73 |
-
|
| 74 |
-
@property
|
| 75 |
-
def gguf_vae_models(self):
|
| 76 |
-
return self._gguf_vae_models
|
| 77 |
-
|
| 78 |
-
@property
|
| 79 |
-
def gguf_t5xxl_models(self):
|
| 80 |
-
return self._gguf_t5xxl_models
|
| 81 |
-
|
| 82 |
-
def load(self, skip_file=False):
|
| 83 |
-
if skip_file:
|
| 84 |
-
print("Skipping config file")
|
| 85 |
-
settings_dict = self._load_default()
|
| 86 |
-
self._config = Settings.model_validate(settings_dict)
|
| 87 |
-
else:
|
| 88 |
-
if not path.exists(self.config_path):
|
| 89 |
-
base_dir = path.dirname(self.config_path)
|
| 90 |
-
if not path.exists(base_dir):
|
| 91 |
-
makedirs(base_dir)
|
| 92 |
-
try:
|
| 93 |
-
print("Settings not found creating default settings")
|
| 94 |
-
with open(self.config_path, "w") as file:
|
| 95 |
-
yaml.dump(
|
| 96 |
-
self._load_default(),
|
| 97 |
-
file,
|
| 98 |
-
)
|
| 99 |
-
except Exception as ex:
|
| 100 |
-
print(f"Error in creating settings : {ex}")
|
| 101 |
-
exit()
|
| 102 |
-
try:
|
| 103 |
-
with open(self.config_path) as file:
|
| 104 |
-
settings_dict = yaml.safe_load(file)
|
| 105 |
-
self._config = Settings.model_validate(settings_dict)
|
| 106 |
-
except Exception as ex:
|
| 107 |
-
print(f"Error in loading settings : {ex}")
|
| 108 |
-
|
| 109 |
-
def save(self):
|
| 110 |
-
try:
|
| 111 |
-
with open(self.config_path, "w") as file:
|
| 112 |
-
tmp_cfg = deepcopy(self._config)
|
| 113 |
-
tmp_cfg.lcm_diffusion_setting.init_image = None
|
| 114 |
-
configurations = tmp_cfg.model_dump(
|
| 115 |
-
exclude=["init_image"],
|
| 116 |
-
)
|
| 117 |
-
if configurations:
|
| 118 |
-
yaml.dump(configurations, file)
|
| 119 |
-
except Exception as ex:
|
| 120 |
-
print(f"Error in saving settings : {ex}")
|
| 121 |
-
|
| 122 |
-
def _load_default(self) -> dict:
|
| 123 |
-
default_config = Settings()
|
| 124 |
-
return default_config.model_dump()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src/backend/__init__.py
DELETED
|
File without changes
|
src/backend/annotators/canny_control.py
DELETED
|
@@ -1,15 +0,0 @@
|
|
| 1 |
-
import numpy as np
|
| 2 |
-
from backend.annotators.control_interface import ControlInterface
|
| 3 |
-
from cv2 import Canny
|
| 4 |
-
from PIL import Image
|
| 5 |
-
|
| 6 |
-
|
| 7 |
-
class CannyControl(ControlInterface):
|
| 8 |
-
def get_control_image(self, image: Image) -> Image:
|
| 9 |
-
low_threshold = 100
|
| 10 |
-
high_threshold = 200
|
| 11 |
-
image = np.array(image)
|
| 12 |
-
image = Canny(image, low_threshold, high_threshold)
|
| 13 |
-
image = image[:, :, None]
|
| 14 |
-
image = np.concatenate([image, image, image], axis=2)
|
| 15 |
-
return Image.fromarray(image)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src/backend/annotators/control_interface.py
DELETED
|
@@ -1,12 +0,0 @@
|
|
| 1 |
-
from abc import ABC, abstractmethod
|
| 2 |
-
|
| 3 |
-
from PIL import Image
|
| 4 |
-
|
| 5 |
-
|
| 6 |
-
class ControlInterface(ABC):
|
| 7 |
-
@abstractmethod
|
| 8 |
-
def get_control_image(
|
| 9 |
-
self,
|
| 10 |
-
image: Image,
|
| 11 |
-
) -> Image:
|
| 12 |
-
pass
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src/backend/annotators/depth_control.py
DELETED
|
@@ -1,15 +0,0 @@
|
|
| 1 |
-
import numpy as np
|
| 2 |
-
from backend.annotators.control_interface import ControlInterface
|
| 3 |
-
from PIL import Image
|
| 4 |
-
from transformers import pipeline
|
| 5 |
-
|
| 6 |
-
|
| 7 |
-
class DepthControl(ControlInterface):
|
| 8 |
-
def get_control_image(self, image: Image) -> Image:
|
| 9 |
-
depth_estimator = pipeline("depth-estimation")
|
| 10 |
-
image = depth_estimator(image)["depth"]
|
| 11 |
-
image = np.array(image)
|
| 12 |
-
image = image[:, :, None]
|
| 13 |
-
image = np.concatenate([image, image, image], axis=2)
|
| 14 |
-
image = Image.fromarray(image)
|
| 15 |
-
return image
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src/backend/annotators/image_control_factory.py
DELETED
|
@@ -1,31 +0,0 @@
|
|
| 1 |
-
from backend.annotators.canny_control import CannyControl
|
| 2 |
-
from backend.annotators.depth_control import DepthControl
|
| 3 |
-
from backend.annotators.lineart_control import LineArtControl
|
| 4 |
-
from backend.annotators.mlsd_control import MlsdControl
|
| 5 |
-
from backend.annotators.normal_control import NormalControl
|
| 6 |
-
from backend.annotators.pose_control import PoseControl
|
| 7 |
-
from backend.annotators.shuffle_control import ShuffleControl
|
| 8 |
-
from backend.annotators.softedge_control import SoftEdgeControl
|
| 9 |
-
|
| 10 |
-
|
| 11 |
-
class ImageControlFactory:
|
| 12 |
-
def create_control(self, controlnet_type: str):
|
| 13 |
-
if controlnet_type == "Canny":
|
| 14 |
-
return CannyControl()
|
| 15 |
-
elif controlnet_type == "Pose":
|
| 16 |
-
return PoseControl()
|
| 17 |
-
elif controlnet_type == "MLSD":
|
| 18 |
-
return MlsdControl()
|
| 19 |
-
elif controlnet_type == "Depth":
|
| 20 |
-
return DepthControl()
|
| 21 |
-
elif controlnet_type == "LineArt":
|
| 22 |
-
return LineArtControl()
|
| 23 |
-
elif controlnet_type == "Shuffle":
|
| 24 |
-
return ShuffleControl()
|
| 25 |
-
elif controlnet_type == "NormalBAE":
|
| 26 |
-
return NormalControl()
|
| 27 |
-
elif controlnet_type == "SoftEdge":
|
| 28 |
-
return SoftEdgeControl()
|
| 29 |
-
else:
|
| 30 |
-
print("Error: Control type not implemented!")
|
| 31 |
-
raise Exception("Error: Control type not implemented!")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src/backend/annotators/lineart_control.py
DELETED
|
@@ -1,11 +0,0 @@
|
|
| 1 |
-
import numpy as np
|
| 2 |
-
from backend.annotators.control_interface import ControlInterface
|
| 3 |
-
from controlnet_aux import LineartDetector
|
| 4 |
-
from PIL import Image
|
| 5 |
-
|
| 6 |
-
|
| 7 |
-
class LineArtControl(ControlInterface):
|
| 8 |
-
def get_control_image(self, image: Image) -> Image:
|
| 9 |
-
processor = LineartDetector.from_pretrained("lllyasviel/Annotators")
|
| 10 |
-
control_image = processor(image)
|
| 11 |
-
return control_image
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src/backend/annotators/mlsd_control.py
DELETED
|
@@ -1,10 +0,0 @@
|
|
| 1 |
-
from backend.annotators.control_interface import ControlInterface
|
| 2 |
-
from controlnet_aux import MLSDdetector
|
| 3 |
-
from PIL import Image
|
| 4 |
-
|
| 5 |
-
|
| 6 |
-
class MlsdControl(ControlInterface):
|
| 7 |
-
def get_control_image(self, image: Image) -> Image:
|
| 8 |
-
mlsd = MLSDdetector.from_pretrained("lllyasviel/ControlNet")
|
| 9 |
-
image = mlsd(image)
|
| 10 |
-
return image
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src/backend/annotators/normal_control.py
DELETED
|
@@ -1,10 +0,0 @@
|
|
| 1 |
-
from backend.annotators.control_interface import ControlInterface
|
| 2 |
-
from controlnet_aux import NormalBaeDetector
|
| 3 |
-
from PIL import Image
|
| 4 |
-
|
| 5 |
-
|
| 6 |
-
class NormalControl(ControlInterface):
|
| 7 |
-
def get_control_image(self, image: Image) -> Image:
|
| 8 |
-
processor = NormalBaeDetector.from_pretrained("lllyasviel/Annotators")
|
| 9 |
-
control_image = processor(image)
|
| 10 |
-
return control_image
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src/backend/annotators/pose_control.py
DELETED
|
@@ -1,10 +0,0 @@
|
|
| 1 |
-
from backend.annotators.control_interface import ControlInterface
|
| 2 |
-
from controlnet_aux import OpenposeDetector
|
| 3 |
-
from PIL import Image
|
| 4 |
-
|
| 5 |
-
|
| 6 |
-
class PoseControl(ControlInterface):
|
| 7 |
-
def get_control_image(self, image: Image) -> Image:
|
| 8 |
-
openpose = OpenposeDetector.from_pretrained("lllyasviel/ControlNet")
|
| 9 |
-
image = openpose(image)
|
| 10 |
-
return image
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src/backend/annotators/shuffle_control.py
DELETED
|
@@ -1,10 +0,0 @@
|
|
| 1 |
-
from backend.annotators.control_interface import ControlInterface
|
| 2 |
-
from controlnet_aux import ContentShuffleDetector
|
| 3 |
-
from PIL import Image
|
| 4 |
-
|
| 5 |
-
|
| 6 |
-
class ShuffleControl(ControlInterface):
|
| 7 |
-
def get_control_image(self, image: Image) -> Image:
|
| 8 |
-
shuffle_processor = ContentShuffleDetector()
|
| 9 |
-
image = shuffle_processor(image)
|
| 10 |
-
return image
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src/backend/annotators/softedge_control.py
DELETED
|
@@ -1,10 +0,0 @@
|
|
| 1 |
-
from backend.annotators.control_interface import ControlInterface
|
| 2 |
-
from controlnet_aux import PidiNetDetector
|
| 3 |
-
from PIL import Image
|
| 4 |
-
|
| 5 |
-
|
| 6 |
-
class SoftEdgeControl(ControlInterface):
|
| 7 |
-
def get_control_image(self, image: Image) -> Image:
|
| 8 |
-
processor = PidiNetDetector.from_pretrained("lllyasviel/Annotators")
|
| 9 |
-
control_image = processor(image)
|
| 10 |
-
return control_image
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src/backend/api/mcp_server.py
DELETED
|
@@ -1,97 +0,0 @@
|
|
| 1 |
-
import platform
|
| 2 |
-
|
| 3 |
-
import uvicorn
|
| 4 |
-
from backend.device import get_device_name
|
| 5 |
-
from backend.models.device import DeviceInfo
|
| 6 |
-
from constants import APP_VERSION, DEVICE
|
| 7 |
-
from context import Context
|
| 8 |
-
from fastapi import FastAPI, Request
|
| 9 |
-
from fastapi_mcp import FastApiMCP
|
| 10 |
-
from state import get_settings
|
| 11 |
-
from fastapi.middleware.cors import CORSMiddleware
|
| 12 |
-
from models.interface_types import InterfaceType
|
| 13 |
-
from fastapi.staticfiles import StaticFiles
|
| 14 |
-
|
| 15 |
-
app_settings = get_settings()
|
| 16 |
-
app = FastAPI(
|
| 17 |
-
title="FastSD CPU",
|
| 18 |
-
description="Fast stable diffusion on CPU",
|
| 19 |
-
version=APP_VERSION,
|
| 20 |
-
license_info={
|
| 21 |
-
"name": "MIT",
|
| 22 |
-
"identifier": "MIT",
|
| 23 |
-
},
|
| 24 |
-
describe_all_responses=True,
|
| 25 |
-
describe_full_response_schema=True,
|
| 26 |
-
)
|
| 27 |
-
origins = ["*"]
|
| 28 |
-
|
| 29 |
-
app.add_middleware(
|
| 30 |
-
CORSMiddleware,
|
| 31 |
-
allow_origins=origins,
|
| 32 |
-
allow_credentials=True,
|
| 33 |
-
allow_methods=["*"],
|
| 34 |
-
allow_headers=["*"],
|
| 35 |
-
)
|
| 36 |
-
print(app_settings.settings.lcm_diffusion_setting)
|
| 37 |
-
|
| 38 |
-
context = Context(InterfaceType.API_SERVER)
|
| 39 |
-
app.mount("/results", StaticFiles(directory="results"), name="results")
|
| 40 |
-
|
| 41 |
-
|
| 42 |
-
@app.get(
|
| 43 |
-
"/info",
|
| 44 |
-
description="Get system information",
|
| 45 |
-
summary="Get system information",
|
| 46 |
-
operation_id="get_system_info",
|
| 47 |
-
)
|
| 48 |
-
async def info() -> dict:
|
| 49 |
-
device_info = DeviceInfo(
|
| 50 |
-
device_type=DEVICE,
|
| 51 |
-
device_name=get_device_name(),
|
| 52 |
-
os=platform.system(),
|
| 53 |
-
platform=platform.platform(),
|
| 54 |
-
processor=platform.processor(),
|
| 55 |
-
)
|
| 56 |
-
return device_info.model_dump()
|
| 57 |
-
|
| 58 |
-
|
| 59 |
-
@app.post(
|
| 60 |
-
"/generate",
|
| 61 |
-
description="Generate image from text prompt",
|
| 62 |
-
summary="Text to image generation",
|
| 63 |
-
operation_id="generate",
|
| 64 |
-
)
|
| 65 |
-
async def generate(
|
| 66 |
-
prompt: str,
|
| 67 |
-
request: Request,
|
| 68 |
-
) -> str:
|
| 69 |
-
"""
|
| 70 |
-
Returns URL of the generated image for text prompt
|
| 71 |
-
"""
|
| 72 |
-
|
| 73 |
-
app_settings.settings.lcm_diffusion_setting.prompt = prompt
|
| 74 |
-
images = context.generate_text_to_image(app_settings.settings)
|
| 75 |
-
image_names = context.save_images(
|
| 76 |
-
images,
|
| 77 |
-
app_settings.settings,
|
| 78 |
-
)
|
| 79 |
-
url = request.url_for("results", path=image_names[0])
|
| 80 |
-
image_url = f"The generated image available at the URL {url}"
|
| 81 |
-
return image_url
|
| 82 |
-
|
| 83 |
-
|
| 84 |
-
def start_mcp_server(port: int = 8000):
|
| 85 |
-
mcp = FastApiMCP(
|
| 86 |
-
app,
|
| 87 |
-
name="FastSDCPU MCP",
|
| 88 |
-
description="MCP server for FastSD CPU API",
|
| 89 |
-
base_url=f"http://localhost:{port}",
|
| 90 |
-
)
|
| 91 |
-
|
| 92 |
-
mcp.mount()
|
| 93 |
-
uvicorn.run(
|
| 94 |
-
app,
|
| 95 |
-
host="0.0.0.0",
|
| 96 |
-
port=port,
|
| 97 |
-
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src/backend/api/models/response.py
DELETED
|
@@ -1,16 +0,0 @@
|
|
| 1 |
-
from typing import List
|
| 2 |
-
|
| 3 |
-
from pydantic import BaseModel
|
| 4 |
-
|
| 5 |
-
|
| 6 |
-
class StableDiffusionResponse(BaseModel):
|
| 7 |
-
"""
|
| 8 |
-
Stable diffusion response model
|
| 9 |
-
|
| 10 |
-
Attributes:
|
| 11 |
-
images (List[str]): List of JPEG image as base64 encoded
|
| 12 |
-
latency (float): Latency in seconds
|
| 13 |
-
"""
|
| 14 |
-
|
| 15 |
-
images: List[str]
|
| 16 |
-
latency: float
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src/backend/api/web.py
DELETED
|
@@ -1,112 +0,0 @@
|
|
| 1 |
-
import platform
|
| 2 |
-
|
| 3 |
-
import uvicorn
|
| 4 |
-
from fastapi import FastAPI
|
| 5 |
-
from fastapi.middleware.cors import CORSMiddleware
|
| 6 |
-
|
| 7 |
-
from backend.api.models.response import StableDiffusionResponse
|
| 8 |
-
from backend.base64_image import base64_image_to_pil, pil_image_to_base64_str
|
| 9 |
-
from backend.device import get_device_name
|
| 10 |
-
from backend.models.device import DeviceInfo
|
| 11 |
-
from backend.models.lcmdiffusion_setting import DiffusionTask, LCMDiffusionSetting
|
| 12 |
-
from constants import APP_VERSION, DEVICE
|
| 13 |
-
from context import Context
|
| 14 |
-
from models.interface_types import InterfaceType
|
| 15 |
-
from state import get_settings
|
| 16 |
-
|
| 17 |
-
app_settings = get_settings()
|
| 18 |
-
app = FastAPI(
|
| 19 |
-
title="FastSD CPU",
|
| 20 |
-
description="Fast stable diffusion on CPU",
|
| 21 |
-
version=APP_VERSION,
|
| 22 |
-
license_info={
|
| 23 |
-
"name": "MIT",
|
| 24 |
-
"identifier": "MIT",
|
| 25 |
-
},
|
| 26 |
-
docs_url="/api/docs",
|
| 27 |
-
redoc_url="/api/redoc",
|
| 28 |
-
openapi_url="/api/openapi.json",
|
| 29 |
-
)
|
| 30 |
-
print(app_settings.settings.lcm_diffusion_setting)
|
| 31 |
-
origins = ["*"]
|
| 32 |
-
app.add_middleware(
|
| 33 |
-
CORSMiddleware,
|
| 34 |
-
allow_origins=origins,
|
| 35 |
-
allow_credentials=True,
|
| 36 |
-
allow_methods=["*"],
|
| 37 |
-
allow_headers=["*"],
|
| 38 |
-
)
|
| 39 |
-
context = Context(InterfaceType.API_SERVER)
|
| 40 |
-
|
| 41 |
-
|
| 42 |
-
@app.get("/api/")
|
| 43 |
-
async def root():
|
| 44 |
-
return {"message": "Welcome to FastSD CPU API"}
|
| 45 |
-
|
| 46 |
-
|
| 47 |
-
@app.get(
|
| 48 |
-
"/api/info",
|
| 49 |
-
description="Get system information",
|
| 50 |
-
summary="Get system information",
|
| 51 |
-
)
|
| 52 |
-
async def info():
|
| 53 |
-
device_info = DeviceInfo(
|
| 54 |
-
device_type=DEVICE,
|
| 55 |
-
device_name=get_device_name(),
|
| 56 |
-
os=platform.system(),
|
| 57 |
-
platform=platform.platform(),
|
| 58 |
-
processor=platform.processor(),
|
| 59 |
-
)
|
| 60 |
-
return device_info.model_dump()
|
| 61 |
-
|
| 62 |
-
|
| 63 |
-
@app.get(
|
| 64 |
-
"/api/config",
|
| 65 |
-
description="Get current configuration",
|
| 66 |
-
summary="Get configurations",
|
| 67 |
-
)
|
| 68 |
-
async def config():
|
| 69 |
-
return app_settings.settings
|
| 70 |
-
|
| 71 |
-
|
| 72 |
-
@app.get(
|
| 73 |
-
"/api/models",
|
| 74 |
-
description="Get available models",
|
| 75 |
-
summary="Get available models",
|
| 76 |
-
)
|
| 77 |
-
async def models():
|
| 78 |
-
return {
|
| 79 |
-
"lcm_lora_models": app_settings.lcm_lora_models,
|
| 80 |
-
"stable_diffusion": app_settings.stable_diffsuion_models,
|
| 81 |
-
"openvino_models": app_settings.openvino_lcm_models,
|
| 82 |
-
"lcm_models": app_settings.lcm_models,
|
| 83 |
-
}
|
| 84 |
-
|
| 85 |
-
|
| 86 |
-
@app.post(
|
| 87 |
-
"/api/generate",
|
| 88 |
-
description="Generate image(Text to image,Image to Image)",
|
| 89 |
-
summary="Generate image(Text to image,Image to Image)",
|
| 90 |
-
)
|
| 91 |
-
async def generate(diffusion_config: LCMDiffusionSetting) -> StableDiffusionResponse:
|
| 92 |
-
app_settings.settings.lcm_diffusion_setting = diffusion_config
|
| 93 |
-
if diffusion_config.diffusion_task == DiffusionTask.image_to_image:
|
| 94 |
-
app_settings.settings.lcm_diffusion_setting.init_image = base64_image_to_pil(
|
| 95 |
-
diffusion_config.init_image
|
| 96 |
-
)
|
| 97 |
-
|
| 98 |
-
images = context.generate_text_to_image(app_settings.settings)
|
| 99 |
-
|
| 100 |
-
images_base64 = [pil_image_to_base64_str(img) for img in images]
|
| 101 |
-
return StableDiffusionResponse(
|
| 102 |
-
latency=round(context.latency, 2),
|
| 103 |
-
images=images_base64,
|
| 104 |
-
)
|
| 105 |
-
|
| 106 |
-
|
| 107 |
-
def start_web_server(port: int = 8000):
|
| 108 |
-
uvicorn.run(
|
| 109 |
-
app,
|
| 110 |
-
host="0.0.0.0",
|
| 111 |
-
port=port,
|
| 112 |
-
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src/backend/base64_image.py
DELETED
|
@@ -1,21 +0,0 @@
|
|
| 1 |
-
from io import BytesIO
|
| 2 |
-
from base64 import b64encode, b64decode
|
| 3 |
-
from PIL import Image
|
| 4 |
-
|
| 5 |
-
|
| 6 |
-
def pil_image_to_base64_str(
|
| 7 |
-
image: Image,
|
| 8 |
-
format: str = "JPEG",
|
| 9 |
-
) -> str:
|
| 10 |
-
buffer = BytesIO()
|
| 11 |
-
image.save(buffer, format=format)
|
| 12 |
-
buffer.seek(0)
|
| 13 |
-
img_base64 = b64encode(buffer.getvalue()).decode("utf-8")
|
| 14 |
-
return img_base64
|
| 15 |
-
|
| 16 |
-
|
| 17 |
-
def base64_image_to_pil(base64_str) -> Image:
|
| 18 |
-
image_data = b64decode(base64_str)
|
| 19 |
-
image_buffer = BytesIO(image_data)
|
| 20 |
-
image = Image.open(image_buffer)
|
| 21 |
-
return image
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src/backend/controlnet.py
DELETED
|
@@ -1,90 +0,0 @@
|
|
| 1 |
-
import logging
|
| 2 |
-
from PIL import Image
|
| 3 |
-
from diffusers import ControlNetModel
|
| 4 |
-
from backend.models.lcmdiffusion_setting import (
|
| 5 |
-
DiffusionTask,
|
| 6 |
-
ControlNetSetting,
|
| 7 |
-
)
|
| 8 |
-
|
| 9 |
-
|
| 10 |
-
# Prepares ControlNet adapters for use with FastSD CPU
|
| 11 |
-
#
|
| 12 |
-
# This function loads the ControlNet adapters defined by the
|
| 13 |
-
# _lcm_diffusion_setting.controlnet_ object and returns a dictionary
|
| 14 |
-
# with the pipeline arguments required to use the loaded adapters
|
| 15 |
-
def load_controlnet_adapters(lcm_diffusion_setting) -> dict:
|
| 16 |
-
controlnet_args = {}
|
| 17 |
-
if (
|
| 18 |
-
lcm_diffusion_setting.controlnet is None
|
| 19 |
-
or not lcm_diffusion_setting.controlnet.enabled
|
| 20 |
-
):
|
| 21 |
-
return controlnet_args
|
| 22 |
-
|
| 23 |
-
logging.info("Loading ControlNet adapter")
|
| 24 |
-
controlnet_adapter = ControlNetModel.from_single_file(
|
| 25 |
-
lcm_diffusion_setting.controlnet.adapter_path,
|
| 26 |
-
# local_files_only=True,
|
| 27 |
-
use_safetensors=True,
|
| 28 |
-
)
|
| 29 |
-
controlnet_args["controlnet"] = controlnet_adapter
|
| 30 |
-
return controlnet_args
|
| 31 |
-
|
| 32 |
-
|
| 33 |
-
# Updates the ControlNet pipeline arguments to use for image generation
|
| 34 |
-
#
|
| 35 |
-
# This function uses the contents of the _lcm_diffusion_setting.controlnet_
|
| 36 |
-
# object to generate a dictionary with the corresponding pipeline arguments
|
| 37 |
-
# to be used for image generation; in particular, it sets the ControlNet control
|
| 38 |
-
# image and conditioning scale
|
| 39 |
-
def update_controlnet_arguments(lcm_diffusion_setting) -> dict:
|
| 40 |
-
controlnet_args = {}
|
| 41 |
-
if (
|
| 42 |
-
lcm_diffusion_setting.controlnet is None
|
| 43 |
-
or not lcm_diffusion_setting.controlnet.enabled
|
| 44 |
-
):
|
| 45 |
-
return controlnet_args
|
| 46 |
-
|
| 47 |
-
controlnet_args["controlnet_conditioning_scale"] = (
|
| 48 |
-
lcm_diffusion_setting.controlnet.conditioning_scale
|
| 49 |
-
)
|
| 50 |
-
if lcm_diffusion_setting.diffusion_task == DiffusionTask.text_to_image.value:
|
| 51 |
-
controlnet_args["image"] = lcm_diffusion_setting.controlnet._control_image
|
| 52 |
-
elif lcm_diffusion_setting.diffusion_task == DiffusionTask.image_to_image.value:
|
| 53 |
-
controlnet_args["control_image"] = (
|
| 54 |
-
lcm_diffusion_setting.controlnet._control_image
|
| 55 |
-
)
|
| 56 |
-
return controlnet_args
|
| 57 |
-
|
| 58 |
-
|
| 59 |
-
# Helper function to adjust ControlNet settings from a dictionary
|
| 60 |
-
def controlnet_settings_from_dict(
|
| 61 |
-
lcm_diffusion_setting,
|
| 62 |
-
dictionary,
|
| 63 |
-
) -> None:
|
| 64 |
-
if lcm_diffusion_setting is None or dictionary is None:
|
| 65 |
-
logging.error("Invalid arguments!")
|
| 66 |
-
return
|
| 67 |
-
if (
|
| 68 |
-
"controlnet" not in dictionary
|
| 69 |
-
or dictionary["controlnet"] is None
|
| 70 |
-
or len(dictionary["controlnet"]) == 0
|
| 71 |
-
):
|
| 72 |
-
logging.warning("ControlNet settings not found, ControlNet will be disabled")
|
| 73 |
-
lcm_diffusion_setting.controlnet = None
|
| 74 |
-
return
|
| 75 |
-
|
| 76 |
-
controlnet = ControlNetSetting()
|
| 77 |
-
controlnet.enabled = dictionary["controlnet"][0]["enabled"]
|
| 78 |
-
controlnet.conditioning_scale = dictionary["controlnet"][0]["conditioning_scale"]
|
| 79 |
-
controlnet.adapter_path = dictionary["controlnet"][0]["adapter_path"]
|
| 80 |
-
controlnet._control_image = None
|
| 81 |
-
image_path = dictionary["controlnet"][0]["control_image"]
|
| 82 |
-
if controlnet.enabled:
|
| 83 |
-
try:
|
| 84 |
-
controlnet._control_image = Image.open(image_path)
|
| 85 |
-
except (AttributeError, FileNotFoundError) as err:
|
| 86 |
-
print(err)
|
| 87 |
-
if controlnet._control_image is None:
|
| 88 |
-
logging.error("Wrong ControlNet control image! Disabling ControlNet")
|
| 89 |
-
controlnet.enabled = False
|
| 90 |
-
lcm_diffusion_setting.controlnet = controlnet
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src/backend/device.py
DELETED
|
@@ -1,23 +0,0 @@
|
|
| 1 |
-
import platform
|
| 2 |
-
from constants import DEVICE
|
| 3 |
-
import torch
|
| 4 |
-
import openvino as ov
|
| 5 |
-
|
| 6 |
-
core = ov.Core()
|
| 7 |
-
|
| 8 |
-
|
| 9 |
-
def is_openvino_device() -> bool:
|
| 10 |
-
if DEVICE.lower() == "cpu" or DEVICE.lower()[0] == "g" or DEVICE.lower()[0] == "n":
|
| 11 |
-
return True
|
| 12 |
-
else:
|
| 13 |
-
return False
|
| 14 |
-
|
| 15 |
-
|
| 16 |
-
def get_device_name() -> str:
|
| 17 |
-
if DEVICE == "cuda" or DEVICE == "mps":
|
| 18 |
-
default_gpu_index = torch.cuda.current_device()
|
| 19 |
-
return torch.cuda.get_device_name(default_gpu_index)
|
| 20 |
-
elif platform.system().lower() == "darwin":
|
| 21 |
-
return platform.processor()
|
| 22 |
-
elif is_openvino_device():
|
| 23 |
-
return core.get_property(DEVICE.upper(), "FULL_DEVICE_NAME")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src/backend/gguf/gguf_diffusion.py
DELETED
|
@@ -1,319 +0,0 @@
|
|
| 1 |
-
"""
|
| 2 |
-
Wrapper class to call the stablediffusion.cpp shared library for GGUF support
|
| 3 |
-
"""
|
| 4 |
-
|
| 5 |
-
import ctypes
|
| 6 |
-
import platform
|
| 7 |
-
from ctypes import (
|
| 8 |
-
POINTER,
|
| 9 |
-
c_bool,
|
| 10 |
-
c_char_p,
|
| 11 |
-
c_float,
|
| 12 |
-
c_int,
|
| 13 |
-
c_int64,
|
| 14 |
-
c_void_p,
|
| 15 |
-
)
|
| 16 |
-
from dataclasses import dataclass
|
| 17 |
-
from os import path
|
| 18 |
-
from typing import List, Any
|
| 19 |
-
|
| 20 |
-
import numpy as np
|
| 21 |
-
from PIL import Image
|
| 22 |
-
|
| 23 |
-
from backend.gguf.sdcpp_types import (
|
| 24 |
-
RngType,
|
| 25 |
-
SampleMethod,
|
| 26 |
-
Schedule,
|
| 27 |
-
SDCPPLogLevel,
|
| 28 |
-
SDImage,
|
| 29 |
-
SdType,
|
| 30 |
-
)
|
| 31 |
-
|
| 32 |
-
|
| 33 |
-
@dataclass
|
| 34 |
-
class ModelConfig:
|
| 35 |
-
model_path: str = ""
|
| 36 |
-
clip_l_path: str = ""
|
| 37 |
-
t5xxl_path: str = ""
|
| 38 |
-
diffusion_model_path: str = ""
|
| 39 |
-
vae_path: str = ""
|
| 40 |
-
taesd_path: str = ""
|
| 41 |
-
control_net_path: str = ""
|
| 42 |
-
lora_model_dir: str = ""
|
| 43 |
-
embed_dir: str = ""
|
| 44 |
-
stacked_id_embed_dir: str = ""
|
| 45 |
-
vae_decode_only: bool = True
|
| 46 |
-
vae_tiling: bool = False
|
| 47 |
-
free_params_immediately: bool = False
|
| 48 |
-
n_threads: int = 4
|
| 49 |
-
wtype: SdType = SdType.SD_TYPE_Q4_0
|
| 50 |
-
rng_type: RngType = RngType.CUDA_RNG
|
| 51 |
-
schedule: Schedule = Schedule.DEFAULT
|
| 52 |
-
keep_clip_on_cpu: bool = False
|
| 53 |
-
keep_control_net_cpu: bool = False
|
| 54 |
-
keep_vae_on_cpu: bool = False
|
| 55 |
-
|
| 56 |
-
|
| 57 |
-
@dataclass
|
| 58 |
-
class Txt2ImgConfig:
|
| 59 |
-
prompt: str = "a man wearing sun glasses, highly detailed"
|
| 60 |
-
negative_prompt: str = ""
|
| 61 |
-
clip_skip: int = -1
|
| 62 |
-
cfg_scale: float = 2.0
|
| 63 |
-
guidance: float = 3.5
|
| 64 |
-
width: int = 512
|
| 65 |
-
height: int = 512
|
| 66 |
-
sample_method: SampleMethod = SampleMethod.EULER_A
|
| 67 |
-
sample_steps: int = 1
|
| 68 |
-
seed: int = -1
|
| 69 |
-
batch_count: int = 2
|
| 70 |
-
control_cond: Image = None
|
| 71 |
-
control_strength: float = 0.90
|
| 72 |
-
style_strength: float = 0.5
|
| 73 |
-
normalize_input: bool = False
|
| 74 |
-
input_id_images_path: bytes = b""
|
| 75 |
-
|
| 76 |
-
|
| 77 |
-
class GGUFDiffusion:
|
| 78 |
-
"""GGUF Diffusion
|
| 79 |
-
To support GGUF diffusion model based on stablediffusion.cpp
|
| 80 |
-
https://github.com/ggerganov/ggml/blob/master/docs/gguf.md
|
| 81 |
-
Implmented based on stablediffusion.h
|
| 82 |
-
"""
|
| 83 |
-
|
| 84 |
-
def __init__(
|
| 85 |
-
self,
|
| 86 |
-
libpath: str,
|
| 87 |
-
config: ModelConfig,
|
| 88 |
-
logging_enabled: bool = False,
|
| 89 |
-
):
|
| 90 |
-
sdcpp_shared_lib_path = self._get_sdcpp_shared_lib_path(libpath)
|
| 91 |
-
try:
|
| 92 |
-
self.libsdcpp = ctypes.CDLL(sdcpp_shared_lib_path)
|
| 93 |
-
except OSError as e:
|
| 94 |
-
print(f"Failed to load library {sdcpp_shared_lib_path}")
|
| 95 |
-
raise ValueError(f"Error: {e}")
|
| 96 |
-
|
| 97 |
-
if not config.clip_l_path or not path.exists(config.clip_l_path):
|
| 98 |
-
raise ValueError(
|
| 99 |
-
"CLIP model file not found,please check readme.md for GGUF model usage"
|
| 100 |
-
)
|
| 101 |
-
|
| 102 |
-
if not config.t5xxl_path or not path.exists(config.t5xxl_path):
|
| 103 |
-
raise ValueError(
|
| 104 |
-
"T5XXL model file not found,please check readme.md for GGUF model usage"
|
| 105 |
-
)
|
| 106 |
-
|
| 107 |
-
if not config.diffusion_model_path or not path.exists(
|
| 108 |
-
config.diffusion_model_path
|
| 109 |
-
):
|
| 110 |
-
raise ValueError(
|
| 111 |
-
"Diffusion model file not found,please check readme.md for GGUF model usage"
|
| 112 |
-
)
|
| 113 |
-
|
| 114 |
-
if not config.vae_path or not path.exists(config.vae_path):
|
| 115 |
-
raise ValueError(
|
| 116 |
-
"VAE model file not found,please check readme.md for GGUF model usage"
|
| 117 |
-
)
|
| 118 |
-
|
| 119 |
-
self.model_config = config
|
| 120 |
-
|
| 121 |
-
self.libsdcpp.new_sd_ctx.argtypes = [
|
| 122 |
-
c_char_p, # const char* model_path
|
| 123 |
-
c_char_p, # const char* clip_l_path
|
| 124 |
-
c_char_p, # const char* t5xxl_path
|
| 125 |
-
c_char_p, # const char* diffusion_model_path
|
| 126 |
-
c_char_p, # const char* vae_path
|
| 127 |
-
c_char_p, # const char* taesd_path
|
| 128 |
-
c_char_p, # const char* control_net_path_c_str
|
| 129 |
-
c_char_p, # const char* lora_model_dir
|
| 130 |
-
c_char_p, # const char* embed_dir_c_str
|
| 131 |
-
c_char_p, # const char* stacked_id_embed_dir_c_str
|
| 132 |
-
c_bool, # bool vae_decode_only
|
| 133 |
-
c_bool, # bool vae_tiling
|
| 134 |
-
c_bool, # bool free_params_immediately
|
| 135 |
-
c_int, # int n_threads
|
| 136 |
-
SdType, # enum sd_type_t wtype
|
| 137 |
-
RngType, # enum rng_type_t rng_type
|
| 138 |
-
Schedule, # enum schedule_t s
|
| 139 |
-
c_bool, # bool keep_clip_on_cpu
|
| 140 |
-
c_bool, # bool keep_control_net_cpu
|
| 141 |
-
c_bool, # bool keep_vae_on_cpu
|
| 142 |
-
]
|
| 143 |
-
|
| 144 |
-
self.libsdcpp.new_sd_ctx.restype = POINTER(c_void_p)
|
| 145 |
-
|
| 146 |
-
self.sd_ctx = self.libsdcpp.new_sd_ctx(
|
| 147 |
-
self._str_to_bytes(self.model_config.model_path),
|
| 148 |
-
self._str_to_bytes(self.model_config.clip_l_path),
|
| 149 |
-
self._str_to_bytes(self.model_config.t5xxl_path),
|
| 150 |
-
self._str_to_bytes(self.model_config.diffusion_model_path),
|
| 151 |
-
self._str_to_bytes(self.model_config.vae_path),
|
| 152 |
-
self._str_to_bytes(self.model_config.taesd_path),
|
| 153 |
-
self._str_to_bytes(self.model_config.control_net_path),
|
| 154 |
-
self._str_to_bytes(self.model_config.lora_model_dir),
|
| 155 |
-
self._str_to_bytes(self.model_config.embed_dir),
|
| 156 |
-
self._str_to_bytes(self.model_config.stacked_id_embed_dir),
|
| 157 |
-
self.model_config.vae_decode_only,
|
| 158 |
-
self.model_config.vae_tiling,
|
| 159 |
-
self.model_config.free_params_immediately,
|
| 160 |
-
self.model_config.n_threads,
|
| 161 |
-
self.model_config.wtype,
|
| 162 |
-
self.model_config.rng_type,
|
| 163 |
-
self.model_config.schedule,
|
| 164 |
-
self.model_config.keep_clip_on_cpu,
|
| 165 |
-
self.model_config.keep_control_net_cpu,
|
| 166 |
-
self.model_config.keep_vae_on_cpu,
|
| 167 |
-
)
|
| 168 |
-
|
| 169 |
-
if logging_enabled:
|
| 170 |
-
self._set_logcallback()
|
| 171 |
-
|
| 172 |
-
def _set_logcallback(self):
|
| 173 |
-
print("Setting logging callback")
|
| 174 |
-
# Define function callback
|
| 175 |
-
SdLogCallbackType = ctypes.CFUNCTYPE(
|
| 176 |
-
None,
|
| 177 |
-
SDCPPLogLevel,
|
| 178 |
-
ctypes.c_char_p,
|
| 179 |
-
ctypes.c_void_p,
|
| 180 |
-
)
|
| 181 |
-
|
| 182 |
-
self.libsdcpp.sd_set_log_callback.argtypes = [
|
| 183 |
-
SdLogCallbackType,
|
| 184 |
-
ctypes.c_void_p,
|
| 185 |
-
]
|
| 186 |
-
self.libsdcpp.sd_set_log_callback.restype = None
|
| 187 |
-
# Convert the Python callback to a C func pointer
|
| 188 |
-
self.c_log_callback = SdLogCallbackType(
|
| 189 |
-
self.log_callback
|
| 190 |
-
) # prevent GC,keep callback as member variable
|
| 191 |
-
self.libsdcpp.sd_set_log_callback(self.c_log_callback, None)
|
| 192 |
-
|
| 193 |
-
def _get_sdcpp_shared_lib_path(
|
| 194 |
-
self,
|
| 195 |
-
root_path: str,
|
| 196 |
-
) -> str:
|
| 197 |
-
system_name = platform.system()
|
| 198 |
-
print(f"GGUF Diffusion on {system_name}")
|
| 199 |
-
lib_name = "stable-diffusion.dll"
|
| 200 |
-
sdcpp_lib_path = ""
|
| 201 |
-
|
| 202 |
-
if system_name == "Windows":
|
| 203 |
-
sdcpp_lib_path = path.join(root_path, lib_name)
|
| 204 |
-
elif system_name == "Linux":
|
| 205 |
-
lib_name = "libstable-diffusion.so"
|
| 206 |
-
sdcpp_lib_path = path.join(root_path, lib_name)
|
| 207 |
-
elif system_name == "Darwin":
|
| 208 |
-
lib_name = "libstable-diffusion.dylib"
|
| 209 |
-
sdcpp_lib_path = path.join(root_path, lib_name)
|
| 210 |
-
else:
|
| 211 |
-
print("Unknown platform.")
|
| 212 |
-
|
| 213 |
-
return sdcpp_lib_path
|
| 214 |
-
|
| 215 |
-
@staticmethod
|
| 216 |
-
def log_callback(
|
| 217 |
-
level,
|
| 218 |
-
text,
|
| 219 |
-
data,
|
| 220 |
-
):
|
| 221 |
-
print(f"{text.decode('utf-8')}", end="")
|
| 222 |
-
|
| 223 |
-
def _str_to_bytes(self, in_str: str, encoding: str = "utf-8") -> bytes:
|
| 224 |
-
if in_str:
|
| 225 |
-
return in_str.encode(encoding)
|
| 226 |
-
else:
|
| 227 |
-
return b""
|
| 228 |
-
|
| 229 |
-
def generate_text2mg(self, txt2img_cfg: Txt2ImgConfig) -> List[Any]:
|
| 230 |
-
self.libsdcpp.txt2img.restype = POINTER(SDImage)
|
| 231 |
-
self.libsdcpp.txt2img.argtypes = [
|
| 232 |
-
c_void_p, # sd_ctx_t* sd_ctx (pointer to context object)
|
| 233 |
-
c_char_p, # const char* prompt
|
| 234 |
-
c_char_p, # const char* negative_prompt
|
| 235 |
-
c_int, # int clip_skip
|
| 236 |
-
c_float, # float cfg_scale
|
| 237 |
-
c_float, # float guidance
|
| 238 |
-
c_int, # int width
|
| 239 |
-
c_int, # int height
|
| 240 |
-
SampleMethod, # enum sample_method_t sample_method
|
| 241 |
-
c_int, # int sample_steps
|
| 242 |
-
c_int64, # int64_t seed
|
| 243 |
-
c_int, # int batch_count
|
| 244 |
-
POINTER(SDImage), # const sd_image_t* control_cond (pointer to SDImage)
|
| 245 |
-
c_float, # float control_strength
|
| 246 |
-
c_float, # float style_strength
|
| 247 |
-
c_bool, # bool normalize_input
|
| 248 |
-
c_char_p, # const char* input_id_images_path
|
| 249 |
-
]
|
| 250 |
-
|
| 251 |
-
image_buffer = self.libsdcpp.txt2img(
|
| 252 |
-
self.sd_ctx,
|
| 253 |
-
self._str_to_bytes(txt2img_cfg.prompt),
|
| 254 |
-
self._str_to_bytes(txt2img_cfg.negative_prompt),
|
| 255 |
-
txt2img_cfg.clip_skip,
|
| 256 |
-
txt2img_cfg.cfg_scale,
|
| 257 |
-
txt2img_cfg.guidance,
|
| 258 |
-
txt2img_cfg.width,
|
| 259 |
-
txt2img_cfg.height,
|
| 260 |
-
txt2img_cfg.sample_method,
|
| 261 |
-
txt2img_cfg.sample_steps,
|
| 262 |
-
txt2img_cfg.seed,
|
| 263 |
-
txt2img_cfg.batch_count,
|
| 264 |
-
txt2img_cfg.control_cond,
|
| 265 |
-
txt2img_cfg.control_strength,
|
| 266 |
-
txt2img_cfg.style_strength,
|
| 267 |
-
txt2img_cfg.normalize_input,
|
| 268 |
-
txt2img_cfg.input_id_images_path,
|
| 269 |
-
)
|
| 270 |
-
|
| 271 |
-
images = self._get_sd_images_from_buffer(
|
| 272 |
-
image_buffer,
|
| 273 |
-
txt2img_cfg.batch_count,
|
| 274 |
-
)
|
| 275 |
-
|
| 276 |
-
return images
|
| 277 |
-
|
| 278 |
-
def _get_sd_images_from_buffer(
|
| 279 |
-
self,
|
| 280 |
-
image_buffer: Any,
|
| 281 |
-
batch_count: int,
|
| 282 |
-
) -> List[Any]:
|
| 283 |
-
images = []
|
| 284 |
-
if image_buffer:
|
| 285 |
-
for i in range(batch_count):
|
| 286 |
-
image = image_buffer[i]
|
| 287 |
-
print(
|
| 288 |
-
f"Generated image: {image.width}x{image.height} with {image.channel} channels"
|
| 289 |
-
)
|
| 290 |
-
|
| 291 |
-
width = image.width
|
| 292 |
-
height = image.height
|
| 293 |
-
channels = image.channel
|
| 294 |
-
pixel_data = np.ctypeslib.as_array(
|
| 295 |
-
image.data, shape=(height, width, channels)
|
| 296 |
-
)
|
| 297 |
-
|
| 298 |
-
if channels == 1:
|
| 299 |
-
pil_image = Image.fromarray(pixel_data.squeeze(), mode="L")
|
| 300 |
-
elif channels == 3:
|
| 301 |
-
pil_image = Image.fromarray(pixel_data, mode="RGB")
|
| 302 |
-
elif channels == 4:
|
| 303 |
-
pil_image = Image.fromarray(pixel_data, mode="RGBA")
|
| 304 |
-
else:
|
| 305 |
-
raise ValueError(f"Unsupported number of channels: {channels}")
|
| 306 |
-
|
| 307 |
-
images.append(pil_image)
|
| 308 |
-
return images
|
| 309 |
-
|
| 310 |
-
def terminate(self):
|
| 311 |
-
if self.libsdcpp:
|
| 312 |
-
if self.sd_ctx:
|
| 313 |
-
self.libsdcpp.free_sd_ctx.argtypes = [c_void_p]
|
| 314 |
-
self.libsdcpp.free_sd_ctx.restype = None
|
| 315 |
-
self.libsdcpp.free_sd_ctx(self.sd_ctx)
|
| 316 |
-
del self.sd_ctx
|
| 317 |
-
self.sd_ctx = None
|
| 318 |
-
del self.libsdcpp
|
| 319 |
-
self.libsdcpp = None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src/backend/gguf/sdcpp_types.py
DELETED
|
@@ -1,104 +0,0 @@
|
|
| 1 |
-
"""
|
| 2 |
-
Ctypes for stablediffusion.cpp shared library
|
| 3 |
-
This is as per the stablediffusion.h file
|
| 4 |
-
"""
|
| 5 |
-
|
| 6 |
-
from enum import IntEnum
|
| 7 |
-
from ctypes import (
|
| 8 |
-
c_int,
|
| 9 |
-
c_uint32,
|
| 10 |
-
c_uint8,
|
| 11 |
-
POINTER,
|
| 12 |
-
Structure,
|
| 13 |
-
)
|
| 14 |
-
|
| 15 |
-
|
| 16 |
-
class CtypesEnum(IntEnum):
|
| 17 |
-
"""A ctypes-compatible IntEnum superclass."""
|
| 18 |
-
|
| 19 |
-
@classmethod
|
| 20 |
-
def from_param(cls, obj):
|
| 21 |
-
return int(obj)
|
| 22 |
-
|
| 23 |
-
|
| 24 |
-
class RngType(CtypesEnum):
|
| 25 |
-
STD_DEFAULT_RNG = 0
|
| 26 |
-
CUDA_RNG = 1
|
| 27 |
-
|
| 28 |
-
|
| 29 |
-
class SampleMethod(CtypesEnum):
|
| 30 |
-
EULER_A = 0
|
| 31 |
-
EULER = 1
|
| 32 |
-
HEUN = 2
|
| 33 |
-
DPM2 = 3
|
| 34 |
-
DPMPP2S_A = 4
|
| 35 |
-
DPMPP2M = 5
|
| 36 |
-
DPMPP2Mv2 = 6
|
| 37 |
-
IPNDM = 7
|
| 38 |
-
IPNDM_V = 7
|
| 39 |
-
LCM = 8
|
| 40 |
-
N_SAMPLE_METHODS = 9
|
| 41 |
-
|
| 42 |
-
|
| 43 |
-
class Schedule(CtypesEnum):
|
| 44 |
-
DEFAULT = 0
|
| 45 |
-
DISCRETE = 1
|
| 46 |
-
KARRAS = 2
|
| 47 |
-
EXPONENTIAL = 3
|
| 48 |
-
AYS = 4
|
| 49 |
-
GITS = 5
|
| 50 |
-
N_SCHEDULES = 5
|
| 51 |
-
|
| 52 |
-
|
| 53 |
-
class SdType(CtypesEnum):
|
| 54 |
-
SD_TYPE_F32 = 0
|
| 55 |
-
SD_TYPE_F16 = 1
|
| 56 |
-
SD_TYPE_Q4_0 = 2
|
| 57 |
-
SD_TYPE_Q4_1 = 3
|
| 58 |
-
# SD_TYPE_Q4_2 = 4, support has been removed
|
| 59 |
-
# SD_TYPE_Q4_3 = 5, support has been removed
|
| 60 |
-
SD_TYPE_Q5_0 = 6
|
| 61 |
-
SD_TYPE_Q5_1 = 7
|
| 62 |
-
SD_TYPE_Q8_0 = 8
|
| 63 |
-
SD_TYPE_Q8_1 = 9
|
| 64 |
-
SD_TYPE_Q2_K = 10
|
| 65 |
-
SD_TYPE_Q3_K = 11
|
| 66 |
-
SD_TYPE_Q4_K = 12
|
| 67 |
-
SD_TYPE_Q5_K = 13
|
| 68 |
-
SD_TYPE_Q6_K = 14
|
| 69 |
-
SD_TYPE_Q8_K = 15
|
| 70 |
-
SD_TYPE_IQ2_XXS = 16
|
| 71 |
-
SD_TYPE_IQ2_XS = 17
|
| 72 |
-
SD_TYPE_IQ3_XXS = 18
|
| 73 |
-
SD_TYPE_IQ1_S = 19
|
| 74 |
-
SD_TYPE_IQ4_NL = 20
|
| 75 |
-
SD_TYPE_IQ3_S = 21
|
| 76 |
-
SD_TYPE_IQ2_S = 22
|
| 77 |
-
SD_TYPE_IQ4_XS = 23
|
| 78 |
-
SD_TYPE_I8 = 24
|
| 79 |
-
SD_TYPE_I16 = 25
|
| 80 |
-
SD_TYPE_I32 = 26
|
| 81 |
-
SD_TYPE_I64 = 27
|
| 82 |
-
SD_TYPE_F64 = 28
|
| 83 |
-
SD_TYPE_IQ1_M = 29
|
| 84 |
-
SD_TYPE_BF16 = 30
|
| 85 |
-
SD_TYPE_Q4_0_4_4 = 31
|
| 86 |
-
SD_TYPE_Q4_0_4_8 = 32
|
| 87 |
-
SD_TYPE_Q4_0_8_8 = 33
|
| 88 |
-
SD_TYPE_COUNT = 34
|
| 89 |
-
|
| 90 |
-
|
| 91 |
-
class SDImage(Structure):
|
| 92 |
-
_fields_ = [
|
| 93 |
-
("width", c_uint32),
|
| 94 |
-
("height", c_uint32),
|
| 95 |
-
("channel", c_uint32),
|
| 96 |
-
("data", POINTER(c_uint8)),
|
| 97 |
-
]
|
| 98 |
-
|
| 99 |
-
|
| 100 |
-
class SDCPPLogLevel(c_int):
|
| 101 |
-
SD_LOG_LEVEL_DEBUG = 0
|
| 102 |
-
SD_LOG_LEVEL_INFO = 1
|
| 103 |
-
SD_LOG_LEVEL_WARNING = 2
|
| 104 |
-
SD_LOG_LEVEL_ERROR = 3
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src/backend/image_saver.py
DELETED
|
@@ -1,75 +0,0 @@
|
|
| 1 |
-
import json
|
| 2 |
-
from os import path, mkdir
|
| 3 |
-
from typing import Any
|
| 4 |
-
from uuid import uuid4
|
| 5 |
-
from backend.models.lcmdiffusion_setting import LCMDiffusionSetting
|
| 6 |
-
from utils import get_image_file_extension
|
| 7 |
-
|
| 8 |
-
|
| 9 |
-
def get_exclude_keys():
|
| 10 |
-
exclude_keys = {
|
| 11 |
-
"init_image": True,
|
| 12 |
-
"generated_images": True,
|
| 13 |
-
"lora": {
|
| 14 |
-
"models_dir": True,
|
| 15 |
-
"path": True,
|
| 16 |
-
},
|
| 17 |
-
"dirs": True,
|
| 18 |
-
"controlnet": {
|
| 19 |
-
"adapter_path": True,
|
| 20 |
-
},
|
| 21 |
-
}
|
| 22 |
-
return exclude_keys
|
| 23 |
-
|
| 24 |
-
|
| 25 |
-
class ImageSaver:
|
| 26 |
-
@staticmethod
|
| 27 |
-
def save_images(
|
| 28 |
-
output_path: str,
|
| 29 |
-
images: Any,
|
| 30 |
-
folder_name: str = "",
|
| 31 |
-
format: str = "PNG",
|
| 32 |
-
jpeg_quality: int = 90,
|
| 33 |
-
lcm_diffusion_setting: LCMDiffusionSetting = None,
|
| 34 |
-
) -> list[str]:
|
| 35 |
-
gen_id = uuid4()
|
| 36 |
-
image_ids = []
|
| 37 |
-
|
| 38 |
-
if images:
|
| 39 |
-
image_seeds = []
|
| 40 |
-
|
| 41 |
-
for index, image in enumerate(images):
|
| 42 |
-
|
| 43 |
-
image_seed = image.info.get('image_seed')
|
| 44 |
-
if image_seed is not None:
|
| 45 |
-
image_seeds.append(image_seed)
|
| 46 |
-
|
| 47 |
-
if not path.exists(output_path):
|
| 48 |
-
mkdir(output_path)
|
| 49 |
-
|
| 50 |
-
if folder_name:
|
| 51 |
-
out_path = path.join(
|
| 52 |
-
output_path,
|
| 53 |
-
folder_name,
|
| 54 |
-
)
|
| 55 |
-
else:
|
| 56 |
-
out_path = output_path
|
| 57 |
-
|
| 58 |
-
if not path.exists(out_path):
|
| 59 |
-
mkdir(out_path)
|
| 60 |
-
image_extension = get_image_file_extension(format)
|
| 61 |
-
image_file_name = f"{gen_id}-{index+1}{image_extension}"
|
| 62 |
-
image_ids.append(image_file_name)
|
| 63 |
-
image.save(path.join(out_path, image_file_name), quality = jpeg_quality)
|
| 64 |
-
if lcm_diffusion_setting:
|
| 65 |
-
data = lcm_diffusion_setting.model_dump(exclude=get_exclude_keys())
|
| 66 |
-
if image_seeds:
|
| 67 |
-
data['image_seeds'] = image_seeds
|
| 68 |
-
with open(path.join(out_path, f"{gen_id}.json"), "w") as json_file:
|
| 69 |
-
json.dump(
|
| 70 |
-
data,
|
| 71 |
-
json_file,
|
| 72 |
-
indent=4,
|
| 73 |
-
)
|
| 74 |
-
return image_ids
|
| 75 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src/backend/lcm_text_to_image.py
DELETED
|
@@ -1,577 +0,0 @@
|
|
| 1 |
-
import gc
|
| 2 |
-
from math import ceil
|
| 3 |
-
from typing import Any, List
|
| 4 |
-
import random
|
| 5 |
-
|
| 6 |
-
import numpy as np
|
| 7 |
-
import torch
|
| 8 |
-
from backend.device import is_openvino_device
|
| 9 |
-
from backend.controlnet import (
|
| 10 |
-
load_controlnet_adapters,
|
| 11 |
-
update_controlnet_arguments,
|
| 12 |
-
)
|
| 13 |
-
from backend.models.lcmdiffusion_setting import (
|
| 14 |
-
DiffusionTask,
|
| 15 |
-
LCMDiffusionSetting,
|
| 16 |
-
LCMLora,
|
| 17 |
-
)
|
| 18 |
-
from backend.openvino.pipelines import (
|
| 19 |
-
get_ov_image_to_image_pipeline,
|
| 20 |
-
get_ov_text_to_image_pipeline,
|
| 21 |
-
ov_load_taesd,
|
| 22 |
-
)
|
| 23 |
-
from backend.pipelines.lcm import (
|
| 24 |
-
get_image_to_image_pipeline,
|
| 25 |
-
get_lcm_model_pipeline,
|
| 26 |
-
load_taesd,
|
| 27 |
-
)
|
| 28 |
-
from backend.pipelines.lcm_lora import get_lcm_lora_pipeline
|
| 29 |
-
from constants import DEVICE, GGUF_THREADS
|
| 30 |
-
from diffusers import LCMScheduler
|
| 31 |
-
from image_ops import resize_pil_image
|
| 32 |
-
from backend.openvino.flux_pipeline import get_flux_pipeline
|
| 33 |
-
from backend.openvino.ov_hc_stablediffusion_pipeline import OvHcLatentConsistency
|
| 34 |
-
from backend.gguf.gguf_diffusion import (
|
| 35 |
-
GGUFDiffusion,
|
| 36 |
-
ModelConfig,
|
| 37 |
-
Txt2ImgConfig,
|
| 38 |
-
SampleMethod,
|
| 39 |
-
)
|
| 40 |
-
from paths import get_app_path
|
| 41 |
-
from pprint import pprint
|
| 42 |
-
|
| 43 |
-
try:
|
| 44 |
-
# support for token merging; keeping it optional for now
|
| 45 |
-
import tomesd
|
| 46 |
-
except ImportError:
|
| 47 |
-
print("tomesd library unavailable; disabling token merging support")
|
| 48 |
-
tomesd = None
|
| 49 |
-
|
| 50 |
-
|
| 51 |
-
class LCMTextToImage:
|
| 52 |
-
def __init__(
|
| 53 |
-
self,
|
| 54 |
-
device: str = "cpu",
|
| 55 |
-
) -> None:
|
| 56 |
-
self.pipeline = None
|
| 57 |
-
self.use_openvino = False
|
| 58 |
-
self.device = ""
|
| 59 |
-
self.previous_model_id = None
|
| 60 |
-
self.previous_use_tae_sd = False
|
| 61 |
-
self.previous_use_lcm_lora = False
|
| 62 |
-
self.previous_ov_model_id = ""
|
| 63 |
-
self.previous_token_merging = 0.0
|
| 64 |
-
self.previous_safety_checker = False
|
| 65 |
-
self.previous_use_openvino = False
|
| 66 |
-
self.img_to_img_pipeline = None
|
| 67 |
-
self.is_openvino_init = False
|
| 68 |
-
self.previous_lora = None
|
| 69 |
-
self.task_type = DiffusionTask.text_to_image
|
| 70 |
-
self.previous_use_gguf_model = False
|
| 71 |
-
self.previous_gguf_model = None
|
| 72 |
-
self.torch_data_type = (
|
| 73 |
-
torch.float32 if is_openvino_device() or DEVICE == "mps" else torch.float16
|
| 74 |
-
)
|
| 75 |
-
self.ov_model_id = None
|
| 76 |
-
print(f"Torch datatype : {self.torch_data_type}")
|
| 77 |
-
|
| 78 |
-
def _pipeline_to_device(self):
|
| 79 |
-
print(f"Pipeline device : {DEVICE}")
|
| 80 |
-
print(f"Pipeline dtype : {self.torch_data_type}")
|
| 81 |
-
self.pipeline.to(
|
| 82 |
-
torch_device=DEVICE,
|
| 83 |
-
torch_dtype=self.torch_data_type,
|
| 84 |
-
)
|
| 85 |
-
|
| 86 |
-
def _add_freeu(self):
|
| 87 |
-
pipeline_class = self.pipeline.__class__.__name__
|
| 88 |
-
if isinstance(self.pipeline.scheduler, LCMScheduler):
|
| 89 |
-
if pipeline_class == "StableDiffusionPipeline":
|
| 90 |
-
print("Add FreeU - SD")
|
| 91 |
-
self.pipeline.enable_freeu(
|
| 92 |
-
s1=0.9,
|
| 93 |
-
s2=0.2,
|
| 94 |
-
b1=1.2,
|
| 95 |
-
b2=1.4,
|
| 96 |
-
)
|
| 97 |
-
elif pipeline_class == "StableDiffusionXLPipeline":
|
| 98 |
-
print("Add FreeU - SDXL")
|
| 99 |
-
self.pipeline.enable_freeu(
|
| 100 |
-
s1=0.6,
|
| 101 |
-
s2=0.4,
|
| 102 |
-
b1=1.1,
|
| 103 |
-
b2=1.2,
|
| 104 |
-
)
|
| 105 |
-
|
| 106 |
-
def _enable_vae_tiling(self):
|
| 107 |
-
self.pipeline.vae.enable_tiling()
|
| 108 |
-
|
| 109 |
-
def _update_lcm_scheduler_params(self):
|
| 110 |
-
if isinstance(self.pipeline.scheduler, LCMScheduler):
|
| 111 |
-
self.pipeline.scheduler = LCMScheduler.from_config(
|
| 112 |
-
self.pipeline.scheduler.config,
|
| 113 |
-
beta_start=0.001,
|
| 114 |
-
beta_end=0.01,
|
| 115 |
-
)
|
| 116 |
-
|
| 117 |
-
def _is_hetero_pipeline(self) -> bool:
|
| 118 |
-
return "square" in self.ov_model_id.lower()
|
| 119 |
-
|
| 120 |
-
def _load_ov_hetero_pipeline(self):
|
| 121 |
-
print("Loading Heterogeneous Compute pipeline")
|
| 122 |
-
if DEVICE.upper() == "NPU":
|
| 123 |
-
device = ["NPU", "NPU", "NPU"]
|
| 124 |
-
self.pipeline = OvHcLatentConsistency(self.ov_model_id, device)
|
| 125 |
-
else:
|
| 126 |
-
self.pipeline = OvHcLatentConsistency(self.ov_model_id)
|
| 127 |
-
|
| 128 |
-
def _generate_images_hetero_compute(
|
| 129 |
-
self,
|
| 130 |
-
lcm_diffusion_setting: LCMDiffusionSetting,
|
| 131 |
-
):
|
| 132 |
-
print("Using OpenVINO ")
|
| 133 |
-
if lcm_diffusion_setting.diffusion_task == DiffusionTask.text_to_image.value:
|
| 134 |
-
return [
|
| 135 |
-
self.pipeline.generate(
|
| 136 |
-
prompt=lcm_diffusion_setting.prompt,
|
| 137 |
-
neg_prompt=lcm_diffusion_setting.negative_prompt,
|
| 138 |
-
init_image=None,
|
| 139 |
-
strength=1.0,
|
| 140 |
-
num_inference_steps=lcm_diffusion_setting.inference_steps,
|
| 141 |
-
)
|
| 142 |
-
]
|
| 143 |
-
else:
|
| 144 |
-
return [
|
| 145 |
-
self.pipeline.generate(
|
| 146 |
-
prompt=lcm_diffusion_setting.prompt,
|
| 147 |
-
neg_prompt=lcm_diffusion_setting.negative_prompt,
|
| 148 |
-
init_image=lcm_diffusion_setting.init_image,
|
| 149 |
-
strength=lcm_diffusion_setting.strength,
|
| 150 |
-
num_inference_steps=lcm_diffusion_setting.inference_steps,
|
| 151 |
-
)
|
| 152 |
-
]
|
| 153 |
-
|
| 154 |
-
def _is_valid_mode(
|
| 155 |
-
self,
|
| 156 |
-
modes: List,
|
| 157 |
-
) -> bool:
|
| 158 |
-
return modes.count(True) == 1 or modes.count(False) == 3
|
| 159 |
-
|
| 160 |
-
def _validate_mode(
|
| 161 |
-
self,
|
| 162 |
-
modes: List,
|
| 163 |
-
) -> None:
|
| 164 |
-
if not self._is_valid_mode(modes):
|
| 165 |
-
raise ValueError("Invalid mode,delete configs/settings.yaml and retry!")
|
| 166 |
-
|
| 167 |
-
def init(
|
| 168 |
-
self,
|
| 169 |
-
device: str = "cpu",
|
| 170 |
-
lcm_diffusion_setting: LCMDiffusionSetting = LCMDiffusionSetting(),
|
| 171 |
-
) -> None:
|
| 172 |
-
# Mode validation either LCM LoRA or OpenVINO or GGUF
|
| 173 |
-
|
| 174 |
-
modes = [
|
| 175 |
-
lcm_diffusion_setting.use_gguf_model,
|
| 176 |
-
lcm_diffusion_setting.use_openvino,
|
| 177 |
-
lcm_diffusion_setting.use_lcm_lora,
|
| 178 |
-
]
|
| 179 |
-
self._validate_mode(modes)
|
| 180 |
-
self.device = device
|
| 181 |
-
self.use_openvino = lcm_diffusion_setting.use_openvino
|
| 182 |
-
model_id = lcm_diffusion_setting.lcm_model_id
|
| 183 |
-
use_local_model = lcm_diffusion_setting.use_offline_model
|
| 184 |
-
use_tiny_auto_encoder = lcm_diffusion_setting.use_tiny_auto_encoder
|
| 185 |
-
use_lora = lcm_diffusion_setting.use_lcm_lora
|
| 186 |
-
lcm_lora: LCMLora = lcm_diffusion_setting.lcm_lora
|
| 187 |
-
token_merging = lcm_diffusion_setting.token_merging
|
| 188 |
-
self.ov_model_id = lcm_diffusion_setting.openvino_lcm_model_id
|
| 189 |
-
|
| 190 |
-
if lcm_diffusion_setting.diffusion_task == DiffusionTask.image_to_image.value:
|
| 191 |
-
lcm_diffusion_setting.init_image = resize_pil_image(
|
| 192 |
-
lcm_diffusion_setting.init_image,
|
| 193 |
-
lcm_diffusion_setting.image_width,
|
| 194 |
-
lcm_diffusion_setting.image_height,
|
| 195 |
-
)
|
| 196 |
-
|
| 197 |
-
if (
|
| 198 |
-
self.pipeline is None
|
| 199 |
-
or self.previous_model_id != model_id
|
| 200 |
-
or self.previous_use_tae_sd != use_tiny_auto_encoder
|
| 201 |
-
or self.previous_lcm_lora_base_id != lcm_lora.base_model_id
|
| 202 |
-
or self.previous_lcm_lora_id != lcm_lora.lcm_lora_id
|
| 203 |
-
or self.previous_use_lcm_lora != use_lora
|
| 204 |
-
or self.previous_ov_model_id != self.ov_model_id
|
| 205 |
-
or self.previous_token_merging != token_merging
|
| 206 |
-
or self.previous_safety_checker != lcm_diffusion_setting.use_safety_checker
|
| 207 |
-
or self.previous_use_openvino != lcm_diffusion_setting.use_openvino
|
| 208 |
-
or self.previous_use_gguf_model != lcm_diffusion_setting.use_gguf_model
|
| 209 |
-
or self.previous_gguf_model != lcm_diffusion_setting.gguf_model
|
| 210 |
-
or (
|
| 211 |
-
self.use_openvino
|
| 212 |
-
and (
|
| 213 |
-
self.previous_task_type != lcm_diffusion_setting.diffusion_task
|
| 214 |
-
or self.previous_lora != lcm_diffusion_setting.lora
|
| 215 |
-
)
|
| 216 |
-
)
|
| 217 |
-
or lcm_diffusion_setting.rebuild_pipeline
|
| 218 |
-
):
|
| 219 |
-
if self.use_openvino and is_openvino_device():
|
| 220 |
-
if self.pipeline:
|
| 221 |
-
del self.pipeline
|
| 222 |
-
self.pipeline = None
|
| 223 |
-
gc.collect()
|
| 224 |
-
self.is_openvino_init = True
|
| 225 |
-
if (
|
| 226 |
-
lcm_diffusion_setting.diffusion_task
|
| 227 |
-
== DiffusionTask.text_to_image.value
|
| 228 |
-
):
|
| 229 |
-
print(
|
| 230 |
-
f"***** Init Text to image (OpenVINO) - {self.ov_model_id} *****"
|
| 231 |
-
)
|
| 232 |
-
if "flux" in self.ov_model_id.lower():
|
| 233 |
-
print("Loading OpenVINO Flux pipeline")
|
| 234 |
-
self.pipeline = get_flux_pipeline(
|
| 235 |
-
self.ov_model_id,
|
| 236 |
-
lcm_diffusion_setting.use_tiny_auto_encoder,
|
| 237 |
-
)
|
| 238 |
-
elif self._is_hetero_pipeline():
|
| 239 |
-
self._load_ov_hetero_pipeline()
|
| 240 |
-
else:
|
| 241 |
-
self.pipeline = get_ov_text_to_image_pipeline(
|
| 242 |
-
self.ov_model_id,
|
| 243 |
-
use_local_model,
|
| 244 |
-
)
|
| 245 |
-
elif (
|
| 246 |
-
lcm_diffusion_setting.diffusion_task
|
| 247 |
-
== DiffusionTask.image_to_image.value
|
| 248 |
-
):
|
| 249 |
-
if not self.pipeline and self._is_hetero_pipeline():
|
| 250 |
-
self._load_ov_hetero_pipeline()
|
| 251 |
-
else:
|
| 252 |
-
print(
|
| 253 |
-
f"***** Image to image (OpenVINO) - {self.ov_model_id} *****"
|
| 254 |
-
)
|
| 255 |
-
self.pipeline = get_ov_image_to_image_pipeline(
|
| 256 |
-
self.ov_model_id,
|
| 257 |
-
use_local_model,
|
| 258 |
-
)
|
| 259 |
-
elif lcm_diffusion_setting.use_gguf_model:
|
| 260 |
-
model = lcm_diffusion_setting.gguf_model.diffusion_path
|
| 261 |
-
print(f"***** Init Text to image (GGUF) - {model} *****")
|
| 262 |
-
# if self.pipeline:
|
| 263 |
-
# self.pipeline.terminate()
|
| 264 |
-
# del self.pipeline
|
| 265 |
-
# self.pipeline = None
|
| 266 |
-
self._init_gguf_diffusion(lcm_diffusion_setting)
|
| 267 |
-
else:
|
| 268 |
-
if self.pipeline or self.img_to_img_pipeline:
|
| 269 |
-
self.pipeline = None
|
| 270 |
-
self.img_to_img_pipeline = None
|
| 271 |
-
gc.collect()
|
| 272 |
-
|
| 273 |
-
controlnet_args = load_controlnet_adapters(lcm_diffusion_setting)
|
| 274 |
-
if use_lora:
|
| 275 |
-
print(
|
| 276 |
-
f"***** Init LCM-LoRA pipeline - {lcm_lora.base_model_id} *****"
|
| 277 |
-
)
|
| 278 |
-
self.pipeline = get_lcm_lora_pipeline(
|
| 279 |
-
lcm_lora.base_model_id,
|
| 280 |
-
lcm_lora.lcm_lora_id,
|
| 281 |
-
use_local_model,
|
| 282 |
-
torch_data_type=self.torch_data_type,
|
| 283 |
-
pipeline_args=controlnet_args,
|
| 284 |
-
)
|
| 285 |
-
|
| 286 |
-
else:
|
| 287 |
-
print(f"***** Init LCM Model pipeline - {model_id} *****")
|
| 288 |
-
self.pipeline = get_lcm_model_pipeline(
|
| 289 |
-
model_id,
|
| 290 |
-
use_local_model,
|
| 291 |
-
controlnet_args,
|
| 292 |
-
)
|
| 293 |
-
|
| 294 |
-
self.img_to_img_pipeline = get_image_to_image_pipeline(self.pipeline)
|
| 295 |
-
|
| 296 |
-
if tomesd and token_merging > 0.001:
|
| 297 |
-
print(f"***** Token Merging: {token_merging} *****")
|
| 298 |
-
tomesd.apply_patch(self.pipeline, ratio=token_merging)
|
| 299 |
-
tomesd.apply_patch(self.img_to_img_pipeline, ratio=token_merging)
|
| 300 |
-
|
| 301 |
-
if use_tiny_auto_encoder:
|
| 302 |
-
if self.use_openvino and is_openvino_device():
|
| 303 |
-
if self.pipeline.__class__.__name__ != "OVFluxPipeline":
|
| 304 |
-
print("Using Tiny Auto Encoder (OpenVINO)")
|
| 305 |
-
ov_load_taesd(
|
| 306 |
-
self.pipeline,
|
| 307 |
-
use_local_model,
|
| 308 |
-
)
|
| 309 |
-
else:
|
| 310 |
-
print("Using Tiny Auto Encoder")
|
| 311 |
-
load_taesd(
|
| 312 |
-
self.pipeline,
|
| 313 |
-
use_local_model,
|
| 314 |
-
self.torch_data_type,
|
| 315 |
-
)
|
| 316 |
-
load_taesd(
|
| 317 |
-
self.img_to_img_pipeline,
|
| 318 |
-
use_local_model,
|
| 319 |
-
self.torch_data_type,
|
| 320 |
-
)
|
| 321 |
-
|
| 322 |
-
if not self.use_openvino and not is_openvino_device():
|
| 323 |
-
self._pipeline_to_device()
|
| 324 |
-
|
| 325 |
-
if not self._is_hetero_pipeline():
|
| 326 |
-
if (
|
| 327 |
-
lcm_diffusion_setting.diffusion_task
|
| 328 |
-
== DiffusionTask.image_to_image.value
|
| 329 |
-
and lcm_diffusion_setting.use_openvino
|
| 330 |
-
):
|
| 331 |
-
self.pipeline.scheduler = LCMScheduler.from_config(
|
| 332 |
-
self.pipeline.scheduler.config,
|
| 333 |
-
)
|
| 334 |
-
else:
|
| 335 |
-
if not lcm_diffusion_setting.use_gguf_model:
|
| 336 |
-
self._update_lcm_scheduler_params()
|
| 337 |
-
|
| 338 |
-
if use_lora:
|
| 339 |
-
self._add_freeu()
|
| 340 |
-
|
| 341 |
-
self.previous_model_id = model_id
|
| 342 |
-
self.previous_ov_model_id = self.ov_model_id
|
| 343 |
-
self.previous_use_tae_sd = use_tiny_auto_encoder
|
| 344 |
-
self.previous_lcm_lora_base_id = lcm_lora.base_model_id
|
| 345 |
-
self.previous_lcm_lora_id = lcm_lora.lcm_lora_id
|
| 346 |
-
self.previous_use_lcm_lora = use_lora
|
| 347 |
-
self.previous_token_merging = lcm_diffusion_setting.token_merging
|
| 348 |
-
self.previous_safety_checker = lcm_diffusion_setting.use_safety_checker
|
| 349 |
-
self.previous_use_openvino = lcm_diffusion_setting.use_openvino
|
| 350 |
-
self.previous_task_type = lcm_diffusion_setting.diffusion_task
|
| 351 |
-
self.previous_lora = lcm_diffusion_setting.lora.model_copy(deep=True)
|
| 352 |
-
self.previous_use_gguf_model = lcm_diffusion_setting.use_gguf_model
|
| 353 |
-
self.previous_gguf_model = lcm_diffusion_setting.gguf_model.model_copy(
|
| 354 |
-
deep=True
|
| 355 |
-
)
|
| 356 |
-
lcm_diffusion_setting.rebuild_pipeline = False
|
| 357 |
-
if (
|
| 358 |
-
lcm_diffusion_setting.diffusion_task
|
| 359 |
-
== DiffusionTask.text_to_image.value
|
| 360 |
-
):
|
| 361 |
-
print(f"Pipeline : {self.pipeline}")
|
| 362 |
-
elif (
|
| 363 |
-
lcm_diffusion_setting.diffusion_task
|
| 364 |
-
== DiffusionTask.image_to_image.value
|
| 365 |
-
):
|
| 366 |
-
if self.use_openvino and is_openvino_device():
|
| 367 |
-
print(f"Pipeline : {self.pipeline}")
|
| 368 |
-
else:
|
| 369 |
-
print(f"Pipeline : {self.img_to_img_pipeline}")
|
| 370 |
-
if self.use_openvino:
|
| 371 |
-
if lcm_diffusion_setting.lora.enabled:
|
| 372 |
-
print("Warning: Lora models not supported on OpenVINO mode")
|
| 373 |
-
elif not lcm_diffusion_setting.use_gguf_model:
|
| 374 |
-
adapters = self.pipeline.get_active_adapters()
|
| 375 |
-
print(f"Active adapters : {adapters}")
|
| 376 |
-
|
| 377 |
-
def _get_timesteps(self):
|
| 378 |
-
time_steps = self.pipeline.scheduler.config.get("timesteps")
|
| 379 |
-
time_steps_value = [int(time_steps)] if time_steps else None
|
| 380 |
-
return time_steps_value
|
| 381 |
-
|
| 382 |
-
def generate(
|
| 383 |
-
self,
|
| 384 |
-
lcm_diffusion_setting: LCMDiffusionSetting,
|
| 385 |
-
reshape: bool = False,
|
| 386 |
-
) -> Any:
|
| 387 |
-
guidance_scale = lcm_diffusion_setting.guidance_scale
|
| 388 |
-
img_to_img_inference_steps = lcm_diffusion_setting.inference_steps
|
| 389 |
-
check_step_value = int(
|
| 390 |
-
lcm_diffusion_setting.inference_steps * lcm_diffusion_setting.strength
|
| 391 |
-
)
|
| 392 |
-
if (
|
| 393 |
-
lcm_diffusion_setting.diffusion_task == DiffusionTask.image_to_image.value
|
| 394 |
-
and check_step_value < 1
|
| 395 |
-
):
|
| 396 |
-
img_to_img_inference_steps = ceil(1 / lcm_diffusion_setting.strength)
|
| 397 |
-
print(
|
| 398 |
-
f"Strength: {lcm_diffusion_setting.strength},{img_to_img_inference_steps}"
|
| 399 |
-
)
|
| 400 |
-
|
| 401 |
-
pipeline_extra_args = {}
|
| 402 |
-
|
| 403 |
-
if lcm_diffusion_setting.use_seed:
|
| 404 |
-
cur_seed = lcm_diffusion_setting.seed
|
| 405 |
-
# for multiple images with a fixed seed, use sequential seeds
|
| 406 |
-
seeds = [
|
| 407 |
-
(cur_seed + i) for i in range(lcm_diffusion_setting.number_of_images)
|
| 408 |
-
]
|
| 409 |
-
else:
|
| 410 |
-
seeds = [
|
| 411 |
-
random.randint(0, 999999999)
|
| 412 |
-
for i in range(lcm_diffusion_setting.number_of_images)
|
| 413 |
-
]
|
| 414 |
-
|
| 415 |
-
if self.use_openvino:
|
| 416 |
-
# no support for generators; try at least to ensure reproducible results for single images
|
| 417 |
-
np.random.seed(seeds[0])
|
| 418 |
-
if self._is_hetero_pipeline():
|
| 419 |
-
torch.manual_seed(seeds[0])
|
| 420 |
-
lcm_diffusion_setting.seed = seeds[0]
|
| 421 |
-
else:
|
| 422 |
-
pipeline_extra_args["generator"] = [
|
| 423 |
-
torch.Generator(device=self.device).manual_seed(s) for s in seeds
|
| 424 |
-
]
|
| 425 |
-
|
| 426 |
-
is_openvino_pipe = lcm_diffusion_setting.use_openvino and is_openvino_device()
|
| 427 |
-
if is_openvino_pipe and not self._is_hetero_pipeline():
|
| 428 |
-
print("Using OpenVINO")
|
| 429 |
-
if reshape and not self.is_openvino_init:
|
| 430 |
-
print("Reshape and compile")
|
| 431 |
-
self.pipeline.reshape(
|
| 432 |
-
batch_size=-1,
|
| 433 |
-
height=lcm_diffusion_setting.image_height,
|
| 434 |
-
width=lcm_diffusion_setting.image_width,
|
| 435 |
-
num_images_per_prompt=lcm_diffusion_setting.number_of_images,
|
| 436 |
-
)
|
| 437 |
-
self.pipeline.compile()
|
| 438 |
-
|
| 439 |
-
if self.is_openvino_init:
|
| 440 |
-
self.is_openvino_init = False
|
| 441 |
-
|
| 442 |
-
if is_openvino_pipe and self._is_hetero_pipeline():
|
| 443 |
-
return self._generate_images_hetero_compute(lcm_diffusion_setting)
|
| 444 |
-
elif lcm_diffusion_setting.use_gguf_model:
|
| 445 |
-
return self._generate_images_gguf(lcm_diffusion_setting)
|
| 446 |
-
|
| 447 |
-
if lcm_diffusion_setting.clip_skip > 1:
|
| 448 |
-
# We follow the convention that "CLIP Skip == 2" means "skip
|
| 449 |
-
# the last layer", so "CLIP Skip == 1" means "no skipping"
|
| 450 |
-
pipeline_extra_args["clip_skip"] = lcm_diffusion_setting.clip_skip - 1
|
| 451 |
-
|
| 452 |
-
if not lcm_diffusion_setting.use_safety_checker:
|
| 453 |
-
self.pipeline.safety_checker = None
|
| 454 |
-
if (
|
| 455 |
-
lcm_diffusion_setting.diffusion_task
|
| 456 |
-
== DiffusionTask.image_to_image.value
|
| 457 |
-
and not is_openvino_pipe
|
| 458 |
-
):
|
| 459 |
-
self.img_to_img_pipeline.safety_checker = None
|
| 460 |
-
|
| 461 |
-
if (
|
| 462 |
-
not lcm_diffusion_setting.use_lcm_lora
|
| 463 |
-
and not lcm_diffusion_setting.use_openvino
|
| 464 |
-
and lcm_diffusion_setting.guidance_scale != 1.0
|
| 465 |
-
):
|
| 466 |
-
print("Not using LCM-LoRA so setting guidance_scale 1.0")
|
| 467 |
-
guidance_scale = 1.0
|
| 468 |
-
|
| 469 |
-
controlnet_args = update_controlnet_arguments(lcm_diffusion_setting)
|
| 470 |
-
if lcm_diffusion_setting.use_openvino:
|
| 471 |
-
if (
|
| 472 |
-
lcm_diffusion_setting.diffusion_task
|
| 473 |
-
== DiffusionTask.text_to_image.value
|
| 474 |
-
):
|
| 475 |
-
result_images = self.pipeline(
|
| 476 |
-
prompt=lcm_diffusion_setting.prompt,
|
| 477 |
-
negative_prompt=lcm_diffusion_setting.negative_prompt,
|
| 478 |
-
num_inference_steps=lcm_diffusion_setting.inference_steps,
|
| 479 |
-
guidance_scale=guidance_scale,
|
| 480 |
-
width=lcm_diffusion_setting.image_width,
|
| 481 |
-
height=lcm_diffusion_setting.image_height,
|
| 482 |
-
num_images_per_prompt=lcm_diffusion_setting.number_of_images,
|
| 483 |
-
).images
|
| 484 |
-
elif (
|
| 485 |
-
lcm_diffusion_setting.diffusion_task
|
| 486 |
-
== DiffusionTask.image_to_image.value
|
| 487 |
-
):
|
| 488 |
-
result_images = self.pipeline(
|
| 489 |
-
image=lcm_diffusion_setting.init_image,
|
| 490 |
-
strength=lcm_diffusion_setting.strength,
|
| 491 |
-
prompt=lcm_diffusion_setting.prompt,
|
| 492 |
-
negative_prompt=lcm_diffusion_setting.negative_prompt,
|
| 493 |
-
num_inference_steps=img_to_img_inference_steps * 3,
|
| 494 |
-
guidance_scale=guidance_scale,
|
| 495 |
-
num_images_per_prompt=lcm_diffusion_setting.number_of_images,
|
| 496 |
-
).images
|
| 497 |
-
|
| 498 |
-
else:
|
| 499 |
-
if (
|
| 500 |
-
lcm_diffusion_setting.diffusion_task
|
| 501 |
-
== DiffusionTask.text_to_image.value
|
| 502 |
-
):
|
| 503 |
-
result_images = self.pipeline(
|
| 504 |
-
prompt=lcm_diffusion_setting.prompt,
|
| 505 |
-
negative_prompt=lcm_diffusion_setting.negative_prompt,
|
| 506 |
-
num_inference_steps=lcm_diffusion_setting.inference_steps,
|
| 507 |
-
guidance_scale=guidance_scale,
|
| 508 |
-
width=lcm_diffusion_setting.image_width,
|
| 509 |
-
height=lcm_diffusion_setting.image_height,
|
| 510 |
-
num_images_per_prompt=lcm_diffusion_setting.number_of_images,
|
| 511 |
-
timesteps=self._get_timesteps(),
|
| 512 |
-
**pipeline_extra_args,
|
| 513 |
-
**controlnet_args,
|
| 514 |
-
).images
|
| 515 |
-
|
| 516 |
-
elif (
|
| 517 |
-
lcm_diffusion_setting.diffusion_task
|
| 518 |
-
== DiffusionTask.image_to_image.value
|
| 519 |
-
):
|
| 520 |
-
result_images = self.img_to_img_pipeline(
|
| 521 |
-
image=lcm_diffusion_setting.init_image,
|
| 522 |
-
strength=lcm_diffusion_setting.strength,
|
| 523 |
-
prompt=lcm_diffusion_setting.prompt,
|
| 524 |
-
negative_prompt=lcm_diffusion_setting.negative_prompt,
|
| 525 |
-
num_inference_steps=img_to_img_inference_steps,
|
| 526 |
-
guidance_scale=guidance_scale,
|
| 527 |
-
width=lcm_diffusion_setting.image_width,
|
| 528 |
-
height=lcm_diffusion_setting.image_height,
|
| 529 |
-
num_images_per_prompt=lcm_diffusion_setting.number_of_images,
|
| 530 |
-
**pipeline_extra_args,
|
| 531 |
-
**controlnet_args,
|
| 532 |
-
).images
|
| 533 |
-
|
| 534 |
-
for i, seed in enumerate(seeds):
|
| 535 |
-
result_images[i].info["image_seed"] = seed
|
| 536 |
-
|
| 537 |
-
return result_images
|
| 538 |
-
|
| 539 |
-
def _init_gguf_diffusion(
|
| 540 |
-
self,
|
| 541 |
-
lcm_diffusion_setting: LCMDiffusionSetting,
|
| 542 |
-
):
|
| 543 |
-
config = ModelConfig()
|
| 544 |
-
config.model_path = lcm_diffusion_setting.gguf_model.diffusion_path
|
| 545 |
-
config.diffusion_model_path = lcm_diffusion_setting.gguf_model.diffusion_path
|
| 546 |
-
config.clip_l_path = lcm_diffusion_setting.gguf_model.clip_path
|
| 547 |
-
config.t5xxl_path = lcm_diffusion_setting.gguf_model.t5xxl_path
|
| 548 |
-
config.vae_path = lcm_diffusion_setting.gguf_model.vae_path
|
| 549 |
-
config.n_threads = GGUF_THREADS
|
| 550 |
-
print(f"GGUF Threads : {GGUF_THREADS} ")
|
| 551 |
-
print("GGUF - Model config")
|
| 552 |
-
pprint(lcm_diffusion_setting.gguf_model.model_dump())
|
| 553 |
-
self.pipeline = GGUFDiffusion(
|
| 554 |
-
get_app_path(), # Place DLL in fastsdcpu folder
|
| 555 |
-
config,
|
| 556 |
-
True,
|
| 557 |
-
)
|
| 558 |
-
|
| 559 |
-
def _generate_images_gguf(
|
| 560 |
-
self,
|
| 561 |
-
lcm_diffusion_setting: LCMDiffusionSetting,
|
| 562 |
-
):
|
| 563 |
-
if lcm_diffusion_setting.diffusion_task == DiffusionTask.text_to_image.value:
|
| 564 |
-
t2iconfig = Txt2ImgConfig()
|
| 565 |
-
t2iconfig.prompt = lcm_diffusion_setting.prompt
|
| 566 |
-
t2iconfig.batch_count = lcm_diffusion_setting.number_of_images
|
| 567 |
-
t2iconfig.cfg_scale = lcm_diffusion_setting.guidance_scale
|
| 568 |
-
t2iconfig.height = lcm_diffusion_setting.image_height
|
| 569 |
-
t2iconfig.width = lcm_diffusion_setting.image_width
|
| 570 |
-
t2iconfig.sample_steps = lcm_diffusion_setting.inference_steps
|
| 571 |
-
t2iconfig.sample_method = SampleMethod.EULER
|
| 572 |
-
if lcm_diffusion_setting.use_seed:
|
| 573 |
-
t2iconfig.seed = lcm_diffusion_setting.seed
|
| 574 |
-
else:
|
| 575 |
-
t2iconfig.seed = -1
|
| 576 |
-
|
| 577 |
-
return self.pipeline.generate_text2mg(t2iconfig)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src/backend/lora.py
DELETED
|
@@ -1,136 +0,0 @@
|
|
| 1 |
-
import glob
|
| 2 |
-
from os import path
|
| 3 |
-
from paths import get_file_name, FastStableDiffusionPaths
|
| 4 |
-
from pathlib import Path
|
| 5 |
-
|
| 6 |
-
|
| 7 |
-
# A basic class to keep track of the currently loaded LoRAs and
|
| 8 |
-
# their weights; the diffusers function \c get_active_adapters()
|
| 9 |
-
# returns a list of adapter names but not their weights so we need
|
| 10 |
-
# a way to keep track of the current LoRA weights to set whenever
|
| 11 |
-
# a new LoRA is loaded
|
| 12 |
-
class _lora_info:
|
| 13 |
-
def __init__(
|
| 14 |
-
self,
|
| 15 |
-
path: str,
|
| 16 |
-
weight: float,
|
| 17 |
-
):
|
| 18 |
-
self.path = path
|
| 19 |
-
self.adapter_name = get_file_name(path)
|
| 20 |
-
self.weight = weight
|
| 21 |
-
|
| 22 |
-
def __del__(self):
|
| 23 |
-
self.path = None
|
| 24 |
-
self.adapter_name = None
|
| 25 |
-
|
| 26 |
-
|
| 27 |
-
_loaded_loras = []
|
| 28 |
-
_current_pipeline = None
|
| 29 |
-
|
| 30 |
-
|
| 31 |
-
# This function loads a LoRA from the LoRA path setting, so it's
|
| 32 |
-
# possible to load multiple LoRAs by calling this function more than
|
| 33 |
-
# once with a different LoRA path setting; note that if you plan to
|
| 34 |
-
# load multiple LoRAs and dynamically change their weights, you
|
| 35 |
-
# might want to set the LoRA fuse option to False
|
| 36 |
-
def load_lora_weight(
|
| 37 |
-
pipeline,
|
| 38 |
-
lcm_diffusion_setting,
|
| 39 |
-
):
|
| 40 |
-
if not lcm_diffusion_setting.lora.path:
|
| 41 |
-
raise Exception("Empty lora model path")
|
| 42 |
-
|
| 43 |
-
if not path.exists(lcm_diffusion_setting.lora.path):
|
| 44 |
-
raise Exception("Lora model path is invalid")
|
| 45 |
-
|
| 46 |
-
# If the pipeline has been rebuilt since the last call, remove all
|
| 47 |
-
# references to previously loaded LoRAs and store the new pipeline
|
| 48 |
-
global _loaded_loras
|
| 49 |
-
global _current_pipeline
|
| 50 |
-
if pipeline != _current_pipeline:
|
| 51 |
-
for lora in _loaded_loras:
|
| 52 |
-
del lora
|
| 53 |
-
del _loaded_loras
|
| 54 |
-
_loaded_loras = []
|
| 55 |
-
_current_pipeline = pipeline
|
| 56 |
-
|
| 57 |
-
current_lora = _lora_info(
|
| 58 |
-
lcm_diffusion_setting.lora.path,
|
| 59 |
-
lcm_diffusion_setting.lora.weight,
|
| 60 |
-
)
|
| 61 |
-
_loaded_loras.append(current_lora)
|
| 62 |
-
|
| 63 |
-
if lcm_diffusion_setting.lora.enabled:
|
| 64 |
-
print(f"LoRA adapter name : {current_lora.adapter_name}")
|
| 65 |
-
pipeline.load_lora_weights(
|
| 66 |
-
FastStableDiffusionPaths.get_lora_models_path(),
|
| 67 |
-
weight_name=Path(lcm_diffusion_setting.lora.path).name,
|
| 68 |
-
local_files_only=True,
|
| 69 |
-
adapter_name=current_lora.adapter_name,
|
| 70 |
-
)
|
| 71 |
-
update_lora_weights(
|
| 72 |
-
pipeline,
|
| 73 |
-
lcm_diffusion_setting,
|
| 74 |
-
)
|
| 75 |
-
|
| 76 |
-
if lcm_diffusion_setting.lora.fuse:
|
| 77 |
-
pipeline.fuse_lora()
|
| 78 |
-
|
| 79 |
-
|
| 80 |
-
def get_lora_models(root_dir: str):
|
| 81 |
-
lora_models = glob.glob(f"{root_dir}/**/*.safetensors", recursive=True)
|
| 82 |
-
lora_models_map = {}
|
| 83 |
-
for file_path in lora_models:
|
| 84 |
-
lora_name = get_file_name(file_path)
|
| 85 |
-
if lora_name is not None:
|
| 86 |
-
lora_models_map[lora_name] = file_path
|
| 87 |
-
return lora_models_map
|
| 88 |
-
|
| 89 |
-
|
| 90 |
-
# This function returns a list of (adapter_name, weight) tuples for the
|
| 91 |
-
# currently loaded LoRAs
|
| 92 |
-
def get_active_lora_weights():
|
| 93 |
-
active_loras = []
|
| 94 |
-
for lora_info in _loaded_loras:
|
| 95 |
-
active_loras.append(
|
| 96 |
-
(
|
| 97 |
-
lora_info.adapter_name,
|
| 98 |
-
lora_info.weight,
|
| 99 |
-
)
|
| 100 |
-
)
|
| 101 |
-
return active_loras
|
| 102 |
-
|
| 103 |
-
|
| 104 |
-
# This function receives a pipeline, an lcm_diffusion_setting object and
|
| 105 |
-
# an optional list of updated (adapter_name, weight) tuples
|
| 106 |
-
def update_lora_weights(
|
| 107 |
-
pipeline,
|
| 108 |
-
lcm_diffusion_setting,
|
| 109 |
-
lora_weights=None,
|
| 110 |
-
):
|
| 111 |
-
global _loaded_loras
|
| 112 |
-
global _current_pipeline
|
| 113 |
-
if pipeline != _current_pipeline:
|
| 114 |
-
print("Wrong pipeline when trying to update LoRA weights")
|
| 115 |
-
return
|
| 116 |
-
if lora_weights:
|
| 117 |
-
for idx, lora in enumerate(lora_weights):
|
| 118 |
-
if _loaded_loras[idx].adapter_name != lora[0]:
|
| 119 |
-
print("Wrong adapter name in LoRA enumeration!")
|
| 120 |
-
continue
|
| 121 |
-
_loaded_loras[idx].weight = lora[1]
|
| 122 |
-
|
| 123 |
-
adapter_names = []
|
| 124 |
-
adapter_weights = []
|
| 125 |
-
if lcm_diffusion_setting.use_lcm_lora:
|
| 126 |
-
adapter_names.append("lcm")
|
| 127 |
-
adapter_weights.append(1.0)
|
| 128 |
-
for lora in _loaded_loras:
|
| 129 |
-
adapter_names.append(lora.adapter_name)
|
| 130 |
-
adapter_weights.append(lora.weight)
|
| 131 |
-
pipeline.set_adapters(
|
| 132 |
-
adapter_names,
|
| 133 |
-
adapter_weights=adapter_weights,
|
| 134 |
-
)
|
| 135 |
-
adapter_weights = zip(adapter_names, adapter_weights)
|
| 136 |
-
print(f"Adapters: {list(adapter_weights)}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src/backend/models/device.py
DELETED
|
@@ -1,9 +0,0 @@
|
|
| 1 |
-
from pydantic import BaseModel
|
| 2 |
-
|
| 3 |
-
|
| 4 |
-
class DeviceInfo(BaseModel):
|
| 5 |
-
device_type: str
|
| 6 |
-
device_name: str
|
| 7 |
-
os: str
|
| 8 |
-
platform: str
|
| 9 |
-
processor: str
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src/backend/models/gen_images.py
DELETED
|
@@ -1,17 +0,0 @@
|
|
| 1 |
-
from pydantic import BaseModel
|
| 2 |
-
from enum import Enum
|
| 3 |
-
from paths import FastStableDiffusionPaths
|
| 4 |
-
|
| 5 |
-
|
| 6 |
-
class ImageFormat(str, Enum):
|
| 7 |
-
"""Image format"""
|
| 8 |
-
|
| 9 |
-
JPEG = "jpeg"
|
| 10 |
-
PNG = "png"
|
| 11 |
-
|
| 12 |
-
|
| 13 |
-
class GeneratedImages(BaseModel):
|
| 14 |
-
path: str = FastStableDiffusionPaths.get_results_path()
|
| 15 |
-
format: str = ImageFormat.PNG.value.upper()
|
| 16 |
-
save_image: bool = True
|
| 17 |
-
save_image_quality: int = 90
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src/backend/models/lcmdiffusion_setting.py
DELETED
|
@@ -1,76 +0,0 @@
|
|
| 1 |
-
from enum import Enum
|
| 2 |
-
from PIL import Image
|
| 3 |
-
from typing import Any, Optional, Union
|
| 4 |
-
|
| 5 |
-
from constants import LCM_DEFAULT_MODEL, LCM_DEFAULT_MODEL_OPENVINO
|
| 6 |
-
from paths import FastStableDiffusionPaths
|
| 7 |
-
from pydantic import BaseModel
|
| 8 |
-
|
| 9 |
-
|
| 10 |
-
class LCMLora(BaseModel):
|
| 11 |
-
base_model_id: str = "Lykon/dreamshaper-8"
|
| 12 |
-
lcm_lora_id: str = "latent-consistency/lcm-lora-sdv1-5"
|
| 13 |
-
|
| 14 |
-
|
| 15 |
-
class DiffusionTask(str, Enum):
|
| 16 |
-
"""Diffusion task types"""
|
| 17 |
-
|
| 18 |
-
text_to_image = "text_to_image"
|
| 19 |
-
image_to_image = "image_to_image"
|
| 20 |
-
|
| 21 |
-
|
| 22 |
-
class Lora(BaseModel):
|
| 23 |
-
models_dir: str = FastStableDiffusionPaths.get_lora_models_path()
|
| 24 |
-
path: Optional[Any] = None
|
| 25 |
-
weight: Optional[float] = 0.5
|
| 26 |
-
fuse: bool = True
|
| 27 |
-
enabled: bool = False
|
| 28 |
-
|
| 29 |
-
|
| 30 |
-
class ControlNetSetting(BaseModel):
|
| 31 |
-
adapter_path: Optional[str] = None # ControlNet adapter path
|
| 32 |
-
conditioning_scale: float = 0.5
|
| 33 |
-
enabled: bool = False
|
| 34 |
-
_control_image: Image = None # Control image, PIL image
|
| 35 |
-
|
| 36 |
-
|
| 37 |
-
class GGUFModel(BaseModel):
|
| 38 |
-
gguf_models: str = FastStableDiffusionPaths.get_gguf_models_path()
|
| 39 |
-
diffusion_path: Optional[str] = None
|
| 40 |
-
clip_path: Optional[str] = None
|
| 41 |
-
t5xxl_path: Optional[str] = None
|
| 42 |
-
vae_path: Optional[str] = None
|
| 43 |
-
|
| 44 |
-
|
| 45 |
-
class LCMDiffusionSetting(BaseModel):
|
| 46 |
-
lcm_model_id: str = LCM_DEFAULT_MODEL
|
| 47 |
-
openvino_lcm_model_id: str = LCM_DEFAULT_MODEL_OPENVINO
|
| 48 |
-
use_offline_model: bool = False
|
| 49 |
-
use_lcm_lora: bool = False
|
| 50 |
-
lcm_lora: Optional[LCMLora] = LCMLora()
|
| 51 |
-
use_tiny_auto_encoder: bool = False
|
| 52 |
-
use_openvino: bool = False
|
| 53 |
-
prompt: str = ""
|
| 54 |
-
negative_prompt: str = ""
|
| 55 |
-
init_image: Any = None
|
| 56 |
-
strength: Optional[float] = 0.6
|
| 57 |
-
image_height: Optional[int] = 512
|
| 58 |
-
image_width: Optional[int] = 512
|
| 59 |
-
inference_steps: Optional[int] = 1
|
| 60 |
-
guidance_scale: Optional[float] = 1
|
| 61 |
-
clip_skip: Optional[int] = 1
|
| 62 |
-
token_merging: Optional[float] = 0
|
| 63 |
-
number_of_images: Optional[int] = 1
|
| 64 |
-
seed: Optional[int] = 123123
|
| 65 |
-
use_seed: bool = False
|
| 66 |
-
use_safety_checker: bool = False
|
| 67 |
-
diffusion_task: str = DiffusionTask.text_to_image.value
|
| 68 |
-
lora: Optional[Lora] = Lora()
|
| 69 |
-
controlnet: Optional[Union[ControlNetSetting, list[ControlNetSetting]]] = None
|
| 70 |
-
dirs: dict = {
|
| 71 |
-
"controlnet": FastStableDiffusionPaths.get_controlnet_models_path(),
|
| 72 |
-
"lora": FastStableDiffusionPaths.get_lora_models_path(),
|
| 73 |
-
}
|
| 74 |
-
rebuild_pipeline: bool = False
|
| 75 |
-
use_gguf_model: bool = False
|
| 76 |
-
gguf_model: Optional[GGUFModel] = GGUFModel()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src/backend/models/upscale.py
DELETED
|
@@ -1,9 +0,0 @@
|
|
| 1 |
-
from enum import Enum
|
| 2 |
-
|
| 3 |
-
|
| 4 |
-
class UpscaleMode(str, Enum):
|
| 5 |
-
"""Diffusion task types"""
|
| 6 |
-
|
| 7 |
-
normal = "normal"
|
| 8 |
-
sd_upscale = "sd_upscale"
|
| 9 |
-
aura_sr = "aura_sr"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src/backend/openvino/custom_ov_model_vae_decoder.py
DELETED
|
@@ -1,21 +0,0 @@
|
|
| 1 |
-
from backend.device import is_openvino_device
|
| 2 |
-
|
| 3 |
-
if is_openvino_device():
|
| 4 |
-
from optimum.intel.openvino.modeling_diffusion import OVModelVaeDecoder
|
| 5 |
-
|
| 6 |
-
|
| 7 |
-
class CustomOVModelVaeDecoder(OVModelVaeDecoder):
|
| 8 |
-
def __init__(
|
| 9 |
-
self,
|
| 10 |
-
model,
|
| 11 |
-
parent_model,
|
| 12 |
-
ov_config=None,
|
| 13 |
-
model_dir=None,
|
| 14 |
-
):
|
| 15 |
-
super(OVModelVaeDecoder, self).__init__(
|
| 16 |
-
model,
|
| 17 |
-
parent_model,
|
| 18 |
-
ov_config,
|
| 19 |
-
"vae_decoder",
|
| 20 |
-
model_dir,
|
| 21 |
-
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src/backend/openvino/flux_pipeline.py
DELETED
|
@@ -1,36 +0,0 @@
|
|
| 1 |
-
from pathlib import Path
|
| 2 |
-
|
| 3 |
-
from constants import DEVICE, LCM_DEFAULT_MODEL_OPENVINO, TAEF1_MODEL_OPENVINO
|
| 4 |
-
from huggingface_hub import snapshot_download
|
| 5 |
-
|
| 6 |
-
from backend.openvino.ovflux import (
|
| 7 |
-
TEXT_ENCODER_2_PATH,
|
| 8 |
-
TEXT_ENCODER_PATH,
|
| 9 |
-
TRANSFORMER_PATH,
|
| 10 |
-
VAE_DECODER_PATH,
|
| 11 |
-
init_pipeline,
|
| 12 |
-
)
|
| 13 |
-
|
| 14 |
-
|
| 15 |
-
def get_flux_pipeline(
|
| 16 |
-
model_id: str = LCM_DEFAULT_MODEL_OPENVINO,
|
| 17 |
-
use_taef1: bool = False,
|
| 18 |
-
taef1_path: str = TAEF1_MODEL_OPENVINO,
|
| 19 |
-
):
|
| 20 |
-
model_dir = Path(snapshot_download(model_id))
|
| 21 |
-
vae_dir = Path(snapshot_download(taef1_path)) if use_taef1 else model_dir
|
| 22 |
-
|
| 23 |
-
model_dict = {
|
| 24 |
-
"transformer": model_dir / TRANSFORMER_PATH,
|
| 25 |
-
"text_encoder": model_dir / TEXT_ENCODER_PATH,
|
| 26 |
-
"text_encoder_2": model_dir / TEXT_ENCODER_2_PATH,
|
| 27 |
-
"vae": vae_dir / VAE_DECODER_PATH,
|
| 28 |
-
}
|
| 29 |
-
ov_pipe = init_pipeline(
|
| 30 |
-
model_dir,
|
| 31 |
-
model_dict,
|
| 32 |
-
device=DEVICE.upper(),
|
| 33 |
-
use_taef1=use_taef1,
|
| 34 |
-
)
|
| 35 |
-
|
| 36 |
-
return ov_pipe
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src/backend/openvino/ov_hc_stablediffusion_pipeline.py
DELETED
|
@@ -1,93 +0,0 @@
|
|
| 1 |
-
"""This is an experimental pipeline used to test AI PC NPU and GPU"""
|
| 2 |
-
|
| 3 |
-
from pathlib import Path
|
| 4 |
-
|
| 5 |
-
from diffusers import EulerDiscreteScheduler,LCMScheduler
|
| 6 |
-
from huggingface_hub import snapshot_download
|
| 7 |
-
from PIL import Image
|
| 8 |
-
from backend.openvino.stable_diffusion_engine import (
|
| 9 |
-
StableDiffusionEngineAdvanced,
|
| 10 |
-
LatentConsistencyEngineAdvanced
|
| 11 |
-
)
|
| 12 |
-
|
| 13 |
-
|
| 14 |
-
class OvHcStableDiffusion:
|
| 15 |
-
"OpenVINO Heterogeneous compute Stablediffusion"
|
| 16 |
-
|
| 17 |
-
def __init__(
|
| 18 |
-
self,
|
| 19 |
-
model_path,
|
| 20 |
-
device: list = ["GPU", "NPU", "GPU", "GPU"],
|
| 21 |
-
):
|
| 22 |
-
model_dir = Path(snapshot_download(model_path))
|
| 23 |
-
self.scheduler = EulerDiscreteScheduler(
|
| 24 |
-
beta_start=0.00085,
|
| 25 |
-
beta_end=0.012,
|
| 26 |
-
beta_schedule="scaled_linear",
|
| 27 |
-
)
|
| 28 |
-
self.ov_sd_pipleline = StableDiffusionEngineAdvanced(
|
| 29 |
-
model=model_dir,
|
| 30 |
-
device=device,
|
| 31 |
-
)
|
| 32 |
-
|
| 33 |
-
def generate(
|
| 34 |
-
self,
|
| 35 |
-
prompt: str,
|
| 36 |
-
neg_prompt: str,
|
| 37 |
-
init_image: Image = None,
|
| 38 |
-
strength: float = 1.0,
|
| 39 |
-
):
|
| 40 |
-
image = self.ov_sd_pipleline(
|
| 41 |
-
prompt=prompt,
|
| 42 |
-
negative_prompt=neg_prompt,
|
| 43 |
-
init_image=init_image,
|
| 44 |
-
strength=strength,
|
| 45 |
-
num_inference_steps=25,
|
| 46 |
-
scheduler=self.scheduler,
|
| 47 |
-
)
|
| 48 |
-
image_rgb = image[..., ::-1]
|
| 49 |
-
return Image.fromarray(image_rgb)
|
| 50 |
-
|
| 51 |
-
|
| 52 |
-
class OvHcLatentConsistency:
|
| 53 |
-
"""
|
| 54 |
-
OpenVINO Heterogeneous compute Latent consistency models
|
| 55 |
-
For the current Intel Cor Ultra, the Text Encoder and Unet can run on NPU
|
| 56 |
-
Supports following - Text to image , Image to image and image variations
|
| 57 |
-
"""
|
| 58 |
-
|
| 59 |
-
def __init__(
|
| 60 |
-
self,
|
| 61 |
-
model_path,
|
| 62 |
-
device: list = ["NPU", "NPU", "GPU"],
|
| 63 |
-
):
|
| 64 |
-
|
| 65 |
-
model_dir = Path(snapshot_download(model_path))
|
| 66 |
-
|
| 67 |
-
self.scheduler = LCMScheduler(
|
| 68 |
-
beta_start=0.001,
|
| 69 |
-
beta_end=0.01,
|
| 70 |
-
)
|
| 71 |
-
self.ov_sd_pipleline = LatentConsistencyEngineAdvanced(
|
| 72 |
-
model=model_dir,
|
| 73 |
-
device=device,
|
| 74 |
-
)
|
| 75 |
-
|
| 76 |
-
def generate(
|
| 77 |
-
self,
|
| 78 |
-
prompt: str,
|
| 79 |
-
neg_prompt: str,
|
| 80 |
-
init_image: Image = None,
|
| 81 |
-
num_inference_steps=4,
|
| 82 |
-
strength: float = 0.5,
|
| 83 |
-
):
|
| 84 |
-
image = self.ov_sd_pipleline(
|
| 85 |
-
prompt=prompt,
|
| 86 |
-
init_image = init_image,
|
| 87 |
-
strength = strength,
|
| 88 |
-
num_inference_steps=num_inference_steps,
|
| 89 |
-
scheduler=self.scheduler,
|
| 90 |
-
seed=None,
|
| 91 |
-
)
|
| 92 |
-
|
| 93 |
-
return image
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src/backend/openvino/ovflux.py
DELETED
|
@@ -1,675 +0,0 @@
|
|
| 1 |
-
"""Based on https://raw.githubusercontent.com/openvinotoolkit/openvino_notebooks/latest/notebooks/flux.1-image-generation/flux_helper.py"""
|
| 2 |
-
|
| 3 |
-
import inspect
|
| 4 |
-
import json
|
| 5 |
-
from pathlib import Path
|
| 6 |
-
from typing import Any, Dict, List, Optional, Union
|
| 7 |
-
|
| 8 |
-
import numpy as np
|
| 9 |
-
import openvino as ov
|
| 10 |
-
import torch
|
| 11 |
-
from diffusers.image_processor import VaeImageProcessor
|
| 12 |
-
from diffusers.pipelines.flux.pipeline_output import FluxPipelineOutput
|
| 13 |
-
from diffusers.pipelines.pipeline_utils import DiffusionPipeline
|
| 14 |
-
from diffusers.schedulers import FlowMatchEulerDiscreteScheduler
|
| 15 |
-
from diffusers.utils.torch_utils import randn_tensor
|
| 16 |
-
from transformers import AutoTokenizer
|
| 17 |
-
|
| 18 |
-
TRANSFORMER_PATH = Path("transformer/transformer.xml")
|
| 19 |
-
VAE_DECODER_PATH = Path("vae/vae_decoder.xml")
|
| 20 |
-
TEXT_ENCODER_PATH = Path("text_encoder/text_encoder.xml")
|
| 21 |
-
TEXT_ENCODER_2_PATH = Path("text_encoder_2/text_encoder_2.xml")
|
| 22 |
-
|
| 23 |
-
|
| 24 |
-
def cleanup_torchscript_cache():
|
| 25 |
-
"""
|
| 26 |
-
Helper for removing cached model representation
|
| 27 |
-
"""
|
| 28 |
-
torch._C._jit_clear_class_registry()
|
| 29 |
-
torch.jit._recursive.concrete_type_store = torch.jit._recursive.ConcreteTypeStore()
|
| 30 |
-
torch.jit._state._clear_class_state()
|
| 31 |
-
|
| 32 |
-
|
| 33 |
-
def _prepare_latent_image_ids(
|
| 34 |
-
batch_size, height, width, device=torch.device("cpu"), dtype=torch.float32
|
| 35 |
-
):
|
| 36 |
-
latent_image_ids = torch.zeros(height // 2, width // 2, 3)
|
| 37 |
-
latent_image_ids[..., 1] = (
|
| 38 |
-
latent_image_ids[..., 1] + torch.arange(height // 2)[:, None]
|
| 39 |
-
)
|
| 40 |
-
latent_image_ids[..., 2] = (
|
| 41 |
-
latent_image_ids[..., 2] + torch.arange(width // 2)[None, :]
|
| 42 |
-
)
|
| 43 |
-
|
| 44 |
-
latent_image_id_height, latent_image_id_width, latent_image_id_channels = (
|
| 45 |
-
latent_image_ids.shape
|
| 46 |
-
)
|
| 47 |
-
|
| 48 |
-
latent_image_ids = latent_image_ids[None, :].repeat(batch_size, 1, 1, 1)
|
| 49 |
-
latent_image_ids = latent_image_ids.reshape(
|
| 50 |
-
batch_size,
|
| 51 |
-
latent_image_id_height * latent_image_id_width,
|
| 52 |
-
latent_image_id_channels,
|
| 53 |
-
)
|
| 54 |
-
|
| 55 |
-
return latent_image_ids.to(device=device, dtype=dtype)
|
| 56 |
-
|
| 57 |
-
|
| 58 |
-
def rope(pos: torch.Tensor, dim: int, theta: int) -> torch.Tensor:
|
| 59 |
-
assert dim % 2 == 0, "The dimension must be even."
|
| 60 |
-
|
| 61 |
-
scale = torch.arange(0, dim, 2, dtype=torch.float32, device=pos.device) / dim
|
| 62 |
-
omega = 1.0 / (theta**scale)
|
| 63 |
-
|
| 64 |
-
batch_size, seq_length = pos.shape
|
| 65 |
-
out = pos.unsqueeze(-1) * omega.unsqueeze(0).unsqueeze(0)
|
| 66 |
-
cos_out = torch.cos(out)
|
| 67 |
-
sin_out = torch.sin(out)
|
| 68 |
-
|
| 69 |
-
stacked_out = torch.stack([cos_out, -sin_out, sin_out, cos_out], dim=-1)
|
| 70 |
-
out = stacked_out.view(batch_size, -1, dim // 2, 2, 2)
|
| 71 |
-
return out.float()
|
| 72 |
-
|
| 73 |
-
|
| 74 |
-
def calculate_shift(
|
| 75 |
-
image_seq_len,
|
| 76 |
-
base_seq_len: int = 256,
|
| 77 |
-
max_seq_len: int = 4096,
|
| 78 |
-
base_shift: float = 0.5,
|
| 79 |
-
max_shift: float = 1.16,
|
| 80 |
-
):
|
| 81 |
-
m = (max_shift - base_shift) / (max_seq_len - base_seq_len)
|
| 82 |
-
b = base_shift - m * base_seq_len
|
| 83 |
-
mu = image_seq_len * m + b
|
| 84 |
-
return mu
|
| 85 |
-
|
| 86 |
-
|
| 87 |
-
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
|
| 88 |
-
def retrieve_timesteps(
|
| 89 |
-
scheduler,
|
| 90 |
-
num_inference_steps: Optional[int] = None,
|
| 91 |
-
timesteps: Optional[List[int]] = None,
|
| 92 |
-
sigmas: Optional[List[float]] = None,
|
| 93 |
-
**kwargs,
|
| 94 |
-
):
|
| 95 |
-
"""
|
| 96 |
-
Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
|
| 97 |
-
custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
|
| 98 |
-
|
| 99 |
-
Args:
|
| 100 |
-
scheduler (`SchedulerMixin`):
|
| 101 |
-
The scheduler to get timesteps from.
|
| 102 |
-
num_inference_steps (`int`):
|
| 103 |
-
The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
|
| 104 |
-
must be `None`.
|
| 105 |
-
device (`str` or `torch.device`, *optional*):
|
| 106 |
-
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
|
| 107 |
-
timesteps (`List[int]`, *optional*):
|
| 108 |
-
Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
|
| 109 |
-
`num_inference_steps` and `sigmas` must be `None`.
|
| 110 |
-
sigmas (`List[float]`, *optional*):
|
| 111 |
-
Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
|
| 112 |
-
`num_inference_steps` and `timesteps` must be `None`.
|
| 113 |
-
|
| 114 |
-
Returns:
|
| 115 |
-
`Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
|
| 116 |
-
second element is the number of inference steps.
|
| 117 |
-
"""
|
| 118 |
-
if timesteps is not None and sigmas is not None:
|
| 119 |
-
raise ValueError(
|
| 120 |
-
"Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values"
|
| 121 |
-
)
|
| 122 |
-
if timesteps is not None:
|
| 123 |
-
accepts_timesteps = "timesteps" in set(
|
| 124 |
-
inspect.signature(scheduler.set_timesteps).parameters.keys()
|
| 125 |
-
)
|
| 126 |
-
if not accepts_timesteps:
|
| 127 |
-
raise ValueError(
|
| 128 |
-
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
|
| 129 |
-
f" timestep schedules. Please check whether you are using the correct scheduler."
|
| 130 |
-
)
|
| 131 |
-
scheduler.set_timesteps(timesteps=timesteps, **kwargs)
|
| 132 |
-
timesteps = scheduler.timesteps
|
| 133 |
-
num_inference_steps = len(timesteps)
|
| 134 |
-
elif sigmas is not None:
|
| 135 |
-
accept_sigmas = "sigmas" in set(
|
| 136 |
-
inspect.signature(scheduler.set_timesteps).parameters.keys()
|
| 137 |
-
)
|
| 138 |
-
if not accept_sigmas:
|
| 139 |
-
raise ValueError(
|
| 140 |
-
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
|
| 141 |
-
f" sigmas schedules. Please check whether you are using the correct scheduler."
|
| 142 |
-
)
|
| 143 |
-
scheduler.set_timesteps(sigmas=sigmas, **kwargs)
|
| 144 |
-
timesteps = scheduler.timesteps
|
| 145 |
-
num_inference_steps = len(timesteps)
|
| 146 |
-
else:
|
| 147 |
-
scheduler.set_timesteps(num_inference_steps, **kwargs)
|
| 148 |
-
timesteps = scheduler.timesteps
|
| 149 |
-
return timesteps, num_inference_steps
|
| 150 |
-
|
| 151 |
-
|
| 152 |
-
class OVFluxPipeline(DiffusionPipeline):
|
| 153 |
-
def __init__(
|
| 154 |
-
self,
|
| 155 |
-
scheduler,
|
| 156 |
-
transformer,
|
| 157 |
-
vae,
|
| 158 |
-
text_encoder,
|
| 159 |
-
text_encoder_2,
|
| 160 |
-
tokenizer,
|
| 161 |
-
tokenizer_2,
|
| 162 |
-
transformer_config,
|
| 163 |
-
vae_config,
|
| 164 |
-
):
|
| 165 |
-
super().__init__()
|
| 166 |
-
|
| 167 |
-
self.register_modules(
|
| 168 |
-
vae=vae,
|
| 169 |
-
text_encoder=text_encoder,
|
| 170 |
-
text_encoder_2=text_encoder_2,
|
| 171 |
-
tokenizer=tokenizer,
|
| 172 |
-
tokenizer_2=tokenizer_2,
|
| 173 |
-
transformer=transformer,
|
| 174 |
-
scheduler=scheduler,
|
| 175 |
-
)
|
| 176 |
-
self.vae_config = vae_config
|
| 177 |
-
self.transformer_config = transformer_config
|
| 178 |
-
self.vae_scale_factor = 2 ** (
|
| 179 |
-
len(self.vae_config.get("block_out_channels", [0] * 16))
|
| 180 |
-
if hasattr(self, "vae") and self.vae is not None
|
| 181 |
-
else 16
|
| 182 |
-
)
|
| 183 |
-
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
|
| 184 |
-
self.tokenizer_max_length = (
|
| 185 |
-
self.tokenizer.model_max_length
|
| 186 |
-
if hasattr(self, "tokenizer") and self.tokenizer is not None
|
| 187 |
-
else 77
|
| 188 |
-
)
|
| 189 |
-
self.default_sample_size = 64
|
| 190 |
-
|
| 191 |
-
def _get_t5_prompt_embeds(
|
| 192 |
-
self,
|
| 193 |
-
prompt: Union[str, List[str]] = None,
|
| 194 |
-
num_images_per_prompt: int = 1,
|
| 195 |
-
max_sequence_length: int = 512,
|
| 196 |
-
):
|
| 197 |
-
prompt = [prompt] if isinstance(prompt, str) else prompt
|
| 198 |
-
batch_size = len(prompt)
|
| 199 |
-
|
| 200 |
-
text_inputs = self.tokenizer_2(
|
| 201 |
-
prompt,
|
| 202 |
-
padding="max_length",
|
| 203 |
-
max_length=max_sequence_length,
|
| 204 |
-
truncation=True,
|
| 205 |
-
return_length=False,
|
| 206 |
-
return_overflowing_tokens=False,
|
| 207 |
-
return_tensors="pt",
|
| 208 |
-
)
|
| 209 |
-
text_input_ids = text_inputs.input_ids
|
| 210 |
-
prompt_embeds = torch.from_numpy(self.text_encoder_2(text_input_ids)[0])
|
| 211 |
-
|
| 212 |
-
_, seq_len, _ = prompt_embeds.shape
|
| 213 |
-
|
| 214 |
-
# duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method
|
| 215 |
-
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
| 216 |
-
prompt_embeds = prompt_embeds.view(
|
| 217 |
-
batch_size * num_images_per_prompt, seq_len, -1
|
| 218 |
-
)
|
| 219 |
-
|
| 220 |
-
return prompt_embeds
|
| 221 |
-
|
| 222 |
-
def _get_clip_prompt_embeds(
|
| 223 |
-
self,
|
| 224 |
-
prompt: Union[str, List[str]],
|
| 225 |
-
num_images_per_prompt: int = 1,
|
| 226 |
-
):
|
| 227 |
-
|
| 228 |
-
prompt = [prompt] if isinstance(prompt, str) else prompt
|
| 229 |
-
batch_size = len(prompt)
|
| 230 |
-
|
| 231 |
-
text_inputs = self.tokenizer(
|
| 232 |
-
prompt,
|
| 233 |
-
padding="max_length",
|
| 234 |
-
max_length=self.tokenizer_max_length,
|
| 235 |
-
truncation=True,
|
| 236 |
-
return_overflowing_tokens=False,
|
| 237 |
-
return_length=False,
|
| 238 |
-
return_tensors="pt",
|
| 239 |
-
)
|
| 240 |
-
|
| 241 |
-
text_input_ids = text_inputs.input_ids
|
| 242 |
-
prompt_embeds = torch.from_numpy(self.text_encoder(text_input_ids)[1])
|
| 243 |
-
|
| 244 |
-
# duplicate text embeddings for each generation per prompt, using mps friendly method
|
| 245 |
-
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
| 246 |
-
prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, -1)
|
| 247 |
-
|
| 248 |
-
return prompt_embeds
|
| 249 |
-
|
| 250 |
-
def encode_prompt(
|
| 251 |
-
self,
|
| 252 |
-
prompt: Union[str, List[str]],
|
| 253 |
-
prompt_2: Union[str, List[str]],
|
| 254 |
-
num_images_per_prompt: int = 1,
|
| 255 |
-
prompt_embeds: Optional[torch.FloatTensor] = None,
|
| 256 |
-
pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
|
| 257 |
-
max_sequence_length: int = 512,
|
| 258 |
-
):
|
| 259 |
-
r"""
|
| 260 |
-
|
| 261 |
-
Args:
|
| 262 |
-
prompt (`str` or `List[str]`, *optional*):
|
| 263 |
-
prompt to be encoded
|
| 264 |
-
prompt_2 (`str` or `List[str]`, *optional*):
|
| 265 |
-
The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
|
| 266 |
-
used in all text-encoders
|
| 267 |
-
num_images_per_prompt (`int`):
|
| 268 |
-
number of images that should be generated per prompt
|
| 269 |
-
prompt_embeds (`torch.FloatTensor`, *optional*):
|
| 270 |
-
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
| 271 |
-
provided, text embeddings will be generated from `prompt` input argument.
|
| 272 |
-
pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
|
| 273 |
-
Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
|
| 274 |
-
If not provided, pooled text embeddings will be generated from `prompt` input argument.
|
| 275 |
-
lora_scale (`float`, *optional*):
|
| 276 |
-
A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
|
| 277 |
-
"""
|
| 278 |
-
|
| 279 |
-
prompt = [prompt] if isinstance(prompt, str) else prompt
|
| 280 |
-
if prompt is not None:
|
| 281 |
-
batch_size = len(prompt)
|
| 282 |
-
else:
|
| 283 |
-
batch_size = prompt_embeds.shape[0]
|
| 284 |
-
|
| 285 |
-
if prompt_embeds is None:
|
| 286 |
-
prompt_2 = prompt_2 or prompt
|
| 287 |
-
prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2
|
| 288 |
-
|
| 289 |
-
# We only use the pooled prompt output from the CLIPTextModel
|
| 290 |
-
pooled_prompt_embeds = self._get_clip_prompt_embeds(
|
| 291 |
-
prompt=prompt,
|
| 292 |
-
num_images_per_prompt=num_images_per_prompt,
|
| 293 |
-
)
|
| 294 |
-
prompt_embeds = self._get_t5_prompt_embeds(
|
| 295 |
-
prompt=prompt_2,
|
| 296 |
-
num_images_per_prompt=num_images_per_prompt,
|
| 297 |
-
max_sequence_length=max_sequence_length,
|
| 298 |
-
)
|
| 299 |
-
text_ids = torch.zeros(batch_size, prompt_embeds.shape[1], 3)
|
| 300 |
-
text_ids = text_ids.repeat(num_images_per_prompt, 1, 1)
|
| 301 |
-
|
| 302 |
-
return prompt_embeds, pooled_prompt_embeds, text_ids
|
| 303 |
-
|
| 304 |
-
def check_inputs(
|
| 305 |
-
self,
|
| 306 |
-
prompt,
|
| 307 |
-
prompt_2,
|
| 308 |
-
height,
|
| 309 |
-
width,
|
| 310 |
-
prompt_embeds=None,
|
| 311 |
-
pooled_prompt_embeds=None,
|
| 312 |
-
max_sequence_length=None,
|
| 313 |
-
):
|
| 314 |
-
if height % 8 != 0 or width % 8 != 0:
|
| 315 |
-
raise ValueError(
|
| 316 |
-
f"`height` and `width` have to be divisible by 8 but are {height} and {width}."
|
| 317 |
-
)
|
| 318 |
-
|
| 319 |
-
if prompt is not None and prompt_embeds is not None:
|
| 320 |
-
raise ValueError(
|
| 321 |
-
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
|
| 322 |
-
" only forward one of the two."
|
| 323 |
-
)
|
| 324 |
-
elif prompt_2 is not None and prompt_embeds is not None:
|
| 325 |
-
raise ValueError(
|
| 326 |
-
f"Cannot forward both `prompt_2`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
|
| 327 |
-
" only forward one of the two."
|
| 328 |
-
)
|
| 329 |
-
elif prompt is None and prompt_embeds is None:
|
| 330 |
-
raise ValueError(
|
| 331 |
-
"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
|
| 332 |
-
)
|
| 333 |
-
elif prompt is not None and (
|
| 334 |
-
not isinstance(prompt, str) and not isinstance(prompt, list)
|
| 335 |
-
):
|
| 336 |
-
raise ValueError(
|
| 337 |
-
f"`prompt` has to be of type `str` or `list` but is {type(prompt)}"
|
| 338 |
-
)
|
| 339 |
-
elif prompt_2 is not None and (
|
| 340 |
-
not isinstance(prompt_2, str) and not isinstance(prompt_2, list)
|
| 341 |
-
):
|
| 342 |
-
raise ValueError(
|
| 343 |
-
f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}"
|
| 344 |
-
)
|
| 345 |
-
|
| 346 |
-
if prompt_embeds is not None and pooled_prompt_embeds is None:
|
| 347 |
-
raise ValueError(
|
| 348 |
-
"If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`."
|
| 349 |
-
)
|
| 350 |
-
|
| 351 |
-
if max_sequence_length is not None and max_sequence_length > 512:
|
| 352 |
-
raise ValueError(
|
| 353 |
-
f"`max_sequence_length` cannot be greater than 512 but is {max_sequence_length}"
|
| 354 |
-
)
|
| 355 |
-
|
| 356 |
-
@staticmethod
|
| 357 |
-
def _prepare_latent_image_ids(batch_size, height, width):
|
| 358 |
-
return _prepare_latent_image_ids(batch_size, height, width)
|
| 359 |
-
|
| 360 |
-
@staticmethod
|
| 361 |
-
def _pack_latents(latents, batch_size, num_channels_latents, height, width):
|
| 362 |
-
latents = latents.view(
|
| 363 |
-
batch_size, num_channels_latents, height // 2, 2, width // 2, 2
|
| 364 |
-
)
|
| 365 |
-
latents = latents.permute(0, 2, 4, 1, 3, 5)
|
| 366 |
-
latents = latents.reshape(
|
| 367 |
-
batch_size, (height // 2) * (width // 2), num_channels_latents * 4
|
| 368 |
-
)
|
| 369 |
-
|
| 370 |
-
return latents
|
| 371 |
-
|
| 372 |
-
@staticmethod
|
| 373 |
-
def _unpack_latents(latents, height, width, vae_scale_factor):
|
| 374 |
-
batch_size, num_patches, channels = latents.shape
|
| 375 |
-
|
| 376 |
-
height = height // vae_scale_factor
|
| 377 |
-
width = width // vae_scale_factor
|
| 378 |
-
|
| 379 |
-
latents = latents.view(batch_size, height, width, channels // 4, 2, 2)
|
| 380 |
-
latents = latents.permute(0, 3, 1, 4, 2, 5)
|
| 381 |
-
|
| 382 |
-
latents = latents.reshape(
|
| 383 |
-
batch_size, channels // (2 * 2), height * 2, width * 2
|
| 384 |
-
)
|
| 385 |
-
|
| 386 |
-
return latents
|
| 387 |
-
|
| 388 |
-
def prepare_latents(
|
| 389 |
-
self,
|
| 390 |
-
batch_size,
|
| 391 |
-
num_channels_latents,
|
| 392 |
-
height,
|
| 393 |
-
width,
|
| 394 |
-
generator,
|
| 395 |
-
latents=None,
|
| 396 |
-
):
|
| 397 |
-
height = 2 * (int(height) // self.vae_scale_factor)
|
| 398 |
-
width = 2 * (int(width) // self.vae_scale_factor)
|
| 399 |
-
|
| 400 |
-
shape = (batch_size, num_channels_latents, height, width)
|
| 401 |
-
|
| 402 |
-
if latents is not None:
|
| 403 |
-
latent_image_ids = self._prepare_latent_image_ids(batch_size, height, width)
|
| 404 |
-
return latents, latent_image_ids
|
| 405 |
-
|
| 406 |
-
if isinstance(generator, list) and len(generator) != batch_size:
|
| 407 |
-
raise ValueError(
|
| 408 |
-
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
|
| 409 |
-
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
|
| 410 |
-
)
|
| 411 |
-
|
| 412 |
-
latents = randn_tensor(shape, generator=generator)
|
| 413 |
-
latents = self._pack_latents(
|
| 414 |
-
latents, batch_size, num_channels_latents, height, width
|
| 415 |
-
)
|
| 416 |
-
|
| 417 |
-
latent_image_ids = self._prepare_latent_image_ids(batch_size, height, width)
|
| 418 |
-
|
| 419 |
-
return latents, latent_image_ids
|
| 420 |
-
|
| 421 |
-
@property
|
| 422 |
-
def guidance_scale(self):
|
| 423 |
-
return self._guidance_scale
|
| 424 |
-
|
| 425 |
-
@property
|
| 426 |
-
def num_timesteps(self):
|
| 427 |
-
return self._num_timesteps
|
| 428 |
-
|
| 429 |
-
@property
|
| 430 |
-
def interrupt(self):
|
| 431 |
-
return self._interrupt
|
| 432 |
-
|
| 433 |
-
def __call__(
|
| 434 |
-
self,
|
| 435 |
-
prompt: Union[str, List[str]] = None,
|
| 436 |
-
prompt_2: Optional[Union[str, List[str]]] = None,
|
| 437 |
-
height: Optional[int] = None,
|
| 438 |
-
width: Optional[int] = None,
|
| 439 |
-
negative_prompt: str = None,
|
| 440 |
-
num_inference_steps: int = 28,
|
| 441 |
-
timesteps: List[int] = None,
|
| 442 |
-
guidance_scale: float = 7.0,
|
| 443 |
-
num_images_per_prompt: Optional[int] = 1,
|
| 444 |
-
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
| 445 |
-
latents: Optional[torch.FloatTensor] = None,
|
| 446 |
-
prompt_embeds: Optional[torch.FloatTensor] = None,
|
| 447 |
-
pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
|
| 448 |
-
output_type: Optional[str] = "pil",
|
| 449 |
-
return_dict: bool = True,
|
| 450 |
-
max_sequence_length: int = 512,
|
| 451 |
-
):
|
| 452 |
-
r"""
|
| 453 |
-
Function invoked when calling the pipeline for generation.
|
| 454 |
-
|
| 455 |
-
Args:
|
| 456 |
-
prompt (`str` or `List[str]`, *optional*):
|
| 457 |
-
The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
|
| 458 |
-
instead.
|
| 459 |
-
prompt_2 (`str` or `List[str]`, *optional*):
|
| 460 |
-
The prompt or prompts to be sent to `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
|
| 461 |
-
will be used instead
|
| 462 |
-
height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
|
| 463 |
-
The height in pixels of the generated image. This is set to 1024 by default for the best results.
|
| 464 |
-
width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
|
| 465 |
-
The width in pixels of the generated image. This is set to 1024 by default for the best results.
|
| 466 |
-
num_inference_steps (`int`, *optional*, defaults to 50):
|
| 467 |
-
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
|
| 468 |
-
expense of slower inference.
|
| 469 |
-
timesteps (`List[int]`, *optional*):
|
| 470 |
-
Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
|
| 471 |
-
in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
|
| 472 |
-
passed will be used. Must be in descending order.
|
| 473 |
-
guidance_scale (`float`, *optional*, defaults to 7.0):
|
| 474 |
-
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
|
| 475 |
-
`guidance_scale` is defined as `w` of equation 2. of [Imagen
|
| 476 |
-
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
|
| 477 |
-
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
|
| 478 |
-
usually at the expense of lower image quality.
|
| 479 |
-
num_images_per_prompt (`int`, *optional*, defaults to 1):
|
| 480 |
-
The number of images to generate per prompt.
|
| 481 |
-
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
|
| 482 |
-
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
|
| 483 |
-
to make generation deterministic.
|
| 484 |
-
latents (`torch.FloatTensor`, *optional*):
|
| 485 |
-
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
|
| 486 |
-
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
|
| 487 |
-
tensor will ge generated by sampling using the supplied random `generator`.
|
| 488 |
-
prompt_embeds (`torch.FloatTensor`, *optional*):
|
| 489 |
-
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
| 490 |
-
provided, text embeddings will be generated from `prompt` input argument.
|
| 491 |
-
pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
|
| 492 |
-
Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
|
| 493 |
-
If not provided, pooled text embeddings will be generated from `prompt` input argument.
|
| 494 |
-
output_type (`str`, *optional*, defaults to `"pil"`):
|
| 495 |
-
The output format of the generate image. Choose between
|
| 496 |
-
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
|
| 497 |
-
return_dict (`bool`, *optional*, defaults to `True`):
|
| 498 |
-
Whether or not to return a [`~pipelines.flux.FluxPipelineOutput`] instead of a plain tuple.
|
| 499 |
-
max_sequence_length (`int` defaults to 512): Maximum sequence length to use with the `prompt`.
|
| 500 |
-
Returns:
|
| 501 |
-
[`~pipelines.flux.FluxPipelineOutput`] or `tuple`: [`~pipelines.flux.FluxPipelineOutput`] if `return_dict`
|
| 502 |
-
is True, otherwise a `tuple`. When returning a tuple, the first element is a list with the generated
|
| 503 |
-
images.
|
| 504 |
-
"""
|
| 505 |
-
|
| 506 |
-
height = height or self.default_sample_size * self.vae_scale_factor
|
| 507 |
-
width = width or self.default_sample_size * self.vae_scale_factor
|
| 508 |
-
|
| 509 |
-
# 1. Check inputs. Raise error if not correct
|
| 510 |
-
self.check_inputs(
|
| 511 |
-
prompt,
|
| 512 |
-
prompt_2,
|
| 513 |
-
height,
|
| 514 |
-
width,
|
| 515 |
-
prompt_embeds=prompt_embeds,
|
| 516 |
-
pooled_prompt_embeds=pooled_prompt_embeds,
|
| 517 |
-
max_sequence_length=max_sequence_length,
|
| 518 |
-
)
|
| 519 |
-
|
| 520 |
-
self._guidance_scale = guidance_scale
|
| 521 |
-
self._interrupt = False
|
| 522 |
-
|
| 523 |
-
# 2. Define call parameters
|
| 524 |
-
if prompt is not None and isinstance(prompt, str):
|
| 525 |
-
batch_size = 1
|
| 526 |
-
elif prompt is not None and isinstance(prompt, list):
|
| 527 |
-
batch_size = len(prompt)
|
| 528 |
-
else:
|
| 529 |
-
batch_size = prompt_embeds.shape[0]
|
| 530 |
-
|
| 531 |
-
(
|
| 532 |
-
prompt_embeds,
|
| 533 |
-
pooled_prompt_embeds,
|
| 534 |
-
text_ids,
|
| 535 |
-
) = self.encode_prompt(
|
| 536 |
-
prompt=prompt,
|
| 537 |
-
prompt_2=prompt_2,
|
| 538 |
-
prompt_embeds=prompt_embeds,
|
| 539 |
-
pooled_prompt_embeds=pooled_prompt_embeds,
|
| 540 |
-
num_images_per_prompt=num_images_per_prompt,
|
| 541 |
-
max_sequence_length=max_sequence_length,
|
| 542 |
-
)
|
| 543 |
-
|
| 544 |
-
# 4. Prepare latent variables
|
| 545 |
-
num_channels_latents = self.transformer_config.get("in_channels", 64) // 4
|
| 546 |
-
latents, latent_image_ids = self.prepare_latents(
|
| 547 |
-
batch_size * num_images_per_prompt,
|
| 548 |
-
num_channels_latents,
|
| 549 |
-
height,
|
| 550 |
-
width,
|
| 551 |
-
generator,
|
| 552 |
-
latents,
|
| 553 |
-
)
|
| 554 |
-
|
| 555 |
-
# 5. Prepare timesteps
|
| 556 |
-
sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps)
|
| 557 |
-
image_seq_len = latents.shape[1]
|
| 558 |
-
mu = calculate_shift(
|
| 559 |
-
image_seq_len,
|
| 560 |
-
self.scheduler.config.base_image_seq_len,
|
| 561 |
-
self.scheduler.config.max_image_seq_len,
|
| 562 |
-
self.scheduler.config.base_shift,
|
| 563 |
-
self.scheduler.config.max_shift,
|
| 564 |
-
)
|
| 565 |
-
timesteps, num_inference_steps = retrieve_timesteps(
|
| 566 |
-
scheduler=self.scheduler,
|
| 567 |
-
num_inference_steps=num_inference_steps,
|
| 568 |
-
timesteps=timesteps,
|
| 569 |
-
sigmas=sigmas,
|
| 570 |
-
mu=mu,
|
| 571 |
-
)
|
| 572 |
-
num_warmup_steps = max(
|
| 573 |
-
len(timesteps) - num_inference_steps * self.scheduler.order, 0
|
| 574 |
-
)
|
| 575 |
-
self._num_timesteps = len(timesteps)
|
| 576 |
-
|
| 577 |
-
# 6. Denoising loop
|
| 578 |
-
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
| 579 |
-
for i, t in enumerate(timesteps):
|
| 580 |
-
if self.interrupt:
|
| 581 |
-
continue
|
| 582 |
-
|
| 583 |
-
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
|
| 584 |
-
timestep = t.expand(latents.shape[0]).to(latents.dtype)
|
| 585 |
-
|
| 586 |
-
# handle guidance
|
| 587 |
-
if self.transformer_config.get("guidance_embeds"):
|
| 588 |
-
guidance = torch.tensor([guidance_scale])
|
| 589 |
-
guidance = guidance.expand(latents.shape[0])
|
| 590 |
-
else:
|
| 591 |
-
guidance = None
|
| 592 |
-
|
| 593 |
-
transformer_input = {
|
| 594 |
-
"hidden_states": latents,
|
| 595 |
-
"timestep": timestep / 1000,
|
| 596 |
-
"pooled_projections": pooled_prompt_embeds,
|
| 597 |
-
"encoder_hidden_states": prompt_embeds,
|
| 598 |
-
"txt_ids": text_ids,
|
| 599 |
-
"img_ids": latent_image_ids,
|
| 600 |
-
}
|
| 601 |
-
if guidance is not None:
|
| 602 |
-
transformer_input["guidance"] = guidance
|
| 603 |
-
|
| 604 |
-
noise_pred = torch.from_numpy(self.transformer(transformer_input)[0])
|
| 605 |
-
|
| 606 |
-
latents = self.scheduler.step(
|
| 607 |
-
noise_pred, t, latents, return_dict=False
|
| 608 |
-
)[0]
|
| 609 |
-
|
| 610 |
-
# call the callback, if provided
|
| 611 |
-
if i == len(timesteps) - 1 or (
|
| 612 |
-
(i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0
|
| 613 |
-
):
|
| 614 |
-
progress_bar.update()
|
| 615 |
-
|
| 616 |
-
if output_type == "latent":
|
| 617 |
-
image = latents
|
| 618 |
-
|
| 619 |
-
else:
|
| 620 |
-
latents = self._unpack_latents(
|
| 621 |
-
latents, height, width, self.vae_scale_factor
|
| 622 |
-
)
|
| 623 |
-
latents = latents / self.vae_config.get(
|
| 624 |
-
"scaling_factor"
|
| 625 |
-
) + self.vae_config.get("shift_factor")
|
| 626 |
-
image = self.vae(latents)[0]
|
| 627 |
-
image = self.image_processor.postprocess(
|
| 628 |
-
torch.from_numpy(image), output_type=output_type
|
| 629 |
-
)
|
| 630 |
-
|
| 631 |
-
if not return_dict:
|
| 632 |
-
return (image,)
|
| 633 |
-
|
| 634 |
-
return FluxPipelineOutput(images=image)
|
| 635 |
-
|
| 636 |
-
|
| 637 |
-
def init_pipeline(
|
| 638 |
-
model_dir,
|
| 639 |
-
models_dict: Dict[str, Any],
|
| 640 |
-
device: str,
|
| 641 |
-
use_taef1: bool = False,
|
| 642 |
-
):
|
| 643 |
-
pipeline_args = {}
|
| 644 |
-
|
| 645 |
-
print("OpenVINO FLUX Model compilation")
|
| 646 |
-
core = ov.Core()
|
| 647 |
-
for model_name, model_path in models_dict.items():
|
| 648 |
-
pipeline_args[model_name] = core.compile_model(model_path, device)
|
| 649 |
-
if model_name == "vae" and use_taef1:
|
| 650 |
-
print(f"✅ VAE(TAEF1) - Done!")
|
| 651 |
-
else:
|
| 652 |
-
print(f"✅ {model_name} - Done!")
|
| 653 |
-
|
| 654 |
-
transformer_path = models_dict["transformer"]
|
| 655 |
-
transformer_config_path = transformer_path.parent / "config.json"
|
| 656 |
-
with transformer_config_path.open("r") as f:
|
| 657 |
-
transformer_config = json.load(f)
|
| 658 |
-
vae_path = models_dict["vae"]
|
| 659 |
-
vae_config_path = vae_path.parent / "config.json"
|
| 660 |
-
with vae_config_path.open("r") as f:
|
| 661 |
-
vae_config = json.load(f)
|
| 662 |
-
|
| 663 |
-
pipeline_args["vae_config"] = vae_config
|
| 664 |
-
pipeline_args["transformer_config"] = transformer_config
|
| 665 |
-
|
| 666 |
-
scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(model_dir / "scheduler")
|
| 667 |
-
|
| 668 |
-
tokenizer = AutoTokenizer.from_pretrained(model_dir / "tokenizer")
|
| 669 |
-
tokenizer_2 = AutoTokenizer.from_pretrained(model_dir / "tokenizer_2")
|
| 670 |
-
|
| 671 |
-
pipeline_args["scheduler"] = scheduler
|
| 672 |
-
pipeline_args["tokenizer"] = tokenizer
|
| 673 |
-
pipeline_args["tokenizer_2"] = tokenizer_2
|
| 674 |
-
ov_pipe = OVFluxPipeline(**pipeline_args)
|
| 675 |
-
return ov_pipe
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src/backend/openvino/pipelines.py
DELETED
|
@@ -1,75 +0,0 @@
|
|
| 1 |
-
from constants import DEVICE, LCM_DEFAULT_MODEL_OPENVINO
|
| 2 |
-
from backend.tiny_decoder import get_tiny_decoder_vae_model
|
| 3 |
-
from typing import Any
|
| 4 |
-
from backend.device import is_openvino_device
|
| 5 |
-
from paths import get_base_folder_name
|
| 6 |
-
|
| 7 |
-
if is_openvino_device():
|
| 8 |
-
from huggingface_hub import snapshot_download
|
| 9 |
-
from optimum.intel.openvino.modeling_diffusion import OVBaseModel
|
| 10 |
-
|
| 11 |
-
from optimum.intel.openvino.modeling_diffusion import (
|
| 12 |
-
OVStableDiffusionPipeline,
|
| 13 |
-
OVStableDiffusionImg2ImgPipeline,
|
| 14 |
-
OVStableDiffusionXLPipeline,
|
| 15 |
-
OVStableDiffusionXLImg2ImgPipeline,
|
| 16 |
-
)
|
| 17 |
-
from backend.openvino.custom_ov_model_vae_decoder import CustomOVModelVaeDecoder
|
| 18 |
-
|
| 19 |
-
|
| 20 |
-
def ov_load_taesd(
|
| 21 |
-
pipeline: Any,
|
| 22 |
-
use_local_model: bool = False,
|
| 23 |
-
):
|
| 24 |
-
taesd_dir = snapshot_download(
|
| 25 |
-
repo_id=get_tiny_decoder_vae_model(pipeline.__class__.__name__),
|
| 26 |
-
local_files_only=use_local_model,
|
| 27 |
-
)
|
| 28 |
-
pipeline.vae_decoder = CustomOVModelVaeDecoder(
|
| 29 |
-
model=OVBaseModel.load_model(f"{taesd_dir}/vae_decoder/openvino_model.xml"),
|
| 30 |
-
parent_model=pipeline,
|
| 31 |
-
model_dir=taesd_dir,
|
| 32 |
-
)
|
| 33 |
-
|
| 34 |
-
|
| 35 |
-
def get_ov_text_to_image_pipeline(
|
| 36 |
-
model_id: str = LCM_DEFAULT_MODEL_OPENVINO,
|
| 37 |
-
use_local_model: bool = False,
|
| 38 |
-
) -> Any:
|
| 39 |
-
if "xl" in get_base_folder_name(model_id).lower():
|
| 40 |
-
pipeline = OVStableDiffusionXLPipeline.from_pretrained(
|
| 41 |
-
model_id,
|
| 42 |
-
local_files_only=use_local_model,
|
| 43 |
-
ov_config={"CACHE_DIR": ""},
|
| 44 |
-
device=DEVICE.upper(),
|
| 45 |
-
)
|
| 46 |
-
else:
|
| 47 |
-
pipeline = OVStableDiffusionPipeline.from_pretrained(
|
| 48 |
-
model_id,
|
| 49 |
-
local_files_only=use_local_model,
|
| 50 |
-
ov_config={"CACHE_DIR": ""},
|
| 51 |
-
device=DEVICE.upper(),
|
| 52 |
-
)
|
| 53 |
-
|
| 54 |
-
return pipeline
|
| 55 |
-
|
| 56 |
-
|
| 57 |
-
def get_ov_image_to_image_pipeline(
|
| 58 |
-
model_id: str = LCM_DEFAULT_MODEL_OPENVINO,
|
| 59 |
-
use_local_model: bool = False,
|
| 60 |
-
) -> Any:
|
| 61 |
-
if "xl" in get_base_folder_name(model_id).lower():
|
| 62 |
-
pipeline = OVStableDiffusionXLImg2ImgPipeline.from_pretrained(
|
| 63 |
-
model_id,
|
| 64 |
-
local_files_only=use_local_model,
|
| 65 |
-
ov_config={"CACHE_DIR": ""},
|
| 66 |
-
device=DEVICE.upper(),
|
| 67 |
-
)
|
| 68 |
-
else:
|
| 69 |
-
pipeline = OVStableDiffusionImg2ImgPipeline.from_pretrained(
|
| 70 |
-
model_id,
|
| 71 |
-
local_files_only=use_local_model,
|
| 72 |
-
ov_config={"CACHE_DIR": ""},
|
| 73 |
-
device=DEVICE.upper(),
|
| 74 |
-
)
|
| 75 |
-
return pipeline
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src/backend/openvino/stable_diffusion_engine.py
DELETED
|
@@ -1,1817 +0,0 @@
|
|
| 1 |
-
"""
|
| 2 |
-
Copyright(C) 2022-2023 Intel Corporation
|
| 3 |
-
SPDX - License - Identifier: Apache - 2.0
|
| 4 |
-
|
| 5 |
-
"""
|
| 6 |
-
import inspect
|
| 7 |
-
from typing import Union, Optional, Any, List, Dict
|
| 8 |
-
import numpy as np
|
| 9 |
-
# openvino
|
| 10 |
-
from openvino.runtime import Core
|
| 11 |
-
# tokenizer
|
| 12 |
-
from transformers import CLIPTokenizer
|
| 13 |
-
import torch
|
| 14 |
-
import random
|
| 15 |
-
|
| 16 |
-
from diffusers import DiffusionPipeline
|
| 17 |
-
from diffusers.schedulers import (DDIMScheduler,
|
| 18 |
-
LMSDiscreteScheduler,
|
| 19 |
-
PNDMScheduler,
|
| 20 |
-
EulerDiscreteScheduler,
|
| 21 |
-
EulerAncestralDiscreteScheduler)
|
| 22 |
-
|
| 23 |
-
|
| 24 |
-
from diffusers.image_processor import VaeImageProcessor
|
| 25 |
-
from diffusers.utils.torch_utils import randn_tensor
|
| 26 |
-
from diffusers.utils import PIL_INTERPOLATION
|
| 27 |
-
|
| 28 |
-
import cv2
|
| 29 |
-
import os
|
| 30 |
-
import sys
|
| 31 |
-
|
| 32 |
-
# for multithreading
|
| 33 |
-
import concurrent.futures
|
| 34 |
-
|
| 35 |
-
#For GIF
|
| 36 |
-
import PIL
|
| 37 |
-
from PIL import Image
|
| 38 |
-
import glob
|
| 39 |
-
import json
|
| 40 |
-
import time
|
| 41 |
-
|
| 42 |
-
def scale_fit_to_window(dst_width:int, dst_height:int, image_width:int, image_height:int):
|
| 43 |
-
"""
|
| 44 |
-
Preprocessing helper function for calculating image size for resize with peserving original aspect ratio
|
| 45 |
-
and fitting image to specific window size
|
| 46 |
-
|
| 47 |
-
Parameters:
|
| 48 |
-
dst_width (int): destination window width
|
| 49 |
-
dst_height (int): destination window height
|
| 50 |
-
image_width (int): source image width
|
| 51 |
-
image_height (int): source image height
|
| 52 |
-
Returns:
|
| 53 |
-
result_width (int): calculated width for resize
|
| 54 |
-
result_height (int): calculated height for resize
|
| 55 |
-
"""
|
| 56 |
-
im_scale = min(dst_height / image_height, dst_width / image_width)
|
| 57 |
-
return int(im_scale * image_width), int(im_scale * image_height)
|
| 58 |
-
|
| 59 |
-
def preprocess(image: PIL.Image.Image, ht=512, wt=512):
|
| 60 |
-
"""
|
| 61 |
-
Image preprocessing function. Takes image in PIL.Image format, resizes it to keep aspect ration and fits to model input window 512x512,
|
| 62 |
-
then converts it to np.ndarray and adds padding with zeros on right or bottom side of image (depends from aspect ratio), after that
|
| 63 |
-
converts data to float32 data type and change range of values from [0, 255] to [-1, 1], finally, converts data layout from planar NHWC to NCHW.
|
| 64 |
-
The function returns preprocessed input tensor and padding size, which can be used in postprocessing.
|
| 65 |
-
|
| 66 |
-
Parameters:
|
| 67 |
-
image (PIL.Image.Image): input image
|
| 68 |
-
Returns:
|
| 69 |
-
image (np.ndarray): preprocessed image tensor
|
| 70 |
-
meta (Dict): dictionary with preprocessing metadata info
|
| 71 |
-
"""
|
| 72 |
-
|
| 73 |
-
src_width, src_height = image.size
|
| 74 |
-
image = image.convert('RGB')
|
| 75 |
-
dst_width, dst_height = scale_fit_to_window(
|
| 76 |
-
wt, ht, src_width, src_height)
|
| 77 |
-
image = np.array(image.resize((dst_width, dst_height),
|
| 78 |
-
resample=PIL.Image.Resampling.LANCZOS))[None, :]
|
| 79 |
-
|
| 80 |
-
pad_width = wt - dst_width
|
| 81 |
-
pad_height = ht - dst_height
|
| 82 |
-
pad = ((0, 0), (0, pad_height), (0, pad_width), (0, 0))
|
| 83 |
-
image = np.pad(image, pad, mode="constant")
|
| 84 |
-
image = image.astype(np.float32) / 255.0
|
| 85 |
-
image = 2.0 * image - 1.0
|
| 86 |
-
image = image.transpose(0, 3, 1, 2)
|
| 87 |
-
|
| 88 |
-
return image, {"padding": pad, "src_width": src_width, "src_height": src_height}
|
| 89 |
-
|
| 90 |
-
def try_enable_npu_turbo(device, core):
|
| 91 |
-
import platform
|
| 92 |
-
if "windows" in platform.system().lower():
|
| 93 |
-
if "NPU" in device and "3720" not in core.get_property('NPU', 'DEVICE_ARCHITECTURE'):
|
| 94 |
-
try:
|
| 95 |
-
core.set_property(properties={'NPU_TURBO': 'YES'},device_name='NPU')
|
| 96 |
-
except:
|
| 97 |
-
print(f"Failed loading NPU_TURBO for device {device}. Skipping... ")
|
| 98 |
-
else:
|
| 99 |
-
print_npu_turbo_art()
|
| 100 |
-
else:
|
| 101 |
-
print(f"Skipping NPU_TURBO for device {device}")
|
| 102 |
-
elif "linux" in platform.system().lower():
|
| 103 |
-
if os.path.isfile('/sys/module/intel_vpu/parameters/test_mode'):
|
| 104 |
-
with open('/sys/module/intel_vpu/version', 'r') as f:
|
| 105 |
-
version = f.readline().split()[0]
|
| 106 |
-
if tuple(map(int, version.split('.'))) < tuple(map(int, '1.9.0'.split('.'))):
|
| 107 |
-
print(f"The driver intel_vpu-1.9.0 (or later) needs to be loaded for NPU Turbo (currently {version}). Skipping...")
|
| 108 |
-
else:
|
| 109 |
-
with open('/sys/module/intel_vpu/parameters/test_mode', 'r') as tm_file:
|
| 110 |
-
test_mode = int(tm_file.readline().split()[0])
|
| 111 |
-
if test_mode == 512:
|
| 112 |
-
print_npu_turbo_art()
|
| 113 |
-
else:
|
| 114 |
-
print("The driver >=intel_vpu-1.9.0 was must be loaded with "
|
| 115 |
-
"\"modprobe intel_vpu test_mode=512\" to enable NPU_TURBO "
|
| 116 |
-
f"(currently test_mode={test_mode}). Skipping...")
|
| 117 |
-
else:
|
| 118 |
-
print(f"The driver >=intel_vpu-1.9.0 must be loaded with \"modprobe intel_vpu test_mode=512\" to enable NPU_TURBO. Skipping...")
|
| 119 |
-
else:
|
| 120 |
-
print(f"This platform ({platform.system()}) does not support NPU Turbo")
|
| 121 |
-
|
| 122 |
-
def result(var):
|
| 123 |
-
return next(iter(var.values()))
|
| 124 |
-
|
| 125 |
-
class StableDiffusionEngineAdvanced(DiffusionPipeline):
|
| 126 |
-
def __init__(self, model="runwayml/stable-diffusion-v1-5",
|
| 127 |
-
tokenizer="openai/clip-vit-large-patch14",
|
| 128 |
-
device=["CPU", "CPU", "CPU", "CPU"]):
|
| 129 |
-
try:
|
| 130 |
-
self.tokenizer = CLIPTokenizer.from_pretrained(model, local_files_only=True)
|
| 131 |
-
except:
|
| 132 |
-
self.tokenizer = CLIPTokenizer.from_pretrained(tokenizer)
|
| 133 |
-
self.tokenizer.save_pretrained(model)
|
| 134 |
-
|
| 135 |
-
self.core = Core()
|
| 136 |
-
self.core.set_property({'CACHE_DIR': os.path.join(model, 'cache')})
|
| 137 |
-
try_enable_npu_turbo(device, self.core)
|
| 138 |
-
|
| 139 |
-
print("Loading models... ")
|
| 140 |
-
|
| 141 |
-
|
| 142 |
-
|
| 143 |
-
with concurrent.futures.ThreadPoolExecutor(max_workers=8) as executor:
|
| 144 |
-
futures = {
|
| 145 |
-
"unet_time_proj": executor.submit(self.core.compile_model, os.path.join(model, "unet_time_proj.xml"), device[0]),
|
| 146 |
-
"text": executor.submit(self.load_model, model, "text_encoder", device[0]),
|
| 147 |
-
"unet": executor.submit(self.load_model, model, "unet_int8", device[1]),
|
| 148 |
-
"unet_neg": executor.submit(self.load_model, model, "unet_int8", device[2]) if device[1] != device[2] else None,
|
| 149 |
-
"vae_decoder": executor.submit(self.load_model, model, "vae_decoder", device[3]),
|
| 150 |
-
"vae_encoder": executor.submit(self.load_model, model, "vae_encoder", device[3])
|
| 151 |
-
}
|
| 152 |
-
|
| 153 |
-
self.unet_time_proj = futures["unet_time_proj"].result()
|
| 154 |
-
self.text_encoder = futures["text"].result()
|
| 155 |
-
self.unet = futures["unet"].result()
|
| 156 |
-
self.unet_neg = futures["unet_neg"].result() if futures["unet_neg"] else self.unet
|
| 157 |
-
self.vae_decoder = futures["vae_decoder"].result()
|
| 158 |
-
self.vae_encoder = futures["vae_encoder"].result()
|
| 159 |
-
print("Text Device:", device[0])
|
| 160 |
-
print("unet Device:", device[1])
|
| 161 |
-
print("unet-neg Device:", device[2])
|
| 162 |
-
print("VAE Device:", device[3])
|
| 163 |
-
|
| 164 |
-
self._text_encoder_output = self.text_encoder.output(0)
|
| 165 |
-
self._vae_d_output = self.vae_decoder.output(0)
|
| 166 |
-
self._vae_e_output = self.vae_encoder.output(0) if self.vae_encoder else None
|
| 167 |
-
|
| 168 |
-
self.set_dimensions()
|
| 169 |
-
self.infer_request_neg = self.unet_neg.create_infer_request()
|
| 170 |
-
self.infer_request = self.unet.create_infer_request()
|
| 171 |
-
self.infer_request_time_proj = self.unet_time_proj.create_infer_request()
|
| 172 |
-
self.time_proj_constants = np.load(os.path.join(model, "time_proj_constants.npy"))
|
| 173 |
-
|
| 174 |
-
def load_model(self, model, model_name, device):
|
| 175 |
-
if "NPU" in device:
|
| 176 |
-
with open(os.path.join(model, f"{model_name}.blob"), "rb") as f:
|
| 177 |
-
return self.core.import_model(f.read(), device)
|
| 178 |
-
return self.core.compile_model(os.path.join(model, f"{model_name}.xml"), device)
|
| 179 |
-
|
| 180 |
-
def set_dimensions(self):
|
| 181 |
-
latent_shape = self.unet.input("latent_model_input").shape
|
| 182 |
-
if latent_shape[1] == 4:
|
| 183 |
-
self.height = latent_shape[2] * 8
|
| 184 |
-
self.width = latent_shape[3] * 8
|
| 185 |
-
else:
|
| 186 |
-
self.height = latent_shape[1] * 8
|
| 187 |
-
self.width = latent_shape[2] * 8
|
| 188 |
-
|
| 189 |
-
def __call__(
|
| 190 |
-
self,
|
| 191 |
-
prompt,
|
| 192 |
-
init_image = None,
|
| 193 |
-
negative_prompt=None,
|
| 194 |
-
scheduler=None,
|
| 195 |
-
strength = 0.5,
|
| 196 |
-
num_inference_steps = 32,
|
| 197 |
-
guidance_scale = 7.5,
|
| 198 |
-
eta = 0.0,
|
| 199 |
-
create_gif = False,
|
| 200 |
-
model = None,
|
| 201 |
-
callback = None,
|
| 202 |
-
callback_userdata = None
|
| 203 |
-
):
|
| 204 |
-
|
| 205 |
-
# extract condition
|
| 206 |
-
text_input = self.tokenizer(
|
| 207 |
-
prompt,
|
| 208 |
-
padding="max_length",
|
| 209 |
-
max_length=self.tokenizer.model_max_length,
|
| 210 |
-
truncation=True,
|
| 211 |
-
return_tensors="np",
|
| 212 |
-
)
|
| 213 |
-
text_embeddings = self.text_encoder(text_input.input_ids)[self._text_encoder_output]
|
| 214 |
-
|
| 215 |
-
# do classifier free guidance
|
| 216 |
-
do_classifier_free_guidance = guidance_scale > 1.0
|
| 217 |
-
if do_classifier_free_guidance:
|
| 218 |
-
|
| 219 |
-
if negative_prompt is None:
|
| 220 |
-
uncond_tokens = [""]
|
| 221 |
-
elif isinstance(negative_prompt, str):
|
| 222 |
-
uncond_tokens = [negative_prompt]
|
| 223 |
-
else:
|
| 224 |
-
uncond_tokens = negative_prompt
|
| 225 |
-
|
| 226 |
-
tokens_uncond = self.tokenizer(
|
| 227 |
-
uncond_tokens,
|
| 228 |
-
padding="max_length",
|
| 229 |
-
max_length=self.tokenizer.model_max_length, #truncation=True,
|
| 230 |
-
return_tensors="np"
|
| 231 |
-
)
|
| 232 |
-
uncond_embeddings = self.text_encoder(tokens_uncond.input_ids)[self._text_encoder_output]
|
| 233 |
-
text_embeddings = np.concatenate([uncond_embeddings, text_embeddings])
|
| 234 |
-
|
| 235 |
-
# set timesteps
|
| 236 |
-
accepts_offset = "offset" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
|
| 237 |
-
extra_set_kwargs = {}
|
| 238 |
-
|
| 239 |
-
if accepts_offset:
|
| 240 |
-
extra_set_kwargs["offset"] = 1
|
| 241 |
-
|
| 242 |
-
scheduler.set_timesteps(num_inference_steps, **extra_set_kwargs)
|
| 243 |
-
|
| 244 |
-
timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, scheduler)
|
| 245 |
-
latent_timestep = timesteps[:1]
|
| 246 |
-
|
| 247 |
-
# get the initial random noise unless the user supplied it
|
| 248 |
-
latents, meta = self.prepare_latents(init_image, latent_timestep, scheduler)
|
| 249 |
-
|
| 250 |
-
|
| 251 |
-
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
|
| 252 |
-
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
|
| 253 |
-
# eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
|
| 254 |
-
# and should be between [0, 1]
|
| 255 |
-
accepts_eta = "eta" in set(inspect.signature(scheduler.step).parameters.keys())
|
| 256 |
-
extra_step_kwargs = {}
|
| 257 |
-
if accepts_eta:
|
| 258 |
-
extra_step_kwargs["eta"] = eta
|
| 259 |
-
if create_gif:
|
| 260 |
-
frames = []
|
| 261 |
-
|
| 262 |
-
for i, t in enumerate(self.progress_bar(timesteps)):
|
| 263 |
-
if callback:
|
| 264 |
-
callback(i, callback_userdata)
|
| 265 |
-
|
| 266 |
-
# expand the latents if we are doing classifier free guidance
|
| 267 |
-
noise_pred = []
|
| 268 |
-
latent_model_input = latents
|
| 269 |
-
latent_model_input = scheduler.scale_model_input(latent_model_input, t)
|
| 270 |
-
|
| 271 |
-
latent_model_input_neg = latent_model_input
|
| 272 |
-
if self.unet.input("latent_model_input").shape[1] != 4:
|
| 273 |
-
#print("In transpose")
|
| 274 |
-
try:
|
| 275 |
-
latent_model_input = latent_model_input.permute(0,2,3,1)
|
| 276 |
-
except:
|
| 277 |
-
latent_model_input = latent_model_input.transpose(0,2,3,1)
|
| 278 |
-
|
| 279 |
-
if self.unet_neg.input("latent_model_input").shape[1] != 4:
|
| 280 |
-
#print("In transpose")
|
| 281 |
-
try:
|
| 282 |
-
latent_model_input_neg = latent_model_input_neg.permute(0,2,3,1)
|
| 283 |
-
except:
|
| 284 |
-
latent_model_input_neg = latent_model_input_neg.transpose(0,2,3,1)
|
| 285 |
-
|
| 286 |
-
|
| 287 |
-
time_proj_constants_fp16 = np.float16(self.time_proj_constants)
|
| 288 |
-
t_scaled_fp16 = time_proj_constants_fp16 * np.float16(t)
|
| 289 |
-
cosine_t_fp16 = np.cos(t_scaled_fp16)
|
| 290 |
-
sine_t_fp16 = np.sin(t_scaled_fp16)
|
| 291 |
-
|
| 292 |
-
t_scaled = self.time_proj_constants * np.float32(t)
|
| 293 |
-
|
| 294 |
-
cosine_t = np.cos(t_scaled)
|
| 295 |
-
sine_t = np.sin(t_scaled)
|
| 296 |
-
|
| 297 |
-
time_proj_dict = {"sine_t" : np.float32(sine_t), "cosine_t" : np.float32(cosine_t)}
|
| 298 |
-
self.infer_request_time_proj.start_async(time_proj_dict)
|
| 299 |
-
self.infer_request_time_proj.wait()
|
| 300 |
-
time_proj = self.infer_request_time_proj.get_output_tensor(0).data.astype(np.float32)
|
| 301 |
-
|
| 302 |
-
input_tens_neg_dict = {"time_proj": np.float32(time_proj), "latent_model_input":latent_model_input_neg, "encoder_hidden_states": np.expand_dims(text_embeddings[0], axis=0)}
|
| 303 |
-
input_tens_dict = {"time_proj": np.float32(time_proj), "latent_model_input":latent_model_input, "encoder_hidden_states": np.expand_dims(text_embeddings[1], axis=0)}
|
| 304 |
-
|
| 305 |
-
self.infer_request_neg.start_async(input_tens_neg_dict)
|
| 306 |
-
self.infer_request.start_async(input_tens_dict)
|
| 307 |
-
self.infer_request_neg.wait()
|
| 308 |
-
self.infer_request.wait()
|
| 309 |
-
|
| 310 |
-
noise_pred_neg = self.infer_request_neg.get_output_tensor(0)
|
| 311 |
-
noise_pred_pos = self.infer_request.get_output_tensor(0)
|
| 312 |
-
|
| 313 |
-
noise_pred.append(noise_pred_neg.data.astype(np.float32))
|
| 314 |
-
noise_pred.append(noise_pred_pos.data.astype(np.float32))
|
| 315 |
-
|
| 316 |
-
# perform guidance
|
| 317 |
-
if do_classifier_free_guidance:
|
| 318 |
-
noise_pred_uncond, noise_pred_text = noise_pred[0], noise_pred[1]
|
| 319 |
-
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
|
| 320 |
-
|
| 321 |
-
# compute the previous noisy sample x_t -> x_t-1
|
| 322 |
-
latents = scheduler.step(torch.from_numpy(noise_pred), t, torch.from_numpy(latents), **extra_step_kwargs)["prev_sample"].numpy()
|
| 323 |
-
|
| 324 |
-
if create_gif:
|
| 325 |
-
frames.append(latents)
|
| 326 |
-
|
| 327 |
-
if callback:
|
| 328 |
-
callback(num_inference_steps, callback_userdata)
|
| 329 |
-
|
| 330 |
-
# scale and decode the image latents with vae
|
| 331 |
-
latents = 1 / 0.18215 * latents
|
| 332 |
-
|
| 333 |
-
start = time.time()
|
| 334 |
-
image = self.vae_decoder(latents)[self._vae_d_output]
|
| 335 |
-
print("Decoder ended:",time.time() - start)
|
| 336 |
-
|
| 337 |
-
image = self.postprocess_image(image, meta)
|
| 338 |
-
|
| 339 |
-
if create_gif:
|
| 340 |
-
gif_folder=os.path.join(model,"../../../gif")
|
| 341 |
-
print("gif_folder:",gif_folder)
|
| 342 |
-
if not os.path.exists(gif_folder):
|
| 343 |
-
os.makedirs(gif_folder)
|
| 344 |
-
for i in range(0,len(frames)):
|
| 345 |
-
image = self.vae_decoder(frames[i]*(1/0.18215))[self._vae_d_output]
|
| 346 |
-
image = self.postprocess_image(image, meta)
|
| 347 |
-
output = gif_folder + "/" + str(i).zfill(3) +".png"
|
| 348 |
-
cv2.imwrite(output, image)
|
| 349 |
-
with open(os.path.join(gif_folder, "prompt.json"), "w") as file:
|
| 350 |
-
json.dump({"prompt": prompt}, file)
|
| 351 |
-
frames_image = [Image.open(image) for image in glob.glob(f"{gif_folder}/*.png")]
|
| 352 |
-
frame_one = frames_image[0]
|
| 353 |
-
gif_file=os.path.join(gif_folder,"stable_diffusion.gif")
|
| 354 |
-
frame_one.save(gif_file, format="GIF", append_images=frames_image, save_all=True, duration=100, loop=0)
|
| 355 |
-
|
| 356 |
-
return image
|
| 357 |
-
|
| 358 |
-
def prepare_latents(self, image:PIL.Image.Image = None, latent_timestep:torch.Tensor = None, scheduler = LMSDiscreteScheduler):
|
| 359 |
-
"""
|
| 360 |
-
Function for getting initial latents for starting generation
|
| 361 |
-
|
| 362 |
-
Parameters:
|
| 363 |
-
image (PIL.Image.Image, *optional*, None):
|
| 364 |
-
Input image for generation, if not provided randon noise will be used as starting point
|
| 365 |
-
latent_timestep (torch.Tensor, *optional*, None):
|
| 366 |
-
Predicted by scheduler initial step for image generation, required for latent image mixing with nosie
|
| 367 |
-
Returns:
|
| 368 |
-
latents (np.ndarray):
|
| 369 |
-
Image encoded in latent space
|
| 370 |
-
"""
|
| 371 |
-
latents_shape = (1, 4, self.height // 8, self.width // 8)
|
| 372 |
-
|
| 373 |
-
noise = np.random.randn(*latents_shape).astype(np.float32)
|
| 374 |
-
if image is None:
|
| 375 |
-
##print("Image is NONE")
|
| 376 |
-
# if we use LMSDiscreteScheduler, let's make sure latents are mulitplied by sigmas
|
| 377 |
-
if isinstance(scheduler, LMSDiscreteScheduler):
|
| 378 |
-
|
| 379 |
-
noise = noise * scheduler.sigmas[0].numpy()
|
| 380 |
-
return noise, {}
|
| 381 |
-
elif isinstance(scheduler, EulerDiscreteScheduler) or isinstance(scheduler,EulerAncestralDiscreteScheduler):
|
| 382 |
-
|
| 383 |
-
noise = noise * scheduler.sigmas.max().numpy()
|
| 384 |
-
return noise, {}
|
| 385 |
-
else:
|
| 386 |
-
return noise, {}
|
| 387 |
-
input_image, meta = preprocess(image,self.height,self.width)
|
| 388 |
-
|
| 389 |
-
moments = self.vae_encoder(input_image)[self._vae_e_output]
|
| 390 |
-
|
| 391 |
-
mean, logvar = np.split(moments, 2, axis=1)
|
| 392 |
-
|
| 393 |
-
std = np.exp(logvar * 0.5)
|
| 394 |
-
latents = (mean + std * np.random.randn(*mean.shape)) * 0.18215
|
| 395 |
-
|
| 396 |
-
|
| 397 |
-
latents = scheduler.add_noise(torch.from_numpy(latents), torch.from_numpy(noise), latent_timestep).numpy()
|
| 398 |
-
return latents, meta
|
| 399 |
-
|
| 400 |
-
def postprocess_image(self, image:np.ndarray, meta:Dict):
|
| 401 |
-
"""
|
| 402 |
-
Postprocessing for decoded image. Takes generated image decoded by VAE decoder, unpad it to initial image size (if required),
|
| 403 |
-
normalize and convert to [0, 255] pixels range. Optionally, convertes it from np.ndarray to PIL.Image format
|
| 404 |
-
|
| 405 |
-
Parameters:
|
| 406 |
-
image (np.ndarray):
|
| 407 |
-
Generated image
|
| 408 |
-
meta (Dict):
|
| 409 |
-
Metadata obtained on latents preparing step, can be empty
|
| 410 |
-
output_type (str, *optional*, pil):
|
| 411 |
-
Output format for result, can be pil or numpy
|
| 412 |
-
Returns:
|
| 413 |
-
image (List of np.ndarray or PIL.Image.Image):
|
| 414 |
-
Postprocessed images
|
| 415 |
-
|
| 416 |
-
if "src_height" in meta:
|
| 417 |
-
orig_height, orig_width = meta["src_height"], meta["src_width"]
|
| 418 |
-
image = [cv2.resize(img, (orig_width, orig_height))
|
| 419 |
-
for img in image]
|
| 420 |
-
|
| 421 |
-
return image
|
| 422 |
-
"""
|
| 423 |
-
if "padding" in meta:
|
| 424 |
-
pad = meta["padding"]
|
| 425 |
-
(_, end_h), (_, end_w) = pad[1:3]
|
| 426 |
-
h, w = image.shape[2:]
|
| 427 |
-
#print("image shape",image.shape[2:])
|
| 428 |
-
unpad_h = h - end_h
|
| 429 |
-
unpad_w = w - end_w
|
| 430 |
-
image = image[:, :, :unpad_h, :unpad_w]
|
| 431 |
-
image = np.clip(image / 2 + 0.5, 0, 1)
|
| 432 |
-
image = (image[0].transpose(1, 2, 0)[:, :, ::-1] * 255).astype(np.uint8)
|
| 433 |
-
|
| 434 |
-
|
| 435 |
-
|
| 436 |
-
if "src_height" in meta:
|
| 437 |
-
orig_height, orig_width = meta["src_height"], meta["src_width"]
|
| 438 |
-
image = cv2.resize(image, (orig_width, orig_height))
|
| 439 |
-
|
| 440 |
-
return image
|
| 441 |
-
|
| 442 |
-
|
| 443 |
-
|
| 444 |
-
|
| 445 |
-
def get_timesteps(self, num_inference_steps:int, strength:float, scheduler):
|
| 446 |
-
"""
|
| 447 |
-
Helper function for getting scheduler timesteps for generation
|
| 448 |
-
In case of image-to-image generation, it updates number of steps according to strength
|
| 449 |
-
|
| 450 |
-
Parameters:
|
| 451 |
-
num_inference_steps (int):
|
| 452 |
-
number of inference steps for generation
|
| 453 |
-
strength (float):
|
| 454 |
-
value between 0.0 and 1.0, that controls the amount of noise that is added to the input image.
|
| 455 |
-
Values that approach 1.0 allow for lots of variations but will also produce images that are not semantically consistent with the input.
|
| 456 |
-
"""
|
| 457 |
-
# get the original timestep using init_timestep
|
| 458 |
-
|
| 459 |
-
init_timestep = min(int(num_inference_steps * strength), num_inference_steps)
|
| 460 |
-
|
| 461 |
-
t_start = max(num_inference_steps - init_timestep, 0)
|
| 462 |
-
timesteps = scheduler.timesteps[t_start:]
|
| 463 |
-
|
| 464 |
-
return timesteps, num_inference_steps - t_start
|
| 465 |
-
|
| 466 |
-
class StableDiffusionEngine(DiffusionPipeline):
|
| 467 |
-
def __init__(
|
| 468 |
-
self,
|
| 469 |
-
model="bes-dev/stable-diffusion-v1-4-openvino",
|
| 470 |
-
tokenizer="openai/clip-vit-large-patch14",
|
| 471 |
-
device=["CPU","CPU","CPU","CPU"]):
|
| 472 |
-
|
| 473 |
-
self.core = Core()
|
| 474 |
-
self.core.set_property({'CACHE_DIR': os.path.join(model, 'cache')})
|
| 475 |
-
|
| 476 |
-
self.batch_size = 2 if device[1] == device[2] and device[1] == "GPU" else 1
|
| 477 |
-
try_enable_npu_turbo(device, self.core)
|
| 478 |
-
|
| 479 |
-
try:
|
| 480 |
-
self.tokenizer = CLIPTokenizer.from_pretrained(model, local_files_only=True)
|
| 481 |
-
except Exception as e:
|
| 482 |
-
print("Local tokenizer not found. Attempting to download...")
|
| 483 |
-
self.tokenizer = self.download_tokenizer(tokenizer, model)
|
| 484 |
-
|
| 485 |
-
print("Loading models... ")
|
| 486 |
-
|
| 487 |
-
with concurrent.futures.ThreadPoolExecutor(max_workers=8) as executor:
|
| 488 |
-
text_future = executor.submit(self.load_model, model, "text_encoder", device[0])
|
| 489 |
-
vae_de_future = executor.submit(self.load_model, model, "vae_decoder", device[3])
|
| 490 |
-
vae_en_future = executor.submit(self.load_model, model, "vae_encoder", device[3])
|
| 491 |
-
|
| 492 |
-
if self.batch_size == 1:
|
| 493 |
-
if "int8" not in model:
|
| 494 |
-
unet_future = executor.submit(self.load_model, model, "unet_bs1", device[1])
|
| 495 |
-
unet_neg_future = executor.submit(self.load_model, model, "unet_bs1", device[2]) if device[1] != device[2] else None
|
| 496 |
-
else:
|
| 497 |
-
unet_future = executor.submit(self.load_model, model, "unet_int8a16", device[1])
|
| 498 |
-
unet_neg_future = executor.submit(self.load_model, model, "unet_int8a16", device[2]) if device[1] != device[2] else None
|
| 499 |
-
else:
|
| 500 |
-
unet_future = executor.submit(self.load_model, model, "unet", device[1])
|
| 501 |
-
unet_neg_future = None
|
| 502 |
-
|
| 503 |
-
self.unet = unet_future.result()
|
| 504 |
-
self.unet_neg = unet_neg_future.result() if unet_neg_future else self.unet
|
| 505 |
-
self.text_encoder = text_future.result()
|
| 506 |
-
self.vae_decoder = vae_de_future.result()
|
| 507 |
-
self.vae_encoder = vae_en_future.result()
|
| 508 |
-
print("Text Device:", device[0])
|
| 509 |
-
print("unet Device:", device[1])
|
| 510 |
-
print("unet-neg Device:", device[2])
|
| 511 |
-
print("VAE Device:", device[3])
|
| 512 |
-
|
| 513 |
-
self._text_encoder_output = self.text_encoder.output(0)
|
| 514 |
-
self._unet_output = self.unet.output(0)
|
| 515 |
-
self._vae_d_output = self.vae_decoder.output(0)
|
| 516 |
-
self._vae_e_output = self.vae_encoder.output(0) if self.vae_encoder else None
|
| 517 |
-
|
| 518 |
-
self.unet_input_tensor_name = "sample" if 'sample' in self.unet.input(0).names else "latent_model_input"
|
| 519 |
-
|
| 520 |
-
if self.batch_size == 1:
|
| 521 |
-
self.infer_request = self.unet.create_infer_request()
|
| 522 |
-
self.infer_request_neg = self.unet_neg.create_infer_request()
|
| 523 |
-
self._unet_neg_output = self.unet_neg.output(0)
|
| 524 |
-
else:
|
| 525 |
-
self.infer_request = None
|
| 526 |
-
self.infer_request_neg = None
|
| 527 |
-
self._unet_neg_output = None
|
| 528 |
-
|
| 529 |
-
self.set_dimensions()
|
| 530 |
-
|
| 531 |
-
|
| 532 |
-
|
| 533 |
-
def load_model(self, model, model_name, device):
|
| 534 |
-
if "NPU" in device:
|
| 535 |
-
with open(os.path.join(model, f"{model_name}.blob"), "rb") as f:
|
| 536 |
-
return self.core.import_model(f.read(), device)
|
| 537 |
-
return self.core.compile_model(os.path.join(model, f"{model_name}.xml"), device)
|
| 538 |
-
|
| 539 |
-
def set_dimensions(self):
|
| 540 |
-
latent_shape = self.unet.input(self.unet_input_tensor_name).shape
|
| 541 |
-
if latent_shape[1] == 4:
|
| 542 |
-
self.height = latent_shape[2] * 8
|
| 543 |
-
self.width = latent_shape[3] * 8
|
| 544 |
-
else:
|
| 545 |
-
self.height = latent_shape[1] * 8
|
| 546 |
-
self.width = latent_shape[2] * 8
|
| 547 |
-
|
| 548 |
-
def __call__(
|
| 549 |
-
self,
|
| 550 |
-
prompt,
|
| 551 |
-
init_image=None,
|
| 552 |
-
negative_prompt=None,
|
| 553 |
-
scheduler=None,
|
| 554 |
-
strength=0.5,
|
| 555 |
-
num_inference_steps=32,
|
| 556 |
-
guidance_scale=7.5,
|
| 557 |
-
eta=0.0,
|
| 558 |
-
create_gif=False,
|
| 559 |
-
model=None,
|
| 560 |
-
callback=None,
|
| 561 |
-
callback_userdata=None
|
| 562 |
-
):
|
| 563 |
-
# extract condition
|
| 564 |
-
text_input = self.tokenizer(
|
| 565 |
-
prompt,
|
| 566 |
-
padding="max_length",
|
| 567 |
-
max_length=self.tokenizer.model_max_length,
|
| 568 |
-
truncation=True,
|
| 569 |
-
return_tensors="np",
|
| 570 |
-
)
|
| 571 |
-
text_embeddings = self.text_encoder(text_input.input_ids)[self._text_encoder_output]
|
| 572 |
-
|
| 573 |
-
|
| 574 |
-
# do classifier free guidance
|
| 575 |
-
do_classifier_free_guidance = guidance_scale > 1.0
|
| 576 |
-
if do_classifier_free_guidance:
|
| 577 |
-
if negative_prompt is None:
|
| 578 |
-
uncond_tokens = [""]
|
| 579 |
-
elif isinstance(negative_prompt, str):
|
| 580 |
-
uncond_tokens = [negative_prompt]
|
| 581 |
-
else:
|
| 582 |
-
uncond_tokens = negative_prompt
|
| 583 |
-
|
| 584 |
-
tokens_uncond = self.tokenizer(
|
| 585 |
-
uncond_tokens,
|
| 586 |
-
padding="max_length",
|
| 587 |
-
max_length=self.tokenizer.model_max_length, # truncation=True,
|
| 588 |
-
return_tensors="np"
|
| 589 |
-
)
|
| 590 |
-
uncond_embeddings = self.text_encoder(tokens_uncond.input_ids)[self._text_encoder_output]
|
| 591 |
-
text_embeddings = np.concatenate([uncond_embeddings, text_embeddings])
|
| 592 |
-
|
| 593 |
-
# set timesteps
|
| 594 |
-
accepts_offset = "offset" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
|
| 595 |
-
extra_set_kwargs = {}
|
| 596 |
-
|
| 597 |
-
if accepts_offset:
|
| 598 |
-
extra_set_kwargs["offset"] = 1
|
| 599 |
-
|
| 600 |
-
scheduler.set_timesteps(num_inference_steps, **extra_set_kwargs)
|
| 601 |
-
|
| 602 |
-
timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, scheduler)
|
| 603 |
-
latent_timestep = timesteps[:1]
|
| 604 |
-
|
| 605 |
-
# get the initial random noise unless the user supplied it
|
| 606 |
-
latents, meta = self.prepare_latents(init_image, latent_timestep, scheduler,model)
|
| 607 |
-
|
| 608 |
-
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
|
| 609 |
-
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
|
| 610 |
-
# eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
|
| 611 |
-
# and should be between [0, 1]
|
| 612 |
-
accepts_eta = "eta" in set(inspect.signature(scheduler.step).parameters.keys())
|
| 613 |
-
extra_step_kwargs = {}
|
| 614 |
-
if accepts_eta:
|
| 615 |
-
extra_step_kwargs["eta"] = eta
|
| 616 |
-
if create_gif:
|
| 617 |
-
frames = []
|
| 618 |
-
|
| 619 |
-
for i, t in enumerate(self.progress_bar(timesteps)):
|
| 620 |
-
if callback:
|
| 621 |
-
callback(i, callback_userdata)
|
| 622 |
-
|
| 623 |
-
if self.batch_size == 1:
|
| 624 |
-
# expand the latents if we are doing classifier free guidance
|
| 625 |
-
noise_pred = []
|
| 626 |
-
latent_model_input = latents
|
| 627 |
-
|
| 628 |
-
#Scales the denoising model input by `(sigma**2 + 1) ** 0.5` to match the Euler algorithm.
|
| 629 |
-
latent_model_input = scheduler.scale_model_input(latent_model_input, t)
|
| 630 |
-
latent_model_input_pos = latent_model_input
|
| 631 |
-
latent_model_input_neg = latent_model_input
|
| 632 |
-
|
| 633 |
-
if self.unet.input(self.unet_input_tensor_name).shape[1] != 4:
|
| 634 |
-
try:
|
| 635 |
-
latent_model_input_pos = latent_model_input_pos.permute(0,2,3,1)
|
| 636 |
-
except:
|
| 637 |
-
latent_model_input_pos = latent_model_input_pos.transpose(0,2,3,1)
|
| 638 |
-
|
| 639 |
-
if self.unet_neg.input(self.unet_input_tensor_name).shape[1] != 4:
|
| 640 |
-
try:
|
| 641 |
-
latent_model_input_neg = latent_model_input_neg.permute(0,2,3,1)
|
| 642 |
-
except:
|
| 643 |
-
latent_model_input_neg = latent_model_input_neg.transpose(0,2,3,1)
|
| 644 |
-
|
| 645 |
-
if "sample" in self.unet_input_tensor_name:
|
| 646 |
-
input_tens_neg_dict = {"sample" : latent_model_input_neg, "encoder_hidden_states": np.expand_dims(text_embeddings[0], axis=0), "timestep": np.expand_dims(np.float32(t), axis=0)}
|
| 647 |
-
input_tens_pos_dict = {"sample" : latent_model_input_pos, "encoder_hidden_states": np.expand_dims(text_embeddings[1], axis=0), "timestep": np.expand_dims(np.float32(t), axis=0)}
|
| 648 |
-
else:
|
| 649 |
-
input_tens_neg_dict = {"latent_model_input" : latent_model_input_neg, "encoder_hidden_states": np.expand_dims(text_embeddings[0], axis=0), "t": np.expand_dims(np.float32(t), axis=0)}
|
| 650 |
-
input_tens_pos_dict = {"latent_model_input" : latent_model_input_pos, "encoder_hidden_states": np.expand_dims(text_embeddings[1], axis=0), "t": np.expand_dims(np.float32(t), axis=0)}
|
| 651 |
-
|
| 652 |
-
self.infer_request_neg.start_async(input_tens_neg_dict)
|
| 653 |
-
self.infer_request.start_async(input_tens_pos_dict)
|
| 654 |
-
|
| 655 |
-
self.infer_request_neg.wait()
|
| 656 |
-
self.infer_request.wait()
|
| 657 |
-
|
| 658 |
-
noise_pred_neg = self.infer_request_neg.get_output_tensor(0)
|
| 659 |
-
noise_pred_pos = self.infer_request.get_output_tensor(0)
|
| 660 |
-
|
| 661 |
-
noise_pred.append(noise_pred_neg.data.astype(np.float32))
|
| 662 |
-
noise_pred.append(noise_pred_pos.data.astype(np.float32))
|
| 663 |
-
else:
|
| 664 |
-
latent_model_input = np.concatenate([latents] * 2) if do_classifier_free_guidance else latents
|
| 665 |
-
latent_model_input = scheduler.scale_model_input(latent_model_input, t)
|
| 666 |
-
noise_pred = self.unet([latent_model_input, np.array(t, dtype=np.float32), text_embeddings])[self._unet_output]
|
| 667 |
-
|
| 668 |
-
if do_classifier_free_guidance:
|
| 669 |
-
noise_pred_uncond, noise_pred_text = noise_pred[0], noise_pred[1]
|
| 670 |
-
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
|
| 671 |
-
|
| 672 |
-
# compute the previous noisy sample x_t -> x_t-1
|
| 673 |
-
latents = scheduler.step(torch.from_numpy(noise_pred), t, torch.from_numpy(latents), **extra_step_kwargs)["prev_sample"].numpy()
|
| 674 |
-
|
| 675 |
-
if create_gif:
|
| 676 |
-
frames.append(latents)
|
| 677 |
-
|
| 678 |
-
if callback:
|
| 679 |
-
callback(num_inference_steps, callback_userdata)
|
| 680 |
-
|
| 681 |
-
# scale and decode the image latents with vae
|
| 682 |
-
#if self.height == 512 and self.width == 512:
|
| 683 |
-
latents = 1 / 0.18215 * latents
|
| 684 |
-
image = self.vae_decoder(latents)[self._vae_d_output]
|
| 685 |
-
image = self.postprocess_image(image, meta)
|
| 686 |
-
|
| 687 |
-
return image
|
| 688 |
-
|
| 689 |
-
def prepare_latents(self, image: PIL.Image.Image = None, latent_timestep: torch.Tensor = None,
|
| 690 |
-
scheduler=LMSDiscreteScheduler,model=None):
|
| 691 |
-
"""
|
| 692 |
-
Function for getting initial latents for starting generation
|
| 693 |
-
|
| 694 |
-
Parameters:
|
| 695 |
-
image (PIL.Image.Image, *optional*, None):
|
| 696 |
-
Input image for generation, if not provided randon noise will be used as starting point
|
| 697 |
-
latent_timestep (torch.Tensor, *optional*, None):
|
| 698 |
-
Predicted by scheduler initial step for image generation, required for latent image mixing with nosie
|
| 699 |
-
Returns:
|
| 700 |
-
latents (np.ndarray):
|
| 701 |
-
Image encoded in latent space
|
| 702 |
-
"""
|
| 703 |
-
latents_shape = (1, 4, self.height // 8, self.width // 8)
|
| 704 |
-
|
| 705 |
-
noise = np.random.randn(*latents_shape).astype(np.float32)
|
| 706 |
-
if image is None:
|
| 707 |
-
#print("Image is NONE")
|
| 708 |
-
# if we use LMSDiscreteScheduler, let's make sure latents are mulitplied by sigmas
|
| 709 |
-
if isinstance(scheduler, LMSDiscreteScheduler):
|
| 710 |
-
|
| 711 |
-
noise = noise * scheduler.sigmas[0].numpy()
|
| 712 |
-
return noise, {}
|
| 713 |
-
elif isinstance(scheduler, EulerDiscreteScheduler):
|
| 714 |
-
|
| 715 |
-
noise = noise * scheduler.sigmas.max().numpy()
|
| 716 |
-
return noise, {}
|
| 717 |
-
else:
|
| 718 |
-
return noise, {}
|
| 719 |
-
input_image, meta = preprocess(image, self.height, self.width)
|
| 720 |
-
|
| 721 |
-
moments = self.vae_encoder(input_image)[self._vae_e_output]
|
| 722 |
-
|
| 723 |
-
if "sd_2.1" in model:
|
| 724 |
-
latents = moments * 0.18215
|
| 725 |
-
|
| 726 |
-
else:
|
| 727 |
-
|
| 728 |
-
mean, logvar = np.split(moments, 2, axis=1)
|
| 729 |
-
|
| 730 |
-
std = np.exp(logvar * 0.5)
|
| 731 |
-
latents = (mean + std * np.random.randn(*mean.shape)) * 0.18215
|
| 732 |
-
|
| 733 |
-
latents = scheduler.add_noise(torch.from_numpy(latents), torch.from_numpy(noise), latent_timestep).numpy()
|
| 734 |
-
return latents, meta
|
| 735 |
-
|
| 736 |
-
|
| 737 |
-
def postprocess_image(self, image: np.ndarray, meta: Dict):
|
| 738 |
-
"""
|
| 739 |
-
Postprocessing for decoded image. Takes generated image decoded by VAE decoder, unpad it to initila image size (if required),
|
| 740 |
-
normalize and convert to [0, 255] pixels range. Optionally, convertes it from np.ndarray to PIL.Image format
|
| 741 |
-
|
| 742 |
-
Parameters:
|
| 743 |
-
image (np.ndarray):
|
| 744 |
-
Generated image
|
| 745 |
-
meta (Dict):
|
| 746 |
-
Metadata obtained on latents preparing step, can be empty
|
| 747 |
-
output_type (str, *optional*, pil):
|
| 748 |
-
Output format for result, can be pil or numpy
|
| 749 |
-
Returns:
|
| 750 |
-
image (List of np.ndarray or PIL.Image.Image):
|
| 751 |
-
Postprocessed images
|
| 752 |
-
|
| 753 |
-
if "src_height" in meta:
|
| 754 |
-
orig_height, orig_width = meta["src_height"], meta["src_width"]
|
| 755 |
-
image = [cv2.resize(img, (orig_width, orig_height))
|
| 756 |
-
for img in image]
|
| 757 |
-
|
| 758 |
-
return image
|
| 759 |
-
"""
|
| 760 |
-
if "padding" in meta:
|
| 761 |
-
pad = meta["padding"]
|
| 762 |
-
(_, end_h), (_, end_w) = pad[1:3]
|
| 763 |
-
h, w = image.shape[2:]
|
| 764 |
-
# print("image shape",image.shape[2:])
|
| 765 |
-
unpad_h = h - end_h
|
| 766 |
-
unpad_w = w - end_w
|
| 767 |
-
image = image[:, :, :unpad_h, :unpad_w]
|
| 768 |
-
image = np.clip(image / 2 + 0.5, 0, 1)
|
| 769 |
-
image = (image[0].transpose(1, 2, 0)[:, :, ::-1] * 255).astype(np.uint8)
|
| 770 |
-
|
| 771 |
-
if "src_height" in meta:
|
| 772 |
-
orig_height, orig_width = meta["src_height"], meta["src_width"]
|
| 773 |
-
image = cv2.resize(image, (orig_width, orig_height))
|
| 774 |
-
|
| 775 |
-
return image
|
| 776 |
-
|
| 777 |
-
# image = (image / 2 + 0.5).clip(0, 1)
|
| 778 |
-
# image = (image[0].transpose(1, 2, 0)[:, :, ::-1] * 255).astype(np.uint8)
|
| 779 |
-
|
| 780 |
-
def get_timesteps(self, num_inference_steps: int, strength: float, scheduler):
|
| 781 |
-
"""
|
| 782 |
-
Helper function for getting scheduler timesteps for generation
|
| 783 |
-
In case of image-to-image generation, it updates number of steps according to strength
|
| 784 |
-
|
| 785 |
-
Parameters:
|
| 786 |
-
num_inference_steps (int):
|
| 787 |
-
number of inference steps for generation
|
| 788 |
-
strength (float):
|
| 789 |
-
value between 0.0 and 1.0, that controls the amount of noise that is added to the input image.
|
| 790 |
-
Values that approach 1.0 allow for lots of variations but will also produce images that are not semantically consistent with the input.
|
| 791 |
-
"""
|
| 792 |
-
# get the original timestep using init_timestep
|
| 793 |
-
|
| 794 |
-
init_timestep = min(int(num_inference_steps * strength), num_inference_steps)
|
| 795 |
-
|
| 796 |
-
t_start = max(num_inference_steps - init_timestep, 0)
|
| 797 |
-
timesteps = scheduler.timesteps[t_start:]
|
| 798 |
-
|
| 799 |
-
return timesteps, num_inference_steps - t_start
|
| 800 |
-
|
| 801 |
-
class LatentConsistencyEngine(DiffusionPipeline):
|
| 802 |
-
def __init__(
|
| 803 |
-
self,
|
| 804 |
-
model="SimianLuo/LCM_Dreamshaper_v7",
|
| 805 |
-
tokenizer="openai/clip-vit-large-patch14",
|
| 806 |
-
device=["CPU", "CPU", "CPU"],
|
| 807 |
-
):
|
| 808 |
-
super().__init__()
|
| 809 |
-
try:
|
| 810 |
-
self.tokenizer = CLIPTokenizer.from_pretrained(model, local_files_only=True)
|
| 811 |
-
except:
|
| 812 |
-
self.tokenizer = CLIPTokenizer.from_pretrained(tokenizer)
|
| 813 |
-
self.tokenizer.save_pretrained(model)
|
| 814 |
-
|
| 815 |
-
self.core = Core()
|
| 816 |
-
self.core.set_property({'CACHE_DIR': os.path.join(model, 'cache')}) # adding caching to reduce init time
|
| 817 |
-
try_enable_npu_turbo(device, self.core)
|
| 818 |
-
|
| 819 |
-
|
| 820 |
-
with concurrent.futures.ThreadPoolExecutor(max_workers=8) as executor:
|
| 821 |
-
text_future = executor.submit(self.load_model, model, "text_encoder", device[0])
|
| 822 |
-
unet_future = executor.submit(self.load_model, model, "unet", device[1])
|
| 823 |
-
vae_de_future = executor.submit(self.load_model, model, "vae_decoder", device[2])
|
| 824 |
-
|
| 825 |
-
print("Text Device:", device[0])
|
| 826 |
-
self.text_encoder = text_future.result()
|
| 827 |
-
self._text_encoder_output = self.text_encoder.output(0)
|
| 828 |
-
|
| 829 |
-
print("Unet Device:", device[1])
|
| 830 |
-
self.unet = unet_future.result()
|
| 831 |
-
self._unet_output = self.unet.output(0)
|
| 832 |
-
self.infer_request = self.unet.create_infer_request()
|
| 833 |
-
|
| 834 |
-
print(f"VAE Device: {device[2]}")
|
| 835 |
-
self.vae_decoder = vae_de_future.result()
|
| 836 |
-
self.infer_request_vae = self.vae_decoder.create_infer_request()
|
| 837 |
-
self.safety_checker = None #pipe.safety_checker
|
| 838 |
-
self.feature_extractor = None #pipe.feature_extractor
|
| 839 |
-
self.vae_scale_factor = 2 ** 3
|
| 840 |
-
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
|
| 841 |
-
|
| 842 |
-
def load_model(self, model, model_name, device):
|
| 843 |
-
if "NPU" in device:
|
| 844 |
-
with open(os.path.join(model, f"{model_name}.blob"), "rb") as f:
|
| 845 |
-
return self.core.import_model(f.read(), device)
|
| 846 |
-
return self.core.compile_model(os.path.join(model, f"{model_name}.xml"), device)
|
| 847 |
-
|
| 848 |
-
def _encode_prompt(
|
| 849 |
-
self,
|
| 850 |
-
prompt,
|
| 851 |
-
num_images_per_prompt,
|
| 852 |
-
prompt_embeds: None,
|
| 853 |
-
):
|
| 854 |
-
r"""
|
| 855 |
-
Encodes the prompt into text encoder hidden states.
|
| 856 |
-
Args:
|
| 857 |
-
prompt (`str` or `List[str]`, *optional*):
|
| 858 |
-
prompt to be encoded
|
| 859 |
-
num_images_per_prompt (`int`):
|
| 860 |
-
number of images that should be generated per prompt
|
| 861 |
-
prompt_embeds (`torch.FloatTensor`, *optional*):
|
| 862 |
-
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
| 863 |
-
provided, text embeddings will be generated from `prompt` input argument.
|
| 864 |
-
"""
|
| 865 |
-
|
| 866 |
-
if prompt_embeds is None:
|
| 867 |
-
|
| 868 |
-
text_inputs = self.tokenizer(
|
| 869 |
-
prompt,
|
| 870 |
-
padding="max_length",
|
| 871 |
-
max_length=self.tokenizer.model_max_length,
|
| 872 |
-
truncation=True,
|
| 873 |
-
return_tensors="pt",
|
| 874 |
-
)
|
| 875 |
-
text_input_ids = text_inputs.input_ids
|
| 876 |
-
untruncated_ids = self.tokenizer(
|
| 877 |
-
prompt, padding="longest", return_tensors="pt"
|
| 878 |
-
).input_ids
|
| 879 |
-
|
| 880 |
-
if untruncated_ids.shape[-1] >= text_input_ids.shape[
|
| 881 |
-
-1
|
| 882 |
-
] and not torch.equal(text_input_ids, untruncated_ids):
|
| 883 |
-
removed_text = self.tokenizer.batch_decode(
|
| 884 |
-
untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]
|
| 885 |
-
)
|
| 886 |
-
logger.warning(
|
| 887 |
-
"The following part of your input was truncated because CLIP can only handle sequences up to"
|
| 888 |
-
f" {self.tokenizer.model_max_length} tokens: {removed_text}"
|
| 889 |
-
)
|
| 890 |
-
|
| 891 |
-
prompt_embeds = self.text_encoder(text_input_ids, share_inputs=True, share_outputs=True)
|
| 892 |
-
prompt_embeds = torch.from_numpy(prompt_embeds[0])
|
| 893 |
-
|
| 894 |
-
bs_embed, seq_len, _ = prompt_embeds.shape
|
| 895 |
-
# duplicate text embeddings for each generation per prompt
|
| 896 |
-
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
| 897 |
-
prompt_embeds = prompt_embeds.view(
|
| 898 |
-
bs_embed * num_images_per_prompt, seq_len, -1
|
| 899 |
-
)
|
| 900 |
-
|
| 901 |
-
# Don't need to get uncond prompt embedding because of LCM Guided Distillation
|
| 902 |
-
return prompt_embeds
|
| 903 |
-
|
| 904 |
-
def run_safety_checker(self, image, dtype):
|
| 905 |
-
if self.safety_checker is None:
|
| 906 |
-
has_nsfw_concept = None
|
| 907 |
-
else:
|
| 908 |
-
if torch.is_tensor(image):
|
| 909 |
-
feature_extractor_input = self.image_processor.postprocess(
|
| 910 |
-
image, output_type="pil"
|
| 911 |
-
)
|
| 912 |
-
else:
|
| 913 |
-
feature_extractor_input = self.image_processor.numpy_to_pil(image)
|
| 914 |
-
safety_checker_input = self.feature_extractor(
|
| 915 |
-
feature_extractor_input, return_tensors="pt"
|
| 916 |
-
)
|
| 917 |
-
image, has_nsfw_concept = self.safety_checker(
|
| 918 |
-
images=image, clip_input=safety_checker_input.pixel_values.to(dtype)
|
| 919 |
-
)
|
| 920 |
-
return image, has_nsfw_concept
|
| 921 |
-
|
| 922 |
-
def prepare_latents(
|
| 923 |
-
self, batch_size, num_channels_latents, height, width, dtype, latents=None
|
| 924 |
-
):
|
| 925 |
-
shape = (
|
| 926 |
-
batch_size,
|
| 927 |
-
num_channels_latents,
|
| 928 |
-
height // self.vae_scale_factor,
|
| 929 |
-
width // self.vae_scale_factor,
|
| 930 |
-
)
|
| 931 |
-
if latents is None:
|
| 932 |
-
latents = torch.randn(shape, dtype=dtype)
|
| 933 |
-
# scale the initial noise by the standard deviation required by the scheduler
|
| 934 |
-
return latents
|
| 935 |
-
|
| 936 |
-
def get_w_embedding(self, w, embedding_dim=512, dtype=torch.float32):
|
| 937 |
-
"""
|
| 938 |
-
see https://github.com/google-research/vdm/blob/dc27b98a554f65cdc654b800da5aa1846545d41b/model_vdm.py#L298
|
| 939 |
-
Args:
|
| 940 |
-
timesteps: torch.Tensor: generate embedding vectors at these timesteps
|
| 941 |
-
embedding_dim: int: dimension of the embeddings to generate
|
| 942 |
-
dtype: data type of the generated embeddings
|
| 943 |
-
Returns:
|
| 944 |
-
embedding vectors with shape `(len(timesteps), embedding_dim)`
|
| 945 |
-
"""
|
| 946 |
-
assert len(w.shape) == 1
|
| 947 |
-
w = w * 1000.0
|
| 948 |
-
|
| 949 |
-
half_dim = embedding_dim // 2
|
| 950 |
-
emb = torch.log(torch.tensor(10000.0)) / (half_dim - 1)
|
| 951 |
-
emb = torch.exp(torch.arange(half_dim, dtype=dtype) * -emb)
|
| 952 |
-
emb = w.to(dtype)[:, None] * emb[None, :]
|
| 953 |
-
emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
|
| 954 |
-
if embedding_dim % 2 == 1: # zero pad
|
| 955 |
-
emb = torch.nn.functional.pad(emb, (0, 1))
|
| 956 |
-
assert emb.shape == (w.shape[0], embedding_dim)
|
| 957 |
-
return emb
|
| 958 |
-
|
| 959 |
-
@torch.no_grad()
|
| 960 |
-
def __call__(
|
| 961 |
-
self,
|
| 962 |
-
prompt: Union[str, List[str]] = None,
|
| 963 |
-
height: Optional[int] = 512,
|
| 964 |
-
width: Optional[int] = 512,
|
| 965 |
-
guidance_scale: float = 7.5,
|
| 966 |
-
scheduler = None,
|
| 967 |
-
num_images_per_prompt: Optional[int] = 1,
|
| 968 |
-
latents: Optional[torch.FloatTensor] = None,
|
| 969 |
-
num_inference_steps: int = 4,
|
| 970 |
-
lcm_origin_steps: int = 50,
|
| 971 |
-
prompt_embeds: Optional[torch.FloatTensor] = None,
|
| 972 |
-
output_type: Optional[str] = "pil",
|
| 973 |
-
return_dict: bool = True,
|
| 974 |
-
model: Optional[Dict[str, any]] = None,
|
| 975 |
-
seed: Optional[int] = 1234567,
|
| 976 |
-
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
| 977 |
-
callback = None,
|
| 978 |
-
callback_userdata = None
|
| 979 |
-
):
|
| 980 |
-
|
| 981 |
-
# 1. Define call parameters
|
| 982 |
-
if prompt is not None and isinstance(prompt, str):
|
| 983 |
-
batch_size = 1
|
| 984 |
-
elif prompt is not None and isinstance(prompt, list):
|
| 985 |
-
batch_size = len(prompt)
|
| 986 |
-
else:
|
| 987 |
-
batch_size = prompt_embeds.shape[0]
|
| 988 |
-
|
| 989 |
-
if seed is not None:
|
| 990 |
-
torch.manual_seed(seed)
|
| 991 |
-
|
| 992 |
-
#print("After Step 1: batch size is ", batch_size)
|
| 993 |
-
# do_classifier_free_guidance = guidance_scale > 0.0
|
| 994 |
-
# In LCM Implementation: cfg_noise = noise_cond + cfg_scale * (noise_cond - noise_uncond) , (cfg_scale > 0.0 using CFG)
|
| 995 |
-
|
| 996 |
-
# 2. Encode input prompt
|
| 997 |
-
prompt_embeds = self._encode_prompt(
|
| 998 |
-
prompt,
|
| 999 |
-
num_images_per_prompt,
|
| 1000 |
-
prompt_embeds=prompt_embeds,
|
| 1001 |
-
)
|
| 1002 |
-
#print("After Step 2: prompt embeds is ", prompt_embeds)
|
| 1003 |
-
#print("After Step 2: scheduler is ", scheduler )
|
| 1004 |
-
# 3. Prepare timesteps
|
| 1005 |
-
scheduler.set_timesteps(num_inference_steps, original_inference_steps=lcm_origin_steps)
|
| 1006 |
-
timesteps = scheduler.timesteps
|
| 1007 |
-
|
| 1008 |
-
#print("After Step 3: timesteps is ", timesteps)
|
| 1009 |
-
|
| 1010 |
-
# 4. Prepare latent variable
|
| 1011 |
-
num_channels_latents = 4
|
| 1012 |
-
latents = self.prepare_latents(
|
| 1013 |
-
batch_size * num_images_per_prompt,
|
| 1014 |
-
num_channels_latents,
|
| 1015 |
-
height,
|
| 1016 |
-
width,
|
| 1017 |
-
prompt_embeds.dtype,
|
| 1018 |
-
latents,
|
| 1019 |
-
)
|
| 1020 |
-
latents = latents * scheduler.init_noise_sigma
|
| 1021 |
-
|
| 1022 |
-
#print("After Step 4: ")
|
| 1023 |
-
bs = batch_size * num_images_per_prompt
|
| 1024 |
-
|
| 1025 |
-
# 5. Get Guidance Scale Embedding
|
| 1026 |
-
w = torch.tensor(guidance_scale).repeat(bs)
|
| 1027 |
-
w_embedding = self.get_w_embedding(w, embedding_dim=256)
|
| 1028 |
-
#print("After Step 5: ")
|
| 1029 |
-
# 6. LCM MultiStep Sampling Loop:
|
| 1030 |
-
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
| 1031 |
-
for i, t in enumerate(timesteps):
|
| 1032 |
-
if callback:
|
| 1033 |
-
callback(i+1, callback_userdata)
|
| 1034 |
-
|
| 1035 |
-
ts = torch.full((bs,), t, dtype=torch.long)
|
| 1036 |
-
|
| 1037 |
-
# model prediction (v-prediction, eps, x)
|
| 1038 |
-
model_pred = self.unet([latents, ts, prompt_embeds, w_embedding],share_inputs=True, share_outputs=True)[0]
|
| 1039 |
-
|
| 1040 |
-
# compute the previous noisy sample x_t -> x_t-1
|
| 1041 |
-
latents, denoised = scheduler.step(
|
| 1042 |
-
torch.from_numpy(model_pred), t, latents, return_dict=False
|
| 1043 |
-
)
|
| 1044 |
-
progress_bar.update()
|
| 1045 |
-
|
| 1046 |
-
#print("After Step 6: ")
|
| 1047 |
-
|
| 1048 |
-
vae_start = time.time()
|
| 1049 |
-
|
| 1050 |
-
if not output_type == "latent":
|
| 1051 |
-
image = torch.from_numpy(self.vae_decoder(denoised / 0.18215, share_inputs=True, share_outputs=True)[0])
|
| 1052 |
-
else:
|
| 1053 |
-
image = denoised
|
| 1054 |
-
|
| 1055 |
-
print("Decoder Ended: ", time.time() - vae_start)
|
| 1056 |
-
#post_start = time.time()
|
| 1057 |
-
|
| 1058 |
-
#if has_nsfw_concept is None:
|
| 1059 |
-
do_denormalize = [True] * image.shape[0]
|
| 1060 |
-
#else:
|
| 1061 |
-
# do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept]
|
| 1062 |
-
|
| 1063 |
-
#print ("After do_denormalize: image is ", image)
|
| 1064 |
-
|
| 1065 |
-
image = self.image_processor.postprocess(
|
| 1066 |
-
image, output_type=output_type, do_denormalize=do_denormalize
|
| 1067 |
-
)
|
| 1068 |
-
|
| 1069 |
-
return image[0]
|
| 1070 |
-
|
| 1071 |
-
class LatentConsistencyEngineAdvanced(DiffusionPipeline):
|
| 1072 |
-
def __init__(
|
| 1073 |
-
self,
|
| 1074 |
-
model="SimianLuo/LCM_Dreamshaper_v7",
|
| 1075 |
-
tokenizer="openai/clip-vit-large-patch14",
|
| 1076 |
-
device=["CPU", "CPU", "CPU"],
|
| 1077 |
-
):
|
| 1078 |
-
super().__init__()
|
| 1079 |
-
try:
|
| 1080 |
-
self.tokenizer = CLIPTokenizer.from_pretrained(model, local_files_only=True)
|
| 1081 |
-
except:
|
| 1082 |
-
self.tokenizer = CLIPTokenizer.from_pretrained(tokenizer)
|
| 1083 |
-
self.tokenizer.save_pretrained(model)
|
| 1084 |
-
|
| 1085 |
-
self.core = Core()
|
| 1086 |
-
self.core.set_property({'CACHE_DIR': os.path.join(model, 'cache')}) # adding caching to reduce init time
|
| 1087 |
-
#try_enable_npu_turbo(device, self.core)
|
| 1088 |
-
|
| 1089 |
-
|
| 1090 |
-
with concurrent.futures.ThreadPoolExecutor(max_workers=8) as executor:
|
| 1091 |
-
text_future = executor.submit(self.load_model, model, "text_encoder", device[0])
|
| 1092 |
-
unet_future = executor.submit(self.load_model, model, "unet", device[1])
|
| 1093 |
-
vae_de_future = executor.submit(self.load_model, model, "vae_decoder", device[2])
|
| 1094 |
-
vae_encoder_future = executor.submit(self.load_model, model, "vae_encoder", device[2])
|
| 1095 |
-
|
| 1096 |
-
|
| 1097 |
-
print("Text Device:", device[0])
|
| 1098 |
-
self.text_encoder = text_future.result()
|
| 1099 |
-
self._text_encoder_output = self.text_encoder.output(0)
|
| 1100 |
-
|
| 1101 |
-
print("Unet Device:", device[1])
|
| 1102 |
-
self.unet = unet_future.result()
|
| 1103 |
-
self._unet_output = self.unet.output(0)
|
| 1104 |
-
self.infer_request = self.unet.create_infer_request()
|
| 1105 |
-
|
| 1106 |
-
print(f"VAE Device: {device[2]}")
|
| 1107 |
-
self.vae_decoder = vae_de_future.result()
|
| 1108 |
-
self.vae_encoder = vae_encoder_future.result()
|
| 1109 |
-
self._vae_e_output = self.vae_encoder.output(0) if self.vae_encoder else None
|
| 1110 |
-
|
| 1111 |
-
self.infer_request_vae = self.vae_decoder.create_infer_request()
|
| 1112 |
-
self.safety_checker = None #pipe.safety_checker
|
| 1113 |
-
self.feature_extractor = None #pipe.feature_extractor
|
| 1114 |
-
self.vae_scale_factor = 2 ** 3
|
| 1115 |
-
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
|
| 1116 |
-
|
| 1117 |
-
def load_model(self, model, model_name, device):
|
| 1118 |
-
print(f"Compiling the {model_name} to {device} ...")
|
| 1119 |
-
return self.core.compile_model(os.path.join(model, f"{model_name}.xml"), device)
|
| 1120 |
-
|
| 1121 |
-
def get_timesteps(self, num_inference_steps:int, strength:float, scheduler):
|
| 1122 |
-
"""
|
| 1123 |
-
Helper function for getting scheduler timesteps for generation
|
| 1124 |
-
In case of image-to-image generation, it updates number of steps according to strength
|
| 1125 |
-
|
| 1126 |
-
Parameters:
|
| 1127 |
-
num_inference_steps (int):
|
| 1128 |
-
number of inference steps for generation
|
| 1129 |
-
strength (float):
|
| 1130 |
-
value between 0.0 and 1.0, that controls the amount of noise that is added to the input image.
|
| 1131 |
-
Values that approach 1.0 allow for lots of variations but will also produce images that are not semantically consistent with the input.
|
| 1132 |
-
"""
|
| 1133 |
-
# get the original timestep using init_timestep
|
| 1134 |
-
|
| 1135 |
-
init_timestep = min(int(num_inference_steps * strength), num_inference_steps)
|
| 1136 |
-
|
| 1137 |
-
t_start = max(num_inference_steps - init_timestep, 0)
|
| 1138 |
-
timesteps = scheduler.timesteps[t_start:]
|
| 1139 |
-
|
| 1140 |
-
return timesteps, num_inference_steps - t_start
|
| 1141 |
-
|
| 1142 |
-
def _encode_prompt(
|
| 1143 |
-
self,
|
| 1144 |
-
prompt,
|
| 1145 |
-
num_images_per_prompt,
|
| 1146 |
-
prompt_embeds: None,
|
| 1147 |
-
):
|
| 1148 |
-
r"""
|
| 1149 |
-
Encodes the prompt into text encoder hidden states.
|
| 1150 |
-
Args:
|
| 1151 |
-
prompt (`str` or `List[str]`, *optional*):
|
| 1152 |
-
prompt to be encoded
|
| 1153 |
-
num_images_per_prompt (`int`):
|
| 1154 |
-
number of images that should be generated per prompt
|
| 1155 |
-
prompt_embeds (`torch.FloatTensor`, *optional*):
|
| 1156 |
-
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
| 1157 |
-
provided, text embeddings will be generated from `prompt` input argument.
|
| 1158 |
-
"""
|
| 1159 |
-
|
| 1160 |
-
if prompt_embeds is None:
|
| 1161 |
-
|
| 1162 |
-
text_inputs = self.tokenizer(
|
| 1163 |
-
prompt,
|
| 1164 |
-
padding="max_length",
|
| 1165 |
-
max_length=self.tokenizer.model_max_length,
|
| 1166 |
-
truncation=True,
|
| 1167 |
-
return_tensors="pt",
|
| 1168 |
-
)
|
| 1169 |
-
text_input_ids = text_inputs.input_ids
|
| 1170 |
-
untruncated_ids = self.tokenizer(
|
| 1171 |
-
prompt, padding="longest", return_tensors="pt"
|
| 1172 |
-
).input_ids
|
| 1173 |
-
|
| 1174 |
-
if untruncated_ids.shape[-1] >= text_input_ids.shape[
|
| 1175 |
-
-1
|
| 1176 |
-
] and not torch.equal(text_input_ids, untruncated_ids):
|
| 1177 |
-
removed_text = self.tokenizer.batch_decode(
|
| 1178 |
-
untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]
|
| 1179 |
-
)
|
| 1180 |
-
logger.warning(
|
| 1181 |
-
"The following part of your input was truncated because CLIP can only handle sequences up to"
|
| 1182 |
-
f" {self.tokenizer.model_max_length} tokens: {removed_text}"
|
| 1183 |
-
)
|
| 1184 |
-
|
| 1185 |
-
prompt_embeds = self.text_encoder(text_input_ids, share_inputs=True, share_outputs=True)
|
| 1186 |
-
prompt_embeds = torch.from_numpy(prompt_embeds[0])
|
| 1187 |
-
|
| 1188 |
-
bs_embed, seq_len, _ = prompt_embeds.shape
|
| 1189 |
-
# duplicate text embeddings for each generation per prompt
|
| 1190 |
-
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
| 1191 |
-
prompt_embeds = prompt_embeds.view(
|
| 1192 |
-
bs_embed * num_images_per_prompt, seq_len, -1
|
| 1193 |
-
)
|
| 1194 |
-
|
| 1195 |
-
# Don't need to get uncond prompt embedding because of LCM Guided Distillation
|
| 1196 |
-
return prompt_embeds
|
| 1197 |
-
|
| 1198 |
-
def run_safety_checker(self, image, dtype):
|
| 1199 |
-
if self.safety_checker is None:
|
| 1200 |
-
has_nsfw_concept = None
|
| 1201 |
-
else:
|
| 1202 |
-
if torch.is_tensor(image):
|
| 1203 |
-
feature_extractor_input = self.image_processor.postprocess(
|
| 1204 |
-
image, output_type="pil"
|
| 1205 |
-
)
|
| 1206 |
-
else:
|
| 1207 |
-
feature_extractor_input = self.image_processor.numpy_to_pil(image)
|
| 1208 |
-
safety_checker_input = self.feature_extractor(
|
| 1209 |
-
feature_extractor_input, return_tensors="pt"
|
| 1210 |
-
)
|
| 1211 |
-
image, has_nsfw_concept = self.safety_checker(
|
| 1212 |
-
images=image, clip_input=safety_checker_input.pixel_values.to(dtype)
|
| 1213 |
-
)
|
| 1214 |
-
return image, has_nsfw_concep
|
| 1215 |
-
|
| 1216 |
-
def prepare_latents(
|
| 1217 |
-
self,image,timestep,batch_size, num_channels_latents, height, width, dtype, scheduler,latents=None,
|
| 1218 |
-
):
|
| 1219 |
-
shape = (
|
| 1220 |
-
batch_size,
|
| 1221 |
-
num_channels_latents,
|
| 1222 |
-
height // self.vae_scale_factor,
|
| 1223 |
-
width // self.vae_scale_factor,
|
| 1224 |
-
)
|
| 1225 |
-
if image:
|
| 1226 |
-
#latents_shape = (1, 4, 512, 512 // 8)
|
| 1227 |
-
#input_image, meta = preprocess(image,512,512)
|
| 1228 |
-
latents_shape = (1, 4, 512 // 8, 512 // 8)
|
| 1229 |
-
noise = np.random.randn(*latents_shape).astype(np.float32)
|
| 1230 |
-
input_image,meta = preprocess(image,512,512)
|
| 1231 |
-
moments = self.vae_encoder(input_image)[self._vae_e_output]
|
| 1232 |
-
mean, logvar = np.split(moments, 2, axis=1)
|
| 1233 |
-
std = np.exp(logvar * 0.5)
|
| 1234 |
-
latents = (mean + std * np.random.randn(*mean.shape)) * 0.18215
|
| 1235 |
-
noise = torch.randn(shape, dtype=dtype)
|
| 1236 |
-
#latents = scheduler.add_noise(init_latents, noise, timestep)
|
| 1237 |
-
latents = scheduler.add_noise(torch.from_numpy(latents), noise, timestep)
|
| 1238 |
-
|
| 1239 |
-
else:
|
| 1240 |
-
latents = torch.randn(shape, dtype=dtype)
|
| 1241 |
-
# scale the initial noise by the standard deviation required by the scheduler
|
| 1242 |
-
return latents
|
| 1243 |
-
|
| 1244 |
-
def get_w_embedding(self, w, embedding_dim=512, dtype=torch.float32):
|
| 1245 |
-
"""
|
| 1246 |
-
see https://github.com/google-research/vdm/blob/dc27b98a554f65cdc654b800da5aa1846545d41b/model_vdm.py#L298
|
| 1247 |
-
Args:
|
| 1248 |
-
timesteps: torch.Tensor: generate embedding vectors at these timesteps
|
| 1249 |
-
embedding_dim: int: dimension of the embeddings to generate
|
| 1250 |
-
dtype: data type of the generated embeddings
|
| 1251 |
-
Returns:
|
| 1252 |
-
embedding vectors with shape `(len(timesteps), embedding_dim)`
|
| 1253 |
-
"""
|
| 1254 |
-
assert len(w.shape) == 1
|
| 1255 |
-
w = w * 1000.0
|
| 1256 |
-
|
| 1257 |
-
half_dim = embedding_dim // 2
|
| 1258 |
-
emb = torch.log(torch.tensor(10000.0)) / (half_dim - 1)
|
| 1259 |
-
emb = torch.exp(torch.arange(half_dim, dtype=dtype) * -emb)
|
| 1260 |
-
emb = w.to(dtype)[:, None] * emb[None, :]
|
| 1261 |
-
emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
|
| 1262 |
-
if embedding_dim % 2 == 1: # zero pad
|
| 1263 |
-
emb = torch.nn.functional.pad(emb, (0, 1))
|
| 1264 |
-
assert emb.shape == (w.shape[0], embedding_dim)
|
| 1265 |
-
return emb
|
| 1266 |
-
|
| 1267 |
-
@torch.no_grad()
|
| 1268 |
-
def __call__(
|
| 1269 |
-
self,
|
| 1270 |
-
prompt: Union[str, List[str]] = None,
|
| 1271 |
-
init_image: Optional[PIL.Image.Image] = None,
|
| 1272 |
-
strength: Optional[float] = 0.8,
|
| 1273 |
-
height: Optional[int] = 512,
|
| 1274 |
-
width: Optional[int] = 512,
|
| 1275 |
-
guidance_scale: float = 7.5,
|
| 1276 |
-
scheduler = None,
|
| 1277 |
-
num_images_per_prompt: Optional[int] = 1,
|
| 1278 |
-
latents: Optional[torch.FloatTensor] = None,
|
| 1279 |
-
num_inference_steps: int = 4,
|
| 1280 |
-
lcm_origin_steps: int = 50,
|
| 1281 |
-
prompt_embeds: Optional[torch.FloatTensor] = None,
|
| 1282 |
-
output_type: Optional[str] = "pil",
|
| 1283 |
-
return_dict: bool = True,
|
| 1284 |
-
model: Optional[Dict[str, any]] = None,
|
| 1285 |
-
seed: Optional[int] = 1234567,
|
| 1286 |
-
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
| 1287 |
-
callback = None,
|
| 1288 |
-
callback_userdata = None
|
| 1289 |
-
):
|
| 1290 |
-
|
| 1291 |
-
# 1. Define call parameters
|
| 1292 |
-
if prompt is not None and isinstance(prompt, str):
|
| 1293 |
-
batch_size = 1
|
| 1294 |
-
elif prompt is not None and isinstance(prompt, list):
|
| 1295 |
-
batch_size = len(prompt)
|
| 1296 |
-
else:
|
| 1297 |
-
batch_size = prompt_embeds.shape[0]
|
| 1298 |
-
|
| 1299 |
-
if seed is not None:
|
| 1300 |
-
torch.manual_seed(seed)
|
| 1301 |
-
|
| 1302 |
-
#print("After Step 1: batch size is ", batch_size)
|
| 1303 |
-
# do_classifier_free_guidance = guidance_scale > 0.0
|
| 1304 |
-
# In LCM Implementation: cfg_noise = noise_cond + cfg_scale * (noise_cond - noise_uncond) , (cfg_scale > 0.0 using CFG)
|
| 1305 |
-
|
| 1306 |
-
# 2. Encode input prompt
|
| 1307 |
-
prompt_embeds = self._encode_prompt(
|
| 1308 |
-
prompt,
|
| 1309 |
-
num_images_per_prompt,
|
| 1310 |
-
prompt_embeds=prompt_embeds,
|
| 1311 |
-
)
|
| 1312 |
-
#print("After Step 2: prompt embeds is ", prompt_embeds)
|
| 1313 |
-
#print("After Step 2: scheduler is ", scheduler )
|
| 1314 |
-
# 3. Prepare timesteps
|
| 1315 |
-
#scheduler.set_timesteps(num_inference_steps, original_inference_steps=lcm_origin_steps)
|
| 1316 |
-
latent_timestep = None
|
| 1317 |
-
if init_image:
|
| 1318 |
-
scheduler.set_timesteps(num_inference_steps, original_inference_steps=lcm_origin_steps)
|
| 1319 |
-
timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, scheduler)
|
| 1320 |
-
latent_timestep = timesteps[:1]
|
| 1321 |
-
else:
|
| 1322 |
-
scheduler.set_timesteps(num_inference_steps, original_inference_steps=lcm_origin_steps)
|
| 1323 |
-
timesteps = scheduler.timesteps
|
| 1324 |
-
#timesteps = scheduler.timesteps
|
| 1325 |
-
#latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt)
|
| 1326 |
-
#print("timesteps: ", latent_timestep)
|
| 1327 |
-
|
| 1328 |
-
#print("After Step 3: timesteps is ", timesteps)
|
| 1329 |
-
|
| 1330 |
-
# 4. Prepare latent variable
|
| 1331 |
-
num_channels_latents = 4
|
| 1332 |
-
latents = self.prepare_latents(
|
| 1333 |
-
init_image,
|
| 1334 |
-
latent_timestep,
|
| 1335 |
-
batch_size * num_images_per_prompt,
|
| 1336 |
-
num_channels_latents,
|
| 1337 |
-
height,
|
| 1338 |
-
width,
|
| 1339 |
-
prompt_embeds.dtype,
|
| 1340 |
-
scheduler,
|
| 1341 |
-
latents,
|
| 1342 |
-
)
|
| 1343 |
-
|
| 1344 |
-
latents = latents * scheduler.init_noise_sigma
|
| 1345 |
-
|
| 1346 |
-
#print("After Step 4: ")
|
| 1347 |
-
bs = batch_size * num_images_per_prompt
|
| 1348 |
-
|
| 1349 |
-
# 5. Get Guidance Scale Embedding
|
| 1350 |
-
w = torch.tensor(guidance_scale).repeat(bs)
|
| 1351 |
-
w_embedding = self.get_w_embedding(w, embedding_dim=256)
|
| 1352 |
-
#print("After Step 5: ")
|
| 1353 |
-
# 6. LCM MultiStep Sampling Loop:
|
| 1354 |
-
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
| 1355 |
-
for i, t in enumerate(timesteps):
|
| 1356 |
-
if callback:
|
| 1357 |
-
callback(i+1, callback_userdata)
|
| 1358 |
-
|
| 1359 |
-
ts = torch.full((bs,), t, dtype=torch.long)
|
| 1360 |
-
|
| 1361 |
-
# model prediction (v-prediction, eps, x)
|
| 1362 |
-
model_pred = self.unet([latents, ts, prompt_embeds, w_embedding],share_inputs=True, share_outputs=True)[0]
|
| 1363 |
-
|
| 1364 |
-
# compute the previous noisy sample x_t -> x_t-1
|
| 1365 |
-
latents, denoised = scheduler.step(
|
| 1366 |
-
torch.from_numpy(model_pred), t, latents, return_dict=False
|
| 1367 |
-
)
|
| 1368 |
-
progress_bar.update()
|
| 1369 |
-
|
| 1370 |
-
#print("After Step 6: ")
|
| 1371 |
-
|
| 1372 |
-
vae_start = time.time()
|
| 1373 |
-
|
| 1374 |
-
if not output_type == "latent":
|
| 1375 |
-
image = torch.from_numpy(self.vae_decoder(denoised / 0.18215, share_inputs=True, share_outputs=True)[0])
|
| 1376 |
-
else:
|
| 1377 |
-
image = denoised
|
| 1378 |
-
|
| 1379 |
-
print("Decoder Ended: ", time.time() - vae_start)
|
| 1380 |
-
#post_start = time.time()
|
| 1381 |
-
|
| 1382 |
-
#if has_nsfw_concept is None:
|
| 1383 |
-
do_denormalize = [True] * image.shape[0]
|
| 1384 |
-
#else:
|
| 1385 |
-
# do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept]
|
| 1386 |
-
|
| 1387 |
-
#print ("After do_denormalize: image is ", image)
|
| 1388 |
-
|
| 1389 |
-
image = self.image_processor.postprocess(
|
| 1390 |
-
image, output_type=output_type, do_denormalize=do_denormalize
|
| 1391 |
-
)
|
| 1392 |
-
|
| 1393 |
-
return image[0]
|
| 1394 |
-
|
| 1395 |
-
class StableDiffusionEngineReferenceOnly(DiffusionPipeline):
|
| 1396 |
-
def __init__(
|
| 1397 |
-
self,
|
| 1398 |
-
#scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler],
|
| 1399 |
-
model="bes-dev/stable-diffusion-v1-4-openvino",
|
| 1400 |
-
tokenizer="openai/clip-vit-large-patch14",
|
| 1401 |
-
device=["CPU","CPU","CPU"]
|
| 1402 |
-
):
|
| 1403 |
-
#self.tokenizer = CLIPTokenizer.from_pretrained(tokenizer)
|
| 1404 |
-
try:
|
| 1405 |
-
self.tokenizer = CLIPTokenizer.from_pretrained(model,local_files_only=True)
|
| 1406 |
-
except:
|
| 1407 |
-
self.tokenizer = CLIPTokenizer.from_pretrained(tokenizer)
|
| 1408 |
-
self.tokenizer.save_pretrained(model)
|
| 1409 |
-
|
| 1410 |
-
#self.scheduler = scheduler
|
| 1411 |
-
# models
|
| 1412 |
-
|
| 1413 |
-
self.core = Core()
|
| 1414 |
-
self.core.set_property({'CACHE_DIR': os.path.join(model, 'cache')}) #adding caching to reduce init time
|
| 1415 |
-
# text features
|
| 1416 |
-
|
| 1417 |
-
print("Text Device:",device[0])
|
| 1418 |
-
self.text_encoder = self.core.compile_model(os.path.join(model, "text_encoder.xml"), device[0])
|
| 1419 |
-
|
| 1420 |
-
self._text_encoder_output = self.text_encoder.output(0)
|
| 1421 |
-
|
| 1422 |
-
# diffusion
|
| 1423 |
-
print("unet_w Device:",device[1])
|
| 1424 |
-
self.unet_w = self.core.compile_model(os.path.join(model, "unet_reference_write.xml"), device[1])
|
| 1425 |
-
self._unet_w_output = self.unet_w.output(0)
|
| 1426 |
-
self.latent_shape = tuple(self.unet_w.inputs[0].shape)[1:]
|
| 1427 |
-
|
| 1428 |
-
print("unet_r Device:",device[1])
|
| 1429 |
-
self.unet_r = self.core.compile_model(os.path.join(model, "unet_reference_read.xml"), device[1])
|
| 1430 |
-
self._unet_r_output = self.unet_r.output(0)
|
| 1431 |
-
# decoder
|
| 1432 |
-
print("Vae Device:",device[2])
|
| 1433 |
-
|
| 1434 |
-
self.vae_decoder = self.core.compile_model(os.path.join(model, "vae_decoder.xml"), device[2])
|
| 1435 |
-
|
| 1436 |
-
# encoder
|
| 1437 |
-
|
| 1438 |
-
self.vae_encoder = self.core.compile_model(os.path.join(model, "vae_encoder.xml"), device[2])
|
| 1439 |
-
|
| 1440 |
-
self.init_image_shape = tuple(self.vae_encoder.inputs[0].shape)[2:]
|
| 1441 |
-
|
| 1442 |
-
self._vae_d_output = self.vae_decoder.output(0)
|
| 1443 |
-
self._vae_e_output = self.vae_encoder.output(0) if self.vae_encoder is not None else None
|
| 1444 |
-
|
| 1445 |
-
self.height = self.unet_w.input(0).shape[2] * 8
|
| 1446 |
-
self.width = self.unet_w.input(0).shape[3] * 8
|
| 1447 |
-
|
| 1448 |
-
|
| 1449 |
-
|
| 1450 |
-
def __call__(
|
| 1451 |
-
self,
|
| 1452 |
-
prompt,
|
| 1453 |
-
image = None,
|
| 1454 |
-
negative_prompt=None,
|
| 1455 |
-
scheduler=None,
|
| 1456 |
-
strength = 1.0,
|
| 1457 |
-
num_inference_steps = 32,
|
| 1458 |
-
guidance_scale = 7.5,
|
| 1459 |
-
eta = 0.0,
|
| 1460 |
-
create_gif = False,
|
| 1461 |
-
model = None,
|
| 1462 |
-
callback = None,
|
| 1463 |
-
callback_userdata = None
|
| 1464 |
-
):
|
| 1465 |
-
# extract condition
|
| 1466 |
-
text_input = self.tokenizer(
|
| 1467 |
-
prompt,
|
| 1468 |
-
padding="max_length",
|
| 1469 |
-
max_length=self.tokenizer.model_max_length,
|
| 1470 |
-
truncation=True,
|
| 1471 |
-
return_tensors="np",
|
| 1472 |
-
)
|
| 1473 |
-
text_embeddings = self.text_encoder(text_input.input_ids)[self._text_encoder_output]
|
| 1474 |
-
|
| 1475 |
-
|
| 1476 |
-
# do classifier free guidance
|
| 1477 |
-
do_classifier_free_guidance = guidance_scale > 1.0
|
| 1478 |
-
if do_classifier_free_guidance:
|
| 1479 |
-
|
| 1480 |
-
if negative_prompt is None:
|
| 1481 |
-
uncond_tokens = [""]
|
| 1482 |
-
elif isinstance(negative_prompt, str):
|
| 1483 |
-
uncond_tokens = [negative_prompt]
|
| 1484 |
-
else:
|
| 1485 |
-
uncond_tokens = negative_prompt
|
| 1486 |
-
|
| 1487 |
-
tokens_uncond = self.tokenizer(
|
| 1488 |
-
uncond_tokens,
|
| 1489 |
-
padding="max_length",
|
| 1490 |
-
max_length=self.tokenizer.model_max_length, #truncation=True,
|
| 1491 |
-
return_tensors="np"
|
| 1492 |
-
)
|
| 1493 |
-
uncond_embeddings = self.text_encoder(tokens_uncond.input_ids)[self._text_encoder_output]
|
| 1494 |
-
text_embeddings = np.concatenate([uncond_embeddings, text_embeddings])
|
| 1495 |
-
|
| 1496 |
-
# set timesteps
|
| 1497 |
-
accepts_offset = "offset" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
|
| 1498 |
-
extra_set_kwargs = {}
|
| 1499 |
-
|
| 1500 |
-
if accepts_offset:
|
| 1501 |
-
extra_set_kwargs["offset"] = 1
|
| 1502 |
-
|
| 1503 |
-
scheduler.set_timesteps(num_inference_steps, **extra_set_kwargs)
|
| 1504 |
-
|
| 1505 |
-
timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, scheduler)
|
| 1506 |
-
latent_timestep = timesteps[:1]
|
| 1507 |
-
|
| 1508 |
-
ref_image = self.prepare_image(
|
| 1509 |
-
image=image,
|
| 1510 |
-
width=512,
|
| 1511 |
-
height=512,
|
| 1512 |
-
)
|
| 1513 |
-
# get the initial random noise unless the user supplied it
|
| 1514 |
-
latents, meta = self.prepare_latents(None, latent_timestep, scheduler)
|
| 1515 |
-
#ref_image_latents, _ = self.prepare_latents(init_image, latent_timestep, scheduler)
|
| 1516 |
-
ref_image_latents = self.ov_prepare_ref_latents(ref_image)
|
| 1517 |
-
|
| 1518 |
-
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
|
| 1519 |
-
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
|
| 1520 |
-
# eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
|
| 1521 |
-
# and should be between [0, 1]
|
| 1522 |
-
accepts_eta = "eta" in set(inspect.signature(scheduler.step).parameters.keys())
|
| 1523 |
-
extra_step_kwargs = {}
|
| 1524 |
-
if accepts_eta:
|
| 1525 |
-
extra_step_kwargs["eta"] = eta
|
| 1526 |
-
if create_gif:
|
| 1527 |
-
frames = []
|
| 1528 |
-
|
| 1529 |
-
for i, t in enumerate(self.progress_bar(timesteps)):
|
| 1530 |
-
if callback:
|
| 1531 |
-
callback(i, callback_userdata)
|
| 1532 |
-
|
| 1533 |
-
# expand the latents if we are doing classifier free guidance
|
| 1534 |
-
latent_model_input = np.concatenate([latents] * 2) if do_classifier_free_guidance else latents
|
| 1535 |
-
latent_model_input = scheduler.scale_model_input(latent_model_input, t)
|
| 1536 |
-
|
| 1537 |
-
# ref only part
|
| 1538 |
-
noise = randn_tensor(
|
| 1539 |
-
ref_image_latents.shape
|
| 1540 |
-
)
|
| 1541 |
-
|
| 1542 |
-
ref_xt = scheduler.add_noise(
|
| 1543 |
-
torch.from_numpy(ref_image_latents),
|
| 1544 |
-
noise,
|
| 1545 |
-
t.reshape(
|
| 1546 |
-
1,
|
| 1547 |
-
),
|
| 1548 |
-
).numpy()
|
| 1549 |
-
ref_xt = np.concatenate([ref_xt] * 2) if do_classifier_free_guidance else ref_xt
|
| 1550 |
-
ref_xt = scheduler.scale_model_input(ref_xt, t)
|
| 1551 |
-
|
| 1552 |
-
# MODE = "write"
|
| 1553 |
-
result_w_dict = self.unet_w([
|
| 1554 |
-
ref_xt,
|
| 1555 |
-
t,
|
| 1556 |
-
text_embeddings
|
| 1557 |
-
])
|
| 1558 |
-
down_0_attn0 = result_w_dict["/unet/down_blocks.0/attentions.0/transformer_blocks.0/norm1/LayerNormalization_output_0"]
|
| 1559 |
-
down_0_attn1 = result_w_dict["/unet/down_blocks.0/attentions.1/transformer_blocks.0/norm1/LayerNormalization_output_0"]
|
| 1560 |
-
down_1_attn0 = result_w_dict["/unet/down_blocks.1/attentions.0/transformer_blocks.0/norm1/LayerNormalization_output_0"]
|
| 1561 |
-
down_1_attn1 = result_w_dict["/unet/down_blocks.1/attentions.1/transformer_blocks.0/norm1/LayerNormalization_output_0"]
|
| 1562 |
-
down_2_attn0 = result_w_dict["/unet/down_blocks.2/attentions.0/transformer_blocks.0/norm1/LayerNormalization_output_0"]
|
| 1563 |
-
down_2_attn1 = result_w_dict["/unet/down_blocks.2/attentions.1/transformer_blocks.0/norm1/LayerNormalization_output_0"]
|
| 1564 |
-
mid_attn0 = result_w_dict["/unet/mid_block/attentions.0/transformer_blocks.0/norm1/LayerNormalization_output_0"]
|
| 1565 |
-
up_1_attn0 = result_w_dict["/unet/up_blocks.1/attentions.0/transformer_blocks.0/norm1/LayerNormalization_output_0"]
|
| 1566 |
-
up_1_attn1 = result_w_dict["/unet/up_blocks.1/attentions.1/transformer_blocks.0/norm1/LayerNormalization_output_0"]
|
| 1567 |
-
up_1_attn2 = result_w_dict["/unet/up_blocks.1/attentions.2/transformer_blocks.0/norm1/LayerNormalization_output_0"]
|
| 1568 |
-
up_2_attn0 = result_w_dict["/unet/up_blocks.2/attentions.0/transformer_blocks.0/norm1/LayerNormalization_output_0"]
|
| 1569 |
-
up_2_attn1 = result_w_dict["/unet/up_blocks.2/attentions.1/transformer_blocks.0/norm1/LayerNormalization_output_0"]
|
| 1570 |
-
up_2_attn2 = result_w_dict["/unet/up_blocks.2/attentions.2/transformer_blocks.0/norm1/LayerNormalization_output_0"]
|
| 1571 |
-
up_3_attn0 = result_w_dict["/unet/up_blocks.3/attentions.0/transformer_blocks.0/norm1/LayerNormalization_output_0"]
|
| 1572 |
-
up_3_attn1 = result_w_dict["/unet/up_blocks.3/attentions.1/transformer_blocks.0/norm1/LayerNormalization_output_0"]
|
| 1573 |
-
up_3_attn2 = result_w_dict["/unet/up_blocks.3/attentions.2/transformer_blocks.0/norm1/LayerNormalization_output_0"]
|
| 1574 |
-
|
| 1575 |
-
# MODE = "read"
|
| 1576 |
-
noise_pred = self.unet_r([
|
| 1577 |
-
latent_model_input, t, text_embeddings, down_0_attn0, down_0_attn1, down_1_attn0,
|
| 1578 |
-
down_1_attn1, down_2_attn0, down_2_attn1, mid_attn0, up_1_attn0, up_1_attn1, up_1_attn2,
|
| 1579 |
-
up_2_attn0, up_2_attn1, up_2_attn2, up_3_attn0, up_3_attn1, up_3_attn2
|
| 1580 |
-
])[0]
|
| 1581 |
-
|
| 1582 |
-
# perform guidance
|
| 1583 |
-
if do_classifier_free_guidance:
|
| 1584 |
-
noise_pred_uncond, noise_pred_text = noise_pred[0], noise_pred[1]
|
| 1585 |
-
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
|
| 1586 |
-
|
| 1587 |
-
# compute the previous noisy sample x_t -> x_t-1
|
| 1588 |
-
latents = scheduler.step(torch.from_numpy(noise_pred), t, torch.from_numpy(latents), **extra_step_kwargs)["prev_sample"].numpy()
|
| 1589 |
-
|
| 1590 |
-
if create_gif:
|
| 1591 |
-
frames.append(latents)
|
| 1592 |
-
|
| 1593 |
-
if callback:
|
| 1594 |
-
callback(num_inference_steps, callback_userdata)
|
| 1595 |
-
|
| 1596 |
-
# scale and decode the image latents with vae
|
| 1597 |
-
|
| 1598 |
-
image = self.vae_decoder(latents)[self._vae_d_output]
|
| 1599 |
-
|
| 1600 |
-
image = self.postprocess_image(image, meta)
|
| 1601 |
-
|
| 1602 |
-
if create_gif:
|
| 1603 |
-
gif_folder=os.path.join(model,"../../../gif")
|
| 1604 |
-
if not os.path.exists(gif_folder):
|
| 1605 |
-
os.makedirs(gif_folder)
|
| 1606 |
-
for i in range(0,len(frames)):
|
| 1607 |
-
image = self.vae_decoder(frames[i])[self._vae_d_output]
|
| 1608 |
-
image = self.postprocess_image(image, meta)
|
| 1609 |
-
output = gif_folder + "/" + str(i).zfill(3) +".png"
|
| 1610 |
-
cv2.imwrite(output, image)
|
| 1611 |
-
with open(os.path.join(gif_folder, "prompt.json"), "w") as file:
|
| 1612 |
-
json.dump({"prompt": prompt}, file)
|
| 1613 |
-
frames_image = [Image.open(image) for image in glob.glob(f"{gif_folder}/*.png")]
|
| 1614 |
-
frame_one = frames_image[0]
|
| 1615 |
-
gif_file=os.path.join(gif_folder,"stable_diffusion.gif")
|
| 1616 |
-
frame_one.save(gif_file, format="GIF", append_images=frames_image, save_all=True, duration=100, loop=0)
|
| 1617 |
-
|
| 1618 |
-
return image
|
| 1619 |
-
|
| 1620 |
-
def ov_prepare_ref_latents(self, refimage, vae_scaling_factor=0.18215):
|
| 1621 |
-
#refimage = refimage.to(device=device, dtype=dtype)
|
| 1622 |
-
|
| 1623 |
-
# encode the mask image into latents space so we can concatenate it to the latents
|
| 1624 |
-
moments = self.vae_encoder(refimage)[0]
|
| 1625 |
-
mean, logvar = np.split(moments, 2, axis=1)
|
| 1626 |
-
std = np.exp(logvar * 0.5)
|
| 1627 |
-
ref_image_latents = (mean + std * np.random.randn(*mean.shape))
|
| 1628 |
-
ref_image_latents = vae_scaling_factor * ref_image_latents
|
| 1629 |
-
#ref_image_latents = scheduler.add_noise(torch.from_numpy(ref_image_latents), torch.from_numpy(noise), latent_timestep).numpy()
|
| 1630 |
-
|
| 1631 |
-
# aligning device to prevent device errors when concating it with the latent model input
|
| 1632 |
-
#ref_image_latents = ref_image_latents.to(device=device, dtype=dtype)
|
| 1633 |
-
return ref_image_latents
|
| 1634 |
-
|
| 1635 |
-
def prepare_latents(self, image:PIL.Image.Image = None, latent_timestep:torch.Tensor = None, scheduler = LMSDiscreteScheduler):
|
| 1636 |
-
"""
|
| 1637 |
-
Function for getting initial latents for starting generation
|
| 1638 |
-
|
| 1639 |
-
Parameters:
|
| 1640 |
-
image (PIL.Image.Image, *optional*, None):
|
| 1641 |
-
Input image for generation, if not provided randon noise will be used as starting point
|
| 1642 |
-
latent_timestep (torch.Tensor, *optional*, None):
|
| 1643 |
-
Predicted by scheduler initial step for image generation, required for latent image mixing with nosie
|
| 1644 |
-
Returns:
|
| 1645 |
-
latents (np.ndarray):
|
| 1646 |
-
Image encoded in latent space
|
| 1647 |
-
"""
|
| 1648 |
-
latents_shape = (1, 4, self.height // 8, self.width // 8)
|
| 1649 |
-
|
| 1650 |
-
noise = np.random.randn(*latents_shape).astype(np.float32)
|
| 1651 |
-
if image is None:
|
| 1652 |
-
#print("Image is NONE")
|
| 1653 |
-
# if we use LMSDiscreteScheduler, let's make sure latents are mulitplied by sigmas
|
| 1654 |
-
if isinstance(scheduler, LMSDiscreteScheduler):
|
| 1655 |
-
|
| 1656 |
-
noise = noise * scheduler.sigmas[0].numpy()
|
| 1657 |
-
return noise, {}
|
| 1658 |
-
elif isinstance(scheduler, EulerDiscreteScheduler):
|
| 1659 |
-
|
| 1660 |
-
noise = noise * scheduler.sigmas.max().numpy()
|
| 1661 |
-
return noise, {}
|
| 1662 |
-
else:
|
| 1663 |
-
return noise, {}
|
| 1664 |
-
input_image, meta = preprocess(image,self.height,self.width)
|
| 1665 |
-
|
| 1666 |
-
moments = self.vae_encoder(input_image)[self._vae_e_output]
|
| 1667 |
-
|
| 1668 |
-
mean, logvar = np.split(moments, 2, axis=1)
|
| 1669 |
-
|
| 1670 |
-
std = np.exp(logvar * 0.5)
|
| 1671 |
-
latents = (mean + std * np.random.randn(*mean.shape)) * 0.18215
|
| 1672 |
-
|
| 1673 |
-
|
| 1674 |
-
latents = scheduler.add_noise(torch.from_numpy(latents), torch.from_numpy(noise), latent_timestep).numpy()
|
| 1675 |
-
return latents, meta
|
| 1676 |
-
|
| 1677 |
-
def postprocess_image(self, image:np.ndarray, meta:Dict):
|
| 1678 |
-
"""
|
| 1679 |
-
Postprocessing for decoded image. Takes generated image decoded by VAE decoder, unpad it to initila image size (if required),
|
| 1680 |
-
normalize and convert to [0, 255] pixels range. Optionally, convertes it from np.ndarray to PIL.Image format
|
| 1681 |
-
|
| 1682 |
-
Parameters:
|
| 1683 |
-
image (np.ndarray):
|
| 1684 |
-
Generated image
|
| 1685 |
-
meta (Dict):
|
| 1686 |
-
Metadata obtained on latents preparing step, can be empty
|
| 1687 |
-
output_type (str, *optional*, pil):
|
| 1688 |
-
Output format for result, can be pil or numpy
|
| 1689 |
-
Returns:
|
| 1690 |
-
image (List of np.ndarray or PIL.Image.Image):
|
| 1691 |
-
Postprocessed images
|
| 1692 |
-
|
| 1693 |
-
if "src_height" in meta:
|
| 1694 |
-
orig_height, orig_width = meta["src_height"], meta["src_width"]
|
| 1695 |
-
image = [cv2.resize(img, (orig_width, orig_height))
|
| 1696 |
-
for img in image]
|
| 1697 |
-
|
| 1698 |
-
return image
|
| 1699 |
-
"""
|
| 1700 |
-
if "padding" in meta:
|
| 1701 |
-
pad = meta["padding"]
|
| 1702 |
-
(_, end_h), (_, end_w) = pad[1:3]
|
| 1703 |
-
h, w = image.shape[2:]
|
| 1704 |
-
#print("image shape",image.shape[2:])
|
| 1705 |
-
unpad_h = h - end_h
|
| 1706 |
-
unpad_w = w - end_w
|
| 1707 |
-
image = image[:, :, :unpad_h, :unpad_w]
|
| 1708 |
-
image = np.clip(image / 2 + 0.5, 0, 1)
|
| 1709 |
-
image = (image[0].transpose(1, 2, 0)[:, :, ::-1] * 255).astype(np.uint8)
|
| 1710 |
-
|
| 1711 |
-
|
| 1712 |
-
|
| 1713 |
-
if "src_height" in meta:
|
| 1714 |
-
orig_height, orig_width = meta["src_height"], meta["src_width"]
|
| 1715 |
-
image = cv2.resize(image, (orig_width, orig_height))
|
| 1716 |
-
|
| 1717 |
-
return image
|
| 1718 |
-
|
| 1719 |
-
|
| 1720 |
-
#image = (image / 2 + 0.5).clip(0, 1)
|
| 1721 |
-
#image = (image[0].transpose(1, 2, 0)[:, :, ::-1] * 255).astype(np.uint8)
|
| 1722 |
-
|
| 1723 |
-
|
| 1724 |
-
def get_timesteps(self, num_inference_steps:int, strength:float, scheduler):
|
| 1725 |
-
"""
|
| 1726 |
-
Helper function for getting scheduler timesteps for generation
|
| 1727 |
-
In case of image-to-image generation, it updates number of steps according to strength
|
| 1728 |
-
|
| 1729 |
-
Parameters:
|
| 1730 |
-
num_inference_steps (int):
|
| 1731 |
-
number of inference steps for generation
|
| 1732 |
-
strength (float):
|
| 1733 |
-
value between 0.0 and 1.0, that controls the amount of noise that is added to the input image.
|
| 1734 |
-
Values that approach 1.0 allow for lots of variations but will also produce images that are not semantically consistent with the input.
|
| 1735 |
-
"""
|
| 1736 |
-
# get the original timestep using init_timestep
|
| 1737 |
-
|
| 1738 |
-
init_timestep = min(int(num_inference_steps * strength), num_inference_steps)
|
| 1739 |
-
|
| 1740 |
-
t_start = max(num_inference_steps - init_timestep, 0)
|
| 1741 |
-
timesteps = scheduler.timesteps[t_start:]
|
| 1742 |
-
|
| 1743 |
-
return timesteps, num_inference_steps - t_start
|
| 1744 |
-
def prepare_image(
|
| 1745 |
-
self,
|
| 1746 |
-
image,
|
| 1747 |
-
width,
|
| 1748 |
-
height,
|
| 1749 |
-
do_classifier_free_guidance=False,
|
| 1750 |
-
guess_mode=False,
|
| 1751 |
-
):
|
| 1752 |
-
if not isinstance(image, np.ndarray):
|
| 1753 |
-
if isinstance(image, PIL.Image.Image):
|
| 1754 |
-
image = [image]
|
| 1755 |
-
|
| 1756 |
-
if isinstance(image[0], PIL.Image.Image):
|
| 1757 |
-
images = []
|
| 1758 |
-
|
| 1759 |
-
for image_ in image:
|
| 1760 |
-
image_ = image_.convert("RGB")
|
| 1761 |
-
image_ = image_.resize((width, height), resample=PIL_INTERPOLATION["lanczos"])
|
| 1762 |
-
image_ = np.array(image_)
|
| 1763 |
-
image_ = image_[None, :]
|
| 1764 |
-
images.append(image_)
|
| 1765 |
-
|
| 1766 |
-
image = images
|
| 1767 |
-
|
| 1768 |
-
image = np.concatenate(image, axis=0)
|
| 1769 |
-
image = np.array(image).astype(np.float32) / 255.0
|
| 1770 |
-
image = (image - 0.5) / 0.5
|
| 1771 |
-
image = image.transpose(0, 3, 1, 2)
|
| 1772 |
-
elif isinstance(image[0], np.ndarray):
|
| 1773 |
-
image = np.concatenate(image, dim=0)
|
| 1774 |
-
|
| 1775 |
-
if do_classifier_free_guidance and not guess_mode:
|
| 1776 |
-
image = np.concatenate([image] * 2)
|
| 1777 |
-
|
| 1778 |
-
return image
|
| 1779 |
-
|
| 1780 |
-
def print_npu_turbo_art():
|
| 1781 |
-
random_number = random.randint(1, 3)
|
| 1782 |
-
|
| 1783 |
-
if random_number == 1:
|
| 1784 |
-
print(" ")
|
| 1785 |
-
print(" ___ ___ ___ ___ ___ ___ ")
|
| 1786 |
-
print(" /\ \ /\ \ /\ \ /\ \ /\ \ _____ /\ \ ")
|
| 1787 |
-
print(" \:\ \ /::\ \ \:\ \ ___ \:\ \ /::\ \ /::\ \ /::\ \ ")
|
| 1788 |
-
print(" \:\ \ /:/\:\__\ \:\ \ /\__\ \:\ \ /:/\:\__\ /:/\:\ \ /:/\:\ \ ")
|
| 1789 |
-
print(" _____\:\ \ /:/ /:/ / ___ \:\ \ /:/ / ___ \:\ \ /:/ /:/ / /:/ /::\__\ /:/ \:\ \ ")
|
| 1790 |
-
print(" /::::::::\__\ /:/_/:/ / /\ \ \:\__\ /:/__/ /\ \ \:\__\ /:/_/:/__/___ /:/_/:/\:|__| /:/__/ \:\__\ ")
|
| 1791 |
-
print(" \:\~~\~~\/__/ \:\/:/ / \:\ \ /:/ / /::\ \ \:\ \ /:/ / \:\/:::::/ / \:\/:/ /:/ / \:\ \ /:/ / ")
|
| 1792 |
-
print(" \:\ \ \::/__/ \:\ /:/ / /:/\:\ \ \:\ /:/ / \::/~~/~~~~ \::/_/:/ / \:\ /:/ / ")
|
| 1793 |
-
print(" \:\ \ \:\ \ \:\/:/ / \/__\:\ \ \:\/:/ / \:\~~\ \:\/:/ / \:\/:/ / ")
|
| 1794 |
-
print(" \:\__\ \:\__\ \::/ / \:\__\ \::/ / \:\__\ \::/ / \::/ / ")
|
| 1795 |
-
print(" \/__/ \/__/ \/__/ \/__/ \/__/ \/__/ \/__/ \/__/ ")
|
| 1796 |
-
print(" ")
|
| 1797 |
-
elif random_number == 2:
|
| 1798 |
-
print(" _ _ ____ _ _ _____ _ _ ____ ____ ___ ")
|
| 1799 |
-
print("| \ | | | _ \ | | | | |_ _| | | | | | _ \ | __ ) / _ \ ")
|
| 1800 |
-
print("| \| | | |_) | | | | | | | | | | | | |_) | | _ \ | | | |")
|
| 1801 |
-
print("| |\ | | __/ | |_| | | | | |_| | | _ < | |_) | | |_| |")
|
| 1802 |
-
print("|_| \_| |_| \___/ |_| \___/ |_| \_\ |____/ \___/ ")
|
| 1803 |
-
print(" ")
|
| 1804 |
-
else:
|
| 1805 |
-
print("")
|
| 1806 |
-
print(" ) ( ( ) ")
|
| 1807 |
-
print(" ( /( )\ ) * ) )\ ) ( ( /( ")
|
| 1808 |
-
print(" )\()) (()/( ( ` ) /( ( (()/( ( )\ )\()) ")
|
| 1809 |
-
print("((_)\ /(_)) )\ ( )(_)) )\ /(_)) )((_) ((_)\ ")
|
| 1810 |
-
print(" _((_) (_)) _ ((_) (_(_()) _ ((_) (_)) ((_)_ ((_) ")
|
| 1811 |
-
print("| \| | | _ \ | | | | |_ _| | | | | | _ \ | _ ) / _ \ ")
|
| 1812 |
-
print("| .` | | _/ | |_| | | | | |_| | | / | _ \ | (_) | ")
|
| 1813 |
-
print("|_|\_| |_| \___/ |_| \___/ |_|_\ |___/ \___/ ")
|
| 1814 |
-
print(" ")
|
| 1815 |
-
|
| 1816 |
-
|
| 1817 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src/backend/pipelines/lcm.py
DELETED
|
@@ -1,122 +0,0 @@
|
|
| 1 |
-
from constants import LCM_DEFAULT_MODEL
|
| 2 |
-
from diffusers import (
|
| 3 |
-
DiffusionPipeline,
|
| 4 |
-
AutoencoderTiny,
|
| 5 |
-
UNet2DConditionModel,
|
| 6 |
-
LCMScheduler,
|
| 7 |
-
StableDiffusionPipeline,
|
| 8 |
-
)
|
| 9 |
-
import torch
|
| 10 |
-
from backend.tiny_decoder import get_tiny_decoder_vae_model
|
| 11 |
-
from typing import Any
|
| 12 |
-
from diffusers import (
|
| 13 |
-
LCMScheduler,
|
| 14 |
-
StableDiffusionImg2ImgPipeline,
|
| 15 |
-
StableDiffusionXLImg2ImgPipeline,
|
| 16 |
-
AutoPipelineForText2Image,
|
| 17 |
-
AutoPipelineForImage2Image,
|
| 18 |
-
StableDiffusionControlNetPipeline,
|
| 19 |
-
)
|
| 20 |
-
import pathlib
|
| 21 |
-
|
| 22 |
-
|
| 23 |
-
def _get_lcm_pipeline_from_base_model(
|
| 24 |
-
lcm_model_id: str,
|
| 25 |
-
base_model_id: str,
|
| 26 |
-
use_local_model: bool,
|
| 27 |
-
):
|
| 28 |
-
pipeline = None
|
| 29 |
-
unet = UNet2DConditionModel.from_pretrained(
|
| 30 |
-
lcm_model_id,
|
| 31 |
-
torch_dtype=torch.float32,
|
| 32 |
-
local_files_only=use_local_model,
|
| 33 |
-
resume_download=True,
|
| 34 |
-
)
|
| 35 |
-
pipeline = DiffusionPipeline.from_pretrained(
|
| 36 |
-
base_model_id,
|
| 37 |
-
unet=unet,
|
| 38 |
-
torch_dtype=torch.float32,
|
| 39 |
-
local_files_only=use_local_model,
|
| 40 |
-
resume_download=True,
|
| 41 |
-
)
|
| 42 |
-
pipeline.scheduler = LCMScheduler.from_config(pipeline.scheduler.config)
|
| 43 |
-
return pipeline
|
| 44 |
-
|
| 45 |
-
|
| 46 |
-
def load_taesd(
|
| 47 |
-
pipeline: Any,
|
| 48 |
-
use_local_model: bool = False,
|
| 49 |
-
torch_data_type: torch.dtype = torch.float32,
|
| 50 |
-
):
|
| 51 |
-
vae_model = get_tiny_decoder_vae_model(pipeline.__class__.__name__)
|
| 52 |
-
pipeline.vae = AutoencoderTiny.from_pretrained(
|
| 53 |
-
vae_model,
|
| 54 |
-
torch_dtype=torch_data_type,
|
| 55 |
-
local_files_only=use_local_model,
|
| 56 |
-
)
|
| 57 |
-
|
| 58 |
-
|
| 59 |
-
def get_lcm_model_pipeline(
|
| 60 |
-
model_id: str = LCM_DEFAULT_MODEL,
|
| 61 |
-
use_local_model: bool = False,
|
| 62 |
-
pipeline_args={},
|
| 63 |
-
):
|
| 64 |
-
pipeline = None
|
| 65 |
-
if model_id == "latent-consistency/lcm-sdxl":
|
| 66 |
-
pipeline = _get_lcm_pipeline_from_base_model(
|
| 67 |
-
model_id,
|
| 68 |
-
"stabilityai/stable-diffusion-xl-base-1.0",
|
| 69 |
-
use_local_model,
|
| 70 |
-
)
|
| 71 |
-
|
| 72 |
-
elif model_id == "latent-consistency/lcm-ssd-1b":
|
| 73 |
-
pipeline = _get_lcm_pipeline_from_base_model(
|
| 74 |
-
model_id,
|
| 75 |
-
"segmind/SSD-1B",
|
| 76 |
-
use_local_model,
|
| 77 |
-
)
|
| 78 |
-
elif pathlib.Path(model_id).suffix == ".safetensors":
|
| 79 |
-
# When loading a .safetensors model, the pipeline has to be created
|
| 80 |
-
# with StableDiffusionPipeline() since it's the only class that
|
| 81 |
-
# defines the method from_single_file()
|
| 82 |
-
dummy_pipeline = StableDiffusionPipeline.from_single_file(
|
| 83 |
-
model_id,
|
| 84 |
-
safety_checker=None,
|
| 85 |
-
run_safety_checker=False,
|
| 86 |
-
load_safety_checker=False,
|
| 87 |
-
local_files_only=use_local_model,
|
| 88 |
-
use_safetensors=True,
|
| 89 |
-
)
|
| 90 |
-
if 'lcm' in model_id.lower():
|
| 91 |
-
dummy_pipeline.scheduler = LCMScheduler.from_config(dummy_pipeline.scheduler.config)
|
| 92 |
-
|
| 93 |
-
pipeline = AutoPipelineForText2Image.from_pipe(
|
| 94 |
-
dummy_pipeline,
|
| 95 |
-
**pipeline_args,
|
| 96 |
-
)
|
| 97 |
-
del dummy_pipeline
|
| 98 |
-
else:
|
| 99 |
-
# pipeline = DiffusionPipeline.from_pretrained(
|
| 100 |
-
pipeline = AutoPipelineForText2Image.from_pretrained(
|
| 101 |
-
model_id,
|
| 102 |
-
local_files_only=use_local_model,
|
| 103 |
-
**pipeline_args,
|
| 104 |
-
)
|
| 105 |
-
|
| 106 |
-
return pipeline
|
| 107 |
-
|
| 108 |
-
|
| 109 |
-
def get_image_to_image_pipeline(pipeline: Any) -> Any:
|
| 110 |
-
components = pipeline.components
|
| 111 |
-
pipeline_class = pipeline.__class__.__name__
|
| 112 |
-
if (
|
| 113 |
-
pipeline_class == "LatentConsistencyModelPipeline"
|
| 114 |
-
or pipeline_class == "StableDiffusionPipeline"
|
| 115 |
-
):
|
| 116 |
-
return StableDiffusionImg2ImgPipeline(**components)
|
| 117 |
-
elif pipeline_class == "StableDiffusionControlNetPipeline":
|
| 118 |
-
return AutoPipelineForImage2Image.from_pipe(pipeline)
|
| 119 |
-
elif pipeline_class == "StableDiffusionXLPipeline":
|
| 120 |
-
return StableDiffusionXLImg2ImgPipeline(**components)
|
| 121 |
-
else:
|
| 122 |
-
raise Exception(f"Unknown pipeline {pipeline_class}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src/backend/pipelines/lcm_lora.py
DELETED
|
@@ -1,81 +0,0 @@
|
|
| 1 |
-
import pathlib
|
| 2 |
-
from os import path
|
| 3 |
-
|
| 4 |
-
import torch
|
| 5 |
-
from diffusers import (
|
| 6 |
-
AutoPipelineForText2Image,
|
| 7 |
-
LCMScheduler,
|
| 8 |
-
StableDiffusionPipeline,
|
| 9 |
-
)
|
| 10 |
-
|
| 11 |
-
|
| 12 |
-
def load_lcm_weights(
|
| 13 |
-
pipeline,
|
| 14 |
-
use_local_model,
|
| 15 |
-
lcm_lora_id,
|
| 16 |
-
):
|
| 17 |
-
kwargs = {
|
| 18 |
-
"local_files_only": use_local_model,
|
| 19 |
-
"weight_name": "pytorch_lora_weights.safetensors",
|
| 20 |
-
}
|
| 21 |
-
pipeline.load_lora_weights(
|
| 22 |
-
lcm_lora_id,
|
| 23 |
-
**kwargs,
|
| 24 |
-
adapter_name="lcm",
|
| 25 |
-
)
|
| 26 |
-
|
| 27 |
-
|
| 28 |
-
def get_lcm_lora_pipeline(
|
| 29 |
-
base_model_id: str,
|
| 30 |
-
lcm_lora_id: str,
|
| 31 |
-
use_local_model: bool,
|
| 32 |
-
torch_data_type: torch.dtype,
|
| 33 |
-
pipeline_args={},
|
| 34 |
-
):
|
| 35 |
-
if pathlib.Path(base_model_id).suffix == ".safetensors":
|
| 36 |
-
# SD 1.5 models only
|
| 37 |
-
# When loading a .safetensors model, the pipeline has to be created
|
| 38 |
-
# with StableDiffusionPipeline() since it's the only class that
|
| 39 |
-
# defines the method from_single_file(); afterwards a new pipeline
|
| 40 |
-
# is created using AutoPipelineForText2Image() for ControlNet
|
| 41 |
-
# support, in case ControlNet is enabled
|
| 42 |
-
if not path.exists(base_model_id):
|
| 43 |
-
raise FileNotFoundError(
|
| 44 |
-
f"Model file not found,Please check your model path: {base_model_id}"
|
| 45 |
-
)
|
| 46 |
-
print("Using single file Safetensors model (Supported models - SD 1.5 models)")
|
| 47 |
-
|
| 48 |
-
dummy_pipeline = StableDiffusionPipeline.from_single_file(
|
| 49 |
-
base_model_id,
|
| 50 |
-
torch_dtype=torch_data_type,
|
| 51 |
-
safety_checker=None,
|
| 52 |
-
local_files_only=use_local_model,
|
| 53 |
-
use_safetensors=True,
|
| 54 |
-
)
|
| 55 |
-
pipeline = AutoPipelineForText2Image.from_pipe(
|
| 56 |
-
dummy_pipeline,
|
| 57 |
-
**pipeline_args,
|
| 58 |
-
)
|
| 59 |
-
del dummy_pipeline
|
| 60 |
-
else:
|
| 61 |
-
pipeline = AutoPipelineForText2Image.from_pretrained(
|
| 62 |
-
base_model_id,
|
| 63 |
-
torch_dtype=torch_data_type,
|
| 64 |
-
local_files_only=use_local_model,
|
| 65 |
-
**pipeline_args,
|
| 66 |
-
)
|
| 67 |
-
|
| 68 |
-
load_lcm_weights(
|
| 69 |
-
pipeline,
|
| 70 |
-
use_local_model,
|
| 71 |
-
lcm_lora_id,
|
| 72 |
-
)
|
| 73 |
-
# Always fuse LCM-LoRA
|
| 74 |
-
# pipeline.fuse_lora()
|
| 75 |
-
|
| 76 |
-
if "lcm" in lcm_lora_id.lower() or "hypersd" in lcm_lora_id.lower():
|
| 77 |
-
print("LCM LoRA model detected so using recommended LCMScheduler")
|
| 78 |
-
pipeline.scheduler = LCMScheduler.from_config(pipeline.scheduler.config)
|
| 79 |
-
|
| 80 |
-
# pipeline.unet.to(memory_format=torch.channels_last)
|
| 81 |
-
return pipeline
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src/backend/tiny_decoder.py
DELETED
|
@@ -1,32 +0,0 @@
|
|
| 1 |
-
from constants import (
|
| 2 |
-
TAESD_MODEL,
|
| 3 |
-
TAESDXL_MODEL,
|
| 4 |
-
TAESD_MODEL_OPENVINO,
|
| 5 |
-
TAESDXL_MODEL_OPENVINO,
|
| 6 |
-
)
|
| 7 |
-
|
| 8 |
-
|
| 9 |
-
def get_tiny_decoder_vae_model(pipeline_class) -> str:
|
| 10 |
-
print(f"Pipeline class : {pipeline_class}")
|
| 11 |
-
if (
|
| 12 |
-
pipeline_class == "LatentConsistencyModelPipeline"
|
| 13 |
-
or pipeline_class == "StableDiffusionPipeline"
|
| 14 |
-
or pipeline_class == "StableDiffusionImg2ImgPipeline"
|
| 15 |
-
or pipeline_class == "StableDiffusionControlNetPipeline"
|
| 16 |
-
or pipeline_class == "StableDiffusionControlNetImg2ImgPipeline"
|
| 17 |
-
):
|
| 18 |
-
return TAESD_MODEL
|
| 19 |
-
elif (
|
| 20 |
-
pipeline_class == "StableDiffusionXLPipeline"
|
| 21 |
-
or pipeline_class == "StableDiffusionXLImg2ImgPipeline"
|
| 22 |
-
):
|
| 23 |
-
return TAESDXL_MODEL
|
| 24 |
-
elif (
|
| 25 |
-
pipeline_class == "OVStableDiffusionPipeline"
|
| 26 |
-
or pipeline_class == "OVStableDiffusionImg2ImgPipeline"
|
| 27 |
-
):
|
| 28 |
-
return TAESD_MODEL_OPENVINO
|
| 29 |
-
elif pipeline_class == "OVStableDiffusionXLPipeline":
|
| 30 |
-
return TAESDXL_MODEL_OPENVINO
|
| 31 |
-
else:
|
| 32 |
-
raise Exception("No valid pipeline class found!")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src/backend/upscale/aura_sr.py
DELETED
|
@@ -1,1004 +0,0 @@
|
|
| 1 |
-
# AuraSR: GAN-based Super-Resolution for real-world, a reproduction of the GigaGAN* paper. Implementation is
|
| 2 |
-
# based on the unofficial lucidrains/gigagan-pytorch repository. Heavily modified from there.
|
| 3 |
-
#
|
| 4 |
-
# https://mingukkang.github.io/GigaGAN/
|
| 5 |
-
from math import log2, ceil
|
| 6 |
-
from functools import partial
|
| 7 |
-
from typing import Any, Optional, List, Iterable
|
| 8 |
-
|
| 9 |
-
import torch
|
| 10 |
-
from torchvision import transforms
|
| 11 |
-
from PIL import Image
|
| 12 |
-
from torch import nn, einsum, Tensor
|
| 13 |
-
import torch.nn.functional as F
|
| 14 |
-
|
| 15 |
-
from einops import rearrange, repeat, reduce
|
| 16 |
-
from einops.layers.torch import Rearrange
|
| 17 |
-
from torchvision.utils import save_image
|
| 18 |
-
import math
|
| 19 |
-
|
| 20 |
-
|
| 21 |
-
def get_same_padding(size, kernel, dilation, stride):
|
| 22 |
-
return ((size - 1) * (stride - 1) + dilation * (kernel - 1)) // 2
|
| 23 |
-
|
| 24 |
-
|
| 25 |
-
class AdaptiveConv2DMod(nn.Module):
|
| 26 |
-
def __init__(
|
| 27 |
-
self,
|
| 28 |
-
dim,
|
| 29 |
-
dim_out,
|
| 30 |
-
kernel,
|
| 31 |
-
*,
|
| 32 |
-
demod=True,
|
| 33 |
-
stride=1,
|
| 34 |
-
dilation=1,
|
| 35 |
-
eps=1e-8,
|
| 36 |
-
num_conv_kernels=1, # set this to be greater than 1 for adaptive
|
| 37 |
-
):
|
| 38 |
-
super().__init__()
|
| 39 |
-
self.eps = eps
|
| 40 |
-
|
| 41 |
-
self.dim_out = dim_out
|
| 42 |
-
|
| 43 |
-
self.kernel = kernel
|
| 44 |
-
self.stride = stride
|
| 45 |
-
self.dilation = dilation
|
| 46 |
-
self.adaptive = num_conv_kernels > 1
|
| 47 |
-
|
| 48 |
-
self.weights = nn.Parameter(
|
| 49 |
-
torch.randn((num_conv_kernels, dim_out, dim, kernel, kernel))
|
| 50 |
-
)
|
| 51 |
-
|
| 52 |
-
self.demod = demod
|
| 53 |
-
|
| 54 |
-
nn.init.kaiming_normal_(
|
| 55 |
-
self.weights, a=0, mode="fan_in", nonlinearity="leaky_relu"
|
| 56 |
-
)
|
| 57 |
-
|
| 58 |
-
def forward(
|
| 59 |
-
self, fmap, mod: Optional[Tensor] = None, kernel_mod: Optional[Tensor] = None
|
| 60 |
-
):
|
| 61 |
-
"""
|
| 62 |
-
notation
|
| 63 |
-
|
| 64 |
-
b - batch
|
| 65 |
-
n - convs
|
| 66 |
-
o - output
|
| 67 |
-
i - input
|
| 68 |
-
k - kernel
|
| 69 |
-
"""
|
| 70 |
-
|
| 71 |
-
b, h = fmap.shape[0], fmap.shape[-2]
|
| 72 |
-
|
| 73 |
-
# account for feature map that has been expanded by the scale in the first dimension
|
| 74 |
-
# due to multiscale inputs and outputs
|
| 75 |
-
|
| 76 |
-
if mod.shape[0] != b:
|
| 77 |
-
mod = repeat(mod, "b ... -> (s b) ...", s=b // mod.shape[0])
|
| 78 |
-
|
| 79 |
-
if exists(kernel_mod):
|
| 80 |
-
kernel_mod_has_el = kernel_mod.numel() > 0
|
| 81 |
-
|
| 82 |
-
assert self.adaptive or not kernel_mod_has_el
|
| 83 |
-
|
| 84 |
-
if kernel_mod_has_el and kernel_mod.shape[0] != b:
|
| 85 |
-
kernel_mod = repeat(
|
| 86 |
-
kernel_mod, "b ... -> (s b) ...", s=b // kernel_mod.shape[0]
|
| 87 |
-
)
|
| 88 |
-
|
| 89 |
-
# prepare weights for modulation
|
| 90 |
-
|
| 91 |
-
weights = self.weights
|
| 92 |
-
|
| 93 |
-
if self.adaptive:
|
| 94 |
-
weights = repeat(weights, "... -> b ...", b=b)
|
| 95 |
-
|
| 96 |
-
# determine an adaptive weight and 'select' the kernel to use with softmax
|
| 97 |
-
|
| 98 |
-
assert exists(kernel_mod) and kernel_mod.numel() > 0
|
| 99 |
-
|
| 100 |
-
kernel_attn = kernel_mod.softmax(dim=-1)
|
| 101 |
-
kernel_attn = rearrange(kernel_attn, "b n -> b n 1 1 1 1")
|
| 102 |
-
|
| 103 |
-
weights = reduce(weights * kernel_attn, "b n ... -> b ...", "sum")
|
| 104 |
-
|
| 105 |
-
# do the modulation, demodulation, as done in stylegan2
|
| 106 |
-
|
| 107 |
-
mod = rearrange(mod, "b i -> b 1 i 1 1")
|
| 108 |
-
|
| 109 |
-
weights = weights * (mod + 1)
|
| 110 |
-
|
| 111 |
-
if self.demod:
|
| 112 |
-
inv_norm = (
|
| 113 |
-
reduce(weights**2, "b o i k1 k2 -> b o 1 1 1", "sum")
|
| 114 |
-
.clamp(min=self.eps)
|
| 115 |
-
.rsqrt()
|
| 116 |
-
)
|
| 117 |
-
weights = weights * inv_norm
|
| 118 |
-
|
| 119 |
-
fmap = rearrange(fmap, "b c h w -> 1 (b c) h w")
|
| 120 |
-
|
| 121 |
-
weights = rearrange(weights, "b o ... -> (b o) ...")
|
| 122 |
-
|
| 123 |
-
padding = get_same_padding(h, self.kernel, self.dilation, self.stride)
|
| 124 |
-
fmap = F.conv2d(fmap, weights, padding=padding, groups=b)
|
| 125 |
-
|
| 126 |
-
return rearrange(fmap, "1 (b o) ... -> b o ...", b=b)
|
| 127 |
-
|
| 128 |
-
|
| 129 |
-
class Attend(nn.Module):
|
| 130 |
-
def __init__(self, dropout=0.0, flash=False):
|
| 131 |
-
super().__init__()
|
| 132 |
-
self.dropout = dropout
|
| 133 |
-
self.attn_dropout = nn.Dropout(dropout)
|
| 134 |
-
self.scale = nn.Parameter(torch.randn(1))
|
| 135 |
-
self.flash = flash
|
| 136 |
-
|
| 137 |
-
def flash_attn(self, q, k, v):
|
| 138 |
-
q, k, v = map(lambda t: t.contiguous(), (q, k, v))
|
| 139 |
-
out = F.scaled_dot_product_attention(
|
| 140 |
-
q, k, v, dropout_p=self.dropout if self.training else 0.0
|
| 141 |
-
)
|
| 142 |
-
return out
|
| 143 |
-
|
| 144 |
-
def forward(self, q, k, v):
|
| 145 |
-
if self.flash:
|
| 146 |
-
return self.flash_attn(q, k, v)
|
| 147 |
-
|
| 148 |
-
scale = q.shape[-1] ** -0.5
|
| 149 |
-
|
| 150 |
-
# similarity
|
| 151 |
-
sim = einsum("b h i d, b h j d -> b h i j", q, k) * scale
|
| 152 |
-
|
| 153 |
-
# attention
|
| 154 |
-
attn = sim.softmax(dim=-1)
|
| 155 |
-
attn = self.attn_dropout(attn)
|
| 156 |
-
|
| 157 |
-
# aggregate values
|
| 158 |
-
out = einsum("b h i j, b h j d -> b h i d", attn, v)
|
| 159 |
-
|
| 160 |
-
return out
|
| 161 |
-
|
| 162 |
-
|
| 163 |
-
def exists(x):
|
| 164 |
-
return x is not None
|
| 165 |
-
|
| 166 |
-
|
| 167 |
-
def default(val, d):
|
| 168 |
-
if exists(val):
|
| 169 |
-
return val
|
| 170 |
-
return d() if callable(d) else d
|
| 171 |
-
|
| 172 |
-
|
| 173 |
-
def cast_tuple(t, length=1):
|
| 174 |
-
if isinstance(t, tuple):
|
| 175 |
-
return t
|
| 176 |
-
return (t,) * length
|
| 177 |
-
|
| 178 |
-
|
| 179 |
-
def identity(t, *args, **kwargs):
|
| 180 |
-
return t
|
| 181 |
-
|
| 182 |
-
|
| 183 |
-
def is_power_of_two(n):
|
| 184 |
-
return log2(n).is_integer()
|
| 185 |
-
|
| 186 |
-
|
| 187 |
-
def null_iterator():
|
| 188 |
-
while True:
|
| 189 |
-
yield None
|
| 190 |
-
|
| 191 |
-
|
| 192 |
-
def Downsample(dim, dim_out=None):
|
| 193 |
-
return nn.Sequential(
|
| 194 |
-
Rearrange("b c (h p1) (w p2) -> b (c p1 p2) h w", p1=2, p2=2),
|
| 195 |
-
nn.Conv2d(dim * 4, default(dim_out, dim), 1),
|
| 196 |
-
)
|
| 197 |
-
|
| 198 |
-
|
| 199 |
-
class RMSNorm(nn.Module):
|
| 200 |
-
def __init__(self, dim):
|
| 201 |
-
super().__init__()
|
| 202 |
-
self.g = nn.Parameter(torch.ones(1, dim, 1, 1))
|
| 203 |
-
self.eps = 1e-4
|
| 204 |
-
|
| 205 |
-
def forward(self, x):
|
| 206 |
-
return F.normalize(x, dim=1) * self.g * (x.shape[1] ** 0.5)
|
| 207 |
-
|
| 208 |
-
|
| 209 |
-
# building block modules
|
| 210 |
-
|
| 211 |
-
|
| 212 |
-
class Block(nn.Module):
|
| 213 |
-
def __init__(self, dim, dim_out, groups=8, num_conv_kernels=0):
|
| 214 |
-
super().__init__()
|
| 215 |
-
self.proj = AdaptiveConv2DMod(
|
| 216 |
-
dim, dim_out, kernel=3, num_conv_kernels=num_conv_kernels
|
| 217 |
-
)
|
| 218 |
-
self.kernel = 3
|
| 219 |
-
self.dilation = 1
|
| 220 |
-
self.stride = 1
|
| 221 |
-
|
| 222 |
-
self.act = nn.SiLU()
|
| 223 |
-
|
| 224 |
-
def forward(self, x, conv_mods_iter: Optional[Iterable] = None):
|
| 225 |
-
conv_mods_iter = default(conv_mods_iter, null_iterator())
|
| 226 |
-
|
| 227 |
-
x = self.proj(x, mod=next(conv_mods_iter), kernel_mod=next(conv_mods_iter))
|
| 228 |
-
|
| 229 |
-
x = self.act(x)
|
| 230 |
-
return x
|
| 231 |
-
|
| 232 |
-
|
| 233 |
-
class ResnetBlock(nn.Module):
|
| 234 |
-
def __init__(
|
| 235 |
-
self, dim, dim_out, *, groups=8, num_conv_kernels=0, style_dims: List = []
|
| 236 |
-
):
|
| 237 |
-
super().__init__()
|
| 238 |
-
style_dims.extend([dim, num_conv_kernels, dim_out, num_conv_kernels])
|
| 239 |
-
|
| 240 |
-
self.block1 = Block(
|
| 241 |
-
dim, dim_out, groups=groups, num_conv_kernels=num_conv_kernels
|
| 242 |
-
)
|
| 243 |
-
self.block2 = Block(
|
| 244 |
-
dim_out, dim_out, groups=groups, num_conv_kernels=num_conv_kernels
|
| 245 |
-
)
|
| 246 |
-
self.res_conv = nn.Conv2d(dim, dim_out, 1) if dim != dim_out else nn.Identity()
|
| 247 |
-
|
| 248 |
-
def forward(self, x, conv_mods_iter: Optional[Iterable] = None):
|
| 249 |
-
h = self.block1(x, conv_mods_iter=conv_mods_iter)
|
| 250 |
-
h = self.block2(h, conv_mods_iter=conv_mods_iter)
|
| 251 |
-
|
| 252 |
-
return h + self.res_conv(x)
|
| 253 |
-
|
| 254 |
-
|
| 255 |
-
class LinearAttention(nn.Module):
|
| 256 |
-
def __init__(self, dim, heads=4, dim_head=32):
|
| 257 |
-
super().__init__()
|
| 258 |
-
self.scale = dim_head**-0.5
|
| 259 |
-
self.heads = heads
|
| 260 |
-
hidden_dim = dim_head * heads
|
| 261 |
-
|
| 262 |
-
self.norm = RMSNorm(dim)
|
| 263 |
-
self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias=False)
|
| 264 |
-
|
| 265 |
-
self.to_out = nn.Sequential(nn.Conv2d(hidden_dim, dim, 1), RMSNorm(dim))
|
| 266 |
-
|
| 267 |
-
def forward(self, x):
|
| 268 |
-
b, c, h, w = x.shape
|
| 269 |
-
|
| 270 |
-
x = self.norm(x)
|
| 271 |
-
|
| 272 |
-
qkv = self.to_qkv(x).chunk(3, dim=1)
|
| 273 |
-
q, k, v = map(
|
| 274 |
-
lambda t: rearrange(t, "b (h c) x y -> b h c (x y)", h=self.heads), qkv
|
| 275 |
-
)
|
| 276 |
-
|
| 277 |
-
q = q.softmax(dim=-2)
|
| 278 |
-
k = k.softmax(dim=-1)
|
| 279 |
-
|
| 280 |
-
q = q * self.scale
|
| 281 |
-
|
| 282 |
-
context = torch.einsum("b h d n, b h e n -> b h d e", k, v)
|
| 283 |
-
|
| 284 |
-
out = torch.einsum("b h d e, b h d n -> b h e n", context, q)
|
| 285 |
-
out = rearrange(out, "b h c (x y) -> b (h c) x y", h=self.heads, x=h, y=w)
|
| 286 |
-
return self.to_out(out)
|
| 287 |
-
|
| 288 |
-
|
| 289 |
-
class Attention(nn.Module):
|
| 290 |
-
def __init__(self, dim, heads=4, dim_head=32, flash=False):
|
| 291 |
-
super().__init__()
|
| 292 |
-
self.heads = heads
|
| 293 |
-
hidden_dim = dim_head * heads
|
| 294 |
-
|
| 295 |
-
self.norm = RMSNorm(dim)
|
| 296 |
-
|
| 297 |
-
self.attend = Attend(flash=flash)
|
| 298 |
-
self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias=False)
|
| 299 |
-
self.to_out = nn.Conv2d(hidden_dim, dim, 1)
|
| 300 |
-
|
| 301 |
-
def forward(self, x):
|
| 302 |
-
b, c, h, w = x.shape
|
| 303 |
-
x = self.norm(x)
|
| 304 |
-
qkv = self.to_qkv(x).chunk(3, dim=1)
|
| 305 |
-
|
| 306 |
-
q, k, v = map(
|
| 307 |
-
lambda t: rearrange(t, "b (h c) x y -> b h (x y) c", h=self.heads), qkv
|
| 308 |
-
)
|
| 309 |
-
|
| 310 |
-
out = self.attend(q, k, v)
|
| 311 |
-
out = rearrange(out, "b h (x y) d -> b (h d) x y", x=h, y=w)
|
| 312 |
-
|
| 313 |
-
return self.to_out(out)
|
| 314 |
-
|
| 315 |
-
|
| 316 |
-
# feedforward
|
| 317 |
-
def FeedForward(dim, mult=4):
|
| 318 |
-
return nn.Sequential(
|
| 319 |
-
RMSNorm(dim),
|
| 320 |
-
nn.Conv2d(dim, dim * mult, 1),
|
| 321 |
-
nn.GELU(),
|
| 322 |
-
nn.Conv2d(dim * mult, dim, 1),
|
| 323 |
-
)
|
| 324 |
-
|
| 325 |
-
|
| 326 |
-
# transformers
|
| 327 |
-
class Transformer(nn.Module):
|
| 328 |
-
def __init__(self, dim, dim_head=64, heads=8, depth=1, flash_attn=True, ff_mult=4):
|
| 329 |
-
super().__init__()
|
| 330 |
-
self.layers = nn.ModuleList([])
|
| 331 |
-
|
| 332 |
-
for _ in range(depth):
|
| 333 |
-
self.layers.append(
|
| 334 |
-
nn.ModuleList(
|
| 335 |
-
[
|
| 336 |
-
Attention(
|
| 337 |
-
dim=dim, dim_head=dim_head, heads=heads, flash=flash_attn
|
| 338 |
-
),
|
| 339 |
-
FeedForward(dim=dim, mult=ff_mult),
|
| 340 |
-
]
|
| 341 |
-
)
|
| 342 |
-
)
|
| 343 |
-
|
| 344 |
-
def forward(self, x):
|
| 345 |
-
for attn, ff in self.layers:
|
| 346 |
-
x = attn(x) + x
|
| 347 |
-
x = ff(x) + x
|
| 348 |
-
|
| 349 |
-
return x
|
| 350 |
-
|
| 351 |
-
|
| 352 |
-
class LinearTransformer(nn.Module):
|
| 353 |
-
def __init__(self, dim, dim_head=64, heads=8, depth=1, ff_mult=4):
|
| 354 |
-
super().__init__()
|
| 355 |
-
self.layers = nn.ModuleList([])
|
| 356 |
-
|
| 357 |
-
for _ in range(depth):
|
| 358 |
-
self.layers.append(
|
| 359 |
-
nn.ModuleList(
|
| 360 |
-
[
|
| 361 |
-
LinearAttention(dim=dim, dim_head=dim_head, heads=heads),
|
| 362 |
-
FeedForward(dim=dim, mult=ff_mult),
|
| 363 |
-
]
|
| 364 |
-
)
|
| 365 |
-
)
|
| 366 |
-
|
| 367 |
-
def forward(self, x):
|
| 368 |
-
for attn, ff in self.layers:
|
| 369 |
-
x = attn(x) + x
|
| 370 |
-
x = ff(x) + x
|
| 371 |
-
|
| 372 |
-
return x
|
| 373 |
-
|
| 374 |
-
|
| 375 |
-
class NearestNeighborhoodUpsample(nn.Module):
|
| 376 |
-
def __init__(self, dim, dim_out=None):
|
| 377 |
-
super().__init__()
|
| 378 |
-
dim_out = default(dim_out, dim)
|
| 379 |
-
self.conv = nn.Conv2d(dim, dim_out, kernel_size=3, stride=1, padding=1)
|
| 380 |
-
|
| 381 |
-
def forward(self, x):
|
| 382 |
-
|
| 383 |
-
if x.shape[0] >= 64:
|
| 384 |
-
x = x.contiguous()
|
| 385 |
-
|
| 386 |
-
x = F.interpolate(x, scale_factor=2.0, mode="nearest")
|
| 387 |
-
x = self.conv(x)
|
| 388 |
-
|
| 389 |
-
return x
|
| 390 |
-
|
| 391 |
-
|
| 392 |
-
class EqualLinear(nn.Module):
|
| 393 |
-
def __init__(self, dim, dim_out, lr_mul=1, bias=True):
|
| 394 |
-
super().__init__()
|
| 395 |
-
self.weight = nn.Parameter(torch.randn(dim_out, dim))
|
| 396 |
-
if bias:
|
| 397 |
-
self.bias = nn.Parameter(torch.zeros(dim_out))
|
| 398 |
-
|
| 399 |
-
self.lr_mul = lr_mul
|
| 400 |
-
|
| 401 |
-
def forward(self, input):
|
| 402 |
-
return F.linear(input, self.weight * self.lr_mul, bias=self.bias * self.lr_mul)
|
| 403 |
-
|
| 404 |
-
|
| 405 |
-
class StyleGanNetwork(nn.Module):
|
| 406 |
-
def __init__(self, dim_in=128, dim_out=512, depth=8, lr_mul=0.1, dim_text_latent=0):
|
| 407 |
-
super().__init__()
|
| 408 |
-
self.dim_in = dim_in
|
| 409 |
-
self.dim_out = dim_out
|
| 410 |
-
self.dim_text_latent = dim_text_latent
|
| 411 |
-
|
| 412 |
-
layers = []
|
| 413 |
-
for i in range(depth):
|
| 414 |
-
is_first = i == 0
|
| 415 |
-
|
| 416 |
-
if is_first:
|
| 417 |
-
dim_in_layer = dim_in + dim_text_latent
|
| 418 |
-
else:
|
| 419 |
-
dim_in_layer = dim_out
|
| 420 |
-
|
| 421 |
-
dim_out_layer = dim_out
|
| 422 |
-
|
| 423 |
-
layers.extend(
|
| 424 |
-
[EqualLinear(dim_in_layer, dim_out_layer, lr_mul), nn.LeakyReLU(0.2)]
|
| 425 |
-
)
|
| 426 |
-
|
| 427 |
-
self.net = nn.Sequential(*layers)
|
| 428 |
-
|
| 429 |
-
def forward(self, x, text_latent=None):
|
| 430 |
-
x = F.normalize(x, dim=1)
|
| 431 |
-
if self.dim_text_latent > 0:
|
| 432 |
-
assert exists(text_latent)
|
| 433 |
-
x = torch.cat((x, text_latent), dim=-1)
|
| 434 |
-
return self.net(x)
|
| 435 |
-
|
| 436 |
-
|
| 437 |
-
class UnetUpsampler(torch.nn.Module):
|
| 438 |
-
|
| 439 |
-
def __init__(
|
| 440 |
-
self,
|
| 441 |
-
dim: int,
|
| 442 |
-
*,
|
| 443 |
-
image_size: int,
|
| 444 |
-
input_image_size: int,
|
| 445 |
-
init_dim: Optional[int] = None,
|
| 446 |
-
out_dim: Optional[int] = None,
|
| 447 |
-
style_network: Optional[dict] = None,
|
| 448 |
-
up_dim_mults: tuple = (1, 2, 4, 8, 16),
|
| 449 |
-
down_dim_mults: tuple = (4, 8, 16),
|
| 450 |
-
channels: int = 3,
|
| 451 |
-
resnet_block_groups: int = 8,
|
| 452 |
-
full_attn: tuple = (False, False, False, True, True),
|
| 453 |
-
flash_attn: bool = True,
|
| 454 |
-
self_attn_dim_head: int = 64,
|
| 455 |
-
self_attn_heads: int = 8,
|
| 456 |
-
attn_depths: tuple = (2, 2, 2, 2, 4),
|
| 457 |
-
mid_attn_depth: int = 4,
|
| 458 |
-
num_conv_kernels: int = 4,
|
| 459 |
-
resize_mode: str = "bilinear",
|
| 460 |
-
unconditional: bool = True,
|
| 461 |
-
skip_connect_scale: Optional[float] = None,
|
| 462 |
-
):
|
| 463 |
-
super().__init__()
|
| 464 |
-
self.style_network = style_network = StyleGanNetwork(**style_network)
|
| 465 |
-
self.unconditional = unconditional
|
| 466 |
-
assert not (
|
| 467 |
-
unconditional
|
| 468 |
-
and exists(style_network)
|
| 469 |
-
and style_network.dim_text_latent > 0
|
| 470 |
-
)
|
| 471 |
-
|
| 472 |
-
assert is_power_of_two(image_size) and is_power_of_two(
|
| 473 |
-
input_image_size
|
| 474 |
-
), "both output image size and input image size must be power of 2"
|
| 475 |
-
assert (
|
| 476 |
-
input_image_size < image_size
|
| 477 |
-
), "input image size must be smaller than the output image size, thus upsampling"
|
| 478 |
-
|
| 479 |
-
self.image_size = image_size
|
| 480 |
-
self.input_image_size = input_image_size
|
| 481 |
-
|
| 482 |
-
style_embed_split_dims = []
|
| 483 |
-
|
| 484 |
-
self.channels = channels
|
| 485 |
-
input_channels = channels
|
| 486 |
-
|
| 487 |
-
init_dim = default(init_dim, dim)
|
| 488 |
-
|
| 489 |
-
up_dims = [init_dim, *map(lambda m: dim * m, up_dim_mults)]
|
| 490 |
-
init_down_dim = up_dims[len(up_dim_mults) - len(down_dim_mults)]
|
| 491 |
-
down_dims = [init_down_dim, *map(lambda m: dim * m, down_dim_mults)]
|
| 492 |
-
self.init_conv = nn.Conv2d(input_channels, init_down_dim, 7, padding=3)
|
| 493 |
-
|
| 494 |
-
up_in_out = list(zip(up_dims[:-1], up_dims[1:]))
|
| 495 |
-
down_in_out = list(zip(down_dims[:-1], down_dims[1:]))
|
| 496 |
-
|
| 497 |
-
block_klass = partial(
|
| 498 |
-
ResnetBlock,
|
| 499 |
-
groups=resnet_block_groups,
|
| 500 |
-
num_conv_kernels=num_conv_kernels,
|
| 501 |
-
style_dims=style_embed_split_dims,
|
| 502 |
-
)
|
| 503 |
-
|
| 504 |
-
FullAttention = partial(Transformer, flash_attn=flash_attn)
|
| 505 |
-
*_, mid_dim = up_dims
|
| 506 |
-
|
| 507 |
-
self.skip_connect_scale = default(skip_connect_scale, 2**-0.5)
|
| 508 |
-
|
| 509 |
-
self.downs = nn.ModuleList([])
|
| 510 |
-
self.ups = nn.ModuleList([])
|
| 511 |
-
|
| 512 |
-
block_count = 6
|
| 513 |
-
|
| 514 |
-
for ind, (
|
| 515 |
-
(dim_in, dim_out),
|
| 516 |
-
layer_full_attn,
|
| 517 |
-
layer_attn_depth,
|
| 518 |
-
) in enumerate(zip(down_in_out, full_attn, attn_depths)):
|
| 519 |
-
attn_klass = FullAttention if layer_full_attn else LinearTransformer
|
| 520 |
-
|
| 521 |
-
blocks = []
|
| 522 |
-
for i in range(block_count):
|
| 523 |
-
blocks.append(block_klass(dim_in, dim_in))
|
| 524 |
-
|
| 525 |
-
self.downs.append(
|
| 526 |
-
nn.ModuleList(
|
| 527 |
-
[
|
| 528 |
-
nn.ModuleList(blocks),
|
| 529 |
-
nn.ModuleList(
|
| 530 |
-
[
|
| 531 |
-
(
|
| 532 |
-
attn_klass(
|
| 533 |
-
dim_in,
|
| 534 |
-
dim_head=self_attn_dim_head,
|
| 535 |
-
heads=self_attn_heads,
|
| 536 |
-
depth=layer_attn_depth,
|
| 537 |
-
)
|
| 538 |
-
if layer_full_attn
|
| 539 |
-
else None
|
| 540 |
-
),
|
| 541 |
-
nn.Conv2d(
|
| 542 |
-
dim_in, dim_out, kernel_size=3, stride=2, padding=1
|
| 543 |
-
),
|
| 544 |
-
]
|
| 545 |
-
),
|
| 546 |
-
]
|
| 547 |
-
)
|
| 548 |
-
)
|
| 549 |
-
|
| 550 |
-
self.mid_block1 = block_klass(mid_dim, mid_dim)
|
| 551 |
-
self.mid_attn = FullAttention(
|
| 552 |
-
mid_dim,
|
| 553 |
-
dim_head=self_attn_dim_head,
|
| 554 |
-
heads=self_attn_heads,
|
| 555 |
-
depth=mid_attn_depth,
|
| 556 |
-
)
|
| 557 |
-
self.mid_block2 = block_klass(mid_dim, mid_dim)
|
| 558 |
-
|
| 559 |
-
*_, last_dim = up_dims
|
| 560 |
-
|
| 561 |
-
for ind, (
|
| 562 |
-
(dim_in, dim_out),
|
| 563 |
-
layer_full_attn,
|
| 564 |
-
layer_attn_depth,
|
| 565 |
-
) in enumerate(
|
| 566 |
-
zip(
|
| 567 |
-
reversed(up_in_out),
|
| 568 |
-
reversed(full_attn),
|
| 569 |
-
reversed(attn_depths),
|
| 570 |
-
)
|
| 571 |
-
):
|
| 572 |
-
attn_klass = FullAttention if layer_full_attn else LinearTransformer
|
| 573 |
-
|
| 574 |
-
blocks = []
|
| 575 |
-
input_dim = dim_in * 2 if ind < len(down_in_out) else dim_in
|
| 576 |
-
for i in range(block_count):
|
| 577 |
-
blocks.append(block_klass(input_dim, dim_in))
|
| 578 |
-
|
| 579 |
-
self.ups.append(
|
| 580 |
-
nn.ModuleList(
|
| 581 |
-
[
|
| 582 |
-
nn.ModuleList(blocks),
|
| 583 |
-
nn.ModuleList(
|
| 584 |
-
[
|
| 585 |
-
NearestNeighborhoodUpsample(
|
| 586 |
-
last_dim if ind == 0 else dim_out,
|
| 587 |
-
dim_in,
|
| 588 |
-
),
|
| 589 |
-
(
|
| 590 |
-
attn_klass(
|
| 591 |
-
dim_in,
|
| 592 |
-
dim_head=self_attn_dim_head,
|
| 593 |
-
heads=self_attn_heads,
|
| 594 |
-
depth=layer_attn_depth,
|
| 595 |
-
)
|
| 596 |
-
if layer_full_attn
|
| 597 |
-
else None
|
| 598 |
-
),
|
| 599 |
-
]
|
| 600 |
-
),
|
| 601 |
-
]
|
| 602 |
-
)
|
| 603 |
-
)
|
| 604 |
-
|
| 605 |
-
self.out_dim = default(out_dim, channels)
|
| 606 |
-
self.final_res_block = block_klass(dim, dim)
|
| 607 |
-
self.final_to_rgb = nn.Conv2d(dim, channels, 1)
|
| 608 |
-
self.resize_mode = resize_mode
|
| 609 |
-
self.style_to_conv_modulations = nn.Linear(
|
| 610 |
-
style_network.dim_out, sum(style_embed_split_dims)
|
| 611 |
-
)
|
| 612 |
-
self.style_embed_split_dims = style_embed_split_dims
|
| 613 |
-
|
| 614 |
-
@property
|
| 615 |
-
def allowable_rgb_resolutions(self):
|
| 616 |
-
input_res_base = int(log2(self.input_image_size))
|
| 617 |
-
output_res_base = int(log2(self.image_size))
|
| 618 |
-
allowed_rgb_res_base = list(range(input_res_base, output_res_base))
|
| 619 |
-
return [*map(lambda p: 2**p, allowed_rgb_res_base)]
|
| 620 |
-
|
| 621 |
-
@property
|
| 622 |
-
def device(self):
|
| 623 |
-
return next(self.parameters()).device
|
| 624 |
-
|
| 625 |
-
@property
|
| 626 |
-
def total_params(self):
|
| 627 |
-
return sum([p.numel() for p in self.parameters()])
|
| 628 |
-
|
| 629 |
-
def resize_image_to(self, x, size):
|
| 630 |
-
return F.interpolate(x, (size, size), mode=self.resize_mode)
|
| 631 |
-
|
| 632 |
-
def forward(
|
| 633 |
-
self,
|
| 634 |
-
lowres_image: torch.Tensor,
|
| 635 |
-
styles: Optional[torch.Tensor] = None,
|
| 636 |
-
noise: Optional[torch.Tensor] = None,
|
| 637 |
-
global_text_tokens: Optional[torch.Tensor] = None,
|
| 638 |
-
return_all_rgbs: bool = False,
|
| 639 |
-
):
|
| 640 |
-
x = lowres_image
|
| 641 |
-
|
| 642 |
-
noise_scale = 0.001 # Adjust the scale of the noise as needed
|
| 643 |
-
noise_aug = torch.randn_like(x) * noise_scale
|
| 644 |
-
x = x + noise_aug
|
| 645 |
-
x = x.clamp(0, 1)
|
| 646 |
-
|
| 647 |
-
shape = x.shape
|
| 648 |
-
batch_size = shape[0]
|
| 649 |
-
|
| 650 |
-
assert shape[-2:] == ((self.input_image_size,) * 2)
|
| 651 |
-
|
| 652 |
-
# styles
|
| 653 |
-
if not exists(styles):
|
| 654 |
-
assert exists(self.style_network)
|
| 655 |
-
|
| 656 |
-
noise = default(
|
| 657 |
-
noise,
|
| 658 |
-
torch.randn(
|
| 659 |
-
(batch_size, self.style_network.dim_in), device=self.device
|
| 660 |
-
),
|
| 661 |
-
)
|
| 662 |
-
styles = self.style_network(noise, global_text_tokens)
|
| 663 |
-
|
| 664 |
-
# project styles to conv modulations
|
| 665 |
-
conv_mods = self.style_to_conv_modulations(styles)
|
| 666 |
-
conv_mods = conv_mods.split(self.style_embed_split_dims, dim=-1)
|
| 667 |
-
conv_mods = iter(conv_mods)
|
| 668 |
-
|
| 669 |
-
x = self.init_conv(x)
|
| 670 |
-
|
| 671 |
-
h = []
|
| 672 |
-
for blocks, (attn, downsample) in self.downs:
|
| 673 |
-
for block in blocks:
|
| 674 |
-
x = block(x, conv_mods_iter=conv_mods)
|
| 675 |
-
h.append(x)
|
| 676 |
-
|
| 677 |
-
if attn is not None:
|
| 678 |
-
x = attn(x)
|
| 679 |
-
|
| 680 |
-
x = downsample(x)
|
| 681 |
-
|
| 682 |
-
x = self.mid_block1(x, conv_mods_iter=conv_mods)
|
| 683 |
-
x = self.mid_attn(x)
|
| 684 |
-
x = self.mid_block2(x, conv_mods_iter=conv_mods)
|
| 685 |
-
|
| 686 |
-
for (
|
| 687 |
-
blocks,
|
| 688 |
-
(
|
| 689 |
-
upsample,
|
| 690 |
-
attn,
|
| 691 |
-
),
|
| 692 |
-
) in self.ups:
|
| 693 |
-
x = upsample(x)
|
| 694 |
-
for block in blocks:
|
| 695 |
-
if h != []:
|
| 696 |
-
res = h.pop()
|
| 697 |
-
res = res * self.skip_connect_scale
|
| 698 |
-
x = torch.cat((x, res), dim=1)
|
| 699 |
-
|
| 700 |
-
x = block(x, conv_mods_iter=conv_mods)
|
| 701 |
-
|
| 702 |
-
if attn is not None:
|
| 703 |
-
x = attn(x)
|
| 704 |
-
|
| 705 |
-
x = self.final_res_block(x, conv_mods_iter=conv_mods)
|
| 706 |
-
rgb = self.final_to_rgb(x)
|
| 707 |
-
|
| 708 |
-
if not return_all_rgbs:
|
| 709 |
-
return rgb
|
| 710 |
-
|
| 711 |
-
return rgb, []
|
| 712 |
-
|
| 713 |
-
|
| 714 |
-
def tile_image(image, chunk_size=64):
|
| 715 |
-
c, h, w = image.shape
|
| 716 |
-
h_chunks = ceil(h / chunk_size)
|
| 717 |
-
w_chunks = ceil(w / chunk_size)
|
| 718 |
-
tiles = []
|
| 719 |
-
for i in range(h_chunks):
|
| 720 |
-
for j in range(w_chunks):
|
| 721 |
-
tile = image[
|
| 722 |
-
:,
|
| 723 |
-
i * chunk_size : (i + 1) * chunk_size,
|
| 724 |
-
j * chunk_size : (j + 1) * chunk_size,
|
| 725 |
-
]
|
| 726 |
-
tiles.append(tile)
|
| 727 |
-
return tiles, h_chunks, w_chunks
|
| 728 |
-
|
| 729 |
-
|
| 730 |
-
# This helps create a checkboard pattern with some edge blending
|
| 731 |
-
def create_checkerboard_weights(tile_size):
|
| 732 |
-
x = torch.linspace(-1, 1, tile_size)
|
| 733 |
-
y = torch.linspace(-1, 1, tile_size)
|
| 734 |
-
|
| 735 |
-
x, y = torch.meshgrid(x, y, indexing="ij")
|
| 736 |
-
d = torch.sqrt(x * x + y * y)
|
| 737 |
-
sigma, mu = 0.5, 0.0
|
| 738 |
-
weights = torch.exp(-((d - mu) ** 2 / (2.0 * sigma**2)))
|
| 739 |
-
|
| 740 |
-
# saturate the values to sure get high weights in the center
|
| 741 |
-
weights = weights**8
|
| 742 |
-
|
| 743 |
-
return weights / weights.max() # Normalize to [0, 1]
|
| 744 |
-
|
| 745 |
-
|
| 746 |
-
def repeat_weights(weights, image_size):
|
| 747 |
-
tile_size = weights.shape[0]
|
| 748 |
-
repeats = (
|
| 749 |
-
math.ceil(image_size[0] / tile_size),
|
| 750 |
-
math.ceil(image_size[1] / tile_size),
|
| 751 |
-
)
|
| 752 |
-
return weights.repeat(repeats)[: image_size[0], : image_size[1]]
|
| 753 |
-
|
| 754 |
-
|
| 755 |
-
def create_offset_weights(weights, image_size):
|
| 756 |
-
tile_size = weights.shape[0]
|
| 757 |
-
offset = tile_size // 2
|
| 758 |
-
full_weights = repeat_weights(
|
| 759 |
-
weights, (image_size[0] + offset, image_size[1] + offset)
|
| 760 |
-
)
|
| 761 |
-
return full_weights[offset:, offset:]
|
| 762 |
-
|
| 763 |
-
|
| 764 |
-
def merge_tiles(tiles, h_chunks, w_chunks, chunk_size=64):
|
| 765 |
-
# Determine the shape of the output tensor
|
| 766 |
-
c = tiles[0].shape[0]
|
| 767 |
-
h = h_chunks * chunk_size
|
| 768 |
-
w = w_chunks * chunk_size
|
| 769 |
-
|
| 770 |
-
# Create an empty tensor to hold the merged image
|
| 771 |
-
merged = torch.zeros((c, h, w), dtype=tiles[0].dtype)
|
| 772 |
-
|
| 773 |
-
# Iterate over the tiles and place them in the correct position
|
| 774 |
-
for idx, tile in enumerate(tiles):
|
| 775 |
-
i = idx // w_chunks
|
| 776 |
-
j = idx % w_chunks
|
| 777 |
-
|
| 778 |
-
h_start = i * chunk_size
|
| 779 |
-
w_start = j * chunk_size
|
| 780 |
-
|
| 781 |
-
tile_h, tile_w = tile.shape[1:]
|
| 782 |
-
merged[:, h_start : h_start + tile_h, w_start : w_start + tile_w] = tile
|
| 783 |
-
|
| 784 |
-
return merged
|
| 785 |
-
|
| 786 |
-
|
| 787 |
-
class AuraSR:
|
| 788 |
-
def __init__(self, config: dict[str, Any], device: str = "cuda"):
|
| 789 |
-
self.upsampler = UnetUpsampler(**config).to(device)
|
| 790 |
-
self.input_image_size = config["input_image_size"]
|
| 791 |
-
|
| 792 |
-
@classmethod
|
| 793 |
-
def from_pretrained(
|
| 794 |
-
cls,
|
| 795 |
-
model_id: str = "fal-ai/AuraSR",
|
| 796 |
-
use_safetensors: bool = True,
|
| 797 |
-
device: str = "cuda",
|
| 798 |
-
):
|
| 799 |
-
import json
|
| 800 |
-
import torch
|
| 801 |
-
from pathlib import Path
|
| 802 |
-
from huggingface_hub import snapshot_download
|
| 803 |
-
|
| 804 |
-
# Check if model_id is a local file
|
| 805 |
-
if Path(model_id).is_file():
|
| 806 |
-
local_file = Path(model_id)
|
| 807 |
-
if local_file.suffix == ".safetensors":
|
| 808 |
-
use_safetensors = True
|
| 809 |
-
elif local_file.suffix == ".ckpt":
|
| 810 |
-
use_safetensors = False
|
| 811 |
-
else:
|
| 812 |
-
raise ValueError(
|
| 813 |
-
f"Unsupported file format: {local_file.suffix}. Please use .safetensors or .ckpt files."
|
| 814 |
-
)
|
| 815 |
-
|
| 816 |
-
# For local files, we need to provide the config separately
|
| 817 |
-
config_path = local_file.with_name("config.json")
|
| 818 |
-
if not config_path.exists():
|
| 819 |
-
raise FileNotFoundError(
|
| 820 |
-
f"Config file not found: {config_path}. "
|
| 821 |
-
f"When loading from a local file, ensure that 'config.json' "
|
| 822 |
-
f"is present in the same directory as '{local_file.name}'. "
|
| 823 |
-
f"If you're trying to load a model from Hugging Face, "
|
| 824 |
-
f"please provide the model ID instead of a file path."
|
| 825 |
-
)
|
| 826 |
-
|
| 827 |
-
config = json.loads(config_path.read_text())
|
| 828 |
-
hf_model_path = local_file.parent
|
| 829 |
-
else:
|
| 830 |
-
hf_model_path = Path(
|
| 831 |
-
snapshot_download(model_id, ignore_patterns=["*.ckpt"])
|
| 832 |
-
)
|
| 833 |
-
config = json.loads((hf_model_path / "config.json").read_text())
|
| 834 |
-
|
| 835 |
-
model = cls(config, device)
|
| 836 |
-
|
| 837 |
-
if use_safetensors:
|
| 838 |
-
try:
|
| 839 |
-
from safetensors.torch import load_file
|
| 840 |
-
|
| 841 |
-
checkpoint = load_file(
|
| 842 |
-
hf_model_path / "model.safetensors"
|
| 843 |
-
if not Path(model_id).is_file()
|
| 844 |
-
else model_id
|
| 845 |
-
)
|
| 846 |
-
except ImportError:
|
| 847 |
-
raise ImportError(
|
| 848 |
-
"The safetensors library is not installed. "
|
| 849 |
-
"Please install it with `pip install safetensors` "
|
| 850 |
-
"or use `use_safetensors=False` to load the model with PyTorch."
|
| 851 |
-
)
|
| 852 |
-
else:
|
| 853 |
-
checkpoint = torch.load(
|
| 854 |
-
hf_model_path / "model.ckpt"
|
| 855 |
-
if not Path(model_id).is_file()
|
| 856 |
-
else model_id
|
| 857 |
-
)
|
| 858 |
-
|
| 859 |
-
model.upsampler.load_state_dict(checkpoint, strict=True)
|
| 860 |
-
return model
|
| 861 |
-
|
| 862 |
-
@torch.no_grad()
|
| 863 |
-
def upscale_4x(self, image: Image.Image, max_batch_size=8) -> Image.Image:
|
| 864 |
-
tensor_transform = transforms.ToTensor()
|
| 865 |
-
device = self.upsampler.device
|
| 866 |
-
|
| 867 |
-
image_tensor = tensor_transform(image).unsqueeze(0)
|
| 868 |
-
_, _, h, w = image_tensor.shape
|
| 869 |
-
pad_h = (
|
| 870 |
-
self.input_image_size - h % self.input_image_size
|
| 871 |
-
) % self.input_image_size
|
| 872 |
-
pad_w = (
|
| 873 |
-
self.input_image_size - w % self.input_image_size
|
| 874 |
-
) % self.input_image_size
|
| 875 |
-
|
| 876 |
-
# Pad the image
|
| 877 |
-
image_tensor = torch.nn.functional.pad(
|
| 878 |
-
image_tensor, (0, pad_w, 0, pad_h), mode="reflect"
|
| 879 |
-
).squeeze(0)
|
| 880 |
-
tiles, h_chunks, w_chunks = tile_image(image_tensor, self.input_image_size)
|
| 881 |
-
|
| 882 |
-
# Batch processing of tiles
|
| 883 |
-
num_tiles = len(tiles)
|
| 884 |
-
batches = [
|
| 885 |
-
tiles[i : i + max_batch_size] for i in range(0, num_tiles, max_batch_size)
|
| 886 |
-
]
|
| 887 |
-
reconstructed_tiles = []
|
| 888 |
-
|
| 889 |
-
for batch in batches:
|
| 890 |
-
model_input = torch.stack(batch).to(device)
|
| 891 |
-
generator_output = self.upsampler(
|
| 892 |
-
lowres_image=model_input,
|
| 893 |
-
noise=torch.randn(model_input.shape[0], 128, device=device),
|
| 894 |
-
)
|
| 895 |
-
reconstructed_tiles.extend(
|
| 896 |
-
list(generator_output.clamp_(0, 1).detach().cpu())
|
| 897 |
-
)
|
| 898 |
-
|
| 899 |
-
merged_tensor = merge_tiles(
|
| 900 |
-
reconstructed_tiles, h_chunks, w_chunks, self.input_image_size * 4
|
| 901 |
-
)
|
| 902 |
-
unpadded = merged_tensor[:, : h * 4, : w * 4]
|
| 903 |
-
|
| 904 |
-
to_pil = transforms.ToPILImage()
|
| 905 |
-
return to_pil(unpadded)
|
| 906 |
-
|
| 907 |
-
# Tiled 4x upscaling with overlapping tiles to reduce seam artifacts
|
| 908 |
-
# weights options are 'checkboard' and 'constant'
|
| 909 |
-
@torch.no_grad()
|
| 910 |
-
def upscale_4x_overlapped(self, image, max_batch_size=8, weight_type="checkboard"):
|
| 911 |
-
tensor_transform = transforms.ToTensor()
|
| 912 |
-
device = self.upsampler.device
|
| 913 |
-
|
| 914 |
-
image_tensor = tensor_transform(image).unsqueeze(0)
|
| 915 |
-
_, _, h, w = image_tensor.shape
|
| 916 |
-
|
| 917 |
-
# Calculate paddings
|
| 918 |
-
pad_h = (
|
| 919 |
-
self.input_image_size - h % self.input_image_size
|
| 920 |
-
) % self.input_image_size
|
| 921 |
-
pad_w = (
|
| 922 |
-
self.input_image_size - w % self.input_image_size
|
| 923 |
-
) % self.input_image_size
|
| 924 |
-
|
| 925 |
-
# Pad the image
|
| 926 |
-
image_tensor = torch.nn.functional.pad(
|
| 927 |
-
image_tensor, (0, pad_w, 0, pad_h), mode="reflect"
|
| 928 |
-
).squeeze(0)
|
| 929 |
-
|
| 930 |
-
# Function to process tiles
|
| 931 |
-
def process_tiles(tiles, h_chunks, w_chunks):
|
| 932 |
-
num_tiles = len(tiles)
|
| 933 |
-
batches = [
|
| 934 |
-
tiles[i : i + max_batch_size]
|
| 935 |
-
for i in range(0, num_tiles, max_batch_size)
|
| 936 |
-
]
|
| 937 |
-
reconstructed_tiles = []
|
| 938 |
-
|
| 939 |
-
for batch in batches:
|
| 940 |
-
model_input = torch.stack(batch).to(device)
|
| 941 |
-
generator_output = self.upsampler(
|
| 942 |
-
lowres_image=model_input,
|
| 943 |
-
noise=torch.randn(model_input.shape[0], 128, device=device),
|
| 944 |
-
)
|
| 945 |
-
reconstructed_tiles.extend(
|
| 946 |
-
list(generator_output.clamp_(0, 1).detach().cpu())
|
| 947 |
-
)
|
| 948 |
-
|
| 949 |
-
return merge_tiles(
|
| 950 |
-
reconstructed_tiles, h_chunks, w_chunks, self.input_image_size * 4
|
| 951 |
-
)
|
| 952 |
-
|
| 953 |
-
# First pass
|
| 954 |
-
tiles1, h_chunks1, w_chunks1 = tile_image(image_tensor, self.input_image_size)
|
| 955 |
-
result1 = process_tiles(tiles1, h_chunks1, w_chunks1)
|
| 956 |
-
|
| 957 |
-
# Second pass with offset
|
| 958 |
-
offset = self.input_image_size // 2
|
| 959 |
-
image_tensor_offset = torch.nn.functional.pad(
|
| 960 |
-
image_tensor, (offset, offset, offset, offset), mode="reflect"
|
| 961 |
-
).squeeze(0)
|
| 962 |
-
|
| 963 |
-
tiles2, h_chunks2, w_chunks2 = tile_image(
|
| 964 |
-
image_tensor_offset, self.input_image_size
|
| 965 |
-
)
|
| 966 |
-
result2 = process_tiles(tiles2, h_chunks2, w_chunks2)
|
| 967 |
-
|
| 968 |
-
# unpad
|
| 969 |
-
offset_4x = offset * 4
|
| 970 |
-
result2_interior = result2[:, offset_4x:-offset_4x, offset_4x:-offset_4x]
|
| 971 |
-
|
| 972 |
-
if weight_type == "checkboard":
|
| 973 |
-
weight_tile = create_checkerboard_weights(self.input_image_size * 4)
|
| 974 |
-
|
| 975 |
-
weight_shape = result2_interior.shape[1:]
|
| 976 |
-
weights_1 = create_offset_weights(weight_tile, weight_shape)
|
| 977 |
-
weights_2 = repeat_weights(weight_tile, weight_shape)
|
| 978 |
-
|
| 979 |
-
normalizer = weights_1 + weights_2
|
| 980 |
-
weights_1 = weights_1 / normalizer
|
| 981 |
-
weights_2 = weights_2 / normalizer
|
| 982 |
-
|
| 983 |
-
weights_1 = weights_1.unsqueeze(0).repeat(3, 1, 1)
|
| 984 |
-
weights_2 = weights_2.unsqueeze(0).repeat(3, 1, 1)
|
| 985 |
-
elif weight_type == "constant":
|
| 986 |
-
weights_1 = torch.ones_like(result2_interior) * 0.5
|
| 987 |
-
weights_2 = weights_1
|
| 988 |
-
else:
|
| 989 |
-
raise ValueError(
|
| 990 |
-
"weight_type should be either 'gaussian' or 'constant' but got",
|
| 991 |
-
weight_type,
|
| 992 |
-
)
|
| 993 |
-
|
| 994 |
-
result1 = result1 * weights_2
|
| 995 |
-
result2 = result2_interior * weights_1
|
| 996 |
-
|
| 997 |
-
# Average the overlapping region
|
| 998 |
-
result1 = result1 + result2
|
| 999 |
-
|
| 1000 |
-
# Remove padding
|
| 1001 |
-
unpadded = result1[:, : h * 4, : w * 4]
|
| 1002 |
-
|
| 1003 |
-
to_pil = transforms.ToPILImage()
|
| 1004 |
-
return to_pil(unpadded)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src/backend/upscale/aura_sr_upscale.py
DELETED
|
@@ -1,9 +0,0 @@
|
|
| 1 |
-
from backend.upscale.aura_sr import AuraSR
|
| 2 |
-
from PIL import Image
|
| 3 |
-
|
| 4 |
-
|
| 5 |
-
def upscale_aura_sr(image_path: str):
|
| 6 |
-
|
| 7 |
-
aura_sr = AuraSR.from_pretrained("fal/AuraSR-v2", device="cpu")
|
| 8 |
-
image_in = Image.open(image_path) # .resize((256, 256))
|
| 9 |
-
return aura_sr.upscale_4x(image_in)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src/backend/upscale/edsr_upscale_onnx.py
DELETED
|
@@ -1,37 +0,0 @@
|
|
| 1 |
-
import numpy as np
|
| 2 |
-
import onnxruntime
|
| 3 |
-
from huggingface_hub import hf_hub_download
|
| 4 |
-
from PIL import Image
|
| 5 |
-
|
| 6 |
-
|
| 7 |
-
def upscale_edsr_2x(image_path: str):
|
| 8 |
-
input_image = Image.open(image_path).convert("RGB")
|
| 9 |
-
input_image = np.array(input_image).astype("float32")
|
| 10 |
-
input_image = np.transpose(input_image, (2, 0, 1))
|
| 11 |
-
img_arr = np.expand_dims(input_image, axis=0)
|
| 12 |
-
|
| 13 |
-
if np.max(img_arr) > 256: # 16-bit image
|
| 14 |
-
max_range = 65535
|
| 15 |
-
else:
|
| 16 |
-
max_range = 255.0
|
| 17 |
-
img = img_arr / max_range
|
| 18 |
-
|
| 19 |
-
model_path = hf_hub_download(
|
| 20 |
-
repo_id="rupeshs/edsr-onnx",
|
| 21 |
-
filename="edsr_onnxsim_2x.onnx",
|
| 22 |
-
)
|
| 23 |
-
sess = onnxruntime.InferenceSession(model_path)
|
| 24 |
-
|
| 25 |
-
input_name = sess.get_inputs()[0].name
|
| 26 |
-
output_name = sess.get_outputs()[0].name
|
| 27 |
-
output = sess.run(
|
| 28 |
-
[output_name],
|
| 29 |
-
{input_name: img},
|
| 30 |
-
)[0]
|
| 31 |
-
|
| 32 |
-
result = output.squeeze()
|
| 33 |
-
result = result.clip(0, 1)
|
| 34 |
-
image_array = np.transpose(result, (1, 2, 0))
|
| 35 |
-
image_array = np.uint8(image_array * 255)
|
| 36 |
-
upscaled_image = Image.fromarray(image_array)
|
| 37 |
-
return upscaled_image
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src/backend/upscale/tiled_upscale.py
DELETED
|
@@ -1,237 +0,0 @@
|
|
| 1 |
-
import time
|
| 2 |
-
import math
|
| 3 |
-
import logging
|
| 4 |
-
from PIL import Image, ImageDraw, ImageFilter
|
| 5 |
-
from backend.models.lcmdiffusion_setting import DiffusionTask
|
| 6 |
-
from context import Context
|
| 7 |
-
from constants import DEVICE
|
| 8 |
-
|
| 9 |
-
|
| 10 |
-
def generate_upscaled_image(
|
| 11 |
-
config,
|
| 12 |
-
input_path=None,
|
| 13 |
-
strength=0.3,
|
| 14 |
-
scale_factor=2.0,
|
| 15 |
-
tile_overlap=16,
|
| 16 |
-
upscale_settings=None,
|
| 17 |
-
context: Context = None,
|
| 18 |
-
output_path=None,
|
| 19 |
-
image_format="PNG",
|
| 20 |
-
):
|
| 21 |
-
if config == None or (
|
| 22 |
-
input_path == None or input_path == "" and upscale_settings == None
|
| 23 |
-
):
|
| 24 |
-
logging.error("Wrong arguments in tiled upscale function call!")
|
| 25 |
-
return
|
| 26 |
-
|
| 27 |
-
# Use the upscale_settings dict if provided; otherwise, build the
|
| 28 |
-
# upscale_settings dict using the function arguments and default values
|
| 29 |
-
if upscale_settings == None:
|
| 30 |
-
upscale_settings = {
|
| 31 |
-
"source_file": input_path,
|
| 32 |
-
"target_file": None,
|
| 33 |
-
"output_format": image_format,
|
| 34 |
-
"strength": strength,
|
| 35 |
-
"scale_factor": scale_factor,
|
| 36 |
-
"prompt": config.lcm_diffusion_setting.prompt,
|
| 37 |
-
"tile_overlap": tile_overlap,
|
| 38 |
-
"tile_size": 256,
|
| 39 |
-
"tiles": [],
|
| 40 |
-
}
|
| 41 |
-
source_image = Image.open(input_path) # PIL image
|
| 42 |
-
else:
|
| 43 |
-
source_image = Image.open(upscale_settings["source_file"])
|
| 44 |
-
|
| 45 |
-
upscale_settings["source_image"] = source_image
|
| 46 |
-
|
| 47 |
-
if upscale_settings["target_file"]:
|
| 48 |
-
result = Image.open(upscale_settings["target_file"])
|
| 49 |
-
else:
|
| 50 |
-
result = Image.new(
|
| 51 |
-
mode="RGBA",
|
| 52 |
-
size=(
|
| 53 |
-
source_image.size[0] * int(upscale_settings["scale_factor"]),
|
| 54 |
-
source_image.size[1] * int(upscale_settings["scale_factor"]),
|
| 55 |
-
),
|
| 56 |
-
color=(0, 0, 0, 0),
|
| 57 |
-
)
|
| 58 |
-
upscale_settings["target_image"] = result
|
| 59 |
-
|
| 60 |
-
# If the custom tile definition array 'tiles' is empty, proceed with the
|
| 61 |
-
# default tiled upscale task by defining all the possible image tiles; note
|
| 62 |
-
# that the actual tile size is 'tile_size' + 'tile_overlap' and the target
|
| 63 |
-
# image width and height are no longer constrained to multiples of 256 but
|
| 64 |
-
# are instead multiples of the actual tile size
|
| 65 |
-
if len(upscale_settings["tiles"]) == 0:
|
| 66 |
-
tile_size = upscale_settings["tile_size"]
|
| 67 |
-
scale_factor = upscale_settings["scale_factor"]
|
| 68 |
-
tile_overlap = upscale_settings["tile_overlap"]
|
| 69 |
-
total_cols = math.ceil(
|
| 70 |
-
source_image.size[0] / tile_size
|
| 71 |
-
) # Image width / tile size
|
| 72 |
-
total_rows = math.ceil(
|
| 73 |
-
source_image.size[1] / tile_size
|
| 74 |
-
) # Image height / tile size
|
| 75 |
-
for y in range(0, total_rows):
|
| 76 |
-
y_offset = tile_overlap if y > 0 else 0 # Tile mask offset
|
| 77 |
-
for x in range(0, total_cols):
|
| 78 |
-
x_offset = tile_overlap if x > 0 else 0 # Tile mask offset
|
| 79 |
-
x1 = x * tile_size
|
| 80 |
-
y1 = y * tile_size
|
| 81 |
-
w = tile_size + (tile_overlap if x < total_cols - 1 else 0)
|
| 82 |
-
h = tile_size + (tile_overlap if y < total_rows - 1 else 0)
|
| 83 |
-
mask_box = ( # Default tile mask box definition
|
| 84 |
-
x_offset,
|
| 85 |
-
y_offset,
|
| 86 |
-
int(w * scale_factor),
|
| 87 |
-
int(h * scale_factor),
|
| 88 |
-
)
|
| 89 |
-
upscale_settings["tiles"].append(
|
| 90 |
-
{
|
| 91 |
-
"x": x1,
|
| 92 |
-
"y": y1,
|
| 93 |
-
"w": w,
|
| 94 |
-
"h": h,
|
| 95 |
-
"mask_box": mask_box,
|
| 96 |
-
"prompt": upscale_settings["prompt"], # Use top level prompt if available
|
| 97 |
-
"scale_factor": scale_factor,
|
| 98 |
-
}
|
| 99 |
-
)
|
| 100 |
-
|
| 101 |
-
# Generate the output image tiles
|
| 102 |
-
for i in range(0, len(upscale_settings["tiles"])):
|
| 103 |
-
generate_upscaled_tile(
|
| 104 |
-
config,
|
| 105 |
-
i,
|
| 106 |
-
upscale_settings,
|
| 107 |
-
context=context,
|
| 108 |
-
)
|
| 109 |
-
|
| 110 |
-
# Save completed upscaled image
|
| 111 |
-
if upscale_settings["output_format"].upper() == "JPEG":
|
| 112 |
-
result_rgb = result.convert("RGB")
|
| 113 |
-
result.close()
|
| 114 |
-
result = result_rgb
|
| 115 |
-
result.save(output_path)
|
| 116 |
-
result.close()
|
| 117 |
-
source_image.close()
|
| 118 |
-
return
|
| 119 |
-
|
| 120 |
-
|
| 121 |
-
def get_current_tile(
|
| 122 |
-
config,
|
| 123 |
-
context,
|
| 124 |
-
strength,
|
| 125 |
-
):
|
| 126 |
-
config.lcm_diffusion_setting.strength = strength
|
| 127 |
-
config.lcm_diffusion_setting.diffusion_task = DiffusionTask.image_to_image.value
|
| 128 |
-
if (
|
| 129 |
-
config.lcm_diffusion_setting.use_tiny_auto_encoder
|
| 130 |
-
and config.lcm_diffusion_setting.use_openvino
|
| 131 |
-
):
|
| 132 |
-
config.lcm_diffusion_setting.use_tiny_auto_encoder = False
|
| 133 |
-
current_tile = context.generate_text_to_image(
|
| 134 |
-
settings=config,
|
| 135 |
-
reshape=True,
|
| 136 |
-
device=DEVICE,
|
| 137 |
-
save_config=False,
|
| 138 |
-
)[0]
|
| 139 |
-
return current_tile
|
| 140 |
-
|
| 141 |
-
|
| 142 |
-
# Generates a single tile from the source image as defined in the
|
| 143 |
-
# upscale_settings["tiles"] array with the corresponding index and pastes the
|
| 144 |
-
# generated tile into the target image using the corresponding mask and scale
|
| 145 |
-
# factor; note that scale factor for the target image and the individual tiles
|
| 146 |
-
# can be different, this function will adjust scale factors as needed
|
| 147 |
-
def generate_upscaled_tile(
|
| 148 |
-
config,
|
| 149 |
-
index,
|
| 150 |
-
upscale_settings,
|
| 151 |
-
context: Context = None,
|
| 152 |
-
):
|
| 153 |
-
if config == None or upscale_settings == None:
|
| 154 |
-
logging.error("Wrong arguments in tile creation function call!")
|
| 155 |
-
return
|
| 156 |
-
|
| 157 |
-
x = upscale_settings["tiles"][index]["x"]
|
| 158 |
-
y = upscale_settings["tiles"][index]["y"]
|
| 159 |
-
w = upscale_settings["tiles"][index]["w"]
|
| 160 |
-
h = upscale_settings["tiles"][index]["h"]
|
| 161 |
-
tile_prompt = upscale_settings["tiles"][index]["prompt"]
|
| 162 |
-
scale_factor = upscale_settings["scale_factor"]
|
| 163 |
-
tile_scale_factor = upscale_settings["tiles"][index]["scale_factor"]
|
| 164 |
-
target_width = int(w * tile_scale_factor)
|
| 165 |
-
target_height = int(h * tile_scale_factor)
|
| 166 |
-
strength = upscale_settings["strength"]
|
| 167 |
-
source_image = upscale_settings["source_image"]
|
| 168 |
-
target_image = upscale_settings["target_image"]
|
| 169 |
-
mask_image = generate_tile_mask(config, index, upscale_settings)
|
| 170 |
-
|
| 171 |
-
config.lcm_diffusion_setting.number_of_images = 1
|
| 172 |
-
config.lcm_diffusion_setting.prompt = tile_prompt
|
| 173 |
-
config.lcm_diffusion_setting.image_width = target_width
|
| 174 |
-
config.lcm_diffusion_setting.image_height = target_height
|
| 175 |
-
config.lcm_diffusion_setting.init_image = source_image.crop((x, y, x + w, y + h))
|
| 176 |
-
|
| 177 |
-
current_tile = None
|
| 178 |
-
print(f"[SD Upscale] Generating tile {index + 1}/{len(upscale_settings['tiles'])} ")
|
| 179 |
-
if tile_prompt == None or tile_prompt == "":
|
| 180 |
-
config.lcm_diffusion_setting.prompt = ""
|
| 181 |
-
config.lcm_diffusion_setting.negative_prompt = ""
|
| 182 |
-
current_tile = get_current_tile(config, context, strength)
|
| 183 |
-
else:
|
| 184 |
-
# Attempt to use img2img with low denoising strength to
|
| 185 |
-
# generate the tiles with the extra aid of a prompt
|
| 186 |
-
# context = get_context(InterfaceType.CLI)
|
| 187 |
-
current_tile = get_current_tile(config, context, strength)
|
| 188 |
-
|
| 189 |
-
if math.isclose(scale_factor, tile_scale_factor):
|
| 190 |
-
target_image.paste(
|
| 191 |
-
current_tile, (int(x * scale_factor), int(y * scale_factor)), mask_image
|
| 192 |
-
)
|
| 193 |
-
else:
|
| 194 |
-
target_image.paste(
|
| 195 |
-
current_tile.resize((int(w * scale_factor), int(h * scale_factor))),
|
| 196 |
-
(int(x * scale_factor), int(y * scale_factor)),
|
| 197 |
-
mask_image.resize((int(w * scale_factor), int(h * scale_factor))),
|
| 198 |
-
)
|
| 199 |
-
mask_image.close()
|
| 200 |
-
current_tile.close()
|
| 201 |
-
config.lcm_diffusion_setting.init_image.close()
|
| 202 |
-
|
| 203 |
-
|
| 204 |
-
# Generate tile mask using the box definition in the upscale_settings["tiles"]
|
| 205 |
-
# array with the corresponding index; note that tile masks for the default
|
| 206 |
-
# tiled upscale task can be reused but that would complicate the code, so
|
| 207 |
-
# new tile masks are instead created for each tile
|
| 208 |
-
def generate_tile_mask(
|
| 209 |
-
config,
|
| 210 |
-
index,
|
| 211 |
-
upscale_settings,
|
| 212 |
-
):
|
| 213 |
-
scale_factor = upscale_settings["scale_factor"]
|
| 214 |
-
tile_overlap = upscale_settings["tile_overlap"]
|
| 215 |
-
tile_scale_factor = upscale_settings["tiles"][index]["scale_factor"]
|
| 216 |
-
w = int(upscale_settings["tiles"][index]["w"] * tile_scale_factor)
|
| 217 |
-
h = int(upscale_settings["tiles"][index]["h"] * tile_scale_factor)
|
| 218 |
-
# The Stable Diffusion pipeline automatically adjusts the output size
|
| 219 |
-
# to multiples of 8 pixels; the mask must be created with the same
|
| 220 |
-
# size as the output tile
|
| 221 |
-
w = w - (w % 8)
|
| 222 |
-
h = h - (h % 8)
|
| 223 |
-
mask_box = upscale_settings["tiles"][index]["mask_box"]
|
| 224 |
-
if mask_box == None:
|
| 225 |
-
# Build a default solid mask with soft/transparent edges
|
| 226 |
-
mask_box = (
|
| 227 |
-
tile_overlap,
|
| 228 |
-
tile_overlap,
|
| 229 |
-
w - tile_overlap,
|
| 230 |
-
h - tile_overlap,
|
| 231 |
-
)
|
| 232 |
-
mask_image = Image.new(mode="RGBA", size=(w, h), color=(0, 0, 0, 0))
|
| 233 |
-
mask_draw = ImageDraw.Draw(mask_image)
|
| 234 |
-
mask_draw.rectangle(tuple(mask_box), fill=(0, 0, 0))
|
| 235 |
-
mask_blur = mask_image.filter(ImageFilter.BoxBlur(tile_overlap - 1))
|
| 236 |
-
mask_image.close()
|
| 237 |
-
return mask_blur
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src/backend/upscale/upscaler.py
DELETED
|
@@ -1,52 +0,0 @@
|
|
| 1 |
-
from backend.models.lcmdiffusion_setting import DiffusionTask
|
| 2 |
-
from backend.models.upscale import UpscaleMode
|
| 3 |
-
from backend.upscale.edsr_upscale_onnx import upscale_edsr_2x
|
| 4 |
-
from backend.upscale.aura_sr_upscale import upscale_aura_sr
|
| 5 |
-
from backend.upscale.tiled_upscale import generate_upscaled_image
|
| 6 |
-
from context import Context
|
| 7 |
-
from PIL import Image
|
| 8 |
-
from state import get_settings
|
| 9 |
-
|
| 10 |
-
|
| 11 |
-
config = get_settings()
|
| 12 |
-
|
| 13 |
-
|
| 14 |
-
def upscale_image(
|
| 15 |
-
context: Context,
|
| 16 |
-
src_image_path: str,
|
| 17 |
-
dst_image_path: str,
|
| 18 |
-
scale_factor: int = 2,
|
| 19 |
-
upscale_mode: UpscaleMode = UpscaleMode.normal.value,
|
| 20 |
-
strength: float = 0.1,
|
| 21 |
-
):
|
| 22 |
-
if upscale_mode == UpscaleMode.normal.value:
|
| 23 |
-
upscaled_img = upscale_edsr_2x(src_image_path)
|
| 24 |
-
upscaled_img.save(dst_image_path)
|
| 25 |
-
print(f"Upscaled image saved {dst_image_path}")
|
| 26 |
-
elif upscale_mode == UpscaleMode.aura_sr.value:
|
| 27 |
-
upscaled_img = upscale_aura_sr(src_image_path)
|
| 28 |
-
upscaled_img.save(dst_image_path)
|
| 29 |
-
print(f"Upscaled image saved {dst_image_path}")
|
| 30 |
-
else:
|
| 31 |
-
config.settings.lcm_diffusion_setting.strength = (
|
| 32 |
-
0.3 if config.settings.lcm_diffusion_setting.use_openvino else strength
|
| 33 |
-
)
|
| 34 |
-
config.settings.lcm_diffusion_setting.diffusion_task = (
|
| 35 |
-
DiffusionTask.image_to_image.value
|
| 36 |
-
)
|
| 37 |
-
|
| 38 |
-
generate_upscaled_image(
|
| 39 |
-
config.settings,
|
| 40 |
-
src_image_path,
|
| 41 |
-
config.settings.lcm_diffusion_setting.strength,
|
| 42 |
-
upscale_settings=None,
|
| 43 |
-
context=context,
|
| 44 |
-
tile_overlap=(
|
| 45 |
-
32 if config.settings.lcm_diffusion_setting.use_openvino else 16
|
| 46 |
-
),
|
| 47 |
-
output_path=dst_image_path,
|
| 48 |
-
image_format=config.settings.generated_images.format,
|
| 49 |
-
)
|
| 50 |
-
print(f"Upscaled image saved {dst_image_path}")
|
| 51 |
-
|
| 52 |
-
return [Image.open(dst_image_path)]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src/constants.py
DELETED
|
@@ -1,25 +0,0 @@
|
|
| 1 |
-
from os import environ, cpu_count
|
| 2 |
-
|
| 3 |
-
cpu_cores = cpu_count()
|
| 4 |
-
cpus = cpu_cores // 2 if cpu_cores else 0
|
| 5 |
-
APP_VERSION = "v1.0.0 beta 200"
|
| 6 |
-
LCM_DEFAULT_MODEL = "stabilityai/sd-turbo"
|
| 7 |
-
LCM_DEFAULT_MODEL_OPENVINO = "rupeshs/sd-turbo-openvino"
|
| 8 |
-
APP_NAME = "FastSD CPU"
|
| 9 |
-
APP_SETTINGS_FILE = "settings.yaml"
|
| 10 |
-
RESULTS_DIRECTORY = "results"
|
| 11 |
-
CONFIG_DIRECTORY = "configs"
|
| 12 |
-
DEVICE = environ.get("DEVICE", "cpu")
|
| 13 |
-
SD_MODELS_FILE = "stable-diffusion-models.txt"
|
| 14 |
-
LCM_LORA_MODELS_FILE = "lcm-lora-models.txt"
|
| 15 |
-
OPENVINO_LCM_MODELS_FILE = "openvino-lcm-models.txt"
|
| 16 |
-
TAESD_MODEL = "madebyollin/taesd"
|
| 17 |
-
TAESDXL_MODEL = "madebyollin/taesdxl"
|
| 18 |
-
TAESD_MODEL_OPENVINO = "deinferno/taesd-openvino"
|
| 19 |
-
LCM_MODELS_FILE = "lcm-models.txt"
|
| 20 |
-
TAESDXL_MODEL_OPENVINO = "rupeshs/taesdxl-openvino"
|
| 21 |
-
LORA_DIRECTORY = "lora_models"
|
| 22 |
-
CONTROLNET_DIRECTORY = "controlnet_models"
|
| 23 |
-
MODELS_DIRECTORY = "models"
|
| 24 |
-
GGUF_THREADS = environ.get("GGUF_THREADS", cpus)
|
| 25 |
-
TAEF1_MODEL_OPENVINO = "rupeshs/taef1-openvino"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src/context.py
DELETED
|
@@ -1,85 +0,0 @@
|
|
| 1 |
-
from typing import Any
|
| 2 |
-
from app_settings import Settings
|
| 3 |
-
from models.interface_types import InterfaceType
|
| 4 |
-
from backend.models.lcmdiffusion_setting import DiffusionTask
|
| 5 |
-
from backend.lcm_text_to_image import LCMTextToImage
|
| 6 |
-
from time import perf_counter
|
| 7 |
-
from backend.image_saver import ImageSaver
|
| 8 |
-
from pprint import pprint
|
| 9 |
-
|
| 10 |
-
|
| 11 |
-
class Context:
|
| 12 |
-
def __init__(
|
| 13 |
-
self,
|
| 14 |
-
interface_type: InterfaceType,
|
| 15 |
-
device="cpu",
|
| 16 |
-
):
|
| 17 |
-
self.interface_type = interface_type.value
|
| 18 |
-
self.lcm_text_to_image = LCMTextToImage(device)
|
| 19 |
-
self._latency = 0
|
| 20 |
-
|
| 21 |
-
@property
|
| 22 |
-
def latency(self):
|
| 23 |
-
return self._latency
|
| 24 |
-
|
| 25 |
-
def generate_text_to_image(
|
| 26 |
-
self,
|
| 27 |
-
settings: Settings,
|
| 28 |
-
reshape: bool = False,
|
| 29 |
-
device: str = "cpu",
|
| 30 |
-
save_config=True,
|
| 31 |
-
) -> Any:
|
| 32 |
-
if (
|
| 33 |
-
settings.lcm_diffusion_setting.use_tiny_auto_encoder
|
| 34 |
-
and settings.lcm_diffusion_setting.use_openvino
|
| 35 |
-
):
|
| 36 |
-
print(
|
| 37 |
-
"WARNING: Tiny AutoEncoder is not supported in Image to image mode (OpenVINO)"
|
| 38 |
-
)
|
| 39 |
-
tick = perf_counter()
|
| 40 |
-
from state import get_settings
|
| 41 |
-
|
| 42 |
-
if (
|
| 43 |
-
settings.lcm_diffusion_setting.diffusion_task
|
| 44 |
-
== DiffusionTask.text_to_image.value
|
| 45 |
-
):
|
| 46 |
-
settings.lcm_diffusion_setting.init_image = None
|
| 47 |
-
|
| 48 |
-
if save_config:
|
| 49 |
-
get_settings().save()
|
| 50 |
-
|
| 51 |
-
pprint(settings.lcm_diffusion_setting.model_dump())
|
| 52 |
-
if not settings.lcm_diffusion_setting.lcm_lora:
|
| 53 |
-
return None
|
| 54 |
-
self.lcm_text_to_image.init(
|
| 55 |
-
device,
|
| 56 |
-
settings.lcm_diffusion_setting,
|
| 57 |
-
)
|
| 58 |
-
images = self.lcm_text_to_image.generate(
|
| 59 |
-
settings.lcm_diffusion_setting,
|
| 60 |
-
reshape,
|
| 61 |
-
)
|
| 62 |
-
elapsed = perf_counter() - tick
|
| 63 |
-
self._latency = elapsed
|
| 64 |
-
print(f"Latency : {elapsed:.2f} seconds")
|
| 65 |
-
if settings.lcm_diffusion_setting.controlnet:
|
| 66 |
-
if settings.lcm_diffusion_setting.controlnet.enabled:
|
| 67 |
-
images.append(settings.lcm_diffusion_setting.controlnet._control_image)
|
| 68 |
-
return images
|
| 69 |
-
|
| 70 |
-
|
| 71 |
-
def save_images(
|
| 72 |
-
self,
|
| 73 |
-
images: Any,
|
| 74 |
-
settings: Settings,
|
| 75 |
-
) -> list[str]:
|
| 76 |
-
saved_images = []
|
| 77 |
-
if images and settings.generated_images.save_image:
|
| 78 |
-
saved_images = ImageSaver.save_images(
|
| 79 |
-
settings.generated_images.path,
|
| 80 |
-
images=images,
|
| 81 |
-
lcm_diffusion_setting=settings.lcm_diffusion_setting,
|
| 82 |
-
format=settings.generated_images.format,
|
| 83 |
-
jpeg_quality=settings.generated_images.save_image_quality,
|
| 84 |
-
)
|
| 85 |
-
return saved_images
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src/frontend/cli_interactive.py
DELETED
|
@@ -1,661 +0,0 @@
|
|
| 1 |
-
from os import path
|
| 2 |
-
from PIL import Image
|
| 3 |
-
from typing import Any
|
| 4 |
-
|
| 5 |
-
from constants import DEVICE
|
| 6 |
-
from paths import FastStableDiffusionPaths
|
| 7 |
-
from backend.upscale.upscaler import upscale_image
|
| 8 |
-
from backend.upscale.tiled_upscale import generate_upscaled_image
|
| 9 |
-
from frontend.webui.image_variations_ui import generate_image_variations
|
| 10 |
-
from backend.lora import (
|
| 11 |
-
get_active_lora_weights,
|
| 12 |
-
update_lora_weights,
|
| 13 |
-
load_lora_weight,
|
| 14 |
-
)
|
| 15 |
-
from backend.models.lcmdiffusion_setting import (
|
| 16 |
-
DiffusionTask,
|
| 17 |
-
ControlNetSetting,
|
| 18 |
-
)
|
| 19 |
-
|
| 20 |
-
|
| 21 |
-
_batch_count = 1
|
| 22 |
-
_edit_lora_settings = False
|
| 23 |
-
|
| 24 |
-
|
| 25 |
-
def user_value(
|
| 26 |
-
value_type: type,
|
| 27 |
-
message: str,
|
| 28 |
-
default_value: Any,
|
| 29 |
-
) -> Any:
|
| 30 |
-
try:
|
| 31 |
-
value = value_type(input(message))
|
| 32 |
-
except:
|
| 33 |
-
value = default_value
|
| 34 |
-
return value
|
| 35 |
-
|
| 36 |
-
|
| 37 |
-
def interactive_mode(
|
| 38 |
-
config,
|
| 39 |
-
context,
|
| 40 |
-
):
|
| 41 |
-
print("=============================================")
|
| 42 |
-
print("Welcome to FastSD CPU Interactive CLI")
|
| 43 |
-
print("=============================================")
|
| 44 |
-
while True:
|
| 45 |
-
print("> 1. Text to Image")
|
| 46 |
-
print("> 2. Image to Image")
|
| 47 |
-
print("> 3. Image Variations")
|
| 48 |
-
print("> 4. EDSR Upscale")
|
| 49 |
-
print("> 5. SD Upscale")
|
| 50 |
-
print("> 6. Edit default generation settings")
|
| 51 |
-
print("> 7. Edit LoRA settings")
|
| 52 |
-
print("> 8. Edit ControlNet settings")
|
| 53 |
-
print("> 9. Edit negative prompt")
|
| 54 |
-
print("> 10. Quit")
|
| 55 |
-
option = user_value(
|
| 56 |
-
int,
|
| 57 |
-
"Enter a Diffusion Task number (1): ",
|
| 58 |
-
1,
|
| 59 |
-
)
|
| 60 |
-
if option not in range(1, 11):
|
| 61 |
-
print("Wrong Diffusion Task number!")
|
| 62 |
-
exit()
|
| 63 |
-
|
| 64 |
-
if option == 1:
|
| 65 |
-
interactive_txt2img(
|
| 66 |
-
config,
|
| 67 |
-
context,
|
| 68 |
-
)
|
| 69 |
-
elif option == 2:
|
| 70 |
-
interactive_img2img(
|
| 71 |
-
config,
|
| 72 |
-
context,
|
| 73 |
-
)
|
| 74 |
-
elif option == 3:
|
| 75 |
-
interactive_variations(
|
| 76 |
-
config,
|
| 77 |
-
context,
|
| 78 |
-
)
|
| 79 |
-
elif option == 4:
|
| 80 |
-
interactive_edsr(
|
| 81 |
-
config,
|
| 82 |
-
context,
|
| 83 |
-
)
|
| 84 |
-
elif option == 5:
|
| 85 |
-
interactive_sdupscale(
|
| 86 |
-
config,
|
| 87 |
-
context,
|
| 88 |
-
)
|
| 89 |
-
elif option == 6:
|
| 90 |
-
interactive_settings(
|
| 91 |
-
config,
|
| 92 |
-
context,
|
| 93 |
-
)
|
| 94 |
-
elif option == 7:
|
| 95 |
-
interactive_lora(
|
| 96 |
-
config,
|
| 97 |
-
context,
|
| 98 |
-
True,
|
| 99 |
-
)
|
| 100 |
-
elif option == 8:
|
| 101 |
-
interactive_controlnet(
|
| 102 |
-
config,
|
| 103 |
-
context,
|
| 104 |
-
True,
|
| 105 |
-
)
|
| 106 |
-
elif option == 9:
|
| 107 |
-
interactive_negative(
|
| 108 |
-
config,
|
| 109 |
-
context,
|
| 110 |
-
)
|
| 111 |
-
elif option == 10:
|
| 112 |
-
exit()
|
| 113 |
-
|
| 114 |
-
|
| 115 |
-
def interactive_negative(
|
| 116 |
-
config,
|
| 117 |
-
context,
|
| 118 |
-
):
|
| 119 |
-
settings = config.lcm_diffusion_setting
|
| 120 |
-
print(f"Current negative prompt: '{settings.negative_prompt}'")
|
| 121 |
-
user_input = input("Write a negative prompt (set guidance > 1.0): ")
|
| 122 |
-
if user_input == "":
|
| 123 |
-
return
|
| 124 |
-
else:
|
| 125 |
-
settings.negative_prompt = user_input
|
| 126 |
-
|
| 127 |
-
|
| 128 |
-
def interactive_controlnet(
|
| 129 |
-
config,
|
| 130 |
-
context,
|
| 131 |
-
menu_flag=False,
|
| 132 |
-
):
|
| 133 |
-
"""
|
| 134 |
-
@param menu_flag: Indicates whether this function was called from the main
|
| 135 |
-
interactive CLI menu; _True_ if called from the main menu, _False_ otherwise
|
| 136 |
-
"""
|
| 137 |
-
settings = config.lcm_diffusion_setting
|
| 138 |
-
if not settings.controlnet:
|
| 139 |
-
settings.controlnet = ControlNetSetting()
|
| 140 |
-
|
| 141 |
-
current_enabled = settings.controlnet.enabled
|
| 142 |
-
current_adapter_path = settings.controlnet.adapter_path
|
| 143 |
-
current_conditioning_scale = settings.controlnet.conditioning_scale
|
| 144 |
-
current_control_image = settings.controlnet._control_image
|
| 145 |
-
|
| 146 |
-
option = input("Enable ControlNet? (y/N): ")
|
| 147 |
-
settings.controlnet.enabled = True if option.upper() == "Y" else False
|
| 148 |
-
if settings.controlnet.enabled:
|
| 149 |
-
option = input(
|
| 150 |
-
f"Enter ControlNet adapter path ({settings.controlnet.adapter_path}): "
|
| 151 |
-
)
|
| 152 |
-
if option != "":
|
| 153 |
-
settings.controlnet.adapter_path = option
|
| 154 |
-
settings.controlnet.conditioning_scale = user_value(
|
| 155 |
-
float,
|
| 156 |
-
f"Enter ControlNet conditioning scale ({settings.controlnet.conditioning_scale}): ",
|
| 157 |
-
settings.controlnet.conditioning_scale,
|
| 158 |
-
)
|
| 159 |
-
option = input(
|
| 160 |
-
f"Enter ControlNet control image path (Leave empty to reuse current): "
|
| 161 |
-
)
|
| 162 |
-
if option != "":
|
| 163 |
-
try:
|
| 164 |
-
new_image = Image.open(option)
|
| 165 |
-
settings.controlnet._control_image = new_image
|
| 166 |
-
except (AttributeError, FileNotFoundError) as e:
|
| 167 |
-
settings.controlnet._control_image = None
|
| 168 |
-
if (
|
| 169 |
-
not settings.controlnet.adapter_path
|
| 170 |
-
or not path.exists(settings.controlnet.adapter_path)
|
| 171 |
-
or not settings.controlnet._control_image
|
| 172 |
-
):
|
| 173 |
-
print("Invalid ControlNet settings! Disabling ControlNet")
|
| 174 |
-
settings.controlnet.enabled = False
|
| 175 |
-
|
| 176 |
-
if (
|
| 177 |
-
settings.controlnet.enabled != current_enabled
|
| 178 |
-
or settings.controlnet.adapter_path != current_adapter_path
|
| 179 |
-
):
|
| 180 |
-
settings.rebuild_pipeline = True
|
| 181 |
-
|
| 182 |
-
|
| 183 |
-
def interactive_lora(
|
| 184 |
-
config,
|
| 185 |
-
context,
|
| 186 |
-
menu_flag=False,
|
| 187 |
-
):
|
| 188 |
-
"""
|
| 189 |
-
@param menu_flag: Indicates whether this function was called from the main
|
| 190 |
-
interactive CLI menu; _True_ if called from the main menu, _False_ otherwise
|
| 191 |
-
"""
|
| 192 |
-
if context == None or context.lcm_text_to_image.pipeline == None:
|
| 193 |
-
print("Diffusion pipeline not initialized, please run a generation task first!")
|
| 194 |
-
return
|
| 195 |
-
|
| 196 |
-
print("> 1. Change LoRA weights")
|
| 197 |
-
print("> 2. Load new LoRA model")
|
| 198 |
-
option = user_value(
|
| 199 |
-
int,
|
| 200 |
-
"Enter a LoRA option (1): ",
|
| 201 |
-
1,
|
| 202 |
-
)
|
| 203 |
-
if option not in range(1, 3):
|
| 204 |
-
print("Wrong LoRA option!")
|
| 205 |
-
return
|
| 206 |
-
|
| 207 |
-
if option == 1:
|
| 208 |
-
update_weights = []
|
| 209 |
-
active_weights = get_active_lora_weights()
|
| 210 |
-
for lora in active_weights:
|
| 211 |
-
weight = user_value(
|
| 212 |
-
float,
|
| 213 |
-
f"Enter a new LoRA weight for {lora[0]} ({lora[1]}): ",
|
| 214 |
-
lora[1],
|
| 215 |
-
)
|
| 216 |
-
update_weights.append(
|
| 217 |
-
(
|
| 218 |
-
lora[0],
|
| 219 |
-
weight,
|
| 220 |
-
)
|
| 221 |
-
)
|
| 222 |
-
if len(update_weights) > 0:
|
| 223 |
-
update_lora_weights(
|
| 224 |
-
context.lcm_text_to_image.pipeline,
|
| 225 |
-
config.lcm_diffusion_setting,
|
| 226 |
-
update_weights,
|
| 227 |
-
)
|
| 228 |
-
elif option == 2:
|
| 229 |
-
# Load a new LoRA
|
| 230 |
-
settings = config.lcm_diffusion_setting
|
| 231 |
-
settings.lora.fuse = False
|
| 232 |
-
settings.lora.enabled = False
|
| 233 |
-
settings.lora.path = input("Enter LoRA model path: ")
|
| 234 |
-
settings.lora.weight = user_value(
|
| 235 |
-
float,
|
| 236 |
-
"Enter a LoRA weight (0.5): ",
|
| 237 |
-
0.5,
|
| 238 |
-
)
|
| 239 |
-
if not path.exists(settings.lora.path):
|
| 240 |
-
print("Invalid LoRA model path!")
|
| 241 |
-
return
|
| 242 |
-
settings.lora.enabled = True
|
| 243 |
-
load_lora_weight(context.lcm_text_to_image.pipeline, settings)
|
| 244 |
-
|
| 245 |
-
if menu_flag:
|
| 246 |
-
global _edit_lora_settings
|
| 247 |
-
_edit_lora_settings = False
|
| 248 |
-
option = input("Edit LoRA settings after every generation? (y/N): ")
|
| 249 |
-
if option.upper() == "Y":
|
| 250 |
-
_edit_lora_settings = True
|
| 251 |
-
|
| 252 |
-
|
| 253 |
-
def interactive_settings(
|
| 254 |
-
config,
|
| 255 |
-
context,
|
| 256 |
-
):
|
| 257 |
-
global _batch_count
|
| 258 |
-
settings = config.lcm_diffusion_setting
|
| 259 |
-
print("Enter generation settings (leave empty to use current value)")
|
| 260 |
-
print("> 1. Use LCM")
|
| 261 |
-
print("> 2. Use LCM-Lora")
|
| 262 |
-
print("> 3. Use OpenVINO")
|
| 263 |
-
option = user_value(
|
| 264 |
-
int,
|
| 265 |
-
"Select inference model option (1): ",
|
| 266 |
-
1,
|
| 267 |
-
)
|
| 268 |
-
if option not in range(1, 4):
|
| 269 |
-
print("Wrong inference model option! Falling back to defaults")
|
| 270 |
-
return
|
| 271 |
-
|
| 272 |
-
settings.use_lcm_lora = False
|
| 273 |
-
settings.use_openvino = False
|
| 274 |
-
if option == 1:
|
| 275 |
-
lcm_model_id = input(f"Enter LCM model ID ({settings.lcm_model_id}): ")
|
| 276 |
-
if lcm_model_id != "":
|
| 277 |
-
settings.lcm_model_id = lcm_model_id
|
| 278 |
-
elif option == 2:
|
| 279 |
-
settings.use_lcm_lora = True
|
| 280 |
-
lcm_lora_id = input(
|
| 281 |
-
f"Enter LCM-Lora model ID ({settings.lcm_lora.lcm_lora_id}): "
|
| 282 |
-
)
|
| 283 |
-
if lcm_lora_id != "":
|
| 284 |
-
settings.lcm_lora.lcm_lora_id = lcm_lora_id
|
| 285 |
-
base_model_id = input(
|
| 286 |
-
f"Enter Base model ID ({settings.lcm_lora.base_model_id}): "
|
| 287 |
-
)
|
| 288 |
-
if base_model_id != "":
|
| 289 |
-
settings.lcm_lora.base_model_id = base_model_id
|
| 290 |
-
elif option == 3:
|
| 291 |
-
settings.use_openvino = True
|
| 292 |
-
openvino_lcm_model_id = input(
|
| 293 |
-
f"Enter OpenVINO model ID ({settings.openvino_lcm_model_id}): "
|
| 294 |
-
)
|
| 295 |
-
if openvino_lcm_model_id != "":
|
| 296 |
-
settings.openvino_lcm_model_id = openvino_lcm_model_id
|
| 297 |
-
|
| 298 |
-
settings.use_offline_model = True
|
| 299 |
-
settings.use_tiny_auto_encoder = True
|
| 300 |
-
option = input("Work offline? (Y/n): ")
|
| 301 |
-
if option.upper() == "N":
|
| 302 |
-
settings.use_offline_model = False
|
| 303 |
-
option = input("Use Tiny Auto Encoder? (Y/n): ")
|
| 304 |
-
if option.upper() == "N":
|
| 305 |
-
settings.use_tiny_auto_encoder = False
|
| 306 |
-
|
| 307 |
-
settings.image_width = user_value(
|
| 308 |
-
int,
|
| 309 |
-
f"Image width ({settings.image_width}): ",
|
| 310 |
-
settings.image_width,
|
| 311 |
-
)
|
| 312 |
-
settings.image_height = user_value(
|
| 313 |
-
int,
|
| 314 |
-
f"Image height ({settings.image_height}): ",
|
| 315 |
-
settings.image_height,
|
| 316 |
-
)
|
| 317 |
-
settings.inference_steps = user_value(
|
| 318 |
-
int,
|
| 319 |
-
f"Inference steps ({settings.inference_steps}): ",
|
| 320 |
-
settings.inference_steps,
|
| 321 |
-
)
|
| 322 |
-
settings.guidance_scale = user_value(
|
| 323 |
-
float,
|
| 324 |
-
f"Guidance scale ({settings.guidance_scale}): ",
|
| 325 |
-
settings.guidance_scale,
|
| 326 |
-
)
|
| 327 |
-
settings.number_of_images = user_value(
|
| 328 |
-
int,
|
| 329 |
-
f"Number of images per batch ({settings.number_of_images}): ",
|
| 330 |
-
settings.number_of_images,
|
| 331 |
-
)
|
| 332 |
-
_batch_count = user_value(
|
| 333 |
-
int,
|
| 334 |
-
f"Batch count ({_batch_count}): ",
|
| 335 |
-
_batch_count,
|
| 336 |
-
)
|
| 337 |
-
# output_format = user_value(int, f"Output format (PNG)", 1)
|
| 338 |
-
print(config.lcm_diffusion_setting)
|
| 339 |
-
|
| 340 |
-
|
| 341 |
-
def interactive_txt2img(
|
| 342 |
-
config,
|
| 343 |
-
context,
|
| 344 |
-
):
|
| 345 |
-
global _batch_count
|
| 346 |
-
config.lcm_diffusion_setting.diffusion_task = DiffusionTask.text_to_image.value
|
| 347 |
-
user_input = input("Write a prompt (write 'exit' to quit): ")
|
| 348 |
-
while True:
|
| 349 |
-
if user_input == "exit":
|
| 350 |
-
return
|
| 351 |
-
elif user_input == "":
|
| 352 |
-
user_input = config.lcm_diffusion_setting.prompt
|
| 353 |
-
config.lcm_diffusion_setting.prompt = user_input
|
| 354 |
-
for _ in range(0, _batch_count):
|
| 355 |
-
images = context.generate_text_to_image(
|
| 356 |
-
settings=config,
|
| 357 |
-
device=DEVICE,
|
| 358 |
-
)
|
| 359 |
-
context.save_images(
|
| 360 |
-
images,
|
| 361 |
-
config,
|
| 362 |
-
)
|
| 363 |
-
if _edit_lora_settings:
|
| 364 |
-
interactive_lora(
|
| 365 |
-
config,
|
| 366 |
-
context,
|
| 367 |
-
)
|
| 368 |
-
user_input = input("Write a prompt: ")
|
| 369 |
-
|
| 370 |
-
|
| 371 |
-
def interactive_img2img(
|
| 372 |
-
config,
|
| 373 |
-
context,
|
| 374 |
-
):
|
| 375 |
-
global _batch_count
|
| 376 |
-
settings = config.lcm_diffusion_setting
|
| 377 |
-
settings.diffusion_task = DiffusionTask.image_to_image.value
|
| 378 |
-
steps = settings.inference_steps
|
| 379 |
-
source_path = input("Image path: ")
|
| 380 |
-
if source_path == "":
|
| 381 |
-
print("Error : You need to provide a file in img2img mode")
|
| 382 |
-
return
|
| 383 |
-
settings.strength = user_value(
|
| 384 |
-
float,
|
| 385 |
-
f"img2img strength ({settings.strength}): ",
|
| 386 |
-
settings.strength,
|
| 387 |
-
)
|
| 388 |
-
settings.inference_steps = int(steps / settings.strength + 1)
|
| 389 |
-
user_input = input("Write a prompt (write 'exit' to quit): ")
|
| 390 |
-
while True:
|
| 391 |
-
if user_input == "exit":
|
| 392 |
-
settings.inference_steps = steps
|
| 393 |
-
return
|
| 394 |
-
settings.init_image = Image.open(source_path)
|
| 395 |
-
settings.prompt = user_input
|
| 396 |
-
for _ in range(0, _batch_count):
|
| 397 |
-
images = context.generate_text_to_image(
|
| 398 |
-
settings=config,
|
| 399 |
-
device=DEVICE,
|
| 400 |
-
)
|
| 401 |
-
context.save_images(
|
| 402 |
-
images,
|
| 403 |
-
config,
|
| 404 |
-
)
|
| 405 |
-
new_path = input(f"Image path ({source_path}): ")
|
| 406 |
-
if new_path != "":
|
| 407 |
-
source_path = new_path
|
| 408 |
-
settings.strength = user_value(
|
| 409 |
-
float,
|
| 410 |
-
f"img2img strength ({settings.strength}): ",
|
| 411 |
-
settings.strength,
|
| 412 |
-
)
|
| 413 |
-
if _edit_lora_settings:
|
| 414 |
-
interactive_lora(
|
| 415 |
-
config,
|
| 416 |
-
context,
|
| 417 |
-
)
|
| 418 |
-
settings.inference_steps = int(steps / settings.strength + 1)
|
| 419 |
-
user_input = input("Write a prompt: ")
|
| 420 |
-
|
| 421 |
-
|
| 422 |
-
def interactive_variations(
|
| 423 |
-
config,
|
| 424 |
-
context,
|
| 425 |
-
):
|
| 426 |
-
global _batch_count
|
| 427 |
-
settings = config.lcm_diffusion_setting
|
| 428 |
-
settings.diffusion_task = DiffusionTask.image_to_image.value
|
| 429 |
-
steps = settings.inference_steps
|
| 430 |
-
source_path = input("Image path: ")
|
| 431 |
-
if source_path == "":
|
| 432 |
-
print("Error : You need to provide a file in Image variations mode")
|
| 433 |
-
return
|
| 434 |
-
settings.strength = user_value(
|
| 435 |
-
float,
|
| 436 |
-
f"Image variations strength ({settings.strength}): ",
|
| 437 |
-
settings.strength,
|
| 438 |
-
)
|
| 439 |
-
settings.inference_steps = int(steps / settings.strength + 1)
|
| 440 |
-
while True:
|
| 441 |
-
settings.init_image = Image.open(source_path)
|
| 442 |
-
settings.prompt = ""
|
| 443 |
-
for i in range(0, _batch_count):
|
| 444 |
-
generate_image_variations(
|
| 445 |
-
settings.init_image,
|
| 446 |
-
settings.strength,
|
| 447 |
-
)
|
| 448 |
-
if _edit_lora_settings:
|
| 449 |
-
interactive_lora(
|
| 450 |
-
config,
|
| 451 |
-
context,
|
| 452 |
-
)
|
| 453 |
-
user_input = input("Continue in Image variations mode? (Y/n): ")
|
| 454 |
-
if user_input.upper() == "N":
|
| 455 |
-
settings.inference_steps = steps
|
| 456 |
-
return
|
| 457 |
-
new_path = input(f"Image path ({source_path}): ")
|
| 458 |
-
if new_path != "":
|
| 459 |
-
source_path = new_path
|
| 460 |
-
settings.strength = user_value(
|
| 461 |
-
float,
|
| 462 |
-
f"Image variations strength ({settings.strength}): ",
|
| 463 |
-
settings.strength,
|
| 464 |
-
)
|
| 465 |
-
settings.inference_steps = int(steps / settings.strength + 1)
|
| 466 |
-
|
| 467 |
-
|
| 468 |
-
def interactive_edsr(
|
| 469 |
-
config,
|
| 470 |
-
context,
|
| 471 |
-
):
|
| 472 |
-
source_path = input("Image path: ")
|
| 473 |
-
if source_path == "":
|
| 474 |
-
print("Error : You need to provide a file in EDSR mode")
|
| 475 |
-
return
|
| 476 |
-
while True:
|
| 477 |
-
output_path = FastStableDiffusionPaths.get_upscale_filepath(
|
| 478 |
-
source_path,
|
| 479 |
-
2,
|
| 480 |
-
config.generated_images.format,
|
| 481 |
-
)
|
| 482 |
-
result = upscale_image(
|
| 483 |
-
context,
|
| 484 |
-
source_path,
|
| 485 |
-
output_path,
|
| 486 |
-
2,
|
| 487 |
-
)
|
| 488 |
-
user_input = input("Continue in EDSR upscale mode? (Y/n): ")
|
| 489 |
-
if user_input.upper() == "N":
|
| 490 |
-
return
|
| 491 |
-
new_path = input(f"Image path ({source_path}): ")
|
| 492 |
-
if new_path != "":
|
| 493 |
-
source_path = new_path
|
| 494 |
-
|
| 495 |
-
|
| 496 |
-
def interactive_sdupscale_settings(config):
|
| 497 |
-
steps = config.lcm_diffusion_setting.inference_steps
|
| 498 |
-
custom_settings = {}
|
| 499 |
-
print("> 1. Upscale whole image")
|
| 500 |
-
print("> 2. Define custom tiles (advanced)")
|
| 501 |
-
option = user_value(
|
| 502 |
-
int,
|
| 503 |
-
"Select an SD Upscale option (1): ",
|
| 504 |
-
1,
|
| 505 |
-
)
|
| 506 |
-
if option not in range(1, 3):
|
| 507 |
-
print("Wrong SD Upscale option!")
|
| 508 |
-
return
|
| 509 |
-
|
| 510 |
-
# custom_settings["source_file"] = args.file
|
| 511 |
-
custom_settings["source_file"] = ""
|
| 512 |
-
new_path = input(f"Input image path ({custom_settings['source_file']}): ")
|
| 513 |
-
if new_path != "":
|
| 514 |
-
custom_settings["source_file"] = new_path
|
| 515 |
-
if custom_settings["source_file"] == "":
|
| 516 |
-
print("Error : You need to provide a file in SD Upscale mode")
|
| 517 |
-
return
|
| 518 |
-
custom_settings["target_file"] = None
|
| 519 |
-
if option == 2:
|
| 520 |
-
custom_settings["target_file"] = input("Image to patch: ")
|
| 521 |
-
if custom_settings["target_file"] == "":
|
| 522 |
-
print("No target file provided, upscaling whole input image instead!")
|
| 523 |
-
custom_settings["target_file"] = None
|
| 524 |
-
option = 1
|
| 525 |
-
custom_settings["output_format"] = config.generated_images.format
|
| 526 |
-
custom_settings["strength"] = user_value(
|
| 527 |
-
float,
|
| 528 |
-
f"SD Upscale strength ({config.lcm_diffusion_setting.strength}): ",
|
| 529 |
-
config.lcm_diffusion_setting.strength,
|
| 530 |
-
)
|
| 531 |
-
config.lcm_diffusion_setting.inference_steps = int(
|
| 532 |
-
steps / custom_settings["strength"] + 1
|
| 533 |
-
)
|
| 534 |
-
if option == 1:
|
| 535 |
-
custom_settings["scale_factor"] = user_value(
|
| 536 |
-
float,
|
| 537 |
-
f"Scale factor (2.0): ",
|
| 538 |
-
2.0,
|
| 539 |
-
)
|
| 540 |
-
custom_settings["tile_size"] = user_value(
|
| 541 |
-
int,
|
| 542 |
-
f"Split input image into tiles of the following size, in pixels (256): ",
|
| 543 |
-
256,
|
| 544 |
-
)
|
| 545 |
-
custom_settings["tile_overlap"] = user_value(
|
| 546 |
-
int,
|
| 547 |
-
f"Tile overlap, in pixels (16): ",
|
| 548 |
-
16,
|
| 549 |
-
)
|
| 550 |
-
elif option == 2:
|
| 551 |
-
custom_settings["scale_factor"] = user_value(
|
| 552 |
-
float,
|
| 553 |
-
"Input image to Image-to-patch scale_factor (2.0): ",
|
| 554 |
-
2.0,
|
| 555 |
-
)
|
| 556 |
-
custom_settings["tile_size"] = 256
|
| 557 |
-
custom_settings["tile_overlap"] = 16
|
| 558 |
-
custom_settings["prompt"] = input(
|
| 559 |
-
"Write a prompt describing the input image (optional): "
|
| 560 |
-
)
|
| 561 |
-
custom_settings["tiles"] = []
|
| 562 |
-
if option == 2:
|
| 563 |
-
add_tile = True
|
| 564 |
-
while add_tile:
|
| 565 |
-
print("=== Define custom SD Upscale tile ===")
|
| 566 |
-
tile_x = user_value(
|
| 567 |
-
int,
|
| 568 |
-
"Enter tile's X position: ",
|
| 569 |
-
0,
|
| 570 |
-
)
|
| 571 |
-
tile_y = user_value(
|
| 572 |
-
int,
|
| 573 |
-
"Enter tile's Y position: ",
|
| 574 |
-
0,
|
| 575 |
-
)
|
| 576 |
-
tile_w = user_value(
|
| 577 |
-
int,
|
| 578 |
-
"Enter tile's width (256): ",
|
| 579 |
-
256,
|
| 580 |
-
)
|
| 581 |
-
tile_h = user_value(
|
| 582 |
-
int,
|
| 583 |
-
"Enter tile's height (256): ",
|
| 584 |
-
256,
|
| 585 |
-
)
|
| 586 |
-
tile_scale = user_value(
|
| 587 |
-
float,
|
| 588 |
-
"Enter tile's scale factor (2.0): ",
|
| 589 |
-
2.0,
|
| 590 |
-
)
|
| 591 |
-
tile_prompt = input("Enter tile's prompt (optional): ")
|
| 592 |
-
custom_settings["tiles"].append(
|
| 593 |
-
{
|
| 594 |
-
"x": tile_x,
|
| 595 |
-
"y": tile_y,
|
| 596 |
-
"w": tile_w,
|
| 597 |
-
"h": tile_h,
|
| 598 |
-
"mask_box": None,
|
| 599 |
-
"prompt": tile_prompt,
|
| 600 |
-
"scale_factor": tile_scale,
|
| 601 |
-
}
|
| 602 |
-
)
|
| 603 |
-
tile_option = input("Do you want to define another tile? (y/N): ")
|
| 604 |
-
if tile_option == "" or tile_option.upper() == "N":
|
| 605 |
-
add_tile = False
|
| 606 |
-
|
| 607 |
-
return custom_settings
|
| 608 |
-
|
| 609 |
-
|
| 610 |
-
def interactive_sdupscale(
|
| 611 |
-
config,
|
| 612 |
-
context,
|
| 613 |
-
):
|
| 614 |
-
settings = config.lcm_diffusion_setting
|
| 615 |
-
settings.diffusion_task = DiffusionTask.image_to_image.value
|
| 616 |
-
settings.init_image = ""
|
| 617 |
-
source_path = ""
|
| 618 |
-
steps = settings.inference_steps
|
| 619 |
-
|
| 620 |
-
while True:
|
| 621 |
-
custom_upscale_settings = None
|
| 622 |
-
option = input("Edit custom SD Upscale settings? (y/N): ")
|
| 623 |
-
if option.upper() == "Y":
|
| 624 |
-
config.lcm_diffusion_setting.inference_steps = steps
|
| 625 |
-
custom_upscale_settings = interactive_sdupscale_settings(config)
|
| 626 |
-
if not custom_upscale_settings:
|
| 627 |
-
return
|
| 628 |
-
source_path = custom_upscale_settings["source_file"]
|
| 629 |
-
else:
|
| 630 |
-
new_path = input(f"Image path ({source_path}): ")
|
| 631 |
-
if new_path != "":
|
| 632 |
-
source_path = new_path
|
| 633 |
-
if source_path == "":
|
| 634 |
-
print("Error : You need to provide a file in SD Upscale mode")
|
| 635 |
-
return
|
| 636 |
-
settings.strength = user_value(
|
| 637 |
-
float,
|
| 638 |
-
f"SD Upscale strength ({settings.strength}): ",
|
| 639 |
-
settings.strength,
|
| 640 |
-
)
|
| 641 |
-
settings.inference_steps = int(steps / settings.strength + 1)
|
| 642 |
-
|
| 643 |
-
output_path = FastStableDiffusionPaths.get_upscale_filepath(
|
| 644 |
-
source_path,
|
| 645 |
-
2,
|
| 646 |
-
config.generated_images.format,
|
| 647 |
-
)
|
| 648 |
-
generate_upscaled_image(
|
| 649 |
-
config,
|
| 650 |
-
source_path,
|
| 651 |
-
settings.strength,
|
| 652 |
-
upscale_settings=custom_upscale_settings,
|
| 653 |
-
context=context,
|
| 654 |
-
tile_overlap=32 if settings.use_openvino else 16,
|
| 655 |
-
output_path=output_path,
|
| 656 |
-
image_format=config.generated_images.format,
|
| 657 |
-
)
|
| 658 |
-
user_input = input("Continue in SD Upscale mode? (Y/n): ")
|
| 659 |
-
if user_input.upper() == "N":
|
| 660 |
-
settings.inference_steps = steps
|
| 661 |
-
return
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src/frontend/gui/app_window.py
DELETED
|
@@ -1,595 +0,0 @@
|
|
| 1 |
-
from datetime import datetime
|
| 2 |
-
|
| 3 |
-
from app_settings import AppSettings
|
| 4 |
-
from backend.models.lcmdiffusion_setting import DiffusionTask
|
| 5 |
-
from constants import (
|
| 6 |
-
APP_NAME,
|
| 7 |
-
APP_VERSION,
|
| 8 |
-
LCM_DEFAULT_MODEL,
|
| 9 |
-
LCM_DEFAULT_MODEL_OPENVINO,
|
| 10 |
-
)
|
| 11 |
-
from context import Context
|
| 12 |
-
from frontend.gui.image_variations_widget import ImageVariationsWidget
|
| 13 |
-
from frontend.gui.upscaler_widget import UpscalerWidget
|
| 14 |
-
from frontend.gui.img2img_widget import Img2ImgWidget
|
| 15 |
-
from frontend.utils import (
|
| 16 |
-
enable_openvino_controls,
|
| 17 |
-
get_valid_model_id,
|
| 18 |
-
is_reshape_required,
|
| 19 |
-
)
|
| 20 |
-
from paths import FastStableDiffusionPaths
|
| 21 |
-
from PyQt5 import QtCore, QtWidgets
|
| 22 |
-
from PyQt5.QtCore import QSize, Qt, QThreadPool, QUrl
|
| 23 |
-
from PyQt5.QtGui import QDesktopServices
|
| 24 |
-
from PyQt5.QtWidgets import (
|
| 25 |
-
QCheckBox,
|
| 26 |
-
QComboBox,
|
| 27 |
-
QFileDialog,
|
| 28 |
-
QHBoxLayout,
|
| 29 |
-
QLabel,
|
| 30 |
-
QLineEdit,
|
| 31 |
-
QMainWindow,
|
| 32 |
-
QPushButton,
|
| 33 |
-
QSizePolicy,
|
| 34 |
-
QSlider,
|
| 35 |
-
QSpacerItem,
|
| 36 |
-
QTabWidget,
|
| 37 |
-
QToolButton,
|
| 38 |
-
QVBoxLayout,
|
| 39 |
-
QWidget,
|
| 40 |
-
)
|
| 41 |
-
|
| 42 |
-
from models.interface_types import InterfaceType
|
| 43 |
-
from frontend.gui.base_widget import BaseWidget
|
| 44 |
-
|
| 45 |
-
# DPI scale fix
|
| 46 |
-
QtWidgets.QApplication.setAttribute(QtCore.Qt.AA_EnableHighDpiScaling, True)
|
| 47 |
-
QtWidgets.QApplication.setAttribute(QtCore.Qt.AA_UseHighDpiPixmaps, True)
|
| 48 |
-
|
| 49 |
-
|
| 50 |
-
class MainWindow(QMainWindow):
|
| 51 |
-
settings_changed = QtCore.pyqtSignal()
|
| 52 |
-
""" This signal is used for enabling/disabling the negative prompt field for
|
| 53 |
-
modes that support it; in particular, negative prompt is supported with OpenVINO models
|
| 54 |
-
and in LCM-LoRA mode but not in LCM mode
|
| 55 |
-
"""
|
| 56 |
-
|
| 57 |
-
def __init__(self, config: AppSettings):
|
| 58 |
-
super().__init__()
|
| 59 |
-
self.config = config
|
| 60 |
-
# Prevent saved LoRA and ControlNet settings from being used by
|
| 61 |
-
# default; in GUI mode, the user must explicitly enable those
|
| 62 |
-
if self.config.settings.lcm_diffusion_setting.lora:
|
| 63 |
-
self.config.settings.lcm_diffusion_setting.lora.enabled = False
|
| 64 |
-
if self.config.settings.lcm_diffusion_setting.controlnet:
|
| 65 |
-
self.config.settings.lcm_diffusion_setting.controlnet.enabled = False
|
| 66 |
-
self.setWindowTitle(APP_NAME)
|
| 67 |
-
self.setFixedSize(QSize(600, 670))
|
| 68 |
-
self.init_ui()
|
| 69 |
-
self.pipeline = None
|
| 70 |
-
self.threadpool = QThreadPool()
|
| 71 |
-
self.device = "cpu"
|
| 72 |
-
self.previous_width = 0
|
| 73 |
-
self.previous_height = 0
|
| 74 |
-
self.previous_model = ""
|
| 75 |
-
self.previous_num_of_images = 0
|
| 76 |
-
self.context = Context(InterfaceType.GUI)
|
| 77 |
-
self.init_ui_values()
|
| 78 |
-
self.gen_images = []
|
| 79 |
-
self.image_index = 0
|
| 80 |
-
print(f"Output path : {self.config.settings.generated_images.path}")
|
| 81 |
-
|
| 82 |
-
def init_ui_values(self):
|
| 83 |
-
self.lcm_model.setEnabled(
|
| 84 |
-
not self.config.settings.lcm_diffusion_setting.use_openvino
|
| 85 |
-
)
|
| 86 |
-
self.guidance.setValue(
|
| 87 |
-
int(self.config.settings.lcm_diffusion_setting.guidance_scale * 10)
|
| 88 |
-
)
|
| 89 |
-
self.seed_value.setEnabled(self.config.settings.lcm_diffusion_setting.use_seed)
|
| 90 |
-
self.safety_checker.setChecked(
|
| 91 |
-
self.config.settings.lcm_diffusion_setting.use_safety_checker
|
| 92 |
-
)
|
| 93 |
-
self.use_openvino_check.setChecked(
|
| 94 |
-
self.config.settings.lcm_diffusion_setting.use_openvino
|
| 95 |
-
)
|
| 96 |
-
self.width.setCurrentText(
|
| 97 |
-
str(self.config.settings.lcm_diffusion_setting.image_width)
|
| 98 |
-
)
|
| 99 |
-
self.height.setCurrentText(
|
| 100 |
-
str(self.config.settings.lcm_diffusion_setting.image_height)
|
| 101 |
-
)
|
| 102 |
-
self.inference_steps.setValue(
|
| 103 |
-
int(self.config.settings.lcm_diffusion_setting.inference_steps)
|
| 104 |
-
)
|
| 105 |
-
self.clip_skip.setValue(
|
| 106 |
-
int(self.config.settings.lcm_diffusion_setting.clip_skip)
|
| 107 |
-
)
|
| 108 |
-
self.token_merging.setValue(
|
| 109 |
-
int(self.config.settings.lcm_diffusion_setting.token_merging * 100)
|
| 110 |
-
)
|
| 111 |
-
self.seed_check.setChecked(self.config.settings.lcm_diffusion_setting.use_seed)
|
| 112 |
-
self.seed_value.setText(str(self.config.settings.lcm_diffusion_setting.seed))
|
| 113 |
-
self.use_local_model_folder.setChecked(
|
| 114 |
-
self.config.settings.lcm_diffusion_setting.use_offline_model
|
| 115 |
-
)
|
| 116 |
-
self.results_path.setText(self.config.settings.generated_images.path)
|
| 117 |
-
self.num_images.setValue(
|
| 118 |
-
self.config.settings.lcm_diffusion_setting.number_of_images
|
| 119 |
-
)
|
| 120 |
-
self.use_tae_sd.setChecked(
|
| 121 |
-
self.config.settings.lcm_diffusion_setting.use_tiny_auto_encoder
|
| 122 |
-
)
|
| 123 |
-
self.use_lcm_lora.setChecked(
|
| 124 |
-
self.config.settings.lcm_diffusion_setting.use_lcm_lora
|
| 125 |
-
)
|
| 126 |
-
self.lcm_model.setCurrentText(
|
| 127 |
-
get_valid_model_id(
|
| 128 |
-
self.config.lcm_models,
|
| 129 |
-
self.config.settings.lcm_diffusion_setting.lcm_model_id,
|
| 130 |
-
LCM_DEFAULT_MODEL,
|
| 131 |
-
)
|
| 132 |
-
)
|
| 133 |
-
self.base_model_id.setCurrentText(
|
| 134 |
-
get_valid_model_id(
|
| 135 |
-
self.config.stable_diffsuion_models,
|
| 136 |
-
self.config.settings.lcm_diffusion_setting.lcm_lora.base_model_id,
|
| 137 |
-
)
|
| 138 |
-
)
|
| 139 |
-
self.lcm_lora_id.setCurrentText(
|
| 140 |
-
get_valid_model_id(
|
| 141 |
-
self.config.lcm_lora_models,
|
| 142 |
-
self.config.settings.lcm_diffusion_setting.lcm_lora.lcm_lora_id,
|
| 143 |
-
)
|
| 144 |
-
)
|
| 145 |
-
self.openvino_lcm_model_id.setCurrentText(
|
| 146 |
-
get_valid_model_id(
|
| 147 |
-
self.config.openvino_lcm_models,
|
| 148 |
-
self.config.settings.lcm_diffusion_setting.openvino_lcm_model_id,
|
| 149 |
-
LCM_DEFAULT_MODEL_OPENVINO,
|
| 150 |
-
)
|
| 151 |
-
)
|
| 152 |
-
self.openvino_lcm_model_id.setEnabled(
|
| 153 |
-
self.config.settings.lcm_diffusion_setting.use_openvino
|
| 154 |
-
)
|
| 155 |
-
|
| 156 |
-
def init_ui(self):
|
| 157 |
-
self.create_main_tab()
|
| 158 |
-
self.create_settings_tab()
|
| 159 |
-
self.create_about_tab()
|
| 160 |
-
self.show()
|
| 161 |
-
|
| 162 |
-
def create_main_tab(self):
|
| 163 |
-
self.tab_widget = QTabWidget(self)
|
| 164 |
-
self.tab_main = BaseWidget(self.config, self)
|
| 165 |
-
self.tab_settings = QWidget()
|
| 166 |
-
self.tab_about = QWidget()
|
| 167 |
-
self.img2img_tab = Img2ImgWidget(self.config, self)
|
| 168 |
-
self.variations_tab = ImageVariationsWidget(self.config, self)
|
| 169 |
-
self.upscaler_tab = UpscalerWidget(self.config, self)
|
| 170 |
-
|
| 171 |
-
# Add main window tabs here
|
| 172 |
-
self.tab_widget.addTab(self.tab_main, "Text to Image")
|
| 173 |
-
self.tab_widget.addTab(self.img2img_tab, "Image to Image")
|
| 174 |
-
self.tab_widget.addTab(self.variations_tab, "Image Variations")
|
| 175 |
-
self.tab_widget.addTab(self.upscaler_tab, "Upscaler")
|
| 176 |
-
self.tab_widget.addTab(self.tab_settings, "Settings")
|
| 177 |
-
self.tab_widget.addTab(self.tab_about, "About")
|
| 178 |
-
|
| 179 |
-
self.setCentralWidget(self.tab_widget)
|
| 180 |
-
self.use_seed = False
|
| 181 |
-
|
| 182 |
-
def create_settings_tab(self):
|
| 183 |
-
self.lcm_model_label = QLabel("Latent Consistency Model:")
|
| 184 |
-
# self.lcm_model = QLineEdit(LCM_DEFAULT_MODEL)
|
| 185 |
-
self.lcm_model = QComboBox(self)
|
| 186 |
-
self.lcm_model.addItems(self.config.lcm_models)
|
| 187 |
-
self.lcm_model.currentIndexChanged.connect(self.on_lcm_model_changed)
|
| 188 |
-
|
| 189 |
-
self.use_lcm_lora = QCheckBox("Use LCM LoRA")
|
| 190 |
-
self.use_lcm_lora.setChecked(False)
|
| 191 |
-
self.use_lcm_lora.stateChanged.connect(self.use_lcm_lora_changed)
|
| 192 |
-
|
| 193 |
-
self.lora_base_model_id_label = QLabel("Lora base model ID :")
|
| 194 |
-
self.base_model_id = QComboBox(self)
|
| 195 |
-
self.base_model_id.addItems(self.config.stable_diffsuion_models)
|
| 196 |
-
self.base_model_id.currentIndexChanged.connect(self.on_base_model_id_changed)
|
| 197 |
-
|
| 198 |
-
self.lcm_lora_model_id_label = QLabel("LCM LoRA model ID :")
|
| 199 |
-
self.lcm_lora_id = QComboBox(self)
|
| 200 |
-
self.lcm_lora_id.addItems(self.config.lcm_lora_models)
|
| 201 |
-
self.lcm_lora_id.currentIndexChanged.connect(self.on_lcm_lora_id_changed)
|
| 202 |
-
|
| 203 |
-
self.inference_steps_value = QLabel("Number of inference steps: 4")
|
| 204 |
-
self.inference_steps = QSlider(orientation=Qt.Orientation.Horizontal)
|
| 205 |
-
self.inference_steps.setMaximum(25)
|
| 206 |
-
self.inference_steps.setMinimum(1)
|
| 207 |
-
self.inference_steps.setValue(4)
|
| 208 |
-
self.inference_steps.valueChanged.connect(self.update_steps_label)
|
| 209 |
-
|
| 210 |
-
self.num_images_value = QLabel("Number of images: 1")
|
| 211 |
-
self.num_images = QSlider(orientation=Qt.Orientation.Horizontal)
|
| 212 |
-
self.num_images.setMaximum(100)
|
| 213 |
-
self.num_images.setMinimum(1)
|
| 214 |
-
self.num_images.setValue(1)
|
| 215 |
-
self.num_images.valueChanged.connect(self.update_num_images_label)
|
| 216 |
-
|
| 217 |
-
self.guidance_value = QLabel("Guidance scale: 1")
|
| 218 |
-
self.guidance = QSlider(orientation=Qt.Orientation.Horizontal)
|
| 219 |
-
self.guidance.setMaximum(20)
|
| 220 |
-
self.guidance.setMinimum(10)
|
| 221 |
-
self.guidance.setValue(10)
|
| 222 |
-
self.guidance.valueChanged.connect(self.update_guidance_label)
|
| 223 |
-
|
| 224 |
-
self.clip_skip_value = QLabel("CLIP Skip: 1")
|
| 225 |
-
self.clip_skip = QSlider(orientation=Qt.Orientation.Horizontal)
|
| 226 |
-
self.clip_skip.setMaximum(12)
|
| 227 |
-
self.clip_skip.setMinimum(1)
|
| 228 |
-
self.clip_skip.setValue(1)
|
| 229 |
-
self.clip_skip.valueChanged.connect(self.update_clip_skip_label)
|
| 230 |
-
|
| 231 |
-
self.token_merging_value = QLabel("Token Merging: 0")
|
| 232 |
-
self.token_merging = QSlider(orientation=Qt.Orientation.Horizontal)
|
| 233 |
-
self.token_merging.setMaximum(100)
|
| 234 |
-
self.token_merging.setMinimum(0)
|
| 235 |
-
self.token_merging.setValue(0)
|
| 236 |
-
self.token_merging.valueChanged.connect(self.update_token_merging_label)
|
| 237 |
-
|
| 238 |
-
self.width_value = QLabel("Width :")
|
| 239 |
-
self.width = QComboBox(self)
|
| 240 |
-
self.width.addItem("256")
|
| 241 |
-
self.width.addItem("512")
|
| 242 |
-
self.width.addItem("768")
|
| 243 |
-
self.width.addItem("1024")
|
| 244 |
-
self.width.setCurrentText("512")
|
| 245 |
-
self.width.currentIndexChanged.connect(self.on_width_changed)
|
| 246 |
-
|
| 247 |
-
self.height_value = QLabel("Height :")
|
| 248 |
-
self.height = QComboBox(self)
|
| 249 |
-
self.height.addItem("256")
|
| 250 |
-
self.height.addItem("512")
|
| 251 |
-
self.height.addItem("768")
|
| 252 |
-
self.height.addItem("1024")
|
| 253 |
-
self.height.setCurrentText("512")
|
| 254 |
-
self.height.currentIndexChanged.connect(self.on_height_changed)
|
| 255 |
-
|
| 256 |
-
self.seed_check = QCheckBox("Use seed")
|
| 257 |
-
self.seed_value = QLineEdit()
|
| 258 |
-
self.seed_value.setInputMask("9999999999")
|
| 259 |
-
self.seed_value.setText("123123")
|
| 260 |
-
self.seed_check.stateChanged.connect(self.seed_changed)
|
| 261 |
-
|
| 262 |
-
self.safety_checker = QCheckBox("Use safety checker")
|
| 263 |
-
self.safety_checker.setChecked(True)
|
| 264 |
-
self.safety_checker.stateChanged.connect(self.use_safety_checker_changed)
|
| 265 |
-
|
| 266 |
-
self.use_openvino_check = QCheckBox("Use OpenVINO")
|
| 267 |
-
self.use_openvino_check.setChecked(False)
|
| 268 |
-
self.openvino_model_label = QLabel("OpenVINO LCM model:")
|
| 269 |
-
self.use_local_model_folder = QCheckBox(
|
| 270 |
-
"Use locally cached model or downloaded model folder(offline)"
|
| 271 |
-
)
|
| 272 |
-
self.openvino_lcm_model_id = QComboBox(self)
|
| 273 |
-
self.openvino_lcm_model_id.addItems(self.config.openvino_lcm_models)
|
| 274 |
-
self.openvino_lcm_model_id.currentIndexChanged.connect(
|
| 275 |
-
self.on_openvino_lcm_model_id_changed
|
| 276 |
-
)
|
| 277 |
-
|
| 278 |
-
self.use_openvino_check.setEnabled(enable_openvino_controls())
|
| 279 |
-
self.use_local_model_folder.setChecked(False)
|
| 280 |
-
self.use_local_model_folder.stateChanged.connect(self.use_offline_model_changed)
|
| 281 |
-
self.use_openvino_check.stateChanged.connect(self.use_openvino_changed)
|
| 282 |
-
|
| 283 |
-
self.use_tae_sd = QCheckBox(
|
| 284 |
-
"Use Tiny Auto Encoder - TAESD (Fast, moderate quality)"
|
| 285 |
-
)
|
| 286 |
-
self.use_tae_sd.setChecked(False)
|
| 287 |
-
self.use_tae_sd.stateChanged.connect(self.use_tae_sd_changed)
|
| 288 |
-
|
| 289 |
-
hlayout = QHBoxLayout()
|
| 290 |
-
hlayout.addWidget(self.seed_check)
|
| 291 |
-
hlayout.addWidget(self.seed_value)
|
| 292 |
-
hspacer = QSpacerItem(20, 10, QSizePolicy.Expanding, QSizePolicy.Minimum)
|
| 293 |
-
slider_hspacer = QSpacerItem(20, 10, QSizePolicy.Expanding, QSizePolicy.Minimum)
|
| 294 |
-
|
| 295 |
-
self.results_path_label = QLabel("Output path:")
|
| 296 |
-
self.results_path = QLineEdit()
|
| 297 |
-
self.results_path.textChanged.connect(self.on_path_changed)
|
| 298 |
-
self.browse_folder_btn = QToolButton()
|
| 299 |
-
self.browse_folder_btn.setText("...")
|
| 300 |
-
self.browse_folder_btn.clicked.connect(self.on_browse_folder)
|
| 301 |
-
|
| 302 |
-
self.reset = QPushButton("Reset All")
|
| 303 |
-
self.reset.clicked.connect(self.reset_all_settings)
|
| 304 |
-
|
| 305 |
-
vlayout = QVBoxLayout()
|
| 306 |
-
vspacer = QSpacerItem(20, 20, QSizePolicy.Minimum, QSizePolicy.Expanding)
|
| 307 |
-
vlayout.addItem(hspacer)
|
| 308 |
-
vlayout.setSpacing(3)
|
| 309 |
-
vlayout.addWidget(self.lcm_model_label)
|
| 310 |
-
vlayout.addWidget(self.lcm_model)
|
| 311 |
-
vlayout.addWidget(self.use_local_model_folder)
|
| 312 |
-
vlayout.addWidget(self.use_lcm_lora)
|
| 313 |
-
vlayout.addWidget(self.lora_base_model_id_label)
|
| 314 |
-
vlayout.addWidget(self.base_model_id)
|
| 315 |
-
vlayout.addWidget(self.lcm_lora_model_id_label)
|
| 316 |
-
vlayout.addWidget(self.lcm_lora_id)
|
| 317 |
-
vlayout.addWidget(self.use_openvino_check)
|
| 318 |
-
vlayout.addWidget(self.openvino_model_label)
|
| 319 |
-
vlayout.addWidget(self.openvino_lcm_model_id)
|
| 320 |
-
vlayout.addWidget(self.use_tae_sd)
|
| 321 |
-
vlayout.addItem(slider_hspacer)
|
| 322 |
-
vlayout.addWidget(self.inference_steps_value)
|
| 323 |
-
vlayout.addWidget(self.inference_steps)
|
| 324 |
-
vlayout.addWidget(self.num_images_value)
|
| 325 |
-
vlayout.addWidget(self.num_images)
|
| 326 |
-
vlayout.addWidget(self.width_value)
|
| 327 |
-
vlayout.addWidget(self.width)
|
| 328 |
-
vlayout.addWidget(self.height_value)
|
| 329 |
-
vlayout.addWidget(self.height)
|
| 330 |
-
vlayout.addWidget(self.guidance_value)
|
| 331 |
-
vlayout.addWidget(self.guidance)
|
| 332 |
-
vlayout.addWidget(self.clip_skip_value)
|
| 333 |
-
vlayout.addWidget(self.clip_skip)
|
| 334 |
-
vlayout.addWidget(self.token_merging_value)
|
| 335 |
-
vlayout.addWidget(self.token_merging)
|
| 336 |
-
vlayout.addLayout(hlayout)
|
| 337 |
-
vlayout.addWidget(self.safety_checker)
|
| 338 |
-
|
| 339 |
-
vlayout.addWidget(self.results_path_label)
|
| 340 |
-
hlayout_path = QHBoxLayout()
|
| 341 |
-
hlayout_path.addWidget(self.results_path)
|
| 342 |
-
hlayout_path.addWidget(self.browse_folder_btn)
|
| 343 |
-
vlayout.addLayout(hlayout_path)
|
| 344 |
-
self.tab_settings.setLayout(vlayout)
|
| 345 |
-
hlayout_reset = QHBoxLayout()
|
| 346 |
-
hspacer = QSpacerItem(20, 20, QSizePolicy.Expanding, QSizePolicy.Minimum)
|
| 347 |
-
hlayout_reset.addItem(hspacer)
|
| 348 |
-
hlayout_reset.addWidget(self.reset)
|
| 349 |
-
vlayout.addLayout(hlayout_reset)
|
| 350 |
-
vlayout.addItem(vspacer)
|
| 351 |
-
|
| 352 |
-
def create_about_tab(self):
|
| 353 |
-
self.label = QLabel()
|
| 354 |
-
self.label.setAlignment(Qt.AlignCenter)
|
| 355 |
-
current_year = datetime.now().year
|
| 356 |
-
self.label.setText(
|
| 357 |
-
f"""<h1>FastSD CPU {APP_VERSION}</h1>
|
| 358 |
-
<h3>(c)2023 - {current_year} Rupesh Sreeraman</h3>
|
| 359 |
-
<h3>Faster stable diffusion on CPU</h3>
|
| 360 |
-
<h3>Based on Latent Consistency Models</h3>
|
| 361 |
-
<h3>GitHub : https://github.com/rupeshs/fastsdcpu/</h3>"""
|
| 362 |
-
)
|
| 363 |
-
|
| 364 |
-
vlayout = QVBoxLayout()
|
| 365 |
-
vlayout.addWidget(self.label)
|
| 366 |
-
self.tab_about.setLayout(vlayout)
|
| 367 |
-
|
| 368 |
-
def show_image(self, pixmap):
|
| 369 |
-
image_width = self.config.settings.lcm_diffusion_setting.image_width
|
| 370 |
-
image_height = self.config.settings.lcm_diffusion_setting.image_height
|
| 371 |
-
if image_width > 512 or image_height > 512:
|
| 372 |
-
new_width = 512 if image_width > 512 else image_width
|
| 373 |
-
new_height = 512 if image_height > 512 else image_height
|
| 374 |
-
self.img.setPixmap(
|
| 375 |
-
pixmap.scaled(
|
| 376 |
-
new_width,
|
| 377 |
-
new_height,
|
| 378 |
-
Qt.KeepAspectRatio,
|
| 379 |
-
)
|
| 380 |
-
)
|
| 381 |
-
else:
|
| 382 |
-
self.img.setPixmap(pixmap)
|
| 383 |
-
|
| 384 |
-
def on_show_next_image(self):
|
| 385 |
-
if self.image_index != len(self.gen_images) - 1 and len(self.gen_images) > 0:
|
| 386 |
-
self.previous_img_btn.setEnabled(True)
|
| 387 |
-
self.image_index += 1
|
| 388 |
-
self.show_image(self.gen_images[self.image_index])
|
| 389 |
-
if self.image_index == len(self.gen_images) - 1:
|
| 390 |
-
self.next_img_btn.setEnabled(False)
|
| 391 |
-
|
| 392 |
-
def on_open_results_folder(self):
|
| 393 |
-
QDesktopServices.openUrl(
|
| 394 |
-
QUrl.fromLocalFile(self.config.settings.generated_images.path)
|
| 395 |
-
)
|
| 396 |
-
|
| 397 |
-
def on_show_previous_image(self):
|
| 398 |
-
if self.image_index != 0:
|
| 399 |
-
self.next_img_btn.setEnabled(True)
|
| 400 |
-
self.image_index -= 1
|
| 401 |
-
self.show_image(self.gen_images[self.image_index])
|
| 402 |
-
if self.image_index == 0:
|
| 403 |
-
self.previous_img_btn.setEnabled(False)
|
| 404 |
-
|
| 405 |
-
def on_path_changed(self, text):
|
| 406 |
-
self.config.settings.generated_images.path = text
|
| 407 |
-
|
| 408 |
-
def on_browse_folder(self):
|
| 409 |
-
options = QFileDialog.Options()
|
| 410 |
-
options |= QFileDialog.ShowDirsOnly
|
| 411 |
-
|
| 412 |
-
folder_path = QFileDialog.getExistingDirectory(
|
| 413 |
-
self, "Select a Folder", "", options=options
|
| 414 |
-
)
|
| 415 |
-
|
| 416 |
-
if folder_path:
|
| 417 |
-
self.config.settings.generated_images.path = folder_path
|
| 418 |
-
self.results_path.setText(folder_path)
|
| 419 |
-
|
| 420 |
-
def on_width_changed(self, index):
|
| 421 |
-
width_txt = self.width.itemText(index)
|
| 422 |
-
self.config.settings.lcm_diffusion_setting.image_width = int(width_txt)
|
| 423 |
-
|
| 424 |
-
def on_height_changed(self, index):
|
| 425 |
-
height_txt = self.height.itemText(index)
|
| 426 |
-
self.config.settings.lcm_diffusion_setting.image_height = int(height_txt)
|
| 427 |
-
|
| 428 |
-
def on_lcm_model_changed(self, index):
|
| 429 |
-
model_id = self.lcm_model.itemText(index)
|
| 430 |
-
self.config.settings.lcm_diffusion_setting.lcm_model_id = model_id
|
| 431 |
-
|
| 432 |
-
def on_base_model_id_changed(self, index):
|
| 433 |
-
model_id = self.base_model_id.itemText(index)
|
| 434 |
-
self.config.settings.lcm_diffusion_setting.lcm_lora.base_model_id = model_id
|
| 435 |
-
|
| 436 |
-
def on_lcm_lora_id_changed(self, index):
|
| 437 |
-
model_id = self.lcm_lora_id.itemText(index)
|
| 438 |
-
self.config.settings.lcm_diffusion_setting.lcm_lora.lcm_lora_id = model_id
|
| 439 |
-
|
| 440 |
-
def on_openvino_lcm_model_id_changed(self, index):
|
| 441 |
-
model_id = self.openvino_lcm_model_id.itemText(index)
|
| 442 |
-
self.config.settings.lcm_diffusion_setting.openvino_lcm_model_id = model_id
|
| 443 |
-
|
| 444 |
-
def use_openvino_changed(self, state):
|
| 445 |
-
if state == 2:
|
| 446 |
-
self.lcm_model.setEnabled(False)
|
| 447 |
-
self.use_lcm_lora.setEnabled(False)
|
| 448 |
-
self.lcm_lora_id.setEnabled(False)
|
| 449 |
-
self.base_model_id.setEnabled(False)
|
| 450 |
-
self.openvino_lcm_model_id.setEnabled(True)
|
| 451 |
-
self.config.settings.lcm_diffusion_setting.use_openvino = True
|
| 452 |
-
else:
|
| 453 |
-
self.lcm_model.setEnabled(True)
|
| 454 |
-
self.use_lcm_lora.setEnabled(True)
|
| 455 |
-
self.lcm_lora_id.setEnabled(True)
|
| 456 |
-
self.base_model_id.setEnabled(True)
|
| 457 |
-
self.openvino_lcm_model_id.setEnabled(False)
|
| 458 |
-
self.config.settings.lcm_diffusion_setting.use_openvino = False
|
| 459 |
-
self.settings_changed.emit()
|
| 460 |
-
|
| 461 |
-
def use_tae_sd_changed(self, state):
|
| 462 |
-
if state == 2:
|
| 463 |
-
self.config.settings.lcm_diffusion_setting.use_tiny_auto_encoder = True
|
| 464 |
-
else:
|
| 465 |
-
self.config.settings.lcm_diffusion_setting.use_tiny_auto_encoder = False
|
| 466 |
-
|
| 467 |
-
def use_offline_model_changed(self, state):
|
| 468 |
-
if state == 2:
|
| 469 |
-
self.config.settings.lcm_diffusion_setting.use_offline_model = True
|
| 470 |
-
else:
|
| 471 |
-
self.config.settings.lcm_diffusion_setting.use_offline_model = False
|
| 472 |
-
|
| 473 |
-
def use_lcm_lora_changed(self, state):
|
| 474 |
-
if state == 2:
|
| 475 |
-
self.lcm_model.setEnabled(False)
|
| 476 |
-
self.lcm_lora_id.setEnabled(True)
|
| 477 |
-
self.base_model_id.setEnabled(True)
|
| 478 |
-
self.config.settings.lcm_diffusion_setting.use_lcm_lora = True
|
| 479 |
-
else:
|
| 480 |
-
self.lcm_model.setEnabled(True)
|
| 481 |
-
self.lcm_lora_id.setEnabled(False)
|
| 482 |
-
self.base_model_id.setEnabled(False)
|
| 483 |
-
self.config.settings.lcm_diffusion_setting.use_lcm_lora = False
|
| 484 |
-
self.settings_changed.emit()
|
| 485 |
-
|
| 486 |
-
def update_clip_skip_label(self, value):
|
| 487 |
-
self.clip_skip_value.setText(f"CLIP Skip: {value}")
|
| 488 |
-
self.config.settings.lcm_diffusion_setting.clip_skip = value
|
| 489 |
-
|
| 490 |
-
def update_token_merging_label(self, value):
|
| 491 |
-
val = round(int(value) / 100, 1)
|
| 492 |
-
self.token_merging_value.setText(f"Token Merging: {val}")
|
| 493 |
-
self.config.settings.lcm_diffusion_setting.token_merging = val
|
| 494 |
-
|
| 495 |
-
def use_safety_checker_changed(self, state):
|
| 496 |
-
if state == 2:
|
| 497 |
-
self.config.settings.lcm_diffusion_setting.use_safety_checker = True
|
| 498 |
-
else:
|
| 499 |
-
self.config.settings.lcm_diffusion_setting.use_safety_checker = False
|
| 500 |
-
|
| 501 |
-
def update_steps_label(self, value):
|
| 502 |
-
self.inference_steps_value.setText(f"Number of inference steps: {value}")
|
| 503 |
-
self.config.settings.lcm_diffusion_setting.inference_steps = value
|
| 504 |
-
|
| 505 |
-
def update_num_images_label(self, value):
|
| 506 |
-
self.num_images_value.setText(f"Number of images: {value}")
|
| 507 |
-
self.config.settings.lcm_diffusion_setting.number_of_images = value
|
| 508 |
-
|
| 509 |
-
def update_guidance_label(self, value):
|
| 510 |
-
val = round(int(value) / 10, 1)
|
| 511 |
-
self.guidance_value.setText(f"Guidance scale: {val}")
|
| 512 |
-
self.config.settings.lcm_diffusion_setting.guidance_scale = val
|
| 513 |
-
|
| 514 |
-
def seed_changed(self, state):
|
| 515 |
-
if state == 2:
|
| 516 |
-
self.seed_value.setEnabled(True)
|
| 517 |
-
self.config.settings.lcm_diffusion_setting.use_seed = True
|
| 518 |
-
else:
|
| 519 |
-
self.seed_value.setEnabled(False)
|
| 520 |
-
self.config.settings.lcm_diffusion_setting.use_seed = False
|
| 521 |
-
|
| 522 |
-
def get_seed_value(self) -> int:
|
| 523 |
-
use_seed = self.config.settings.lcm_diffusion_setting.use_seed
|
| 524 |
-
seed_value = int(self.seed_value.text()) if use_seed else -1
|
| 525 |
-
return seed_value
|
| 526 |
-
|
| 527 |
-
# def text_to_image(self):
|
| 528 |
-
# self.img.setText("Please wait...")
|
| 529 |
-
# worker = ImageGeneratorWorker(self.generate_image)
|
| 530 |
-
# self.threadpool.start(worker)
|
| 531 |
-
|
| 532 |
-
def closeEvent(self, event):
|
| 533 |
-
self.config.settings.lcm_diffusion_setting.seed = self.get_seed_value()
|
| 534 |
-
print(self.config.settings.lcm_diffusion_setting)
|
| 535 |
-
print("Saving settings")
|
| 536 |
-
self.config.save()
|
| 537 |
-
|
| 538 |
-
def reset_all_settings(self):
|
| 539 |
-
self.use_local_model_folder.setChecked(False)
|
| 540 |
-
self.width.setCurrentText("512")
|
| 541 |
-
self.height.setCurrentText("512")
|
| 542 |
-
self.inference_steps.setValue(4)
|
| 543 |
-
self.guidance.setValue(10)
|
| 544 |
-
self.clip_skip.setValue(1)
|
| 545 |
-
self.token_merging.setValue(0)
|
| 546 |
-
self.use_openvino_check.setChecked(False)
|
| 547 |
-
self.seed_check.setChecked(False)
|
| 548 |
-
self.safety_checker.setChecked(False)
|
| 549 |
-
self.results_path.setText(FastStableDiffusionPaths().get_results_path())
|
| 550 |
-
self.use_tae_sd.setChecked(False)
|
| 551 |
-
self.use_lcm_lora.setChecked(False)
|
| 552 |
-
|
| 553 |
-
def prepare_generation_settings(self, config):
|
| 554 |
-
"""Populate config settings with the values set by the user in the GUI"""
|
| 555 |
-
config.settings.lcm_diffusion_setting.seed = self.get_seed_value()
|
| 556 |
-
config.settings.lcm_diffusion_setting.lcm_lora.lcm_lora_id = (
|
| 557 |
-
self.lcm_lora_id.currentText()
|
| 558 |
-
)
|
| 559 |
-
config.settings.lcm_diffusion_setting.lcm_lora.base_model_id = (
|
| 560 |
-
self.base_model_id.currentText()
|
| 561 |
-
)
|
| 562 |
-
|
| 563 |
-
if config.settings.lcm_diffusion_setting.use_openvino:
|
| 564 |
-
model_id = self.openvino_lcm_model_id.currentText()
|
| 565 |
-
config.settings.lcm_diffusion_setting.openvino_lcm_model_id = model_id
|
| 566 |
-
else:
|
| 567 |
-
model_id = self.lcm_model.currentText()
|
| 568 |
-
config.settings.lcm_diffusion_setting.lcm_model_id = model_id
|
| 569 |
-
|
| 570 |
-
config.reshape_required = False
|
| 571 |
-
config.model_id = model_id
|
| 572 |
-
if config.settings.lcm_diffusion_setting.use_openvino:
|
| 573 |
-
# Detect dimension change
|
| 574 |
-
config.reshape_required = is_reshape_required(
|
| 575 |
-
self.previous_width,
|
| 576 |
-
config.settings.lcm_diffusion_setting.image_width,
|
| 577 |
-
self.previous_height,
|
| 578 |
-
config.settings.lcm_diffusion_setting.image_height,
|
| 579 |
-
self.previous_model,
|
| 580 |
-
model_id,
|
| 581 |
-
self.previous_num_of_images,
|
| 582 |
-
config.settings.lcm_diffusion_setting.number_of_images,
|
| 583 |
-
)
|
| 584 |
-
config.settings.lcm_diffusion_setting.diffusion_task = (
|
| 585 |
-
DiffusionTask.text_to_image.value
|
| 586 |
-
)
|
| 587 |
-
|
| 588 |
-
def store_dimension_settings(self):
|
| 589 |
-
"""These values are only needed for OpenVINO model reshape"""
|
| 590 |
-
self.previous_width = self.config.settings.lcm_diffusion_setting.image_width
|
| 591 |
-
self.previous_height = self.config.settings.lcm_diffusion_setting.image_height
|
| 592 |
-
self.previous_model = self.config.model_id
|
| 593 |
-
self.previous_num_of_images = (
|
| 594 |
-
self.config.settings.lcm_diffusion_setting.number_of_images
|
| 595 |
-
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src/frontend/gui/base_widget.py
DELETED
|
@@ -1,199 +0,0 @@
|
|
| 1 |
-
from PIL.ImageQt import ImageQt
|
| 2 |
-
from PyQt5 import QtCore
|
| 3 |
-
from PyQt5.QtCore import QSize, Qt, QUrl
|
| 4 |
-
from PyQt5.QtGui import (
|
| 5 |
-
QDesktopServices,
|
| 6 |
-
QPixmap,
|
| 7 |
-
)
|
| 8 |
-
from PyQt5.QtWidgets import (
|
| 9 |
-
QApplication,
|
| 10 |
-
QHBoxLayout,
|
| 11 |
-
QLabel,
|
| 12 |
-
QPushButton,
|
| 13 |
-
QSizePolicy,
|
| 14 |
-
QTextEdit,
|
| 15 |
-
QToolButton,
|
| 16 |
-
QVBoxLayout,
|
| 17 |
-
QWidget,
|
| 18 |
-
)
|
| 19 |
-
|
| 20 |
-
from app_settings import AppSettings
|
| 21 |
-
from constants import DEVICE
|
| 22 |
-
from frontend.gui.image_generator_worker import ImageGeneratorWorker
|
| 23 |
-
|
| 24 |
-
|
| 25 |
-
class ImageLabel(QLabel):
|
| 26 |
-
"""Defines a simple QLabel widget"""
|
| 27 |
-
|
| 28 |
-
changed = QtCore.pyqtSignal()
|
| 29 |
-
|
| 30 |
-
def __init__(self, text: str):
|
| 31 |
-
super().__init__(text)
|
| 32 |
-
self.setAlignment(Qt.AlignCenter)
|
| 33 |
-
self.resize(512, 512)
|
| 34 |
-
self.setSizePolicy(QSizePolicy.MinimumExpanding, QSizePolicy.MinimumExpanding)
|
| 35 |
-
self.sizeHint = QSize(512, 512)
|
| 36 |
-
self.setAcceptDrops(False)
|
| 37 |
-
|
| 38 |
-
def show_image(self, pixmap: QPixmap = None):
|
| 39 |
-
"""Updates the widget pixamp"""
|
| 40 |
-
if pixmap == None or pixmap.isNull():
|
| 41 |
-
return
|
| 42 |
-
self.current_pixmap = pixmap
|
| 43 |
-
self.changed.emit()
|
| 44 |
-
|
| 45 |
-
# Resize the pixmap to the widget dimensions
|
| 46 |
-
image_width = self.current_pixmap.width()
|
| 47 |
-
image_height = self.current_pixmap.height()
|
| 48 |
-
if image_width > 512 or image_height > 512:
|
| 49 |
-
new_width = 512 if image_width > 512 else image_width
|
| 50 |
-
new_height = 512 if image_height > 512 else image_height
|
| 51 |
-
self.setPixmap(
|
| 52 |
-
self.current_pixmap.scaled(
|
| 53 |
-
new_width,
|
| 54 |
-
new_height,
|
| 55 |
-
Qt.KeepAspectRatio,
|
| 56 |
-
)
|
| 57 |
-
)
|
| 58 |
-
else:
|
| 59 |
-
self.setPixmap(self.current_pixmap)
|
| 60 |
-
|
| 61 |
-
|
| 62 |
-
class BaseWidget(QWidget):
|
| 63 |
-
def __init__(self, config: AppSettings, parent):
|
| 64 |
-
super().__init__()
|
| 65 |
-
self.config = config
|
| 66 |
-
self.gen_images = []
|
| 67 |
-
self.image_index = 0
|
| 68 |
-
self.config = config
|
| 69 |
-
self.parent = parent
|
| 70 |
-
|
| 71 |
-
# Initialize GUI widgets
|
| 72 |
-
self.prev_btn = QToolButton()
|
| 73 |
-
self.prev_btn.setText("<")
|
| 74 |
-
self.prev_btn.clicked.connect(self.on_show_previous_image)
|
| 75 |
-
self.img = ImageLabel("<<Image>>")
|
| 76 |
-
self.next_btn = QToolButton()
|
| 77 |
-
self.next_btn.setText(">")
|
| 78 |
-
self.next_btn.clicked.connect(self.on_show_next_image)
|
| 79 |
-
self.prompt = QTextEdit()
|
| 80 |
-
self.prompt.setPlaceholderText("A fantasy landscape")
|
| 81 |
-
self.prompt.setAcceptRichText(False)
|
| 82 |
-
self.prompt.setFixedHeight(40)
|
| 83 |
-
self.neg_prompt = QTextEdit()
|
| 84 |
-
self.neg_prompt.setPlaceholderText("")
|
| 85 |
-
self.neg_prompt.setAcceptRichText(False)
|
| 86 |
-
self.neg_prompt_label = QLabel("Negative prompt (Set guidance scale > 1.0):")
|
| 87 |
-
self.neg_prompt.setFixedHeight(35)
|
| 88 |
-
self.neg_prompt.setEnabled(False)
|
| 89 |
-
self.generate = QPushButton("Generate")
|
| 90 |
-
self.generate.clicked.connect(self.generate_click)
|
| 91 |
-
self.browse_results = QPushButton("...")
|
| 92 |
-
self.browse_results.setFixedWidth(30)
|
| 93 |
-
self.browse_results.clicked.connect(self.on_open_results_folder)
|
| 94 |
-
self.browse_results.setToolTip("Open output folder")
|
| 95 |
-
|
| 96 |
-
# Create the image navigation layout
|
| 97 |
-
ilayout = QHBoxLayout()
|
| 98 |
-
ilayout.addWidget(self.prev_btn)
|
| 99 |
-
ilayout.addWidget(self.img)
|
| 100 |
-
ilayout.addWidget(self.next_btn)
|
| 101 |
-
|
| 102 |
-
# Create the generate button layout
|
| 103 |
-
hlayout = QHBoxLayout()
|
| 104 |
-
hlayout.addWidget(self.neg_prompt)
|
| 105 |
-
hlayout.addWidget(self.generate)
|
| 106 |
-
hlayout.addWidget(self.browse_results)
|
| 107 |
-
|
| 108 |
-
# Create the actual widget layout
|
| 109 |
-
vlayout = QVBoxLayout()
|
| 110 |
-
vlayout.addLayout(ilayout)
|
| 111 |
-
# vlayout.addItem(self.vspacer)
|
| 112 |
-
vlayout.addWidget(self.prompt)
|
| 113 |
-
vlayout.addWidget(self.neg_prompt_label)
|
| 114 |
-
vlayout.addLayout(hlayout)
|
| 115 |
-
self.setLayout(vlayout)
|
| 116 |
-
|
| 117 |
-
self.parent.settings_changed.connect(self.on_settings_changed)
|
| 118 |
-
|
| 119 |
-
def generate_image(self):
|
| 120 |
-
self.parent.prepare_generation_settings(self.config)
|
| 121 |
-
self.config.settings.lcm_diffusion_setting.prompt = self.prompt.toPlainText()
|
| 122 |
-
self.config.settings.lcm_diffusion_setting.negative_prompt = (
|
| 123 |
-
self.neg_prompt.toPlainText()
|
| 124 |
-
)
|
| 125 |
-
images = self.parent.context.generate_text_to_image(
|
| 126 |
-
self.config.settings,
|
| 127 |
-
self.config.reshape_required,
|
| 128 |
-
DEVICE,
|
| 129 |
-
)
|
| 130 |
-
self.parent.context.save_images(
|
| 131 |
-
images,
|
| 132 |
-
self.config.settings,
|
| 133 |
-
)
|
| 134 |
-
self.prepare_images(images)
|
| 135 |
-
self.after_generation()
|
| 136 |
-
|
| 137 |
-
def prepare_images(self, images):
|
| 138 |
-
"""Prepares the generated images to be displayed in the Qt widget"""
|
| 139 |
-
self.image_index = 0
|
| 140 |
-
self.gen_images = []
|
| 141 |
-
for img in images:
|
| 142 |
-
im = ImageQt(img).copy()
|
| 143 |
-
pixmap = QPixmap.fromImage(im)
|
| 144 |
-
self.gen_images.append(pixmap)
|
| 145 |
-
|
| 146 |
-
if len(self.gen_images) > 1:
|
| 147 |
-
self.next_btn.setEnabled(True)
|
| 148 |
-
self.prev_btn.setEnabled(False)
|
| 149 |
-
else:
|
| 150 |
-
self.next_btn.setEnabled(False)
|
| 151 |
-
self.prev_btn.setEnabled(False)
|
| 152 |
-
|
| 153 |
-
self.img.show_image(pixmap=self.gen_images[0])
|
| 154 |
-
|
| 155 |
-
def on_show_next_image(self):
|
| 156 |
-
if self.image_index != len(self.gen_images) - 1 and len(self.gen_images) > 0:
|
| 157 |
-
self.prev_btn.setEnabled(True)
|
| 158 |
-
self.image_index += 1
|
| 159 |
-
self.img.show_image(pixmap=self.gen_images[self.image_index])
|
| 160 |
-
if self.image_index == len(self.gen_images) - 1:
|
| 161 |
-
self.next_btn.setEnabled(False)
|
| 162 |
-
|
| 163 |
-
def on_show_previous_image(self):
|
| 164 |
-
if self.image_index != 0:
|
| 165 |
-
self.next_btn.setEnabled(True)
|
| 166 |
-
self.image_index -= 1
|
| 167 |
-
self.img.show_image(pixmap=self.gen_images[self.image_index])
|
| 168 |
-
if self.image_index == 0:
|
| 169 |
-
self.prev_btn.setEnabled(False)
|
| 170 |
-
|
| 171 |
-
def on_open_results_folder(self):
|
| 172 |
-
QDesktopServices.openUrl(
|
| 173 |
-
QUrl.fromLocalFile(self.config.settings.generated_images.path)
|
| 174 |
-
)
|
| 175 |
-
|
| 176 |
-
def generate_click(self):
|
| 177 |
-
self.img.setText("Please wait...")
|
| 178 |
-
self.before_generation()
|
| 179 |
-
worker = ImageGeneratorWorker(self.generate_image)
|
| 180 |
-
self.parent.threadpool.start(worker)
|
| 181 |
-
|
| 182 |
-
def before_generation(self):
|
| 183 |
-
"""Call this function before running a generation task"""
|
| 184 |
-
self.img.setEnabled(False)
|
| 185 |
-
self.generate.setEnabled(False)
|
| 186 |
-
self.browse_results.setEnabled(False)
|
| 187 |
-
|
| 188 |
-
def after_generation(self):
|
| 189 |
-
"""Call this function after running a generation task"""
|
| 190 |
-
self.img.setEnabled(True)
|
| 191 |
-
self.generate.setEnabled(True)
|
| 192 |
-
self.browse_results.setEnabled(True)
|
| 193 |
-
self.parent.store_dimension_settings()
|
| 194 |
-
|
| 195 |
-
def on_settings_changed(self):
|
| 196 |
-
self.neg_prompt.setEnabled(
|
| 197 |
-
self.config.settings.lcm_diffusion_setting.use_openvino
|
| 198 |
-
or self.config.settings.lcm_diffusion_setting.use_lcm_lora
|
| 199 |
-
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src/frontend/gui/image_generator_worker.py
DELETED
|
@@ -1,37 +0,0 @@
|
|
| 1 |
-
from PyQt5.QtCore import (
|
| 2 |
-
pyqtSlot,
|
| 3 |
-
QRunnable,
|
| 4 |
-
pyqtSignal,
|
| 5 |
-
pyqtSlot,
|
| 6 |
-
)
|
| 7 |
-
from PyQt5.QtCore import QObject
|
| 8 |
-
import traceback
|
| 9 |
-
import sys
|
| 10 |
-
|
| 11 |
-
|
| 12 |
-
class WorkerSignals(QObject):
|
| 13 |
-
finished = pyqtSignal()
|
| 14 |
-
error = pyqtSignal(tuple)
|
| 15 |
-
result = pyqtSignal(object)
|
| 16 |
-
|
| 17 |
-
|
| 18 |
-
class ImageGeneratorWorker(QRunnable):
|
| 19 |
-
def __init__(self, fn, *args, **kwargs):
|
| 20 |
-
super(ImageGeneratorWorker, self).__init__()
|
| 21 |
-
self.fn = fn
|
| 22 |
-
self.args = args
|
| 23 |
-
self.kwargs = kwargs
|
| 24 |
-
self.signals = WorkerSignals()
|
| 25 |
-
|
| 26 |
-
@pyqtSlot()
|
| 27 |
-
def run(self):
|
| 28 |
-
try:
|
| 29 |
-
result = self.fn(*self.args, **self.kwargs)
|
| 30 |
-
except:
|
| 31 |
-
traceback.print_exc()
|
| 32 |
-
exctype, value = sys.exc_info()[:2]
|
| 33 |
-
self.signals.error.emit((exctype, value, traceback.format_exc()))
|
| 34 |
-
else:
|
| 35 |
-
self.signals.result.emit(result)
|
| 36 |
-
finally:
|
| 37 |
-
self.signals.finished.emit()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src/frontend/gui/image_variations_widget.py
DELETED
|
@@ -1,35 +0,0 @@
|
|
| 1 |
-
from PIL import Image
|
| 2 |
-
from PyQt5.QtWidgets import QApplication
|
| 3 |
-
|
| 4 |
-
from app_settings import AppSettings
|
| 5 |
-
from backend.models.lcmdiffusion_setting import DiffusionTask
|
| 6 |
-
from frontend.gui.img2img_widget import Img2ImgWidget
|
| 7 |
-
from frontend.webui.image_variations_ui import generate_image_variations
|
| 8 |
-
|
| 9 |
-
|
| 10 |
-
class ImageVariationsWidget(Img2ImgWidget):
|
| 11 |
-
def __init__(self, config: AppSettings, parent):
|
| 12 |
-
super().__init__(config, parent)
|
| 13 |
-
# Hide prompt and negative prompt widgets
|
| 14 |
-
self.prompt.hide()
|
| 15 |
-
self.neg_prompt_label.hide()
|
| 16 |
-
self.neg_prompt.setEnabled(False)
|
| 17 |
-
|
| 18 |
-
def generate_image(self):
|
| 19 |
-
self.parent.prepare_generation_settings(self.config)
|
| 20 |
-
self.config.settings.lcm_diffusion_setting.diffusion_task = (
|
| 21 |
-
DiffusionTask.image_to_image.value
|
| 22 |
-
)
|
| 23 |
-
self.config.settings.lcm_diffusion_setting.prompt = ""
|
| 24 |
-
self.config.settings.lcm_diffusion_setting.negative_prompt = ""
|
| 25 |
-
self.config.settings.lcm_diffusion_setting.init_image = Image.open(
|
| 26 |
-
self.img_path.text()
|
| 27 |
-
)
|
| 28 |
-
self.config.settings.lcm_diffusion_setting.strength = self.strength.value() / 10
|
| 29 |
-
|
| 30 |
-
images = generate_image_variations(
|
| 31 |
-
self.config.settings.lcm_diffusion_setting.init_image,
|
| 32 |
-
self.config.settings.lcm_diffusion_setting.strength,
|
| 33 |
-
)
|
| 34 |
-
self.prepare_images(images)
|
| 35 |
-
self.after_generation()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|