File size: 6,940 Bytes
9d3cb0a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import repeat
import math
from .udit import UDiT
from .utils.span_mask import compute_mask_indices


class EmbeddingCFG(nn.Module):
    """
    Handles label dropout for classifier-free guidance.
    """
    # todo: support 2D input

    def __init__(self, in_channels):
        super().__init__()
        self.cfg_embedding = nn.Parameter(
            torch.randn(in_channels) / in_channels ** 0.5)

    def token_drop(self, condition, condition_mask, cfg_prob):
        """
        Drops labels to enable classifier-free guidance.
        """
        b, t, device = condition.shape[0], condition.shape[1], condition.device
        drop_ids = torch.rand(b, device=device) < cfg_prob
        uncond = repeat(self.cfg_embedding, "c -> b t c", b=b, t=t)
        condition = torch.where(drop_ids[:, None, None], uncond, condition)
        if condition_mask is not None:
            condition_mask[drop_ids] = False
            condition_mask[drop_ids, 0] = True

        return condition, condition_mask

    def forward(self, condition, condition_mask, cfg_prob=0.0):
        if condition_mask is not None:
            condition_mask = condition_mask.clone()
        if cfg_prob > 0:
            condition, condition_mask = self.token_drop(condition,
                                                        condition_mask,
                                                        cfg_prob)
        return condition, condition_mask


class DiscreteCFG(nn.Module):
    def __init__(self, replace_id=2):
        super(DiscreteCFG, self).__init__()
        self.replace_id = replace_id

    def forward(self, context, context_mask, cfg_prob):
        context = context.clone()
        if context_mask is not None:
            context_mask = context_mask.clone()
        if cfg_prob > 0:
            cfg_mask = torch.rand(len(context)) < cfg_prob
            if torch.any(cfg_mask):
                context[cfg_mask] = 0
                context[cfg_mask, 0] = self.replace_id
                if context_mask is not None:
                    context_mask[cfg_mask] = False
                    context_mask[cfg_mask, 0] = True
        return context, context_mask


class CFGModel(nn.Module):
    def __init__(self, context_dim, backbone):
        super().__init__()
        self.model = backbone
        self.context_cfg = EmbeddingCFG(context_dim)

    def forward(self, x, timesteps,
                context, x_mask=None, context_mask=None,
                cfg_prob=0.0):
        context = self.context_cfg(context, cfg_prob)
        x = self.model(x=x, timesteps=timesteps,
                       context=context,
                       x_mask=x_mask, context_mask=context_mask)
        return x


class ConcatModel(nn.Module):
    def __init__(self, backbone, in_dim, stride=[]):
        super().__init__()
        self.model = backbone

        self.downsample_layers = nn.ModuleList()
        for i, s in enumerate(stride):
            downsample_layer = nn.Conv1d(
                in_dim,
                in_dim * 2,
                kernel_size=2 * s,
                stride=s,
                padding=math.ceil(s / 2),
            )
            self.downsample_layers.append(downsample_layer)
            in_dim = in_dim * 2

        self.context_cfg = EmbeddingCFG(in_dim)

    def forward(self, x, timesteps,
                context, x_mask=None,
                cfg=False, cfg_prob=0.0):

        # todo: support 2D input
        # x: B, C, L
        # context: B, C, L

        for downsample_layer in self.downsample_layers:
            context = downsample_layer(context)

        context = context.transpose(1, 2)
        context = self.context_cfg(caption=context,
                                   cfg=cfg, cfg_prob=cfg_prob)
        context = context.transpose(1, 2)

        assert context.shape[-1] == x.shape[-1]
        x = torch.cat([context, x], dim=1)
        x = self.model(x=x, timesteps=timesteps,
                       context=None, x_mask=x_mask, context_mask=None)
        return x


class MaskDiT(nn.Module):
    def __init__(self, mae=False, mae_prob=0.5, mask_ratio=[0.25, 1.0], mask_span=10, **kwargs):
        super().__init__()
        self.model = UDiT(**kwargs)
        self.mae = mae
        if self.mae:
            out_channel = kwargs.pop('out_chans', None)
            self.mask_embed = nn.Parameter(torch.zeros((out_channel)))
            self.mae_prob = mae_prob
            self.mask_ratio = mask_ratio
            self.mask_span = mask_span

    def random_masking(self, gt, mask_ratios, mae_mask_infer=None):
        B, D, L = gt.shape
        if mae_mask_infer is None:
            # mask = torch.rand(B, L).to(gt.device) < mask_ratios.unsqueeze(1)
            mask_ratios = mask_ratios.cpu().numpy()
            mask = compute_mask_indices(shape=[B, L],
                                        padding_mask=None,
                                        mask_prob=mask_ratios,
                                        mask_length=self.mask_span,
                                        mask_type="static",
                                        mask_other=0.0,
                                        min_masks=1,
                                        no_overlap=False,
                                        min_space=0,)
            mask = mask.unsqueeze(1).expand_as(gt)
        else:
            mask = mae_mask_infer
            mask = mask.expand_as(gt)
        gt[mask] = self.mask_embed.view(1, D, 1).expand_as(gt)[mask]
        return gt, mask.type_as(gt)

    def forward(self, x, timesteps, context,
                x_mask=None, context_mask=None, cls_token=None,
                gt=None, mae_mask_infer=None,
                forward_model=True):
        # todo: handle controlnet inside
        mae_mask = torch.ones_like(x)
        if self.mae:
            if gt is not None:
                B, D, L = gt.shape
                mask_ratios = torch.FloatTensor(B).uniform_(*self.mask_ratio).to(gt.device)
                gt, mae_mask = self.random_masking(gt, mask_ratios, mae_mask_infer)
                # apply mae only to the selected batches
                if mae_mask_infer is None:
                    # determine mae batch
                    mae_batch = torch.rand(B) < self.mae_prob
                    gt[~mae_batch] = self.mask_embed.view(1, D, 1).expand_as(gt)[~mae_batch]
                    mae_mask[~mae_batch] = 1.0
            else:
                B, D, L = x.shape
                gt = self.mask_embed.view(1, D, 1).expand_as(x)
            x = torch.cat([x, gt, mae_mask[:, 0:1, :]], dim=1)

        if forward_model:
            x = self.model(x=x, timesteps=timesteps, context=context,
                           x_mask=x_mask, context_mask=context_mask,
                           cls_token=cls_token)
            # print(mae_mask[:, 0, :].sum(dim=-1))
        return x, mae_mask