BitTransformerLM / markov_spline_training.py
WCNegentropy's picture
πŸš€ OS Launch: Clean documentation and refined licensing
dfe6d16 verified
raw
history blame
16 kB
#!/usr/bin/env python3
"""
MarkovSpline-Enhanced BitTransformerLM Training
Integrates MarkovSpline data smoothing directly into BitTransformerLM training pipeline
for improved data preprocessing and gradient optimization.
"""
import os
import sys
import json
import time
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
from pathlib import Path
from typing import Dict, List, Tuple, Optional, Any
from torch.utils.data import DataLoader, Dataset
# Add MarkovSpline to path
sys.path.insert(0, '/data/MarkovSpline')
from bitpipe_integration import MarkovSplineBitPipeModule, create_markov_spline_bitpipe_module
# BitTransformerLM imports
from bit_transformer.model import BitTransformerLM
from bit_transformer.telemetry import TelemetrySynthesizer
# Simple trainer base class
class BitwiseTrainer:
"""Simple base trainer for BitTransformerLM."""
def __init__(self, model, learning_rate=1e-3, max_grad_norm=1.0):
self.model = model
self.device = next(model.parameters()).device
self.optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
self.criterion = nn.CrossEntropyLoss()
self.max_grad_norm = max_grad_norm
def train_step(self, batch):
"""Simple training step."""
self.optimizer.zero_grad()
outputs = self.model(batch['input_bits'])
# BitTransformerLM returns (logits, telemetry)
if isinstance(outputs, tuple):
logits, telemetry = outputs
else:
logits = outputs
loss = self.criterion(logits.reshape(-1, logits.size(-1)), batch['target_bits'].reshape(-1))
loss.backward()
if self.max_grad_norm > 0:
torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.max_grad_norm)
self.optimizer.step()
return {'loss': loss.item()}
class MarkovSplineEnhancedDataset(Dataset):
"""Dataset wrapper that applies MarkovSpline preprocessing."""
def __init__(self,
base_dataset: Dataset,
markov_module: MarkovSplineBitPipeModule,
smoothing_strength: float = 0.1,
enable_smoothing: bool = True):
self.base_dataset = base_dataset
self.markov_module = markov_module
self.smoothing_strength = smoothing_strength
self.enable_smoothing = enable_smoothing
# Initialize data preprocessor
if enable_smoothing:
self.markov_module.initialize_application('data_preprocessor',
smoothing_strength=smoothing_strength,
preserve_features=True)
def __len__(self):
return len(self.base_dataset)
def __getitem__(self, idx):
# Get original data
data = self.base_dataset[idx]
if not self.enable_smoothing:
return data
# Apply MarkovSpline preprocessing to bit sequences
if isinstance(data, dict) and 'input_bits' in data:
try:
# Smooth input bits
result = self.markov_module.process_data(
[data['input_bits']],
'preprocess_training',
binary_data=True
)
if result['success'] and result['processed_sequences']:
data['input_bits'] = result['processed_sequences'][0]
data['smoothing_applied'] = True
else:
data['smoothing_applied'] = False
except Exception as e:
print(f"Warning: MarkovSpline preprocessing failed for sample {idx}: {e}")
data['smoothing_applied'] = False
return data
class MarkovSplineEnhancedTrainer(BitwiseTrainer):
"""Enhanced BitTransformerLM trainer with MarkovSpline integration."""
def __init__(self,
model: BitTransformerLM,
markov_config: Optional[Dict] = None,
gradient_smoothing: bool = True,
data_smoothing: bool = True,
smoothing_strength: float = 0.1,
**kwargs):
super().__init__(model, **kwargs)
# Initialize MarkovSpline module
self.markov_module = create_markov_spline_bitpipe_module(markov_config)
self.gradient_smoothing = gradient_smoothing
self.data_smoothing = data_smoothing
self.smoothing_strength = smoothing_strength
# Initialize gradient smoother if enabled
if gradient_smoothing:
self.markov_module.initialize_application('gradient_smoother',
learning_rate=kwargs.get('learning_rate', 0.001),
smoothing_strength=smoothing_strength,
momentum_states=10)
# Tracking
self.smoothing_metrics = {}
self.gradient_smooth_history = []
print(f"🌊 MarkovSpline Enhanced Trainer initialized")
print(f" - Gradient smoothing: {'βœ…' if gradient_smoothing else '❌'}")
print(f" - Data smoothing: {'βœ…' if data_smoothing else '❌'}")
print(f" - Smoothing strength: {smoothing_strength}")
def create_enhanced_dataloader(self,
dataset: Dataset,
batch_size: int = 8,
**kwargs) -> DataLoader:
"""Create dataloader with MarkovSpline preprocessing."""
enhanced_dataset = MarkovSplineEnhancedDataset(
dataset,
self.markov_module,
self.smoothing_strength,
self.data_smoothing
)
return DataLoader(enhanced_dataset, batch_size=batch_size, **kwargs)
def apply_gradient_smoothing(self,
parameters: Dict[str, torch.Tensor],
gradients: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
"""Apply MarkovSpline gradient smoothing."""
if not self.gradient_smoothing:
return parameters
try:
# Process through MarkovSpline gradient smoother
result = self.markov_module.process_data(
{
'parameters': parameters,
'gradients': gradients
},
'smooth_gradients'
)
if result['success']:
self.gradient_smooth_history.append(result['optimization_metrics'])
return result['smoothed_parameters']
else:
print(f"Warning: Gradient smoothing failed: {result.get('error', 'Unknown')}")
return parameters
except Exception as e:
print(f"Warning: Gradient smoothing error: {e}")
return parameters
def train_step(self, batch: Dict[str, torch.Tensor]) -> Dict[str, float]:
"""Enhanced training step with MarkovSpline integration."""
# Standard forward pass
self.optimizer.zero_grad()
# Forward pass
outputs = self.model(batch['input_bits'])
# BitTransformerLM returns (logits, telemetry)
if isinstance(outputs, tuple):
logits, telemetry = outputs
else:
logits = outputs
loss = self.criterion(logits.reshape(-1, logits.size(-1)), batch['target_bits'].reshape(-1))
# Backward pass
loss.backward()
# Extract parameters and gradients for smoothing
if self.gradient_smoothing:
parameters = {}
gradients = {}
for name, param in self.model.named_parameters():
if param.grad is not None:
parameters[name] = param.data.clone()
gradients[name] = param.grad.data.clone()
# Apply MarkovSpline gradient smoothing
smoothed_params = self.apply_gradient_smoothing(parameters, gradients)
# Update model parameters with smoothed values
for name, param in self.model.named_parameters():
if name in smoothed_params:
param.data = smoothed_params[name]
# Standard optimizer step
if self.max_grad_norm > 0:
torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.max_grad_norm)
self.optimizer.step()
# Collect metrics
metrics = {
'loss': loss.item(),
'smoothing_applied': batch.get('smoothing_applied', torch.tensor(False)).float().mean().item()
}
if hasattr(batch, 'smoothing_applied'):
metrics['data_smoothing_rate'] = batch['smoothing_applied'].float().mean().item()
return metrics
def train_epoch(self,
train_loader: DataLoader,
epoch: int) -> Dict[str, float]:
"""Train one epoch with MarkovSpline enhancements."""
self.model.train()
epoch_metrics = {
'loss': 0.0,
'smoothing_applied': 0.0,
'data_smoothing_rate': 0.0,
'gradient_smoothing_success': 0.0
}
num_batches = 0
for batch_idx, batch in enumerate(train_loader):
# Move batch to device
for key in batch:
if isinstance(batch[key], torch.Tensor):
batch[key] = batch[key].to(self.device)
# Training step with MarkovSpline integration
step_metrics = self.train_step(batch)
# Accumulate metrics
for key, value in step_metrics.items():
if key in epoch_metrics:
epoch_metrics[key] += value
num_batches += 1
# Log progress
if batch_idx % 10 == 0:
print(f" Batch {batch_idx:3d}: Loss={step_metrics['loss']:.4f}")
# Average metrics
for key in epoch_metrics:
epoch_metrics[key] /= num_batches
return epoch_metrics
def get_markov_spline_metrics(self) -> Dict[str, Any]:
"""Get comprehensive MarkovSpline performance metrics."""
metrics = self.markov_module.get_performance_metrics()
# Add training-specific metrics
metrics['training_integration'] = {
'gradient_smoothing_enabled': self.gradient_smoothing,
'data_smoothing_enabled': self.data_smoothing,
'smoothing_strength': self.smoothing_strength,
'gradient_smooth_operations': len(self.gradient_smooth_history)
}
if self.gradient_smooth_history:
recent_gradient_metrics = self.gradient_smooth_history[-10:] # Last 10 operations
metrics['recent_gradient_smoothing'] = {
'average_metrics': {
key: np.mean([m.get(key, 0) for m in recent_gradient_metrics])
for key in recent_gradient_metrics[0].keys()
} if recent_gradient_metrics else {}
}
return metrics
def save_enhanced_checkpoint(self,
checkpoint_path: str,
epoch: int,
metrics: Dict[str, float]):
"""Save checkpoint with MarkovSpline state."""
# Standard checkpoint data
checkpoint = {
'epoch': epoch,
'model_state_dict': self.model.state_dict(),
'optimizer_state_dict': self.optimizer.state_dict(),
'metrics': metrics,
'config': self.model.get_config()
}
# Add MarkovSpline metrics
checkpoint['markov_spline_metrics'] = self.get_markov_spline_metrics()
checkpoint['markov_spline_config'] = {
'gradient_smoothing': self.gradient_smoothing,
'data_smoothing': self.data_smoothing,
'smoothing_strength': self.smoothing_strength
}
# Save MarkovSpline module state
markov_state_path = Path(checkpoint_path).parent / 'markov_spline_state'
self.markov_module.save_module_state(markov_state_path)
torch.save(checkpoint, checkpoint_path)
print(f"βœ… Enhanced checkpoint saved: {checkpoint_path}")
def create_markov_enhanced_training_config(base_config: Dict) -> Dict:
"""Create training configuration with MarkovSpline enhancements."""
enhanced_config = base_config.copy()
# MarkovSpline specific settings
enhanced_config.update({
'markov_spline': {
'enabled': True,
'gradient_smoothing': True,
'data_smoothing': True,
'smoothing_strength': 0.1,
'num_states': 10,
'spline_type': 'cubic',
'adaptive_smoothing': True
},
'data_preprocessing': {
'smooth_training_data': True,
'preserve_features': True,
'preprocessing_strength': 0.15
},
'gradient_optimization': {
'smooth_gradients': True,
'momentum_states': 10,
'learning_rate_smoothing': 0.2
}
})
return enhanced_config
def run_markov_enhanced_training(config_file: str = None):
"""Run BitTransformerLM training with MarkovSpline enhancements."""
# Load configuration
if config_file and os.path.exists(config_file):
with open(config_file, 'r') as f:
config = json.load(f)
else:
# Default enhanced configuration
config = create_markov_enhanced_training_config({
'model': {
'd_model': 128,
'nhead': 8,
'num_layers': 4,
'dim_feedforward': 512,
'max_seq_len': 512
},
'training': {
'batch_size': 8,
'learning_rate': 1e-4,
'epochs': 10,
'max_grad_norm': 1.0
}
})
print("🌊 Starting MarkovSpline-Enhanced BitTransformerLM Training")
print(f"πŸ“‹ Configuration: {json.dumps(config, indent=2)}")
# Initialize model
model_config = config['model']
model = BitTransformerLM(**model_config)
# Initialize enhanced trainer
trainer = MarkovSplineEnhancedTrainer(
model=model,
markov_config=config.get('markov_spline'),
gradient_smoothing=config['markov_spline']['gradient_smoothing'],
data_smoothing=config['markov_spline']['data_smoothing'],
smoothing_strength=config['markov_spline']['smoothing_strength'],
**config['training']
)
print("πŸš€ Enhanced training pipeline initialized successfully!")
return trainer, config
if __name__ == '__main__':
import argparse
parser = argparse.ArgumentParser(description='MarkovSpline-Enhanced BitTransformerLM Training')
parser.add_argument('--config', '-c', help='Configuration file path')
parser.add_argument('--output-dir', '-o', default='./markov_enhanced_checkpoints',
help='Output directory for checkpoints')
args = parser.parse_args()
# Create output directory
os.makedirs(args.output_dir, exist_ok=True)
# Run enhanced training
trainer, config = run_markov_enhanced_training(args.config)
print(f"πŸ“Š MarkovSpline metrics: {trainer.get_markov_spline_metrics()}")