# Required Packages

In [None]:
!pip install torch transformers tiktoken datasets

# Imports and Setup

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, RandomSampler
from torch.optim import AdamW
from torch.cuda.amp import GradScaler, autocast
from datasets import load_dataset
import numpy as np
from IPython.display import clear_output
import json
import os
from typing import Dict, List, Optional, Tuple
import tiktoken

# Check if GPU is available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

shuffle_generator = torch.Generator()
shuffle_generator.manual_seed(42)

os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"

# Configuration Class

In [None]:
class Config:
    def __init__(self):
        # Model architecture
        self.vocab_size = 100283
        self.max_position_embeddings = 1024
        self.hidden_size = 768
        self.num_layers = 6
        self.num_heads = 12
        self.intermediate_size = 3072
        self.dropout = 0.1
        
        # Training
        self.batch_size = 4
        self.learning_rate = 3e-4
        self.weight_decay = 0.01
        self.warmup_steps = 1000
        self.max_epochs = 3
        self.gradient_accumulation_steps = 8
        self.max_grad_norm = 1.0
        
        # Checkpointing
        self.checkpoint_every = 300  # Save every N batches
        self.evaluation_every = 500  # Evaluate every N batches

# Dataset Class

In [None]:
class TextDataset(Dataset):
    def __init__(self, text: str, block_size: int, tokenizer, chunk_size: int = 1024):
        self.tokenizer = tokenizer
        self.block_size = block_size
        self.chunk_size = chunk_size

        # Tokenize the entire text first
        tokens = self.tokenizer.encode(text, allowed_special={'<user>', '</user>', '<assistant>', '</assistant>', '<system>', '</system>'})
        
        # Process the tokenized text in chunks
        self.examples = []
        for chunk_start in range(0, len(tokens), self.chunk_size):
            # Get the token chunk
            chunk = tokens[chunk_start:chunk_start + self.chunk_size]
            
            # Create overlapping blocks from the tokenized chunk
            for i in range(0, len(chunk) - block_size + 1):
                self.examples.append(chunk[i:i + block_size])

    def __len__(self):
        return len(self.examples)
    
    def __getitem__(self, i):
        return torch.tensor(self.examples[i], dtype=torch.long)

# Model Architecture Classes

In [None]:
class AttentionHead(nn.Module):
    def __init__(self, config: Config):
        super().__init__()
        self.head_dim = config.hidden_size // config.num_heads
        self.query = nn.Linear(config.hidden_size, self.head_dim)
        self.key = nn.Linear(config.hidden_size, self.head_dim)
        self.value = nn.Linear(config.hidden_size, self.head_dim)
        
    def forward(self, x: torch.Tensor, mask: Optional[torch.Tensor] = None) -> torch.Tensor:
        Q = self.query(x)
        K = self.key(x)
        V = self.value(x)
        
        # Scaled dot-product attention
        scale = Q.size(-1) ** 0.5
        scores = torch.matmul(Q, K.transpose(-2, -1)) / scale
        
        if mask is not None:
            scores = scores.masked_fill(mask == 0, float('-inf'))
        
        attention = F.softmax(scores, dim=-1)
        return torch.matmul(attention, V)

class MultiHeadAttention(nn.Module):
    def __init__(self, config: Config):
        super().__init__()
        self.heads = nn.ModuleList([
            AttentionHead(config) for _ in range(config.num_heads)
        ])
        self.linear = nn.Linear(config.hidden_size, config.hidden_size)
        self.dropout = nn.Dropout(config.dropout)
        
    def forward(self, x: torch.Tensor, mask: Optional[torch.Tensor] = None) -> torch.Tensor:
        heads = [head(x, mask) for head in self.heads]
        multihead = torch.cat(heads, dim=-1)
        return self.dropout(self.linear(multihead))

class TransformerBlock(nn.Module):
    def __init__(self, config: Config):
        super().__init__()
        self.attention = MultiHeadAttention(config)
        self.norm1 = nn.LayerNorm(config.hidden_size)
        self.norm2 = nn.LayerNorm(config.hidden_size)
        self.feed_forward = nn.Sequential(
            nn.Linear(config.hidden_size, config.intermediate_size),
            nn.GELU(),
            nn.Linear(config.intermediate_size, config.hidden_size),
            nn.Dropout(config.dropout)
        )
        
    def forward(self, x: torch.Tensor, mask: Optional[torch.Tensor] = None) -> torch.Tensor:
        # Attention with residual connection and layer norm
        attended = self.attention(x, mask)
        x = self.norm1(x + attended)
        
        # Feed forward with residual connection and layer norm
        fed_forward = self.feed_forward(x)
        return self.norm2(x + fed_forward)

