qqc1989 commited on
Commit
3232548
·
verified ·
1 Parent(s): c5cba1d

Upload 13 files

Browse files
.gitattributes CHANGED
@@ -33,3 +33,8 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ asserts/img2img_output_axe.png filter=lfs diff=lfs merge=lfs -text
37
+ asserts/img2img-init.png filter=lfs diff=lfs merge=lfs -text
38
+ asserts/lcm_lora_sdv1_5_axmodel.png filter=lfs diff=lfs merge=lfs -text
39
+ asserts/lcm_lora_sdv1-5_imgGrid_output.png filter=lfs diff=lfs merge=lfs -text
40
+ asserts/txt2img_output_axe.png filter=lfs diff=lfs merge=lfs -text
Disclaimer.md ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ # 免责声明
2
+ 所有由 SD1.5-LCM.Axera 导出的模型均存在固有的局限性,可能产生错误的、有害的、冒犯性的或其他不良的输出。用户在关键或高风险场景中应谨慎行事,不要使用这些模型,以免导致人身伤害、财产损失或重大损失。此类场景的例子包括但不限于医疗领域、可能导致伤害的软硬件系统的控制以及进行重要的财务或法律决策。
3
+
4
+ SD1.5-LCM.Axera 按对于开源项目“原样”提供,不附带任何种类的明示或暗示的保证,包括但不限于适销性、特定目的的适用性和非侵权的暗示保证。在任何情况下,作者、贡献者或版权所有者均不对因软件或使用或其他软件交易而产生的任何索赔、损害赔偿或其他责任(无论是合同、侵权还是其他原因)承担责任。
5
+
6
+ 使用 SD1.5-LCM.Axera 即表示您同意这些条款和条件,并承认您了解其使用可能带来的潜在风险。您还同意赔偿并使作者、贡献者和版权所有者免受因您使用 SD1.5-LCM.Axera 而产生的任何索赔、损害赔偿或责任的影响。
LICENSE ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ BSD 3-Clause License
2
+
3
+ Copyright (c) 2024, BUG1989
4
+
5
+ Redistribution and use in source and binary forms, with or without
6
+ modification, are permitted provided that the following conditions are met:
7
+
8
+ 1. Redistributions of source code must retain the above copyright notice, this
9
+ list of conditions and the following disclaimer.
10
+
11
+ 2. Redistributions in binary form must reproduce the above copyright notice,
12
+ this list of conditions and the following disclaimer in the documentation
13
+ and/or other materials provided with the distribution.
14
+
15
+ 3. Neither the name of the copyright holder nor the names of its
16
+ contributors may be used to endorse or promote products derived from
17
+ this software without specific prior written permission.
18
+
19
+ THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
20
+ AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
21
+ IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
22
+ DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
23
+ FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
24
+ DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
25
+ SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
26
+ CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
27
+ OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
28
+ OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
README.md CHANGED
@@ -1,23 +1,14 @@
1
- ---
2
- license: bsd-3-clause
3
- base_model:
4
- - latent-consistency/lcm-lora-sdv1-5
5
- - Lykon/dreamshaper-7
6
- tags:
7
- - RaspberryPi5
8
- - StableDiffusion1.5
9
- ---
10
-
11
- # SD1.5-LCM
12
 
13
  基于 StableDiffusion 1.5 LCM 项目,展示该项目 **文生图**、**图生图** 在基于 AX650N 的产品上部署的流程。
14
 
15
- ## Support Platform
 
16
 
