|
|
|
""" |
|
CPU-Optimized Edge Deployment BitTransformerLM Training |
|
Optimized for consumer devices and edge applications. |
|
""" |
|
|
|
import os |
|
import time |
|
import torch |
|
import torch.nn.functional as F |
|
from datasets import load_dataset |
|
|
|
from bit_transformer import ( |
|
BitTransformerLM, |
|
text_to_bits, |
|
bits_to_text, |
|
train_loop, |
|
configure_optimizer, |
|
save_model, |
|
load_model, |
|
set_dropout, |
|
hil_safe_inference, |
|
quantize_dynamic, |
|
) |
|
from bit_transformer.torch_utils import cpu_autocast |
|
from bit_transformer.training import train_loop |
|
|
|
|
|
def create_optimal_cpu_model(): |
|
"""Create BitTransformerLM optimized for CPU edge deployment.""" |
|
print("π§ Creating CPU-optimized BitTransformerLM...") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
model = BitTransformerLM( |
|
d_model=64, |
|
nhead=4, |
|
num_layers=3, |
|
dim_feedforward=128, |
|
max_seq_len=256, |
|
reversible=False, |
|
use_checkpoint=False, |
|
use_autocast=True, |
|
use_act=False, |
|
chunk_size=32, |
|
full_attn_logging=False, |
|
lambda_K=1.0, |
|
lambda_C=1.0, |
|
lambda_S=1.0, |
|
) |
|
|
|
|
|
total_params = sum(p.numel() for p in model.parameters()) |
|
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad) |
|
|
|
print(f" π Model Configuration:") |
|
print(f" d_model: {64}") |
|
print(f" num_layers: {3}") |
|
print(f" nhead: {4}") |
|
print(f" dim_feedforward: {128}") |
|
print(f" max_seq_len: {256}") |
|
print(f" Total parameters: {total_params:,}") |
|
print(f" Trainable parameters: {trainable_params:,}") |
|
print(f" Estimated size: {total_params * 4 / 1024 / 1024:.1f}MB (FP32)") |
|
print(f" With autocast: ~{total_params * 2 / 1024 / 1024:.1f}MB (BF16)") |
|
|
|
return model |
|
|
|
|
|
def load_training_dataset(dataset_size=512, max_len=128): |
|
"""Load and prepare training dataset optimized for edge training.""" |
|
print("π Loading training dataset...") |
|
|
|
try: |
|
|
|
print(" Attempting to load BitTransformerLM dataset...") |
|
dataset = load_dataset("WCNegentropy/BitTransformerLM", split="train[:{}]".format(dataset_size)) |
|
if dataset and len(dataset) > 0: |
|
train_texts = [item['text'] for item in dataset if item.get('text')] |
|
if len(train_texts) > 0: |
|
print(f" β
Loaded {len(train_texts)} samples from BitTransformerLM dataset") |
|
else: |
|
raise Exception("No text samples found in dataset") |
|
else: |
|
raise Exception("Dataset empty or not accessible") |
|
|
|
except Exception as e: |
|
print(f" β οΈ BitTransformerLM dataset not available: {e}") |
|
print(" π Falling back to WikiText-2...") |
|
try: |
|
|
|
ds = load_dataset("wikitext", "wikitext-2-raw-v1") |
|
train_texts = [text for text in ds["train"]["text"] if text.strip()][:dataset_size] |
|
print(f" β
Loaded {len(train_texts)} samples from WikiText-2") |
|
except Exception as e2: |
|
print(f" β Failed to load WikiText-2: {e2}") |
|
print(" π² Using synthetic text data...") |
|
|
|
synthetic_texts = [ |
|
"The quick brown fox jumps over the lazy dog.", |
|
"Machine learning is transforming technology.", |
|
"Edge computing enables local AI processing.", |
|
"BitTransformerLM uses bit-native processing.", |
|
"CPU optimization improves inference speed.", |
|
"Neural networks learn from training data.", |
|
"Transformers use attention mechanisms.", |
|
"Language models understand text patterns.", |
|
] |
|
train_texts = (synthetic_texts * (dataset_size // len(synthetic_texts) + 1))[:dataset_size] |
|
print(f" β
Generated {len(train_texts)} synthetic samples") |
|
|
|
|
|
print(" π Converting text to bits...") |
|
train_sequences = [] |
|
valid_sequences = [] |
|
|
|
for i, text in enumerate(train_texts): |
|
try: |
|
bits = text_to_bits(text)[:max_len] |
|
if len(bits) < max_len: |
|
bits.extend([0] * (max_len - len(bits))) |
|
|
|
|
|
if i < len(train_texts) * 0.8: |
|
train_sequences.append(bits) |
|
else: |
|
valid_sequences.append(bits) |
|
|
|
except Exception as e: |
|
print(f" β οΈ Failed to convert text to bits: {e}") |
|
continue |
|
|
|
train_tensor = torch.tensor(train_sequences, dtype=torch.long) |
|
valid_tensor = torch.tensor(valid_sequences, dtype=torch.long) if valid_sequences else train_tensor[:16] |
|
|
|
print(f" π Dataset Statistics:") |
|
print(f" Training sequences: {len(train_sequences)}") |
|
print(f" Validation sequences: {len(valid_sequences)}") |
|
print(f" Sequence length: {max_len}") |
|
print(f" Training tensor shape: {train_tensor.shape}") |
|
|
|
return train_tensor, valid_tensor, train_texts[:len(train_sequences)] |
|
|
|
|
|
def train_cpu_optimized_model(model, train_data, valid_data, epochs=5): |
|
"""Train the model with CPU-optimized settings.""" |
|
print(f"π Training CPU-optimized BitTransformerLM for {epochs} epochs...") |
|
|
|
|
|
model.train() |
|
set_dropout(model, 0.1) |
|
|
|
|
|
|
|
batch_size = 4 |
|
learning_rate = 5e-4 |
|
total_steps = max(1, epochs * (len(train_data) // batch_size)) |
|
|
|
if len(train_data) == 0: |
|
raise ValueError("No training data available - check dataset loading") |
|
|
|
optimizer, scheduler = configure_optimizer( |
|
model, |
|
lr=learning_rate, |
|
total_steps=total_steps, |
|
weight_decay=0.01 |
|
) |
|
|
|
print(f" π Training Configuration:") |
|
print(f" Batch size: {batch_size}") |
|
print(f" Learning rate: {learning_rate}") |
|
print(f" Total steps: {total_steps}") |
|
print(f" CPU autocast: Enabled") |
|
|
|
|
|
train_losses = [] |
|
|
|
for epoch in range(epochs): |
|
print(f"\n π Epoch {epoch + 1}/{epochs}") |
|
epoch_losses = [] |
|
epoch_start_time = time.time() |
|
|
|
|
|
perm = torch.randperm(len(train_data)) |
|
train_data_shuffled = train_data[perm] |
|
|
|
|
|
for batch_idx in range(0, len(train_data_shuffled), batch_size): |
|
batch_end = min(batch_idx + batch_size, len(train_data_shuffled)) |
|
batch = train_data_shuffled[batch_idx:batch_end] |
|
|
|
if len(batch) == 0: |
|
continue |
|
|
|
optimizer.zero_grad() |
|
|
|
|
|
with cpu_autocast(): |
|
logits, telemetry = model(batch) |
|
|
|
|
|
pred = logits[:, :-1, :].reshape(-1, 2) |
|
target = batch[:, 1:].reshape(-1) |
|
loss = F.cross_entropy(pred, target) |
|
|
|
|
|
loss.backward() |
|
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) |
|
optimizer.step() |
|
|
|
|
|
if scheduler.last_epoch < scheduler.total_steps - 1: |
|
scheduler.step() |
|
|
|
batch_loss = loss.item() |
|
epoch_losses.append(batch_loss) |
|
|
|
|
|
if (batch_idx // batch_size) % 50 == 0: |
|
avg_loss = sum(epoch_losses[-10:]) / min(10, len(epoch_losses)) |
|
telemetry_str = f"K={telemetry.get('K', 0):.3f}, C={telemetry.get('C', 0):.3f}, S={telemetry.get('S', 0):.3f}" |
|
print(f" Step {batch_idx // batch_size}: Loss={avg_loss:.4f}, {telemetry_str}") |
|
|
|
epoch_time = time.time() - epoch_start_time |
|
avg_epoch_loss = sum(epoch_losses) / len(epoch_losses) |
|
train_losses.append(avg_epoch_loss) |
|
|
|
print(f" β±οΈ Epoch {epoch + 1} completed in {epoch_time:.1f}s, Avg Loss: {avg_epoch_loss:.4f}") |
|
|
|
|
|
if len(valid_data) > 0: |
|
model.eval() |
|
set_dropout(model, 0.0) |
|
|
|
with torch.no_grad(): |
|
with cpu_autocast(): |
|
val_batch = valid_data[:min(8, len(valid_data))] |
|
val_logits, val_telemetry = model(val_batch) |
|
val_pred = val_logits[:, :-1, :].reshape(-1, 2) |
|
val_target = val_batch[:, 1:].reshape(-1) |
|
val_loss = F.cross_entropy(val_pred, val_target).item() |
|
|
|
print(f" π Validation Loss: {val_loss:.4f}") |
|
print(f" π Telemetry - K: {val_telemetry.get('K', 0):.3f}, C: {val_telemetry.get('C', 0):.3f}, S: {val_telemetry.get('S', 0):.3f}") |
|
|
|
model.train() |
|
set_dropout(model, 0.1) |
|
|
|
print(f"\nβ
Training completed!") |
|
print(f" Final training loss: {train_losses[-1]:.4f}") |
|
|
|
return model, train_losses |
|
|
|
|
|
def test_model_inference(model, test_texts): |
|
"""Test the trained model with inference and safety checks.""" |
|
print("\nπ§ͺ Testing Model Inference...") |
|
|
|
model.eval() |
|
set_dropout(model, 0.0) |
|
|
|
|
|
test_samples = test_texts[:3] |
|
|
|
for i, text in enumerate(test_samples): |
|
print(f"\n Test {i + 1}: {text[:50]}...") |
|
|
|
try: |
|
|
|
input_bits = text_to_bits(text)[:64] |
|
if len(input_bits) < 64: |
|
input_bits.extend([0] * (64 - len(input_bits))) |
|
|
|
input_tensor = torch.tensor([input_bits], dtype=torch.long) |
|
|
|
|
|
with torch.no_grad(): |
|
with cpu_autocast(): |
|
logits, telemetry = model(input_tensor) |
|
|
|
|
|
next_token_logits = logits[0, -1, :] |
|
next_token_probs = F.softmax(next_token_logits, dim=-1) |
|
next_token = torch.multinomial(next_token_probs, 1).item() |
|
|
|
print(f" Input bits: {input_bits[:16]}... (showing first 16)") |
|
print(f" Next token prediction: {next_token}") |
|
print(f" Next token confidence: {next_token_probs[next_token]:.3f}") |
|
print(f" Telemetry - K: {telemetry.get('K', 0):.3f}, C: {telemetry.get('C', 0):.3f}, S: {telemetry.get('S', 0):.3f}") |
|
|
|
except Exception as e: |
|
print(f" β Inference failed: {e}") |
|
|
|
|
|
print(f"\nπ‘οΈ Testing Safe Inference...") |
|
try: |
|
|
|
test_prompt = "The future of AI is" |
|
prompt_bits = text_to_bits(test_prompt) |
|
prompt_tensor = torch.tensor([prompt_bits], dtype=torch.long) |
|
|
|
with cpu_autocast(): |
|
safe_result = hil_safe_inference(model, prompt_tensor, max_new_tokens=16) |
|
|
|
if safe_result is not None: |
|
print(f" β
Safe inference successful") |
|
print(f" Generated {len(safe_result[0]) - len(prompt_bits)} new tokens") |
|
else: |
|
print(f" β οΈ Safe inference blocked by safety gates") |
|
|
|
except Exception as e: |
|
print(f" β Safe inference test failed: {e}") |
|
|
|
|
|
def benchmark_cpu_performance(model): |
|
"""Benchmark the model's CPU performance.""" |
|
print("\nβ‘ CPU Performance Benchmark...") |
|
|
|
model.eval() |
|
set_dropout(model, 0.0) |
|
|
|
|
|
batch_sizes = [1, 2, 4] |
|
sequence_lengths = [32, 64, 128] |
|
|
|
results = [] |
|
|
|
for batch_size in batch_sizes: |
|
for seq_len in sequence_lengths: |
|
print(f"\n Testing batch_size={batch_size}, seq_len={seq_len}") |
|
|
|
|
|
test_data = torch.randint(0, 2, (batch_size, seq_len), dtype=torch.long) |
|
|
|
|
|
with torch.no_grad(): |
|
with cpu_autocast(): |
|
for _ in range(3): |
|
_, _ = model(test_data) |
|
|
|
|
|
times = [] |
|
for _ in range(10): |
|
start_time = time.time() |
|
with torch.no_grad(): |
|
with cpu_autocast(): |
|
logits, telemetry = model(test_data) |
|
end_time = time.time() |
|
times.append(end_time - start_time) |
|
|
|
avg_time = sum(times) / len(times) |
|
throughput = (batch_size * seq_len) / avg_time |
|
|
|
result = { |
|
'batch_size': batch_size, |
|
'seq_len': seq_len, |
|
'avg_time_ms': avg_time * 1000, |
|
'throughput_tokens_per_sec': throughput |
|
} |
|
results.append(result) |
|
|
|
print(f" Average time: {avg_time * 1000:.2f}ms") |
|
print(f" Throughput: {throughput:.0f} tokens/sec") |
|
|
|
|
|
print(f"\nπ Performance Summary:") |
|
best_throughput = max(results, key=lambda x: x['throughput_tokens_per_sec']) |
|
print(f" Best throughput: {best_throughput['throughput_tokens_per_sec']:.0f} tokens/sec") |
|
print(f" At batch_size={best_throughput['batch_size']}, seq_len={best_throughput['seq_len']}") |
|
|
|
return results |
|
|
|
|
|
def quantize_for_deployment(model): |
|
"""Apply dynamic quantization for deployment.""" |
|
print("\nποΈ Applying Dynamic Quantization for Deployment...") |
|
|
|
try: |
|
quantized_model = quantize_dynamic(model) |
|
|
|
|
|
original_params = sum(p.numel() for p in model.parameters()) |
|
quantized_params = sum(p.numel() for p in quantized_model.parameters()) |
|
|
|
print(f" Original parameters: {original_params:,}") |
|
print(f" Quantized parameters: {quantized_params:,}") |
|
print(f" Model size reduction: ~50% (FP32 -> INT8)") |
|
|
|
|
|
test_input = torch.randint(0, 2, (1, 32), dtype=torch.long) |
|
|
|
with torch.no_grad(): |
|
original_output = model(test_input) |
|
quantized_output = quantized_model(test_input) |
|
|
|
print(f" β
Quantization successful - model still functional") |
|
|
|
return quantized_model |
|
|
|
except Exception as e: |
|
print(f" β Quantization failed: {e}") |
|
return model |
|
|
|
|
|
def main(): |
|
"""Main training and testing pipeline.""" |
|
print("π CPU-Optimized BitTransformerLM Training Pipeline") |
|
print("="*60) |
|
|
|
|
|
model = create_optimal_cpu_model() |
|
|
|
|
|
train_data, valid_data, train_texts = load_training_dataset(dataset_size=256, max_len=128) |
|
|
|
|
|
trained_model, train_losses = train_cpu_optimized_model(model, train_data, valid_data, epochs=3) |
|
|
|
|
|
test_model_inference(trained_model, train_texts) |
|
|
|
|
|
benchmark_results = benchmark_cpu_performance(trained_model) |
|
|
|
|
|
quantized_model = quantize_for_deployment(trained_model) |
|
|
|
|
|
print("\nπΎ Saving Models...") |
|
|
|
|
|
os.makedirs("weights", exist_ok=True) |
|
|
|
try: |
|
save_model(trained_model, "weights/cpu_edge_model.pt.gz") |
|
print(" β
Saved trained model: weights/cpu_edge_model.pt.gz") |
|
|
|
save_model(quantized_model, "weights/cpu_edge_model_quantized.pt.gz") |
|
print(" β
Saved quantized model: weights/cpu_edge_model_quantized.pt.gz") |
|
|
|
except Exception as e: |
|
print(f" β οΈ Model saving failed: {e}") |
|
|
|
|
|
print("\n" + "="*60) |
|
print("π CPU-Optimized BitTransformerLM Training Complete!") |
|
print("="*60) |
|
|
|
total_params = sum(p.numel() for p in trained_model.parameters()) |
|
final_loss = train_losses[-1] if train_losses else "N/A" |
|
best_throughput = max(benchmark_results, key=lambda x: x['throughput_tokens_per_sec']) |
|
|
|
print(f"π Final Results:") |
|
print(f" Model Parameters: {total_params:,}") |
|
print(f" Final Training Loss: {final_loss}") |
|
print(f" Peak Throughput: {best_throughput['throughput_tokens_per_sec']:.0f} tokens/sec") |
|
print(f" Model Size (quantized): ~{total_params * 1 / 1024 / 1024:.1f}MB") |
|
print(f" CPU Optimizations: BF16 autocast, no gradient checkpointing, small chunks") |
|
print(f" Edge Ready: β
Optimized for consumer CPUs") |
|
|
|
|
|
if __name__ == "__main__": |
|
main() |