Upload 13 files
Browse files- .gitattributes +5 -0
- Disclaimer.md +6 -0
- LICENSE +28 -0
- README.md +7 -16
- asserts/img2img-init.png +3 -0
- asserts/img2img_output_axe.png +3 -0
- asserts/lcm_lora_sdv1-5_imgGrid_output.png +3 -0
- asserts/lcm_lora_sdv1_5_axmodel.png +3 -0
- asserts/txt2img_output_axe.png +3 -0
- config.json.txt +0 -0
- run_img2img_axe_infer.py +546 -0
- run_img2img_onnx_infer.py +546 -0
- run_txt2img_axe_infer.py +183 -0
- run_txt2img_onnx_infer.py +183 -0
.gitattributes
CHANGED
@@ -33,3 +33,8 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
36 |
+
asserts/img2img_output_axe.png filter=lfs diff=lfs merge=lfs -text
|
37 |
+
asserts/img2img-init.png filter=lfs diff=lfs merge=lfs -text
|
38 |
+
asserts/lcm_lora_sdv1_5_axmodel.png filter=lfs diff=lfs merge=lfs -text
|
39 |
+
asserts/lcm_lora_sdv1-5_imgGrid_output.png filter=lfs diff=lfs merge=lfs -text
|
40 |
+
asserts/txt2img_output_axe.png filter=lfs diff=lfs merge=lfs -text
|
Disclaimer.md
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# 免责声明
|
2 |
+
所有由 SD1.5-LCM.Axera 导出的模型均存在固有的局限性,可能产生错误的、有害的、冒犯性的或其他不良的输出。用户在关键或高风险场景中应谨慎行事,不要使用这些模型,以免导致人身伤害、财产损失或重大损失。此类场景的例子包括但不限于医疗领域、可能导致伤害的软硬件系统的控制以及进行重要的财务或法律决策。
|
3 |
+
|
4 |
+
SD1.5-LCM.Axera 按对于开源项目“原样”提供,不附带任何种类的明示或暗示的保证,包括但不限于适销性、特定目的的适用性和非侵权的暗示保证。在任何情况下,作者、贡献者或版权所有者均不对因软件或使用或其他软件交易而产生的任何索赔、损害赔偿或其他责任(无论是合同、侵权还是其他原因)承担责任。
|
5 |
+
|
6 |
+
使用 SD1.5-LCM.Axera 即表示您同意这些条款和条件,并承认您了解其使用可能带来的潜在风险。您还同意赔偿并使作者、贡献者和版权所有者免受因您使用 SD1.5-LCM.Axera 而产生的任何索赔、损害赔偿或责任的影响。
|
LICENSE
ADDED
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
BSD 3-Clause License
|
2 |
+
|
3 |
+
Copyright (c) 2024, BUG1989
|
4 |
+
|
5 |
+
Redistribution and use in source and binary forms, with or without
|
6 |
+
modification, are permitted provided that the following conditions are met:
|
7 |
+
|
8 |
+
1. Redistributions of source code must retain the above copyright notice, this
|
9 |
+
list of conditions and the following disclaimer.
|
10 |
+
|
11 |
+
2. Redistributions in binary form must reproduce the above copyright notice,
|
12 |
+
this list of conditions and the following disclaimer in the documentation
|
13 |
+
and/or other materials provided with the distribution.
|
14 |
+
|
15 |
+
3. Neither the name of the copyright holder nor the names of its
|
16 |
+
contributors may be used to endorse or promote products derived from
|
17 |
+
this software without specific prior written permission.
|
18 |
+
|
19 |
+
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
20 |
+
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
21 |
+
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
22 |
+
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
23 |
+
FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
24 |
+
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
25 |
+
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
26 |
+
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
27 |
+
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
28 |
+
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
README.md
CHANGED
@@ -1,23 +1,14 @@
|
|
1 |
-
|
2 |
-
license: bsd-3-clause
|
3 |
-
base_model:
|
4 |
-
- latent-consistency/lcm-lora-sdv1-5
|
5 |
-
- Lykon/dreamshaper-7
|
6 |
-
tags:
|
7 |
-
- RaspberryPi5
|
8 |
-
- StableDiffusion1.5
|
9 |
-
---
|
10 |
-
|
11 |
-
# SD1.5-LCM
|
12 |
|
13 |
基于 StableDiffusion 1.5 LCM 项目,展示该项目 **文生图**、**图生图** 在基于 AX650N 的产品上部署的流程。
|
14 |
|
15 |
-
|
|
|
16 |
|
17 |
-
|
18 |
-
|
19 |
-
|
20 |
-
|
21 |
|
22 |
原始模型请参考
|
23 |
- [Latent Consistency Model (LCM) LoRA: SDv1-5](https://huggingface.co/latent-consistency/lcm-lora-sdv1-5)
|
|
|
1 |
+
# SD1.5-LCM.Axera
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
2 |
|
3 |
基于 StableDiffusion 1.5 LCM 项目,展示该项目 **文生图**、**图生图** 在基于 AX650N 的产品上部署的流程。
|
4 |
|
5 |
+
支持芯片:
|
6 |
+
- AX650N
|
7 |
|
8 |
+
支持硬件
|
9 |
+
|
10 |
+
- 爱芯派Pro
|
11 |
+
- M.2 算力卡
|
12 |
|
13 |
原始模型请参考
|
14 |
- [Latent Consistency Model (LCM) LoRA: SDv1-5](https://huggingface.co/latent-consistency/lcm-lora-sdv1-5)
|
asserts/img2img-init.png
ADDED
![]() |
Git LFS Details
|
asserts/img2img_output_axe.png
ADDED
![]() |
Git LFS Details
|
asserts/lcm_lora_sdv1-5_imgGrid_output.png
ADDED
![]() |
Git LFS Details
|
asserts/lcm_lora_sdv1_5_axmodel.png
ADDED
![]() |
Git LFS Details
|
asserts/txt2img_output_axe.png
ADDED
![]() |
Git LFS Details
|
config.json.txt
ADDED
File without changes
|
run_img2img_axe_infer.py
ADDED
@@ -0,0 +1,546 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import List, Union
|
2 |
+
import numpy as np
|
3 |
+
# import onnxruntime
|
4 |
+
import axengine
|
5 |
+
import torch
|
6 |
+
from PIL import Image
|
7 |
+
from transformers import CLIPTokenizer, CLIPTextModel, PreTrainedTokenizer, CLIPTextModelWithProjection
|
8 |
+
|
9 |
+
import time
|
10 |
+
import argparse
|
11 |
+
from diffusers.utils import load_image
|
12 |
+
import PIL.Image
|
13 |
+
from typing import List, Optional, Tuple, Union
|
14 |
+
from diffusers.models.autoencoders.vae import DiagonalGaussianDistribution
|
15 |
+
from diffusers.models.modeling_outputs import AutoencoderKLOutput
|
16 |
+
from diffusers.utils.torch_utils import randn_tensor
|
17 |
+
from diffusers.utils import make_image_grid, load_image
|
18 |
+
|
19 |
+
|
20 |
+
|
21 |
+
########## Img2Img
|
22 |
+
PipelineImageInput = Union[
|
23 |
+
PIL.Image.Image,
|
24 |
+
np.ndarray,
|
25 |
+
torch.Tensor,
|
26 |
+
List[PIL.Image.Image],
|
27 |
+
List[np.ndarray],
|
28 |
+
List[torch.Tensor],
|
29 |
+
]
|
30 |
+
|
31 |
+
PipelineDepthInput = PipelineImageInput
|
32 |
+
|
33 |
+
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.add_noise
|
34 |
+
def add_noise(
|
35 |
+
original_samples: torch.Tensor,
|
36 |
+
noise: torch.Tensor,
|
37 |
+
timesteps: torch.IntTensor,
|
38 |
+
) -> torch.Tensor:
|
39 |
+
# Make sure alphas_cumprod and timestep have same device and dtype as original_samples
|
40 |
+
# Move the self.alphas_cumprod to device to avoid redundant CPU to GPU data movement
|
41 |
+
# for the subsequent add_noise calls
|
42 |
+
# self.alphas_cumprod = self.alphas_cumprod.to(device=original_samples.device)
|
43 |
+
# Convert betas to alphas_bar_sqrt
|
44 |
+
beta_start = 0.00085
|
45 |
+
beta_end = 0.012
|
46 |
+
num_train_timesteps = 1000
|
47 |
+
betas = torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2
|
48 |
+
alphas = 1.0 - betas
|
49 |
+
alphas_cumprod = torch.cumprod(alphas, dim=0)
|
50 |
+
alphas_cumprod = alphas_cumprod.to(device=original_samples.device)
|
51 |
+
alphas_cumprod = alphas_cumprod.to(dtype=original_samples.dtype)
|
52 |
+
timesteps = timesteps.to(original_samples.device)
|
53 |
+
|
54 |
+
sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5
|
55 |
+
sqrt_alpha_prod = sqrt_alpha_prod.flatten()
|
56 |
+
while len(sqrt_alpha_prod.shape) < len(original_samples.shape):
|
57 |
+
sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1)
|
58 |
+
|
59 |
+
sqrt_one_minus_alpha_prod = (1 - alphas_cumprod[timesteps]) ** 0.5
|
60 |
+
sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten()
|
61 |
+
while len(sqrt_one_minus_alpha_prod.shape) < len(original_samples.shape):
|
62 |
+
sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1)
|
63 |
+
|
64 |
+
noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise
|
65 |
+
return noisy_samples
|
66 |
+
|
67 |
+
def retrieve_latents(
|
68 |
+
encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
|
69 |
+
):
|
70 |
+
if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
|
71 |
+
return encoder_output.latent_dist.sample(generator)
|
72 |
+
elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
|
73 |
+
return encoder_output.latent_dist.mode()
|
74 |
+
elif hasattr(encoder_output, "latents"):
|
75 |
+
return encoder_output.latents
|
76 |
+
else:
|
77 |
+
raise AttributeError("Could not access latents of provided encoder_output")
|
78 |
+
|
79 |
+
def numpy_to_pt(images: np.ndarray) -> torch.Tensor:
|
80 |
+
r"""
|
81 |
+
Convert a NumPy image to a PyTorch tensor.
|
82 |
+
|
83 |
+
Args:
|
84 |
+
images (`np.ndarray`):
|
85 |
+
The NumPy image array to convert to PyTorch format.
|
86 |
+
|
87 |
+
Returns:
|
88 |
+
`torch.Tensor`:
|
89 |
+
A PyTorch tensor representation of the images.
|
90 |
+
"""
|
91 |
+
if images.ndim == 3:
|
92 |
+
images = images[..., None]
|
93 |
+
|
94 |
+
images = torch.from_numpy(images.transpose(0, 3, 1, 2))
|
95 |
+
return images
|
96 |
+
|
97 |
+
def pil_to_numpy(images: Union[List[PIL.Image.Image], PIL.Image.Image]) -> np.ndarray:
|
98 |
+
r"""
|
99 |
+
Convert a PIL image or a list of PIL images to NumPy arrays.
|
100 |
+
|
101 |
+
Args:
|
102 |
+
images (`PIL.Image.Image` or `List[PIL.Image.Image]`):
|
103 |
+
The PIL image or list of images to convert to NumPy format.
|
104 |
+
|
105 |
+
Returns:
|
106 |
+
`np.ndarray`:
|
107 |
+
A NumPy array representation of the images.
|
108 |
+
"""
|
109 |
+
if not isinstance(images, list):
|
110 |
+
images = [images]
|
111 |
+
images = [np.array(image).astype(np.float32) / 255.0 for image in images]
|
112 |
+
images = np.stack(images, axis=0)
|
113 |
+
|
114 |
+
return images
|
115 |
+
|
116 |
+
def is_valid_image(image) -> bool:
|
117 |
+
r"""
|
118 |
+
Checks if the input is a valid image.
|
119 |
+
|
120 |
+
A valid image can be:
|
121 |
+
- A `PIL.Image.Image`.
|
122 |
+
- A 2D or 3D `np.ndarray` or `torch.Tensor` (grayscale or color image).
|
123 |
+
|
124 |
+
Args:
|
125 |
+
image (`Union[PIL.Image.Image, np.ndarray, torch.Tensor]`):
|
126 |
+
The image to validate. It can be a PIL image, a NumPy array, or a torch tensor.
|
127 |
+
|
128 |
+
Returns:
|
129 |
+
`bool`:
|
130 |
+
`True` if the input is a valid image, `False` otherwise.
|
131 |
+
"""
|
132 |
+
return isinstance(image, PIL.Image.Image) or isinstance(image, (np.ndarray, torch.Tensor)) and image.ndim in (2, 3)
|
133 |
+
|
134 |
+
def is_valid_image_imagelist(images):
|
135 |
+
r"""
|
136 |
+
Checks if the input is a valid image or list of images.
|
137 |
+
|
138 |
+
The input can be one of the following formats:
|
139 |
+
- A 4D tensor or numpy array (batch of images).
|
140 |
+
- A valid single image: `PIL.Image.Image`, 2D `np.ndarray` or `torch.Tensor` (grayscale image), 3D `np.ndarray` or
|
141 |
+
`torch.Tensor`.
|
142 |
+
- A list of valid images.
|
143 |
+
|
144 |
+
Args:
|
145 |
+
images (`Union[np.ndarray, torch.Tensor, PIL.Image.Image, List]`):
|
146 |
+
The image(s) to check. Can be a batch of images (4D tensor/array), a single image, or a list of valid
|
147 |
+
images.
|
148 |
+
|
149 |
+
Returns:
|
150 |
+
`bool`:
|
151 |
+
`True` if the input is valid, `False` otherwise.
|
152 |
+
"""
|
153 |
+
if isinstance(images, (np.ndarray, torch.Tensor)) and images.ndim == 4:
|
154 |
+
return True
|
155 |
+
elif is_valid_image(images):
|
156 |
+
return True
|
157 |
+
elif isinstance(images, list):
|
158 |
+
return all(is_valid_image(image) for image in images)
|
159 |
+
return False
|
160 |
+
|
161 |
+
|
162 |
+
def normalize(images: Union[np.ndarray, torch.Tensor]) -> Union[np.ndarray, torch.Tensor]:
|
163 |
+
r"""
|
164 |
+
Normalize an image array to [-1,1].
|
165 |
+
|
166 |
+
Args:
|
167 |
+
images (`np.ndarray` or `torch.Tensor`):
|
168 |
+
The image array to normalize.
|
169 |
+
|
170 |
+
Returns:
|
171 |
+
`np.ndarray` or `torch.Tensor`:
|
172 |
+
The normalized image array.
|
173 |
+
"""
|
174 |
+
return 2.0 * images - 1.0
|
175 |
+
|
176 |
+
# Copy from: /home/baiyongqiang/miniforge-pypy3/envs/hf/lib/python3.9/site-packages/diffusers/image_processor.py#607
|
177 |
+
def preprocess(
|
178 |
+
image: PipelineImageInput,
|
179 |
+
height: Optional[int] = None,
|
180 |
+
width: Optional[int] = None,
|
181 |
+
resize_mode: str = "default", # "default", "fill", "crop"
|
182 |
+
crops_coords: Optional[Tuple[int, int, int, int]] = None,
|
183 |
+
) -> torch.Tensor:
|
184 |
+
"""
|
185 |
+
Preprocess the image input.
|
186 |
+
|
187 |
+
Args:
|
188 |
+
image (`PipelineImageInput`):
|
189 |
+
The image input, accepted formats are PIL images, NumPy arrays, PyTorch tensors; Also accept list of
|
190 |
+
supported formats.
|
191 |
+
height (`int`, *optional*):
|
192 |
+
The height in preprocessed image. If `None`, will use the `get_default_height_width()` to get default
|
193 |
+
height.
|
194 |
+
width (`int`, *optional*):
|
195 |
+
The width in preprocessed. If `None`, will use get_default_height_width()` to get the default width.
|
196 |
+
resize_mode (`str`, *optional*, defaults to `default`):
|
197 |
+
The resize mode, can be one of `default` or `fill`. If `default`, will resize the image to fit within
|
198 |
+
the specified width and height, and it may not maintaining the original aspect ratio. If `fill`, will
|
199 |
+
resize the image to fit within the specified width and height, maintaining the aspect ratio, and then
|
200 |
+
center the image within the dimensions, filling empty with data from image. If `crop`, will resize the
|
201 |
+
image to fit within the specified width and height, maintaining the aspect ratio, and then center the
|
202 |
+
image within the dimensions, cropping the excess. Note that resize_mode `fill` and `crop` are only
|
203 |
+
supported for PIL image input.
|
204 |
+
crops_coords (`List[Tuple[int, int, int, int]]`, *optional*, defaults to `None`):
|
205 |
+
The crop coordinates for each image in the batch. If `None`, will not crop the image.
|
206 |
+
|
207 |
+
Returns:
|
208 |
+
`torch.Tensor`:
|
209 |
+
The preprocessed image.
|
210 |
+
"""
|
211 |
+
supported_formats = (PIL.Image.Image, np.ndarray, torch.Tensor)
|
212 |
+
|
213 |
+
# # Expand the missing dimension for 3-dimensional pytorch tensor or numpy array that represents grayscale image
|
214 |
+
# if self.config.do_convert_grayscale and isinstance(image, (torch.Tensor, np.ndarray)) and image.ndim == 3:
|
215 |
+
# if isinstance(image, torch.Tensor):
|
216 |
+
# # if image is a pytorch tensor could have 2 possible shapes:
|
217 |
+
# # 1. batch x height x width: we should insert the channel dimension at position 1
|
218 |
+
# # 2. channel x height x width: we should insert batch dimension at position 0,
|
219 |
+
# # however, since both channel and batch dimension has same size 1, it is same to insert at position 1
|
220 |
+
# # for simplicity, we insert a dimension of size 1 at position 1 for both cases
|
221 |
+
# image = image.unsqueeze(1)
|
222 |
+
# else:
|
223 |
+
# # if it is a numpy array, it could have 2 possible shapes:
|
224 |
+
# # 1. batch x height x width: insert channel dimension on last position
|
225 |
+
# # 2. height x width x channel: insert batch dimension on first position
|
226 |
+
# if image.shape[-1] == 1:
|
227 |
+
# image = np.expand_dims(image, axis=0)
|
228 |
+
# else:
|
229 |
+
# image = np.expand_dims(image, axis=-1)
|
230 |
+
|
231 |
+
if isinstance(image, list) and isinstance(image[0], np.ndarray) and image[0].ndim == 4:
|
232 |
+
warnings.warn(
|
233 |
+
"Passing `image` as a list of 4d np.ndarray is deprecated."
|
234 |
+
"Please concatenate the list along the batch dimension and pass it as a single 4d np.ndarray",
|
235 |
+
FutureWarning,
|
236 |
+
)
|
237 |
+
image = np.concatenate(image, axis=0)
|
238 |
+
if isinstance(image, list) and isinstance(image[0], torch.Tensor) and image[0].ndim == 4:
|
239 |
+
warnings.warn(
|
240 |
+
"Passing `image` as a list of 4d torch.Tensor is deprecated."
|
241 |
+
"Please concatenate the list along the batch dimension and pass it as a single 4d torch.Tensor",
|
242 |
+
FutureWarning,
|
243 |
+
)
|
244 |
+
image = torch.cat(image, axis=0)
|
245 |
+
|
246 |
+
if not is_valid_image_imagelist(image):
|
247 |
+
raise ValueError(
|
248 |
+
f"Input is in incorrect format. Currently, we only support {', '.join(str(x) for x in supported_formats)}"
|
249 |
+
)
|
250 |
+
if not isinstance(image, list):
|
251 |
+
image = [image]
|
252 |
+
|
253 |
+
if isinstance(image[0], PIL.Image.Image):
|
254 |
+
if crops_coords is not None:
|
255 |
+
image = [i.crop(crops_coords) for i in image]
|
256 |
+
# if self.config.do_resize:
|
257 |
+
# height, width = self.get_default_height_width(image[0], height, width)
|
258 |
+
# image = [self.resize(i, height, width, resize_mode=resize_mode) for i in image]
|
259 |
+
# if self.config.do_convert_rgb:
|
260 |
+
# image = [self.convert_to_rgb(i) for i in image]
|
261 |
+
# elif self.config.do_convert_grayscale:
|
262 |
+
# image = [self.convert_to_grayscale(i) for i in image]
|
263 |
+
image = pil_to_numpy(image) # to np
|
264 |
+
image = numpy_to_pt(image) # to pt
|
265 |
+
|
266 |
+
elif isinstance(image[0], np.ndarray):
|
267 |
+
image = np.concatenate(image, axis=0) if image[0].ndim == 4 else np.stack(image, axis=0)
|
268 |
+
|
269 |
+
# image = self.numpy_to_pt(image)
|
270 |
+
|
271 |
+
# height, width = self.get_default_height_width(image, height, width)
|
272 |
+
# if self.config.do_resize:
|
273 |
+
# image = self.resize(image, height, width)
|
274 |
+
|
275 |
+
elif isinstance(image[0], torch.Tensor):
|
276 |
+
image = torch.cat(image, axis=0) if image[0].ndim == 4 else torch.stack(image, axis=0)
|
277 |
+
|
278 |
+
# if self.config.do_convert_grayscale and image.ndim == 3:
|
279 |
+
# image = image.unsqueeze(1)
|
280 |
+
|
281 |
+
channel = image.shape[1]
|
282 |
+
# don't need any preprocess if the image is latents
|
283 |
+
# if channel == self.config.vae_latent_channels:
|
284 |
+
# return image
|
285 |
+
|
286 |
+
# height, width = self.get_default_height_width(image, height, width)
|
287 |
+
# if self.config.do_resize:
|
288 |
+
# image = self.resize(image, height, width)
|
289 |
+
|
290 |
+
# expected range [0,1], normalize to [-1,1]
|
291 |
+
do_normalize = True # self.config.do_normalize
|
292 |
+
if do_normalize and image.min() < 0:
|
293 |
+
warnings.warn(
|
294 |
+
"Passing `image` as torch tensor with value range in [-1,1] is deprecated. The expected value range for image tensor is [0,1] "
|
295 |
+
f"when passing as pytorch tensor or numpy Array. You passed `image` with value range [{image.min()},{image.max()}]",
|
296 |
+
FutureWarning,
|
297 |
+
)
|
298 |
+
do_normalize = False
|
299 |
+
if do_normalize:
|
300 |
+
image = normalize(image)
|
301 |
+
|
302 |
+
# if self.config.do_binarize:
|
303 |
+
# image = self.binarize(image)
|
304 |
+
|
305 |
+
return image
|
306 |
+
##########
|
307 |
+
|
308 |
+
|
309 |
+
def get_args():
|
310 |
+
parser = argparse.ArgumentParser(
|
311 |
+
prog="StableDiffusion",
|
312 |
+
description="Generate picture with the input prompt"
|
313 |
+
)
|
314 |
+
parser.add_argument("--prompt", type=str, required=False, default="Astronauts in a jungle, cold color palette, muted colors, detailed, 8k", help="the input text prompt")
|
315 |
+
parser.add_argument("--text_model_dir", type=str, required=False, default="./models/", help="Path to text encoder and tokenizer files")
|
316 |
+
parser.add_argument("--unet_model", type=str, required=False, default="./models/unet.axmodel", help="Path to unet axmodel model")
|
317 |
+
parser.add_argument("--vae_encoder_model", type=str, required=False, default="./models/vae_encoder.axmodel", help="Path to vae encoder axmodel model")
|
318 |
+
parser.add_argument("--vae_decoder_model", type=str, required=False, default="./models/vae_decoder.axmodel", help="Path to vae decoder axmodel model")
|
319 |
+
parser.add_argument("--time_input", type=str, required=False, default="./models/time_input_img2img.npy", help="Path to time input file")
|
320 |
+
parser.add_argument("--init_image", type=str, required=False, default="./models/img2img-init.png", help="Path to initial image file")
|
321 |
+
parser.add_argument("--save_dir", type=str, required=False, default="./img2img_output_axe.png", help="Path to the output image file")
|
322 |
+
return parser.parse_args()
|
323 |
+
|
324 |
+
def maybe_convert_prompt(prompt: Union[str, List[str]], tokenizer: "PreTrainedTokenizer"): # noqa: F821
|
325 |
+
if not isinstance(prompt, List):
|
326 |
+
prompts = [prompt]
|
327 |
+
else:
|
328 |
+
prompts = prompt
|
329 |
+
|
330 |
+
prompts = [_maybe_convert_prompt(p, tokenizer) for p in prompts]
|
331 |
+
|
332 |
+
if not isinstance(prompt, List):
|
333 |
+
return prompts[0]
|
334 |
+
|
335 |
+
return prompts
|
336 |
+
|
337 |
+
|
338 |
+
def _maybe_convert_prompt(prompt: str, tokenizer: "PreTrainedTokenizer"): # noqa: F821
|
339 |
+
tokens = tokenizer.tokenize(prompt)
|
340 |
+
unique_tokens = set(tokens)
|
341 |
+
for token in unique_tokens:
|
342 |
+
if token in tokenizer.added_tokens_encoder:
|
343 |
+
replacement = token
|
344 |
+
i = 1
|
345 |
+
while f"{token}_{i}" in tokenizer.added_tokens_encoder:
|
346 |
+
replacement += f" {token}_{i}"
|
347 |
+
i += 1
|
348 |
+
|
349 |
+
prompt = prompt.replace(token, replacement)
|
350 |
+
|
351 |
+
return prompt
|
352 |
+
|
353 |
+
|
354 |
+
def get_embeds(prompt = "Portrait of a pretty girl", tokenizer_dir = "./models/tokenizer", text_encoder_dir = "./models/text_encoder"):
|
355 |
+
tokenizer = CLIPTokenizer.from_pretrained(tokenizer_dir)
|
356 |
+
text_encoder = CLIPTextModel.from_pretrained(text_encoder_dir,
|
357 |
+
torch_dtype=torch.float32,
|
358 |
+
variant="fp16")
|
359 |
+
text_inputs = tokenizer(
|
360 |
+
prompt,
|
361 |
+
padding="max_length",
|
362 |
+
max_length=77,
|
363 |
+
truncation=True,
|
364 |
+
return_tensors="pt",
|
365 |
+
)
|
366 |
+
text_input_ids = text_inputs.input_ids
|
367 |
+
prompt_embeds = text_encoder(text_input_ids.to("cpu"), attention_mask=None)
|
368 |
+
|
369 |
+
prompt_embeds_npy = prompt_embeds[0].detach().numpy()
|
370 |
+
return prompt_embeds_npy
|
371 |
+
|
372 |
+
|
373 |
+
def get_alphas_cumprod():
|
374 |
+
betas = torch.linspace(0.00085 ** 0.5, 0.012 ** 0.5, 1000, dtype=torch.float32) ** 2
|
375 |
+
alphas = 1.0 - betas
|
376 |
+
alphas_cumprod = torch.cumprod(alphas, dim=0).detach().numpy()
|
377 |
+
final_alphas_cumprod = alphas_cumprod[0]
|
378 |
+
self_timesteps = np.arange(0, 1000)[::-1].copy().astype(np.int64)
|
379 |
+
return alphas_cumprod, final_alphas_cumprod, self_timesteps
|
380 |
+
|
381 |
+
def resize_and_rgb(image: PIL.Image.Image) -> PIL.Image.Image:
|
382 |
+
"""
|
383 |
+
Resize the image to 512x512 and convert it to RGB.
|
384 |
+
"""
|
385 |
+
return image.resize((512, 512)).convert("RGB")
|
386 |
+
|
387 |
+
|
388 |
+
if __name__ == '__main__':
|
389 |
+
|
390 |
+
"""
|
391 |
+
Usage:
|
392 |
+
- python3 run_img2img_axmodel_infer.py --prompt "Astronauts in a jungle, cold color palette, muted colors, detailed, 8k" --unet_model output_onnx/unet_sim.onnx --vae_encoder_model output_onnx/vae_encoder_sim.onnx --vae_decoder_model output_onnx/vae_decoder_sim.onnx --time_input ./output_onnx/time_input.npy --save_dir ./img2img_output.png
|
393 |
+
"""
|
394 |
+
args = get_args()
|
395 |
+
prompt = args.prompt
|
396 |
+
tokenizer_dir = args.text_model_dir + 'tokenizer'
|
397 |
+
text_encoder_dir = args.text_model_dir + 'text_encoder'
|
398 |
+
unet_model = args.unet_model
|
399 |
+
vae_decoder_model = args.vae_decoder_model
|
400 |
+
vae_encoder_model = args.vae_encoder_model
|
401 |
+
init_image = args.init_image
|
402 |
+
time_input = args.time_input
|
403 |
+
save_dir = args.save_dir
|
404 |
+
|
405 |
+
print(f"prompt: {prompt}")
|
406 |
+
print(f"text_tokenizer: {tokenizer_dir}")
|
407 |
+
print(f"text_encoder: {text_encoder_dir}")
|
408 |
+
print(f"unet_model: {unet_model}")
|
409 |
+
print(f"vae_encoder_model: {vae_encoder_model}")
|
410 |
+
print(f"vae_decoder_model: {vae_decoder_model}")
|
411 |
+
print(f"init image: {init_image}")
|
412 |
+
print(f"time_input: {time_input}")
|
413 |
+
print(f"save_dir: {save_dir}")
|
414 |
+
|
415 |
+
# timesteps = np.array([999, 759, 499, 259]).astype(np.int64)
|
416 |
+
|
417 |
+
# text encoder
|
418 |
+
start = time.time()
|
419 |
+
# prompt = "Self-portrait oil painting, a beautiful cyborg with golden hair, 8k"
|
420 |
+
# prompt = "Astronauts in a jungle, cold color palette, muted colors, detailed, 8k"
|
421 |
+
# prompt = "Caricature, a beautiful girl with black hair, 8k"
|
422 |
+
prompt_embeds_npy = get_embeds(prompt, tokenizer_dir, text_encoder_dir)
|
423 |
+
print(f"text encoder take {(1000 * (time.time() - start)):.1f}ms")
|
424 |
+
|
425 |
+
prompt_name = prompt.replace(" ", "_")
|
426 |
+
latents_shape = [1, 4, 64, 64]
|
427 |
+
# latent = torch.randn(latents_shape, generator=None, device="cpu", dtype=torch.float32,
|
428 |
+
# layout=torch.strided).detach().numpy()
|
429 |
+
|
430 |
+
alphas_cumprod, final_alphas_cumprod, self_timesteps = get_alphas_cumprod()
|
431 |
+
|
432 |
+
# load unet model and vae model
|
433 |
+
start = time.time()
|
434 |
+
vae_encoder = axengine.InferenceSession(vae_encoder_model)
|
435 |
+
unet_session_main = axengine.InferenceSession(unet_model)
|
436 |
+
vae_decoder = axengine.InferenceSession(vae_decoder_model)
|
437 |
+
print(f"load models take {(1000 * (time.time() - start)):.1f}ms")
|
438 |
+
|
439 |
+
# load time input file
|
440 |
+
time_input = np.load(time_input)
|
441 |
+
|
442 |
+
# load image
|
443 |
+
# url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/img2img-init.png"
|
444 |
+
url = init_image
|
445 |
+
init_image = load_image(url, convert_method=resize_and_rgb) # U8, (512, 512, 3), RGB
|
446 |
+
init_image_show = init_image
|
447 |
+
|
448 |
+
# vae encoder inference
|
449 |
+
vae_start = time.time()
|
450 |
+
|
451 |
+
init_image = preprocess(init_image) # torch.Size([1, 3, 512, 512])
|
452 |
+
if isinstance(init_image, torch.Tensor):
|
453 |
+
init_image = init_image.detach().numpy()
|
454 |
+
|
455 |
+
vae_encoder_onnx_inp_name = vae_encoder.get_inputs()[0].name
|
456 |
+
vae_encoder_onnx_out_name = vae_encoder.get_outputs()[0].name
|
457 |
+
|
458 |
+
# vae_encoder_out.shape (1, 8, 64, 64)
|
459 |
+
vae_encoder_out = vae_encoder.run(None, {vae_encoder_onnx_inp_name: init_image})[0] # encoder out: torch.Size([1, 8, 64, 64])
|
460 |
+
print(f"vae encoder inference take {(1000 * (time.time() - vae_start)):.1f}ms")
|
461 |
+
|
462 |
+
# vae encoder inference
|
463 |
+
device = torch.device("cpu")
|
464 |
+
vae_encoder_out = torch.from_numpy(vae_encoder_out).to(torch.float32)
|
465 |
+
posterior = DiagonalGaussianDistribution(vae_encoder_out) # 数值基本对的上
|
466 |
+
vae_encode_info = AutoencoderKLOutput(latent_dist=posterior)
|
467 |
+
generator = torch.manual_seed(0)
|
468 |
+
init_latents = retrieve_latents(vae_encode_info, generator=generator) # 数值基本对的上
|
469 |
+
init_latents = init_latents * 0.18215 # 数值基本对的��
|
470 |
+
init_latents = torch.cat([init_latents], dim=0)
|
471 |
+
shape = init_latents.shape
|
472 |
+
dtype = torch.float16
|
473 |
+
noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype) # dtype 不同, 随机值不同
|
474 |
+
# get latents
|
475 |
+
timestep = torch.tensor([499]).to(device)
|
476 |
+
init_latents = add_noise(init_latents.to(device), noise, timestep)
|
477 |
+
latents = init_latents
|
478 |
+
|
479 |
+
latents = latents.detach().cpu().numpy()
|
480 |
+
latent = latents
|
481 |
+
|
482 |
+
# unet inference loop
|
483 |
+
unet_loop_start = time.time()
|
484 |
+
timesteps = np.array([499, 259]).astype(np.int64)
|
485 |
+
self_timesteps = np.array([999, 759, 499, 259]).astype(np.int64)
|
486 |
+
step_index = [2, 3]
|
487 |
+
for i, timestep in enumerate(timesteps):
|
488 |
+
unet_start = time.time()
|
489 |
+
noise_pred = unet_session_main.run(None, {"sample": latent, \
|
490 |
+
"/down_blocks.0/resnets.0/act_1/Mul_output_0": np.expand_dims(time_input[i], axis=0), \
|
491 |
+
"encoder_hidden_states": prompt_embeds_npy})[0]
|
492 |
+
|
493 |
+
print(f"unet once take {(1000 * (time.time() - unet_start)):.1f}ms")
|
494 |
+
|
495 |
+
sample = latent
|
496 |
+
model_output = noise_pred
|
497 |
+
|
498 |
+
# 1. get previous step value
|
499 |
+
prev_step_index = step_index[i] + 1
|
500 |
+
if prev_step_index < len(self_timesteps):
|
501 |
+
prev_timestep = self_timesteps[prev_step_index]
|
502 |
+
else:
|
503 |
+
prev_timestep = timestep
|
504 |
+
|
505 |
+
alpha_prod_t = alphas_cumprod[timestep]
|
506 |
+
alpha_prod_t_prev = alphas_cumprod[prev_timestep] if prev_timestep >= 0 else final_alphas_cumprod
|
507 |
+
beta_prod_t = 1 - alpha_prod_t
|
508 |
+
beta_prod_t_prev = 1 - alpha_prod_t_prev
|
509 |
+
|
510 |
+
# 3. Get scalings for boundary conditions
|
511 |
+
scaled_timestep = timestep * 10
|
512 |
+
c_skip = 0.5 ** 2 / (scaled_timestep ** 2 + 0.5 ** 2)
|
513 |
+
c_out = scaled_timestep / (scaled_timestep ** 2 + 0.5 ** 2) ** 0.5
|
514 |
+
predicted_original_sample = (sample - (beta_prod_t ** 0.5) * model_output) / (alpha_prod_t ** 0.5) # 数值基本对齐
|
515 |
+
|
516 |
+
denoised = c_out * predicted_original_sample + c_skip * sample
|
517 |
+
if step_index[i] != 3:
|
518 |
+
device = torch.device("cpu")
|
519 |
+
noise = randn_tensor(model_output.shape, generator=generator, device=device, dtype=torch.float16).numpy()
|
520 |
+
prev_sample = (alpha_prod_t_prev ** 0.5) * denoised + (beta_prod_t_prev ** 0.5) * noise
|
521 |
+
else:
|
522 |
+
prev_sample = denoised
|
523 |
+
|
524 |
+
latent = prev_sample
|
525 |
+
|
526 |
+
print(f"unet loop take {(1000 * (time.time() - unet_loop_start)):.1f}ms")
|
527 |
+
|
528 |
+
# vae decoder inference
|
529 |
+
vae_start = time.time()
|
530 |
+
latent = latent / 0.18215
|
531 |
+
image = vae_decoder.run(None, {"x": latent})[0] # ['784']
|
532 |
+
print(f"vae decoder inference take {(1000 * (time.time() - vae_start)):.1f}ms")
|
533 |
+
|
534 |
+
# save result
|
535 |
+
save_start = time.time()
|
536 |
+
image = np.transpose(image, (0, 2, 3, 1)).squeeze(axis=0)
|
537 |
+
image_denorm = np.clip(image / 2 + 0.5, 0, 1)
|
538 |
+
image = (image_denorm * 255).round().astype("uint8")
|
539 |
+
pil_image = Image.fromarray(image[:, :, :3])
|
540 |
+
pil_image.save(save_dir)
|
541 |
+
|
542 |
+
grid_img = make_image_grid([init_image_show, pil_image], rows=1, cols=2)
|
543 |
+
grid_img.save(f"./lcm_lora_sdv1-5_imgGrid_output.png")
|
544 |
+
|
545 |
+
print(f"grid image saved in ./lcm_lora_sdv1-5_imgGrid_output.png")
|
546 |
+
print(f"save image take {(1000 * (time.time() - save_start)):.1f}ms")
|
run_img2img_onnx_infer.py
ADDED
@@ -0,0 +1,546 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import List, Union
|
2 |
+
import numpy as np
|
3 |
+
import onnxruntime
|
4 |
+
import torch
|
5 |
+
from PIL import Image
|
6 |
+
from transformers import CLIPTokenizer, CLIPTextModel, PreTrainedTokenizer, CLIPTextModelWithProjection
|
7 |
+
|
8 |
+
# import axengine as axe
|
9 |
+
import time
|
10 |
+
import argparse
|
11 |
+
from diffusers.utils import load_image
|
12 |
+
import PIL.Image
|
13 |
+
from typing import List, Optional, Tuple, Union
|
14 |
+
from diffusers.models.autoencoders.vae import DiagonalGaussianDistribution
|
15 |
+
from diffusers.models.modeling_outputs import AutoencoderKLOutput
|
16 |
+
from diffusers.utils.torch_utils import randn_tensor
|
17 |
+
from diffusers.utils import make_image_grid, load_image
|
18 |
+
|
19 |
+
|
20 |
+
|
21 |
+
########## Img2Img
|
22 |
+
PipelineImageInput = Union[
|
23 |
+
PIL.Image.Image,
|
24 |
+
np.ndarray,
|
25 |
+
torch.Tensor,
|
26 |
+
List[PIL.Image.Image],
|
27 |
+
List[np.ndarray],
|
28 |
+
List[torch.Tensor],
|
29 |
+
]
|
30 |
+
|
31 |
+
PipelineDepthInput = PipelineImageInput
|
32 |
+
|
33 |
+
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.add_noise
|
34 |
+
def add_noise(
|
35 |
+
original_samples: torch.Tensor,
|
36 |
+
noise: torch.Tensor,
|
37 |
+
timesteps: torch.IntTensor,
|
38 |
+
) -> torch.Tensor:
|
39 |
+
# Make sure alphas_cumprod and timestep have same device and dtype as original_samples
|
40 |
+
# Move the self.alphas_cumprod to device to avoid redundant CPU to GPU data movement
|
41 |
+
# for the subsequent add_noise calls
|
42 |
+
# self.alphas_cumprod = self.alphas_cumprod.to(device=original_samples.device)
|
43 |
+
# Convert betas to alphas_bar_sqrt
|
44 |
+
beta_start = 0.00085
|
45 |
+
beta_end = 0.012
|
46 |
+
num_train_timesteps = 1000
|
47 |
+
betas = torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2
|
48 |
+
alphas = 1.0 - betas
|
49 |
+
alphas_cumprod = torch.cumprod(alphas, dim=0)
|
50 |
+
alphas_cumprod = alphas_cumprod.to(device=original_samples.device)
|
51 |
+
alphas_cumprod = alphas_cumprod.to(dtype=original_samples.dtype)
|
52 |
+
timesteps = timesteps.to(original_samples.device)
|
53 |
+
|
54 |
+
sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5
|
55 |
+
sqrt_alpha_prod = sqrt_alpha_prod.flatten()
|
56 |
+
while len(sqrt_alpha_prod.shape) < len(original_samples.shape):
|
57 |
+
sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1)
|
58 |
+
|
59 |
+
sqrt_one_minus_alpha_prod = (1 - alphas_cumprod[timesteps]) ** 0.5
|
60 |
+
sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten()
|
61 |
+
while len(sqrt_one_minus_alpha_prod.shape) < len(original_samples.shape):
|
62 |
+
sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1)
|
63 |
+
|
64 |
+
noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise
|
65 |
+
return noisy_samples
|
66 |
+
|
67 |
+
def retrieve_latents(
|
68 |
+
encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
|
69 |
+
):
|
70 |
+
if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
|
71 |
+
return encoder_output.latent_dist.sample(generator)
|
72 |
+
elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
|
73 |
+
return encoder_output.latent_dist.mode()
|
74 |
+
elif hasattr(encoder_output, "latents"):
|
75 |
+
return encoder_output.latents
|
76 |
+
else:
|
77 |
+
raise AttributeError("Could not access latents of provided encoder_output")
|
78 |
+
|
79 |
+
def numpy_to_pt(images: np.ndarray) -> torch.Tensor:
|
80 |
+
r"""
|
81 |
+
Convert a NumPy image to a PyTorch tensor.
|
82 |
+
|
83 |
+
Args:
|
84 |
+
images (`np.ndarray`):
|
85 |
+
The NumPy image array to convert to PyTorch format.
|
86 |
+
|
87 |
+
Returns:
|
88 |
+
`torch.Tensor`:
|
89 |
+
A PyTorch tensor representation of the images.
|
90 |
+
"""
|
91 |
+
if images.ndim == 3:
|
92 |
+
images = images[..., None]
|
93 |
+
|
94 |
+
images = torch.from_numpy(images.transpose(0, 3, 1, 2))
|
95 |
+
return images
|
96 |
+
|
97 |
+
def pil_to_numpy(images: Union[List[PIL.Image.Image], PIL.Image.Image]) -> np.ndarray:
|
98 |
+
r"""
|
99 |
+
Convert a PIL image or a list of PIL images to NumPy arrays.
|
100 |
+
|
101 |
+
Args:
|
102 |
+
images (`PIL.Image.Image` or `List[PIL.Image.Image]`):
|
103 |
+
The PIL image or list of images to convert to NumPy format.
|
104 |
+
|
105 |
+
Returns:
|
106 |
+
`np.ndarray`:
|
107 |
+
A NumPy array representation of the images.
|
108 |
+
"""
|
109 |
+
if not isinstance(images, list):
|
110 |
+
images = [images]
|
111 |
+
images = [np.array(image).astype(np.float32) / 255.0 for image in images]
|
112 |
+
images = np.stack(images, axis=0)
|
113 |
+
|
114 |
+
return images
|
115 |
+
|
116 |
+
def is_valid_image(image) -> bool:
|
117 |
+
r"""
|
118 |
+
Checks if the input is a valid image.
|
119 |
+
|
120 |
+
A valid image can be:
|
121 |
+
- A `PIL.Image.Image`.
|
122 |
+
- A 2D or 3D `np.ndarray` or `torch.Tensor` (grayscale or color image).
|
123 |
+
|
124 |
+
Args:
|
125 |
+
image (`Union[PIL.Image.Image, np.ndarray, torch.Tensor]`):
|
126 |
+
The image to validate. It can be a PIL image, a NumPy array, or a torch tensor.
|
127 |
+
|
128 |
+
Returns:
|
129 |
+
`bool`:
|
130 |
+
`True` if the input is a valid image, `False` otherwise.
|
131 |
+
"""
|
132 |
+
return isinstance(image, PIL.Image.Image) or isinstance(image, (np.ndarray, torch.Tensor)) and image.ndim in (2, 3)
|
133 |
+
|
134 |
+
def is_valid_image_imagelist(images):
|
135 |
+
r"""
|
136 |
+
Checks if the input is a valid image or list of images.
|
137 |
+
|
138 |
+
The input can be one of the following formats:
|
139 |
+
- A 4D tensor or numpy array (batch of images).
|
140 |
+
- A valid single image: `PIL.Image.Image`, 2D `np.ndarray` or `torch.Tensor` (grayscale image), 3D `np.ndarray` or
|
141 |
+
`torch.Tensor`.
|
142 |
+
- A list of valid images.
|
143 |
+
|
144 |
+
Args:
|
145 |
+
images (`Union[np.ndarray, torch.Tensor, PIL.Image.Image, List]`):
|
146 |
+
The image(s) to check. Can be a batch of images (4D tensor/array), a single image, or a list of valid
|
147 |
+
images.
|
148 |
+
|
149 |
+
Returns:
|
150 |
+
`bool`:
|
151 |
+
`True` if the input is valid, `False` otherwise.
|
152 |
+
"""
|
153 |
+
if isinstance(images, (np.ndarray, torch.Tensor)) and images.ndim == 4:
|
154 |
+
return True
|
155 |
+
elif is_valid_image(images):
|
156 |
+
return True
|
157 |
+
elif isinstance(images, list):
|
158 |
+
return all(is_valid_image(image) for image in images)
|
159 |
+
return False
|
160 |
+
|
161 |
+
|
162 |
+
def normalize(images: Union[np.ndarray, torch.Tensor]) -> Union[np.ndarray, torch.Tensor]:
|
163 |
+
r"""
|
164 |
+
Normalize an image array to [-1,1].
|
165 |
+
|
166 |
+
Args:
|
167 |
+
images (`np.ndarray` or `torch.Tensor`):
|
168 |
+
The image array to normalize.
|
169 |
+
|
170 |
+
Returns:
|
171 |
+
`np.ndarray` or `torch.Tensor`:
|
172 |
+
The normalized image array.
|
173 |
+
"""
|
174 |
+
return 2.0 * images - 1.0
|
175 |
+
|
176 |
+
# Copy from: /home/baiyongqiang/miniforge-pypy3/envs/hf/lib/python3.9/site-packages/diffusers/image_processor.py#607
|
177 |
+
def preprocess(
|
178 |
+
image: PipelineImageInput,
|
179 |
+
height: Optional[int] = None,
|
180 |
+
width: Optional[int] = None,
|
181 |
+
resize_mode: str = "default", # "default", "fill", "crop"
|
182 |
+
crops_coords: Optional[Tuple[int, int, int, int]] = None,
|
183 |
+
) -> torch.Tensor:
|
184 |
+
"""
|
185 |
+
Preprocess the image input.
|
186 |
+
|
187 |
+
Args:
|
188 |
+
image (`PipelineImageInput`):
|
189 |
+
The image input, accepted formats are PIL images, NumPy arrays, PyTorch tensors; Also accept list of
|
190 |
+
supported formats.
|
191 |
+
height (`int`, *optional*):
|
192 |
+
The height in preprocessed image. If `None`, will use the `get_default_height_width()` to get default
|
193 |
+
height.
|
194 |
+
width (`int`, *optional*):
|
195 |
+
The width in preprocessed. If `None`, will use get_default_height_width()` to get the default width.
|
196 |
+
resize_mode (`str`, *optional*, defaults to `default`):
|
197 |
+
The resize mode, can be one of `default` or `fill`. If `default`, will resize the image to fit within
|
198 |
+
the specified width and height, and it may not maintaining the original aspect ratio. If `fill`, will
|
199 |
+
resize the image to fit within the specified width and height, maintaining the aspect ratio, and then
|
200 |
+
center the image within the dimensions, filling empty with data from image. If `crop`, will resize the
|
201 |
+
image to fit within the specified width and height, maintaining the aspect ratio, and then center the
|
202 |
+
image within the dimensions, cropping the excess. Note that resize_mode `fill` and `crop` are only
|
203 |
+
supported for PIL image input.
|
204 |
+
crops_coords (`List[Tuple[int, int, int, int]]`, *optional*, defaults to `None`):
|
205 |
+
The crop coordinates for each image in the batch. If `None`, will not crop the image.
|
206 |
+
|
207 |
+
Returns:
|
208 |
+
`torch.Tensor`:
|
209 |
+
The preprocessed image.
|
210 |
+
"""
|
211 |
+
supported_formats = (PIL.Image.Image, np.ndarray, torch.Tensor)
|
212 |
+
|
213 |
+
# # Expand the missing dimension for 3-dimensional pytorch tensor or numpy array that represents grayscale image
|
214 |
+
# if self.config.do_convert_grayscale and isinstance(image, (torch.Tensor, np.ndarray)) and image.ndim == 3:
|
215 |
+
# if isinstance(image, torch.Tensor):
|
216 |
+
# # if image is a pytorch tensor could have 2 possible shapes:
|
217 |
+
# # 1. batch x height x width: we should insert the channel dimension at position 1
|
218 |
+
# # 2. channel x height x width: we should insert batch dimension at position 0,
|
219 |
+
# # however, since both channel and batch dimension has same size 1, it is same to insert at position 1
|
220 |
+
# # for simplicity, we insert a dimension of size 1 at position 1 for both cases
|
221 |
+
# image = image.unsqueeze(1)
|
222 |
+
# else:
|
223 |
+
# # if it is a numpy array, it could have 2 possible shapes:
|
224 |
+
# # 1. batch x height x width: insert channel dimension on last position
|
225 |
+
# # 2. height x width x channel: insert batch dimension on first position
|
226 |
+
# if image.shape[-1] == 1:
|
227 |
+
# image = np.expand_dims(image, axis=0)
|
228 |
+
# else:
|
229 |
+
# image = np.expand_dims(image, axis=-1)
|
230 |
+
|
231 |
+
if isinstance(image, list) and isinstance(image[0], np.ndarray) and image[0].ndim == 4:
|
232 |
+
warnings.warn(
|
233 |
+
"Passing `image` as a list of 4d np.ndarray is deprecated."
|
234 |
+
"Please concatenate the list along the batch dimension and pass it as a single 4d np.ndarray",
|
235 |
+
FutureWarning,
|
236 |
+
)
|
237 |
+
image = np.concatenate(image, axis=0)
|
238 |
+
if isinstance(image, list) and isinstance(image[0], torch.Tensor) and image[0].ndim == 4:
|
239 |
+
warnings.warn(
|
240 |
+
"Passing `image` as a list of 4d torch.Tensor is deprecated."
|
241 |
+
"Please concatenate the list along the batch dimension and pass it as a single 4d torch.Tensor",
|
242 |
+
FutureWarning,
|
243 |
+
)
|
244 |
+
image = torch.cat(image, axis=0)
|
245 |
+
|
246 |
+
if not is_valid_image_imagelist(image):
|
247 |
+
raise ValueError(
|
248 |
+
f"Input is in incorrect format. Currently, we only support {', '.join(str(x) for x in supported_formats)}"
|
249 |
+
)
|
250 |
+
if not isinstance(image, list):
|
251 |
+
image = [image]
|
252 |
+
|
253 |
+
if isinstance(image[0], PIL.Image.Image):
|
254 |
+
if crops_coords is not None:
|
255 |
+
image = [i.crop(crops_coords) for i in image]
|
256 |
+
# if self.config.do_resize:
|
257 |
+
# height, width = self.get_default_height_width(image[0], height, width)
|
258 |
+
# image = [self.resize(i, height, width, resize_mode=resize_mode) for i in image]
|
259 |
+
# if self.config.do_convert_rgb:
|
260 |
+
# image = [self.convert_to_rgb(i) for i in image]
|
261 |
+
# elif self.config.do_convert_grayscale:
|
262 |
+
# image = [self.convert_to_grayscale(i) for i in image]
|
263 |
+
image = pil_to_numpy(image) # to np
|
264 |
+
image = numpy_to_pt(image) # to pt
|
265 |
+
|
266 |
+
elif isinstance(image[0], np.ndarray):
|
267 |
+
image = np.concatenate(image, axis=0) if image[0].ndim == 4 else np.stack(image, axis=0)
|
268 |
+
|
269 |
+
# image = self.numpy_to_pt(image)
|
270 |
+
|
271 |
+
# height, width = self.get_default_height_width(image, height, width)
|
272 |
+
# if self.config.do_resize:
|
273 |
+
# image = self.resize(image, height, width)
|
274 |
+
|
275 |
+
elif isinstance(image[0], torch.Tensor):
|
276 |
+
image = torch.cat(image, axis=0) if image[0].ndim == 4 else torch.stack(image, axis=0)
|
277 |
+
|
278 |
+
# if self.config.do_convert_grayscale and image.ndim == 3:
|
279 |
+
# image = image.unsqueeze(1)
|
280 |
+
|
281 |
+
channel = image.shape[1]
|
282 |
+
# don't need any preprocess if the image is latents
|
283 |
+
# if channel == self.config.vae_latent_channels:
|
284 |
+
# return image
|
285 |
+
|
286 |
+
# height, width = self.get_default_height_width(image, height, width)
|
287 |
+
# if self.config.do_resize:
|
288 |
+
# image = self.resize(image, height, width)
|
289 |
+
|
290 |
+
# expected range [0,1], normalize to [-1,1]
|
291 |
+
do_normalize = True # self.config.do_normalize
|
292 |
+
if do_normalize and image.min() < 0:
|
293 |
+
warnings.warn(
|
294 |
+
"Passing `image` as torch tensor with value range in [-1,1] is deprecated. The expected value range for image tensor is [0,1] "
|
295 |
+
f"when passing as pytorch tensor or numpy Array. You passed `image` with value range [{image.min()},{image.max()}]",
|
296 |
+
FutureWarning,
|
297 |
+
)
|
298 |
+
do_normalize = False
|
299 |
+
if do_normalize:
|
300 |
+
image = normalize(image)
|
301 |
+
|
302 |
+
# if self.config.do_binarize:
|
303 |
+
# image = self.binarize(image)
|
304 |
+
|
305 |
+
return image
|
306 |
+
##########
|
307 |
+
|
308 |
+
|
309 |
+
def get_args():
|
310 |
+
parser = argparse.ArgumentParser(
|
311 |
+
prog="StableDiffusion",
|
312 |
+
description="Generate picture with the input prompt"
|
313 |
+
)
|
314 |
+
parser.add_argument("--prompt", type=str, required=False, default="Astronauts in a jungle, cold color palette, muted colors, detailed, 8k", help="the input text prompt")
|
315 |
+
parser.add_argument("--text_model_dir", type=str, required=False, default="./models/", help="Path to text encoder and tokenizer files")
|
316 |
+
parser.add_argument("--unet_model", type=str, required=False, default="./models/unet.onnx", help="Path to unet ONNX model")
|
317 |
+
parser.add_argument("--vae_encoder_model", type=str, required=False, default="./models/vae_encoder.onnx", help="Path to vae encoder ONNX model")
|
318 |
+
parser.add_argument("--vae_decoder_model", type=str, required=False, default="./models/vae_decoder.onnx", help="Path to vae decoder ONNX model")
|
319 |
+
parser.add_argument("--time_input", type=str, required=False, default="./models/time_input_img2img.npy", help="Path to time input file")
|
320 |
+
parser.add_argument("--init_image", type=str, required=False, default="./models/img2img-init.png", help="Path to initial image file")
|
321 |
+
parser.add_argument("--save_dir", type=str, required=False, default="./img2img_output_onnx.png", help="Path to the output image file")
|
322 |
+
return parser.parse_args()
|
323 |
+
|
324 |
+
def maybe_convert_prompt(prompt: Union[str, List[str]], tokenizer: "PreTrainedTokenizer"): # noqa: F821
|
325 |
+
if not isinstance(prompt, List):
|
326 |
+
prompts = [prompt]
|
327 |
+
else:
|
328 |
+
prompts = prompt
|
329 |
+
|
330 |
+
prompts = [_maybe_convert_prompt(p, tokenizer) for p in prompts]
|
331 |
+
|
332 |
+
if not isinstance(prompt, List):
|
333 |
+
return prompts[0]
|
334 |
+
|
335 |
+
return prompts
|
336 |
+
|
337 |
+
|
338 |
+
def _maybe_convert_prompt(prompt: str, tokenizer: "PreTrainedTokenizer"): # noqa: F821
|
339 |
+
tokens = tokenizer.tokenize(prompt)
|
340 |
+
unique_tokens = set(tokens)
|
341 |
+
for token in unique_tokens:
|
342 |
+
if token in tokenizer.added_tokens_encoder:
|
343 |
+
replacement = token
|
344 |
+
i = 1
|
345 |
+
while f"{token}_{i}" in tokenizer.added_tokens_encoder:
|
346 |
+
replacement += f" {token}_{i}"
|
347 |
+
i += 1
|
348 |
+
|
349 |
+
prompt = prompt.replace(token, replacement)
|
350 |
+
|
351 |
+
return prompt
|
352 |
+
|
353 |
+
|
354 |
+
def get_embeds(prompt = "Portrait of a pretty girl", tokenizer_dir = "./models/tokenizer", text_encoder_dir = "./models/text_encoder"):
|
355 |
+
tokenizer = CLIPTokenizer.from_pretrained(tokenizer_dir)
|
356 |
+
text_encoder = CLIPTextModel.from_pretrained(text_encoder_dir,
|
357 |
+
torch_dtype=torch.float32,
|
358 |
+
variant="fp16")
|
359 |
+
text_inputs = tokenizer(
|
360 |
+
prompt,
|
361 |
+
padding="max_length",
|
362 |
+
max_length=77,
|
363 |
+
truncation=True,
|
364 |
+
return_tensors="pt",
|
365 |
+
)
|
366 |
+
text_input_ids = text_inputs.input_ids
|
367 |
+
prompt_embeds = text_encoder(text_input_ids.to("cpu"), attention_mask=None)
|
368 |
+
|
369 |
+
prompt_embeds_npy = prompt_embeds[0].detach().numpy()
|
370 |
+
return prompt_embeds_npy
|
371 |
+
|
372 |
+
|
373 |
+
def get_alphas_cumprod():
|
374 |
+
betas = torch.linspace(0.00085 ** 0.5, 0.012 ** 0.5, 1000, dtype=torch.float32) ** 2
|
375 |
+
alphas = 1.0 - betas
|
376 |
+
alphas_cumprod = torch.cumprod(alphas, dim=0).detach().numpy()
|
377 |
+
final_alphas_cumprod = alphas_cumprod[0]
|
378 |
+
self_timesteps = np.arange(0, 1000)[::-1].copy().astype(np.int64)
|
379 |
+
return alphas_cumprod, final_alphas_cumprod, self_timesteps
|
380 |
+
|
381 |
+
def resize_and_rgb(image: PIL.Image.Image) -> PIL.Image.Image:
|
382 |
+
"""
|
383 |
+
Resize the image to 512x512 and convert it to RGB.
|
384 |
+
"""
|
385 |
+
return image.resize((512, 512)).convert("RGB")
|
386 |
+
|
387 |
+
|
388 |
+
if __name__ == '__main__':
|
389 |
+
|
390 |
+
"""
|
391 |
+
Usage:
|
392 |
+
- python3 run_img2img_onnx_infer.py --prompt "Astronauts in a jungle, cold color palette, muted colors, detailed, 8k" --unet_model output_onnx/unet_sim.onnx --vae_encoder_model output_onnx/vae_encoder_sim.onnx --vae_decoder_model output_onnx/vae_decoder_sim.onnx --time_input ./output_onnx/time_input.npy --save_dir ./img2img_output.png
|
393 |
+
"""
|
394 |
+
args = get_args()
|
395 |
+
prompt = args.prompt
|
396 |
+
tokenizer_dir = args.text_model_dir + 'tokenizer'
|
397 |
+
text_encoder_dir = args.text_model_dir + 'text_encoder'
|
398 |
+
unet_model = args.unet_model
|
399 |
+
vae_decoder_model = args.vae_decoder_model
|
400 |
+
vae_encoder_model = args.vae_encoder_model
|
401 |
+
init_image = args.init_image
|
402 |
+
time_input = args.time_input
|
403 |
+
save_dir = args.save_dir
|
404 |
+
|
405 |
+
print(f"prompt: {prompt}")
|
406 |
+
print(f"text_tokenizer: {tokenizer_dir}")
|
407 |
+
print(f"text_encoder: {text_encoder_dir}")
|
408 |
+
print(f"unet_model: {unet_model}")
|
409 |
+
print(f"vae_encoder_model: {vae_encoder_model}")
|
410 |
+
print(f"vae_decoder_model: {vae_decoder_model}")
|
411 |
+
print(f"init image: {init_image}")
|
412 |
+
print(f"time_input: {time_input}")
|
413 |
+
print(f"save_dir: {save_dir}")
|
414 |
+
|
415 |
+
# timesteps = np.array([999, 759, 499, 259]).astype(np.int64)
|
416 |
+
|
417 |
+
# text encoder
|
418 |
+
start = time.time()
|
419 |
+
# prompt = "Self-portrait oil painting, a beautiful cyborg with golden hair, 8k"
|
420 |
+
# prompt = "Astronauts in a jungle, cold color palette, muted colors, detailed, 8k"
|
421 |
+
# prompt = "Caricature, a beautiful girl with black hair, 8k"
|
422 |
+
prompt_embeds_npy = get_embeds(prompt, tokenizer_dir, text_encoder_dir)
|
423 |
+
print(f"text encoder take {(1000 * (time.time() - start)):.1f}ms")
|
424 |
+
|
425 |
+
prompt_name = prompt.replace(" ", "_")
|
426 |
+
latents_shape = [1, 4, 64, 64]
|
427 |
+
# latent = torch.randn(latents_shape, generator=None, device="cpu", dtype=torch.float32,
|
428 |
+
# layout=torch.strided).detach().numpy()
|
429 |
+
|
430 |
+
alphas_cumprod, final_alphas_cumprod, self_timesteps = get_alphas_cumprod()
|
431 |
+
|
432 |
+
# load unet model and vae model
|
433 |
+
start = time.time()
|
434 |
+
vae_encoder = onnxruntime.InferenceSession(vae_encoder_model)
|
435 |
+
unet_session_main = onnxruntime.InferenceSession(unet_model)
|
436 |
+
vae_decoder = onnxruntime.InferenceSession(vae_decoder_model)
|
437 |
+
print(f"load models take {(1000 * (time.time() - start)):.1f}ms")
|
438 |
+
|
439 |
+
# load time input file
|
440 |
+
time_input = np.load(time_input)
|
441 |
+
|
442 |
+
# load image
|
443 |
+
# url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/img2img-init.png"
|
444 |
+
url = init_image
|
445 |
+
init_image = load_image(url, convert_method=resize_and_rgb) # U8, (512, 512, 3), RGB
|
446 |
+
init_image_show = init_image
|
447 |
+
|
448 |
+
# vae encoder inference
|
449 |
+
vae_start = time.time()
|
450 |
+
|
451 |
+
init_image = preprocess(init_image) # torch.Size([1, 3, 512, 512])
|
452 |
+
if isinstance(init_image, torch.Tensor):
|
453 |
+
init_image = init_image.detach().numpy()
|
454 |
+
|
455 |
+
vae_encoder_onnx_inp_name = vae_encoder.get_inputs()[0].name
|
456 |
+
vae_encoder_onnx_out_name = vae_encoder.get_outputs()[0].name
|
457 |
+
|
458 |
+
# vae_encoder_out.shape (1, 8, 64, 64)
|
459 |
+
vae_encoder_out = vae_encoder.run(None, {vae_encoder_onnx_inp_name: init_image})[0] # encoder out: torch.Size([1, 8, 64, 64])
|
460 |
+
print(f"vae encoder inference take {(1000 * (time.time() - vae_start)):.1f}ms")
|
461 |
+
|
462 |
+
# vae encoder inference
|
463 |
+
device = torch.device("cpu")
|
464 |
+
vae_encoder_out = torch.from_numpy(vae_encoder_out).to(torch.float32)
|
465 |
+
posterior = DiagonalGaussianDistribution(vae_encoder_out) # 数值基本对的上
|
466 |
+
vae_encode_info = AutoencoderKLOutput(latent_dist=posterior)
|
467 |
+
generator = torch.manual_seed(0)
|
468 |
+
init_latents = retrieve_latents(vae_encode_info, generator=generator) # 数值基本对的上
|
469 |
+
init_latents = init_latents * 0.18215 # 数值基本对的上
|
470 |
+
init_latents = torch.cat([init_latents], dim=0)
|
471 |
+
shape = init_latents.shape
|
472 |
+
dtype = torch.float16
|
473 |
+
noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype) # dtype 不同, 随机值不同
|
474 |
+
# get latents
|
475 |
+
timestep = torch.tensor([499]).to(device)
|
476 |
+
init_latents = add_noise(init_latents.to(device), noise, timestep)
|
477 |
+
latents = init_latents
|
478 |
+
|
479 |
+
latents = latents.detach().cpu().numpy()
|
480 |
+
latent = latents
|
481 |
+
|
482 |
+
# unet inference loop
|
483 |
+
unet_loop_start = time.time()
|
484 |
+
timesteps = np.array([499, 259]).astype(np.int64)
|
485 |
+
self_timesteps = np.array([999, 759, 499, 259]).astype(np.int64)
|
486 |
+
step_index = [2, 3]
|
487 |
+
for i, timestep in enumerate(timesteps):
|
488 |
+
unet_start = time.time()
|
489 |
+
noise_pred = unet_session_main.run(None, {"sample": latent, \
|
490 |
+
"/down_blocks.0/resnets.0/act_1/Mul_output_0": np.expand_dims(time_input[i], axis=0), \
|
491 |
+
"encoder_hidden_states": prompt_embeds_npy})[0]
|
492 |
+
|
493 |
+
print(f"unet once take {(1000 * (time.time() - unet_start)):.1f}ms")
|
494 |
+
|
495 |
+
sample = latent
|
496 |
+
model_output = noise_pred
|
497 |
+
|
498 |
+
# 1. get previous step value
|
499 |
+
prev_step_index = step_index[i] + 1
|
500 |
+
if prev_step_index < len(self_timesteps):
|
501 |
+
prev_timestep = self_timesteps[prev_step_index]
|
502 |
+
else:
|
503 |
+
prev_timestep = timestep
|
504 |
+
|
505 |
+
alpha_prod_t = alphas_cumprod[timestep]
|
506 |
+
alpha_prod_t_prev = alphas_cumprod[prev_timestep] if prev_timestep >= 0 else final_alphas_cumprod
|
507 |
+
beta_prod_t = 1 - alpha_prod_t
|
508 |
+
beta_prod_t_prev = 1 - alpha_prod_t_prev
|
509 |
+
|
510 |
+
# 3. Get scalings for boundary conditions
|
511 |
+
scaled_timestep = timestep * 10
|
512 |
+
c_skip = 0.5 ** 2 / (scaled_timestep ** 2 + 0.5 ** 2)
|
513 |
+
c_out = scaled_timestep / (scaled_timestep ** 2 + 0.5 ** 2) ** 0.5
|
514 |
+
predicted_original_sample = (sample - (beta_prod_t ** 0.5) * model_output) / (alpha_prod_t ** 0.5) # 数值基本对齐
|
515 |
+
|
516 |
+
denoised = c_out * predicted_original_sample + c_skip * sample
|
517 |
+
if step_index[i] != 3:
|
518 |
+
device = torch.device("cpu")
|
519 |
+
noise = randn_tensor(model_output.shape, generator=generator, device=device, dtype=torch.float16).numpy()
|
520 |
+
prev_sample = (alpha_prod_t_prev ** 0.5) * denoised + (beta_prod_t_prev ** 0.5) * noise
|
521 |
+
else:
|
522 |
+
prev_sample = denoised
|
523 |
+
|
524 |
+
latent = prev_sample
|
525 |
+
|
526 |
+
print(f"unet loop take {(1000 * (time.time() - unet_loop_start)):.1f}ms")
|
527 |
+
|
528 |
+
# vae decoder inference
|
529 |
+
vae_start = time.time()
|
530 |
+
latent = latent / 0.18215
|
531 |
+
image = vae_decoder.run(None, {"x": latent})[0] # ['784']
|
532 |
+
print(f"vae decoder inference take {(1000 * (time.time() - vae_start)):.1f}ms")
|
533 |
+
|
534 |
+
# save result
|
535 |
+
save_start = time.time()
|
536 |
+
image = np.transpose(image, (0, 2, 3, 1)).squeeze(axis=0)
|
537 |
+
image_denorm = np.clip(image / 2 + 0.5, 0, 1)
|
538 |
+
image = (image_denorm * 255).round().astype("uint8")
|
539 |
+
pil_image = Image.fromarray(image[:, :, :3])
|
540 |
+
pil_image.save(save_dir)
|
541 |
+
|
542 |
+
grid_img = make_image_grid([init_image_show, pil_image], rows=1, cols=2)
|
543 |
+
grid_img.save(f"./lcm_lora_sdv1-5_imgGrid_output.png")
|
544 |
+
|
545 |
+
print(f"grid image saved in ./lcm_lora_sdv1-5_imgGrid_output.png")
|
546 |
+
print(f"save image take {(1000 * (time.time() - save_start)):.1f}ms")
|
run_txt2img_axe_infer.py
ADDED
@@ -0,0 +1,183 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import List, Union
|
2 |
+
import numpy as np
|
3 |
+
# import onnxruntime
|
4 |
+
import axengine
|
5 |
+
import torch
|
6 |
+
from PIL import Image
|
7 |
+
from transformers import CLIPTokenizer, CLIPTextModel, PreTrainedTokenizer, CLIPTextModelWithProjection
|
8 |
+
|
9 |
+
import time
|
10 |
+
import argparse
|
11 |
+
|
12 |
+
|
13 |
+
def get_args():
|
14 |
+
parser = argparse.ArgumentParser(
|
15 |
+
prog="StableDiffusion",
|
16 |
+
description="Generate picture with the input prompt"
|
17 |
+
)
|
18 |
+
parser.add_argument("--prompt", type=str, required=False, default="Self-portrait oil painting, a beautiful cyborg with golden hair, 8k", help="the input text prompt")
|
19 |
+
parser.add_argument("--text_model_dir", type=str, required=False, default="./models/", help="Path to text encoder and tokenizer files")
|
20 |
+
parser.add_argument("--unet_model", type=str, required=False, default="./models/unet.axmodel", help="Path to unet axmodel model")
|
21 |
+
parser.add_argument("--vae_decoder_model", type=str, required=False, default="./models/vae_decoder.axmodel", help="Path to vae decoder axmodel model")
|
22 |
+
parser.add_argument("--time_input", type=str, required=False, default="./models/time_input_txt2img.npy", help="Path to time input file")
|
23 |
+
parser.add_argument("--save_dir", type=str, required=False, default="./txt2img_output_axe.png", help="Path to the output image file")
|
24 |
+
return parser.parse_args()
|
25 |
+
|
26 |
+
def maybe_convert_prompt(prompt: Union[str, List[str]], tokenizer: "PreTrainedTokenizer"): # noqa: F821
|
27 |
+
if not isinstance(prompt, List):
|
28 |
+
prompts = [prompt]
|
29 |
+
else:
|
30 |
+
prompts = prompt
|
31 |
+
|
32 |
+
prompts = [_maybe_convert_prompt(p, tokenizer) for p in prompts]
|
33 |
+
|
34 |
+
if not isinstance(prompt, List):
|
35 |
+
return prompts[0]
|
36 |
+
|
37 |
+
return prompts
|
38 |
+
|
39 |
+
|
40 |
+
def _maybe_convert_prompt(prompt: str, tokenizer: "PreTrainedTokenizer"): # noqa: F821
|
41 |
+
tokens = tokenizer.tokenize(prompt)
|
42 |
+
unique_tokens = set(tokens)
|
43 |
+
for token in unique_tokens:
|
44 |
+
if token in tokenizer.added_tokens_encoder:
|
45 |
+
replacement = token
|
46 |
+
i = 1
|
47 |
+
while f"{token}_{i}" in tokenizer.added_tokens_encoder:
|
48 |
+
replacement += f" {token}_{i}"
|
49 |
+
i += 1
|
50 |
+
|
51 |
+
prompt = prompt.replace(token, replacement)
|
52 |
+
|
53 |
+
return prompt
|
54 |
+
|
55 |
+
|
56 |
+
def get_embeds(prompt = "Portrait of a pretty girl", tokenizer_dir = "./models/tokenizer", text_encoder_dir = "./models/text_encoder"):
|
57 |
+
tokenizer = CLIPTokenizer.from_pretrained(tokenizer_dir)
|
58 |
+
text_encoder = CLIPTextModel.from_pretrained(text_encoder_dir,
|
59 |
+
torch_dtype=torch.float32,
|
60 |
+
variant="fp16")
|
61 |
+
text_inputs = tokenizer(
|
62 |
+
prompt,
|
63 |
+
padding="max_length",
|
64 |
+
max_length=77,
|
65 |
+
truncation=True,
|
66 |
+
return_tensors="pt",
|
67 |
+
)
|
68 |
+
text_input_ids = text_inputs.input_ids
|
69 |
+
prompt_embeds = text_encoder(text_input_ids.to("cpu"), attention_mask=None)
|
70 |
+
|
71 |
+
prompt_embeds_npy = prompt_embeds[0].detach().numpy()
|
72 |
+
return prompt_embeds_npy
|
73 |
+
|
74 |
+
|
75 |
+
def get_alphas_cumprod():
|
76 |
+
betas = torch.linspace(0.00085 ** 0.5, 0.012 ** 0.5, 1000, dtype=torch.float32) ** 2
|
77 |
+
alphas = 1.0 - betas
|
78 |
+
alphas_cumprod = torch.cumprod(alphas, dim=0).detach().numpy()
|
79 |
+
final_alphas_cumprod = alphas_cumprod[0]
|
80 |
+
self_timesteps = np.arange(0, 1000)[::-1].copy().astype(np.int64)
|
81 |
+
return alphas_cumprod, final_alphas_cumprod, self_timesteps
|
82 |
+
|
83 |
+
|
84 |
+
if __name__ == '__main__':
|
85 |
+
args = get_args()
|
86 |
+
prompt = args.prompt
|
87 |
+
tokenizer_dir = args.text_model_dir + 'tokenizer'
|
88 |
+
text_encoder_dir = args.text_model_dir + 'text_encoder'
|
89 |
+
unet_model = args.unet_model
|
90 |
+
vae_decoder_model = args.vae_decoder_model
|
91 |
+
time_input = args.time_input
|
92 |
+
save_dir = args.save_dir
|
93 |
+
|
94 |
+
print(f"prompt: {prompt}")
|
95 |
+
print(f"text_tokenizer: {tokenizer_dir}")
|
96 |
+
print(f"text_encoder: {text_encoder_dir}")
|
97 |
+
print(f"unet_model: {unet_model}")
|
98 |
+
print(f"vae_decoder_model: {vae_decoder_model}")
|
99 |
+
print(f"time_input: {time_input}")
|
100 |
+
print(f"save_dir: {save_dir}")
|
101 |
+
|
102 |
+
timesteps = np.array([999, 759, 499, 259]).astype(np.int64)
|
103 |
+
|
104 |
+
# text encoder
|
105 |
+
start = time.time()
|
106 |
+
# prompt = "Self-portrait oil painting, a beautiful cyborg with golden hair, 8k"
|
107 |
+
prompt_embeds_npy = get_embeds(prompt, tokenizer_dir, text_encoder_dir)
|
108 |
+
print(f"text encoder take {(1000 * (time.time() - start)):.1f}ms")
|
109 |
+
|
110 |
+
prompt_name = prompt.replace(" ", "_")
|
111 |
+
latents_shape = [1, 4, 64, 64]
|
112 |
+
latent = torch.randn(latents_shape, generator=None, device="cpu", dtype=torch.float32,
|
113 |
+
layout=torch.strided).detach().numpy()
|
114 |
+
|
115 |
+
alphas_cumprod, final_alphas_cumprod, self_timesteps = get_alphas_cumprod()
|
116 |
+
|
117 |
+
# load unet model and vae model
|
118 |
+
start = time.time()
|
119 |
+
unet_session_main = axengine.InferenceSession(unet_model)
|
120 |
+
vae_decoder = axengine.InferenceSession(vae_decoder_model)
|
121 |
+
print(f"load models take {(1000 * (time.time() - start)):.1f}ms")
|
122 |
+
|
123 |
+
# load time input file
|
124 |
+
time_input = np.load(time_input)
|
125 |
+
|
126 |
+
# unet inference loop
|
127 |
+
unet_loop_start = time.time()
|
128 |
+
for i, timestep in enumerate(timesteps):
|
129 |
+
# print(i, timestep)
|
130 |
+
|
131 |
+
unet_start = time.time()
|
132 |
+
noise_pred = unet_session_main.run(None, {"sample": latent, \
|
133 |
+
"/down_blocks.0/resnets.0/act_1/Mul_output_0": np.expand_dims(time_input[i], axis=0), \
|
134 |
+
"encoder_hidden_states": prompt_embeds_npy})[0]
|
135 |
+
|
136 |
+
print(f"unet once take {(1000 * (time.time() - unet_start)):.1f}ms")
|
137 |
+
|
138 |
+
sample = latent
|
139 |
+
model_output = noise_pred
|
140 |
+
if i < 3:
|
141 |
+
prev_timestep = timesteps[i + 1]
|
142 |
+
else:
|
143 |
+
prev_timestep = timestep
|
144 |
+
|
145 |
+
alpha_prod_t = alphas_cumprod[timestep]
|
146 |
+
alpha_prod_t_prev = alphas_cumprod[prev_timestep] if prev_timestep >= 0 else final_alphas_cumprod
|
147 |
+
|
148 |
+
beta_prod_t = 1 - alpha_prod_t
|
149 |
+
beta_prod_t_prev = 1 - alpha_prod_t_prev
|
150 |
+
|
151 |
+
# 3. Get scalings for boundary conditions
|
152 |
+
scaled_timestep = timestep * 10
|
153 |
+
c_skip = 0.5 ** 2 / (scaled_timestep ** 2 + 0.5 ** 2)
|
154 |
+
c_out = scaled_timestep / (scaled_timestep ** 2 + 0.5 ** 2) ** 0.5
|
155 |
+
predicted_original_sample = (sample - (beta_prod_t ** 0.5) * model_output) / (alpha_prod_t ** 0.5)
|
156 |
+
|
157 |
+
denoised = c_out * predicted_original_sample + c_skip * sample
|
158 |
+
|
159 |
+
if i != 3:
|
160 |
+
noise = torch.randn(model_output.shape, generator=None, device="cpu", dtype=torch.float32,
|
161 |
+
layout=torch.strided).to("cpu").detach().numpy()
|
162 |
+
prev_sample = (alpha_prod_t_prev ** 0.5) * denoised + (beta_prod_t_prev ** 0.5) * noise
|
163 |
+
else:
|
164 |
+
prev_sample = denoised
|
165 |
+
|
166 |
+
latent = prev_sample
|
167 |
+
|
168 |
+
print(f"unet loop take {(1000 * (time.time() - unet_loop_start)):.1f}ms")
|
169 |
+
|
170 |
+
# vae inference
|
171 |
+
vae_start = time.time()
|
172 |
+
latent = latent / 0.18215
|
173 |
+
image = vae_decoder.run(None, {"x": latent})[0]
|
174 |
+
print(f"vae inference take {(1000 * (time.time() - vae_start)):.1f}ms")
|
175 |
+
|
176 |
+
# save result
|
177 |
+
save_start = time.time()
|
178 |
+
image = np.transpose(image, (0, 2, 3, 1)).squeeze(axis=0)
|
179 |
+
image_denorm = np.clip(image / 2 + 0.5, 0, 1)
|
180 |
+
image = (image_denorm * 255).round().astype("uint8")
|
181 |
+
pil_image = Image.fromarray(image[:, :, :3])
|
182 |
+
pil_image.save(save_dir)
|
183 |
+
print(f"save image take {(1000 * (time.time() - save_start)):.1f}ms")
|
run_txt2img_onnx_infer.py
ADDED
@@ -0,0 +1,183 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import List, Union
|
2 |
+
import numpy as np
|
3 |
+
import onnxruntime
|
4 |
+
# import axengine
|
5 |
+
import torch
|
6 |
+
from PIL import Image
|
7 |
+
from transformers import CLIPTokenizer, CLIPTextModel, PreTrainedTokenizer, CLIPTextModelWithProjection
|
8 |
+
|
9 |
+
import time
|
10 |
+
import argparse
|
11 |
+
|
12 |
+
|
13 |
+
def get_args():
|
14 |
+
parser = argparse.ArgumentParser(
|
15 |
+
prog="StableDiffusion",
|
16 |
+
description="Generate picture with the input prompt"
|
17 |
+
)
|
18 |
+
parser.add_argument("--prompt", type=str, required=False, default="Self-portrait oil painting, a beautiful cyborg with golden hair, 8k", help="the input text prompt")
|
19 |
+
parser.add_argument("--text_model_dir", type=str, required=False, default="./models/", help="Path to text encoder and tokenizer files")
|
20 |
+
parser.add_argument("--unet_model", type=str, required=False, default="./models/unet.onnx", help="Path to unet ONNX model")
|
21 |
+
parser.add_argument("--vae_decoder_model", type=str, required=False, default="./models/vae_decoder.onnx", help="Path to vae decoder ONNX model")
|
22 |
+
parser.add_argument("--time_input", type=str, required=False, default="./models/time_input_txt2img.npy", help="Path to time input file")
|
23 |
+
parser.add_argument("--save_dir", type=str, required=False, default="./txt2img_output_onnx.png", help="Path to the output image file")
|
24 |
+
return parser.parse_args()
|
25 |
+
|
26 |
+
def maybe_convert_prompt(prompt: Union[str, List[str]], tokenizer: "PreTrainedTokenizer"): # noqa: F821
|
27 |
+
if not isinstance(prompt, List):
|
28 |
+
prompts = [prompt]
|
29 |
+
else:
|
30 |
+
prompts = prompt
|
31 |
+
|
32 |
+
prompts = [_maybe_convert_prompt(p, tokenizer) for p in prompts]
|
33 |
+
|
34 |
+
if not isinstance(prompt, List):
|
35 |
+
return prompts[0]
|
36 |
+
|
37 |
+
return prompts
|
38 |
+
|
39 |
+
|
40 |
+
def _maybe_convert_prompt(prompt: str, tokenizer: "PreTrainedTokenizer"): # noqa: F821
|
41 |
+
tokens = tokenizer.tokenize(prompt)
|
42 |
+
unique_tokens = set(tokens)
|
43 |
+
for token in unique_tokens:
|
44 |
+
if token in tokenizer.added_tokens_encoder:
|
45 |
+
replacement = token
|
46 |
+
i = 1
|
47 |
+
while f"{token}_{i}" in tokenizer.added_tokens_encoder:
|
48 |
+
replacement += f" {token}_{i}"
|
49 |
+
i += 1
|
50 |
+
|
51 |
+
prompt = prompt.replace(token, replacement)
|
52 |
+
|
53 |
+
return prompt
|
54 |
+
|
55 |
+
|
56 |
+
def get_embeds(prompt = "Portrait of a pretty girl", tokenizer_dir = "./models/tokenizer", text_encoder_dir = "./models/text_encoder"):
|
57 |
+
tokenizer = CLIPTokenizer.from_pretrained(tokenizer_dir)
|
58 |
+
text_encoder = CLIPTextModel.from_pretrained(text_encoder_dir,
|
59 |
+
torch_dtype=torch.float32,
|
60 |
+
variant="fp16")
|
61 |
+
text_inputs = tokenizer(
|
62 |
+
prompt,
|
63 |
+
padding="max_length",
|
64 |
+
max_length=77,
|
65 |
+
truncation=True,
|
66 |
+
return_tensors="pt",
|
67 |
+
)
|
68 |
+
text_input_ids = text_inputs.input_ids
|
69 |
+
prompt_embeds = text_encoder(text_input_ids.to("cpu"), attention_mask=None)
|
70 |
+
|
71 |
+
prompt_embeds_npy = prompt_embeds[0].detach().numpy()
|
72 |
+
return prompt_embeds_npy
|
73 |
+
|
74 |
+
|
75 |
+
def get_alphas_cumprod():
|
76 |
+
betas = torch.linspace(0.00085 ** 0.5, 0.012 ** 0.5, 1000, dtype=torch.float32) ** 2
|
77 |
+
alphas = 1.0 - betas
|
78 |
+
alphas_cumprod = torch.cumprod(alphas, dim=0).detach().numpy()
|
79 |
+
final_alphas_cumprod = alphas_cumprod[0]
|
80 |
+
self_timesteps = np.arange(0, 1000)[::-1].copy().astype(np.int64)
|
81 |
+
return alphas_cumprod, final_alphas_cumprod, self_timesteps
|
82 |
+
|
83 |
+
|
84 |
+
if __name__ == '__main__':
|
85 |
+
args = get_args()
|
86 |
+
prompt = args.prompt
|
87 |
+
tokenizer_dir = args.text_model_dir + 'tokenizer'
|
88 |
+
text_encoder_dir = args.text_model_dir + 'text_encoder'
|
89 |
+
unet_model = args.unet_model
|
90 |
+
vae_decoder_model = args.vae_decoder_model
|
91 |
+
time_input = args.time_input
|
92 |
+
save_dir = args.save_dir
|
93 |
+
|
94 |
+
print(f"prompt: {prompt}")
|
95 |
+
print(f"text_tokenizer: {tokenizer_dir}")
|
96 |
+
print(f"text_encoder: {text_encoder_dir}")
|
97 |
+
print(f"unet_model: {unet_model}")
|
98 |
+
print(f"vae_decoder_model: {vae_decoder_model}")
|
99 |
+
print(f"time_input: {time_input}")
|
100 |
+
print(f"save_dir: {save_dir}")
|
101 |
+
|
102 |
+
timesteps = np.array([999, 759, 499, 259]).astype(np.int64)
|
103 |
+
|
104 |
+
# text encoder
|
105 |
+
start = time.time()
|
106 |
+
# prompt = "Self-portrait oil painting, a beautiful cyborg with golden hair, 8k"
|
107 |
+
prompt_embeds_npy = get_embeds(prompt, tokenizer_dir, text_encoder_dir)
|
108 |
+
print(f"text encoder take {(1000 * (time.time() - start)):.1f}ms")
|
109 |
+
|
110 |
+
prompt_name = prompt.replace(" ", "_")
|
111 |
+
latents_shape = [1, 4, 64, 64]
|
112 |
+
latent = torch.randn(latents_shape, generator=None, device="cpu", dtype=torch.float32,
|
113 |
+
layout=torch.strided).detach().numpy()
|
114 |
+
|
115 |
+
alphas_cumprod, final_alphas_cumprod, self_timesteps = get_alphas_cumprod()
|
116 |
+
|
117 |
+
# load unet model and vae model
|
118 |
+
start = time.time()
|
119 |
+
unet_session_main = onnxruntime.InferenceSession(unet_model)
|
120 |
+
vae_decoder = onnxruntime.InferenceSession(vae_decoder_model)
|
121 |
+
print(f"load models take {(1000 * (time.time() - start)):.1f}ms")
|
122 |
+
|
123 |
+
# load time input file
|
124 |
+
time_input = np.load(time_input)
|
125 |
+
|
126 |
+
# unet inference loop
|
127 |
+
unet_loop_start = time.time()
|
128 |
+
for i, timestep in enumerate(timesteps):
|
129 |
+
# print(i, timestep)
|
130 |
+
|
131 |
+
unet_start = time.time()
|
132 |
+
noise_pred = unet_session_main.run(None, {"sample": latent, \
|
133 |
+
"/down_blocks.0/resnets.0/act_1/Mul_output_0": np.expand_dims(time_input[i], axis=0), \
|
134 |
+
"encoder_hidden_states": prompt_embeds_npy})[0]
|
135 |
+
|
136 |
+
print(f"unet once take {(1000 * (time.time() - unet_start)):.1f}ms")
|
137 |
+
|
138 |
+
sample = latent
|
139 |
+
model_output = noise_pred
|
140 |
+
if i < 3:
|
141 |
+
prev_timestep = timesteps[i + 1]
|
142 |
+
else:
|
143 |
+
prev_timestep = timestep
|
144 |
+
|
145 |
+
alpha_prod_t = alphas_cumprod[timestep]
|
146 |
+
alpha_prod_t_prev = alphas_cumprod[prev_timestep] if prev_timestep >= 0 else final_alphas_cumprod
|
147 |
+
|
148 |
+
beta_prod_t = 1 - alpha_prod_t
|
149 |
+
beta_prod_t_prev = 1 - alpha_prod_t_prev
|
150 |
+
|
151 |
+
# 3. Get scalings for boundary conditions
|
152 |
+
scaled_timestep = timestep * 10
|
153 |
+
c_skip = 0.5 ** 2 / (scaled_timestep ** 2 + 0.5 ** 2)
|
154 |
+
c_out = scaled_timestep / (scaled_timestep ** 2 + 0.5 ** 2) ** 0.5
|
155 |
+
predicted_original_sample = (sample - (beta_prod_t ** 0.5) * model_output) / (alpha_prod_t ** 0.5)
|
156 |
+
|
157 |
+
denoised = c_out * predicted_original_sample + c_skip * sample
|
158 |
+
|
159 |
+
if i != 3:
|
160 |
+
noise = torch.randn(model_output.shape, generator=None, device="cpu", dtype=torch.float32,
|
161 |
+
layout=torch.strided).to("cpu").detach().numpy()
|
162 |
+
prev_sample = (alpha_prod_t_prev ** 0.5) * denoised + (beta_prod_t_prev ** 0.5) * noise
|
163 |
+
else:
|
164 |
+
prev_sample = denoised
|
165 |
+
|
166 |
+
latent = prev_sample
|
167 |
+
|
168 |
+
print(f"unet loop take {(1000 * (time.time() - unet_loop_start)):.1f}ms")
|
169 |
+
|
170 |
+
# vae inference
|
171 |
+
vae_start = time.time()
|
172 |
+
latent = latent / 0.18215
|
173 |
+
image = vae_decoder.run(None, {"x": latent})[0]
|
174 |
+
print(f"vae inference take {(1000 * (time.time() - vae_start)):.1f}ms")
|
175 |
+
|
176 |
+
# save result
|
177 |
+
save_start = time.time()
|
178 |
+
image = np.transpose(image, (0, 2, 3, 1)).squeeze(axis=0)
|
179 |
+
image_denorm = np.clip(image / 2 + 0.5, 0, 1)
|
180 |
+
image = (image_denorm * 255).round().astype("uint8")
|
181 |
+
pil_image = Image.fromarray(image[:, :, :3])
|
182 |
+
pil_image.save(save_dir)
|
183 |
+
print(f"save image take {(1000 * (time.time() - save_start)):.1f}ms")
|