class SmallLanguageModel(nn.Module):
    def __init__(self, config: Config):
        super().__init__()
        self.config = config
        
        # Token and position embeddings
        self.token_embedding = nn.Embedding(config.vocab_size, config.hidden_size)
        self.position_embedding = nn.Embedding(config.max_position_embeddings, config.hidden_size)
        
        # Transformer blocks
        self.transformer_blocks = nn.ModuleList([
            TransformerBlock(config) for _ in range(config.num_layers)
        ])
        
        self.dropout = nn.Dropout(config.dropout)
        self.ln_f = nn.LayerNorm(config.hidden_size)
        self.head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
        
        # Initialize weights
        self.apply(self._init_weights)
        
    def _init_weights(self, module):
        if isinstance(module, (nn.Linear, nn.Embedding)):
            module.weight.data.normal_(mean=0.0, std=0.02)
            if isinstance(module, nn.Linear) and module.bias is not None:
                module.bias.data.zero_()
        elif isinstance(module, nn.LayerNorm):
            module.bias.data.zero_()
            module.weight.data.fill_(1.0)
            
    def get_causal_mask(self, size: int) -> torch.Tensor:
        mask = torch.triu(torch.ones(size, size), diagonal=1).bool()
        return ~mask
            
    def forward(self, input_ids: torch.Tensor) -> torch.Tensor:
        b, t = input_ids.size()
        
        # Create position indices and causal mask
        positions = torch.arange(0, t, dtype=torch.long, device=input_ids.device)
        mask = self.get_causal_mask(t).to(input_ids.device)
        
        # Get token and position embeddings
        token_embeddings = self.token_embedding(input_ids)
        position_embeddings = self.position_embedding(positions)
        
        # Combine embeddings
        x = self.dropout(token_embeddings + position_embeddings)
        
        # Apply transformer blocks
        for block in self.transformer_blocks:
            x = block(x, mask)
            
        x = self.ln_f(x)
        logits = self.head(x)
        
        return logits

# Trainer Class

In [None]:
class TrainingMode:
    PRETRAIN = "pretrain"
    RESUME = "resume"
    NEW_DATASET = "new_dataset"
    FINETUNE = "finetune"

class TrainerConfig:
    @staticmethod
    def get_pretrain_config():
        config = Config()
        # Default pretraining parameters are already set in Config class
        return config

    @staticmethod
    def get_finetune_config():
        config = Config()
        # Modify for finetuning
        config.learning_rate = 2e-5
        config.max_epochs = 3
        config.batch_size = 4
        config.weight_decay = 0.02
        config.warmup_steps = 100
        config.gradient_accumulation_steps = 16
        return config

