Upload run_txt2img_axe_infer_new.py
Browse filesOnly load the model once, support multiple rounds of Prompt to generate images and save to the specified folder
- run_txt2img_axe_infer_new.py +197 -0
run_txt2img_axe_infer_new.py
ADDED
@@ -0,0 +1,197 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import List, Union
|
2 |
+
import numpy as np
|
3 |
+
# import onnxruntime
|
4 |
+
import axengine
|
5 |
+
import torch
|
6 |
+
from PIL import Image
|
7 |
+
from transformers import CLIPTokenizer, CLIPTextModel, PreTrainedTokenizer, CLIPTextModelWithProjection
|
8 |
+
|
9 |
+
import time
|
10 |
+
import argparse
|
11 |
+
import uuid # 用于生成唯一文件名
|
12 |
+
import os
|
13 |
+
|
14 |
+
|
15 |
+
def get_args():
|
16 |
+
parser = argparse.ArgumentParser(
|
17 |
+
prog="StableDiffusion",
|
18 |
+
description="Generate picture with the input prompt"
|
19 |
+
)
|
20 |
+
parser.add_argument("--prompt", type=str, required=False, default="Self-portrait oil painting, a beautiful cyborg with golden hair, 8k", help="the input text prompt")
|
21 |
+
parser.add_argument("--text_model_dir", type=str, required=False, default="./models/", help="Path to text encoder and tokenizer files")
|
22 |
+
parser.add_argument("--unet_model", type=str, required=False, default="./models/unet.axmodel", help="Path to unet axmodel model")
|
23 |
+
parser.add_argument("--vae_decoder_model", type=str, required=False, default="./models/vae_decoder.axmodel", help="Path to vae decoder axmodel model")
|
24 |
+
parser.add_argument("--time_input", type=str, required=False, default="./models/time_input_txt2img.npy", help="Path to time input file")
|
25 |
+
parser.add_argument("--save_dir", type=str, required=False, default="./txt2img_output_axe", help="Path to the output image file")
|
26 |
+
return parser.parse_args()
|
27 |
+
|
28 |
+
def maybe_convert_prompt(prompt: Union[str, List[str]], tokenizer: "PreTrainedTokenizer"): # noqa: F821
|
29 |
+
if not isinstance(prompt, List):
|
30 |
+
prompts = [prompt]
|
31 |
+
else:
|
32 |
+
prompts = prompt
|
33 |
+
|
34 |
+
prompts = [_maybe_convert_prompt(p, tokenizer) for p in prompts]
|
35 |
+
|
36 |
+
if not isinstance(prompt, List):
|
37 |
+
return prompts[0]
|
38 |
+
|
39 |
+
return prompts
|
40 |
+
|
41 |
+
|
42 |
+
def _maybe_convert_prompt(prompt: str, tokenizer: "PreTrainedTokenizer"): # noqa: F821
|
43 |
+
tokens = tokenizer.tokenize(prompt)
|
44 |
+
unique_tokens = set(tokens)
|
45 |
+
for token in unique_tokens:
|
46 |
+
if token in tokenizer.added_tokens_encoder:
|
47 |
+
replacement = token
|
48 |
+
i = 1
|
49 |
+
while f"{token}_{i}" in tokenizer.added_tokens_encoder:
|
50 |
+
replacement += f" {token}_{i}"
|
51 |
+
i += 1
|
52 |
+
|
53 |
+
prompt = prompt.replace(token, replacement)
|
54 |
+
|
55 |
+
return prompt
|
56 |
+
|
57 |
+
|
58 |
+
def get_embeds(prompt = "Portrait of a pretty girl", tokenizer_dir = "./models/tokenizer", text_encoder_dir = "./models/text_encoder"):
|
59 |
+
tokenizer = CLIPTokenizer.from_pretrained(tokenizer_dir)
|
60 |
+
text_encoder = CLIPTextModel.from_pretrained(text_encoder_dir,
|
61 |
+
torch_dtype=torch.float32,
|
62 |
+
variant="fp16")
|
63 |
+
text_inputs = tokenizer(
|
64 |
+
prompt,
|
65 |
+
padding="max_length",
|
66 |
+
max_length=77,
|
67 |
+
truncation=True,
|
68 |
+
return_tensors="pt",
|
69 |
+
)
|
70 |
+
text_input_ids = text_inputs.input_ids
|
71 |
+
prompt_embeds = text_encoder(text_input_ids.to("cpu"), attention_mask=None)
|
72 |
+
|
73 |
+
prompt_embeds_npy = prompt_embeds[0].detach().numpy()
|
74 |
+
return prompt_embeds_npy
|
75 |
+
|
76 |
+
|
77 |
+
def get_alphas_cumprod():
|
78 |
+
betas = torch.linspace(0.00085 ** 0.5, 0.012 ** 0.5, 1000, dtype=torch.float32) ** 2
|
79 |
+
alphas = 1.0 - betas
|
80 |
+
alphas_cumprod = torch.cumprod(alphas, dim=0).detach().numpy()
|
81 |
+
final_alphas_cumprod = alphas_cumprod[0]
|
82 |
+
self_timesteps = np.arange(0, 1000)[::-1].copy().astype(np.int64)
|
83 |
+
return alphas_cumprod, final_alphas_cumprod, self_timesteps
|
84 |
+
|
85 |
+
|
86 |
+
|
87 |
+
if __name__ == '__main__':
|
88 |
+
args = get_args()
|
89 |
+
|
90 |
+
tokenizer_dir = args.text_model_dir + 'tokenizer'
|
91 |
+
text_encoder_dir = args.text_model_dir + 'text_encoder'
|
92 |
+
unet_model = args.unet_model
|
93 |
+
vae_decoder_model = args.vae_decoder_model
|
94 |
+
time_input = args.time_input
|
95 |
+
save_dir = args.save_dir
|
96 |
+
|
97 |
+
# 确保保存目录存在
|
98 |
+
os.makedirs(save_dir, exist_ok=True)
|
99 |
+
|
100 |
+
print(f"tokenizer: {tokenizer_dir}")
|
101 |
+
print(f"text_encoder: {text_encoder_dir}")
|
102 |
+
print(f"unet_model: {unet_model}")
|
103 |
+
print(f"vae_decoder_model: {vae_decoder_model}")
|
104 |
+
print(f"time_input: {time_input}")
|
105 |
+
print(f"save_dir: {save_dir}")
|
106 |
+
|
107 |
+
# 加载模型(只加载一次)
|
108 |
+
start = time.time()
|
109 |
+
unet_session_main = axengine.InferenceSession(unet_model)
|
110 |
+
vae_decoder = axengine.InferenceSession(vae_decoder_model)
|
111 |
+
print(f"load models take {(1000 * (time.time() - start)):.1f}ms")
|
112 |
+
|
113 |
+
# 主循环:支持多次输入 Prompt
|
114 |
+
while True:
|
115 |
+
# 用户输入 Prompt
|
116 |
+
prompt = input("\nEnter a prompt to generate an image (or type 'exit' to quit): ")
|
117 |
+
if prompt.lower() == 'exit':
|
118 |
+
print("Exiting...")
|
119 |
+
break
|
120 |
+
|
121 |
+
# Text Encoder
|
122 |
+
start = time.time()
|
123 |
+
prompt_embeds_npy = get_embeds(prompt, tokenizer_dir, text_encoder_dir)
|
124 |
+
print(f"text encoder take {(1000 * (time.time() - start)):.1f}ms")
|
125 |
+
|
126 |
+
# 初始化 Latent
|
127 |
+
latents_shape = [1, 4, 64, 64]
|
128 |
+
latent = torch.randn(latents_shape, generator=None, device="cpu", dtype=torch.float32,
|
129 |
+
layout=torch.strided).detach().numpy()
|
130 |
+
|
131 |
+
alphas_cumprod, final_alphas_cumprod, self_timesteps = get_alphas_cumprod()
|
132 |
+
|
133 |
+
# 加载 time_input 文件
|
134 |
+
time_input_data = np.load(time_input)
|
135 |
+
|
136 |
+
# U-Net Inference Loop
|
137 |
+
timesteps = np.array([999, 759, 499, 259]).astype(np.int64)
|
138 |
+
unet_loop_start = time.time()
|
139 |
+
for i, timestep in enumerate(timesteps):
|
140 |
+
unet_start = time.time()
|
141 |
+
noise_pred = unet_session_main.run(None, {
|
142 |
+
"sample": latent,
|
143 |
+
"/down_blocks.0/resnets.0/act_1/Mul_output_0": np.expand_dims(time_input_data[i], axis=0),
|
144 |
+
"encoder_hidden_states": prompt_embeds_npy
|
145 |
+
})[0]
|
146 |
+
print(f"unet once take {(1000 * (time.time() - unet_start)):.1f}ms")
|
147 |
+
|
148 |
+
sample = latent
|
149 |
+
model_output = noise_pred
|
150 |
+
if i < 3:
|
151 |
+
prev_timestep = timesteps[i + 1]
|
152 |
+
else:
|
153 |
+
prev_timestep = timestep
|
154 |
+
|
155 |
+
alpha_prod_t = alphas_cumprod[timestep]
|
156 |
+
alpha_prod_t_prev = alphas_cumprod[prev_timestep] if prev_timestep >= 0 else final_alphas_cumprod
|
157 |
+
|
158 |
+
beta_prod_t = 1 - alpha_prod_t
|
159 |
+
beta_prod_t_prev = 1 - alpha_prod_t_prev
|
160 |
+
|
161 |
+
scaled_timestep = timestep * 10
|
162 |
+
c_skip = 0.5 ** 2 / (scaled_timestep ** 2 + 0.5 ** 2)
|
163 |
+
c_out = scaled_timestep / (scaled_timestep ** 2 + 0.5 ** 2) ** 0.5
|
164 |
+
predicted_original_sample = (sample - (beta_prod_t ** 0.5) * model_output) / (alpha_prod_t ** 0.5)
|
165 |
+
|
166 |
+
denoised = c_out * predicted_original_sample + c_skip * sample
|
167 |
+
|
168 |
+
if i != 3:
|
169 |
+
noise = torch.randn(model_output.shape, generator=None, device="cpu", dtype=torch.float32,
|
170 |
+
layout=torch.strided).to("cpu").detach().numpy()
|
171 |
+
prev_sample = (alpha_prod_t_prev ** 0.5) * denoised + (beta_prod_t_prev ** 0.5) * noise
|
172 |
+
else:
|
173 |
+
prev_sample = denoised
|
174 |
+
|
175 |
+
latent = prev_sample
|
176 |
+
|
177 |
+
print(f"unet loop take {(1000 * (time.time() - unet_loop_start)):.1f}ms")
|
178 |
+
|
179 |
+
# VAE Inference
|
180 |
+
vae_start = time.time()
|
181 |
+
latent = latent / 0.18215
|
182 |
+
image = vae_decoder.run(None, {"x": latent})[0]
|
183 |
+
print(f"vae inference take {(1000 * (time.time() - vae_start)):.1f}ms")
|
184 |
+
|
185 |
+
# 保存结果
|
186 |
+
save_start = time.time()
|
187 |
+
image = np.transpose(image, (0, 2, 3, 1)).squeeze(axis=0)
|
188 |
+
image_denorm = np.clip(image / 2 + 0.5, 0, 1)
|
189 |
+
image = (image_denorm * 255).round().astype("uint8")
|
190 |
+
pil_image = Image.fromarray(image[:, :, :3])
|
191 |
+
|
192 |
+
# 使用 UUID 确保文件名唯一
|
193 |
+
unique_filename = f"{uuid.uuid4()}.png"
|
194 |
+
save_path = os.path.join(save_dir, unique_filename)
|
195 |
+
pil_image.save(save_path)
|
196 |
+
print(f"Image saved to {save_path}")
|
197 |
+
print(f"Save image take {(1000 * (time.time() - save_start)):.1f}ms")
|