|
import math |
|
import os |
|
|
|
import time |
|
|
|
import torch |
|
import logging |
|
from pathlib import Path |
|
|
|
import torchaudio |
|
from torchaudio.functional import resample |
|
from gradio import Progress |
|
|
|
from src.models import modelFactory |
|
from src.utils import bold |
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
SEGMENT_DURATION_SEC = 5 |
|
SEGMENT_OVERLAP_RATIO = 0.25 |
|
SERIALIZE_KEY_STATE = 'state' |
|
|
|
def _load_model(checkpoint_filename="FM_Radio_SR.th",model_name="aero",experiment_file="aero_441-441_512_256.yaml"): |
|
checkpoint_file = Path("models/" + checkpoint_filename) |
|
model = modelFactory.get_model(model_name,experiment_file)['generator'] |
|
package = torch.load(checkpoint_file, 'cpu') |
|
if 'state' in package.keys(): |
|
logger.info(bold(f'Loading model {model_name} from file.')) |
|
model.load_state_dict(package[SERIALIZE_KEY_STATE]) |
|
|
|
return model |
|
|
|
def crossfade_and_blend(out_clip, in_clip, segment_overlap_samples): |
|
fade_out = torchaudio.transforms.Fade(0,segment_overlap_samples) |
|
fade_in = torchaudio.transforms.Fade(segment_overlap_samples, 0) |
|
return fade_out(out_clip) + fade_in(in_clip) |
|
|
|
def upscaleAudio(lr_sig, checkpoint_file: str, sr=44100, hr_sr=44100, model_name="aero", experiment_file="aero_441-441_512_256.yaml", progress=Progress()): |
|
|
|
model = _load_model(checkpoint_file,model_name,experiment_file) |
|
device = torch.device('cpu') |
|
if torch.cuda.is_available(): |
|
device = torch.device('cuda') |
|
model.cuda() |
|
|
|
logger.info(f'lr wav shape: {lr_sig.shape}') |
|
|
|
segment_duration_samples = sr * SEGMENT_DURATION_SEC |
|
n_chunks = math.ceil(lr_sig.shape[-1] / segment_duration_samples) |
|
logger.info(f'number of chunks: {n_chunks}') |
|
|
|
lr_chunks = [] |
|
for i in range(n_chunks): |
|
start = i * segment_duration_samples |
|
end = min((i + 1) * segment_duration_samples, lr_sig.shape[-1]) |
|
lr_chunks.append(lr_sig[:, start:end]) |
|
|
|
pr_chunks = [] |
|
|
|
lr_segment_overlap_samples = int(sr*SEGMENT_OVERLAP_RATIO*SEGMENT_DURATION_SEC) |
|
hr_segment_overlap_samples = int(hr_sr*SEGMENT_OVERLAP_RATIO*SEGMENT_DURATION_SEC) |
|
|
|
model.eval() |
|
pred_start = time.time() |
|
with torch.no_grad(): |
|
previous_chunk = None |
|
progress(0) |
|
for i, lr_chunk in enumerate(lr_chunks): |
|
progress(i/n_chunks) |
|
pr_chunk = None |
|
if previous_chunk is not None: |
|
combined_chunk = torch.cat((previous_chunk[...,-lr_segment_overlap_samples:], lr_chunk), 1) |
|
pr_combined_chunk = model(combined_chunk.unsqueeze(0).to(device)).squeeze(0) |
|
pr_chunk = pr_combined_chunk[...,hr_segment_overlap_samples:] |
|
pr_chunks[-1][...,-hr_segment_overlap_samples:] = crossfade_and_blend(pr_chunks[-1][...,-hr_segment_overlap_samples:], pr_combined_chunk.cpu()[...,:hr_segment_overlap_samples], hr_segment_overlap_samples ) |
|
else: |
|
pr_chunk = model(lr_chunk.unsqueeze(0).to(device)).squeeze(0) |
|
logger.info(f'lr chunk {i} shape: {lr_chunk.shape}') |
|
logger.info(f'pr chunk {i} shape: {pr_chunk.shape}') |
|
pr_chunks.append(pr_chunk.cpu()) |
|
previous_chunk = lr_chunk |
|
|
|
pred_duration = time.time() - pred_start |
|
logger.info(f'prediction duration: {pred_duration}') |
|
|
|
pr = torch.concat(pr_chunks, dim=-1) |
|
|
|
logger.info(f'pr wav shape: {pr.shape}') |
|
|
|
return pr |