17
- - AX650
18
- - AX650N DEMO Board
19
- - [M4N-Dock(爱芯派Pro)](https://wiki.sipeed.com/hardware/zh/maixIV/m4ndock/m4ndock.html)
20
- - [M.2 Accelerator card](https://axcl-docs.readthedocs.io/zh-cn/latest/doc_guide_hardware.html)
21
 
22
  原始模型请参考
23
  - [Latent Consistency Model (LCM) LoRA: SDv1-5](https://huggingface.co/latent-consistency/lcm-lora-sdv1-5)
 
1
+ # SD1.5-LCM.Axera
 
 
 
 
 
 
 
 
 
 
2
 
3
  基于 StableDiffusion 1.5 LCM 项目,展示该项目 **文生图**、**图生图** 在基于 AX650N 的产品上部署的流程。
4
 
5
+ 支持芯片:
6
+ - AX650N
7
 
8
+ 支持硬件
9
+
10
+ - 爱芯派Pro
11
+ - M.2 算力卡
12
 
13
  原始模型请参考
14
  - [Latent Consistency Model (LCM) LoRA: SDv1-5](https://huggingface.co/latent-consistency/lcm-lora-sdv1-5)
asserts/img2img-init.png ADDED

Git LFS Details

  • SHA256: 42f0ee242d8caaee1aea5506c8318c6a920d559a63c6db8d79f993eebaf7d790
  • Pointer size: 131 Bytes
  • Size of remote file: 253 kB
asserts/img2img_output_axe.png ADDED

Git LFS Details

  • SHA256: 7e10a3fb95ce3eb95079d584cafa2f7e55b373c3bdc0110fdf5f7a45e462df78
  • Pointer size: 131 Bytes
  • Size of remote file: 367 kB
asserts/lcm_lora_sdv1-5_imgGrid_output.png ADDED

Git LFS Details

  • SHA256: 327f76a5e9b68642224224d7dbcae1828da695179d4b6d7c41507b8c1c369d8e
  • Pointer size: 131 Bytes
  • Size of remote file: 633 kB
asserts/lcm_lora_sdv1_5_axmodel.png ADDED

Git LFS Details

  • SHA256: 91dd62d810ba66e8142ef27285fe28af55746e764c10eafe3521a52cd92f9da0
  • Pointer size: 131 Bytes
  • Size of remote file: 395 kB
asserts/txt2img_output_axe.png ADDED

Git LFS Details

  • SHA256: 572a205faf1fb7c2538c06f5c2bbce175ef8630b3d276a201f89c7b6695efeb0
  • Pointer size: 131 Bytes
  • Size of remote file: 340 kB
config.json.txt ADDED
File without changes
run_img2img_axe_infer.py ADDED
@@ -0,0 +1,546 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Union
2
+ import numpy as np
3
+ # import onnxruntime
4
+ import axengine
5
+ import torch
6
+ from PIL import Image
7
+ from transformers import CLIPTokenizer, CLIPTextModel, PreTrainedTokenizer, CLIPTextModelWithProjection
8
+
9
+ import time
10
+ import argparse
11
+ from diffusers.utils import load_image
12
+ import PIL.Image
13
+ from typing import List, Optional, Tuple, Union
14
+ from diffusers.models.autoencoders.vae import DiagonalGaussianDistribution
15
+ from diffusers.models.modeling_outputs import AutoencoderKLOutput
16
+ from diffusers.utils.torch_utils import randn_tensor
17
+ from diffusers.utils import make_image_grid, load_image
18
+
19
+
20
+
21
+ ########## Img2Img
22
+ PipelineImageInput = Union[
23
+ PIL.Image.Image,
24
+ np.ndarray,
25
+ torch.Tensor,
26
+ List[PIL.Image.Image],
27
+ List[np.ndarray],
28
+ List[torch.Tensor],
29
+ ]
30
+
31
+ PipelineDepthInput = PipelineImageInput
32
+
33
+ # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.add_noise
34
+ def add_noise(
35
+ original_samples: torch.Tensor,
36
+ noise: torch.Tensor,
37
+ timesteps: torch.IntTensor,
38
+ ) -> torch.Tensor:
39
+ # Make sure alphas_cumprod and timestep have same device and dtype as original_samples
40
+ # Move the self.alphas_cumprod to device to avoid redundant CPU to GPU data movement
41
+ # for the subsequent add_noise calls
42
+ # self.alphas_cumprod = self.alphas_cumprod.to(device=original_samples.device)
43
+ # Convert betas to alphas_bar_sqrt
44
+ beta_start = 0.00085
45
+ beta_end = 0.012
46
+ num_train_timesteps = 1000
47
+ betas = torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2
48
+ alphas = 1.0 - betas
49
+ alphas_cumprod = torch.cumprod(alphas, dim=0)
50
+ alphas_cumprod = alphas_cumprod.to(device=original_samples.device)
51
+ alphas_cumprod = alphas_cumprod.to(dtype=original_samples.dtype)
52
+ timesteps = timesteps.to(original_samples.device)
53
+
54
+ sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5
55
+ sqrt_alpha_prod = sqrt_alpha_prod.flatten()
56
+ while len(sqrt_alpha_prod.shape) < len(original_samples.shape):
57
+ sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1)
58
+
59
+ sqrt_one_minus_alpha_prod = (1 - alphas_cumprod[timesteps]) ** 0.5
60
+ sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten()
61
+ while len(sqrt_one_minus_alpha_prod.shape) < len(original_samples.shape):
62
+ sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1)
63
+
64
+ noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise
65
+ return noisy_samples
66
+
67
+ def retrieve_latents(
68
+ encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
69
+ ):
70
+ if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
71
+ return encoder_output.latent_dist.sample(generator)
72
+ elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
73
+ return encoder_output.latent_dist.mode()
74
+ elif hasattr(encoder_output, "latents"):
75
+ return encoder_output.latents
76
+ else:
77
+ raise AttributeError("Could not access latents of provided encoder_output")
78
+
79
+ def numpy_to_pt(images: np.ndarray) -> torch.Tensor:
80
+ r"""
81
+ Convert a NumPy image to a PyTorch tensor.
82
+
83
+ Args:
84
+ images (`np.ndarray`):
85
+ The NumPy image array to convert to PyTorch format.
86
+
87
+ Returns:
88
+ `torch.Tensor`:
89
+ A PyTorch tensor representation of the images.
90
+ """
91
+ if images.ndim == 3:
92
+ images = images[..., None]
93
+
94
+ images = torch.from_numpy(images.transpose(0, 3, 1, 2))
95
+ return images
96
+
97
+ def pil_to_numpy(images: Union[List[PIL.Image.Image], PIL.Image.Image]) -> np.ndarray:
98
+ r"""
99
+ Convert a PIL image or a list of PIL images to NumPy arrays.
100
+
101
+ Args:
102
+ images (`PIL.Image.Image` or `List[PIL.Image.Image]`):
103
+ The PIL image or list of images to convert to NumPy format.
104
+
105
+ Returns:
106
+ `np.ndarray`:
107
+ A NumPy array representation of the images.
108
+ """
109
+ if not isinstance(images, list):
110
+ images = [images]
111
+ images = [np.array(image).astype(np.float32) / 255.0 for image in images]
112
+ images = np.stack(images, axis=0)
113
+
114
+ return images
115
+
116
+ def is_valid_image(image) -> bool:
117
+ r"""
118
+ Checks if the input is a valid image.
119
+
120
+ A valid image can be:
121
+ - A `PIL.Image.Image`.
122
+ - A 2D or 3D `np.ndarray` or `torch.Tensor` (grayscale or color image).
123
+
124
+ Args:
125
+ image (`Union[PIL.Image.Image, np.ndarray, torch.Tensor]`):
126
+ The image to validate. It can be a PIL image, a NumPy array, or a torch tensor.
127
+
128
+ Returns:
129
+ `bool`:
130
+ `True` if the input is a valid image, `False` otherwise.
131
+ """
132
+ return isinstance(image, PIL.Image.Image) or isinstance(image, (np.ndarray, torch.Tensor)) and image.ndim in (2, 3)
133
+
134
+ def is_valid_image_imagelist(images):
135
+ r"""
136
+ Checks if the input is a valid image or list of images.
137
+
138
+ The input can be one of the following formats:
139
+ - A 4D tensor or numpy array (batch of images).
140
+ - A valid single image: `PIL.Image.Image`, 2D `np.ndarray` or `torch.Tensor` (grayscale image), 3D `np.ndarray` or
141
+ `torch.Tensor`.
142
+ - A list of valid images.
143
+
144
+ Args:
145
+ images (`Union[np.ndarray, torch.Tensor, PIL.Image.Image, List]`):
146
+ The image(s) to check. Can be a batch of images (4D tensor/array), a single image, or a list of valid
147
+ images.
148
+
149
+ Returns:
150
+ `bool`:
151
+ `True` if the input is valid, `False` otherwise.
152
+ """
153
+ if isinstance(images, (np.ndarray, torch.Tensor)) and images.ndim == 4:
154
+ return True
155
+ elif is_valid_image(images):
156
+ return True
157
+ elif isinstance(images, list):
158
+ return all(is_valid_image(image) for image in images)
159
+ return False
160
+
161
+
162
+ def normalize(images: Union[np.ndarray, torch.Tensor]) -> Union[np.ndarray, torch.Tensor]:
163
+ r"""
164
+ Normalize an image array to [-1,1].
165
+
166
+ Args:
167
+ images (`np.ndarray` or `torch.Tensor`):
168
+ The image array to normalize.
169
+
170
+ Returns:
171
+ `np.ndarray` or `torch.Tensor`:
172
+ The normalized image array.
173
+ """
174
+ return 2.0 * images - 1.0
175
+
176
+ # Copy from: /home/baiyongqiang/miniforge-pypy3/envs/hf/lib/python3.9/site-packages/diffusers/image_processor.py#607
177
+ def preprocess(
178
+ image: PipelineImageInput,
179
+ height: Optional[int] = None,
180
+ width: Optional[int] = None,
181
+ resize_mode: str = "default", # "default", "fill", "crop"
182
+ crops_coords: Optional[Tuple[int, int, int, int]] = None,
183
+ ) -> torch.Tensor:
184
+ """
185
+ Preprocess the image input.
186
+
187
+ Args:
188
+ image (`PipelineImageInput`):
189
+ The image input, accepted formats are PIL images, NumPy arrays, PyTorch tensors; Also accept list of
190
+ supported formats.
191
+ height (`int`, *optional*):
192
+ The height in preprocessed image. If `None`, will use the `get_default_height_width()` to get default
193
+ height.
194
+ width (`int`, *optional*):
195
+ The width in preprocessed. If `None`, will use get_default_height_width()` to get the default width.
196
+ resize_mode (`str`, *optional*, defaults to `default`):
197
+ The resize mode, can be one of `default` or `fill`. If `default`, will resize the image to fit within
198
+ the specified width and height, and it may not maintaining the original aspect ratio. If `fill`, will
199
+ resize the image to fit within the specified width and height, maintaining the aspect ratio, and then
200
+ center the image within the dimensions, filling empty with data from image. If `crop`, will resize the
201
+ image to fit within the specified width and height, maintaining the aspect ratio, and then center the
202
+ image within the dimensions, cropping the excess. Note that resize_mode `fill` and `crop` are only
203
+ supported for PIL image input.
204
+ crops_coords (`List[Tuple[int, int, int, int]]`, *optional*, defaults to `None`):
205
+ The crop coordinates for each image in the batch. If `None`, will not crop the image.
206
+
207
+ Returns:
208
+ `torch.Tensor`:
209
+ The preprocessed image.
210
+ """
211
+ supported_formats = (PIL.Image.Image, np.ndarray, torch.Tensor)
212
+
213
+ # # Expand the missing dimension for 3-dimensional pytorch tensor or numpy array that represents grayscale image
214
+ # if self.config.do_convert_grayscale and isinstance(image, (torch.Tensor, np.ndarray)) and image.ndim == 3:
215
+ # if isinstance(image, torch.Tensor):
216
+ # # if image is a pytorch tensor could have 2 possible shapes:
217
+ # # 1. batch x height x width: we should insert the channel dimension at position 1
218
+ # # 2. channel x height x width: we should insert batch dimension at position 0,
219
+ # # however, since both channel and batch dimension has same size 1, it is same to insert at position 1
220
+ # # for simplicity, we insert a dimension of size 1 at position 1 for both cases
221
+ # image = image.unsqueeze(1)
222
+ # else:
223
+ # # if it is a numpy array, it could have 2 possible shapes:
224
+ # # 1. batch x height x width: insert channel dimension on last position
225
+ # # 2. height x width x channel: insert batch dimension on first position
226
+ # if image.shape[-1] == 1:
227
+ # image = np.expand_dims(image, axis=0)
228
+ # else:
229
+ # image = np.expand_dims(image, axis=-1)
230
+
231
+ if isinstance(image, list) and isinstance(image[0], np.ndarray) and image[0].ndim == 4:
232
+ warnings.warn(
233
+ "Passing `image` as a list of 4d np.ndarray is deprecated."
234
+ "Please concatenate the list along the batch dimension and pass it as a single 4d np.ndarray",
235
+ FutureWarning,
236
+ )
237
+ image = np.concatenate(image, axis=0)
238
+ if isinstance(image, list) and isinstance(image[0], torch.Tensor) and image[0].ndim == 4:
239
+ warnings.warn(
240
+ "Passing `image` as a list of 4d torch.Tensor is deprecated."
241
+ "Please concatenate the list along the batch dimension and pass it as a single 4d torch.Tensor",
242
+ FutureWarning,
243
+ )
244
+ image = torch.cat(image, axis=0)
245
+
246
+ if not is_valid_image_imagelist(image):
247
+ raise ValueError(
248
+ f"Input is in incorrect format. Currently, we only support {', '.join(str(x) for x in supported_formats)}"
249
+ )
250
+ if not isinstance(image, list):
251
+ image = [image]
252
+
253
+ if isinstance(image[0], PIL.Image.Image):
254
+ if crops_coords is not None:
255
+ image = [i.crop(crops_coords) for i in image]
256
+ # if self.config.do_resize:
257
+ # height, width = self.get_default_height_width(image[0], height, width)
258
+ # image = [self.resize(i, height, width, resize_mode=resize_mode) for i in image]
259
+ # if self.config.do_convert_rgb:
260
+ # image = [self.convert_to_rgb(i) for i in image]
261
+ # elif self.config.do_convert_grayscale:
262
+ # image = [self.convert_to_grayscale(i) for i in image]
263
+ image = pil_to_numpy(image) # to np
264
+ image = numpy_to_pt(image) # to pt
265
+
266
+ elif isinstance(image[0], np.ndarray):
267
+ image = np.concatenate(image, axis=0) if image[0].ndim == 4 else np.stack(image, axis=0)
268
+
269
+ # image = self.numpy_to_pt(image)
270
+
271
+ # height, width = self.get_default_height_width(image, height, width)
272
+ # if self.config.do_resize:
273
+ # image = self.resize(image, height, width)
274
+
275
+ elif isinstance(image[0], torch.Tensor):
276
+ image = torch.cat(image, axis=0) if image[0].ndim == 4 else torch.stack(image, axis=0)
277
+
278
+ # if self.config.do_convert_grayscale and image.ndim == 3:
279
+ # image = image.unsqueeze(1)
280
+
281
+ channel = image.shape[1]
282
+ # don't need any preprocess if the image is latents
283
+ # if channel == self.config.vae_latent_channels:
284
+ # return image
285
+
286
+ # height, width = self.get_default_height_width(image, height, width)
287
+ # if self.config.do_resize:
288
+ # image = self.resize(image, height, width)
289
+
290
+ # expected range [0,1], normalize to [-1,1]
291
+ do_normalize = True # self.config.do_normalize
292
+ if do_normalize and image.min() < 0:
293
+ warnings.warn(
294
+ "Passing `image` as torch tensor with value range in [-1,1] is deprecated. The expected value range for image tensor is [0,1] "
295
+ f"when passing as pytorch tensor or numpy Array. You passed `image` with value range [{image.min()},{image.max()}]",
296
+ FutureWarning,
297
+ )
298
+ do_normalize = False
299
+ if do_normalize:
300
+ image = normalize(image)
301
+
302
+ # if self.config.do_binarize:
303
+ # image = self.binarize(image)
304
+
305
+ return image
306
+ ##########
307
+
308
+
309
+ def get_args():
310
+ parser = argparse.ArgumentParser(
311
+ prog="StableDiffusion",
312
+ description="Generate picture with the input prompt"
313
+ )
314
+ parser.add_argument("--prompt", type=str, required=False, default="Astronauts in a jungle, cold color palette, muted colors, detailed, 8k", help="the input text prompt")
315
+ parser.add_argument("--text_model_dir", type=str, required=False, default="./models/", help="Path to text encoder and tokenizer files")
316
+ parser.add_argument("--unet_model", type=str, required=False, default="./models/unet.axmodel", help="Path to unet axmodel model")
317
+ parser.add_argument("--vae_encoder_model", type=str, required=False, default="./models/vae_encoder.axmodel", help="Path to vae encoder axmodel model")
318
+ parser.add_argument("--vae_decoder_model", type=str, required=False, default="./models/vae_decoder.axmodel", help="Path to vae decoder axmodel model")
319
+ parser.add_argument("--time_input", type=str, required=False, default="./models/time_input_img2img.npy", help="Path to time input file")
320
+ parser.add_argument("--init_image", type=str, required=False, default="./models/img2img-init.png", help="Path to initial image file")
321
+ parser.add_argument("--save_dir", type=str, required=False, default="./img2img_output_axe.png", help="Path to the output image file")
322
+ return parser.parse_args()
323
+
324
+ def maybe_convert_prompt(prompt: Union[str, List[str]], tokenizer: "PreTrainedTokenizer"): # noqa: F821
325
+ if not isinstance(prompt, List):
326
+ prompts = [prompt]
327
+ else:
328
+ prompts = prompt
329
+
330
+ prompts = [_maybe_convert_prompt(p, tokenizer) for p in prompts]
331
+
332
+ if not isinstance(prompt, List):
333
+ return prompts[0]
334
+
335
+ return prompts
336
+
337
+
338
+ def _maybe_convert_prompt(prompt: str, tokenizer: "PreTrainedTokenizer"): # noqa: F821
339
+ tokens = tokenizer.tokenize(prompt)
340
+ unique_tokens = set(tokens)
341
+ for token in unique_tokens:
342
+ if token in tokenizer.added_tokens_encoder:
343
+ replacement = token
344
+ i = 1
345
+ while f"{token}_{i}" in tokenizer.added_tokens_encoder:
346
+ replacement += f" {token}_{i}"
347
+ i += 1
348
+
349
+ prompt = prompt.replace(token, replacement)
350
+
351
+ return prompt
352
+
353
+
354
+ def get_embeds(prompt = "Portrait of a pretty girl", tokenizer_dir = "./models/tokenizer", text_encoder_dir = "./models/text_encoder"):
355
+ tokenizer = CLIPTokenizer.from_pretrained(tokenizer_dir)
356
+ text_encoder = CLIPTextModel.from_pretrained(text_encoder_dir,
357
+ torch_dtype=torch.float32,
358
+ variant="fp16")
359
+ text_inputs = tokenizer(
360
+ prompt,
361
+ padding="max_length",
362
+ max_length=77,
363
+ truncation=True,
364
+ return_tensors="pt",
365
+ )
366
+ text_input_ids = text_inputs.input_ids
367
+ prompt_embeds = text_encoder(text_input_ids.to("cpu"), attention_mask=None)
368
+
369
+ prompt_embeds_npy = prompt_embeds[0].detach().numpy()
370
+ return prompt_embeds_npy
371
+
372
+
373
+ def get_alphas_cumprod():
374
+ betas = torch.linspace(0.00085 ** 0.5, 0.012 ** 0.5, 1000, dtype=torch.float32) ** 2
375
+ alphas = 1.0 - betas
376
+ alphas_cumprod = torch.cumprod(alphas, dim=0).detach().numpy()
377
+ final_alphas_cumprod = alphas_cumprod[0]
378
+ self_timesteps = np.arange(0, 1000)[::-1].copy().astype(np.int64)
379
+ return alphas_cumprod, final_alphas_cumprod, self_timesteps
380
+
381
+ def resize_and_rgb(image: PIL.Image.Image) -> PIL.Image.Image:
382
+ """
383
+ Resize the image to 512x512 and convert it to RGB.
384
+ """
385
+ return image.resize((512, 512)).convert("RGB")
386
+
387
+
388
+ if __name__ == '__main__':
389
+
390
+ """
391
+ Usage:
392
+ - python3 run_img2img_axmodel_infer.py --prompt "Astronauts in a jungle, cold color palette, muted colors, detailed, 8k" --unet_model output_onnx/unet_sim.onnx --vae_encoder_model output_onnx/vae_encoder_sim.onnx --vae_decoder_model output_onnx/vae_decoder_sim.onnx --time_input ./output_onnx/time_input.npy --save_dir ./img2img_output.png
393
+ """
394
+ args = get_args()
395
+ prompt = args.prompt
396
+ tokenizer_dir = args.text_model_dir + 'tokenizer'
397
+ text_encoder_dir = args.text_model_dir + 'text_encoder'
398
+ unet_model = args.unet_model
399
+ vae_decoder_model = args.vae_decoder_model
400
+ vae_encoder_model = args.vae_encoder_model
401
+ init_image = args.init_image
402
+ time_input = args.time_input
403
+ save_dir = args.save_dir
404
+
405
+ print(f"prompt: {prompt}")
406
+ print(f"text_tokenizer: {tokenizer_dir}")
407
+ print(f"text_encoder: {text_encoder_dir}")
408
+ print(f"unet_model: {unet_model}")
409
+ print(f"vae_encoder_model: {vae_encoder_model}")
410
+ print(f"vae_decoder_model: {vae_decoder_model}")
411
+ print(f"init image: {init_image}")
412
+ print(f"time_input: {time_input}")
413
+ print(f"save_dir: {save_dir}")
414
+
415
+ # timesteps = np.array([999, 759, 499, 259]).astype(np.int64)
416
+
417
+ # text encoder
418
+ start = time.time()
419
+ # prompt = "Self-portrait oil painting, a beautiful cyborg with golden hair, 8k"
420
+ # prompt = "Astronauts in a jungle, cold color palette, muted colors, detailed, 8k"
421
+ # prompt = "Caricature, a beautiful girl with black hair, 8k"
422
+ prompt_embeds_npy = get_embeds(prompt, tokenizer_dir, text_encoder_dir)
423
+ print(f"text encoder take {(1000 * (time.time() - start)):.1f}ms")
424
+
425
+ prompt_name = prompt.replace(" ", "_")
426
+ latents_shape = [1, 4, 64, 64]
427
+ # latent = torch.randn(latents_shape, generator=None, device="cpu", dtype=torch.float32,
428
+ # layout=torch.strided).detach().numpy()
429
+
430
+ alphas_cumprod, final_alphas_cumprod, self_timesteps = get_alphas_cumprod()
431
+
432
+ # load unet model and vae model
433
+ start = time.time()
434
+ vae_encoder = axengine.InferenceSession(vae_encoder_model)
435
+ unet_session_main = axengine.InferenceSession(unet_model)
436
+ vae_decoder = axengine.InferenceSession(vae_decoder_model)
437
+ print(f"load models take {(1000 * (time.time() - start)):.1f}ms")
438
+
439
+ # load time input file
440
+ time_input = np.load(time_input)
441
+
442
+ # load image
443
+ # url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/img2img-init.png"
444
+ url = init_image
445
+ init_image = load_image(url, convert_method=resize_and_rgb) # U8, (512, 512, 3), RGB
446
+ init_image_show = init_image
447
+
448
+ # vae encoder inference
449
+ vae_start = time.time()
450
+
451
+ init_image = preprocess(init_image) # torch.Size([1, 3, 512, 512])
452
+ if isinstance(init_image, torch.Tensor):
453
+ init_image = init_image.detach().numpy()
454
+
455
+ vae_encoder_onnx_inp_name = vae_encoder.get_inputs()[0].name
456
+ vae_encoder_onnx_out_name = vae_encoder.get_outputs()[0].name
457
+
458
+ # vae_encoder_out.shape (1, 8, 64, 64)
459
+ vae_encoder_out = vae_encoder.run(None, {vae_encoder_onnx_inp_name: init_image})[0] # encoder out: torch.Size([1, 8, 64, 64])
460
+ print(f"vae encoder inference take {(1000 * (time.time() - vae_start)):.1f}ms")
461
+
462
+ # vae encoder inference
463
+ device = torch.device("cpu")
464
+ vae_encoder_out = torch.from_numpy(vae_encoder_out).to(torch.float32)
465
+ posterior = DiagonalGaussianDistribution(vae_encoder_out) # 数值基本对的上
466
+ vae_encode_info = AutoencoderKLOutput(latent_dist=posterior)
467
+ generator = torch.manual_seed(0)
468
+ init_latents = retrieve_latents(vae_encode_info, generator=generator) # 数值基本对的上
469
+ init_latents = init_latents * 0.18215 # 数值基本对的��
470
+ init_latents = torch.cat([init_latents], dim=0)
471
+ shape = init_latents.shape
472
+ dtype = torch.float16
473
+ noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype) # dtype 不同, 随机值不同
474
+ # get latents
475
+ timestep = torch.tensor([499]).to(device)
476
+ init_latents = add_noise(init_latents.to(device), noise, timestep)
477
+ latents = init_latents
478
+
479
+ latents = latents.detach().cpu().numpy()
480
+ latent = latents
481
+
482
+ # unet inference loop
483
+ unet_loop_start = time.time()
484
+ timesteps = np.array([499, 259]).astype(np.int64)
485
+ self_timesteps = np.array([999, 759, 499, 259]).astype(np.int64)
486
+ step_index = [2, 3]
487
+ for i, timestep in enumerate(timesteps):
488
+ unet_start = time.time()
489
+ noise_pred = unet_session_main.run(None, {"sample": latent, \
490
+ "/down_blocks.0/resnets.0/act_1/Mul_output_0": np.expand_dims(time_input[i], axis=0), \
491
+ "encoder_hidden_states": prompt_embeds_npy})[0]
492
+
493
+ print(f"unet once take {(1000 * (time.time() - unet_start)):.1f}ms")
494
+
495
+ sample = latent
496
+ model_output = noise_pred
497
+
498
+ # 1. get previous step value
499
+ prev_step_index = step_index[i] + 1
500
+ if prev_step_index < len(self_timesteps):
501
+ prev_timestep = self_timesteps[prev_step_index]
502
+ else:
503
+ prev_timestep = timestep
504
+
505
+ alpha_prod_t = alphas_cumprod[timestep]
506
+ alpha_prod_t_prev = alphas_cumprod[prev_timestep] if prev_timestep >= 0 else final_alphas_cumprod
507
+ beta_prod_t = 1 - alpha_prod_t
508
+ beta_prod_t_prev = 1 - alpha_prod_t_prev
509
+
510
+ # 3. Get scalings for boundary conditions
511
+ scaled_timestep = timestep * 10
512
+ c_skip = 0.5 ** 2 / (scaled_timestep ** 2 + 0.5 ** 2)
513
+ c_out = scaled_timestep / (scaled_timestep ** 2 + 0.5 ** 2) ** 0.5
514
+ predicted_original_sample = (sample - (beta_prod_t ** 0.5) * model_output) / (alpha_prod_t ** 0.5) # 数值基本对齐
515
+
516
+ denoised = c_out * predicted_original_sample + c_skip * sample
517
+ if step_index[i] != 3:
518
+ device = torch.device("cpu")
519
+ noise = randn_tensor(model_output.shape, generator=generator, device=device, dtype=torch.float16).numpy()
520
+ prev_sample = (alpha_prod_t_prev ** 0.5) * denoised + (beta_prod_t_prev ** 0.5) * noise
521
+ else:
522
+ prev_sample = denoised
523
+
524
+ latent = prev_sample
525
+
526
+ print(f"unet loop take {(1000 * (time.time() - unet_loop_start)):.1f}ms")
527
+
528
+ # vae decoder inference
529
+ vae_start = time.time()
530
+ latent = latent / 0.18215
531
+ image = vae_decoder.run(None, {"x": latent})[0] # ['784']
532
+ print(f"vae decoder inference take {(1000 * (time.time() - vae_start)):.1f}ms")
533
+
534
+ # save result
535
+ save_start = time.time()
536
+ image = np.transpose(image, (0, 2, 3, 1)).squeeze(axis=0)
537
+ image_denorm = np.clip(image / 2 + 0.5, 0, 1)
538
+ image = (image_denorm * 255).round().astype("uint8")
539
+ pil_image = Image.fromarray(image[:, :, :3])
540
+ pil_image.save(save_dir)
541
+
542
+ grid_img = make_image_grid([init_image_show, pil_image], rows=1, cols=2)
543
+ grid_img.save(f"./lcm_lora_sdv1-5_imgGrid_output.png")
544
+
545
+ print(f"grid image saved in ./lcm_lora_sdv1-5_imgGrid_output.png")
546
+ print(f"save image take {(1000 * (time.time() - save_start)):.1f}ms")
run_img2img_onnx_infer.py ADDED
@@ -0,0 +1,546 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Union
2
+ import numpy as np
3
+ import onnxruntime
4
+ import torch
5
+ from PIL import Image
6
+ from transformers import CLIPTokenizer, CLIPTextModel, PreTrainedTokenizer, CLIPTextModelWithProjection
7
+
8
+ # import axengine as axe
9
+ import time
10
+ import argparse
11
+ from diffusers.utils import load_image
12
+ import PIL.Image
13
+ from typing import List, Optional, Tuple, Union
14
+ from diffusers.models.autoencoders.vae import DiagonalGaussianDistribution
15
+ from diffusers.models.modeling_outputs import AutoencoderKLOutput
16
+ from diffusers.utils.torch_utils import randn_tensor
17
+ from diffusers.utils import make_image_grid, load_image
18
+
19
+
20
+
21
+ ########## Img2Img
22
+ PipelineImageInput = Union[
23
+ PIL.Image.Image,
24
+ np.ndarray,
25
+ torch.Tensor,
26
+ List[PIL.Image.Image],
27
+ List[np.ndarray],
28
+ List[torch.Tensor],
29
+ ]
30
+
31
+ PipelineDepthInput = PipelineImageInput
32
+
33
+ # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.add_noise
34
+ def add_noise(
35
+ original_samples: torch.Tensor,
36
+ noise: torch.Tensor,
37
+ timesteps: torch.IntTensor,
38
+ ) -> torch.Tensor:
39
+ # Make sure alphas_cumprod and timestep have same device and dtype as original_samples
40
+ # Move the self.alphas_cumprod to device to avoid redundant CPU to GPU data movement
41
+ # for the subsequent add_noise calls
42
+ # self.alphas_cumprod = self.alphas_cumprod.to(device=original_samples.device)
43
+ # Convert betas to alphas_bar_sqrt
44
+ beta_start = 0.00085
45
+ beta_end = 0.012
46
+ num_train_timesteps = 1000
47
+ betas = torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2
48
+ alphas = 1.0 - betas
49
+ alphas_cumprod = torch.cumprod(alphas, dim=0)
50
+ alphas_cumprod = alphas_cumprod.to(device=original_samples.device)
51
+ alphas_cumprod = alphas_cumprod.to(dtype=original_samples.dtype)
52
+ timesteps = timesteps.to(original_samples.device)
53
+
54
+ sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5
55
+ sqrt_alpha_prod = sqrt_alpha_prod.flatten()
56
+ while len(sqrt_alpha_prod.shape) < len(original_samples.shape):
57
+ sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1)
58
+
59
+ sqrt_one_minus_alpha_prod = (1 - alphas_cumprod[timesteps]) ** 0.5
60
+ sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten()
61
+ while len(sqrt_one_minus_alpha_prod.shape) < len(original_samples.shape):
62
+ sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1)
63
+
64
+ noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise
65
+ return noisy_samples
66
+
67
+ def retrieve_latents(
68
+ encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
69
+ ):
70
+ if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
71
+ return encoder_output.latent_dist.sample(generator)
72
+ elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
73
+ return encoder_output.latent_dist.mode()
74
+ elif hasattr(encoder_output, "latents"):
75
+ return encoder_output.latents
76
+ else:
77
+ raise AttributeError("Could not access latents of provided encoder_output")
78
+
79
+ def numpy_to_pt(images: np.ndarray) -> torch.Tensor:
80
+ r"""
81
+ Convert a NumPy image to a PyTorch tensor.
82
+
83
+ Args:
84
+ images (`np.ndarray`):
85
+ The NumPy image array to convert to PyTorch format.
86
+
87
+ Returns:
88
+ `torch.Tensor`:
89
+ A PyTorch tensor representation of the images.
90
+ """
91
+ if images.ndim == 3:
92
+ images = images[..., None]
93
+
94
+ images = torch.from_numpy(images.transpose(0, 3, 1, 2))
95
+ return images
96
+
97
+ def pil_to_numpy(images: Union[List[PIL.Image.Image], PIL.Image.Image]) -> np.ndarray:
98
+ r"""
99
+ Convert a PIL image or a list of PIL images to NumPy arrays.
100
+
101
+ Args:
102
+ images (`PIL.Image.Image` or `List[PIL.Image.Image]`):
103
+ The PIL image or list of images to convert to NumPy format.
104
+
105
+ Returns:
106
+ `np.ndarray`:
107
+ A NumPy array representation of the images.
108
+ """
109
+ if not isinstance(images, list):
110
+ images = [images]
111
+ images = [np.array(image).astype(np.float32) / 255.0 for image in images]
112
+ images = np.stack(images, axis=0)
113
+
114
+ return images
115
+
116
+ def is_valid_image(image) -> bool:
117
+ r"""
118
+ Checks if the input is a valid image.
119
+
120
+ A valid image can be:
121
+ - A `PIL.Image.Image`.
122
+ - A 2D or 3D `np.ndarray` or `torch.Tensor` (grayscale or color image).
123
+
124
+ Args:
125
+ image (`Union[PIL.Image.Image, np.ndarray, torch.Tensor]`):
126
+ The image to validate. It can be a PIL image, a NumPy array, or a torch tensor.
127
+
128
+ Returns:
129
+ `bool`:
130
+ `True` if the input is a valid image, `False` otherwise.
131
+ """
132
+ return isinstance(image, PIL.Image.Image) or isinstance(image, (np.ndarray, torch.Tensor)) and image.ndim in (2, 3)
133
+
134
+ def is_valid_image_imagelist(images):
135
+ r"""
136
+ Checks if the input is a valid image or list of images.
137
+
138
+ The input can be one of the following formats:
139
+ - A 4D tensor or numpy array (batch of images).
140
+ - A valid single image: `PIL.Image.Image`, 2D `np.ndarray` or `torch.Tensor` (grayscale image), 3D `np.ndarray` or
141
+ `torch.Tensor`.
142
+ - A list of valid images.
143
+
144
+ Args:
145
+ images (`Union[np.ndarray, torch.Tensor, PIL.Image.Image, List]`):
146
+ The image(s) to check. Can be a batch of images (4D tensor/array), a single image, or a list of valid
147
+ images.
148
+
149
+ Returns:
150
+ `bool`:
151
+ `True` if the input is valid, `False` otherwise.
152
+ """
153
+ if isinstance(images, (np.ndarray, torch.Tensor)) and images.ndim == 4:
154
+ return True
155
+ elif is_valid_image(images):
156
+ return True
157
+ elif isinstance(images, list):
158
+ return all(is_valid_image(image) for image in images)
159
+ return False
160
+
161
+
162
+ def normalize(images: Union[np.ndarray, torch.Tensor]) -> Union[np.ndarray, torch.Tensor]:
163
+ r"""
164
+ Normalize an image array to [-1,1].
165
+
166
+ Args:
167
+ images (`np.ndarray` or `torch.Tensor`):
168
+ The image array to normalize.
169
+
170
+ Returns:
171
+ `np.ndarray` or `torch.Tensor`:
172
+ The normalized image array.
173
+ """
174
+ return 2.0 * images - 1.0
175
+
176
+ # Copy from: /home/baiyongqiang/miniforge-pypy3/envs/hf/lib/python3.9/site-packages/diffusers/image_processor.py#607
177
+ def preprocess(
178
+ image: PipelineImageInput,
179
+ height: Optional[int] = None,
180
+ width: Optional[int] = None,
181
+ resize_mode: str = "default", # "default", "fill", "crop"
182
+ crops_coords: Optional[Tuple[int, int, int, int]] = None,
183
+ ) -> torch.Tensor:
184
+ """
185
+ Preprocess the image input.
186
+
187
+ Args:
188
+ image (`PipelineImageInput`):
189
+ The image input, accepted formats are PIL images, NumPy arrays, PyTorch tensors; Also accept list of
190
+ supported formats.
191
+ height (`int`, *optional*):
192
+ The height in preprocessed image. If `None`, will use the `get_default_height_width()` to get default
193
+ height.
194
+ width (`int`, *optional*):
195
+ The width in preprocessed. If `None`, will use get_default_height_width()` to get the default width.
196
+ resize_mode (`str`, *optional*, defaults to `default`):
197
+ The resize mode, can be one of `default` or `fill`. If `default`, will resize the image to fit within
198
+ the specified width and height, and it may not maintaining the original aspect ratio. If `fill`, will
199
+ resize the image to fit within the specified width and height, maintaining the aspect ratio, and then
200
+ center the image within the dimensions, filling empty with data from image. If `crop`, will resize the
201
+ image to fit within the specified width and height, maintaining the aspect ratio, and then center the
202
+ image within the dimensions, cropping the excess. Note that resize_mode `fill` and `crop` are only
203
+ supported for PIL image input.
204
+ crops_coords (`List[Tuple[int, int, int, int]]`, *optional*, defaults to `None`):
205
+ The crop coordinates for each image in the batch. If `None`, will not crop the image.
206
+
207
+ Returns:
208
+ `torch.Tensor`:
209
+ The preprocessed image.
210
+ """
211
+ supported_formats = (PIL.Image.Image, np.ndarray, torch.Tensor)
212
+
213
+ # # Expand the missing dimension for 3-dimensional pytorch tensor or numpy array that represents grayscale image
214
+ # if self.config.do_convert_grayscale and isinstance(image, (torch.Tensor, np.ndarray)) and image.ndim == 3:
215
+ # if isinstance(image, torch.Tensor):
216
+ # # if image is a pytorch tensor could have 2 possible shapes:
217
+ # # 1. batch x height x width: we should insert the channel dimension at position 1
218
+ # # 2. channel x height x width: we should insert batch dimension at position 0,
219
+ # # however, since both channel and batch dimension has same size 1, it is same to insert at position 1
220
+ # # for simplicity, we insert a dimension of size 1 at position 1 for both cases
221
+ # image = image.unsqueeze(1)
222
+ # else:
223
+ # # if it is a numpy array, it could have 2 possible shapes:
224
+ # # 1. batch x height x width: insert channel dimension on last position
225
+ # # 2. height x width x channel: insert batch dimension on first position
226
+ # if image.shape[-1] == 1:
227
+ # image = np.expand_dims(image, axis=0)
228
+ # else:
229
+ # image = np.expand_dims(image, axis=-1)
230
+
231
+ if isinstance(image, list) and isinstance(image[0], np.ndarray) and image[0].ndim == 4:
232
+ warnings.warn(
233
+ "Passing `image` as a list of 4d np.ndarray is deprecated."
234
+ "Please concatenate the list along the batch dimension and pass it as a single 4d np.ndarray",
235
+ FutureWarning,
236
+ )
237
+ image = np.concatenate(image, axis=0)
238
+ if isinstance(image, list) and isinstance(image[0], torch.Tensor) and image[0].ndim == 4:
239
+ warnings.warn(
240
+ "Passing `image` as a list of 4d torch.Tensor is deprecated."
241
+ "Please concatenate the list along the batch dimension and pass it as a single 4d torch.Tensor",
242
+ FutureWarning,
243
+ )
244
+ image = torch.cat(image, axis=0)
245
+
246
+ if not is_valid_image_imagelist(image):
247
+ raise ValueError(
248
+ f"Input is in incorrect format. Currently, we only support {', '.join(str(x) for x in supported_formats)}"
249
+ )
250
+ if not isinstance(image, list):
251
+ image = [image]
252
+
253
+ if isinstance(image[0], PIL.Image.Image):
254
+ if crops_coords is not None:
255
+ image = [i.crop(crops_coords) for i in image]
256
+ # if self.config.do_resize:
257
+ # height, width = self.get_default_height_width(image[0], height, width)
258
+ # image = [self.resize(i, height, width, resize_mode=resize_mode) for i in image]
259
+ # if self.config.do_convert_rgb:
260
+ # image = [self.convert_to_rgb(i) for i in image]
261
+ # elif self.config.do_convert_grayscale:
262
+ # image = [self.convert_to_grayscale(i) for i in image]
263
+ image = pil_to_numpy(image) # to np
264
+ image = numpy_to_pt(image) # to pt
265
+
266
+ elif isinstance(image[0], np.ndarray):
267
+ image = np.concatenate(image, axis=0) if image[0].ndim == 4 else np.stack(image, axis=0)
268
+
269
+ # image = self.numpy_to_pt(image)
270
+
271
+ # height, width = self.get_default_height_width(image, height, width)
272
+ # if self.config.do_resize:
273
+ # image = self.resize(image, height, width)
274
+
275
+ elif isinstance(image[0], torch.Tensor):
276
+ image = torch.cat(image, axis=0) if image[0].ndim == 4 else torch.stack(image, axis=0)
277
+
278
+ # if self.config.do_convert_grayscale and image.ndim == 3:
279
+ # image = image.unsqueeze(1)
280
+
281
+ channel = image.shape[1]
282
+ # don't need any preprocess if the image is latents
283
+ # if channel == self.config.vae_latent_channels:
284
+ # return image
285
+
286
+ # height, width = self.get_default_height_width(image, height, width)
287
+ # if self.config.do_resize:
288
+ # image = self.resize(image, height, width)
289
+
290
+ # expected range [0,1], normalize to [-1,1]
291
+ do_normalize = True # self.config.do_normalize
292
+ if do_normalize and image.min() < 0:
293
+ warnings.warn(
294
+ "Passing `image` as torch tensor with value range in [-1,1] is deprecated. The expected value range for image tensor is [0,1] "
295
+ f"when passing as pytorch tensor or numpy Array. You passed `image` with value range [{image.min()},{image.max()}]",
296
+ FutureWarning,
297
+ )
298
+ do_normalize = False
299
+ if do_normalize:
300
+ image = normalize(image)
301
+
302
+ # if self.config.do_binarize:
303
+ # image = self.binarize(image)
304
+
305
+ return image
306
+ ##########
307
+
308
+
309
+ def get_args():
310
+ parser = argparse.ArgumentParser(
311
+ prog="StableDiffusion",
312
+ description="Generate picture with the input prompt"
313
+ )
314
+ parser.add_argument("--prompt", type=str, required=False, default="Astronauts in a jungle, cold color palette, muted colors, detailed, 8k", help="the input text prompt")
315
+ parser.add_argument("--text_model_dir", type=str, required=False, default="./models/", help="Path to text encoder and tokenizer files")
316
+ parser.add_argument("--unet_model", type=str, required=False, default="./models/unet.onnx", help="Path to unet ONNX model")
317
+ parser.add_argument("--vae_encoder_model", type=str, required=False, default="./models/vae_encoder.onnx", help="Path to vae encoder ONNX model")
318
+ parser.add_argument("--vae_decoder_model", type=str, required=False, default="./models/vae_decoder.onnx", help="Path to vae decoder ONNX model")
319
+ parser.add_argument("--time_input", type=str, required=False, default="./models/time_input_img2img.npy", help="Path to time input file")
320
+ parser.add_argument("--init_image", type=str, required=False, default="./models/img2img-init.png", help="Path to initial image file")
321
+ parser.add_argument("--save_dir", type=str, required=False, default="./img2img_output_onnx.png", help="Path to the output image file")
322
+ return parser.parse_args()
323
+
324
+ def maybe_convert_prompt(prompt: Union[str, List[str]], tokenizer: "PreTrainedTokenizer"): # noqa: F821
325
+ if not isinstance(prompt, List):
326
+ prompts = [prompt]
327
+ else:
328
+ prompts = prompt
329
+
330
+ prompts = [_maybe_convert_prompt(p, tokenizer) for p in prompts]
331
+
332
+ if not isinstance(prompt, List):
333
+ return prompts[0]
334
+
335
+ return prompts
336
+
337
+
338
+ def _maybe_convert_prompt(prompt: str, tokenizer: "PreTrainedTokenizer"): # noqa: F821
339
+ tokens = tokenizer.tokenize(prompt)
340
+ unique_tokens = set(tokens)
341
+ for token in unique_tokens:
342
+ if token in tokenizer.added_tokens_encoder:
343
+ replacement = token
344
+ i = 1
345
+ while f"{token}_{i}" in tokenizer.added_tokens_encoder:
346
+ replacement += f" {token}_{i}"
347
+ i += 1
348
+
349
+ prompt = prompt.replace(token, replacement)
350
+
351
+ return prompt
352
+
353
+
354
+ def get_embeds(prompt = "Portrait of a pretty girl", tokenizer_dir = "./models/tokenizer", text_encoder_dir = "./models/text_encoder"):
355
+ tokenizer = CLIPTokenizer.from_pretrained(tokenizer_dir)
356
+ text_encoder = CLIPTextModel.from_pretrained(text_encoder_dir,
357
+ torch_dtype=torch.float32,
358
+ variant="fp16")
359
+ text_inputs = tokenizer(
360
+ prompt,
361
+ padding="max_length",
362
+ max_length=77,
363
+ truncation=True,
364
+ return_tensors="pt",
365
+ )
366
+ text_input_ids = text_inputs.input_ids
367
+ prompt_embeds = text_encoder(text_input_ids.to("cpu"), attention_mask=None)
368
+
369
+ prompt_embeds_npy = prompt_embeds[0].detach().numpy()
370
+ return prompt_embeds_npy
371
+
372
+
373
+ def get_alphas_cumprod():
374
+ betas = torch.linspace(0.00085 ** 0.5, 0.012 ** 0.5, 1000, dtype=torch.float32) ** 2
375
+ alphas = 1.0 - betas
376
+ alphas_cumprod = torch.cumprod(alphas, dim=0).detach().numpy()
377
+ final_alphas_cumprod = alphas_cumprod[0]
378
+ self_timesteps = np.arange(0, 1000)[::-1].copy().astype(np.int64)
379
+ return alphas_cumprod, final_alphas_cumprod, self_timesteps
380
+
381
+ def resize_and_rgb(image: PIL.Image.Image) -> PIL.Image.Image:
382
+ """
383
+ Resize the image to 512x512 and convert it to RGB.
384
+ """
385
+ return image.resize((512, 512)).convert("RGB")
386
+
387
+
388
+ if __name__ == '__main__':
389
+
390
+ """
391
+ Usage:
392
+ - python3 run_img2img_onnx_infer.py --prompt "Astronauts in a jungle, cold color palette, muted colors, detailed, 8k" --unet_model output_onnx/unet_sim.onnx --vae_encoder_model output_onnx/vae_encoder_sim.onnx --vae_decoder_model output_onnx/vae_decoder_sim.onnx --time_input ./output_onnx/time_input.npy --save_dir ./img2img_output.png
393
+ """
394
+ args = get_args()
395
+ prompt = args.prompt
396
+ tokenizer_dir = args.text_model_dir + 'tokenizer'
397
+ text_encoder_dir = args.text_model_dir + 'text_encoder'
398
+ unet_model = args.unet_model
399
+ vae_decoder_model = args.vae_decoder_model
400
+ vae_encoder_model = args.vae_encoder_model
401
+ init_image = args.init_image
402
+ time_input = args.time_input
403
+ save_dir = args.save_dir
404
+
405
+ print(f"prompt: {prompt}")
406
+ print(f"text_tokenizer: {tokenizer_dir}")
407
+ print(f"text_encoder: {text_encoder_dir}")
408
+ print(f"unet_model: {unet_model}")
409
+ print(f"vae_encoder_model: {vae_encoder_model}")
410
+ print(f"vae_decoder_model: {vae_decoder_model}")
411
+ print(f"init image: {init_image}")
412
+ print(f"time_input: {time_input}")
413
+ print(f"save_dir: {save_dir}")
414
+
415
+ # timesteps = np.array([999, 759, 499, 259]).astype(np.int64)
416
+
417
+ # text encoder
418
+ start = time.time()
419
+ # prompt = "Self-portrait oil painting, a beautiful cyborg with golden hair, 8k"
420
+ # prompt = "Astronauts in a jungle, cold color palette, muted colors, detailed, 8k"
421
+ # prompt = "Caricature, a beautiful girl with black hair, 8k"
422
+ prompt_embeds_npy = get_embeds(prompt, tokenizer_dir, text_encoder_dir)
423
+ print(f"text encoder take {(1000 * (time.time() - start)):.1f}ms")
424
+
425
+ prompt_name = prompt.replace(" ", "_")
426
+ latents_shape = [1, 4, 64, 64]
427
+ # latent = torch.randn(latents_shape, generator=None, device="cpu", dtype=torch.float32,
428
+ # layout=torch.strided).detach().numpy()
429
+
430
+ alphas_cumprod, final_alphas_cumprod, self_timesteps = get_alphas_cumprod()
431
+
432
+ # load unet model and vae model
433
+ start = time.time()
434
+ vae_encoder = onnxruntime.InferenceSession(vae_encoder_model)
435
+ unet_session_main = onnxruntime.InferenceSession(unet_model)
436
+ vae_decoder = onnxruntime.InferenceSession(vae_decoder_model)
437
+ print(f"load models take {(1000 * (time.time() - start)):.1f}ms")
438
+
439
+ # load time input file
440
+ time_input = np.load(time_input)
441
+
442
+ # load image
443
+ # url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/img2img-init.png"
444
+ url = init_image
445
+ init_image = load_image(url, convert_method=resize_and_rgb) # U8, (512, 512, 3), RGB
446
+ init_image_show = init_image
447
+
448
+ # vae encoder inference
449
+ vae_start = time.time()
450
+
451
+ init_image = preprocess(init_image) # torch.Size([1, 3, 512, 512])
452
+ if isinstance(init_image, torch.Tensor):
453
+ init_image = init_image.detach().numpy()
454
+
455
+ vae_encoder_onnx_inp_name = vae_encoder.get_inputs()[0].name
456
+ vae_encoder_onnx_out_name = vae_encoder.get_outputs()[0].name
457
+
458
+ # vae_encoder_out.shape (1, 8, 64, 64)
459
+ vae_encoder_out = vae_encoder.run(None, {vae_encoder_onnx_inp_name: init_image})[0] # encoder out: torch.Size([1, 8, 64, 64])
460
+ print(f"vae encoder inference take {(1000 * (time.time() - vae_start)):.1f}ms")
461
+
462
+ # vae encoder inference
463
+ device = torch.device("cpu")
464
+ vae_encoder_out = torch.from_numpy(vae_encoder_out).to(torch.float32)
465
+ posterior = DiagonalGaussianDistribution(vae_encoder_out) # 数值基本对的上
466
+ vae_encode_info = AutoencoderKLOutput(latent_dist=posterior)
467
+ generator = torch.manual_seed(0)
468
+ init_latents = retrieve_latents(vae_encode_info, generator=generator) # 数值基本对的上
469
+ init_latents = init_latents * 0.18215 # 数值基本对的上
470
+ init_latents = torch.cat([init_latents], dim=0)
471
+ shape = init_latents.shape
472
+ dtype = torch.float16
473
+ noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype) # dtype 不同, 随机值不同
474
+ # get latents
475
+ timestep = torch.tensor([499]).to(device)
476
+ init_latents = add_noise(init_latents.to(device), noise, timestep)
477
+ latents = init_latents
478
+
479
+ latents = latents.detach().cpu().numpy()
480
+ latent = latents
481
+
482
+ # unet inference loop
483
+ unet_loop_start = time.time()
484
+ timesteps = np.array([499, 259]).astype(np.int64)
485
+ self_timesteps = np.array([999, 759, 499, 259]).astype(np.int64)
486
+ step_index = [2, 3]
487
+ for i, timestep in enumerate(timesteps):
488
+ unet_start = time.time()
489
+ noise_pred = unet_session_main.run(None, {"sample": latent, \
490
+ "/down_blocks.0/resnets.0/act_1/Mul_output_0": np.expand_dims(time_input[i], axis=0), \
491
+ "encoder_hidden_states": prompt_embeds_npy})[0]
492
+
493
+ print(f"unet once take {(1000 * (time.time() - unet_start)):.1f}ms")
494
+
495
+ sample = latent
496
+ model_output = noise_pred
497
+
498
+ # 1. get previous step value
499
+ prev_step_index = step_index[i] + 1
500
+ if prev_step_index < len(self_timesteps):
501
+ prev_timestep = self_timesteps[prev_step_index]
502
+ else:
503
+ prev_timestep = timestep
504
+
505
+ alpha_prod_t = alphas_cumprod[timestep]
506
+ alpha_prod_t_prev = alphas_cumprod[prev_timestep] if prev_timestep >= 0 else final_alphas_cumprod
507
+ beta_prod_t = 1 - alpha_prod_t
508
+ beta_prod_t_prev = 1 - alpha_prod_t_prev
509
+
510
+ # 3. Get scalings for boundary conditions
511
+ scaled_timestep = timestep * 10
512
+ c_skip = 0.5 ** 2 / (scaled_timestep ** 2 + 0.5 ** 2)
513
+ c_out = scaled_timestep / (scaled_timestep ** 2 + 0.5 ** 2) ** 0.5
514
+ predicted_original_sample = (sample - (beta_prod_t ** 0.5) * model_output) / (alpha_prod_t ** 0.5) # 数值基本对齐
515
+
516
+ denoised = c_out * predicted_original_sample + c_skip * sample
517
+ if step_index[i] != 3:
518
+ device = torch.device("cpu")
519
+ noise = randn_tensor(model_output.shape, generator=generator, device=device, dtype=torch.float16).numpy()
520
+ prev_sample = (alpha_prod_t_prev ** 0.5) * denoised + (beta_prod_t_prev ** 0.5) * noise
521
+ else:
522
+ prev_sample = denoised
523
+
524
+ latent = prev_sample
525
+
526
+ print(f"unet loop take {(1000 * (time.time() - unet_loop_start)):.1f}ms")
527
+
528
+ # vae decoder inference
529
+ vae_start = time.time()
530
+ latent = latent / 0.18215
531
+ image = vae_decoder.run(None, {"x": latent})[0] # ['784']
532
+ print(f"vae decoder inference take {(1000 * (time.time() - vae_start)):.1f}ms")
533
+
534
+ # save result
535
+ save_start = time.time()
536
+ image = np.transpose(image, (0, 2, 3, 1)).squeeze(axis=0)
537
+ image_denorm = np.clip(image / 2 + 0.5, 0, 1)
538
+ image = (image_denorm * 255).round().astype("uint8")
539
+ pil_image = Image.fromarray(image[:, :, :3])
540
+ pil_image.save(save_dir)
541
+
542
+ grid_img = make_image_grid([init_image_show, pil_image], rows=1, cols=2)
543
+ grid_img.save(f"./lcm_lora_sdv1-5_imgGrid_output.png")
544
+
545
+ print(f"grid image saved in ./lcm_lora_sdv1-5_imgGrid_output.png")
546
+ print(f"save image take {(1000 * (time.time() - save_start)):.1f}ms")
run_txt2img_axe_infer.py ADDED
@@ -0,0 +1,183 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Union
2
+ import numpy as np
3
+ # import onnxruntime
4
+ import axengine
5
+ import torch
6
+ from PIL import Image
7
+ from transformers import CLIPTokenizer, CLIPTextModel, PreTrainedTokenizer, CLIPTextModelWithProjection
8
+
9
+ import time
10
+ import argparse
11
+
12
+
13
+ def get_args():
14
+ parser = argparse.ArgumentParser(
15
+ prog="StableDiffusion",
16
+ description="Generate picture with the input prompt"
17
+ )
18
+ parser.add_argument("--prompt", type=str, required=False, default="Self-portrait oil painting, a beautiful cyborg with golden hair, 8k", help="the input text prompt")
19
+ parser.add_argument("--text_model_dir", type=str, required=False, default="./models/", help="Path to text encoder and tokenizer files")
20
+ parser.add_argument("--unet_model", type=str, required=False, default="./models/unet.axmodel", help="Path to unet axmodel model")
21
+ parser.add_argument("--vae_decoder_model", type=str, required=False, default="./models/vae_decoder.axmodel", help="Path to vae decoder axmodel model")
22
+ parser.add_argument("--time_input", type=str, required=False, default="./models/time_input_txt2img.npy", help="Path to time input file")
23
+ parser.add_argument("--save_dir", type=str, required=False, default="./txt2img_output_axe.png", help="Path to the output image file")
24
+ return parser.parse_args()
25
+
26
+ def maybe_convert_prompt(prompt: Union[str, List[str]], tokenizer: "PreTrainedTokenizer"): # noqa: F821
27
+ if not isinstance(prompt, List):
28
+ prompts = [prompt]
29
+ else:
30
+ prompts = prompt
31
+
32
+ prompts = [_maybe_convert_prompt(p, tokenizer) for p in prompts]
33
+
34
+ if not isinstance(prompt, List):
35
+ return prompts[0]
36
+
37
+ return prompts
38
+
39
+
40
+ def _maybe_convert_prompt(prompt: str, tokenizer: "PreTrainedTokenizer"): # noqa: F821
41
+ tokens = tokenizer.tokenize(prompt)
42
+ unique_tokens = set(tokens)
43
+ for token in unique_tokens:
44
+ if token in tokenizer.added_tokens_encoder:
45
+ replacement = token
46
+ i = 1
47
+ while f"{token}_{i}" in tokenizer.added_tokens_encoder:
48
+ replacement += f" {token}_{i}"
49
+ i += 1
50
+
51
+ prompt = prompt.replace(token, replacement)
52
+
53
+ return prompt
54
+
55
+
56
+ def get_embeds(prompt = "Portrait of a pretty girl", tokenizer_dir = "./models/tokenizer", text_encoder_dir = "./models/text_encoder"):
57
+ tokenizer = CLIPTokenizer.from_pretrained(tokenizer_dir)
58
+ text_encoder = CLIPTextModel.from_pretrained(text_encoder_dir,
59
+ torch_dtype=torch.float32,
60
+ variant="fp16")
61
+ text_inputs = tokenizer(
62
+ prompt,
63
+ padding="max_length",
64
+ max_length=77,
65
+ truncation=True,
66
+ return_tensors="pt",
67
+ )
68
+ text_input_ids = text_inputs.input_ids
69
+ prompt_embeds = text_encoder(text_input_ids.to("cpu"), attention_mask=None)
70
+
71
+ prompt_embeds_npy = prompt_embeds[0].detach().numpy()
72
+ return prompt_embeds_npy
73
+
74
+
75
+ def get_alphas_cumprod():
76
+ betas = torch.linspace(0.00085 ** 0.5, 0.012 ** 0.5, 1000, dtype=torch.float32) ** 2
77
+ alphas = 1.0 - betas
78
+ alphas_cumprod = torch.cumprod(alphas, dim=0).detach().numpy()
79
+ final_alphas_cumprod = alphas_cumprod[0]
80
+ self_timesteps = np.arange(0, 1000)[::-1].copy().astype(np.int64)
81
+ return alphas_cumprod, final_alphas_cumprod, self_timesteps
82
+
83
+
84
+ if __name__ == '__main__':
85
+ args = get_args()
86
+ prompt = args.prompt
87
+ tokenizer_dir = args.text_model_dir + 'tokenizer'
88
+ text_encoder_dir = args.text_model_dir + 'text_encoder'
89
+ unet_model = args.unet_model
90
+ vae_decoder_model = args.vae_decoder_model
91
+ time_input = args.time_input
92
+ save_dir = args.save_dir
93
+
94
+ print(f"prompt: {prompt}")
95
+ print(f"text_tokenizer: {tokenizer_dir}")
96
+ print(f"text_encoder: {text_encoder_dir}")
97
+ print(f"unet_model: {unet_model}")
98
+ print(f"vae_decoder_model: {vae_decoder_model}")
99
+ print(f"time_input: {time_input}")
100
+ print(f"save_dir: {save_dir}")
101
+
102
+ timesteps = np.array([999, 759, 499, 259]).astype(np.int64)
103
+
104
+ # text encoder
105
+ start = time.time()
106
+ # prompt = "Self-portrait oil painting, a beautiful cyborg with golden hair, 8k"
107
+ prompt_embeds_npy = get_embeds(prompt, tokenizer_dir, text_encoder_dir)
108
+ print(f"text encoder take {(1000 * (time.time() - start)):.1f}ms")
109
+
110
+ prompt_name = prompt.replace(" ", "_")
111
+ latents_shape = [1, 4, 64, 64]
112
+ latent = torch.randn(latents_shape, generator=None, device="cpu", dtype=torch.float32,
113
+ layout=torch.strided).detach().numpy()
114
+
115
+ alphas_cumprod, final_alphas_cumprod, self_timesteps = get_alphas_cumprod()
116
+
117
+ # load unet model and vae model
118
+ start = time.time()
119
+ unet_session_main = axengine.InferenceSession(unet_model)
120
+ vae_decoder = axengine.InferenceSession(vae_decoder_model)
121
+ print(f"load models take {(1000 * (time.time() - start)):.1f}ms")
122
+
123
+ # load time input file
124
+ time_input = np.load(time_input)
125
+
126
+ # unet inference loop
127
+ unet_loop_start = time.time()
128
+ for i, timestep in enumerate(timesteps):
129
+ # print(i, timestep)
130
+
131
+ unet_start = time.time()
132
+ noise_pred = unet_session_main.run(None, {"sample": latent, \
133
+ "/down_blocks.0/resnets.0/act_1/Mul_output_0": np.expand_dims(time_input[i], axis=0), \
134
+ "encoder_hidden_states": prompt_embeds_npy})[0]
135
+
136
+ print(f"unet once take {(1000 * (time.time() - unet_start)):.1f}ms")
137
+
138
+ sample = latent
139
+ model_output = noise_pred
140
+ if i < 3:
141
+ prev_timestep = timesteps[i + 1]
142
+ else:
143
+ prev_timestep = timestep
144
+
145
+ alpha_prod_t = alphas_cumprod[timestep]
146
+ alpha_prod_t_prev = alphas_cumprod[prev_timestep] if prev_timestep >= 0 else final_alphas_cumprod
147
+
148
+ beta_prod_t = 1 - alpha_prod_t
149
+ beta_prod_t_prev = 1 - alpha_prod_t_prev
150
+
151
+ # 3. Get scalings for boundary conditions
152
+ scaled_timestep = timestep * 10
153
+ c_skip = 0.5 ** 2 / (scaled_timestep ** 2 + 0.5 ** 2)
154
+ c_out = scaled_timestep / (scaled_timestep ** 2 + 0.5 ** 2) ** 0.5
155
+ predicted_original_sample = (sample - (beta_prod_t ** 0.5) * model_output) / (alpha_prod_t ** 0.5)
156
+
157
+ denoised = c_out * predicted_original_sample + c_skip * sample
158
+
159
+ if i != 3:
160
+ noise = torch.randn(model_output.shape, generator=None, device="cpu", dtype=torch.float32,
161
+ layout=torch.strided).to("cpu").detach().numpy()
162
+ prev_sample = (alpha_prod_t_prev ** 0.5) * denoised + (beta_prod_t_prev ** 0.5) * noise
163
+ else:
164
+ prev_sample = denoised
165
+
166
+ latent = prev_sample
167
+
168
+ print(f"unet loop take {(1000 * (time.time() - unet_loop_start)):.1f}ms")
169
+
170
+ # vae inference
171
+ vae_start = time.time()
172
+ latent = latent / 0.18215
173
+ image = vae_decoder.run(None, {"x": latent})[0]
174
+ print(f"vae inference take {(1000 * (time.time() - vae_start)):.1f}ms")
175
+
176
+ # save result
177
+ save_start = time.time()
178
+ image = np.transpose(image, (0, 2, 3, 1)).squeeze(axis=0)
179
+ image_denorm = np.clip(image / 2 + 0.5, 0, 1)
180
+ image = (image_denorm * 255).round().astype("uint8")
181
+ pil_image = Image.fromarray(image[:, :, :3])
182
+ pil_image.save(save_dir)
183
+ print(f"save image take {(1000 * (time.time() - save_start)):.1f}ms")
run_txt2img_onnx_infer.py ADDED
@@ -0,0 +1,183 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Union
2
+ import numpy as np
3
+ import onnxruntime
4
+ # import axengine
5
+ import torch
6
+ from PIL import Image
7
+ from transformers import CLIPTokenizer, CLIPTextModel, PreTrainedTokenizer, CLIPTextModelWithProjection
8
+
9
+ import time
10
+ import argparse
11
+
12
+
13
+ def get_args():
14
+ parser = argparse.ArgumentParser(
15
+ prog="StableDiffusion",
16
+ description="Generate picture with the input prompt"
17
+ )
18
+ parser.add_argument("--prompt", type=str, required=False, default="Self-portrait oil painting, a beautiful cyborg with golden hair, 8k", help="the input text prompt")
19
+ parser.add_argument("--text_model_dir", type=str, required=False, default="./models/", help="Path to text encoder and tokenizer files")
20
+ parser.add_argument("--unet_model", type=str, required=False, default="./models/unet.onnx", help="Path to unet ONNX model")
21
+ parser.add_argument("--vae_decoder_model", type=str, required=False, default="./models/vae_decoder.onnx", help="Path to vae decoder ONNX model")
22
+ parser.add_argument("--time_input", type=str, required=False, default="./models/time_input_txt2img.npy", help="Path to time input file")
23
+ parser.add_argument("--save_dir", type=str, required=False, default="./txt2img_output_onnx.png", help="Path to the output image file")
24
+ return parser.parse_args()
25
+
26
+ def maybe_convert_prompt(prompt: Union[str, List[str]], tokenizer: "PreTrainedTokenizer"): # noqa: F821
27
+ if not isinstance(prompt, List):
28
+ prompts = [prompt]
29
+ else:
30
+ prompts = prompt
31
+
32
+ prompts = [_maybe_convert_prompt(p, tokenizer) for p in prompts]
33
+
34
+ if not isinstance(prompt, List):
35
+ return prompts[0]
36
+
37
+ return prompts
38
+
39
+
40
+ def _maybe_convert_prompt(prompt: str, tokenizer: "PreTrainedTokenizer"): # noqa: F821
41
+ tokens = tokenizer.tokenize(prompt)
42
+ unique_tokens = set(tokens)
43
+ for token in unique_tokens:
44
+ if token in tokenizer.added_tokens_encoder:
45
+ replacement = token
46
+ i = 1
47
+ while f"{token}_{i}" in tokenizer.added_tokens_encoder:
48
+ replacement += f" {token}_{i}"
49
+ i += 1
50
+
51
+ prompt = prompt.replace(token, replacement)
52
+
53
+ return prompt
54
+
55
+
56
+ def get_embeds(prompt = "Portrait of a pretty girl", tokenizer_dir = "./models/tokenizer", text_encoder_dir = "./models/text_encoder"):
57
+ tokenizer = CLIPTokenizer.from_pretrained(tokenizer_dir)
58
+ text_encoder = CLIPTextModel.from_pretrained(text_encoder_dir,
59
+ torch_dtype=torch.float32,
60
+ variant="fp16")
61
+ text_inputs = tokenizer(
62
+ prompt,
63
+ padding="max_length",
64
+ max_length=77,
65
+ truncation=True,
66
+ return_tensors="pt",
67
+ )
68
+ text_input_ids = text_inputs.input_ids
69
+ prompt_embeds = text_encoder(text_input_ids.to("cpu"), attention_mask=None)
70
+
71
+ prompt_embeds_npy = prompt_embeds[0].detach().numpy()
72
+ return prompt_embeds_npy
73
+
74
+
75
+ def get_alphas_cumprod():
76
+ betas = torch.linspace(0.00085 ** 0.5, 0.012 ** 0.5, 1000, dtype=torch.float32) ** 2
77
+ alphas = 1.0 - betas
78
+ alphas_cumprod = torch.cumprod(alphas, dim=0).detach().numpy()
79
+ final_alphas_cumprod = alphas_cumprod[0]
80
+ self_timesteps = np.arange(0, 1000)[::-1].copy().astype(np.int64)
81
+ return alphas_cumprod, final_alphas_cumprod, self_timesteps
82
+
83
+
84
+ if __name__ == '__main__':
85
+ args = get_args()
86
+ prompt = args.prompt
87
+ tokenizer_dir = args.text_model_dir + 'tokenizer'
88
+ text_encoder_dir = args.text_model_dir + 'text_encoder'
89
+ unet_model = args.unet_model
90
+ vae_decoder_model = args.vae_decoder_model
91
+ time_input = args.time_input
92
+ save_dir = args.save_dir
93
+
94
+ print(f"prompt: {prompt}")
95
+ print(f"text_tokenizer: {tokenizer_dir}")
96
+ print(f"text_encoder: {text_encoder_dir}")
97
+ print(f"unet_model: {unet_model}")
98
+ print(f"vae_decoder_model: {vae_decoder_model}")
99
+ print(f"time_input: {time_input}")
100
+ print(f"save_dir: {save_dir}")
101
+
102
+ timesteps = np.array([999, 759, 499, 259]).astype(np.int64)
103
+
104
+ # text encoder
105
+ start = time.time()
106
+ # prompt = "Self-portrait oil painting, a beautiful cyborg with golden hair, 8k"
107
+ prompt_embeds_npy = get_embeds(prompt, tokenizer_dir, text_encoder_dir)
108
+ print(f"text encoder take {(1000 * (time.time() - start)):.1f}ms")
109
+
110
+ prompt_name = prompt.replace(" ", "_")
111
+ latents_shape = [1, 4, 64, 64]
112
+ latent = torch.randn(latents_shape, generator=None, device="cpu", dtype=torch.float32,
113
+ layout=torch.strided).detach().numpy()
114
+
115
+ alphas_cumprod, final_alphas_cumprod, self_timesteps = get_alphas_cumprod()
116
+
117
+ # load unet model and vae model
118
+ start = time.time()
119
+ unet_session_main = onnxruntime.InferenceSession(unet_model)
120
+ vae_decoder = onnxruntime.InferenceSession(vae_decoder_model)
121
+ print(f"load models take {(1000 * (time.time() - start)):.1f}ms")
122
+
123
+ # load time input file
124
+ time_input = np.load(time_input)
125
+
126
+ # unet inference loop
127
+ unet_loop_start = time.time()
128
+ for i, timestep in enumerate(timesteps):
129
+ # print(i, timestep)
130
+
131
+ unet_start = time.time()
132
+ noise_pred = unet_session_main.run(None, {"sample": latent, \
133
+ "/down_blocks.0/resnets.0/act_1/Mul_output_0": np.expand_dims(time_input[i], axis=0), \
134
+ "encoder_hidden_states": prompt_embeds_npy})[0]
135
+
136
+ print(f"unet once take {(1000 * (time.time() - unet_start)):.1f}ms")
137
+
138
+ sample = latent
139
+ model_output = noise_pred
140
+ if i < 3:
141
+ prev_timestep = timesteps[i + 1]
142
+ else:
143
+ prev_timestep = timestep
144
+
145
+ alpha_prod_t = alphas_cumprod[timestep]
146
+ alpha_prod_t_prev = alphas_cumprod[prev_timestep] if prev_timestep >= 0 else final_alphas_cumprod
147
+
148
+ beta_prod_t = 1 - alpha_prod_t
149
+ beta_prod_t_prev = 1 - alpha_prod_t_prev
150
+
151
+ # 3. Get scalings for boundary conditions
152
+ scaled_timestep = timestep * 10
153
+ c_skip = 0.5 ** 2 / (scaled_timestep ** 2 + 0.5 ** 2)
154
+ c_out = scaled_timestep / (scaled_timestep ** 2 + 0.5 ** 2) ** 0.5
155
+ predicted_original_sample = (sample - (beta_prod_t ** 0.5) * model_output) / (alpha_prod_t ** 0.5)
156
+
157
+ denoised = c_out * predicted_original_sample + c_skip * sample
158
+
159
+ if i != 3:
160
+ noise = torch.randn(model_output.shape, generator=None, device="cpu", dtype=torch.float32,
161
+ layout=torch.strided).to("cpu").detach().numpy()
162
+ prev_sample = (alpha_prod_t_prev ** 0.5) * denoised + (beta_prod_t_prev ** 0.5) * noise
163
+ else:
164
+ prev_sample = denoised
165
+
166
+ latent = prev_sample
167
+
168
+ print(f"unet loop take {(1000 * (time.time() - unet_loop_start)):.1f}ms")
169
+
170
+ # vae inference
171
+ vae_start = time.time()
172
+ latent = latent / 0.18215
173
+ image = vae_decoder.run(None, {"x": latent})[0]
174
+ print(f"vae inference take {(1000 * (time.time() - vae_start)):.1f}ms")
175
+
176
+ # save result
177
+ save_start = time.time()
178
+ image = np.transpose(image, (0, 2, 3, 1)).squeeze(axis=0)
179
+ image_denorm = np.clip(image / 2 + 0.5, 0, 1)
180
+ image = (image_denorm * 255).round().astype("uint8")
181
+ pil_image = Image.fromarray(image[:, :, :3])
182
+ pil_image.save(save_dir)
183
+ print(f"save image take {(1000 * (time.time() - save_start)):.1f}ms")