Spaces:
Running
Running
import gradio as gr | |
class GradioWebUI(): | |
def __init__(self, device, VAE, uNet, CLAP, CLAP_tokenizer, | |
freq_resolution=512, time_resolution=256, channels=4, timesteps=1000, | |
sample_rate=16000, squared=False, VAE_scale=4, | |
flexible_duration=False, noise_strategy="repeat", | |
GAN_generator = None): | |
self.device = device | |
self.VAE_encoder, self.VAE_quantizer, self.VAE_decoder = VAE._encoder, VAE._vq_vae, VAE._decoder | |
self.uNet = uNet | |
self.CLAP, self.CLAP_tokenizer = CLAP, CLAP_tokenizer | |
self.freq_resolution, self.time_resolution = freq_resolution, time_resolution | |
self.channels = channels | |
self.GAN_generator = GAN_generator | |
self.timesteps = timesteps | |
self.sample_rate = sample_rate | |
self.squared = squared | |
self.VAE_scale = VAE_scale | |
self.flexible_duration = flexible_duration | |
self.noise_strategy = noise_strategy | |
self.text2sound_state = gr.State(value={}) | |
self.interpolation_state = gr.State(value={}) | |
self.sound2sound_state = gr.State(value={}) | |
self.inpaint_state = gr.State(value={}) | |
def get_sample_steps_slider(self): | |
default_steps = 10 if (self.device == "cpu") else 20 | |
return gr.Slider(minimum=10, maximum=100, value=default_steps, step=1, | |
label="Sample steps", | |
info="Sampling steps. The more sampling steps, the better the " | |
"theoretical result, but the time it consumes.") | |
def get_sampler_radio(self): | |
# return gr.Radio(choices=["ddpm", "ddim", "dpmsolver++", "dpmsolver"], value="ddim", label="Sampler") | |
return gr.Radio(choices=["ddpm", "ddim"], value="ddim", label="Sampler") | |
def get_batchsize_slider(self, cpu_batchsize=1): | |
return gr.Slider(minimum=1., maximum=16, value=cpu_batchsize if (self.device == "cpu") else 8, step=1, label="Batchsize") | |
def get_time_resolution_slider(self): | |
return gr.Slider(minimum=16., maximum=int(1024/self.VAE_scale), value=int(256/self.VAE_scale), step=1, label="Time resolution", interactive=True) | |
def get_duration_slider(self): | |
if self.flexible_duration: | |
return gr.Slider(minimum=0.25, maximum=8., value=3., step=0.01, label="duration in sec") | |
else: | |
return gr.Slider(minimum=1., maximum=8., value=3., step=1., label="duration in sec") | |
def get_guidance_scale_slider(self): | |
return gr.Slider(minimum=0., maximum=20., value=6., step=1., | |
label="Guidance scale", | |
info="The larger this value, the more the generated sound is " | |
"influenced by the condition. Setting it to 0 is equivalent to " | |
"the negative case.") | |
def get_noising_strength_slider(self, default_noising_strength=0.7): | |
return gr.Slider(minimum=0.0, maximum=1.00, value=default_noising_strength, step=0.01, | |
label="noising strength", | |
info="The smaller this value, the more the generated sound is " | |
"closed to the origin.") | |
def get_seed_textbox(self): | |
return gr.Textbox(label="Seed", lines=1, placeholder="seed", value=0) | |