HaileyStorm
commited on
Commit
β’
d4b964f
1
Parent(s):
c0347e0
Upload 2 files
Browse files- infer.py +73 -0
- merge_compare.py +263 -0
infer.py
ADDED
@@ -0,0 +1,73 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from diffusers import FluxPipeline, FluxTransformer2DModel
|
2 |
+
import torch
|
3 |
+
import os
|
4 |
+
|
5 |
+
# Configuration
|
6 |
+
MODEL_DIR = "./merged_models/2.5_1"
|
7 |
+
IMAGE_OUTPUT_DIR = "./"
|
8 |
+
IMAGE_PREFIX = "flowers_2.5_1"
|
9 |
+
DEVICE = torch.device("cpu")
|
10 |
+
# If True, uses pipeline.enable_sequential_cpu_offload(). Make sure device is CPU.
|
11 |
+
USE_CPU_OFFLOAD = True
|
12 |
+
SEED = 0
|
13 |
+
# At least 880x656 fits on 24GB GPU w/ sequential offload
|
14 |
+
IMAGE_WIDTH = 1280
|
15 |
+
IMAGE_HEIGHT = 1024
|
16 |
+
NUM_STEPS = 10 # Try ~4-8 for 10:1 and ~8-16+ for 4:1 and 2.5:1 ("Default" 6, 10, 16)
|
17 |
+
NUM_IMAGES = 4
|
18 |
+
CFG = 3.5
|
19 |
+
PROMPT = ("Impressionistic tableau medium shot painting with soft, blended brushstrokes and muted colors complemented "
|
20 |
+
"by sporadic vibrant highlights.")
|
21 |
+
PROMPT2 = ("Impressionistic tableau painting with soft brushstrokes and muted colors, accented by vibrant highlights, "
|
22 |
+
"of a tranquil courtyard surrounded by wildflowers. Madison, a 19-year-old woman with light dirty blond "
|
23 |
+
"hair and bubblegum-pink highlights in a ponytail, brown eyes, and soft facial features, stands beside "
|
24 |
+
"Amelia, a tall mid-20s woman with deep auburn hair in a messy bun, summer sky-blue eyes, and pronounced "
|
25 |
+
"cheekbones. Together, they exude harmony and intrigue, their contrasting features complementing each "
|
26 |
+
"other.")
|
27 |
+
|
28 |
+
print("Loading model...")
|
29 |
+
transformer = FluxTransformer2DModel.from_pretrained(MODEL_DIR, torch_dtype=torch.bfloat16, use_safetensors=True)
|
30 |
+
print("Creating pipeline...")
|
31 |
+
pipeline = FluxPipeline.from_pretrained(
|
32 |
+
"black-forest-labs/FLUX.1-dev", transformer=transformer, torch_dtype=torch.bfloat16
|
33 |
+
, use_safetensors=True, local_dir="./models/dev/", local_dir_use_symlinks=False,
|
34 |
+
ignore_patterns=["flux1-dev.sft", "flux1-dev.safetensors"]).to(DEVICE)
|
35 |
+
pipeline.enable_sequential_cpu_offload()
|
36 |
+
print("Generating image...")
|
37 |
+
# Params:
|
38 |
+
# prompt β The prompt or prompts to guide the image generation. If not defined, one has to pass prompt_embeds. instead.
|
39 |
+
# prompt_2 β The prompt or prompts to be sent to tokenizer_2 and text_encoder_2. If not defined, prompt is will be used instead
|
40 |
+
# height β The height in pixels of the generated image. This is set to 1024 by default for the best results.
|
41 |
+
# width β The width in pixels of the generated image. This is set to 1024 by default for the best results.
|
42 |
+
# num_inference_steps β The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference.
|
43 |
+
# timesteps β Custom timesteps to use for the denoising process with schedulers which support a timesteps argument in their set_timesteps method. If not defined, the default behavior when num_inference_steps is passed will be used. Must be in descending order.
|
44 |
+
# guidance_scale β Guidance scale as defined in [Classifier-Free Diffusion Guidance](https:// arxiv. org/ abs/ 2207.12598 ). guidance_scale is defined as w of equation 2. of [Imagen Paper](https:// arxiv. org/ pdf/ 2205.11487.pdf ). Guidance scale is enabled by setting guidance_scale > 1. Higher guidance scale encourages to generate images that are closely linked to the text prompt, usually at the expense of lower image quality.
|
45 |
+
# num_images_per_prompt β The number of images to generate per prompt.
|
46 |
+
# generator β One or a list of [torch generator(s)](https:// pytorch. org/ docs/ stable/ generated/ torch. Generator. html ) to make generation deterministic.
|
47 |
+
# latents β Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image generation. Can be used to tweak the same generation with different prompts. If not provided, a latents tensor will ge generated by sampling using the supplied random generator.
|
48 |
+
# prompt_embeds β Pre-generated text embeddings. Can be used to easily tweak text inputs, e. g. prompt weighting. If not provided, text embeddings will be generated from prompt input argument.
|
49 |
+
# pooled_prompt_embeds β Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, e. g. prompt weighting. If not provided, pooled text embeddings will be generated from prompt input argument.
|
50 |
+
# output_type β The output format of the generate image. Choose between [PIL](https:// pillow. readthedocs. io/ en/ stable/ ): PIL. Image. Image or np. array.
|
51 |
+
# return_dict β Whether or not to return a [~pipelines. flux. FluxPipelineOutput] instead of a plain tuple.
|
52 |
+
# joint_attention_kwargs β A kwargs dictionary that if specified is passed along to the AttentionProcessor as defined under self. processor in [diffusers. models. attention_processor](https:// github. com/ huggingface/ diffusers/ blob/ main/ src/ diffusers/ models/ attention_processor. py ).
|
53 |
+
# callback_on_step_end β A function that calls at the end of each denoising steps during the inference. The function is called with the following arguments: callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict). callback_kwargs will include a list of all tensors as specified by callback_on_step_end_tensor_inputs.
|
54 |
+
# callback_on_step_end_tensor_inputs β The list of tensor inputs for the callback_on_step_end function. The tensors specified in the list will be passed as callback_kwargs argument. You will only be able to include variables listed in the ._callback_tensor_inputs attribute of your pipeline class.
|
55 |
+
# max_sequence_length β Maximum sequence length to use with the prompt.
|
56 |
+
# Returns:
|
57 |
+
# [~pipelines. flux. FluxPipelineOutput] if return_dict is True, otherwise a tuple. When returning a tuple, the first element is a list with the generated images.
|
58 |
+
images = pipeline(
|
59 |
+
prompt=PROMPT,
|
60 |
+
prompt_2=PROMPT2,
|
61 |
+
guidance_scale=CFG,
|
62 |
+
num_inference_steps=NUM_STEPS,
|
63 |
+
height=IMAGE_HEIGHT,
|
64 |
+
width=IMAGE_WIDTH,
|
65 |
+
max_sequence_length=512,
|
66 |
+
generator=torch.manual_seed(42),
|
67 |
+
num_images_per_prompt=NUM_IMAGES,
|
68 |
+
).images
|
69 |
+
for i, image in enumerate(images):
|
70 |
+
print("Saving image...")
|
71 |
+
path = os.path.join(IMAGE_OUTPUT_DIR, f"{IMAGE_PREFIX}_{i}.png")
|
72 |
+
image.save(path)
|
73 |
+
print("Done.")
|
merge_compare.py
ADDED
@@ -0,0 +1,263 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import gc
|
3 |
+
import glob
|
4 |
+
from multiprocessing import Pool
|
5 |
+
import time
|
6 |
+
from tqdm import tqdm
|
7 |
+
import torch
|
8 |
+
from safetensors.torch import load_file
|
9 |
+
from diffusers import FluxTransformer2DModel, FluxPipeline
|
10 |
+
from huggingface_hub import snapshot_download
|
11 |
+
from PIL import Image
|
12 |
+
|
13 |
+
# Configuration
|
14 |
+
DEVICE = torch.device("cpu")
|
15 |
+
# If True, uses pipeline.enable_sequential_cpu_offload(). Make sure device is CPU.
|
16 |
+
USE_CPU_OFFLOAD = True
|
17 |
+
DTYPE = torch.bfloat16
|
18 |
+
NUM_WORKERS = 1
|
19 |
+
SEED = 0
|
20 |
+
IMAGE_WIDTH = 880 # 688
|
21 |
+
IMAGE_HEIGHT = 656 # 512
|
22 |
+
|
23 |
+
PROMPTS = [
|
24 |
+
"a tiny astronaut hatching from an egg on the moon",
|
25 |
+
#"photo of a female cyberpunk hacker, plugged in and hacking, far future, neon lights"
|
26 |
+
'photo of a man on a beach holding a sign that says "Premature optimization is the root of all evil - test your shit!"'
|
27 |
+
]
|
28 |
+
STEP_COUNTS = [4, 8, 16, 32, 50]
|
29 |
+
MERGE_RATIOS = [
|
30 |
+
# (1, 0), (4, 1), (3, 1), (2, 1), (1, 1), (1, 2), (1, 3), (1, 4), (0, 1)
|
31 |
+
(1, 0), (12, 1), (10, 1), (7, 1), (5.5, 1), (4, 1), (3.5, 1), (3, 1), (2.5, 1), (2, 1), (1.5, 1), (0, 1)
|
32 |
+
]
|
33 |
+
MERGE_LABELS = [
|
34 |
+
# "Pure Schnell", "4:1", "3:1", "2:1", "1:1 Merge", "1:2", "1:3", "1:4", "Pure Dev"
|
35 |
+
"Pure Schnell", "12:1", "10:1", "7:1", "5.5:1", "4:1", "3.5:1", "3:1", "2.5:1", "2:1", "1.5:1", "Pure Dev"
|
36 |
+
]
|
37 |
+
assert len(MERGE_RATIOS) == len(MERGE_LABELS)
|
38 |
+
|
39 |
+
# Output directories
|
40 |
+
IMAGE_OUTPUT_DIR = "./outputs"
|
41 |
+
MODEL_OUTPUT_DIR = "./merged_models"
|
42 |
+
SAVE_MODELS = False
|
43 |
+
os.makedirs(IMAGE_OUTPUT_DIR, exist_ok=True)
|
44 |
+
|
45 |
+
|
46 |
+
# Utility function for cleanup
|
47 |
+
def cleanup():
|
48 |
+
gc.collect()
|
49 |
+
torch.cuda.empty_cache()
|
50 |
+
|
51 |
+
|
52 |
+
# Start timing
|
53 |
+
start_time = time.time()
|
54 |
+
|
55 |
+
|
56 |
+
def merge_models(dev_shards, schnell_shards, ratio):
|
57 |
+
schnell_weight, dev_weight = ratio
|
58 |
+
total_weight = schnell_weight + dev_weight
|
59 |
+
|
60 |
+
merged_state_dict = {}
|
61 |
+
guidance_state_dict = {}
|
62 |
+
|
63 |
+
for i in tqdm(range(len(dev_shards)), "Processing shards...", dynamic_ncols=True):
|
64 |
+
state_dict_dev = load_file(dev_shards[i])
|
65 |
+
state_dict_schnell = load_file(schnell_shards[i])
|
66 |
+
|
67 |
+
keys = list(state_dict_dev.keys())
|
68 |
+
for k in tqdm(keys, f"\tProcessing keys of shard {i}...", dynamic_ncols=True):
|
69 |
+
if "guidance" not in k:
|
70 |
+
merged_state_dict[k] = (
|
71 |
+
state_dict_schnell[k] * schnell_weight +
|
72 |
+
state_dict_dev[k] * dev_weight
|
73 |
+
) / total_weight
|
74 |
+
else:
|
75 |
+
guidance_state_dict[k] = state_dict_dev[k]
|
76 |
+
|
77 |
+
merged_state_dict.update(guidance_state_dict)
|
78 |
+
return merged_state_dict
|
79 |
+
|
80 |
+
|
81 |
+
# Function to create merged model
|
82 |
+
def create_merged_model(dev_ckpt, schnell_ckpt, ratio):
|
83 |
+
config = FluxTransformer2DModel.load_config("black-forest-labs/FLUX.1-dev", subfolder="transformer")
|
84 |
+
model = FluxTransformer2DModel.from_config(config)
|
85 |
+
|
86 |
+
dev_shards = sorted(glob.glob(f"{dev_ckpt}/transformer/*.safetensors"))
|
87 |
+
schnell_shards = sorted(glob.glob(f"{schnell_ckpt}/transformer/*.safetensors"))
|
88 |
+
|
89 |
+
merged_state_dict = merge_models(dev_shards, schnell_shards, ratio)
|
90 |
+
model.load_state_dict(merged_state_dict)
|
91 |
+
del merged_state_dict
|
92 |
+
cleanup()
|
93 |
+
|
94 |
+
return model.to(DTYPE)
|
95 |
+
|
96 |
+
|
97 |
+
def generate_image(pipeline, prompt, num_steps, output_path):
|
98 |
+
if not os.path.exists(output_path):
|
99 |
+
# Params:
|
100 |
+
# prompt β The prompt or prompts to guide the image generation. If not defined, one has to pass prompt_embeds. instead.
|
101 |
+
# prompt_2 β The prompt or prompts to be sent to tokenizer_2 and text_encoder_2. If not defined, prompt is will be used instead
|
102 |
+
# height β The height in pixels of the generated image. This is set to 1024 by default for the best results.
|
103 |
+
# width β The width in pixels of the generated image. This is set to 1024 by default for the best results.
|
104 |
+
# num_inference_steps β The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference.
|
105 |
+
# timesteps β Custom timesteps to use for the denoising process with schedulers which support a timesteps argument in their set_timesteps method. If not defined, the default behavior when num_inference_steps is passed will be used. Must be in descending order.
|
106 |
+
# guidance_scale β Guidance scale as defined in [Classifier-Free Diffusion Guidance](https:// arxiv. org/ abs/ 2207.12598 ). guidance_scale is defined as w of equation 2. of [Imagen Paper](https:// arxiv. org/ pdf/ 2205.11487.pdf ). Guidance scale is enabled by setting guidance_scale > 1. Higher guidance scale encourages to generate images that are closely linked to the text prompt, usually at the expense of lower image quality.
|
107 |
+
# num_images_per_prompt β The number of images to generate per prompt.
|
108 |
+
# generator β One or a list of [torch generator(s)](https:// pytorch. org/ docs/ stable/ generated/ torch. Generator. html ) to make generation deterministic.
|
109 |
+
# latents β Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image generation. Can be used to tweak the same generation with different prompts. If not provided, a latents tensor will ge generated by sampling using the supplied random generator.
|
110 |
+
# prompt_embeds β Pre-generated text embeddings. Can be used to easily tweak text inputs, e. g. prompt weighting. If not provided, text embeddings will be generated from prompt input argument.
|
111 |
+
# pooled_prompt_embeds β Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, e. g. prompt weighting. If not provided, pooled text embeddings will be generated from prompt input argument.
|
112 |
+
# output_type β The output format of the generate image. Choose between [PIL](https:// pillow. readthedocs. io/ en/ stable/ ): PIL. Image. Image or np. array.
|
113 |
+
# return_dict β Whether or not to return a [~pipelines. flux. FluxPipelineOutput] instead of a plain tuple.
|
114 |
+
# joint_attention_kwargs β A kwargs dictionary that if specified is passed along to the AttentionProcessor as defined under self. processor in [diffusers. models. attention_processor](https:// github. com/ huggingface/ diffusers/ blob/ main/ src/ diffusers/ models/ attention_processor. py ).
|
115 |
+
# callback_on_step_end β A function that calls at the end of each denoising steps during the inference. The function is called with the following arguments: callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict). callback_kwargs will include a list of all tensors as specified by callback_on_step_end_tensor_inputs.
|
116 |
+
# callback_on_step_end_tensor_inputs β The list of tensor inputs for the callback_on_step_end function. The tensors specified in the list will be passed as callback_kwargs argument. You will only be able to include variables listed in the ._callback_tensor_inputs attribute of your pipeline class.
|
117 |
+
# max_sequence_length β Maximum sequence length to use with the prompt.
|
118 |
+
# Returns:
|
119 |
+
# [~pipelines. flux. FluxPipelineOutput] if return_dict is True, otherwise a tuple. When returning a tuple, the first element is a list with the generated images.
|
120 |
+
image = pipeline(
|
121 |
+
prompt=prompt,
|
122 |
+
guidance_scale=3.5,
|
123 |
+
num_inference_steps=num_steps,
|
124 |
+
height=IMAGE_HEIGHT,
|
125 |
+
width=IMAGE_WIDTH,
|
126 |
+
max_sequence_length=512,
|
127 |
+
generator=torch.manual_seed(SEED),
|
128 |
+
).images[0]
|
129 |
+
image.save(output_path)
|
130 |
+
else:
|
131 |
+
print("Image already exists, skipping...")
|
132 |
+
|
133 |
+
|
134 |
+
def process_model(ratio, label, dev_ckpt, schnell_ckpt):
|
135 |
+
image_output_dir = os.path.join(IMAGE_OUTPUT_DIR, label.replace(":", "_"))
|
136 |
+
os.makedirs(image_output_dir, exist_ok=True)
|
137 |
+
existing_images = len([name for name in os.listdir(image_output_dir) if os.path.isfile(os.path.join(image_output_dir, name))])
|
138 |
+
if existing_images == len(PROMPTS) * len(STEP_COUNTS):
|
139 |
+
print(f"\nModel {label} already complete, skipping...")
|
140 |
+
return
|
141 |
+
else:
|
142 |
+
print(f"\nProcessing {label} model...")
|
143 |
+
|
144 |
+
if ratio == (1, 0): # Pure Schnell
|
145 |
+
model = FluxTransformer2DModel.from_pretrained(schnell_ckpt, subfolder="transformer", torch_dtype=DTYPE)
|
146 |
+
elif ratio == (0, 1): # Pure Dev
|
147 |
+
model = FluxTransformer2DModel.save_pretrained().from_pretrained(dev_ckpt, subfolder="transformer", torch_dtype=DTYPE)
|
148 |
+
else:
|
149 |
+
model = create_merged_model(dev_ckpt, schnell_ckpt, ratio)
|
150 |
+
|
151 |
+
if SAVE_MODELS:
|
152 |
+
model_output_dir = os.path.join(MODEL_OUTPUT_DIR, label.replace(":", "_"))
|
153 |
+
print(f"Saving model to {model_output_dir}...")
|
154 |
+
model.save_pretrained(model_output_dir, max_shared_size="50GB", safe_serialization=True)
|
155 |
+
|
156 |
+
pipeline = FluxPipeline.from_pretrained(
|
157 |
+
dev_ckpt,
|
158 |
+
transformer=model,
|
159 |
+
torch_dtype=DTYPE,
|
160 |
+
).to(DEVICE)
|
161 |
+
if USE_CPU_OFFLOAD:
|
162 |
+
pipeline.enable_sequential_cpu_offload()
|
163 |
+
#pipeline.enable_xformers_memory_efficient_attention()
|
164 |
+
|
165 |
+
for prompt_idx, prompt in enumerate(PROMPTS):
|
166 |
+
for step_count in STEP_COUNTS:
|
167 |
+
output_path = os.path.join(
|
168 |
+
image_output_dir,
|
169 |
+
f"prompt{prompt_idx + 1}_steps{step_count}.png"
|
170 |
+
)
|
171 |
+
generate_image(pipeline, prompt, step_count, output_path)
|
172 |
+
|
173 |
+
del pipeline
|
174 |
+
del model
|
175 |
+
cleanup()
|
176 |
+
|
177 |
+
|
178 |
+
def main():
|
179 |
+
dev_ckpt = snapshot_download(repo_id="black-forest-labs/FLUX.1-dev", ignore_patterns=["flux1-dev.sft","flux1-dev.safetensors"],
|
180 |
+
local_dir="./models/dev/")
|
181 |
+
schnell_ckpt = snapshot_download(repo_id="black-forest-labs/FLUX.1-schnell", allow_patterns="transformer/*",
|
182 |
+
local_dir="./models/schnell/")
|
183 |
+
|
184 |
+
with Pool(NUM_WORKERS) as pool:
|
185 |
+
results = [
|
186 |
+
pool.apply_async(
|
187 |
+
process_model,
|
188 |
+
(ratio, label, dev_ckpt, schnell_ckpt)
|
189 |
+
)
|
190 |
+
for ratio, label in zip(MERGE_RATIOS, MERGE_LABELS)
|
191 |
+
]
|
192 |
+
|
193 |
+
for result in tqdm(results):
|
194 |
+
result.get() # This will block until the result is ready
|
195 |
+
|
196 |
+
pool.close()
|
197 |
+
pool.join()
|
198 |
+
|
199 |
+
|
200 |
+
def create_image_grid(image_paths, output_path, padding=10):
|
201 |
+
width = IMAGE_WIDTH // 2
|
202 |
+
height = IMAGE_HEIGHT // 2
|
203 |
+
images = [Image.open(path).resize((width, height)) for path in image_paths]
|
204 |
+
|
205 |
+
grid_cols = len(MERGE_RATIOS)
|
206 |
+
grid_rows = len(STEP_COUNTS)
|
207 |
+
top_pad = 250
|
208 |
+
left_pad = 200
|
209 |
+
grid_width = (width * grid_cols) + (padding * (grid_cols + 1)) + left_pad
|
210 |
+
grid_height = (height * grid_rows) + (padding * (grid_rows + 1)) + top_pad
|
211 |
+
|
212 |
+
grid_image = Image.new('RGB', (grid_width, grid_height), color=(255, 255, 255))
|
213 |
+
|
214 |
+
for idx, img in enumerate(images):
|
215 |
+
row = idx // grid_cols
|
216 |
+
col = idx % grid_cols
|
217 |
+
x_position = (col * width) + (padding * (col + 1)) + left_pad
|
218 |
+
y_position = (row * height) + (padding * (row + 1)) + top_pad
|
219 |
+
grid_image.paste(img, (x_position, y_position))
|
220 |
+
|
221 |
+
grid_image.save(output_path)
|
222 |
+
|
223 |
+
|
224 |
+
# Run the main process
|
225 |
+
main()
|
226 |
+
|
227 |
+
# Create the image grids
|
228 |
+
print("Creating image comparison grid...")
|
229 |
+
# Reconstruct the image paths
|
230 |
+
all_image_paths = [
|
231 |
+
os.path.join(
|
232 |
+
IMAGE_OUTPUT_DIR,
|
233 |
+
label.replace(":", "_"),
|
234 |
+
f"prompt{prompt_idx + 1}_steps{step_count}.png"
|
235 |
+
)
|
236 |
+
for prompt_idx in range(len(PROMPTS))
|
237 |
+
for step_count in STEP_COUNTS
|
238 |
+
for label in MERGE_LABELS
|
239 |
+
]
|
240 |
+
missing_images = [path for path in all_image_paths if not os.path.exists(path)]
|
241 |
+
if missing_images:
|
242 |
+
print(f"Warning: {len(missing_images)} images were not generated:")
|
243 |
+
for path in missing_images[:5]: # Show first 5
|
244 |
+
print(f" β’ {path}")
|
245 |
+
if len(missing_images) > 5:
|
246 |
+
print(f" (and {len(missing_images) - 5} more...)")
|
247 |
+
|
248 |
+
# Create grid images
|
249 |
+
for prompt_idx in range(len(PROMPTS)):
|
250 |
+
prompt_images = [path for path in all_image_paths if f"prompt{prompt_idx + 1}" in path]
|
251 |
+
grid_output_path = os.path.join(IMAGE_OUTPUT_DIR, f"grid_prompt{prompt_idx + 1}.png")
|
252 |
+
create_image_grid(prompt_images, grid_output_path)
|
253 |
+
|
254 |
+
# Final report
|
255 |
+
end_time = time.time()
|
256 |
+
total_time = end_time - start_time
|
257 |
+
num_images = len(all_image_paths)
|
258 |
+
|
259 |
+
print(f"\nProcessing complete!")
|
260 |
+
print(f"Total time: {total_time:.2f} seconds")
|
261 |
+
print(f"Total images generated: {num_images}")
|
262 |
+
print(f"Average time per image: {total_time / num_images:.2f} seconds")
|
263 |
+
print(f"Output directory: {IMAGE_OUTPUT_DIR}")
|