|
|
|
""" |
|
MarkovSpline-Enhanced BitTransformerLM Training |
|
|
|
Integrates MarkovSpline data smoothing directly into BitTransformerLM training pipeline |
|
for improved data preprocessing and gradient optimization. |
|
""" |
|
|
|
import os |
|
import sys |
|
import json |
|
import time |
|
import torch |
|
import torch.nn as nn |
|
import torch.optim as optim |
|
import numpy as np |
|
from pathlib import Path |
|
from typing import Dict, List, Tuple, Optional, Any |
|
from torch.utils.data import DataLoader, Dataset |
|
|
|
|
|
sys.path.insert(0, '/data/MarkovSpline') |
|
from bitpipe_integration import MarkovSplineBitPipeModule, create_markov_spline_bitpipe_module |
|
|
|
|
|
from bit_transformer.model import BitTransformerLM |
|
from bit_transformer.telemetry import TelemetrySynthesizer |
|
|
|
|
|
class BitwiseTrainer: |
|
"""Simple base trainer for BitTransformerLM.""" |
|
|
|
def __init__(self, model, learning_rate=1e-3, max_grad_norm=1.0): |
|
self.model = model |
|
self.device = next(model.parameters()).device |
|
self.optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate) |
|
self.criterion = nn.CrossEntropyLoss() |
|
self.max_grad_norm = max_grad_norm |
|
|
|
def train_step(self, batch): |
|
"""Simple training step.""" |
|
self.optimizer.zero_grad() |
|
|
|
outputs = self.model(batch['input_bits']) |
|
|
|
if isinstance(outputs, tuple): |
|
logits, telemetry = outputs |
|
else: |
|
logits = outputs |
|
|
|
loss = self.criterion(logits.reshape(-1, logits.size(-1)), batch['target_bits'].reshape(-1)) |
|
|
|
loss.backward() |
|
|
|
if self.max_grad_norm > 0: |
|
torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.max_grad_norm) |
|
|
|
self.optimizer.step() |
|
|
|
return {'loss': loss.item()} |
|
|
|
|
|
class MarkovSplineEnhancedDataset(Dataset): |
|
"""Dataset wrapper that applies MarkovSpline preprocessing.""" |
|
|
|
def __init__(self, |
|
base_dataset: Dataset, |
|
markov_module: MarkovSplineBitPipeModule, |
|
smoothing_strength: float = 0.1, |
|
enable_smoothing: bool = True): |
|
|
|
self.base_dataset = base_dataset |
|
self.markov_module = markov_module |
|
self.smoothing_strength = smoothing_strength |
|
self.enable_smoothing = enable_smoothing |
|
|
|
|
|
if enable_smoothing: |
|
self.markov_module.initialize_application('data_preprocessor', |
|
smoothing_strength=smoothing_strength, |
|
preserve_features=True) |
|
|
|
def __len__(self): |
|
return len(self.base_dataset) |
|
|
|
def __getitem__(self, idx): |
|
|
|
data = self.base_dataset[idx] |
|
|
|
if not self.enable_smoothing: |
|
return data |
|
|
|
|
|
if isinstance(data, dict) and 'input_bits' in data: |
|
try: |
|
|
|
result = self.markov_module.process_data( |
|
[data['input_bits']], |
|
'preprocess_training', |
|
binary_data=True |
|
) |
|
|
|
if result['success'] and result['processed_sequences']: |
|
data['input_bits'] = result['processed_sequences'][0] |
|
data['smoothing_applied'] = True |
|
else: |
|
data['smoothing_applied'] = False |
|
|
|
except Exception as e: |
|
print(f"Warning: MarkovSpline preprocessing failed for sample {idx}: {e}") |
|
data['smoothing_applied'] = False |
|
|
|
return data |
|
|
|
|
|
class MarkovSplineEnhancedTrainer(BitwiseTrainer): |
|
"""Enhanced BitTransformerLM trainer with MarkovSpline integration.""" |
|
|
|
def __init__(self, |
|
model: BitTransformerLM, |
|
markov_config: Optional[Dict] = None, |
|
gradient_smoothing: bool = True, |
|
data_smoothing: bool = True, |
|
smoothing_strength: float = 0.1, |
|
**kwargs): |
|
|
|
super().__init__(model, **kwargs) |
|
|
|
|
|
self.markov_module = create_markov_spline_bitpipe_module(markov_config) |
|
self.gradient_smoothing = gradient_smoothing |
|
self.data_smoothing = data_smoothing |
|
self.smoothing_strength = smoothing_strength |
|
|
|
|
|
if gradient_smoothing: |
|
self.markov_module.initialize_application('gradient_smoother', |
|
learning_rate=kwargs.get('learning_rate', 0.001), |
|
smoothing_strength=smoothing_strength, |
|
momentum_states=10) |
|
|
|
|
|
self.smoothing_metrics = {} |
|
self.gradient_smooth_history = [] |
|
|
|
print(f"π MarkovSpline Enhanced Trainer initialized") |
|
print(f" - Gradient smoothing: {'β
' if gradient_smoothing else 'β'}") |
|
print(f" - Data smoothing: {'β
' if data_smoothing else 'β'}") |
|
print(f" - Smoothing strength: {smoothing_strength}") |
|
|
|
def create_enhanced_dataloader(self, |
|
dataset: Dataset, |
|
batch_size: int = 8, |
|
**kwargs) -> DataLoader: |
|
"""Create dataloader with MarkovSpline preprocessing.""" |
|
|
|
enhanced_dataset = MarkovSplineEnhancedDataset( |
|
dataset, |
|
self.markov_module, |
|
self.smoothing_strength, |
|
self.data_smoothing |
|
) |
|
|
|
return DataLoader(enhanced_dataset, batch_size=batch_size, **kwargs) |
|
|
|
def apply_gradient_smoothing(self, |
|
parameters: Dict[str, torch.Tensor], |
|
gradients: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: |
|
"""Apply MarkovSpline gradient smoothing.""" |
|
|
|
if not self.gradient_smoothing: |
|
return parameters |
|
|
|
try: |
|
|
|
result = self.markov_module.process_data( |
|
{ |
|
'parameters': parameters, |
|
'gradients': gradients |
|
}, |
|
'smooth_gradients' |
|
) |
|
|
|
if result['success']: |
|
self.gradient_smooth_history.append(result['optimization_metrics']) |
|
return result['smoothed_parameters'] |
|
else: |
|
print(f"Warning: Gradient smoothing failed: {result.get('error', 'Unknown')}") |
|
return parameters |
|
|
|
except Exception as e: |
|
print(f"Warning: Gradient smoothing error: {e}") |
|
return parameters |
|
|
|
def train_step(self, batch: Dict[str, torch.Tensor]) -> Dict[str, float]: |
|
"""Enhanced training step with MarkovSpline integration.""" |
|
|
|
|
|
self.optimizer.zero_grad() |
|
|
|
|
|
outputs = self.model(batch['input_bits']) |
|
|
|
if isinstance(outputs, tuple): |
|
logits, telemetry = outputs |
|
else: |
|
logits = outputs |
|
|
|
loss = self.criterion(logits.reshape(-1, logits.size(-1)), batch['target_bits'].reshape(-1)) |
|
|
|
|
|
loss.backward() |
|
|
|
|
|
if self.gradient_smoothing: |
|
parameters = {} |
|
gradients = {} |
|
|
|
for name, param in self.model.named_parameters(): |
|
if param.grad is not None: |
|
parameters[name] = param.data.clone() |
|
gradients[name] = param.grad.data.clone() |
|
|
|
|
|
smoothed_params = self.apply_gradient_smoothing(parameters, gradients) |
|
|
|
|
|
for name, param in self.model.named_parameters(): |
|
if name in smoothed_params: |
|
param.data = smoothed_params[name] |
|
|
|
|
|
if self.max_grad_norm > 0: |
|
torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.max_grad_norm) |
|
|
|
self.optimizer.step() |
|
|
|
|
|
metrics = { |
|
'loss': loss.item(), |
|
'smoothing_applied': batch.get('smoothing_applied', torch.tensor(False)).float().mean().item() |
|
} |
|
|
|
if hasattr(batch, 'smoothing_applied'): |
|
metrics['data_smoothing_rate'] = batch['smoothing_applied'].float().mean().item() |
|
|
|
return metrics |
|
|
|
def train_epoch(self, |
|
train_loader: DataLoader, |
|
epoch: int) -> Dict[str, float]: |
|
"""Train one epoch with MarkovSpline enhancements.""" |
|
|
|
self.model.train() |
|
epoch_metrics = { |
|
'loss': 0.0, |
|
'smoothing_applied': 0.0, |
|
'data_smoothing_rate': 0.0, |
|
'gradient_smoothing_success': 0.0 |
|
} |
|
|
|
num_batches = 0 |
|
|
|
for batch_idx, batch in enumerate(train_loader): |
|
|
|
for key in batch: |
|
if isinstance(batch[key], torch.Tensor): |
|
batch[key] = batch[key].to(self.device) |
|
|
|
|
|
step_metrics = self.train_step(batch) |
|
|
|
|
|
for key, value in step_metrics.items(): |
|
if key in epoch_metrics: |
|
epoch_metrics[key] += value |
|
|
|
num_batches += 1 |
|
|
|
|
|
if batch_idx % 10 == 0: |
|
print(f" Batch {batch_idx:3d}: Loss={step_metrics['loss']:.4f}") |
|
|
|
|
|
for key in epoch_metrics: |
|
epoch_metrics[key] /= num_batches |
|
|
|
return epoch_metrics |
|
|
|
def get_markov_spline_metrics(self) -> Dict[str, Any]: |
|
"""Get comprehensive MarkovSpline performance metrics.""" |
|
|
|
metrics = self.markov_module.get_performance_metrics() |
|
|
|
|
|
metrics['training_integration'] = { |
|
'gradient_smoothing_enabled': self.gradient_smoothing, |
|
'data_smoothing_enabled': self.data_smoothing, |
|
'smoothing_strength': self.smoothing_strength, |
|
'gradient_smooth_operations': len(self.gradient_smooth_history) |
|
} |
|
|
|
if self.gradient_smooth_history: |
|
recent_gradient_metrics = self.gradient_smooth_history[-10:] |
|
metrics['recent_gradient_smoothing'] = { |
|
'average_metrics': { |
|
key: np.mean([m.get(key, 0) for m in recent_gradient_metrics]) |
|
for key in recent_gradient_metrics[0].keys() |
|
} if recent_gradient_metrics else {} |
|
} |
|
|
|
return metrics |
|
|
|
def save_enhanced_checkpoint(self, |
|
checkpoint_path: str, |
|
epoch: int, |
|
metrics: Dict[str, float]): |
|
"""Save checkpoint with MarkovSpline state.""" |
|
|
|
|
|
checkpoint = { |
|
'epoch': epoch, |
|
'model_state_dict': self.model.state_dict(), |
|
'optimizer_state_dict': self.optimizer.state_dict(), |
|
'metrics': metrics, |
|
'config': self.model.get_config() |
|
} |
|
|
|
|
|
checkpoint['markov_spline_metrics'] = self.get_markov_spline_metrics() |
|
checkpoint['markov_spline_config'] = { |
|
'gradient_smoothing': self.gradient_smoothing, |
|
'data_smoothing': self.data_smoothing, |
|
'smoothing_strength': self.smoothing_strength |
|
} |
|
|
|
|
|
markov_state_path = Path(checkpoint_path).parent / 'markov_spline_state' |
|
self.markov_module.save_module_state(markov_state_path) |
|
|
|
torch.save(checkpoint, checkpoint_path) |
|
print(f"β
Enhanced checkpoint saved: {checkpoint_path}") |
|
|
|
|
|
def create_markov_enhanced_training_config(base_config: Dict) -> Dict: |
|
"""Create training configuration with MarkovSpline enhancements.""" |
|
|
|
enhanced_config = base_config.copy() |
|
|
|
|
|
enhanced_config.update({ |
|
'markov_spline': { |
|
'enabled': True, |
|
'gradient_smoothing': True, |
|
'data_smoothing': True, |
|
'smoothing_strength': 0.1, |
|
'num_states': 10, |
|
'spline_type': 'cubic', |
|
'adaptive_smoothing': True |
|
}, |
|
'data_preprocessing': { |
|
'smooth_training_data': True, |
|
'preserve_features': True, |
|
'preprocessing_strength': 0.15 |
|
}, |
|
'gradient_optimization': { |
|
'smooth_gradients': True, |
|
'momentum_states': 10, |
|
'learning_rate_smoothing': 0.2 |
|
} |
|
}) |
|
|
|
return enhanced_config |
|
|
|
|
|
def run_markov_enhanced_training(config_file: str = None): |
|
"""Run BitTransformerLM training with MarkovSpline enhancements.""" |
|
|
|
|
|
if config_file and os.path.exists(config_file): |
|
with open(config_file, 'r') as f: |
|
config = json.load(f) |
|
else: |
|
|
|
config = create_markov_enhanced_training_config({ |
|
'model': { |
|
'd_model': 128, |
|
'nhead': 8, |
|
'num_layers': 4, |
|
'dim_feedforward': 512, |
|
'max_seq_len': 512 |
|
}, |
|
'training': { |
|
'batch_size': 8, |
|
'learning_rate': 1e-4, |
|
'epochs': 10, |
|
'max_grad_norm': 1.0 |
|
} |
|
}) |
|
|
|
print("π Starting MarkovSpline-Enhanced BitTransformerLM Training") |
|
print(f"π Configuration: {json.dumps(config, indent=2)}") |
|
|
|
|
|
model_config = config['model'] |
|
model = BitTransformerLM(**model_config) |
|
|
|
|
|
trainer = MarkovSplineEnhancedTrainer( |
|
model=model, |
|
markov_config=config.get('markov_spline'), |
|
gradient_smoothing=config['markov_spline']['gradient_smoothing'], |
|
data_smoothing=config['markov_spline']['data_smoothing'], |
|
smoothing_strength=config['markov_spline']['smoothing_strength'], |
|
**config['training'] |
|
) |
|
|
|
print("π Enhanced training pipeline initialized successfully!") |
|
return trainer, config |
|
|
|
|
|
if __name__ == '__main__': |
|
import argparse |
|
|
|
parser = argparse.ArgumentParser(description='MarkovSpline-Enhanced BitTransformerLM Training') |
|
parser.add_argument('--config', '-c', help='Configuration file path') |
|
parser.add_argument('--output-dir', '-o', default='./markov_enhanced_checkpoints', |
|
help='Output directory for checkpoints') |
|
|
|
args = parser.parse_args() |
|
|
|
|
|
os.makedirs(args.output_dir, exist_ok=True) |
|
|
|
|
|
trainer, config = run_markov_enhanced_training(args.config) |
|
|
|
print(f"π MarkovSpline metrics: {trainer.get_markov_spline_metrics()}") |