from typing import List, Union import numpy as np import onnxruntime import torch from PIL import Image from transformers import CLIPTokenizer, CLIPTextModel, PreTrainedTokenizer, CLIPTextModelWithProjection # import axengine as axe import time import os import argparse from diffusers.utils import load_image import PIL.Image from typing import List, Optional, Tuple, Union from diffusers.models.autoencoders.vae import DiagonalGaussianDistribution from diffusers.models.modeling_outputs import AutoencoderKLOutput from diffusers.utils.torch_utils import randn_tensor from diffusers.utils import make_image_grid, load_image ########## Img2Img PipelineImageInput = Union[ PIL.Image.Image, np.ndarray, torch.Tensor, List[PIL.Image.Image], List[np.ndarray], List[torch.Tensor], ] PipelineDepthInput = PipelineImageInput # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.add_noise def add_noise( original_samples: torch.Tensor, noise: torch.Tensor, timesteps: torch.IntTensor, ) -> torch.Tensor: # Make sure alphas_cumprod and timestep have same device and dtype as original_samples # Move the self.alphas_cumprod to device to avoid redundant CPU to GPU data movement # for the subsequent add_noise calls # self.alphas_cumprod = self.alphas_cumprod.to(device=original_samples.device) # Convert betas to alphas_bar_sqrt beta_start = 0.00085 beta_end = 0.012 num_train_timesteps = 1000 betas = torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2 alphas = 1.0 - betas alphas_cumprod = torch.cumprod(alphas, dim=0) alphas_cumprod = alphas_cumprod.to(device=original_samples.device) alphas_cumprod = alphas_cumprod.to(dtype=original_samples.dtype) timesteps = timesteps.to(original_samples.device) sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5 sqrt_alpha_prod = sqrt_alpha_prod.flatten() while len(sqrt_alpha_prod.shape) < len(original_samples.shape): sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1) sqrt_one_minus_alpha_prod = (1 - alphas_cumprod[timesteps]) ** 0.5 sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten() while len(sqrt_one_minus_alpha_prod.shape) < len(original_samples.shape): sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1) noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise return noisy_samples def retrieve_latents( encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample" ): if hasattr(encoder_output, "latent_dist") and sample_mode == "sample": return encoder_output.latent_dist.sample(generator) elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax": return encoder_output.latent_dist.mode() elif hasattr(encoder_output, "latents"): return encoder_output.latents else: raise AttributeError("Could not access latents of provided encoder_output") def numpy_to_pt(images: np.ndarray) -> torch.Tensor: r""" Convert a NumPy image to a PyTorch tensor. Args: images (`np.ndarray`): The NumPy image array to convert to PyTorch format. Returns: `torch.Tensor`: A PyTorch tensor representation of the images. """ if images.ndim == 3: images = images[..., None] images = torch.from_numpy(images.transpose(0, 3, 1, 2)) return images def pil_to_numpy(images: Union[List[PIL.Image.Image], PIL.Image.Image]) -> np.ndarray: r""" Convert a PIL image or a list of PIL images to NumPy arrays. Args: images (`PIL.Image.Image` or `List[PIL.Image.Image]`): The PIL image or list of images to convert to NumPy format. Returns: `np.ndarray`: A NumPy array representation of the images. """ if not isinstance(images, list): images = [images] images = [np.array(image).astype(np.float32) / 255.0 for image in images] images = np.stack(images, axis=0) return images def is_valid_image(image) -> bool: r""" Checks if the input is a valid image. A valid image can be: - A `PIL.Image.Image`. - A 2D or 3D `np.ndarray` or `torch.Tensor` (grayscale or color image). Args: image (`Union[PIL.Image.Image, np.ndarray, torch.Tensor]`): The image to validate. It can be a PIL image, a NumPy array, or a torch tensor. Returns: `bool`: `True` if the input is a valid image, `False` otherwise. """ return isinstance(image, PIL.Image.Image) or isinstance(image, (np.ndarray, torch.Tensor)) and image.ndim in (2, 3) def is_valid_image_imagelist(images): r""" Checks if the input is a valid image or list of images. The input can be one of the following formats: - A 4D tensor or numpy array (batch of images). - A valid single image: `PIL.Image.Image`, 2D `np.ndarray` or `torch.Tensor` (grayscale image), 3D `np.ndarray` or `torch.Tensor`. - A list of valid images. Args: images (`Union[np.ndarray, torch.Tensor, PIL.Image.Image, List]`): The image(s) to check. Can be a batch of images (4D tensor/array), a single image, or a list of valid images. Returns: `bool`: `True` if the input is valid, `False` otherwise. """ if isinstance(images, (np.ndarray, torch.Tensor)) and images.ndim == 4: return True elif is_valid_image(images): return True elif isinstance(images, list): return all(is_valid_image(image) for image in images) return False def normalize(images: Union[np.ndarray, torch.Tensor]) -> Union[np.ndarray, torch.Tensor]: r""" Normalize an image array to [-1,1]. Args: images (`np.ndarray` or `torch.Tensor`): The image array to normalize. Returns: `np.ndarray` or `torch.Tensor`: The normalized image array. """ return 2.0 * images - 1.0 # Copy from: /home/baiyongqiang/miniforge-pypy3/envs/hf/lib/python3.9/site-packages/diffusers/image_processor.py#607 def preprocess( image: PipelineImageInput, height: Optional[int] = None, width: Optional[int] = None, resize_mode: str = "default", # "default", "fill", "crop" crops_coords: Optional[Tuple[int, int, int, int]] = None, ) -> torch.Tensor: """ Preprocess the image input. Args: image (`PipelineImageInput`): The image input, accepted formats are PIL images, NumPy arrays, PyTorch tensors; Also accept list of supported formats. height (`int`, *optional*): The height in preprocessed image. If `None`, will use the `get_default_height_width()` to get default height. width (`int`, *optional*): The width in preprocessed. If `None`, will use get_default_height_width()` to get the default width. resize_mode (`str`, *optional*, defaults to `default`): The resize mode, can be one of `default` or `fill`. If `default`, will resize the image to fit within the specified width and height, and it may not maintaining the original aspect ratio. If `fill`, will resize the image to fit within the specified width and height, maintaining the aspect ratio, and then center the image within the dimensions, filling empty with data from image. If `crop`, will resize the image to fit within the specified width and height, maintaining the aspect ratio, and then center the image within the dimensions, cropping the excess. Note that resize_mode `fill` and `crop` are only supported for PIL image input. crops_coords (`List[Tuple[int, int, int, int]]`, *optional*, defaults to `None`): The crop coordinates for each image in the batch. If `None`, will not crop the image. Returns: `torch.Tensor`: The preprocessed image. """ supported_formats = (PIL.Image.Image, np.ndarray, torch.Tensor) # # Expand the missing dimension for 3-dimensional pytorch tensor or numpy array that represents grayscale image # if self.config.do_convert_grayscale and isinstance(image, (torch.Tensor, np.ndarray)) and image.ndim == 3: # if isinstance(image, torch.Tensor): # # if image is a pytorch tensor could have 2 possible shapes: # # 1. batch x height x width: we should insert the channel dimension at position 1 # # 2. channel x height x width: we should insert batch dimension at position 0, # # however, since both channel and batch dimension has same size 1, it is same to insert at position 1 # # for simplicity, we insert a dimension of size 1 at position 1 for both cases # image = image.unsqueeze(1) # else: # # if it is a numpy array, it could have 2 possible shapes: # # 1. batch x height x width: insert channel dimension on last position # # 2. height x width x channel: insert batch dimension on first position # if image.shape[-1] == 1: # image = np.expand_dims(image, axis=0) # else: # image = np.expand_dims(image, axis=-1) if isinstance(image, list) and isinstance(image[0], np.ndarray) and image[0].ndim == 4: warnings.warn( "Passing `image` as a list of 4d np.ndarray is deprecated." "Please concatenate the list along the batch dimension and pass it as a single 4d np.ndarray", FutureWarning, ) image = np.concatenate(image, axis=0) if isinstance(image, list) and isinstance(image[0], torch.Tensor) and image[0].ndim == 4: warnings.warn( "Passing `image` as a list of 4d torch.Tensor is deprecated." "Please concatenate the list along the batch dimension and pass it as a single 4d torch.Tensor", FutureWarning, ) image = torch.cat(image, axis=0) if not is_valid_image_imagelist(image): raise ValueError( f"Input is in incorrect format. Currently, we only support {', '.join(str(x) for x in supported_formats)}" ) if not isinstance(image, list): image = [image] if isinstance(image[0], PIL.Image.Image): if crops_coords is not None: image = [i.crop(crops_coords) for i in image] # if self.config.do_resize: # height, width = self.get_default_height_width(image[0], height, width) # image = [self.resize(i, height, width, resize_mode=resize_mode) for i in image] # if self.config.do_convert_rgb: # image = [self.convert_to_rgb(i) for i in image] # elif self.config.do_convert_grayscale: # image = [self.convert_to_grayscale(i) for i in image] image = pil_to_numpy(image) # to np image = numpy_to_pt(image) # to pt elif isinstance(image[0], np.ndarray): image = np.concatenate(image, axis=0) if image[0].ndim == 4 else np.stack(image, axis=0) # image = self.numpy_to_pt(image) # height, width = self.get_default_height_width(image, height, width) # if self.config.do_resize: # image = self.resize(image, height, width) elif isinstance(image[0], torch.Tensor): image = torch.cat(image, axis=0) if image[0].ndim == 4 else torch.stack(image, axis=0) # if self.config.do_convert_grayscale and image.ndim == 3: # image = image.unsqueeze(1) channel = image.shape[1] # don't need any preprocess if the image is latents # if channel == self.config.vae_latent_channels: # return image # height, width = self.get_default_height_width(image, height, width) # if self.config.do_resize: # image = self.resize(image, height, width) # expected range [0,1], normalize to [-1,1] do_normalize = True # self.config.do_normalize if do_normalize and image.min() < 0: warnings.warn( "Passing `image` as torch tensor with value range in [-1,1] is deprecated. The expected value range for image tensor is [0,1] " f"when passing as pytorch tensor or numpy Array. You passed `image` with value range [{image.min()},{image.max()}]", FutureWarning, ) do_normalize = False if do_normalize: image = normalize(image) # if self.config.do_binarize: # image = self.binarize(image) return image ########## def get_args(): parser = argparse.ArgumentParser( prog="StableDiffusion", description="Generate picture with the input prompt" ) parser.add_argument("--prompt", type=str, required=False, default="Astronauts in a jungle, cold color palette, muted colors, detailed, 8k", help="the input text prompt") parser.add_argument("--text_model_dir", type=str, required=False, default="./models/", help="Path to text encoder and tokenizer files") parser.add_argument("--unet_model", type=str, required=False, default="./models/unet.onnx", help="Path to unet ONNX model") parser.add_argument("--vae_encoder_model", type=str, required=False, default="./models/vae_encoder.onnx", help="Path to vae encoder ONNX model") parser.add_argument("--vae_decoder_model", type=str, required=False, default="./models/vae_decoder.onnx", help="Path to vae decoder ONNX model") parser.add_argument("--time_input", type=str, required=False, default="./models/time_input_img2img.npy", help="Path to time input file") parser.add_argument("--init_image", type=str, required=False, default="./models/img2img-init.png", help="Path to initial image file") parser.add_argument("--save_dir", type=str, required=False, default="./img2img_output_onnx.png", help="Path to the output image file") return parser.parse_args() def maybe_convert_prompt(prompt: Union[str, List[str]], tokenizer: "PreTrainedTokenizer"): # noqa: F821 if not isinstance(prompt, List): prompts = [prompt] else: prompts = prompt prompts = [_maybe_convert_prompt(p, tokenizer) for p in prompts] if not isinstance(prompt, List): return prompts[0] return prompts def _maybe_convert_prompt(prompt: str, tokenizer: "PreTrainedTokenizer"): # noqa: F821 tokens = tokenizer.tokenize(prompt) unique_tokens = set(tokens) for token in unique_tokens: if token in tokenizer.added_tokens_encoder: replacement = token i = 1 while f"{token}_{i}" in tokenizer.added_tokens_encoder: replacement += f" {token}_{i}" i += 1 prompt = prompt.replace(token, replacement) return prompt def get_embeds(prompt = "Portrait of a pretty girl", tokenizer_dir = "./models/tokenizer", text_encoder_dir = "./models/text_encoder"): tokenizer = CLIPTokenizer.from_pretrained(tokenizer_dir) text_inputs = tokenizer( prompt, padding="max_length", max_length=77, truncation=True, return_tensors="pt", ) text_input_ids = text_inputs.input_ids text_encoder = onnxruntime.InferenceSession( os.path.join( text_encoder_dir, "sd15_text_encoder_sim.onnx" ), providers=["CPUExecutionProvider"] ) text_encoder_onnx_out = text_encoder.run(None, {"input_ids": text_input_ids.to("cpu").numpy()})[0] prompt_embeds_npy = text_encoder_onnx_out return prompt_embeds_npy def get_alphas_cumprod(): betas = torch.linspace(0.00085 ** 0.5, 0.012 ** 0.5, 1000, dtype=torch.float32) ** 2 alphas = 1.0 - betas alphas_cumprod = torch.cumprod(alphas, dim=0).detach().numpy() final_alphas_cumprod = alphas_cumprod[0] self_timesteps = np.arange(0, 1000)[::-1].copy().astype(np.int64) return alphas_cumprod, final_alphas_cumprod, self_timesteps def resize_and_rgb(image: PIL.Image.Image) -> PIL.Image.Image: """ Resize the image to 512x512 and convert it to RGB. """ return image.resize((512, 512)).convert("RGB") if __name__ == '__main__': """ Usage: - python3 run_img2img_onnx_infer.py --prompt "Astronauts in a jungle, cold color palette, muted colors, detailed, 8k" --unet_model output_onnx/unet_sim.onnx --vae_encoder_model output_onnx/vae_encoder_sim.onnx --vae_decoder_model output_onnx/vae_decoder_sim.onnx --time_input ./output_onnx/time_input.npy --save_dir ./img2img_output.png """ args = get_args() prompt = args.prompt tokenizer_dir = args.text_model_dir + 'tokenizer' text_encoder_dir = args.text_model_dir + 'text_encoder' unet_model = args.unet_model vae_decoder_model = args.vae_decoder_model vae_encoder_model = args.vae_encoder_model init_image = args.init_image time_input = args.time_input save_dir = args.save_dir print(f"prompt: {prompt}") print(f"text_tokenizer: {tokenizer_dir}") print(f"text_encoder: {text_encoder_dir}") print(f"unet_model: {unet_model}") print(f"vae_encoder_model: {vae_encoder_model}") print(f"vae_decoder_model: {vae_decoder_model}") print(f"init image: {init_image}") print(f"time_input: {time_input}") print(f"save_dir: {save_dir}") # timesteps = np.array([999, 759, 499, 259]).astype(np.int64) # text encoder start = time.time() # prompt = "Self-portrait oil painting, a beautiful cyborg with golden hair, 8k" # prompt = "Astronauts in a jungle, cold color palette, muted colors, detailed, 8k" # prompt = "Caricature, a beautiful girl with black hair, 8k" prompt_embeds_npy = get_embeds(prompt, tokenizer_dir, text_encoder_dir) print(f"text encoder take {(1000 * (time.time() - start)):.1f}ms") prompt_name = prompt.replace(" ", "_") latents_shape = [1, 4, 64, 64] # latent = torch.randn(latents_shape, generator=None, device="cpu", dtype=torch.float32, # layout=torch.strided).detach().numpy() alphas_cumprod, final_alphas_cumprod, self_timesteps = get_alphas_cumprod() # load unet model and vae model start = time.time() vae_encoder = onnxruntime.InferenceSession(vae_encoder_model, providers=["CPUExecutionProvider"]) unet_session_main = onnxruntime.InferenceSession(unet_model, providers=["CPUExecutionProvider"]) vae_decoder = onnxruntime.InferenceSession(vae_decoder_model, providers=["CPUExecutionProvider"]) print(f"load models take {(1000 * (time.time() - start)):.1f}ms") # load time input file time_input = np.load(time_input) # load image # url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/img2img-init.png" url = init_image init_image = load_image(url, convert_method=resize_and_rgb) # U8, (512, 512, 3), RGB init_image_show = init_image # vae encoder inference vae_start = time.time() init_image = preprocess(init_image) # torch.Size([1, 3, 512, 512]) if isinstance(init_image, torch.Tensor): init_image = init_image.detach().numpy() vae_encoder_onnx_inp_name = vae_encoder.get_inputs()[0].name vae_encoder_onnx_out_name = vae_encoder.get_outputs()[0].name # vae_encoder_out.shape (1, 8, 64, 64) vae_encoder_out = vae_encoder.run(None, {vae_encoder_onnx_inp_name: init_image})[0] # encoder out: torch.Size([1, 8, 64, 64]) print(f"vae encoder inference take {(1000 * (time.time() - vae_start)):.1f}ms") # vae encoder inference device = torch.device("cpu") vae_encoder_out = torch.from_numpy(vae_encoder_out).to(torch.float32) posterior = DiagonalGaussianDistribution(vae_encoder_out) # 数值基本对的上 vae_encode_info = AutoencoderKLOutput(latent_dist=posterior) generator = torch.manual_seed(0) init_latents = retrieve_latents(vae_encode_info, generator=generator) # 数值基本对的上 init_latents = init_latents * 0.18215 # 数值基本对的上 init_latents = torch.cat([init_latents], dim=0) shape = init_latents.shape dtype = torch.float16 noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype) # dtype 不同, 随机值不同 # get latents timestep = torch.tensor([499]).to(device) init_latents = add_noise(init_latents.to(device), noise, timestep) latents = init_latents latents = latents.detach().cpu().numpy() latent = latents # unet inference loop unet_loop_start = time.time() timesteps = np.array([499, 259]).astype(np.int64) self_timesteps = np.array([999, 759, 499, 259]).astype(np.int64) step_index = [2, 3] for i, timestep in enumerate(timesteps): unet_start = time.time() noise_pred = unet_session_main.run(None, {"sample": latent, \ "/down_blocks.0/resnets.0/act_1/Mul_output_0": np.expand_dims(time_input[i], axis=0), \ "encoder_hidden_states": prompt_embeds_npy})[0] print(f"unet once take {(1000 * (time.time() - unet_start)):.1f}ms") sample = latent model_output = noise_pred # 1. get previous step value prev_step_index = step_index[i] + 1 if prev_step_index < len(self_timesteps): prev_timestep = self_timesteps[prev_step_index] else: prev_timestep = timestep alpha_prod_t = alphas_cumprod[timestep] alpha_prod_t_prev = alphas_cumprod[prev_timestep] if prev_timestep >= 0 else final_alphas_cumprod beta_prod_t = 1 - alpha_prod_t beta_prod_t_prev = 1 - alpha_prod_t_prev # 3. Get scalings for boundary conditions scaled_timestep = timestep * 10 c_skip = 0.5 ** 2 / (scaled_timestep ** 2 + 0.5 ** 2) c_out = scaled_timestep / (scaled_timestep ** 2 + 0.5 ** 2) ** 0.5 predicted_original_sample = (sample - (beta_prod_t ** 0.5) * model_output) / (alpha_prod_t ** 0.5) # 数值基本对齐 denoised = c_out * predicted_original_sample + c_skip * sample if step_index[i] != 3: device = torch.device("cpu") noise = randn_tensor(model_output.shape, generator=generator, device=device, dtype=torch.float16).numpy() prev_sample = (alpha_prod_t_prev ** 0.5) * denoised + (beta_prod_t_prev ** 0.5) * noise else: prev_sample = denoised latent = prev_sample print(f"unet loop take {(1000 * (time.time() - unet_loop_start)):.1f}ms") # vae decoder inference vae_start = time.time() latent = latent / 0.18215 image = vae_decoder.run(None, {"x": latent})[0] # ['784'] print(f"vae decoder inference take {(1000 * (time.time() - vae_start)):.1f}ms") # save result save_start = time.time() image = np.transpose(image, (0, 2, 3, 1)).squeeze(axis=0) image_denorm = np.clip(image / 2 + 0.5, 0, 1) image = (image_denorm * 255).round().astype("uint8") pil_image = Image.fromarray(image[:, :, :3]) pil_image.save(save_dir) grid_img = make_image_grid([init_image_show, pil_image], rows=1, cols=2) grid_img.save(f"./lcm_lora_sdv1-5_imgGrid_output.png") print(f"grid image saved in ./lcm_lora_sdv1-5_imgGrid_output.png") print(f"save image take {(1000 * (time.time() - save_start)):.1f}ms")