import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from multiheadattention import MultiHeadAttention | |
class TransformerBlock(nn.Module): | |
def __init__(self, embed_dim, num_heads, ff_dim): | |
super().__init__() | |
self.attn = MultiHeadAttention(embed_dim, num_heads) | |
self.ln1 = nn.LayerNorm(embed_dim) | |
self.ff = nn.Sequential( | |
nn.Linear(embed_dim, ff_dim), | |
nn.GELU(), | |
nn.Linear(ff_dim, embed_dim) | |
) | |
self.ln2 = nn.LayerNorm(embed_dim) | |
def forward(self, x, mask=None): | |
x = x + self.attn(self.ln1(x), mask = mask) | |
x = x + self.ff(self.ln2(x)) | |
return x | |