class Trainer:
    def __init__(self, model: nn.Module, train_dataset: Dataset,
                 val_dataset: Optional[Dataset] = None):
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.model = model.to(self.device)
        self.train_dataset = train_dataset
        self.val_dataset = val_dataset
        self.scaler = GradScaler()
        self.optimizer = None

    def setup_optimizer(self, config: Config, mode: str):
        if mode == TrainingMode.FINETUNE:
            # Layer-wise learning rate decay for finetuning
            params = []
            assigned_params = set()

            def add_params_to_group(param_names, lr_scale):
                group_params = []
                for name, param in self.model.named_parameters():
                    if any(pn in name for pn in param_names) and param not in assigned_params:
                        group_params.append(param)
                        assigned_params.add(param)
                if group_params:
                    params.append({
                        'params': group_params,
                        'lr': config.learning_rate * lr_scale
                    })

            # Add embedding parameters (lowest learning rate)
            add_params_to_group(['embedding'], 0.1)

            # Add transformer layers with progressive learning rates
            for i in range(config.num_layers):
                lr_scale = 1 - (0.1 * (config.num_layers - i - 1) / config.num_layers)
                add_params_to_group([f'transformer_blocks.{i}.'], lr_scale)

            # Add output layer parameters (highest learning rate)
            add_params_to_group(['head', 'ln_f'], 1.0)

            # Add remaining parameters
            remaining_params = [p for n, p in self.model.named_parameters()
                              if p not in assigned_params]
            if remaining_params:
                params.append({
                    'params': remaining_params,
                    'lr': config.learning_rate
                })

            self.optimizer = AdamW(params, weight_decay=config.weight_decay)
        else:
            # Standard optimizer for pretraining
            self.optimizer = AdamW(
                self.model.parameters(),
                lr=config.learning_rate,
                weight_decay=config.weight_decay
            )

        # Move optimizer states to the correct device
        for state in self.optimizer.state.values():
            for k, v in state.items():
                if isinstance(v, torch.Tensor):
                    state[k] = v.to(self.device)

    def save_checkpoint(self, epoch: int, batch_idx: int, loss: float, config: Config, mode: str):
        checkpoint = {
            'epoch': epoch,
            'batch_idx': batch_idx,
            'model_state_dict': self.model.state_dict(),
            'optimizer_state_dict': self.optimizer.state_dict(),
            'scaler_state_dict': self.scaler.state_dict(),
            'loss': loss,
            'config': vars(config),
            'original_mode': mode
        }
        path = 'checkpoint.pt'
        torch.save(checkpoint, path)
        print(f"Saved checkpoint to {path}")

    def load_checkpoint(self, path: str, mode: str) -> Tuple[int, int, str]:
        print(f"Loading checkpoint from {path}")
        checkpoint = torch.load(path, map_location=self.device, weights_only=False)

        if mode == TrainingMode.NEW_DATASET:
            # Only load model weights for new dataset
            self.model.load_state_dict(checkpoint['model_state_dict'])
            return 0, 0, checkpoint.get('original_mode', TrainingMode.PRETRAIN)
        elif mode == TrainingMode.RESUME:
            # Load model state
            self.model.load_state_dict(checkpoint['model_state_dict'])
            
            # Load optimizer and scaler states
            self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
            self.scaler.load_state_dict(checkpoint['scaler_state_dict'])
            
            # Move optimizer states to correct device
            for state in self.optimizer.state.values():
                for k, v in state.items():
                    if isinstance(v, torch.Tensor):
                        state[k] = v.to(self.device)
                        
            return checkpoint['epoch'], checkpoint['batch_idx'], checkpoint.get('original_mode', TrainingMode.PRETRAIN)
        else:
            # For direct pretrain/finetune, just load model weights
            self.model.load_state_dict(checkpoint['model_state_dict'])
            return 0, 0, mode

    def train(self, mode: str, checkpoint_path: Optional[str] = None):
        # Get appropriate config first
        if mode == TrainingMode.RESUME and checkpoint_path:
            checkpoint = torch.load(checkpoint_path, map_location=self.device, weights_only=False)
            original_mode = checkpoint.get('original_mode', TrainingMode.PRETRAIN)
            config = (TrainerConfig.get_finetune_config() 
                     if original_mode == TrainingMode.FINETUNE 
                     else TrainerConfig.get_pretrain_config())
        else:
            config = (TrainerConfig.get_finetune_config() 
                     if mode == TrainingMode.FINETUNE 
                     else TrainerConfig.get_pretrain_config())

        # Setup optimizer before loading checkpoint
        self.setup_optimizer(config, mode)

        # Initialize starting points
        start_epoch = 0
        start_batch = 0
        original_mode = mode

        # Load checkpoint if provided
        if checkpoint_path:
            start_epoch, start_batch, original_mode = self.load_checkpoint(checkpoint_path, mode)

        # Training loop setup
        train_loader = DataLoader(
            self.train_dataset,
            batch_size=config.batch_size,
            sampler=RandomSampler(self.train_dataset, generator=shuffle_generator),
            pin_memory=True,
            num_workers=4
        )

        print(f"Training started in {mode} mode")
        print(f"Total batches per epoch: {len(train_loader)}")
        print(f"Total epochs: {config.max_epochs}")
        print(f"Device: {self.device}")

        # Training loop
        for epoch in range(start_epoch, config.max_epochs):
            self.model.train()
            total_loss = 0

            for batch_idx, batch in enumerate(train_loader):
                # Skip batches if resuming from checkpoint
                if epoch == start_epoch and batch_idx < start_batch:
                    continue

                # Move batch to device
                batch = batch.to(self.device)

                # Forward pass with mixed precision
                with autocast():
                    logits = self.model(batch)
                    targets = batch[:, 1:]
                    logits = logits[:, :-1, :]
                    loss = F.cross_entropy(
                        logits.reshape(-1, logits.size(-1)),
                        targets.reshape(-1)
                    )

                # Scale loss and backward pass
                self.scaler.scale(loss).backward()

                # Gradient accumulation
                if (batch_idx + 1) % config.gradient_accumulation_steps == 0:
                    self.scaler.unscale_(self.optimizer)
                    torch.nn.utils.clip_grad_norm_(
                        self.model.parameters(),
                        config.max_grad_norm
                    )

                    self.scaler.step(self.optimizer)
                    self.scaler.update()
                    self.optimizer.zero_grad()

                total_loss += loss.item()

                # Save checkpoint and display progress
                if batch_idx % config.checkpoint_every == 0:
                    self.save_checkpoint(epoch, batch_idx, loss.item(), config, original_mode)
                    clear_output(wait=True)
                    print(f"Mode: {mode}")
                    print(f"Epoch {epoch+1}/{config.max_epochs}")
                    print(f"Batch {batch_idx+1}/{len(train_loader)}")
                    print(f"Loss {total_loss / (batch_idx + 1):.4f}")

            # Save checkpoint at end of epoch
            self.save_checkpoint(epoch, len(train_loader)-1, loss.item(), config, original_mode)

        print(f'{mode} training complete.')

