File size: 2,245 Bytes
3761354 bae23b5 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 |
from diffusers import DiffusionPipeline
import torch
from PIL import Image
def test_lora(lcm_speedup=Flase):
# 加载 Stable Diffusion 模型
pipe = DiffusionPipeline.from_pretrained("your_sd_dir/stable-diffusion-v1-5", torch_dtype=torch.float16, safety_checker = None, requires_safety_checker=False)
pipe.to("cuda")
# 加载 LoRA 权重
lora_path = "your_lora_dir"
pipe.load_lora_weights(pretrained_model_name_or_path_or_dict=lora_path, weight_name="1epoch_lora.safetensors", adapter_name="pattern")
if lcm_speedup:
pipe.load_lora_weights(pretrained_model_name_or_path_or_dict=lora_path, weight_name="lcm_lora.safetensors", adapter_name="lcm")
pipe.set_adapters(["pattern", "lcm"], adapter_weights=[1.0, 1.0])
# 定义 prompt 列表
prompts = [
"Tang Dynasty Phoenix bird pattern, multi-integrated color complex figurative embroidery animal pattern, white background, asymmetry, meaning good weather, good luck, happy life. A symbol of good peace, abundance of children, supreme power and dominion. Worship of auspicious gods",
"Tang Dynasty Treasure Flower Pattern,flower,rotational,flower,rotational, radioactive arrangement,symmetry, solo, yellow theme"
]
# 设置生成参数
if lcm_speedup:
num_inference_steps = 8
guidance_scale = 2
else:
num_inference_steps = 30
guidance_scale = 7.5
num_samples_per_prompt = 3
# 创建一个空的图像列表
all_images = []
# 为每个 prompt 生成 num_samples_per_prompt 张图片
for prompt in prompts:
images = pipe(prompt, num_inference_steps=num_inference_steps, guidance_scale=guidance_scale, num_images_per_prompt=num_samples_per_prompt).images
all_images.extend(images)
# 创建一个 2x3 的网格图
grid_image = Image.new('RGB', (3 * 512, 2 * 512)) # 假设每张图片大小为 512x512
for idx, img in enumerate(all_images):
x = (idx % 3) * 512
y = (idx // 3) * 512
grid_image.paste(img, (x, y))
# 保存网格图
n = 4 if lcm_speedup else 30
grid_image.save(f"test_lora_grid_{n}_steps.png")
if __name__ == "__main__":
test_lora()
test_lora(lcm_speedup=True) |