from typing import Optional import torch from .attention import HiDreamAttention try: from flash_attn_interface import flash_attn_func USE_FLASH_ATTN3 = True except: from flash_attn import flash_attn_func USE_FLASH_ATTN3 = False # Copied from https://github.com/black-forest-labs/flux/blob/main/src/flux/math.py def apply_rope(xq: torch.Tensor, xk: torch.Tensor, freqs_cis: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: xq_ = xq.float().reshape(*xq.shape[:-1], -1, 1, 2) xk_ = xk.float().reshape(*xk.shape[:-1], -1, 1, 2) xq_out = freqs_cis[..., 0] * xq_[..., 0] + freqs_cis[..., 1] * xq_[..., 1] xk_out = freqs_cis[..., 0] * xk_[..., 0] + freqs_cis[..., 1] * xk_[..., 1] return xq_out.reshape(*xq.shape).type_as(xq), xk_out.reshape(*xk.shape).type_as(xk) def attention(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor): if USE_FLASH_ATTN3: hidden_states = flash_attn_func(query, key, value, causal=False, deterministic=False)[0] else: hidden_states = flash_attn_func(query, key, value, dropout_p=0., causal=False) hidden_states = hidden_states.flatten(-2) hidden_states = hidden_states.to(query.dtype) return hidden_states class HiDreamAttnProcessor_flashattn: """Attention processor used typically in processing the SD3-like self-attention projections.""" def __call__( self, attn: HiDreamAttention, image_tokens: torch.FloatTensor, image_tokens_masks: Optional[torch.FloatTensor] = None, text_tokens: Optional[torch.FloatTensor] = None, rope: torch.FloatTensor = None, *args, **kwargs, ) -> torch.FloatTensor: dtype = image_tokens.dtype batch_size = image_tokens.shape[0] query_i = attn.q_rms_norm(attn.to_q(image_tokens)).to(dtype=dtype) key_i = attn.k_rms_norm(attn.to_k(image_tokens)).to(dtype=dtype) value_i = attn.to_v(image_tokens) inner_dim = key_i.shape[-1] head_dim = inner_dim // attn.heads query_i = query_i.view(batch_size, -1, attn.heads, head_dim) key_i = key_i.view(batch_size, -1, attn.heads, head_dim) value_i = value_i.view(batch_size, -1, attn.heads, head_dim) if image_tokens_masks is not None: key_i = key_i * image_tokens_masks.view(batch_size, -1, 1, 1) if not attn.single: query_t = attn.q_rms_norm_t(attn.to_q_t(text_tokens)).to(dtype=dtype) key_t = attn.k_rms_norm_t(attn.to_k_t(text_tokens)).to(dtype=dtype) value_t = attn.to_v_t(text_tokens) query_t = query_t.view(batch_size, -1, attn.heads, head_dim) key_t = key_t.view(batch_size, -1, attn.heads, head_dim) value_t = value_t.view(batch_size, -1, attn.heads, head_dim) num_image_tokens = query_i.shape[1] num_text_tokens = query_t.shape[1] query = torch.cat([query_i, query_t], dim=1) key = torch.cat([key_i, key_t], dim=1) value = torch.cat([value_i, value_t], dim=1) else: query = query_i key = key_i value = value_i if query.shape[-1] == rope.shape[-3] * 2: query, key = apply_rope(query, key, rope) else: query_1, query_2 = query.chunk(2, dim=-1) key_1, key_2 = key.chunk(2, dim=-1) query_1, key_1 = apply_rope(query_1, key_1, rope) query = torch.cat([query_1, query_2], dim=-1) key = torch.cat([key_1, key_2], dim=-1) hidden_states = attention(query, key, value) if not attn.single: hidden_states_i, hidden_states_t = torch.split(hidden_states, [num_image_tokens, num_text_tokens], dim=1) hidden_states_i = attn.to_out(hidden_states_i) hidden_states_t = attn.to_out_t(hidden_states_t) return hidden_states_i, hidden_states_t else: hidden_states = attn.to_out(hidden_states) return hidden_states