#!/usr/bin/env python3 """ MarkovSpline CLI Interface for BitTransformerLM Integration Provides command-line tools for using MarkovSpline data smoothing with BitTransformerLM training and inference pipelines. """ import argparse import sys import os import json import numpy as np import torch from pathlib import Path from typing import List, Dict, Any, Optional # Add MarkovSpline to path sys.path.insert(0, '/data/MarkovSpline') from bitpipe_integration import MarkovSplineBitPipeModule, create_markov_spline_bitpipe_module from core import SplineType # Simple text to bits converter for CLI class TextToBitsConverter: """Simple text to bits converter.""" def text_to_bits(self, text, max_length=128): """Convert text to bit sequence.""" bit_sequence = [] for char in text[:max_length//8]: char_bits = format(ord(char), '08b') bit_sequence.extend([int(b) for b in char_bits]) # Pad or truncate to max_length if len(bit_sequence) < max_length: bit_sequence.extend([0] * (max_length - len(bit_sequence))) else: bit_sequence = bit_sequence[:max_length] return bit_sequence class MarkovSplineBitTransformerCLI: """CLI interface for MarkovSpline + BitTransformerLM integration.""" def __init__(self): self.markov_module = None self.text_converter = TextToBitsConverter() def initialize_markov_spline(self, config: Optional[Dict] = None) -> bool: """Initialize MarkovSpline module with configuration.""" try: self.markov_module = create_markov_spline_bitpipe_module(config) print(f"✅ Initialized MarkovSpline module: {self.markov_module.module_name}") return True except Exception as e: print(f"❌ Failed to initialize MarkovSpline: {e}") return False def preprocess_text_data(self, input_file: str, output_file: str, smoothing_strength: float = 0.15, chunk_size: int = 128) -> bool: """Preprocess text data using MarkovSpline for BitTransformerLM training.""" if not self.markov_module: print("❌ MarkovSpline module not initialized") return False try: # Read input text with open(input_file, 'r', encoding='utf-8') as f: text_data = f.read().strip().split('\n') print(f"📖 Processing {len(text_data)} text samples...") # Convert text to bit sequences bit_sequences = [] for text in text_data: if text.strip(): bits = self.text_converter.text_to_bits(text, max_length=chunk_size) bit_sequences.append(bits) print(f"🔄 Converting to bit sequences: {len(bit_sequences)} sequences") # Initialize MarkovSpline preprocessor self.markov_module.initialize_application('data_preprocessor', smoothing_strength=smoothing_strength, preserve_features=True) # Process bit sequences through MarkovSpline result = self.markov_module.process_data( bit_sequences, 'preprocess_training', binary_data=True ) if not result['success']: print(f"❌ Processing failed: {result.get('error', 'Unknown error')}") return False # Save processed sequences processed_data = { 'processed_sequences': result['processed_sequences'], 'preprocessing_summary': result['preprocessing_summary'], 'original_count': len(bit_sequences), 'smoothing_strength': smoothing_strength, 'chunk_size': chunk_size } with open(output_file, 'w') as f: json.dump(processed_data, f, indent=2, default=str) print(f"✅ Preprocessed data saved to: {output_file}") print(f"📊 Summary: {result['preprocessing_summary']}") return True except Exception as e: print(f"❌ Preprocessing failed: {e}") return False def smooth_bit_sequence(self, bit_sequence: List[int], smoothing_type: str = 'predict_binary', num_predictions: int = 10) -> Dict[str, Any]: """Smooth/predict bit sequence using MarkovSpline.""" if not self.markov_module: print("❌ MarkovSpline module not initialized") return {'success': False, 'error': 'Module not initialized'} try: result = self.markov_module.process_data( bit_sequence, smoothing_type, num_predictions=num_predictions ) return result except Exception as e: print(f"❌ Bit sequence processing failed: {e}") return {'success': False, 'error': str(e)} def smooth_training_gradients(self, gradient_file: str, output_file: str, learning_rate: float = 0.01, smoothing_strength: float = 0.2) -> bool: """Apply MarkovSpline gradient smoothing to BitTransformerLM training.""" if not self.markov_module: print("❌ MarkovSpline module not initialized") return False try: # Load gradient data (assuming PyTorch checkpoint format) checkpoint = torch.load(gradient_file, map_location='cpu') if 'gradients' not in checkpoint or 'parameters' not in checkpoint: print("❌ Invalid gradient file format") return False # Initialize gradient smoother self.markov_module.initialize_application('gradient_smoother', learning_rate=learning_rate, smoothing_strength=smoothing_strength) # Process gradients result = self.markov_module.process_data( { 'parameters': checkpoint['parameters'], 'gradients': checkpoint['gradients'] }, 'smooth_gradients' ) if not result['success']: print(f"❌ Gradient smoothing failed: {result.get('error', 'Unknown error')}") return False # Save smoothed parameters smoothed_checkpoint = { 'smoothed_parameters': result['smoothed_parameters'], 'optimization_metrics': result['optimization_metrics'], 'original_gradients': checkpoint['gradients'] } torch.save(smoothed_checkpoint, output_file) print(f"✅ Smoothed gradients saved to: {output_file}") print(f"📊 Optimization metrics: {result['optimization_metrics']}") return True except Exception as e: print(f"❌ Gradient smoothing failed: {e}") return False def create_smoothed_dataset(self, input_dataset: str, output_dataset: str, config: Optional[Dict] = None) -> bool: """Create smoothed dataset for BitTransformerLM training.""" # Default configuration for dataset smoothing default_config = { 'smoothing_strength': 0.1, 'num_states': 20, 'spline_type': 'cubic', 'preserve_features': True } if config: default_config.update(config) if not self.markov_module: self.initialize_markov_spline(default_config) return self.preprocess_text_data(input_dataset, output_dataset, default_config['smoothing_strength']) def main(): parser = argparse.ArgumentParser(description='MarkovSpline CLI for BitTransformerLM') parser.add_argument('command', choices=['preprocess', 'smooth-gradients', 'create-dataset', 'predict-bits'], help='Command to execute') # Common arguments parser.add_argument('--input', '-i', required=True, help='Input file path') parser.add_argument('--output', '-o', required=True, help='Output file path') parser.add_argument('--config', '-c', help='Configuration JSON file') # Preprocessing arguments parser.add_argument('--smoothing-strength', type=float, default=0.15, help='Smoothing strength (0.0-1.0)') parser.add_argument('--chunk-size', type=int, default=128, help='Text chunk size for bit conversion') # Gradient smoothing arguments parser.add_argument('--learning-rate', type=float, default=0.01, help='Learning rate for gradient smoothing') # Bit prediction arguments parser.add_argument('--num-predictions', type=int, default=10, help='Number of bit predictions to generate') args = parser.parse_args() # Load configuration if provided config = None if args.config: try: with open(args.config, 'r') as f: config = json.load(f) except Exception as e: print(f"❌ Failed to load config: {e}") return 1 # Initialize CLI cli = MarkovSplineBitTransformerCLI() if not cli.initialize_markov_spline(config): return 1 # Execute command success = False if args.command == 'preprocess': success = cli.preprocess_text_data( args.input, args.output, args.smoothing_strength, args.chunk_size ) elif args.command == 'smooth-gradients': success = cli.smooth_training_gradients( args.input, args.output, args.learning_rate, args.smoothing_strength ) elif args.command == 'create-dataset': success = cli.create_smoothed_dataset( args.input, args.output, config ) elif args.command == 'predict-bits': # Read bit sequence from input file try: with open(args.input, 'r') as f: bit_data = json.load(f) bit_sequence = bit_data.get('bits', []) result = cli.smooth_bit_sequence(bit_sequence, 'predict_binary', args.num_predictions) if result['success']: with open(args.output, 'w') as f: json.dump(result, f, indent=2, default=str) print(f"✅ Bit predictions saved to: {args.output}") success = True else: print(f"❌ Bit prediction failed: {result.get('error', 'Unknown error')}") except Exception as e: print(f"❌ Bit prediction failed: {e}") return 0 if success else 1 if __name__ == '__main__': sys.exit(main())