import torch import torch.nn as nn from .PositionalEncoding import PositionalEncoding class AntiCheatPT_256(nn.Module): def __init__( self, feature_dim=44, # nr of features per tick seq_len=256, # nr of ticks nhead=1, # nr of attention heads num_layers=4, # nr of transformer encoder layers dim_feedforward=176, # hidden size of feedforward network (MLP) dropout=0.1 # dropout rate ): super(AntiCheatPT_256, self).__init__() self.positional_encoding = PositionalEncoding(d_model=feature_dim, max_len=seq_len + 1) # +1 for CLS token encoder_layer = nn.TransformerEncoderLayer( d_model=feature_dim, nhead=nhead, dim_feedforward=dim_feedforward, dropout=dropout, batch_first=True ) # input shape is (batch, seq_len, d_model) self.transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=num_layers) self.cls_token = nn.Parameter(torch.zeros(1, 1, feature_dim)) # add classification token self.fc_out = nn.Sequential( nn.Linear(feature_dim, 128), nn.ReLU(), nn.Linear(128, 1) ) def forward(self, x): # x shape: (batch_size, seq_len, feature_dim) = (batch_size, 256, d_model) B = x.size(0) # add classification token cls_tokens = self.cls_token.expand(B, -1, -1).to(x.device) # (batch_size, 1, d_model) x = torch.cat((cls_tokens, x), dim=1) # -> (batch_size, 257, d_model) x = self.positional_encoding(x) # add positional encoding x = self.transformer_encoder(x) # -> (batch_size, 257, d_model) cls_output = x[:, 0] # get output for classification token out = self.fc_out(cls_output) # -> (batch_size, 1) return out.squeeze(1) # -> (batch_size,)