gitesh-grover's picture
Upload 9 files
a48e448 verified
raw
history blame
5.64 kB
import os
import math
import time
import inspect
from dataclasses import dataclass
import torch
import torch.nn as nn
from torch.nn import functional as F
class Attention(nn.Module):
def __init__(self, config):
super().__init__()
assert config.nn_embed % config.nn_head == 0
self.nn_head = config.nn_head
self.nn_embed = config.nn_embed
# K,Q,V NN layer calculated for the every token of the every batch
self.w_qkv = nn.Linear(config.nn_embed, config.nn_embed * 3) # (X, embed) -> (X, 3*embed)
# Projection layer to mix up the heads or the every token of the every batch
self.proj = nn.Linear(config.nn_embed, config.nn_embed) # (X, embed) -> (X, embed)
# TODO What does the following line do (coiped from class)
self.register_buffer("bias", torch.tril(torch.ones(config.nn_max_tok_seq, config.nn_max_tok_seq)).view(1, 1, config.nn_max_tok_seq, config.nn_max_tok_seq))
def forward(self, x):
B, T, E = x.size() # Batch size, token numbers, Embediing(nn_embed)
q, k, v = self.w_qkv(x).split(self.nn_embed, dim=2) # Split the last dimension in size od embed ie into 3
# divide the q,k,v last dim in groups (heads) and then shuffle to for the calculation
q = q.view(B, T, self.nn_head, E//self.nn_head).transpose(1,2) # (B, head, T, headEmbedSize)
k = k.view(B, T, self.nn_head, E//self.nn_head).transpose(1,2) # (B, head, T, headEmbedSize)
v = v.view(B, T, self.nn_head, E//self.nn_head).transpose(1,2) # (B, head, T, headEmbedSize)
att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1))) # Q*K / sqt(headEmbedSize)...(B, head, T, T)
att = att.masked_fill(self.bias[:, :, :T, :T] == 0, float('-inf')) # Mask fill the B,headEmbedSize,T.a,T.b with -infinity where T.a < T.b
att = F.softmax(att, dim = -1) # maxFilled vals -infinity will become 0 after softmax
y = att @ v # B, head, T, headEmbedSize
# Shuffle the head and headEmbedSize together and append one after another to get back embed
y = y.transpose(1,2).contiguous().view(B, T, E) # B, T, head, headEmbedSize -> B, T, E
# Projection NN layer to shuffle the last dim that were stacked together
y = self.proj(y) # B, T, E
return y
# Feed Forward NN Layer
class MLP(nn.Module):
def __init__(self, config):
super().__init__()
self.fc = nn.Linear(config.nn_embed, config.nn_embed * config.nn_mlp_expansion)
self.gelu = nn.GELU(approximate='tanh')
self.proj = nn.Linear(config.nn_embed * config.nn_mlp_expansion, config.nn_embed)
self.proj.NANGPT_SCALE_INIT = 1
def forward(self, x):
x = self.fc(x)
x = self.gelu(x)
x = self.proj(x)
return x
class Block(nn.Module):
def __init__(self, config):
super().__init__()
self.ln_1 = nn.LayerNorm(config.nn_embed)
self.att = Attention(config)
self.ln_2 = nn.LayerNorm(config.nn_embed)
self.mlp = MLP(config)
def forward(self, x):
x = x + self.att(self.ln_1(x))
x = x + self.mlp(self.ln_2(x))
return x
class DecoderTransformer(nn.Module):
def __init__(self, config):
super().__init__()
self.config = config
self.wte = nn.Embedding(config.vocab_size, config.nn_embed)
self.wpe = nn.Embedding(config.nn_max_tok_seq, config.nn_embed)
self.blocks = nn.ModuleList([Block(config) for _ in range(0, config.nn_layer)])
self.lm_head = nn.Linear(config.nn_embed, config.vocab_size, bias=False)
# weight sharing for cost optimization
self.wte.weight = self.lm_head.weight
# weight initialization
self.apply(self._init_weights)
def _init_weights(self, module):
if isinstance(module, nn.Linear):
std = 0.02
if hasattr(module, 'NANGPT_SCALE_INIT'):
std *= (2 * self.config.nn_layer) ** -0.5
torch.nn.init.normal_(module.weight, mean = 0.0, std = std)
if module.bias is not None:
torch.nn.init.zeros_(module.bias)
elif isinstance(module, nn.Embedding):
torch.nn.init.normal_(module.weight, mean=0.0, std = 0.02)
def forward(self, idx, targets=None):
B, T = idx.size()
assert T <= self.config.nn_max_tok_seq, f"Token length ({T}) can not exceed the max allowed sequence size (block size) ({self.config.nn_max_tok_seq})"
# Embedding Layer
pos = torch.arange(0, T, dtype=torch.long, device=idx.device) # 1-D vector from 0..T represing token seq of a single batch
pos_embed = self.wpe(pos) # position embedding (T, nn_embed) - every token of given sequence will have a nn_embed size output
tok_embed = self.wte(idx) # Token embedding (B, T, nn_embed) - every token of a batch will have individual token embedding
# As pos embedding would be same for all the batches (as it is based on token sequence and not value), it can be added to every batch as is
x = pos_embed + tok_embed # B, T, nn_embed
# Transformer blocks..nn_layers
for block in self.blocks:
x = block(x) # B, T, nn_embed
# Head - last layer
logits = self.lm_head(x) # B, T, vocab_size
# If targets are supplied, calculate loss and return both logits & loss, otherwise just the loss
loss = None
if targets is not None:
loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1))
return logits, loss