diff --git "a/CLEANED_modular_brain_agent.ipynb" "b/CLEANED_modular_brain_agent.ipynb" new file mode 100644--- /dev/null +++ "b/CLEANED_modular_brain_agent.ipynb" @@ -0,0 +1,1458 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "id": "b6055f45", + "metadata": {}, + "outputs": [], + "source": [ + "import torch\n", + "import torch.nn as nn\n", + "import torch.nn.functional as F\n", + "from torch.utils.data import DataLoader, TensorDataset\n", + "from torchvision import datasets, transforms\n", + "from torch.utils.tensorboard import SummaryWriter\n", + "import numpy as np\n", + "import random\n", + "from collections import deque\n", + "import os # For creating log directory\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "8db6a5c2", + "metadata": {}, + "outputs": [], + "source": [ + "# --- 0. Global Setup and Vocabulary ---\n", + "\n", + "# Simulated language data\n", + "sample_sentences = [\n", + "\"I love AI\",\n", + "\"Deep learning is fun\",\n", + "\"Spiking neurons are cool\",\n", + "\"Brain inspired models rock\",\n", + "\"Replay buffer helps learning\"\n", + "]\n", + "\n", + "\n", + "def simple_tokenize(text):\n", + " \"\"\"Basic tokenizer for sample sentences.\"\"\"\n", + " return text.lower().split()[:10]\n", + "\n", + " # Build vocab manually\n", + " vocab = {\"\": 0}\n", + " word_counter = 1\n", + " for sentence in sample_sentences:\n", + " for word in simple_tokenize(sentence):\n", + " if word not in vocab:\n", + " vocab[word] = word_counter\n", + " word_counter += 1\n", + "\n", + " # Task IDs (using an Enum-like class for clarity)\n", + " class Task:\n", + " REGRESSION = 3\n", + " BINARY = 4\n", + " VISION = 5\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "b55bcf2b", + "metadata": {}, + "outputs": [], + "source": [ + " # --- 1. Base Classes and Spiking Neuron Implementations ---\n", + "\n", + " class Module(nn.Module):\n", + " \"\"\"Base class for brain-inspired modules.\"\"\"\n", + "\n", + " def __init__(self):\n", + " super().__init__()\n", + "\n", + " def forward(self, x: torch.Tensor) -> torch.Tensor:\n", + " raise NotImplementedError\n", + "\n", + " class SurrogateLIF(torch.autograd.Function):\n", + " \"\"\"\n", + " Surrogate gradient function for LIF neurons.\n", + " Allows backpropagation through the non-differentiable spiking\n", + " function.\n", + " \"\"\"\n", + " @staticmethod\n", + " def forward(\n", + " ctx, input: torch.Tensor) -> torch.Tensor:\n", + " ctx.save_for_backward(input)\n", + " return (input > 0).float()\n", + "\n", + " @staticmethod\n", + " def backward(\n", + " ctx, grad_output: torch.Tensor) -> torch.Tensor:\n", + " input, = ctx.saved_tensors\n", + " grad_input = grad_output.clone()\n", + " # Approximate derivative: a\n", + " # constant value where\n", + " # input is\n", + " close to 0\n", + " grad_input[input.abs(\n", + " ) < 1] = 1.0\n", + " return grad_input\n", + "\n", + " class SpikingNeuron(\n", + " Module):\n", + " \"\"\"\n", + " Leaky Integrate-and-Fire (LIF) neuron model.\n", + " Resets membrane potential to zero after spiking.\n", + " \"\"\"\n", + "\n", + " def __init__(\n", + " self, threshold: float = 1.0, decay: float = 0.95):\n", + " super().__init__()\n", + " self.threshold = threshold\n", + " self.decay = decay\n", + " # Register buffer\n", + " # for membrane\n", + " # potential,\n", + " # persists across\n", + " forward calls\n", + " # type:\n", + " # torch.Tensor\n", + " self.register_buffer(\n", + " 'mem', None)\n", + "\n", + " def forward(\n", + " self, x: torch.Tensor) -> torch.Tensor:\n", + " # Initialize or\n", + " # re-initialize\n", + " # membrane\n", + " # potential if\n", + " # shape\n", + " changes\n", + " if self.mem is None or self.mem.shape != x.shape:\n", + " self.mem = torch.zeros_like(\n", + " x)\n", + "\n", + " # Update\n", + " # membrane\n", + " # potential\n", + " self.mem = self.decay * self.mem + x\n", + " # Generate\n", + " # spike\n", + " spike = (\n", + " self.mem >= self.threshold).float()\n", + "\n", + " # Reset\n", + " # membrane\n", + " # potential\n", + " # for\n", + " # spiking\n", + " # neurons\n", + " self.mem = torch.where(spike.bool(),\n", + " torch.zeros_like(self.mem), self.mem)\n", + "\n", + " # Apply\n", + " # surrogate\n", + " # gradient\n", + " # for\n", + " # backpropagation\n", + " return SurrogateLIF.apply(\n", + " spike)\n", + "\n", + " class AdaptiveLIFNeuron(\n", + " SpikingNeuron):\n", + " \"\"\"\n", + " LIF neuron with an adaptive threshold based on recent firing rate.\n", + " \"\"\"\n", + "\n", + " def __init__(\n", + " self, threshold: float = 1.0, decay: float = 0.95):\n", + " super().__init__(threshold=threshold, decay=decay)\n", + " # Register\n", + " # buffer\n", + " # for\n", + " # adaptive\n", + " # threshold\n", + " # state\n", + " self.register_buffer('threshold_state',\n", + " torch.tensor(threshold))\n", + "\n", + " def forward(\n", + " self, x: torch.Tensor) -> torch.Tensor:\n", + " if self.mem is None or self.mem.shape != x.shape:\n", + " self.mem = torch.zeros_like(\n", + " x)\n", + "\n", + " self.mem = self.decay * self.mem + x\n", + " spike = (\n", + " self.mem >= self.threshold_state).float()\n", + " self.mem = torch.where(spike.bool(),\n", + " torch.zeros_like(self.mem), self.mem)\n", + "\n", + " # Adapt\n", + " # threshold\n", + " # based\n", + " # on\n", + " # scalar\n", + " # mean\n", + " # spike\n", + " # rate\n", + " # (no_grad\n", + " # for\n", + " threshold update)\n", + " with torch.no_grad():\n", + " spike_rate = spike.mean().item() # Scalar rate\n", + " # Simple\n", + " # moving\n", + " # average\n", + " # for\n", + " # threshold\n", + " # adaptation\n", + " self.threshold_state = self.threshold_state * 0.99 +\n", + " spike_rate * 0.01\n", + "\n", + " return SurrogateLIF.apply(\n", + " spike)\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "618a2828", + "metadata": {}, + "outputs": [], + "source": [ + " # --- 2. Encoder Modules (Task-Specific Input Processing) ---\n", + "\n", + " class SharedEncoder(\n", + " nn.Module):\n", + " \"\"\"Generic encoder for numerical data.\"\"\"\n", + " def __init__(\n", + " self, input_size: int, hidden_size: int = 4):\n", + " super().__init__()\n", + " self.encoder = nn.Sequential(\n", + " nn.Linear(\n", + " input_size, hidden_size),\n", + " nn.LayerNorm(\n", + " hidden_size),\n", + " nn.ReLU(),\n", + " nn.Dropout(\n", + " 0.3)\n", + "\n", + " )\n", + "\n", + " def forward(\n", + " self, x: torch.Tensor) -> torch.Tensor:\n", + " assert x.dim() == 2, f\"Expected 2D input for SharedEncoder,\n", + " got {x.shape}\"\n", + " return self.encoder(x.float())\n", + "\n", + "\n", + " class CNNVision(nn.Module):\n", + " \"\"\"CNN encoder for image data (e.g., MNIST).\"\"\"\n", + " def __init__(self, output_features: int = 4):\n", + " super().__init__()\n", + " self.conv = nn.Sequential(\n", + " nn.Conv2d(1, 4, kernel_size=5),\n", + " nn.ReLU(),\n", + " nn.MaxPool2d(2),\n", + " nn.Conv2d(4, 8, kernel_size=5),\n", + " nn.ReLU(),\n", + " nn.AdaptiveAvgPool2d((1, 1)), # Outputs [batch, 8, 1, 1]\n", + " nn.Flatten(), # Outputs [batch, 8]\n", + " nn.Linear(8, output_features) # Maps to desired output\n", + " size\n", + " )\n", + "\n", + " def forward(self, x: torch.Tensor) -> torch.Tensor:\n", + " return self.conv(x)\n", + "\n", + "\n", + " class GRULanguage(nn.Module):\n", + " \"\"\"GRU encoder for sequential language data.\"\"\"\n", + " def __init__(self, vocab_size: int, embedding_dim: int = 4,\n", + " hidden_dim: int = 4):\n", + " super().__init__()\n", + " self.embed = nn.Embedding(vocab_size, embedding_dim)\n", + " self.gru = nn.GRU(embedding_dim, hidden_dim, batch_first=True)\n", + "\n", + " def forward(self, x: torch.Tensor) -> torch.Tensor:\n", + " if x.dim() == 1:\n", + " x = x.unsqueeze(0) # Add batch dimension if missing\n", + " emb = self.embed(x.long()) # Ensure input is long type for\n", + " embedding\n", + " out, _ = self.gru(emb)\n", + " return out[:, -1, :] # Use last hidden state as sequence\n", + " representation\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "e3333c1f", + "metadata": {}, + "outputs": [], + "source": [ + " # --- 3. Brain-Inspired Modules with Specific Neuron Counts ---\n", + "\n", + " class SensoryProcessor(Module):\n", + "\n", + " \"\"\"Processes initial sensory input, often with low-level\n", + " features.\"\"\"\n", + " def __init__(self, input_dim: int, output_dim: int):\n", + " super().__init__()\n", + " self.linear = nn.Linear(input_dim, output_dim)\n", + " self.norm = nn.LayerNorm(output_dim)\n", + " self.dropout = nn.Dropout(0.3)\n", + " self.neuron = SpikingNeuron()\n", + "\n", + " def forward(self, x: torch.Tensor) -> torch.Tensor:\n", + " z = self.dropout(self.norm(self.linear(x)))\n", + " return self.neuron(z)\n", + "\n", + "\n", + " class RelayLayer(Module):\n", + " \"\"\"\n", + " Routes information, potentially performing attention-like\n", + " operations.\n", + " Simulates a small sequence for MultiheadAttention.\n", + " \"\"\"\n", + " def __init__(self, input_dim: int, output_dim: int):\n", + " super().__init__()\n", + " self.linear = nn.Linear(input_dim, output_dim)\n", + " self.norm = nn.LayerNorm(output_dim)\n", + " # MultiheadAttention requires embed_dim to be divisible by\n", + " num_heads\n", + " self.attn = nn.MultiheadAttention(embed_dim=output_dim,\n", + " num_heads=2, batch_first=True)\n", + " self.dropout = nn.Dropout(0.3)\n", + " self.neuron = SpikingNeuron()\n", + "\n", + " def forward(self, x: torch.Tensor) -> torch.Tensor:\n", + " z = self.dropout(self.norm(self.linear(x)))\n", + "\n", + " # Create fake time dimension by repeating the input for\n", + " attention\n", + " seq_len = 4\n", + " z_seq = z.unsqueeze(1).repeat(1, seq_len, 1)\n", + "\n", + " # Apply attention across the simulated time steps\n", + " z_out, _ = self.attn(z_seq, z_seq, z_seq)\n", + "\n", + " # Pool over the time dimension to get a single representation\n", + " routed = z_out.mean(dim=1)\n", + "\n", + " return self.neuron(routed)\n", + "\n", + "\n", + "\n", + " class InterneuronLogic(Module):\n", + " \"\"\"Core processing unit, potentially for decision making or\n", + " high-level abstraction.\"\"\"\n", + " def __init__(self, input_dim: int, output_dim: int):\n", + " super().__init__()\n", + " self.linear = nn.Linear(input_dim, output_dim)\n", + " self.norm = nn.LayerNorm(output_dim)\n", + " self.dropout = nn.Dropout(0.3)\n", + " self.neuron = AdaptiveLIFNeuron() # Using adaptive neuron here\n", + "\n", + " def forward(self, x: torch.Tensor) -> torch.Tensor:\n", + " z = self.dropout(self.norm(self.linear(x)))\n", + " return self.neuron(z)\n", + "\n", + "\n", + " class NeuroendocrineModulator(Module):\n", + " \"\"\"\n", + " Modulates signals, potentially for gain control or emotional\n", + " states.\n", + " Applies a sigmoid-based gain to its input.\n", + " \"\"\"\n", + " def __init__(self, input_dim: int, output_dim: int):\n", + " super().__init__()\n", + " self.linear = nn.Linear(input_dim, output_dim)\n", + " self.norm = nn.LayerNorm(output_dim)\n", + " self.dropout = nn.Dropout(0.3)\n", + " self.gain_control = nn.Linear(output_dim, output_dim) # Maps\n", + " output to gain factor\n", + "\n", + " def forward(self, x: torch.Tensor) -> torch.Tensor:\n", + " z = self.dropout(self.norm(self.linear(x)))\n", + " gain = torch.sigmoid(self.gain_control(z)) # Gain factor\n", + " between 0 and 1\n", + " return z * gain\n", + "\n", + "\n", + " class AutonomicProcessor(Module):\n", + " \"\"\"\n", + " Manages internal states, often involves recurrent processing.\n", + " Uses a GRU for sequential processing.\n", + " \"\"\"\n", + " def __init__(self, input_dim: int, output_dim: int):\n", + " super().__init__()\n", + " self.linear = nn.Linear(input_dim, output_dim)\n", + " self.norm = nn.LayerNorm(output_dim)\n", + " self.recurrent = nn.GRU(output_dim, output_dim,\n", + " batch_first=True)\n", + " self.feedback_gain = 0.9\n", + "\n", + "\n", + " def forward(self, x: torch.Tensor) -> torch.Tensor:\n", + " z = self.norm(self.linear(x))\n", + " # GRU expects (batch, seq_len, features) -> seq_len=1 for\n", + " single step\n", + " h, _ = self.recurrent(z.unsqueeze(1))\n", + " # Squeeze the sequence dimension back out\n", + " return self.feedback_gain * h.squeeze(1)\n", + "\n", + "\n", + " class MirrorComparator(Module):\n", + " \"\"\"\n", + " Compares current state to a reference, potentially for self-other\n", + " distinction or goal comparison.\n", + " Outputs spikes and an optional similarity score.\n", + " \"\"\"\n", + " def __init__(self, input_dim: int, output_dim: int):\n", + " super().__init__()\n", + " self.linear = nn.Linear(input_dim, output_dim)\n", + " self.norm = nn.LayerNorm(output_dim)\n", + " self.comparison_layer = nn.CosineSimilarity(dim=1) # Cosine\n", + " similarity for comparison\n", + " self.neuron = SpikingNeuron()\n", + "\n", + " def compare(self, a: torch.Tensor, b: torch.Tensor) ->\n", + " torch.Tensor:\n", + " \"\"\"Compare pre-spike representations for similarity.\"\"\"\n", + " a_proj = self.linear(a)\n", + " b_proj = self.linear(b)\n", + " return self.comparison_layer(a_proj, b_proj)\n", + "\n", + " def forward(self, x: torch.Tensor, reference: torch.Tensor = None)\n", + " -> (torch.Tensor, torch.Tensor):\n", + " z = self.norm(self.linear(x))\n", + " spike = self.neuron(z)\n", + " similarity = None\n", + " if reference is not None:\n", + " # Note: The 'reference' input would need to be processed\n", + " similarly before comparison\n", + " # For simplicity, assuming 'reference' is already in the\n", + " correct feature space for comparison_layer\n", + " similarity = self.compare(z, reference) # Pass processed\n", + " 'z' for comparison\n", + " return spike, similarity\n", + "\n", + "\n", + " class PlaceGridMemory(Module):\n", + " \"\"\"\n", + "\n", + " Spatial memory system using population codes and LSTM for\n", + " sequential memory.\n", + " \"\"\"\n", + " def __init__(self, input_dim: int, output_dim: int):\n", + " super().__init__()\n", + " self.linear = nn.Linear(input_dim, output_dim)\n", + " self.norm = nn.LayerNorm(output_dim)\n", + " self.positional_encoder = self.population_code # Reference to\n", + " internal method\n", + " self.memory_cell = nn.LSTM(output_dim, output_dim,\n", + " batch_first=True)\n", + "\n", + " def population_code(self, x: torch.Tensor, pop_size: int) ->\n", + " torch.Tensor:\n", + " \"\"\"Expands a scalar value into a population-coded (one-hot\n", + " like) vector.\"\"\"\n", + " # Calculate a scalar representation (e.g., mean)\n", + " x_scalar = x.mean(dim=1)\n", + " x_normalized = torch.sigmoid(x_scalar) # Normalize to [0, 1]\n", + " # Map normalized value to an index within the population size\n", + " idx = (x_normalized * (pop_size - 1)).long().clamp(0, pop_size\n", + " - 1)\n", + "\n", + " encoded = torch.zeros(x.size(0), pop_size, device=x.device)\n", + " # Set the corresponding index to 1.0\n", + " encoded.scatter_(1, idx.unsqueeze(1), 1.0)\n", + " return encoded\n", + "\n", + " def forward(self, x: torch.Tensor) -> torch.Tensor:\n", + " encoded = self.norm(self.linear(x))\n", + " # Apply population coding based on configured output_dim\n", + " pos_encoded = self.positional_encoder(encoded,\n", + " pop_size=self.linear.out_features)\n", + " # LSTM expects (batch, seq_len, features) -> seq_len=1 for\n", + " single step\n", + " memory_out, _ = self.memory_cell(pos_encoded.unsqueeze(1))\n", + " # Squeeze the sequence dimension back out\n", + " return memory_out.squeeze(1)\n", + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "c7c74040", + "metadata": {}, + "outputs": [], + "source": [ + " # --- 4. Task Heads (Task-Specific Output Layers) ---\n", + "\n", + " class TaskHead(nn.Module):\n", + " \"\"\"Generic task head for different output types.\"\"\"\n", + " def __init__(self, input_dim: int, output_dim: int, task_type:\n", + " str):\n", + " super().__init__()\n", + " self.task_type = task_type\n", + "\n", + " self.head = nn.Sequential(\n", + " nn.Linear(input_dim, output_dim),\n", + " nn.Dropout(0.3)\n", + " )\n", + "\n", + " def forward(self, x: torch.Tensor) -> torch.Tensor:\n", + " # Regression and Binary will be squeezed to [batch]\n", + " # Vision will remain [batch, num_classes]\n", + " return self.head(x).squeeze(-1) if self.task_type in\n", + " ['binary', 'regression'] else self.head(x)\n", + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "65e6bc3e", + "metadata": {}, + "outputs": [], + "source": [ + " # --- 5. Replay Buffer for Continual Learning ---\n", + "\n", + " class TaskReplayBuffer:\n", + " \"\"\"\n", + " A replay buffer that stores experiences tagged with their task ID.\n", + " Samples are drawn proportionally or randomly across tasks.\n", + " \"\"\"\n", + " def __init__(self, buffer_size: int = 1000, device: str = \"cpu\"):\n", + " self.buffer_size = buffer_size\n", + " self.task_buffers = {task_id: deque(maxlen=buffer_size) for\n", + " task_id in [Task.REGRESSION, Task.BINARY, Task.VISION]}\n", + " self.device = device\n", + "\n", + " def add(self, task_id: int, states: torch.Tensor, labels:\n", + " torch.Tensor):\n", + " \"\"\"Adds experiences (state-label pairs) to the buffer for a\n", + " specific task.\"\"\"\n", + " if task_id not in self.task_buffers:\n", + " # Should not happen if self.task_buffers is\n", + " pre-initialized with all Task IDs\n", + " self.task_buffers[task_id] =\n", + " deque(maxlen=self.buffer_size)\n", + "\n", + " for state, label in zip(states, labels):\n", + " # Detach from graph and move to CPU to prevent memory\n", + " leaks\n", + " self.task_buffers[task_id].append((state.detach().cpu(),\n", + " label.detach().cpu()))\n", + "\n", + " def sample(self, batch_size: int = 32) -> (torch.Tensor,\n", + " torch.Tensor, list):\n", + " \"\"\"\n", + " Samples a batch of experiences from the buffer, ensuring mixed\n", + " tasks.\n", + " Returns: (states, labels, list of corresponding task_ids)\n", + " \"\"\"\n", + "\n", + " active_tasks = [tid for tid, buffer in\n", + " self.task_buffers.items() if len(buffer) > 0]\n", + " if not active_tasks:\n", + " return None, None, None\n", + "\n", + " # Sample at least two distinct tasks if possible\n", + " num_tasks_to_sample = min(2, len(active_tasks))\n", + " sampled_tasks_ids = random.sample(active_tasks,\n", + " num_tasks_to_sample)\n", + "\n", + " batch_states, batch_labels, batch_task_ids = [], [], []\n", + "\n", + " # Distribute batch_size across sampled tasks\n", + " per_task_samples_base = batch_size // num_tasks_to_sample\n", + " remainder = batch_size % num_tasks_to_sample\n", + "\n", + " for i, task_id in enumerate(sampled_tasks_ids):\n", + " task_list = list(self.task_buffers[task_id]) # Convert\n", + " deque to list for random.sample/choices\n", + " if not task_list:\n", + " continue\n", + "\n", + " # Add remainder to first few tasks\n", + " k = per_task_samples_base + (1 if i < remainder else 0)\n", + " k = min(k, len(task_list)) # Ensure we don't sample more\n", + " than available\n", + "\n", + " if k == 0: continue\n", + "\n", + " try:\n", + " samples = random.sample(task_list, k)\n", + " except ValueError:\n", + " # If k is larger than task_list length (should be\n", + " caught by min(k, len(task_list))\n", + " # but good fallback for robustness)\n", + " samples = random.choices(task_list, k=k)\n", + "\n", + " for state, label in samples:\n", + " batch_states.append(state)\n", + " batch_labels.append(label)\n", + " batch_task_ids.append(task_id)\n", + "\n", + " if not batch_states:\n", + " return None, None, None\n", + "\n", + " # Stack tensors and return task IDs for individual processing\n", + " return (\n", + " torch.stack(batch_states),\n", + "\n", + " torch.stack(batch_labels),\n", + " batch_task_ids # Return as list to preserve individual\n", + " task identities\n", + " )\n", + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "cf664c01", + "metadata": {}, + "outputs": [], + "source": [ + " # --- 6. Full Agent Model ---\n", + "\n", + " class ModularBrainAgent(nn.Module):\n", + " \"\"\"\n", + " A comprehensive modular neural agent inspired by brain\n", + " architecture,\n", + " integrating spiking neurons, attention, and recurrent memory for\n", + " multi-task learning.\n", + " \"\"\"\n", + " def __init__(self, neuron_counts: dict = None):\n", + " super().__init__()\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "4afb88ab", + "metadata": {}, + "outputs": [], + "source": [ + " # --- Parameterization of Neuron Counts ---\n", + " if neuron_counts is None:\n", + " # Default neuron counts for each brain region\n", + " neuron_counts = {\n", + " 'sensory': 4,\n", + " 'relay': 12,\n", + " 'interneurons': 2,\n", + " 'neuroendocrine': 8,\n", + " 'autonomic': 10,\n", + " 'mirror': 14,\n", + " 'place_grid': 16\n", + " }\n", + " self.neuron_counts = neuron_counts\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "cd012993", + "metadata": {}, + "outputs": [], + "source": [ + " # --- Encoders (Input Modalities) ---\n", + " self.encoders = nn.ModuleDict({\n", + " 'regression': SharedEncoder(2,\n", + " self.neuron_counts['sensory']),\n", + " 'language': GRULanguage(len(vocab),\n", + " embedding_dim=self.neuron_counts['sensory'],\n", + " hidden_dim=self.neuron_counts['sensory']),\n", + " 'vision':\n", + " CNNVision(output_features=self.neuron_counts['sensory'])\n", + " })\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "6ab19dd3", + "metadata": {}, + "outputs": [], + "source": [ + " # --- Brain Modules with Realistic Neuron Counts ---\n", + " self.sensory = SensoryProcessor(\n", + " input_dim=self.neuron_counts['sensory'],\n", + " output_dim=self.neuron_counts['sensory']\n", + " )\n", + "\n", + " self.relay = RelayLayer(\n", + " input_dim=self.neuron_counts['sensory'], # Input from\n", + " SensoryProcessor\n", + " output_dim=self.neuron_counts['relay']\n", + " )\n", + " self.interneurons = InterneuronLogic(\n", + " input_dim=self.neuron_counts['relay'], # Input from\n", + " RelayLayer\n", + " output_dim=self.neuron_counts['interneurons']\n", + " )\n", + " self.neuroendocrine = NeuroendocrineModulator(\n", + " input_dim=self.neuron_counts['interneurons'],\n", + " output_dim=self.neuron_counts['neuroendocrine']\n", + " )\n", + " self.autonomic = AutonomicProcessor(\n", + " input_dim=self.neuron_counts['neuroendocrine'],\n", + " output_dim=self.neuron_counts['autonomic']\n", + " )\n", + " self.mirror = MirrorComparator(\n", + " input_dim=self.neuron_counts['autonomic'],\n", + " output_dim=self.neuron_counts['mirror']\n", + " )\n", + " self.place_grid = PlaceGridMemory(\n", + " input_dim=self.neuron_counts['mirror'],\n", + " output_dim=self.neuron_counts['place_grid']\n", + " )\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "224830f4", + "metadata": {}, + "outputs": [], + "source": [ + " # --- Interconnection Weights (between brain modules) ---\n", + " # Ensure input/output dimensions match the neuron_counts\n", + " self.connect_sensory_to_relay =\n", + " nn.Linear(self.neuron_counts['sensory'], self.neuron_counts['relay'])\n", + " self.connect_relay_to_inter =\n", + " nn.Linear(self.neuron_counts['relay'],\n", + " self.neuron_counts['interneurons'])\n", + " self.connect_inter_to_modulators =\n", + " nn.Linear(self.neuron_counts['interneurons'],\n", + " self.neuron_counts['neuroendocrine'])\n", + " self.connect_modulators_to_auto =\n", + " nn.Linear(self.neuron_counts['neuroendocrine'],\n", + " self.neuron_counts['autonomic'])\n", + " self.connect_auto_to_mirror =\n", + " nn.Linear(self.neuron_counts['autonomic'],\n", + " self.neuron_counts['mirror'])\n", + " self.connect_mirror_to_place =\n", + " nn.Linear(self.neuron_counts['mirror'],\n", + " self.neuron_counts['place_grid'])\n", + " # Recurrent loop from place_grid back to relay\n", + " self.connect_place_to_relay =\n", + "\n", + " nn.Linear(self.neuron_counts['place_grid'],\n", + " self.neuron_counts['relay'])\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "fc9114ee", + "metadata": {}, + "outputs": [], + "source": [ + " # --- Optional Feedback Connections ---\n", + " self.feedback_relay_to_sensory =\n", + " nn.Linear(self.neuron_counts['relay'], self.neuron_counts['sensory'])\n", + " self.feedback_inter_to_relay =\n", + " nn.Linear(self.neuron_counts['interneurons'],\n", + " self.neuron_counts['relay'])\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "1ebf7548", + "metadata": {}, + "outputs": [], + "source": [ + " # --- Task Heads (Outputs for specific tasks) ---\n", + " self.task_heads = nn.ModuleDict({\n", + " 'binary': TaskHead(self.neuron_counts['place_grid'], 1,\n", + " 'binary'),\n", + " 'vision': TaskHead(self.neuron_counts['place_grid'], 10,\n", + " 'vision'), # MNIST has 10 classes\n", + " 'regression': TaskHead(self.neuron_counts['place_grid'],\n", + " 1, 'regression')\n", + " })\n", + "\n", + " def route_modules(self, x: torch.Tensor) -> tuple[torch.Tensor,\n", + " ...]:\n", + " \"\"\"\n", + " Defines the forward pass through the interconnected brain\n", + " modules.\n", + " Returns intermediate activations for potential monitoring or\n", + " later use.\n", + " \"\"\"\n", + " # Forward pass through brain circuits\n", + " h_sensory = self.sensory(x)\n", + " h_relay = self.connect_sensory_to_relay(h_sensory)\n", + " h_relay = self.relay(h_relay)\n", + "\n", + " h_inter = self.connect_relay_to_inter(h_relay)\n", + " h_inter = self.interneurons(h_inter)\n", + "\n", + " h_modulate = self.connect_inter_to_modulators(h_inter)\n", + " h_modulate = self.neuroendocrine(h_modulate)\n", + "\n", + " h_auto = self.connect_modulators_to_auto(h_modulate)\n", + " h_auto = self.autonomic(h_auto)\n", + "\n", + " h_mirror_result, _ =\n", + " self.mirror(self.connect_auto_to_mirror(h_auto)) # Mirror returns\n", + " tuple\n", + " # if isinstance(mirror_result, tuple): # No longer needed as\n", + " forward always returns tuple\n", + " # h_mirror, _ = mirror_result\n", + "\n", + " # else:\n", + " # h_mirror = mirror_result\n", + " h_mirror = h_mirror_result # Renamed for clarity after tuple\n", + " unpack\n", + "\n", + " h_place = self.connect_mirror_to_place(h_mirror)\n", + " h_place = self.place_grid(h_place)\n", + "\n", + " # Recurrent loop from hippocampus (place_grid) back to\n", + " thalamus (relay)\n", + " # Note: Added ReLU here to prevent negative values from\n", + " feeding back into spiking neurons potentially\n", + " h_relay = h_relay +\n", + " torch.relu(self.connect_place_to_relay(h_place))\n", + "\n", + " # Optional feedback connections\n", + " h_sensory = h_sensory +\n", + " torch.relu(self.feedback_relay_to_sensory(h_relay))\n", + " h_relay = h_relay +\n", + " torch.relu(self.feedback_inter_to_relay(h_inter))\n", + "\n", + " return h_relay, h_place, h_mirror, h_auto, h_modulate\n", + "\n", + "\n", + " def encode(self, x: torch.Tensor, task_id: int) -> torch.Tensor:\n", + " \"\"\"Selects and applies the appropriate encoder based on task\n", + " ID and input shape.\"\"\"\n", + " # Check task ID and input shape for specific encoders\n", + " if task_id == Task.REGRESSION and x.dim() == 2 and x.size(1)\n", + " == 2:\n", + " return self.encoders['regression'](x.float())\n", + " elif task_id == Task.BINARY and x.dim() == 2 and x.size(1) ==\n", + " 10:\n", + " return self.encoders['language'](x.long())\n", + " elif task_id == Task.VISION:\n", + " if x.dim() == 4: # Already 4D (batch, 1, H, W)\n", + " return self.encoders['vision'](x.float())\n", + " elif x.dim() == 2 and x.size(1) == 784: # Flattened 2D\n", + " (batch, 784)\n", + " return self.encoders['vision'](x.view(x.size(0), 1,\n", + " 28, 28))\n", + " else:\n", + " raise ValueError(f\"Unexpected vision input shape:\n", + " {x.shape}\")\n", + " else:\n", + " # Fallback for unexpected task_id/shape combination, or if\n", + " x is already \"encoded\"\n", + " # In a real system, you might want a specific error or\n", + "\n", + " default encoder here.\n", + " # For now, assumes x is already in the 'sensory' input\n", + " dimension if it falls through.\n", + " if x.size(-1) != self.neuron_counts['sensory']:\n", + " raise ValueError(f\"Input {x.shape} does not match\n", + " sensory input dim {self.neuron_counts['sensory']} and no specific\n", + " encoder found for task {task_id}.\")\n", + " return x.float()\n", + "\n", + "\n", + " def forward(self, x: torch.Tensor, task_id: int) -> torch.Tensor:\n", + " \"\"\"\n", + " Main forward pass of the Modular Brain Agent.\n", + " Encodes input, routes through brain modules, and selects task\n", + " head.\n", + " \"\"\"\n", + " if x.dim() == 1:\n", + " x = x.unsqueeze(0) # Ensure batch dimension\n", + "\n", + " encoded = self.encode(x, task_id)\n", + " # Route through brain modules, focusing on the outputs used\n", + " for task heads\n", + " _, h_place, _, _, _ = self.route_modules(encoded) # Only\n", + " h_place is used for task heads\n", + "\n", + " # Map task IDs to the appropriate task head name\n", + " head_name = {\n", + " Task.REGRESSION: 'regression',\n", + " Task.BINARY: 'binary',\n", + " Task.VISION: 'vision'\n", + " }.get(task_id, 'regression') # Default to regression head if\n", + " task_id is unexpected\n", + "\n", + " out = self.task_heads[head_name](h_place) # All tasks use the\n", + " PlaceGridMemory output\n", + " return out\n", + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "28b1b910", + "metadata": {}, + "outputs": [], + "source": [ + " # --- 7. Data Loaders (Synthetic and Real) ---\n", + "\n", + " def get_mnist_loader() -> (DataLoader, int):\n", + " \"\"\"Returns MNIST DataLoader and its Task ID.\"\"\"\n", + " transform = transforms.Compose([\n", + " transforms.ToTensor(),\n", + " transforms.Normalize((0.5,), (0.5,))\n", + " ])\n", + " dataset = datasets.MNIST(root='./data', train=True, download=True,\n", + " transform=transform)\n", + "\n", + " return DataLoader(dataset, batch_size=4, shuffle=True),\n", + " Task.VISION\n", + "\n", + "\n", + " def get_imdb_loader() -> (DataLoader, int):\n", + " \"\"\"Returns a synthetic IMDB-like DataLoader and its Task ID.\"\"\"\n", + " texts = []\n", + " labels = []\n", + "\n", + " for i in range(1000): # Create 1000 synthetic samples\n", + " sentence = random.choice(sample_sentences)\n", + " tokens = [vocab.get(word, vocab[\"\"]) for word in\n", + " simple_tokenize(sentence)]\n", + "\n", + " # Pad or truncate tokens to a fixed length\n", + " while len(tokens) < 10:\n", + " tokens.append(vocab[\"\"])\n", + " if len(tokens) > 10:\n", + " tokens = tokens[:10]\n", + "\n", + " texts.append(tokens)\n", + " labels.append(i % 2) # Binary sentiment: 0 or 1\n", + "\n", + " X = torch.tensor(texts).long()\n", + " Y = torch.tensor(labels).float()\n", + " dataset = TensorDataset(X, Y)\n", + " return DataLoader(dataset, batch_size=4, shuffle=True),\n", + " Task.BINARY\n", + "\n", + "\n", + " def get_regression_loader() -> (DataLoader, int):\n", + " \"\"\"Returns a synthetic regression DataLoader and its Task ID.\"\"\"\n", + " X = torch.randn(1000, 2) # 1000 samples, 2 features\n", + " Y = ((X[:, 0] + X[:, 1]) / 2).float() # Simple linear relationship\n", + " dataset = TensorDataset(X, Y)\n", + " return DataLoader(dataset, batch_size=4, shuffle=True),\n", + " Task.REGRESSION\n", + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "c3dced65", + "metadata": {}, + "outputs": [], + "source": [ + " # --- 8. Helper Function: Shape Inspector ---\n", + "\n", + " def inspect_shapes(agent: ModularBrainAgent):\n", + " \"\"\"\n", + " Prints the shapes of tensors at various points in the agent's\n", + " forward pass\n", + " to help verify architectural correctness.\n", + " \"\"\"\n", + " print(\"\\n=== SHAPE INSPECTOR ===\")\n", + "\n", + " # Use CPU for inspection to avoid CUDA memory issues during\n", + " debugging\n", + " agent.cpu()\n", + " agent.eval() # Set to eval mode for consistent behavior (e.g.,\n", + " dropout off)\n", + "\n", + " try:\n", + " # Test regression path\n", + " sample_input_reg = torch.randn(4, 2)\n", + " encoded_reg = agent.encode(sample_input_reg, Task.REGRESSION)\n", + " print(f\"Encoded (Regression): {encoded_reg.shape}\")\n", + " h_relay_reg, h_place_reg, h_mirror_reg, h_auto_reg,\n", + " h_modulate_reg = agent.route_modules(encoded_reg)\n", + " out_reg = agent.task_heads['regression'](h_place_reg)\n", + " print(f\"Regression Path:\n", + " Sensory({agent.sensory(encoded_reg).shape}) ->\n", + " Relay({h_relay_reg.shape}) -> Place({h_place_reg.shape}) ->\n", + " Output({out_reg.shape})\")\n", + "\n", + " # Test language path\n", + " # Using fixed input_dim for language to map to sensory input\n", + " dim\n", + " sample_input_lang = torch.randint(0, len(vocab), (4, 10))\n", + " encoded_lang = agent.encode(sample_input_lang, Task.BINARY) #\n", + " Use BINARY task for language\n", + " print(f\"Encoded (Language): {encoded_lang.shape}\")\n", + " h_relay_lang, h_place_lang, h_mirror_lang, h_auto_lang,\n", + " h_modulate_lang = agent.route_modules(encoded_lang)\n", + " out_lang = agent.task_heads['binary'](h_place_lang)\n", + " print(f\"Language Path:\n", + " Sensory({agent.sensory(encoded_lang).shape}) ->\n", + " Relay({h_relay_lang.shape}) -> Place({h_place_lang.shape}) ->\n", + " Output({out_lang.shape})\")\n", + "\n", + " # Test vision path\n", + " sample_input_vis = torch.randn(4, 1, 28, 28)\n", + " encoded_vis = agent.encode(sample_input_vis, Task.VISION)\n", + " print(f\"Encoded (Vision): {encoded_vis.shape}\")\n", + " h_relay_vis, h_place_vis, h_mirror_vis, h_auto_vis,\n", + " h_modulate_vis = agent.route_modules(encoded_vis)\n", + " out_vis = agent.task_heads['vision'](h_place_vis)\n", + " print(f\"Vision Path:\n", + " Sensory({agent.sensory(encoded_vis).shape}) ->\n", + " Relay({h_relay_vis.shape}) -> Place({h_place_vis.shape}) ->\n", + " Output({out_vis.shape})\")\n", + "\n", + " except Exception as e:\n", + " print(f\"Shape inspection failed: {e}\")\n", + "\n", + " # Re-raise to ensure the error is not ignored in main\n", + " execution\n", + " raise e\n", + " finally:\n", + " agent.train() # Set back to train mode\n", + " print(\"=========================\\n\")\n", + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "46a6f97b", + "metadata": {}, + "outputs": [], + "source": [ + " # --- 9. Training Function ---\n", + "\n", + " def train(agent: ModularBrainAgent, episodes: int = 14400,\n", + " buffer_size: int = 1000, replay_freq: int = 5):\n", + " \"\"\"\n", + " Trains the ModularBrainAgent using a curriculum learning strategy\n", + " and experience replay.\n", + " \"\"\"\n", + " device = torch.device(\"cuda\" if torch.cuda.is_available() else\n", + " \"cpu\")\n", + " agent.to(device)\n", + " print(f\"Using device: {device}\")\n", + "\n", + " print(\"🧠 Running shape inspector...\")\n", + " inspect_shapes(agent)\n", + "\n", + " replay_buffer = TaskReplayBuffer(buffer_size=buffer_size,\n", + " device=device)\n", + "\n", + " optimizer = torch.optim.AdamW(agent.parameters(), lr=0.001,\n", + " weight_decay=1e-5)\n", + " scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer,\n", + " patience=100, factor=0.5, verbose=True)\n", + "\n", + " # Initialize weights\n", + " def init_weights(m):\n", + " if isinstance(m, nn.Linear):\n", + " torch.nn.init.xavier_uniform_(m.weight)\n", + " if m.bias is not None:\n", + " torch.nn.init.zeros_(m.bias)\n", + " elif isinstance(m, nn.Conv2d):\n", + " torch.nn.init.kaiming_uniform_(m.weight,\n", + " nonlinearity='relu')\n", + " if m.bias is not None:\n", + " torch.nn.init.zeros_(m.bias)\n", + " agent.apply(init_weights)\n", + "\n", + " # Setup TensorBoard writer\n", + " log_dir = \"runs/modular_brain_agent\"\n", + " os.makedirs(log_dir, exist_ok=True)\n", + "\n", + " writer = SummaryWriter(log_dir=log_dir)\n", + " print(f\"TensorBoard logs are being saved to: {log_dir}\")\n", + "\n", + "\n", + " best_loss = float('inf')\n", + " no_improvement = 0\n", + " save_path = \"best_model.pth\"\n", + " task_history = {i: [] for i in [Task.REGRESSION, Task.BINARY,\n", + " Task.VISION]}\n", + " curriculum_stages = {\n", + " 0: [Task.REGRESSION],\n", + " 300: [Task.REGRESSION, Task.BINARY], # Start binary after 300\n", + " episodes\n", + " 600: [Task.REGRESSION, Task.BINARY, Task.VISION] # Start\n", + " vision after 600 episodes\n", + " }\n", + "\n", + " raw_loaders = {\n", + " Task.REGRESSION: get_regression_loader(),\n", + " Task.BINARY: get_imdb_loader(),\n", + " Task.VISION: get_mnist_loader()\n", + " }\n", + "\n", + " loaders = {k: v[0] for k, v in raw_loaders.items()}\n", + " # task_ids are fixed as keys for loaders in this setup\n", + "\n", + " # Initialize persistent iterators for each task's DataLoader\n", + " loaders_iter = {\n", + " Task.REGRESSION: iter(loaders[Task.REGRESSION]),\n", + " Task.BINARY: iter(loaders[Task.BINARY]),\n", + " Task.VISION: iter(loaders[Task.VISION])\n", + " }\n", + "\n", + " # Helper for loss computation\n", + " def compute_loss(out: torch.Tensor, Y: torch.Tensor, task_id: int)\n", + " -> torch.Tensor:\n", + " \"\"\"Computes loss based on task type, handling shape\n", + " alignments.\"\"\"\n", + " if task_id == Task.VISION:\n", + " Y = Y.long() # Target labels for CrossEntropy must be long\n", + " if out.dim() != 2 or Y.dim() != 1:\n", + " raise ValueError(f\"Vision task expects out [batch,\n", + " num_classes], Y [batch]. Got {out.shape} vs {Y.shape}\")\n", + " return F.cross_entropy(out, Y)\n", + "\n", + " elif task_id == Task.BINARY:\n", + " # Ensure output and target have matching shapes for\n", + " BCEWithLogitsLoss\n", + "\n", + " if out.shape != Y.shape:\n", + " if out.numel() == Y.numel():\n", + " out = out.view_as(Y) # Reshape output to match\n", + " target if num elements are same\n", + " else:\n", + " raise ValueError(f\"Binary task shape mismatch: out\n", + " {out.shape} vs Y {Y.shape}\")\n", + " return F.binary_cross_entropy_with_logits(out, Y.float())\n", + "\n", + " else: # REGRESSION\n", + " if out.shape != Y.shape:\n", + " if out.numel() == Y.numel():\n", + " out = out.view_as(Y) # Reshape output to match\n", + " target if num elements are same\n", + " else:\n", + " raise ValueError(f\"Regression task shape mismatch:\n", + " out {out.shape} vs Y {Y.shape}\")\n", + " return F.smooth_l1_loss(out, Y.float())\n", + "\n", + " loss_weights = {Task.REGRESSION: 1.5, Task.BINARY: 1.2,\n", + " Task.VISION: 1.2}\n", + " global_step = 0 # For TensorBoard logging\n", + "\n", + " for ep in range(episodes):\n", + " agent.train() # Ensure model is in training mode\n", + "\n", + " # Determine current curriculum stage\n", + " current_stage_episodes = 0\n", + " for stage_start_ep, tasks in\n", + " sorted(curriculum_stages.items()):\n", + " if ep >= stage_start_ep:\n", + " current_stage_episodes = stage_start_ep\n", + " current_tasks = curriculum_stages[current_stage_episodes]\n", + "\n", + " # Randomly select a task from the active curriculum stage\n", + " task_id = np.random.choice(current_tasks)\n", + "\n", + " # Fetch data using persistent iterator\n", + " try:\n", + " X, Y = next(loaders_iter[task_id])\n", + " except StopIteration:\n", + " # Reset iterator when exhausted for that task\n", + " loaders_iter[task_id] = iter(loaders[task_id])\n", + " X, Y = next(loaders_iter[task_id])\n", + "\n", + " X, Y = X.to(device), Y.to(device)\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "8cbe3ad7", + "metadata": {}, + "outputs": [], + "source": [ + " # --- Primary Training Step ---\n", + "\n", + " out = agent(X, task_id)\n", + " loss = compute_loss(out, Y, task_id) *\n", + " loss_weights.get(task_id, 1.0)\n", + "\n", + " optimizer.zero_grad()\n", + " loss.backward()\n", + " torch.nn.utils.clip_grad_norm_(agent.parameters(),\n", + " max_norm=1.0) # Clip gradients\n", + " optimizer.step()\n", + "\n", + " # Store current experience in replay buffer\n", + " replay_buffer.add(task_id, X, Y)\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "857737d5", + "metadata": {}, + "outputs": [], + "source": [ + " # --- Accuracy Metric Calculation ---\n", + " acc = 0.0\n", + " with torch.no_grad(): # Don't track gradients for accuracy\n", + " calculation\n", + " if task_id == Task.VISION:\n", + " pred = out.argmax(dim=1)\n", + " acc = (pred == Y).float().mean().item()\n", + " elif task_id == Task.BINARY:\n", + " pred = (torch.sigmoid(out) > 0.5).float()\n", + " acc = (pred == Y).float().mean().item()\n", + " else: # REGRESSION\n", + " acc = ((out - Y).abs() < 0.2).float().mean().item() #\n", + " within 0.2 error margin\n", + "\n", + " task_history[task_id].append((loss.item(), acc))\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "f3724390", + "metadata": {}, + "outputs": [], + "source": [ + " # --- Replay Phase (Fixed Logic) ---\n", + " if ep % replay_freq == 0:\n", + " replay_X_mixed, replay_Y_mixed, replay_task_ids_list =\n", + " replay_buffer.sample(batch_size=4)\n", + "\n", + " if replay_X_mixed is not None:\n", + " replay_X_mixed = replay_X_mixed.to(device)\n", + " replay_Y_mixed = replay_Y_mixed.to(device)\n", + "\n", + " unique_replay_tasks = list(set(replay_task_ids_list))\n", + " total_replay_loss = 0.0\n", + " num_replay_samples = 0\n", + "\n", + " # Iterate through each unique task in the sampled\n", + " batch\n", + " for current_replay_task_id in unique_replay_tasks:\n", + " # Filter samples belonging to this specific task\n", + " indices_for_this_task = [i for i, tid in\n", + " enumerate(replay_task_ids_list) if tid == current_replay_task_id]\n", + "\n", + "\n", + " if not indices_for_this_task:\n", + " continue\n", + "\n", + " task_replay_X =\n", + " replay_X_mixed[indices_for_this_task]\n", + " task_replay_Y =\n", + " replay_Y_mixed[indices_for_this_task]\n", + "\n", + " # Process these samples using the correct encoder\n", + " and task head\n", + " replay_out_for_task = agent(task_replay_X,\n", + " current_replay_task_id)\n", + "\n", + " # Compute loss for this task's samples\n", + " replay_loss_for_task =\n", + " compute_loss(replay_out_for_task, task_replay_Y,\n", + " current_replay_task_id)\n", + " total_replay_loss += replay_loss_for_task *\n", + " len(indices_for_this_task) # Weighted sum\n", + " num_replay_samples += len(indices_for_this_task)\n", + "\n", + " if num_replay_samples > 0:\n", + " total_replay_loss /= num_replay_samples # Average\n", + " over all replayed samples\n", + " optimizer.zero_grad()\n", + " total_replay_loss.backward()\n", + " torch.nn.utils.clip_grad_norm_(agent.parameters(),\n", + " max_norm=1.0)\n", + " optimizer.step()\n", + " # Log replay loss\n", + " writer.add_scalar('Loss/Replay_Loss',\n", + " total_replay_loss.item(), global_step)\n", + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "0b3999e3", + "metadata": {}, + "outputs": [], + "source": [ + " # --- Logging and Early Stopping ---\n", + " scheduler.step(loss.item()) # Step scheduler based on current\n", + " task loss\n", + "\n", + " # TensorBoard Logging\n", + " writer.add_scalar(f'Loss/Current_Task_{task_id}', loss.item(),\n", + " global_step)\n", + " writer.add_scalar(f'Accuracy/Current_Task_{task_id}', acc,\n", + " global_step)\n", + " writer.add_scalar('LearningRate',\n", + " optimizer.param_groups[0]['lr'], global_step)\n", + "\n", + " global_step += 1\n", + "\n", + "\n", + " # Check for best model and early stopping\n", + " if loss.item() < best_loss:\n", + " best_loss = loss.item()\n", + " no_improvement = 0\n", + " torch.save(agent.state_dict(), save_path)\n", + " # print(f\"New best model saved at episode {ep} with loss:\n", + " {best_loss:.4f}\")\n", + " else:\n", + " no_improvement += 1\n", + "\n", + " # Print summary periodically\n", + " if ep % 200 == 0:\n", + " print(f\"\\n--- Episode {ep} (Curriculum Stage:\n", + " {current_stage_episodes}) ---\")\n", + " for t_id in sorted(task_history.keys()):\n", + " if task_history[t_id]:\n", + " # Get recent history, default to empty if not\n", + " enough entries\n", + " recent_losses = [x[0] for x in\n", + " task_history[t_id][-200:]]\n", + " recent_accs = [x[1] for x in\n", + " task_history[t_id][-200:]]\n", + "\n", + " avg_loss = np.mean(recent_losses) if recent_losses\n", + " else 0.0\n", + " avg_acc = np.mean(recent_accs) if recent_accs else\n", + " 0.0\n", + "\n", + " # Map task ID to a readable name\n", + " task_name = {Task.REGRESSION: \"Regression\",\n", + " Task.BINARY: \"Binary\", Task.VISION: \"Vision\"}.get(t_id,\n", + " f\"Unknown_{t_id}\")\n", + " print(f\"Task {task_name} | Avg Loss:\n", + " {avg_loss:.3f} | Avg Acc: {avg_acc:.2f} ({len(recent_losses)}\n", + " samples)\")\n", + " print(f\"Overall Current Loss: {loss.item():.4f} | Best\n", + " Loss: {best_loss:.4f} | No Improvement: {no_improvement}\")\n", + " print(\"--------------------\\n\")\n", + "\n", + " if no_improvement >= 1000: # Early stopping threshold\n", + " print(f\"Early stopping at episode {ep} due to no\n", + " improvement for {no_improvement} steps.\")\n", + " break\n", + "\n", + " writer.close()\n", + " print(\"āœ… Training finished.\")\n", + " return best_loss\n", + "\n", + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "2300bd2d", + "metadata": {}, + "outputs": [], + "source": [ + " # --- 10. Main Execution Block ---\n", + "\n", + " if __name__ == \"__main__\":\n", + " print(\"šŸš€ Initializing Modular Brain Agent...\")\n", + "\n", + " # Optional: Define custom neuron counts here\n", + " # custom_neuron_config = {\n", + " # 'sensory': 8,\n", + " # 'relay': 24,\n", + " # 'interneurons': 4,\n", + " # 'neuroendocrine': 16,\n", + " # 'autonomic': 20,\n", + " # 'mirror': 28,\n", + " # 'place_grid': 32\n", + " # }\n", + " # agent = ModularBrainAgent(neuron_counts=custom_neuron_config)\n", + "\n", + " agent = ModularBrainAgent() # Using default neuron counts for now\n", + "\n", + " # Print model summary\n", + " total_params = sum(p.numel() for p in agent.parameters())\n", + " trainable_params = sum(p.numel() for p in agent.parameters() if\n", + " p.requires_grad)\n", + "\n", + " print(f\"šŸ“Š Model Summary:\")\n", + " print(f\" Total parameters: {total_params:,}\")\n", + " print(f\" Trainable parameters: {trainable_params:,}\")\n", + " print(f\" Model size: ~{total_params * 4 / (1024**2):.2f} MB\n", + " (approx. float32)\")\n", + "\n", + " # Test forward pass on each task to verify shapes\n", + " print(\"\\nšŸ” Testing forward passes...\")\n", + " try:\n", + " # Test regression task\n", + " test_reg = torch.randn(2, 2)\n", + " out_reg = agent(test_reg, Task.REGRESSION)\n", + " print(f\"āœ… Regression test passed: {test_reg.shape} ->\n", + " {out_reg.shape}\")\n", + "\n", + " # Test binary classification task (language task)\n", + " test_bin = torch.randint(0, len(vocab), (2, 10))\n", + " out_bin = agent(test_bin, Task.BINARY)\n", + " print(f\"āœ… Binary classification test passed: {test_bin.shape}\n", + " -> {out_bin.shape}\")\n", + "\n", + " # Test vision task\n", + "\n", + " test_vis = torch.randn(2, 1, 28, 28)\n", + " out_vis = agent(test_vis, Task.VISION)\n", + " print(f\"āœ… Vision test passed: {test_vis.shape} ->\n", + " {out_vis.shape}\")\n", + "\n", + " except Exception as e:\n", + " print(f\"āŒ Forward pass test failed: {e}\")\n", + " # If forward pass fails, there's a fundamental issue, so exit\n", + " import sys\n", + " sys.exit(1)\n", + "\n", + " print(\"\\nšŸŽÆ Starting training...\")\n", + "\n", + " # Train the agent\n", + " try:\n", + " final_best_loss = train(\n", + " agent=agent,\n", + " episodes=14400, # Total training episodes\n", + " buffer_size=1000, # Replay buffer size\n", + " replay_freq=5 # Replay every N episodes\n", + " )\n", + "\n", + " print(f\"\\nšŸŽ‰ Training completed successfully!\")\n", + " print(f\"šŸ“ˆ Best loss achieved during training:\n", + " {final_best_loss:.4f}\")\n", + " print(f\"šŸ’¾ Best model saved to: best_model.pth\")\n", + "\n", + " except KeyboardInterrupt:\n", + " print(\"\\nā¹ Training interrupted by user.\")\n", + " except Exception as e:\n", + " print(f\"\\nāŒ Training failed with error: {e}\")\n", + " raise e\n", + "\n", + " print(\"\\n🧠 Modular Brain Agent execution completed!\")" + ] + } + ], + "metadata": {}, + "nbformat": 4, + "nbformat_minor": 5 +}