Transformer-MiniGPT / transformer.py
Austin207's picture
Upload folder using huggingface_hub
e24d6f1 verified
import torch
import torch.nn as nn
import torch.nn.functional as F
from multiheadattention import MultiHeadAttention
class TransformerBlock(nn.Module):
def __init__(self, embed_dim, num_heads, ff_dim):
super().__init__()
self.attn = MultiHeadAttention(embed_dim, num_heads)
self.ln1 = nn.LayerNorm(embed_dim)
self.ff = nn.Sequential(
nn.Linear(embed_dim, ff_dim),
nn.GELU(),
nn.Linear(ff_dim, embed_dim)
)
self.ln2 = nn.LayerNorm(embed_dim)
def forward(self, x, mask=None):
x = x + self.attn(self.ln1(x), mask = mask)
x = x + self.ff(self.ln2(x))
return x