|
|
|
""" |
|
Simple WrinkleBrane Demo |
|
Shows basic functionality and a few simple optimizations working. |
|
""" |
|
|
|
import sys |
|
from pathlib import Path |
|
sys.path.append(str(Path(__file__).resolve().parent / "src")) |
|
|
|
import torch |
|
import numpy as np |
|
import matplotlib.pyplot as plt |
|
from wrinklebrane.membrane_bank import MembraneBank |
|
from wrinklebrane.codes import hadamard_codes, dct_codes, gaussian_codes, coherence_stats |
|
from wrinklebrane.slicer import make_slicer |
|
from wrinklebrane.write_ops import store_pairs |
|
from wrinklebrane.metrics import psnr, ssim |
|
|
|
def create_test_patterns(K, H, W, device): |
|
"""Create diverse test patterns for demonstration.""" |
|
patterns = [] |
|
|
|
for i in range(K): |
|
pattern = torch.zeros(H, W, device=device) |
|
|
|
if i % 4 == 0: |
|
center = (H // 2, W // 2) |
|
radius = 2 + (i // 4) |
|
for y in range(H): |
|
for x in range(W): |
|
if (x - center[0])**2 + (y - center[1])**2 <= radius**2: |
|
pattern[y, x] = 1.0 |
|
elif i % 4 == 1: |
|
size = 4 + (i // 4) |
|
start = (H - size) // 2 |
|
end = start + size |
|
if end <= H and end <= W: |
|
pattern[start:end, start:end] = 1.0 |
|
elif i % 4 == 2: |
|
y = H // 2 + (i // 4) - 1 |
|
if 0 <= y < H: |
|
pattern[y, :] = 1.0 |
|
else: |
|
x = W // 2 + (i // 4) - 1 |
|
if 0 <= x < W: |
|
pattern[:, x] = 1.0 |
|
|
|
patterns.append(pattern) |
|
|
|
return torch.stack(patterns) |
|
|
|
def demonstrate_basic_functionality(): |
|
"""Show WrinkleBrane working with perfect recall.""" |
|
print("π WrinkleBrane Basic Functionality Demo") |
|
print("="*40) |
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
B, L, H, W, K = 1, 32, 16, 16, 8 |
|
|
|
print(f"Configuration: L={L}, H={H}, W={W}, K={K} patterns") |
|
print(f"Device: {device}") |
|
|
|
|
|
bank = MembraneBank(L, H, W, device=device) |
|
bank.allocate(B) |
|
|
|
C = hadamard_codes(L, K).to(device) |
|
slicer = make_slicer(C) |
|
|
|
patterns = create_test_patterns(K, H, W, device) |
|
keys = torch.arange(K, device=device) |
|
alphas = torch.ones(K, device=device) |
|
|
|
|
|
print("\nπ Storing patterns...") |
|
M = store_pairs(bank.read(), C, keys, patterns, alphas) |
|
bank.write(M - bank.read()) |
|
|
|
|
|
print("π Retrieving patterns...") |
|
readouts = slicer(bank.read()).squeeze(0) |
|
|
|
|
|
print("\nπ Fidelity Results:") |
|
total_psnr = 0 |
|
total_ssim = 0 |
|
|
|
for i in range(K): |
|
original = patterns[i] |
|
retrieved = readouts[i] |
|
|
|
psnr_val = psnr(original.cpu().numpy(), retrieved.cpu().numpy()) |
|
ssim_val = ssim(original.cpu().numpy(), retrieved.cpu().numpy()) |
|
|
|
total_psnr += psnr_val |
|
total_ssim += ssim_val |
|
|
|
print(f" Pattern {i}: PSNR={psnr_val:.1f}dB, SSIM={ssim_val:.4f}") |
|
|
|
avg_psnr = total_psnr / K |
|
avg_ssim = total_ssim / K |
|
|
|
print(f"\nπ― Summary:") |
|
print(f" Average PSNR: {avg_psnr:.1f}dB") |
|
print(f" Average SSIM: {avg_ssim:.4f}") |
|
|
|
if avg_psnr > 100: |
|
print("β
EXCELLENT: >100dB PSNR (near-perfect recall)") |
|
elif avg_psnr > 50: |
|
print("β
GOOD: >50dB PSNR (high-quality recall)") |
|
else: |
|
print("β οΈ LOW: <50dB PSNR (may need optimization)") |
|
|
|
return avg_psnr |
|
|
|
def compare_code_types(): |
|
"""Compare different orthogonal code types.""" |
|
print("\n𧬠Code Types Comparison") |
|
print("="*40) |
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
L, K = 32, 16 |
|
|
|
code_types = { |
|
"Hadamard": hadamard_codes(L, K).to(device), |
|
"DCT": dct_codes(L, K).to(device), |
|
"Gaussian": gaussian_codes(L, K).to(device) |
|
} |
|
|
|
results = {} |
|
|
|
for name, codes in code_types.items(): |
|
print(f"\n{name} Codes:") |
|
|
|
|
|
stats = coherence_stats(codes) |
|
print(f" Max off-diagonal correlation: {stats['max_abs_offdiag']:.6f}") |
|
print(f" Mean off-diagonal correlation: {stats['mean_abs_offdiag']:.6f}") |
|
|
|
|
|
B, H, W = 1, 16, 16 |
|
bank = MembraneBank(L, H, W, device=device) |
|
bank.allocate(B) |
|
|
|
slicer = make_slicer(codes) |
|
patterns = create_test_patterns(K, H, W, device) |
|
keys = torch.arange(K, device=device) |
|
alphas = torch.ones(K, device=device) |
|
|
|
|
|
M = store_pairs(bank.read(), codes, keys, patterns, alphas) |
|
bank.write(M - bank.read()) |
|
readouts = slicer(bank.read()).squeeze(0) |
|
|
|
|
|
psnr_values = [] |
|
for i in range(K): |
|
psnr_val = psnr(patterns[i].cpu().numpy(), readouts[i].cpu().numpy()) |
|
psnr_values.append(psnr_val) |
|
|
|
avg_psnr = np.mean(psnr_values) |
|
std_psnr = np.std(psnr_values) |
|
|
|
print(f" Performance: {avg_psnr:.1f}Β±{std_psnr:.1f}dB PSNR") |
|
|
|
results[name] = { |
|
'orthogonality': stats['max_abs_offdiag'], |
|
'performance': avg_psnr |
|
} |
|
|
|
|
|
best_code = max(results.items(), key=lambda x: x[1]['performance']) |
|
print(f"\nπ Best Performing: {best_code[0]} ({best_code[1]['performance']:.1f}dB)") |
|
|
|
return results |
|
|
|
def test_capacity_scaling(): |
|
"""Test how performance scales with number of stored patterns.""" |
|
print("\nπ Capacity Scaling Test") |
|
print("="*40) |
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
L, H, W = 64, 16, 16 |
|
|
|
|
|
pattern_counts = [8, 16, 32, 64] |
|
results = [] |
|
|
|
for K in pattern_counts: |
|
print(f"\nTesting {K} patterns (capacity: {K/L:.1%})...") |
|
|
|
bank = MembraneBank(L, H, W, device=device) |
|
bank.allocate(1) |
|
|
|
|
|
C = hadamard_codes(L, K).to(device) |
|
slicer = make_slicer(C) |
|
|
|
patterns = create_test_patterns(K, H, W, device) |
|
keys = torch.arange(K, device=device) |
|
alphas = torch.ones(K, device=device) |
|
|
|
|
|
M = store_pairs(bank.read(), C, keys, patterns, alphas) |
|
bank.write(M - bank.read()) |
|
readouts = slicer(bank.read()).squeeze(0) |
|
|
|
|
|
psnr_values = [] |
|
for i in range(K): |
|
psnr_val = psnr(patterns[i].cpu().numpy(), readouts[i].cpu().numpy()) |
|
psnr_values.append(psnr_val) |
|
|
|
avg_psnr = np.mean(psnr_values) |
|
min_psnr = np.min(psnr_values) |
|
|
|
print(f" PSNR: {avg_psnr:.1f}dB average, {min_psnr:.1f}dB minimum") |
|
|
|
result = { |
|
'K': K, |
|
'capacity_ratio': K / L, |
|
'avg_psnr': avg_psnr, |
|
'min_psnr': min_psnr |
|
} |
|
results.append(result) |
|
|
|
|
|
print(f"\nπ Capacity Scaling Summary:") |
|
for result in results: |
|
status = "β
" if result['avg_psnr'] > 100 else "β οΈ" if result['avg_psnr'] > 50 else "β" |
|
print(f" {result['capacity_ratio']:3.0%} capacity: {result['avg_psnr']:5.1f}dB {status}") |
|
|
|
return results |
|
|
|
def demonstrate_wave_interference(): |
|
"""Show the wave interference pattern that gives WrinkleBrane its name.""" |
|
print("\nπ Wave Interference Demonstration") |
|
print("="*40) |
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
L, H, W = 16, 8, 8 |
|
|
|
|
|
bank = MembraneBank(L, H, W, device=device) |
|
bank.allocate(1) |
|
|
|
|
|
K = 2 |
|
C = hadamard_codes(L, K).to(device) |
|
|
|
|
|
pattern1 = torch.zeros(H, W, device=device) |
|
pattern1[H//2, W//2] = 1.0 |
|
|
|
|
|
pattern2 = torch.zeros(H, W, device=device) |
|
pattern2[H//2, :] = 0.5 |
|
pattern2[:, W//2] = 0.5 |
|
|
|
patterns = torch.stack([pattern1, pattern2]) |
|
keys = torch.tensor([0, 1], device=device) |
|
alphas = torch.ones(2, device=device) |
|
|
|
|
|
M = store_pairs(bank.read(), C, keys, patterns, alphas) |
|
bank.write(M - bank.read()) |
|
|
|
|
|
membrane_state = bank.read().squeeze(0) |
|
|
|
print(f"Membrane state shape: {membrane_state.shape}") |
|
print(f"Pattern 1 energy: {torch.norm(pattern1):.3f}") |
|
print(f"Pattern 2 energy: {torch.norm(pattern2):.3f}") |
|
|
|
|
|
layer_energies = [] |
|
for l in range(L): |
|
energy = torch.norm(membrane_state[l]).item() |
|
layer_energies.append(energy) |
|
|
|
print(f"Layer energies (first 8): {[f'{e:.3f}' for e in layer_energies[:8]]}") |
|
|
|
|
|
slicer = make_slicer(C) |
|
readouts = slicer(bank.read()).squeeze(0) |
|
|
|
psnr1 = psnr(pattern1.cpu().numpy(), readouts[0].cpu().numpy()) |
|
psnr2 = psnr(pattern2.cpu().numpy(), readouts[1].cpu().numpy()) |
|
|
|
print(f"\nRetrieval fidelity:") |
|
print(f" Pattern 1: {psnr1:.1f}dB PSNR") |
|
print(f" Pattern 2: {psnr2:.1f}dB PSNR") |
|
|
|
|
|
total_membrane_energy = torch.norm(membrane_state).item() |
|
expected_energy = torch.norm(pattern1).item() + torch.norm(pattern2).item() |
|
|
|
print(f"\nWave interference analysis:") |
|
print(f" Total membrane energy: {total_membrane_energy:.3f}") |
|
print(f" Expected (no interference): {expected_energy:.3f}") |
|
print(f" Interference factor: {total_membrane_energy/expected_energy:.3f}") |
|
|
|
return membrane_state |
|
|
|
def main(): |
|
"""Run complete WrinkleBrane demonstration.""" |
|
print("π WrinkleBrane Complete Demonstration") |
|
print("="*50) |
|
|
|
torch.manual_seed(42) |
|
np.random.seed(42) |
|
|
|
try: |
|
|
|
basic_psnr = demonstrate_basic_functionality() |
|
|
|
|
|
code_results = compare_code_types() |
|
|
|
|
|
capacity_results = test_capacity_scaling() |
|
|
|
|
|
membrane_state = demonstrate_wave_interference() |
|
|
|
print("\n" + "="*50) |
|
print("π WrinkleBrane Demonstration Complete!") |
|
print("="*50) |
|
|
|
print("\nπ Key Results:") |
|
print(f"β’ Basic fidelity: {basic_psnr:.1f}dB PSNR") |
|
print(f"β’ Best code type: {max(code_results.items(), key=lambda x: x[1]['performance'])[0]}") |
|
print(f"β’ Maximum capacity: {capacity_results[-1]['K']} patterns at {capacity_results[-1]['avg_psnr']:.1f}dB") |
|
print(f"β’ Membrane state shape: {membrane_state.shape}") |
|
|
|
if basic_psnr > 100: |
|
print("\nπ WrinkleBrane is performing EXCELLENTLY!") |
|
print(" Wave-interference associative memory working at near-perfect fidelity!") |
|
else: |
|
print(f"\nβ
WrinkleBrane is working correctly with {basic_psnr:.1f}dB fidelity") |
|
|
|
except Exception as e: |
|
print(f"\nβ Demo failed with error: {e}") |
|
import traceback |
|
traceback.print_exc() |
|
return False |
|
|
|
return True |
|
|
|
if __name__ == "__main__": |
|
main() |