File size: 2,245 Bytes
3761354
 
 
 
bae23b5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
from diffusers import DiffusionPipeline
import torch
from PIL import Image


def test_lora(lcm_speedup=Flase):
    # 加载 Stable Diffusion 模型
    pipe = DiffusionPipeline.from_pretrained("your_sd_dir/stable-diffusion-v1-5", torch_dtype=torch.float16, safety_checker = None, requires_safety_checker=False)
    pipe.to("cuda")

    # 加载 LoRA 权重
    lora_path = "your_lora_dir" 
    pipe.load_lora_weights(pretrained_model_name_or_path_or_dict=lora_path, weight_name="1epoch_lora.safetensors", adapter_name="pattern")
    if lcm_speedup:
        pipe.load_lora_weights(pretrained_model_name_or_path_or_dict=lora_path, weight_name="lcm_lora.safetensors", adapter_name="lcm")
        pipe.set_adapters(["pattern", "lcm"], adapter_weights=[1.0, 1.0])

    # 定义 prompt 列表
    prompts = [
        "Tang Dynasty Phoenix bird pattern, multi-integrated color complex figurative embroidery animal pattern, white background, asymmetry, meaning good weather, good luck, happy life. A symbol of good peace, abundance of children, supreme power and dominion. Worship of auspicious gods",
        "Tang Dynasty Treasure Flower Pattern,flower,rotational,flower,rotational, radioactive arrangement,symmetry, solo, yellow theme"
    ]

    # 设置生成参数
    if lcm_speedup:
        num_inference_steps = 8
        guidance_scale = 2
    else:
        num_inference_steps = 30
        guidance_scale = 7.5
        
    num_samples_per_prompt = 3

    # 创建一个空的图像列表
    all_images = []

    # 为每个 prompt 生成 num_samples_per_prompt 张图片
    for prompt in prompts:
        images = pipe(prompt, num_inference_steps=num_inference_steps, guidance_scale=guidance_scale, num_images_per_prompt=num_samples_per_prompt).images
        all_images.extend(images)

    # 创建一个 2x3 的网格图
    grid_image = Image.new('RGB', (3 * 512, 2 * 512))  # 假设每张图片大小为 512x512
    for idx, img in enumerate(all_images):
        x = (idx % 3) * 512
        y = (idx // 3) * 512
        grid_image.paste(img, (x, y))

    # 保存网格图
    n = 4 if lcm_speedup else 30
    grid_image.save(f"test_lora_grid_{n}_steps.png")


if __name__ == "__main__":
    test_lora()
    test_lora(lcm_speedup=True)