|
|
|
""" |
|
Full end-to-end BitTransformerLM training run with all optimizations! |
|
Small scale test to validate our enhanced system. |
|
""" |
|
|
|
import torch |
|
import torch.nn as nn |
|
import torch.optim as optim |
|
from torch.utils.data import Dataset, DataLoader |
|
import numpy as np |
|
import logging |
|
from pathlib import Path |
|
import time |
|
from typing import List, Dict, Any |
|
|
|
|
|
from bit_transformer.model import BitTransformerLM |
|
from bit_transformer.compression import compress_bits_batch, model_output_decompress |
|
from bit_transformer.error_handling import safe_model_forward, setup_error_logging |
|
from bit_transformer.types import BitSequence, TelemetryDict |
|
from enhanced_checkpoint_system import create_checkpoint_manager |
|
|
|
|
|
logger = setup_error_logging("INFO") |
|
|
|
class SimpleBitDataset(Dataset): |
|
"""Simple dataset of bit sequences for training.""" |
|
|
|
def __init__(self, num_samples: int = 1000, seq_length: int = 128): |
|
self.num_samples = num_samples |
|
self.seq_length = seq_length |
|
self.data = self._generate_bit_sequences() |
|
|
|
def _generate_bit_sequences(self) -> List[torch.Tensor]: |
|
"""Generate diverse bit sequences with different patterns.""" |
|
sequences = [] |
|
|
|
|
|
for i in range(self.num_samples // 4): |
|
pattern = torch.tensor([i % 2] * self.seq_length, dtype=torch.long) |
|
sequences.append(pattern) |
|
|
|
|
|
for i in range(self.num_samples // 4): |
|
pattern = torch.randint(0, 2, (self.seq_length,), dtype=torch.long) |
|
sequences.append(pattern) |
|
|
|
|
|
for i in range(self.num_samples // 4): |
|
pattern = [] |
|
pos = 0 |
|
while pos < self.seq_length: |
|
run_length = min(np.random.randint(1, 20), self.seq_length - pos) |
|
bit_value = np.random.randint(0, 2) |
|
pattern.extend([bit_value] * run_length) |
|
pos += run_length |
|
pattern = torch.tensor(pattern[:self.seq_length], dtype=torch.long) |
|
sequences.append(pattern) |
|
|
|
|
|
remaining = self.num_samples - len(sequences) |
|
for i in range(remaining): |
|
pattern = [0, 1] |
|
while len(pattern) < self.seq_length: |
|
pattern.append(pattern[-1] ^ pattern[-2]) |
|
pattern = torch.tensor(pattern[:self.seq_length], dtype=torch.long) |
|
sequences.append(pattern) |
|
|
|
return sequences |
|
|
|
def __len__(self): |
|
return len(self.data) |
|
|
|
def __getitem__(self, idx): |
|
sequence = self.data[idx] |
|
|
|
return sequence[:-1], sequence[1:] |
|
|
|
|
|
def compute_safety_metrics(predictions: torch.Tensor, targets: torch.Tensor) -> Dict[str, float]: |
|
"""Compute K/C/S safety metrics.""" |
|
pred_bits = (predictions > 0.5).float().flatten() |
|
|
|
|
|
if len(pred_bits) > 0: |
|
prob_1 = pred_bits.mean().item() |
|
prob_0 = 1 - prob_1 |
|
if prob_0 > 0 and prob_1 > 0: |
|
entropy = -prob_0 * np.log2(prob_0) - prob_1 * np.log2(prob_1) |
|
negentropy = 1.0 - entropy |
|
else: |
|
negentropy = 1.0 if prob_1 == 1.0 or prob_1 == 0.0 else 0.0 |
|
else: |
|
negentropy = 0.0 |
|
|
|
|
|
changes = (pred_bits[1:] != pred_bits[:-1]).sum().item() |
|
complexity = min(changes / len(pred_bits), 1.0) if len(pred_bits) > 1 else 0.0 |
|
|
|
|
|
target_bits = targets.float().flatten() |
|
if len(target_bits) > 0: |
|
target_mean = target_bits.mean() |
|
pred_mean = pred_bits.mean() |
|
symbiosis = 1.0 - abs(target_mean - pred_mean).item() |
|
else: |
|
symbiosis = 1.0 |
|
|
|
return { |
|
'K_negentropy': negentropy, |
|
'C_complexity': complexity, |
|
'S_symbiosis': symbiosis |
|
} |
|
|
|
|
|
def train_bittransformer(): |
|
"""Main training function with all optimizations.""" |
|
|
|
logger.info("π Starting BitTransformerLM end-to-end training run!") |
|
|
|
|
|
model_config = { |
|
'd_model': 256, |
|
'nhead': 8, |
|
'num_layers': 4, |
|
'dim_feedforward': 512, |
|
'max_seq_len': 128, |
|
'use_checkpoint': True, |
|
'chunk_size': None, |
|
} |
|
|
|
training_config = { |
|
'batch_size': 16, |
|
'learning_rate': 1e-3, |
|
'num_epochs': 10, |
|
'save_every_n_epochs': 2, |
|
'log_every_n_steps': 10 |
|
} |
|
|
|
|
|
checkpoint_manager = create_checkpoint_manager() |
|
session_id = checkpoint_manager.create_training_session( |
|
session_name="end_to_end_test", |
|
model_config=model_config, |
|
training_config=training_config |
|
) |
|
|
|
logger.info(f"π Created training session: {session_id}") |
|
|
|
|
|
logger.info("π Creating training dataset...") |
|
dataset = SimpleBitDataset(num_samples=800, seq_length=model_config['max_seq_len']) |
|
dataloader = DataLoader(dataset, batch_size=training_config['batch_size'], shuffle=True) |
|
|
|
|
|
logger.info("π§ Initializing BitTransformerLM model...") |
|
model = BitTransformerLM( |
|
d_model=model_config['d_model'], |
|
nhead=model_config['nhead'], |
|
num_layers=model_config['num_layers'], |
|
dim_feedforward=model_config['dim_feedforward'], |
|
max_seq_len=model_config['max_seq_len'], |
|
use_checkpoint=model_config['use_checkpoint'], |
|
chunk_size=model_config['chunk_size'] |
|
) |
|
|
|
|
|
total_params = sum(p.numel() for p in model.parameters()) |
|
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad) |
|
logger.info(f"π’ Model parameters: {total_params:,} total, {trainable_params:,} trainable") |
|
|
|
|
|
optimizer = optim.AdamW(model.parameters(), lr=training_config['learning_rate']) |
|
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=training_config['num_epochs']) |
|
criterion = nn.CrossEntropyLoss() |
|
|
|
|
|
logger.info("πββοΈ Starting training loop...") |
|
|
|
for epoch in range(training_config['num_epochs']): |
|
model.train() |
|
epoch_loss = 0.0 |
|
epoch_metrics = {'K_negentropy': 0.0, 'C_complexity': 0.0, 'S_symbiosis': 0.0} |
|
num_batches = 0 |
|
|
|
start_time = time.time() |
|
|
|
for batch_idx, (inputs, targets) in enumerate(dataloader): |
|
optimizer.zero_grad() |
|
|
|
|
|
try: |
|
|
|
output = safe_model_forward(model, inputs) |
|
if isinstance(output, tuple): |
|
logits, telemetry = output |
|
else: |
|
logits = output |
|
telemetry = {} |
|
|
|
|
|
|
|
if logits.dim() == 2: |
|
|
|
logits_flat = logits |
|
targets_flat = targets.reshape(-1) |
|
else: |
|
|
|
logits_flat = logits.reshape(-1, 2) |
|
targets_flat = targets.reshape(-1) |
|
|
|
loss = criterion(logits_flat, targets_flat) |
|
|
|
|
|
loss.backward() |
|
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) |
|
optimizer.step() |
|
|
|
|
|
with torch.no_grad(): |
|
|
|
if logits.dim() == 2: |
|
|
|
batch_size = inputs.shape[0] |
|
seq_len = inputs.shape[1] |
|
logits_reshaped = logits.reshape(batch_size, seq_len, 2) |
|
predictions = torch.softmax(logits_reshaped, dim=-1)[:, :, 1] |
|
else: |
|
|
|
predictions = torch.softmax(logits, dim=-1)[:, :, 1] |
|
|
|
safety_metrics = compute_safety_metrics(predictions, targets) |
|
|
|
epoch_loss += loss.item() |
|
for key, value in safety_metrics.items(): |
|
epoch_metrics[key] += value |
|
num_batches += 1 |
|
|
|
|
|
if batch_idx % training_config['log_every_n_steps'] == 0: |
|
logger.info(f"Epoch {epoch+1}/{training_config['num_epochs']}, " |
|
f"Batch {batch_idx}/{len(dataloader)}, " |
|
f"Loss: {loss.item():.4f}, " |
|
f"K: {safety_metrics['K_negentropy']:.3f}, " |
|
f"C: {safety_metrics['C_complexity']:.3f}, " |
|
f"S: {safety_metrics['S_symbiosis']:.3f}") |
|
|
|
except Exception as e: |
|
logger.error(f"Error in batch {batch_idx}: {e}") |
|
continue |
|
|
|
|
|
scheduler.step() |
|
epoch_time = time.time() - start_time |
|
|
|
if num_batches > 0: |
|
avg_loss = epoch_loss / num_batches |
|
avg_metrics = {k: v / num_batches for k, v in epoch_metrics.items()} |
|
|
|
logger.info(f"β
Epoch {epoch+1} completed in {epoch_time:.2f}s") |
|
logger.info(f"π Avg Loss: {avg_loss:.4f}") |
|
logger.info(f"π Safety Metrics - K: {avg_metrics['K_negentropy']:.3f}, " |
|
f"C: {avg_metrics['C_complexity']:.3f}, " |
|
f"S: {avg_metrics['S_symbiosis']:.3f}") |
|
|
|
|
|
if (epoch + 1) % training_config['save_every_n_epochs'] == 0: |
|
checkpoint_success = checkpoint_manager.save_checkpoint( |
|
model=model, |
|
session_id=session_id, |
|
epoch=epoch + 1, |
|
metrics={ |
|
'loss': avg_loss, |
|
'learning_rate': scheduler.get_last_lr()[0], |
|
**avg_metrics |
|
}, |
|
optimizer_state=optimizer.state_dict(), |
|
scheduler_state=scheduler.state_dict() |
|
) |
|
|
|
if checkpoint_success: |
|
logger.info(f"πΎ Checkpoint saved for epoch {epoch+1}") |
|
|
|
|
|
checkpoint_manager.save_best_model( |
|
session_id=session_id, |
|
model=model, |
|
metric_name='loss', |
|
metric_value=avg_loss, |
|
is_better_func=lambda x, y: x < y |
|
) |
|
|
|
logger.info("π Training completed successfully!") |
|
|
|
|
|
logger.info("π§ͺ Testing model inference and compression...") |
|
|
|
model.eval() |
|
with torch.no_grad(): |
|
|
|
test_input = torch.randint(0, 2, (1, 64), dtype=torch.long) |
|
logger.info(f"π₯ Input sequence: {test_input.squeeze().tolist()}") |
|
|
|
|
|
output_logits = model(test_input) |
|
output_probs = torch.softmax(output_logits, dim=-1) |
|
predicted_bits = torch.argmax(output_probs, dim=-1) |
|
|
|
logger.info(f"π€ Predicted sequence: {predicted_bits.squeeze().tolist()}") |
|
|
|
|
|
compressed = compress_bits_batch(predicted_bits) |
|
logger.info(f"ποΈ Compressed length: {len(compressed[0])} (original: {predicted_bits.shape[-1]})") |
|
|
|
|
|
decompressed = model_output_decompress(compressed) |
|
compression_success = torch.equal(predicted_bits, decompressed) |
|
logger.info(f"β
Compression/decompression successful: {compression_success}") |
|
|
|
|
|
storage_usage = checkpoint_manager.get_storage_usage() |
|
logger.info(f"πΎ Final storage usage: {storage_usage['total_gb']:.3f} GB") |
|
logger.info(f"π Training sessions: {storage_usage['num_sessions']}") |
|
|
|
return session_id, model, checkpoint_manager |
|
|
|
|
|
if __name__ == "__main__": |
|
try: |
|
session_id, trained_model, manager = train_bittransformer() |
|
print(f"\nπ SUCCESS! Training session completed: {session_id}") |
|
print(f"π Use checkpoint_manager.load_checkpoint('{session_id}') to resume") |
|
|
|
except Exception as e: |
|
logger.error(f"β Training failed: {e}") |
|
import traceback |
|
traceback.print_exc() |
|
raise |