WeixuanYuan commited on
Commit
8ab6976
·
verified ·
1 Parent(s): 73b5f87

Update webUI/natural_language_guided_4/text2sound.py

Browse files
webUI/natural_language_guided_4/text2sound.py CHANGED
@@ -1,220 +1,220 @@
1
- import gradio as gr
2
- import numpy as np
3
-
4
- from model.DiffSynthSampler import DiffSynthSampler
5
- from tools import safe_int
6
- from webUI.natural_language_guided_4.utils import latent_representation_to_Gradio_image, \
7
- encodeBatch2GradioOutput_STFT, add_instrument, resize_image_to_aspect_ratio
8
-
9
-
10
- def get_text2sound_module(gradioWebUI, text2sound_state, virtual_instruments_state):
11
- # Load configurations
12
- uNet = gradioWebUI.uNet
13
- freq_resolution, time_resolution = gradioWebUI.freq_resolution, gradioWebUI.time_resolution
14
- VAE_scale = gradioWebUI.VAE_scale
15
- height, width, channels = int(freq_resolution / VAE_scale), int(time_resolution / VAE_scale), gradioWebUI.channels
16
-
17
- timesteps = gradioWebUI.timesteps
18
- VAE_quantizer = gradioWebUI.VAE_quantizer
19
- VAE_decoder = gradioWebUI.VAE_decoder
20
- CLAP = gradioWebUI.CLAP
21
- CLAP_tokenizer = gradioWebUI.CLAP_tokenizer
22
- device = gradioWebUI.device
23
- squared = gradioWebUI.squared
24
- sample_rate = gradioWebUI.sample_rate
25
- noise_strategy = gradioWebUI.noise_strategy
26
-
27
- def diffusion_random_sample(text2sound_prompts, text2sound_negative_prompts, text2sound_batchsize,
28
- text2sound_duration,
29
- text2sound_guidance_scale, text2sound_sampler,
30
- text2sound_sample_steps, text2sound_seed,
31
- text2sound_dict):
32
- text2sound_sample_steps = int(text2sound_sample_steps)
33
- text2sound_seed = safe_int(text2sound_seed, 12345678)
34
-
35
- width = int(time_resolution * ((text2sound_duration + 1) / 4) / VAE_scale)
36
-
37
- text2sound_batchsize = int(text2sound_batchsize)
38
-
39
- text2sound_embedding = \
40
- CLAP.get_text_features(**CLAP_tokenizer([text2sound_prompts], padding=True, return_tensors="pt"))[0].to(
41
- device)
42
-
43
- CFG = int(text2sound_guidance_scale)
44
-
45
- mySampler = DiffSynthSampler(timesteps, height=height, channels=channels, noise_strategy=noise_strategy)
46
- negative_condition = \
47
- CLAP.get_text_features(**CLAP_tokenizer([text2sound_negative_prompts], padding=True, return_tensors="pt"))[
48
- 0]
49
-
50
- mySampler.activate_classifier_free_guidance(CFG, negative_condition.to(device))
51
-
52
- mySampler.respace(list(np.linspace(0, timesteps - 1, text2sound_sample_steps, dtype=np.int32)))
53
-
54
- condition = text2sound_embedding.repeat(text2sound_batchsize, 1)
55
-
56
- latent_representations, initial_noise = \
57
- mySampler.sample(model=uNet, shape=(text2sound_batchsize, channels, height, width), seed=text2sound_seed,
58
- return_tensor=True, condition=condition, sampler=text2sound_sampler)
59
-
60
- latent_representations = latent_representations[-1]
61
-
62
- latent_representation_gradio_images = []
63
- quantized_latent_representation_gradio_images = []
64
- new_sound_spectrogram_gradio_images = []
65
- new_sound_phase_gradio_images = []
66
- new_sound_rec_signals_gradio = []
67
-
68
- quantized_latent_representations, loss, (_, _, _) = VAE_quantizer(latent_representations)
69
- # Todo: remove hard-coding
70
- flipped_log_spectrums, flipped_phases, rec_signals, _, _, _ = encodeBatch2GradioOutput_STFT(VAE_decoder,
71
- quantized_latent_representations,
72
- resolution=(
73
- 512,
74
- width * VAE_scale),
75
- original_STFT_batch=None
76
- )
77
-
78
- for i in range(text2sound_batchsize):
79
- latent_representation_gradio_images.append(latent_representation_to_Gradio_image(latent_representations[i]))
80
- quantized_latent_representation_gradio_images.append(
81
- latent_representation_to_Gradio_image(quantized_latent_representations[i]))
82
- new_sound_spectrogram_gradio_images.append(flipped_log_spectrums[i])
83
- new_sound_phase_gradio_images.append(flipped_phases[i])
84
- new_sound_rec_signals_gradio.append((sample_rate, rec_signals[i]))
85
-
86
- text2sound_dict["latent_representation_gradio_images"] = latent_representation_gradio_images
87
- text2sound_dict["quantized_latent_representation_gradio_images"] = quantized_latent_representation_gradio_images
88
- text2sound_dict["new_sound_spectrogram_gradio_images"] = new_sound_spectrogram_gradio_images
89
- text2sound_dict["new_sound_phase_gradio_images"] = new_sound_phase_gradio_images
90
- text2sound_dict["new_sound_rec_signals_gradio"] = new_sound_rec_signals_gradio
91
-
92
- # save instrument
93
- text2sound_dict["latent_representations"] = latent_representations.to("cpu").detach().numpy()
94
- text2sound_dict["quantized_latent_representations"] = quantized_latent_representations.to(
95
- "cpu").detach().numpy()
96
- text2sound_dict["condition"] = condition.to("cpu").detach().numpy()
97
- text2sound_dict["negative_condition"] = negative_condition.to("cpu").detach().numpy()
98
- text2sound_dict["guidance_scale"] = CFG
99
- text2sound_dict["sampler"] = text2sound_sampler
100
-
101
- return {text2sound_latent_representation_image: text2sound_dict["latent_representation_gradio_images"][0],
102
- text2sound_quantized_latent_representation_image:
103
- text2sound_dict["quantized_latent_representation_gradio_images"][0],
104
- text2sound_sampled_spectrogram_image: resize_image_to_aspect_ratio(
105
- text2sound_dict["new_sound_spectrogram_gradio_images"][0],
106
- 1.55,
107
- 1),
108
- text2sound_sampled_phase_image: resize_image_to_aspect_ratio(
109
- text2sound_dict["new_sound_phase_gradio_images"][0],
110
- 1.55,
111
- 1),
112
- text2sound_sampled_audio: text2sound_dict["new_sound_rec_signals_gradio"][0],
113
- text2sound_seed_textbox: text2sound_seed,
114
- text2sound_state: text2sound_dict,
115
- text2sound_sample_index_slider: gr.update(minimum=0, maximum=text2sound_batchsize - 1, value=0, step=1,
116
- visible=True,
117
- label="Sample index.",
118
- info="Swipe to view other samples")}
119
-
120
- def show_random_sample(sample_index, text2sound_dict):
121
- sample_index = int(sample_index)
122
- text2sound_dict["sample_index"] = sample_index
123
- print(text2sound_dict["new_sound_rec_signals_gradio"][sample_index])
124
- return {text2sound_latent_representation_image: text2sound_dict["latent_representation_gradio_images"][
125
- sample_index],
126
- text2sound_quantized_latent_representation_image:
127
- text2sound_dict["quantized_latent_representation_gradio_images"][sample_index],
128
- text2sound_sampled_spectrogram_image: resize_image_to_aspect_ratio(
129
- text2sound_dict["new_sound_spectrogram_gradio_images"][sample_index], 1.55, 1),
130
- text2sound_sampled_phase_image: resize_image_to_aspect_ratio(text2sound_dict["new_sound_phase_gradio_images"][
131
- sample_index], 1.55, 1),
132
- text2sound_sampled_audio: text2sound_dict["new_sound_rec_signals_gradio"][sample_index]}
133
-
134
- def save_virtual_instrument(sample_index, virtual_instrument_name, text2sound_dict, virtual_instruments_dict):
135
- virtual_instruments_dict = add_instrument(text2sound_dict, virtual_instruments_dict, virtual_instrument_name,
136
- sample_index)
137
-
138
- return {virtual_instruments_state: virtual_instruments_dict,
139
- text2sound_instrument_name_textbox: gr.Textbox(label="Instrument name", lines=1,
140
- placeholder=f"Saved as {virtual_instrument_name}!")}
141
-
142
- with gr.Tab("Text2sound"):
143
- gr.Markdown("Use neural networks to select random sounds using your favorite instrument!")
144
- with gr.Row(variant="panel"):
145
- with gr.Column(scale=3):
146
- text2sound_prompts_textbox = gr.Textbox(label="Positive prompt", lines=2, value="organ")
147
- text2sound_negative_prompts_textbox = gr.Textbox(label="Negative prompt", lines=2, value="")
148
-
149
- with gr.Column(scale=1):
150
- text2sound_sampling_button = gr.Button(variant="primary",
151
- value="Generate a batch of samples and show "
152
- "the first one",
153
- scale=1)
154
- text2sound_sample_index_slider = gr.Slider(minimum=0, maximum=3, value=0, step=1.0, visible=False,
155
- label="Sample index",
156
- info="Swipe to view other samples")
157
- with gr.Row(variant="panel"):
158
- with gr.Column(variant="panel", scale=1):
159
- text2sound_sample_steps_slider = gradioWebUI.get_sample_steps_slider()
160
- text2sound_sampler_radio = gradioWebUI.get_sampler_radio()
161
- text2sound_batchsize_slider = gradioWebUI.get_batchsize_slider()
162
- text2sound_duration_slider = gradioWebUI.get_duration_slider()
163
- text2sound_guidance_scale_slider = gradioWebUI.get_guidance_scale_slider()
164
- text2sound_seed_textbox = gradioWebUI.get_seed_textbox()
165
-
166
- with gr.Column(variant="panel", scale=1):
167
- with gr.Row(variant="panel", ):
168
- text2sound_sampled_spectrogram_image = gr.Image(label="Sampled spectrogram", type="numpy", )
169
- text2sound_sampled_phase_image = gr.Image(label="Sampled phase", type="numpy")
170
- text2sound_sampled_audio = gr.Audio(type="numpy", label="Play",
171
- scale=1)
172
-
173
- with gr.Row(variant="panel", ):
174
- text2sound_instrument_name_textbox = gr.Textbox(label="Instrument name", lines=2,
175
- placeholder="Name of your instrument",
176
- scale=1)
177
- text2sound_save_instrument_button = gr.Button(variant="primary",
178
- value="Save instrument",
179
- scale=1)
180
-
181
- with gr.Row(variant="panel"):
182
- text2sound_latent_representation_image = gr.Image(label="Sampled latent representation", type="numpy",
183
- height=200, width=100, visible=False)
184
- text2sound_quantized_latent_representation_image = gr.Image(label="Quantized latent representation",
185
- type="numpy", height=200, width=100,
186
- visible=False)
187
-
188
- text2sound_sampling_button.click(diffusion_random_sample,
189
- inputs=[text2sound_prompts_textbox,
190
- text2sound_negative_prompts_textbox,
191
- text2sound_batchsize_slider,
192
- text2sound_duration_slider,
193
- text2sound_guidance_scale_slider, text2sound_sampler_radio,
194
- text2sound_sample_steps_slider,
195
- text2sound_seed_textbox,
196
- text2sound_state],
197
- outputs=[text2sound_latent_representation_image,
198
- text2sound_quantized_latent_representation_image,
199
- text2sound_sampled_spectrogram_image,
200
- text2sound_sampled_phase_image,
201
- text2sound_sampled_audio,
202
- text2sound_seed_textbox,
203
- text2sound_state,
204
- text2sound_sample_index_slider])
205
-
206
- text2sound_save_instrument_button.click(save_virtual_instrument,
207
- inputs=[text2sound_sample_index_slider,
208
- text2sound_instrument_name_textbox,
209
- text2sound_state,
210
- virtual_instruments_state],
211
- outputs=[virtual_instruments_state,
212
- text2sound_instrument_name_textbox])
213
-
214
- text2sound_sample_index_slider.change(show_random_sample,
215
- inputs=[text2sound_sample_index_slider, text2sound_state],
216
- outputs=[text2sound_latent_representation_image,
217
- text2sound_quantized_latent_representation_image,
218
- text2sound_sampled_spectrogram_image,
219
- text2sound_sampled_phase_image,
220
- text2sound_sampled_audio])
 
