Spaces:
Running
Running
WeixuanYuan
commited on
Commit
•
ae1bdf7
1
Parent(s):
39653fc
Upload 66 files
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- Dockerfile +27 -0
- app.py +107 -0
- app_chat.py +7 -0
- metrics/FD.py +293 -0
- metrics/IS.py +218 -0
- metrics/P_C_T.py +12 -0
- metrics/get_reference_AST_features.py +63 -0
- metrics/pipelines.py +144 -0
- metrics/pipelines_STFT.py +100 -0
- metrics/precision_recall.py +204 -0
- metrics/visualizations.py +123 -0
- model/DiffSynthSampler.py +425 -0
- model/GAN.py +262 -0
- model/VQGAN.py +684 -0
- model/__pycache__/DiffSynthSampler.cpython-310.pyc +0 -0
- model/__pycache__/GAN.cpython-310.pyc +0 -0
- model/__pycache__/VQGAN.cpython-310.pyc +0 -0
- model/__pycache__/diffusion.cpython-310.pyc +0 -0
- model/__pycache__/diffusion_components.cpython-310.pyc +0 -0
- model/__pycache__/multimodal_model.cpython-310.pyc +0 -0
- model/__pycache__/perceptual_label_predictor.cpython-37.pyc +0 -0
- model/__pycache__/timbre_encoder_pretrain.cpython-310.pyc +0 -0
- model/diffusion.py +371 -0
- model/diffusion_components.py +351 -0
- model/multimodal_model.py +274 -0
- model/timbre_encoder_pretrain.py +220 -0
- models/24_1_2024-52_4x_L_D_imageVQVAE.pth +3 -0
- models/24_1_2024-52_4x_L_D_imageVQVAE_discriminator.pth +3 -0
- models/24_1_2024_MMM.pth +3 -0
- models/24_1_2024_STFT_timbre_encoder.pth +3 -0
- models/history/28_1_2024_TE_STFT_300000_UNet.pth +3 -0
- requirements.txt +15 -0
- tools.py +344 -0
- webUI/__pycache__/app.cpython-310.pyc +0 -0
- webUI/deprecated/interpolationWithCondition.py +178 -0
- webUI/deprecated/interpolationWithXT.py +173 -0
- webUI/natural_language_guided/GAN.py +164 -0
- webUI/natural_language_guided/README.py +53 -0
- webUI/natural_language_guided/__pycache__/README.cpython-310.pyc +0 -0
- webUI/natural_language_guided/__pycache__/README_STFT.cpython-310.pyc +0 -0
- webUI/natural_language_guided/__pycache__/buildInstrument_STFT.cpython-310.pyc +0 -0
- webUI/natural_language_guided/__pycache__/build_instrument.cpython-310.pyc +0 -0
- webUI/natural_language_guided/__pycache__/gradioWebUI.cpython-310.pyc +0 -0
- webUI/natural_language_guided/__pycache__/gradioWebUI_STFT.cpython-310.pyc +0 -0
- webUI/natural_language_guided/__pycache__/gradio_webUI.cpython-310.pyc +0 -0
- webUI/natural_language_guided/__pycache__/inpaintWithText.cpython-310.pyc +0 -0
- webUI/natural_language_guided/__pycache__/inpaintWithText_STFT.cpython-310.pyc +0 -0
- webUI/natural_language_guided/__pycache__/inpaint_with_text.cpython-310.pyc +0 -0
- webUI/natural_language_guided/__pycache__/rec.cpython-310.pyc +0 -0
- webUI/natural_language_guided/__pycache__/recSTFT.cpython-310.pyc +0 -0
Dockerfile
ADDED
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# 使用 Python 3.9 作为基础镜像
|
2 |
+
FROM python:3.9
|
3 |
+
|
4 |
+
# 添加用户
|
5 |
+
RUN useradd -m -u 1000 user
|
6 |
+
|
7 |
+
# 设置工作目录
|
8 |
+
WORKDIR /app
|
9 |
+
|
10 |
+
# 切换到 root 用户以安装系统依赖
|
11 |
+
USER root
|
12 |
+
RUN apt-get update && apt-get install -y rubberband-cli
|
13 |
+
|
14 |
+
# 切回到普通用户
|
15 |
+
USER user
|
16 |
+
ENV PATH="/home/user/.local/bin:$PATH"
|
17 |
+
|
18 |
+
# 复制并安装 Python 依赖
|
19 |
+
COPY --chown=user ./requirements.txt requirements.txt
|
20 |
+
RUN pip install --no-cache-dir --upgrade -r requirements.txt
|
21 |
+
|
22 |
+
# 复制应用文件
|
23 |
+
COPY --chown=user . /app
|
24 |
+
|
25 |
+
# 启动 Gradio 应用,假设应用入口文件为 app.py
|
26 |
+
# CMD ["python", "app.py"]
|
27 |
+
CMD ["python", "app_chat.py"]
|
app.py
ADDED
@@ -0,0 +1,107 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import numpy as np
|
3 |
+
from tqdm import tqdm
|
4 |
+
from model.DiffSynthSampler import DiffSynthSampler
|
5 |
+
import soundfile as sf
|
6 |
+
# import pyrubberband as pyrb
|
7 |
+
from tqdm import tqdm
|
8 |
+
from model.VQGAN import get_VQGAN
|
9 |
+
from model.diffusion import get_diffusion_model
|
10 |
+
from transformers import AutoTokenizer, ClapModel
|
11 |
+
from model.diffusion_components import linear_beta_schedule
|
12 |
+
from model.timbre_encoder_pretrain import get_timbre_encoder
|
13 |
+
from model.multimodal_model import get_multi_modal_model
|
14 |
+
|
15 |
+
|
16 |
+
|
17 |
+
import gradio as gr
|
18 |
+
from webUI.natural_language_guided.gradio_webUI import GradioWebUI
|
19 |
+
from webUI.natural_language_guided.text2sound import get_text2sound_module
|
20 |
+
from webUI.natural_language_guided.sound2sound_with_text import get_sound2sound_with_text_module
|
21 |
+
from webUI.natural_language_guided.inpaint_with_text import get_inpaint_with_text_module
|
22 |
+
from webUI.natural_language_guided.build_instrument import get_build_instrument_module
|
23 |
+
from webUI.natural_language_guided.README import get_readme_module
|
24 |
+
|
25 |
+
|
26 |
+
|
27 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
28 |
+
use_pretrained_CLAP = False
|
29 |
+
|
30 |
+
# load VQ-GAN
|
31 |
+
VAE_model_name = "24_1_2024-52_4x_L_D"
|
32 |
+
modelConfig = {"in_channels": 3, "hidden_channels": [80, 160], "embedding_dim": 4, "out_channels": 3, "block_depth": 2,
|
33 |
+
"attn_pos": [80, 160], "attn_with_skip": True,
|
34 |
+
"num_embeddings": 8192, "commitment_cost": 0.25, "decay": 0.99,
|
35 |
+
"norm_type": "groupnorm", "act_type": "swish", "num_groups": 16}
|
36 |
+
VAE = get_VQGAN(modelConfig, load_pretrain=True, model_name=VAE_model_name, device=device)
|
37 |
+
|
38 |
+
# load U-Net
|
39 |
+
UNet_model_name = "history/28_1_2024_CLAP_STFT_180000" if use_pretrained_CLAP else "history/28_1_2024_TE_STFT_300000"
|
40 |
+
unetConfig = {"in_dim": 4, "down_dims": [96, 96, 192, 384], "up_dims": [384, 384, 192, 96], "attn_type": "linear_add", "condition_type": "natural_language_prompt", "label_emb_dim": 512}
|
41 |
+
uNet = get_diffusion_model(unetConfig, load_pretrain=True, model_name=UNet_model_name, device=device)
|
42 |
+
|
43 |
+
# load LM
|
44 |
+
CLAP_temp = ClapModel.from_pretrained("laion/clap-htsat-unfused") # 153,492,890
|
45 |
+
CLAP_tokenizer = AutoTokenizer.from_pretrained("laion/clap-htsat-unfused")
|
46 |
+
|
47 |
+
timbre_encoder_name = "24_1_2024_STFT"
|
48 |
+
timbre_encoder_Config = {"input_dim": 512, "feature_dim": 512, "hidden_dim": 1024, "num_instrument_classes": 1006, "num_instrument_family_classes": 11, "num_velocity_classes": 128, "num_qualities": 10, "num_layers": 3}
|
49 |
+
timbre_encoder = get_timbre_encoder(timbre_encoder_Config, load_pretrain=True, model_name=timbre_encoder_name, device=device)
|
50 |
+
|
51 |
+
if use_pretrained_CLAP:
|
52 |
+
text_encoder = CLAP_temp
|
53 |
+
else:
|
54 |
+
multimodalmodel_name = "24_1_2024"
|
55 |
+
multimodalmodel_config = {"text_feature_dim": 512, "spectrogram_feature_dim": 1024, "multi_modal_emb_dim": 512, "num_projection_layers": 2,
|
56 |
+
"temperature": 1.0, "dropout": 0.1, "freeze_text_encoder": False, "freeze_spectrogram_encoder": False}
|
57 |
+
mmm = get_multi_modal_model(timbre_encoder, CLAP_temp, multimodalmodel_config, load_pretrain=True, model_name=multimodalmodel_name, device=device)
|
58 |
+
|
59 |
+
text_encoder = mmm.to("cpu")
|
60 |
+
|
61 |
+
|
62 |
+
|
63 |
+
|
64 |
+
|
65 |
+
|
66 |
+
gradioWebUI = GradioWebUI(device, VAE, uNet, text_encoder, CLAP_tokenizer, freq_resolution=512, time_resolution=256, channels=4, timesteps=1000, squared=False,
|
67 |
+
VAE_scale=4, flexible_duration=True, noise_strategy="repeat", GAN_generator=None)
|
68 |
+
|
69 |
+
with gr.Blocks(theme=gr.themes.Soft(), mode="dark") as demo:
|
70 |
+
# with gr.Blocks(theme='WeixuanYuan/Soft_dark', mode="dark") as demo:
|
71 |
+
gr.Markdown("DiffuSynth v0.2")
|
72 |
+
|
73 |
+
reconstruction_state = gr.State(value={})
|
74 |
+
text2sound_state = gr.State(value={})
|
75 |
+
sound2sound_state = gr.State(value={})
|
76 |
+
inpaint_state = gr.State(value={})
|
77 |
+
super_resolution_state = gr.State(value={})
|
78 |
+
virtual_instruments_state = gr.State(value={"virtual_instruments": {}})
|
79 |
+
|
80 |
+
get_text2sound_module(gradioWebUI, text2sound_state, virtual_instruments_state)
|
81 |
+
get_sound2sound_with_text_module(gradioWebUI, sound2sound_state, virtual_instruments_state)
|
82 |
+
get_inpaint_with_text_module(gradioWebUI, inpaint_state, virtual_instruments_state)
|
83 |
+
get_build_instrument_module(gradioWebUI, virtual_instruments_state)
|
84 |
+
get_readme_module()
|
85 |
+
|
86 |
+
demo.launch(debug=True, share=True)
|
87 |
+
|
88 |
+
|
89 |
+
|
90 |
+
|
91 |
+
|
92 |
+
|
93 |
+
|
94 |
+
|
95 |
+
|
96 |
+
|
97 |
+
|
98 |
+
|
99 |
+
|
100 |
+
|
101 |
+
|
102 |
+
|
103 |
+
|
104 |
+
|
105 |
+
|
106 |
+
|
107 |
+
|
app_chat.py
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
|
3 |
+
def greet(name):
|
4 |
+
return "Hello " + name + "!!"
|
5 |
+
|
6 |
+
demo = gr.Interface(fn=greet, inputs="text", outputs="text")
|
7 |
+
demo.launch()
|
metrics/FD.py
ADDED
@@ -0,0 +1,293 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
import os
|
3 |
+
|
4 |
+
import librosa
|
5 |
+
import numpy as np
|
6 |
+
import torch
|
7 |
+
from tqdm import tqdm
|
8 |
+
from scipy.linalg import sqrtm
|
9 |
+
|
10 |
+
from metrics.pipelines import sample_pipeline, sample_pipeline_GAN
|
11 |
+
from metrics.pipelines_STFT import sample_pipeline_STFT, sample_pipeline_GAN_STFT
|
12 |
+
from tools import rms_normalize
|
13 |
+
|
14 |
+
|
15 |
+
def ASTaudio2feature(device, signal, processor, AST, sampling_rate):
|
16 |
+
# audio file is decoded on the fly
|
17 |
+
inputs = processor(signal, sampling_rate=sampling_rate, return_tensors="pt").to(device)
|
18 |
+
with torch.no_grad():
|
19 |
+
outputs = AST(**inputs)
|
20 |
+
|
21 |
+
last_hidden_states = outputs.last_hidden_state[:, 0, :].to("cpu").detach().numpy()
|
22 |
+
return last_hidden_states
|
23 |
+
|
24 |
+
|
25 |
+
# 计算两个numpy数组的均值和协方差矩阵
|
26 |
+
def calculate_statistics(features):
|
27 |
+
mu = np.mean(features, axis=0)
|
28 |
+
sigma = np.cov(features, rowvar=False)
|
29 |
+
return mu, sigma
|
30 |
+
|
31 |
+
|
32 |
+
# 计算FID
|
33 |
+
def calculate_fid(mu1, sigma1, mu2, sigma2, eps=1e-6):
|
34 |
+
# 在协方差矩阵对角线上添加一个小的正值
|
35 |
+
sigma1 += np.eye(sigma1.shape[0]) * eps
|
36 |
+
sigma2 += np.eye(sigma2.shape[0]) * eps
|
37 |
+
|
38 |
+
ssdiff = np.sum((mu1 - mu2) ** 2.0)
|
39 |
+
covmean = sqrtm(sigma1.dot(sigma2))
|
40 |
+
|
41 |
+
# 由于数值问题,有时可能会得到复数,只取实部
|
42 |
+
if np.iscomplexobj(covmean):
|
43 |
+
covmean = covmean.real
|
44 |
+
|
45 |
+
fid = ssdiff + np.trace(sigma1 + sigma2 - 2.0 * covmean)
|
46 |
+
return fid
|
47 |
+
|
48 |
+
|
49 |
+
# 计算FID
|
50 |
+
def calculate_fid_dict(dict1, dict2, eps=1e-6):
|
51 |
+
# 在协方差矩阵对角线上添加一个小的正值
|
52 |
+
mu1, sigma1 = dict1["mu"], dict1["sigma"]
|
53 |
+
mu2, sigma2 = dict2["mu"], dict2["sigma"]
|
54 |
+
sigma1 += np.eye(sigma1.shape[0]) * eps
|
55 |
+
sigma2 += np.eye(sigma2.shape[0]) * eps
|
56 |
+
|
57 |
+
ssdiff = np.sum((mu1 - mu2) ** 2.0)
|
58 |
+
covmean = sqrtm(sigma1.dot(sigma2))
|
59 |
+
|
60 |
+
# 由于数值问题,有时可能会得到复数,只取实部
|
61 |
+
if np.iscomplexobj(covmean):
|
62 |
+
covmean = covmean.real
|
63 |
+
|
64 |
+
fid = ssdiff + np.trace(sigma1 + sigma2 - 2.0 * covmean)
|
65 |
+
return fid
|
66 |
+
|
67 |
+
|
68 |
+
# Todo: AudioLDM
|
69 |
+
# def generate_features_with_AudioLDM_and_AST(device, processor, AST, AudioLDM_signals_directory_path, return_feature=False):
|
70 |
+
|
71 |
+
# diffuSynth_features = []
|
72 |
+
|
73 |
+
# # Step 1: Load all wav files in AudioLDM_signals_directory_path
|
74 |
+
# AudioLDM_signals = []
|
75 |
+
# signal_lengths = set()
|
76 |
+
|
77 |
+
# for file_name in os.listdir(AudioLDM_signals_directory_path):
|
78 |
+
# if file_name.endswith('.wav'):
|
79 |
+
# file_path = os.path.join(AudioLDM_signals_directory_path, file_name)
|
80 |
+
# signal, sr = librosa.load(file_path, sr=16000) # Load audio file with sampling rate 16000
|
81 |
+
# # Normalize
|
82 |
+
# AudioLDM_signals.append(rms_normalize(signal))
|
83 |
+
# signal_lengths.add(len(signal))
|
84 |
+
|
85 |
+
# # Step 2: Check if all signals have the same length
|
86 |
+
# if len(signal_lengths) != 1:
|
87 |
+
# raise ValueError("Not all signals have the same length. Please ensure all audio files are of the same length.")
|
88 |
+
|
89 |
+
# # Step 3: Reshape to signal_batches [number_batches, batch_size=8, signal_length]
|
90 |
+
# batch_size = 8
|
91 |
+
# signal_length = signal_lengths.pop() # All lengths are the same, get one of them
|
92 |
+
|
93 |
+
# # Create batches
|
94 |
+
# signal_batches = [AudioLDM_signals[i:i + batch_size] for i in range(0, len(AudioLDM_signals), batch_size)]
|
95 |
+
|
96 |
+
# for signal_batch in tqdm(signal_batches):
|
97 |
+
|
98 |
+
# features = ASTaudio2feature(device, signal_batch, processor, AST, sampling_rate=16000)
|
99 |
+
# diffuSynth_features.extend(features)
|
100 |
+
|
101 |
+
# if return_feature:
|
102 |
+
# return diffuSynth_features
|
103 |
+
# else:
|
104 |
+
# mu, sigma = calculate_statistics(diffuSynth_features)
|
105 |
+
# return {"mu": mu, "sigma": sigma}
|
106 |
+
|
107 |
+
def generate_features_with_AudioLDM_and_AST(device, processor, AST, AudioLDM_signals_directory_path, return_feature=False):
|
108 |
+
|
109 |
+
diffuSynth_features = []
|
110 |
+
|
111 |
+
# Step 1: Load all wav files in AudioLDM_signals_directory_path
|
112 |
+
AudioLDM_signals = []
|
113 |
+
signal_lengths = set()
|
114 |
+
target_length = 4 * 16000 # 4 seconds * 16000 samples per second
|
115 |
+
|
116 |
+
for file_name in os.listdir(AudioLDM_signals_directory_path):
|
117 |
+
if file_name.endswith('.wav') and not file_name.startswith('._'):
|
118 |
+
file_path = os.path.join(AudioLDM_signals_directory_path, file_name)
|
119 |
+
try:
|
120 |
+
signal, sr = librosa.load(file_path, sr=16000) # Load audio file with sampling rate 16000
|
121 |
+
if len(signal) >= target_length:
|
122 |
+
signal = signal[:target_length] # Take only the first 4 seconds
|
123 |
+
else:
|
124 |
+
raise ValueError(f"The file {file_name} is shorter than 4 seconds.")
|
125 |
+
# Normalize
|
126 |
+
AudioLDM_signals.append(rms_normalize(signal))
|
127 |
+
signal_lengths.add(len(signal))
|
128 |
+
except Exception as e:
|
129 |
+
print(f"Error loading {file_name}: {e}")
|
130 |
+
|
131 |
+
# Step 2: Check if all signals have the same length
|
132 |
+
if len(signal_lengths) != 1:
|
133 |
+
raise ValueError("Not all signals have the same length. Please ensure all audio files are of the same length.")
|
134 |
+
|
135 |
+
# Step 3: Reshape to signal_batches [number_batches, batch_size=8, signal_length]
|
136 |
+
batch_size = 8
|
137 |
+
signal_length = signal_lengths.pop() # All lengths are the same, get one of them
|
138 |
+
|
139 |
+
# Create batches
|
140 |
+
signal_batches = [AudioLDM_signals[i:i + batch_size] for i in range(0, len(AudioLDM_signals), batch_size)]
|
141 |
+
|
142 |
+
for signal_batch in tqdm(signal_batches):
|
143 |
+
features = ASTaudio2feature(device, signal_batch, processor, AST, sampling_rate=16000)
|
144 |
+
diffuSynth_features.extend(features)
|
145 |
+
|
146 |
+
if return_feature:
|
147 |
+
return diffuSynth_features
|
148 |
+
else:
|
149 |
+
mu, sigma = calculate_statistics(diffuSynth_features)
|
150 |
+
return {"mu": mu, "sigma": sigma}
|
151 |
+
|
152 |
+
|
153 |
+
|
154 |
+
|
155 |
+
def generate_features_with_diffuSynth_and_AST(device, uNet, VAE, mmm, CLAP_tokenizer, processor, AST, num_batches,
|
156 |
+
positive_prompts, negative_prompts="", CFG=1, sample_steps=10, task="spectrograms", return_feature=False):
|
157 |
+
diffuSynth_features = []
|
158 |
+
|
159 |
+
if task == "spectrograms":
|
160 |
+
pipe = sample_pipeline
|
161 |
+
elif task == "STFT":
|
162 |
+
pipe = sample_pipeline_STFT
|
163 |
+
else:
|
164 |
+
raise NotImplementedError
|
165 |
+
|
166 |
+
for _ in tqdm(range(num_batches)):
|
167 |
+
quantized_latent_representations, reconstruction_batch, signals = pipe(device, uNet, VAE, mmm,
|
168 |
+
CLAP_tokenizer,
|
169 |
+
positive_prompts=positive_prompts,
|
170 |
+
negative_prompts=negative_prompts,
|
171 |
+
batchsize=8,
|
172 |
+
sample_steps=sample_steps,
|
173 |
+
CFG=CFG, seed=None,
|
174 |
+
return_latent=False)
|
175 |
+
|
176 |
+
features = ASTaudio2feature(device, signals, processor, AST, sampling_rate=16000)
|
177 |
+
diffuSynth_features.extend(features)
|
178 |
+
|
179 |
+
if return_feature:
|
180 |
+
return diffuSynth_features
|
181 |
+
else:
|
182 |
+
mu, sigma = calculate_statistics(diffuSynth_features)
|
183 |
+
return {"mu": mu, "sigma": sigma}
|
184 |
+
|
185 |
+
|
186 |
+
def generate_features_with_GAN_and_AST(device, gan_generator, VAE, mmm, CLAP_tokenizer, processor, AST, num_batches,
|
187 |
+
positive_prompts, negative_prompts="", CFG=1, sample_steps=10, task="spectrograms", return_feature=False):
|
188 |
+
diffuSynth_features = []
|
189 |
+
|
190 |
+
if task == "spectrograms":
|
191 |
+
pipe = sample_pipeline_GAN
|
192 |
+
elif task == "STFT":
|
193 |
+
pipe = sample_pipeline_GAN_STFT
|
194 |
+
else:
|
195 |
+
raise NotImplementedError
|
196 |
+
|
197 |
+
for _ in tqdm(range(num_batches)):
|
198 |
+
quantized_latent_representations, reconstruction_batch, signals = pipe(device, gan_generator, VAE, mmm,
|
199 |
+
CLAP_tokenizer,
|
200 |
+
positive_prompts=positive_prompts,
|
201 |
+
negative_prompts=negative_prompts,
|
202 |
+
batchsize=8,
|
203 |
+
sample_steps=sample_steps,
|
204 |
+
CFG=CFG, seed=None,
|
205 |
+
return_latent=False)
|
206 |
+
|
207 |
+
features = ASTaudio2feature(device, signals, processor, AST, sampling_rate=16000)
|
208 |
+
diffuSynth_features.extend(features)
|
209 |
+
|
210 |
+
if return_feature:
|
211 |
+
return diffuSynth_features
|
212 |
+
else:
|
213 |
+
mu, sigma = calculate_statistics(diffuSynth_features)
|
214 |
+
return {"mu": mu, "sigma": sigma}
|
215 |
+
|
216 |
+
|
217 |
+
def get_FD(train_features, device, uNet, VAE, mmm, CLAP_tokenizer, processor, AST, num_batches, positive_prompts,
|
218 |
+
negative_prompts="", CFG=1, sample_steps=10):
|
219 |
+
diffuSynth_features = generate_features_with_diffuSynth_and_AST(device, uNet, VAE, mmm, CLAP_tokenizer, processor,
|
220 |
+
AST, num_batches, positive_prompts,
|
221 |
+
negative_prompts=negative_prompts, CFG=CFG,
|
222 |
+
sample_steps=sample_steps)
|
223 |
+
|
224 |
+
mu_real, sigma_real = calculate_statistics(train_features)
|
225 |
+
mu_gen, sigma_gen = calculate_statistics(diffuSynth_features)
|
226 |
+
|
227 |
+
fid_score = calculate_fid(mu_real, sigma_real, mu_gen, sigma_gen)
|
228 |
+
print('FID score:', fid_score)
|
229 |
+
|
230 |
+
|
231 |
+
def get_fid_score(feature1, features2):
|
232 |
+
mu_real, sigma_real = calculate_statistics(feature1)
|
233 |
+
mu_gen, sigma_gen = calculate_statistics(features2)
|
234 |
+
|
235 |
+
fid_score = calculate_fid(mu_real, sigma_real, mu_gen, sigma_gen)
|
236 |
+
# print('FID score:', fid_score)
|
237 |
+
return fid_score
|
238 |
+
|
239 |
+
|
240 |
+
def calculate_fid_matrix(features_list_1, features_list_2, get_fid_score):
|
241 |
+
# 初始化一个矩阵来存储FID分数
|
242 |
+
# 矩阵的大小为 len(features_list_1) x len(features_list_2)
|
243 |
+
fid_scores = [[0 for _ in range(len(features_list_2))] for _ in range(len(features_list_1))]
|
244 |
+
|
245 |
+
# 遍历两个列表,并计算每一对特征集合的FID分数
|
246 |
+
for i, feature1 in enumerate(features_list_1):
|
247 |
+
for j, feature2 in enumerate(features_list_2):
|
248 |
+
fid_scores[i][j] = get_fid_score(feature1, feature2)
|
249 |
+
|
250 |
+
return fid_scores
|
251 |
+
|
252 |
+
|
253 |
+
def save_AST_feature(key, mu, sigma, path='results/AST_metric/pre_calculated_features/AST_features.json'):
|
254 |
+
# 尝试打开并读取现有的JSON文件
|
255 |
+
try:
|
256 |
+
with open(path, 'r') as file:
|
257 |
+
data = json.load(file)
|
258 |
+
except FileNotFoundError:
|
259 |
+
# 如果文件不存在,创建一个新的字典
|
260 |
+
data = {}
|
261 |
+
|
262 |
+
if isinstance(mu, np.ndarray):
|
263 |
+
mu = mu.tolist()
|
264 |
+
if isinstance(sigma, np.ndarray):
|
265 |
+
sigma = sigma.tolist()
|
266 |
+
|
267 |
+
# 添加新数据
|
268 |
+
data[key] = {"mu": mu, "sigma": sigma}
|
269 |
+
|
270 |
+
# 将更新后的数据写回文件
|
271 |
+
with open(path, 'w') as file:
|
272 |
+
json.dump(data, file, indent=4)
|
273 |
+
|
274 |
+
|
275 |
+
def read_AST_features(path='results/AST_metric/pre_calculated_features/AST_features.json'):
|
276 |
+
try:
|
277 |
+
# 尝试打开并读取JSON文件
|
278 |
+
with open(path, 'r') as file:
|
279 |
+
AST_features = json.load(file)
|
280 |
+
|
281 |
+
for AST_feature_name in AST_features.keys():
|
282 |
+
AST_features[AST_feature_name]["mu"] = np.array(AST_features[AST_feature_name]["mu"])
|
283 |
+
AST_features[AST_feature_name]["sigma"] = np.array(AST_features[AST_feature_name]["sigma"])
|
284 |
+
|
285 |
+
return AST_features
|
286 |
+
except FileNotFoundError:
|
287 |
+
# 如果文件不存在,返回一个空字典
|
288 |
+
print(f"文件 {path} 未找到.")
|
289 |
+
return {}
|
290 |
+
except json.JSONDecodeError:
|
291 |
+
# 如果文件不是有效的JSON,返回一个空字典
|
292 |
+
print(f"文件 {path} 不是有效的JSON格式.")
|
293 |
+
return {}
|
metrics/IS.py
ADDED
@@ -0,0 +1,218 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
|
3 |
+
import librosa
|
4 |
+
import numpy as np
|
5 |
+
import torch
|
6 |
+
from tqdm import tqdm
|
7 |
+
|
8 |
+
from metrics.pipelines import sample_pipeline, inpaint_pipeline, sample_pipeline_GAN
|
9 |
+
from metrics.pipelines_STFT import sample_pipeline_STFT, sample_pipeline_GAN_STFT
|
10 |
+
from tools import rms_normalize, pad_STFT, encode_stft
|
11 |
+
from webUI.natural_language_guided.utils import InputBatch2Encode_STFT
|
12 |
+
|
13 |
+
def get_inception_score_for_AudioLDM(device, timbre_encoder, VAE, AudioLDM_signals_directory_path):
|
14 |
+
VAE_encoder, VAE_quantizer, VAE_decoder = VAE._encoder, VAE._vq_vae, VAE._decoder
|
15 |
+
|
16 |
+
diffuSynth_probabilities = []
|
17 |
+
|
18 |
+
# Step 1: Load all wav files in AudioLDM_signals_directory_path
|
19 |
+
AudioLDM_signals = []
|
20 |
+
signal_lengths = set()
|
21 |
+
target_length = 4 * 16000 # 4 seconds * 16000 samples per second
|
22 |
+
|
23 |
+
for file_name in os.listdir(AudioLDM_signals_directory_path):
|
24 |
+
if file_name.endswith('.wav') and not file_name.startswith('._'):
|
25 |
+
file_path = os.path.join(AudioLDM_signals_directory_path, file_name)
|
26 |
+
signal, sr = librosa.load(file_path, sr=16000) # Load audio file with sampling rate 16000
|
27 |
+
if len(signal) >= target_length:
|
28 |
+
signal = signal[:target_length] # Take only the first 4 seconds
|
29 |
+
else:
|
30 |
+
raise ValueError(f"The file {file_name} is shorter than 4 seconds.")
|
31 |
+
# Normalize
|
32 |
+
AudioLDM_signals.append(rms_normalize(signal))
|
33 |
+
signal_lengths.add(len(signal))
|
34 |
+
|
35 |
+
# Step 2: Check if all signals have the same length
|
36 |
+
if len(signal_lengths) != 1:
|
37 |
+
raise ValueError("Not all signals have the same length. Please ensure all audio files are of the same length.")
|
38 |
+
|
39 |
+
encoded_audios = []
|
40 |
+
for origin_audio in AudioLDM_signals:
|
41 |
+
D = librosa.stft(origin_audio, n_fft=1024, hop_length=256, win_length=1024)
|
42 |
+
padded_D = pad_STFT(D)
|
43 |
+
encoded_D = encode_stft(padded_D)
|
44 |
+
encoded_audios.append(encoded_D)
|
45 |
+
encoded_audios_np = np.array(encoded_audios)
|
46 |
+
origin_spectrogram_batch_tensor = torch.from_numpy(encoded_audios_np).float().to(device)
|
47 |
+
|
48 |
+
# Step 3: Reshape to signal_batches [number_batches, batch_size=8, signal_length]
|
49 |
+
batch_size = 8
|
50 |
+
num_batches = int(np.ceil(origin_spectrogram_batch_tensor.shape[0] / batch_size))
|
51 |
+
spectrogram_batches = []
|
52 |
+
for i in range(num_batches):
|
53 |
+
batch = origin_spectrogram_batch_tensor[i * batch_size:(i + 1) * batch_size]
|
54 |
+
spectrogram_batches.append(batch)
|
55 |
+
|
56 |
+
for spectrogram_batch in tqdm(spectrogram_batches):
|
57 |
+
spectrogram_batch = spectrogram_batch.to(device)
|
58 |
+
_, _, _, _, quantized_latent_representations = InputBatch2Encode_STFT(VAE_encoder, spectrogram_batch, quantizer=VAE_quantizer, squared=False)
|
59 |
+
quantized_latent_representations = quantized_latent_representations
|
60 |
+
feature, instrument_logits, instrument_family_logits, velocity_logits, qualities = timbre_encoder(quantized_latent_representations)
|
61 |
+
probabilities = torch.nn.functional.softmax(instrument_logits, dim=1)
|
62 |
+
|
63 |
+
diffuSynth_probabilities.extend(probabilities.to("cpu").detach().numpy())
|
64 |
+
|
65 |
+
return inception_score(np.array(diffuSynth_probabilities))
|
66 |
+
|
67 |
+
|
68 |
+
# def get_inception_score_for_AudioLDM(device, timbre_encoder, VAE, AudioLDM_signals_directory_path):
|
69 |
+
# VAE_encoder, VAE_quantizer, VAE_decoder = VAE._encoder, VAE._vq_vae, VAE._decoder
|
70 |
+
#
|
71 |
+
# diffuSynth_probabilities = []
|
72 |
+
#
|
73 |
+
# # Step 1: Load all wav files in AudioLDM_signals_directory_path
|
74 |
+
# AudioLDM_signals = []
|
75 |
+
# signal_lengths = set()
|
76 |
+
#
|
77 |
+
# for file_name in os.listdir(AudioLDM_signals_directory_path):
|
78 |
+
# if file_name.endswith('.wav'):
|
79 |
+
# file_path = os.path.join(AudioLDM_signals_directory_path, file_name)
|
80 |
+
# signal, sr = librosa.load(file_path, sr=16000) # Load audio file with sampling rate 16000
|
81 |
+
# # Normalize
|
82 |
+
# AudioLDM_signals.append(rms_normalize(signal))
|
83 |
+
# signal_lengths.add(len(signal))
|
84 |
+
#
|
85 |
+
# # Step 2: Check if all signals have the same length
|
86 |
+
# if len(signal_lengths) != 1:
|
87 |
+
# raise ValueError("Not all signals have the same length. Please ensure all audio files are of the same length.")
|
88 |
+
#
|
89 |
+
# encoded_audios = []
|
90 |
+
# for origin_audio in AudioLDM_signals:
|
91 |
+
# D = librosa.stft(origin_audio, n_fft=1024, hop_length=256, win_length=1024)
|
92 |
+
# padded_D = pad_STFT(D)
|
93 |
+
# encoded_D = encode_stft(padded_D)
|
94 |
+
# encoded_audios.append(encoded_D)
|
95 |
+
# encoded_audios_np = np.array(encoded_audios)
|
96 |
+
# origin_spectrogram_batch_tensor = torch.from_numpy(encoded_audios_np).float().to(device)
|
97 |
+
#
|
98 |
+
#
|
99 |
+
# # Step 3: Reshape to signal_batches [number_batches, batch_size=8, signal_length]
|
100 |
+
# batch_size = 8
|
101 |
+
# num_batches = int(np.ceil(origin_spectrogram_batch_tensor.shape[0] / batch_size))
|
102 |
+
# spectrogram_batches = []
|
103 |
+
# for i in range(num_batches):
|
104 |
+
# batch = origin_spectrogram_batch_tensor[i * batch_size:(i + 1) * batch_size]
|
105 |
+
# spectrogram_batches.append(batch)
|
106 |
+
#
|
107 |
+
#
|
108 |
+
# for spectrogram_batch in tqdm(spectrogram_batches):
|
109 |
+
# spectrogram_batch = spectrogram_batch.to(device)
|
110 |
+
# _, _, _, _, quantized_latent_representations = InputBatch2Encode_STFT(VAE_encoder, spectrogram_batch, quantizer=VAE_quantizer,squared=False)
|
111 |
+
# quantized_latent_representations = quantized_latent_representations
|
112 |
+
# feature, instrument_logits, instrument_family_logits, velocity_logits, qualities = timbre_encoder(quantized_latent_representations)
|
113 |
+
# probabilities = torch.nn.functional.softmax(instrument_logits, dim=1)
|
114 |
+
#
|
115 |
+
# diffuSynth_probabilities.extend(probabilities.to("cpu").detach().numpy())
|
116 |
+
#
|
117 |
+
# return inception_score(np.array(diffuSynth_probabilities))
|
118 |
+
|
119 |
+
|
120 |
+
def get_inception_score(device, uNet, VAE, MMM, CLAP_tokenizer, timbre_encoder, num_batches, positive_prompts, negative_prompts="", CFG=1, sample_steps=10, task="spectrograms"):
|
121 |
+
diffuSynth_probabilities = []
|
122 |
+
|
123 |
+
if task == "spectrograms":
|
124 |
+
pipe = sample_pipeline
|
125 |
+
elif task == "STFT":
|
126 |
+
pipe = sample_pipeline_STFT
|
127 |
+
else:
|
128 |
+
raise NotImplementedError
|
129 |
+
|
130 |
+
for _ in tqdm(range(num_batches)):
|
131 |
+
quantized_latent_representations = pipe(device, uNet, VAE, MMM, CLAP_tokenizer,
|
132 |
+
positive_prompts=positive_prompts, negative_prompts=negative_prompts,
|
133 |
+
batchsize=8, sample_steps=sample_steps, CFG=CFG, seed=None)
|
134 |
+
|
135 |
+
quantized_latent_representations = quantized_latent_representations.to(device)
|
136 |
+
feature, instrument_logits, instrument_family_logits, velocity_logits, qualities = timbre_encoder(quantized_latent_representations)
|
137 |
+
probabilities = torch.nn.functional.softmax(instrument_logits, dim=1)
|
138 |
+
|
139 |
+
diffuSynth_probabilities.extend(probabilities.to("cpu").detach().numpy())
|
140 |
+
|
141 |
+
return inception_score(np.array(diffuSynth_probabilities))
|
142 |
+
|
143 |
+
|
144 |
+
def get_inception_score_GAN(device, gan_generator, VAE, MMM, CLAP_tokenizer, timbre_encoder, num_batches, positive_prompts, negative_prompts="", CFG=1, sample_steps=10, task="spectrograms"):
|
145 |
+
diffuSynth_probabilities = []
|
146 |
+
|
147 |
+
if task == "spectrograms":
|
148 |
+
pipe = sample_pipeline_GAN
|
149 |
+
elif task == "STFT":
|
150 |
+
pipe = sample_pipeline_GAN_STFT
|
151 |
+
else:
|
152 |
+
raise NotImplementedError
|
153 |
+
|
154 |
+
for _ in tqdm(range(num_batches)):
|
155 |
+
quantized_latent_representations = pipe(device, gan_generator, VAE, MMM, CLAP_tokenizer,
|
156 |
+
positive_prompts=positive_prompts, negative_prompts=negative_prompts,
|
157 |
+
batchsize=8, sample_steps=sample_steps, CFG=CFG, seed=None)
|
158 |
+
|
159 |
+
quantized_latent_representations = quantized_latent_representations.to(device)
|
160 |
+
feature, instrument_logits, instrument_family_logits, velocity_logits, qualities = timbre_encoder(quantized_latent_representations)
|
161 |
+
probabilities = torch.nn.functional.softmax(instrument_logits, dim=1)
|
162 |
+
|
163 |
+
diffuSynth_probabilities.extend(probabilities.to("cpu").detach().numpy())
|
164 |
+
|
165 |
+
return inception_score(np.array(diffuSynth_probabilities))
|
166 |
+
|
167 |
+
|
168 |
+
def predict_qualities_with_diffuSynth_sample(device, uNet, VAE, MMM, CLAP_tokenizer, timbre_encoder, num_batches, positive_prompts, negative_prompts="", CFG=6, sample_steps=10):
|
169 |
+
diffuSynth_qualities = []
|
170 |
+
for _ in tqdm(range(num_batches)):
|
171 |
+
quantized_latent_representations = sample_pipeline(device, uNet, VAE, MMM, CLAP_tokenizer,
|
172 |
+
positive_prompts=positive_prompts, negative_prompts=negative_prompts,
|
173 |
+
batchsize=8, sample_steps=sample_steps, CFG=CFG, seed=None)
|
174 |
+
|
175 |
+
quantized_latent_representations = quantized_latent_representations.to(device)
|
176 |
+
feature, instrument_logits, instrument_family_logits, velocity_logits, qualities = timbre_encoder(quantized_latent_representations)
|
177 |
+
qualities = qualities.to("cpu").detach().numpy()
|
178 |
+
# qualities = np.where(qualities > 0.5, 1, 0)
|
179 |
+
|
180 |
+
diffuSynth_qualities.extend(qualities)
|
181 |
+
|
182 |
+
return np.mean(diffuSynth_qualities, axis=0)
|
183 |
+
|
184 |
+
|
185 |
+
def generate_probabilities_with_diffuSynth_inpaint(device, uNet, VAE, MMM, CLAP_tokenizer, timbre_encoder, num_batches, guidance, duration, use_dynamic_mask, noising_strength, positive_prompts, negative_prompts="", CFG=6, sample_steps=10):
|
186 |
+
|
187 |
+
inpaint_probabilities, signals = [], []
|
188 |
+
for _ in tqdm(range(num_batches)):
|
189 |
+
quantized_latent_representations, _, rec_signals = inpaint_pipeline(device, uNet, VAE, MMM, CLAP_tokenizer,
|
190 |
+
use_dynamic_mask=use_dynamic_mask, noising_strength=noising_strength, guidance=guidance,
|
191 |
+
positive_prompts=positive_prompts, negative_prompts=negative_prompts, batchsize=8, sample_steps=sample_steps, CFG=CFG, seed=None, duration=duration, mask_flexivity=0.999,
|
192 |
+
return_latent=False)
|
193 |
+
|
194 |
+
quantized_latent_representations = quantized_latent_representations.to(device)
|
195 |
+
feature, instrument_logits, instrument_family_logits, velocity_logits, qualities = timbre_encoder(quantized_latent_representations)
|
196 |
+
probabilities = torch.nn.functional.softmax(instrument_logits, dim=1)
|
197 |
+
|
198 |
+
inpaint_probabilities.extend(probabilities.to("cpu").detach().numpy())
|
199 |
+
signals.extend(rec_signals)
|
200 |
+
|
201 |
+
return np.array(inpaint_probabilities), signals
|
202 |
+
|
203 |
+
|
204 |
+
def inception_score(pred):
|
205 |
+
|
206 |
+
# 计算每个图像的条件概率分布 P(y|x)
|
207 |
+
pyx = pred / np.sum(pred, axis=1, keepdims=True)
|
208 |
+
|
209 |
+
# 计算整个数据集的边缘概率分布 P(y)
|
210 |
+
py = np.mean(pyx, axis=0, keepdims=True)
|
211 |
+
|
212 |
+
# 计算KL散度
|
213 |
+
kl_div = pyx * (np.log(pyx + 1e-11) - np.log(py + 1e-11))
|
214 |
+
|
215 |
+
# 对所有图像求和并平均
|
216 |
+
kl_div_sum = np.sum(kl_div, axis=1)
|
217 |
+
score = np.exp(np.mean(kl_div_sum))
|
218 |
+
return score
|
metrics/P_C_T.py
ADDED
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
from metrics.precision_recall import knn_precision_recall_features
|
3 |
+
|
4 |
+
|
5 |
+
# 生成样本
|
6 |
+
real_features = np.random.normal(0, 1, size=(1600, 512))
|
7 |
+
generated_features = np.random.normal(0, 1, size=(1600, 512))
|
8 |
+
|
9 |
+
state = knn_precision_recall_features(real_features, generated_features, nhood_sizes=[1, 2, 3, 4, 5, 10],
|
10 |
+
row_batch_size=16, col_batch_size=16)
|
11 |
+
|
12 |
+
print(state)
|
metrics/get_reference_AST_features.py
ADDED
@@ -0,0 +1,63 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
import librosa
|
3 |
+
import numpy as np
|
4 |
+
from tqdm import tqdm
|
5 |
+
from metrics.FD import ASTaudio2feature, calculate_statistics, save_AST_feature
|
6 |
+
from tools import rms_normalize
|
7 |
+
from transformers import AutoProcessor, ASTModel
|
8 |
+
|
9 |
+
device = "cpu"
|
10 |
+
processor = AutoProcessor.from_pretrained("MIT/ast-finetuned-audioset-10-10-0.4593")
|
11 |
+
AST = ASTModel.from_pretrained("MIT/ast-finetuned-audioset-10-10-0.4593").to(device)
|
12 |
+
|
13 |
+
|
14 |
+
data_split = "train"
|
15 |
+
with open(f'data/NSynth/{data_split}_examples.json') as f:
|
16 |
+
data = json.load(f)
|
17 |
+
|
18 |
+
def read_signal(note_str):
|
19 |
+
y, sr = librosa.load(f"data/NSynth/nsynth-{data_split}-52/audio/{note_str}.wav", sr=16000)
|
20 |
+
if len(y) >= 64000:
|
21 |
+
y = y[:64000]
|
22 |
+
else:
|
23 |
+
y_extend = [0.0] * 64000
|
24 |
+
y_extend[:len(y)] = y
|
25 |
+
y = y_extend
|
26 |
+
|
27 |
+
return rms_normalize(y)
|
28 |
+
|
29 |
+
for quality in ["bright", "dark", "distortion", "fast_decay", "long_release", "multiphonic", "nonlinear_env", "percussive", "reverb", "tempo-synced"]:
|
30 |
+
features = []
|
31 |
+
for i, (note_str, attributes) in tqdm(enumerate(data.items())):
|
32 |
+
if not attributes["pitch"] == 52:
|
33 |
+
continue
|
34 |
+
if not (quality in attributes['qualities_str']):
|
35 |
+
continue
|
36 |
+
|
37 |
+
signal = read_signal(note_str)
|
38 |
+
feature_for_one_signal = ASTaudio2feature(device, [signal], processor, AST, sampling_rate=16000)[0]
|
39 |
+
features.append(feature_for_one_signal)
|
40 |
+
|
41 |
+
mu, sigma = calculate_statistics(features)
|
42 |
+
print(np.shape(mu))
|
43 |
+
print(np.shape(sigma))
|
44 |
+
|
45 |
+
save_AST_feature(f'{data_split}_{quality}', mu.tolist(), sigma.tolist())
|
46 |
+
|
47 |
+
for instrument_name in ["bass", "brass", "flute", "guitar", "keyboard", "mallet", "organ", "reed", "string", "synth_lead", "vocal"]:
|
48 |
+
features = []
|
49 |
+
for i, (note_str, attributes) in tqdm(enumerate(data.items())):
|
50 |
+
if not attributes["pitch"] == 52:
|
51 |
+
continue
|
52 |
+
if not (attributes["instrument_family_str"] == instrument_name):
|
53 |
+
continue
|
54 |
+
|
55 |
+
signal = read_signal(note_str)
|
56 |
+
feature_for_one_signal = ASTaudio2feature(device, [signal], processor, AST, sampling_rate=16000)[0]
|
57 |
+
features.append(feature_for_one_signal)
|
58 |
+
|
59 |
+
mu, sigma = calculate_statistics(features)
|
60 |
+
print(np.shape(mu))
|
61 |
+
print(np.shape(sigma))
|
62 |
+
|
63 |
+
save_AST_feature(f'{data_split}_{instrument_name}', mu.tolist(), sigma.tolist())
|
metrics/pipelines.py
ADDED
@@ -0,0 +1,144 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import librosa
|
2 |
+
import numpy as np
|
3 |
+
import torch
|
4 |
+
from tqdm import tqdm
|
5 |
+
|
6 |
+
from tools import VAE_out_put_to_spc, rms_normalize, nnData2Audio
|
7 |
+
from model.DiffSynthSampler import DiffSynthSampler
|
8 |
+
|
9 |
+
def sample_pipeline(device, uNet, VAE, MMM, CLAP_tokenizer,
|
10 |
+
positive_prompts, negative_prompts, batchsize, sample_steps, CFG, seed=None, duration=3.0,
|
11 |
+
freq_resolution=512, time_resolution=256, channels=4, VAE_scale=4, timesteps=1000, noise_strategy="repeat", sampler="ddim", return_latent=True):
|
12 |
+
|
13 |
+
height = int(freq_resolution/VAE_scale)
|
14 |
+
width = int(time_resolution/VAE_scale)
|
15 |
+
VAE_encoder, VAE_quantizer, VAE_decoder = VAE._encoder, VAE._vq_vae, VAE._decoder
|
16 |
+
|
17 |
+
text2sound_embedding = \
|
18 |
+
MMM.get_text_features(**CLAP_tokenizer([positive_prompts], padding=True, return_tensors="pt"))[0].to(device)
|
19 |
+
negative_condition = \
|
20 |
+
MMM.get_text_features(**CLAP_tokenizer([negative_prompts], padding=True, return_tensors="pt"))[0].to(device)
|
21 |
+
|
22 |
+
mySampler = DiffSynthSampler(timesteps, height=height, channels=channels, noise_strategy=noise_strategy, mute=True)
|
23 |
+
mySampler.activate_classifier_free_guidance(CFG, negative_condition)
|
24 |
+
|
25 |
+
mySampler.respace(list(np.linspace(0, timesteps - 1, sample_steps, dtype=np.int32)))
|
26 |
+
|
27 |
+
condition = text2sound_embedding.repeat(batchsize, 1)
|
28 |
+
|
29 |
+
latent_representations, initial_noise = \
|
30 |
+
mySampler.sample(model=uNet, shape=(batchsize, channels, height, width), seed=seed,
|
31 |
+
return_tensor=True, condition=condition, sampler=sampler)
|
32 |
+
|
33 |
+
latent_representations = latent_representations[-1]
|
34 |
+
|
35 |
+
quantized_latent_representations, _, (_, _, _) = VAE_quantizer(latent_representations)
|
36 |
+
|
37 |
+
if return_latent:
|
38 |
+
return quantized_latent_representations.detach()
|
39 |
+
reconstruction_batch = VAE_decoder(quantized_latent_representations).to("cpu").detach().numpy()
|
40 |
+
time_resolution = int(time_resolution * ((duration+1) / 4))
|
41 |
+
|
42 |
+
rec_signals = nnData2Audio(reconstruction_batch, resolution=(freq_resolution, time_resolution))
|
43 |
+
rec_signals = [rms_normalize(rec_signal) for rec_signal in rec_signals]
|
44 |
+
|
45 |
+
return quantized_latent_representations.detach(), reconstruction_batch, rec_signals
|
46 |
+
|
47 |
+
def sample_pipeline_GAN(device, gan_generator, VAE, MMM, CLAP_tokenizer,
|
48 |
+
positive_prompts, negative_prompts, batchsize, sample_steps, CFG, seed=None, duration=3.0,
|
49 |
+
freq_resolution=512, time_resolution=256, channels=4, VAE_scale=4, timesteps=1000, noise_strategy="repeat", sampler="ddim", return_latent=True):
|
50 |
+
|
51 |
+
height = int(freq_resolution/VAE_scale)
|
52 |
+
width = int(time_resolution/VAE_scale)
|
53 |
+
VAE_encoder, VAE_quantizer, VAE_decoder = VAE._encoder, VAE._vq_vae, VAE._decoder
|
54 |
+
|
55 |
+
text2sound_embedding = \
|
56 |
+
MMM.get_text_features(**CLAP_tokenizer([positive_prompts], padding=True, return_tensors="pt"))[0].to(device)
|
57 |
+
|
58 |
+
condition = text2sound_embedding.repeat(batchsize, 1)
|
59 |
+
|
60 |
+
noise = torch.randn(batchsize, channels, height, width).to(device)
|
61 |
+
latent_representations = gan_generator(noise, condition)
|
62 |
+
|
63 |
+
quantized_latent_representations, _, (_, _, _) = VAE_quantizer(latent_representations)
|
64 |
+
|
65 |
+
if return_latent:
|
66 |
+
return quantized_latent_representations.detach()
|
67 |
+
reconstruction_batch = VAE_decoder(quantized_latent_representations).to("cpu").detach().numpy()
|
68 |
+
time_resolution = int(time_resolution * ((duration+1) / 4))
|
69 |
+
|
70 |
+
rec_signals = nnData2Audio(reconstruction_batch, resolution=(freq_resolution, time_resolution))
|
71 |
+
rec_signals = [rms_normalize(rec_signal) for rec_signal in rec_signals]
|
72 |
+
|
73 |
+
return quantized_latent_representations.detach(), reconstruction_batch, rec_signals
|
74 |
+
|
75 |
+
def inpaint_pipeline(device, uNet, VAE, MMM, CLAP_tokenizer, use_dynamic_mask, noising_strength, guidance,
|
76 |
+
positive_prompts, negative_prompts, batchsize, sample_steps, CFG, seed=None, duration=3.0, mask_flexivity=0.99,
|
77 |
+
freq_resolution=512, time_resolution=256, channels=4, VAE_scale=4, timesteps=1000, noise_strategy="repeat", sampler="ddim", return_latent=True):
|
78 |
+
|
79 |
+
height = int(freq_resolution/VAE_scale)
|
80 |
+
width = int(time_resolution * ((duration + 1) / 4) / VAE_scale)
|
81 |
+
VAE_encoder, VAE_quantizer, VAE_decoder = VAE._encoder, VAE._vq_vae, VAE._decoder
|
82 |
+
|
83 |
+
|
84 |
+
text2sound_embedding = \
|
85 |
+
MMM.get_text_features(**CLAP_tokenizer([positive_prompts], padding=True, return_tensors="pt"))[0]
|
86 |
+
negative_condition = \
|
87 |
+
MMM.get_text_features(**CLAP_tokenizer([negative_prompts], padding=True, return_tensors="pt"))[0]
|
88 |
+
|
89 |
+
|
90 |
+
mySampler = DiffSynthSampler(timesteps, height=height, channels=channels, noise_strategy=noise_strategy, mute=True)
|
91 |
+
mySampler.activate_classifier_free_guidance(CFG, negative_condition)
|
92 |
+
mySampler.respace(list(np.linspace(0, timesteps - 1, sample_steps, dtype=np.int32)))
|
93 |
+
|
94 |
+
condition = text2sound_embedding.repeat(batchsize, 1)
|
95 |
+
guidance = guidance.repeat(batchsize, 1, 1, 1).to(device)
|
96 |
+
|
97 |
+
# mask = 1, freeze
|
98 |
+
latent_mask = torch.zeros((batchsize, 1, height, width), dtype=torch.float32).to(device)
|
99 |
+
latent_mask[:, :, :, -int(time_resolution * (1 / 4) / VAE_scale):] = 1.0
|
100 |
+
|
101 |
+
latent_representations, initial_noise = \
|
102 |
+
mySampler.inpaint_sample(model=uNet, shape=(batchsize, channels, height, width),
|
103 |
+
noising_strength=noising_strength,
|
104 |
+
guide_img=guidance, mask=latent_mask, return_tensor=True,
|
105 |
+
condition=condition, sampler=sampler,
|
106 |
+
use_dynamic_mask=use_dynamic_mask,
|
107 |
+
end_noise_level_ratio=0.0,
|
108 |
+
mask_flexivity=mask_flexivity)
|
109 |
+
|
110 |
+
latent_representations = latent_representations[-1]
|
111 |
+
|
112 |
+
quantized_latent_representations, _, (_, _, _) = VAE_quantizer(latent_representations)
|
113 |
+
|
114 |
+
if return_latent:
|
115 |
+
return quantized_latent_representations.detach()
|
116 |
+
reconstruction_batch = VAE_decoder(quantized_latent_representations).to("cpu").detach().numpy()
|
117 |
+
time_resolution = int(time_resolution * ((duration+1) / 4))
|
118 |
+
|
119 |
+
rec_signals = nnData2Audio(reconstruction_batch, resolution=(freq_resolution, time_resolution))
|
120 |
+
rec_signals = [rms_normalize(rec_signal) for rec_signal in rec_signals]
|
121 |
+
|
122 |
+
return quantized_latent_representations.detach(), reconstruction_batch, rec_signals
|
123 |
+
|
124 |
+
|
125 |
+
def generate_audios_with_diffuSynth_sample(device, uNet, VAE, MMM, CLAP_tokenizer, num_batches, positive_prompts, negative_prompts="", CFG=6, sample_steps=10):
|
126 |
+
diffuSynth_signals = []
|
127 |
+
for _ in tqdm(range(num_batches)):
|
128 |
+
_, _, signals = sample_pipeline(device, uNet, VAE, MMM, CLAP_tokenizer,
|
129 |
+
positive_prompts=positive_prompts, negative_prompts=negative_prompts,
|
130 |
+
batchsize=16, sample_steps=sample_steps, CFG=CFG, seed=None, return_latent=False)
|
131 |
+
diffuSynth_signals.extend(signals)
|
132 |
+
return np.array(diffuSynth_signals)
|
133 |
+
|
134 |
+
|
135 |
+
def generate_audios_with_diffuSynth_inpaint(device, uNet, VAE, MMM, CLAP_tokenizer, num_batches, guidance, duration, use_dynamic_mask, noising_strength, positive_prompts, negative_prompts="", CFG=6, sample_steps=10):
|
136 |
+
|
137 |
+
diffuSynth_signals = []
|
138 |
+
for _ in tqdm(range(num_batches)):
|
139 |
+
_, _, signals = inpaint_pipeline(device, uNet, VAE, MMM, CLAP_tokenizer,
|
140 |
+
use_dynamic_mask=use_dynamic_mask, noising_strength=noising_strength, guidance=guidance,
|
141 |
+
positive_prompts=positive_prompts, negative_prompts=negative_prompts, batchsize=16, sample_steps=sample_steps, CFG=CFG, seed=None, duration=duration, mask_flexivity=0.999,
|
142 |
+
return_latent=False)
|
143 |
+
diffuSynth_signals.extend(signals)
|
144 |
+
return np.array(diffuSynth_signals)
|
metrics/pipelines_STFT.py
ADDED
@@ -0,0 +1,100 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import librosa
|
2 |
+
import numpy as np
|
3 |
+
import torch
|
4 |
+
from tqdm import tqdm
|
5 |
+
|
6 |
+
from tools import rms_normalize, decode_stft, depad_STFT
|
7 |
+
from model.DiffSynthSampler import DiffSynthSampler
|
8 |
+
|
9 |
+
def sample_pipeline_STFT(device, uNet, VAE, MMM, CLAP_tokenizer,
|
10 |
+
positive_prompts, negative_prompts, batchsize, sample_steps, CFG, seed=None,
|
11 |
+
freq_resolution=512, time_resolution=256, channels=4, VAE_scale=4, timesteps=1000, noise_strategy="repeat", sampler="ddim", return_latent=True):
|
12 |
+
"Sample a fix-length audio using a diffusion model, including 'ISTFT+' post-processing."
|
13 |
+
|
14 |
+
height = int(freq_resolution/VAE_scale)
|
15 |
+
width = int(time_resolution/VAE_scale)
|
16 |
+
VAE_encoder, VAE_quantizer, VAE_decoder = VAE._encoder, VAE._vq_vae, VAE._decoder
|
17 |
+
|
18 |
+
text2sound_embedding = \
|
19 |
+
MMM.get_text_features(**CLAP_tokenizer([positive_prompts], padding=True, return_tensors="pt"))[0].to(device)
|
20 |
+
negative_condition = \
|
21 |
+
MMM.get_text_features(**CLAP_tokenizer([negative_prompts], padding=True, return_tensors="pt"))[
|
22 |
+
0].to(device)
|
23 |
+
|
24 |
+
mySampler = DiffSynthSampler(timesteps, height=height, channels=channels, noise_strategy=noise_strategy, mute=True)
|
25 |
+
mySampler.activate_classifier_free_guidance(CFG, negative_condition)
|
26 |
+
|
27 |
+
mySampler.respace(list(np.linspace(0, timesteps - 1, sample_steps, dtype=np.int32)))
|
28 |
+
|
29 |
+
condition = text2sound_embedding.repeat(batchsize, 1)
|
30 |
+
|
31 |
+
latent_representations, initial_noise = \
|
32 |
+
mySampler.sample(model=uNet, shape=(batchsize, channels, height, width), seed=seed,
|
33 |
+
return_tensor=True, condition=condition, sampler=sampler)
|
34 |
+
|
35 |
+
latent_representations = latent_representations[-1]
|
36 |
+
|
37 |
+
quantized_latent_representations, _, (_, _, _) = VAE_quantizer(latent_representations)
|
38 |
+
|
39 |
+
if return_latent:
|
40 |
+
return quantized_latent_representations.detach()
|
41 |
+
|
42 |
+
reconstruction_batch = VAE_decoder(quantized_latent_representations).to("cpu").detach().numpy()
|
43 |
+
|
44 |
+
rec_signals = []
|
45 |
+
|
46 |
+
for index, STFT in enumerate(reconstruction_batch):
|
47 |
+
padded_D_rec = decode_stft(STFT)
|
48 |
+
D_rec = depad_STFT(padded_D_rec)
|
49 |
+
# get_audio
|
50 |
+
rec_signal = librosa.istft(D_rec, hop_length=256, win_length=1024)
|
51 |
+
rec_signals.append(rms_normalize(rec_signal))
|
52 |
+
|
53 |
+
return quantized_latent_representations.detach(), reconstruction_batch, rec_signals
|
54 |
+
|
55 |
+
def sample_pipeline_GAN_STFT(device, gan_generator, VAE, MMM, CLAP_tokenizer,
|
56 |
+
positive_prompts, negative_prompts, batchsize, sample_steps, CFG, seed=None,
|
57 |
+
freq_resolution=512, time_resolution=256, channels=4, VAE_scale=4, timesteps=1000, noise_strategy="repeat", sampler="ddim", return_latent=True):
|
58 |
+
"Sample fix-length audio using a GAN, including 'ISTFT+' post-processing."
|
59 |
+
|
60 |
+
height = int(freq_resolution/VAE_scale)
|
61 |
+
width = int(time_resolution/VAE_scale)
|
62 |
+
VAE_encoder, VAE_quantizer, VAE_decoder = VAE._encoder, VAE._vq_vae, VAE._decoder
|
63 |
+
|
64 |
+
text2sound_embedding = \
|
65 |
+
MMM.get_text_features(**CLAP_tokenizer([positive_prompts], padding=True, return_tensors="pt"))[0].to(device)
|
66 |
+
|
67 |
+
condition = text2sound_embedding.repeat(batchsize, 1)
|
68 |
+
|
69 |
+
noise = torch.randn(batchsize, channels, height, width).to(device)
|
70 |
+
latent_representations = gan_generator(noise, condition)
|
71 |
+
|
72 |
+
quantized_latent_representations, _, (_, _, _) = VAE_quantizer(latent_representations)
|
73 |
+
|
74 |
+
if return_latent:
|
75 |
+
return quantized_latent_representations.detach()
|
76 |
+
reconstruction_batch = VAE_decoder(quantized_latent_representations).to("cpu").detach().numpy()
|
77 |
+
|
78 |
+
rec_signals = []
|
79 |
+
|
80 |
+
for index, STFT in enumerate(reconstruction_batch):
|
81 |
+
padded_D_rec = decode_stft(STFT)
|
82 |
+
D_rec = depad_STFT(padded_D_rec)
|
83 |
+
# get_audio
|
84 |
+
rec_signal = librosa.istft(D_rec, hop_length=256, win_length=1024)
|
85 |
+
rec_signals.append(rms_normalize(rec_signal))
|
86 |
+
|
87 |
+
return quantized_latent_representations.detach(), reconstruction_batch, rec_signals
|
88 |
+
|
89 |
+
|
90 |
+
def generate_audios_with_diffuSynth_sample(device, uNet, VAE, MMM, CLAP_tokenizer, num_batches, positive_prompts, negative_prompts="", CFG=6, sample_steps=10):
|
91 |
+
"Sample audios using a diffusion model, including 'ISTFT+' post-processing."
|
92 |
+
|
93 |
+
diffuSynth_signals = []
|
94 |
+
for _ in tqdm(range(num_batches)):
|
95 |
+
_, _, signals = sample_pipeline_STFT(device, uNet, VAE, MMM, CLAP_tokenizer,
|
96 |
+
positive_prompts=positive_prompts, negative_prompts=negative_prompts,
|
97 |
+
batchsize=8, sample_steps=sample_steps, CFG=CFG, seed=None, return_latent=False)
|
98 |
+
diffuSynth_signals.extend(signals)
|
99 |
+
return np.array(diffuSynth_signals)
|
100 |
+
|
metrics/precision_recall.py
ADDED
@@ -0,0 +1,204 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
|
2 |
+
#
|
3 |
+
# This work is licensed under the Creative Commons Attribution-NonCommercial
|
4 |
+
# 4.0 International License. To view a copy of this license, visit
|
5 |
+
# http://creativecommons.org/licenses/by-nc/4.0/ or send a letter to
|
6 |
+
# Creative Commons, PO Box 1866, Mountain View, CA 94042, USA.
|
7 |
+
|
8 |
+
"""k-NN precision and recall."""
|
9 |
+
|
10 |
+
from time import time
|
11 |
+
|
12 |
+
|
13 |
+
# ----------------------------------------------------------------------------
|
14 |
+
|
15 |
+
import numpy as np
|
16 |
+
from tqdm import tqdm
|
17 |
+
|
18 |
+
|
19 |
+
def batch_pairwise_distances(U, V):
|
20 |
+
"""Compute pair-wise distance in a batch of feature."""
|
21 |
+
|
22 |
+
norm_u = np.sum(np.square(U), axis=1)
|
23 |
+
norm_v = np.sum(np.square(V), axis=1)
|
24 |
+
|
25 |
+
norm_u = np.reshape(norm_u, [-1, 1])
|
26 |
+
norm_v = np.reshape(norm_v, [1, -1])
|
27 |
+
|
28 |
+
D = np.maximum(norm_u - 2 * np.dot(U, V.T) + norm_v, 0.0)
|
29 |
+
return D
|
30 |
+
|
31 |
+
|
32 |
+
# ----------------------------------------------------------------------------
|
33 |
+
|
34 |
+
class DistanceBlock():
|
35 |
+
"""Compute pair-wise distance in a batch of feature."""
|
36 |
+
|
37 |
+
def __init__(self, num_features):
|
38 |
+
self.num_features = num_features
|
39 |
+
|
40 |
+
def pairwise_distances(self, U, V):
|
41 |
+
return batch_pairwise_distances(U, V)
|
42 |
+
|
43 |
+
|
44 |
+
|
45 |
+
# ----------------------------------------------------------------------------
|
46 |
+
|
47 |
+
class ManifoldEstimator():
|
48 |
+
"""Estimates the manifold of given feature vectors."""
|
49 |
+
|
50 |
+
def __init__(self, distance_block, features, row_batch_size=16, col_batch_size=16,
|
51 |
+
nhood_sizes=[3], clamp_to_percentile=None, eps=1e-5, mute=False):
|
52 |
+
"""Estimate the manifold of given feature vectors.
|
53 |
+
|
54 |
+
Args:
|
55 |
+
distance_block: DistanceBlock object that distributes pairwise distance
|
56 |
+
calculation to multiple GPUs.
|
57 |
+
features (np.array/tf.Tensor): Matrix of feature vectors to estimate their manifold.
|
58 |
+
row_batch_size (int): Row batch size to compute pairwise distances
|
59 |
+
(parameter to trade-off between memory usage and performance).
|
60 |
+
col_batch_size (int): Column batch size to compute pairwise distances.
|
61 |
+
nhood_sizes (list): Number of neighbors used to estimate the manifold.
|
62 |
+
clamp_to_percentile (float): Prune hyperspheres that have radius larger than
|
63 |
+
the given percentile.
|
64 |
+
eps (float): Small number for numerical stability.
|
65 |
+
"""
|
66 |
+
num_images = features.shape[0]
|
67 |
+
self.nhood_sizes = nhood_sizes
|
68 |
+
self.num_nhoods = len(nhood_sizes)
|
69 |
+
self.eps = eps
|
70 |
+
self.row_batch_size = row_batch_size
|
71 |
+
self.col_batch_size = col_batch_size
|
72 |
+
self._ref_features = features
|
73 |
+
self._distance_block = distance_block
|
74 |
+
self.mute = mute
|
75 |
+
|
76 |
+
# Estimate manifold of features by calculating distances to k-NN of each sample.
|
77 |
+
self.D = np.zeros([num_images, self.num_nhoods], dtype=np.float32)
|
78 |
+
distance_batch = np.zeros([row_batch_size, num_images], dtype=np.float32)
|
79 |
+
seq = np.arange(max(self.nhood_sizes) + 1, dtype=np.int32)
|
80 |
+
|
81 |
+
if mute:
|
82 |
+
for begin1 in range(0, num_images, row_batch_size):
|
83 |
+
end1 = min(begin1 + row_batch_size, num_images)
|
84 |
+
row_batch = features[begin1:end1]
|
85 |
+
|
86 |
+
for begin2 in range(0, num_images, col_batch_size):
|
87 |
+
end2 = min(begin2 + col_batch_size, num_images)
|
88 |
+
col_batch = features[begin2:end2]
|
89 |
+
|
90 |
+
# Compute distances between batches.
|
91 |
+
distance_batch[0:end1 - begin1, begin2:end2] = self._distance_block.pairwise_distances(row_batch,
|
92 |
+
col_batch)
|
93 |
+
|
94 |
+
# Find the k-nearest neighbor from the current batch.
|
95 |
+
self.D[begin1:end1, :] = np.partition(distance_batch[0:end1 - begin1, :], seq, axis=1)[:, self.nhood_sizes]
|
96 |
+
else:
|
97 |
+
for begin1 in tqdm(range(0, num_images, row_batch_size)):
|
98 |
+
end1 = min(begin1 + row_batch_size, num_images)
|
99 |
+
row_batch = features[begin1:end1]
|
100 |
+
|
101 |
+
for begin2 in range(0, num_images, col_batch_size):
|
102 |
+
end2 = min(begin2 + col_batch_size, num_images)
|
103 |
+
col_batch = features[begin2:end2]
|
104 |
+
|
105 |
+
# Compute distances between batches.
|
106 |
+
distance_batch[0:end1 - begin1, begin2:end2] = self._distance_block.pairwise_distances(row_batch,
|
107 |
+
col_batch)
|
108 |
+
|
109 |
+
# Find the k-nearest neighbor from the current batch.
|
110 |
+
self.D[begin1:end1, :] = np.partition(distance_batch[0:end1 - begin1, :], seq, axis=1)[:, self.nhood_sizes]
|
111 |
+
|
112 |
+
if clamp_to_percentile is not None:
|
113 |
+
max_distances = np.percentile(self.D, clamp_to_percentile, axis=0)
|
114 |
+
self.D[self.D > max_distances] = 0
|
115 |
+
|
116 |
+
def evaluate(self, eval_features, return_realism=False, return_neighbors=False):
|
117 |
+
"""Evaluate if new feature vectors are at the manifold."""
|
118 |
+
num_eval_images = eval_features.shape[0]
|
119 |
+
num_ref_images = self.D.shape[0]
|
120 |
+
distance_batch = np.zeros([self.row_batch_size, num_ref_images], dtype=np.float32)
|
121 |
+
batch_predictions = np.zeros([num_eval_images, self.num_nhoods], dtype=np.int32)
|
122 |
+
max_realism_score = np.zeros([num_eval_images, ], dtype=np.float32)
|
123 |
+
nearest_indices = np.zeros([num_eval_images, ], dtype=np.int32)
|
124 |
+
|
125 |
+
for begin1 in range(0, num_eval_images, self.row_batch_size):
|
126 |
+
end1 = min(begin1 + self.row_batch_size, num_eval_images)
|
127 |
+
feature_batch = eval_features[begin1:end1]
|
128 |
+
|
129 |
+
for begin2 in range(0, num_ref_images, self.col_batch_size):
|
130 |
+
end2 = min(begin2 + self.col_batch_size, num_ref_images)
|
131 |
+
ref_batch = self._ref_features[begin2:end2]
|
132 |
+
|
133 |
+
distance_batch[0:end1 - begin1, begin2:end2] = self._distance_block.pairwise_distances(feature_batch,
|
134 |
+
ref_batch)
|
135 |
+
|
136 |
+
# From the minibatch of new feature vectors, determine if they are in the estimated manifold.
|
137 |
+
# If a feature vector is inside a hypersphere of some reference sample, then
|
138 |
+
# the new sample lies at the estimated manifold.
|
139 |
+
# The radii of the hyperspheres are determined from distances of neighborhood size k.
|
140 |
+
samples_in_manifold = distance_batch[0:end1 - begin1, :, None] <= self.D
|
141 |
+
batch_predictions[begin1:end1] = np.any(samples_in_manifold, axis=1).astype(np.int32)
|
142 |
+
|
143 |
+
max_realism_score[begin1:end1] = np.max(self.D[:, 0] / (distance_batch[0:end1 - begin1, :] + self.eps),
|
144 |
+
axis=1)
|
145 |
+
nearest_indices[begin1:end1] = np.argmin(distance_batch[0:end1 - begin1, :], axis=1)
|
146 |
+
|
147 |
+
if return_realism and return_neighbors:
|
148 |
+
return batch_predictions, max_realism_score, nearest_indices
|
149 |
+
elif return_realism:
|
150 |
+
return batch_predictions, max_realism_score
|
151 |
+
elif return_neighbors:
|
152 |
+
return batch_predictions, nearest_indices
|
153 |
+
|
154 |
+
return batch_predictions
|
155 |
+
|
156 |
+
|
157 |
+
# ----------------------------------------------------------------------------
|
158 |
+
|
159 |
+
def knn_precision_recall_features(ref_features, eval_features, nhood_sizes=[3],
|
160 |
+
row_batch_size=10000, col_batch_size=50000, mute=False):
|
161 |
+
"""Calculates k-NN precision and recall for two sets of feature vectors.
|
162 |
+
|
163 |
+
Args:
|
164 |
+
ref_features (np.array/tf.Tensor): Feature vectors of reference images.
|
165 |
+
eval_features (np.array/tf.Tensor): Feature vectors of generated images.
|
166 |
+
nhood_sizes (list): Number of neighbors used to estimate the manifold.
|
167 |
+
row_batch_size (int): Row batch size to compute pairwise distances
|
168 |
+
(parameter to trade-off between memory usage and performance).
|
169 |
+
col_batch_size (int): Column batch size to compute pairwise distances.
|
170 |
+
num_gpus (int): Number of GPUs used to evaluate precision and recall.
|
171 |
+
|
172 |
+
Returns:
|
173 |
+
State (dict): Dict that contains precision and recall calculated from
|
174 |
+
ref_features and eval_features.
|
175 |
+
"""
|
176 |
+
state = dict()
|
177 |
+
num_images = ref_features.shape[0]
|
178 |
+
num_features = ref_features.shape[1]
|
179 |
+
|
180 |
+
# Initialize DistanceBlock and ManifoldEstimators.
|
181 |
+
distance_block = DistanceBlock(num_features)
|
182 |
+
ref_manifold = ManifoldEstimator(distance_block, ref_features, row_batch_size, col_batch_size, nhood_sizes, mute=mute)
|
183 |
+
eval_manifold = ManifoldEstimator(distance_block, eval_features, row_batch_size, col_batch_size, nhood_sizes, mute=mute)
|
184 |
+
|
185 |
+
# Evaluate precision and recall using k-nearest neighbors.
|
186 |
+
if not mute:
|
187 |
+
print('Evaluating k-NN precision and recall with %i samples...' % num_images)
|
188 |
+
start = time()
|
189 |
+
|
190 |
+
# Precision: How many points from eval_features are in ref_features manifold.
|
191 |
+
precision = ref_manifold.evaluate(eval_features)
|
192 |
+
state['precision'] = precision.mean(axis=0)
|
193 |
+
|
194 |
+
# Recall: How many points from ref_features are in eval_features manifold.
|
195 |
+
recall = eval_manifold.evaluate(ref_features)
|
196 |
+
state['recall'] = recall.mean(axis=0)
|
197 |
+
|
198 |
+
if not mute:
|
199 |
+
print('Evaluated k-NN precision and recall in: %gs' % (time() - start))
|
200 |
+
|
201 |
+
return state
|
202 |
+
|
203 |
+
# ----------------------------------------------------------------------------
|
204 |
+
|
metrics/visualizations.py
ADDED
@@ -0,0 +1,123 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
from matplotlib import pyplot as plt
|
3 |
+
from scipy.fft import fft
|
4 |
+
from scipy.signal import savgol_filter
|
5 |
+
from tools import rms_normalize
|
6 |
+
|
7 |
+
colors = [
|
8 |
+
# (0, 0, 0), # Black
|
9 |
+
# (86, 180, 233), # Sky blue
|
10 |
+
# (240, 228, 66), # Yellow
|
11 |
+
# (204, 121, 167), # Reddish purple
|
12 |
+
(213, 94, 0), # Vermilion
|
13 |
+
(0, 114, 178), # Blue
|
14 |
+
(230, 159, 0), # Orange
|
15 |
+
(0, 158, 115), # Bluish green
|
16 |
+
]
|
17 |
+
|
18 |
+
|
19 |
+
def plot_psd_multiple_signals(signals_list, labels_list, sample_rate=16000, window_size=500,
|
20 |
+
figsize=(10, 6), save_path=None, normalize=False):
|
21 |
+
"""
|
22 |
+
在同一张图上绘制多组音频信号的功率谱密度比较图,使用对数刻度的响度轴(以2为底),并应用平滑处理。
|
23 |
+
|
24 |
+
参数:
|
25 |
+
signals_list: 包含多组音频信号的列表,每组信号形状为 [sample_number, sample_length] 的numpy array
|
26 |
+
labels_list: 每组音频信号对应的标签字符串列表
|
27 |
+
sample_rate: 音频的采样率
|
28 |
+
"""
|
29 |
+
|
30 |
+
# 确保传入的signals_list和labels_list长度相同
|
31 |
+
assert len(signals_list) == len(labels_list), "每组信号必须有一个对应的标签。"
|
32 |
+
|
33 |
+
signals_list = [np.array([rms_normalize(signal) for signal in signals]) for signals in signals_list]
|
34 |
+
|
35 |
+
# 绘图准备
|
36 |
+
plt.figure(figsize=figsize)
|
37 |
+
|
38 |
+
# 遍历所有的音频信号
|
39 |
+
i = 0
|
40 |
+
for signal, label in zip(signals_list, labels_list):
|
41 |
+
# 计算FFT
|
42 |
+
fft_signal = fft(signal, axis=1)
|
43 |
+
|
44 |
+
# 计算平均功率谱密度
|
45 |
+
psd_signal = np.mean(np.abs(fft_signal)**2, axis=0)
|
46 |
+
|
47 |
+
# 计算频率轴
|
48 |
+
freqs = np.fft.fftfreq(signal.shape[1], 1/sample_rate)
|
49 |
+
|
50 |
+
# 应用Savitzky-Golay滤波器进行平滑
|
51 |
+
psd_smoothed = savgol_filter(np.log2(psd_signal[:signal.shape[1] // 2] + 1), window_size, 3) # 窗口大小51, 多项式阶数3
|
52 |
+
|
53 |
+
# Normalize each curve if normalize is True
|
54 |
+
if normalize:
|
55 |
+
psd_smoothed /= np.mean(psd_smoothed)
|
56 |
+
|
57 |
+
# 绘制每组信号的功率谱密度
|
58 |
+
plt.plot(freqs[:signal.shape[1] // 2], psd_smoothed, label=label, color=[x/255.0 for x in colors[i % len(colors)]], linewidth=1)
|
59 |
+
i += 1
|
60 |
+
|
61 |
+
# 设置图表元素
|
62 |
+
plt.xlabel('Frequency (Hz)')
|
63 |
+
plt.ylabel('Mean Log-Amplitude')
|
64 |
+
plt.legend()
|
65 |
+
|
66 |
+
# 根据save_path参数决定保存图像还是直接显示
|
67 |
+
if save_path:
|
68 |
+
plt.savefig(save_path)
|
69 |
+
else:
|
70 |
+
plt.show()
|
71 |
+
|
72 |
+
|
73 |
+
def plot_amplitude_over_time(signals_list, labels_list, sample_rate=16000, window_size=500,
|
74 |
+
figsize=(10, 6), save_path=None, normalize=False, start_time=0):
|
75 |
+
"""
|
76 |
+
Plot the loudness of multiple sets of audio signals over time on the same graph,
|
77 |
+
using a logarithmic scale for the loudness axis (base 2), with smoothing applied.
|
78 |
+
|
79 |
+
Parameters:
|
80 |
+
signals_list: List of sets of audio signals, each set is a numpy array with shape [sample_number, sample_length]
|
81 |
+
labels_list: List of labels corresponding to each set of audio signals
|
82 |
+
sample_rate: Sampling rate of the audio
|
83 |
+
window_size: Window size for the Savitzky-Golay filter
|
84 |
+
figsize: Figure size
|
85 |
+
save_path: Path to save the figure, if None, the figure will be displayed
|
86 |
+
normalize: Whether to normalize each curve so that the sum of each curve is the same
|
87 |
+
start_time: Time (in seconds) to start plotting, only data after this time will be retained
|
88 |
+
"""
|
89 |
+
assert len(signals_list) == len(labels_list), f"len(signals_list) != len(labels_list) for " \
|
90 |
+
f"len(signals_list) = {len(signals_list)} and len(labels_list) = {len(labels_list)}"
|
91 |
+
|
92 |
+
# Compute starting sample index
|
93 |
+
start_sample = int(start_time * sample_rate)
|
94 |
+
|
95 |
+
# Normalize signals and truncate data
|
96 |
+
signals_list = [np.array([rms_normalize(signal)[start_sample:] for signal in signals]) for signals in signals_list]
|
97 |
+
time_axis = np.arange(start_sample, start_sample + signals_list[0].shape[1]) / sample_rate
|
98 |
+
|
99 |
+
plt.figure(figsize=figsize)
|
100 |
+
|
101 |
+
i = 0
|
102 |
+
for signal, label in zip(signals_list, labels_list):
|
103 |
+
amplitude_mean = np.mean(np.abs(signal), axis=0)
|
104 |
+
|
105 |
+
amplitude_smoothed = savgol_filter(np.log2(amplitude_mean + 1), window_size, 3)
|
106 |
+
|
107 |
+
# Normalize each curve if normalize is True
|
108 |
+
if normalize:
|
109 |
+
amplitude_smoothed /= np.mean(amplitude_smoothed)
|
110 |
+
|
111 |
+
plt.plot(time_axis, amplitude_smoothed, label=label, color=[x/255.0 for x in colors[i % len(colors)]], linewidth=1)
|
112 |
+
i += 1
|
113 |
+
|
114 |
+
plt.xlabel('Time (seconds)')
|
115 |
+
plt.ylabel('Mean Log-Amplitude')
|
116 |
+
plt.legend()
|
117 |
+
|
118 |
+
# Save or show the figure based on save_path parameter
|
119 |
+
if save_path:
|
120 |
+
plt.savefig(save_path)
|
121 |
+
else:
|
122 |
+
plt.show()
|
123 |
+
|
model/DiffSynthSampler.py
ADDED
@@ -0,0 +1,425 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import torch
|
3 |
+
from tqdm import tqdm
|
4 |
+
|
5 |
+
|
6 |
+
def _extract_into_tensor(arr, timesteps, broadcast_shape):
|
7 |
+
"""
|
8 |
+
Extract values from a 1-D numpy array for a batch of indices.
|
9 |
+
|
10 |
+
:param arr: the 1-D numpy array.
|
11 |
+
:param timesteps: a tensor of indices into the array to extract.
|
12 |
+
:param broadcast_shape: a larger shape of K dimensions with the batch
|
13 |
+
dimension equal to the length of timesteps.
|
14 |
+
:return: a tensor of shape [batch_size, 1, ...] where the shape has K dims.
|
15 |
+
"""
|
16 |
+
res = torch.from_numpy(arr).to(device=timesteps.device)[timesteps].float()
|
17 |
+
while len(res.shape) < len(broadcast_shape):
|
18 |
+
res = res[..., None]
|
19 |
+
return res.expand(broadcast_shape)
|
20 |
+
|
21 |
+
|
22 |
+
class DiffSynthSampler:
|
23 |
+
|
24 |
+
def __init__(self, timesteps, beta_start=0.0001, beta_end=0.02, device=None, mute=False,
|
25 |
+
height=128, max_batchsize=16, max_width=256, channels=4, train_width=64, noise_strategy="repeat"):
|
26 |
+
if device is None:
|
27 |
+
self.device = "cuda" if torch.cuda.is_available() else "cpu"
|
28 |
+
else:
|
29 |
+
self.device = device
|
30 |
+
self.height = height
|
31 |
+
self.train_width = train_width
|
32 |
+
self.max_batchsize = max_batchsize
|
33 |
+
self.max_width = max_width
|
34 |
+
self.channels = channels
|
35 |
+
self.num_timesteps = timesteps
|
36 |
+
self.timestep_map = list(range(self.num_timesteps))
|
37 |
+
self.betas = np.array(np.linspace(beta_start, beta_end, self.num_timesteps), dtype=np.float64)
|
38 |
+
self.respaced = False
|
39 |
+
self.define_beta_schedule()
|
40 |
+
self.CFG = 1.0
|
41 |
+
self.mute = mute
|
42 |
+
self.noise_strategy = noise_strategy
|
43 |
+
|
44 |
+
def get_deterministic_noise_tensor_non_repeat(self, batchsize, width, reference_noise=None):
|
45 |
+
if reference_noise is None:
|
46 |
+
large_noise_tensor = torch.randn((self.max_batchsize, self.channels, self.height, self.max_width), device=self.device)
|
47 |
+
else:
|
48 |
+
assert reference_noise.shape == (batchsize, self.channels, self.height, self.max_width), "reference_noise shape mismatch"
|
49 |
+
large_noise_tensor = reference_noise
|
50 |
+
return large_noise_tensor[:batchsize, :, :, :width], None
|
51 |
+
|
52 |
+
def get_deterministic_noise_tensor(self, batchsize, width, reference_noise=None):
|
53 |
+
if self.noise_strategy == "repeat":
|
54 |
+
noise, concat_points = self.get_deterministic_noise_tensor_repeat(batchsize, width, reference_noise=reference_noise)
|
55 |
+
return noise, concat_points
|
56 |
+
else:
|
57 |
+
noise, concat_points = self.get_deterministic_noise_tensor_non_repeat(batchsize, width, reference_noise=reference_noise)
|
58 |
+
return noise, concat_points
|
59 |
+
|
60 |
+
|
61 |
+
def get_deterministic_noise_tensor_repeat(self, batchsize, width, reference_noise=None):
|
62 |
+
# 生成与训练数据长度相等的噪音
|
63 |
+
if reference_noise is None:
|
64 |
+
train_noise_tensor = torch.randn((self.max_batchsize, self.channels, self.height, self.train_width), device=self.device)
|
65 |
+
else:
|
66 |
+
assert reference_noise.shape == (batchsize, self.channels, self.height, self.train_width), "reference_noise shape mismatch"
|
67 |
+
train_noise_tensor = reference_noise
|
68 |
+
|
69 |
+
release_width = int(self.train_width * 1.0 / 4)
|
70 |
+
first_part_width = self.train_width - release_width
|
71 |
+
|
72 |
+
first_part = train_noise_tensor[:batchsize, :, :, :first_part_width]
|
73 |
+
release_part = train_noise_tensor[:batchsize, :, :, -release_width:]
|
74 |
+
|
75 |
+
# 如果所需 length 小于等于 origin length,去掉 first_part 的中间部分
|
76 |
+
if width <= self.train_width:
|
77 |
+
_first_part_head_width = int((width - release_width) / 2)
|
78 |
+
_first_part_tail_width = width - release_width - _first_part_head_width
|
79 |
+
all_parts = [first_part[:, :, :, :_first_part_head_width], first_part[:, :, :, -_first_part_tail_width:], release_part]
|
80 |
+
|
81 |
+
# 沿第四维度拼接张量
|
82 |
+
noise_tensor = torch.cat(all_parts, dim=3)
|
83 |
+
|
84 |
+
# 记录拼接点的位置
|
85 |
+
concat_points = [0]
|
86 |
+
for part in all_parts[:-1]:
|
87 |
+
next_point = concat_points[-1] + part.size(3)
|
88 |
+
concat_points.append(next_point)
|
89 |
+
|
90 |
+
return noise_tensor, concat_points
|
91 |
+
|
92 |
+
# 如果所需 length 大于 origin length,不断地从中间插入 first_part 的中间部分
|
93 |
+
else:
|
94 |
+
# 计算需要重复front_width的次数
|
95 |
+
repeats = (width - release_width) // first_part_width
|
96 |
+
extra = (width - release_width) % first_part_width
|
97 |
+
|
98 |
+
_repeat_first_part_head_width = int(first_part_width / 2)
|
99 |
+
_repeat_first_part_tail_width = first_part_width - _repeat_first_part_head_width
|
100 |
+
|
101 |
+
repeated_first_head_parts = [first_part[:, :, :, :_repeat_first_part_head_width] for _ in range(repeats)]
|
102 |
+
repeated_first_tail_parts = [first_part[:, :, :, -_repeat_first_part_tail_width:] for _ in range(repeats)]
|
103 |
+
|
104 |
+
# 计算起始索引
|
105 |
+
_middle_part_start_index = (first_part_width - extra) // 2
|
106 |
+
# 切片张量以获取中间部分
|
107 |
+
middle_part = first_part[:, :, :, _middle_part_start_index: _middle_part_start_index + extra]
|
108 |
+
|
109 |
+
all_parts = repeated_first_head_parts + [middle_part] + repeated_first_tail_parts + [release_part]
|
110 |
+
|
111 |
+
# 沿第四维度拼接张量
|
112 |
+
noise_tensor = torch.cat(all_parts, dim=3)
|
113 |
+
|
114 |
+
# 记录拼接点的位置
|
115 |
+
concat_points = [0]
|
116 |
+
for part in all_parts[:-1]:
|
117 |
+
next_point = concat_points[-1] + part.size(3)
|
118 |
+
concat_points.append(next_point)
|
119 |
+
|
120 |
+
return noise_tensor, concat_points
|
121 |
+
|
122 |
+
def define_beta_schedule(self):
|
123 |
+
assert self.respaced == False, "This schedule has already been respaced!"
|
124 |
+
# define alphas
|
125 |
+
self.alphas = 1.0 - self.betas
|
126 |
+
self.alphas_cumprod = np.cumprod(self.alphas, axis=0)
|
127 |
+
self.alphas_cumprod_prev = np.append(1.0, self.alphas_cumprod[:-1])
|
128 |
+
self.alphas_cumprod_next = np.append(self.alphas_cumprod[1:], 0.0)
|
129 |
+
|
130 |
+
# calculations for diffusion q(x_t | x_{t-1}) and others
|
131 |
+
self.sqrt_alphas_cumprod = np.sqrt(self.alphas_cumprod)
|
132 |
+
self.sqrt_one_minus_alphas_cumprod = np.sqrt(1.0 - self.alphas_cumprod)
|
133 |
+
self.log_one_minus_alphas_cumprod = np.log(1.0 - self.alphas_cumprod)
|
134 |
+
self.sqrt_recip_alphas_cumprod = np.sqrt(1.0 / self.alphas_cumprod)
|
135 |
+
self.sqrt_recip_alphas = np.sqrt(1.0 / self.alphas)
|
136 |
+
self.sqrt_recipm1_alphas_cumprod = np.sqrt(1.0 / self.alphas_cumprod - 1)
|
137 |
+
|
138 |
+
# calculations for posterior q(x_{t-1} | x_t, x_0)
|
139 |
+
self.posterior_variance = (self.betas * (1.0 - self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod))
|
140 |
+
|
141 |
+
def activate_classifier_free_guidance(self, CFG, unconditional_condition):
|
142 |
+
assert (
|
143 |
+
not unconditional_condition is None) or CFG == 1.0, "For CFG != 1.0, unconditional_condition must be available"
|
144 |
+
self.CFG = CFG
|
145 |
+
self.unconditional_condition = unconditional_condition
|
146 |
+
|
147 |
+
def respace(self, use_timesteps=None):
|
148 |
+
if not use_timesteps is None:
|
149 |
+
last_alpha_cumprod = 1.0
|
150 |
+
new_betas = []
|
151 |
+
self.timestep_map = []
|
152 |
+
for i, _alpha_cumprod in enumerate(self.alphas_cumprod):
|
153 |
+
if i in use_timesteps:
|
154 |
+
new_betas.append(1 - _alpha_cumprod / last_alpha_cumprod)
|
155 |
+
last_alpha_cumprod = _alpha_cumprod
|
156 |
+
self.timestep_map.append(i)
|
157 |
+
self.num_timesteps = len(use_timesteps)
|
158 |
+
self.betas = np.array(new_betas)
|
159 |
+
self.define_beta_schedule()
|
160 |
+
self.respaced = True
|
161 |
+
|
162 |
+
def generate_linear_noise(self, shape, variance=1.0, first_endpoint=None, second_endpoint=None):
|
163 |
+
assert shape[1] == self.channels, "shape[1] != self.channels"
|
164 |
+
assert shape[2] == self.height, "shape[2] != self.height"
|
165 |
+
noise = torch.empty(*shape, device=self.device)
|
166 |
+
|
167 |
+
# 第三种情况:两个端点都不是None,进行线性插值
|
168 |
+
if first_endpoint is not None and second_endpoint is not None:
|
169 |
+
for i in range(shape[0]):
|
170 |
+
alpha = i / (shape[0] - 1) # 插值系数
|
171 |
+
noise[i] = alpha * second_endpoint + (1 - alpha) * first_endpoint
|
172 |
+
return noise # 返回插值后的结果,不需要进行后续的均值和方差调整
|
173 |
+
else:
|
174 |
+
# 第一个端点不是None
|
175 |
+
if first_endpoint is not None:
|
176 |
+
noise[0] = first_endpoint
|
177 |
+
if shape[0] > 1:
|
178 |
+
noise[1], _ = self.get_deterministic_noise_tensor(1, shape[3])[0]
|
179 |
+
else:
|
180 |
+
noise[0], _ = self.get_deterministic_noise_tensor(1, shape[3])[0]
|
181 |
+
if shape[0] > 1:
|
182 |
+
noise[1], _ = self.get_deterministic_noise_tensor(1, shape[3])[0]
|
183 |
+
|
184 |
+
# 生成其他的噪声点
|
185 |
+
for i in range(2, shape[0]):
|
186 |
+
noise[i] = 2 * noise[i - 1] - noise[i - 2]
|
187 |
+
|
188 |
+
# 当只有一个端点被指定时
|
189 |
+
current_var = noise.var()
|
190 |
+
stddev_ratio = torch.sqrt(variance / current_var)
|
191 |
+
noise = noise * stddev_ratio
|
192 |
+
|
193 |
+
# 如果第一个端点被指定,进行平移调整
|
194 |
+
if first_endpoint is not None:
|
195 |
+
shift = first_endpoint - noise[0]
|
196 |
+
noise += shift
|
197 |
+
|
198 |
+
return noise
|
199 |
+
|
200 |
+
def q_sample(self, x_start, t, noise=None):
|
201 |
+
"""
|
202 |
+
Diffuse the data for a given number of diffusion steps.
|
203 |
+
|
204 |
+
In other words, sample from q(x_t | x_0).
|
205 |
+
|
206 |
+
:param x_start: the initial data batch.
|
207 |
+
:param t: the number of diffusion steps (minus 1). Here, 0 means one step.
|
208 |
+
:param noise: if specified, the split-out normal noise.
|
209 |
+
:return: A noisy version of x_start.
|
210 |
+
"""
|
211 |
+
assert x_start.shape[1] == self.channels, "shape[1] != self.channels"
|
212 |
+
assert x_start.shape[2] == self.height, "shape[2] != self.height"
|
213 |
+
|
214 |
+
if noise is None:
|
215 |
+
# noise = torch.randn_like(x_start)
|
216 |
+
noise, _ = self.get_deterministic_noise_tensor(x_start.shape[0], x_start.shape[3])
|
217 |
+
|
218 |
+
assert noise.shape == x_start.shape
|
219 |
+
return (
|
220 |
+
_extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start
|
221 |
+
+ _extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape)
|
222 |
+
* noise
|
223 |
+
)
|
224 |
+
|
225 |
+
@torch.no_grad()
|
226 |
+
def ddim_sample(self, model, x, t, condition=None, ddim_eta=0.0):
|
227 |
+
map_tensor = torch.tensor(self.timestep_map, device=t.device, dtype=t.dtype)
|
228 |
+
mapped_t = map_tensor[t]
|
229 |
+
|
230 |
+
# Todo: add CFG
|
231 |
+
|
232 |
+
if self.CFG == 1.0:
|
233 |
+
pred_noise = model(x, mapped_t, condition)
|
234 |
+
else:
|
235 |
+
unconditional_condition = self.unconditional_condition.unsqueeze(0).repeat(
|
236 |
+
*([x.shape[0]] + [1] * len(self.unconditional_condition.shape)))
|
237 |
+
x_in = torch.cat([x] * 2)
|
238 |
+
t_in = torch.cat([mapped_t] * 2)
|
239 |
+
c_in = torch.cat([unconditional_condition, condition])
|
240 |
+
noise_uncond, noise = model(x_in, t_in, c_in).chunk(2)
|
241 |
+
pred_noise = noise_uncond + self.CFG * (noise - noise_uncond)
|
242 |
+
|
243 |
+
# Todo: END
|
244 |
+
|
245 |
+
alpha_cumprod_t = _extract_into_tensor(self.alphas_cumprod, t, x.shape)
|
246 |
+
alpha_cumprod_t_prev = _extract_into_tensor(self.alphas_cumprod_prev, t, x.shape)
|
247 |
+
|
248 |
+
pred_x0 = (x - torch.sqrt((1. - alpha_cumprod_t)) * pred_noise) / torch.sqrt(alpha_cumprod_t)
|
249 |
+
|
250 |
+
sigmas_t = (
|
251 |
+
ddim_eta
|
252 |
+
* torch.sqrt((1 - alpha_cumprod_t_prev) / (1 - alpha_cumprod_t))
|
253 |
+
* torch.sqrt(1 - alpha_cumprod_t / alpha_cumprod_t_prev)
|
254 |
+
)
|
255 |
+
|
256 |
+
pred_dir_xt = torch.sqrt(1 - alpha_cumprod_t_prev - sigmas_t ** 2) * pred_noise
|
257 |
+
|
258 |
+
|
259 |
+
step_noise, _ = self.get_deterministic_noise_tensor(x.shape[0], x.shape[3])
|
260 |
+
|
261 |
+
|
262 |
+
x_prev = torch.sqrt(alpha_cumprod_t_prev) * pred_x0 + pred_dir_xt + sigmas_t * step_noise
|
263 |
+
|
264 |
+
return x_prev
|
265 |
+
|
266 |
+
def p_sample(self, model, x, t, condition=None, sampler="ddim"):
|
267 |
+
if sampler == "ddim":
|
268 |
+
return self.ddim_sample(model, x, t, condition=condition, ddim_eta=0.0)
|
269 |
+
elif sampler == "ddpm":
|
270 |
+
return self.ddim_sample(model, x, t, condition=condition, ddim_eta=1.0)
|
271 |
+
else:
|
272 |
+
raise NotImplementedError()
|
273 |
+
|
274 |
+
def get_dynamic_masks(self, n_masks, shape, concat_points, mask_flexivity=0.8):
|
275 |
+
release_length = int(self.train_width / 4)
|
276 |
+
assert shape[3] == (concat_points[-1] + release_length), "shape[3] != (concat_points[-1] + release_length)"
|
277 |
+
|
278 |
+
fraction_lengths = [concat_points[i + 1] - concat_points[i] for i in range(len(concat_points) - 1)]
|
279 |
+
|
280 |
+
# Todo: remove hard-coding
|
281 |
+
n_guidance_steps = int(n_masks * mask_flexivity)
|
282 |
+
n_free_steps = n_masks - n_guidance_steps
|
283 |
+
|
284 |
+
masks = []
|
285 |
+
# Todo: 在一半的 steps 内收缩 mask。也就是说,在后程对 release 以外的区域不做inpaint,而是 img2img
|
286 |
+
for i in range(n_guidance_steps):
|
287 |
+
# mask = 1, freeze
|
288 |
+
step_i_mask = torch.zeros((shape[0], 1, shape[2], shape[3]), dtype=torch.float32).to(self.device)
|
289 |
+
step_i_mask[:, :, :, -release_length:] = 1.0
|
290 |
+
|
291 |
+
for fraction_index in range(len(fraction_lengths)):
|
292 |
+
|
293 |
+
_fraction_mask_length = int((n_guidance_steps - 1 - i) / (n_guidance_steps - 1) * fraction_lengths[fraction_index])
|
294 |
+
|
295 |
+
if fraction_index == 0:
|
296 |
+
step_i_mask[:, :, :, :_fraction_mask_length] = 1.0
|
297 |
+
elif fraction_index == len(fraction_lengths) - 1:
|
298 |
+
if not _fraction_mask_length == 0:
|
299 |
+
step_i_mask[:, :, :, -_fraction_mask_length - release_length:] = 1.0
|
300 |
+
else:
|
301 |
+
fraction_mask_start_position = int((fraction_lengths[fraction_index] - _fraction_mask_length) / 2)
|
302 |
+
|
303 |
+
step_i_mask[:, :, :,
|
304 |
+
concat_points[fraction_index] + fraction_mask_start_position:concat_points[
|
305 |
+
fraction_index] + fraction_mask_start_position + _fraction_mask_length] = 1.0
|
306 |
+
masks.append(step_i_mask)
|
307 |
+
|
308 |
+
for i in range(n_free_steps):
|
309 |
+
step_i_mask = torch.zeros((shape[0], 1, shape[2], shape[3]), dtype=torch.float32).to(self.device)
|
310 |
+
step_i_mask[:, :, :, -release_length:] = 1.0
|
311 |
+
masks.append(step_i_mask)
|
312 |
+
|
313 |
+
masks.reverse()
|
314 |
+
return masks
|
315 |
+
|
316 |
+
@torch.no_grad()
|
317 |
+
def p_sample_loop(self, model, shape, initial_noise=None, start_noise_level_ratio=1.0, end_noise_level_ratio=0.0,
|
318 |
+
return_tensor=False, condition=None, guide_img=None,
|
319 |
+
mask=None, sampler="ddim", inpaint=False, use_dynamic_mask=False, mask_flexivity=0.8):
|
320 |
+
|
321 |
+
assert shape[1] == self.channels, "shape[1] != self.channels"
|
322 |
+
assert shape[2] == self.height, "shape[2] != self.height"
|
323 |
+
|
324 |
+
initial_noise, _ = self.get_deterministic_noise_tensor(shape[0], shape[3], reference_noise=initial_noise)
|
325 |
+
assert initial_noise.shape == shape, "initial_noise.shape != shape"
|
326 |
+
|
327 |
+
start_noise_level_index = int(self.num_timesteps * start_noise_level_ratio) # not included!!!
|
328 |
+
end_noise_level_index = int(self.num_timesteps * end_noise_level_ratio)
|
329 |
+
|
330 |
+
timesteps = reversed(range(end_noise_level_index, start_noise_level_index))
|
331 |
+
|
332 |
+
# configure initial img
|
333 |
+
assert (start_noise_level_ratio == 1.0) or (
|
334 |
+
not guide_img is None), "A guide_img must be given to sample from a non-pure-noise."
|
335 |
+
|
336 |
+
if guide_img is None:
|
337 |
+
img = initial_noise
|
338 |
+
else:
|
339 |
+
guide_img, concat_points = self.get_deterministic_noise_tensor_repeat(shape[0], shape[3], reference_noise=guide_img)
|
340 |
+
assert guide_img.shape == shape, "guide_img.shape != shape"
|
341 |
+
|
342 |
+
if start_noise_level_index > 0:
|
343 |
+
t = torch.full((shape[0],), start_noise_level_index-1, device=self.device).long() # -1 for start_noise_level_index not included
|
344 |
+
img = self.q_sample(guide_img, t, noise=initial_noise)
|
345 |
+
else:
|
346 |
+
print("Zero noise added to the guidance latent representation.")
|
347 |
+
img = guide_img
|
348 |
+
|
349 |
+
# get masks
|
350 |
+
n_masks = start_noise_level_index - end_noise_level_index
|
351 |
+
if use_dynamic_mask:
|
352 |
+
masks = self.get_dynamic_masks(n_masks, shape, concat_points, mask_flexivity)
|
353 |
+
else:
|
354 |
+
masks = [mask for _ in range(n_masks)]
|
355 |
+
|
356 |
+
imgs = [img]
|
357 |
+
current_mask = None
|
358 |
+
|
359 |
+
|
360 |
+
for i in tqdm(timesteps, total=start_noise_level_index - end_noise_level_index, disable=self.mute):
|
361 |
+
|
362 |
+
# if i == 3:
|
363 |
+
# return [img], initial_noise # 第1排,第1列
|
364 |
+
|
365 |
+
img = self.p_sample(model, img, torch.full((shape[0],), i, device=self.device, dtype=torch.long),
|
366 |
+
condition=condition,
|
367 |
+
sampler=sampler)
|
368 |
+
# if i == 3:
|
369 |
+
# return [img], initial_noise # 第1排,第2列
|
370 |
+
|
371 |
+
if inpaint:
|
372 |
+
if i > 0:
|
373 |
+
t = torch.full((shape[0],), int(i-1), device=self.device).long()
|
374 |
+
img_noise_t = self.q_sample(guide_img, t, noise=initial_noise)
|
375 |
+
# if i == 3:
|
376 |
+
# return [img_noise_t], initial_noise # 第2排,第2列
|
377 |
+
current_mask = masks.pop()
|
378 |
+
img = current_mask * img_noise_t + (1 - current_mask) * img
|
379 |
+
# if i == 3:
|
380 |
+
# return [img], initial_noise # 第1.5排,最后1列
|
381 |
+
else:
|
382 |
+
img = current_mask * guide_img + (1 - current_mask) * img
|
383 |
+
|
384 |
+
if return_tensor:
|
385 |
+
imgs.append(img)
|
386 |
+
else:
|
387 |
+
imgs.append(img.cpu().numpy())
|
388 |
+
|
389 |
+
return imgs, initial_noise
|
390 |
+
|
391 |
+
|
392 |
+
def sample(self, model, shape, return_tensor=False, condition=None, sampler="ddim", initial_noise=None, seed=None):
|
393 |
+
if not seed is None:
|
394 |
+
torch.manual_seed(seed)
|
395 |
+
return self.p_sample_loop(model, shape, initial_noise=initial_noise, start_noise_level_ratio=1.0, end_noise_level_ratio=0.0,
|
396 |
+
return_tensor=return_tensor, condition=condition, sampler=sampler)
|
397 |
+
|
398 |
+
def interpolate(self, model, shape, variance, first_endpoint=None, second_endpoint=None, return_tensor=False,
|
399 |
+
condition=None, sampler="ddim", seed=None):
|
400 |
+
if not seed is None:
|
401 |
+
torch.manual_seed(seed)
|
402 |
+
linear_noise = self.generate_linear_noise(shape, variance, first_endpoint=first_endpoint,
|
403 |
+
second_endpoint=second_endpoint)
|
404 |
+
return self.p_sample_loop(model, shape, initial_noise=linear_noise, start_noise_level_ratio=1.0,
|
405 |
+
end_noise_level_ratio=0.0,
|
406 |
+
return_tensor=return_tensor, condition=condition, sampler=sampler)
|
407 |
+
|
408 |
+
def img_guided_sample(self, model, shape, noising_strength, guide_img, return_tensor=False, condition=None,
|
409 |
+
sampler="ddim", initial_noise=None, seed=None):
|
410 |
+
if not seed is None:
|
411 |
+
torch.manual_seed(seed)
|
412 |
+
assert guide_img.shape[-1] == shape[-1], "guide_img.shape[:-1] != shape[:-1]"
|
413 |
+
return self.p_sample_loop(model, shape, start_noise_level_ratio=noising_strength, end_noise_level_ratio=0.0,
|
414 |
+
return_tensor=return_tensor, condition=condition, sampler=sampler,
|
415 |
+
guide_img=guide_img, initial_noise=initial_noise)
|
416 |
+
|
417 |
+
def inpaint_sample(self, model, shape, noising_strength, guide_img, mask, return_tensor=False, condition=None,
|
418 |
+
sampler="ddim", initial_noise=None, use_dynamic_mask=False, end_noise_level_ratio=0.0, seed=None,
|
419 |
+
mask_flexivity=0.8):
|
420 |
+
if not seed is None:
|
421 |
+
torch.manual_seed(seed)
|
422 |
+
return self.p_sample_loop(model, shape, start_noise_level_ratio=noising_strength, end_noise_level_ratio=end_noise_level_ratio,
|
423 |
+
return_tensor=return_tensor, condition=condition, guide_img=guide_img, mask=mask,
|
424 |
+
sampler=sampler, inpaint=True, initial_noise=initial_noise, use_dynamic_mask=use_dynamic_mask,
|
425 |
+
mask_flexivity=mask_flexivity)
|
model/GAN.py
ADDED
@@ -0,0 +1,262 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
import numpy as np
|
3 |
+
import torch
|
4 |
+
from torch import nn
|
5 |
+
from six.moves import xrange
|
6 |
+
from torch.utils.tensorboard import SummaryWriter
|
7 |
+
import random
|
8 |
+
|
9 |
+
from model.diffusion import ConditionedUnet
|
10 |
+
from tools import create_key
|
11 |
+
|
12 |
+
class Discriminator(nn.Module):
|
13 |
+
def __init__(self, label_emb_dim):
|
14 |
+
super(Discriminator, self).__init__()
|
15 |
+
# 特征图卷积层
|
16 |
+
self.conv_layers = nn.Sequential(
|
17 |
+
nn.Conv2d(4, 64, kernel_size=4, stride=2, padding=1),
|
18 |
+
nn.LeakyReLU(0.2, inplace=True),
|
19 |
+
nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1),
|
20 |
+
nn.BatchNorm2d(128),
|
21 |
+
nn.LeakyReLU(0.2, inplace=True),
|
22 |
+
nn.Conv2d(128, 256, kernel_size=4, stride=2, padding=1),
|
23 |
+
nn.BatchNorm2d(256),
|
24 |
+
nn.LeakyReLU(0.2, inplace=True),
|
25 |
+
nn.Conv2d(256, 512, kernel_size=4, stride=2, padding=1),
|
26 |
+
nn.BatchNorm2d(512),
|
27 |
+
nn.LeakyReLU(0.2, inplace=True),
|
28 |
+
nn.AdaptiveAvgPool2d(1), # 添加适应性池化层
|
29 |
+
nn.Flatten()
|
30 |
+
)
|
31 |
+
|
32 |
+
# 文本嵌入处理
|
33 |
+
self.text_embedding = nn.Sequential(
|
34 |
+
nn.Linear(label_emb_dim, 512),
|
35 |
+
nn.LeakyReLU(0.2, inplace=True)
|
36 |
+
)
|
37 |
+
|
38 |
+
# 判别器最后的全连接层
|
39 |
+
self.fc = nn.Linear(512 + 512, 1) # 两个512分别来自特征图和文本嵌入
|
40 |
+
|
41 |
+
def forward(self, x, text_emb):
|
42 |
+
x = self.conv_layers(x)
|
43 |
+
text_emb = self.text_embedding(text_emb)
|
44 |
+
combined = torch.cat((x, text_emb), dim=1)
|
45 |
+
output = self.fc(combined)
|
46 |
+
return output
|
47 |
+
|
48 |
+
|
49 |
+
|
50 |
+
def evaluate_GAN(device, generator, discriminator, iterator, encodes2embeddings_mapping):
|
51 |
+
generator.to(device)
|
52 |
+
discriminator.to(device)
|
53 |
+
generator.eval()
|
54 |
+
discriminator.eval()
|
55 |
+
|
56 |
+
real_accs = []
|
57 |
+
fake_accs = []
|
58 |
+
|
59 |
+
with torch.no_grad():
|
60 |
+
for i in range(100):
|
61 |
+
data, attributes = next(iter(iterator))
|
62 |
+
data = data.to(device)
|
63 |
+
|
64 |
+
conditions = [encodes2embeddings_mapping[create_key(attribute)] for attribute in attributes]
|
65 |
+
selected_conditions = [random.choice(conditions_of_one_sample) for conditions_of_one_sample in conditions]
|
66 |
+
selected_conditions = torch.stack(selected_conditions).float().to(device)
|
67 |
+
|
68 |
+
# 将数据和标签移至设备
|
69 |
+
real_images = data.to(device)
|
70 |
+
labels = selected_conditions.to(device)
|
71 |
+
|
72 |
+
# 生成噪声和假图像
|
73 |
+
noise = torch.randn_like(real_images).to(device)
|
74 |
+
fake_images = generator(noise)
|
75 |
+
|
76 |
+
# 评估鉴别器的性能
|
77 |
+
real_preds = discriminator(real_images, labels).reshape(-1)
|
78 |
+
fake_preds = discriminator(fake_images, labels).reshape(-1)
|
79 |
+
real_acc = (real_preds > 0.5).float().mean().item() # 真实图像的准确率
|
80 |
+
fake_acc = (fake_preds < 0.5).float().mean().item() # 生成图像的准确率
|
81 |
+
|
82 |
+
real_accs.append(real_acc)
|
83 |
+
fake_accs.append(fake_acc)
|
84 |
+
|
85 |
+
|
86 |
+
# 计算平均准确率
|
87 |
+
average_real_acc = sum(real_accs) / len(real_accs)
|
88 |
+
average_fake_acc = sum(fake_accs) / len(fake_accs)
|
89 |
+
|
90 |
+
return average_real_acc, average_fake_acc
|
91 |
+
|
92 |
+
|
93 |
+
def get_Generator(model_Config, load_pretrain=False, model_name=None, device="cpu"):
|
94 |
+
generator = ConditionedUnet(**model_Config)
|
95 |
+
print(f"Model intialized, size: {sum(p.numel() for p in generator.parameters() if p.requires_grad)}")
|
96 |
+
generator.to(device)
|
97 |
+
|
98 |
+
if load_pretrain:
|
99 |
+
print(f"Loading weights from models/{model_name}_generator.pth")
|
100 |
+
checkpoint = torch.load(f'models/{model_name}_generator.pth', map_location=device)
|
101 |
+
generator.load_state_dict(checkpoint['model_state_dict'])
|
102 |
+
generator.eval()
|
103 |
+
return generator
|
104 |
+
|
105 |
+
|
106 |
+
def get_Discriminator(model_Config, load_pretrain=False, model_name=None, device="cpu"):
|
107 |
+
discriminator = Discriminator(**model_Config)
|
108 |
+
print(f"Model intialized, size: {sum(p.numel() for p in discriminator.parameters() if p.requires_grad)}")
|
109 |
+
discriminator.to(device)
|
110 |
+
|
111 |
+
if load_pretrain:
|
112 |
+
print(f"Loading weights from models/{model_name}_discriminator.pth")
|
113 |
+
checkpoint = torch.load(f'models/{model_name}_discriminator.pth', map_location=device)
|
114 |
+
discriminator.load_state_dict(checkpoint['model_state_dict'])
|
115 |
+
discriminator.eval()
|
116 |
+
return discriminator
|
117 |
+
|
118 |
+
|
119 |
+
def train_GAN(device, init_model_name, unetConfig, BATCH_SIZE, lr_G, lr_D, max_iter, iterator, load_pretrain,
|
120 |
+
encodes2embeddings_mapping, save_steps, unconditional_condition, uncondition_rate, save_model_name=None):
|
121 |
+
|
122 |
+
if save_model_name is None:
|
123 |
+
save_model_name = init_model_name
|
124 |
+
|
125 |
+
def save_model_hyperparameter(model_name, unetConfig, BATCH_SIZE, model_size, current_iter, current_loss):
|
126 |
+
model_hyperparameter = unetConfig
|
127 |
+
model_hyperparameter["BATCH_SIZE"] = BATCH_SIZE
|
128 |
+
model_hyperparameter["lr_G"] = lr_G
|
129 |
+
model_hyperparameter["lr_D"] = lr_D
|
130 |
+
model_hyperparameter["model_size"] = model_size
|
131 |
+
model_hyperparameter["current_iter"] = current_iter
|
132 |
+
model_hyperparameter["current_loss"] = current_loss
|
133 |
+
with open(f"models/hyperparameters/{model_name}_GAN.json", "w") as json_file:
|
134 |
+
json.dump(model_hyperparameter, json_file, ensure_ascii=False, indent=4)
|
135 |
+
|
136 |
+
generator = ConditionedUnet(**unetConfig)
|
137 |
+
discriminator = Discriminator(unetConfig["label_emb_dim"])
|
138 |
+
generator_size = sum(p.numel() for p in generator.parameters() if p.requires_grad)
|
139 |
+
discriminator_size = sum(p.numel() for p in discriminator.parameters() if p.requires_grad)
|
140 |
+
|
141 |
+
print(f"Generator trainable parameters: {generator_size}, discriminator trainable parameters: {discriminator_size}")
|
142 |
+
generator.to(device)
|
143 |
+
discriminator.to(device)
|
144 |
+
optimizer_G = torch.optim.Adam(filter(lambda p: p.requires_grad, generator.parameters()), lr=lr_G, amsgrad=False)
|
145 |
+
optimizer_D = torch.optim.Adam(filter(lambda p: p.requires_grad, discriminator.parameters()), lr=lr_D, amsgrad=False)
|
146 |
+
|
147 |
+
if load_pretrain:
|
148 |
+
print(f"Loading weights from models/{init_model_name}_generator.pt")
|
149 |
+
checkpoint = torch.load(f'models/{init_model_name}_generator.pth')
|
150 |
+
generator.load_state_dict(checkpoint['model_state_dict'])
|
151 |
+
optimizer_G.load_state_dict(checkpoint['optimizer_state_dict'])
|
152 |
+
print(f"Loading weights from models/{init_model_name}_discriminator.pt")
|
153 |
+
checkpoint = torch.load(f'models/{init_model_name}_discriminator.pth')
|
154 |
+
discriminator.load_state_dict(checkpoint['model_state_dict'])
|
155 |
+
optimizer_D.load_state_dict(checkpoint['optimizer_state_dict'])
|
156 |
+
else:
|
157 |
+
print("Model initialized.")
|
158 |
+
if max_iter == 0:
|
159 |
+
print("Return model directly.")
|
160 |
+
return generator, discriminator, optimizer_G, optimizer_D
|
161 |
+
|
162 |
+
|
163 |
+
train_loss_G, train_loss_D = [], []
|
164 |
+
writer = SummaryWriter(f'runs/{save_model_name}_GAN')
|
165 |
+
|
166 |
+
# average_real_acc, average_fake_acc = evaluate_GAN(device, generator, discriminator, iterator, encodes2embeddings_mapping)
|
167 |
+
# print(f"average_real_acc, average_fake_acc: {average_real_acc, average_fake_acc}")
|
168 |
+
|
169 |
+
criterion = nn.BCEWithLogitsLoss()
|
170 |
+
generator.train()
|
171 |
+
for i in xrange(max_iter):
|
172 |
+
data, attributes = next(iter(iterator))
|
173 |
+
data = data.to(device)
|
174 |
+
|
175 |
+
conditions = [encodes2embeddings_mapping[create_key(attribute)] for attribute in attributes]
|
176 |
+
unconditional_condition_copy = torch.tensor(unconditional_condition, dtype=torch.float32).to(device).detach()
|
177 |
+
selected_conditions = [unconditional_condition_copy if random.random() < uncondition_rate else random.choice(
|
178 |
+
conditions_of_one_sample) for conditions_of_one_sample in conditions]
|
179 |
+
batch_size = len(selected_conditions)
|
180 |
+
selected_conditions = torch.stack(selected_conditions).float().to(device)
|
181 |
+
|
182 |
+
# 将数据和标签移至设备
|
183 |
+
real_images = data.to(device)
|
184 |
+
labels = selected_conditions.to(device)
|
185 |
+
|
186 |
+
# 真实和假的标签
|
187 |
+
real_labels = torch.ones(batch_size, 1).to(device)
|
188 |
+
fake_labels = torch.zeros(batch_size, 1).to(device)
|
189 |
+
|
190 |
+
# ========== 训练鉴别器 ==========
|
191 |
+
optimizer_D.zero_grad()
|
192 |
+
|
193 |
+
# 计算鉴别器对真实图像的损失
|
194 |
+
outputs_real = discriminator(real_images, labels)
|
195 |
+
loss_D_real = criterion(outputs_real, real_labels)
|
196 |
+
|
197 |
+
# 生成假图像
|
198 |
+
noise = torch.randn_like(real_images).to(device)
|
199 |
+
fake_images = generator(noise, labels)
|
200 |
+
|
201 |
+
# 计算鉴别器对假图像的损失
|
202 |
+
outputs_fake = discriminator(fake_images.detach(), labels)
|
203 |
+
loss_D_fake = criterion(outputs_fake, fake_labels)
|
204 |
+
|
205 |
+
# 反向传播和优化
|
206 |
+
loss_D = loss_D_real + loss_D_fake
|
207 |
+
loss_D.backward()
|
208 |
+
optimizer_D.step()
|
209 |
+
|
210 |
+
# ========== 训练生成器 ==========
|
211 |
+
optimizer_G.zero_grad()
|
212 |
+
|
213 |
+
# 计算生成器的损失
|
214 |
+
outputs_fake = discriminator(fake_images, labels)
|
215 |
+
loss_G = criterion(outputs_fake, real_labels)
|
216 |
+
|
217 |
+
# 反向传播和优化
|
218 |
+
loss_G.backward()
|
219 |
+
optimizer_G.step()
|
220 |
+
|
221 |
+
|
222 |
+
train_loss_G.append(loss_G.item())
|
223 |
+
train_loss_D.append(loss_D.item())
|
224 |
+
step = int(optimizer_G.state_dict()['state'][list(optimizer_G.state_dict()['state'].keys())[0]]['step'].numpy())
|
225 |
+
|
226 |
+
if (i + 1) % 100 == 0:
|
227 |
+
print('%d step' % (step))
|
228 |
+
|
229 |
+
if (i + 1) % save_steps == 0:
|
230 |
+
current_loss_D = np.mean(train_loss_D[-save_steps:])
|
231 |
+
current_loss_G = np.mean(train_loss_G[-save_steps:])
|
232 |
+
print('current_loss_G: %.5f' % current_loss_G)
|
233 |
+
print('current_loss_D: %.5f' % current_loss_D)
|
234 |
+
|
235 |
+
writer.add_scalar(f"current_loss_G", current_loss_G, step)
|
236 |
+
writer.add_scalar(f"current_loss_D", current_loss_D, step)
|
237 |
+
|
238 |
+
|
239 |
+
torch.save({
|
240 |
+
'model_state_dict': generator.state_dict(),
|
241 |
+
'optimizer_state_dict': optimizer_G.state_dict(),
|
242 |
+
}, f'models/{save_model_name}_generator.pth')
|
243 |
+
save_model_hyperparameter(save_model_name, unetConfig, BATCH_SIZE, generator_size, step, current_loss_G)
|
244 |
+
torch.save({
|
245 |
+
'model_state_dict': discriminator.state_dict(),
|
246 |
+
'optimizer_state_dict': optimizer_D.state_dict(),
|
247 |
+
}, f'models/{save_model_name}_discriminator.pth')
|
248 |
+
save_model_hyperparameter(save_model_name, unetConfig, BATCH_SIZE, discriminator_size, step, current_loss_D)
|
249 |
+
|
250 |
+
if step % 10000 == 0:
|
251 |
+
torch.save({
|
252 |
+
'model_state_dict': generator.state_dict(),
|
253 |
+
'optimizer_state_dict': optimizer_G.state_dict(),
|
254 |
+
}, f'models/history/{save_model_name}_{step}_generator.pth')
|
255 |
+
torch.save({
|
256 |
+
'model_state_dict': discriminator.state_dict(),
|
257 |
+
'optimizer_state_dict': optimizer_D.state_dict(),
|
258 |
+
}, f'models/history/{save_model_name}_{step}_discriminator.pth')
|
259 |
+
|
260 |
+
return generator, discriminator, optimizer_G, optimizer_D
|
261 |
+
|
262 |
+
|
model/VQGAN.py
ADDED
@@ -0,0 +1,684 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
from torch.utils.tensorboard import SummaryWriter
|
3 |
+
import torch
|
4 |
+
import torch.nn as nn
|
5 |
+
import torch.nn.functional as F
|
6 |
+
import numpy as np
|
7 |
+
from six.moves import xrange
|
8 |
+
from einops import rearrange
|
9 |
+
from torchvision import models
|
10 |
+
|
11 |
+
|
12 |
+
def Normalize(in_channels, num_groups=32, norm_type="groupnorm"):
|
13 |
+
"""Normalization layer"""
|
14 |
+
|
15 |
+
if norm_type == "batchnorm":
|
16 |
+
return torch.nn.BatchNorm2d(in_channels)
|
17 |
+
return torch.nn.GroupNorm(num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True)
|
18 |
+
|
19 |
+
|
20 |
+
def nonlinearity(x, act_type="relu"):
|
21 |
+
"""Nonlinear activation function"""
|
22 |
+
|
23 |
+
if act_type == "relu":
|
24 |
+
return F.relu(x)
|
25 |
+
else:
|
26 |
+
# swish
|
27 |
+
return x * torch.sigmoid(x)
|
28 |
+
|
29 |
+
|
30 |
+
class VectorQuantizer(nn.Module):
|
31 |
+
"""Vector quantization layer"""
|
32 |
+
|
33 |
+
def __init__(self, num_embeddings, embedding_dim, commitment_cost):
|
34 |
+
super(VectorQuantizer, self).__init__()
|
35 |
+
|
36 |
+
self._embedding_dim = embedding_dim
|
37 |
+
self._num_embeddings = num_embeddings
|
38 |
+
|
39 |
+
self._embedding = nn.Embedding(self._num_embeddings, self._embedding_dim)
|
40 |
+
self._embedding.weight.data.uniform_(-1 / self._num_embeddings, 1 / self._num_embeddings)
|
41 |
+
self._commitment_cost = commitment_cost
|
42 |
+
|
43 |
+
def forward(self, inputs):
|
44 |
+
# convert inputs from BCHW -> BHWC
|
45 |
+
inputs = inputs.permute(0, 2, 3, 1).contiguous()
|
46 |
+
input_shape = inputs.shape
|
47 |
+
|
48 |
+
# Flatten input BCHW -> (BHW)C
|
49 |
+
flat_input = inputs.view(-1, self._embedding_dim)
|
50 |
+
|
51 |
+
# Calculate distances (input-embedding)^2
|
52 |
+
distances = (torch.sum(flat_input ** 2, dim=1, keepdim=True)
|
53 |
+
+ torch.sum(self._embedding.weight ** 2, dim=1)
|
54 |
+
- 2 * torch.matmul(flat_input, self._embedding.weight.t()))
|
55 |
+
|
56 |
+
# Encoding (one-hot-encoding matrix)
|
57 |
+
encoding_indices = torch.argmin(distances, dim=1).unsqueeze(1)
|
58 |
+
encodings = torch.zeros(encoding_indices.shape[0], self._num_embeddings, device=inputs.device)
|
59 |
+
encodings.scatter_(1, encoding_indices, 1)
|
60 |
+
|
61 |
+
# Quantize and unflatten
|
62 |
+
quantized = torch.matmul(encodings, self._embedding.weight).view(input_shape)
|
63 |
+
|
64 |
+
# Loss
|
65 |
+
e_latent_loss = F.mse_loss(quantized.detach(), inputs)
|
66 |
+
q_latent_loss = F.mse_loss(quantized, inputs.detach())
|
67 |
+
loss = q_latent_loss + self._commitment_cost * e_latent_loss
|
68 |
+
|
69 |
+
quantized = inputs + (quantized - inputs).detach()
|
70 |
+
avg_probs = torch.mean(encodings, dim=0)
|
71 |
+
perplexity = torch.exp(-torch.sum(avg_probs * torch.log(avg_probs + 1e-10)))
|
72 |
+
|
73 |
+
# convert quantized from BHWC -> BCHW
|
74 |
+
min_encodings, min_encoding_indices = None, None
|
75 |
+
return quantized.permute(0, 3, 1, 2).contiguous(), loss, (perplexity, min_encodings, min_encoding_indices)
|
76 |
+
|
77 |
+
|
78 |
+
class VectorQuantizerEMA(nn.Module):
|
79 |
+
"""Vector quantization layer based on exponential moving average"""
|
80 |
+
|
81 |
+
def __init__(self, num_embeddings, embedding_dim, commitment_cost, decay, epsilon=1e-5):
|
82 |
+
super(VectorQuantizerEMA, self).__init__()
|
83 |
+
|
84 |
+
self._embedding_dim = embedding_dim
|
85 |
+
self._num_embeddings = num_embeddings
|
86 |
+
|
87 |
+
self._embedding = nn.Embedding(self._num_embeddings, self._embedding_dim)
|
88 |
+
self._embedding.weight.data.normal_()
|
89 |
+
self._commitment_cost = commitment_cost
|
90 |
+
|
91 |
+
self.register_buffer('_ema_cluster_size', torch.zeros(num_embeddings))
|
92 |
+
self._ema_w = nn.Parameter(torch.Tensor(num_embeddings, self._embedding_dim))
|
93 |
+
self._ema_w.data.normal_()
|
94 |
+
|
95 |
+
self._decay = decay
|
96 |
+
self._epsilon = epsilon
|
97 |
+
|
98 |
+
def forward(self, inputs):
|
99 |
+
# convert inputs from BCHW -> BHWC
|
100 |
+
inputs = inputs.permute(0, 2, 3, 1).contiguous()
|
101 |
+
input_shape = inputs.shape
|
102 |
+
|
103 |
+
# Flatten input
|
104 |
+
flat_input = inputs.view(-1, self._embedding_dim)
|
105 |
+
|
106 |
+
# Calculate distances
|
107 |
+
distances = (torch.sum(flat_input ** 2, dim=1, keepdim=True)
|
108 |
+
+ torch.sum(self._embedding.weight ** 2, dim=1)
|
109 |
+
- 2 * torch.matmul(flat_input, self._embedding.weight.t()))
|
110 |
+
|
111 |
+
# Encoding
|
112 |
+
encoding_indices = torch.argmin(distances, dim=1).unsqueeze(1)
|
113 |
+
encodings = torch.zeros(encoding_indices.shape[0], self._num_embeddings, device=inputs.device)
|
114 |
+
encodings.scatter_(1, encoding_indices, 1)
|
115 |
+
|
116 |
+
# Quantize and unflatten
|
117 |
+
quantized = torch.matmul(encodings, self._embedding.weight).view(input_shape)
|
118 |
+
|
119 |
+
# Use EMA to update the embedding vectors
|
120 |
+
if self.training:
|
121 |
+
self._ema_cluster_size = self._ema_cluster_size * self._decay + \
|
122 |
+
(1 - self._decay) * torch.sum(encodings, 0)
|
123 |
+
|
124 |
+
# Laplace smoothing of the cluster size
|
125 |
+
n = torch.sum(self._ema_cluster_size.data)
|
126 |
+
self._ema_cluster_size = (
|
127 |
+
(self._ema_cluster_size + self._epsilon)
|
128 |
+
/ (n + self._num_embeddings * self._epsilon) * n)
|
129 |
+
|
130 |
+
dw = torch.matmul(encodings.t(), flat_input)
|
131 |
+
self._ema_w = nn.Parameter(self._ema_w * self._decay + (1 - self._decay) * dw)
|
132 |
+
|
133 |
+
self._embedding.weight = nn.Parameter(self._ema_w / self._ema_cluster_size.unsqueeze(1))
|
134 |
+
|
135 |
+
# Loss
|
136 |
+
e_latent_loss = F.mse_loss(quantized.detach(), inputs)
|
137 |
+
loss = self._commitment_cost * e_latent_loss
|
138 |
+
|
139 |
+
# Straight Through Estimator
|
140 |
+
quantized = inputs + (quantized - inputs).detach()
|
141 |
+
avg_probs = torch.mean(encodings, dim=0)
|
142 |
+
perplexity = torch.exp(-torch.sum(avg_probs * torch.log(avg_probs + 1e-10)))
|
143 |
+
|
144 |
+
# convert quantized from BHWC -> BCHW
|
145 |
+
min_encodings, min_encoding_indices = None, None
|
146 |
+
return quantized.permute(0, 3, 1, 2).contiguous(), loss, (perplexity, min_encodings, min_encoding_indices)
|
147 |
+
|
148 |
+
|
149 |
+
class DownSample(nn.Module):
|
150 |
+
"""DownSample layer"""
|
151 |
+
|
152 |
+
def __init__(self, in_channels, out_channels):
|
153 |
+
super(DownSample, self).__init__()
|
154 |
+
self._conv2d = nn.Conv2d(in_channels=in_channels,
|
155 |
+
out_channels=out_channels,
|
156 |
+
kernel_size=4,
|
157 |
+
stride=2, padding=1)
|
158 |
+
|
159 |
+
def forward(self, x):
|
160 |
+
return self._conv2d(x)
|
161 |
+
|
162 |
+
|
163 |
+
class UpSample(nn.Module):
|
164 |
+
"""UpSample layer"""
|
165 |
+
|
166 |
+
def __init__(self, in_channels, out_channels):
|
167 |
+
super(UpSample, self).__init__()
|
168 |
+
self._conv2d = nn.ConvTranspose2d(in_channels=in_channels,
|
169 |
+
out_channels=out_channels,
|
170 |
+
kernel_size=4,
|
171 |
+
stride=2, padding=1)
|
172 |
+
|
173 |
+
def forward(self, x):
|
174 |
+
return self._conv2d(x)
|
175 |
+
|
176 |
+
|
177 |
+
class ResnetBlock(nn.Module):
|
178 |
+
"""ResnetBlock is a combination of non-linearity, convolution, and normalization"""
|
179 |
+
|
180 |
+
def __init__(self, *, in_channels, out_channels=None, double_conv=False, conv_shortcut=False,
|
181 |
+
dropout=0.0, temb_channels=512, norm_type="groupnorm", act_type="relu", num_groups=32):
|
182 |
+
super().__init__()
|
183 |
+
self.in_channels = in_channels
|
184 |
+
out_channels = in_channels if out_channels is None else out_channels
|
185 |
+
self.out_channels = out_channels
|
186 |
+
self.use_conv_shortcut = conv_shortcut
|
187 |
+
self.act_type = act_type
|
188 |
+
|
189 |
+
self.norm1 = Normalize(in_channels, norm_type=norm_type, num_groups=num_groups)
|
190 |
+
self.conv1 = torch.nn.Conv2d(in_channels,
|
191 |
+
out_channels,
|
192 |
+
kernel_size=3,
|
193 |
+
stride=1,
|
194 |
+
padding=1)
|
195 |
+
if temb_channels > 0:
|
196 |
+
self.temb_proj = torch.nn.Linear(temb_channels,
|
197 |
+
out_channels)
|
198 |
+
|
199 |
+
self.double_conv = double_conv
|
200 |
+
if self.double_conv:
|
201 |
+
self.norm2 = Normalize(out_channels, norm_type=norm_type, num_groups=num_groups)
|
202 |
+
self.dropout = torch.nn.Dropout(dropout)
|
203 |
+
self.conv2 = torch.nn.Conv2d(out_channels,
|
204 |
+
out_channels,
|
205 |
+
kernel_size=3,
|
206 |
+
stride=1,
|
207 |
+
padding=1)
|
208 |
+
|
209 |
+
if self.in_channels != self.out_channels:
|
210 |
+
if self.use_conv_shortcut:
|
211 |
+
self.conv_shortcut = torch.nn.Conv2d(in_channels,
|
212 |
+
out_channels,
|
213 |
+
kernel_size=3,
|
214 |
+
stride=1,
|
215 |
+
padding=1)
|
216 |
+
else:
|
217 |
+
self.nin_shortcut = torch.nn.Conv2d(in_channels,
|
218 |
+
out_channels,
|
219 |
+
kernel_size=1,
|
220 |
+
stride=1,
|
221 |
+
padding=0)
|
222 |
+
|
223 |
+
def forward(self, x, temb=None):
|
224 |
+
h = x
|
225 |
+
h = self.norm1(h)
|
226 |
+
h = nonlinearity(h, act_type=self.act_type)
|
227 |
+
h = self.conv1(h)
|
228 |
+
|
229 |
+
if temb is not None:
|
230 |
+
h = h + self.temb_proj(nonlinearity(temb, act_type=self.act_type))[:, :, None, None]
|
231 |
+
|
232 |
+
if self.double_conv:
|
233 |
+
h = self.norm2(h)
|
234 |
+
h = nonlinearity(h, act_type=self.act_type)
|
235 |
+
h = self.dropout(h)
|
236 |
+
h = self.conv2(h)
|
237 |
+
|
238 |
+
if self.in_channels != self.out_channels:
|
239 |
+
if self.use_conv_shortcut:
|
240 |
+
x = self.conv_shortcut(x)
|
241 |
+
else:
|
242 |
+
x = self.nin_shortcut(x)
|
243 |
+
|
244 |
+
return x + h
|
245 |
+
|
246 |
+
|
247 |
+
class LinearAttention(nn.Module):
|
248 |
+
"""Efficient attention block based on <https://proceedings.mlr.press/v119/katharopoulos20a.html>"""
|
249 |
+
|
250 |
+
def __init__(self, dim, heads=4, dim_head=32, with_skip=True):
|
251 |
+
super().__init__()
|
252 |
+
self.heads = heads
|
253 |
+
hidden_dim = dim_head * heads
|
254 |
+
self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias=False)
|
255 |
+
self.to_out = nn.Conv2d(hidden_dim, dim, 1)
|
256 |
+
|
257 |
+
self.with_skip = with_skip
|
258 |
+
if self.with_skip:
|
259 |
+
self.nin_shortcut = torch.nn.Conv2d(dim, dim, kernel_size=1, stride=1, padding=0)
|
260 |
+
|
261 |
+
def forward(self, x):
|
262 |
+
b, c, h, w = x.shape
|
263 |
+
qkv = self.to_qkv(x)
|
264 |
+
q, k, v = rearrange(qkv, 'b (qkv heads c) h w -> qkv b heads c (h w)', heads=self.heads, qkv=3)
|
265 |
+
k = k.softmax(dim=-1)
|
266 |
+
context = torch.einsum('bhdn,bhen->bhde', k, v)
|
267 |
+
out = torch.einsum('bhde,bhdn->bhen', context, q)
|
268 |
+
out = rearrange(out, 'b heads c (h w) -> b (heads c) h w', heads=self.heads, h=h, w=w)
|
269 |
+
|
270 |
+
if self.with_skip:
|
271 |
+
return self.to_out(out) + self.nin_shortcut(x)
|
272 |
+
return self.to_out(out)
|
273 |
+
|
274 |
+
|
275 |
+
class Encoder(nn.Module):
|
276 |
+
"""The encoder, consisting of alternating stacks of ResNet blocks, efficient attention modules, and downsampling layers."""
|
277 |
+
|
278 |
+
def __init__(self, in_channels, hidden_channels, embedding_dim, block_depth=2,
|
279 |
+
attn_pos=None, attn_with_skip=True, norm_type="groupnorm", act_type="relu", num_groups=32):
|
280 |
+
super(Encoder, self).__init__()
|
281 |
+
|
282 |
+
if attn_pos is None:
|
283 |
+
attn_pos = []
|
284 |
+
self._layers = nn.ModuleList([DownSample(in_channels, hidden_channels[0])])
|
285 |
+
current_channel = hidden_channels[0]
|
286 |
+
|
287 |
+
for i in range(1, len(hidden_channels)):
|
288 |
+
for _ in range(block_depth - 1):
|
289 |
+
self._layers.append(ResnetBlock(in_channels=current_channel,
|
290 |
+
out_channels=current_channel,
|
291 |
+
double_conv=False,
|
292 |
+
conv_shortcut=False,
|
293 |
+
norm_type=norm_type,
|
294 |
+
act_type=act_type,
|
295 |
+
num_groups=num_groups))
|
296 |
+
if current_channel in attn_pos:
|
297 |
+
self._layers.append(LinearAttention(current_channel, 1, 32, attn_with_skip))
|
298 |
+
|
299 |
+
self._layers.append(Normalize(current_channel, norm_type=norm_type, num_groups=num_groups))
|
300 |
+
self._layers.append(nn.ReLU())
|
301 |
+
self._layers.append(DownSample(current_channel, hidden_channels[i]))
|
302 |
+
current_channel = hidden_channels[i]
|
303 |
+
|
304 |
+
for _ in range(block_depth - 1):
|
305 |
+
self._layers.append(ResnetBlock(in_channels=current_channel,
|
306 |
+
out_channels=current_channel,
|
307 |
+
double_conv=False,
|
308 |
+
conv_shortcut=False,
|
309 |
+
norm_type=norm_type,
|
310 |
+
act_type=act_type,
|
311 |
+
num_groups=num_groups))
|
312 |
+
if current_channel in attn_pos:
|
313 |
+
self._layers.append(LinearAttention(current_channel, 1, 32, attn_with_skip))
|
314 |
+
|
315 |
+
# Conv1x1: hidden_channels[-1] -> embedding_dim
|
316 |
+
self._layers.append(Normalize(current_channel, norm_type=norm_type, num_groups=num_groups))
|
317 |
+
self._layers.append(nn.ReLU())
|
318 |
+
self._layers.append(nn.Conv2d(in_channels=current_channel,
|
319 |
+
out_channels=embedding_dim,
|
320 |
+
kernel_size=1,
|
321 |
+
stride=1))
|
322 |
+
|
323 |
+
def forward(self, x):
|
324 |
+
for layer in self._layers:
|
325 |
+
x = layer(x)
|
326 |
+
return x
|
327 |
+
|
328 |
+
|
329 |
+
class Decoder(nn.Module):
|
330 |
+
"""The decoder, consisting of alternating stacks of ResNet blocks, efficient attention modules, and upsampling layers."""
|
331 |
+
|
332 |
+
def __init__(self, embedding_dim, hidden_channels, out_channels, block_depth=2,
|
333 |
+
attn_pos=None, attn_with_skip=True, norm_type="groupnorm", act_type="relu",
|
334 |
+
num_groups=32):
|
335 |
+
super(Decoder, self).__init__()
|
336 |
+
|
337 |
+
if attn_pos is None:
|
338 |
+
attn_pos = []
|
339 |
+
reversed_hidden_channels = list(reversed(hidden_channels))
|
340 |
+
|
341 |
+
# Conv1x1: hidden_channels[-1] -> embedding_dim
|
342 |
+
self._layers = nn.ModuleList([nn.Conv2d(in_channels=embedding_dim,
|
343 |
+
out_channels=reversed_hidden_channels[0],
|
344 |
+
kernel_size=1, stride=1, bias=False)])
|
345 |
+
|
346 |
+
current_channel = reversed_hidden_channels[0]
|
347 |
+
|
348 |
+
for _ in range(block_depth - 1):
|
349 |
+
if current_channel in attn_pos:
|
350 |
+
self._layers.append(LinearAttention(current_channel, 1, 32, attn_with_skip))
|
351 |
+
self._layers.append(ResnetBlock(in_channels=current_channel,
|
352 |
+
out_channels=current_channel,
|
353 |
+
double_conv=False,
|
354 |
+
conv_shortcut=False,
|
355 |
+
norm_type=norm_type,
|
356 |
+
act_type=act_type,
|
357 |
+
num_groups=num_groups))
|
358 |
+
|
359 |
+
for i in range(1, len(reversed_hidden_channels)):
|
360 |
+
self._layers.append(Normalize(current_channel, norm_type=norm_type, num_groups=num_groups))
|
361 |
+
self._layers.append(nn.ReLU())
|
362 |
+
self._layers.append(UpSample(current_channel, reversed_hidden_channels[i]))
|
363 |
+
current_channel = reversed_hidden_channels[i]
|
364 |
+
|
365 |
+
for _ in range(block_depth - 1):
|
366 |
+
if current_channel in attn_pos:
|
367 |
+
self._layers.append(LinearAttention(current_channel, 1, 32, attn_with_skip))
|
368 |
+
self._layers.append(ResnetBlock(in_channels=current_channel,
|
369 |
+
out_channels=current_channel,
|
370 |
+
double_conv=False,
|
371 |
+
conv_shortcut=False,
|
372 |
+
norm_type=norm_type,
|
373 |
+
act_type=act_type,
|
374 |
+
num_groups=num_groups))
|
375 |
+
|
376 |
+
self._layers.append(Normalize(current_channel, norm_type=norm_type, num_groups=num_groups))
|
377 |
+
self._layers.append(nn.ReLU())
|
378 |
+
self._layers.append(UpSample(current_channel, current_channel))
|
379 |
+
|
380 |
+
# final layers
|
381 |
+
self._layers.append(ResnetBlock(in_channels=current_channel,
|
382 |
+
out_channels=out_channels,
|
383 |
+
double_conv=False,
|
384 |
+
conv_shortcut=False,
|
385 |
+
norm_type=norm_type,
|
386 |
+
act_type=act_type,
|
387 |
+
num_groups=num_groups))
|
388 |
+
|
389 |
+
|
390 |
+
def forward(self, x):
|
391 |
+
for layer in self._layers:
|
392 |
+
x = layer(x)
|
393 |
+
|
394 |
+
log_magnitude = torch.nn.functional.softplus(x[:, 0, :, :])
|
395 |
+
|
396 |
+
cos_phase = torch.tanh(x[:, 1, :, :])
|
397 |
+
sin_phase = torch.tanh(x[:, 2, :, :])
|
398 |
+
x = torch.stack([log_magnitude, cos_phase, sin_phase], dim=1)
|
399 |
+
|
400 |
+
return x
|
401 |
+
|
402 |
+
|
403 |
+
class VQGAN_Discriminator(nn.Module):
|
404 |
+
"""The discriminator employs an 18-layer-ResNet architecture , with the first layer replaced by a 2D convolutional
|
405 |
+
layer that accommodates spectral representation inputs and the last two layers replaced by a binary classifier
|
406 |
+
layer."""
|
407 |
+
|
408 |
+
def __init__(self, in_channels=1):
|
409 |
+
super(VQGAN_Discriminator, self).__init__()
|
410 |
+
resnet = models.resnet18(pretrained=True)
|
411 |
+
|
412 |
+
# 修改第一层以接受单通道(黑白)图像
|
413 |
+
resnet.conv1 = nn.Conv2d(in_channels, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
|
414 |
+
|
415 |
+
# 使用ResNet的特征提取部分
|
416 |
+
self.features = nn.Sequential(*list(resnet.children())[:-2])
|
417 |
+
|
418 |
+
# 添加判别器的额外层
|
419 |
+
self.classifier = nn.Sequential(
|
420 |
+
nn.Linear(512, 1),
|
421 |
+
nn.Sigmoid()
|
422 |
+
)
|
423 |
+
|
424 |
+
def forward(self, x):
|
425 |
+
x = self.features(x)
|
426 |
+
x = nn.functional.adaptive_avg_pool2d(x, (1, 1))
|
427 |
+
x = torch.flatten(x, 1)
|
428 |
+
x = self.classifier(x)
|
429 |
+
return x
|
430 |
+
|
431 |
+
|
432 |
+
class VQGAN(nn.Module):
|
433 |
+
"""The VQ-GAN model. <https://openaccess.thecvf.com/content/CVPR2021/html/Esser_Taming_Transformers_for_High-Resolution_Image_Synthesis_CVPR_2021_paper.html?ref=>"""
|
434 |
+
|
435 |
+
def __init__(self, in_channels, hidden_channels, embedding_dim, out_channels, block_depth=2,
|
436 |
+
attn_pos=None, attn_with_skip=True, norm_type="groupnorm", act_type="relu",
|
437 |
+
num_embeddings=1024, commitment_cost=0.25, decay=0.99, num_groups=32):
|
438 |
+
super(VQGAN, self).__init__()
|
439 |
+
|
440 |
+
self._encoder = Encoder(in_channels, hidden_channels, embedding_dim, block_depth=block_depth,
|
441 |
+
attn_pos=attn_pos, attn_with_skip=attn_with_skip, norm_type=norm_type, act_type="act_type", num_groups=num_groups)
|
442 |
+
|
443 |
+
if decay > 0.0:
|
444 |
+
self._vq_vae = VectorQuantizerEMA(num_embeddings, embedding_dim,
|
445 |
+
commitment_cost, decay)
|
446 |
+
else:
|
447 |
+
self._vq_vae = VectorQuantizer(num_embeddings, embedding_dim,
|
448 |
+
commitment_cost)
|
449 |
+
self._decoder = Decoder(embedding_dim, hidden_channels, out_channels, block_depth=block_depth,
|
450 |
+
attn_pos=attn_pos, attn_with_skip=attn_with_skip, norm_type=norm_type,
|
451 |
+
act_type=act_type, num_groups=num_groups)
|
452 |
+
|
453 |
+
def forward(self, x):
|
454 |
+
z = self._encoder(x)
|
455 |
+
quantized, vq_loss, (perplexity, _, _) = self._vq_vae(z)
|
456 |
+
x_recon = self._decoder(quantized)
|
457 |
+
|
458 |
+
return vq_loss, x_recon, perplexity
|
459 |
+
|
460 |
+
|
461 |
+
class ReconstructionLoss(nn.Module):
|
462 |
+
def __init__(self, w1, w2, epsilon=1e-3):
|
463 |
+
super(ReconstructionLoss, self).__init__()
|
464 |
+
self.w1 = w1
|
465 |
+
self.w2 = w2
|
466 |
+
self.epsilon = epsilon
|
467 |
+
|
468 |
+
def weighted_mae_loss(self, y_true, y_pred):
|
469 |
+
# avoid divide by zero
|
470 |
+
y_true_safe = torch.clamp(y_true, min=self.epsilon)
|
471 |
+
|
472 |
+
# compute weighted MAE
|
473 |
+
loss = torch.mean(torch.abs(y_pred - y_true) / y_true_safe)
|
474 |
+
return loss
|
475 |
+
|
476 |
+
def mae_loss(self, y_true, y_pred):
|
477 |
+
loss = torch.mean(torch.abs(y_pred - y_true))
|
478 |
+
return loss
|
479 |
+
|
480 |
+
def forward(self, y_pred, y_true):
|
481 |
+
# loss for magnitude channel
|
482 |
+
log_magnitude_loss = self.w1 * self.weighted_mae_loss(y_pred[:, 0, :, :], y_true[:, 0, :, :])
|
483 |
+
|
484 |
+
# loss for phase channels
|
485 |
+
phase_loss = self.w2 * self.mae_loss(y_pred[:, 1:, :, :], y_true[:, 1:, :, :])
|
486 |
+
|
487 |
+
# sum up
|
488 |
+
rec_loss = log_magnitude_loss + phase_loss
|
489 |
+
return log_magnitude_loss, phase_loss, rec_loss
|
490 |
+
|
491 |
+
|
492 |
+
def evaluate_VQGAN(model, discriminator, iterator, reconstructionLoss, adversarial_loss, trainingConfig):
|
493 |
+
model.to(trainingConfig["device"])
|
494 |
+
model.eval()
|
495 |
+
train_res_error = []
|
496 |
+
for i in xrange(100):
|
497 |
+
data = next(iter(iterator))
|
498 |
+
data = data.to(trainingConfig["device"])
|
499 |
+
|
500 |
+
# true/fake labels
|
501 |
+
real_labels = torch.ones(data.size(0), 1).to(trainingConfig["device"])
|
502 |
+
|
503 |
+
vq_loss, data_recon, perplexity = model(data)
|
504 |
+
|
505 |
+
|
506 |
+
fake_preds = discriminator(data_recon)
|
507 |
+
adver_loss = adversarial_loss(fake_preds, real_labels)
|
508 |
+
|
509 |
+
log_magnitude_loss, phase_loss, rec_loss = reconstructionLoss(data_recon, data)
|
510 |
+
loss = rec_loss + trainingConfig["vq_weight"] * vq_loss + trainingConfig["adver_weight"] * adver_loss
|
511 |
+
|
512 |
+
train_res_error.append(loss.item())
|
513 |
+
initial_loss = np.mean(train_res_error)
|
514 |
+
return initial_loss
|
515 |
+
|
516 |
+
|
517 |
+
def get_VQGAN(model_Config, load_pretrain=False, model_name=None, device="cpu"):
|
518 |
+
VQVAE = VQGAN(**model_Config)
|
519 |
+
print(f"Model intialized, size: {sum(p.numel() for p in VQVAE.parameters() if p.requires_grad)}")
|
520 |
+
VQVAE.to(device)
|
521 |
+
|
522 |
+
if load_pretrain:
|
523 |
+
print(f"Loading weights from models/{model_name}_imageVQVAE.pth")
|
524 |
+
checkpoint = torch.load(f'models/{model_name}_imageVQVAE.pth', map_location=device)
|
525 |
+
VQVAE.load_state_dict(checkpoint['model_state_dict'])
|
526 |
+
VQVAE.eval()
|
527 |
+
return VQVAE
|
528 |
+
|
529 |
+
|
530 |
+
def train_VQGAN(model_Config, trainingConfig, iterator):
|
531 |
+
|
532 |
+
def save_model_hyperparameter(model_Config, trainingConfig, current_iter,
|
533 |
+
log_magnitude_loss, phase_loss, current_perplexity, current_vq_loss,
|
534 |
+
current_loss):
|
535 |
+
model_name = trainingConfig["model_name"]
|
536 |
+
model_hyperparameter = model_Config
|
537 |
+
model_hyperparameter.update(trainingConfig)
|
538 |
+
model_hyperparameter["current_iter"] = current_iter
|
539 |
+
model_hyperparameter["log_magnitude_loss"] = log_magnitude_loss
|
540 |
+
model_hyperparameter["phase_loss"] = phase_loss
|
541 |
+
model_hyperparameter["erplexity"] = current_perplexity
|
542 |
+
model_hyperparameter["vq_loss"] = current_vq_loss
|
543 |
+
model_hyperparameter["total_loss"] = current_loss
|
544 |
+
|
545 |
+
with open(f"models/hyperparameters/{model_name}_VQGAN_STFT.json", "w") as json_file:
|
546 |
+
json.dump(model_hyperparameter, json_file, ensure_ascii=False, indent=4)
|
547 |
+
|
548 |
+
# initialize VAE
|
549 |
+
model = VQGAN(**model_Config)
|
550 |
+
print(f"VQ_VAE size: {sum(p.numel() for p in model.parameters() if p.requires_grad)}")
|
551 |
+
model.to(trainingConfig["device"])
|
552 |
+
|
553 |
+
VAE_optimizer = torch.optim.Adam(model.parameters(), lr=trainingConfig["lr"], amsgrad=False)
|
554 |
+
model_name = trainingConfig["model_name"]
|
555 |
+
|
556 |
+
if trainingConfig["load_pretrain"]:
|
557 |
+
print(f"Loading weights from models/{model_name}_imageVQVAE.pth")
|
558 |
+
checkpoint = torch.load(f'models/{model_name}_imageVQVAE.pth', map_location=trainingConfig["device"])
|
559 |
+
model.load_state_dict(checkpoint['model_state_dict'])
|
560 |
+
VAE_optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
|
561 |
+
else:
|
562 |
+
print("VAE initialized.")
|
563 |
+
if trainingConfig["max_iter"] == 0:
|
564 |
+
print("Return VAE directly.")
|
565 |
+
return model
|
566 |
+
|
567 |
+
# initialize discriminator
|
568 |
+
discriminator = VQGAN_Discriminator(model_Config["in_channels"])
|
569 |
+
print(f"Discriminator size: {sum(p.numel() for p in discriminator.parameters() if p.requires_grad)}")
|
570 |
+
discriminator.to(trainingConfig["device"])
|
571 |
+
|
572 |
+
discriminator_optimizer = torch.optim.Adam(discriminator.parameters(), lr=trainingConfig["d_lr"], amsgrad=False)
|
573 |
+
|
574 |
+
if trainingConfig["load_pretrain"]:
|
575 |
+
print(f"Loading weights from models/{model_name}_imageVQVAE_discriminator.pth")
|
576 |
+
checkpoint = torch.load(f'models/{model_name}_imageVQVAE_discriminator.pth', map_location=trainingConfig["device"])
|
577 |
+
discriminator.load_state_dict(checkpoint['model_state_dict'])
|
578 |
+
discriminator_optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
|
579 |
+
else:
|
580 |
+
print("Discriminator initialized.")
|
581 |
+
|
582 |
+
# Training
|
583 |
+
|
584 |
+
train_res_phase_loss, train_res_perplexity, train_res_log_magnitude_loss, train_res_vq_loss, train_res_loss = [], [], [], [], []
|
585 |
+
train_discriminator_loss, train_adverserial_loss = [], []
|
586 |
+
|
587 |
+
reconstructionLoss = ReconstructionLoss(w1=trainingConfig["w1"], w2=trainingConfig["w2"], epsilon=trainingConfig["threshold"])
|
588 |
+
|
589 |
+
adversarial_loss = nn.BCEWithLogitsLoss()
|
590 |
+
writer = SummaryWriter(f'runs/{model_name}_VQVAE_lr=1e-4')
|
591 |
+
|
592 |
+
previous_lowest_loss = evaluate_VQGAN(model, discriminator, iterator,
|
593 |
+
reconstructionLoss, adversarial_loss, trainingConfig)
|
594 |
+
print(f"initial_loss: {previous_lowest_loss}")
|
595 |
+
|
596 |
+
model.train()
|
597 |
+
for i in xrange(trainingConfig["max_iter"]):
|
598 |
+
data = next(iter(iterator))
|
599 |
+
data = data.to(trainingConfig["device"])
|
600 |
+
|
601 |
+
# true/fake labels
|
602 |
+
real_labels = torch.ones(data.size(0), 1).to(trainingConfig["device"])
|
603 |
+
fake_labels = torch.zeros(data.size(0), 1).to(trainingConfig["device"])
|
604 |
+
|
605 |
+
# update discriminator
|
606 |
+
discriminator_optimizer.zero_grad()
|
607 |
+
|
608 |
+
vq_loss, data_recon, perplexity = model(data)
|
609 |
+
|
610 |
+
real_preds = discriminator(data)
|
611 |
+
fake_preds = discriminator(data_recon.detach())
|
612 |
+
|
613 |
+
loss_real = adversarial_loss(real_preds, real_labels)
|
614 |
+
loss_fake = adversarial_loss(fake_preds, fake_labels)
|
615 |
+
|
616 |
+
loss_D = loss_real + loss_fake
|
617 |
+
loss_D.backward()
|
618 |
+
discriminator_optimizer.step()
|
619 |
+
|
620 |
+
|
621 |
+
# update VQVAE
|
622 |
+
VAE_optimizer.zero_grad()
|
623 |
+
|
624 |
+
fake_preds = discriminator(data_recon)
|
625 |
+
adver_loss = adversarial_loss(fake_preds, real_labels)
|
626 |
+
|
627 |
+
log_magnitude_loss, phase_loss, rec_loss = reconstructionLoss(data_recon, data)
|
628 |
+
|
629 |
+
loss = rec_loss + trainingConfig["vq_weight"] * vq_loss + trainingConfig["adver_weight"] * adver_loss
|
630 |
+
loss.backward()
|
631 |
+
VAE_optimizer.step()
|
632 |
+
|
633 |
+
train_discriminator_loss.append(loss_D.item())
|
634 |
+
train_adverserial_loss.append(trainingConfig["adver_weight"] * adver_loss.item())
|
635 |
+
train_res_log_magnitude_loss.append(log_magnitude_loss.item())
|
636 |
+
train_res_phase_loss.append(phase_loss.item())
|
637 |
+
train_res_perplexity.append(perplexity.item())
|
638 |
+
train_res_vq_loss.append(trainingConfig["vq_weight"] * vq_loss.item())
|
639 |
+
train_res_loss.append(loss.item())
|
640 |
+
step = int(VAE_optimizer.state_dict()['state'][list(VAE_optimizer.state_dict()['state'].keys())[0]]['step'].cpu().numpy())
|
641 |
+
|
642 |
+
save_steps = trainingConfig["save_steps"]
|
643 |
+
if (i + 1) % 100 == 0:
|
644 |
+
print('%d step' % (step))
|
645 |
+
|
646 |
+
if (i + 1) % save_steps == 0:
|
647 |
+
current_discriminator_loss = np.mean(train_discriminator_loss[-save_steps:])
|
648 |
+
current_adverserial_loss = np.mean(train_adverserial_loss[-save_steps:])
|
649 |
+
current_log_magnitude_loss = np.mean(train_res_log_magnitude_loss[-save_steps:])
|
650 |
+
current_phase_loss = np.mean(train_res_phase_loss[-save_steps:])
|
651 |
+
current_perplexity = np.mean(train_res_perplexity[-save_steps:])
|
652 |
+
current_vq_loss = np.mean(train_res_vq_loss[-save_steps:])
|
653 |
+
current_loss = np.mean(train_res_loss[-save_steps:])
|
654 |
+
|
655 |
+
print('discriminator_loss: %.3f' % current_discriminator_loss)
|
656 |
+
print('adverserial_loss: %.3f' % current_adverserial_loss)
|
657 |
+
print('log_magnitude_loss: %.3f' % current_log_magnitude_loss)
|
658 |
+
print('phase_loss: %.3f' % current_phase_loss)
|
659 |
+
print('perplexity: %.3f' % current_perplexity)
|
660 |
+
print('vq_loss: %.3f' % current_vq_loss)
|
661 |
+
print('total_loss: %.3f' % current_loss)
|
662 |
+
writer.add_scalar(f"log_magnitude_loss", current_log_magnitude_loss, step)
|
663 |
+
writer.add_scalar(f"phase_loss", current_phase_loss, step)
|
664 |
+
writer.add_scalar(f"perplexity", current_perplexity, step)
|
665 |
+
writer.add_scalar(f"vq_loss", current_vq_loss, step)
|
666 |
+
writer.add_scalar(f"total_loss", current_loss, step)
|
667 |
+
if current_loss < previous_lowest_loss:
|
668 |
+
previous_lowest_loss = current_loss
|
669 |
+
|
670 |
+
torch.save({
|
671 |
+
'model_state_dict': model.state_dict(),
|
672 |
+
'optimizer_state_dict': VAE_optimizer.state_dict(),
|
673 |
+
}, f'models/{model_name}_imageVQVAE.pth')
|
674 |
+
|
675 |
+
torch.save({
|
676 |
+
'model_state_dict': discriminator.state_dict(),
|
677 |
+
'optimizer_state_dict': discriminator_optimizer.state_dict(),
|
678 |
+
}, f'models/{model_name}_imageVQVAE_discriminator.pth')
|
679 |
+
|
680 |
+
save_model_hyperparameter(model_Config, trainingConfig, step,
|
681 |
+
current_log_magnitude_loss, current_phase_loss, current_perplexity, current_vq_loss,
|
682 |
+
current_loss)
|
683 |
+
|
684 |
+
return model
|
model/__pycache__/DiffSynthSampler.cpython-310.pyc
ADDED
Binary file (12.8 kB). View file
|
|
model/__pycache__/GAN.cpython-310.pyc
ADDED
Binary file (7.49 kB). View file
|
|
model/__pycache__/VQGAN.cpython-310.pyc
ADDED
Binary file (18.8 kB). View file
|
|
model/__pycache__/diffusion.cpython-310.pyc
ADDED
Binary file (10.1 kB). View file
|
|
model/__pycache__/diffusion_components.cpython-310.pyc
ADDED
Binary file (11.6 kB). View file
|
|
model/__pycache__/multimodal_model.cpython-310.pyc
ADDED
Binary file (9.88 kB). View file
|
|
model/__pycache__/perceptual_label_predictor.cpython-37.pyc
ADDED
Binary file (1.67 kB). View file
|
|
model/__pycache__/timbre_encoder_pretrain.cpython-310.pyc
ADDED
Binary file (7.96 kB). View file
|
|
model/diffusion.py
ADDED
@@ -0,0 +1,371 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
from functools import partial
|
3 |
+
|
4 |
+
import numpy as np
|
5 |
+
import torch
|
6 |
+
from torch import nn
|
7 |
+
import torch.nn.functional as F
|
8 |
+
from six.moves import xrange
|
9 |
+
from torch.utils.tensorboard import SummaryWriter
|
10 |
+
import random
|
11 |
+
|
12 |
+
from metrics.IS import get_inception_score
|
13 |
+
from tools import create_key
|
14 |
+
|
15 |
+
from model.diffusion_components import default, ConvNextBlock, ResnetBlock, SinusoidalPositionEmbeddings, Residual, \
|
16 |
+
PreNorm, \
|
17 |
+
Downsample, Upsample, exists, q_sample, get_beta_schedule, pad_and_concat, ConditionalEmbedding, \
|
18 |
+
LinearCrossAttention, LinearCrossAttentionAdd
|
19 |
+
|
20 |
+
|
21 |
+
class ConditionedUnet(nn.Module):
|
22 |
+
def __init__(
|
23 |
+
self,
|
24 |
+
in_dim,
|
25 |
+
out_dim=None,
|
26 |
+
down_dims=None,
|
27 |
+
up_dims=None,
|
28 |
+
mid_depth=3,
|
29 |
+
with_time_emb=True,
|
30 |
+
time_dim=None,
|
31 |
+
resnet_block_groups=8,
|
32 |
+
use_convnext=True,
|
33 |
+
convnext_mult=2,
|
34 |
+
attn_type="linear_cat",
|
35 |
+
n_label_class=11,
|
36 |
+
condition_type="instrument_family",
|
37 |
+
label_emb_dim=128,
|
38 |
+
):
|
39 |
+
super().__init__()
|
40 |
+
|
41 |
+
self.label_embedding = ConditionalEmbedding(int(n_label_class + 1), int(label_emb_dim), condition_type)
|
42 |
+
|
43 |
+
if up_dims is None:
|
44 |
+
up_dims = [128, 128, 64, 32]
|
45 |
+
if down_dims is None:
|
46 |
+
down_dims = [32, 32, 64, 128]
|
47 |
+
|
48 |
+
out_dim = default(out_dim, in_dim)
|
49 |
+
assert len(down_dims) == len(up_dims), "len(down_dims) != len(up_dims)"
|
50 |
+
assert down_dims[0] == up_dims[-1], "down_dims[0] != up_dims[-1]"
|
51 |
+
assert up_dims[0] == down_dims[-1], "up_dims[0] != down_dims[-1]"
|
52 |
+
down_in_out = list(zip(down_dims[:-1], down_dims[1:]))
|
53 |
+
up_in_out = list(zip(up_dims[:-1], up_dims[1:]))
|
54 |
+
print(f"down_in_out: {down_in_out}")
|
55 |
+
print(f"up_in_out: {up_in_out}")
|
56 |
+
time_dim = default(time_dim, int(down_dims[0] * 4))
|
57 |
+
|
58 |
+
self.init_conv = nn.Conv2d(in_dim, down_dims[0], 7, padding=3)
|
59 |
+
|
60 |
+
if use_convnext:
|
61 |
+
block_klass = partial(ConvNextBlock, mult=convnext_mult)
|
62 |
+
else:
|
63 |
+
block_klass = partial(ResnetBlock, groups=resnet_block_groups)
|
64 |
+
|
65 |
+
if attn_type == "linear_cat":
|
66 |
+
attn_klass = partial(LinearCrossAttention)
|
67 |
+
elif attn_type == "linear_add":
|
68 |
+
attn_klass = partial(LinearCrossAttentionAdd)
|
69 |
+
else:
|
70 |
+
raise NotImplementedError()
|
71 |
+
|
72 |
+
# time embeddings
|
73 |
+
if with_time_emb:
|
74 |
+
self.time_mlp = nn.Sequential(
|
75 |
+
SinusoidalPositionEmbeddings(down_dims[0]),
|
76 |
+
nn.Linear(down_dims[0], time_dim),
|
77 |
+
nn.GELU(),
|
78 |
+
nn.Linear(time_dim, time_dim),
|
79 |
+
)
|
80 |
+
else:
|
81 |
+
time_dim = None
|
82 |
+
self.time_mlp = None
|
83 |
+
|
84 |
+
# left layers
|
85 |
+
self.downs = nn.ModuleList([])
|
86 |
+
self.ups = nn.ModuleList([])
|
87 |
+
skip_dims = []
|
88 |
+
|
89 |
+
for down_dim_in, down_dim_out in down_in_out:
|
90 |
+
self.downs.append(
|
91 |
+
nn.ModuleList(
|
92 |
+
[
|
93 |
+
block_klass(down_dim_in, down_dim_out, time_emb_dim=time_dim),
|
94 |
+
|
95 |
+
Residual(PreNorm(down_dim_out, attn_klass(down_dim_out, label_emb_dim=label_emb_dim, ))),
|
96 |
+
block_klass(down_dim_out, down_dim_out, time_emb_dim=time_dim),
|
97 |
+
Residual(PreNorm(down_dim_out, attn_klass(down_dim_out, label_emb_dim=label_emb_dim, ))),
|
98 |
+
Downsample(down_dim_out),
|
99 |
+
]
|
100 |
+
)
|
101 |
+
)
|
102 |
+
skip_dims.append(down_dim_out)
|
103 |
+
|
104 |
+
# bottleneck
|
105 |
+
mid_dim = down_dims[-1]
|
106 |
+
self.mid_left = nn.ModuleList([])
|
107 |
+
self.mid_right = nn.ModuleList([])
|
108 |
+
for _ in range(mid_depth - 1):
|
109 |
+
self.mid_left.append(block_klass(mid_dim, mid_dim, time_emb_dim=time_dim))
|
110 |
+
self.mid_right.append(block_klass(mid_dim * 2, mid_dim, time_emb_dim=time_dim))
|
111 |
+
self.mid_mid = nn.ModuleList(
|
112 |
+
[
|
113 |
+
block_klass(mid_dim, mid_dim, time_emb_dim=time_dim),
|
114 |
+
Residual(PreNorm(mid_dim, attn_klass(mid_dim, label_emb_dim=label_emb_dim, ))),
|
115 |
+
block_klass(mid_dim, mid_dim, time_emb_dim=time_dim),
|
116 |
+
]
|
117 |
+
)
|
118 |
+
|
119 |
+
# right layers
|
120 |
+
for ind, (up_dim_in, up_dim_out) in enumerate(up_in_out):
|
121 |
+
skip_dim = skip_dims.pop() # down_dim_out
|
122 |
+
self.ups.append(
|
123 |
+
nn.ModuleList(
|
124 |
+
[
|
125 |
+
# pop&cat (h/2, w/2, down_dim_out)
|
126 |
+
block_klass(up_dim_in + skip_dim, up_dim_in, time_emb_dim=time_dim),
|
127 |
+
Residual(PreNorm(up_dim_in, attn_klass(up_dim_in, label_emb_dim=label_emb_dim, ))),
|
128 |
+
Upsample(up_dim_in),
|
129 |
+
# pop&cat (h, w, down_dim_out)
|
130 |
+
block_klass(up_dim_in + skip_dim, up_dim_out, time_emb_dim=time_dim),
|
131 |
+
Residual(PreNorm(up_dim_out, attn_klass(up_dim_out, label_emb_dim=label_emb_dim, ))),
|
132 |
+
# pop&cat (h, w, down_dim_out)
|
133 |
+
block_klass(up_dim_out + skip_dim, up_dim_out, time_emb_dim=time_dim),
|
134 |
+
Residual(PreNorm(up_dim_out, attn_klass(up_dim_out, label_emb_dim=label_emb_dim, ))),
|
135 |
+
]
|
136 |
+
)
|
137 |
+
)
|
138 |
+
|
139 |
+
self.final_conv = nn.Sequential(
|
140 |
+
block_klass(down_dims[0] + up_dims[-1], up_dims[-1]), nn.Conv2d(up_dims[-1], out_dim, 3, padding=1)
|
141 |
+
)
|
142 |
+
|
143 |
+
def size(self):
|
144 |
+
total_params = sum(p.numel() for p in self.parameters())
|
145 |
+
trainable_params = sum(p.numel() for p in self.parameters() if p.requires_grad)
|
146 |
+
print(f"Total parameters: {total_params}")
|
147 |
+
print(f"Trainable parameters: {trainable_params}")
|
148 |
+
|
149 |
+
|
150 |
+
def forward(self, x, time, condition=None):
|
151 |
+
|
152 |
+
if condition is not None:
|
153 |
+
condition_emb = self.label_embedding(condition)
|
154 |
+
else:
|
155 |
+
condition_emb = None
|
156 |
+
|
157 |
+
h = []
|
158 |
+
|
159 |
+
x = self.init_conv(x)
|
160 |
+
h.append(x)
|
161 |
+
|
162 |
+
time_emb = self.time_mlp(time) if exists(self.time_mlp) else None
|
163 |
+
|
164 |
+
# downsample
|
165 |
+
for block1, attn1, block2, attn2, downsample in self.downs:
|
166 |
+
x = block1(x, time_emb)
|
167 |
+
x = attn1(x, condition_emb)
|
168 |
+
h.append(x)
|
169 |
+
x = block2(x, time_emb)
|
170 |
+
x = attn2(x, condition_emb)
|
171 |
+
h.append(x)
|
172 |
+
x = downsample(x)
|
173 |
+
h.append(x)
|
174 |
+
|
175 |
+
# bottleneck
|
176 |
+
|
177 |
+
for block in self.mid_left:
|
178 |
+
x = block(x, time_emb)
|
179 |
+
h.append(x)
|
180 |
+
|
181 |
+
(block1, attn, block2) = self.mid_mid
|
182 |
+
x = block1(x, time_emb)
|
183 |
+
x = attn(x, condition_emb)
|
184 |
+
x = block2(x, time_emb)
|
185 |
+
|
186 |
+
for block in self.mid_right:
|
187 |
+
# This is U-Net!!!
|
188 |
+
x = pad_and_concat(h.pop(), x)
|
189 |
+
x = block(x, time_emb)
|
190 |
+
|
191 |
+
# upsample
|
192 |
+
for block1, attn1, upsample, block2, attn2, block3, attn3 in self.ups:
|
193 |
+
x = pad_and_concat(h.pop(), x)
|
194 |
+
x = block1(x, time_emb)
|
195 |
+
x = attn1(x, condition_emb)
|
196 |
+
x = upsample(x)
|
197 |
+
|
198 |
+
x = pad_and_concat(h.pop(), x)
|
199 |
+
x = block2(x, time_emb)
|
200 |
+
x = attn2(x, condition_emb)
|
201 |
+
|
202 |
+
x = pad_and_concat(h.pop(), x)
|
203 |
+
x = block3(x, time_emb)
|
204 |
+
x = attn3(x, condition_emb)
|
205 |
+
|
206 |
+
x = pad_and_concat(h.pop(), x)
|
207 |
+
x = self.final_conv(x)
|
208 |
+
return x
|
209 |
+
|
210 |
+
|
211 |
+
def conditional_p_losses(denoise_model, x_start, t, condition, sqrt_alphas_cumprod, sqrt_one_minus_alphas_cumprod,
|
212 |
+
noise=None, loss_type="l1"):
|
213 |
+
if noise is None:
|
214 |
+
noise = torch.randn_like(x_start)
|
215 |
+
|
216 |
+
x_noisy = q_sample(x_start=x_start, t=t, sqrt_alphas_cumprod=sqrt_alphas_cumprod,
|
217 |
+
sqrt_one_minus_alphas_cumprod=sqrt_one_minus_alphas_cumprod, noise=noise)
|
218 |
+
predicted_noise = denoise_model(x_noisy, t, condition)
|
219 |
+
|
220 |
+
if loss_type == 'l1':
|
221 |
+
loss = F.l1_loss(noise, predicted_noise)
|
222 |
+
elif loss_type == 'l2':
|
223 |
+
loss = F.mse_loss(noise, predicted_noise)
|
224 |
+
elif loss_type == "huber":
|
225 |
+
loss = F.smooth_l1_loss(noise, predicted_noise)
|
226 |
+
else:
|
227 |
+
raise NotImplementedError()
|
228 |
+
|
229 |
+
return loss
|
230 |
+
|
231 |
+
|
232 |
+
def evaluate_diffusion_model(device, model, iterator, BATCH_SIZE, timesteps, unetConfig, encodes2embeddings_mapping,
|
233 |
+
uncondition_rate, unconditional_condition):
|
234 |
+
model.to(device)
|
235 |
+
model.eval()
|
236 |
+
eva_loss = []
|
237 |
+
sqrt_alphas_cumprod, sqrt_one_minus_alphas_cumprod, _, _ = get_beta_schedule(timesteps)
|
238 |
+
for i in xrange(500):
|
239 |
+
data, attributes = next(iter(iterator))
|
240 |
+
data = data.to(device)
|
241 |
+
|
242 |
+
conditions = [encodes2embeddings_mapping[create_key(attribute)] for attribute in attributes]
|
243 |
+
selected_conditions = [
|
244 |
+
unconditional_condition if random.random() < uncondition_rate else random.choice(conditions_of_one_sample)
|
245 |
+
for conditions_of_one_sample in conditions]
|
246 |
+
|
247 |
+
selected_conditions = torch.stack(selected_conditions).float().to(device)
|
248 |
+
|
249 |
+
t = torch.randint(0, timesteps, (BATCH_SIZE,), device=device).long()
|
250 |
+
loss = conditional_p_losses(model, data, t, selected_conditions, loss_type="huber",
|
251 |
+
sqrt_alphas_cumprod=sqrt_alphas_cumprod,
|
252 |
+
sqrt_one_minus_alphas_cumprod=sqrt_one_minus_alphas_cumprod)
|
253 |
+
|
254 |
+
eva_loss.append(loss.item())
|
255 |
+
initial_loss = np.mean(eva_loss)
|
256 |
+
return initial_loss
|
257 |
+
|
258 |
+
|
259 |
+
def get_diffusion_model(model_Config, load_pretrain=False, model_name=None, device="cpu"):
|
260 |
+
UNet = ConditionedUnet(**model_Config)
|
261 |
+
print(f"Model intialized, size: {sum(p.numel() for p in UNet.parameters() if p.requires_grad)}")
|
262 |
+
UNet.to(device)
|
263 |
+
|
264 |
+
if load_pretrain:
|
265 |
+
print(f"Loading weights from models/{model_name}_UNet.pth")
|
266 |
+
checkpoint = torch.load(f'models/{model_name}_UNet.pth', map_location=device)
|
267 |
+
UNet.load_state_dict(checkpoint['model_state_dict'])
|
268 |
+
UNet.eval()
|
269 |
+
return UNet
|
270 |
+
|
271 |
+
|
272 |
+
def train_diffusion_model(VAE, text_encoder, CLAP_tokenizer, timbre_encoder, device, init_model_name, unetConfig, BATCH_SIZE, timesteps, lr, max_iter, iterator, load_pretrain,
|
273 |
+
encodes2embeddings_mapping, uncondition_rate, unconditional_condition, save_steps=5000, init_loss=None, save_model_name=None,
|
274 |
+
n_IS_batches=50):
|
275 |
+
|
276 |
+
if save_model_name is None:
|
277 |
+
save_model_name = init_model_name
|
278 |
+
|
279 |
+
def save_model_hyperparameter(model_name, unetConfig, BATCH_SIZE, lr, model_size, current_iter, current_loss):
|
280 |
+
model_hyperparameter = unetConfig
|
281 |
+
model_hyperparameter["BATCH_SIZE"] = BATCH_SIZE
|
282 |
+
model_hyperparameter["lr"] = lr
|
283 |
+
model_hyperparameter["model_size"] = model_size
|
284 |
+
model_hyperparameter["current_iter"] = current_iter
|
285 |
+
model_hyperparameter["current_loss"] = current_loss
|
286 |
+
with open(f"models/hyperparameters/{model_name}_UNet.json", "w") as json_file:
|
287 |
+
json.dump(model_hyperparameter, json_file, ensure_ascii=False, indent=4)
|
288 |
+
|
289 |
+
model = ConditionedUnet(**unetConfig)
|
290 |
+
model_size = sum(p.numel() for p in model.parameters() if p.requires_grad)
|
291 |
+
print(f"Trainable parameters: {model_size}")
|
292 |
+
model.to(device)
|
293 |
+
optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=lr, amsgrad=False)
|
294 |
+
|
295 |
+
if load_pretrain:
|
296 |
+
print(f"Loading weights from models/{init_model_name}_UNet.pt")
|
297 |
+
checkpoint = torch.load(f'models/{init_model_name}_UNet.pth')
|
298 |
+
model.load_state_dict(checkpoint['model_state_dict'])
|
299 |
+
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
|
300 |
+
else:
|
301 |
+
print("Model initialized.")
|
302 |
+
if max_iter == 0:
|
303 |
+
print("Return model directly.")
|
304 |
+
return model, optimizer
|
305 |
+
|
306 |
+
|
307 |
+
train_loss = []
|
308 |
+
writer = SummaryWriter(f'runs/{save_model_name}_UNet')
|
309 |
+
if init_loss is None:
|
310 |
+
previous_loss = evaluate_diffusion_model(device, model, iterator, BATCH_SIZE, timesteps, unetConfig, encodes2embeddings_mapping,
|
311 |
+
uncondition_rate, unconditional_condition)
|
312 |
+
else:
|
313 |
+
previous_loss = init_loss
|
314 |
+
print(f"initial_IS: {previous_loss}")
|
315 |
+
sqrt_alphas_cumprod, sqrt_one_minus_alphas_cumprod, _, _ = get_beta_schedule(timesteps)
|
316 |
+
|
317 |
+
model.train()
|
318 |
+
for i in xrange(max_iter):
|
319 |
+
data, attributes = next(iter(iterator))
|
320 |
+
data = data.to(device)
|
321 |
+
|
322 |
+
conditions = [encodes2embeddings_mapping[create_key(attribute)] for attribute in attributes]
|
323 |
+
unconditional_condition_copy = torch.tensor(unconditional_condition, dtype=torch.float32).to(device).detach()
|
324 |
+
selected_conditions = [unconditional_condition_copy if random.random() < uncondition_rate else random.choice(
|
325 |
+
conditions_of_one_sample) for conditions_of_one_sample in conditions]
|
326 |
+
|
327 |
+
selected_conditions = torch.stack(selected_conditions).float().to(device)
|
328 |
+
|
329 |
+
optimizer.zero_grad()
|
330 |
+
|
331 |
+
t = torch.randint(0, timesteps, (BATCH_SIZE,), device=device).long()
|
332 |
+
loss = conditional_p_losses(model, data, t, selected_conditions, loss_type="huber",
|
333 |
+
sqrt_alphas_cumprod=sqrt_alphas_cumprod,
|
334 |
+
sqrt_one_minus_alphas_cumprod=sqrt_one_minus_alphas_cumprod)
|
335 |
+
|
336 |
+
loss.backward()
|
337 |
+
optimizer.step()
|
338 |
+
|
339 |
+
train_loss.append(loss.item())
|
340 |
+
step = int(optimizer.state_dict()['state'][list(optimizer.state_dict()['state'].keys())[0]]['step'].numpy())
|
341 |
+
|
342 |
+
if step % 100 == 0:
|
343 |
+
print('%d step' % (step))
|
344 |
+
|
345 |
+
if step % save_steps == 0:
|
346 |
+
current_loss = np.mean(train_loss[-save_steps:])
|
347 |
+
print(f"current_loss = {current_loss}")
|
348 |
+
torch.save({
|
349 |
+
'model_state_dict': model.state_dict(),
|
350 |
+
'optimizer_state_dict': optimizer.state_dict(),
|
351 |
+
}, f'models/{save_model_name}_UNet.pth')
|
352 |
+
save_model_hyperparameter(save_model_name, unetConfig, BATCH_SIZE, lr, model_size, step, current_loss)
|
353 |
+
|
354 |
+
|
355 |
+
if step % 20000 == 0:
|
356 |
+
current_IS = get_inception_score(device, model, VAE, text_encoder, CLAP_tokenizer, timbre_encoder, n_IS_batches,
|
357 |
+
positive_prompts="", negative_prompts="", CFG=1, sample_steps=20, task="STFT")
|
358 |
+
print('current_IS: %.5f' % current_IS)
|
359 |
+
current_loss = np.mean(train_loss[-save_steps:])
|
360 |
+
|
361 |
+
writer.add_scalar(f"current_IS", current_IS, step)
|
362 |
+
|
363 |
+
torch.save({
|
364 |
+
'model_state_dict': model.state_dict(),
|
365 |
+
'optimizer_state_dict': optimizer.state_dict(),
|
366 |
+
}, f'models/history/{save_model_name}_{step}_UNet.pth')
|
367 |
+
save_model_hyperparameter(save_model_name, unetConfig, BATCH_SIZE, lr, model_size, step, current_loss)
|
368 |
+
|
369 |
+
return model, optimizer
|
370 |
+
|
371 |
+
|
model/diffusion_components.py
ADDED
@@ -0,0 +1,351 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch.nn.functional as F
|
2 |
+
import torch
|
3 |
+
from torch import nn
|
4 |
+
from einops import rearrange
|
5 |
+
from inspect import isfunction
|
6 |
+
import math
|
7 |
+
from tqdm import tqdm
|
8 |
+
|
9 |
+
|
10 |
+
def exists(x):
|
11 |
+
"""Return true for x is not None."""
|
12 |
+
return x is not None
|
13 |
+
|
14 |
+
|
15 |
+
def default(val, d):
|
16 |
+
"""Helper function"""
|
17 |
+
if exists(val):
|
18 |
+
return val
|
19 |
+
return d() if isfunction(d) else d
|
20 |
+
|
21 |
+
|
22 |
+
class Residual(nn.Module):
|
23 |
+
"""Skip connection"""
|
24 |
+
def __init__(self, fn):
|
25 |
+
super().__init__()
|
26 |
+
self.fn = fn
|
27 |
+
|
28 |
+
def forward(self, x, *args, **kwargs):
|
29 |
+
return self.fn(x, *args, **kwargs) + x
|
30 |
+
|
31 |
+
|
32 |
+
def Upsample(dim):
|
33 |
+
"""Upsample layer, a transposed convolution layer with stride=2"""
|
34 |
+
return nn.ConvTranspose2d(dim, dim, 4, 2, 1)
|
35 |
+
|
36 |
+
|
37 |
+
def Downsample(dim):
|
38 |
+
"""Downsample layer, a convolution layer with stride=2"""
|
39 |
+
return nn.Conv2d(dim, dim, 4, 2, 1)
|
40 |
+
|
41 |
+
|
42 |
+
class SinusoidalPositionEmbeddings(nn.Module):
|
43 |
+
"""Return sinusoidal embedding for integer time step."""
|
44 |
+
|
45 |
+
def __init__(self, dim):
|
46 |
+
super().__init__()
|
47 |
+
self.dim = dim
|
48 |
+
|
49 |
+
def forward(self, time):
|
50 |
+
device = time.device
|
51 |
+
half_dim = self.dim // 2
|
52 |
+
embeddings = math.log(10000) / (half_dim - 1)
|
53 |
+
embeddings = torch.exp(torch.arange(half_dim, device=device) * -embeddings)
|
54 |
+
embeddings = time[:, None] * embeddings[None, :]
|
55 |
+
embeddings = torch.cat((embeddings.sin(), embeddings.cos()), dim=-1)
|
56 |
+
return embeddings
|
57 |
+
|
58 |
+
|
59 |
+
class Block(nn.Module):
|
60 |
+
"""Stack of convolution, normalization, and non-linear activation"""
|
61 |
+
|
62 |
+
def __init__(self, dim, dim_out, groups=8):
|
63 |
+
super().__init__()
|
64 |
+
self.proj = nn.Conv2d(dim, dim_out, 3, padding=1)
|
65 |
+
self.norm = nn.GroupNorm(groups, dim_out)
|
66 |
+
self.act = nn.SiLU()
|
67 |
+
|
68 |
+
def forward(self, x, scale_shift=None):
|
69 |
+
x = self.proj(x)
|
70 |
+
x = self.norm(x)
|
71 |
+
|
72 |
+
if exists(scale_shift):
|
73 |
+
scale, shift = scale_shift
|
74 |
+
x = x * (scale + 1) + shift
|
75 |
+
|
76 |
+
x = self.act(x)
|
77 |
+
return x
|
78 |
+
|
79 |
+
|
80 |
+
class ResnetBlock(nn.Module):
|
81 |
+
"""Stack of [conv + norm + act (+ scale&shift)], with positional embedding inserted <https://arxiv.org/abs/1512.03385>"""
|
82 |
+
|
83 |
+
def __init__(self, dim, dim_out, *, time_emb_dim=None, groups=8):
|
84 |
+
super().__init__()
|
85 |
+
self.mlp = (
|
86 |
+
nn.Sequential(nn.SiLU(), nn.Linear(time_emb_dim, dim_out))
|
87 |
+
if exists(time_emb_dim)
|
88 |
+
else None
|
89 |
+
)
|
90 |
+
|
91 |
+
self.block1 = Block(dim, dim_out, groups=groups)
|
92 |
+
self.block2 = Block(dim_out, dim_out, groups=groups)
|
93 |
+
self.res_conv = nn.Conv2d(dim, dim_out, 1) if dim != dim_out else nn.Identity()
|
94 |
+
|
95 |
+
def forward(self, x, time_emb=None):
|
96 |
+
h = self.block1(x)
|
97 |
+
|
98 |
+
if exists(self.mlp) and exists(time_emb):
|
99 |
+
time_emb = self.mlp(time_emb)
|
100 |
+
# Adding positional embedding to intermediate layer (by broadcasting along spatial dimension)
|
101 |
+
h = rearrange(time_emb, "b c -> b c 1 1") + h
|
102 |
+
|
103 |
+
h = self.block2(h)
|
104 |
+
return h + self.res_conv(x)
|
105 |
+
|
106 |
+
|
107 |
+
class ConvNextBlock(nn.Module):
|
108 |
+
"""Stack of [conv7x7 (+ condition(pos)) + norm + conv3x3 + act + norm + conv3x3 + res1x1],with positional embedding inserted"""
|
109 |
+
|
110 |
+
def __init__(self, dim, dim_out, *, time_emb_dim=None, mult=2, norm=True):
|
111 |
+
super().__init__()
|
112 |
+
self.mlp = (
|
113 |
+
nn.Sequential(nn.GELU(), nn.Linear(time_emb_dim, dim))
|
114 |
+
if exists(time_emb_dim)
|
115 |
+
else None
|
116 |
+
)
|
117 |
+
|
118 |
+
self.ds_conv = nn.Conv2d(dim, dim, 7, padding=3, groups=dim)
|
119 |
+
|
120 |
+
self.net = nn.Sequential(
|
121 |
+
nn.GroupNorm(1, dim) if norm else nn.Identity(),
|
122 |
+
nn.Conv2d(dim, dim_out * mult, 3, padding=1),
|
123 |
+
nn.GELU(),
|
124 |
+
nn.GroupNorm(1, dim_out * mult),
|
125 |
+
nn.Conv2d(dim_out * mult, dim_out, 3, padding=1),
|
126 |
+
)
|
127 |
+
|
128 |
+
self.res_conv = nn.Conv2d(dim, dim_out, 1) if dim != dim_out else nn.Identity()
|
129 |
+
|
130 |
+
def forward(self, x, time_emb=None):
|
131 |
+
h = self.ds_conv(x)
|
132 |
+
|
133 |
+
if exists(self.mlp) and exists(time_emb):
|
134 |
+
assert exists(time_emb), "time embedding must be passed in"
|
135 |
+
condition = self.mlp(time_emb)
|
136 |
+
h = h + rearrange(condition, "b c -> b c 1 1")
|
137 |
+
|
138 |
+
h = self.net(h)
|
139 |
+
return h + self.res_conv(x)
|
140 |
+
|
141 |
+
|
142 |
+
class PreNorm(nn.Module):
|
143 |
+
"""Apply normalization before 'fn'"""
|
144 |
+
|
145 |
+
def __init__(self, dim, fn):
|
146 |
+
super().__init__()
|
147 |
+
self.fn = fn
|
148 |
+
self.norm = nn.GroupNorm(1, dim)
|
149 |
+
|
150 |
+
def forward(self, x, *args, **kwargs):
|
151 |
+
x = self.norm(x)
|
152 |
+
return self.fn(x, *args, **kwargs)
|
153 |
+
|
154 |
+
|
155 |
+
class ConditionalEmbedding(nn.Module):
|
156 |
+
"""Return embedding for label and projection for text embedding"""
|
157 |
+
|
158 |
+
def __init__(self, num_labels, embedding_dim, condition_type="instrument_family"):
|
159 |
+
super(ConditionalEmbedding, self).__init__()
|
160 |
+
if condition_type == "instrument_family":
|
161 |
+
self.embedding = nn.Embedding(num_labels, embedding_dim)
|
162 |
+
elif condition_type == "natural_language_prompt":
|
163 |
+
self.embedding = nn.Linear(embedding_dim, embedding_dim, bias=True)
|
164 |
+
else:
|
165 |
+
raise NotImplementedError()
|
166 |
+
|
167 |
+
def forward(self, labels):
|
168 |
+
return self.embedding(labels)
|
169 |
+
|
170 |
+
|
171 |
+
class LinearCrossAttention(nn.Module):
|
172 |
+
"""Combination of efficient attention and cross attention."""
|
173 |
+
|
174 |
+
def __init__(self, dim, heads=4, label_emb_dim=128, dim_head=32):
|
175 |
+
super().__init__()
|
176 |
+
self.dim_head = dim_head
|
177 |
+
self.scale = dim_head ** -0.5
|
178 |
+
self.heads = heads
|
179 |
+
hidden_dim = dim_head * heads
|
180 |
+
self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias=False)
|
181 |
+
self.to_out = nn.Sequential(nn.Conv2d(hidden_dim, dim, 1), nn.GroupNorm(1, dim))
|
182 |
+
|
183 |
+
# embedding for key and value
|
184 |
+
self.label_key = nn.Linear(label_emb_dim, hidden_dim)
|
185 |
+
self.label_value = nn.Linear(label_emb_dim, hidden_dim)
|
186 |
+
|
187 |
+
def forward(self, x, label_embedding=None):
|
188 |
+
b, c, h, w = x.shape
|
189 |
+
qkv = self.to_qkv(x).chunk(3, dim=1)
|
190 |
+
q, k, v = map(
|
191 |
+
lambda t: rearrange(t, "b (h c) x y -> b h c (x y)", h=self.heads), qkv
|
192 |
+
)
|
193 |
+
|
194 |
+
if label_embedding is not None:
|
195 |
+
label_k = self.label_key(label_embedding).view(b, self.heads, self.dim_head, 1)
|
196 |
+
label_v = self.label_value(label_embedding).view(b, self.heads, self.dim_head, 1)
|
197 |
+
|
198 |
+
k = torch.cat([k, label_k], dim=-1)
|
199 |
+
v = torch.cat([v, label_v], dim=-1)
|
200 |
+
|
201 |
+
q = q.softmax(dim=-2)
|
202 |
+
k = k.softmax(dim=-1)
|
203 |
+
q = q * self.scale
|
204 |
+
context = torch.einsum("b h d n, b h e n -> b h d e", k, v)
|
205 |
+
out = torch.einsum("b h d e, b h d n -> b h e n", context, q)
|
206 |
+
out = rearrange(out, "b h c (x y) -> b (h c) x y", h=self.heads, x=h, y=w)
|
207 |
+
return self.to_out(out)
|
208 |
+
|
209 |
+
|
210 |
+
def pad_to_match(encoder_tensor, decoder_tensor):
|
211 |
+
"""
|
212 |
+
Pads the decoder_tensor to match the spatial dimensions of encoder_tensor.
|
213 |
+
|
214 |
+
:param encoder_tensor: The feature map from the encoder.
|
215 |
+
:param decoder_tensor: The feature map from the decoder that needs to be upsampled.
|
216 |
+
:return: Padded decoder_tensor with the same spatial dimensions as encoder_tensor.
|
217 |
+
"""
|
218 |
+
|
219 |
+
enc_shape = encoder_tensor.shape[2:] # spatial dimensions are at index 2 and 3
|
220 |
+
dec_shape = decoder_tensor.shape[2:]
|
221 |
+
|
222 |
+
# assume enc_shape >= dec_shape
|
223 |
+
delta_w = enc_shape[1] - dec_shape[1]
|
224 |
+
delta_h = enc_shape[0] - dec_shape[0]
|
225 |
+
|
226 |
+
# padding
|
227 |
+
padding_left = delta_w // 2
|
228 |
+
padding_right = delta_w - padding_left
|
229 |
+
padding_top = delta_h // 2
|
230 |
+
padding_bottom = delta_h - padding_top
|
231 |
+
decoder_tensor_padded = F.pad(decoder_tensor, (padding_left, padding_right, padding_top, padding_bottom))
|
232 |
+
|
233 |
+
return decoder_tensor_padded
|
234 |
+
|
235 |
+
|
236 |
+
def pad_and_concat(encoder_tensor, decoder_tensor):
|
237 |
+
"""
|
238 |
+
Pads the decoder_tensor and concatenates it with the encoder_tensor along the channel dimension.
|
239 |
+
|
240 |
+
:param encoder_tensor: The feature map from the encoder.
|
241 |
+
:param decoder_tensor: The feature map from the decoder that needs to be concatenated with encoder_tensor.
|
242 |
+
:return: Concatenated tensor.
|
243 |
+
"""
|
244 |
+
|
245 |
+
# pad decoder_tensor
|
246 |
+
decoder_tensor_padded = pad_to_match(encoder_tensor, decoder_tensor)
|
247 |
+
# concat encoder_tensor and decoder_tensor_padded
|
248 |
+
concatenated_tensor = torch.cat((encoder_tensor, decoder_tensor_padded), dim=1)
|
249 |
+
return concatenated_tensor
|
250 |
+
|
251 |
+
|
252 |
+
class LinearCrossAttentionAdd(nn.Module):
|
253 |
+
def __init__(self, dim, heads=4, label_emb_dim=128, dim_head=32):
|
254 |
+
super().__init__()
|
255 |
+
self.dim = dim
|
256 |
+
self.dim_head = dim_head
|
257 |
+
self.scale = dim_head ** -0.5
|
258 |
+
self.heads = heads
|
259 |
+
self.label_emb_dim = label_emb_dim
|
260 |
+
self.dim_head = dim_head
|
261 |
+
|
262 |
+
self.hidden_dim = dim_head * heads
|
263 |
+
self.to_qkv = nn.Conv2d(self.dim, self.hidden_dim * 3, 1, bias=False)
|
264 |
+
self.to_out = nn.Sequential(nn.Conv2d(self.hidden_dim, dim, 1), nn.GroupNorm(1, dim))
|
265 |
+
|
266 |
+
# embedding for key and value
|
267 |
+
self.label_key = nn.Linear(label_emb_dim, self.hidden_dim)
|
268 |
+
self.label_query = nn.Linear(label_emb_dim, self.hidden_dim)
|
269 |
+
|
270 |
+
|
271 |
+
def forward(self, x, condition=None):
|
272 |
+
b, c, h, w = x.shape
|
273 |
+
|
274 |
+
qkv = self.to_qkv(x).chunk(3, dim=1)
|
275 |
+
|
276 |
+
q, k, v = map(
|
277 |
+
lambda t: rearrange(t, "b (h c) x y -> b h c (x y)", h=self.heads), qkv
|
278 |
+
)
|
279 |
+
|
280 |
+
# if condition exists,concat its key and value with origin
|
281 |
+
if condition is not None:
|
282 |
+
label_k = self.label_key(condition).view(b, self.heads, self.dim_head, 1)
|
283 |
+
label_q = self.label_query(condition).view(b, self.heads, self.dim_head, 1)
|
284 |
+
k = k + label_k
|
285 |
+
q = q + label_q
|
286 |
+
|
287 |
+
q = q.softmax(dim=-2)
|
288 |
+
k = k.softmax(dim=-1)
|
289 |
+
q = q * self.scale
|
290 |
+
context = torch.einsum("b h d n, b h e n -> b h d e", k, v)
|
291 |
+
out = torch.einsum("b h d e, b h d n -> b h e n", context, q)
|
292 |
+
out = rearrange(out, "b h c (x y) -> b (h c) x y", h=self.heads, x=h, y=w)
|
293 |
+
return self.to_out(out)
|
294 |
+
|
295 |
+
|
296 |
+
|
297 |
+
def linear_beta_schedule(timesteps):
|
298 |
+
beta_start = 0.0001
|
299 |
+
beta_end = 0.02
|
300 |
+
return torch.linspace(beta_start, beta_end, timesteps)
|
301 |
+
|
302 |
+
|
303 |
+
def get_beta_schedule(timesteps):
|
304 |
+
betas = linear_beta_schedule(timesteps=timesteps)
|
305 |
+
|
306 |
+
# define alphas
|
307 |
+
alphas = 1. - betas
|
308 |
+
alphas_cumprod = torch.cumprod(alphas, axis=0)
|
309 |
+
alphas_cumprod_prev = F.pad(alphas_cumprod[:-1], (1, 0), value=1.0)
|
310 |
+
sqrt_recip_alphas = torch.sqrt(1.0 / alphas)
|
311 |
+
|
312 |
+
# calculations for diffusion q(x_t | x_{t-1}) and others
|
313 |
+
sqrt_alphas_cumprod = torch.sqrt(alphas_cumprod)
|
314 |
+
sqrt_one_minus_alphas_cumprod = torch.sqrt(1. - alphas_cumprod)
|
315 |
+
|
316 |
+
# calculations for posterior q(x_{t-1} | x_t, x_0)
|
317 |
+
posterior_variance = betas * (1. - alphas_cumprod_prev) / (1. - alphas_cumprod)
|
318 |
+
return sqrt_alphas_cumprod, sqrt_one_minus_alphas_cumprod, posterior_variance, sqrt_recip_alphas
|
319 |
+
|
320 |
+
|
321 |
+
def extract(a, t, x_shape):
|
322 |
+
batch_size = t.shape[0]
|
323 |
+
out = a.gather(-1, t.cpu())
|
324 |
+
return out.reshape(batch_size, *((1,) * (len(x_shape) - 1))).to(t.device)
|
325 |
+
|
326 |
+
|
327 |
+
# forward diffusion
|
328 |
+
def q_sample(x_start, t, sqrt_alphas_cumprod, sqrt_one_minus_alphas_cumprod, noise=None):
|
329 |
+
if noise is None:
|
330 |
+
noise = torch.randn_like(x_start)
|
331 |
+
|
332 |
+
sqrt_alphas_cumprod_t = extract(sqrt_alphas_cumprod, t, x_start.shape)
|
333 |
+
sqrt_one_minus_alphas_cumprod_t = extract(
|
334 |
+
sqrt_one_minus_alphas_cumprod, t, x_start.shape
|
335 |
+
)
|
336 |
+
|
337 |
+
return sqrt_alphas_cumprod_t * x_start + sqrt_one_minus_alphas_cumprod_t * noise
|
338 |
+
|
339 |
+
|
340 |
+
|
341 |
+
|
342 |
+
|
343 |
+
|
344 |
+
|
345 |
+
|
346 |
+
|
347 |
+
|
348 |
+
|
349 |
+
|
350 |
+
|
351 |
+
|
model/multimodal_model.py
ADDED
@@ -0,0 +1,274 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import itertools
|
2 |
+
import json
|
3 |
+
import random
|
4 |
+
|
5 |
+
import numpy as np
|
6 |
+
import torch
|
7 |
+
from torch import nn
|
8 |
+
import torch.nn.functional as F
|
9 |
+
|
10 |
+
from tools import create_key
|
11 |
+
from model.timbre_encoder_pretrain import get_timbre_encoder
|
12 |
+
|
13 |
+
|
14 |
+
class ProjectionLayer(nn.Module):
|
15 |
+
"""Single-layer Linear projection with dropout, layer norm, and Gelu activation"""
|
16 |
+
|
17 |
+
def __init__(self, input_dim, output_dim, dropout):
|
18 |
+
super(ProjectionLayer, self).__init__()
|
19 |
+
self.projection = nn.Linear(input_dim, output_dim)
|
20 |
+
self.gelu = nn.GELU()
|
21 |
+
self.fc = nn.Linear(output_dim, output_dim)
|
22 |
+
self.dropout = nn.Dropout(dropout)
|
23 |
+
self.layer_norm = nn.LayerNorm(output_dim)
|
24 |
+
|
25 |
+
def forward(self, x):
|
26 |
+
projected = self.projection(x)
|
27 |
+
x = self.gelu(projected)
|
28 |
+
x = self.fc(x)
|
29 |
+
x = self.dropout(x)
|
30 |
+
x = x + projected
|
31 |
+
x = self.layer_norm(x)
|
32 |
+
return x
|
33 |
+
|
34 |
+
|
35 |
+
class ProjectionHead(nn.Module):
|
36 |
+
"""Stack of 'ProjectionLayer'"""
|
37 |
+
|
38 |
+
def __init__(self, embedding_dim, projection_dim, dropout, num_layers=2):
|
39 |
+
super(ProjectionHead, self).__init__()
|
40 |
+
self.layers = nn.ModuleList([ProjectionLayer(embedding_dim if i == 0 else projection_dim,
|
41 |
+
projection_dim,
|
42 |
+
dropout) for i in range(num_layers)])
|
43 |
+
|
44 |
+
def forward(self, x):
|
45 |
+
for layer in self.layers:
|
46 |
+
x = layer(x)
|
47 |
+
return x
|
48 |
+
|
49 |
+
|
50 |
+
class multi_modal_model(nn.Module):
|
51 |
+
"""The multi-modal model for contrastive learning"""
|
52 |
+
|
53 |
+
def __init__(
|
54 |
+
self,
|
55 |
+
timbre_encoder,
|
56 |
+
text_encoder,
|
57 |
+
spectrogram_feature_dim,
|
58 |
+
text_feature_dim,
|
59 |
+
multi_modal_emb_dim,
|
60 |
+
temperature,
|
61 |
+
dropout,
|
62 |
+
num_projection_layers=1,
|
63 |
+
freeze_spectrogram_encoder=True,
|
64 |
+
freeze_text_encoder=True,
|
65 |
+
):
|
66 |
+
super().__init__()
|
67 |
+
self.timbre_encoder = timbre_encoder
|
68 |
+
self.text_encoder = text_encoder
|
69 |
+
|
70 |
+
self.multi_modal_emb_dim = multi_modal_emb_dim
|
71 |
+
|
72 |
+
self.text_projection = ProjectionHead(embedding_dim=text_feature_dim,
|
73 |
+
projection_dim=self.multi_modal_emb_dim, dropout=dropout,
|
74 |
+
num_layers=num_projection_layers)
|
75 |
+
|
76 |
+
self.spectrogram_projection = ProjectionHead(embedding_dim=spectrogram_feature_dim,
|
77 |
+
projection_dim=self.multi_modal_emb_dim, dropout=dropout,
|
78 |
+
num_layers=num_projection_layers)
|
79 |
+
|
80 |
+
self.temperature = temperature
|
81 |
+
|
82 |
+
# Make spectrogram_encoder parameters non-trainable
|
83 |
+
for param in self.timbre_encoder.parameters():
|
84 |
+
param.requires_grad = not freeze_spectrogram_encoder
|
85 |
+
|
86 |
+
# Make text_encoder parameters non-trainable
|
87 |
+
for param in self.text_encoder.parameters():
|
88 |
+
param.requires_grad = not freeze_text_encoder
|
89 |
+
|
90 |
+
def forward(self, spectrogram_batch, tokenized_text_batch):
|
91 |
+
# Getting Image and Text Embeddings (with same dimension)
|
92 |
+
spectrogram_features, _, _, _, _ = self.timbre_encoder(spectrogram_batch)
|
93 |
+
text_features = self.text_encoder.get_text_features(**tokenized_text_batch)
|
94 |
+
|
95 |
+
# Concat and apply projection
|
96 |
+
spectrogram_embeddings = self.spectrogram_projection(spectrogram_features)
|
97 |
+
text_embeddings = self.text_projection(text_features)
|
98 |
+
|
99 |
+
# Calculating the Loss
|
100 |
+
logits = (text_embeddings @ spectrogram_embeddings.T) / self.temperature
|
101 |
+
images_similarity = spectrogram_embeddings @ spectrogram_embeddings.T
|
102 |
+
texts_similarity = text_embeddings @ text_embeddings.T
|
103 |
+
targets = F.softmax(
|
104 |
+
(images_similarity + texts_similarity) / 2 * self.temperature, dim=-1
|
105 |
+
)
|
106 |
+
texts_loss = cross_entropy(logits, targets, reduction='none')
|
107 |
+
images_loss = cross_entropy(logits.T, targets.T, reduction='none')
|
108 |
+
contrastive_loss = (images_loss + texts_loss) / 2.0 # shape: (batch_size)
|
109 |
+
contrastive_loss = contrastive_loss.mean()
|
110 |
+
|
111 |
+
return contrastive_loss
|
112 |
+
|
113 |
+
|
114 |
+
def get_text_features(self, input_ids, attention_mask):
|
115 |
+
text_features = self.text_encoder.get_text_features(input_ids=input_ids, attention_mask=attention_mask)
|
116 |
+
return self.text_projection(text_features)
|
117 |
+
|
118 |
+
|
119 |
+
def get_timbre_features(self, spectrogram_batch):
|
120 |
+
spectrogram_features, _, _, _, _ = self.timbre_encoder(spectrogram_batch)
|
121 |
+
return self.spectrogram_projection(spectrogram_features)
|
122 |
+
|
123 |
+
|
124 |
+
def cross_entropy(preds, targets, reduction='none'):
|
125 |
+
log_softmax = nn.LogSoftmax(dim=-1)
|
126 |
+
loss = (-targets * log_softmax(preds)).sum(1)
|
127 |
+
if reduction == "none":
|
128 |
+
return loss
|
129 |
+
elif reduction == "mean":
|
130 |
+
return loss.mean()
|
131 |
+
|
132 |
+
|
133 |
+
def get_multi_modal_model(timbre_encoder, text_encoder, model_Config, load_pretrain=False, model_name=None, device="cpu"):
|
134 |
+
mmm = multi_modal_model(timbre_encoder, text_encoder, **model_Config)
|
135 |
+
print(f"Model intialized, size: {sum(p.numel() for p in mmm.parameters() if p.requires_grad)}")
|
136 |
+
mmm.to(device)
|
137 |
+
|
138 |
+
if load_pretrain:
|
139 |
+
print(f"Loading weights from models/{model_name}_MMM.pth")
|
140 |
+
checkpoint = torch.load(f'models/{model_name}_MMM.pth', map_location=device)
|
141 |
+
mmm.load_state_dict(checkpoint['model_state_dict'])
|
142 |
+
mmm.eval()
|
143 |
+
return mmm
|
144 |
+
|
145 |
+
|
146 |
+
def train_epoch(text_tokenizer, model, train_loader, labels_mapping, optimizer, device):
|
147 |
+
(data, attributes) = next(iter(train_loader))
|
148 |
+
keys = [create_key(attribute) for attribute in attributes]
|
149 |
+
|
150 |
+
while(len(set(keys)) != len(keys)):
|
151 |
+
(data, attributes) = next(iter(train_loader))
|
152 |
+
keys = [create_key(attribute) for attribute in attributes]
|
153 |
+
|
154 |
+
data = data.to(device)
|
155 |
+
|
156 |
+
texts = [labels_mapping[create_key(attribute)] for attribute in attributes]
|
157 |
+
selected_texts = [l[random.randint(0, len(l) - 1)] for l in texts]
|
158 |
+
|
159 |
+
tokenized_text = text_tokenizer(selected_texts, padding=True, return_tensors="pt").to(device)
|
160 |
+
|
161 |
+
loss = model(data, tokenized_text)
|
162 |
+
optimizer.zero_grad()
|
163 |
+
loss.backward()
|
164 |
+
optimizer.step()
|
165 |
+
|
166 |
+
return loss.item()
|
167 |
+
|
168 |
+
|
169 |
+
def valid_epoch(text_tokenizer, model, valid_loader, labels_mapping, device):
|
170 |
+
(data, attributes) = next(iter(valid_loader))
|
171 |
+
keys = [create_key(attribute) for attribute in attributes]
|
172 |
+
|
173 |
+
while(len(set(keys)) != len(keys)):
|
174 |
+
(data, attributes) = next(iter(valid_loader))
|
175 |
+
keys = [create_key(attribute) for attribute in attributes]
|
176 |
+
|
177 |
+
data = data.to(device)
|
178 |
+
texts = [labels_mapping[create_key(attribute)] for attribute in attributes]
|
179 |
+
selected_texts = [l[random.randint(0, len(l) - 1)] for l in texts]
|
180 |
+
|
181 |
+
tokenized_text = text_tokenizer(selected_texts, padding=True, return_tensors="pt").to(device)
|
182 |
+
|
183 |
+
loss = model(data, tokenized_text)
|
184 |
+
return loss.item()
|
185 |
+
|
186 |
+
|
187 |
+
def train_multi_modal_model(device, training_dataloader, labels_mapping, text_tokenizer, text_encoder,
|
188 |
+
timbre_encoder_Config, MMM_config, MMM_training_config,
|
189 |
+
mmm_name, BATCH_SIZE, max_iter=0, load_pretrain=True,
|
190 |
+
timbre_encoder_name=None, init_loss=None, save_steps=2000):
|
191 |
+
|
192 |
+
def save_model_hyperparameter(model_name, MMM_config, MMM_training_config, BATCH_SIZE, model_size, current_iter,
|
193 |
+
current_loss):
|
194 |
+
|
195 |
+
model_hyperparameter = MMM_config
|
196 |
+
model_hyperparameter.update(MMM_training_config)
|
197 |
+
model_hyperparameter["BATCH_SIZE"] = BATCH_SIZE
|
198 |
+
model_hyperparameter["model_size"] = model_size
|
199 |
+
model_hyperparameter["current_iter"] = current_iter
|
200 |
+
model_hyperparameter["current_loss"] = current_loss
|
201 |
+
with open(f"models/hyperparameters/{model_name}_MMM.json", "w") as json_file:
|
202 |
+
json.dump(model_hyperparameter, json_file, ensure_ascii=False, indent=4)
|
203 |
+
|
204 |
+
timbreEncoder = get_timbre_encoder(timbre_encoder_Config, load_pretrain=True, model_name=timbre_encoder_name,
|
205 |
+
device=device)
|
206 |
+
|
207 |
+
mmm = multi_modal_model(timbreEncoder, text_encoder, **MMM_config).to(device)
|
208 |
+
|
209 |
+
print(f"spectrogram_encoder parameter: {sum(p.numel() for p in mmm.timbre_encoder.parameters())}")
|
210 |
+
print(f"text_encoder parameter: {sum(p.numel() for p in mmm.text_encoder.parameters())}")
|
211 |
+
print(f"spectrogram_projection parameter: {sum(p.numel() for p in mmm.spectrogram_projection.parameters())}")
|
212 |
+
print(f"text_projection parameter: {sum(p.numel() for p in mmm.text_projection.parameters())}")
|
213 |
+
total_parameters = sum(p.numel() for p in mmm.parameters())
|
214 |
+
trainable_parameters = sum(p.numel() for p in mmm.parameters() if p.requires_grad)
|
215 |
+
print(f"Trainable/Total parameter: {trainable_parameters}/{total_parameters}")
|
216 |
+
|
217 |
+
params = [
|
218 |
+
{"params": itertools.chain(
|
219 |
+
mmm.spectrogram_projection.parameters(),
|
220 |
+
mmm.text_projection.parameters(),
|
221 |
+
), "lr": MMM_training_config["head_lr"], "weight_decay": MMM_training_config["head_weight_decay"]},
|
222 |
+
]
|
223 |
+
if not MMM_config["freeze_text_encoder"]:
|
224 |
+
params.append({"params": mmm.text_encoder.parameters(), "lr": MMM_training_config["text_encoder_lr"],
|
225 |
+
"weight_decay": MMM_training_config["text_encoder_weight_decay"]})
|
226 |
+
if not MMM_config["freeze_spectrogram_encoder"]:
|
227 |
+
params.append({"params": mmm.timbre_encoder.parameters(), "lr": MMM_training_config["spectrogram_encoder_lr"],
|
228 |
+
"weight_decay": MMM_training_config["timbre_encoder_weight_decay"]})
|
229 |
+
|
230 |
+
optimizer = torch.optim.AdamW(params, weight_decay=0.)
|
231 |
+
|
232 |
+
if load_pretrain:
|
233 |
+
print(f"Loading weights from models/{mmm_name}_MMM.pt")
|
234 |
+
checkpoint = torch.load(f'models/{mmm_name}_MMM.pth')
|
235 |
+
mmm.load_state_dict(checkpoint['model_state_dict'])
|
236 |
+
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
|
237 |
+
else:
|
238 |
+
print("Model initialized.")
|
239 |
+
|
240 |
+
if max_iter == 0:
|
241 |
+
print("Return model directly.")
|
242 |
+
return mmm, optimizer
|
243 |
+
|
244 |
+
if init_loss is None:
|
245 |
+
previous_lowest_loss = valid_epoch(text_tokenizer, mmm, training_dataloader, labels_mapping, device)
|
246 |
+
else:
|
247 |
+
previous_lowest_loss = init_loss
|
248 |
+
print(f"Initial total loss: {previous_lowest_loss}")
|
249 |
+
|
250 |
+
train_loss_list = []
|
251 |
+
for i in range(max_iter):
|
252 |
+
|
253 |
+
mmm.train()
|
254 |
+
train_loss = train_epoch(text_tokenizer, mmm, training_dataloader, labels_mapping, optimizer, device)
|
255 |
+
train_loss_list.append(train_loss)
|
256 |
+
|
257 |
+
step = int(
|
258 |
+
optimizer.state_dict()['state'][list(optimizer.state_dict()['state'].keys())[0]]['step'].cpu().numpy())
|
259 |
+
if (i + 1) % 100 == 0:
|
260 |
+
print('%d step' % (step))
|
261 |
+
|
262 |
+
if (i + 1) % save_steps == 0:
|
263 |
+
current_loss = np.mean(train_loss_list[-save_steps:])
|
264 |
+
print(f"train_total_loss: {current_loss}")
|
265 |
+
if current_loss < previous_lowest_loss:
|
266 |
+
previous_lowest_loss = current_loss
|
267 |
+
torch.save({
|
268 |
+
'model_state_dict': mmm.state_dict(),
|
269 |
+
'optimizer_state_dict': optimizer.state_dict(),
|
270 |
+
}, f'models/{mmm_name}_MMM.pth')
|
271 |
+
save_model_hyperparameter(mmm_name, MMM_config, MMM_training_config, BATCH_SIZE, total_parameters, step,
|
272 |
+
current_loss)
|
273 |
+
|
274 |
+
return mmm, optimizer
|
model/timbre_encoder_pretrain.py
ADDED
@@ -0,0 +1,220 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
import numpy as np
|
3 |
+
import torch
|
4 |
+
from torch import nn
|
5 |
+
from torch.utils.tensorboard import SummaryWriter
|
6 |
+
from tools import create_key
|
7 |
+
|
8 |
+
|
9 |
+
class TimbreEncoder(nn.Module):
|
10 |
+
def __init__(self, input_dim, feature_dim, hidden_dim, num_instrument_classes, num_instrument_family_classes, num_velocity_classes, num_qualities, num_layers=1):
|
11 |
+
super(TimbreEncoder, self).__init__()
|
12 |
+
|
13 |
+
# Input layer
|
14 |
+
self.input_layer = nn.Linear(input_dim, feature_dim)
|
15 |
+
|
16 |
+
# LSTM Layer
|
17 |
+
self.lstm = nn.LSTM(feature_dim, hidden_dim, num_layers=num_layers, batch_first=True)
|
18 |
+
|
19 |
+
# Fully Connected Layers for classification
|
20 |
+
self.instrument_classifier_layer = nn.Linear(hidden_dim, num_instrument_classes)
|
21 |
+
self.instrument_family_classifier_layer = nn.Linear(hidden_dim, num_instrument_family_classes)
|
22 |
+
self.velocity_classifier_layer = nn.Linear(hidden_dim, num_velocity_classes)
|
23 |
+
self.qualities_classifier_layer = nn.Linear(hidden_dim, num_qualities)
|
24 |
+
|
25 |
+
# Softmax for converting output to probabilities
|
26 |
+
self.softmax = nn.LogSoftmax(dim=1)
|
27 |
+
|
28 |
+
def forward(self, x):
|
29 |
+
# # Merge first two dimensions
|
30 |
+
batch_size, _, _, seq_len = x.shape
|
31 |
+
x = x.view(batch_size, -1, seq_len) # [batch_size, input_dim, seq_len]
|
32 |
+
|
33 |
+
# Forward propagate LSTM
|
34 |
+
x = x.permute(0, 2, 1)
|
35 |
+
x = self.input_layer(x)
|
36 |
+
feature, _ = self.lstm(x)
|
37 |
+
feature = feature[:, -1, :]
|
38 |
+
|
39 |
+
# Apply classification layers
|
40 |
+
instrument_logits = self.instrument_classifier_layer(feature)
|
41 |
+
instrument_family_logits = self.instrument_family_classifier_layer(feature)
|
42 |
+
velocity_logits = self.velocity_classifier_layer(feature)
|
43 |
+
qualities = self.qualities_classifier_layer(feature)
|
44 |
+
|
45 |
+
# Apply Softmax
|
46 |
+
instrument_logits = self.softmax(instrument_logits)
|
47 |
+
instrument_family_logits= self.softmax(instrument_family_logits)
|
48 |
+
velocity_logits = self.softmax(velocity_logits)
|
49 |
+
qualities = torch.sigmoid(qualities)
|
50 |
+
|
51 |
+
return feature, instrument_logits, instrument_family_logits, velocity_logits, qualities
|
52 |
+
|
53 |
+
|
54 |
+
def get_multiclass_acc(outputs, ground_truth):
|
55 |
+
_, predicted = torch.max(outputs.data, 1)
|
56 |
+
total = ground_truth.size(0)
|
57 |
+
correct = (predicted == ground_truth).sum().item()
|
58 |
+
accuracy = 100 * correct / total
|
59 |
+
return accuracy
|
60 |
+
|
61 |
+
def get_binary_accuracy(y_pred, y_true):
|
62 |
+
predictions = (y_pred > 0.5).int()
|
63 |
+
|
64 |
+
correct_predictions = (predictions == y_true).float()
|
65 |
+
|
66 |
+
accuracy = correct_predictions.mean()
|
67 |
+
|
68 |
+
return accuracy.item() * 100.0
|
69 |
+
|
70 |
+
|
71 |
+
def get_timbre_encoder(model_Config, load_pretrain=False, model_name=None, device="cpu"):
|
72 |
+
timbreEncoder = TimbreEncoder(**model_Config)
|
73 |
+
print(f"Model intialized, size: {sum(p.numel() for p in timbreEncoder.parameters() if p.requires_grad)}")
|
74 |
+
timbreEncoder.to(device)
|
75 |
+
|
76 |
+
if load_pretrain:
|
77 |
+
print(f"Loading weights from models/{model_name}_timbre_encoder.pth")
|
78 |
+
checkpoint = torch.load(f'models/{model_name}_timbre_encoder.pth', map_location=device)
|
79 |
+
timbreEncoder.load_state_dict(checkpoint['model_state_dict'])
|
80 |
+
timbreEncoder.eval()
|
81 |
+
return timbreEncoder
|
82 |
+
|
83 |
+
|
84 |
+
def evaluate_timbre_encoder(device, model, iterator, nll_Loss, bce_Loss, n_sample=100):
|
85 |
+
model.to(device)
|
86 |
+
model.eval()
|
87 |
+
|
88 |
+
eva_loss = []
|
89 |
+
for i in range(n_sample):
|
90 |
+
representation, attributes = next(iter(iterator))
|
91 |
+
|
92 |
+
instrument = torch.tensor([s["instrument"] for s in attributes], dtype=torch.long).to(device)
|
93 |
+
instrument_family = torch.tensor([s["instrument_family"] for s in attributes], dtype=torch.long).to(device)
|
94 |
+
velocity = torch.tensor([s["velocity"] for s in attributes], dtype=torch.long).to(device)
|
95 |
+
qualities = torch.tensor([[int(char) for char in create_key(attribute)[-10:]] for attribute in attributes], dtype=torch.float32).to(device)
|
96 |
+
|
97 |
+
_, instrument_logits, instrument_family_logits, velocity_logits, qualities_pred = model(representation.to(device))
|
98 |
+
|
99 |
+
# compute loss
|
100 |
+
instrument_loss = nll_Loss(instrument_logits, instrument)
|
101 |
+
instrument_family_loss = nll_Loss(instrument_family_logits, instrument_family)
|
102 |
+
velocity_loss = nll_Loss(velocity_logits, velocity)
|
103 |
+
qualities_loss = bce_Loss(qualities_pred, qualities)
|
104 |
+
|
105 |
+
loss = instrument_loss + instrument_family_loss + velocity_loss + qualities_loss
|
106 |
+
|
107 |
+
eva_loss.append(loss.item())
|
108 |
+
|
109 |
+
eva_loss = np.mean(eva_loss)
|
110 |
+
return eva_loss
|
111 |
+
|
112 |
+
|
113 |
+
def train_timbre_encoder(device, model_name, timbre_encoder_Config, BATCH_SIZE, lr, max_iter, training_iterator, load_pretrain):
|
114 |
+
def save_model_hyperparameter(model_name, timbre_encoder_Config, BATCH_SIZE, lr, model_size, current_iter,
|
115 |
+
current_loss):
|
116 |
+
model_hyperparameter = timbre_encoder_Config
|
117 |
+
model_hyperparameter["BATCH_SIZE"] = BATCH_SIZE
|
118 |
+
model_hyperparameter["lr"] = lr
|
119 |
+
model_hyperparameter["model_size"] = model_size
|
120 |
+
model_hyperparameter["current_iter"] = current_iter
|
121 |
+
model_hyperparameter["current_loss"] = current_loss
|
122 |
+
with open(f"models/hyperparameters/{model_name}_timbre_encoder.json", "w") as json_file:
|
123 |
+
json.dump(model_hyperparameter, json_file, ensure_ascii=False, indent=4)
|
124 |
+
|
125 |
+
model = TimbreEncoder(**timbre_encoder_Config)
|
126 |
+
model_size = sum(p.numel() for p in model.parameters() if p.requires_grad)
|
127 |
+
print(f"Model size: {model_size}")
|
128 |
+
model.to(device)
|
129 |
+
nll_Loss = torch.nn.NLLLoss()
|
130 |
+
bce_Loss = torch.nn.BCELoss()
|
131 |
+
|
132 |
+
optimizer = torch.optim.Adam(model.parameters(), lr=lr, amsgrad=False)
|
133 |
+
|
134 |
+
if load_pretrain:
|
135 |
+
print(f"Loading weights from models/{model_name}_timbre_encoder.pt")
|
136 |
+
checkpoint = torch.load(f'models/{model_name}_timbre_encoder.pth')
|
137 |
+
model.load_state_dict(checkpoint['model_state_dict'])
|
138 |
+
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
|
139 |
+
else:
|
140 |
+
print("Model initialized.")
|
141 |
+
if max_iter == 0:
|
142 |
+
print("Return model directly.")
|
143 |
+
return model, model
|
144 |
+
|
145 |
+
train_loss, training_instrument_acc, training_instrument_family_acc, training_velocity_acc, training_qualities_acc = [], [], [], [], []
|
146 |
+
writer = SummaryWriter(f'runs/{model_name}_timbre_encoder')
|
147 |
+
current_best_model = model
|
148 |
+
previous_lowest_loss = 100.0
|
149 |
+
print(f"initial__loss: {previous_lowest_loss}")
|
150 |
+
|
151 |
+
for i in range(max_iter):
|
152 |
+
model.train()
|
153 |
+
|
154 |
+
representation, attributes = next(iter(training_iterator))
|
155 |
+
|
156 |
+
instrument = torch.tensor([s["instrument"] for s in attributes], dtype=torch.long).to(device)
|
157 |
+
instrument_family = torch.tensor([s["instrument_family"] for s in attributes], dtype=torch.long).to(device)
|
158 |
+
velocity = torch.tensor([s["velocity"] for s in attributes], dtype=torch.long).to(device)
|
159 |
+
qualities = torch.tensor([[int(char) for char in create_key(attribute)[-10:]] for attribute in attributes], dtype=torch.float32).to(device)
|
160 |
+
|
161 |
+
optimizer.zero_grad()
|
162 |
+
|
163 |
+
_, instrument_logits, instrument_family_logits, velocity_logits, qualities_pred = model(representation.to(device))
|
164 |
+
|
165 |
+
# compute loss
|
166 |
+
instrument_loss = nll_Loss(instrument_logits, instrument)
|
167 |
+
instrument_family_loss = nll_Loss(instrument_family_logits, instrument_family)
|
168 |
+
velocity_loss = nll_Loss(velocity_logits, velocity)
|
169 |
+
qualities_loss = bce_Loss(qualities_pred, qualities)
|
170 |
+
|
171 |
+
loss = instrument_loss + instrument_family_loss + velocity_loss + qualities_loss
|
172 |
+
|
173 |
+
loss.backward()
|
174 |
+
optimizer.step()
|
175 |
+
instrument_acc = get_multiclass_acc(instrument_logits, instrument)
|
176 |
+
instrument_family_acc = get_multiclass_acc(instrument_family_logits, instrument_family)
|
177 |
+
velocity_acc = get_multiclass_acc(velocity_logits, velocity)
|
178 |
+
qualities_acc = get_binary_accuracy(qualities_pred, qualities)
|
179 |
+
|
180 |
+
train_loss.append(loss.item())
|
181 |
+
training_instrument_acc.append(instrument_acc)
|
182 |
+
training_instrument_family_acc.append(instrument_family_acc)
|
183 |
+
training_velocity_acc.append(velocity_acc)
|
184 |
+
training_qualities_acc.append(qualities_acc)
|
185 |
+
step = int(optimizer.state_dict()['state'][list(optimizer.state_dict()['state'].keys())[0]]['step'].numpy())
|
186 |
+
|
187 |
+
if (i + 1) % 100 == 0:
|
188 |
+
print('%d step' % (step))
|
189 |
+
|
190 |
+
save_steps = 500
|
191 |
+
if (i + 1) % save_steps == 0:
|
192 |
+
current_loss = np.mean(train_loss[-save_steps:])
|
193 |
+
current_instrument_acc = np.mean(training_instrument_acc[-save_steps:])
|
194 |
+
current_instrument_family_acc = np.mean(training_instrument_family_acc[-save_steps:])
|
195 |
+
current_velocity_acc = np.mean(training_velocity_acc[-save_steps:])
|
196 |
+
current_qualities_acc = np.mean(training_qualities_acc[-save_steps:])
|
197 |
+
print('train_loss: %.5f' % current_loss)
|
198 |
+
print('current_instrument_acc: %.5f' % current_instrument_acc)
|
199 |
+
print('current_instrument_family_acc: %.5f' % current_instrument_family_acc)
|
200 |
+
print('current_velocity_acc: %.5f' % current_velocity_acc)
|
201 |
+
print('current_qualities_acc: %.5f' % current_qualities_acc)
|
202 |
+
writer.add_scalar(f"train_loss", current_loss, step)
|
203 |
+
writer.add_scalar(f"current_instrument_acc", current_instrument_acc, step)
|
204 |
+
writer.add_scalar(f"current_instrument_family_acc", current_instrument_family_acc, step)
|
205 |
+
writer.add_scalar(f"current_velocity_acc", current_velocity_acc, step)
|
206 |
+
writer.add_scalar(f"current_qualities_acc", current_qualities_acc, step)
|
207 |
+
|
208 |
+
if current_loss < previous_lowest_loss:
|
209 |
+
previous_lowest_loss = current_loss
|
210 |
+
current_best_model = model
|
211 |
+
torch.save({
|
212 |
+
'model_state_dict': model.state_dict(),
|
213 |
+
'optimizer_state_dict': optimizer.state_dict(),
|
214 |
+
}, f'models/{model_name}_timbre_encoder.pth')
|
215 |
+
save_model_hyperparameter(model_name, timbre_encoder_Config, BATCH_SIZE, lr, model_size, step,
|
216 |
+
current_loss)
|
217 |
+
|
218 |
+
return model, current_best_model
|
219 |
+
|
220 |
+
|
models/24_1_2024-52_4x_L_D_imageVQVAE.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:e5feec46219e25e6f95bfa453a4bddb3ec7bc26d29f2e01748defa4901762c9f
|
3 |
+
size 16069859
|
models/24_1_2024-52_4x_L_D_imageVQVAE_discriminator.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:69080ff5094f82bba6e69e4310cda27864420dfea9b08d3a001dafe46bbf6808
|
3 |
+
size 134268962
|
models/24_1_2024_MMM.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:494f7fa1f9874ebd2cd870b6da993cb796e907e857a28550e62be8c335fb9f5a
|
3 |
+
size 1930637291
|
models/24_1_2024_STFT_timbre_encoder.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:9b9288a7fc4a8cdc4b0db5803fd0a5e88e6bcf07bdde22f49ebb8dec12fb33e6
|
3 |
+
size 294502949
|
models/history/28_1_2024_TE_STFT_300000_UNet.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:edcdd3c65275badad3233eabdcb7a3ffa7adbf95c673172ca2e35338097d1a1c
|
3 |
+
size 1284015362
|
requirements.txt
ADDED
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
torchmetrics==0.7.0
|
2 |
+
torchsynth==1.0.2
|
3 |
+
torchaudio
|
4 |
+
soundfile
|
5 |
+
einops
|
6 |
+
pytorch-ssim
|
7 |
+
piqa
|
8 |
+
torchinfo
|
9 |
+
mido
|
10 |
+
tensorboard
|
11 |
+
librosa
|
12 |
+
transformers
|
13 |
+
matplotlib
|
14 |
+
gradio==3.50.2
|
15 |
+
|
tools.py
ADDED
@@ -0,0 +1,344 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import matplotlib.pyplot as plt
|
3 |
+
import matplotlib
|
4 |
+
import librosa
|
5 |
+
from scipy.io.wavfile import write
|
6 |
+
import torch
|
7 |
+
|
8 |
+
k = 1e-16
|
9 |
+
|
10 |
+
def np_log10(x):
|
11 |
+
"""Safe log function with base 10."""
|
12 |
+
numerator = np.log(x + 1e-16)
|
13 |
+
denominator = np.log(10)
|
14 |
+
return numerator / denominator
|
15 |
+
|
16 |
+
|
17 |
+
def sigmoid(x):
|
18 |
+
"""Safe log function with base 10."""
|
19 |
+
s = 1 / (1 + np.exp(-x))
|
20 |
+
return s
|
21 |
+
|
22 |
+
|
23 |
+
def inv_sigmoid(s):
|
24 |
+
"""Safe inverse sigmoid function."""
|
25 |
+
x = np.log((s / (1 - s)) + 1e-16)
|
26 |
+
return x
|
27 |
+
|
28 |
+
|
29 |
+
def spc_to_VAE_input(spc):
|
30 |
+
"""Restrict value range from [0, infinite] to [0, 1]. (deprecated )"""
|
31 |
+
return spc / (1 + spc)
|
32 |
+
|
33 |
+
|
34 |
+
def VAE_out_put_to_spc(o):
|
35 |
+
"""Inverse transform of function 'spc_to_VAE_input'. (deprecated )"""
|
36 |
+
return o / (1 - o + k)
|
37 |
+
|
38 |
+
|
39 |
+
|
40 |
+
def np_power_to_db(S, amin=1e-16, top_db=80.0):
|
41 |
+
"""Helper method for numpy data scaling. (deprecated )"""
|
42 |
+
ref = S.max()
|
43 |
+
|
44 |
+
log_spec = 10.0 * np_log10(np.maximum(amin, S))
|
45 |
+
log_spec -= 10.0 * np_log10(np.maximum(amin, ref))
|
46 |
+
|
47 |
+
log_spec = np.maximum(log_spec, log_spec.max() - top_db)
|
48 |
+
|
49 |
+
return log_spec
|
50 |
+
|
51 |
+
|
52 |
+
def show_spc(spc):
|
53 |
+
"""Show a spectrogram. (deprecated )"""
|
54 |
+
s = np.shape(spc)
|
55 |
+
spc = np.reshape(spc, (s[0], s[1]))
|
56 |
+
magnitude_spectrum = np.abs(spc)
|
57 |
+
log_spectrum = np_power_to_db(magnitude_spectrum)
|
58 |
+
plt.imshow(np.flipud(log_spectrum))
|
59 |
+
plt.show()
|
60 |
+
|
61 |
+
|
62 |
+
def save_results(spectrogram, spectrogram_image_path, waveform_path):
|
63 |
+
"""Save the input 'spectrogram' and its waveform (reconstructed by Griffin Lim)
|
64 |
+
to path provided by 'spectrogram_image_path' and 'waveform_path'."""
|
65 |
+
magnitude_spectrum = np.abs(spectrogram)
|
66 |
+
log_spc = np_power_to_db(magnitude_spectrum)
|
67 |
+
log_spc = np.reshape(log_spc, (512, 256))
|
68 |
+
matplotlib.pyplot.imsave(spectrogram_image_path, log_spc, vmin=-100, vmax=0,
|
69 |
+
origin='lower')
|
70 |
+
|
71 |
+
# save waveform
|
72 |
+
abs_spec = np.zeros((513, 256))
|
73 |
+
abs_spec[:512, :] = abs_spec[:512, :] + np.sqrt(np.reshape(spectrogram, (512, 256)))
|
74 |
+
rec_signal = librosa.griffinlim(abs_spec, n_iter=32, hop_length=256, win_length=1024)
|
75 |
+
write(waveform_path, 16000, rec_signal)
|
76 |
+
|
77 |
+
|
78 |
+
def plot_log_spectrogram(signal: np.ndarray,
|
79 |
+
path: str,
|
80 |
+
n_fft=2048,
|
81 |
+
frame_length=1024,
|
82 |
+
frame_step=256):
|
83 |
+
"""Save spectrogram."""
|
84 |
+
stft = librosa.stft(signal, n_fft=n_fft, hop_length=frame_step, win_length=frame_length)
|
85 |
+
amp = np.square(np.real(stft)) + np.square(np.imag(stft))
|
86 |
+
magnitude_spectrum = np.abs(amp)
|
87 |
+
log_mel = np_power_to_db(magnitude_spectrum)
|
88 |
+
matplotlib.pyplot.imsave(path, log_mel, vmin=-100, vmax=0, origin='lower')
|
89 |
+
|
90 |
+
|
91 |
+
def visualize_feature_maps(device, model, inputs, channel_indices=[0, 3,]):
|
92 |
+
"""
|
93 |
+
Visualize feature maps before and after quantization for given input.
|
94 |
+
|
95 |
+
Parameters:
|
96 |
+
- model: Your VQ-VAE model.
|
97 |
+
- inputs: A batch of input data.
|
98 |
+
- channel_indices: Indices of feature map channels to visualize.
|
99 |
+
"""
|
100 |
+
model.eval()
|
101 |
+
inputs = inputs.to(device)
|
102 |
+
|
103 |
+
with torch.no_grad():
|
104 |
+
z_e = model._encoder(inputs)
|
105 |
+
z_q, loss, (perplexity, min_encodings, min_encoding_indices) = model._vq_vae(z_e)
|
106 |
+
|
107 |
+
# Assuming inputs have shape [batch_size, channels, height, width]
|
108 |
+
batch_size = z_e.size(0)
|
109 |
+
|
110 |
+
for idx in range(batch_size):
|
111 |
+
fig, axs = plt.subplots(1, len(channel_indices)*2, figsize=(15, 5))
|
112 |
+
|
113 |
+
for i, channel_idx in enumerate(channel_indices):
|
114 |
+
# Plot encoder output
|
115 |
+
axs[2*i].imshow(z_e[idx][channel_idx].cpu().numpy(), cmap='viridis')
|
116 |
+
axs[2*i].set_title(f"Encoder Output - Channel {channel_idx}")
|
117 |
+
|
118 |
+
# Plot quantized output
|
119 |
+
axs[2*i+1].imshow(z_q[idx][channel_idx].cpu().numpy(), cmap='viridis')
|
120 |
+
axs[2*i+1].set_title(f"Quantized Output - Channel {channel_idx}")
|
121 |
+
|
122 |
+
plt.show()
|
123 |
+
|
124 |
+
|
125 |
+
def adjust_audio_length(audio, desired_length, original_sample_rate, target_sample_rate):
|
126 |
+
"""
|
127 |
+
Adjust the audio length to the desired length and resample to target sample rate.
|
128 |
+
|
129 |
+
Parameters:
|
130 |
+
- audio (np.array): The input audio signal
|
131 |
+
- desired_length (int): The desired length of the output audio
|
132 |
+
- original_sample_rate (int): The original sample rate of the audio
|
133 |
+
- target_sample_rate (int): The target sample rate for the output audio
|
134 |
+
|
135 |
+
Returns:
|
136 |
+
- np.array: The adjusted and resampled audio
|
137 |
+
"""
|
138 |
+
|
139 |
+
if not (original_sample_rate == target_sample_rate):
|
140 |
+
audio = librosa.core.resample(audio, orig_sr=original_sample_rate, target_sr=target_sample_rate)
|
141 |
+
|
142 |
+
if len(audio) > desired_length:
|
143 |
+
return audio[:desired_length]
|
144 |
+
|
145 |
+
elif len(audio) < desired_length:
|
146 |
+
padded_audio = np.zeros(desired_length)
|
147 |
+
padded_audio[:len(audio)] = audio
|
148 |
+
return padded_audio
|
149 |
+
else:
|
150 |
+
return audio
|
151 |
+
|
152 |
+
|
153 |
+
def safe_int(s, default=0):
|
154 |
+
try:
|
155 |
+
return int(s)
|
156 |
+
except ValueError:
|
157 |
+
return default
|
158 |
+
|
159 |
+
|
160 |
+
def pad_spectrogram(D):
|
161 |
+
"""Resize spectrogram to (512, 256). (deprecated )"""
|
162 |
+
D = D[1:, :]
|
163 |
+
|
164 |
+
padding_length = 256 - D.shape[1]
|
165 |
+
D_padded = np.pad(D, ((0, 0), (0, padding_length)), 'constant')
|
166 |
+
return D_padded
|
167 |
+
|
168 |
+
|
169 |
+
def pad_STFT(D, time_resolution=256):
|
170 |
+
"""Resize spectral matrix by padding and cropping"""
|
171 |
+
D = D[1:, :]
|
172 |
+
|
173 |
+
if time_resolution is None:
|
174 |
+
return D
|
175 |
+
|
176 |
+
padding_length = time_resolution - D.shape[1]
|
177 |
+
if padding_length > 0:
|
178 |
+
D_padded = np.pad(D, ((0, 0), (0, padding_length)), 'constant')
|
179 |
+
return D_padded
|
180 |
+
else:
|
181 |
+
return D
|
182 |
+
|
183 |
+
|
184 |
+
def depad_STFT(D_padded):
|
185 |
+
"""Inverse function of 'pad_STFT'"""
|
186 |
+
zero_row = np.zeros((1, D_padded.shape[1]))
|
187 |
+
|
188 |
+
D_restored = np.concatenate([zero_row, D_padded], axis=0)
|
189 |
+
|
190 |
+
return D_restored
|
191 |
+
|
192 |
+
|
193 |
+
def nnData2Audio(spectrogram_batch, resolution=(512, 256), squared=False):
|
194 |
+
"""Transform batch of numpy spectrogram into signals and encodings."""
|
195 |
+
# Todo: remove resolution hard-coding
|
196 |
+
frequency_resolution, time_resolution = resolution
|
197 |
+
|
198 |
+
if isinstance(spectrogram_batch, torch.Tensor):
|
199 |
+
spectrogram_batch = spectrogram_batch.to("cpu").detach().numpy()
|
200 |
+
|
201 |
+
origin_signals = []
|
202 |
+
for spectrogram in spectrogram_batch:
|
203 |
+
spc = VAE_out_put_to_spc(spectrogram)
|
204 |
+
|
205 |
+
# get_audio
|
206 |
+
abs_spec = np.zeros((frequency_resolution+1, time_resolution))
|
207 |
+
|
208 |
+
if squared:
|
209 |
+
abs_spec[1:, :] = abs_spec[1:, :] + np.sqrt(np.reshape(spc, (frequency_resolution, time_resolution)))
|
210 |
+
else:
|
211 |
+
abs_spec[1:, :] = abs_spec[1:, :] + np.reshape(spc, (frequency_resolution, time_resolution))
|
212 |
+
|
213 |
+
origin_signal = librosa.griffinlim(abs_spec, n_iter=32, hop_length=256, win_length=1024)
|
214 |
+
origin_signals.append(origin_signal)
|
215 |
+
|
216 |
+
return origin_signals
|
217 |
+
|
218 |
+
|
219 |
+
def amp_to_audio(amp, n_iter=50):
|
220 |
+
"""The Griffin-Lim algorithm."""
|
221 |
+
y_reconstructed = librosa.griffinlim(amp, n_iter=n_iter, hop_length=256, win_length=1024)
|
222 |
+
return y_reconstructed
|
223 |
+
|
224 |
+
|
225 |
+
def rescale(amp, method="log1p"):
|
226 |
+
"""Rescale function."""
|
227 |
+
if method == "log1p":
|
228 |
+
return np.log1p(amp)
|
229 |
+
elif method == "NormalizedLogisticCompression":
|
230 |
+
return amp / (1.0 + amp)
|
231 |
+
else:
|
232 |
+
raise NotImplementedError()
|
233 |
+
|
234 |
+
|
235 |
+
def unrescale(scaled_amp, method="NormalizedLogisticCompression"):
|
236 |
+
"""Inverse function of 'rescale'"""
|
237 |
+
if method == "log1p":
|
238 |
+
return np.expm1(scaled_amp)
|
239 |
+
elif method == "NormalizedLogisticCompression":
|
240 |
+
return scaled_amp / (1.0 - scaled_amp + 1e-10)
|
241 |
+
else:
|
242 |
+
raise NotImplementedError()
|
243 |
+
|
244 |
+
|
245 |
+
def create_key(attributes):
|
246 |
+
"""Create unique key for each multi-label."""
|
247 |
+
qualities_str = ''.join(map(str, attributes["qualities"]))
|
248 |
+
instrument_source_str = attributes["instrument_source_str"]
|
249 |
+
instrument_family = attributes["instrument_family_str"]
|
250 |
+
key = f"{instrument_source_str}_{instrument_family}_{qualities_str}"
|
251 |
+
return key
|
252 |
+
|
253 |
+
|
254 |
+
def merge_dictionaries(dicts):
|
255 |
+
"""Merge dictionaries."""
|
256 |
+
merged_dict = {}
|
257 |
+
for dictionary in dicts:
|
258 |
+
for key, value in dictionary.items():
|
259 |
+
if key in merged_dict:
|
260 |
+
merged_dict[key] += value
|
261 |
+
else:
|
262 |
+
merged_dict[key] = value
|
263 |
+
return merged_dict
|
264 |
+
|
265 |
+
|
266 |
+
def adsr_envelope(signal, sample_rate, duration, attack_time, decay_time, sustain_level, release_time):
|
267 |
+
"""
|
268 |
+
Apply an ADSR envelope to an audio signal.
|
269 |
+
|
270 |
+
:param signal: The original audio signal (numpy array).
|
271 |
+
:param sample_rate: The sample rate of the audio signal.
|
272 |
+
:param attack_time: Attack time in seconds.
|
273 |
+
:param decay_time: Decay time in seconds.
|
274 |
+
:param sustain_level: Sustain level as a fraction of the peak (0 to 1).
|
275 |
+
:param release_time: Release time in seconds.
|
276 |
+
:return: The audio signal with the ADSR envelope applied.
|
277 |
+
"""
|
278 |
+
# Calculate the number of samples for each ADSR phase
|
279 |
+
duration_samples = int(duration * sample_rate)
|
280 |
+
|
281 |
+
# assert (duration_samples + int(1.0 * sample_rate)) <= len(signal), "(duration_samples + sample_rate) > len(signal)"
|
282 |
+
assert release_time <= 1.0, "release_time > 1.0"
|
283 |
+
|
284 |
+
attack_samples = int(attack_time * sample_rate)
|
285 |
+
decay_samples = int(decay_time * sample_rate)
|
286 |
+
release_samples = int(release_time * sample_rate)
|
287 |
+
sustain_samples = max(0, duration_samples - attack_samples - decay_samples)
|
288 |
+
|
289 |
+
# Create ADSR envelope
|
290 |
+
attack_env = np.linspace(0, 1, attack_samples)
|
291 |
+
decay_env = np.linspace(1, sustain_level, decay_samples)
|
292 |
+
sustain_env = np.full(sustain_samples, sustain_level)
|
293 |
+
release_env = np.linspace(sustain_level, 0, release_samples)
|
294 |
+
release_env_expand = np.zeros(int(1.0 * sample_rate))
|
295 |
+
release_env_expand[:len(release_env)] = release_env
|
296 |
+
|
297 |
+
# Concatenate all phases to create the complete envelope
|
298 |
+
envelope = np.concatenate([attack_env, decay_env, sustain_env, release_env_expand])
|
299 |
+
|
300 |
+
# Apply the envelope to the signal
|
301 |
+
if len(envelope) <= len(signal):
|
302 |
+
applied_signal = signal[:len(envelope)] * envelope
|
303 |
+
else:
|
304 |
+
signal_expanded = np.zeros(len(envelope))
|
305 |
+
signal_expanded[:len(signal)] = signal
|
306 |
+
applied_signal = signal_expanded * envelope
|
307 |
+
|
308 |
+
return applied_signal
|
309 |
+
|
310 |
+
|
311 |
+
def rms_normalize(audio, target_rms=0.1):
|
312 |
+
"""Normalize the RMS value."""
|
313 |
+
current_rms = np.sqrt(np.mean(audio**2))
|
314 |
+
scaling_factor = target_rms / current_rms
|
315 |
+
normalized_audio = audio * scaling_factor
|
316 |
+
return normalized_audio
|
317 |
+
|
318 |
+
|
319 |
+
def encode_stft(D):
|
320 |
+
"""'STFT+' function that transform spectral matrix into spectral representation."""
|
321 |
+
magnitude = np.abs(D)
|
322 |
+
phase = np.angle(D)
|
323 |
+
|
324 |
+
log_magnitude = np.log1p(magnitude)
|
325 |
+
|
326 |
+
cos_phase = np.cos(phase)
|
327 |
+
sin_phase = np.sin(phase)
|
328 |
+
|
329 |
+
encoded_D = np.stack([log_magnitude, cos_phase, sin_phase], axis=0)
|
330 |
+
return encoded_D
|
331 |
+
|
332 |
+
|
333 |
+
def decode_stft(encoded_D):
|
334 |
+
"""'ISTFT+' function that reconstructs spectral matrix from spectral representation."""
|
335 |
+
log_magnitude = encoded_D[0, ...]
|
336 |
+
cos_phase = encoded_D[1, ...]
|
337 |
+
sin_phase = encoded_D[2, ...]
|
338 |
+
|
339 |
+
magnitude = np.expm1(log_magnitude)
|
340 |
+
|
341 |
+
phase = np.arctan2(sin_phase, cos_phase)
|
342 |
+
|
343 |
+
D = magnitude * (np.cos(phase) + 1j * np.sin(phase))
|
344 |
+
return D
|
webUI/__pycache__/app.cpython-310.pyc
ADDED
Binary file (10.4 kB). View file
|
|
webUI/deprecated/interpolationWithCondition.py
ADDED
@@ -0,0 +1,178 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
import numpy as np
|
3 |
+
import torch
|
4 |
+
|
5 |
+
from model.DiffSynthSampler import DiffSynthSampler
|
6 |
+
from tools import safe_int
|
7 |
+
from webUI.natural_language_guided_STFT.utils import encodeBatch2GradioOutput, latent_representation_to_Gradio_image
|
8 |
+
|
9 |
+
|
10 |
+
def get_interpolation_with_condition_module(gradioWebUI, interpolation_with_text_state):
|
11 |
+
# Load configurations
|
12 |
+
uNet = gradioWebUI.uNet
|
13 |
+
freq_resolution, time_resolution = gradioWebUI.freq_resolution, gradioWebUI.time_resolution
|
14 |
+
VAE_scale = gradioWebUI.VAE_scale
|
15 |
+
height, width, channels = int(freq_resolution/VAE_scale), int(time_resolution/VAE_scale), gradioWebUI.channels
|
16 |
+
timesteps = gradioWebUI.timesteps
|
17 |
+
VAE_quantizer = gradioWebUI.VAE_quantizer
|
18 |
+
VAE_decoder = gradioWebUI.VAE_decoder
|
19 |
+
CLAP = gradioWebUI.CLAP
|
20 |
+
CLAP_tokenizer = gradioWebUI.CLAP_tokenizer
|
21 |
+
device = gradioWebUI.device
|
22 |
+
squared = gradioWebUI.squared
|
23 |
+
sample_rate = gradioWebUI.sample_rate
|
24 |
+
noise_strategy = gradioWebUI.noise_strategy
|
25 |
+
|
26 |
+
def diffusion_random_sample(text2sound_prompts_1, text2sound_prompts_2, text2sound_negative_prompts, text2sound_batchsize,
|
27 |
+
text2sound_duration,
|
28 |
+
text2sound_guidance_scale, text2sound_sampler,
|
29 |
+
text2sound_sample_steps, text2sound_seed,
|
30 |
+
interpolation_with_text_dict):
|
31 |
+
text2sound_sample_steps = int(text2sound_sample_steps)
|
32 |
+
text2sound_seed = safe_int(text2sound_seed, 12345678)
|
33 |
+
# Todo: take care of text2sound_time_resolution/width
|
34 |
+
width = int(time_resolution*((text2sound_duration+1)/4) / VAE_scale)
|
35 |
+
text2sound_batchsize = int(text2sound_batchsize)
|
36 |
+
|
37 |
+
text2sound_embedding_1 = \
|
38 |
+
CLAP.get_text_features(**CLAP_tokenizer([text2sound_prompts_1], padding=True, return_tensors="pt"))[0].to(device)
|
39 |
+
text2sound_embedding_2 = \
|
40 |
+
CLAP.get_text_features(**CLAP_tokenizer([text2sound_prompts_2], padding=True, return_tensors="pt"))[0].to(device)
|
41 |
+
|
42 |
+
CFG = int(text2sound_guidance_scale)
|
43 |
+
|
44 |
+
mySampler = DiffSynthSampler(timesteps, height=height, channels=channels, noise_strategy=noise_strategy)
|
45 |
+
unconditional_condition = \
|
46 |
+
CLAP.get_text_features(**CLAP_tokenizer([text2sound_negative_prompts], padding=True, return_tensors="pt"))[0]
|
47 |
+
mySampler.activate_classifier_free_guidance(CFG, unconditional_condition.to(device))
|
48 |
+
|
49 |
+
mySampler.respace(list(np.linspace(0, timesteps - 1, text2sound_sample_steps, dtype=np.int32)))
|
50 |
+
|
51 |
+
condition = torch.linspace(1, 0, steps=text2sound_batchsize).unsqueeze(1).to(device) * text2sound_embedding_1 + \
|
52 |
+
torch.linspace(0, 1, steps=text2sound_batchsize).unsqueeze(1).to(device) * text2sound_embedding_2
|
53 |
+
|
54 |
+
# Todo: move this code
|
55 |
+
torch.manual_seed(text2sound_seed)
|
56 |
+
initial_noise = torch.randn(text2sound_batchsize, channels, height, width).to(device)
|
57 |
+
|
58 |
+
latent_representations, initial_noise = \
|
59 |
+
mySampler.sample(model=uNet, shape=(text2sound_batchsize, channels, height, width), seed=text2sound_seed,
|
60 |
+
return_tensor=True, condition=condition, sampler=text2sound_sampler, initial_noise=initial_noise)
|
61 |
+
|
62 |
+
latent_representations = latent_representations[-1]
|
63 |
+
|
64 |
+
interpolation_with_text_dict["latent_representations"] = latent_representations
|
65 |
+
|
66 |
+
latent_representation_gradio_images = []
|
67 |
+
quantized_latent_representation_gradio_images = []
|
68 |
+
new_sound_spectrogram_gradio_images = []
|
69 |
+
new_sound_rec_signals_gradio = []
|
70 |
+
|
71 |
+
quantized_latent_representations, loss, (_, _, _) = VAE_quantizer(latent_representations)
|
72 |
+
# Todo: remove hard-coding
|
73 |
+
flipped_log_spectrums, rec_signals = encodeBatch2GradioOutput(VAE_decoder, quantized_latent_representations,
|
74 |
+
resolution=(512, width * VAE_scale), centralized=False,
|
75 |
+
squared=squared)
|
76 |
+
|
77 |
+
for i in range(text2sound_batchsize):
|
78 |
+
latent_representation_gradio_images.append(latent_representation_to_Gradio_image(latent_representations[i]))
|
79 |
+
quantized_latent_representation_gradio_images.append(
|
80 |
+
latent_representation_to_Gradio_image(quantized_latent_representations[i]))
|
81 |
+
new_sound_spectrogram_gradio_images.append(flipped_log_spectrums[i])
|
82 |
+
new_sound_rec_signals_gradio.append((sample_rate, rec_signals[i]))
|
83 |
+
|
84 |
+
def concatenate_arrays(arrays_list):
|
85 |
+
return np.concatenate(arrays_list, axis=1)
|
86 |
+
|
87 |
+
concatenated_spectrogram_gradio_image = concatenate_arrays(new_sound_spectrogram_gradio_images)
|
88 |
+
|
89 |
+
interpolation_with_text_dict["latent_representation_gradio_images"] = latent_representation_gradio_images
|
90 |
+
interpolation_with_text_dict["quantized_latent_representation_gradio_images"] = quantized_latent_representation_gradio_images
|
91 |
+
interpolation_with_text_dict["new_sound_spectrogram_gradio_images"] = new_sound_spectrogram_gradio_images
|
92 |
+
interpolation_with_text_dict["new_sound_rec_signals_gradio"] = new_sound_rec_signals_gradio
|
93 |
+
|
94 |
+
return {text2sound_latent_representation_image: interpolation_with_text_dict["latent_representation_gradio_images"][0],
|
95 |
+
text2sound_quantized_latent_representation_image:
|
96 |
+
interpolation_with_text_dict["quantized_latent_representation_gradio_images"][0],
|
97 |
+
text2sound_sampled_concatenated_spectrogram_image: concatenated_spectrogram_gradio_image,
|
98 |
+
text2sound_sampled_spectrogram_image: interpolation_with_text_dict["new_sound_spectrogram_gradio_images"][0],
|
99 |
+
text2sound_sampled_audio: interpolation_with_text_dict["new_sound_rec_signals_gradio"][0],
|
100 |
+
text2sound_seed_textbox: text2sound_seed,
|
101 |
+
interpolation_with_text_state: interpolation_with_text_dict,
|
102 |
+
text2sound_sample_index_slider: gr.update(minimum=0, maximum=text2sound_batchsize - 1, value=0, step=1,
|
103 |
+
visible=True,
|
104 |
+
label="Sample index.",
|
105 |
+
info="Swipe to view other samples")}
|
106 |
+
|
107 |
+
def show_random_sample(sample_index, text2sound_dict):
|
108 |
+
sample_index = int(sample_index)
|
109 |
+
return {text2sound_latent_representation_image: text2sound_dict["latent_representation_gradio_images"][
|
110 |
+
sample_index],
|
111 |
+
text2sound_quantized_latent_representation_image:
|
112 |
+
text2sound_dict["quantized_latent_representation_gradio_images"][sample_index],
|
113 |
+
text2sound_sampled_spectrogram_image: text2sound_dict["new_sound_spectrogram_gradio_images"][sample_index],
|
114 |
+
text2sound_sampled_audio: text2sound_dict["new_sound_rec_signals_gradio"][sample_index]}
|
115 |
+
|
116 |
+
with gr.Tab("InterpolationCond."):
|
117 |
+
gr.Markdown("Use interpolation to generate a gradient sound sequence.")
|
118 |
+
with gr.Row(variant="panel"):
|
119 |
+
with gr.Column(scale=3):
|
120 |
+
text2sound_prompts_1_textbox = gr.Textbox(label="Positive prompt 1", lines=2, value="organ")
|
121 |
+
text2sound_prompts_2_textbox = gr.Textbox(label="Positive prompt 2", lines=2, value="string")
|
122 |
+
text2sound_negative_prompts_textbox = gr.Textbox(label="Negative prompt", lines=2, value="")
|
123 |
+
|
124 |
+
with gr.Column(scale=1):
|
125 |
+
text2sound_sampling_button = gr.Button(variant="primary",
|
126 |
+
value="Generate a batch of samples and show "
|
127 |
+
"the first one",
|
128 |
+
scale=1)
|
129 |
+
text2sound_sample_index_slider = gr.Slider(minimum=0, maximum=3, value=0, step=1.0, visible=False,
|
130 |
+
label="Sample index",
|
131 |
+
info="Swipe to view other samples")
|
132 |
+
with gr.Row(variant="panel"):
|
133 |
+
with gr.Column(scale=1, variant="panel"):
|
134 |
+
text2sound_sample_steps_slider = gradioWebUI.get_sample_steps_slider()
|
135 |
+
text2sound_sampler_radio = gradioWebUI.get_sampler_radio()
|
136 |
+
text2sound_batchsize_slider = gradioWebUI.get_batchsize_slider(cpu_batchsize=3)
|
137 |
+
text2sound_duration_slider = gradioWebUI.get_duration_slider()
|
138 |
+
text2sound_guidance_scale_slider = gradioWebUI.get_guidance_scale_slider()
|
139 |
+
text2sound_seed_textbox = gradioWebUI.get_seed_textbox()
|
140 |
+
|
141 |
+
with gr.Column(scale=1):
|
142 |
+
with gr.Row(variant="panel"):
|
143 |
+
text2sound_sampled_concatenated_spectrogram_image = gr.Image(label="Interpolations", type="numpy",
|
144 |
+
height=420, scale=8)
|
145 |
+
text2sound_sampled_spectrogram_image = gr.Image(label="Selected spectrogram", type="numpy",
|
146 |
+
height=420, scale=1)
|
147 |
+
text2sound_sampled_audio = gr.Audio(type="numpy", label="Play")
|
148 |
+
|
149 |
+
with gr.Row(variant="panel"):
|
150 |
+
text2sound_latent_representation_image = gr.Image(label="Sampled latent representation", type="numpy",
|
151 |
+
height=200, width=100)
|
152 |
+
text2sound_quantized_latent_representation_image = gr.Image(label="Quantized latent representation",
|
153 |
+
type="numpy", height=200, width=100)
|
154 |
+
|
155 |
+
text2sound_sampling_button.click(diffusion_random_sample,
|
156 |
+
inputs=[text2sound_prompts_1_textbox,
|
157 |
+
text2sound_prompts_2_textbox,
|
158 |
+
text2sound_negative_prompts_textbox,
|
159 |
+
text2sound_batchsize_slider,
|
160 |
+
text2sound_duration_slider,
|
161 |
+
text2sound_guidance_scale_slider, text2sound_sampler_radio,
|
162 |
+
text2sound_sample_steps_slider,
|
163 |
+
text2sound_seed_textbox,
|
164 |
+
interpolation_with_text_state],
|
165 |
+
outputs=[text2sound_latent_representation_image,
|
166 |
+
text2sound_quantized_latent_representation_image,
|
167 |
+
text2sound_sampled_concatenated_spectrogram_image,
|
168 |
+
text2sound_sampled_spectrogram_image,
|
169 |
+
text2sound_sampled_audio,
|
170 |
+
text2sound_seed_textbox,
|
171 |
+
interpolation_with_text_state,
|
172 |
+
text2sound_sample_index_slider])
|
173 |
+
text2sound_sample_index_slider.change(show_random_sample,
|
174 |
+
inputs=[text2sound_sample_index_slider, interpolation_with_text_state],
|
175 |
+
outputs=[text2sound_latent_representation_image,
|
176 |
+
text2sound_quantized_latent_representation_image,
|
177 |
+
text2sound_sampled_spectrogram_image,
|
178 |
+
text2sound_sampled_audio])
|
webUI/deprecated/interpolationWithXT.py
ADDED
@@ -0,0 +1,173 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
import numpy as np
|
3 |
+
import torch
|
4 |
+
|
5 |
+
from model.DiffSynthSampler import DiffSynthSampler
|
6 |
+
from tools import safe_int
|
7 |
+
from webUI.natural_language_guided.utils import encodeBatch2GradioOutput, latent_representation_to_Gradio_image
|
8 |
+
|
9 |
+
|
10 |
+
def get_interpolation_with_xT_module(gradioWebUI, interpolation_with_text_state):
|
11 |
+
# Load configurations
|
12 |
+
uNet = gradioWebUI.uNet
|
13 |
+
freq_resolution, time_resolution = gradioWebUI.freq_resolution, gradioWebUI.time_resolution
|
14 |
+
VAE_scale = gradioWebUI.VAE_scale
|
15 |
+
height, width, channels = int(freq_resolution/VAE_scale), int(time_resolution/VAE_scale), gradioWebUI.channels
|
16 |
+
timesteps = gradioWebUI.timesteps
|
17 |
+
VAE_quantizer = gradioWebUI.VAE_quantizer
|
18 |
+
VAE_decoder = gradioWebUI.VAE_decoder
|
19 |
+
CLAP = gradioWebUI.CLAP
|
20 |
+
CLAP_tokenizer = gradioWebUI.CLAP_tokenizer
|
21 |
+
device = gradioWebUI.device
|
22 |
+
squared = gradioWebUI.squared
|
23 |
+
sample_rate = gradioWebUI.sample_rate
|
24 |
+
noise_strategy = gradioWebUI.noise_strategy
|
25 |
+
|
26 |
+
def diffusion_random_sample(text2sound_prompts, text2sound_negative_prompts, text2sound_batchsize,
|
27 |
+
text2sound_duration,
|
28 |
+
text2sound_noise_variance, text2sound_guidance_scale, text2sound_sampler,
|
29 |
+
text2sound_sample_steps, text2sound_seed,
|
30 |
+
interpolation_with_text_dict):
|
31 |
+
text2sound_sample_steps = int(text2sound_sample_steps)
|
32 |
+
text2sound_seed = safe_int(text2sound_seed, 12345678)
|
33 |
+
# Todo: take care of text2sound_time_resolution/width
|
34 |
+
width = int(time_resolution*((text2sound_duration+1)/4) / VAE_scale)
|
35 |
+
text2sound_batchsize = int(text2sound_batchsize)
|
36 |
+
|
37 |
+
text2sound_embedding = \
|
38 |
+
CLAP.get_text_features(**CLAP_tokenizer([text2sound_prompts], padding=True, return_tensors="pt"))[0].to(device)
|
39 |
+
|
40 |
+
CFG = int(text2sound_guidance_scale)
|
41 |
+
|
42 |
+
mySampler = DiffSynthSampler(timesteps, height=height, channels=channels, noise_strategy=noise_strategy)
|
43 |
+
unconditional_condition = \
|
44 |
+
CLAP.get_text_features(**CLAP_tokenizer([text2sound_negative_prompts], padding=True, return_tensors="pt"))[0]
|
45 |
+
mySampler.activate_classifier_free_guidance(CFG, unconditional_condition.to(device))
|
46 |
+
|
47 |
+
mySampler.respace(list(np.linspace(0, timesteps - 1, text2sound_sample_steps, dtype=np.int32)))
|
48 |
+
|
49 |
+
condition = text2sound_embedding.repeat(text2sound_batchsize, 1)
|
50 |
+
latent_representations, initial_noise = \
|
51 |
+
mySampler.interpolate(model=uNet, shape=(text2sound_batchsize, channels, height, width),
|
52 |
+
seed=text2sound_seed,
|
53 |
+
variance=text2sound_noise_variance,
|
54 |
+
return_tensor=True, condition=condition, sampler=text2sound_sampler)
|
55 |
+
|
56 |
+
latent_representations = latent_representations[-1]
|
57 |
+
|
58 |
+
interpolation_with_text_dict["latent_representations"] = latent_representations
|
59 |
+
|
60 |
+
latent_representation_gradio_images = []
|
61 |
+
quantized_latent_representation_gradio_images = []
|
62 |
+
new_sound_spectrogram_gradio_images = []
|
63 |
+
new_sound_rec_signals_gradio = []
|
64 |
+
|
65 |
+
quantized_latent_representations, loss, (_, _, _) = VAE_quantizer(latent_representations)
|
66 |
+
# Todo: remove hard-coding
|
67 |
+
flipped_log_spectrums, rec_signals = encodeBatch2GradioOutput(VAE_decoder, quantized_latent_representations,
|
68 |
+
resolution=(512, width * VAE_scale), centralized=False,
|
69 |
+
squared=squared)
|
70 |
+
|
71 |
+
for i in range(text2sound_batchsize):
|
72 |
+
latent_representation_gradio_images.append(latent_representation_to_Gradio_image(latent_representations[i]))
|
73 |
+
quantized_latent_representation_gradio_images.append(
|
74 |
+
latent_representation_to_Gradio_image(quantized_latent_representations[i]))
|
75 |
+
new_sound_spectrogram_gradio_images.append(flipped_log_spectrums[i])
|
76 |
+
new_sound_rec_signals_gradio.append((sample_rate, rec_signals[i]))
|
77 |
+
|
78 |
+
def concatenate_arrays(arrays_list):
|
79 |
+
return np.concatenate(arrays_list, axis=1)
|
80 |
+
|
81 |
+
concatenated_spectrogram_gradio_image = concatenate_arrays(new_sound_spectrogram_gradio_images)
|
82 |
+
|
83 |
+
interpolation_with_text_dict["latent_representation_gradio_images"] = latent_representation_gradio_images
|
84 |
+
interpolation_with_text_dict["quantized_latent_representation_gradio_images"] = quantized_latent_representation_gradio_images
|
85 |
+
interpolation_with_text_dict["new_sound_spectrogram_gradio_images"] = new_sound_spectrogram_gradio_images
|
86 |
+
interpolation_with_text_dict["new_sound_rec_signals_gradio"] = new_sound_rec_signals_gradio
|
87 |
+
|
88 |
+
return {text2sound_latent_representation_image: interpolation_with_text_dict["latent_representation_gradio_images"][0],
|
89 |
+
text2sound_quantized_latent_representation_image:
|
90 |
+
interpolation_with_text_dict["quantized_latent_representation_gradio_images"][0],
|
91 |
+
text2sound_sampled_concatenated_spectrogram_image: concatenated_spectrogram_gradio_image,
|
92 |
+
text2sound_sampled_spectrogram_image: interpolation_with_text_dict["new_sound_spectrogram_gradio_images"][0],
|
93 |
+
text2sound_sampled_audio: interpolation_with_text_dict["new_sound_rec_signals_gradio"][0],
|
94 |
+
text2sound_seed_textbox: text2sound_seed,
|
95 |
+
interpolation_with_text_state: interpolation_with_text_dict,
|
96 |
+
text2sound_sample_index_slider: gr.update(minimum=0, maximum=text2sound_batchsize - 1, value=0, step=1,
|
97 |
+
visible=True,
|
98 |
+
label="Sample index.",
|
99 |
+
info="Swipe to view other samples")}
|
100 |
+
|
101 |
+
def show_random_sample(sample_index, text2sound_dict):
|
102 |
+
sample_index = int(sample_index)
|
103 |
+
return {text2sound_latent_representation_image: text2sound_dict["latent_representation_gradio_images"][
|
104 |
+
sample_index],
|
105 |
+
text2sound_quantized_latent_representation_image:
|
106 |
+
text2sound_dict["quantized_latent_representation_gradio_images"][sample_index],
|
107 |
+
text2sound_sampled_spectrogram_image: text2sound_dict["new_sound_spectrogram_gradio_images"][sample_index],
|
108 |
+
text2sound_sampled_audio: text2sound_dict["new_sound_rec_signals_gradio"][sample_index]}
|
109 |
+
|
110 |
+
with gr.Tab("InterpolationXT"):
|
111 |
+
gr.Markdown("Use interpolation to generate a gradient sound sequence.")
|
112 |
+
with gr.Row(variant="panel"):
|
113 |
+
with gr.Column(scale=3):
|
114 |
+
text2sound_prompts_textbox = gr.Textbox(label="Positive prompt", lines=2, value="organ")
|
115 |
+
text2sound_negative_prompts_textbox = gr.Textbox(label="Negative prompt", lines=2, value="")
|
116 |
+
|
117 |
+
with gr.Column(scale=1):
|
118 |
+
text2sound_sampling_button = gr.Button(variant="primary",
|
119 |
+
value="Generate a batch of samples and show "
|
120 |
+
"the first one",
|
121 |
+
scale=1)
|
122 |
+
text2sound_sample_index_slider = gr.Slider(minimum=0, maximum=3, value=0, step=1.0, visible=False,
|
123 |
+
label="Sample index",
|
124 |
+
info="Swipe to view other samples")
|
125 |
+
with gr.Row(variant="panel"):
|
126 |
+
with gr.Column(scale=1, variant="panel"):
|
127 |
+
text2sound_sample_steps_slider = gradioWebUI.get_sample_steps_slider()
|
128 |
+
text2sound_sampler_radio = gradioWebUI.get_sampler_radio()
|
129 |
+
text2sound_batchsize_slider = gradioWebUI.get_batchsize_slider(cpu_batchsize=3)
|
130 |
+
text2sound_duration_slider = gradioWebUI.get_duration_slider()
|
131 |
+
text2sound_guidance_scale_slider = gradioWebUI.get_guidance_scale_slider()
|
132 |
+
text2sound_seed_textbox = gradioWebUI.get_seed_textbox()
|
133 |
+
text2sound_noise_variance_slider = gr.Slider(minimum=0., maximum=5., value=1., step=0.01,
|
134 |
+
label="Noise variance",
|
135 |
+
info="The larger this value, the more diversity the interpolation has.")
|
136 |
+
|
137 |
+
with gr.Column(scale=1):
|
138 |
+
with gr.Row(variant="panel"):
|
139 |
+
text2sound_sampled_concatenated_spectrogram_image = gr.Image(label="Interpolations", type="numpy",
|
140 |
+
height=420, scale=8)
|
141 |
+
text2sound_sampled_spectrogram_image = gr.Image(label="Selected spectrogram", type="numpy",
|
142 |
+
height=420, scale=1)
|
143 |
+
text2sound_sampled_audio = gr.Audio(type="numpy", label="Play")
|
144 |
+
|
145 |
+
with gr.Row(variant="panel"):
|
146 |
+
text2sound_latent_representation_image = gr.Image(label="Sampled latent representation", type="numpy",
|
147 |
+
height=200, width=100)
|
148 |
+
text2sound_quantized_latent_representation_image = gr.Image(label="Quantized latent representation",
|
149 |
+
type="numpy", height=200, width=100)
|
150 |
+
|
151 |
+
text2sound_sampling_button.click(diffusion_random_sample,
|
152 |
+
inputs=[text2sound_prompts_textbox, text2sound_negative_prompts_textbox,
|
153 |
+
text2sound_batchsize_slider,
|
154 |
+
text2sound_duration_slider,
|
155 |
+
text2sound_noise_variance_slider,
|
156 |
+
text2sound_guidance_scale_slider, text2sound_sampler_radio,
|
157 |
+
text2sound_sample_steps_slider,
|
158 |
+
text2sound_seed_textbox,
|
159 |
+
interpolation_with_text_state],
|
160 |
+
outputs=[text2sound_latent_representation_image,
|
161 |
+
text2sound_quantized_latent_representation_image,
|
162 |
+
text2sound_sampled_concatenated_spectrogram_image,
|
163 |
+
text2sound_sampled_spectrogram_image,
|
164 |
+
text2sound_sampled_audio,
|
165 |
+
text2sound_seed_textbox,
|
166 |
+
interpolation_with_text_state,
|
167 |
+
text2sound_sample_index_slider])
|
168 |
+
text2sound_sample_index_slider.change(show_random_sample,
|
169 |
+
inputs=[text2sound_sample_index_slider, interpolation_with_text_state],
|
170 |
+
outputs=[text2sound_latent_representation_image,
|
171 |
+
text2sound_quantized_latent_representation_image,
|
172 |
+
text2sound_sampled_spectrogram_image,
|
173 |
+
text2sound_sampled_audio])
|
webUI/natural_language_guided/GAN.py
ADDED
@@ -0,0 +1,164 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
import numpy as np
|
3 |
+
import torch
|
4 |
+
|
5 |
+
from tools import safe_int
|
6 |
+
from webUI.natural_language_guided_STFT.utils import encodeBatch2GradioOutput, latent_representation_to_Gradio_image, \
|
7 |
+
add_instrument
|
8 |
+
|
9 |
+
|
10 |
+
def get_testGAN(gradioWebUI, text2sound_state, virtual_instruments_state):
|
11 |
+
# Load configurations
|
12 |
+
gan_generator = gradioWebUI.GAN_generator
|
13 |
+
freq_resolution, time_resolution = gradioWebUI.freq_resolution, gradioWebUI.time_resolution
|
14 |
+
VAE_scale = gradioWebUI.VAE_scale
|
15 |
+
height, width, channels = int(freq_resolution / VAE_scale), int(time_resolution / VAE_scale), gradioWebUI.channels
|
16 |
+
|
17 |
+
timesteps = gradioWebUI.timesteps
|
18 |
+
VAE_quantizer = gradioWebUI.VAE_quantizer
|
19 |
+
VAE_decoder = gradioWebUI.VAE_decoder
|
20 |
+
CLAP = gradioWebUI.CLAP
|
21 |
+
CLAP_tokenizer = gradioWebUI.CLAP_tokenizer
|
22 |
+
device = gradioWebUI.device
|
23 |
+
squared = gradioWebUI.squared
|
24 |
+
sample_rate = gradioWebUI.sample_rate
|
25 |
+
noise_strategy = gradioWebUI.noise_strategy
|
26 |
+
|
27 |
+
def gan_random_sample(text2sound_prompts, text2sound_negative_prompts, text2sound_batchsize,
|
28 |
+
text2sound_duration,
|
29 |
+
text2sound_guidance_scale, text2sound_sampler,
|
30 |
+
text2sound_sample_steps, text2sound_seed,
|
31 |
+
text2sound_dict):
|
32 |
+
text2sound_seed = safe_int(text2sound_seed, 12345678)
|
33 |
+
|
34 |
+
width = int(time_resolution * ((text2sound_duration + 1) / 4) / VAE_scale)
|
35 |
+
|
36 |
+
text2sound_batchsize = int(text2sound_batchsize)
|
37 |
+
|
38 |
+
text2sound_embedding = \
|
39 |
+
CLAP.get_text_features(**CLAP_tokenizer([text2sound_prompts], padding=True, return_tensors="pt"))[0].to(
|
40 |
+
device)
|
41 |
+
|
42 |
+
CFG = int(text2sound_guidance_scale)
|
43 |
+
|
44 |
+
condition = text2sound_embedding.repeat(text2sound_batchsize, 1)
|
45 |
+
|
46 |
+
noise = torch.randn(text2sound_batchsize, channels, height, width).to(device)
|
47 |
+
latent_representations = gan_generator(noise, condition)
|
48 |
+
|
49 |
+
print(latent_representations[0, 0, :3, :3])
|
50 |
+
|
51 |
+
latent_representation_gradio_images = []
|
52 |
+
quantized_latent_representation_gradio_images = []
|
53 |
+
new_sound_spectrogram_gradio_images = []
|
54 |
+
new_sound_rec_signals_gradio = []
|
55 |
+
|
56 |
+
quantized_latent_representations, loss, (_, _, _) = VAE_quantizer(latent_representations)
|
57 |
+
# Todo: remove hard-coding
|
58 |
+
flipped_log_spectrums, rec_signals = encodeBatch2GradioOutput(VAE_decoder, quantized_latent_representations,
|
59 |
+
resolution=(512, width * VAE_scale),
|
60 |
+
centralized=False,
|
61 |
+
squared=squared)
|
62 |
+
|
63 |
+
for i in range(text2sound_batchsize):
|
64 |
+
latent_representation_gradio_images.append(latent_representation_to_Gradio_image(latent_representations[i]))
|
65 |
+
quantized_latent_representation_gradio_images.append(
|
66 |
+
latent_representation_to_Gradio_image(quantized_latent_representations[i]))
|
67 |
+
new_sound_spectrogram_gradio_images.append(flipped_log_spectrums[i])
|
68 |
+
new_sound_rec_signals_gradio.append((sample_rate, rec_signals[i]))
|
69 |
+
|
70 |
+
text2sound_dict["latent_representations"] = latent_representations.to("cpu").detach().numpy()
|
71 |
+
text2sound_dict["quantized_latent_representations"] = quantized_latent_representations.to("cpu").detach().numpy()
|
72 |
+
text2sound_dict["latent_representation_gradio_images"] = latent_representation_gradio_images
|
73 |
+
text2sound_dict["quantized_latent_representation_gradio_images"] = quantized_latent_representation_gradio_images
|
74 |
+
text2sound_dict["new_sound_spectrogram_gradio_images"] = new_sound_spectrogram_gradio_images
|
75 |
+
text2sound_dict["new_sound_rec_signals_gradio"] = new_sound_rec_signals_gradio
|
76 |
+
|
77 |
+
text2sound_dict["condition"] = condition.to("cpu").detach().numpy()
|
78 |
+
# text2sound_dict["negative_condition"] = negative_condition.to("cpu").detach().numpy()
|
79 |
+
text2sound_dict["guidance_scale"] = CFG
|
80 |
+
text2sound_dict["sampler"] = text2sound_sampler
|
81 |
+
|
82 |
+
return {text2sound_latent_representation_image: text2sound_dict["latent_representation_gradio_images"][0],
|
83 |
+
text2sound_quantized_latent_representation_image:
|
84 |
+
text2sound_dict["quantized_latent_representation_gradio_images"][0],
|
85 |
+
text2sound_sampled_spectrogram_image: text2sound_dict["new_sound_spectrogram_gradio_images"][0],
|
86 |
+
text2sound_sampled_audio: text2sound_dict["new_sound_rec_signals_gradio"][0],
|
87 |
+
text2sound_seed_textbox: text2sound_seed,
|
88 |
+
text2sound_state: text2sound_dict,
|
89 |
+
text2sound_sample_index_slider: gr.update(minimum=0, maximum=text2sound_batchsize - 1, value=0, step=1,
|
90 |
+
visible=True,
|
91 |
+
label="Sample index.",
|
92 |
+
info="Swipe to view other samples")}
|
93 |
+
|
94 |
+
def show_random_sample(sample_index, text2sound_dict):
|
95 |
+
sample_index = int(sample_index)
|
96 |
+
text2sound_dict["sample_index"] = sample_index
|
97 |
+
return {text2sound_latent_representation_image: text2sound_dict["latent_representation_gradio_images"][
|
98 |
+
sample_index],
|
99 |
+
text2sound_quantized_latent_representation_image:
|
100 |
+
text2sound_dict["quantized_latent_representation_gradio_images"][sample_index],
|
101 |
+
text2sound_sampled_spectrogram_image: text2sound_dict["new_sound_spectrogram_gradio_images"][
|
102 |
+
sample_index],
|
103 |
+
text2sound_sampled_audio: text2sound_dict["new_sound_rec_signals_gradio"][sample_index]}
|
104 |
+
|
105 |
+
|
106 |
+
with gr.Tab("Text2sound_GAN"):
|
107 |
+
gr.Markdown("Use neural networks to select random sounds using your favorite instrument!")
|
108 |
+
with gr.Row(variant="panel"):
|
109 |
+
with gr.Column(scale=3):
|
110 |
+
text2sound_prompts_textbox = gr.Textbox(label="Positive prompt", lines=2, value="organ")
|
111 |
+
text2sound_negative_prompts_textbox = gr.Textbox(label="Negative prompt", lines=2, value="")
|
112 |
+
|
113 |
+
with gr.Column(scale=1):
|
114 |
+
text2sound_sampling_button = gr.Button(variant="primary",
|
115 |
+
value="Generate a batch of samples and show "
|
116 |
+
"the first one",
|
117 |
+
scale=1)
|
118 |
+
text2sound_sample_index_slider = gr.Slider(minimum=0, maximum=3, value=0, step=1.0, visible=False,
|
119 |
+
label="Sample index",
|
120 |
+
info="Swipe to view other samples")
|
121 |
+
with gr.Row(variant="panel"):
|
122 |
+
with gr.Column(scale=1, variant="panel"):
|
123 |
+
text2sound_sample_steps_slider = gradioWebUI.get_sample_steps_slider()
|
124 |
+
text2sound_sampler_radio = gradioWebUI.get_sampler_radio()
|
125 |
+
text2sound_batchsize_slider = gradioWebUI.get_batchsize_slider()
|
126 |
+
text2sound_duration_slider = gradioWebUI.get_duration_slider()
|
127 |
+
text2sound_guidance_scale_slider = gradioWebUI.get_guidance_scale_slider()
|
128 |
+
text2sound_seed_textbox = gradioWebUI.get_seed_textbox()
|
129 |
+
|
130 |
+
with gr.Column(scale=1):
|
131 |
+
text2sound_sampled_spectrogram_image = gr.Image(label="Sampled spectrogram", type="numpy", height=420)
|
132 |
+
text2sound_sampled_audio = gr.Audio(type="numpy", label="Play")
|
133 |
+
|
134 |
+
|
135 |
+
with gr.Row(variant="panel"):
|
136 |
+
text2sound_latent_representation_image = gr.Image(label="Sampled latent representation", type="numpy",
|
137 |
+
height=200, width=100)
|
138 |
+
text2sound_quantized_latent_representation_image = gr.Image(label="Quantized latent representation",
|
139 |
+
type="numpy", height=200, width=100)
|
140 |
+
|
141 |
+
text2sound_sampling_button.click(gan_random_sample,
|
142 |
+
inputs=[text2sound_prompts_textbox,
|
143 |
+
text2sound_negative_prompts_textbox,
|
144 |
+
text2sound_batchsize_slider,
|
145 |
+
text2sound_duration_slider,
|
146 |
+
text2sound_guidance_scale_slider, text2sound_sampler_radio,
|
147 |
+
text2sound_sample_steps_slider,
|
148 |
+
text2sound_seed_textbox,
|
149 |
+
text2sound_state],
|
150 |
+
outputs=[text2sound_latent_representation_image,
|
151 |
+
text2sound_quantized_latent_representation_image,
|
152 |
+
text2sound_sampled_spectrogram_image,
|
153 |
+
text2sound_sampled_audio,
|
154 |
+
text2sound_seed_textbox,
|
155 |
+
text2sound_state,
|
156 |
+
text2sound_sample_index_slider])
|
157 |
+
|
158 |
+
|
159 |
+
text2sound_sample_index_slider.change(show_random_sample,
|
160 |
+
inputs=[text2sound_sample_index_slider, text2sound_state],
|
161 |
+
outputs=[text2sound_latent_representation_image,
|
162 |
+
text2sound_quantized_latent_representation_image,
|
163 |
+
text2sound_sampled_spectrogram_image,
|
164 |
+
text2sound_sampled_audio])
|
webUI/natural_language_guided/README.py
ADDED
@@ -0,0 +1,53 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
|
3 |
+
readme_content = """## Stable Diffusion for Sound Generation
|
4 |
+
|
5 |
+
This project applies stable diffusion[1] to sound generation. Inspired by the work of AUTOMATIC1111, 2022[2], we have implemented a preliminary version of text2sound, sound2sound, inpaint, as well as an additional interpolation feature, all accessible through a web UI.
|
6 |
+
|
7 |
+
### Neural Network Training Data:
|
8 |
+
The neural network is trained using the filtered NSynth dataset[3], which is a large-scale and high-quality collection of annotated musical notes, comprising 305,979 musical notes. However, for this project, only samples with a pitch set to E3 were used, resulting in an actual training sample size of 4,096, making it a low-resource project.
|
9 |
+
|
10 |
+
The training took place on an NVIDIA Tesla T4 GPU and spanned approximately 10 hours.
|
11 |
+
|
12 |
+
### Natural Language Guidance:
|
13 |
+
Natural language guidance is derived from the multi-label annotations of the NSynth dataset. The labels included in the training are:
|
14 |
+
|
15 |
+
- **Instrument Families**: bass, brass, flute, guitar, keyboard, mallet, organ, reed, string, synth lead, vocal.
|
16 |
+
|
17 |
+
- **Instrument Sources**: acoustic, electronic, synthetic.
|
18 |
+
|
19 |
+
- **Note Qualities**: bright, dark, distortion, fast decay, long release, multiphonic, nonlinear env, percussive, reverb, tempo-synced.
|
20 |
+
|
21 |
+
### Usage Hints:
|
22 |
+
|
23 |
+
1. **Prompt Format**: It's recommended to use the format “label1, label2, label3“, e.g., ”organ, dark, long release“.
|
24 |
+
|
25 |
+
2. **Unique Sounds**: If you keep generating the same sound, try setting a different seed!
|
26 |
+
|
27 |
+
3. **Sample Indexing**: Drag the "Sample index slider" to view other samples within the generated batch.
|
28 |
+
|
29 |
+
4. **Running on CPU**: Be cautious with the settings for 'batchsize' and 'sample_steps' when running on CPU to avoid timeouts. Recommended settings are batchsize ≤ 4 and sample_steps = 15.
|
30 |
+
|
31 |
+
5. **Editing Sounds**: Generated audio can be downloaded and then re-uploaded for further editing at the sound2sound/inpaint sections.
|
32 |
+
|
33 |
+
6. **Guidance Scale**: A higher 'guidance_scale' intensifies the influence of natural language conditioning on the generation[4]. It's recommended to set it between 3 and 10.
|
34 |
+
|
35 |
+
7. **Noising Strength**: A smaller 'noising_strength' value makes the generated sound closer to the input sound.
|
36 |
+
|
37 |
+
References:
|
38 |
+
|
39 |
+
[1] Rombach, R., Blattmann, A., Lorenz, D., Esser, P., & Ommer, B. (2022). High-resolution image synthesis with latent diffusion models. In Proceedings of the IEEE/CVF conference on computer vision and pattern recognition (pp. 10684-10695).
|
40 |
+
|
41 |
+
[2] AUTOMATIC1111. (2022). Stable Diffusion Web UI [Computer software]. Retrieved from https://github.com/AUTOMATIC1111/stable-diffusion-webui
|
42 |
+
|
43 |
+
[3] Engel, J., Resnick, C., Roberts, A., Dieleman, S., Eck, D., Simonyan, K., & Norouzi, M. (2017). Neural Audio Synthesis of Musical Notes with WaveNet Autoencoders.
|
44 |
+
|
45 |
+
[4] Ho, J., & Salimans, T. (2022). Classifier-free diffusion guidance. arXiv preprint arXiv:2207.12598.
|
46 |
+
"""
|
47 |
+
|
48 |
+
def get_readme_module():
|
49 |
+
|
50 |
+
with gr.Tab("README"):
|
51 |
+
# gr.Markdown("Use interpolation to generate a gradient sound sequence.")
|
52 |
+
with gr.Column(scale=3):
|
53 |
+
readme_textbox = gr.Textbox(label="readme", lines=40, value=readme_content, interactive=False)
|
webUI/natural_language_guided/__pycache__/README.cpython-310.pyc
ADDED
Binary file (3.51 kB). View file
|
|
webUI/natural_language_guided/__pycache__/README_STFT.cpython-310.pyc
ADDED
Binary file (3.51 kB). View file
|
|
webUI/natural_language_guided/__pycache__/buildInstrument_STFT.cpython-310.pyc
ADDED
Binary file (8.26 kB). View file
|
|
webUI/natural_language_guided/__pycache__/build_instrument.cpython-310.pyc
ADDED
Binary file (8.08 kB). View file
|
|
webUI/natural_language_guided/__pycache__/gradioWebUI.cpython-310.pyc
ADDED
Binary file (3.61 kB). View file
|
|
webUI/natural_language_guided/__pycache__/gradioWebUI_STFT.cpython-310.pyc
ADDED
Binary file (3.62 kB). View file
|
|
webUI/natural_language_guided/__pycache__/gradio_webUI.cpython-310.pyc
ADDED
Binary file (3.61 kB). View file
|
|
webUI/natural_language_guided/__pycache__/inpaintWithText.cpython-310.pyc
ADDED
Binary file (12.5 kB). View file
|
|
webUI/natural_language_guided/__pycache__/inpaintWithText_STFT.cpython-310.pyc
ADDED
Binary file (11.6 kB). View file
|
|
webUI/natural_language_guided/__pycache__/inpaint_with_text.cpython-310.pyc
ADDED
Binary file (12.5 kB). View file
|
|
webUI/natural_language_guided/__pycache__/rec.cpython-310.pyc
ADDED
Binary file (6.11 kB). View file
|
|
webUI/natural_language_guided/__pycache__/recSTFT.cpython-310.pyc
ADDED
Binary file (6.11 kB). View file
|
|