qqc1989 commited on
Commit
e86efa6
·
verified ·
1 Parent(s): 6086185

Upload run_txt2img_axe_infer_new.py

Browse files

Only load the model once, support multiple rounds of Prompt to generate images and save to the specified folder

Files changed (1) hide show
  1. run_txt2img_axe_infer_new.py +197 -0
run_txt2img_axe_infer_new.py ADDED
@@ -0,0 +1,197 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Union
2
+ import numpy as np
3
+ # import onnxruntime
4
+ import axengine
5
+ import torch
6
+ from PIL import Image
7
+ from transformers import CLIPTokenizer, CLIPTextModel, PreTrainedTokenizer, CLIPTextModelWithProjection
8
+
9
+ import time
10
+ import argparse
11
+ import uuid # 用于生成唯一文件名
12
+ import os
13
+
14
+
15
+ def get_args():
16
+ parser = argparse.ArgumentParser(
17
+ prog="StableDiffusion",
18
+ description="Generate picture with the input prompt"
19
+ )
20
+ parser.add_argument("--prompt", type=str, required=False, default="Self-portrait oil painting, a beautiful cyborg with golden hair, 8k", help="the input text prompt")
21
+ parser.add_argument("--text_model_dir", type=str, required=False, default="./models/", help="Path to text encoder and tokenizer files")
22
+ parser.add_argument("--unet_model", type=str, required=False, default="./models/unet.axmodel", help="Path to unet axmodel model")
23
+ parser.add_argument("--vae_decoder_model", type=str, required=False, default="./models/vae_decoder.axmodel", help="Path to vae decoder axmodel model")
24
+ parser.add_argument("--time_input", type=str, required=False, default="./models/time_input_txt2img.npy", help="Path to time input file")
25
+ parser.add_argument("--save_dir", type=str, required=False, default="./txt2img_output_axe", help="Path to the output image file")
26
+ return parser.parse_args()
27
+
28
+ def maybe_convert_prompt(prompt: Union[str, List[str]], tokenizer: "PreTrainedTokenizer"): # noqa: F821
29
+ if not isinstance(prompt, List):
30
+ prompts = [prompt]
31
+ else:
32
+ prompts = prompt
33
+
34
+ prompts = [_maybe_convert_prompt(p, tokenizer) for p in prompts]
35
+
36
+ if not isinstance(prompt, List):
37
+ return prompts[0]
38
+
39
+ return prompts
40
+
41
+
42
+ def _maybe_convert_prompt(prompt: str, tokenizer: "PreTrainedTokenizer"): # noqa: F821
43
+ tokens = tokenizer.tokenize(prompt)
44
+ unique_tokens = set(tokens)
45
+ for token in unique_tokens:
46
+ if token in tokenizer.added_tokens_encoder:
47
+ replacement = token
48
+ i = 1
49
+ while f"{token}_{i}" in tokenizer.added_tokens_encoder:
50
+ replacement += f" {token}_{i}"
51
+ i += 1
52
+
53
+ prompt = prompt.replace(token, replacement)
54
+
55
+ return prompt
56
+
57
+
58
+ def get_embeds(prompt = "Portrait of a pretty girl", tokenizer_dir = "./models/tokenizer", text_encoder_dir = "./models/text_encoder"):
59
+ tokenizer = CLIPTokenizer.from_pretrained(tokenizer_dir)
60
+ text_encoder = CLIPTextModel.from_pretrained(text_encoder_dir,
61
+ torch_dtype=torch.float32,
62
+ variant="fp16")
63
+ text_inputs = tokenizer(
64
+ prompt,
65
+ padding="max_length",
66
+ max_length=77,
67
+ truncation=True,
68
+ return_tensors="pt",
69
+ )
70
+ text_input_ids = text_inputs.input_ids
71
+ prompt_embeds = text_encoder(text_input_ids.to("cpu"), attention_mask=None)
72
+
73
+ prompt_embeds_npy = prompt_embeds[0].detach().numpy()
74
+ return prompt_embeds_npy
75
+
76
+
77
+ def get_alphas_cumprod():
78
+ betas = torch.linspace(0.00085 ** 0.5, 0.012 ** 0.5, 1000, dtype=torch.float32) ** 2
79
+ alphas = 1.0 - betas
80
+ alphas_cumprod = torch.cumprod(alphas, dim=0).detach().numpy()
81
+ final_alphas_cumprod = alphas_cumprod[0]
82
+ self_timesteps = np.arange(0, 1000)[::-1].copy().astype(np.int64)
83
+ return alphas_cumprod, final_alphas_cumprod, self_timesteps
84
+
85
+
86
+
87
+ if __name__ == '__main__':
88
+ args = get_args()
89
+
90
+ tokenizer_dir = args.text_model_dir + 'tokenizer'
91
+ text_encoder_dir = args.text_model_dir + 'text_encoder'
92
+ unet_model = args.unet_model
93
+ vae_decoder_model = args.vae_decoder_model
94
+ time_input = args.time_input
95
+ save_dir = args.save_dir
96
+
97
+ # 确保保存目录存在
98
+ os.makedirs(save_dir, exist_ok=True)
99
+
100
+ print(f"tokenizer: {tokenizer_dir}")
101
+ print(f"text_encoder: {text_encoder_dir}")
102
+ print(f"unet_model: {unet_model}")
103
+ print(f"vae_decoder_model: {vae_decoder_model}")
104
+ print(f"time_input: {time_input}")
105
+ print(f"save_dir: {save_dir}")
106
+
107
+ # 加载模型(只加载一次)
108
+ start = time.time()
109
+ unet_session_main = axengine.InferenceSession(unet_model)
110
+ vae_decoder = axengine.InferenceSession(vae_decoder_model)
111
+ print(f"load models take {(1000 * (time.time() - start)):.1f}ms")
112
+
113
+ # 主循环:支持多次输入 Prompt
114
+ while True:
115
+ # 用户输入 Prompt
116
+ prompt = input("\nEnter a prompt to generate an image (or type 'exit' to quit): ")
117
+ if prompt.lower() == 'exit':
118
+ print("Exiting...")
119
+ break
120
+
121
+ # Text Encoder
122
+ start = time.time()
123
+ prompt_embeds_npy = get_embeds(prompt, tokenizer_dir, text_encoder_dir)
124
+ print(f"text encoder take {(1000 * (time.time() - start)):.1f}ms")
125
+
126
+ # 初始化 Latent
127
+ latents_shape = [1, 4, 64, 64]
128
+ latent = torch.randn(latents_shape, generator=None, device="cpu", dtype=torch.float32,
129
+ layout=torch.strided).detach().numpy()
130
+
131
+ alphas_cumprod, final_alphas_cumprod, self_timesteps = get_alphas_cumprod()
132
+
133
+ # 加载 time_input 文件
134
+ time_input_data = np.load(time_input)
135
+
136
+ # U-Net Inference Loop
137
+ timesteps = np.array([999, 759, 499, 259]).astype(np.int64)
138
+ unet_loop_start = time.time()
139
+ for i, timestep in enumerate(timesteps):
140
+ unet_start = time.time()
141
+ noise_pred = unet_session_main.run(None, {
142
+ "sample": latent,
143
+ "/down_blocks.0/resnets.0/act_1/Mul_output_0": np.expand_dims(time_input_data[i], axis=0),
144
+ "encoder_hidden_states": prompt_embeds_npy
145
+ })[0]
146
+ print(f"unet once take {(1000 * (time.time() - unet_start)):.1f}ms")
147
+
148
+ sample = latent
149
+ model_output = noise_pred
150
+ if i < 3:
151
+ prev_timestep = timesteps[i + 1]
152
+ else:
153
+ prev_timestep = timestep
154
+
155
+ alpha_prod_t = alphas_cumprod[timestep]
156
+ alpha_prod_t_prev = alphas_cumprod[prev_timestep] if prev_timestep >= 0 else final_alphas_cumprod
157
+
158
+ beta_prod_t = 1 - alpha_prod_t
159
+ beta_prod_t_prev = 1 - alpha_prod_t_prev
160
+
161
+ scaled_timestep = timestep * 10
162
+ c_skip = 0.5 ** 2 / (scaled_timestep ** 2 + 0.5 ** 2)
163
+ c_out = scaled_timestep / (scaled_timestep ** 2 + 0.5 ** 2) ** 0.5
164
+ predicted_original_sample = (sample - (beta_prod_t ** 0.5) * model_output) / (alpha_prod_t ** 0.5)
165
+
166
+ denoised = c_out * predicted_original_sample + c_skip * sample
167
+
168
+ if i != 3:
169
+ noise = torch.randn(model_output.shape, generator=None, device="cpu", dtype=torch.float32,
170
+ layout=torch.strided).to("cpu").detach().numpy()
171
+ prev_sample = (alpha_prod_t_prev ** 0.5) * denoised + (beta_prod_t_prev ** 0.5) * noise
172
+ else:
173
+ prev_sample = denoised
174
+
175
+ latent = prev_sample
176
+
177
+ print(f"unet loop take {(1000 * (time.time() - unet_loop_start)):.1f}ms")
178
+
179
+ # VAE Inference
180
+ vae_start = time.time()
181
+ latent = latent / 0.18215
182
+ image = vae_decoder.run(None, {"x": latent})[0]
183
+ print(f"vae inference take {(1000 * (time.time() - vae_start)):.1f}ms")
184
+
185
+ # 保存结果
186
+ save_start = time.time()
187
+ image = np.transpose(image, (0, 2, 3, 1)).squeeze(axis=0)
188
+ image_denorm = np.clip(image / 2 + 0.5, 0, 1)
189
+ image = (image_denorm * 255).round().astype("uint8")
190
+ pil_image = Image.fromarray(image[:, :, :3])
191
+
192
+ # 使用 UUID 确保文件名唯一
193
+ unique_filename = f"{uuid.uuid4()}.png"
194
+ save_path = os.path.join(save_dir, unique_filename)
195
+ pil_image.save(save_path)
196
+ print(f"Image saved to {save_path}")
197
+ print(f"Save image take {(1000 * (time.time() - save_start)):.1f}ms")