File size: 3,400 Bytes
f113387
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3551fa7
f113387
 
efc318c
 
 
f113387
 
 
 
 
 
 
3551fa7
 
 
f113387
 
efc318c
f113387
efc318c
f113387
3551fa7
 
 
f113387
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2ce231c
 
f113387
 
 
 
 
 
 
 
 
 
 
 
3551fa7
 
f113387
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
import math
import os

import time

import torch
import logging
from pathlib import Path

import torchaudio
from torchaudio.functional import resample
from gradio import Progress

from src.models import modelFactory
from src.utils import bold

logger = logging.getLogger(__name__)

SEGMENT_DURATION_SEC = 5
SEGMENT_OVERLAP_RATIO = 0.25
SERIALIZE_KEY_STATE = 'state'

def _load_model(checkpoint_filename="FM_Radio_SR.th",model_name="aero",experiment_file="aero_441-441_512_256.yaml"):
    checkpoint_file = Path("models/" + checkpoint_filename)
    model = modelFactory.get_model(model_name,experiment_file)['generator']
    package = torch.load(checkpoint_file, 'cpu')
    if  'state' in package.keys(): #raw model file
        logger.info(bold(f'Loading model {model_name} from file.'))
        model.load_state_dict(package[SERIALIZE_KEY_STATE])

    return model

def crossfade_and_blend(out_clip, in_clip, segment_overlap_samples):
    fade_out = torchaudio.transforms.Fade(0,segment_overlap_samples)
    fade_in = torchaudio.transforms.Fade(segment_overlap_samples, 0)
    return fade_out(out_clip) + fade_in(in_clip)

def upscaleAudio(lr_sig, checkpoint_file: str, sr=44100, hr_sr=44100, model_name="aero", experiment_file="aero_441-441_512_256.yaml", progress=Progress()):

    model = _load_model(checkpoint_file,model_name,experiment_file)
    device = torch.device('cpu')
    if torch.cuda.is_available():
        device = torch.device('cuda')
        model.cuda()

    logger.info(f'lr wav shape: {lr_sig.shape}')

    segment_duration_samples = sr * SEGMENT_DURATION_SEC
    n_chunks = math.ceil(lr_sig.shape[-1] / segment_duration_samples)
    logger.info(f'number of chunks: {n_chunks}')

    lr_chunks = []
    for i in range(n_chunks):
        start = i * segment_duration_samples
        end = min((i + 1) * segment_duration_samples, lr_sig.shape[-1])
        lr_chunks.append(lr_sig[:, start:end])

    pr_chunks = []

    lr_segment_overlap_samples = int(sr*SEGMENT_OVERLAP_RATIO*SEGMENT_DURATION_SEC) 
    hr_segment_overlap_samples = int(hr_sr*SEGMENT_OVERLAP_RATIO*SEGMENT_DURATION_SEC)

    model.eval()
    pred_start = time.time()
    with torch.no_grad():
        previous_chunk = None
        progress(0)
        for i, lr_chunk in enumerate(lr_chunks):
            progress(i/n_chunks)
            pr_chunk = None
            if previous_chunk is not None:
                combined_chunk = torch.cat((previous_chunk[...,-lr_segment_overlap_samples:], lr_chunk), 1)
                pr_combined_chunk = model(combined_chunk.unsqueeze(0).to(device)).squeeze(0)
                pr_chunk = pr_combined_chunk[...,hr_segment_overlap_samples:]
                pr_chunks[-1][...,-hr_segment_overlap_samples:] = crossfade_and_blend(pr_chunks[-1][...,-hr_segment_overlap_samples:], pr_combined_chunk.cpu()[...,:hr_segment_overlap_samples], hr_segment_overlap_samples )
            else:
                pr_chunk = model(lr_chunk.unsqueeze(0).to(device)).squeeze(0)
            logger.info(f'lr chunk {i} shape: {lr_chunk.shape}')
            logger.info(f'pr chunk {i} shape: {pr_chunk.shape}')
            pr_chunks.append(pr_chunk.cpu())
            previous_chunk = lr_chunk

    pred_duration = time.time() - pred_start
    logger.info(f'prediction duration: {pred_duration}')

    pr = torch.concat(pr_chunks, dim=-1)

    logger.info(f'pr wav shape: {pr.shape}')

    return pr