|
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
|
|
class MultiHeadAttention(nn.Module):
|
|
def __init__(self, embed_dim, num_heads):
|
|
super().__init__()
|
|
assert embed_dim % num_heads == 0, "Embedding dim must be divisible by num heads"
|
|
self.embed_dim = embed_dim
|
|
self.num_heads = num_heads
|
|
self.head_dim = embed_dim // num_heads
|
|
|
|
self.qkv_proj = nn.Linear(embed_dim, 3 * embed_dim)
|
|
self.out_proj = nn.Linear(embed_dim, embed_dim)
|
|
|
|
def forward(self, x, mask=None):
|
|
B, T, C = x.shape
|
|
qkv = self.qkv_proj(x)
|
|
qkv = qkv.reshape(B, T, self.num_heads, 3 * self.head_dim)
|
|
qkv = qkv.permute(0, 2, 1, 3)
|
|
q, k, v = qkv.chunk(3, dim=-1)
|
|
|
|
attn_scores = (q @ k.transpose(-2, -1)) / (self.head_dim ** 0.5)
|
|
|
|
if mask is not None:
|
|
attn_scores = attn_scores.masked_fill(mask == 0, float('-inf'))
|
|
|
|
attn_weights = F.softmax(attn_scores, dim=-1)
|
|
attn_output = attn_weights @ v
|
|
|
|
|
|
attn_output = attn_output.transpose(1, 2).reshape(B, T, C)
|
|
ouptut = self.out_proj(attn_output)
|
|
return ouptut |