File size: 6,954 Bytes
36c78b1 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 |
#!/usr/bin/env python3
"""
Test the trained BitTransformerLM model and validate all features.
"""
import torch
import numpy as np
import logging
from enhanced_checkpoint_system import create_checkpoint_manager
from bit_transformer.model import BitTransformerLM
from bit_transformer.compression import compress_bits_batch, model_output_decompress
logger = logging.getLogger(__name__)
def test_trained_model():
"""Test the most recent trained model."""
print("π§ͺ Testing trained BitTransformerLM model...")
# Load checkpoint manager
manager = create_checkpoint_manager()
# Find the most recent session
sessions = list(manager.sessions_dir.iterdir())
if not sessions:
print("β No training sessions found")
return
latest_session = max(sessions, key=lambda x: x.stat().st_mtime)
session_id = latest_session.name
print(f"π Loading from session: {session_id}")
# Initialize model with same config
model = BitTransformerLM(
d_model=256,
nhead=8,
num_layers=4,
dim_feedforward=512,
max_seq_len=128,
use_checkpoint=True,
chunk_size=None
)
# Load checkpoint
try:
checkpoint_data = manager.load_checkpoint(session_id, model=model)
print(f"β
Model loaded from: {checkpoint_data['checkpoint_path']}")
metrics = checkpoint_data['model_data']['metrics']
print(f"π Training metrics - Loss: {metrics['loss']:.4f}, "
f"K: {metrics['K_negentropy']:.3f}, "
f"C: {metrics['C_complexity']:.3f}, "
f"S: {metrics['S_symbiosis']:.3f}")
except Exception as e:
print(f"β Failed to load checkpoint: {e}")
return
# Test inference
model.eval()
with torch.no_grad():
print("\n㪠Testing model inference...")
# Test 1: Simple alternating pattern
test_input1 = torch.tensor([[0, 1, 0, 1, 0, 1, 0, 1]], dtype=torch.long)
output1 = model(test_input1)
if isinstance(output1, tuple):
logits1, telemetry1 = output1
print(f"β
Forward pass successful, output shape: {logits1.shape}")
print(f"π‘ Telemetry keys: {list(telemetry1.keys())}")
else:
logits1 = output1
print(f"β
Forward pass successful, output shape: {logits1.shape}")
# Get predictions
if logits1.dim() == 3:
predictions1 = torch.argmax(logits1, dim=-1)
else:
predictions1 = torch.argmax(logits1.reshape(1, 8, 2), dim=-1)
print(f"π₯ Input: {test_input1.squeeze().tolist()}")
print(f"π€ Output: {predictions1.squeeze().tolist()}")
# Test 2: Random pattern
test_input2 = torch.randint(0, 2, (1, 16), dtype=torch.long)
output2 = model(test_input2)
if isinstance(output2, tuple):
logits2, telemetry2 = output2
else:
logits2 = output2
predictions2 = torch.argmax(logits2.reshape(1, 16, 2), dim=-1)
print(f"\nπ₯ Random input: {test_input2.squeeze().tolist()}")
print(f"π€ Model output: {predictions2.squeeze().tolist()}")
# Test 3: Compression/Decompression
print("\nποΈ Testing compression features...")
# Create a longer sequence for compression testing
long_sequence = torch.randint(0, 2, (1, 64), dtype=torch.long)
# Test compression
compressed = compress_bits_batch(long_sequence)
print(f"Original length: {long_sequence.shape[-1]}")
print(f"Compressed length: {len(compressed[0])}")
print(f"Compression ratio: {len(compressed[0]) / long_sequence.shape[-1]:.2f}")
# Test decompression
decompressed = model_output_decompress(compressed)
compression_success = torch.equal(long_sequence, decompressed)
print(f"β
Compression/decompression successful: {compression_success}")
# Test 4: Safety metrics computation
print("\nπ‘οΈ Testing safety metrics...")
def compute_safety_metrics(predictions, targets):
pred_bits = predictions.float().flatten()
target_bits = targets.float().flatten()
# K metric (Negentropy)
prob_1 = pred_bits.mean().item()
prob_0 = 1 - prob_1
if prob_0 > 0 and prob_1 > 0:
entropy = -prob_0 * np.log2(prob_0) - prob_1 * np.log2(prob_1)
negentropy = 1.0 - entropy
else:
negentropy = 1.0
# C metric (Complexity)
changes = (pred_bits[1:] != pred_bits[:-1]).sum().item()
complexity = changes / len(pred_bits) if len(pred_bits) > 1 else 0.0
# S metric (Symbiosis)
target_mean = target_bits.mean()
pred_mean = pred_bits.mean()
symbiosis = 1.0 - abs(target_mean - pred_mean).item()
return {
'K_negentropy': negentropy,
'C_complexity': complexity,
'S_symbiosis': symbiosis
}
# Test on several patterns
test_patterns = [
[0, 1, 0, 1, 0, 1, 0, 1], # Alternating
[1, 1, 1, 1, 0, 0, 0, 0], # Block pattern
[0, 1, 1, 0, 1, 0, 1, 1], # Mixed
]
for i, pattern in enumerate(test_patterns):
test_seq = torch.tensor([pattern], dtype=torch.long)
model_out = model(test_seq)
if isinstance(model_out, tuple):
model_logits, _ = model_out
else:
model_logits = model_out
model_preds = torch.argmax(model_logits.reshape(1, len(pattern), 2), dim=-1)
metrics = compute_safety_metrics(model_preds, test_seq)
print(f"Pattern {i+1}: K={metrics['K_negentropy']:.3f}, "
f"C={metrics['C_complexity']:.3f}, "
f"S={metrics['S_symbiosis']:.3f}")
# Storage usage report
print(f"\nπΎ Storage usage report:")
usage = manager.get_storage_usage()
print(f"Total storage used: {usage['total_gb']:.3f} GB")
print(f"Training sessions: {usage['num_sessions']}")
print(f"Best models saved: {usage['num_best_models']}")
for session in usage['sessions'][:3]: # Top 3 sessions by size
print(f" - {session['session_id']}: {session['size_gb']:.3f} GB "
f"({session['num_checkpoints']} checkpoints)")
print("\nπ Model testing completed successfully!")
return True
if __name__ == "__main__":
success = test_trained_model()
if success:
print("β
ALL TESTS PASSED!")
else:
print("β Some tests failed") |