Character-Level GPT Model Trained on WikiText-2

This repository contains a character-level GPT model trained on the WikiText-2 dataset. The model architecture is a custom implementation of a transformer-based language model with Rotary Positional Embeddings (RoPE) and SwiGLU feed-forward networks.

This was Largely AI generated.

As in all of the code to make the model.

And a large amount of this description.

Model Description

  • Model Type: Custom Transformer-based Causal Language Model
  • Training Data: WikiText-2
  • Tokenization: Character-level
  • Architecture Details:
    • Rotary Positional Embeddings (RoPE)
    • SwiGLU Feed-Forward Networks
  • Parameters:
    • n_layer: 8
    • n_head: 8
    • n_embd: 512
    • block_size: 512
    • dropout: 0.1
    • vocab_size: 283

Intended Use

This model is intended for research purposes, including:

  • Experimenting with character-level language modeling.
  • Studying the effects of different training techniques on transformer models.

Limitations

  • Character-Level Tokenization: The model uses character-level tokenization, which is less efficient than subword tokenization (e.g., Byte-Pair Encoding) for capturing long-range dependencies and generating coherent text. As a result, the quality of generated text may be limited compared to models using subword tokenization.
  • Limited Training Data: The model was trained on the WikiText-2 dataset, which is relatively small. Training on a larger dataset would likely improve performance.
  • Custom Architecture: This is a custom model implementation, not a standard pre-trained model from the transformers library.
  • Requires Manual Intervention: Loading and using this model requires manual intervention and a deeper understanding of the architecture. The AutoModelForCausalLM class from transformers cannot be used.

Inference Code.

import torch
import torch.nn as nn
from torch.nn import functional as F
import math
import json
import os
from huggingface_hub import hf_hub_download

# --- Configuration ---
repo_id = "Ma7ee7/WikiGPT-25M"
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f"Using device: {device}")

# --- Download Necessary Files ---
print(f"Downloading files from {repo_id}...")
try:
    config_path = hf_hub_download(repo_id=repo_id, filename="config.json")
    weights_path = hf_hub_download(repo_id=repo_id, filename="pytorch_model.bin")
    vocab_path = hf_hub_download(repo_id=repo_id, filename="vocab.json")
    # tokenizer_config_path = hf_hub_download(repo_id=repo_id, filename="tokenizer_config.json") # Optional
    print("Files downloaded successfully.")
except Exception as e:
    print(f"Error downloading files: {e}")
    print("Please ensure the repository ID is correct and the files exist.")
    exit()

# --- Load Configuration ---
print("Loading configuration...")
try:
    with open(config_path, 'r') as f:
        config = json.load(f)
    print("Configuration loaded:")
    print(config)

    # Extract necessary hyperparameters from config
    vocab_size = config["vocab_size"]
    n_layer = config["n_layer"]
    n_head = config["n_head"]
    n_embd = config["n_embd"]
    block_size = config["block_size"]
    dropout = config["dropout"]

except Exception as e:
    print(f"Error loading config.json: {e}")
    exit()

