HaileyStorm commited on
Commit
d4b964f
β€’
1 Parent(s): c0347e0

Upload 2 files

Browse files
Files changed (2) hide show
  1. infer.py +73 -0
  2. merge_compare.py +263 -0
infer.py ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from diffusers import FluxPipeline, FluxTransformer2DModel
2
+ import torch
3
+ import os
4
+
5
+ # Configuration
6
+ MODEL_DIR = "./merged_models/2.5_1"
7
+ IMAGE_OUTPUT_DIR = "./"
8
+ IMAGE_PREFIX = "flowers_2.5_1"
9
+ DEVICE = torch.device("cpu")
10
+ # If True, uses pipeline.enable_sequential_cpu_offload(). Make sure device is CPU.
11
+ USE_CPU_OFFLOAD = True
12
+ SEED = 0
13
+ # At least 880x656 fits on 24GB GPU w/ sequential offload
14
+ IMAGE_WIDTH = 1280
15
+ IMAGE_HEIGHT = 1024
16
+ NUM_STEPS = 10 # Try ~4-8 for 10:1 and ~8-16+ for 4:1 and 2.5:1 ("Default" 6, 10, 16)
17
+ NUM_IMAGES = 4
18
+ CFG = 3.5
19
+ PROMPT = ("Impressionistic tableau medium shot painting with soft, blended brushstrokes and muted colors complemented "
20
+ "by sporadic vibrant highlights.")
21
+ PROMPT2 = ("Impressionistic tableau painting with soft brushstrokes and muted colors, accented by vibrant highlights, "
22
+ "of a tranquil courtyard surrounded by wildflowers. Madison, a 19-year-old woman with light dirty blond "
23
+ "hair and bubblegum-pink highlights in a ponytail, brown eyes, and soft facial features, stands beside "
24
+ "Amelia, a tall mid-20s woman with deep auburn hair in a messy bun, summer sky-blue eyes, and pronounced "
25
+ "cheekbones. Together, they exude harmony and intrigue, their contrasting features complementing each "
26
+ "other.")
27
+
28
+ print("Loading model...")
29
+ transformer = FluxTransformer2DModel.from_pretrained(MODEL_DIR, torch_dtype=torch.bfloat16, use_safetensors=True)
30
+ print("Creating pipeline...")
31
+ pipeline = FluxPipeline.from_pretrained(
32
+ "black-forest-labs/FLUX.1-dev", transformer=transformer, torch_dtype=torch.bfloat16
33
+ , use_safetensors=True, local_dir="./models/dev/", local_dir_use_symlinks=False,
34
+ ignore_patterns=["flux1-dev.sft", "flux1-dev.safetensors"]).to(DEVICE)
35
+ pipeline.enable_sequential_cpu_offload()
36
+ print("Generating image...")
37
+ # Params:
38
+ # prompt – The prompt or prompts to guide the image generation. If not defined, one has to pass prompt_embeds. instead.
39
+ # prompt_2 – The prompt or prompts to be sent to tokenizer_2 and text_encoder_2. If not defined, prompt is will be used instead
40
+ # height – The height in pixels of the generated image. This is set to 1024 by default for the best results.
41
+ # width – The width in pixels of the generated image. This is set to 1024 by default for the best results.
42
+ # num_inference_steps – The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference.
43
+ # timesteps – Custom timesteps to use for the denoising process with schedulers which support a timesteps argument in their set_timesteps method. If not defined, the default behavior when num_inference_steps is passed will be used. Must be in descending order.
44
+ # guidance_scale – Guidance scale as defined in [Classifier-Free Diffusion Guidance](https:// arxiv. org/ abs/ 2207.12598 ). guidance_scale is defined as w of equation 2. of [Imagen Paper](https:// arxiv. org/ pdf/ 2205.11487.pdf ). Guidance scale is enabled by setting guidance_scale > 1. Higher guidance scale encourages to generate images that are closely linked to the text prompt, usually at the expense of lower image quality.
45
+ # num_images_per_prompt – The number of images to generate per prompt.
46
+ # generator – One or a list of [torch generator(s)](https:// pytorch. org/ docs/ stable/ generated/ torch. Generator. html ) to make generation deterministic.
47
+ # latents – Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image generation. Can be used to tweak the same generation with different prompts. If not provided, a latents tensor will ge generated by sampling using the supplied random generator.
48
+ # prompt_embeds – Pre-generated text embeddings. Can be used to easily tweak text inputs, e. g. prompt weighting. If not provided, text embeddings will be generated from prompt input argument.
49
+ # pooled_prompt_embeds – Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, e. g. prompt weighting. If not provided, pooled text embeddings will be generated from prompt input argument.
50
+ # output_type – The output format of the generate image. Choose between [PIL](https:// pillow. readthedocs. io/ en/ stable/ ): PIL. Image. Image or np. array.
51
+ # return_dict – Whether or not to return a [~pipelines. flux. FluxPipelineOutput] instead of a plain tuple.
52
+ # joint_attention_kwargs – A kwargs dictionary that if specified is passed along to the AttentionProcessor as defined under self. processor in [diffusers. models. attention_processor](https:// github. com/ huggingface/ diffusers/ blob/ main/ src/ diffusers/ models/ attention_processor. py ).
53
+ # callback_on_step_end – A function that calls at the end of each denoising steps during the inference. The function is called with the following arguments: callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict). callback_kwargs will include a list of all tensors as specified by callback_on_step_end_tensor_inputs.
54
+ # callback_on_step_end_tensor_inputs – The list of tensor inputs for the callback_on_step_end function. The tensors specified in the list will be passed as callback_kwargs argument. You will only be able to include variables listed in the ._callback_tensor_inputs attribute of your pipeline class.
55
+ # max_sequence_length – Maximum sequence length to use with the prompt.
56
+ # Returns:
57
+ # [~pipelines. flux. FluxPipelineOutput] if return_dict is True, otherwise a tuple. When returning a tuple, the first element is a list with the generated images.
58
+ images = pipeline(
59
+ prompt=PROMPT,
60
+ prompt_2=PROMPT2,
61
+ guidance_scale=CFG,
62
+ num_inference_steps=NUM_STEPS,
63
+ height=IMAGE_HEIGHT,
64
+ width=IMAGE_WIDTH,
65
+ max_sequence_length=512,
66
+ generator=torch.manual_seed(42),
67
+ num_images_per_prompt=NUM_IMAGES,
68
+ ).images
69
+ for i, image in enumerate(images):
70
+ print("Saving image...")
71
+ path = os.path.join(IMAGE_OUTPUT_DIR, f"{IMAGE_PREFIX}_{i}.png")
72
+ image.save(path)
73
+ print("Done.")
merge_compare.py ADDED
@@ -0,0 +1,263 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import gc
3
+ import glob
4
+ from multiprocessing import Pool
5
+ import time
6
+ from tqdm import tqdm
7
+ import torch
8
+ from safetensors.torch import load_file
9
+ from diffusers import FluxTransformer2DModel, FluxPipeline
10
+ from huggingface_hub import snapshot_download
11
+ from PIL import Image
12
+
13
+ # Configuration
14
+ DEVICE = torch.device("cpu")
15
+ # If True, uses pipeline.enable_sequential_cpu_offload(). Make sure device is CPU.
16
+ USE_CPU_OFFLOAD = True
17
+ DTYPE = torch.bfloat16
18
+ NUM_WORKERS = 1
19
+ SEED = 0
20
+ IMAGE_WIDTH = 880 # 688
21
+ IMAGE_HEIGHT = 656 # 512
22
+
23
+ PROMPTS = [
24
+ "a tiny astronaut hatching from an egg on the moon",
25
+ #"photo of a female cyberpunk hacker, plugged in and hacking, far future, neon lights"
26
+ 'photo of a man on a beach holding a sign that says "Premature optimization is the root of all evil - test your shit!"'
27
+ ]
28
+ STEP_COUNTS = [4, 8, 16, 32, 50]
29
+ MERGE_RATIOS = [
30
+ # (1, 0), (4, 1), (3, 1), (2, 1), (1, 1), (1, 2), (1, 3), (1, 4), (0, 1)
31
+ (1, 0), (12, 1), (10, 1), (7, 1), (5.5, 1), (4, 1), (3.5, 1), (3, 1), (2.5, 1), (2, 1), (1.5, 1), (0, 1)
32
+ ]
33
+ MERGE_LABELS = [
34
+ # "Pure Schnell", "4:1", "3:1", "2:1", "1:1 Merge", "1:2", "1:3", "1:4", "Pure Dev"
35
+ "Pure Schnell", "12:1", "10:1", "7:1", "5.5:1", "4:1", "3.5:1", "3:1", "2.5:1", "2:1", "1.5:1", "Pure Dev"
36
+ ]
37
+ assert len(MERGE_RATIOS) == len(MERGE_LABELS)
38
+
39
+ # Output directories
40
+ IMAGE_OUTPUT_DIR = "./outputs"
41
+ MODEL_OUTPUT_DIR = "./merged_models"
42
+ SAVE_MODELS = False
43
+ os.makedirs(IMAGE_OUTPUT_DIR, exist_ok=True)
44
+
45
+
46
+ # Utility function for cleanup
47
+ def cleanup():
48
+ gc.collect()
49
+ torch.cuda.empty_cache()
50
+
51
+
52
+ # Start timing
53
+ start_time = time.time()
54
+
55
+
56
+ def merge_models(dev_shards, schnell_shards, ratio):
57
+ schnell_weight, dev_weight = ratio
58
+ total_weight = schnell_weight + dev_weight
59
+
60
+ merged_state_dict = {}
61
+ guidance_state_dict = {}
62
+
63
+ for i in tqdm(range(len(dev_shards)), "Processing shards...", dynamic_ncols=True):
64
+ state_dict_dev = load_file(dev_shards[i])
65
+ state_dict_schnell = load_file(schnell_shards[i])
66
+
67
+ keys = list(state_dict_dev.keys())
68
+ for k in tqdm(keys, f"\tProcessing keys of shard {i}...", dynamic_ncols=True):
69
+ if "guidance" not in k:
70
+ merged_state_dict[k] = (
71
+ state_dict_schnell[k] * schnell_weight +
72
+ state_dict_dev[k] * dev_weight
73
+ ) / total_weight
74
+ else:
75
+ guidance_state_dict[k] = state_dict_dev[k]
76
+
77
+ merged_state_dict.update(guidance_state_dict)
78
+ return merged_state_dict
79
+
80
+
81
+ # Function to create merged model
82
+ def create_merged_model(dev_ckpt, schnell_ckpt, ratio):
83
+ config = FluxTransformer2DModel.load_config("black-forest-labs/FLUX.1-dev", subfolder="transformer")
84
+ model = FluxTransformer2DModel.from_config(config)
85
+
86
+ dev_shards = sorted(glob.glob(f"{dev_ckpt}/transformer/*.safetensors"))
87
+ schnell_shards = sorted(glob.glob(f"{schnell_ckpt}/transformer/*.safetensors"))
88
+
89
+ merged_state_dict = merge_models(dev_shards, schnell_shards, ratio)
90
+ model.load_state_dict(merged_state_dict)
91
+ del merged_state_dict
92
+ cleanup()
93
+
94
+ return model.to(DTYPE)
95
+
96
+
97
+ def generate_image(pipeline, prompt, num_steps, output_path):
98
+ if not os.path.exists(output_path):
99
+ # Params:
100
+ # prompt – The prompt or prompts to guide the image generation. If not defined, one has to pass prompt_embeds. instead.
101
+ # prompt_2 – The prompt or prompts to be sent to tokenizer_2 and text_encoder_2. If not defined, prompt is will be used instead
102
+ # height – The height in pixels of the generated image. This is set to 1024 by default for the best results.
103
+ # width – The width in pixels of the generated image. This is set to 1024 by default for the best results.
104
+ # num_inference_steps – The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference.
105
+ # timesteps – Custom timesteps to use for the denoising process with schedulers which support a timesteps argument in their set_timesteps method. If not defined, the default behavior when num_inference_steps is passed will be used. Must be in descending order.
106
+ # guidance_scale – Guidance scale as defined in [Classifier-Free Diffusion Guidance](https:// arxiv. org/ abs/ 2207.12598 ). guidance_scale is defined as w of equation 2. of [Imagen Paper](https:// arxiv. org/ pdf/ 2205.11487.pdf ). Guidance scale is enabled by setting guidance_scale > 1. Higher guidance scale encourages to generate images that are closely linked to the text prompt, usually at the expense of lower image quality.
107
+ # num_images_per_prompt – The number of images to generate per prompt.
108
+ # generator – One or a list of [torch generator(s)](https:// pytorch. org/ docs/ stable/ generated/ torch. Generator. html ) to make generation deterministic.
109
+ # latents – Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image generation. Can be used to tweak the same generation with different prompts. If not provided, a latents tensor will ge generated by sampling using the supplied random generator.
110
+ # prompt_embeds – Pre-generated text embeddings. Can be used to easily tweak text inputs, e. g. prompt weighting. If not provided, text embeddings will be generated from prompt input argument.
111
+ # pooled_prompt_embeds – Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, e. g. prompt weighting. If not provided, pooled text embeddings will be generated from prompt input argument.
112
+ # output_type – The output format of the generate image. Choose between [PIL](https:// pillow. readthedocs. io/ en/ stable/ ): PIL. Image. Image or np. array.
113
+ # return_dict – Whether or not to return a [~pipelines. flux. FluxPipelineOutput] instead of a plain tuple.
114
+ # joint_attention_kwargs – A kwargs dictionary that if specified is passed along to the AttentionProcessor as defined under self. processor in [diffusers. models. attention_processor](https:// github. com/ huggingface/ diffusers/ blob/ main/ src/ diffusers/ models/ attention_processor. py ).
115
+ # callback_on_step_end – A function that calls at the end of each denoising steps during the inference. The function is called with the following arguments: callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict). callback_kwargs will include a list of all tensors as specified by callback_on_step_end_tensor_inputs.
116
+ # callback_on_step_end_tensor_inputs – The list of tensor inputs for the callback_on_step_end function. The tensors specified in the list will be passed as callback_kwargs argument. You will only be able to include variables listed in the ._callback_tensor_inputs attribute of your pipeline class.
117
+ # max_sequence_length – Maximum sequence length to use with the prompt.
118
+ # Returns:
119
+ # [~pipelines. flux. FluxPipelineOutput] if return_dict is True, otherwise a tuple. When returning a tuple, the first element is a list with the generated images.
120
+ image = pipeline(
121
+ prompt=prompt,
122
+ guidance_scale=3.5,
123
+ num_inference_steps=num_steps,
124
+ height=IMAGE_HEIGHT,
125
+ width=IMAGE_WIDTH,
126
+ max_sequence_length=512,
127
+ generator=torch.manual_seed(SEED),
128
+ ).images[0]
129
+ image.save(output_path)
130
+ else:
131
+ print("Image already exists, skipping...")
132
+
133
+
134
+ def process_model(ratio, label, dev_ckpt, schnell_ckpt):
135
+ image_output_dir = os.path.join(IMAGE_OUTPUT_DIR, label.replace(":", "_"))
136
+ os.makedirs(image_output_dir, exist_ok=True)
137
+ existing_images = len([name for name in os.listdir(image_output_dir) if os.path.isfile(os.path.join(image_output_dir, name))])
138
+ if existing_images == len(PROMPTS) * len(STEP_COUNTS):
139
+ print(f"\nModel {label} already complete, skipping...")
140
+ return
141
+ else:
142
+ print(f"\nProcessing {label} model...")
143
+
144
+ if ratio == (1, 0): # Pure Schnell
145
+ model = FluxTransformer2DModel.from_pretrained(schnell_ckpt, subfolder="transformer", torch_dtype=DTYPE)
146
+ elif ratio == (0, 1): # Pure Dev
147
+ model = FluxTransformer2DModel.save_pretrained().from_pretrained(dev_ckpt, subfolder="transformer", torch_dtype=DTYPE)
148
+ else:
149
+ model = create_merged_model(dev_ckpt, schnell_ckpt, ratio)
150
+
151
+ if SAVE_MODELS:
152
+ model_output_dir = os.path.join(MODEL_OUTPUT_DIR, label.replace(":", "_"))
153
+ print(f"Saving model to {model_output_dir}...")
154
+ model.save_pretrained(model_output_dir, max_shared_size="50GB", safe_serialization=True)
155
+
156
+ pipeline = FluxPipeline.from_pretrained(
157
+ dev_ckpt,
158
+ transformer=model,
159
+ torch_dtype=DTYPE,
160
+ ).to(DEVICE)
161
+ if USE_CPU_OFFLOAD:
162
+ pipeline.enable_sequential_cpu_offload()
163
+ #pipeline.enable_xformers_memory_efficient_attention()
164
+
165
+ for prompt_idx, prompt in enumerate(PROMPTS):
166
+ for step_count in STEP_COUNTS:
167
+ output_path = os.path.join(
168
+ image_output_dir,
169
+ f"prompt{prompt_idx + 1}_steps{step_count}.png"
170
+ )
171
+ generate_image(pipeline, prompt, step_count, output_path)
172
+
173
+ del pipeline
174
+ del model
175
+ cleanup()
176
+
177
+
178
+ def main():
179
+ dev_ckpt = snapshot_download(repo_id="black-forest-labs/FLUX.1-dev", ignore_patterns=["flux1-dev.sft","flux1-dev.safetensors"],
180
+ local_dir="./models/dev/")
181
+ schnell_ckpt = snapshot_download(repo_id="black-forest-labs/FLUX.1-schnell", allow_patterns="transformer/*",
182
+ local_dir="./models/schnell/")
183
+
184
+ with Pool(NUM_WORKERS) as pool:
185
+ results = [
186
+ pool.apply_async(
187
+ process_model,
188
+ (ratio, label, dev_ckpt, schnell_ckpt)
189
+ )
190
+ for ratio, label in zip(MERGE_RATIOS, MERGE_LABELS)
191
+ ]
192
+
193
+ for result in tqdm(results):
194
+ result.get() # This will block until the result is ready
195
+
196
+ pool.close()
197
+ pool.join()
198
+
199
+
200
+ def create_image_grid(image_paths, output_path, padding=10):
201
+ width = IMAGE_WIDTH // 2
202
+ height = IMAGE_HEIGHT // 2
203
+ images = [Image.open(path).resize((width, height)) for path in image_paths]
204
+
205
+ grid_cols = len(MERGE_RATIOS)
206
+ grid_rows = len(STEP_COUNTS)
207
+ top_pad = 250
208
+ left_pad = 200
209
+ grid_width = (width * grid_cols) + (padding * (grid_cols + 1)) + left_pad
210
+ grid_height = (height * grid_rows) + (padding * (grid_rows + 1)) + top_pad
211
+
212
+ grid_image = Image.new('RGB', (grid_width, grid_height), color=(255, 255, 255))
213
+
214
+ for idx, img in enumerate(images):
215
+ row = idx // grid_cols
216
+ col = idx % grid_cols
217
+ x_position = (col * width) + (padding * (col + 1)) + left_pad
218
+ y_position = (row * height) + (padding * (row + 1)) + top_pad
219
+ grid_image.paste(img, (x_position, y_position))
220
+
221
+ grid_image.save(output_path)
222
+
223
+
224
+ # Run the main process
225
+ main()
226
+
227
+ # Create the image grids
228
+ print("Creating image comparison grid...")
229
+ # Reconstruct the image paths
230
+ all_image_paths = [
231
+ os.path.join(
232
+ IMAGE_OUTPUT_DIR,
233
+ label.replace(":", "_"),
234
+ f"prompt{prompt_idx + 1}_steps{step_count}.png"
235
+ )
236
+ for prompt_idx in range(len(PROMPTS))
237
+ for step_count in STEP_COUNTS
238
+ for label in MERGE_LABELS
239
+ ]
240
+ missing_images = [path for path in all_image_paths if not os.path.exists(path)]
241
+ if missing_images:
242
+ print(f"Warning: {len(missing_images)} images were not generated:")
243
+ for path in missing_images[:5]: # Show first 5
244
+ print(f" β€’ {path}")
245
+ if len(missing_images) > 5:
246
+ print(f" (and {len(missing_images) - 5} more...)")
247
+
248
+ # Create grid images
249
+ for prompt_idx in range(len(PROMPTS)):
250
+ prompt_images = [path for path in all_image_paths if f"prompt{prompt_idx + 1}" in path]
251
+ grid_output_path = os.path.join(IMAGE_OUTPUT_DIR, f"grid_prompt{prompt_idx + 1}.png")
252
+ create_image_grid(prompt_images, grid_output_path)
253
+
254
+ # Final report
255
+ end_time = time.time()
256
+ total_time = end_time - start_time
257
+ num_images = len(all_image_paths)
258
+
259
+ print(f"\nProcessing complete!")
260
+ print(f"Total time: {total_time:.2f} seconds")
261
+ print(f"Total images generated: {num_images}")
262
+ print(f"Average time per image: {total_time / num_images:.2f} seconds")
263
+ print(f"Output directory: {IMAGE_OUTPUT_DIR}")