BitTransformerLM / test_trained_model.py
WCNegentropy's picture
πŸ€– Updated BitTransformerLM from development space
36c78b1 verified
raw
history blame
6.95 kB
#!/usr/bin/env python3
"""
Test the trained BitTransformerLM model and validate all features.
"""
import torch
import numpy as np
import logging
from enhanced_checkpoint_system import create_checkpoint_manager
from bit_transformer.model import BitTransformerLM
from bit_transformer.compression import compress_bits_batch, model_output_decompress
logger = logging.getLogger(__name__)
def test_trained_model():
"""Test the most recent trained model."""
print("πŸ§ͺ Testing trained BitTransformerLM model...")
# Load checkpoint manager
manager = create_checkpoint_manager()
# Find the most recent session
sessions = list(manager.sessions_dir.iterdir())
if not sessions:
print("❌ No training sessions found")
return
latest_session = max(sessions, key=lambda x: x.stat().st_mtime)
session_id = latest_session.name
print(f"πŸ“ Loading from session: {session_id}")
# Initialize model with same config
model = BitTransformerLM(
d_model=256,
nhead=8,
num_layers=4,
dim_feedforward=512,
max_seq_len=128,
use_checkpoint=True,
chunk_size=None
)
# Load checkpoint
try:
checkpoint_data = manager.load_checkpoint(session_id, model=model)
print(f"βœ… Model loaded from: {checkpoint_data['checkpoint_path']}")
metrics = checkpoint_data['model_data']['metrics']
print(f"πŸ“Š Training metrics - Loss: {metrics['loss']:.4f}, "
f"K: {metrics['K_negentropy']:.3f}, "
f"C: {metrics['C_complexity']:.3f}, "
f"S: {metrics['S_symbiosis']:.3f}")
except Exception as e:
print(f"❌ Failed to load checkpoint: {e}")
return
# Test inference
model.eval()
with torch.no_grad():
print("\nπŸ”¬ Testing model inference...")
# Test 1: Simple alternating pattern
test_input1 = torch.tensor([[0, 1, 0, 1, 0, 1, 0, 1]], dtype=torch.long)
output1 = model(test_input1)
if isinstance(output1, tuple):
logits1, telemetry1 = output1
print(f"βœ… Forward pass successful, output shape: {logits1.shape}")
print(f"πŸ“‘ Telemetry keys: {list(telemetry1.keys())}")
else:
logits1 = output1
print(f"βœ… Forward pass successful, output shape: {logits1.shape}")
# Get predictions
if logits1.dim() == 3:
predictions1 = torch.argmax(logits1, dim=-1)
else:
predictions1 = torch.argmax(logits1.reshape(1, 8, 2), dim=-1)
print(f"πŸ“₯ Input: {test_input1.squeeze().tolist()}")
print(f"πŸ“€ Output: {predictions1.squeeze().tolist()}")
# Test 2: Random pattern
test_input2 = torch.randint(0, 2, (1, 16), dtype=torch.long)
output2 = model(test_input2)
if isinstance(output2, tuple):
logits2, telemetry2 = output2
else:
logits2 = output2
predictions2 = torch.argmax(logits2.reshape(1, 16, 2), dim=-1)
print(f"\nπŸ“₯ Random input: {test_input2.squeeze().tolist()}")
print(f"πŸ“€ Model output: {predictions2.squeeze().tolist()}")
# Test 3: Compression/Decompression
print("\nπŸ—œοΈ Testing compression features...")
# Create a longer sequence for compression testing
long_sequence = torch.randint(0, 2, (1, 64), dtype=torch.long)
# Test compression
compressed = compress_bits_batch(long_sequence)
print(f"Original length: {long_sequence.shape[-1]}")
print(f"Compressed length: {len(compressed[0])}")
print(f"Compression ratio: {len(compressed[0]) / long_sequence.shape[-1]:.2f}")
# Test decompression
decompressed = model_output_decompress(compressed)
compression_success = torch.equal(long_sequence, decompressed)
print(f"βœ… Compression/decompression successful: {compression_success}")
# Test 4: Safety metrics computation
print("\nπŸ›‘οΈ Testing safety metrics...")
def compute_safety_metrics(predictions, targets):
pred_bits = predictions.float().flatten()
target_bits = targets.float().flatten()
# K metric (Negentropy)
prob_1 = pred_bits.mean().item()
prob_0 = 1 - prob_1
if prob_0 > 0 and prob_1 > 0:
entropy = -prob_0 * np.log2(prob_0) - prob_1 * np.log2(prob_1)
negentropy = 1.0 - entropy
else:
negentropy = 1.0
# C metric (Complexity)
changes = (pred_bits[1:] != pred_bits[:-1]).sum().item()
complexity = changes / len(pred_bits) if len(pred_bits) > 1 else 0.0
# S metric (Symbiosis)
target_mean = target_bits.mean()
pred_mean = pred_bits.mean()
symbiosis = 1.0 - abs(target_mean - pred_mean).item()
return {
'K_negentropy': negentropy,
'C_complexity': complexity,
'S_symbiosis': symbiosis
}
# Test on several patterns
test_patterns = [
[0, 1, 0, 1, 0, 1, 0, 1], # Alternating
[1, 1, 1, 1, 0, 0, 0, 0], # Block pattern
[0, 1, 1, 0, 1, 0, 1, 1], # Mixed
]
for i, pattern in enumerate(test_patterns):
test_seq = torch.tensor([pattern], dtype=torch.long)
model_out = model(test_seq)
if isinstance(model_out, tuple):
model_logits, _ = model_out
else:
model_logits = model_out
model_preds = torch.argmax(model_logits.reshape(1, len(pattern), 2), dim=-1)
metrics = compute_safety_metrics(model_preds, test_seq)
print(f"Pattern {i+1}: K={metrics['K_negentropy']:.3f}, "
f"C={metrics['C_complexity']:.3f}, "
f"S={metrics['S_symbiosis']:.3f}")
# Storage usage report
print(f"\nπŸ’Ύ Storage usage report:")
usage = manager.get_storage_usage()
print(f"Total storage used: {usage['total_gb']:.3f} GB")
print(f"Training sessions: {usage['num_sessions']}")
print(f"Best models saved: {usage['num_best_models']}")
for session in usage['sessions'][:3]: # Top 3 sessions by size
print(f" - {session['session_id']}: {session['size_gb']:.3f} GB "
f"({session['num_checkpoints']} checkpoints)")
print("\nπŸŽ‰ Model testing completed successfully!")
return True
if __name__ == "__main__":
success = test_trained_model()
if success:
print("βœ… ALL TESTS PASSED!")
else:
print("❌ Some tests failed")