cora / model /modules /register_attention.py
armikaeili's picture
code added
79c5088
import torch
import torch.nn as nn
from einops import rearrange, repeat
from typing import Any
from model.directional_attentions import DirectionalAttentionControl, AttentionBase
from utils.utils import find_smallest_key_with_suffix
def register_attention_editor_diffusers(model: Any, editor: AttentionBase):
def ca_forward(self, place_in_unet):
def forward(
x: torch.Tensor,
encoder_hidden_states: torch.Tensor = None,
attention_mask: torch.Tensor = None,
context: torch.Tensor = None,
mask: torch.Tensor = None
):
if encoder_hidden_states is not None:
context = encoder_hidden_states
if attention_mask is not None:
mask = attention_mask
h = self.heads
is_cross = context is not None
context = context if is_cross else x
q = self.to_q(x)
k = self.to_k(context)
v = self.to_v(context)
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))
sim = torch.einsum('b i d, b j d -> b i j', q, k) * self.scale
if mask is not None:
mask = rearrange(mask, 'b ... -> b (...)')
max_neg_value = -torch.finfo(sim.dtype).max
mask = repeat(mask, 'b j -> (b h) () j', h=h)
sim.masked_fill_(~mask, max_neg_value)
dift_features_dict = getattr(model.unet.latent_store, 'dift_features', {})
dift_features_key = find_smallest_key_with_suffix(dift_features_dict, suffix='_1')
dift_features = dift_features_dict.get(dift_features_key, None)
attn = sim.softmax(dim=-1)
out = editor(
q, k, v, sim, attn, is_cross, place_in_unet,
self.heads,
scale=self.scale,
dift_features=dift_features
)
to_out = self.to_out
if isinstance(to_out, nn.modules.container.ModuleList):
to_out = self.to_out[0]
return to_out(out)
return forward
def register_editor(net, count, place_in_unet):
for name, subnet in net.named_children():
if net.__class__.__name__ == 'Attention': # spatial Transformer layer
net.forward = ca_forward(net, place_in_unet)
return count + 1
elif hasattr(net, 'children'):
count = register_editor(subnet, count, place_in_unet)
return count
cross_att_count = 0
for net_name, net in model.unet.named_children():
if "down" in net_name:
cross_att_count += register_editor(net, 0, "down")
elif "mid" in net_name:
cross_att_count += register_editor(net, 0, "mid")
elif "up" in net_name:
cross_att_count += register_editor(net, 0, "up")
editor.num_att_layers = cross_att_count