BitTransformerLM / cpu_edge_training.py
WCNegentropy's picture
πŸš€ OS Launch: Clean documentation and refined licensing
2f39acc verified
raw
history blame
18.4 kB
#!/usr/bin/env python3
"""
CPU-Optimized Edge Deployment BitTransformerLM Training
Optimized for consumer devices and edge applications.
"""
import os
import time
import torch
import torch.nn.functional as F
from datasets import load_dataset
from bit_transformer import (
BitTransformerLM,
text_to_bits,
bits_to_text,
train_loop,
configure_optimizer,
save_model,
load_model,
set_dropout,
hil_safe_inference,
quantize_dynamic,
)
from bit_transformer.torch_utils import cpu_autocast
from bit_transformer.training import train_loop
def create_optimal_cpu_model():
"""Create BitTransformerLM optimized for CPU edge deployment."""
print("🧠 Creating CPU-optimized BitTransformerLM...")
# Optimal configuration for edge devices:
# - Small model size for low memory footprint
# - CPU autocast for faster FP16 inference
# - No reversible layers (simpler for CPU)
# - Gradient checkpointing disabled for speed
# - Small context length for efficiency
model = BitTransformerLM(
d_model=64, # Small embedding dimension (vs 128 default)
nhead=4, # Fewer attention heads (vs 8 default)
num_layers=3, # Shallow model (vs 4 default)
dim_feedforward=128, # Smaller FFN (vs 512 default)
max_seq_len=256, # Shorter context (vs 1024 default)
reversible=False, # Disable reversible layers (CPU doesn't benefit much)
use_checkpoint=False, # Disable gradient checkpointing (prioritize speed)
use_autocast=True, # Enable CPU autocast for BF16 mixed precision
use_act=False, # Disable ACT for simplicity
chunk_size=32, # Small chunks for memory efficiency
full_attn_logging=False, # Disable attention logging to save memory
lambda_K=1.0, # Standard telemetry weights
lambda_C=1.0,
lambda_S=1.0,
)
# Calculate model parameters
total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f" πŸ“Š Model Configuration:")
print(f" d_model: {64}")
print(f" num_layers: {3}")
print(f" nhead: {4}")
print(f" dim_feedforward: {128}")
print(f" max_seq_len: {256}")
print(f" Total parameters: {total_params:,}")
print(f" Trainable parameters: {trainable_params:,}")
print(f" Estimated size: {total_params * 4 / 1024 / 1024:.1f}MB (FP32)")
print(f" With autocast: ~{total_params * 2 / 1024 / 1024:.1f}MB (BF16)")
return model
def load_training_dataset(dataset_size=512, max_len=128):
"""Load and prepare training dataset optimized for edge training."""
print("πŸ“š Loading training dataset...")
try:
# Try to load BitTransformerLM dataset from HuggingFace
print(" Attempting to load BitTransformerLM dataset...")
dataset = load_dataset("WCNegentropy/BitTransformerLM", split="train[:{}]".format(dataset_size))
if dataset and len(dataset) > 0:
train_texts = [item['text'] for item in dataset if item.get('text')]
if len(train_texts) > 0:
print(f" βœ… Loaded {len(train_texts)} samples from BitTransformerLM dataset")
else:
raise Exception("No text samples found in dataset")
else:
raise Exception("Dataset empty or not accessible")
except Exception as e:
print(f" ⚠️ BitTransformerLM dataset not available: {e}")
print(" πŸ“– Falling back to WikiText-2...")
try:
# Fallback to WikiText-2 for training
ds = load_dataset("wikitext", "wikitext-2-raw-v1")
train_texts = [text for text in ds["train"]["text"] if text.strip()][:dataset_size]
print(f" βœ… Loaded {len(train_texts)} samples from WikiText-2")
except Exception as e2:
print(f" ❌ Failed to load WikiText-2: {e2}")
print(" 🎲 Using synthetic text data...")
# Generate simple synthetic text for demonstration
synthetic_texts = [
"The quick brown fox jumps over the lazy dog.",
"Machine learning is transforming technology.",
"Edge computing enables local AI processing.",
"BitTransformerLM uses bit-native processing.",
"CPU optimization improves inference speed.",
"Neural networks learn from training data.",
"Transformers use attention mechanisms.",
"Language models understand text patterns.",
]
train_texts = (synthetic_texts * (dataset_size // len(synthetic_texts) + 1))[:dataset_size]
print(f" βœ… Generated {len(train_texts)} synthetic samples")
# Convert text to bits
print(" πŸ”„ Converting text to bits...")
train_sequences = []
valid_sequences = []
for i, text in enumerate(train_texts):
try:
bits = text_to_bits(text)[:max_len]
if len(bits) < max_len:
bits.extend([0] * (max_len - len(bits))) # Pad to max_len
# Use 80/20 split for train/validation
if i < len(train_texts) * 0.8:
train_sequences.append(bits)
else:
valid_sequences.append(bits)
except Exception as e:
print(f" ⚠️ Failed to convert text to bits: {e}")
continue
train_tensor = torch.tensor(train_sequences, dtype=torch.long)
valid_tensor = torch.tensor(valid_sequences, dtype=torch.long) if valid_sequences else train_tensor[:16]
print(f" πŸ“Š Dataset Statistics:")
print(f" Training sequences: {len(train_sequences)}")
print(f" Validation sequences: {len(valid_sequences)}")
print(f" Sequence length: {max_len}")
print(f" Training tensor shape: {train_tensor.shape}")
return train_tensor, valid_tensor, train_texts[:len(train_sequences)]
def train_cpu_optimized_model(model, train_data, valid_data, epochs=5):
"""Train the model with CPU-optimized settings."""
print(f"πŸš€ Training CPU-optimized BitTransformerLM for {epochs} epochs...")
# Set model to training mode
model.train()
set_dropout(model, 0.1)
# Configure optimizer for edge deployment
# Lower learning rate and smaller batch size for stable CPU training
batch_size = 4 # Small batch size for memory efficiency
learning_rate = 5e-4 # Conservative learning rate
total_steps = max(1, epochs * (len(train_data) // batch_size)) # Ensure at least 1 step
if len(train_data) == 0:
raise ValueError("No training data available - check dataset loading")
optimizer, scheduler = configure_optimizer(
model,
lr=learning_rate,
total_steps=total_steps,
weight_decay=0.01
)
print(f" πŸ“‹ Training Configuration:")
print(f" Batch size: {batch_size}")
print(f" Learning rate: {learning_rate}")
print(f" Total steps: {total_steps}")
print(f" CPU autocast: Enabled")
# Training loop with CPU optimizations
train_losses = []
for epoch in range(epochs):
print(f"\n πŸ“– Epoch {epoch + 1}/{epochs}")
epoch_losses = []
epoch_start_time = time.time()
# Shuffle training data
perm = torch.randperm(len(train_data))
train_data_shuffled = train_data[perm]
# Process in small batches
for batch_idx in range(0, len(train_data_shuffled), batch_size):
batch_end = min(batch_idx + batch_size, len(train_data_shuffled))
batch = train_data_shuffled[batch_idx:batch_end]
if len(batch) == 0:
continue
optimizer.zero_grad()
# Use CPU autocast for mixed precision
with cpu_autocast():
logits, telemetry = model(batch)
# Standard autoregressive loss
pred = logits[:, :-1, :].reshape(-1, 2)
target = batch[:, 1:].reshape(-1)
loss = F.cross_entropy(pred, target)
# Backward pass
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
optimizer.step()
# Only step scheduler if we haven't exceeded total steps
if scheduler.last_epoch < scheduler.total_steps - 1:
scheduler.step()
batch_loss = loss.item()
epoch_losses.append(batch_loss)
# Log progress every 50 steps
if (batch_idx // batch_size) % 50 == 0:
avg_loss = sum(epoch_losses[-10:]) / min(10, len(epoch_losses))
telemetry_str = f"K={telemetry.get('K', 0):.3f}, C={telemetry.get('C', 0):.3f}, S={telemetry.get('S', 0):.3f}"
print(f" Step {batch_idx // batch_size}: Loss={avg_loss:.4f}, {telemetry_str}")
epoch_time = time.time() - epoch_start_time
avg_epoch_loss = sum(epoch_losses) / len(epoch_losses)
train_losses.append(avg_epoch_loss)
print(f" ⏱️ Epoch {epoch + 1} completed in {epoch_time:.1f}s, Avg Loss: {avg_epoch_loss:.4f}")
# Validation every epoch
if len(valid_data) > 0:
model.eval()
set_dropout(model, 0.0)
with torch.no_grad():
with cpu_autocast():
val_batch = valid_data[:min(8, len(valid_data))] # Small validation batch
val_logits, val_telemetry = model(val_batch)
val_pred = val_logits[:, :-1, :].reshape(-1, 2)
val_target = val_batch[:, 1:].reshape(-1)
val_loss = F.cross_entropy(val_pred, val_target).item()
print(f" πŸ“Š Validation Loss: {val_loss:.4f}")
print(f" πŸ“ˆ Telemetry - K: {val_telemetry.get('K', 0):.3f}, C: {val_telemetry.get('C', 0):.3f}, S: {val_telemetry.get('S', 0):.3f}")
model.train()
set_dropout(model, 0.1)
print(f"\nβœ… Training completed!")
print(f" Final training loss: {train_losses[-1]:.4f}")
return model, train_losses
def test_model_inference(model, test_texts):
"""Test the trained model with inference and safety checks."""
print("\nπŸ§ͺ Testing Model Inference...")
model.eval()
set_dropout(model, 0.0)
# Test basic inference
test_samples = test_texts[:3] # Test with first 3 samples
for i, text in enumerate(test_samples):
print(f"\n Test {i + 1}: {text[:50]}...")
try:
# Convert to bits
input_bits = text_to_bits(text)[:64] # Shorter for demo
if len(input_bits) < 64:
input_bits.extend([0] * (64 - len(input_bits)))
input_tensor = torch.tensor([input_bits], dtype=torch.long)
# Run inference with CPU autocast
with torch.no_grad():
with cpu_autocast():
logits, telemetry = model(input_tensor)
# Generate next tokens
next_token_logits = logits[0, -1, :]
next_token_probs = F.softmax(next_token_logits, dim=-1)
next_token = torch.multinomial(next_token_probs, 1).item()
print(f" Input bits: {input_bits[:16]}... (showing first 16)")
print(f" Next token prediction: {next_token}")
print(f" Next token confidence: {next_token_probs[next_token]:.3f}")
print(f" Telemetry - K: {telemetry.get('K', 0):.3f}, C: {telemetry.get('C', 0):.3f}, S: {telemetry.get('S', 0):.3f}")
except Exception as e:
print(f" ❌ Inference failed: {e}")
# Test safe inference
print(f"\nπŸ›‘οΈ Testing Safe Inference...")
try:
# Create a simple prompt
test_prompt = "The future of AI is"
prompt_bits = text_to_bits(test_prompt)
prompt_tensor = torch.tensor([prompt_bits], dtype=torch.long)
with cpu_autocast():
safe_result = hil_safe_inference(model, prompt_tensor, max_new_tokens=16)
if safe_result is not None:
print(f" βœ… Safe inference successful")
print(f" Generated {len(safe_result[0]) - len(prompt_bits)} new tokens")
else:
print(f" ⚠️ Safe inference blocked by safety gates")
except Exception as e:
print(f" ❌ Safe inference test failed: {e}")
def benchmark_cpu_performance(model):
"""Benchmark the model's CPU performance."""
print("\n⚑ CPU Performance Benchmark...")
model.eval()
set_dropout(model, 0.0)
# Prepare test data
batch_sizes = [1, 2, 4]
sequence_lengths = [32, 64, 128]
results = []
for batch_size in batch_sizes:
for seq_len in sequence_lengths:
print(f"\n Testing batch_size={batch_size}, seq_len={seq_len}")
# Create random test data
test_data = torch.randint(0, 2, (batch_size, seq_len), dtype=torch.long)
# Warmup
with torch.no_grad():
with cpu_autocast():
for _ in range(3):
_, _ = model(test_data)
# Benchmark
times = []
for _ in range(10):
start_time = time.time()
with torch.no_grad():
with cpu_autocast():
logits, telemetry = model(test_data)
end_time = time.time()
times.append(end_time - start_time)
avg_time = sum(times) / len(times)
throughput = (batch_size * seq_len) / avg_time
result = {
'batch_size': batch_size,
'seq_len': seq_len,
'avg_time_ms': avg_time * 1000,
'throughput_tokens_per_sec': throughput
}
results.append(result)
print(f" Average time: {avg_time * 1000:.2f}ms")
print(f" Throughput: {throughput:.0f} tokens/sec")
# Summary
print(f"\nπŸ“Š Performance Summary:")
best_throughput = max(results, key=lambda x: x['throughput_tokens_per_sec'])
print(f" Best throughput: {best_throughput['throughput_tokens_per_sec']:.0f} tokens/sec")
print(f" At batch_size={best_throughput['batch_size']}, seq_len={best_throughput['seq_len']}")
return results
def quantize_for_deployment(model):
"""Apply dynamic quantization for deployment."""
print("\nπŸ—œοΈ Applying Dynamic Quantization for Deployment...")
try:
quantized_model = quantize_dynamic(model)
# Compare model sizes
original_params = sum(p.numel() for p in model.parameters())
quantized_params = sum(p.numel() for p in quantized_model.parameters())
print(f" Original parameters: {original_params:,}")
print(f" Quantized parameters: {quantized_params:,}")
print(f" Model size reduction: ~50% (FP32 -> INT8)")
# Quick inference test
test_input = torch.randint(0, 2, (1, 32), dtype=torch.long)
with torch.no_grad():
original_output = model(test_input)
quantized_output = quantized_model(test_input)
print(f" βœ… Quantization successful - model still functional")
return quantized_model
except Exception as e:
print(f" ❌ Quantization failed: {e}")
return model
def main():
"""Main training and testing pipeline."""
print("πŸš€ CPU-Optimized BitTransformerLM Training Pipeline")
print("="*60)
# Step 1: Create optimal CPU model
model = create_optimal_cpu_model()
# Step 2: Load training dataset
train_data, valid_data, train_texts = load_training_dataset(dataset_size=256, max_len=128)
# Step 3: Train the model
trained_model, train_losses = train_cpu_optimized_model(model, train_data, valid_data, epochs=3)
# Step 4: Test inference
test_model_inference(trained_model, train_texts)
# Step 5: Benchmark performance
benchmark_results = benchmark_cpu_performance(trained_model)
# Step 6: Apply quantization
quantized_model = quantize_for_deployment(trained_model)
# Step 7: Save models
print("\nπŸ’Ύ Saving Models...")
# Create weights directory if it doesn't exist
os.makedirs("weights", exist_ok=True)
try:
save_model(trained_model, "weights/cpu_edge_model.pt.gz")
print(" βœ… Saved trained model: weights/cpu_edge_model.pt.gz")
save_model(quantized_model, "weights/cpu_edge_model_quantized.pt.gz")
print(" βœ… Saved quantized model: weights/cpu_edge_model_quantized.pt.gz")
except Exception as e:
print(f" ⚠️ Model saving failed: {e}")
# Final summary
print("\n" + "="*60)
print("πŸŽ‰ CPU-Optimized BitTransformerLM Training Complete!")
print("="*60)
total_params = sum(p.numel() for p in trained_model.parameters())
final_loss = train_losses[-1] if train_losses else "N/A"
best_throughput = max(benchmark_results, key=lambda x: x['throughput_tokens_per_sec'])
print(f"πŸ“Š Final Results:")
print(f" Model Parameters: {total_params:,}")
print(f" Final Training Loss: {final_loss}")
print(f" Peak Throughput: {best_throughput['throughput_tokens_per_sec']:.0f} tokens/sec")
print(f" Model Size (quantized): ~{total_params * 1 / 1024 / 1024:.1f}MB")
print(f" CPU Optimizations: BF16 autocast, no gradient checkpointing, small chunks")
print(f" Edge Ready: βœ… Optimized for consumer CPUs")
if __name__ == "__main__":
main()