Spaces:
Runtime error
Runtime error
Mehdi Cherti
commited on
Commit
·
8ab4de9
1
Parent(s):
3dcdf92
add basic cross attention + global attention block
Browse files- score_sde/models/layers.py +1 -1
- score_sde/models/layerspp.py +28 -0
- score_sde/models/ncsnpp_generator_adagn.py +42 -4
- train_ddgan.py +39 -25
score_sde/models/layers.py
CHANGED
|
@@ -583,7 +583,7 @@ class Identity(nn.Module):
|
|
| 583 |
def forward(self, x, *args, **kwargs):
|
| 584 |
return x
|
| 585 |
|
| 586 |
-
|
| 587 |
class CrossAttention(nn.Module):
|
| 588 |
def __init__(
|
| 589 |
self,
|
|
|
|
| 583 |
def forward(self, x, *args, **kwargs):
|
| 584 |
return x
|
| 585 |
|
| 586 |
+
|
| 587 |
class CrossAttention(nn.Module):
|
| 588 |
def __init__(
|
| 589 |
self,
|
score_sde/models/layerspp.py
CHANGED
|
@@ -123,6 +123,34 @@ class AttnBlockpp(nn.Module):
|
|
| 123 |
else:
|
| 124 |
return (x + h) / np.sqrt(2.)
|
| 125 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 126 |
|
| 127 |
class Upsample(nn.Module):
|
| 128 |
def __init__(self, in_ch=None, out_ch=None, with_conv=False, fir=False,
|
|
|
|
| 123 |
else:
|
| 124 |
return (x + h) / np.sqrt(2.)
|
| 125 |
|
| 126 |
+
class AttnBlockppRaw(nn.Module):
|
| 127 |
+
"""Channel-wise self-attention block. Modified from DDPM."""
|
| 128 |
+
|
| 129 |
+
def __init__(self, channels, skip_rescale=False, init_scale=0.):
|
| 130 |
+
super().__init__()
|
| 131 |
+
self.GroupNorm_0 = nn.GroupNorm(num_groups=min(channels // 4, 32), num_channels=channels,
|
| 132 |
+
eps=1e-6)
|
| 133 |
+
self.NIN_0 = NIN(channels, channels)
|
| 134 |
+
self.NIN_1 = NIN(channels, channels)
|
| 135 |
+
self.NIN_2 = NIN(channels, channels)
|
| 136 |
+
self.NIN_3 = NIN(channels, channels, init_scale=init_scale)
|
| 137 |
+
self.skip_rescale = skip_rescale
|
| 138 |
+
|
| 139 |
+
def forward(self, x):
|
| 140 |
+
B, C, H, W = x.shape
|
| 141 |
+
h = self.GroupNorm_0(x)
|
| 142 |
+
q = self.NIN_0(h)
|
| 143 |
+
k = self.NIN_1(h)
|
| 144 |
+
v = self.NIN_2(h)
|
| 145 |
+
|
| 146 |
+
w = torch.einsum('bchw,bcij->bhwij', q, k) * (int(C) ** (-0.5))
|
| 147 |
+
w = torch.reshape(w, (B, H, W, H * W))
|
| 148 |
+
w = F.softmax(w, dim=-1)
|
| 149 |
+
w = torch.reshape(w, (B, H, W, H, W))
|
| 150 |
+
h = torch.einsum('bhwij,bcij->bchw', w, v)
|
| 151 |
+
h = self.NIN_3(h)
|
| 152 |
+
return h
|
| 153 |
+
|
| 154 |
|
| 155 |
class Upsample(nn.Module):
|
| 156 |
def __init__(self, in_ch=None, out_ch=None, with_conv=False, fir=False,
|
score_sde/models/ncsnpp_generator_adagn.py
CHANGED
|
@@ -53,6 +53,36 @@ get_act = layers.get_act
|
|
| 53 |
default_initializer = layers.default_init
|
| 54 |
dense = dense_layer.dense
|
| 55 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 56 |
class PixelNorm(nn.Module):
|
| 57 |
def __init__(self):
|
| 58 |
super().__init__()
|
|
@@ -68,6 +98,7 @@ class NCSNpp(nn.Module):
|
|
| 68 |
def __init__(self, config):
|
| 69 |
super().__init__()
|
| 70 |
self.config = config
|
|
|
|
| 71 |
self.grad_checkpointing = config.grad_checkpointing if hasattr(config, "grad_checkpointing") else False
|
| 72 |
self.not_use_tanh = config.not_use_tanh
|
| 73 |
self.act = act = nn.SiLU()
|
|
@@ -124,7 +155,14 @@ class NCSNpp(nn.Module):
|
|
| 124 |
modules[-1].weight.data = default_initializer()(modules[-1].weight.shape)
|
| 125 |
nn.init.zeros_(modules[-1].bias)
|
| 126 |
if config.cross_attention:
|
| 127 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 128 |
else:
|
| 129 |
AttnBlock = functools.partial(layerspp.AttnBlockpp,
|
| 130 |
init_scale=init_scale,
|
|
@@ -342,7 +380,7 @@ class NCSNpp(nn.Module):
|
|
| 342 |
h = modules[m_idx](hs[-1], temb, zemb)
|
| 343 |
m_idx += 1
|
| 344 |
if h.shape[-1] in self.attn_resolutions:
|
| 345 |
-
if type(modules[m_idx])
|
| 346 |
h = modules[m_idx](h, cond, cond_mask)
|
| 347 |
else:
|
| 348 |
h = modules[m_idx](h)
|
|
@@ -377,7 +415,7 @@ class NCSNpp(nn.Module):
|
|
| 377 |
h = hs[-1]
|
| 378 |
h = modules[m_idx](h, temb, zemb)
|
| 379 |
m_idx += 1
|
| 380 |
-
if type(modules[m_idx])
|
| 381 |
h = modules[m_idx](h, cond, cond_mask)
|
| 382 |
else:
|
| 383 |
h = modules[m_idx](h)
|
|
@@ -394,7 +432,7 @@ class NCSNpp(nn.Module):
|
|
| 394 |
m_idx += 1
|
| 395 |
|
| 396 |
if h.shape[-1] in self.attn_resolutions:
|
| 397 |
-
if type(modules[m_idx])
|
| 398 |
h = modules[m_idx](h, cond, cond_mask)
|
| 399 |
else:
|
| 400 |
h = modules[m_idx](h)
|
|
|
|
| 53 |
default_initializer = layers.default_init
|
| 54 |
dense = dense_layer.dense
|
| 55 |
|
| 56 |
+
class CrossAndGlobalAttnBlock(nn.Module):
|
| 57 |
+
"""Channel-wise self-attention block."""
|
| 58 |
+
def __init__(self, channels, *, context_dim=None, dim_head=64, heads=8, norm_context=False, cosine_sim_attn=False):
|
| 59 |
+
super().__init__()
|
| 60 |
+
self.GroupNorm_0 = nn.GroupNorm(num_groups=32, num_channels=channels, eps=1e-6)
|
| 61 |
+
self.ca = layers.CrossAttention(
|
| 62 |
+
channels,
|
| 63 |
+
context_dim=context_dim,
|
| 64 |
+
dim_head=dim_head,
|
| 65 |
+
heads=heads,
|
| 66 |
+
norm_context=norm_context,
|
| 67 |
+
cosine_sim_attn=cosine_sim_attn,
|
| 68 |
+
)
|
| 69 |
+
self.attn = layerspp.AttnBlockppRaw(channels)
|
| 70 |
+
|
| 71 |
+
def forward(self, x, cond, mask=None):
|
| 72 |
+
B, C, H, W = x.shape
|
| 73 |
+
h = self.GroupNorm_0(x)
|
| 74 |
+
h = h.view(B, C, H*W)
|
| 75 |
+
h = h.permute(0,2,1)
|
| 76 |
+
h = h.contiguous()
|
| 77 |
+
h_new = self.ca(h, cond, mask=mask)
|
| 78 |
+
h_new = h_new.permute(0,2,1)
|
| 79 |
+
h_new = h_new.contiguous()
|
| 80 |
+
h_new = h_new.view(B, C, H, W)
|
| 81 |
+
|
| 82 |
+
h_global = self.attn(x)
|
| 83 |
+
h = h_new + h_global
|
| 84 |
+
return x + h
|
| 85 |
+
|
| 86 |
class PixelNorm(nn.Module):
|
| 87 |
def __init__(self):
|
| 88 |
super().__init__()
|
|
|
|
| 98 |
def __init__(self, config):
|
| 99 |
super().__init__()
|
| 100 |
self.config = config
|
| 101 |
+
self.cross_attention_block = config.cross_attention_block
|
| 102 |
self.grad_checkpointing = config.grad_checkpointing if hasattr(config, "grad_checkpointing") else False
|
| 103 |
self.not_use_tanh = config.not_use_tanh
|
| 104 |
self.act = act = nn.SiLU()
|
|
|
|
| 155 |
modules[-1].weight.data = default_initializer()(modules[-1].weight.shape)
|
| 156 |
nn.init.zeros_(modules[-1].bias)
|
| 157 |
if config.cross_attention:
|
| 158 |
+
|
| 159 |
+
#block_name = config.cross_attention_block if hasattr(config, "cross_attention_block") else "basic"
|
| 160 |
+
block_name = config.cross_attention_block
|
| 161 |
+
if block_name == "basic":
|
| 162 |
+
AttnBlock = functools.partial(layers.CondAttnBlock, context_dim=config.cond_size)
|
| 163 |
+
elif block_name == "cross_and_global_attention":
|
| 164 |
+
AttnBlock = functools.partial(CrossAndGlobalAttnBlock, context_dim=config.cond_size)
|
| 165 |
+
print(AttnBlock)
|
| 166 |
else:
|
| 167 |
AttnBlock = functools.partial(layerspp.AttnBlockpp,
|
| 168 |
init_scale=init_scale,
|
|
|
|
| 380 |
h = modules[m_idx](hs[-1], temb, zemb)
|
| 381 |
m_idx += 1
|
| 382 |
if h.shape[-1] in self.attn_resolutions:
|
| 383 |
+
if type(modules[m_idx]) in (layers.CondAttnBlock, CrossAndGlobalAttnBlock):
|
| 384 |
h = modules[m_idx](h, cond, cond_mask)
|
| 385 |
else:
|
| 386 |
h = modules[m_idx](h)
|
|
|
|
| 415 |
h = hs[-1]
|
| 416 |
h = modules[m_idx](h, temb, zemb)
|
| 417 |
m_idx += 1
|
| 418 |
+
if type(modules[m_idx]) in (layers.CondAttnBlock, CrossAndGlobalAttnBlock):
|
| 419 |
h = modules[m_idx](h, cond, cond_mask)
|
| 420 |
else:
|
| 421 |
h = modules[m_idx](h)
|
|
|
|
| 432 |
m_idx += 1
|
| 433 |
|
| 434 |
if h.shape[-1] in self.attn_resolutions:
|
| 435 |
+
if type(modules[m_idx]) in (layers.CondAttnBlock, CrossAndGlobalAttnBlock):
|
| 436 |
h = modules[m_idx](h, cond, cond_mask)
|
| 437 |
else:
|
| 438 |
h = modules[m_idx](h)
|
train_ddgan.py
CHANGED
|
@@ -385,9 +385,10 @@ def train(rank, gpu, args):
|
|
| 385 |
backbone_kwargs={"cond_size": text_encoder.output_size}
|
| 386 |
)
|
| 387 |
netD = netD.to(device)
|
| 388 |
-
|
| 389 |
-
|
| 390 |
-
|
|
|
|
| 391 |
|
| 392 |
if args.fsdp:
|
| 393 |
from fairscale.nn.checkpoint.checkpoint_activations import checkpoint_wrapper
|
|
@@ -410,8 +411,9 @@ def train(rank, gpu, args):
|
|
| 410 |
if args.fsdp:
|
| 411 |
netD = nn.parallel.DistributedDataParallel(netD, device_ids=[gpu])
|
| 412 |
else:
|
| 413 |
-
|
| 414 |
-
|
|
|
|
| 415 |
#if args.discr_type == "projected_gan":
|
| 416 |
# netD._set_static_graph()
|
| 417 |
|
|
@@ -652,7 +654,8 @@ def train(rank, gpu, args):
|
|
| 652 |
torchvision.utils.save_image(fake_sample, os.path.join(exp_path, 'sample_discrete_epoch_{}_iteration_{}.png'.format(epoch, iteration)), normalize=True)
|
| 653 |
|
| 654 |
if args.save_content:
|
| 655 |
-
|
|
|
|
| 656 |
if rank == 0:
|
| 657 |
print('Saving content.')
|
| 658 |
def to_cpu(d):
|
|
@@ -709,20 +712,26 @@ def init_processes(rank, size, fn, args):
|
|
| 709 |
""" Initialize the distributed environment. """
|
| 710 |
|
| 711 |
import os
|
| 712 |
-
|
| 713 |
-
|
| 714 |
-
|
| 715 |
-
|
| 716 |
-
|
| 717 |
-
|
| 718 |
-
|
| 719 |
-
|
| 720 |
-
|
| 721 |
-
|
| 722 |
-
|
| 723 |
-
|
| 724 |
-
|
| 725 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 726 |
|
| 727 |
def cleanup():
|
| 728 |
dist.destroy_process_group()
|
|
@@ -737,6 +746,8 @@ if __name__ == '__main__':
|
|
| 737 |
parser.add_argument('--mismatch_loss', action='store_true',default=False, help="use mismatch loss")
|
| 738 |
parser.add_argument('--text_encoder', type=str, default="google/t5-v1_1-base")
|
| 739 |
parser.add_argument('--cross_attention', action='store_true',default=False, help="use cross attention in generator")
|
|
|
|
|
|
|
| 740 |
parser.add_argument('--fsdp', action='store_true',default=False, help='use FSDP')
|
| 741 |
parser.add_argument('--grad_checkpointing', action='store_true',default=False, help='use grad checkpointing')
|
| 742 |
|
|
@@ -809,7 +820,7 @@ if __name__ == '__main__':
|
|
| 809 |
parser.add_argument('--beta2', type=float, default=0.9,
|
| 810 |
help='beta2 for adam')
|
| 811 |
parser.add_argument('--no_lr_decay',action='store_true', default=False)
|
| 812 |
-
parser.add_argument('--grad_penalty_cond', action='store_true',default=False, help="cond based grad
|
| 813 |
|
| 814 |
parser.add_argument('--use_ema', action='store_true', default=False,
|
| 815 |
help='use EMA or not')
|
|
@@ -828,6 +839,7 @@ if __name__ == '__main__':
|
|
| 828 |
parser.add_argument('--precision', type=str, default="fp32")
|
| 829 |
|
| 830 |
###ddp
|
|
|
|
| 831 |
parser.add_argument('--num_proc_node', type=int, default=1,
|
| 832 |
help='The number of nodes in multi node env.')
|
| 833 |
parser.add_argument('--num_process_per_node', type=int, default=1,
|
|
@@ -840,8 +852,10 @@ if __name__ == '__main__':
|
|
| 840 |
help='address for master')
|
| 841 |
|
| 842 |
args = parser.parse_args()
|
| 843 |
-
|
| 844 |
-
|
| 845 |
-
|
| 846 |
-
|
|
|
|
|
|
|
| 847 |
init_processes(args.rank, args.world_size, train, args)
|
|
|
|
| 385 |
backbone_kwargs={"cond_size": text_encoder.output_size}
|
| 386 |
)
|
| 387 |
netD = netD.to(device)
|
| 388 |
+
|
| 389 |
+
if args.world_size > 1:
|
| 390 |
+
broadcast_params(netG.parameters())
|
| 391 |
+
broadcast_params(netD.parameters())
|
| 392 |
|
| 393 |
if args.fsdp:
|
| 394 |
from fairscale.nn.checkpoint.checkpoint_activations import checkpoint_wrapper
|
|
|
|
| 411 |
if args.fsdp:
|
| 412 |
netD = nn.parallel.DistributedDataParallel(netD, device_ids=[gpu])
|
| 413 |
else:
|
| 414 |
+
if args.world_size > 1:
|
| 415 |
+
netG = nn.parallel.DistributedDataParallel(netG, device_ids=[gpu])
|
| 416 |
+
netD = nn.parallel.DistributedDataParallel(netD, device_ids=[gpu], find_unused_parameters=args.discr_type=="projected_gan")
|
| 417 |
#if args.discr_type == "projected_gan":
|
| 418 |
# netD._set_static_graph()
|
| 419 |
|
|
|
|
| 654 |
torchvision.utils.save_image(fake_sample, os.path.join(exp_path, 'sample_discrete_epoch_{}_iteration_{}.png'.format(epoch, iteration)), normalize=True)
|
| 655 |
|
| 656 |
if args.save_content:
|
| 657 |
+
if args.world_size > 1:
|
| 658 |
+
dist.barrier()
|
| 659 |
if rank == 0:
|
| 660 |
print('Saving content.')
|
| 661 |
def to_cpu(d):
|
|
|
|
| 712 |
""" Initialize the distributed environment. """
|
| 713 |
|
| 714 |
import os
|
| 715 |
+
|
| 716 |
+
if size == 1:
|
| 717 |
+
args.rank = 0
|
| 718 |
+
args.world_size = 1
|
| 719 |
+
args.local_rank = 0
|
| 720 |
+
fn(rank,args.local_rank, args)
|
| 721 |
+
else:
|
| 722 |
+
args.rank = int(os.environ['SLURM_PROCID'])
|
| 723 |
+
args.world_size = int(os.getenv("SLURM_NTASKS"))
|
| 724 |
+
args.local_rank = int(os.environ['SLURM_LOCALID'])
|
| 725 |
+
print(args.rank, args.world_size)
|
| 726 |
+
args.master_address = os.getenv("SLURM_LAUNCH_NODE_IPADDR")
|
| 727 |
+
os.environ['MASTER_ADDR'] = args.master_address
|
| 728 |
+
os.environ['MASTER_PORT'] = "12345"
|
| 729 |
+
torch.cuda.set_device(args.local_rank)
|
| 730 |
+
gpu = args.local_rank
|
| 731 |
+
dist.init_process_group(backend='nccl', init_method='env://', rank=rank, world_size=args.world_size)
|
| 732 |
+
fn(rank, gpu, args)
|
| 733 |
+
dist.barrier()
|
| 734 |
+
cleanup()
|
| 735 |
|
| 736 |
def cleanup():
|
| 737 |
dist.destroy_process_group()
|
|
|
|
| 746 |
parser.add_argument('--mismatch_loss', action='store_true',default=False, help="use mismatch loss")
|
| 747 |
parser.add_argument('--text_encoder', type=str, default="google/t5-v1_1-base")
|
| 748 |
parser.add_argument('--cross_attention', action='store_true',default=False, help="use cross attention in generator")
|
| 749 |
+
parser.add_argument('--cross_attention_block', default="basic", help="cross attention block type")
|
| 750 |
+
|
| 751 |
parser.add_argument('--fsdp', action='store_true',default=False, help='use FSDP')
|
| 752 |
parser.add_argument('--grad_checkpointing', action='store_true',default=False, help='use grad checkpointing')
|
| 753 |
|
|
|
|
| 820 |
parser.add_argument('--beta2', type=float, default=0.9,
|
| 821 |
help='beta2 for adam')
|
| 822 |
parser.add_argument('--no_lr_decay',action='store_true', default=False)
|
| 823 |
+
parser.add_argument('--grad_penalty_cond', action='store_true',default=False, help="cond based grad")
|
| 824 |
|
| 825 |
parser.add_argument('--use_ema', action='store_true', default=False,
|
| 826 |
help='use EMA or not')
|
|
|
|
| 839 |
parser.add_argument('--precision', type=str, default="fp32")
|
| 840 |
|
| 841 |
###ddp
|
| 842 |
+
|
| 843 |
parser.add_argument('--num_proc_node', type=int, default=1,
|
| 844 |
help='The number of nodes in multi node env.')
|
| 845 |
parser.add_argument('--num_process_per_node', type=int, default=1,
|
|
|
|
| 852 |
help='address for master')
|
| 853 |
|
| 854 |
args = parser.parse_args()
|
| 855 |
+
if 'SLURM_NTASKS' in os.environ:
|
| 856 |
+
args.world_size = int(os.getenv("SLURM_NTASKS"))
|
| 857 |
+
args.rank = int(os.environ['SLURM_PROCID'])
|
| 858 |
+
else:
|
| 859 |
+
args.world_size = 1
|
| 860 |
+
args.rank = 0
|
| 861 |
init_processes(args.rank, args.world_size, train, args)
|