# Initialize Tokenizer

In [None]:
cl100k_base = tiktoken.get_encoding("cl100k_base")

tokenizer = tiktoken.Encoding(
    name="cl100k_xml",
    pat_str=cl100k_base._pat_str,
    mergeable_ranks=cl100k_base._mergeable_ranks,
    special_tokens={
        **cl100k_base._special_tokens,
        "<user>": 100277,
        "</user>": 100278,
        "<assistant>": 100279,
        "</assistant>": 100280,
        "<system>": 100281,
        "</system>": 100282
    }
)

# Create config
config = Config()
# Update vocab size to match tokenizer
config.vocab_size = tokenizer.n_vocab

# Load Data

## Pretrain 1st Session

In [None]:
from datasets import load_dataset

# Load Dataset
ds = ''.join(load_dataset("nampdn-ai/mini-en", split='train', token="[INSERT HF TOKEN HERE]")['text'])

# Create dataset
train_dataset = TextDataset(ds, config.max_position_embeddings, tokenizer)

## Pretrain 2nd Session

In [None]:
from datasets import load_dataset

# Load Dataset
ds = ''.join(load_dataset("HuggingFaceTB/cosmopedia-100k", split='train', token="[INSERT HF TOKEN HERE]")['text'])

# Create dataset
train_dataset = TextDataset(ds, config.max_position_embeddings, tokenizer)

## Finetune Session

In [None]:
fds1 = load_dataset("HuggingFaceH4/SystemChat", split="train_sft")
fds2 = load_dataset("HuggingFaceH4/no_robots", split="train")
fds3 = load_dataset('b-ai/deepseek_synthetic_conversation_dialogue', split='train')
finetune_datasets = [fds1, fds2, fds3]

finetune_parts = []

for ds in finetune_datasets:
    dataset_conversations = []
    
    for messages in ds['messages']:
        # Format all messages first
        conversation_chunks = [
            f"<{message['role']}>{message['content']}</{message['role']}>"
            for message in messages
        ]
        
        # Add two newlines at conversation start + join messages
        full_conversation = '\n'.join(conversation_chunks)
        dataset_conversations.append(full_conversation)
    
    # Join conversations within dataset
    finetune_parts.append('\n\n'.join(dataset_conversations))

# Final dataset assembly
finetune_text = '\n\n'.join(finetune_parts)
train_dataset = TextDataset(finetune_text, config.max_position_embeddings, tokenizer)

# Initialize Model

In [None]:
# Initialize model
model = SmallLanguageModel(config)

