hebbian_bloom / hebbian_bloom.py
1990two's picture
Update hebbian_bloom.py
02b6ba6 verified
###########################################################################################################################################
#||- - - |8.19.2025| - - - || HEBBIAN BLOOM || - - - | 1990two | - - -||#
###########################################################################################################################################
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import math
import hashlib
from collections import defaultdict, deque
from typing import List, Dict, Tuple, Optional, Union
SAFE_MIN = -1e6
SAFE_MAX = 1e6
EPS = 1e-8
#||- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - 𓅸 - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -||#
def make_safe(tensor, min_val=SAFE_MIN, max_val=SAFE_MAX):
tensor = torch.where(torch.isnan(tensor), torch.tensor(0.0, device=tensor.device, dtype=tensor.dtype), tensor)
tensor = torch.where(torch.isinf(tensor), torch.tensor(max_val, device=tensor.device, dtype=tensor.dtype), tensor)
return torch.clamp(tensor, min_val, max_val)
def safe_cosine_similarity(a, b, dim=-1, eps=EPS):
if a.dtype != torch.float32:
a = a.float()
if b.dtype != torch.float32:
b = b.float()
a_norm = torch.norm(a, dim=dim, keepdim=True).clamp(min=eps)
b_norm = torch.norm(b, dim=dim, keepdim=True).clamp(min=eps)
return torch.sum(a * b, dim=dim, keepdim=True) / (a_norm * b_norm)
def item_to_vector(item, vector_dim=64):
if isinstance(item, str):
hash_obj = hashlib.md5(item.encode())
hash_bytes = hash_obj.digest()
vector = torch.tensor([b / 255.0 for b in hash_bytes], dtype=torch.float32)
if len(vector) < vector_dim:
padding = torch.zeros(vector_dim - len(vector), dtype=torch.float32)
vector = torch.cat([vector, padding])
else:
vector = vector[:vector_dim]
elif isinstance(item, (int, float)):
vector = torch.zeros(vector_dim, dtype=torch.float32)
for i in range(vector_dim // 2):
freq = 10000 ** (-2 * i / vector_dim)
vector[2*i] = math.sin(item * freq)
vector[2*i + 1] = math.cos(item * freq)
elif torch.is_tensor(item):
vector = item.flatten().float()
if len(vector) < vector_dim:
padding = torch.zeros(vector_dim - len(vector), dtype=torch.float32, device=vector.device)
vector = torch.cat([vector, padding])
else:
vector = vector[:vector_dim]
else:
hash_val = hash(str(item)) % (2**31)
gen = torch.Generator(device='cpu')
gen.manual_seed(hash_val)
vector = torch.randn(vector_dim, generator=gen, dtype=torch.float32)
return make_safe(vector)
###########################################################################################################################################
###############################################- - - LEARNABLE HASH FUNCTION - - -#####################################################
class LearnableHashFunction(nn.Module):
def __init__(self, input_dim, hash_output_bits=32, learning_rate=0.01):
super().__init__()
self.input_dim = input_dim
self.hash_output_bits = hash_output_bits
self.learning_rate = learning_rate
self.hash_network = nn.Sequential(
nn.Linear(input_dim, input_dim * 2),
nn.LayerNorm(input_dim * 2),
nn.Tanh(),
nn.Linear(input_dim * 2, hash_output_bits),
nn.Tanh() # Output in [-1, 1]
)
self.hebbian_weights = nn.Parameter(torch.ones(hash_output_bits) * 0.1)
self.plasticity_rate = nn.Parameter(torch.tensor(learning_rate))
self.register_buffer('activity_history', torch.zeros(100, hash_output_bits))
self.register_buffer('history_pointer', torch.tensor(0, dtype=torch.long))
self.coactivation_matrix = nn.Parameter(torch.eye(hash_output_bits) * 0.1)
self.activation_threshold = nn.Parameter(torch.zeros(hash_output_bits))
def compute_hash_activation(self, item_vector):
if item_vector.dim() == 1:
item_vector = item_vector.unsqueeze(0)
item_vector = item_vector.to(next(self.hash_network.parameters()).device, dtype=torch.float32)
base_hash = self.hash_network(item_vector).squeeze(0)
hebbian_modulation = torch.tanh(self.hebbian_weights)
modulated_hash = base_hash * hebbian_modulation
thresholded = modulated_hash - self.activation_threshold
hash_probs = torch.sigmoid(thresholded * 10.0) # Sharp sigmoid
return hash_probs, modulated_hash
def get_hash_bits(self, item_vector, deterministic=False):
hash_probs, _ = self.compute_hash_activation(item_vector)
if deterministic:
hash_bits = (hash_probs > 0.5).float()
else:
hash_bits = torch.bernoulli(hash_probs)
return hash_bits
def hebbian_update(self, item_vector, co_occurring_items=None):
hash_probs, modulated_hash = self.compute_hash_activation(item_vector)
with torch.no_grad():
ptr = int(self.history_pointer.item())
self.activity_history[ptr % self.activity_history.size(0)].copy_(hash_probs.detach())
self.history_pointer.add_(1)
self.history_pointer.remainder_(self.activity_history.size(0))
plasticity_rate = torch.clamp(self.plasticity_rate, 0.001, 0.1)
activity_strength = torch.abs(modulated_hash)
hebbian_delta = plasticity_rate * activity_strength * hash_probs
with torch.no_grad():
self.hebbian_weights.data.add_(hebbian_delta * 0.05)
self.hebbian_weights.data.clamp_(-2.0, 2.0)
if co_occurring_items is not None:
self.update_coactivation_matrix(hash_probs, co_occurring_items)
return hash_probs
def update_coactivation_matrix(self, current_activation, co_occurring_items):
with torch.no_grad():
for co_item in co_occurring_items:
co_item_vector = item_to_vector(co_item, self.input_dim).to(current_activation.device)
co_activation, _ = self.compute_hash_activation(co_item_vector)
coactivation_update = torch.outer(current_activation, co_activation)
learning_rate = 0.01
self.coactivation_matrix.data.add_(learning_rate * coactivation_update)
self.coactivation_matrix.data.clamp_(-1.0, 1.0)
def get_similar_patterns(self, item_vector, top_k=5):
current_probs, _ = self.compute_hash_activation(item_vector)
similarities = []
for i in range(self.activity_history.shape[0]):
hist_pattern = self.activity_history[i]
if torch.sum(hist_pattern) > 0: # Non-zero pattern
similarity = safe_cosine_similarity(
current_probs.unsqueeze(0),
hist_pattern.unsqueeze(0)
).squeeze()
similarities.append((i, float(similarity.item())))
similarities.sort(key=lambda x: x[1], reverse=True)
return similarities[:top_k]
def apply_forgetting(self, forget_rate=0.99):
with torch.no_grad():
self.hebbian_weights.data.mul_(forget_rate)
self.coactivation_matrix.data.mul_(forget_rate)
###########################################################################################################################################
################################################- - - HEBBIAN BLOOM FILTER - - -#######################################################
class HebbianBloomFilter(nn.Module):
def __init__(self, capacity=10000, error_rate=0.01, vector_dim=64, num_hash_functions=8):
super().__init__()
self.capacity = capacity
self.error_rate = error_rate
self.vector_dim = vector_dim
self.num_hash_functions = num_hash_functions
self.bit_array_size = self._calculate_bit_array_size(capacity, error_rate)
self.hash_functions = nn.ModuleList([
LearnableHashFunction(vector_dim, hash_output_bits=32)
for _ in range(num_hash_functions)
])
self.register_buffer('bit_array', torch.zeros(self.bit_array_size))
self.register_buffer('confidence_array', torch.zeros(self.bit_array_size))
self.stored_items = {}
self.item_vectors = {}
self.register_buffer('access_counts', torch.zeros(self.bit_array_size))
self.register_buffer('total_items_added', torch.tensor(0, dtype=torch.long))
self.association_strength = nn.Parameter(torch.tensor(0.1))
self.confidence_threshold = nn.Parameter(torch.tensor(0.5))
self.decay_rate = nn.Parameter(torch.tensor(0.999))
def _calculate_bit_array_size(self, capacity, error_rate):
return int(-capacity * math.log(error_rate) / (math.log(2) ** 2))
def _get_bit_indices(self, item_vector):
indices = []
confidences = []
for hash_func in self.hash_functions:
hash_bits = hash_func.get_hash_bits(item_vector, deterministic=True)
weights = (1 << torch.arange(len(hash_bits), device=hash_bits.device, dtype=torch.int64))
bit_index = int((hash_bits.to(dtype=torch.int64) * weights).sum().item())
bit_index = bit_index % self.bit_array_size
hash_probs, _ = hash_func.compute_hash_activation(item_vector)
confidence = torch.mean(torch.abs(hash_probs - 0.5)) * 2 # Distance from uncertain (0.5)
indices.append(bit_index)
confidences.append(confidence.item())
return indices, confidences
def add(self, item, associated_items=None):
item_vector = item_to_vector(item, self.vector_dim)
item_key = str(item)
self.stored_items[item_key] = item
self.item_vectors[item_key] = item_vector
indices, confidences = self._get_bit_indices(item_vector)
with torch.no_grad():
for idx, conf in zip(indices, confidences):
self.bit_array[idx] = 1.0
self.confidence_array[idx] = max(float(self.confidence_array[idx].item()), conf)
self.access_counts[idx] += 1
for hash_func in self.hash_functions:
hash_func.hebbian_update(item_vector, associated_items)
with torch.no_grad():
self.total_items_added.add_(1)
if associated_items:
self._learn_associations(item, associated_items)
return indices
def _learn_associations(self, primary_item, associated_items):
primary_vector = item_to_vector(primary_item, self.vector_dim)
for assoc_item in associated_items:
assoc_vector = item_to_vector(assoc_item, self.vector_dim)
similarity = safe_cosine_similarity(
primary_vector.unsqueeze(0),
assoc_vector.unsqueeze(0)
).squeeze()
association_strength = torch.clamp(self.association_strength, 0.01, 1.0)
_ = association_strength # keep variable used to respect format
for hash_func in self.hash_functions:
if float(similarity.item()) > 0.5:
hash_func.hebbian_update(primary_vector, [assoc_item])
def query(self, item, return_confidence=False):
item_vector = item_to_vector(item, self.vector_dim)
indices, confidences = self._get_bit_indices(item_vector)
bit_checks = [self.bit_array[idx].item() > 0 for idx in indices]
is_member = all(bit_checks)
if return_confidence:
bit_confidences = [self.confidence_array[idx].item() for idx in indices]
hash_confidences = confidences
bit_conf = np.mean(bit_confidences) if bit_confidences else 0.0
hash_conf = np.mean(hash_confidences) if hash_confidences else 0.0
access_conf = np.mean([self.access_counts[idx].item() for idx in indices])
access_conf = min(access_conf / 10.0, 1.0) # Normalize
overall_confidence = (bit_conf + hash_conf + access_conf) / 3.0
return is_member, overall_confidence
return is_member
def find_similar_items(self, query_item, top_k=5):
query_vector = item_to_vector(query_item, self.vector_dim)
coact_weights = []
for hash_func in self.hash_functions:
q_act, _ = hash_func.compute_hash_activation(query_vector)
q_weight = torch.matmul(hash_func.coactivation_matrix.t(), q_act)
coact_weights.append((q_act, q_weight))
similarities = []
for item_key, item_vector in self.item_vectors.items():
base_sim = safe_cosine_similarity(
query_vector.unsqueeze(0),
item_vector.unsqueeze(0)
).squeeze().item()
co_sim_sum = 0.0
for (hash_func, (q_act, q_weight)) in zip(self.hash_functions, coact_weights):
i_act, _ = hash_func.compute_hash_activation(item_vector)
co_sim_sum += torch.dot(q_weight, i_act).item() / max(1, len(i_act))
co_sim = co_sim_sum / max(1, len(self.hash_functions))
alpha, beta = 0.6, 0.4
score = alpha * base_sim + beta * co_sim
similarities.append((self.stored_items[item_key], score))
similarities.sort(key=lambda x: x[1], reverse=True)
return similarities[:top_k]
def get_hash_statistics(self):
stats = {
'total_items': int(self.total_items_added.item()),
'bit_array_utilization': (self.bit_array > 0).float().mean().item(),
'average_confidence': self.confidence_array.mean().item(),
'hash_function_stats': []
}
for i, hash_func in enumerate(self.hash_functions):
hash_stats = {
'function_id': i,
'hebbian_weights_mean': hash_func.hebbian_weights.mean().item(),
'plasticity_rate': hash_func.plasticity_rate.item(),
'activation_threshold_mean': hash_func.activation_threshold.mean().item()
}
stats['hash_function_stats'].append(hash_stats)
return stats
def apply_temporal_decay(self):
decay_rate = torch.clamp(self.decay_rate, 0.9, 0.999)
with torch.no_grad():
self.confidence_array.mul_(decay_rate)
self.access_counts.mul_(decay_rate)
low_confidence_mask = self.confidence_array < 0.1
self.bit_array[low_confidence_mask] = 0.0
self.confidence_array[low_confidence_mask] = 0.0
for hash_func in self.hash_functions:
hash_func.apply_forgetting(float(decay_rate.item()))
def optimize_structure(self):
with torch.no_grad():
high_access_ratio = (self.access_counts > self.access_counts.mean()).float().mean().item()
adjustment = -0.01 * high_access_ratio
for hash_func in self.hash_functions:
hash_func.activation_threshold.data.add_(adjustment)
hash_func.activation_threshold.data.clamp_(-1.0, 1.0)
###########################################################################################################################################
############################################- - - ASSOCIATIVE HEBBIAN BLOOM SYSTEM - - -###############################################
class AssociativeHebbianBloomSystem(nn.Module):
def __init__(self, capacity=10000, vector_dim=64, num_filters=3):
super().__init__()
self.capacity = capacity
self.vector_dim = vector_dim
self.num_filters = num_filters
self.filters = nn.ModuleList([
HebbianBloomFilter(
capacity=capacity // num_filters,
error_rate=0.01,
vector_dim=vector_dim,
num_hash_functions=6
) for _ in range(num_filters)
])
self.filter_selector = nn.Sequential(
nn.Linear(vector_dim, vector_dim // 2),
nn.ReLU(),
nn.Linear(vector_dim // 2, num_filters),
nn.Softmax(dim=-1)
)
self.global_association_net = nn.Sequential(
nn.Linear(vector_dim * 2, vector_dim),
nn.Tanh(),
nn.Linear(vector_dim, 1),
nn.Sigmoid()
)
self.register_buffer('global_access_count', torch.tensor(0, dtype=torch.long))
def add_item(self, item, category=None, associated_items=None):
item_vector = item_to_vector(item, self.vector_dim)
filter_weights = self.filter_selector(item_vector.unsqueeze(0)).squeeze(0)
with torch.no_grad():
loads = torch.tensor([f.total_items_added.item() / max(1, f.capacity) for f in self.filters], dtype=filter_weights.dtype, device=filter_weights.device)
filter_weights = filter_weights - 0.1 * loads
top_k_filters = min(2, self.num_filters) # Use top 2 filters
_, top_indices = torch.topk(filter_weights, top_k_filters)
added_to_filters = []
for filter_idx in top_indices:
filter_obj = self.filters[filter_idx.item()]
indices = filter_obj.add(item, associated_items)
added_to_filters.append((filter_idx.item(), indices))
with torch.no_grad():
self.global_access_count.add_(1)
return added_to_filters
def query_item(self, item, return_detailed=False):
item_vector = item_to_vector(item, self.vector_dim)
results = []
confidences = []
for i, filter_obj in enumerate(self.filters):
is_member, confidence = filter_obj.query(item, return_confidence=True)
results.append(is_member)
confidences.append(confidence)
positive_votes = sum(results)
avg_confidence = np.mean(confidences)
ensemble_decision = positive_votes > len(self.filters) // 2
if return_detailed:
return {
'is_member': ensemble_decision,
'confidence': avg_confidence,
'individual_results': list(zip(results, confidences)),
'positive_votes': positive_votes,
'total_filters': len(self.filters)
}
return ensemble_decision
def find_associations(self, query_item, top_k=10):
all_similarities = []
for filter_obj in self.filters:
similarities = filter_obj.find_similar_items(query_item, top_k)
all_similarities.extend(similarities)
unique_items = {}
for item, similarity in all_similarities:
item_key = str(item)
if item_key in unique_items:
unique_items[item_key] = max(unique_items[item_key], similarity)
else:
unique_items[item_key] = similarity
ranked_items = sorted(unique_items.items(), key=lambda x: x[1], reverse=True)
return ranked_items[:top_k]
def system_maintenance(self):
for filter_obj in self.filters:
filter_obj.apply_temporal_decay()
filter_obj.optimize_structure()
if self.global_access_count % 1000 == 0:
self._global_optimization()
def _global_optimization(self):
print("Performing global Hebbian Bloom system optimization...")
filter_utilizations = []
for filter_obj in self.filters:
stats = filter_obj.get_hash_statistics()
utilization = stats['bit_array_utilization']
filter_utilizations.append(utilization)
def get_system_statistics(self):
"""Get comprehensive system statistics."""
stats = {
'global_access_count': int(self.global_access_count.item()),
'num_filters': self.num_filters,
'filter_statistics': []
}
for i, filter_obj in enumerate(self.filters):
filter_stats = filter_obj.get_hash_statistics()
filter_stats['filter_id'] = i
stats['filter_statistics'].append(filter_stats)
return stats
###########################################################################################################################################