File size: 2,957 Bytes
79c5088
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
import torch
import torch.nn as nn
from einops import rearrange, repeat

from typing import Any
from model.directional_attentions import DirectionalAttentionControl, AttentionBase
from utils.utils import find_smallest_key_with_suffix


def register_attention_editor_diffusers(model: Any, editor: AttentionBase):
    def ca_forward(self, place_in_unet):
        def forward(
            x: torch.Tensor, 
            encoder_hidden_states: torch.Tensor = None,
            attention_mask: torch.Tensor = None,
            context: torch.Tensor = None,
            mask: torch.Tensor = None
        ):
            if encoder_hidden_states is not None:
                context = encoder_hidden_states
            if attention_mask is not None:
                mask = attention_mask

            h = self.heads
            is_cross = context is not None
            context = context if is_cross else x

            q = self.to_q(x)
            k = self.to_k(context)
            v = self.to_v(context)

            q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))
            sim = torch.einsum('b i d, b j d -> b i j', q, k) * self.scale

            if mask is not None:
                mask = rearrange(mask, 'b ... -> b (...)')
                max_neg_value = -torch.finfo(sim.dtype).max
                mask = repeat(mask, 'b j -> (b h) () j', h=h)
                sim.masked_fill_(~mask, max_neg_value)

            dift_features_dict = getattr(model.unet.latent_store, 'dift_features', {})
            dift_features_key = find_smallest_key_with_suffix(dift_features_dict, suffix='_1')
            dift_features = dift_features_dict.get(dift_features_key, None)

            attn = sim.softmax(dim=-1)
            out = editor(
                q, k, v, sim, attn, is_cross, place_in_unet,
                self.heads,
                scale=self.scale,
                dift_features=dift_features
            )

            to_out = self.to_out
            if isinstance(to_out, nn.modules.container.ModuleList):
                to_out = self.to_out[0]

            return to_out(out)
        return forward

    def register_editor(net, count, place_in_unet):
        for name, subnet in net.named_children():
            if net.__class__.__name__ == 'Attention':  # spatial Transformer layer
                net.forward = ca_forward(net, place_in_unet)
                return count + 1
            elif hasattr(net, 'children'):
                count = register_editor(subnet, count, place_in_unet)
        return count

    cross_att_count = 0
    for net_name, net in model.unet.named_children():
        if "down" in net_name:
            cross_att_count += register_editor(net, 0, "down")
        elif "mid" in net_name:
            cross_att_count += register_editor(net, 0, "mid")
        elif "up" in net_name:
            cross_att_count += register_editor(net, 0, "up")
    editor.num_att_layers = cross_att_count