#!/usr/bin/env python3 """ Full end-to-end BitTransformerLM training run with all optimizations! Small scale test to validate our enhanced system. """ import torch import torch.nn as nn import torch.optim as optim from torch.utils.data import Dataset, DataLoader import numpy as np import logging from pathlib import Path import time from typing import List, Dict, Any # Import our enhanced modules from bit_transformer.model import BitTransformerLM from bit_transformer.compression import compress_bits_batch, model_output_decompress from bit_transformer.error_handling import safe_model_forward, setup_error_logging from bit_transformer.types import BitSequence, TelemetryDict from enhanced_checkpoint_system import create_checkpoint_manager # Setup logging logger = setup_error_logging("INFO") class SimpleBitDataset(Dataset): """Simple dataset of bit sequences for training.""" def __init__(self, num_samples: int = 1000, seq_length: int = 128): self.num_samples = num_samples self.seq_length = seq_length self.data = self._generate_bit_sequences() def _generate_bit_sequences(self) -> List[torch.Tensor]: """Generate diverse bit sequences with different patterns.""" sequences = [] # Pattern 1: Alternating sequences for i in range(self.num_samples // 4): pattern = torch.tensor([i % 2] * self.seq_length, dtype=torch.long) sequences.append(pattern) # Pattern 2: Random sequences for i in range(self.num_samples // 4): pattern = torch.randint(0, 2, (self.seq_length,), dtype=torch.long) sequences.append(pattern) # Pattern 3: Structured patterns (runs) for i in range(self.num_samples // 4): pattern = [] pos = 0 while pos < self.seq_length: run_length = min(np.random.randint(1, 20), self.seq_length - pos) bit_value = np.random.randint(0, 2) pattern.extend([bit_value] * run_length) pos += run_length pattern = torch.tensor(pattern[:self.seq_length], dtype=torch.long) sequences.append(pattern) # Pattern 4: Fibonacci-like sequences remaining = self.num_samples - len(sequences) for i in range(remaining): pattern = [0, 1] while len(pattern) < self.seq_length: pattern.append(pattern[-1] ^ pattern[-2]) # XOR of last two bits pattern = torch.tensor(pattern[:self.seq_length], dtype=torch.long) sequences.append(pattern) return sequences def __len__(self): return len(self.data) def __getitem__(self, idx): sequence = self.data[idx] # For language modeling, input is sequence[:-1], target is sequence[1:] return sequence[:-1], sequence[1:] def compute_safety_metrics(predictions: torch.Tensor, targets: torch.Tensor) -> Dict[str, float]: """Compute K/C/S safety metrics.""" pred_bits = (predictions > 0.5).float().flatten() # K metric (Negentropy): Measure of order vs randomness if len(pred_bits) > 0: prob_1 = pred_bits.mean().item() prob_0 = 1 - prob_1 if prob_0 > 0 and prob_1 > 0: entropy = -prob_0 * np.log2(prob_0) - prob_1 * np.log2(prob_1) negentropy = 1.0 - entropy # Higher = more ordered else: negentropy = 1.0 if prob_1 == 1.0 or prob_1 == 0.0 else 0.0 else: negentropy = 0.0 # C metric (Complexity): Simple run-length approximation changes = (pred_bits[1:] != pred_bits[:-1]).sum().item() complexity = min(changes / len(pred_bits), 1.0) if len(pred_bits) > 1 else 0.0 # S metric (Symbiosis): Alignment with target distribution target_bits = targets.float().flatten() if len(target_bits) > 0: target_mean = target_bits.mean() pred_mean = pred_bits.mean() symbiosis = 1.0 - abs(target_mean - pred_mean).item() else: symbiosis = 1.0 return { 'K_negentropy': negentropy, 'C_complexity': complexity, 'S_symbiosis': symbiosis } def train_bittransformer(): """Main training function with all optimizations.""" logger.info("๐Ÿš€ Starting BitTransformerLM end-to-end training run!") # Model configuration - small but meaningful model_config = { 'd_model': 256, 'nhead': 8, 'num_layers': 4, 'dim_feedforward': 512, 'max_seq_len': 128, 'use_checkpoint': True, 'chunk_size': None, # Disable chunking for small model } training_config = { 'batch_size': 16, 'learning_rate': 1e-3, 'num_epochs': 10, 'save_every_n_epochs': 2, 'log_every_n_steps': 10 } # Initialize enhanced checkpoint manager checkpoint_manager = create_checkpoint_manager() session_id = checkpoint_manager.create_training_session( session_name="end_to_end_test", model_config=model_config, training_config=training_config ) logger.info(f"๐Ÿ“ Created training session: {session_id}") # Create dataset and dataloader logger.info("๐Ÿ“Š Creating training dataset...") dataset = SimpleBitDataset(num_samples=800, seq_length=model_config['max_seq_len']) dataloader = DataLoader(dataset, batch_size=training_config['batch_size'], shuffle=True) # Initialize model logger.info("๐Ÿง  Initializing BitTransformerLM model...") model = BitTransformerLM( d_model=model_config['d_model'], nhead=model_config['nhead'], num_layers=model_config['num_layers'], dim_feedforward=model_config['dim_feedforward'], max_seq_len=model_config['max_seq_len'], use_checkpoint=model_config['use_checkpoint'], chunk_size=model_config['chunk_size'] ) # Count parameters total_params = sum(p.numel() for p in model.parameters()) trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad) logger.info(f"๐Ÿ”ข Model parameters: {total_params:,} total, {trainable_params:,} trainable") # Setup optimizer and loss optimizer = optim.AdamW(model.parameters(), lr=training_config['learning_rate']) scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=training_config['num_epochs']) criterion = nn.CrossEntropyLoss() # Training loop logger.info("๐Ÿƒโ€โ™‚๏ธ Starting training loop...") for epoch in range(training_config['num_epochs']): model.train() epoch_loss = 0.0 epoch_metrics = {'K_negentropy': 0.0, 'C_complexity': 0.0, 'S_symbiosis': 0.0} num_batches = 0 start_time = time.time() for batch_idx, (inputs, targets) in enumerate(dataloader): optimizer.zero_grad() # Forward pass with safety monitoring try: # BitTransformerLM returns (logits, telemetry) output = safe_model_forward(model, inputs) if isinstance(output, tuple): logits, telemetry = output else: logits = output telemetry = {} # BitTransformerLM outputs logits for binary classification # Shape should be [batch, seq_len, 2] for binary vocab if logits.dim() == 2: # If [batch*seq_len, 2], already flattened logits_flat = logits targets_flat = targets.reshape(-1) else: # If [batch, seq_len, 2], flatten logits_flat = logits.reshape(-1, 2) targets_flat = targets.reshape(-1) loss = criterion(logits_flat, targets_flat) # Backward pass loss.backward() torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) optimizer.step() # Compute metrics with torch.no_grad(): # Handle different logits shapes for predictions if logits.dim() == 2: # [batch*seq_len, 2] -> reshape back to [batch, seq_len, 2] batch_size = inputs.shape[0] seq_len = inputs.shape[1] logits_reshaped = logits.reshape(batch_size, seq_len, 2) predictions = torch.softmax(logits_reshaped, dim=-1)[:, :, 1] # Prob of bit=1 else: # [batch, seq_len, 2] predictions = torch.softmax(logits, dim=-1)[:, :, 1] # Prob of bit=1 safety_metrics = compute_safety_metrics(predictions, targets) epoch_loss += loss.item() for key, value in safety_metrics.items(): epoch_metrics[key] += value num_batches += 1 # Logging if batch_idx % training_config['log_every_n_steps'] == 0: logger.info(f"Epoch {epoch+1}/{training_config['num_epochs']}, " f"Batch {batch_idx}/{len(dataloader)}, " f"Loss: {loss.item():.4f}, " f"K: {safety_metrics['K_negentropy']:.3f}, " f"C: {safety_metrics['C_complexity']:.3f}, " f"S: {safety_metrics['S_symbiosis']:.3f}") except Exception as e: logger.error(f"Error in batch {batch_idx}: {e}") continue # End of epoch processing scheduler.step() epoch_time = time.time() - start_time if num_batches > 0: avg_loss = epoch_loss / num_batches avg_metrics = {k: v / num_batches for k, v in epoch_metrics.items()} logger.info(f"โœ… Epoch {epoch+1} completed in {epoch_time:.2f}s") logger.info(f"๐Ÿ“Š Avg Loss: {avg_loss:.4f}") logger.info(f"๐Ÿ“ˆ Safety Metrics - K: {avg_metrics['K_negentropy']:.3f}, " f"C: {avg_metrics['C_complexity']:.3f}, " f"S: {avg_metrics['S_symbiosis']:.3f}") # Save checkpoint if (epoch + 1) % training_config['save_every_n_epochs'] == 0: checkpoint_success = checkpoint_manager.save_checkpoint( model=model, session_id=session_id, epoch=epoch + 1, metrics={ 'loss': avg_loss, 'learning_rate': scheduler.get_last_lr()[0], **avg_metrics }, optimizer_state=optimizer.state_dict(), scheduler_state=scheduler.state_dict() ) if checkpoint_success: logger.info(f"๐Ÿ’พ Checkpoint saved for epoch {epoch+1}") # Save best model if loss improved checkpoint_manager.save_best_model( session_id=session_id, model=model, metric_name='loss', metric_value=avg_loss, is_better_func=lambda x, y: x < y # Lower loss is better ) logger.info("๐ŸŽ‰ Training completed successfully!") # Test inference and compression logger.info("๐Ÿงช Testing model inference and compression...") model.eval() with torch.no_grad(): # Create a test sequence test_input = torch.randint(0, 2, (1, 64), dtype=torch.long) logger.info(f"๐Ÿ“ฅ Input sequence: {test_input.squeeze().tolist()}") # Model inference output_logits = model(test_input) output_probs = torch.softmax(output_logits, dim=-1) predicted_bits = torch.argmax(output_probs, dim=-1) logger.info(f"๐Ÿ“ค Predicted sequence: {predicted_bits.squeeze().tolist()}") # Test compression compressed = compress_bits_batch(predicted_bits) logger.info(f"๐Ÿ—œ๏ธ Compressed length: {len(compressed[0])} (original: {predicted_bits.shape[-1]})") # Decompress to verify decompressed = model_output_decompress(compressed) compression_success = torch.equal(predicted_bits, decompressed) logger.info(f"โœ… Compression/decompression successful: {compression_success}") # Final storage usage report storage_usage = checkpoint_manager.get_storage_usage() logger.info(f"๐Ÿ’พ Final storage usage: {storage_usage['total_gb']:.3f} GB") logger.info(f"๐Ÿ“ Training sessions: {storage_usage['num_sessions']}") return session_id, model, checkpoint_manager if __name__ == "__main__": try: session_id, trained_model, manager = train_bittransformer() print(f"\n๐ŸŽ‰ SUCCESS! Training session completed: {session_id}") print(f"๐Ÿ” Use checkpoint_manager.load_checkpoint('{session_id}') to resume") except Exception as e: logger.error(f"โŒ Training failed: {e}") import traceback traceback.print_exc() raise