Spaces:
Paused
Paused
File size: 6,940 Bytes
9d3cb0a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 |
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import repeat
import math
from .udit import UDiT
from .utils.span_mask import compute_mask_indices
class EmbeddingCFG(nn.Module):
"""
Handles label dropout for classifier-free guidance.
"""
# todo: support 2D input
def __init__(self, in_channels):
super().__init__()
self.cfg_embedding = nn.Parameter(
torch.randn(in_channels) / in_channels ** 0.5)
def token_drop(self, condition, condition_mask, cfg_prob):
"""
Drops labels to enable classifier-free guidance.
"""
b, t, device = condition.shape[0], condition.shape[1], condition.device
drop_ids = torch.rand(b, device=device) < cfg_prob
uncond = repeat(self.cfg_embedding, "c -> b t c", b=b, t=t)
condition = torch.where(drop_ids[:, None, None], uncond, condition)
if condition_mask is not None:
condition_mask[drop_ids] = False
condition_mask[drop_ids, 0] = True
return condition, condition_mask
def forward(self, condition, condition_mask, cfg_prob=0.0):
if condition_mask is not None:
condition_mask = condition_mask.clone()
if cfg_prob > 0:
condition, condition_mask = self.token_drop(condition,
condition_mask,
cfg_prob)
return condition, condition_mask
class DiscreteCFG(nn.Module):
def __init__(self, replace_id=2):
super(DiscreteCFG, self).__init__()
self.replace_id = replace_id
def forward(self, context, context_mask, cfg_prob):
context = context.clone()
if context_mask is not None:
context_mask = context_mask.clone()
if cfg_prob > 0:
cfg_mask = torch.rand(len(context)) < cfg_prob
if torch.any(cfg_mask):
context[cfg_mask] = 0
context[cfg_mask, 0] = self.replace_id
if context_mask is not None:
context_mask[cfg_mask] = False
context_mask[cfg_mask, 0] = True
return context, context_mask
class CFGModel(nn.Module):
def __init__(self, context_dim, backbone):
super().__init__()
self.model = backbone
self.context_cfg = EmbeddingCFG(context_dim)
def forward(self, x, timesteps,
context, x_mask=None, context_mask=None,
cfg_prob=0.0):
context = self.context_cfg(context, cfg_prob)
x = self.model(x=x, timesteps=timesteps,
context=context,
x_mask=x_mask, context_mask=context_mask)
return x
class ConcatModel(nn.Module):
def __init__(self, backbone, in_dim, stride=[]):
super().__init__()
self.model = backbone
self.downsample_layers = nn.ModuleList()
for i, s in enumerate(stride):
downsample_layer = nn.Conv1d(
in_dim,
in_dim * 2,
kernel_size=2 * s,
stride=s,
padding=math.ceil(s / 2),
)
self.downsample_layers.append(downsample_layer)
in_dim = in_dim * 2
self.context_cfg = EmbeddingCFG(in_dim)
def forward(self, x, timesteps,
context, x_mask=None,
cfg=False, cfg_prob=0.0):
# todo: support 2D input
# x: B, C, L
# context: B, C, L
for downsample_layer in self.downsample_layers:
context = downsample_layer(context)
context = context.transpose(1, 2)
context = self.context_cfg(caption=context,
cfg=cfg, cfg_prob=cfg_prob)
context = context.transpose(1, 2)
assert context.shape[-1] == x.shape[-1]
x = torch.cat([context, x], dim=1)
x = self.model(x=x, timesteps=timesteps,
context=None, x_mask=x_mask, context_mask=None)
return x
class MaskDiT(nn.Module):
def __init__(self, mae=False, mae_prob=0.5, mask_ratio=[0.25, 1.0], mask_span=10, **kwargs):
super().__init__()
self.model = UDiT(**kwargs)
self.mae = mae
if self.mae:
out_channel = kwargs.pop('out_chans', None)
self.mask_embed = nn.Parameter(torch.zeros((out_channel)))
self.mae_prob = mae_prob
self.mask_ratio = mask_ratio
self.mask_span = mask_span
def random_masking(self, gt, mask_ratios, mae_mask_infer=None):
B, D, L = gt.shape
if mae_mask_infer is None:
# mask = torch.rand(B, L).to(gt.device) < mask_ratios.unsqueeze(1)
mask_ratios = mask_ratios.cpu().numpy()
mask = compute_mask_indices(shape=[B, L],
padding_mask=None,
mask_prob=mask_ratios,
mask_length=self.mask_span,
mask_type="static",
mask_other=0.0,
min_masks=1,
no_overlap=False,
min_space=0,)
mask = mask.unsqueeze(1).expand_as(gt)
else:
mask = mae_mask_infer
mask = mask.expand_as(gt)
gt[mask] = self.mask_embed.view(1, D, 1).expand_as(gt)[mask]
return gt, mask.type_as(gt)
def forward(self, x, timesteps, context,
x_mask=None, context_mask=None, cls_token=None,
gt=None, mae_mask_infer=None,
forward_model=True):
# todo: handle controlnet inside
mae_mask = torch.ones_like(x)
if self.mae:
if gt is not None:
B, D, L = gt.shape
mask_ratios = torch.FloatTensor(B).uniform_(*self.mask_ratio).to(gt.device)
gt, mae_mask = self.random_masking(gt, mask_ratios, mae_mask_infer)
# apply mae only to the selected batches
if mae_mask_infer is None:
# determine mae batch
mae_batch = torch.rand(B) < self.mae_prob
gt[~mae_batch] = self.mask_embed.view(1, D, 1).expand_as(gt)[~mae_batch]
mae_mask[~mae_batch] = 1.0
else:
B, D, L = x.shape
gt = self.mask_embed.view(1, D, 1).expand_as(x)
x = torch.cat([x, gt, mae_mask[:, 0:1, :]], dim=1)
if forward_model:
x = self.model(x=x, timesteps=timesteps, context=context,
x_mask=x_mask, context_mask=context_mask,
cls_token=cls_token)
# print(mae_mask[:, 0, :].sum(dim=-1))
return x, mae_mask
|