File size: 3,400 Bytes
f113387 3551fa7 f113387 efc318c f113387 3551fa7 f113387 efc318c f113387 efc318c f113387 3551fa7 f113387 2ce231c f113387 3551fa7 f113387 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 |
import math
import os
import time
import torch
import logging
from pathlib import Path
import torchaudio
from torchaudio.functional import resample
from gradio import Progress
from src.models import modelFactory
from src.utils import bold
logger = logging.getLogger(__name__)
SEGMENT_DURATION_SEC = 5
SEGMENT_OVERLAP_RATIO = 0.25
SERIALIZE_KEY_STATE = 'state'
def _load_model(checkpoint_filename="FM_Radio_SR.th",model_name="aero",experiment_file="aero_441-441_512_256.yaml"):
checkpoint_file = Path("models/" + checkpoint_filename)
model = modelFactory.get_model(model_name,experiment_file)['generator']
package = torch.load(checkpoint_file, 'cpu')
if 'state' in package.keys(): #raw model file
logger.info(bold(f'Loading model {model_name} from file.'))
model.load_state_dict(package[SERIALIZE_KEY_STATE])
return model
def crossfade_and_blend(out_clip, in_clip, segment_overlap_samples):
fade_out = torchaudio.transforms.Fade(0,segment_overlap_samples)
fade_in = torchaudio.transforms.Fade(segment_overlap_samples, 0)
return fade_out(out_clip) + fade_in(in_clip)
def upscaleAudio(lr_sig, checkpoint_file: str, sr=44100, hr_sr=44100, model_name="aero", experiment_file="aero_441-441_512_256.yaml", progress=Progress()):
model = _load_model(checkpoint_file,model_name,experiment_file)
device = torch.device('cpu')
if torch.cuda.is_available():
device = torch.device('cuda')
model.cuda()
logger.info(f'lr wav shape: {lr_sig.shape}')
segment_duration_samples = sr * SEGMENT_DURATION_SEC
n_chunks = math.ceil(lr_sig.shape[-1] / segment_duration_samples)
logger.info(f'number of chunks: {n_chunks}')
lr_chunks = []
for i in range(n_chunks):
start = i * segment_duration_samples
end = min((i + 1) * segment_duration_samples, lr_sig.shape[-1])
lr_chunks.append(lr_sig[:, start:end])
pr_chunks = []
lr_segment_overlap_samples = int(sr*SEGMENT_OVERLAP_RATIO*SEGMENT_DURATION_SEC)
hr_segment_overlap_samples = int(hr_sr*SEGMENT_OVERLAP_RATIO*SEGMENT_DURATION_SEC)
model.eval()
pred_start = time.time()
with torch.no_grad():
previous_chunk = None
progress(0)
for i, lr_chunk in enumerate(lr_chunks):
progress(i/n_chunks)
pr_chunk = None
if previous_chunk is not None:
combined_chunk = torch.cat((previous_chunk[...,-lr_segment_overlap_samples:], lr_chunk), 1)
pr_combined_chunk = model(combined_chunk.unsqueeze(0).to(device)).squeeze(0)
pr_chunk = pr_combined_chunk[...,hr_segment_overlap_samples:]
pr_chunks[-1][...,-hr_segment_overlap_samples:] = crossfade_and_blend(pr_chunks[-1][...,-hr_segment_overlap_samples:], pr_combined_chunk.cpu()[...,:hr_segment_overlap_samples], hr_segment_overlap_samples )
else:
pr_chunk = model(lr_chunk.unsqueeze(0).to(device)).squeeze(0)
logger.info(f'lr chunk {i} shape: {lr_chunk.shape}')
logger.info(f'pr chunk {i} shape: {pr_chunk.shape}')
pr_chunks.append(pr_chunk.cpu())
previous_chunk = lr_chunk
pred_duration = time.time() - pred_start
logger.info(f'prediction duration: {pred_duration}')
pr = torch.concat(pr_chunks, dim=-1)
logger.info(f'pr wav shape: {pr.shape}')
return pr |