Spaces:
Runtime error
Runtime error
| # Copyright (c) Meta Platforms, Inc. and affiliates. | |
| # All rights reserved. | |
| # This source code is licensed under the license found in the | |
| # LICENSE file in the root directory of this source tree. | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from typing import Type | |
| class MLPBlock(nn.Module): | |
| def __init__( | |
| self, | |
| embedding_dim: int, | |
| mlp_dim: int, | |
| act: Type[nn.Module] = nn.GELU, | |
| ) -> None: | |
| super().__init__() | |
| self.lin1 = nn.Linear(embedding_dim, mlp_dim) | |
| self.lin2 = nn.Linear(mlp_dim, embedding_dim) | |
| self.act = act() | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| return self.lin2(self.act(self.lin1(x))) | |
| # From https://github.com/facebookresearch/detectron2/blob/main/detectron2/layers/batch_norm.py # noqa | |
| # Itself from https://github.com/facebookresearch/ConvNeXt/blob/d1fa8f6fef0a165b27399986cc2bdacc92777e40/models/convnext.py#L119 # noqa | |
| class LayerNorm2d(nn.Module): | |
| def __init__(self, num_channels: int, eps: float = 1e-6) -> None: | |
| super().__init__() | |
| self.weight = nn.Parameter(torch.ones(num_channels)) | |
| self.bias = nn.Parameter(torch.zeros(num_channels)) | |
| self.eps = eps | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| u = x.mean(1, keepdim=True) | |
| s = (x - u).pow(2).mean(1, keepdim=True) | |
| x = (x - u) / torch.sqrt(s + self.eps) | |
| x = self.weight[:, None, None] * x + self.bias[:, None, None] | |
| return x | |
| def val2list(x: list or tuple or any, repeat_time=1) -> list: | |
| if isinstance(x, (list, tuple)): | |
| return list(x) | |
| return [x for _ in range(repeat_time)] | |
| def val2tuple(x: list or tuple or any, min_len: int = 1, idx_repeat: int = -1) -> tuple: | |
| x = val2list(x) | |
| # repeat elements if necessary | |
| if len(x) > 0: | |
| x[idx_repeat:idx_repeat] = [x[idx_repeat] for _ in range(min_len - len(x))] | |
| return tuple(x) | |
| def list_sum(x: list) -> any: | |
| return x[0] if len(x) == 1 else x[0] + list_sum(x[1:]) | |
| def resize( | |
| x: torch.Tensor, | |
| size: any or None = None, | |
| scale_factor=None, | |
| mode: str = "bicubic", | |
| align_corners: bool or None = False, | |
| ) -> torch.Tensor: | |
| if mode in ["bilinear", "bicubic"]: | |
| return F.interpolate( | |
| x, | |
| size=size, | |
| scale_factor=scale_factor, | |
| mode=mode, | |
| align_corners=align_corners, | |
| ) | |
| elif mode in ["nearest", "area"]: | |
| return F.interpolate(x, size=size, scale_factor=scale_factor, mode=mode) | |
| else: | |
| raise NotImplementedError(f"resize(mode={mode}) not implemented.") | |
| class UpSampleLayer(nn.Module): | |
| def __init__( | |
| self, | |
| mode="bicubic", | |
| size=None, | |
| factor=2, | |
| align_corners=False, | |
| ): | |
| super(UpSampleLayer, self).__init__() | |
| self.mode = mode | |
| self.size = val2list(size, 2) if size is not None else None | |
| self.factor = None if self.size is not None else factor | |
| self.align_corners = align_corners | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| return resize(x, self.size, self.factor, self.mode, self.align_corners) | |
| class OpSequential(nn.Module): | |
| def __init__(self, op_list): | |
| super(OpSequential, self).__init__() | |
| valid_op_list = [] | |
| for op in op_list: | |
| if op is not None: | |
| valid_op_list.append(op) | |
| self.op_list = nn.ModuleList(valid_op_list) | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| for op in self.op_list: | |
| x = op(x) | |
| return x |