sereich's picture
Add enablement logic for processing button.
9b33c6e
raw
history blame
3.16 kB
import gradio as gr
import torch
import numpy as np
from torchaudio.functional import resample
from processAudio import upscaleAudio
class Object(object):
pass
with gr.Blocks(theme=gr.themes.Default().set(body_background_fill="#CCEEFF")) as layout:
with gr.Row():
gr.Markdown("<h2>Broadcast Audio Upscaler</h2>")
with gr.Row():
with open("html/directions.html", "r") as directionsHtml:
gr.Markdown(directionsHtml.read())
with gr.Row():
modelSelect = gr.Dropdown(
[
["FM Radio Super Resolution","FM_Radio_SR.th"],
["AM Radio Super Resolution (Beta)","AM_Radio_SR.th"]
],
label="Select Model:",
value="FM_Radio_SR.th",
)
with gr.Row():
with gr.Column():
audioFileSelect = gr.Audio(label="Audio File (Mono or Stereo, Max 6 Minutes):",sources="upload", max_length=360)
with gr.Column():
audioOutput = gr.Audio(show_download_button=True, label="Restored Audio:", sources=[], max_length=360)
with gr.Row():
with gr.Column():
submit = gr.Button("Process Audio", variant="primary", interactive=False)
with gr.Row():
with gr.Accordion("More Information:", open=False):
with open("html/information.html", "r") as informationHtml:
gr.Markdown(informationHtml.read())
@audioFileSelect.input(inputs=audioFileSelect, outputs=[submit, audioFileSelect])
def audioFileSelectChanged(audioData: gr.Audio):
#Audio exists and is mono or stereo
if audioData is None:
return gr.update(interactive=False), None
if len(audioData[1].shape) == 1:
return gr.update(interactive=True), audioData
if audioData[1].shape[1] > 2:
gr.Warning("Audio with more than 2 channels is not supported.")
return gr.update(interactive=False), None
return gr.update(interactive=True), audioData
@submit.click(inputs=[modelSelect, audioFileSelect], outputs=audioOutput)
def processAudio(model: gr.Dropdown, audioData: gr.Audio):
if audioData is None:
raise gr.Error("Load an audio file.")
return None
elif len(audioData[1].shape) == 1: #Convert mono to stereo
lrAudio = torch.tensor(np.array([
audioData[1].copy().astype(np.float32)/32768,
audioData[1].copy().astype(np.float32)/32768
]))
elif audioData[1].shape[1] > 2:
raise gr.Error("Audio with more than 2 channels is not supported.")
return None
else: #re-order channel data from [samples, 2] to [2, samples]
lrAudio = torch.tensor(audioData[1].copy().astype(np.float32)/32768).transpose(0,1)
if audioData[0] != 44100:
lrAudio = resample(lrAudio, audioData[0], 44100)
hrAudio=upscaleAudio(lrAudio, "models/" + model)
hrAudio=hrAudio / max(hrAudio.abs().max().item(), 1)
outAudio=(hrAudio*32767).numpy().astype(np.int16).transpose(1,0)
return tuple([44100, outAudio])
layout.launch()