import torch import torch.nn as nn import torch.nn.functional as F from einops import repeat import math from .udit import UDiT from .utils.span_mask import compute_mask_indices class EmbeddingCFG(nn.Module): """ Handles label dropout for classifier-free guidance. """ # todo: support 2D input def __init__(self, in_channels): super().__init__() self.cfg_embedding = nn.Parameter( torch.randn(in_channels) / in_channels ** 0.5) def token_drop(self, condition, condition_mask, cfg_prob): """ Drops labels to enable classifier-free guidance. """ b, t, device = condition.shape[0], condition.shape[1], condition.device drop_ids = torch.rand(b, device=device) < cfg_prob uncond = repeat(self.cfg_embedding, "c -> b t c", b=b, t=t) condition = torch.where(drop_ids[:, None, None], uncond, condition) if condition_mask is not None: condition_mask[drop_ids] = False condition_mask[drop_ids, 0] = True return condition, condition_mask def forward(self, condition, condition_mask, cfg_prob=0.0): if condition_mask is not None: condition_mask = condition_mask.clone() if cfg_prob > 0: condition, condition_mask = self.token_drop(condition, condition_mask, cfg_prob) return condition, condition_mask class DiscreteCFG(nn.Module): def __init__(self, replace_id=2): super(DiscreteCFG, self).__init__() self.replace_id = replace_id def forward(self, context, context_mask, cfg_prob): context = context.clone() if context_mask is not None: context_mask = context_mask.clone() if cfg_prob > 0: cfg_mask = torch.rand(len(context)) < cfg_prob if torch.any(cfg_mask): context[cfg_mask] = 0 context[cfg_mask, 0] = self.replace_id if context_mask is not None: context_mask[cfg_mask] = False context_mask[cfg_mask, 0] = True return context, context_mask class CFGModel(nn.Module): def __init__(self, context_dim, backbone): super().__init__() self.model = backbone self.context_cfg = EmbeddingCFG(context_dim) def forward(self, x, timesteps, context, x_mask=None, context_mask=None, cfg_prob=0.0): context = self.context_cfg(context, cfg_prob) x = self.model(x=x, timesteps=timesteps, context=context, x_mask=x_mask, context_mask=context_mask) return x class ConcatModel(nn.Module): def __init__(self, backbone, in_dim, stride=[]): super().__init__() self.model = backbone self.downsample_layers = nn.ModuleList() for i, s in enumerate(stride): downsample_layer = nn.Conv1d( in_dim, in_dim * 2, kernel_size=2 * s, stride=s, padding=math.ceil(s / 2), ) self.downsample_layers.append(downsample_layer) in_dim = in_dim * 2 self.context_cfg = EmbeddingCFG(in_dim) def forward(self, x, timesteps, context, x_mask=None, cfg=False, cfg_prob=0.0): # todo: support 2D input # x: B, C, L # context: B, C, L for downsample_layer in self.downsample_layers: context = downsample_layer(context) context = context.transpose(1, 2) context = self.context_cfg(caption=context, cfg=cfg, cfg_prob=cfg_prob) context = context.transpose(1, 2) assert context.shape[-1] == x.shape[-1] x = torch.cat([context, x], dim=1) x = self.model(x=x, timesteps=timesteps, context=None, x_mask=x_mask, context_mask=None) return x class MaskDiT(nn.Module): def __init__(self, mae=False, mae_prob=0.5, mask_ratio=[0.25, 1.0], mask_span=10, **kwargs): super().__init__() self.model = UDiT(**kwargs) self.mae = mae if self.mae: out_channel = kwargs.pop('out_chans', None) self.mask_embed = nn.Parameter(torch.zeros((out_channel))) self.mae_prob = mae_prob self.mask_ratio = mask_ratio self.mask_span = mask_span def random_masking(self, gt, mask_ratios, mae_mask_infer=None): B, D, L = gt.shape if mae_mask_infer is None: # mask = torch.rand(B, L).to(gt.device) < mask_ratios.unsqueeze(1) mask_ratios = mask_ratios.cpu().numpy() mask = compute_mask_indices(shape=[B, L], padding_mask=None, mask_prob=mask_ratios, mask_length=self.mask_span, mask_type="static", mask_other=0.0, min_masks=1, no_overlap=False, min_space=0,) mask = mask.unsqueeze(1).expand_as(gt) else: mask = mae_mask_infer mask = mask.expand_as(gt) gt[mask] = self.mask_embed.view(1, D, 1).expand_as(gt)[mask] return gt, mask.type_as(gt) def forward(self, x, timesteps, context, x_mask=None, context_mask=None, cls_token=None, gt=None, mae_mask_infer=None, forward_model=True): # todo: handle controlnet inside mae_mask = torch.ones_like(x) if self.mae: if gt is not None: B, D, L = gt.shape mask_ratios = torch.FloatTensor(B).uniform_(*self.mask_ratio).to(gt.device) gt, mae_mask = self.random_masking(gt, mask_ratios, mae_mask_infer) # apply mae only to the selected batches if mae_mask_infer is None: # determine mae batch mae_batch = torch.rand(B) < self.mae_prob gt[~mae_batch] = self.mask_embed.view(1, D, 1).expand_as(gt)[~mae_batch] mae_mask[~mae_batch] = 1.0 else: B, D, L = x.shape gt = self.mask_embed.view(1, D, 1).expand_as(x) x = torch.cat([x, gt, mae_mask[:, 0:1, :]], dim=1) if forward_model: x = self.model(x=x, timesteps=timesteps, context=context, x_mask=x_mask, context_mask=context_mask, cls_token=cls_token) # print(mae_mask[:, 0, :].sum(dim=-1)) return x, mae_mask