from typing import Optional import torch from torch import Tensor from torch.nn import Linear, Module from transformers import PreTrainedModel from .encoder import MarlinEncoder from .decoder import MarlinDecoder from .config import MarlinConfig class Marlin(Module): def __init__( self, img_size: int, patch_size: int, n_frames: int, encoder_embed_dim: int, encoder_depth: int, encoder_num_heads: int, decoder_embed_dim: int, decoder_depth: int, decoder_num_heads: int, mlp_ratio: float, qkv_bias: bool, qk_scale: Optional[float], drop_rate: float, attn_drop_rate: float, norm_layer: str, init_values: float, tubelet_size: int, as_feature_extractor: bool = True, ): super().__init__() self.encoder = MarlinEncoder( img_size=img_size, patch_size=patch_size, n_frames=n_frames, embed_dim=encoder_embed_dim, depth=encoder_depth, num_heads=encoder_num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, drop_rate=drop_rate, attn_drop_rate=attn_drop_rate, norm_layer=norm_layer, init_values=init_values, tubelet_size=tubelet_size, ) self.as_feature_extractor = as_feature_extractor self.clip_frames = n_frames if as_feature_extractor: self.enc_dec_proj = None self.decoder = None else: self.decoder = MarlinDecoder( img_size=img_size, patch_size=patch_size, embed_dim=decoder_embed_dim, depth=decoder_depth, num_heads=decoder_num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, drop_rate=drop_rate, attn_drop_rate=attn_drop_rate, norm_layer=norm_layer, init_values=init_values, tubelet_size=tubelet_size, ) self.enc_dec_proj = Linear(encoder_embed_dim, decoder_embed_dim, bias=False) def forward(self, x: Tensor, mask: Tensor = None) -> Tensor: if self.as_feature_extractor: raise RuntimeError( "For feature extraction, please use `extract_features` or `extract_video`." ) else: assert mask is not None x = self.encoder(x, mask) x = self.enc_dec_proj(x) x = self.decoder(x, mask) return x @property def device(self): return self.encoder.norm.weight.device def extract_features(self, x: Tensor, keep_seq: bool = True): """Extract features for one video clip (v)""" if self.training: return self.encoder.extract_features(x, seq_mean_pool=not keep_seq) else: with torch.no_grad(): return self.encoder.extract_features(x, seq_mean_pool=not keep_seq) class MarlinModel(PreTrainedModel): config_class = MarlinConfig def __init__(self, config: MarlinConfig): super().__init__(config) self.config = config self.marlin = Marlin( img_size=config.img_size, patch_size=config.patch_size, n_frames=config.n_frames, encoder_embed_dim=config.encoder_embed_dim, encoder_depth=config.encoder_depth, encoder_num_heads=config.encoder_num_heads, decoder_embed_dim=config.decoder_embed_dim, decoder_depth=config.decoder_depth, decoder_num_heads=config.decoder_num_heads, mlp_ratio=config.mlp_ratio, qkv_bias=config.qkv_bias, qk_scale=config.qk_scale, drop_rate=config.drop_rate, attn_drop_rate=config.attn_drop_rate, norm_layer=config.norm_layer, init_values=config.init_values, tubelet_size=config.tubelet_size, ) def forward(self, x: Tensor, keep_seq: bool = True): return self.marlin.extract_features(x, keep_seq=keep_seq)