#!/usr/bin/env python3 """ MarkovSpline-Enhanced BitTransformerLM Training Integrates MarkovSpline data smoothing directly into BitTransformerLM training pipeline for improved data preprocessing and gradient optimization. """ import os import sys import json import time import torch import torch.nn as nn import torch.optim as optim import numpy as np from pathlib import Path from typing import Dict, List, Tuple, Optional, Any from torch.utils.data import DataLoader, Dataset # Add MarkovSpline to path sys.path.insert(0, '/data/MarkovSpline') from bitpipe_integration import MarkovSplineBitPipeModule, create_markov_spline_bitpipe_module # BitTransformerLM imports from bit_transformer.model import BitTransformerLM from bit_transformer.telemetry import TelemetrySynthesizer # Simple trainer base class class BitwiseTrainer: """Simple base trainer for BitTransformerLM.""" def __init__(self, model, learning_rate=1e-3, max_grad_norm=1.0): self.model = model self.device = next(model.parameters()).device self.optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate) self.criterion = nn.CrossEntropyLoss() self.max_grad_norm = max_grad_norm def train_step(self, batch): """Simple training step.""" self.optimizer.zero_grad() outputs = self.model(batch['input_bits']) # BitTransformerLM returns (logits, telemetry) if isinstance(outputs, tuple): logits, telemetry = outputs else: logits = outputs loss = self.criterion(logits.reshape(-1, logits.size(-1)), batch['target_bits'].reshape(-1)) loss.backward() if self.max_grad_norm > 0: torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.max_grad_norm) self.optimizer.step() return {'loss': loss.item()} class MarkovSplineEnhancedDataset(Dataset): """Dataset wrapper that applies MarkovSpline preprocessing.""" def __init__(self, base_dataset: Dataset, markov_module: MarkovSplineBitPipeModule, smoothing_strength: float = 0.1, enable_smoothing: bool = True): self.base_dataset = base_dataset self.markov_module = markov_module self.smoothing_strength = smoothing_strength self.enable_smoothing = enable_smoothing # Initialize data preprocessor if enable_smoothing: self.markov_module.initialize_application('data_preprocessor', smoothing_strength=smoothing_strength, preserve_features=True) def __len__(self): return len(self.base_dataset) def __getitem__(self, idx): # Get original data data = self.base_dataset[idx] if not self.enable_smoothing: return data # Apply MarkovSpline preprocessing to bit sequences if isinstance(data, dict) and 'input_bits' in data: try: # Smooth input bits result = self.markov_module.process_data( [data['input_bits']], 'preprocess_training', binary_data=True ) if result['success'] and result['processed_sequences']: data['input_bits'] = result['processed_sequences'][0] data['smoothing_applied'] = True else: data['smoothing_applied'] = False except Exception as e: print(f"Warning: MarkovSpline preprocessing failed for sample {idx}: {e}") data['smoothing_applied'] = False return data class MarkovSplineEnhancedTrainer(BitwiseTrainer): """Enhanced BitTransformerLM trainer with MarkovSpline integration.""" def __init__(self, model: BitTransformerLM, markov_config: Optional[Dict] = None, gradient_smoothing: bool = True, data_smoothing: bool = True, smoothing_strength: float = 0.1, **kwargs): super().__init__(model, **kwargs) # Initialize MarkovSpline module self.markov_module = create_markov_spline_bitpipe_module(markov_config) self.gradient_smoothing = gradient_smoothing self.data_smoothing = data_smoothing self.smoothing_strength = smoothing_strength # Initialize gradient smoother if enabled if gradient_smoothing: self.markov_module.initialize_application('gradient_smoother', learning_rate=kwargs.get('learning_rate', 0.001), smoothing_strength=smoothing_strength, momentum_states=10) # Tracking self.smoothing_metrics = {} self.gradient_smooth_history = [] print(f"🌊 MarkovSpline Enhanced Trainer initialized") print(f" - Gradient smoothing: {'✅' if gradient_smoothing else '❌'}") print(f" - Data smoothing: {'✅' if data_smoothing else '❌'}") print(f" - Smoothing strength: {smoothing_strength}") def create_enhanced_dataloader(self, dataset: Dataset, batch_size: int = 8, **kwargs) -> DataLoader: """Create dataloader with MarkovSpline preprocessing.""" enhanced_dataset = MarkovSplineEnhancedDataset( dataset, self.markov_module, self.smoothing_strength, self.data_smoothing ) return DataLoader(enhanced_dataset, batch_size=batch_size, **kwargs) def apply_gradient_smoothing(self, parameters: Dict[str, torch.Tensor], gradients: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: """Apply MarkovSpline gradient smoothing.""" if not self.gradient_smoothing: return parameters try: # Process through MarkovSpline gradient smoother result = self.markov_module.process_data( { 'parameters': parameters, 'gradients': gradients }, 'smooth_gradients' ) if result['success']: self.gradient_smooth_history.append(result['optimization_metrics']) return result['smoothed_parameters'] else: print(f"Warning: Gradient smoothing failed: {result.get('error', 'Unknown')}") return parameters except Exception as e: print(f"Warning: Gradient smoothing error: {e}") return parameters def train_step(self, batch: Dict[str, torch.Tensor]) -> Dict[str, float]: """Enhanced training step with MarkovSpline integration.""" # Standard forward pass self.optimizer.zero_grad() # Forward pass outputs = self.model(batch['input_bits']) # BitTransformerLM returns (logits, telemetry) if isinstance(outputs, tuple): logits, telemetry = outputs else: logits = outputs loss = self.criterion(logits.reshape(-1, logits.size(-1)), batch['target_bits'].reshape(-1)) # Backward pass loss.backward() # Extract parameters and gradients for smoothing if self.gradient_smoothing: parameters = {} gradients = {} for name, param in self.model.named_parameters(): if param.grad is not None: parameters[name] = param.data.clone() gradients[name] = param.grad.data.clone() # Apply MarkovSpline gradient smoothing smoothed_params = self.apply_gradient_smoothing(parameters, gradients) # Update model parameters with smoothed values for name, param in self.model.named_parameters(): if name in smoothed_params: param.data = smoothed_params[name] # Standard optimizer step if self.max_grad_norm > 0: torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.max_grad_norm) self.optimizer.step() # Collect metrics metrics = { 'loss': loss.item(), 'smoothing_applied': batch.get('smoothing_applied', torch.tensor(False)).float().mean().item() } if hasattr(batch, 'smoothing_applied'): metrics['data_smoothing_rate'] = batch['smoothing_applied'].float().mean().item() return metrics def train_epoch(self, train_loader: DataLoader, epoch: int) -> Dict[str, float]: """Train one epoch with MarkovSpline enhancements.""" self.model.train() epoch_metrics = { 'loss': 0.0, 'smoothing_applied': 0.0, 'data_smoothing_rate': 0.0, 'gradient_smoothing_success': 0.0 } num_batches = 0 for batch_idx, batch in enumerate(train_loader): # Move batch to device for key in batch: if isinstance(batch[key], torch.Tensor): batch[key] = batch[key].to(self.device) # Training step with MarkovSpline integration step_metrics = self.train_step(batch) # Accumulate metrics for key, value in step_metrics.items(): if key in epoch_metrics: epoch_metrics[key] += value num_batches += 1 # Log progress if batch_idx % 10 == 0: print(f" Batch {batch_idx:3d}: Loss={step_metrics['loss']:.4f}") # Average metrics for key in epoch_metrics: epoch_metrics[key] /= num_batches return epoch_metrics def get_markov_spline_metrics(self) -> Dict[str, Any]: """Get comprehensive MarkovSpline performance metrics.""" metrics = self.markov_module.get_performance_metrics() # Add training-specific metrics metrics['training_integration'] = { 'gradient_smoothing_enabled': self.gradient_smoothing, 'data_smoothing_enabled': self.data_smoothing, 'smoothing_strength': self.smoothing_strength, 'gradient_smooth_operations': len(self.gradient_smooth_history) } if self.gradient_smooth_history: recent_gradient_metrics = self.gradient_smooth_history[-10:] # Last 10 operations metrics['recent_gradient_smoothing'] = { 'average_metrics': { key: np.mean([m.get(key, 0) for m in recent_gradient_metrics]) for key in recent_gradient_metrics[0].keys() } if recent_gradient_metrics else {} } return metrics def save_enhanced_checkpoint(self, checkpoint_path: str, epoch: int, metrics: Dict[str, float]): """Save checkpoint with MarkovSpline state.""" # Standard checkpoint data checkpoint = { 'epoch': epoch, 'model_state_dict': self.model.state_dict(), 'optimizer_state_dict': self.optimizer.state_dict(), 'metrics': metrics, 'config': self.model.get_config() } # Add MarkovSpline metrics checkpoint['markov_spline_metrics'] = self.get_markov_spline_metrics() checkpoint['markov_spline_config'] = { 'gradient_smoothing': self.gradient_smoothing, 'data_smoothing': self.data_smoothing, 'smoothing_strength': self.smoothing_strength } # Save MarkovSpline module state markov_state_path = Path(checkpoint_path).parent / 'markov_spline_state' self.markov_module.save_module_state(markov_state_path) torch.save(checkpoint, checkpoint_path) print(f"✅ Enhanced checkpoint saved: {checkpoint_path}") def create_markov_enhanced_training_config(base_config: Dict) -> Dict: """Create training configuration with MarkovSpline enhancements.""" enhanced_config = base_config.copy() # MarkovSpline specific settings enhanced_config.update({ 'markov_spline': { 'enabled': True, 'gradient_smoothing': True, 'data_smoothing': True, 'smoothing_strength': 0.1, 'num_states': 10, 'spline_type': 'cubic', 'adaptive_smoothing': True }, 'data_preprocessing': { 'smooth_training_data': True, 'preserve_features': True, 'preprocessing_strength': 0.15 }, 'gradient_optimization': { 'smooth_gradients': True, 'momentum_states': 10, 'learning_rate_smoothing': 0.2 } }) return enhanced_config def run_markov_enhanced_training(config_file: str = None): """Run BitTransformerLM training with MarkovSpline enhancements.""" # Load configuration if config_file and os.path.exists(config_file): with open(config_file, 'r') as f: config = json.load(f) else: # Default enhanced configuration config = create_markov_enhanced_training_config({ 'model': { 'd_model': 128, 'nhead': 8, 'num_layers': 4, 'dim_feedforward': 512, 'max_seq_len': 512 }, 'training': { 'batch_size': 8, 'learning_rate': 1e-4, 'epochs': 10, 'max_grad_norm': 1.0 } }) print("🌊 Starting MarkovSpline-Enhanced BitTransformerLM Training") print(f"📋 Configuration: {json.dumps(config, indent=2)}") # Initialize model model_config = config['model'] model = BitTransformerLM(**model_config) # Initialize enhanced trainer trainer = MarkovSplineEnhancedTrainer( model=model, markov_config=config.get('markov_spline'), gradient_smoothing=config['markov_spline']['gradient_smoothing'], data_smoothing=config['markov_spline']['data_smoothing'], smoothing_strength=config['markov_spline']['smoothing_strength'], **config['training'] ) print("🚀 Enhanced training pipeline initialized successfully!") return trainer, config if __name__ == '__main__': import argparse parser = argparse.ArgumentParser(description='MarkovSpline-Enhanced BitTransformerLM Training') parser.add_argument('--config', '-c', help='Configuration file path') parser.add_argument('--output-dir', '-o', default='./markov_enhanced_checkpoints', help='Output directory for checkpoints') args = parser.parse_args() # Create output directory os.makedirs(args.output_dir, exist_ok=True) # Run enhanced training trainer, config = run_markov_enhanced_training(args.config) print(f"📊 MarkovSpline metrics: {trainer.get_markov_spline_metrics()}")