# Start or Resume Training

## Initialize Trainer

In [None]:
# Initialize trainer
trainer = Trainer(model, train_dataset)

### Start Pretrain Mode

In [None]:
trainer.train(mode=TrainingMode.PRETRAIN)

### Resume Mode

In [None]:
trainer.train(mode=TrainingMode.RESUME, checkpoint_path='checkpoint.pt')

### New Dataset Pretrain Mode

In [None]:
trainer.train(mode=TrainingMode.NEW_DATASET, checkpoint_path='checkpoint.pt')

### Finetune Mode

In [None]:
trainer.train(mode=TrainingMode.FINETUNE, checkpoint_path='checkpoint.pt')

# Inference

In [None]:
class TextGenerator:
    def __init__(self, model, tokenizer):
        self.model = model
        self.model.eval()  # Set to evaluation mode
        self.tokenizer = tokenizer
        
    @torch.no_grad()  # Disable gradient calculation for inference
    def generate(
        self, 
        prompt: str,
        max_length: int = 100,
        temperature: float = 0.7,
        top_k: int = 50,
        top_p: float = 0.9
    ):
        try:
            # Encode the prompt
            input_ids = torch.tensor(self.tokenizer.encode(prompt, allowed_special={'<user>', '</user>', '<assistant>', '</assistant>', '<system>', '</system>'})).unsqueeze(0).to(device)
            
            # Generate tokens
            for _ in range(max_length):
                # Get model predictions
                if input_ids.size(1) > config.max_position_embeddings:
                    input_ids = input_ids[:, -config.max_position_embeddings:]
                    
                logits = self.model(input_ids)
                next_token_logits = logits[:, -1, :] / temperature
                
                # Apply top-k filtering
                if top_k > 0:
                    values, _ = torch.topk(next_token_logits, top_k)
                    min_value = values[:, -1].unsqueeze(-1).expand_as(next_token_logits)
                    next_token_logits = torch.where(
                        next_token_logits < min_value,
                        torch.ones_like(next_token_logits) * float('-inf'),
                        next_token_logits
                    )
                
                # Apply top-p (nucleus) filtering
                if top_p < 1.0:
                    # Sort logits in descending order
                    sorted_logits, sorted_indices = torch.sort(next_token_logits, descending=True)
                    cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)

                    # Remove tokens with cumulative probability above the threshold
                    sorted_indices_to_remove = cumulative_probs > top_p
                    # Shift the indices to the right to keep also the first token above the threshold
                    sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
                    sorted_indices_to_remove[..., 0] = 0

                    # Scatter sorted tensors to original indexing
                    indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
                    next_token_logits = next_token_logits.masked_fill(indices_to_remove, float('-inf'))
                
                # Sample next token
                probs = F.softmax(next_token_logits, dim=-1)
                next_token = torch.multinomial(probs, num_samples=1)
                
                # Append next token to input_ids
                input_ids = torch.cat((input_ids, next_token), dim=1)
                
                # Check for end of text token (if your tokenizer has one)
                if hasattr(self.tokenizer, 'eot_token') and next_token.item() == self.tokenizer.eot_token:
                    break
            
            # Decode the generated tokens
            return self.tokenizer.decode(input_ids[0].tolist())
            
        except Exception as e:
            print(f"Error during generation: {str(e)}")
            return prompt  # Return original prompt if generation fails

# Load the checkpoint
checkpoint = torch.load('checkpoint.pt', map_location=torch.device(device), weights_only=False)
model.load_state_dict(checkpoint['model_state_dict'])
model.to(device)

# Initialize the generator
generator = TextGenerator(model, tokenizer)

# Example usage
prompt = "Hello"
generated_text = generator.generate(
    prompt=prompt,
    max_length=100,    # Maximum tokens to generate
    temperature=0.7,   # Higher = more random, lower = more focused
    top_k=50,         # Consider only top k tokens
    top_p=0.9         # Nucleus sampling threshold
)

print("Generated text:")
print(generated_text)

# Inspect Checkpoint

In [None]:
import torch
from typing import Any

