|
from diffusers import DiffusionPipeline |
|
import torch |
|
from PIL import Image |
|
|
|
|
|
def test_lora(lcm_speedup=Flase): |
|
|
|
pipe = DiffusionPipeline.from_pretrained("your_sd_dir/stable-diffusion-v1-5", torch_dtype=torch.float16, safety_checker = None, requires_safety_checker=False) |
|
pipe.to("cuda") |
|
|
|
|
|
lora_path = "your_lora_dir" |
|
pipe.load_lora_weights(pretrained_model_name_or_path_or_dict=lora_path, weight_name="1epoch_lora.safetensors", adapter_name="pattern") |
|
if lcm_speedup: |
|
pipe.load_lora_weights(pretrained_model_name_or_path_or_dict=lora_path, weight_name="lcm_lora.safetensors", adapter_name="lcm") |
|
pipe.set_adapters(["pattern", "lcm"], adapter_weights=[1.0, 1.0]) |
|
|
|
|
|
prompts = [ |
|
"Tang Dynasty Phoenix bird pattern, multi-integrated color complex figurative embroidery animal pattern, white background, asymmetry, meaning good weather, good luck, happy life. A symbol of good peace, abundance of children, supreme power and dominion. Worship of auspicious gods", |
|
"Tang Dynasty Treasure Flower Pattern,flower,rotational,flower,rotational, radioactive arrangement,symmetry, solo, yellow theme" |
|
] |
|
|
|
|
|
if lcm_speedup: |
|
num_inference_steps = 8 |
|
guidance_scale = 2 |
|
else: |
|
num_inference_steps = 30 |
|
guidance_scale = 7.5 |
|
|
|
num_samples_per_prompt = 3 |
|
|
|
|
|
all_images = [] |
|
|
|
|
|
for prompt in prompts: |
|
images = pipe(prompt, num_inference_steps=num_inference_steps, guidance_scale=guidance_scale, num_images_per_prompt=num_samples_per_prompt).images |
|
all_images.extend(images) |
|
|
|
|
|
grid_image = Image.new('RGB', (3 * 512, 2 * 512)) |
|
for idx, img in enumerate(all_images): |
|
x = (idx % 3) * 512 |
|
y = (idx // 3) * 512 |
|
grid_image.paste(img, (x, y)) |
|
|
|
|
|
n = 4 if lcm_speedup else 30 |
|
grid_image.save(f"test_lora_grid_{n}_steps.png") |
|
|
|
|
|
if __name__ == "__main__": |
|
test_lora() |
|
test_lora(lcm_speedup=True) |