WeixuanYuan's picture
Upload 70 files
bd6e54b verified
import gradio as gr
class GradioWebUI():
def __init__(self, device, VAE, uNet, CLAP, CLAP_tokenizer,
freq_resolution=512, time_resolution=256, channels=4, timesteps=1000,
sample_rate=16000, squared=False, VAE_scale=4,
flexible_duration=False, noise_strategy="repeat",
GAN_generator = None):
self.device = device
self.VAE_encoder, self.VAE_quantizer, self.VAE_decoder = VAE._encoder, VAE._vq_vae, VAE._decoder
self.uNet = uNet
self.CLAP, self.CLAP_tokenizer = CLAP, CLAP_tokenizer
self.freq_resolution, self.time_resolution = freq_resolution, time_resolution
self.channels = channels
self.GAN_generator = GAN_generator
self.timesteps = timesteps
self.sample_rate = sample_rate
self.squared = squared
self.VAE_scale = VAE_scale
self.flexible_duration = flexible_duration
self.noise_strategy = noise_strategy
self.text2sound_state = gr.State(value={})
self.interpolation_state = gr.State(value={})
self.sound2sound_state = gr.State(value={})
self.inpaint_state = gr.State(value={})
def get_sample_steps_slider(self):
default_steps = 10 if (self.device == "cpu") else 20
return gr.Slider(minimum=10, maximum=100, value=default_steps, step=1,
label="Sample steps",
info="Sampling steps. The more sampling steps, the better the "
"theoretical result, but the time it consumes.")
def get_sampler_radio(self):
# return gr.Radio(choices=["ddpm", "ddim", "dpmsolver++", "dpmsolver"], value="ddim", label="Sampler")
return gr.Radio(choices=["ddpm", "ddim"], value="ddim", label="Sampler")
def get_batchsize_slider(self, cpu_batchsize=1):
return gr.Slider(minimum=1., maximum=16, value=cpu_batchsize if (self.device == "cpu") else 8, step=1, label="Batchsize")
def get_time_resolution_slider(self):
return gr.Slider(minimum=16., maximum=int(1024/self.VAE_scale), value=int(256/self.VAE_scale), step=1, label="Time resolution", interactive=True)
def get_duration_slider(self):
if self.flexible_duration:
return gr.Slider(minimum=0.25, maximum=8., value=3., step=0.01, label="duration in sec")
else:
return gr.Slider(minimum=1., maximum=8., value=3., step=1., label="duration in sec")
def get_guidance_scale_slider(self):
return gr.Slider(minimum=0., maximum=20., value=6., step=1.,
label="Guidance scale",
info="The larger this value, the more the generated sound is "
"influenced by the condition. Setting it to 0 is equivalent to "
"the negative case.")
def get_noising_strength_slider(self, default_noising_strength=0.7):
return gr.Slider(minimum=0.0, maximum=1.00, value=default_noising_strength, step=0.01,
label="noising strength",
info="The smaller this value, the more the generated sound is "
"closed to the origin.")
def get_seed_textbox(self):
return gr.Textbox(label="Seed", lines=1, placeholder="seed", value=0)