import torch import torch.nn as nn import torch.nn.functional as F from torch.utils.data import DataLoader, TensorDataset from torchvision import datasets, transforms from torch.utils.tensorboard import SummaryWriter import numpy as np import random from collections import deque import os # For creating log directory # --- 0. Global Setup and Vocabulary --- # Simulated language data sample_sentences = [ "I love AI", "Deep learning is fun", "Spiking neurons are cool", "Brain inspired models rock", "Replay buffer helps learning" ] def simple_tokenize(text): """Basic tokenizer for sample sentences.""" return text.lower().split()[:10] # Build vocab manually vocab = {"": 0} word_counter = 1 for sentence in sample_sentences: for word in simple_tokenize(sentence): if word not in vocab: vocab[word] = word_counter word_counter += 1 # Task IDs (using an Enum-like class for clarity) class Task: REGRESSION = 3 BINARY = 4 VISION = 5 # --- 1. Base Classes and Spiking Neuron Implementations --- class Module(nn.Module): """Base class for brain-inspired modules.""" def __init__(self): super().__init__() def forward(self, x: torch.Tensor) -> torch.Tensor: raise NotImplementedError class SurrogateLIF(torch.autograd.Function): """ Surrogate gradient function for LIF neurons. Allows backpropagation through the non-differentiable spiking function. """ @staticmethod def forward( ctx, input: torch.Tensor) -> torch.Tensor: ctx.save_for_backward(input) return (input > 0).float() @staticmethod def backward( ctx, grad_output: torch.Tensor) -> torch.Tensor: input, = ctx.saved_tensors grad_input = grad_output.clone() # Approximate derivative: a # constant value where # input is close to 0 grad_input[input.abs( ) < 1] = 1.0 return grad_input class SpikingNeuron( Module): """ Leaky Integrate-and-Fire (LIF) neuron model. Resets membrane potential to zero after spiking. """ def __init__( self, threshold: float = 1.0, decay: float = 0.95): super().__init__() self.threshold = threshold self.decay = decay # Register buffer # for membrane # potential, # persists across forward calls # type: # torch.Tensor self.register_buffer( 'mem', None) def forward( self, x: torch.Tensor) -> torch.Tensor: # Initialize or # re-initialize # membrane # potential if # shape changes if self.mem is None or self.mem.shape != x.shape: self.mem = torch.zeros_like( x) # Update # membrane # potential self.mem = self.decay * self.mem + x # Generate # spike spike = ( self.mem >= self.threshold).float() # Reset # membrane # potential # for # spiking # neurons self.mem = torch.where(spike.bool(), torch.zeros_like(self.mem), self.mem) # Apply # surrogate # gradient # for # backpropagation return SurrogateLIF.apply( spike) class AdaptiveLIFNeuron( SpikingNeuron): """ LIF neuron with an adaptive threshold based on recent firing rate. """ def __init__( self, threshold: float = 1.0, decay: float = 0.95): super().__init__(threshold=threshold, decay=decay) # Register # buffer # for # adaptive # threshold # state self.register_buffer('threshold_state', torch.tensor(threshold)) def forward( self, x: torch.Tensor) -> torch.Tensor: if self.mem is None or self.mem.shape != x.shape: self.mem = torch.zeros_like( x) self.mem = self.decay * self.mem + x spike = ( self.mem >= self.threshold_state).float() self.mem = torch.where(spike.bool(), torch.zeros_like(self.mem), self.mem) # Adapt # threshold # based # on # scalar # mean # spike # rate # (no_grad # for threshold update) with torch.no_grad(): spike_rate = spike.mean().item() # Scalar rate # Simple # moving # average # for # threshold # adaptation self.threshold_state = self.threshold_state * 0.99 + spike_rate * 0.01 return SurrogateLIF.apply( spike) # --- 2. Encoder Modules (Task-Specific Input Processing) --- class SharedEncoder( nn.Module): """Generic encoder for numerical data.""" def __init__( self, input_size: int, hidden_size: int = 4): super().__init__() self.encoder = nn.Sequential( nn.Linear( input_size, hidden_size), nn.LayerNorm( hidden_size), nn.ReLU(), nn.Dropout( 0.3) ) def forward( self, x: torch.Tensor) -> torch.Tensor: assert x.dim() == 2, f"Expected 2D input for SharedEncoder, got {x.shape}" return self.encoder(x.float()) class CNNVision(nn.Module): """CNN encoder for image data (e.g., MNIST).""" def __init__(self, output_features: int = 4): super().__init__() self.conv = nn.Sequential( nn.Conv2d(1, 4, kernel_size=5), nn.ReLU(), nn.MaxPool2d(2), nn.Conv2d(4, 8, kernel_size=5), nn.ReLU(), nn.AdaptiveAvgPool2d((1, 1)), # Outputs [batch, 8, 1, 1] nn.Flatten(), # Outputs [batch, 8] nn.Linear(8, output_features) # Maps to desired output size ) def forward(self, x: torch.Tensor) -> torch.Tensor: return self.conv(x) class GRULanguage(nn.Module): """GRU encoder for sequential language data.""" def __init__(self, vocab_size: int, embedding_dim: int = 4, hidden_dim: int = 4): super().__init__() self.embed = nn.Embedding(vocab_size, embedding_dim) self.gru = nn.GRU(embedding_dim, hidden_dim, batch_first=True) def forward(self, x: torch.Tensor) -> torch.Tensor: if x.dim() == 1: x = x.unsqueeze(0) # Add batch dimension if missing emb = self.embed(x.long()) # Ensure input is long type for embedding out, _ = self.gru(emb) return out[:, -1, :] # Use last hidden state as sequence representation # --- 3. Brain-Inspired Modules with Specific Neuron Counts --- class SensoryProcessor(Module): """Processes initial sensory input, often with low-level features.""" def __init__(self, input_dim: int, output_dim: int): super().__init__() self.linear = nn.Linear(input_dim, output_dim) self.norm = nn.LayerNorm(output_dim) self.dropout = nn.Dropout(0.3) self.neuron = SpikingNeuron() def forward(self, x: torch.Tensor) -> torch.Tensor: z = self.dropout(self.norm(self.linear(x))) return self.neuron(z) class RelayLayer(Module): """ Routes information, potentially performing attention-like operations. Simulates a small sequence for MultiheadAttention. """ def __init__(self, input_dim: int, output_dim: int): super().__init__() self.linear = nn.Linear(input_dim, output_dim) self.norm = nn.LayerNorm(output_dim) # MultiheadAttention requires embed_dim to be divisible by num_heads self.attn = nn.MultiheadAttention(embed_dim=output_dim, num_heads=2, batch_first=True) self.dropout = nn.Dropout(0.3) self.neuron = SpikingNeuron() def forward(self, x: torch.Tensor) -> torch.Tensor: z = self.dropout(self.norm(self.linear(x))) # Create fake time dimension by repeating the input for attention seq_len = 4 z_seq = z.unsqueeze(1).repeat(1, seq_len, 1) # Apply attention across the simulated time steps z_out, _ = self.attn(z_seq, z_seq, z_seq) # Pool over the time dimension to get a single representation routed = z_out.mean(dim=1) return self.neuron(routed) class InterneuronLogic(Module): """Core processing unit, potentially for decision making or high-level abstraction.""" def __init__(self, input_dim: int, output_dim: int): super().__init__() self.linear = nn.Linear(input_dim, output_dim) self.norm = nn.LayerNorm(output_dim) self.dropout = nn.Dropout(0.3) self.neuron = AdaptiveLIFNeuron() # Using adaptive neuron here def forward(self, x: torch.Tensor) -> torch.Tensor: z = self.dropout(self.norm(self.linear(x))) return self.neuron(z) class NeuroendocrineModulator(Module): """ Modulates signals, potentially for gain control or emotional states. Applies a sigmoid-based gain to its input. """ def __init__(self, input_dim: int, output_dim: int): super().__init__() self.linear = nn.Linear(input_dim, output_dim) self.norm = nn.LayerNorm(output_dim) self.dropout = nn.Dropout(0.3) self.gain_control = nn.Linear(output_dim, output_dim) # Maps output to gain factor def forward(self, x: torch.Tensor) -> torch.Tensor: z = self.dropout(self.norm(self.linear(x))) gain = torch.sigmoid(self.gain_control(z)) # Gain factor between 0 and 1 return z * gain class AutonomicProcessor(Module): """ Manages internal states, often involves recurrent processing. Uses a GRU for sequential processing. """ def __init__(self, input_dim: int, output_dim: int): super().__init__() self.linear = nn.Linear(input_dim, output_dim) self.norm = nn.LayerNorm(output_dim) self.recurrent = nn.GRU(output_dim, output_dim, batch_first=True) self.feedback_gain = 0.9 def forward(self, x: torch.Tensor) -> torch.Tensor: z = self.norm(self.linear(x)) # GRU expects (batch, seq_len, features) -> seq_len=1 for single step h, _ = self.recurrent(z.unsqueeze(1)) # Squeeze the sequence dimension back out return self.feedback_gain * h.squeeze(1) class MirrorComparator(Module): """ Compares current state to a reference, potentially for self-other distinction or goal comparison. Outputs spikes and an optional similarity score. """ def __init__(self, input_dim: int, output_dim: int): super().__init__() self.linear = nn.Linear(input_dim, output_dim) self.norm = nn.LayerNorm(output_dim) self.comparison_layer = nn.CosineSimilarity(dim=1) # Cosine similarity for comparison self.neuron = SpikingNeuron() def compare(self, a: torch.Tensor, b: torch.Tensor) -> torch.Tensor: """Compare pre-spike representations for similarity.""" a_proj = self.linear(a) b_proj = self.linear(b) return self.comparison_layer(a_proj, b_proj) def forward(self, x: torch.Tensor, reference: torch.Tensor = None) -> (torch.Tensor, torch.Tensor): z = self.norm(self.linear(x)) spike = self.neuron(z) similarity = None if reference is not None: # Note: The 'reference' input would need to be processed similarly before comparison # For simplicity, assuming 'reference' is already in the correct feature space for comparison_layer similarity = self.compare(z, reference) # Pass processed 'z' for comparison return spike, similarity class PlaceGridMemory(Module): """ Spatial memory system using population codes and LSTM for sequential memory. """ def __init__(self, input_dim: int, output_dim: int): super().__init__() self.linear = nn.Linear(input_dim, output_dim) self.norm = nn.LayerNorm(output_dim) self.positional_encoder = self.population_code # Reference to internal method self.memory_cell = nn.LSTM(output_dim, output_dim, batch_first=True) def population_code(self, x: torch.Tensor, pop_size: int) -> torch.Tensor: """Expands a scalar value into a population-coded (one-hot like) vector.""" # Calculate a scalar representation (e.g., mean) x_scalar = x.mean(dim=1) x_normalized = torch.sigmoid(x_scalar) # Normalize to [0, 1] # Map normalized value to an index within the population size idx = (x_normalized * (pop_size - 1)).long().clamp(0, pop_size - 1) encoded = torch.zeros(x.size(0), pop_size, device=x.device) # Set the corresponding index to 1.0 encoded.scatter_(1, idx.unsqueeze(1), 1.0) return encoded def forward(self, x: torch.Tensor) -> torch.Tensor: encoded = self.norm(self.linear(x)) # Apply population coding based on configured output_dim pos_encoded = self.positional_encoder(encoded, pop_size=self.linear.out_features) # LSTM expects (batch, seq_len, features) -> seq_len=1 for single step memory_out, _ = self.memory_cell(pos_encoded.unsqueeze(1)) # Squeeze the sequence dimension back out return memory_out.squeeze(1) # --- 4. Task Heads (Task-Specific Output Layers) --- class TaskHead(nn.Module): """Generic task head for different output types.""" def __init__(self, input_dim: int, output_dim: int, task_type: str): super().__init__() self.task_type = task_type self.head = nn.Sequential( nn.Linear(input_dim, output_dim), nn.Dropout(0.3) ) def forward(self, x: torch.Tensor) -> torch.Tensor: # Regression and Binary will be squeezed to [batch] # Vision will remain [batch, num_classes] return self.head(x).squeeze(-1) if self.task_type in ['binary', 'regression'] else self.head(x) # --- 5. Replay Buffer for Continual Learning --- class TaskReplayBuffer: """ A replay buffer that stores experiences tagged with their task ID. Samples are drawn proportionally or randomly across tasks. """ def __init__(self, buffer_size: int = 1000, device: str = "cpu"): self.buffer_size = buffer_size self.task_buffers = {task_id: deque(maxlen=buffer_size) for task_id in [Task.REGRESSION, Task.BINARY, Task.VISION]} self.device = device def add(self, task_id: int, states: torch.Tensor, labels: torch.Tensor): """Adds experiences (state-label pairs) to the buffer for a specific task.""" if task_id not in self.task_buffers: # Should not happen if self.task_buffers is pre-initialized with all Task IDs self.task_buffers[task_id] = deque(maxlen=self.buffer_size) for state, label in zip(states, labels): # Detach from graph and move to CPU to prevent memory leaks self.task_buffers[task_id].append((state.detach().cpu(), label.detach().cpu())) def sample(self, batch_size: int = 32) -> (torch.Tensor, torch.Tensor, list): """ Samples a batch of experiences from the buffer, ensuring mixed tasks. Returns: (states, labels, list of corresponding task_ids) """ active_tasks = [tid for tid, buffer in self.task_buffers.items() if len(buffer) > 0] if not active_tasks: return None, None, None # Sample at least two distinct tasks if possible num_tasks_to_sample = min(2, len(active_tasks)) sampled_tasks_ids = random.sample(active_tasks, num_tasks_to_sample) batch_states, batch_labels, batch_task_ids = [], [], [] # Distribute batch_size across sampled tasks per_task_samples_base = batch_size // num_tasks_to_sample remainder = batch_size % num_tasks_to_sample for i, task_id in enumerate(sampled_tasks_ids): task_list = list(self.task_buffers[task_id]) # Convert deque to list for random.sample/choices if not task_list: continue # Add remainder to first few tasks k = per_task_samples_base + (1 if i < remainder else 0) k = min(k, len(task_list)) # Ensure we don't sample more than available if k == 0: continue try: samples = random.sample(task_list, k) except ValueError: # If k is larger than task_list length (should be caught by min(k, len(task_list)) # but good fallback for robustness) samples = random.choices(task_list, k=k) for state, label in samples: batch_states.append(state) batch_labels.append(label) batch_task_ids.append(task_id) if not batch_states: return None, None, None # Stack tensors and return task IDs for individual processing return ( torch.stack(batch_states), torch.stack(batch_labels), batch_task_ids # Return as list to preserve individual task identities ) # --- 6. Full Agent Model --- class ModularBrainAgent(nn.Module): """ A comprehensive modular neural agent inspired by brain architecture, integrating spiking neurons, attention, and recurrent memory for multi-task learning. """ def __init__(self, neuron_counts: dict = None): super().__init__() # --- Parameterization of Neuron Counts --- if neuron_counts is None: # Default neuron counts for each brain region neuron_counts = { 'sensory': 4, 'relay': 12, 'interneurons': 2, 'neuroendocrine': 8, 'autonomic': 10, 'mirror': 14, 'place_grid': 16 } self.neuron_counts = neuron_counts # --- Encoders (Input Modalities) --- self.encoders = nn.ModuleDict({ 'regression': SharedEncoder(2, self.neuron_counts['sensory']), 'language': GRULanguage(len(vocab), embedding_dim=self.neuron_counts['sensory'], hidden_dim=self.neuron_counts['sensory']), 'vision': CNNVision(output_features=self.neuron_counts['sensory']) }) # --- Brain Modules with Realistic Neuron Counts --- self.sensory = SensoryProcessor( input_dim=self.neuron_counts['sensory'], output_dim=self.neuron_counts['sensory'] ) self.relay = RelayLayer( input_dim=self.neuron_counts['sensory'], # Input from SensoryProcessor output_dim=self.neuron_counts['relay'] ) self.interneurons = InterneuronLogic( input_dim=self.neuron_counts['relay'], # Input from RelayLayer output_dim=self.neuron_counts['interneurons'] ) self.neuroendocrine = NeuroendocrineModulator( input_dim=self.neuron_counts['interneurons'], output_dim=self.neuron_counts['neuroendocrine'] ) self.autonomic = AutonomicProcessor( input_dim=self.neuron_counts['neuroendocrine'], output_dim=self.neuron_counts['autonomic'] ) self.mirror = MirrorComparator( input_dim=self.neuron_counts['autonomic'], output_dim=self.neuron_counts['mirror'] ) self.place_grid = PlaceGridMemory( input_dim=self.neuron_counts['mirror'], output_dim=self.neuron_counts['place_grid'] ) # --- Interconnection Weights (between brain modules) --- # Ensure input/output dimensions match the neuron_counts self.connect_sensory_to_relay = nn.Linear(self.neuron_counts['sensory'], self.neuron_counts['relay']) self.connect_relay_to_inter = nn.Linear(self.neuron_counts['relay'], self.neuron_counts['interneurons']) self.connect_inter_to_modulators = nn.Linear(self.neuron_counts['interneurons'], self.neuron_counts['neuroendocrine']) self.connect_modulators_to_auto = nn.Linear(self.neuron_counts['neuroendocrine'], self.neuron_counts['autonomic']) self.connect_auto_to_mirror = nn.Linear(self.neuron_counts['autonomic'], self.neuron_counts['mirror']) self.connect_mirror_to_place = nn.Linear(self.neuron_counts['mirror'], self.neuron_counts['place_grid']) # Recurrent loop from place_grid back to relay self.connect_place_to_relay = nn.Linear(self.neuron_counts['place_grid'], self.neuron_counts['relay']) # --- Optional Feedback Connections --- self.feedback_relay_to_sensory = nn.Linear(self.neuron_counts['relay'], self.neuron_counts['sensory']) self.feedback_inter_to_relay = nn.Linear(self.neuron_counts['interneurons'], self.neuron_counts['relay']) # --- Task Heads (Outputs for specific tasks) --- self.task_heads = nn.ModuleDict({ 'binary': TaskHead(self.neuron_counts['place_grid'], 1, 'binary'), 'vision': TaskHead(self.neuron_counts['place_grid'], 10, 'vision'), # MNIST has 10 classes 'regression': TaskHead(self.neuron_counts['place_grid'], 1, 'regression') }) def route_modules(self, x: torch.Tensor) -> tuple[torch.Tensor, ...]: """ Defines the forward pass through the interconnected brain modules. Returns intermediate activations for potential monitoring or later use. """ # Forward pass through brain circuits h_sensory = self.sensory(x) h_relay = self.connect_sensory_to_relay(h_sensory) h_relay = self.relay(h_relay) h_inter = self.connect_relay_to_inter(h_relay) h_inter = self.interneurons(h_inter) h_modulate = self.connect_inter_to_modulators(h_inter) h_modulate = self.neuroendocrine(h_modulate) h_auto = self.connect_modulators_to_auto(h_modulate) h_auto = self.autonomic(h_auto) h_mirror_result, _ = self.mirror(self.connect_auto_to_mirror(h_auto)) # Mirror returns tuple # if isinstance(mirror_result, tuple): # No longer needed as forward always returns tuple # h_mirror, _ = mirror_result # else: # h_mirror = mirror_result h_mirror = h_mirror_result # Renamed for clarity after tuple unpack h_place = self.connect_mirror_to_place(h_mirror) h_place = self.place_grid(h_place) # Recurrent loop from hippocampus (place_grid) back to thalamus (relay) # Note: Added ReLU here to prevent negative values from feeding back into spiking neurons potentially h_relay = h_relay + torch.relu(self.connect_place_to_relay(h_place)) # Optional feedback connections h_sensory = h_sensory + torch.relu(self.feedback_relay_to_sensory(h_relay)) h_relay = h_relay + torch.relu(self.feedback_inter_to_relay(h_inter)) return h_relay, h_place, h_mirror, h_auto, h_modulate def encode(self, x: torch.Tensor, task_id: int) -> torch.Tensor: """Selects and applies the appropriate encoder based on task ID and input shape.""" # Check task ID and input shape for specific encoders if task_id == Task.REGRESSION and x.dim() == 2 and x.size(1) == 2: return self.encoders['regression'](x.float()) elif task_id == Task.BINARY and x.dim() == 2 and x.size(1) == 10: return self.encoders['language'](x.long()) elif task_id == Task.VISION: if x.dim() == 4: # Already 4D (batch, 1, H, W) return self.encoders['vision'](x.float()) elif x.dim() == 2 and x.size(1) == 784: # Flattened 2D (batch, 784) return self.encoders['vision'](x.view(x.size(0), 1, 28, 28)) else: raise ValueError(f"Unexpected vision input shape: {x.shape}") else: # Fallback for unexpected task_id/shape combination, or if x is already "encoded" # In a real system, you might want a specific error or default encoder here. # For now, assumes x is already in the 'sensory' input dimension if it falls through. if x.size(-1) != self.neuron_counts['sensory']: raise ValueError(f"Input {x.shape} does not match sensory input dim {self.neuron_counts['sensory']} and no specific encoder found for task {task_id}.") return x.float() def forward(self, x: torch.Tensor, task_id: int) -> torch.Tensor: """ Main forward pass of the Modular Brain Agent. Encodes input, routes through brain modules, and selects task head. """ if x.dim() == 1: x = x.unsqueeze(0) # Ensure batch dimension encoded = self.encode(x, task_id) # Route through brain modules, focusing on the outputs used for task heads _, h_place, _, _, _ = self.route_modules(encoded) # Only h_place is used for task heads # Map task IDs to the appropriate task head name head_name = { Task.REGRESSION: 'regression', Task.BINARY: 'binary', Task.VISION: 'vision' }.get(task_id, 'regression') # Default to regression head if task_id is unexpected out = self.task_heads[head_name](h_place) # All tasks use the PlaceGridMemory output return out # --- 7. Data Loaders (Synthetic and Real) --- def get_mnist_loader() -> (DataLoader, int): """Returns MNIST DataLoader and its Task ID.""" transform = transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,)) ]) dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform) return DataLoader(dataset, batch_size=4, shuffle=True), Task.VISION def get_imdb_loader() -> (DataLoader, int): """Returns a synthetic IMDB-like DataLoader and its Task ID.""" texts = [] labels = [] for i in range(1000): # Create 1000 synthetic samples sentence = random.choice(sample_sentences) tokens = [vocab.get(word, vocab[""]) for word in simple_tokenize(sentence)] # Pad or truncate tokens to a fixed length while len(tokens) < 10: tokens.append(vocab[""]) if len(tokens) > 10: tokens = tokens[:10] texts.append(tokens) labels.append(i % 2) # Binary sentiment: 0 or 1 X = torch.tensor(texts).long() Y = torch.tensor(labels).float() dataset = TensorDataset(X, Y) return DataLoader(dataset, batch_size=4, shuffle=True), Task.BINARY def get_regression_loader() -> (DataLoader, int): """Returns a synthetic regression DataLoader and its Task ID.""" X = torch.randn(1000, 2) # 1000 samples, 2 features Y = ((X[:, 0] + X[:, 1]) / 2).float() # Simple linear relationship dataset = TensorDataset(X, Y) return DataLoader(dataset, batch_size=4, shuffle=True), Task.REGRESSION # --- 8. Helper Function: Shape Inspector --- def inspect_shapes(agent: ModularBrainAgent): """ Prints the shapes of tensors at various points in the agent's forward pass to help verify architectural correctness. """ print("\n=== SHAPE INSPECTOR ===") # Use CPU for inspection to avoid CUDA memory issues during debugging agent.cpu() agent.eval() # Set to eval mode for consistent behavior (e.g., dropout off) try: # Test regression path sample_input_reg = torch.randn(4, 2) encoded_reg = agent.encode(sample_input_reg, Task.REGRESSION) print(f"Encoded (Regression): {encoded_reg.shape}") h_relay_reg, h_place_reg, h_mirror_reg, h_auto_reg, h_modulate_reg = agent.route_modules(encoded_reg) out_reg = agent.task_heads['regression'](h_place_reg) print(f"Regression Path: Sensory({agent.sensory(encoded_reg).shape}) -> Relay({h_relay_reg.shape}) -> Place({h_place_reg.shape}) -> Output({out_reg.shape})") # Test language path # Using fixed input_dim for language to map to sensory input dim sample_input_lang = torch.randint(0, len(vocab), (4, 10)) encoded_lang = agent.encode(sample_input_lang, Task.BINARY) # Use BINARY task for language print(f"Encoded (Language): {encoded_lang.shape}") h_relay_lang, h_place_lang, h_mirror_lang, h_auto_lang, h_modulate_lang = agent.route_modules(encoded_lang) out_lang = agent.task_heads['binary'](h_place_lang) print(f"Language Path: Sensory({agent.sensory(encoded_lang).shape}) -> Relay({h_relay_lang.shape}) -> Place({h_place_lang.shape}) -> Output({out_lang.shape})") # Test vision path sample_input_vis = torch.randn(4, 1, 28, 28) encoded_vis = agent.encode(sample_input_vis, Task.VISION) print(f"Encoded (Vision): {encoded_vis.shape}") h_relay_vis, h_place_vis, h_mirror_vis, h_auto_vis, h_modulate_vis = agent.route_modules(encoded_vis) out_vis = agent.task_heads['vision'](h_place_vis) print(f"Vision Path: Sensory({agent.sensory(encoded_vis).shape}) -> Relay({h_relay_vis.shape}) -> Place({h_place_vis.shape}) -> Output({out_vis.shape})") except Exception as e: print(f"Shape inspection failed: {e}") # Re-raise to ensure the error is not ignored in main execution raise e finally: agent.train() # Set back to train mode print("=========================\n") # --- 9. Training Function --- def train(agent: ModularBrainAgent, episodes: int = 14400, buffer_size: int = 1000, replay_freq: int = 5): """ Trains the ModularBrainAgent using a curriculum learning strategy and experience replay. """ device = torch.device("cuda" if torch.cuda.is_available() else "cpu") agent.to(device) print(f"Using device: {device}") print("🧠 Running shape inspector...") inspect_shapes(agent) replay_buffer = TaskReplayBuffer(buffer_size=buffer_size, device=device) optimizer = torch.optim.AdamW(agent.parameters(), lr=0.001, weight_decay=1e-5) scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, patience=100, factor=0.5, verbose=True) # Initialize weights def init_weights(m): if isinstance(m, nn.Linear): torch.nn.init.xavier_uniform_(m.weight) if m.bias is not None: torch.nn.init.zeros_(m.bias) elif isinstance(m, nn.Conv2d): torch.nn.init.kaiming_uniform_(m.weight, nonlinearity='relu') if m.bias is not None: torch.nn.init.zeros_(m.bias) agent.apply(init_weights) # Setup TensorBoard writer log_dir = "runs/modular_brain_agent" os.makedirs(log_dir, exist_ok=True) writer = SummaryWriter(log_dir=log_dir) print(f"TensorBoard logs are being saved to: {log_dir}") best_loss = float('inf') no_improvement = 0 save_path = "best_model.pth" task_history = {i: [] for i in [Task.REGRESSION, Task.BINARY, Task.VISION]} curriculum_stages = { 0: [Task.REGRESSION], 300: [Task.REGRESSION, Task.BINARY], # Start binary after 300 episodes 600: [Task.REGRESSION, Task.BINARY, Task.VISION] # Start vision after 600 episodes } raw_loaders = { Task.REGRESSION: get_regression_loader(), Task.BINARY: get_imdb_loader(), Task.VISION: get_mnist_loader() } loaders = {k: v[0] for k, v in raw_loaders.items()} # task_ids are fixed as keys for loaders in this setup # Initialize persistent iterators for each task's DataLoader loaders_iter = { Task.REGRESSION: iter(loaders[Task.REGRESSION]), Task.BINARY: iter(loaders[Task.BINARY]), Task.VISION: iter(loaders[Task.VISION]) } # Helper for loss computation def compute_loss(out: torch.Tensor, Y: torch.Tensor, task_id: int) -> torch.Tensor: """Computes loss based on task type, handling shape alignments.""" if task_id == Task.VISION: Y = Y.long() # Target labels for CrossEntropy must be long if out.dim() != 2 or Y.dim() != 1: raise ValueError(f"Vision task expects out [batch, num_classes], Y [batch]. Got {out.shape} vs {Y.shape}") return F.cross_entropy(out, Y) elif task_id == Task.BINARY: # Ensure output and target have matching shapes for BCEWithLogitsLoss if out.shape != Y.shape: if out.numel() == Y.numel(): out = out.view_as(Y) # Reshape output to match target if num elements are same else: raise ValueError(f"Binary task shape mismatch: out {out.shape} vs Y {Y.shape}") return F.binary_cross_entropy_with_logits(out, Y.float()) else: # REGRESSION if out.shape != Y.shape: if out.numel() == Y.numel(): out = out.view_as(Y) # Reshape output to match target if num elements are same else: raise ValueError(f"Regression task shape mismatch: out {out.shape} vs Y {Y.shape}") return F.smooth_l1_loss(out, Y.float()) loss_weights = {Task.REGRESSION: 1.5, Task.BINARY: 1.2, Task.VISION: 1.2} global_step = 0 # For TensorBoard logging for ep in range(episodes): agent.train() # Ensure model is in training mode # Determine current curriculum stage current_stage_episodes = 0 for stage_start_ep, tasks in sorted(curriculum_stages.items()): if ep >= stage_start_ep: current_stage_episodes = stage_start_ep current_tasks = curriculum_stages[current_stage_episodes] # Randomly select a task from the active curriculum stage task_id = np.random.choice(current_tasks) # Fetch data using persistent iterator try: X, Y = next(loaders_iter[task_id]) except StopIteration: # Reset iterator when exhausted for that task loaders_iter[task_id] = iter(loaders[task_id]) X, Y = next(loaders_iter[task_id]) X, Y = X.to(device), Y.to(device) # --- Primary Training Step --- out = agent(X, task_id) loss = compute_loss(out, Y, task_id) * loss_weights.get(task_id, 1.0) optimizer.zero_grad() loss.backward() torch.nn.utils.clip_grad_norm_(agent.parameters(), max_norm=1.0) # Clip gradients optimizer.step() # Store current experience in replay buffer replay_buffer.add(task_id, X, Y) # --- Accuracy Metric Calculation --- acc = 0.0 with torch.no_grad(): # Don't track gradients for accuracy calculation if task_id == Task.VISION: pred = out.argmax(dim=1) acc = (pred == Y).float().mean().item() elif task_id == Task.BINARY: pred = (torch.sigmoid(out) > 0.5).float() acc = (pred == Y).float().mean().item() else: # REGRESSION acc = ((out - Y).abs() < 0.2).float().mean().item() # within 0.2 error margin task_history[task_id].append((loss.item(), acc)) # --- Replay Phase (Fixed Logic) --- if ep % replay_freq == 0: replay_X_mixed, replay_Y_mixed, replay_task_ids_list = replay_buffer.sample(batch_size=4) if replay_X_mixed is not None: replay_X_mixed = replay_X_mixed.to(device) replay_Y_mixed = replay_Y_mixed.to(device) unique_replay_tasks = list(set(replay_task_ids_list)) total_replay_loss = 0.0 num_replay_samples = 0 # Iterate through each unique task in the sampled batch for current_replay_task_id in unique_replay_tasks: # Filter samples belonging to this specific task indices_for_this_task = [i for i, tid in enumerate(replay_task_ids_list) if tid == current_replay_task_id] if not indices_for_this_task: continue task_replay_X = replay_X_mixed[indices_for_this_task] task_replay_Y = replay_Y_mixed[indices_for_this_task] # Process these samples using the correct encoder and task head replay_out_for_task = agent(task_replay_X, current_replay_task_id) # Compute loss for this task's samples replay_loss_for_task = compute_loss(replay_out_for_task, task_replay_Y, current_replay_task_id) total_replay_loss += replay_loss_for_task * len(indices_for_this_task) # Weighted sum num_replay_samples += len(indices_for_this_task) if num_replay_samples > 0: total_replay_loss /= num_replay_samples # Average over all replayed samples optimizer.zero_grad() total_replay_loss.backward() torch.nn.utils.clip_grad_norm_(agent.parameters(), max_norm=1.0) optimizer.step() # Log replay loss writer.add_scalar('Loss/Replay_Loss', total_replay_loss.item(), global_step) # --- Logging and Early Stopping --- scheduler.step(loss.item()) # Step scheduler based on current task loss # TensorBoard Logging writer.add_scalar(f'Loss/Current_Task_{task_id}', loss.item(), global_step) writer.add_scalar(f'Accuracy/Current_Task_{task_id}', acc, global_step) writer.add_scalar('LearningRate', optimizer.param_groups[0]['lr'], global_step) global_step += 1 # Check for best model and early stopping if loss.item() < best_loss: best_loss = loss.item() no_improvement = 0 torch.save(agent.state_dict(), save_path) # print(f"New best model saved at episode {ep} with loss: {best_loss:.4f}") else: no_improvement += 1 # Print summary periodically if ep % 200 == 0: print(f"\n--- Episode {ep} (Curriculum Stage: {current_stage_episodes}) ---") for t_id in sorted(task_history.keys()): if task_history[t_id]: # Get recent history, default to empty if not enough entries recent_losses = [x[0] for x in task_history[t_id][-200:]] recent_accs = [x[1] for x in task_history[t_id][-200:]] avg_loss = np.mean(recent_losses) if recent_losses else 0.0 avg_acc = np.mean(recent_accs) if recent_accs else 0.0 # Map task ID to a readable name task_name = {Task.REGRESSION: "Regression", Task.BINARY: "Binary", Task.VISION: "Vision"}.get(t_id, f"Unknown_{t_id}") print(f"Task {task_name} | Avg Loss: {avg_loss:.3f} | Avg Acc: {avg_acc:.2f} ({len(recent_losses)} samples)") print(f"Overall Current Loss: {loss.item():.4f} | Best Loss: {best_loss:.4f} | No Improvement: {no_improvement}") print("--------------------\n") if no_improvement >= 1000: # Early stopping threshold print(f"Early stopping at episode {ep} due to no improvement for {no_improvement} steps.") break writer.close() print("āœ… Training finished.") return best_loss # --- 10. Main Execution Block --- if __name__ == "__main__": print("šŸš€ Initializing Modular Brain Agent...") # Optional: Define custom neuron counts here # custom_neuron_config = { # 'sensory': 8, # 'relay': 24, # 'interneurons': 4, # 'neuroendocrine': 16, # 'autonomic': 20, # 'mirror': 28, # 'place_grid': 32 # } # agent = ModularBrainAgent(neuron_counts=custom_neuron_config) agent = ModularBrainAgent() # Using default neuron counts for now # Print model summary total_params = sum(p.numel() for p in agent.parameters()) trainable_params = sum(p.numel() for p in agent.parameters() if p.requires_grad) print(f"šŸ“Š Model Summary:") print(f" Total parameters: {total_params:,}") print(f" Trainable parameters: {trainable_params:,}") print(f" Model size: ~{total_params * 4 / (1024**2):.2f} MB (approx. float32)") # Test forward pass on each task to verify shapes print("\nšŸ” Testing forward passes...") try: # Test regression task test_reg = torch.randn(2, 2) out_reg = agent(test_reg, Task.REGRESSION) print(f"āœ… Regression test passed: {test_reg.shape} -> {out_reg.shape}") # Test binary classification task (language task) test_bin = torch.randint(0, len(vocab), (2, 10)) out_bin = agent(test_bin, Task.BINARY) print(f"āœ… Binary classification test passed: {test_bin.shape} -> {out_bin.shape}") # Test vision task test_vis = torch.randn(2, 1, 28, 28) out_vis = agent(test_vis, Task.VISION) print(f"āœ… Vision test passed: {test_vis.shape} -> {out_vis.shape}") except Exception as e: print(f"āŒ Forward pass test failed: {e}") # If forward pass fails, there's a fundamental issue, so exit import sys sys.exit(1) print("\nšŸŽÆ Starting training...") # Train the agent try: final_best_loss = train( agent=agent, episodes=14400, # Total training episodes buffer_size=1000, # Replay buffer size replay_freq=5 # Replay every N episodes ) print(f"\nšŸŽ‰ Training completed successfully!") print(f"šŸ“ˆ Best loss achieved during training: {final_best_loss:.4f}") print(f"šŸ’¾ Best model saved to: best_model.pth") except KeyboardInterrupt: print("\nā¹ Training interrupted by user.") except Exception as e: print(f"\nāŒ Training failed with error: {e}") raise e print("\n🧠 Modular Brain Agent execution completed!")