sereich commited on
Commit
3551fa7
·
1 Parent(s): 872a4c6

Add GPU support when available, use proportional fading.

Browse files
Files changed (1) hide show
  1. processAudio.py +11 -8
processAudio.py CHANGED
@@ -17,7 +17,7 @@ from src.utils import bold
17
  logger = logging.getLogger(__name__)
18
 
19
  SEGMENT_DURATION_SEC = 5
20
- SEGMENT_OVERLAP_SAMPLES = 2048
21
  SERIALIZE_KEY_STATE = 'state'
22
 
23
  def _load_model(checkpoint_file="models/FM_Radio_SR.th",model_name="aero"):
@@ -30,16 +30,18 @@ def _load_model(checkpoint_file="models/FM_Radio_SR.th",model_name="aero"):
30
 
31
  return model
32
 
33
- def crossfade_and_blend(out_clip, in_clip):
34
- fade_out = torchaudio.transforms.Fade(0,SEGMENT_OVERLAP_SAMPLES)
35
- fade_in = torchaudio.transforms.Fade(SEGMENT_OVERLAP_SAMPLES, 0)
36
  return fade_out(out_clip) + fade_in(in_clip)
37
 
38
  def upscaleAudio(lr_sig, checkpoint_file: str, sr=44100, hr_sr=44100, model_name="aero", progress=Progress()):
39
 
40
  model = _load_model(checkpoint_file,model_name)
41
  device = torch.device('cpu')
42
- #model.cuda()
 
 
43
 
44
  logger.info(f'lr wav shape: {lr_sig.shape}')
45
 
@@ -55,7 +57,8 @@ def upscaleAudio(lr_sig, checkpoint_file: str, sr=44100, hr_sr=44100, model_name
55
 
56
  pr_chunks = []
57
 
58
- lr_segment_overlap_samples = int((sr/hr_sr) * SEGMENT_OVERLAP_SAMPLES)
 
59
 
60
  model.eval()
61
  pred_start = time.time()
@@ -68,8 +71,8 @@ def upscaleAudio(lr_sig, checkpoint_file: str, sr=44100, hr_sr=44100, model_name
68
  if previous_chunk is not None:
69
  combined_chunk = torch.cat((previous_chunk[...,-lr_segment_overlap_samples:], lr_chunk), 1)
70
  pr_combined_chunk = model(combined_chunk.unsqueeze(0).to(device)).squeeze(0)
71
- pr_chunk = pr_combined_chunk[...,SEGMENT_OVERLAP_SAMPLES:]
72
- pr_chunks[-1][...,-SEGMENT_OVERLAP_SAMPLES:] = crossfade_and_blend(pr_chunks[-1][...,-SEGMENT_OVERLAP_SAMPLES:], pr_combined_chunk.cpu()[...,:SEGMENT_OVERLAP_SAMPLES] )
73
  else:
74
  pr_chunk = model(lr_chunk.unsqueeze(0).to(device)).squeeze(0)
75
  logger.info(f'lr chunk {i} shape: {lr_chunk.shape}')
 
17
  logger = logging.getLogger(__name__)
18
 
19
  SEGMENT_DURATION_SEC = 5
20
+ SEGMENT_OVERLAP_RATIO = 0.25
21
  SERIALIZE_KEY_STATE = 'state'
22
 
23
  def _load_model(checkpoint_file="models/FM_Radio_SR.th",model_name="aero"):
 
30
 
31
  return model
32
 
33
+ def crossfade_and_blend(out_clip, in_clip, segment_overlap_samples):
34
+ fade_out = torchaudio.transforms.Fade(0,segment_overlap_samples)
35
+ fade_in = torchaudio.transforms.Fade(segment_overlap_samples, 0)
36
  return fade_out(out_clip) + fade_in(in_clip)
37
 
38
  def upscaleAudio(lr_sig, checkpoint_file: str, sr=44100, hr_sr=44100, model_name="aero", progress=Progress()):
39
 
40
  model = _load_model(checkpoint_file,model_name)
41
  device = torch.device('cpu')
42
+ if torch.cuda.is_available():
43
+ device = torch.device('cuda')
44
+ model.cuda()
45
 
46
  logger.info(f'lr wav shape: {lr_sig.shape}')
47
 
 
57
 
58
  pr_chunks = []
59
 
60
+ lr_segment_overlap_samples = int(sr*SEGMENT_OVERLAP_RATIO)
61
+ hr_segment_overlap_samples = int(hr_sr*SEGMENT_OVERLAP_RATIO)
62
 
63
  model.eval()
64
  pred_start = time.time()
 
71
  if previous_chunk is not None:
72
  combined_chunk = torch.cat((previous_chunk[...,-lr_segment_overlap_samples:], lr_chunk), 1)
73
  pr_combined_chunk = model(combined_chunk.unsqueeze(0).to(device)).squeeze(0)
74
+ pr_chunk = pr_combined_chunk[...,hr_segment_overlap_samples:]
75
+ pr_chunks[-1][...,-hr_segment_overlap_samples:] = crossfade_and_blend(pr_chunks[-1][...,-hr_segment_overlap_samples:], pr_combined_chunk.cpu()[...,:hr_segment_overlap_samples], hr_segment_overlap_samples )
76
  else:
77
  pr_chunk = model(lr_chunk.unsqueeze(0).to(device)).squeeze(0)
78
  logger.info(f'lr chunk {i} shape: {lr_chunk.shape}')