1
+ import gradio as gr
2
+ import numpy as np
3
+
4
+ from model.DiffSynthSampler import DiffSynthSampler
5
+ from tools import safe_int
6
+ from webUI.natural_language_guided_4.utils import latent_representation_to_Gradio_image, \
7
+ encodeBatch2GradioOutput_STFT, add_instrument, resize_image_to_aspect_ratio
8
+
9
+
10
+ def get_text2sound_module(gradioWebUI, text2sound_state, virtual_instruments_state):
11
+ # Load configurations
12
+ uNet = gradioWebUI.uNet
13
+ freq_resolution, time_resolution = gradioWebUI.freq_resolution, gradioWebUI.time_resolution
14
+ VAE_scale = gradioWebUI.VAE_scale
15
+ height, width, channels = int(freq_resolution / VAE_scale), int(time_resolution / VAE_scale), gradioWebUI.channels
16
+
17
+ timesteps = gradioWebUI.timesteps
18
+ VAE_quantizer = gradioWebUI.VAE_quantizer
19
+ VAE_decoder = gradioWebUI.VAE_decoder
20
+ CLAP = gradioWebUI.CLAP
21
+ CLAP_tokenizer = gradioWebUI.CLAP_tokenizer
22
+ device = gradioWebUI.device
23
+ squared = gradioWebUI.squared
24
+ sample_rate = gradioWebUI.sample_rate
25
+ noise_strategy = gradioWebUI.noise_strategy
26
+
27
+ def diffusion_random_sample(text2sound_prompts, text2sound_negative_prompts, text2sound_batchsize,
28
+ text2sound_duration,
29
+ text2sound_guidance_scale, text2sound_sampler,
30
+ text2sound_sample_steps, text2sound_seed,
31
+ text2sound_dict):
32
+ text2sound_sample_steps = int(text2sound_sample_steps)
33
+ text2sound_seed = safe_int(text2sound_seed, 12345678)
34
+
35
+ width = int(time_resolution * ((text2sound_duration + 1) / 4) / VAE_scale)
36
+
37
+ text2sound_batchsize = int(text2sound_batchsize)
38
+
39
+ text2sound_embedding = \
40
+ CLAP.get_text_features(**CLAP_tokenizer([text2sound_prompts], padding=True, return_tensors="pt"))[0].to(
41
+ device)
42
+
43
+ CFG = int(text2sound_guidance_scale)
44
+
45
+ mySampler = DiffSynthSampler(timesteps, height=height, channels=channels, noise_strategy=noise_strategy)
46
+ negative_condition = \
47
+ CLAP.get_text_features(**CLAP_tokenizer([text2sound_negative_prompts], padding=True, return_tensors="pt"))[
48
+ 0]
49
+
50
+ mySampler.activate_classifier_free_guidance(CFG, negative_condition.to(device))
51
+
52
+ mySampler.respace(list(np.linspace(0, timesteps - 1, text2sound_sample_steps, dtype=np.int32)))
53
+
54
+ condition = text2sound_embedding.repeat(text2sound_batchsize, 1)
55
+
56
+ latent_representations, initial_noise = \
57
+ mySampler.sample(model=uNet, shape=(text2sound_batchsize, channels, height, width), seed=text2sound_seed,
58
+ return_tensor=True, condition=condition, sampler=text2sound_sampler)
59
+
60
+ latent_representations = latent_representations[-1]
61
+
62
+ latent_representation_gradio_images = []
63
+ quantized_latent_representation_gradio_images = []
64
+ new_sound_spectrogram_gradio_images = []
65
+ new_sound_phase_gradio_images = []
66
+ new_sound_rec_signals_gradio = []
67
+
68
+ quantized_latent_representations, loss, (_, _, _) = VAE_quantizer(latent_representations)
69
+ # Todo: remove hard-coding
70
+ flipped_log_spectrums, flipped_phases, rec_signals, _, _, _ = encodeBatch2GradioOutput_STFT(VAE_decoder,
71
+ quantized_latent_representations,
72
+ resolution=(
73
+ 512,
74
+ width * VAE_scale),
75
+ original_STFT_batch=None
76
+ )
77
+
78
+ for i in range(text2sound_batchsize):
79
+ latent_representation_gradio_images.append(latent_representation_to_Gradio_image(latent_representations[i]))
80
+ quantized_latent_representation_gradio_images.append(
81
+ latent_representation_to_Gradio_image(quantized_latent_representations[i]))
82
+ new_sound_spectrogram_gradio_images.append(flipped_log_spectrums[i])
83
+ new_sound_phase_gradio_images.append(flipped_phases[i])
84
+ new_sound_rec_signals_gradio.append((sample_rate, rec_signals[i]))
85
+
86
+ text2sound_dict["latent_representation_gradio_images"] = latent_representation_gradio_images
87
+ text2sound_dict["quantized_latent_representation_gradio_images"] = quantized_latent_representation_gradio_images
88
+ text2sound_dict["new_sound_spectrogram_gradio_images"] = new_sound_spectrogram_gradio_images
89
+ text2sound_dict["new_sound_phase_gradio_images"] = new_sound_phase_gradio_images
90
+ text2sound_dict["new_sound_rec_signals_gradio"] = new_sound_rec_signals_gradio
91
+
92
+ # save instrument
93
+ text2sound_dict["latent_representations"] = latent_representations.to("cpu").detach().numpy()
94
+ text2sound_dict["quantized_latent_representations"] = quantized_latent_representations.to(
95
+ "cpu").detach().numpy()
96
+ text2sound_dict["condition"] = condition.to("cpu").detach().numpy()
97
+ text2sound_dict["negative_condition"] = negative_condition.to("cpu").detach().numpy()
98
+ text2sound_dict["guidance_scale"] = CFG
99
+ text2sound_dict["sampler"] = text2sound_sampler
100
+
101
+ return {text2sound_latent_representation_image: text2sound_dict["latent_representation_gradio_images"][0],
102
+ text2sound_quantized_latent_representation_image:
103
+ text2sound_dict["quantized_latent_representation_gradio_images"][0],
104
+ text2sound_sampled_spectrogram_image: resize_image_to_aspect_ratio(
105
+ text2sound_dict["new_sound_spectrogram_gradio_images"][0],
106
+ 1.55,
107
+ 1),
108
+ text2sound_sampled_phase_image: resize_image_to_aspect_ratio(
109
+ text2sound_dict["new_sound_phase_gradio_images"][0],
110
+ 1.55,
111
+ 1),
112
+ text2sound_sampled_audio: text2sound_dict["new_sound_rec_signals_gradio"][0],
113
+ text2sound_seed_textbox: text2sound_seed,
114
+ text2sound_state: text2sound_dict,
115
+ text2sound_sample_index_slider: gr.update(minimum=0, maximum=text2sound_batchsize - 1, value=0, step=1,
116
+ visible=True,
117
+ label="Sample index.",
118
+ info="Swipe to view other samples")}
119
+
120
+ def show_random_sample(sample_index, text2sound_dict):
121
+ sample_index = int(sample_index)
122
+ text2sound_dict["sample_index"] = sample_index
123
+ print(text2sound_dict["new_sound_rec_signals_gradio"][sample_index])
124
+ return {text2sound_latent_representation_image: text2sound_dict["latent_representation_gradio_images"][
125
+ sample_index],
126
+ text2sound_quantized_latent_representation_image:
127
+ text2sound_dict["quantized_latent_representation_gradio_images"][sample_index],
128
+ text2sound_sampled_spectrogram_image: resize_image_to_aspect_ratio(
129
+ text2sound_dict["new_sound_spectrogram_gradio_images"][sample_index], 1.55, 1),
130
+ text2sound_sampled_phase_image: resize_image_to_aspect_ratio(text2sound_dict["new_sound_phase_gradio_images"][
131
+ sample_index], 1.55, 1),
132
+ text2sound_sampled_audio: text2sound_dict["new_sound_rec_signals_gradio"][sample_index]}
133
+
134
+ def save_virtual_instrument(sample_index, virtual_instrument_name, text2sound_dict, virtual_instruments_dict):
135
+ virtual_instruments_dict = add_instrument(text2sound_dict, virtual_instruments_dict, virtual_instrument_name,
136
+ sample_index)
137
+
138
+ return {virtual_instruments_state: virtual_instruments_dict,
139
+ text2sound_instrument_name_textbox: gr.Textbox(label="Instrument name", lines=1,
140
+ placeholder=f"Saved as {virtual_instrument_name}!")}
141
+
142
+ with gr.Tab("Text2sound"):
143
+ gr.Markdown("Use neural networks to select random sounds using your favorite instrument!")
144
+ with gr.Row(variant="panel"):
145
+ with gr.Column(scale=3):
146
+ text2sound_prompts_textbox = gr.Textbox(label="Positive prompt", lines=2, value="string")
147
+ text2sound_negative_prompts_textbox = gr.Textbox(label="Negative prompt", lines=2, value="")
148
+
149
+ with gr.Column(scale=1):
150
+ text2sound_sampling_button = gr.Button(variant="primary",
151
+ value="Generate a batch of samples and show "
152
+ "the first one",
153
+ scale=1)
154
+ text2sound_sample_index_slider = gr.Slider(minimum=0, maximum=3, value=0, step=1.0, visible=False,
155
+ label="Sample index",
156
+ info="Swipe to view other samples")
157
+ with gr.Row(variant="panel"):
158
+ with gr.Column(variant="panel", scale=1):
159
+ text2sound_sample_steps_slider = gradioWebUI.get_sample_steps_slider()
160
+ text2sound_sampler_radio = gradioWebUI.get_sampler_radio()
161
+ text2sound_batchsize_slider = gradioWebUI.get_batchsize_slider()
162
+ text2sound_duration_slider = gradioWebUI.get_duration_slider()
163
+ text2sound_guidance_scale_slider = gradioWebUI.get_guidance_scale_slider()
164
+ text2sound_seed_textbox = gradioWebUI.get_seed_textbox()
165
+
166
+ with gr.Column(variant="panel", scale=1):
167
+ with gr.Row(variant="panel", ):
168
+ text2sound_sampled_spectrogram_image = gr.Image(label="Sampled spectrogram", type="numpy", )
169
+ text2sound_sampled_phase_image = gr.Image(label="Sampled phase", type="numpy")
170
+ text2sound_sampled_audio = gr.Audio(type="numpy", label="Play",
171
+ scale=1)
172
+
173
+ with gr.Row(variant="panel", ):
174
+ text2sound_instrument_name_textbox = gr.Textbox(label="Instrument name", lines=2,
175
+ placeholder="Name of your instrument",
176
+ scale=1)
177
+ text2sound_save_instrument_button = gr.Button(variant="primary",
178
+ value="Save instrument",
179
+ scale=1)
180
+
181
+ with gr.Row(variant="panel"):
182
+ text2sound_latent_representation_image = gr.Image(label="Sampled latent representation", type="numpy",
183
+ height=200, width=100, visible=False)
184
+ text2sound_quantized_latent_representation_image = gr.Image(label="Quantized latent representation",
185
+ type="numpy", height=200, width=100,
186
+ visible=False)
187
+
188
+ text2sound_sampling_button.click(diffusion_random_sample,
189
+ inputs=[text2sound_prompts_textbox,
190
+ text2sound_negative_prompts_textbox,
191
+ text2sound_batchsize_slider,
192
+ text2sound_duration_slider,
193
+ text2sound_guidance_scale_slider, text2sound_sampler_radio,
194
+ text2sound_sample_steps_slider,
195
+ text2sound_seed_textbox,
196
+ text2sound_state],
197
+ outputs=[text2sound_latent_representation_image,
198
+ text2sound_quantized_latent_representation_image,
199
+ text2sound_sampled_spectrogram_image,
200
+ text2sound_sampled_phase_image,
201
+ text2sound_sampled_audio,
202
+ text2sound_seed_textbox,
203
+ text2sound_state,
204
+ text2sound_sample_index_slider])
205
+
206
+ text2sound_save_instrument_button.click(save_virtual_instrument,
207
+ inputs=[text2sound_sample_index_slider,
208
+ text2sound_instrument_name_textbox,
209
+ text2sound_state,
210
+ virtual_instruments_state],
211
+ outputs=[virtual_instruments_state,
212
+ text2sound_instrument_name_textbox])
213
+
214
+ text2sound_sample_index_slider.change(show_random_sample,
215
+ inputs=[text2sound_sample_index_slider, text2sound_state],
216
+ outputs=[text2sound_latent_representation_image,
217
+ text2sound_quantized_latent_representation_image,
218
+ text2sound_sampled_spectrogram_image,
219
+ text2sound_sampled_phase_image,
220
+ text2sound_sampled_audio])