Spaces:
Sleeping
Sleeping
| # Copyright (c) Meta Platforms, Inc. and affiliates. | |
| # All rights reserved. | |
| # This source code is licensed under the license found in the | |
| # LICENSE file in the root directory of this source tree. | |
| import copy | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| def select_closest_cond_frames(frame_idx, cond_frame_outputs, max_cond_frame_num): | |
| """ | |
| Select up to `max_cond_frame_num` conditioning frames from `cond_frame_outputs` | |
| that are temporally closest to the current frame at `frame_idx`. Here, we take | |
| - a) the closest conditioning frame before `frame_idx` (if any); | |
| - b) the closest conditioning frame after `frame_idx` (if any); | |
| - c) any other temporally closest conditioning frames until reaching a total | |
| of `max_cond_frame_num` conditioning frames. | |
| Outputs: | |
| - selected_outputs: selected items (keys & values) from `cond_frame_outputs`. | |
| - unselected_outputs: items (keys & values) not selected in `cond_frame_outputs`. | |
| """ | |
| if max_cond_frame_num == -1 or len(cond_frame_outputs) <= max_cond_frame_num: | |
| selected_outputs = cond_frame_outputs | |
| unselected_outputs = {} | |
| else: | |
| assert max_cond_frame_num >= 2, "we should allow using 2+ conditioning frames" | |
| selected_outputs = {} | |
| # the closest conditioning frame before `frame_idx` (if any) | |
| idx_before = max((t for t in cond_frame_outputs if t < frame_idx), default=None) | |
| if idx_before is not None: | |
| selected_outputs[idx_before] = cond_frame_outputs[idx_before] | |
| # the closest conditioning frame after `frame_idx` (if any) | |
| idx_after = min((t for t in cond_frame_outputs if t >= frame_idx), default=None) | |
| if idx_after is not None: | |
| selected_outputs[idx_after] = cond_frame_outputs[idx_after] | |
| # add other temporally closest conditioning frames until reaching a total | |
| # of `max_cond_frame_num` conditioning frames. | |
| num_remain = max_cond_frame_num - len(selected_outputs) | |
| inds_remain = sorted( | |
| (t for t in cond_frame_outputs if t not in selected_outputs), | |
| key=lambda x: abs(x - frame_idx), | |
| )[:num_remain] | |
| selected_outputs.update((t, cond_frame_outputs[t]) for t in inds_remain) | |
| unselected_outputs = { | |
| t: v for t, v in cond_frame_outputs.items() if t not in selected_outputs | |
| } | |
| return selected_outputs, unselected_outputs | |
| def get_1d_sine_pe(pos_inds, dim, temperature=10000): | |
| """ | |
| Get 1D sine positional embedding as in the original Transformer paper. | |
| """ | |
| pe_dim = dim // 2 | |
| dim_t = torch.arange(pe_dim, dtype=torch.float32, device=pos_inds.device) | |
| dim_t = temperature ** (2 * (dim_t // 2) / pe_dim) | |
| pos_embed = pos_inds.unsqueeze(-1) / dim_t | |
| pos_embed = torch.cat([pos_embed.sin(), pos_embed.cos()], dim=-1) | |
| return pos_embed | |
| def get_activation_fn(activation): | |
| """Return an activation function given a string""" | |
| if activation == "relu": | |
| return F.relu | |
| if activation == "gelu": | |
| return F.gelu | |
| if activation == "glu": | |
| return F.glu | |
| raise RuntimeError(f"activation should be relu/gelu, not {activation}.") | |
| def get_clones(module, N): | |
| return nn.ModuleList([copy.deepcopy(module) for i in range(N)]) | |
| class DropPath(nn.Module): | |
| # adapted from https://github.com/huggingface/pytorch-image-models/blob/main/timm/layers/drop.py | |
| def __init__(self, drop_prob=0.0, scale_by_keep=True): | |
| super(DropPath, self).__init__() | |
| self.drop_prob = drop_prob | |
| self.scale_by_keep = scale_by_keep | |
| def forward(self, x): | |
| if self.drop_prob == 0.0 or not self.training: | |
| return x | |
| keep_prob = 1 - self.drop_prob | |
| shape = (x.shape[0],) + (1,) * (x.ndim - 1) | |
| random_tensor = x.new_empty(shape).bernoulli_(keep_prob) | |
| if keep_prob > 0.0 and self.scale_by_keep: | |
| random_tensor.div_(keep_prob) | |
| return x * random_tensor | |
| # Lightly adapted from | |
| # https://github.com/facebookresearch/MaskFormer/blob/main/mask_former/modeling/transformer/transformer_predictor.py # noqa | |
| class MLP(nn.Module): | |
| def __init__( | |
| self, | |
| input_dim: int, | |
| hidden_dim: int, | |
| output_dim: int, | |
| num_layers: int, | |
| activation: nn.Module = nn.ReLU, | |
| sigmoid_output: bool = False, | |
| ) -> None: | |
| super().__init__() | |
| self.num_layers = num_layers | |
| h = [hidden_dim] * (num_layers - 1) | |
| self.layers = nn.ModuleList( | |
| nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim]) | |
| ) | |
| self.sigmoid_output = sigmoid_output | |
| self.act = activation() | |
| def forward(self, x): | |
| for i, layer in enumerate(self.layers): | |
| x = self.act(layer(x)) if i < self.num_layers - 1 else layer(x) | |
| if self.sigmoid_output: | |
| x = F.sigmoid(x) | |
| return x | |
| # From https://github.com/facebookresearch/detectron2/blob/main/detectron2/layers/batch_norm.py # noqa | |
| # Itself from https://github.com/facebookresearch/ConvNeXt/blob/d1fa8f6fef0a165b27399986cc2bdacc92777e40/models/convnext.py#L119 # noqa | |
| class LayerNorm2d(nn.Module): | |
| def __init__(self, num_channels: int, eps: float = 1e-6) -> None: | |
| super().__init__() | |
| self.weight = nn.Parameter(torch.ones(num_channels)) | |
| self.bias = nn.Parameter(torch.zeros(num_channels)) | |
| self.eps = eps | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| u = x.mean(1, keepdim=True) | |
| s = (x - u).pow(2).mean(1, keepdim=True) | |
| x = (x - u) / torch.sqrt(s + self.eps) | |
| x = self.weight[:, None, None] * x + self.bias[:, None, None] | |
| return x | |