Spaces:
Running
on
Zero
Running
on
Zero
Alex Ergasti
commited on
Commit
·
b89c182
1
Parent(s):
9ca6da9
Init
Browse files- app.py +140 -0
- common_parser.py +50 -0
- converter.py +368 -0
- dataset.py +206 -0
- download.py +50 -0
- models.py +751 -0
- utils.py +267 -0
app.py
ADDED
@@ -0,0 +1,140 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
# the first flag below was False when we tested this script but True makes A100 training a lot faster:
|
3 |
+
torch.backends.cuda.matmul.allow_tf32 = True
|
4 |
+
torch.backends.cudnn.allow_tf32 = True
|
5 |
+
|
6 |
+
import os
|
7 |
+
|
8 |
+
from diffusers.models import AutoencoderKL
|
9 |
+
from models import FLAV_models
|
10 |
+
|
11 |
+
from diffusion.rectified_flow import RectifiedFlow
|
12 |
+
|
13 |
+
from diffusers.training_utils import EMAModel
|
14 |
+
from converter import Generator
|
15 |
+
from utils import *
|
16 |
+
|
17 |
+
import tempfile
|
18 |
+
import gradio as gr
|
19 |
+
from huggingface_hub import hf_hub_download
|
20 |
+
AUDIO_T_PER_FRAME = 1600 // 160
|
21 |
+
|
22 |
+
#################################################################################
|
23 |
+
# Global Model Setup #
|
24 |
+
#################################################################################
|
25 |
+
|
26 |
+
# These variables will be initialized in setup_models() and used in main()
|
27 |
+
vae = None
|
28 |
+
model = None
|
29 |
+
vocoder = None
|
30 |
+
audio_scale = 3.50
|
31 |
+
|
32 |
+
|
33 |
+
def setup_models():
|
34 |
+
global vae, model, vocoder
|
35 |
+
|
36 |
+
device = "cpu"
|
37 |
+
vae = AutoencoderKL.from_pretrained(f"stabilityai/sd-vae-ft-ema")
|
38 |
+
|
39 |
+
model = FLAV_models["FLAV-B/1"](
|
40 |
+
latent_size= 256//8,
|
41 |
+
in_channels = 4,
|
42 |
+
num_classes = 0,
|
43 |
+
predict_frames = 10,
|
44 |
+
causal_attn = True,
|
45 |
+
)
|
46 |
+
|
47 |
+
ckpt_path = hf_hub_download(repo_id="MaverickAlex/R-FLAV", filename="aist-ema.pth")
|
48 |
+
|
49 |
+
state_dict = torch.load(ckpt_path)
|
50 |
+
|
51 |
+
ema = EMAModel(model.parameters())
|
52 |
+
ema.load_state_dict(state_dict)
|
53 |
+
ema.copy_to(model.parameters())
|
54 |
+
|
55 |
+
hf_hub_download(repo_id="MaverickAlex/R-FLAV", filename="vocoder-aist/config.json")
|
56 |
+
vocoder_path = hf_hub_download(repo_id="MaverickAlex/R-FLAV", filename="vocoder-aist/vocoder.pt")
|
57 |
+
|
58 |
+
vocoder_path = vocoder_path.replace("vocoder.pt", "")
|
59 |
+
vocoder = Generator.from_pretrained(vocoder_path)
|
60 |
+
|
61 |
+
vae.to(device)
|
62 |
+
model.to(device)
|
63 |
+
vocoder.to(device)
|
64 |
+
|
65 |
+
|
66 |
+
|
67 |
+
def generate_video(num_frames=10, steps=2, seed=42):
|
68 |
+
global vae, model, vocoder
|
69 |
+
# Setup device
|
70 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
71 |
+
torch.manual_seed(seed)
|
72 |
+
|
73 |
+
# Set up generation parameters
|
74 |
+
video_latent_size = (1, 10, 4, 256//8, 256//8)
|
75 |
+
audio_latent_size = (1, 10, 1, 256, AUDIO_T_PER_FRAME)
|
76 |
+
|
77 |
+
rectified_flow = RectifiedFlow(num_timesteps=steps,
|
78 |
+
warmup_timesteps=10,
|
79 |
+
window_size=10)
|
80 |
+
|
81 |
+
# Generate sample
|
82 |
+
video, audio = generate_sample(
|
83 |
+
vae=vae, # These globals are set by setup_models
|
84 |
+
rectified_flow=rectified_flow,
|
85 |
+
forward_fn=model.forward,
|
86 |
+
video_length=num_frames,
|
87 |
+
video_latent_size=video_latent_size,
|
88 |
+
audio_latent_size=audio_latent_size,
|
89 |
+
y=None,
|
90 |
+
cfg_scale=None,
|
91 |
+
device=device
|
92 |
+
)
|
93 |
+
|
94 |
+
# Convert to wav
|
95 |
+
wavs = get_wavs(audio, vocoder, audio_scale, device)
|
96 |
+
|
97 |
+
# Save to temporary files
|
98 |
+
temp_dir = tempfile.mkdtemp()
|
99 |
+
video_path = os.path.join(temp_dir, "video", "generated_video.mp4")
|
100 |
+
|
101 |
+
# Use the first video and wav
|
102 |
+
vid, wav = video[0], wavs[0]
|
103 |
+
save_multimodal(vid, wav, temp_dir, "generated")
|
104 |
+
|
105 |
+
return video_path
|
106 |
+
|
107 |
+
def ui_generate_video(num_frames, steps, seed):
|
108 |
+
try:
|
109 |
+
return generate_video(int(num_frames), int(steps), int(seed))
|
110 |
+
except Exception as e:
|
111 |
+
return None
|
112 |
+
|
113 |
+
# Create Gradio interface
|
114 |
+
with gr.Blocks(title="FLAV Video Generator") as demo:
|
115 |
+
gr.Markdown("# FLAV Video Generator")
|
116 |
+
gr.Markdown("Generate videos using the FLAV model")
|
117 |
+
|
118 |
+
num_frames = None
|
119 |
+
steps = None
|
120 |
+
seed = None
|
121 |
+
|
122 |
+
video_output = None
|
123 |
+
with gr.Row():
|
124 |
+
with gr.Column():
|
125 |
+
num_frames = gr.Slider(minimum=5, maximum=30, step=1, value=10, label="Number of Frames")
|
126 |
+
steps = gr.Slider(minimum=1, maximum=20, step=1, value=5, label="Number of Steps (multiplied by a factor of 10)")
|
127 |
+
seed = gr.Slider(minimum=0, maximum=9999, step=1, value=42, label="Random Seed")
|
128 |
+
generate_btn = gr.Button("Generate Video")
|
129 |
+
|
130 |
+
with gr.Column():
|
131 |
+
video_output = gr.PlayableVideo(label="Generated Video", width=256, height=256)
|
132 |
+
generate_btn.click(
|
133 |
+
fn=ui_generate_video,
|
134 |
+
inputs=[num_frames, steps, seed],
|
135 |
+
outputs=[video_output]
|
136 |
+
)
|
137 |
+
|
138 |
+
if __name__ == "__main__":
|
139 |
+
setup_models()
|
140 |
+
demo.launch()
|
common_parser.py
ADDED
@@ -0,0 +1,50 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
from models import FLAV_models
|
3 |
+
|
4 |
+
class CommonParser:
|
5 |
+
def __init__(self):
|
6 |
+
self.parser = argparse.ArgumentParser()
|
7 |
+
# Datasets
|
8 |
+
self.parser.add_argument("--data-path", type=str, required=True)
|
9 |
+
self.parser.add_argument("--load-latents", action="store_true")
|
10 |
+
self.parser.add_argument("--num-classes", type=int, default=9)
|
11 |
+
self.parser.add_argument("--image-size", type=int, choices=[64, 256, 512, 1024], default=256)
|
12 |
+
self.parser.add_argument("--target-video-fps", type=int, default=10)
|
13 |
+
self.parser.add_argument("--ignore-cache", action="store_true")
|
14 |
+
self.parser.add_argument("--audio-scale", type=float, default=3.5009668382765917)
|
15 |
+
|
16 |
+
# Results
|
17 |
+
self.parser.add_argument("--video-length", type=int, default=1)
|
18 |
+
self.parser.add_argument("--predict-frames", type=int, default=10)
|
19 |
+
self.parser.add_argument("--results-dir", type=str, default="results")
|
20 |
+
self.parser.add_argument("--experiment-dir", type=str, default="")
|
21 |
+
self.parser.add_argument("--checkpoint-dir", type=str, default="checkpoints")
|
22 |
+
self.parser.add_argument("--ckpt-every", type=int, default=5_000)
|
23 |
+
|
24 |
+
# Models
|
25 |
+
self.parser.add_argument("--seed", type=int, default=42)
|
26 |
+
self.parser.add_argument("--model", type=str, choices=list(FLAV_models.keys()), default="FLAV-XL/2")
|
27 |
+
self.parser.add_argument("--vae", type=str, choices=["ema", "mse"], default="ema")
|
28 |
+
self.parser.add_argument("--use_sd_vae", action="store_true")
|
29 |
+
self.parser.add_argument("--vocoder-ckpt", type=str, default="vocoder/")
|
30 |
+
self.parser.add_argument("--optimizer-wd", type=float, default=0.02)
|
31 |
+
|
32 |
+
# Resources
|
33 |
+
self.parser.add_argument("--batch-size", type=int, default=4)
|
34 |
+
self.parser.add_argument("--num-workers", type=int, default=32)
|
35 |
+
self.parser.add_argument("--log-every", type=int, default=100)
|
36 |
+
|
37 |
+
# Config
|
38 |
+
self.parser.add_argument("--load-config", action="store_true")
|
39 |
+
self.parser.add_argument("--config-no-save", action="store_true")
|
40 |
+
self.parser.add_argument("--config-path", type=str, default="")
|
41 |
+
self.parser.add_argument("--config-name", type=str, default="config.json")
|
42 |
+
|
43 |
+
# Architecture
|
44 |
+
self.parser.add_argument("--causal-attn", action="store_true")
|
45 |
+
|
46 |
+
#RF
|
47 |
+
self.parser.add_argument("--num_timesteps", type=int, default=2)
|
48 |
+
|
49 |
+
def get_parser(self):
|
50 |
+
return self.parser
|
converter.py
ADDED
@@ -0,0 +1,368 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
|
3 |
+
import numpy as np
|
4 |
+
from PIL import Image
|
5 |
+
|
6 |
+
import math
|
7 |
+
import os
|
8 |
+
import random
|
9 |
+
import torch
|
10 |
+
import json
|
11 |
+
import torch.utils.data
|
12 |
+
import numpy as np
|
13 |
+
import librosa
|
14 |
+
from librosa.util import normalize
|
15 |
+
from scipy.io.wavfile import read
|
16 |
+
from librosa.filters import mel as librosa_mel_fn
|
17 |
+
|
18 |
+
import torch.nn.functional as F
|
19 |
+
import torch.nn as nn
|
20 |
+
from torch.nn import Conv1d, ConvTranspose1d, AvgPool1d, Conv2d
|
21 |
+
from torch.nn.utils import weight_norm, remove_weight_norm, spectral_norm
|
22 |
+
|
23 |
+
|
24 |
+
def normalize(images):
|
25 |
+
"""
|
26 |
+
Normalize an image array to [-1,1].
|
27 |
+
"""
|
28 |
+
if images.min() >= 0:
|
29 |
+
return 2.0 * images - 1.0
|
30 |
+
else:
|
31 |
+
return images
|
32 |
+
|
33 |
+
def denormalize(images):
|
34 |
+
"""
|
35 |
+
Denormalize an image array to [0,1].
|
36 |
+
"""
|
37 |
+
if images.min() < 0:
|
38 |
+
return (images / 2 + 0.5).clamp(0, 1)
|
39 |
+
else:
|
40 |
+
return images.clamp(0, 1)
|
41 |
+
|
42 |
+
|
43 |
+
MAX_WAV_VALUE = 32768.0
|
44 |
+
|
45 |
+
|
46 |
+
def load_wav(full_path):
|
47 |
+
sampling_rate, data = read(full_path)
|
48 |
+
return data, sampling_rate
|
49 |
+
|
50 |
+
|
51 |
+
def dynamic_range_compression(x, C=1, clip_val=1e-5):
|
52 |
+
return np.log(np.clip(x, a_min=clip_val, a_max=None) * C)
|
53 |
+
|
54 |
+
|
55 |
+
def dynamic_range_decompression(x, C=1):
|
56 |
+
return np.exp(x) / C
|
57 |
+
|
58 |
+
|
59 |
+
def dynamic_range_compression_torch(x, C=1, clip_val=1e-5):
|
60 |
+
return torch.log(torch.clamp(x, min=clip_val) * C)
|
61 |
+
|
62 |
+
|
63 |
+
def dynamic_range_decompression_torch(x, C=1):
|
64 |
+
return torch.exp(x) / C
|
65 |
+
|
66 |
+
|
67 |
+
def spectral_normalize_torch(magnitudes):
|
68 |
+
output = dynamic_range_compression_torch(magnitudes)
|
69 |
+
return output
|
70 |
+
|
71 |
+
|
72 |
+
def spectral_de_normalize_torch(magnitudes):
|
73 |
+
output = dynamic_range_decompression_torch(magnitudes)
|
74 |
+
return output
|
75 |
+
|
76 |
+
|
77 |
+
mel_basis = {}
|
78 |
+
hann_window = {}
|
79 |
+
|
80 |
+
|
81 |
+
def mel_spectrogram(y, n_fft, num_mels, sampling_rate, hop_size, win_size, fmin, fmax, center=False):
|
82 |
+
# if torch.min(y) < -1.:
|
83 |
+
# print('min value is ', torch.min(y))
|
84 |
+
# if torch.max(y) > 1.:
|
85 |
+
# print('max value is ', torch.max(y))
|
86 |
+
|
87 |
+
global mel_basis, hann_window
|
88 |
+
if fmax not in mel_basis:
|
89 |
+
mel = librosa_mel_fn(sr=sampling_rate, n_fft=n_fft, n_mels=num_mels, fmin=fmin, fmax=fmax)
|
90 |
+
mel_basis[str(fmax)+'_'+str(y.device)] = torch.from_numpy(mel).float().to(y.device)
|
91 |
+
hann_window[str(y.device)] = torch.hann_window(win_size).to(y.device)
|
92 |
+
|
93 |
+
y = torch.nn.functional.pad(y, (int((n_fft-hop_size)/2), int((n_fft-hop_size)/2)), mode='reflect')
|
94 |
+
y = y.squeeze(1)
|
95 |
+
|
96 |
+
# complex tensor as default, then use view_as_real for future pytorch compatibility
|
97 |
+
spec = torch.stft(y, n_fft, hop_length=hop_size, win_length=win_size, window=hann_window[str(y.device)],
|
98 |
+
center=center, pad_mode='reflect', normalized=False, onesided=True, return_complex=True)
|
99 |
+
spec = torch.view_as_real(spec)
|
100 |
+
spec = torch.sqrt(spec.pow(2).sum(-1)+(1e-9))
|
101 |
+
|
102 |
+
spec = torch.matmul(mel_basis[str(fmax)+'_'+str(y.device)], spec)
|
103 |
+
spec = spectral_normalize_torch(spec)
|
104 |
+
|
105 |
+
return spec
|
106 |
+
|
107 |
+
|
108 |
+
def spectrogram(y, n_fft, num_mels, sampling_rate, hop_size, win_size, fmin, fmax, center=False):
|
109 |
+
global hann_window
|
110 |
+
hann_window[str(y.device)] = torch.hann_window(win_size).to(y.device)
|
111 |
+
|
112 |
+
y = torch.nn.functional.pad(y.unsqueeze(1), (int((n_fft-hop_size)/2), int((n_fft-hop_size)/2)), mode='reflect')
|
113 |
+
y = y.squeeze(1)
|
114 |
+
|
115 |
+
# complex tensor as default, then use view_as_real for future pytorch compatibility
|
116 |
+
spec = torch.stft(y, n_fft, hop_length=hop_size, win_length=win_size, window=hann_window[str(y.device)],
|
117 |
+
center=center, pad_mode='reflect', normalized=False, onesided=True, return_complex=True)
|
118 |
+
spec = torch.view_as_real(spec)
|
119 |
+
spec = torch.sqrt(spec.pow(2).sum(-1)+(1e-9))
|
120 |
+
|
121 |
+
return spec
|
122 |
+
|
123 |
+
|
124 |
+
def normalize_spectrogram(
|
125 |
+
spectrogram: torch.Tensor,
|
126 |
+
max_value: float = 200,
|
127 |
+
min_value: float = 1e-5,
|
128 |
+
power: float = 1.,
|
129 |
+
inverse: bool = False
|
130 |
+
) -> torch.Tensor:
|
131 |
+
|
132 |
+
# Rescale to 0-1
|
133 |
+
max_value = np.log(max_value) # 5.298317366548036
|
134 |
+
min_value = np.log(min_value) # -11.512925464970229
|
135 |
+
|
136 |
+
assert spectrogram.max() <= max_value and spectrogram.min() >= min_value
|
137 |
+
|
138 |
+
data = (spectrogram - min_value) / (max_value - min_value)
|
139 |
+
|
140 |
+
# Invert
|
141 |
+
if inverse:
|
142 |
+
data = 1 - data
|
143 |
+
|
144 |
+
# Apply the power curve
|
145 |
+
data = torch.pow(data, power)
|
146 |
+
|
147 |
+
return data
|
148 |
+
|
149 |
+
|
150 |
+
|
151 |
+
def denormalize_spectrogram(
|
152 |
+
data: torch.Tensor,
|
153 |
+
max_value: float = 200,
|
154 |
+
min_value: float = 1e-5,
|
155 |
+
power: float = 1,
|
156 |
+
) -> torch.Tensor:
|
157 |
+
|
158 |
+
max_value = np.log(max_value)
|
159 |
+
min_value = np.log(min_value)
|
160 |
+
|
161 |
+
# Reverse the power curve
|
162 |
+
data = torch.pow(data, 1 / power)
|
163 |
+
|
164 |
+
# Rescale to max value
|
165 |
+
spectrogram = data * (max_value - min_value) + min_value
|
166 |
+
|
167 |
+
return spectrogram
|
168 |
+
|
169 |
+
|
170 |
+
def get_mel_spectrogram_from_audio(audio, device="cuda"):
|
171 |
+
audio = audio / MAX_WAV_VALUE
|
172 |
+
audio = librosa.util.normalize(audio) * 0.95
|
173 |
+
|
174 |
+
audio = torch.FloatTensor(audio)
|
175 |
+
audio = audio.unsqueeze(0)
|
176 |
+
|
177 |
+
waveform = audio
|
178 |
+
spec = mel_spectrogram(waveform, n_fft=2048, num_mels=256, sampling_rate=16000, hop_size=160, win_size=1024, fmin=0, fmax=8000, center=False)
|
179 |
+
return audio, spec
|
180 |
+
|
181 |
+
|
182 |
+
|
183 |
+
LRELU_SLOPE = 0.1
|
184 |
+
MAX_WAV_VALUE = 32768.0
|
185 |
+
|
186 |
+
|
187 |
+
class AttrDict(dict):
|
188 |
+
def __init__(self, *args, **kwargs):
|
189 |
+
super(AttrDict, self).__init__(*args, **kwargs)
|
190 |
+
self.__dict__ = self
|
191 |
+
|
192 |
+
|
193 |
+
def get_config(config_path):
|
194 |
+
config = json.loads(open(config_path).read())
|
195 |
+
config = AttrDict(config)
|
196 |
+
return config
|
197 |
+
|
198 |
+
def init_weights(m, mean=0.0, std=0.01):
|
199 |
+
classname = m.__class__.__name__
|
200 |
+
if classname.find("Conv") != -1:
|
201 |
+
m.weight.data.normal_(mean, std)
|
202 |
+
|
203 |
+
|
204 |
+
def apply_weight_norm(m):
|
205 |
+
classname = m.__class__.__name__
|
206 |
+
if classname.find("Conv") != -1:
|
207 |
+
weight_norm(m)
|
208 |
+
|
209 |
+
|
210 |
+
def get_padding(kernel_size, dilation=1):
|
211 |
+
return int((kernel_size*dilation - dilation)/2)
|
212 |
+
|
213 |
+
|
214 |
+
class ResBlock1(torch.nn.Module):
|
215 |
+
def __init__(self, h, channels, kernel_size=3, dilation=(1, 3, 5)):
|
216 |
+
super(ResBlock1, self).__init__()
|
217 |
+
self.h = h
|
218 |
+
self.convs1 = nn.ModuleList([
|
219 |
+
weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[0],
|
220 |
+
padding=get_padding(kernel_size, dilation[0]))),
|
221 |
+
weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[1],
|
222 |
+
padding=get_padding(kernel_size, dilation[1]))),
|
223 |
+
weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[2],
|
224 |
+
padding=get_padding(kernel_size, dilation[2])))
|
225 |
+
])
|
226 |
+
self.convs1.apply(init_weights)
|
227 |
+
|
228 |
+
self.convs2 = nn.ModuleList([
|
229 |
+
weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1,
|
230 |
+
padding=get_padding(kernel_size, 1))),
|
231 |
+
weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1,
|
232 |
+
padding=get_padding(kernel_size, 1))),
|
233 |
+
weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1,
|
234 |
+
padding=get_padding(kernel_size, 1)))
|
235 |
+
])
|
236 |
+
self.convs2.apply(init_weights)
|
237 |
+
|
238 |
+
def forward(self, x):
|
239 |
+
for c1, c2 in zip(self.convs1, self.convs2):
|
240 |
+
xt = F.leaky_relu(x, LRELU_SLOPE)
|
241 |
+
xt = c1(xt)
|
242 |
+
xt = F.leaky_relu(xt, LRELU_SLOPE)
|
243 |
+
xt = c2(xt)
|
244 |
+
x = xt + x
|
245 |
+
return x
|
246 |
+
|
247 |
+
def remove_weight_norm(self):
|
248 |
+
for l in self.convs1:
|
249 |
+
remove_weight_norm(l)
|
250 |
+
for l in self.convs2:
|
251 |
+
remove_weight_norm(l)
|
252 |
+
|
253 |
+
|
254 |
+
class ResBlock2(torch.nn.Module):
|
255 |
+
def __init__(self, h, channels, kernel_size=3, dilation=(1, 3)):
|
256 |
+
super(ResBlock2, self).__init__()
|
257 |
+
self.h = h
|
258 |
+
self.convs = nn.ModuleList([
|
259 |
+
weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[0],
|
260 |
+
padding=get_padding(kernel_size, dilation[0]))),
|
261 |
+
weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[1],
|
262 |
+
padding=get_padding(kernel_size, dilation[1])))
|
263 |
+
])
|
264 |
+
self.convs.apply(init_weights)
|
265 |
+
|
266 |
+
def forward(self, x):
|
267 |
+
for c in self.convs:
|
268 |
+
xt = F.leaky_relu(x, LRELU_SLOPE)
|
269 |
+
xt = c(xt)
|
270 |
+
x = xt + x
|
271 |
+
return x
|
272 |
+
|
273 |
+
def remove_weight_norm(self):
|
274 |
+
for l in self.convs:
|
275 |
+
remove_weight_norm(l)
|
276 |
+
|
277 |
+
|
278 |
+
|
279 |
+
class Generator(torch.nn.Module):
|
280 |
+
def __init__(self, h):
|
281 |
+
super(Generator, self).__init__()
|
282 |
+
self.h = h
|
283 |
+
self.num_kernels = len(h.resblock_kernel_sizes)
|
284 |
+
self.num_upsamples = len(h.upsample_rates)
|
285 |
+
self.conv_pre = weight_norm(Conv1d(h.num_mels, h.upsample_initial_channel, 7, 1, padding=3)) # change: 80 --> 512
|
286 |
+
resblock = ResBlock1 if h.resblock == '1' else ResBlock2
|
287 |
+
|
288 |
+
self.ups = nn.ModuleList()
|
289 |
+
for i, (u, k) in enumerate(zip(h.upsample_rates, h.upsample_kernel_sizes)):
|
290 |
+
if (k-u) % 2 == 0:
|
291 |
+
self.ups.append(weight_norm(
|
292 |
+
ConvTranspose1d(h.upsample_initial_channel//(2**i), h.upsample_initial_channel//(2**(i+1)),
|
293 |
+
k, u, padding=(k-u)//2)))
|
294 |
+
else:
|
295 |
+
self.ups.append(weight_norm(
|
296 |
+
ConvTranspose1d(h.upsample_initial_channel//(2**i), h.upsample_initial_channel//(2**(i+1)),
|
297 |
+
k, u, padding=(k-u)//2+1, output_padding=1)))
|
298 |
+
|
299 |
+
# self.ups.append(weight_norm(
|
300 |
+
# ConvTranspose1d(h.upsample_initial_channel//(2**i), h.upsample_initial_channel//(2**(i+1)),
|
301 |
+
# k, u, padding=(k-u)//2)))
|
302 |
+
|
303 |
+
|
304 |
+
self.resblocks = nn.ModuleList()
|
305 |
+
for i in range(len(self.ups)):
|
306 |
+
ch = h.upsample_initial_channel//(2**(i+1))
|
307 |
+
for j, (k, d) in enumerate(zip(h.resblock_kernel_sizes, h.resblock_dilation_sizes)):
|
308 |
+
self.resblocks.append(resblock(h, ch, k, d))
|
309 |
+
|
310 |
+
self.conv_post = weight_norm(Conv1d(ch, 1, 7, 1, padding=3))
|
311 |
+
self.ups.apply(init_weights)
|
312 |
+
self.conv_post.apply(init_weights)
|
313 |
+
|
314 |
+
def forward(self, x):
|
315 |
+
x = self.conv_pre(x)
|
316 |
+
for i in range(self.num_upsamples):
|
317 |
+
x = F.leaky_relu(x, LRELU_SLOPE)
|
318 |
+
x = self.ups[i](x)
|
319 |
+
xs = None
|
320 |
+
for j in range(self.num_kernels):
|
321 |
+
if xs is None:
|
322 |
+
xs = self.resblocks[i*self.num_kernels+j](x)
|
323 |
+
else:
|
324 |
+
xs += self.resblocks[i*self.num_kernels+j](x)
|
325 |
+
x = xs / self.num_kernels
|
326 |
+
x = F.leaky_relu(x)
|
327 |
+
x = self.conv_post(x)
|
328 |
+
x = torch.tanh(x)
|
329 |
+
|
330 |
+
return x
|
331 |
+
|
332 |
+
def remove_weight_norm(self):
|
333 |
+
for l in self.ups:
|
334 |
+
remove_weight_norm(l)
|
335 |
+
for l in self.resblocks:
|
336 |
+
l.remove_weight_norm()
|
337 |
+
remove_weight_norm(self.conv_pre)
|
338 |
+
remove_weight_norm(self.conv_post)
|
339 |
+
|
340 |
+
@classmethod
|
341 |
+
def from_pretrained(cls, pretrained_model_name_or_path, subfolder=None):
|
342 |
+
if subfolder is not None:
|
343 |
+
pretrained_model_name_or_path = os.path.join(pretrained_model_name_or_path, subfolder)
|
344 |
+
config_path = os.path.join(pretrained_model_name_or_path, "config.json")
|
345 |
+
ckpt_path = os.path.join(pretrained_model_name_or_path, "vocoder.pt")
|
346 |
+
|
347 |
+
config = get_config(config_path)
|
348 |
+
vocoder = cls(config)
|
349 |
+
|
350 |
+
state_dict_g = torch.load(ckpt_path)
|
351 |
+
vocoder.load_state_dict(state_dict_g["generator"])
|
352 |
+
vocoder.eval()
|
353 |
+
vocoder.remove_weight_norm()
|
354 |
+
return vocoder
|
355 |
+
|
356 |
+
|
357 |
+
@torch.no_grad()
|
358 |
+
def inference(self, mels, lengths=None):
|
359 |
+
self.eval()
|
360 |
+
with torch.no_grad():
|
361 |
+
wavs = self(mels).squeeze(1)
|
362 |
+
|
363 |
+
wavs = (wavs.cpu().numpy() * MAX_WAV_VALUE).astype("int16")
|
364 |
+
|
365 |
+
if lengths is not None:
|
366 |
+
wavs = wavs[:, :lengths]
|
367 |
+
|
368 |
+
return wavs
|
dataset.py
ADDED
@@ -0,0 +1,206 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. All Rights Reserved
|
2 |
+
|
3 |
+
import os.path as osp
|
4 |
+
import math
|
5 |
+
import pickle
|
6 |
+
import warnings
|
7 |
+
|
8 |
+
import glob
|
9 |
+
|
10 |
+
import torch.utils.data as data
|
11 |
+
import torch.nn.functional as F
|
12 |
+
from torchvision.datasets.video_utils import VideoClips
|
13 |
+
from converter import normalize, normalize_spectrogram, get_mel_spectrogram_from_audio
|
14 |
+
from torchaudio import transforms as Ta
|
15 |
+
from torchvision import transforms as Tv
|
16 |
+
from torchvision.io.video import read_video
|
17 |
+
import torch
|
18 |
+
from torchvision.transforms import InterpolationMode
|
19 |
+
|
20 |
+
class LatentDataset(data.Dataset):
|
21 |
+
""" Generic dataset for latents pregenerated from a dataset
|
22 |
+
Returns a dictionary of latents encoded from the original dataset """
|
23 |
+
exts = ['pt']
|
24 |
+
|
25 |
+
def __init__(self, data_folder, train=True):
|
26 |
+
"""
|
27 |
+
Args:
|
28 |
+
data_folder: path to the folder with videos. The folder
|
29 |
+
should contain a 'train' and a 'test' directory,
|
30 |
+
each with corresponding videos stored
|
31 |
+
"""
|
32 |
+
super().__init__()
|
33 |
+
self.train = train
|
34 |
+
|
35 |
+
folder = osp.join(data_folder, 'train' if train else 'test')
|
36 |
+
self.files = sum([glob.glob(osp.join(folder, '**', f'*.{ext}'), recursive=True)
|
37 |
+
for ext in self.exts], [])
|
38 |
+
|
39 |
+
warnings.filterwarnings('ignore')
|
40 |
+
|
41 |
+
def __len__(self):
|
42 |
+
return len(self.files)
|
43 |
+
|
44 |
+
def __getitem__(self, idx):
|
45 |
+
while True:
|
46 |
+
try:
|
47 |
+
latents = torch.load(self.files[idx], map_location="cpu")
|
48 |
+
except Exception as e:
|
49 |
+
print(f"Dataset Exception: {e}")
|
50 |
+
idx = (idx + 1) % len(self.files)
|
51 |
+
continue
|
52 |
+
break
|
53 |
+
|
54 |
+
return latents["video"], latents["audio"], latents["y"]
|
55 |
+
class AudioVideoDataset(data.Dataset):
|
56 |
+
""" Generic dataset for videos files stored in folders
|
57 |
+
Returns BCTHW videos in the range [-0.5, 0.5] """
|
58 |
+
exts = ['avi', 'mp4', 'webm']
|
59 |
+
|
60 |
+
def __init__(self, data_folder, train=True, resolution=64, sample_every_n_frames=1, sequence_length=8, audio_channels=1, sample_rate=16000, min_length=1, ignore_cache=False, labeled=True, target_video_fps=10):
|
61 |
+
"""
|
62 |
+
Args:
|
63 |
+
data_folder: path to the folder with videos. The folder
|
64 |
+
should contain a 'train' and a 'test' directory,
|
65 |
+
each with corresponding videos stored
|
66 |
+
sequence_length: length of extracted video sequences
|
67 |
+
"""
|
68 |
+
super().__init__()
|
69 |
+
self.train = train
|
70 |
+
self.sequence_length = sequence_length
|
71 |
+
self.resolution = resolution
|
72 |
+
self.sample_every_n_frames = sample_every_n_frames
|
73 |
+
self.audio_channels = audio_channels
|
74 |
+
self.sample_rate = sample_rate
|
75 |
+
self.min_length = min_length
|
76 |
+
self.labeled = labeled
|
77 |
+
|
78 |
+
|
79 |
+
folder = osp.join(data_folder, 'train' if train else 'test')
|
80 |
+
files = sum([glob.glob(osp.join(folder, '**', f'*.{ext}'), recursive=True)
|
81 |
+
for ext in self.exts], [])
|
82 |
+
|
83 |
+
# hacky way to compute # of classes (count # of unique parent directories)
|
84 |
+
self.classes = list(set([get_parent_dir(f) for f in files]))
|
85 |
+
self.classes.sort()
|
86 |
+
self.class_to_label = {c: i for i, c in enumerate(self.classes)}
|
87 |
+
|
88 |
+
warnings.filterwarnings('ignore')
|
89 |
+
cache_file = osp.join(folder, f"metadata_{self.sequence_length}.pkl")
|
90 |
+
if not osp.exists(cache_file) or ignore_cache or True:
|
91 |
+
clips = VideoClips(files, self.sequence_length, num_workers=32, frame_rate=target_video_fps)
|
92 |
+
# pickle.dump(clips.metadata, open(cache_file, 'wb'))
|
93 |
+
else:
|
94 |
+
metadata = pickle.load(open(cache_file, 'rb'))
|
95 |
+
clips = VideoClips(files, self.sequence_length,
|
96 |
+
_precomputed_metadata=metadata)
|
97 |
+
|
98 |
+
# self._clips = clips.subset(np.arange(24))
|
99 |
+
self._clips = clips
|
100 |
+
|
101 |
+
@property
|
102 |
+
def n_classes(self):
|
103 |
+
return len(self.classes)
|
104 |
+
|
105 |
+
def __len__(self):
|
106 |
+
return self._clips.num_clips()
|
107 |
+
|
108 |
+
def __getitem__(self, idx):
|
109 |
+
resolution = self.resolution
|
110 |
+
while True:
|
111 |
+
try:
|
112 |
+
video, _, info, _ = self._clips.get_clip(idx)
|
113 |
+
except Exception:
|
114 |
+
idx = (idx + 1) % self._clips.num_clips()
|
115 |
+
continue
|
116 |
+
break
|
117 |
+
|
118 |
+
return preprocess(video, resolution, sample_every_n_frames=self.sample_every_n_frames), self.get_audio(info, idx), self.get_label(idx)
|
119 |
+
|
120 |
+
def get_label(self, idx):
|
121 |
+
if not self.labeled:
|
122 |
+
return -1
|
123 |
+
video_idx, clip_idx = self._clips.get_clip_location(idx)
|
124 |
+
class_name = get_parent_dir(self._clips.video_paths[video_idx])
|
125 |
+
label = self.class_to_label[class_name]
|
126 |
+
return label
|
127 |
+
|
128 |
+
def get_audio(self, info, idx):
|
129 |
+
video_idx, clip_idx = self._clips.get_clip_location(idx)
|
130 |
+
|
131 |
+
video_path = self._clips.video_paths[video_idx]
|
132 |
+
video_fps = self._clips.video_fps[video_idx]
|
133 |
+
|
134 |
+
duration_per_frame = self._clips.video_pts[video_idx][1] - self._clips.video_pts[video_idx][0]
|
135 |
+
clip_pts = self._clips.clips[video_idx][clip_idx]
|
136 |
+
clip_pid = clip_pts // duration_per_frame
|
137 |
+
|
138 |
+
start_t = (clip_pid[0] / video_fps * 1. ).item()
|
139 |
+
end_t = ((clip_pid[-1] + 1) / video_fps * 1. ).item()
|
140 |
+
|
141 |
+
_, raw_audio, _ = read_video(video_path,start_t, end_t, pts_unit='sec')
|
142 |
+
raw_audio = prepare_audio(raw_audio, info["audio_fps"], self.sample_rate, self.audio_channels, self.sequence_length, self.min_length)
|
143 |
+
|
144 |
+
_, spec = get_mel_spectrogram_from_audio(raw_audio[0].numpy())
|
145 |
+
norm_spec = normalize_spectrogram(spec)
|
146 |
+
norm_spec = normalize(norm_spec) # normalize to [-1, 1], because pipeline do not normalize for torch.Tensor input
|
147 |
+
norm_spec.unsqueeze(1) # add channel dimension
|
148 |
+
return norm_spec
|
149 |
+
#return raw_audio[0]
|
150 |
+
|
151 |
+
|
152 |
+
def get_parent_dir(path):
|
153 |
+
return osp.basename(osp.dirname(path))
|
154 |
+
|
155 |
+
def preprocess(video, resolution, sample_every_n_frames=1):
|
156 |
+
video = video.permute(0, 3, 1, 2).float() / 255. # TCHW
|
157 |
+
|
158 |
+
old_size = video.shape[2:4]
|
159 |
+
ratio = min(float(resolution)/(old_size[0]), float(resolution)/(old_size[1]) )
|
160 |
+
new_size = tuple([int(i*ratio) for i in old_size])
|
161 |
+
pad_w = resolution - new_size[1]
|
162 |
+
pad_h = resolution- new_size[0]
|
163 |
+
top,bottom = pad_h//2, pad_h-(pad_h//2)
|
164 |
+
left,right = pad_w//2, pad_w -(pad_w//2)
|
165 |
+
transform = Tv.Compose([Tv.Resize(new_size, interpolation=InterpolationMode.BICUBIC), Tv.Pad((left, top, right, bottom))])
|
166 |
+
video_new = transform(video)
|
167 |
+
|
168 |
+
video_new = video_new*2-1
|
169 |
+
|
170 |
+
return video_new
|
171 |
+
|
172 |
+
def pad_crop_audio(audio, target_length):
|
173 |
+
target_length = int(target_length)
|
174 |
+
n, s = audio.shape
|
175 |
+
start = 0
|
176 |
+
end = start + target_length
|
177 |
+
output = audio.new_zeros([n, target_length])
|
178 |
+
output[:, :min(s, target_length)] = audio[:, start:end]
|
179 |
+
return output
|
180 |
+
|
181 |
+
def prepare_audio(audio, in_sr, target_sr, target_channels, sequence_length, min_length):
|
182 |
+
if in_sr != target_sr:
|
183 |
+
resample_tf = Ta.Resample(in_sr, target_sr)
|
184 |
+
audio = resample_tf(audio)
|
185 |
+
|
186 |
+
max_length = target_sr/10*sequence_length
|
187 |
+
target_length = max_length + (min_length - (max_length % min_length)) % min_length
|
188 |
+
|
189 |
+
audio = pad_crop_audio(audio, target_length)
|
190 |
+
|
191 |
+
audio = set_audio_channels(audio, target_channels)
|
192 |
+
|
193 |
+
return audio
|
194 |
+
|
195 |
+
def set_audio_channels(audio, target_channels):
|
196 |
+
if target_channels == 1:
|
197 |
+
# Convert to mono
|
198 |
+
# audio = audio.mean(0, keepdim=True)
|
199 |
+
audio = audio[:1, :]
|
200 |
+
elif target_channels == 2:
|
201 |
+
# Convert to stereo
|
202 |
+
if audio.shape[0] == 1:
|
203 |
+
audio = audio.repeat(2, 1)
|
204 |
+
elif audio.shape[0] > 2:
|
205 |
+
audio = audio[:2, :]
|
206 |
+
return audio
|
download.py
ADDED
@@ -0,0 +1,50 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
"""
|
8 |
+
Functions for downloading pre-trained DiT models
|
9 |
+
"""
|
10 |
+
from torchvision.datasets.utils import download_url
|
11 |
+
import torch
|
12 |
+
import os
|
13 |
+
|
14 |
+
|
15 |
+
pretrained_models = {'DiT-XL-2-512x512.pt', 'DiT-XL-2-256x256.pt'}
|
16 |
+
|
17 |
+
|
18 |
+
def find_model(model_name):
|
19 |
+
"""
|
20 |
+
Finds a pre-trained DiT model, downloading it if necessary. Alternatively, loads a model from a local path.
|
21 |
+
"""
|
22 |
+
if model_name in pretrained_models: # Find/download our pre-trained DiT checkpoints
|
23 |
+
return download_model(model_name)
|
24 |
+
else: # Load a custom DiT checkpoint:
|
25 |
+
assert os.path.isfile(model_name), f'Could not find DiT checkpoint at {model_name}'
|
26 |
+
checkpoint = torch.load(model_name, map_location=lambda storage, loc: storage)
|
27 |
+
if "ema" in checkpoint: # supports checkpoints from train.py
|
28 |
+
checkpoint = checkpoint["ema"]
|
29 |
+
return checkpoint
|
30 |
+
|
31 |
+
|
32 |
+
def download_model(model_name):
|
33 |
+
"""
|
34 |
+
Downloads a pre-trained DiT model from the web.
|
35 |
+
"""
|
36 |
+
assert model_name in pretrained_models
|
37 |
+
local_path = f'pretrained_models/{model_name}'
|
38 |
+
if not os.path.isfile(local_path):
|
39 |
+
os.makedirs('pretrained_models', exist_ok=True)
|
40 |
+
web_path = f'https://dl.fbaipublicfiles.com/DiT/models/{model_name}'
|
41 |
+
download_url(web_path, 'pretrained_models')
|
42 |
+
model = torch.load(local_path, map_location=lambda storage, loc: storage)
|
43 |
+
return model
|
44 |
+
|
45 |
+
|
46 |
+
if __name__ == "__main__":
|
47 |
+
# Download all DiT checkpoints
|
48 |
+
for model in pretrained_models:
|
49 |
+
download_model(model)
|
50 |
+
print('Done.')
|
models.py
ADDED
@@ -0,0 +1,751 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
# --------------------------------------------------------
|
7 |
+
# References:
|
8 |
+
# GLIDE: https://github.com/openai/glide-text2im
|
9 |
+
# MAE: https://github.com/facebookresearch/mae/blob/main/models_mae.py
|
10 |
+
# --------------------------------------------------------
|
11 |
+
|
12 |
+
import torch
|
13 |
+
import torch.nn as nn
|
14 |
+
import numpy as np
|
15 |
+
import math
|
16 |
+
from timm.models.vision_transformer import PatchEmbed, Attention, Mlp
|
17 |
+
import einops
|
18 |
+
|
19 |
+
import torch.utils.checkpoint as checkpoint
|
20 |
+
|
21 |
+
from transformers import PreTrainedModel
|
22 |
+
import random
|
23 |
+
|
24 |
+
class MelPatchEmbed(nn.Module):
|
25 |
+
""" Image to Patch Embedding
|
26 |
+
"""
|
27 |
+
def __init__(self, n_mels, n_frames, patch_size=16, in_chans=1, embed_dim=768):
|
28 |
+
super().__init__()
|
29 |
+
num_patches = (n_mels // patch_size) * (n_frames // patch_size)
|
30 |
+
self.patch_size = patch_size
|
31 |
+
self.num_patches = int(num_patches)
|
32 |
+
|
33 |
+
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
|
34 |
+
|
35 |
+
def forward(self, x):
|
36 |
+
x = self.proj(x).flatten(2).transpose(1, 2)
|
37 |
+
return x
|
38 |
+
|
39 |
+
class SelfAttention(nn.Module):
|
40 |
+
def __init__(
|
41 |
+
self,
|
42 |
+
dim: int,
|
43 |
+
num_heads: int = 8,
|
44 |
+
qkv_bias: bool = False,
|
45 |
+
qk_norm: bool = False,
|
46 |
+
attn_drop: float = 0.,
|
47 |
+
proj_drop: float = 0.,
|
48 |
+
norm_layer: nn.Module = nn.LayerNorm,
|
49 |
+
is_causal: bool = False,
|
50 |
+
) -> None:
|
51 |
+
super().__init__()
|
52 |
+
assert dim % num_heads == 0, 'dim should be divisible by num_heads'
|
53 |
+
self.is_causal = is_causal
|
54 |
+
self.num_heads = num_heads
|
55 |
+
self.head_dim = dim // num_heads
|
56 |
+
self.scale = self.head_dim ** -0.5
|
57 |
+
|
58 |
+
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
59 |
+
self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
|
60 |
+
self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
|
61 |
+
self.attn_drop = nn.Dropout(attn_drop)
|
62 |
+
self.proj = nn.Linear(dim, dim)
|
63 |
+
self.proj_drop = nn.Dropout(proj_drop)
|
64 |
+
|
65 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
66 |
+
B, N, C = x.shape
|
67 |
+
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
|
68 |
+
q, k, v = qkv.unbind(0)
|
69 |
+
q, k = self.q_norm(q), self.k_norm(k)
|
70 |
+
x = torch.nn.functional.scaled_dot_product_attention(
|
71 |
+
q, k, v,
|
72 |
+
dropout_p=self.attn_drop.p if self.training else 0.,
|
73 |
+
is_causal=self.is_causal
|
74 |
+
)
|
75 |
+
|
76 |
+
x = x.transpose(1, 2).reshape(B, N, C)
|
77 |
+
x = self.proj(x)
|
78 |
+
x = self.proj_drop(x)
|
79 |
+
return x
|
80 |
+
|
81 |
+
class CrossAttention(nn.Module):
|
82 |
+
def __init__(
|
83 |
+
self,
|
84 |
+
dim,
|
85 |
+
num_heads=8,
|
86 |
+
qkv_bias=False,
|
87 |
+
attn_drop=0.,
|
88 |
+
proj_drop=0.,
|
89 |
+
mask_attn=False,
|
90 |
+
):
|
91 |
+
super().__init__()
|
92 |
+
self.mask_attn = mask_attn
|
93 |
+
self.num_heads = num_heads
|
94 |
+
head_dim = dim // num_heads
|
95 |
+
# NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights
|
96 |
+
self.scale = head_dim ** -0.5
|
97 |
+
|
98 |
+
self.wq = nn.Linear(dim, dim, bias=qkv_bias)
|
99 |
+
self.wkv = nn.Linear(dim, dim*2, bias=qkv_bias)
|
100 |
+
self.attn_drop = nn.Dropout(attn_drop)
|
101 |
+
self.proj = nn.Linear(dim, dim)
|
102 |
+
self.proj_drop = nn.Dropout(proj_drop)
|
103 |
+
|
104 |
+
def forward(self, x, cond):
|
105 |
+
B, N, C = x.shape
|
106 |
+
|
107 |
+
q = self.wq(x)
|
108 |
+
q = einops.rearrange(q, 'B N (H D) -> B H N D', H=self.num_heads)
|
109 |
+
|
110 |
+
kv = self.wkv(cond) # BMD
|
111 |
+
kv = einops.rearrange(kv, 'B N (K H D) ->K B H N D', H=self.num_heads, K=2)
|
112 |
+
k = kv[0]
|
113 |
+
v = kv[1]
|
114 |
+
|
115 |
+
|
116 |
+
x = torch.nn.functional.scaled_dot_product_attention(q, k, v)
|
117 |
+
|
118 |
+
x = einops.rearrange(x, 'B H N D -> B N (H D)')
|
119 |
+
x = self.proj(x)
|
120 |
+
x = self.proj_drop(x)
|
121 |
+
return x
|
122 |
+
|
123 |
+
def modulate(x, shift, scale):
|
124 |
+
return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
|
125 |
+
|
126 |
+
def temporalModulate(x, shift, scale):
|
127 |
+
"""
|
128 |
+
Modulate the input tensor x with the given shift and scale tensors.
|
129 |
+
:param x: the input tensor to modulate with shape (B, T, L, D).
|
130 |
+
:param shift: the shift tensor with shape (B, T, D).
|
131 |
+
:param scale: the scale tensor with shape (B, T, D).
|
132 |
+
"""
|
133 |
+
return x * (1 + scale.unsqueeze(2)) + shift.unsqueeze(2)
|
134 |
+
|
135 |
+
|
136 |
+
#################################################################################
|
137 |
+
# Embedding Layers for Timesteps and Class Labels #
|
138 |
+
#################################################################################
|
139 |
+
|
140 |
+
class TimestepEmbedder(nn.Module):
|
141 |
+
"""
|
142 |
+
Embeds scalar timesteps into vector representations.
|
143 |
+
"""
|
144 |
+
def __init__(self, hidden_size, frequency_embedding_size=256):
|
145 |
+
super().__init__()
|
146 |
+
self.mlp = nn.Sequential(
|
147 |
+
nn.Linear(frequency_embedding_size, hidden_size, bias=True),
|
148 |
+
nn.SiLU(),
|
149 |
+
nn.Linear(hidden_size, hidden_size, bias=True),
|
150 |
+
)
|
151 |
+
self.frequency_embedding_size = frequency_embedding_size
|
152 |
+
|
153 |
+
@staticmethod
|
154 |
+
def timestep_embedding(t, dim, max_period=10000):
|
155 |
+
"""
|
156 |
+
Create sinusoidal timestep embeddings.
|
157 |
+
:param t: a 1-D Tensor of N indices, one per batch element.
|
158 |
+
These may be fractional.
|
159 |
+
:param dim: the dimension of the output.
|
160 |
+
:param max_period: controls the minimum frequency of the embeddings.
|
161 |
+
:return: an (N, D) Tensor of positional embeddings.
|
162 |
+
"""
|
163 |
+
# https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py
|
164 |
+
half = dim // 2
|
165 |
+
freqs = torch.exp(
|
166 |
+
-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half
|
167 |
+
).to(device=t.device)
|
168 |
+
args = t[:, None].float() * freqs[None]
|
169 |
+
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
|
170 |
+
if dim % 2:
|
171 |
+
embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
|
172 |
+
return embedding
|
173 |
+
|
174 |
+
def forward(self, t):
|
175 |
+
t_freq = self.timestep_embedding(t, self.frequency_embedding_size)
|
176 |
+
t_emb = self.mlp(t_freq)
|
177 |
+
return t_emb
|
178 |
+
|
179 |
+
|
180 |
+
class AudioEmbedder(nn.Module):
|
181 |
+
"""
|
182 |
+
Embeds scalar timesteps into vector representations.
|
183 |
+
"""
|
184 |
+
def __init__(self, n_mels, hidden_size):
|
185 |
+
super().__init__()
|
186 |
+
self.mlp = nn.Sequential(
|
187 |
+
nn.Linear(n_mels, hidden_size, bias=True),
|
188 |
+
nn.SiLU(),
|
189 |
+
nn.Linear(hidden_size, hidden_size, bias=True),
|
190 |
+
)
|
191 |
+
|
192 |
+
# TODO: Activation?
|
193 |
+
|
194 |
+
def forward(self, a):
|
195 |
+
a = self.mlp(a)
|
196 |
+
return a
|
197 |
+
|
198 |
+
def init_weights(self):
|
199 |
+
nn.init.xavier_uniform_(self.mlp[0].weight)
|
200 |
+
nn.init.constant_(self.mlp[0].bias, 0)
|
201 |
+
nn.init.xavier_uniform_(self.mlp[2].weight)
|
202 |
+
nn.init.constant_(self.mlp[2].bias, 0)
|
203 |
+
|
204 |
+
class LabelEmbedder(nn.Module):
|
205 |
+
"""
|
206 |
+
Embeds class labels into vector representations. Also handles label dropout for classifier-free guidance.
|
207 |
+
"""
|
208 |
+
def __init__(self, num_classes, hidden_size, dropout_prob):
|
209 |
+
super().__init__()
|
210 |
+
use_cfg_embedding = dropout_prob > 0
|
211 |
+
self.embedding_table = nn.Embedding(num_classes + use_cfg_embedding, hidden_size)
|
212 |
+
self.num_classes = num_classes
|
213 |
+
self.dropout_prob = dropout_prob
|
214 |
+
|
215 |
+
def token_drop(self, labels, force_drop_ids=None):
|
216 |
+
"""
|
217 |
+
Drops labels to enable classifier-free guidance.
|
218 |
+
"""
|
219 |
+
if force_drop_ids is None:
|
220 |
+
drop_ids = torch.rand(labels.shape[0], device=labels.device) < self.dropout_prob
|
221 |
+
else:
|
222 |
+
drop_ids = force_drop_ids == 1
|
223 |
+
labels = torch.where(drop_ids, self.num_classes, labels)
|
224 |
+
return labels
|
225 |
+
|
226 |
+
def forward(self, labels, train, force_drop_ids=None):
|
227 |
+
use_dropout = self.dropout_prob > 0
|
228 |
+
if (train and use_dropout) or (force_drop_ids is not None):
|
229 |
+
labels = self.token_drop(labels, force_drop_ids)
|
230 |
+
embeddings = self.embedding_table(labels)
|
231 |
+
return embeddings
|
232 |
+
|
233 |
+
|
234 |
+
#################################################################################
|
235 |
+
# Core FLAV Model #
|
236 |
+
#################################################################################
|
237 |
+
|
238 |
+
class FLAVBlock(nn.Module):
|
239 |
+
"""
|
240 |
+
A FLAV block with adaptive layer norm zero (adaLN-Zero) conditioning.
|
241 |
+
"""
|
242 |
+
def __init__(self, hidden_size, num_heads, mlp_ratio=4.0, grad_ckpt=False, causal_attn=False, **block_kwargs):
|
243 |
+
super().__init__()
|
244 |
+
|
245 |
+
self.video_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
246 |
+
self.audio_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
247 |
+
# self.video_audio_attn = SelfAttention(hidden_size, num_heads=num_heads, qkv_bias=True, **block_kwargs)
|
248 |
+
self.video_spatial_attn = SelfAttention(hidden_size, num_heads=num_heads, qkv_bias=True, **block_kwargs)
|
249 |
+
self.video_temporal_attn = SelfAttention(hidden_size, num_heads=num_heads, qkv_bias=True, is_causal=causal_attn, **block_kwargs)
|
250 |
+
|
251 |
+
self.audio_spatial_attn = SelfAttention(hidden_size, num_heads=num_heads, qkv_bias=True, is_causal=causal_attn, **block_kwargs)
|
252 |
+
# self.audio_temporal_attn = SelfAttention(hidden_size, num_heads=num_heads, qkv_bias=True, is_causal=causal_attn, **block_kwargs)
|
253 |
+
|
254 |
+
self.video_norm3 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
255 |
+
self.audio_norm3 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
256 |
+
|
257 |
+
self.video_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
258 |
+
self.audio_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
259 |
+
|
260 |
+
mlp_hidden_dim = int(hidden_size * mlp_ratio)
|
261 |
+
approx_gelu = lambda: nn.GELU(approximate="tanh")
|
262 |
+
self.video_mlp = Mlp(in_features=hidden_size, hidden_features=mlp_hidden_dim, act_layer=approx_gelu, drop=0)
|
263 |
+
self.audio_mlp = Mlp(in_features=hidden_size, hidden_features=mlp_hidden_dim, act_layer=approx_gelu, drop=0)
|
264 |
+
|
265 |
+
self.video_adaLN_modulation = nn.Sequential(
|
266 |
+
nn.SiLU(),
|
267 |
+
nn.Linear(hidden_size, 6 * hidden_size, bias=True)
|
268 |
+
)
|
269 |
+
|
270 |
+
self.audio_adaLN_modulation = nn.Sequential(
|
271 |
+
nn.SiLU(),
|
272 |
+
nn.Linear(hidden_size, 3 * hidden_size, bias=True)
|
273 |
+
)
|
274 |
+
|
275 |
+
self.video_scale = nn.Sequential(
|
276 |
+
nn.SiLU(),
|
277 |
+
nn.Linear(hidden_size, 3 * hidden_size, bias=True)
|
278 |
+
)
|
279 |
+
|
280 |
+
self.audio_scale = nn.Sequential(
|
281 |
+
nn.SiLU(),
|
282 |
+
nn.Linear(hidden_size, 3 * hidden_size, bias=True)
|
283 |
+
)
|
284 |
+
|
285 |
+
self.v_avg_proj = nn.Sequential(
|
286 |
+
nn.Linear(hidden_size, hidden_size, bias=True),
|
287 |
+
)
|
288 |
+
self.a_avg_proj = nn.Sequential(
|
289 |
+
nn.Linear(hidden_size, hidden_size, bias=True),
|
290 |
+
)
|
291 |
+
|
292 |
+
|
293 |
+
|
294 |
+
self.grad_ckpt = grad_ckpt
|
295 |
+
|
296 |
+
def forward(self,v, a, v_c, a_c):
|
297 |
+
if self.grad_ckpt:
|
298 |
+
return checkpoint.checkpoint(self._forward, v, a, v_c, a_c, use_reentrant=False)
|
299 |
+
else:
|
300 |
+
return self._forward(v, a, v_c, a_c)
|
301 |
+
|
302 |
+
def _forward(self, v, a, v_c, a_c):
|
303 |
+
"""
|
304 |
+
v: Size of (B, T, Lv, D)
|
305 |
+
a: Size of (B, T, La, D)
|
306 |
+
v_c: Size of (B, T, D)
|
307 |
+
a_c: Size of (B, T, D)
|
308 |
+
"""
|
309 |
+
|
310 |
+
video_shift_msa, video_scale_msa, video_gate_msa, video_shift_tmsa, video_scale_tmsa, video_gate_tmsa = self.video_adaLN_modulation(v_c).chunk(6, dim=-1)
|
311 |
+
# audio_shift_msa, audio_scale_msa, audio_gate_msa, audio_shift_tmsa, audio_scale_tmsa, audio_gate_tmsa = self.audio_adaLN_modulation(a_c).chunk(6, dim=-1)
|
312 |
+
audio_shift_msa, audio_scale_msa, audio_gate_msa = self.audio_adaLN_modulation(a_c).chunk(3, dim=-1)
|
313 |
+
B, T, L, D = v.shape
|
314 |
+
|
315 |
+
v_att = temporalModulate(self.video_norm1(v), video_shift_msa, video_scale_msa)
|
316 |
+
v_att = einops.rearrange(v_att, 'B T L D -> (B T) L D')
|
317 |
+
v_att = v + video_gate_msa.unsqueeze(2)*(self.video_spatial_attn(v_att).view(B, T, L, D))
|
318 |
+
|
319 |
+
v = v_att
|
320 |
+
|
321 |
+
v_att = temporalModulate(self.video_norm2(v_att), video_shift_tmsa, video_scale_tmsa)
|
322 |
+
v_att = einops.rearrange(v_att, 'B T L D -> (B L) T D', T=T)
|
323 |
+
v_att = einops.rearrange(self.video_temporal_attn(v_att), "(B L) T D -> B T L D", B=B)
|
324 |
+
v = v + video_gate_tmsa.unsqueeze(2)*v_att
|
325 |
+
|
326 |
+
a_att = temporalModulate(self.audio_norm1(a), audio_shift_msa, audio_scale_msa)
|
327 |
+
a_att = einops.rearrange(a_att, 'B T L D -> B (T L) D')
|
328 |
+
a_att = a + audio_gate_msa.unsqueeze(2)*(self.audio_spatial_attn(a_att).view(B, T, -1, D))
|
329 |
+
|
330 |
+
a = a_att
|
331 |
+
|
332 |
+
a_avg = self.a_avg_proj(a.mean(dim=2)) # B T D
|
333 |
+
v_avg = self.v_avg_proj(v.mean(dim=2)) # B T D
|
334 |
+
|
335 |
+
v_avg += a_c
|
336 |
+
a_avg += v_c
|
337 |
+
|
338 |
+
scale_v, shift_v, gate_v = self.video_scale(a_avg).chunk(3, dim=-1)
|
339 |
+
scale_a, shift_a, gate_a = self.audio_scale(v_avg).chunk(3, dim=-1)
|
340 |
+
|
341 |
+
|
342 |
+
v = v + gate_v.unsqueeze(2) * self.video_mlp(temporalModulate(self.video_norm3(v), shift_v, scale_v))
|
343 |
+
a = a + gate_a.unsqueeze(2) * self.audio_mlp(temporalModulate(self.audio_norm3(a), shift_a, scale_a))
|
344 |
+
|
345 |
+
return v, a
|
346 |
+
|
347 |
+
def _spatial_attn(self, x, b_size, attn_func):
|
348 |
+
x = einops.rearrange(x, "(B N) T D -> (B T) N D", B=b_size)
|
349 |
+
x = attn_func(x)
|
350 |
+
x = einops.rearrange(x, "(B T) N D -> (B N) T D", B=b_size)
|
351 |
+
return x
|
352 |
+
|
353 |
+
|
354 |
+
class FinalLayer(nn.Module):
|
355 |
+
"""
|
356 |
+
The final layer of FLAV.
|
357 |
+
"""
|
358 |
+
def __init__(self, hidden_size, patch_size, out_channels):
|
359 |
+
super().__init__()
|
360 |
+
self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
361 |
+
self.linear = nn.Linear(hidden_size, patch_size * patch_size * out_channels, bias=True)
|
362 |
+
self.adaLN_modulation = nn.Sequential(
|
363 |
+
nn.SiLU(),
|
364 |
+
nn.Linear(hidden_size, 2 * hidden_size, bias=True)
|
365 |
+
)
|
366 |
+
|
367 |
+
def forward(self, x, c):
|
368 |
+
shift, scale = self.adaLN_modulation(c).chunk(2, dim=-1)
|
369 |
+
x = temporalModulate(self.norm_final(x), shift, scale)
|
370 |
+
x = self.linear(x)
|
371 |
+
return x
|
372 |
+
|
373 |
+
|
374 |
+
class FLAV(nn.Module):
|
375 |
+
"""
|
376 |
+
Diffusion model with a Transformer backbone.
|
377 |
+
"""
|
378 |
+
def __init__(
|
379 |
+
self,
|
380 |
+
latent_size=None,
|
381 |
+
patch_size=2,
|
382 |
+
in_channels=4,
|
383 |
+
hidden_size=1152,
|
384 |
+
depth=28,
|
385 |
+
num_heads=16,
|
386 |
+
mlp_ratio=4.0,
|
387 |
+
class_dropout_prob=0.1,
|
388 |
+
num_classes=1000,
|
389 |
+
predict_frames = 1,
|
390 |
+
grad_ckpt = False,
|
391 |
+
n_mels=256,
|
392 |
+
audio_fr = 16000,
|
393 |
+
causal_attn = False,
|
394 |
+
):
|
395 |
+
super().__init__()
|
396 |
+
self.in_channels = in_channels
|
397 |
+
self.out_channels = in_channels
|
398 |
+
self.patch_size = patch_size
|
399 |
+
self.num_heads = num_heads
|
400 |
+
self.predict_frames = predict_frames
|
401 |
+
self.grad_ckpt = grad_ckpt
|
402 |
+
self.n_mels = n_mels
|
403 |
+
self.audio_fr = audio_fr
|
404 |
+
self.latent_size = latent_size # T H W
|
405 |
+
|
406 |
+
self.num_classes = num_classes
|
407 |
+
|
408 |
+
self.v_embedder = PatchEmbed(latent_size, patch_size, in_channels, hidden_size, bias=True)
|
409 |
+
self.a_embedder = nn.Linear(n_mels, hidden_size, bias=True)
|
410 |
+
|
411 |
+
self.video_t_embedder = TimestepEmbedder(hidden_size)
|
412 |
+
self.audio_t_embedder = TimestepEmbedder(hidden_size)
|
413 |
+
|
414 |
+
if self.num_classes > 0:
|
415 |
+
self.video_y_embedder = LabelEmbedder(num_classes, hidden_size, class_dropout_prob)
|
416 |
+
self.audio_y_embedder = LabelEmbedder(num_classes, hidden_size, class_dropout_prob)
|
417 |
+
|
418 |
+
num_patches = self.v_embedder.num_patches
|
419 |
+
self.video_spatial_pos_embed = nn.Parameter(torch.zeros(1, 1, num_patches, hidden_size), requires_grad=True)
|
420 |
+
self.video_temporal_pos_embed = nn.Parameter(torch.zeros(1, self.predict_frames, 1, hidden_size), requires_grad=True)
|
421 |
+
|
422 |
+
self.audio_spatial_pos_embed = nn.Parameter(torch.zeros(1, 1, 10, hidden_size), requires_grad=True)
|
423 |
+
self.audio_temporal_pos_embed = nn.Parameter(torch.zeros(1, self.predict_frames, 1, hidden_size), requires_grad=True)
|
424 |
+
|
425 |
+
|
426 |
+
self.blocks = nn.ModuleList([
|
427 |
+
FLAVBlock(hidden_size, num_heads, mlp_ratio=mlp_ratio, grad_ckpt=grad_ckpt, causal_attn=causal_attn) for _ in range(depth)
|
428 |
+
])
|
429 |
+
|
430 |
+
self.video_final_layer = FinalLayer(hidden_size, patch_size, self.out_channels)
|
431 |
+
self.audio_final_layer = FinalLayer(hidden_size, 1, n_mels)
|
432 |
+
self.initialize_weights()
|
433 |
+
|
434 |
+
def initialize_weights(self):
|
435 |
+
# Initialize transformer layers:
|
436 |
+
def _basic_init(module):
|
437 |
+
if isinstance(module, nn.Linear):
|
438 |
+
torch.nn.init.xavier_uniform_(module.weight)
|
439 |
+
if module.bias is not None:
|
440 |
+
nn.init.constant_(module.bias, 0)
|
441 |
+
self.apply(_basic_init)
|
442 |
+
|
443 |
+
# Initialize patch_embed like nn.Linear (instead of nn.Conv2d):
|
444 |
+
w = self.v_embedder.proj.weight.data
|
445 |
+
nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
|
446 |
+
nn.init.constant_(self.v_embedder.proj.bias, 0)
|
447 |
+
|
448 |
+
|
449 |
+
if self.num_classes > 0:
|
450 |
+
nn.init.normal_(self.video_y_embedder.embedding_table.weight, std=0.02)
|
451 |
+
nn.init.normal_(self.audio_y_embedder.embedding_table.weight, std=0.02)
|
452 |
+
|
453 |
+
# Initialize timestep embedding MLP:
|
454 |
+
nn.init.normal_(self.video_t_embedder.mlp[0].weight, std=0.02)
|
455 |
+
nn.init.normal_(self.video_t_embedder.mlp[2].weight, std=0.02)
|
456 |
+
|
457 |
+
nn.init.normal_(self.audio_t_embedder.mlp[0].weight, std=0.02)
|
458 |
+
nn.init.normal_(self.audio_t_embedder.mlp[2].weight, std=0.02)
|
459 |
+
|
460 |
+
# Zero-out adaLN modulation layers in FLAV blocks:
|
461 |
+
for block in self.blocks:
|
462 |
+
nn.init.constant_(block.video_adaLN_modulation[-1].weight, 0)
|
463 |
+
nn.init.constant_(block.video_adaLN_modulation[-1].bias, 0)
|
464 |
+
nn.init.constant_(block.audio_adaLN_modulation[-1].weight, 0)
|
465 |
+
nn.init.constant_(block.audio_adaLN_modulation[-1].bias, 0)
|
466 |
+
|
467 |
+
nn.init.constant_(block.video_scale[-1].weight, 0)
|
468 |
+
nn.init.constant_(block.video_scale[-1].bias, 0)
|
469 |
+
|
470 |
+
nn.init.constant_(block.audio_scale[-1].weight, 0)
|
471 |
+
nn.init.constant_(block.audio_scale[-1].bias, 0)
|
472 |
+
|
473 |
+
# Zero-out output layers:
|
474 |
+
nn.init.constant_(self.video_final_layer.adaLN_modulation[-1].weight, 0)
|
475 |
+
nn.init.constant_(self.video_final_layer.adaLN_modulation[-1].bias, 0)
|
476 |
+
nn.init.constant_(self.video_final_layer.linear.weight, 0)
|
477 |
+
nn.init.constant_(self.video_final_layer.linear.bias, 0)
|
478 |
+
|
479 |
+
nn.init.constant_(self.audio_final_layer.adaLN_modulation[-1].weight, 0)
|
480 |
+
nn.init.constant_(self.audio_final_layer.adaLN_modulation[-1].bias, 0)
|
481 |
+
nn.init.constant_(self.audio_final_layer.linear.weight, 0)
|
482 |
+
nn.init.constant_(self.audio_final_layer.linear.bias, 0)
|
483 |
+
|
484 |
+
def unpatchify(self, x):
|
485 |
+
"""
|
486 |
+
x: (N, T, patch_size**2 * C)
|
487 |
+
imgs: (N, C, H, W)
|
488 |
+
"""
|
489 |
+
c = self.out_channels
|
490 |
+
p = self.v_embedder.patch_size[0]
|
491 |
+
h = w = int(x.shape[1] ** 0.5)
|
492 |
+
assert h * w == x.shape[1]
|
493 |
+
|
494 |
+
x = x.reshape(shape=(x.shape[0], h, w, p, p, c))
|
495 |
+
x = torch.einsum('nhwpqc->nchpwq', x)
|
496 |
+
imgs = x.reshape(shape=(x.shape[0], c, h * p, h * p))
|
497 |
+
return imgs
|
498 |
+
|
499 |
+
def _apply_rnd_mask(self, input, mask, device="cuda"):
|
500 |
+
input_rnd = torch.rand(input[0].shape).unsqueeze(0).to(device=device)*2 - 1
|
501 |
+
return self._apply_mask(input, mask, input_rnd)
|
502 |
+
|
503 |
+
def _apply_zero_mask(self, input, mask, device="cuda"):
|
504 |
+
input_zero= torch.zeros(input[0].shape).unsqueeze(0).to(device=device)
|
505 |
+
return self._apply_mask(input, mask, input_zero)
|
506 |
+
|
507 |
+
def _get_frames_mask(self, bs):
|
508 |
+
"""
|
509 |
+
bs: batch size
|
510 |
+
|
511 |
+
returns a boolean mask to be applied to condition frames
|
512 |
+
to mask a selected number of random frames
|
513 |
+
"""
|
514 |
+
fmask = np.full(self.cond_frames*bs, False)
|
515 |
+
frames = list(range(self.cond_frames))
|
516 |
+
for b in range(bs):
|
517 |
+
if random.randint(0, 100) < self.mask_freq:
|
518 |
+
sub_frames = random.sample(frames, min(self.cond_frames, self.frames_to_mask))
|
519 |
+
idxs = [f+(b*self.cond_frames) for f in sub_frames]
|
520 |
+
fmask[idxs] = True
|
521 |
+
return fmask
|
522 |
+
|
523 |
+
def _get_batch_mask(self, bs):
|
524 |
+
"""
|
525 |
+
bs: batch size
|
526 |
+
|
527 |
+
returns a boolean mask to be applied to condition frames
|
528 |
+
to mask a random number of condition sequences in a batch
|
529 |
+
"""
|
530 |
+
rnd = np.random.rand(bs)
|
531 |
+
bmask= rnd < self.batch_mask_freq/100
|
532 |
+
bmask = np.repeat(bmask, self.cond_frames)
|
533 |
+
return bmask
|
534 |
+
|
535 |
+
def _apply_mask(self, input, mask, values):
|
536 |
+
input[mask] = values
|
537 |
+
return input
|
538 |
+
|
539 |
+
def audio_unpatchify(self, x):
|
540 |
+
"""
|
541 |
+
x: (N, T, patch_size * C)
|
542 |
+
audio: (N, N_mels, frames)
|
543 |
+
"""
|
544 |
+
c = 1
|
545 |
+
p = self.audio_patch_size
|
546 |
+
h = int(self.n_mels//p)
|
547 |
+
w = int((self.audio_fr/1600)/p)
|
548 |
+
|
549 |
+
|
550 |
+
assert h * w == x.shape[1]
|
551 |
+
|
552 |
+
x = x.reshape(shape=(x.shape[0], h, w, p, p, c))
|
553 |
+
x = torch.einsum('nhwpqc->nchpwq', x)
|
554 |
+
audio = x.reshape(shape=(x.shape[0], c, h * p, w * p))
|
555 |
+
return audio
|
556 |
+
|
557 |
+
def forward(self, v, a, t, y):
|
558 |
+
"""
|
559 |
+
Forward pass of FLAV.
|
560 |
+
v: (B, T, C, H, W) tensor of spatial inputs (images or latent representations of images)
|
561 |
+
a: (B, 1, n_bins, T) # mel spectrogram of audio
|
562 |
+
t: (B, T) tensor of diffusion timesteps
|
563 |
+
y: (B,) tensor of class labels
|
564 |
+
"""
|
565 |
+
|
566 |
+
### Video
|
567 |
+
B, T, C, H, W = v.shape
|
568 |
+
v = einops.rearrange(v, 'B T C H W -> (B T) C H W')
|
569 |
+
v = self.v_embedder(v)
|
570 |
+
v = einops.rearrange(v, '(B T) L D -> B T L D', T=T)
|
571 |
+
v = v + self.video_temporal_pos_embed + self.video_spatial_pos_embed
|
572 |
+
|
573 |
+
|
574 |
+
### Audio
|
575 |
+
a = einops.rearrange(a, "B T C N F -> B T C F N").squeeze(2)
|
576 |
+
a = self.a_embedder(a)
|
577 |
+
a = a + self.audio_temporal_pos_embed + self.audio_spatial_pos_embed
|
578 |
+
|
579 |
+
### Conditioning
|
580 |
+
t = t.view(-1) # B T -> (B T)
|
581 |
+
v_t = self.video_t_embedder(t) # (B, T, D)
|
582 |
+
v_t = v_t.view(B, T, -1) # (B T) D -> B T D
|
583 |
+
|
584 |
+
if self.num_classes > 0:
|
585 |
+
v_y = self.video_y_embedder(y, self.training) # (B, D)
|
586 |
+
v_y = v_y.unsqueeze(1).expand(-1, T, -1) # (B, T, D)
|
587 |
+
|
588 |
+
v_c = (v_t + v_y) if self.num_classes > 0 else v_t # (B, T, D)
|
589 |
+
|
590 |
+
a_t = self.audio_t_embedder(t) # (B, T, D)
|
591 |
+
a_t = a_t.view(B, T, -1)
|
592 |
+
|
593 |
+
if self.num_classes > 0:
|
594 |
+
a_y = self.audio_y_embedder(y, self.training)
|
595 |
+
a_y = a_y.unsqueeze(1).expand(-1, T, -1)
|
596 |
+
|
597 |
+
a_c = (a_t + a_y) if self.num_classes > 0 else a_t # (B, T, D)
|
598 |
+
|
599 |
+
for block in self.blocks:
|
600 |
+
v, a = block(v, a, v_c, a_c) # (B, T, D)
|
601 |
+
|
602 |
+
v = self.video_final_layer(v, v_c) # (B, T, patch_size ** 2 * out_channels), (B, T, L)
|
603 |
+
a = self.audio_final_layer(a, a_c)
|
604 |
+
|
605 |
+
v = einops.rearrange(v, 'B T L D -> (B T) L D', T = T)
|
606 |
+
v = self.unpatchify(v) # (B, out_channels, H, W)
|
607 |
+
v = einops.rearrange(v, '(B T) C H W -> B T C H W', T = T)
|
608 |
+
|
609 |
+
a = einops.rearrange(a, 'B T F N -> B T N F', T = T).unsqueeze(2)
|
610 |
+
return v, a
|
611 |
+
|
612 |
+
def forward_with_cfg(self, v, a, t, y, cfg_scale):
|
613 |
+
"""
|
614 |
+
Forward pass of FLAV, but also batches the unconditional forward pass for classifier-free guidance.
|
615 |
+
"""
|
616 |
+
v_combined = torch.cat([v, v], dim=0)
|
617 |
+
|
618 |
+
a_combined = torch.cat([a, a], dim=0)
|
619 |
+
|
620 |
+
y_null = torch.tensor([self.num_classes]*v.shape[0], device=v.device)
|
621 |
+
y = torch.cat([y, y_null], dim=0)
|
622 |
+
|
623 |
+
t = torch.cat([t, t], dim=0)
|
624 |
+
|
625 |
+
v_model_out, a_model_out = self.forward(v_combined, a_combined, t, y)
|
626 |
+
v_eps = v_model_out
|
627 |
+
a_eps = a_model_out
|
628 |
+
|
629 |
+
v_cond_eps, v_uncond_eps = torch.split(v_eps, len(v_eps) // 2, dim=0)
|
630 |
+
v_eps = v_uncond_eps + cfg_scale * (v_cond_eps - v_uncond_eps)
|
631 |
+
|
632 |
+
a_cond_eps, a_uncond_eps = torch.split(a_eps, len(a_eps) // 2, dim=0)
|
633 |
+
a_eps = a_uncond_eps + cfg_scale * (a_cond_eps - a_uncond_eps)
|
634 |
+
|
635 |
+
return v_eps, a_eps
|
636 |
+
|
637 |
+
#################################################################################
|
638 |
+
# Sine/Cosine Positional Embedding Functions #
|
639 |
+
#################################################################################
|
640 |
+
# https://github.com/facebookresearch/mae/blob/main/util/video_pos_embed.py
|
641 |
+
|
642 |
+
def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False, extra_tokens=0):
|
643 |
+
"""
|
644 |
+
grid_size: int of the grid height and width
|
645 |
+
return:
|
646 |
+
video_pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
|
647 |
+
"""
|
648 |
+
grid_h = np.arange(grid_size, dtype=np.float32)
|
649 |
+
grid_w = np.arange(grid_size, dtype=np.float32)
|
650 |
+
grid = np.meshgrid(grid_w, grid_h) # here w goes first
|
651 |
+
grid = np.stack(grid, axis=0)
|
652 |
+
|
653 |
+
grid = grid.reshape([2, 1, grid_size, grid_size])
|
654 |
+
video_pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
|
655 |
+
if cls_token and extra_tokens > 0:
|
656 |
+
video_pos_embed = np.concatenate([np.zeros([extra_tokens, embed_dim]), video_pos_embed], axis=0)
|
657 |
+
return video_pos_embed
|
658 |
+
|
659 |
+
|
660 |
+
def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
|
661 |
+
assert embed_dim % 2 == 0
|
662 |
+
|
663 |
+
# use half of dimensions to encode grid_h
|
664 |
+
emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2)
|
665 |
+
emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2)
|
666 |
+
|
667 |
+
emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D)
|
668 |
+
return emb
|
669 |
+
|
670 |
+
|
671 |
+
def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
|
672 |
+
"""
|
673 |
+
embed_dim: output dimension for each position
|
674 |
+
pos: a list of positions to be encoded: size (M,)
|
675 |
+
out: (M, D)
|
676 |
+
"""
|
677 |
+
assert embed_dim % 2 == 0
|
678 |
+
omega = np.arange(embed_dim // 2, dtype=np.float64)
|
679 |
+
omega /= embed_dim / 2.
|
680 |
+
omega = 1. / 10000**omega # (D/2,)
|
681 |
+
|
682 |
+
pos = pos.reshape(-1) # (M,)
|
683 |
+
out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product
|
684 |
+
|
685 |
+
emb_sin = np.sin(out) # (M, D/2)
|
686 |
+
emb_cos = np.cos(out) # (M, D/2)
|
687 |
+
|
688 |
+
emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
|
689 |
+
return emb
|
690 |
+
|
691 |
+
|
692 |
+
#################################################################################
|
693 |
+
# FLAV Configs #
|
694 |
+
#################################################################################
|
695 |
+
|
696 |
+
def FLAV_XL_2(**kwargs):
|
697 |
+
return FLAV(depth=28, hidden_size=1152, patch_size=2, num_heads=16, **kwargs)
|
698 |
+
|
699 |
+
def FLAV_XL_4(**kwargs):
|
700 |
+
return FLAV(depth=28, hidden_size=1152, patch_size=4, num_heads=16, **kwargs)
|
701 |
+
|
702 |
+
def FLAV_XL_8(**kwargs):
|
703 |
+
return FLAV(depth=28, hidden_size=1152, patch_size=8, num_heads=16, **kwargs)
|
704 |
+
|
705 |
+
# def FLAV_L_2(**kwargs):
|
706 |
+
# return FLAV(depth=24, hidden_size=1024, patch_size=2, num_heads=16, **kwargs)
|
707 |
+
|
708 |
+
def FLAV_L_1(**kwargs):
|
709 |
+
return FLAV(depth=24, hidden_size=1024, patch_size=1, num_heads=16, **kwargs)
|
710 |
+
|
711 |
+
def FLAV_L_2(**kwargs):
|
712 |
+
return FLAV(depth=20, hidden_size=1024, patch_size=2, num_heads=16, **kwargs)
|
713 |
+
|
714 |
+
def FLAV_L_4(**kwargs):
|
715 |
+
return FLAV(depth=20, hidden_size=1024, patch_size=4, num_heads=16, **kwargs)
|
716 |
+
|
717 |
+
def FLAV_L_8(**kwargs):
|
718 |
+
return FLAV(depth=20, hidden_size=1024, patch_size=8, num_heads=16, **kwargs)
|
719 |
+
|
720 |
+
# def FLAV_B_2(**kwargs):
|
721 |
+
# return FLAV(depth=12, hidden_size=768, patch_size=2, num_heads=12, **kwargs)
|
722 |
+
|
723 |
+
def FLAV_B_1(**kwargs):
|
724 |
+
return FLAV(depth=12, hidden_size=768, patch_size=1, num_heads=12, **kwargs)
|
725 |
+
|
726 |
+
def FLAV_B_2(**kwargs):
|
727 |
+
return FLAV(depth=12, hidden_size=768, patch_size=2, num_heads=12, **kwargs)
|
728 |
+
|
729 |
+
def FLAV_B_4(**kwargs):
|
730 |
+
return FLAV(depth=12, hidden_size=768, patch_size=4, num_heads=12, **kwargs)
|
731 |
+
|
732 |
+
def FLAV_B_8(**kwargs):
|
733 |
+
return FLAV(depth=12, hidden_size=768, patch_size=8, num_heads=12, **kwargs)
|
734 |
+
|
735 |
+
def FLAV_S_2(**kwargs):
|
736 |
+
return FLAV(depth=12, hidden_size=384, patch_size=2, num_heads=6, **kwargs)
|
737 |
+
|
738 |
+
def FLAV_S_4(**kwargs):
|
739 |
+
return FLAV(depth=12, hidden_size=384, patch_size=4, num_heads=6, **kwargs)
|
740 |
+
|
741 |
+
def FLAV_S_8(**kwargs):
|
742 |
+
return FLAV(depth=12, hidden_size=384, patch_size=8, num_heads=6, **kwargs)
|
743 |
+
|
744 |
+
|
745 |
+
FLAV_models = {
|
746 |
+
'FLAV-XL/2': FLAV_XL_2, 'FLAV-XL/4': FLAV_XL_4, 'FLAV-XL/8': FLAV_XL_8,
|
747 |
+
'FLAV-L/1' : FLAV_L_1, 'FLAV-L/2': FLAV_L_2, 'FLAV-L/4': FLAV_L_4, 'FLAV-L/8': FLAV_L_8,
|
748 |
+
'FLAV-B/1' : FLAV_B_1, 'FLAV-B/2': FLAV_B_2, 'FLAV-B/4': FLAV_B_4, 'FLAV-B/8': FLAV_B_8,
|
749 |
+
'FLAV-S/2' : FLAV_S_2, 'FLAV-S/4': FLAV_S_4, 'FLAV-S/8': FLAV_S_8,
|
750 |
+
}
|
751 |
+
|
utils.py
ADDED
@@ -0,0 +1,267 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# from moviepy.video.io.ImageSequenceClip import ImageSequenceClip
|
2 |
+
# # from moviepy.audio.AudioClip import AudioArrayClip
|
3 |
+
# from moviepy.audio.io.AudioFileClip import AudioFileClip
|
4 |
+
from torch.utils.data import DataLoader
|
5 |
+
from dataset import AudioVideoDataset, LatentDataset
|
6 |
+
import torch as th
|
7 |
+
import numpy as np
|
8 |
+
|
9 |
+
import einops
|
10 |
+
from moviepy.audio.io.AudioFileClip import AudioFileClip
|
11 |
+
from moviepy.video.io.ImageSequenceClip import ImageSequenceClip
|
12 |
+
from diffusers.models import AutoencoderKL
|
13 |
+
|
14 |
+
from converter import denormalize, denormalize_spectrogram
|
15 |
+
|
16 |
+
import soundfile as sf
|
17 |
+
import os
|
18 |
+
import json
|
19 |
+
import torch
|
20 |
+
from tqdm import tqdm
|
21 |
+
#################################################################################
|
22 |
+
# Video Utils #
|
23 |
+
#################################################################################
|
24 |
+
|
25 |
+
|
26 |
+
def preprocess_video(video):
|
27 |
+
# video = 255*(video+1)/2.0 # [-1,1] -> [0,1] -> [0,255]
|
28 |
+
# video = th.clamp(video, 0, 255).to(dtype=th.uint8, device="cuda")
|
29 |
+
video = out2img(video)
|
30 |
+
video = einops.rearrange(video, 't c h w -> t h w c').cpu().numpy()
|
31 |
+
return video
|
32 |
+
|
33 |
+
def preprocess_video_batch(videos):
|
34 |
+
B = videos.shape[0]
|
35 |
+
videos_prep = np.empty(B, dtype=np.ndarray)
|
36 |
+
for b in range(B):
|
37 |
+
videos_prep[b] = preprocess_video(videos[b])
|
38 |
+
videos_prep = np.stack(videos_prep, axis=0)
|
39 |
+
return videos_prep
|
40 |
+
|
41 |
+
def save_latents(video, audio, y, output_path, name_prefix, ext=".pt"):
|
42 |
+
os.makedirs(output_path, exist_ok=True)
|
43 |
+
th.save(
|
44 |
+
{
|
45 |
+
"video":video,
|
46 |
+
"audio":audio,
|
47 |
+
"y":y
|
48 |
+
}, os.path.join(output_path, name_prefix + ext))
|
49 |
+
|
50 |
+
def save_multimodal(video, audio, output_path, name_prefix, video_fps=10, audio_fps=16000, audio_dir=None):
|
51 |
+
if not audio_dir:
|
52 |
+
audio_dir = output_path
|
53 |
+
|
54 |
+
#prepare folders
|
55 |
+
audio_dir = os.path.join(audio_dir, "audio")
|
56 |
+
os.makedirs(audio_dir, exist_ok=True)
|
57 |
+
audio_path = os.path.join(audio_dir, name_prefix + "_audio.wav")
|
58 |
+
|
59 |
+
video_dir = os.path.join(output_path, "video")
|
60 |
+
os.makedirs(video_dir, exist_ok=True)
|
61 |
+
video_path = os.path.join(video_dir, name_prefix + "_video.mp4")
|
62 |
+
|
63 |
+
#save audio
|
64 |
+
sf.write(audio_path, audio, samplerate=audio_fps)
|
65 |
+
|
66 |
+
#save video
|
67 |
+
video = preprocess_video(video)
|
68 |
+
|
69 |
+
imgs = [img for img in video]
|
70 |
+
video_clip = ImageSequenceClip(imgs, fps=video_fps)
|
71 |
+
audio_clip = AudioFileClip(audio_path)
|
72 |
+
video_clip = video_clip.with_audio(audio_clip)
|
73 |
+
video_clip.write_videofile(video_path, video_fps, audio=True, audio_fps=audio_fps)
|
74 |
+
|
75 |
+
def get_dataloader(args, logger, sequence_length, train, latents=False):
|
76 |
+
if latents:
|
77 |
+
train_set = LatentDataset(args.data_path, train=train)
|
78 |
+
else:
|
79 |
+
train_set = AudioVideoDataset(
|
80 |
+
args.data_path,
|
81 |
+
train=train,
|
82 |
+
sample_every_n_frames=1,
|
83 |
+
resolution=args.image_size,
|
84 |
+
sequence_length = sequence_length,
|
85 |
+
audio_channels = 1,
|
86 |
+
sample_rate=16000,
|
87 |
+
min_length=1,
|
88 |
+
ignore_cache=args.ignore_cache,
|
89 |
+
labeled=args.num_classes > 0,
|
90 |
+
target_video_fps=args.target_video_fps,
|
91 |
+
)
|
92 |
+
loader = DataLoader(
|
93 |
+
train_set,
|
94 |
+
batch_size=args.batch_size,
|
95 |
+
shuffle=True,
|
96 |
+
num_workers=args.num_workers,
|
97 |
+
pin_memory=True,
|
98 |
+
drop_last=True
|
99 |
+
)
|
100 |
+
if logger is not None:
|
101 |
+
logger.info(f'{"Train" if train else "Test"} Dataset contains {len(train_set)}, images ({args.data_path})')
|
102 |
+
else:
|
103 |
+
print(f'{"Train" if train else "Test"} Dataset contains {len(train_set)}, images ({args.data_path})')
|
104 |
+
return loader
|
105 |
+
|
106 |
+
@torch.no_grad()
|
107 |
+
def encode_video(video, vae, use_sd_vae = False):
|
108 |
+
b, t, c, h, w = video.shape
|
109 |
+
video = einops.rearrange(video, "b t c h w-> (b t) c h w")
|
110 |
+
if use_sd_vae:
|
111 |
+
video = vae.encode(video).latent_dist.sample().mul_(0.18215)
|
112 |
+
else:
|
113 |
+
video = vae.encode(video)*vae.cfg.scaling_factor
|
114 |
+
video = einops.rearrange(video, "(b t) c h w -> b t c h w", t=t)
|
115 |
+
return video
|
116 |
+
|
117 |
+
@torch.no_grad()
|
118 |
+
def decode_video(video, vae):
|
119 |
+
b = video.shape[0]
|
120 |
+
video_decoded = []
|
121 |
+
video = einops.rearrange(video, "b t c h w -> (b t) c h w")
|
122 |
+
|
123 |
+
#use minibatch to avoid memory error
|
124 |
+
for i in range(0, video.shape[0], b):
|
125 |
+
if isinstance(vae, AutoencoderKL):
|
126 |
+
video_decoded.append(vae.decode(video[i:i+b] / 0.18215).sample.detach().cpu())
|
127 |
+
else:
|
128 |
+
video_decoded.append(vae.decode(video[i:i+b] / vae.cfg.scaling_factor).detach().cpu())
|
129 |
+
|
130 |
+
video = torch.cat(video_decoded, dim=0)
|
131 |
+
video = einops.rearrange(video, "(b t) c h w ->b t c h w",b=b)
|
132 |
+
return video
|
133 |
+
|
134 |
+
|
135 |
+
def generate_sample(vae,
|
136 |
+
rectified_flow,
|
137 |
+
forward_fn,
|
138 |
+
video_length,
|
139 |
+
video_latent_size,
|
140 |
+
audio_latent_size,
|
141 |
+
y,
|
142 |
+
cfg_scale,
|
143 |
+
device):
|
144 |
+
|
145 |
+
|
146 |
+
with torch.no_grad():
|
147 |
+
v_z = torch.randn(video_latent_size, device=device)*rectified_flow.noise_scale
|
148 |
+
a_z = torch.randn(audio_latent_size, device=device)*rectified_flow.noise_scale
|
149 |
+
|
150 |
+
model_kwargs = dict(y=y, cfg_scale=cfg_scale) if cfg_scale else dict(y=y)
|
151 |
+
|
152 |
+
sample_fn = rectified_flow.sample(
|
153 |
+
forward_fn, v_z, a_z, model_kwargs=model_kwargs, progress=True)()
|
154 |
+
|
155 |
+
video = []
|
156 |
+
audio = []
|
157 |
+
for _ in tqdm(range(video_length), desc="Generating frames"):
|
158 |
+
video_samples, audio_samples = next(sample_fn)
|
159 |
+
|
160 |
+
video.append(video_samples)
|
161 |
+
audio.append(audio_samples)
|
162 |
+
|
163 |
+
video = torch.stack(video, dim=1)
|
164 |
+
audio = torch.stack(audio, dim=1)
|
165 |
+
|
166 |
+
video = decode_video(video, vae)
|
167 |
+
audio = einops.rearrange(audio, "B T C N F -> B C N (T F)")
|
168 |
+
|
169 |
+
return video, audio
|
170 |
+
|
171 |
+
def generate_sample_a2v(vae,
|
172 |
+
rectified_flow,
|
173 |
+
forward_fn,
|
174 |
+
video_length,
|
175 |
+
video_latent_size,
|
176 |
+
audio,
|
177 |
+
y,
|
178 |
+
device,
|
179 |
+
cfg_scale=1,
|
180 |
+
scale=1):
|
181 |
+
|
182 |
+
|
183 |
+
v_z = torch.randn(video_latent_size, device=device)*rectified_flow.noise_scale
|
184 |
+
|
185 |
+
model_kwargs = dict(y=y, cfg_scale=cfg_scale) if cfg_scale else dict(y=y)
|
186 |
+
|
187 |
+
sample_fn = rectified_flow.sample_a2v(
|
188 |
+
forward_fn, v_z, audio, model_kwargs=model_kwargs, scale=scale, progress=True)()
|
189 |
+
|
190 |
+
video = []
|
191 |
+
for i in tqdm(range(video_length), desc="Generating frames"):
|
192 |
+
video_samples = next(sample_fn)
|
193 |
+
|
194 |
+
video.append(video_samples)
|
195 |
+
|
196 |
+
video = torch.stack(video, dim=1)
|
197 |
+
|
198 |
+
video = decode_video(video, vae)
|
199 |
+
audio = einops.rearrange(audio, "B T C N F -> B C N (T F)")
|
200 |
+
|
201 |
+
return video, audio
|
202 |
+
|
203 |
+
def generate_sample_v2a(vae,
|
204 |
+
rectified_flow,
|
205 |
+
forward_fn,
|
206 |
+
video_length,
|
207 |
+
video,
|
208 |
+
audio_latent_size,
|
209 |
+
y,
|
210 |
+
device,
|
211 |
+
cfg_scale=1,
|
212 |
+
scale=1):
|
213 |
+
|
214 |
+
|
215 |
+
a_z = torch.randn(audio_latent_size, device=device)*rectified_flow.noise_scale
|
216 |
+
|
217 |
+
model_kwargs = dict(y=y, cfg_scale=cfg_scale) if cfg_scale else dict(y=y)
|
218 |
+
|
219 |
+
sample_fn = rectified_flow.sample_v2a(
|
220 |
+
forward_fn, video, a_z, model_kwargs=model_kwargs, scale=scale, progress=True)()
|
221 |
+
|
222 |
+
audio = []
|
223 |
+
for i in tqdm(range(video_length), desc="Generating frames"):
|
224 |
+
audio_samples = next(sample_fn)
|
225 |
+
|
226 |
+
audio.append(audio_samples)
|
227 |
+
|
228 |
+
audio = torch.stack(audio, dim=1)
|
229 |
+
|
230 |
+
video = decode_video(video, vae)
|
231 |
+
audio = einops.rearrange(audio, "B T C N F -> B C N (T F)")
|
232 |
+
|
233 |
+
return video, audio
|
234 |
+
|
235 |
+
def dict_to_json(path, args):
|
236 |
+
with open(path, 'w') as f:
|
237 |
+
json.dump(args.__dict__, f, indent=2)
|
238 |
+
|
239 |
+
def json_to_dict(path, args):
|
240 |
+
with open(path, 'r') as f:
|
241 |
+
args.__dict__ = json.load(f)
|
242 |
+
return args
|
243 |
+
|
244 |
+
def log_args(args, logger):
|
245 |
+
text = ""
|
246 |
+
for k, v in vars(args).items():
|
247 |
+
text += f'{k}={v}\n'
|
248 |
+
logger.info(f"##### ARGS #####\n{text}")
|
249 |
+
|
250 |
+
def out2img(samples):
|
251 |
+
return th.clamp(127.5 * samples + 128.0, 0, 255).to(
|
252 |
+
dtype=th.uint8
|
253 |
+
).cuda()
|
254 |
+
|
255 |
+
def get_gpu_usage():
|
256 |
+
device = th.device('cuda:0')
|
257 |
+
free, total = th.cuda.mem_get_info(device)
|
258 |
+
mem_used_MB = (total - free) / 1024 ** 2
|
259 |
+
return mem_used_MB
|
260 |
+
|
261 |
+
def get_wavs(norm_spec, vocoder, audio_scale, device):
|
262 |
+
norm_spec = norm_spec.squeeze(1)
|
263 |
+
norm_spec = norm_spec / audio_scale
|
264 |
+
post_norm_spec = denormalize(norm_spec).to(device)
|
265 |
+
raw_chunk_spec = denormalize_spectrogram(post_norm_spec)
|
266 |
+
wavs = vocoder.inference(raw_chunk_spec)
|
267 |
+
return wavs
|