Andranik Sargsyan
add demo code
bfd34e9
raw
history blame contribute delete
No virus
1.93 kB
import torch
from torch import nn, einsum
from einops import rearrange, repeat
from ... import share
from ..attentionpatch import painta
use_grad = True
def forward(self, x, context=None):
# Todo: add batch inference support
if use_grad:
y, self_v, self_sim = self.attn1(self.norm1(x), None) # Self Attn.
x_uncond, x_cond = x.chunk(2)
context_uncond, context_cond = context.chunk(2)
y_uncond, y_cond = y.chunk(2)
self_sim_uncond, self_sim_cond = self_sim.chunk(2)
self_v_uncond, self_v_cond = self_v.chunk(2)
# Calculate CA similarities with conditional context
cross_h = self.attn2.heads
cross_q = self.attn2.to_q(self.norm2(x_cond+y_cond))
cross_k = self.attn2.to_k(context_cond)
cross_v = self.attn2.to_v(context_cond)
cross_q, cross_k, cross_v = map(lambda t: rearrange(t, "b n (h d) -> (b h) n d", h=cross_h), (cross_q, cross_k, cross_v))
with torch.autocast(enabled=False, device_type = 'cuda'):
cross_q, cross_k = cross_q.float(), cross_k.float()
cross_sim = einsum('b i d, b j d -> b i j', cross_q, cross_k) * self.attn2.scale
del cross_q, cross_k
cross_sim = cross_sim.softmax(dim=-1) # Up to this point cross_sim is regular cross_sim in CA layer
cross_sim = cross_sim.mean(dim=0) # Calculate mean across heads
# PAIntA rescale
y_cond = painta.painta_rescale(
y_cond, self_v_cond, self_sim_cond, cross_sim, self.attn1.heads, self.attn1.to_out) # Rescale cond
y_uncond = painta.painta_rescale(
y_uncond, self_v_uncond, self_sim_uncond, cross_sim, self.attn1.heads, self.attn1.to_out) # Rescale uncond
y = torch.cat([y_uncond, y_cond], dim=0)
x = x + y
x = x + self.attn2(self.norm2(x), context=context) # Cross Attn.
x = x + self.ff(self.norm3(x))
return x