soccer-qa-4b / src /models /utils /patch_embed.py
VarunKodathala's picture
Upload folder using huggingface_hub
0e37bb2 verified
raw
history blame
1.34 kB
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import torch.nn as nn
from einops import rearrange
class PatchEmbed(nn.Module):
"""
Image to Patch Embedding
"""
def __init__(self, patch_size=16, in_chans=3, embed_dim=768):
super().__init__()
self.patch_size = patch_size
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
def forward(self, x):
B, C, H, W = x.shape
x = self.proj(x).flatten(2).transpose(1, 2)
return x
class PatchEmbed3D(nn.Module):
"""
Image to Patch Embedding
"""
def __init__(
self,
patch_size=16,
tubelet_size=2,
in_chans=3,
embed_dim=768,
):
super().__init__()
self.patch_size = patch_size
self.tubelet_size = tubelet_size
self.proj = nn.Conv3d(
in_channels=in_chans,
out_channels=embed_dim,
kernel_size=(tubelet_size, patch_size, patch_size),
stride=(tubelet_size, patch_size, patch_size),
)
def forward(self, x, **kwargs):
B, C, T, H, W = x.shape
x = self.proj(x).flatten(2).transpose(1, 2)
return x