|
|
|
""" |
|
Test the trained BitTransformerLM model and validate all features. |
|
""" |
|
|
|
import torch |
|
import numpy as np |
|
import logging |
|
from enhanced_checkpoint_system import create_checkpoint_manager |
|
from bit_transformer.model import BitTransformerLM |
|
from bit_transformer.compression import compress_bits_batch, model_output_decompress |
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
def test_trained_model(): |
|
"""Test the most recent trained model.""" |
|
|
|
print("π§ͺ Testing trained BitTransformerLM model...") |
|
|
|
|
|
manager = create_checkpoint_manager() |
|
|
|
|
|
sessions = list(manager.sessions_dir.iterdir()) |
|
if not sessions: |
|
print("β No training sessions found") |
|
return |
|
|
|
latest_session = max(sessions, key=lambda x: x.stat().st_mtime) |
|
session_id = latest_session.name |
|
|
|
print(f"π Loading from session: {session_id}") |
|
|
|
|
|
model = BitTransformerLM( |
|
d_model=256, |
|
nhead=8, |
|
num_layers=4, |
|
dim_feedforward=512, |
|
max_seq_len=128, |
|
use_checkpoint=True, |
|
chunk_size=None |
|
) |
|
|
|
|
|
try: |
|
checkpoint_data = manager.load_checkpoint(session_id, model=model) |
|
print(f"β
Model loaded from: {checkpoint_data['checkpoint_path']}") |
|
|
|
metrics = checkpoint_data['model_data']['metrics'] |
|
print(f"π Training metrics - Loss: {metrics['loss']:.4f}, " |
|
f"K: {metrics['K_negentropy']:.3f}, " |
|
f"C: {metrics['C_complexity']:.3f}, " |
|
f"S: {metrics['S_symbiosis']:.3f}") |
|
|
|
except Exception as e: |
|
print(f"β Failed to load checkpoint: {e}") |
|
return |
|
|
|
|
|
model.eval() |
|
with torch.no_grad(): |
|
print("\n㪠Testing model inference...") |
|
|
|
|
|
test_input1 = torch.tensor([[0, 1, 0, 1, 0, 1, 0, 1]], dtype=torch.long) |
|
output1 = model(test_input1) |
|
|
|
if isinstance(output1, tuple): |
|
logits1, telemetry1 = output1 |
|
print(f"β
Forward pass successful, output shape: {logits1.shape}") |
|
print(f"π‘ Telemetry keys: {list(telemetry1.keys())}") |
|
else: |
|
logits1 = output1 |
|
print(f"β
Forward pass successful, output shape: {logits1.shape}") |
|
|
|
|
|
if logits1.dim() == 3: |
|
predictions1 = torch.argmax(logits1, dim=-1) |
|
else: |
|
predictions1 = torch.argmax(logits1.reshape(1, 8, 2), dim=-1) |
|
|
|
print(f"π₯ Input: {test_input1.squeeze().tolist()}") |
|
print(f"π€ Output: {predictions1.squeeze().tolist()}") |
|
|
|
|
|
test_input2 = torch.randint(0, 2, (1, 16), dtype=torch.long) |
|
output2 = model(test_input2) |
|
|
|
if isinstance(output2, tuple): |
|
logits2, telemetry2 = output2 |
|
else: |
|
logits2 = output2 |
|
|
|
predictions2 = torch.argmax(logits2.reshape(1, 16, 2), dim=-1) |
|
print(f"\nπ₯ Random input: {test_input2.squeeze().tolist()}") |
|
print(f"π€ Model output: {predictions2.squeeze().tolist()}") |
|
|
|
|
|
print("\nποΈ Testing compression features...") |
|
|
|
|
|
long_sequence = torch.randint(0, 2, (1, 64), dtype=torch.long) |
|
|
|
|
|
compressed = compress_bits_batch(long_sequence) |
|
print(f"Original length: {long_sequence.shape[-1]}") |
|
print(f"Compressed length: {len(compressed[0])}") |
|
print(f"Compression ratio: {len(compressed[0]) / long_sequence.shape[-1]:.2f}") |
|
|
|
|
|
decompressed = model_output_decompress(compressed) |
|
compression_success = torch.equal(long_sequence, decompressed) |
|
print(f"β
Compression/decompression successful: {compression_success}") |
|
|
|
|
|
print("\nπ‘οΈ Testing safety metrics...") |
|
|
|
def compute_safety_metrics(predictions, targets): |
|
pred_bits = predictions.float().flatten() |
|
target_bits = targets.float().flatten() |
|
|
|
|
|
prob_1 = pred_bits.mean().item() |
|
prob_0 = 1 - prob_1 |
|
if prob_0 > 0 and prob_1 > 0: |
|
entropy = -prob_0 * np.log2(prob_0) - prob_1 * np.log2(prob_1) |
|
negentropy = 1.0 - entropy |
|
else: |
|
negentropy = 1.0 |
|
|
|
|
|
changes = (pred_bits[1:] != pred_bits[:-1]).sum().item() |
|
complexity = changes / len(pred_bits) if len(pred_bits) > 1 else 0.0 |
|
|
|
|
|
target_mean = target_bits.mean() |
|
pred_mean = pred_bits.mean() |
|
symbiosis = 1.0 - abs(target_mean - pred_mean).item() |
|
|
|
return { |
|
'K_negentropy': negentropy, |
|
'C_complexity': complexity, |
|
'S_symbiosis': symbiosis |
|
} |
|
|
|
|
|
test_patterns = [ |
|
[0, 1, 0, 1, 0, 1, 0, 1], |
|
[1, 1, 1, 1, 0, 0, 0, 0], |
|
[0, 1, 1, 0, 1, 0, 1, 1], |
|
] |
|
|
|
for i, pattern in enumerate(test_patterns): |
|
test_seq = torch.tensor([pattern], dtype=torch.long) |
|
model_out = model(test_seq) |
|
if isinstance(model_out, tuple): |
|
model_logits, _ = model_out |
|
else: |
|
model_logits = model_out |
|
|
|
model_preds = torch.argmax(model_logits.reshape(1, len(pattern), 2), dim=-1) |
|
metrics = compute_safety_metrics(model_preds, test_seq) |
|
|
|
print(f"Pattern {i+1}: K={metrics['K_negentropy']:.3f}, " |
|
f"C={metrics['C_complexity']:.3f}, " |
|
f"S={metrics['S_symbiosis']:.3f}") |
|
|
|
|
|
print(f"\nπΎ Storage usage report:") |
|
usage = manager.get_storage_usage() |
|
print(f"Total storage used: {usage['total_gb']:.3f} GB") |
|
print(f"Training sessions: {usage['num_sessions']}") |
|
print(f"Best models saved: {usage['num_best_models']}") |
|
|
|
for session in usage['sessions'][:3]: |
|
print(f" - {session['session_id']}: {session['size_gb']:.3f} GB " |
|
f"({session['num_checkpoints']} checkpoints)") |
|
|
|
print("\nπ Model testing completed successfully!") |
|
return True |
|
|
|
if __name__ == "__main__": |
|
success = test_trained_model() |
|
if success: |
|
print("β
ALL TESTS PASSED!") |
|
else: |
|
print("β Some tests failed") |