|
from typing import List, Union |
|
import numpy as np |
|
import onnxruntime |
|
import torch |
|
from PIL import Image |
|
from transformers import CLIPTokenizer, CLIPTextModel, PreTrainedTokenizer, CLIPTextModelWithProjection |
|
|
|
|
|
import time |
|
import os |
|
import argparse |
|
from diffusers.utils import load_image |
|
import PIL.Image |
|
from typing import List, Optional, Tuple, Union |
|
from diffusers.models.autoencoders.vae import DiagonalGaussianDistribution |
|
from diffusers.models.modeling_outputs import AutoencoderKLOutput |
|
from diffusers.utils.torch_utils import randn_tensor |
|
from diffusers.utils import make_image_grid, load_image |
|
|
|
|
|
|
|
|
|
PipelineImageInput = Union[ |
|
PIL.Image.Image, |
|
np.ndarray, |
|
torch.Tensor, |
|
List[PIL.Image.Image], |
|
List[np.ndarray], |
|
List[torch.Tensor], |
|
] |
|
|
|
PipelineDepthInput = PipelineImageInput |
|
|
|
|
|
def add_noise( |
|
original_samples: torch.Tensor, |
|
noise: torch.Tensor, |
|
timesteps: torch.IntTensor, |
|
) -> torch.Tensor: |
|
|
|
|
|
|
|
|
|
|
|
beta_start = 0.00085 |
|
beta_end = 0.012 |
|
num_train_timesteps = 1000 |
|
betas = torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2 |
|
alphas = 1.0 - betas |
|
alphas_cumprod = torch.cumprod(alphas, dim=0) |
|
alphas_cumprod = alphas_cumprod.to(device=original_samples.device) |
|
alphas_cumprod = alphas_cumprod.to(dtype=original_samples.dtype) |
|
timesteps = timesteps.to(original_samples.device) |
|
|
|
sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5 |
|
sqrt_alpha_prod = sqrt_alpha_prod.flatten() |
|
while len(sqrt_alpha_prod.shape) < len(original_samples.shape): |
|
sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1) |
|
|
|
sqrt_one_minus_alpha_prod = (1 - alphas_cumprod[timesteps]) ** 0.5 |
|
sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten() |
|
while len(sqrt_one_minus_alpha_prod.shape) < len(original_samples.shape): |
|
sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1) |
|
|
|
noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise |
|
return noisy_samples |
|
|
|
def retrieve_latents( |
|
encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample" |
|
): |
|
if hasattr(encoder_output, "latent_dist") and sample_mode == "sample": |
|
return encoder_output.latent_dist.sample(generator) |
|
elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax": |
|
return encoder_output.latent_dist.mode() |
|
elif hasattr(encoder_output, "latents"): |
|
return encoder_output.latents |
|
else: |
|
raise AttributeError("Could not access latents of provided encoder_output") |
|
|
|
def numpy_to_pt(images: np.ndarray) -> torch.Tensor: |
|
r""" |
|
Convert a NumPy image to a PyTorch tensor. |
|
|
|
Args: |
|
images (`np.ndarray`): |
|
The NumPy image array to convert to PyTorch format. |
|
|
|
Returns: |
|
`torch.Tensor`: |
|
A PyTorch tensor representation of the images. |
|
""" |
|
if images.ndim == 3: |
|
images = images[..., None] |
|
|
|
images = torch.from_numpy(images.transpose(0, 3, 1, 2)) |
|
return images |
|
|
|
def pil_to_numpy(images: Union[List[PIL.Image.Image], PIL.Image.Image]) -> np.ndarray: |
|
r""" |
|
Convert a PIL image or a list of PIL images to NumPy arrays. |
|
|
|
Args: |
|
images (`PIL.Image.Image` or `List[PIL.Image.Image]`): |
|
The PIL image or list of images to convert to NumPy format. |
|
|
|
Returns: |
|
`np.ndarray`: |
|
A NumPy array representation of the images. |
|
""" |
|
if not isinstance(images, list): |
|
images = [images] |
|
images = [np.array(image).astype(np.float32) / 255.0 for image in images] |
|
images = np.stack(images, axis=0) |
|
|
|
return images |
|
|
|
def is_valid_image(image) -> bool: |
|
r""" |
|
Checks if the input is a valid image. |
|
|
|
A valid image can be: |
|
- A `PIL.Image.Image`. |
|
- A 2D or 3D `np.ndarray` or `torch.Tensor` (grayscale or color image). |
|
|
|
Args: |
|
image (`Union[PIL.Image.Image, np.ndarray, torch.Tensor]`): |
|
The image to validate. It can be a PIL image, a NumPy array, or a torch tensor. |
|
|
|
Returns: |
|
`bool`: |
|
`True` if the input is a valid image, `False` otherwise. |
|
""" |
|
return isinstance(image, PIL.Image.Image) or isinstance(image, (np.ndarray, torch.Tensor)) and image.ndim in (2, 3) |
|
|
|
def is_valid_image_imagelist(images): |
|
r""" |
|
Checks if the input is a valid image or list of images. |
|
|
|
The input can be one of the following formats: |
|
- A 4D tensor or numpy array (batch of images). |
|
- A valid single image: `PIL.Image.Image`, 2D `np.ndarray` or `torch.Tensor` (grayscale image), 3D `np.ndarray` or |
|
`torch.Tensor`. |
|
- A list of valid images. |
|
|
|
Args: |
|
images (`Union[np.ndarray, torch.Tensor, PIL.Image.Image, List]`): |
|
The image(s) to check. Can be a batch of images (4D tensor/array), a single image, or a list of valid |
|
images. |
|
|
|
Returns: |
|
`bool`: |
|
`True` if the input is valid, `False` otherwise. |
|
""" |
|
if isinstance(images, (np.ndarray, torch.Tensor)) and images.ndim == 4: |
|
return True |
|
elif is_valid_image(images): |
|
return True |
|
elif isinstance(images, list): |
|
return all(is_valid_image(image) for image in images) |
|
return False |
|
|
|
|
|
def normalize(images: Union[np.ndarray, torch.Tensor]) -> Union[np.ndarray, torch.Tensor]: |
|
r""" |
|
Normalize an image array to [-1,1]. |
|
|
|
Args: |
|
images (`np.ndarray` or `torch.Tensor`): |
|
The image array to normalize. |
|
|
|
Returns: |
|
`np.ndarray` or `torch.Tensor`: |
|
The normalized image array. |
|
""" |
|
return 2.0 * images - 1.0 |
|
|
|
|
|
def preprocess( |
|
image: PipelineImageInput, |
|
height: Optional[int] = None, |
|
width: Optional[int] = None, |
|
resize_mode: str = "default", |
|
crops_coords: Optional[Tuple[int, int, int, int]] = None, |
|
) -> torch.Tensor: |
|
""" |
|
Preprocess the image input. |
|
|
|
Args: |
|
image (`PipelineImageInput`): |
|
The image input, accepted formats are PIL images, NumPy arrays, PyTorch tensors; Also accept list of |
|
supported formats. |
|
height (`int`, *optional*): |
|
The height in preprocessed image. If `None`, will use the `get_default_height_width()` to get default |
|
height. |
|
width (`int`, *optional*): |
|
The width in preprocessed. If `None`, will use get_default_height_width()` to get the default width. |
|
resize_mode (`str`, *optional*, defaults to `default`): |
|
The resize mode, can be one of `default` or `fill`. If `default`, will resize the image to fit within |
|
the specified width and height, and it may not maintaining the original aspect ratio. If `fill`, will |
|
resize the image to fit within the specified width and height, maintaining the aspect ratio, and then |
|
center the image within the dimensions, filling empty with data from image. If `crop`, will resize the |
|
image to fit within the specified width and height, maintaining the aspect ratio, and then center the |
|
image within the dimensions, cropping the excess. Note that resize_mode `fill` and `crop` are only |
|
supported for PIL image input. |
|
crops_coords (`List[Tuple[int, int, int, int]]`, *optional*, defaults to `None`): |
|
The crop coordinates for each image in the batch. If `None`, will not crop the image. |
|
|
|
Returns: |
|
`torch.Tensor`: |
|
The preprocessed image. |
|
""" |
|
supported_formats = (PIL.Image.Image, np.ndarray, torch.Tensor) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if isinstance(image, list) and isinstance(image[0], np.ndarray) and image[0].ndim == 4: |
|
warnings.warn( |
|
"Passing `image` as a list of 4d np.ndarray is deprecated." |
|
"Please concatenate the list along the batch dimension and pass it as a single 4d np.ndarray", |
|
FutureWarning, |
|
) |
|
image = np.concatenate(image, axis=0) |
|
if isinstance(image, list) and isinstance(image[0], torch.Tensor) and image[0].ndim == 4: |
|
warnings.warn( |
|
"Passing `image` as a list of 4d torch.Tensor is deprecated." |
|
"Please concatenate the list along the batch dimension and pass it as a single 4d torch.Tensor", |
|
FutureWarning, |
|
) |
|
image = torch.cat(image, axis=0) |
|
|
|
if not is_valid_image_imagelist(image): |
|
raise ValueError( |
|
f"Input is in incorrect format. Currently, we only support {', '.join(str(x) for x in supported_formats)}" |
|
) |
|
if not isinstance(image, list): |
|
image = [image] |
|
|
|
if isinstance(image[0], PIL.Image.Image): |
|
if crops_coords is not None: |
|
image = [i.crop(crops_coords) for i in image] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
image = pil_to_numpy(image) |
|
image = numpy_to_pt(image) |
|
|
|
elif isinstance(image[0], np.ndarray): |
|
image = np.concatenate(image, axis=0) if image[0].ndim == 4 else np.stack(image, axis=0) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
elif isinstance(image[0], torch.Tensor): |
|
image = torch.cat(image, axis=0) if image[0].ndim == 4 else torch.stack(image, axis=0) |
|
|
|
|
|
|
|
|
|
channel = image.shape[1] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
do_normalize = True |
|
if do_normalize and image.min() < 0: |
|
warnings.warn( |
|
"Passing `image` as torch tensor with value range in [-1,1] is deprecated. The expected value range for image tensor is [0,1] " |
|
f"when passing as pytorch tensor or numpy Array. You passed `image` with value range [{image.min()},{image.max()}]", |
|
FutureWarning, |
|
) |
|
do_normalize = False |
|
if do_normalize: |
|
image = normalize(image) |
|
|
|
|
|
|
|
|
|
return image |
|
|
|
|
|
|
|
def get_args(): |
|
parser = argparse.ArgumentParser( |
|
prog="StableDiffusion", |
|
description="Generate picture with the input prompt" |
|
) |
|
parser.add_argument("--prompt", type=str, required=False, default="Astronauts in a jungle, cold color palette, muted colors, detailed, 8k", help="the input text prompt") |
|
parser.add_argument("--text_model_dir", type=str, required=False, default="./models/", help="Path to text encoder and tokenizer files") |
|
parser.add_argument("--unet_model", type=str, required=False, default="./models/unet.onnx", help="Path to unet ONNX model") |
|
parser.add_argument("--vae_encoder_model", type=str, required=False, default="./models/vae_encoder.onnx", help="Path to vae encoder ONNX model") |
|
parser.add_argument("--vae_decoder_model", type=str, required=False, default="./models/vae_decoder.onnx", help="Path to vae decoder ONNX model") |
|
parser.add_argument("--time_input", type=str, required=False, default="./models/time_input_img2img.npy", help="Path to time input file") |
|
parser.add_argument("--init_image", type=str, required=False, default="./models/img2img-init.png", help="Path to initial image file") |
|
parser.add_argument("--save_dir", type=str, required=False, default="./img2img_output_onnx.png", help="Path to the output image file") |
|
return parser.parse_args() |
|
|
|
def maybe_convert_prompt(prompt: Union[str, List[str]], tokenizer: "PreTrainedTokenizer"): |
|
if not isinstance(prompt, List): |
|
prompts = [prompt] |
|
else: |
|
prompts = prompt |
|
|
|
prompts = [_maybe_convert_prompt(p, tokenizer) for p in prompts] |
|
|
|
if not isinstance(prompt, List): |
|
return prompts[0] |
|
|
|
return prompts |
|
|
|
|
|
def _maybe_convert_prompt(prompt: str, tokenizer: "PreTrainedTokenizer"): |
|
tokens = tokenizer.tokenize(prompt) |
|
unique_tokens = set(tokens) |
|
for token in unique_tokens: |
|
if token in tokenizer.added_tokens_encoder: |
|
replacement = token |
|
i = 1 |
|
while f"{token}_{i}" in tokenizer.added_tokens_encoder: |
|
replacement += f" {token}_{i}" |
|
i += 1 |
|
|
|
prompt = prompt.replace(token, replacement) |
|
|
|
return prompt |
|
|
|
|
|
def get_embeds(prompt = "Portrait of a pretty girl", tokenizer_dir = "./models/tokenizer", text_encoder_dir = "./models/text_encoder"): |
|
tokenizer = CLIPTokenizer.from_pretrained(tokenizer_dir) |
|
|
|
text_inputs = tokenizer( |
|
prompt, |
|
padding="max_length", |
|
max_length=77, |
|
truncation=True, |
|
return_tensors="pt", |
|
) |
|
text_input_ids = text_inputs.input_ids |
|
|
|
text_encoder = onnxruntime.InferenceSession( |
|
os.path.join( |
|
text_encoder_dir, |
|
"sd15_text_encoder_sim.onnx" |
|
), |
|
providers=["CPUExecutionProvider"] |
|
) |
|
text_encoder_onnx_out = text_encoder.run(None, {"input_ids": text_input_ids.to("cpu").numpy()})[0] |
|
|
|
prompt_embeds_npy = text_encoder_onnx_out |
|
return prompt_embeds_npy |
|
|
|
|
|
def get_alphas_cumprod(): |
|
betas = torch.linspace(0.00085 ** 0.5, 0.012 ** 0.5, 1000, dtype=torch.float32) ** 2 |
|
alphas = 1.0 - betas |
|
alphas_cumprod = torch.cumprod(alphas, dim=0).detach().numpy() |
|
final_alphas_cumprod = alphas_cumprod[0] |
|
self_timesteps = np.arange(0, 1000)[::-1].copy().astype(np.int64) |
|
return alphas_cumprod, final_alphas_cumprod, self_timesteps |
|
|
|
def resize_and_rgb(image: PIL.Image.Image) -> PIL.Image.Image: |
|
""" |
|
Resize the image to 512x512 and convert it to RGB. |
|
""" |
|
return image.resize((512, 512)).convert("RGB") |
|
|
|
|
|
if __name__ == '__main__': |
|
|
|
""" |
|
Usage: |
|
- python3 run_img2img_onnx_infer.py --prompt "Astronauts in a jungle, cold color palette, muted colors, detailed, 8k" --unet_model output_onnx/unet_sim.onnx --vae_encoder_model output_onnx/vae_encoder_sim.onnx --vae_decoder_model output_onnx/vae_decoder_sim.onnx --time_input ./output_onnx/time_input.npy --save_dir ./img2img_output.png |
|
""" |
|
args = get_args() |
|
prompt = args.prompt |
|
tokenizer_dir = args.text_model_dir + 'tokenizer' |
|
text_encoder_dir = args.text_model_dir + 'text_encoder' |
|
unet_model = args.unet_model |
|
vae_decoder_model = args.vae_decoder_model |
|
vae_encoder_model = args.vae_encoder_model |
|
init_image = args.init_image |
|
time_input = args.time_input |
|
save_dir = args.save_dir |
|
|
|
print(f"prompt: {prompt}") |
|
print(f"text_tokenizer: {tokenizer_dir}") |
|
print(f"text_encoder: {text_encoder_dir}") |
|
print(f"unet_model: {unet_model}") |
|
print(f"vae_encoder_model: {vae_encoder_model}") |
|
print(f"vae_decoder_model: {vae_decoder_model}") |
|
print(f"init image: {init_image}") |
|
print(f"time_input: {time_input}") |
|
print(f"save_dir: {save_dir}") |
|
|
|
|
|
|
|
|
|
start = time.time() |
|
|
|
|
|
|
|
prompt_embeds_npy = get_embeds(prompt, tokenizer_dir, text_encoder_dir) |
|
print(f"text encoder take {(1000 * (time.time() - start)):.1f}ms") |
|
|
|
prompt_name = prompt.replace(" ", "_") |
|
latents_shape = [1, 4, 64, 64] |
|
|
|
|
|
|
|
alphas_cumprod, final_alphas_cumprod, self_timesteps = get_alphas_cumprod() |
|
|
|
|
|
start = time.time() |
|
vae_encoder = onnxruntime.InferenceSession(vae_encoder_model, providers=["CPUExecutionProvider"]) |
|
unet_session_main = onnxruntime.InferenceSession(unet_model, providers=["CPUExecutionProvider"]) |
|
vae_decoder = onnxruntime.InferenceSession(vae_decoder_model, providers=["CPUExecutionProvider"]) |
|
print(f"load models take {(1000 * (time.time() - start)):.1f}ms") |
|
|
|
|
|
time_input = np.load(time_input) |
|
|
|
|
|
|
|
url = init_image |
|
init_image = load_image(url, convert_method=resize_and_rgb) |
|
init_image_show = init_image |
|
|
|
|
|
vae_start = time.time() |
|
|
|
init_image = preprocess(init_image) |
|
if isinstance(init_image, torch.Tensor): |
|
init_image = init_image.detach().numpy() |
|
|
|
vae_encoder_onnx_inp_name = vae_encoder.get_inputs()[0].name |
|
vae_encoder_onnx_out_name = vae_encoder.get_outputs()[0].name |
|
|
|
|
|
vae_encoder_out = vae_encoder.run(None, {vae_encoder_onnx_inp_name: init_image})[0] |
|
print(f"vae encoder inference take {(1000 * (time.time() - vae_start)):.1f}ms") |
|
|
|
|
|
device = torch.device("cpu") |
|
vae_encoder_out = torch.from_numpy(vae_encoder_out).to(torch.float32) |
|
posterior = DiagonalGaussianDistribution(vae_encoder_out) |
|
vae_encode_info = AutoencoderKLOutput(latent_dist=posterior) |
|
generator = torch.manual_seed(0) |
|
init_latents = retrieve_latents(vae_encode_info, generator=generator) |
|
init_latents = init_latents * 0.18215 |
|
init_latents = torch.cat([init_latents], dim=0) |
|
shape = init_latents.shape |
|
dtype = torch.float16 |
|
noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype) |
|
|
|
timestep = torch.tensor([499]).to(device) |
|
init_latents = add_noise(init_latents.to(device), noise, timestep) |
|
latents = init_latents |
|
|
|
latents = latents.detach().cpu().numpy() |
|
latent = latents |
|
|
|
|
|
unet_loop_start = time.time() |
|
timesteps = np.array([499, 259]).astype(np.int64) |
|
self_timesteps = np.array([999, 759, 499, 259]).astype(np.int64) |
|
step_index = [2, 3] |
|
for i, timestep in enumerate(timesteps): |
|
unet_start = time.time() |
|
noise_pred = unet_session_main.run(None, {"sample": latent, \ |
|
"/down_blocks.0/resnets.0/act_1/Mul_output_0": np.expand_dims(time_input[i], axis=0), \ |
|
"encoder_hidden_states": prompt_embeds_npy})[0] |
|
|
|
print(f"unet once take {(1000 * (time.time() - unet_start)):.1f}ms") |
|
|
|
sample = latent |
|
model_output = noise_pred |
|
|
|
|
|
prev_step_index = step_index[i] + 1 |
|
if prev_step_index < len(self_timesteps): |
|
prev_timestep = self_timesteps[prev_step_index] |
|
else: |
|
prev_timestep = timestep |
|
|
|
alpha_prod_t = alphas_cumprod[timestep] |
|
alpha_prod_t_prev = alphas_cumprod[prev_timestep] if prev_timestep >= 0 else final_alphas_cumprod |
|
beta_prod_t = 1 - alpha_prod_t |
|
beta_prod_t_prev = 1 - alpha_prod_t_prev |
|
|
|
|
|
scaled_timestep = timestep * 10 |
|
c_skip = 0.5 ** 2 / (scaled_timestep ** 2 + 0.5 ** 2) |
|
c_out = scaled_timestep / (scaled_timestep ** 2 + 0.5 ** 2) ** 0.5 |
|
predicted_original_sample = (sample - (beta_prod_t ** 0.5) * model_output) / (alpha_prod_t ** 0.5) |
|
|
|
denoised = c_out * predicted_original_sample + c_skip * sample |
|
if step_index[i] != 3: |
|
device = torch.device("cpu") |
|
noise = randn_tensor(model_output.shape, generator=generator, device=device, dtype=torch.float16).numpy() |
|
prev_sample = (alpha_prod_t_prev ** 0.5) * denoised + (beta_prod_t_prev ** 0.5) * noise |
|
else: |
|
prev_sample = denoised |
|
|
|
latent = prev_sample |
|
|
|
print(f"unet loop take {(1000 * (time.time() - unet_loop_start)):.1f}ms") |
|
|
|
|
|
vae_start = time.time() |
|
latent = latent / 0.18215 |
|
image = vae_decoder.run(None, {"x": latent})[0] |
|
print(f"vae decoder inference take {(1000 * (time.time() - vae_start)):.1f}ms") |
|
|
|
|
|
save_start = time.time() |
|
image = np.transpose(image, (0, 2, 3, 1)).squeeze(axis=0) |
|
image_denorm = np.clip(image / 2 + 0.5, 0, 1) |
|
image = (image_denorm * 255).round().astype("uint8") |
|
pil_image = Image.fromarray(image[:, :, :3]) |
|
pil_image.save(save_dir) |
|
|
|
grid_img = make_image_grid([init_image_show, pil_image], rows=1, cols=2) |
|
grid_img.save(f"./lcm_lora_sdv1-5_imgGrid_output.png") |
|
|
|
print(f"grid image saved in ./lcm_lora_sdv1-5_imgGrid_output.png") |
|
print(f"save image take {(1000 * (time.time() - save_start)):.1f}ms") |
|
|