Spaces:
Runtime error
Runtime error
| import torch | |
| import torch.nn as nn | |
| from einops import rearrange, repeat | |
| from typing import Any | |
| from model.directional_attentions import DirectionalAttentionControl, AttentionBase | |
| from utils.utils import find_smallest_key_with_suffix | |
| def register_attention_editor_diffusers(model: Any, editor: AttentionBase): | |
| def ca_forward(self, place_in_unet): | |
| def forward( | |
| x: torch.Tensor, | |
| encoder_hidden_states: torch.Tensor = None, | |
| attention_mask: torch.Tensor = None, | |
| context: torch.Tensor = None, | |
| mask: torch.Tensor = None | |
| ): | |
| if encoder_hidden_states is not None: | |
| context = encoder_hidden_states | |
| if attention_mask is not None: | |
| mask = attention_mask | |
| h = self.heads | |
| is_cross = context is not None | |
| context = context if is_cross else x | |
| q = self.to_q(x) | |
| k = self.to_k(context) | |
| v = self.to_v(context) | |
| q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v)) | |
| sim = torch.einsum('b i d, b j d -> b i j', q, k) * self.scale | |
| if mask is not None: | |
| mask = rearrange(mask, 'b ... -> b (...)') | |
| max_neg_value = -torch.finfo(sim.dtype).max | |
| mask = repeat(mask, 'b j -> (b h) () j', h=h) | |
| sim.masked_fill_(~mask, max_neg_value) | |
| dift_features_dict = getattr(model.unet.latent_store, 'dift_features', {}) | |
| dift_features_key = find_smallest_key_with_suffix(dift_features_dict, suffix='_1') | |
| dift_features = dift_features_dict.get(dift_features_key, None) | |
| attn = sim.softmax(dim=-1) | |
| out = editor( | |
| q, k, v, sim, attn, is_cross, place_in_unet, | |
| self.heads, | |
| scale=self.scale, | |
| dift_features=dift_features | |
| ) | |
| to_out = self.to_out | |
| if isinstance(to_out, nn.modules.container.ModuleList): | |
| to_out = self.to_out[0] | |
| return to_out(out) | |
| return forward | |
| def register_editor(net, count, place_in_unet): | |
| for name, subnet in net.named_children(): | |
| if net.__class__.__name__ == 'Attention': # spatial Transformer layer | |
| net.forward = ca_forward(net, place_in_unet) | |
| return count + 1 | |
| elif hasattr(net, 'children'): | |
| count = register_editor(subnet, count, place_in_unet) | |
| return count | |
| cross_att_count = 0 | |
| for net_name, net in model.unet.named_children(): | |
| if "down" in net_name: | |
| cross_att_count += register_editor(net, 0, "down") | |
| elif "mid" in net_name: | |
| cross_att_count += register_editor(net, 0, "mid") | |
| elif "up" in net_name: | |
| cross_att_count += register_editor(net, 0, "up") | |
| editor.num_att_layers = cross_att_count | |