WeixuanYuan commited on
Commit
ae1bdf7
1 Parent(s): 39653fc

Upload 66 files

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. Dockerfile +27 -0
  2. app.py +107 -0
  3. app_chat.py +7 -0
  4. metrics/FD.py +293 -0
  5. metrics/IS.py +218 -0
  6. metrics/P_C_T.py +12 -0
  7. metrics/get_reference_AST_features.py +63 -0
  8. metrics/pipelines.py +144 -0
  9. metrics/pipelines_STFT.py +100 -0
  10. metrics/precision_recall.py +204 -0
  11. metrics/visualizations.py +123 -0
  12. model/DiffSynthSampler.py +425 -0
  13. model/GAN.py +262 -0
  14. model/VQGAN.py +684 -0
  15. model/__pycache__/DiffSynthSampler.cpython-310.pyc +0 -0
  16. model/__pycache__/GAN.cpython-310.pyc +0 -0
  17. model/__pycache__/VQGAN.cpython-310.pyc +0 -0
  18. model/__pycache__/diffusion.cpython-310.pyc +0 -0
  19. model/__pycache__/diffusion_components.cpython-310.pyc +0 -0
  20. model/__pycache__/multimodal_model.cpython-310.pyc +0 -0
  21. model/__pycache__/perceptual_label_predictor.cpython-37.pyc +0 -0
  22. model/__pycache__/timbre_encoder_pretrain.cpython-310.pyc +0 -0
  23. model/diffusion.py +371 -0
  24. model/diffusion_components.py +351 -0
  25. model/multimodal_model.py +274 -0
  26. model/timbre_encoder_pretrain.py +220 -0
  27. models/24_1_2024-52_4x_L_D_imageVQVAE.pth +3 -0
  28. models/24_1_2024-52_4x_L_D_imageVQVAE_discriminator.pth +3 -0
  29. models/24_1_2024_MMM.pth +3 -0
  30. models/24_1_2024_STFT_timbre_encoder.pth +3 -0
  31. models/history/28_1_2024_TE_STFT_300000_UNet.pth +3 -0
  32. requirements.txt +15 -0
  33. tools.py +344 -0
  34. webUI/__pycache__/app.cpython-310.pyc +0 -0
  35. webUI/deprecated/interpolationWithCondition.py +178 -0
  36. webUI/deprecated/interpolationWithXT.py +173 -0
  37. webUI/natural_language_guided/GAN.py +164 -0
  38. webUI/natural_language_guided/README.py +53 -0
  39. webUI/natural_language_guided/__pycache__/README.cpython-310.pyc +0 -0
  40. webUI/natural_language_guided/__pycache__/README_STFT.cpython-310.pyc +0 -0
  41. webUI/natural_language_guided/__pycache__/buildInstrument_STFT.cpython-310.pyc +0 -0
  42. webUI/natural_language_guided/__pycache__/build_instrument.cpython-310.pyc +0 -0
  43. webUI/natural_language_guided/__pycache__/gradioWebUI.cpython-310.pyc +0 -0
  44. webUI/natural_language_guided/__pycache__/gradioWebUI_STFT.cpython-310.pyc +0 -0
  45. webUI/natural_language_guided/__pycache__/gradio_webUI.cpython-310.pyc +0 -0
  46. webUI/natural_language_guided/__pycache__/inpaintWithText.cpython-310.pyc +0 -0
  47. webUI/natural_language_guided/__pycache__/inpaintWithText_STFT.cpython-310.pyc +0 -0
  48. webUI/natural_language_guided/__pycache__/inpaint_with_text.cpython-310.pyc +0 -0
  49. webUI/natural_language_guided/__pycache__/rec.cpython-310.pyc +0 -0
  50. webUI/natural_language_guided/__pycache__/recSTFT.cpython-310.pyc +0 -0
