import math import os import time import torch import logging from pathlib import Path import torchaudio from torchaudio.functional import resample from gradio import Progress from src.models import modelFactory from src.utils import bold logger = logging.getLogger(__name__) SEGMENT_DURATION_SEC = 5 SEGMENT_OVERLAP_RATIO = 0.25 SERIALIZE_KEY_STATE = 'state' def _load_model(checkpoint_filename="FM_Radio_SR.th",model_name="aero",experiment_file="aero_441-441_512_256.yaml"): checkpoint_file = Path("models/" + checkpoint_filename) model = modelFactory.get_model(model_name,experiment_file)['generator'] package = torch.load(checkpoint_file, 'cpu') if 'state' in package.keys(): #raw model file logger.info(bold(f'Loading model {model_name} from file.')) model.load_state_dict(package[SERIALIZE_KEY_STATE]) return model def crossfade_and_blend(out_clip, in_clip, segment_overlap_samples): fade_out = torchaudio.transforms.Fade(0,segment_overlap_samples) fade_in = torchaudio.transforms.Fade(segment_overlap_samples, 0) return fade_out(out_clip) + fade_in(in_clip) def upscaleAudio(lr_sig, checkpoint_file: str, sr=44100, hr_sr=44100, model_name="aero", experiment_file="aero_441-441_512_256.yaml", progress=Progress()): model = _load_model(checkpoint_file,model_name,experiment_file) device = torch.device('cpu') if torch.cuda.is_available(): device = torch.device('cuda') model.cuda() logger.info(f'lr wav shape: {lr_sig.shape}') segment_duration_samples = sr * SEGMENT_DURATION_SEC n_chunks = math.ceil(lr_sig.shape[-1] / segment_duration_samples) logger.info(f'number of chunks: {n_chunks}') lr_chunks = [] for i in range(n_chunks): start = i * segment_duration_samples end = min((i + 1) * segment_duration_samples, lr_sig.shape[-1]) lr_chunks.append(lr_sig[:, start:end]) pr_chunks = [] lr_segment_overlap_samples = int(sr*SEGMENT_OVERLAP_RATIO*SEGMENT_DURATION_SEC) hr_segment_overlap_samples = int(hr_sr*SEGMENT_OVERLAP_RATIO*SEGMENT_DURATION_SEC) model.eval() pred_start = time.time() with torch.no_grad(): previous_chunk = None progress(0) for i, lr_chunk in enumerate(lr_chunks): progress(i/n_chunks) pr_chunk = None if previous_chunk is not None: combined_chunk = torch.cat((previous_chunk[...,-lr_segment_overlap_samples:], lr_chunk), 1) pr_combined_chunk = model(combined_chunk.unsqueeze(0).to(device)).squeeze(0) pr_chunk = pr_combined_chunk[...,hr_segment_overlap_samples:] pr_chunks[-1][...,-hr_segment_overlap_samples:] = crossfade_and_blend(pr_chunks[-1][...,-hr_segment_overlap_samples:], pr_combined_chunk.cpu()[...,:hr_segment_overlap_samples], hr_segment_overlap_samples ) else: pr_chunk = model(lr_chunk.unsqueeze(0).to(device)).squeeze(0) logger.info(f'lr chunk {i} shape: {lr_chunk.shape}') logger.info(f'pr chunk {i} shape: {pr_chunk.shape}') pr_chunks.append(pr_chunk.cpu()) previous_chunk = lr_chunk pred_duration = time.time() - pred_start logger.info(f'prediction duration: {pred_duration}') pr = torch.concat(pr_chunks, dim=-1) logger.info(f'pr wav shape: {pr.shape}') return pr