WrinkleBrane / simple_demo.py
WCNegentropy's picture
πŸ“š Updated with scientifically rigorous documentation
dc2b9f3 verified
raw
history blame
11.7 kB
#!/usr/bin/env python3
"""
Simple WrinkleBrane Demo
Shows basic functionality and a few simple optimizations working.
"""
import sys
from pathlib import Path
sys.path.append(str(Path(__file__).resolve().parent / "src"))
import torch
import numpy as np
import matplotlib.pyplot as plt
from wrinklebrane.membrane_bank import MembraneBank
from wrinklebrane.codes import hadamard_codes, dct_codes, gaussian_codes, coherence_stats
from wrinklebrane.slicer import make_slicer
from wrinklebrane.write_ops import store_pairs
from wrinklebrane.metrics import psnr, ssim
def create_test_patterns(K, H, W, device):
"""Create diverse test patterns for demonstration."""
patterns = []
for i in range(K):
pattern = torch.zeros(H, W, device=device)
if i % 4 == 0: # Circles
center = (H // 2, W // 2)
radius = 2 + (i // 4)
for y in range(H):
for x in range(W):
if (x - center[0])**2 + (y - center[1])**2 <= radius**2:
pattern[y, x] = 1.0
elif i % 4 == 1: # Squares
size = 4 + (i // 4)
start = (H - size) // 2
end = start + size
if end <= H and end <= W:
pattern[start:end, start:end] = 1.0
elif i % 4 == 2: # Horizontal lines
y = H // 2 + (i // 4) - 1
if 0 <= y < H:
pattern[y, :] = 1.0
else: # Vertical lines
x = W // 2 + (i // 4) - 1
if 0 <= x < W:
pattern[:, x] = 1.0
patterns.append(pattern)
return torch.stack(patterns)
def demonstrate_basic_functionality():
"""Show WrinkleBrane working with perfect recall."""
print("🌊 WrinkleBrane Basic Functionality Demo")
print("="*40)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
B, L, H, W, K = 1, 32, 16, 16, 8
print(f"Configuration: L={L}, H={H}, W={W}, K={K} patterns")
print(f"Device: {device}")
# Setup
bank = MembraneBank(L, H, W, device=device)
bank.allocate(B)
C = hadamard_codes(L, K).to(device)
slicer = make_slicer(C)
patterns = create_test_patterns(K, H, W, device)
keys = torch.arange(K, device=device)
alphas = torch.ones(K, device=device)
# Store patterns
print("\nπŸ“ Storing patterns...")
M = store_pairs(bank.read(), C, keys, patterns, alphas)
bank.write(M - bank.read())
# Retrieve patterns
print("πŸ“– Retrieving patterns...")
readouts = slicer(bank.read()).squeeze(0)
# Calculate fidelity
print("\nπŸ“Š Fidelity Results:")
total_psnr = 0
total_ssim = 0
for i in range(K):
original = patterns[i]
retrieved = readouts[i]
psnr_val = psnr(original.cpu().numpy(), retrieved.cpu().numpy())
ssim_val = ssim(original.cpu().numpy(), retrieved.cpu().numpy())
total_psnr += psnr_val
total_ssim += ssim_val
print(f" Pattern {i}: PSNR={psnr_val:.1f}dB, SSIM={ssim_val:.4f}")
avg_psnr = total_psnr / K
avg_ssim = total_ssim / K
print(f"\n🎯 Summary:")
print(f" Average PSNR: {avg_psnr:.1f}dB")
print(f" Average SSIM: {avg_ssim:.4f}")
if avg_psnr > 100:
print("βœ… EXCELLENT: >100dB PSNR (near-perfect recall)")
elif avg_psnr > 50:
print("βœ… GOOD: >50dB PSNR (high-quality recall)")
else:
print("⚠️ LOW: <50dB PSNR (may need optimization)")
return avg_psnr
def compare_code_types():
"""Compare different orthogonal code types."""
print("\n🧬 Code Types Comparison")
print("="*40)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
L, K = 32, 16
code_types = {
"Hadamard": hadamard_codes(L, K).to(device),
"DCT": dct_codes(L, K).to(device),
"Gaussian": gaussian_codes(L, K).to(device)
}
results = {}
for name, codes in code_types.items():
print(f"\n{name} Codes:")
# Orthogonality analysis
stats = coherence_stats(codes)
print(f" Max off-diagonal correlation: {stats['max_abs_offdiag']:.6f}")
print(f" Mean off-diagonal correlation: {stats['mean_abs_offdiag']:.6f}")
# Performance test
B, H, W = 1, 16, 16
bank = MembraneBank(L, H, W, device=device)
bank.allocate(B)
slicer = make_slicer(codes)
patterns = create_test_patterns(K, H, W, device)
keys = torch.arange(K, device=device)
alphas = torch.ones(K, device=device)
# Store and retrieve
M = store_pairs(bank.read(), codes, keys, patterns, alphas)
bank.write(M - bank.read())
readouts = slicer(bank.read()).squeeze(0)
# Calculate performance
psnr_values = []
for i in range(K):
psnr_val = psnr(patterns[i].cpu().numpy(), readouts[i].cpu().numpy())
psnr_values.append(psnr_val)
avg_psnr = np.mean(psnr_values)
std_psnr = np.std(psnr_values)
print(f" Performance: {avg_psnr:.1f}Β±{std_psnr:.1f}dB PSNR")
results[name] = {
'orthogonality': stats['max_abs_offdiag'],
'performance': avg_psnr
}
# Find best performer
best_code = max(results.items(), key=lambda x: x[1]['performance'])
print(f"\nπŸ† Best Performing: {best_code[0]} ({best_code[1]['performance']:.1f}dB)")
return results
def test_capacity_scaling():
"""Test how performance scales with number of stored patterns."""
print("\nπŸ“ˆ Capacity Scaling Test")
print("="*40)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
L, H, W = 64, 16, 16
# Test different pattern counts
pattern_counts = [8, 16, 32, 64] # Up to theoretical limit L
results = []
for K in pattern_counts:
print(f"\nTesting {K} patterns (capacity: {K/L:.1%})...")
bank = MembraneBank(L, H, W, device=device)
bank.allocate(1)
# Use best codes (Hadamard)
C = hadamard_codes(L, K).to(device)
slicer = make_slicer(C)
patterns = create_test_patterns(K, H, W, device)
keys = torch.arange(K, device=device)
alphas = torch.ones(K, device=device)
# Store and retrieve
M = store_pairs(bank.read(), C, keys, patterns, alphas)
bank.write(M - bank.read())
readouts = slicer(bank.read()).squeeze(0)
# Calculate metrics
psnr_values = []
for i in range(K):
psnr_val = psnr(patterns[i].cpu().numpy(), readouts[i].cpu().numpy())
psnr_values.append(psnr_val)
avg_psnr = np.mean(psnr_values)
min_psnr = np.min(psnr_values)
print(f" PSNR: {avg_psnr:.1f}dB average, {min_psnr:.1f}dB minimum")
result = {
'K': K,
'capacity_ratio': K / L,
'avg_psnr': avg_psnr,
'min_psnr': min_psnr
}
results.append(result)
# Show scaling trend
print(f"\nπŸ“Š Capacity Scaling Summary:")
for result in results:
status = "βœ…" if result['avg_psnr'] > 100 else "⚠️" if result['avg_psnr'] > 50 else "❌"
print(f" {result['capacity_ratio']:3.0%} capacity: {result['avg_psnr']:5.1f}dB {status}")
return results
def demonstrate_wave_interference():
"""Show the wave interference pattern that gives WrinkleBrane its name."""
print("\n🌊 Wave Interference Demonstration")
print("="*40)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
L, H, W = 16, 8, 8
# Create simple test case
bank = MembraneBank(L, H, W, device=device)
bank.allocate(1)
# Store two simple patterns
K = 2
C = hadamard_codes(L, K).to(device)
# Pattern 1: single point
pattern1 = torch.zeros(H, W, device=device)
pattern1[H//2, W//2] = 1.0
# Pattern 2: cross shape
pattern2 = torch.zeros(H, W, device=device)
pattern2[H//2, :] = 0.5
pattern2[:, W//2] = 0.5
patterns = torch.stack([pattern1, pattern2])
keys = torch.tensor([0, 1], device=device)
alphas = torch.ones(2, device=device)
# Store patterns and examine membrane state
M = store_pairs(bank.read(), C, keys, patterns, alphas)
bank.write(M - bank.read())
# Show interference in membrane layers
membrane_state = bank.read().squeeze(0) # Remove batch dimension: [L, H, W]
print(f"Membrane state shape: {membrane_state.shape}")
print(f"Pattern 1 energy: {torch.norm(pattern1):.3f}")
print(f"Pattern 2 energy: {torch.norm(pattern2):.3f}")
# Calculate total energy across layers
layer_energies = []
for l in range(L):
energy = torch.norm(membrane_state[l]).item()
layer_energies.append(energy)
print(f"Layer energies (first 8): {[f'{e:.3f}' for e in layer_energies[:8]]}")
# Retrieve and verify
slicer = make_slicer(C)
readouts = slicer(bank.read()).squeeze(0)
psnr1 = psnr(pattern1.cpu().numpy(), readouts[0].cpu().numpy())
psnr2 = psnr(pattern2.cpu().numpy(), readouts[1].cpu().numpy())
print(f"\nRetrieval fidelity:")
print(f" Pattern 1: {psnr1:.1f}dB PSNR")
print(f" Pattern 2: {psnr2:.1f}dB PSNR")
# Show the "wrinkle" effect - constructive/destructive interference
total_membrane_energy = torch.norm(membrane_state).item()
expected_energy = torch.norm(pattern1).item() + torch.norm(pattern2).item()
print(f"\nWave interference analysis:")
print(f" Total membrane energy: {total_membrane_energy:.3f}")
print(f" Expected (no interference): {expected_energy:.3f}")
print(f" Interference factor: {total_membrane_energy/expected_energy:.3f}")
return membrane_state
def main():
"""Run complete WrinkleBrane demonstration."""
print("πŸš€ WrinkleBrane Complete Demonstration")
print("="*50)
torch.manual_seed(42) # Reproducible results
np.random.seed(42)
try:
# Basic functionality
basic_psnr = demonstrate_basic_functionality()
# Code comparison
code_results = compare_code_types()
# Capacity scaling
capacity_results = test_capacity_scaling()
# Wave interference demo
membrane_state = demonstrate_wave_interference()
print("\n" + "="*50)
print("πŸŽ‰ WrinkleBrane Demonstration Complete!")
print("="*50)
print("\nπŸ“‹ Key Results:")
print(f"β€’ Basic fidelity: {basic_psnr:.1f}dB PSNR")
print(f"β€’ Best code type: {max(code_results.items(), key=lambda x: x[1]['performance'])[0]}")
print(f"β€’ Maximum capacity: {capacity_results[-1]['K']} patterns at {capacity_results[-1]['avg_psnr']:.1f}dB")
print(f"β€’ Membrane state shape: {membrane_state.shape}")
if basic_psnr > 100:
print("\nπŸ† WrinkleBrane is performing EXCELLENTLY!")
print(" Wave-interference associative memory working at near-perfect fidelity!")
else:
print(f"\nβœ… WrinkleBrane is working correctly with {basic_psnr:.1f}dB fidelity")
except Exception as e:
print(f"\n❌ Demo failed with error: {e}")
import traceback
traceback.print_exc()
return False
return True
if __name__ == "__main__":
main()