Spaces:
Running
on
Zero
Running
on
Zero
| ##### | |
| # Modified from https://github.com/huggingface/diffusers/blob/v0.29.1/src/diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py | |
| # PhotoMaker v2 @ TencentARC and MCG-NKU | |
| # Author: Zhen Li | |
| ##### | |
| # Copyright 2024 TencentARC and The HuggingFace Team. All rights reserved. | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| import inspect | |
| from typing import Any, Callable, Dict, List, Optional, Tuple, Union | |
| import numpy as np | |
| import PIL.Image | |
| import torch | |
| from transformers import ( | |
| CLIPImageProcessor, | |
| CLIPTextModel, | |
| CLIPTextModelWithProjection, | |
| CLIPTokenizer, | |
| CLIPVisionModelWithProjection, | |
| ) | |
| from diffusers.image_processor import PipelineImageInput, VaeImageProcessor | |
| from diffusers.loaders import ( | |
| FromSingleFileMixin, | |
| IPAdapterMixin, | |
| StableDiffusionXLLoraLoaderMixin, | |
| TextualInversionLoaderMixin, | |
| ) | |
| from diffusers.models import AutoencoderKL, ImageProjection, MultiAdapter, T2IAdapter, UNet2DConditionModel | |
| from diffusers.models.attention_processor import ( | |
| AttnProcessor2_0, | |
| LoRAAttnProcessor2_0, | |
| LoRAXFormersAttnProcessor, | |
| XFormersAttnProcessor, | |
| ) | |
| from diffusers.models.lora import adjust_lora_scale_text_encoder | |
| from diffusers.schedulers import KarrasDiffusionSchedulers | |
| from diffusers.utils import ( | |
| PIL_INTERPOLATION, | |
| USE_PEFT_BACKEND, | |
| logging, | |
| replace_example_docstring, | |
| scale_lora_layers, | |
| unscale_lora_layers, | |
| ) | |
| from diffusers.utils.torch_utils import randn_tensor | |
| from diffusers.pipelines.pipeline_utils import DiffusionPipeline, StableDiffusionMixin | |
| from diffusers.pipelines.stable_diffusion_xl.pipeline_output import StableDiffusionXLPipelineOutput | |
| from diffusers.pipelines import StableDiffusionXLAdapterPipeline | |
| from diffusers.utils import _get_model_file | |
| from safetensors import safe_open | |
| from huggingface_hub.utils import validate_hf_hub_args | |
| from module.model_v2 import PhotoMakerIDEncoder_CLIPInsightfaceExtendtoken | |
| # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.rescale_noise_cfg | |
| def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0): | |
| """ | |
| Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and | |
| Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4 | |
| """ | |
| std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True) | |
| std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True) | |
| # rescale the results from guidance (fixes overexposure) | |
| noise_pred_rescaled = noise_cfg * (std_text / std_cfg) | |
| # mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images | |
| noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg | |
| return noise_cfg | |
| # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps | |
| def retrieve_timesteps( | |
| scheduler, | |
| num_inference_steps: Optional[int] = None, | |
| device: Optional[Union[str, torch.device]] = None, | |
| timesteps: Optional[List[int]] = None, | |
| sigmas: Optional[List[float]] = None, | |
| **kwargs, | |
| ): | |
| """ | |
| Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles | |
| custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. | |
| Args: | |
| scheduler (`SchedulerMixin`): | |
| The scheduler to get timesteps from. | |
| num_inference_steps (`int`): | |
| The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` | |
| must be `None`. | |
| device (`str` or `torch.device`, *optional*): | |
| The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. | |
| timesteps (`List[int]`, *optional*): | |
| Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, | |
| `num_inference_steps` and `sigmas` must be `None`. | |
| sigmas (`List[float]`, *optional*): | |
| Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, | |
| `num_inference_steps` and `timesteps` must be `None`. | |
| Returns: | |
| `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the | |
| second element is the number of inference steps. | |
| """ | |
| if timesteps is not None and sigmas is not None: | |
| raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") | |
| if timesteps is not None: | |
| accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) | |
| if not accepts_timesteps: | |
| raise ValueError( | |
| f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" | |
| f" timestep schedules. Please check whether you are using the correct scheduler." | |
| ) | |
| scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) | |
| timesteps = scheduler.timesteps | |
| num_inference_steps = len(timesteps) | |
| elif sigmas is not None: | |
| accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) | |
| if not accept_sigmas: | |
| raise ValueError( | |
| f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" | |
| f" sigmas schedules. Please check whether you are using the correct scheduler." | |
| ) | |
| scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) | |
| timesteps = scheduler.timesteps | |
| num_inference_steps = len(timesteps) | |
| else: | |
| scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) | |
| timesteps = scheduler.timesteps | |
| return timesteps, num_inference_steps | |
| def _preprocess_adapter_image(image, height, width): | |
| if isinstance(image, torch.Tensor): | |
| return image | |
| elif isinstance(image, PIL.Image.Image): | |
| image = [image] | |
| if isinstance(image[0], PIL.Image.Image): | |
| image = [np.array(i.resize((width, height), resample=PIL_INTERPOLATION["lanczos"])) for i in image] | |
| image = [ | |
| i[None, ..., None] if i.ndim == 2 else i[None, ...] for i in image | |
| ] # expand [h, w] or [h, w, c] to [b, h, w, c] | |
| image = np.concatenate(image, axis=0) | |
| image = np.array(image).astype(np.float32) / 255.0 | |
| image = image.transpose(0, 3, 1, 2) | |
| image = torch.from_numpy(image) | |
| elif isinstance(image[0], torch.Tensor): | |
| if image[0].ndim == 3: | |
| image = torch.stack(image, dim=0) | |
| elif image[0].ndim == 4: | |
| image = torch.cat(image, dim=0) | |
| else: | |
| raise ValueError( | |
| f"Invalid image tensor! Expecting image tensor with 3 or 4 dimension, but recive: {image[0].ndim}" | |
| ) | |
| return image | |
| class PhotoMakerStableDiffusionXLAdapterPipeline(StableDiffusionXLAdapterPipeline): | |
| def load_photomaker_adapter( | |
| self, | |
| pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], | |
| weight_name: str, | |
| subfolder: str = '', | |
| trigger_word: str = 'img', | |
| pm_version: str = 'v2', | |
| **kwargs, | |
| ): | |
| """ | |
| Parameters: | |
| pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`): | |
| Can be either: | |
| - A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on | |
| the Hub. | |
| - A path to a *directory* (for example `./my_model_directory`) containing the model weights saved | |
| with [`ModelMixin.save_pretrained`]. | |
| - A [torch state | |
| dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict). | |
| weight_name (`str`): | |
| The weight name NOT the path to the weight. | |
| subfolder (`str`, defaults to `""`): | |
| The subfolder location of a model file within a larger model repository on the Hub or locally. | |
| trigger_word (`str`, *optional*, defaults to `"img"`): | |
| The trigger word is used to identify the position of class word in the text prompt, | |
| and it is recommended not to set it as a common word. | |
| This trigger word must be placed after the class word when used, otherwise, it will affect the performance of the personalized generation. | |
| """ | |
| # Load the main state dict first. | |
| cache_dir = kwargs.pop("cache_dir", None) | |
| force_download = kwargs.pop("force_download", False) | |
| resume_download = kwargs.pop("resume_download", False) | |
| proxies = kwargs.pop("proxies", None) | |
| local_files_only = kwargs.pop("local_files_only", None) | |
| token = kwargs.pop("token", None) | |
| revision = kwargs.pop("revision", None) | |
| user_agent = { | |
| "file_type": "attn_procs_weights", | |
| "framework": "pytorch", | |
| } | |
| if not isinstance(pretrained_model_name_or_path_or_dict, dict): | |
| model_file = _get_model_file( | |
| pretrained_model_name_or_path_or_dict, | |
| weights_name=weight_name, | |
| cache_dir=cache_dir, | |
| force_download=force_download, | |
| # resume_download=resume_download, | |
| proxies=proxies, | |
| local_files_only=local_files_only, | |
| token=token, | |
| revision=revision, | |
| subfolder=subfolder, | |
| user_agent=user_agent, | |
| ) | |
| if weight_name.endswith(".safetensors"): | |
| state_dict = {"id_encoder": {}, "lora_weights": {}} | |
| with safe_open(model_file, framework="pt", device="cpu") as f: | |
| for key in f.keys(): | |
| if key.startswith("id_encoder."): | |
| state_dict["id_encoder"][key.replace("id_encoder.", "")] = f.get_tensor(key) | |
| elif key.startswith("lora_weights."): | |
| state_dict["lora_weights"][key.replace("lora_weights.", "")] = f.get_tensor(key) | |
| else: | |
| state_dict = torch.load(model_file, map_location="cpu") | |
| else: | |
| state_dict = pretrained_model_name_or_path_or_dict | |
| keys = list(state_dict.keys()) | |
| if keys != ["id_encoder", "lora_weights"]: | |
| raise ValueError("Required keys are (`id_encoder` and `lora_weights`) missing from the state dict.") | |
| self.num_tokens = 2 | |
| self.trigger_word = trigger_word | |
| # load finetuned CLIP image encoder and fuse module here if it has not been registered to the pipeline yet | |
| print(f"Loading PhotoMaker {pm_version} components [1] id_encoder from [{pretrained_model_name_or_path_or_dict}]...") | |
| self.id_image_processor = CLIPImageProcessor() | |
| if pm_version == "v1": # PhotoMaker v1 | |
| id_encoder = PhotoMakerIDEncoder() | |
| elif pm_version == "v2": # PhotoMaker v2 | |
| id_encoder = PhotoMakerIDEncoder_CLIPInsightfaceExtendtoken() | |
| else: | |
| raise NotImplementedError(f"The PhotoMaker version [{pm_version}] does not support") | |
| id_encoder.load_state_dict(state_dict["id_encoder"], strict=True) | |
| id_encoder = id_encoder.to(self.device, dtype=self.unet.dtype) | |
| self.id_encoder = id_encoder | |
| # load lora into models | |
| print(f"Loading PhotoMaker {pm_version} components [2] lora_weights from [{pretrained_model_name_or_path_or_dict}]") | |
| self.load_lora_weights(state_dict["lora_weights"], adapter_name="photomaker") | |
| # Add trigger word token | |
| if self.tokenizer is not None: | |
| self.tokenizer.add_tokens([self.trigger_word], special_tokens=True) | |
| self.tokenizer_2.add_tokens([self.trigger_word], special_tokens=True) | |
| def encode_prompt_with_trigger_word( | |
| self, | |
| prompt: str, | |
| prompt_2: Optional[str] = None, | |
| device: Optional[torch.device] = None, | |
| num_images_per_prompt: int = 1, | |
| do_classifier_free_guidance: bool = True, | |
| negative_prompt: Optional[str] = None, | |
| negative_prompt_2: Optional[str] = None, | |
| prompt_embeds: Optional[torch.Tensor] = None, | |
| negative_prompt_embeds: Optional[torch.Tensor] = None, | |
| pooled_prompt_embeds: Optional[torch.Tensor] = None, | |
| negative_pooled_prompt_embeds: Optional[torch.Tensor] = None, | |
| lora_scale: Optional[float] = None, | |
| clip_skip: Optional[int] = None, | |
| ### Added args | |
| num_id_images: int = 1, | |
| class_tokens_mask: Optional[torch.LongTensor] = None, | |
| ): | |
| device = device or self._execution_device | |
| # set lora scale so that monkey patched LoRA | |
| # function of text encoder can correctly access it | |
| if lora_scale is not None and isinstance(self, StableDiffusionXLLoraLoaderMixin): | |
| self._lora_scale = lora_scale | |
| # dynamically adjust the LoRA scale | |
| if self.text_encoder is not None: | |
| if not USE_PEFT_BACKEND: | |
| adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) | |
| else: | |
| scale_lora_layers(self.text_encoder, lora_scale) | |
| if self.text_encoder_2 is not None: | |
| if not USE_PEFT_BACKEND: | |
| adjust_lora_scale_text_encoder(self.text_encoder_2, lora_scale) | |
| else: | |
| scale_lora_layers(self.text_encoder_2, lora_scale) | |
| prompt = [prompt] if isinstance(prompt, str) else prompt | |
| if prompt is not None: | |
| batch_size = len(prompt) | |
| else: | |
| batch_size = prompt_embeds.shape[0] | |
| # Find the token id of the trigger word | |
| image_token_id = self.tokenizer_2.convert_tokens_to_ids(self.trigger_word) | |
| # Define tokenizers and text encoders | |
| tokenizers = [self.tokenizer, self.tokenizer_2] if self.tokenizer is not None else [self.tokenizer_2] | |
| text_encoders = ( | |
| [self.text_encoder, self.text_encoder_2] if self.text_encoder is not None else [self.text_encoder_2] | |
| ) | |
| if prompt_embeds is None: | |
| prompt_2 = prompt_2 or prompt | |
| prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2 | |
| # textual inversion: process multi-vector tokens if necessary | |
| prompt_embeds_list = [] | |
| prompts = [prompt, prompt_2] | |
| for prompt, tokenizer, text_encoder in zip(prompts, tokenizers, text_encoders): | |
| if isinstance(self, TextualInversionLoaderMixin): | |
| prompt = self.maybe_convert_prompt(prompt, tokenizer) | |
| text_inputs = tokenizer( | |
| prompt, | |
| padding="max_length", | |
| max_length=tokenizer.model_max_length, | |
| truncation=True, | |
| return_tensors="pt", | |
| ) | |
| text_input_ids = text_inputs.input_ids | |
| untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="pt").input_ids | |
| if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( | |
| text_input_ids, untruncated_ids | |
| ): | |
| removed_text = tokenizer.batch_decode(untruncated_ids[:, tokenizer.model_max_length - 1 : -1]) | |
| print( | |
| "The following part of your input was truncated because CLIP can only handle sequences up to" | |
| f" {tokenizer.model_max_length} tokens: {removed_text}" | |
| ) | |
| clean_index = 0 | |
| clean_input_ids = [] | |
| class_token_index = [] | |
| # Find out the corresponding class word token based on the newly added trigger word token | |
| for i, token_id in enumerate(text_input_ids.tolist()[0]): | |
| if token_id == image_token_id: | |
| class_token_index.append(clean_index - 1) | |
| else: | |
| clean_input_ids.append(token_id) | |
| clean_index += 1 | |
| if len(class_token_index) != 1: | |
| raise ValueError( | |
| f"PhotoMaker currently does not support multiple trigger words in a single prompt.\ | |
| Trigger word: {self.trigger_word}, Prompt: {prompt}." | |
| ) | |
| class_token_index = class_token_index[0] | |
| # Expand the class word token and corresponding mask | |
| class_token = clean_input_ids[class_token_index] | |
| clean_input_ids = clean_input_ids[:class_token_index] + [class_token] * num_id_images * self.num_tokens + \ | |
| clean_input_ids[class_token_index+1:] | |
| # Truncation or padding | |
| max_len = tokenizer.model_max_length | |
| if len(clean_input_ids) > max_len: | |
| clean_input_ids = clean_input_ids[:max_len] | |
| else: | |
| clean_input_ids = clean_input_ids + [tokenizer.pad_token_id] * ( | |
| max_len - len(clean_input_ids) | |
| ) | |
| class_tokens_mask = [True if class_token_index <= i < class_token_index+(num_id_images * self.num_tokens) else False \ | |
| for i in range(len(clean_input_ids))] | |
| clean_input_ids = torch.tensor(clean_input_ids, dtype=torch.long).unsqueeze(0) | |
| class_tokens_mask = torch.tensor(class_tokens_mask, dtype=torch.bool).unsqueeze(0) | |
| prompt_embeds = text_encoder(clean_input_ids.to(device), output_hidden_states=True) | |
| # We are only ALWAYS interested in the pooled output of the final text encoder | |
| pooled_prompt_embeds = prompt_embeds[0] | |
| if clip_skip is None: | |
| prompt_embeds = prompt_embeds.hidden_states[-2] | |
| else: | |
| # "2" because SDXL always indexes from the penultimate layer. | |
| prompt_embeds = prompt_embeds.hidden_states[-(clip_skip + 2)] | |
| prompt_embeds_list.append(prompt_embeds) | |
| prompt_embeds = torch.concat(prompt_embeds_list, dim=-1) | |
| prompt_embeds = prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device) | |
| class_tokens_mask = class_tokens_mask.to(device=device) # TODO: ignoring two-prompt case | |
| # get unconditional embeddings for classifier free guidance | |
| zero_out_negative_prompt = negative_prompt is None and self.config.force_zeros_for_empty_prompt | |
| if do_classifier_free_guidance and negative_prompt_embeds is None and zero_out_negative_prompt: | |
| negative_prompt_embeds = torch.zeros_like(prompt_embeds) | |
| negative_pooled_prompt_embeds = torch.zeros_like(pooled_prompt_embeds) | |
| elif do_classifier_free_guidance and negative_prompt_embeds is None: | |
| negative_prompt = negative_prompt or "" | |
| negative_prompt_2 = negative_prompt_2 or negative_prompt | |
| # normalize str to list | |
| negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt | |
| negative_prompt_2 = ( | |
| batch_size * [negative_prompt_2] if isinstance(negative_prompt_2, str) else negative_prompt_2 | |
| ) | |
| uncond_tokens: List[str] | |
| if prompt is not None and type(prompt) is not type(negative_prompt): | |
| raise TypeError( | |
| f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" | |
| f" {type(prompt)}." | |
| ) | |
| elif batch_size != len(negative_prompt): | |
| raise ValueError( | |
| f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" | |
| f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" | |
| " the batch size of `prompt`." | |
| ) | |
| else: | |
| uncond_tokens = [negative_prompt, negative_prompt_2] | |
| negative_prompt_embeds_list = [] | |
| for negative_prompt, tokenizer, text_encoder in zip(uncond_tokens, tokenizers, text_encoders): | |
| if isinstance(self, TextualInversionLoaderMixin): | |
| negative_prompt = self.maybe_convert_prompt(negative_prompt, tokenizer) | |
| max_length = prompt_embeds.shape[1] | |
| uncond_input = tokenizer( | |
| negative_prompt, | |
| padding="max_length", | |
| max_length=max_length, | |
| truncation=True, | |
| return_tensors="pt", | |
| ) | |
| negative_prompt_embeds = text_encoder( | |
| uncond_input.input_ids.to(device), | |
| output_hidden_states=True, | |
| ) | |
| # We are only ALWAYS interested in the pooled output of the final text encoder | |
| negative_pooled_prompt_embeds = negative_prompt_embeds[0] | |
| negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2] | |
| negative_prompt_embeds_list.append(negative_prompt_embeds) | |
| negative_prompt_embeds = torch.concat(negative_prompt_embeds_list, dim=-1) | |
| if self.text_encoder_2 is not None: | |
| prompt_embeds = prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device) | |
| else: | |
| prompt_embeds = prompt_embeds.to(dtype=self.unet.dtype, device=device) | |
| bs_embed, seq_len, _ = prompt_embeds.shape | |
| if do_classifier_free_guidance: | |
| # duplicate unconditional embeddings for each generation per prompt, using mps friendly method | |
| seq_len = negative_prompt_embeds.shape[1] | |
| if self.text_encoder_2 is not None: | |
| negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device) | |
| else: | |
| negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.unet.dtype, device=device) | |
| negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) | |
| negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) | |
| pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt).view( | |
| bs_embed * num_images_per_prompt, -1 | |
| ) | |
| if do_classifier_free_guidance: | |
| negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.repeat(1, num_images_per_prompt).view( | |
| bs_embed * num_images_per_prompt, -1 | |
| ) | |
| if self.text_encoder is not None: | |
| if isinstance(self, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND: | |
| # Retrieve the original scale by scaling back the LoRA layers | |
| unscale_lora_layers(self.text_encoder, lora_scale) | |
| if self.text_encoder_2 is not None: | |
| if isinstance(self, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND: | |
| # Retrieve the original scale by scaling back the LoRA layers | |
| unscale_lora_layers(self.text_encoder_2, lora_scale) | |
| return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds, class_tokens_mask | |
| def interrupt(self): | |
| return self._interrupt | |
| def __call__( | |
| self, | |
| prompt: Union[str, List[str]] = None, | |
| prompt_2: Optional[Union[str, List[str]]] = None, | |
| image: PipelineImageInput = None, | |
| height: Optional[int] = None, | |
| width: Optional[int] = None, | |
| num_inference_steps: int = 50, | |
| timesteps: List[int] = None, | |
| sigmas: List[float] = None, | |
| denoising_end: Optional[float] = None, | |
| guidance_scale: float = 5.0, | |
| negative_prompt: Optional[Union[str, List[str]]] = None, | |
| negative_prompt_2: Optional[Union[str, List[str]]] = None, | |
| num_images_per_prompt: Optional[int] = 1, | |
| eta: float = 0.0, | |
| generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, | |
| latents: Optional[torch.Tensor] = None, | |
| prompt_embeds: Optional[torch.Tensor] = None, | |
| negative_prompt_embeds: Optional[torch.Tensor] = None, | |
| pooled_prompt_embeds: Optional[torch.Tensor] = None, | |
| negative_pooled_prompt_embeds: Optional[torch.Tensor] = None, | |
| ip_adapter_image: Optional[PipelineImageInput] = None, | |
| ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None, | |
| output_type: Optional[str] = "pil", | |
| return_dict: bool = True, | |
| callback: Optional[Callable[[int, int, torch.Tensor], None]] = None, | |
| callback_steps: int = 1, | |
| cross_attention_kwargs: Optional[Dict[str, Any]] = None, | |
| guidance_rescale: float = 0.0, | |
| original_size: Optional[Tuple[int, int]] = None, | |
| crops_coords_top_left: Tuple[int, int] = (0, 0), | |
| target_size: Optional[Tuple[int, int]] = None, | |
| negative_original_size: Optional[Tuple[int, int]] = None, | |
| negative_crops_coords_top_left: Tuple[int, int] = (0, 0), | |
| negative_target_size: Optional[Tuple[int, int]] = None, | |
| adapter_conditioning_scale: Union[float, List[float]] = 1.0, | |
| adapter_conditioning_factor: float = 1.0, | |
| clip_skip: Optional[int] = None, | |
| # Added parameters (for PhotoMaker) | |
| input_id_images: PipelineImageInput = None, | |
| start_merge_step: int = 10, # TODO: change to `style_strength_ratio` in the future | |
| class_tokens_mask: Optional[torch.LongTensor] = None, | |
| id_embeds: Optional[torch.FloatTensor] = None, | |
| prompt_embeds_text_only: Optional[torch.FloatTensor] = None, | |
| pooled_prompt_embeds_text_only: Optional[torch.FloatTensor] = None, | |
| **kwargs, | |
| ): | |
| r""" | |
| Function invoked when calling the pipeline for generation. | |
| Only the parameters introduced by PhotoMaker are discussed here. | |
| For explanations of the previous parameters in StableDiffusionXLControlNetPipeline, please refer to https://github.com/huggingface/diffusers/blob/v0.25.0/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py | |
| Args: | |
| input_id_images (`PipelineImageInput`, *optional*): | |
| Input ID Image to work with PhotoMaker. | |
| class_tokens_mask (`torch.LongTensor`, *optional*): | |
| Pre-generated class token. When the `prompt_embeds` parameter is provided in advance, it is necessary to prepare the `class_tokens_mask` beforehand for marking out the position of class word. | |
| prompt_embeds_text_only (`torch.FloatTensor`, *optional*): | |
| Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not | |
| provided, text embeddings will be generated from `prompt` input argument. | |
| pooled_prompt_embeds_text_only (`torch.FloatTensor`, *optional*): | |
| Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. | |
| If not provided, pooled text embeddings will be generated from `prompt` input argument. | |
| Returns: | |
| [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] or `tuple`: | |
| [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] if `return_dict` is True, otherwise a | |
| `tuple`. When returning a tuple, the first element is a list with the generated images. | |
| """ | |
| height, width = self._default_height_width(height, width, image) | |
| device = self._execution_device | |
| use_adapter = True if image is not None else False | |
| print(f"Use adapter: {use_adapter} | output size: {(height, width)}") | |
| if use_adapter: | |
| if isinstance(self.adapter, MultiAdapter): | |
| adapter_input = [] | |
| for one_image in image: | |
| one_image = _preprocess_adapter_image(one_image, height, width) | |
| one_image = one_image.to(device=device, dtype=self.adapter.dtype) | |
| adapter_input.append(one_image) | |
| else: | |
| adapter_input = _preprocess_adapter_image(image, height, width) | |
| adapter_input = adapter_input.to(device=device, dtype=self.adapter.dtype) | |
| original_size = original_size or (height, width) | |
| target_size = target_size or (height, width) | |
| # 1. Check inputs. Raise error if not correct | |
| self.check_inputs( | |
| prompt, | |
| prompt_2, | |
| height, | |
| width, | |
| callback_steps, | |
| negative_prompt, | |
| negative_prompt_2, | |
| prompt_embeds, | |
| negative_prompt_embeds, | |
| pooled_prompt_embeds, | |
| negative_pooled_prompt_embeds, | |
| ip_adapter_image, | |
| ip_adapter_image_embeds, | |
| ) | |
| self._guidance_scale = guidance_scale | |
| self._clip_skip = clip_skip | |
| # | |
| if prompt_embeds is not None and class_tokens_mask is None: | |
| raise ValueError( | |
| "If `prompt_embeds` are provided, `class_tokens_mask` also have to be passed. Make sure to generate `class_tokens_mask` from the same tokenizer that was used to generate `prompt_embeds`." | |
| ) | |
| # check the input id images | |
| if input_id_images is None: | |
| raise ValueError( | |
| "Provide `input_id_images`. Cannot leave `input_id_images` undefined for PhotoMaker pipeline." | |
| ) | |
| if not isinstance(input_id_images, list): | |
| input_id_images = [input_id_images] | |
| # 2. Define call parameters | |
| if prompt is not None and isinstance(prompt, str): | |
| batch_size = 1 | |
| elif prompt is not None and isinstance(prompt, list): | |
| batch_size = len(prompt) | |
| else: | |
| batch_size = prompt_embeds.shape[0] | |
| device = self._execution_device | |
| # 3. Encode input prompt | |
| lora_scale = ( | |
| cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None | |
| ) | |
| num_id_images = len(input_id_images) | |
| ( | |
| prompt_embeds, | |
| _, | |
| pooled_prompt_embeds, | |
| _, | |
| class_tokens_mask, | |
| ) = self.encode_prompt_with_trigger_word( | |
| prompt=prompt, | |
| prompt_2=prompt_2, | |
| device=device, | |
| num_id_images=num_id_images, | |
| class_tokens_mask=class_tokens_mask, | |
| num_images_per_prompt=num_images_per_prompt, | |
| do_classifier_free_guidance=self.do_classifier_free_guidance, | |
| negative_prompt=negative_prompt, | |
| negative_prompt_2=negative_prompt_2, | |
| prompt_embeds=prompt_embeds, | |
| negative_prompt_embeds=negative_prompt_embeds, | |
| pooled_prompt_embeds=pooled_prompt_embeds, | |
| negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, | |
| lora_scale=lora_scale, | |
| clip_skip=self._clip_skip, | |
| ) | |
| # 4. Encode input prompt without the trigger word for delayed conditioning | |
| # encode, remove trigger word token, then decode | |
| tokens_text_only = self.tokenizer.encode(prompt, add_special_tokens=False) | |
| trigger_word_token = self.tokenizer.convert_tokens_to_ids(self.trigger_word) | |
| tokens_text_only.remove(trigger_word_token) | |
| prompt_text_only = self.tokenizer.decode(tokens_text_only, add_special_tokens=False) | |
| ( | |
| prompt_embeds_text_only, | |
| negative_prompt_embeds, | |
| pooled_prompt_embeds_text_only, # TODO: replace the pooled_prompt_embeds with text only prompt | |
| negative_pooled_prompt_embeds, | |
| ) = self.encode_prompt( | |
| prompt=prompt_text_only, | |
| prompt_2=prompt_2, | |
| device=device, | |
| num_images_per_prompt=num_images_per_prompt, | |
| do_classifier_free_guidance=self.do_classifier_free_guidance, | |
| negative_prompt=negative_prompt, | |
| negative_prompt_2=negative_prompt_2, | |
| prompt_embeds=prompt_embeds_text_only, | |
| negative_prompt_embeds=negative_prompt_embeds, | |
| pooled_prompt_embeds=pooled_prompt_embeds_text_only, | |
| negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, | |
| lora_scale=lora_scale, | |
| clip_skip=self._clip_skip, | |
| ) | |
| # 5. Prepare the input ID images | |
| dtype = next(self.id_encoder.parameters()).dtype | |
| if not isinstance(input_id_images[0], torch.Tensor): | |
| id_pixel_values = self.id_image_processor(input_id_images, return_tensors="pt").pixel_values | |
| id_pixel_values = id_pixel_values.unsqueeze(0).to(device=device, dtype=dtype) # TODO: multiple prompts | |
| # 6. Get the update text embedding with the stacked ID embedding | |
| if id_embeds is not None: | |
| id_embeds = id_embeds.unsqueeze(0).to(device=device, dtype=dtype) | |
| prompt_embeds = self.id_encoder(id_pixel_values, prompt_embeds, class_tokens_mask, id_embeds) | |
| else: | |
| prompt_embeds = self.id_encoder(id_pixel_values, prompt_embeds, class_tokens_mask) | |
| bs_embed, seq_len, _ = prompt_embeds.shape | |
| # duplicate text embeddings for each generation per prompt, using mps friendly method | |
| prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) | |
| prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) | |
| # 6.1 Get the ip adapter embedding | |
| if ip_adapter_image is not None or ip_adapter_image_embeds is not None: | |
| image_embeds = self.prepare_ip_adapter_image_embeds( | |
| ip_adapter_image, | |
| ip_adapter_image_embeds, | |
| device, | |
| batch_size * num_images_per_prompt, | |
| self.do_classifier_free_guidance, | |
| ) | |
| # 7. Prepare timesteps | |
| timesteps, num_inference_steps = retrieve_timesteps( | |
| self.scheduler, num_inference_steps, device, timesteps, sigmas | |
| ) | |
| # 8. Prepare latent variables | |
| num_channels_latents = self.unet.config.in_channels | |
| latents = self.prepare_latents( | |
| batch_size * num_images_per_prompt, | |
| num_channels_latents, | |
| height, | |
| width, | |
| prompt_embeds.dtype, | |
| device, | |
| generator, | |
| latents, | |
| ) | |
| # 9. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline | |
| extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) | |
| # 8.5 Optionally get Guidance Scale Embedding | |
| timestep_cond = None | |
| if self.unet.config.time_cond_proj_dim is not None: | |
| guidance_scale_tensor = torch.tensor(self.guidance_scale - 1).repeat(batch_size * num_images_per_prompt) | |
| timestep_cond = self.get_guidance_scale_embedding( | |
| guidance_scale_tensor, embedding_dim=self.unet.config.time_cond_proj_dim | |
| ).to(device=device, dtype=latents.dtype) | |
| # 9. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline | |
| if use_adapter: | |
| if isinstance(self.adapter, MultiAdapter): | |
| adapter_state = self.adapter(adapter_input, adapter_conditioning_scale) | |
| for k, v in enumerate(adapter_state): | |
| adapter_state[k] = v | |
| else: | |
| adapter_state = self.adapter(adapter_input) | |
| for k, v in enumerate(adapter_state): | |
| adapter_state[k] = v * adapter_conditioning_scale | |
| if num_images_per_prompt > 1: | |
| for k, v in enumerate(adapter_state): | |
| adapter_state[k] = v.repeat(num_images_per_prompt, 1, 1, 1) | |
| if self.do_classifier_free_guidance: | |
| for k, v in enumerate(adapter_state): | |
| adapter_state[k] = torch.cat([v] * 2, dim=0) | |
| add_text_embeds = pooled_prompt_embeds | |
| if self.text_encoder_2 is None: | |
| text_encoder_projection_dim = int(pooled_prompt_embeds.shape[-1]) | |
| else: | |
| text_encoder_projection_dim = self.text_encoder_2.config.projection_dim | |
| add_time_ids = self._get_add_time_ids( | |
| original_size, | |
| crops_coords_top_left, | |
| target_size, | |
| dtype=prompt_embeds.dtype, | |
| text_encoder_projection_dim=text_encoder_projection_dim, | |
| ) | |
| if negative_original_size is not None and negative_target_size is not None: | |
| negative_add_time_ids = self._get_add_time_ids( | |
| negative_original_size, | |
| negative_crops_coords_top_left, | |
| negative_target_size, | |
| dtype=prompt_embeds.dtype, | |
| text_encoder_projection_dim=text_encoder_projection_dim, | |
| ) | |
| else: | |
| negative_add_time_ids = add_time_ids | |
| if self.do_classifier_free_guidance: | |
| add_time_ids = torch.cat([negative_add_time_ids, add_time_ids], dim=0) | |
| prompt_embeds = prompt_embeds.to(device) | |
| add_text_embeds = add_text_embeds.to(device) | |
| add_time_ids = add_time_ids.to(device).repeat(batch_size * num_images_per_prompt, 1) | |
| # 11. Denoising loop | |
| num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) | |
| # Apply denoising_end | |
| if denoising_end is not None and isinstance(denoising_end, float) and denoising_end > 0 and denoising_end < 1: | |
| discrete_timestep_cutoff = int( | |
| round( | |
| self.scheduler.config.num_train_timesteps | |
| - (denoising_end * self.scheduler.config.num_train_timesteps) | |
| ) | |
| ) | |
| num_inference_steps = len(list(filter(lambda ts: ts >= discrete_timestep_cutoff, timesteps))) | |
| timesteps = timesteps[:num_inference_steps] | |
| with self.progress_bar(total=num_inference_steps) as progress_bar: | |
| for i, t in enumerate(timesteps): | |
| # expand the latents if we are doing classifier free guidance | |
| latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents | |
| latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) | |
| if i <= start_merge_step: | |
| current_prompt_embeds = torch.cat( | |
| [negative_prompt_embeds, prompt_embeds_text_only], dim=0 | |
| ) if self.do_classifier_free_guidance else prompt_embeds_text_only | |
| add_text_embeds = torch.cat( | |
| [negative_pooled_prompt_embeds, pooled_prompt_embeds_text_only], dim=0 | |
| ) if self.do_classifier_free_guidance else pooled_prompt_embeds_text_only | |
| else: | |
| current_prompt_embeds = torch.cat( | |
| [negative_prompt_embeds, prompt_embeds], dim=0 | |
| ) if self.do_classifier_free_guidance else prompt_embeds | |
| add_text_embeds = torch.cat( | |
| [negative_pooled_prompt_embeds, pooled_prompt_embeds], dim=0 | |
| ) if self.do_classifier_free_guidance else pooled_prompt_embeds | |
| if i < int(num_inference_steps * adapter_conditioning_factor) and (use_adapter): | |
| down_intrablock_additional_residuals = [state.clone() for state in adapter_state] | |
| else: | |
| down_intrablock_additional_residuals = None | |
| added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids} | |
| if ip_adapter_image is not None or ip_adapter_image_embeds is not None: | |
| added_cond_kwargs["image_embeds"] = image_embeds | |
| # predict the noise residual | |
| noise_pred = self.unet( | |
| latent_model_input, | |
| t, | |
| encoder_hidden_states=current_prompt_embeds, | |
| timestep_cond=timestep_cond, | |
| cross_attention_kwargs=cross_attention_kwargs, | |
| down_intrablock_additional_residuals=down_intrablock_additional_residuals, | |
| added_cond_kwargs=added_cond_kwargs, | |
| return_dict=False, | |
| )[0] | |
| # perform guidance | |
| if self.do_classifier_free_guidance: | |
| noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) | |
| noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) | |
| if self.do_classifier_free_guidance and guidance_rescale > 0.0: | |
| # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf | |
| noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale) | |
| # compute the previous noisy sample x_t -> x_t-1 | |
| latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] | |
| # call the callback, if provided | |
| if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): | |
| progress_bar.update() | |
| if callback is not None and i % callback_steps == 0: | |
| step_idx = i // getattr(self.scheduler, "order", 1) | |
| callback(step_idx, t, latents) | |
| if not output_type == "latent": | |
| # make sure the VAE is in float32 mode, as it overflows in float16 | |
| needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast | |
| if needs_upcasting: | |
| self.upcast_vae() | |
| latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype) | |
| image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] | |
| # cast back to fp16 if needed | |
| if needs_upcasting: | |
| self.vae.to(dtype=torch.float16) | |
| else: | |
| image = latents | |
| return StableDiffusionXLPipelineOutput(images=image) | |
| image = self.image_processor.postprocess(image, output_type=output_type) | |
| # Offload all models | |
| self.maybe_free_model_hooks() | |
| if not return_dict: | |
| return (image,) | |
| return StableDiffusionXLPipelineOutput(images=image) |