# --- RoPE Helper Functions ---
def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0, device='cpu'):
    freqs_part = torch.arange(0, dim, 2)[: (dim // 2)].float() / dim
    freqs = 1.0 / (theta ** freqs_part).to(device)
    t = torch.arange(end, device=device)
    freqs = torch.outer(t, freqs).float()
    freqs_cis = torch.polar(torch.ones_like(freqs), freqs)
    return freqs_cis

def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor):
    ndim = x.ndim
    assert 0 <= 1 < ndim
    if not freqs_cis.shape == (x.shape[1], x.shape[-1]): # Check dimensions used
         raise ValueError(f"Freqs shape {freqs_cis.shape} does not match x shape {x.shape} at dims 1 and -1")
    shape = [1] * (ndim - 2) + list(freqs_cis.shape)
    return freqs_cis.view(*shape)

def apply_rotary_emb(xq: torch.Tensor, xk: torch.Tensor, freqs_cis: torch.Tensor):
    xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
    xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
    freqs_cis = reshape_for_broadcast(freqs_cis, xq_)
    xq_out_complex = xq_ * freqs_cis
    xk_out_complex = xk_ * freqs_cis
    xq_out = torch.view_as_real(xq_out_complex).flatten(start_dim=2)
    xk_out = torch.view_as_real(xk_out_complex).flatten(start_dim=2)
    return xq_out.type_as(xq), xk_out.type_as(xk)
# --- End RoPE Helpers ---


class Head(nn.Module):
    """ one head of self-attention with RoPE """
    def __init__(self, head_size):
        super().__init__()
        # Use parameters from loaded config directly if possible, or pass them
        self.key = nn.Linear(n_embd, head_size, bias=False)
        self.query = nn.Linear(n_embd, head_size, bias=False)
        self.value = nn.Linear(n_embd, head_size, bias=False)
        # Register buffer requires size, use block_size from config
        self.register_buffer('tril', torch.tril(torch.ones(block_size, block_size)))
        self.dropout = nn.Dropout(dropout) # Use dropout from config

    def forward(self, x, freqs_cis):
        B, T, C = x.shape
        k = self.key(x)
        q = self.query(x)

        q_rope, k_rope = apply_rotary_emb(q, k, freqs_cis=freqs_cis)

        head_size = q_rope.shape[-1]
        # Use tensor for scale
        scale = torch.tensor(head_size ** -0.5, device=q_rope.device, dtype=q_rope.dtype)
        wei = (q_rope @ k_rope.transpose(-2, -1)) * scale

        mask = self.tril[:T, :T] == 0
        wei = wei.masked_fill(mask.to(wei.device), float('-inf')) # Ensure mask on correct device

        wei = F.softmax(wei, dim=-1)
        wei = self.dropout(wei)

        v = self.value(x)
        out = wei @ v
        return out

class MultiHeadAttention(nn.Module):
    """ multiple heads of self-attention in parallel with RoPE """
    def __init__(self, num_heads, head_size):
        super().__init__()
        self.heads = nn.ModuleList([Head(head_size) for _ in range(num_heads)])
        self.proj = nn.Linear(n_embd, n_embd) # head_size * num_heads = n_embd
        self.dropout = nn.Dropout(dropout) # Use dropout from config

    def forward(self, x, freqs_cis):
        head_outputs = [h(x, freqs_cis) for h in self.heads]
        out = torch.cat(head_outputs, dim=-1)
        out = self.dropout(self.proj(out))
        return out

class SwiGLU(nn.Module):
    """ SwiGLU Feed-Forward Network """
    def __init__(self, n_embd, hidden_dim=None, dropout=0.1): # Allow dropout override
        super().__init__()
        if hidden_dim is None:
            hidden_dim = int(4 * n_embd * 2 / 3)
        self.w1 = nn.Linear(n_embd, hidden_dim, bias=False)
        self.w2 = nn.Linear(hidden_dim, n_embd, bias=False)
        self.w3 = nn.Linear(n_embd, hidden_dim, bias=False)
        self.dropout = nn.Dropout(dropout) # Use passed dropout

    def forward(self, x):
        gate = self.w3(x)
        value = self.w1(x)
        swish_gate = F.silu(gate)
        out = swish_gate * value
        out = self.dropout(self.w2(out))
        return out

class Block(nn.Module):
    """ Transformer block using RoPE attention and SwiGLU FFN """
    def __init__(self, n_embd, n_head):
        super().__init__()
        head_size = n_embd // n_head
        self.sa = MultiHeadAttention(n_head, head_size)
        # Pass dropout rate from config to SwiGLU
        self.ffwd = SwiGLU(n_embd, dropout=dropout)
        self.ln1 = nn.LayerNorm(n_embd)
        self.ln2 = nn.LayerNorm(n_embd)

    def forward(self, x, freqs_cis):
        x = x + self.sa(self.ln1(x), freqs_cis)
        x = x + self.ffwd(self.ln2(x))
        return x

class GPTLanguageModel(nn.Module):
    def __init__(self, vocab_size, n_embd, n_head, n_layer, block_size, dropout):
        super().__init__()
        self.token_embedding_table = nn.Embedding(vocab_size, n_embd)
        self.blocks = nn.ModuleList([Block(n_embd, n_head=n_head) for _ in range(n_layer)])
        self.ln_f = nn.LayerNorm(n_embd)
        self.lm_head = nn.Linear(n_embd, vocab_size, bias=False)

        self.block_size = block_size # Store block_size

        # Precompute RoPE frequencies using parameters
        head_dim = n_embd // n_head
        # Pass device explicitly during precomputation if model is moved later
        freqs_cis_buffer = precompute_freqs_cis(head_dim, block_size * 2, device='cpu') # Compute on CPU first
        self.register_buffer("freqs_cis", freqs_cis_buffer, persistent=False) # Register but don't save in state_dict

        # No weight initialization needed here for inference

    def forward(self, idx, targets=None):
        B, T = idx.shape
        # Ensure T doesn't exceed block size for freqs_cis slicing
        # In generate, idx_cond handles this, but good check here too
        T_used = min(T, self.block_size)

        tok_emb = self.token_embedding_table(idx[:, -T_used:]) # Use only last block_size tokens if T > block_size

        # Retrieve precomputed RoPE frequencies for the actual sequence length T_used
        # Move required part of freqs_cis to the same device as embeddings
        freqs_cis_for_block = self.freqs_cis[:T_used].to(tok_emb.device)

        x = tok_emb
        for block in self.blocks:
            x = block(x, freqs_cis_for_block)

        x = self.ln_f(x)

        if targets is not None:
            # This path isn't typically used during inference
            logits = self.lm_head(x)
            loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1))
        else:
            # Inference: compute logits only for the last token
            logits = self.lm_head(x[:, [-1], :]) # (B, 1, vocab_size)
            loss = None

        return logits, loss

    @torch.no_grad() # Ensure no gradients are computed during generation
    def generate(self, idx, max_new_tokens):
        self.eval() # Ensure model is in eval mode
        for _ in range(max_new_tokens):
            # Crop context if it exceeds block size *before* forward pass
            idx_cond = idx if idx.size(1) <= self.block_size else idx[:, -self.block_size:]
            # Forward pass for inference (gets logits for the last token)
            logits, _ = self(idx_cond) # Call forward with targets=None
            logits = logits[:, -1, :] # Shape (B, vocab_size)

            # Softmax and sampling
            probs = F.softmax(logits, dim=-1)
            idx_next = torch.multinomial(probs, num_samples=1)
            # Append
            idx = torch.cat((idx, idx_next), dim=1)
        self.train() # Optional: set back to train mode if needed elsewhere
        return idx
