diff --git a/app.py b/app.py
new file mode 100644
index 0000000000000000000000000000000000000000..58c6df4d10fa29381847f91a955bea86d955cb61
--- /dev/null
+++ b/app.py
@@ -0,0 +1,238 @@
+import gradio as gr
+import torch
+from PIL import Image
+import os
+from transformers import CLIPTokenizer, CLIPTextModel, AutoProcessor, T5EncoderModel, T5TokenizerFast
+from diffusers import AutoencoderKL, FlowMatchEulerDiscreteScheduler
+from flux.transformer_flux import FluxTransformer2DModel
+from flux.pipeline_flux_chameleon import FluxPipeline
+import torch.nn as nn
+
+MODEL_ID = "Djrango/Qwen2vl-Flux"
+
+class Qwen2Connector(nn.Module):
+ def __init__(self, input_dim=3584, output_dim=4096):
+ super().__init__()
+ self.linear = nn.Linear(input_dim, output_dim)
+
+ def forward(self, x):
+ return self.linear(x)
+
+class FluxInterface:
+ def __init__(self, device="cuda" if torch.cuda.is_available() else "cpu"):
+ self.device = device
+ self.dtype = torch.bfloat16
+ self.models = None
+ self.MODEL_ID = "Djrango/Qwen2vl-Flux"
+
+ def load_models(self):
+ if self.models is not None:
+ return
+
+ # Load FLUX components
+ tokenizer = CLIPTokenizer.from_pretrained(self.MODEL_ID, subfolder="flux/tokenizer")
+ text_encoder = CLIPTextModel.from_pretrained(self.MODEL_ID, subfolder="flux/text_encoder")
+ text_encoder_two = T5EncoderModel.from_pretrained(self.MODEL_ID, subfolder="flux/text_encoder_2")
+ tokenizer_two = T5TokenizerFast.from_pretrained(self.MODEL_ID, subfolder="flux/tokenizer_2")
+
+ # Load VAE and transformer from flux folder
+ vae = AutoencoderKL.from_pretrained(self.MODEL_ID, subfolder="flux")
+ transformer = FluxTransformer2DModel.from_pretrained(self.MODEL_ID, subfolder="flux")
+ scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(self.MODEL_ID, subfolder="flux/scheduler", shift=1)
+
+ # Load Qwen2VL components from qwen2-vl folder
+ qwen2vl = Qwen2VLSimplifiedModel.from_pretrained(self.MODEL_ID, subfolder="qwen2-vl")
+
+ # Load connector and t5 embedder from qwen2-vl folder
+ connector = Qwen2Connector()
+ connector_path = f"https://huggingface.co/{self.MODEL_ID}/resolve/main/qwen2-vl/connector.pt"
+ connector_state = torch.hub.load_state_dict_from_url(connector_path, map_location=self.device)
+ connector.load_state_dict(connector_state)
+
+ # Load T5 embedder
+ self.t5_context_embedder = nn.Linear(4096, 3072)
+ t5_embedder_path = f"https://huggingface.co/{self.MODEL_ID}/resolve/main/qwen2-vl/t5_embedder.pt"
+ t5_embedder_state = torch.hub.load_state_dict_from_url(t5_embedder_path, map_location=self.device)
+ self.t5_context_embedder.load_state_dict(t5_embedder_state)
+
+ # Move models to device and set dtype
+ models = [text_encoder, text_encoder_two, vae, transformer, qwen2vl, connector, self.t5_context_embedder]
+ for model in models:
+ model.to(self.device).to(self.dtype)
+ model.eval()
+
+ self.models = {
+ 'tokenizer': tokenizer,
+ 'text_encoder': text_encoder,
+ 'text_encoder_two': text_encoder_two,
+ 'tokenizer_two': tokenizer_two,
+ 'vae': vae,
+ 'transformer': transformer,
+ 'scheduler': scheduler,
+ 'qwen2vl': qwen2vl,
+ 'connector': connector
+ }
+
+ # Initialize processor and pipeline
+ self.qwen2vl_processor = AutoProcessor.from_pretrained(
+ self.MODEL_ID,
+ subfolder="qwen2-vl",
+ min_pixels=256*28*28,
+ max_pixels=256*28*28
+ )
+
+ self.pipeline = FluxPipeline(
+ transformer=transformer,
+ scheduler=scheduler,
+ vae=vae,
+ text_encoder=text_encoder,
+ tokenizer=tokenizer,
+ )
+
+ def resize_image(self, img, max_pixels=1050000):
+ if not isinstance(img, Image.Image):
+ img = Image.fromarray(img)
+
+ width, height = img.size
+ num_pixels = width * height
+
+ if num_pixels > max_pixels:
+ scale = math.sqrt(max_pixels / num_pixels)
+ new_width = int(width * scale)
+ new_height = int(height * scale)
+ new_width = new_width - (new_width % 8)
+ new_height = new_height - (new_height % 8)
+ img = img.resize((new_width, new_height), Image.LANCZOS)
+
+ return img
+
+ def process_image(self, image):
+ message = [
+ {
+ "role": "user",
+ "content": [
+ {"type": "image", "image": image},
+ {"type": "text", "text": "Describe this image."},
+ ]
+ }
+ ]
+ text = self.qwen2vl_processor.apply_chat_template(message, tokenize=False, add_generation_prompt=True)
+
+ with torch.no_grad():
+ inputs = self.qwen2vl_processor(text=[text], images=[image], padding=True, return_tensors="pt").to(self.device)
+ output_hidden_state, image_token_mask, image_grid_thw = self.models['qwen2vl'](**inputs)
+ image_hidden_state = output_hidden_state[image_token_mask].view(1, -1, output_hidden_state.size(-1))
+ image_hidden_state = self.models['connector'](image_hidden_state)
+
+ return image_hidden_state, image_grid_thw
+
+ def compute_t5_text_embeddings(self, prompt):
+ """Compute T5 embeddings for text prompt"""
+ if prompt == "":
+ return None
+
+ text_inputs = self.models['tokenizer_two'](
+ prompt,
+ padding="max_length",
+ max_length=256,
+ truncation=True,
+ return_tensors="pt"
+ ).to(self.device)
+
+ prompt_embeds = self.models['text_encoder_two'](text_inputs.input_ids)[0]
+ prompt_embeds = prompt_embeds.to(dtype=self.dtype, device=self.device)
+ prompt_embeds = self.t5_context_embedder(prompt_embeds)
+
+ return prompt_embeds
+
+ def compute_text_embeddings(self, prompt=""):
+ with torch.no_grad():
+ text_inputs = self.models['tokenizer'](
+ prompt,
+ padding="max_length",
+ max_length=77,
+ truncation=True,
+ return_tensors="pt"
+ ).to(self.device)
+
+ prompt_embeds = self.models['text_encoder'](
+ text_inputs.input_ids,
+ output_hidden_states=False
+ )
+ pooled_prompt_embeds = prompt_embeds.pooler_output.to(self.dtype)
+
+ return pooled_prompt_embeds
+
+ def generate(self, input_image, prompt="", guidance_scale=3.5, num_inference_steps=28, num_images=2, seed=None):
+ try:
+ if seed is not None:
+ torch.manual_seed(seed)
+
+ self.load_models()
+
+ # Process input image
+ input_image = self.resize_image(input_image)
+ qwen2_hidden_state, image_grid_thw = self.process_image(input_image)
+ pooled_prompt_embeds = self.compute_text_embeddings("")
+
+ # Get T5 embeddings if prompt is provided
+ t5_prompt_embeds = self.compute_t5_text_embeddings(prompt)
+
+ # Generate images
+ output_images = self.pipeline(
+ prompt_embeds=qwen2_hidden_state.repeat(num_images, 1, 1),
+ pooled_prompt_embeds=pooled_prompt_embeds,
+ t5_prompt_embeds=t5_prompt_embeds.repeat(num_images, 1, 1) if t5_prompt_embeds is not None else None,
+ num_inference_steps=num_inference_steps,
+ guidance_scale=guidance_scale,
+ ).images
+
+ return output_images
+
+ except Exception as e:
+ print(f"Error during generation: {str(e)}")
+ raise gr.Error(f"Generation failed: {str(e)}")
+
+# Initialize the interface
+interface = FluxInterface()
+
+# Create Gradio interface
+with gr.Blocks(title="Qwen2vl-Flux Demo") as demo:
+ gr.Markdown("""
+ # ๐จ Qwen2vl-Flux Image Variation Demo
+ Upload an image and get AI-generated variations. You can optionally add a text prompt to guide the generation.
+ """)
+
+ with gr.Row():
+ with gr.Column():
+ input_image = gr.Image(label="Upload Image", type="pil")
+ prompt = gr.Textbox(label="Optional Text Prompt(should be as long as possible)", placeholder="Enter text prompt here (optional)")
+
+ with gr.Row():
+ guidance = gr.Slider(minimum=1, maximum=10, value=3.5, label="Guidance Scale")
+ steps = gr.Slider(minimum=1, maximum=50, value=28, label="Number of Steps")
+ num_images = gr.Slider(minimum=1, maximum=4, value=2, step=1, label="Number of Images")
+
+ seed = gr.Number(label="Random Seed (optional)", precision=0)
+ submit_btn = gr.Button("Generate Variations", variant="primary")
+
+ with gr.Column():
+ output_gallery = gr.Gallery(label="Generated Variations", columns=2, show_label=True)
+
+ # Set up the generation function
+ submit_btn.click(
+ fn=interface.generate,
+ inputs=[input_image, prompt, guidance, steps, num_images, seed],
+ outputs=output_gallery,
+ )
+
+ gr.Markdown("""
+ ### Notes:
+ - Higher guidance scale values result in outputs that more closely follow the prompt
+ - More steps generally produce better quality but take longer
+ - Set a seed for reproducible results
+ """)
+
+# Launch the app
+if __name__ == "__main__":
+ demo.launch()
\ No newline at end of file
diff --git a/flux/__pycache__/activations.cpython-310.pyc b/flux/__pycache__/activations.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..4132cdaf5cd83093e0edbd7155f0a6d955d83ac5
Binary files /dev/null and b/flux/__pycache__/activations.cpython-310.pyc differ
diff --git a/flux/__pycache__/alimama_controlnet.cpython-310.pyc b/flux/__pycache__/alimama_controlnet.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..90628ba1ffb55ff44eafc89973e415b133cdab81
Binary files /dev/null and b/flux/__pycache__/alimama_controlnet.cpython-310.pyc differ
diff --git a/flux/__pycache__/attention_processor.cpython-310.pyc b/flux/__pycache__/attention_processor.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..36cfd4bbbd1c84d25afaf6a4d8d163e2508f0109
Binary files /dev/null and b/flux/__pycache__/attention_processor.cpython-310.pyc differ
diff --git a/flux/__pycache__/controlnet_flux.cpython-310.pyc b/flux/__pycache__/controlnet_flux.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..18176da42a2e3ba54bfc8a85a03fb32519c82b3b
Binary files /dev/null and b/flux/__pycache__/controlnet_flux.cpython-310.pyc differ
diff --git a/flux/__pycache__/embeddings.cpython-310.pyc b/flux/__pycache__/embeddings.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..f97d28e977e5a01ca1901c1c23eae21cb3a0abbc
Binary files /dev/null and b/flux/__pycache__/embeddings.cpython-310.pyc differ
diff --git a/flux/__pycache__/flux_network.cpython-310.pyc b/flux/__pycache__/flux_network.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..98e085e5c2194016dd6406a26a0282435084f50b
Binary files /dev/null and b/flux/__pycache__/flux_network.cpython-310.pyc differ
diff --git a/flux/__pycache__/normalization.cpython-310.pyc b/flux/__pycache__/normalization.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..a84cf49d4fffc37f252f00c69c7910eb61783463
Binary files /dev/null and b/flux/__pycache__/normalization.cpython-310.pyc differ
diff --git a/flux/__pycache__/pipeline_flux.cpython-310.pyc b/flux/__pycache__/pipeline_flux.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..248d184326fa197d8e53f4c5817fe0a8db75ff4d
Binary files /dev/null and b/flux/__pycache__/pipeline_flux.cpython-310.pyc differ
diff --git a/flux/__pycache__/pipeline_flux_chameleon.cpython-310.pyc b/flux/__pycache__/pipeline_flux_chameleon.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..7c912ca9a5c852e4932dba44eec878bf960c0fa5
Binary files /dev/null and b/flux/__pycache__/pipeline_flux_chameleon.cpython-310.pyc differ
diff --git a/flux/__pycache__/pipeline_flux_chameleon_no_clipl.cpython-310.pyc b/flux/__pycache__/pipeline_flux_chameleon_no_clipl.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..ed6f44b3ac1f5e2186a34fe762ad230677219468
Binary files /dev/null and b/flux/__pycache__/pipeline_flux_chameleon_no_clipl.cpython-310.pyc differ
diff --git a/flux/__pycache__/pipeline_flux_controlnet.cpython-310.pyc b/flux/__pycache__/pipeline_flux_controlnet.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..9c21f68bee0fcd71118583e8251704c52621ba90
Binary files /dev/null and b/flux/__pycache__/pipeline_flux_controlnet.cpython-310.pyc differ
diff --git a/flux/__pycache__/pipeline_flux_controlnet_inpainting.cpython-310.pyc b/flux/__pycache__/pipeline_flux_controlnet_inpainting.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..98b41013ce7883337bc8aa032cef30a9d331c946
Binary files /dev/null and b/flux/__pycache__/pipeline_flux_controlnet_inpainting.cpython-310.pyc differ
diff --git a/flux/__pycache__/pipeline_flux_img2img.cpython-310.pyc b/flux/__pycache__/pipeline_flux_img2img.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..b2e1a020bccec1315fea9a308fa88645a0e18bb0
Binary files /dev/null and b/flux/__pycache__/pipeline_flux_img2img.cpython-310.pyc differ
diff --git a/flux/__pycache__/pipeline_flux_inpaint.cpython-310.pyc b/flux/__pycache__/pipeline_flux_inpaint.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..2b3fc75b4f1b5e9489f5c8eb64f8257dff1b09c1
Binary files /dev/null and b/flux/__pycache__/pipeline_flux_inpaint.cpython-310.pyc differ
diff --git a/flux/__pycache__/pipeline_flux_inversion.cpython-310.pyc b/flux/__pycache__/pipeline_flux_inversion.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..3e56881cc0ca835f88074101607dc93fe15e6235
Binary files /dev/null and b/flux/__pycache__/pipeline_flux_inversion.cpython-310.pyc differ
diff --git a/flux/__pycache__/pipeline_flux_rb.cpython-310.pyc b/flux/__pycache__/pipeline_flux_rb.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..a4efbcf07b0e8e80307be5b65ef792d428bae0b0
Binary files /dev/null and b/flux/__pycache__/pipeline_flux_rb.cpython-310.pyc differ
diff --git a/flux/__pycache__/pipeline_output.cpython-310.pyc b/flux/__pycache__/pipeline_output.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..278da3d9e195687a670252c69fea4cbc179dc91c
Binary files /dev/null and b/flux/__pycache__/pipeline_output.cpython-310.pyc differ
diff --git a/flux/__pycache__/scheduling_flow_match_euler_discrete.cpython-310.pyc b/flux/__pycache__/scheduling_flow_match_euler_discrete.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..24b389306fa6403b20531d87fe396167c368a114
Binary files /dev/null and b/flux/__pycache__/scheduling_flow_match_euler_discrete.cpython-310.pyc differ
diff --git a/flux/__pycache__/transformer_flux.cpython-310.pyc b/flux/__pycache__/transformer_flux.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..cfac3aca7fd0b0aed4000f4ac58715f4a8d2831c
Binary files /dev/null and b/flux/__pycache__/transformer_flux.cpython-310.pyc differ
diff --git a/flux/activations.py b/flux/activations.py
new file mode 100644
index 0000000000000000000000000000000000000000..9a6676b887d47e01d9520fd519bec5bd953385be
--- /dev/null
+++ b/flux/activations.py
@@ -0,0 +1,165 @@
+# coding=utf-8
+# Copyright 2024 HuggingFace Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import torch
+import torch.nn.functional as F
+from torch import nn
+
+from diffusers.utils import deprecate
+from diffusers.utils.import_utils import is_torch_npu_available
+
+
+if is_torch_npu_available():
+ import torch_npu
+
+ACTIVATION_FUNCTIONS = {
+ "swish": nn.SiLU(),
+ "silu": nn.SiLU(),
+ "mish": nn.Mish(),
+ "gelu": nn.GELU(),
+ "relu": nn.ReLU(),
+}
+
+
+def get_activation(act_fn: str) -> nn.Module:
+ """Helper function to get activation function from string.
+
+ Args:
+ act_fn (str): Name of activation function.
+
+ Returns:
+ nn.Module: Activation function.
+ """
+
+ act_fn = act_fn.lower()
+ if act_fn in ACTIVATION_FUNCTIONS:
+ return ACTIVATION_FUNCTIONS[act_fn]
+ else:
+ raise ValueError(f"Unsupported activation function: {act_fn}")
+
+
+class FP32SiLU(nn.Module):
+ r"""
+ SiLU activation function with input upcasted to torch.float32.
+ """
+
+ def __init__(self):
+ super().__init__()
+
+ def forward(self, inputs: torch.Tensor) -> torch.Tensor:
+ return F.silu(inputs.float(), inplace=False).to(inputs.dtype)
+
+
+class GELU(nn.Module):
+ r"""
+ GELU activation function with tanh approximation support with `approximate="tanh"`.
+
+ Parameters:
+ dim_in (`int`): The number of channels in the input.
+ dim_out (`int`): The number of channels in the output.
+ approximate (`str`, *optional*, defaults to `"none"`): If `"tanh"`, use tanh approximation.
+ bias (`bool`, defaults to True): Whether to use a bias in the linear layer.
+ """
+
+ def __init__(self, dim_in: int, dim_out: int, approximate: str = "none", bias: bool = True):
+ super().__init__()
+ self.proj = nn.Linear(dim_in, dim_out, bias=bias)
+ self.approximate = approximate
+
+ def gelu(self, gate: torch.Tensor) -> torch.Tensor:
+ if gate.device.type != "mps":
+ return F.gelu(gate, approximate=self.approximate)
+ # mps: gelu is not implemented for float16
+ return F.gelu(gate.to(dtype=torch.float32), approximate=self.approximate).to(dtype=gate.dtype)
+
+ def forward(self, hidden_states):
+ hidden_states = self.proj(hidden_states)
+ hidden_states = self.gelu(hidden_states)
+ return hidden_states
+
+
+class GEGLU(nn.Module):
+ r"""
+ A [variant](https://arxiv.org/abs/2002.05202) of the gated linear unit activation function.
+
+ Parameters:
+ dim_in (`int`): The number of channels in the input.
+ dim_out (`int`): The number of channels in the output.
+ bias (`bool`, defaults to True): Whether to use a bias in the linear layer.
+ """
+
+ def __init__(self, dim_in: int, dim_out: int, bias: bool = True):
+ super().__init__()
+ self.proj = nn.Linear(dim_in, dim_out * 2, bias=bias)
+
+ def gelu(self, gate: torch.Tensor) -> torch.Tensor:
+ if gate.device.type != "mps":
+ return F.gelu(gate)
+ # mps: gelu is not implemented for float16
+ return F.gelu(gate.to(dtype=torch.float32)).to(dtype=gate.dtype)
+
+ def forward(self, hidden_states, *args, **kwargs):
+ if len(args) > 0 or kwargs.get("scale", None) is not None:
+ deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
+ deprecate("scale", "1.0.0", deprecation_message)
+ hidden_states = self.proj(hidden_states)
+ if is_torch_npu_available():
+ # using torch_npu.npu_geglu can run faster and save memory on NPU.
+ return torch_npu.npu_geglu(hidden_states, dim=-1, approximate=1)[0]
+ else:
+ hidden_states, gate = hidden_states.chunk(2, dim=-1)
+ return hidden_states * self.gelu(gate)
+
+
+class SwiGLU(nn.Module):
+ r"""
+ A [variant](https://arxiv.org/abs/2002.05202) of the gated linear unit activation function. It's similar to `GEGLU`
+ but uses SiLU / Swish instead of GeLU.
+
+ Parameters:
+ dim_in (`int`): The number of channels in the input.
+ dim_out (`int`): The number of channels in the output.
+ bias (`bool`, defaults to True): Whether to use a bias in the linear layer.
+ """
+
+ def __init__(self, dim_in: int, dim_out: int, bias: bool = True):
+ super().__init__()
+ self.proj = nn.Linear(dim_in, dim_out * 2, bias=bias)
+ self.activation = nn.SiLU()
+
+ def forward(self, hidden_states):
+ hidden_states = self.proj(hidden_states)
+ hidden_states, gate = hidden_states.chunk(2, dim=-1)
+ return hidden_states * self.activation(gate)
+
+
+class ApproximateGELU(nn.Module):
+ r"""
+ The approximate form of the Gaussian Error Linear Unit (GELU). For more details, see section 2 of this
+ [paper](https://arxiv.org/abs/1606.08415).
+
+ Parameters:
+ dim_in (`int`): The number of channels in the input.
+ dim_out (`int`): The number of channels in the output.
+ bias (`bool`, defaults to True): Whether to use a bias in the linear layer.
+ """
+
+ def __init__(self, dim_in: int, dim_out: int, bias: bool = True):
+ super().__init__()
+ self.proj = nn.Linear(dim_in, dim_out, bias=bias)
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ x = self.proj(x)
+ return x * torch.sigmoid(1.702 * x)
\ No newline at end of file
diff --git a/flux/attention.py b/flux/attention.py
new file mode 100644
index 0000000000000000000000000000000000000000..071fa156c4b26aae196c2b1e4219ad5cefa24c7b
--- /dev/null
+++ b/flux/attention.py
@@ -0,0 +1,843 @@
+# Copyright 2024 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from typing import Any, Dict, Optional
+
+import torch
+import torch.nn.functional as F
+from torch import nn
+
+from diffusers.utils import deprecate, logging
+from diffusers.utils.torch_utils import maybe_allow_in_graph
+from .activations import GEGLU, GELU, ApproximateGELU, FP32SiLU, SwiGLU
+from .attention_processor import Attention, JointAttnProcessor2_0
+from .embeddings import SinusoidalPositionalEmbedding
+from .normalization import AdaLayerNorm, AdaLayerNormContinuous, AdaLayerNormZero, RMSNorm
+
+
+logger = logging.get_logger(__name__)
+
+
+def _chunked_feed_forward(ff: nn.Module, hidden_states: torch.Tensor, chunk_dim: int, chunk_size: int):
+ # "feed_forward_chunk_size" can be used to save memory
+ if hidden_states.shape[chunk_dim] % chunk_size != 0:
+ raise ValueError(
+ f"`hidden_states` dimension to be chunked: {hidden_states.shape[chunk_dim]} has to be divisible by chunk size: {chunk_size}. Make sure to set an appropriate `chunk_size` when calling `unet.enable_forward_chunking`."
+ )
+
+ num_chunks = hidden_states.shape[chunk_dim] // chunk_size
+ ff_output = torch.cat(
+ [ff(hid_slice) for hid_slice in hidden_states.chunk(num_chunks, dim=chunk_dim)],
+ dim=chunk_dim,
+ )
+ return ff_output
+
+
+@maybe_allow_in_graph
+class GatedSelfAttentionDense(nn.Module):
+ r"""
+ A gated self-attention dense layer that combines visual features and object features.
+
+ Parameters:
+ query_dim (`int`): The number of channels in the query.
+ context_dim (`int`): The number of channels in the context.
+ n_heads (`int`): The number of heads to use for attention.
+ d_head (`int`): The number of channels in each head.
+ """
+
+ def __init__(self, query_dim: int, context_dim: int, n_heads: int, d_head: int):
+ super().__init__()
+
+ # we need a linear projection since we need cat visual feature and obj feature
+ self.linear = nn.Linear(context_dim, query_dim)
+
+ self.attn = Attention(query_dim=query_dim, heads=n_heads, dim_head=d_head)
+ self.ff = FeedForward(query_dim, activation_fn="geglu")
+
+ self.norm1 = nn.LayerNorm(query_dim)
+ self.norm2 = nn.LayerNorm(query_dim)
+
+ self.register_parameter("alpha_attn", nn.Parameter(torch.tensor(0.0)))
+ self.register_parameter("alpha_dense", nn.Parameter(torch.tensor(0.0)))
+
+ self.enabled = True
+
+ def forward(self, x: torch.Tensor, objs: torch.Tensor) -> torch.Tensor:
+ if not self.enabled:
+ return x
+
+ n_visual = x.shape[1]
+ objs = self.linear(objs)
+
+ x = x + self.alpha_attn.tanh() * self.attn(self.norm1(torch.cat([x, objs], dim=1)))[:, :n_visual, :]
+ x = x + self.alpha_dense.tanh() * self.ff(self.norm2(x))
+
+ return x
+
+
+@maybe_allow_in_graph
+class JointTransformerBlock(nn.Module):
+ r"""
+ A Transformer block following the MMDiT architecture, introduced in Stable Diffusion 3.
+
+ Reference: https://arxiv.org/abs/2403.03206
+
+ Parameters:
+ dim (`int`): The number of channels in the input and output.
+ num_attention_heads (`int`): The number of heads to use for multi-head attention.
+ attention_head_dim (`int`): The number of channels in each head.
+ context_pre_only (`bool`): Boolean to determine if we should add some blocks associated with the
+ processing of `context` conditions.
+ """
+
+ def __init__(self, dim, num_attention_heads, attention_head_dim, context_pre_only=False):
+ super().__init__()
+
+ self.context_pre_only = context_pre_only
+ context_norm_type = "ada_norm_continous" if context_pre_only else "ada_norm_zero"
+
+ self.norm1 = AdaLayerNormZero(dim)
+
+ if context_norm_type == "ada_norm_continous":
+ self.norm1_context = AdaLayerNormContinuous(
+ dim, dim, elementwise_affine=False, eps=1e-6, bias=True, norm_type="layer_norm"
+ )
+ elif context_norm_type == "ada_norm_zero":
+ self.norm1_context = AdaLayerNormZero(dim)
+ else:
+ raise ValueError(
+ f"Unknown context_norm_type: {context_norm_type}, currently only support `ada_norm_continous`, `ada_norm_zero`"
+ )
+ if hasattr(F, "scaled_dot_product_attention"):
+ processor = JointAttnProcessor2_0()
+ else:
+ raise ValueError(
+ "The current PyTorch version does not support the `scaled_dot_product_attention` function."
+ )
+ self.attn = Attention(
+ query_dim=dim,
+ cross_attention_dim=None,
+ added_kv_proj_dim=dim,
+ dim_head=attention_head_dim,
+ heads=num_attention_heads,
+ out_dim=dim,
+ context_pre_only=context_pre_only,
+ bias=True,
+ processor=processor,
+ )
+
+ self.norm2 = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
+ self.ff = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate")
+
+ if not context_pre_only:
+ self.norm2_context = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
+ self.ff_context = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate")
+ else:
+ self.norm2_context = None
+ self.ff_context = None
+
+ # let chunk size default to None
+ self._chunk_size = None
+ self._chunk_dim = 0
+
+ # Copied from diffusers.models.attention.BasicTransformerBlock.set_chunk_feed_forward
+ def set_chunk_feed_forward(self, chunk_size: Optional[int], dim: int = 0):
+ # Sets chunk feed-forward
+ self._chunk_size = chunk_size
+ self._chunk_dim = dim
+
+ def forward(
+ self, hidden_states: torch.FloatTensor, encoder_hidden_states: torch.FloatTensor, temb: torch.FloatTensor
+ ):
+ norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(hidden_states, emb=temb)
+
+ if self.context_pre_only:
+ norm_encoder_hidden_states = self.norm1_context(encoder_hidden_states, temb)
+ else:
+ norm_encoder_hidden_states, c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp = self.norm1_context(
+ encoder_hidden_states, emb=temb
+ )
+
+ # Attention.
+ attn_output, context_attn_output = self.attn(
+ hidden_states=norm_hidden_states, encoder_hidden_states=norm_encoder_hidden_states
+ )
+
+ # Process attention outputs for the `hidden_states`.
+ attn_output = gate_msa.unsqueeze(1) * attn_output
+ hidden_states = hidden_states + attn_output
+
+ norm_hidden_states = self.norm2(hidden_states)
+ norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
+ if self._chunk_size is not None:
+ # "feed_forward_chunk_size" can be used to save memory
+ ff_output = _chunked_feed_forward(self.ff, norm_hidden_states, self._chunk_dim, self._chunk_size)
+ else:
+ ff_output = self.ff(norm_hidden_states)
+ ff_output = gate_mlp.unsqueeze(1) * ff_output
+
+ hidden_states = hidden_states + ff_output
+
+ # Process attention outputs for the `encoder_hidden_states`.
+ if self.context_pre_only:
+ encoder_hidden_states = None
+ else:
+ context_attn_output = c_gate_msa.unsqueeze(1) * context_attn_output
+ encoder_hidden_states = encoder_hidden_states + context_attn_output
+
+ norm_encoder_hidden_states = self.norm2_context(encoder_hidden_states)
+ norm_encoder_hidden_states = norm_encoder_hidden_states * (1 + c_scale_mlp[:, None]) + c_shift_mlp[:, None]
+ if self._chunk_size is not None:
+ # "feed_forward_chunk_size" can be used to save memory
+ context_ff_output = _chunked_feed_forward(
+ self.ff_context, norm_encoder_hidden_states, self._chunk_dim, self._chunk_size
+ )
+ else:
+ context_ff_output = self.ff_context(norm_encoder_hidden_states)
+ encoder_hidden_states = encoder_hidden_states + c_gate_mlp.unsqueeze(1) * context_ff_output
+
+ return encoder_hidden_states, hidden_states
+
+
+@maybe_allow_in_graph
+class BasicTransformerBlock(nn.Module):
+ r"""
+ A basic Transformer block.
+
+ Parameters:
+ dim (`int`): The number of channels in the input and output.
+ num_attention_heads (`int`): The number of heads to use for multi-head attention.
+ attention_head_dim (`int`): The number of channels in each head.
+ dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
+ cross_attention_dim (`int`, *optional*): The size of the encoder_hidden_states vector for cross attention.
+ activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
+ num_embeds_ada_norm (:
+ obj: `int`, *optional*): The number of diffusion steps used during training. See `Transformer2DModel`.
+ attention_bias (:
+ obj: `bool`, *optional*, defaults to `False`): Configure if the attentions should contain a bias parameter.
+ only_cross_attention (`bool`, *optional*):
+ Whether to use only cross-attention layers. In this case two cross attention layers are used.
+ double_self_attention (`bool`, *optional*):
+ Whether to use two self-attention layers. In this case no cross attention layers are used.
+ upcast_attention (`bool`, *optional*):
+ Whether to upcast the attention computation to float32. This is useful for mixed precision training.
+ norm_elementwise_affine (`bool`, *optional*, defaults to `True`):
+ Whether to use learnable elementwise affine parameters for normalization.
+ norm_type (`str`, *optional*, defaults to `"layer_norm"`):
+ The normalization layer to use. Can be `"layer_norm"`, `"ada_norm"` or `"ada_norm_zero"`.
+ final_dropout (`bool` *optional*, defaults to False):
+ Whether to apply a final dropout after the last feed-forward layer.
+ attention_type (`str`, *optional*, defaults to `"default"`):
+ The type of attention to use. Can be `"default"` or `"gated"` or `"gated-text-image"`.
+ positional_embeddings (`str`, *optional*, defaults to `None`):
+ The type of positional embeddings to apply to.
+ num_positional_embeddings (`int`, *optional*, defaults to `None`):
+ The maximum number of positional embeddings to apply.
+ """
+
+ def __init__(
+ self,
+ dim: int,
+ num_attention_heads: int,
+ attention_head_dim: int,
+ dropout=0.0,
+ cross_attention_dim: Optional[int] = None,
+ activation_fn: str = "geglu",
+ num_embeds_ada_norm: Optional[int] = None,
+ attention_bias: bool = False,
+ only_cross_attention: bool = False,
+ double_self_attention: bool = False,
+ upcast_attention: bool = False,
+ norm_elementwise_affine: bool = True,
+ norm_type: str = "layer_norm", # 'layer_norm', 'ada_norm', 'ada_norm_zero', 'ada_norm_single', 'ada_norm_continuous', 'layer_norm_i2vgen'
+ norm_eps: float = 1e-5,
+ final_dropout: bool = False,
+ attention_type: str = "default",
+ positional_embeddings: Optional[str] = None,
+ num_positional_embeddings: Optional[int] = None,
+ ada_norm_continous_conditioning_embedding_dim: Optional[int] = None,
+ ada_norm_bias: Optional[int] = None,
+ ff_inner_dim: Optional[int] = None,
+ ff_bias: bool = True,
+ attention_out_bias: bool = True,
+ ):
+ super().__init__()
+ self.only_cross_attention = only_cross_attention
+
+ # We keep these boolean flags for backward-compatibility.
+ self.use_ada_layer_norm_zero = (num_embeds_ada_norm is not None) and norm_type == "ada_norm_zero"
+ self.use_ada_layer_norm = (num_embeds_ada_norm is not None) and norm_type == "ada_norm"
+ self.use_ada_layer_norm_single = norm_type == "ada_norm_single"
+ self.use_layer_norm = norm_type == "layer_norm"
+ self.use_ada_layer_norm_continuous = norm_type == "ada_norm_continuous"
+
+ if norm_type in ("ada_norm", "ada_norm_zero") and num_embeds_ada_norm is None:
+ raise ValueError(
+ f"`norm_type` is set to {norm_type}, but `num_embeds_ada_norm` is not defined. Please make sure to"
+ f" define `num_embeds_ada_norm` if setting `norm_type` to {norm_type}."
+ )
+
+ self.norm_type = norm_type
+ self.num_embeds_ada_norm = num_embeds_ada_norm
+
+ if positional_embeddings and (num_positional_embeddings is None):
+ raise ValueError(
+ "If `positional_embedding` type is defined, `num_positition_embeddings` must also be defined."
+ )
+
+ if positional_embeddings == "sinusoidal":
+ self.pos_embed = SinusoidalPositionalEmbedding(dim, max_seq_length=num_positional_embeddings)
+ else:
+ self.pos_embed = None
+
+ # Define 3 blocks. Each block has its own normalization layer.
+ # 1. Self-Attn
+ if norm_type == "ada_norm":
+ self.norm1 = AdaLayerNorm(dim, num_embeds_ada_norm)
+ elif norm_type == "ada_norm_zero":
+ self.norm1 = AdaLayerNormZero(dim, num_embeds_ada_norm)
+ elif norm_type == "ada_norm_continuous":
+ self.norm1 = AdaLayerNormContinuous(
+ dim,
+ ada_norm_continous_conditioning_embedding_dim,
+ norm_elementwise_affine,
+ norm_eps,
+ ada_norm_bias,
+ "rms_norm",
+ )
+ else:
+ self.norm1 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps)
+
+ self.attn1 = Attention(
+ query_dim=dim,
+ heads=num_attention_heads,
+ dim_head=attention_head_dim,
+ dropout=dropout,
+ bias=attention_bias,
+ cross_attention_dim=cross_attention_dim if only_cross_attention else None,
+ upcast_attention=upcast_attention,
+ out_bias=attention_out_bias,
+ )
+
+ # 2. Cross-Attn
+ if cross_attention_dim is not None or double_self_attention:
+ # We currently only use AdaLayerNormZero for self attention where there will only be one attention block.
+ # I.e. the number of returned modulation chunks from AdaLayerZero would not make sense if returned during
+ # the second cross attention block.
+ if norm_type == "ada_norm":
+ self.norm2 = AdaLayerNorm(dim, num_embeds_ada_norm)
+ elif norm_type == "ada_norm_continuous":
+ self.norm2 = AdaLayerNormContinuous(
+ dim,
+ ada_norm_continous_conditioning_embedding_dim,
+ norm_elementwise_affine,
+ norm_eps,
+ ada_norm_bias,
+ "rms_norm",
+ )
+ else:
+ self.norm2 = nn.LayerNorm(dim, norm_eps, norm_elementwise_affine)
+
+ self.attn2 = Attention(
+ query_dim=dim,
+ cross_attention_dim=cross_attention_dim if not double_self_attention else None,
+ heads=num_attention_heads,
+ dim_head=attention_head_dim,
+ dropout=dropout,
+ bias=attention_bias,
+ upcast_attention=upcast_attention,
+ out_bias=attention_out_bias,
+ ) # is self-attn if encoder_hidden_states is none
+ else:
+ if norm_type == "ada_norm_single": # For Latte
+ self.norm2 = nn.LayerNorm(dim, norm_eps, norm_elementwise_affine)
+ else:
+ self.norm2 = None
+ self.attn2 = None
+
+ # 3. Feed-forward
+ if norm_type == "ada_norm_continuous":
+ self.norm3 = AdaLayerNormContinuous(
+ dim,
+ ada_norm_continous_conditioning_embedding_dim,
+ norm_elementwise_affine,
+ norm_eps,
+ ada_norm_bias,
+ "layer_norm",
+ )
+
+ elif norm_type in ["ada_norm_zero", "ada_norm", "layer_norm"]:
+ self.norm3 = nn.LayerNorm(dim, norm_eps, norm_elementwise_affine)
+ elif norm_type == "layer_norm_i2vgen":
+ self.norm3 = None
+
+ self.ff = FeedForward(
+ dim,
+ dropout=dropout,
+ activation_fn=activation_fn,
+ final_dropout=final_dropout,
+ inner_dim=ff_inner_dim,
+ bias=ff_bias,
+ )
+
+ # 4. Fuser
+ if attention_type == "gated" or attention_type == "gated-text-image":
+ self.fuser = GatedSelfAttentionDense(dim, cross_attention_dim, num_attention_heads, attention_head_dim)
+
+ # 5. Scale-shift for PixArt-Alpha.
+ if norm_type == "ada_norm_single":
+ self.scale_shift_table = nn.Parameter(torch.randn(6, dim) / dim**0.5)
+
+ # let chunk size default to None
+ self._chunk_size = None
+ self._chunk_dim = 0
+
+ def set_chunk_feed_forward(self, chunk_size: Optional[int], dim: int = 0):
+ # Sets chunk feed-forward
+ self._chunk_size = chunk_size
+ self._chunk_dim = dim
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ encoder_hidden_states: Optional[torch.Tensor] = None,
+ encoder_attention_mask: Optional[torch.Tensor] = None,
+ timestep: Optional[torch.LongTensor] = None,
+ cross_attention_kwargs: Dict[str, Any] = None,
+ class_labels: Optional[torch.LongTensor] = None,
+ added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None,
+ ) -> torch.Tensor:
+ if cross_attention_kwargs is not None:
+ if cross_attention_kwargs.get("scale", None) is not None:
+ logger.warning("Passing `scale` to `cross_attention_kwargs` is deprecated. `scale` will be ignored.")
+
+ # Notice that normalization is always applied before the real computation in the following blocks.
+ # 0. Self-Attention
+ batch_size = hidden_states.shape[0]
+
+ if self.norm_type == "ada_norm":
+ norm_hidden_states = self.norm1(hidden_states, timestep)
+ elif self.norm_type == "ada_norm_zero":
+ norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(
+ hidden_states, timestep, class_labels, hidden_dtype=hidden_states.dtype
+ )
+ elif self.norm_type in ["layer_norm", "layer_norm_i2vgen"]:
+ norm_hidden_states = self.norm1(hidden_states)
+ elif self.norm_type == "ada_norm_continuous":
+ norm_hidden_states = self.norm1(hidden_states, added_cond_kwargs["pooled_text_emb"])
+ elif self.norm_type == "ada_norm_single":
+ shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (
+ self.scale_shift_table[None] + timestep.reshape(batch_size, 6, -1)
+ ).chunk(6, dim=1)
+ norm_hidden_states = self.norm1(hidden_states)
+ norm_hidden_states = norm_hidden_states * (1 + scale_msa) + shift_msa
+ else:
+ raise ValueError("Incorrect norm used")
+
+ if self.pos_embed is not None:
+ norm_hidden_states = self.pos_embed(norm_hidden_states)
+
+ # 1. Prepare GLIGEN inputs
+ cross_attention_kwargs = cross_attention_kwargs.copy() if cross_attention_kwargs is not None else {}
+ gligen_kwargs = cross_attention_kwargs.pop("gligen", None)
+
+ attn_output = self.attn1(
+ norm_hidden_states,
+ encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None,
+ attention_mask=attention_mask,
+ **cross_attention_kwargs,
+ )
+
+ if self.norm_type == "ada_norm_zero":
+ attn_output = gate_msa.unsqueeze(1) * attn_output
+ elif self.norm_type == "ada_norm_single":
+ attn_output = gate_msa * attn_output
+
+ hidden_states = attn_output + hidden_states
+ if hidden_states.ndim == 4:
+ hidden_states = hidden_states.squeeze(1)
+
+ # 1.2 GLIGEN Control
+ if gligen_kwargs is not None:
+ hidden_states = self.fuser(hidden_states, gligen_kwargs["objs"])
+
+ # 3. Cross-Attention
+ if self.attn2 is not None:
+ if self.norm_type == "ada_norm":
+ norm_hidden_states = self.norm2(hidden_states, timestep)
+ elif self.norm_type in ["ada_norm_zero", "layer_norm", "layer_norm_i2vgen"]:
+ norm_hidden_states = self.norm2(hidden_states)
+ elif self.norm_type == "ada_norm_single":
+ # For PixArt norm2 isn't applied here:
+ # https://github.com/PixArt-alpha/PixArt-alpha/blob/0f55e922376d8b797edd44d25d0e7464b260dcab/diffusion/model/nets/PixArtMS.py#L70C1-L76C103
+ norm_hidden_states = hidden_states
+ elif self.norm_type == "ada_norm_continuous":
+ norm_hidden_states = self.norm2(hidden_states, added_cond_kwargs["pooled_text_emb"])
+ else:
+ raise ValueError("Incorrect norm")
+
+ if self.pos_embed is not None and self.norm_type != "ada_norm_single":
+ norm_hidden_states = self.pos_embed(norm_hidden_states)
+
+ attn_output = self.attn2(
+ norm_hidden_states,
+ encoder_hidden_states=encoder_hidden_states,
+ attention_mask=encoder_attention_mask,
+ **cross_attention_kwargs,
+ )
+ hidden_states = attn_output + hidden_states
+
+ # 4. Feed-forward
+ # i2vgen doesn't have this norm ๐คทโโ๏ธ
+ if self.norm_type == "ada_norm_continuous":
+ norm_hidden_states = self.norm3(hidden_states, added_cond_kwargs["pooled_text_emb"])
+ elif not self.norm_type == "ada_norm_single":
+ norm_hidden_states = self.norm3(hidden_states)
+
+ if self.norm_type == "ada_norm_zero":
+ norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
+
+ if self.norm_type == "ada_norm_single":
+ norm_hidden_states = self.norm2(hidden_states)
+ norm_hidden_states = norm_hidden_states * (1 + scale_mlp) + shift_mlp
+
+ if self._chunk_size is not None:
+ # "feed_forward_chunk_size" can be used to save memory
+ ff_output = _chunked_feed_forward(self.ff, norm_hidden_states, self._chunk_dim, self._chunk_size)
+ else:
+ ff_output = self.ff(norm_hidden_states)
+
+ if self.norm_type == "ada_norm_zero":
+ ff_output = gate_mlp.unsqueeze(1) * ff_output
+ elif self.norm_type == "ada_norm_single":
+ ff_output = gate_mlp * ff_output
+
+ hidden_states = ff_output + hidden_states
+ if hidden_states.ndim == 4:
+ hidden_states = hidden_states.squeeze(1)
+
+ return hidden_states
+
+
+class LuminaFeedForward(nn.Module):
+ r"""
+ A feed-forward layer.
+
+ Parameters:
+ hidden_size (`int`):
+ The dimensionality of the hidden layers in the model. This parameter determines the width of the model's
+ hidden representations.
+ intermediate_size (`int`): The intermediate dimension of the feedforward layer.
+ multiple_of (`int`, *optional*): Value to ensure hidden dimension is a multiple
+ of this value.
+ ffn_dim_multiplier (float, *optional*): Custom multiplier for hidden
+ dimension. Defaults to None.
+ """
+
+ def __init__(
+ self,
+ dim: int,
+ inner_dim: int,
+ multiple_of: Optional[int] = 256,
+ ffn_dim_multiplier: Optional[float] = None,
+ ):
+ super().__init__()
+ inner_dim = int(2 * inner_dim / 3)
+ # custom hidden_size factor multiplier
+ if ffn_dim_multiplier is not None:
+ inner_dim = int(ffn_dim_multiplier * inner_dim)
+ inner_dim = multiple_of * ((inner_dim + multiple_of - 1) // multiple_of)
+
+ self.linear_1 = nn.Linear(
+ dim,
+ inner_dim,
+ bias=False,
+ )
+ self.linear_2 = nn.Linear(
+ inner_dim,
+ dim,
+ bias=False,
+ )
+ self.linear_3 = nn.Linear(
+ dim,
+ inner_dim,
+ bias=False,
+ )
+ self.silu = FP32SiLU()
+
+ def forward(self, x):
+ return self.linear_2(self.silu(self.linear_1(x)) * self.linear_3(x))
+
+
+@maybe_allow_in_graph
+class TemporalBasicTransformerBlock(nn.Module):
+ r"""
+ A basic Transformer block for video like data.
+
+ Parameters:
+ dim (`int`): The number of channels in the input and output.
+ time_mix_inner_dim (`int`): The number of channels for temporal attention.
+ num_attention_heads (`int`): The number of heads to use for multi-head attention.
+ attention_head_dim (`int`): The number of channels in each head.
+ cross_attention_dim (`int`, *optional*): The size of the encoder_hidden_states vector for cross attention.
+ """
+
+ def __init__(
+ self,
+ dim: int,
+ time_mix_inner_dim: int,
+ num_attention_heads: int,
+ attention_head_dim: int,
+ cross_attention_dim: Optional[int] = None,
+ ):
+ super().__init__()
+ self.is_res = dim == time_mix_inner_dim
+
+ self.norm_in = nn.LayerNorm(dim)
+
+ # Define 3 blocks. Each block has its own normalization layer.
+ # 1. Self-Attn
+ self.ff_in = FeedForward(
+ dim,
+ dim_out=time_mix_inner_dim,
+ activation_fn="geglu",
+ )
+
+ self.norm1 = nn.LayerNorm(time_mix_inner_dim)
+ self.attn1 = Attention(
+ query_dim=time_mix_inner_dim,
+ heads=num_attention_heads,
+ dim_head=attention_head_dim,
+ cross_attention_dim=None,
+ )
+
+ # 2. Cross-Attn
+ if cross_attention_dim is not None:
+ # We currently only use AdaLayerNormZero for self attention where there will only be one attention block.
+ # I.e. the number of returned modulation chunks from AdaLayerZero would not make sense if returned during
+ # the second cross attention block.
+ self.norm2 = nn.LayerNorm(time_mix_inner_dim)
+ self.attn2 = Attention(
+ query_dim=time_mix_inner_dim,
+ cross_attention_dim=cross_attention_dim,
+ heads=num_attention_heads,
+ dim_head=attention_head_dim,
+ ) # is self-attn if encoder_hidden_states is none
+ else:
+ self.norm2 = None
+ self.attn2 = None
+
+ # 3. Feed-forward
+ self.norm3 = nn.LayerNorm(time_mix_inner_dim)
+ self.ff = FeedForward(time_mix_inner_dim, activation_fn="geglu")
+
+ # let chunk size default to None
+ self._chunk_size = None
+ self._chunk_dim = None
+
+ def set_chunk_feed_forward(self, chunk_size: Optional[int], **kwargs):
+ # Sets chunk feed-forward
+ self._chunk_size = chunk_size
+ # chunk dim should be hardcoded to 1 to have better speed vs. memory trade-off
+ self._chunk_dim = 1
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ num_frames: int,
+ encoder_hidden_states: Optional[torch.Tensor] = None,
+ ) -> torch.Tensor:
+ # Notice that normalization is always applied before the real computation in the following blocks.
+ # 0. Self-Attention
+ batch_size = hidden_states.shape[0]
+
+ batch_frames, seq_length, channels = hidden_states.shape
+ batch_size = batch_frames // num_frames
+
+ hidden_states = hidden_states[None, :].reshape(batch_size, num_frames, seq_length, channels)
+ hidden_states = hidden_states.permute(0, 2, 1, 3)
+ hidden_states = hidden_states.reshape(batch_size * seq_length, num_frames, channels)
+
+ residual = hidden_states
+ hidden_states = self.norm_in(hidden_states)
+
+ if self._chunk_size is not None:
+ hidden_states = _chunked_feed_forward(self.ff_in, hidden_states, self._chunk_dim, self._chunk_size)
+ else:
+ hidden_states = self.ff_in(hidden_states)
+
+ if self.is_res:
+ hidden_states = hidden_states + residual
+
+ norm_hidden_states = self.norm1(hidden_states)
+ attn_output = self.attn1(norm_hidden_states, encoder_hidden_states=None)
+ hidden_states = attn_output + hidden_states
+
+ # 3. Cross-Attention
+ if self.attn2 is not None:
+ norm_hidden_states = self.norm2(hidden_states)
+ attn_output = self.attn2(norm_hidden_states, encoder_hidden_states=encoder_hidden_states)
+ hidden_states = attn_output + hidden_states
+
+ # 4. Feed-forward
+ norm_hidden_states = self.norm3(hidden_states)
+
+ if self._chunk_size is not None:
+ ff_output = _chunked_feed_forward(self.ff, norm_hidden_states, self._chunk_dim, self._chunk_size)
+ else:
+ ff_output = self.ff(norm_hidden_states)
+
+ if self.is_res:
+ hidden_states = ff_output + hidden_states
+ else:
+ hidden_states = ff_output
+
+ hidden_states = hidden_states[None, :].reshape(batch_size, seq_length, num_frames, channels)
+ hidden_states = hidden_states.permute(0, 2, 1, 3)
+ hidden_states = hidden_states.reshape(batch_size * num_frames, seq_length, channels)
+
+ return hidden_states
+
+
+class SkipFFTransformerBlock(nn.Module):
+ def __init__(
+ self,
+ dim: int,
+ num_attention_heads: int,
+ attention_head_dim: int,
+ kv_input_dim: int,
+ kv_input_dim_proj_use_bias: bool,
+ dropout=0.0,
+ cross_attention_dim: Optional[int] = None,
+ attention_bias: bool = False,
+ attention_out_bias: bool = True,
+ ):
+ super().__init__()
+ if kv_input_dim != dim:
+ self.kv_mapper = nn.Linear(kv_input_dim, dim, kv_input_dim_proj_use_bias)
+ else:
+ self.kv_mapper = None
+
+ self.norm1 = RMSNorm(dim, 1e-06)
+
+ self.attn1 = Attention(
+ query_dim=dim,
+ heads=num_attention_heads,
+ dim_head=attention_head_dim,
+ dropout=dropout,
+ bias=attention_bias,
+ cross_attention_dim=cross_attention_dim,
+ out_bias=attention_out_bias,
+ )
+
+ self.norm2 = RMSNorm(dim, 1e-06)
+
+ self.attn2 = Attention(
+ query_dim=dim,
+ cross_attention_dim=cross_attention_dim,
+ heads=num_attention_heads,
+ dim_head=attention_head_dim,
+ dropout=dropout,
+ bias=attention_bias,
+ out_bias=attention_out_bias,
+ )
+
+ def forward(self, hidden_states, encoder_hidden_states, cross_attention_kwargs):
+ cross_attention_kwargs = cross_attention_kwargs.copy() if cross_attention_kwargs is not None else {}
+
+ if self.kv_mapper is not None:
+ encoder_hidden_states = self.kv_mapper(F.silu(encoder_hidden_states))
+
+ norm_hidden_states = self.norm1(hidden_states)
+
+ attn_output = self.attn1(
+ norm_hidden_states,
+ encoder_hidden_states=encoder_hidden_states,
+ **cross_attention_kwargs,
+ )
+
+ hidden_states = attn_output + hidden_states
+
+ norm_hidden_states = self.norm2(hidden_states)
+
+ attn_output = self.attn2(
+ norm_hidden_states,
+ encoder_hidden_states=encoder_hidden_states,
+ **cross_attention_kwargs,
+ )
+
+ hidden_states = attn_output + hidden_states
+
+ return hidden_states
+
+
+class FeedForward(nn.Module):
+ r"""
+ A feed-forward layer.
+
+ Parameters:
+ dim (`int`): The number of channels in the input.
+ dim_out (`int`, *optional*): The number of channels in the output. If not given, defaults to `dim`.
+ mult (`int`, *optional*, defaults to 4): The multiplier to use for the hidden dimension.
+ dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
+ activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
+ final_dropout (`bool` *optional*, defaults to False): Apply a final dropout.
+ bias (`bool`, defaults to True): Whether to use a bias in the linear layer.
+ """
+
+ def __init__(
+ self,
+ dim: int,
+ dim_out: Optional[int] = None,
+ mult: int = 4,
+ dropout: float = 0.0,
+ activation_fn: str = "geglu",
+ final_dropout: bool = False,
+ inner_dim=None,
+ bias: bool = True,
+ ):
+ super().__init__()
+ if inner_dim is None:
+ inner_dim = int(dim * mult)
+ dim_out = dim_out if dim_out is not None else dim
+
+ if activation_fn == "gelu":
+ act_fn = GELU(dim, inner_dim, bias=bias)
+ if activation_fn == "gelu-approximate":
+ act_fn = GELU(dim, inner_dim, approximate="tanh", bias=bias)
+ elif activation_fn == "geglu":
+ act_fn = GEGLU(dim, inner_dim, bias=bias)
+ elif activation_fn == "geglu-approximate":
+ act_fn = ApproximateGELU(dim, inner_dim, bias=bias)
+ elif activation_fn == "swiglu":
+ act_fn = SwiGLU(dim, inner_dim, bias=bias)
+
+ self.net = nn.ModuleList([])
+ # project in
+ self.net.append(act_fn)
+ # project dropout
+ self.net.append(nn.Dropout(dropout))
+ # project out
+ self.net.append(nn.Linear(inner_dim, dim_out, bias=bias))
+ # FF as used in Vision Transformer, MLP-Mixer, etc. have a final dropout
+ if final_dropout:
+ self.net.append(nn.Dropout(dropout))
+
+ def forward(self, hidden_states: torch.Tensor, *args, **kwargs) -> torch.Tensor:
+ if len(args) > 0 or kwargs.get("scale", None) is not None:
+ deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
+ deprecate("scale", "1.0.0", deprecation_message)
+ for module in self.net:
+ hidden_states = module(hidden_states)
+ return hidden_states
\ No newline at end of file
diff --git a/flux/attention_processor.py b/flux/attention_processor.py
new file mode 100644
index 0000000000000000000000000000000000000000..31f61e9a6ebf728388e7ced51eb0e6688ce3015a
--- /dev/null
+++ b/flux/attention_processor.py
@@ -0,0 +1,2606 @@
+# Copyright 2024 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import inspect
+import math
+from typing import Callable, List, Optional, Tuple, Union
+import numbers
+
+import torch
+import torch.nn.functional as F
+from torch import nn
+
+#from ..image_processor import IPAdapterMaskProcessor
+from diffusers.utils import deprecate, logging
+from diffusers.utils.import_utils import is_torch_npu_available, is_xformers_available
+from diffusers.utils.torch_utils import is_torch_version, maybe_allow_in_graph
+
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+if is_torch_npu_available():
+ import torch_npu
+
+if is_xformers_available():
+ import xformers
+ import xformers.ops
+else:
+ xformers = None
+
+
+@maybe_allow_in_graph
+class Attention(nn.Module):
+ r"""
+ A cross attention layer.
+
+ Parameters:
+ query_dim (`int`):
+ The number of channels in the query.
+ cross_attention_dim (`int`, *optional*):
+ The number of channels in the encoder_hidden_states. If not given, defaults to `query_dim`.
+ heads (`int`, *optional*, defaults to 8):
+ The number of heads to use for multi-head attention.
+ kv_heads (`int`, *optional*, defaults to `None`):
+ The number of key and value heads to use for multi-head attention. Defaults to `heads`. If
+ `kv_heads=heads`, the model will use Multi Head Attention (MHA), if `kv_heads=1` the model will use Multi
+ Query Attention (MQA) otherwise GQA is used.
+ dim_head (`int`, *optional*, defaults to 64):
+ The number of channels in each head.
+ dropout (`float`, *optional*, defaults to 0.0):
+ The dropout probability to use.
+ bias (`bool`, *optional*, defaults to False):
+ Set to `True` for the query, key, and value linear layers to contain a bias parameter.
+ upcast_attention (`bool`, *optional*, defaults to False):
+ Set to `True` to upcast the attention computation to `float32`.
+ upcast_softmax (`bool`, *optional*, defaults to False):
+ Set to `True` to upcast the softmax computation to `float32`.
+ cross_attention_norm (`str`, *optional*, defaults to `None`):
+ The type of normalization to use for the cross attention. Can be `None`, `layer_norm`, or `group_norm`.
+ cross_attention_norm_num_groups (`int`, *optional*, defaults to 32):
+ The number of groups to use for the group norm in the cross attention.
+ added_kv_proj_dim (`int`, *optional*, defaults to `None`):
+ The number of channels to use for the added key and value projections. If `None`, no projection is used.
+ norm_num_groups (`int`, *optional*, defaults to `None`):
+ The number of groups to use for the group norm in the attention.
+ spatial_norm_dim (`int`, *optional*, defaults to `None`):
+ The number of channels to use for the spatial normalization.
+ out_bias (`bool`, *optional*, defaults to `True`):
+ Set to `True` to use a bias in the output linear layer.
+ scale_qk (`bool`, *optional*, defaults to `True`):
+ Set to `True` to scale the query and key by `1 / sqrt(dim_head)`.
+ only_cross_attention (`bool`, *optional*, defaults to `False`):
+ Set to `True` to only use cross attention and not added_kv_proj_dim. Can only be set to `True` if
+ `added_kv_proj_dim` is not `None`.
+ eps (`float`, *optional*, defaults to 1e-5):
+ An additional value added to the denominator in group normalization that is used for numerical stability.
+ rescale_output_factor (`float`, *optional*, defaults to 1.0):
+ A factor to rescale the output by dividing it with this value.
+ residual_connection (`bool`, *optional*, defaults to `False`):
+ Set to `True` to add the residual connection to the output.
+ _from_deprecated_attn_block (`bool`, *optional*, defaults to `False`):
+ Set to `True` if the attention block is loaded from a deprecated state dict.
+ processor (`AttnProcessor`, *optional*, defaults to `None`):
+ The attention processor to use. If `None`, defaults to `AttnProcessor2_0` if `torch 2.x` is used and
+ `AttnProcessor` otherwise.
+ """
+
+ def __init__(
+ self,
+ query_dim: int,
+ cross_attention_dim: Optional[int] = None,
+ heads: int = 8,
+ kv_heads: Optional[int] = None,
+ dim_head: int = 64,
+ dropout: float = 0.0,
+ bias: bool = False,
+ upcast_attention: bool = False,
+ upcast_softmax: bool = False,
+ cross_attention_norm: Optional[str] = None,
+ cross_attention_norm_num_groups: int = 32,
+ qk_norm: Optional[str] = None,
+ added_kv_proj_dim: Optional[int] = None,
+ added_proj_bias: Optional[bool] = True,
+ norm_num_groups: Optional[int] = None,
+ spatial_norm_dim: Optional[int] = None,
+ out_bias: bool = True,
+ scale_qk: bool = True,
+ only_cross_attention: bool = False,
+ eps: float = 1e-5,
+ rescale_output_factor: float = 1.0,
+ residual_connection: bool = False,
+ _from_deprecated_attn_block: bool = False,
+ processor: Optional["AttnProcessor"] = None,
+ out_dim: int = None,
+ context_pre_only=None,
+ pre_only=False,
+ ):
+ super().__init__()
+
+ # To prevent circular import.
+ from .normalization import FP32LayerNorm, RMSNorm
+
+ self.inner_dim = out_dim if out_dim is not None else dim_head * heads
+ self.inner_kv_dim = self.inner_dim if kv_heads is None else dim_head * kv_heads
+ self.query_dim = query_dim
+ self.use_bias = bias
+ self.is_cross_attention = cross_attention_dim is not None
+ self.cross_attention_dim = cross_attention_dim if cross_attention_dim is not None else query_dim
+ self.upcast_attention = upcast_attention
+ self.upcast_softmax = upcast_softmax
+ self.rescale_output_factor = rescale_output_factor
+ self.residual_connection = residual_connection
+ self.dropout = dropout
+ self.fused_projections = False
+ self.out_dim = out_dim if out_dim is not None else query_dim
+ self.context_pre_only = context_pre_only
+ self.pre_only = pre_only
+
+ # we make use of this private variable to know whether this class is loaded
+ # with an deprecated state dict so that we can convert it on the fly
+ self._from_deprecated_attn_block = _from_deprecated_attn_block
+
+ self.scale_qk = scale_qk
+ self.scale = dim_head**-0.5 if self.scale_qk else 1.0
+
+ self.heads = out_dim // dim_head if out_dim is not None else heads
+ # for slice_size > 0 the attention score computation
+ # is split across the batch axis to save memory
+ # You can set slice_size with `set_attention_slice`
+ self.sliceable_head_dim = heads
+
+ self.added_kv_proj_dim = added_kv_proj_dim
+ self.only_cross_attention = only_cross_attention
+
+ if self.added_kv_proj_dim is None and self.only_cross_attention:
+ raise ValueError(
+ "`only_cross_attention` can only be set to True if `added_kv_proj_dim` is not None. Make sure to set either `only_cross_attention=False` or define `added_kv_proj_dim`."
+ )
+
+ if norm_num_groups is not None:
+ self.group_norm = nn.GroupNorm(num_channels=query_dim, num_groups=norm_num_groups, eps=eps, affine=True)
+ else:
+ self.group_norm = None
+
+ if spatial_norm_dim is not None:
+ self.spatial_norm = SpatialNorm(f_channels=query_dim, zq_channels=spatial_norm_dim)
+ else:
+ self.spatial_norm = None
+
+ if qk_norm is None:
+ self.norm_q = None
+ self.norm_k = None
+ elif qk_norm == "layer_norm":
+ self.norm_q = nn.LayerNorm(dim_head, eps=eps)
+ self.norm_k = nn.LayerNorm(dim_head, eps=eps)
+ elif qk_norm == "fp32_layer_norm":
+ self.norm_q = FP32LayerNorm(dim_head, elementwise_affine=False, bias=False, eps=eps)
+ self.norm_k = FP32LayerNorm(dim_head, elementwise_affine=False, bias=False, eps=eps)
+ elif qk_norm == "layer_norm_across_heads":
+ # Lumina applys qk norm across all heads
+ self.norm_q = nn.LayerNorm(dim_head * heads, eps=eps)
+ self.norm_k = nn.LayerNorm(dim_head * kv_heads, eps=eps)
+ elif qk_norm == "rms_norm":
+ self.norm_q = RMSNorm(dim_head, eps=eps)
+ self.norm_k = RMSNorm(dim_head, eps=eps)
+ else:
+ raise ValueError(f"unknown qk_norm: {qk_norm}. Should be None or 'layer_norm'")
+
+ if cross_attention_norm is None:
+ self.norm_cross = None
+ elif cross_attention_norm == "layer_norm":
+ self.norm_cross = nn.LayerNorm(self.cross_attention_dim)
+ elif cross_attention_norm == "group_norm":
+ if self.added_kv_proj_dim is not None:
+ # The given `encoder_hidden_states` are initially of shape
+ # (batch_size, seq_len, added_kv_proj_dim) before being projected
+ # to (batch_size, seq_len, cross_attention_dim). The norm is applied
+ # before the projection, so we need to use `added_kv_proj_dim` as
+ # the number of channels for the group norm.
+ norm_cross_num_channels = added_kv_proj_dim
+ else:
+ norm_cross_num_channels = self.cross_attention_dim
+
+ self.norm_cross = nn.GroupNorm(
+ num_channels=norm_cross_num_channels, num_groups=cross_attention_norm_num_groups, eps=1e-5, affine=True
+ )
+ else:
+ raise ValueError(
+ f"unknown cross_attention_norm: {cross_attention_norm}. Should be None, 'layer_norm' or 'group_norm'"
+ )
+
+ self.to_q = nn.Linear(query_dim, self.inner_dim, bias=bias)
+
+ if not self.only_cross_attention:
+ # only relevant for the `AddedKVProcessor` classes
+ self.to_k = nn.Linear(self.cross_attention_dim, self.inner_kv_dim, bias=bias)
+ self.to_v = nn.Linear(self.cross_attention_dim, self.inner_kv_dim, bias=bias)
+ else:
+ self.to_k = None
+ self.to_v = None
+
+ self.added_proj_bias = added_proj_bias
+ if self.added_kv_proj_dim is not None:
+ self.add_k_proj = nn.Linear(added_kv_proj_dim, self.inner_kv_dim, bias=added_proj_bias)
+ self.add_v_proj = nn.Linear(added_kv_proj_dim, self.inner_kv_dim, bias=added_proj_bias)
+ if self.context_pre_only is not None:
+ self.add_q_proj = nn.Linear(added_kv_proj_dim, self.inner_dim, bias=added_proj_bias)
+
+ if not self.pre_only:
+ self.to_out = nn.ModuleList([])
+ self.to_out.append(nn.Linear(self.inner_dim, self.out_dim, bias=out_bias))
+ self.to_out.append(nn.Dropout(dropout))
+
+ if self.context_pre_only is not None and not self.context_pre_only:
+ self.to_add_out = nn.Linear(self.inner_dim, self.out_dim, bias=out_bias)
+
+ if qk_norm is not None and added_kv_proj_dim is not None:
+ if qk_norm == "fp32_layer_norm":
+ self.norm_added_q = FP32LayerNorm(dim_head, elementwise_affine=False, bias=False, eps=eps)
+ self.norm_added_k = FP32LayerNorm(dim_head, elementwise_affine=False, bias=False, eps=eps)
+ elif qk_norm == "rms_norm":
+ self.norm_added_q = RMSNorm(dim_head, eps=eps)
+ self.norm_added_k = RMSNorm(dim_head, eps=eps)
+ else:
+ self.norm_added_q = None
+ self.norm_added_k = None
+
+ # set attention processor
+ # We use the AttnProcessor2_0 by default when torch 2.x is used which uses
+ # torch.nn.functional.scaled_dot_product_attention for native Flash/memory_efficient_attention
+ # but only if it has the default `scale` argument. TODO remove scale_qk check when we move to torch 2.1
+ if processor is None:
+ processor = (
+ AttnProcessor2_0() if hasattr(F, "scaled_dot_product_attention") and self.scale_qk else AttnProcessor()
+ )
+ self.set_processor(processor)
+
+ def set_use_npu_flash_attention(self, use_npu_flash_attention: bool) -> None:
+ r"""
+ Set whether to use npu flash attention from `torch_npu` or not.
+
+ """
+ if use_npu_flash_attention:
+ processor = AttnProcessorNPU()
+ else:
+ # set attention processor
+ # We use the AttnProcessor2_0 by default when torch 2.x is used which uses
+ # torch.nn.functional.scaled_dot_product_attention for native Flash/memory_efficient_attention
+ # but only if it has the default `scale` argument. TODO remove scale_qk check when we move to torch 2.1
+ processor = (
+ AttnProcessor2_0() if hasattr(F, "scaled_dot_product_attention") and self.scale_qk else AttnProcessor()
+ )
+ self.set_processor(processor)
+
+ def set_use_memory_efficient_attention_xformers(
+ self, use_memory_efficient_attention_xformers: bool, attention_op: Optional[Callable] = None
+ ) -> None:
+ r"""
+ Set whether to use memory efficient attention from `xformers` or not.
+
+ Args:
+ use_memory_efficient_attention_xformers (`bool`):
+ Whether to use memory efficient attention from `xformers` or not.
+ attention_op (`Callable`, *optional*):
+ The attention operation to use. Defaults to `None` which uses the default attention operation from
+ `xformers`.
+ """
+ is_custom_diffusion = hasattr(self, "processor") and isinstance(
+ self.processor,
+ (CustomDiffusionAttnProcessor, CustomDiffusionXFormersAttnProcessor, CustomDiffusionAttnProcessor2_0),
+ )
+ is_added_kv_processor = hasattr(self, "processor") and isinstance(
+ self.processor,
+ (
+ AttnAddedKVProcessor,
+ AttnAddedKVProcessor2_0,
+ SlicedAttnAddedKVProcessor,
+ XFormersAttnAddedKVProcessor,
+ ),
+ )
+
+ if use_memory_efficient_attention_xformers:
+ if is_added_kv_processor and is_custom_diffusion:
+ raise NotImplementedError(
+ f"Memory efficient attention is currently not supported for custom diffusion for attention processor type {self.processor}"
+ )
+ if not is_xformers_available():
+ raise ModuleNotFoundError(
+ (
+ "Refer to https://github.com/facebookresearch/xformers for more information on how to install"
+ " xformers"
+ ),
+ name="xformers",
+ )
+ elif not torch.cuda.is_available():
+ raise ValueError(
+ "torch.cuda.is_available() should be True but is False. xformers' memory efficient attention is"
+ " only available for GPU "
+ )
+ else:
+ try:
+ # Make sure we can run the memory efficient attention
+ _ = xformers.ops.memory_efficient_attention(
+ torch.randn((1, 2, 40), device="cuda"),
+ torch.randn((1, 2, 40), device="cuda"),
+ torch.randn((1, 2, 40), device="cuda"),
+ )
+ except Exception as e:
+ raise e
+
+ if is_custom_diffusion:
+ processor = CustomDiffusionXFormersAttnProcessor(
+ train_kv=self.processor.train_kv,
+ train_q_out=self.processor.train_q_out,
+ hidden_size=self.processor.hidden_size,
+ cross_attention_dim=self.processor.cross_attention_dim,
+ attention_op=attention_op,
+ )
+ processor.load_state_dict(self.processor.state_dict())
+ if hasattr(self.processor, "to_k_custom_diffusion"):
+ processor.to(self.processor.to_k_custom_diffusion.weight.device)
+ elif is_added_kv_processor:
+ # TODO(Patrick, Suraj, William) - currently xformers doesn't work for UnCLIP
+ # which uses this type of cross attention ONLY because the attention mask of format
+ # [0, ..., -10.000, ..., 0, ...,] is not supported
+ # throw warning
+ logger.info(
+ "Memory efficient attention with `xformers` might currently not work correctly if an attention mask is required for the attention operation."
+ )
+ processor = XFormersAttnAddedKVProcessor(attention_op=attention_op)
+ else:
+ processor = XFormersAttnProcessor(attention_op=attention_op)
+ else:
+ if is_custom_diffusion:
+ attn_processor_class = (
+ CustomDiffusionAttnProcessor2_0
+ if hasattr(F, "scaled_dot_product_attention")
+ else CustomDiffusionAttnProcessor
+ )
+ processor = attn_processor_class(
+ train_kv=self.processor.train_kv,
+ train_q_out=self.processor.train_q_out,
+ hidden_size=self.processor.hidden_size,
+ cross_attention_dim=self.processor.cross_attention_dim,
+ )
+ processor.load_state_dict(self.processor.state_dict())
+ if hasattr(self.processor, "to_k_custom_diffusion"):
+ processor.to(self.processor.to_k_custom_diffusion.weight.device)
+ else:
+ # set attention processor
+ # We use the AttnProcessor2_0 by default when torch 2.x is used which uses
+ # torch.nn.functional.scaled_dot_product_attention for native Flash/memory_efficient_attention
+ # but only if it has the default `scale` argument. TODO remove scale_qk check when we move to torch 2.1
+ processor = (
+ AttnProcessor2_0()
+ if hasattr(F, "scaled_dot_product_attention") and self.scale_qk
+ else AttnProcessor()
+ )
+
+ self.set_processor(processor)
+
+ def set_attention_slice(self, slice_size: int) -> None:
+ r"""
+ Set the slice size for attention computation.
+
+ Args:
+ slice_size (`int`):
+ The slice size for attention computation.
+ """
+ if slice_size is not None and slice_size > self.sliceable_head_dim:
+ raise ValueError(f"slice_size {slice_size} has to be smaller or equal to {self.sliceable_head_dim}.")
+
+ if slice_size is not None and self.added_kv_proj_dim is not None:
+ processor = SlicedAttnAddedKVProcessor(slice_size)
+ elif slice_size is not None:
+ processor = SlicedAttnProcessor(slice_size)
+ elif self.added_kv_proj_dim is not None:
+ processor = AttnAddedKVProcessor()
+ else:
+ # set attention processor
+ # We use the AttnProcessor2_0 by default when torch 2.x is used which uses
+ # torch.nn.functional.scaled_dot_product_attention for native Flash/memory_efficient_attention
+ # but only if it has the default `scale` argument. TODO remove scale_qk check when we move to torch 2.1
+ processor = (
+ AttnProcessor2_0() if hasattr(F, "scaled_dot_product_attention") and self.scale_qk else AttnProcessor()
+ )
+
+ self.set_processor(processor)
+
+ def set_processor(self, processor: "AttnProcessor") -> None:
+ r"""
+ Set the attention processor to use.
+
+ Args:
+ processor (`AttnProcessor`):
+ The attention processor to use.
+ """
+ # if current processor is in `self._modules` and if passed `processor` is not, we need to
+ # pop `processor` from `self._modules`
+ if (
+ hasattr(self, "processor")
+ and isinstance(self.processor, torch.nn.Module)
+ and not isinstance(processor, torch.nn.Module)
+ ):
+ logger.info(f"You are removing possibly trained weights of {self.processor} with {processor}")
+ self._modules.pop("processor")
+
+ self.processor = processor
+
+ def get_processor(self, return_deprecated_lora: bool = False) -> "AttentionProcessor":
+ r"""
+ Get the attention processor in use.
+
+ Args:
+ return_deprecated_lora (`bool`, *optional*, defaults to `False`):
+ Set to `True` to return the deprecated LoRA attention processor.
+
+ Returns:
+ "AttentionProcessor": The attention processor in use.
+ """
+ if not return_deprecated_lora:
+ return self.processor
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ encoder_hidden_states: Optional[torch.Tensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ **cross_attention_kwargs,
+ ) -> torch.Tensor:
+ r"""
+ The forward method of the `Attention` class.
+
+ Args:
+ hidden_states (`torch.Tensor`):
+ The hidden states of the query.
+ encoder_hidden_states (`torch.Tensor`, *optional*):
+ The hidden states of the encoder.
+ attention_mask (`torch.Tensor`, *optional*):
+ The attention mask to use. If `None`, no mask is applied.
+ **cross_attention_kwargs:
+ Additional keyword arguments to pass along to the cross attention.
+
+ Returns:
+ `torch.Tensor`: The output of the attention layer.
+ """
+ # The `Attention` class can call different attention processors / attention functions
+ # here we simply pass along all tensors to the selected processor class
+ # For standard processors that are defined here, `**cross_attention_kwargs` is empty
+
+ attn_parameters = set(inspect.signature(self.processor.__call__).parameters.keys())
+ quiet_attn_parameters = {"ip_adapter_masks"}
+ unused_kwargs = [
+ k for k, _ in cross_attention_kwargs.items() if k not in attn_parameters and k not in quiet_attn_parameters
+ ]
+ if len(unused_kwargs) > 0:
+ logger.warning(
+ f"cross_attention_kwargs {unused_kwargs} are not expected by {self.processor.__class__.__name__} and will be ignored."
+ )
+ cross_attention_kwargs = {k: w for k, w in cross_attention_kwargs.items() if k in attn_parameters}
+
+ return self.processor(
+ self,
+ hidden_states,
+ encoder_hidden_states=encoder_hidden_states,
+ attention_mask=attention_mask,
+ **cross_attention_kwargs,
+ )
+
+ def batch_to_head_dim(self, tensor: torch.Tensor) -> torch.Tensor:
+ r"""
+ Reshape the tensor from `[batch_size, seq_len, dim]` to `[batch_size // heads, seq_len, dim * heads]`. `heads`
+ is the number of heads initialized while constructing the `Attention` class.
+
+ Args:
+ tensor (`torch.Tensor`): The tensor to reshape.
+
+ Returns:
+ `torch.Tensor`: The reshaped tensor.
+ """
+ head_size = self.heads
+ batch_size, seq_len, dim = tensor.shape
+ tensor = tensor.reshape(batch_size // head_size, head_size, seq_len, dim)
+ tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size // head_size, seq_len, dim * head_size)
+ return tensor
+
+ def head_to_batch_dim(self, tensor: torch.Tensor, out_dim: int = 3) -> torch.Tensor:
+ r"""
+ Reshape the tensor from `[batch_size, seq_len, dim]` to `[batch_size, seq_len, heads, dim // heads]` `heads` is
+ the number of heads initialized while constructing the `Attention` class.
+
+ Args:
+ tensor (`torch.Tensor`): The tensor to reshape.
+ out_dim (`int`, *optional*, defaults to `3`): The output dimension of the tensor. If `3`, the tensor is
+ reshaped to `[batch_size * heads, seq_len, dim // heads]`.
+
+ Returns:
+ `torch.Tensor`: The reshaped tensor.
+ """
+ head_size = self.heads
+ if tensor.ndim == 3:
+ batch_size, seq_len, dim = tensor.shape
+ extra_dim = 1
+ else:
+ batch_size, extra_dim, seq_len, dim = tensor.shape
+ tensor = tensor.reshape(batch_size, seq_len * extra_dim, head_size, dim // head_size)
+ tensor = tensor.permute(0, 2, 1, 3)
+
+ if out_dim == 3:
+ tensor = tensor.reshape(batch_size * head_size, seq_len * extra_dim, dim // head_size)
+
+ return tensor
+
+ def get_attention_scores(
+ self, query: torch.Tensor, key: torch.Tensor, attention_mask: Optional[torch.Tensor] = None
+ ) -> torch.Tensor:
+ r"""
+ Compute the attention scores.
+
+ Args:
+ query (`torch.Tensor`): The query tensor.
+ key (`torch.Tensor`): The key tensor.
+ attention_mask (`torch.Tensor`, *optional*): The attention mask to use. If `None`, no mask is applied.
+
+ Returns:
+ `torch.Tensor`: The attention probabilities/scores.
+ """
+ dtype = query.dtype
+ if self.upcast_attention:
+ query = query.float()
+ key = key.float()
+
+ if attention_mask is None:
+ baddbmm_input = torch.empty(
+ query.shape[0], query.shape[1], key.shape[1], dtype=query.dtype, device=query.device
+ )
+ beta = 0
+ else:
+ baddbmm_input = attention_mask
+ beta = 1
+
+ attention_scores = torch.baddbmm(
+ baddbmm_input,
+ query,
+ key.transpose(-1, -2),
+ beta=beta,
+ alpha=self.scale,
+ )
+ del baddbmm_input
+
+ if self.upcast_softmax:
+ attention_scores = attention_scores.float()
+
+ attention_probs = attention_scores.softmax(dim=-1)
+ del attention_scores
+
+ attention_probs = attention_probs.to(dtype)
+
+ return attention_probs
+
+ def prepare_attention_mask(
+ self, attention_mask: torch.Tensor, target_length: int, batch_size: int, out_dim: int = 3
+ ) -> torch.Tensor:
+ r"""
+ Prepare the attention mask for the attention computation.
+
+ Args:
+ attention_mask (`torch.Tensor`):
+ The attention mask to prepare.
+ target_length (`int`):
+ The target length of the attention mask. This is the length of the attention mask after padding.
+ batch_size (`int`):
+ The batch size, which is used to repeat the attention mask.
+ out_dim (`int`, *optional*, defaults to `3`):
+ The output dimension of the attention mask. Can be either `3` or `4`.
+
+ Returns:
+ `torch.Tensor`: The prepared attention mask.
+ """
+ head_size = self.heads
+ if attention_mask is None:
+ return attention_mask
+
+ current_length: int = attention_mask.shape[-1]
+ if current_length != target_length:
+ if attention_mask.device.type == "mps":
+ # HACK: MPS: Does not support padding by greater than dimension of input tensor.
+ # Instead, we can manually construct the padding tensor.
+ padding_shape = (attention_mask.shape[0], attention_mask.shape[1], target_length)
+ padding = torch.zeros(padding_shape, dtype=attention_mask.dtype, device=attention_mask.device)
+ attention_mask = torch.cat([attention_mask, padding], dim=2)
+ else:
+ # TODO: for pipelines such as stable-diffusion, padding cross-attn mask:
+ # we want to instead pad by (0, remaining_length), where remaining_length is:
+ # remaining_length: int = target_length - current_length
+ # TODO: re-enable tests/models/test_models_unet_2d_condition.py#test_model_xattn_padding
+ attention_mask = F.pad(attention_mask, (0, target_length), value=0.0)
+
+ if out_dim == 3:
+ if attention_mask.shape[0] < batch_size * head_size:
+ attention_mask = attention_mask.repeat_interleave(head_size, dim=0)
+ elif out_dim == 4:
+ attention_mask = attention_mask.unsqueeze(1)
+ attention_mask = attention_mask.repeat_interleave(head_size, dim=1)
+
+ return attention_mask
+
+ def norm_encoder_hidden_states(self, encoder_hidden_states: torch.Tensor) -> torch.Tensor:
+ r"""
+ Normalize the encoder hidden states. Requires `self.norm_cross` to be specified when constructing the
+ `Attention` class.
+
+ Args:
+ encoder_hidden_states (`torch.Tensor`): Hidden states of the encoder.
+
+ Returns:
+ `torch.Tensor`: The normalized encoder hidden states.
+ """
+ assert self.norm_cross is not None, "self.norm_cross must be defined to call self.norm_encoder_hidden_states"
+
+ if isinstance(self.norm_cross, nn.LayerNorm):
+ encoder_hidden_states = self.norm_cross(encoder_hidden_states)
+ elif isinstance(self.norm_cross, nn.GroupNorm):
+ # Group norm norms along the channels dimension and expects
+ # input to be in the shape of (N, C, *). In this case, we want
+ # to norm along the hidden dimension, so we need to move
+ # (batch_size, sequence_length, hidden_size) ->
+ # (batch_size, hidden_size, sequence_length)
+ encoder_hidden_states = encoder_hidden_states.transpose(1, 2)
+ encoder_hidden_states = self.norm_cross(encoder_hidden_states)
+ encoder_hidden_states = encoder_hidden_states.transpose(1, 2)
+ else:
+ assert False
+
+ return encoder_hidden_states
+
+ @torch.no_grad()
+ def fuse_projections(self, fuse=True):
+ device = self.to_q.weight.data.device
+ dtype = self.to_q.weight.data.dtype
+
+ if not self.is_cross_attention:
+ # fetch weight matrices.
+ concatenated_weights = torch.cat([self.to_q.weight.data, self.to_k.weight.data, self.to_v.weight.data])
+ in_features = concatenated_weights.shape[1]
+ out_features = concatenated_weights.shape[0]
+
+ # create a new single projection layer and copy over the weights.
+ self.to_qkv = nn.Linear(in_features, out_features, bias=self.use_bias, device=device, dtype=dtype)
+ self.to_qkv.weight.copy_(concatenated_weights)
+ if self.use_bias:
+ concatenated_bias = torch.cat([self.to_q.bias.data, self.to_k.bias.data, self.to_v.bias.data])
+ self.to_qkv.bias.copy_(concatenated_bias)
+
+ else:
+ concatenated_weights = torch.cat([self.to_k.weight.data, self.to_v.weight.data])
+ in_features = concatenated_weights.shape[1]
+ out_features = concatenated_weights.shape[0]
+
+ self.to_kv = nn.Linear(in_features, out_features, bias=self.use_bias, device=device, dtype=dtype)
+ self.to_kv.weight.copy_(concatenated_weights)
+ if self.use_bias:
+ concatenated_bias = torch.cat([self.to_k.bias.data, self.to_v.bias.data])
+ self.to_kv.bias.copy_(concatenated_bias)
+
+ # handle added projections for SD3 and others.
+ if hasattr(self, "add_q_proj") and hasattr(self, "add_k_proj") and hasattr(self, "add_v_proj"):
+ concatenated_weights = torch.cat(
+ [self.add_q_proj.weight.data, self.add_k_proj.weight.data, self.add_v_proj.weight.data]
+ )
+ in_features = concatenated_weights.shape[1]
+ out_features = concatenated_weights.shape[0]
+
+ self.to_added_qkv = nn.Linear(
+ in_features, out_features, bias=self.added_proj_bias, device=device, dtype=dtype
+ )
+ self.to_added_qkv.weight.copy_(concatenated_weights)
+ if self.added_proj_bias:
+ concatenated_bias = torch.cat(
+ [self.add_q_proj.bias.data, self.add_k_proj.bias.data, self.add_v_proj.bias.data]
+ )
+ self.to_added_qkv.bias.copy_(concatenated_bias)
+
+ self.fused_projections = fuse
+
+
+class AttnProcessor:
+ r"""
+ Default processor for performing attention-related computations.
+ """
+
+ def __call__(
+ self,
+ attn: Attention,
+ hidden_states: torch.Tensor,
+ encoder_hidden_states: Optional[torch.Tensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ temb: Optional[torch.Tensor] = None,
+ *args,
+ **kwargs,
+ ) -> torch.Tensor:
+ if len(args) > 0 or kwargs.get("scale", None) is not None:
+ deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
+ deprecate("scale", "1.0.0", deprecation_message)
+
+ residual = hidden_states
+
+ if attn.spatial_norm is not None:
+ hidden_states = attn.spatial_norm(hidden_states, temb)
+
+ input_ndim = hidden_states.ndim
+
+ if input_ndim == 4:
+ batch_size, channel, height, width = hidden_states.shape
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
+
+ batch_size, sequence_length, _ = (
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
+ )
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
+
+ if attn.group_norm is not None:
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
+
+ query = attn.to_q(hidden_states)
+
+ if encoder_hidden_states is None:
+ encoder_hidden_states = hidden_states
+ elif attn.norm_cross:
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
+
+ key = attn.to_k(encoder_hidden_states)
+ value = attn.to_v(encoder_hidden_states)
+
+ query = attn.head_to_batch_dim(query)
+ key = attn.head_to_batch_dim(key)
+ value = attn.head_to_batch_dim(value)
+
+ attention_probs = attn.get_attention_scores(query, key, attention_mask)
+ hidden_states = torch.bmm(attention_probs, value)
+ hidden_states = attn.batch_to_head_dim(hidden_states)
+
+ # linear proj
+ hidden_states = attn.to_out[0](hidden_states)
+ # dropout
+ hidden_states = attn.to_out[1](hidden_states)
+
+ if input_ndim == 4:
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
+
+ if attn.residual_connection:
+ hidden_states = hidden_states + residual
+
+ hidden_states = hidden_states / attn.rescale_output_factor
+
+ return hidden_states
+
+
+class CustomDiffusionAttnProcessor(nn.Module):
+ r"""
+ Processor for implementing attention for the Custom Diffusion method.
+
+ Args:
+ train_kv (`bool`, defaults to `True`):
+ Whether to newly train the key and value matrices corresponding to the text features.
+ train_q_out (`bool`, defaults to `True`):
+ Whether to newly train query matrices corresponding to the latent image features.
+ hidden_size (`int`, *optional*, defaults to `None`):
+ The hidden size of the attention layer.
+ cross_attention_dim (`int`, *optional*, defaults to `None`):
+ The number of channels in the `encoder_hidden_states`.
+ out_bias (`bool`, defaults to `True`):
+ Whether to include the bias parameter in `train_q_out`.
+ dropout (`float`, *optional*, defaults to 0.0):
+ The dropout probability to use.
+ """
+
+ def __init__(
+ self,
+ train_kv: bool = True,
+ train_q_out: bool = True,
+ hidden_size: Optional[int] = None,
+ cross_attention_dim: Optional[int] = None,
+ out_bias: bool = True,
+ dropout: float = 0.0,
+ ):
+ super().__init__()
+ self.train_kv = train_kv
+ self.train_q_out = train_q_out
+
+ self.hidden_size = hidden_size
+ self.cross_attention_dim = cross_attention_dim
+
+ # `_custom_diffusion` id for easy serialization and loading.
+ if self.train_kv:
+ self.to_k_custom_diffusion = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
+ self.to_v_custom_diffusion = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
+ if self.train_q_out:
+ self.to_q_custom_diffusion = nn.Linear(hidden_size, hidden_size, bias=False)
+ self.to_out_custom_diffusion = nn.ModuleList([])
+ self.to_out_custom_diffusion.append(nn.Linear(hidden_size, hidden_size, bias=out_bias))
+ self.to_out_custom_diffusion.append(nn.Dropout(dropout))
+
+ def __call__(
+ self,
+ attn: Attention,
+ hidden_states: torch.Tensor,
+ encoder_hidden_states: Optional[torch.Tensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ ) -> torch.Tensor:
+ batch_size, sequence_length, _ = hidden_states.shape
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
+ if self.train_q_out:
+ query = self.to_q_custom_diffusion(hidden_states).to(attn.to_q.weight.dtype)
+ else:
+ query = attn.to_q(hidden_states.to(attn.to_q.weight.dtype))
+
+ if encoder_hidden_states is None:
+ crossattn = False
+ encoder_hidden_states = hidden_states
+ else:
+ crossattn = True
+ if attn.norm_cross:
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
+
+ if self.train_kv:
+ key = self.to_k_custom_diffusion(encoder_hidden_states.to(self.to_k_custom_diffusion.weight.dtype))
+ value = self.to_v_custom_diffusion(encoder_hidden_states.to(self.to_v_custom_diffusion.weight.dtype))
+ key = key.to(attn.to_q.weight.dtype)
+ value = value.to(attn.to_q.weight.dtype)
+ else:
+ key = attn.to_k(encoder_hidden_states)
+ value = attn.to_v(encoder_hidden_states)
+
+ if crossattn:
+ detach = torch.ones_like(key)
+ detach[:, :1, :] = detach[:, :1, :] * 0.0
+ key = detach * key + (1 - detach) * key.detach()
+ value = detach * value + (1 - detach) * value.detach()
+
+ query = attn.head_to_batch_dim(query)
+ key = attn.head_to_batch_dim(key)
+ value = attn.head_to_batch_dim(value)
+
+ attention_probs = attn.get_attention_scores(query, key, attention_mask)
+ hidden_states = torch.bmm(attention_probs, value)
+ hidden_states = attn.batch_to_head_dim(hidden_states)
+
+ if self.train_q_out:
+ # linear proj
+ hidden_states = self.to_out_custom_diffusion[0](hidden_states)
+ # dropout
+ hidden_states = self.to_out_custom_diffusion[1](hidden_states)
+ else:
+ # linear proj
+ hidden_states = attn.to_out[0](hidden_states)
+ # dropout
+ hidden_states = attn.to_out[1](hidden_states)
+
+ return hidden_states
+
+
+class AttnAddedKVProcessor:
+ r"""
+ Processor for performing attention-related computations with extra learnable key and value matrices for the text
+ encoder.
+ """
+
+ def __call__(
+ self,
+ attn: Attention,
+ hidden_states: torch.Tensor,
+ encoder_hidden_states: Optional[torch.Tensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ *args,
+ **kwargs,
+ ) -> torch.Tensor:
+ if len(args) > 0 or kwargs.get("scale", None) is not None:
+ deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
+ deprecate("scale", "1.0.0", deprecation_message)
+
+ residual = hidden_states
+
+ hidden_states = hidden_states.view(hidden_states.shape[0], hidden_states.shape[1], -1).transpose(1, 2)
+ batch_size, sequence_length, _ = hidden_states.shape
+
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
+
+ if encoder_hidden_states is None:
+ encoder_hidden_states = hidden_states
+ elif attn.norm_cross:
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
+
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
+
+ query = attn.to_q(hidden_states)
+ query = attn.head_to_batch_dim(query)
+
+ encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
+ encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)
+ encoder_hidden_states_key_proj = attn.head_to_batch_dim(encoder_hidden_states_key_proj)
+ encoder_hidden_states_value_proj = attn.head_to_batch_dim(encoder_hidden_states_value_proj)
+
+ if not attn.only_cross_attention:
+ key = attn.to_k(hidden_states)
+ value = attn.to_v(hidden_states)
+ key = attn.head_to_batch_dim(key)
+ value = attn.head_to_batch_dim(value)
+ key = torch.cat([encoder_hidden_states_key_proj, key], dim=1)
+ value = torch.cat([encoder_hidden_states_value_proj, value], dim=1)
+ else:
+ key = encoder_hidden_states_key_proj
+ value = encoder_hidden_states_value_proj
+
+ attention_probs = attn.get_attention_scores(query, key, attention_mask)
+ hidden_states = torch.bmm(attention_probs, value)
+ hidden_states = attn.batch_to_head_dim(hidden_states)
+
+ # linear proj
+ hidden_states = attn.to_out[0](hidden_states)
+ # dropout
+ hidden_states = attn.to_out[1](hidden_states)
+
+ hidden_states = hidden_states.transpose(-1, -2).reshape(residual.shape)
+ hidden_states = hidden_states + residual
+
+ return hidden_states
+
+
+class AttnAddedKVProcessor2_0:
+ r"""
+ Processor for performing scaled dot-product attention (enabled by default if you're using PyTorch 2.0), with extra
+ learnable key and value matrices for the text encoder.
+ """
+
+ def __init__(self):
+ if not hasattr(F, "scaled_dot_product_attention"):
+ raise ImportError(
+ "AttnAddedKVProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0."
+ )
+
+ def __call__(
+ self,
+ attn: Attention,
+ hidden_states: torch.Tensor,
+ encoder_hidden_states: Optional[torch.Tensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ *args,
+ **kwargs,
+ ) -> torch.Tensor:
+ if len(args) > 0 or kwargs.get("scale", None) is not None:
+ deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
+ deprecate("scale", "1.0.0", deprecation_message)
+
+ residual = hidden_states
+
+ hidden_states = hidden_states.view(hidden_states.shape[0], hidden_states.shape[1], -1).transpose(1, 2)
+ batch_size, sequence_length, _ = hidden_states.shape
+
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size, out_dim=4)
+
+ if encoder_hidden_states is None:
+ encoder_hidden_states = hidden_states
+ elif attn.norm_cross:
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
+
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
+
+ query = attn.to_q(hidden_states)
+ query = attn.head_to_batch_dim(query, out_dim=4)
+
+ encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
+ encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)
+ encoder_hidden_states_key_proj = attn.head_to_batch_dim(encoder_hidden_states_key_proj, out_dim=4)
+ encoder_hidden_states_value_proj = attn.head_to_batch_dim(encoder_hidden_states_value_proj, out_dim=4)
+
+ if not attn.only_cross_attention:
+ key = attn.to_k(hidden_states)
+ value = attn.to_v(hidden_states)
+ key = attn.head_to_batch_dim(key, out_dim=4)
+ value = attn.head_to_batch_dim(value, out_dim=4)
+ key = torch.cat([encoder_hidden_states_key_proj, key], dim=2)
+ value = torch.cat([encoder_hidden_states_value_proj, value], dim=2)
+ else:
+ key = encoder_hidden_states_key_proj
+ value = encoder_hidden_states_value_proj
+
+ # the output of sdp = (batch, num_heads, seq_len, head_dim)
+ # TODO: add support for attn.scale when we move to Torch 2.1
+ hidden_states = F.scaled_dot_product_attention(
+ query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
+ )
+ hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, residual.shape[1])
+
+ # linear proj
+ hidden_states = attn.to_out[0](hidden_states)
+ # dropout
+ hidden_states = attn.to_out[1](hidden_states)
+
+ hidden_states = hidden_states.transpose(-1, -2).reshape(residual.shape)
+ hidden_states = hidden_states + residual
+
+ return hidden_states
+
+
+class JointAttnProcessor2_0:
+ """Attention processor used typically in processing the SD3-like self-attention projections."""
+
+ def __init__(self):
+ if not hasattr(F, "scaled_dot_product_attention"):
+ raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
+
+ def __call__(
+ self,
+ attn: Attention,
+ hidden_states: torch.FloatTensor,
+ encoder_hidden_states: torch.FloatTensor = None,
+ attention_mask: Optional[torch.FloatTensor] = None,
+ *args,
+ **kwargs,
+ ) -> torch.FloatTensor:
+ residual = hidden_states
+
+ input_ndim = hidden_states.ndim
+ if input_ndim == 4:
+ batch_size, channel, height, width = hidden_states.shape
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
+ context_input_ndim = encoder_hidden_states.ndim
+ if context_input_ndim == 4:
+ batch_size, channel, height, width = encoder_hidden_states.shape
+ encoder_hidden_states = encoder_hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
+
+ batch_size = encoder_hidden_states.shape[0]
+
+ # `sample` projections.
+ query = attn.to_q(hidden_states)
+ key = attn.to_k(hidden_states)
+ value = attn.to_v(hidden_states)
+
+ # `context` projections.
+ encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states)
+ encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
+ encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)
+
+ # attention
+ query = torch.cat([query, encoder_hidden_states_query_proj], dim=1)
+ key = torch.cat([key, encoder_hidden_states_key_proj], dim=1)
+ value = torch.cat([value, encoder_hidden_states_value_proj], dim=1)
+
+ inner_dim = key.shape[-1]
+ head_dim = inner_dim // attn.heads
+ query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
+ key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
+ value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
+
+ hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False)
+ hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
+ hidden_states = hidden_states.to(query.dtype)
+
+ # Split the attention outputs.
+ hidden_states, encoder_hidden_states = (
+ hidden_states[:, : residual.shape[1]],
+ hidden_states[:, residual.shape[1] :],
+ )
+
+ # linear proj
+ hidden_states = attn.to_out[0](hidden_states)
+ # dropout
+ hidden_states = attn.to_out[1](hidden_states)
+ if not attn.context_pre_only:
+ encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
+
+ if input_ndim == 4:
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
+ if context_input_ndim == 4:
+ encoder_hidden_states = encoder_hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
+
+ return hidden_states, encoder_hidden_states
+
+
+class FusedJointAttnProcessor2_0:
+ """Attention processor used typically in processing the SD3-like self-attention projections."""
+
+ def __init__(self):
+ if not hasattr(F, "scaled_dot_product_attention"):
+ raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
+
+ def __call__(
+ self,
+ attn: Attention,
+ hidden_states: torch.FloatTensor,
+ encoder_hidden_states: torch.FloatTensor = None,
+ attention_mask: Optional[torch.FloatTensor] = None,
+ *args,
+ **kwargs,
+ ) -> torch.FloatTensor:
+ residual = hidden_states
+
+ input_ndim = hidden_states.ndim
+ if input_ndim == 4:
+ batch_size, channel, height, width = hidden_states.shape
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
+ context_input_ndim = encoder_hidden_states.ndim
+ if context_input_ndim == 4:
+ batch_size, channel, height, width = encoder_hidden_states.shape
+ encoder_hidden_states = encoder_hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
+
+ batch_size = encoder_hidden_states.shape[0]
+
+ # `sample` projections.
+ qkv = attn.to_qkv(hidden_states)
+ split_size = qkv.shape[-1] // 3
+ query, key, value = torch.split(qkv, split_size, dim=-1)
+
+ # `context` projections.
+ encoder_qkv = attn.to_added_qkv(encoder_hidden_states)
+ split_size = encoder_qkv.shape[-1] // 3
+ (
+ encoder_hidden_states_query_proj,
+ encoder_hidden_states_key_proj,
+ encoder_hidden_states_value_proj,
+ ) = torch.split(encoder_qkv, split_size, dim=-1)
+
+ # attention
+ query = torch.cat([query, encoder_hidden_states_query_proj], dim=1)
+ key = torch.cat([key, encoder_hidden_states_key_proj], dim=1)
+ value = torch.cat([value, encoder_hidden_states_value_proj], dim=1)
+
+ inner_dim = key.shape[-1]
+ head_dim = inner_dim // attn.heads
+ query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
+ key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
+ value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
+
+ hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False)
+ hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
+ hidden_states = hidden_states.to(query.dtype)
+
+ # Split the attention outputs.
+ hidden_states, encoder_hidden_states = (
+ hidden_states[:, : residual.shape[1]],
+ hidden_states[:, residual.shape[1] :],
+ )
+
+ # linear proj
+ hidden_states = attn.to_out[0](hidden_states)
+ # dropout
+ hidden_states = attn.to_out[1](hidden_states)
+ if not attn.context_pre_only:
+ encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
+
+ if input_ndim == 4:
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
+ if context_input_ndim == 4:
+ encoder_hidden_states = encoder_hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
+
+ return hidden_states, encoder_hidden_states
+
+
+class AuraFlowAttnProcessor2_0:
+ """Attention processor used typically in processing Aura Flow."""
+
+ def __init__(self):
+ if not hasattr(F, "scaled_dot_product_attention") and is_torch_version("<", "2.1"):
+ raise ImportError(
+ "AuraFlowAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to at least 2.1 or above as we use `scale` in `F.scaled_dot_product_attention()`. "
+ )
+
+ def __call__(
+ self,
+ attn: Attention,
+ hidden_states: torch.FloatTensor,
+ encoder_hidden_states: torch.FloatTensor = None,
+ *args,
+ **kwargs,
+ ) -> torch.FloatTensor:
+ batch_size = hidden_states.shape[0]
+
+ # `sample` projections.
+ query = attn.to_q(hidden_states)
+ key = attn.to_k(hidden_states)
+ value = attn.to_v(hidden_states)
+
+ # `context` projections.
+ if encoder_hidden_states is not None:
+ encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states)
+ encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
+ encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)
+
+ # Reshape.
+ inner_dim = key.shape[-1]
+ head_dim = inner_dim // attn.heads
+ query = query.view(batch_size, -1, attn.heads, head_dim)
+ key = key.view(batch_size, -1, attn.heads, head_dim)
+ value = value.view(batch_size, -1, attn.heads, head_dim)
+
+ # Apply QK norm.
+ if attn.norm_q is not None:
+ query = attn.norm_q(query)
+ if attn.norm_k is not None:
+ key = attn.norm_k(key)
+
+ # Concatenate the projections.
+ if encoder_hidden_states is not None:
+ encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view(
+ batch_size, -1, attn.heads, head_dim
+ )
+ encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view(batch_size, -1, attn.heads, head_dim)
+ encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view(
+ batch_size, -1, attn.heads, head_dim
+ )
+
+ if attn.norm_added_q is not None:
+ encoder_hidden_states_query_proj = attn.norm_added_q(encoder_hidden_states_query_proj)
+ if attn.norm_added_k is not None:
+ encoder_hidden_states_key_proj = attn.norm_added_q(encoder_hidden_states_key_proj)
+
+ query = torch.cat([encoder_hidden_states_query_proj, query], dim=1)
+ key = torch.cat([encoder_hidden_states_key_proj, key], dim=1)
+ value = torch.cat([encoder_hidden_states_value_proj, value], dim=1)
+
+ query = query.transpose(1, 2)
+ key = key.transpose(1, 2)
+ value = value.transpose(1, 2)
+
+ # Attention.
+ hidden_states = F.scaled_dot_product_attention(
+ query, key, value, dropout_p=0.0, scale=attn.scale, is_causal=False
+ )
+ hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
+ hidden_states = hidden_states.to(query.dtype)
+
+ # Split the attention outputs.
+ if encoder_hidden_states is not None:
+ hidden_states, encoder_hidden_states = (
+ hidden_states[:, encoder_hidden_states.shape[1] :],
+ hidden_states[:, : encoder_hidden_states.shape[1]],
+ )
+
+ # linear proj
+ hidden_states = attn.to_out[0](hidden_states)
+ # dropout
+ hidden_states = attn.to_out[1](hidden_states)
+ if encoder_hidden_states is not None:
+ encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
+
+ if encoder_hidden_states is not None:
+ return hidden_states, encoder_hidden_states
+ else:
+ return hidden_states
+
+
+class FusedAuraFlowAttnProcessor2_0:
+ """Attention processor used typically in processing Aura Flow with fused projections."""
+
+ def __init__(self):
+ if not hasattr(F, "scaled_dot_product_attention") and is_torch_version("<", "2.1"):
+ raise ImportError(
+ "FusedAuraFlowAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to at least 2.1 or above as we use `scale` in `F.scaled_dot_product_attention()`. "
+ )
+
+ def __call__(
+ self,
+ attn: Attention,
+ hidden_states: torch.FloatTensor,
+ encoder_hidden_states: torch.FloatTensor = None,
+ *args,
+ **kwargs,
+ ) -> torch.FloatTensor:
+ batch_size = hidden_states.shape[0]
+
+ # `sample` projections.
+ qkv = attn.to_qkv(hidden_states)
+ split_size = qkv.shape[-1] // 3
+ query, key, value = torch.split(qkv, split_size, dim=-1)
+
+ # `context` projections.
+ if encoder_hidden_states is not None:
+ encoder_qkv = attn.to_added_qkv(encoder_hidden_states)
+ split_size = encoder_qkv.shape[-1] // 3
+ (
+ encoder_hidden_states_query_proj,
+ encoder_hidden_states_key_proj,
+ encoder_hidden_states_value_proj,
+ ) = torch.split(encoder_qkv, split_size, dim=-1)
+
+ # Reshape.
+ inner_dim = key.shape[-1]
+ head_dim = inner_dim // attn.heads
+ query = query.view(batch_size, -1, attn.heads, head_dim)
+ key = key.view(batch_size, -1, attn.heads, head_dim)
+ value = value.view(batch_size, -1, attn.heads, head_dim)
+
+ # Apply QK norm.
+ if attn.norm_q is not None:
+ query = attn.norm_q(query)
+ if attn.norm_k is not None:
+ key = attn.norm_k(key)
+
+ # Concatenate the projections.
+ if encoder_hidden_states is not None:
+ encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view(
+ batch_size, -1, attn.heads, head_dim
+ )
+ encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view(batch_size, -1, attn.heads, head_dim)
+ encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view(
+ batch_size, -1, attn.heads, head_dim
+ )
+
+ if attn.norm_added_q is not None:
+ encoder_hidden_states_query_proj = attn.norm_added_q(encoder_hidden_states_query_proj)
+ if attn.norm_added_k is not None:
+ encoder_hidden_states_key_proj = attn.norm_added_q(encoder_hidden_states_key_proj)
+
+ query = torch.cat([encoder_hidden_states_query_proj, query], dim=1)
+ key = torch.cat([encoder_hidden_states_key_proj, key], dim=1)
+ value = torch.cat([encoder_hidden_states_value_proj, value], dim=1)
+
+ query = query.transpose(1, 2)
+ key = key.transpose(1, 2)
+ value = value.transpose(1, 2)
+
+ # Attention.
+ hidden_states = F.scaled_dot_product_attention(
+ query, key, value, dropout_p=0.0, scale=attn.scale, is_causal=False
+ )
+ hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
+ hidden_states = hidden_states.to(query.dtype)
+
+ # Split the attention outputs.
+ if encoder_hidden_states is not None:
+ hidden_states, encoder_hidden_states = (
+ hidden_states[:, encoder_hidden_states.shape[1] :],
+ hidden_states[:, : encoder_hidden_states.shape[1]],
+ )
+
+ # linear proj
+ hidden_states = attn.to_out[0](hidden_states)
+ # dropout
+ hidden_states = attn.to_out[1](hidden_states)
+ if encoder_hidden_states is not None:
+ encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
+
+ if encoder_hidden_states is not None:
+ return hidden_states, encoder_hidden_states
+ else:
+ return hidden_states
+
+
+# YiYi to-do: refactor rope related functions/classes
+def apply_rope(xq, xk, freqs_cis):
+ xq_ = xq.float().reshape(*xq.shape[:-1], -1, 1, 2)
+ xk_ = xk.float().reshape(*xk.shape[:-1], -1, 1, 2)
+ xq_out = freqs_cis[..., 0] * xq_[..., 0] + freqs_cis[..., 1] * xq_[..., 1]
+ xk_out = freqs_cis[..., 0] * xk_[..., 0] + freqs_cis[..., 1] * xk_[..., 1]
+ return xq_out.reshape(*xq.shape).type_as(xq), xk_out.reshape(*xk.shape).type_as(xk)
+
+
+#class FluxSingleAttnProcessor2_0:
+# r"""
+# Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0).
+# """
+#
+# def __init__(self):
+# if not hasattr(F, "scaled_dot_product_attention"):
+# raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
+#
+# def __call__(
+# self,
+# attn: Attention,
+# hidden_states: torch.Tensor,
+# encoder_hidden_states: Optional[torch.Tensor] = None,
+# attention_mask: Optional[torch.FloatTensor] = None,
+# image_rotary_emb: Optional[torch.Tensor] = None,
+# ) -> torch.Tensor:
+# input_ndim = hidden_states.ndim
+#
+# if input_ndim == 4:
+# batch_size, channel, height, width = hidden_states.shape
+# hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
+#
+# batch_size, _, _ = hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
+#
+# query = attn.to_q(hidden_states)
+# if encoder_hidden_states is None:
+# encoder_hidden_states = hidden_states
+#
+# key = attn.to_k(encoder_hidden_states)
+# value = attn.to_v(encoder_hidden_states)
+#
+# inner_dim = key.shape[-1]
+# head_dim = inner_dim // attn.heads
+#
+# query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
+#
+# key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
+# value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
+#
+# if attn.norm_q is not None:
+# query = attn.norm_q(query)
+# if attn.norm_k is not None:
+# key = attn.norm_k(key)
+#
+# # Apply RoPE if needed
+# if image_rotary_emb is not None:
+# # YiYi to-do: update uising apply_rotary_emb
+# # from ..embeddings import apply_rotary_emb
+# # query = apply_rotary_emb(query, image_rotary_emb)
+# # key = apply_rotary_emb(key, image_rotary_emb)
+# query, key = apply_rope(query, key, image_rotary_emb)
+#
+# # the output of sdp = (batch, num_heads, seq_len, head_dim)
+# # TODO: add support for attn.scale when we move to Torch 2.1
+# hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False)
+#
+# hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
+# hidden_states = hidden_states.to(query.dtype)
+#
+# if input_ndim == 4:
+# hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
+#
+# return hidden_states
+
+
+#class FluxAttnProcessor2_0:
+# """Attention processor used typically in processing the SD3-like self-attention projections."""
+#
+# def __init__(self):
+# if not hasattr(F, "scaled_dot_product_attention"):
+# raise ImportError("FluxAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
+#
+# def __call__(
+# self,
+# attn: Attention,
+# hidden_states: torch.FloatTensor,
+# encoder_hidden_states: torch.FloatTensor = None,
+# attention_mask: Optional[torch.FloatTensor] = None,
+# image_rotary_emb: Optional[torch.Tensor] = None,
+# ) -> torch.FloatTensor:
+# input_ndim = hidden_states.ndim
+# if input_ndim == 4:
+# batch_size, channel, height, width = hidden_states.shape
+# hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
+# context_input_ndim = encoder_hidden_states.ndim
+# if context_input_ndim == 4:
+# batch_size, channel, height, width = encoder_hidden_states.shape
+# encoder_hidden_states = encoder_hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
+#
+# batch_size = encoder_hidden_states.shape[0]
+#
+# # `sample` projections.
+# query = attn.to_q(hidden_states)
+# key = attn.to_k(hidden_states)
+# value = attn.to_v(hidden_states)
+#
+# inner_dim = key.shape[-1]
+# head_dim = inner_dim // attn.heads
+#
+# query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
+# key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
+# value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
+#
+# if attn.norm_q is not None:
+# query = attn.norm_q(query)
+# if attn.norm_k is not None:
+# key = attn.norm_k(key)
+#
+# # `context` projections.
+# encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states)
+# encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
+# encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)
+#
+# encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view(
+# batch_size, -1, attn.heads, head_dim
+# ).transpose(1, 2)
+# encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view(
+# batch_size, -1, attn.heads, head_dim
+# ).transpose(1, 2)
+# encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view(
+# batch_size, -1, attn.heads, head_dim
+# ).transpose(1, 2)
+#
+# if attn.norm_added_q is not None:
+# encoder_hidden_states_query_proj = attn.norm_added_q(encoder_hidden_states_query_proj)
+# if attn.norm_added_k is not None:
+# encoder_hidden_states_key_proj = attn.norm_added_k(encoder_hidden_states_key_proj)
+#
+# # attention
+# query = torch.cat([encoder_hidden_states_query_proj, query], dim=2)
+# key = torch.cat([encoder_hidden_states_key_proj, key], dim=2)
+# value = torch.cat([encoder_hidden_states_value_proj, value], dim=2)
+#
+# if image_rotary_emb is not None:
+# # YiYi to-do: update uising apply_rotary_emb
+# # from ..embeddings import apply_rotary_emb
+# # query = apply_rotary_emb(query, image_rotary_emb)
+# # key = apply_rotary_emb(key, image_rotary_emb)
+# query, key = apply_rope(query, key, image_rotary_emb)
+#
+# hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False)
+# hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
+# hidden_states = hidden_states.to(query.dtype)
+#
+# encoder_hidden_states, hidden_states = (
+# hidden_states[:, : encoder_hidden_states.shape[1]],
+# hidden_states[:, encoder_hidden_states.shape[1] :],
+# )
+#
+# # linear proj
+# hidden_states = attn.to_out[0](hidden_states)
+# # dropout
+# hidden_states = attn.to_out[1](hidden_states)
+# encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
+#
+# if input_ndim == 4:
+# hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
+# if context_input_ndim == 4:
+# encoder_hidden_states = encoder_hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
+#
+# return hidden_states, encoder_hidden_states
+
+class FluxAttnProcessor2_0:
+ """Attention processor used typically in processing the SD3-like self-attention projections."""
+
+ def __init__(self):
+ if not hasattr(F, "scaled_dot_product_attention"):
+ raise ImportError("FluxAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
+
+ def __call__(
+ self,
+ attn: Attention,
+ hidden_states: torch.FloatTensor,
+ encoder_hidden_states: torch.FloatTensor = None,
+ attention_mask: Optional[torch.FloatTensor] = None,
+ image_rotary_emb: Optional[torch.Tensor] = None,
+ ) -> torch.FloatTensor:
+ batch_size, _, _ = hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
+
+ # `sample` projections.
+ query = attn.to_q(hidden_states)
+ key = attn.to_k(hidden_states)
+ value = attn.to_v(hidden_states)
+
+ inner_dim = key.shape[-1]
+ head_dim = inner_dim // attn.heads
+
+ query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
+ key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
+ value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
+
+ if attn.norm_q is not None:
+ query = attn.norm_q(query)
+ if attn.norm_k is not None:
+ key = attn.norm_k(key)
+
+ # the attention in FluxSingleTransformerBlock does not use `encoder_hidden_states`
+ if encoder_hidden_states is not None:
+ # `context` projections.
+ encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states)
+ encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
+ encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)
+
+ encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view(
+ batch_size, -1, attn.heads, head_dim
+ ).transpose(1, 2)
+ encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view(
+ batch_size, -1, attn.heads, head_dim
+ ).transpose(1, 2)
+ encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view(
+ batch_size, -1, attn.heads, head_dim
+ ).transpose(1, 2)
+
+ if attn.norm_added_q is not None:
+ encoder_hidden_states_query_proj = attn.norm_added_q(encoder_hidden_states_query_proj)
+ if attn.norm_added_k is not None:
+ encoder_hidden_states_key_proj = attn.norm_added_k(encoder_hidden_states_key_proj)
+
+ # attention
+ query = torch.cat([encoder_hidden_states_query_proj, query], dim=2)
+ key = torch.cat([encoder_hidden_states_key_proj, key], dim=2)
+ value = torch.cat([encoder_hidden_states_value_proj, value], dim=2)
+
+ if image_rotary_emb is not None:
+ from .embeddings import apply_rotary_emb
+
+ query = apply_rotary_emb(query, image_rotary_emb)
+ key = apply_rotary_emb(key, image_rotary_emb)
+
+ hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False)
+ hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
+ hidden_states = hidden_states.to(query.dtype)
+
+ if encoder_hidden_states is not None:
+ encoder_hidden_states, hidden_states = (
+ hidden_states[:, : encoder_hidden_states.shape[1]],
+ hidden_states[:, encoder_hidden_states.shape[1] :],
+ )
+
+ # linear proj
+ hidden_states = attn.to_out[0](hidden_states)
+ # dropout
+ hidden_states = attn.to_out[1](hidden_states)
+ encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
+
+ return hidden_states, encoder_hidden_states
+ else:
+ return hidden_states
+
+class FluxSingleAttnProcessor2_0(FluxAttnProcessor2_0):
+ r"""
+ Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0).
+ """
+
+ def __init__(self):
+ deprecation_message = "`FluxSingleAttnProcessor2_0` is deprecated and will be removed in a future version. Please use `FluxAttnProcessor2_0` instead."
+ deprecate("FluxSingleAttnProcessor2_0", "0.32.0", deprecation_message)
+ super().__init__()
+
+
+class XFormersAttnAddedKVProcessor:
+ r"""
+ Processor for implementing memory efficient attention using xFormers.
+
+ Args:
+ attention_op (`Callable`, *optional*, defaults to `None`):
+ The base
+ [operator](https://facebookresearch.github.io/xformers/components/ops.html#xformers.ops.AttentionOpBase) to
+ use as the attention operator. It is recommended to set to `None`, and allow xFormers to choose the best
+ operator.
+ """
+
+ def __init__(self, attention_op: Optional[Callable] = None):
+ self.attention_op = attention_op
+
+ def __call__(
+ self,
+ attn: Attention,
+ hidden_states: torch.Tensor,
+ encoder_hidden_states: Optional[torch.Tensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ ) -> torch.Tensor:
+ residual = hidden_states
+ hidden_states = hidden_states.view(hidden_states.shape[0], hidden_states.shape[1], -1).transpose(1, 2)
+ batch_size, sequence_length, _ = hidden_states.shape
+
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
+
+ if encoder_hidden_states is None:
+ encoder_hidden_states = hidden_states
+ elif attn.norm_cross:
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
+
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
+
+ query = attn.to_q(hidden_states)
+ query = attn.head_to_batch_dim(query)
+
+ encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
+ encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)
+ encoder_hidden_states_key_proj = attn.head_to_batch_dim(encoder_hidden_states_key_proj)
+ encoder_hidden_states_value_proj = attn.head_to_batch_dim(encoder_hidden_states_value_proj)
+
+ if not attn.only_cross_attention:
+ key = attn.to_k(hidden_states)
+ value = attn.to_v(hidden_states)
+ key = attn.head_to_batch_dim(key)
+ value = attn.head_to_batch_dim(value)
+ key = torch.cat([encoder_hidden_states_key_proj, key], dim=1)
+ value = torch.cat([encoder_hidden_states_value_proj, value], dim=1)
+ else:
+ key = encoder_hidden_states_key_proj
+ value = encoder_hidden_states_value_proj
+
+ hidden_states = xformers.ops.memory_efficient_attention(
+ query, key, value, attn_bias=attention_mask, op=self.attention_op, scale=attn.scale
+ )
+ hidden_states = hidden_states.to(query.dtype)
+ hidden_states = attn.batch_to_head_dim(hidden_states)
+
+ # linear proj
+ hidden_states = attn.to_out[0](hidden_states)
+ # dropout
+ hidden_states = attn.to_out[1](hidden_states)
+
+ hidden_states = hidden_states.transpose(-1, -2).reshape(residual.shape)
+ hidden_states = hidden_states + residual
+
+ return hidden_states
+
+
+class XFormersAttnProcessor:
+ r"""
+ Processor for implementing memory efficient attention using xFormers.
+
+ Args:
+ attention_op (`Callable`, *optional*, defaults to `None`):
+ The base
+ [operator](https://facebookresearch.github.io/xformers/components/ops.html#xformers.ops.AttentionOpBase) to
+ use as the attention operator. It is recommended to set to `None`, and allow xFormers to choose the best
+ operator.
+ """
+
+ def __init__(self, attention_op: Optional[Callable] = None):
+ self.attention_op = attention_op
+
+ def __call__(
+ self,
+ attn: Attention,
+ hidden_states: torch.Tensor,
+ encoder_hidden_states: Optional[torch.Tensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ temb: Optional[torch.Tensor] = None,
+ *args,
+ **kwargs,
+ ) -> torch.Tensor:
+ if len(args) > 0 or kwargs.get("scale", None) is not None:
+ deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
+ deprecate("scale", "1.0.0", deprecation_message)
+
+ residual = hidden_states
+
+ if attn.spatial_norm is not None:
+ hidden_states = attn.spatial_norm(hidden_states, temb)
+
+ input_ndim = hidden_states.ndim
+
+ if input_ndim == 4:
+ batch_size, channel, height, width = hidden_states.shape
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
+
+ batch_size, key_tokens, _ = (
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
+ )
+
+ attention_mask = attn.prepare_attention_mask(attention_mask, key_tokens, batch_size)
+ if attention_mask is not None:
+ # expand our mask's singleton query_tokens dimension:
+ # [batch*heads, 1, key_tokens] ->
+ # [batch*heads, query_tokens, key_tokens]
+ # so that it can be added as a bias onto the attention scores that xformers computes:
+ # [batch*heads, query_tokens, key_tokens]
+ # we do this explicitly because xformers doesn't broadcast the singleton dimension for us.
+ _, query_tokens, _ = hidden_states.shape
+ attention_mask = attention_mask.expand(-1, query_tokens, -1)
+
+ if attn.group_norm is not None:
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
+
+ query = attn.to_q(hidden_states)
+
+ if encoder_hidden_states is None:
+ encoder_hidden_states = hidden_states
+ elif attn.norm_cross:
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
+
+ key = attn.to_k(encoder_hidden_states)
+ value = attn.to_v(encoder_hidden_states)
+
+ query = attn.head_to_batch_dim(query).contiguous()
+ key = attn.head_to_batch_dim(key).contiguous()
+ value = attn.head_to_batch_dim(value).contiguous()
+
+ hidden_states = xformers.ops.memory_efficient_attention(
+ query, key, value, attn_bias=attention_mask, op=self.attention_op, scale=attn.scale
+ )
+ hidden_states = hidden_states.to(query.dtype)
+ hidden_states = attn.batch_to_head_dim(hidden_states)
+
+ # linear proj
+ hidden_states = attn.to_out[0](hidden_states)
+ # dropout
+ hidden_states = attn.to_out[1](hidden_states)
+
+ if input_ndim == 4:
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
+
+ if attn.residual_connection:
+ hidden_states = hidden_states + residual
+
+ hidden_states = hidden_states / attn.rescale_output_factor
+
+ return hidden_states
+
+
+class AttnProcessorNPU:
+ r"""
+ Processor for implementing flash attention using torch_npu. Torch_npu supports only fp16 and bf16 data types. If
+ fp32 is used, F.scaled_dot_product_attention will be used for computation, but the acceleration effect on NPU is
+ not significant.
+
+ """
+
+ def __init__(self):
+ if not is_torch_npu_available():
+ raise ImportError("AttnProcessorNPU requires torch_npu extensions and is supported only on npu devices.")
+
+ def __call__(
+ self,
+ attn: Attention,
+ hidden_states: torch.Tensor,
+ encoder_hidden_states: Optional[torch.Tensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ temb: Optional[torch.Tensor] = None,
+ *args,
+ **kwargs,
+ ) -> torch.Tensor:
+ if len(args) > 0 or kwargs.get("scale", None) is not None:
+ deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
+ deprecate("scale", "1.0.0", deprecation_message)
+
+ residual = hidden_states
+ if attn.spatial_norm is not None:
+ hidden_states = attn.spatial_norm(hidden_states, temb)
+
+ input_ndim = hidden_states.ndim
+
+ if input_ndim == 4:
+ batch_size, channel, height, width = hidden_states.shape
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
+
+ batch_size, sequence_length, _ = (
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
+ )
+
+ if attention_mask is not None:
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
+ # scaled_dot_product_attention expects attention_mask shape to be
+ # (batch, heads, source_length, target_length)
+ attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
+
+ if attn.group_norm is not None:
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
+
+ query = attn.to_q(hidden_states)
+
+ if encoder_hidden_states is None:
+ encoder_hidden_states = hidden_states
+ elif attn.norm_cross:
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
+
+ key = attn.to_k(encoder_hidden_states)
+ value = attn.to_v(encoder_hidden_states)
+
+ inner_dim = key.shape[-1]
+ head_dim = inner_dim // attn.heads
+
+ query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
+
+ key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
+ value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
+
+ # the output of sdp = (batch, num_heads, seq_len, head_dim)
+ if query.dtype in (torch.float16, torch.bfloat16):
+ hidden_states = torch_npu.npu_fusion_attention(
+ query,
+ key,
+ value,
+ attn.heads,
+ input_layout="BNSD",
+ pse=None,
+ atten_mask=attention_mask,
+ scale=1.0 / math.sqrt(query.shape[-1]),
+ pre_tockens=65536,
+ next_tockens=65536,
+ keep_prob=1.0,
+ sync=False,
+ inner_precise=0,
+ )[0]
+ else:
+ # TODO: add support for attn.scale when we move to Torch 2.1
+ hidden_states = F.scaled_dot_product_attention(
+ query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
+ )
+
+ hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
+ hidden_states = hidden_states.to(query.dtype)
+
+ # linear proj
+ hidden_states = attn.to_out[0](hidden_states)
+ # dropout
+ hidden_states = attn.to_out[1](hidden_states)
+
+ if input_ndim == 4:
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
+
+ if attn.residual_connection:
+ hidden_states = hidden_states + residual
+
+ hidden_states = hidden_states / attn.rescale_output_factor
+
+ return hidden_states
+
+
+class AttnProcessor2_0:
+ r"""
+ Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0).
+ """
+
+ def __init__(self):
+ if not hasattr(F, "scaled_dot_product_attention"):
+ raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
+
+ def __call__(
+ self,
+ attn: Attention,
+ hidden_states: torch.Tensor,
+ encoder_hidden_states: Optional[torch.Tensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ temb: Optional[torch.Tensor] = None,
+ *args,
+ **kwargs,
+ ) -> torch.Tensor:
+ if len(args) > 0 or kwargs.get("scale", None) is not None:
+ deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
+ deprecate("scale", "1.0.0", deprecation_message)
+
+ residual = hidden_states
+ if attn.spatial_norm is not None:
+ hidden_states = attn.spatial_norm(hidden_states, temb)
+
+ input_ndim = hidden_states.ndim
+
+ if input_ndim == 4:
+ batch_size, channel, height, width = hidden_states.shape
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
+
+ batch_size, sequence_length, _ = (
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
+ )
+
+ if attention_mask is not None:
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
+ # scaled_dot_product_attention expects attention_mask shape to be
+ # (batch, heads, source_length, target_length)
+ attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
+
+ if attn.group_norm is not None:
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
+
+ query = attn.to_q(hidden_states)
+
+ if encoder_hidden_states is None:
+ encoder_hidden_states = hidden_states
+ elif attn.norm_cross:
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
+
+ key = attn.to_k(encoder_hidden_states)
+ value = attn.to_v(encoder_hidden_states)
+
+ inner_dim = key.shape[-1]
+ head_dim = inner_dim // attn.heads
+
+ query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
+
+ key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
+ value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
+
+ if attn.norm_q is not None:
+ query = attn.norm_q(query)
+ if attn.norm_k is not None:
+ key = attn.norm_k(key)
+
+ # the output of sdp = (batch, num_heads, seq_len, head_dim)
+ # TODO: add support for attn.scale when we move to Torch 2.1
+ hidden_states = F.scaled_dot_product_attention(
+ query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
+ )
+
+ hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
+ hidden_states = hidden_states.to(query.dtype)
+
+ # linear proj
+ hidden_states = attn.to_out[0](hidden_states)
+ # dropout
+ hidden_states = attn.to_out[1](hidden_states)
+
+ if input_ndim == 4:
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
+
+ if attn.residual_connection:
+ hidden_states = hidden_states + residual
+
+ hidden_states = hidden_states / attn.rescale_output_factor
+
+ return hidden_states
+
+class FusedAttnProcessor2_0:
+ r"""
+ Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). It uses
+ fused projection layers. For self-attention modules, all projection matrices (i.e., query, key, value) are fused.
+ For cross-attention modules, key and value projection matrices are fused.
+
+
+
+ This API is currently ๐งช experimental in nature and can change in future.
+
+
+ """
+
+ def __init__(self):
+ if not hasattr(F, "scaled_dot_product_attention"):
+ raise ImportError(
+ "FusedAttnProcessor2_0 requires at least PyTorch 2.0, to use it. Please upgrade PyTorch to > 2.0."
+ )
+
+ def __call__(
+ self,
+ attn: Attention,
+ hidden_states: torch.Tensor,
+ encoder_hidden_states: Optional[torch.Tensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ temb: Optional[torch.Tensor] = None,
+ *args,
+ **kwargs,
+ ) -> torch.Tensor:
+ if len(args) > 0 or kwargs.get("scale", None) is not None:
+ deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
+ deprecate("scale", "1.0.0", deprecation_message)
+
+ residual = hidden_states
+ if attn.spatial_norm is not None:
+ hidden_states = attn.spatial_norm(hidden_states, temb)
+
+ input_ndim = hidden_states.ndim
+
+ if input_ndim == 4:
+ batch_size, channel, height, width = hidden_states.shape
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
+
+ batch_size, sequence_length, _ = (
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
+ )
+
+ if attention_mask is not None:
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
+ # scaled_dot_product_attention expects attention_mask shape to be
+ # (batch, heads, source_length, target_length)
+ attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
+
+ if attn.group_norm is not None:
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
+
+ if encoder_hidden_states is None:
+ qkv = attn.to_qkv(hidden_states)
+ split_size = qkv.shape[-1] // 3
+ query, key, value = torch.split(qkv, split_size, dim=-1)
+ else:
+ if attn.norm_cross:
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
+ query = attn.to_q(hidden_states)
+
+ kv = attn.to_kv(encoder_hidden_states)
+ split_size = kv.shape[-1] // 2
+ key, value = torch.split(kv, split_size, dim=-1)
+
+ inner_dim = key.shape[-1]
+ head_dim = inner_dim // attn.heads
+
+ query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
+ key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
+ value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
+
+ if attn.norm_q is not None:
+ query = attn.norm_q(query)
+ if attn.norm_k is not None:
+ key = attn.norm_k(key)
+
+ # the output of sdp = (batch, num_heads, seq_len, head_dim)
+ # TODO: add support for attn.scale when we move to Torch 2.1
+ hidden_states = F.scaled_dot_product_attention(
+ query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
+ )
+
+ hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
+ hidden_states = hidden_states.to(query.dtype)
+
+ # linear proj
+ hidden_states = attn.to_out[0](hidden_states)
+ # dropout
+ hidden_states = attn.to_out[1](hidden_states)
+
+ if input_ndim == 4:
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
+
+ if attn.residual_connection:
+ hidden_states = hidden_states + residual
+
+ hidden_states = hidden_states / attn.rescale_output_factor
+
+ return hidden_states
+
+
+class CustomDiffusionXFormersAttnProcessor(nn.Module):
+ r"""
+ Processor for implementing memory efficient attention using xFormers for the Custom Diffusion method.
+
+ Args:
+ train_kv (`bool`, defaults to `True`):
+ Whether to newly train the key and value matrices corresponding to the text features.
+ train_q_out (`bool`, defaults to `True`):
+ Whether to newly train query matrices corresponding to the latent image features.
+ hidden_size (`int`, *optional*, defaults to `None`):
+ The hidden size of the attention layer.
+ cross_attention_dim (`int`, *optional*, defaults to `None`):
+ The number of channels in the `encoder_hidden_states`.
+ out_bias (`bool`, defaults to `True`):
+ Whether to include the bias parameter in `train_q_out`.
+ dropout (`float`, *optional*, defaults to 0.0):
+ The dropout probability to use.
+ attention_op (`Callable`, *optional*, defaults to `None`):
+ The base
+ [operator](https://facebookresearch.github.io/xformers/components/ops.html#xformers.ops.AttentionOpBase) to use
+ as the attention operator. It is recommended to set to `None`, and allow xFormers to choose the best operator.
+ """
+
+ def __init__(
+ self,
+ train_kv: bool = True,
+ train_q_out: bool = False,
+ hidden_size: Optional[int] = None,
+ cross_attention_dim: Optional[int] = None,
+ out_bias: bool = True,
+ dropout: float = 0.0,
+ attention_op: Optional[Callable] = None,
+ ):
+ super().__init__()
+ self.train_kv = train_kv
+ self.train_q_out = train_q_out
+
+ self.hidden_size = hidden_size
+ self.cross_attention_dim = cross_attention_dim
+ self.attention_op = attention_op
+
+ # `_custom_diffusion` id for easy serialization and loading.
+ if self.train_kv:
+ self.to_k_custom_diffusion = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
+ self.to_v_custom_diffusion = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
+ if self.train_q_out:
+ self.to_q_custom_diffusion = nn.Linear(hidden_size, hidden_size, bias=False)
+ self.to_out_custom_diffusion = nn.ModuleList([])
+ self.to_out_custom_diffusion.append(nn.Linear(hidden_size, hidden_size, bias=out_bias))
+ self.to_out_custom_diffusion.append(nn.Dropout(dropout))
+
+ def __call__(
+ self,
+ attn: Attention,
+ hidden_states: torch.Tensor,
+ encoder_hidden_states: Optional[torch.Tensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ ) -> torch.Tensor:
+ batch_size, sequence_length, _ = (
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
+ )
+
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
+
+ if self.train_q_out:
+ query = self.to_q_custom_diffusion(hidden_states).to(attn.to_q.weight.dtype)
+ else:
+ query = attn.to_q(hidden_states.to(attn.to_q.weight.dtype))
+
+ if encoder_hidden_states is None:
+ crossattn = False
+ encoder_hidden_states = hidden_states
+ else:
+ crossattn = True
+ if attn.norm_cross:
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
+
+ if self.train_kv:
+ key = self.to_k_custom_diffusion(encoder_hidden_states.to(self.to_k_custom_diffusion.weight.dtype))
+ value = self.to_v_custom_diffusion(encoder_hidden_states.to(self.to_v_custom_diffusion.weight.dtype))
+ key = key.to(attn.to_q.weight.dtype)
+ value = value.to(attn.to_q.weight.dtype)
+ else:
+ key = attn.to_k(encoder_hidden_states)
+ value = attn.to_v(encoder_hidden_states)
+
+ if crossattn:
+ detach = torch.ones_like(key)
+ detach[:, :1, :] = detach[:, :1, :] * 0.0
+ key = detach * key + (1 - detach) * key.detach()
+ value = detach * value + (1 - detach) * value.detach()
+
+ query = attn.head_to_batch_dim(query).contiguous()
+ key = attn.head_to_batch_dim(key).contiguous()
+ value = attn.head_to_batch_dim(value).contiguous()
+
+ hidden_states = xformers.ops.memory_efficient_attention(
+ query, key, value, attn_bias=attention_mask, op=self.attention_op, scale=attn.scale
+ )
+ hidden_states = hidden_states.to(query.dtype)
+ hidden_states = attn.batch_to_head_dim(hidden_states)
+
+ if self.train_q_out:
+ # linear proj
+ hidden_states = self.to_out_custom_diffusion[0](hidden_states)
+ # dropout
+ hidden_states = self.to_out_custom_diffusion[1](hidden_states)
+ else:
+ # linear proj
+ hidden_states = attn.to_out[0](hidden_states)
+ # dropout
+ hidden_states = attn.to_out[1](hidden_states)
+
+ return hidden_states
+
+
+class CustomDiffusionAttnProcessor2_0(nn.Module):
+ r"""
+ Processor for implementing attention for the Custom Diffusion method using PyTorch 2.0โs memory-efficient scaled
+ dot-product attention.
+
+ Args:
+ train_kv (`bool`, defaults to `True`):
+ Whether to newly train the key and value matrices corresponding to the text features.
+ train_q_out (`bool`, defaults to `True`):
+ Whether to newly train query matrices corresponding to the latent image features.
+ hidden_size (`int`, *optional*, defaults to `None`):
+ The hidden size of the attention layer.
+ cross_attention_dim (`int`, *optional*, defaults to `None`):
+ The number of channels in the `encoder_hidden_states`.
+ out_bias (`bool`, defaults to `True`):
+ Whether to include the bias parameter in `train_q_out`.
+ dropout (`float`, *optional*, defaults to 0.0):
+ The dropout probability to use.
+ """
+
+ def __init__(
+ self,
+ train_kv: bool = True,
+ train_q_out: bool = True,
+ hidden_size: Optional[int] = None,
+ cross_attention_dim: Optional[int] = None,
+ out_bias: bool = True,
+ dropout: float = 0.0,
+ ):
+ super().__init__()
+ self.train_kv = train_kv
+ self.train_q_out = train_q_out
+
+ self.hidden_size = hidden_size
+ self.cross_attention_dim = cross_attention_dim
+
+ # `_custom_diffusion` id for easy serialization and loading.
+ if self.train_kv:
+ self.to_k_custom_diffusion = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
+ self.to_v_custom_diffusion = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
+ if self.train_q_out:
+ self.to_q_custom_diffusion = nn.Linear(hidden_size, hidden_size, bias=False)
+ self.to_out_custom_diffusion = nn.ModuleList([])
+ self.to_out_custom_diffusion.append(nn.Linear(hidden_size, hidden_size, bias=out_bias))
+ self.to_out_custom_diffusion.append(nn.Dropout(dropout))
+
+ def __call__(
+ self,
+ attn: Attention,
+ hidden_states: torch.Tensor,
+ encoder_hidden_states: Optional[torch.Tensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ ) -> torch.Tensor:
+ batch_size, sequence_length, _ = hidden_states.shape
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
+ if self.train_q_out:
+ query = self.to_q_custom_diffusion(hidden_states)
+ else:
+ query = attn.to_q(hidden_states)
+
+ if encoder_hidden_states is None:
+ crossattn = False
+ encoder_hidden_states = hidden_states
+ else:
+ crossattn = True
+ if attn.norm_cross:
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
+
+ if self.train_kv:
+ key = self.to_k_custom_diffusion(encoder_hidden_states.to(self.to_k_custom_diffusion.weight.dtype))
+ value = self.to_v_custom_diffusion(encoder_hidden_states.to(self.to_v_custom_diffusion.weight.dtype))
+ key = key.to(attn.to_q.weight.dtype)
+ value = value.to(attn.to_q.weight.dtype)
+
+ else:
+ key = attn.to_k(encoder_hidden_states)
+ value = attn.to_v(encoder_hidden_states)
+
+ if crossattn:
+ detach = torch.ones_like(key)
+ detach[:, :1, :] = detach[:, :1, :] * 0.0
+ key = detach * key + (1 - detach) * key.detach()
+ value = detach * value + (1 - detach) * value.detach()
+
+ inner_dim = hidden_states.shape[-1]
+
+ head_dim = inner_dim // attn.heads
+ query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
+ key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
+ value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
+
+ # the output of sdp = (batch, num_heads, seq_len, head_dim)
+ # TODO: add support for attn.scale when we move to Torch 2.1
+ hidden_states = F.scaled_dot_product_attention(
+ query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
+ )
+
+ hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
+ hidden_states = hidden_states.to(query.dtype)
+
+ if self.train_q_out:
+ # linear proj
+ hidden_states = self.to_out_custom_diffusion[0](hidden_states)
+ # dropout
+ hidden_states = self.to_out_custom_diffusion[1](hidden_states)
+ else:
+ # linear proj
+ hidden_states = attn.to_out[0](hidden_states)
+ # dropout
+ hidden_states = attn.to_out[1](hidden_states)
+
+ return hidden_states
+
+
+class SlicedAttnProcessor:
+ r"""
+ Processor for implementing sliced attention.
+
+ Args:
+ slice_size (`int`, *optional*):
+ The number of steps to compute attention. Uses as many slices as `attention_head_dim // slice_size`, and
+ `attention_head_dim` must be a multiple of the `slice_size`.
+ """
+
+ def __init__(self, slice_size: int):
+ self.slice_size = slice_size
+
+ def __call__(
+ self,
+ attn: Attention,
+ hidden_states: torch.Tensor,
+ encoder_hidden_states: Optional[torch.Tensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ ) -> torch.Tensor:
+ residual = hidden_states
+
+ input_ndim = hidden_states.ndim
+
+ if input_ndim == 4:
+ batch_size, channel, height, width = hidden_states.shape
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
+
+ batch_size, sequence_length, _ = (
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
+ )
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
+
+ if attn.group_norm is not None:
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
+
+ query = attn.to_q(hidden_states)
+ dim = query.shape[-1]
+ query = attn.head_to_batch_dim(query)
+
+ if encoder_hidden_states is None:
+ encoder_hidden_states = hidden_states
+ elif attn.norm_cross:
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
+
+ key = attn.to_k(encoder_hidden_states)
+ value = attn.to_v(encoder_hidden_states)
+ key = attn.head_to_batch_dim(key)
+ value = attn.head_to_batch_dim(value)
+
+ batch_size_attention, query_tokens, _ = query.shape
+ hidden_states = torch.zeros(
+ (batch_size_attention, query_tokens, dim // attn.heads), device=query.device, dtype=query.dtype
+ )
+
+ for i in range((batch_size_attention - 1) // self.slice_size + 1):
+ start_idx = i * self.slice_size
+ end_idx = (i + 1) * self.slice_size
+
+ query_slice = query[start_idx:end_idx]
+ key_slice = key[start_idx:end_idx]
+ attn_mask_slice = attention_mask[start_idx:end_idx] if attention_mask is not None else None
+
+ attn_slice = attn.get_attention_scores(query_slice, key_slice, attn_mask_slice)
+
+ attn_slice = torch.bmm(attn_slice, value[start_idx:end_idx])
+
+ hidden_states[start_idx:end_idx] = attn_slice
+
+ hidden_states = attn.batch_to_head_dim(hidden_states)
+
+ # linear proj
+ hidden_states = attn.to_out[0](hidden_states)
+ # dropout
+ hidden_states = attn.to_out[1](hidden_states)
+
+ if input_ndim == 4:
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
+
+ if attn.residual_connection:
+ hidden_states = hidden_states + residual
+
+ hidden_states = hidden_states / attn.rescale_output_factor
+
+ return hidden_states
+
+
+class SlicedAttnAddedKVProcessor:
+ r"""
+ Processor for implementing sliced attention with extra learnable key and value matrices for the text encoder.
+
+ Args:
+ slice_size (`int`, *optional*):
+ The number of steps to compute attention. Uses as many slices as `attention_head_dim // slice_size`, and
+ `attention_head_dim` must be a multiple of the `slice_size`.
+ """
+
+ def __init__(self, slice_size):
+ self.slice_size = slice_size
+
+ def __call__(
+ self,
+ attn: "Attention",
+ hidden_states: torch.Tensor,
+ encoder_hidden_states: Optional[torch.Tensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ temb: Optional[torch.Tensor] = None,
+ ) -> torch.Tensor:
+ residual = hidden_states
+
+ if attn.spatial_norm is not None:
+ hidden_states = attn.spatial_norm(hidden_states, temb)
+
+ hidden_states = hidden_states.view(hidden_states.shape[0], hidden_states.shape[1], -1).transpose(1, 2)
+
+ batch_size, sequence_length, _ = hidden_states.shape
+
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
+
+ if encoder_hidden_states is None:
+ encoder_hidden_states = hidden_states
+ elif attn.norm_cross:
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
+
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
+
+ query = attn.to_q(hidden_states)
+ dim = query.shape[-1]
+ query = attn.head_to_batch_dim(query)
+
+ encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
+ encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)
+
+ encoder_hidden_states_key_proj = attn.head_to_batch_dim(encoder_hidden_states_key_proj)
+ encoder_hidden_states_value_proj = attn.head_to_batch_dim(encoder_hidden_states_value_proj)
+
+ if not attn.only_cross_attention:
+ key = attn.to_k(hidden_states)
+ value = attn.to_v(hidden_states)
+ key = attn.head_to_batch_dim(key)
+ value = attn.head_to_batch_dim(value)
+ key = torch.cat([encoder_hidden_states_key_proj, key], dim=1)
+ value = torch.cat([encoder_hidden_states_value_proj, value], dim=1)
+ else:
+ key = encoder_hidden_states_key_proj
+ value = encoder_hidden_states_value_proj
+
+ batch_size_attention, query_tokens, _ = query.shape
+ hidden_states = torch.zeros(
+ (batch_size_attention, query_tokens, dim // attn.heads), device=query.device, dtype=query.dtype
+ )
+
+ for i in range((batch_size_attention - 1) // self.slice_size + 1):
+ start_idx = i * self.slice_size
+ end_idx = (i + 1) * self.slice_size
+
+ query_slice = query[start_idx:end_idx]
+ key_slice = key[start_idx:end_idx]
+ attn_mask_slice = attention_mask[start_idx:end_idx] if attention_mask is not None else None
+
+ attn_slice = attn.get_attention_scores(query_slice, key_slice, attn_mask_slice)
+
+ attn_slice = torch.bmm(attn_slice, value[start_idx:end_idx])
+
+ hidden_states[start_idx:end_idx] = attn_slice
+
+ hidden_states = attn.batch_to_head_dim(hidden_states)
+
+ # linear proj
+ hidden_states = attn.to_out[0](hidden_states)
+ # dropout
+ hidden_states = attn.to_out[1](hidden_states)
+
+ hidden_states = hidden_states.transpose(-1, -2).reshape(residual.shape)
+ hidden_states = hidden_states + residual
+
+ return hidden_states
+
+
+class SpatialNorm(nn.Module):
+ """
+ Spatially conditioned normalization as defined in https://arxiv.org/abs/2209.09002.
+
+ Args:
+ f_channels (`int`):
+ The number of channels for input to group normalization layer, and output of the spatial norm layer.
+ zq_channels (`int`):
+ The number of channels for the quantized vector as described in the paper.
+ """
+
+ def __init__(
+ self,
+ f_channels: int,
+ zq_channels: int,
+ ):
+ super().__init__()
+ self.norm_layer = nn.GroupNorm(num_channels=f_channels, num_groups=32, eps=1e-6, affine=True)
+ self.conv_y = nn.Conv2d(zq_channels, f_channels, kernel_size=1, stride=1, padding=0)
+ self.conv_b = nn.Conv2d(zq_channels, f_channels, kernel_size=1, stride=1, padding=0)
+
+ def forward(self, f: torch.Tensor, zq: torch.Tensor) -> torch.Tensor:
+ f_size = f.shape[-2:]
+ zq = F.interpolate(zq, size=f_size, mode="nearest")
+ norm_f = self.norm_layer(f)
+ new_f = norm_f * self.conv_y(zq) + self.conv_b(zq)
+ return new_f
+
+class LoRAAttnProcessor:
+ def __init__(self):
+ pass
+
+
+class LoRAAttnProcessor2_0:
+ def __init__(self):
+ pass
+
+
+class LoRAXFormersAttnProcessor:
+ def __init__(self):
+ pass
+
+
+class LoRAAttnAddedKVProcessor:
+ def __init__(self):
+ pass
+
+
+ADDED_KV_ATTENTION_PROCESSORS = (
+ AttnAddedKVProcessor,
+ SlicedAttnAddedKVProcessor,
+ AttnAddedKVProcessor2_0,
+ XFormersAttnAddedKVProcessor,
+)
+
+CROSS_ATTENTION_PROCESSORS = (
+ AttnProcessor,
+ AttnProcessor2_0,
+ XFormersAttnProcessor,
+ SlicedAttnProcessor,
+)
+
+AttentionProcessor = Union[
+ AttnProcessor,
+ AttnProcessor2_0,
+ FusedAttnProcessor2_0,
+ XFormersAttnProcessor,
+ SlicedAttnProcessor,
+ AttnAddedKVProcessor,
+ SlicedAttnAddedKVProcessor,
+ AttnAddedKVProcessor2_0,
+ XFormersAttnAddedKVProcessor,
+ CustomDiffusionAttnProcessor,
+ CustomDiffusionXFormersAttnProcessor,
+ CustomDiffusionAttnProcessor2_0,
+]
\ No newline at end of file
diff --git a/flux/controlnet_flux.py b/flux/controlnet_flux.py
new file mode 100644
index 0000000000000000000000000000000000000000..bde87346ab0ff03f292172b73d1e6ec255824b6c
--- /dev/null
+++ b/flux/controlnet_flux.py
@@ -0,0 +1,617 @@
+# Copyright 2024 Black Forest Labs, The HuggingFace Team and The InstantX Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from dataclasses import dataclass
+from typing import Any, Dict, List, Optional, Tuple, Union
+
+import torch
+import torch.nn as nn
+
+from diffusers.configuration_utils import ConfigMixin, register_to_config
+from .lora.peft import PeftAdapterMixin
+from diffusers.models.attention_processor import AttentionProcessor
+from diffusers.models.modeling_utils import ModelMixin
+from diffusers.utils import USE_PEFT_BACKEND, is_torch_version, logging, scale_lora_layers, unscale_lora_layers
+#from .controlnet import BaseOutput, zero_module
+from diffusers.utils import BaseOutput
+from .embeddings import CombinedTimestepGuidanceTextProjEmbeddings, CombinedTimestepTextProjEmbeddings
+from diffusers.models.modeling_outputs import Transformer2DModelOutput
+from .transformer_flux import FluxSingleTransformerBlock, FluxTransformerBlock
+import numpy as np
+
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+def zero_module(module):
+ for p in module.parameters():
+ nn.init.zeros_(p)
+ return module
+
+def get_1d_rotary_pos_embed(
+ dim: int,
+ pos: Union[np.ndarray, int],
+ theta: float = 10000.0,
+ use_real=False,
+ linear_factor=1.0,
+ ntk_factor=1.0,
+ repeat_interleave_real=True,
+ freqs_dtype=torch.float32, # torch.float32, torch.float64 (flux)
+):
+ """
+ Precompute the frequency tensor for complex exponentials (cis) with given dimensions.
+
+ This function calculates a frequency tensor with complex exponentials using the given dimension 'dim' and the end
+ index 'end'. The 'theta' parameter scales the frequencies. The returned tensor contains complex values in complex64
+ data type.
+
+ Args:
+ dim (`int`): Dimension of the frequency tensor.
+ pos (`np.ndarray` or `int`): Position indices for the frequency tensor. [S] or scalar
+ theta (`float`, *optional*, defaults to 10000.0):
+ Scaling factor for frequency computation. Defaults to 10000.0.
+ use_real (`bool`, *optional*):
+ If True, return real part and imaginary part separately. Otherwise, return complex numbers.
+ linear_factor (`float`, *optional*, defaults to 1.0):
+ Scaling factor for the context extrapolation. Defaults to 1.0.
+ ntk_factor (`float`, *optional*, defaults to 1.0):
+ Scaling factor for the NTK-Aware RoPE. Defaults to 1.0.
+ repeat_interleave_real (`bool`, *optional*, defaults to `True`):
+ If `True` and `use_real`, real part and imaginary part are each interleaved with themselves to reach `dim`.
+ Otherwise, they are concateanted with themselves.
+ freqs_dtype (`torch.float32` or `torch.float64`, *optional*, defaults to `torch.float32`):
+ the dtype of the frequency tensor.
+ Returns:
+ `torch.Tensor`: Precomputed frequency tensor with complex exponentials. [S, D/2]
+ """
+ assert dim % 2 == 0
+
+ if isinstance(pos, int):
+ pos = torch.arange(pos)
+ if isinstance(pos, np.ndarray):
+ pos = torch.from_numpy(pos) # type: ignore # [S]
+
+ theta = theta * ntk_factor
+ freqs = 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=freqs_dtype)[: (dim // 2)] / dim)) / linear_factor # [D/2]
+ freqs = freqs.to(pos.device)
+ freqs = torch.outer(pos, freqs) # type: ignore # [S, D/2]
+ if use_real and repeat_interleave_real:
+ # flux, hunyuan-dit, cogvideox
+ freqs_cos = freqs.cos().repeat_interleave(2, dim=1).float() # [S, D]
+ freqs_sin = freqs.sin().repeat_interleave(2, dim=1).float() # [S, D]
+ return freqs_cos, freqs_sin
+ elif use_real:
+ # stable audio
+ freqs_cos = torch.cat([freqs.cos(), freqs.cos()], dim=-1).float() # [S, D]
+ freqs_sin = torch.cat([freqs.sin(), freqs.sin()], dim=-1).float() # [S, D]
+ return freqs_cos, freqs_sin
+ else:
+ # lumina
+ freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64 # [S, D/2]
+ return freqs_cis
+
+class FluxPosEmbed(nn.Module):
+ # modified from https://github.com/black-forest-labs/flux/blob/c00d7c60b085fce8058b9df845e036090873f2ce/src/flux/modules/layers.py#L11
+ def __init__(self, theta: int, axes_dim: List[int]):
+ super().__init__()
+ self.theta = theta
+ self.axes_dim = axes_dim
+
+ def forward(self, ids: torch.Tensor) -> torch.Tensor:
+ n_axes = ids.shape[-1]
+ cos_out = []
+ sin_out = []
+ pos = ids.squeeze().float()
+ is_mps = ids.device.type == "mps"
+ freqs_dtype = torch.float32 if is_mps else torch.float64
+ for i in range(n_axes):
+ cos, sin = get_1d_rotary_pos_embed(
+ self.axes_dim[i], pos[:, i], repeat_interleave_real=True, use_real=True, freqs_dtype=freqs_dtype
+ )
+ cos_out.append(cos)
+ sin_out.append(sin)
+ freqs_cos = torch.cat(cos_out, dim=-1).to(ids.device)
+ freqs_sin = torch.cat(sin_out, dim=-1).to(ids.device)
+ return freqs_cos, freqs_sin
+
+
+@dataclass
+class FluxControlNetOutput(BaseOutput):
+ controlnet_block_samples: Tuple[torch.Tensor]
+ controlnet_single_block_samples: Tuple[torch.Tensor]
+
+
+class FluxControlNetModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
+ _supports_gradient_checkpointing = True
+
+ @register_to_config
+ def __init__(
+ self,
+ patch_size: int = 1,
+ in_channels: int = 64,
+ num_layers: int = 19,
+ num_single_layers: int = 38,
+ attention_head_dim: int = 128,
+ num_attention_heads: int = 24,
+ joint_attention_dim: int = 4096,
+ pooled_projection_dim: int = 768,
+ guidance_embeds: bool = False,
+ axes_dims_rope: List[int] = [16, 56, 56],
+ num_mode: int = None,
+ ):
+ super().__init__()
+ self.out_channels = in_channels
+ self.inner_dim = num_attention_heads * attention_head_dim
+
+ self.pos_embed = FluxPosEmbed(theta=10000, axes_dim=axes_dims_rope)
+ text_time_guidance_cls = (
+ CombinedTimestepGuidanceTextProjEmbeddings if guidance_embeds else CombinedTimestepTextProjEmbeddings
+ )
+ self.time_text_embed = text_time_guidance_cls(
+ embedding_dim=self.inner_dim, pooled_projection_dim=pooled_projection_dim
+ )
+
+ self.context_embedder = nn.Linear(joint_attention_dim, self.inner_dim)
+ self.x_embedder = torch.nn.Linear(in_channels, self.inner_dim)
+
+ self.transformer_blocks = nn.ModuleList(
+ [
+ FluxTransformerBlock(
+ dim=self.inner_dim,
+ num_attention_heads=num_attention_heads,
+ attention_head_dim=attention_head_dim,
+ )
+ for i in range(num_layers)
+ ]
+ )
+
+ self.single_transformer_blocks = nn.ModuleList(
+ [
+ FluxSingleTransformerBlock(
+ dim=self.inner_dim,
+ num_attention_heads=num_attention_heads,
+ attention_head_dim=attention_head_dim,
+ )
+ for i in range(num_single_layers)
+ ]
+ )
+
+ # controlnet_blocks
+ self.controlnet_blocks = nn.ModuleList([])
+ for _ in range(len(self.transformer_blocks)):
+ self.controlnet_blocks.append(zero_module(nn.Linear(self.inner_dim, self.inner_dim)))
+
+ self.controlnet_single_blocks = nn.ModuleList([])
+ for _ in range(len(self.single_transformer_blocks)):
+ self.controlnet_single_blocks.append(zero_module(nn.Linear(self.inner_dim, self.inner_dim)))
+
+ self.union = num_mode is not None
+ if self.union:
+ self.controlnet_mode_embedder = nn.Embedding(num_mode, self.inner_dim)
+
+ self.controlnet_x_embedder = zero_module(torch.nn.Linear(in_channels, self.inner_dim))
+
+ self.gradient_checkpointing = False
+
+ @property
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
+ def attn_processors(self):
+ r"""
+ Returns:
+ `dict` of attention processors: A dictionary containing all attention processors used in the model with
+ indexed by its weight name.
+ """
+ # set recursively
+ processors = {}
+
+ def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
+ if hasattr(module, "get_processor"):
+ processors[f"{name}.processor"] = module.get_processor()
+
+ for sub_name, child in module.named_children():
+ fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
+
+ return processors
+
+ for name, module in self.named_children():
+ fn_recursive_add_processors(name, module, processors)
+
+ return processors
+
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
+ def set_attn_processor(self, processor):
+ r"""
+ Sets the attention processor to use to compute attention.
+
+ Parameters:
+ processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
+ The instantiated processor class or a dictionary of processor classes that will be set as the processor
+ for **all** `Attention` layers.
+
+ If `processor` is a dict, the key needs to define the path to the corresponding cross attention
+ processor. This is strongly recommended when setting trainable attention processors.
+
+ """
+ count = len(self.attn_processors.keys())
+
+ if isinstance(processor, dict) and len(processor) != count:
+ raise ValueError(
+ f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
+ f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
+ )
+
+ def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
+ if hasattr(module, "set_processor"):
+ if not isinstance(processor, dict):
+ module.set_processor(processor)
+ else:
+ module.set_processor(processor.pop(f"{name}.processor"))
+
+ for sub_name, child in module.named_children():
+ fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
+
+ for name, module in self.named_children():
+ fn_recursive_attn_processor(name, module, processor)
+
+ def _set_gradient_checkpointing(self, module, value=False):
+ if hasattr(module, "gradient_checkpointing"):
+ module.gradient_checkpointing = value
+
+ @classmethod
+ def from_transformer(
+ cls,
+ transformer,
+ num_layers: int = 4,
+ num_single_layers: int = 10,
+ attention_head_dim: int = 128,
+ num_attention_heads: int = 24,
+ load_weights_from_transformer=True,
+ ):
+ config = transformer.config
+ config["num_layers"] = num_layers
+ config["num_single_layers"] = num_single_layers
+ config["attention_head_dim"] = attention_head_dim
+ config["num_attention_heads"] = num_attention_heads
+
+ controlnet = cls(**config)
+
+ if load_weights_from_transformer:
+ controlnet.pos_embed.load_state_dict(transformer.pos_embed.state_dict())
+ controlnet.time_text_embed.load_state_dict(transformer.time_text_embed.state_dict())
+ controlnet.context_embedder.load_state_dict(transformer.context_embedder.state_dict())
+ controlnet.x_embedder.load_state_dict(transformer.x_embedder.state_dict())
+ controlnet.transformer_blocks.load_state_dict(transformer.transformer_blocks.state_dict(), strict=False)
+ controlnet.single_transformer_blocks.load_state_dict(
+ transformer.single_transformer_blocks.state_dict(), strict=False
+ )
+
+ controlnet.controlnet_x_embedder = zero_module(controlnet.controlnet_x_embedder)
+
+ return controlnet
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ controlnet_cond: torch.Tensor,
+ controlnet_mode: torch.Tensor = None,
+ conditioning_scale: float = 1.0,
+ encoder_hidden_states: torch.Tensor = None,
+ t5_encoder_hidden_states: torch.Tensor = None,
+ pooled_projections: torch.Tensor = None,
+ timestep: torch.LongTensor = None,
+ img_ids: torch.Tensor = None,
+ txt_ids: torch.Tensor = None,
+ guidance: torch.Tensor = None,
+ joint_attention_kwargs: Optional[Dict[str, Any]] = None,
+ return_dict: bool = True,
+ ) -> Union[torch.FloatTensor, Transformer2DModelOutput]:
+ """
+ The [`FluxTransformer2DModel`] forward method.
+
+ Args:
+ hidden_states (`torch.FloatTensor` of shape `(batch size, channel, height, width)`):
+ Input `hidden_states`.
+ controlnet_cond (`torch.Tensor`):
+ The conditional input tensor of shape `(batch_size, sequence_length, hidden_size)`.
+ controlnet_mode (`torch.Tensor`):
+ The mode tensor of shape `(batch_size, 1)`.
+ conditioning_scale (`float`, defaults to `1.0`):
+ The scale factor for ControlNet outputs.
+ encoder_hidden_states (`torch.FloatTensor` of shape `(batch size, sequence_len, embed_dims)`):
+ Conditional embeddings (embeddings computed from the input conditions such as prompts) to use.
+ pooled_projections (`torch.FloatTensor` of shape `(batch_size, projection_dim)`): Embeddings projected
+ from the embeddings of input conditions.
+ timestep ( `torch.LongTensor`):
+ Used to indicate denoising step.
+ block_controlnet_hidden_states: (`list` of `torch.Tensor`):
+ A list of tensors that if specified are added to the residuals of transformer blocks.
+ joint_attention_kwargs (`dict`, *optional*):
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
+ `self.processor` in
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~models.transformer_2d.Transformer2DModelOutput`] instead of a plain
+ tuple.
+
+ Returns:
+ If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a
+ `tuple` where the first element is the sample tensor.
+ """
+ if joint_attention_kwargs is not None:
+ joint_attention_kwargs = joint_attention_kwargs.copy()
+ lora_scale = joint_attention_kwargs.pop("scale", 1.0)
+ else:
+ lora_scale = 1.0
+
+ if USE_PEFT_BACKEND:
+ # weight the lora layers by setting `lora_scale` for each PEFT layer
+ scale_lora_layers(self, lora_scale)
+ else:
+ if joint_attention_kwargs is not None and joint_attention_kwargs.get("scale", None) is not None:
+ logger.warning(
+ "Passing `scale` via `joint_attention_kwargs` when not using the PEFT backend is ineffective."
+ )
+ hidden_states = self.x_embedder(hidden_states)
+
+ # add
+ hidden_states = hidden_states + self.controlnet_x_embedder(controlnet_cond)
+
+ timestep = timestep.to(hidden_states.dtype) * 1000
+ if guidance is not None:
+ guidance = guidance.to(hidden_states.dtype) * 1000
+ else:
+ guidance = None
+ temb = (
+ self.time_text_embed(timestep, pooled_projections)
+ if guidance is None
+ else self.time_text_embed(timestep, guidance, pooled_projections)
+ )
+ encoder_hidden_states = self.context_embedder(encoder_hidden_states)
+ if t5_encoder_hidden_states is not None:
+ encoder_hidden_states = torch.cat([encoder_hidden_states, t5_encoder_hidden_states], dim=1)
+
+ if txt_ids.ndim == 3:
+ logger.warning(
+ "Passing `txt_ids` 3d torch.Tensor is deprecated."
+ "Please remove the batch dimension and pass it as a 2d torch Tensor"
+ )
+ txt_ids = txt_ids[0]
+
+ if self.union:
+ # union mode
+ if controlnet_mode is None:
+ raise ValueError("`controlnet_mode` cannot be `None` when applying ControlNet-Union")
+ # union mode emb
+ controlnet_mode_emb = self.controlnet_mode_embedder(controlnet_mode)
+ encoder_hidden_states = torch.cat([controlnet_mode_emb, encoder_hidden_states], dim=1)
+ txt_ids = torch.cat([txt_ids[:1], txt_ids], dim=0)
+
+ if img_ids.ndim == 3:
+ logger.warning(
+ "Passing `img_ids` 3d torch.Tensor is deprecated."
+ "Please remove the batch dimension and pass it as a 2d torch Tensor"
+ )
+ img_ids = img_ids[0]
+
+ ids = torch.cat((txt_ids, img_ids), dim=0)
+ image_rotary_emb = self.pos_embed(ids)
+
+ block_samples = ()
+ for index_block, block in enumerate(self.transformer_blocks):
+ if self.training and self.gradient_checkpointing:
+
+ def create_custom_forward(module, return_dict=None):
+ def custom_forward(*inputs):
+ if return_dict is not None:
+ return module(*inputs, return_dict=return_dict)
+ else:
+ return module(*inputs)
+
+ return custom_forward
+
+ ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
+ encoder_hidden_states, hidden_states = torch.utils.checkpoint.checkpoint(
+ create_custom_forward(block),
+ hidden_states,
+ encoder_hidden_states,
+ temb,
+ image_rotary_emb,
+ **ckpt_kwargs,
+ )
+
+ else:
+ encoder_hidden_states, hidden_states = block(
+ hidden_states=hidden_states,
+ encoder_hidden_states=encoder_hidden_states,
+ temb=temb,
+ image_rotary_emb=image_rotary_emb,
+ )
+ block_samples = block_samples + (hidden_states,)
+
+ hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
+
+ single_block_samples = ()
+ for index_block, block in enumerate(self.single_transformer_blocks):
+ if self.training and self.gradient_checkpointing:
+
+ def create_custom_forward(module, return_dict=None):
+ def custom_forward(*inputs):
+ if return_dict is not None:
+ return module(*inputs, return_dict=return_dict)
+ else:
+ return module(*inputs)
+
+ return custom_forward
+
+ ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
+ hidden_states = torch.utils.checkpoint.checkpoint(
+ create_custom_forward(block),
+ hidden_states,
+ temb,
+ image_rotary_emb,
+ **ckpt_kwargs,
+ )
+
+ else:
+ hidden_states = block(
+ hidden_states=hidden_states,
+ temb=temb,
+ image_rotary_emb=image_rotary_emb,
+ )
+ single_block_samples = single_block_samples + (hidden_states[:, encoder_hidden_states.shape[1] :],)
+
+ # controlnet block
+ controlnet_block_samples = ()
+ for block_sample, controlnet_block in zip(block_samples, self.controlnet_blocks):
+ block_sample = controlnet_block(block_sample)
+ controlnet_block_samples = controlnet_block_samples + (block_sample,)
+
+ controlnet_single_block_samples = ()
+ for single_block_sample, controlnet_block in zip(single_block_samples, self.controlnet_single_blocks):
+ single_block_sample = controlnet_block(single_block_sample)
+ controlnet_single_block_samples = controlnet_single_block_samples + (single_block_sample,)
+
+ # scaling
+ controlnet_block_samples = [sample * conditioning_scale for sample in controlnet_block_samples]
+ controlnet_single_block_samples = [sample * conditioning_scale for sample in controlnet_single_block_samples]
+
+ controlnet_block_samples = None if len(controlnet_block_samples) == 0 else controlnet_block_samples
+ controlnet_single_block_samples = (
+ None if len(controlnet_single_block_samples) == 0 else controlnet_single_block_samples
+ )
+
+ if USE_PEFT_BACKEND:
+ # remove `lora_scale` from each PEFT layer
+ unscale_lora_layers(self, lora_scale)
+
+ if not return_dict:
+ return (controlnet_block_samples, controlnet_single_block_samples)
+
+ return FluxControlNetOutput(
+ controlnet_block_samples=controlnet_block_samples,
+ controlnet_single_block_samples=controlnet_single_block_samples,
+ )
+
+
+class FluxMultiControlNetModel(ModelMixin):
+ r"""
+ `FluxMultiControlNetModel` wrapper class for Multi-FluxControlNetModel
+
+ This module is a wrapper for multiple instances of the `FluxControlNetModel`. The `forward()` API is designed to be
+ compatible with `FluxControlNetModel`.
+
+ Args:
+ controlnets (`List[FluxControlNetModel]`):
+ Provides additional conditioning to the unet during the denoising process. You must set multiple
+ `FluxControlNetModel` as a list.
+ """
+
+ def __init__(self, controlnets):
+ super().__init__()
+ self.nets = nn.ModuleList(controlnets)
+
+ def forward(
+ self,
+ hidden_states: torch.FloatTensor,
+ controlnet_cond: List[torch.tensor],
+ controlnet_mode: List[torch.tensor],
+ conditioning_scale: List[float],
+ encoder_hidden_states: torch.Tensor = None,
+ t5_encoder_hidden_states: torch.Tensor = None,
+ pooled_projections: torch.Tensor = None,
+ timestep: torch.LongTensor = None,
+ img_ids: torch.Tensor = None,
+ txt_ids: torch.Tensor = None,
+ guidance: torch.Tensor = None,
+ joint_attention_kwargs: Optional[Dict[str, Any]] = None,
+ return_dict: bool = True,
+ ) -> Union[FluxControlNetOutput, Tuple]:
+ # ControlNet-Union with multiple conditions
+ # only load one ControlNet for saving memories
+ if len(self.nets) == 1 and self.nets[0].union:
+ controlnet = self.nets[0]
+
+ for i, (image, mode, scale) in enumerate(zip(controlnet_cond, controlnet_mode, conditioning_scale)):
+ block_samples, single_block_samples = controlnet(
+ hidden_states=hidden_states,
+ controlnet_cond=image,
+ controlnet_mode=mode[:, None],
+ conditioning_scale=scale,
+ timestep=timestep,
+ guidance=guidance,
+ pooled_projections=pooled_projections,
+ encoder_hidden_states=encoder_hidden_states,
+ t5_encoder_hidden_states=t5_encoder_hidden_states,
+ txt_ids=txt_ids,
+ img_ids=img_ids,
+ joint_attention_kwargs=joint_attention_kwargs,
+ return_dict=return_dict,
+ )
+
+ # merge samples
+ if i == 0:
+ control_block_samples = block_samples
+ control_single_block_samples = single_block_samples
+ else:
+ control_block_samples = [
+ control_block_sample + block_sample
+ for control_block_sample, block_sample in zip(control_block_samples, block_samples)
+ ]
+
+ control_single_block_samples = [
+ control_single_block_sample + block_sample
+ for control_single_block_sample, block_sample in zip(
+ control_single_block_samples, single_block_samples
+ )
+ ]
+
+ # Regular Multi-ControlNets
+ # load all ControlNets into memories
+ else:
+ for i, (image, mode, scale, controlnet) in enumerate(
+ zip(controlnet_cond, controlnet_mode, conditioning_scale, self.nets)
+ ):
+ block_samples, single_block_samples = controlnet(
+ hidden_states=hidden_states,
+ controlnet_cond=image,
+ controlnet_mode=mode[:, None],
+ conditioning_scale=scale,
+ timestep=timestep,
+ guidance=guidance,
+ pooled_projections=pooled_projections,
+ encoder_hidden_states=encoder_hidden_states,
+ txt_ids=txt_ids,
+ img_ids=img_ids,
+ joint_attention_kwargs=joint_attention_kwargs,
+ return_dict=return_dict,
+ )
+
+ # merge samples
+ if i == 0:
+ control_block_samples = block_samples
+ control_single_block_samples = single_block_samples
+ else:
+ if block_samples is not None and control_block_samples is not None:
+ control_block_samples = [
+ control_block_sample + block_sample
+ for control_block_sample, block_sample in zip(control_block_samples, block_samples)
+ ]
+ if single_block_samples is not None and control_single_block_samples is not None:
+ control_single_block_samples = [
+ control_single_block_sample + block_sample
+ for control_single_block_sample, block_sample in zip(
+ control_single_block_samples, single_block_samples
+ )
+ ]
+
+ return control_block_samples, control_single_block_samples
\ No newline at end of file
diff --git a/flux/embeddings.py b/flux/embeddings.py
new file mode 100644
index 0000000000000000000000000000000000000000..cd202d697f51228345472035c1835ed21ce84371
--- /dev/null
+++ b/flux/embeddings.py
@@ -0,0 +1,1469 @@
+# Copyright 2024 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import math
+from typing import List, Optional, Tuple, Union
+
+import numpy as np
+import torch
+import torch.nn.functional as F
+from torch import nn
+
+from diffusers.utils import deprecate
+from .activations import FP32SiLU, get_activation
+from .attention_processor import Attention
+
+
+def get_timestep_embedding(
+ timesteps: torch.Tensor,
+ embedding_dim: int,
+ flip_sin_to_cos: bool = False,
+ downscale_freq_shift: float = 1,
+ scale: float = 1,
+ max_period: int = 10000,
+):
+ """
+ This matches the implementation in Denoising Diffusion Probabilistic Models: Create sinusoidal timestep embeddings.
+
+ Args
+ timesteps (torch.Tensor):
+ a 1-D Tensor of N indices, one per batch element. These may be fractional.
+ embedding_dim (int):
+ the dimension of the output.
+ flip_sin_to_cos (bool):
+ Whether the embedding order should be `cos, sin` (if True) or `sin, cos` (if False)
+ downscale_freq_shift (float):
+ Controls the delta between frequencies between dimensions
+ scale (float):
+ Scaling factor applied to the embeddings.
+ max_period (int):
+ Controls the maximum frequency of the embeddings
+ Returns
+ torch.Tensor: an [N x dim] Tensor of positional embeddings.
+ """
+ assert len(timesteps.shape) == 1, "Timesteps should be a 1d-array"
+
+ half_dim = embedding_dim // 2
+ exponent = -math.log(max_period) * torch.arange(
+ start=0, end=half_dim, dtype=torch.float32, device=timesteps.device
+ )
+ exponent = exponent / (half_dim - downscale_freq_shift)
+
+ emb = torch.exp(exponent)
+ emb = timesteps[:, None].float() * emb[None, :]
+
+ # scale embeddings
+ emb = scale * emb
+
+ # concat sine and cosine embeddings
+ emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=-1)
+
+ # flip sine and cosine embeddings
+ if flip_sin_to_cos:
+ emb = torch.cat([emb[:, half_dim:], emb[:, :half_dim]], dim=-1)
+
+ # zero pad
+ if embedding_dim % 2 == 1:
+ emb = torch.nn.functional.pad(emb, (0, 1, 0, 0))
+ return emb
+
+
+def get_2d_sincos_pos_embed(
+ embed_dim, grid_size, cls_token=False, extra_tokens=0, interpolation_scale=1.0, base_size=16
+):
+ """
+ grid_size: int of the grid height and width return: pos_embed: [grid_size*grid_size, embed_dim] or
+ [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
+ """
+ if isinstance(grid_size, int):
+ grid_size = (grid_size, grid_size)
+
+ grid_h = np.arange(grid_size[0], dtype=np.float32) / (grid_size[0] / base_size) / interpolation_scale
+ grid_w = np.arange(grid_size[1], dtype=np.float32) / (grid_size[1] / base_size) / interpolation_scale
+ grid = np.meshgrid(grid_w, grid_h) # here w goes first
+ grid = np.stack(grid, axis=0)
+
+ grid = grid.reshape([2, 1, grid_size[1], grid_size[0]])
+ pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
+ if cls_token and extra_tokens > 0:
+ pos_embed = np.concatenate([np.zeros([extra_tokens, embed_dim]), pos_embed], axis=0)
+ return pos_embed
+
+
+def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
+ if embed_dim % 2 != 0:
+ raise ValueError("embed_dim must be divisible by 2")
+
+ # use half of dimensions to encode grid_h
+ emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2)
+ emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2)
+
+ emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D)
+ return emb
+
+
+def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
+ """
+ embed_dim: output dimension for each position pos: a list of positions to be encoded: size (M,) out: (M, D)
+ """
+ if embed_dim % 2 != 0:
+ raise ValueError("embed_dim must be divisible by 2")
+
+ omega = np.arange(embed_dim // 2, dtype=np.float64)
+ omega /= embed_dim / 2.0
+ omega = 1.0 / 10000**omega # (D/2,)
+
+ pos = pos.reshape(-1) # (M,)
+ out = np.einsum("m,d->md", pos, omega) # (M, D/2), outer product
+
+ emb_sin = np.sin(out) # (M, D/2)
+ emb_cos = np.cos(out) # (M, D/2)
+
+ emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
+ return emb
+
+
+class PatchEmbed(nn.Module):
+ """2D Image to Patch Embedding with support for SD3 cropping."""
+
+ def __init__(
+ self,
+ height=224,
+ width=224,
+ patch_size=16,
+ in_channels=3,
+ embed_dim=768,
+ layer_norm=False,
+ flatten=True,
+ bias=True,
+ interpolation_scale=1,
+ pos_embed_type="sincos",
+ pos_embed_max_size=None, # For SD3 cropping
+ ):
+ super().__init__()
+
+ num_patches = (height // patch_size) * (width // patch_size)
+ self.flatten = flatten
+ self.layer_norm = layer_norm
+ self.pos_embed_max_size = pos_embed_max_size
+
+ self.proj = nn.Conv2d(
+ in_channels, embed_dim, kernel_size=(patch_size, patch_size), stride=patch_size, bias=bias
+ )
+ if layer_norm:
+ self.norm = nn.LayerNorm(embed_dim, elementwise_affine=False, eps=1e-6)
+ else:
+ self.norm = None
+
+ self.patch_size = patch_size
+ self.height, self.width = height // patch_size, width // patch_size
+ self.base_size = height // patch_size
+ self.interpolation_scale = interpolation_scale
+
+ # Calculate positional embeddings based on max size or default
+ if pos_embed_max_size:
+ grid_size = pos_embed_max_size
+ else:
+ grid_size = int(num_patches**0.5)
+
+ if pos_embed_type is None:
+ self.pos_embed = None
+ elif pos_embed_type == "sincos":
+ pos_embed = get_2d_sincos_pos_embed(
+ embed_dim, grid_size, base_size=self.base_size, interpolation_scale=self.interpolation_scale
+ )
+ persistent = True if pos_embed_max_size else False
+ self.register_buffer("pos_embed", torch.from_numpy(pos_embed).float().unsqueeze(0), persistent=persistent)
+ else:
+ raise ValueError(f"Unsupported pos_embed_type: {pos_embed_type}")
+
+ def cropped_pos_embed(self, height, width):
+ """Crops positional embeddings for SD3 compatibility."""
+ if self.pos_embed_max_size is None:
+ raise ValueError("`pos_embed_max_size` must be set for cropping.")
+
+ height = height // self.patch_size
+ width = width // self.patch_size
+ if height > self.pos_embed_max_size:
+ raise ValueError(
+ f"Height ({height}) cannot be greater than `pos_embed_max_size`: {self.pos_embed_max_size}."
+ )
+ if width > self.pos_embed_max_size:
+ raise ValueError(
+ f"Width ({width}) cannot be greater than `pos_embed_max_size`: {self.pos_embed_max_size}."
+ )
+
+ top = (self.pos_embed_max_size - height) // 2
+ left = (self.pos_embed_max_size - width) // 2
+ spatial_pos_embed = self.pos_embed.reshape(1, self.pos_embed_max_size, self.pos_embed_max_size, -1)
+ spatial_pos_embed = spatial_pos_embed[:, top : top + height, left : left + width, :]
+ spatial_pos_embed = spatial_pos_embed.reshape(1, -1, spatial_pos_embed.shape[-1])
+ return spatial_pos_embed
+
+ def forward(self, latent):
+ if self.pos_embed_max_size is not None:
+ height, width = latent.shape[-2:]
+ else:
+ height, width = latent.shape[-2] // self.patch_size, latent.shape[-1] // self.patch_size
+
+ latent = self.proj(latent)
+ if self.flatten:
+ latent = latent.flatten(2).transpose(1, 2) # BCHW -> BNC
+ if self.layer_norm:
+ latent = self.norm(latent)
+ if self.pos_embed is None:
+ return latent.to(latent.dtype)
+ # Interpolate or crop positional embeddings as needed
+ if self.pos_embed_max_size:
+ pos_embed = self.cropped_pos_embed(height, width)
+ else:
+ if self.height != height or self.width != width:
+ pos_embed = get_2d_sincos_pos_embed(
+ embed_dim=self.pos_embed.shape[-1],
+ grid_size=(height, width),
+ base_size=self.base_size,
+ interpolation_scale=self.interpolation_scale,
+ )
+ pos_embed = torch.from_numpy(pos_embed).float().unsqueeze(0).to(latent.device)
+ else:
+ pos_embed = self.pos_embed
+
+ return (latent + pos_embed).to(latent.dtype)
+
+
+class LuminaPatchEmbed(nn.Module):
+ """2D Image to Patch Embedding with support for Lumina-T2X"""
+
+ def __init__(self, patch_size=2, in_channels=4, embed_dim=768, bias=True):
+ super().__init__()
+ self.patch_size = patch_size
+ self.proj = nn.Linear(
+ in_features=patch_size * patch_size * in_channels,
+ out_features=embed_dim,
+ bias=bias,
+ )
+
+ def forward(self, x, freqs_cis):
+ """
+ Patchifies and embeds the input tensor(s).
+
+ Args:
+ x (List[torch.Tensor] | torch.Tensor): The input tensor(s) to be patchified and embedded.
+
+ Returns:
+ Tuple[torch.Tensor, torch.Tensor, List[Tuple[int, int]], torch.Tensor]: A tuple containing the patchified
+ and embedded tensor(s), the mask indicating the valid patches, the original image size(s), and the
+ frequency tensor(s).
+ """
+ freqs_cis = freqs_cis.to(x[0].device)
+ patch_height = patch_width = self.patch_size
+ batch_size, channel, height, width = x.size()
+ height_tokens, width_tokens = height // patch_height, width // patch_width
+
+ x = x.view(batch_size, channel, height_tokens, patch_height, width_tokens, patch_width).permute(
+ 0, 2, 4, 1, 3, 5
+ )
+ x = x.flatten(3)
+ x = self.proj(x)
+ x = x.flatten(1, 2)
+
+ mask = torch.ones(x.shape[0], x.shape[1], dtype=torch.int32, device=x.device)
+
+ return (
+ x,
+ mask,
+ [(height, width)] * batch_size,
+ freqs_cis[:height_tokens, :width_tokens].flatten(0, 1).unsqueeze(0),
+ )
+
+
+def get_2d_rotary_pos_embed(embed_dim, crops_coords, grid_size, use_real=True):
+ """
+ RoPE for image tokens with 2d structure.
+
+ Args:
+ embed_dim: (`int`):
+ The embedding dimension size
+ crops_coords (`Tuple[int]`)
+ The top-left and bottom-right coordinates of the crop.
+ grid_size (`Tuple[int]`):
+ The grid size of the positional embedding.
+ use_real (`bool`):
+ If True, return real part and imaginary part separately. Otherwise, return complex numbers.
+
+ Returns:
+ `torch.Tensor`: positional embedding with shape `( grid_size * grid_size, embed_dim/2)`.
+ """
+ start, stop = crops_coords
+ grid_h = np.linspace(start[0], stop[0], grid_size[0], endpoint=False, dtype=np.float32)
+ grid_w = np.linspace(start[1], stop[1], grid_size[1], endpoint=False, dtype=np.float32)
+ grid = np.meshgrid(grid_w, grid_h) # here w goes first
+ grid = np.stack(grid, axis=0) # [2, W, H]
+
+ grid = grid.reshape([2, 1, *grid.shape[1:]])
+ pos_embed = get_2d_rotary_pos_embed_from_grid(embed_dim, grid, use_real=use_real)
+ return pos_embed
+
+
+def get_2d_rotary_pos_embed_from_grid(embed_dim, grid, use_real=False):
+ assert embed_dim % 4 == 0
+
+ # use half of dimensions to encode grid_h
+ emb_h = get_1d_rotary_pos_embed(
+ embed_dim // 2, grid[0].reshape(-1), use_real=use_real
+ ) # (H*W, D/2) if use_real else (H*W, D/4)
+ emb_w = get_1d_rotary_pos_embed(
+ embed_dim // 2, grid[1].reshape(-1), use_real=use_real
+ ) # (H*W, D/2) if use_real else (H*W, D/4)
+
+ if use_real:
+ cos = torch.cat([emb_h[0], emb_w[0]], dim=1) # (H*W, D)
+ sin = torch.cat([emb_h[1], emb_w[1]], dim=1) # (H*W, D)
+ return cos, sin
+ else:
+ emb = torch.cat([emb_h, emb_w], dim=1) # (H*W, D/2)
+ return emb
+
+
+def get_2d_rotary_pos_embed_lumina(embed_dim, len_h, len_w, linear_factor=1.0, ntk_factor=1.0):
+ assert embed_dim % 4 == 0
+
+ emb_h = get_1d_rotary_pos_embed(
+ embed_dim // 2, len_h, linear_factor=linear_factor, ntk_factor=ntk_factor
+ ) # (H, D/4)
+ emb_w = get_1d_rotary_pos_embed(
+ embed_dim // 2, len_w, linear_factor=linear_factor, ntk_factor=ntk_factor
+ ) # (W, D/4)
+ emb_h = emb_h.view(len_h, 1, embed_dim // 4, 1).repeat(1, len_w, 1, 1) # (H, W, D/4, 1)
+ emb_w = emb_w.view(1, len_w, embed_dim // 4, 1).repeat(len_h, 1, 1, 1) # (H, W, D/4, 1)
+
+ emb = torch.cat([emb_h, emb_w], dim=-1).flatten(2) # (H, W, D/2)
+ return emb
+
+
+def get_1d_rotary_pos_embed(
+ dim: int,
+ pos: Union[np.ndarray, int],
+ theta: float = 10000.0,
+ use_real=False,
+ linear_factor=1.0,
+ ntk_factor=1.0,
+ repeat_interleave_real=True,
+):
+ """
+ Precompute the frequency tensor for complex exponentials (cis) with given dimensions.
+
+ This function calculates a frequency tensor with complex exponentials using the given dimension 'dim' and the end
+ index 'end'. The 'theta' parameter scales the frequencies. The returned tensor contains complex values in complex64
+ data type.
+
+ Args:
+ dim (`int`): Dimension of the frequency tensor.
+ pos (`np.ndarray` or `int`): Position indices for the frequency tensor. [S] or scalar
+ theta (`float`, *optional*, defaults to 10000.0):
+ Scaling factor for frequency computation. Defaults to 10000.0.
+ use_real (`bool`, *optional*):
+ If True, return real part and imaginary part separately. Otherwise, return complex numbers.
+ linear_factor (`float`, *optional*, defaults to 1.0):
+ Scaling factor for the context extrapolation. Defaults to 1.0.
+ ntk_factor (`float`, *optional*, defaults to 1.0):
+ Scaling factor for the NTK-Aware RoPE. Defaults to 1.0.
+ repeat_interleave_real (`bool`, *optional*, defaults to `True`):
+ If `True` and `use_real`, real part and imaginary part are each interleaved with themselves to reach `dim`.
+ Otherwise, they are concateanted with themselves.
+ Returns:
+ `torch.Tensor`: Precomputed frequency tensor with complex exponentials. [S, D/2]
+ """
+ assert dim % 2 == 0
+
+ if isinstance(pos, int):
+ pos = np.arange(pos)
+ theta = theta * ntk_factor
+ freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)) / linear_factor # [D/2]
+ t = torch.from_numpy(pos).to(freqs.device) # type: ignore # [S]
+ freqs = torch.outer(t, freqs).float() # type: ignore # [S, D/2]
+ if use_real and repeat_interleave_real:
+ freqs_cos = freqs.cos().repeat_interleave(2, dim=1) # [S, D]
+ freqs_sin = freqs.sin().repeat_interleave(2, dim=1) # [S, D]
+ return freqs_cos, freqs_sin
+ elif use_real:
+ freqs_cos = torch.cat([freqs.cos(), freqs.cos()], dim=-1) # [S, D]
+ freqs_sin = torch.cat([freqs.sin(), freqs.sin()], dim=-1) # [S, D]
+ return freqs_cos, freqs_sin
+ else:
+ freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64 # [S, D/2]
+ return freqs_cis
+
+
+def apply_rotary_emb(
+ x: torch.Tensor,
+ freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor]],
+ use_real: bool = True,
+ use_real_unbind_dim: int = -1,
+) -> Tuple[torch.Tensor, torch.Tensor]:
+ """
+ Apply rotary embeddings to input tensors using the given frequency tensor. This function applies rotary embeddings
+ to the given query or key 'x' tensors using the provided frequency tensor 'freqs_cis'. The input tensors are
+ reshaped as complex numbers, and the frequency tensor is reshaped for broadcasting compatibility. The resulting
+ tensors contain rotary embeddings and are returned as real tensors.
+
+ Args:
+ x (`torch.Tensor`):
+ Query or key tensor to apply rotary embeddings. [B, H, S, D] xk (torch.Tensor): Key tensor to apply
+ freqs_cis (`Tuple[torch.Tensor]`): Precomputed frequency tensor for complex exponentials. ([S, D], [S, D],)
+
+ Returns:
+ Tuple[torch.Tensor, torch.Tensor]: Tuple of modified query tensor and key tensor with rotary embeddings.
+ """
+ if use_real:
+ cos, sin = freqs_cis # [S, D]
+ cos = cos[None, None]
+ sin = sin[None, None]
+ cos, sin = cos.to(x.device), sin.to(x.device)
+
+ if use_real_unbind_dim == -1:
+ # Use for example in Lumina
+ x_real, x_imag = x.reshape(*x.shape[:-1], -1, 2).unbind(-1) # [B, S, H, D//2]
+ x_rotated = torch.stack([-x_imag, x_real], dim=-1).flatten(3)
+ elif use_real_unbind_dim == -2:
+ # Use for example in Stable Audio
+ x_real, x_imag = x.reshape(*x.shape[:-1], 2, -1).unbind(-2) # [B, S, H, D//2]
+ x_rotated = torch.cat([-x_imag, x_real], dim=-1)
+ else:
+ raise ValueError(f"`use_real_unbind_dim={use_real_unbind_dim}` but should be -1 or -2.")
+
+ out = (x.float() * cos + x_rotated.float() * sin).to(x.dtype)
+
+ return out
+ else:
+ x_rotated = torch.view_as_complex(x.float().reshape(*x.shape[:-1], -1, 2))
+ freqs_cis = freqs_cis.unsqueeze(2)
+ x_out = torch.view_as_real(x_rotated * freqs_cis).flatten(3)
+
+ return x_out.type_as(x)
+
+
+class TimestepEmbedding(nn.Module):
+ def __init__(
+ self,
+ in_channels: int,
+ time_embed_dim: int,
+ act_fn: str = "silu",
+ out_dim: int = None,
+ post_act_fn: Optional[str] = None,
+ cond_proj_dim=None,
+ sample_proj_bias=True,
+ ):
+ super().__init__()
+
+ self.linear_1 = nn.Linear(in_channels, time_embed_dim, sample_proj_bias)
+
+ if cond_proj_dim is not None:
+ self.cond_proj = nn.Linear(cond_proj_dim, in_channels, bias=False)
+ else:
+ self.cond_proj = None
+
+ self.act = get_activation(act_fn)
+
+ if out_dim is not None:
+ time_embed_dim_out = out_dim
+ else:
+ time_embed_dim_out = time_embed_dim
+ self.linear_2 = nn.Linear(time_embed_dim, time_embed_dim_out, sample_proj_bias)
+
+ if post_act_fn is None:
+ self.post_act = None
+ else:
+ self.post_act = get_activation(post_act_fn)
+
+ def forward(self, sample, condition=None):
+ if condition is not None:
+ sample = sample + self.cond_proj(condition)
+ sample = self.linear_1(sample)
+
+ if self.act is not None:
+ sample = self.act(sample)
+
+ sample = self.linear_2(sample)
+
+ if self.post_act is not None:
+ sample = self.post_act(sample)
+ return sample
+
+
+class Timesteps(nn.Module):
+ def __init__(self, num_channels: int, flip_sin_to_cos: bool, downscale_freq_shift: float, scale: int = 1):
+ super().__init__()
+ self.num_channels = num_channels
+ self.flip_sin_to_cos = flip_sin_to_cos
+ self.downscale_freq_shift = downscale_freq_shift
+ self.scale = scale
+
+ def forward(self, timesteps):
+ t_emb = get_timestep_embedding(
+ timesteps,
+ self.num_channels,
+ flip_sin_to_cos=self.flip_sin_to_cos,
+ downscale_freq_shift=self.downscale_freq_shift,
+ scale=self.scale,
+ )
+ return t_emb
+
+
+class GaussianFourierProjection(nn.Module):
+ """Gaussian Fourier embeddings for noise levels."""
+
+ def __init__(
+ self, embedding_size: int = 256, scale: float = 1.0, set_W_to_weight=True, log=True, flip_sin_to_cos=False
+ ):
+ super().__init__()
+ self.weight = nn.Parameter(torch.randn(embedding_size) * scale, requires_grad=False)
+ self.log = log
+ self.flip_sin_to_cos = flip_sin_to_cos
+
+ if set_W_to_weight:
+ # to delete later
+ del self.weight
+ self.W = nn.Parameter(torch.randn(embedding_size) * scale, requires_grad=False)
+ self.weight = self.W
+ del self.W
+
+ def forward(self, x):
+ if self.log:
+ x = torch.log(x)
+
+ x_proj = x[:, None] * self.weight[None, :] * 2 * np.pi
+
+ if self.flip_sin_to_cos:
+ out = torch.cat([torch.cos(x_proj), torch.sin(x_proj)], dim=-1)
+ else:
+ out = torch.cat([torch.sin(x_proj), torch.cos(x_proj)], dim=-1)
+ return out
+
+
+class SinusoidalPositionalEmbedding(nn.Module):
+ """Apply positional information to a sequence of embeddings.
+
+ Takes in a sequence of embeddings with shape (batch_size, seq_length, embed_dim) and adds positional embeddings to
+ them
+
+ Args:
+ embed_dim: (int): Dimension of the positional embedding.
+ max_seq_length: Maximum sequence length to apply positional embeddings
+
+ """
+
+ def __init__(self, embed_dim: int, max_seq_length: int = 32):
+ super().__init__()
+ position = torch.arange(max_seq_length).unsqueeze(1)
+ div_term = torch.exp(torch.arange(0, embed_dim, 2) * (-math.log(10000.0) / embed_dim))
+ pe = torch.zeros(1, max_seq_length, embed_dim)
+ pe[0, :, 0::2] = torch.sin(position * div_term)
+ pe[0, :, 1::2] = torch.cos(position * div_term)
+ self.register_buffer("pe", pe)
+
+ def forward(self, x):
+ _, seq_length, _ = x.shape
+ x = x + self.pe[:, :seq_length]
+ return x
+
+
+class ImagePositionalEmbeddings(nn.Module):
+ """
+ Converts latent image classes into vector embeddings. Sums the vector embeddings with positional embeddings for the
+ height and width of the latent space.
+
+ For more details, see figure 10 of the dall-e paper: https://arxiv.org/abs/2102.12092
+
+ For VQ-diffusion:
+
+ Output vector embeddings are used as input for the transformer.
+
+ Note that the vector embeddings for the transformer are different than the vector embeddings from the VQVAE.
+
+ Args:
+ num_embed (`int`):
+ Number of embeddings for the latent pixels embeddings.
+ height (`int`):
+ Height of the latent image i.e. the number of height embeddings.
+ width (`int`):
+ Width of the latent image i.e. the number of width embeddings.
+ embed_dim (`int`):
+ Dimension of the produced vector embeddings. Used for the latent pixel, height, and width embeddings.
+ """
+
+ def __init__(
+ self,
+ num_embed: int,
+ height: int,
+ width: int,
+ embed_dim: int,
+ ):
+ super().__init__()
+
+ self.height = height
+ self.width = width
+ self.num_embed = num_embed
+ self.embed_dim = embed_dim
+
+ self.emb = nn.Embedding(self.num_embed, embed_dim)
+ self.height_emb = nn.Embedding(self.height, embed_dim)
+ self.width_emb = nn.Embedding(self.width, embed_dim)
+
+ def forward(self, index):
+ emb = self.emb(index)
+
+ height_emb = self.height_emb(torch.arange(self.height, device=index.device).view(1, self.height))
+
+ # 1 x H x D -> 1 x H x 1 x D
+ height_emb = height_emb.unsqueeze(2)
+
+ width_emb = self.width_emb(torch.arange(self.width, device=index.device).view(1, self.width))
+
+ # 1 x W x D -> 1 x 1 x W x D
+ width_emb = width_emb.unsqueeze(1)
+
+ pos_emb = height_emb + width_emb
+
+ # 1 x H x W x D -> 1 x L xD
+ pos_emb = pos_emb.view(1, self.height * self.width, -1)
+
+ emb = emb + pos_emb[:, : emb.shape[1], :]
+
+ return emb
+
+
+class LabelEmbedding(nn.Module):
+ """
+ Embeds class labels into vector representations. Also handles label dropout for classifier-free guidance.
+
+ Args:
+ num_classes (`int`): The number of classes.
+ hidden_size (`int`): The size of the vector embeddings.
+ dropout_prob (`float`): The probability of dropping a label.
+ """
+
+ def __init__(self, num_classes, hidden_size, dropout_prob):
+ super().__init__()
+ use_cfg_embedding = dropout_prob > 0
+ self.embedding_table = nn.Embedding(num_classes + use_cfg_embedding, hidden_size)
+ self.num_classes = num_classes
+ self.dropout_prob = dropout_prob
+
+ def token_drop(self, labels, force_drop_ids=None):
+ """
+ Drops labels to enable classifier-free guidance.
+ """
+ if force_drop_ids is None:
+ drop_ids = torch.rand(labels.shape[0], device=labels.device) < self.dropout_prob
+ else:
+ drop_ids = torch.tensor(force_drop_ids == 1)
+ labels = torch.where(drop_ids, self.num_classes, labels)
+ return labels
+
+ def forward(self, labels: torch.LongTensor, force_drop_ids=None):
+ use_dropout = self.dropout_prob > 0
+ if (self.training and use_dropout) or (force_drop_ids is not None):
+ labels = self.token_drop(labels, force_drop_ids)
+ embeddings = self.embedding_table(labels)
+ return embeddings
+
+
+class TextImageProjection(nn.Module):
+ def __init__(
+ self,
+ text_embed_dim: int = 1024,
+ image_embed_dim: int = 768,
+ cross_attention_dim: int = 768,
+ num_image_text_embeds: int = 10,
+ ):
+ super().__init__()
+
+ self.num_image_text_embeds = num_image_text_embeds
+ self.image_embeds = nn.Linear(image_embed_dim, self.num_image_text_embeds * cross_attention_dim)
+ self.text_proj = nn.Linear(text_embed_dim, cross_attention_dim)
+
+ def forward(self, text_embeds: torch.Tensor, image_embeds: torch.Tensor):
+ batch_size = text_embeds.shape[0]
+
+ # image
+ image_text_embeds = self.image_embeds(image_embeds)
+ image_text_embeds = image_text_embeds.reshape(batch_size, self.num_image_text_embeds, -1)
+
+ # text
+ text_embeds = self.text_proj(text_embeds)
+
+ return torch.cat([image_text_embeds, text_embeds], dim=1)
+
+
+class ImageProjection(nn.Module):
+ def __init__(
+ self,
+ image_embed_dim: int = 768,
+ cross_attention_dim: int = 768,
+ num_image_text_embeds: int = 32,
+ ):
+ super().__init__()
+
+ self.num_image_text_embeds = num_image_text_embeds
+ self.image_embeds = nn.Linear(image_embed_dim, self.num_image_text_embeds * cross_attention_dim)
+ self.norm = nn.LayerNorm(cross_attention_dim)
+
+ def forward(self, image_embeds: torch.Tensor):
+ batch_size = image_embeds.shape[0]
+
+ # image
+ image_embeds = self.image_embeds(image_embeds)
+ image_embeds = image_embeds.reshape(batch_size, self.num_image_text_embeds, -1)
+ image_embeds = self.norm(image_embeds)
+ return image_embeds
+
+
+class IPAdapterFullImageProjection(nn.Module):
+ def __init__(self, image_embed_dim=1024, cross_attention_dim=1024):
+ super().__init__()
+ from .attention import FeedForward
+
+ self.ff = FeedForward(image_embed_dim, cross_attention_dim, mult=1, activation_fn="gelu")
+ self.norm = nn.LayerNorm(cross_attention_dim)
+
+ def forward(self, image_embeds: torch.Tensor):
+ return self.norm(self.ff(image_embeds))
+
+
+class IPAdapterFaceIDImageProjection(nn.Module):
+ def __init__(self, image_embed_dim=1024, cross_attention_dim=1024, mult=1, num_tokens=1):
+ super().__init__()
+ from .attention import FeedForward
+
+ self.num_tokens = num_tokens
+ self.cross_attention_dim = cross_attention_dim
+ self.ff = FeedForward(image_embed_dim, cross_attention_dim * num_tokens, mult=mult, activation_fn="gelu")
+ self.norm = nn.LayerNorm(cross_attention_dim)
+
+ def forward(self, image_embeds: torch.Tensor):
+ x = self.ff(image_embeds)
+ x = x.reshape(-1, self.num_tokens, self.cross_attention_dim)
+ return self.norm(x)
+
+
+class CombinedTimestepLabelEmbeddings(nn.Module):
+ def __init__(self, num_classes, embedding_dim, class_dropout_prob=0.1):
+ super().__init__()
+
+ self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=1)
+ self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim)
+ self.class_embedder = LabelEmbedding(num_classes, embedding_dim, class_dropout_prob)
+
+ def forward(self, timestep, class_labels, hidden_dtype=None):
+ timesteps_proj = self.time_proj(timestep)
+ timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=hidden_dtype)) # (N, D)
+
+ class_labels = self.class_embedder(class_labels) # (N, D)
+
+ conditioning = timesteps_emb + class_labels # (N, D)
+
+ return conditioning
+
+
+class CombinedTimestepTextProjEmbeddings(nn.Module):
+ def __init__(self, embedding_dim, pooled_projection_dim):
+ super().__init__()
+
+ self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0)
+ self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim)
+ self.text_embedder = PixArtAlphaTextProjection(pooled_projection_dim, embedding_dim, act_fn="silu")
+
+ def forward(self, timestep, pooled_projection):
+ timesteps_proj = self.time_proj(timestep)
+ timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=pooled_projection.dtype)) # (N, D)
+
+ pooled_projections = self.text_embedder(pooled_projection)
+
+ conditioning = timesteps_emb + pooled_projections
+
+ return conditioning
+
+
+class CombinedTimestepGuidanceTextProjEmbeddings(nn.Module):
+ def __init__(self, embedding_dim, pooled_projection_dim):
+ super().__init__()
+
+ self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0)
+ self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim)
+ self.guidance_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim)
+ self.text_embedder = PixArtAlphaTextProjection(pooled_projection_dim, embedding_dim, act_fn="silu")
+
+ def forward(self, timestep, guidance, pooled_projection):
+ timesteps_proj = self.time_proj(timestep)
+ timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=pooled_projection.dtype)) # (N, D)
+
+ guidance_proj = self.time_proj(guidance)
+ guidance_emb = self.guidance_embedder(guidance_proj.to(dtype=pooled_projection.dtype)) # (N, D)
+
+ time_guidance_emb = timesteps_emb + guidance_emb
+
+ pooled_projections = self.text_embedder(pooled_projection)
+ conditioning = time_guidance_emb + pooled_projections
+
+ return conditioning
+
+
+class HunyuanDiTAttentionPool(nn.Module):
+ # Copied from https://github.com/Tencent/HunyuanDiT/blob/cb709308d92e6c7e8d59d0dff41b74d35088db6a/hydit/modules/poolers.py#L6
+
+ def __init__(self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None):
+ super().__init__()
+ self.positional_embedding = nn.Parameter(torch.randn(spacial_dim + 1, embed_dim) / embed_dim**0.5)
+ self.k_proj = nn.Linear(embed_dim, embed_dim)
+ self.q_proj = nn.Linear(embed_dim, embed_dim)
+ self.v_proj = nn.Linear(embed_dim, embed_dim)
+ self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim)
+ self.num_heads = num_heads
+
+ def forward(self, x):
+ x = x.permute(1, 0, 2) # NLC -> LNC
+ x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (L+1)NC
+ x = x + self.positional_embedding[:, None, :].to(x.dtype) # (L+1)NC
+ x, _ = F.multi_head_attention_forward(
+ query=x[:1],
+ key=x,
+ value=x,
+ embed_dim_to_check=x.shape[-1],
+ num_heads=self.num_heads,
+ q_proj_weight=self.q_proj.weight,
+ k_proj_weight=self.k_proj.weight,
+ v_proj_weight=self.v_proj.weight,
+ in_proj_weight=None,
+ in_proj_bias=torch.cat([self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]),
+ bias_k=None,
+ bias_v=None,
+ add_zero_attn=False,
+ dropout_p=0,
+ out_proj_weight=self.c_proj.weight,
+ out_proj_bias=self.c_proj.bias,
+ use_separate_proj_weight=True,
+ training=self.training,
+ need_weights=False,
+ )
+ return x.squeeze(0)
+
+
+class HunyuanCombinedTimestepTextSizeStyleEmbedding(nn.Module):
+ def __init__(
+ self,
+ embedding_dim,
+ pooled_projection_dim=1024,
+ seq_len=256,
+ cross_attention_dim=2048,
+ use_style_cond_and_image_meta_size=True,
+ ):
+ super().__init__()
+
+ self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0)
+ self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim)
+
+ self.size_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0)
+
+ self.pooler = HunyuanDiTAttentionPool(
+ seq_len, cross_attention_dim, num_heads=8, output_dim=pooled_projection_dim
+ )
+
+ # Here we use a default learned embedder layer for future extension.
+ self.use_style_cond_and_image_meta_size = use_style_cond_and_image_meta_size
+ if use_style_cond_and_image_meta_size:
+ self.style_embedder = nn.Embedding(1, embedding_dim)
+ extra_in_dim = 256 * 6 + embedding_dim + pooled_projection_dim
+ else:
+ extra_in_dim = pooled_projection_dim
+
+ self.extra_embedder = PixArtAlphaTextProjection(
+ in_features=extra_in_dim,
+ hidden_size=embedding_dim * 4,
+ out_features=embedding_dim,
+ act_fn="silu_fp32",
+ )
+
+ def forward(self, timestep, encoder_hidden_states, image_meta_size, style, hidden_dtype=None):
+ timesteps_proj = self.time_proj(timestep)
+ timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=hidden_dtype)) # (N, 256)
+
+ # extra condition1: text
+ pooled_projections = self.pooler(encoder_hidden_states) # (N, 1024)
+
+ if self.use_style_cond_and_image_meta_size:
+ # extra condition2: image meta size embedding
+ image_meta_size = self.size_proj(image_meta_size.view(-1))
+ image_meta_size = image_meta_size.to(dtype=hidden_dtype)
+ image_meta_size = image_meta_size.view(-1, 6 * 256) # (N, 1536)
+
+ # extra condition3: style embedding
+ style_embedding = self.style_embedder(style) # (N, embedding_dim)
+
+ # Concatenate all extra vectors
+ extra_cond = torch.cat([pooled_projections, image_meta_size, style_embedding], dim=1)
+ else:
+ extra_cond = torch.cat([pooled_projections], dim=1)
+
+ conditioning = timesteps_emb + self.extra_embedder(extra_cond) # [B, D]
+
+ return conditioning
+
+
+class LuminaCombinedTimestepCaptionEmbedding(nn.Module):
+ def __init__(self, hidden_size=4096, cross_attention_dim=2048, frequency_embedding_size=256):
+ super().__init__()
+ self.time_proj = Timesteps(
+ num_channels=frequency_embedding_size, flip_sin_to_cos=True, downscale_freq_shift=0.0
+ )
+
+ self.timestep_embedder = TimestepEmbedding(in_channels=frequency_embedding_size, time_embed_dim=hidden_size)
+
+ self.caption_embedder = nn.Sequential(
+ nn.LayerNorm(cross_attention_dim),
+ nn.Linear(
+ cross_attention_dim,
+ hidden_size,
+ bias=True,
+ ),
+ )
+
+ def forward(self, timestep, caption_feat, caption_mask):
+ # timestep embedding:
+ time_freq = self.time_proj(timestep)
+ time_embed = self.timestep_embedder(time_freq.to(dtype=self.timestep_embedder.linear_1.weight.dtype))
+
+ # caption condition embedding:
+ caption_mask_float = caption_mask.float().unsqueeze(-1)
+ caption_feats_pool = (caption_feat * caption_mask_float).sum(dim=1) / caption_mask_float.sum(dim=1)
+ caption_feats_pool = caption_feats_pool.to(caption_feat)
+ caption_embed = self.caption_embedder(caption_feats_pool)
+
+ conditioning = time_embed + caption_embed
+
+ return conditioning
+
+
+class TextTimeEmbedding(nn.Module):
+ def __init__(self, encoder_dim: int, time_embed_dim: int, num_heads: int = 64):
+ super().__init__()
+ self.norm1 = nn.LayerNorm(encoder_dim)
+ self.pool = AttentionPooling(num_heads, encoder_dim)
+ self.proj = nn.Linear(encoder_dim, time_embed_dim)
+ self.norm2 = nn.LayerNorm(time_embed_dim)
+
+ def forward(self, hidden_states):
+ hidden_states = self.norm1(hidden_states)
+ hidden_states = self.pool(hidden_states)
+ hidden_states = self.proj(hidden_states)
+ hidden_states = self.norm2(hidden_states)
+ return hidden_states
+
+
+class TextImageTimeEmbedding(nn.Module):
+ def __init__(self, text_embed_dim: int = 768, image_embed_dim: int = 768, time_embed_dim: int = 1536):
+ super().__init__()
+ self.text_proj = nn.Linear(text_embed_dim, time_embed_dim)
+ self.text_norm = nn.LayerNorm(time_embed_dim)
+ self.image_proj = nn.Linear(image_embed_dim, time_embed_dim)
+
+ def forward(self, text_embeds: torch.Tensor, image_embeds: torch.Tensor):
+ # text
+ time_text_embeds = self.text_proj(text_embeds)
+ time_text_embeds = self.text_norm(time_text_embeds)
+
+ # image
+ time_image_embeds = self.image_proj(image_embeds)
+
+ return time_image_embeds + time_text_embeds
+
+
+class ImageTimeEmbedding(nn.Module):
+ def __init__(self, image_embed_dim: int = 768, time_embed_dim: int = 1536):
+ super().__init__()
+ self.image_proj = nn.Linear(image_embed_dim, time_embed_dim)
+ self.image_norm = nn.LayerNorm(time_embed_dim)
+
+ def forward(self, image_embeds: torch.Tensor):
+ # image
+ time_image_embeds = self.image_proj(image_embeds)
+ time_image_embeds = self.image_norm(time_image_embeds)
+ return time_image_embeds
+
+
+class ImageHintTimeEmbedding(nn.Module):
+ def __init__(self, image_embed_dim: int = 768, time_embed_dim: int = 1536):
+ super().__init__()
+ self.image_proj = nn.Linear(image_embed_dim, time_embed_dim)
+ self.image_norm = nn.LayerNorm(time_embed_dim)
+ self.input_hint_block = nn.Sequential(
+ nn.Conv2d(3, 16, 3, padding=1),
+ nn.SiLU(),
+ nn.Conv2d(16, 16, 3, padding=1),
+ nn.SiLU(),
+ nn.Conv2d(16, 32, 3, padding=1, stride=2),
+ nn.SiLU(),
+ nn.Conv2d(32, 32, 3, padding=1),
+ nn.SiLU(),
+ nn.Conv2d(32, 96, 3, padding=1, stride=2),
+ nn.SiLU(),
+ nn.Conv2d(96, 96, 3, padding=1),
+ nn.SiLU(),
+ nn.Conv2d(96, 256, 3, padding=1, stride=2),
+ nn.SiLU(),
+ nn.Conv2d(256, 4, 3, padding=1),
+ )
+
+ def forward(self, image_embeds: torch.Tensor, hint: torch.Tensor):
+ # image
+ time_image_embeds = self.image_proj(image_embeds)
+ time_image_embeds = self.image_norm(time_image_embeds)
+ hint = self.input_hint_block(hint)
+ return time_image_embeds, hint
+
+
+class AttentionPooling(nn.Module):
+ # Copied from https://github.com/deep-floyd/IF/blob/2f91391f27dd3c468bf174be5805b4cc92980c0b/deepfloyd_if/model/nn.py#L54
+
+ def __init__(self, num_heads, embed_dim, dtype=None):
+ super().__init__()
+ self.dtype = dtype
+ self.positional_embedding = nn.Parameter(torch.randn(1, embed_dim) / embed_dim**0.5)
+ self.k_proj = nn.Linear(embed_dim, embed_dim, dtype=self.dtype)
+ self.q_proj = nn.Linear(embed_dim, embed_dim, dtype=self.dtype)
+ self.v_proj = nn.Linear(embed_dim, embed_dim, dtype=self.dtype)
+ self.num_heads = num_heads
+ self.dim_per_head = embed_dim // self.num_heads
+
+ def forward(self, x):
+ bs, length, width = x.size()
+
+ def shape(x):
+ # (bs, length, width) --> (bs, length, n_heads, dim_per_head)
+ x = x.view(bs, -1, self.num_heads, self.dim_per_head)
+ # (bs, length, n_heads, dim_per_head) --> (bs, n_heads, length, dim_per_head)
+ x = x.transpose(1, 2)
+ # (bs, n_heads, length, dim_per_head) --> (bs*n_heads, length, dim_per_head)
+ x = x.reshape(bs * self.num_heads, -1, self.dim_per_head)
+ # (bs*n_heads, length, dim_per_head) --> (bs*n_heads, dim_per_head, length)
+ x = x.transpose(1, 2)
+ return x
+
+ class_token = x.mean(dim=1, keepdim=True) + self.positional_embedding.to(x.dtype)
+ x = torch.cat([class_token, x], dim=1) # (bs, length+1, width)
+
+ # (bs*n_heads, class_token_length, dim_per_head)
+ q = shape(self.q_proj(class_token))
+ # (bs*n_heads, length+class_token_length, dim_per_head)
+ k = shape(self.k_proj(x))
+ v = shape(self.v_proj(x))
+
+ # (bs*n_heads, class_token_length, length+class_token_length):
+ scale = 1 / math.sqrt(math.sqrt(self.dim_per_head))
+ weight = torch.einsum("bct,bcs->bts", q * scale, k * scale) # More stable with f16 than dividing afterwards
+ weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype)
+
+ # (bs*n_heads, dim_per_head, class_token_length)
+ a = torch.einsum("bts,bcs->bct", weight, v)
+
+ # (bs, length+1, width)
+ a = a.reshape(bs, -1, 1).transpose(1, 2)
+
+ return a[:, 0, :] # cls_token
+
+
+def get_fourier_embeds_from_boundingbox(embed_dim, box):
+ """
+ Args:
+ embed_dim: int
+ box: a 3-D tensor [B x N x 4] representing the bounding boxes for GLIGEN pipeline
+ Returns:
+ [B x N x embed_dim] tensor of positional embeddings
+ """
+
+ batch_size, num_boxes = box.shape[:2]
+
+ emb = 100 ** (torch.arange(embed_dim) / embed_dim)
+ emb = emb[None, None, None].to(device=box.device, dtype=box.dtype)
+ emb = emb * box.unsqueeze(-1)
+
+ emb = torch.stack((emb.sin(), emb.cos()), dim=-1)
+ emb = emb.permute(0, 1, 3, 4, 2).reshape(batch_size, num_boxes, embed_dim * 2 * 4)
+
+ return emb
+
+
+class GLIGENTextBoundingboxProjection(nn.Module):
+ def __init__(self, positive_len, out_dim, feature_type="text-only", fourier_freqs=8):
+ super().__init__()
+ self.positive_len = positive_len
+ self.out_dim = out_dim
+
+ self.fourier_embedder_dim = fourier_freqs
+ self.position_dim = fourier_freqs * 2 * 4 # 2: sin/cos, 4: xyxy
+
+ if isinstance(out_dim, tuple):
+ out_dim = out_dim[0]
+
+ if feature_type == "text-only":
+ self.linears = nn.Sequential(
+ nn.Linear(self.positive_len + self.position_dim, 512),
+ nn.SiLU(),
+ nn.Linear(512, 512),
+ nn.SiLU(),
+ nn.Linear(512, out_dim),
+ )
+ self.null_positive_feature = torch.nn.Parameter(torch.zeros([self.positive_len]))
+
+ elif feature_type == "text-image":
+ self.linears_text = nn.Sequential(
+ nn.Linear(self.positive_len + self.position_dim, 512),
+ nn.SiLU(),
+ nn.Linear(512, 512),
+ nn.SiLU(),
+ nn.Linear(512, out_dim),
+ )
+ self.linears_image = nn.Sequential(
+ nn.Linear(self.positive_len + self.position_dim, 512),
+ nn.SiLU(),
+ nn.Linear(512, 512),
+ nn.SiLU(),
+ nn.Linear(512, out_dim),
+ )
+ self.null_text_feature = torch.nn.Parameter(torch.zeros([self.positive_len]))
+ self.null_image_feature = torch.nn.Parameter(torch.zeros([self.positive_len]))
+
+ self.null_position_feature = torch.nn.Parameter(torch.zeros([self.position_dim]))
+
+ def forward(
+ self,
+ boxes,
+ masks,
+ positive_embeddings=None,
+ phrases_masks=None,
+ image_masks=None,
+ phrases_embeddings=None,
+ image_embeddings=None,
+ ):
+ masks = masks.unsqueeze(-1)
+
+ # embedding position (it may includes padding as placeholder)
+ xyxy_embedding = get_fourier_embeds_from_boundingbox(self.fourier_embedder_dim, boxes) # B*N*4 -> B*N*C
+
+ # learnable null embedding
+ xyxy_null = self.null_position_feature.view(1, 1, -1)
+
+ # replace padding with learnable null embedding
+ xyxy_embedding = xyxy_embedding * masks + (1 - masks) * xyxy_null
+
+ # positionet with text only information
+ if positive_embeddings is not None:
+ # learnable null embedding
+ positive_null = self.null_positive_feature.view(1, 1, -1)
+
+ # replace padding with learnable null embedding
+ positive_embeddings = positive_embeddings * masks + (1 - masks) * positive_null
+
+ objs = self.linears(torch.cat([positive_embeddings, xyxy_embedding], dim=-1))
+
+ # positionet with text and image information
+ else:
+ phrases_masks = phrases_masks.unsqueeze(-1)
+ image_masks = image_masks.unsqueeze(-1)
+
+ # learnable null embedding
+ text_null = self.null_text_feature.view(1, 1, -1)
+ image_null = self.null_image_feature.view(1, 1, -1)
+
+ # replace padding with learnable null embedding
+ phrases_embeddings = phrases_embeddings * phrases_masks + (1 - phrases_masks) * text_null
+ image_embeddings = image_embeddings * image_masks + (1 - image_masks) * image_null
+
+ objs_text = self.linears_text(torch.cat([phrases_embeddings, xyxy_embedding], dim=-1))
+ objs_image = self.linears_image(torch.cat([image_embeddings, xyxy_embedding], dim=-1))
+ objs = torch.cat([objs_text, objs_image], dim=1)
+
+ return objs
+
+
+class PixArtAlphaCombinedTimestepSizeEmbeddings(nn.Module):
+ """
+ For PixArt-Alpha.
+
+ Reference:
+ https://github.com/PixArt-alpha/PixArt-alpha/blob/0f55e922376d8b797edd44d25d0e7464b260dcab/diffusion/model/nets/PixArtMS.py#L164C9-L168C29
+ """
+
+ def __init__(self, embedding_dim, size_emb_dim, use_additional_conditions: bool = False):
+ super().__init__()
+
+ self.outdim = size_emb_dim
+ self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0)
+ self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim)
+
+ self.use_additional_conditions = use_additional_conditions
+ if use_additional_conditions:
+ self.additional_condition_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0)
+ self.resolution_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=size_emb_dim)
+ self.aspect_ratio_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=size_emb_dim)
+
+ def forward(self, timestep, resolution, aspect_ratio, batch_size, hidden_dtype):
+ timesteps_proj = self.time_proj(timestep)
+ timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=hidden_dtype)) # (N, D)
+
+ if self.use_additional_conditions:
+ resolution_emb = self.additional_condition_proj(resolution.flatten()).to(hidden_dtype)
+ resolution_emb = self.resolution_embedder(resolution_emb).reshape(batch_size, -1)
+ aspect_ratio_emb = self.additional_condition_proj(aspect_ratio.flatten()).to(hidden_dtype)
+ aspect_ratio_emb = self.aspect_ratio_embedder(aspect_ratio_emb).reshape(batch_size, -1)
+ conditioning = timesteps_emb + torch.cat([resolution_emb, aspect_ratio_emb], dim=1)
+ else:
+ conditioning = timesteps_emb
+
+ return conditioning
+
+
+class PixArtAlphaTextProjection(nn.Module):
+ """
+ Projects caption embeddings. Also handles dropout for classifier-free guidance.
+
+ Adapted from https://github.com/PixArt-alpha/PixArt-alpha/blob/master/diffusion/model/nets/PixArt_blocks.py
+ """
+
+ def __init__(self, in_features, hidden_size, out_features=None, act_fn="gelu_tanh"):
+ super().__init__()
+ if out_features is None:
+ out_features = hidden_size
+ self.linear_1 = nn.Linear(in_features=in_features, out_features=hidden_size, bias=True)
+ if act_fn == "gelu_tanh":
+ self.act_1 = nn.GELU(approximate="tanh")
+ elif act_fn == "silu":
+ self.act_1 = nn.SiLU()
+ elif act_fn == "silu_fp32":
+ self.act_1 = FP32SiLU()
+ else:
+ raise ValueError(f"Unknown activation function: {act_fn}")
+ self.linear_2 = nn.Linear(in_features=hidden_size, out_features=out_features, bias=True)
+
+ def forward(self, caption):
+ hidden_states = self.linear_1(caption)
+ hidden_states = self.act_1(hidden_states)
+ hidden_states = self.linear_2(hidden_states)
+ return hidden_states
+
+
+class IPAdapterPlusImageProjectionBlock(nn.Module):
+ def __init__(
+ self,
+ embed_dims: int = 768,
+ dim_head: int = 64,
+ heads: int = 16,
+ ffn_ratio: float = 4,
+ ) -> None:
+ super().__init__()
+ from .attention import FeedForward
+
+ self.ln0 = nn.LayerNorm(embed_dims)
+ self.ln1 = nn.LayerNorm(embed_dims)
+ self.attn = Attention(
+ query_dim=embed_dims,
+ dim_head=dim_head,
+ heads=heads,
+ out_bias=False,
+ )
+ self.ff = nn.Sequential(
+ nn.LayerNorm(embed_dims),
+ FeedForward(embed_dims, embed_dims, activation_fn="gelu", mult=ffn_ratio, bias=False),
+ )
+
+ def forward(self, x, latents, residual):
+ encoder_hidden_states = self.ln0(x)
+ latents = self.ln1(latents)
+ encoder_hidden_states = torch.cat([encoder_hidden_states, latents], dim=-2)
+ latents = self.attn(latents, encoder_hidden_states) + residual
+ latents = self.ff(latents) + latents
+ return latents
+
+
+class IPAdapterPlusImageProjection(nn.Module):
+ """Resampler of IP-Adapter Plus.
+
+ Args:
+ embed_dims (int): The feature dimension. Defaults to 768. output_dims (int): The number of output channels,
+ that is the same
+ number of the channels in the `unet.config.cross_attention_dim`. Defaults to 1024.
+ hidden_dims (int):
+ The number of hidden channels. Defaults to 1280. depth (int): The number of blocks. Defaults
+ to 8. dim_head (int): The number of head channels. Defaults to 64. heads (int): Parallel attention heads.
+ Defaults to 16. num_queries (int):
+ The number of queries. Defaults to 8. ffn_ratio (float): The expansion ratio
+ of feedforward network hidden
+ layer channels. Defaults to 4.
+ """
+
+ def __init__(
+ self,
+ embed_dims: int = 768,
+ output_dims: int = 1024,
+ hidden_dims: int = 1280,
+ depth: int = 4,
+ dim_head: int = 64,
+ heads: int = 16,
+ num_queries: int = 8,
+ ffn_ratio: float = 4,
+ ) -> None:
+ super().__init__()
+ self.latents = nn.Parameter(torch.randn(1, num_queries, hidden_dims) / hidden_dims**0.5)
+
+ self.proj_in = nn.Linear(embed_dims, hidden_dims)
+
+ self.proj_out = nn.Linear(hidden_dims, output_dims)
+ self.norm_out = nn.LayerNorm(output_dims)
+
+ self.layers = nn.ModuleList(
+ [IPAdapterPlusImageProjectionBlock(hidden_dims, dim_head, heads, ffn_ratio) for _ in range(depth)]
+ )
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ """Forward pass.
+
+ Args:
+ x (torch.Tensor): Input Tensor.
+ Returns:
+ torch.Tensor: Output Tensor.
+ """
+ latents = self.latents.repeat(x.size(0), 1, 1)
+
+ x = self.proj_in(x)
+
+ for block in self.layers:
+ residual = latents
+ latents = block(x, latents, residual)
+
+ latents = self.proj_out(latents)
+ return self.norm_out(latents)
+
+
+class IPAdapterFaceIDPlusImageProjection(nn.Module):
+ """FacePerceiverResampler of IP-Adapter Plus.
+
+ Args:
+ embed_dims (int): The feature dimension. Defaults to 768. output_dims (int): The number of output channels,
+ that is the same
+ number of the channels in the `unet.config.cross_attention_dim`. Defaults to 1024.
+ hidden_dims (int):
+ The number of hidden channels. Defaults to 1280. depth (int): The number of blocks. Defaults
+ to 8. dim_head (int): The number of head channels. Defaults to 64. heads (int): Parallel attention heads.
+ Defaults to 16. num_tokens (int): Number of tokens num_queries (int): The number of queries. Defaults to 8.
+ ffn_ratio (float): The expansion ratio of feedforward network hidden
+ layer channels. Defaults to 4.
+ ffproj_ratio (float): The expansion ratio of feedforward network hidden
+ layer channels (for ID embeddings). Defaults to 4.
+ """
+
+ def __init__(
+ self,
+ embed_dims: int = 768,
+ output_dims: int = 768,
+ hidden_dims: int = 1280,
+ id_embeddings_dim: int = 512,
+ depth: int = 4,
+ dim_head: int = 64,
+ heads: int = 16,
+ num_tokens: int = 4,
+ num_queries: int = 8,
+ ffn_ratio: float = 4,
+ ffproj_ratio: int = 2,
+ ) -> None:
+ super().__init__()
+ from .attention import FeedForward
+
+ self.num_tokens = num_tokens
+ self.embed_dim = embed_dims
+ self.clip_embeds = None
+ self.shortcut = False
+ self.shortcut_scale = 1.0
+
+ self.proj = FeedForward(id_embeddings_dim, embed_dims * num_tokens, activation_fn="gelu", mult=ffproj_ratio)
+ self.norm = nn.LayerNorm(embed_dims)
+
+ self.proj_in = nn.Linear(hidden_dims, embed_dims)
+
+ self.proj_out = nn.Linear(embed_dims, output_dims)
+ self.norm_out = nn.LayerNorm(output_dims)
+
+ self.layers = nn.ModuleList(
+ [IPAdapterPlusImageProjectionBlock(embed_dims, dim_head, heads, ffn_ratio) for _ in range(depth)]
+ )
+
+ def forward(self, id_embeds: torch.Tensor) -> torch.Tensor:
+ """Forward pass.
+
+ Args:
+ id_embeds (torch.Tensor): Input Tensor (ID embeds).
+ Returns:
+ torch.Tensor: Output Tensor.
+ """
+ id_embeds = id_embeds.to(self.clip_embeds.dtype)
+ id_embeds = self.proj(id_embeds)
+ id_embeds = id_embeds.reshape(-1, self.num_tokens, self.embed_dim)
+ id_embeds = self.norm(id_embeds)
+ latents = id_embeds
+
+ clip_embeds = self.proj_in(self.clip_embeds)
+ x = clip_embeds.reshape(-1, clip_embeds.shape[2], clip_embeds.shape[3])
+
+ for block in self.layers:
+ residual = latents
+ latents = block(x, latents, residual)
+
+ latents = self.proj_out(latents)
+ out = self.norm_out(latents)
+ if self.shortcut:
+ out = id_embeds + self.shortcut_scale * out
+ return out
+
+
+class MultiIPAdapterImageProjection(nn.Module):
+ def __init__(self, IPAdapterImageProjectionLayers: Union[List[nn.Module], Tuple[nn.Module]]):
+ super().__init__()
+ self.image_projection_layers = nn.ModuleList(IPAdapterImageProjectionLayers)
+
+ def forward(self, image_embeds: List[torch.Tensor]):
+ projected_image_embeds = []
+
+ # currently, we accept `image_embeds` as
+ # 1. a tensor (deprecated) with shape [batch_size, embed_dim] or [batch_size, sequence_length, embed_dim]
+ # 2. list of `n` tensors where `n` is number of ip-adapters, each tensor can hae shape [batch_size, num_images, embed_dim] or [batch_size, num_images, sequence_length, embed_dim]
+ if not isinstance(image_embeds, list):
+ deprecation_message = (
+ "You have passed a tensor as `image_embeds`.This is deprecated and will be removed in a future release."
+ " Please make sure to update your script to pass `image_embeds` as a list of tensors to suppress this warning."
+ )
+ deprecate("image_embeds not a list", "1.0.0", deprecation_message, standard_warn=False)
+ image_embeds = [image_embeds.unsqueeze(1)]
+
+ if len(image_embeds) != len(self.image_projection_layers):
+ raise ValueError(
+ f"image_embeds must have the same length as image_projection_layers, got {len(image_embeds)} and {len(self.image_projection_layers)}"
+ )
+
+ for image_embed, image_projection_layer in zip(image_embeds, self.image_projection_layers):
+ batch_size, num_images = image_embed.shape[0], image_embed.shape[1]
+ image_embed = image_embed.reshape((batch_size * num_images,) + image_embed.shape[2:])
+ image_embed = image_projection_layer(image_embed)
+ image_embed = image_embed.reshape((batch_size, num_images) + image_embed.shape[1:])
+
+ projected_image_embeds.append(image_embed)
+
+ return projected_image_embeds
\ No newline at end of file
diff --git a/flux/flux_network.py b/flux/flux_network.py
new file mode 100644
index 0000000000000000000000000000000000000000..f0f67afffc45da22c372d6db4d8a0f56e42d2ca8
--- /dev/null
+++ b/flux/flux_network.py
@@ -0,0 +1,183 @@
+import torch
+import torch.nn as nn
+from .transformer_flux import FluxTransformer2DModel
+
+class FluxNetwork(nn.Module):
+ TARGET_REPLACE_MODULE = ["FluxTransformerBlock","FluxSingleTransformerBlock"] # ๅฏ่ฎญ็ป็ๆจกๅ็ฑปๅ
+ FLUX_PREFIX = "flux"
+
+ def __init__(self, flux_model: FluxTransformer2DModel):
+ super().__init__()
+ self.flux_model = flux_model
+ self.trainable_component_names = [] # ็จไบ่ฎฐๅฝๅฏ่ฎญ็ป็ปไปถ็ๅ็งฐ
+
+ @staticmethod
+ def generate_trainable_components(layers, num_transformer_blocks=19):
+ transformer_components = [
+ "attn.to_q",
+ "attn.to_k",
+ "attn.to_v",
+ "attn.to_out",
+ "norm1",
+ "norm1_context",
+ ]
+
+ single_transformer_components = [
+ "attn.to_q",
+ "attn.to_k",
+ "attn.to_v",
+ "norm",
+ #"proj_mlp",
+ ]
+
+ components = ["context_embedder"] # ๆทปๅ context_embedder
+ for layer in layers:
+ if layer < num_transformer_blocks:
+ prefix = f"transformer_blocks.{layer}"
+ base_components = transformer_components
+ else:
+ prefix = f"single_transformer_blocks.{layer - num_transformer_blocks}"
+ base_components = single_transformer_components
+ components.extend([f"{prefix}.{comp}" for comp in base_components])
+
+ return components
+
+ #def apply_to(self, num_layers=1, additional_components=None):
+ # component_names = self.generate_trainable_components(num_layers)
+ #
+ # if additional_components:
+ # component_names.extend(additional_components)
+ #
+ # self.trainable_component_names = [] # ้็ฝฎ
+ # for name in component_names:
+ # recursive_getattr(self.flux_model, name).requires_grad_(True)
+ # self.trainable_component_names.append(name) # ่ฎฐๅฝๅ็งฐ
+
+ #def apply_to(self, num_layers=1, additional_components=None):
+ # component_names = self.generate_trainable_components(num_layers)
+ #
+ # if additional_components:
+ # component_names.extend(additional_components)
+ #
+ # self.trainable_component_names = [] # ้็ฝฎ
+ # for name in component_names:
+ # component = recursive_getattr(self.flux_model, name)
+ # if isinstance(component, nn.Module):
+ # component.requires_grad_(True)
+ # self.trainable_component_names.append(name)
+ # else:
+ # print(f"Warning: {name} is not a Module, skipping.")
+
+ def apply_to(self, layers=None, additional_components=None):
+ if layers is None:
+ layers = list(range(57)) # ้ป่ฎคๅ
ๅซๆๆๅฑ
+
+ component_names = self.generate_trainable_components(layers)
+
+ if additional_components:
+ component_names.extend(additional_components)
+
+ self.trainable_component_names = [] # ้็ฝฎ
+ for name in component_names:
+ try:
+ component = recursive_getattr(self.flux_model, name)
+ if isinstance(component, nn.Module):
+ component.requires_grad_(True)
+ self.trainable_component_names.append(name)
+ else:
+ print(f"Warning: {name} is not a Module, skipping.")
+ except AttributeError:
+ print(f"Warning: {name} not found in the model, skipping.")
+
+ def prepare_grad_etc(self):
+ # ไพflux_model่ฐ็จ,็จไบๅป็ป/่งฃๅป็ปไปถ
+ self.flux_model.requires_grad_(False)
+ for name in self.trainable_component_names:
+ recursive_getattr(self.flux_model, name).requires_grad_(True)
+
+ def get_trainable_params(self):
+ # ่ฟๅ้่ฆ่ฎญ็ป็ๅๆฐ
+ params = []
+ for name in self.trainable_component_names:
+ params.extend(recursive_getattr(self.flux_model, name).parameters())
+ return params
+
+ def print_trainable_params_info(self):
+ total_params = 0
+ for name in self.trainable_component_names:
+ module = recursive_getattr(self.flux_model, name)
+ module_params = sum(p.numel() for p in module.parameters() if p.requires_grad)
+ total_params += module_params
+ #print(f'{name}: {module_params} trainable parameters')
+ print(f'Total trainable params: {total_params}')
+
+ def save_weights(self, file, dtype=None):
+ # ไฟๅญ้่ฆ่ฎญ็ป็็ปไปถๅๆฐ
+ state_dict = {}
+ for name in self.trainable_component_names:
+ state_dict[name] = recursive_getattr(self.flux_model, name).state_dict()
+ if dtype is not None:
+ for v in state_dict.values():
+ v = {k: t.detach().clone().to("cpu").to(dtype) for k, t in v.items()}
+ torch.save(state_dict, file)
+
+ #def load_weights(self, file):
+ # # ๅ ่ฝฝ้่ฆ่ฎญ็ป็็ปไปถๅๆฐ
+ # state_dict = torch.load(file, weights_only=True)
+ # for name in state_dict:
+ # module = recursive_getattr(self.flux_model, name)
+ # module.load_state_dict(state_dict[name])
+ # print(f"ๅ ่ฝฝๅๆฐ: {name}")
+
+ def load_weights(self, file, device):
+ print(f"Loading weights from {file}")
+ try:
+ state_dict = torch.load(file, map_location=device, weights_only=True)
+ except Exception as e:
+ print(f"Failed to load weights from {file}: {str(e)}")
+ return False
+
+ successfully_loaded = []
+ failed_to_load = []
+
+ for name in state_dict:
+ try:
+ module = recursive_getattr(self.flux_model, name)
+ module_state_dict = module.state_dict()
+
+ # ๆฃๆฅstate_dict็้ฎๆฏๅฆๅน้
+ if set(state_dict[name].keys()) != set(module_state_dict.keys()):
+ raise ValueError(f"State dict keys for {name} do not match")
+
+ # ๆฃๆฅๅผ ้็ๅฝข็ถๆฏๅฆๅน้
+ for key in state_dict[name]:
+ if state_dict[name][key].shape != module_state_dict[key].shape:
+ raise ValueError(f"Shape mismatch for {name}.{key}")
+
+ module.load_state_dict(state_dict[name])
+ successfully_loaded.append(name)
+
+ except Exception as e:
+ print(f"Failed to load weights for {name}: {str(e)}")
+ failed_to_load.append(name)
+
+ if successfully_loaded:
+ print(f"Successfully loaded weights for: {', '.join(successfully_loaded)}")
+ if failed_to_load:
+ print(f"Failed to load weights for: {', '.join(failed_to_load)}")
+
+ return len(failed_to_load) == 0 # ๅฆๆๆฒกๆๅ ่ฝฝๅคฑ่ดฅ็็ปไปถ๏ผๅ่ฟๅTrue
+
+# ๆน่ฟ็้ๅฝ่ทๅๅฑๆงๅฝๆฐ
+def recursive_getattr(obj, attr):
+ attrs = attr.split(".")
+ for i in range(len(attrs)):
+ obj = getattr(obj, attrs[i])
+ return obj
+
+# ้ๅฝ่ฎพ็ฝฎๅฑๆงๅฝๆฐ
+def recursive_setattr(obj, attr, val):
+ attrs = attr.split(".")
+ for i in range(len(attrs)-1):
+ obj = getattr(obj, attrs[i])
+ setattr(obj, attrs[-1], val)
\ No newline at end of file
diff --git a/flux/lora/__pycache__/lora_base.cpython-310.pyc b/flux/lora/__pycache__/lora_base.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..78f7a0cdb646b9c6b7e9a98a96ee5671e1b419ad
Binary files /dev/null and b/flux/lora/__pycache__/lora_base.cpython-310.pyc differ
diff --git a/flux/lora/__pycache__/lora_conversion_utils.cpython-310.pyc b/flux/lora/__pycache__/lora_conversion_utils.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..59f9ea8e4518efe38a18426440dd1cc157ae6df4
Binary files /dev/null and b/flux/lora/__pycache__/lora_conversion_utils.cpython-310.pyc differ
diff --git a/flux/lora/__pycache__/lora_pipeline.cpython-310.pyc b/flux/lora/__pycache__/lora_pipeline.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..d793f7a6e3608913aaa6e35fa459dd847827776d
Binary files /dev/null and b/flux/lora/__pycache__/lora_pipeline.cpython-310.pyc differ
diff --git a/flux/lora/__pycache__/peft.cpython-310.pyc b/flux/lora/__pycache__/peft.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..b72e16851224d2cea144c6c222f153f2e8f0dba8
Binary files /dev/null and b/flux/lora/__pycache__/peft.cpython-310.pyc differ
diff --git a/flux/lora/lora_base.py b/flux/lora/lora_base.py
new file mode 100644
index 0000000000000000000000000000000000000000..dcaad3af61bfff1225ad0271094a47a546ca1550
--- /dev/null
+++ b/flux/lora/lora_base.py
@@ -0,0 +1,752 @@
+# Copyright 2024 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import copy
+import inspect
+import os
+from pathlib import Path
+from typing import Callable, Dict, List, Optional, Union
+
+import safetensors
+import torch
+import torch.nn as nn
+from huggingface_hub import model_info
+from huggingface_hub.constants import HF_HUB_OFFLINE
+
+from diffusers.models.modeling_utils import ModelMixin, load_state_dict
+from diffusers.utils import (
+ USE_PEFT_BACKEND,
+ _get_model_file,
+ delete_adapter_layers,
+ deprecate,
+ is_accelerate_available,
+ is_peft_available,
+ is_transformers_available,
+ logging,
+ recurse_remove_peft_layers,
+ set_adapter_layers,
+ set_weights_and_activate_adapters,
+)
+
+
+if is_transformers_available():
+ from transformers import PreTrainedModel
+
+if is_peft_available():
+ from peft.tuners.tuners_utils import BaseTunerLayer
+
+if is_accelerate_available():
+ from accelerate.hooks import AlignDevicesHook, CpuOffload, remove_hook_from_module
+
+logger = logging.get_logger(__name__)
+
+
+def fuse_text_encoder_lora(text_encoder, lora_scale=1.0, safe_fusing=False, adapter_names=None):
+ """
+ Fuses LoRAs for the text encoder.
+
+ Args:
+ text_encoder (`torch.nn.Module`):
+ The text encoder module to set the adapter layers for. If `None`, it will try to get the `text_encoder`
+ attribute.
+ lora_scale (`float`, defaults to 1.0):
+ Controls how much to influence the outputs with the LoRA parameters.
+ safe_fusing (`bool`, defaults to `False`):
+ Whether to check fused weights for NaN values before fusing and if values are NaN not fusing them.
+ adapter_names (`List[str]` or `str`):
+ The names of the adapters to use.
+ """
+ merge_kwargs = {"safe_merge": safe_fusing}
+
+ for module in text_encoder.modules():
+ if isinstance(module, BaseTunerLayer):
+ if lora_scale != 1.0:
+ module.scale_layer(lora_scale)
+
+ # For BC with previous PEFT versions, we need to check the signature
+ # of the `merge` method to see if it supports the `adapter_names` argument.
+ supported_merge_kwargs = list(inspect.signature(module.merge).parameters)
+ if "adapter_names" in supported_merge_kwargs:
+ merge_kwargs["adapter_names"] = adapter_names
+ elif "adapter_names" not in supported_merge_kwargs and adapter_names is not None:
+ raise ValueError(
+ "The `adapter_names` argument is not supported with your PEFT version. "
+ "Please upgrade to the latest version of PEFT. `pip install -U peft`"
+ )
+
+ module.merge(**merge_kwargs)
+
+
+def unfuse_text_encoder_lora(text_encoder):
+ """
+ Unfuses LoRAs for the text encoder.
+
+ Args:
+ text_encoder (`torch.nn.Module`):
+ The text encoder module to set the adapter layers for. If `None`, it will try to get the `text_encoder`
+ attribute.
+ """
+ for module in text_encoder.modules():
+ if isinstance(module, BaseTunerLayer):
+ module.unmerge()
+
+
+def set_adapters_for_text_encoder(
+ adapter_names: Union[List[str], str],
+ text_encoder: Optional["PreTrainedModel"] = None, # noqa: F821
+ text_encoder_weights: Optional[Union[float, List[float], List[None]]] = None,
+):
+ """
+ Sets the adapter layers for the text encoder.
+
+ Args:
+ adapter_names (`List[str]` or `str`):
+ The names of the adapters to use.
+ text_encoder (`torch.nn.Module`, *optional*):
+ The text encoder module to set the adapter layers for. If `None`, it will try to get the `text_encoder`
+ attribute.
+ text_encoder_weights (`List[float]`, *optional*):
+ The weights to use for the text encoder. If `None`, the weights are set to `1.0` for all the adapters.
+ """
+ if text_encoder is None:
+ raise ValueError(
+ "The pipeline does not have a default `pipe.text_encoder` class. Please make sure to pass a `text_encoder` instead."
+ )
+
+ def process_weights(adapter_names, weights):
+ # Expand weights into a list, one entry per adapter
+ # e.g. for 2 adapters: 7 -> [7,7] ; [3, None] -> [3, None]
+ if not isinstance(weights, list):
+ weights = [weights] * len(adapter_names)
+
+ if len(adapter_names) != len(weights):
+ raise ValueError(
+ f"Length of adapter names {len(adapter_names)} is not equal to the length of the weights {len(weights)}"
+ )
+
+ # Set None values to default of 1.0
+ # e.g. [7,7] -> [7,7] ; [3, None] -> [3,1]
+ weights = [w if w is not None else 1.0 for w in weights]
+
+ return weights
+
+ adapter_names = [adapter_names] if isinstance(adapter_names, str) else adapter_names
+ text_encoder_weights = process_weights(adapter_names, text_encoder_weights)
+ set_weights_and_activate_adapters(text_encoder, adapter_names, text_encoder_weights)
+
+
+def disable_lora_for_text_encoder(text_encoder: Optional["PreTrainedModel"] = None):
+ """
+ Disables the LoRA layers for the text encoder.
+
+ Args:
+ text_encoder (`torch.nn.Module`, *optional*):
+ The text encoder module to disable the LoRA layers for. If `None`, it will try to get the `text_encoder`
+ attribute.
+ """
+ if text_encoder is None:
+ raise ValueError("Text Encoder not found.")
+ set_adapter_layers(text_encoder, enabled=False)
+
+
+def enable_lora_for_text_encoder(text_encoder: Optional["PreTrainedModel"] = None):
+ """
+ Enables the LoRA layers for the text encoder.
+
+ Args:
+ text_encoder (`torch.nn.Module`, *optional*):
+ The text encoder module to enable the LoRA layers for. If `None`, it will try to get the `text_encoder`
+ attribute.
+ """
+ if text_encoder is None:
+ raise ValueError("Text Encoder not found.")
+ set_adapter_layers(text_encoder, enabled=True)
+
+
+def _remove_text_encoder_monkey_patch(text_encoder):
+ recurse_remove_peft_layers(text_encoder)
+ if getattr(text_encoder, "peft_config", None) is not None:
+ del text_encoder.peft_config
+ text_encoder._hf_peft_config_loaded = None
+
+
+class LoraBaseMixin:
+ """Utility class for handling LoRAs."""
+
+ _lora_loadable_modules = []
+ num_fused_loras = 0
+
+ def load_lora_weights(self, **kwargs):
+ raise NotImplementedError("`load_lora_weights()` is not implemented.")
+
+ @classmethod
+ def save_lora_weights(cls, **kwargs):
+ raise NotImplementedError("`save_lora_weights()` not implemented.")
+
+ @classmethod
+ def lora_state_dict(cls, **kwargs):
+ raise NotImplementedError("`lora_state_dict()` is not implemented.")
+
+ @classmethod
+ def _optionally_disable_offloading(cls, _pipeline):
+ """
+ Optionally removes offloading in case the pipeline has been already sequentially offloaded to CPU.
+
+ Args:
+ _pipeline (`DiffusionPipeline`):
+ The pipeline to disable offloading for.
+
+ Returns:
+ tuple:
+ A tuple indicating if `is_model_cpu_offload` or `is_sequential_cpu_offload` is True.
+ """
+ is_model_cpu_offload = False
+ is_sequential_cpu_offload = False
+
+ if _pipeline is not None and _pipeline.hf_device_map is None:
+ for _, component in _pipeline.components.items():
+ if isinstance(component, nn.Module) and hasattr(component, "_hf_hook"):
+ if not is_model_cpu_offload:
+ is_model_cpu_offload = isinstance(component._hf_hook, CpuOffload)
+ if not is_sequential_cpu_offload:
+ is_sequential_cpu_offload = (
+ isinstance(component._hf_hook, AlignDevicesHook)
+ or hasattr(component._hf_hook, "hooks")
+ and isinstance(component._hf_hook.hooks[0], AlignDevicesHook)
+ )
+
+ logger.info(
+ "Accelerate hooks detected. Since you have called `load_lora_weights()`, the previous hooks will be first removed. Then the LoRA parameters will be loaded and the hooks will be applied again."
+ )
+ remove_hook_from_module(component, recurse=is_sequential_cpu_offload)
+
+ return (is_model_cpu_offload, is_sequential_cpu_offload)
+
+ @classmethod
+ def _fetch_state_dict(
+ cls,
+ pretrained_model_name_or_path_or_dict,
+ weight_name,
+ use_safetensors,
+ local_files_only,
+ cache_dir,
+ force_download,
+ proxies,
+ token,
+ revision,
+ subfolder,
+ user_agent,
+ allow_pickle,
+ ):
+ from .lora_pipeline import LORA_WEIGHT_NAME, LORA_WEIGHT_NAME_SAFE
+
+ model_file = None
+ if not isinstance(pretrained_model_name_or_path_or_dict, dict):
+ # Let's first try to load .safetensors weights
+ if (use_safetensors and weight_name is None) or (
+ weight_name is not None and weight_name.endswith(".safetensors")
+ ):
+ try:
+ # Here we're relaxing the loading check to enable more Inference API
+ # friendliness where sometimes, it's not at all possible to automatically
+ # determine `weight_name`.
+ if weight_name is None:
+ weight_name = cls._best_guess_weight_name(
+ pretrained_model_name_or_path_or_dict,
+ file_extension=".safetensors",
+ local_files_only=local_files_only,
+ )
+ model_file = _get_model_file(
+ pretrained_model_name_or_path_or_dict,
+ weights_name=weight_name or LORA_WEIGHT_NAME_SAFE,
+ cache_dir=cache_dir,
+ force_download=force_download,
+ proxies=proxies,
+ local_files_only=local_files_only,
+ token=token,
+ revision=revision,
+ subfolder=subfolder,
+ user_agent=user_agent,
+ )
+ state_dict = safetensors.torch.load_file(model_file, device="cpu")
+ except (IOError, safetensors.SafetensorError) as e:
+ if not allow_pickle:
+ raise e
+ # try loading non-safetensors weights
+ model_file = None
+ pass
+
+ if model_file is None:
+ if weight_name is None:
+ weight_name = cls._best_guess_weight_name(
+ pretrained_model_name_or_path_or_dict, file_extension=".bin", local_files_only=local_files_only
+ )
+ model_file = _get_model_file(
+ pretrained_model_name_or_path_or_dict,
+ weights_name=weight_name or LORA_WEIGHT_NAME,
+ cache_dir=cache_dir,
+ force_download=force_download,
+ proxies=proxies,
+ local_files_only=local_files_only,
+ token=token,
+ revision=revision,
+ subfolder=subfolder,
+ user_agent=user_agent,
+ )
+ state_dict = load_state_dict(model_file)
+ else:
+ state_dict = pretrained_model_name_or_path_or_dict
+
+ return state_dict
+
+ @classmethod
+ def _best_guess_weight_name(
+ cls, pretrained_model_name_or_path_or_dict, file_extension=".safetensors", local_files_only=False
+ ):
+ from .lora_pipeline import LORA_WEIGHT_NAME, LORA_WEIGHT_NAME_SAFE
+
+ if local_files_only or HF_HUB_OFFLINE:
+ raise ValueError("When using the offline mode, you must specify a `weight_name`.")
+
+ targeted_files = []
+
+ if os.path.isfile(pretrained_model_name_or_path_or_dict):
+ return
+ elif os.path.isdir(pretrained_model_name_or_path_or_dict):
+ targeted_files = [
+ f for f in os.listdir(pretrained_model_name_or_path_or_dict) if f.endswith(file_extension)
+ ]
+ else:
+ files_in_repo = model_info(pretrained_model_name_or_path_or_dict).siblings
+ targeted_files = [f.rfilename for f in files_in_repo if f.rfilename.endswith(file_extension)]
+ if len(targeted_files) == 0:
+ return
+
+ # "scheduler" does not correspond to a LoRA checkpoint.
+ # "optimizer" does not correspond to a LoRA checkpoint
+ # only top-level checkpoints are considered and not the other ones, hence "checkpoint".
+ unallowed_substrings = {"scheduler", "optimizer", "checkpoint"}
+ targeted_files = list(
+ filter(lambda x: all(substring not in x for substring in unallowed_substrings), targeted_files)
+ )
+
+ if any(f.endswith(LORA_WEIGHT_NAME) for f in targeted_files):
+ targeted_files = list(filter(lambda x: x.endswith(LORA_WEIGHT_NAME), targeted_files))
+ elif any(f.endswith(LORA_WEIGHT_NAME_SAFE) for f in targeted_files):
+ targeted_files = list(filter(lambda x: x.endswith(LORA_WEIGHT_NAME_SAFE), targeted_files))
+
+ if len(targeted_files) > 1:
+ raise ValueError(
+ f"Provided path contains more than one weights file in the {file_extension} format. Either specify `weight_name` in `load_lora_weights` or make sure there's only one `.safetensors` or `.bin` file in {pretrained_model_name_or_path_or_dict}."
+ )
+ weight_name = targeted_files[0]
+ return weight_name
+
+ def unload_lora_weights(self):
+ """
+ Unloads the LoRA parameters.
+
+ Examples:
+
+ ```python
+ >>> # Assuming `pipeline` is already loaded with the LoRA parameters.
+ >>> pipeline.unload_lora_weights()
+ >>> ...
+ ```
+ """
+ if not USE_PEFT_BACKEND:
+ raise ValueError("PEFT backend is required for this method.")
+
+ for component in self._lora_loadable_modules:
+ model = getattr(self, component, None)
+ if model is not None:
+ if issubclass(model.__class__, ModelMixin):
+ model.unload_lora()
+ elif issubclass(model.__class__, PreTrainedModel):
+ _remove_text_encoder_monkey_patch(model)
+
+ def fuse_lora(
+ self,
+ components: List[str] = [],
+ lora_scale: float = 1.0,
+ safe_fusing: bool = False,
+ adapter_names: Optional[List[str]] = None,
+ **kwargs,
+ ):
+ r"""
+ Fuses the LoRA parameters into the original parameters of the corresponding blocks.
+
+
+
+ This is an experimental API.
+
+
+
+ Args:
+ components: (`List[str]`): List of LoRA-injectable components to fuse the LoRAs into.
+ lora_scale (`float`, defaults to 1.0):
+ Controls how much to influence the outputs with the LoRA parameters.
+ safe_fusing (`bool`, defaults to `False`):
+ Whether to check fused weights for NaN values before fusing and if values are NaN not fusing them.
+ adapter_names (`List[str]`, *optional*):
+ Adapter names to be used for fusing. If nothing is passed, all active adapters will be fused.
+
+ Example:
+
+ ```py
+ from diffusers import DiffusionPipeline
+ import torch
+
+ pipeline = DiffusionPipeline.from_pretrained(
+ "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16
+ ).to("cuda")
+ pipeline.load_lora_weights("nerijs/pixel-art-xl", weight_name="pixel-art-xl.safetensors", adapter_name="pixel")
+ pipeline.fuse_lora(lora_scale=0.7)
+ ```
+ """
+ if "fuse_unet" in kwargs:
+ depr_message = "Passing `fuse_unet` to `fuse_lora()` is deprecated and will be ignored. Please use the `components` argument and provide a list of the components whose LoRAs are to be fused. `fuse_unet` will be removed in a future version."
+ deprecate(
+ "fuse_unet",
+ "1.0.0",
+ depr_message,
+ )
+ if "fuse_transformer" in kwargs:
+ depr_message = "Passing `fuse_transformer` to `fuse_lora()` is deprecated and will be ignored. Please use the `components` argument and provide a list of the components whose LoRAs are to be fused. `fuse_transformer` will be removed in a future version."
+ deprecate(
+ "fuse_transformer",
+ "1.0.0",
+ depr_message,
+ )
+ if "fuse_text_encoder" in kwargs:
+ depr_message = "Passing `fuse_text_encoder` to `fuse_lora()` is deprecated and will be ignored. Please use the `components` argument and provide a list of the components whose LoRAs are to be fused. `fuse_text_encoder` will be removed in a future version."
+ deprecate(
+ "fuse_text_encoder",
+ "1.0.0",
+ depr_message,
+ )
+
+ if len(components) == 0:
+ raise ValueError("`components` cannot be an empty list.")
+
+ for fuse_component in components:
+ if fuse_component not in self._lora_loadable_modules:
+ raise ValueError(f"{fuse_component} is not found in {self._lora_loadable_modules=}.")
+
+ model = getattr(self, fuse_component, None)
+ if model is not None:
+ # check if diffusers model
+ if issubclass(model.__class__, ModelMixin):
+ model.fuse_lora(lora_scale, safe_fusing=safe_fusing, adapter_names=adapter_names)
+ # handle transformers models.
+ if issubclass(model.__class__, PreTrainedModel):
+ fuse_text_encoder_lora(
+ model, lora_scale=lora_scale, safe_fusing=safe_fusing, adapter_names=adapter_names
+ )
+
+ self.num_fused_loras += 1
+
+ def unfuse_lora(self, components: List[str] = [], **kwargs):
+ r"""
+ Reverses the effect of
+ [`pipe.fuse_lora()`](https://huggingface.co/docs/diffusers/main/en/api/loaders#diffusers.loaders.LoraBaseMixin.fuse_lora).
+
+
+
+ This is an experimental API.
+
+
+
+ Args:
+ components (`List[str]`): List of LoRA-injectable components to unfuse LoRA from.
+ unfuse_unet (`bool`, defaults to `True`): Whether to unfuse the UNet LoRA parameters.
+ unfuse_text_encoder (`bool`, defaults to `True`):
+ Whether to unfuse the text encoder LoRA parameters. If the text encoder wasn't monkey-patched with the
+ LoRA parameters then it won't have any effect.
+ """
+ if "unfuse_unet" in kwargs:
+ depr_message = "Passing `unfuse_unet` to `unfuse_lora()` is deprecated and will be ignored. Please use the `components` argument. `unfuse_unet` will be removed in a future version."
+ deprecate(
+ "unfuse_unet",
+ "1.0.0",
+ depr_message,
+ )
+ if "unfuse_transformer" in kwargs:
+ depr_message = "Passing `unfuse_transformer` to `unfuse_lora()` is deprecated and will be ignored. Please use the `components` argument. `unfuse_transformer` will be removed in a future version."
+ deprecate(
+ "unfuse_transformer",
+ "1.0.0",
+ depr_message,
+ )
+ if "unfuse_text_encoder" in kwargs:
+ depr_message = "Passing `unfuse_text_encoder` to `unfuse_lora()` is deprecated and will be ignored. Please use the `components` argument. `unfuse_text_encoder` will be removed in a future version."
+ deprecate(
+ "unfuse_text_encoder",
+ "1.0.0",
+ depr_message,
+ )
+
+ if len(components) == 0:
+ raise ValueError("`components` cannot be an empty list.")
+
+ for fuse_component in components:
+ if fuse_component not in self._lora_loadable_modules:
+ raise ValueError(f"{fuse_component} is not found in {self._lora_loadable_modules=}.")
+
+ model = getattr(self, fuse_component, None)
+ if model is not None:
+ if issubclass(model.__class__, (ModelMixin, PreTrainedModel)):
+ for module in model.modules():
+ if isinstance(module, BaseTunerLayer):
+ module.unmerge()
+
+ self.num_fused_loras -= 1
+
+ def set_adapters(
+ self,
+ adapter_names: Union[List[str], str],
+ adapter_weights: Optional[Union[float, Dict, List[float], List[Dict]]] = None,
+ ):
+ adapter_names = [adapter_names] if isinstance(adapter_names, str) else adapter_names
+
+ adapter_weights = copy.deepcopy(adapter_weights)
+
+ # Expand weights into a list, one entry per adapter
+ if not isinstance(adapter_weights, list):
+ adapter_weights = [adapter_weights] * len(adapter_names)
+
+ if len(adapter_names) != len(adapter_weights):
+ raise ValueError(
+ f"Length of adapter names {len(adapter_names)} is not equal to the length of the weights {len(adapter_weights)}"
+ )
+
+ list_adapters = self.get_list_adapters() # eg {"unet": ["adapter1", "adapter2"], "text_encoder": ["adapter2"]}
+ all_adapters = {
+ adapter for adapters in list_adapters.values() for adapter in adapters
+ } # eg ["adapter1", "adapter2"]
+ invert_list_adapters = {
+ adapter: [part for part, adapters in list_adapters.items() if adapter in adapters]
+ for adapter in all_adapters
+ } # eg {"adapter1": ["unet"], "adapter2": ["unet", "text_encoder"]}
+
+ # Decompose weights into weights for denoiser and text encoders.
+ _component_adapter_weights = {}
+ for component in self._lora_loadable_modules:
+ model = getattr(self, component)
+
+ for adapter_name, weights in zip(adapter_names, adapter_weights):
+ if isinstance(weights, dict):
+ component_adapter_weights = weights.pop(component, None)
+
+ if component_adapter_weights is not None and not hasattr(self, component):
+ logger.warning(
+ f"Lora weight dict contains {component} weights but will be ignored because pipeline does not have {component}."
+ )
+
+ if component_adapter_weights is not None and component not in invert_list_adapters[adapter_name]:
+ logger.warning(
+ (
+ f"Lora weight dict for adapter '{adapter_name}' contains {component},"
+ f"but this will be ignored because {adapter_name} does not contain weights for {component}."
+ f"Valid parts for {adapter_name} are: {invert_list_adapters[adapter_name]}."
+ )
+ )
+
+ else:
+ component_adapter_weights = weights
+
+ _component_adapter_weights.setdefault(component, [])
+ _component_adapter_weights[component].append(component_adapter_weights)
+
+ if issubclass(model.__class__, ModelMixin):
+ model.set_adapters(adapter_names, _component_adapter_weights[component])
+ elif issubclass(model.__class__, PreTrainedModel):
+ set_adapters_for_text_encoder(adapter_names, model, _component_adapter_weights[component])
+
+ def disable_lora(self):
+ if not USE_PEFT_BACKEND:
+ raise ValueError("PEFT backend is required for this method.")
+
+ for component in self._lora_loadable_modules:
+ model = getattr(self, component, None)
+ if model is not None:
+ if issubclass(model.__class__, ModelMixin):
+ model.disable_lora()
+ elif issubclass(model.__class__, PreTrainedModel):
+ disable_lora_for_text_encoder(model)
+
+ def enable_lora(self):
+ if not USE_PEFT_BACKEND:
+ raise ValueError("PEFT backend is required for this method.")
+
+ for component in self._lora_loadable_modules:
+ model = getattr(self, component, None)
+ if model is not None:
+ if issubclass(model.__class__, ModelMixin):
+ model.enable_lora()
+ elif issubclass(model.__class__, PreTrainedModel):
+ enable_lora_for_text_encoder(model)
+
+ def delete_adapters(self, adapter_names: Union[List[str], str]):
+ """
+ Args:
+ Deletes the LoRA layers of `adapter_name` for the unet and text-encoder(s).
+ adapter_names (`Union[List[str], str]`):
+ The names of the adapter to delete. Can be a single string or a list of strings
+ """
+ if not USE_PEFT_BACKEND:
+ raise ValueError("PEFT backend is required for this method.")
+
+ if isinstance(adapter_names, str):
+ adapter_names = [adapter_names]
+
+ for component in self._lora_loadable_modules:
+ model = getattr(self, component, None)
+ if model is not None:
+ if issubclass(model.__class__, ModelMixin):
+ model.delete_adapters(adapter_names)
+ elif issubclass(model.__class__, PreTrainedModel):
+ for adapter_name in adapter_names:
+ delete_adapter_layers(model, adapter_name)
+
+ def get_active_adapters(self) -> List[str]:
+ """
+ Gets the list of the current active adapters.
+
+ Example:
+
+ ```python
+ from diffusers import DiffusionPipeline
+
+ pipeline = DiffusionPipeline.from_pretrained(
+ "stabilityai/stable-diffusion-xl-base-1.0",
+ ).to("cuda")
+ pipeline.load_lora_weights("CiroN2022/toy-face", weight_name="toy_face_sdxl.safetensors", adapter_name="toy")
+ pipeline.get_active_adapters()
+ ```
+ """
+ if not USE_PEFT_BACKEND:
+ raise ValueError(
+ "PEFT backend is required for this method. Please install the latest version of PEFT `pip install -U peft`"
+ )
+
+ active_adapters = []
+
+ for component in self._lora_loadable_modules:
+ model = getattr(self, component, None)
+ if model is not None and issubclass(model.__class__, ModelMixin):
+ for module in model.modules():
+ if isinstance(module, BaseTunerLayer):
+ active_adapters = module.active_adapters
+ break
+
+ return active_adapters
+
+ def get_list_adapters(self) -> Dict[str, List[str]]:
+ """
+ Gets the current list of all available adapters in the pipeline.
+ """
+ if not USE_PEFT_BACKEND:
+ raise ValueError(
+ "PEFT backend is required for this method. Please install the latest version of PEFT `pip install -U peft`"
+ )
+
+ set_adapters = {}
+
+ for component in self._lora_loadable_modules:
+ model = getattr(self, component, None)
+ if (
+ model is not None
+ and issubclass(model.__class__, (ModelMixin, PreTrainedModel))
+ and hasattr(model, "peft_config")
+ ):
+ set_adapters[component] = list(model.peft_config.keys())
+
+ return set_adapters
+
+ def set_lora_device(self, adapter_names: List[str], device: Union[torch.device, str, int]) -> None:
+ """
+ Moves the LoRAs listed in `adapter_names` to a target device. Useful for offloading the LoRA to the CPU in case
+ you want to load multiple adapters and free some GPU memory.
+
+ Args:
+ adapter_names (`List[str]`):
+ List of adapters to send device to.
+ device (`Union[torch.device, str, int]`):
+ Device to send the adapters to. Can be either a torch device, a str or an integer.
+ """
+ if not USE_PEFT_BACKEND:
+ raise ValueError("PEFT backend is required for this method.")
+
+ for component in self._lora_loadable_modules:
+ model = getattr(self, component, None)
+ if model is not None:
+ for module in model.modules():
+ if isinstance(module, BaseTunerLayer):
+ for adapter_name in adapter_names:
+ module.lora_A[adapter_name].to(device)
+ module.lora_B[adapter_name].to(device)
+ # this is a param, not a module, so device placement is not in-place -> re-assign
+ if hasattr(module, "lora_magnitude_vector") and module.lora_magnitude_vector is not None:
+ module.lora_magnitude_vector[adapter_name] = module.lora_magnitude_vector[
+ adapter_name
+ ].to(device)
+
+ @staticmethod
+ def pack_weights(layers, prefix):
+ layers_weights = layers.state_dict() if isinstance(layers, torch.nn.Module) else layers
+ layers_state_dict = {f"{prefix}.{module_name}": param for module_name, param in layers_weights.items()}
+ return layers_state_dict
+
+ @staticmethod
+ def write_lora_layers(
+ state_dict: Dict[str, torch.Tensor],
+ save_directory: str,
+ is_main_process: bool,
+ weight_name: str,
+ save_function: Callable,
+ safe_serialization: bool,
+ ):
+ from .lora_pipeline import LORA_WEIGHT_NAME, LORA_WEIGHT_NAME_SAFE
+
+ if os.path.isfile(save_directory):
+ logger.error(f"Provided path ({save_directory}) should be a directory, not a file")
+ return
+
+ if save_function is None:
+ if safe_serialization:
+
+ def save_function(weights, filename):
+ return safetensors.torch.save_file(weights, filename, metadata={"format": "pt"})
+
+ else:
+ save_function = torch.save
+
+ os.makedirs(save_directory, exist_ok=True)
+
+ if weight_name is None:
+ if safe_serialization:
+ weight_name = LORA_WEIGHT_NAME_SAFE
+ else:
+ weight_name = LORA_WEIGHT_NAME
+
+ save_path = Path(save_directory, weight_name).as_posix()
+ save_function(state_dict, save_path)
+ logger.info(f"Model weights saved in {save_path}")
+
+ @property
+ def lora_scale(self) -> float:
+ # property function that returns the lora scale which can be set at run time by the pipeline.
+ # if _lora_scale has not been set, return 1
+ return self._lora_scale if hasattr(self, "_lora_scale") else 1.0
\ No newline at end of file
diff --git a/flux/lora/lora_conversion_utils.py b/flux/lora/lora_conversion_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..2a17987097997541fec8f313b5dc02df6042e60d
--- /dev/null
+++ b/flux/lora/lora_conversion_utils.py
@@ -0,0 +1,328 @@
+# Copyright 2024 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import re
+
+from diffusers.utils import is_peft_version, logging
+
+
+logger = logging.get_logger(__name__)
+
+
+def _maybe_map_sgm_blocks_to_diffusers(state_dict, unet_config, delimiter="_", block_slice_pos=5):
+ # 1. get all state_dict_keys
+ all_keys = list(state_dict.keys())
+ sgm_patterns = ["input_blocks", "middle_block", "output_blocks"]
+
+ # 2. check if needs remapping, if not return original dict
+ is_in_sgm_format = False
+ for key in all_keys:
+ if any(p in key for p in sgm_patterns):
+ is_in_sgm_format = True
+ break
+
+ if not is_in_sgm_format:
+ return state_dict
+
+ # 3. Else remap from SGM patterns
+ new_state_dict = {}
+ inner_block_map = ["resnets", "attentions", "upsamplers"]
+
+ # Retrieves # of down, mid and up blocks
+ input_block_ids, middle_block_ids, output_block_ids = set(), set(), set()
+
+ for layer in all_keys:
+ if "text" in layer:
+ new_state_dict[layer] = state_dict.pop(layer)
+ else:
+ layer_id = int(layer.split(delimiter)[:block_slice_pos][-1])
+ if sgm_patterns[0] in layer:
+ input_block_ids.add(layer_id)
+ elif sgm_patterns[1] in layer:
+ middle_block_ids.add(layer_id)
+ elif sgm_patterns[2] in layer:
+ output_block_ids.add(layer_id)
+ else:
+ raise ValueError(f"Checkpoint not supported because layer {layer} not supported.")
+
+ input_blocks = {
+ layer_id: [key for key in state_dict if f"input_blocks{delimiter}{layer_id}" in key]
+ for layer_id in input_block_ids
+ }
+ middle_blocks = {
+ layer_id: [key for key in state_dict if f"middle_block{delimiter}{layer_id}" in key]
+ for layer_id in middle_block_ids
+ }
+ output_blocks = {
+ layer_id: [key for key in state_dict if f"output_blocks{delimiter}{layer_id}" in key]
+ for layer_id in output_block_ids
+ }
+
+ # Rename keys accordingly
+ for i in input_block_ids:
+ block_id = (i - 1) // (unet_config.layers_per_block + 1)
+ layer_in_block_id = (i - 1) % (unet_config.layers_per_block + 1)
+
+ for key in input_blocks[i]:
+ inner_block_id = int(key.split(delimiter)[block_slice_pos])
+ inner_block_key = inner_block_map[inner_block_id] if "op" not in key else "downsamplers"
+ inner_layers_in_block = str(layer_in_block_id) if "op" not in key else "0"
+ new_key = delimiter.join(
+ key.split(delimiter)[: block_slice_pos - 1]
+ + [str(block_id), inner_block_key, inner_layers_in_block]
+ + key.split(delimiter)[block_slice_pos + 1 :]
+ )
+ new_state_dict[new_key] = state_dict.pop(key)
+
+ for i in middle_block_ids:
+ key_part = None
+ if i == 0:
+ key_part = [inner_block_map[0], "0"]
+ elif i == 1:
+ key_part = [inner_block_map[1], "0"]
+ elif i == 2:
+ key_part = [inner_block_map[0], "1"]
+ else:
+ raise ValueError(f"Invalid middle block id {i}.")
+
+ for key in middle_blocks[i]:
+ new_key = delimiter.join(
+ key.split(delimiter)[: block_slice_pos - 1] + key_part + key.split(delimiter)[block_slice_pos:]
+ )
+ new_state_dict[new_key] = state_dict.pop(key)
+
+ for i in output_block_ids:
+ block_id = i // (unet_config.layers_per_block + 1)
+ layer_in_block_id = i % (unet_config.layers_per_block + 1)
+
+ for key in output_blocks[i]:
+ inner_block_id = int(key.split(delimiter)[block_slice_pos])
+ inner_block_key = inner_block_map[inner_block_id]
+ inner_layers_in_block = str(layer_in_block_id) if inner_block_id < 2 else "0"
+ new_key = delimiter.join(
+ key.split(delimiter)[: block_slice_pos - 1]
+ + [str(block_id), inner_block_key, inner_layers_in_block]
+ + key.split(delimiter)[block_slice_pos + 1 :]
+ )
+ new_state_dict[new_key] = state_dict.pop(key)
+
+ if len(state_dict) > 0:
+ raise ValueError("At this point all state dict entries have to be converted.")
+
+ return new_state_dict
+
+
+def _convert_non_diffusers_lora_to_diffusers(state_dict, unet_name="unet", text_encoder_name="text_encoder"):
+ """
+ Converts a non-Diffusers LoRA state dict to a Diffusers compatible state dict.
+
+ Args:
+ state_dict (`dict`): The state dict to convert.
+ unet_name (`str`, optional): The name of the U-Net module in the Diffusers model. Defaults to "unet".
+ text_encoder_name (`str`, optional): The name of the text encoder module in the Diffusers model. Defaults to
+ "text_encoder".
+
+ Returns:
+ `tuple`: A tuple containing the converted state dict and a dictionary of alphas.
+ """
+ unet_state_dict = {}
+ te_state_dict = {}
+ te2_state_dict = {}
+ network_alphas = {}
+
+ # Check for DoRA-enabled LoRAs.
+ dora_present_in_unet = any("dora_scale" in k and "lora_unet_" in k for k in state_dict)
+ dora_present_in_te = any("dora_scale" in k and ("lora_te_" in k or "lora_te1_" in k) for k in state_dict)
+ dora_present_in_te2 = any("dora_scale" in k and "lora_te2_" in k for k in state_dict)
+ if dora_present_in_unet or dora_present_in_te or dora_present_in_te2:
+ if is_peft_version("<", "0.9.0"):
+ raise ValueError(
+ "You need `peft` 0.9.0 at least to use DoRA-enabled LoRAs. Please upgrade your installation of `peft`."
+ )
+
+ # Iterate over all LoRA weights.
+ all_lora_keys = list(state_dict.keys())
+ for key in all_lora_keys:
+ if not key.endswith("lora_down.weight"):
+ continue
+
+ # Extract LoRA name.
+ lora_name = key.split(".")[0]
+
+ # Find corresponding up weight and alpha.
+ lora_name_up = lora_name + ".lora_up.weight"
+ lora_name_alpha = lora_name + ".alpha"
+
+ # Handle U-Net LoRAs.
+ if lora_name.startswith("lora_unet_"):
+ diffusers_name = _convert_unet_lora_key(key)
+
+ # Store down and up weights.
+ unet_state_dict[diffusers_name] = state_dict.pop(key)
+ unet_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict.pop(lora_name_up)
+
+ # Store DoRA scale if present.
+ if dora_present_in_unet:
+ dora_scale_key_to_replace = "_lora.down." if "_lora.down." in diffusers_name else ".lora.down."
+ unet_state_dict[
+ diffusers_name.replace(dora_scale_key_to_replace, ".lora_magnitude_vector.")
+ ] = state_dict.pop(key.replace("lora_down.weight", "dora_scale"))
+
+ # Handle text encoder LoRAs.
+ elif lora_name.startswith(("lora_te_", "lora_te1_", "lora_te2_")):
+ diffusers_name = _convert_text_encoder_lora_key(key, lora_name)
+
+ # Store down and up weights for te or te2.
+ if lora_name.startswith(("lora_te_", "lora_te1_")):
+ te_state_dict[diffusers_name] = state_dict.pop(key)
+ te_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict.pop(lora_name_up)
+ else:
+ te2_state_dict[diffusers_name] = state_dict.pop(key)
+ te2_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict.pop(lora_name_up)
+
+ # Store DoRA scale if present.
+ if dora_present_in_te or dora_present_in_te2:
+ dora_scale_key_to_replace_te = (
+ "_lora.down." if "_lora.down." in diffusers_name else ".lora_linear_layer."
+ )
+ if lora_name.startswith(("lora_te_", "lora_te1_")):
+ te_state_dict[
+ diffusers_name.replace(dora_scale_key_to_replace_te, ".lora_magnitude_vector.")
+ ] = state_dict.pop(key.replace("lora_down.weight", "dora_scale"))
+ elif lora_name.startswith("lora_te2_"):
+ te2_state_dict[
+ diffusers_name.replace(dora_scale_key_to_replace_te, ".lora_magnitude_vector.")
+ ] = state_dict.pop(key.replace("lora_down.weight", "dora_scale"))
+
+ # Store alpha if present.
+ if lora_name_alpha in state_dict:
+ alpha = state_dict.pop(lora_name_alpha).item()
+ network_alphas.update(_get_alpha_name(lora_name_alpha, diffusers_name, alpha))
+
+ # Check if any keys remain.
+ if len(state_dict) > 0:
+ raise ValueError(f"The following keys have not been correctly renamed: \n\n {', '.join(state_dict.keys())}")
+
+ logger.info("Non-diffusers checkpoint detected.")
+
+ # Construct final state dict.
+ unet_state_dict = {f"{unet_name}.{module_name}": params for module_name, params in unet_state_dict.items()}
+ te_state_dict = {f"{text_encoder_name}.{module_name}": params for module_name, params in te_state_dict.items()}
+ te2_state_dict = (
+ {f"text_encoder_2.{module_name}": params for module_name, params in te2_state_dict.items()}
+ if len(te2_state_dict) > 0
+ else None
+ )
+ if te2_state_dict is not None:
+ te_state_dict.update(te2_state_dict)
+
+ new_state_dict = {**unet_state_dict, **te_state_dict}
+ return new_state_dict, network_alphas
+
+
+def _convert_unet_lora_key(key):
+ """
+ Converts a U-Net LoRA key to a Diffusers compatible key.
+ """
+ diffusers_name = key.replace("lora_unet_", "").replace("_", ".")
+
+ # Replace common U-Net naming patterns.
+ diffusers_name = diffusers_name.replace("input.blocks", "down_blocks")
+ diffusers_name = diffusers_name.replace("down.blocks", "down_blocks")
+ diffusers_name = diffusers_name.replace("middle.block", "mid_block")
+ diffusers_name = diffusers_name.replace("mid.block", "mid_block")
+ diffusers_name = diffusers_name.replace("output.blocks", "up_blocks")
+ diffusers_name = diffusers_name.replace("up.blocks", "up_blocks")
+ diffusers_name = diffusers_name.replace("transformer.blocks", "transformer_blocks")
+ diffusers_name = diffusers_name.replace("to.q.lora", "to_q_lora")
+ diffusers_name = diffusers_name.replace("to.k.lora", "to_k_lora")
+ diffusers_name = diffusers_name.replace("to.v.lora", "to_v_lora")
+ diffusers_name = diffusers_name.replace("to.out.0.lora", "to_out_lora")
+ diffusers_name = diffusers_name.replace("proj.in", "proj_in")
+ diffusers_name = diffusers_name.replace("proj.out", "proj_out")
+ diffusers_name = diffusers_name.replace("emb.layers", "time_emb_proj")
+
+ # SDXL specific conversions.
+ if "emb" in diffusers_name and "time.emb.proj" not in diffusers_name:
+ pattern = r"\.\d+(?=\D*$)"
+ diffusers_name = re.sub(pattern, "", diffusers_name, count=1)
+ if ".in." in diffusers_name:
+ diffusers_name = diffusers_name.replace("in.layers.2", "conv1")
+ if ".out." in diffusers_name:
+ diffusers_name = diffusers_name.replace("out.layers.3", "conv2")
+ if "downsamplers" in diffusers_name or "upsamplers" in diffusers_name:
+ diffusers_name = diffusers_name.replace("op", "conv")
+ if "skip" in diffusers_name:
+ diffusers_name = diffusers_name.replace("skip.connection", "conv_shortcut")
+
+ # LyCORIS specific conversions.
+ if "time.emb.proj" in diffusers_name:
+ diffusers_name = diffusers_name.replace("time.emb.proj", "time_emb_proj")
+ if "conv.shortcut" in diffusers_name:
+ diffusers_name = diffusers_name.replace("conv.shortcut", "conv_shortcut")
+
+ # General conversions.
+ if "transformer_blocks" in diffusers_name:
+ if "attn1" in diffusers_name or "attn2" in diffusers_name:
+ diffusers_name = diffusers_name.replace("attn1", "attn1.processor")
+ diffusers_name = diffusers_name.replace("attn2", "attn2.processor")
+ elif "ff" in diffusers_name:
+ pass
+ elif any(key in diffusers_name for key in ("proj_in", "proj_out")):
+ pass
+ else:
+ pass
+
+ return diffusers_name
+
+
+def _convert_text_encoder_lora_key(key, lora_name):
+ """
+ Converts a text encoder LoRA key to a Diffusers compatible key.
+ """
+ if lora_name.startswith(("lora_te_", "lora_te1_")):
+ key_to_replace = "lora_te_" if lora_name.startswith("lora_te_") else "lora_te1_"
+ else:
+ key_to_replace = "lora_te2_"
+
+ diffusers_name = key.replace(key_to_replace, "").replace("_", ".")
+ diffusers_name = diffusers_name.replace("text.model", "text_model")
+ diffusers_name = diffusers_name.replace("self.attn", "self_attn")
+ diffusers_name = diffusers_name.replace("q.proj.lora", "to_q_lora")
+ diffusers_name = diffusers_name.replace("k.proj.lora", "to_k_lora")
+ diffusers_name = diffusers_name.replace("v.proj.lora", "to_v_lora")
+ diffusers_name = diffusers_name.replace("out.proj.lora", "to_out_lora")
+ diffusers_name = diffusers_name.replace("text.projection", "text_projection")
+
+ if "self_attn" in diffusers_name or "text_projection" in diffusers_name:
+ pass
+ elif "mlp" in diffusers_name:
+ # Be aware that this is the new diffusers convention and the rest of the code might
+ # not utilize it yet.
+ diffusers_name = diffusers_name.replace(".lora.", ".lora_linear_layer.")
+ return diffusers_name
+
+
+def _get_alpha_name(lora_name_alpha, diffusers_name, alpha):
+ """
+ Gets the correct alpha name for the Diffusers model.
+ """
+ if lora_name_alpha.startswith("lora_unet_"):
+ prefix = "unet."
+ elif lora_name_alpha.startswith(("lora_te_", "lora_te1_")):
+ prefix = "text_encoder."
+ else:
+ prefix = "text_encoder_2."
+ new_name = prefix + diffusers_name.split(".lora.")[0] + ".alpha"
+ return {new_name: alpha}
\ No newline at end of file
diff --git a/flux/lora/lora_pipeline.py b/flux/lora/lora_pipeline.py
new file mode 100644
index 0000000000000000000000000000000000000000..bf7318ee6a56ebb998781c0f46b6b22851d4d26d
--- /dev/null
+++ b/flux/lora/lora_pipeline.py
@@ -0,0 +1,2236 @@
+# Copyright 2024 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import os
+from typing import Callable, Dict, List, Optional, Union
+
+import torch
+from huggingface_hub.utils import validate_hf_hub_args
+
+from diffusers.utils import (
+ USE_PEFT_BACKEND,
+ convert_state_dict_to_diffusers,
+ convert_state_dict_to_peft,
+ convert_unet_state_dict_to_peft,
+ deprecate,
+ get_adapter_name,
+ get_peft_kwargs,
+ is_peft_version,
+ is_transformers_available,
+ logging,
+ scale_lora_layers,
+)
+from .lora_base import LoraBaseMixin
+from .lora_conversion_utils import _convert_non_diffusers_lora_to_diffusers, _maybe_map_sgm_blocks_to_diffusers
+
+
+if is_transformers_available():
+ from diffusers.models.lora import text_encoder_attn_modules, text_encoder_mlp_modules
+
+logger = logging.get_logger(__name__)
+
+TEXT_ENCODER_NAME = "text_encoder"
+UNET_NAME = "unet"
+TRANSFORMER_NAME = "transformer"
+
+LORA_WEIGHT_NAME = "pytorch_lora_weights.bin"
+LORA_WEIGHT_NAME_SAFE = "pytorch_lora_weights.safetensors"
+
+
+class StableDiffusionLoraLoaderMixin(LoraBaseMixin):
+ r"""
+ Load LoRA layers into Stable Diffusion [`UNet2DConditionModel`] and
+ [`CLIPTextModel`](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel).
+ """
+
+ _lora_loadable_modules = ["unet", "text_encoder"]
+ unet_name = UNET_NAME
+ text_encoder_name = TEXT_ENCODER_NAME
+
+ def load_lora_weights(
+ self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], adapter_name=None, **kwargs
+ ):
+ """
+ Load LoRA weights specified in `pretrained_model_name_or_path_or_dict` into `self.unet` and
+ `self.text_encoder`.
+
+ All kwargs are forwarded to `self.lora_state_dict`.
+
+ See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`] for more details on how the state dict is
+ loaded.
+
+ See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_into_unet`] for more details on how the state dict is
+ loaded into `self.unet`.
+
+ See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_into_text_encoder`] for more details on how the state
+ dict is loaded into `self.text_encoder`.
+
+ Parameters:
+ pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
+ See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`].
+ kwargs (`dict`, *optional*):
+ See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`].
+ adapter_name (`str`, *optional*):
+ Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
+ `default_{i}` where i is the total number of adapters being loaded.
+ """
+ if not USE_PEFT_BACKEND:
+ raise ValueError("PEFT backend is required for this method.")
+
+ # if a dict is passed, copy it instead of modifying it inplace
+ if isinstance(pretrained_model_name_or_path_or_dict, dict):
+ pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict.copy()
+
+ # First, ensure that the checkpoint is a compatible one and can be successfully loaded.
+ state_dict, network_alphas = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs)
+
+ is_correct_format = all("lora" in key or "dora_scale" in key for key in state_dict.keys())
+ if not is_correct_format:
+ raise ValueError("Invalid LoRA checkpoint.")
+
+ self.load_lora_into_unet(
+ state_dict,
+ network_alphas=network_alphas,
+ unet=getattr(self, self.unet_name) if not hasattr(self, "unet") else self.unet,
+ adapter_name=adapter_name,
+ _pipeline=self,
+ )
+ self.load_lora_into_text_encoder(
+ state_dict,
+ network_alphas=network_alphas,
+ text_encoder=getattr(self, self.text_encoder_name)
+ if not hasattr(self, "text_encoder")
+ else self.text_encoder,
+ lora_scale=self.lora_scale,
+ adapter_name=adapter_name,
+ _pipeline=self,
+ )
+
+ @classmethod
+ @validate_hf_hub_args
+ def lora_state_dict(
+ cls,
+ pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
+ **kwargs,
+ ):
+ r"""
+ Return state dict for lora weights and the network alphas.
+
+
+
+ We support loading A1111 formatted LoRA checkpoints in a limited capacity.
+
+ This function is experimental and might change in the future.
+
+
+
+ Parameters:
+ pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
+ Can be either:
+
+ - A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on
+ the Hub.
+ - A path to a *directory* (for example `./my_model_directory`) containing the model weights saved
+ with [`ModelMixin.save_pretrained`].
+ - A [torch state
+ dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict).
+
+ cache_dir (`Union[str, os.PathLike]`, *optional*):
+ Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
+ is not used.
+ force_download (`bool`, *optional*, defaults to `False`):
+ Whether or not to force the (re-)download of the model weights and configuration files, overriding the
+ cached versions if they exist.
+
+ proxies (`Dict[str, str]`, *optional*):
+ A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
+ 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
+ local_files_only (`bool`, *optional*, defaults to `False`):
+ Whether to only load local model weights and configuration files or not. If set to `True`, the model
+ won't be downloaded from the Hub.
+ token (`str` or *bool*, *optional*):
+ The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
+ `diffusers-cli login` (stored in `~/.huggingface`) is used.
+ revision (`str`, *optional*, defaults to `"main"`):
+ The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
+ allowed by Git.
+ subfolder (`str`, *optional*, defaults to `""`):
+ The subfolder location of a model file within a larger model repository on the Hub or locally.
+ weight_name (`str`, *optional*, defaults to None):
+ Name of the serialized state dict file.
+ """
+ # Load the main state dict first which has the LoRA layers for either of
+ # UNet and text encoder or both.
+ cache_dir = kwargs.pop("cache_dir", None)
+ force_download = kwargs.pop("force_download", False)
+ proxies = kwargs.pop("proxies", None)
+ local_files_only = kwargs.pop("local_files_only", None)
+ token = kwargs.pop("token", None)
+ revision = kwargs.pop("revision", None)
+ subfolder = kwargs.pop("subfolder", None)
+ weight_name = kwargs.pop("weight_name", None)
+ unet_config = kwargs.pop("unet_config", None)
+ use_safetensors = kwargs.pop("use_safetensors", None)
+
+ allow_pickle = False
+ if use_safetensors is None:
+ use_safetensors = True
+ allow_pickle = True
+
+ user_agent = {
+ "file_type": "attn_procs_weights",
+ "framework": "pytorch",
+ }
+
+ state_dict = cls._fetch_state_dict(
+ pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict,
+ weight_name=weight_name,
+ use_safetensors=use_safetensors,
+ local_files_only=local_files_only,
+ cache_dir=cache_dir,
+ force_download=force_download,
+ proxies=proxies,
+ token=token,
+ revision=revision,
+ subfolder=subfolder,
+ user_agent=user_agent,
+ allow_pickle=allow_pickle,
+ )
+
+ network_alphas = None
+ # TODO: replace it with a method from `state_dict_utils`
+ if all(
+ (
+ k.startswith("lora_te_")
+ or k.startswith("lora_unet_")
+ or k.startswith("lora_te1_")
+ or k.startswith("lora_te2_")
+ )
+ for k in state_dict.keys()
+ ):
+ # Map SDXL blocks correctly.
+ if unet_config is not None:
+ # use unet config to remap block numbers
+ state_dict = _maybe_map_sgm_blocks_to_diffusers(state_dict, unet_config)
+ state_dict, network_alphas = _convert_non_diffusers_lora_to_diffusers(state_dict)
+
+ return state_dict, network_alphas
+
+ @classmethod
+ def load_lora_into_unet(cls, state_dict, network_alphas, unet, adapter_name=None, _pipeline=None):
+ """
+ This will load the LoRA layers specified in `state_dict` into `unet`.
+
+ Parameters:
+ state_dict (`dict`):
+ A standard state dict containing the lora layer parameters. The keys can either be indexed directly
+ into the unet or prefixed with an additional `unet` which can be used to distinguish between text
+ encoder lora layers.
+ network_alphas (`Dict[str, float]`):
+ The value of the network alpha used for stable learning and preventing underflow. This value has the
+ same meaning as the `--network_alpha` option in the kohya-ss trainer script. Refer to [this
+ link](https://github.com/darkstorm2150/sd-scripts/blob/main/docs/train_network_README-en.md#execute-learning).
+ unet (`UNet2DConditionModel`):
+ The UNet model to load the LoRA layers into.
+ adapter_name (`str`, *optional*):
+ Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
+ `default_{i}` where i is the total number of adapters being loaded.
+ """
+ if not USE_PEFT_BACKEND:
+ raise ValueError("PEFT backend is required for this method.")
+
+ # If the serialization format is new (introduced in https://github.com/huggingface/diffusers/pull/2918),
+ # then the `state_dict` keys should have `cls.unet_name` and/or `cls.text_encoder_name` as
+ # their prefixes.
+ keys = list(state_dict.keys())
+ only_text_encoder = all(key.startswith(cls.text_encoder_name) for key in keys)
+ if not only_text_encoder:
+ # Load the layers corresponding to UNet.
+ logger.info(f"Loading {cls.unet_name}.")
+ unet.load_attn_procs(
+ state_dict, network_alphas=network_alphas, adapter_name=adapter_name, _pipeline=_pipeline
+ )
+
+ @classmethod
+ def load_lora_into_text_encoder(
+ cls,
+ state_dict,
+ network_alphas,
+ text_encoder,
+ prefix=None,
+ lora_scale=1.0,
+ adapter_name=None,
+ _pipeline=None,
+ ):
+ """
+ This will load the LoRA layers specified in `state_dict` into `text_encoder`
+
+ Parameters:
+ state_dict (`dict`):
+ A standard state dict containing the lora layer parameters. The key should be prefixed with an
+ additional `text_encoder` to distinguish between unet lora layers.
+ network_alphas (`Dict[str, float]`):
+ See `LoRALinearLayer` for more details.
+ text_encoder (`CLIPTextModel`):
+ The text encoder model to load the LoRA layers into.
+ prefix (`str`):
+ Expected prefix of the `text_encoder` in the `state_dict`.
+ lora_scale (`float`):
+ How much to scale the output of the lora linear layer before it is added with the output of the regular
+ lora layer.
+ adapter_name (`str`, *optional*):
+ Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
+ `default_{i}` where i is the total number of adapters being loaded.
+ """
+ if not USE_PEFT_BACKEND:
+ raise ValueError("PEFT backend is required for this method.")
+
+ from peft import LoraConfig
+
+ # If the serialization format is new (introduced in https://github.com/huggingface/diffusers/pull/2918),
+ # then the `state_dict` keys should have `self.unet_name` and/or `self.text_encoder_name` as
+ # their prefixes.
+ keys = list(state_dict.keys())
+ prefix = cls.text_encoder_name if prefix is None else prefix
+
+ # Safe prefix to check with.
+ if any(cls.text_encoder_name in key for key in keys):
+ # Load the layers corresponding to text encoder and make necessary adjustments.
+ text_encoder_keys = [k for k in keys if k.startswith(prefix) and k.split(".")[0] == prefix]
+ text_encoder_lora_state_dict = {
+ k.replace(f"{prefix}.", ""): v for k, v in state_dict.items() if k in text_encoder_keys
+ }
+
+ if len(text_encoder_lora_state_dict) > 0:
+ logger.info(f"Loading {prefix}.")
+ rank = {}
+ text_encoder_lora_state_dict = convert_state_dict_to_diffusers(text_encoder_lora_state_dict)
+
+ # convert state dict
+ text_encoder_lora_state_dict = convert_state_dict_to_peft(text_encoder_lora_state_dict)
+
+ for name, _ in text_encoder_attn_modules(text_encoder):
+ for module in ("out_proj", "q_proj", "k_proj", "v_proj"):
+ rank_key = f"{name}.{module}.lora_B.weight"
+ if rank_key not in text_encoder_lora_state_dict:
+ continue
+ rank[rank_key] = text_encoder_lora_state_dict[rank_key].shape[1]
+
+ for name, _ in text_encoder_mlp_modules(text_encoder):
+ for module in ("fc1", "fc2"):
+ rank_key = f"{name}.{module}.lora_B.weight"
+ if rank_key not in text_encoder_lora_state_dict:
+ continue
+ rank[rank_key] = text_encoder_lora_state_dict[rank_key].shape[1]
+
+ if network_alphas is not None:
+ alpha_keys = [
+ k for k in network_alphas.keys() if k.startswith(prefix) and k.split(".")[0] == prefix
+ ]
+ network_alphas = {
+ k.replace(f"{prefix}.", ""): v for k, v in network_alphas.items() if k in alpha_keys
+ }
+
+ lora_config_kwargs = get_peft_kwargs(rank, network_alphas, text_encoder_lora_state_dict, is_unet=False)
+ if "use_dora" in lora_config_kwargs:
+ if lora_config_kwargs["use_dora"]:
+ if is_peft_version("<", "0.9.0"):
+ raise ValueError(
+ "You need `peft` 0.9.0 at least to use DoRA-enabled LoRAs. Please upgrade your installation of `peft`."
+ )
+ else:
+ if is_peft_version("<", "0.9.0"):
+ lora_config_kwargs.pop("use_dora")
+ lora_config = LoraConfig(**lora_config_kwargs)
+
+ # adapter_name
+ if adapter_name is None:
+ adapter_name = get_adapter_name(text_encoder)
+
+ is_model_cpu_offload, is_sequential_cpu_offload = cls._optionally_disable_offloading(_pipeline)
+
+ # inject LoRA layers and load the state dict
+ # in transformers we automatically check whether the adapter name is already in use or not
+ text_encoder.load_adapter(
+ adapter_name=adapter_name,
+ adapter_state_dict=text_encoder_lora_state_dict,
+ peft_config=lora_config,
+ )
+
+ # scale LoRA layers with `lora_scale`
+ scale_lora_layers(text_encoder, weight=lora_scale)
+
+ text_encoder.to(device=text_encoder.device, dtype=text_encoder.dtype)
+
+ # Offload back.
+ if is_model_cpu_offload:
+ _pipeline.enable_model_cpu_offload()
+ elif is_sequential_cpu_offload:
+ _pipeline.enable_sequential_cpu_offload()
+ # Unsafe code />
+
+ @classmethod
+ def save_lora_weights(
+ cls,
+ save_directory: Union[str, os.PathLike],
+ unet_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None,
+ text_encoder_lora_layers: Dict[str, torch.nn.Module] = None,
+ is_main_process: bool = True,
+ weight_name: str = None,
+ save_function: Callable = None,
+ safe_serialization: bool = True,
+ ):
+ r"""
+ Save the LoRA parameters corresponding to the UNet and text encoder.
+
+ Arguments:
+ save_directory (`str` or `os.PathLike`):
+ Directory to save LoRA parameters to. Will be created if it doesn't exist.
+ unet_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`):
+ State dict of the LoRA layers corresponding to the `unet`.
+ text_encoder_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`):
+ State dict of the LoRA layers corresponding to the `text_encoder`. Must explicitly pass the text
+ encoder LoRA state dict because it comes from ๐ค Transformers.
+ is_main_process (`bool`, *optional*, defaults to `True`):
+ Whether the process calling this is the main process or not. Useful during distributed training and you
+ need to call this function on all processes. In this case, set `is_main_process=True` only on the main
+ process to avoid race conditions.
+ save_function (`Callable`):
+ The function to use to save the state dictionary. Useful during distributed training when you need to
+ replace `torch.save` with another method. Can be configured with the environment variable
+ `DIFFUSERS_SAVE_MODE`.
+ safe_serialization (`bool`, *optional*, defaults to `True`):
+ Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`.
+ """
+ state_dict = {}
+
+ if not (unet_lora_layers or text_encoder_lora_layers):
+ raise ValueError("You must pass at least one of `unet_lora_layers` and `text_encoder_lora_layers`.")
+
+ if unet_lora_layers:
+ state_dict.update(cls.pack_weights(unet_lora_layers, cls.unet_name))
+
+ if text_encoder_lora_layers:
+ state_dict.update(cls.pack_weights(text_encoder_lora_layers, cls.text_encoder_name))
+
+ # Save the model
+ cls.write_lora_layers(
+ state_dict=state_dict,
+ save_directory=save_directory,
+ is_main_process=is_main_process,
+ weight_name=weight_name,
+ save_function=save_function,
+ safe_serialization=safe_serialization,
+ )
+
+ def fuse_lora(
+ self,
+ components: List[str] = ["unet", "text_encoder"],
+ lora_scale: float = 1.0,
+ safe_fusing: bool = False,
+ adapter_names: Optional[List[str]] = None,
+ **kwargs,
+ ):
+ r"""
+ Fuses the LoRA parameters into the original parameters of the corresponding blocks.
+
+
+
+ This is an experimental API.
+
+
+
+ Args:
+ components: (`List[str]`): List of LoRA-injectable components to fuse the LoRAs into.
+ lora_scale (`float`, defaults to 1.0):
+ Controls how much to influence the outputs with the LoRA parameters.
+ safe_fusing (`bool`, defaults to `False`):
+ Whether to check fused weights for NaN values before fusing and if values are NaN not fusing them.
+ adapter_names (`List[str]`, *optional*):
+ Adapter names to be used for fusing. If nothing is passed, all active adapters will be fused.
+
+ Example:
+
+ ```py
+ from diffusers import DiffusionPipeline
+ import torch
+
+ pipeline = DiffusionPipeline.from_pretrained(
+ "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16
+ ).to("cuda")
+ pipeline.load_lora_weights("nerijs/pixel-art-xl", weight_name="pixel-art-xl.safetensors", adapter_name="pixel")
+ pipeline.fuse_lora(lora_scale=0.7)
+ ```
+ """
+ super().fuse_lora(
+ components=components, lora_scale=lora_scale, safe_fusing=safe_fusing, adapter_names=adapter_names
+ )
+
+ def unfuse_lora(self, components: List[str] = ["unet", "text_encoder"], **kwargs):
+ r"""
+ Reverses the effect of
+ [`pipe.fuse_lora()`](https://huggingface.co/docs/diffusers/main/en/api/loaders#diffusers.loaders.LoraBaseMixin.fuse_lora).
+
+
+
+ This is an experimental API.
+
+
+
+ Args:
+ components (`List[str]`): List of LoRA-injectable components to unfuse LoRA from.
+ unfuse_unet (`bool`, defaults to `True`): Whether to unfuse the UNet LoRA parameters.
+ unfuse_text_encoder (`bool`, defaults to `True`):
+ Whether to unfuse the text encoder LoRA parameters. If the text encoder wasn't monkey-patched with the
+ LoRA parameters then it won't have any effect.
+ """
+ super().unfuse_lora(components=components)
+
+
+class StableDiffusionXLLoraLoaderMixin(LoraBaseMixin):
+ r"""
+ Load LoRA layers into Stable Diffusion XL [`UNet2DConditionModel`],
+ [`CLIPTextModel`](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), and
+ [`CLIPTextModelWithProjection`](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModelWithProjection).
+ """
+
+ _lora_loadable_modules = ["unet", "text_encoder", "text_encoder_2"]
+ unet_name = UNET_NAME
+ text_encoder_name = TEXT_ENCODER_NAME
+
+ def load_lora_weights(
+ self,
+ pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
+ adapter_name: Optional[str] = None,
+ **kwargs,
+ ):
+ """
+ Load LoRA weights specified in `pretrained_model_name_or_path_or_dict` into `self.unet` and
+ `self.text_encoder`.
+
+ All kwargs are forwarded to `self.lora_state_dict`.
+
+ See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`] for more details on how the state dict is
+ loaded.
+
+ See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_into_unet`] for more details on how the state dict is
+ loaded into `self.unet`.
+
+ See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_into_text_encoder`] for more details on how the state
+ dict is loaded into `self.text_encoder`.
+
+ Parameters:
+ pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
+ See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`].
+ adapter_name (`str`, *optional*):
+ Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
+ `default_{i}` where i is the total number of adapters being loaded.
+ kwargs (`dict`, *optional*):
+ See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`].
+ """
+ if not USE_PEFT_BACKEND:
+ raise ValueError("PEFT backend is required for this method.")
+
+ # We could have accessed the unet config from `lora_state_dict()` too. We pass
+ # it here explicitly to be able to tell that it's coming from an SDXL
+ # pipeline.
+
+ # if a dict is passed, copy it instead of modifying it inplace
+ if isinstance(pretrained_model_name_or_path_or_dict, dict):
+ pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict.copy()
+
+ # First, ensure that the checkpoint is a compatible one and can be successfully loaded.
+ state_dict, network_alphas = self.lora_state_dict(
+ pretrained_model_name_or_path_or_dict,
+ unet_config=self.unet.config,
+ **kwargs,
+ )
+ is_correct_format = all("lora" in key or "dora_scale" in key for key in state_dict.keys())
+ if not is_correct_format:
+ raise ValueError("Invalid LoRA checkpoint.")
+
+ self.load_lora_into_unet(
+ state_dict, network_alphas=network_alphas, unet=self.unet, adapter_name=adapter_name, _pipeline=self
+ )
+ text_encoder_state_dict = {k: v for k, v in state_dict.items() if "text_encoder." in k}
+ if len(text_encoder_state_dict) > 0:
+ self.load_lora_into_text_encoder(
+ text_encoder_state_dict,
+ network_alphas=network_alphas,
+ text_encoder=self.text_encoder,
+ prefix="text_encoder",
+ lora_scale=self.lora_scale,
+ adapter_name=adapter_name,
+ _pipeline=self,
+ )
+
+ text_encoder_2_state_dict = {k: v for k, v in state_dict.items() if "text_encoder_2." in k}
+ if len(text_encoder_2_state_dict) > 0:
+ self.load_lora_into_text_encoder(
+ text_encoder_2_state_dict,
+ network_alphas=network_alphas,
+ text_encoder=self.text_encoder_2,
+ prefix="text_encoder_2",
+ lora_scale=self.lora_scale,
+ adapter_name=adapter_name,
+ _pipeline=self,
+ )
+
+ @classmethod
+ @validate_hf_hub_args
+ # Copied from diffusers.loaders.lora_pipeline.StableDiffusionLoraLoaderMixin.lora_state_dict
+ def lora_state_dict(
+ cls,
+ pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
+ **kwargs,
+ ):
+ r"""
+ Return state dict for lora weights and the network alphas.
+
+
+
+ We support loading A1111 formatted LoRA checkpoints in a limited capacity.
+
+ This function is experimental and might change in the future.
+
+
+
+ Parameters:
+ pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
+ Can be either:
+
+ - A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on
+ the Hub.
+ - A path to a *directory* (for example `./my_model_directory`) containing the model weights saved
+ with [`ModelMixin.save_pretrained`].
+ - A [torch state
+ dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict).
+
+ cache_dir (`Union[str, os.PathLike]`, *optional*):
+ Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
+ is not used.
+ force_download (`bool`, *optional*, defaults to `False`):
+ Whether or not to force the (re-)download of the model weights and configuration files, overriding the
+ cached versions if they exist.
+
+ proxies (`Dict[str, str]`, *optional*):
+ A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
+ 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
+ local_files_only (`bool`, *optional*, defaults to `False`):
+ Whether to only load local model weights and configuration files or not. If set to `True`, the model
+ won't be downloaded from the Hub.
+ token (`str` or *bool*, *optional*):
+ The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
+ `diffusers-cli login` (stored in `~/.huggingface`) is used.
+ revision (`str`, *optional*, defaults to `"main"`):
+ The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
+ allowed by Git.
+ subfolder (`str`, *optional*, defaults to `""`):
+ The subfolder location of a model file within a larger model repository on the Hub or locally.
+ weight_name (`str`, *optional*, defaults to None):
+ Name of the serialized state dict file.
+ """
+ # Load the main state dict first which has the LoRA layers for either of
+ # UNet and text encoder or both.
+ cache_dir = kwargs.pop("cache_dir", None)
+ force_download = kwargs.pop("force_download", False)
+ proxies = kwargs.pop("proxies", None)
+ local_files_only = kwargs.pop("local_files_only", None)
+ token = kwargs.pop("token", None)
+ revision = kwargs.pop("revision", None)
+ subfolder = kwargs.pop("subfolder", None)
+ weight_name = kwargs.pop("weight_name", None)
+ unet_config = kwargs.pop("unet_config", None)
+ use_safetensors = kwargs.pop("use_safetensors", None)
+
+ allow_pickle = False
+ if use_safetensors is None:
+ use_safetensors = True
+ allow_pickle = True
+
+ user_agent = {
+ "file_type": "attn_procs_weights",
+ "framework": "pytorch",
+ }
+
+ state_dict = cls._fetch_state_dict(
+ pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict,
+ weight_name=weight_name,
+ use_safetensors=use_safetensors,
+ local_files_only=local_files_only,
+ cache_dir=cache_dir,
+ force_download=force_download,
+ proxies=proxies,
+ token=token,
+ revision=revision,
+ subfolder=subfolder,
+ user_agent=user_agent,
+ allow_pickle=allow_pickle,
+ )
+
+ network_alphas = None
+ # TODO: replace it with a method from `state_dict_utils`
+ if all(
+ (
+ k.startswith("lora_te_")
+ or k.startswith("lora_unet_")
+ or k.startswith("lora_te1_")
+ or k.startswith("lora_te2_")
+ )
+ for k in state_dict.keys()
+ ):
+ # Map SDXL blocks correctly.
+ if unet_config is not None:
+ # use unet config to remap block numbers
+ state_dict = _maybe_map_sgm_blocks_to_diffusers(state_dict, unet_config)
+ state_dict, network_alphas = _convert_non_diffusers_lora_to_diffusers(state_dict)
+
+ return state_dict, network_alphas
+
+ @classmethod
+ # Copied from diffusers.loaders.lora_pipeline.StableDiffusionLoraLoaderMixin.load_lora_into_unet
+ def load_lora_into_unet(cls, state_dict, network_alphas, unet, adapter_name=None, _pipeline=None):
+ """
+ This will load the LoRA layers specified in `state_dict` into `unet`.
+
+ Parameters:
+ state_dict (`dict`):
+ A standard state dict containing the lora layer parameters. The keys can either be indexed directly
+ into the unet or prefixed with an additional `unet` which can be used to distinguish between text
+ encoder lora layers.
+ network_alphas (`Dict[str, float]`):
+ The value of the network alpha used for stable learning and preventing underflow. This value has the
+ same meaning as the `--network_alpha` option in the kohya-ss trainer script. Refer to [this
+ link](https://github.com/darkstorm2150/sd-scripts/blob/main/docs/train_network_README-en.md#execute-learning).
+ unet (`UNet2DConditionModel`):
+ The UNet model to load the LoRA layers into.
+ adapter_name (`str`, *optional*):
+ Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
+ `default_{i}` where i is the total number of adapters being loaded.
+ """
+ if not USE_PEFT_BACKEND:
+ raise ValueError("PEFT backend is required for this method.")
+
+ # If the serialization format is new (introduced in https://github.com/huggingface/diffusers/pull/2918),
+ # then the `state_dict` keys should have `cls.unet_name` and/or `cls.text_encoder_name` as
+ # their prefixes.
+ keys = list(state_dict.keys())
+ only_text_encoder = all(key.startswith(cls.text_encoder_name) for key in keys)
+ if not only_text_encoder:
+ # Load the layers corresponding to UNet.
+ logger.info(f"Loading {cls.unet_name}.")
+ unet.load_attn_procs(
+ state_dict, network_alphas=network_alphas, adapter_name=adapter_name, _pipeline=_pipeline
+ )
+
+ @classmethod
+ # Copied from diffusers.loaders.lora_pipeline.StableDiffusionLoraLoaderMixin.load_lora_into_text_encoder
+ def load_lora_into_text_encoder(
+ cls,
+ state_dict,
+ network_alphas,
+ text_encoder,
+ prefix=None,
+ lora_scale=1.0,
+ adapter_name=None,
+ _pipeline=None,
+ ):
+ """
+ This will load the LoRA layers specified in `state_dict` into `text_encoder`
+
+ Parameters:
+ state_dict (`dict`):
+ A standard state dict containing the lora layer parameters. The key should be prefixed with an
+ additional `text_encoder` to distinguish between unet lora layers.
+ network_alphas (`Dict[str, float]`):
+ See `LoRALinearLayer` for more details.
+ text_encoder (`CLIPTextModel`):
+ The text encoder model to load the LoRA layers into.
+ prefix (`str`):
+ Expected prefix of the `text_encoder` in the `state_dict`.
+ lora_scale (`float`):
+ How much to scale the output of the lora linear layer before it is added with the output of the regular
+ lora layer.
+ adapter_name (`str`, *optional*):
+ Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
+ `default_{i}` where i is the total number of adapters being loaded.
+ """
+ if not USE_PEFT_BACKEND:
+ raise ValueError("PEFT backend is required for this method.")
+
+ from peft import LoraConfig
+
+ # If the serialization format is new (introduced in https://github.com/huggingface/diffusers/pull/2918),
+ # then the `state_dict` keys should have `self.unet_name` and/or `self.text_encoder_name` as
+ # their prefixes.
+ keys = list(state_dict.keys())
+ prefix = cls.text_encoder_name if prefix is None else prefix
+
+ # Safe prefix to check with.
+ if any(cls.text_encoder_name in key for key in keys):
+ # Load the layers corresponding to text encoder and make necessary adjustments.
+ text_encoder_keys = [k for k in keys if k.startswith(prefix) and k.split(".")[0] == prefix]
+ text_encoder_lora_state_dict = {
+ k.replace(f"{prefix}.", ""): v for k, v in state_dict.items() if k in text_encoder_keys
+ }
+
+ if len(text_encoder_lora_state_dict) > 0:
+ logger.info(f"Loading {prefix}.")
+ rank = {}
+ text_encoder_lora_state_dict = convert_state_dict_to_diffusers(text_encoder_lora_state_dict)
+
+ # convert state dict
+ text_encoder_lora_state_dict = convert_state_dict_to_peft(text_encoder_lora_state_dict)
+
+ for name, _ in text_encoder_attn_modules(text_encoder):
+ for module in ("out_proj", "q_proj", "k_proj", "v_proj"):
+ rank_key = f"{name}.{module}.lora_B.weight"
+ if rank_key not in text_encoder_lora_state_dict:
+ continue
+ rank[rank_key] = text_encoder_lora_state_dict[rank_key].shape[1]
+
+ for name, _ in text_encoder_mlp_modules(text_encoder):
+ for module in ("fc1", "fc2"):
+ rank_key = f"{name}.{module}.lora_B.weight"
+ if rank_key not in text_encoder_lora_state_dict:
+ continue
+ rank[rank_key] = text_encoder_lora_state_dict[rank_key].shape[1]
+
+ if network_alphas is not None:
+ alpha_keys = [
+ k for k in network_alphas.keys() if k.startswith(prefix) and k.split(".")[0] == prefix
+ ]
+ network_alphas = {
+ k.replace(f"{prefix}.", ""): v for k, v in network_alphas.items() if k in alpha_keys
+ }
+
+ lora_config_kwargs = get_peft_kwargs(rank, network_alphas, text_encoder_lora_state_dict, is_unet=False)
+ if "use_dora" in lora_config_kwargs:
+ if lora_config_kwargs["use_dora"]:
+ if is_peft_version("<", "0.9.0"):
+ raise ValueError(
+ "You need `peft` 0.9.0 at least to use DoRA-enabled LoRAs. Please upgrade your installation of `peft`."
+ )
+ else:
+ if is_peft_version("<", "0.9.0"):
+ lora_config_kwargs.pop("use_dora")
+ lora_config = LoraConfig(**lora_config_kwargs)
+
+ # adapter_name
+ if adapter_name is None:
+ adapter_name = get_adapter_name(text_encoder)
+
+ is_model_cpu_offload, is_sequential_cpu_offload = cls._optionally_disable_offloading(_pipeline)
+
+ # inject LoRA layers and load the state dict
+ # in transformers we automatically check whether the adapter name is already in use or not
+ text_encoder.load_adapter(
+ adapter_name=adapter_name,
+ adapter_state_dict=text_encoder_lora_state_dict,
+ peft_config=lora_config,
+ )
+
+ # scale LoRA layers with `lora_scale`
+ scale_lora_layers(text_encoder, weight=lora_scale)
+
+ text_encoder.to(device=text_encoder.device, dtype=text_encoder.dtype)
+
+ # Offload back.
+ if is_model_cpu_offload:
+ _pipeline.enable_model_cpu_offload()
+ elif is_sequential_cpu_offload:
+ _pipeline.enable_sequential_cpu_offload()
+ # Unsafe code />
+
+ @classmethod
+ def save_lora_weights(
+ cls,
+ save_directory: Union[str, os.PathLike],
+ unet_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None,
+ text_encoder_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None,
+ text_encoder_2_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None,
+ is_main_process: bool = True,
+ weight_name: str = None,
+ save_function: Callable = None,
+ safe_serialization: bool = True,
+ ):
+ r"""
+ Save the LoRA parameters corresponding to the UNet and text encoder.
+
+ Arguments:
+ save_directory (`str` or `os.PathLike`):
+ Directory to save LoRA parameters to. Will be created if it doesn't exist.
+ unet_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`):
+ State dict of the LoRA layers corresponding to the `unet`.
+ text_encoder_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`):
+ State dict of the LoRA layers corresponding to the `text_encoder`. Must explicitly pass the text
+ encoder LoRA state dict because it comes from ๐ค Transformers.
+ text_encoder_2_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`):
+ State dict of the LoRA layers corresponding to the `text_encoder_2`. Must explicitly pass the text
+ encoder LoRA state dict because it comes from ๐ค Transformers.
+ is_main_process (`bool`, *optional*, defaults to `True`):
+ Whether the process calling this is the main process or not. Useful during distributed training and you
+ need to call this function on all processes. In this case, set `is_main_process=True` only on the main
+ process to avoid race conditions.
+ save_function (`Callable`):
+ The function to use to save the state dictionary. Useful during distributed training when you need to
+ replace `torch.save` with another method. Can be configured with the environment variable
+ `DIFFUSERS_SAVE_MODE`.
+ safe_serialization (`bool`, *optional*, defaults to `True`):
+ Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`.
+ """
+ state_dict = {}
+
+ if not (unet_lora_layers or text_encoder_lora_layers or text_encoder_2_lora_layers):
+ raise ValueError(
+ "You must pass at least one of `unet_lora_layers`, `text_encoder_lora_layers` or `text_encoder_2_lora_layers`."
+ )
+
+ if unet_lora_layers:
+ state_dict.update(cls.pack_weights(unet_lora_layers, "unet"))
+
+ if text_encoder_lora_layers:
+ state_dict.update(cls.pack_weights(text_encoder_lora_layers, "text_encoder"))
+
+ if text_encoder_2_lora_layers:
+ state_dict.update(cls.pack_weights(text_encoder_2_lora_layers, "text_encoder_2"))
+
+ cls.write_lora_layers(
+ state_dict=state_dict,
+ save_directory=save_directory,
+ is_main_process=is_main_process,
+ weight_name=weight_name,
+ save_function=save_function,
+ safe_serialization=safe_serialization,
+ )
+
+ def fuse_lora(
+ self,
+ components: List[str] = ["unet", "text_encoder", "text_encoder_2"],
+ lora_scale: float = 1.0,
+ safe_fusing: bool = False,
+ adapter_names: Optional[List[str]] = None,
+ **kwargs,
+ ):
+ r"""
+ Fuses the LoRA parameters into the original parameters of the corresponding blocks.
+
+
+
+ This is an experimental API.
+
+
+
+ Args:
+ components: (`List[str]`): List of LoRA-injectable components to fuse the LoRAs into.
+ lora_scale (`float`, defaults to 1.0):
+ Controls how much to influence the outputs with the LoRA parameters.
+ safe_fusing (`bool`, defaults to `False`):
+ Whether to check fused weights for NaN values before fusing and if values are NaN not fusing them.
+ adapter_names (`List[str]`, *optional*):
+ Adapter names to be used for fusing. If nothing is passed, all active adapters will be fused.
+
+ Example:
+
+ ```py
+ from diffusers import DiffusionPipeline
+ import torch
+
+ pipeline = DiffusionPipeline.from_pretrained(
+ "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16
+ ).to("cuda")
+ pipeline.load_lora_weights("nerijs/pixel-art-xl", weight_name="pixel-art-xl.safetensors", adapter_name="pixel")
+ pipeline.fuse_lora(lora_scale=0.7)
+ ```
+ """
+ super().fuse_lora(
+ components=components, lora_scale=lora_scale, safe_fusing=safe_fusing, adapter_names=adapter_names
+ )
+
+ def unfuse_lora(self, components: List[str] = ["unet", "text_encoder", "text_encoder_2"], **kwargs):
+ r"""
+ Reverses the effect of
+ [`pipe.fuse_lora()`](https://huggingface.co/docs/diffusers/main/en/api/loaders#diffusers.loaders.LoraBaseMixin.fuse_lora).
+
+
+
+ This is an experimental API.
+
+
+
+ Args:
+ components (`List[str]`): List of LoRA-injectable components to unfuse LoRA from.
+ unfuse_unet (`bool`, defaults to `True`): Whether to unfuse the UNet LoRA parameters.
+ unfuse_text_encoder (`bool`, defaults to `True`):
+ Whether to unfuse the text encoder LoRA parameters. If the text encoder wasn't monkey-patched with the
+ LoRA parameters then it won't have any effect.
+ """
+ super().unfuse_lora(components=components)
+
+
+class SD3LoraLoaderMixin(LoraBaseMixin):
+ r"""
+ Load LoRA layers into [`SD3Transformer2DModel`],
+ [`CLIPTextModel`](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), and
+ [`CLIPTextModelWithProjection`](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModelWithProjection).
+
+ Specific to [`StableDiffusion3Pipeline`].
+ """
+
+ _lora_loadable_modules = ["transformer", "text_encoder", "text_encoder_2"]
+ transformer_name = TRANSFORMER_NAME
+ text_encoder_name = TEXT_ENCODER_NAME
+
+ @classmethod
+ @validate_hf_hub_args
+ def lora_state_dict(
+ cls,
+ pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
+ **kwargs,
+ ):
+ r"""
+ Return state dict for lora weights and the network alphas.
+
+
+
+ We support loading A1111 formatted LoRA checkpoints in a limited capacity.
+
+ This function is experimental and might change in the future.
+
+
+
+ Parameters:
+ pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
+ Can be either:
+
+ - A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on
+ the Hub.
+ - A path to a *directory* (for example `./my_model_directory`) containing the model weights saved
+ with [`ModelMixin.save_pretrained`].
+ - A [torch state
+ dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict).
+
+ cache_dir (`Union[str, os.PathLike]`, *optional*):
+ Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
+ is not used.
+ force_download (`bool`, *optional*, defaults to `False`):
+ Whether or not to force the (re-)download of the model weights and configuration files, overriding the
+ cached versions if they exist.
+
+ proxies (`Dict[str, str]`, *optional*):
+ A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
+ 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
+ local_files_only (`bool`, *optional*, defaults to `False`):
+ Whether to only load local model weights and configuration files or not. If set to `True`, the model
+ won't be downloaded from the Hub.
+ token (`str` or *bool*, *optional*):
+ The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
+ `diffusers-cli login` (stored in `~/.huggingface`) is used.
+ revision (`str`, *optional*, defaults to `"main"`):
+ The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
+ allowed by Git.
+ subfolder (`str`, *optional*, defaults to `""`):
+ The subfolder location of a model file within a larger model repository on the Hub or locally.
+
+ """
+ # Load the main state dict first which has the LoRA layers for either of
+ # transformer and text encoder or both.
+ cache_dir = kwargs.pop("cache_dir", None)
+ force_download = kwargs.pop("force_download", False)
+ proxies = kwargs.pop("proxies", None)
+ local_files_only = kwargs.pop("local_files_only", None)
+ token = kwargs.pop("token", None)
+ revision = kwargs.pop("revision", None)
+ subfolder = kwargs.pop("subfolder", None)
+ weight_name = kwargs.pop("weight_name", None)
+ use_safetensors = kwargs.pop("use_safetensors", None)
+
+ allow_pickle = False
+ if use_safetensors is None:
+ use_safetensors = True
+ allow_pickle = True
+
+ user_agent = {
+ "file_type": "attn_procs_weights",
+ "framework": "pytorch",
+ }
+
+ state_dict = cls._fetch_state_dict(
+ pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict,
+ weight_name=weight_name,
+ use_safetensors=use_safetensors,
+ local_files_only=local_files_only,
+ cache_dir=cache_dir,
+ force_download=force_download,
+ proxies=proxies,
+ token=token,
+ revision=revision,
+ subfolder=subfolder,
+ user_agent=user_agent,
+ allow_pickle=allow_pickle,
+ )
+
+ return state_dict
+
+ def load_lora_weights(
+ self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], adapter_name=None, **kwargs
+ ):
+ """
+ Load LoRA weights specified in `pretrained_model_name_or_path_or_dict` into `self.unet` and
+ `self.text_encoder`.
+
+ All kwargs are forwarded to `self.lora_state_dict`.
+
+ See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`] for more details on how the state dict is
+ loaded.
+
+ See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_into_transformer`] for more details on how the state
+ dict is loaded into `self.transformer`.
+
+ Parameters:
+ pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
+ See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`].
+ kwargs (`dict`, *optional*):
+ See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`].
+ adapter_name (`str`, *optional*):
+ Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
+ `default_{i}` where i is the total number of adapters being loaded.
+ """
+ if not USE_PEFT_BACKEND:
+ raise ValueError("PEFT backend is required for this method.")
+
+ # if a dict is passed, copy it instead of modifying it inplace
+ if isinstance(pretrained_model_name_or_path_or_dict, dict):
+ pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict.copy()
+
+ # First, ensure that the checkpoint is a compatible one and can be successfully loaded.
+ state_dict = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs)
+
+ is_correct_format = all("lora" in key or "dora_scale" in key for key in state_dict.keys())
+ if not is_correct_format:
+ raise ValueError("Invalid LoRA checkpoint.")
+
+ self.load_lora_into_transformer(
+ state_dict,
+ transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer,
+ adapter_name=adapter_name,
+ _pipeline=self,
+ )
+
+ text_encoder_state_dict = {k: v for k, v in state_dict.items() if "text_encoder." in k}
+ if len(text_encoder_state_dict) > 0:
+ self.load_lora_into_text_encoder(
+ text_encoder_state_dict,
+ network_alphas=None,
+ text_encoder=self.text_encoder,
+ prefix="text_encoder",
+ lora_scale=self.lora_scale,
+ adapter_name=adapter_name,
+ _pipeline=self,
+ )
+
+ text_encoder_2_state_dict = {k: v for k, v in state_dict.items() if "text_encoder_2." in k}
+ if len(text_encoder_2_state_dict) > 0:
+ self.load_lora_into_text_encoder(
+ text_encoder_2_state_dict,
+ network_alphas=None,
+ text_encoder=self.text_encoder_2,
+ prefix="text_encoder_2",
+ lora_scale=self.lora_scale,
+ adapter_name=adapter_name,
+ _pipeline=self,
+ )
+
+ @classmethod
+ def load_lora_into_transformer(cls, state_dict, transformer, adapter_name=None, _pipeline=None):
+ """
+ This will load the LoRA layers specified in `state_dict` into `transformer`.
+
+ Parameters:
+ state_dict (`dict`):
+ A standard state dict containing the lora layer parameters. The keys can either be indexed directly
+ into the unet or prefixed with an additional `unet` which can be used to distinguish between text
+ encoder lora layers.
+ transformer (`SD3Transformer2DModel`):
+ The Transformer model to load the LoRA layers into.
+ adapter_name (`str`, *optional*):
+ Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
+ `default_{i}` where i is the total number of adapters being loaded.
+ """
+ from peft import LoraConfig, inject_adapter_in_model, set_peft_model_state_dict
+
+ keys = list(state_dict.keys())
+
+ transformer_keys = [k for k in keys if k.startswith(cls.transformer_name)]
+ state_dict = {
+ k.replace(f"{cls.transformer_name}.", ""): v for k, v in state_dict.items() if k in transformer_keys
+ }
+
+ if len(state_dict.keys()) > 0:
+ # check with first key if is not in peft format
+ first_key = next(iter(state_dict.keys()))
+ if "lora_A" not in first_key:
+ state_dict = convert_unet_state_dict_to_peft(state_dict)
+
+ if adapter_name in getattr(transformer, "peft_config", {}):
+ raise ValueError(
+ f"Adapter name {adapter_name} already in use in the transformer - please select a new adapter name."
+ )
+
+ rank = {}
+ for key, val in state_dict.items():
+ if "lora_B" in key:
+ rank[key] = val.shape[1]
+
+ lora_config_kwargs = get_peft_kwargs(rank, network_alpha_dict=None, peft_state_dict=state_dict)
+ if "use_dora" in lora_config_kwargs:
+ if lora_config_kwargs["use_dora"] and is_peft_version("<", "0.9.0"):
+ raise ValueError(
+ "You need `peft` 0.9.0 at least to use DoRA-enabled LoRAs. Please upgrade your installation of `peft`."
+ )
+ else:
+ lora_config_kwargs.pop("use_dora")
+ lora_config = LoraConfig(**lora_config_kwargs)
+
+ # adapter_name
+ if adapter_name is None:
+ adapter_name = get_adapter_name(transformer)
+
+ # In case the pipeline has been already offloaded to CPU - temporarily remove the hooks
+ # otherwise loading LoRA weights will lead to an error
+ is_model_cpu_offload, is_sequential_cpu_offload = cls._optionally_disable_offloading(_pipeline)
+
+ inject_adapter_in_model(lora_config, transformer, adapter_name=adapter_name)
+ incompatible_keys = set_peft_model_state_dict(transformer, state_dict, adapter_name)
+
+ if incompatible_keys is not None:
+ # check only for unexpected keys
+ unexpected_keys = getattr(incompatible_keys, "unexpected_keys", None)
+ if unexpected_keys:
+ logger.warning(
+ f"Loading adapter weights from state_dict led to unexpected keys not found in the model: "
+ f" {unexpected_keys}. "
+ )
+
+ # Offload back.
+ if is_model_cpu_offload:
+ _pipeline.enable_model_cpu_offload()
+ elif is_sequential_cpu_offload:
+ _pipeline.enable_sequential_cpu_offload()
+ # Unsafe code />
+
+ @classmethod
+ # Copied from diffusers.loaders.lora_pipeline.StableDiffusionLoraLoaderMixin.load_lora_into_text_encoder
+ def load_lora_into_text_encoder(
+ cls,
+ state_dict,
+ network_alphas,
+ text_encoder,
+ prefix=None,
+ lora_scale=1.0,
+ adapter_name=None,
+ _pipeline=None,
+ ):
+ """
+ This will load the LoRA layers specified in `state_dict` into `text_encoder`
+
+ Parameters:
+ state_dict (`dict`):
+ A standard state dict containing the lora layer parameters. The key should be prefixed with an
+ additional `text_encoder` to distinguish between unet lora layers.
+ network_alphas (`Dict[str, float]`):
+ See `LoRALinearLayer` for more details.
+ text_encoder (`CLIPTextModel`):
+ The text encoder model to load the LoRA layers into.
+ prefix (`str`):
+ Expected prefix of the `text_encoder` in the `state_dict`.
+ lora_scale (`float`):
+ How much to scale the output of the lora linear layer before it is added with the output of the regular
+ lora layer.
+ adapter_name (`str`, *optional*):
+ Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
+ `default_{i}` where i is the total number of adapters being loaded.
+ """
+ if not USE_PEFT_BACKEND:
+ raise ValueError("PEFT backend is required for this method.")
+
+ from peft import LoraConfig
+
+ # If the serialization format is new (introduced in https://github.com/huggingface/diffusers/pull/2918),
+ # then the `state_dict` keys should have `self.unet_name` and/or `self.text_encoder_name` as
+ # their prefixes.
+ keys = list(state_dict.keys())
+ prefix = cls.text_encoder_name if prefix is None else prefix
+
+ # Safe prefix to check with.
+ if any(cls.text_encoder_name in key for key in keys):
+ # Load the layers corresponding to text encoder and make necessary adjustments.
+ text_encoder_keys = [k for k in keys if k.startswith(prefix) and k.split(".")[0] == prefix]
+ text_encoder_lora_state_dict = {
+ k.replace(f"{prefix}.", ""): v for k, v in state_dict.items() if k in text_encoder_keys
+ }
+
+ if len(text_encoder_lora_state_dict) > 0:
+ logger.info(f"Loading {prefix}.")
+ rank = {}
+ text_encoder_lora_state_dict = convert_state_dict_to_diffusers(text_encoder_lora_state_dict)
+
+ # convert state dict
+ text_encoder_lora_state_dict = convert_state_dict_to_peft(text_encoder_lora_state_dict)
+
+ for name, _ in text_encoder_attn_modules(text_encoder):
+ for module in ("out_proj", "q_proj", "k_proj", "v_proj"):
+ rank_key = f"{name}.{module}.lora_B.weight"
+ if rank_key not in text_encoder_lora_state_dict:
+ continue
+ rank[rank_key] = text_encoder_lora_state_dict[rank_key].shape[1]
+
+ for name, _ in text_encoder_mlp_modules(text_encoder):
+ for module in ("fc1", "fc2"):
+ rank_key = f"{name}.{module}.lora_B.weight"
+ if rank_key not in text_encoder_lora_state_dict:
+ continue
+ rank[rank_key] = text_encoder_lora_state_dict[rank_key].shape[1]
+
+ if network_alphas is not None:
+ alpha_keys = [
+ k for k in network_alphas.keys() if k.startswith(prefix) and k.split(".")[0] == prefix
+ ]
+ network_alphas = {
+ k.replace(f"{prefix}.", ""): v for k, v in network_alphas.items() if k in alpha_keys
+ }
+
+ lora_config_kwargs = get_peft_kwargs(rank, network_alphas, text_encoder_lora_state_dict, is_unet=False)
+ if "use_dora" in lora_config_kwargs:
+ if lora_config_kwargs["use_dora"]:
+ if is_peft_version("<", "0.9.0"):
+ raise ValueError(
+ "You need `peft` 0.9.0 at least to use DoRA-enabled LoRAs. Please upgrade your installation of `peft`."
+ )
+ else:
+ if is_peft_version("<", "0.9.0"):
+ lora_config_kwargs.pop("use_dora")
+ lora_config = LoraConfig(**lora_config_kwargs)
+
+ # adapter_name
+ if adapter_name is None:
+ adapter_name = get_adapter_name(text_encoder)
+
+ is_model_cpu_offload, is_sequential_cpu_offload = cls._optionally_disable_offloading(_pipeline)
+
+ # inject LoRA layers and load the state dict
+ # in transformers we automatically check whether the adapter name is already in use or not
+ text_encoder.load_adapter(
+ adapter_name=adapter_name,
+ adapter_state_dict=text_encoder_lora_state_dict,
+ peft_config=lora_config,
+ )
+
+ # scale LoRA layers with `lora_scale`
+ scale_lora_layers(text_encoder, weight=lora_scale)
+
+ text_encoder.to(device=text_encoder.device, dtype=text_encoder.dtype)
+
+ # Offload back.
+ if is_model_cpu_offload:
+ _pipeline.enable_model_cpu_offload()
+ elif is_sequential_cpu_offload:
+ _pipeline.enable_sequential_cpu_offload()
+ # Unsafe code />
+
+ @classmethod
+ def save_lora_weights(
+ cls,
+ save_directory: Union[str, os.PathLike],
+ transformer_lora_layers: Dict[str, torch.nn.Module] = None,
+ text_encoder_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None,
+ text_encoder_2_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None,
+ is_main_process: bool = True,
+ weight_name: str = None,
+ save_function: Callable = None,
+ safe_serialization: bool = True,
+ ):
+ r"""
+ Save the LoRA parameters corresponding to the UNet and text encoder.
+
+ Arguments:
+ save_directory (`str` or `os.PathLike`):
+ Directory to save LoRA parameters to. Will be created if it doesn't exist.
+ transformer_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`):
+ State dict of the LoRA layers corresponding to the `transformer`.
+ text_encoder_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`):
+ State dict of the LoRA layers corresponding to the `text_encoder`. Must explicitly pass the text
+ encoder LoRA state dict because it comes from ๐ค Transformers.
+ text_encoder_2_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`):
+ State dict of the LoRA layers corresponding to the `text_encoder_2`. Must explicitly pass the text
+ encoder LoRA state dict because it comes from ๐ค Transformers.
+ is_main_process (`bool`, *optional*, defaults to `True`):
+ Whether the process calling this is the main process or not. Useful during distributed training and you
+ need to call this function on all processes. In this case, set `is_main_process=True` only on the main
+ process to avoid race conditions.
+ save_function (`Callable`):
+ The function to use to save the state dictionary. Useful during distributed training when you need to
+ replace `torch.save` with another method. Can be configured with the environment variable
+ `DIFFUSERS_SAVE_MODE`.
+ safe_serialization (`bool`, *optional*, defaults to `True`):
+ Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`.
+ """
+ state_dict = {}
+
+ if not (transformer_lora_layers or text_encoder_lora_layers or text_encoder_2_lora_layers):
+ raise ValueError(
+ "You must pass at least one of `transformer_lora_layers`, `text_encoder_lora_layers`, `text_encoder_2_lora_layers`."
+ )
+
+ if transformer_lora_layers:
+ state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name))
+
+ if text_encoder_lora_layers:
+ state_dict.update(cls.pack_weights(text_encoder_lora_layers, "text_encoder"))
+
+ if text_encoder_2_lora_layers:
+ state_dict.update(cls.pack_weights(text_encoder_2_lora_layers, "text_encoder_2"))
+
+ # Save the model
+ cls.write_lora_layers(
+ state_dict=state_dict,
+ save_directory=save_directory,
+ is_main_process=is_main_process,
+ weight_name=weight_name,
+ save_function=save_function,
+ safe_serialization=safe_serialization,
+ )
+
+ def fuse_lora(
+ self,
+ components: List[str] = ["transformer", "text_encoder", "text_encoder_2"],
+ lora_scale: float = 1.0,
+ safe_fusing: bool = False,
+ adapter_names: Optional[List[str]] = None,
+ **kwargs,
+ ):
+ r"""
+ Fuses the LoRA parameters into the original parameters of the corresponding blocks.
+
+
+
+ This is an experimental API.
+
+
+
+ Args:
+ components: (`List[str]`): List of LoRA-injectable components to fuse the LoRAs into.
+ lora_scale (`float`, defaults to 1.0):
+ Controls how much to influence the outputs with the LoRA parameters.
+ safe_fusing (`bool`, defaults to `False`):
+ Whether to check fused weights for NaN values before fusing and if values are NaN not fusing them.
+ adapter_names (`List[str]`, *optional*):
+ Adapter names to be used for fusing. If nothing is passed, all active adapters will be fused.
+
+ Example:
+
+ ```py
+ from diffusers import DiffusionPipeline
+ import torch
+
+ pipeline = DiffusionPipeline.from_pretrained(
+ "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16
+ ).to("cuda")
+ pipeline.load_lora_weights("nerijs/pixel-art-xl", weight_name="pixel-art-xl.safetensors", adapter_name="pixel")
+ pipeline.fuse_lora(lora_scale=0.7)
+ ```
+ """
+ super().fuse_lora(
+ components=components, lora_scale=lora_scale, safe_fusing=safe_fusing, adapter_names=adapter_names
+ )
+
+ def unfuse_lora(self, components: List[str] = ["transformer", "text_encoder", "text_encoder_2"], **kwargs):
+ r"""
+ Reverses the effect of
+ [`pipe.fuse_lora()`](https://huggingface.co/docs/diffusers/main/en/api/loaders#diffusers.loaders.LoraBaseMixin.fuse_lora).
+
+
+
+ This is an experimental API.
+
+
+
+ Args:
+ components (`List[str]`): List of LoRA-injectable components to unfuse LoRA from.
+ unfuse_unet (`bool`, defaults to `True`): Whether to unfuse the UNet LoRA parameters.
+ unfuse_text_encoder (`bool`, defaults to `True`):
+ Whether to unfuse the text encoder LoRA parameters. If the text encoder wasn't monkey-patched with the
+ LoRA parameters then it won't have any effect.
+ """
+ super().unfuse_lora(components=components)
+
+
+class FluxLoraLoaderMixin(LoraBaseMixin):
+ r"""
+ Load LoRA layers into [`FluxTransformer2DModel`],
+ [`CLIPTextModel`](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel).
+
+ Specific to [`StableDiffusion3Pipeline`].
+ """
+
+ _lora_loadable_modules = ["transformer", "text_encoder"]
+ transformer_name = TRANSFORMER_NAME
+ text_encoder_name = TEXT_ENCODER_NAME
+
+ @classmethod
+ @validate_hf_hub_args
+ # Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.lora_state_dict
+ def lora_state_dict(
+ cls,
+ pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
+ **kwargs,
+ ):
+ r"""
+ Return state dict for lora weights and the network alphas.
+
+
+
+ We support loading A1111 formatted LoRA checkpoints in a limited capacity.
+
+ This function is experimental and might change in the future.
+
+
+
+ Parameters:
+ pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
+ Can be either:
+
+ - A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on
+ the Hub.
+ - A path to a *directory* (for example `./my_model_directory`) containing the model weights saved
+ with [`ModelMixin.save_pretrained`].
+ - A [torch state
+ dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict).
+
+ cache_dir (`Union[str, os.PathLike]`, *optional*):
+ Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
+ is not used.
+ force_download (`bool`, *optional*, defaults to `False`):
+ Whether or not to force the (re-)download of the model weights and configuration files, overriding the
+ cached versions if they exist.
+
+ proxies (`Dict[str, str]`, *optional*):
+ A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
+ 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
+ local_files_only (`bool`, *optional*, defaults to `False`):
+ Whether to only load local model weights and configuration files or not. If set to `True`, the model
+ won't be downloaded from the Hub.
+ token (`str` or *bool*, *optional*):
+ The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
+ `diffusers-cli login` (stored in `~/.huggingface`) is used.
+ revision (`str`, *optional*, defaults to `"main"`):
+ The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
+ allowed by Git.
+ subfolder (`str`, *optional*, defaults to `""`):
+ The subfolder location of a model file within a larger model repository on the Hub or locally.
+
+ """
+ # Load the main state dict first which has the LoRA layers for either of
+ # transformer and text encoder or both.
+ cache_dir = kwargs.pop("cache_dir", None)
+ force_download = kwargs.pop("force_download", False)
+ proxies = kwargs.pop("proxies", None)
+ local_files_only = kwargs.pop("local_files_only", None)
+ token = kwargs.pop("token", None)
+ revision = kwargs.pop("revision", None)
+ subfolder = kwargs.pop("subfolder", None)
+ weight_name = kwargs.pop("weight_name", None)
+ use_safetensors = kwargs.pop("use_safetensors", None)
+
+ allow_pickle = False
+ if use_safetensors is None:
+ use_safetensors = True
+ allow_pickle = True
+
+ user_agent = {
+ "file_type": "attn_procs_weights",
+ "framework": "pytorch",
+ }
+
+ state_dict = cls._fetch_state_dict(
+ pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict,
+ weight_name=weight_name,
+ use_safetensors=use_safetensors,
+ local_files_only=local_files_only,
+ cache_dir=cache_dir,
+ force_download=force_download,
+ proxies=proxies,
+ token=token,
+ revision=revision,
+ subfolder=subfolder,
+ user_agent=user_agent,
+ allow_pickle=allow_pickle,
+ )
+
+ return state_dict
+
+ def load_lora_weights(
+ self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], adapter_name=None, **kwargs
+ ):
+ """
+ Load LoRA weights specified in `pretrained_model_name_or_path_or_dict` into `self.transformer` and
+ `self.text_encoder`.
+
+ All kwargs are forwarded to `self.lora_state_dict`.
+
+ See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`] for more details on how the state dict is
+ loaded.
+
+ See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_into_transformer`] for more details on how the state
+ dict is loaded into `self.transformer`.
+
+ Parameters:
+ pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
+ See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`].
+ kwargs (`dict`, *optional*):
+ See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`].
+ adapter_name (`str`, *optional*):
+ Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
+ `default_{i}` where i is the total number of adapters being loaded.
+ """
+ if not USE_PEFT_BACKEND:
+ raise ValueError("PEFT backend is required for this method.")
+
+ # if a dict is passed, copy it instead of modifying it inplace
+ if isinstance(pretrained_model_name_or_path_or_dict, dict):
+ pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict.copy()
+
+ # First, ensure that the checkpoint is a compatible one and can be successfully loaded.
+ state_dict = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs)
+
+ #lupengqi add
+ #print("LoRA state_dict keys:")
+ #for key in state_dict.keys():
+ # print(key)
+
+ ## Add "default_1" prefix to LoRA weights
+ #new_state_dict = {}
+ #for key, value in state_dict.items():
+ # if key.endswith((".lora_A.weight", ".lora_B.weight")):
+ # new_key = key.replace(".weight", ".default_1.weight")
+ # else:
+ # new_key = key
+ # new_state_dict[new_key] = value
+
+ is_correct_format = all("lora" in key or "dora_scale" in key for key in state_dict.keys())
+ if not is_correct_format:
+ raise ValueError("Invalid LoRA checkpoint.")
+
+ self.load_lora_into_transformer(
+ state_dict,
+ transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer,
+ adapter_name=adapter_name,
+ _pipeline=self,
+ )
+
+ text_encoder_state_dict = {k: v for k, v in state_dict.items() if "text_encoder." in k}
+ if len(text_encoder_state_dict) > 0:
+ self.load_lora_into_text_encoder(
+ text_encoder_state_dict,
+ network_alphas=None,
+ text_encoder=self.text_encoder,
+ prefix="text_encoder",
+ lora_scale=self.lora_scale,
+ adapter_name=adapter_name,
+ _pipeline=self,
+ )
+
+ @classmethod
+ # Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.load_lora_into_transformer
+ def load_lora_into_transformer(cls, state_dict, transformer, adapter_name=None, _pipeline=None):
+ """
+ This will load the LoRA layers specified in `state_dict` into `transformer`.
+
+ Parameters:
+ state_dict (`dict`):
+ A standard state dict containing the lora layer parameters. The keys can either be indexed directly
+ into the unet or prefixed with an additional `unet` which can be used to distinguish between text
+ encoder lora layers.
+ transformer (`SD3Transformer2DModel`):
+ The Transformer model to load the LoRA layers into.
+ adapter_name (`str`, *optional*):
+ Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
+ `default_{i}` where i is the total number of adapters being loaded.
+ """
+ from peft import LoraConfig, inject_adapter_in_model, set_peft_model_state_dict
+
+ keys = list(state_dict.keys())
+
+ transformer_keys = [k for k in keys if k.startswith(cls.transformer_name)]
+ state_dict = {
+ k.replace(f"{cls.transformer_name}.", ""): v for k, v in state_dict.items() if k in transformer_keys
+ }
+
+ if len(state_dict.keys()) > 0:
+ # check with first key if is not in peft format
+ first_key = next(iter(state_dict.keys()))
+ if "lora_A" not in first_key:
+ state_dict = convert_unet_state_dict_to_peft(state_dict)
+
+ if adapter_name in getattr(transformer, "peft_config", {}):
+ raise ValueError(
+ f"Adapter name {adapter_name} already in use in the transformer - please select a new adapter name."
+ )
+
+ rank = {}
+ for key, val in state_dict.items():
+ if "lora_B" in key:
+ rank[key] = val.shape[1]
+
+ lora_config_kwargs = get_peft_kwargs(rank, network_alpha_dict=None, peft_state_dict=state_dict)
+ if "use_dora" in lora_config_kwargs:
+ if lora_config_kwargs["use_dora"] and is_peft_version("<", "0.9.0"):
+ raise ValueError(
+ "You need `peft` 0.9.0 at least to use DoRA-enabled LoRAs. Please upgrade your installation of `peft`."
+ )
+ else:
+ lora_config_kwargs.pop("use_dora")
+ lora_config = LoraConfig(**lora_config_kwargs)
+
+ # adapter_name
+ if adapter_name is None:
+ adapter_name = get_adapter_name(transformer)
+
+ # In case the pipeline has been already offloaded to CPU - temporarily remove the hooks
+ # otherwise loading LoRA weights will lead to an error
+ is_model_cpu_offload, is_sequential_cpu_offload = cls._optionally_disable_offloading(_pipeline)
+
+ inject_adapter_in_model(lora_config, transformer, adapter_name=adapter_name)
+ incompatible_keys = set_peft_model_state_dict(transformer, state_dict, adapter_name)
+
+ if incompatible_keys is not None:
+ # check only for unexpected keys
+ unexpected_keys = getattr(incompatible_keys, "unexpected_keys", None)
+ if unexpected_keys:
+ logger.warning(
+ f"Loading adapter weights from state_dict led to unexpected keys not found in the model: "
+ f" {unexpected_keys}. "
+ )
+
+ # Offload back.
+ if is_model_cpu_offload:
+ _pipeline.enable_model_cpu_offload()
+ elif is_sequential_cpu_offload:
+ _pipeline.enable_sequential_cpu_offload()
+ # Unsafe code />
+
+ @classmethod
+ # Copied from diffusers.loaders.lora_pipeline.StableDiffusionLoraLoaderMixin.load_lora_into_text_encoder
+ def load_lora_into_text_encoder(
+ cls,
+ state_dict,
+ network_alphas,
+ text_encoder,
+ prefix=None,
+ lora_scale=1.0,
+ adapter_name=None,
+ _pipeline=None,
+ ):
+ """
+ This will load the LoRA layers specified in `state_dict` into `text_encoder`
+
+ Parameters:
+ state_dict (`dict`):
+ A standard state dict containing the lora layer parameters. The key should be prefixed with an
+ additional `text_encoder` to distinguish between unet lora layers.
+ network_alphas (`Dict[str, float]`):
+ See `LoRALinearLayer` for more details.
+ text_encoder (`CLIPTextModel`):
+ The text encoder model to load the LoRA layers into.
+ prefix (`str`):
+ Expected prefix of the `text_encoder` in the `state_dict`.
+ lora_scale (`float`):
+ How much to scale the output of the lora linear layer before it is added with the output of the regular
+ lora layer.
+ adapter_name (`str`, *optional*):
+ Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
+ `default_{i}` where i is the total number of adapters being loaded.
+ """
+ if not USE_PEFT_BACKEND:
+ raise ValueError("PEFT backend is required for this method.")
+
+ from peft import LoraConfig
+
+ # If the serialization format is new (introduced in https://github.com/huggingface/diffusers/pull/2918),
+ # then the `state_dict` keys should have `self.unet_name` and/or `self.text_encoder_name` as
+ # their prefixes.
+ keys = list(state_dict.keys())
+ prefix = cls.text_encoder_name if prefix is None else prefix
+
+ # Safe prefix to check with.
+ if any(cls.text_encoder_name in key for key in keys):
+ # Load the layers corresponding to text encoder and make necessary adjustments.
+ text_encoder_keys = [k for k in keys if k.startswith(prefix) and k.split(".")[0] == prefix]
+ text_encoder_lora_state_dict = {
+ k.replace(f"{prefix}.", ""): v for k, v in state_dict.items() if k in text_encoder_keys
+ }
+
+ if len(text_encoder_lora_state_dict) > 0:
+ logger.info(f"Loading {prefix}.")
+ rank = {}
+ text_encoder_lora_state_dict = convert_state_dict_to_diffusers(text_encoder_lora_state_dict)
+
+ # convert state dict
+ text_encoder_lora_state_dict = convert_state_dict_to_peft(text_encoder_lora_state_dict)
+
+ for name, _ in text_encoder_attn_modules(text_encoder):
+ for module in ("out_proj", "q_proj", "k_proj", "v_proj"):
+ rank_key = f"{name}.{module}.lora_B.weight"
+ if rank_key not in text_encoder_lora_state_dict:
+ continue
+ rank[rank_key] = text_encoder_lora_state_dict[rank_key].shape[1]
+
+ for name, _ in text_encoder_mlp_modules(text_encoder):
+ for module in ("fc1", "fc2"):
+ rank_key = f"{name}.{module}.lora_B.weight"
+ if rank_key not in text_encoder_lora_state_dict:
+ continue
+ rank[rank_key] = text_encoder_lora_state_dict[rank_key].shape[1]
+
+ if network_alphas is not None:
+ alpha_keys = [
+ k for k in network_alphas.keys() if k.startswith(prefix) and k.split(".")[0] == prefix
+ ]
+ network_alphas = {
+ k.replace(f"{prefix}.", ""): v for k, v in network_alphas.items() if k in alpha_keys
+ }
+
+ lora_config_kwargs = get_peft_kwargs(rank, network_alphas, text_encoder_lora_state_dict, is_unet=False)
+ if "use_dora" in lora_config_kwargs:
+ if lora_config_kwargs["use_dora"]:
+ if is_peft_version("<", "0.9.0"):
+ raise ValueError(
+ "You need `peft` 0.9.0 at least to use DoRA-enabled LoRAs. Please upgrade your installation of `peft`."
+ )
+ else:
+ if is_peft_version("<", "0.9.0"):
+ lora_config_kwargs.pop("use_dora")
+ lora_config = LoraConfig(**lora_config_kwargs)
+
+ # adapter_name
+ if adapter_name is None:
+ adapter_name = get_adapter_name(text_encoder)
+
+ is_model_cpu_offload, is_sequential_cpu_offload = cls._optionally_disable_offloading(_pipeline)
+
+ # inject LoRA layers and load the state dict
+ # in transformers we automatically check whether the adapter name is already in use or not
+ text_encoder.load_adapter(
+ adapter_name=adapter_name,
+ adapter_state_dict=text_encoder_lora_state_dict,
+ peft_config=lora_config,
+ )
+
+ # scale LoRA layers with `lora_scale`
+ scale_lora_layers(text_encoder, weight=lora_scale)
+
+ text_encoder.to(device=text_encoder.device, dtype=text_encoder.dtype)
+
+ # Offload back.
+ if is_model_cpu_offload:
+ _pipeline.enable_model_cpu_offload()
+ elif is_sequential_cpu_offload:
+ _pipeline.enable_sequential_cpu_offload()
+ # Unsafe code />
+
+ @classmethod
+ # Copied from diffusers.loaders.lora_pipeline.StableDiffusionLoraLoaderMixin.save_lora_weights with unet->transformer
+ def save_lora_weights(
+ cls,
+ save_directory: Union[str, os.PathLike],
+ transformer_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None,
+ text_encoder_lora_layers: Dict[str, torch.nn.Module] = None,
+ is_main_process: bool = True,
+ weight_name: str = None,
+ save_function: Callable = None,
+ safe_serialization: bool = True,
+ ):
+ r"""
+ Save the LoRA parameters corresponding to the UNet and text encoder.
+
+ Arguments:
+ save_directory (`str` or `os.PathLike`):
+ Directory to save LoRA parameters to. Will be created if it doesn't exist.
+ transformer_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`):
+ State dict of the LoRA layers corresponding to the `transformer`.
+ text_encoder_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`):
+ State dict of the LoRA layers corresponding to the `text_encoder`. Must explicitly pass the text
+ encoder LoRA state dict because it comes from ๐ค Transformers.
+ is_main_process (`bool`, *optional*, defaults to `True`):
+ Whether the process calling this is the main process or not. Useful during distributed training and you
+ need to call this function on all processes. In this case, set `is_main_process=True` only on the main
+ process to avoid race conditions.
+ save_function (`Callable`):
+ The function to use to save the state dictionary. Useful during distributed training when you need to
+ replace `torch.save` with another method. Can be configured with the environment variable
+ `DIFFUSERS_SAVE_MODE`.
+ safe_serialization (`bool`, *optional*, defaults to `True`):
+ Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`.
+ """
+ state_dict = {}
+
+ if not (transformer_lora_layers or text_encoder_lora_layers):
+ raise ValueError("You must pass at least one of `transformer_lora_layers` and `text_encoder_lora_layers`.")
+
+ if transformer_lora_layers:
+ state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name))
+
+ if text_encoder_lora_layers:
+ state_dict.update(cls.pack_weights(text_encoder_lora_layers, cls.text_encoder_name))
+
+ # Save the model
+ cls.write_lora_layers(
+ state_dict=state_dict,
+ save_directory=save_directory,
+ is_main_process=is_main_process,
+ weight_name=weight_name,
+ save_function=save_function,
+ safe_serialization=safe_serialization,
+ )
+
+ # Copied from diffusers.loaders.lora_pipeline.StableDiffusionLoraLoaderMixin.fuse_lora with unet->transformer
+ def fuse_lora(
+ self,
+ components: List[str] = ["transformer", "text_encoder"],
+ lora_scale: float = 1.0,
+ safe_fusing: bool = False,
+ adapter_names: Optional[List[str]] = None,
+ **kwargs,
+ ):
+ r"""
+ Fuses the LoRA parameters into the original parameters of the corresponding blocks.
+
+
+
+ This is an experimental API.
+
+
+
+ Args:
+ components: (`List[str]`): List of LoRA-injectable components to fuse the LoRAs into.
+ lora_scale (`float`, defaults to 1.0):
+ Controls how much to influence the outputs with the LoRA parameters.
+ safe_fusing (`bool`, defaults to `False`):
+ Whether to check fused weights for NaN values before fusing and if values are NaN not fusing them.
+ adapter_names (`List[str]`, *optional*):
+ Adapter names to be used for fusing. If nothing is passed, all active adapters will be fused.
+
+ Example:
+
+ ```py
+ from diffusers import DiffusionPipeline
+ import torch
+
+ pipeline = DiffusionPipeline.from_pretrained(
+ "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16
+ ).to("cuda")
+ pipeline.load_lora_weights("nerijs/pixel-art-xl", weight_name="pixel-art-xl.safetensors", adapter_name="pixel")
+ pipeline.fuse_lora(lora_scale=0.7)
+ ```
+ """
+ super().fuse_lora(
+ components=components, lora_scale=lora_scale, safe_fusing=safe_fusing, adapter_names=adapter_names
+ )
+
+ def unfuse_lora(self, components: List[str] = ["transformer", "text_encoder"], **kwargs):
+ r"""
+ Reverses the effect of
+ [`pipe.fuse_lora()`](https://huggingface.co/docs/diffusers/main/en/api/loaders#diffusers.loaders.LoraBaseMixin.fuse_lora).
+
+
+
+ This is an experimental API.
+
+
+
+ Args:
+ components (`List[str]`): List of LoRA-injectable components to unfuse LoRA from.
+ """
+ super().unfuse_lora(components=components)
+
+
+# The reason why we subclass from `StableDiffusionLoraLoaderMixin` here is because Amused initially
+# relied on `StableDiffusionLoraLoaderMixin` for its LoRA support.
+class AmusedLoraLoaderMixin(StableDiffusionLoraLoaderMixin):
+ _lora_loadable_modules = ["transformer", "text_encoder"]
+ transformer_name = TRANSFORMER_NAME
+ text_encoder_name = TEXT_ENCODER_NAME
+
+ @classmethod
+ def load_lora_into_transformer(cls, state_dict, network_alphas, transformer, adapter_name=None, _pipeline=None):
+ """
+ This will load the LoRA layers specified in `state_dict` into `transformer`.
+
+ Parameters:
+ state_dict (`dict`):
+ A standard state dict containing the lora layer parameters. The keys can either be indexed directly
+ into the unet or prefixed with an additional `unet` which can be used to distinguish between text
+ encoder lora layers.
+ network_alphas (`Dict[str, float]`):
+ See `LoRALinearLayer` for more details.
+ unet (`UNet2DConditionModel`):
+ The UNet model to load the LoRA layers into.
+ adapter_name (`str`, *optional*):
+ Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
+ `default_{i}` where i is the total number of adapters being loaded.
+ """
+ if not USE_PEFT_BACKEND:
+ raise ValueError("PEFT backend is required for this method.")
+
+ from peft import LoraConfig, inject_adapter_in_model, set_peft_model_state_dict
+
+ keys = list(state_dict.keys())
+
+ transformer_keys = [k for k in keys if k.startswith(cls.transformer_name)]
+ state_dict = {
+ k.replace(f"{cls.transformer_name}.", ""): v for k, v in state_dict.items() if k in transformer_keys
+ }
+
+ if network_alphas is not None:
+ alpha_keys = [k for k in network_alphas.keys() if k.startswith(cls.transformer_name)]
+ network_alphas = {
+ k.replace(f"{cls.transformer_name}.", ""): v for k, v in network_alphas.items() if k in alpha_keys
+ }
+
+ if len(state_dict.keys()) > 0:
+ if adapter_name in getattr(transformer, "peft_config", {}):
+ raise ValueError(
+ f"Adapter name {adapter_name} already in use in the transformer - please select a new adapter name."
+ )
+
+ rank = {}
+ for key, val in state_dict.items():
+ if "lora_B" in key:
+ rank[key] = val.shape[1]
+
+ lora_config_kwargs = get_peft_kwargs(rank, network_alphas, state_dict)
+ if "use_dora" in lora_config_kwargs:
+ if lora_config_kwargs["use_dora"] and is_peft_version("<", "0.9.0"):
+ raise ValueError(
+ "You need `peft` 0.9.0 at least to use DoRA-enabled LoRAs. Please upgrade your installation of `peft`."
+ )
+ else:
+ lora_config_kwargs.pop("use_dora")
+ lora_config = LoraConfig(**lora_config_kwargs)
+
+ # adapter_name
+ if adapter_name is None:
+ adapter_name = get_adapter_name(transformer)
+
+ # In case the pipeline has been already offloaded to CPU - temporarily remove the hooks
+ # otherwise loading LoRA weights will lead to an error
+ is_model_cpu_offload, is_sequential_cpu_offload = cls._optionally_disable_offloading(_pipeline)
+
+ inject_adapter_in_model(lora_config, transformer, adapter_name=adapter_name)
+ incompatible_keys = set_peft_model_state_dict(transformer, state_dict, adapter_name)
+
+ if incompatible_keys is not None:
+ # check only for unexpected keys
+ unexpected_keys = getattr(incompatible_keys, "unexpected_keys", None)
+ if unexpected_keys:
+ logger.warning(
+ f"Loading adapter weights from state_dict led to unexpected keys not found in the model: "
+ f" {unexpected_keys}. "
+ )
+
+ # Offload back.
+ if is_model_cpu_offload:
+ _pipeline.enable_model_cpu_offload()
+ elif is_sequential_cpu_offload:
+ _pipeline.enable_sequential_cpu_offload()
+ # Unsafe code />
+
+ @classmethod
+ # Copied from diffusers.loaders.lora_pipeline.StableDiffusionLoraLoaderMixin.load_lora_into_text_encoder
+ def load_lora_into_text_encoder(
+ cls,
+ state_dict,
+ network_alphas,
+ text_encoder,
+ prefix=None,
+ lora_scale=1.0,
+ adapter_name=None,
+ _pipeline=None,
+ ):
+ """
+ This will load the LoRA layers specified in `state_dict` into `text_encoder`
+
+ Parameters:
+ state_dict (`dict`):
+ A standard state dict containing the lora layer parameters. The key should be prefixed with an
+ additional `text_encoder` to distinguish between unet lora layers.
+ network_alphas (`Dict[str, float]`):
+ See `LoRALinearLayer` for more details.
+ text_encoder (`CLIPTextModel`):
+ The text encoder model to load the LoRA layers into.
+ prefix (`str`):
+ Expected prefix of the `text_encoder` in the `state_dict`.
+ lora_scale (`float`):
+ How much to scale the output of the lora linear layer before it is added with the output of the regular
+ lora layer.
+ adapter_name (`str`, *optional*):
+ Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
+ `default_{i}` where i is the total number of adapters being loaded.
+ """
+ if not USE_PEFT_BACKEND:
+ raise ValueError("PEFT backend is required for this method.")
+
+ from peft import LoraConfig
+
+ # If the serialization format is new (introduced in https://github.com/huggingface/diffusers/pull/2918),
+ # then the `state_dict` keys should have `self.unet_name` and/or `self.text_encoder_name` as
+ # their prefixes.
+ keys = list(state_dict.keys())
+ prefix = cls.text_encoder_name if prefix is None else prefix
+
+ # Safe prefix to check with.
+ if any(cls.text_encoder_name in key for key in keys):
+ # Load the layers corresponding to text encoder and make necessary adjustments.
+ text_encoder_keys = [k for k in keys if k.startswith(prefix) and k.split(".")[0] == prefix]
+ text_encoder_lora_state_dict = {
+ k.replace(f"{prefix}.", ""): v for k, v in state_dict.items() if k in text_encoder_keys
+ }
+
+ if len(text_encoder_lora_state_dict) > 0:
+ logger.info(f"Loading {prefix}.")
+ rank = {}
+ text_encoder_lora_state_dict = convert_state_dict_to_diffusers(text_encoder_lora_state_dict)
+
+ # convert state dict
+ text_encoder_lora_state_dict = convert_state_dict_to_peft(text_encoder_lora_state_dict)
+
+ for name, _ in text_encoder_attn_modules(text_encoder):
+ for module in ("out_proj", "q_proj", "k_proj", "v_proj"):
+ rank_key = f"{name}.{module}.lora_B.weight"
+ if rank_key not in text_encoder_lora_state_dict:
+ continue
+ rank[rank_key] = text_encoder_lora_state_dict[rank_key].shape[1]
+
+ for name, _ in text_encoder_mlp_modules(text_encoder):
+ for module in ("fc1", "fc2"):
+ rank_key = f"{name}.{module}.lora_B.weight"
+ if rank_key not in text_encoder_lora_state_dict:
+ continue
+ rank[rank_key] = text_encoder_lora_state_dict[rank_key].shape[1]
+
+ if network_alphas is not None:
+ alpha_keys = [
+ k for k in network_alphas.keys() if k.startswith(prefix) and k.split(".")[0] == prefix
+ ]
+ network_alphas = {
+ k.replace(f"{prefix}.", ""): v for k, v in network_alphas.items() if k in alpha_keys
+ }
+
+ lora_config_kwargs = get_peft_kwargs(rank, network_alphas, text_encoder_lora_state_dict, is_unet=False)
+ if "use_dora" in lora_config_kwargs:
+ if lora_config_kwargs["use_dora"]:
+ if is_peft_version("<", "0.9.0"):
+ raise ValueError(
+ "You need `peft` 0.9.0 at least to use DoRA-enabled LoRAs. Please upgrade your installation of `peft`."
+ )
+ else:
+ if is_peft_version("<", "0.9.0"):
+ lora_config_kwargs.pop("use_dora")
+ lora_config = LoraConfig(**lora_config_kwargs)
+
+ # adapter_name
+ if adapter_name is None:
+ adapter_name = get_adapter_name(text_encoder)
+
+ is_model_cpu_offload, is_sequential_cpu_offload = cls._optionally_disable_offloading(_pipeline)
+
+ # inject LoRA layers and load the state dict
+ # in transformers we automatically check whether the adapter name is already in use or not
+ text_encoder.load_adapter(
+ adapter_name=adapter_name,
+ adapter_state_dict=text_encoder_lora_state_dict,
+ peft_config=lora_config,
+ )
+
+ # scale LoRA layers with `lora_scale`
+ scale_lora_layers(text_encoder, weight=lora_scale)
+
+ text_encoder.to(device=text_encoder.device, dtype=text_encoder.dtype)
+
+ # Offload back.
+ if is_model_cpu_offload:
+ _pipeline.enable_model_cpu_offload()
+ elif is_sequential_cpu_offload:
+ _pipeline.enable_sequential_cpu_offload()
+ # Unsafe code />
+
+ @classmethod
+ def save_lora_weights(
+ cls,
+ save_directory: Union[str, os.PathLike],
+ text_encoder_lora_layers: Dict[str, torch.nn.Module] = None,
+ transformer_lora_layers: Dict[str, torch.nn.Module] = None,
+ is_main_process: bool = True,
+ weight_name: str = None,
+ save_function: Callable = None,
+ safe_serialization: bool = True,
+ ):
+ r"""
+ Save the LoRA parameters corresponding to the UNet and text encoder.
+
+ Arguments:
+ save_directory (`str` or `os.PathLike`):
+ Directory to save LoRA parameters to. Will be created if it doesn't exist.
+ unet_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`):
+ State dict of the LoRA layers corresponding to the `unet`.
+ text_encoder_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`):
+ State dict of the LoRA layers corresponding to the `text_encoder`. Must explicitly pass the text
+ encoder LoRA state dict because it comes from ๐ค Transformers.
+ is_main_process (`bool`, *optional*, defaults to `True`):
+ Whether the process calling this is the main process or not. Useful during distributed training and you
+ need to call this function on all processes. In this case, set `is_main_process=True` only on the main
+ process to avoid race conditions.
+ save_function (`Callable`):
+ The function to use to save the state dictionary. Useful during distributed training when you need to
+ replace `torch.save` with another method. Can be configured with the environment variable
+ `DIFFUSERS_SAVE_MODE`.
+ safe_serialization (`bool`, *optional*, defaults to `True`):
+ Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`.
+ """
+ state_dict = {}
+
+ if not (transformer_lora_layers or text_encoder_lora_layers):
+ raise ValueError("You must pass at least one of `transformer_lora_layers` or `text_encoder_lora_layers`.")
+
+ if transformer_lora_layers:
+ state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name))
+
+ if text_encoder_lora_layers:
+ state_dict.update(cls.pack_weights(text_encoder_lora_layers, cls.text_encoder_name))
+
+ # Save the model
+ cls.write_lora_layers(
+ state_dict=state_dict,
+ save_directory=save_directory,
+ is_main_process=is_main_process,
+ weight_name=weight_name,
+ save_function=save_function,
+ safe_serialization=safe_serialization,
+ )
+
+
+class LoraLoaderMixin(StableDiffusionLoraLoaderMixin):
+ def __init__(self, *args, **kwargs):
+ deprecation_message = "LoraLoaderMixin is deprecated and this will be removed in a future version. Please use `StableDiffusionLoraLoaderMixin`, instead."
+ deprecate("LoraLoaderMixin", "1.0.0", deprecation_message)
+ super().__init__(*args, **kwargs)
\ No newline at end of file
diff --git a/flux/lora/peft.py b/flux/lora/peft.py
new file mode 100644
index 0000000000000000000000000000000000000000..cf255aaac298b1f732f25412a02303a8355102fb
--- /dev/null
+++ b/flux/lora/peft.py
@@ -0,0 +1,395 @@
+# coding=utf-8
+# Copyright 2024 The HuggingFace Inc. team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import inspect
+from functools import partial
+from typing import Dict, List, Optional, Union
+
+from diffusers.utils import (
+ MIN_PEFT_VERSION,
+ USE_PEFT_BACKEND,
+ check_peft_version,
+ delete_adapter_layers,
+ is_peft_available,
+ set_adapter_layers,
+ set_weights_and_activate_adapters,
+)
+#from .unet_loader_utils import _maybe_expand_lora_scales
+
+
+_SET_ADAPTER_SCALE_FN_MAPPING = {
+ #"UNet2DConditionModel": _maybe_expand_lora_scales,
+ #"UNetMotionModel": _maybe_expand_lora_scales,
+ "SD3Transformer2DModel": lambda model_cls, weights: weights,
+ "FluxTransformer2DModel": lambda model_cls, weights: weights,
+}
+
+
+class PeftAdapterMixin:
+ """
+ A class containing all functions for loading and using adapters weights that are supported in PEFT library. For
+ more details about adapters and injecting them in a base model, check out the PEFT
+ [documentation](https://huggingface.co/docs/peft/index).
+
+ Install the latest version of PEFT, and use this mixin to:
+
+ - Attach new adapters in the model.
+ - Attach multiple adapters and iteratively activate/deactivate them.
+ - Activate/deactivate all adapters from the model.
+ - Get a list of the active adapters.
+ """
+
+ _hf_peft_config_loaded = False
+
+ def set_adapters(
+ self,
+ adapter_names: Union[List[str], str],
+ weights: Optional[Union[float, Dict, List[float], List[Dict], List[None]]] = None,
+ ):
+ """
+ Set the currently active adapters for use in the UNet.
+
+ Args:
+ adapter_names (`List[str]` or `str`):
+ The names of the adapters to use.
+ adapter_weights (`Union[List[float], float]`, *optional*):
+ The adapter(s) weights to use with the UNet. If `None`, the weights are set to `1.0` for all the
+ adapters.
+
+ Example:
+
+ ```py
+ from diffusers import AutoPipelineForText2Image
+ import torch
+
+ pipeline = AutoPipelineForText2Image.from_pretrained(
+ "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16
+ ).to("cuda")
+ pipeline.load_lora_weights(
+ "jbilcke-hf/sdxl-cinematic-1", weight_name="pytorch_lora_weights.safetensors", adapter_name="cinematic"
+ )
+ pipeline.load_lora_weights("nerijs/pixel-art-xl", weight_name="pixel-art-xl.safetensors", adapter_name="pixel")
+ pipeline.set_adapters(["cinematic", "pixel"], adapter_weights=[0.5, 0.5])
+ ```
+ """
+ if not USE_PEFT_BACKEND:
+ raise ValueError("PEFT backend is required for `set_adapters()`.")
+
+ adapter_names = [adapter_names] if isinstance(adapter_names, str) else adapter_names
+
+ # Expand weights into a list, one entry per adapter
+ # examples for e.g. 2 adapters: [{...}, 7] -> [7,7] ; None -> [None, None]
+ if not isinstance(weights, list):
+ weights = [weights] * len(adapter_names)
+
+ if len(adapter_names) != len(weights):
+ raise ValueError(
+ f"Length of adapter names {len(adapter_names)} is not equal to the length of their weights {len(weights)}."
+ )
+
+ # Set None values to default of 1.0
+ # e.g. [{...}, 7] -> [{...}, 7] ; [None, None] -> [1.0, 1.0]
+ weights = [w if w is not None else 1.0 for w in weights]
+
+ # e.g. [{...}, 7] -> [{expanded dict...}, 7]
+ scale_expansion_fn = _SET_ADAPTER_SCALE_FN_MAPPING[self.__class__.__name__]
+ weights = scale_expansion_fn(self, weights)
+
+ set_weights_and_activate_adapters(self, adapter_names, weights)
+
+ def add_adapter(self, adapter_config, adapter_name: str = "default") -> None:
+ r"""
+ Adds a new adapter to the current model for training. If no adapter name is passed, a default name is assigned
+ to the adapter to follow the convention of the PEFT library.
+
+ If you are not familiar with adapters and PEFT methods, we invite you to read more about them in the PEFT
+ [documentation](https://huggingface.co/docs/peft).
+
+ Args:
+ adapter_config (`[~peft.PeftConfig]`):
+ The configuration of the adapter to add; supported adapters are non-prefix tuning and adaption prompt
+ methods.
+ adapter_name (`str`, *optional*, defaults to `"default"`):
+ The name of the adapter to add. If no name is passed, a default name is assigned to the adapter.
+ """
+ check_peft_version(min_version=MIN_PEFT_VERSION)
+
+ if not is_peft_available():
+ raise ImportError("PEFT is not available. Please install PEFT to use this function: `pip install peft`.")
+
+ from peft import PeftConfig, inject_adapter_in_model
+
+ if not self._hf_peft_config_loaded:
+ self._hf_peft_config_loaded = True
+ elif adapter_name in self.peft_config:
+ raise ValueError(f"Adapter with name {adapter_name} already exists. Please use a different name.")
+
+ if not isinstance(adapter_config, PeftConfig):
+ raise ValueError(
+ f"adapter_config should be an instance of PeftConfig. Got {type(adapter_config)} instead."
+ )
+
+ # Unlike transformers, here we don't need to retrieve the name_or_path of the unet as the loading logic is
+ # handled by the `load_lora_layers` or `StableDiffusionLoraLoaderMixin`. Therefore we set it to `None` here.
+ adapter_config.base_model_name_or_path = None
+ inject_adapter_in_model(adapter_config, self, adapter_name)
+ self.set_adapter(adapter_name)
+
+ def set_adapter(self, adapter_name: Union[str, List[str]]) -> None:
+ """
+ Sets a specific adapter by forcing the model to only use that adapter and disables the other adapters.
+
+ If you are not familiar with adapters and PEFT methods, we invite you to read more about them on the PEFT
+ [documentation](https://huggingface.co/docs/peft).
+
+ Args:
+ adapter_name (Union[str, List[str]])):
+ The list of adapters to set or the adapter name in the case of a single adapter.
+ """
+ check_peft_version(min_version=MIN_PEFT_VERSION)
+
+ if not self._hf_peft_config_loaded:
+ raise ValueError("No adapter loaded. Please load an adapter first.")
+
+ if isinstance(adapter_name, str):
+ adapter_name = [adapter_name]
+
+ missing = set(adapter_name) - set(self.peft_config)
+ if len(missing) > 0:
+ raise ValueError(
+ f"Following adapter(s) could not be found: {', '.join(missing)}. Make sure you are passing the correct adapter name(s)."
+ f" current loaded adapters are: {list(self.peft_config.keys())}"
+ )
+
+ from peft.tuners.tuners_utils import BaseTunerLayer
+
+ _adapters_has_been_set = False
+
+ for _, module in self.named_modules():
+ if isinstance(module, BaseTunerLayer):
+ if hasattr(module, "set_adapter"):
+ module.set_adapter(adapter_name)
+ # Previous versions of PEFT does not support multi-adapter inference
+ elif not hasattr(module, "set_adapter") and len(adapter_name) != 1:
+ raise ValueError(
+ "You are trying to set multiple adapters and you have a PEFT version that does not support multi-adapter inference. Please upgrade to the latest version of PEFT."
+ " `pip install -U peft` or `pip install -U git+https://github.com/huggingface/peft.git`"
+ )
+ else:
+ module.active_adapter = adapter_name
+ _adapters_has_been_set = True
+
+ if not _adapters_has_been_set:
+ raise ValueError(
+ "Did not succeeded in setting the adapter. Please make sure you are using a model that supports adapters."
+ )
+
+ def disable_adapters(self) -> None:
+ r"""
+ Disable all adapters attached to the model and fallback to inference with the base model only.
+
+ If you are not familiar with adapters and PEFT methods, we invite you to read more about them on the PEFT
+ [documentation](https://huggingface.co/docs/peft).
+ """
+ check_peft_version(min_version=MIN_PEFT_VERSION)
+
+ if not self._hf_peft_config_loaded:
+ raise ValueError("No adapter loaded. Please load an adapter first.")
+
+ from peft.tuners.tuners_utils import BaseTunerLayer
+
+ for _, module in self.named_modules():
+ if isinstance(module, BaseTunerLayer):
+ if hasattr(module, "enable_adapters"):
+ module.enable_adapters(enabled=False)
+ else:
+ # support for older PEFT versions
+ module.disable_adapters = True
+
+ def enable_adapters(self) -> None:
+ """
+ Enable adapters that are attached to the model. The model uses `self.active_adapters()` to retrieve the list of
+ adapters to enable.
+
+ If you are not familiar with adapters and PEFT methods, we invite you to read more about them on the PEFT
+ [documentation](https://huggingface.co/docs/peft).
+ """
+ check_peft_version(min_version=MIN_PEFT_VERSION)
+
+ if not self._hf_peft_config_loaded:
+ raise ValueError("No adapter loaded. Please load an adapter first.")
+
+ from peft.tuners.tuners_utils import BaseTunerLayer
+
+ for _, module in self.named_modules():
+ if isinstance(module, BaseTunerLayer):
+ if hasattr(module, "enable_adapters"):
+ module.enable_adapters(enabled=True)
+ else:
+ # support for older PEFT versions
+ module.disable_adapters = False
+
+ def active_adapters(self) -> List[str]:
+ """
+ Gets the current list of active adapters of the model.
+
+ If you are not familiar with adapters and PEFT methods, we invite you to read more about them on the PEFT
+ [documentation](https://huggingface.co/docs/peft).
+ """
+ check_peft_version(min_version=MIN_PEFT_VERSION)
+
+ if not is_peft_available():
+ raise ImportError("PEFT is not available. Please install PEFT to use this function: `pip install peft`.")
+
+ if not self._hf_peft_config_loaded:
+ raise ValueError("No adapter loaded. Please load an adapter first.")
+
+ from peft.tuners.tuners_utils import BaseTunerLayer
+
+ for _, module in self.named_modules():
+ if isinstance(module, BaseTunerLayer):
+ return module.active_adapter
+
+ def fuse_lora(self, lora_scale=1.0, safe_fusing=False, adapter_names=None):
+ if not USE_PEFT_BACKEND:
+ raise ValueError("PEFT backend is required for `fuse_lora()`.")
+
+ self.lora_scale = lora_scale
+ self._safe_fusing = safe_fusing
+ self.apply(partial(self._fuse_lora_apply, adapter_names=adapter_names))
+
+ def _fuse_lora_apply(self, module, adapter_names=None):
+ from peft.tuners.tuners_utils import BaseTunerLayer
+
+ merge_kwargs = {"safe_merge": self._safe_fusing}
+
+ if isinstance(module, BaseTunerLayer):
+ if self.lora_scale != 1.0:
+ module.scale_layer(self.lora_scale)
+
+ # For BC with prevous PEFT versions, we need to check the signature
+ # of the `merge` method to see if it supports the `adapter_names` argument.
+ supported_merge_kwargs = list(inspect.signature(module.merge).parameters)
+ if "adapter_names" in supported_merge_kwargs:
+ merge_kwargs["adapter_names"] = adapter_names
+ elif "adapter_names" not in supported_merge_kwargs and adapter_names is not None:
+ raise ValueError(
+ "The `adapter_names` argument is not supported with your PEFT version. Please upgrade"
+ " to the latest version of PEFT. `pip install -U peft`"
+ )
+
+ module.merge(**merge_kwargs)
+
+ def unfuse_lora(self):
+ if not USE_PEFT_BACKEND:
+ raise ValueError("PEFT backend is required for `unfuse_lora()`.")
+ self.apply(self._unfuse_lora_apply)
+
+ def _unfuse_lora_apply(self, module):
+ from peft.tuners.tuners_utils import BaseTunerLayer
+
+ if isinstance(module, BaseTunerLayer):
+ module.unmerge()
+
+ def unload_lora(self):
+ if not USE_PEFT_BACKEND:
+ raise ValueError("PEFT backend is required for `unload_lora()`.")
+
+ from diffusers.utils import recurse_remove_peft_layers
+
+ recurse_remove_peft_layers(self)
+ if hasattr(self, "peft_config"):
+ del self.peft_config
+
+ def disable_lora(self):
+ """
+ Disables the active LoRA layers of the underlying model.
+
+ Example:
+
+ ```py
+ from diffusers import AutoPipelineForText2Image
+ import torch
+
+ pipeline = AutoPipelineForText2Image.from_pretrained(
+ "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16
+ ).to("cuda")
+ pipeline.load_lora_weights(
+ "jbilcke-hf/sdxl-cinematic-1", weight_name="pytorch_lora_weights.safetensors", adapter_name="cinematic"
+ )
+ pipeline.disable_lora()
+ ```
+ """
+ if not USE_PEFT_BACKEND:
+ raise ValueError("PEFT backend is required for this method.")
+ set_adapter_layers(self, enabled=False)
+
+ def enable_lora(self):
+ """
+ Enables the active LoRA layers of the underlying model.
+
+ Example:
+
+ ```py
+ from diffusers import AutoPipelineForText2Image
+ import torch
+
+ pipeline = AutoPipelineForText2Image.from_pretrained(
+ "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16
+ ).to("cuda")
+ pipeline.load_lora_weights(
+ "jbilcke-hf/sdxl-cinematic-1", weight_name="pytorch_lora_weights.safetensors", adapter_name="cinematic"
+ )
+ pipeline.enable_lora()
+ ```
+ """
+ if not USE_PEFT_BACKEND:
+ raise ValueError("PEFT backend is required for this method.")
+ set_adapter_layers(self, enabled=True)
+
+ def delete_adapters(self, adapter_names: Union[List[str], str]):
+ """
+ Delete an adapter's LoRA layers from the underlying model.
+
+ Args:
+ adapter_names (`Union[List[str], str]`):
+ The names (single string or list of strings) of the adapter to delete.
+
+ Example:
+
+ ```py
+ from diffusers import AutoPipelineForText2Image
+ import torch
+
+ pipeline = AutoPipelineForText2Image.from_pretrained(
+ "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16
+ ).to("cuda")
+ pipeline.load_lora_weights(
+ "jbilcke-hf/sdxl-cinematic-1", weight_name="pytorch_lora_weights.safetensors", adapter_names="cinematic"
+ )
+ pipeline.delete_adapters("cinematic")
+ ```
+ """
+ if not USE_PEFT_BACKEND:
+ raise ValueError("PEFT backend is required for this method.")
+
+ if isinstance(adapter_names, str):
+ adapter_names = [adapter_names]
+
+ for adapter_name in adapter_names:
+ delete_adapter_layers(self, adapter_name)
+
+ # Pop also the corresponding adapter from the config
+ if hasattr(self, "peft_config"):
+ self.peft_config.pop(adapter_name, None)
\ No newline at end of file
diff --git a/flux/normalization.py b/flux/normalization.py
new file mode 100644
index 0000000000000000000000000000000000000000..4db7fbe27692f6d5c1d67aa9cd772296b4701a27
--- /dev/null
+++ b/flux/normalization.py
@@ -0,0 +1,393 @@
+# coding=utf-8
+# Copyright 2024 HuggingFace Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import numbers
+from typing import Dict, Optional, Tuple
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from diffusers.utils import is_torch_version
+from .embeddings import get_activation
+from .embeddings import (
+ CombinedTimestepLabelEmbeddings,
+ PixArtAlphaCombinedTimestepSizeEmbeddings,
+)
+
+
+class AdaLayerNorm(nn.Module):
+ r"""
+ Norm layer modified to incorporate timestep embeddings.
+
+ Parameters:
+ embedding_dim (`int`): The size of each embedding vector.
+ num_embeddings (`int`): The size of the embeddings dictionary.
+ """
+
+ def __init__(self, embedding_dim: int, num_embeddings: int):
+ super().__init__()
+ self.emb = nn.Embedding(num_embeddings, embedding_dim)
+ self.silu = nn.SiLU()
+ self.linear = nn.Linear(embedding_dim, embedding_dim * 2)
+ self.norm = nn.LayerNorm(embedding_dim, elementwise_affine=False)
+
+ def forward(self, x: torch.Tensor, timestep: torch.Tensor) -> torch.Tensor:
+ emb = self.linear(self.silu(self.emb(timestep)))
+ scale, shift = torch.chunk(emb, 2)
+ x = self.norm(x) * (1 + scale) + shift
+ return x
+
+
+class FP32LayerNorm(nn.LayerNorm):
+ def forward(self, inputs: torch.Tensor) -> torch.Tensor:
+ origin_dtype = inputs.dtype
+ return F.layer_norm(
+ inputs.float(),
+ self.normalized_shape,
+ self.weight.float() if self.weight is not None else None,
+ self.bias.float() if self.bias is not None else None,
+ self.eps,
+ ).to(origin_dtype)
+
+
+class AdaLayerNormZero(nn.Module):
+ r"""
+ Norm layer adaptive layer norm zero (adaLN-Zero).
+
+ Parameters:
+ embedding_dim (`int`): The size of each embedding vector.
+ num_embeddings (`int`): The size of the embeddings dictionary.
+ """
+
+ def __init__(self, embedding_dim: int, num_embeddings: Optional[int] = None, norm_type="layer_norm", bias=True):
+ super().__init__()
+ if num_embeddings is not None:
+ self.emb = CombinedTimestepLabelEmbeddings(num_embeddings, embedding_dim)
+ else:
+ self.emb = None
+
+ self.silu = nn.SiLU()
+ self.linear = nn.Linear(embedding_dim, 6 * embedding_dim, bias=bias)
+ if norm_type == "layer_norm":
+ self.norm = nn.LayerNorm(embedding_dim, elementwise_affine=False, eps=1e-6)
+ elif norm_type == "fp32_layer_norm":
+ self.norm = FP32LayerNorm(embedding_dim, elementwise_affine=False, bias=False)
+ else:
+ raise ValueError(
+ f"Unsupported `norm_type` ({norm_type}) provided. Supported ones are: 'layer_norm', 'fp32_layer_norm'."
+ )
+
+ def forward(
+ self,
+ x: torch.Tensor,
+ timestep: Optional[torch.Tensor] = None,
+ class_labels: Optional[torch.LongTensor] = None,
+ hidden_dtype: Optional[torch.dtype] = None,
+ emb: Optional[torch.Tensor] = None,
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
+ if self.emb is not None:
+ emb = self.emb(timestep, class_labels, hidden_dtype=hidden_dtype)
+ emb = self.linear(self.silu(emb))
+ shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = emb.chunk(6, dim=1)
+ x = self.norm(x) * (1 + scale_msa[:, None]) + shift_msa[:, None]
+ return x, gate_msa, shift_mlp, scale_mlp, gate_mlp
+
+
+class AdaLayerNormZeroSingle(nn.Module):
+ r"""
+ Norm layer adaptive layer norm zero (adaLN-Zero).
+
+ Parameters:
+ embedding_dim (`int`): The size of each embedding vector.
+ num_embeddings (`int`): The size of the embeddings dictionary.
+ """
+
+ def __init__(self, embedding_dim: int, norm_type="layer_norm", bias=True):
+ super().__init__()
+
+ self.silu = nn.SiLU()
+ self.linear = nn.Linear(embedding_dim, 3 * embedding_dim, bias=bias)
+ if norm_type == "layer_norm":
+ self.norm = nn.LayerNorm(embedding_dim, elementwise_affine=False, eps=1e-6)
+ else:
+ raise ValueError(
+ f"Unsupported `norm_type` ({norm_type}) provided. Supported ones are: 'layer_norm', 'fp32_layer_norm'."
+ )
+
+ def forward(
+ self,
+ x: torch.Tensor,
+ emb: Optional[torch.Tensor] = None,
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
+ emb = self.linear(self.silu(emb))
+ shift_msa, scale_msa, gate_msa = emb.chunk(3, dim=1)
+ x = self.norm(x) * (1 + scale_msa[:, None]) + shift_msa[:, None]
+ return x, gate_msa
+
+
+class LuminaRMSNormZero(nn.Module):
+ """
+ Norm layer adaptive RMS normalization zero.
+
+ Parameters:
+ embedding_dim (`int`): The size of each embedding vector.
+ """
+
+ def __init__(self, embedding_dim: int, norm_eps: float, norm_elementwise_affine: bool):
+ super().__init__()
+ self.silu = nn.SiLU()
+ self.linear = nn.Linear(
+ min(embedding_dim, 1024),
+ 4 * embedding_dim,
+ bias=True,
+ )
+ self.norm = RMSNorm(embedding_dim, eps=norm_eps, elementwise_affine=norm_elementwise_affine)
+
+ def forward(
+ self,
+ x: torch.Tensor,
+ emb: Optional[torch.Tensor] = None,
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
+ # emb = self.emb(timestep, encoder_hidden_states, encoder_mask)
+ emb = self.linear(self.silu(emb))
+ scale_msa, gate_msa, scale_mlp, gate_mlp = emb.chunk(4, dim=1)
+ x = self.norm(x) * (1 + scale_msa[:, None])
+
+ return x, gate_msa, scale_mlp, gate_mlp
+
+
+class AdaLayerNormSingle(nn.Module):
+ r"""
+ Norm layer adaptive layer norm single (adaLN-single).
+
+ As proposed in PixArt-Alpha (see: https://arxiv.org/abs/2310.00426; Section 2.3).
+
+ Parameters:
+ embedding_dim (`int`): The size of each embedding vector.
+ use_additional_conditions (`bool`): To use additional conditions for normalization or not.
+ """
+
+ def __init__(self, embedding_dim: int, use_additional_conditions: bool = False):
+ super().__init__()
+
+ self.emb = PixArtAlphaCombinedTimestepSizeEmbeddings(
+ embedding_dim, size_emb_dim=embedding_dim // 3, use_additional_conditions=use_additional_conditions
+ )
+
+ self.silu = nn.SiLU()
+ self.linear = nn.Linear(embedding_dim, 6 * embedding_dim, bias=True)
+
+ def forward(
+ self,
+ timestep: torch.Tensor,
+ added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None,
+ batch_size: Optional[int] = None,
+ hidden_dtype: Optional[torch.dtype] = None,
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
+ # No modulation happening here.
+ embedded_timestep = self.emb(timestep, **added_cond_kwargs, batch_size=batch_size, hidden_dtype=hidden_dtype)
+ return self.linear(self.silu(embedded_timestep)), embedded_timestep
+
+
+class AdaGroupNorm(nn.Module):
+ r"""
+ GroupNorm layer modified to incorporate timestep embeddings.
+
+ Parameters:
+ embedding_dim (`int`): The size of each embedding vector.
+ num_embeddings (`int`): The size of the embeddings dictionary.
+ num_groups (`int`): The number of groups to separate the channels into.
+ act_fn (`str`, *optional*, defaults to `None`): The activation function to use.
+ eps (`float`, *optional*, defaults to `1e-5`): The epsilon value to use for numerical stability.
+ """
+
+ def __init__(
+ self, embedding_dim: int, out_dim: int, num_groups: int, act_fn: Optional[str] = None, eps: float = 1e-5
+ ):
+ super().__init__()
+ self.num_groups = num_groups
+ self.eps = eps
+
+ if act_fn is None:
+ self.act = None
+ else:
+ self.act = get_activation(act_fn)
+
+ self.linear = nn.Linear(embedding_dim, out_dim * 2)
+
+ def forward(self, x: torch.Tensor, emb: torch.Tensor) -> torch.Tensor:
+ if self.act:
+ emb = self.act(emb)
+ emb = self.linear(emb)
+ emb = emb[:, :, None, None]
+ scale, shift = emb.chunk(2, dim=1)
+
+ x = F.group_norm(x, self.num_groups, eps=self.eps)
+ x = x * (1 + scale) + shift
+ return x
+
+
+class AdaLayerNormContinuous(nn.Module):
+ def __init__(
+ self,
+ embedding_dim: int,
+ conditioning_embedding_dim: int,
+ # NOTE: It is a bit weird that the norm layer can be configured to have scale and shift parameters
+ # because the output is immediately scaled and shifted by the projected conditioning embeddings.
+ # Note that AdaLayerNorm does not let the norm layer have scale and shift parameters.
+ # However, this is how it was implemented in the original code, and it's rather likely you should
+ # set `elementwise_affine` to False.
+ elementwise_affine=True,
+ eps=1e-5,
+ bias=True,
+ norm_type="layer_norm",
+ ):
+ super().__init__()
+ self.silu = nn.SiLU()
+ self.linear = nn.Linear(conditioning_embedding_dim, embedding_dim * 2, bias=bias)
+ if norm_type == "layer_norm":
+ self.norm = LayerNorm(embedding_dim, eps, elementwise_affine, bias)
+ elif norm_type == "rms_norm":
+ self.norm = RMSNorm(embedding_dim, eps, elementwise_affine)
+ else:
+ raise ValueError(f"unknown norm_type {norm_type}")
+
+ def forward(self, x: torch.Tensor, conditioning_embedding: torch.Tensor) -> torch.Tensor:
+ # convert back to the original dtype in case `conditioning_embedding`` is upcasted to float32 (needed for hunyuanDiT)
+ emb = self.linear(self.silu(conditioning_embedding).to(x.dtype))
+ scale, shift = torch.chunk(emb, 2, dim=1)
+ x = self.norm(x) * (1 + scale)[:, None, :] + shift[:, None, :]
+ return x
+
+
+class LuminaLayerNormContinuous(nn.Module):
+ def __init__(
+ self,
+ embedding_dim: int,
+ conditioning_embedding_dim: int,
+ # NOTE: It is a bit weird that the norm layer can be configured to have scale and shift parameters
+ # because the output is immediately scaled and shifted by the projected conditioning embeddings.
+ # Note that AdaLayerNorm does not let the norm layer have scale and shift parameters.
+ # However, this is how it was implemented in the original code, and it's rather likely you should
+ # set `elementwise_affine` to False.
+ elementwise_affine=True,
+ eps=1e-5,
+ bias=True,
+ norm_type="layer_norm",
+ out_dim: Optional[int] = None,
+ ):
+ super().__init__()
+ # AdaLN
+ self.silu = nn.SiLU()
+ self.linear_1 = nn.Linear(conditioning_embedding_dim, embedding_dim, bias=bias)
+ if norm_type == "layer_norm":
+ self.norm = LayerNorm(embedding_dim, eps, elementwise_affine, bias)
+ else:
+ raise ValueError(f"unknown norm_type {norm_type}")
+ # linear_2
+ if out_dim is not None:
+ self.linear_2 = nn.Linear(
+ embedding_dim,
+ out_dim,
+ bias=bias,
+ )
+
+ def forward(
+ self,
+ x: torch.Tensor,
+ conditioning_embedding: torch.Tensor,
+ ) -> torch.Tensor:
+ # convert back to the original dtype in case `conditioning_embedding`` is upcasted to float32 (needed for hunyuanDiT)
+ emb = self.linear_1(self.silu(conditioning_embedding).to(x.dtype))
+ scale = emb
+ x = self.norm(x) * (1 + scale)[:, None, :]
+
+ if self.linear_2 is not None:
+ x = self.linear_2(x)
+
+ return x
+
+
+if is_torch_version(">=", "2.1.0"):
+ LayerNorm = nn.LayerNorm
+else:
+ # Has optional bias parameter compared to torch layer norm
+ # TODO: replace with torch layernorm once min required torch version >= 2.1
+ class LayerNorm(nn.Module):
+ def __init__(self, dim, eps: float = 1e-5, elementwise_affine: bool = True, bias: bool = True):
+ super().__init__()
+
+ self.eps = eps
+
+ if isinstance(dim, numbers.Integral):
+ dim = (dim,)
+
+ self.dim = torch.Size(dim)
+
+ if elementwise_affine:
+ self.weight = nn.Parameter(torch.ones(dim))
+ self.bias = nn.Parameter(torch.zeros(dim)) if bias else None
+ else:
+ self.weight = None
+ self.bias = None
+
+ def forward(self, input):
+ return F.layer_norm(input, self.dim, self.weight, self.bias, self.eps)
+
+
+class RMSNorm(nn.Module):
+ def __init__(self, dim, eps: float, elementwise_affine: bool = True):
+ super().__init__()
+
+ self.eps = eps
+
+ if isinstance(dim, numbers.Integral):
+ dim = (dim,)
+
+ self.dim = torch.Size(dim)
+
+ if elementwise_affine:
+ self.weight = nn.Parameter(torch.ones(dim))
+ else:
+ self.weight = None
+
+ def forward(self, hidden_states):
+ input_dtype = hidden_states.dtype
+ variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)
+ hidden_states = hidden_states * torch.rsqrt(variance + self.eps)
+
+ if self.weight is not None:
+ # convert into half-precision if necessary
+ if self.weight.dtype in [torch.float16, torch.bfloat16]:
+ hidden_states = hidden_states.to(self.weight.dtype)
+ hidden_states = hidden_states * self.weight
+ else:
+ hidden_states = hidden_states.to(input_dtype)
+
+ return hidden_states
+
+
+class GlobalResponseNorm(nn.Module):
+ # Taken from https://github.com/facebookresearch/ConvNeXt-V2/blob/3608f67cc1dae164790c5d0aead7bf2d73d9719b/models/utils.py#L105
+ def __init__(self, dim):
+ super().__init__()
+ self.gamma = nn.Parameter(torch.zeros(1, 1, 1, dim))
+ self.beta = nn.Parameter(torch.zeros(1, 1, 1, dim))
+
+ def forward(self, x):
+ gx = torch.norm(x, p=2, dim=(1, 2), keepdim=True)
+ nx = gx / (gx.mean(dim=-1, keepdim=True) + 1e-6)
+ return self.gamma * (x * nx) + self.beta + x
\ No newline at end of file
diff --git a/flux/pipeline_flux.py b/flux/pipeline_flux.py
new file mode 100644
index 0000000000000000000000000000000000000000..527c8565e7816b5df973acbee74fa1f5f05d4472
--- /dev/null
+++ b/flux/pipeline_flux.py
@@ -0,0 +1,749 @@
+# Copyright 2024 Black Forest Labs and The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import inspect
+from typing import Any, Callable, Dict, List, Optional, Union
+
+import numpy as np
+import torch
+from transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5TokenizerFast
+
+from diffusers.image_processor import VaeImageProcessor
+from .lora.lora_pipeline import FluxLoraLoaderMixin
+from diffusers.models.autoencoders import AutoencoderKL
+from .transformer_flux import FluxTransformer2DModel
+from .scheduling_flow_match_euler_discrete import FlowMatchEulerDiscreteScheduler
+from diffusers.utils import (
+ USE_PEFT_BACKEND,
+ is_torch_xla_available,
+ logging,
+ replace_example_docstring,
+ scale_lora_layers,
+ unscale_lora_layers,
+)
+from diffusers.utils.torch_utils import randn_tensor
+from diffusers.pipelines.pipeline_utils import DiffusionPipeline
+from .pipeline_output import FluxPipelineOutput
+
+
+if is_torch_xla_available():
+ import torch_xla.core.xla_model as xm
+
+ XLA_AVAILABLE = True
+else:
+ XLA_AVAILABLE = False
+
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+EXAMPLE_DOC_STRING = """
+ Examples:
+ ```py
+ >>> import torch
+ >>> from diffusers import FluxPipeline
+
+ >>> pipe = FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell", torch_dtype=torch.bfloat16)
+ >>> pipe.to("cuda")
+ >>> prompt = "A cat holding a sign that says hello world"
+ >>> # Depending on the variant being used, the pipeline call will slightly vary.
+ >>> # Refer to the pipeline documentation for more details.
+ >>> image = pipe(prompt, num_inference_steps=4, guidance_scale=0.0).images[0]
+ >>> image.save("flux.png")
+ ```
+"""
+
+
+def calculate_shift(
+ image_seq_len,
+ base_seq_len: int = 256,
+ max_seq_len: int = 4096,
+ base_shift: float = 0.5,
+ max_shift: float = 1.16,
+):
+ m = (max_shift - base_shift) / (max_seq_len - base_seq_len)
+ b = base_shift - m * base_seq_len
+ mu = image_seq_len * m + b
+ return mu
+
+
+# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
+def retrieve_timesteps(
+ scheduler,
+ num_inference_steps: Optional[int] = None,
+ device: Optional[Union[str, torch.device]] = None,
+ timesteps: Optional[List[int]] = None,
+ sigmas: Optional[List[float]] = None,
+ **kwargs,
+):
+ """
+ Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
+ custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
+
+ Args:
+ scheduler (`SchedulerMixin`):
+ The scheduler to get timesteps from.
+ num_inference_steps (`int`):
+ The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
+ must be `None`.
+ device (`str` or `torch.device`, *optional*):
+ The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
+ timesteps (`List[int]`, *optional*):
+ Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
+ `num_inference_steps` and `sigmas` must be `None`.
+ sigmas (`List[float]`, *optional*):
+ Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
+ `num_inference_steps` and `timesteps` must be `None`.
+
+ Returns:
+ `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
+ second element is the number of inference steps.
+ """
+ if timesteps is not None and sigmas is not None:
+ raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
+ if timesteps is not None:
+ accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
+ if not accepts_timesteps:
+ raise ValueError(
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
+ f" timestep schedules. Please check whether you are using the correct scheduler."
+ )
+ scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
+ timesteps = scheduler.timesteps
+ num_inference_steps = len(timesteps)
+ elif sigmas is not None:
+ accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
+ if not accept_sigmas:
+ raise ValueError(
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
+ f" sigmas schedules. Please check whether you are using the correct scheduler."
+ )
+ scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
+ timesteps = scheduler.timesteps
+ num_inference_steps = len(timesteps)
+ else:
+ scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
+ timesteps = scheduler.timesteps
+ return timesteps, num_inference_steps
+
+
+class FluxPipeline(DiffusionPipeline, FluxLoraLoaderMixin):
+ r"""
+ The Flux pipeline for text-to-image generation.
+
+ Reference: https://blackforestlabs.ai/announcing-black-forest-labs/
+
+ Args:
+ transformer ([`FluxTransformer2DModel`]):
+ Conditional Transformer (MMDiT) architecture to denoise the encoded image latents.
+ scheduler ([`FlowMatchEulerDiscreteScheduler`]):
+ A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
+ vae ([`AutoencoderKL`]):
+ Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
+ text_encoder ([`CLIPTextModel`]):
+ [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically
+ the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
+ text_encoder_2 ([`T5EncoderModel`]):
+ [T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5EncoderModel), specifically
+ the [google/t5-v1_1-xxl](https://huggingface.co/google/t5-v1_1-xxl) variant.
+ tokenizer (`CLIPTokenizer`):
+ Tokenizer of class
+ [CLIPTokenizer](https://huggingface.co/docs/transformers/en/model_doc/clip#transformers.CLIPTokenizer).
+ tokenizer_2 (`T5TokenizerFast`):
+ Second Tokenizer of class
+ [T5TokenizerFast](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5TokenizerFast).
+ """
+
+ model_cpu_offload_seq = "text_encoder->text_encoder_2->transformer->vae"
+ _optional_components = []
+ _callback_tensor_inputs = ["latents", "prompt_embeds"]
+
+ def __init__(
+ self,
+ scheduler: FlowMatchEulerDiscreteScheduler,
+ vae: AutoencoderKL,
+ text_encoder: CLIPTextModel,
+ tokenizer: CLIPTokenizer,
+ text_encoder_2: T5EncoderModel,
+ tokenizer_2: T5TokenizerFast,
+ transformer: FluxTransformer2DModel,
+ ):
+ super().__init__()
+
+ self.register_modules(
+ vae=vae,
+ text_encoder=text_encoder,
+ text_encoder_2=text_encoder_2,
+ tokenizer=tokenizer,
+ tokenizer_2=tokenizer_2,
+ transformer=transformer,
+ scheduler=scheduler,
+ )
+ self.vae_scale_factor = (
+ 2 ** (len(self.vae.config.block_out_channels)) if hasattr(self, "vae") and self.vae is not None else 16
+ )
+ self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
+ self.tokenizer_max_length = (
+ self.tokenizer.model_max_length if hasattr(self, "tokenizer") and self.tokenizer is not None else 77
+ )
+ self.default_sample_size = 64
+
+ def _get_t5_prompt_embeds(
+ self,
+ prompt: Union[str, List[str]] = None,
+ num_images_per_prompt: int = 1,
+ max_sequence_length: int = 512,
+ device: Optional[torch.device] = None,
+ dtype: Optional[torch.dtype] = None,
+ ):
+ device = device or self._execution_device
+ dtype = dtype or self.text_encoder.dtype
+
+ prompt = [prompt] if isinstance(prompt, str) else prompt
+ batch_size = len(prompt)
+
+ text_inputs = self.tokenizer_2(
+ prompt,
+ padding="max_length",
+ max_length=max_sequence_length,
+ truncation=True,
+ return_length=False,
+ return_overflowing_tokens=False,
+ return_tensors="pt",
+ )
+ text_input_ids = text_inputs.input_ids
+ untruncated_ids = self.tokenizer_2(prompt, padding="longest", return_tensors="pt").input_ids
+
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
+ removed_text = self.tokenizer_2.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1])
+ logger.warning(
+ "The following part of your input was truncated because `max_sequence_length` is set to "
+ f" {max_sequence_length} tokens: {removed_text}"
+ )
+
+ prompt_embeds = self.text_encoder_2(text_input_ids.to(device), output_hidden_states=False)[0]
+
+ dtype = self.text_encoder_2.dtype
+ prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
+
+ _, seq_len, _ = prompt_embeds.shape
+
+ # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
+ prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
+
+ return prompt_embeds
+
+ def _get_clip_prompt_embeds(
+ self,
+ prompt: Union[str, List[str]],
+ num_images_per_prompt: int = 1,
+ device: Optional[torch.device] = None,
+ ):
+ device = device or self._execution_device
+
+ prompt = [prompt] if isinstance(prompt, str) else prompt
+ batch_size = len(prompt)
+
+ text_inputs = self.tokenizer(
+ prompt,
+ padding="max_length",
+ max_length=self.tokenizer_max_length,
+ truncation=True,
+ return_overflowing_tokens=False,
+ return_length=False,
+ return_tensors="pt",
+ )
+
+ text_input_ids = text_inputs.input_ids
+ untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
+ removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1])
+ logger.warning(
+ "The following part of your input was truncated because CLIP can only handle sequences up to"
+ f" {self.tokenizer_max_length} tokens: {removed_text}"
+ )
+ prompt_embeds = self.text_encoder(text_input_ids.to(device), output_hidden_states=False)
+
+ # Use pooled output of CLIPTextModel
+ prompt_embeds = prompt_embeds.pooler_output
+ prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)
+
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
+ prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, -1)
+
+ return prompt_embeds
+
+ def encode_prompt(
+ self,
+ prompt: Union[str, List[str]],
+ prompt_2: Union[str, List[str]],
+ device: Optional[torch.device] = None,
+ num_images_per_prompt: int = 1,
+ prompt_embeds: Optional[torch.FloatTensor] = None,
+ pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
+ max_sequence_length: int = 512,
+ lora_scale: Optional[float] = None,
+ ):
+ r"""
+
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ prompt to be encoded
+ prompt_2 (`str` or `List[str]`, *optional*):
+ The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
+ used in all text-encoders
+ device: (`torch.device`):
+ torch device
+ num_images_per_prompt (`int`):
+ number of images that should be generated per prompt
+ prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+ provided, text embeddings will be generated from `prompt` input argument.
+ pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
+ If not provided, pooled text embeddings will be generated from `prompt` input argument.
+ lora_scale (`float`, *optional*):
+ A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
+ """
+ device = device or self._execution_device
+
+ # set lora scale so that monkey patched LoRA
+ # function of text encoder can correctly access it
+ if lora_scale is not None and isinstance(self, FluxLoraLoaderMixin):
+ self._lora_scale = lora_scale
+
+ # dynamically adjust the LoRA scale
+ if self.text_encoder is not None and USE_PEFT_BACKEND:
+ scale_lora_layers(self.text_encoder, lora_scale)
+ if self.text_encoder_2 is not None and USE_PEFT_BACKEND:
+ scale_lora_layers(self.text_encoder_2, lora_scale)
+
+ prompt = [prompt] if isinstance(prompt, str) else prompt
+ if prompt is not None:
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ if prompt_embeds is None:
+ prompt_2 = prompt_2 or prompt
+ prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2
+
+ # We only use the pooled prompt output from the CLIPTextModel
+ pooled_prompt_embeds = self._get_clip_prompt_embeds(
+ prompt=prompt,
+ device=device,
+ num_images_per_prompt=num_images_per_prompt,
+ )
+ prompt_embeds = self._get_t5_prompt_embeds(
+ prompt=prompt_2,
+ num_images_per_prompt=num_images_per_prompt,
+ max_sequence_length=max_sequence_length,
+ device=device,
+ )
+
+ if self.text_encoder is not None:
+ if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND:
+ # Retrieve the original scale by scaling back the LoRA layers
+ unscale_lora_layers(self.text_encoder, lora_scale)
+
+ if self.text_encoder_2 is not None:
+ if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND:
+ # Retrieve the original scale by scaling back the LoRA layers
+ unscale_lora_layers(self.text_encoder_2, lora_scale)
+
+ dtype = self.text_encoder.dtype if self.text_encoder is not None else self.transformer.dtype
+ text_ids = torch.zeros(batch_size, prompt_embeds.shape[1], 3).to(device=device, dtype=dtype)
+ text_ids = text_ids.repeat(num_images_per_prompt, 1, 1)
+
+ return prompt_embeds, pooled_prompt_embeds, text_ids
+
+ def check_inputs(
+ self,
+ prompt,
+ prompt_2,
+ height,
+ width,
+ prompt_embeds=None,
+ pooled_prompt_embeds=None,
+ callback_on_step_end_tensor_inputs=None,
+ max_sequence_length=None,
+ ):
+ if height % 8 != 0 or width % 8 != 0:
+ raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
+
+ if callback_on_step_end_tensor_inputs is not None and not all(
+ k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
+ ):
+ raise ValueError(
+ f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
+ )
+
+ if prompt is not None and prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
+ " only forward one of the two."
+ )
+ elif prompt_2 is not None and prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `prompt_2`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
+ " only forward one of the two."
+ )
+ elif prompt is None and prompt_embeds is None:
+ raise ValueError(
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
+ )
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
+ elif prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)):
+ raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}")
+
+ if prompt_embeds is not None and pooled_prompt_embeds is None:
+ raise ValueError(
+ "If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`."
+ )
+
+ if max_sequence_length is not None and max_sequence_length > 512:
+ raise ValueError(f"`max_sequence_length` cannot be greater than 512 but is {max_sequence_length}")
+
+ @staticmethod
+ def _prepare_latent_image_ids(batch_size, height, width, device, dtype):
+ latent_image_ids = torch.zeros(height // 2, width // 2, 3)
+ latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height // 2)[:, None]
+ latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width // 2)[None, :]
+
+ latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape
+
+ latent_image_ids = latent_image_ids[None, :].repeat(batch_size, 1, 1, 1)
+ latent_image_ids = latent_image_ids.reshape(
+ batch_size, latent_image_id_height * latent_image_id_width, latent_image_id_channels
+ )
+
+ return latent_image_ids.to(device=device, dtype=dtype)
+
+ @staticmethod
+ def _pack_latents(latents, batch_size, num_channels_latents, height, width):
+ latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2)
+ latents = latents.permute(0, 2, 4, 1, 3, 5)
+ latents = latents.reshape(batch_size, (height // 2) * (width // 2), num_channels_latents * 4)
+
+ return latents
+
+ @staticmethod
+ def _unpack_latents(latents, height, width, vae_scale_factor):
+ batch_size, num_patches, channels = latents.shape
+
+ height = height // vae_scale_factor
+ width = width // vae_scale_factor
+
+ latents = latents.view(batch_size, height, width, channels // 4, 2, 2)
+ latents = latents.permute(0, 3, 1, 4, 2, 5)
+
+ latents = latents.reshape(batch_size, channels // (2 * 2), height * 2, width * 2)
+
+ return latents
+
+ def prepare_latents(
+ self,
+ batch_size,
+ num_channels_latents,
+ height,
+ width,
+ dtype,
+ device,
+ generator,
+ latents=None,
+ ):
+ height = 2 * (int(height) // self.vae_scale_factor)
+ width = 2 * (int(width) // self.vae_scale_factor)
+
+ shape = (batch_size, num_channels_latents, height, width)
+
+ if latents is not None:
+ latent_image_ids = self._prepare_latent_image_ids(batch_size, height, width, device, dtype)
+ return latents.to(device=device, dtype=dtype), latent_image_ids
+
+ if isinstance(generator, list) and len(generator) != batch_size:
+ raise ValueError(
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
+ )
+
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
+ latents = self._pack_latents(latents, batch_size, num_channels_latents, height, width)
+
+ latent_image_ids = self._prepare_latent_image_ids(batch_size, height, width, device, dtype)
+
+ return latents, latent_image_ids
+
+ @property
+ def guidance_scale(self):
+ return self._guidance_scale
+
+ @property
+ def joint_attention_kwargs(self):
+ return self._joint_attention_kwargs
+
+ @property
+ def num_timesteps(self):
+ return self._num_timesteps
+
+ @property
+ def interrupt(self):
+ return self._interrupt
+
+ @torch.no_grad()
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
+ def __call__(
+ self,
+ prompt: Union[str, List[str]] = None,
+ prompt_2: Optional[Union[str, List[str]]] = None,
+ height: Optional[int] = None,
+ width: Optional[int] = None,
+ num_inference_steps: int = 28,
+ timesteps: List[int] = None,
+ guidance_scale: float = 7.0,
+ num_images_per_prompt: Optional[int] = 1,
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+ latents: Optional[torch.FloatTensor] = None,
+ prompt_embeds: Optional[torch.FloatTensor] = None,
+ pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
+ output_type: Optional[str] = "pil",
+ return_dict: bool = True,
+ joint_attention_kwargs: Optional[Dict[str, Any]] = None,
+ callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
+ max_sequence_length: int = 512,
+ ):
+ r"""
+ Function invoked when calling the pipeline for generation.
+
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
+ instead.
+ prompt_2 (`str` or `List[str]`, *optional*):
+ The prompt or prompts to be sent to `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
+ will be used instead
+ height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
+ The height in pixels of the generated image. This is set to 1024 by default for the best results.
+ width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
+ The width in pixels of the generated image. This is set to 1024 by default for the best results.
+ num_inference_steps (`int`, *optional*, defaults to 50):
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
+ expense of slower inference.
+ timesteps (`List[int]`, *optional*):
+ Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
+ in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
+ passed will be used. Must be in descending order.
+ guidance_scale (`float`, *optional*, defaults to 7.0):
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
+ usually at the expense of lower image quality.
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
+ The number of images to generate per prompt.
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
+ to make generation deterministic.
+ latents (`torch.FloatTensor`, *optional*):
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
+ tensor will ge generated by sampling using the supplied random `generator`.
+ prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+ provided, text embeddings will be generated from `prompt` input argument.
+ pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
+ If not provided, pooled text embeddings will be generated from `prompt` input argument.
+ output_type (`str`, *optional*, defaults to `"pil"`):
+ The output format of the generate image. Choose between
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~pipelines.flux.FluxPipelineOutput`] instead of a plain tuple.
+ joint_attention_kwargs (`dict`, *optional*):
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
+ `self.processor` in
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
+ callback_on_step_end (`Callable`, *optional*):
+ A function that calls at the end of each denoising steps during the inference. The function is called
+ with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
+ callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
+ `callback_on_step_end_tensor_inputs`.
+ callback_on_step_end_tensor_inputs (`List`, *optional*):
+ The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
+ will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
+ `._callback_tensor_inputs` attribute of your pipeline class.
+ max_sequence_length (`int` defaults to 512): Maximum sequence length to use with the `prompt`.
+
+ Examples:
+
+ Returns:
+ [`~pipelines.flux.FluxPipelineOutput`] or `tuple`: [`~pipelines.flux.FluxPipelineOutput`] if `return_dict`
+ is True, otherwise a `tuple`. When returning a tuple, the first element is a list with the generated
+ images.
+ """
+
+ height = height or self.default_sample_size * self.vae_scale_factor
+ width = width or self.default_sample_size * self.vae_scale_factor
+
+ # 1. Check inputs. Raise error if not correct
+ self.check_inputs(
+ prompt,
+ prompt_2,
+ height,
+ width,
+ prompt_embeds=prompt_embeds,
+ pooled_prompt_embeds=pooled_prompt_embeds,
+ callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
+ max_sequence_length=max_sequence_length,
+ )
+
+ self._guidance_scale = guidance_scale
+ self._joint_attention_kwargs = joint_attention_kwargs
+ self._interrupt = False
+
+ # 2. Define call parameters
+ if prompt is not None and isinstance(prompt, str):
+ batch_size = 1
+ elif prompt is not None and isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ device = self._execution_device
+
+ lora_scale = (
+ self.joint_attention_kwargs.get("scale", None) if self.joint_attention_kwargs is not None else None
+ )
+ (
+ prompt_embeds,
+ pooled_prompt_embeds,
+ text_ids,
+ ) = self.encode_prompt(
+ prompt=prompt,
+ prompt_2=prompt_2,
+ prompt_embeds=prompt_embeds,
+ pooled_prompt_embeds=pooled_prompt_embeds,
+ device=device,
+ num_images_per_prompt=num_images_per_prompt,
+ max_sequence_length=max_sequence_length,
+ lora_scale=lora_scale,
+ )
+
+ # 4. Prepare latent variables
+ num_channels_latents = self.transformer.config.in_channels // 4
+ latents, latent_image_ids = self.prepare_latents(
+ batch_size * num_images_per_prompt,
+ num_channels_latents,
+ height,
+ width,
+ prompt_embeds.dtype,
+ device,
+ generator,
+ latents,
+ )
+
+ # 5. Prepare timesteps
+ sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps)
+ image_seq_len = latents.shape[1]
+ mu = calculate_shift(
+ image_seq_len,
+ self.scheduler.config.base_image_seq_len,
+ self.scheduler.config.max_image_seq_len,
+ self.scheduler.config.base_shift,
+ self.scheduler.config.max_shift,
+ )
+ timesteps, num_inference_steps = retrieve_timesteps(
+ self.scheduler,
+ num_inference_steps,
+ device,
+ timesteps,
+ sigmas,
+ mu=mu,
+ )
+ num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
+ self._num_timesteps = len(timesteps)
+
+ # 6. Denoising loop
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
+ for i, t in enumerate(timesteps):
+ if self.interrupt:
+ continue
+
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
+ timestep = t.expand(latents.shape[0]).to(latents.dtype)
+
+ # handle guidance
+ if self.transformer.config.guidance_embeds:
+ guidance = torch.tensor([guidance_scale], device=device)
+ guidance = guidance.expand(latents.shape[0])
+ else:
+ guidance = None
+
+ noise_pred = self.transformer(
+ hidden_states=latents,
+ # YiYi notes: divide it by 1000 for now because we scale it by 1000 in the transforme rmodel (we should not keep it but I want to keep the inputs same for the model for testing)
+ timestep=timestep / 1000,
+ guidance=guidance,
+ pooled_projections=pooled_prompt_embeds,
+ encoder_hidden_states=prompt_embeds,
+ txt_ids=text_ids,
+ img_ids=latent_image_ids,
+ joint_attention_kwargs=self.joint_attention_kwargs,
+ return_dict=False,
+ )[0]
+
+ # compute the previous noisy sample x_t -> x_t-1
+ latents_dtype = latents.dtype
+ latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
+
+ if latents.dtype != latents_dtype:
+ if torch.backends.mps.is_available():
+ # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
+ latents = latents.to(latents_dtype)
+
+ if callback_on_step_end is not None:
+ callback_kwargs = {}
+ for k in callback_on_step_end_tensor_inputs:
+ callback_kwargs[k] = locals()[k]
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
+
+ latents = callback_outputs.pop("latents", latents)
+ prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
+
+ # call the callback, if provided
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
+ progress_bar.update()
+
+ if XLA_AVAILABLE:
+ xm.mark_step()
+
+ if output_type == "latent":
+ image = latents
+
+ else:
+ latents = self._unpack_latents(latents, height, width, self.vae_scale_factor)
+ latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor
+ image = self.vae.decode(latents, return_dict=False)[0]
+ image = self.image_processor.postprocess(image, output_type=output_type)
+
+ # Offload all models
+ self.maybe_free_model_hooks()
+
+ if not return_dict:
+ return (image,)
+
+ return FluxPipelineOutput(images=image)
\ No newline at end of file
diff --git a/flux/pipeline_flux_chameleon.py b/flux/pipeline_flux_chameleon.py
new file mode 100644
index 0000000000000000000000000000000000000000..c01b57c4710d25cd6e214dd5c2fb3e51481813a2
--- /dev/null
+++ b/flux/pipeline_flux_chameleon.py
@@ -0,0 +1,758 @@
+# Copyright 2024 Black Forest Labs and The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import inspect
+from typing import Any, Callable, Dict, List, Optional, Union
+
+import numpy as np
+import torch
+from transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5TokenizerFast
+
+from diffusers.image_processor import VaeImageProcessor
+from .lora.lora_pipeline import FluxLoraLoaderMixin
+from diffusers.models.autoencoders import AutoencoderKL
+from .transformer_flux import FluxTransformer2DModel
+from .scheduling_flow_match_euler_discrete import FlowMatchEulerDiscreteScheduler
+from diffusers.utils import (
+ USE_PEFT_BACKEND,
+ is_torch_xla_available,
+ logging,
+ replace_example_docstring,
+ scale_lora_layers,
+ unscale_lora_layers,
+)
+from diffusers.utils.torch_utils import randn_tensor
+from diffusers.pipelines.pipeline_utils import DiffusionPipeline
+from .pipeline_output import FluxPipelineOutput
+
+
+if is_torch_xla_available():
+ import torch_xla.core.xla_model as xm
+
+ XLA_AVAILABLE = True
+else:
+ XLA_AVAILABLE = False
+
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+EXAMPLE_DOC_STRING = """
+ Examples:
+ ```py
+ >>> import torch
+ >>> from diffusers import FluxPipeline
+
+ >>> pipe = FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell", torch_dtype=torch.bfloat16)
+ >>> pipe.to("cuda")
+ >>> prompt = "A cat holding a sign that says hello world"
+ >>> # Depending on the variant being used, the pipeline call will slightly vary.
+ >>> # Refer to the pipeline documentation for more details.
+ >>> image = pipe(prompt, num_inference_steps=4, guidance_scale=0.0).images[0]
+ >>> image.save("flux.png")
+ ```
+"""
+
+
+def calculate_shift(
+ image_seq_len,
+ base_seq_len: int = 256,
+ max_seq_len: int = 4096,
+ base_shift: float = 0.5,
+ max_shift: float = 1.16,
+):
+ m = (max_shift - base_shift) / (max_seq_len - base_seq_len)
+ b = base_shift - m * base_seq_len
+ mu = image_seq_len * m + b
+ return mu
+
+
+# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
+def retrieve_timesteps(
+ scheduler,
+ num_inference_steps: Optional[int] = None,
+ device: Optional[Union[str, torch.device]] = None,
+ timesteps: Optional[List[int]] = None,
+ sigmas: Optional[List[float]] = None,
+ **kwargs,
+):
+ """
+ Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
+ custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
+
+ Args:
+ scheduler (`SchedulerMixin`):
+ The scheduler to get timesteps from.
+ num_inference_steps (`int`):
+ The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
+ must be `None`.
+ device (`str` or `torch.device`, *optional*):
+ The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
+ timesteps (`List[int]`, *optional*):
+ Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
+ `num_inference_steps` and `sigmas` must be `None`.
+ sigmas (`List[float]`, *optional*):
+ Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
+ `num_inference_steps` and `timesteps` must be `None`.
+
+ Returns:
+ `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
+ second element is the number of inference steps.
+ """
+ if timesteps is not None and sigmas is not None:
+ raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
+ if timesteps is not None:
+ accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
+ if not accepts_timesteps:
+ raise ValueError(
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
+ f" timestep schedules. Please check whether you are using the correct scheduler."
+ )
+ scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
+ timesteps = scheduler.timesteps
+ num_inference_steps = len(timesteps)
+ elif sigmas is not None:
+ accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
+ if not accept_sigmas:
+ raise ValueError(
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
+ f" sigmas schedules. Please check whether you are using the correct scheduler."
+ )
+ scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
+ timesteps = scheduler.timesteps
+ num_inference_steps = len(timesteps)
+ else:
+ scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
+ timesteps = scheduler.timesteps
+ return timesteps, num_inference_steps
+
+
+class FluxPipeline(DiffusionPipeline, FluxLoraLoaderMixin):
+ r"""
+ The Flux pipeline for text-to-image generation.
+
+ Reference: https://blackforestlabs.ai/announcing-black-forest-labs/
+
+ Args:
+ transformer ([`FluxTransformer2DModel`]):
+ Conditional Transformer (MMDiT) architecture to denoise the encoded image latents.
+ scheduler ([`FlowMatchEulerDiscreteScheduler`]):
+ A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
+ vae ([`AutoencoderKL`]):
+ Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
+ text_encoder ([`CLIPTextModel`]):
+ [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically
+ the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
+ text_encoder_2 ([`T5EncoderModel`]):
+ [T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5EncoderModel), specifically
+ the [google/t5-v1_1-xxl](https://huggingface.co/google/t5-v1_1-xxl) variant.
+ tokenizer (`CLIPTokenizer`):
+ Tokenizer of class
+ [CLIPTokenizer](https://huggingface.co/docs/transformers/en/model_doc/clip#transformers.CLIPTokenizer).
+ tokenizer_2 (`T5TokenizerFast`):
+ Second Tokenizer of class
+ [T5TokenizerFast](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5TokenizerFast).
+ """
+
+ model_cpu_offload_seq = "text_encoder->text_encoder_2->transformer->vae"
+ _optional_components = []
+ _callback_tensor_inputs = ["latents", "prompt_embeds"]
+
+ def __init__(
+ self,
+ scheduler: FlowMatchEulerDiscreteScheduler,
+ vae: AutoencoderKL,
+ text_encoder: CLIPTextModel,
+ tokenizer: CLIPTokenizer,
+ transformer: FluxTransformer2DModel,
+ text_encoder_2: T5EncoderModel | None = None,
+ tokenizer_2: T5TokenizerFast | None = None,
+ ):
+ super().__init__()
+
+ self.register_modules(
+ vae=vae,
+ text_encoder=text_encoder,
+ #text_encoder_2=text_encoder_2,
+ tokenizer=tokenizer,
+ #tokenizer_2=tokenizer_2,
+ transformer=transformer,
+ scheduler=scheduler,
+ )
+ self.vae_scale_factor = (
+ 2 ** (len(self.vae.config.block_out_channels)) if hasattr(self, "vae") and self.vae is not None else 16
+ )
+ self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
+ self.tokenizer_max_length = (
+ self.tokenizer.model_max_length if hasattr(self, "tokenizer") and self.tokenizer is not None else 77
+ )
+ self.default_sample_size = 64
+
+ def _get_t5_prompt_embeds(
+ self,
+ prompt: Union[str, List[str]] = None,
+ num_images_per_prompt: int = 1,
+ max_sequence_length: int = 512,
+ device: Optional[torch.device] = None,
+ dtype: Optional[torch.dtype] = None,
+ ):
+ device = device or self._execution_device
+ dtype = dtype or self.text_encoder.dtype
+
+ prompt = [prompt] if isinstance(prompt, str) else prompt
+ batch_size = len(prompt)
+
+ text_inputs = self.tokenizer_2(
+ prompt,
+ padding="max_length",
+ max_length=max_sequence_length,
+ truncation=True,
+ return_length=False,
+ return_overflowing_tokens=False,
+ return_tensors="pt",
+ )
+ text_input_ids = text_inputs.input_ids
+ untruncated_ids = self.tokenizer_2(prompt, padding="longest", return_tensors="pt").input_ids
+
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
+ removed_text = self.tokenizer_2.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1])
+ logger.warning(
+ "The following part of your input was truncated because `max_sequence_length` is set to "
+ f" {max_sequence_length} tokens: {removed_text}"
+ )
+
+ prompt_embeds = self.text_encoder_2(text_input_ids.to(device), output_hidden_states=False)[0]
+
+ dtype = self.text_encoder_2.dtype
+ prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
+
+ _, seq_len, _ = prompt_embeds.shape
+
+ # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
+ prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
+
+ return prompt_embeds
+
+ def _get_clip_prompt_embeds(
+ self,
+ prompt: Union[str, List[str]],
+ num_images_per_prompt: int = 1,
+ device: Optional[torch.device] = None,
+ ):
+ device = device or self._execution_device
+
+ prompt = [prompt] if isinstance(prompt, str) else prompt
+ batch_size = len(prompt)
+
+ text_inputs = self.tokenizer(
+ prompt,
+ padding="max_length",
+ max_length=self.tokenizer_max_length,
+ truncation=True,
+ return_overflowing_tokens=False,
+ return_length=False,
+ return_tensors="pt",
+ )
+
+ text_input_ids = text_inputs.input_ids
+ untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
+ removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1])
+ logger.warning(
+ "The following part of your input was truncated because CLIP can only handle sequences up to"
+ f" {self.tokenizer_max_length} tokens: {removed_text}"
+ )
+ prompt_embeds = self.text_encoder(text_input_ids.to(device), output_hidden_states=False)
+
+ # Use pooled output of CLIPTextModel
+ prompt_embeds = prompt_embeds.pooler_output
+ prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)
+
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
+ prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, -1)
+
+ return prompt_embeds
+
+ def encode_prompt(
+ self,
+ prompt: Union[str, List[str]],
+ prompt_2: Union[str, List[str]],
+ device: Optional[torch.device] = None,
+ num_images_per_prompt: int = 1,
+ prompt_embeds: Optional[torch.FloatTensor] = None,
+ t5_prompt_embeds: Optional[torch.FloatTensor] = None,
+ pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
+ max_sequence_length: int = 512,
+ lora_scale: Optional[float] = None,
+ ):
+ r"""
+
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ prompt to be encoded
+ prompt_2 (`str` or `List[str]`, *optional*):
+ The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
+ used in all text-encoders
+ device: (`torch.device`):
+ torch device
+ num_images_per_prompt (`int`):
+ number of images that should be generated per prompt
+ prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+ provided, text embeddings will be generated from `prompt` input argument.
+ pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
+ If not provided, pooled text embeddings will be generated from `prompt` input argument.
+ lora_scale (`float`, *optional*):
+ A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
+ """
+ device = device or self._execution_device
+
+ # set lora scale so that monkey patched LoRA
+ # function of text encoder can correctly access it
+ if lora_scale is not None and isinstance(self, FluxLoraLoaderMixin):
+ self._lora_scale = lora_scale
+
+ # dynamically adjust the LoRA scale
+ if self.text_encoder is not None and USE_PEFT_BACKEND:
+ scale_lora_layers(self.text_encoder, lora_scale)
+ if self.text_encoder_2 is not None and USE_PEFT_BACKEND:
+ scale_lora_layers(self.text_encoder_2, lora_scale)
+
+ prompt = [prompt] if isinstance(prompt, str) else prompt
+ if prompt is not None:
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ if prompt_embeds is None:
+ prompt_2 = prompt_2 or prompt
+ prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2
+
+ # We only use the pooled prompt output from the CLIPTextModel
+ pooled_prompt_embeds = self._get_clip_prompt_embeds(
+ prompt=prompt,
+ device=device,
+ num_images_per_prompt=num_images_per_prompt,
+ )
+ prompt_embeds = self._get_t5_prompt_embeds(
+ prompt=prompt_2,
+ num_images_per_prompt=num_images_per_prompt,
+ max_sequence_length=max_sequence_length,
+ device=device,
+ )
+
+ if self.text_encoder is not None:
+ if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND:
+ # Retrieve the original scale by scaling back the LoRA layers
+ unscale_lora_layers(self.text_encoder, lora_scale)
+
+ #if self.text_encoder_2 is not None:
+ # if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND:
+ # # Retrieve the original scale by scaling back the LoRA layers
+ # unscale_lora_layers(self.text_encoder_2, lora_scale)
+
+ dtype = self.text_encoder.dtype if self.text_encoder is not None else self.transformer.dtype
+ if t5_prompt_embeds is not None:
+ text_ids = torch.zeros(batch_size, prompt_embeds.shape[1] + t5_prompt_embeds.shape[1], 3).to(device=device, dtype=dtype)
+ else:
+ text_ids = torch.zeros(batch_size, prompt_embeds.shape[1], 3).to(device=device, dtype=dtype)
+
+ text_ids = text_ids.repeat(num_images_per_prompt, 1, 1)
+
+ return prompt_embeds, pooled_prompt_embeds, text_ids
+
+ def check_inputs(
+ self,
+ prompt,
+ prompt_2,
+ height,
+ width,
+ prompt_embeds=None,
+ pooled_prompt_embeds=None,
+ callback_on_step_end_tensor_inputs=None,
+ max_sequence_length=None,
+ ):
+ if height % 8 != 0 or width % 8 != 0:
+ raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
+
+ if callback_on_step_end_tensor_inputs is not None and not all(
+ k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
+ ):
+ raise ValueError(
+ f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
+ )
+
+ if prompt is not None and prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
+ " only forward one of the two."
+ )
+ elif prompt_2 is not None and prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `prompt_2`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
+ " only forward one of the two."
+ )
+ elif prompt is None and prompt_embeds is None:
+ raise ValueError(
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
+ )
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
+ elif prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)):
+ raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}")
+
+ if prompt_embeds is not None and pooled_prompt_embeds is None:
+ raise ValueError(
+ "If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`."
+ )
+
+ if max_sequence_length is not None and max_sequence_length > 512:
+ raise ValueError(f"`max_sequence_length` cannot be greater than 512 but is {max_sequence_length}")
+
+ @staticmethod
+ def _prepare_latent_image_ids(batch_size, height, width, device, dtype):
+ latent_image_ids = torch.zeros(height // 2, width // 2, 3)
+ latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height // 2)[:, None]
+ latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width // 2)[None, :]
+
+ latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape
+
+ latent_image_ids = latent_image_ids[None, :].repeat(batch_size, 1, 1, 1)
+ latent_image_ids = latent_image_ids.reshape(
+ batch_size, latent_image_id_height * latent_image_id_width, latent_image_id_channels
+ )
+
+ return latent_image_ids.to(device=device, dtype=dtype)
+
+ @staticmethod
+ def _pack_latents(latents, batch_size, num_channels_latents, height, width):
+ latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2)
+ latents = latents.permute(0, 2, 4, 1, 3, 5)
+ latents = latents.reshape(batch_size, (height // 2) * (width // 2), num_channels_latents * 4)
+
+ return latents
+
+ @staticmethod
+ def _unpack_latents(latents, height, width, vae_scale_factor):
+ batch_size, num_patches, channels = latents.shape
+
+ height = height // vae_scale_factor
+ width = width // vae_scale_factor
+
+ latents = latents.view(batch_size, height, width, channels // 4, 2, 2)
+ latents = latents.permute(0, 3, 1, 4, 2, 5)
+
+ latents = latents.reshape(batch_size, channels // (2 * 2), height * 2, width * 2)
+
+ return latents
+
+ def prepare_latents(
+ self,
+ batch_size,
+ num_channels_latents,
+ height,
+ width,
+ dtype,
+ device,
+ generator,
+ latents=None,
+ ):
+ height = 2 * (int(height) // self.vae_scale_factor)
+ width = 2 * (int(width) // self.vae_scale_factor)
+
+ shape = (batch_size, num_channels_latents, height, width)
+
+ if latents is not None:
+ latent_image_ids = self._prepare_latent_image_ids(batch_size, height, width, device, dtype)
+ return latents.to(device=device, dtype=dtype), latent_image_ids
+
+ if isinstance(generator, list) and len(generator) != batch_size:
+ raise ValueError(
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
+ )
+
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
+ latents = self._pack_latents(latents, batch_size, num_channels_latents, height, width)
+
+ latent_image_ids = self._prepare_latent_image_ids(batch_size, height, width, device, dtype)
+
+ return latents, latent_image_ids
+
+ @property
+ def guidance_scale(self):
+ return self._guidance_scale
+
+ @property
+ def joint_attention_kwargs(self):
+ return self._joint_attention_kwargs
+
+ @property
+ def num_timesteps(self):
+ return self._num_timesteps
+
+ @property
+ def interrupt(self):
+ return self._interrupt
+
+ #@torch.inference_mode()
+ @torch.no_grad()
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
+ def __call__(
+ self,
+ prompt: Union[str, List[str]] = None,
+ prompt_2: Optional[Union[str, List[str]]] = None,
+ height: Optional[int] = None,
+ width: Optional[int] = None,
+ num_inference_steps: int = 28,
+ timesteps: List[int] = None,
+ guidance_scale: float = 7.0,
+ num_images_per_prompt: Optional[int] = 1,
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+ latents: Optional[torch.FloatTensor] = None,
+ prompt_embeds: Optional[torch.FloatTensor] = None,
+ t5_prompt_embeds: Optional[torch.FloatTensor] = None,
+ pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
+ output_type: Optional[str] = "pil",
+ return_dict: bool = True,
+ joint_attention_kwargs: Optional[Dict[str, Any]] = None,
+ callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
+ max_sequence_length: int = 512,
+ ):
+ r"""
+ Function invoked when calling the pipeline for generation.
+
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
+ instead.
+ prompt_2 (`str` or `List[str]`, *optional*):
+ The prompt or prompts to be sent to `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
+ will be used instead
+ height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
+ The height in pixels of the generated image. This is set to 1024 by default for the best results.
+ width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
+ The width in pixels of the generated image. This is set to 1024 by default for the best results.
+ num_inference_steps (`int`, *optional*, defaults to 50):
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
+ expense of slower inference.
+ timesteps (`List[int]`, *optional*):
+ Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
+ in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
+ passed will be used. Must be in descending order.
+ guidance_scale (`float`, *optional*, defaults to 7.0):
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
+ usually at the expense of lower image quality.
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
+ The number of images to generate per prompt.
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
+ to make generation deterministic.
+ latents (`torch.FloatTensor`, *optional*):
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
+ tensor will ge generated by sampling using the supplied random `generator`.
+ prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+ provided, text embeddings will be generated from `prompt` input argument.
+ pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
+ If not provided, pooled text embeddings will be generated from `prompt` input argument.
+ output_type (`str`, *optional*, defaults to `"pil"`):
+ The output format of the generate image. Choose between
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~pipelines.flux.FluxPipelineOutput`] instead of a plain tuple.
+ joint_attention_kwargs (`dict`, *optional*):
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
+ `self.processor` in
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
+ callback_on_step_end (`Callable`, *optional*):
+ A function that calls at the end of each denoising steps during the inference. The function is called
+ with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
+ callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
+ `callback_on_step_end_tensor_inputs`.
+ callback_on_step_end_tensor_inputs (`List`, *optional*):
+ The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
+ will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
+ `._callback_tensor_inputs` attribute of your pipeline class.
+ max_sequence_length (`int` defaults to 512): Maximum sequence length to use with the `prompt`.
+
+ Examples:
+
+ Returns:
+ [`~pipelines.flux.FluxPipelineOutput`] or `tuple`: [`~pipelines.flux.FluxPipelineOutput`] if `return_dict`
+ is True, otherwise a `tuple`. When returning a tuple, the first element is a list with the generated
+ images.
+ """
+
+ height = height or self.default_sample_size * self.vae_scale_factor
+ width = width or self.default_sample_size * self.vae_scale_factor
+
+ # 1. Check inputs. Raise error if not correct
+ self.check_inputs(
+ prompt,
+ prompt_2,
+ height,
+ width,
+ prompt_embeds=prompt_embeds,
+ pooled_prompt_embeds=pooled_prompt_embeds,
+ callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
+ max_sequence_length=max_sequence_length,
+ )
+
+ self._guidance_scale = guidance_scale
+ self._joint_attention_kwargs = joint_attention_kwargs
+ self._interrupt = False
+
+ # 2. Define call parameters
+ if prompt is not None and isinstance(prompt, str):
+ batch_size = 1
+ elif prompt is not None and isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ device = self._execution_device
+
+ lora_scale = (
+ self.joint_attention_kwargs.get("scale", None) if self.joint_attention_kwargs is not None else None
+ )
+ (
+ prompt_embeds,
+ pooled_prompt_embeds,
+ text_ids,
+ ) = self.encode_prompt(
+ prompt=prompt,
+ prompt_2=prompt_2,
+ prompt_embeds=prompt_embeds,
+ t5_prompt_embeds=t5_prompt_embeds,
+ pooled_prompt_embeds=pooled_prompt_embeds,
+ device=device,
+ num_images_per_prompt=num_images_per_prompt,
+ max_sequence_length=max_sequence_length,
+ lora_scale=lora_scale,
+ )
+
+ # 4. Prepare latent variables
+ num_channels_latents = self.transformer.config.in_channels // 4
+ latents, latent_image_ids = self.prepare_latents(
+ batch_size * num_images_per_prompt,
+ num_channels_latents,
+ height,
+ width,
+ prompt_embeds.dtype,
+ device,
+ generator,
+ latents,
+ )
+
+ # 5. Prepare timesteps
+ sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps)
+ image_seq_len = latents.shape[1]
+ mu = calculate_shift(
+ image_seq_len,
+ self.scheduler.config.base_image_seq_len,
+ self.scheduler.config.max_image_seq_len,
+ self.scheduler.config.base_shift,
+ self.scheduler.config.max_shift,
+ )
+ timesteps, num_inference_steps = retrieve_timesteps(
+ self.scheduler,
+ num_inference_steps,
+ device,
+ timesteps,
+ sigmas,
+ mu=mu,
+ )
+ num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
+ self._num_timesteps = len(timesteps)
+
+ # 6. Denoising loop
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
+ for i, t in enumerate(timesteps):
+ if self.interrupt:
+ continue
+
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
+ timestep = t.expand(latents.shape[0]).to(latents.dtype)
+
+ # handle guidance
+ if self.transformer.config.guidance_embeds:
+ guidance = torch.tensor([guidance_scale], device=device)
+ guidance = guidance.expand(latents.shape[0])
+ else:
+ guidance = None
+
+ noise_pred = self.transformer(
+ hidden_states=latents,
+ # YiYi notes: divide it by 1000 for now because we scale it by 1000 in the transforme rmodel (we should not keep it but I want to keep the inputs same for the model for testing)
+ timestep=timestep / 1000,
+ guidance=guidance,
+ pooled_projections=pooled_prompt_embeds,
+ encoder_hidden_states=prompt_embeds,
+ t5_encoder_hidden_states=t5_prompt_embeds,
+ txt_ids=text_ids,
+ img_ids=latent_image_ids,
+ joint_attention_kwargs=self.joint_attention_kwargs,
+ return_dict=False,
+ )[0]
+
+ # compute the previous noisy sample x_t -> x_t-1
+ latents_dtype = latents.dtype
+ latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
+
+ if latents.dtype != latents_dtype:
+ if torch.backends.mps.is_available():
+ # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
+ latents = latents.to(latents_dtype)
+
+ if callback_on_step_end is not None:
+ callback_kwargs = {}
+ for k in callback_on_step_end_tensor_inputs:
+ callback_kwargs[k] = locals()[k]
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
+
+ latents = callback_outputs.pop("latents", latents)
+ prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
+
+ # call the callback, if provided
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
+ progress_bar.update()
+
+ if XLA_AVAILABLE:
+ xm.mark_step()
+
+ if output_type == "latent":
+ image = latents
+
+ else:
+ latents = self._unpack_latents(latents, height, width, self.vae_scale_factor)
+ latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor
+ image = self.vae.decode(latents, return_dict=False)[0]
+ image = self.image_processor.postprocess(image, output_type=output_type)
+
+ # Offload all models
+ self.maybe_free_model_hooks()
+
+ if not return_dict:
+ return (image,)
+
+ return FluxPipelineOutput(images=image)
\ No newline at end of file
diff --git a/flux/pipeline_flux_controlnet.py b/flux/pipeline_flux_controlnet.py
new file mode 100644
index 0000000000000000000000000000000000000000..a8734ad61beb9bc65d21081d2c5f7f58813667b0
--- /dev/null
+++ b/flux/pipeline_flux_controlnet.py
@@ -0,0 +1,945 @@
+# Copyright 2024 Black Forest Labs, The HuggingFace Team and The InstantX Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import inspect
+from typing import Any, Callable, Dict, List, Optional, Tuple, Union
+
+import numpy as np
+import torch
+from transformers import (
+ CLIPTextModel,
+ CLIPTokenizer,
+ T5EncoderModel,
+ T5TokenizerFast,
+)
+
+from diffusers.image_processor import PipelineImageInput, VaeImageProcessor
+from .lora.lora_pipeline import FluxLoraLoaderMixin
+from diffusers.loaders import FromSingleFileMixin
+from diffusers.models.autoencoders import AutoencoderKL
+from .controlnet_flux import FluxControlNetModel, FluxMultiControlNetModel
+from .transformer_flux import FluxTransformer2DModel
+from .scheduling_flow_match_euler_discrete import FlowMatchEulerDiscreteScheduler
+from diffusers.utils import (
+ USE_PEFT_BACKEND,
+ is_torch_xla_available,
+ logging,
+ replace_example_docstring,
+ scale_lora_layers,
+ unscale_lora_layers,
+)
+from diffusers.utils.torch_utils import randn_tensor
+from diffusers.pipelines.pipeline_utils import DiffusionPipeline
+from .pipeline_output import FluxPipelineOutput
+
+
+if is_torch_xla_available():
+ import torch_xla.core.xla_model as xm
+
+ XLA_AVAILABLE = True
+else:
+ XLA_AVAILABLE = False
+
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+EXAMPLE_DOC_STRING = """
+ Examples:
+ ```py
+ >>> import torch
+ >>> from diffusers.utils import load_image
+ >>> from diffusers import FluxControlNetPipeline
+ >>> from diffusers import FluxControlNetModel
+
+ >>> controlnet_model = "InstantX/FLUX.1-dev-controlnet-canny"
+ >>> controlnet = FluxControlNetModel.from_pretrained(controlnet_model, torch_dtype=torch.bfloat16)
+ >>> pipe = FluxControlNetPipeline.from_pretrained(
+ ... base_model, controlnet=controlnet, torch_dtype=torch.bfloat16
+ ... )
+ >>> pipe.to("cuda")
+ >>> control_image = load_image("https://huggingface.co/InstantX/SD3-Controlnet-Canny/resolve/main/canny.jpg")
+ >>> prompt = "A girl in city, 25 years old, cool, futuristic"
+ >>> image = pipe(
+ ... prompt,
+ ... control_image=control_image,
+ ... controlnet_conditioning_scale=0.6,
+ ... num_inference_steps=28,
+ ... guidance_scale=3.5,
+ ... ).images[0]
+ >>> image.save("flux.png")
+ ```
+"""
+
+
+# Copied from diffusers.pipelines.flux.pipeline_flux.calculate_shift
+def calculate_shift(
+ image_seq_len,
+ base_seq_len: int = 256,
+ max_seq_len: int = 4096,
+ base_shift: float = 0.5,
+ max_shift: float = 1.16,
+):
+ m = (max_shift - base_shift) / (max_seq_len - base_seq_len)
+ b = base_shift - m * base_seq_len
+ mu = image_seq_len * m + b
+ return mu
+
+
+# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
+def retrieve_timesteps(
+ scheduler,
+ num_inference_steps: Optional[int] = None,
+ device: Optional[Union[str, torch.device]] = None,
+ timesteps: Optional[List[int]] = None,
+ sigmas: Optional[List[float]] = None,
+ **kwargs,
+):
+ """
+ Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
+ custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
+
+ Args:
+ scheduler (`SchedulerMixin`):
+ The scheduler to get timesteps from.
+ num_inference_steps (`int`):
+ The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
+ must be `None`.
+ device (`str` or `torch.device`, *optional*):
+ The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
+ timesteps (`List[int]`, *optional*):
+ Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
+ `num_inference_steps` and `sigmas` must be `None`.
+ sigmas (`List[float]`, *optional*):
+ Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
+ `num_inference_steps` and `timesteps` must be `None`.
+
+ Returns:
+ `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
+ second element is the number of inference steps.
+ """
+ if timesteps is not None and sigmas is not None:
+ raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
+ if timesteps is not None:
+ accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
+ if not accepts_timesteps:
+ raise ValueError(
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
+ f" timestep schedules. Please check whether you are using the correct scheduler."
+ )
+ scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
+ timesteps = scheduler.timesteps
+ num_inference_steps = len(timesteps)
+ elif sigmas is not None:
+ accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
+ if not accept_sigmas:
+ raise ValueError(
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
+ f" sigmas schedules. Please check whether you are using the correct scheduler."
+ )
+ scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
+ timesteps = scheduler.timesteps
+ num_inference_steps = len(timesteps)
+ else:
+ scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
+ timesteps = scheduler.timesteps
+ return timesteps, num_inference_steps
+
+
+class FluxControlNetPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleFileMixin):
+ r"""
+ The Flux pipeline for text-to-image generation.
+
+ Reference: https://blackforestlabs.ai/announcing-black-forest-labs/
+
+ Args:
+ transformer ([`FluxTransformer2DModel`]):
+ Conditional Transformer (MMDiT) architecture to denoise the encoded image latents.
+ scheduler ([`FlowMatchEulerDiscreteScheduler`]):
+ A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
+ vae ([`AutoencoderKL`]):
+ Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
+ text_encoder ([`CLIPTextModel`]):
+ [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically
+ the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
+ text_encoder_2 ([`T5EncoderModel`]):
+ [T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5EncoderModel), specifically
+ the [google/t5-v1_1-xxl](https://huggingface.co/google/t5-v1_1-xxl) variant.
+ tokenizer (`CLIPTokenizer`):
+ Tokenizer of class
+ [CLIPTokenizer](https://huggingface.co/docs/transformers/en/model_doc/clip#transformers.CLIPTokenizer).
+ tokenizer_2 (`T5TokenizerFast`):
+ Second Tokenizer of class
+ [T5TokenizerFast](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5TokenizerFast).
+ """
+
+ model_cpu_offload_seq = "text_encoder->text_encoder_2->transformer->vae"
+ _optional_components = []
+ _callback_tensor_inputs = ["latents", "prompt_embeds"]
+
+ def __init__(
+ self,
+ scheduler: FlowMatchEulerDiscreteScheduler,
+ vae: AutoencoderKL,
+ text_encoder: CLIPTextModel,
+ tokenizer: CLIPTokenizer,
+ transformer: FluxTransformer2DModel,
+ controlnet: Union[
+ FluxControlNetModel, List[FluxControlNetModel], Tuple[FluxControlNetModel], FluxMultiControlNetModel
+ ],
+ text_encoder_2: T5EncoderModel | None = None,
+ tokenizer_2: T5TokenizerFast | None = None,
+ ):
+ super().__init__()
+
+ self.register_modules(
+ vae=vae,
+ text_encoder=text_encoder,
+ #text_encoder_2=text_encoder_2,
+ tokenizer=tokenizer,
+ #tokenizer_2=tokenizer_2,
+ transformer=transformer,
+ scheduler=scheduler,
+ controlnet=controlnet,
+ )
+ self.vae_scale_factor = (
+ 2 ** (len(self.vae.config.block_out_channels)) if hasattr(self, "vae") and self.vae is not None else 16
+ )
+ self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
+ self.tokenizer_max_length = (
+ self.tokenizer.model_max_length if hasattr(self, "tokenizer") and self.tokenizer is not None else 77
+ )
+ self.default_sample_size = 64
+
+ def _get_t5_prompt_embeds(
+ self,
+ prompt: Union[str, List[str]] = None,
+ num_images_per_prompt: int = 1,
+ max_sequence_length: int = 512,
+ device: Optional[torch.device] = None,
+ dtype: Optional[torch.dtype] = None,
+ ):
+ device = device or self._execution_device
+ dtype = dtype or self.text_encoder.dtype
+
+ prompt = [prompt] if isinstance(prompt, str) else prompt
+ batch_size = len(prompt)
+
+ text_inputs = self.tokenizer_2(
+ prompt,
+ padding="max_length",
+ max_length=max_sequence_length,
+ truncation=True,
+ return_length=False,
+ return_overflowing_tokens=False,
+ return_tensors="pt",
+ )
+ text_input_ids = text_inputs.input_ids
+ untruncated_ids = self.tokenizer_2(prompt, padding="longest", return_tensors="pt").input_ids
+
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
+ removed_text = self.tokenizer_2.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1])
+ logger.warning(
+ "The following part of your input was truncated because `max_sequence_length` is set to "
+ f" {max_sequence_length} tokens: {removed_text}"
+ )
+
+ prompt_embeds = self.text_encoder_2(text_input_ids.to(device), output_hidden_states=False)[0]
+
+ dtype = self.text_encoder_2.dtype
+ prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
+
+ _, seq_len, _ = prompt_embeds.shape
+
+ # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
+ prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
+
+ return prompt_embeds
+
+ def _get_clip_prompt_embeds(
+ self,
+ prompt: Union[str, List[str]],
+ num_images_per_prompt: int = 1,
+ device: Optional[torch.device] = None,
+ ):
+ device = device or self._execution_device
+
+ prompt = [prompt] if isinstance(prompt, str) else prompt
+ batch_size = len(prompt)
+
+ text_inputs = self.tokenizer(
+ prompt,
+ padding="max_length",
+ max_length=self.tokenizer_max_length,
+ truncation=True,
+ return_overflowing_tokens=False,
+ return_length=False,
+ return_tensors="pt",
+ )
+
+ text_input_ids = text_inputs.input_ids
+ untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
+ removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1])
+ logger.warning(
+ "The following part of your input was truncated because CLIP can only handle sequences up to"
+ f" {self.tokenizer_max_length} tokens: {removed_text}"
+ )
+ prompt_embeds = self.text_encoder(text_input_ids.to(device), output_hidden_states=False)
+
+ # Use pooled output of CLIPTextModel
+ prompt_embeds = prompt_embeds.pooler_output
+ prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)
+
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt)
+ prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, -1)
+
+ return prompt_embeds
+
+ def encode_prompt(
+ self,
+ prompt: Union[str, List[str]],
+ prompt_2: Union[str, List[str]],
+ device: Optional[torch.device] = None,
+ num_images_per_prompt: int = 1,
+ prompt_embeds: Optional[torch.FloatTensor] = None,
+ t5_prompt_embeds: Optional[torch.FloatTensor] = None,
+ pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
+ max_sequence_length: int = 512,
+ lora_scale: Optional[float] = None,
+ ):
+ r"""
+
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ prompt to be encoded
+ prompt_2 (`str` or `List[str]`, *optional*):
+ The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
+ used in all text-encoders
+ device: (`torch.device`):
+ torch device
+ num_images_per_prompt (`int`):
+ number of images that should be generated per prompt
+ prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+ provided, text embeddings will be generated from `prompt` input argument.
+ pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
+ If not provided, pooled text embeddings will be generated from `prompt` input argument.
+ clip_skip (`int`, *optional*):
+ Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
+ the output of the pre-final layer will be used for computing the prompt embeddings.
+ lora_scale (`float`, *optional*):
+ A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
+ """
+ device = device or self._execution_device
+
+ # set lora scale so that monkey patched LoRA
+ # function of text encoder can correctly access it
+ if lora_scale is not None and isinstance(self, FluxLoraLoaderMixin):
+ self._lora_scale = lora_scale
+
+ # dynamically adjust the LoRA scale
+ if self.text_encoder is not None and USE_PEFT_BACKEND:
+ scale_lora_layers(self.text_encoder, lora_scale)
+ if self.text_encoder_2 is not None and USE_PEFT_BACKEND:
+ scale_lora_layers(self.text_encoder_2, lora_scale)
+
+ prompt = [prompt] if isinstance(prompt, str) else prompt
+
+ if prompt_embeds is None:
+ prompt_2 = prompt_2 or prompt
+ prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2
+
+ # We only use the pooled prompt output from the CLIPTextModel
+ pooled_prompt_embeds = self._get_clip_prompt_embeds(
+ prompt=prompt,
+ device=device,
+ num_images_per_prompt=num_images_per_prompt,
+ )
+ prompt_embeds = self._get_t5_prompt_embeds(
+ prompt=prompt_2,
+ num_images_per_prompt=num_images_per_prompt,
+ max_sequence_length=max_sequence_length,
+ device=device,
+ )
+
+ if self.text_encoder is not None:
+ if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND:
+ # Retrieve the original scale by scaling back the LoRA layers
+ unscale_lora_layers(self.text_encoder, lora_scale)
+
+ #if self.text_encoder_2 is not None:
+ # if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND:
+ # # Retrieve the original scale by scaling back the LoRA layers
+ # unscale_lora_layers(self.text_encoder_2, lora_scale)
+
+ dtype = self.text_encoder.dtype if self.text_encoder is not None else self.transformer.dtype
+ if t5_prompt_embeds is not None:
+ text_ids = torch.zeros(prompt_embeds.shape[1] + t5_prompt_embeds.shape[1], 3).to(device=device, dtype=dtype)
+ else:
+ text_ids = torch.zeros(prompt_embeds.shape[1], 3).to(device=device, dtype=dtype)
+
+ return prompt_embeds, pooled_prompt_embeds, text_ids
+
+ def check_inputs(
+ self,
+ prompt,
+ prompt_2,
+ height,
+ width,
+ prompt_embeds=None,
+ pooled_prompt_embeds=None,
+ callback_on_step_end_tensor_inputs=None,
+ max_sequence_length=None,
+ ):
+ if height % 8 != 0 or width % 8 != 0:
+ raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
+
+ if callback_on_step_end_tensor_inputs is not None and not all(
+ k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
+ ):
+ raise ValueError(
+ f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
+ )
+
+ if prompt is not None and prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
+ " only forward one of the two."
+ )
+ elif prompt_2 is not None and prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `prompt_2`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
+ " only forward one of the two."
+ )
+ elif prompt is None and prompt_embeds is None:
+ raise ValueError(
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
+ )
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
+ elif prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)):
+ raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}")
+
+ if prompt_embeds is not None and pooled_prompt_embeds is None:
+ raise ValueError(
+ "If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`."
+ )
+
+ if max_sequence_length is not None and max_sequence_length > 512:
+ raise ValueError(f"`max_sequence_length` cannot be greater than 512 but is {max_sequence_length}")
+
+ @staticmethod
+ # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._prepare_latent_image_ids
+ def _prepare_latent_image_ids(batch_size, height, width, device, dtype):
+ latent_image_ids = torch.zeros(height // 2, width // 2, 3)
+ latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height // 2)[:, None]
+ latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width // 2)[None, :]
+
+ latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape
+
+ latent_image_ids = latent_image_ids.reshape(
+ latent_image_id_height * latent_image_id_width, latent_image_id_channels
+ )
+
+ return latent_image_ids.to(device=device, dtype=dtype)
+
+ @staticmethod
+ # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._pack_latents
+ def _pack_latents(latents, batch_size, num_channels_latents, height, width):
+ latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2)
+ latents = latents.permute(0, 2, 4, 1, 3, 5)
+ latents = latents.reshape(batch_size, (height // 2) * (width // 2), num_channels_latents * 4)
+
+ return latents
+
+ @staticmethod
+ # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._unpack_latents
+ def _unpack_latents(latents, height, width, vae_scale_factor):
+ batch_size, num_patches, channels = latents.shape
+
+ height = height // vae_scale_factor
+ width = width // vae_scale_factor
+
+ latents = latents.view(batch_size, height, width, channels // 4, 2, 2)
+ latents = latents.permute(0, 3, 1, 4, 2, 5)
+
+ latents = latents.reshape(batch_size, channels // (2 * 2), height * 2, width * 2)
+
+ return latents
+
+ # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.prepare_latents
+ def prepare_latents(
+ self,
+ batch_size,
+ num_channels_latents,
+ height,
+ width,
+ dtype,
+ device,
+ generator,
+ latents=None,
+ ):
+ height = 2 * (int(height) // self.vae_scale_factor)
+ width = 2 * (int(width) // self.vae_scale_factor)
+
+ shape = (batch_size, num_channels_latents, height, width)
+
+ if latents is not None:
+ latent_image_ids = self._prepare_latent_image_ids(batch_size, height, width, device, dtype)
+ return latents.to(device=device, dtype=dtype), latent_image_ids
+
+ if isinstance(generator, list) and len(generator) != batch_size:
+ raise ValueError(
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
+ )
+
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
+ latents = self._pack_latents(latents, batch_size, num_channels_latents, height, width)
+
+ latent_image_ids = self._prepare_latent_image_ids(batch_size, height, width, device, dtype)
+
+ return latents, latent_image_ids
+
+ # Copied from diffusers.pipelines.controlnet_sd3.pipeline_stable_diffusion_3_controlnet.StableDiffusion3ControlNetPipeline.prepare_image
+ def prepare_image(
+ self,
+ image,
+ width,
+ height,
+ batch_size,
+ num_images_per_prompt,
+ device,
+ dtype,
+ do_classifier_free_guidance=False,
+ guess_mode=False,
+ ):
+ if isinstance(image, torch.Tensor):
+ pass
+ else:
+ image = self.image_processor.preprocess(image, height=height, width=width)
+
+ image_batch_size = image.shape[0]
+
+ if image_batch_size == 1:
+ repeat_by = batch_size
+ else:
+ # image batch size is the same as prompt batch size
+ repeat_by = num_images_per_prompt
+
+ image = image.repeat_interleave(repeat_by, dim=0)
+
+ image = image.to(device=device, dtype=dtype)
+
+ if do_classifier_free_guidance and not guess_mode:
+ image = torch.cat([image] * 2)
+
+ return image
+
+ @property
+ def guidance_scale(self):
+ return self._guidance_scale
+
+ @property
+ def joint_attention_kwargs(self):
+ return self._joint_attention_kwargs
+
+ @property
+ def num_timesteps(self):
+ return self._num_timesteps
+
+ @property
+ def interrupt(self):
+ return self._interrupt
+
+ @torch.no_grad()
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
+ def __call__(
+ self,
+ prompt: Union[str, List[str]] = None,
+ prompt_2: Optional[Union[str, List[str]]] = None,
+ height: Optional[int] = None,
+ width: Optional[int] = None,
+ num_inference_steps: int = 28,
+ timesteps: List[int] = None,
+ guidance_scale: float = 7.0,
+ control_image: PipelineImageInput = None,
+ control_mode: Optional[Union[int, List[int]]] = None,
+ controlnet_conditioning_scale: Union[float, List[float]] = 1.0,
+ num_images_per_prompt: Optional[int] = 1,
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+ latents: Optional[torch.FloatTensor] = None,
+ prompt_embeds: Optional[torch.FloatTensor] = None,
+ t5_prompt_embeds: Optional[torch.FloatTensor] = None,
+ prompt_embeds_control: Optional[torch.FloatTensor] = None,
+ pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
+ output_type: Optional[str] = "pil",
+ return_dict: bool = True,
+ joint_attention_kwargs: Optional[Dict[str, Any]] = None,
+ callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
+ max_sequence_length: int = 512,
+ ):
+ r"""
+ Function invoked when calling the pipeline for generation.
+
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
+ instead.
+ prompt_2 (`str` or `List[str]`, *optional*):
+ The prompt or prompts to be sent to `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
+ will be used instead
+ height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
+ The height in pixels of the generated image. This is set to 1024 by default for the best results.
+ width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
+ The width in pixels of the generated image. This is set to 1024 by default for the best results.
+ num_inference_steps (`int`, *optional*, defaults to 50):
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
+ expense of slower inference.
+ timesteps (`List[int]`, *optional*):
+ Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
+ in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
+ passed will be used. Must be in descending order.
+ guidance_scale (`float`, *optional*, defaults to 7.0):
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
+ usually at the expense of lower image quality.
+ control_image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.Tensor]`, `List[PIL.Image.Image]`, `List[np.ndarray]`,:
+ `List[List[torch.Tensor]]`, `List[List[np.ndarray]]` or `List[List[PIL.Image.Image]]`):
+ The ControlNet input condition to provide guidance to the `unet` for generation. If the type is
+ specified as `torch.Tensor`, it is passed to ControlNet as is. `PIL.Image.Image` can also be accepted
+ as an image. The dimensions of the output image defaults to `image`'s dimensions. If height and/or
+ width are passed, `image` is resized accordingly. If multiple ControlNets are specified in `init`,
+ images must be passed as a list such that each element of the list can be correctly batched for input
+ to a single ControlNet.
+ controlnet_conditioning_scale (`float` or `List[float]`, *optional*, defaults to 1.0):
+ The outputs of the ControlNet are multiplied by `controlnet_conditioning_scale` before they are added
+ to the residual in the original `unet`. If multiple ControlNets are specified in `init`, you can set
+ the corresponding scale as a list.
+ control_mode (`int` or `List[int]`,, *optional*, defaults to None):
+ The control mode when applying ControlNet-Union.
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
+ The number of images to generate per prompt.
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
+ to make generation deterministic.
+ latents (`torch.FloatTensor`, *optional*):
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
+ tensor will ge generated by sampling using the supplied random `generator`.
+ prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+ provided, text embeddings will be generated from `prompt` input argument.
+ pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
+ If not provided, pooled text embeddings will be generated from `prompt` input argument.
+ output_type (`str`, *optional*, defaults to `"pil"`):
+ The output format of the generate image. Choose between
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~pipelines.flux.FluxPipelineOutput`] instead of a plain tuple.
+ joint_attention_kwargs (`dict`, *optional*):
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
+ `self.processor` in
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
+ callback_on_step_end (`Callable`, *optional*):
+ A function that calls at the end of each denoising steps during the inference. The function is called
+ with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
+ callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
+ `callback_on_step_end_tensor_inputs`.
+ callback_on_step_end_tensor_inputs (`List`, *optional*):
+ The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
+ will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
+ `._callback_tensor_inputs` attribute of your pipeline class.
+ max_sequence_length (`int` defaults to 512): Maximum sequence length to use with the `prompt`.
+
+ Examples:
+
+ Returns:
+ [`~pipelines.flux.FluxPipelineOutput`] or `tuple`: [`~pipelines.flux.FluxPipelineOutput`] if `return_dict`
+ is True, otherwise a `tuple`. When returning a tuple, the first element is a list with the generated
+ images.
+ """
+
+ height = height or self.default_sample_size * self.vae_scale_factor
+ width = width or self.default_sample_size * self.vae_scale_factor
+
+ # 1. Check inputs. Raise error if not correct
+ self.check_inputs(
+ prompt,
+ prompt_2,
+ height,
+ width,
+ prompt_embeds=prompt_embeds,
+ pooled_prompt_embeds=pooled_prompt_embeds,
+ callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
+ max_sequence_length=max_sequence_length,
+ )
+
+ self._guidance_scale = guidance_scale
+ self._joint_attention_kwargs = joint_attention_kwargs
+ self._interrupt = False
+
+ # 2. Define call parameters
+ if prompt is not None and isinstance(prompt, str):
+ batch_size = 1
+ elif prompt is not None and isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ device = self._execution_device
+ dtype = self.transformer.dtype
+
+ lora_scale = (
+ self.joint_attention_kwargs.get("scale", None) if self.joint_attention_kwargs is not None else None
+ )
+ (
+ prompt_embeds,
+ pooled_prompt_embeds,
+ text_ids,
+ ) = self.encode_prompt(
+ prompt=prompt,
+ prompt_2=prompt_2,
+ prompt_embeds=prompt_embeds,
+ t5_prompt_embeds=t5_prompt_embeds,
+ pooled_prompt_embeds=pooled_prompt_embeds,
+ device=device,
+ num_images_per_prompt=num_images_per_prompt,
+ max_sequence_length=max_sequence_length,
+ lora_scale=lora_scale,
+ )
+
+ # 3. Prepare control image
+ num_channels_latents = self.transformer.config.in_channels // 4
+ if isinstance(self.controlnet, FluxControlNetModel):
+ control_image = self.prepare_image(
+ image=control_image,
+ width=width,
+ height=height,
+ batch_size=batch_size * num_images_per_prompt,
+ num_images_per_prompt=num_images_per_prompt,
+ device=device,
+ dtype=self.vae.dtype,
+ )
+ height, width = control_image.shape[-2:]
+
+ # vae encode
+ control_image = self.vae.encode(control_image).latent_dist.sample()
+ control_image = (control_image - self.vae.config.shift_factor) * self.vae.config.scaling_factor
+
+ # pack
+ height_control_image, width_control_image = control_image.shape[2:]
+ control_image = self._pack_latents(
+ control_image,
+ batch_size * num_images_per_prompt,
+ num_channels_latents,
+ height_control_image,
+ width_control_image,
+ )
+
+ # Here we ensure that `control_mode` has the same length as the control_image.
+ if control_mode is not None:
+ if not isinstance(control_mode, int):
+ raise ValueError(" For `FluxControlNet`, `control_mode` should be an `int` or `None`")
+ control_mode = torch.tensor(control_mode).to(device, dtype=torch.long)
+ control_mode = control_mode.view(-1, 1).expand(control_image.shape[0], 1)
+
+ elif isinstance(self.controlnet, FluxMultiControlNetModel):
+ control_images = []
+
+ for control_image_ in control_image:
+ control_image_ = self.prepare_image(
+ image=control_image_,
+ width=width,
+ height=height,
+ batch_size=batch_size * num_images_per_prompt,
+ num_images_per_prompt=num_images_per_prompt,
+ device=device,
+ dtype=self.vae.dtype,
+ )
+ height, width = control_image_.shape[-2:]
+
+ # vae encode
+ control_image_ = self.vae.encode(control_image_).latent_dist.sample()
+ control_image_ = (control_image_ - self.vae.config.shift_factor) * self.vae.config.scaling_factor
+
+ # pack
+ height_control_image, width_control_image = control_image_.shape[2:]
+ control_image_ = self._pack_latents(
+ control_image_,
+ batch_size * num_images_per_prompt,
+ num_channels_latents,
+ height_control_image,
+ width_control_image,
+ )
+
+ control_images.append(control_image_)
+
+ control_image = control_images
+
+ # Here we ensure that `control_mode` has the same length as the control_image.
+ if isinstance(control_mode, list) and len(control_mode) != len(control_image):
+ raise ValueError(
+ "For Multi-ControlNet, `control_mode` must be a list of the same "
+ + " length as the number of controlnets (control images) specified"
+ )
+ if not isinstance(control_mode, list):
+ control_mode = [control_mode] * len(control_image)
+ # set control mode
+ control_modes = []
+ for cmode in control_mode:
+ if cmode is None:
+ cmode = -1
+ control_mode = torch.tensor(cmode).expand(control_images[0].shape[0]).to(device, dtype=torch.long)
+ control_modes.append(control_mode)
+ control_mode = control_modes
+
+ # 4. Prepare latent variables
+ num_channels_latents = self.transformer.config.in_channels // 4
+ latents, latent_image_ids = self.prepare_latents(
+ batch_size * num_images_per_prompt,
+ num_channels_latents,
+ height,
+ width,
+ prompt_embeds.dtype,
+ device,
+ generator,
+ latents,
+ )
+
+ # 5. Prepare timesteps
+ sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps)
+ image_seq_len = latents.shape[1]
+ mu = calculate_shift(
+ image_seq_len,
+ self.scheduler.config.base_image_seq_len,
+ self.scheduler.config.max_image_seq_len,
+ self.scheduler.config.base_shift,
+ self.scheduler.config.max_shift,
+ )
+ timesteps, num_inference_steps = retrieve_timesteps(
+ self.scheduler,
+ num_inference_steps,
+ device,
+ timesteps,
+ sigmas,
+ mu=mu,
+ )
+
+ num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
+ self._num_timesteps = len(timesteps)
+
+ # 6. Denoising loop
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
+ for i, t in enumerate(timesteps):
+ if self.interrupt:
+ continue
+
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
+ timestep = t.expand(latents.shape[0]).to(latents.dtype)
+
+ if isinstance(self.controlnet, FluxMultiControlNetModel):
+ use_guidance = self.controlnet.nets[0].config.guidance_embeds
+ else:
+ use_guidance = self.controlnet.config.guidance_embeds
+
+ guidance = torch.tensor([guidance_scale], device=device) if use_guidance else None
+ guidance = guidance.expand(latents.shape[0]) if guidance is not None else None
+
+ # controlnet
+ controlnet_block_samples, controlnet_single_block_samples = self.controlnet(
+ hidden_states=latents,
+ controlnet_cond=control_image,
+ controlnet_mode=control_mode,
+ conditioning_scale=controlnet_conditioning_scale,
+ timestep=timestep / 1000,
+ guidance=guidance,
+ pooled_projections=pooled_prompt_embeds,
+ encoder_hidden_states=prompt_embeds_control,
+ t5_encoder_hidden_states=t5_prompt_embeds,
+ txt_ids=text_ids,
+ img_ids=latent_image_ids,
+ joint_attention_kwargs=self.joint_attention_kwargs,
+ return_dict=False,
+ )
+
+ guidance = (
+ torch.tensor([guidance_scale], device=device) if self.transformer.config.guidance_embeds else None
+ )
+ guidance = guidance.expand(latents.shape[0]) if guidance is not None else None
+
+ noise_pred = self.transformer(
+ hidden_states=latents,
+ timestep=timestep / 1000,
+ guidance=guidance,
+ pooled_projections=pooled_prompt_embeds,
+ encoder_hidden_states=prompt_embeds,
+ t5_encoder_hidden_states=t5_prompt_embeds,
+ controlnet_block_samples=controlnet_block_samples,
+ controlnet_single_block_samples=controlnet_single_block_samples,
+ txt_ids=text_ids,
+ img_ids=latent_image_ids,
+ joint_attention_kwargs=self.joint_attention_kwargs,
+ return_dict=False,
+ )[0]
+
+ # compute the previous noisy sample x_t -> x_t-1
+ latents_dtype = latents.dtype
+ latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
+
+ if latents.dtype != latents_dtype:
+ if torch.backends.mps.is_available():
+ # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
+ latents = latents.to(latents_dtype)
+
+ if callback_on_step_end is not None:
+ callback_kwargs = {}
+ for k in callback_on_step_end_tensor_inputs:
+ callback_kwargs[k] = locals()[k]
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
+
+ latents = callback_outputs.pop("latents", latents)
+ prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
+
+ # call the callback, if provided
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
+ progress_bar.update()
+
+ if XLA_AVAILABLE:
+ xm.mark_step()
+
+ if output_type == "latent":
+ image = latents
+
+ else:
+ latents = self._unpack_latents(latents, height, width, self.vae_scale_factor)
+ latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor
+
+ image = self.vae.decode(latents, return_dict=False)[0]
+ image = self.image_processor.postprocess(image, output_type=output_type)
+
+ # Offload all models
+ self.maybe_free_model_hooks()
+
+ if not return_dict:
+ return (image,)
+
+ return FluxPipelineOutput(images=image)
\ No newline at end of file
diff --git a/flux/pipeline_flux_controlnet_img2img.py b/flux/pipeline_flux_controlnet_img2img.py
new file mode 100644
index 0000000000000000000000000000000000000000..a9c728ef2cc48b9e237eaa669bf8b1dcc1eea07e
--- /dev/null
+++ b/flux/pipeline_flux_controlnet_img2img.py
@@ -0,0 +1,1002 @@
+# Copyright 2024 Black Forest Labs, The HuggingFace Team and The InstantX Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import inspect
+from typing import Any, Callable, Dict, List, Optional, Tuple, Union
+
+import numpy as np
+import torch
+from transformers import (
+ CLIPTextModel,
+ CLIPTokenizer,
+ T5EncoderModel,
+ T5TokenizerFast,
+)
+
+from diffusers.image_processor import PipelineImageInput, VaeImageProcessor
+from .lora.lora_pipeline import FluxLoraLoaderMixin
+from diffusers.loaders import FromSingleFileMixin
+from diffusers.models.autoencoders import AutoencoderKL
+from .controlnet_flux import FluxControlNetModel, FluxMultiControlNetModel
+from .transformer_flux import FluxTransformer2DModel
+from .scheduling_flow_match_euler_discrete import FlowMatchEulerDiscreteScheduler
+from diffusers.utils import (
+ USE_PEFT_BACKEND,
+ is_torch_xla_available,
+ logging,
+ replace_example_docstring,
+ scale_lora_layers,
+ unscale_lora_layers,
+)
+from diffusers.utils.torch_utils import randn_tensor
+from diffusers.pipelines.pipeline_utils import DiffusionPipeline
+from .pipeline_output import FluxPipelineOutput
+
+
+if is_torch_xla_available():
+ import torch_xla.core.xla_model as xm
+
+ XLA_AVAILABLE = True
+else:
+ XLA_AVAILABLE = False
+
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+EXAMPLE_DOC_STRING = """
+ Examples:
+ ```py
+ >>> import torch
+ >>> from diffusers.utils import load_image
+ >>> from diffusers import FluxControlNetPipeline
+ >>> from diffusers import FluxControlNetModel
+
+ >>> controlnet_model = "InstantX/FLUX.1-dev-controlnet-canny"
+ >>> controlnet = FluxControlNetModel.from_pretrained(controlnet_model, torch_dtype=torch.bfloat16)
+ >>> pipe = FluxControlNetPipeline.from_pretrained(
+ ... base_model, controlnet=controlnet, torch_dtype=torch.bfloat16
+ ... )
+ >>> pipe.to("cuda")
+ >>> control_image = load_image("https://huggingface.co/InstantX/SD3-Controlnet-Canny/resolve/main/canny.jpg")
+ >>> prompt = "A girl in city, 25 years old, cool, futuristic"
+ >>> image = pipe(
+ ... prompt,
+ ... control_image=control_image,
+ ... controlnet_conditioning_scale=0.6,
+ ... num_inference_steps=28,
+ ... guidance_scale=3.5,
+ ... ).images[0]
+ >>> image.save("flux.png")
+ ```
+"""
+
+
+# Copied from diffusers.pipelines.flux.pipeline_flux.calculate_shift
+def calculate_shift(
+ image_seq_len,
+ base_seq_len: int = 256,
+ max_seq_len: int = 4096,
+ base_shift: float = 0.5,
+ max_shift: float = 1.16,
+):
+ m = (max_shift - base_shift) / (max_seq_len - base_seq_len)
+ b = base_shift - m * base_seq_len
+ mu = image_seq_len * m + b
+ return mu
+
+# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
+def retrieve_latents(
+ encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
+):
+ if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
+ return encoder_output.latent_dist.sample(generator)
+ elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
+ return encoder_output.latent_dist.mode()
+ elif hasattr(encoder_output, "latents"):
+ return encoder_output.latents
+ else:
+ raise AttributeError("Could not access latents of provided encoder_output")
+
+# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
+def retrieve_timesteps(
+ scheduler,
+ num_inference_steps: Optional[int] = None,
+ device: Optional[Union[str, torch.device]] = None,
+ timesteps: Optional[List[int]] = None,
+ sigmas: Optional[List[float]] = None,
+ **kwargs,
+):
+ """
+ Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
+ custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
+
+ Args:
+ scheduler (`SchedulerMixin`):
+ The scheduler to get timesteps from.
+ num_inference_steps (`int`):
+ The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
+ must be `None`.
+ device (`str` or `torch.device`, *optional*):
+ The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
+ timesteps (`List[int]`, *optional*):
+ Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
+ `num_inference_steps` and `sigmas` must be `None`.
+ sigmas (`List[float]`, *optional*):
+ Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
+ `num_inference_steps` and `timesteps` must be `None`.
+
+ Returns:
+ `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
+ second element is the number of inference steps.
+ """
+ if timesteps is not None and sigmas is not None:
+ raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
+ if timesteps is not None:
+ accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
+ if not accepts_timesteps:
+ raise ValueError(
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
+ f" timestep schedules. Please check whether you are using the correct scheduler."
+ )
+ scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
+ timesteps = scheduler.timesteps
+ num_inference_steps = len(timesteps)
+ elif sigmas is not None:
+ accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
+ if not accept_sigmas:
+ raise ValueError(
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
+ f" sigmas schedules. Please check whether you are using the correct scheduler."
+ )
+ scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
+ timesteps = scheduler.timesteps
+ num_inference_steps = len(timesteps)
+ else:
+ scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
+ timesteps = scheduler.timesteps
+ return timesteps, num_inference_steps
+
+
+class FluxControlNetImg2ImgPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleFileMixin):
+ r"""
+ The Flux pipeline for text-to-image generation.
+
+ Reference: https://blackforestlabs.ai/announcing-black-forest-labs/
+
+ Args:
+ transformer ([`FluxTransformer2DModel`]):
+ Conditional Transformer (MMDiT) architecture to denoise the encoded image latents.
+ scheduler ([`FlowMatchEulerDiscreteScheduler`]):
+ A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
+ vae ([`AutoencoderKL`]):
+ Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
+ text_encoder ([`CLIPTextModel`]):
+ [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically
+ the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
+ text_encoder_2 ([`T5EncoderModel`]):
+ [T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5EncoderModel), specifically
+ the [google/t5-v1_1-xxl](https://huggingface.co/google/t5-v1_1-xxl) variant.
+ tokenizer (`CLIPTokenizer`):
+ Tokenizer of class
+ [CLIPTokenizer](https://huggingface.co/docs/transformers/en/model_doc/clip#transformers.CLIPTokenizer).
+ tokenizer_2 (`T5TokenizerFast`):
+ Second Tokenizer of class
+ [T5TokenizerFast](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5TokenizerFast).
+ """
+
+ model_cpu_offload_seq = "text_encoder->text_encoder_2->transformer->vae"
+ _optional_components = []
+ _callback_tensor_inputs = ["latents", "prompt_embeds"]
+
+ def __init__(
+ self,
+ scheduler: FlowMatchEulerDiscreteScheduler,
+ vae: AutoencoderKL,
+ text_encoder: CLIPTextModel,
+ tokenizer: CLIPTokenizer,
+ transformer: FluxTransformer2DModel,
+ controlnet: Union[
+ FluxControlNetModel, List[FluxControlNetModel], Tuple[FluxControlNetModel], FluxMultiControlNetModel
+ ],
+ text_encoder_2: T5EncoderModel | None = None,
+ tokenizer_2: T5TokenizerFast | None = None,
+ ):
+ super().__init__()
+
+ self.register_modules(
+ vae=vae,
+ text_encoder=text_encoder,
+ #text_encoder_2=text_encoder_2,
+ tokenizer=tokenizer,
+ #tokenizer_2=tokenizer_2,
+ transformer=transformer,
+ scheduler=scheduler,
+ controlnet=controlnet,
+ )
+ self.vae_scale_factor = (
+ 2 ** (len(self.vae.config.block_out_channels)) if hasattr(self, "vae") and self.vae is not None else 16
+ )
+ self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
+ self.tokenizer_max_length = (
+ self.tokenizer.model_max_length if hasattr(self, "tokenizer") and self.tokenizer is not None else 77
+ )
+ self.default_sample_size = 64
+
+ def _get_t5_prompt_embeds(
+ self,
+ prompt: Union[str, List[str]] = None,
+ num_images_per_prompt: int = 1,
+ max_sequence_length: int = 512,
+ device: Optional[torch.device] = None,
+ dtype: Optional[torch.dtype] = None,
+ ):
+ device = device or self._execution_device
+ dtype = dtype or self.text_encoder.dtype
+
+ prompt = [prompt] if isinstance(prompt, str) else prompt
+ batch_size = len(prompt)
+
+ text_inputs = self.tokenizer_2(
+ prompt,
+ padding="max_length",
+ max_length=max_sequence_length,
+ truncation=True,
+ return_length=False,
+ return_overflowing_tokens=False,
+ return_tensors="pt",
+ )
+ text_input_ids = text_inputs.input_ids
+ untruncated_ids = self.tokenizer_2(prompt, padding="longest", return_tensors="pt").input_ids
+
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
+ removed_text = self.tokenizer_2.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1])
+ logger.warning(
+ "The following part of your input was truncated because `max_sequence_length` is set to "
+ f" {max_sequence_length} tokens: {removed_text}"
+ )
+
+ prompt_embeds = self.text_encoder_2(text_input_ids.to(device), output_hidden_states=False)[0]
+
+ dtype = self.text_encoder_2.dtype
+ prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
+
+ _, seq_len, _ = prompt_embeds.shape
+
+ # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
+ prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
+
+ return prompt_embeds
+
+ def _get_clip_prompt_embeds(
+ self,
+ prompt: Union[str, List[str]],
+ num_images_per_prompt: int = 1,
+ device: Optional[torch.device] = None,
+ ):
+ device = device or self._execution_device
+
+ prompt = [prompt] if isinstance(prompt, str) else prompt
+ batch_size = len(prompt)
+
+ text_inputs = self.tokenizer(
+ prompt,
+ padding="max_length",
+ max_length=self.tokenizer_max_length,
+ truncation=True,
+ return_overflowing_tokens=False,
+ return_length=False,
+ return_tensors="pt",
+ )
+
+ text_input_ids = text_inputs.input_ids
+ untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
+ removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1])
+ logger.warning(
+ "The following part of your input was truncated because CLIP can only handle sequences up to"
+ f" {self.tokenizer_max_length} tokens: {removed_text}"
+ )
+ prompt_embeds = self.text_encoder(text_input_ids.to(device), output_hidden_states=False)
+
+ # Use pooled output of CLIPTextModel
+ prompt_embeds = prompt_embeds.pooler_output
+ prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)
+
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt)
+ prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, -1)
+
+ return prompt_embeds
+
+ def encode_prompt(
+ self,
+ prompt: Union[str, List[str]],
+ prompt_2: Union[str, List[str]],
+ device: Optional[torch.device] = None,
+ num_images_per_prompt: int = 1,
+ prompt_embeds: Optional[torch.FloatTensor] = None,
+ t5_prompt_embeds: Optional[torch.FloatTensor] = None,
+ pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
+ max_sequence_length: int = 512,
+ lora_scale: Optional[float] = None,
+ ):
+ r"""
+
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ prompt to be encoded
+ prompt_2 (`str` or `List[str]`, *optional*):
+ The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
+ used in all text-encoders
+ device: (`torch.device`):
+ torch device
+ num_images_per_prompt (`int`):
+ number of images that should be generated per prompt
+ prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+ provided, text embeddings will be generated from `prompt` input argument.
+ pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
+ If not provided, pooled text embeddings will be generated from `prompt` input argument.
+ clip_skip (`int`, *optional*):
+ Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
+ the output of the pre-final layer will be used for computing the prompt embeddings.
+ lora_scale (`float`, *optional*):
+ A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
+ """
+ device = device or self._execution_device
+
+ # set lora scale so that monkey patched LoRA
+ # function of text encoder can correctly access it
+ if lora_scale is not None and isinstance(self, FluxLoraLoaderMixin):
+ self._lora_scale = lora_scale
+
+ # dynamically adjust the LoRA scale
+ if self.text_encoder is not None and USE_PEFT_BACKEND:
+ scale_lora_layers(self.text_encoder, lora_scale)
+ if self.text_encoder_2 is not None and USE_PEFT_BACKEND:
+ scale_lora_layers(self.text_encoder_2, lora_scale)
+
+ prompt = [prompt] if isinstance(prompt, str) else prompt
+
+ if prompt_embeds is None:
+ prompt_2 = prompt_2 or prompt
+ prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2
+
+ # We only use the pooled prompt output from the CLIPTextModel
+ pooled_prompt_embeds = self._get_clip_prompt_embeds(
+ prompt=prompt,
+ device=device,
+ num_images_per_prompt=num_images_per_prompt,
+ )
+ prompt_embeds = self._get_t5_prompt_embeds(
+ prompt=prompt_2,
+ num_images_per_prompt=num_images_per_prompt,
+ max_sequence_length=max_sequence_length,
+ device=device,
+ )
+
+ if self.text_encoder is not None:
+ if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND:
+ # Retrieve the original scale by scaling back the LoRA layers
+ unscale_lora_layers(self.text_encoder, lora_scale)
+
+ #if self.text_encoder_2 is not None:
+ # if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND:
+ # # Retrieve the original scale by scaling back the LoRA layers
+ # unscale_lora_layers(self.text_encoder_2, lora_scale)
+
+ dtype = self.text_encoder.dtype if self.text_encoder is not None else self.transformer.dtype
+ if t5_prompt_embeds is not None:
+ text_ids = torch.zeros(prompt_embeds.shape[1] + t5_prompt_embeds.shape[1], 3).to(device=device, dtype=dtype)
+ else:
+ text_ids = torch.zeros(prompt_embeds.shape[1], 3).to(device=device, dtype=dtype)
+
+ return prompt_embeds, pooled_prompt_embeds, text_ids
+
+ # Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3_inpaint.StableDiffusion3InpaintPipeline._encode_vae_image
+ def _encode_vae_image(self, image: torch.Tensor, generator: torch.Generator):
+ if isinstance(generator, list):
+ image_latents = [
+ retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i])
+ for i in range(image.shape[0])
+ ]
+ image_latents = torch.cat(image_latents, dim=0)
+ else:
+ image_latents = retrieve_latents(self.vae.encode(image), generator=generator)
+
+ image_latents = (image_latents - self.vae.config.shift_factor) * self.vae.config.scaling_factor
+
+ return image_latents
+
+ # Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3_img2img.StableDiffusion3Img2ImgPipeline.get_timesteps
+ def get_timesteps(self, num_inference_steps, strength, device):
+ # get the original timestep using init_timestep
+ init_timestep = min(num_inference_steps * strength, num_inference_steps)
+
+ t_start = int(max(num_inference_steps - init_timestep, 0))
+ timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :]
+ if hasattr(self.scheduler, "set_begin_index"):
+ self.scheduler.set_begin_index(t_start * self.scheduler.order)
+
+ return timesteps, num_inference_steps - t_start
+
+ def check_inputs(
+ self,
+ prompt,
+ prompt_2,
+ height,
+ width,
+ prompt_embeds=None,
+ pooled_prompt_embeds=None,
+ callback_on_step_end_tensor_inputs=None,
+ max_sequence_length=None,
+ ):
+ if height % 8 != 0 or width % 8 != 0:
+ raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
+
+ if callback_on_step_end_tensor_inputs is not None and not all(
+ k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
+ ):
+ raise ValueError(
+ f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
+ )
+
+ if prompt is not None and prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
+ " only forward one of the two."
+ )
+ elif prompt_2 is not None and prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `prompt_2`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
+ " only forward one of the two."
+ )
+ elif prompt is None and prompt_embeds is None:
+ raise ValueError(
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
+ )
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
+ elif prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)):
+ raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}")
+
+ if prompt_embeds is not None and pooled_prompt_embeds is None:
+ raise ValueError(
+ "If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`."
+ )
+
+ if max_sequence_length is not None and max_sequence_length > 512:
+ raise ValueError(f"`max_sequence_length` cannot be greater than 512 but is {max_sequence_length}")
+
+ @staticmethod
+ # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._prepare_latent_image_ids
+ def _prepare_latent_image_ids(batch_size, height, width, device, dtype):
+ latent_image_ids = torch.zeros(height // 2, width // 2, 3)
+ latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height // 2)[:, None]
+ latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width // 2)[None, :]
+
+ latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape
+
+ latent_image_ids = latent_image_ids.reshape(
+ latent_image_id_height * latent_image_id_width, latent_image_id_channels
+ )
+
+ return latent_image_ids.to(device=device, dtype=dtype)
+
+ @staticmethod
+ # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._pack_latents
+ def _pack_latents(latents, batch_size, num_channels_latents, height, width):
+ latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2)
+ latents = latents.permute(0, 2, 4, 1, 3, 5)
+ latents = latents.reshape(batch_size, (height // 2) * (width // 2), num_channels_latents * 4)
+
+ return latents
+
+ @staticmethod
+ # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._unpack_latents
+ def _unpack_latents(latents, height, width, vae_scale_factor):
+ batch_size, num_patches, channels = latents.shape
+
+ height = height // vae_scale_factor
+ width = width // vae_scale_factor
+
+ latents = latents.view(batch_size, height, width, channels // 4, 2, 2)
+ latents = latents.permute(0, 3, 1, 4, 2, 5)
+
+ latents = latents.reshape(batch_size, channels // (2 * 2), height * 2, width * 2)
+
+ return latents
+
+ # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.prepare_latents
+ def prepare_latents(
+ self,
+ image,
+ timestep,
+ batch_size,
+ num_channels_latents,
+ height,
+ width,
+ dtype,
+ device,
+ generator,
+ latents=None,
+ ):
+ height = 2 * (int(height) // self.vae_scale_factor)
+ width = 2 * (int(width) // self.vae_scale_factor)
+
+ shape = (batch_size, num_channels_latents, height, width)
+
+ if latents is not None:
+ latent_image_ids = self._prepare_latent_image_ids(batch_size, height, width, device, dtype)
+ return latents.to(device=device, dtype=dtype), latent_image_ids
+
+ if isinstance(generator, list) and len(generator) != batch_size:
+ raise ValueError(
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
+ )
+ image = image.to(device=device, dtype=dtype)
+ image_latents = self._encode_vae_image(image=image, generator=generator)
+ noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
+ if timestep == 28:
+ latents = self.scheduler.scale_noise(image_latents, timestep, noise)
+ else:
+ latents = noise
+ latents = self._pack_latents(latents, batch_size, num_channels_latents, height, width)
+
+ latent_image_ids = self._prepare_latent_image_ids(batch_size, height, width, device, dtype)
+
+ return latents, latent_image_ids
+
+ # Copied from diffusers.pipelines.controlnet_sd3.pipeline_stable_diffusion_3_controlnet.StableDiffusion3ControlNetPipeline.prepare_image
+ def prepare_image(
+ self,
+ image,
+ width,
+ height,
+ batch_size,
+ num_images_per_prompt,
+ device,
+ dtype,
+ do_classifier_free_guidance=False,
+ guess_mode=False,
+ ):
+ if isinstance(image, torch.Tensor):
+ pass
+ else:
+ image = self.image_processor.preprocess(image, height=height, width=width)
+
+ image_batch_size = image.shape[0]
+
+ if image_batch_size == 1:
+ repeat_by = batch_size
+ else:
+ # image batch size is the same as prompt batch size
+ repeat_by = num_images_per_prompt
+
+ image = image.repeat_interleave(repeat_by, dim=0)
+
+ image = image.to(device=device, dtype=dtype)
+
+ if do_classifier_free_guidance and not guess_mode:
+ image = torch.cat([image] * 2)
+
+ return image
+
+ @property
+ def guidance_scale(self):
+ return self._guidance_scale
+
+ @property
+ def joint_attention_kwargs(self):
+ return self._joint_attention_kwargs
+
+ @property
+ def num_timesteps(self):
+ return self._num_timesteps
+
+ @property
+ def interrupt(self):
+ return self._interrupt
+
+ @torch.no_grad()
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
+ def __call__(
+ self,
+ prompt: Union[str, List[str]] = None,
+ prompt_2: Optional[Union[str, List[str]]] = None,
+ image: PipelineImageInput = None,
+ height: Optional[int] = None,
+ width: Optional[int] = None,
+ strength: float = 0.6,
+ num_inference_steps: int = 28,
+ timesteps: List[int] = None,
+ guidance_scale: float = 7.0,
+ control_image: PipelineImageInput = None,
+ control_mode: Optional[Union[int, List[int]]] = None,
+ controlnet_conditioning_scale: Union[float, List[float]] = 1.0,
+ num_images_per_prompt: Optional[int] = 1,
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+ latents: Optional[torch.FloatTensor] = None,
+ prompt_embeds: Optional[torch.FloatTensor] = None,
+ t5_prompt_embeds: Optional[torch.FloatTensor] = None,
+ prompt_embeds_control: Optional[torch.FloatTensor] = None,
+ pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
+ output_type: Optional[str] = "pil",
+ return_dict: bool = True,
+ joint_attention_kwargs: Optional[Dict[str, Any]] = None,
+ callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
+ max_sequence_length: int = 512,
+ ):
+ r"""
+ Function invoked when calling the pipeline for generation.
+
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
+ instead.
+ prompt_2 (`str` or `List[str]`, *optional*):
+ The prompt or prompts to be sent to `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
+ will be used instead
+ height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
+ The height in pixels of the generated image. This is set to 1024 by default for the best results.
+ width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
+ The width in pixels of the generated image. This is set to 1024 by default for the best results.
+ num_inference_steps (`int`, *optional*, defaults to 50):
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
+ expense of slower inference.
+ timesteps (`List[int]`, *optional*):
+ Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
+ in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
+ passed will be used. Must be in descending order.
+ guidance_scale (`float`, *optional*, defaults to 7.0):
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
+ usually at the expense of lower image quality.
+ control_image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.Tensor]`, `List[PIL.Image.Image]`, `List[np.ndarray]`,:
+ `List[List[torch.Tensor]]`, `List[List[np.ndarray]]` or `List[List[PIL.Image.Image]]`):
+ The ControlNet input condition to provide guidance to the `unet` for generation. If the type is
+ specified as `torch.Tensor`, it is passed to ControlNet as is. `PIL.Image.Image` can also be accepted
+ as an image. The dimensions of the output image defaults to `image`'s dimensions. If height and/or
+ width are passed, `image` is resized accordingly. If multiple ControlNets are specified in `init`,
+ images must be passed as a list such that each element of the list can be correctly batched for input
+ to a single ControlNet.
+ controlnet_conditioning_scale (`float` or `List[float]`, *optional*, defaults to 1.0):
+ The outputs of the ControlNet are multiplied by `controlnet_conditioning_scale` before they are added
+ to the residual in the original `unet`. If multiple ControlNets are specified in `init`, you can set
+ the corresponding scale as a list.
+ control_mode (`int` or `List[int]`,, *optional*, defaults to None):
+ The control mode when applying ControlNet-Union.
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
+ The number of images to generate per prompt.
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
+ to make generation deterministic.
+ latents (`torch.FloatTensor`, *optional*):
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
+ tensor will ge generated by sampling using the supplied random `generator`.
+ prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+ provided, text embeddings will be generated from `prompt` input argument.
+ pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
+ If not provided, pooled text embeddings will be generated from `prompt` input argument.
+ output_type (`str`, *optional*, defaults to `"pil"`):
+ The output format of the generate image. Choose between
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~pipelines.flux.FluxPipelineOutput`] instead of a plain tuple.
+ joint_attention_kwargs (`dict`, *optional*):
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
+ `self.processor` in
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
+ callback_on_step_end (`Callable`, *optional*):
+ A function that calls at the end of each denoising steps during the inference. The function is called
+ with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
+ callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
+ `callback_on_step_end_tensor_inputs`.
+ callback_on_step_end_tensor_inputs (`List`, *optional*):
+ The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
+ will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
+ `._callback_tensor_inputs` attribute of your pipeline class.
+ max_sequence_length (`int` defaults to 512): Maximum sequence length to use with the `prompt`.
+
+ Examples:
+
+ Returns:
+ [`~pipelines.flux.FluxPipelineOutput`] or `tuple`: [`~pipelines.flux.FluxPipelineOutput`] if `return_dict`
+ is True, otherwise a `tuple`. When returning a tuple, the first element is a list with the generated
+ images.
+ """
+
+ height = height or self.default_sample_size * self.vae_scale_factor
+ width = width or self.default_sample_size * self.vae_scale_factor
+
+ # 1. Check inputs. Raise error if not correct
+ self.check_inputs(
+ prompt,
+ prompt_2,
+ height,
+ width,
+ prompt_embeds=prompt_embeds,
+ pooled_prompt_embeds=pooled_prompt_embeds,
+ callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
+ max_sequence_length=max_sequence_length,
+ )
+
+ self._guidance_scale = guidance_scale
+ self._joint_attention_kwargs = joint_attention_kwargs
+ self._interrupt = False
+
+ # 2. Preprocess image
+ init_image = self.image_processor.preprocess(image, height=height, width=width)
+ init_image = init_image.to(dtype=torch.float32)
+
+ # 2. Define call parameters
+ if prompt is not None and isinstance(prompt, str):
+ batch_size = 1
+ elif prompt is not None and isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ device = self._execution_device
+ dtype = self.transformer.dtype
+
+ lora_scale = (
+ self.joint_attention_kwargs.get("scale", None) if self.joint_attention_kwargs is not None else None
+ )
+ (
+ prompt_embeds,
+ pooled_prompt_embeds,
+ text_ids,
+ ) = self.encode_prompt(
+ prompt=prompt,
+ prompt_2=prompt_2,
+ prompt_embeds=prompt_embeds,
+ t5_prompt_embeds=t5_prompt_embeds,
+ pooled_prompt_embeds=pooled_prompt_embeds,
+ device=device,
+ num_images_per_prompt=num_images_per_prompt,
+ max_sequence_length=max_sequence_length,
+ lora_scale=lora_scale,
+ )
+
+ # 3. Prepare control image
+ num_channels_latents = self.transformer.config.in_channels // 4
+ if isinstance(self.controlnet, FluxControlNetModel):
+ control_image = self.prepare_image(
+ image=control_image,
+ width=width,
+ height=height,
+ batch_size=batch_size * num_images_per_prompt,
+ num_images_per_prompt=num_images_per_prompt,
+ device=device,
+ dtype=self.vae.dtype,
+ )
+ height, width = control_image.shape[-2:]
+
+ # vae encode
+ control_image = self.vae.encode(control_image).latent_dist.sample()
+ control_image = (control_image - self.vae.config.shift_factor) * self.vae.config.scaling_factor
+
+ # pack
+ height_control_image, width_control_image = control_image.shape[2:]
+ control_image = self._pack_latents(
+ control_image,
+ batch_size * num_images_per_prompt,
+ num_channels_latents,
+ height_control_image,
+ width_control_image,
+ )
+
+ # Here we ensure that `control_mode` has the same length as the control_image.
+ if control_mode is not None:
+ if not isinstance(control_mode, int):
+ raise ValueError(" For `FluxControlNet`, `control_mode` should be an `int` or `None`")
+ control_mode = torch.tensor(control_mode).to(device, dtype=torch.long)
+ control_mode = control_mode.view(-1, 1).expand(control_image.shape[0], 1)
+
+ elif isinstance(self.controlnet, FluxMultiControlNetModel):
+ control_images = []
+
+ for control_image_ in control_image:
+ control_image_ = self.prepare_image(
+ image=control_image_,
+ width=width,
+ height=height,
+ batch_size=batch_size * num_images_per_prompt,
+ num_images_per_prompt=num_images_per_prompt,
+ device=device,
+ dtype=self.vae.dtype,
+ )
+ height, width = control_image_.shape[-2:]
+
+ # vae encode
+ control_image_ = self.vae.encode(control_image_).latent_dist.sample()
+ control_image_ = (control_image_ - self.vae.config.shift_factor) * self.vae.config.scaling_factor
+
+ # pack
+ height_control_image, width_control_image = control_image_.shape[2:]
+ control_image_ = self._pack_latents(
+ control_image_,
+ batch_size * num_images_per_prompt,
+ num_channels_latents,
+ height_control_image,
+ width_control_image,
+ )
+
+ control_images.append(control_image_)
+
+ control_image = control_images
+
+ # Here we ensure that `control_mode` has the same length as the control_image.
+ if isinstance(control_mode, list) and len(control_mode) != len(control_image):
+ raise ValueError(
+ "For Multi-ControlNet, `control_mode` must be a list of the same "
+ + " length as the number of controlnets (control images) specified"
+ )
+ if not isinstance(control_mode, list):
+ control_mode = [control_mode] * len(control_image)
+ # set control mode
+ control_modes = []
+ for cmode in control_mode:
+ if cmode is None:
+ cmode = -1
+ control_mode = torch.tensor(cmode).expand(control_images[0].shape[0]).to(device, dtype=torch.long)
+ control_modes.append(control_mode)
+ control_mode = control_modes
+
+ # 5. Prepare timesteps
+ sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps)
+ #image_seq_len = latents.shape[1]
+ image_seq_len = (int(height) // self.vae_scale_factor) * (int(width) // self.vae_scale_factor)
+ mu = calculate_shift(
+ image_seq_len,
+ self.scheduler.config.base_image_seq_len,
+ self.scheduler.config.max_image_seq_len,
+ self.scheduler.config.base_shift,
+ self.scheduler.config.max_shift,
+ )
+ timesteps, num_inference_steps = retrieve_timesteps(
+ self.scheduler,
+ num_inference_steps,
+ device,
+ timesteps,
+ sigmas,
+ mu=mu,
+ )
+ timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device)
+ latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt)
+
+ # 4. Prepare latent variables
+ num_channels_latents = self.transformer.config.in_channels // 4
+ latents, latent_image_ids = self.prepare_latents(
+ init_image,
+ latent_timestep,
+ batch_size * num_images_per_prompt,
+ num_channels_latents,
+ height,
+ width,
+ prompt_embeds.dtype,
+ device,
+ generator,
+ latents,
+ )
+
+ num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
+ self._num_timesteps = len(timesteps)
+
+ # 6. Denoising loop
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
+ for i, t in enumerate(timesteps):
+ if self.interrupt:
+ continue
+
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
+ timestep = t.expand(latents.shape[0]).to(latents.dtype)
+
+ if isinstance(self.controlnet, FluxMultiControlNetModel):
+ use_guidance = self.controlnet.nets[0].config.guidance_embeds
+ else:
+ use_guidance = self.controlnet.config.guidance_embeds
+
+ guidance = torch.tensor([guidance_scale], device=device) if use_guidance else None
+ guidance = guidance.expand(latents.shape[0]) if guidance is not None else None
+
+ # controlnet
+ controlnet_block_samples, controlnet_single_block_samples = self.controlnet(
+ hidden_states=latents,
+ controlnet_cond=control_image,
+ controlnet_mode=control_mode,
+ conditioning_scale=controlnet_conditioning_scale,
+ timestep=timestep / 1000,
+ guidance=guidance,
+ pooled_projections=pooled_prompt_embeds,
+ encoder_hidden_states=prompt_embeds_control,
+ t5_encoder_hidden_states=t5_prompt_embeds,
+ txt_ids=text_ids,
+ img_ids=latent_image_ids,
+ joint_attention_kwargs=self.joint_attention_kwargs,
+ return_dict=False,
+ )
+
+ guidance = (
+ torch.tensor([guidance_scale], device=device) if self.transformer.config.guidance_embeds else None
+ )
+ guidance = guidance.expand(latents.shape[0]) if guidance is not None else None
+
+ noise_pred = self.transformer(
+ hidden_states=latents,
+ timestep=timestep / 1000,
+ guidance=guidance,
+ pooled_projections=pooled_prompt_embeds,
+ encoder_hidden_states=prompt_embeds,
+ t5_encoder_hidden_states=t5_prompt_embeds,
+ controlnet_block_samples=controlnet_block_samples,
+ controlnet_single_block_samples=controlnet_single_block_samples,
+ txt_ids=text_ids,
+ img_ids=latent_image_ids,
+ joint_attention_kwargs=self.joint_attention_kwargs,
+ return_dict=False,
+ )[0]
+
+ # compute the previous noisy sample x_t -> x_t-1
+ latents_dtype = latents.dtype
+ latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
+
+ if latents.dtype != latents_dtype:
+ if torch.backends.mps.is_available():
+ # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
+ latents = latents.to(latents_dtype)
+
+ if callback_on_step_end is not None:
+ callback_kwargs = {}
+ for k in callback_on_step_end_tensor_inputs:
+ callback_kwargs[k] = locals()[k]
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
+
+ latents = callback_outputs.pop("latents", latents)
+ prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
+
+ # call the callback, if provided
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
+ progress_bar.update()
+
+ if XLA_AVAILABLE:
+ xm.mark_step()
+
+ if output_type == "latent":
+ image = latents
+
+ else:
+ latents = self._unpack_latents(latents, height, width, self.vae_scale_factor)
+ latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor
+
+ image = self.vae.decode(latents, return_dict=False)[0]
+ image = self.image_processor.postprocess(image, output_type=output_type)
+
+ # Offload all models
+ self.maybe_free_model_hooks()
+
+ if not return_dict:
+ return (image,)
+
+ return FluxPipelineOutput(images=image)
\ No newline at end of file
diff --git a/flux/pipeline_flux_controlnet_inpainting.py b/flux/pipeline_flux_controlnet_inpainting.py
new file mode 100644
index 0000000000000000000000000000000000000000..8eba25b22754aab84d38abf46533949a3055ee89
--- /dev/null
+++ b/flux/pipeline_flux_controlnet_inpainting.py
@@ -0,0 +1,1199 @@
+import inspect
+from typing import Any, Callable, Dict, List, Optional, Tuple, Union
+
+import numpy as np
+import PIL
+import torch
+from transformers import (
+ CLIPTextModel,
+ CLIPTokenizer,
+ T5EncoderModel,
+ T5TokenizerFast,
+)
+
+from diffusers.image_processor import PipelineImageInput, VaeImageProcessor
+from diffusers.loaders import FluxLoraLoaderMixin, FromSingleFileMixin, TextualInversionLoaderMixin
+from diffusers.models.autoencoders import AutoencoderKL
+from .controlnet_flux import FluxControlNetModel, FluxMultiControlNetModel
+from .transformer_flux import FluxTransformer2DModel
+from .scheduling_flow_match_euler_discrete import FlowMatchEulerDiscreteScheduler
+from diffusers.utils import (
+ USE_PEFT_BACKEND,
+ is_torch_xla_available,
+ logging,
+ replace_example_docstring,
+ scale_lora_layers,
+ unscale_lora_layers,
+)
+from diffusers.utils.torch_utils import randn_tensor
+from diffusers.pipelines.pipeline_utils import DiffusionPipeline
+from .pipeline_output import FluxPipelineOutput
+
+
+if is_torch_xla_available():
+ import torch_xla.core.xla_model as xm
+
+ XLA_AVAILABLE = True
+else:
+ XLA_AVAILABLE = False
+
+logger = logging.get_logger(__name__)
+
+EXAMPLE_DOC_STRING = """
+ Examples:
+ ```py
+ >>> import torch
+ >>> from diffusers import FluxControlNetInpaintPipeline
+ >>> from diffusers.models import FluxControlNetModel
+ >>> from diffusers.utils import load_image
+
+ >>> controlnet = FluxControlNetModel.from_pretrained(
+ ... "InstantX/FLUX.1-dev-controlnet-canny", torch_dtype=torch.float16
+ ... )
+ >>> pipe = FluxControlNetInpaintPipeline.from_pretrained(
+ ... "black-forest-labs/FLUX.1-schnell", controlnet=controlnet, torch_dtype=torch.float16
+ ... )
+ >>> pipe.to("cuda")
+
+ >>> control_image = load_image(
+ ... "https://huggingface.co/InstantX/FLUX.1-dev-Controlnet-Canny-alpha/resolve/main/canny.jpg"
+ ... )
+ >>> init_image = load_image(
+ ... "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo.png"
+ ... )
+ >>> mask_image = load_image(
+ ... "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo_mask.png"
+ ... )
+
+ >>> prompt = "A girl holding a sign that says InstantX"
+ >>> image = pipe(
+ ... prompt,
+ ... image=init_image,
+ ... mask_image=mask_image,
+ ... control_image=control_image,
+ ... control_guidance_start=0.2,
+ ... control_guidance_end=0.8,
+ ... controlnet_conditioning_scale=0.7,
+ ... strength=0.7,
+ ... num_inference_steps=28,
+ ... guidance_scale=3.5,
+ ... ).images[0]
+ >>> image.save("flux_controlnet_inpaint.png")
+ ```
+"""
+
+
+# Copied from diffusers.pipelines.flux.pipeline_flux.calculate_shift
+def calculate_shift(
+ image_seq_len,
+ base_seq_len: int = 256,
+ max_seq_len: int = 4096,
+ base_shift: float = 0.5,
+ max_shift: float = 1.16,
+):
+ m = (max_shift - base_shift) / (max_seq_len - base_seq_len)
+ b = base_shift - m * base_seq_len
+ mu = image_seq_len * m + b
+ return mu
+
+
+# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
+def retrieve_latents(
+ encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
+):
+ if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
+ return encoder_output.latent_dist.sample(generator)
+ elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
+ return encoder_output.latent_dist.mode()
+ elif hasattr(encoder_output, "latents"):
+ return encoder_output.latents
+ else:
+ raise AttributeError("Could not access latents of provided encoder_output")
+
+
+# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
+def retrieve_timesteps(
+ scheduler,
+ num_inference_steps: Optional[int] = None,
+ device: Optional[Union[str, torch.device]] = None,
+ timesteps: Optional[List[int]] = None,
+ sigmas: Optional[List[float]] = None,
+ **kwargs,
+):
+ r"""
+ Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
+ custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
+
+ Args:
+ scheduler (`SchedulerMixin`):
+ The scheduler to get timesteps from.
+ num_inference_steps (`int`):
+ The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
+ must be `None`.
+ device (`str` or `torch.device`, *optional*):
+ The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
+ timesteps (`List[int]`, *optional*):
+ Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
+ `num_inference_steps` and `sigmas` must be `None`.
+ sigmas (`List[float]`, *optional*):
+ Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
+ `num_inference_steps` and `timesteps` must be `None`.
+
+ Returns:
+ `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
+ second element is the number of inference steps.
+ """
+ if timesteps is not None and sigmas is not None:
+ raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
+ if timesteps is not None:
+ accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
+ if not accepts_timesteps:
+ raise ValueError(
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
+ f" timestep schedules. Please check whether you are using the correct scheduler."
+ )
+ scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
+ timesteps = scheduler.timesteps
+ num_inference_steps = len(timesteps)
+ elif sigmas is not None:
+ accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
+ if not accept_sigmas:
+ raise ValueError(
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
+ f" sigmas schedules. Please check whether you are using the correct scheduler."
+ )
+ scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
+ timesteps = scheduler.timesteps
+ num_inference_steps = len(timesteps)
+ else:
+ scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
+ timesteps = scheduler.timesteps
+ return timesteps, num_inference_steps
+
+
+class FluxControlNetInpaintPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleFileMixin):
+ r"""
+ The Flux controlnet pipeline for inpainting.
+
+ Reference: https://blackforestlabs.ai/announcing-black-forest-labs/
+
+ Args:
+ transformer ([`FluxTransformer2DModel`]):
+ Conditional Transformer (MMDiT) architecture to denoise the encoded image latents.
+ scheduler ([`FlowMatchEulerDiscreteScheduler`]):
+ A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
+ vae ([`AutoencoderKL`]):
+ Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
+ text_encoder ([`CLIPTextModel`]):
+ [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically
+ the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
+ text_encoder_2 ([`T5EncoderModel`]):
+ [T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5EncoderModel), specifically
+ the [google/t5-v1_1-xxl](https://huggingface.co/google/t5-v1_1-xxl) variant.
+ tokenizer (`CLIPTokenizer`):
+ Tokenizer of class
+ [CLIPTokenizer](https://huggingface.co/docs/transformers/en/model_doc/clip#transformers.CLIPTokenizer).
+ tokenizer_2 (`T5TokenizerFast`):
+ Second Tokenizer of class
+ [T5TokenizerFast](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5TokenizerFast).
+ """
+
+ model_cpu_offload_seq = "text_encoder->text_encoder_2->transformer->vae"
+ _optional_components = []
+ _callback_tensor_inputs = ["latents", "prompt_embeds"]
+
+ def __init__(
+ self,
+ scheduler: FlowMatchEulerDiscreteScheduler,
+ vae: AutoencoderKL,
+ text_encoder: CLIPTextModel,
+ tokenizer: CLIPTokenizer,
+ transformer: FluxTransformer2DModel,
+ controlnet: Union[
+ FluxControlNetModel, List[FluxControlNetModel], Tuple[FluxControlNetModel], FluxMultiControlNetModel
+ ],
+ text_encoder_2: T5EncoderModel | None = None,
+ tokenizer_2: T5TokenizerFast | None = None,
+ ):
+ super().__init__()
+ if isinstance(controlnet, (list, tuple)):
+ controlnet = FluxMultiControlNetModel(controlnet)
+
+ self.register_modules(
+ scheduler=scheduler,
+ vae=vae,
+ text_encoder=text_encoder,
+ tokenizer=tokenizer,
+ text_encoder_2=text_encoder_2,
+ tokenizer_2=tokenizer_2,
+ transformer=transformer,
+ controlnet=controlnet,
+ )
+
+ self.vae_scale_factor = (
+ 2 ** (len(self.vae.config.block_out_channels)) if hasattr(self, "vae") and self.vae is not None else 16
+ )
+ self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
+ self.mask_processor = VaeImageProcessor(
+ vae_scale_factor=self.vae_scale_factor,
+ vae_latent_channels=self.vae.config.latent_channels,
+ do_normalize=False,
+ do_binarize=True,
+ do_convert_grayscale=True,
+ )
+ self.tokenizer_max_length = (
+ self.tokenizer.model_max_length if hasattr(self, "tokenizer") and self.tokenizer is not None else 77
+ )
+ self.default_sample_size = 64
+
+ # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._get_t5_prompt_embeds
+ def _get_t5_prompt_embeds(
+ self,
+ prompt: Union[str, List[str]] = None,
+ num_images_per_prompt: int = 1,
+ max_sequence_length: int = 512,
+ device: Optional[torch.device] = None,
+ dtype: Optional[torch.dtype] = None,
+ ):
+ device = device or self._execution_device
+ dtype = dtype or self.text_encoder.dtype
+
+ prompt = [prompt] if isinstance(prompt, str) else prompt
+ batch_size = len(prompt)
+
+ if isinstance(self, TextualInversionLoaderMixin):
+ prompt = self.maybe_convert_prompt(prompt, self.tokenizer_2)
+
+ text_inputs = self.tokenizer_2(
+ prompt,
+ padding="max_length",
+ max_length=max_sequence_length,
+ truncation=True,
+ return_length=False,
+ return_overflowing_tokens=False,
+ return_tensors="pt",
+ )
+ text_input_ids = text_inputs.input_ids
+ untruncated_ids = self.tokenizer_2(prompt, padding="longest", return_tensors="pt").input_ids
+
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
+ removed_text = self.tokenizer_2.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1])
+ logger.warning(
+ "The following part of your input was truncated because `max_sequence_length` is set to "
+ f" {max_sequence_length} tokens: {removed_text}"
+ )
+
+ prompt_embeds = self.text_encoder_2(text_input_ids.to(device), output_hidden_states=False)[0]
+
+ dtype = self.text_encoder_2.dtype
+ prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
+
+ _, seq_len, _ = prompt_embeds.shape
+
+ # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
+ prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
+
+ return prompt_embeds
+
+ # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._get_clip_prompt_embeds
+ def _get_clip_prompt_embeds(
+ self,
+ prompt: Union[str, List[str]],
+ num_images_per_prompt: int = 1,
+ device: Optional[torch.device] = None,
+ ):
+ device = device or self._execution_device
+
+ prompt = [prompt] if isinstance(prompt, str) else prompt
+ batch_size = len(prompt)
+
+ if isinstance(self, TextualInversionLoaderMixin):
+ prompt = self.maybe_convert_prompt(prompt, self.tokenizer)
+
+ text_inputs = self.tokenizer(
+ prompt,
+ padding="max_length",
+ max_length=self.tokenizer_max_length,
+ truncation=True,
+ return_overflowing_tokens=False,
+ return_length=False,
+ return_tensors="pt",
+ )
+
+ text_input_ids = text_inputs.input_ids
+ untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
+ removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1])
+ logger.warning(
+ "The following part of your input was truncated because CLIP can only handle sequences up to"
+ f" {self.tokenizer_max_length} tokens: {removed_text}"
+ )
+ prompt_embeds = self.text_encoder(text_input_ids.to(device), output_hidden_states=False)
+
+ # Use pooled output of CLIPTextModel
+ prompt_embeds = prompt_embeds.pooler_output
+ prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)
+
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt)
+ prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, -1)
+
+ return prompt_embeds
+
+ # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.encode_prompt
+ def encode_prompt(
+ self,
+ prompt: Union[str, List[str]],
+ prompt_2: Union[str, List[str]],
+ device: Optional[torch.device] = None,
+ num_images_per_prompt: int = 1,
+ prompt_embeds: Optional[torch.FloatTensor] = None,
+ t5_prompt_embeds: Optional[torch.FloatTensor] = None,
+ pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
+ max_sequence_length: int = 512,
+ lora_scale: Optional[float] = None,
+ ):
+ r"""
+
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ prompt to be encoded
+ prompt_2 (`str` or `List[str]`, *optional*):
+ The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
+ used in all text-encoders
+ device: (`torch.device`):
+ torch device
+ num_images_per_prompt (`int`):
+ number of images that should be generated per prompt
+ prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+ provided, text embeddings will be generated from `prompt` input argument.
+ pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
+ If not provided, pooled text embeddings will be generated from `prompt` input argument.
+ lora_scale (`float`, *optional*):
+ A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
+ """
+ device = device or self._execution_device
+
+ # set lora scale so that monkey patched LoRA
+ # function of text encoder can correctly access it
+ if lora_scale is not None and isinstance(self, FluxLoraLoaderMixin):
+ self._lora_scale = lora_scale
+
+ # dynamically adjust the LoRA scale
+ if self.text_encoder is not None and USE_PEFT_BACKEND:
+ scale_lora_layers(self.text_encoder, lora_scale)
+ if self.text_encoder_2 is not None and USE_PEFT_BACKEND:
+ scale_lora_layers(self.text_encoder_2, lora_scale)
+
+ prompt = [prompt] if isinstance(prompt, str) else prompt
+
+ if prompt_embeds is None:
+ prompt_2 = prompt_2 or prompt
+ prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2
+
+ # We only use the pooled prompt output from the CLIPTextModel
+ pooled_prompt_embeds = self._get_clip_prompt_embeds(
+ prompt=prompt,
+ device=device,
+ num_images_per_prompt=num_images_per_prompt,
+ )
+ prompt_embeds = self._get_t5_prompt_embeds(
+ prompt=prompt_2,
+ num_images_per_prompt=num_images_per_prompt,
+ max_sequence_length=max_sequence_length,
+ device=device,
+ )
+
+ if self.text_encoder is not None:
+ if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND:
+ # Retrieve the original scale by scaling back the LoRA layers
+ unscale_lora_layers(self.text_encoder, lora_scale)
+
+ if self.text_encoder_2 is not None:
+ if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND:
+ # Retrieve the original scale by scaling back the LoRA layers
+ unscale_lora_layers(self.text_encoder_2, lora_scale)
+
+ dtype = self.text_encoder.dtype if self.text_encoder is not None else self.transformer.dtype
+ if t5_prompt_embeds is not None:
+ text_ids = torch.zeros(prompt_embeds.shape[1] + t5_prompt_embeds.shape[1], 3).to(device=device, dtype=dtype)
+ else:
+ text_ids = torch.zeros(prompt_embeds.shape[1], 3).to(device=device, dtype=dtype)
+
+
+ return prompt_embeds, pooled_prompt_embeds, text_ids
+
+ # Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3_inpaint.StableDiffusion3InpaintPipeline._encode_vae_image
+ def _encode_vae_image(self, image: torch.Tensor, generator: torch.Generator):
+ if isinstance(generator, list):
+ image_latents = [
+ retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i])
+ for i in range(image.shape[0])
+ ]
+ image_latents = torch.cat(image_latents, dim=0)
+ else:
+ image_latents = retrieve_latents(self.vae.encode(image), generator=generator)
+
+ image_latents = (image_latents - self.vae.config.shift_factor) * self.vae.config.scaling_factor
+
+ return image_latents
+
+ # Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3_img2img.StableDiffusion3Img2ImgPipeline.get_timesteps
+ def get_timesteps(self, num_inference_steps, strength, device):
+ # get the original timestep using init_timestep
+ init_timestep = min(num_inference_steps * strength, num_inference_steps)
+
+ t_start = int(max(num_inference_steps - init_timestep, 0))
+ timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :]
+ if hasattr(self.scheduler, "set_begin_index"):
+ self.scheduler.set_begin_index(t_start * self.scheduler.order)
+
+ return timesteps, num_inference_steps - t_start
+
+ def check_inputs(
+ self,
+ prompt,
+ prompt_2,
+ image,
+ mask_image,
+ strength,
+ height,
+ width,
+ output_type,
+ prompt_embeds=None,
+ pooled_prompt_embeds=None,
+ callback_on_step_end_tensor_inputs=None,
+ padding_mask_crop=None,
+ max_sequence_length=None,
+ ):
+ if strength < 0 or strength > 1:
+ raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}")
+
+ if height % 8 != 0 or width % 8 != 0:
+ raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
+
+ if callback_on_step_end_tensor_inputs is not None and not all(
+ k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
+ ):
+ raise ValueError(
+ f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
+ )
+
+ if prompt is not None and prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
+ " only forward one of the two."
+ )
+ elif prompt_2 is not None and prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `prompt_2`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
+ " only forward one of the two."
+ )
+ elif prompt is None and prompt_embeds is None:
+ raise ValueError(
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
+ )
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
+ elif prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)):
+ raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}")
+
+ if prompt_embeds is not None and pooled_prompt_embeds is None:
+ raise ValueError(
+ "If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`."
+ )
+
+ if padding_mask_crop is not None:
+ if not isinstance(image, PIL.Image.Image):
+ raise ValueError(
+ f"The image should be a PIL image when inpainting mask crop, but is of type" f" {type(image)}."
+ )
+ if not isinstance(mask_image, PIL.Image.Image):
+ raise ValueError(
+ f"The mask image should be a PIL image when inpainting mask crop, but is of type"
+ f" {type(mask_image)}."
+ )
+ if output_type != "pil":
+ raise ValueError(f"The output type should be PIL when inpainting mask crop, but is" f" {output_type}.")
+
+ if max_sequence_length is not None and max_sequence_length > 512:
+ raise ValueError(f"`max_sequence_length` cannot be greater than 512 but is {max_sequence_length}")
+
+ @staticmethod
+ # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._prepare_latent_image_ids
+ def _prepare_latent_image_ids(batch_size, height, width, device, dtype):
+ latent_image_ids = torch.zeros(height // 2, width // 2, 3)
+ latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height // 2)[:, None]
+ latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width // 2)[None, :]
+
+ latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape
+
+ latent_image_ids = latent_image_ids.reshape(
+ latent_image_id_height * latent_image_id_width, latent_image_id_channels
+ )
+
+ return latent_image_ids.to(device=device, dtype=dtype)
+
+ @staticmethod
+ # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._pack_latents
+ def _pack_latents(latents, batch_size, num_channels_latents, height, width):
+ latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2)
+ latents = latents.permute(0, 2, 4, 1, 3, 5)
+ latents = latents.reshape(batch_size, (height // 2) * (width // 2), num_channels_latents * 4)
+
+ return latents
+
+ @staticmethod
+ # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._unpack_latents
+ def _unpack_latents(latents, height, width, vae_scale_factor):
+ batch_size, num_patches, channels = latents.shape
+
+ height = height // vae_scale_factor
+ width = width // vae_scale_factor
+
+ latents = latents.view(batch_size, height, width, channels // 4, 2, 2)
+ latents = latents.permute(0, 3, 1, 4, 2, 5)
+
+ latents = latents.reshape(batch_size, channels // (2 * 2), height * 2, width * 2)
+
+ return latents
+
+ # Copied from diffusers.pipelines.flux.pipeline_flux_inpaint.FluxInpaintPipeline.prepare_latents
+ def prepare_latents(
+ self,
+ image,
+ timestep,
+ batch_size,
+ num_channels_latents,
+ height,
+ width,
+ dtype,
+ device,
+ generator,
+ latents=None,
+ ):
+ if isinstance(generator, list) and len(generator) != batch_size:
+ raise ValueError(
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
+ )
+
+ height = 2 * (int(height) // self.vae_scale_factor)
+ width = 2 * (int(width) // self.vae_scale_factor)
+
+ shape = (batch_size, num_channels_latents, height, width)
+ latent_image_ids = self._prepare_latent_image_ids(batch_size, height, width, device, dtype)
+
+ image = image.to(device=device, dtype=dtype)
+ image_latents = self._encode_vae_image(image=image, generator=generator)
+
+ if batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] == 0:
+ # expand init_latents for batch_size
+ additional_image_per_prompt = batch_size // image_latents.shape[0]
+ image_latents = torch.cat([image_latents] * additional_image_per_prompt, dim=0)
+ elif batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] != 0:
+ raise ValueError(
+ f"Cannot duplicate `image` of batch size {image_latents.shape[0]} to {batch_size} text prompts."
+ )
+ else:
+ image_latents = torch.cat([image_latents], dim=0)
+
+ if latents is None:
+ noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
+ latents = self.scheduler.scale_noise(image_latents, timestep, noise)
+ else:
+ noise = latents.to(device)
+ latents = noise
+
+ noise = self._pack_latents(noise, batch_size, num_channels_latents, height, width)
+ image_latents = self._pack_latents(image_latents, batch_size, num_channels_latents, height, width)
+ latents = self._pack_latents(latents, batch_size, num_channels_latents, height, width)
+ return latents, noise, image_latents, latent_image_ids
+
+ # Copied from diffusers.pipelines.flux.pipeline_flux_inpaint.FluxInpaintPipeline.prepare_mask_latents
+ def prepare_mask_latents(
+ self,
+ mask,
+ masked_image,
+ batch_size,
+ num_channels_latents,
+ num_images_per_prompt,
+ height,
+ width,
+ dtype,
+ device,
+ generator,
+ ):
+ height = 2 * (int(height) // self.vae_scale_factor)
+ width = 2 * (int(width) // self.vae_scale_factor)
+ # resize the mask to latents shape as we concatenate the mask to the latents
+ # we do that before converting to dtype to avoid breaking in case we're using cpu_offload
+ # and half precision
+ mask = torch.nn.functional.interpolate(mask, size=(height, width))
+ mask = mask.to(device=device, dtype=dtype)
+
+ batch_size = batch_size * num_images_per_prompt
+
+ masked_image = masked_image.to(device=device, dtype=dtype)
+
+ if masked_image.shape[1] == 16:
+ masked_image_latents = masked_image
+ else:
+ masked_image_latents = retrieve_latents(self.vae.encode(masked_image), generator=generator)
+
+ masked_image_latents = (masked_image_latents - self.vae.config.shift_factor) * self.vae.config.scaling_factor
+
+ # duplicate mask and masked_image_latents for each generation per prompt, using mps friendly method
+ if mask.shape[0] < batch_size:
+ if not batch_size % mask.shape[0] == 0:
+ raise ValueError(
+ "The passed mask and the required batch size don't match. Masks are supposed to be duplicated to"
+ f" a total batch size of {batch_size}, but {mask.shape[0]} masks were passed. Make sure the number"
+ " of masks that you pass is divisible by the total requested batch size."
+ )
+ mask = mask.repeat(batch_size // mask.shape[0], 1, 1, 1)
+ if masked_image_latents.shape[0] < batch_size:
+ if not batch_size % masked_image_latents.shape[0] == 0:
+ raise ValueError(
+ "The passed images and the required batch size don't match. Images are supposed to be duplicated"
+ f" to a total batch size of {batch_size}, but {masked_image_latents.shape[0]} images were passed."
+ " Make sure the number of images that you pass is divisible by the total requested batch size."
+ )
+ masked_image_latents = masked_image_latents.repeat(batch_size // masked_image_latents.shape[0], 1, 1, 1)
+
+ # aligning device to prevent device errors when concating it with the latent model input
+ masked_image_latents = masked_image_latents.to(device=device, dtype=dtype)
+
+ masked_image_latents = self._pack_latents(
+ masked_image_latents,
+ batch_size,
+ num_channels_latents,
+ height,
+ width,
+ )
+ mask = self._pack_latents(
+ mask.repeat(1, num_channels_latents, 1, 1),
+ batch_size,
+ num_channels_latents,
+ height,
+ width,
+ )
+
+ return mask, masked_image_latents
+
+ # Copied from diffusers.pipelines.controlnet_sd3.pipeline_stable_diffusion_3_controlnet.StableDiffusion3ControlNetPipeline.prepare_image
+ def prepare_image(
+ self,
+ image,
+ width,
+ height,
+ batch_size,
+ num_images_per_prompt,
+ device,
+ dtype,
+ do_classifier_free_guidance=False,
+ guess_mode=False,
+ ):
+ if isinstance(image, torch.Tensor):
+ pass
+ else:
+ image = self.image_processor.preprocess(image, height=height, width=width)
+
+ image_batch_size = image.shape[0]
+
+ if image_batch_size == 1:
+ repeat_by = batch_size
+ else:
+ # image batch size is the same as prompt batch size
+ repeat_by = num_images_per_prompt
+
+ image = image.repeat_interleave(repeat_by, dim=0)
+
+ image = image.to(device=device, dtype=dtype)
+
+ if do_classifier_free_guidance and not guess_mode:
+ image = torch.cat([image] * 2)
+
+ return image
+
+ @property
+ def guidance_scale(self):
+ return self._guidance_scale
+
+ @property
+ def joint_attention_kwargs(self):
+ return self._joint_attention_kwargs
+
+ @property
+ def num_timesteps(self):
+ return self._num_timesteps
+
+ @property
+ def interrupt(self):
+ return self._interrupt
+
+ @torch.no_grad()
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
+ def __call__(
+ self,
+ prompt: Union[str, List[str]] = None,
+ prompt_2: Optional[Union[str, List[str]]] = None,
+ image: PipelineImageInput = None,
+ mask_image: PipelineImageInput = None,
+ masked_image_latents: PipelineImageInput = None,
+ control_image: PipelineImageInput = None,
+ height: Optional[int] = None,
+ width: Optional[int] = None,
+ strength: float = 0.6,
+ padding_mask_crop: Optional[int] = None,
+ timesteps: List[int] = None,
+ num_inference_steps: int = 28,
+ guidance_scale: float = 7.0,
+ control_guidance_start: Union[float, List[float]] = 0.0,
+ control_guidance_end: Union[float, List[float]] = 1.0,
+ control_mode: Optional[Union[int, List[int]]] = None,
+ controlnet_conditioning_scale: Union[float, List[float]] = 1.0,
+ num_images_per_prompt: Optional[int] = 1,
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+ latents: Optional[torch.FloatTensor] = None,
+ prompt_embeds: Optional[torch.FloatTensor] = None,
+ t5_prompt_embeds: Optional[torch.FloatTensor] = None,
+ prompt_embeds_control: Optional[torch.FloatTensor] = None,
+ pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
+ output_type: Optional[str] = "pil",
+ return_dict: bool = True,
+ joint_attention_kwargs: Optional[Dict[str, Any]] = None,
+ callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
+ max_sequence_length: int = 512,
+ ):
+ """
+ Function invoked when calling the pipeline for generation.
+
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts to guide the image generation.
+ prompt_2 (`str` or `List[str]`, *optional*):
+ The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`.
+ image (`PIL.Image.Image` or `List[PIL.Image.Image]` or `torch.FloatTensor`):
+ The image(s) to inpaint.
+ mask_image (`PIL.Image.Image` or `List[PIL.Image.Image]` or `torch.FloatTensor`):
+ The mask image(s) to use for inpainting. White pixels in the mask will be repainted, while black pixels
+ will be preserved.
+ masked_image_latents (`torch.FloatTensor`, *optional*):
+ Pre-generated masked image latents.
+ control_image (`PIL.Image.Image` or `List[PIL.Image.Image]` or `torch.FloatTensor`):
+ The ControlNet input condition. Image to control the generation.
+ height (`int`, *optional*, defaults to self.default_sample_size * self.vae_scale_factor):
+ The height in pixels of the generated image.
+ width (`int`, *optional*, defaults to self.default_sample_size * self.vae_scale_factor):
+ The width in pixels of the generated image.
+ strength (`float`, *optional*, defaults to 0.6):
+ Conceptually, indicates how much to inpaint the masked area. Must be between 0 and 1.
+ padding_mask_crop (`int`, *optional*):
+ The size of the padding to use when cropping the mask.
+ num_inference_steps (`int`, *optional*, defaults to 28):
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
+ expense of slower inference.
+ timesteps (`List[int]`, *optional*):
+ Custom timesteps to use for the denoising process.
+ guidance_scale (`float`, *optional*, defaults to 7.0):
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
+ control_guidance_start (`float` or `List[float]`, *optional*, defaults to 0.0):
+ The percentage of total steps at which the ControlNet starts applying.
+ control_guidance_end (`float` or `List[float]`, *optional*, defaults to 1.0):
+ The percentage of total steps at which the ControlNet stops applying.
+ control_mode (`int` or `List[int]`, *optional*):
+ The mode for the ControlNet. If multiple ControlNets are used, this should be a list.
+ controlnet_conditioning_scale (`float` or `List[float]`, *optional*, defaults to 1.0):
+ The outputs of the ControlNet are multiplied by `controlnet_conditioning_scale` before they are added
+ to the residual in the original transformer.
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
+ The number of images to generate per prompt.
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
+ One or more [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) to
+ make generation deterministic.
+ latents (`torch.FloatTensor`, *optional*):
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
+ generation. Can be used to tweak the same generation with different prompts.
+ prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
+ pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated pooled text embeddings.
+ output_type (`str`, *optional*, defaults to `"pil"`):
+ The output format of the generate image. Choose between `PIL.Image` or `np.array`.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~pipelines.flux.FluxPipelineOutput`] instead of a plain tuple.
+ joint_attention_kwargs (`dict`, *optional*):
+ Additional keyword arguments to be passed to the joint attention mechanism.
+ callback_on_step_end (`Callable`, *optional*):
+ A function that calls at the end of each denoising step during the inference.
+ callback_on_step_end_tensor_inputs (`List[str]`, *optional*):
+ The list of tensor inputs for the `callback_on_step_end` function.
+ max_sequence_length (`int`, *optional*, defaults to 512):
+ The maximum length of the sequence to be generated.
+
+ Examples:
+
+ Returns:
+ [`~pipelines.flux.FluxPipelineOutput`] or `tuple`: [`~pipelines.flux.FluxPipelineOutput`] if `return_dict`
+ is True, otherwise a `tuple`. When returning a tuple, the first element is a list with the generated
+ images.
+ """
+ height = height or self.default_sample_size * self.vae_scale_factor
+ width = width or self.default_sample_size * self.vae_scale_factor
+
+ global_height = height
+ global_width = width
+
+ if not isinstance(control_guidance_start, list) and isinstance(control_guidance_end, list):
+ control_guidance_start = len(control_guidance_end) * [control_guidance_start]
+ elif not isinstance(control_guidance_end, list) and isinstance(control_guidance_start, list):
+ control_guidance_end = len(control_guidance_start) * [control_guidance_end]
+ elif not isinstance(control_guidance_start, list) and not isinstance(control_guidance_end, list):
+ mult = len(self.controlnet.nets) if isinstance(self.controlnet, FluxMultiControlNetModel) else 1
+ control_guidance_start, control_guidance_end = (
+ mult * [control_guidance_start],
+ mult * [control_guidance_end],
+ )
+
+ # 1. Check inputs
+ self.check_inputs(
+ prompt,
+ prompt_2,
+ image,
+ mask_image,
+ strength,
+ height,
+ width,
+ output_type=output_type,
+ prompt_embeds=prompt_embeds,
+ pooled_prompt_embeds=pooled_prompt_embeds,
+ callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
+ padding_mask_crop=padding_mask_crop,
+ max_sequence_length=max_sequence_length,
+ )
+
+ self._guidance_scale = guidance_scale
+ self._joint_attention_kwargs = joint_attention_kwargs
+ self._interrupt = False
+
+ # 2. Define call parameters
+ if prompt is not None and isinstance(prompt, str):
+ batch_size = 1
+ elif prompt is not None and isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ device = self._execution_device
+ dtype = self.transformer.dtype
+
+ # 3. Encode input prompt
+ lora_scale = (
+ self.joint_attention_kwargs.get("scale", None) if self.joint_attention_kwargs is not None else None
+ )
+ prompt_embeds, pooled_prompt_embeds, text_ids = self.encode_prompt(
+ prompt=prompt,
+ prompt_2=prompt_2,
+ prompt_embeds=prompt_embeds,
+ t5_prompt_embeds=t5_prompt_embeds,
+ pooled_prompt_embeds=pooled_prompt_embeds,
+ device=device,
+ num_images_per_prompt=num_images_per_prompt,
+ max_sequence_length=max_sequence_length,
+ lora_scale=lora_scale,
+ )
+
+ # 4. Preprocess mask and image
+ if padding_mask_crop is not None:
+ crops_coords = self.mask_processor.get_crop_region(
+ mask_image, global_width, global_height, pad=padding_mask_crop
+ )
+ resize_mode = "fill"
+ else:
+ crops_coords = None
+ resize_mode = "default"
+
+ original_image = image
+ init_image = self.image_processor.preprocess(
+ image, height=global_height, width=global_width, crops_coords=crops_coords, resize_mode=resize_mode
+ )
+ init_image = init_image.to(dtype=torch.float32)
+
+ # 5. Prepare control image
+ num_channels_latents = self.transformer.config.in_channels // 4
+ if isinstance(self.controlnet, FluxControlNetModel):
+ control_image = self.prepare_image(
+ image=control_image,
+ width=height,
+ height=width,
+ batch_size=batch_size * num_images_per_prompt,
+ num_images_per_prompt=num_images_per_prompt,
+ device=device,
+ dtype=self.vae.dtype,
+ )
+ height, width = control_image.shape[-2:]
+
+ # vae encode
+ control_image = self.vae.encode(control_image).latent_dist.sample()
+ control_image = (control_image - self.vae.config.shift_factor) * self.vae.config.scaling_factor
+
+ # pack
+ height_control_image, width_control_image = control_image.shape[2:]
+ control_image = self._pack_latents(
+ control_image,
+ batch_size * num_images_per_prompt,
+ num_channels_latents,
+ height_control_image,
+ width_control_image,
+ )
+
+ # set control mode
+ if control_mode is not None:
+ control_mode = torch.tensor(control_mode).to(device, dtype=torch.long)
+ control_mode = control_mode.reshape([-1, 1])
+
+ elif isinstance(self.controlnet, FluxMultiControlNetModel):
+ control_images = []
+
+ for control_image_ in control_image:
+ control_image_ = self.prepare_image(
+ image=control_image_,
+ width=width,
+ height=height,
+ batch_size=batch_size * num_images_per_prompt,
+ num_images_per_prompt=num_images_per_prompt,
+ device=device,
+ dtype=self.vae.dtype,
+ )
+ height, width = control_image_.shape[-2:]
+
+ # vae encode
+ control_image_ = self.vae.encode(control_image_).latent_dist.sample()
+ control_image_ = (control_image_ - self.vae.config.shift_factor) * self.vae.config.scaling_factor
+
+ # pack
+ height_control_image, width_control_image = control_image_.shape[2:]
+ control_image_ = self._pack_latents(
+ control_image_,
+ batch_size * num_images_per_prompt,
+ num_channels_latents,
+ height_control_image,
+ width_control_image,
+ )
+
+ control_images.append(control_image_)
+
+ control_image = control_images
+
+ ## set control mode
+ #control_mode_ = []
+ #if isinstance(control_mode, list):
+ # for cmode in control_mode:
+ # if cmode is None:
+ # control_mode_.append(-1)
+ # else:
+ # control_mode_.append(cmode)
+ #control_mode = torch.tensor(control_mode_).to(device, dtype=torch.long)
+ #control_mode = control_mode.reshape([-1, 1])
+ control_modes = []
+ for cmode in control_mode:
+ if cmode is None:
+ cmode = -1
+ control_mode = torch.tensor(cmode).expand(control_images[0].shape[0]).to(device, dtype=torch.long)
+ control_modes.append(control_mode)
+ control_mode = control_modes
+
+ # 6. Prepare timesteps
+
+ sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps)
+ image_seq_len = (int(global_height) // self.vae_scale_factor) * (int(global_width) // self.vae_scale_factor)
+ mu = calculate_shift(
+ image_seq_len,
+ self.scheduler.config.base_image_seq_len,
+ self.scheduler.config.max_image_seq_len,
+ self.scheduler.config.base_shift,
+ self.scheduler.config.max_shift,
+ )
+ timesteps, num_inference_steps = retrieve_timesteps(
+ self.scheduler,
+ num_inference_steps,
+ device,
+ timesteps,
+ sigmas,
+ mu=mu,
+ )
+ timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device)
+
+ if num_inference_steps < 1:
+ raise ValueError(
+ f"After adjusting the num_inference_steps by strength parameter: {strength}, the number of pipeline"
+ f"steps is {num_inference_steps} which is < 1 and not appropriate for this pipeline."
+ )
+ latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt)
+
+ # 7. Prepare latent variables
+
+ latents, noise, image_latents, latent_image_ids = self.prepare_latents(
+ init_image,
+ latent_timestep,
+ batch_size * num_images_per_prompt,
+ num_channels_latents,
+ global_height,
+ global_width,
+ prompt_embeds.dtype,
+ device,
+ generator,
+ latents,
+ )
+
+ # 8. Prepare mask latents
+ mask_condition = self.mask_processor.preprocess(
+ mask_image, height=global_height, width=global_width, resize_mode=resize_mode, crops_coords=crops_coords
+ )
+ if masked_image_latents is None:
+ masked_image = init_image * (mask_condition < 0.5)
+ else:
+ masked_image = masked_image_latents
+
+ mask, masked_image_latents = self.prepare_mask_latents(
+ mask_condition,
+ masked_image,
+ batch_size,
+ num_channels_latents,
+ num_images_per_prompt,
+ global_height,
+ global_width,
+ prompt_embeds.dtype,
+ device,
+ generator,
+ )
+
+ controlnet_keep = []
+ for i in range(len(timesteps)):
+ keeps = [
+ 1.0 - float(i / len(timesteps) < s or (i + 1) / len(timesteps) > e)
+ for s, e in zip(control_guidance_start, control_guidance_end)
+ ]
+ controlnet_keep.append(keeps[0] if isinstance(self.controlnet, FluxControlNetModel) else keeps)
+
+ # 9. Denoising loop
+ num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
+ self._num_timesteps = len(timesteps)
+
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
+ for i, t in enumerate(timesteps):
+ if self.interrupt:
+ continue
+
+ timestep = t.expand(latents.shape[0]).to(latents.dtype)
+
+ # predict the noise residual
+ #if self.controlnet.config.guidance_embeds:
+ guidance = torch.full([1], guidance_scale, device=device, dtype=torch.float32)
+ guidance = guidance.expand(latents.shape[0])
+ #else:
+ # guidance = None
+
+ if isinstance(controlnet_keep[i], list):
+ cond_scale = [c * s for c, s in zip(controlnet_conditioning_scale, controlnet_keep[i])]
+ else:
+ controlnet_cond_scale = controlnet_conditioning_scale
+ if isinstance(controlnet_cond_scale, list):
+ controlnet_cond_scale = controlnet_cond_scale[0]
+ cond_scale = controlnet_cond_scale * controlnet_keep[i]
+
+ controlnet_block_samples, controlnet_single_block_samples = self.controlnet(
+ hidden_states=latents,
+ controlnet_cond=control_image,
+ controlnet_mode=control_mode,
+ conditioning_scale=cond_scale,
+ timestep=timestep / 1000,
+ guidance=guidance,
+ pooled_projections=pooled_prompt_embeds,
+ encoder_hidden_states=prompt_embeds_control,
+ t5_encoder_hidden_states=t5_prompt_embeds,
+ txt_ids=text_ids,
+ img_ids=latent_image_ids,
+ joint_attention_kwargs=self.joint_attention_kwargs,
+ return_dict=False,
+ )
+
+ if self.transformer.config.guidance_embeds:
+ guidance = torch.full([1], guidance_scale, device=device, dtype=torch.float32)
+ guidance = guidance.expand(latents.shape[0])
+ else:
+ guidance = None
+
+ noise_pred = self.transformer(
+ hidden_states=latents,
+ timestep=timestep / 1000,
+ guidance=guidance,
+ pooled_projections=pooled_prompt_embeds,
+ encoder_hidden_states=prompt_embeds,
+ t5_encoder_hidden_states=t5_prompt_embeds,
+ controlnet_block_samples=controlnet_block_samples,
+ controlnet_single_block_samples=controlnet_single_block_samples,
+ txt_ids=text_ids,
+ img_ids=latent_image_ids,
+ joint_attention_kwargs=self.joint_attention_kwargs,
+ return_dict=False,
+ )[0]
+
+ # compute the previous noisy sample x_t -> x_t-1
+ latents_dtype = latents.dtype
+ latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
+
+ # For inpainting, we need to apply the mask and add the masked image latents
+ init_latents_proper = image_latents
+ init_mask = mask
+
+ if i < len(timesteps) - 1:
+ noise_timestep = timesteps[i + 1]
+ init_latents_proper = self.scheduler.scale_noise(
+ init_latents_proper, torch.tensor([noise_timestep]), noise
+ )
+
+ latents = (1 - init_mask) * init_latents_proper + init_mask * latents
+
+ if latents.dtype != latents_dtype:
+ if torch.backends.mps.is_available():
+ # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
+ latents = latents.to(latents_dtype)
+
+ # call the callback, if provided
+ if callback_on_step_end is not None:
+ callback_kwargs = {}
+ for k in callback_on_step_end_tensor_inputs:
+ callback_kwargs[k] = locals()[k]
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
+
+ latents = callback_outputs.pop("latents", latents)
+ prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
+
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
+ progress_bar.update()
+
+ if XLA_AVAILABLE:
+ xm.mark_step()
+
+ # Post-processing
+ if output_type == "latent":
+ image = latents
+ else:
+ latents = self._unpack_latents(latents, global_height, global_width, self.vae_scale_factor)
+ latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor
+ image = self.vae.decode(latents, return_dict=False)[0]
+ image = self.image_processor.postprocess(image, output_type=output_type)
+
+ # Offload all models
+ self.maybe_free_model_hooks()
+
+ if not return_dict:
+ return (image,)
+
+ return FluxPipelineOutput(images=image)
diff --git a/flux/pipeline_flux_img2img.py b/flux/pipeline_flux_img2img.py
new file mode 100644
index 0000000000000000000000000000000000000000..6e2f5629c0db8a2ebfd988f69aef6f74c88cdd6b
--- /dev/null
+++ b/flux/pipeline_flux_img2img.py
@@ -0,0 +1,856 @@
+# Copyright 2024 Black Forest Labs and The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import inspect
+from typing import Any, Callable, Dict, List, Optional, Union
+
+import numpy as np
+import torch
+from transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5TokenizerFast
+
+from diffusers.image_processor import PipelineImageInput, VaeImageProcessor
+from .lora.lora_pipeline import FluxLoraLoaderMixin
+from diffusers.models.autoencoders import AutoencoderKL
+from .transformer_flux import FluxTransformer2DModel
+from .scheduling_flow_match_euler_discrete import FlowMatchEulerDiscreteScheduler
+from diffusers.utils import (
+ USE_PEFT_BACKEND,
+ is_torch_xla_available,
+ logging,
+ replace_example_docstring,
+ scale_lora_layers,
+ unscale_lora_layers,
+)
+from diffusers.utils.torch_utils import randn_tensor
+from diffusers.pipelines.pipeline_utils import DiffusionPipeline
+from .pipeline_output import FluxPipelineOutput
+
+
+if is_torch_xla_available():
+ import torch_xla.core.xla_model as xm
+
+ XLA_AVAILABLE = True
+else:
+ XLA_AVAILABLE = False
+
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+EXAMPLE_DOC_STRING = """
+ Examples:
+ ```py
+ >>> import torch
+
+ >>> from diffusers import FluxImg2ImgPipeline
+ >>> from diffusers.utils import load_image
+
+ >>> device = "cuda"
+ >>> pipe = FluxImg2ImgPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell", torch_dtype=torch.bfloat16)
+ >>> pipe = pipe.to(device)
+
+ >>> url = "https://raw.githubusercontent.com/CompVis/stable-diffusion/main/assets/stable-samples/img2img/sketch-mountains-input.jpg"
+ >>> init_image = load_image(url).resize((1024, 1024))
+
+ >>> prompt = "cat wizard, gandalf, lord of the rings, detailed, fantasy, cute, adorable, Pixar, Disney, 8k"
+
+ >>> images = pipe(
+ ... prompt=prompt, image=init_image, num_inference_steps=4, strength=0.95, guidance_scale=0.0
+ ... ).images[0]
+ ```
+"""
+
+
+# Copied from diffusers.pipelines.flux.pipeline_flux.calculate_shift
+def calculate_shift(
+ image_seq_len,
+ base_seq_len: int = 256,
+ max_seq_len: int = 4096,
+ base_shift: float = 0.5,
+ max_shift: float = 1.16,
+):
+ m = (max_shift - base_shift) / (max_seq_len - base_seq_len)
+ b = base_shift - m * base_seq_len
+ mu = image_seq_len * m + b
+ return mu
+
+
+# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
+def retrieve_latents(
+ encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
+):
+ if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
+ return encoder_output.latent_dist.sample(generator)
+ elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
+ return encoder_output.latent_dist.mode()
+ elif hasattr(encoder_output, "latents"):
+ return encoder_output.latents
+ else:
+ raise AttributeError("Could not access latents of provided encoder_output")
+
+
+# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
+def retrieve_timesteps(
+ scheduler,
+ num_inference_steps: Optional[int] = None,
+ device: Optional[Union[str, torch.device]] = None,
+ timesteps: Optional[List[int]] = None,
+ sigmas: Optional[List[float]] = None,
+ **kwargs,
+):
+ """
+ Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
+ custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
+
+ Args:
+ scheduler (`SchedulerMixin`):
+ The scheduler to get timesteps from.
+ num_inference_steps (`int`):
+ The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
+ must be `None`.
+ device (`str` or `torch.device`, *optional*):
+ The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
+ timesteps (`List[int]`, *optional*):
+ Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
+ `num_inference_steps` and `sigmas` must be `None`.
+ sigmas (`List[float]`, *optional*):
+ Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
+ `num_inference_steps` and `timesteps` must be `None`.
+
+ Returns:
+ `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
+ second element is the number of inference steps.
+ """
+ if timesteps is not None and sigmas is not None:
+ raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
+ if timesteps is not None:
+ accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
+ if not accepts_timesteps:
+ raise ValueError(
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
+ f" timestep schedules. Please check whether you are using the correct scheduler."
+ )
+ scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
+ timesteps = scheduler.timesteps
+ num_inference_steps = len(timesteps)
+ elif sigmas is not None:
+ accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
+ if not accept_sigmas:
+ raise ValueError(
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
+ f" sigmas schedules. Please check whether you are using the correct scheduler."
+ )
+ scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
+ timesteps = scheduler.timesteps
+ num_inference_steps = len(timesteps)
+ else:
+ scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
+ timesteps = scheduler.timesteps
+ return timesteps, num_inference_steps
+
+
+class FluxImg2ImgPipeline(DiffusionPipeline, FluxLoraLoaderMixin):
+ r"""
+ The Flux pipeline for image inpainting.
+
+ Reference: https://blackforestlabs.ai/announcing-black-forest-labs/
+
+ Args:
+ transformer ([`FluxTransformer2DModel`]):
+ Conditional Transformer (MMDiT) architecture to denoise the encoded image latents.
+ scheduler ([`FlowMatchEulerDiscreteScheduler`]):
+ A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
+ vae ([`AutoencoderKL`]):
+ Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
+ text_encoder ([`CLIPTextModel`]):
+ [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically
+ the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
+ text_encoder_2 ([`T5EncoderModel`]):
+ [T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5EncoderModel), specifically
+ the [google/t5-v1_1-xxl](https://huggingface.co/google/t5-v1_1-xxl) variant.
+ tokenizer (`CLIPTokenizer`):
+ Tokenizer of class
+ [CLIPTokenizer](https://huggingface.co/docs/transformers/en/model_doc/clip#transformers.CLIPTokenizer).
+ tokenizer_2 (`T5TokenizerFast`):
+ Second Tokenizer of class
+ [T5TokenizerFast](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5TokenizerFast).
+ """
+
+ model_cpu_offload_seq = "text_encoder->text_encoder_2->transformer->vae"
+ _optional_components = []
+ _callback_tensor_inputs = ["latents", "prompt_embeds"]
+
+ def __init__(
+ self,
+ scheduler: FlowMatchEulerDiscreteScheduler,
+ vae: AutoencoderKL,
+ text_encoder: CLIPTextModel,
+ tokenizer: CLIPTokenizer,
+ transformer: FluxTransformer2DModel,
+ text_encoder_2: T5EncoderModel | None = None,
+ tokenizer_2: T5TokenizerFast | None = None,
+ ):
+ super().__init__()
+
+ self.register_modules(
+ vae=vae,
+ text_encoder=text_encoder,
+ #text_encoder_2=text_encoder_2,
+ tokenizer=tokenizer,
+ #tokenizer_2=tokenizer_2,
+ transformer=transformer,
+ scheduler=scheduler,
+ )
+ self.vae_scale_factor = (
+ 2 ** (len(self.vae.config.block_out_channels)) if hasattr(self, "vae") and self.vae is not None else 16
+ )
+ self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
+ self.tokenizer_max_length = (
+ self.tokenizer.model_max_length if hasattr(self, "tokenizer") and self.tokenizer is not None else 77
+ )
+ self.default_sample_size = 64
+
+ # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._get_t5_prompt_embeds
+ def _get_t5_prompt_embeds(
+ self,
+ prompt: Union[str, List[str]] = None,
+ num_images_per_prompt: int = 1,
+ max_sequence_length: int = 512,
+ device: Optional[torch.device] = None,
+ dtype: Optional[torch.dtype] = None,
+ ):
+ device = device or self._execution_device
+ dtype = dtype or self.text_encoder.dtype
+
+ prompt = [prompt] if isinstance(prompt, str) else prompt
+ batch_size = len(prompt)
+
+ text_inputs = self.tokenizer_2(
+ prompt,
+ padding="max_length",
+ max_length=max_sequence_length,
+ truncation=True,
+ return_length=False,
+ return_overflowing_tokens=False,
+ return_tensors="pt",
+ )
+ text_input_ids = text_inputs.input_ids
+ untruncated_ids = self.tokenizer_2(prompt, padding="longest", return_tensors="pt").input_ids
+
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
+ removed_text = self.tokenizer_2.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1])
+ logger.warning(
+ "The following part of your input was truncated because `max_sequence_length` is set to "
+ f" {max_sequence_length} tokens: {removed_text}"
+ )
+
+ prompt_embeds = self.text_encoder_2(text_input_ids.to(device), output_hidden_states=False)[0]
+
+ dtype = self.text_encoder_2.dtype
+ prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
+
+ _, seq_len, _ = prompt_embeds.shape
+
+ # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
+ prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
+
+ return prompt_embeds
+
+ # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._get_clip_prompt_embeds
+ def _get_clip_prompt_embeds(
+ self,
+ prompt: Union[str, List[str]],
+ num_images_per_prompt: int = 1,
+ device: Optional[torch.device] = None,
+ ):
+ device = device or self._execution_device
+
+ prompt = [prompt] if isinstance(prompt, str) else prompt
+ batch_size = len(prompt)
+
+ text_inputs = self.tokenizer(
+ prompt,
+ padding="max_length",
+ max_length=self.tokenizer_max_length,
+ truncation=True,
+ return_overflowing_tokens=False,
+ return_length=False,
+ return_tensors="pt",
+ )
+
+ text_input_ids = text_inputs.input_ids
+ untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
+ removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1])
+ logger.warning(
+ "The following part of your input was truncated because CLIP can only handle sequences up to"
+ f" {self.tokenizer_max_length} tokens: {removed_text}"
+ )
+ prompt_embeds = self.text_encoder(text_input_ids.to(device), output_hidden_states=False)
+
+ # Use pooled output of CLIPTextModel
+ prompt_embeds = prompt_embeds.pooler_output
+ prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)
+
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt)
+ prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, -1)
+
+ return prompt_embeds
+
+ # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.encode_prompt
+ def encode_prompt(
+ self,
+ prompt: Union[str, List[str]],
+ prompt_2: Union[str, List[str]],
+ device: Optional[torch.device] = None,
+ num_images_per_prompt: int = 1,
+ prompt_embeds: Optional[torch.FloatTensor] = None,
+ t5_prompt_embeds: Optional[torch.FloatTensor] = None,
+ pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
+ max_sequence_length: int = 512,
+ lora_scale: Optional[float] = None,
+ ):
+ r"""
+
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ prompt to be encoded
+ prompt_2 (`str` or `List[str]`, *optional*):
+ The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
+ used in all text-encoders
+ device: (`torch.device`):
+ torch device
+ num_images_per_prompt (`int`):
+ number of images that should be generated per prompt
+ prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+ provided, text embeddings will be generated from `prompt` input argument.
+ pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
+ If not provided, pooled text embeddings will be generated from `prompt` input argument.
+ lora_scale (`float`, *optional*):
+ A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
+ """
+ device = device or self._execution_device
+
+ # set lora scale so that monkey patched LoRA
+ # function of text encoder can correctly access it
+ if lora_scale is not None and isinstance(self, FluxLoraLoaderMixin):
+ self._lora_scale = lora_scale
+
+ # dynamically adjust the LoRA scale
+ if self.text_encoder is not None and USE_PEFT_BACKEND:
+ scale_lora_layers(self.text_encoder, lora_scale)
+ if self.text_encoder_2 is not None and USE_PEFT_BACKEND:
+ scale_lora_layers(self.text_encoder_2, lora_scale)
+
+ prompt = [prompt] if isinstance(prompt, str) else prompt
+ if prompt is not None:
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ if prompt_embeds is None:
+ prompt_2 = prompt_2 or prompt
+ prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2
+
+ # We only use the pooled prompt output from the CLIPTextModel
+ pooled_prompt_embeds = self._get_clip_prompt_embeds(
+ prompt=prompt,
+ device=device,
+ num_images_per_prompt=num_images_per_prompt,
+ )
+ prompt_embeds = self._get_t5_prompt_embeds(
+ prompt=prompt_2,
+ num_images_per_prompt=num_images_per_prompt,
+ max_sequence_length=max_sequence_length,
+ device=device,
+ )
+
+ if self.text_encoder is not None:
+ if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND:
+ # Retrieve the original scale by scaling back the LoRA layers
+ unscale_lora_layers(self.text_encoder, lora_scale)
+
+ #if self.text_encoder_2 is not None:
+ # if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND:
+ # # Retrieve the original scale by scaling back the LoRA layers
+ # unscale_lora_layers(self.text_encoder_2, lora_scale)
+
+ dtype = self.text_encoder.dtype if self.text_encoder is not None else self.transformer.dtype
+ if t5_prompt_embeds is not None:
+ text_ids = torch.zeros(batch_size, prompt_embeds.shape[1] + t5_prompt_embeds.shape[1], 3).to(device=device, dtype=dtype)
+ else:
+ text_ids = torch.zeros(batch_size, prompt_embeds.shape[1], 3).to(device=device, dtype=dtype)
+ #text_ids = torch.zeros(prompt_embeds.shape[1], 3).to(device=device, dtype=dtype)
+
+ return prompt_embeds, pooled_prompt_embeds, text_ids
+
+ # Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3_inpaint.StableDiffusion3InpaintPipeline._encode_vae_image
+ def _encode_vae_image(self, image: torch.Tensor, generator: torch.Generator):
+ if isinstance(generator, list):
+ image_latents = [
+ retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i])
+ for i in range(image.shape[0])
+ ]
+ image_latents = torch.cat(image_latents, dim=0)
+ else:
+ image_latents = retrieve_latents(self.vae.encode(image), generator=generator)
+
+ image_latents = (image_latents - self.vae.config.shift_factor) * self.vae.config.scaling_factor
+
+ return image_latents
+
+ # Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3_img2img.StableDiffusion3Img2ImgPipeline.get_timesteps
+ def get_timesteps(self, num_inference_steps, strength, device):
+ # get the original timestep using init_timestep
+ init_timestep = min(num_inference_steps * strength, num_inference_steps)
+
+ t_start = int(max(num_inference_steps - init_timestep, 0))
+ timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :]
+ if hasattr(self.scheduler, "set_begin_index"):
+ self.scheduler.set_begin_index(t_start * self.scheduler.order)
+
+ return timesteps, num_inference_steps - t_start
+
+ def check_inputs(
+ self,
+ prompt,
+ prompt_2,
+ strength,
+ height,
+ width,
+ prompt_embeds=None,
+ pooled_prompt_embeds=None,
+ callback_on_step_end_tensor_inputs=None,
+ max_sequence_length=None,
+ ):
+ if strength < 0 or strength > 1:
+ raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}")
+
+ if height % 8 != 0 or width % 8 != 0:
+ raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
+
+ if callback_on_step_end_tensor_inputs is not None and not all(
+ k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
+ ):
+ raise ValueError(
+ f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
+ )
+
+ if prompt is not None and prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
+ " only forward one of the two."
+ )
+ elif prompt_2 is not None and prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `prompt_2`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
+ " only forward one of the two."
+ )
+ elif prompt is None and prompt_embeds is None:
+ raise ValueError(
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
+ )
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
+ elif prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)):
+ raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}")
+
+ if prompt_embeds is not None and pooled_prompt_embeds is None:
+ raise ValueError(
+ "If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`."
+ )
+
+ if max_sequence_length is not None and max_sequence_length > 512:
+ raise ValueError(f"`max_sequence_length` cannot be greater than 512 but is {max_sequence_length}")
+
+ @staticmethod
+ # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._prepare_latent_image_ids
+ def _prepare_latent_image_ids(batch_size, height, width, device, dtype):
+ latent_image_ids = torch.zeros(height // 2, width // 2, 3)
+ latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height // 2)[:, None]
+ latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width // 2)[None, :]
+
+ latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape
+
+ latent_image_ids = latent_image_ids.reshape(
+ latent_image_id_height * latent_image_id_width, latent_image_id_channels
+ )
+
+ return latent_image_ids.to(device=device, dtype=dtype)
+
+ @staticmethod
+ # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._pack_latents
+ def _pack_latents(latents, batch_size, num_channels_latents, height, width):
+ latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2)
+ latents = latents.permute(0, 2, 4, 1, 3, 5)
+ latents = latents.reshape(batch_size, (height // 2) * (width // 2), num_channels_latents * 4)
+
+ return latents
+
+ @staticmethod
+ # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._unpack_latents
+ def _unpack_latents(latents, height, width, vae_scale_factor):
+ batch_size, num_patches, channels = latents.shape
+
+ height = height // vae_scale_factor
+ width = width // vae_scale_factor
+
+ latents = latents.view(batch_size, height, width, channels // 4, 2, 2)
+ latents = latents.permute(0, 3, 1, 4, 2, 5)
+
+ latents = latents.reshape(batch_size, channels // (2 * 2), height * 2, width * 2)
+
+ return latents
+
+ def prepare_latents(
+ self,
+ image,
+ timestep,
+ batch_size,
+ num_channels_latents,
+ height,
+ width,
+ dtype,
+ device,
+ generator,
+ latents=None,
+ ):
+ if isinstance(generator, list) and len(generator) != batch_size:
+ raise ValueError(
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
+ )
+
+ height = 2 * (int(height) // self.vae_scale_factor)
+ width = 2 * (int(width) // self.vae_scale_factor)
+
+ shape = (batch_size, num_channels_latents, height, width)
+ latent_image_ids = self._prepare_latent_image_ids(batch_size, height, width, device, dtype)
+
+ if latents is not None:
+ return latents.to(device=device, dtype=dtype), latent_image_ids
+
+ image = image.to(device=device, dtype=dtype)
+ image_latents = self._encode_vae_image(image=image, generator=generator)
+ if batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] == 0:
+ # expand init_latents for batch_size
+ additional_image_per_prompt = batch_size // image_latents.shape[0]
+ image_latents = torch.cat([image_latents] * additional_image_per_prompt, dim=0)
+ elif batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] != 0:
+ raise ValueError(
+ f"Cannot duplicate `image` of batch size {image_latents.shape[0]} to {batch_size} text prompts."
+ )
+ else:
+ image_latents = torch.cat([image_latents], dim=0)
+
+ noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
+ latents = self.scheduler.scale_noise(image_latents, timestep, noise)
+ latents = self._pack_latents(latents, batch_size, num_channels_latents, height, width)
+ return latents, latent_image_ids
+
+ @property
+ def guidance_scale(self):
+ return self._guidance_scale
+
+ @property
+ def joint_attention_kwargs(self):
+ return self._joint_attention_kwargs
+
+ @property
+ def num_timesteps(self):
+ return self._num_timesteps
+
+ @property
+ def interrupt(self):
+ return self._interrupt
+
+ @torch.no_grad()
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
+ def __call__(
+ self,
+ prompt: Union[str, List[str]] = None,
+ prompt_2: Optional[Union[str, List[str]]] = None,
+ image: PipelineImageInput = None,
+ height: Optional[int] = None,
+ width: Optional[int] = None,
+ strength: float = 0.6,
+ num_inference_steps: int = 28,
+ timesteps: List[int] = None,
+ guidance_scale: float = 7.0,
+ num_images_per_prompt: Optional[int] = 1,
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+ latents: Optional[torch.FloatTensor] = None,
+ prompt_embeds: Optional[torch.FloatTensor] = None,
+ t5_prompt_embeds: Optional[torch.FloatTensor] = None,
+ pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
+ output_type: Optional[str] = "pil",
+ return_dict: bool = True,
+ joint_attention_kwargs: Optional[Dict[str, Any]] = None,
+ callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
+ max_sequence_length: int = 512,
+ ):
+ r"""
+ Function invoked when calling the pipeline for generation.
+
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
+ instead.
+ prompt_2 (`str` or `List[str]`, *optional*):
+ The prompt or prompts to be sent to `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
+ will be used instead
+ image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.Tensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`):
+ `Image`, numpy array or tensor representing an image batch to be used as the starting point. For both
+ numpy array and pytorch tensor, the expected value range is between `[0, 1]` If it's a tensor or a list
+ or tensors, the expected shape should be `(B, C, H, W)` or `(C, H, W)`. If it is a numpy array or a
+ list of arrays, the expected shape should be `(B, H, W, C)` or `(H, W, C)` It can also accept image
+ latents as `image`, but if passing latents directly it is not encoded again.
+ height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
+ The height in pixels of the generated image. This is set to 1024 by default for the best results.
+ width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
+ The width in pixels of the generated image. This is set to 1024 by default for the best results.
+ strength (`float`, *optional*, defaults to 1.0):
+ Indicates extent to transform the reference `image`. Must be between 0 and 1. `image` is used as a
+ starting point and more noise is added the higher the `strength`. The number of denoising steps depends
+ on the amount of noise initially added. When `strength` is 1, added noise is maximum and the denoising
+ process runs for the full number of iterations specified in `num_inference_steps`. A value of 1
+ essentially ignores `image`.
+ num_inference_steps (`int`, *optional*, defaults to 50):
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
+ expense of slower inference.
+ timesteps (`List[int]`, *optional*):
+ Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
+ in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
+ passed will be used. Must be in descending order.
+ guidance_scale (`float`, *optional*, defaults to 7.0):
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
+ usually at the expense of lower image quality.
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
+ The number of images to generate per prompt.
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
+ to make generation deterministic.
+ latents (`torch.FloatTensor`, *optional*):
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
+ tensor will ge generated by sampling using the supplied random `generator`.
+ prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+ provided, text embeddings will be generated from `prompt` input argument.
+ pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
+ If not provided, pooled text embeddings will be generated from `prompt` input argument.
+ output_type (`str`, *optional*, defaults to `"pil"`):
+ The output format of the generate image. Choose between
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~pipelines.flux.FluxPipelineOutput`] instead of a plain tuple.
+ joint_attention_kwargs (`dict`, *optional*):
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
+ `self.processor` in
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
+ callback_on_step_end (`Callable`, *optional*):
+ A function that calls at the end of each denoising steps during the inference. The function is called
+ with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
+ callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
+ `callback_on_step_end_tensor_inputs`.
+ callback_on_step_end_tensor_inputs (`List`, *optional*):
+ The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
+ will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
+ `._callback_tensor_inputs` attribute of your pipeline class.
+ max_sequence_length (`int` defaults to 512): Maximum sequence length to use with the `prompt`.
+
+ Examples:
+
+ Returns:
+ [`~pipelines.flux.FluxPipelineOutput`] or `tuple`: [`~pipelines.flux.FluxPipelineOutput`] if `return_dict`
+ is True, otherwise a `tuple`. When returning a tuple, the first element is a list with the generated
+ images.
+ """
+
+ height = height or self.default_sample_size * self.vae_scale_factor
+ width = width or self.default_sample_size * self.vae_scale_factor
+
+ # 1. Check inputs. Raise error if not correct
+ self.check_inputs(
+ prompt,
+ prompt_2,
+ strength,
+ height,
+ width,
+ prompt_embeds=prompt_embeds,
+ pooled_prompt_embeds=pooled_prompt_embeds,
+ callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
+ max_sequence_length=max_sequence_length,
+ )
+
+ self._guidance_scale = guidance_scale
+ self._joint_attention_kwargs = joint_attention_kwargs
+ self._interrupt = False
+
+ # 2. Preprocess image
+ init_image = self.image_processor.preprocess(image, height=height, width=width)
+ init_image = init_image.to(dtype=torch.float32)
+
+ # 3. Define call parameters
+ if prompt is not None and isinstance(prompt, str):
+ batch_size = 1
+ elif prompt is not None and isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ device = self._execution_device
+
+ lora_scale = (
+ self.joint_attention_kwargs.get("scale", None) if self.joint_attention_kwargs is not None else None
+ )
+ (
+ prompt_embeds,
+ pooled_prompt_embeds,
+ text_ids,
+ ) = self.encode_prompt(
+ prompt=prompt,
+ prompt_2=prompt_2,
+ prompt_embeds=prompt_embeds,
+ t5_prompt_embeds=t5_prompt_embeds,
+ pooled_prompt_embeds=pooled_prompt_embeds,
+ device=device,
+ num_images_per_prompt=num_images_per_prompt,
+ max_sequence_length=max_sequence_length,
+ lora_scale=lora_scale,
+ )
+
+ # 4.Prepare timesteps
+ sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps)
+ image_seq_len = (int(height) // self.vae_scale_factor) * (int(width) // self.vae_scale_factor)
+ mu = calculate_shift(
+ image_seq_len,
+ self.scheduler.config.base_image_seq_len,
+ self.scheduler.config.max_image_seq_len,
+ self.scheduler.config.base_shift,
+ self.scheduler.config.max_shift,
+ )
+ timesteps, num_inference_steps = retrieve_timesteps(
+ self.scheduler,
+ num_inference_steps,
+ device,
+ timesteps,
+ sigmas,
+ mu=mu,
+ )
+ timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device)
+
+ if num_inference_steps < 1:
+ raise ValueError(
+ f"After adjusting the num_inference_steps by strength parameter: {strength}, the number of pipeline"
+ f"steps is {num_inference_steps} which is < 1 and not appropriate for this pipeline."
+ )
+ latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt)
+
+ # 5. Prepare latent variables
+ num_channels_latents = self.transformer.config.in_channels // 4
+
+ latents, latent_image_ids = self.prepare_latents(
+ init_image,
+ latent_timestep,
+ batch_size * num_images_per_prompt,
+ num_channels_latents,
+ height,
+ width,
+ prompt_embeds.dtype,
+ device,
+ generator,
+ latents,
+ )
+
+ num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
+ self._num_timesteps = len(timesteps)
+
+ # handle guidance
+ if self.transformer.config.guidance_embeds:
+ guidance = torch.full([1], guidance_scale, device=device, dtype=torch.float32)
+ guidance = guidance.expand(latents.shape[0])
+ else:
+ guidance = None
+
+ # 6. Denoising loop
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
+ for i, t in enumerate(timesteps):
+ if self.interrupt:
+ continue
+
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
+ timestep = t.expand(latents.shape[0]).to(latents.dtype)
+ noise_pred = self.transformer(
+ hidden_states=latents,
+ timestep=timestep / 1000,
+ guidance=guidance,
+ pooled_projections=pooled_prompt_embeds,
+ encoder_hidden_states=prompt_embeds,
+ t5_encoder_hidden_states=t5_prompt_embeds,
+ txt_ids=text_ids,
+ img_ids=latent_image_ids,
+ joint_attention_kwargs=self.joint_attention_kwargs,
+ return_dict=False,
+ )[0]
+
+ # compute the previous noisy sample x_t -> x_t-1
+ latents_dtype = latents.dtype
+ latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
+
+ if latents.dtype != latents_dtype:
+ if torch.backends.mps.is_available():
+ # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
+ latents = latents.to(latents_dtype)
+
+ if callback_on_step_end is not None:
+ callback_kwargs = {}
+ for k in callback_on_step_end_tensor_inputs:
+ callback_kwargs[k] = locals()[k]
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
+
+ latents = callback_outputs.pop("latents", latents)
+ prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
+
+ # call the callback, if provided
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
+ progress_bar.update()
+
+ if XLA_AVAILABLE:
+ xm.mark_step()
+
+ if output_type == "latent":
+ image = latents
+
+ else:
+ latents = self._unpack_latents(latents, height, width, self.vae_scale_factor)
+ latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor
+ image = self.vae.decode(latents, return_dict=False)[0]
+ image = self.image_processor.postprocess(image, output_type=output_type)
+
+ # Offload all models
+ self.maybe_free_model_hooks()
+
+ if not return_dict:
+ return (image,)
+
+ return FluxPipelineOutput(images=image)
\ No newline at end of file
diff --git a/flux/pipeline_flux_inpaint.py b/flux/pipeline_flux_inpaint.py
new file mode 100644
index 0000000000000000000000000000000000000000..987bb4145af36b718e2cf6fc8e289209d0f1dca1
--- /dev/null
+++ b/flux/pipeline_flux_inpaint.py
@@ -0,0 +1,1021 @@
+# Copyright 2024 Black Forest Labs and The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import inspect
+from typing import Any, Callable, Dict, List, Optional, Union
+
+import numpy as np
+import PIL.Image
+import torch
+from transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5TokenizerFast
+
+from diffusers.image_processor import PipelineImageInput, VaeImageProcessor
+from .lora.lora_pipeline import FluxLoraLoaderMixin
+from diffusers.models.autoencoders import AutoencoderKL
+from .transformer_flux import FluxTransformer2DModel
+from .scheduling_flow_match_euler_discrete import FlowMatchEulerDiscreteScheduler
+from diffusers.utils import (
+ USE_PEFT_BACKEND,
+ is_torch_xla_available,
+ logging,
+ replace_example_docstring,
+ scale_lora_layers,
+ unscale_lora_layers,
+)
+from diffusers.utils.torch_utils import randn_tensor
+from diffusers.pipelines.pipeline_utils import DiffusionPipeline
+from .pipeline_output import FluxPipelineOutput
+
+
+if is_torch_xla_available():
+ import torch_xla.core.xla_model as xm
+
+ XLA_AVAILABLE = True
+else:
+ XLA_AVAILABLE = False
+
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+EXAMPLE_DOC_STRING = """
+ Examples:
+ ```py
+ >>> import torch
+ >>> from diffusers import FluxInpaintPipeline
+ >>> from diffusers.utils import load_image
+
+ >>> pipe = FluxInpaintPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell", torch_dtype=torch.bfloat16)
+ >>> pipe.to("cuda")
+ >>> prompt = "Face of a yellow cat, high resolution, sitting on a park bench"
+ >>> img_url = "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo.png"
+ >>> mask_url = "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo_mask.png"
+ >>> source = load_image(img_url)
+ >>> mask = load_image(mask_url)
+ >>> image = pipe(prompt=prompt, image=source, mask_image=mask).images[0]
+ >>> image.save("flux_inpainting.png")
+ ```
+"""
+
+
+# Copied from diffusers.pipelines.flux.pipeline_flux.calculate_shift
+def calculate_shift(
+ image_seq_len,
+ base_seq_len: int = 256,
+ max_seq_len: int = 4096,
+ base_shift: float = 0.5,
+ max_shift: float = 1.16,
+):
+ m = (max_shift - base_shift) / (max_seq_len - base_seq_len)
+ b = base_shift - m * base_seq_len
+ mu = image_seq_len * m + b
+ return mu
+
+
+# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
+def retrieve_latents(
+ encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
+):
+ if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
+ return encoder_output.latent_dist.sample(generator)
+ elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
+ return encoder_output.latent_dist.mode()
+ elif hasattr(encoder_output, "latents"):
+ return encoder_output.latents
+ else:
+ raise AttributeError("Could not access latents of provided encoder_output")
+
+
+# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
+def retrieve_timesteps(
+ scheduler,
+ num_inference_steps: Optional[int] = None,
+ device: Optional[Union[str, torch.device]] = None,
+ timesteps: Optional[List[int]] = None,
+ sigmas: Optional[List[float]] = None,
+ **kwargs,
+):
+ """
+ Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
+ custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
+
+ Args:
+ scheduler (`SchedulerMixin`):
+ The scheduler to get timesteps from.
+ num_inference_steps (`int`):
+ The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
+ must be `None`.
+ device (`str` or `torch.device`, *optional*):
+ The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
+ timesteps (`List[int]`, *optional*):
+ Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
+ `num_inference_steps` and `sigmas` must be `None`.
+ sigmas (`List[float]`, *optional*):
+ Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
+ `num_inference_steps` and `timesteps` must be `None`.
+
+ Returns:
+ `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
+ second element is the number of inference steps.
+ """
+ if timesteps is not None and sigmas is not None:
+ raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
+ if timesteps is not None:
+ accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
+ if not accepts_timesteps:
+ raise ValueError(
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
+ f" timestep schedules. Please check whether you are using the correct scheduler."
+ )
+ scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
+ timesteps = scheduler.timesteps
+ num_inference_steps = len(timesteps)
+ elif sigmas is not None:
+ accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
+ if not accept_sigmas:
+ raise ValueError(
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
+ f" sigmas schedules. Please check whether you are using the correct scheduler."
+ )
+ scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
+ timesteps = scheduler.timesteps
+ num_inference_steps = len(timesteps)
+ else:
+ scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
+ timesteps = scheduler.timesteps
+ return timesteps, num_inference_steps
+
+
+class FluxInpaintPipeline(DiffusionPipeline, FluxLoraLoaderMixin):
+ r"""
+ The Flux pipeline for image inpainting.
+
+ Reference: https://blackforestlabs.ai/announcing-black-forest-labs/
+
+ Args:
+ transformer ([`FluxTransformer2DModel`]):
+ Conditional Transformer (MMDiT) architecture to denoise the encoded image latents.
+ scheduler ([`FlowMatchEulerDiscreteScheduler`]):
+ A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
+ vae ([`AutoencoderKL`]):
+ Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
+ text_encoder ([`CLIPTextModel`]):
+ [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically
+ the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
+ text_encoder_2 ([`T5EncoderModel`]):
+ [T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5EncoderModel), specifically
+ the [google/t5-v1_1-xxl](https://huggingface.co/google/t5-v1_1-xxl) variant.
+ tokenizer (`CLIPTokenizer`):
+ Tokenizer of class
+ [CLIPTokenizer](https://huggingface.co/docs/transformers/en/model_doc/clip#transformers.CLIPTokenizer).
+ tokenizer_2 (`T5TokenizerFast`):
+ Second Tokenizer of class
+ [T5TokenizerFast](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5TokenizerFast).
+ """
+
+ model_cpu_offload_seq = "text_encoder->text_encoder_2->transformer->vae"
+ _optional_components = []
+ _callback_tensor_inputs = ["latents", "prompt_embeds"]
+
+ def __init__(
+ self,
+ scheduler: FlowMatchEulerDiscreteScheduler,
+ vae: AutoencoderKL,
+ text_encoder: CLIPTextModel,
+ tokenizer: CLIPTokenizer,
+ transformer: FluxTransformer2DModel,
+ text_encoder_2: T5EncoderModel | None = None,
+ tokenizer_2: T5TokenizerFast | None = None,
+ ):
+ super().__init__()
+
+ self.register_modules(
+ vae=vae,
+ text_encoder=text_encoder,
+ #text_encoder_2=text_encoder_2,
+ tokenizer=tokenizer,
+ #tokenizer_2=tokenizer_2,
+ transformer=transformer,
+ scheduler=scheduler,
+ )
+ self.vae_scale_factor = (
+ 2 ** (len(self.vae.config.block_out_channels)) if hasattr(self, "vae") and self.vae is not None else 16
+ )
+ self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
+ self.mask_processor = VaeImageProcessor(
+ vae_scale_factor=self.vae_scale_factor,
+ vae_latent_channels=self.vae.config.latent_channels,
+ do_normalize=False,
+ do_binarize=True,
+ do_convert_grayscale=True,
+ )
+ self.tokenizer_max_length = (
+ self.tokenizer.model_max_length if hasattr(self, "tokenizer") and self.tokenizer is not None else 77
+ )
+ self.default_sample_size = 64
+
+ # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._get_t5_prompt_embeds
+ def _get_t5_prompt_embeds(
+ self,
+ prompt: Union[str, List[str]] = None,
+ num_images_per_prompt: int = 1,
+ max_sequence_length: int = 512,
+ device: Optional[torch.device] = None,
+ dtype: Optional[torch.dtype] = None,
+ ):
+ device = device or self._execution_device
+ dtype = dtype or self.text_encoder.dtype
+
+ prompt = [prompt] if isinstance(prompt, str) else prompt
+ batch_size = len(prompt)
+
+ text_inputs = self.tokenizer_2(
+ prompt,
+ padding="max_length",
+ max_length=max_sequence_length,
+ truncation=True,
+ return_length=False,
+ return_overflowing_tokens=False,
+ return_tensors="pt",
+ )
+ text_input_ids = text_inputs.input_ids
+ untruncated_ids = self.tokenizer_2(prompt, padding="longest", return_tensors="pt").input_ids
+
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
+ removed_text = self.tokenizer_2.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1])
+ logger.warning(
+ "The following part of your input was truncated because `max_sequence_length` is set to "
+ f" {max_sequence_length} tokens: {removed_text}"
+ )
+
+ prompt_embeds = self.text_encoder_2(text_input_ids.to(device), output_hidden_states=False)[0]
+
+ dtype = self.text_encoder_2.dtype
+ prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
+
+ _, seq_len, _ = prompt_embeds.shape
+
+ # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
+ prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
+
+ return prompt_embeds
+
+ # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._get_clip_prompt_embeds
+ def _get_clip_prompt_embeds(
+ self,
+ prompt: Union[str, List[str]],
+ num_images_per_prompt: int = 1,
+ device: Optional[torch.device] = None,
+ ):
+ device = device or self._execution_device
+
+ prompt = [prompt] if isinstance(prompt, str) else prompt
+ batch_size = len(prompt)
+
+ text_inputs = self.tokenizer(
+ prompt,
+ padding="max_length",
+ max_length=self.tokenizer_max_length,
+ truncation=True,
+ return_overflowing_tokens=False,
+ return_length=False,
+ return_tensors="pt",
+ )
+
+ text_input_ids = text_inputs.input_ids
+ untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
+ removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1])
+ logger.warning(
+ "The following part of your input was truncated because CLIP can only handle sequences up to"
+ f" {self.tokenizer_max_length} tokens: {removed_text}"
+ )
+ prompt_embeds = self.text_encoder(text_input_ids.to(device), output_hidden_states=False)
+
+ # Use pooled output of CLIPTextModel
+ prompt_embeds = prompt_embeds.pooler_output
+ prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)
+
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt)
+ prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, -1)
+
+ return prompt_embeds
+
+ # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.encode_prompt
+ def encode_prompt(
+ self,
+ prompt: Union[str, List[str]],
+ prompt_2: Union[str, List[str]],
+ device: Optional[torch.device] = None,
+ num_images_per_prompt: int = 1,
+ prompt_embeds: Optional[torch.FloatTensor] = None,
+ t5_prompt_embeds: Optional[torch.FloatTensor] = None,
+ pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
+ max_sequence_length: int = 512,
+ lora_scale: Optional[float] = None,
+ ):
+ r"""
+
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ prompt to be encoded
+ prompt_2 (`str` or `List[str]`, *optional*):
+ The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
+ used in all text-encoders
+ device: (`torch.device`):
+ torch device
+ num_images_per_prompt (`int`):
+ number of images that should be generated per prompt
+ prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+ provided, text embeddings will be generated from `prompt` input argument.
+ pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
+ If not provided, pooled text embeddings will be generated from `prompt` input argument.
+ lora_scale (`float`, *optional*):
+ A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
+ """
+ device = device or self._execution_device
+
+ # set lora scale so that monkey patched LoRA
+ # function of text encoder can correctly access it
+ if lora_scale is not None and isinstance(self, FluxLoraLoaderMixin):
+ self._lora_scale = lora_scale
+
+ # dynamically adjust the LoRA scale
+ if self.text_encoder is not None and USE_PEFT_BACKEND:
+ scale_lora_layers(self.text_encoder, lora_scale)
+ if self.text_encoder_2 is not None and USE_PEFT_BACKEND:
+ scale_lora_layers(self.text_encoder_2, lora_scale)
+
+ prompt = [prompt] if isinstance(prompt, str) else prompt
+ if prompt is not None:
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ if prompt_embeds is None:
+ prompt_2 = prompt_2 or prompt
+ prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2
+
+ # We only use the pooled prompt output from the CLIPTextModel
+ pooled_prompt_embeds = self._get_clip_prompt_embeds(
+ prompt=prompt,
+ device=device,
+ num_images_per_prompt=num_images_per_prompt,
+ )
+ prompt_embeds = self._get_t5_prompt_embeds(
+ prompt=prompt_2,
+ num_images_per_prompt=num_images_per_prompt,
+ max_sequence_length=max_sequence_length,
+ device=device,
+ )
+
+ if self.text_encoder is not None:
+ if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND:
+ # Retrieve the original scale by scaling back the LoRA layers
+ unscale_lora_layers(self.text_encoder, lora_scale)
+
+ #if self.text_encoder_2 is not None:
+ # if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND:
+ # # Retrieve the original scale by scaling back the LoRA layers
+ # unscale_lora_layers(self.text_encoder_2, lora_scale)
+
+ dtype = self.text_encoder.dtype if self.text_encoder is not None else self.transformer.dtype
+ if t5_prompt_embeds is not None:
+ text_ids = torch.zeros(batch_size, prompt_embeds.shape[1] + t5_prompt_embeds.shape[1], 3).to(device=device, dtype=dtype)
+ else:
+ text_ids = torch.zeros(batch_size, prompt_embeds.shape[1], 3).to(device=device, dtype=dtype)
+ #text_ids = torch.zeros(prompt_embeds.shape[1], 3).to(device=device, dtype=dtype)
+
+ return prompt_embeds, pooled_prompt_embeds, text_ids
+
+ # Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3_inpaint.StableDiffusion3InpaintPipeline._encode_vae_image
+ def _encode_vae_image(self, image: torch.Tensor, generator: torch.Generator):
+ if isinstance(generator, list):
+ image_latents = [
+ retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i])
+ for i in range(image.shape[0])
+ ]
+ image_latents = torch.cat(image_latents, dim=0)
+ else:
+ image_latents = retrieve_latents(self.vae.encode(image), generator=generator)
+
+ image_latents = (image_latents - self.vae.config.shift_factor) * self.vae.config.scaling_factor
+
+ return image_latents
+
+ # Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3_img2img.StableDiffusion3Img2ImgPipeline.get_timesteps
+ def get_timesteps(self, num_inference_steps, strength, device):
+ # get the original timestep using init_timestep
+ init_timestep = min(num_inference_steps * strength, num_inference_steps)
+
+ t_start = int(max(num_inference_steps - init_timestep, 0))
+ timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :]
+ if hasattr(self.scheduler, "set_begin_index"):
+ self.scheduler.set_begin_index(t_start * self.scheduler.order)
+
+ return timesteps, num_inference_steps - t_start
+
+ def check_inputs(
+ self,
+ prompt,
+ prompt_2,
+ image,
+ mask_image,
+ strength,
+ height,
+ width,
+ output_type,
+ prompt_embeds=None,
+ pooled_prompt_embeds=None,
+ callback_on_step_end_tensor_inputs=None,
+ padding_mask_crop=None,
+ max_sequence_length=None,
+ ):
+ if strength < 0 or strength > 1:
+ raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}")
+
+ if height % 8 != 0 or width % 8 != 0:
+ raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
+
+ if callback_on_step_end_tensor_inputs is not None and not all(
+ k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
+ ):
+ raise ValueError(
+ f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
+ )
+
+ if prompt is not None and prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
+ " only forward one of the two."
+ )
+ elif prompt_2 is not None and prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `prompt_2`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
+ " only forward one of the two."
+ )
+ elif prompt is None and prompt_embeds is None:
+ raise ValueError(
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
+ )
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
+ elif prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)):
+ raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}")
+
+ if prompt_embeds is not None and pooled_prompt_embeds is None:
+ raise ValueError(
+ "If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`."
+ )
+
+ if padding_mask_crop is not None:
+ if not isinstance(image, PIL.Image.Image):
+ raise ValueError(
+ f"The image should be a PIL image when inpainting mask crop, but is of type" f" {type(image)}."
+ )
+ if not isinstance(mask_image, PIL.Image.Image):
+ raise ValueError(
+ f"The mask image should be a PIL image when inpainting mask crop, but is of type"
+ f" {type(mask_image)}."
+ )
+ if output_type != "pil":
+ raise ValueError(f"The output type should be PIL when inpainting mask crop, but is" f" {output_type}.")
+
+ if max_sequence_length is not None and max_sequence_length > 512:
+ raise ValueError(f"`max_sequence_length` cannot be greater than 512 but is {max_sequence_length}")
+
+ @staticmethod
+ # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._prepare_latent_image_ids
+ def _prepare_latent_image_ids(batch_size, height, width, device, dtype):
+ latent_image_ids = torch.zeros(height // 2, width // 2, 3)
+ latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height // 2)[:, None]
+ latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width // 2)[None, :]
+
+ latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape
+
+ latent_image_ids = latent_image_ids.reshape(
+ latent_image_id_height * latent_image_id_width, latent_image_id_channels
+ )
+
+ return latent_image_ids.to(device=device, dtype=dtype)
+
+ @staticmethod
+ # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._pack_latents
+ def _pack_latents(latents, batch_size, num_channels_latents, height, width):
+ latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2)
+ latents = latents.permute(0, 2, 4, 1, 3, 5)
+ latents = latents.reshape(batch_size, (height // 2) * (width // 2), num_channels_latents * 4)
+
+ return latents
+
+ @staticmethod
+ # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._unpack_latents
+ def _unpack_latents(latents, height, width, vae_scale_factor):
+ batch_size, num_patches, channels = latents.shape
+
+ height = height // vae_scale_factor
+ width = width // vae_scale_factor
+
+ latents = latents.view(batch_size, height, width, channels // 4, 2, 2)
+ latents = latents.permute(0, 3, 1, 4, 2, 5)
+
+ latents = latents.reshape(batch_size, channels // (2 * 2), height * 2, width * 2)
+
+ return latents
+
+ def prepare_latents(
+ self,
+ image,
+ timestep,
+ batch_size,
+ num_channels_latents,
+ height,
+ width,
+ dtype,
+ device,
+ generator,
+ latents=None,
+ ):
+ if isinstance(generator, list) and len(generator) != batch_size:
+ raise ValueError(
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
+ )
+
+ height = 2 * (int(height) // self.vae_scale_factor)
+ width = 2 * (int(width) // self.vae_scale_factor)
+
+ shape = (batch_size, num_channels_latents, height, width)
+ latent_image_ids = self._prepare_latent_image_ids(batch_size, height, width, device, dtype)
+
+ image = image.to(device=device, dtype=dtype)
+ image_latents = self._encode_vae_image(image=image, generator=generator)
+
+ if batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] == 0:
+ # expand init_latents for batch_size
+ additional_image_per_prompt = batch_size // image_latents.shape[0]
+ image_latents = torch.cat([image_latents] * additional_image_per_prompt, dim=0)
+ elif batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] != 0:
+ raise ValueError(
+ f"Cannot duplicate `image` of batch size {image_latents.shape[0]} to {batch_size} text prompts."
+ )
+ else:
+ image_latents = torch.cat([image_latents], dim=0)
+
+ if latents is None:
+ noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
+ latents = self.scheduler.scale_noise(image_latents, timestep, noise)
+ else:
+ noise = latents.to(device)
+ latents = noise
+
+ noise = self._pack_latents(noise, batch_size, num_channels_latents, height, width)
+ image_latents = self._pack_latents(image_latents, batch_size, num_channels_latents, height, width)
+ latents = self._pack_latents(latents, batch_size, num_channels_latents, height, width)
+ return latents, noise, image_latents, latent_image_ids
+
+ def prepare_mask_latents(
+ self,
+ mask,
+ masked_image,
+ batch_size,
+ num_channels_latents,
+ num_images_per_prompt,
+ height,
+ width,
+ dtype,
+ device,
+ generator,
+ ):
+ height = 2 * (int(height) // self.vae_scale_factor)
+ width = 2 * (int(width) // self.vae_scale_factor)
+ # resize the mask to latents shape as we concatenate the mask to the latents
+ # we do that before converting to dtype to avoid breaking in case we're using cpu_offload
+ # and half precision
+ mask = torch.nn.functional.interpolate(mask, size=(height, width))
+ mask = mask.to(device=device, dtype=dtype)
+
+ batch_size = batch_size * num_images_per_prompt
+
+ masked_image = masked_image.to(device=device, dtype=dtype)
+
+ if masked_image.shape[1] == 16:
+ masked_image_latents = masked_image
+ else:
+ masked_image_latents = retrieve_latents(self.vae.encode(masked_image), generator=generator)
+
+ masked_image_latents = (masked_image_latents - self.vae.config.shift_factor) * self.vae.config.scaling_factor
+
+ # duplicate mask and masked_image_latents for each generation per prompt, using mps friendly method
+ if mask.shape[0] < batch_size:
+ if not batch_size % mask.shape[0] == 0:
+ raise ValueError(
+ "The passed mask and the required batch size don't match. Masks are supposed to be duplicated to"
+ f" a total batch size of {batch_size}, but {mask.shape[0]} masks were passed. Make sure the number"
+ " of masks that you pass is divisible by the total requested batch size."
+ )
+ mask = mask.repeat(batch_size // mask.shape[0], 1, 1, 1)
+ if masked_image_latents.shape[0] < batch_size:
+ if not batch_size % masked_image_latents.shape[0] == 0:
+ raise ValueError(
+ "The passed images and the required batch size don't match. Images are supposed to be duplicated"
+ f" to a total batch size of {batch_size}, but {masked_image_latents.shape[0]} images were passed."
+ " Make sure the number of images that you pass is divisible by the total requested batch size."
+ )
+ masked_image_latents = masked_image_latents.repeat(batch_size // masked_image_latents.shape[0], 1, 1, 1)
+
+ # aligning device to prevent device errors when concating it with the latent model input
+ masked_image_latents = masked_image_latents.to(device=device, dtype=dtype)
+
+ masked_image_latents = self._pack_latents(
+ masked_image_latents,
+ batch_size,
+ num_channels_latents,
+ height,
+ width,
+ )
+ mask = self._pack_latents(
+ mask.repeat(1, num_channels_latents, 1, 1),
+ batch_size,
+ num_channels_latents,
+ height,
+ width,
+ )
+
+ return mask, masked_image_latents
+
+ @property
+ def guidance_scale(self):
+ return self._guidance_scale
+
+ @property
+ def joint_attention_kwargs(self):
+ return self._joint_attention_kwargs
+
+ @property
+ def num_timesteps(self):
+ return self._num_timesteps
+
+ @property
+ def interrupt(self):
+ return self._interrupt
+
+ @torch.no_grad()
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
+ def __call__(
+ self,
+ prompt: Union[str, List[str]] = None,
+ prompt_2: Optional[Union[str, List[str]]] = None,
+ image: PipelineImageInput = None,
+ mask_image: PipelineImageInput = None,
+ masked_image_latents: PipelineImageInput = None,
+ height: Optional[int] = None,
+ width: Optional[int] = None,
+ padding_mask_crop: Optional[int] = None,
+ strength: float = 0.6,
+ num_inference_steps: int = 28,
+ timesteps: List[int] = None,
+ guidance_scale: float = 7.0,
+ num_images_per_prompt: Optional[int] = 1,
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+ latents: Optional[torch.FloatTensor] = None,
+ prompt_embeds: Optional[torch.FloatTensor] = None,
+ t5_prompt_embeds: Optional[torch.FloatTensor] = None,
+ pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
+ output_type: Optional[str] = "pil",
+ return_dict: bool = True,
+ joint_attention_kwargs: Optional[Dict[str, Any]] = None,
+ callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
+ max_sequence_length: int = 512,
+ ):
+ r"""
+ Function invoked when calling the pipeline for generation.
+
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
+ instead.
+ prompt_2 (`str` or `List[str]`, *optional*):
+ The prompt or prompts to be sent to `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
+ will be used instead
+ image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.Tensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`):
+ `Image`, numpy array or tensor representing an image batch to be used as the starting point. For both
+ numpy array and pytorch tensor, the expected value range is between `[0, 1]` If it's a tensor or a list
+ or tensors, the expected shape should be `(B, C, H, W)` or `(C, H, W)`. If it is a numpy array or a
+ list of arrays, the expected shape should be `(B, H, W, C)` or `(H, W, C)` It can also accept image
+ latents as `image`, but if passing latents directly it is not encoded again.
+ mask_image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.Tensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`):
+ `Image`, numpy array or tensor representing an image batch to mask `image`. White pixels in the mask
+ are repainted while black pixels are preserved. If `mask_image` is a PIL image, it is converted to a
+ single channel (luminance) before use. If it's a numpy array or pytorch tensor, it should contain one
+ color channel (L) instead of 3, so the expected shape for pytorch tensor would be `(B, 1, H, W)`, `(B,
+ H, W)`, `(1, H, W)`, `(H, W)`. And for numpy array would be for `(B, H, W, 1)`, `(B, H, W)`, `(H, W,
+ 1)`, or `(H, W)`.
+ mask_image_latent (`torch.Tensor`, `List[torch.Tensor]`):
+ `Tensor` representing an image batch to mask `image` generated by VAE. If not provided, the mask
+ latents tensor will ge generated by `mask_image`.
+ height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
+ The height in pixels of the generated image. This is set to 1024 by default for the best results.
+ width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
+ The width in pixels of the generated image. This is set to 1024 by default for the best results.
+ padding_mask_crop (`int`, *optional*, defaults to `None`):
+ The size of margin in the crop to be applied to the image and masking. If `None`, no crop is applied to
+ image and mask_image. If `padding_mask_crop` is not `None`, it will first find a rectangular region
+ with the same aspect ration of the image and contains all masked area, and then expand that area based
+ on `padding_mask_crop`. The image and mask_image will then be cropped based on the expanded area before
+ resizing to the original image size for inpainting. This is useful when the masked area is small while
+ the image is large and contain information irrelevant for inpainting, such as background.
+ strength (`float`, *optional*, defaults to 1.0):
+ Indicates extent to transform the reference `image`. Must be between 0 and 1. `image` is used as a
+ starting point and more noise is added the higher the `strength`. The number of denoising steps depends
+ on the amount of noise initially added. When `strength` is 1, added noise is maximum and the denoising
+ process runs for the full number of iterations specified in `num_inference_steps`. A value of 1
+ essentially ignores `image`.
+ num_inference_steps (`int`, *optional*, defaults to 50):
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
+ expense of slower inference.
+ timesteps (`List[int]`, *optional*):
+ Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
+ in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
+ passed will be used. Must be in descending order.
+ guidance_scale (`float`, *optional*, defaults to 7.0):
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
+ usually at the expense of lower image quality.
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
+ The number of images to generate per prompt.
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
+ to make generation deterministic.
+ latents (`torch.FloatTensor`, *optional*):
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
+ tensor will ge generated by sampling using the supplied random `generator`.
+ prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+ provided, text embeddings will be generated from `prompt` input argument.
+ pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
+ If not provided, pooled text embeddings will be generated from `prompt` input argument.
+ output_type (`str`, *optional*, defaults to `"pil"`):
+ The output format of the generate image. Choose between
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~pipelines.flux.FluxPipelineOutput`] instead of a plain tuple.
+ joint_attention_kwargs (`dict`, *optional*):
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
+ `self.processor` in
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
+ callback_on_step_end (`Callable`, *optional*):
+ A function that calls at the end of each denoising steps during the inference. The function is called
+ with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
+ callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
+ `callback_on_step_end_tensor_inputs`.
+ callback_on_step_end_tensor_inputs (`List`, *optional*):
+ The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
+ will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
+ `._callback_tensor_inputs` attribute of your pipeline class.
+ max_sequence_length (`int` defaults to 512): Maximum sequence length to use with the `prompt`.
+
+ Examples:
+
+ Returns:
+ [`~pipelines.flux.FluxPipelineOutput`] or `tuple`: [`~pipelines.flux.FluxPipelineOutput`] if `return_dict`
+ is True, otherwise a `tuple`. When returning a tuple, the first element is a list with the generated
+ images.
+ """
+
+ height = height or self.default_sample_size * self.vae_scale_factor
+ width = width or self.default_sample_size * self.vae_scale_factor
+
+ # 1. Check inputs. Raise error if not correct
+ self.check_inputs(
+ prompt,
+ prompt_2,
+ image,
+ mask_image,
+ strength,
+ height,
+ width,
+ output_type=output_type,
+ prompt_embeds=prompt_embeds,
+ pooled_prompt_embeds=pooled_prompt_embeds,
+ callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
+ padding_mask_crop=padding_mask_crop,
+ max_sequence_length=max_sequence_length,
+ )
+
+ self._guidance_scale = guidance_scale
+ self._joint_attention_kwargs = joint_attention_kwargs
+ self._interrupt = False
+
+ # 2. Preprocess mask and image
+ if padding_mask_crop is not None:
+ crops_coords = self.mask_processor.get_crop_region(mask_image, width, height, pad=padding_mask_crop)
+ resize_mode = "fill"
+ else:
+ crops_coords = None
+ resize_mode = "default"
+
+ original_image = image
+ init_image = self.image_processor.preprocess(
+ image, height=height, width=width, crops_coords=crops_coords, resize_mode=resize_mode
+ )
+ init_image = init_image.to(dtype=torch.float32)
+
+ # 3. Define call parameters
+ if prompt is not None and isinstance(prompt, str):
+ batch_size = 1
+ elif prompt is not None and isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ device = self._execution_device
+
+ lora_scale = (
+ self.joint_attention_kwargs.get("scale", None) if self.joint_attention_kwargs is not None else None
+ )
+ (
+ prompt_embeds,
+ pooled_prompt_embeds,
+ text_ids,
+ ) = self.encode_prompt(
+ prompt=prompt,
+ prompt_2=prompt_2,
+ prompt_embeds=prompt_embeds,
+ t5_prompt_embeds=t5_prompt_embeds,
+ pooled_prompt_embeds=pooled_prompt_embeds,
+ device=device,
+ num_images_per_prompt=num_images_per_prompt,
+ max_sequence_length=max_sequence_length,
+ lora_scale=lora_scale,
+ )
+
+ # 4.Prepare timesteps
+ sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps)
+ image_seq_len = (int(height) // self.vae_scale_factor) * (int(width) // self.vae_scale_factor)
+ mu = calculate_shift(
+ image_seq_len,
+ self.scheduler.config.base_image_seq_len,
+ self.scheduler.config.max_image_seq_len,
+ self.scheduler.config.base_shift,
+ self.scheduler.config.max_shift,
+ )
+ timesteps, num_inference_steps = retrieve_timesteps(
+ self.scheduler,
+ num_inference_steps,
+ device,
+ timesteps,
+ sigmas,
+ mu=mu,
+ )
+ timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device)
+
+ if num_inference_steps < 1:
+ raise ValueError(
+ f"After adjusting the num_inference_steps by strength parameter: {strength}, the number of pipeline"
+ f"steps is {num_inference_steps} which is < 1 and not appropriate for this pipeline."
+ )
+ latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt)
+
+ # 5. Prepare latent variables
+ num_channels_latents = self.transformer.config.in_channels // 4
+ num_channels_transformer = self.transformer.config.in_channels
+
+ latents, noise, image_latents, latent_image_ids = self.prepare_latents(
+ init_image,
+ latent_timestep,
+ batch_size * num_images_per_prompt,
+ num_channels_latents,
+ height,
+ width,
+ prompt_embeds.dtype,
+ device,
+ generator,
+ latents,
+ )
+
+ mask_condition = self.mask_processor.preprocess(
+ mask_image, height=height, width=width, resize_mode=resize_mode, crops_coords=crops_coords
+ )
+
+ if masked_image_latents is None:
+ masked_image = init_image * (mask_condition < 0.5)
+ else:
+ masked_image = masked_image_latents
+
+ mask, masked_image_latents = self.prepare_mask_latents(
+ mask_condition,
+ masked_image,
+ batch_size,
+ num_channels_latents,
+ num_images_per_prompt,
+ height,
+ width,
+ prompt_embeds.dtype,
+ device,
+ generator,
+ )
+
+ num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
+ self._num_timesteps = len(timesteps)
+
+ # handle guidance
+ if self.transformer.config.guidance_embeds:
+ guidance = torch.full([1], guidance_scale, device=device, dtype=torch.float32)
+ guidance = guidance.expand(latents.shape[0])
+ else:
+ guidance = None
+
+ # 6. Denoising loop
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
+ for i, t in enumerate(timesteps):
+ if self.interrupt:
+ continue
+
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
+ timestep = t.expand(latents.shape[0]).to(latents.dtype)
+ noise_pred = self.transformer(
+ hidden_states=latents,
+ timestep=timestep / 1000,
+ guidance=guidance,
+ pooled_projections=pooled_prompt_embeds,
+ encoder_hidden_states=prompt_embeds,
+ t5_encoder_hidden_states=t5_prompt_embeds,
+ txt_ids=text_ids,
+ img_ids=latent_image_ids,
+ joint_attention_kwargs=self.joint_attention_kwargs,
+ return_dict=False,
+ )[0]
+
+ # compute the previous noisy sample x_t -> x_t-1
+ latents_dtype = latents.dtype
+ latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
+
+ # for 64 channel transformer only.
+ init_latents_proper = image_latents
+ init_mask = mask
+
+ if i < len(timesteps) - 1:
+ noise_timestep = timesteps[i + 1]
+ init_latents_proper = self.scheduler.scale_noise(
+ init_latents_proper, torch.tensor([noise_timestep]), noise
+ )
+
+ latents = (1 - init_mask) * init_latents_proper + init_mask * latents
+
+ if latents.dtype != latents_dtype:
+ if torch.backends.mps.is_available():
+ # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
+ latents = latents.to(latents_dtype)
+
+ if callback_on_step_end is not None:
+ callback_kwargs = {}
+ for k in callback_on_step_end_tensor_inputs:
+ callback_kwargs[k] = locals()[k]
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
+
+ latents = callback_outputs.pop("latents", latents)
+ prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
+
+ # call the callback, if provided
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
+ progress_bar.update()
+
+ if XLA_AVAILABLE:
+ xm.mark_step()
+
+ if output_type == "latent":
+ image = latents
+
+ else:
+ latents = self._unpack_latents(latents, height, width, self.vae_scale_factor)
+ latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor
+ image = self.vae.decode(latents, return_dict=False)[0]
+ image = self.image_processor.postprocess(image, output_type=output_type)
+
+ # Offload all models
+ self.maybe_free_model_hooks()
+
+ if not return_dict:
+ return (image,)
+
+ return FluxPipelineOutput(images=image)
\ No newline at end of file
diff --git a/flux/pipeline_output.py b/flux/pipeline_output.py
new file mode 100644
index 0000000000000000000000000000000000000000..a0754ded662d0c2101e76a4420459740cd5a72f8
--- /dev/null
+++ b/flux/pipeline_output.py
@@ -0,0 +1,21 @@
+from dataclasses import dataclass
+from typing import List, Union
+
+import numpy as np
+import PIL.Image
+
+from diffusers.utils import BaseOutput
+
+
+@dataclass
+class FluxPipelineOutput(BaseOutput):
+ """
+ Output class for Stable Diffusion pipelines.
+
+ Args:
+ images (`List[PIL.Image.Image]` or `np.ndarray`)
+ List of denoised PIL images of length `batch_size` or numpy array of shape `(batch_size, height, width,
+ num_channels)`. PIL images or numpy array present the denoised images of the diffusion pipeline.
+ """
+
+ images: Union[List[PIL.Image.Image], np.ndarray]
\ No newline at end of file
diff --git a/flux/scheduling_flow_match_euler_discrete.py b/flux/scheduling_flow_match_euler_discrete.py
new file mode 100644
index 0000000000000000000000000000000000000000..81b8c98dafcf1ca8408bcc1822d46250e8364471
--- /dev/null
+++ b/flux/scheduling_flow_match_euler_discrete.py
@@ -0,0 +1,325 @@
+# Copyright 2024 Stability AI, Katherine Crowson and The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import math
+from dataclasses import dataclass
+from typing import List, Optional, Tuple, Union
+
+import numpy as np
+import torch
+
+from diffusers.configuration_utils import ConfigMixin, register_to_config
+from diffusers.utils import BaseOutput, logging
+from diffusers.schedulers.scheduling_utils import SchedulerMixin
+
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+
+@dataclass
+class FlowMatchEulerDiscreteSchedulerOutput(BaseOutput):
+ """
+ Output class for the scheduler's `step` function output.
+
+ Args:
+ prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
+ Computed sample `(x_{t-1})` of previous timestep. `prev_sample` should be used as next model input in the
+ denoising loop.
+ """
+
+ prev_sample: torch.FloatTensor
+
+
+class FlowMatchEulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
+ """
+ Euler scheduler.
+
+ This model inherits from [`SchedulerMixin`] and [`ConfigMixin`]. Check the superclass documentation for the generic
+ methods the library implements for all schedulers such as loading and saving.
+
+ Args:
+ num_train_timesteps (`int`, defaults to 1000):
+ The number of diffusion steps to train the model.
+ timestep_spacing (`str`, defaults to `"linspace"`):
+ The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and
+ Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information.
+ shift (`float`, defaults to 1.0):
+ The shift value for the timestep schedule.
+ """
+
+ _compatibles = []
+ order = 1
+
+ @register_to_config
+ def __init__(
+ self,
+ num_train_timesteps: int = 1000,
+ shift: float = 1.0,
+ use_dynamic_shifting=False,
+ base_shift: Optional[float] = 0.5,
+ max_shift: Optional[float] = 1.15,
+ base_image_seq_len: Optional[int] = 256,
+ max_image_seq_len: Optional[int] = 4096,
+ ):
+ timesteps = np.linspace(1, num_train_timesteps, num_train_timesteps, dtype=np.float32)[::-1].copy()
+ timesteps = torch.from_numpy(timesteps).to(dtype=torch.float32)
+
+ sigmas = timesteps / num_train_timesteps
+ if not use_dynamic_shifting:
+ # when use_dynamic_shifting is True, we apply the timestep shifting on the fly based on the image resolution
+ sigmas = shift * sigmas / (1 + (shift - 1) * sigmas)
+
+ self.timesteps = sigmas * num_train_timesteps
+
+ self._step_index = None
+ self._begin_index = None
+
+ self.sigmas = sigmas.to("cpu") # to avoid too much CPU/GPU communication
+ self.sigma_min = self.sigmas[-1].item()
+ self.sigma_max = self.sigmas[0].item()
+
+ @property
+ def step_index(self):
+ """
+ The index counter for current timestep. It will increase 1 after each scheduler step.
+ """
+ return self._step_index
+
+ @property
+ def begin_index(self):
+ """
+ The index for the first timestep. It should be set from pipeline with `set_begin_index` method.
+ """
+ return self._begin_index
+
+ # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.set_begin_index
+ def set_begin_index(self, begin_index: int = 0):
+ """
+ Sets the begin index for the scheduler. This function should be run from pipeline before the inference.
+
+ Args:
+ begin_index (`int`):
+ The begin index for the scheduler.
+ """
+ self._begin_index = begin_index
+
+ def scale_noise(
+ self,
+ sample: torch.FloatTensor,
+ timestep: Union[float, torch.FloatTensor],
+ noise: Optional[torch.FloatTensor] = None,
+ ) -> torch.FloatTensor:
+ """
+ Forward process in flow-matching
+
+ Args:
+ sample (`torch.FloatTensor`):
+ The input sample.
+ timestep (`int`, *optional*):
+ The current timestep in the diffusion chain.
+
+ Returns:
+ `torch.FloatTensor`:
+ A scaled input sample.
+ """
+ # Make sure sigmas and timesteps have the same device and dtype as original_samples
+ sigmas = self.sigmas.to(device=sample.device, dtype=sample.dtype)
+
+ if sample.device.type == "mps" and torch.is_floating_point(timestep):
+ # mps does not support float64
+ schedule_timesteps = self.timesteps.to(sample.device, dtype=torch.float32)
+ timestep = timestep.to(sample.device, dtype=torch.float32)
+ else:
+ schedule_timesteps = self.timesteps.to(sample.device)
+ timestep = timestep.to(sample.device)
+
+ # self.begin_index is None when scheduler is used for training, or pipeline does not implement set_begin_index
+ if self.begin_index is None:
+ step_indices = [self.index_for_timestep(t, schedule_timesteps) for t in timestep]
+ elif self.step_index is not None:
+ # add_noise is called after first denoising step (for inpainting)
+ step_indices = [self.step_index] * timestep.shape[0]
+ else:
+ # add noise is called before first denoising step to create initial latent(img2img)
+ step_indices = [self.begin_index] * timestep.shape[0]
+
+ sigma = sigmas[step_indices].flatten()
+ while len(sigma.shape) < len(sample.shape):
+ sigma = sigma.unsqueeze(-1)
+
+ sample = sigma * noise + (1.0 - sigma) * sample
+
+ return sample
+
+ def _sigma_to_t(self, sigma):
+ return sigma * self.config.num_train_timesteps
+
+ def time_shift(self, mu: float, sigma: float, t: torch.Tensor):
+ return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma)
+
+ def set_timesteps(
+ self,
+ num_inference_steps: int = None,
+ device: Union[str, torch.device] = None,
+ sigmas: Optional[List[float]] = None,
+ mu: Optional[float] = None,
+ ):
+ """
+ Sets the discrete timesteps used for the diffusion chain (to be run before inference).
+
+ Args:
+ num_inference_steps (`int`):
+ The number of diffusion steps used when generating samples with a pre-trained model.
+ device (`str` or `torch.device`, *optional*):
+ The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
+ """
+
+ if self.config.use_dynamic_shifting and mu is None:
+ raise ValueError(" you have a pass a value for `mu` when `use_dynamic_shifting` is set to be `True`")
+
+ if sigmas is None:
+ self.num_inference_steps = num_inference_steps
+ timesteps = np.linspace(
+ self._sigma_to_t(self.sigma_max), self._sigma_to_t(self.sigma_min), num_inference_steps
+ )
+
+ sigmas = timesteps / self.config.num_train_timesteps
+
+ if self.config.use_dynamic_shifting:
+ sigmas = self.time_shift(mu, 1.0, sigmas)
+ else:
+ sigmas = self.config.shift * sigmas / (1 + (self.config.shift - 1) * sigmas)
+
+ sigmas = torch.from_numpy(sigmas).to(dtype=torch.float32, device=device)
+ timesteps = sigmas * self.config.num_train_timesteps
+
+ self.timesteps = timesteps.to(device=device)
+ self.sigmas = torch.cat([sigmas, torch.zeros(1, device=sigmas.device)])
+
+ self._step_index = None
+ self._begin_index = None
+
+ def index_for_timestep(self, timestep, schedule_timesteps=None):
+ if schedule_timesteps is None:
+ schedule_timesteps = self.timesteps
+
+ indices = (schedule_timesteps == timestep).nonzero()
+
+ # The sigma index that is taken for the **very** first `step`
+ # is always the second index (or the last index if there is only 1)
+ # This way we can ensure we don't accidentally skip a sigma in
+ # case we start in the middle of the denoising schedule (e.g. for image-to-image)
+ pos = 1 if len(indices) > 1 else 0
+
+ return indices[pos].item()
+
+ def _init_step_index(self, timestep):
+ if self.begin_index is None:
+ if isinstance(timestep, torch.Tensor):
+ timestep = timestep.to(self.timesteps.device)
+ self._step_index = self.index_for_timestep(timestep)
+ else:
+ self._step_index = self._begin_index
+
+ def step(
+ self,
+ model_output: torch.FloatTensor,
+ timestep: Union[float, torch.FloatTensor],
+ sample: torch.FloatTensor,
+ s_churn: float = 0.0,
+ s_tmin: float = 0.0,
+ s_tmax: float = float("inf"),
+ s_noise: float = 1.0,
+ generator: Optional[torch.Generator] = None,
+ return_dict: bool = True,
+ ) -> Union[FlowMatchEulerDiscreteSchedulerOutput, Tuple]:
+ """
+ Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion
+ process from the learned model outputs (most often the predicted noise).
+
+ Args:
+ model_output (`torch.FloatTensor`):
+ The direct output from learned diffusion model.
+ timestep (`float`):
+ The current discrete timestep in the diffusion chain.
+ sample (`torch.FloatTensor`):
+ A current instance of a sample created by the diffusion process.
+ s_churn (`float`):
+ s_tmin (`float`):
+ s_tmax (`float`):
+ s_noise (`float`, defaults to 1.0):
+ Scaling factor for noise added to the sample.
+ generator (`torch.Generator`, *optional*):
+ A random number generator.
+ return_dict (`bool`):
+ Whether or not to return a [`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] or
+ tuple.
+
+ Returns:
+ [`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] or `tuple`:
+ If return_dict is `True`, [`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] is
+ returned, otherwise a tuple is returned where the first element is the sample tensor.
+ """
+
+ if (
+ isinstance(timestep, int)
+ or isinstance(timestep, torch.IntTensor)
+ or isinstance(timestep, torch.LongTensor)
+ ):
+ raise ValueError(
+ (
+ "Passing integer indices (e.g. from `enumerate(timesteps)`) as timesteps to"
+ " `EulerDiscreteScheduler.step()` is not supported. Make sure to pass"
+ " one of the `scheduler.timesteps` as a timestep."
+ ),
+ )
+
+ if self.step_index is None:
+ self._init_step_index(timestep)
+
+ # Upcast to avoid precision issues when computing prev_sample
+ sample = sample.to(torch.float32)
+
+ sigma = self.sigmas[self.step_index]
+ sigma_next = self.sigmas[self.step_index + 1]
+
+ prev_sample = sample + (sigma_next - sigma) * model_output
+
+ # Cast sample back to model compatible dtype
+ prev_sample = prev_sample.to(model_output.dtype)
+
+ # upon completion increase step index by one
+ self._step_index += 1
+
+ if not return_dict:
+ return (prev_sample,)
+
+ return FlowMatchEulerDiscreteSchedulerOutput(prev_sample=prev_sample)
+
+ def __len__(self):
+ return self.config.num_train_timesteps
+
+ def step_to_x0(self, model_output: torch.FloatTensor, timestep: Union[float, torch.FloatTensor], sample: torch.FloatTensor) -> torch.FloatTensor:
+ """
+ Compute the predicted x_0 given the model output and current sample at timestep t.
+ """
+ if self.step_index is None:
+ self._init_step_index(timestep)
+
+ sigma = self.sigmas[self.step_index]
+ sigma_from = sigma
+ sigma_to = self.sigmas[-1] # This corresponds to x_0
+
+ x0 = sample + (sigma_to - sigma_from) * model_output
+ return x0
\ No newline at end of file
diff --git a/flux/transformer_flux.py b/flux/transformer_flux.py
new file mode 100644
index 0000000000000000000000000000000000000000..823b67b3fc31b36ae391f3b1b4df1471e66e88eb
--- /dev/null
+++ b/flux/transformer_flux.py
@@ -0,0 +1,572 @@
+# Copyright 2024 Black Forest Labs, The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+from typing import Any, Dict, List, Optional, Union
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from diffusers.configuration_utils import ConfigMixin, register_to_config
+from .lora.peft import PeftAdapterMixin
+from diffusers.models.attention import FeedForward
+from .attention_processor import Attention, FluxAttnProcessor2_0, FluxSingleAttnProcessor2_0
+from diffusers.models.modeling_utils import ModelMixin
+from .normalization import AdaLayerNormContinuous, AdaLayerNormZero, AdaLayerNormZeroSingle
+from diffusers.utils import USE_PEFT_BACKEND, is_torch_version, logging, scale_lora_layers, unscale_lora_layers
+from diffusers.utils.torch_utils import maybe_allow_in_graph
+from .embeddings import CombinedTimestepGuidanceTextProjEmbeddings, CombinedTimestepTextProjEmbeddings
+from diffusers.models.modeling_outputs import Transformer2DModelOutput
+import numpy as np
+
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+def get_1d_rotary_pos_embed(
+ dim: int,
+ pos: Union[np.ndarray, int],
+ theta: float = 10000.0,
+ use_real=False,
+ linear_factor=1.0,
+ ntk_factor=1.0,
+ repeat_interleave_real=True,
+ freqs_dtype=torch.float32, # torch.float32, torch.float64 (flux)
+):
+ """
+ Precompute the frequency tensor for complex exponentials (cis) with given dimensions.
+
+ This function calculates a frequency tensor with complex exponentials using the given dimension 'dim' and the end
+ index 'end'. The 'theta' parameter scales the frequencies. The returned tensor contains complex values in complex64
+ data type.
+
+ Args:
+ dim (`int`): Dimension of the frequency tensor.
+ pos (`np.ndarray` or `int`): Position indices for the frequency tensor. [S] or scalar
+ theta (`float`, *optional*, defaults to 10000.0):
+ Scaling factor for frequency computation. Defaults to 10000.0.
+ use_real (`bool`, *optional*):
+ If True, return real part and imaginary part separately. Otherwise, return complex numbers.
+ linear_factor (`float`, *optional*, defaults to 1.0):
+ Scaling factor for the context extrapolation. Defaults to 1.0.
+ ntk_factor (`float`, *optional*, defaults to 1.0):
+ Scaling factor for the NTK-Aware RoPE. Defaults to 1.0.
+ repeat_interleave_real (`bool`, *optional*, defaults to `True`):
+ If `True` and `use_real`, real part and imaginary part are each interleaved with themselves to reach `dim`.
+ Otherwise, they are concateanted with themselves.
+ freqs_dtype (`torch.float32` or `torch.float64`, *optional*, defaults to `torch.float32`):
+ the dtype of the frequency tensor.
+ Returns:
+ `torch.Tensor`: Precomputed frequency tensor with complex exponentials. [S, D/2]
+ """
+ assert dim % 2 == 0
+
+ if isinstance(pos, int):
+ pos = torch.arange(pos)
+ if isinstance(pos, np.ndarray):
+ pos = torch.from_numpy(pos) # type: ignore # [S]
+
+ theta = theta * ntk_factor
+ freqs = 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=freqs_dtype)[: (dim // 2)] / dim)) / linear_factor # [D/2]
+ freqs = freqs.to(pos.device)
+ freqs = torch.outer(pos, freqs) # type: ignore # [S, D/2]
+ if use_real and repeat_interleave_real:
+ # flux, hunyuan-dit, cogvideox
+ freqs_cos = freqs.cos().repeat_interleave(2, dim=1).float() # [S, D]
+ freqs_sin = freqs.sin().repeat_interleave(2, dim=1).float() # [S, D]
+ return freqs_cos, freqs_sin
+ elif use_real:
+ # stable audio
+ freqs_cos = torch.cat([freqs.cos(), freqs.cos()], dim=-1).float() # [S, D]
+ freqs_sin = torch.cat([freqs.sin(), freqs.sin()], dim=-1).float() # [S, D]
+ return freqs_cos, freqs_sin
+ else:
+ # lumina
+ freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64 # [S, D/2]
+ return freqs_cis
+
+class FluxPosEmbed(nn.Module):
+ # modified from https://github.com/black-forest-labs/flux/blob/c00d7c60b085fce8058b9df845e036090873f2ce/src/flux/modules/layers.py#L11
+ def __init__(self, theta: int, axes_dim: List[int]):
+ super().__init__()
+ self.theta = theta
+ self.axes_dim = axes_dim
+
+ def forward(self, ids: torch.Tensor) -> torch.Tensor:
+ n_axes = ids.shape[-1]
+ cos_out = []
+ sin_out = []
+ pos = ids.squeeze().float()
+ is_mps = ids.device.type == "mps"
+ freqs_dtype = torch.float32 if is_mps else torch.float64
+ for i in range(n_axes):
+ cos, sin = get_1d_rotary_pos_embed(
+ self.axes_dim[i], pos[:, i], repeat_interleave_real=True, use_real=True, freqs_dtype=freqs_dtype
+ )
+ cos_out.append(cos)
+ sin_out.append(sin)
+ freqs_cos = torch.cat(cos_out, dim=-1).to(ids.device)
+ freqs_sin = torch.cat(sin_out, dim=-1).to(ids.device)
+ return freqs_cos, freqs_sin
+
+# YiYi to-do: refactor rope related functions/classes
+def rope(pos: torch.Tensor, dim: int, theta: int) -> torch.Tensor:
+ assert dim % 2 == 0, "The dimension must be even."
+
+ scale = torch.arange(0, dim, 2, dtype=torch.float64, device=pos.device) / dim
+ omega = 1.0 / (theta**scale)
+
+ batch_size, seq_length = pos.shape
+ out = torch.einsum("...n,d->...nd", pos, omega)
+ cos_out = torch.cos(out)
+ sin_out = torch.sin(out)
+
+ stacked_out = torch.stack([cos_out, -sin_out, sin_out, cos_out], dim=-1)
+ out = stacked_out.view(batch_size, -1, dim // 2, 2, 2)
+ return out.float()
+
+
+# YiYi to-do: refactor rope related functions/classes
+class EmbedND(nn.Module):
+ def __init__(self, dim: int, theta: int, axes_dim: List[int]):
+ super().__init__()
+ self.dim = dim
+ self.theta = theta
+ self.axes_dim = axes_dim
+
+ def forward(self, ids: torch.Tensor) -> torch.Tensor:
+ n_axes = ids.shape[-1]
+ emb = torch.cat(
+ [rope(ids[..., i], self.axes_dim[i], self.theta) for i in range(n_axes)],
+ dim=-3,
+ )
+ return emb.unsqueeze(1)
+
+
+@maybe_allow_in_graph
+class FluxSingleTransformerBlock(nn.Module):
+ r"""
+ A Transformer block following the MMDiT architecture, introduced in Stable Diffusion 3.
+
+ Reference: https://arxiv.org/abs/2403.03206
+
+ Parameters:
+ dim (`int`): The number of channels in the input and output.
+ num_attention_heads (`int`): The number of heads to use for multi-head attention.
+ attention_head_dim (`int`): The number of channels in each head.
+ context_pre_only (`bool`): Boolean to determine if we should add some blocks associated with the
+ processing of `context` conditions.
+ """
+
+ def __init__(self, dim, num_attention_heads, attention_head_dim, mlp_ratio=4.0):
+ super().__init__()
+ self.mlp_hidden_dim = int(dim * mlp_ratio)
+
+ self.norm = AdaLayerNormZeroSingle(dim)
+ self.proj_mlp = nn.Linear(dim, self.mlp_hidden_dim)
+ self.act_mlp = nn.GELU(approximate="tanh")
+ self.proj_out = nn.Linear(dim + self.mlp_hidden_dim, dim)
+
+ processor = FluxSingleAttnProcessor2_0()
+ self.attn = Attention(
+ query_dim=dim,
+ cross_attention_dim=None,
+ dim_head=attention_head_dim,
+ heads=num_attention_heads,
+ out_dim=dim,
+ bias=True,
+ processor=processor,
+ qk_norm="rms_norm",
+ eps=1e-6,
+ pre_only=True,
+ )
+
+ def forward(
+ self,
+ hidden_states: torch.FloatTensor,
+ temb: torch.FloatTensor,
+ image_rotary_emb=None,
+ ):
+ residual = hidden_states
+ norm_hidden_states, gate = self.norm(hidden_states, emb=temb)
+ mlp_hidden_states = self.act_mlp(self.proj_mlp(norm_hidden_states))
+
+ attn_output = self.attn(
+ hidden_states=norm_hidden_states,
+ image_rotary_emb=image_rotary_emb,
+ )
+
+ hidden_states = torch.cat([attn_output, mlp_hidden_states], dim=2)
+ gate = gate.unsqueeze(1)
+ hidden_states = gate * self.proj_out(hidden_states)
+ hidden_states = residual + hidden_states
+
+ return hidden_states
+
+
+@maybe_allow_in_graph
+class FluxTransformerBlock(nn.Module):
+ r"""
+ A Transformer block following the MMDiT architecture, introduced in Stable Diffusion 3.
+
+ Reference: https://arxiv.org/abs/2403.03206
+
+ Parameters:
+ dim (`int`): The number of channels in the input and output.
+ num_attention_heads (`int`): The number of heads to use for multi-head attention.
+ attention_head_dim (`int`): The number of channels in each head.
+ context_pre_only (`bool`): Boolean to determine if we should add some blocks associated with the
+ processing of `context` conditions.
+ """
+
+ def __init__(self, dim, num_attention_heads, attention_head_dim, qk_norm="rms_norm", eps=1e-6):
+ super().__init__()
+
+ self.norm1 = AdaLayerNormZero(dim)
+
+ self.norm1_context = AdaLayerNormZero(dim)
+
+ if hasattr(F, "scaled_dot_product_attention"):
+ processor = FluxAttnProcessor2_0()
+ else:
+ raise ValueError(
+ "The current PyTorch version does not support the `scaled_dot_product_attention` function."
+ )
+ self.attn = Attention(
+ query_dim=dim,
+ cross_attention_dim=None,
+ added_kv_proj_dim=dim,
+ dim_head=attention_head_dim,
+ heads=num_attention_heads,
+ out_dim=dim,
+ context_pre_only=False,
+ bias=True,
+ processor=processor,
+ qk_norm=qk_norm,
+ eps=eps,
+ )
+
+ self.norm2 = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
+ self.ff = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate")
+
+ self.norm2_context = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
+ self.ff_context = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate")
+
+ # let chunk size default to None
+ self._chunk_size = None
+ self._chunk_dim = 0
+
+ def forward(
+ self,
+ hidden_states: torch.FloatTensor,
+ encoder_hidden_states: torch.FloatTensor,
+ temb: torch.FloatTensor,
+ image_rotary_emb=None,
+ ):
+ norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(hidden_states, emb=temb)
+
+ norm_encoder_hidden_states, c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp = self.norm1_context(
+ encoder_hidden_states, emb=temb
+ )
+
+ # Attention.
+ attn_output, context_attn_output = self.attn(
+ hidden_states=norm_hidden_states,
+ encoder_hidden_states=norm_encoder_hidden_states,
+ image_rotary_emb=image_rotary_emb,
+ )
+
+ # Process attention outputs for the `hidden_states`.
+ attn_output = gate_msa.unsqueeze(1) * attn_output
+ hidden_states = hidden_states + attn_output
+
+ norm_hidden_states = self.norm2(hidden_states)
+ norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
+
+ ff_output = self.ff(norm_hidden_states)
+ ff_output = gate_mlp.unsqueeze(1) * ff_output
+
+ hidden_states = hidden_states + ff_output
+
+ # Process attention outputs for the `encoder_hidden_states`.
+
+ context_attn_output = c_gate_msa.unsqueeze(1) * context_attn_output
+ encoder_hidden_states = encoder_hidden_states + context_attn_output
+
+ norm_encoder_hidden_states = self.norm2_context(encoder_hidden_states)
+ norm_encoder_hidden_states = norm_encoder_hidden_states * (1 + c_scale_mlp[:, None]) + c_shift_mlp[:, None]
+
+ context_ff_output = self.ff_context(norm_encoder_hidden_states)
+ encoder_hidden_states = encoder_hidden_states + c_gate_mlp.unsqueeze(1) * context_ff_output
+
+ return encoder_hidden_states, hidden_states
+
+
+class FluxTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
+ """
+ The Transformer model introduced in Flux.
+
+ Reference: https://blackforestlabs.ai/announcing-black-forest-labs/
+
+ Parameters:
+ patch_size (`int`): Patch size to turn the input data into small patches.
+ in_channels (`int`, *optional*, defaults to 16): The number of channels in the input.
+ num_layers (`int`, *optional*, defaults to 18): The number of layers of MMDiT blocks to use.
+ num_single_layers (`int`, *optional*, defaults to 18): The number of layers of single DiT blocks to use.
+ attention_head_dim (`int`, *optional*, defaults to 64): The number of channels in each head.
+ num_attention_heads (`int`, *optional*, defaults to 18): The number of heads to use for multi-head attention.
+ joint_attention_dim (`int`, *optional*): The number of `encoder_hidden_states` dimensions to use.
+ pooled_projection_dim (`int`): Number of dimensions to use when projecting the `pooled_projections`.
+ guidance_embeds (`bool`, defaults to False): Whether to use guidance embeddings.
+ """
+
+ _supports_gradient_checkpointing = True
+
+ @register_to_config
+ def __init__(
+ self,
+ patch_size: int = 1,
+ in_channels: int = 64,
+ num_layers: int = 19,
+ num_single_layers: int = 38,
+ attention_head_dim: int = 128,
+ num_attention_heads: int = 24,
+ joint_attention_dim: int = 4096,
+ pooled_projection_dim: int = 768,
+ guidance_embeds: bool = False,
+ axes_dims_rope: List[int] = [16, 56, 56],
+ ):
+ super().__init__()
+ self.out_channels = in_channels
+ self.inner_dim = self.config.num_attention_heads * self.config.attention_head_dim
+
+ #self.pos_embed = EmbedND(dim=self.inner_dim, theta=10000, axes_dim=axes_dims_rope)
+ self.pos_embed = FluxPosEmbed(theta=10000, axes_dim=axes_dims_rope)
+ text_time_guidance_cls = (
+ CombinedTimestepGuidanceTextProjEmbeddings if guidance_embeds else CombinedTimestepTextProjEmbeddings
+ )
+ self.time_text_embed = text_time_guidance_cls(
+ embedding_dim=self.inner_dim, pooled_projection_dim=self.config.pooled_projection_dim
+ )
+
+ self.context_embedder = nn.Linear(self.config.joint_attention_dim, self.inner_dim)
+ self.x_embedder = torch.nn.Linear(self.config.in_channels, self.inner_dim)
+
+ self.transformer_blocks = nn.ModuleList(
+ [
+ FluxTransformerBlock(
+ dim=self.inner_dim,
+ num_attention_heads=self.config.num_attention_heads,
+ attention_head_dim=self.config.attention_head_dim,
+ )
+ for i in range(self.config.num_layers)
+ ]
+ )
+
+ self.single_transformer_blocks = nn.ModuleList(
+ [
+ FluxSingleTransformerBlock(
+ dim=self.inner_dim,
+ num_attention_heads=self.config.num_attention_heads,
+ attention_head_dim=self.config.attention_head_dim,
+ )
+ for i in range(self.config.num_single_layers)
+ ]
+ )
+
+ self.norm_out = AdaLayerNormContinuous(self.inner_dim, self.inner_dim, elementwise_affine=False, eps=1e-6)
+ self.proj_out = nn.Linear(self.inner_dim, patch_size * patch_size * self.out_channels, bias=True)
+
+ self.gradient_checkpointing = True
+
+ def _set_gradient_checkpointing(self, module, value=False):
+ if hasattr(module, "gradient_checkpointing"):
+ module.gradient_checkpointing = value
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ encoder_hidden_states: torch.Tensor = None,
+ t5_encoder_hidden_states: torch.Tensor = None,
+ pooled_projections: torch.Tensor = None,
+ timestep: torch.LongTensor = None,
+ img_ids: torch.Tensor = None,
+ txt_ids: torch.Tensor = None,
+ guidance: torch.Tensor = None,
+ joint_attention_kwargs: Optional[Dict[str, Any]] = None,
+ controlnet_block_samples=None,
+ controlnet_single_block_samples=None,
+ return_dict: bool = True,
+ ) -> Union[torch.FloatTensor, Transformer2DModelOutput]:
+ """
+ The [`FluxTransformer2DModel`] forward method.
+
+ Args:
+ hidden_states (`torch.FloatTensor` of shape `(batch size, channel, height, width)`):
+ Input `hidden_states`.
+ encoder_hidden_states (`torch.FloatTensor` of shape `(batch size, sequence_len, embed_dims)`):
+ Conditional embeddings (embeddings computed from the input conditions such as prompts) to use.
+ pooled_projections (`torch.FloatTensor` of shape `(batch_size, projection_dim)`): Embeddings projected
+ from the embeddings of input conditions.
+ timestep ( `torch.LongTensor`):
+ Used to indicate denoising step.
+ block_controlnet_hidden_states: (`list` of `torch.Tensor`):
+ A list of tensors that if specified are added to the residuals of transformer blocks.
+ joint_attention_kwargs (`dict`, *optional*):
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
+ `self.processor` in
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~models.transformer_2d.Transformer2DModelOutput`] instead of a plain
+ tuple.
+
+ Returns:
+ If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a
+ `tuple` where the first element is the sample tensor.
+ """
+ if joint_attention_kwargs is not None:
+ joint_attention_kwargs = joint_attention_kwargs.copy()
+ lora_scale = joint_attention_kwargs.pop("scale", 1.0)
+ else:
+ lora_scale = 1.0
+
+ if USE_PEFT_BACKEND:
+ # weight the lora layers by setting `lora_scale` for each PEFT layer
+ scale_lora_layers(self, lora_scale)
+ else:
+ if joint_attention_kwargs is not None and joint_attention_kwargs.get("scale", None) is not None:
+ logger.warning(
+ "Passing `scale` via `joint_attention_kwargs` when not using the PEFT backend is ineffective."
+ )
+ hidden_states = self.x_embedder(hidden_states)
+
+ timestep = timestep.to(hidden_states.dtype) * 1000
+ if guidance is not None:
+ guidance = guidance.to(hidden_states.dtype) * 1000
+ else:
+ guidance = None
+ temb = (
+ self.time_text_embed(timestep, pooled_projections)
+ if guidance is None
+ else self.time_text_embed(timestep, guidance, pooled_projections)
+ )
+ encoder_hidden_states = self.context_embedder(encoder_hidden_states)
+ if t5_encoder_hidden_states is not None:
+ encoder_hidden_states = torch.cat([encoder_hidden_states, t5_encoder_hidden_states], dim=1)
+
+ #ids = torch.cat((txt_ids, img_ids), dim=1)
+ if txt_ids.ndim == 3:
+ #logger.warning(
+ # "Passing `txt_ids` 3d torch.Tensor is deprecated."
+ # "Please remove the batch dimension and pass it as a 2d torch Tensor"
+ #)
+ txt_ids = txt_ids[0]
+ if img_ids.ndim == 3:
+ #logger.warning(
+ # "Passing `img_ids` 3d torch.Tensor is deprecated."
+ # "Please remove the batch dimension and pass it as a 2d torch Tensor"
+ #)
+ img_ids = img_ids[0]
+ ids = torch.cat((txt_ids, img_ids), dim=0)
+ image_rotary_emb = self.pos_embed(ids)
+
+ for index_block, block in enumerate(self.transformer_blocks):
+
+ if self.training and self.gradient_checkpointing:
+
+ def create_custom_forward(module, return_dict=None):
+ def custom_forward(*inputs):
+ if return_dict is not None:
+ return module(*inputs, return_dict=return_dict)
+ else:
+ return module(*inputs)
+
+ return custom_forward
+
+ ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
+ encoder_hidden_states, hidden_states = torch.utils.checkpoint.checkpoint(
+ create_custom_forward(block),
+ hidden_states,
+ encoder_hidden_states,
+ temb,
+ image_rotary_emb,
+ **ckpt_kwargs,
+ )
+
+ else:
+ encoder_hidden_states, hidden_states = block(
+ hidden_states=hidden_states,
+ encoder_hidden_states=encoder_hidden_states,
+ temb=temb,
+ image_rotary_emb=image_rotary_emb,
+ )
+
+ # controlnet residual
+ if controlnet_block_samples is not None:
+ interval_control = len(self.transformer_blocks) / len(controlnet_block_samples)
+ interval_control = int(np.ceil(interval_control))
+ hidden_states = hidden_states + controlnet_block_samples[index_block // interval_control]
+
+ hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
+
+ for index_block, block in enumerate(self.single_transformer_blocks):
+ if self.training and self.gradient_checkpointing:
+
+ def create_custom_forward(module, return_dict=None):
+ def custom_forward(*inputs):
+ if return_dict is not None:
+ return module(*inputs, return_dict=return_dict)
+ else:
+ return module(*inputs)
+
+ return custom_forward
+
+ ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
+ hidden_states = torch.utils.checkpoint.checkpoint(
+ create_custom_forward(block),
+ hidden_states,
+ temb,
+ image_rotary_emb,
+ **ckpt_kwargs,
+ )
+
+ else:
+ hidden_states = block(
+ hidden_states=hidden_states,
+ temb=temb,
+ image_rotary_emb=image_rotary_emb,
+ )
+
+ # controlnet residual
+ if controlnet_single_block_samples is not None:
+ interval_control = len(self.single_transformer_blocks) / len(controlnet_single_block_samples)
+ interval_control = int(np.ceil(interval_control))
+ hidden_states[:, encoder_hidden_states.shape[1] :, ...] = (
+ hidden_states[:, encoder_hidden_states.shape[1] :, ...]
+ + controlnet_single_block_samples[index_block // interval_control]
+ )
+
+ hidden_states = hidden_states[:, encoder_hidden_states.shape[1] :, ...]
+
+ hidden_states = self.norm_out(hidden_states, temb)
+ output = self.proj_out(hidden_states)
+
+ if USE_PEFT_BACKEND:
+ # remove `lora_scale` from each PEFT layer
+ unscale_lora_layers(self, lora_scale)
+
+ if not return_dict:
+ return (output,)
+
+ return Transformer2DModelOutput(sample=output)
\ No newline at end of file
diff --git a/qwen2_vl/__pycache__/configuration_qwen2_vl.cpython-310.pyc b/qwen2_vl/__pycache__/configuration_qwen2_vl.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..f7250f58f7dd33c20771e7a38996e59b086aca29
Binary files /dev/null and b/qwen2_vl/__pycache__/configuration_qwen2_vl.cpython-310.pyc differ
diff --git a/qwen2_vl/__pycache__/modeling_qwen2_vl.cpython-310.pyc b/qwen2_vl/__pycache__/modeling_qwen2_vl.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..bf4f643bcad9fbbc76dfd9208c770bc7785b0e5f
Binary files /dev/null and b/qwen2_vl/__pycache__/modeling_qwen2_vl.cpython-310.pyc differ
diff --git a/qwen2_vl/configuration_qwen2_vl.py b/qwen2_vl/configuration_qwen2_vl.py
new file mode 100644
index 0000000000000000000000000000000000000000..2d3d0bb86c1242722cc3dae64d41a02132eb27f6
--- /dev/null
+++ b/qwen2_vl/configuration_qwen2_vl.py
@@ -0,0 +1,206 @@
+# coding=utf-8
+# Copyright 2024 The Qwen team, Alibaba Group and the HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Qwen2VL model configuration"""
+
+import os
+from typing import Union
+
+from transformers.configuration_utils import PretrainedConfig
+from transformers.utils import logging
+
+
+logger = logging.get_logger(__name__)
+
+
+class Qwen2VLVisionConfig(PretrainedConfig):
+ model_type = "qwen2_vl"
+
+ def __init__(
+ self,
+ depth=32,
+ embed_dim=1280,
+ hidden_size=3584,
+ hidden_act="quick_gelu",
+ mlp_ratio=4,
+ num_heads=16,
+ in_channels=3,
+ patch_size=14,
+ spatial_merge_size=2,
+ temporal_patch_size=2,
+ **kwargs,
+ ):
+ super().__init__(**kwargs)
+
+ self.depth = depth
+ self.embed_dim = embed_dim
+ self.hidden_size = hidden_size
+ self.hidden_act = hidden_act
+ self.mlp_ratio = mlp_ratio
+ self.num_heads = num_heads
+ self.in_channels = in_channels
+ self.patch_size = patch_size
+ self.spatial_merge_size = spatial_merge_size
+ self.temporal_patch_size = temporal_patch_size
+
+ @classmethod
+ def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs) -> "PretrainedConfig":
+ cls._set_token_in_kwargs(kwargs)
+
+ config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs)
+
+ if config_dict.get("model_type") == "qwen2_vl":
+ config_dict = config_dict["vision_config"]
+
+ if "model_type" in config_dict and hasattr(cls, "model_type") and config_dict["model_type"] != cls.model_type:
+ logger.warning(
+ f"You are using a model of type {config_dict['model_type']} to instantiate a model of type "
+ f"{cls.model_type}. This is not supported for all configurations of models and can yield errors."
+ )
+
+ return cls.from_dict(config_dict, **kwargs)
+
+
+class Qwen2VLConfig(PretrainedConfig):
+ r"""
+ This is the configuration class to store the configuration of a [`Qwen2VLModel`]. It is used to instantiate a
+ Qwen2-VL model according to the specified arguments, defining the model architecture. Instantiating a configuration
+ with the defaults will yield a similar configuration to that of
+ Qwen2-VL-7B-Instruct [Qwen/Qwen2-VL-7B-Instruct](https://huggingface.co/Qwen/Qwen2-VL-7B-Instruct).
+
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
+ documentation from [`PretrainedConfig`] for more information.
+
+
+ Args:
+ vocab_size (`int`, *optional*, defaults to 152064):
+ Vocabulary size of the Qwen2VL model. Defines the number of different tokens that can be represented by the
+ `inputs_ids` passed when calling [`Qwen2VLModel`]
+ hidden_size (`int`, *optional*, defaults to 8192):
+ Dimension of the hidden representations.
+ intermediate_size (`int`, *optional*, defaults to 29568):
+ Dimension of the MLP representations.
+ num_hidden_layers (`int`, *optional*, defaults to 80):
+ Number of hidden layers in the Transformer encoder.
+ num_attention_heads (`int`, *optional*, defaults to 64):
+ Number of attention heads for each attention layer in the Transformer encoder.
+ num_key_value_heads (`int`, *optional*, defaults to 8):
+ This is the number of key_value heads that should be used to implement Grouped Query Attention. If
+ `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
+ `num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When
+ converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
+ by meanpooling all the original heads within that group. For more details checkout [this
+ paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to `32`.
+ hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
+ The non-linear activation function (function or string) in the decoder.
+ max_position_embeddings (`int`, *optional*, defaults to 32768):
+ The maximum sequence length that this model might ever be used with.
+ initializer_range (`float`, *optional*, defaults to 0.02):
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
+ rms_norm_eps (`float`, *optional*, defaults to 1e-05):
+ The epsilon used by the rms normalization layers.
+ use_cache (`bool`, *optional*, defaults to `True`):
+ Whether or not the model should return the last key/values attentions (not used by all models). Only
+ relevant if `config.is_decoder=True`.
+ tie_word_embeddings (`bool`, *optional*, defaults to `False`):
+ Whether the model's input and output word embeddings should be tied.
+ rope_theta (`float`, *optional*, defaults to 1000000.0):
+ The base period of the RoPE embeddings.
+ use_sliding_window (`bool`, *optional*, defaults to `False`):
+ Whether to use sliding window attention.
+ sliding_window (`int`, *optional*, defaults to 4096):
+ Sliding window attention (SWA) window size. If not specified, will default to `4096`.
+ max_window_layers (`int`, *optional*, defaults to 80):
+ The number of layers that use SWA (Sliding Window Attention). The bottom layers use SWA while the top use full attention.
+ attention_dropout (`float`, *optional*, defaults to 0.0):
+ The dropout ratio for the attention probabilities.
+ vision_config (`Dict`, *optional*):
+ The config for the visual encoder initialization.
+ rope_scaling (`Dict`, *optional*):
+ Dictionary containing the scaling configuration for the RoPE embeddings. Currently supports two scaling
+ strategies: linear and dynamic. Their scaling factor must be a float greater than 1. The expected format is
+ `{"type": strategy name, "factor": scaling factor}`. When using this flag, don't update
+ `max_position_embeddings` to the expected new maximum. See the following thread for more information on how
+ these scaling strategies behave:
+ https://www.reddit.com/r/LocalLLaMA/comments/14mrgpr/dynamically_scaled_rope_further_increases/. This is an
+ experimental feature, subject to breaking API changes in future versions.
+
+ ```python
+ >>> from transformers import Qwen2VLForConditionalGeneration, Qwen2VLConfig
+
+ >>> # Initializing a Qwen2VL style configuration
+ >>> configuration = Qwen2VLConfig()
+
+ >>> # Initializing a model from the Qwen2-VL-7B style configuration
+ >>> model = Qwen2VLForConditionalGeneration(configuration)
+
+ >>> # Accessing the model configuration
+ >>> configuration = model.config
+ ```"""
+
+ model_type = "qwen2_vl"
+ keys_to_ignore_at_inference = ["past_key_values"]
+
+ def __init__(
+ self,
+ vocab_size=152064,
+ hidden_size=8192,
+ intermediate_size=29568,
+ num_hidden_layers=80,
+ num_attention_heads=64,
+ num_key_value_heads=8,
+ hidden_act="silu",
+ max_position_embeddings=32768,
+ initializer_range=0.02,
+ rms_norm_eps=1e-05,
+ use_cache=True,
+ tie_word_embeddings=False,
+ rope_theta=1000000.0,
+ use_sliding_window=False,
+ sliding_window=4096,
+ max_window_layers=80,
+ attention_dropout=0.0,
+ vision_config=None,
+ rope_scaling=None,
+ **kwargs,
+ ):
+ if isinstance(vision_config, dict):
+ self.vision_config = Qwen2VLVisionConfig(**vision_config)
+ elif vision_config is None:
+ self.vision_config = Qwen2VLVisionConfig()
+
+ self.vocab_size = vocab_size
+ self.max_position_embeddings = max_position_embeddings
+ self.hidden_size = hidden_size
+ self.intermediate_size = intermediate_size
+ self.num_hidden_layers = num_hidden_layers
+ self.num_attention_heads = num_attention_heads
+ self.use_sliding_window = use_sliding_window
+ self.sliding_window = sliding_window
+ self.max_window_layers = max_window_layers
+
+ # for backward compatibility
+ if num_key_value_heads is None:
+ num_key_value_heads = num_attention_heads
+
+ self.num_key_value_heads = num_key_value_heads
+ self.hidden_act = hidden_act
+ self.initializer_range = initializer_range
+ self.rms_norm_eps = rms_norm_eps
+ self.use_cache = use_cache
+ self.rope_theta = rope_theta
+ self.attention_dropout = attention_dropout
+ self.rope_scaling = rope_scaling
+
+ super().__init__(tie_word_embeddings=tie_word_embeddings, **kwargs)
diff --git a/qwen2_vl/image_processing_qwen2_vl.py b/qwen2_vl/image_processing_qwen2_vl.py
new file mode 100644
index 0000000000000000000000000000000000000000..bc73129a0aad8ad8e60cf1fed2990666c5f4d147
--- /dev/null
+++ b/qwen2_vl/image_processing_qwen2_vl.py
@@ -0,0 +1,458 @@
+# coding=utf-8
+# Copyright 2024 The Qwen team, Alibaba Group and the HuggingFace Inc. team. All rights reserved.
+#
+# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
+# and OPT implementations in this library. It has been modified from its
+# original forms to accommodate minor architectural differences compared
+# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Image processor class for Qwen2-VL."""
+
+import math
+from typing import Dict, List, Optional, Union
+
+import numpy as np
+
+from transformers.image_processing_utils import BaseImageProcessor, BatchFeature
+from transformers.image_transforms import (
+ convert_to_rgb,
+ resize,
+ to_channel_dimension_format,
+)
+from transformers.image_utils import (
+ OPENAI_CLIP_MEAN,
+ OPENAI_CLIP_STD,
+ ChannelDimension,
+ ImageInput,
+ PILImageResampling,
+ VideoInput,
+ get_image_size,
+ infer_channel_dimension_format,
+ is_scaled_image,
+ is_valid_image,
+ make_list_of_images,
+ to_numpy_array,
+ valid_images,
+ validate_preprocess_arguments,
+)
+from transformers.utils import TensorType, is_vision_available, logging
+
+
+logger = logging.get_logger(__name__)
+
+
+if is_vision_available():
+ from PIL import Image
+
+
+def make_batched_images(images) -> List[List[ImageInput]]:
+ """
+ Accepts images in list or nested list format, and makes a list of images for preprocessing.
+
+ Args:
+ images (`Union[List[List[ImageInput]], List[ImageInput], ImageInput]`):
+ The input image.
+
+ Returns:
+ list: A list of images.
+ """
+ if isinstance(images, (list, tuple)) and isinstance(images[0], (list, tuple)) and is_valid_image(images[0][0]):
+ return [img for img_list in images for img in img_list]
+
+ elif isinstance(images, (list, tuple)) and is_valid_image(images[0]):
+ return images
+
+ elif is_valid_image(images):
+ return [images]
+
+ raise ValueError(f"Could not make batched images from {images}")
+
+
+# Copied from transformers.models.llava_next_video.image_processing_llava_next_video.make_batched_videos
+def make_batched_videos(videos) -> List[VideoInput]:
+ if isinstance(videos, (list, tuple)) and isinstance(videos[0], (list, tuple)) and is_valid_image(videos[0][0]):
+ return videos
+
+ elif isinstance(videos, (list, tuple)) and is_valid_image(videos[0]):
+ if isinstance(videos[0], Image.Image):
+ return [videos]
+ elif len(videos[0].shape) == 4:
+ return [list(video) for video in videos]
+
+ elif is_valid_image(videos) and len(videos.shape) == 4:
+ return [list(videos)]
+
+ raise ValueError(f"Could not make batched video from {videos}")
+
+
+def smart_resize(
+ height: int, width: int, factor: int = 28, min_pixels: int = 56 * 56, max_pixels: int = 14 * 14 * 4 * 1280
+):
+ """Rescales the image so that the following conditions are met:
+
+ 1. Both dimensions (height and width) are divisible by 'factor'.
+
+ 2. The total number of pixels is within the range ['min_pixels', 'max_pixels'].
+
+ 3. The aspect ratio of the image is maintained as closely as possible.
+
+ """
+ if height < factor or width < factor:
+ raise ValueError(f"height:{height} or width:{width} must be larger than factor:{factor}")
+ elif max(height, width) / min(height, width) > 200:
+ raise ValueError(
+ f"absolute aspect ratio must be smaller than 200, got {max(height, width) / min(height, width)}"
+ )
+ h_bar = round(height / factor) * factor
+ w_bar = round(width / factor) * factor
+ if h_bar * w_bar > max_pixels:
+ beta = math.sqrt((height * width) / max_pixels)
+ h_bar = math.floor(height / beta / factor) * factor
+ w_bar = math.floor(width / beta / factor) * factor
+ elif h_bar * w_bar < min_pixels:
+ beta = math.sqrt(min_pixels / (height * width))
+ h_bar = math.ceil(height * beta / factor) * factor
+ w_bar = math.ceil(width * beta / factor) * factor
+ return h_bar, w_bar
+
+
+class Qwen2VLImageProcessor(BaseImageProcessor):
+ r"""
+ Constructs a Qwen2-VL image processor that dynamically resizes images based on the original images.
+
+ Args:
+ do_resize (`bool`, *optional*, defaults to `True`):
+ Whether to resize the image's (height, width) dimensions.
+ resample (`PILImageResampling`, *optional*, defaults to `Resampling.BICUBIC`):
+ Resampling filter to use when resizing the image.
+ do_rescale (`bool`, *optional*, defaults to `True`):
+ Whether to rescale the image by the specified scale `rescale_factor`.
+ rescale_factor (`int` or `float`, *optional*, defaults to `1/255`):
+ Scale factor to use if rescaling the image.
+ do_normalize (`bool`, *optional*, defaults to `True`):
+ Whether to normalize the image.
+ image_mean (`float` or `List[float]`, *optional*, defaults to `[0.48145466, 0.4578275, 0.40821073]`):
+ Mean to use if normalizing the image. This is a float or list of floats for each channel in the image.
+ image_std (`float` or `List[float]`, *optional*, defaults to `[0.26862954, 0.26130258, 0.27577711]`):
+ Standard deviation to use if normalizing the image. This is a float or list of floats for each channel in the image.
+ do_convert_rgb (`bool`, *optional*, defaults to `True`):
+ Whether to convert the image to RGB.
+ min_pixels (`int`, *optional*, defaults to `56 * 56`):
+ The min pixels of the image to resize the image.
+ max_pixels (`int`, *optional*, defaults to `28 * 28 * 1280`):
+ The max pixels of the image to resize the image.
+ patch_size (`int`, *optional*, defaults to 14):
+ The spacial patch size of the vision encoder.
+ temporal_patch_size (`int`, *optional*, defaults to 2):
+ The temporal patch size of the vision encoder.
+ merge_size (`int`, *optional*, defaults to 2):
+ The merge size of the vision encoder to llm encoder.
+ """
+
+ model_input_names = ["pixel_values", "image_grid_thw", "pixel_values_videos", "video_grid_thw"]
+
+ def __init__(
+ self,
+ do_resize: bool = True,
+ resample: PILImageResampling = PILImageResampling.BICUBIC,
+ do_rescale: bool = True,
+ rescale_factor: Union[int, float] = 1 / 255,
+ do_normalize: bool = True,
+ image_mean: Optional[Union[float, List[float]]] = None,
+ image_std: Optional[Union[float, List[float]]] = None,
+ do_convert_rgb: bool = True,
+ min_pixels: int = 56 * 56,
+ max_pixels: int = 28 * 28 * 1280,
+ patch_size: int = 14,
+ temporal_patch_size: int = 2,
+ merge_size: int = 2,
+ **kwargs,
+ ) -> None:
+ super().__init__(**kwargs)
+ self.do_resize = do_resize
+ self.resample = resample
+ self.do_rescale = do_rescale
+ self.rescale_factor = rescale_factor
+ self.do_normalize = do_normalize
+ self.image_mean = image_mean if image_mean is not None else OPENAI_CLIP_MEAN
+ self.image_std = image_std if image_std is not None else OPENAI_CLIP_STD
+ self.min_pixels = min_pixels
+ self.max_pixels = max_pixels
+ self.patch_size = patch_size
+ self.temporal_patch_size = temporal_patch_size
+ self.merge_size = merge_size
+ self.size = {"min_pixels": min_pixels, "max_pixels": max_pixels}
+ self.do_convert_rgb = do_convert_rgb
+
+ def _preprocess(
+ self,
+ images: Union[ImageInput, VideoInput],
+ do_resize: bool = None,
+ resample: PILImageResampling = None,
+ do_rescale: bool = None,
+ rescale_factor: float = None,
+ do_normalize: bool = None,
+ image_mean: Optional[Union[float, List[float]]] = None,
+ image_std: Optional[Union[float, List[float]]] = None,
+ do_convert_rgb: bool = None,
+ data_format: Optional[ChannelDimension] = ChannelDimension.FIRST,
+ input_data_format: Optional[Union[str, ChannelDimension]] = None,
+ ):
+ """
+ Preprocess an image or batch of images. Copy of the `preprocess` method from `CLIPImageProcessor`.
+
+ Args:
+ images (`ImageInput`):
+ Image or batch of images to preprocess. Expects pixel values ranging from 0 to 255. If pixel values range from 0 to 1, set `do_rescale=False`.
+ vision_info (`List[Dict]`, *optional*):
+ Optional list of dictionaries containing additional information about vision inputs.
+ do_resize (`bool`, *optional*, defaults to `self.do_resize`):
+ Whether to resize the image.
+ resample (`PILImageResampling`, *optional*, defaults to `self.resample`):
+ Resampling filter to use if resizing the image. This can be one of the `PILImageResampling` enums.
+ do_rescale (`bool`, *optional*, defaults to `self.do_rescale`):
+ Whether to rescale the image.
+ rescale_factor (`float`, *optional*, defaults to `self.rescale_factor`):
+ Scale factor to use if rescaling the image.
+ do_normalize (`bool`, *optional*, defaults to `self.do_normalize`):
+ Whether to normalize the image.
+ image_mean (`float` or `List[float]`, *optional*, defaults to `self.image_mean`):
+ Mean to use if normalizing the image. Can be a float or a list of floats corresponding to the number of channels in the image.
+ image_std (`float` or `List[float]`, *optional*, defaults to `self.image_std`):
+ Standard deviation to use if normalizing the image. Can be a float or a list of floats corresponding to the number of channels in the image.
+ do_convert_rgb (`bool`, *optional*, defaults to `self.do_convert_rgb`):
+ Whether to convert the image to RGB.
+ data_format (`ChannelDimension`, *optional*, defaults to `ChannelDimension.FIRST`):
+ The channel dimension format for the output image. Can be one of:
+ - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
+ - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
+ - Unset: Use the channel dimension format of the input image.
+ input_data_format (`ChannelDimension` or `str`, *optional*):
+ The channel dimension format for the input image. Can be one of:
+ - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
+ - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
+ - `"none"` or `ChannelDimension.NONE`: image in (height, width) format. - `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
+ """
+ images = make_list_of_images(images)
+
+ if do_convert_rgb:
+ images = [convert_to_rgb(image) for image in images]
+
+ # All transformations expect numpy arrays.
+ images = [to_numpy_array(image) for image in images]
+
+ if is_scaled_image(images[0]) and do_rescale:
+ logger.warning_once(
+ "It looks like you are trying to rescale already rescaled images. If the input"
+ " images have pixel values between 0 and 1, set `do_rescale=False` to avoid rescaling them again."
+ )
+ if input_data_format is None:
+ # We assume that all images have the same channel dimension format.
+ input_data_format = infer_channel_dimension_format(images[0])
+
+ height, width = get_image_size(images[0], channel_dim=input_data_format)
+ resized_height, resized_width = height, width
+ processed_images = []
+ for image in images:
+ if do_resize:
+ resized_height, resized_width = smart_resize(
+ height,
+ width,
+ factor=self.patch_size * self.merge_size,
+ min_pixels=self.min_pixels,
+ max_pixels=self.max_pixels,
+ )
+ image = resize(
+ image, size=(resized_height, resized_width), resample=resample, input_data_format=input_data_format
+ )
+
+ if do_rescale:
+ image = self.rescale(image, scale=rescale_factor, input_data_format=input_data_format)
+
+ if do_normalize:
+ image = self.normalize(
+ image=image, mean=image_mean, std=image_std, input_data_format=input_data_format
+ )
+
+ image = to_channel_dimension_format(image, data_format, input_channel_dim=input_data_format)
+ processed_images.append(image)
+
+ patches = np.array(processed_images)
+ if data_format == ChannelDimension.LAST:
+ patches = patches.transpose(0, 3, 1, 2)
+ if patches.shape[0] == 1:
+ patches = np.tile(patches, (self.temporal_patch_size, 1, 1, 1))
+ channel = patches.shape[1]
+ grid_t = patches.shape[0] // self.temporal_patch_size
+ grid_h, grid_w = resized_height // self.patch_size, resized_width // self.patch_size
+ patches = patches.reshape(
+ grid_t,
+ self.temporal_patch_size,
+ channel,
+ grid_h // self.merge_size,
+ self.merge_size,
+ self.patch_size,
+ grid_w // self.merge_size,
+ self.merge_size,
+ self.patch_size,
+ )
+ patches = patches.transpose(0, 3, 6, 4, 7, 2, 1, 5, 8)
+ flatten_patches = patches.reshape(
+ grid_t * grid_h * grid_w, channel * self.temporal_patch_size * self.patch_size * self.patch_size
+ )
+
+ return flatten_patches, (grid_t, grid_h, grid_w)
+
+ def preprocess(
+ self,
+ images: ImageInput,
+ videos: VideoInput = None,
+ do_resize: bool = None,
+ size: Dict[str, int] = None,
+ resample: PILImageResampling = None,
+ do_rescale: bool = None,
+ rescale_factor: float = None,
+ do_normalize: bool = None,
+ image_mean: Optional[Union[float, List[float]]] = None,
+ image_std: Optional[Union[float, List[float]]] = None,
+ do_convert_rgb: bool = None,
+ return_tensors: Optional[Union[str, TensorType]] = None,
+ data_format: Optional[ChannelDimension] = ChannelDimension.FIRST,
+ input_data_format: Optional[Union[str, ChannelDimension]] = None,
+ ):
+ """
+ Args:
+ images (`ImageInput`):
+ Image to preprocess. Expects a single or batch of images with pixel values ranging from 0 to 255. If
+ passing in images with pixel values between 0 and 1, set `do_rescale=False`.
+ videos (`VideoInput`):
+ Video to preprocess. Expects a single or batch of videos with pixel values ranging from 0 to 255. If
+ passing in videos with pixel values between 0 and 1, set `do_rescale=False`.
+ do_resize (`bool`, *optional*, defaults to `self.do_resize`):
+ Whether to resize the image.
+ size (`Dict[str, int]`, *optional*, defaults to `self.size`):
+ Size of the image after resizing. Shortest edge of the image is resized to size["shortest_edge"], with
+ the longest edge resized to keep the input aspect ratio.
+ resample (`int`, *optional*, defaults to `self.resample`):
+ Resampling filter to use if resizing the image. This can be one of the enum `PILImageResampling`. Only
+ has an effect if `do_resize` is set to `True`.
+ do_rescale (`bool`, *optional*, defaults to `self.do_rescale`):
+ Whether to rescale the image.
+ rescale_factor (`float`, *optional*, defaults to `self.rescale_factor`):
+ Rescale factor to rescale the image by if `do_rescale` is set to `True`.
+ do_normalize (`bool`, *optional*, defaults to `self.do_normalize`):
+ Whether to normalize the image.
+ image_mean (`float` or `List[float]`, *optional*, defaults to `self.image_mean`):
+ Image mean to use for normalization. Only has an effect if `do_normalize` is set to `True`.
+ image_std (`float` or `List[float]`, *optional*, defaults to `self.image_std`):
+ Image standard deviation to use for normalization. Only has an effect if `do_normalize` is set to
+ `True`.
+ do_convert_rgb (`bool`, *optional*, defaults to `self.do_convert_rgb`):
+ Whether to convert the image to RGB.
+ return_tensors (`str` or `TensorType`, *optional*):
+ The type of tensors to return. Can be one of:
+ - Unset: Return a list of `np.ndarray`.
+ - `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`.
+ - `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`.
+ - `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`.
+ - `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`.
+ data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`):
+ The channel dimension format for the output image. Can be one of:
+ - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
+ - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
+ - Unset: Use the channel dimension format of the input image.
+ input_data_format (`ChannelDimension` or `str`, *optional*):
+ The channel dimension format for the input image. If unset, the channel dimension format is inferred
+ from the input image. Can be one of:
+ - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
+ - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
+ - `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
+
+ """
+ do_resize = do_resize if do_resize is not None else self.do_resize
+ size = size if size is not None else self.size
+ resample = resample if resample is not None else self.resample
+ do_rescale = do_rescale if do_rescale is not None else self.do_rescale
+ rescale_factor = rescale_factor if rescale_factor is not None else self.rescale_factor
+ do_normalize = do_normalize if do_normalize is not None else self.do_normalize
+ image_mean = image_mean if image_mean is not None else self.image_mean
+ image_std = image_std if image_std is not None else self.image_std
+ do_convert_rgb = do_convert_rgb if do_convert_rgb is not None else self.do_convert_rgb
+
+ if images is not None:
+ images = make_batched_images(images)
+ if videos is not None:
+ videos = make_batched_videos(videos)
+
+ if images is not None and not valid_images(images):
+ raise ValueError(
+ "Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, "
+ "torch.Tensor, tf.Tensor or jax.ndarray."
+ )
+
+ validate_preprocess_arguments(
+ rescale_factor=rescale_factor,
+ do_normalize=do_normalize,
+ image_mean=image_mean,
+ image_std=image_std,
+ do_resize=do_resize,
+ size=size,
+ resample=resample,
+ )
+
+ if images is not None:
+ pixel_values, vision_grid_thws = [], []
+ for image in images:
+ patches, image_grid_thw = self._preprocess(
+ image,
+ do_resize=do_resize,
+ resample=resample,
+ do_rescale=do_rescale,
+ rescale_factor=rescale_factor,
+ do_normalize=do_normalize,
+ image_mean=image_mean,
+ image_std=image_std,
+ data_format=data_format,
+ do_convert_rgb=do_convert_rgb,
+ input_data_format=input_data_format,
+ )
+ pixel_values.extend(patches)
+ vision_grid_thws.append(image_grid_thw)
+ pixel_values = np.array(pixel_values)
+ vision_grid_thws = np.array(vision_grid_thws)
+ data = {"pixel_values": pixel_values, "image_grid_thw": vision_grid_thws}
+
+ if videos is not None:
+ pixel_values, vision_grid_thws = [], []
+ for images in videos:
+ patches, video_grid_thw = self._preprocess(
+ images,
+ do_resize=do_resize,
+ resample=resample,
+ do_rescale=do_rescale,
+ rescale_factor=rescale_factor,
+ do_normalize=do_normalize,
+ image_mean=image_mean,
+ image_std=image_std,
+ data_format=data_format,
+ do_convert_rgb=do_convert_rgb,
+ input_data_format=input_data_format,
+ )
+ pixel_values.extend(patches)
+ vision_grid_thws.append(video_grid_thw)
+ pixel_values = np.array(pixel_values)
+ vision_grid_thws = np.array(vision_grid_thws)
+ data = {"pixel_values_videos": pixel_values, "video_grid_thw": vision_grid_thws}
+
+ return BatchFeature(data=data, tensor_type=return_tensors)
diff --git a/qwen2_vl/modeling_qwen2_vl.py b/qwen2_vl/modeling_qwen2_vl.py
new file mode 100644
index 0000000000000000000000000000000000000000..99e130bda22cdf72c8a21c44128253070d84d4d3
--- /dev/null
+++ b/qwen2_vl/modeling_qwen2_vl.py
@@ -0,0 +1,1952 @@
+# coding=utf-8
+# Copyright 2024 The Qwen team, Alibaba Group and the HuggingFace Inc. team. All rights reserved.
+#
+# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
+# and OPT implementations in this library. It has been modified from its
+# original forms to accommodate minor architectural differences compared
+# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""PyTorch Qwen2-VL model."""
+
+import math
+from dataclasses import dataclass
+from typing import Any, Dict, List, Optional, Tuple, Union
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import torch.utils.checkpoint
+from torch.nn import CrossEntropyLoss, LayerNorm
+
+from transformers.activations import ACT2FN
+from transformers.cache_utils import Cache, StaticCache
+from transformers.modeling_attn_mask_utils import (
+ AttentionMaskConverter,
+)
+from transformers.modeling_outputs import (
+ BaseModelOutputWithPast,
+ ModelOutput,
+)
+from transformers.modeling_utils import PreTrainedModel
+from transformers.utils import (
+ add_start_docstrings,
+ add_start_docstrings_to_model_forward,
+ is_flash_attn_2_available,
+ is_flash_attn_greater_or_equal_2_10,
+ logging,
+ replace_return_docstrings,
+)
+from qwen2_vl.configuration_qwen2_vl import Qwen2VLConfig, Qwen2VLVisionConfig
+
+import traceback
+
+
+if is_flash_attn_2_available():
+ from flash_attn import flash_attn_varlen_func
+
+ from ...modeling_flash_attention_utils import _flash_attention_forward
+else:
+ flash_attn_varlen_func = None
+
+
+logger = logging.get_logger(__name__)
+
+_CONFIG_FOR_DOC = "Qwen2VLConfig"
+
+
+@dataclass
+class Qwen2VLCausalLMOutputWithPast(ModelOutput):
+ """
+ Base class for Qwen2VL causal language model (or autoregressive) outputs.
+
+ Args:
+ loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
+ Language modeling loss (for next-token prediction).
+ logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
+ Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
+ past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
+ Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
+ `(batch_size, num_heads, sequence_length, embed_size_per_head)`)
+
+ Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see
+ `past_key_values` input) to speed up sequential decoding.
+ hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
+ one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
+
+ Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
+ attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
+ sequence_length)`.
+
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
+ heads.
+ rope_deltas (`torch.LongTensor` of shape `(batch_size, )`, *optional*):
+ The rope index difference between sequence length and multimodal rope.
+ """
+
+ loss: Optional[torch.FloatTensor] = None
+ logits: torch.FloatTensor = None
+ past_key_values: Optional[List[torch.FloatTensor]] = None
+ hidden_states: Optional[Tuple[torch.FloatTensor]] = None
+ attentions: Optional[Tuple[torch.FloatTensor]] = None
+ rope_deltas: Optional[torch.LongTensor] = None
+
+
+# Copied from transformers.models.qwen2.modeling_qwen2.Qwen2RotaryEmbedding
+class Qwen2RotaryEmbedding(nn.Module):
+ def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
+ super().__init__()
+
+ self.dim = dim
+ self.max_position_embeddings = max_position_embeddings
+ self.base = base
+ inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(device) / self.dim))
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
+
+ # Build here to make `torch.jit.trace` work.
+ self._set_cos_sin_cache(
+ seq_len=max_position_embeddings, device=self.inv_freq.device, dtype=torch.get_default_dtype()
+ )
+
+ def _set_cos_sin_cache(self, seq_len, device, dtype):
+ self.max_seq_len_cached = seq_len
+ t = torch.arange(self.max_seq_len_cached, device=device, dtype=torch.int64).type_as(self.inv_freq)
+
+ freqs = torch.outer(t, self.inv_freq)
+ # Different from paper, but it uses a different permutation in order to obtain the same calculation
+ emb = torch.cat((freqs, freqs), dim=-1)
+ self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
+ self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
+
+ def forward(self, x, seq_len=None):
+ # x: [bs, num_attention_heads, seq_len, head_size]
+ if seq_len > self.max_seq_len_cached:
+ self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype)
+
+ return (
+ self.cos_cached[:seq_len].to(dtype=x.dtype),
+ self.sin_cached[:seq_len].to(dtype=x.dtype),
+ )
+
+
+# Copied from transformers.models.llama.modeling_llama.rotate_half
+def rotate_half(x):
+ """Rotates half the hidden dims of the input."""
+ x1 = x[..., : x.shape[-1] // 2]
+ x2 = x[..., x.shape[-1] // 2 :]
+ return torch.cat((-x2, x1), dim=-1)
+
+
+def apply_multimodal_rotary_pos_emb(q, k, cos, sin, position_ids, mrope_section, unsqueeze_dim=1):
+ """Applies Rotary Position Embedding with Multimodal Sections to the query and key tensors (https://qwenlm.github.io/blog/qwen2-vl/).
+
+ Explanation:
+ Multimodal 3D rotary position embedding is an extension to 1D rotary position embedding. The input embedding
+ sequence contains vision (images / videos) embedding and text embedding or just contains text embedding. For
+ vision embedding part, we apply rotary position embedding on temporal, height and width dimension seperately.
+ Here we split the channel dimension to 3 chunks for the temporal, height and width rotary position embedding.
+ For text embedding part, we just apply 1D rotary position embedding. The three rotary position index (temporal,
+ height and width) of text embedding is always the same, so the text embedding rotary position embedding has no
+ difference with modern LLMs.
+
+ Args:
+ q (`torch.Tensor`): The query tensor.
+ k (`torch.Tensor`): The key tensor.
+ cos (`torch.Tensor`): The cosine part of the rotary embedding.
+ sin (`torch.Tensor`): The sine part of the rotary embedding.
+ position_ids (`torch.Tensor`):
+ The position indices of the tokens corresponding to the query and key tensors. For example, this can be
+ used to pass offsetted position ids when working with a KV-cache.
+ mrope_section(`List(int)`):
+ Multimodal rope section is for channel dimension of temporal, height and width in rope calculation.
+ unsqueeze_dim (`int`, *optional*, defaults to 1):
+ The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
+ sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
+ that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
+ k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
+ cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
+ the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
+ Returns:
+ `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
+ """
+ cos = cos[position_ids]
+ sin = sin[position_ids]
+ mrope_section = mrope_section * 2
+ cos = torch.cat([m[i % 3] for i, m in enumerate(cos.split(mrope_section, dim=-1))], dim=-1).unsqueeze(
+ unsqueeze_dim
+ )
+ sin = torch.cat([m[i % 3] for i, m in enumerate(sin.split(mrope_section, dim=-1))], dim=-1).unsqueeze(
+ unsqueeze_dim
+ )
+
+ q_embed = (q * cos) + (rotate_half(q) * sin)
+ k_embed = (k * cos) + (rotate_half(k) * sin)
+ return q_embed, k_embed
+
+
+def apply_rotary_pos_emb_vision(tensor: torch.Tensor, freqs: torch.Tensor) -> torch.Tensor:
+ orig_dtype = tensor.dtype
+ tensor = tensor.float()
+ cos = freqs.cos()
+ sin = freqs.sin()
+ cos = cos.unsqueeze(1).repeat(1, 1, 2).unsqueeze(0).float()
+ sin = sin.unsqueeze(1).repeat(1, 1, 2).unsqueeze(0).float()
+ output = (tensor * cos) + (rotate_half(tensor) * sin)
+ output = output.to(orig_dtype)
+ return output
+
+
+class VisionRotaryEmbedding(nn.Module):
+ def __init__(self, dim: int, theta: float = 10000.0) -> None:
+ super().__init__()
+ inv_freq = 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=torch.float) / dim))
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
+
+ def forward(self, seqlen: int) -> torch.Tensor:
+ seq = torch.arange(seqlen, device=self.inv_freq.device, dtype=self.inv_freq.dtype)
+ freqs = torch.outer(seq, self.inv_freq)
+ return freqs
+
+
+class PatchEmbed(nn.Module):
+ def __init__(
+ self,
+ patch_size: int = 14,
+ temporal_patch_size: int = 2,
+ in_channels: int = 3,
+ embed_dim: int = 1152,
+ ) -> None:
+ super().__init__()
+ self.patch_size = patch_size
+ self.temporal_patch_size = temporal_patch_size
+ self.in_channels = in_channels
+ self.embed_dim = embed_dim
+
+ kernel_size = [temporal_patch_size, patch_size, patch_size]
+ self.proj = nn.Conv3d(in_channels, embed_dim, kernel_size=kernel_size, stride=kernel_size, bias=False)
+
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
+ target_dtype = self.proj.weight.dtype
+ hidden_states = hidden_states.view(
+ -1, self.in_channels, self.temporal_patch_size, self.patch_size, self.patch_size
+ )
+ hidden_states = self.proj(hidden_states.to(dtype=target_dtype)).view(-1, self.embed_dim)
+ return hidden_states
+
+
+class PatchMerger(nn.Module):
+ def __init__(self, dim: int, context_dim: int, spatial_merge_size: int = 2) -> None:
+ super().__init__()
+ self.hidden_size = context_dim * (spatial_merge_size**2)
+ self.ln_q = LayerNorm(context_dim, eps=1e-6)
+ self.mlp = nn.Sequential(
+ nn.Linear(self.hidden_size, self.hidden_size),
+ nn.GELU(),
+ nn.Linear(self.hidden_size, dim),
+ )
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ x = self.mlp(self.ln_q(x).view(-1, self.hidden_size))
+ return x
+
+
+class VisionMlp(nn.Module):
+ def __init__(self, dim: int, hidden_dim: int, hidden_act: str) -> None:
+ super().__init__()
+ self.fc1 = nn.Linear(dim, hidden_dim)
+ self.act = ACT2FN[hidden_act]
+ self.fc2 = nn.Linear(hidden_dim, dim)
+
+ def forward(self, x) -> torch.Tensor:
+ return self.fc2(self.act(self.fc1(x)))
+
+
+class VisionAttention(nn.Module):
+ def __init__(self, dim: int, num_heads: int = 16) -> None:
+ super().__init__()
+ self.num_heads = num_heads
+ self.head_dim = dim // num_heads
+ self.qkv = nn.Linear(dim, dim * 3, bias=True)
+ self.proj = nn.Linear(dim, dim)
+
+ def forward(
+ self, hidden_states: torch.Tensor, cu_seqlens: torch.Tensor, rotary_pos_emb: torch.Tensor = None
+ ) -> torch.Tensor:
+ seq_length = hidden_states.shape[0]
+ q, k, v = self.qkv(hidden_states).reshape(seq_length, 3, self.num_heads, -1).permute(1, 0, 2, 3).unbind(0)
+ q = apply_rotary_pos_emb_vision(q.unsqueeze(0), rotary_pos_emb).squeeze(0)
+ k = apply_rotary_pos_emb_vision(k.unsqueeze(0), rotary_pos_emb).squeeze(0)
+
+ attention_mask = torch.full(
+ [1, seq_length, seq_length], torch.finfo(q.dtype).min, device=q.device, dtype=q.dtype
+ )
+ for i in range(1, len(cu_seqlens)):
+ attention_mask[..., cu_seqlens[i - 1] : cu_seqlens[i], cu_seqlens[i - 1] : cu_seqlens[i]] = 0
+
+ q = q.transpose(0, 1)
+ k = k.transpose(0, 1)
+ v = v.transpose(0, 1)
+ attn_weights = torch.matmul(q, k.transpose(1, 2)) / math.sqrt(self.head_dim)
+ attn_weights = attn_weights + attention_mask
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(q.dtype)
+ attn_output = torch.matmul(attn_weights, v)
+ attn_output = attn_output.transpose(0, 1)
+ attn_output = attn_output.reshape(seq_length, -1)
+ attn_output = self.proj(attn_output)
+ return attn_output
+
+
+class VisionFlashAttention2(nn.Module):
+ def __init__(self, dim: int, num_heads: int = 16) -> None:
+ super().__init__()
+ self.num_heads = num_heads
+ self.qkv = nn.Linear(dim, dim * 3, bias=True)
+ self.proj = nn.Linear(dim, dim)
+
+ def forward(
+ self, hidden_states: torch.Tensor, cu_seqlens: torch.Tensor, rotary_pos_emb: torch.Tensor = None
+ ) -> torch.Tensor:
+ seq_length = hidden_states.shape[0]
+ q, k, v = self.qkv(hidden_states).reshape(seq_length, 3, self.num_heads, -1).permute(1, 0, 2, 3).unbind(0)
+ q = apply_rotary_pos_emb_vision(q.unsqueeze(0), rotary_pos_emb).squeeze(0)
+ k = apply_rotary_pos_emb_vision(k.unsqueeze(0), rotary_pos_emb).squeeze(0)
+
+ max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item()
+ attn_output = flash_attn_varlen_func(q, k, v, cu_seqlens, cu_seqlens, max_seqlen, max_seqlen).reshape(
+ seq_length, -1
+ )
+ attn_output = self.proj(attn_output)
+ return attn_output
+
+
+class VisionSdpaAttention(nn.Module):
+ def __init__(self, dim: int, num_heads: int = 16) -> None:
+ super().__init__()
+ self.num_heads = num_heads
+ self.qkv = nn.Linear(dim, dim * 3, bias=True)
+ self.proj = nn.Linear(dim, dim)
+
+ def forward(
+ self, hidden_states: torch.Tensor, cu_seqlens: torch.Tensor, rotary_pos_emb: torch.Tensor = None
+ ) -> torch.Tensor:
+ seq_length = hidden_states.shape[0]
+ q, k, v = self.qkv(hidden_states).reshape(seq_length, 3, self.num_heads, -1).permute(1, 0, 2, 3).unbind(0)
+ q = apply_rotary_pos_emb_vision(q.unsqueeze(0), rotary_pos_emb).squeeze(0)
+ k = apply_rotary_pos_emb_vision(k.unsqueeze(0), rotary_pos_emb).squeeze(0)
+
+ attention_mask = torch.zeros([1, seq_length, seq_length], device=q.device, dtype=torch.bool)
+ for i in range(1, len(cu_seqlens)):
+ attention_mask[..., cu_seqlens[i - 1] : cu_seqlens[i], cu_seqlens[i - 1] : cu_seqlens[i]] = True
+ q = q.transpose(0, 1)
+ k = k.transpose(0, 1)
+ v = v.transpose(0, 1)
+ attn_output = F.scaled_dot_product_attention(q, k, v, attention_mask, dropout_p=0.0)
+ attn_output = attn_output.transpose(0, 1)
+ attn_output = attn_output.reshape(seq_length, -1)
+ attn_output = self.proj(attn_output)
+ return attn_output
+
+
+QWEN2_VL_VISION_ATTENTION_CLASSES = {
+ "eager": VisionAttention,
+ "flash_attention_2": VisionFlashAttention2,
+ "sdpa": VisionSdpaAttention,
+}
+
+
+class Qwen2VLVisionBlock(nn.Module):
+ def __init__(self, config, attn_implementation: str = "sdpa") -> None:
+ super().__init__()
+ self.norm1 = LayerNorm(config.embed_dim, eps=1e-6)
+ self.norm2 = LayerNorm(config.embed_dim, eps=1e-6)
+ mlp_hidden_dim = int(config.embed_dim * config.mlp_ratio)
+
+ self.attn = QWEN2_VL_VISION_ATTENTION_CLASSES[attn_implementation](
+ config.embed_dim, num_heads=config.num_heads
+ )
+ self.mlp = VisionMlp(dim=config.embed_dim, hidden_dim=mlp_hidden_dim, hidden_act=config.hidden_act)
+
+ def forward(self, hidden_states, cu_seqlens, rotary_pos_emb) -> torch.Tensor:
+ hidden_states = hidden_states + self.attn(
+ self.norm1(hidden_states), cu_seqlens=cu_seqlens, rotary_pos_emb=rotary_pos_emb
+ )
+ hidden_states = hidden_states + self.mlp(self.norm2(hidden_states))
+ return hidden_states
+
+
+# Copied from transformers.models.llama.modeling_llama._prepare_4d_causal_attention_mask_with_cache_position
+def _prepare_4d_causal_attention_mask_with_cache_position(
+ attention_mask: torch.Tensor,
+ sequence_length: int,
+ target_length: int,
+ dtype: torch.dtype,
+ device: torch.device,
+ min_dtype: float,
+ cache_position: torch.Tensor,
+ batch_size: int,
+):
+ """
+ Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
+ `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.
+
+ Args:
+ attention_mask (`torch.Tensor`):
+ A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape `(batch_size, 1, query_length, key_value_length)`.
+ sequence_length (`int`):
+ The sequence length being processed.
+ target_length (`int`):
+ The target length: when generating with static cache, the mask should be as long as the static cache, to account for the 0 padding, the part of the cache that is not filled yet.
+ dtype (`torch.dtype`):
+ The dtype to use for the 4D attention mask.
+ device (`torch.device`):
+ The device to plcae the 4D attention mask on.
+ min_dtype (`float`):
+ The minimum value representable with the dtype `dtype`.
+ cache_position (`torch.Tensor`):
+ Indices depicting the position of the input sequence tokens in the sequence.
+ batch_size (`torch.Tensor`):
+ Batch size.
+ """
+ if attention_mask is not None and attention_mask.dim() == 4:
+ # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
+ causal_mask = attention_mask
+ else:
+ causal_mask = torch.full((sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device)
+ if sequence_length != 1:
+ causal_mask = torch.triu(causal_mask, diagonal=1)
+ causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
+ causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
+ if attention_mask is not None:
+ causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
+ mask_length = attention_mask.shape[-1]
+ padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :]
+ padding_mask = padding_mask == 0
+ causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
+ padding_mask, min_dtype
+ )
+
+ return causal_mask
+
+
+# Copied from transformers.models.qwen2.modeling_qwen2.Qwen2RMSNorm
+class Qwen2RMSNorm(nn.Module):
+ def __init__(self, hidden_size, eps=1e-6):
+ """
+ Qwen2RMSNorm is equivalent to T5LayerNorm
+ """
+ super().__init__()
+ self.weight = nn.Parameter(torch.ones(hidden_size))
+ self.variance_epsilon = eps
+
+ def forward(self, hidden_states):
+ input_dtype = hidden_states.dtype
+ hidden_states = hidden_states.to(torch.float32)
+ variance = hidden_states.pow(2).mean(-1, keepdim=True)
+ hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
+ return self.weight * hidden_states.to(input_dtype)
+
+ def extra_repr(self):
+ return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
+
+
+# Copied from transformers.models.qwen2.modeling_qwen2.Qwen2MLP
+class Qwen2MLP(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.hidden_size = config.hidden_size
+ self.intermediate_size = config.intermediate_size
+ self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
+ self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
+ self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
+ self.act_fn = ACT2FN[config.hidden_act]
+
+ def forward(self, hidden_state):
+ return self.down_proj(self.act_fn(self.gate_proj(hidden_state)) * self.up_proj(hidden_state))
+
+
+# Copied from transformers.models.llama.modeling_llama.repeat_kv
+def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
+ """
+ This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
+ num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
+ """
+ batch, num_key_value_heads, slen, head_dim = hidden_states.shape
+ if n_rep == 1:
+ return hidden_states
+ hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
+ return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
+
+
+class Qwen2VLAttention(nn.Module):
+ """
+ Multi-headed attention from 'Attention Is All You Need' paper. Modified to use sliding window attention: Longformer
+ and "Generating Long Sequences with Sparse Transformers".
+ """
+
+ def __init__(self, config: Qwen2VLConfig, layer_idx: Optional[int] = None):
+ super().__init__()
+ self.config = config
+ self.layer_idx = layer_idx
+ if layer_idx is None:
+ logger.warning_once(
+ f"Instantiating {self.__class__.__name__} without passing `layer_idx` is not recommended and will "
+ "to errors during the forward call, if caching is used. Please make sure to provide a `layer_idx` "
+ "when creating this class."
+ )
+
+ self.hidden_size = config.hidden_size
+ self.num_heads = config.num_attention_heads
+ self.head_dim = self.hidden_size // self.num_heads
+ self.num_key_value_heads = config.num_key_value_heads
+ self.num_key_value_groups = self.num_heads // self.num_key_value_heads
+ self.max_position_embeddings = config.max_position_embeddings
+ self.rope_theta = config.rope_theta
+ self.is_causal = True
+ self.attention_dropout = config.attention_dropout
+ self.rope_scaling = config.rope_scaling
+
+ if (self.head_dim * self.num_heads) != self.hidden_size:
+ raise ValueError(
+ f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
+ f" and `num_heads`: {self.num_heads})."
+ )
+ self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=True)
+ self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=True)
+ self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=True)
+ self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
+
+ self.rotary_emb = Qwen2RotaryEmbedding(
+ self.head_dim,
+ max_position_embeddings=self.max_position_embeddings,
+ base=self.rope_theta,
+ )
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_value: Optional[Cache] = None,
+ output_attentions: bool = False,
+ use_cache: bool = False,
+ cache_position: Optional[torch.LongTensor] = None,
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
+ bsz, q_len, _ = hidden_states.size()
+
+ query_states = self.q_proj(hidden_states)
+ key_states = self.k_proj(hidden_states)
+ value_states = self.v_proj(hidden_states)
+
+ query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
+ key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
+ value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
+
+ kv_seq_len = key_states.shape[-2]
+ if past_key_value is not None:
+ if self.layer_idx is None:
+ raise ValueError(
+ f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} "
+ "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
+ "with a layer index."
+ )
+ kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
+ cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
+ query_states, key_states = apply_multimodal_rotary_pos_emb(
+ query_states, key_states, cos, sin, position_ids, self.rope_scaling["mrope_section"]
+ )
+
+ if past_key_value is not None:
+ cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models
+ key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
+
+ # repeat k/v heads if n_kv_heads < n_heads
+ key_states = repeat_kv(key_states, self.num_key_value_groups)
+ value_states = repeat_kv(value_states, self.num_key_value_groups)
+
+ attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
+
+ if attention_mask is not None: # no matter the length, we just slice it
+ causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
+ attn_weights = attn_weights + causal_mask
+
+ # upcast attention to fp32
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
+ attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training)
+ attn_output = torch.matmul(attn_weights, value_states)
+
+ if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
+ raise ValueError(
+ f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
+ f" {attn_output.size()}"
+ )
+
+ attn_output = attn_output.transpose(1, 2).contiguous()
+ attn_output = attn_output.reshape(bsz, q_len, -1)
+
+ attn_output = self.o_proj(attn_output)
+
+ if not output_attentions:
+ attn_weights = None
+
+ return attn_output, attn_weights, past_key_value
+
+
+class Qwen2VLFlashAttention2(Qwen2VLAttention):
+ """
+ Qwen2VL flash attention module, following Qwen2VL attention module. This module inherits from `Qwen2VLAttention`
+ as the weights of the module stays untouched. The only required change would be on the forward pass
+ where it needs to correctly call the public API of flash attention and deal with padding tokens
+ in case the input contains any of them. Additionally, for sliding window attention, we apply SWA only to the bottom
+ config.max_window_layers layers.
+ """
+
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+
+ # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.
+ # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
+ # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
+ self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_value: Optional[Cache] = None,
+ output_attentions: bool = False,
+ use_cache: bool = False,
+ cache_position: Optional[torch.LongTensor] = None,
+ ):
+ bsz, q_len, _ = hidden_states.size()
+
+ query_states = self.q_proj(hidden_states)
+ key_states = self.k_proj(hidden_states)
+ value_states = self.v_proj(hidden_states)
+
+ query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
+ key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
+ value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
+
+ kv_seq_len = key_states.shape[-2]
+ if past_key_value is not None:
+ if self.layer_idx is None:
+ raise ValueError(
+ f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} "
+ "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
+ "with a layer index."
+ )
+ kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
+
+ # Because the input can be padded, the absolute sequence length depends on the max position id.
+ rotary_seq_len = (
+ max(kv_seq_len, position_ids[:, -1].max().item() + 1) if position_ids is not None else kv_seq_len
+ )
+
+ cos, sin = self.rotary_emb(value_states, seq_len=rotary_seq_len)
+
+ query_states, key_states = apply_multimodal_rotary_pos_emb(
+ query_states, key_states, cos, sin, position_ids, self.rope_scaling["mrope_section"]
+ )
+
+ if past_key_value is not None:
+ # Activate slicing cache only if the config has a value `sliding_windows` attribute
+ cache_has_contents = past_key_value.get_seq_length(self.layer_idx) > 0
+ if (
+ getattr(self.config, "sliding_window", None) is not None
+ and kv_seq_len > self.config.sliding_window
+ and cache_has_contents
+ ):
+ slicing_tokens = 1 - self.config.sliding_window
+
+ past_key = past_key_value[self.layer_idx][0]
+ past_value = past_key_value[self.layer_idx][1]
+
+ past_key = past_key[:, :, slicing_tokens:, :].contiguous()
+ past_value = past_value[:, :, slicing_tokens:, :].contiguous()
+
+ if past_key.shape[-2] != self.config.sliding_window - 1:
+ raise ValueError(
+ f"past key must have a shape of (`batch_size, num_heads, self.config.sliding_window-1, head_dim`), got"
+ f" {past_key.shape}"
+ )
+
+ if attention_mask is not None:
+ attention_mask = attention_mask[:, slicing_tokens:]
+ attention_mask = torch.cat([attention_mask, torch.ones_like(attention_mask[:, -1:])], dim=-1)
+
+ cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models
+ key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
+
+ # repeat k/v heads if n_kv_heads < n_heads
+ key_states = repeat_kv(key_states, self.num_key_value_groups)
+ value_states = repeat_kv(value_states, self.num_key_value_groups)
+ dropout_rate = 0.0 if not self.training else self.attention_dropout
+
+ # In PEFT, usually we cast the layer norms in float32 for training stability reasons
+ # therefore the input hidden states gets silently casted in float32. Hence, we need
+ # cast them back in float16 just to be sure everything works as expected.
+ input_dtype = query_states.dtype
+ if input_dtype == torch.float32:
+ if torch.is_autocast_enabled():
+ target_dtype = torch.get_autocast_gpu_dtype()
+ # Handle the case where the model is quantized
+ elif hasattr(self.config, "_pre_quantization_dtype"):
+ target_dtype = self.config._pre_quantization_dtype
+ else:
+ target_dtype = self.q_proj.weight.dtype
+
+ logger.warning_once(
+ f"The input hidden states seems to be silently casted in float32, this might be related to"
+ f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
+ f" {target_dtype}."
+ )
+
+ query_states = query_states.to(target_dtype)
+ key_states = key_states.to(target_dtype)
+ value_states = value_states.to(target_dtype)
+
+ # Reashape to the expected shape for Flash Attention
+ query_states = query_states.transpose(1, 2)
+ key_states = key_states.transpose(1, 2)
+ value_states = value_states.transpose(1, 2)
+
+ if (
+ self.config.use_sliding_window
+ and getattr(self.config, "sliding_window", None) is not None
+ and self.layer_idx >= self.config.max_window_layers
+ ):
+ sliding_window = self.config.sliding_window
+ else:
+ sliding_window = None
+
+ attn_output = _flash_attention_forward(
+ query_states,
+ key_states,
+ value_states,
+ attention_mask,
+ q_len,
+ dropout=dropout_rate,
+ sliding_window=sliding_window,
+ is_causal=self.is_causal,
+ use_top_left_mask=self._flash_attn_uses_top_left_mask,
+ )
+
+ attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous()
+ attn_output = self.o_proj(attn_output)
+
+ if not output_attentions:
+ attn_weights = None
+
+ return attn_output, attn_weights, past_key_value
+
+
+class Qwen2VLSdpaAttention(Qwen2VLAttention):
+ """
+ Qwen2 attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from
+ `Qwen2Attention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to
+ SDPA API.
+ """
+
+ # Adapted from Qwen2Attention.forward
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_value: Optional[Cache] = None,
+ output_attentions: bool = False,
+ use_cache: bool = False,
+ cache_position: Optional[torch.LongTensor] = None,
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
+ if output_attentions:
+ # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented.
+ logger.warning_once(
+ "Qwen2VLModel is using Qwen2VLSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, "
+ 'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
+ )
+ return super().forward(
+ hidden_states=hidden_states,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_value=past_key_value,
+ output_attentions=output_attentions,
+ use_cache=use_cache,
+ )
+
+ bsz, q_len, _ = hidden_states.size()
+
+ query_states = self.q_proj(hidden_states)
+ key_states = self.k_proj(hidden_states)
+ value_states = self.v_proj(hidden_states)
+
+ query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
+ key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
+ value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
+
+ kv_seq_len = key_states.shape[-2]
+ if past_key_value is not None:
+ kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
+ cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
+ query_states, key_states = apply_multimodal_rotary_pos_emb(
+ query_states, key_states, cos, sin, position_ids, self.rope_scaling["mrope_section"]
+ )
+
+ if past_key_value is not None:
+ cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models
+ key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
+
+ key_states = repeat_kv(key_states, self.num_key_value_groups)
+ value_states = repeat_kv(value_states, self.num_key_value_groups)
+
+ causal_mask = attention_mask
+ if attention_mask is not None: # no matter the length, we just slice it
+ causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
+
+ # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
+ # Reference: https://github.com/pytorch/pytorch/issues/112577.
+ if query_states.device.type == "cuda" and attention_mask is not None:
+ query_states = query_states.contiguous()
+ key_states = key_states.contiguous()
+ value_states = value_states.contiguous()
+
+ # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment
+ # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling.
+ # The q_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case q_len == 1.
+ is_causal = True if causal_mask is None and q_len > 1 else False
+
+ attn_output = torch.nn.functional.scaled_dot_product_attention(
+ query_states,
+ key_states,
+ value_states,
+ attn_mask=causal_mask,
+ dropout_p=self.attention_dropout if self.training else 0.0,
+ is_causal=is_causal,
+ )
+
+ attn_output = attn_output.transpose(1, 2).contiguous()
+ attn_output = attn_output.view(bsz, q_len, self.hidden_size)
+
+ attn_output = self.o_proj(attn_output)
+
+ return attn_output, None, past_key_value
+
+
+QWEN2_VL_ATTENTION_CLASSES = {
+ "eager": Qwen2VLAttention,
+ "flash_attention_2": Qwen2VLFlashAttention2,
+ "sdpa": Qwen2VLSdpaAttention,
+}
+
+
+class Qwen2VLDecoderLayer(nn.Module):
+ def __init__(self, config: Qwen2VLConfig, layer_idx: int):
+ super().__init__()
+ self.hidden_size = config.hidden_size
+
+ if config.use_sliding_window and config._attn_implementation != "flash_attention_2":
+ logger.warning_once(
+ f"Sliding Window Attention is enabled but not implemented for `{config._attn_implementation}`; "
+ "unexpected results may be encountered."
+ )
+ self.self_attn = QWEN2_VL_ATTENTION_CLASSES[config._attn_implementation](config, layer_idx)
+
+ self.mlp = Qwen2MLP(config)
+ self.input_layernorm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
+ self.post_attention_layernorm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
+ output_attentions: Optional[bool] = False,
+ use_cache: Optional[bool] = False,
+ cache_position: Optional[torch.LongTensor] = None,
+ **kwargs,
+ ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
+ """
+ Args:
+ hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
+ attention_mask (`torch.FloatTensor`, *optional*): attention mask of size
+ `(batch, sequence_length)` where padding elements are indicated by 0.
+ output_attentions (`bool`, *optional*):
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
+ returned tensors for more detail.
+ use_cache (`bool`, *optional*):
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
+ (see `past_key_values`).
+ past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
+ cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
+ Indices depicting the position of the input sequence tokens in the sequence.
+ kwargs (`dict`, *optional*):
+ Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code
+ into the model
+ """
+
+ residual = hidden_states
+
+ hidden_states = self.input_layernorm(hidden_states)
+
+ # Self Attention
+ hidden_states, self_attn_weights, present_key_value = self.self_attn(
+ hidden_states=hidden_states,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_value=past_key_value,
+ output_attentions=output_attentions,
+ use_cache=use_cache,
+ cache_position=cache_position,
+ )
+ hidden_states = residual + hidden_states
+
+ # Fully Connected
+ residual = hidden_states
+ hidden_states = self.post_attention_layernorm(hidden_states)
+ hidden_states = self.mlp(hidden_states)
+ hidden_states = residual + hidden_states
+
+ outputs = (hidden_states,)
+
+ if output_attentions:
+ outputs += (self_attn_weights,)
+
+ if use_cache:
+ outputs += (present_key_value,)
+
+ return outputs
+
+
+QWEN2VL_START_DOCSTRING = r"""
+ This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
+ library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
+ etc.)
+
+ This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
+ Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
+ and behavior.
+
+ Parameters:
+ config ([`Qwen2VLConfig`]):
+ Model configuration class with all the parameters of the model. Initializing with a config file does not
+ load the weights associated with the model, only the configuration. Check out the
+ [`~PreTrainedModel.from_pretrained`] method to load the model weights.
+"""
+
+
+@add_start_docstrings(
+ "The bare Qwen2VL Model outputting raw hidden-states without any specific head on top.",
+ QWEN2VL_START_DOCSTRING,
+)
+class Qwen2VLPreTrainedModel(PreTrainedModel):
+ config_class = Qwen2VLConfig
+ base_model_prefix = "model"
+ supports_gradient_checkpointing = True
+ _no_split_modules = ["Qwen2VLDecoderLayer", "Qwen2VLVisionBlock"]
+ _skip_keys_device_placement = "past_key_values"
+ _supports_flash_attn_2 = True
+ _supports_sdpa = True
+ _supports_cache_class = True
+ _supports_static_cache = True
+
+ def _init_weights(self, module):
+ std = self.config.initializer_range
+ if isinstance(module, (nn.Linear, nn.Conv3d)):
+ module.weight.data.normal_(mean=0.0, std=std)
+ if module.bias is not None:
+ module.bias.data.zero_()
+ elif isinstance(module, nn.Embedding):
+ module.weight.data.normal_(mean=0.0, std=std)
+ if module.padding_idx is not None:
+ module.weight.data[module.padding_idx].zero_()
+
+
+class Qwen2VisionTransformerPretrainedModel(Qwen2VLPreTrainedModel):
+ config_class = Qwen2VLVisionConfig
+ _no_split_modules = ["Qwen2VLVisionBlock"]
+
+ def __init__(self, config) -> None:
+ super().__init__(config)
+ self.spatial_merge_size = config.spatial_merge_size
+
+ self.patch_embed = PatchEmbed(
+ patch_size=config.patch_size,
+ temporal_patch_size=config.temporal_patch_size,
+ in_channels=config.in_channels,
+ embed_dim=config.embed_dim,
+ )
+
+ head_dim = config.embed_dim // config.num_heads
+ self.rotary_pos_emb = VisionRotaryEmbedding(head_dim // 2)
+
+ self.blocks = nn.ModuleList(
+ [Qwen2VLVisionBlock(config, config._attn_implementation) for _ in range(config.depth)]
+ )
+ self.merger = PatchMerger(dim=config.hidden_size, context_dim=config.embed_dim)
+
+ def get_dtype(self) -> torch.dtype:
+ return self.blocks[0].mlp.fc2.weight.dtype
+
+ def get_device(self) -> torch.device:
+ return self.blocks[0].mlp.fc2.weight.device
+
+ def rot_pos_emb(self, grid_thw):
+ pos_ids = []
+ for t, h, w in grid_thw:
+ hpos_ids = torch.arange(h).unsqueeze(1).expand(-1, w)
+ hpos_ids = hpos_ids.reshape(
+ h // self.spatial_merge_size,
+ self.spatial_merge_size,
+ w // self.spatial_merge_size,
+ self.spatial_merge_size,
+ )
+ hpos_ids = hpos_ids.permute(0, 2, 1, 3)
+ hpos_ids = hpos_ids.flatten()
+
+ wpos_ids = torch.arange(w).unsqueeze(0).expand(h, -1)
+ wpos_ids = wpos_ids.reshape(
+ h // self.spatial_merge_size,
+ self.spatial_merge_size,
+ w // self.spatial_merge_size,
+ self.spatial_merge_size,
+ )
+ wpos_ids = wpos_ids.permute(0, 2, 1, 3)
+ wpos_ids = wpos_ids.flatten()
+ pos_ids.append(torch.stack([hpos_ids, wpos_ids], dim=-1).repeat(t, 1))
+ pos_ids = torch.cat(pos_ids, dim=0)
+ max_grid_size = grid_thw[:, 1:].max()
+ rotary_pos_emb_full = self.rotary_pos_emb(max_grid_size)
+ rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(1)
+ return rotary_pos_emb
+
+ def forward(self, hidden_states: torch.Tensor, grid_thw: torch.Tensor) -> torch.Tensor:
+ hidden_states = self.patch_embed(hidden_states)
+ rotary_pos_emb = self.rot_pos_emb(grid_thw)
+
+ cu_seqlens = torch.repeat_interleave(grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]).cumsum(
+ dim=0, dtype=torch.int32
+ )
+ cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0)
+
+ for blk in self.blocks:
+ hidden_states = blk(hidden_states, cu_seqlens=cu_seqlens, rotary_pos_emb=rotary_pos_emb)
+
+ return self.merger(hidden_states)
+
+
+@add_start_docstrings(
+ "The bare Qwen2VL Model outputting raw hidden-states without any specific head on top.",
+ QWEN2VL_START_DOCSTRING,
+)
+class Qwen2VLModel(Qwen2VLPreTrainedModel):
+ def __init__(self, config: Qwen2VLConfig):
+ super().__init__(config)
+ self.padding_idx = config.pad_token_id
+ self.vocab_size = config.vocab_size
+
+ self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
+ self.layers = nn.ModuleList(
+ [Qwen2VLDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
+ )
+ self._attn_implementation = config._attn_implementation
+ self.norm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
+
+ self.gradient_checkpointing = False
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ def get_input_embeddings(self):
+ return self.embed_tokens
+
+ def set_input_embeddings(self, value):
+ self.embed_tokens = value
+
+ def forward(
+ self,
+ input_ids: torch.LongTensor = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ cache_position: Optional[torch.LongTensor] = None,
+ ) -> Union[Tuple, BaseModelOutputWithPast]:
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
+
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ if (input_ids is None) ^ (inputs_embeds is not None):
+ raise ValueError(
+ "You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one"
+ )
+
+ if self.gradient_checkpointing and self.training:
+ if use_cache:
+ logger.warning_once(
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
+ )
+ use_cache = False
+
+ if inputs_embeds is None:
+ inputs_embeds = self.embed_tokens(input_ids)
+
+ if cache_position is None:
+ past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
+ cache_position = torch.arange(
+ past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
+ )
+ if position_ids is None:
+ # the hard coded `3` is for temporal, height and width.
+ position_ids = cache_position.view(1, 1, -1).expand(3, inputs_embeds.shape[0], -1)
+
+ causal_mask = self._update_causal_mask(
+ attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions
+ )
+
+ hidden_states = inputs_embeds
+
+ # decoder layers
+ all_hidden_states = () if output_hidden_states else None
+ all_self_attns = () if output_attentions else None
+ next_decoder_cache = None
+
+ for decoder_layer in self.layers:
+ if output_hidden_states:
+ all_hidden_states += (hidden_states,)
+
+ if self.gradient_checkpointing and self.training:
+ layer_outputs = self._gradient_checkpointing_func(
+ decoder_layer.__call__,
+ hidden_states,
+ causal_mask,
+ position_ids,
+ past_key_values,
+ output_attentions,
+ use_cache,
+ cache_position,
+ )
+ else:
+ layer_outputs = decoder_layer(
+ hidden_states,
+ attention_mask=causal_mask,
+ position_ids=position_ids,
+ past_key_value=past_key_values,
+ output_attentions=output_attentions,
+ use_cache=use_cache,
+ cache_position=cache_position,
+ )
+
+ hidden_states = layer_outputs[0]
+
+ if use_cache:
+ next_decoder_cache = layer_outputs[2 if output_attentions else 1]
+
+ if output_attentions:
+ all_self_attns += (layer_outputs[1],)
+
+ hidden_states = self.norm(hidden_states)
+
+ # add hidden states from the last decoder layer
+ if output_hidden_states:
+ all_hidden_states += (hidden_states,)
+
+ next_cache = next_decoder_cache if use_cache else None
+
+ if not return_dict:
+ return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
+ return BaseModelOutputWithPast(
+ last_hidden_state=hidden_states,
+ past_key_values=next_cache,
+ hidden_states=all_hidden_states,
+ attentions=all_self_attns,
+ )
+
+ # Copied from transformers.models.llama.modeling_llama.LlamaModel._update_causal_mask
+ def _update_causal_mask(
+ self,
+ attention_mask: torch.Tensor,
+ input_tensor: torch.Tensor,
+ cache_position: torch.Tensor,
+ past_key_values: Cache,
+ output_attentions: bool,
+ ):
+ if self.config._attn_implementation == "flash_attention_2":
+ if attention_mask is not None and 0.0 in attention_mask:
+ return attention_mask
+ return None
+
+ # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in
+ # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail
+ # to infer the attention mask.
+ past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
+ using_static_cache = isinstance(past_key_values, StaticCache)
+
+ # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward
+ if self.config._attn_implementation == "sdpa" and not using_static_cache and not output_attentions:
+ if AttentionMaskConverter._ignore_causal_mask_sdpa(
+ attention_mask,
+ inputs_embeds=input_tensor,
+ past_key_values_length=past_seen_tokens,
+ is_training=self.training,
+ ):
+ return None
+
+ dtype, device = input_tensor.dtype, input_tensor.device
+ min_dtype = torch.finfo(dtype).min
+ sequence_length = input_tensor.shape[1]
+ if using_static_cache:
+ target_length = past_key_values.get_max_length()
+ else:
+ target_length = (
+ attention_mask.shape[-1]
+ if isinstance(attention_mask, torch.Tensor)
+ else past_seen_tokens + sequence_length + 1
+ )
+
+ # In case the provided `attention` mask is 2D, we generate a causal mask here (4D).
+ causal_mask = _prepare_4d_causal_attention_mask_with_cache_position(
+ attention_mask,
+ sequence_length=sequence_length,
+ target_length=target_length,
+ dtype=dtype,
+ device=device,
+ min_dtype=min_dtype,
+ cache_position=cache_position,
+ batch_size=input_tensor.shape[0],
+ )
+
+ if (
+ self.config._attn_implementation == "sdpa"
+ and attention_mask is not None
+ and attention_mask.device.type == "cuda"
+ and not output_attentions
+ ):
+ # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
+ # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
+ # Details: https://github.com/pytorch/pytorch/issues/110213
+ causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype)
+
+ return causal_mask
+
+
+QWEN2_VL_INPUTS_DOCSTRING = r"""
+ Args:
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
+ Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
+ it.
+
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
+ [`PreTrainedTokenizer.__call__`] for details.
+
+ [What are input IDs?](../glossary#input-ids)
+ attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
+
+ - 1 for tokens that are **not masked**,
+ - 0 for tokens that are **masked**.
+
+ [What are attention masks?](../glossary#attention-mask)
+
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
+ [`PreTrainedTokenizer.__call__`] for details.
+
+ If `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see
+ `past_key_values`).
+
+ If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`]
+ and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more
+ information on the default strategy.
+
+ - 1 indicates the head is **not masked**,
+ - 0 indicates the head is **masked**.
+ position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
+ config.n_positions - 1]`. [What are position IDs?](../glossary#position-ids)
+ past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
+ Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
+ `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape
+ `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`.
+
+ Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
+ blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.
+
+ If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that
+ don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all
+ `decoder_input_ids` of shape `(batch_size, sequence_length)`.
+ inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
+ is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
+ model's internal embedding lookup matrix.
+ use_cache (`bool`, *optional*):
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
+ `past_key_values`).
+ output_attentions (`bool`, *optional*):
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
+ tensors for more detail.
+ output_hidden_states (`bool`, *optional*):
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
+ more detail.
+ return_dict (`bool`, *optional*):
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
+ pixel_values (`torch.FloatTensor` of shape `(seq_length, num_channels * image_size * image_size)):
+ The tensors corresponding to the input images. Pixel values can be obtained using
+ [`AutoImageProcessor`]. See [`Qwen2VLImageProcessor.__call__`] for details. [`Qwen2VLProcessor`] uses
+ [`Qwen2VLImageProcessor`] for processing images.
+ pixel_values_videos (`torch.FloatTensor` of shape `(seq_length, num_channels * temporal_size * image_size * image_size)):
+ The tensors corresponding to the input videos. Pixel values can be obtained using
+ [`AutoImageProcessor`]. See [`Qwen2VLImageProcessor.__call__`] for details. [`Qwen2VLProcessor`] uses
+ [`Qwen2VLImageProcessor`] for processing videos.
+ image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*):
+ The temporal, height and width of feature shape of each image in LLM.
+ video_grid_thw (`torch.LongTensor` of shape `(num_videos, 3)`, *optional*):
+ The temporal, height and width of feature shape of each video in LLM.
+ rope_deltas (`torch.LongTensor` of shape `(batch_size, )`, *optional*):
+ The rope index difference between sequence length and multimodal rope.
+"""
+
+
+class Qwen2VLForConditionalGeneration(Qwen2VLPreTrainedModel):
+ _tied_weights_keys = ["lm_head.weight"]
+
+ def __init__(self, config):
+ super().__init__(config)
+ self.visual = Qwen2VisionTransformerPretrainedModel._from_config(
+ config.vision_config, attn_implementation=config._attn_implementation
+ )
+ self.model = Qwen2VLModel(config)
+ self.vocab_size = config.vocab_size
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
+ self.padding_side = "left" # set it to left by default, user can use setter to change padding_sides
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ def get_input_embeddings(self):
+ return self.model.embed_tokens
+
+ def set_input_embeddings(self, value):
+ self.model.embed_tokens = value
+
+ def get_output_embeddings(self):
+ return self.lm_head
+
+ def set_output_embeddings(self, new_embeddings):
+ self.lm_head = new_embeddings
+
+ def set_decoder(self, decoder):
+ self.model = decoder
+
+ def get_decoder(self):
+ return self.model
+
+ def get_rope_index(
+ self,
+ input_ids: torch.LongTensor,
+ image_grid_thw: Optional[torch.LongTensor] = None,
+ video_grid_thw: Optional[torch.LongTensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
+ """
+ Calculate the 3D rope index based on image and video's temporal, height and width in LLM.
+
+ Explanation:
+ Each embedding sequence contains vision embedding and text embedding or just contains text embedding.
+
+ For pure text embedding sequence, the rotary position embedding has no difference with mordern LLMs.
+ Examples:
+ input_ids: [T T T T T], here T is for text.
+ temporal position_ids: [0, 1, 2, 3, 4]
+ height position_ids: [0, 1, 2, 3, 4]
+ width position_ids: [0, 1, 2, 3, 4]
+
+ For vision and text embedding sequence, we calculate 3D rotary position embedding for vision part
+ and 1D rotary position embeddin for text part.
+ Examples:
+ Assume we have a video input with 3 temporal patches, 2 height patches and 2 width patches.
+ input_ids: [V V V V V V V V V V V V T T T T T], here V is for vision.
+ vision temporal position_ids: [0, 0, 0, 0, 1, 1, 1, 1, 2, 2, 2, 2]
+ vision height position_ids: [0, 0, 1, 1, 0, 0, 1, 1, 0, 0, 1, 1]
+ vision width position_ids: [0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1]
+ text temporal position_ids: [3, 4, 5, 6, 7]
+ text height position_ids: [3, 4, 5, 6, 7]
+ text width position_ids: [3, 4, 5, 6, 7]
+ Here we calculate the text start position_ids as the max vision position_ids plus 1.
+
+ Args:
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
+ Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
+ it.
+ image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*):
+ The temporal, height and width of feature shape of each image in LLM.
+ video_grid_thw (`torch.LongTensor` of shape `(num_videos, 3)`, *optional*):
+ The temporal, height and width of feature shape of each video in LLM.
+ attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
+
+ - 1 for tokens that are **not masked**,
+ - 0 for tokens that are **masked**.
+
+ Returns:
+ position_ids (`torch.LongTensor` of shape `(3, batch_size, sequence_length)`)
+ mrope_position_deltas (`torch.Tensor` of shape `(batch_size)`)
+ """
+ spatial_merge_size = self.config.vision_config.spatial_merge_size
+ image_token_id = self.config.image_token_id
+ video_token_id = self.config.video_token_id
+ vision_start_token_id = self.config.vision_start_token_id
+ mrope_position_deltas = []
+ if image_grid_thw is not None or video_grid_thw is not None:
+ total_input_ids = input_ids
+ position_ids = torch.ones(
+ 3, input_ids.shape[0], input_ids.shape[1], dtype=input_ids.dtype, device=input_ids.device
+ )
+ image_index, video_index = 0, 0
+ for i, input_ids in enumerate(total_input_ids):
+ if attention_mask is not None:
+ input_ids = input_ids[attention_mask[i] == 1]
+ image_nums, video_nums = 0, 0
+ vision_start_indices = torch.argwhere(input_ids == vision_start_token_id).squeeze(1)
+ vision_tokens = input_ids[vision_start_indices + 1]
+ image_nums = (vision_tokens == image_token_id).sum()
+ video_nums = (vision_tokens == video_token_id).sum()
+ input_tokens = input_ids.tolist()
+ llm_pos_ids_list: list = []
+ st = 0
+ remain_images, remain_videos = image_nums, video_nums
+ for _ in range(image_nums + video_nums):
+ if image_token_id in input_tokens and remain_images > 0:
+ ed_image = input_tokens.index(image_token_id, st)
+ else:
+ ed_image = len(input_tokens) + 1
+ if video_token_id in input_tokens and remain_videos > 0:
+ ed_video = input_tokens.index(video_token_id, st)
+ else:
+ ed_video = len(input_tokens) + 1
+ if ed_image < ed_video:
+ t, h, w = (
+ image_grid_thw[image_index][0],
+ image_grid_thw[image_index][1],
+ image_grid_thw[image_index][2],
+ )
+ image_index += 1
+ remain_images -= 1
+ ed = ed_image
+ else:
+ t, h, w = (
+ video_grid_thw[video_index][0],
+ video_grid_thw[video_index][1],
+ video_grid_thw[video_index][2],
+ )
+ video_index += 1
+ remain_videos -= 1
+ ed = ed_video
+ llm_grid_t, llm_grid_h, llm_grid_w = (
+ t.item(),
+ h.item() // spatial_merge_size,
+ w.item() // spatial_merge_size,
+ )
+ text_len = ed - st
+
+ st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0
+ llm_pos_ids_list.append(torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx)
+
+ t_index = torch.arange(llm_grid_t).view(-1, 1).expand(-1, llm_grid_h * llm_grid_w).flatten()
+ h_index = torch.arange(llm_grid_h).view(1, -1, 1).expand(llm_grid_t, -1, llm_grid_w).flatten()
+ w_index = torch.arange(llm_grid_w).view(1, 1, -1).expand(llm_grid_t, llm_grid_h, -1).flatten()
+ llm_pos_ids_list.append(torch.stack([t_index, h_index, w_index]) + text_len + st_idx)
+ st = ed + llm_grid_t * llm_grid_h * llm_grid_w
+
+ if st < len(input_tokens):
+ st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0
+ text_len = len(input_tokens) - st
+ llm_pos_ids_list.append(torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx)
+
+ llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1)
+ position_ids[..., i, attention_mask[i] == 1] = llm_positions.to(position_ids.device)
+ mrope_position_deltas.append(llm_positions.max() + 1 - len(total_input_ids[i]))
+ mrope_position_deltas = torch.tensor(mrope_position_deltas, device=input_ids.device).unsqueeze(1)
+ return position_ids, mrope_position_deltas
+ else:
+ if attention_mask is not None:
+ position_ids = attention_mask.long().cumsum(-1) - 1
+ position_ids.masked_fill_(attention_mask == 0, 1)
+ position_ids = position_ids.unsqueeze(0).expand(3, -1, -1).to(input_ids.device)
+ max_position_ids = position_ids.max(0, keepdim=False)[0].max(-1, keepdim=True)[0]
+ mrope_position_deltas = max_position_ids + 1 - attention_mask.shape[-1]
+ else:
+ position_ids = (
+ torch.arange(input_ids.shape[1], device=input_ids.device)
+ .view(1, 1, -1)
+ .expand(3, input_ids.shape[0], -1)
+ )
+ mrope_position_deltas = torch.zeros(
+ [input_ids.shape[0], 1],
+ device=input_ids.device,
+ dtype=input_ids.dtype,
+ )
+
+ return position_ids, mrope_position_deltas
+
+ def _update_model_kwargs_for_generation(
+ self,
+ outputs: ModelOutput,
+ model_kwargs: Dict[str, Any],
+ is_encoder_decoder: bool = False,
+ num_new_tokens: int = 1,
+ ) -> Dict[str, Any]:
+ model_kwargs = super()._update_model_kwargs_for_generation(
+ outputs=outputs,
+ model_kwargs=model_kwargs,
+ is_encoder_decoder=is_encoder_decoder,
+ num_new_tokens=num_new_tokens,
+ )
+
+ if getattr(outputs, "rope_deltas", None) is not None:
+ model_kwargs["rope_deltas"] = outputs.rope_deltas
+
+ return model_kwargs
+
+ @add_start_docstrings_to_model_forward(QWEN2_VL_INPUTS_DOCSTRING)
+ @replace_return_docstrings(output_type=Qwen2VLCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
+ def forward(
+ self,
+ input_ids: torch.LongTensor = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ labels: Optional[torch.LongTensor] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ pixel_values: Optional[torch.Tensor] = None,
+ pixel_values_videos: Optional[torch.FloatTensor] = None,
+ image_grid_thw: Optional[torch.LongTensor] = None,
+ video_grid_thw: Optional[torch.LongTensor] = None,
+ rope_deltas: Optional[torch.LongTensor] = None,
+ ) -> Union[Tuple, Qwen2VLCausalLMOutputWithPast]:
+ r"""
+ Args:
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
+ config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
+
+ Returns:
+
+ Example:
+
+ ```python
+ >>> from PIL import Image
+ >>> import requests
+ >>> from transformers import AutoProcessor, Qwen2VLForConditionalGeneration
+
+ >>> model = Qwen2VLForConditionalGeneration.from_pretrained("Qwen/Qwen2-VL-7B-Instruct")
+ >>> processor = AutoProcessor.from_pretrained("Qwen/Qwen2-VL-7B-Instruct")
+
+ >>> messages = [
+ {
+ "role": "user",
+ "content": [
+ {"type": "image"},
+ {"type": "text", "text": "What is shown in this image?"},
+ ],
+ },
+ ]
+ >>> url = "https://www.ilankelman.org/stopsigns/australia.jpg"
+ >>> image = Image.open(requests.get(url, stream=True).raw)
+
+ >>> text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
+ >>> inputs = processor(text=[text], images=[image], vision_infos=[vision_infos])
+
+ >>> # Generate
+ >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
+ >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
+ "The image shows a street scene with a red stop sign in the foreground. In the background, there is a large red gate with Chinese characters ..."
+ ```"""
+
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ if inputs_embeds is None:
+ inputs_embeds = self.model.embed_tokens(input_ids)
+ if pixel_values is not None:
+ pixel_values = pixel_values.type(self.visual.get_dtype())
+ image_embeds = self.visual(pixel_values, grid_thw=image_grid_thw).to(inputs_embeds.device)
+ image_mask = input_ids == self.config.image_token_id
+ if self.training:
+ inputs_embeds = inputs_embeds.clone()
+ inputs_embeds[image_mask] = image_embeds
+ if pixel_values_videos is not None:
+ pixel_values_videos = pixel_values_videos.type(self.visual.get_dtype())
+ video_embeds = self.visual(pixel_values_videos, grid_thw=video_grid_thw).to(inputs_embeds.device)
+ video_mask = input_ids == self.config.video_token_id
+ inputs_embeds[video_mask] = video_embeds
+ if attention_mask is not None:
+ attention_mask = attention_mask.to(inputs_embeds.device)
+
+ outputs = self.model(
+ input_ids=None,
+ position_ids=position_ids,
+ attention_mask=attention_mask,
+ past_key_values=past_key_values,
+ inputs_embeds=inputs_embeds,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+
+ hidden_states = outputs[0]
+ logits = self.lm_head(hidden_states)
+ logits = logits.float()
+
+ loss = None
+ if labels is not None:
+ # Shift so that tokens < n predict n
+ shift_logits = logits[..., :-1, :].contiguous()
+ shift_labels = labels[..., 1:].contiguous()
+ # Flatten the tokens
+ loss_fct = CrossEntropyLoss()
+ shift_logits = shift_logits.view(-1, self.config.vocab_size)
+ shift_labels = shift_labels.view(-1)
+ # Enable model parallelism
+ shift_labels = shift_labels.to(shift_logits.device)
+ loss = loss_fct(shift_logits, shift_labels)
+
+ if not return_dict:
+ output = (logits,) + outputs[1:]
+ return (loss,) + output if loss is not None else output
+
+ return Qwen2VLCausalLMOutputWithPast(
+ loss=loss,
+ logits=logits,
+ past_key_values=outputs.past_key_values,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ rope_deltas=rope_deltas,
+ )
+
+ def prepare_inputs_for_generation(
+ self,
+ input_ids,
+ past_key_values=None,
+ attention_mask=None,
+ inputs_embeds=None,
+ cache_position=None,
+ position_ids=None,
+ use_cache=True,
+ pixel_values=None,
+ pixel_values_videos=None,
+ image_grid_thw=None,
+ video_grid_thw=None,
+ **kwargs,
+ ):
+ # If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens
+ # Exception 1: when passing input_embeds, input_ids may be missing entries
+ # Exception 2: some generation methods do special slicing of input_ids, so we don't need to do it here
+ if past_key_values is not None:
+ if inputs_embeds is not None: # Exception 1
+ input_ids = input_ids[:, -cache_position.shape[0] :]
+ elif input_ids.shape[1] != cache_position.shape[0]: # Default case (the "else", a no op, is Exception 2)
+ input_ids = input_ids[:, cache_position]
+
+ rope_deltas = kwargs.get("rope_deltas", None)
+ if attention_mask is not None and position_ids is None:
+ if cache_position is None or (cache_position is not None and cache_position[0] == 0):
+ position_ids, rope_deltas = self.get_rope_index(
+ input_ids, image_grid_thw, video_grid_thw, attention_mask
+ )
+ else:
+ batch_size, seq_length = input_ids.shape
+ delta = (
+ cache_position[0] + rope_deltas if cache_position is not None and rope_deltas is not None else 0
+ )
+ position_ids = torch.arange(seq_length, device=input_ids.device)
+ position_ids = position_ids.view(1, -1).expand(batch_size, -1)
+ position_ids = position_ids.add(delta)
+ position_ids = position_ids.unsqueeze(0).expand(3, -1, -1)
+
+ if cache_position[0] != 0:
+ pixel_values = None
+ pixel_values_videos = None
+
+ # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
+ if inputs_embeds is not None and cache_position[0] == 0:
+ model_inputs = {"inputs_embeds": inputs_embeds}
+ else:
+ model_inputs = {"input_ids": input_ids}
+
+ if isinstance(past_key_values, StaticCache) and attention_mask.ndim == 2:
+ if inputs_embeds is not None:
+ batch_size, sequence_length = inputs_embeds.shape
+ device = inputs_embeds.device
+ else:
+ batch_size, sequence_length = input_ids.shape
+ device = input_ids.device
+
+ dtype = self.lm_head.weight.dtype
+ min_dtype = torch.finfo(dtype).min
+
+ attention_mask = _prepare_4d_causal_attention_mask_with_cache_position(
+ attention_mask,
+ sequence_length=sequence_length,
+ target_length=past_key_values.get_max_length(),
+ dtype=dtype,
+ device=device,
+ min_dtype=min_dtype,
+ cache_position=cache_position,
+ batch_size=batch_size,
+ )
+
+ model_inputs.update(
+ {
+ "position_ids": position_ids,
+ "past_key_values": past_key_values,
+ "use_cache": use_cache,
+ "attention_mask": attention_mask,
+ "pixel_values": pixel_values,
+ "pixel_values_videos": pixel_values_videos,
+ "image_grid_thw": image_grid_thw,
+ "video_grid_thw": video_grid_thw,
+ "rope_deltas": rope_deltas,
+ }
+ )
+ return model_inputs
+
+
+class Qwen2VLSimplifiedModel(Qwen2VLPreTrainedModel):
+ def __init__(self, config):
+ super().__init__(config)
+ self.visual = Qwen2VisionTransformerPretrainedModel._from_config(
+ config.vision_config, attn_implementation=config._attn_implementation
+ )
+ self.model = Qwen2VLModel(config)
+ self.hidden_size = config.hidden_size
+
+ # ๅๅงๅๆ้
+ self.post_init()
+
+ def get_input_embeddings(self):
+ return self.model.embed_tokens
+
+ def _update_model_kwargs_for_generation(self, outputs, model_kwargs, is_encoder_decoder=False, num_new_tokens=1):
+ # ็งป้ค็ๆ็ธๅ
ณ็ๆดๆฐ้ป่พ
+ return model_kwargs
+
+ def forward(
+ self,
+ input_ids: torch.LongTensor = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ pixel_values: Optional[torch.Tensor] = None,
+ pixel_values_videos: Optional[torch.FloatTensor] = None,
+ image_grid_thw: Optional[torch.LongTensor] = None,
+ video_grid_thw: Optional[torch.LongTensor] = None
+ ):
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ if inputs_embeds is None:
+ inputs_embeds = self.model.embed_tokens(input_ids)
+ if pixel_values is not None:
+ image_embeds = self.visual(pixel_values, grid_thw=image_grid_thw).to(inputs_embeds.device)
+ image_mask = input_ids == self.config.image_token_id
+ inputs_embeds[image_mask] = image_embeds
+ if attention_mask is not None:
+ attention_mask = attention_mask.to(inputs_embeds.device)
+
+ if position_ids is None:
+ position_ids, _ = self.get_rope_index(input_ids, image_grid_thw, video_grid_thw, attention_mask)
+
+ outputs = self.model(
+ input_ids=None,
+ position_ids=position_ids,
+ attention_mask=attention_mask,
+ inputs_embeds=inputs_embeds,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+ hidden_states = outputs[0]
+
+ return hidden_states, image_mask, image_grid_thw
+
+ def get_rope_index(
+ self,
+ input_ids: torch.LongTensor,
+ image_grid_thw: Optional[torch.LongTensor] = None,
+ video_grid_thw: Optional[torch.LongTensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
+ """
+ Calculate the 3D rope index based on image and video's temporal, height and width in LLM.
+
+ Explanation:
+ Each embedding sequence contains vision embedding and text embedding or just contains text embedding.
+
+ For pure text embedding sequence, the rotary position embedding has no difference with mordern LLMs.
+ Examples:
+ input_ids: [T T T T T], here T is for text.
+ temporal position_ids: [0, 1, 2, 3, 4]
+ height position_ids: [0, 1, 2, 3, 4]
+ width position_ids: [0, 1, 2, 3, 4]
+
+ For vision and text embedding sequence, we calculate 3D rotary position embedding for vision part
+ and 1D rotary position embeddin for text part.
+ Examples:
+ Assume we have a video input with 3 temporal patches, 2 height patches and 2 width patches.
+ input_ids: [V V V V V V V V V V V V T T T T T], here V is for vision.
+ vision temporal position_ids: [0, 0, 0, 0, 1, 1, 1, 1, 2, 2, 2, 2]
+ vision height position_ids: [0, 0, 1, 1, 0, 0, 1, 1, 0, 0, 1, 1]
+ vision width position_ids: [0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1]
+ text temporal position_ids: [3, 4, 5, 6, 7]
+ text height position_ids: [3, 4, 5, 6, 7]
+ text width position_ids: [3, 4, 5, 6, 7]
+ Here we calculate the text start position_ids as the max vision position_ids plus 1.
+
+ Args:
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
+ Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
+ it.
+ image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*):
+ The temporal, height and width of feature shape of each image in LLM.
+ video_grid_thw (`torch.LongTensor` of shape `(num_videos, 3)`, *optional*):
+ The temporal, height and width of feature shape of each video in LLM.
+ attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
+
+ - 1 for tokens that are **not masked**,
+ - 0 for tokens that are **masked**.
+
+ Returns:
+ position_ids (`torch.LongTensor` of shape `(3, batch_size, sequence_length)`)
+ mrope_position_deltas (`torch.Tensor` of shape `(batch_size)`)
+ """
+ spatial_merge_size = self.config.vision_config.spatial_merge_size
+ image_token_id = self.config.image_token_id
+ video_token_id = self.config.video_token_id
+ vision_start_token_id = self.config.vision_start_token_id
+ mrope_position_deltas = []
+ if image_grid_thw is not None or video_grid_thw is not None:
+ total_input_ids = input_ids
+ position_ids = torch.ones(
+ 3, input_ids.shape[0], input_ids.shape[1], dtype=input_ids.dtype, device=input_ids.device
+ )
+ image_index, video_index = 0, 0
+ for i, input_ids in enumerate(total_input_ids):
+ if attention_mask is not None:
+ input_ids = input_ids[attention_mask[i] == 1]
+ image_nums, video_nums = 0, 0
+ vision_start_indices = torch.argwhere(input_ids == vision_start_token_id).squeeze(1)
+ vision_tokens = input_ids[vision_start_indices + 1]
+ image_nums = (vision_tokens == image_token_id).sum()
+ video_nums = (vision_tokens == video_token_id).sum()
+ input_tokens = input_ids.tolist()
+ llm_pos_ids_list: list = []
+ st = 0
+ remain_images, remain_videos = image_nums, video_nums
+ for _ in range(image_nums + video_nums):
+ if image_token_id in input_tokens and remain_images > 0:
+ ed_image = input_tokens.index(image_token_id, st)
+ else:
+ ed_image = len(input_tokens) + 1
+ if video_token_id in input_tokens and remain_videos > 0:
+ ed_video = input_tokens.index(video_token_id, st)
+ else:
+ ed_video = len(input_tokens) + 1
+ if ed_image < ed_video:
+ t, h, w = (
+ image_grid_thw[image_index][0],
+ image_grid_thw[image_index][1],
+ image_grid_thw[image_index][2],
+ )
+ image_index += 1
+ remain_images -= 1
+ ed = ed_image
+ else:
+ t, h, w = (
+ video_grid_thw[video_index][0],
+ video_grid_thw[video_index][1],
+ video_grid_thw[video_index][2],
+ )
+ video_index += 1
+ remain_videos -= 1
+ ed = ed_video
+ llm_grid_t, llm_grid_h, llm_grid_w = (
+ t.item(),
+ h.item() // spatial_merge_size,
+ w.item() // spatial_merge_size,
+ )
+ text_len = ed - st
+
+ st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0
+ llm_pos_ids_list.append(torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx)
+
+ t_index = torch.arange(llm_grid_t).view(-1, 1).expand(-1, llm_grid_h * llm_grid_w).flatten()
+ h_index = torch.arange(llm_grid_h).view(1, -1, 1).expand(llm_grid_t, -1, llm_grid_w).flatten()
+ w_index = torch.arange(llm_grid_w).view(1, 1, -1).expand(llm_grid_t, llm_grid_h, -1).flatten()
+ llm_pos_ids_list.append(torch.stack([t_index, h_index, w_index]) + text_len + st_idx)
+ st = ed + llm_grid_t * llm_grid_h * llm_grid_w
+
+ if st < len(input_tokens):
+ st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0
+ text_len = len(input_tokens) - st
+ llm_pos_ids_list.append(torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx)
+
+ llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1)
+ position_ids[..., i, attention_mask[i] == 1] = llm_positions.to(position_ids.device)
+ mrope_position_deltas.append(llm_positions.max() + 1 - len(total_input_ids[i]))
+ mrope_position_deltas = torch.tensor(mrope_position_deltas, device=input_ids.device).unsqueeze(1)
+ return position_ids, mrope_position_deltas
+ else:
+ if attention_mask is not None:
+ position_ids = attention_mask.long().cumsum(-1) - 1
+ position_ids.masked_fill_(attention_mask == 0, 1)
+ position_ids = position_ids.unsqueeze(0).expand(3, -1, -1).to(input_ids.device)
+ max_position_ids = position_ids.max(0, keepdim=False)[0].max(-1, keepdim=True)[0]
+ mrope_position_deltas = max_position_ids + 1 - attention_mask.shape[-1]
+ else:
+ position_ids = (
+ torch.arange(input_ids.shape[1], device=input_ids.device)
+ .view(1, 1, -1)
+ .expand(3, input_ids.shape[0], -1)
+ )
+ mrope_position_deltas = torch.zeros(
+ [input_ids.shape[0], 1],
+ device=input_ids.device,
+ dtype=input_ids.dtype,
+ )
+
+ return position_ids, mrope_position_deltas
diff --git a/qwen2_vl/processing_qwen2_vl.py b/qwen2_vl/processing_qwen2_vl.py
new file mode 100644
index 0000000000000000000000000000000000000000..8a0e79be2bd5e3369ea32456f0a1bc2730badba6
--- /dev/null
+++ b/qwen2_vl/processing_qwen2_vl.py
@@ -0,0 +1,183 @@
+# coding=utf-8
+# Copyright 2024 The Qwen team, Alibaba Group and the HuggingFace Inc. team. All rights reserved.
+#
+# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
+# and OPT implementations in this library. It has been modified from its
+# original forms to accommodate minor architectural differences compared
+# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+Processor class for Qwen2-VL.
+"""
+
+from typing import List, Optional, Union
+
+from transformers.feature_extraction_utils import BatchFeature
+from transformers.image_utils import ImageInput, VideoInput
+from transformers.processing_utils import ProcessorMixin
+from transformers.tokenization_utils_base import PaddingStrategy, PreTokenizedInput, TextInput, TruncationStrategy
+from transformers.utils import TensorType, logging
+
+
+logger = logging.get_logger(__name__)
+
+
+class Qwen2VLProcessor(ProcessorMixin):
+ r"""
+ Constructs a Qwen2-VL processor which wraps a Qwen2-VL image processor and a Qwen2 tokenizer into a single processor.
+
+ [`Qwen2VLProcessor`] offers all the functionalities of [`Qwen2VLImageProcessor`] and [`Qwen2TokenizerFast`]. See the
+ [`~Qwen2VLProcessor.__call__`] and [`~Qwen2VLProcessor.decode`] for more information.
+
+ Args:
+ image_processor ([`Qwen2VLImageProcessor`], *optional*):
+ The image processor is a required input.
+ tokenizer ([`Qwen2TokenizerFast`], *optional*):
+ The tokenizer is a required input.
+ chat_template (`str`, *optional*): A Jinja template which will be used to convert lists of messages
+ in a chat into a tokenizable string.
+ """
+
+ attributes = ["image_processor", "tokenizer"]
+ valid_kwargs = ["chat_template"]
+ image_processor_class = "Qwen2VLImageProcessor"
+ tokenizer_class = ("Qwen2Tokenizer", "Qwen2TokenizerFast")
+
+ def __init__(self, image_processor=None, tokenizer=None, chat_template=None, **kwargs):
+ super().__init__(image_processor, tokenizer, chat_template=chat_template)
+
+ def __call__(
+ self,
+ images: ImageInput = None,
+ text: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]] = None,
+ videos: VideoInput = None,
+ padding: Union[bool, str, PaddingStrategy] = False,
+ truncation: Union[bool, str, TruncationStrategy] = None,
+ max_length: int = None,
+ return_tensors: Optional[Union[str, TensorType]] = TensorType.PYTORCH,
+ ) -> BatchFeature:
+ """
+ Main method to prepare for the model one or several sequences(s) and image(s). This method forwards the `text`
+ and `kwargs` arguments to Qwen2TokenizerFast's [`~Qwen2TokenizerFast.__call__`] if `text` is not `None` to encode
+ the text. To prepare the vision inputs, this method forwards the `vision_infos` and `kwrags` arguments to
+ Qwen2VLImageProcessor's [`~Qwen2VLImageProcessor.__call__`] if `vision_infos` is not `None`.
+
+ Args:
+ images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `List[PIL.Image.Image]`, `List[np.ndarray]`, `List[torch.Tensor]`):
+ The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch
+ tensor. Both channels-first and channels-last formats are supported.
+ text (`str`, `List[str]`, `List[List[str]]`):
+ The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings
+ (pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set
+ `is_split_into_words=True` (to lift the ambiguity with a batch of sequences).
+ videos (`np.ndarray`, `torch.Tensor`, `List[np.ndarray]`, `List[torch.Tensor]`):
+ The image or batch of videos to be prepared. Each video can be a 4D NumPy array or PyTorch
+ tensor, or a nested list of 3D frames. Both channels-first and channels-last formats are supported.
+ padding (`bool`, `str` or [`~utils.PaddingStrategy`], *optional*, defaults to `False`):
+ Select a strategy to pad the returned sequences (according to the model's padding side and padding
+ index) among:
+ - `True` or `'longest'`: Pad to the longest sequence in the batch (or no padding if only a single
+ sequence if provided).
+ - `'max_length'`: Pad to a maximum length specified with the argument `max_length` or to the maximum
+ acceptable input length for the model if that argument is not provided.
+ - `False` or `'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of different
+ lengths).
+ max_length (`int`, *optional*):
+ Maximum length of the returned list and optionally padding length (see above).
+ truncation (`bool`, *optional*):
+ Activates truncation to cut input sequences longer than `max_length` to `max_length`.
+ return_tensors (`str` or [`~utils.TensorType`], *optional*):
+ If set, will return tensors of a particular framework. Acceptable values are:
+
+ - `'tf'`: Return TensorFlow `tf.constant` objects.
+ - `'pt'`: Return PyTorch `torch.Tensor` objects.
+ - `'np'`: Return NumPy `np.ndarray` objects.
+ - `'jax'`: Return JAX `jnp.ndarray` objects.
+
+ Returns:
+ [`BatchFeature`]: A [`BatchFeature`] with the following fields:
+
+ - **input_ids** -- List of token ids to be fed to a model. Returned when `text` is not `None`.
+ - **attention_mask** -- List of indices specifying which tokens should be attended to by the model (when
+ `return_attention_mask=True` or if *"attention_mask"* is in `self.model_input_names` and if `text` is not
+ `None`).
+ - **pixel_values** -- Pixel values to be fed to a model. Returned when `images` is not `None`.
+ - **pixel_values_videos** -- Pixel values of videos to be fed to a model. Returned when `videos` is not `None`.
+ - **image_grid_thw** -- List of image 3D grid in LLM. Returned when `images` is not `None`.
+ - **video_grid_thw** -- List of video 3D grid in LLM. Returned when `videos` is not `None`.
+ """
+ if images is not None:
+ image_inputs = self.image_processor(images=images, videos=None, return_tensors=return_tensors)
+ image_grid_thw = image_inputs["image_grid_thw"]
+ else:
+ image_inputs = {}
+ image_grid_thw = None
+
+ if videos is not None:
+ videos_inputs = self.image_processor(images=None, videos=videos, return_tensors=return_tensors)
+ video_grid_thw = videos_inputs["video_grid_thw"]
+ else:
+ videos_inputs = {}
+ video_grid_thw = None
+
+ if not isinstance(text, list):
+ text = [text]
+
+ if image_grid_thw is not None:
+ merge_length = self.image_processor.merge_size**2
+ index = 0
+ for i in range(len(text)):
+ while "<|image_pad|>" in text[i]:
+ text[i] = text[i].replace(
+ "<|image_pad|>", "<|placeholder|>" * (image_grid_thw[index].prod() // merge_length), 1
+ )
+ index += 1
+ text[i] = text[i].replace("<|placeholder|>", "<|image_pad|>")
+
+ if video_grid_thw is not None:
+ merge_length = self.image_processor.merge_size**2
+ index = 0
+ for i in range(len(text)):
+ while "<|video_pad|>" in text[i]:
+ text[i] = text[i].replace(
+ "<|video_pad|>", "<|placeholder|>" * (video_grid_thw[index].prod() // merge_length), 1
+ )
+ index += 1
+ text[i] = text[i].replace("<|placeholder|>", "<|video_pad|>")
+
+ text_inputs = self.tokenizer(
+ text, return_tensors=return_tensors, padding=padding, truncation=truncation, max_length=max_length
+ )
+
+ return BatchFeature(data={**text_inputs, **image_inputs, **videos_inputs})
+
+ def batch_decode(self, *args, **kwargs):
+ """
+ This method forwards all its arguments to Qwen2TokenizerFast's [`~PreTrainedTokenizer.batch_decode`]. Please
+ refer to the docstring of this method for more information.
+ """
+ return self.tokenizer.batch_decode(*args, **kwargs)
+
+ def decode(self, *args, **kwargs):
+ """
+ This method forwards all its arguments to Qwen2TokenizerFast's [`~PreTrainedTokenizer.decode`]. Please refer to
+ the docstring of this method for more information.
+ """
+ return self.tokenizer.decode(*args, **kwargs)
+
+ @property
+ def model_input_names(self):
+ tokenizer_input_names = self.tokenizer.model_input_names
+ image_processor_input_names = self.image_processor.model_input_names
+ return list(dict.fromkeys(tokenizer_input_names + image_processor_input_names))
diff --git a/requirements.txt b/requirements.txt
new file mode 100644
index 0000000000000000000000000000000000000000..e5f4d859e5ceaad275c309c0ad54eede51885eaf
--- /dev/null
+++ b/requirements.txt
@@ -0,0 +1,17 @@
+# Core dependencies
+torch>=2.4.1
+torchvision>=0.15.0
+transformers==4.45.0
+diffusers==0.30.0
+accelerate==0.33.0
+
+# Image processing
+Pillow>=10.0.0
+opencv-python-headless>=4.8.0
+numpy>=1.24.0
+
+# Utilities
+protobuf==4.23.4
+sentencepiece==0.2.0
+
+gradio==4.36.0
\ No newline at end of file