Add GPU support when available, use proportional fading.
Browse files- processAudio.py +11 -8
processAudio.py
CHANGED
@@ -17,7 +17,7 @@ from src.utils import bold
|
|
17 |
logger = logging.getLogger(__name__)
|
18 |
|
19 |
SEGMENT_DURATION_SEC = 5
|
20 |
-
|
21 |
SERIALIZE_KEY_STATE = 'state'
|
22 |
|
23 |
def _load_model(checkpoint_file="models/FM_Radio_SR.th",model_name="aero"):
|
@@ -30,16 +30,18 @@ def _load_model(checkpoint_file="models/FM_Radio_SR.th",model_name="aero"):
|
|
30 |
|
31 |
return model
|
32 |
|
33 |
-
def crossfade_and_blend(out_clip, in_clip):
|
34 |
-
fade_out = torchaudio.transforms.Fade(0,
|
35 |
-
fade_in = torchaudio.transforms.Fade(
|
36 |
return fade_out(out_clip) + fade_in(in_clip)
|
37 |
|
38 |
def upscaleAudio(lr_sig, checkpoint_file: str, sr=44100, hr_sr=44100, model_name="aero", progress=Progress()):
|
39 |
|
40 |
model = _load_model(checkpoint_file,model_name)
|
41 |
device = torch.device('cpu')
|
42 |
-
|
|
|
|
|
43 |
|
44 |
logger.info(f'lr wav shape: {lr_sig.shape}')
|
45 |
|
@@ -55,7 +57,8 @@ def upscaleAudio(lr_sig, checkpoint_file: str, sr=44100, hr_sr=44100, model_name
|
|
55 |
|
56 |
pr_chunks = []
|
57 |
|
58 |
-
lr_segment_overlap_samples = int(
|
|
|
59 |
|
60 |
model.eval()
|
61 |
pred_start = time.time()
|
@@ -68,8 +71,8 @@ def upscaleAudio(lr_sig, checkpoint_file: str, sr=44100, hr_sr=44100, model_name
|
|
68 |
if previous_chunk is not None:
|
69 |
combined_chunk = torch.cat((previous_chunk[...,-lr_segment_overlap_samples:], lr_chunk), 1)
|
70 |
pr_combined_chunk = model(combined_chunk.unsqueeze(0).to(device)).squeeze(0)
|
71 |
-
pr_chunk = pr_combined_chunk[...,
|
72 |
-
pr_chunks[-1][...,-
|
73 |
else:
|
74 |
pr_chunk = model(lr_chunk.unsqueeze(0).to(device)).squeeze(0)
|
75 |
logger.info(f'lr chunk {i} shape: {lr_chunk.shape}')
|
|
|
17 |
logger = logging.getLogger(__name__)
|
18 |
|
19 |
SEGMENT_DURATION_SEC = 5
|
20 |
+
SEGMENT_OVERLAP_RATIO = 0.25
|
21 |
SERIALIZE_KEY_STATE = 'state'
|
22 |
|
23 |
def _load_model(checkpoint_file="models/FM_Radio_SR.th",model_name="aero"):
|
|
|
30 |
|
31 |
return model
|
32 |
|
33 |
+
def crossfade_and_blend(out_clip, in_clip, segment_overlap_samples):
|
34 |
+
fade_out = torchaudio.transforms.Fade(0,segment_overlap_samples)
|
35 |
+
fade_in = torchaudio.transforms.Fade(segment_overlap_samples, 0)
|
36 |
return fade_out(out_clip) + fade_in(in_clip)
|
37 |
|
38 |
def upscaleAudio(lr_sig, checkpoint_file: str, sr=44100, hr_sr=44100, model_name="aero", progress=Progress()):
|
39 |
|
40 |
model = _load_model(checkpoint_file,model_name)
|
41 |
device = torch.device('cpu')
|
42 |
+
if torch.cuda.is_available():
|
43 |
+
device = torch.device('cuda')
|
44 |
+
model.cuda()
|
45 |
|
46 |
logger.info(f'lr wav shape: {lr_sig.shape}')
|
47 |
|
|
|
57 |
|
58 |
pr_chunks = []
|
59 |
|
60 |
+
lr_segment_overlap_samples = int(sr*SEGMENT_OVERLAP_RATIO)
|
61 |
+
hr_segment_overlap_samples = int(hr_sr*SEGMENT_OVERLAP_RATIO)
|
62 |
|
63 |
model.eval()
|
64 |
pred_start = time.time()
|
|
|
71 |
if previous_chunk is not None:
|
72 |
combined_chunk = torch.cat((previous_chunk[...,-lr_segment_overlap_samples:], lr_chunk), 1)
|
73 |
pr_combined_chunk = model(combined_chunk.unsqueeze(0).to(device)).squeeze(0)
|
74 |
+
pr_chunk = pr_combined_chunk[...,hr_segment_overlap_samples:]
|
75 |
+
pr_chunks[-1][...,-hr_segment_overlap_samples:] = crossfade_and_blend(pr_chunks[-1][...,-hr_segment_overlap_samples:], pr_combined_chunk.cpu()[...,:hr_segment_overlap_samples], hr_segment_overlap_samples )
|
76 |
else:
|
77 |
pr_chunk = model(lr_chunk.unsqueeze(0).to(device)).squeeze(0)
|
78 |
logger.info(f'lr chunk {i} shape: {lr_chunk.shape}')
|