import inspect from typing import Any, Callable, Dict, List, Optional, Union import math import einops import torch from transformers import ( CLIPTextModelWithProjection, CLIPTokenizer, T5EncoderModel, T5Tokenizer, LlamaForCausalLM, PreTrainedTokenizerFast ) from diffusers.image_processor import VaeImageProcessor from diffusers.loaders import FromSingleFileMixin from diffusers.models.autoencoders import AutoencoderKL from diffusers.schedulers import FlowMatchEulerDiscreteScheduler from diffusers.utils import ( USE_PEFT_BACKEND, is_torch_xla_available, logging, ) from diffusers.utils.torch_utils import randn_tensor from diffusers.pipelines.pipeline_utils import DiffusionPipeline from .pipeline_output import HiDreamImagePipelineOutput from ...models.transformers.transformer_hidream_image import HiDreamImageTransformer2DModel from ...schedulers.fm_solvers_unipc import FlowUniPCMultistepScheduler if is_torch_xla_available(): import torch_xla.core.xla_model as xm XLA_AVAILABLE = True else: XLA_AVAILABLE = False logger = logging.get_logger(__name__) # pylint: disable=invalid-name # Copied from diffusers.pipelines.flux.pipeline_flux.calculate_shift def calculate_shift( image_seq_len, base_seq_len: int = 256, max_seq_len: int = 4096, base_shift: float = 0.5, max_shift: float = 1.15, ): m = (max_shift - base_shift) / (max_seq_len - base_seq_len) b = base_shift - m * base_seq_len mu = image_seq_len * m + b return mu # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps def retrieve_timesteps( scheduler, num_inference_steps: Optional[int] = None, device: Optional[Union[str, torch.device]] = None, timesteps: Optional[List[int]] = None, sigmas: Optional[List[float]] = None, **kwargs, ): r""" Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. Args: scheduler (`SchedulerMixin`): The scheduler to get timesteps from. num_inference_steps (`int`): The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` must be `None`. device (`str` or `torch.device`, *optional*): The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. timesteps (`List[int]`, *optional*): Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, `num_inference_steps` and `sigmas` must be `None`. sigmas (`List[float]`, *optional*): Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, `num_inference_steps` and `timesteps` must be `None`. Returns: `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the second element is the number of inference steps. """ if timesteps is not None and sigmas is not None: raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") if timesteps is not None: accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) if not accepts_timesteps: raise ValueError( f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" f" timestep schedules. Please check whether you are using the correct scheduler." ) scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) timesteps = scheduler.timesteps num_inference_steps = len(timesteps) elif sigmas is not None: accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) if not accept_sigmas: raise ValueError( f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" f" sigmas schedules. Please check whether you are using the correct scheduler." ) scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) timesteps = scheduler.timesteps num_inference_steps = len(timesteps) else: scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) timesteps = scheduler.timesteps return timesteps, num_inference_steps class HiDreamImagePipeline(DiffusionPipeline, FromSingleFileMixin): model_cpu_offload_seq = "text_encoder->text_encoder_2->text_encoder_3->text_encoder_4->image_encoder->transformer->vae" _optional_components = ["image_encoder", "feature_extractor"] _callback_tensor_inputs = ["latents", "prompt_embeds"] def __init__( self, scheduler: FlowMatchEulerDiscreteScheduler, vae: AutoencoderKL, text_encoder: CLIPTextModelWithProjection, tokenizer: CLIPTokenizer, text_encoder_2: CLIPTextModelWithProjection, tokenizer_2: CLIPTokenizer, text_encoder_3: T5EncoderModel, tokenizer_3: T5Tokenizer, text_encoder_4: LlamaForCausalLM, tokenizer_4: PreTrainedTokenizerFast, ): super().__init__() self.register_modules( vae=vae, text_encoder=text_encoder, text_encoder_2=text_encoder_2, text_encoder_3=text_encoder_3, text_encoder_4=text_encoder_4, tokenizer=tokenizer, tokenizer_2=tokenizer_2, tokenizer_3=tokenizer_3, tokenizer_4=tokenizer_4, scheduler=scheduler, ) self.vae_scale_factor = ( 2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, "vae") and self.vae is not None else 8 ) # HiDreamImage latents are turned into 2x2 patches and packed. This means the latent width and height has to be divisible # by the patch size. So the vae scale factor is multiplied by the patch size to account for this self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor * 2) self.default_sample_size = 128 self.tokenizer_4.pad_token = self.tokenizer_4.eos_token def _get_t5_prompt_embeds( self, prompt: Union[str, List[str]] = None, num_images_per_prompt: int = 1, max_sequence_length: int = 128, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None, ): device = device or self._execution_device dtype = dtype or self.text_encoder_3.dtype prompt = [prompt] if isinstance(prompt, str) else prompt batch_size = len(prompt) text_inputs = self.tokenizer_3( prompt, padding="max_length", max_length=min(max_sequence_length, self.tokenizer_3.model_max_length), truncation=True, add_special_tokens=True, return_tensors="pt", ) text_input_ids = text_inputs.input_ids attention_mask = text_inputs.attention_mask untruncated_ids = self.tokenizer_3(prompt, padding="longest", return_tensors="pt").input_ids if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): removed_text = self.tokenizer_3.batch_decode(untruncated_ids[:, min(max_sequence_length, self.tokenizer_3.model_max_length) - 1 : -1]) logger.warning( "The following part of your input was truncated because `max_sequence_length` is set to " f" {min(max_sequence_length, self.tokenizer_3.model_max_length)} tokens: {removed_text}" ) prompt_embeds = self.text_encoder_3(text_input_ids.to(device), attention_mask=attention_mask.to(device))[0] prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) _, seq_len, _ = prompt_embeds.shape # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) return prompt_embeds def _get_clip_prompt_embeds( self, tokenizer, text_encoder, prompt: Union[str, List[str]], num_images_per_prompt: int = 1, max_sequence_length: int = 128, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None, ): device = device or self._execution_device dtype = dtype or text_encoder.dtype prompt = [prompt] if isinstance(prompt, str) else prompt batch_size = len(prompt) text_inputs = tokenizer( prompt, padding="max_length", max_length=min(max_sequence_length, 218), truncation=True, return_tensors="pt", ) text_input_ids = text_inputs.input_ids untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="pt").input_ids if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): removed_text = tokenizer.batch_decode(untruncated_ids[:, 218 - 1 : -1]) logger.warning( "The following part of your input was truncated because CLIP can only handle sequences up to" f" {218} tokens: {removed_text}" ) prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True) # Use pooled output of CLIPTextModel prompt_embeds = prompt_embeds[0] prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) # duplicate text embeddings for each generation per prompt, using mps friendly method prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt) prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, -1) return prompt_embeds def _get_llama3_prompt_embeds( self, prompt: Union[str, List[str]] = None, num_images_per_prompt: int = 1, max_sequence_length: int = 128, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None, ): device = device or self._execution_device dtype = dtype or self.text_encoder_4.dtype prompt = [prompt] if isinstance(prompt, str) else prompt batch_size = len(prompt) text_inputs = self.tokenizer_4( prompt, padding="max_length", max_length=min(max_sequence_length, self.tokenizer_4.model_max_length), truncation=True, add_special_tokens=True, return_tensors="pt", ) text_input_ids = text_inputs.input_ids attention_mask = text_inputs.attention_mask untruncated_ids = self.tokenizer_4(prompt, padding="longest", return_tensors="pt").input_ids if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): removed_text = self.tokenizer_4.batch_decode(untruncated_ids[:, min(max_sequence_length, self.tokenizer_4.model_max_length) - 1 : -1]) logger.warning( "The following part of your input was truncated because `max_sequence_length` is set to " f" {min(max_sequence_length, self.tokenizer_4.model_max_length)} tokens: {removed_text}" ) outputs = self.text_encoder_4( text_input_ids.to(device), attention_mask=attention_mask.to(device), output_hidden_states=True, output_attentions=True ) prompt_embeds = outputs.hidden_states[1:] prompt_embeds = torch.stack(prompt_embeds, dim=0) _, _, seq_len, dim = prompt_embeds.shape # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method prompt_embeds = prompt_embeds.repeat(1, 1, num_images_per_prompt, 1) prompt_embeds = prompt_embeds.view(-1, batch_size * num_images_per_prompt, seq_len, dim) return prompt_embeds def encode_prompt( self, prompt: Union[str, List[str]], prompt_2: Union[str, List[str]], prompt_3: Union[str, List[str]], prompt_4: Union[str, List[str]], device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None, num_images_per_prompt: int = 1, do_classifier_free_guidance: bool = True, negative_prompt: Optional[Union[str, List[str]]] = None, negative_prompt_2: Optional[Union[str, List[str]]] = None, negative_prompt_3: Optional[Union[str, List[str]]] = None, negative_prompt_4: Optional[Union[str, List[str]]] = None, prompt_embeds: Optional[List[torch.FloatTensor]] = None, negative_prompt_embeds: Optional[torch.FloatTensor] = None, pooled_prompt_embeds: Optional[torch.FloatTensor] = None, negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None, max_sequence_length: int = 128, lora_scale: Optional[float] = None, ): prompt = [prompt] if isinstance(prompt, str) else prompt if prompt is not None: batch_size = len(prompt) else: batch_size = prompt_embeds.shape[0] prompt_embeds, pooled_prompt_embeds = self._encode_prompt( prompt = prompt, prompt_2 = prompt_2, prompt_3 = prompt_3, prompt_4 = prompt_4, device = device, dtype = dtype, num_images_per_prompt = num_images_per_prompt, prompt_embeds = prompt_embeds, pooled_prompt_embeds = pooled_prompt_embeds, max_sequence_length = max_sequence_length, ) if do_classifier_free_guidance and negative_prompt_embeds is None: negative_prompt = negative_prompt or "" negative_prompt_2 = negative_prompt_2 or negative_prompt negative_prompt_3 = negative_prompt_3 or negative_prompt negative_prompt_4 = negative_prompt_4 or negative_prompt # normalize str to list negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt negative_prompt_2 = ( batch_size * [negative_prompt_2] if isinstance(negative_prompt_2, str) else negative_prompt_2 ) negative_prompt_3 = ( batch_size * [negative_prompt_3] if isinstance(negative_prompt_3, str) else negative_prompt_3 ) negative_prompt_4 = ( batch_size * [negative_prompt_4] if isinstance(negative_prompt_4, str) else negative_prompt_4 ) if prompt is not None and type(prompt) is not type(negative_prompt): raise TypeError( f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" f" {type(prompt)}." ) elif batch_size != len(negative_prompt): raise ValueError( f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" " the batch size of `prompt`." ) negative_prompt_embeds, negative_pooled_prompt_embeds = self._encode_prompt( prompt = negative_prompt, prompt_2 = negative_prompt_2, prompt_3 = negative_prompt_3, prompt_4 = negative_prompt_4, device = device, dtype = dtype, num_images_per_prompt = num_images_per_prompt, prompt_embeds = negative_prompt_embeds, pooled_prompt_embeds = negative_pooled_prompt_embeds, max_sequence_length = max_sequence_length, ) return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds def _encode_prompt( self, prompt: Union[str, List[str]], prompt_2: Union[str, List[str]], prompt_3: Union[str, List[str]], prompt_4: Union[str, List[str]], device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None, num_images_per_prompt: int = 1, prompt_embeds: Optional[List[torch.FloatTensor]] = None, pooled_prompt_embeds: Optional[torch.FloatTensor] = None, max_sequence_length: int = 128, ): device = device or self._execution_device if prompt_embeds is None: prompt_2 = prompt_2 or prompt prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2 prompt_3 = prompt_3 or prompt prompt_3 = [prompt_3] if isinstance(prompt_3, str) else prompt_3 prompt_4 = prompt_4 or prompt prompt_4 = [prompt_4] if isinstance(prompt_4, str) else prompt_4 pooled_prompt_embeds_1 = self._get_clip_prompt_embeds( self.tokenizer, self.text_encoder, prompt = prompt, num_images_per_prompt = num_images_per_prompt, max_sequence_length = max_sequence_length, device = device, dtype = dtype, ) pooled_prompt_embeds_2 = self._get_clip_prompt_embeds( self.tokenizer_2, self.text_encoder_2, prompt = prompt_2, num_images_per_prompt = num_images_per_prompt, max_sequence_length = max_sequence_length, device = device, dtype = dtype, ) pooled_prompt_embeds = torch.cat([pooled_prompt_embeds_1, pooled_prompt_embeds_2], dim=-1) t5_prompt_embeds = self._get_t5_prompt_embeds( prompt = prompt_3, num_images_per_prompt = num_images_per_prompt, max_sequence_length = max_sequence_length, device = device, dtype = dtype ) llama3_prompt_embeds = self._get_llama3_prompt_embeds( prompt = prompt_4, num_images_per_prompt = num_images_per_prompt, max_sequence_length = max_sequence_length, device = device, dtype = dtype ) prompt_embeds = [t5_prompt_embeds, llama3_prompt_embeds] return prompt_embeds, pooled_prompt_embeds def enable_vae_slicing(self): r""" Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to compute decoding in several steps. This is useful to save some memory and allow larger batch sizes. """ self.vae.enable_slicing() def disable_vae_slicing(self): r""" Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to computing decoding in one step. """ self.vae.disable_slicing() def enable_vae_tiling(self): r""" Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow processing larger images. """ self.vae.enable_tiling() def disable_vae_tiling(self): r""" Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to computing decoding in one step. """ self.vae.disable_tiling() def prepare_latents( self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None, ): # VAE applies 8x compression on images but we must also account for packing which requires # latent height and width to be divisible by 2. height = 2 * (int(height) // (self.vae_scale_factor * 2)) width = 2 * (int(width) // (self.vae_scale_factor * 2)) shape = (batch_size, num_channels_latents, height, width) if latents is None: latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) else: if latents.shape != shape: raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}") latents = latents.to(device) return latents @property def guidance_scale(self): return self._guidance_scale @property def do_classifier_free_guidance(self): return self._guidance_scale > 1 @property def joint_attention_kwargs(self): return self._joint_attention_kwargs @property def num_timesteps(self): return self._num_timesteps @property def interrupt(self): return self._interrupt @torch.no_grad() def __call__( self, prompt: Union[str, List[str]] = None, prompt_2: Optional[Union[str, List[str]]] = None, prompt_3: Optional[Union[str, List[str]]] = None, prompt_4: Optional[Union[str, List[str]]] = None, height: Optional[int] = None, width: Optional[int] = None, num_inference_steps: int = 50, sigmas: Optional[List[float]] = None, guidance_scale: float = 5.0, negative_prompt: Optional[Union[str, List[str]]] = None, negative_prompt_2: Optional[Union[str, List[str]]] = None, negative_prompt_3: Optional[Union[str, List[str]]] = None, negative_prompt_4: Optional[Union[str, List[str]]] = None, num_images_per_prompt: Optional[int] = 1, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, latents: Optional[torch.FloatTensor] = None, prompt_embeds: Optional[torch.FloatTensor] = None, negative_prompt_embeds: Optional[torch.FloatTensor] = None, pooled_prompt_embeds: Optional[torch.FloatTensor] = None, negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None, output_type: Optional[str] = "pil", return_dict: bool = True, joint_attention_kwargs: Optional[Dict[str, Any]] = None, callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, callback_on_step_end_tensor_inputs: List[str] = ["latents"], max_sequence_length: int = 128, ): height = height or self.default_sample_size * self.vae_scale_factor width = width or self.default_sample_size * self.vae_scale_factor division = self.vae_scale_factor * 2 S_max = (self.default_sample_size * self.vae_scale_factor) ** 2 scale = S_max / (width * height) scale = math.sqrt(scale) width, height = int(width * scale // division * division), int(height * scale // division * division) self._guidance_scale = guidance_scale self._joint_attention_kwargs = joint_attention_kwargs self._interrupt = False # 2. Define call parameters if prompt is not None and isinstance(prompt, str): batch_size = 1 elif prompt is not None and isinstance(prompt, list): batch_size = len(prompt) else: batch_size = prompt_embeds.shape[0] device = self._execution_device lora_scale = ( self.joint_attention_kwargs.get("scale", None) if self.joint_attention_kwargs is not None else None ) ( prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds, ) = self.encode_prompt( prompt=prompt, prompt_2=prompt_2, prompt_3=prompt_3, prompt_4=prompt_4, negative_prompt=negative_prompt, negative_prompt_2=negative_prompt_2, negative_prompt_3=negative_prompt_3, negative_prompt_4=negative_prompt_4, do_classifier_free_guidance=self.do_classifier_free_guidance, prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_prompt_embeds, pooled_prompt_embeds=pooled_prompt_embeds, negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, device=device, num_images_per_prompt=num_images_per_prompt, max_sequence_length=max_sequence_length, lora_scale=lora_scale, ) if self.do_classifier_free_guidance: prompt_embeds_arr = [] for n, p in zip(negative_prompt_embeds, prompt_embeds): if len(n.shape) == 3: prompt_embeds_arr.append(torch.cat([n, p], dim=0)) else: prompt_embeds_arr.append(torch.cat([n, p], dim=1)) prompt_embeds = prompt_embeds_arr pooled_prompt_embeds = torch.cat([negative_pooled_prompt_embeds, pooled_prompt_embeds], dim=0) # 4. Prepare latent variables num_channels_latents = self.transformer.config.in_channels latents = self.prepare_latents( batch_size * num_images_per_prompt, num_channels_latents, height, width, pooled_prompt_embeds.dtype, device, generator, latents, ) if latents.shape[-2] != latents.shape[-1]: B, C, H, W = latents.shape pH, pW = H // self.transformer.config.patch_size, W // self.transformer.config.patch_size img_sizes = torch.tensor([pH, pW], dtype=torch.int64).reshape(-1) img_ids = torch.zeros(pH, pW, 3) img_ids[..., 1] = img_ids[..., 1] + torch.arange(pH)[:, None] img_ids[..., 2] = img_ids[..., 2] + torch.arange(pW)[None, :] img_ids = img_ids.reshape(pH * pW, -1) img_ids_pad = torch.zeros(self.transformer.max_seq, 3) img_ids_pad[:pH*pW, :] = img_ids img_sizes = img_sizes.unsqueeze(0).to(latents.device) img_ids = img_ids_pad.unsqueeze(0).to(latents.device) if self.do_classifier_free_guidance: img_sizes = img_sizes.repeat(2 * B, 1) img_ids = img_ids.repeat(2 * B, 1, 1) else: img_sizes = img_ids = None # 5. Prepare timesteps mu = calculate_shift(self.transformer.max_seq) scheduler_kwargs = {"mu": mu} if isinstance(self.scheduler, FlowUniPCMultistepScheduler): self.scheduler.set_timesteps(num_inference_steps, device=device, shift=math.exp(mu)) timesteps = self.scheduler.timesteps else: timesteps, num_inference_steps = retrieve_timesteps( self.scheduler, num_inference_steps, device, sigmas=sigmas, **scheduler_kwargs, ) num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) self._num_timesteps = len(timesteps) # 6. Denoising loop with self.progress_bar(total=num_inference_steps) as progress_bar: for i, t in enumerate(timesteps): if self.interrupt: continue # expand the latents if we are doing classifier free guidance latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents # broadcast to batch dimension in a way that's compatible with ONNX/Core ML timestep = t.expand(latent_model_input.shape[0]) if latent_model_input.shape[-2] != latent_model_input.shape[-1]: B, C, H, W = latent_model_input.shape patch_size = self.transformer.config.patch_size pH, pW = H // patch_size, W // patch_size out = torch.zeros( (B, C, self.transformer.max_seq, patch_size * patch_size), dtype=latent_model_input.dtype, device=latent_model_input.device ) latent_model_input = einops.rearrange(latent_model_input, 'B C (H p1) (W p2) -> B C (H W) (p1 p2)', p1=patch_size, p2=patch_size) out[:, :, 0:pH*pW] = latent_model_input latent_model_input = out noise_pred = self.transformer( hidden_states = latent_model_input, timesteps = timestep, encoder_hidden_states = prompt_embeds, pooled_embeds = pooled_prompt_embeds, img_sizes = img_sizes, img_ids = img_ids, return_dict = False, )[0] noise_pred = -noise_pred # perform guidance if self.do_classifier_free_guidance: noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond) # compute the previous noisy sample x_t -> x_t-1 latents_dtype = latents.dtype latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] if latents.dtype != latents_dtype: if torch.backends.mps.is_available(): # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 latents = latents.to(latents_dtype) if callback_on_step_end is not None: callback_kwargs = {} for k in callback_on_step_end_tensor_inputs: callback_kwargs[k] = locals()[k] callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) latents = callback_outputs.pop("latents", latents) prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) # call the callback, if provided if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): progress_bar.update() if XLA_AVAILABLE: xm.mark_step() if output_type == "latent": image = latents else: latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor image = self.vae.decode(latents, return_dict=False)[0] image = self.image_processor.postprocess(image, output_type=output_type) # Offload all models self.maybe_free_model_hooks() if not return_dict: return (image,) return HiDreamImagePipelineOutput(images=image)