Alex Ergasti commited on
Commit
b89c182
·
1 Parent(s): 9ca6da9
Files changed (7) hide show
  1. app.py +140 -0
  2. common_parser.py +50 -0
  3. converter.py +368 -0
  4. dataset.py +206 -0
  5. download.py +50 -0
  6. models.py +751 -0
  7. utils.py +267 -0
app.py ADDED
@@ -0,0 +1,140 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ # the first flag below was False when we tested this script but True makes A100 training a lot faster:
3
+ torch.backends.cuda.matmul.allow_tf32 = True
4
+ torch.backends.cudnn.allow_tf32 = True
5
+
6
+ import os
7
+
8
+ from diffusers.models import AutoencoderKL
9
+ from models import FLAV_models
10
+
11
+ from diffusion.rectified_flow import RectifiedFlow
12
+
13
+ from diffusers.training_utils import EMAModel
14
+ from converter import Generator
15
+ from utils import *
16
+
17
+ import tempfile
18
+ import gradio as gr
19
+ from huggingface_hub import hf_hub_download
20
+ AUDIO_T_PER_FRAME = 1600 // 160
21
+
22
+ #################################################################################
23
+ # Global Model Setup #
24
+ #################################################################################
25
+
26
+ # These variables will be initialized in setup_models() and used in main()
27
+ vae = None
28
+ model = None
29
+ vocoder = None
30
+ audio_scale = 3.50
31
+
32
+
33
+ def setup_models():
34
+ global vae, model, vocoder
35
+
36
+ device = "cpu"
37
+ vae = AutoencoderKL.from_pretrained(f"stabilityai/sd-vae-ft-ema")
38
+
39
+ model = FLAV_models["FLAV-B/1"](
40
+ latent_size= 256//8,
41
+ in_channels = 4,
42
+ num_classes = 0,
43
+ predict_frames = 10,
44
+ causal_attn = True,
45
+ )
46
+
47
+ ckpt_path = hf_hub_download(repo_id="MaverickAlex/R-FLAV", filename="aist-ema.pth")
48
+
49
+ state_dict = torch.load(ckpt_path)
50
+
51
+ ema = EMAModel(model.parameters())
52
+ ema.load_state_dict(state_dict)
53
+ ema.copy_to(model.parameters())
54
+
55
+ hf_hub_download(repo_id="MaverickAlex/R-FLAV", filename="vocoder-aist/config.json")
56
+ vocoder_path = hf_hub_download(repo_id="MaverickAlex/R-FLAV", filename="vocoder-aist/vocoder.pt")
57
+
58
+ vocoder_path = vocoder_path.replace("vocoder.pt", "")
59
+ vocoder = Generator.from_pretrained(vocoder_path)
60
+
61
+ vae.to(device)
62
+ model.to(device)
63
+ vocoder.to(device)
64
+
65
+
66
+
67
+ def generate_video(num_frames=10, steps=2, seed=42):
68
+ global vae, model, vocoder
69
+ # Setup device
70
+ device = "cuda" if torch.cuda.is_available() else "cpu"
71
+ torch.manual_seed(seed)
72
+
73
+ # Set up generation parameters
74
+ video_latent_size = (1, 10, 4, 256//8, 256//8)
75
+ audio_latent_size = (1, 10, 1, 256, AUDIO_T_PER_FRAME)
76
+
77
+ rectified_flow = RectifiedFlow(num_timesteps=steps,
78
+ warmup_timesteps=10,
79
+ window_size=10)
80
+
81
+ # Generate sample
82
+ video, audio = generate_sample(
83
+ vae=vae, # These globals are set by setup_models
84
+ rectified_flow=rectified_flow,
85
+ forward_fn=model.forward,
86
+ video_length=num_frames,
87
+ video_latent_size=video_latent_size,
88
+ audio_latent_size=audio_latent_size,
89
+ y=None,
90
+ cfg_scale=None,
91
+ device=device
92
+ )
93
+
94
+ # Convert to wav
95
+ wavs = get_wavs(audio, vocoder, audio_scale, device)
96
+
97
+ # Save to temporary files
98
+ temp_dir = tempfile.mkdtemp()
99
+ video_path = os.path.join(temp_dir, "video", "generated_video.mp4")
100
+
101
+ # Use the first video and wav
102
+ vid, wav = video[0], wavs[0]
103
+ save_multimodal(vid, wav, temp_dir, "generated")
104
+
105
+ return video_path
106
+
107
+ def ui_generate_video(num_frames, steps, seed):
108
+ try:
109
+ return generate_video(int(num_frames), int(steps), int(seed))
110
+ except Exception as e:
111
+ return None
112
+
113
+ # Create Gradio interface
114
+ with gr.Blocks(title="FLAV Video Generator") as demo:
115
+ gr.Markdown("# FLAV Video Generator")
116
+ gr.Markdown("Generate videos using the FLAV model")
117
+
118
+ num_frames = None
119
+ steps = None
120
+ seed = None
121
+
122
+ video_output = None
123
+ with gr.Row():
124
+ with gr.Column():
125
+ num_frames = gr.Slider(minimum=5, maximum=30, step=1, value=10, label="Number of Frames")
126
+ steps = gr.Slider(minimum=1, maximum=20, step=1, value=5, label="Number of Steps (multiplied by a factor of 10)")
127
+ seed = gr.Slider(minimum=0, maximum=9999, step=1, value=42, label="Random Seed")
128
+ generate_btn = gr.Button("Generate Video")
129
+
130
+ with gr.Column():
131
+ video_output = gr.PlayableVideo(label="Generated Video", width=256, height=256)
132
+ generate_btn.click(
133
+ fn=ui_generate_video,
134
+ inputs=[num_frames, steps, seed],
135
+ outputs=[video_output]
136
+ )
137
+
138
+ if __name__ == "__main__":
139
+ setup_models()
140
+ demo.launch()
common_parser.py ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ from models import FLAV_models
3
+
4
+ class CommonParser:
5
+ def __init__(self):
6
+ self.parser = argparse.ArgumentParser()
7
+ # Datasets
8
+ self.parser.add_argument("--data-path", type=str, required=True)
9
+ self.parser.add_argument("--load-latents", action="store_true")
10
+ self.parser.add_argument("--num-classes", type=int, default=9)
11
+ self.parser.add_argument("--image-size", type=int, choices=[64, 256, 512, 1024], default=256)
12
+ self.parser.add_argument("--target-video-fps", type=int, default=10)
13
+ self.parser.add_argument("--ignore-cache", action="store_true")
14
+ self.parser.add_argument("--audio-scale", type=float, default=3.5009668382765917)
15
+
16
+ # Results
17
+ self.parser.add_argument("--video-length", type=int, default=1)
18
+ self.parser.add_argument("--predict-frames", type=int, default=10)
19
+ self.parser.add_argument("--results-dir", type=str, default="results")
20
+ self.parser.add_argument("--experiment-dir", type=str, default="")
21
+ self.parser.add_argument("--checkpoint-dir", type=str, default="checkpoints")
22
+ self.parser.add_argument("--ckpt-every", type=int, default=5_000)
23
+
24
+ # Models
25
+ self.parser.add_argument("--seed", type=int, default=42)
26
+ self.parser.add_argument("--model", type=str, choices=list(FLAV_models.keys()), default="FLAV-XL/2")
27
+ self.parser.add_argument("--vae", type=str, choices=["ema", "mse"], default="ema")
28
+ self.parser.add_argument("--use_sd_vae", action="store_true")
29
+ self.parser.add_argument("--vocoder-ckpt", type=str, default="vocoder/")
30
+ self.parser.add_argument("--optimizer-wd", type=float, default=0.02)
31
+
32
+ # Resources
33
+ self.parser.add_argument("--batch-size", type=int, default=4)
34
+ self.parser.add_argument("--num-workers", type=int, default=32)
35
+ self.parser.add_argument("--log-every", type=int, default=100)
36
+
37
+ # Config
38
+ self.parser.add_argument("--load-config", action="store_true")
39
+ self.parser.add_argument("--config-no-save", action="store_true")
40
+ self.parser.add_argument("--config-path", type=str, default="")
41
+ self.parser.add_argument("--config-name", type=str, default="config.json")
42
+
43
+ # Architecture
44
+ self.parser.add_argument("--causal-attn", action="store_true")
45
+
46
+ #RF
47
+ self.parser.add_argument("--num_timesteps", type=int, default=2)
48
+
49
+ def get_parser(self):
50
+ return self.parser
converter.py ADDED
@@ -0,0 +1,368 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+
3
+ import numpy as np
4
+ from PIL import Image
5
+
6
+ import math
7
+ import os
8
+ import random
9
+ import torch
10
+ import json
11
+ import torch.utils.data
12
+ import numpy as np
13
+ import librosa
14
+ from librosa.util import normalize
15
+ from scipy.io.wavfile import read
16
+ from librosa.filters import mel as librosa_mel_fn
17
+
18
+ import torch.nn.functional as F
19
+ import torch.nn as nn
20
+ from torch.nn import Conv1d, ConvTranspose1d, AvgPool1d, Conv2d
21
+ from torch.nn.utils import weight_norm, remove_weight_norm, spectral_norm
22
+
23
+
24
+ def normalize(images):
25
+ """
26
+ Normalize an image array to [-1,1].
27
+ """
28
+ if images.min() >= 0:
29
+ return 2.0 * images - 1.0
30
+ else:
31
+ return images
32
+
33
+ def denormalize(images):
34
+ """
35
+ Denormalize an image array to [0,1].
36
+ """
37
+ if images.min() < 0:
38
+ return (images / 2 + 0.5).clamp(0, 1)
39
+ else:
40
+ return images.clamp(0, 1)
41
+
42
+
43
+ MAX_WAV_VALUE = 32768.0
44
+
45
+
46
+ def load_wav(full_path):
47
+ sampling_rate, data = read(full_path)
48
+ return data, sampling_rate
49
+
50
+
51
+ def dynamic_range_compression(x, C=1, clip_val=1e-5):
52
+ return np.log(np.clip(x, a_min=clip_val, a_max=None) * C)
53
+
54
+
55
+ def dynamic_range_decompression(x, C=1):
56
+ return np.exp(x) / C
57
+
58
+
59
+ def dynamic_range_compression_torch(x, C=1, clip_val=1e-5):
60
+ return torch.log(torch.clamp(x, min=clip_val) * C)
61
+
62
+
63
+ def dynamic_range_decompression_torch(x, C=1):
64
+ return torch.exp(x) / C
65
+
66
+
67
+ def spectral_normalize_torch(magnitudes):
68
+ output = dynamic_range_compression_torch(magnitudes)
69
+ return output
70
+
71
+
72
+ def spectral_de_normalize_torch(magnitudes):
73
+ output = dynamic_range_decompression_torch(magnitudes)
74
+ return output
75
+
76
+
77
+ mel_basis = {}
78
+ hann_window = {}
79
+
80
+
81
+ def mel_spectrogram(y, n_fft, num_mels, sampling_rate, hop_size, win_size, fmin, fmax, center=False):
82
+ # if torch.min(y) < -1.:
83
+ # print('min value is ', torch.min(y))
84
+ # if torch.max(y) > 1.:
85
+ # print('max value is ', torch.max(y))
86
+
87
+ global mel_basis, hann_window
88
+ if fmax not in mel_basis:
89
+ mel = librosa_mel_fn(sr=sampling_rate, n_fft=n_fft, n_mels=num_mels, fmin=fmin, fmax=fmax)
90
+ mel_basis[str(fmax)+'_'+str(y.device)] = torch.from_numpy(mel).float().to(y.device)
91
+ hann_window[str(y.device)] = torch.hann_window(win_size).to(y.device)
92
+
93
+ y = torch.nn.functional.pad(y, (int((n_fft-hop_size)/2), int((n_fft-hop_size)/2)), mode='reflect')
94
+ y = y.squeeze(1)
95
+
96
+ # complex tensor as default, then use view_as_real for future pytorch compatibility
97
+ spec = torch.stft(y, n_fft, hop_length=hop_size, win_length=win_size, window=hann_window[str(y.device)],
98
+ center=center, pad_mode='reflect', normalized=False, onesided=True, return_complex=True)
99
+ spec = torch.view_as_real(spec)
100
+ spec = torch.sqrt(spec.pow(2).sum(-1)+(1e-9))
101
+
102
+ spec = torch.matmul(mel_basis[str(fmax)+'_'+str(y.device)], spec)
103
+ spec = spectral_normalize_torch(spec)
104
+
105
+ return spec
106
+
107
+
108
+ def spectrogram(y, n_fft, num_mels, sampling_rate, hop_size, win_size, fmin, fmax, center=False):
109
+ global hann_window
110
+ hann_window[str(y.device)] = torch.hann_window(win_size).to(y.device)
111
+
112
+ y = torch.nn.functional.pad(y.unsqueeze(1), (int((n_fft-hop_size)/2), int((n_fft-hop_size)/2)), mode='reflect')
113
+ y = y.squeeze(1)
114
+
115
+ # complex tensor as default, then use view_as_real for future pytorch compatibility
116
+ spec = torch.stft(y, n_fft, hop_length=hop_size, win_length=win_size, window=hann_window[str(y.device)],
117
+ center=center, pad_mode='reflect', normalized=False, onesided=True, return_complex=True)
118
+ spec = torch.view_as_real(spec)
119
+ spec = torch.sqrt(spec.pow(2).sum(-1)+(1e-9))
120
+
121
+ return spec
122
+
123
+
124
+ def normalize_spectrogram(
125
+ spectrogram: torch.Tensor,
126
+ max_value: float = 200,
127
+ min_value: float = 1e-5,
128
+ power: float = 1.,
129
+ inverse: bool = False
130
+ ) -> torch.Tensor:
131
+
132
+ # Rescale to 0-1
133
+ max_value = np.log(max_value) # 5.298317366548036
134
+ min_value = np.log(min_value) # -11.512925464970229
135
+
136
+ assert spectrogram.max() <= max_value and spectrogram.min() >= min_value
137
+
138
+ data = (spectrogram - min_value) / (max_value - min_value)
139
+
140
+ # Invert
141
+ if inverse:
142
+ data = 1 - data
143
+
144
+ # Apply the power curve
145
+ data = torch.pow(data, power)
146
+
147
+ return data
148
+
149
+
150
+
151
+ def denormalize_spectrogram(
152
+ data: torch.Tensor,
153
+ max_value: float = 200,
154
+ min_value: float = 1e-5,
155
+ power: float = 1,
156
+ ) -> torch.Tensor:
157
+
158
+ max_value = np.log(max_value)
159
+ min_value = np.log(min_value)
160
+
161
+ # Reverse the power curve
162
+ data = torch.pow(data, 1 / power)
163
+
164
+ # Rescale to max value
165
+ spectrogram = data * (max_value - min_value) + min_value
166
+
167
+ return spectrogram
168
+
169
+
170
+ def get_mel_spectrogram_from_audio(audio, device="cuda"):
171
+ audio = audio / MAX_WAV_VALUE
172
+ audio = librosa.util.normalize(audio) * 0.95
173
+
174
+ audio = torch.FloatTensor(audio)
175
+ audio = audio.unsqueeze(0)
176
+
177
+ waveform = audio
178
+ spec = mel_spectrogram(waveform, n_fft=2048, num_mels=256, sampling_rate=16000, hop_size=160, win_size=1024, fmin=0, fmax=8000, center=False)
179
+ return audio, spec
180
+
181
+
182
+
183
+ LRELU_SLOPE = 0.1
184
+ MAX_WAV_VALUE = 32768.0
185
+
186
+
187
+ class AttrDict(dict):
188
+ def __init__(self, *args, **kwargs):
189
+ super(AttrDict, self).__init__(*args, **kwargs)
190
+ self.__dict__ = self
191
+
192
+
193
+ def get_config(config_path):
194
+ config = json.loads(open(config_path).read())
195
+ config = AttrDict(config)
196
+ return config
197
+
198
+ def init_weights(m, mean=0.0, std=0.01):
199
+ classname = m.__class__.__name__
200
+ if classname.find("Conv") != -1:
201
+ m.weight.data.normal_(mean, std)
202
+
203
+
204
+ def apply_weight_norm(m):
205
+ classname = m.__class__.__name__
206
+ if classname.find("Conv") != -1:
207
+ weight_norm(m)
208
+
209
+
210
+ def get_padding(kernel_size, dilation=1):
211
+ return int((kernel_size*dilation - dilation)/2)
212
+
213
+
214
+ class ResBlock1(torch.nn.Module):
215
+ def __init__(self, h, channels, kernel_size=3, dilation=(1, 3, 5)):
216
+ super(ResBlock1, self).__init__()
217
+ self.h = h
218
+ self.convs1 = nn.ModuleList([
219
+ weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[0],
220
+ padding=get_padding(kernel_size, dilation[0]))),
221
+ weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[1],
222
+ padding=get_padding(kernel_size, dilation[1]))),
223
+ weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[2],
224
+ padding=get_padding(kernel_size, dilation[2])))
225
+ ])
226
+ self.convs1.apply(init_weights)
227
+
228
+ self.convs2 = nn.ModuleList([
229
+ weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1,
230
+ padding=get_padding(kernel_size, 1))),
231
+ weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1,
232
+ padding=get_padding(kernel_size, 1))),
233
+ weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1,
234
+ padding=get_padding(kernel_size, 1)))
235
+ ])
236
+ self.convs2.apply(init_weights)
237
+
238
+ def forward(self, x):
239
+ for c1, c2 in zip(self.convs1, self.convs2):
240
+ xt = F.leaky_relu(x, LRELU_SLOPE)
241
+ xt = c1(xt)
242
+ xt = F.leaky_relu(xt, LRELU_SLOPE)
243
+ xt = c2(xt)
244
+ x = xt + x
245
+ return x
246
+
247
+ def remove_weight_norm(self):
248
+ for l in self.convs1:
249
+ remove_weight_norm(l)
250
+ for l in self.convs2:
251
+ remove_weight_norm(l)
252
+
253
+
254
+ class ResBlock2(torch.nn.Module):
255
+ def __init__(self, h, channels, kernel_size=3, dilation=(1, 3)):
256
+ super(ResBlock2, self).__init__()
257
+ self.h = h
258
+ self.convs = nn.ModuleList([
259
+ weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[0],
260
+ padding=get_padding(kernel_size, dilation[0]))),
261
+ weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[1],
262
+ padding=get_padding(kernel_size, dilation[1])))
263
+ ])
264
+ self.convs.apply(init_weights)
265
+
266
+ def forward(self, x):
267
+ for c in self.convs:
268
+ xt = F.leaky_relu(x, LRELU_SLOPE)
269
+ xt = c(xt)
270
+ x = xt + x
271
+ return x
272
+
273
+ def remove_weight_norm(self):
274
+ for l in self.convs:
275
+ remove_weight_norm(l)
276
+
277
+
278
+
279
+ class Generator(torch.nn.Module):
280
+ def __init__(self, h):
281
+ super(Generator, self).__init__()
282
+ self.h = h
283
+ self.num_kernels = len(h.resblock_kernel_sizes)
284
+ self.num_upsamples = len(h.upsample_rates)
285
+ self.conv_pre = weight_norm(Conv1d(h.num_mels, h.upsample_initial_channel, 7, 1, padding=3)) # change: 80 --> 512
286
+ resblock = ResBlock1 if h.resblock == '1' else ResBlock2
287
+
288
+ self.ups = nn.ModuleList()
289
+ for i, (u, k) in enumerate(zip(h.upsample_rates, h.upsample_kernel_sizes)):
290
+ if (k-u) % 2 == 0:
291
+ self.ups.append(weight_norm(
292
+ ConvTranspose1d(h.upsample_initial_channel//(2**i), h.upsample_initial_channel//(2**(i+1)),
293
+ k, u, padding=(k-u)//2)))
294
+ else:
295
+ self.ups.append(weight_norm(
296
+ ConvTranspose1d(h.upsample_initial_channel//(2**i), h.upsample_initial_channel//(2**(i+1)),
297
+ k, u, padding=(k-u)//2+1, output_padding=1)))
298
+
299
+ # self.ups.append(weight_norm(
300
+ # ConvTranspose1d(h.upsample_initial_channel//(2**i), h.upsample_initial_channel//(2**(i+1)),
301
+ # k, u, padding=(k-u)//2)))
302
+
303
+
304
+ self.resblocks = nn.ModuleList()
305
+ for i in range(len(self.ups)):
306
+ ch = h.upsample_initial_channel//(2**(i+1))
307
+ for j, (k, d) in enumerate(zip(h.resblock_kernel_sizes, h.resblock_dilation_sizes)):
308
+ self.resblocks.append(resblock(h, ch, k, d))
309
+
310
+ self.conv_post = weight_norm(Conv1d(ch, 1, 7, 1, padding=3))
311
+ self.ups.apply(init_weights)
312
+ self.conv_post.apply(init_weights)
313
+
314
+ def forward(self, x):
315
+ x = self.conv_pre(x)
316
+ for i in range(self.num_upsamples):
317
+ x = F.leaky_relu(x, LRELU_SLOPE)
318
+ x = self.ups[i](x)
319
+ xs = None
320
+ for j in range(self.num_kernels):
321
+ if xs is None:
322
+ xs = self.resblocks[i*self.num_kernels+j](x)
323
+ else:
324
+ xs += self.resblocks[i*self.num_kernels+j](x)
325
+ x = xs / self.num_kernels
326
+ x = F.leaky_relu(x)
327
+ x = self.conv_post(x)
328
+ x = torch.tanh(x)
329
+
330
+ return x
331
+
332
+ def remove_weight_norm(self):
333
+ for l in self.ups:
334
+ remove_weight_norm(l)
335
+ for l in self.resblocks:
336
+ l.remove_weight_norm()
337
+ remove_weight_norm(self.conv_pre)
338
+ remove_weight_norm(self.conv_post)
339
+
340
+ @classmethod
341
+ def from_pretrained(cls, pretrained_model_name_or_path, subfolder=None):
342
+ if subfolder is not None:
343
+ pretrained_model_name_or_path = os.path.join(pretrained_model_name_or_path, subfolder)
344
+ config_path = os.path.join(pretrained_model_name_or_path, "config.json")
345
+ ckpt_path = os.path.join(pretrained_model_name_or_path, "vocoder.pt")
346
+
347
+ config = get_config(config_path)
348
+ vocoder = cls(config)
349
+
350
+ state_dict_g = torch.load(ckpt_path)
351
+ vocoder.load_state_dict(state_dict_g["generator"])
352
+ vocoder.eval()
353
+ vocoder.remove_weight_norm()
354
+ return vocoder
355
+
356
+
357
+ @torch.no_grad()
358
+ def inference(self, mels, lengths=None):
359
+ self.eval()
360
+ with torch.no_grad():
361
+ wavs = self(mels).squeeze(1)
362
+
363
+ wavs = (wavs.cpu().numpy() * MAX_WAV_VALUE).astype("int16")
364
+
365
+ if lengths is not None:
366
+ wavs = wavs[:, :lengths]
367
+
368
+ return wavs
dataset.py ADDED
@@ -0,0 +1,206 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. All Rights Reserved
2
+
3
+ import os.path as osp
4
+ import math
5
+ import pickle
6
+ import warnings
7
+
8
+ import glob
9
+
10
+ import torch.utils.data as data
11
+ import torch.nn.functional as F
12
+ from torchvision.datasets.video_utils import VideoClips
13
+ from converter import normalize, normalize_spectrogram, get_mel_spectrogram_from_audio
14
+ from torchaudio import transforms as Ta
15
+ from torchvision import transforms as Tv
16
+ from torchvision.io.video import read_video
17
+ import torch
18
+ from torchvision.transforms import InterpolationMode
19
+
20
+ class LatentDataset(data.Dataset):
21
+ """ Generic dataset for latents pregenerated from a dataset
22
+ Returns a dictionary of latents encoded from the original dataset """
23
+ exts = ['pt']
24
+
25
+ def __init__(self, data_folder, train=True):
26
+ """
27
+ Args:
28
+ data_folder: path to the folder with videos. The folder
29
+ should contain a 'train' and a 'test' directory,
30
+ each with corresponding videos stored
31
+ """
32
+ super().__init__()
33
+ self.train = train
34
+
35
+ folder = osp.join(data_folder, 'train' if train else 'test')
36
+ self.files = sum([glob.glob(osp.join(folder, '**', f'*.{ext}'), recursive=True)
37
+ for ext in self.exts], [])
38
+
39
+ warnings.filterwarnings('ignore')
40
+
41
+ def __len__(self):
42
+ return len(self.files)
43
+
44
+ def __getitem__(self, idx):
45
+ while True:
46
+ try:
47
+ latents = torch.load(self.files[idx], map_location="cpu")
48
+ except Exception as e:
49
+ print(f"Dataset Exception: {e}")
50
+ idx = (idx + 1) % len(self.files)
51
+ continue
52
+ break
53
+
54
+ return latents["video"], latents["audio"], latents["y"]
55
+ class AudioVideoDataset(data.Dataset):
56
+ """ Generic dataset for videos files stored in folders
57
+ Returns BCTHW videos in the range [-0.5, 0.5] """
58
+ exts = ['avi', 'mp4', 'webm']
59
+
60
+ def __init__(self, data_folder, train=True, resolution=64, sample_every_n_frames=1, sequence_length=8, audio_channels=1, sample_rate=16000, min_length=1, ignore_cache=False, labeled=True, target_video_fps=10):
61
+ """
62
+ Args:
63
+ data_folder: path to the folder with videos. The folder
64
+ should contain a 'train' and a 'test' directory,
65
+ each with corresponding videos stored
66
+ sequence_length: length of extracted video sequences
67
+ """
68
+ super().__init__()
69
+ self.train = train
70
+ self.sequence_length = sequence_length
71
+ self.resolution = resolution
72
+ self.sample_every_n_frames = sample_every_n_frames
73
+ self.audio_channels = audio_channels
74
+ self.sample_rate = sample_rate
75
+ self.min_length = min_length
76
+ self.labeled = labeled
77
+
78
+
79
+ folder = osp.join(data_folder, 'train' if train else 'test')
80
+ files = sum([glob.glob(osp.join(folder, '**', f'*.{ext}'), recursive=True)
81
+ for ext in self.exts], [])
82
+
83
+ # hacky way to compute # of classes (count # of unique parent directories)
84
+ self.classes = list(set([get_parent_dir(f) for f in files]))
85
+ self.classes.sort()
86
+ self.class_to_label = {c: i for i, c in enumerate(self.classes)}
87
+
88
+ warnings.filterwarnings('ignore')
89
+ cache_file = osp.join(folder, f"metadata_{self.sequence_length}.pkl")
90
+ if not osp.exists(cache_file) or ignore_cache or True:
91
+ clips = VideoClips(files, self.sequence_length, num_workers=32, frame_rate=target_video_fps)
92
+ # pickle.dump(clips.metadata, open(cache_file, 'wb'))
93
+ else:
94
+ metadata = pickle.load(open(cache_file, 'rb'))
95
+ clips = VideoClips(files, self.sequence_length,
96
+ _precomputed_metadata=metadata)
97
+
98
+ # self._clips = clips.subset(np.arange(24))
99
+ self._clips = clips
100
+
101
+ @property
102
+ def n_classes(self):
103
+ return len(self.classes)
104
+
105
+ def __len__(self):
106
+ return self._clips.num_clips()
107
+
108
+ def __getitem__(self, idx):
109
+ resolution = self.resolution
110
+ while True:
111
+ try:
112
+ video, _, info, _ = self._clips.get_clip(idx)
113
+ except Exception:
114
+ idx = (idx + 1) % self._clips.num_clips()
115
+ continue
116
+ break
117
+
118
+ return preprocess(video, resolution, sample_every_n_frames=self.sample_every_n_frames), self.get_audio(info, idx), self.get_label(idx)
119
+
120
+ def get_label(self, idx):
121
+ if not self.labeled:
122
+ return -1
123
+ video_idx, clip_idx = self._clips.get_clip_location(idx)
124
+ class_name = get_parent_dir(self._clips.video_paths[video_idx])
125
+ label = self.class_to_label[class_name]
126
+ return label
127
+
128
+ def get_audio(self, info, idx):
129
+ video_idx, clip_idx = self._clips.get_clip_location(idx)
130
+
131
+ video_path = self._clips.video_paths[video_idx]
132
+ video_fps = self._clips.video_fps[video_idx]
133
+
134
+ duration_per_frame = self._clips.video_pts[video_idx][1] - self._clips.video_pts[video_idx][0]
135
+ clip_pts = self._clips.clips[video_idx][clip_idx]
136
+ clip_pid = clip_pts // duration_per_frame
137
+
138
+ start_t = (clip_pid[0] / video_fps * 1. ).item()
139
+ end_t = ((clip_pid[-1] + 1) / video_fps * 1. ).item()
140
+
141
+ _, raw_audio, _ = read_video(video_path,start_t, end_t, pts_unit='sec')
142
+ raw_audio = prepare_audio(raw_audio, info["audio_fps"], self.sample_rate, self.audio_channels, self.sequence_length, self.min_length)
143
+
144
+ _, spec = get_mel_spectrogram_from_audio(raw_audio[0].numpy())
145
+ norm_spec = normalize_spectrogram(spec)
146
+ norm_spec = normalize(norm_spec) # normalize to [-1, 1], because pipeline do not normalize for torch.Tensor input
147
+ norm_spec.unsqueeze(1) # add channel dimension
148
+ return norm_spec
149
+ #return raw_audio[0]
150
+
151
+
152
+ def get_parent_dir(path):
153
+ return osp.basename(osp.dirname(path))
154
+
155
+ def preprocess(video, resolution, sample_every_n_frames=1):
156
+ video = video.permute(0, 3, 1, 2).float() / 255. # TCHW
157
+
158
+ old_size = video.shape[2:4]
159
+ ratio = min(float(resolution)/(old_size[0]), float(resolution)/(old_size[1]) )
160
+ new_size = tuple([int(i*ratio) for i in old_size])
161
+ pad_w = resolution - new_size[1]
162
+ pad_h = resolution- new_size[0]
163
+ top,bottom = pad_h//2, pad_h-(pad_h//2)
164
+ left,right = pad_w//2, pad_w -(pad_w//2)
165
+ transform = Tv.Compose([Tv.Resize(new_size, interpolation=InterpolationMode.BICUBIC), Tv.Pad((left, top, right, bottom))])
166
+ video_new = transform(video)
167
+
168
+ video_new = video_new*2-1
169
+
170
+ return video_new
171
+
172
+ def pad_crop_audio(audio, target_length):
173
+ target_length = int(target_length)
174
+ n, s = audio.shape
175
+ start = 0
176
+ end = start + target_length
177
+ output = audio.new_zeros([n, target_length])
178
+ output[:, :min(s, target_length)] = audio[:, start:end]
179
+ return output
180
+
181
+ def prepare_audio(audio, in_sr, target_sr, target_channels, sequence_length, min_length):
182
+ if in_sr != target_sr:
183
+ resample_tf = Ta.Resample(in_sr, target_sr)
184
+ audio = resample_tf(audio)
185
+
186
+ max_length = target_sr/10*sequence_length
187
+ target_length = max_length + (min_length - (max_length % min_length)) % min_length
188
+
189
+ audio = pad_crop_audio(audio, target_length)
190
+
191
+ audio = set_audio_channels(audio, target_channels)
192
+
193
+ return audio
194
+
195
+ def set_audio_channels(audio, target_channels):
196
+ if target_channels == 1:
197
+ # Convert to mono
198
+ # audio = audio.mean(0, keepdim=True)
199
+ audio = audio[:1, :]
200
+ elif target_channels == 2:
201
+ # Convert to stereo
202
+ if audio.shape[0] == 1:
203
+ audio = audio.repeat(2, 1)
204
+ elif audio.shape[0] > 2:
205
+ audio = audio[:2, :]
206
+ return audio
download.py ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ """
8
+ Functions for downloading pre-trained DiT models
9
+ """
10
+ from torchvision.datasets.utils import download_url
11
+ import torch
12
+ import os
13
+
14
+
15
+ pretrained_models = {'DiT-XL-2-512x512.pt', 'DiT-XL-2-256x256.pt'}
16
+
17
+
18
+ def find_model(model_name):
19
+ """
20
+ Finds a pre-trained DiT model, downloading it if necessary. Alternatively, loads a model from a local path.
21
+ """
22
+ if model_name in pretrained_models: # Find/download our pre-trained DiT checkpoints
23
+ return download_model(model_name)
24
+ else: # Load a custom DiT checkpoint:
25
+ assert os.path.isfile(model_name), f'Could not find DiT checkpoint at {model_name}'
26
+ checkpoint = torch.load(model_name, map_location=lambda storage, loc: storage)
27
+ if "ema" in checkpoint: # supports checkpoints from train.py
28
+ checkpoint = checkpoint["ema"]
29
+ return checkpoint
30
+
31
+
32
+ def download_model(model_name):
33
+ """
34
+ Downloads a pre-trained DiT model from the web.
35
+ """
36
+ assert model_name in pretrained_models
37
+ local_path = f'pretrained_models/{model_name}'
38
+ if not os.path.isfile(local_path):
39
+ os.makedirs('pretrained_models', exist_ok=True)
40
+ web_path = f'https://dl.fbaipublicfiles.com/DiT/models/{model_name}'
41
+ download_url(web_path, 'pretrained_models')
42
+ model = torch.load(local_path, map_location=lambda storage, loc: storage)
43
+ return model
44
+
45
+
46
+ if __name__ == "__main__":
47
+ # Download all DiT checkpoints
48
+ for model in pretrained_models:
49
+ download_model(model)
50
+ print('Done.')
models.py ADDED
@@ -0,0 +1,751 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+ # --------------------------------------------------------
7
+ # References:
8
+ # GLIDE: https://github.com/openai/glide-text2im
9
+ # MAE: https://github.com/facebookresearch/mae/blob/main/models_mae.py
10
+ # --------------------------------------------------------
11
+
12
+ import torch
13
+ import torch.nn as nn
14
+ import numpy as np
15
+ import math
16
+ from timm.models.vision_transformer import PatchEmbed, Attention, Mlp
17
+ import einops
18
+
19
+ import torch.utils.checkpoint as checkpoint
20
+
21
+ from transformers import PreTrainedModel
22
+ import random
23
+
24
+ class MelPatchEmbed(nn.Module):
25
+ """ Image to Patch Embedding
26
+ """
27
+ def __init__(self, n_mels, n_frames, patch_size=16, in_chans=1, embed_dim=768):
28
+ super().__init__()
29
+ num_patches = (n_mels // patch_size) * (n_frames // patch_size)
30
+ self.patch_size = patch_size
31
+ self.num_patches = int(num_patches)
32
+
33
+ self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
34
+
35
+ def forward(self, x):
36
+ x = self.proj(x).flatten(2).transpose(1, 2)
37
+ return x
38
+
39
+ class SelfAttention(nn.Module):
40
+ def __init__(
41
+ self,
42
+ dim: int,
43
+ num_heads: int = 8,
44
+ qkv_bias: bool = False,
45
+ qk_norm: bool = False,
46
+ attn_drop: float = 0.,
47
+ proj_drop: float = 0.,
48
+ norm_layer: nn.Module = nn.LayerNorm,
49
+ is_causal: bool = False,
50
+ ) -> None:
51
+ super().__init__()
52
+ assert dim % num_heads == 0, 'dim should be divisible by num_heads'
53
+ self.is_causal = is_causal
54
+ self.num_heads = num_heads
55
+ self.head_dim = dim // num_heads
56
+ self.scale = self.head_dim ** -0.5
57
+
58
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
59
+ self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
60
+ self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
61
+ self.attn_drop = nn.Dropout(attn_drop)
62
+ self.proj = nn.Linear(dim, dim)
63
+ self.proj_drop = nn.Dropout(proj_drop)
64
+
65
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
66
+ B, N, C = x.shape
67
+ qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
68
+ q, k, v = qkv.unbind(0)
69
+ q, k = self.q_norm(q), self.k_norm(k)
70
+ x = torch.nn.functional.scaled_dot_product_attention(
71
+ q, k, v,
72
+ dropout_p=self.attn_drop.p if self.training else 0.,
73
+ is_causal=self.is_causal
74
+ )
75
+
76
+ x = x.transpose(1, 2).reshape(B, N, C)
77
+ x = self.proj(x)
78
+ x = self.proj_drop(x)
79
+ return x
80
+
81
+ class CrossAttention(nn.Module):
82
+ def __init__(
83
+ self,
84
+ dim,
85
+ num_heads=8,
86
+ qkv_bias=False,
87
+ attn_drop=0.,
88
+ proj_drop=0.,
89
+ mask_attn=False,
90
+ ):
91
+ super().__init__()
92
+ self.mask_attn = mask_attn
93
+ self.num_heads = num_heads
94
+ head_dim = dim // num_heads
95
+ # NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights
96
+ self.scale = head_dim ** -0.5
97
+
98
+ self.wq = nn.Linear(dim, dim, bias=qkv_bias)
99
+ self.wkv = nn.Linear(dim, dim*2, bias=qkv_bias)
100
+ self.attn_drop = nn.Dropout(attn_drop)
101
+ self.proj = nn.Linear(dim, dim)
102
+ self.proj_drop = nn.Dropout(proj_drop)
103
+
104
+ def forward(self, x, cond):
105
+ B, N, C = x.shape
106
+
107
+ q = self.wq(x)
108
+ q = einops.rearrange(q, 'B N (H D) -> B H N D', H=self.num_heads)
109
+
110
+ kv = self.wkv(cond) # BMD
111
+ kv = einops.rearrange(kv, 'B N (K H D) ->K B H N D', H=self.num_heads, K=2)
112
+ k = kv[0]
113
+ v = kv[1]
114
+
115
+
116
+ x = torch.nn.functional.scaled_dot_product_attention(q, k, v)
117
+
118
+ x = einops.rearrange(x, 'B H N D -> B N (H D)')
119
+ x = self.proj(x)
120
+ x = self.proj_drop(x)
121
+ return x
122
+
123
+ def modulate(x, shift, scale):
124
+ return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
125
+
126
+ def temporalModulate(x, shift, scale):
127
+ """
128
+ Modulate the input tensor x with the given shift and scale tensors.
129
+ :param x: the input tensor to modulate with shape (B, T, L, D).
130
+ :param shift: the shift tensor with shape (B, T, D).
131
+ :param scale: the scale tensor with shape (B, T, D).
132
+ """
133
+ return x * (1 + scale.unsqueeze(2)) + shift.unsqueeze(2)
134
+
135
+
136
+ #################################################################################
137
+ # Embedding Layers for Timesteps and Class Labels #
138
+ #################################################################################
139
+
140
+ class TimestepEmbedder(nn.Module):
141
+ """
142
+ Embeds scalar timesteps into vector representations.
143
+ """
144
+ def __init__(self, hidden_size, frequency_embedding_size=256):
145
+ super().__init__()
146
+ self.mlp = nn.Sequential(
147
+ nn.Linear(frequency_embedding_size, hidden_size, bias=True),
148
+ nn.SiLU(),
149
+ nn.Linear(hidden_size, hidden_size, bias=True),
150
+ )
151
+ self.frequency_embedding_size = frequency_embedding_size
152
+
153
+ @staticmethod
154
+ def timestep_embedding(t, dim, max_period=10000):
155
+ """
156
+ Create sinusoidal timestep embeddings.
157
+ :param t: a 1-D Tensor of N indices, one per batch element.
158
+ These may be fractional.
159
+ :param dim: the dimension of the output.
160
+ :param max_period: controls the minimum frequency of the embeddings.
161
+ :return: an (N, D) Tensor of positional embeddings.
162
+ """
163
+ # https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py
164
+ half = dim // 2
165
+ freqs = torch.exp(
166
+ -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half
167
+ ).to(device=t.device)
168
+ args = t[:, None].float() * freqs[None]
169
+ embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
170
+ if dim % 2:
171
+ embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
172
+ return embedding
173
+
174
+ def forward(self, t):
175
+ t_freq = self.timestep_embedding(t, self.frequency_embedding_size)
176
+ t_emb = self.mlp(t_freq)
177
+ return t_emb
178
+
179
+
180
+ class AudioEmbedder(nn.Module):
181
+ """
182
+ Embeds scalar timesteps into vector representations.
183
+ """
184
+ def __init__(self, n_mels, hidden_size):
185
+ super().__init__()
186
+ self.mlp = nn.Sequential(
187
+ nn.Linear(n_mels, hidden_size, bias=True),
188
+ nn.SiLU(),
189
+ nn.Linear(hidden_size, hidden_size, bias=True),
190
+ )
191
+
192
+ # TODO: Activation?
193
+
194
+ def forward(self, a):
195
+ a = self.mlp(a)
196
+ return a
197
+
198
+ def init_weights(self):
199
+ nn.init.xavier_uniform_(self.mlp[0].weight)
200
+ nn.init.constant_(self.mlp[0].bias, 0)
201
+ nn.init.xavier_uniform_(self.mlp[2].weight)
202
+ nn.init.constant_(self.mlp[2].bias, 0)
203
+
204
+ class LabelEmbedder(nn.Module):
205
+ """
206
+ Embeds class labels into vector representations. Also handles label dropout for classifier-free guidance.
207
+ """
208
+ def __init__(self, num_classes, hidden_size, dropout_prob):
209
+ super().__init__()
210
+ use_cfg_embedding = dropout_prob > 0
211
+ self.embedding_table = nn.Embedding(num_classes + use_cfg_embedding, hidden_size)
212
+ self.num_classes = num_classes
213
+ self.dropout_prob = dropout_prob
214
+
215
+ def token_drop(self, labels, force_drop_ids=None):
216
+ """
217
+ Drops labels to enable classifier-free guidance.
218
+ """
219
+ if force_drop_ids is None:
220
+ drop_ids = torch.rand(labels.shape[0], device=labels.device) < self.dropout_prob
221
+ else:
222
+ drop_ids = force_drop_ids == 1
223
+ labels = torch.where(drop_ids, self.num_classes, labels)
224
+ return labels
225
+
226
+ def forward(self, labels, train, force_drop_ids=None):
227
+ use_dropout = self.dropout_prob > 0
228
+ if (train and use_dropout) or (force_drop_ids is not None):
229
+ labels = self.token_drop(labels, force_drop_ids)
230
+ embeddings = self.embedding_table(labels)
231
+ return embeddings
232
+
233
+
234
+ #################################################################################
235
+ # Core FLAV Model #
236
+ #################################################################################
237
+
238
+ class FLAVBlock(nn.Module):
239
+ """
240
+ A FLAV block with adaptive layer norm zero (adaLN-Zero) conditioning.
241
+ """
242
+ def __init__(self, hidden_size, num_heads, mlp_ratio=4.0, grad_ckpt=False, causal_attn=False, **block_kwargs):
243
+ super().__init__()
244
+
245
+ self.video_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
246
+ self.audio_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
247
+ # self.video_audio_attn = SelfAttention(hidden_size, num_heads=num_heads, qkv_bias=True, **block_kwargs)
248
+ self.video_spatial_attn = SelfAttention(hidden_size, num_heads=num_heads, qkv_bias=True, **block_kwargs)
249
+ self.video_temporal_attn = SelfAttention(hidden_size, num_heads=num_heads, qkv_bias=True, is_causal=causal_attn, **block_kwargs)
250
+
251
+ self.audio_spatial_attn = SelfAttention(hidden_size, num_heads=num_heads, qkv_bias=True, is_causal=causal_attn, **block_kwargs)
252
+ # self.audio_temporal_attn = SelfAttention(hidden_size, num_heads=num_heads, qkv_bias=True, is_causal=causal_attn, **block_kwargs)
253
+
254
+ self.video_norm3 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
255
+ self.audio_norm3 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
256
+
257
+ self.video_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
258
+ self.audio_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
259
+
260
+ mlp_hidden_dim = int(hidden_size * mlp_ratio)
261
+ approx_gelu = lambda: nn.GELU(approximate="tanh")
262
+ self.video_mlp = Mlp(in_features=hidden_size, hidden_features=mlp_hidden_dim, act_layer=approx_gelu, drop=0)
263
+ self.audio_mlp = Mlp(in_features=hidden_size, hidden_features=mlp_hidden_dim, act_layer=approx_gelu, drop=0)
264
+
265
+ self.video_adaLN_modulation = nn.Sequential(
266
+ nn.SiLU(),
267
+ nn.Linear(hidden_size, 6 * hidden_size, bias=True)
268
+ )
269
+
270
+ self.audio_adaLN_modulation = nn.Sequential(
271
+ nn.SiLU(),
272
+ nn.Linear(hidden_size, 3 * hidden_size, bias=True)
273
+ )
274
+
275
+ self.video_scale = nn.Sequential(
276
+ nn.SiLU(),
277
+ nn.Linear(hidden_size, 3 * hidden_size, bias=True)
278
+ )
279
+
280
+ self.audio_scale = nn.Sequential(
281
+ nn.SiLU(),
282
+ nn.Linear(hidden_size, 3 * hidden_size, bias=True)
283
+ )
284
+
285
+ self.v_avg_proj = nn.Sequential(
286
+ nn.Linear(hidden_size, hidden_size, bias=True),
287
+ )
288
+ self.a_avg_proj = nn.Sequential(
289
+ nn.Linear(hidden_size, hidden_size, bias=True),
290
+ )
291
+
292
+
293
+
294
+ self.grad_ckpt = grad_ckpt
295
+
296
+ def forward(self,v, a, v_c, a_c):
297
+ if self.grad_ckpt:
298
+ return checkpoint.checkpoint(self._forward, v, a, v_c, a_c, use_reentrant=False)
299
+ else:
300
+ return self._forward(v, a, v_c, a_c)
301
+
302
+ def _forward(self, v, a, v_c, a_c):
303
+ """
304
+ v: Size of (B, T, Lv, D)
305
+ a: Size of (B, T, La, D)
306
+ v_c: Size of (B, T, D)
307
+ a_c: Size of (B, T, D)
308
+ """
309
+
310
+ video_shift_msa, video_scale_msa, video_gate_msa, video_shift_tmsa, video_scale_tmsa, video_gate_tmsa = self.video_adaLN_modulation(v_c).chunk(6, dim=-1)
311
+ # audio_shift_msa, audio_scale_msa, audio_gate_msa, audio_shift_tmsa, audio_scale_tmsa, audio_gate_tmsa = self.audio_adaLN_modulation(a_c).chunk(6, dim=-1)
312
+ audio_shift_msa, audio_scale_msa, audio_gate_msa = self.audio_adaLN_modulation(a_c).chunk(3, dim=-1)
313
+ B, T, L, D = v.shape
314
+
315
+ v_att = temporalModulate(self.video_norm1(v), video_shift_msa, video_scale_msa)
316
+ v_att = einops.rearrange(v_att, 'B T L D -> (B T) L D')
317
+ v_att = v + video_gate_msa.unsqueeze(2)*(self.video_spatial_attn(v_att).view(B, T, L, D))
318
+
319
+ v = v_att
320
+
321
+ v_att = temporalModulate(self.video_norm2(v_att), video_shift_tmsa, video_scale_tmsa)
322
+ v_att = einops.rearrange(v_att, 'B T L D -> (B L) T D', T=T)
323
+ v_att = einops.rearrange(self.video_temporal_attn(v_att), "(B L) T D -> B T L D", B=B)
324
+ v = v + video_gate_tmsa.unsqueeze(2)*v_att
325
+
326
+ a_att = temporalModulate(self.audio_norm1(a), audio_shift_msa, audio_scale_msa)
327
+ a_att = einops.rearrange(a_att, 'B T L D -> B (T L) D')
328
+ a_att = a + audio_gate_msa.unsqueeze(2)*(self.audio_spatial_attn(a_att).view(B, T, -1, D))
329
+
330
+ a = a_att
331
+
332
+ a_avg = self.a_avg_proj(a.mean(dim=2)) # B T D
333
+ v_avg = self.v_avg_proj(v.mean(dim=2)) # B T D
334
+
335
+ v_avg += a_c
336
+ a_avg += v_c
337
+
338
+ scale_v, shift_v, gate_v = self.video_scale(a_avg).chunk(3, dim=-1)
339
+ scale_a, shift_a, gate_a = self.audio_scale(v_avg).chunk(3, dim=-1)
340
+
341
+
342
+ v = v + gate_v.unsqueeze(2) * self.video_mlp(temporalModulate(self.video_norm3(v), shift_v, scale_v))
343
+ a = a + gate_a.unsqueeze(2) * self.audio_mlp(temporalModulate(self.audio_norm3(a), shift_a, scale_a))
344
+
345
+ return v, a
346
+
347
+ def _spatial_attn(self, x, b_size, attn_func):
348
+ x = einops.rearrange(x, "(B N) T D -> (B T) N D", B=b_size)
349
+ x = attn_func(x)
350
+ x = einops.rearrange(x, "(B T) N D -> (B N) T D", B=b_size)
351
+ return x
352
+
353
+
354
+ class FinalLayer(nn.Module):
355
+ """
356
+ The final layer of FLAV.
357
+ """
358
+ def __init__(self, hidden_size, patch_size, out_channels):
359
+ super().__init__()
360
+ self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
361
+ self.linear = nn.Linear(hidden_size, patch_size * patch_size * out_channels, bias=True)
362
+ self.adaLN_modulation = nn.Sequential(
363
+ nn.SiLU(),
364
+ nn.Linear(hidden_size, 2 * hidden_size, bias=True)
365
+ )
366
+
367
+ def forward(self, x, c):
368
+ shift, scale = self.adaLN_modulation(c).chunk(2, dim=-1)
369
+ x = temporalModulate(self.norm_final(x), shift, scale)
370
+ x = self.linear(x)
371
+ return x
372
+
373
+
374
+ class FLAV(nn.Module):
375
+ """
376
+ Diffusion model with a Transformer backbone.
377
+ """
378
+ def __init__(
379
+ self,
380
+ latent_size=None,
381
+ patch_size=2,
382
+ in_channels=4,
383
+ hidden_size=1152,
384
+ depth=28,
385
+ num_heads=16,
386
+ mlp_ratio=4.0,
387
+ class_dropout_prob=0.1,
388
+ num_classes=1000,
389
+ predict_frames = 1,
390
+ grad_ckpt = False,
391
+ n_mels=256,
392
+ audio_fr = 16000,
393
+ causal_attn = False,
394
+ ):
395
+ super().__init__()
396
+ self.in_channels = in_channels
397
+ self.out_channels = in_channels
398
+ self.patch_size = patch_size
399
+ self.num_heads = num_heads
400
+ self.predict_frames = predict_frames
401
+ self.grad_ckpt = grad_ckpt
402
+ self.n_mels = n_mels
403
+ self.audio_fr = audio_fr
404
+ self.latent_size = latent_size # T H W
405
+
406
+ self.num_classes = num_classes
407
+
408
+ self.v_embedder = PatchEmbed(latent_size, patch_size, in_channels, hidden_size, bias=True)
409
+ self.a_embedder = nn.Linear(n_mels, hidden_size, bias=True)
410
+
411
+ self.video_t_embedder = TimestepEmbedder(hidden_size)
412
+ self.audio_t_embedder = TimestepEmbedder(hidden_size)
413
+
414
+ if self.num_classes > 0:
415
+ self.video_y_embedder = LabelEmbedder(num_classes, hidden_size, class_dropout_prob)
416
+ self.audio_y_embedder = LabelEmbedder(num_classes, hidden_size, class_dropout_prob)
417
+
418
+ num_patches = self.v_embedder.num_patches
419
+ self.video_spatial_pos_embed = nn.Parameter(torch.zeros(1, 1, num_patches, hidden_size), requires_grad=True)
420
+ self.video_temporal_pos_embed = nn.Parameter(torch.zeros(1, self.predict_frames, 1, hidden_size), requires_grad=True)
421
+
422
+ self.audio_spatial_pos_embed = nn.Parameter(torch.zeros(1, 1, 10, hidden_size), requires_grad=True)
423
+ self.audio_temporal_pos_embed = nn.Parameter(torch.zeros(1, self.predict_frames, 1, hidden_size), requires_grad=True)
424
+
425
+
426
+ self.blocks = nn.ModuleList([
427
+ FLAVBlock(hidden_size, num_heads, mlp_ratio=mlp_ratio, grad_ckpt=grad_ckpt, causal_attn=causal_attn) for _ in range(depth)
428
+ ])
429
+
430
+ self.video_final_layer = FinalLayer(hidden_size, patch_size, self.out_channels)
431
+ self.audio_final_layer = FinalLayer(hidden_size, 1, n_mels)
432
+ self.initialize_weights()
433
+
434
+ def initialize_weights(self):
435
+ # Initialize transformer layers:
436
+ def _basic_init(module):
437
+ if isinstance(module, nn.Linear):
438
+ torch.nn.init.xavier_uniform_(module.weight)
439
+ if module.bias is not None:
440
+ nn.init.constant_(module.bias, 0)
441
+ self.apply(_basic_init)
442
+
443
+ # Initialize patch_embed like nn.Linear (instead of nn.Conv2d):
444
+ w = self.v_embedder.proj.weight.data
445
+ nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
446
+ nn.init.constant_(self.v_embedder.proj.bias, 0)
447
+
448
+
449
+ if self.num_classes > 0:
450
+ nn.init.normal_(self.video_y_embedder.embedding_table.weight, std=0.02)
451
+ nn.init.normal_(self.audio_y_embedder.embedding_table.weight, std=0.02)
452
+
453
+ # Initialize timestep embedding MLP:
454
+ nn.init.normal_(self.video_t_embedder.mlp[0].weight, std=0.02)
455
+ nn.init.normal_(self.video_t_embedder.mlp[2].weight, std=0.02)
456
+
457
+ nn.init.normal_(self.audio_t_embedder.mlp[0].weight, std=0.02)
458
+ nn.init.normal_(self.audio_t_embedder.mlp[2].weight, std=0.02)
459
+
460
+ # Zero-out adaLN modulation layers in FLAV blocks:
461
+ for block in self.blocks:
462
+ nn.init.constant_(block.video_adaLN_modulation[-1].weight, 0)
463
+ nn.init.constant_(block.video_adaLN_modulation[-1].bias, 0)
464
+ nn.init.constant_(block.audio_adaLN_modulation[-1].weight, 0)
465
+ nn.init.constant_(block.audio_adaLN_modulation[-1].bias, 0)
466
+
467
+ nn.init.constant_(block.video_scale[-1].weight, 0)
468
+ nn.init.constant_(block.video_scale[-1].bias, 0)
469
+
470
+ nn.init.constant_(block.audio_scale[-1].weight, 0)
471
+ nn.init.constant_(block.audio_scale[-1].bias, 0)
472
+
473
+ # Zero-out output layers:
474
+ nn.init.constant_(self.video_final_layer.adaLN_modulation[-1].weight, 0)
475
+ nn.init.constant_(self.video_final_layer.adaLN_modulation[-1].bias, 0)
476
+ nn.init.constant_(self.video_final_layer.linear.weight, 0)
477
+ nn.init.constant_(self.video_final_layer.linear.bias, 0)
478
+
479
+ nn.init.constant_(self.audio_final_layer.adaLN_modulation[-1].weight, 0)
480
+ nn.init.constant_(self.audio_final_layer.adaLN_modulation[-1].bias, 0)
481
+ nn.init.constant_(self.audio_final_layer.linear.weight, 0)
482
+ nn.init.constant_(self.audio_final_layer.linear.bias, 0)
483
+
484
+ def unpatchify(self, x):
485
+ """
486
+ x: (N, T, patch_size**2 * C)
487
+ imgs: (N, C, H, W)
488
+ """
489
+ c = self.out_channels
490
+ p = self.v_embedder.patch_size[0]
491
+ h = w = int(x.shape[1] ** 0.5)
492
+ assert h * w == x.shape[1]
493
+
494
+ x = x.reshape(shape=(x.shape[0], h, w, p, p, c))
495
+ x = torch.einsum('nhwpqc->nchpwq', x)
496
+ imgs = x.reshape(shape=(x.shape[0], c, h * p, h * p))
497
+ return imgs
498
+
499
+ def _apply_rnd_mask(self, input, mask, device="cuda"):
500
+ input_rnd = torch.rand(input[0].shape).unsqueeze(0).to(device=device)*2 - 1
501
+ return self._apply_mask(input, mask, input_rnd)
502
+
503
+ def _apply_zero_mask(self, input, mask, device="cuda"):
504
+ input_zero= torch.zeros(input[0].shape).unsqueeze(0).to(device=device)
505
+ return self._apply_mask(input, mask, input_zero)
506
+
507
+ def _get_frames_mask(self, bs):
508
+ """
509
+ bs: batch size
510
+
511
+ returns a boolean mask to be applied to condition frames
512
+ to mask a selected number of random frames
513
+ """
514
+ fmask = np.full(self.cond_frames*bs, False)
515
+ frames = list(range(self.cond_frames))
516
+ for b in range(bs):
517
+ if random.randint(0, 100) < self.mask_freq:
518
+ sub_frames = random.sample(frames, min(self.cond_frames, self.frames_to_mask))
519
+ idxs = [f+(b*self.cond_frames) for f in sub_frames]
520
+ fmask[idxs] = True
521
+ return fmask
522
+
523
+ def _get_batch_mask(self, bs):
524
+ """
525
+ bs: batch size
526
+
527
+ returns a boolean mask to be applied to condition frames
528
+ to mask a random number of condition sequences in a batch
529
+ """
530
+ rnd = np.random.rand(bs)
531
+ bmask= rnd < self.batch_mask_freq/100
532
+ bmask = np.repeat(bmask, self.cond_frames)
533
+ return bmask
534
+
535
+ def _apply_mask(self, input, mask, values):
536
+ input[mask] = values
537
+ return input
538
+
539
+ def audio_unpatchify(self, x):
540
+ """
541
+ x: (N, T, patch_size * C)
542
+ audio: (N, N_mels, frames)
543
+ """
544
+ c = 1
545
+ p = self.audio_patch_size
546
+ h = int(self.n_mels//p)
547
+ w = int((self.audio_fr/1600)/p)
548
+
549
+
550
+ assert h * w == x.shape[1]
551
+
552
+ x = x.reshape(shape=(x.shape[0], h, w, p, p, c))
553
+ x = torch.einsum('nhwpqc->nchpwq', x)
554
+ audio = x.reshape(shape=(x.shape[0], c, h * p, w * p))
555
+ return audio
556
+
557
+ def forward(self, v, a, t, y):
558
+ """
559
+ Forward pass of FLAV.
560
+ v: (B, T, C, H, W) tensor of spatial inputs (images or latent representations of images)
561
+ a: (B, 1, n_bins, T) # mel spectrogram of audio
562
+ t: (B, T) tensor of diffusion timesteps
563
+ y: (B,) tensor of class labels
564
+ """
565
+
566
+ ### Video
567
+ B, T, C, H, W = v.shape
568
+ v = einops.rearrange(v, 'B T C H W -> (B T) C H W')
569
+ v = self.v_embedder(v)
570
+ v = einops.rearrange(v, '(B T) L D -> B T L D', T=T)
571
+ v = v + self.video_temporal_pos_embed + self.video_spatial_pos_embed
572
+
573
+
574
+ ### Audio
575
+ a = einops.rearrange(a, "B T C N F -> B T C F N").squeeze(2)
576
+ a = self.a_embedder(a)
577
+ a = a + self.audio_temporal_pos_embed + self.audio_spatial_pos_embed
578
+
579
+ ### Conditioning
580
+ t = t.view(-1) # B T -> (B T)
581
+ v_t = self.video_t_embedder(t) # (B, T, D)
582
+ v_t = v_t.view(B, T, -1) # (B T) D -> B T D
583
+
584
+ if self.num_classes > 0:
585
+ v_y = self.video_y_embedder(y, self.training) # (B, D)
586
+ v_y = v_y.unsqueeze(1).expand(-1, T, -1) # (B, T, D)
587
+
588
+ v_c = (v_t + v_y) if self.num_classes > 0 else v_t # (B, T, D)
589
+
590
+ a_t = self.audio_t_embedder(t) # (B, T, D)
591
+ a_t = a_t.view(B, T, -1)
592
+
593
+ if self.num_classes > 0:
594
+ a_y = self.audio_y_embedder(y, self.training)
595
+ a_y = a_y.unsqueeze(1).expand(-1, T, -1)
596
+
597
+ a_c = (a_t + a_y) if self.num_classes > 0 else a_t # (B, T, D)
598
+
599
+ for block in self.blocks:
600
+ v, a = block(v, a, v_c, a_c) # (B, T, D)
601
+
602
+ v = self.video_final_layer(v, v_c) # (B, T, patch_size ** 2 * out_channels), (B, T, L)
603
+ a = self.audio_final_layer(a, a_c)
604
+
605
+ v = einops.rearrange(v, 'B T L D -> (B T) L D', T = T)
606
+ v = self.unpatchify(v) # (B, out_channels, H, W)
607
+ v = einops.rearrange(v, '(B T) C H W -> B T C H W', T = T)
608
+
609
+ a = einops.rearrange(a, 'B T F N -> B T N F', T = T).unsqueeze(2)
610
+ return v, a
611
+
612
+ def forward_with_cfg(self, v, a, t, y, cfg_scale):
613
+ """
614
+ Forward pass of FLAV, but also batches the unconditional forward pass for classifier-free guidance.
615
+ """
616
+ v_combined = torch.cat([v, v], dim=0)
617
+
618
+ a_combined = torch.cat([a, a], dim=0)
619
+
620
+ y_null = torch.tensor([self.num_classes]*v.shape[0], device=v.device)
621
+ y = torch.cat([y, y_null], dim=0)
622
+
623
+ t = torch.cat([t, t], dim=0)
624
+
625
+ v_model_out, a_model_out = self.forward(v_combined, a_combined, t, y)
626
+ v_eps = v_model_out
627
+ a_eps = a_model_out
628
+
629
+ v_cond_eps, v_uncond_eps = torch.split(v_eps, len(v_eps) // 2, dim=0)
630
+ v_eps = v_uncond_eps + cfg_scale * (v_cond_eps - v_uncond_eps)
631
+
632
+ a_cond_eps, a_uncond_eps = torch.split(a_eps, len(a_eps) // 2, dim=0)
633
+ a_eps = a_uncond_eps + cfg_scale * (a_cond_eps - a_uncond_eps)
634
+
635
+ return v_eps, a_eps
636
+
637
+ #################################################################################
638
+ # Sine/Cosine Positional Embedding Functions #
639
+ #################################################################################
640
+ # https://github.com/facebookresearch/mae/blob/main/util/video_pos_embed.py
641
+
642
+ def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False, extra_tokens=0):
643
+ """
644
+ grid_size: int of the grid height and width
645
+ return:
646
+ video_pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
647
+ """
648
+ grid_h = np.arange(grid_size, dtype=np.float32)
649
+ grid_w = np.arange(grid_size, dtype=np.float32)
650
+ grid = np.meshgrid(grid_w, grid_h) # here w goes first
651
+ grid = np.stack(grid, axis=0)
652
+
653
+ grid = grid.reshape([2, 1, grid_size, grid_size])
654
+ video_pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
655
+ if cls_token and extra_tokens > 0:
656
+ video_pos_embed = np.concatenate([np.zeros([extra_tokens, embed_dim]), video_pos_embed], axis=0)
657
+ return video_pos_embed
658
+
659
+
660
+ def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
661
+ assert embed_dim % 2 == 0
662
+
663
+ # use half of dimensions to encode grid_h
664
+ emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2)
665
+ emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2)
666
+
667
+ emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D)
668
+ return emb
669
+
670
+
671
+ def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
672
+ """
673
+ embed_dim: output dimension for each position
674
+ pos: a list of positions to be encoded: size (M,)
675
+ out: (M, D)
676
+ """
677
+ assert embed_dim % 2 == 0
678
+ omega = np.arange(embed_dim // 2, dtype=np.float64)
679
+ omega /= embed_dim / 2.
680
+ omega = 1. / 10000**omega # (D/2,)
681
+
682
+ pos = pos.reshape(-1) # (M,)
683
+ out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product
684
+
685
+ emb_sin = np.sin(out) # (M, D/2)
686
+ emb_cos = np.cos(out) # (M, D/2)
687
+
688
+ emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
689
+ return emb
690
+
691
+
692
+ #################################################################################
693
+ # FLAV Configs #
694
+ #################################################################################
695
+
696
+ def FLAV_XL_2(**kwargs):
697
+ return FLAV(depth=28, hidden_size=1152, patch_size=2, num_heads=16, **kwargs)
698
+
699
+ def FLAV_XL_4(**kwargs):
700
+ return FLAV(depth=28, hidden_size=1152, patch_size=4, num_heads=16, **kwargs)
701
+
702
+ def FLAV_XL_8(**kwargs):
703
+ return FLAV(depth=28, hidden_size=1152, patch_size=8, num_heads=16, **kwargs)
704
+
705
+ # def FLAV_L_2(**kwargs):
706
+ # return FLAV(depth=24, hidden_size=1024, patch_size=2, num_heads=16, **kwargs)
707
+
708
+ def FLAV_L_1(**kwargs):
709
+ return FLAV(depth=24, hidden_size=1024, patch_size=1, num_heads=16, **kwargs)
710
+
711
+ def FLAV_L_2(**kwargs):
712
+ return FLAV(depth=20, hidden_size=1024, patch_size=2, num_heads=16, **kwargs)
713
+
714
+ def FLAV_L_4(**kwargs):
715
+ return FLAV(depth=20, hidden_size=1024, patch_size=4, num_heads=16, **kwargs)
716
+
717
+ def FLAV_L_8(**kwargs):
718
+ return FLAV(depth=20, hidden_size=1024, patch_size=8, num_heads=16, **kwargs)
719
+
720
+ # def FLAV_B_2(**kwargs):
721
+ # return FLAV(depth=12, hidden_size=768, patch_size=2, num_heads=12, **kwargs)
722
+
723
+ def FLAV_B_1(**kwargs):
724
+ return FLAV(depth=12, hidden_size=768, patch_size=1, num_heads=12, **kwargs)
725
+
726
+ def FLAV_B_2(**kwargs):
727
+ return FLAV(depth=12, hidden_size=768, patch_size=2, num_heads=12, **kwargs)
728
+
729
+ def FLAV_B_4(**kwargs):
730
+ return FLAV(depth=12, hidden_size=768, patch_size=4, num_heads=12, **kwargs)
731
+
732
+ def FLAV_B_8(**kwargs):
733
+ return FLAV(depth=12, hidden_size=768, patch_size=8, num_heads=12, **kwargs)
734
+
735
+ def FLAV_S_2(**kwargs):
736
+ return FLAV(depth=12, hidden_size=384, patch_size=2, num_heads=6, **kwargs)
737
+
738
+ def FLAV_S_4(**kwargs):
739
+ return FLAV(depth=12, hidden_size=384, patch_size=4, num_heads=6, **kwargs)
740
+
741
+ def FLAV_S_8(**kwargs):
742
+ return FLAV(depth=12, hidden_size=384, patch_size=8, num_heads=6, **kwargs)
743
+
744
+
745
+ FLAV_models = {
746
+ 'FLAV-XL/2': FLAV_XL_2, 'FLAV-XL/4': FLAV_XL_4, 'FLAV-XL/8': FLAV_XL_8,
747
+ 'FLAV-L/1' : FLAV_L_1, 'FLAV-L/2': FLAV_L_2, 'FLAV-L/4': FLAV_L_4, 'FLAV-L/8': FLAV_L_8,
748
+ 'FLAV-B/1' : FLAV_B_1, 'FLAV-B/2': FLAV_B_2, 'FLAV-B/4': FLAV_B_4, 'FLAV-B/8': FLAV_B_8,
749
+ 'FLAV-S/2' : FLAV_S_2, 'FLAV-S/4': FLAV_S_4, 'FLAV-S/8': FLAV_S_8,
750
+ }
751
+
utils.py ADDED
@@ -0,0 +1,267 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # from moviepy.video.io.ImageSequenceClip import ImageSequenceClip
2
+ # # from moviepy.audio.AudioClip import AudioArrayClip
3
+ # from moviepy.audio.io.AudioFileClip import AudioFileClip
4
+ from torch.utils.data import DataLoader
5
+ from dataset import AudioVideoDataset, LatentDataset
6
+ import torch as th
7
+ import numpy as np
8
+
9
+ import einops
10
+ from moviepy.audio.io.AudioFileClip import AudioFileClip
11
+ from moviepy.video.io.ImageSequenceClip import ImageSequenceClip
12
+ from diffusers.models import AutoencoderKL
13
+
14
+ from converter import denormalize, denormalize_spectrogram
15
+
16
+ import soundfile as sf
17
+ import os
18
+ import json
19
+ import torch
20
+ from tqdm import tqdm
21
+ #################################################################################
22
+ # Video Utils #
23
+ #################################################################################
24
+
25
+
26
+ def preprocess_video(video):
27
+ # video = 255*(video+1)/2.0 # [-1,1] -> [0,1] -> [0,255]
28
+ # video = th.clamp(video, 0, 255).to(dtype=th.uint8, device="cuda")
29
+ video = out2img(video)
30
+ video = einops.rearrange(video, 't c h w -> t h w c').cpu().numpy()
31
+ return video
32
+
33
+ def preprocess_video_batch(videos):
34
+ B = videos.shape[0]
35
+ videos_prep = np.empty(B, dtype=np.ndarray)
36
+ for b in range(B):
37
+ videos_prep[b] = preprocess_video(videos[b])
38
+ videos_prep = np.stack(videos_prep, axis=0)
39
+ return videos_prep
40
+
41
+ def save_latents(video, audio, y, output_path, name_prefix, ext=".pt"):
42
+ os.makedirs(output_path, exist_ok=True)
43
+ th.save(
44
+ {
45
+ "video":video,
46
+ "audio":audio,
47
+ "y":y
48
+ }, os.path.join(output_path, name_prefix + ext))
49
+
50
+ def save_multimodal(video, audio, output_path, name_prefix, video_fps=10, audio_fps=16000, audio_dir=None):
51
+ if not audio_dir:
52
+ audio_dir = output_path
53
+
54
+ #prepare folders
55
+ audio_dir = os.path.join(audio_dir, "audio")
56
+ os.makedirs(audio_dir, exist_ok=True)
57
+ audio_path = os.path.join(audio_dir, name_prefix + "_audio.wav")
58
+
59
+ video_dir = os.path.join(output_path, "video")
60
+ os.makedirs(video_dir, exist_ok=True)
61
+ video_path = os.path.join(video_dir, name_prefix + "_video.mp4")
62
+
63
+ #save audio
64
+ sf.write(audio_path, audio, samplerate=audio_fps)
65
+
66
+ #save video
67
+ video = preprocess_video(video)
68
+
69
+ imgs = [img for img in video]
70
+ video_clip = ImageSequenceClip(imgs, fps=video_fps)
71
+ audio_clip = AudioFileClip(audio_path)
72
+ video_clip = video_clip.with_audio(audio_clip)
73
+ video_clip.write_videofile(video_path, video_fps, audio=True, audio_fps=audio_fps)
74
+
75
+ def get_dataloader(args, logger, sequence_length, train, latents=False):
76
+ if latents:
77
+ train_set = LatentDataset(args.data_path, train=train)
78
+ else:
79
+ train_set = AudioVideoDataset(
80
+ args.data_path,
81
+ train=train,
82
+ sample_every_n_frames=1,
83
+ resolution=args.image_size,
84
+ sequence_length = sequence_length,
85
+ audio_channels = 1,
86
+ sample_rate=16000,
87
+ min_length=1,
88
+ ignore_cache=args.ignore_cache,
89
+ labeled=args.num_classes > 0,
90
+ target_video_fps=args.target_video_fps,
91
+ )
92
+ loader = DataLoader(
93
+ train_set,
94
+ batch_size=args.batch_size,
95
+ shuffle=True,
96
+ num_workers=args.num_workers,
97
+ pin_memory=True,
98
+ drop_last=True
99
+ )
100
+ if logger is not None:
101
+ logger.info(f'{"Train" if train else "Test"} Dataset contains {len(train_set)}, images ({args.data_path})')
102
+ else:
103
+ print(f'{"Train" if train else "Test"} Dataset contains {len(train_set)}, images ({args.data_path})')
104
+ return loader
105
+
106
+ @torch.no_grad()
107
+ def encode_video(video, vae, use_sd_vae = False):
108
+ b, t, c, h, w = video.shape
109
+ video = einops.rearrange(video, "b t c h w-> (b t) c h w")
110
+ if use_sd_vae:
111
+ video = vae.encode(video).latent_dist.sample().mul_(0.18215)
112
+ else:
113
+ video = vae.encode(video)*vae.cfg.scaling_factor
114
+ video = einops.rearrange(video, "(b t) c h w -> b t c h w", t=t)
115
+ return video
116
+
117
+ @torch.no_grad()
118
+ def decode_video(video, vae):
119
+ b = video.shape[0]
120
+ video_decoded = []
121
+ video = einops.rearrange(video, "b t c h w -> (b t) c h w")
122
+
123
+ #use minibatch to avoid memory error
124
+ for i in range(0, video.shape[0], b):
125
+ if isinstance(vae, AutoencoderKL):
126
+ video_decoded.append(vae.decode(video[i:i+b] / 0.18215).sample.detach().cpu())
127
+ else:
128
+ video_decoded.append(vae.decode(video[i:i+b] / vae.cfg.scaling_factor).detach().cpu())
129
+
130
+ video = torch.cat(video_decoded, dim=0)
131
+ video = einops.rearrange(video, "(b t) c h w ->b t c h w",b=b)
132
+ return video
133
+
134
+
135
+ def generate_sample(vae,
136
+ rectified_flow,
137
+ forward_fn,
138
+ video_length,
139
+ video_latent_size,
140
+ audio_latent_size,
141
+ y,
142
+ cfg_scale,
143
+ device):
144
+
145
+
146
+ with torch.no_grad():
147
+ v_z = torch.randn(video_latent_size, device=device)*rectified_flow.noise_scale
148
+ a_z = torch.randn(audio_latent_size, device=device)*rectified_flow.noise_scale
149
+
150
+ model_kwargs = dict(y=y, cfg_scale=cfg_scale) if cfg_scale else dict(y=y)
151
+
152
+ sample_fn = rectified_flow.sample(
153
+ forward_fn, v_z, a_z, model_kwargs=model_kwargs, progress=True)()
154
+
155
+ video = []
156
+ audio = []
157
+ for _ in tqdm(range(video_length), desc="Generating frames"):
158
+ video_samples, audio_samples = next(sample_fn)
159
+
160
+ video.append(video_samples)
161
+ audio.append(audio_samples)
162
+
163
+ video = torch.stack(video, dim=1)
164
+ audio = torch.stack(audio, dim=1)
165
+
166
+ video = decode_video(video, vae)
167
+ audio = einops.rearrange(audio, "B T C N F -> B C N (T F)")
168
+
169
+ return video, audio
170
+
171
+ def generate_sample_a2v(vae,
172
+ rectified_flow,
173
+ forward_fn,
174
+ video_length,
175
+ video_latent_size,
176
+ audio,
177
+ y,
178
+ device,
179
+ cfg_scale=1,
180
+ scale=1):
181
+
182
+
183
+ v_z = torch.randn(video_latent_size, device=device)*rectified_flow.noise_scale
184
+
185
+ model_kwargs = dict(y=y, cfg_scale=cfg_scale) if cfg_scale else dict(y=y)
186
+
187
+ sample_fn = rectified_flow.sample_a2v(
188
+ forward_fn, v_z, audio, model_kwargs=model_kwargs, scale=scale, progress=True)()
189
+
190
+ video = []
191
+ for i in tqdm(range(video_length), desc="Generating frames"):
192
+ video_samples = next(sample_fn)
193
+
194
+ video.append(video_samples)
195
+
196
+ video = torch.stack(video, dim=1)
197
+
198
+ video = decode_video(video, vae)
199
+ audio = einops.rearrange(audio, "B T C N F -> B C N (T F)")
200
+
201
+ return video, audio
202
+
203
+ def generate_sample_v2a(vae,
204
+ rectified_flow,
205
+ forward_fn,
206
+ video_length,
207
+ video,
208
+ audio_latent_size,
209
+ y,
210
+ device,
211
+ cfg_scale=1,
212
+ scale=1):
213
+
214
+
215
+ a_z = torch.randn(audio_latent_size, device=device)*rectified_flow.noise_scale
216
+
217
+ model_kwargs = dict(y=y, cfg_scale=cfg_scale) if cfg_scale else dict(y=y)
218
+
219
+ sample_fn = rectified_flow.sample_v2a(
220
+ forward_fn, video, a_z, model_kwargs=model_kwargs, scale=scale, progress=True)()
221
+
222
+ audio = []
223
+ for i in tqdm(range(video_length), desc="Generating frames"):
224
+ audio_samples = next(sample_fn)
225
+
226
+ audio.append(audio_samples)
227
+
228
+ audio = torch.stack(audio, dim=1)
229
+
230
+ video = decode_video(video, vae)
231
+ audio = einops.rearrange(audio, "B T C N F -> B C N (T F)")
232
+
233
+ return video, audio
234
+
235
+ def dict_to_json(path, args):
236
+ with open(path, 'w') as f:
237
+ json.dump(args.__dict__, f, indent=2)
238
+
239
+ def json_to_dict(path, args):
240
+ with open(path, 'r') as f:
241
+ args.__dict__ = json.load(f)
242
+ return args
243
+
244
+ def log_args(args, logger):
245
+ text = ""
246
+ for k, v in vars(args).items():
247
+ text += f'{k}={v}\n'
248
+ logger.info(f"##### ARGS #####\n{text}")
249
+
250
+ def out2img(samples):
251
+ return th.clamp(127.5 * samples + 128.0, 0, 255).to(
252
+ dtype=th.uint8
253
+ ).cuda()
254
+
255
+ def get_gpu_usage():
256
+ device = th.device('cuda:0')
257
+ free, total = th.cuda.mem_get_info(device)
258
+ mem_used_MB = (total - free) / 1024 ** 2
259
+ return mem_used_MB
260
+
261
+ def get_wavs(norm_spec, vocoder, audio_scale, device):
262
+ norm_spec = norm_spec.squeeze(1)
263
+ norm_spec = norm_spec / audio_scale
264
+ post_norm_spec = denormalize(norm_spec).to(device)
265
+ raw_chunk_spec = denormalize_spectrogram(post_norm_spec)
266
+ wavs = vocoder.inference(raw_chunk_spec)
267
+ return wavs