File size: 3,829 Bytes
416692d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f113387
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
import gradio as gr
import torch
import numpy as np
from torchaudio.functional import resample

from processAudio import upscaleAudio

class Object(object):
    pass

with gr.Blocks(theme=gr.themes.Default().set(body_background_fill="#CCEEFF")) as layout:
    with gr.Row():
        gr.Markdown("<h2>Broadcast Audio Upscaler</h2>")
    with gr.Row():
        with open("html/directions.html", "r") as directionsHtml:
            gr.Markdown(directionsHtml.read())
    with gr.Row():
        modelSelect = gr.Dropdown(
            [
                ["FM Radio Super Resolution","FM_Radio_SR.th"],
                ["AM Radio Super Resolution (Beta v2)","AM_Radio_SR.th"],
                ["Telephone Super Resolution (Beta)","Telephone_SR.th"] 
            ],
            label="Select Model:",
            value="FM_Radio_SR.th",
        )
    with gr.Row():
        with gr.Column():
            audioFileSelect = gr.Audio(label="Audio File (Mono or Stereo, Max 6 Minutes):",sources="upload", max_length=360)
        with gr.Column():
            audioOutput = gr.Audio(show_download_button=True, label="Restored Audio:", sources=[], max_length=360)
    with gr.Row():
        with gr.Column():
            submit = gr.Button("Process Audio", variant="primary", interactive=False)
    with gr.Row():
        with gr.Accordion("More Information:", open=False):
            with open("html/information.html", "r") as informationHtml:
                gr.Markdown(informationHtml.read())    

    @audioFileSelect.input(inputs=audioFileSelect, outputs=[submit, audioFileSelect])
    def audioFileSelectChanged(audioData: gr.Audio):
        #Audio exists and is mono or stereo
        if audioData is None:
            return gr.update(interactive=False), None
        if len(audioData[1].shape) == 1:
            return gr.update(interactive=True), audioData
        if audioData[1].shape[1] > 2:
            gr.Warning("Audio with more than 2 channels is not supported.")
            return gr.update(interactive=False), None
        return gr.update(interactive=True), audioData
        

    @submit.click(inputs=[modelSelect, audioFileSelect], outputs=audioOutput)
    def processAudio(model: gr.Dropdown, audioData: gr.Audio):
        if audioData is None:
            raise gr.Error("Load an audio file.")
            return None
        elif len(audioData[1].shape) == 1: #Convert mono to stereo
            lrAudio = torch.tensor(np.array([
                audioData[1].copy().astype(np.float32)/32768, 
                audioData[1].copy().astype(np.float32)/32768
            ]))
        elif audioData[1].shape[1] > 2:
            raise gr.Error("Audio with more than 2 channels is not supported.")
            return None
        else: #re-order channel data from [samples, 2] to [2, samples]
            lrAudio = torch.tensor(audioData[1].copy().astype(np.float32)/32768).transpose(0,1)
        if audioData[0] != 44100:
            lrAudio = resample(lrAudio, audioData[0], 44100)
        model_name, experiment_file = getModelInfo(model)
        hrAudio=upscaleAudio(lrAudio, model, model_name=model_name, experiment_file=experiment_file)
        hrAudio=hrAudio / max(hrAudio.abs().max().item(), 1)
        outAudio=(hrAudio*32767).numpy().astype(np.int16).transpose(1,0)
        return tuple([44100, outAudio])
    
    def getModelInfo(modelFilename: str):
        if(modelFilename == "FM_Radio_SR.th"):
            return "aero", "aero_441-441_512_256.yaml"
        if(modelFilename == "AM_Radio_SR.th"):
            return "aero", "aero_441-441_512_256.yaml"
        if(modelFilename == "Telephone_SR.th"):
            return "aero", "aero_441-441_512_256.yaml"
        return "aero", "aero_441-441_512_256.yaml"

layout.launch()