BroadcastAudioUpscaling / processAudio.py
sereich's picture
Add phone model (beta), allow models to use different architectures
efc318c
raw
history blame
3.4 kB
import math
import os
import time
import torch
import logging
from pathlib import Path
import torchaudio
from torchaudio.functional import resample
from gradio import Progress
from src.models import modelFactory
from src.utils import bold
logger = logging.getLogger(__name__)
SEGMENT_DURATION_SEC = 5
SEGMENT_OVERLAP_RATIO = 0.25
SERIALIZE_KEY_STATE = 'state'
def _load_model(checkpoint_filename="FM_Radio_SR.th",model_name="aero",experiment_file="aero_441-441_512_256.yaml"):
checkpoint_file = Path("models/" + checkpoint_filename)
model = modelFactory.get_model(model_name,experiment_file)['generator']
package = torch.load(checkpoint_file, 'cpu')
if 'state' in package.keys(): #raw model file
logger.info(bold(f'Loading model {model_name} from file.'))
model.load_state_dict(package[SERIALIZE_KEY_STATE])
return model
def crossfade_and_blend(out_clip, in_clip, segment_overlap_samples):
fade_out = torchaudio.transforms.Fade(0,segment_overlap_samples)
fade_in = torchaudio.transforms.Fade(segment_overlap_samples, 0)
return fade_out(out_clip) + fade_in(in_clip)
def upscaleAudio(lr_sig, checkpoint_file: str, sr=44100, hr_sr=44100, model_name="aero", experiment_file="aero_441-441_512_256.yaml", progress=Progress()):
model = _load_model(checkpoint_file,model_name,experiment_file)
device = torch.device('cpu')
if torch.cuda.is_available():
device = torch.device('cuda')
model.cuda()
logger.info(f'lr wav shape: {lr_sig.shape}')
segment_duration_samples = sr * SEGMENT_DURATION_SEC
n_chunks = math.ceil(lr_sig.shape[-1] / segment_duration_samples)
logger.info(f'number of chunks: {n_chunks}')
lr_chunks = []
for i in range(n_chunks):
start = i * segment_duration_samples
end = min((i + 1) * segment_duration_samples, lr_sig.shape[-1])
lr_chunks.append(lr_sig[:, start:end])
pr_chunks = []
lr_segment_overlap_samples = int(sr*SEGMENT_OVERLAP_RATIO*SEGMENT_DURATION_SEC)
hr_segment_overlap_samples = int(hr_sr*SEGMENT_OVERLAP_RATIO*SEGMENT_DURATION_SEC)
model.eval()
pred_start = time.time()
with torch.no_grad():
previous_chunk = None
progress(0)
for i, lr_chunk in enumerate(lr_chunks):
progress(i/n_chunks)
pr_chunk = None
if previous_chunk is not None:
combined_chunk = torch.cat((previous_chunk[...,-lr_segment_overlap_samples:], lr_chunk), 1)
pr_combined_chunk = model(combined_chunk.unsqueeze(0).to(device)).squeeze(0)
pr_chunk = pr_combined_chunk[...,hr_segment_overlap_samples:]
pr_chunks[-1][...,-hr_segment_overlap_samples:] = crossfade_and_blend(pr_chunks[-1][...,-hr_segment_overlap_samples:], pr_combined_chunk.cpu()[...,:hr_segment_overlap_samples], hr_segment_overlap_samples )
else:
pr_chunk = model(lr_chunk.unsqueeze(0).to(device)).squeeze(0)
logger.info(f'lr chunk {i} shape: {lr_chunk.shape}')
logger.info(f'pr chunk {i} shape: {pr_chunk.shape}')
pr_chunks.append(pr_chunk.cpu())
previous_chunk = lr_chunk
pred_duration = time.time() - pred_start
logger.info(f'prediction duration: {pred_duration}')
pr = torch.concat(pr_chunks, dim=-1)
logger.info(f'pr wav shape: {pr.shape}')
return pr