sereich commited on
Commit
efc318c
·
1 Parent(s): fe17acd

Add phone model (beta), allow models to use different architectures

Browse files
app.py CHANGED
@@ -18,7 +18,8 @@ with gr.Blocks(theme=gr.themes.Default().set(body_background_fill="#CCEEFF")) as
18
  modelSelect = gr.Dropdown(
19
  [
20
  ["FM Radio Super Resolution","FM_Radio_SR.th"],
21
- ["AM Radio Super Resolution (Beta)","AM_Radio_SR.th"]
 
22
  ],
23
  label="Select Model:",
24
  value="FM_Radio_SR.th",
@@ -66,9 +67,19 @@ with gr.Blocks(theme=gr.themes.Default().set(body_background_fill="#CCEEFF")) as
66
  lrAudio = torch.tensor(audioData[1].copy().astype(np.float32)/32768).transpose(0,1)
67
  if audioData[0] != 44100:
68
  lrAudio = resample(lrAudio, audioData[0], 44100)
69
- hrAudio=upscaleAudio(lrAudio, "models/" + model)
 
70
  hrAudio=hrAudio / max(hrAudio.abs().max().item(), 1)
71
  outAudio=(hrAudio*32767).numpy().astype(np.int16).transpose(1,0)
72
  return tuple([44100, outAudio])
 
 
 
 
 
 
 
 
 
73
 
74
  layout.launch()
 
18
  modelSelect = gr.Dropdown(
19
  [
20
  ["FM Radio Super Resolution","FM_Radio_SR.th"],
21
+ ["AM Radio Super Resolution (Beta)","AM_Radio_SR.th"],
22
+ ["Telephone Super Resolution (Beta)","Telephone_SR.th"]
23
  ],
24
  label="Select Model:",
25
  value="FM_Radio_SR.th",
 
67
  lrAudio = torch.tensor(audioData[1].copy().astype(np.float32)/32768).transpose(0,1)
68
  if audioData[0] != 44100:
69
  lrAudio = resample(lrAudio, audioData[0], 44100)
70
+ model_name, experiment_file = getModelInfo(model)
71
+ hrAudio=upscaleAudio(lrAudio, model, model_name=model_name, experiment_file=experiment_file)
72
  hrAudio=hrAudio / max(hrAudio.abs().max().item(), 1)
73
  outAudio=(hrAudio*32767).numpy().astype(np.int16).transpose(1,0)
74
  return tuple([44100, outAudio])
75
+
76
+ def getModelInfo(modelFilename: str):
77
+ if(modelFilename == "FM_Radio_SR.th"):
78
+ return "aero", "aero_441-441_512_256.yaml"
79
+ if(modelFilename == "AM_Radio_SR.th"):
80
+ return "aero", "aero_441-441_512_256.yaml"
81
+ if(modelFilename == "Telephone_SR.th"):
82
+ return "aero", "aero_441-441_512_256.yaml"
83
+ return "aero", "aero_441-441_512_256.yaml"
84
 
85
  layout.launch()
models/Telephone_SR.th ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b59e32fccaf83c7e038b8c5e894eeebbce272e9ab00db6b20d45b2fba6e911ca
3
+ size 136533968
processAudio.py CHANGED
@@ -20,9 +20,9 @@ SEGMENT_DURATION_SEC = 5
20
  SEGMENT_OVERLAP_RATIO = 0.25
21
  SERIALIZE_KEY_STATE = 'state'
22
 
23
- def _load_model(checkpoint_file="models/FM_Radio_SR.th",model_name="aero"):
24
- checkpoint_file = Path(checkpoint_file)
25
- model = modelFactory.get_model(model_name)['generator']
26
  package = torch.load(checkpoint_file, 'cpu')
27
  if 'state' in package.keys(): #raw model file
28
  logger.info(bold(f'Loading model {model_name} from file.'))
@@ -35,9 +35,9 @@ def crossfade_and_blend(out_clip, in_clip, segment_overlap_samples):
35
  fade_in = torchaudio.transforms.Fade(segment_overlap_samples, 0)
36
  return fade_out(out_clip) + fade_in(in_clip)
37
 
38
- def upscaleAudio(lr_sig, checkpoint_file: str, sr=44100, hr_sr=44100, model_name="aero", progress=Progress()):
39
 
40
- model = _load_model(checkpoint_file,model_name)
41
  device = torch.device('cpu')
42
  if torch.cuda.is_available():
43
  device = torch.device('cuda')
 
20
  SEGMENT_OVERLAP_RATIO = 0.25
21
  SERIALIZE_KEY_STATE = 'state'
22
 
23
+ def _load_model(checkpoint_filename="FM_Radio_SR.th",model_name="aero",experiment_file="aero_441-441_512_256.yaml"):
24
+ checkpoint_file = Path("models/" + checkpoint_filename)
25
+ model = modelFactory.get_model(model_name,experiment_file)['generator']
26
  package = torch.load(checkpoint_file, 'cpu')
27
  if 'state' in package.keys(): #raw model file
28
  logger.info(bold(f'Loading model {model_name} from file.'))
 
35
  fade_in = torchaudio.transforms.Fade(segment_overlap_samples, 0)
36
  return fade_out(out_clip) + fade_in(in_clip)
37
 
38
+ def upscaleAudio(lr_sig, checkpoint_file: str, sr=44100, hr_sr=44100, model_name="aero", experiment_file="aero_441-441_512_256.yaml", progress=Progress()):
39
 
40
+ model = _load_model(checkpoint_file,model_name,experiment_file)
41
  device = torch.device('cpu')
42
  if torch.cuda.is_available():
43
  device = torch.device('cuda')
src/models/modelFactory.py CHANGED
@@ -2,12 +2,12 @@ from src.models.aero import Aero
2
  from src.models.seanet import Seanet
3
  from yaml import safe_load
4
 
5
- def get_model(model_name="aero"):
6
  if model_name == 'aero':
7
- with open("conf/experiment/aero_441-441_512_256.yaml") as f:
8
  generator = Aero(**safe_load(f)["aero"])
9
  elif model_name == 'seanet':
10
- with open("conf/experiment/seanet_441-441.yaml") as f:
11
  generator = Seanet(**safe_load(f)["seanet"])
12
 
13
  models = {'generator': generator}
 
2
  from src.models.seanet import Seanet
3
  from yaml import safe_load
4
 
5
+ def get_model(model_name="aero", experiment_file="aero_441-441_512_256.yaml"):
6
  if model_name == 'aero':
7
+ with open("conf/experiment/" + experiment_file) as f:
8
  generator = Aero(**safe_load(f)["aero"])
9
  elif model_name == 'seanet':
10
+ with open("conf/experiment/" + experiment_file) as f:
11
  generator = Seanet(**safe_load(f)["seanet"])
12
 
13
  models = {'generator': generator}