# --- End Model Definitions ---


# --- Instantiate Model ---
print("Instantiating model...")
try:
    model = GPTLanguageModel(
        vocab_size=vocab_size,
        n_embd=n_embd,
        n_head=n_head,
        n_layer=n_layer,
        block_size=block_size,
        dropout=dropout
    )
    print("Model instantiated.")
except Exception as e:
    print(f"Error instantiating model: {e}")
    print("Ensure the class definitions above match the configuration.")
    exit()

# --- Load Weights ---
print("Loading model weights...")
try:
    state_dict = torch.load(weights_path, map_location=torch.device('cpu')) # Load to CPU first
    # Adapt state_dict if necessary (e.g., if module names changed)
    # Example: remove unexpected keys like 'freqs_cis' if they were accidentally saved
    state_dict.pop("freqs_cis", None)

    load_result = model.load_state_dict(state_dict, strict=True) # Use strict=True initially
    print(f"Weight loading result: {load_result}")
    print("Model weights loaded successfully.")
except Exception as e:
    print(f"Error loading weights: {e}")
    print("Ensure the model architecture definition matches the saved weights.")
    # If using strict=False, check missing/unexpected keys printed by load_state_dict
    exit()

# --- Setup for Inference ---
model.eval() # Set to evaluation mode (disable dropout etc.)
model.to(device) # Move model to target device
print(f"Model moved to {device} and set to evaluation mode.")

# --- Load Vocabulary and Define Encode/Decode ---
print("Loading vocabulary...")
try:
    with open(vocab_path, 'r', encoding='utf-8') as f:
        stoi = json.load(f)
    itos = {i: ch for ch, i in stoi.items()}
    print(f"Vocabulary loaded ({len(stoi)} chars).")

    # Define encoding/decoding functions based on loaded vocab
    encode = lambda s: [stoi.get(c, stoi.get('\n')) for c in s] # Use \n as fallback? Or a dedicated UNK?
    decode = lambda l: ''.join([itos.get(i, '?') for i in l])   # Use ? for unknown indices

except Exception as e:
    print(f"Error loading vocabulary: {e}")
    exit()


# --- Inference ---
print("\n--- Starting Inference ---")
prompt = " = = Artificial Intelligence in Medicine = = \n\n Artificial intelligence ( AI ) has" # Example prompt

max_tokens_to_generate = 500

print(f"Prompt:\n{prompt}")
print(f"\nGenerating ({max_tokens_to_generate} tokens)...")

# Encode the prompt
encoded_prompt = torch.tensor(encode(prompt), dtype=torch.long, device=device).unsqueeze(0)

# Generate text
# Ensure generate method uses torch.no_grad() internally
generated_ids = model.generate(encoded_prompt, max_new_tokens=max_tokens_to_generate)
generated_text = decode(generated_ids[0].tolist()) # Decode the generated indices

print("\n--- Generated Text ---")
print(generated_text)
print("\n----------------------")

This model was trained in about 1 hour and 30 minutes, so it has basic word connection (In English) but thats about it.

Downloads last month
4
Inference Providers NEW
This model isn't deployed by any Inference Provider. ๐Ÿ™‹ Ask for provider support