Dockerfile ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # 使用 Python 3.9 作为基础镜像
2
+ FROM python:3.9
3
+
4
+ # 添加用户
5
+ RUN useradd -m -u 1000 user
6
+
7
+ # 设置工作目录
8
+ WORKDIR /app
9
+
10
+ # 切换到 root 用户以安装系统依赖
11
+ USER root
12
+ RUN apt-get update && apt-get install -y rubberband-cli
13
+
14
+ # 切回到普通用户
15
+ USER user
16
+ ENV PATH="/home/user/.local/bin:$PATH"
17
+
18
+ # 复制并安装 Python 依赖
19
+ COPY --chown=user ./requirements.txt requirements.txt
20
+ RUN pip install --no-cache-dir --upgrade -r requirements.txt
21
+
22
+ # 复制应用文件
23
+ COPY --chown=user . /app
24
+
25
+ # 启动 Gradio 应用,假设应用入口文件为 app.py
26
+ # CMD ["python", "app.py"]
27
+ CMD ["python", "app_chat.py"]
app.py ADDED
@@ -0,0 +1,107 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+ from tqdm import tqdm
4
+ from model.DiffSynthSampler import DiffSynthSampler
5
+ import soundfile as sf
6
+ # import pyrubberband as pyrb
7
+ from tqdm import tqdm
8
+ from model.VQGAN import get_VQGAN
9
+ from model.diffusion import get_diffusion_model
10
+ from transformers import AutoTokenizer, ClapModel
11
+ from model.diffusion_components import linear_beta_schedule
12
+ from model.timbre_encoder_pretrain import get_timbre_encoder
13
+ from model.multimodal_model import get_multi_modal_model
14
+
15
+
16
+
17
+ import gradio as gr
18
+ from webUI.natural_language_guided.gradio_webUI import GradioWebUI
19
+ from webUI.natural_language_guided.text2sound import get_text2sound_module
20
+ from webUI.natural_language_guided.sound2sound_with_text import get_sound2sound_with_text_module
21
+ from webUI.natural_language_guided.inpaint_with_text import get_inpaint_with_text_module
22
+ from webUI.natural_language_guided.build_instrument import get_build_instrument_module
23
+ from webUI.natural_language_guided.README import get_readme_module
24
+
25
+
26
+
27
+ device = "cuda" if torch.cuda.is_available() else "cpu"
28
+ use_pretrained_CLAP = False
29
+
30
+ # load VQ-GAN
31
+ VAE_model_name = "24_1_2024-52_4x_L_D"
32
+ modelConfig = {"in_channels": 3, "hidden_channels": [80, 160], "embedding_dim": 4, "out_channels": 3, "block_depth": 2,
33
+ "attn_pos": [80, 160], "attn_with_skip": True,
34
+ "num_embeddings": 8192, "commitment_cost": 0.25, "decay": 0.99,
35
+ "norm_type": "groupnorm", "act_type": "swish", "num_groups": 16}
36
+ VAE = get_VQGAN(modelConfig, load_pretrain=True, model_name=VAE_model_name, device=device)
37
+
38
+ # load U-Net
39
+ UNet_model_name = "history/28_1_2024_CLAP_STFT_180000" if use_pretrained_CLAP else "history/28_1_2024_TE_STFT_300000"
40
+ unetConfig = {"in_dim": 4, "down_dims": [96, 96, 192, 384], "up_dims": [384, 384, 192, 96], "attn_type": "linear_add", "condition_type": "natural_language_prompt", "label_emb_dim": 512}
41
+ uNet = get_diffusion_model(unetConfig, load_pretrain=True, model_name=UNet_model_name, device=device)
42
+
43
+ # load LM
44
+ CLAP_temp = ClapModel.from_pretrained("laion/clap-htsat-unfused") # 153,492,890
45
+ CLAP_tokenizer = AutoTokenizer.from_pretrained("laion/clap-htsat-unfused")
46
+
47
+ timbre_encoder_name = "24_1_2024_STFT"
48
+ timbre_encoder_Config = {"input_dim": 512, "feature_dim": 512, "hidden_dim": 1024, "num_instrument_classes": 1006, "num_instrument_family_classes": 11, "num_velocity_classes": 128, "num_qualities": 10, "num_layers": 3}
49
+ timbre_encoder = get_timbre_encoder(timbre_encoder_Config, load_pretrain=True, model_name=timbre_encoder_name, device=device)
50
+
51
+ if use_pretrained_CLAP:
52
+ text_encoder = CLAP_temp
53
+ else:
54
+ multimodalmodel_name = "24_1_2024"
55
+ multimodalmodel_config = {"text_feature_dim": 512, "spectrogram_feature_dim": 1024, "multi_modal_emb_dim": 512, "num_projection_layers": 2,
56
+ "temperature": 1.0, "dropout": 0.1, "freeze_text_encoder": False, "freeze_spectrogram_encoder": False}
57
+ mmm = get_multi_modal_model(timbre_encoder, CLAP_temp, multimodalmodel_config, load_pretrain=True, model_name=multimodalmodel_name, device=device)
58
+
59
+ text_encoder = mmm.to("cpu")
60
+
61
+
62
+
63
+
64
+
65
+
66
+ gradioWebUI = GradioWebUI(device, VAE, uNet, text_encoder, CLAP_tokenizer, freq_resolution=512, time_resolution=256, channels=4, timesteps=1000, squared=False,
67
+ VAE_scale=4, flexible_duration=True, noise_strategy="repeat", GAN_generator=None)
68
+
69
+ with gr.Blocks(theme=gr.themes.Soft(), mode="dark") as demo:
70
+ # with gr.Blocks(theme='WeixuanYuan/Soft_dark', mode="dark") as demo:
71
+ gr.Markdown("DiffuSynth v0.2")
72
+
73
+ reconstruction_state = gr.State(value={})
74
+ text2sound_state = gr.State(value={})
75
+ sound2sound_state = gr.State(value={})
76
+ inpaint_state = gr.State(value={})
77
+ super_resolution_state = gr.State(value={})
78
+ virtual_instruments_state = gr.State(value={"virtual_instruments": {}})
79
+
80
+ get_text2sound_module(gradioWebUI, text2sound_state, virtual_instruments_state)
81
+ get_sound2sound_with_text_module(gradioWebUI, sound2sound_state, virtual_instruments_state)
82
+ get_inpaint_with_text_module(gradioWebUI, inpaint_state, virtual_instruments_state)
83
+ get_build_instrument_module(gradioWebUI, virtual_instruments_state)
84
+ get_readme_module()
85
+
86
+ demo.launch(debug=True, share=True)
87
+
88
+
89
+
90
+
91
+
92
+
93
+
94
+
95
+
96
+
97
+
98
+
99
+
100
+
101
+
102
+
103
+
104
+
105
+
106
+
107
+
app_chat.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+
3
+ def greet(name):
4
+ return "Hello " + name + "!!"
5
+
6
+ demo = gr.Interface(fn=greet, inputs="text", outputs="text")
7
+ demo.launch()
metrics/FD.py ADDED
@@ -0,0 +1,293 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import os
3
+
4
+ import librosa
5
+ import numpy as np
6
+ import torch
7
+ from tqdm import tqdm
8
+ from scipy.linalg import sqrtm
9
+
10
+ from metrics.pipelines import sample_pipeline, sample_pipeline_GAN
11
+ from metrics.pipelines_STFT import sample_pipeline_STFT, sample_pipeline_GAN_STFT
12
+ from tools import rms_normalize
13
+
14
+
15
+ def ASTaudio2feature(device, signal, processor, AST, sampling_rate):
16
+ # audio file is decoded on the fly
17
+ inputs = processor(signal, sampling_rate=sampling_rate, return_tensors="pt").to(device)
18
+ with torch.no_grad():
19
+ outputs = AST(**inputs)
20
+
21
+ last_hidden_states = outputs.last_hidden_state[:, 0, :].to("cpu").detach().numpy()
22
+ return last_hidden_states
23
+
24
+
25
+ # 计算两个numpy数组的均值和协方差矩阵
26
+ def calculate_statistics(features):
27
+ mu = np.mean(features, axis=0)
28
+ sigma = np.cov(features, rowvar=False)
29
+ return mu, sigma
30
+
31
+
32
+ # 计算FID
33
+ def calculate_fid(mu1, sigma1, mu2, sigma2, eps=1e-6):
34
+ # 在协方差矩阵对角线上添加一个小的正值
35
+ sigma1 += np.eye(sigma1.shape[0]) * eps
36
+ sigma2 += np.eye(sigma2.shape[0]) * eps
37
+
38
+ ssdiff = np.sum((mu1 - mu2) ** 2.0)
39
+ covmean = sqrtm(sigma1.dot(sigma2))
40
+
41
+ # 由于数值问题,有时可能会得到复数,只取实部
42
+ if np.iscomplexobj(covmean):
43
+ covmean = covmean.real
44
+
45
+ fid = ssdiff + np.trace(sigma1 + sigma2 - 2.0 * covmean)
46
+ return fid
47
+
48
+
49
+ # 计算FID
50
+ def calculate_fid_dict(dict1, dict2, eps=1e-6):
51
+ # 在协方差矩阵对角线上添加一个小的正值
52
+ mu1, sigma1 = dict1["mu"], dict1["sigma"]
53
+ mu2, sigma2 = dict2["mu"], dict2["sigma"]
54
+ sigma1 += np.eye(sigma1.shape[0]) * eps
55
+ sigma2 += np.eye(sigma2.shape[0]) * eps
56
+
57
+ ssdiff = np.sum((mu1 - mu2) ** 2.0)
58
+ covmean = sqrtm(sigma1.dot(sigma2))
59
+
60
+ # 由于数值问题,有时可能会得到复数,只取实部
61
+ if np.iscomplexobj(covmean):
62
+ covmean = covmean.real
63
+
64
+ fid = ssdiff + np.trace(sigma1 + sigma2 - 2.0 * covmean)
65
+ return fid
66
+
67
+
68
+ # Todo: AudioLDM
69
+ # def generate_features_with_AudioLDM_and_AST(device, processor, AST, AudioLDM_signals_directory_path, return_feature=False):
70
+
71
+ # diffuSynth_features = []
72
+
73
+ # # Step 1: Load all wav files in AudioLDM_signals_directory_path
74
+ # AudioLDM_signals = []
75
+ # signal_lengths = set()
76
+
77
+ # for file_name in os.listdir(AudioLDM_signals_directory_path):
78
+ # if file_name.endswith('.wav'):
79
+ # file_path = os.path.join(AudioLDM_signals_directory_path, file_name)
80
+ # signal, sr = librosa.load(file_path, sr=16000) # Load audio file with sampling rate 16000
81
+ # # Normalize
82
+ # AudioLDM_signals.append(rms_normalize(signal))
83
+ # signal_lengths.add(len(signal))
84
+
85
+ # # Step 2: Check if all signals have the same length
86
+ # if len(signal_lengths) != 1:
87
+ # raise ValueError("Not all signals have the same length. Please ensure all audio files are of the same length.")
88
+
89
+ # # Step 3: Reshape to signal_batches [number_batches, batch_size=8, signal_length]
90
+ # batch_size = 8
91
+ # signal_length = signal_lengths.pop() # All lengths are the same, get one of them
92
+
93
+ # # Create batches
94
+ # signal_batches = [AudioLDM_signals[i:i + batch_size] for i in range(0, len(AudioLDM_signals), batch_size)]
95
+
96
+ # for signal_batch in tqdm(signal_batches):
97
+
98
+ # features = ASTaudio2feature(device, signal_batch, processor, AST, sampling_rate=16000)
99
+ # diffuSynth_features.extend(features)
100
+
101
+ # if return_feature:
102
+ # return diffuSynth_features
103
+ # else:
104
+ # mu, sigma = calculate_statistics(diffuSynth_features)
105
+ # return {"mu": mu, "sigma": sigma}
106
+
107
+ def generate_features_with_AudioLDM_and_AST(device, processor, AST, AudioLDM_signals_directory_path, return_feature=False):
108
+
109
+ diffuSynth_features = []
110
+
111
+ # Step 1: Load all wav files in AudioLDM_signals_directory_path
112
+ AudioLDM_signals = []
113
+ signal_lengths = set()
114
+ target_length = 4 * 16000 # 4 seconds * 16000 samples per second
115
+
116
+ for file_name in os.listdir(AudioLDM_signals_directory_path):
117
+ if file_name.endswith('.wav') and not file_name.startswith('._'):
118
+ file_path = os.path.join(AudioLDM_signals_directory_path, file_name)
119
+ try:
120
+ signal, sr = librosa.load(file_path, sr=16000) # Load audio file with sampling rate 16000
121
+ if len(signal) >= target_length:
122
+ signal = signal[:target_length] # Take only the first 4 seconds
123
+ else:
124
+ raise ValueError(f"The file {file_name} is shorter than 4 seconds.")
125
+ # Normalize
126
+ AudioLDM_signals.append(rms_normalize(signal))
127
+ signal_lengths.add(len(signal))
128
+ except Exception as e:
129
+ print(f"Error loading {file_name}: {e}")
130
+
131
+ # Step 2: Check if all signals have the same length
132
+ if len(signal_lengths) != 1:
133
+ raise ValueError("Not all signals have the same length. Please ensure all audio files are of the same length.")
134
+
135
+ # Step 3: Reshape to signal_batches [number_batches, batch_size=8, signal_length]
136
+ batch_size = 8
137
+ signal_length = signal_lengths.pop() # All lengths are the same, get one of them
138
+
139
+ # Create batches
140
+ signal_batches = [AudioLDM_signals[i:i + batch_size] for i in range(0, len(AudioLDM_signals), batch_size)]
141
+
142
+ for signal_batch in tqdm(signal_batches):
143
+ features = ASTaudio2feature(device, signal_batch, processor, AST, sampling_rate=16000)
144
+ diffuSynth_features.extend(features)
145
+
146
+ if return_feature:
147
+ return diffuSynth_features
148
+ else:
149
+ mu, sigma = calculate_statistics(diffuSynth_features)
150
+ return {"mu": mu, "sigma": sigma}
151
+
152
+
153
+
154
+
155
+ def generate_features_with_diffuSynth_and_AST(device, uNet, VAE, mmm, CLAP_tokenizer, processor, AST, num_batches,
156
+ positive_prompts, negative_prompts="", CFG=1, sample_steps=10, task="spectrograms", return_feature=False):
157
+ diffuSynth_features = []
158
+
159
+ if task == "spectrograms":
160
+ pipe = sample_pipeline
161
+ elif task == "STFT":
162
+ pipe = sample_pipeline_STFT
163
+ else:
164
+ raise NotImplementedError
165
+
166
+ for _ in tqdm(range(num_batches)):
167
+ quantized_latent_representations, reconstruction_batch, signals = pipe(device, uNet, VAE, mmm,
168
+ CLAP_tokenizer,
169
+ positive_prompts=positive_prompts,
170
+ negative_prompts=negative_prompts,
171
+ batchsize=8,
172
+ sample_steps=sample_steps,
173
+ CFG=CFG, seed=None,
174
+ return_latent=False)
175
+
176
+ features = ASTaudio2feature(device, signals, processor, AST, sampling_rate=16000)
177
+ diffuSynth_features.extend(features)
178
+
179
+ if return_feature:
180
+ return diffuSynth_features
181
+ else:
182
+ mu, sigma = calculate_statistics(diffuSynth_features)
183
+ return {"mu": mu, "sigma": sigma}
184
+
185
+
186
+ def generate_features_with_GAN_and_AST(device, gan_generator, VAE, mmm, CLAP_tokenizer, processor, AST, num_batches,
187
+ positive_prompts, negative_prompts="", CFG=1, sample_steps=10, task="spectrograms", return_feature=False):
188
+ diffuSynth_features = []
189
+
190
+ if task == "spectrograms":
191
+ pipe = sample_pipeline_GAN
192
+ elif task == "STFT":
193
+ pipe = sample_pipeline_GAN_STFT
194
+ else:
195
+ raise NotImplementedError
196
+
197
+ for _ in tqdm(range(num_batches)):
198
+ quantized_latent_representations, reconstruction_batch, signals = pipe(device, gan_generator, VAE, mmm,
199
+ CLAP_tokenizer,
200
+ positive_prompts=positive_prompts,
201
+ negative_prompts=negative_prompts,
202
+ batchsize=8,
203
+ sample_steps=sample_steps,
204
+ CFG=CFG, seed=None,
205
+ return_latent=False)
206
+
207
+ features = ASTaudio2feature(device, signals, processor, AST, sampling_rate=16000)
208
+ diffuSynth_features.extend(features)
209
+
210
+ if return_feature:
211
+ return diffuSynth_features
212
+ else:
213
+ mu, sigma = calculate_statistics(diffuSynth_features)
214
+ return {"mu": mu, "sigma": sigma}
215
+
216
+
217
+ def get_FD(train_features, device, uNet, VAE, mmm, CLAP_tokenizer, processor, AST, num_batches, positive_prompts,
218
+ negative_prompts="", CFG=1, sample_steps=10):
219
+ diffuSynth_features = generate_features_with_diffuSynth_and_AST(device, uNet, VAE, mmm, CLAP_tokenizer, processor,
220
+ AST, num_batches, positive_prompts,
221
+ negative_prompts=negative_prompts, CFG=CFG,
222
+ sample_steps=sample_steps)
223
+
224
+ mu_real, sigma_real = calculate_statistics(train_features)
225
+ mu_gen, sigma_gen = calculate_statistics(diffuSynth_features)
226
+
227
+ fid_score = calculate_fid(mu_real, sigma_real, mu_gen, sigma_gen)
228
+ print('FID score:', fid_score)
229
+
230
+
231
+ def get_fid_score(feature1, features2):
232
+ mu_real, sigma_real = calculate_statistics(feature1)
233
+ mu_gen, sigma_gen = calculate_statistics(features2)
234
+
235
+ fid_score = calculate_fid(mu_real, sigma_real, mu_gen, sigma_gen)
236
+ # print('FID score:', fid_score)
237
+ return fid_score
238
+
239
+
240
+ def calculate_fid_matrix(features_list_1, features_list_2, get_fid_score):
241
+ # 初始化一个矩阵来存储FID分数
242
+ # 矩阵的大小为 len(features_list_1) x len(features_list_2)
243
+ fid_scores = [[0 for _ in range(len(features_list_2))] for _ in range(len(features_list_1))]
244
+
245
+ # 遍历两个列表,并计算每一对特征集合的FID分数
246
+ for i, feature1 in enumerate(features_list_1):
247
+ for j, feature2 in enumerate(features_list_2):
248
+ fid_scores[i][j] = get_fid_score(feature1, feature2)
249
+
250
+ return fid_scores
251
+
252
+
253
+ def save_AST_feature(key, mu, sigma, path='results/AST_metric/pre_calculated_features/AST_features.json'):
254
+ # 尝试打开并读取现有的JSON文件
255
+ try:
256
+ with open(path, 'r') as file:
257
+ data = json.load(file)
258
+ except FileNotFoundError:
259
+ # 如果文件不存在,创建一个新的字典
260
+ data = {}
261
+
262
+ if isinstance(mu, np.ndarray):
263
+ mu = mu.tolist()
264
+ if isinstance(sigma, np.ndarray):
265
+ sigma = sigma.tolist()
266
+
267
+ # 添加新数据
268
+ data[key] = {"mu": mu, "sigma": sigma}
269
+
270
+ # 将更新后的数据写回文件
271
+ with open(path, 'w') as file:
272
+ json.dump(data, file, indent=4)
273
+
274
+
275
+ def read_AST_features(path='results/AST_metric/pre_calculated_features/AST_features.json'):
276
+ try:
277
+ # 尝试打开并读取JSON文件
278
+ with open(path, 'r') as file:
279
+ AST_features = json.load(file)
280
+
281
+ for AST_feature_name in AST_features.keys():
282
+ AST_features[AST_feature_name]["mu"] = np.array(AST_features[AST_feature_name]["mu"])
283
+ AST_features[AST_feature_name]["sigma"] = np.array(AST_features[AST_feature_name]["sigma"])
284
+
285
+ return AST_features
286
+ except FileNotFoundError:
287
+ # 如果文件不存在,返回一个空字典
288
+ print(f"文件 {path} 未找到.")
289
+ return {}
290
+ except json.JSONDecodeError:
291
+ # 如果文件不是有效的JSON,返回一个空字典
292
+ print(f"文件 {path} 不是有效的JSON格式.")
293
+ return {}
metrics/IS.py ADDED
@@ -0,0 +1,218 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ import librosa
4
+ import numpy as np
5
+ import torch
6
+ from tqdm import tqdm
7
+
8
+ from metrics.pipelines import sample_pipeline, inpaint_pipeline, sample_pipeline_GAN
9
+ from metrics.pipelines_STFT import sample_pipeline_STFT, sample_pipeline_GAN_STFT
10
+ from tools import rms_normalize, pad_STFT, encode_stft
11
+ from webUI.natural_language_guided.utils import InputBatch2Encode_STFT
12
+
13
+ def get_inception_score_for_AudioLDM(device, timbre_encoder, VAE, AudioLDM_signals_directory_path):
14
+ VAE_encoder, VAE_quantizer, VAE_decoder = VAE._encoder, VAE._vq_vae, VAE._decoder
15
+
16
+ diffuSynth_probabilities = []
17
+
18
+ # Step 1: Load all wav files in AudioLDM_signals_directory_path
19
+ AudioLDM_signals = []
20
+ signal_lengths = set()
21
+ target_length = 4 * 16000 # 4 seconds * 16000 samples per second
22
+
23
+ for file_name in os.listdir(AudioLDM_signals_directory_path):
24
+ if file_name.endswith('.wav') and not file_name.startswith('._'):
25
+ file_path = os.path.join(AudioLDM_signals_directory_path, file_name)
26
+ signal, sr = librosa.load(file_path, sr=16000) # Load audio file with sampling rate 16000
27
+ if len(signal) >= target_length:
28
+ signal = signal[:target_length] # Take only the first 4 seconds
29
+ else:
30
+ raise ValueError(f"The file {file_name} is shorter than 4 seconds.")
31
+ # Normalize
32
+ AudioLDM_signals.append(rms_normalize(signal))
33
+ signal_lengths.add(len(signal))
34
+
35
+ # Step 2: Check if all signals have the same length
36
+ if len(signal_lengths) != 1:
37
+ raise ValueError("Not all signals have the same length. Please ensure all audio files are of the same length.")
38
+
39
+ encoded_audios = []
40
+ for origin_audio in AudioLDM_signals:
41
+ D = librosa.stft(origin_audio, n_fft=1024, hop_length=256, win_length=1024)
42
+ padded_D = pad_STFT(D)
43
+ encoded_D = encode_stft(padded_D)
44
+ encoded_audios.append(encoded_D)
45
+ encoded_audios_np = np.array(encoded_audios)
46
+ origin_spectrogram_batch_tensor = torch.from_numpy(encoded_audios_np).float().to(device)
47
+
48
+ # Step 3: Reshape to signal_batches [number_batches, batch_size=8, signal_length]
49
+ batch_size = 8
50
+ num_batches = int(np.ceil(origin_spectrogram_batch_tensor.shape[0] / batch_size))
51
+ spectrogram_batches = []
52
+ for i in range(num_batches):
53
+ batch = origin_spectrogram_batch_tensor[i * batch_size:(i + 1) * batch_size]
54
+ spectrogram_batches.append(batch)
55
+
56
+ for spectrogram_batch in tqdm(spectrogram_batches):
57
+ spectrogram_batch = spectrogram_batch.to(device)
58
+ _, _, _, _, quantized_latent_representations = InputBatch2Encode_STFT(VAE_encoder, spectrogram_batch, quantizer=VAE_quantizer, squared=False)
59
+ quantized_latent_representations = quantized_latent_representations
60
+ feature, instrument_logits, instrument_family_logits, velocity_logits, qualities = timbre_encoder(quantized_latent_representations)
61
+ probabilities = torch.nn.functional.softmax(instrument_logits, dim=1)
62
+
63
+ diffuSynth_probabilities.extend(probabilities.to("cpu").detach().numpy())
64
+
65
+ return inception_score(np.array(diffuSynth_probabilities))
66
+
67
+
68
+ # def get_inception_score_for_AudioLDM(device, timbre_encoder, VAE, AudioLDM_signals_directory_path):
69
+ # VAE_encoder, VAE_quantizer, VAE_decoder = VAE._encoder, VAE._vq_vae, VAE._decoder
70
+ #
71
+ # diffuSynth_probabilities = []
72
+ #
73
+ # # Step 1: Load all wav files in AudioLDM_signals_directory_path
74
+ # AudioLDM_signals = []
75
+ # signal_lengths = set()
76
+ #
77
+ # for file_name in os.listdir(AudioLDM_signals_directory_path):
78
+ # if file_name.endswith('.wav'):
79
+ # file_path = os.path.join(AudioLDM_signals_directory_path, file_name)
80
+ # signal, sr = librosa.load(file_path, sr=16000) # Load audio file with sampling rate 16000
81
+ # # Normalize
82
+ # AudioLDM_signals.append(rms_normalize(signal))
83
+ # signal_lengths.add(len(signal))
84
+ #
85
+ # # Step 2: Check if all signals have the same length
86
+ # if len(signal_lengths) != 1:
87
+ # raise ValueError("Not all signals have the same length. Please ensure all audio files are of the same length.")
88
+ #
89
+ # encoded_audios = []
90
+ # for origin_audio in AudioLDM_signals:
91
+ # D = librosa.stft(origin_audio, n_fft=1024, hop_length=256, win_length=1024)
92
+ # padded_D = pad_STFT(D)
93
+ # encoded_D = encode_stft(padded_D)
94
+ # encoded_audios.append(encoded_D)
95
+ # encoded_audios_np = np.array(encoded_audios)
96
+ # origin_spectrogram_batch_tensor = torch.from_numpy(encoded_audios_np).float().to(device)
97
+ #
98
+ #
99
+ # # Step 3: Reshape to signal_batches [number_batches, batch_size=8, signal_length]
100
+ # batch_size = 8
101
+ # num_batches = int(np.ceil(origin_spectrogram_batch_tensor.shape[0] / batch_size))
102
+ # spectrogram_batches = []
103
+ # for i in range(num_batches):
104
+ # batch = origin_spectrogram_batch_tensor[i * batch_size:(i + 1) * batch_size]
105
+ # spectrogram_batches.append(batch)
106
+ #
107
+ #
108
+ # for spectrogram_batch in tqdm(spectrogram_batches):
109
+ # spectrogram_batch = spectrogram_batch.to(device)
110
+ # _, _, _, _, quantized_latent_representations = InputBatch2Encode_STFT(VAE_encoder, spectrogram_batch, quantizer=VAE_quantizer,squared=False)
111
+ # quantized_latent_representations = quantized_latent_representations
112
+ # feature, instrument_logits, instrument_family_logits, velocity_logits, qualities = timbre_encoder(quantized_latent_representations)
113
+ # probabilities = torch.nn.functional.softmax(instrument_logits, dim=1)
114
+ #
115
+ # diffuSynth_probabilities.extend(probabilities.to("cpu").detach().numpy())
116
+ #
117
+ # return inception_score(np.array(diffuSynth_probabilities))
118
+
119
+
120
+ def get_inception_score(device, uNet, VAE, MMM, CLAP_tokenizer, timbre_encoder, num_batches, positive_prompts, negative_prompts="", CFG=1, sample_steps=10, task="spectrograms"):
121
+ diffuSynth_probabilities = []
122
+
123
+ if task == "spectrograms":
124
+ pipe = sample_pipeline
125
+ elif task == "STFT":
126
+ pipe = sample_pipeline_STFT
127
+ else:
128
+ raise NotImplementedError
129
+
130
+ for _ in tqdm(range(num_batches)):
131
+ quantized_latent_representations = pipe(device, uNet, VAE, MMM, CLAP_tokenizer,
132
+ positive_prompts=positive_prompts, negative_prompts=negative_prompts,
133
+ batchsize=8, sample_steps=sample_steps, CFG=CFG, seed=None)
134
+
135
+ quantized_latent_representations = quantized_latent_representations.to(device)
136
+ feature, instrument_logits, instrument_family_logits, velocity_logits, qualities = timbre_encoder(quantized_latent_representations)
137
+ probabilities = torch.nn.functional.softmax(instrument_logits, dim=1)
138
+
139
+ diffuSynth_probabilities.extend(probabilities.to("cpu").detach().numpy())
140
+
141
+ return inception_score(np.array(diffuSynth_probabilities))
142
+
143
+
144
+ def get_inception_score_GAN(device, gan_generator, VAE, MMM, CLAP_tokenizer, timbre_encoder, num_batches, positive_prompts, negative_prompts="", CFG=1, sample_steps=10, task="spectrograms"):
145
+ diffuSynth_probabilities = []
146
+
147
+ if task == "spectrograms":
148
+ pipe = sample_pipeline_GAN
149
+ elif task == "STFT":
150
+ pipe = sample_pipeline_GAN_STFT
151
+ else:
152
+ raise NotImplementedError
153
+
154
+ for _ in tqdm(range(num_batches)):
155
+ quantized_latent_representations = pipe(device, gan_generator, VAE, MMM, CLAP_tokenizer,
156
+ positive_prompts=positive_prompts, negative_prompts=negative_prompts,
157
+ batchsize=8, sample_steps=sample_steps, CFG=CFG, seed=None)
158
+
159
+ quantized_latent_representations = quantized_latent_representations.to(device)
160
+ feature, instrument_logits, instrument_family_logits, velocity_logits, qualities = timbre_encoder(quantized_latent_representations)
161
+ probabilities = torch.nn.functional.softmax(instrument_logits, dim=1)
162
+
163
+ diffuSynth_probabilities.extend(probabilities.to("cpu").detach().numpy())
164
+
165
+ return inception_score(np.array(diffuSynth_probabilities))
166
+
167
+
168
+ def predict_qualities_with_diffuSynth_sample(device, uNet, VAE, MMM, CLAP_tokenizer, timbre_encoder, num_batches, positive_prompts, negative_prompts="", CFG=6, sample_steps=10):
169
+ diffuSynth_qualities = []
170
+ for _ in tqdm(range(num_batches)):
171
+ quantized_latent_representations = sample_pipeline(device, uNet, VAE, MMM, CLAP_tokenizer,
172
+ positive_prompts=positive_prompts, negative_prompts=negative_prompts,
173
+ batchsize=8, sample_steps=sample_steps, CFG=CFG, seed=None)
174
+
175
+ quantized_latent_representations = quantized_latent_representations.to(device)
176
+ feature, instrument_logits, instrument_family_logits, velocity_logits, qualities = timbre_encoder(quantized_latent_representations)
177
+ qualities = qualities.to("cpu").detach().numpy()
178
+ # qualities = np.where(qualities > 0.5, 1, 0)
179
+
180
+ diffuSynth_qualities.extend(qualities)
181
+
182
+ return np.mean(diffuSynth_qualities, axis=0)
183
+
184
+
185
+ def generate_probabilities_with_diffuSynth_inpaint(device, uNet, VAE, MMM, CLAP_tokenizer, timbre_encoder, num_batches, guidance, duration, use_dynamic_mask, noising_strength, positive_prompts, negative_prompts="", CFG=6, sample_steps=10):
186
+
187
+ inpaint_probabilities, signals = [], []
188
+ for _ in tqdm(range(num_batches)):
189
+ quantized_latent_representations, _, rec_signals = inpaint_pipeline(device, uNet, VAE, MMM, CLAP_tokenizer,
190
+ use_dynamic_mask=use_dynamic_mask, noising_strength=noising_strength, guidance=guidance,
191
+ positive_prompts=positive_prompts, negative_prompts=negative_prompts, batchsize=8, sample_steps=sample_steps, CFG=CFG, seed=None, duration=duration, mask_flexivity=0.999,
192
+ return_latent=False)
193
+
194
+ quantized_latent_representations = quantized_latent_representations.to(device)
195
+ feature, instrument_logits, instrument_family_logits, velocity_logits, qualities = timbre_encoder(quantized_latent_representations)
196
+ probabilities = torch.nn.functional.softmax(instrument_logits, dim=1)
197
+
198
+ inpaint_probabilities.extend(probabilities.to("cpu").detach().numpy())
199
+ signals.extend(rec_signals)
200
+
201
+ return np.array(inpaint_probabilities), signals
202
+
203
+
204
+ def inception_score(pred):
205
+
206
+ # 计算每个图像的条件概率分布 P(y|x)
207
+ pyx = pred / np.sum(pred, axis=1, keepdims=True)
208
+
209
+ # 计算整个数据集的边缘概率分布 P(y)
210
+ py = np.mean(pyx, axis=0, keepdims=True)
211
+
212
+ # 计算KL散度
213
+ kl_div = pyx * (np.log(pyx + 1e-11) - np.log(py + 1e-11))
214
+
215
+ # 对所有图像求和并平均
216
+ kl_div_sum = np.sum(kl_div, axis=1)
217
+ score = np.exp(np.mean(kl_div_sum))
218
+ return score
metrics/P_C_T.py ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ from metrics.precision_recall import knn_precision_recall_features
3
+
4
+
5
+ # 生成样本
6
+ real_features = np.random.normal(0, 1, size=(1600, 512))
7
+ generated_features = np.random.normal(0, 1, size=(1600, 512))
8
+
9
+ state = knn_precision_recall_features(real_features, generated_features, nhood_sizes=[1, 2, 3, 4, 5, 10],
10
+ row_batch_size=16, col_batch_size=16)
11
+
12
+ print(state)
metrics/get_reference_AST_features.py ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import librosa
3
+ import numpy as np
4
+ from tqdm import tqdm
5
+ from metrics.FD import ASTaudio2feature, calculate_statistics, save_AST_feature
6
+ from tools import rms_normalize
7
+ from transformers import AutoProcessor, ASTModel
8
+
9
+ device = "cpu"
10
+ processor = AutoProcessor.from_pretrained("MIT/ast-finetuned-audioset-10-10-0.4593")
11
+ AST = ASTModel.from_pretrained("MIT/ast-finetuned-audioset-10-10-0.4593").to(device)
12
+
13
+
14
+ data_split = "train"
15
+ with open(f'data/NSynth/{data_split}_examples.json') as f:
16
+ data = json.load(f)
17
+
18
+ def read_signal(note_str):
19
+ y, sr = librosa.load(f"data/NSynth/nsynth-{data_split}-52/audio/{note_str}.wav", sr=16000)
20
+ if len(y) >= 64000:
21
+ y = y[:64000]
22
+ else:
23
+ y_extend = [0.0] * 64000
24
+ y_extend[:len(y)] = y
25
+ y = y_extend
26
+
27
+ return rms_normalize(y)
28
+
29
+ for quality in ["bright", "dark", "distortion", "fast_decay", "long_release", "multiphonic", "nonlinear_env", "percussive", "reverb", "tempo-synced"]:
30
+ features = []
31
+ for i, (note_str, attributes) in tqdm(enumerate(data.items())):
32
+ if not attributes["pitch"] == 52:
33
+ continue
34
+ if not (quality in attributes['qualities_str']):
35
+ continue
36
+
37
+ signal = read_signal(note_str)
38
+ feature_for_one_signal = ASTaudio2feature(device, [signal], processor, AST, sampling_rate=16000)[0]
39
+ features.append(feature_for_one_signal)
40
+
41
+ mu, sigma = calculate_statistics(features)
42
+ print(np.shape(mu))
43
+ print(np.shape(sigma))
44
+
45
+ save_AST_feature(f'{data_split}_{quality}', mu.tolist(), sigma.tolist())
46
+
47
+ for instrument_name in ["bass", "brass", "flute", "guitar", "keyboard", "mallet", "organ", "reed", "string", "synth_lead", "vocal"]:
48
+ features = []
49
+ for i, (note_str, attributes) in tqdm(enumerate(data.items())):
50
+ if not attributes["pitch"] == 52:
51
+ continue
52
+ if not (attributes["instrument_family_str"] == instrument_name):
53
+ continue
54
+
55
+ signal = read_signal(note_str)
56
+ feature_for_one_signal = ASTaudio2feature(device, [signal], processor, AST, sampling_rate=16000)[0]
57
+ features.append(feature_for_one_signal)
58
+
59
+ mu, sigma = calculate_statistics(features)
60
+ print(np.shape(mu))
61
+ print(np.shape(sigma))
62
+
63
+ save_AST_feature(f'{data_split}_{instrument_name}', mu.tolist(), sigma.tolist())
metrics/pipelines.py ADDED
@@ -0,0 +1,144 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import librosa
2
+ import numpy as np
3
+ import torch
4
+ from tqdm import tqdm
5
+
6
+ from tools import VAE_out_put_to_spc, rms_normalize, nnData2Audio
7
+ from model.DiffSynthSampler import DiffSynthSampler
8
+
9
+ def sample_pipeline(device, uNet, VAE, MMM, CLAP_tokenizer,
10
+ positive_prompts, negative_prompts, batchsize, sample_steps, CFG, seed=None, duration=3.0,
11
+ freq_resolution=512, time_resolution=256, channels=4, VAE_scale=4, timesteps=1000, noise_strategy="repeat", sampler="ddim", return_latent=True):
12
+
13
+ height = int(freq_resolution/VAE_scale)
14
+ width = int(time_resolution/VAE_scale)
15
+ VAE_encoder, VAE_quantizer, VAE_decoder = VAE._encoder, VAE._vq_vae, VAE._decoder
16
+
17
+ text2sound_embedding = \
18
+ MMM.get_text_features(**CLAP_tokenizer([positive_prompts], padding=True, return_tensors="pt"))[0].to(device)
19
+ negative_condition = \
20
+ MMM.get_text_features(**CLAP_tokenizer([negative_prompts], padding=True, return_tensors="pt"))[0].to(device)
21
+
22
+ mySampler = DiffSynthSampler(timesteps, height=height, channels=channels, noise_strategy=noise_strategy, mute=True)
23
+ mySampler.activate_classifier_free_guidance(CFG, negative_condition)
24
+
25
+ mySampler.respace(list(np.linspace(0, timesteps - 1, sample_steps, dtype=np.int32)))
26
+
27
+ condition = text2sound_embedding.repeat(batchsize, 1)
28
+
29
+ latent_representations, initial_noise = \
30
+ mySampler.sample(model=uNet, shape=(batchsize, channels, height, width), seed=seed,
31
+ return_tensor=True, condition=condition, sampler=sampler)
32
+
33
+ latent_representations = latent_representations[-1]
34
+
35
+ quantized_latent_representations, _, (_, _, _) = VAE_quantizer(latent_representations)
36
+
37
+ if return_latent:
38
+ return quantized_latent_representations.detach()
39
+ reconstruction_batch = VAE_decoder(quantized_latent_representations).to("cpu").detach().numpy()
40
+ time_resolution = int(time_resolution * ((duration+1) / 4))
41
+
42
+ rec_signals = nnData2Audio(reconstruction_batch, resolution=(freq_resolution, time_resolution))
43
+ rec_signals = [rms_normalize(rec_signal) for rec_signal in rec_signals]
44
+
45
+ return quantized_latent_representations.detach(), reconstruction_batch, rec_signals
46
+
47
+ def sample_pipeline_GAN(device, gan_generator, VAE, MMM, CLAP_tokenizer,
48
+ positive_prompts, negative_prompts, batchsize, sample_steps, CFG, seed=None, duration=3.0,
49
+ freq_resolution=512, time_resolution=256, channels=4, VAE_scale=4, timesteps=1000, noise_strategy="repeat", sampler="ddim", return_latent=True):
50
+
51
+ height = int(freq_resolution/VAE_scale)
52
+ width = int(time_resolution/VAE_scale)
53
+ VAE_encoder, VAE_quantizer, VAE_decoder = VAE._encoder, VAE._vq_vae, VAE._decoder
54
+
55
+ text2sound_embedding = \
56
+ MMM.get_text_features(**CLAP_tokenizer([positive_prompts], padding=True, return_tensors="pt"))[0].to(device)
57
+
58
+ condition = text2sound_embedding.repeat(batchsize, 1)
59
+
60
+ noise = torch.randn(batchsize, channels, height, width).to(device)
61
+ latent_representations = gan_generator(noise, condition)
62
+
63
+ quantized_latent_representations, _, (_, _, _) = VAE_quantizer(latent_representations)
64
+
65
+ if return_latent:
66
+ return quantized_latent_representations.detach()
67
+ reconstruction_batch = VAE_decoder(quantized_latent_representations).to("cpu").detach().numpy()
68
+ time_resolution = int(time_resolution * ((duration+1) / 4))
69
+
70
+ rec_signals = nnData2Audio(reconstruction_batch, resolution=(freq_resolution, time_resolution))
71
+ rec_signals = [rms_normalize(rec_signal) for rec_signal in rec_signals]
72
+
73
+ return quantized_latent_representations.detach(), reconstruction_batch, rec_signals
74
+
75
+ def inpaint_pipeline(device, uNet, VAE, MMM, CLAP_tokenizer, use_dynamic_mask, noising_strength, guidance,
76
+ positive_prompts, negative_prompts, batchsize, sample_steps, CFG, seed=None, duration=3.0, mask_flexivity=0.99,
77
+ freq_resolution=512, time_resolution=256, channels=4, VAE_scale=4, timesteps=1000, noise_strategy="repeat", sampler="ddim", return_latent=True):
78
+
79
+ height = int(freq_resolution/VAE_scale)
80
+ width = int(time_resolution * ((duration + 1) / 4) / VAE_scale)
81
+ VAE_encoder, VAE_quantizer, VAE_decoder = VAE._encoder, VAE._vq_vae, VAE._decoder
82
+
83
+
84
+ text2sound_embedding = \
85
+ MMM.get_text_features(**CLAP_tokenizer([positive_prompts], padding=True, return_tensors="pt"))[0]
86
+ negative_condition = \
87
+ MMM.get_text_features(**CLAP_tokenizer([negative_prompts], padding=True, return_tensors="pt"))[0]
88
+
89
+
90
+ mySampler = DiffSynthSampler(timesteps, height=height, channels=channels, noise_strategy=noise_strategy, mute=True)
91
+ mySampler.activate_classifier_free_guidance(CFG, negative_condition)
92
+ mySampler.respace(list(np.linspace(0, timesteps - 1, sample_steps, dtype=np.int32)))
93
+
94
+ condition = text2sound_embedding.repeat(batchsize, 1)
95
+ guidance = guidance.repeat(batchsize, 1, 1, 1).to(device)
96
+
97
+ # mask = 1, freeze
98
+ latent_mask = torch.zeros((batchsize, 1, height, width), dtype=torch.float32).to(device)
99
+ latent_mask[:, :, :, -int(time_resolution * (1 / 4) / VAE_scale):] = 1.0
100
+
101
+ latent_representations, initial_noise = \
102
+ mySampler.inpaint_sample(model=uNet, shape=(batchsize, channels, height, width),
103
+ noising_strength=noising_strength,
104
+ guide_img=guidance, mask=latent_mask, return_tensor=True,
105
+ condition=condition, sampler=sampler,
106
+ use_dynamic_mask=use_dynamic_mask,
107
+ end_noise_level_ratio=0.0,
108
+ mask_flexivity=mask_flexivity)
109
+
110
+ latent_representations = latent_representations[-1]
111
+
112
+ quantized_latent_representations, _, (_, _, _) = VAE_quantizer(latent_representations)
113
+
114
+ if return_latent:
115
+ return quantized_latent_representations.detach()
116
+ reconstruction_batch = VAE_decoder(quantized_latent_representations).to("cpu").detach().numpy()
117
+ time_resolution = int(time_resolution * ((duration+1) / 4))
118
+
119
+ rec_signals = nnData2Audio(reconstruction_batch, resolution=(freq_resolution, time_resolution))
120
+ rec_signals = [rms_normalize(rec_signal) for rec_signal in rec_signals]
121
+
122
+ return quantized_latent_representations.detach(), reconstruction_batch, rec_signals
123
+
124
+
125
+ def generate_audios_with_diffuSynth_sample(device, uNet, VAE, MMM, CLAP_tokenizer, num_batches, positive_prompts, negative_prompts="", CFG=6, sample_steps=10):
126
+ diffuSynth_signals = []
127
+ for _ in tqdm(range(num_batches)):
128
+ _, _, signals = sample_pipeline(device, uNet, VAE, MMM, CLAP_tokenizer,
129
+ positive_prompts=positive_prompts, negative_prompts=negative_prompts,
130
+ batchsize=16, sample_steps=sample_steps, CFG=CFG, seed=None, return_latent=False)
131
+ diffuSynth_signals.extend(signals)
132
+ return np.array(diffuSynth_signals)
133
+
134
+
135
+ def generate_audios_with_diffuSynth_inpaint(device, uNet, VAE, MMM, CLAP_tokenizer, num_batches, guidance, duration, use_dynamic_mask, noising_strength, positive_prompts, negative_prompts="", CFG=6, sample_steps=10):
136
+
137
+ diffuSynth_signals = []
138
+ for _ in tqdm(range(num_batches)):
139
+ _, _, signals = inpaint_pipeline(device, uNet, VAE, MMM, CLAP_tokenizer,
140
+ use_dynamic_mask=use_dynamic_mask, noising_strength=noising_strength, guidance=guidance,
141
+ positive_prompts=positive_prompts, negative_prompts=negative_prompts, batchsize=16, sample_steps=sample_steps, CFG=CFG, seed=None, duration=duration, mask_flexivity=0.999,
142
+ return_latent=False)
143
+ diffuSynth_signals.extend(signals)
144
+ return np.array(diffuSynth_signals)
metrics/pipelines_STFT.py ADDED
@@ -0,0 +1,100 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import librosa
2
+ import numpy as np
3
+ import torch
4
+ from tqdm import tqdm
5
+
6
+ from tools import rms_normalize, decode_stft, depad_STFT
7
+ from model.DiffSynthSampler import DiffSynthSampler
8
+
9
+ def sample_pipeline_STFT(device, uNet, VAE, MMM, CLAP_tokenizer,
10
+ positive_prompts, negative_prompts, batchsize, sample_steps, CFG, seed=None,
11
+ freq_resolution=512, time_resolution=256, channels=4, VAE_scale=4, timesteps=1000, noise_strategy="repeat", sampler="ddim", return_latent=True):
12
+ "Sample a fix-length audio using a diffusion model, including 'ISTFT+' post-processing."
13
+
14
+ height = int(freq_resolution/VAE_scale)
15
+ width = int(time_resolution/VAE_scale)
16
+ VAE_encoder, VAE_quantizer, VAE_decoder = VAE._encoder, VAE._vq_vae, VAE._decoder
17
+
18
+ text2sound_embedding = \
19
+ MMM.get_text_features(**CLAP_tokenizer([positive_prompts], padding=True, return_tensors="pt"))[0].to(device)
20
+ negative_condition = \
21
+ MMM.get_text_features(**CLAP_tokenizer([negative_prompts], padding=True, return_tensors="pt"))[
22
+ 0].to(device)
23
+
24
+ mySampler = DiffSynthSampler(timesteps, height=height, channels=channels, noise_strategy=noise_strategy, mute=True)
25
+ mySampler.activate_classifier_free_guidance(CFG, negative_condition)
26
+
27
+ mySampler.respace(list(np.linspace(0, timesteps - 1, sample_steps, dtype=np.int32)))
28
+
29
+ condition = text2sound_embedding.repeat(batchsize, 1)
30
+
31
+ latent_representations, initial_noise = \
32
+ mySampler.sample(model=uNet, shape=(batchsize, channels, height, width), seed=seed,
33
+ return_tensor=True, condition=condition, sampler=sampler)
34
+
35
+ latent_representations = latent_representations[-1]
36
+
37
+ quantized_latent_representations, _, (_, _, _) = VAE_quantizer(latent_representations)
38
+
39
+ if return_latent:
40
+ return quantized_latent_representations.detach()
41
+
42
+ reconstruction_batch = VAE_decoder(quantized_latent_representations).to("cpu").detach().numpy()
43
+
44
+ rec_signals = []
45
+
46
+ for index, STFT in enumerate(reconstruction_batch):
47
+ padded_D_rec = decode_stft(STFT)
48
+ D_rec = depad_STFT(padded_D_rec)
49
+ # get_audio
50
+ rec_signal = librosa.istft(D_rec, hop_length=256, win_length=1024)
51
+ rec_signals.append(rms_normalize(rec_signal))
52
+
53
+ return quantized_latent_representations.detach(), reconstruction_batch, rec_signals
54
+
55
+ def sample_pipeline_GAN_STFT(device, gan_generator, VAE, MMM, CLAP_tokenizer,
56
+ positive_prompts, negative_prompts, batchsize, sample_steps, CFG, seed=None,
57
+ freq_resolution=512, time_resolution=256, channels=4, VAE_scale=4, timesteps=1000, noise_strategy="repeat", sampler="ddim", return_latent=True):
58
+ "Sample fix-length audio using a GAN, including 'ISTFT+' post-processing."
59
+
60
+ height = int(freq_resolution/VAE_scale)
61
+ width = int(time_resolution/VAE_scale)
62
+ VAE_encoder, VAE_quantizer, VAE_decoder = VAE._encoder, VAE._vq_vae, VAE._decoder
63
+
64
+ text2sound_embedding = \
65
+ MMM.get_text_features(**CLAP_tokenizer([positive_prompts], padding=True, return_tensors="pt"))[0].to(device)
66
+
67
+ condition = text2sound_embedding.repeat(batchsize, 1)
68
+
69
+ noise = torch.randn(batchsize, channels, height, width).to(device)
70
+ latent_representations = gan_generator(noise, condition)
71
+
72
+ quantized_latent_representations, _, (_, _, _) = VAE_quantizer(latent_representations)
73
+
74
+ if return_latent:
75
+ return quantized_latent_representations.detach()
76
+ reconstruction_batch = VAE_decoder(quantized_latent_representations).to("cpu").detach().numpy()
77
+
78
+ rec_signals = []
79
+
80
+ for index, STFT in enumerate(reconstruction_batch):
81
+ padded_D_rec = decode_stft(STFT)
82
+ D_rec = depad_STFT(padded_D_rec)
83
+ # get_audio
84
+ rec_signal = librosa.istft(D_rec, hop_length=256, win_length=1024)
85
+ rec_signals.append(rms_normalize(rec_signal))
86
+
87
+ return quantized_latent_representations.detach(), reconstruction_batch, rec_signals
88
+
89
+
90
+ def generate_audios_with_diffuSynth_sample(device, uNet, VAE, MMM, CLAP_tokenizer, num_batches, positive_prompts, negative_prompts="", CFG=6, sample_steps=10):
91
+ "Sample audios using a diffusion model, including 'ISTFT+' post-processing."
92
+
93
+ diffuSynth_signals = []
94
+ for _ in tqdm(range(num_batches)):
95
+ _, _, signals = sample_pipeline_STFT(device, uNet, VAE, MMM, CLAP_tokenizer,
96
+ positive_prompts=positive_prompts, negative_prompts=negative_prompts,
97
+ batchsize=8, sample_steps=sample_steps, CFG=CFG, seed=None, return_latent=False)
98
+ diffuSynth_signals.extend(signals)
99
+ return np.array(diffuSynth_signals)
100
+
metrics/precision_recall.py ADDED
@@ -0,0 +1,204 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
2
+ #
3
+ # This work is licensed under the Creative Commons Attribution-NonCommercial
4
+ # 4.0 International License. To view a copy of this license, visit
5
+ # http://creativecommons.org/licenses/by-nc/4.0/ or send a letter to
6
+ # Creative Commons, PO Box 1866, Mountain View, CA 94042, USA.
7
+
8
+ """k-NN precision and recall."""
9
+
10
+ from time import time
11
+
12
+
13
+ # ----------------------------------------------------------------------------
14
+
15
+ import numpy as np
16
+ from tqdm import tqdm
17
+
18
+
19
+ def batch_pairwise_distances(U, V):
20
+ """Compute pair-wise distance in a batch of feature."""
21
+
22
+ norm_u = np.sum(np.square(U), axis=1)
23
+ norm_v = np.sum(np.square(V), axis=1)
24
+
25
+ norm_u = np.reshape(norm_u, [-1, 1])
26
+ norm_v = np.reshape(norm_v, [1, -1])
27
+
28
+ D = np.maximum(norm_u - 2 * np.dot(U, V.T) + norm_v, 0.0)
29
+ return D
30
+
31
+
32
+ # ----------------------------------------------------------------------------
33
+
34
+ class DistanceBlock():
35
+ """Compute pair-wise distance in a batch of feature."""
36
+
37
+ def __init__(self, num_features):
38
+ self.num_features = num_features
39
+
40
+ def pairwise_distances(self, U, V):
41
+ return batch_pairwise_distances(U, V)
42
+
43
+
44
+
45
+ # ----------------------------------------------------------------------------
46
+
47
+ class ManifoldEstimator():
48
+ """Estimates the manifold of given feature vectors."""
49
+
50
+ def __init__(self, distance_block, features, row_batch_size=16, col_batch_size=16,
51
+ nhood_sizes=[3], clamp_to_percentile=None, eps=1e-5, mute=False):
52
+ """Estimate the manifold of given feature vectors.
53
+
54
+ Args:
55
+ distance_block: DistanceBlock object that distributes pairwise distance
56
+ calculation to multiple GPUs.
57
+ features (np.array/tf.Tensor): Matrix of feature vectors to estimate their manifold.
58
+ row_batch_size (int): Row batch size to compute pairwise distances
59
+ (parameter to trade-off between memory usage and performance).
60
+ col_batch_size (int): Column batch size to compute pairwise distances.
61
+ nhood_sizes (list): Number of neighbors used to estimate the manifold.
62
+ clamp_to_percentile (float): Prune hyperspheres that have radius larger than
63
+ the given percentile.
64
+ eps (float): Small number for numerical stability.
65
+ """
66
+ num_images = features.shape[0]
67
+ self.nhood_sizes = nhood_sizes
68
+ self.num_nhoods = len(nhood_sizes)
69
+ self.eps = eps
70
+ self.row_batch_size = row_batch_size
71
+ self.col_batch_size = col_batch_size
72
+ self._ref_features = features
73
+ self._distance_block = distance_block
74
+ self.mute = mute
75
+
76
+ # Estimate manifold of features by calculating distances to k-NN of each sample.
77
+ self.D = np.zeros([num_images, self.num_nhoods], dtype=np.float32)
78
+ distance_batch = np.zeros([row_batch_size, num_images], dtype=np.float32)
79
+ seq = np.arange(max(self.nhood_sizes) + 1, dtype=np.int32)
80
+
81
+ if mute:
82
+ for begin1 in range(0, num_images, row_batch_size):
83
+ end1 = min(begin1 + row_batch_size, num_images)
84
+ row_batch = features[begin1:end1]
85
+
86
+ for begin2 in range(0, num_images, col_batch_size):
87
+ end2 = min(begin2 + col_batch_size, num_images)
88
+ col_batch = features[begin2:end2]
89
+
90
+ # Compute distances between batches.
91
+ distance_batch[0:end1 - begin1, begin2:end2] = self._distance_block.pairwise_distances(row_batch,
92
+ col_batch)
93
+
94
+ # Find the k-nearest neighbor from the current batch.
95
+ self.D[begin1:end1, :] = np.partition(distance_batch[0:end1 - begin1, :], seq, axis=1)[:, self.nhood_sizes]
96
+ else:
97
+ for begin1 in tqdm(range(0, num_images, row_batch_size)):
98
+ end1 = min(begin1 + row_batch_size, num_images)
99
+ row_batch = features[begin1:end1]
100
+
101
+ for begin2 in range(0, num_images, col_batch_size):
102
+ end2 = min(begin2 + col_batch_size, num_images)
103
+ col_batch = features[begin2:end2]
104
+
105
+ # Compute distances between batches.
106
+ distance_batch[0:end1 - begin1, begin2:end2] = self._distance_block.pairwise_distances(row_batch,
107
+ col_batch)
108
+
109
+ # Find the k-nearest neighbor from the current batch.
110
+ self.D[begin1:end1, :] = np.partition(distance_batch[0:end1 - begin1, :], seq, axis=1)[:, self.nhood_sizes]
111
+
112
+ if clamp_to_percentile is not None:
113
+ max_distances = np.percentile(self.D, clamp_to_percentile, axis=0)
114
+ self.D[self.D > max_distances] = 0
115
+
116
+ def evaluate(self, eval_features, return_realism=False, return_neighbors=False):
117
+ """Evaluate if new feature vectors are at the manifold."""
118
+ num_eval_images = eval_features.shape[0]
119
+ num_ref_images = self.D.shape[0]
120
+ distance_batch = np.zeros([self.row_batch_size, num_ref_images], dtype=np.float32)
121
+ batch_predictions = np.zeros([num_eval_images, self.num_nhoods], dtype=np.int32)
122
+ max_realism_score = np.zeros([num_eval_images, ], dtype=np.float32)
123
+ nearest_indices = np.zeros([num_eval_images, ], dtype=np.int32)
124
+
125
+ for begin1 in range(0, num_eval_images, self.row_batch_size):
126
+ end1 = min(begin1 + self.row_batch_size, num_eval_images)
127
+ feature_batch = eval_features[begin1:end1]
128
+
129
+ for begin2 in range(0, num_ref_images, self.col_batch_size):
130
+ end2 = min(begin2 + self.col_batch_size, num_ref_images)
131
+ ref_batch = self._ref_features[begin2:end2]
132
+
133
+ distance_batch[0:end1 - begin1, begin2:end2] = self._distance_block.pairwise_distances(feature_batch,
134
+ ref_batch)
135
+
136
+ # From the minibatch of new feature vectors, determine if they are in the estimated manifold.
137
+ # If a feature vector is inside a hypersphere of some reference sample, then
138
+ # the new sample lies at the estimated manifold.
139
+ # The radii of the hyperspheres are determined from distances of neighborhood size k.
140
+ samples_in_manifold = distance_batch[0:end1 - begin1, :, None] <= self.D
141
+ batch_predictions[begin1:end1] = np.any(samples_in_manifold, axis=1).astype(np.int32)
142
+
143
+ max_realism_score[begin1:end1] = np.max(self.D[:, 0] / (distance_batch[0:end1 - begin1, :] + self.eps),
144
+ axis=1)
145
+ nearest_indices[begin1:end1] = np.argmin(distance_batch[0:end1 - begin1, :], axis=1)
146
+
147
+ if return_realism and return_neighbors:
148
+ return batch_predictions, max_realism_score, nearest_indices
149
+ elif return_realism:
150
+ return batch_predictions, max_realism_score
151
+ elif return_neighbors:
152
+ return batch_predictions, nearest_indices
153
+
154
+ return batch_predictions
155
+
156
+
157
+ # ----------------------------------------------------------------------------
158
+
159
+ def knn_precision_recall_features(ref_features, eval_features, nhood_sizes=[3],
160
+ row_batch_size=10000, col_batch_size=50000, mute=False):
161
+ """Calculates k-NN precision and recall for two sets of feature vectors.
162
+
163
+ Args:
164
+ ref_features (np.array/tf.Tensor): Feature vectors of reference images.
165
+ eval_features (np.array/tf.Tensor): Feature vectors of generated images.
166
+ nhood_sizes (list): Number of neighbors used to estimate the manifold.
167
+ row_batch_size (int): Row batch size to compute pairwise distances
168
+ (parameter to trade-off between memory usage and performance).
169
+ col_batch_size (int): Column batch size to compute pairwise distances.
170
+ num_gpus (int): Number of GPUs used to evaluate precision and recall.
171
+
172
+ Returns:
173
+ State (dict): Dict that contains precision and recall calculated from
174
+ ref_features and eval_features.
175
+ """
176
+ state = dict()
177
+ num_images = ref_features.shape[0]
178
+ num_features = ref_features.shape[1]
179
+
180
+ # Initialize DistanceBlock and ManifoldEstimators.
181
+ distance_block = DistanceBlock(num_features)
182
+ ref_manifold = ManifoldEstimator(distance_block, ref_features, row_batch_size, col_batch_size, nhood_sizes, mute=mute)
183
+ eval_manifold = ManifoldEstimator(distance_block, eval_features, row_batch_size, col_batch_size, nhood_sizes, mute=mute)
184
+
185
+ # Evaluate precision and recall using k-nearest neighbors.
186
+ if not mute:
187
+ print('Evaluating k-NN precision and recall with %i samples...' % num_images)
188
+ start = time()
189
+
190
+ # Precision: How many points from eval_features are in ref_features manifold.
191
+ precision = ref_manifold.evaluate(eval_features)
192
+ state['precision'] = precision.mean(axis=0)
193
+
194
+ # Recall: How many points from ref_features are in eval_features manifold.
195
+ recall = eval_manifold.evaluate(ref_features)
196
+ state['recall'] = recall.mean(axis=0)
197
+
198
+ if not mute:
199
+ print('Evaluated k-NN precision and recall in: %gs' % (time() - start))
200
+
201
+ return state
202
+
203
+ # ----------------------------------------------------------------------------
204
+
metrics/visualizations.py ADDED
@@ -0,0 +1,123 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ from matplotlib import pyplot as plt
3
+ from scipy.fft import fft
4
+ from scipy.signal import savgol_filter
5
+ from tools import rms_normalize
6
+
7
+ colors = [
8
+ # (0, 0, 0), # Black
9
+ # (86, 180, 233), # Sky blue
10
+ # (240, 228, 66), # Yellow
11
+ # (204, 121, 167), # Reddish purple
12
+ (213, 94, 0), # Vermilion
13
+ (0, 114, 178), # Blue
14
+ (230, 159, 0), # Orange
15
+ (0, 158, 115), # Bluish green
16
+ ]
17
+
18
+
19
+ def plot_psd_multiple_signals(signals_list, labels_list, sample_rate=16000, window_size=500,
20
+ figsize=(10, 6), save_path=None, normalize=False):
21
+ """
22
+ 在同一张图上绘制多组音频信号的功率谱密度比较图,使用对数刻度的响度轴(以2为底),并应用平滑处理。
23
+
24
+ 参数:
25
+ signals_list: 包含多组音频信号的列表,每组信号形状为 [sample_number, sample_length] 的numpy array
26
+ labels_list: 每组音频信号对应的标签字符串列表
27
+ sample_rate: 音频的采样率
28
+ """
29
+
30
+ # 确保传入的signals_list和labels_list长度相同
31
+ assert len(signals_list) == len(labels_list), "每组信号必须有一个对应的标签。"
32
+
33
+ signals_list = [np.array([rms_normalize(signal) for signal in signals]) for signals in signals_list]
34
+
35
+ # 绘图准备
36
+ plt.figure(figsize=figsize)
37
+
38
+ # 遍历所有的音频信号
39
+ i = 0
40
+ for signal, label in zip(signals_list, labels_list):
41
+ # 计算FFT
42
+ fft_signal = fft(signal, axis=1)
43
+
44
+ # 计算平均功率谱密度
45
+ psd_signal = np.mean(np.abs(fft_signal)**2, axis=0)
46
+
47
+ # 计算频率轴
48
+ freqs = np.fft.fftfreq(signal.shape[1], 1/sample_rate)
49
+
50
+ # 应用Savitzky-Golay滤波器进行平滑
51
+ psd_smoothed = savgol_filter(np.log2(psd_signal[:signal.shape[1] // 2] + 1), window_size, 3) # 窗口大小51, 多项式阶数3
52
+
53
+ # Normalize each curve if normalize is True
54
+ if normalize:
55
+ psd_smoothed /= np.mean(psd_smoothed)
56
+
57
+ # 绘制每组信号的功率谱密度
58
+ plt.plot(freqs[:signal.shape[1] // 2], psd_smoothed, label=label, color=[x/255.0 for x in colors[i % len(colors)]], linewidth=1)
59
+ i += 1
60
+
61
+ # 设置图表元素
62
+ plt.xlabel('Frequency (Hz)')
63
+ plt.ylabel('Mean Log-Amplitude')
64
+ plt.legend()
65
+
66
+ # 根据save_path参数决定保存图像还是直接显示
67
+ if save_path:
68
+ plt.savefig(save_path)
69
+ else:
70
+ plt.show()
71
+
72
+
73
+ def plot_amplitude_over_time(signals_list, labels_list, sample_rate=16000, window_size=500,
74
+ figsize=(10, 6), save_path=None, normalize=False, start_time=0):
75
+ """
76
+ Plot the loudness of multiple sets of audio signals over time on the same graph,
77
+ using a logarithmic scale for the loudness axis (base 2), with smoothing applied.
78
+
79
+ Parameters:
80
+ signals_list: List of sets of audio signals, each set is a numpy array with shape [sample_number, sample_length]
81
+ labels_list: List of labels corresponding to each set of audio signals
82
+ sample_rate: Sampling rate of the audio
83
+ window_size: Window size for the Savitzky-Golay filter
84
+ figsize: Figure size
85
+ save_path: Path to save the figure, if None, the figure will be displayed
86
+ normalize: Whether to normalize each curve so that the sum of each curve is the same
87
+ start_time: Time (in seconds) to start plotting, only data after this time will be retained
88
+ """
89
+ assert len(signals_list) == len(labels_list), f"len(signals_list) != len(labels_list) for " \
90
+ f"len(signals_list) = {len(signals_list)} and len(labels_list) = {len(labels_list)}"
91
+
92
+ # Compute starting sample index
93
+ start_sample = int(start_time * sample_rate)
94
+
95
+ # Normalize signals and truncate data
96
+ signals_list = [np.array([rms_normalize(signal)[start_sample:] for signal in signals]) for signals in signals_list]
97
+ time_axis = np.arange(start_sample, start_sample + signals_list[0].shape[1]) / sample_rate
98
+
99
+ plt.figure(figsize=figsize)
100
+
101
+ i = 0
102
+ for signal, label in zip(signals_list, labels_list):
103
+ amplitude_mean = np.mean(np.abs(signal), axis=0)
104
+
105
+ amplitude_smoothed = savgol_filter(np.log2(amplitude_mean + 1), window_size, 3)
106
+
107
+ # Normalize each curve if normalize is True
108
+ if normalize:
109
+ amplitude_smoothed /= np.mean(amplitude_smoothed)
110
+
111
+ plt.plot(time_axis, amplitude_smoothed, label=label, color=[x/255.0 for x in colors[i % len(colors)]], linewidth=1)
112
+ i += 1
113
+
114
+ plt.xlabel('Time (seconds)')
115
+ plt.ylabel('Mean Log-Amplitude')
116
+ plt.legend()
117
+
118
+ # Save or show the figure based on save_path parameter
119
+ if save_path:
120
+ plt.savefig(save_path)
121
+ else:
122
+ plt.show()
123
+
model/DiffSynthSampler.py ADDED
@@ -0,0 +1,425 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ from tqdm import tqdm
4
+
5
+
6
+ def _extract_into_tensor(arr, timesteps, broadcast_shape):
7
+ """
8
+ Extract values from a 1-D numpy array for a batch of indices.
9
+
10
+ :param arr: the 1-D numpy array.
11
+ :param timesteps: a tensor of indices into the array to extract.
12
+ :param broadcast_shape: a larger shape of K dimensions with the batch
13
+ dimension equal to the length of timesteps.
14
+ :return: a tensor of shape [batch_size, 1, ...] where the shape has K dims.
15
+ """
16
+ res = torch.from_numpy(arr).to(device=timesteps.device)[timesteps].float()
17
+ while len(res.shape) < len(broadcast_shape):
18
+ res = res[..., None]
19
+ return res.expand(broadcast_shape)
20
+
21
+
22
+ class DiffSynthSampler:
23
+
24
+ def __init__(self, timesteps, beta_start=0.0001, beta_end=0.02, device=None, mute=False,
25
+ height=128, max_batchsize=16, max_width=256, channels=4, train_width=64, noise_strategy="repeat"):
26
+ if device is None:
27
+ self.device = "cuda" if torch.cuda.is_available() else "cpu"
28
+ else:
29
+ self.device = device
30
+ self.height = height
31
+ self.train_width = train_width
32
+ self.max_batchsize = max_batchsize
33
+ self.max_width = max_width
34
+ self.channels = channels
35
+ self.num_timesteps = timesteps
36
+ self.timestep_map = list(range(self.num_timesteps))
37
+ self.betas = np.array(np.linspace(beta_start, beta_end, self.num_timesteps), dtype=np.float64)
38
+ self.respaced = False
39
+ self.define_beta_schedule()
40
+ self.CFG = 1.0
41
+ self.mute = mute
42
+ self.noise_strategy = noise_strategy
43
+
44
+ def get_deterministic_noise_tensor_non_repeat(self, batchsize, width, reference_noise=None):
45
+ if reference_noise is None:
46
+ large_noise_tensor = torch.randn((self.max_batchsize, self.channels, self.height, self.max_width), device=self.device)
47
+ else:
48
+ assert reference_noise.shape == (batchsize, self.channels, self.height, self.max_width), "reference_noise shape mismatch"
49
+ large_noise_tensor = reference_noise
50
+ return large_noise_tensor[:batchsize, :, :, :width], None
51
+
52
+ def get_deterministic_noise_tensor(self, batchsize, width, reference_noise=None):
53
+ if self.noise_strategy == "repeat":
54
+ noise, concat_points = self.get_deterministic_noise_tensor_repeat(batchsize, width, reference_noise=reference_noise)
55
+ return noise, concat_points
56
+ else:
57
+ noise, concat_points = self.get_deterministic_noise_tensor_non_repeat(batchsize, width, reference_noise=reference_noise)
58
+ return noise, concat_points
59
+
60
+
61
+ def get_deterministic_noise_tensor_repeat(self, batchsize, width, reference_noise=None):
62
+ # 生成与训练数据长度相等的噪音
63
+ if reference_noise is None:
64
+ train_noise_tensor = torch.randn((self.max_batchsize, self.channels, self.height, self.train_width), device=self.device)
65
+ else:
66
+ assert reference_noise.shape == (batchsize, self.channels, self.height, self.train_width), "reference_noise shape mismatch"
67
+ train_noise_tensor = reference_noise
68
+
69
+ release_width = int(self.train_width * 1.0 / 4)
70
+ first_part_width = self.train_width - release_width
71
+
72
+ first_part = train_noise_tensor[:batchsize, :, :, :first_part_width]
73
+ release_part = train_noise_tensor[:batchsize, :, :, -release_width:]
74
+
75
+ # 如果所需 length 小于等于 origin length,去掉 first_part 的中间部分
76
+ if width <= self.train_width:
77
+ _first_part_head_width = int((width - release_width) / 2)
78
+ _first_part_tail_width = width - release_width - _first_part_head_width
79
+ all_parts = [first_part[:, :, :, :_first_part_head_width], first_part[:, :, :, -_first_part_tail_width:], release_part]
80
+
81
+ # 沿第四维度拼接张量
82
+ noise_tensor = torch.cat(all_parts, dim=3)
83
+
84
+ # 记录拼接点的位置
85
+ concat_points = [0]
86
+ for part in all_parts[:-1]:
87
+ next_point = concat_points[-1] + part.size(3)
88
+ concat_points.append(next_point)
89
+
90
+ return noise_tensor, concat_points
91
+
92
+ # 如果所需 length 大于 origin length,不断地从中间插入 first_part 的中间部分
93
+ else:
94
+ # 计算需要重复front_width的次数
95
+ repeats = (width - release_width) // first_part_width
96
+ extra = (width - release_width) % first_part_width
97
+
98
+ _repeat_first_part_head_width = int(first_part_width / 2)
99
+ _repeat_first_part_tail_width = first_part_width - _repeat_first_part_head_width
100
+
101
+ repeated_first_head_parts = [first_part[:, :, :, :_repeat_first_part_head_width] for _ in range(repeats)]
102
+ repeated_first_tail_parts = [first_part[:, :, :, -_repeat_first_part_tail_width:] for _ in range(repeats)]
103
+
104
+ # 计算起始索引
105
+ _middle_part_start_index = (first_part_width - extra) // 2
106
+ # 切片张量以获取中间部分
107
+ middle_part = first_part[:, :, :, _middle_part_start_index: _middle_part_start_index + extra]
108
+
109
+ all_parts = repeated_first_head_parts + [middle_part] + repeated_first_tail_parts + [release_part]
110
+
111
+ # 沿第四维度拼接张量
112
+ noise_tensor = torch.cat(all_parts, dim=3)
113
+
114
+ # 记录拼接点的位置
115
+ concat_points = [0]
116
+ for part in all_parts[:-1]:
117
+ next_point = concat_points[-1] + part.size(3)
118
+ concat_points.append(next_point)
119
+
120
+ return noise_tensor, concat_points
121
+
122
+ def define_beta_schedule(self):
123
+ assert self.respaced == False, "This schedule has already been respaced!"
124
+ # define alphas
125
+ self.alphas = 1.0 - self.betas
126
+ self.alphas_cumprod = np.cumprod(self.alphas, axis=0)
127
+ self.alphas_cumprod_prev = np.append(1.0, self.alphas_cumprod[:-1])
128
+ self.alphas_cumprod_next = np.append(self.alphas_cumprod[1:], 0.0)
129
+
130
+ # calculations for diffusion q(x_t | x_{t-1}) and others
131
+ self.sqrt_alphas_cumprod = np.sqrt(self.alphas_cumprod)
132
+ self.sqrt_one_minus_alphas_cumprod = np.sqrt(1.0 - self.alphas_cumprod)
133
+ self.log_one_minus_alphas_cumprod = np.log(1.0 - self.alphas_cumprod)
134
+ self.sqrt_recip_alphas_cumprod = np.sqrt(1.0 / self.alphas_cumprod)
135
+ self.sqrt_recip_alphas = np.sqrt(1.0 / self.alphas)
136
+ self.sqrt_recipm1_alphas_cumprod = np.sqrt(1.0 / self.alphas_cumprod - 1)
137
+
138
+ # calculations for posterior q(x_{t-1} | x_t, x_0)
139
+ self.posterior_variance = (self.betas * (1.0 - self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod))
140
+
141
+ def activate_classifier_free_guidance(self, CFG, unconditional_condition):
142
+ assert (
143
+ not unconditional_condition is None) or CFG == 1.0, "For CFG != 1.0, unconditional_condition must be available"
144
+ self.CFG = CFG
145
+ self.unconditional_condition = unconditional_condition
146
+
147
+ def respace(self, use_timesteps=None):
148
+ if not use_timesteps is None:
149
+ last_alpha_cumprod = 1.0
150
+ new_betas = []
151
+ self.timestep_map = []
152
+ for i, _alpha_cumprod in enumerate(self.alphas_cumprod):
153
+ if i in use_timesteps:
154
+ new_betas.append(1 - _alpha_cumprod / last_alpha_cumprod)
155
+ last_alpha_cumprod = _alpha_cumprod
156
+ self.timestep_map.append(i)
157
+ self.num_timesteps = len(use_timesteps)
158
+ self.betas = np.array(new_betas)
159
+ self.define_beta_schedule()
160
+ self.respaced = True
161
+
162
+ def generate_linear_noise(self, shape, variance=1.0, first_endpoint=None, second_endpoint=None):
163
+ assert shape[1] == self.channels, "shape[1] != self.channels"
164
+ assert shape[2] == self.height, "shape[2] != self.height"
165
+ noise = torch.empty(*shape, device=self.device)
166
+
167
+ # 第三种情况:两个端点都不是None,进行线性插值
168
+ if first_endpoint is not None and second_endpoint is not None:
169
+ for i in range(shape[0]):
170
+ alpha = i / (shape[0] - 1) # 插值系数
171
+ noise[i] = alpha * second_endpoint + (1 - alpha) * first_endpoint
172
+ return noise # 返回插值后的结果,不需要进行后续的均值和方差调整
173
+ else:
174
+ # 第一个端点不是None
175
+ if first_endpoint is not None:
176
+ noise[0] = first_endpoint
177
+ if shape[0] > 1:
178
+ noise[1], _ = self.get_deterministic_noise_tensor(1, shape[3])[0]
179
+ else:
180
+ noise[0], _ = self.get_deterministic_noise_tensor(1, shape[3])[0]
181
+ if shape[0] > 1:
182
+ noise[1], _ = self.get_deterministic_noise_tensor(1, shape[3])[0]
183
+
184
+ # 生成其他的噪声点
185
+ for i in range(2, shape[0]):
186
+ noise[i] = 2 * noise[i - 1] - noise[i - 2]
187
+
188
+ # 当只有一个端点被指定时
189
+ current_var = noise.var()
190
+ stddev_ratio = torch.sqrt(variance / current_var)
191
+ noise = noise * stddev_ratio
192
+
193
+ # 如果第一个端点被指定,进行平移调整
194
+ if first_endpoint is not None:
195
+ shift = first_endpoint - noise[0]
196
+ noise += shift
197
+
198
+ return noise
199
+
200
+ def q_sample(self, x_start, t, noise=None):
201
+ """
202
+ Diffuse the data for a given number of diffusion steps.
203
+
204
+ In other words, sample from q(x_t | x_0).
205
+
206
+ :param x_start: the initial data batch.
207
+ :param t: the number of diffusion steps (minus 1). Here, 0 means one step.
208
+ :param noise: if specified, the split-out normal noise.
209
+ :return: A noisy version of x_start.
210
+ """
211
+ assert x_start.shape[1] == self.channels, "shape[1] != self.channels"
212
+ assert x_start.shape[2] == self.height, "shape[2] != self.height"
213
+
214
+ if noise is None:
215
+ # noise = torch.randn_like(x_start)
216
+ noise, _ = self.get_deterministic_noise_tensor(x_start.shape[0], x_start.shape[3])
217
+
218
+ assert noise.shape == x_start.shape
219
+ return (
220
+ _extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start
221
+ + _extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape)
222
+ * noise
223
+ )
224
+
225
+ @torch.no_grad()
226
+ def ddim_sample(self, model, x, t, condition=None, ddim_eta=0.0):
227
+ map_tensor = torch.tensor(self.timestep_map, device=t.device, dtype=t.dtype)
228
+ mapped_t = map_tensor[t]
229
+
230
+ # Todo: add CFG
231
+
232
+ if self.CFG == 1.0:
233
+ pred_noise = model(x, mapped_t, condition)
234
+ else:
235
+ unconditional_condition = self.unconditional_condition.unsqueeze(0).repeat(
236
+ *([x.shape[0]] + [1] * len(self.unconditional_condition.shape)))
237
+ x_in = torch.cat([x] * 2)
238
+ t_in = torch.cat([mapped_t] * 2)
239
+ c_in = torch.cat([unconditional_condition, condition])
240
+ noise_uncond, noise = model(x_in, t_in, c_in).chunk(2)
241
+ pred_noise = noise_uncond + self.CFG * (noise - noise_uncond)
242
+
243
+ # Todo: END
244
+
245
+ alpha_cumprod_t = _extract_into_tensor(self.alphas_cumprod, t, x.shape)
246
+ alpha_cumprod_t_prev = _extract_into_tensor(self.alphas_cumprod_prev, t, x.shape)
247
+
248
+ pred_x0 = (x - torch.sqrt((1. - alpha_cumprod_t)) * pred_noise) / torch.sqrt(alpha_cumprod_t)
249
+
250
+ sigmas_t = (
251
+ ddim_eta
252
+ * torch.sqrt((1 - alpha_cumprod_t_prev) / (1 - alpha_cumprod_t))
253
+ * torch.sqrt(1 - alpha_cumprod_t / alpha_cumprod_t_prev)
254
+ )
255
+
256
+ pred_dir_xt = torch.sqrt(1 - alpha_cumprod_t_prev - sigmas_t ** 2) * pred_noise
257
+
258
+
259
+ step_noise, _ = self.get_deterministic_noise_tensor(x.shape[0], x.shape[3])
260
+
261
+
262
+ x_prev = torch.sqrt(alpha_cumprod_t_prev) * pred_x0 + pred_dir_xt + sigmas_t * step_noise
263
+
264
+ return x_prev
265
+
266
+ def p_sample(self, model, x, t, condition=None, sampler="ddim"):
267
+ if sampler == "ddim":
268
+ return self.ddim_sample(model, x, t, condition=condition, ddim_eta=0.0)
269
+ elif sampler == "ddpm":
270
+ return self.ddim_sample(model, x, t, condition=condition, ddim_eta=1.0)
271
+ else:
272
+ raise NotImplementedError()
273
+
274
+ def get_dynamic_masks(self, n_masks, shape, concat_points, mask_flexivity=0.8):
275
+ release_length = int(self.train_width / 4)
276
+ assert shape[3] == (concat_points[-1] + release_length), "shape[3] != (concat_points[-1] + release_length)"
277
+
278
+ fraction_lengths = [concat_points[i + 1] - concat_points[i] for i in range(len(concat_points) - 1)]
279
+
280
+ # Todo: remove hard-coding
281
+ n_guidance_steps = int(n_masks * mask_flexivity)
282
+ n_free_steps = n_masks - n_guidance_steps
283
+
284
+ masks = []
285
+ # Todo: 在一半的 steps 内收缩 mask。也就是说,在后程对 release 以外的区域不做inpaint,而是 img2img
286
+ for i in range(n_guidance_steps):
287
+ # mask = 1, freeze
288
+ step_i_mask = torch.zeros((shape[0], 1, shape[2], shape[3]), dtype=torch.float32).to(self.device)
289
+ step_i_mask[:, :, :, -release_length:] = 1.0
290
+
291
+ for fraction_index in range(len(fraction_lengths)):
292
+
293
+ _fraction_mask_length = int((n_guidance_steps - 1 - i) / (n_guidance_steps - 1) * fraction_lengths[fraction_index])
294
+
295
+ if fraction_index == 0:
296
+ step_i_mask[:, :, :, :_fraction_mask_length] = 1.0
297
+ elif fraction_index == len(fraction_lengths) - 1:
298
+ if not _fraction_mask_length == 0:
299
+ step_i_mask[:, :, :, -_fraction_mask_length - release_length:] = 1.0
300
+ else:
301
+ fraction_mask_start_position = int((fraction_lengths[fraction_index] - _fraction_mask_length) / 2)
302
+
303
+ step_i_mask[:, :, :,
304
+ concat_points[fraction_index] + fraction_mask_start_position:concat_points[
305
+ fraction_index] + fraction_mask_start_position + _fraction_mask_length] = 1.0
306
+ masks.append(step_i_mask)
307
+
308
+ for i in range(n_free_steps):
309
+ step_i_mask = torch.zeros((shape[0], 1, shape[2], shape[3]), dtype=torch.float32).to(self.device)
310
+ step_i_mask[:, :, :, -release_length:] = 1.0
311
+ masks.append(step_i_mask)
312
+
313
+ masks.reverse()
314
+ return masks
315
+
316
+ @torch.no_grad()
317
+ def p_sample_loop(self, model, shape, initial_noise=None, start_noise_level_ratio=1.0, end_noise_level_ratio=0.0,
318
+ return_tensor=False, condition=None, guide_img=None,
319
+ mask=None, sampler="ddim", inpaint=False, use_dynamic_mask=False, mask_flexivity=0.8):
320
+
321
+ assert shape[1] == self.channels, "shape[1] != self.channels"
322
+ assert shape[2] == self.height, "shape[2] != self.height"
323
+
324
+ initial_noise, _ = self.get_deterministic_noise_tensor(shape[0], shape[3], reference_noise=initial_noise)
325
+ assert initial_noise.shape == shape, "initial_noise.shape != shape"
326
+
327
+ start_noise_level_index = int(self.num_timesteps * start_noise_level_ratio) # not included!!!
328
+ end_noise_level_index = int(self.num_timesteps * end_noise_level_ratio)
329
+
330
+ timesteps = reversed(range(end_noise_level_index, start_noise_level_index))
331
+
332
+ # configure initial img
333
+ assert (start_noise_level_ratio == 1.0) or (
334
+ not guide_img is None), "A guide_img must be given to sample from a non-pure-noise."
335
+
336
+ if guide_img is None:
337
+ img = initial_noise
338
+ else:
339
+ guide_img, concat_points = self.get_deterministic_noise_tensor_repeat(shape[0], shape[3], reference_noise=guide_img)
340
+ assert guide_img.shape == shape, "guide_img.shape != shape"
341
+
342
+ if start_noise_level_index > 0:
343
+ t = torch.full((shape[0],), start_noise_level_index-1, device=self.device).long() # -1 for start_noise_level_index not included
344
+ img = self.q_sample(guide_img, t, noise=initial_noise)
345
+ else:
346
+ print("Zero noise added to the guidance latent representation.")
347
+ img = guide_img
348
+
349
+ # get masks
350
+ n_masks = start_noise_level_index - end_noise_level_index
351
+ if use_dynamic_mask:
352
+ masks = self.get_dynamic_masks(n_masks, shape, concat_points, mask_flexivity)
353
+ else:
354
+ masks = [mask for _ in range(n_masks)]
355
+
356
+ imgs = [img]
357
+ current_mask = None
358
+
359
+
360
+ for i in tqdm(timesteps, total=start_noise_level_index - end_noise_level_index, disable=self.mute):
361
+
362
+ # if i == 3:
363
+ # return [img], initial_noise # 第1排,第1列
364
+
365
+ img = self.p_sample(model, img, torch.full((shape[0],), i, device=self.device, dtype=torch.long),
366
+ condition=condition,
367
+ sampler=sampler)
368
+ # if i == 3:
369
+ # return [img], initial_noise # 第1排,第2列
370
+
371
+ if inpaint:
372
+ if i > 0:
373
+ t = torch.full((shape[0],), int(i-1), device=self.device).long()
374
+ img_noise_t = self.q_sample(guide_img, t, noise=initial_noise)
375
+ # if i == 3:
376
+ # return [img_noise_t], initial_noise # 第2排,第2列
377
+ current_mask = masks.pop()
378
+ img = current_mask * img_noise_t + (1 - current_mask) * img
379
+ # if i == 3:
380
+ # return [img], initial_noise # 第1.5排,最后1列
381
+ else:
382
+ img = current_mask * guide_img + (1 - current_mask) * img
383
+
384
+ if return_tensor:
385
+ imgs.append(img)
386
+ else:
387
+ imgs.append(img.cpu().numpy())
388
+
389
+ return imgs, initial_noise
390
+
391
+
392
+ def sample(self, model, shape, return_tensor=False, condition=None, sampler="ddim", initial_noise=None, seed=None):
393
+ if not seed is None:
394
+ torch.manual_seed(seed)
395
+ return self.p_sample_loop(model, shape, initial_noise=initial_noise, start_noise_level_ratio=1.0, end_noise_level_ratio=0.0,
396
+ return_tensor=return_tensor, condition=condition, sampler=sampler)
397
+
398
+ def interpolate(self, model, shape, variance, first_endpoint=None, second_endpoint=None, return_tensor=False,
399
+ condition=None, sampler="ddim", seed=None):
400
+ if not seed is None:
401
+ torch.manual_seed(seed)
402
+ linear_noise = self.generate_linear_noise(shape, variance, first_endpoint=first_endpoint,
403
+ second_endpoint=second_endpoint)
404
+ return self.p_sample_loop(model, shape, initial_noise=linear_noise, start_noise_level_ratio=1.0,
405
+ end_noise_level_ratio=0.0,
406
+ return_tensor=return_tensor, condition=condition, sampler=sampler)
407
+
408
+ def img_guided_sample(self, model, shape, noising_strength, guide_img, return_tensor=False, condition=None,
409
+ sampler="ddim", initial_noise=None, seed=None):
410
+ if not seed is None:
411
+ torch.manual_seed(seed)
412
+ assert guide_img.shape[-1] == shape[-1], "guide_img.shape[:-1] != shape[:-1]"
413
+ return self.p_sample_loop(model, shape, start_noise_level_ratio=noising_strength, end_noise_level_ratio=0.0,
414
+ return_tensor=return_tensor, condition=condition, sampler=sampler,
415
+ guide_img=guide_img, initial_noise=initial_noise)
416
+
417
+ def inpaint_sample(self, model, shape, noising_strength, guide_img, mask, return_tensor=False, condition=None,
418
+ sampler="ddim", initial_noise=None, use_dynamic_mask=False, end_noise_level_ratio=0.0, seed=None,
419
+ mask_flexivity=0.8):
420
+ if not seed is None:
421
+ torch.manual_seed(seed)
422
+ return self.p_sample_loop(model, shape, start_noise_level_ratio=noising_strength, end_noise_level_ratio=end_noise_level_ratio,
423
+ return_tensor=return_tensor, condition=condition, guide_img=guide_img, mask=mask,
424
+ sampler=sampler, inpaint=True, initial_noise=initial_noise, use_dynamic_mask=use_dynamic_mask,
425
+ mask_flexivity=mask_flexivity)
model/GAN.py ADDED
@@ -0,0 +1,262 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import numpy as np
3
+ import torch
4
+ from torch import nn
5
+ from six.moves import xrange
6
+ from torch.utils.tensorboard import SummaryWriter
7
+ import random
8
+
9
+ from model.diffusion import ConditionedUnet
10
+ from tools import create_key
11
+
12
+ class Discriminator(nn.Module):
13
+ def __init__(self, label_emb_dim):
14
+ super(Discriminator, self).__init__()
15
+ # 特征图卷积层
16
+ self.conv_layers = nn.Sequential(
17
+ nn.Conv2d(4, 64, kernel_size=4, stride=2, padding=1),
18
+ nn.LeakyReLU(0.2, inplace=True),
19
+ nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1),
20
+ nn.BatchNorm2d(128),
21
+ nn.LeakyReLU(0.2, inplace=True),
22
+ nn.Conv2d(128, 256, kernel_size=4, stride=2, padding=1),
23
+ nn.BatchNorm2d(256),
24
+ nn.LeakyReLU(0.2, inplace=True),
25
+ nn.Conv2d(256, 512, kernel_size=4, stride=2, padding=1),
26
+ nn.BatchNorm2d(512),
27
+ nn.LeakyReLU(0.2, inplace=True),
28
+ nn.AdaptiveAvgPool2d(1), # 添加适应性池化层
29
+ nn.Flatten()
30
+ )
31
+
32
+ # 文本嵌入处理
33
+ self.text_embedding = nn.Sequential(
34
+ nn.Linear(label_emb_dim, 512),
35
+ nn.LeakyReLU(0.2, inplace=True)
36
+ )
37
+
38
+ # 判别器最后的全连接层
39
+ self.fc = nn.Linear(512 + 512, 1) # 两个512分别来自特征图和文本嵌入
40
+
41
+ def forward(self, x, text_emb):
42
+ x = self.conv_layers(x)
43
+ text_emb = self.text_embedding(text_emb)
44
+ combined = torch.cat((x, text_emb), dim=1)
45
+ output = self.fc(combined)
46
+ return output
47
+
48
+
49
+
50
+ def evaluate_GAN(device, generator, discriminator, iterator, encodes2embeddings_mapping):
51
+ generator.to(device)
52
+ discriminator.to(device)
53
+ generator.eval()
54
+ discriminator.eval()
55
+
56
+ real_accs = []
57
+ fake_accs = []
58
+
59
+ with torch.no_grad():
60
+ for i in range(100):
61
+ data, attributes = next(iter(iterator))
62
+ data = data.to(device)
63
+
64
+ conditions = [encodes2embeddings_mapping[create_key(attribute)] for attribute in attributes]
65
+ selected_conditions = [random.choice(conditions_of_one_sample) for conditions_of_one_sample in conditions]
66
+ selected_conditions = torch.stack(selected_conditions).float().to(device)
67
+
68
+ # 将数据和标签移至设备
69
+ real_images = data.to(device)
70
+ labels = selected_conditions.to(device)
71
+
72
+ # 生成噪声和假图像
73
+ noise = torch.randn_like(real_images).to(device)
74
+ fake_images = generator(noise)
75
+
76
+ # 评估鉴别器的性能
77
+ real_preds = discriminator(real_images, labels).reshape(-1)
78
+ fake_preds = discriminator(fake_images, labels).reshape(-1)
79
+ real_acc = (real_preds > 0.5).float().mean().item() # 真实图像的准确率
80
+ fake_acc = (fake_preds < 0.5).float().mean().item() # 生成图像的准确率
81
+
82
+ real_accs.append(real_acc)
83
+ fake_accs.append(fake_acc)
84
+
85
+
86
+ # 计算平均准确率
87
+ average_real_acc = sum(real_accs) / len(real_accs)
88
+ average_fake_acc = sum(fake_accs) / len(fake_accs)
89
+
90
+ return average_real_acc, average_fake_acc
91
+
92
+
93
+ def get_Generator(model_Config, load_pretrain=False, model_name=None, device="cpu"):
94
+ generator = ConditionedUnet(**model_Config)
95
+ print(f"Model intialized, size: {sum(p.numel() for p in generator.parameters() if p.requires_grad)}")
96
+ generator.to(device)
97
+
98
+ if load_pretrain:
99
+ print(f"Loading weights from models/{model_name}_generator.pth")
100
+ checkpoint = torch.load(f'models/{model_name}_generator.pth', map_location=device)
101
+ generator.load_state_dict(checkpoint['model_state_dict'])
102
+ generator.eval()
103
+ return generator
104
+
105
+
106
+ def get_Discriminator(model_Config, load_pretrain=False, model_name=None, device="cpu"):
107
+ discriminator = Discriminator(**model_Config)
108
+ print(f"Model intialized, size: {sum(p.numel() for p in discriminator.parameters() if p.requires_grad)}")
109
+ discriminator.to(device)
110
+
111
+ if load_pretrain:
112
+ print(f"Loading weights from models/{model_name}_discriminator.pth")
113
+ checkpoint = torch.load(f'models/{model_name}_discriminator.pth', map_location=device)
114
+ discriminator.load_state_dict(checkpoint['model_state_dict'])
115
+ discriminator.eval()
116
+ return discriminator
117
+
118
+
119
+ def train_GAN(device, init_model_name, unetConfig, BATCH_SIZE, lr_G, lr_D, max_iter, iterator, load_pretrain,
120
+ encodes2embeddings_mapping, save_steps, unconditional_condition, uncondition_rate, save_model_name=None):
121
+
122
+ if save_model_name is None:
123
+ save_model_name = init_model_name
124
+
125
+ def save_model_hyperparameter(model_name, unetConfig, BATCH_SIZE, model_size, current_iter, current_loss):
126
+ model_hyperparameter = unetConfig
127
+ model_hyperparameter["BATCH_SIZE"] = BATCH_SIZE
128
+ model_hyperparameter["lr_G"] = lr_G
129
+ model_hyperparameter["lr_D"] = lr_D
130
+ model_hyperparameter["model_size"] = model_size
131
+ model_hyperparameter["current_iter"] = current_iter
132
+ model_hyperparameter["current_loss"] = current_loss
133
+ with open(f"models/hyperparameters/{model_name}_GAN.json", "w") as json_file:
134
+ json.dump(model_hyperparameter, json_file, ensure_ascii=False, indent=4)
135
+
136
+ generator = ConditionedUnet(**unetConfig)
137
+ discriminator = Discriminator(unetConfig["label_emb_dim"])
138
+ generator_size = sum(p.numel() for p in generator.parameters() if p.requires_grad)
139
+ discriminator_size = sum(p.numel() for p in discriminator.parameters() if p.requires_grad)
140
+
141
+ print(f"Generator trainable parameters: {generator_size}, discriminator trainable parameters: {discriminator_size}")
142
+ generator.to(device)
143
+ discriminator.to(device)
144
+ optimizer_G = torch.optim.Adam(filter(lambda p: p.requires_grad, generator.parameters()), lr=lr_G, amsgrad=False)
145
+ optimizer_D = torch.optim.Adam(filter(lambda p: p.requires_grad, discriminator.parameters()), lr=lr_D, amsgrad=False)
146
+
147
+ if load_pretrain:
148
+ print(f"Loading weights from models/{init_model_name}_generator.pt")
149
+ checkpoint = torch.load(f'models/{init_model_name}_generator.pth')
150
+ generator.load_state_dict(checkpoint['model_state_dict'])
151
+ optimizer_G.load_state_dict(checkpoint['optimizer_state_dict'])
152
+ print(f"Loading weights from models/{init_model_name}_discriminator.pt")
153
+ checkpoint = torch.load(f'models/{init_model_name}_discriminator.pth')
154
+ discriminator.load_state_dict(checkpoint['model_state_dict'])
155
+ optimizer_D.load_state_dict(checkpoint['optimizer_state_dict'])
156
+ else:
157
+ print("Model initialized.")
158
+ if max_iter == 0:
159
+ print("Return model directly.")
160
+ return generator, discriminator, optimizer_G, optimizer_D
161
+
162
+
163
+ train_loss_G, train_loss_D = [], []
164
+ writer = SummaryWriter(f'runs/{save_model_name}_GAN')
165
+
166
+ # average_real_acc, average_fake_acc = evaluate_GAN(device, generator, discriminator, iterator, encodes2embeddings_mapping)
167
+ # print(f"average_real_acc, average_fake_acc: {average_real_acc, average_fake_acc}")
168
+
169
+ criterion = nn.BCEWithLogitsLoss()
170
+ generator.train()
171
+ for i in xrange(max_iter):
172
+ data, attributes = next(iter(iterator))
173
+ data = data.to(device)
174
+
175
+ conditions = [encodes2embeddings_mapping[create_key(attribute)] for attribute in attributes]
176
+ unconditional_condition_copy = torch.tensor(unconditional_condition, dtype=torch.float32).to(device).detach()
177
+ selected_conditions = [unconditional_condition_copy if random.random() < uncondition_rate else random.choice(
178
+ conditions_of_one_sample) for conditions_of_one_sample in conditions]
179
+ batch_size = len(selected_conditions)
180
+ selected_conditions = torch.stack(selected_conditions).float().to(device)
181
+
182
+ # 将数据和标签移至设备
183
+ real_images = data.to(device)
184
+ labels = selected_conditions.to(device)
185
+
186
+ # 真实和假的标签
187
+ real_labels = torch.ones(batch_size, 1).to(device)
188
+ fake_labels = torch.zeros(batch_size, 1).to(device)
189
+
190
+ # ========== 训练鉴别器 ==========
191
+ optimizer_D.zero_grad()
192
+
193
+ # 计算鉴别器对真实图像的损失
194
+ outputs_real = discriminator(real_images, labels)
195
+ loss_D_real = criterion(outputs_real, real_labels)
196
+
197
+ # 生成假图像
198
+ noise = torch.randn_like(real_images).to(device)
199
+ fake_images = generator(noise, labels)
200
+
201
+ # 计算鉴别器对假图像的损失
202
+ outputs_fake = discriminator(fake_images.detach(), labels)
203
+ loss_D_fake = criterion(outputs_fake, fake_labels)
204
+
205
+ # 反向传播和优化
206
+ loss_D = loss_D_real + loss_D_fake
207
+ loss_D.backward()
208
+ optimizer_D.step()
209
+
210
+ # ========== 训练生成器 ==========
211
+ optimizer_G.zero_grad()
212
+
213
+ # 计算生成器的损失
214
+ outputs_fake = discriminator(fake_images, labels)
215
+ loss_G = criterion(outputs_fake, real_labels)
216
+
217
+ # 反向传播和优化
218
+ loss_G.backward()
219
+ optimizer_G.step()
220
+
221
+
222
+ train_loss_G.append(loss_G.item())
223
+ train_loss_D.append(loss_D.item())
224
+ step = int(optimizer_G.state_dict()['state'][list(optimizer_G.state_dict()['state'].keys())[0]]['step'].numpy())
225
+
226
+ if (i + 1) % 100 == 0:
227
+ print('%d step' % (step))
228
+
229
+ if (i + 1) % save_steps == 0:
230
+ current_loss_D = np.mean(train_loss_D[-save_steps:])
231
+ current_loss_G = np.mean(train_loss_G[-save_steps:])
232
+ print('current_loss_G: %.5f' % current_loss_G)
233
+ print('current_loss_D: %.5f' % current_loss_D)
234
+
235
+ writer.add_scalar(f"current_loss_G", current_loss_G, step)
236
+ writer.add_scalar(f"current_loss_D", current_loss_D, step)
237
+
238
+
239
+ torch.save({
240
+ 'model_state_dict': generator.state_dict(),
241
+ 'optimizer_state_dict': optimizer_G.state_dict(),
242
+ }, f'models/{save_model_name}_generator.pth')
243
+ save_model_hyperparameter(save_model_name, unetConfig, BATCH_SIZE, generator_size, step, current_loss_G)
244
+ torch.save({
245
+ 'model_state_dict': discriminator.state_dict(),
246
+ 'optimizer_state_dict': optimizer_D.state_dict(),
247
+ }, f'models/{save_model_name}_discriminator.pth')
248
+ save_model_hyperparameter(save_model_name, unetConfig, BATCH_SIZE, discriminator_size, step, current_loss_D)
249
+
250
+ if step % 10000 == 0:
251
+ torch.save({
252
+ 'model_state_dict': generator.state_dict(),
253
+ 'optimizer_state_dict': optimizer_G.state_dict(),
254
+ }, f'models/history/{save_model_name}_{step}_generator.pth')
255
+ torch.save({
256
+ 'model_state_dict': discriminator.state_dict(),
257
+ 'optimizer_state_dict': optimizer_D.state_dict(),
258
+ }, f'models/history/{save_model_name}_{step}_discriminator.pth')
259
+
260
+ return generator, discriminator, optimizer_G, optimizer_D
261
+
262
+
model/VQGAN.py ADDED
@@ -0,0 +1,684 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ from torch.utils.tensorboard import SummaryWriter
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+ import numpy as np
7
+ from six.moves import xrange
8
+ from einops import rearrange
9
+ from torchvision import models
10
+
11
+
12
+ def Normalize(in_channels, num_groups=32, norm_type="groupnorm"):
13
+ """Normalization layer"""
14
+
15
+ if norm_type == "batchnorm":
16
+ return torch.nn.BatchNorm2d(in_channels)
17
+ return torch.nn.GroupNorm(num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True)
18
+
19
+
20
+ def nonlinearity(x, act_type="relu"):
21
+ """Nonlinear activation function"""
22
+
23
+ if act_type == "relu":
24
+ return F.relu(x)
25
+ else:
26
+ # swish
27
+ return x * torch.sigmoid(x)
28
+
29
+
30
+ class VectorQuantizer(nn.Module):
31
+ """Vector quantization layer"""
32
+
33
+ def __init__(self, num_embeddings, embedding_dim, commitment_cost):
34
+ super(VectorQuantizer, self).__init__()
35
+
36
+ self._embedding_dim = embedding_dim
37
+ self._num_embeddings = num_embeddings
38
+
39
+ self._embedding = nn.Embedding(self._num_embeddings, self._embedding_dim)
40
+ self._embedding.weight.data.uniform_(-1 / self._num_embeddings, 1 / self._num_embeddings)
41
+ self._commitment_cost = commitment_cost
42
+
43
+ def forward(self, inputs):
44
+ # convert inputs from BCHW -> BHWC
45
+ inputs = inputs.permute(0, 2, 3, 1).contiguous()
46
+ input_shape = inputs.shape
47
+
48
+ # Flatten input BCHW -> (BHW)C
49
+ flat_input = inputs.view(-1, self._embedding_dim)
50
+
51
+ # Calculate distances (input-embedding)^2
52
+ distances = (torch.sum(flat_input ** 2, dim=1, keepdim=True)
53
+ + torch.sum(self._embedding.weight ** 2, dim=1)
54
+ - 2 * torch.matmul(flat_input, self._embedding.weight.t()))
55
+
56
+ # Encoding (one-hot-encoding matrix)
57
+ encoding_indices = torch.argmin(distances, dim=1).unsqueeze(1)
58
+ encodings = torch.zeros(encoding_indices.shape[0], self._num_embeddings, device=inputs.device)
59
+ encodings.scatter_(1, encoding_indices, 1)
60
+
61
+ # Quantize and unflatten
62
+ quantized = torch.matmul(encodings, self._embedding.weight).view(input_shape)
63
+
64
+ # Loss
65
+ e_latent_loss = F.mse_loss(quantized.detach(), inputs)
66
+ q_latent_loss = F.mse_loss(quantized, inputs.detach())
67
+ loss = q_latent_loss + self._commitment_cost * e_latent_loss
68
+
69
+ quantized = inputs + (quantized - inputs).detach()
70
+ avg_probs = torch.mean(encodings, dim=0)
71
+ perplexity = torch.exp(-torch.sum(avg_probs * torch.log(avg_probs + 1e-10)))
72
+
73
+ # convert quantized from BHWC -> BCHW
74
+ min_encodings, min_encoding_indices = None, None
75
+ return quantized.permute(0, 3, 1, 2).contiguous(), loss, (perplexity, min_encodings, min_encoding_indices)
76
+
77
+
78
+ class VectorQuantizerEMA(nn.Module):
79
+ """Vector quantization layer based on exponential moving average"""
80
+
81
+ def __init__(self, num_embeddings, embedding_dim, commitment_cost, decay, epsilon=1e-5):
82
+ super(VectorQuantizerEMA, self).__init__()
83
+
84
+ self._embedding_dim = embedding_dim
85
+ self._num_embeddings = num_embeddings
86
+
87
+ self._embedding = nn.Embedding(self._num_embeddings, self._embedding_dim)
88
+ self._embedding.weight.data.normal_()
89
+ self._commitment_cost = commitment_cost
90
+
91
+ self.register_buffer('_ema_cluster_size', torch.zeros(num_embeddings))
92
+ self._ema_w = nn.Parameter(torch.Tensor(num_embeddings, self._embedding_dim))
93
+ self._ema_w.data.normal_()
94
+
95
+ self._decay = decay
96
+ self._epsilon = epsilon
97
+
98
+ def forward(self, inputs):
99
+ # convert inputs from BCHW -> BHWC
100
+ inputs = inputs.permute(0, 2, 3, 1).contiguous()
101
+ input_shape = inputs.shape
102
+
103
+ # Flatten input
104
+ flat_input = inputs.view(-1, self._embedding_dim)
105
+
106
+ # Calculate distances
107
+ distances = (torch.sum(flat_input ** 2, dim=1, keepdim=True)
108
+ + torch.sum(self._embedding.weight ** 2, dim=1)
109
+ - 2 * torch.matmul(flat_input, self._embedding.weight.t()))
110
+
111
+ # Encoding
112
+ encoding_indices = torch.argmin(distances, dim=1).unsqueeze(1)
113
+ encodings = torch.zeros(encoding_indices.shape[0], self._num_embeddings, device=inputs.device)
114
+ encodings.scatter_(1, encoding_indices, 1)
115
+
116
+ # Quantize and unflatten
117
+ quantized = torch.matmul(encodings, self._embedding.weight).view(input_shape)
118
+
119
+ # Use EMA to update the embedding vectors
120
+ if self.training:
121
+ self._ema_cluster_size = self._ema_cluster_size * self._decay + \
122
+ (1 - self._decay) * torch.sum(encodings, 0)
123
+
124
+ # Laplace smoothing of the cluster size
125
+ n = torch.sum(self._ema_cluster_size.data)
126
+ self._ema_cluster_size = (
127
+ (self._ema_cluster_size + self._epsilon)
128
+ / (n + self._num_embeddings * self._epsilon) * n)
129
+
130
+ dw = torch.matmul(encodings.t(), flat_input)
131
+ self._ema_w = nn.Parameter(self._ema_w * self._decay + (1 - self._decay) * dw)
132
+
133
+ self._embedding.weight = nn.Parameter(self._ema_w / self._ema_cluster_size.unsqueeze(1))
134
+
135
+ # Loss
136
+ e_latent_loss = F.mse_loss(quantized.detach(), inputs)
137
+ loss = self._commitment_cost * e_latent_loss
138
+
139
+ # Straight Through Estimator
140
+ quantized = inputs + (quantized - inputs).detach()
141
+ avg_probs = torch.mean(encodings, dim=0)
142
+ perplexity = torch.exp(-torch.sum(avg_probs * torch.log(avg_probs + 1e-10)))
143
+
144
+ # convert quantized from BHWC -> BCHW
145
+ min_encodings, min_encoding_indices = None, None
146
+ return quantized.permute(0, 3, 1, 2).contiguous(), loss, (perplexity, min_encodings, min_encoding_indices)
147
+
148
+
149
+ class DownSample(nn.Module):
150
+ """DownSample layer"""
151
+
152
+ def __init__(self, in_channels, out_channels):
153
+ super(DownSample, self).__init__()
154
+ self._conv2d = nn.Conv2d(in_channels=in_channels,
155
+ out_channels=out_channels,
156
+ kernel_size=4,
157
+ stride=2, padding=1)
158
+
159
+ def forward(self, x):
160
+ return self._conv2d(x)
161
+
162
+
163
+ class UpSample(nn.Module):
164
+ """UpSample layer"""
165
+
166
+ def __init__(self, in_channels, out_channels):
167
+ super(UpSample, self).__init__()
168
+ self._conv2d = nn.ConvTranspose2d(in_channels=in_channels,
169
+ out_channels=out_channels,
170
+ kernel_size=4,
171
+ stride=2, padding=1)
172
+
173
+ def forward(self, x):
174
+ return self._conv2d(x)
175
+
176
+
177
+ class ResnetBlock(nn.Module):
178
+ """ResnetBlock is a combination of non-linearity, convolution, and normalization"""
179
+
180
+ def __init__(self, *, in_channels, out_channels=None, double_conv=False, conv_shortcut=False,
181
+ dropout=0.0, temb_channels=512, norm_type="groupnorm", act_type="relu", num_groups=32):
182
+ super().__init__()
183
+ self.in_channels = in_channels
184
+ out_channels = in_channels if out_channels is None else out_channels
185
+ self.out_channels = out_channels
186
+ self.use_conv_shortcut = conv_shortcut
187
+ self.act_type = act_type
188
+
189
+ self.norm1 = Normalize(in_channels, norm_type=norm_type, num_groups=num_groups)
190
+ self.conv1 = torch.nn.Conv2d(in_channels,
191
+ out_channels,
192
+ kernel_size=3,
193
+ stride=1,
194
+ padding=1)
195
+ if temb_channels > 0:
196
+ self.temb_proj = torch.nn.Linear(temb_channels,
197
+ out_channels)
198
+
199
+ self.double_conv = double_conv
200
+ if self.double_conv:
201
+ self.norm2 = Normalize(out_channels, norm_type=norm_type, num_groups=num_groups)
202
+ self.dropout = torch.nn.Dropout(dropout)
203
+ self.conv2 = torch.nn.Conv2d(out_channels,
204
+ out_channels,
205
+ kernel_size=3,
206
+ stride=1,
207
+ padding=1)
208
+
209
+ if self.in_channels != self.out_channels:
210
+ if self.use_conv_shortcut:
211
+ self.conv_shortcut = torch.nn.Conv2d(in_channels,
212
+ out_channels,
213
+ kernel_size=3,
214
+ stride=1,
215
+ padding=1)
216
+ else:
217
+ self.nin_shortcut = torch.nn.Conv2d(in_channels,
218
+ out_channels,
219
+ kernel_size=1,
220
+ stride=1,
221
+ padding=0)
222
+
223
+ def forward(self, x, temb=None):
224
+ h = x
225
+ h = self.norm1(h)
226
+ h = nonlinearity(h, act_type=self.act_type)
227
+ h = self.conv1(h)
228
+
229
+ if temb is not None:
230
+ h = h + self.temb_proj(nonlinearity(temb, act_type=self.act_type))[:, :, None, None]
231
+
232
+ if self.double_conv:
233
+ h = self.norm2(h)
234
+ h = nonlinearity(h, act_type=self.act_type)
235
+ h = self.dropout(h)
236
+ h = self.conv2(h)
237
+
238
+ if self.in_channels != self.out_channels:
239
+ if self.use_conv_shortcut:
240
+ x = self.conv_shortcut(x)
241
+ else:
242
+ x = self.nin_shortcut(x)
243
+
244
+ return x + h
245
+
246
+
247
+ class LinearAttention(nn.Module):
248
+ """Efficient attention block based on <https://proceedings.mlr.press/v119/katharopoulos20a.html>"""
249
+
250
+ def __init__(self, dim, heads=4, dim_head=32, with_skip=True):
251
+ super().__init__()
252
+ self.heads = heads
253
+ hidden_dim = dim_head * heads
254
+ self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias=False)
255
+ self.to_out = nn.Conv2d(hidden_dim, dim, 1)
256
+
257
+ self.with_skip = with_skip
258
+ if self.with_skip:
259
+ self.nin_shortcut = torch.nn.Conv2d(dim, dim, kernel_size=1, stride=1, padding=0)
260
+
261
+ def forward(self, x):
262
+ b, c, h, w = x.shape
263
+ qkv = self.to_qkv(x)
264
+ q, k, v = rearrange(qkv, 'b (qkv heads c) h w -> qkv b heads c (h w)', heads=self.heads, qkv=3)
265
+ k = k.softmax(dim=-1)
266
+ context = torch.einsum('bhdn,bhen->bhde', k, v)
267
+ out = torch.einsum('bhde,bhdn->bhen', context, q)
268
+ out = rearrange(out, 'b heads c (h w) -> b (heads c) h w', heads=self.heads, h=h, w=w)
269
+
270
+ if self.with_skip:
271
+ return self.to_out(out) + self.nin_shortcut(x)
272
+ return self.to_out(out)
273
+
274
+
275
+ class Encoder(nn.Module):
276
+ """The encoder, consisting of alternating stacks of ResNet blocks, efficient attention modules, and downsampling layers."""
277
+
278
+ def __init__(self, in_channels, hidden_channels, embedding_dim, block_depth=2,
279
+ attn_pos=None, attn_with_skip=True, norm_type="groupnorm", act_type="relu", num_groups=32):
280
+ super(Encoder, self).__init__()
281
+
282
+ if attn_pos is None:
283
+ attn_pos = []
284
+ self._layers = nn.ModuleList([DownSample(in_channels, hidden_channels[0])])
285
+ current_channel = hidden_channels[0]
286
+
287
+ for i in range(1, len(hidden_channels)):
288
+ for _ in range(block_depth - 1):
289
+ self._layers.append(ResnetBlock(in_channels=current_channel,
290
+ out_channels=current_channel,
291
+ double_conv=False,
292
+ conv_shortcut=False,
293
+ norm_type=norm_type,
294
+ act_type=act_type,
295
+ num_groups=num_groups))
296
+ if current_channel in attn_pos:
297
+ self._layers.append(LinearAttention(current_channel, 1, 32, attn_with_skip))
298
+
299
+ self._layers.append(Normalize(current_channel, norm_type=norm_type, num_groups=num_groups))
300
+ self._layers.append(nn.ReLU())
301
+ self._layers.append(DownSample(current_channel, hidden_channels[i]))
302
+ current_channel = hidden_channels[i]
303
+
304
+ for _ in range(block_depth - 1):
305
+ self._layers.append(ResnetBlock(in_channels=current_channel,
306
+ out_channels=current_channel,
307
+ double_conv=False,
308
+ conv_shortcut=False,
309
+ norm_type=norm_type,
310
+ act_type=act_type,
311
+ num_groups=num_groups))
312
+ if current_channel in attn_pos:
313
+ self._layers.append(LinearAttention(current_channel, 1, 32, attn_with_skip))
314
+
315
+ # Conv1x1: hidden_channels[-1] -> embedding_dim
316
+ self._layers.append(Normalize(current_channel, norm_type=norm_type, num_groups=num_groups))
317
+ self._layers.append(nn.ReLU())
318
+ self._layers.append(nn.Conv2d(in_channels=current_channel,
319
+ out_channels=embedding_dim,
320
+ kernel_size=1,
321
+ stride=1))
322
+
323
+ def forward(self, x):
324
+ for layer in self._layers:
325
+ x = layer(x)
326
+ return x
327
+
328
+
329
+ class Decoder(nn.Module):
330
+ """The decoder, consisting of alternating stacks of ResNet blocks, efficient attention modules, and upsampling layers."""
331
+
332
+ def __init__(self, embedding_dim, hidden_channels, out_channels, block_depth=2,
333
+ attn_pos=None, attn_with_skip=True, norm_type="groupnorm", act_type="relu",
334
+ num_groups=32):
335
+ super(Decoder, self).__init__()
336
+
337
+ if attn_pos is None:
338
+ attn_pos = []
339
+ reversed_hidden_channels = list(reversed(hidden_channels))
340
+
341
+ # Conv1x1: hidden_channels[-1] -> embedding_dim
342
+ self._layers = nn.ModuleList([nn.Conv2d(in_channels=embedding_dim,
343
+ out_channels=reversed_hidden_channels[0],
344
+ kernel_size=1, stride=1, bias=False)])
345
+
346
+ current_channel = reversed_hidden_channels[0]
347
+
348
+ for _ in range(block_depth - 1):
349
+ if current_channel in attn_pos:
350
+ self._layers.append(LinearAttention(current_channel, 1, 32, attn_with_skip))
351
+ self._layers.append(ResnetBlock(in_channels=current_channel,
352
+ out_channels=current_channel,
353
+ double_conv=False,
354
+ conv_shortcut=False,
355
+ norm_type=norm_type,
356
+ act_type=act_type,
357
+ num_groups=num_groups))
358
+
359
+ for i in range(1, len(reversed_hidden_channels)):
360
+ self._layers.append(Normalize(current_channel, norm_type=norm_type, num_groups=num_groups))
361
+ self._layers.append(nn.ReLU())
362
+ self._layers.append(UpSample(current_channel, reversed_hidden_channels[i]))
363
+ current_channel = reversed_hidden_channels[i]
364
+
365
+ for _ in range(block_depth - 1):
366
+ if current_channel in attn_pos:
367
+ self._layers.append(LinearAttention(current_channel, 1, 32, attn_with_skip))
368
+ self._layers.append(ResnetBlock(in_channels=current_channel,
369
+ out_channels=current_channel,
370
+ double_conv=False,
371
+ conv_shortcut=False,
372
+ norm_type=norm_type,
373
+ act_type=act_type,
374
+ num_groups=num_groups))
375
+
376
+ self._layers.append(Normalize(current_channel, norm_type=norm_type, num_groups=num_groups))
377
+ self._layers.append(nn.ReLU())
378
+ self._layers.append(UpSample(current_channel, current_channel))
379
+
380
+ # final layers
381
+ self._layers.append(ResnetBlock(in_channels=current_channel,
382
+ out_channels=out_channels,
383
+ double_conv=False,
384
+ conv_shortcut=False,
385
+ norm_type=norm_type,
386
+ act_type=act_type,
387
+ num_groups=num_groups))
388
+
389
+
390
+ def forward(self, x):
391
+ for layer in self._layers:
392
+ x = layer(x)
393
+
394
+ log_magnitude = torch.nn.functional.softplus(x[:, 0, :, :])
395
+
396
+ cos_phase = torch.tanh(x[:, 1, :, :])
397
+ sin_phase = torch.tanh(x[:, 2, :, :])
398
+ x = torch.stack([log_magnitude, cos_phase, sin_phase], dim=1)
399
+
400
+ return x
401
+
402
+
403
+ class VQGAN_Discriminator(nn.Module):
404
+ """The discriminator employs an 18-layer-ResNet architecture , with the first layer replaced by a 2D convolutional
405
+ layer that accommodates spectral representation inputs and the last two layers replaced by a binary classifier
406
+ layer."""
407
+
408
+ def __init__(self, in_channels=1):
409
+ super(VQGAN_Discriminator, self).__init__()
410
+ resnet = models.resnet18(pretrained=True)
411
+
412
+ # 修改第一层以接受单通道(黑白)图像
413
+ resnet.conv1 = nn.Conv2d(in_channels, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
414
+
415
+ # 使用ResNet的特征提取部分
416
+ self.features = nn.Sequential(*list(resnet.children())[:-2])
417
+
418
+ # 添加判别器的额外层
419
+ self.classifier = nn.Sequential(
420
+ nn.Linear(512, 1),
421
+ nn.Sigmoid()
422
+ )
423
+
424
+ def forward(self, x):
425
+ x = self.features(x)
426
+ x = nn.functional.adaptive_avg_pool2d(x, (1, 1))
427
+ x = torch.flatten(x, 1)
428
+ x = self.classifier(x)
429
+ return x
430
+
431
+
432
+ class VQGAN(nn.Module):
433
+ """The VQ-GAN model. <https://openaccess.thecvf.com/content/CVPR2021/html/Esser_Taming_Transformers_for_High-Resolution_Image_Synthesis_CVPR_2021_paper.html?ref=>"""
434
+
435
+ def __init__(self, in_channels, hidden_channels, embedding_dim, out_channels, block_depth=2,
436
+ attn_pos=None, attn_with_skip=True, norm_type="groupnorm", act_type="relu",
437
+ num_embeddings=1024, commitment_cost=0.25, decay=0.99, num_groups=32):
438
+ super(VQGAN, self).__init__()
439
+
440
+ self._encoder = Encoder(in_channels, hidden_channels, embedding_dim, block_depth=block_depth,
441
+ attn_pos=attn_pos, attn_with_skip=attn_with_skip, norm_type=norm_type, act_type="act_type", num_groups=num_groups)
442
+
443
+ if decay > 0.0:
444
+ self._vq_vae = VectorQuantizerEMA(num_embeddings, embedding_dim,
445
+ commitment_cost, decay)
446
+ else:
447
+ self._vq_vae = VectorQuantizer(num_embeddings, embedding_dim,
448
+ commitment_cost)
449
+ self._decoder = Decoder(embedding_dim, hidden_channels, out_channels, block_depth=block_depth,
450
+ attn_pos=attn_pos, attn_with_skip=attn_with_skip, norm_type=norm_type,
451
+ act_type=act_type, num_groups=num_groups)
452
+
453
+ def forward(self, x):
454
+ z = self._encoder(x)
455
+ quantized, vq_loss, (perplexity, _, _) = self._vq_vae(z)
456
+ x_recon = self._decoder(quantized)
457
+
458
+ return vq_loss, x_recon, perplexity
459
+
460
+
461
+ class ReconstructionLoss(nn.Module):
462
+ def __init__(self, w1, w2, epsilon=1e-3):
463
+ super(ReconstructionLoss, self).__init__()
464
+ self.w1 = w1
465
+ self.w2 = w2
466
+ self.epsilon = epsilon
467
+
468
+ def weighted_mae_loss(self, y_true, y_pred):
469
+ # avoid divide by zero
470
+ y_true_safe = torch.clamp(y_true, min=self.epsilon)
471
+
472
+ # compute weighted MAE
473
+ loss = torch.mean(torch.abs(y_pred - y_true) / y_true_safe)
474
+ return loss
475
+
476
+ def mae_loss(self, y_true, y_pred):
477
+ loss = torch.mean(torch.abs(y_pred - y_true))
478
+ return loss
479
+
480
+ def forward(self, y_pred, y_true):
481
+ # loss for magnitude channel
482
+ log_magnitude_loss = self.w1 * self.weighted_mae_loss(y_pred[:, 0, :, :], y_true[:, 0, :, :])
483
+
484
+ # loss for phase channels
485
+ phase_loss = self.w2 * self.mae_loss(y_pred[:, 1:, :, :], y_true[:, 1:, :, :])
486
+
487
+ # sum up
488
+ rec_loss = log_magnitude_loss + phase_loss
489
+ return log_magnitude_loss, phase_loss, rec_loss
490
+
491
+
492
+ def evaluate_VQGAN(model, discriminator, iterator, reconstructionLoss, adversarial_loss, trainingConfig):
493
+ model.to(trainingConfig["device"])
494
+ model.eval()
495
+ train_res_error = []
496
+ for i in xrange(100):
497
+ data = next(iter(iterator))
498
+ data = data.to(trainingConfig["device"])
499
+
500
+ # true/fake labels
501
+ real_labels = torch.ones(data.size(0), 1).to(trainingConfig["device"])
502
+
503
+ vq_loss, data_recon, perplexity = model(data)
504
+
505
+
506
+ fake_preds = discriminator(data_recon)
507
+ adver_loss = adversarial_loss(fake_preds, real_labels)
508
+
509
+ log_magnitude_loss, phase_loss, rec_loss = reconstructionLoss(data_recon, data)
510
+ loss = rec_loss + trainingConfig["vq_weight"] * vq_loss + trainingConfig["adver_weight"] * adver_loss
511
+
512
+ train_res_error.append(loss.item())
513
+ initial_loss = np.mean(train_res_error)
514
+ return initial_loss
515
+
516
+
517
+ def get_VQGAN(model_Config, load_pretrain=False, model_name=None, device="cpu"):
518
+ VQVAE = VQGAN(**model_Config)
519
+ print(f"Model intialized, size: {sum(p.numel() for p in VQVAE.parameters() if p.requires_grad)}")
520
+ VQVAE.to(device)
521
+
522
+ if load_pretrain:
523
+ print(f"Loading weights from models/{model_name}_imageVQVAE.pth")
524
+ checkpoint = torch.load(f'models/{model_name}_imageVQVAE.pth', map_location=device)
525
+ VQVAE.load_state_dict(checkpoint['model_state_dict'])
526
+ VQVAE.eval()
527
+ return VQVAE
528
+
529
+
530
+ def train_VQGAN(model_Config, trainingConfig, iterator):
531
+
532
+ def save_model_hyperparameter(model_Config, trainingConfig, current_iter,
533
+ log_magnitude_loss, phase_loss, current_perplexity, current_vq_loss,
534
+ current_loss):
535
+ model_name = trainingConfig["model_name"]
536
+ model_hyperparameter = model_Config
537
+ model_hyperparameter.update(trainingConfig)
538
+ model_hyperparameter["current_iter"] = current_iter
539
+ model_hyperparameter["log_magnitude_loss"] = log_magnitude_loss
540
+ model_hyperparameter["phase_loss"] = phase_loss
541
+ model_hyperparameter["erplexity"] = current_perplexity
542
+ model_hyperparameter["vq_loss"] = current_vq_loss
543
+ model_hyperparameter["total_loss"] = current_loss
544
+
545
+ with open(f"models/hyperparameters/{model_name}_VQGAN_STFT.json", "w") as json_file:
546
+ json.dump(model_hyperparameter, json_file, ensure_ascii=False, indent=4)
547
+
548
+ # initialize VAE
549
+ model = VQGAN(**model_Config)
550
+ print(f"VQ_VAE size: {sum(p.numel() for p in model.parameters() if p.requires_grad)}")
551
+ model.to(trainingConfig["device"])
552
+
553
+ VAE_optimizer = torch.optim.Adam(model.parameters(), lr=trainingConfig["lr"], amsgrad=False)
554
+ model_name = trainingConfig["model_name"]
555
+
556
+ if trainingConfig["load_pretrain"]:
557
+ print(f"Loading weights from models/{model_name}_imageVQVAE.pth")
558
+ checkpoint = torch.load(f'models/{model_name}_imageVQVAE.pth', map_location=trainingConfig["device"])
559
+ model.load_state_dict(checkpoint['model_state_dict'])
560
+ VAE_optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
561
+ else:
562
+ print("VAE initialized.")
563
+ if trainingConfig["max_iter"] == 0:
564
+ print("Return VAE directly.")
565
+ return model
566
+
567
+ # initialize discriminator
568
+ discriminator = VQGAN_Discriminator(model_Config["in_channels"])
569
+ print(f"Discriminator size: {sum(p.numel() for p in discriminator.parameters() if p.requires_grad)}")
570
+ discriminator.to(trainingConfig["device"])
571
+
572
+ discriminator_optimizer = torch.optim.Adam(discriminator.parameters(), lr=trainingConfig["d_lr"], amsgrad=False)
573
+
574
+ if trainingConfig["load_pretrain"]:
575
+ print(f"Loading weights from models/{model_name}_imageVQVAE_discriminator.pth")
576
+ checkpoint = torch.load(f'models/{model_name}_imageVQVAE_discriminator.pth', map_location=trainingConfig["device"])
577
+ discriminator.load_state_dict(checkpoint['model_state_dict'])
578
+ discriminator_optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
579
+ else:
580
+ print("Discriminator initialized.")
581
+
582
+ # Training
583
+
584
+ train_res_phase_loss, train_res_perplexity, train_res_log_magnitude_loss, train_res_vq_loss, train_res_loss = [], [], [], [], []
585
+ train_discriminator_loss, train_adverserial_loss = [], []
586
+
587
+ reconstructionLoss = ReconstructionLoss(w1=trainingConfig["w1"], w2=trainingConfig["w2"], epsilon=trainingConfig["threshold"])
588
+
589
+ adversarial_loss = nn.BCEWithLogitsLoss()
590
+ writer = SummaryWriter(f'runs/{model_name}_VQVAE_lr=1e-4')
591
+
592
+ previous_lowest_loss = evaluate_VQGAN(model, discriminator, iterator,
593
+ reconstructionLoss, adversarial_loss, trainingConfig)
594
+ print(f"initial_loss: {previous_lowest_loss}")
595
+
596
+ model.train()
597
+ for i in xrange(trainingConfig["max_iter"]):
598
+ data = next(iter(iterator))
599
+ data = data.to(trainingConfig["device"])
600
+
601
+ # true/fake labels
602
+ real_labels = torch.ones(data.size(0), 1).to(trainingConfig["device"])
603
+ fake_labels = torch.zeros(data.size(0), 1).to(trainingConfig["device"])
604
+
605
+ # update discriminator
606
+ discriminator_optimizer.zero_grad()
607
+
608
+ vq_loss, data_recon, perplexity = model(data)
609
+
610
+ real_preds = discriminator(data)
611
+ fake_preds = discriminator(data_recon.detach())
612
+
613
+ loss_real = adversarial_loss(real_preds, real_labels)
614
+ loss_fake = adversarial_loss(fake_preds, fake_labels)
615
+
616
+ loss_D = loss_real + loss_fake
617
+ loss_D.backward()
618
+ discriminator_optimizer.step()
619
+
620
+
621
+ # update VQVAE
622
+ VAE_optimizer.zero_grad()
623
+
624
+ fake_preds = discriminator(data_recon)
625
+ adver_loss = adversarial_loss(fake_preds, real_labels)
626
+
627
+ log_magnitude_loss, phase_loss, rec_loss = reconstructionLoss(data_recon, data)
628
+
629
+ loss = rec_loss + trainingConfig["vq_weight"] * vq_loss + trainingConfig["adver_weight"] * adver_loss
630
+ loss.backward()
631
+ VAE_optimizer.step()
632
+
633
+ train_discriminator_loss.append(loss_D.item())
634
+ train_adverserial_loss.append(trainingConfig["adver_weight"] * adver_loss.item())
635
+ train_res_log_magnitude_loss.append(log_magnitude_loss.item())
636
+ train_res_phase_loss.append(phase_loss.item())
637
+ train_res_perplexity.append(perplexity.item())
638
+ train_res_vq_loss.append(trainingConfig["vq_weight"] * vq_loss.item())
639
+ train_res_loss.append(loss.item())
640
+ step = int(VAE_optimizer.state_dict()['state'][list(VAE_optimizer.state_dict()['state'].keys())[0]]['step'].cpu().numpy())
641
+
642
+ save_steps = trainingConfig["save_steps"]
643
+ if (i + 1) % 100 == 0:
644
+ print('%d step' % (step))
645
+
646
+ if (i + 1) % save_steps == 0:
647
+ current_discriminator_loss = np.mean(train_discriminator_loss[-save_steps:])
648
+ current_adverserial_loss = np.mean(train_adverserial_loss[-save_steps:])
649
+ current_log_magnitude_loss = np.mean(train_res_log_magnitude_loss[-save_steps:])
650
+ current_phase_loss = np.mean(train_res_phase_loss[-save_steps:])
651
+ current_perplexity = np.mean(train_res_perplexity[-save_steps:])
652
+ current_vq_loss = np.mean(train_res_vq_loss[-save_steps:])
653
+ current_loss = np.mean(train_res_loss[-save_steps:])
654
+
655
+ print('discriminator_loss: %.3f' % current_discriminator_loss)
656
+ print('adverserial_loss: %.3f' % current_adverserial_loss)
657
+ print('log_magnitude_loss: %.3f' % current_log_magnitude_loss)
658
+ print('phase_loss: %.3f' % current_phase_loss)
659
+ print('perplexity: %.3f' % current_perplexity)
660
+ print('vq_loss: %.3f' % current_vq_loss)
661
+ print('total_loss: %.3f' % current_loss)
662
+ writer.add_scalar(f"log_magnitude_loss", current_log_magnitude_loss, step)
663
+ writer.add_scalar(f"phase_loss", current_phase_loss, step)
664
+ writer.add_scalar(f"perplexity", current_perplexity, step)
665
+ writer.add_scalar(f"vq_loss", current_vq_loss, step)
666
+ writer.add_scalar(f"total_loss", current_loss, step)
667
+ if current_loss < previous_lowest_loss:
668
+ previous_lowest_loss = current_loss
669
+
670
+ torch.save({
671
+ 'model_state_dict': model.state_dict(),
672
+ 'optimizer_state_dict': VAE_optimizer.state_dict(),
673
+ }, f'models/{model_name}_imageVQVAE.pth')
674
+
675
+ torch.save({
676
+ 'model_state_dict': discriminator.state_dict(),
677
+ 'optimizer_state_dict': discriminator_optimizer.state_dict(),
678
+ }, f'models/{model_name}_imageVQVAE_discriminator.pth')
679
+
680
+ save_model_hyperparameter(model_Config, trainingConfig, step,
681
+ current_log_magnitude_loss, current_phase_loss, current_perplexity, current_vq_loss,
682
+ current_loss)
683
+
684
+ return model
model/__pycache__/DiffSynthSampler.cpython-310.pyc ADDED
Binary file (12.8 kB). View file
 
model/__pycache__/GAN.cpython-310.pyc ADDED
Binary file (7.49 kB). View file
 
model/__pycache__/VQGAN.cpython-310.pyc ADDED
Binary file (18.8 kB). View file
 
model/__pycache__/diffusion.cpython-310.pyc ADDED
Binary file (10.1 kB). View file
 
model/__pycache__/diffusion_components.cpython-310.pyc ADDED
Binary file (11.6 kB). View file
 
model/__pycache__/multimodal_model.cpython-310.pyc ADDED
Binary file (9.88 kB). View file
 
model/__pycache__/perceptual_label_predictor.cpython-37.pyc ADDED
Binary file (1.67 kB). View file
 
model/__pycache__/timbre_encoder_pretrain.cpython-310.pyc ADDED
Binary file (7.96 kB). View file
 
model/diffusion.py ADDED
@@ -0,0 +1,371 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ from functools import partial
3
+
4
+ import numpy as np
5
+ import torch
6
+ from torch import nn
7
+ import torch.nn.functional as F
8
+ from six.moves import xrange
9
+ from torch.utils.tensorboard import SummaryWriter
10
+ import random
11
+
12
+ from metrics.IS import get_inception_score
13
+ from tools import create_key
14
+
15
+ from model.diffusion_components import default, ConvNextBlock, ResnetBlock, SinusoidalPositionEmbeddings, Residual, \
16
+ PreNorm, \
17
+ Downsample, Upsample, exists, q_sample, get_beta_schedule, pad_and_concat, ConditionalEmbedding, \
18
+ LinearCrossAttention, LinearCrossAttentionAdd
19
+
20
+
21
+ class ConditionedUnet(nn.Module):
22
+ def __init__(
23
+ self,
24
+ in_dim,
25
+ out_dim=None,
26
+ down_dims=None,
27
+ up_dims=None,
28
+ mid_depth=3,
29
+ with_time_emb=True,
30
+ time_dim=None,
31
+ resnet_block_groups=8,
32
+ use_convnext=True,
33
+ convnext_mult=2,
34
+ attn_type="linear_cat",
35
+ n_label_class=11,
36
+ condition_type="instrument_family",
37
+ label_emb_dim=128,
38
+ ):
39
+ super().__init__()
40
+
41
+ self.label_embedding = ConditionalEmbedding(int(n_label_class + 1), int(label_emb_dim), condition_type)
42
+
43
+ if up_dims is None:
44
+ up_dims = [128, 128, 64, 32]
45
+ if down_dims is None:
46
+ down_dims = [32, 32, 64, 128]
47
+
48
+ out_dim = default(out_dim, in_dim)
49
+ assert len(down_dims) == len(up_dims), "len(down_dims) != len(up_dims)"
50
+ assert down_dims[0] == up_dims[-1], "down_dims[0] != up_dims[-1]"
51
+ assert up_dims[0] == down_dims[-1], "up_dims[0] != down_dims[-1]"
52
+ down_in_out = list(zip(down_dims[:-1], down_dims[1:]))
53
+ up_in_out = list(zip(up_dims[:-1], up_dims[1:]))
54
+ print(f"down_in_out: {down_in_out}")
55
+ print(f"up_in_out: {up_in_out}")
56
+ time_dim = default(time_dim, int(down_dims[0] * 4))
57
+
58
+ self.init_conv = nn.Conv2d(in_dim, down_dims[0], 7, padding=3)
59
+
60
+ if use_convnext:
61
+ block_klass = partial(ConvNextBlock, mult=convnext_mult)
62
+ else:
63
+ block_klass = partial(ResnetBlock, groups=resnet_block_groups)
64
+
65
+ if attn_type == "linear_cat":
66
+ attn_klass = partial(LinearCrossAttention)
67
+ elif attn_type == "linear_add":
68
+ attn_klass = partial(LinearCrossAttentionAdd)
69
+ else:
70
+ raise NotImplementedError()
71
+
72
+ # time embeddings
73
+ if with_time_emb:
74
+ self.time_mlp = nn.Sequential(
75
+ SinusoidalPositionEmbeddings(down_dims[0]),
76
+ nn.Linear(down_dims[0], time_dim),
77
+ nn.GELU(),
78
+ nn.Linear(time_dim, time_dim),
79
+ )
80
+ else:
81
+ time_dim = None
82
+ self.time_mlp = None
83
+
84
+ # left layers
85
+ self.downs = nn.ModuleList([])
86
+ self.ups = nn.ModuleList([])
87
+ skip_dims = []
88
+
89
+ for down_dim_in, down_dim_out in down_in_out:
90
+ self.downs.append(
91
+ nn.ModuleList(
92
+ [
93
+ block_klass(down_dim_in, down_dim_out, time_emb_dim=time_dim),
94
+
95
+ Residual(PreNorm(down_dim_out, attn_klass(down_dim_out, label_emb_dim=label_emb_dim, ))),
96
+ block_klass(down_dim_out, down_dim_out, time_emb_dim=time_dim),
97
+ Residual(PreNorm(down_dim_out, attn_klass(down_dim_out, label_emb_dim=label_emb_dim, ))),
98
+ Downsample(down_dim_out),
99
+ ]
100
+ )
101
+ )
102
+ skip_dims.append(down_dim_out)
103
+
104
+ # bottleneck
105
+ mid_dim = down_dims[-1]
106
+ self.mid_left = nn.ModuleList([])
107
+ self.mid_right = nn.ModuleList([])
108
+ for _ in range(mid_depth - 1):
109
+ self.mid_left.append(block_klass(mid_dim, mid_dim, time_emb_dim=time_dim))
110
+ self.mid_right.append(block_klass(mid_dim * 2, mid_dim, time_emb_dim=time_dim))
111
+ self.mid_mid = nn.ModuleList(
112
+ [
113
+ block_klass(mid_dim, mid_dim, time_emb_dim=time_dim),
114
+ Residual(PreNorm(mid_dim, attn_klass(mid_dim, label_emb_dim=label_emb_dim, ))),
115
+ block_klass(mid_dim, mid_dim, time_emb_dim=time_dim),
116
+ ]
117
+ )
118
+
119
+ # right layers
120
+ for ind, (up_dim_in, up_dim_out) in enumerate(up_in_out):
121
+ skip_dim = skip_dims.pop() # down_dim_out
122
+ self.ups.append(
123
+ nn.ModuleList(
124
+ [
125
+ # pop&cat (h/2, w/2, down_dim_out)
126
+ block_klass(up_dim_in + skip_dim, up_dim_in, time_emb_dim=time_dim),
127
+ Residual(PreNorm(up_dim_in, attn_klass(up_dim_in, label_emb_dim=label_emb_dim, ))),
128
+ Upsample(up_dim_in),
129
+ # pop&cat (h, w, down_dim_out)
130
+ block_klass(up_dim_in + skip_dim, up_dim_out, time_emb_dim=time_dim),
131
+ Residual(PreNorm(up_dim_out, attn_klass(up_dim_out, label_emb_dim=label_emb_dim, ))),
132
+ # pop&cat (h, w, down_dim_out)
133
+ block_klass(up_dim_out + skip_dim, up_dim_out, time_emb_dim=time_dim),
134
+ Residual(PreNorm(up_dim_out, attn_klass(up_dim_out, label_emb_dim=label_emb_dim, ))),
135
+ ]
136
+ )
137
+ )
138
+
139
+ self.final_conv = nn.Sequential(
140
+ block_klass(down_dims[0] + up_dims[-1], up_dims[-1]), nn.Conv2d(up_dims[-1], out_dim, 3, padding=1)
141
+ )
142
+
143
+ def size(self):
144
+ total_params = sum(p.numel() for p in self.parameters())
145
+ trainable_params = sum(p.numel() for p in self.parameters() if p.requires_grad)
146
+ print(f"Total parameters: {total_params}")
147
+ print(f"Trainable parameters: {trainable_params}")
148
+
149
+
150
+ def forward(self, x, time, condition=None):
151
+
152
+ if condition is not None:
153
+ condition_emb = self.label_embedding(condition)
154
+ else:
155
+ condition_emb = None
156
+
157
+ h = []
158
+
159
+ x = self.init_conv(x)
160
+ h.append(x)
161
+
162
+ time_emb = self.time_mlp(time) if exists(self.time_mlp) else None
163
+
164
+ # downsample
165
+ for block1, attn1, block2, attn2, downsample in self.downs:
166
+ x = block1(x, time_emb)
167
+ x = attn1(x, condition_emb)
168
+ h.append(x)
169
+ x = block2(x, time_emb)
170
+ x = attn2(x, condition_emb)
171
+ h.append(x)
172
+ x = downsample(x)
173
+ h.append(x)
174
+
175
+ # bottleneck
176
+
177
+ for block in self.mid_left:
178
+ x = block(x, time_emb)
179
+ h.append(x)
180
+
181
+ (block1, attn, block2) = self.mid_mid
182
+ x = block1(x, time_emb)
183
+ x = attn(x, condition_emb)
184
+ x = block2(x, time_emb)
185
+
186
+ for block in self.mid_right:
187
+ # This is U-Net!!!
188
+ x = pad_and_concat(h.pop(), x)
189
+ x = block(x, time_emb)
190
+
191
+ # upsample
192
+ for block1, attn1, upsample, block2, attn2, block3, attn3 in self.ups:
193
+ x = pad_and_concat(h.pop(), x)
194
+ x = block1(x, time_emb)
195
+ x = attn1(x, condition_emb)
196
+ x = upsample(x)
197
+
198
+ x = pad_and_concat(h.pop(), x)
199
+ x = block2(x, time_emb)
200
+ x = attn2(x, condition_emb)
201
+
202
+ x = pad_and_concat(h.pop(), x)
203
+ x = block3(x, time_emb)
204
+ x = attn3(x, condition_emb)
205
+
206
+ x = pad_and_concat(h.pop(), x)
207
+ x = self.final_conv(x)
208
+ return x
209
+
210
+
211
+ def conditional_p_losses(denoise_model, x_start, t, condition, sqrt_alphas_cumprod, sqrt_one_minus_alphas_cumprod,
212
+ noise=None, loss_type="l1"):
213
+ if noise is None:
214
+ noise = torch.randn_like(x_start)
215
+
216
+ x_noisy = q_sample(x_start=x_start, t=t, sqrt_alphas_cumprod=sqrt_alphas_cumprod,
217
+ sqrt_one_minus_alphas_cumprod=sqrt_one_minus_alphas_cumprod, noise=noise)
218
+ predicted_noise = denoise_model(x_noisy, t, condition)
219
+
220
+ if loss_type == 'l1':
221
+ loss = F.l1_loss(noise, predicted_noise)
222
+ elif loss_type == 'l2':
223
+ loss = F.mse_loss(noise, predicted_noise)
224
+ elif loss_type == "huber":
225
+ loss = F.smooth_l1_loss(noise, predicted_noise)
226
+ else:
227
+ raise NotImplementedError()
228
+
229
+ return loss
230
+
231
+
232
+ def evaluate_diffusion_model(device, model, iterator, BATCH_SIZE, timesteps, unetConfig, encodes2embeddings_mapping,
233
+ uncondition_rate, unconditional_condition):
234
+ model.to(device)
235
+ model.eval()
236
+ eva_loss = []
237
+ sqrt_alphas_cumprod, sqrt_one_minus_alphas_cumprod, _, _ = get_beta_schedule(timesteps)
238
+ for i in xrange(500):
239
+ data, attributes = next(iter(iterator))
240
+ data = data.to(device)
241
+
242
+ conditions = [encodes2embeddings_mapping[create_key(attribute)] for attribute in attributes]
243
+ selected_conditions = [
244
+ unconditional_condition if random.random() < uncondition_rate else random.choice(conditions_of_one_sample)
245
+ for conditions_of_one_sample in conditions]
246
+
247
+ selected_conditions = torch.stack(selected_conditions).float().to(device)
248
+
249
+ t = torch.randint(0, timesteps, (BATCH_SIZE,), device=device).long()
250
+ loss = conditional_p_losses(model, data, t, selected_conditions, loss_type="huber",
251
+ sqrt_alphas_cumprod=sqrt_alphas_cumprod,
252
+ sqrt_one_minus_alphas_cumprod=sqrt_one_minus_alphas_cumprod)
253
+
254
+ eva_loss.append(loss.item())
255
+ initial_loss = np.mean(eva_loss)
256
+ return initial_loss
257
+
258
+
259
+ def get_diffusion_model(model_Config, load_pretrain=False, model_name=None, device="cpu"):
260
+ UNet = ConditionedUnet(**model_Config)
261
+ print(f"Model intialized, size: {sum(p.numel() for p in UNet.parameters() if p.requires_grad)}")
262
+ UNet.to(device)
263
+
264
+ if load_pretrain:
265
+ print(f"Loading weights from models/{model_name}_UNet.pth")
266
+ checkpoint = torch.load(f'models/{model_name}_UNet.pth', map_location=device)
267
+ UNet.load_state_dict(checkpoint['model_state_dict'])
268
+ UNet.eval()
269
+ return UNet
270
+
271
+
272
+ def train_diffusion_model(VAE, text_encoder, CLAP_tokenizer, timbre_encoder, device, init_model_name, unetConfig, BATCH_SIZE, timesteps, lr, max_iter, iterator, load_pretrain,
273
+ encodes2embeddings_mapping, uncondition_rate, unconditional_condition, save_steps=5000, init_loss=None, save_model_name=None,
274
+ n_IS_batches=50):
275
+
276
+ if save_model_name is None:
277
+ save_model_name = init_model_name
278
+
279
+ def save_model_hyperparameter(model_name, unetConfig, BATCH_SIZE, lr, model_size, current_iter, current_loss):
280
+ model_hyperparameter = unetConfig
281
+ model_hyperparameter["BATCH_SIZE"] = BATCH_SIZE
282
+ model_hyperparameter["lr"] = lr
283
+ model_hyperparameter["model_size"] = model_size
284
+ model_hyperparameter["current_iter"] = current_iter
285
+ model_hyperparameter["current_loss"] = current_loss
286
+ with open(f"models/hyperparameters/{model_name}_UNet.json", "w") as json_file:
287
+ json.dump(model_hyperparameter, json_file, ensure_ascii=False, indent=4)
288
+
289
+ model = ConditionedUnet(**unetConfig)
290
+ model_size = sum(p.numel() for p in model.parameters() if p.requires_grad)
291
+ print(f"Trainable parameters: {model_size}")
292
+ model.to(device)
293
+ optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=lr, amsgrad=False)
294
+
295
+ if load_pretrain:
296
+ print(f"Loading weights from models/{init_model_name}_UNet.pt")
297
+ checkpoint = torch.load(f'models/{init_model_name}_UNet.pth')
298
+ model.load_state_dict(checkpoint['model_state_dict'])
299
+ optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
300
+ else:
301
+ print("Model initialized.")
302
+ if max_iter == 0:
303
+ print("Return model directly.")
304
+ return model, optimizer
305
+
306
+
307
+ train_loss = []
308
+ writer = SummaryWriter(f'runs/{save_model_name}_UNet')
309
+ if init_loss is None:
310
+ previous_loss = evaluate_diffusion_model(device, model, iterator, BATCH_SIZE, timesteps, unetConfig, encodes2embeddings_mapping,
311
+ uncondition_rate, unconditional_condition)
312
+ else:
313
+ previous_loss = init_loss
314
+ print(f"initial_IS: {previous_loss}")
315
+ sqrt_alphas_cumprod, sqrt_one_minus_alphas_cumprod, _, _ = get_beta_schedule(timesteps)
316
+
317
+ model.train()
318
+ for i in xrange(max_iter):
319
+ data, attributes = next(iter(iterator))
320
+ data = data.to(device)
321
+
322
+ conditions = [encodes2embeddings_mapping[create_key(attribute)] for attribute in attributes]
323
+ unconditional_condition_copy = torch.tensor(unconditional_condition, dtype=torch.float32).to(device).detach()
324
+ selected_conditions = [unconditional_condition_copy if random.random() < uncondition_rate else random.choice(
325
+ conditions_of_one_sample) for conditions_of_one_sample in conditions]
326
+
327
+ selected_conditions = torch.stack(selected_conditions).float().to(device)
328
+
329
+ optimizer.zero_grad()
330
+
331
+ t = torch.randint(0, timesteps, (BATCH_SIZE,), device=device).long()
332
+ loss = conditional_p_losses(model, data, t, selected_conditions, loss_type="huber",
333
+ sqrt_alphas_cumprod=sqrt_alphas_cumprod,
334
+ sqrt_one_minus_alphas_cumprod=sqrt_one_minus_alphas_cumprod)
335
+
336
+ loss.backward()
337
+ optimizer.step()
338
+
339
+ train_loss.append(loss.item())
340
+ step = int(optimizer.state_dict()['state'][list(optimizer.state_dict()['state'].keys())[0]]['step'].numpy())
341
+
342
+ if step % 100 == 0:
343
+ print('%d step' % (step))
344
+
345
+ if step % save_steps == 0:
346
+ current_loss = np.mean(train_loss[-save_steps:])
347
+ print(f"current_loss = {current_loss}")
348
+ torch.save({
349
+ 'model_state_dict': model.state_dict(),
350
+ 'optimizer_state_dict': optimizer.state_dict(),
351
+ }, f'models/{save_model_name}_UNet.pth')
352
+ save_model_hyperparameter(save_model_name, unetConfig, BATCH_SIZE, lr, model_size, step, current_loss)
353
+
354
+
355
+ if step % 20000 == 0:
356
+ current_IS = get_inception_score(device, model, VAE, text_encoder, CLAP_tokenizer, timbre_encoder, n_IS_batches,
357
+ positive_prompts="", negative_prompts="", CFG=1, sample_steps=20, task="STFT")
358
+ print('current_IS: %.5f' % current_IS)
359
+ current_loss = np.mean(train_loss[-save_steps:])
360
+
361
+ writer.add_scalar(f"current_IS", current_IS, step)
362
+
363
+ torch.save({
364
+ 'model_state_dict': model.state_dict(),
365
+ 'optimizer_state_dict': optimizer.state_dict(),
366
+ }, f'models/history/{save_model_name}_{step}_UNet.pth')
367
+ save_model_hyperparameter(save_model_name, unetConfig, BATCH_SIZE, lr, model_size, step, current_loss)
368
+
369
+ return model, optimizer
370
+
371
+
model/diffusion_components.py ADDED
@@ -0,0 +1,351 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn.functional as F
2
+ import torch
3
+ from torch import nn
4
+ from einops import rearrange
5
+ from inspect import isfunction
6
+ import math
7
+ from tqdm import tqdm
8
+
9
+
10
+ def exists(x):
11
+ """Return true for x is not None."""
12
+ return x is not None
13
+
14
+
15
+ def default(val, d):
16
+ """Helper function"""
17
+ if exists(val):
18
+ return val
19
+ return d() if isfunction(d) else d
20
+
21
+
22
+ class Residual(nn.Module):
23
+ """Skip connection"""
24
+ def __init__(self, fn):
25
+ super().__init__()
26
+ self.fn = fn
27
+
28
+ def forward(self, x, *args, **kwargs):
29
+ return self.fn(x, *args, **kwargs) + x
30
+
31
+
32
+ def Upsample(dim):
33
+ """Upsample layer, a transposed convolution layer with stride=2"""
34
+ return nn.ConvTranspose2d(dim, dim, 4, 2, 1)
35
+
36
+
37
+ def Downsample(dim):
38
+ """Downsample layer, a convolution layer with stride=2"""
39
+ return nn.Conv2d(dim, dim, 4, 2, 1)
40
+
41
+
42
+ class SinusoidalPositionEmbeddings(nn.Module):
43
+ """Return sinusoidal embedding for integer time step."""
44
+
45
+ def __init__(self, dim):
46
+ super().__init__()
47
+ self.dim = dim
48
+
49
+ def forward(self, time):
50
+ device = time.device
51
+ half_dim = self.dim // 2
52
+ embeddings = math.log(10000) / (half_dim - 1)
53
+ embeddings = torch.exp(torch.arange(half_dim, device=device) * -embeddings)
54
+ embeddings = time[:, None] * embeddings[None, :]
55
+ embeddings = torch.cat((embeddings.sin(), embeddings.cos()), dim=-1)
56
+ return embeddings
57
+
58
+
59
+ class Block(nn.Module):
60
+ """Stack of convolution, normalization, and non-linear activation"""
61
+
62
+ def __init__(self, dim, dim_out, groups=8):
63
+ super().__init__()
64
+ self.proj = nn.Conv2d(dim, dim_out, 3, padding=1)
65
+ self.norm = nn.GroupNorm(groups, dim_out)
66
+ self.act = nn.SiLU()
67
+
68
+ def forward(self, x, scale_shift=None):
69
+ x = self.proj(x)
70
+ x = self.norm(x)
71
+
72
+ if exists(scale_shift):
73
+ scale, shift = scale_shift
74
+ x = x * (scale + 1) + shift
75
+
76
+ x = self.act(x)
77
+ return x
78
+
79
+
80
+ class ResnetBlock(nn.Module):
81
+ """Stack of [conv + norm + act (+ scale&shift)], with positional embedding inserted <https://arxiv.org/abs/1512.03385>"""
82
+
83
+ def __init__(self, dim, dim_out, *, time_emb_dim=None, groups=8):
84
+ super().__init__()
85
+ self.mlp = (
86
+ nn.Sequential(nn.SiLU(), nn.Linear(time_emb_dim, dim_out))
87
+ if exists(time_emb_dim)
88
+ else None
89
+ )
90
+
91
+ self.block1 = Block(dim, dim_out, groups=groups)
92
+ self.block2 = Block(dim_out, dim_out, groups=groups)
93
+ self.res_conv = nn.Conv2d(dim, dim_out, 1) if dim != dim_out else nn.Identity()
94
+
95
+ def forward(self, x, time_emb=None):
96
+ h = self.block1(x)
97
+
98
+ if exists(self.mlp) and exists(time_emb):
99
+ time_emb = self.mlp(time_emb)
100
+ # Adding positional embedding to intermediate layer (by broadcasting along spatial dimension)
101
+ h = rearrange(time_emb, "b c -> b c 1 1") + h
102
+
103
+ h = self.block2(h)
104
+ return h + self.res_conv(x)
105
+
106
+
107
+ class ConvNextBlock(nn.Module):
108
+ """Stack of [conv7x7 (+ condition(pos)) + norm + conv3x3 + act + norm + conv3x3 + res1x1],with positional embedding inserted"""
109
+
110
+ def __init__(self, dim, dim_out, *, time_emb_dim=None, mult=2, norm=True):
111
+ super().__init__()
112
+ self.mlp = (
113
+ nn.Sequential(nn.GELU(), nn.Linear(time_emb_dim, dim))
114
+ if exists(time_emb_dim)
115
+ else None
116
+ )
117
+
118
+ self.ds_conv = nn.Conv2d(dim, dim, 7, padding=3, groups=dim)
119
+
120
+ self.net = nn.Sequential(
121
+ nn.GroupNorm(1, dim) if norm else nn.Identity(),
122
+ nn.Conv2d(dim, dim_out * mult, 3, padding=1),
123
+ nn.GELU(),
124
+ nn.GroupNorm(1, dim_out * mult),
125
+ nn.Conv2d(dim_out * mult, dim_out, 3, padding=1),
126
+ )
127
+
128
+ self.res_conv = nn.Conv2d(dim, dim_out, 1) if dim != dim_out else nn.Identity()
129
+
130
+ def forward(self, x, time_emb=None):
131
+ h = self.ds_conv(x)
132
+
133
+ if exists(self.mlp) and exists(time_emb):
134
+ assert exists(time_emb), "time embedding must be passed in"
135
+ condition = self.mlp(time_emb)
136
+ h = h + rearrange(condition, "b c -> b c 1 1")
137
+
138
+ h = self.net(h)
139
+ return h + self.res_conv(x)
140
+
141
+
142
+ class PreNorm(nn.Module):
143
+ """Apply normalization before 'fn'"""
144
+
145
+ def __init__(self, dim, fn):
146
+ super().__init__()
147
+ self.fn = fn
148
+ self.norm = nn.GroupNorm(1, dim)
149
+
150
+ def forward(self, x, *args, **kwargs):
151
+ x = self.norm(x)
152
+ return self.fn(x, *args, **kwargs)
153
+
154
+
155
+ class ConditionalEmbedding(nn.Module):
156
+ """Return embedding for label and projection for text embedding"""
157
+
158
+ def __init__(self, num_labels, embedding_dim, condition_type="instrument_family"):
159
+ super(ConditionalEmbedding, self).__init__()
160
+ if condition_type == "instrument_family":
161
+ self.embedding = nn.Embedding(num_labels, embedding_dim)
162
+ elif condition_type == "natural_language_prompt":
163
+ self.embedding = nn.Linear(embedding_dim, embedding_dim, bias=True)
164
+ else:
165
+ raise NotImplementedError()
166
+
167
+ def forward(self, labels):
168
+ return self.embedding(labels)
169
+
170
+
171
+ class LinearCrossAttention(nn.Module):
172
+ """Combination of efficient attention and cross attention."""
173
+
174
+ def __init__(self, dim, heads=4, label_emb_dim=128, dim_head=32):
175
+ super().__init__()
176
+ self.dim_head = dim_head
177
+ self.scale = dim_head ** -0.5
178
+ self.heads = heads
179
+ hidden_dim = dim_head * heads
180
+ self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias=False)
181
+ self.to_out = nn.Sequential(nn.Conv2d(hidden_dim, dim, 1), nn.GroupNorm(1, dim))
182
+
183
+ # embedding for key and value
184
+ self.label_key = nn.Linear(label_emb_dim, hidden_dim)
185
+ self.label_value = nn.Linear(label_emb_dim, hidden_dim)
186
+
187
+ def forward(self, x, label_embedding=None):
188
+ b, c, h, w = x.shape
189
+ qkv = self.to_qkv(x).chunk(3, dim=1)
190
+ q, k, v = map(
191
+ lambda t: rearrange(t, "b (h c) x y -> b h c (x y)", h=self.heads), qkv
192
+ )
193
+
194
+ if label_embedding is not None:
195
+ label_k = self.label_key(label_embedding).view(b, self.heads, self.dim_head, 1)
196
+ label_v = self.label_value(label_embedding).view(b, self.heads, self.dim_head, 1)
197
+
198
+ k = torch.cat([k, label_k], dim=-1)
199
+ v = torch.cat([v, label_v], dim=-1)
200
+
201
+ q = q.softmax(dim=-2)
202
+ k = k.softmax(dim=-1)
203
+ q = q * self.scale
204
+ context = torch.einsum("b h d n, b h e n -> b h d e", k, v)
205
+ out = torch.einsum("b h d e, b h d n -> b h e n", context, q)
206
+ out = rearrange(out, "b h c (x y) -> b (h c) x y", h=self.heads, x=h, y=w)
207
+ return self.to_out(out)
208
+
209
+
210
+ def pad_to_match(encoder_tensor, decoder_tensor):
211
+ """
212
+ Pads the decoder_tensor to match the spatial dimensions of encoder_tensor.
213
+
214
+ :param encoder_tensor: The feature map from the encoder.
215
+ :param decoder_tensor: The feature map from the decoder that needs to be upsampled.
216
+ :return: Padded decoder_tensor with the same spatial dimensions as encoder_tensor.
217
+ """
218
+
219
+ enc_shape = encoder_tensor.shape[2:] # spatial dimensions are at index 2 and 3
220
+ dec_shape = decoder_tensor.shape[2:]
221
+
222
+ # assume enc_shape >= dec_shape
223
+ delta_w = enc_shape[1] - dec_shape[1]
224
+ delta_h = enc_shape[0] - dec_shape[0]
225
+
226
+ # padding
227
+ padding_left = delta_w // 2
228
+ padding_right = delta_w - padding_left
229
+ padding_top = delta_h // 2
230
+ padding_bottom = delta_h - padding_top
231
+ decoder_tensor_padded = F.pad(decoder_tensor, (padding_left, padding_right, padding_top, padding_bottom))
232
+
233
+ return decoder_tensor_padded
234
+
235
+
236
+ def pad_and_concat(encoder_tensor, decoder_tensor):
237
+ """
238
+ Pads the decoder_tensor and concatenates it with the encoder_tensor along the channel dimension.
239
+
240
+ :param encoder_tensor: The feature map from the encoder.
241
+ :param decoder_tensor: The feature map from the decoder that needs to be concatenated with encoder_tensor.
242
+ :return: Concatenated tensor.
243
+ """
244
+
245
+ # pad decoder_tensor
246
+ decoder_tensor_padded = pad_to_match(encoder_tensor, decoder_tensor)
247
+ # concat encoder_tensor and decoder_tensor_padded
248
+ concatenated_tensor = torch.cat((encoder_tensor, decoder_tensor_padded), dim=1)
249
+ return concatenated_tensor
250
+
251
+
252
+ class LinearCrossAttentionAdd(nn.Module):
253
+ def __init__(self, dim, heads=4, label_emb_dim=128, dim_head=32):
254
+ super().__init__()
255
+ self.dim = dim
256
+ self.dim_head = dim_head
257
+ self.scale = dim_head ** -0.5
258
+ self.heads = heads
259
+ self.label_emb_dim = label_emb_dim
260
+ self.dim_head = dim_head
261
+
262
+ self.hidden_dim = dim_head * heads
263
+ self.to_qkv = nn.Conv2d(self.dim, self.hidden_dim * 3, 1, bias=False)
264
+ self.to_out = nn.Sequential(nn.Conv2d(self.hidden_dim, dim, 1), nn.GroupNorm(1, dim))
265
+
266
+ # embedding for key and value
267
+ self.label_key = nn.Linear(label_emb_dim, self.hidden_dim)
268
+ self.label_query = nn.Linear(label_emb_dim, self.hidden_dim)
269
+
270
+
271
+ def forward(self, x, condition=None):
272
+ b, c, h, w = x.shape
273
+
274
+ qkv = self.to_qkv(x).chunk(3, dim=1)
275
+
276
+ q, k, v = map(
277
+ lambda t: rearrange(t, "b (h c) x y -> b h c (x y)", h=self.heads), qkv
278
+ )
279
+
280
+ # if condition exists,concat its key and value with origin
281
+ if condition is not None:
282
+ label_k = self.label_key(condition).view(b, self.heads, self.dim_head, 1)
283
+ label_q = self.label_query(condition).view(b, self.heads, self.dim_head, 1)
284
+ k = k + label_k
285
+ q = q + label_q
286
+
287
+ q = q.softmax(dim=-2)
288
+ k = k.softmax(dim=-1)
289
+ q = q * self.scale
290
+ context = torch.einsum("b h d n, b h e n -> b h d e", k, v)
291
+ out = torch.einsum("b h d e, b h d n -> b h e n", context, q)
292
+ out = rearrange(out, "b h c (x y) -> b (h c) x y", h=self.heads, x=h, y=w)
293
+ return self.to_out(out)
294
+
295
+
296
+
297
+ def linear_beta_schedule(timesteps):
298
+ beta_start = 0.0001
299
+ beta_end = 0.02
300
+ return torch.linspace(beta_start, beta_end, timesteps)
301
+
302
+
303
+ def get_beta_schedule(timesteps):
304
+ betas = linear_beta_schedule(timesteps=timesteps)
305
+
306
+ # define alphas
307
+ alphas = 1. - betas
308
+ alphas_cumprod = torch.cumprod(alphas, axis=0)
309
+ alphas_cumprod_prev = F.pad(alphas_cumprod[:-1], (1, 0), value=1.0)
310
+ sqrt_recip_alphas = torch.sqrt(1.0 / alphas)
311
+
312
+ # calculations for diffusion q(x_t | x_{t-1}) and others
313
+ sqrt_alphas_cumprod = torch.sqrt(alphas_cumprod)
314
+ sqrt_one_minus_alphas_cumprod = torch.sqrt(1. - alphas_cumprod)
315
+
316
+ # calculations for posterior q(x_{t-1} | x_t, x_0)
317
+ posterior_variance = betas * (1. - alphas_cumprod_prev) / (1. - alphas_cumprod)
318
+ return sqrt_alphas_cumprod, sqrt_one_minus_alphas_cumprod, posterior_variance, sqrt_recip_alphas
319
+
320
+
321
+ def extract(a, t, x_shape):
322
+ batch_size = t.shape[0]
323
+ out = a.gather(-1, t.cpu())
324
+ return out.reshape(batch_size, *((1,) * (len(x_shape) - 1))).to(t.device)
325
+
326
+
327
+ # forward diffusion
328
+ def q_sample(x_start, t, sqrt_alphas_cumprod, sqrt_one_minus_alphas_cumprod, noise=None):
329
+ if noise is None:
330
+ noise = torch.randn_like(x_start)
331
+
332
+ sqrt_alphas_cumprod_t = extract(sqrt_alphas_cumprod, t, x_start.shape)
333
+ sqrt_one_minus_alphas_cumprod_t = extract(
334
+ sqrt_one_minus_alphas_cumprod, t, x_start.shape
335
+ )
336
+
337
+ return sqrt_alphas_cumprod_t * x_start + sqrt_one_minus_alphas_cumprod_t * noise
338
+
339
+
340
+
341
+
342
+
343
+
344
+
345
+
346
+
347
+
348
+
349
+
350
+
351
+
model/multimodal_model.py ADDED
@@ -0,0 +1,274 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import itertools
2
+ import json
3
+ import random
4
+
5
+ import numpy as np
6
+ import torch
7
+ from torch import nn
8
+ import torch.nn.functional as F
9
+
10
+ from tools import create_key
11
+ from model.timbre_encoder_pretrain import get_timbre_encoder
12
+
13
+
14
+ class ProjectionLayer(nn.Module):
15
+ """Single-layer Linear projection with dropout, layer norm, and Gelu activation"""
16
+
17
+ def __init__(self, input_dim, output_dim, dropout):
18
+ super(ProjectionLayer, self).__init__()
19
+ self.projection = nn.Linear(input_dim, output_dim)
20
+ self.gelu = nn.GELU()
21
+ self.fc = nn.Linear(output_dim, output_dim)
22
+ self.dropout = nn.Dropout(dropout)
23
+ self.layer_norm = nn.LayerNorm(output_dim)
24
+
25
+ def forward(self, x):
26
+ projected = self.projection(x)
27
+ x = self.gelu(projected)
28
+ x = self.fc(x)
29
+ x = self.dropout(x)
30
+ x = x + projected
31
+ x = self.layer_norm(x)
32
+ return x
33
+
34
+
35
+ class ProjectionHead(nn.Module):
36
+ """Stack of 'ProjectionLayer'"""
37
+
38
+ def __init__(self, embedding_dim, projection_dim, dropout, num_layers=2):
39
+ super(ProjectionHead, self).__init__()
40
+ self.layers = nn.ModuleList([ProjectionLayer(embedding_dim if i == 0 else projection_dim,
41
+ projection_dim,
42
+ dropout) for i in range(num_layers)])
43
+
44
+ def forward(self, x):
45
+ for layer in self.layers:
46
+ x = layer(x)
47
+ return x
48
+
49
+
50
+ class multi_modal_model(nn.Module):
51
+ """The multi-modal model for contrastive learning"""
52
+
53
+ def __init__(
54
+ self,
55
+ timbre_encoder,
56
+ text_encoder,
57
+ spectrogram_feature_dim,
58
+ text_feature_dim,
59
+ multi_modal_emb_dim,
60
+ temperature,
61
+ dropout,
62
+ num_projection_layers=1,
63
+ freeze_spectrogram_encoder=True,
64
+ freeze_text_encoder=True,
65
+ ):
66
+ super().__init__()
67
+ self.timbre_encoder = timbre_encoder
68
+ self.text_encoder = text_encoder
69
+
70
+ self.multi_modal_emb_dim = multi_modal_emb_dim
71
+
72
+ self.text_projection = ProjectionHead(embedding_dim=text_feature_dim,
73
+ projection_dim=self.multi_modal_emb_dim, dropout=dropout,
74
+ num_layers=num_projection_layers)
75
+
76
+ self.spectrogram_projection = ProjectionHead(embedding_dim=spectrogram_feature_dim,
77
+ projection_dim=self.multi_modal_emb_dim, dropout=dropout,
78
+ num_layers=num_projection_layers)
79
+
80
+ self.temperature = temperature
81
+
82
+ # Make spectrogram_encoder parameters non-trainable
83
+ for param in self.timbre_encoder.parameters():
84
+ param.requires_grad = not freeze_spectrogram_encoder
85
+
86
+ # Make text_encoder parameters non-trainable
87
+ for param in self.text_encoder.parameters():
88
+ param.requires_grad = not freeze_text_encoder
89
+
90
+ def forward(self, spectrogram_batch, tokenized_text_batch):
91
+ # Getting Image and Text Embeddings (with same dimension)
92
+ spectrogram_features, _, _, _, _ = self.timbre_encoder(spectrogram_batch)
93
+ text_features = self.text_encoder.get_text_features(**tokenized_text_batch)
94
+
95
+ # Concat and apply projection
96
+ spectrogram_embeddings = self.spectrogram_projection(spectrogram_features)
97
+ text_embeddings = self.text_projection(text_features)
98
+
99
+ # Calculating the Loss
100
+ logits = (text_embeddings @ spectrogram_embeddings.T) / self.temperature
101
+ images_similarity = spectrogram_embeddings @ spectrogram_embeddings.T
102
+ texts_similarity = text_embeddings @ text_embeddings.T
103
+ targets = F.softmax(
104
+ (images_similarity + texts_similarity) / 2 * self.temperature, dim=-1
105
+ )
106
+ texts_loss = cross_entropy(logits, targets, reduction='none')
107
+ images_loss = cross_entropy(logits.T, targets.T, reduction='none')
108
+ contrastive_loss = (images_loss + texts_loss) / 2.0 # shape: (batch_size)
109
+ contrastive_loss = contrastive_loss.mean()
110
+
111
+ return contrastive_loss
112
+
113
+
114
+ def get_text_features(self, input_ids, attention_mask):
115
+ text_features = self.text_encoder.get_text_features(input_ids=input_ids, attention_mask=attention_mask)
116
+ return self.text_projection(text_features)
117
+
118
+
119
+ def get_timbre_features(self, spectrogram_batch):
120
+ spectrogram_features, _, _, _, _ = self.timbre_encoder(spectrogram_batch)
121
+ return self.spectrogram_projection(spectrogram_features)
122
+
123
+
124
+ def cross_entropy(preds, targets, reduction='none'):
125
+ log_softmax = nn.LogSoftmax(dim=-1)
126
+ loss = (-targets * log_softmax(preds)).sum(1)
127
+ if reduction == "none":
128
+ return loss
129
+ elif reduction == "mean":
130
+ return loss.mean()
131
+
132
+
133
+ def get_multi_modal_model(timbre_encoder, text_encoder, model_Config, load_pretrain=False, model_name=None, device="cpu"):
134
+ mmm = multi_modal_model(timbre_encoder, text_encoder, **model_Config)
135
+ print(f"Model intialized, size: {sum(p.numel() for p in mmm.parameters() if p.requires_grad)}")
136
+ mmm.to(device)
137
+
138
+ if load_pretrain:
139
+ print(f"Loading weights from models/{model_name}_MMM.pth")
140
+ checkpoint = torch.load(f'models/{model_name}_MMM.pth', map_location=device)
141
+ mmm.load_state_dict(checkpoint['model_state_dict'])
142
+ mmm.eval()
143
+ return mmm
144
+
145
+
146
+ def train_epoch(text_tokenizer, model, train_loader, labels_mapping, optimizer, device):
147
+ (data, attributes) = next(iter(train_loader))
148
+ keys = [create_key(attribute) for attribute in attributes]
149
+
150
+ while(len(set(keys)) != len(keys)):
151
+ (data, attributes) = next(iter(train_loader))
152
+ keys = [create_key(attribute) for attribute in attributes]
153
+
154
+ data = data.to(device)
155
+
156
+ texts = [labels_mapping[create_key(attribute)] for attribute in attributes]
157
+ selected_texts = [l[random.randint(0, len(l) - 1)] for l in texts]
158
+
159
+ tokenized_text = text_tokenizer(selected_texts, padding=True, return_tensors="pt").to(device)
160
+
161
+ loss = model(data, tokenized_text)
162
+ optimizer.zero_grad()
163
+ loss.backward()
164
+ optimizer.step()
165
+
166
+ return loss.item()
167
+
168
+
169
+ def valid_epoch(text_tokenizer, model, valid_loader, labels_mapping, device):
170
+ (data, attributes) = next(iter(valid_loader))
171
+ keys = [create_key(attribute) for attribute in attributes]
172
+
173
+ while(len(set(keys)) != len(keys)):
174
+ (data, attributes) = next(iter(valid_loader))
175
+ keys = [create_key(attribute) for attribute in attributes]
176
+
177
+ data = data.to(device)
178
+ texts = [labels_mapping[create_key(attribute)] for attribute in attributes]
179
+ selected_texts = [l[random.randint(0, len(l) - 1)] for l in texts]
180
+
181
+ tokenized_text = text_tokenizer(selected_texts, padding=True, return_tensors="pt").to(device)
182
+
183
+ loss = model(data, tokenized_text)
184
+ return loss.item()
185
+
186
+
187
+ def train_multi_modal_model(device, training_dataloader, labels_mapping, text_tokenizer, text_encoder,
188
+ timbre_encoder_Config, MMM_config, MMM_training_config,
189
+ mmm_name, BATCH_SIZE, max_iter=0, load_pretrain=True,
190
+ timbre_encoder_name=None, init_loss=None, save_steps=2000):
191
+
192
+ def save_model_hyperparameter(model_name, MMM_config, MMM_training_config, BATCH_SIZE, model_size, current_iter,
193
+ current_loss):
194
+
195
+ model_hyperparameter = MMM_config
196
+ model_hyperparameter.update(MMM_training_config)
197
+ model_hyperparameter["BATCH_SIZE"] = BATCH_SIZE
198
+ model_hyperparameter["model_size"] = model_size
199
+ model_hyperparameter["current_iter"] = current_iter
200
+ model_hyperparameter["current_loss"] = current_loss
201
+ with open(f"models/hyperparameters/{model_name}_MMM.json", "w") as json_file:
202
+ json.dump(model_hyperparameter, json_file, ensure_ascii=False, indent=4)
203
+
204
+ timbreEncoder = get_timbre_encoder(timbre_encoder_Config, load_pretrain=True, model_name=timbre_encoder_name,
205
+ device=device)
206
+
207
+ mmm = multi_modal_model(timbreEncoder, text_encoder, **MMM_config).to(device)
208
+
209
+ print(f"spectrogram_encoder parameter: {sum(p.numel() for p in mmm.timbre_encoder.parameters())}")
210
+ print(f"text_encoder parameter: {sum(p.numel() for p in mmm.text_encoder.parameters())}")
211
+ print(f"spectrogram_projection parameter: {sum(p.numel() for p in mmm.spectrogram_projection.parameters())}")
212
+ print(f"text_projection parameter: {sum(p.numel() for p in mmm.text_projection.parameters())}")
213
+ total_parameters = sum(p.numel() for p in mmm.parameters())
214
+ trainable_parameters = sum(p.numel() for p in mmm.parameters() if p.requires_grad)
215
+ print(f"Trainable/Total parameter: {trainable_parameters}/{total_parameters}")
216
+
217
+ params = [
218
+ {"params": itertools.chain(
219
+ mmm.spectrogram_projection.parameters(),
220
+ mmm.text_projection.parameters(),
221
+ ), "lr": MMM_training_config["head_lr"], "weight_decay": MMM_training_config["head_weight_decay"]},
222
+ ]
223
+ if not MMM_config["freeze_text_encoder"]:
224
+ params.append({"params": mmm.text_encoder.parameters(), "lr": MMM_training_config["text_encoder_lr"],
225
+ "weight_decay": MMM_training_config["text_encoder_weight_decay"]})
226
+ if not MMM_config["freeze_spectrogram_encoder"]:
227
+ params.append({"params": mmm.timbre_encoder.parameters(), "lr": MMM_training_config["spectrogram_encoder_lr"],
228
+ "weight_decay": MMM_training_config["timbre_encoder_weight_decay"]})
229
+
230
+ optimizer = torch.optim.AdamW(params, weight_decay=0.)
231
+
232
+ if load_pretrain:
233
+ print(f"Loading weights from models/{mmm_name}_MMM.pt")
234
+ checkpoint = torch.load(f'models/{mmm_name}_MMM.pth')
235
+ mmm.load_state_dict(checkpoint['model_state_dict'])
236
+ optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
237
+ else:
238
+ print("Model initialized.")
239
+
240
+ if max_iter == 0:
241
+ print("Return model directly.")
242
+ return mmm, optimizer
243
+
244
+ if init_loss is None:
245
+ previous_lowest_loss = valid_epoch(text_tokenizer, mmm, training_dataloader, labels_mapping, device)
246
+ else:
247
+ previous_lowest_loss = init_loss
248
+ print(f"Initial total loss: {previous_lowest_loss}")
249
+
250
+ train_loss_list = []
251
+ for i in range(max_iter):
252
+
253
+ mmm.train()
254
+ train_loss = train_epoch(text_tokenizer, mmm, training_dataloader, labels_mapping, optimizer, device)
255
+ train_loss_list.append(train_loss)
256
+
257
+ step = int(
258
+ optimizer.state_dict()['state'][list(optimizer.state_dict()['state'].keys())[0]]['step'].cpu().numpy())
259
+ if (i + 1) % 100 == 0:
260
+ print('%d step' % (step))
261
+
262
+ if (i + 1) % save_steps == 0:
263
+ current_loss = np.mean(train_loss_list[-save_steps:])
264
+ print(f"train_total_loss: {current_loss}")
265
+ if current_loss < previous_lowest_loss:
266
+ previous_lowest_loss = current_loss
267
+ torch.save({
268
+ 'model_state_dict': mmm.state_dict(),
269
+ 'optimizer_state_dict': optimizer.state_dict(),
270
+ }, f'models/{mmm_name}_MMM.pth')
271
+ save_model_hyperparameter(mmm_name, MMM_config, MMM_training_config, BATCH_SIZE, total_parameters, step,
272
+ current_loss)
273
+
274
+ return mmm, optimizer
model/timbre_encoder_pretrain.py ADDED
@@ -0,0 +1,220 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import numpy as np
3
+ import torch
4
+ from torch import nn
5
+ from torch.utils.tensorboard import SummaryWriter
6
+ from tools import create_key
7
+
8
+
9
+ class TimbreEncoder(nn.Module):
10
+ def __init__(self, input_dim, feature_dim, hidden_dim, num_instrument_classes, num_instrument_family_classes, num_velocity_classes, num_qualities, num_layers=1):
11
+ super(TimbreEncoder, self).__init__()
12
+
13
+ # Input layer
14
+ self.input_layer = nn.Linear(input_dim, feature_dim)
15
+
16
+ # LSTM Layer
17
+ self.lstm = nn.LSTM(feature_dim, hidden_dim, num_layers=num_layers, batch_first=True)
18
+
19
+ # Fully Connected Layers for classification
20
+ self.instrument_classifier_layer = nn.Linear(hidden_dim, num_instrument_classes)
21
+ self.instrument_family_classifier_layer = nn.Linear(hidden_dim, num_instrument_family_classes)
22
+ self.velocity_classifier_layer = nn.Linear(hidden_dim, num_velocity_classes)
23
+ self.qualities_classifier_layer = nn.Linear(hidden_dim, num_qualities)
24
+
25
+ # Softmax for converting output to probabilities
26
+ self.softmax = nn.LogSoftmax(dim=1)
27
+
28
+ def forward(self, x):
29
+ # # Merge first two dimensions
30
+ batch_size, _, _, seq_len = x.shape
31
+ x = x.view(batch_size, -1, seq_len) # [batch_size, input_dim, seq_len]
32
+
33
+ # Forward propagate LSTM
34
+ x = x.permute(0, 2, 1)
35
+ x = self.input_layer(x)
36
+ feature, _ = self.lstm(x)
37
+ feature = feature[:, -1, :]
38
+
39
+ # Apply classification layers
40
+ instrument_logits = self.instrument_classifier_layer(feature)
41
+ instrument_family_logits = self.instrument_family_classifier_layer(feature)
42
+ velocity_logits = self.velocity_classifier_layer(feature)
43
+ qualities = self.qualities_classifier_layer(feature)
44
+
45
+ # Apply Softmax
46
+ instrument_logits = self.softmax(instrument_logits)
47
+ instrument_family_logits= self.softmax(instrument_family_logits)
48
+ velocity_logits = self.softmax(velocity_logits)
49
+ qualities = torch.sigmoid(qualities)
50
+
51
+ return feature, instrument_logits, instrument_family_logits, velocity_logits, qualities
52
+
53
+
54
+ def get_multiclass_acc(outputs, ground_truth):
55
+ _, predicted = torch.max(outputs.data, 1)
56
+ total = ground_truth.size(0)
57
+ correct = (predicted == ground_truth).sum().item()
58
+ accuracy = 100 * correct / total
59
+ return accuracy
60
+
61
+ def get_binary_accuracy(y_pred, y_true):
62
+ predictions = (y_pred > 0.5).int()
63
+
64
+ correct_predictions = (predictions == y_true).float()
65
+
66
+ accuracy = correct_predictions.mean()
67
+
68
+ return accuracy.item() * 100.0
69
+
70
+
71
+ def get_timbre_encoder(model_Config, load_pretrain=False, model_name=None, device="cpu"):
72
+ timbreEncoder = TimbreEncoder(**model_Config)
73
+ print(f"Model intialized, size: {sum(p.numel() for p in timbreEncoder.parameters() if p.requires_grad)}")
74
+ timbreEncoder.to(device)
75
+
76
+ if load_pretrain:
77
+ print(f"Loading weights from models/{model_name}_timbre_encoder.pth")
78
+ checkpoint = torch.load(f'models/{model_name}_timbre_encoder.pth', map_location=device)
79
+ timbreEncoder.load_state_dict(checkpoint['model_state_dict'])
80
+ timbreEncoder.eval()
81
+ return timbreEncoder
82
+
83
+
84
+ def evaluate_timbre_encoder(device, model, iterator, nll_Loss, bce_Loss, n_sample=100):
85
+ model.to(device)
86
+ model.eval()
87
+
88
+ eva_loss = []
89
+ for i in range(n_sample):
90
+ representation, attributes = next(iter(iterator))
91
+
92
+ instrument = torch.tensor([s["instrument"] for s in attributes], dtype=torch.long).to(device)
93
+ instrument_family = torch.tensor([s["instrument_family"] for s in attributes], dtype=torch.long).to(device)
94
+ velocity = torch.tensor([s["velocity"] for s in attributes], dtype=torch.long).to(device)
95
+ qualities = torch.tensor([[int(char) for char in create_key(attribute)[-10:]] for attribute in attributes], dtype=torch.float32).to(device)
96
+
97
+ _, instrument_logits, instrument_family_logits, velocity_logits, qualities_pred = model(representation.to(device))
98
+
99
+ # compute loss
100
+ instrument_loss = nll_Loss(instrument_logits, instrument)
101
+ instrument_family_loss = nll_Loss(instrument_family_logits, instrument_family)
102
+ velocity_loss = nll_Loss(velocity_logits, velocity)
103
+ qualities_loss = bce_Loss(qualities_pred, qualities)
104
+
105
+ loss = instrument_loss + instrument_family_loss + velocity_loss + qualities_loss
106
+
107
+ eva_loss.append(loss.item())
108
+
109
+ eva_loss = np.mean(eva_loss)
110
+ return eva_loss
111
+
112
+
113
+ def train_timbre_encoder(device, model_name, timbre_encoder_Config, BATCH_SIZE, lr, max_iter, training_iterator, load_pretrain):
114
+ def save_model_hyperparameter(model_name, timbre_encoder_Config, BATCH_SIZE, lr, model_size, current_iter,
115
+ current_loss):
116
+ model_hyperparameter = timbre_encoder_Config
117
+ model_hyperparameter["BATCH_SIZE"] = BATCH_SIZE
118
+ model_hyperparameter["lr"] = lr
119
+ model_hyperparameter["model_size"] = model_size
120
+ model_hyperparameter["current_iter"] = current_iter
121
+ model_hyperparameter["current_loss"] = current_loss
122
+ with open(f"models/hyperparameters/{model_name}_timbre_encoder.json", "w") as json_file:
123
+ json.dump(model_hyperparameter, json_file, ensure_ascii=False, indent=4)
124
+
125
+ model = TimbreEncoder(**timbre_encoder_Config)
126
+ model_size = sum(p.numel() for p in model.parameters() if p.requires_grad)
127
+ print(f"Model size: {model_size}")
128
+ model.to(device)
129
+ nll_Loss = torch.nn.NLLLoss()
130
+ bce_Loss = torch.nn.BCELoss()
131
+
132
+ optimizer = torch.optim.Adam(model.parameters(), lr=lr, amsgrad=False)
133
+
134
+ if load_pretrain:
135
+ print(f"Loading weights from models/{model_name}_timbre_encoder.pt")
136
+ checkpoint = torch.load(f'models/{model_name}_timbre_encoder.pth')
137
+ model.load_state_dict(checkpoint['model_state_dict'])
138
+ optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
139
+ else:
140
+ print("Model initialized.")
141
+ if max_iter == 0:
142
+ print("Return model directly.")
143
+ return model, model
144
+
145
+ train_loss, training_instrument_acc, training_instrument_family_acc, training_velocity_acc, training_qualities_acc = [], [], [], [], []
146
+ writer = SummaryWriter(f'runs/{model_name}_timbre_encoder')
147
+ current_best_model = model
148
+ previous_lowest_loss = 100.0
149
+ print(f"initial__loss: {previous_lowest_loss}")
150
+
151
+ for i in range(max_iter):
152
+ model.train()
153
+
154
+ representation, attributes = next(iter(training_iterator))
155
+
156
+ instrument = torch.tensor([s["instrument"] for s in attributes], dtype=torch.long).to(device)
157
+ instrument_family = torch.tensor([s["instrument_family"] for s in attributes], dtype=torch.long).to(device)
158
+ velocity = torch.tensor([s["velocity"] for s in attributes], dtype=torch.long).to(device)
159
+ qualities = torch.tensor([[int(char) for char in create_key(attribute)[-10:]] for attribute in attributes], dtype=torch.float32).to(device)
160
+
161
+ optimizer.zero_grad()
162
+
163
+ _, instrument_logits, instrument_family_logits, velocity_logits, qualities_pred = model(representation.to(device))
164
+
165
+ # compute loss
166
+ instrument_loss = nll_Loss(instrument_logits, instrument)
167
+ instrument_family_loss = nll_Loss(instrument_family_logits, instrument_family)
168
+ velocity_loss = nll_Loss(velocity_logits, velocity)
169
+ qualities_loss = bce_Loss(qualities_pred, qualities)
170
+
171
+ loss = instrument_loss + instrument_family_loss + velocity_loss + qualities_loss
172
+
173
+ loss.backward()
174
+ optimizer.step()
175
+ instrument_acc = get_multiclass_acc(instrument_logits, instrument)
176
+ instrument_family_acc = get_multiclass_acc(instrument_family_logits, instrument_family)
177
+ velocity_acc = get_multiclass_acc(velocity_logits, velocity)
178
+ qualities_acc = get_binary_accuracy(qualities_pred, qualities)
179
+
180
+ train_loss.append(loss.item())
181
+ training_instrument_acc.append(instrument_acc)
182
+ training_instrument_family_acc.append(instrument_family_acc)
183
+ training_velocity_acc.append(velocity_acc)
184
+ training_qualities_acc.append(qualities_acc)
185
+ step = int(optimizer.state_dict()['state'][list(optimizer.state_dict()['state'].keys())[0]]['step'].numpy())
186
+
187
+ if (i + 1) % 100 == 0:
188
+ print('%d step' % (step))
189
+
190
+ save_steps = 500
191
+ if (i + 1) % save_steps == 0:
192
+ current_loss = np.mean(train_loss[-save_steps:])
193
+ current_instrument_acc = np.mean(training_instrument_acc[-save_steps:])
194
+ current_instrument_family_acc = np.mean(training_instrument_family_acc[-save_steps:])
195
+ current_velocity_acc = np.mean(training_velocity_acc[-save_steps:])
196
+ current_qualities_acc = np.mean(training_qualities_acc[-save_steps:])
197
+ print('train_loss: %.5f' % current_loss)
198
+ print('current_instrument_acc: %.5f' % current_instrument_acc)
199
+ print('current_instrument_family_acc: %.5f' % current_instrument_family_acc)
200
+ print('current_velocity_acc: %.5f' % current_velocity_acc)
201
+ print('current_qualities_acc: %.5f' % current_qualities_acc)
202
+ writer.add_scalar(f"train_loss", current_loss, step)
203
+ writer.add_scalar(f"current_instrument_acc", current_instrument_acc, step)
204
+ writer.add_scalar(f"current_instrument_family_acc", current_instrument_family_acc, step)
205
+ writer.add_scalar(f"current_velocity_acc", current_velocity_acc, step)
206
+ writer.add_scalar(f"current_qualities_acc", current_qualities_acc, step)
207
+
208
+ if current_loss < previous_lowest_loss:
209
+ previous_lowest_loss = current_loss
210
+ current_best_model = model
211
+ torch.save({
212
+ 'model_state_dict': model.state_dict(),
213
+ 'optimizer_state_dict': optimizer.state_dict(),
214
+ }, f'models/{model_name}_timbre_encoder.pth')
215
+ save_model_hyperparameter(model_name, timbre_encoder_Config, BATCH_SIZE, lr, model_size, step,
216
+ current_loss)
217
+
218
+ return model, current_best_model
219
+
220
+
models/24_1_2024-52_4x_L_D_imageVQVAE.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e5feec46219e25e6f95bfa453a4bddb3ec7bc26d29f2e01748defa4901762c9f
3
+ size 16069859
models/24_1_2024-52_4x_L_D_imageVQVAE_discriminator.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:69080ff5094f82bba6e69e4310cda27864420dfea9b08d3a001dafe46bbf6808
3
+ size 134268962
models/24_1_2024_MMM.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:494f7fa1f9874ebd2cd870b6da993cb796e907e857a28550e62be8c335fb9f5a
3
+ size 1930637291
models/24_1_2024_STFT_timbre_encoder.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9b9288a7fc4a8cdc4b0db5803fd0a5e88e6bcf07bdde22f49ebb8dec12fb33e6
3
+ size 294502949
models/history/28_1_2024_TE_STFT_300000_UNet.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:edcdd3c65275badad3233eabdcb7a3ffa7adbf95c673172ca2e35338097d1a1c
3
+ size 1284015362
requirements.txt ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ torchmetrics==0.7.0
2
+ torchsynth==1.0.2
3
+ torchaudio
4
+ soundfile
5
+ einops
6
+ pytorch-ssim
7
+ piqa
8
+ torchinfo
9
+ mido
10
+ tensorboard
11
+ librosa
12
+ transformers
13
+ matplotlib
14
+ gradio==3.50.2
15
+
tools.py ADDED
@@ -0,0 +1,344 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import matplotlib.pyplot as plt
3
+ import matplotlib
4
+ import librosa
5
+ from scipy.io.wavfile import write
6
+ import torch
7
+
8
+ k = 1e-16
9
+
10
+ def np_log10(x):
11
+ """Safe log function with base 10."""
12
+ numerator = np.log(x + 1e-16)
13
+ denominator = np.log(10)
14
+ return numerator / denominator
15
+
16
+
17
+ def sigmoid(x):
18
+ """Safe log function with base 10."""
19
+ s = 1 / (1 + np.exp(-x))
20
+ return s
21
+
22
+
23
+ def inv_sigmoid(s):
24
+ """Safe inverse sigmoid function."""
25
+ x = np.log((s / (1 - s)) + 1e-16)
26
+ return x
27
+
28
+
29
+ def spc_to_VAE_input(spc):
30
+ """Restrict value range from [0, infinite] to [0, 1]. (deprecated )"""
31
+ return spc / (1 + spc)
32
+
33
+
34
+ def VAE_out_put_to_spc(o):
35
+ """Inverse transform of function 'spc_to_VAE_input'. (deprecated )"""
36
+ return o / (1 - o + k)
37
+
38
+
39
+
40
+ def np_power_to_db(S, amin=1e-16, top_db=80.0):
41
+ """Helper method for numpy data scaling. (deprecated )"""
42
+ ref = S.max()
43
+
44
+ log_spec = 10.0 * np_log10(np.maximum(amin, S))
45
+ log_spec -= 10.0 * np_log10(np.maximum(amin, ref))
46
+
47
+ log_spec = np.maximum(log_spec, log_spec.max() - top_db)
48
+
49
+ return log_spec
50
+
51
+
52
+ def show_spc(spc):
53
+ """Show a spectrogram. (deprecated )"""
54
+ s = np.shape(spc)
55
+ spc = np.reshape(spc, (s[0], s[1]))
56
+ magnitude_spectrum = np.abs(spc)
57
+ log_spectrum = np_power_to_db(magnitude_spectrum)
58
+ plt.imshow(np.flipud(log_spectrum))
59
+ plt.show()
60
+
61
+
62
+ def save_results(spectrogram, spectrogram_image_path, waveform_path):
63
+ """Save the input 'spectrogram' and its waveform (reconstructed by Griffin Lim)
64
+ to path provided by 'spectrogram_image_path' and 'waveform_path'."""
65
+ magnitude_spectrum = np.abs(spectrogram)
66
+ log_spc = np_power_to_db(magnitude_spectrum)
67
+ log_spc = np.reshape(log_spc, (512, 256))
68
+ matplotlib.pyplot.imsave(spectrogram_image_path, log_spc, vmin=-100, vmax=0,
69
+ origin='lower')
70
+
71
+ # save waveform
72
+ abs_spec = np.zeros((513, 256))
73
+ abs_spec[:512, :] = abs_spec[:512, :] + np.sqrt(np.reshape(spectrogram, (512, 256)))
74
+ rec_signal = librosa.griffinlim(abs_spec, n_iter=32, hop_length=256, win_length=1024)
75
+ write(waveform_path, 16000, rec_signal)
76
+
77
+
78
+ def plot_log_spectrogram(signal: np.ndarray,
79
+ path: str,
80
+ n_fft=2048,
81
+ frame_length=1024,
82
+ frame_step=256):
83
+ """Save spectrogram."""
84
+ stft = librosa.stft(signal, n_fft=n_fft, hop_length=frame_step, win_length=frame_length)
85
+ amp = np.square(np.real(stft)) + np.square(np.imag(stft))
86
+ magnitude_spectrum = np.abs(amp)
87
+ log_mel = np_power_to_db(magnitude_spectrum)
88
+ matplotlib.pyplot.imsave(path, log_mel, vmin=-100, vmax=0, origin='lower')
89
+
90
+
91
+ def visualize_feature_maps(device, model, inputs, channel_indices=[0, 3,]):
92
+ """
93
+ Visualize feature maps before and after quantization for given input.
94
+
95
+ Parameters:
96
+ - model: Your VQ-VAE model.
97
+ - inputs: A batch of input data.
98
+ - channel_indices: Indices of feature map channels to visualize.
99
+ """
100
+ model.eval()
101
+ inputs = inputs.to(device)
102
+
103
+ with torch.no_grad():
104
+ z_e = model._encoder(inputs)
105
+ z_q, loss, (perplexity, min_encodings, min_encoding_indices) = model._vq_vae(z_e)
106
+
107
+ # Assuming inputs have shape [batch_size, channels, height, width]
108
+ batch_size = z_e.size(0)
109
+
110
+ for idx in range(batch_size):
111
+ fig, axs = plt.subplots(1, len(channel_indices)*2, figsize=(15, 5))
112
+
113
+ for i, channel_idx in enumerate(channel_indices):
114
+ # Plot encoder output
115
+ axs[2*i].imshow(z_e[idx][channel_idx].cpu().numpy(), cmap='viridis')
116
+ axs[2*i].set_title(f"Encoder Output - Channel {channel_idx}")
117
+
118
+ # Plot quantized output
119
+ axs[2*i+1].imshow(z_q[idx][channel_idx].cpu().numpy(), cmap='viridis')
120
+ axs[2*i+1].set_title(f"Quantized Output - Channel {channel_idx}")
121
+
122
+ plt.show()
123
+
124
+
125
+ def adjust_audio_length(audio, desired_length, original_sample_rate, target_sample_rate):
126
+ """
127
+ Adjust the audio length to the desired length and resample to target sample rate.
128
+
129
+ Parameters:
130
+ - audio (np.array): The input audio signal
131
+ - desired_length (int): The desired length of the output audio
132
+ - original_sample_rate (int): The original sample rate of the audio
133
+ - target_sample_rate (int): The target sample rate for the output audio
134
+
135
+ Returns:
136
+ - np.array: The adjusted and resampled audio
137
+ """
138
+
139
+ if not (original_sample_rate == target_sample_rate):
140
+ audio = librosa.core.resample(audio, orig_sr=original_sample_rate, target_sr=target_sample_rate)
141
+
142
+ if len(audio) > desired_length:
143
+ return audio[:desired_length]
144
+
145
+ elif len(audio) < desired_length:
146
+ padded_audio = np.zeros(desired_length)
147
+ padded_audio[:len(audio)] = audio
148
+ return padded_audio
149
+ else:
150
+ return audio
151
+
152
+
153
+ def safe_int(s, default=0):
154
+ try:
155
+ return int(s)
156
+ except ValueError:
157
+ return default
158
+
159
+
160
+ def pad_spectrogram(D):
161
+ """Resize spectrogram to (512, 256). (deprecated )"""
162
+ D = D[1:, :]
163
+
164
+ padding_length = 256 - D.shape[1]
165
+ D_padded = np.pad(D, ((0, 0), (0, padding_length)), 'constant')
166
+ return D_padded
167
+
168
+
169
+ def pad_STFT(D, time_resolution=256):
170
+ """Resize spectral matrix by padding and cropping"""
171
+ D = D[1:, :]
172
+
173
+ if time_resolution is None:
174
+ return D
175
+
176
+ padding_length = time_resolution - D.shape[1]
177
+ if padding_length > 0:
178
+ D_padded = np.pad(D, ((0, 0), (0, padding_length)), 'constant')
179
+ return D_padded
180
+ else:
181
+ return D
182
+
183
+
184
+ def depad_STFT(D_padded):
185
+ """Inverse function of 'pad_STFT'"""
186
+ zero_row = np.zeros((1, D_padded.shape[1]))
187
+
188
+ D_restored = np.concatenate([zero_row, D_padded], axis=0)
189
+
190
+ return D_restored
191
+
192
+
193
+ def nnData2Audio(spectrogram_batch, resolution=(512, 256), squared=False):
194
+ """Transform batch of numpy spectrogram into signals and encodings."""
195
+ # Todo: remove resolution hard-coding
196
+ frequency_resolution, time_resolution = resolution
197
+
198
+ if isinstance(spectrogram_batch, torch.Tensor):
199
+ spectrogram_batch = spectrogram_batch.to("cpu").detach().numpy()
200
+
201
+ origin_signals = []
202
+ for spectrogram in spectrogram_batch:
203
+ spc = VAE_out_put_to_spc(spectrogram)
204
+
205
+ # get_audio
206
+ abs_spec = np.zeros((frequency_resolution+1, time_resolution))
207
+
208
+ if squared:
209
+ abs_spec[1:, :] = abs_spec[1:, :] + np.sqrt(np.reshape(spc, (frequency_resolution, time_resolution)))
210
+ else:
211
+ abs_spec[1:, :] = abs_spec[1:, :] + np.reshape(spc, (frequency_resolution, time_resolution))
212
+
213
+ origin_signal = librosa.griffinlim(abs_spec, n_iter=32, hop_length=256, win_length=1024)
214
+ origin_signals.append(origin_signal)
215
+
216
+ return origin_signals
217
+
218
+
219
+ def amp_to_audio(amp, n_iter=50):
220
+ """The Griffin-Lim algorithm."""
221
+ y_reconstructed = librosa.griffinlim(amp, n_iter=n_iter, hop_length=256, win_length=1024)
222
+ return y_reconstructed
223
+
224
+
225
+ def rescale(amp, method="log1p"):
226
+ """Rescale function."""
227
+ if method == "log1p":
228
+ return np.log1p(amp)
229
+ elif method == "NormalizedLogisticCompression":
230
+ return amp / (1.0 + amp)
231
+ else:
232
+ raise NotImplementedError()
233
+
234
+
235
+ def unrescale(scaled_amp, method="NormalizedLogisticCompression"):
236
+ """Inverse function of 'rescale'"""
237
+ if method == "log1p":
238
+ return np.expm1(scaled_amp)
239
+ elif method == "NormalizedLogisticCompression":
240
+ return scaled_amp / (1.0 - scaled_amp + 1e-10)
241
+ else:
242
+ raise NotImplementedError()
243
+
244
+
245
+ def create_key(attributes):
246
+ """Create unique key for each multi-label."""
247
+ qualities_str = ''.join(map(str, attributes["qualities"]))
248
+ instrument_source_str = attributes["instrument_source_str"]
249
+ instrument_family = attributes["instrument_family_str"]
250
+ key = f"{instrument_source_str}_{instrument_family}_{qualities_str}"
251
+ return key
252
+
253
+
254
+ def merge_dictionaries(dicts):
255
+ """Merge dictionaries."""
256
+ merged_dict = {}
257
+ for dictionary in dicts:
258
+ for key, value in dictionary.items():
259
+ if key in merged_dict:
260
+ merged_dict[key] += value
261
+ else:
262
+ merged_dict[key] = value
263
+ return merged_dict
264
+
265
+
266
+ def adsr_envelope(signal, sample_rate, duration, attack_time, decay_time, sustain_level, release_time):
267
+ """
268
+ Apply an ADSR envelope to an audio signal.
269
+
270
+ :param signal: The original audio signal (numpy array).
271
+ :param sample_rate: The sample rate of the audio signal.
272
+ :param attack_time: Attack time in seconds.
273
+ :param decay_time: Decay time in seconds.
274
+ :param sustain_level: Sustain level as a fraction of the peak (0 to 1).
275
+ :param release_time: Release time in seconds.
276
+ :return: The audio signal with the ADSR envelope applied.
277
+ """
278
+ # Calculate the number of samples for each ADSR phase
279
+ duration_samples = int(duration * sample_rate)
280
+
281
+ # assert (duration_samples + int(1.0 * sample_rate)) <= len(signal), "(duration_samples + sample_rate) > len(signal)"
282
+ assert release_time <= 1.0, "release_time > 1.0"
283
+
284
+ attack_samples = int(attack_time * sample_rate)
285
+ decay_samples = int(decay_time * sample_rate)
286
+ release_samples = int(release_time * sample_rate)
287
+ sustain_samples = max(0, duration_samples - attack_samples - decay_samples)
288
+
289
+ # Create ADSR envelope
290
+ attack_env = np.linspace(0, 1, attack_samples)
291
+ decay_env = np.linspace(1, sustain_level, decay_samples)
292
+ sustain_env = np.full(sustain_samples, sustain_level)
293
+ release_env = np.linspace(sustain_level, 0, release_samples)
294
+ release_env_expand = np.zeros(int(1.0 * sample_rate))
295
+ release_env_expand[:len(release_env)] = release_env
296
+
297
+ # Concatenate all phases to create the complete envelope
298
+ envelope = np.concatenate([attack_env, decay_env, sustain_env, release_env_expand])
299
+
300
+ # Apply the envelope to the signal
301
+ if len(envelope) <= len(signal):
302
+ applied_signal = signal[:len(envelope)] * envelope
303
+ else:
304
+ signal_expanded = np.zeros(len(envelope))
305
+ signal_expanded[:len(signal)] = signal
306
+ applied_signal = signal_expanded * envelope
307
+
308
+ return applied_signal
309
+
310
+
311
+ def rms_normalize(audio, target_rms=0.1):
312
+ """Normalize the RMS value."""
313
+ current_rms = np.sqrt(np.mean(audio**2))
314
+ scaling_factor = target_rms / current_rms
315
+ normalized_audio = audio * scaling_factor
316
+ return normalized_audio
317
+
318
+
319
+ def encode_stft(D):
320
+ """'STFT+' function that transform spectral matrix into spectral representation."""
321
+ magnitude = np.abs(D)
322
+ phase = np.angle(D)
323
+
324
+ log_magnitude = np.log1p(magnitude)
325
+
326
+ cos_phase = np.cos(phase)
327
+ sin_phase = np.sin(phase)
328
+
329
+ encoded_D = np.stack([log_magnitude, cos_phase, sin_phase], axis=0)
330
+ return encoded_D
331
+
332
+
333
+ def decode_stft(encoded_D):
334
+ """'ISTFT+' function that reconstructs spectral matrix from spectral representation."""
335
+ log_magnitude = encoded_D[0, ...]
336
+ cos_phase = encoded_D[1, ...]
337
+ sin_phase = encoded_D[2, ...]
338
+
339
+ magnitude = np.expm1(log_magnitude)
340
+
341
+ phase = np.arctan2(sin_phase, cos_phase)
342
+
343
+ D = magnitude * (np.cos(phase) + 1j * np.sin(phase))
344
+ return D
webUI/__pycache__/app.cpython-310.pyc ADDED
Binary file (10.4 kB). View file
 
webUI/deprecated/interpolationWithCondition.py ADDED
@@ -0,0 +1,178 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import numpy as np
3
+ import torch
4
+
5
+ from model.DiffSynthSampler import DiffSynthSampler
6
+ from tools import safe_int
7
+ from webUI.natural_language_guided_STFT.utils import encodeBatch2GradioOutput, latent_representation_to_Gradio_image
8
+
9
+
10
+ def get_interpolation_with_condition_module(gradioWebUI, interpolation_with_text_state):
11
+ # Load configurations
12
+ uNet = gradioWebUI.uNet
13
+ freq_resolution, time_resolution = gradioWebUI.freq_resolution, gradioWebUI.time_resolution
14
+ VAE_scale = gradioWebUI.VAE_scale
15
+ height, width, channels = int(freq_resolution/VAE_scale), int(time_resolution/VAE_scale), gradioWebUI.channels
16
+ timesteps = gradioWebUI.timesteps
17
+ VAE_quantizer = gradioWebUI.VAE_quantizer
18
+ VAE_decoder = gradioWebUI.VAE_decoder
19
+ CLAP = gradioWebUI.CLAP
20
+ CLAP_tokenizer = gradioWebUI.CLAP_tokenizer
21
+ device = gradioWebUI.device
22
+ squared = gradioWebUI.squared
23
+ sample_rate = gradioWebUI.sample_rate
24
+ noise_strategy = gradioWebUI.noise_strategy
25
+
26
+ def diffusion_random_sample(text2sound_prompts_1, text2sound_prompts_2, text2sound_negative_prompts, text2sound_batchsize,
27
+ text2sound_duration,
28
+ text2sound_guidance_scale, text2sound_sampler,
29
+ text2sound_sample_steps, text2sound_seed,
30
+ interpolation_with_text_dict):
31
+ text2sound_sample_steps = int(text2sound_sample_steps)
32
+ text2sound_seed = safe_int(text2sound_seed, 12345678)
33
+ # Todo: take care of text2sound_time_resolution/width
34
+ width = int(time_resolution*((text2sound_duration+1)/4) / VAE_scale)
35
+ text2sound_batchsize = int(text2sound_batchsize)
36
+
37
+ text2sound_embedding_1 = \
38
+ CLAP.get_text_features(**CLAP_tokenizer([text2sound_prompts_1], padding=True, return_tensors="pt"))[0].to(device)
39
+ text2sound_embedding_2 = \
40
+ CLAP.get_text_features(**CLAP_tokenizer([text2sound_prompts_2], padding=True, return_tensors="pt"))[0].to(device)
41
+
42
+ CFG = int(text2sound_guidance_scale)
43
+
44
+ mySampler = DiffSynthSampler(timesteps, height=height, channels=channels, noise_strategy=noise_strategy)
45
+ unconditional_condition = \
46
+ CLAP.get_text_features(**CLAP_tokenizer([text2sound_negative_prompts], padding=True, return_tensors="pt"))[0]
47
+ mySampler.activate_classifier_free_guidance(CFG, unconditional_condition.to(device))
48
+
49
+ mySampler.respace(list(np.linspace(0, timesteps - 1, text2sound_sample_steps, dtype=np.int32)))
50
+
51
+ condition = torch.linspace(1, 0, steps=text2sound_batchsize).unsqueeze(1).to(device) * text2sound_embedding_1 + \
52
+ torch.linspace(0, 1, steps=text2sound_batchsize).unsqueeze(1).to(device) * text2sound_embedding_2
53
+
54
+ # Todo: move this code
55
+ torch.manual_seed(text2sound_seed)
56
+ initial_noise = torch.randn(text2sound_batchsize, channels, height, width).to(device)
57
+
58
+ latent_representations, initial_noise = \
59
+ mySampler.sample(model=uNet, shape=(text2sound_batchsize, channels, height, width), seed=text2sound_seed,
60
+ return_tensor=True, condition=condition, sampler=text2sound_sampler, initial_noise=initial_noise)
61
+
62
+ latent_representations = latent_representations[-1]
63
+
64
+ interpolation_with_text_dict["latent_representations"] = latent_representations
65
+
66
+ latent_representation_gradio_images = []
67
+ quantized_latent_representation_gradio_images = []
68
+ new_sound_spectrogram_gradio_images = []
69
+ new_sound_rec_signals_gradio = []
70
+
71
+ quantized_latent_representations, loss, (_, _, _) = VAE_quantizer(latent_representations)
72
+ # Todo: remove hard-coding
73
+ flipped_log_spectrums, rec_signals = encodeBatch2GradioOutput(VAE_decoder, quantized_latent_representations,
74
+ resolution=(512, width * VAE_scale), centralized=False,
75
+ squared=squared)
76
+
77
+ for i in range(text2sound_batchsize):
78
+ latent_representation_gradio_images.append(latent_representation_to_Gradio_image(latent_representations[i]))
79
+ quantized_latent_representation_gradio_images.append(
80
+ latent_representation_to_Gradio_image(quantized_latent_representations[i]))
81
+ new_sound_spectrogram_gradio_images.append(flipped_log_spectrums[i])
82
+ new_sound_rec_signals_gradio.append((sample_rate, rec_signals[i]))
83
+
84
+ def concatenate_arrays(arrays_list):
85
+ return np.concatenate(arrays_list, axis=1)
86
+
87
+ concatenated_spectrogram_gradio_image = concatenate_arrays(new_sound_spectrogram_gradio_images)
88
+
89
+ interpolation_with_text_dict["latent_representation_gradio_images"] = latent_representation_gradio_images
90
+ interpolation_with_text_dict["quantized_latent_representation_gradio_images"] = quantized_latent_representation_gradio_images
91
+ interpolation_with_text_dict["new_sound_spectrogram_gradio_images"] = new_sound_spectrogram_gradio_images
92
+ interpolation_with_text_dict["new_sound_rec_signals_gradio"] = new_sound_rec_signals_gradio
93
+
94
+ return {text2sound_latent_representation_image: interpolation_with_text_dict["latent_representation_gradio_images"][0],
95
+ text2sound_quantized_latent_representation_image:
96
+ interpolation_with_text_dict["quantized_latent_representation_gradio_images"][0],
97
+ text2sound_sampled_concatenated_spectrogram_image: concatenated_spectrogram_gradio_image,
98
+ text2sound_sampled_spectrogram_image: interpolation_with_text_dict["new_sound_spectrogram_gradio_images"][0],
99
+ text2sound_sampled_audio: interpolation_with_text_dict["new_sound_rec_signals_gradio"][0],
100
+ text2sound_seed_textbox: text2sound_seed,
101
+ interpolation_with_text_state: interpolation_with_text_dict,
102
+ text2sound_sample_index_slider: gr.update(minimum=0, maximum=text2sound_batchsize - 1, value=0, step=1,
103
+ visible=True,
104
+ label="Sample index.",
105
+ info="Swipe to view other samples")}
106
+
107
+ def show_random_sample(sample_index, text2sound_dict):
108
+ sample_index = int(sample_index)
109
+ return {text2sound_latent_representation_image: text2sound_dict["latent_representation_gradio_images"][
110
+ sample_index],
111
+ text2sound_quantized_latent_representation_image:
112
+ text2sound_dict["quantized_latent_representation_gradio_images"][sample_index],
113
+ text2sound_sampled_spectrogram_image: text2sound_dict["new_sound_spectrogram_gradio_images"][sample_index],
114
+ text2sound_sampled_audio: text2sound_dict["new_sound_rec_signals_gradio"][sample_index]}
115
+
116
+ with gr.Tab("InterpolationCond."):
117
+ gr.Markdown("Use interpolation to generate a gradient sound sequence.")
118
+ with gr.Row(variant="panel"):
119
+ with gr.Column(scale=3):
120
+ text2sound_prompts_1_textbox = gr.Textbox(label="Positive prompt 1", lines=2, value="organ")
121
+ text2sound_prompts_2_textbox = gr.Textbox(label="Positive prompt 2", lines=2, value="string")
122
+ text2sound_negative_prompts_textbox = gr.Textbox(label="Negative prompt", lines=2, value="")
123
+
124
+ with gr.Column(scale=1):
125
+ text2sound_sampling_button = gr.Button(variant="primary",
126
+ value="Generate a batch of samples and show "
127
+ "the first one",
128
+ scale=1)
129
+ text2sound_sample_index_slider = gr.Slider(minimum=0, maximum=3, value=0, step=1.0, visible=False,
130
+ label="Sample index",
131
+ info="Swipe to view other samples")
132
+ with gr.Row(variant="panel"):
133
+ with gr.Column(scale=1, variant="panel"):
134
+ text2sound_sample_steps_slider = gradioWebUI.get_sample_steps_slider()
135
+ text2sound_sampler_radio = gradioWebUI.get_sampler_radio()
136
+ text2sound_batchsize_slider = gradioWebUI.get_batchsize_slider(cpu_batchsize=3)
137
+ text2sound_duration_slider = gradioWebUI.get_duration_slider()
138
+ text2sound_guidance_scale_slider = gradioWebUI.get_guidance_scale_slider()
139
+ text2sound_seed_textbox = gradioWebUI.get_seed_textbox()
140
+
141
+ with gr.Column(scale=1):
142
+ with gr.Row(variant="panel"):
143
+ text2sound_sampled_concatenated_spectrogram_image = gr.Image(label="Interpolations", type="numpy",
144
+ height=420, scale=8)
145
+ text2sound_sampled_spectrogram_image = gr.Image(label="Selected spectrogram", type="numpy",
146
+ height=420, scale=1)
147
+ text2sound_sampled_audio = gr.Audio(type="numpy", label="Play")
148
+
149
+ with gr.Row(variant="panel"):
150
+ text2sound_latent_representation_image = gr.Image(label="Sampled latent representation", type="numpy",
151
+ height=200, width=100)
152
+ text2sound_quantized_latent_representation_image = gr.Image(label="Quantized latent representation",
153
+ type="numpy", height=200, width=100)
154
+
155
+ text2sound_sampling_button.click(diffusion_random_sample,
156
+ inputs=[text2sound_prompts_1_textbox,
157
+ text2sound_prompts_2_textbox,
158
+ text2sound_negative_prompts_textbox,
159
+ text2sound_batchsize_slider,
160
+ text2sound_duration_slider,
161
+ text2sound_guidance_scale_slider, text2sound_sampler_radio,
162
+ text2sound_sample_steps_slider,
163
+ text2sound_seed_textbox,
164
+ interpolation_with_text_state],
165
+ outputs=[text2sound_latent_representation_image,
166
+ text2sound_quantized_latent_representation_image,
167
+ text2sound_sampled_concatenated_spectrogram_image,
168
+ text2sound_sampled_spectrogram_image,
169
+ text2sound_sampled_audio,
170
+ text2sound_seed_textbox,
171
+ interpolation_with_text_state,
172
+ text2sound_sample_index_slider])
173
+ text2sound_sample_index_slider.change(show_random_sample,
174
+ inputs=[text2sound_sample_index_slider, interpolation_with_text_state],
175
+ outputs=[text2sound_latent_representation_image,
176
+ text2sound_quantized_latent_representation_image,
177
+ text2sound_sampled_spectrogram_image,
178
+ text2sound_sampled_audio])
webUI/deprecated/interpolationWithXT.py ADDED
@@ -0,0 +1,173 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import numpy as np
3
+ import torch
4
+
5
+ from model.DiffSynthSampler import DiffSynthSampler
6
+ from tools import safe_int
7
+ from webUI.natural_language_guided.utils import encodeBatch2GradioOutput, latent_representation_to_Gradio_image
8
+
9
+
10
+ def get_interpolation_with_xT_module(gradioWebUI, interpolation_with_text_state):
11
+ # Load configurations
12
+ uNet = gradioWebUI.uNet
13
+ freq_resolution, time_resolution = gradioWebUI.freq_resolution, gradioWebUI.time_resolution
14
+ VAE_scale = gradioWebUI.VAE_scale
15
+ height, width, channels = int(freq_resolution/VAE_scale), int(time_resolution/VAE_scale), gradioWebUI.channels
16
+ timesteps = gradioWebUI.timesteps
17
+ VAE_quantizer = gradioWebUI.VAE_quantizer
18
+ VAE_decoder = gradioWebUI.VAE_decoder
19
+ CLAP = gradioWebUI.CLAP
20
+ CLAP_tokenizer = gradioWebUI.CLAP_tokenizer
21
+ device = gradioWebUI.device
22
+ squared = gradioWebUI.squared
23
+ sample_rate = gradioWebUI.sample_rate
24
+ noise_strategy = gradioWebUI.noise_strategy
25
+
26
+ def diffusion_random_sample(text2sound_prompts, text2sound_negative_prompts, text2sound_batchsize,
27
+ text2sound_duration,
28
+ text2sound_noise_variance, text2sound_guidance_scale, text2sound_sampler,
29
+ text2sound_sample_steps, text2sound_seed,
30
+ interpolation_with_text_dict):
31
+ text2sound_sample_steps = int(text2sound_sample_steps)
32
+ text2sound_seed = safe_int(text2sound_seed, 12345678)
33
+ # Todo: take care of text2sound_time_resolution/width
34
+ width = int(time_resolution*((text2sound_duration+1)/4) / VAE_scale)
35
+ text2sound_batchsize = int(text2sound_batchsize)
36
+
37
+ text2sound_embedding = \
38
+ CLAP.get_text_features(**CLAP_tokenizer([text2sound_prompts], padding=True, return_tensors="pt"))[0].to(device)
39
+
40
+ CFG = int(text2sound_guidance_scale)
41
+
42
+ mySampler = DiffSynthSampler(timesteps, height=height, channels=channels, noise_strategy=noise_strategy)
43
+ unconditional_condition = \
44
+ CLAP.get_text_features(**CLAP_tokenizer([text2sound_negative_prompts], padding=True, return_tensors="pt"))[0]
45
+ mySampler.activate_classifier_free_guidance(CFG, unconditional_condition.to(device))
46
+
47
+ mySampler.respace(list(np.linspace(0, timesteps - 1, text2sound_sample_steps, dtype=np.int32)))
48
+
49
+ condition = text2sound_embedding.repeat(text2sound_batchsize, 1)
50
+ latent_representations, initial_noise = \
51
+ mySampler.interpolate(model=uNet, shape=(text2sound_batchsize, channels, height, width),
52
+ seed=text2sound_seed,
53
+ variance=text2sound_noise_variance,
54
+ return_tensor=True, condition=condition, sampler=text2sound_sampler)
55
+
56
+ latent_representations = latent_representations[-1]
57
+
58
+ interpolation_with_text_dict["latent_representations"] = latent_representations
59
+
60
+ latent_representation_gradio_images = []
61
+ quantized_latent_representation_gradio_images = []
62
+ new_sound_spectrogram_gradio_images = []
63
+ new_sound_rec_signals_gradio = []
64
+
65
+ quantized_latent_representations, loss, (_, _, _) = VAE_quantizer(latent_representations)
66
+ # Todo: remove hard-coding
67
+ flipped_log_spectrums, rec_signals = encodeBatch2GradioOutput(VAE_decoder, quantized_latent_representations,
68
+ resolution=(512, width * VAE_scale), centralized=False,
69
+ squared=squared)
70
+
71
+ for i in range(text2sound_batchsize):
72
+ latent_representation_gradio_images.append(latent_representation_to_Gradio_image(latent_representations[i]))
73
+ quantized_latent_representation_gradio_images.append(
74
+ latent_representation_to_Gradio_image(quantized_latent_representations[i]))
75
+ new_sound_spectrogram_gradio_images.append(flipped_log_spectrums[i])
76
+ new_sound_rec_signals_gradio.append((sample_rate, rec_signals[i]))
77
+
78
+ def concatenate_arrays(arrays_list):
79
+ return np.concatenate(arrays_list, axis=1)
80
+
81
+ concatenated_spectrogram_gradio_image = concatenate_arrays(new_sound_spectrogram_gradio_images)
82
+
83
+ interpolation_with_text_dict["latent_representation_gradio_images"] = latent_representation_gradio_images
84
+ interpolation_with_text_dict["quantized_latent_representation_gradio_images"] = quantized_latent_representation_gradio_images
85
+ interpolation_with_text_dict["new_sound_spectrogram_gradio_images"] = new_sound_spectrogram_gradio_images
86
+ interpolation_with_text_dict["new_sound_rec_signals_gradio"] = new_sound_rec_signals_gradio
87
+
88
+ return {text2sound_latent_representation_image: interpolation_with_text_dict["latent_representation_gradio_images"][0],
89
+ text2sound_quantized_latent_representation_image:
90
+ interpolation_with_text_dict["quantized_latent_representation_gradio_images"][0],
91
+ text2sound_sampled_concatenated_spectrogram_image: concatenated_spectrogram_gradio_image,
92
+ text2sound_sampled_spectrogram_image: interpolation_with_text_dict["new_sound_spectrogram_gradio_images"][0],
93
+ text2sound_sampled_audio: interpolation_with_text_dict["new_sound_rec_signals_gradio"][0],
94
+ text2sound_seed_textbox: text2sound_seed,
95
+ interpolation_with_text_state: interpolation_with_text_dict,
96
+ text2sound_sample_index_slider: gr.update(minimum=0, maximum=text2sound_batchsize - 1, value=0, step=1,
97
+ visible=True,
98
+ label="Sample index.",
99
+ info="Swipe to view other samples")}
100
+
101
+ def show_random_sample(sample_index, text2sound_dict):
102
+ sample_index = int(sample_index)
103
+ return {text2sound_latent_representation_image: text2sound_dict["latent_representation_gradio_images"][
104
+ sample_index],
105
+ text2sound_quantized_latent_representation_image:
106
+ text2sound_dict["quantized_latent_representation_gradio_images"][sample_index],
107
+ text2sound_sampled_spectrogram_image: text2sound_dict["new_sound_spectrogram_gradio_images"][sample_index],
108
+ text2sound_sampled_audio: text2sound_dict["new_sound_rec_signals_gradio"][sample_index]}
109
+
110
+ with gr.Tab("InterpolationXT"):
111
+ gr.Markdown("Use interpolation to generate a gradient sound sequence.")
112
+ with gr.Row(variant="panel"):
113
+ with gr.Column(scale=3):
114
+ text2sound_prompts_textbox = gr.Textbox(label="Positive prompt", lines=2, value="organ")
115
+ text2sound_negative_prompts_textbox = gr.Textbox(label="Negative prompt", lines=2, value="")
116
+
117
+ with gr.Column(scale=1):
118
+ text2sound_sampling_button = gr.Button(variant="primary",
119
+ value="Generate a batch of samples and show "
120
+ "the first one",
121
+ scale=1)
122
+ text2sound_sample_index_slider = gr.Slider(minimum=0, maximum=3, value=0, step=1.0, visible=False,
123
+ label="Sample index",
124
+ info="Swipe to view other samples")
125
+ with gr.Row(variant="panel"):
126
+ with gr.Column(scale=1, variant="panel"):
127
+ text2sound_sample_steps_slider = gradioWebUI.get_sample_steps_slider()
128
+ text2sound_sampler_radio = gradioWebUI.get_sampler_radio()
129
+ text2sound_batchsize_slider = gradioWebUI.get_batchsize_slider(cpu_batchsize=3)
130
+ text2sound_duration_slider = gradioWebUI.get_duration_slider()
131
+ text2sound_guidance_scale_slider = gradioWebUI.get_guidance_scale_slider()
132
+ text2sound_seed_textbox = gradioWebUI.get_seed_textbox()
133
+ text2sound_noise_variance_slider = gr.Slider(minimum=0., maximum=5., value=1., step=0.01,
134
+ label="Noise variance",
135
+ info="The larger this value, the more diversity the interpolation has.")
136
+
137
+ with gr.Column(scale=1):
138
+ with gr.Row(variant="panel"):
139
+ text2sound_sampled_concatenated_spectrogram_image = gr.Image(label="Interpolations", type="numpy",
140
+ height=420, scale=8)
141
+ text2sound_sampled_spectrogram_image = gr.Image(label="Selected spectrogram", type="numpy",
142
+ height=420, scale=1)
143
+ text2sound_sampled_audio = gr.Audio(type="numpy", label="Play")
144
+
145
+ with gr.Row(variant="panel"):
146
+ text2sound_latent_representation_image = gr.Image(label="Sampled latent representation", type="numpy",
147
+ height=200, width=100)
148
+ text2sound_quantized_latent_representation_image = gr.Image(label="Quantized latent representation",
149
+ type="numpy", height=200, width=100)
150
+
151
+ text2sound_sampling_button.click(diffusion_random_sample,
152
+ inputs=[text2sound_prompts_textbox, text2sound_negative_prompts_textbox,
153
+ text2sound_batchsize_slider,
154
+ text2sound_duration_slider,
155
+ text2sound_noise_variance_slider,
156
+ text2sound_guidance_scale_slider, text2sound_sampler_radio,
157
+ text2sound_sample_steps_slider,
158
+ text2sound_seed_textbox,
159
+ interpolation_with_text_state],
160
+ outputs=[text2sound_latent_representation_image,
161
+ text2sound_quantized_latent_representation_image,
162
+ text2sound_sampled_concatenated_spectrogram_image,
163
+ text2sound_sampled_spectrogram_image,
164
+ text2sound_sampled_audio,
165
+ text2sound_seed_textbox,
166
+ interpolation_with_text_state,
167
+ text2sound_sample_index_slider])
168
+ text2sound_sample_index_slider.change(show_random_sample,
169
+ inputs=[text2sound_sample_index_slider, interpolation_with_text_state],
170
+ outputs=[text2sound_latent_representation_image,
171
+ text2sound_quantized_latent_representation_image,
172
+ text2sound_sampled_spectrogram_image,
173
+ text2sound_sampled_audio])
webUI/natural_language_guided/GAN.py ADDED
@@ -0,0 +1,164 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import numpy as np
3
+ import torch
4
+
5
+ from tools import safe_int
6
+ from webUI.natural_language_guided_STFT.utils import encodeBatch2GradioOutput, latent_representation_to_Gradio_image, \
7
+ add_instrument
8
+
9
+
10
+ def get_testGAN(gradioWebUI, text2sound_state, virtual_instruments_state):
11
+ # Load configurations
12
+ gan_generator = gradioWebUI.GAN_generator
13
+ freq_resolution, time_resolution = gradioWebUI.freq_resolution, gradioWebUI.time_resolution
14
+ VAE_scale = gradioWebUI.VAE_scale
15
+ height, width, channels = int(freq_resolution / VAE_scale), int(time_resolution / VAE_scale), gradioWebUI.channels
16
+
17
+ timesteps = gradioWebUI.timesteps
18
+ VAE_quantizer = gradioWebUI.VAE_quantizer
19
+ VAE_decoder = gradioWebUI.VAE_decoder
20
+ CLAP = gradioWebUI.CLAP
21
+ CLAP_tokenizer = gradioWebUI.CLAP_tokenizer
22
+ device = gradioWebUI.device
23
+ squared = gradioWebUI.squared
24
+ sample_rate = gradioWebUI.sample_rate
25
+ noise_strategy = gradioWebUI.noise_strategy
26
+
27
+ def gan_random_sample(text2sound_prompts, text2sound_negative_prompts, text2sound_batchsize,
28
+ text2sound_duration,
29
+ text2sound_guidance_scale, text2sound_sampler,
30
+ text2sound_sample_steps, text2sound_seed,
31
+ text2sound_dict):
32
+ text2sound_seed = safe_int(text2sound_seed, 12345678)
33
+
34
+ width = int(time_resolution * ((text2sound_duration + 1) / 4) / VAE_scale)
35
+
36
+ text2sound_batchsize = int(text2sound_batchsize)
37
+
38
+ text2sound_embedding = \
39
+ CLAP.get_text_features(**CLAP_tokenizer([text2sound_prompts], padding=True, return_tensors="pt"))[0].to(
40
+ device)
41
+
42
+ CFG = int(text2sound_guidance_scale)
43
+
44
+ condition = text2sound_embedding.repeat(text2sound_batchsize, 1)
45
+
46
+ noise = torch.randn(text2sound_batchsize, channels, height, width).to(device)
47
+ latent_representations = gan_generator(noise, condition)
48
+
49
+ print(latent_representations[0, 0, :3, :3])
50
+
51
+ latent_representation_gradio_images = []
52
+ quantized_latent_representation_gradio_images = []
53
+ new_sound_spectrogram_gradio_images = []
54
+ new_sound_rec_signals_gradio = []
55
+
56
+ quantized_latent_representations, loss, (_, _, _) = VAE_quantizer(latent_representations)
57
+ # Todo: remove hard-coding
58
+ flipped_log_spectrums, rec_signals = encodeBatch2GradioOutput(VAE_decoder, quantized_latent_representations,
59
+ resolution=(512, width * VAE_scale),
60
+ centralized=False,
61
+ squared=squared)
62
+
63
+ for i in range(text2sound_batchsize):
64
+ latent_representation_gradio_images.append(latent_representation_to_Gradio_image(latent_representations[i]))
65
+ quantized_latent_representation_gradio_images.append(
66
+ latent_representation_to_Gradio_image(quantized_latent_representations[i]))
67
+ new_sound_spectrogram_gradio_images.append(flipped_log_spectrums[i])
68
+ new_sound_rec_signals_gradio.append((sample_rate, rec_signals[i]))
69
+
70
+ text2sound_dict["latent_representations"] = latent_representations.to("cpu").detach().numpy()
71
+ text2sound_dict["quantized_latent_representations"] = quantized_latent_representations.to("cpu").detach().numpy()
72
+ text2sound_dict["latent_representation_gradio_images"] = latent_representation_gradio_images
73
+ text2sound_dict["quantized_latent_representation_gradio_images"] = quantized_latent_representation_gradio_images
74
+ text2sound_dict["new_sound_spectrogram_gradio_images"] = new_sound_spectrogram_gradio_images
75
+ text2sound_dict["new_sound_rec_signals_gradio"] = new_sound_rec_signals_gradio
76
+
77
+ text2sound_dict["condition"] = condition.to("cpu").detach().numpy()
78
+ # text2sound_dict["negative_condition"] = negative_condition.to("cpu").detach().numpy()
79
+ text2sound_dict["guidance_scale"] = CFG
80
+ text2sound_dict["sampler"] = text2sound_sampler
81
+
82
+ return {text2sound_latent_representation_image: text2sound_dict["latent_representation_gradio_images"][0],
83
+ text2sound_quantized_latent_representation_image:
84
+ text2sound_dict["quantized_latent_representation_gradio_images"][0],
85
+ text2sound_sampled_spectrogram_image: text2sound_dict["new_sound_spectrogram_gradio_images"][0],
86
+ text2sound_sampled_audio: text2sound_dict["new_sound_rec_signals_gradio"][0],
87
+ text2sound_seed_textbox: text2sound_seed,
88
+ text2sound_state: text2sound_dict,
89
+ text2sound_sample_index_slider: gr.update(minimum=0, maximum=text2sound_batchsize - 1, value=0, step=1,
90
+ visible=True,
91
+ label="Sample index.",
92
+ info="Swipe to view other samples")}
93
+
94
+ def show_random_sample(sample_index, text2sound_dict):
95
+ sample_index = int(sample_index)
96
+ text2sound_dict["sample_index"] = sample_index
97
+ return {text2sound_latent_representation_image: text2sound_dict["latent_representation_gradio_images"][
98
+ sample_index],
99
+ text2sound_quantized_latent_representation_image:
100
+ text2sound_dict["quantized_latent_representation_gradio_images"][sample_index],
101
+ text2sound_sampled_spectrogram_image: text2sound_dict["new_sound_spectrogram_gradio_images"][
102
+ sample_index],
103
+ text2sound_sampled_audio: text2sound_dict["new_sound_rec_signals_gradio"][sample_index]}
104
+
105
+
106
+ with gr.Tab("Text2sound_GAN"):
107
+ gr.Markdown("Use neural networks to select random sounds using your favorite instrument!")
108
+ with gr.Row(variant="panel"):
109
+ with gr.Column(scale=3):
110
+ text2sound_prompts_textbox = gr.Textbox(label="Positive prompt", lines=2, value="organ")
111
+ text2sound_negative_prompts_textbox = gr.Textbox(label="Negative prompt", lines=2, value="")
112
+
113
+ with gr.Column(scale=1):
114
+ text2sound_sampling_button = gr.Button(variant="primary",
115
+ value="Generate a batch of samples and show "
116
+ "the first one",
117
+ scale=1)
118
+ text2sound_sample_index_slider = gr.Slider(minimum=0, maximum=3, value=0, step=1.0, visible=False,
119
+ label="Sample index",
120
+ info="Swipe to view other samples")
121
+ with gr.Row(variant="panel"):
122
+ with gr.Column(scale=1, variant="panel"):
123
+ text2sound_sample_steps_slider = gradioWebUI.get_sample_steps_slider()
124
+ text2sound_sampler_radio = gradioWebUI.get_sampler_radio()
125
+ text2sound_batchsize_slider = gradioWebUI.get_batchsize_slider()
126
+ text2sound_duration_slider = gradioWebUI.get_duration_slider()
127
+ text2sound_guidance_scale_slider = gradioWebUI.get_guidance_scale_slider()
128
+ text2sound_seed_textbox = gradioWebUI.get_seed_textbox()
129
+
130
+ with gr.Column(scale=1):
131
+ text2sound_sampled_spectrogram_image = gr.Image(label="Sampled spectrogram", type="numpy", height=420)
132
+ text2sound_sampled_audio = gr.Audio(type="numpy", label="Play")
133
+
134
+
135
+ with gr.Row(variant="panel"):
136
+ text2sound_latent_representation_image = gr.Image(label="Sampled latent representation", type="numpy",
137
+ height=200, width=100)
138
+ text2sound_quantized_latent_representation_image = gr.Image(label="Quantized latent representation",
139
+ type="numpy", height=200, width=100)
140
+
141
+ text2sound_sampling_button.click(gan_random_sample,
142
+ inputs=[text2sound_prompts_textbox,
143
+ text2sound_negative_prompts_textbox,
144
+ text2sound_batchsize_slider,
145
+ text2sound_duration_slider,
146
+ text2sound_guidance_scale_slider, text2sound_sampler_radio,
147
+ text2sound_sample_steps_slider,
148
+ text2sound_seed_textbox,
149
+ text2sound_state],
150
+ outputs=[text2sound_latent_representation_image,
151
+ text2sound_quantized_latent_representation_image,
152
+ text2sound_sampled_spectrogram_image,
153
+ text2sound_sampled_audio,
154
+ text2sound_seed_textbox,
155
+ text2sound_state,
156
+ text2sound_sample_index_slider])
157
+
158
+
159
+ text2sound_sample_index_slider.change(show_random_sample,
160
+ inputs=[text2sound_sample_index_slider, text2sound_state],
161
+ outputs=[text2sound_latent_representation_image,
162
+ text2sound_quantized_latent_representation_image,
163
+ text2sound_sampled_spectrogram_image,
164
+ text2sound_sampled_audio])
webUI/natural_language_guided/README.py ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+
3
+ readme_content = """## Stable Diffusion for Sound Generation
4
+
5
+ This project applies stable diffusion[1] to sound generation. Inspired by the work of AUTOMATIC1111, 2022[2], we have implemented a preliminary version of text2sound, sound2sound, inpaint, as well as an additional interpolation feature, all accessible through a web UI.
6
+
7
+ ### Neural Network Training Data:
8
+ The neural network is trained using the filtered NSynth dataset[3], which is a large-scale and high-quality collection of annotated musical notes, comprising 305,979 musical notes. However, for this project, only samples with a pitch set to E3 were used, resulting in an actual training sample size of 4,096, making it a low-resource project.
9
+
10
+ The training took place on an NVIDIA Tesla T4 GPU and spanned approximately 10 hours.
11
+
12
+ ### Natural Language Guidance:
13
+ Natural language guidance is derived from the multi-label annotations of the NSynth dataset. The labels included in the training are:
14
+
15
+ - **Instrument Families**: bass, brass, flute, guitar, keyboard, mallet, organ, reed, string, synth lead, vocal.
16
+
17
+ - **Instrument Sources**: acoustic, electronic, synthetic.
18
+
19
+ - **Note Qualities**: bright, dark, distortion, fast decay, long release, multiphonic, nonlinear env, percussive, reverb, tempo-synced.
20
+
21
+ ### Usage Hints:
22
+
23
+ 1. **Prompt Format**: It's recommended to use the format “label1, label2, label3“, e.g., ”organ, dark, long release“.
24
+
25
+ 2. **Unique Sounds**: If you keep generating the same sound, try setting a different seed!
26
+
27
+ 3. **Sample Indexing**: Drag the "Sample index slider" to view other samples within the generated batch.
28
+
29
+ 4. **Running on CPU**: Be cautious with the settings for 'batchsize' and 'sample_steps' when running on CPU to avoid timeouts. Recommended settings are batchsize ≤ 4 and sample_steps = 15.
30
+
31
+ 5. **Editing Sounds**: Generated audio can be downloaded and then re-uploaded for further editing at the sound2sound/inpaint sections.
32
+
33
+ 6. **Guidance Scale**: A higher 'guidance_scale' intensifies the influence of natural language conditioning on the generation[4]. It's recommended to set it between 3 and 10.
34
+
35
+ 7. **Noising Strength**: A smaller 'noising_strength' value makes the generated sound closer to the input sound.
36
+
37
+ References:
38
+
39
+ [1] Rombach, R., Blattmann, A., Lorenz, D., Esser, P., & Ommer, B. (2022). High-resolution image synthesis with latent diffusion models. In Proceedings of the IEEE/CVF conference on computer vision and pattern recognition (pp. 10684-10695).
40
+
41
+ [2] AUTOMATIC1111. (2022). Stable Diffusion Web UI [Computer software]. Retrieved from https://github.com/AUTOMATIC1111/stable-diffusion-webui
42
+
43
+ [3] Engel, J., Resnick, C., Roberts, A., Dieleman, S., Eck, D., Simonyan, K., & Norouzi, M. (2017). Neural Audio Synthesis of Musical Notes with WaveNet Autoencoders.
44
+
45
+ [4] Ho, J., & Salimans, T. (2022). Classifier-free diffusion guidance. arXiv preprint arXiv:2207.12598.
46
+ """
47
+
48
+ def get_readme_module():
49
+
50
+ with gr.Tab("README"):
51
+ # gr.Markdown("Use interpolation to generate a gradient sound sequence.")
52
+ with gr.Column(scale=3):
53
+ readme_textbox = gr.Textbox(label="readme", lines=40, value=readme_content, interactive=False)
webUI/natural_language_guided/__pycache__/README.cpython-310.pyc ADDED
Binary file (3.51 kB). View file
 
webUI/natural_language_guided/__pycache__/README_STFT.cpython-310.pyc ADDED
Binary file (3.51 kB). View file
 
webUI/natural_language_guided/__pycache__/buildInstrument_STFT.cpython-310.pyc ADDED
Binary file (8.26 kB). View file
 
webUI/natural_language_guided/__pycache__/build_instrument.cpython-310.pyc ADDED
Binary file (8.08 kB). View file
 
webUI/natural_language_guided/__pycache__/gradioWebUI.cpython-310.pyc ADDED
Binary file (3.61 kB). View file
 
webUI/natural_language_guided/__pycache__/gradioWebUI_STFT.cpython-310.pyc ADDED
Binary file (3.62 kB). View file
 
webUI/natural_language_guided/__pycache__/gradio_webUI.cpython-310.pyc ADDED
Binary file (3.61 kB). View file
 
webUI/natural_language_guided/__pycache__/inpaintWithText.cpython-310.pyc ADDED
Binary file (12.5 kB). View file
 
webUI/natural_language_guided/__pycache__/inpaintWithText_STFT.cpython-310.pyc ADDED
Binary file (11.6 kB). View file
 
webUI/natural_language_guided/__pycache__/inpaint_with_text.cpython-310.pyc ADDED
Binary file (12.5 kB). View file
 
webUI/natural_language_guided/__pycache__/rec.cpython-310.pyc ADDED
Binary file (6.11 kB). View file
 
webUI/natural_language_guided/__pycache__/recSTFT.cpython-310.pyc ADDED
Binary file (6.11 kB). View file