import gradio as gr class GradioWebUI(): def __init__(self, device, VAE, uNet, CLAP, CLAP_tokenizer, freq_resolution=512, time_resolution=256, channels=4, timesteps=1000, sample_rate=16000, squared=False, VAE_scale=4, flexible_duration=False, noise_strategy="repeat", GAN_generator = None): self.device = device self.VAE_encoder, self.VAE_quantizer, self.VAE_decoder = VAE._encoder, VAE._vq_vae, VAE._decoder self.uNet = uNet self.CLAP, self.CLAP_tokenizer = CLAP, CLAP_tokenizer self.freq_resolution, self.time_resolution = freq_resolution, time_resolution self.channels = channels self.GAN_generator = GAN_generator self.timesteps = timesteps self.sample_rate = sample_rate self.squared = squared self.VAE_scale = VAE_scale self.flexible_duration = flexible_duration self.noise_strategy = noise_strategy self.text2sound_state = gr.State(value={}) self.interpolation_state = gr.State(value={}) self.sound2sound_state = gr.State(value={}) self.inpaint_state = gr.State(value={}) def get_sample_steps_slider(self): default_steps = 10 if (self.device == "cpu") else 20 return gr.Slider(minimum=10, maximum=100, value=default_steps, step=1, label="Sample steps", info="Sampling steps. The more sampling steps, the better the " "theoretical result, but the time it consumes.") def get_sampler_radio(self): # return gr.Radio(choices=["ddpm", "ddim", "dpmsolver++", "dpmsolver"], value="ddim", label="Sampler") return gr.Radio(choices=["ddpm", "ddim"], value="ddim", label="Sampler") def get_batchsize_slider(self, cpu_batchsize=1): return gr.Slider(minimum=1., maximum=16, value=cpu_batchsize if (self.device == "cpu") else 8, step=1, label="Batchsize") def get_time_resolution_slider(self): return gr.Slider(minimum=16., maximum=int(1024/self.VAE_scale), value=int(256/self.VAE_scale), step=1, label="Time resolution", interactive=True) def get_duration_slider(self): if self.flexible_duration: return gr.Slider(minimum=0.25, maximum=8., value=3., step=0.01, label="duration in sec") else: return gr.Slider(minimum=1., maximum=8., value=3., step=1., label="duration in sec") def get_guidance_scale_slider(self): return gr.Slider(minimum=0., maximum=20., value=6., step=1., label="Guidance scale", info="The larger this value, the more the generated sound is " "influenced by the condition. Setting it to 0 is equivalent to " "the negative case.") def get_noising_strength_slider(self, default_noising_strength=0.7): return gr.Slider(minimum=0.0, maximum=1.00, value=default_noising_strength, step=0.01, label="noising strength", info="The smaller this value, the more the generated sound is " "closed to the origin.") def get_seed_textbox(self): return gr.Textbox(label="Seed", lines=1, placeholder="seed", value=0)