#!/usr/bin/env python3 """ Simple WrinkleBrane Demo Shows basic functionality and a few simple optimizations working. """ import sys from pathlib import Path sys.path.append(str(Path(__file__).resolve().parent / "src")) import torch import numpy as np import matplotlib.pyplot as plt from wrinklebrane.membrane_bank import MembraneBank from wrinklebrane.codes import hadamard_codes, dct_codes, gaussian_codes, coherence_stats from wrinklebrane.slicer import make_slicer from wrinklebrane.write_ops import store_pairs from wrinklebrane.metrics import psnr, ssim def create_test_patterns(K, H, W, device): """Create diverse test patterns for demonstration.""" patterns = [] for i in range(K): pattern = torch.zeros(H, W, device=device) if i % 4 == 0: # Circles center = (H // 2, W // 2) radius = 2 + (i // 4) for y in range(H): for x in range(W): if (x - center[0])**2 + (y - center[1])**2 <= radius**2: pattern[y, x] = 1.0 elif i % 4 == 1: # Squares size = 4 + (i // 4) start = (H - size) // 2 end = start + size if end <= H and end <= W: pattern[start:end, start:end] = 1.0 elif i % 4 == 2: # Horizontal lines y = H // 2 + (i // 4) - 1 if 0 <= y < H: pattern[y, :] = 1.0 else: # Vertical lines x = W // 2 + (i // 4) - 1 if 0 <= x < W: pattern[:, x] = 1.0 patterns.append(pattern) return torch.stack(patterns) def demonstrate_basic_functionality(): """Show WrinkleBrane working with perfect recall.""" print("🌊 WrinkleBrane Basic Functionality Demo") print("="*40) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") B, L, H, W, K = 1, 32, 16, 16, 8 print(f"Configuration: L={L}, H={H}, W={W}, K={K} patterns") print(f"Device: {device}") # Setup bank = MembraneBank(L, H, W, device=device) bank.allocate(B) C = hadamard_codes(L, K).to(device) slicer = make_slicer(C) patterns = create_test_patterns(K, H, W, device) keys = torch.arange(K, device=device) alphas = torch.ones(K, device=device) # Store patterns print("\nšŸ“ Storing patterns...") M = store_pairs(bank.read(), C, keys, patterns, alphas) bank.write(M - bank.read()) # Retrieve patterns print("šŸ“– Retrieving patterns...") readouts = slicer(bank.read()).squeeze(0) # Calculate fidelity print("\nšŸ“Š Fidelity Results:") total_psnr = 0 total_ssim = 0 for i in range(K): original = patterns[i] retrieved = readouts[i] psnr_val = psnr(original.cpu().numpy(), retrieved.cpu().numpy()) ssim_val = ssim(original.cpu().numpy(), retrieved.cpu().numpy()) total_psnr += psnr_val total_ssim += ssim_val print(f" Pattern {i}: PSNR={psnr_val:.1f}dB, SSIM={ssim_val:.4f}") avg_psnr = total_psnr / K avg_ssim = total_ssim / K print(f"\nšŸŽÆ Summary:") print(f" Average PSNR: {avg_psnr:.1f}dB") print(f" Average SSIM: {avg_ssim:.4f}") if avg_psnr > 100: print("āœ… EXCELLENT: >100dB PSNR (near-perfect recall)") elif avg_psnr > 50: print("āœ… GOOD: >50dB PSNR (high-quality recall)") else: print("āš ļø LOW: <50dB PSNR (may need optimization)") return avg_psnr def compare_code_types(): """Compare different orthogonal code types.""" print("\n🧬 Code Types Comparison") print("="*40) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") L, K = 32, 16 code_types = { "Hadamard": hadamard_codes(L, K).to(device), "DCT": dct_codes(L, K).to(device), "Gaussian": gaussian_codes(L, K).to(device) } results = {} for name, codes in code_types.items(): print(f"\n{name} Codes:") # Orthogonality analysis stats = coherence_stats(codes) print(f" Max off-diagonal correlation: {stats['max_abs_offdiag']:.6f}") print(f" Mean off-diagonal correlation: {stats['mean_abs_offdiag']:.6f}") # Performance test B, H, W = 1, 16, 16 bank = MembraneBank(L, H, W, device=device) bank.allocate(B) slicer = make_slicer(codes) patterns = create_test_patterns(K, H, W, device) keys = torch.arange(K, device=device) alphas = torch.ones(K, device=device) # Store and retrieve M = store_pairs(bank.read(), codes, keys, patterns, alphas) bank.write(M - bank.read()) readouts = slicer(bank.read()).squeeze(0) # Calculate performance psnr_values = [] for i in range(K): psnr_val = psnr(patterns[i].cpu().numpy(), readouts[i].cpu().numpy()) psnr_values.append(psnr_val) avg_psnr = np.mean(psnr_values) std_psnr = np.std(psnr_values) print(f" Performance: {avg_psnr:.1f}±{std_psnr:.1f}dB PSNR") results[name] = { 'orthogonality': stats['max_abs_offdiag'], 'performance': avg_psnr } # Find best performer best_code = max(results.items(), key=lambda x: x[1]['performance']) print(f"\nšŸ† Best Performing: {best_code[0]} ({best_code[1]['performance']:.1f}dB)") return results def test_capacity_scaling(): """Test how performance scales with number of stored patterns.""" print("\nšŸ“ˆ Capacity Scaling Test") print("="*40) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") L, H, W = 64, 16, 16 # Test different pattern counts pattern_counts = [8, 16, 32, 64] # Up to theoretical limit L results = [] for K in pattern_counts: print(f"\nTesting {K} patterns (capacity: {K/L:.1%})...") bank = MembraneBank(L, H, W, device=device) bank.allocate(1) # Use best codes (Hadamard) C = hadamard_codes(L, K).to(device) slicer = make_slicer(C) patterns = create_test_patterns(K, H, W, device) keys = torch.arange(K, device=device) alphas = torch.ones(K, device=device) # Store and retrieve M = store_pairs(bank.read(), C, keys, patterns, alphas) bank.write(M - bank.read()) readouts = slicer(bank.read()).squeeze(0) # Calculate metrics psnr_values = [] for i in range(K): psnr_val = psnr(patterns[i].cpu().numpy(), readouts[i].cpu().numpy()) psnr_values.append(psnr_val) avg_psnr = np.mean(psnr_values) min_psnr = np.min(psnr_values) print(f" PSNR: {avg_psnr:.1f}dB average, {min_psnr:.1f}dB minimum") result = { 'K': K, 'capacity_ratio': K / L, 'avg_psnr': avg_psnr, 'min_psnr': min_psnr } results.append(result) # Show scaling trend print(f"\nšŸ“Š Capacity Scaling Summary:") for result in results: status = "āœ…" if result['avg_psnr'] > 100 else "āš ļø" if result['avg_psnr'] > 50 else "āŒ" print(f" {result['capacity_ratio']:3.0%} capacity: {result['avg_psnr']:5.1f}dB {status}") return results def demonstrate_wave_interference(): """Show the wave interference pattern that gives WrinkleBrane its name.""" print("\n🌊 Wave Interference Demonstration") print("="*40) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") L, H, W = 16, 8, 8 # Create simple test case bank = MembraneBank(L, H, W, device=device) bank.allocate(1) # Store two simple patterns K = 2 C = hadamard_codes(L, K).to(device) # Pattern 1: single point pattern1 = torch.zeros(H, W, device=device) pattern1[H//2, W//2] = 1.0 # Pattern 2: cross shape pattern2 = torch.zeros(H, W, device=device) pattern2[H//2, :] = 0.5 pattern2[:, W//2] = 0.5 patterns = torch.stack([pattern1, pattern2]) keys = torch.tensor([0, 1], device=device) alphas = torch.ones(2, device=device) # Store patterns and examine membrane state M = store_pairs(bank.read(), C, keys, patterns, alphas) bank.write(M - bank.read()) # Show interference in membrane layers membrane_state = bank.read().squeeze(0) # Remove batch dimension: [L, H, W] print(f"Membrane state shape: {membrane_state.shape}") print(f"Pattern 1 energy: {torch.norm(pattern1):.3f}") print(f"Pattern 2 energy: {torch.norm(pattern2):.3f}") # Calculate total energy across layers layer_energies = [] for l in range(L): energy = torch.norm(membrane_state[l]).item() layer_energies.append(energy) print(f"Layer energies (first 8): {[f'{e:.3f}' for e in layer_energies[:8]]}") # Retrieve and verify slicer = make_slicer(C) readouts = slicer(bank.read()).squeeze(0) psnr1 = psnr(pattern1.cpu().numpy(), readouts[0].cpu().numpy()) psnr2 = psnr(pattern2.cpu().numpy(), readouts[1].cpu().numpy()) print(f"\nRetrieval fidelity:") print(f" Pattern 1: {psnr1:.1f}dB PSNR") print(f" Pattern 2: {psnr2:.1f}dB PSNR") # Show the "wrinkle" effect - constructive/destructive interference total_membrane_energy = torch.norm(membrane_state).item() expected_energy = torch.norm(pattern1).item() + torch.norm(pattern2).item() print(f"\nWave interference analysis:") print(f" Total membrane energy: {total_membrane_energy:.3f}") print(f" Expected (no interference): {expected_energy:.3f}") print(f" Interference factor: {total_membrane_energy/expected_energy:.3f}") return membrane_state def main(): """Run complete WrinkleBrane demonstration.""" print("šŸš€ WrinkleBrane Complete Demonstration") print("="*50) torch.manual_seed(42) # Reproducible results np.random.seed(42) try: # Basic functionality basic_psnr = demonstrate_basic_functionality() # Code comparison code_results = compare_code_types() # Capacity scaling capacity_results = test_capacity_scaling() # Wave interference demo membrane_state = demonstrate_wave_interference() print("\n" + "="*50) print("šŸŽ‰ WrinkleBrane Demonstration Complete!") print("="*50) print("\nšŸ“‹ Key Results:") print(f"• Basic fidelity: {basic_psnr:.1f}dB PSNR") print(f"• Best code type: {max(code_results.items(), key=lambda x: x[1]['performance'])[0]}") print(f"• Maximum capacity: {capacity_results[-1]['K']} patterns at {capacity_results[-1]['avg_psnr']:.1f}dB") print(f"• Membrane state shape: {membrane_state.shape}") if basic_psnr > 100: print("\nšŸ† WrinkleBrane is performing EXCELLENTLY!") print(" Wave-interference associative memory working at near-perfect fidelity!") else: print(f"\nāœ… WrinkleBrane is working correctly with {basic_psnr:.1f}dB fidelity") except Exception as e: print(f"\nāŒ Demo failed with error: {e}") import traceback traceback.print_exc() return False return True if __name__ == "__main__": main()