Transformer-MiniGPT / multiheadattention.py
Austin207's picture
Upload folder using huggingface_hub
e24d6f1 verified
import torch
import torch.nn as nn
import torch.nn.functional as F
class MultiHeadAttention(nn.Module):
def __init__(self, embed_dim, num_heads):
super().__init__()
assert embed_dim % num_heads == 0, "Embedding dim must be divisible by num heads"
self.embed_dim = embed_dim
self.num_heads = num_heads
self.head_dim = embed_dim // num_heads
self.qkv_proj = nn.Linear(embed_dim, 3 * embed_dim)
self.out_proj = nn.Linear(embed_dim, embed_dim)
def forward(self, x, mask=None):
B, T, C = x.shape
qkv = self.qkv_proj(x)
qkv = qkv.reshape(B, T, self.num_heads, 3 * self.head_dim)
qkv = qkv.permute(0, 2, 1, 3)
q, k, v = qkv.chunk(3, dim=-1)
attn_scores = (q @ k.transpose(-2, -1)) / (self.head_dim ** 0.5)
if mask is not None:
attn_scores = attn_scores.masked_fill(mask == 0, float('-inf'))
attn_weights = F.softmax(attn_scores, dim=-1)
attn_output = attn_weights @ v
attn_output = attn_output.transpose(1, 2).reshape(B, T, C)
ouptut = self.out_proj(attn_output)
return ouptut