Add phone model (beta), allow models to use different architectures
Browse files- app.py +13 -2
- models/Telephone_SR.th +3 -0
- processAudio.py +5 -5
- src/models/modelFactory.py +3 -3
app.py
CHANGED
@@ -18,7 +18,8 @@ with gr.Blocks(theme=gr.themes.Default().set(body_background_fill="#CCEEFF")) as
|
|
18 |
modelSelect = gr.Dropdown(
|
19 |
[
|
20 |
["FM Radio Super Resolution","FM_Radio_SR.th"],
|
21 |
-
["AM Radio Super Resolution (Beta)","AM_Radio_SR.th"]
|
|
|
22 |
],
|
23 |
label="Select Model:",
|
24 |
value="FM_Radio_SR.th",
|
@@ -66,9 +67,19 @@ with gr.Blocks(theme=gr.themes.Default().set(body_background_fill="#CCEEFF")) as
|
|
66 |
lrAudio = torch.tensor(audioData[1].copy().astype(np.float32)/32768).transpose(0,1)
|
67 |
if audioData[0] != 44100:
|
68 |
lrAudio = resample(lrAudio, audioData[0], 44100)
|
69 |
-
|
|
|
70 |
hrAudio=hrAudio / max(hrAudio.abs().max().item(), 1)
|
71 |
outAudio=(hrAudio*32767).numpy().astype(np.int16).transpose(1,0)
|
72 |
return tuple([44100, outAudio])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
73 |
|
74 |
layout.launch()
|
|
|
18 |
modelSelect = gr.Dropdown(
|
19 |
[
|
20 |
["FM Radio Super Resolution","FM_Radio_SR.th"],
|
21 |
+
["AM Radio Super Resolution (Beta)","AM_Radio_SR.th"],
|
22 |
+
["Telephone Super Resolution (Beta)","Telephone_SR.th"]
|
23 |
],
|
24 |
label="Select Model:",
|
25 |
value="FM_Radio_SR.th",
|
|
|
67 |
lrAudio = torch.tensor(audioData[1].copy().astype(np.float32)/32768).transpose(0,1)
|
68 |
if audioData[0] != 44100:
|
69 |
lrAudio = resample(lrAudio, audioData[0], 44100)
|
70 |
+
model_name, experiment_file = getModelInfo(model)
|
71 |
+
hrAudio=upscaleAudio(lrAudio, model, model_name=model_name, experiment_file=experiment_file)
|
72 |
hrAudio=hrAudio / max(hrAudio.abs().max().item(), 1)
|
73 |
outAudio=(hrAudio*32767).numpy().astype(np.int16).transpose(1,0)
|
74 |
return tuple([44100, outAudio])
|
75 |
+
|
76 |
+
def getModelInfo(modelFilename: str):
|
77 |
+
if(modelFilename == "FM_Radio_SR.th"):
|
78 |
+
return "aero", "aero_441-441_512_256.yaml"
|
79 |
+
if(modelFilename == "AM_Radio_SR.th"):
|
80 |
+
return "aero", "aero_441-441_512_256.yaml"
|
81 |
+
if(modelFilename == "Telephone_SR.th"):
|
82 |
+
return "aero", "aero_441-441_512_256.yaml"
|
83 |
+
return "aero", "aero_441-441_512_256.yaml"
|
84 |
|
85 |
layout.launch()
|
models/Telephone_SR.th
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:b59e32fccaf83c7e038b8c5e894eeebbce272e9ab00db6b20d45b2fba6e911ca
|
3 |
+
size 136533968
|
processAudio.py
CHANGED
@@ -20,9 +20,9 @@ SEGMENT_DURATION_SEC = 5
|
|
20 |
SEGMENT_OVERLAP_RATIO = 0.25
|
21 |
SERIALIZE_KEY_STATE = 'state'
|
22 |
|
23 |
-
def _load_model(
|
24 |
-
checkpoint_file = Path(
|
25 |
-
model = modelFactory.get_model(model_name)['generator']
|
26 |
package = torch.load(checkpoint_file, 'cpu')
|
27 |
if 'state' in package.keys(): #raw model file
|
28 |
logger.info(bold(f'Loading model {model_name} from file.'))
|
@@ -35,9 +35,9 @@ def crossfade_and_blend(out_clip, in_clip, segment_overlap_samples):
|
|
35 |
fade_in = torchaudio.transforms.Fade(segment_overlap_samples, 0)
|
36 |
return fade_out(out_clip) + fade_in(in_clip)
|
37 |
|
38 |
-
def upscaleAudio(lr_sig, checkpoint_file: str, sr=44100, hr_sr=44100, model_name="aero", progress=Progress()):
|
39 |
|
40 |
-
model = _load_model(checkpoint_file,model_name)
|
41 |
device = torch.device('cpu')
|
42 |
if torch.cuda.is_available():
|
43 |
device = torch.device('cuda')
|
|
|
20 |
SEGMENT_OVERLAP_RATIO = 0.25
|
21 |
SERIALIZE_KEY_STATE = 'state'
|
22 |
|
23 |
+
def _load_model(checkpoint_filename="FM_Radio_SR.th",model_name="aero",experiment_file="aero_441-441_512_256.yaml"):
|
24 |
+
checkpoint_file = Path("models/" + checkpoint_filename)
|
25 |
+
model = modelFactory.get_model(model_name,experiment_file)['generator']
|
26 |
package = torch.load(checkpoint_file, 'cpu')
|
27 |
if 'state' in package.keys(): #raw model file
|
28 |
logger.info(bold(f'Loading model {model_name} from file.'))
|
|
|
35 |
fade_in = torchaudio.transforms.Fade(segment_overlap_samples, 0)
|
36 |
return fade_out(out_clip) + fade_in(in_clip)
|
37 |
|
38 |
+
def upscaleAudio(lr_sig, checkpoint_file: str, sr=44100, hr_sr=44100, model_name="aero", experiment_file="aero_441-441_512_256.yaml", progress=Progress()):
|
39 |
|
40 |
+
model = _load_model(checkpoint_file,model_name,experiment_file)
|
41 |
device = torch.device('cpu')
|
42 |
if torch.cuda.is_available():
|
43 |
device = torch.device('cuda')
|
src/models/modelFactory.py
CHANGED
@@ -2,12 +2,12 @@ from src.models.aero import Aero
|
|
2 |
from src.models.seanet import Seanet
|
3 |
from yaml import safe_load
|
4 |
|
5 |
-
def get_model(model_name="aero"):
|
6 |
if model_name == 'aero':
|
7 |
-
with open("conf/experiment/
|
8 |
generator = Aero(**safe_load(f)["aero"])
|
9 |
elif model_name == 'seanet':
|
10 |
-
with open("conf/experiment/
|
11 |
generator = Seanet(**safe_load(f)["seanet"])
|
12 |
|
13 |
models = {'generator': generator}
|
|
|
2 |
from src.models.seanet import Seanet
|
3 |
from yaml import safe_load
|
4 |
|
5 |
+
def get_model(model_name="aero", experiment_file="aero_441-441_512_256.yaml"):
|
6 |
if model_name == 'aero':
|
7 |
+
with open("conf/experiment/" + experiment_file) as f:
|
8 |
generator = Aero(**safe_load(f)["aero"])
|
9 |
elif model_name == 'seanet':
|
10 |
+
with open("conf/experiment/" + experiment_file) as f:
|
11 |
generator = Seanet(**safe_load(f)["seanet"])
|
12 |
|
13 |
models = {'generator': generator}
|