|
import functools |
|
import math |
|
|
|
from torch.nn import functional as F |
|
|
|
from contextlib import contextmanager |
|
|
|
|
|
def capture_init(init): |
|
"""capture_init. |
|
|
|
Decorate `__init__` with this, and you can then |
|
recover the *args and **kwargs passed to it in `self._init_args_kwargs` |
|
""" |
|
|
|
@functools.wraps(init) |
|
def __init__(self, *args, **kwargs): |
|
self._init_args_kwargs = (args, kwargs) |
|
init(self, *args, **kwargs) |
|
|
|
return __init__ |
|
|
|
|
|
def unfold(a, kernel_size, stride): |
|
"""Given input of size [*OT, T], output Tensor of size [*OT, F, K] |
|
with K the kernel size, by extracting frames with the given stride. |
|
This will pad the input so that `F = ceil(T / K)`. |
|
see https://github.com/pytorch/pytorch/issues/60466 |
|
""" |
|
*shape, length = a.shape |
|
n_frames = math.ceil(length / stride) |
|
tgt_length = (n_frames - 1) * stride + kernel_size |
|
a = F.pad(a, (0, tgt_length - length)) |
|
strides = list(a.stride()) |
|
assert strides[-1] == 1, 'data should be contiguous' |
|
strides = strides[:-1] + [stride, 1] |
|
return a.as_strided([*shape, n_frames, kernel_size], strides) |
|
|
|
def colorize(text, color): |
|
""" |
|
Display text with some ANSI color in the terminal. |
|
""" |
|
code = f"\033[{color}m" |
|
restore = "\033[0m" |
|
return "".join([code, text, restore]) |
|
|
|
|
|
def bold(text): |
|
""" |
|
Display text in bold in the terminal. |
|
""" |
|
return colorize(text, "1") |