{ "cells": [ { "cell_type": "code", "execution_count": null, "id": "b6055f45", "metadata": {}, "outputs": [], "source": [ "import torch\n", "import torch.nn as nn\n", "import torch.nn.functional as F\n", "from torch.utils.data import DataLoader, TensorDataset\n", "from torchvision import datasets, transforms\n", "from torch.utils.tensorboard import SummaryWriter\n", "import numpy as np\n", "import random\n", "from collections import deque\n", "import os # For creating log directory\n" ] }, { "cell_type": "code", "execution_count": null, "id": "8db6a5c2", "metadata": {}, "outputs": [], "source": [ "# --- 0. Global Setup and Vocabulary ---\n", "\n", "# Simulated language data\n", "sample_sentences = [\n", "\"I love AI\",\n", "\"Deep learning is fun\",\n", "\"Spiking neurons are cool\",\n", "\"Brain inspired models rock\",\n", "\"Replay buffer helps learning\"\n", "]\n", "\n", "\n", "def simple_tokenize(text):\n", " \"\"\"Basic tokenizer for sample sentences.\"\"\"\n", " return text.lower().split()[:10]\n", "\n", " # Build vocab manually\n", " vocab = {\"\": 0}\n", " word_counter = 1\n", " for sentence in sample_sentences:\n", " for word in simple_tokenize(sentence):\n", " if word not in vocab:\n", " vocab[word] = word_counter\n", " word_counter += 1\n", "\n", " # Task IDs (using an Enum-like class for clarity)\n", " class Task:\n", " REGRESSION = 3\n", " BINARY = 4\n", " VISION = 5\n" ] }, { "cell_type": "code", "execution_count": null, "id": "b55bcf2b", "metadata": {}, "outputs": [], "source": [ " # --- 1. Base Classes and Spiking Neuron Implementations ---\n", "\n", " class Module(nn.Module):\n", " \"\"\"Base class for brain-inspired modules.\"\"\"\n", "\n", " def __init__(self):\n", " super().__init__()\n", "\n", " def forward(self, x: torch.Tensor) -> torch.Tensor:\n", " raise NotImplementedError\n", "\n", " class SurrogateLIF(torch.autograd.Function):\n", " \"\"\"\n", " Surrogate gradient function for LIF neurons.\n", " Allows backpropagation through the non-differentiable spiking\n", " function.\n", " \"\"\"\n", " @staticmethod\n", " def forward(\n", " ctx, input: torch.Tensor) -> torch.Tensor:\n", " ctx.save_for_backward(input)\n", " return (input > 0).float()\n", "\n", " @staticmethod\n", " def backward(\n", " ctx, grad_output: torch.Tensor) -> torch.Tensor:\n", " input, = ctx.saved_tensors\n", " grad_input = grad_output.clone()\n", " # Approximate derivative: a\n", " # constant value where\n", " # input is\n", " close to 0\n", " grad_input[input.abs(\n", " ) < 1] = 1.0\n", " return grad_input\n", "\n", " class SpikingNeuron(\n", " Module):\n", " \"\"\"\n", " Leaky Integrate-and-Fire (LIF) neuron model.\n", " Resets membrane potential to zero after spiking.\n", " \"\"\"\n", "\n", " def __init__(\n", " self, threshold: float = 1.0, decay: float = 0.95):\n", " super().__init__()\n", " self.threshold = threshold\n", " self.decay = decay\n", " # Register buffer\n", " # for membrane\n", " # potential,\n", " # persists across\n", " forward calls\n", " # type:\n", " # torch.Tensor\n", " self.register_buffer(\n", " 'mem', None)\n", "\n", " def forward(\n", " self, x: torch.Tensor) -> torch.Tensor:\n", " # Initialize or\n", " # re-initialize\n", " # membrane\n", " # potential if\n", " # shape\n", " changes\n", " if self.mem is None or self.mem.shape != x.shape:\n", " self.mem = torch.zeros_like(\n", " x)\n", "\n", " # Update\n", " # membrane\n", " # potential\n", " self.mem = self.decay * self.mem + x\n", " # Generate\n", " # spike\n", " spike = (\n", " self.mem >= self.threshold).float()\n", "\n", " # Reset\n", " # membrane\n", " # potential\n", " # for\n", " # spiking\n", " # neurons\n", " self.mem = torch.where(spike.bool(),\n", " torch.zeros_like(self.mem), self.mem)\n", "\n", " # Apply\n", " # surrogate\n", " # gradient\n", " # for\n", " # backpropagation\n", " return SurrogateLIF.apply(\n", " spike)\n", "\n", " class AdaptiveLIFNeuron(\n", " SpikingNeuron):\n", " \"\"\"\n", " LIF neuron with an adaptive threshold based on recent firing rate.\n", " \"\"\"\n", "\n", " def __init__(\n", " self, threshold: float = 1.0, decay: float = 0.95):\n", " super().__init__(threshold=threshold, decay=decay)\n", " # Register\n", " # buffer\n", " # for\n", " # adaptive\n", " # threshold\n", " # state\n", " self.register_buffer('threshold_state',\n", " torch.tensor(threshold))\n", "\n", " def forward(\n", " self, x: torch.Tensor) -> torch.Tensor:\n", " if self.mem is None or self.mem.shape != x.shape:\n", " self.mem = torch.zeros_like(\n", " x)\n", "\n", " self.mem = self.decay * self.mem + x\n", " spike = (\n", " self.mem >= self.threshold_state).float()\n", " self.mem = torch.where(spike.bool(),\n", " torch.zeros_like(self.mem), self.mem)\n", "\n", " # Adapt\n", " # threshold\n", " # based\n", " # on\n", " # scalar\n", " # mean\n", " # spike\n", " # rate\n", " # (no_grad\n", " # for\n", " threshold update)\n", " with torch.no_grad():\n", " spike_rate = spike.mean().item() # Scalar rate\n", " # Simple\n", " # moving\n", " # average\n", " # for\n", " # threshold\n", " # adaptation\n", " self.threshold_state = self.threshold_state * 0.99 +\n", " spike_rate * 0.01\n", "\n", " return SurrogateLIF.apply(\n", " spike)\n" ] }, { "cell_type": "code", "execution_count": null, "id": "618a2828", "metadata": {}, "outputs": [], "source": [ " # --- 2. Encoder Modules (Task-Specific Input Processing) ---\n", "\n", " class SharedEncoder(\n", " nn.Module):\n", " \"\"\"Generic encoder for numerical data.\"\"\"\n", " def __init__(\n", " self, input_size: int, hidden_size: int = 4):\n", " super().__init__()\n", " self.encoder = nn.Sequential(\n", " nn.Linear(\n", " input_size, hidden_size),\n", " nn.LayerNorm(\n", " hidden_size),\n", " nn.ReLU(),\n", " nn.Dropout(\n", " 0.3)\n", "\n", " )\n", "\n", " def forward(\n", " self, x: torch.Tensor) -> torch.Tensor:\n", " assert x.dim() == 2, f\"Expected 2D input for SharedEncoder,\n", " got {x.shape}\"\n", " return self.encoder(x.float())\n", "\n", "\n", " class CNNVision(nn.Module):\n", " \"\"\"CNN encoder for image data (e.g., MNIST).\"\"\"\n", " def __init__(self, output_features: int = 4):\n", " super().__init__()\n", " self.conv = nn.Sequential(\n", " nn.Conv2d(1, 4, kernel_size=5),\n", " nn.ReLU(),\n", " nn.MaxPool2d(2),\n", " nn.Conv2d(4, 8, kernel_size=5),\n", " nn.ReLU(),\n", " nn.AdaptiveAvgPool2d((1, 1)), # Outputs [batch, 8, 1, 1]\n", " nn.Flatten(), # Outputs [batch, 8]\n", " nn.Linear(8, output_features) # Maps to desired output\n", " size\n", " )\n", "\n", " def forward(self, x: torch.Tensor) -> torch.Tensor:\n", " return self.conv(x)\n", "\n", "\n", " class GRULanguage(nn.Module):\n", " \"\"\"GRU encoder for sequential language data.\"\"\"\n", " def __init__(self, vocab_size: int, embedding_dim: int = 4,\n", " hidden_dim: int = 4):\n", " super().__init__()\n", " self.embed = nn.Embedding(vocab_size, embedding_dim)\n", " self.gru = nn.GRU(embedding_dim, hidden_dim, batch_first=True)\n", "\n", " def forward(self, x: torch.Tensor) -> torch.Tensor:\n", " if x.dim() == 1:\n", " x = x.unsqueeze(0) # Add batch dimension if missing\n", " emb = self.embed(x.long()) # Ensure input is long type for\n", " embedding\n", " out, _ = self.gru(emb)\n", " return out[:, -1, :] # Use last hidden state as sequence\n", " representation\n" ] }, { "cell_type": "code", "execution_count": null, "id": "e3333c1f", "metadata": {}, "outputs": [], "source": [ " # --- 3. Brain-Inspired Modules with Specific Neuron Counts ---\n", "\n", " class SensoryProcessor(Module):\n", "\n", " \"\"\"Processes initial sensory input, often with low-level\n", " features.\"\"\"\n", " def __init__(self, input_dim: int, output_dim: int):\n", " super().__init__()\n", " self.linear = nn.Linear(input_dim, output_dim)\n", " self.norm = nn.LayerNorm(output_dim)\n", " self.dropout = nn.Dropout(0.3)\n", " self.neuron = SpikingNeuron()\n", "\n", " def forward(self, x: torch.Tensor) -> torch.Tensor:\n", " z = self.dropout(self.norm(self.linear(x)))\n", " return self.neuron(z)\n", "\n", "\n", " class RelayLayer(Module):\n", " \"\"\"\n", " Routes information, potentially performing attention-like\n", " operations.\n", " Simulates a small sequence for MultiheadAttention.\n", " \"\"\"\n", " def __init__(self, input_dim: int, output_dim: int):\n", " super().__init__()\n", " self.linear = nn.Linear(input_dim, output_dim)\n", " self.norm = nn.LayerNorm(output_dim)\n", " # MultiheadAttention requires embed_dim to be divisible by\n", " num_heads\n", " self.attn = nn.MultiheadAttention(embed_dim=output_dim,\n", " num_heads=2, batch_first=True)\n", " self.dropout = nn.Dropout(0.3)\n", " self.neuron = SpikingNeuron()\n", "\n", " def forward(self, x: torch.Tensor) -> torch.Tensor:\n", " z = self.dropout(self.norm(self.linear(x)))\n", "\n", " # Create fake time dimension by repeating the input for\n", " attention\n", " seq_len = 4\n", " z_seq = z.unsqueeze(1).repeat(1, seq_len, 1)\n", "\n", " # Apply attention across the simulated time steps\n", " z_out, _ = self.attn(z_seq, z_seq, z_seq)\n", "\n", " # Pool over the time dimension to get a single representation\n", " routed = z_out.mean(dim=1)\n", "\n", " return self.neuron(routed)\n", "\n", "\n", "\n", " class InterneuronLogic(Module):\n", " \"\"\"Core processing unit, potentially for decision making or\n", " high-level abstraction.\"\"\"\n", " def __init__(self, input_dim: int, output_dim: int):\n", " super().__init__()\n", " self.linear = nn.Linear(input_dim, output_dim)\n", " self.norm = nn.LayerNorm(output_dim)\n", " self.dropout = nn.Dropout(0.3)\n", " self.neuron = AdaptiveLIFNeuron() # Using adaptive neuron here\n", "\n", " def forward(self, x: torch.Tensor) -> torch.Tensor:\n", " z = self.dropout(self.norm(self.linear(x)))\n", " return self.neuron(z)\n", "\n", "\n", " class NeuroendocrineModulator(Module):\n", " \"\"\"\n", " Modulates signals, potentially for gain control or emotional\n", " states.\n", " Applies a sigmoid-based gain to its input.\n", " \"\"\"\n", " def __init__(self, input_dim: int, output_dim: int):\n", " super().__init__()\n", " self.linear = nn.Linear(input_dim, output_dim)\n", " self.norm = nn.LayerNorm(output_dim)\n", " self.dropout = nn.Dropout(0.3)\n", " self.gain_control = nn.Linear(output_dim, output_dim) # Maps\n", " output to gain factor\n", "\n", " def forward(self, x: torch.Tensor) -> torch.Tensor:\n", " z = self.dropout(self.norm(self.linear(x)))\n", " gain = torch.sigmoid(self.gain_control(z)) # Gain factor\n", " between 0 and 1\n", " return z * gain\n", "\n", "\n", " class AutonomicProcessor(Module):\n", " \"\"\"\n", " Manages internal states, often involves recurrent processing.\n", " Uses a GRU for sequential processing.\n", " \"\"\"\n", " def __init__(self, input_dim: int, output_dim: int):\n", " super().__init__()\n", " self.linear = nn.Linear(input_dim, output_dim)\n", " self.norm = nn.LayerNorm(output_dim)\n", " self.recurrent = nn.GRU(output_dim, output_dim,\n", " batch_first=True)\n", " self.feedback_gain = 0.9\n", "\n", "\n", " def forward(self, x: torch.Tensor) -> torch.Tensor:\n", " z = self.norm(self.linear(x))\n", " # GRU expects (batch, seq_len, features) -> seq_len=1 for\n", " single step\n", " h, _ = self.recurrent(z.unsqueeze(1))\n", " # Squeeze the sequence dimension back out\n", " return self.feedback_gain * h.squeeze(1)\n", "\n", "\n", " class MirrorComparator(Module):\n", " \"\"\"\n", " Compares current state to a reference, potentially for self-other\n", " distinction or goal comparison.\n", " Outputs spikes and an optional similarity score.\n", " \"\"\"\n", " def __init__(self, input_dim: int, output_dim: int):\n", " super().__init__()\n", " self.linear = nn.Linear(input_dim, output_dim)\n", " self.norm = nn.LayerNorm(output_dim)\n", " self.comparison_layer = nn.CosineSimilarity(dim=1) # Cosine\n", " similarity for comparison\n", " self.neuron = SpikingNeuron()\n", "\n", " def compare(self, a: torch.Tensor, b: torch.Tensor) ->\n", " torch.Tensor:\n", " \"\"\"Compare pre-spike representations for similarity.\"\"\"\n", " a_proj = self.linear(a)\n", " b_proj = self.linear(b)\n", " return self.comparison_layer(a_proj, b_proj)\n", "\n", " def forward(self, x: torch.Tensor, reference: torch.Tensor = None)\n", " -> (torch.Tensor, torch.Tensor):\n", " z = self.norm(self.linear(x))\n", " spike = self.neuron(z)\n", " similarity = None\n", " if reference is not None:\n", " # Note: The 'reference' input would need to be processed\n", " similarly before comparison\n", " # For simplicity, assuming 'reference' is already in the\n", " correct feature space for comparison_layer\n", " similarity = self.compare(z, reference) # Pass processed\n", " 'z' for comparison\n", " return spike, similarity\n", "\n", "\n", " class PlaceGridMemory(Module):\n", " \"\"\"\n", "\n", " Spatial memory system using population codes and LSTM for\n", " sequential memory.\n", " \"\"\"\n", " def __init__(self, input_dim: int, output_dim: int):\n", " super().__init__()\n", " self.linear = nn.Linear(input_dim, output_dim)\n", " self.norm = nn.LayerNorm(output_dim)\n", " self.positional_encoder = self.population_code # Reference to\n", " internal method\n", " self.memory_cell = nn.LSTM(output_dim, output_dim,\n", " batch_first=True)\n", "\n", " def population_code(self, x: torch.Tensor, pop_size: int) ->\n", " torch.Tensor:\n", " \"\"\"Expands a scalar value into a population-coded (one-hot\n", " like) vector.\"\"\"\n", " # Calculate a scalar representation (e.g., mean)\n", " x_scalar = x.mean(dim=1)\n", " x_normalized = torch.sigmoid(x_scalar) # Normalize to [0, 1]\n", " # Map normalized value to an index within the population size\n", " idx = (x_normalized * (pop_size - 1)).long().clamp(0, pop_size\n", " - 1)\n", "\n", " encoded = torch.zeros(x.size(0), pop_size, device=x.device)\n", " # Set the corresponding index to 1.0\n", " encoded.scatter_(1, idx.unsqueeze(1), 1.0)\n", " return encoded\n", "\n", " def forward(self, x: torch.Tensor) -> torch.Tensor:\n", " encoded = self.norm(self.linear(x))\n", " # Apply population coding based on configured output_dim\n", " pos_encoded = self.positional_encoder(encoded,\n", " pop_size=self.linear.out_features)\n", " # LSTM expects (batch, seq_len, features) -> seq_len=1 for\n", " single step\n", " memory_out, _ = self.memory_cell(pos_encoded.unsqueeze(1))\n", " # Squeeze the sequence dimension back out\n", " return memory_out.squeeze(1)\n", "\n" ] }, { "cell_type": "code", "execution_count": null, "id": "c7c74040", "metadata": {}, "outputs": [], "source": [ " # --- 4. Task Heads (Task-Specific Output Layers) ---\n", "\n", " class TaskHead(nn.Module):\n", " \"\"\"Generic task head for different output types.\"\"\"\n", " def __init__(self, input_dim: int, output_dim: int, task_type:\n", " str):\n", " super().__init__()\n", " self.task_type = task_type\n", "\n", " self.head = nn.Sequential(\n", " nn.Linear(input_dim, output_dim),\n", " nn.Dropout(0.3)\n", " )\n", "\n", " def forward(self, x: torch.Tensor) -> torch.Tensor:\n", " # Regression and Binary will be squeezed to [batch]\n", " # Vision will remain [batch, num_classes]\n", " return self.head(x).squeeze(-1) if self.task_type in\n", " ['binary', 'regression'] else self.head(x)\n", "\n" ] }, { "cell_type": "code", "execution_count": null, "id": "65e6bc3e", "metadata": {}, "outputs": [], "source": [ " # --- 5. Replay Buffer for Continual Learning ---\n", "\n", " class TaskReplayBuffer:\n", " \"\"\"\n", " A replay buffer that stores experiences tagged with their task ID.\n", " Samples are drawn proportionally or randomly across tasks.\n", " \"\"\"\n", " def __init__(self, buffer_size: int = 1000, device: str = \"cpu\"):\n", " self.buffer_size = buffer_size\n", " self.task_buffers = {task_id: deque(maxlen=buffer_size) for\n", " task_id in [Task.REGRESSION, Task.BINARY, Task.VISION]}\n", " self.device = device\n", "\n", " def add(self, task_id: int, states: torch.Tensor, labels:\n", " torch.Tensor):\n", " \"\"\"Adds experiences (state-label pairs) to the buffer for a\n", " specific task.\"\"\"\n", " if task_id not in self.task_buffers:\n", " # Should not happen if self.task_buffers is\n", " pre-initialized with all Task IDs\n", " self.task_buffers[task_id] =\n", " deque(maxlen=self.buffer_size)\n", "\n", " for state, label in zip(states, labels):\n", " # Detach from graph and move to CPU to prevent memory\n", " leaks\n", " self.task_buffers[task_id].append((state.detach().cpu(),\n", " label.detach().cpu()))\n", "\n", " def sample(self, batch_size: int = 32) -> (torch.Tensor,\n", " torch.Tensor, list):\n", " \"\"\"\n", " Samples a batch of experiences from the buffer, ensuring mixed\n", " tasks.\n", " Returns: (states, labels, list of corresponding task_ids)\n", " \"\"\"\n", "\n", " active_tasks = [tid for tid, buffer in\n", " self.task_buffers.items() if len(buffer) > 0]\n", " if not active_tasks:\n", " return None, None, None\n", "\n", " # Sample at least two distinct tasks if possible\n", " num_tasks_to_sample = min(2, len(active_tasks))\n", " sampled_tasks_ids = random.sample(active_tasks,\n", " num_tasks_to_sample)\n", "\n", " batch_states, batch_labels, batch_task_ids = [], [], []\n", "\n", " # Distribute batch_size across sampled tasks\n", " per_task_samples_base = batch_size // num_tasks_to_sample\n", " remainder = batch_size % num_tasks_to_sample\n", "\n", " for i, task_id in enumerate(sampled_tasks_ids):\n", " task_list = list(self.task_buffers[task_id]) # Convert\n", " deque to list for random.sample/choices\n", " if not task_list:\n", " continue\n", "\n", " # Add remainder to first few tasks\n", " k = per_task_samples_base + (1 if i < remainder else 0)\n", " k = min(k, len(task_list)) # Ensure we don't sample more\n", " than available\n", "\n", " if k == 0: continue\n", "\n", " try:\n", " samples = random.sample(task_list, k)\n", " except ValueError:\n", " # If k is larger than task_list length (should be\n", " caught by min(k, len(task_list))\n", " # but good fallback for robustness)\n", " samples = random.choices(task_list, k=k)\n", "\n", " for state, label in samples:\n", " batch_states.append(state)\n", " batch_labels.append(label)\n", " batch_task_ids.append(task_id)\n", "\n", " if not batch_states:\n", " return None, None, None\n", "\n", " # Stack tensors and return task IDs for individual processing\n", " return (\n", " torch.stack(batch_states),\n", "\n", " torch.stack(batch_labels),\n", " batch_task_ids # Return as list to preserve individual\n", " task identities\n", " )\n", "\n" ] }, { "cell_type": "code", "execution_count": null, "id": "cf664c01", "metadata": {}, "outputs": [], "source": [ " # --- 6. Full Agent Model ---\n", "\n", " class ModularBrainAgent(nn.Module):\n", " \"\"\"\n", " A comprehensive modular neural agent inspired by brain\n", " architecture,\n", " integrating spiking neurons, attention, and recurrent memory for\n", " multi-task learning.\n", " \"\"\"\n", " def __init__(self, neuron_counts: dict = None):\n", " super().__init__()\n" ] }, { "cell_type": "code", "execution_count": null, "id": "4afb88ab", "metadata": {}, "outputs": [], "source": [ " # --- Parameterization of Neuron Counts ---\n", " if neuron_counts is None:\n", " # Default neuron counts for each brain region\n", " neuron_counts = {\n", " 'sensory': 4,\n", " 'relay': 12,\n", " 'interneurons': 2,\n", " 'neuroendocrine': 8,\n", " 'autonomic': 10,\n", " 'mirror': 14,\n", " 'place_grid': 16\n", " }\n", " self.neuron_counts = neuron_counts\n" ] }, { "cell_type": "code", "execution_count": null, "id": "cd012993", "metadata": {}, "outputs": [], "source": [ " # --- Encoders (Input Modalities) ---\n", " self.encoders = nn.ModuleDict({\n", " 'regression': SharedEncoder(2,\n", " self.neuron_counts['sensory']),\n", " 'language': GRULanguage(len(vocab),\n", " embedding_dim=self.neuron_counts['sensory'],\n", " hidden_dim=self.neuron_counts['sensory']),\n", " 'vision':\n", " CNNVision(output_features=self.neuron_counts['sensory'])\n", " })\n" ] }, { "cell_type": "code", "execution_count": null, "id": "6ab19dd3", "metadata": {}, "outputs": [], "source": [ " # --- Brain Modules with Realistic Neuron Counts ---\n", " self.sensory = SensoryProcessor(\n", " input_dim=self.neuron_counts['sensory'],\n", " output_dim=self.neuron_counts['sensory']\n", " )\n", "\n", " self.relay = RelayLayer(\n", " input_dim=self.neuron_counts['sensory'], # Input from\n", " SensoryProcessor\n", " output_dim=self.neuron_counts['relay']\n", " )\n", " self.interneurons = InterneuronLogic(\n", " input_dim=self.neuron_counts['relay'], # Input from\n", " RelayLayer\n", " output_dim=self.neuron_counts['interneurons']\n", " )\n", " self.neuroendocrine = NeuroendocrineModulator(\n", " input_dim=self.neuron_counts['interneurons'],\n", " output_dim=self.neuron_counts['neuroendocrine']\n", " )\n", " self.autonomic = AutonomicProcessor(\n", " input_dim=self.neuron_counts['neuroendocrine'],\n", " output_dim=self.neuron_counts['autonomic']\n", " )\n", " self.mirror = MirrorComparator(\n", " input_dim=self.neuron_counts['autonomic'],\n", " output_dim=self.neuron_counts['mirror']\n", " )\n", " self.place_grid = PlaceGridMemory(\n", " input_dim=self.neuron_counts['mirror'],\n", " output_dim=self.neuron_counts['place_grid']\n", " )\n" ] }, { "cell_type": "code", "execution_count": null, "id": "224830f4", "metadata": {}, "outputs": [], "source": [ " # --- Interconnection Weights (between brain modules) ---\n", " # Ensure input/output dimensions match the neuron_counts\n", " self.connect_sensory_to_relay =\n", " nn.Linear(self.neuron_counts['sensory'], self.neuron_counts['relay'])\n", " self.connect_relay_to_inter =\n", " nn.Linear(self.neuron_counts['relay'],\n", " self.neuron_counts['interneurons'])\n", " self.connect_inter_to_modulators =\n", " nn.Linear(self.neuron_counts['interneurons'],\n", " self.neuron_counts['neuroendocrine'])\n", " self.connect_modulators_to_auto =\n", " nn.Linear(self.neuron_counts['neuroendocrine'],\n", " self.neuron_counts['autonomic'])\n", " self.connect_auto_to_mirror =\n", " nn.Linear(self.neuron_counts['autonomic'],\n", " self.neuron_counts['mirror'])\n", " self.connect_mirror_to_place =\n", " nn.Linear(self.neuron_counts['mirror'],\n", " self.neuron_counts['place_grid'])\n", " # Recurrent loop from place_grid back to relay\n", " self.connect_place_to_relay =\n", "\n", " nn.Linear(self.neuron_counts['place_grid'],\n", " self.neuron_counts['relay'])\n" ] }, { "cell_type": "code", "execution_count": null, "id": "fc9114ee", "metadata": {}, "outputs": [], "source": [ " # --- Optional Feedback Connections ---\n", " self.feedback_relay_to_sensory =\n", " nn.Linear(self.neuron_counts['relay'], self.neuron_counts['sensory'])\n", " self.feedback_inter_to_relay =\n", " nn.Linear(self.neuron_counts['interneurons'],\n", " self.neuron_counts['relay'])\n" ] }, { "cell_type": "code", "execution_count": null, "id": "1ebf7548", "metadata": {}, "outputs": [], "source": [ " # --- Task Heads (Outputs for specific tasks) ---\n", " self.task_heads = nn.ModuleDict({\n", " 'binary': TaskHead(self.neuron_counts['place_grid'], 1,\n", " 'binary'),\n", " 'vision': TaskHead(self.neuron_counts['place_grid'], 10,\n", " 'vision'), # MNIST has 10 classes\n", " 'regression': TaskHead(self.neuron_counts['place_grid'],\n", " 1, 'regression')\n", " })\n", "\n", " def route_modules(self, x: torch.Tensor) -> tuple[torch.Tensor,\n", " ...]:\n", " \"\"\"\n", " Defines the forward pass through the interconnected brain\n", " modules.\n", " Returns intermediate activations for potential monitoring or\n", " later use.\n", " \"\"\"\n", " # Forward pass through brain circuits\n", " h_sensory = self.sensory(x)\n", " h_relay = self.connect_sensory_to_relay(h_sensory)\n", " h_relay = self.relay(h_relay)\n", "\n", " h_inter = self.connect_relay_to_inter(h_relay)\n", " h_inter = self.interneurons(h_inter)\n", "\n", " h_modulate = self.connect_inter_to_modulators(h_inter)\n", " h_modulate = self.neuroendocrine(h_modulate)\n", "\n", " h_auto = self.connect_modulators_to_auto(h_modulate)\n", " h_auto = self.autonomic(h_auto)\n", "\n", " h_mirror_result, _ =\n", " self.mirror(self.connect_auto_to_mirror(h_auto)) # Mirror returns\n", " tuple\n", " # if isinstance(mirror_result, tuple): # No longer needed as\n", " forward always returns tuple\n", " # h_mirror, _ = mirror_result\n", "\n", " # else:\n", " # h_mirror = mirror_result\n", " h_mirror = h_mirror_result # Renamed for clarity after tuple\n", " unpack\n", "\n", " h_place = self.connect_mirror_to_place(h_mirror)\n", " h_place = self.place_grid(h_place)\n", "\n", " # Recurrent loop from hippocampus (place_grid) back to\n", " thalamus (relay)\n", " # Note: Added ReLU here to prevent negative values from\n", " feeding back into spiking neurons potentially\n", " h_relay = h_relay +\n", " torch.relu(self.connect_place_to_relay(h_place))\n", "\n", " # Optional feedback connections\n", " h_sensory = h_sensory +\n", " torch.relu(self.feedback_relay_to_sensory(h_relay))\n", " h_relay = h_relay +\n", " torch.relu(self.feedback_inter_to_relay(h_inter))\n", "\n", " return h_relay, h_place, h_mirror, h_auto, h_modulate\n", "\n", "\n", " def encode(self, x: torch.Tensor, task_id: int) -> torch.Tensor:\n", " \"\"\"Selects and applies the appropriate encoder based on task\n", " ID and input shape.\"\"\"\n", " # Check task ID and input shape for specific encoders\n", " if task_id == Task.REGRESSION and x.dim() == 2 and x.size(1)\n", " == 2:\n", " return self.encoders['regression'](x.float())\n", " elif task_id == Task.BINARY and x.dim() == 2 and x.size(1) ==\n", " 10:\n", " return self.encoders['language'](x.long())\n", " elif task_id == Task.VISION:\n", " if x.dim() == 4: # Already 4D (batch, 1, H, W)\n", " return self.encoders['vision'](x.float())\n", " elif x.dim() == 2 and x.size(1) == 784: # Flattened 2D\n", " (batch, 784)\n", " return self.encoders['vision'](x.view(x.size(0), 1,\n", " 28, 28))\n", " else:\n", " raise ValueError(f\"Unexpected vision input shape:\n", " {x.shape}\")\n", " else:\n", " # Fallback for unexpected task_id/shape combination, or if\n", " x is already \"encoded\"\n", " # In a real system, you might want a specific error or\n", "\n", " default encoder here.\n", " # For now, assumes x is already in the 'sensory' input\n", " dimension if it falls through.\n", " if x.size(-1) != self.neuron_counts['sensory']:\n", " raise ValueError(f\"Input {x.shape} does not match\n", " sensory input dim {self.neuron_counts['sensory']} and no specific\n", " encoder found for task {task_id}.\")\n", " return x.float()\n", "\n", "\n", " def forward(self, x: torch.Tensor, task_id: int) -> torch.Tensor:\n", " \"\"\"\n", " Main forward pass of the Modular Brain Agent.\n", " Encodes input, routes through brain modules, and selects task\n", " head.\n", " \"\"\"\n", " if x.dim() == 1:\n", " x = x.unsqueeze(0) # Ensure batch dimension\n", "\n", " encoded = self.encode(x, task_id)\n", " # Route through brain modules, focusing on the outputs used\n", " for task heads\n", " _, h_place, _, _, _ = self.route_modules(encoded) # Only\n", " h_place is used for task heads\n", "\n", " # Map task IDs to the appropriate task head name\n", " head_name = {\n", " Task.REGRESSION: 'regression',\n", " Task.BINARY: 'binary',\n", " Task.VISION: 'vision'\n", " }.get(task_id, 'regression') # Default to regression head if\n", " task_id is unexpected\n", "\n", " out = self.task_heads[head_name](h_place) # All tasks use the\n", " PlaceGridMemory output\n", " return out\n", "\n" ] }, { "cell_type": "code", "execution_count": null, "id": "28b1b910", "metadata": {}, "outputs": [], "source": [ " # --- 7. Data Loaders (Synthetic and Real) ---\n", "\n", " def get_mnist_loader() -> (DataLoader, int):\n", " \"\"\"Returns MNIST DataLoader and its Task ID.\"\"\"\n", " transform = transforms.Compose([\n", " transforms.ToTensor(),\n", " transforms.Normalize((0.5,), (0.5,))\n", " ])\n", " dataset = datasets.MNIST(root='./data', train=True, download=True,\n", " transform=transform)\n", "\n", " return DataLoader(dataset, batch_size=4, shuffle=True),\n", " Task.VISION\n", "\n", "\n", " def get_imdb_loader() -> (DataLoader, int):\n", " \"\"\"Returns a synthetic IMDB-like DataLoader and its Task ID.\"\"\"\n", " texts = []\n", " labels = []\n", "\n", " for i in range(1000): # Create 1000 synthetic samples\n", " sentence = random.choice(sample_sentences)\n", " tokens = [vocab.get(word, vocab[\"\"]) for word in\n", " simple_tokenize(sentence)]\n", "\n", " # Pad or truncate tokens to a fixed length\n", " while len(tokens) < 10:\n", " tokens.append(vocab[\"\"])\n", " if len(tokens) > 10:\n", " tokens = tokens[:10]\n", "\n", " texts.append(tokens)\n", " labels.append(i % 2) # Binary sentiment: 0 or 1\n", "\n", " X = torch.tensor(texts).long()\n", " Y = torch.tensor(labels).float()\n", " dataset = TensorDataset(X, Y)\n", " return DataLoader(dataset, batch_size=4, shuffle=True),\n", " Task.BINARY\n", "\n", "\n", " def get_regression_loader() -> (DataLoader, int):\n", " \"\"\"Returns a synthetic regression DataLoader and its Task ID.\"\"\"\n", " X = torch.randn(1000, 2) # 1000 samples, 2 features\n", " Y = ((X[:, 0] + X[:, 1]) / 2).float() # Simple linear relationship\n", " dataset = TensorDataset(X, Y)\n", " return DataLoader(dataset, batch_size=4, shuffle=True),\n", " Task.REGRESSION\n", "\n" ] }, { "cell_type": "code", "execution_count": null, "id": "c3dced65", "metadata": {}, "outputs": [], "source": [ " # --- 8. Helper Function: Shape Inspector ---\n", "\n", " def inspect_shapes(agent: ModularBrainAgent):\n", " \"\"\"\n", " Prints the shapes of tensors at various points in the agent's\n", " forward pass\n", " to help verify architectural correctness.\n", " \"\"\"\n", " print(\"\\n=== SHAPE INSPECTOR ===\")\n", "\n", " # Use CPU for inspection to avoid CUDA memory issues during\n", " debugging\n", " agent.cpu()\n", " agent.eval() # Set to eval mode for consistent behavior (e.g.,\n", " dropout off)\n", "\n", " try:\n", " # Test regression path\n", " sample_input_reg = torch.randn(4, 2)\n", " encoded_reg = agent.encode(sample_input_reg, Task.REGRESSION)\n", " print(f\"Encoded (Regression): {encoded_reg.shape}\")\n", " h_relay_reg, h_place_reg, h_mirror_reg, h_auto_reg,\n", " h_modulate_reg = agent.route_modules(encoded_reg)\n", " out_reg = agent.task_heads['regression'](h_place_reg)\n", " print(f\"Regression Path:\n", " Sensory({agent.sensory(encoded_reg).shape}) ->\n", " Relay({h_relay_reg.shape}) -> Place({h_place_reg.shape}) ->\n", " Output({out_reg.shape})\")\n", "\n", " # Test language path\n", " # Using fixed input_dim for language to map to sensory input\n", " dim\n", " sample_input_lang = torch.randint(0, len(vocab), (4, 10))\n", " encoded_lang = agent.encode(sample_input_lang, Task.BINARY) #\n", " Use BINARY task for language\n", " print(f\"Encoded (Language): {encoded_lang.shape}\")\n", " h_relay_lang, h_place_lang, h_mirror_lang, h_auto_lang,\n", " h_modulate_lang = agent.route_modules(encoded_lang)\n", " out_lang = agent.task_heads['binary'](h_place_lang)\n", " print(f\"Language Path:\n", " Sensory({agent.sensory(encoded_lang).shape}) ->\n", " Relay({h_relay_lang.shape}) -> Place({h_place_lang.shape}) ->\n", " Output({out_lang.shape})\")\n", "\n", " # Test vision path\n", " sample_input_vis = torch.randn(4, 1, 28, 28)\n", " encoded_vis = agent.encode(sample_input_vis, Task.VISION)\n", " print(f\"Encoded (Vision): {encoded_vis.shape}\")\n", " h_relay_vis, h_place_vis, h_mirror_vis, h_auto_vis,\n", " h_modulate_vis = agent.route_modules(encoded_vis)\n", " out_vis = agent.task_heads['vision'](h_place_vis)\n", " print(f\"Vision Path:\n", " Sensory({agent.sensory(encoded_vis).shape}) ->\n", " Relay({h_relay_vis.shape}) -> Place({h_place_vis.shape}) ->\n", " Output({out_vis.shape})\")\n", "\n", " except Exception as e:\n", " print(f\"Shape inspection failed: {e}\")\n", "\n", " # Re-raise to ensure the error is not ignored in main\n", " execution\n", " raise e\n", " finally:\n", " agent.train() # Set back to train mode\n", " print(\"=========================\\n\")\n", "\n" ] }, { "cell_type": "code", "execution_count": null, "id": "46a6f97b", "metadata": {}, "outputs": [], "source": [ " # --- 9. Training Function ---\n", "\n", " def train(agent: ModularBrainAgent, episodes: int = 14400,\n", " buffer_size: int = 1000, replay_freq: int = 5):\n", " \"\"\"\n", " Trains the ModularBrainAgent using a curriculum learning strategy\n", " and experience replay.\n", " \"\"\"\n", " device = torch.device(\"cuda\" if torch.cuda.is_available() else\n", " \"cpu\")\n", " agent.to(device)\n", " print(f\"Using device: {device}\")\n", "\n", " print(\"🧠 Running shape inspector...\")\n", " inspect_shapes(agent)\n", "\n", " replay_buffer = TaskReplayBuffer(buffer_size=buffer_size,\n", " device=device)\n", "\n", " optimizer = torch.optim.AdamW(agent.parameters(), lr=0.001,\n", " weight_decay=1e-5)\n", " scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer,\n", " patience=100, factor=0.5, verbose=True)\n", "\n", " # Initialize weights\n", " def init_weights(m):\n", " if isinstance(m, nn.Linear):\n", " torch.nn.init.xavier_uniform_(m.weight)\n", " if m.bias is not None:\n", " torch.nn.init.zeros_(m.bias)\n", " elif isinstance(m, nn.Conv2d):\n", " torch.nn.init.kaiming_uniform_(m.weight,\n", " nonlinearity='relu')\n", " if m.bias is not None:\n", " torch.nn.init.zeros_(m.bias)\n", " agent.apply(init_weights)\n", "\n", " # Setup TensorBoard writer\n", " log_dir = \"runs/modular_brain_agent\"\n", " os.makedirs(log_dir, exist_ok=True)\n", "\n", " writer = SummaryWriter(log_dir=log_dir)\n", " print(f\"TensorBoard logs are being saved to: {log_dir}\")\n", "\n", "\n", " best_loss = float('inf')\n", " no_improvement = 0\n", " save_path = \"best_model.pth\"\n", " task_history = {i: [] for i in [Task.REGRESSION, Task.BINARY,\n", " Task.VISION]}\n", " curriculum_stages = {\n", " 0: [Task.REGRESSION],\n", " 300: [Task.REGRESSION, Task.BINARY], # Start binary after 300\n", " episodes\n", " 600: [Task.REGRESSION, Task.BINARY, Task.VISION] # Start\n", " vision after 600 episodes\n", " }\n", "\n", " raw_loaders = {\n", " Task.REGRESSION: get_regression_loader(),\n", " Task.BINARY: get_imdb_loader(),\n", " Task.VISION: get_mnist_loader()\n", " }\n", "\n", " loaders = {k: v[0] for k, v in raw_loaders.items()}\n", " # task_ids are fixed as keys for loaders in this setup\n", "\n", " # Initialize persistent iterators for each task's DataLoader\n", " loaders_iter = {\n", " Task.REGRESSION: iter(loaders[Task.REGRESSION]),\n", " Task.BINARY: iter(loaders[Task.BINARY]),\n", " Task.VISION: iter(loaders[Task.VISION])\n", " }\n", "\n", " # Helper for loss computation\n", " def compute_loss(out: torch.Tensor, Y: torch.Tensor, task_id: int)\n", " -> torch.Tensor:\n", " \"\"\"Computes loss based on task type, handling shape\n", " alignments.\"\"\"\n", " if task_id == Task.VISION:\n", " Y = Y.long() # Target labels for CrossEntropy must be long\n", " if out.dim() != 2 or Y.dim() != 1:\n", " raise ValueError(f\"Vision task expects out [batch,\n", " num_classes], Y [batch]. Got {out.shape} vs {Y.shape}\")\n", " return F.cross_entropy(out, Y)\n", "\n", " elif task_id == Task.BINARY:\n", " # Ensure output and target have matching shapes for\n", " BCEWithLogitsLoss\n", "\n", " if out.shape != Y.shape:\n", " if out.numel() == Y.numel():\n", " out = out.view_as(Y) # Reshape output to match\n", " target if num elements are same\n", " else:\n", " raise ValueError(f\"Binary task shape mismatch: out\n", " {out.shape} vs Y {Y.shape}\")\n", " return F.binary_cross_entropy_with_logits(out, Y.float())\n", "\n", " else: # REGRESSION\n", " if out.shape != Y.shape:\n", " if out.numel() == Y.numel():\n", " out = out.view_as(Y) # Reshape output to match\n", " target if num elements are same\n", " else:\n", " raise ValueError(f\"Regression task shape mismatch:\n", " out {out.shape} vs Y {Y.shape}\")\n", " return F.smooth_l1_loss(out, Y.float())\n", "\n", " loss_weights = {Task.REGRESSION: 1.5, Task.BINARY: 1.2,\n", " Task.VISION: 1.2}\n", " global_step = 0 # For TensorBoard logging\n", "\n", " for ep in range(episodes):\n", " agent.train() # Ensure model is in training mode\n", "\n", " # Determine current curriculum stage\n", " current_stage_episodes = 0\n", " for stage_start_ep, tasks in\n", " sorted(curriculum_stages.items()):\n", " if ep >= stage_start_ep:\n", " current_stage_episodes = stage_start_ep\n", " current_tasks = curriculum_stages[current_stage_episodes]\n", "\n", " # Randomly select a task from the active curriculum stage\n", " task_id = np.random.choice(current_tasks)\n", "\n", " # Fetch data using persistent iterator\n", " try:\n", " X, Y = next(loaders_iter[task_id])\n", " except StopIteration:\n", " # Reset iterator when exhausted for that task\n", " loaders_iter[task_id] = iter(loaders[task_id])\n", " X, Y = next(loaders_iter[task_id])\n", "\n", " X, Y = X.to(device), Y.to(device)\n" ] }, { "cell_type": "code", "execution_count": null, "id": "8cbe3ad7", "metadata": {}, "outputs": [], "source": [ " # --- Primary Training Step ---\n", "\n", " out = agent(X, task_id)\n", " loss = compute_loss(out, Y, task_id) *\n", " loss_weights.get(task_id, 1.0)\n", "\n", " optimizer.zero_grad()\n", " loss.backward()\n", " torch.nn.utils.clip_grad_norm_(agent.parameters(),\n", " max_norm=1.0) # Clip gradients\n", " optimizer.step()\n", "\n", " # Store current experience in replay buffer\n", " replay_buffer.add(task_id, X, Y)\n" ] }, { "cell_type": "code", "execution_count": null, "id": "857737d5", "metadata": {}, "outputs": [], "source": [ " # --- Accuracy Metric Calculation ---\n", " acc = 0.0\n", " with torch.no_grad(): # Don't track gradients for accuracy\n", " calculation\n", " if task_id == Task.VISION:\n", " pred = out.argmax(dim=1)\n", " acc = (pred == Y).float().mean().item()\n", " elif task_id == Task.BINARY:\n", " pred = (torch.sigmoid(out) > 0.5).float()\n", " acc = (pred == Y).float().mean().item()\n", " else: # REGRESSION\n", " acc = ((out - Y).abs() < 0.2).float().mean().item() #\n", " within 0.2 error margin\n", "\n", " task_history[task_id].append((loss.item(), acc))\n" ] }, { "cell_type": "code", "execution_count": null, "id": "f3724390", "metadata": {}, "outputs": [], "source": [ " # --- Replay Phase (Fixed Logic) ---\n", " if ep % replay_freq == 0:\n", " replay_X_mixed, replay_Y_mixed, replay_task_ids_list =\n", " replay_buffer.sample(batch_size=4)\n", "\n", " if replay_X_mixed is not None:\n", " replay_X_mixed = replay_X_mixed.to(device)\n", " replay_Y_mixed = replay_Y_mixed.to(device)\n", "\n", " unique_replay_tasks = list(set(replay_task_ids_list))\n", " total_replay_loss = 0.0\n", " num_replay_samples = 0\n", "\n", " # Iterate through each unique task in the sampled\n", " batch\n", " for current_replay_task_id in unique_replay_tasks:\n", " # Filter samples belonging to this specific task\n", " indices_for_this_task = [i for i, tid in\n", " enumerate(replay_task_ids_list) if tid == current_replay_task_id]\n", "\n", "\n", " if not indices_for_this_task:\n", " continue\n", "\n", " task_replay_X =\n", " replay_X_mixed[indices_for_this_task]\n", " task_replay_Y =\n", " replay_Y_mixed[indices_for_this_task]\n", "\n", " # Process these samples using the correct encoder\n", " and task head\n", " replay_out_for_task = agent(task_replay_X,\n", " current_replay_task_id)\n", "\n", " # Compute loss for this task's samples\n", " replay_loss_for_task =\n", " compute_loss(replay_out_for_task, task_replay_Y,\n", " current_replay_task_id)\n", " total_replay_loss += replay_loss_for_task *\n", " len(indices_for_this_task) # Weighted sum\n", " num_replay_samples += len(indices_for_this_task)\n", "\n", " if num_replay_samples > 0:\n", " total_replay_loss /= num_replay_samples # Average\n", " over all replayed samples\n", " optimizer.zero_grad()\n", " total_replay_loss.backward()\n", " torch.nn.utils.clip_grad_norm_(agent.parameters(),\n", " max_norm=1.0)\n", " optimizer.step()\n", " # Log replay loss\n", " writer.add_scalar('Loss/Replay_Loss',\n", " total_replay_loss.item(), global_step)\n", "\n" ] }, { "cell_type": "code", "execution_count": null, "id": "0b3999e3", "metadata": {}, "outputs": [], "source": [ " # --- Logging and Early Stopping ---\n", " scheduler.step(loss.item()) # Step scheduler based on current\n", " task loss\n", "\n", " # TensorBoard Logging\n", " writer.add_scalar(f'Loss/Current_Task_{task_id}', loss.item(),\n", " global_step)\n", " writer.add_scalar(f'Accuracy/Current_Task_{task_id}', acc,\n", " global_step)\n", " writer.add_scalar('LearningRate',\n", " optimizer.param_groups[0]['lr'], global_step)\n", "\n", " global_step += 1\n", "\n", "\n", " # Check for best model and early stopping\n", " if loss.item() < best_loss:\n", " best_loss = loss.item()\n", " no_improvement = 0\n", " torch.save(agent.state_dict(), save_path)\n", " # print(f\"New best model saved at episode {ep} with loss:\n", " {best_loss:.4f}\")\n", " else:\n", " no_improvement += 1\n", "\n", " # Print summary periodically\n", " if ep % 200 == 0:\n", " print(f\"\\n--- Episode {ep} (Curriculum Stage:\n", " {current_stage_episodes}) ---\")\n", " for t_id in sorted(task_history.keys()):\n", " if task_history[t_id]:\n", " # Get recent history, default to empty if not\n", " enough entries\n", " recent_losses = [x[0] for x in\n", " task_history[t_id][-200:]]\n", " recent_accs = [x[1] for x in\n", " task_history[t_id][-200:]]\n", "\n", " avg_loss = np.mean(recent_losses) if recent_losses\n", " else 0.0\n", " avg_acc = np.mean(recent_accs) if recent_accs else\n", " 0.0\n", "\n", " # Map task ID to a readable name\n", " task_name = {Task.REGRESSION: \"Regression\",\n", " Task.BINARY: \"Binary\", Task.VISION: \"Vision\"}.get(t_id,\n", " f\"Unknown_{t_id}\")\n", " print(f\"Task {task_name} | Avg Loss:\n", " {avg_loss:.3f} | Avg Acc: {avg_acc:.2f} ({len(recent_losses)}\n", " samples)\")\n", " print(f\"Overall Current Loss: {loss.item():.4f} | Best\n", " Loss: {best_loss:.4f} | No Improvement: {no_improvement}\")\n", " print(\"--------------------\\n\")\n", "\n", " if no_improvement >= 1000: # Early stopping threshold\n", " print(f\"Early stopping at episode {ep} due to no\n", " improvement for {no_improvement} steps.\")\n", " break\n", "\n", " writer.close()\n", " print(\"āœ… Training finished.\")\n", " return best_loss\n", "\n", "\n" ] }, { "cell_type": "code", "execution_count": null, "id": "2300bd2d", "metadata": {}, "outputs": [], "source": [ " # --- 10. Main Execution Block ---\n", "\n", " if __name__ == \"__main__\":\n", " print(\"šŸš€ Initializing Modular Brain Agent...\")\n", "\n", " # Optional: Define custom neuron counts here\n", " # custom_neuron_config = {\n", " # 'sensory': 8,\n", " # 'relay': 24,\n", " # 'interneurons': 4,\n", " # 'neuroendocrine': 16,\n", " # 'autonomic': 20,\n", " # 'mirror': 28,\n", " # 'place_grid': 32\n", " # }\n", " # agent = ModularBrainAgent(neuron_counts=custom_neuron_config)\n", "\n", " agent = ModularBrainAgent() # Using default neuron counts for now\n", "\n", " # Print model summary\n", " total_params = sum(p.numel() for p in agent.parameters())\n", " trainable_params = sum(p.numel() for p in agent.parameters() if\n", " p.requires_grad)\n", "\n", " print(f\"šŸ“Š Model Summary:\")\n", " print(f\" Total parameters: {total_params:,}\")\n", " print(f\" Trainable parameters: {trainable_params:,}\")\n", " print(f\" Model size: ~{total_params * 4 / (1024**2):.2f} MB\n", " (approx. float32)\")\n", "\n", " # Test forward pass on each task to verify shapes\n", " print(\"\\nšŸ” Testing forward passes...\")\n", " try:\n", " # Test regression task\n", " test_reg = torch.randn(2, 2)\n", " out_reg = agent(test_reg, Task.REGRESSION)\n", " print(f\"āœ… Regression test passed: {test_reg.shape} ->\n", " {out_reg.shape}\")\n", "\n", " # Test binary classification task (language task)\n", " test_bin = torch.randint(0, len(vocab), (2, 10))\n", " out_bin = agent(test_bin, Task.BINARY)\n", " print(f\"āœ… Binary classification test passed: {test_bin.shape}\n", " -> {out_bin.shape}\")\n", "\n", " # Test vision task\n", "\n", " test_vis = torch.randn(2, 1, 28, 28)\n", " out_vis = agent(test_vis, Task.VISION)\n", " print(f\"āœ… Vision test passed: {test_vis.shape} ->\n", " {out_vis.shape}\")\n", "\n", " except Exception as e:\n", " print(f\"āŒ Forward pass test failed: {e}\")\n", " # If forward pass fails, there's a fundamental issue, so exit\n", " import sys\n", " sys.exit(1)\n", "\n", " print(\"\\nšŸŽÆ Starting training...\")\n", "\n", " # Train the agent\n", " try:\n", " final_best_loss = train(\n", " agent=agent,\n", " episodes=14400, # Total training episodes\n", " buffer_size=1000, # Replay buffer size\n", " replay_freq=5 # Replay every N episodes\n", " )\n", "\n", " print(f\"\\nšŸŽ‰ Training completed successfully!\")\n", " print(f\"šŸ“ˆ Best loss achieved during training:\n", " {final_best_loss:.4f}\")\n", " print(f\"šŸ’¾ Best model saved to: best_model.pth\")\n", "\n", " except KeyboardInterrupt:\n", " print(\"\\nā¹ Training interrupted by user.\")\n", " except Exception as e:\n", " print(f\"\\nāŒ Training failed with error: {e}\")\n", " raise e\n", "\n", " print(\"\\n🧠 Modular Brain Agent execution completed!\")" ] } ], "metadata": {}, "nbformat": 4, "nbformat_minor": 5 }