import torch | |
import torch.nn as nn | |
from .PositionalEncoding import PositionalEncoding | |
class AntiCheatPT_256(nn.Module): | |
def __init__( | |
self, | |
feature_dim=44, # nr of features per tick | |
seq_len=256, # nr of ticks | |
nhead=1, # nr of attention heads | |
num_layers=4, # nr of transformer encoder layers | |
dim_feedforward=176, # hidden size of feedforward network (MLP) | |
dropout=0.1 # dropout rate | |
): | |
super(AntiCheatPT_256, self).__init__() | |
self.positional_encoding = PositionalEncoding(d_model=feature_dim, max_len=seq_len + 1) # +1 for CLS token | |
encoder_layer = nn.TransformerEncoderLayer( | |
d_model=feature_dim, | |
nhead=nhead, | |
dim_feedforward=dim_feedforward, | |
dropout=dropout, | |
batch_first=True | |
) # input shape is (batch, seq_len, d_model) | |
self.transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=num_layers) | |
self.cls_token = nn.Parameter(torch.zeros(1, 1, feature_dim)) # add classification token | |
self.fc_out = nn.Sequential( | |
nn.Linear(feature_dim, 128), | |
nn.ReLU(), | |
nn.Linear(128, 1) | |
) | |
def forward(self, x): | |
# x shape: (batch_size, seq_len, feature_dim) = (batch_size, 256, d_model) | |
B = x.size(0) | |
# add classification token | |
cls_tokens = self.cls_token.expand(B, -1, -1).to(x.device) # (batch_size, 1, d_model) | |
x = torch.cat((cls_tokens, x), dim=1) # -> (batch_size, 257, d_model) | |
x = self.positional_encoding(x) # add positional encoding | |
x = self.transformer_encoder(x) # -> (batch_size, 257, d_model) | |
cls_output = x[:, 0] # get output for classification token | |
out = self.fc_out(cls_output) # -> (batch_size, 1) | |
return out.squeeze(1) # -> (batch_size,) |