#!/usr/bin/env python3 """ Test the trained BitTransformerLM model and validate all features. """ import torch import numpy as np import logging from enhanced_checkpoint_system import create_checkpoint_manager from bit_transformer.model import BitTransformerLM from bit_transformer.compression import compress_bits_batch, model_output_decompress logger = logging.getLogger(__name__) def test_trained_model(): """Test the most recent trained model.""" print("๐Ÿงช Testing trained BitTransformerLM model...") # Load checkpoint manager manager = create_checkpoint_manager() # Find the most recent session sessions = list(manager.sessions_dir.iterdir()) if not sessions: print("โŒ No training sessions found") return latest_session = max(sessions, key=lambda x: x.stat().st_mtime) session_id = latest_session.name print(f"๐Ÿ“ Loading from session: {session_id}") # Initialize model with same config model = BitTransformerLM( d_model=256, nhead=8, num_layers=4, dim_feedforward=512, max_seq_len=128, use_checkpoint=True, chunk_size=None ) # Load checkpoint try: checkpoint_data = manager.load_checkpoint(session_id, model=model) print(f"โœ… Model loaded from: {checkpoint_data['checkpoint_path']}") metrics = checkpoint_data['model_data']['metrics'] print(f"๐Ÿ“Š Training metrics - Loss: {metrics['loss']:.4f}, " f"K: {metrics['K_negentropy']:.3f}, " f"C: {metrics['C_complexity']:.3f}, " f"S: {metrics['S_symbiosis']:.3f}") except Exception as e: print(f"โŒ Failed to load checkpoint: {e}") return # Test inference model.eval() with torch.no_grad(): print("\n๐Ÿ”ฌ Testing model inference...") # Test 1: Simple alternating pattern test_input1 = torch.tensor([[0, 1, 0, 1, 0, 1, 0, 1]], dtype=torch.long) output1 = model(test_input1) if isinstance(output1, tuple): logits1, telemetry1 = output1 print(f"โœ… Forward pass successful, output shape: {logits1.shape}") print(f"๐Ÿ“ก Telemetry keys: {list(telemetry1.keys())}") else: logits1 = output1 print(f"โœ… Forward pass successful, output shape: {logits1.shape}") # Get predictions if logits1.dim() == 3: predictions1 = torch.argmax(logits1, dim=-1) else: predictions1 = torch.argmax(logits1.reshape(1, 8, 2), dim=-1) print(f"๐Ÿ“ฅ Input: {test_input1.squeeze().tolist()}") print(f"๐Ÿ“ค Output: {predictions1.squeeze().tolist()}") # Test 2: Random pattern test_input2 = torch.randint(0, 2, (1, 16), dtype=torch.long) output2 = model(test_input2) if isinstance(output2, tuple): logits2, telemetry2 = output2 else: logits2 = output2 predictions2 = torch.argmax(logits2.reshape(1, 16, 2), dim=-1) print(f"\n๐Ÿ“ฅ Random input: {test_input2.squeeze().tolist()}") print(f"๐Ÿ“ค Model output: {predictions2.squeeze().tolist()}") # Test 3: Compression/Decompression print("\n๐Ÿ—œ๏ธ Testing compression features...") # Create a longer sequence for compression testing long_sequence = torch.randint(0, 2, (1, 64), dtype=torch.long) # Test compression compressed = compress_bits_batch(long_sequence) print(f"Original length: {long_sequence.shape[-1]}") print(f"Compressed length: {len(compressed[0])}") print(f"Compression ratio: {len(compressed[0]) / long_sequence.shape[-1]:.2f}") # Test decompression decompressed = model_output_decompress(compressed) compression_success = torch.equal(long_sequence, decompressed) print(f"โœ… Compression/decompression successful: {compression_success}") # Test 4: Safety metrics computation print("\n๐Ÿ›ก๏ธ Testing safety metrics...") def compute_safety_metrics(predictions, targets): pred_bits = predictions.float().flatten() target_bits = targets.float().flatten() # K metric (Negentropy) prob_1 = pred_bits.mean().item() prob_0 = 1 - prob_1 if prob_0 > 0 and prob_1 > 0: entropy = -prob_0 * np.log2(prob_0) - prob_1 * np.log2(prob_1) negentropy = 1.0 - entropy else: negentropy = 1.0 # C metric (Complexity) changes = (pred_bits[1:] != pred_bits[:-1]).sum().item() complexity = changes / len(pred_bits) if len(pred_bits) > 1 else 0.0 # S metric (Symbiosis) target_mean = target_bits.mean() pred_mean = pred_bits.mean() symbiosis = 1.0 - abs(target_mean - pred_mean).item() return { 'K_negentropy': negentropy, 'C_complexity': complexity, 'S_symbiosis': symbiosis } # Test on several patterns test_patterns = [ [0, 1, 0, 1, 0, 1, 0, 1], # Alternating [1, 1, 1, 1, 0, 0, 0, 0], # Block pattern [0, 1, 1, 0, 1, 0, 1, 1], # Mixed ] for i, pattern in enumerate(test_patterns): test_seq = torch.tensor([pattern], dtype=torch.long) model_out = model(test_seq) if isinstance(model_out, tuple): model_logits, _ = model_out else: model_logits = model_out model_preds = torch.argmax(model_logits.reshape(1, len(pattern), 2), dim=-1) metrics = compute_safety_metrics(model_preds, test_seq) print(f"Pattern {i+1}: K={metrics['K_negentropy']:.3f}, " f"C={metrics['C_complexity']:.3f}, " f"S={metrics['S_symbiosis']:.3f}") # Storage usage report print(f"\n๐Ÿ’พ Storage usage report:") usage = manager.get_storage_usage() print(f"Total storage used: {usage['total_gb']:.3f} GB") print(f"Training sessions: {usage['num_sessions']}") print(f"Best models saved: {usage['num_best_models']}") for session in usage['sessions'][:3]: # Top 3 sessions by size print(f" - {session['session_id']}: {session['size_gb']:.3f} GB " f"({session['num_checkpoints']} checkpoints)") print("\n๐ŸŽ‰ Model testing completed successfully!") return True if __name__ == "__main__": success = test_trained_model() if success: print("โœ… ALL TESTS PASSED!") else: print("โŒ Some tests failed")