#!/usr/bin/env python3 """ CPU-Optimized Edge Deployment BitTransformerLM Training Optimized for consumer devices and edge applications. """ import os import time import torch import torch.nn.functional as F from datasets import load_dataset from bit_transformer import ( BitTransformerLM, text_to_bits, bits_to_text, train_loop, configure_optimizer, save_model, load_model, set_dropout, hil_safe_inference, quantize_dynamic, ) from bit_transformer.torch_utils import cpu_autocast from bit_transformer.training import train_loop def create_optimal_cpu_model(): """Create BitTransformerLM optimized for CPU edge deployment.""" print("๐Ÿง  Creating CPU-optimized BitTransformerLM...") # Optimal configuration for edge devices: # - Small model size for low memory footprint # - CPU autocast for faster FP16 inference # - No reversible layers (simpler for CPU) # - Gradient checkpointing disabled for speed # - Small context length for efficiency model = BitTransformerLM( d_model=64, # Small embedding dimension (vs 128 default) nhead=4, # Fewer attention heads (vs 8 default) num_layers=3, # Shallow model (vs 4 default) dim_feedforward=128, # Smaller FFN (vs 512 default) max_seq_len=256, # Shorter context (vs 1024 default) reversible=False, # Disable reversible layers (CPU doesn't benefit much) use_checkpoint=False, # Disable gradient checkpointing (prioritize speed) use_autocast=True, # Enable CPU autocast for BF16 mixed precision use_act=False, # Disable ACT for simplicity chunk_size=32, # Small chunks for memory efficiency full_attn_logging=False, # Disable attention logging to save memory lambda_K=1.0, # Standard telemetry weights lambda_C=1.0, lambda_S=1.0, ) # Calculate model parameters total_params = sum(p.numel() for p in model.parameters()) trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad) print(f" ๐Ÿ“Š Model Configuration:") print(f" d_model: {64}") print(f" num_layers: {3}") print(f" nhead: {4}") print(f" dim_feedforward: {128}") print(f" max_seq_len: {256}") print(f" Total parameters: {total_params:,}") print(f" Trainable parameters: {trainable_params:,}") print(f" Estimated size: {total_params * 4 / 1024 / 1024:.1f}MB (FP32)") print(f" With autocast: ~{total_params * 2 / 1024 / 1024:.1f}MB (BF16)") return model def load_training_dataset(dataset_size=512, max_len=128): """Load and prepare training dataset optimized for edge training.""" print("๐Ÿ“š Loading training dataset...") try: # Try to load BitTransformerLM dataset from HuggingFace print(" Attempting to load BitTransformerLM dataset...") dataset = load_dataset("WCNegentropy/BitTransformerLM", split="train[:{}]".format(dataset_size)) if dataset and len(dataset) > 0: train_texts = [item['text'] for item in dataset if item.get('text')] if len(train_texts) > 0: print(f" โœ… Loaded {len(train_texts)} samples from BitTransformerLM dataset") else: raise Exception("No text samples found in dataset") else: raise Exception("Dataset empty or not accessible") except Exception as e: print(f" โš ๏ธ BitTransformerLM dataset not available: {e}") print(" ๐Ÿ“– Falling back to WikiText-2...") try: # Fallback to WikiText-2 for training ds = load_dataset("wikitext", "wikitext-2-raw-v1") train_texts = [text for text in ds["train"]["text"] if text.strip()][:dataset_size] print(f" โœ… Loaded {len(train_texts)} samples from WikiText-2") except Exception as e2: print(f" โŒ Failed to load WikiText-2: {e2}") print(" ๐ŸŽฒ Using synthetic text data...") # Generate simple synthetic text for demonstration synthetic_texts = [ "The quick brown fox jumps over the lazy dog.", "Machine learning is transforming technology.", "Edge computing enables local AI processing.", "BitTransformerLM uses bit-native processing.", "CPU optimization improves inference speed.", "Neural networks learn from training data.", "Transformers use attention mechanisms.", "Language models understand text patterns.", ] train_texts = (synthetic_texts * (dataset_size // len(synthetic_texts) + 1))[:dataset_size] print(f" โœ… Generated {len(train_texts)} synthetic samples") # Convert text to bits print(" ๐Ÿ”„ Converting text to bits...") train_sequences = [] valid_sequences = [] for i, text in enumerate(train_texts): try: bits = text_to_bits(text)[:max_len] if len(bits) < max_len: bits.extend([0] * (max_len - len(bits))) # Pad to max_len # Use 80/20 split for train/validation if i < len(train_texts) * 0.8: train_sequences.append(bits) else: valid_sequences.append(bits) except Exception as e: print(f" โš ๏ธ Failed to convert text to bits: {e}") continue train_tensor = torch.tensor(train_sequences, dtype=torch.long) valid_tensor = torch.tensor(valid_sequences, dtype=torch.long) if valid_sequences else train_tensor[:16] print(f" ๐Ÿ“Š Dataset Statistics:") print(f" Training sequences: {len(train_sequences)}") print(f" Validation sequences: {len(valid_sequences)}") print(f" Sequence length: {max_len}") print(f" Training tensor shape: {train_tensor.shape}") return train_tensor, valid_tensor, train_texts[:len(train_sequences)] def train_cpu_optimized_model(model, train_data, valid_data, epochs=5): """Train the model with CPU-optimized settings.""" print(f"๐Ÿš€ Training CPU-optimized BitTransformerLM for {epochs} epochs...") # Set model to training mode model.train() set_dropout(model, 0.1) # Configure optimizer for edge deployment # Lower learning rate and smaller batch size for stable CPU training batch_size = 4 # Small batch size for memory efficiency learning_rate = 5e-4 # Conservative learning rate total_steps = max(1, epochs * (len(train_data) // batch_size)) # Ensure at least 1 step if len(train_data) == 0: raise ValueError("No training data available - check dataset loading") optimizer, scheduler = configure_optimizer( model, lr=learning_rate, total_steps=total_steps, weight_decay=0.01 ) print(f" ๐Ÿ“‹ Training Configuration:") print(f" Batch size: {batch_size}") print(f" Learning rate: {learning_rate}") print(f" Total steps: {total_steps}") print(f" CPU autocast: Enabled") # Training loop with CPU optimizations train_losses = [] for epoch in range(epochs): print(f"\n ๐Ÿ“– Epoch {epoch + 1}/{epochs}") epoch_losses = [] epoch_start_time = time.time() # Shuffle training data perm = torch.randperm(len(train_data)) train_data_shuffled = train_data[perm] # Process in small batches for batch_idx in range(0, len(train_data_shuffled), batch_size): batch_end = min(batch_idx + batch_size, len(train_data_shuffled)) batch = train_data_shuffled[batch_idx:batch_end] if len(batch) == 0: continue optimizer.zero_grad() # Use CPU autocast for mixed precision with cpu_autocast(): logits, telemetry = model(batch) # Standard autoregressive loss pred = logits[:, :-1, :].reshape(-1, 2) target = batch[:, 1:].reshape(-1) loss = F.cross_entropy(pred, target) # Backward pass loss.backward() torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) optimizer.step() # Only step scheduler if we haven't exceeded total steps if scheduler.last_epoch < scheduler.total_steps - 1: scheduler.step() batch_loss = loss.item() epoch_losses.append(batch_loss) # Log progress every 50 steps if (batch_idx // batch_size) % 50 == 0: avg_loss = sum(epoch_losses[-10:]) / min(10, len(epoch_losses)) telemetry_str = f"K={telemetry.get('K', 0):.3f}, C={telemetry.get('C', 0):.3f}, S={telemetry.get('S', 0):.3f}" print(f" Step {batch_idx // batch_size}: Loss={avg_loss:.4f}, {telemetry_str}") epoch_time = time.time() - epoch_start_time avg_epoch_loss = sum(epoch_losses) / len(epoch_losses) train_losses.append(avg_epoch_loss) print(f" โฑ๏ธ Epoch {epoch + 1} completed in {epoch_time:.1f}s, Avg Loss: {avg_epoch_loss:.4f}") # Validation every epoch if len(valid_data) > 0: model.eval() set_dropout(model, 0.0) with torch.no_grad(): with cpu_autocast(): val_batch = valid_data[:min(8, len(valid_data))] # Small validation batch val_logits, val_telemetry = model(val_batch) val_pred = val_logits[:, :-1, :].reshape(-1, 2) val_target = val_batch[:, 1:].reshape(-1) val_loss = F.cross_entropy(val_pred, val_target).item() print(f" ๐Ÿ“Š Validation Loss: {val_loss:.4f}") print(f" ๐Ÿ“ˆ Telemetry - K: {val_telemetry.get('K', 0):.3f}, C: {val_telemetry.get('C', 0):.3f}, S: {val_telemetry.get('S', 0):.3f}") model.train() set_dropout(model, 0.1) print(f"\nโœ… Training completed!") print(f" Final training loss: {train_losses[-1]:.4f}") return model, train_losses def test_model_inference(model, test_texts): """Test the trained model with inference and safety checks.""" print("\n๐Ÿงช Testing Model Inference...") model.eval() set_dropout(model, 0.0) # Test basic inference test_samples = test_texts[:3] # Test with first 3 samples for i, text in enumerate(test_samples): print(f"\n Test {i + 1}: {text[:50]}...") try: # Convert to bits input_bits = text_to_bits(text)[:64] # Shorter for demo if len(input_bits) < 64: input_bits.extend([0] * (64 - len(input_bits))) input_tensor = torch.tensor([input_bits], dtype=torch.long) # Run inference with CPU autocast with torch.no_grad(): with cpu_autocast(): logits, telemetry = model(input_tensor) # Generate next tokens next_token_logits = logits[0, -1, :] next_token_probs = F.softmax(next_token_logits, dim=-1) next_token = torch.multinomial(next_token_probs, 1).item() print(f" Input bits: {input_bits[:16]}... (showing first 16)") print(f" Next token prediction: {next_token}") print(f" Next token confidence: {next_token_probs[next_token]:.3f}") print(f" Telemetry - K: {telemetry.get('K', 0):.3f}, C: {telemetry.get('C', 0):.3f}, S: {telemetry.get('S', 0):.3f}") except Exception as e: print(f" โŒ Inference failed: {e}") # Test safe inference print(f"\n๐Ÿ›ก๏ธ Testing Safe Inference...") try: # Create a simple prompt test_prompt = "The future of AI is" prompt_bits = text_to_bits(test_prompt) prompt_tensor = torch.tensor([prompt_bits], dtype=torch.long) with cpu_autocast(): safe_result = hil_safe_inference(model, prompt_tensor, max_new_tokens=16) if safe_result is not None: print(f" โœ… Safe inference successful") print(f" Generated {len(safe_result[0]) - len(prompt_bits)} new tokens") else: print(f" โš ๏ธ Safe inference blocked by safety gates") except Exception as e: print(f" โŒ Safe inference test failed: {e}") def benchmark_cpu_performance(model): """Benchmark the model's CPU performance.""" print("\nโšก CPU Performance Benchmark...") model.eval() set_dropout(model, 0.0) # Prepare test data batch_sizes = [1, 2, 4] sequence_lengths = [32, 64, 128] results = [] for batch_size in batch_sizes: for seq_len in sequence_lengths: print(f"\n Testing batch_size={batch_size}, seq_len={seq_len}") # Create random test data test_data = torch.randint(0, 2, (batch_size, seq_len), dtype=torch.long) # Warmup with torch.no_grad(): with cpu_autocast(): for _ in range(3): _, _ = model(test_data) # Benchmark times = [] for _ in range(10): start_time = time.time() with torch.no_grad(): with cpu_autocast(): logits, telemetry = model(test_data) end_time = time.time() times.append(end_time - start_time) avg_time = sum(times) / len(times) throughput = (batch_size * seq_len) / avg_time result = { 'batch_size': batch_size, 'seq_len': seq_len, 'avg_time_ms': avg_time * 1000, 'throughput_tokens_per_sec': throughput } results.append(result) print(f" Average time: {avg_time * 1000:.2f}ms") print(f" Throughput: {throughput:.0f} tokens/sec") # Summary print(f"\n๐Ÿ“Š Performance Summary:") best_throughput = max(results, key=lambda x: x['throughput_tokens_per_sec']) print(f" Best throughput: {best_throughput['throughput_tokens_per_sec']:.0f} tokens/sec") print(f" At batch_size={best_throughput['batch_size']}, seq_len={best_throughput['seq_len']}") return results def quantize_for_deployment(model): """Apply dynamic quantization for deployment.""" print("\n๐Ÿ—œ๏ธ Applying Dynamic Quantization for Deployment...") try: quantized_model = quantize_dynamic(model) # Compare model sizes original_params = sum(p.numel() for p in model.parameters()) quantized_params = sum(p.numel() for p in quantized_model.parameters()) print(f" Original parameters: {original_params:,}") print(f" Quantized parameters: {quantized_params:,}") print(f" Model size reduction: ~50% (FP32 -> INT8)") # Quick inference test test_input = torch.randint(0, 2, (1, 32), dtype=torch.long) with torch.no_grad(): original_output = model(test_input) quantized_output = quantized_model(test_input) print(f" โœ… Quantization successful - model still functional") return quantized_model except Exception as e: print(f" โŒ Quantization failed: {e}") return model def main(): """Main training and testing pipeline.""" print("๐Ÿš€ CPU-Optimized BitTransformerLM Training Pipeline") print("="*60) # Step 1: Create optimal CPU model model = create_optimal_cpu_model() # Step 2: Load training dataset train_data, valid_data, train_texts = load_training_dataset(dataset_size=256, max_len=128) # Step 3: Train the model trained_model, train_losses = train_cpu_optimized_model(model, train_data, valid_data, epochs=3) # Step 4: Test inference test_model_inference(trained_model, train_texts) # Step 5: Benchmark performance benchmark_results = benchmark_cpu_performance(trained_model) # Step 6: Apply quantization quantized_model = quantize_for_deployment(trained_model) # Step 7: Save models print("\n๐Ÿ’พ Saving Models...") # Create weights directory if it doesn't exist os.makedirs("weights", exist_ok=True) try: save_model(trained_model, "weights/cpu_edge_model.pt.gz") print(" โœ… Saved trained model: weights/cpu_edge_model.pt.gz") save_model(quantized_model, "weights/cpu_edge_model_quantized.pt.gz") print(" โœ… Saved quantized model: weights/cpu_edge_model_quantized.pt.gz") except Exception as e: print(f" โš ๏ธ Model saving failed: {e}") # Final summary print("\n" + "="*60) print("๐ŸŽ‰ CPU-Optimized BitTransformerLM Training Complete!") print("="*60) total_params = sum(p.numel() for p in trained_model.parameters()) final_loss = train_losses[-1] if train_losses else "N/A" best_throughput = max(benchmark_results, key=lambda x: x['throughput_tokens_per_sec']) print(f"๐Ÿ“Š Final Results:") print(f" Model Parameters: {total_params:,}") print(f" Final Training Loss: {final_loss}") print(f" Peak Throughput: {best_throughput['throughput_tokens_per_sec']:.0f} tokens/sec") print(f" Model Size (quantized): ~{total_params * 1 / 1024 / 1024:.1f}MB") print(f" CPU Optimizations: BF16 autocast, no gradient checkpointing, small chunks") print(f" Edge Ready: โœ… Optimized for consumer CPUs") if __name__ == "__main__": main()