sereich's picture
Initial commit of Radio Upscaling UI (minus models)
f113387
raw
history blame
1.43 kB
import functools
import math
from torch.nn import functional as F
from contextlib import contextmanager
def capture_init(init):
"""capture_init.
Decorate `__init__` with this, and you can then
recover the *args and **kwargs passed to it in `self._init_args_kwargs`
"""
@functools.wraps(init)
def __init__(self, *args, **kwargs):
self._init_args_kwargs = (args, kwargs)
init(self, *args, **kwargs)
return __init__
def unfold(a, kernel_size, stride):
"""Given input of size [*OT, T], output Tensor of size [*OT, F, K]
with K the kernel size, by extracting frames with the given stride.
This will pad the input so that `F = ceil(T / K)`.
see https://github.com/pytorch/pytorch/issues/60466
"""
*shape, length = a.shape
n_frames = math.ceil(length / stride)
tgt_length = (n_frames - 1) * stride + kernel_size
a = F.pad(a, (0, tgt_length - length))
strides = list(a.stride())
assert strides[-1] == 1, 'data should be contiguous'
strides = strides[:-1] + [stride, 1]
return a.as_strided([*shape, n_frames, kernel_size], strides)
def colorize(text, color):
"""
Display text with some ANSI color in the terminal.
"""
code = f"\033[{color}m"
restore = "\033[0m"
return "".join([code, text, restore])
def bold(text):
"""
Display text in bold in the terminal.
"""
return colorize(text, "1")