class CheckpointManager:
    def __init__(self, checkpoint_path: str):
        self.checkpoint_path = checkpoint_path
        self.checkpoint = torch.load(checkpoint_path, map_location='cpu', weights_only=False)
        
    def display_contents(self):
        """Display all contents of the checkpoint"""
        print("\n=== Checkpoint Contents ===")
        print(f"Training Mode: {self.checkpoint.get('original_mode', 'Not specified')}")
        print(f"Epoch: {self.checkpoint.get('epoch', 'Not specified')}")
        print(f"Batch Index: {self.checkpoint.get('batch_idx', 'Not specified')}")
        print(f"Loss: {self.checkpoint.get('loss', 'Not specified')}")
        
        print("\n=== Configuration ===")
        if 'config' in self.checkpoint:
            for key, value in self.checkpoint['config'].items():
                print(f"{key}: {value}")
        else:
            print("No configuration found in checkpoint")
        
        print("\n=== State Dictionaries Present ===")
        print("Model state dict:", 'model_state_dict' in self.checkpoint)
        print("Optimizer state dict:", 'optimizer_state_dict' in self.checkpoint)
        print("Scaler state dict:", 'scaler_state_dict' in self.checkpoint)

    def expand_embeddings(self, new_vocab_size: int, init_method: str = "mean"):
        state_dict = self.checkpoint['model_state_dict']
        
        # Get original sizes from BOTH layers
        old_token_vocab_size = state_dict['token_embedding.weight'].size(0)
        old_head_vocab_size = state_dict['head.weight'].size(0)
        
        # ===== 1. Expand Token Embeddings =====
        old_emb = state_dict['token_embedding.weight']
        new_emb = torch.zeros((new_vocab_size, old_emb.size(1)))
        new_emb[:old_token_vocab_size] = old_emb
        
        # Initialize new token embeddings
        if init_method == "mean":
            new_emb[old_token_vocab_size:] = old_emb.mean(dim=0)
        elif init_method == "normal":
            new_emb[old_token_vocab_size:] = torch.randn_like(new_emb[old_token_vocab_size:]) * 0.02
            
        state_dict['token_embedding.weight'] = new_emb
        
        # ===== 2. Expand Output Layer =====
        old_head = state_dict['head.weight']
        new_head = torch.zeros((new_vocab_size, old_head.size(1)))
        new_head[:old_head_vocab_size] = old_head  # Use HEAD's original size
        
        # Initialize new output weights
        if init_method == "mean":
            new_head[old_head_vocab_size:] = old_head.mean(dim=0)
        elif init_method == "normal":
            new_head[old_head_vocab_size:] = torch.randn_like(new_head[old_head_vocab_size:]) * 0.02
            
        state_dict['head.weight'] = new_head
        
        # Update config
        self.checkpoint['config']['vocab_size'] = new_vocab_size
        print(f"Expanded: Tokens {old_token_vocab_size}→{new_vocab_size}, Head {old_head_vocab_size}→{new_vocab_size}")
    
    def modify_value(self, key_path: str, new_value: Any):
        """
        Modify a value in the checkpoint using a dot-notation path
        Example: 'config.learning_rate' or 'original_mode'
        """
        keys = key_path.split('.')
        current = self.checkpoint
        
        # Navigate to the nested location
        for key in keys[:-1]:
            if key not in current:
                print(f"Error: Key '{key}' not found in checkpoint")
                return
            current = current[key]
            
        final_key = keys[-1]
        if final_key not in current:
            print(f"Error: Final key '{final_key}' not found")
            return
            
        # Convert value to the same type as the existing value
        try:
            old_value = current[final_key]
            if isinstance(old_value, bool):
                new_value = bool(new_value)
            elif isinstance(old_value, int):
                new_value = int(new_value)
            elif isinstance(old_value, float):
                new_value = float(new_value)
        except ValueError:
            print(f"Error: Could not convert new value to type {type(old_value)}")
            return
            
        # Update the value
        current[final_key] = new_value
        print(f"Updated {key_path} from {old_value} to {new_value}")
    
    def save(self, output_path: str = None):
        """Save the modified checkpoint"""
        save_path = output_path or self.checkpoint_path
        torch.save(self.checkpoint, save_path)
        print(f"Saved checkpoint to {save_path}")

manager = CheckpointManager('checkpoint.pt')

manager.display_contents()

In [None]:
from IPython.display import FileLink

FileLink(r'checkpoint.pt')