Spaces:
Paused
Paused
from typing import Optional | |
import torch | |
from .attention import HiDreamAttention | |
try: | |
from flash_attn_interface import flash_attn_func | |
USE_FLASH_ATTN3 = True | |
except: | |
from flash_attn import flash_attn_func | |
USE_FLASH_ATTN3 = False | |
# Copied from https://github.com/black-forest-labs/flux/blob/main/src/flux/math.py | |
def apply_rope(xq: torch.Tensor, xk: torch.Tensor, freqs_cis: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: | |
xq_ = xq.float().reshape(*xq.shape[:-1], -1, 1, 2) | |
xk_ = xk.float().reshape(*xk.shape[:-1], -1, 1, 2) | |
xq_out = freqs_cis[..., 0] * xq_[..., 0] + freqs_cis[..., 1] * xq_[..., 1] | |
xk_out = freqs_cis[..., 0] * xk_[..., 0] + freqs_cis[..., 1] * xk_[..., 1] | |
return xq_out.reshape(*xq.shape).type_as(xq), xk_out.reshape(*xk.shape).type_as(xk) | |
def attention(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor): | |
if USE_FLASH_ATTN3: | |
hidden_states = flash_attn_func(query, key, value, causal=False, deterministic=False)[0] | |
else: | |
hidden_states = flash_attn_func(query, key, value, dropout_p=0., causal=False) | |
hidden_states = hidden_states.flatten(-2) | |
hidden_states = hidden_states.to(query.dtype) | |
return hidden_states | |
class HiDreamAttnProcessor_flashattn: | |
"""Attention processor used typically in processing the SD3-like self-attention projections.""" | |
def __call__( | |
self, | |
attn: HiDreamAttention, | |
image_tokens: torch.FloatTensor, | |
image_tokens_masks: Optional[torch.FloatTensor] = None, | |
text_tokens: Optional[torch.FloatTensor] = None, | |
rope: torch.FloatTensor = None, | |
*args, | |
**kwargs, | |
) -> torch.FloatTensor: | |
dtype = image_tokens.dtype | |
batch_size = image_tokens.shape[0] | |
query_i = attn.q_rms_norm(attn.to_q(image_tokens)).to(dtype=dtype) | |
key_i = attn.k_rms_norm(attn.to_k(image_tokens)).to(dtype=dtype) | |
value_i = attn.to_v(image_tokens) | |
inner_dim = key_i.shape[-1] | |
head_dim = inner_dim // attn.heads | |
query_i = query_i.view(batch_size, -1, attn.heads, head_dim) | |
key_i = key_i.view(batch_size, -1, attn.heads, head_dim) | |
value_i = value_i.view(batch_size, -1, attn.heads, head_dim) | |
if image_tokens_masks is not None: | |
key_i = key_i * image_tokens_masks.view(batch_size, -1, 1, 1) | |
if not attn.single: | |
query_t = attn.q_rms_norm_t(attn.to_q_t(text_tokens)).to(dtype=dtype) | |
key_t = attn.k_rms_norm_t(attn.to_k_t(text_tokens)).to(dtype=dtype) | |
value_t = attn.to_v_t(text_tokens) | |
query_t = query_t.view(batch_size, -1, attn.heads, head_dim) | |
key_t = key_t.view(batch_size, -1, attn.heads, head_dim) | |
value_t = value_t.view(batch_size, -1, attn.heads, head_dim) | |
num_image_tokens = query_i.shape[1] | |
num_text_tokens = query_t.shape[1] | |
query = torch.cat([query_i, query_t], dim=1) | |
key = torch.cat([key_i, key_t], dim=1) | |
value = torch.cat([value_i, value_t], dim=1) | |
else: | |
query = query_i | |
key = key_i | |
value = value_i | |
if query.shape[-1] == rope.shape[-3] * 2: | |
query, key = apply_rope(query, key, rope) | |
else: | |
query_1, query_2 = query.chunk(2, dim=-1) | |
key_1, key_2 = key.chunk(2, dim=-1) | |
query_1, key_1 = apply_rope(query_1, key_1, rope) | |
query = torch.cat([query_1, query_2], dim=-1) | |
key = torch.cat([key_1, key_2], dim=-1) | |
hidden_states = attention(query, key, value) | |
if not attn.single: | |
hidden_states_i, hidden_states_t = torch.split(hidden_states, [num_image_tokens, num_text_tokens], dim=1) | |
hidden_states_i = attn.to_out(hidden_states_i) | |
hidden_states_t = attn.to_out_t(hidden_states_t) | |
return hidden_states_i, hidden_states_t | |
else: | |
hidden_states = attn.to_out(hidden_states) | |
return hidden_states |