WCNegentropy commited on
Commit
6193257
·
verified ·
1 Parent(s): c2af0a0

Clear existing file: wrinklebrane_dataset_builder.py

Browse files
Files changed (1) hide show
  1. wrinklebrane_dataset_builder.py +0 -723
wrinklebrane_dataset_builder.py DELETED
@@ -1,723 +0,0 @@
1
- """
2
- WrinkleBrane Dataset Builder & HuggingFace Integration
3
-
4
- Creates curated datasets optimized for associative memory training with
5
- membrane storage, interference studies, and orthogonality benchmarks.
6
- """
7
-
8
- import os
9
- import json
10
- import gzip
11
- import random
12
- import math
13
- from typing import List, Dict, Any, Optional, Tuple, Union
14
- from pathlib import Path
15
- from datetime import datetime
16
- import tempfile
17
-
18
- import torch
19
- import numpy as np
20
- from datasets import Dataset, DatasetDict
21
- from huggingface_hub import HfApi, login, create_repo
22
-
23
-
24
- class WrinkleBraneDatasetBuilder:
25
- """
26
- Comprehensive dataset builder for WrinkleBrane associative memory training.
27
-
28
- Generates:
29
- - Key-value pairs for associative memory tasks
30
- - Visual patterns (MNIST-style, geometric shapes)
31
- - Interference benchmark sequences
32
- - Orthogonality optimization data
33
- - Persistence decay studies
34
- """
35
-
36
- def __init__(self, hf_token: str, repo_id: str = "WrinkleBrane"):
37
- """Initialize with HuggingFace credentials."""
38
- self.hf_token = hf_token
39
- self.repo_id = repo_id
40
- self.api = HfApi()
41
-
42
- # Login to HuggingFace
43
- login(token=hf_token)
44
-
45
- # Dataset configuration
46
- self.config = {
47
- "version": "1.0.0",
48
- "created": datetime.now().isoformat(),
49
- "model_compatibility": "WrinkleBrane",
50
- "membrane_encoding": "2D_spatial_maps",
51
- "default_H": 64,
52
- "default_W": 64,
53
- "default_L": 64, # membrane layers
54
- "default_K": 64, # codebook size
55
- "total_samples": 20000,
56
- "quality_thresholds": {
57
- "min_fidelity_psnr": 20.0,
58
- "max_interference_rms": 0.1,
59
- "min_orthogonality": 0.8
60
- }
61
- }
62
-
63
- def generate_visual_memory_pairs(self, num_samples: int = 5000, H: int = 64, W: int = 64) -> List[Dict]:
64
- """Generate visual key-value pairs for associative memory."""
65
- samples = []
66
-
67
- visual_types = [
68
- "mnist_digits",
69
- "geometric_shapes",
70
- "noise_patterns",
71
- "edge_features",
72
- "texture_patches",
73
- "sparse_dots"
74
- ]
75
-
76
- for i in range(num_samples):
77
- visual_type = random.choice(visual_types)
78
-
79
- # Generate key pattern
80
- key_pattern = self._generate_visual_pattern(visual_type, H, W, is_key=True)
81
-
82
- # Generate corresponding value pattern
83
- value_pattern = self._generate_visual_pattern(visual_type, H, W, is_key=False)
84
-
85
- # Compute quality metrics
86
- fidelity_psnr = self._compute_psnr(key_pattern, value_pattern)
87
- orthogonality = self._compute_orthogonality(key_pattern.flatten(), value_pattern.flatten())
88
- compressibility = self._compute_gzip_ratio(key_pattern)
89
-
90
- sample = {
91
- "id": f"visual_{visual_type}_{i:06d}",
92
- "key_pattern": key_pattern.tolist(),
93
- "value_pattern": value_pattern.tolist(),
94
- "pattern_type": visual_type,
95
- "H": H,
96
- "W": W,
97
- "fidelity_psnr": float(fidelity_psnr),
98
- "orthogonality": float(orthogonality),
99
- "compressibility": float(compressibility),
100
- "category": "visual_memory",
101
- # Consistent schema fields
102
- "interference_rms": None,
103
- "persistence_lambda": None,
104
- "codebook_type": None,
105
- "capacity_load": None,
106
- "time_step": None,
107
- "energy_retention": None,
108
- "temporal_correlation": None,
109
- "L": None,
110
- "K": None,
111
- "reconstruction_error": None,
112
- "reconstructed_pattern": None,
113
- "codebook_matrix": None
114
- }
115
- samples.append(sample)
116
-
117
- return samples
118
-
119
- def generate_synthetic_maps(self, num_samples: int = 3000, H: int = 64, W: int = 64) -> List[Dict]:
120
- """Generate synthetic spatial pattern mappings."""
121
- samples = []
122
-
123
- map_types = [
124
- "gaussian_fields",
125
- "spiral_patterns",
126
- "frequency_domains",
127
- "cellular_automata",
128
- "fractal_structures",
129
- "gradient_maps"
130
- ]
131
-
132
- for i in range(num_samples):
133
- map_type = random.choice(map_types)
134
-
135
- # Generate synthetic key-value mapping
136
- key_map = self._generate_synthetic_map(map_type, H, W, seed=i*2)
137
- value_map = self._generate_synthetic_map(map_type, H, W, seed=i*2+1)
138
-
139
- # Apply transformation relationship
140
- value_map = self._apply_map_transform(key_map, value_map, map_type)
141
-
142
- # Compute metrics
143
- fidelity_psnr = self._compute_psnr(key_map, value_map)
144
- orthogonality = self._compute_orthogonality(key_map.flatten(), value_map.flatten())
145
- compressibility = self._compute_gzip_ratio(key_map)
146
-
147
- sample = {
148
- "id": f"synthetic_{map_type}_{i:06d}",
149
- "key_pattern": key_map.tolist(),
150
- "value_pattern": value_map.tolist(),
151
- "pattern_type": map_type,
152
- "H": H,
153
- "W": W,
154
- "fidelity_psnr": float(fidelity_psnr),
155
- "orthogonality": float(orthogonality),
156
- "compressibility": float(compressibility),
157
- "category": "synthetic_maps",
158
- # Consistent schema fields
159
- "interference_rms": None,
160
- "persistence_lambda": None,
161
- "codebook_type": None,
162
- "capacity_load": None,
163
- "time_step": None,
164
- "energy_retention": None,
165
- "temporal_correlation": None,
166
- "L": None,
167
- "K": None,
168
- "reconstruction_error": None,
169
- "reconstructed_pattern": None,
170
- "codebook_matrix": None
171
- }
172
- samples.append(sample)
173
-
174
- return samples
175
-
176
- def generate_interference_studies(self, num_samples: int = 2000, H: int = 64, W: int = 64) -> List[Dict]:
177
- """Generate data for studying memory interference and capacity limits."""
178
- samples = []
179
-
180
- # Test different capacity loads
181
- capacity_loads = [0.1, 0.25, 0.5, 0.75, 0.9, 0.95, 0.99]
182
-
183
- for load in capacity_loads:
184
- load_samples = int(num_samples * 0.14) # Distribute across loads
185
-
186
- for i in range(load_samples):
187
- # Generate multiple overlapping patterns to study interference
188
- num_patterns = max(1, int(64 * load)) # Scale with capacity load
189
-
190
- patterns = []
191
- for p in range(min(num_patterns, 10)): # Limit for memory
192
- pattern = np.random.randn(H, W).astype(np.float32)
193
- pattern = (pattern - pattern.mean()) / pattern.std() # Normalize
194
- patterns.append(pattern)
195
-
196
- # Create composite pattern (sum of all patterns)
197
- composite = np.sum(patterns, axis=0) / len(patterns)
198
- target = patterns[0] if patterns else composite # Try to retrieve first pattern
199
-
200
- # Compute interference metrics
201
- interference_rms = self._compute_interference_rms(patterns, target)
202
- fidelity_psnr = self._compute_psnr(composite, target)
203
- orthogonality = self._compute_pattern_orthogonality(patterns)
204
-
205
- sample = {
206
- "id": f"interference_load_{load}_{i:06d}",
207
- "key_pattern": composite.tolist(),
208
- "value_pattern": target.tolist(),
209
- "pattern_type": "interference_test",
210
- "H": H,
211
- "W": W,
212
- "capacity_load": float(load),
213
- "interference_rms": float(interference_rms),
214
- "fidelity_psnr": float(fidelity_psnr),
215
- "orthogonality": float(orthogonality),
216
- "category": "interference_study",
217
- # Consistent schema fields
218
- "compressibility": None,
219
- "persistence_lambda": None,
220
- "codebook_type": None,
221
- "time_step": None,
222
- "energy_retention": None,
223
- "temporal_correlation": None,
224
- "L": None,
225
- "K": None,
226
- "reconstruction_error": None,
227
- "reconstructed_pattern": None,
228
- "codebook_matrix": None
229
- }
230
- samples.append(sample)
231
-
232
- return samples
233
-
234
- def generate_orthogonality_benchmarks(self, num_samples: int = 1500, L: int = 64, K: int = 64) -> List[Dict]:
235
- """Generate codebook optimization data for orthogonality studies."""
236
- samples = []
237
-
238
- codebook_types = [
239
- "hadamard",
240
- "random_orthogonal",
241
- "dct_basis",
242
- "wavelet_basis",
243
- "learned_sparse"
244
- ]
245
-
246
- for codebook_type in codebook_types:
247
- type_samples = num_samples // len(codebook_types)
248
-
249
- for i in range(type_samples):
250
- # Generate codebook matrix C[L, K]
251
- codebook = self._generate_codebook(codebook_type, L, K, seed=i)
252
-
253
- # Test multiple read/write operations
254
- H, W = 64, 64
255
- test_key = np.random.randn(H, W).astype(np.float32)
256
- test_value = np.random.randn(H, W).astype(np.float32)
257
-
258
- # Simulate membrane write and read
259
- written_membrane, read_result = self._simulate_membrane_operation(
260
- codebook, test_key, test_value, H, W
261
- )
262
-
263
- # Compute orthogonality metrics
264
- orthogonality = self._compute_codebook_orthogonality(codebook)
265
- reconstruction_error = np.mean((test_value - read_result) ** 2)
266
-
267
- sample = {
268
- "id": f"orthogonal_{codebook_type}_{i:06d}",
269
- "key_pattern": test_key.tolist(),
270
- "value_pattern": test_value.tolist(),
271
- "reconstructed_pattern": read_result.tolist(),
272
- "codebook_matrix": codebook.tolist(),
273
- "pattern_type": "orthogonality_test",
274
- "codebook_type": codebook_type,
275
- "H": H,
276
- "W": W,
277
- "L": L,
278
- "K": K,
279
- "orthogonality": float(orthogonality),
280
- "reconstruction_error": float(reconstruction_error),
281
- "category": "orthogonality_benchmark",
282
- # Consistent schema fields
283
- "fidelity_psnr": None,
284
- "compressibility": None,
285
- "interference_rms": None,
286
- "persistence_lambda": None,
287
- "capacity_load": None,
288
- "time_step": None,
289
- "energy_retention": None,
290
- "temporal_correlation": None
291
- }
292
- samples.append(sample)
293
-
294
- return samples
295
-
296
- def generate_persistence_traces(self, num_samples: int = 1000, H: int = 64, W: int = 64) -> List[Dict]:
297
- """Generate temporal decay studies for persistence analysis."""
298
- samples = []
299
-
300
- # Test different decay rates
301
- lambda_values = [0.95, 0.97, 0.98, 0.99, 0.995]
302
- time_steps = [1, 5, 10, 20, 50, 100]
303
-
304
- for lambda_val in lambda_values:
305
- for time_step in time_steps:
306
- step_samples = max(1, num_samples // (len(lambda_values) * len(time_steps)))
307
-
308
- for i in range(step_samples):
309
- # Generate initial pattern
310
- initial_pattern = np.random.randn(H, W).astype(np.float32)
311
- initial_pattern = (initial_pattern - initial_pattern.mean()) / initial_pattern.std()
312
-
313
- # Simulate temporal decay: M_t+1 = λ * M_t
314
- decayed_pattern = initial_pattern * (lambda_val ** time_step)
315
-
316
- # Add noise for realism
317
- noise_level = 0.01 * (1 - lambda_val) # More noise for faster decay
318
- noise = np.random.normal(0, noise_level, (H, W)).astype(np.float32)
319
- decayed_pattern += noise
320
-
321
- # Compute persistence metrics
322
- energy_retention = np.mean(decayed_pattern ** 2) / np.mean(initial_pattern ** 2)
323
- correlation = np.corrcoef(initial_pattern.flatten(), decayed_pattern.flatten())[0, 1]
324
-
325
- sample = {
326
- "id": f"persistence_l{lambda_val}_t{time_step}_{i:06d}",
327
- "key_pattern": initial_pattern.tolist(),
328
- "value_pattern": decayed_pattern.tolist(),
329
- "pattern_type": "persistence_decay",
330
- "persistence_lambda": float(lambda_val),
331
- "time_step": int(time_step),
332
- "H": H,
333
- "W": W,
334
- "energy_retention": float(energy_retention),
335
- "temporal_correlation": float(correlation if not np.isnan(correlation) else 0.0),
336
- "category": "persistence_trace",
337
- # Consistent schema fields - set all to None for consistency
338
- "fidelity_psnr": None,
339
- "orthogonality": None,
340
- "compressibility": None,
341
- "interference_rms": None,
342
- "codebook_type": None,
343
- "capacity_load": None,
344
- # Additional fields that other samples might have
345
- "L": None,
346
- "K": None,
347
- "reconstruction_error": None,
348
- "reconstructed_pattern": None,
349
- "codebook_matrix": None
350
- }
351
- samples.append(sample)
352
-
353
- return samples
354
-
355
- def _generate_visual_pattern(self, pattern_type: str, H: int, W: int, is_key: bool = True) -> np.ndarray:
356
- """Generate visual patterns for different types."""
357
- if pattern_type == "mnist_digits":
358
- # Simple digit-like patterns
359
- digit = random.randint(0, 9)
360
- pattern = self._create_digit_pattern(digit, H, W)
361
- if not is_key:
362
- # For value, create slightly transformed version
363
- pattern = self._apply_simple_transform(pattern, "rotate_small")
364
-
365
- elif pattern_type == "geometric_shapes":
366
- shape = random.choice(["circle", "square", "triangle", "cross"])
367
- pattern = self._create_geometric_pattern(shape, H, W)
368
- if not is_key:
369
- pattern = self._apply_simple_transform(pattern, "scale")
370
-
371
- elif pattern_type == "noise_patterns":
372
- pattern = np.random.randn(H, W).astype(np.float32)
373
- pattern = (pattern - pattern.mean()) / pattern.std()
374
- if not is_key:
375
- pattern = pattern + 0.1 * np.random.randn(H, W)
376
-
377
- else:
378
- # Default random pattern
379
- pattern = np.random.uniform(-1, 1, (H, W)).astype(np.float32)
380
-
381
- return pattern
382
-
383
- def _generate_synthetic_map(self, map_type: str, H: int, W: int, seed: int) -> np.ndarray:
384
- """Generate synthetic spatial maps."""
385
- np.random.seed(seed)
386
-
387
- if map_type == "gaussian_fields":
388
- # Random Gaussian field
389
- x, y = np.meshgrid(np.linspace(-2, 2, W), np.linspace(-2, 2, H))
390
- pattern = np.exp(-(x**2 + y**2) / (2 * (0.5 + random.random())**2))
391
-
392
- elif map_type == "spiral_patterns":
393
- # Spiral pattern
394
- x, y = np.meshgrid(np.linspace(-np.pi, np.pi, W), np.linspace(-np.pi, np.pi, H))
395
- r = np.sqrt(x**2 + y**2)
396
- theta = np.arctan2(y, x)
397
- pattern = np.sin(r * 3 + theta * random.randint(1, 5))
398
-
399
- elif map_type == "frequency_domains":
400
- # Frequency domain pattern
401
- freq_x, freq_y = random.randint(1, 8), random.randint(1, 8)
402
- x, y = np.meshgrid(np.linspace(0, 2*np.pi, W), np.linspace(0, 2*np.pi, H))
403
- pattern = np.sin(freq_x * x) * np.cos(freq_y * y)
404
-
405
- else:
406
- # Default random field
407
- pattern = np.random.randn(H, W)
408
-
409
- # Normalize
410
- pattern = (pattern - pattern.mean()) / (pattern.std() + 1e-7)
411
- return pattern.astype(np.float32)
412
-
413
- def _create_digit_pattern(self, digit: int, H: int, W: int) -> np.ndarray:
414
- """Create simple digit-like pattern."""
415
- pattern = np.zeros((H, W), dtype=np.float32)
416
-
417
- # Simple digit patterns
418
- h_center, w_center = H // 2, W // 2
419
- size = min(H, W) // 3
420
-
421
- if digit in [0, 6, 8, 9]:
422
- # Draw circle/oval
423
- y, x = np.ogrid[:H, :W]
424
- mask = ((x - w_center) ** 2 / size**2 + (y - h_center) ** 2 / size**2) <= 1
425
- pattern[mask] = 1.0
426
-
427
- if digit in [1, 4, 7]:
428
- # Draw vertical line
429
- pattern[h_center-size:h_center+size, w_center-2:w_center+2] = 1.0
430
-
431
- # Add some randomization
432
- noise = 0.1 * np.random.randn(H, W)
433
- pattern = np.clip(pattern + noise, -1, 1)
434
-
435
- return pattern
436
-
437
- def _create_geometric_pattern(self, shape: str, H: int, W: int) -> np.ndarray:
438
- """Create geometric shape patterns."""
439
- pattern = np.zeros((H, W), dtype=np.float32)
440
- center_h, center_w = H // 2, W // 2
441
- size = min(H, W) // 4
442
-
443
- if shape == "circle":
444
- y, x = np.ogrid[:H, :W]
445
- mask = ((x - center_w) ** 2 + (y - center_h) ** 2) <= size**2
446
- pattern[mask] = 1.0
447
-
448
- elif shape == "square":
449
- pattern[center_h-size:center_h+size, center_w-size:center_w+size] = 1.0
450
-
451
- elif shape == "cross":
452
- pattern[center_h-size:center_h+size, center_w-3:center_w+3] = 1.0
453
- pattern[center_h-3:center_h+3, center_w-size:center_w+size] = 1.0
454
-
455
- return pattern
456
-
457
- def _apply_simple_transform(self, pattern: np.ndarray, transform: str) -> np.ndarray:
458
- """Apply simple transformations to patterns."""
459
- if transform == "rotate_small":
460
- # Small rotation (simplified)
461
- return np.roll(pattern, random.randint(-2, 2), axis=random.randint(0, 1))
462
- elif transform == "scale":
463
- # Simple scaling via interpolation approximation
464
- return pattern * (0.8 + 0.4 * random.random())
465
- else:
466
- return pattern
467
-
468
- def _apply_map_transform(self, key_map: np.ndarray, value_map: np.ndarray, map_type: str) -> np.ndarray:
469
- """Apply transformation relationship between key and value maps."""
470
- if map_type == "gaussian_fields":
471
- # Value is blurred version of key
472
- return 0.7 * key_map + 0.3 * value_map
473
- elif map_type == "spiral_patterns":
474
- # Value is phase-shifted version
475
- return np.roll(key_map, random.randint(-3, 3), axis=1)
476
- else:
477
- # Default: slightly correlated
478
- return 0.8 * key_map + 0.2 * value_map
479
-
480
- def _compute_psnr(self, pattern1: np.ndarray, pattern2: np.ndarray) -> float:
481
- """Compute Peak Signal-to-Noise Ratio."""
482
- mse = np.mean((pattern1 - pattern2) ** 2)
483
- if mse == 0:
484
- return float('inf')
485
- max_val = max(np.max(pattern1), np.max(pattern2))
486
- psnr = 20 * np.log10(max_val / np.sqrt(mse))
487
- return psnr
488
-
489
- def _compute_orthogonality(self, vec1: np.ndarray, vec2: np.ndarray) -> float:
490
- """Compute orthogonality score between two vectors."""
491
- vec1_norm = vec1 / (np.linalg.norm(vec1) + 1e-7)
492
- vec2_norm = vec2 / (np.linalg.norm(vec2) + 1e-7)
493
- dot_product = np.abs(np.dot(vec1_norm, vec2_norm))
494
- orthogonality = 1.0 - dot_product # 1 = orthogonal, 0 = parallel
495
- return orthogonality
496
-
497
- def _compute_gzip_ratio(self, pattern: np.ndarray) -> float:
498
- """Compute compressibility using gzip ratio."""
499
- # Convert to bytes
500
- pattern_bytes = (pattern * 255).astype(np.uint8).tobytes()
501
- compressed = gzip.compress(pattern_bytes)
502
- ratio = len(compressed) / len(pattern_bytes)
503
- return ratio
504
-
505
- def _compute_interference_rms(self, patterns: List[np.ndarray], target: np.ndarray) -> float:
506
- """Compute RMS interference from multiple patterns."""
507
- if not patterns:
508
- return 0.0
509
-
510
- # Sum all patterns except target
511
- interference = np.zeros_like(target)
512
- for p in patterns[1:]: # Skip first pattern (target)
513
- interference += p
514
-
515
- rms = np.sqrt(np.mean(interference ** 2))
516
- return rms
517
-
518
- def _compute_pattern_orthogonality(self, patterns: List[np.ndarray]) -> float:
519
- """Compute average orthogonality between patterns."""
520
- if len(patterns) < 2:
521
- return 1.0
522
-
523
- orthogonalities = []
524
- for i in range(len(patterns)):
525
- for j in range(i + 1, min(i + 5, len(patterns))): # Limit comparisons
526
- orth = self._compute_orthogonality(patterns[i].flatten(), patterns[j].flatten())
527
- orthogonalities.append(orth)
528
-
529
- return np.mean(orthogonalities) if orthogonalities else 1.0
530
-
531
- def _generate_codebook(self, codebook_type: str, L: int, K: int, seed: int) -> np.ndarray:
532
- """Generate codebook matrix for different types."""
533
- np.random.seed(seed)
534
-
535
- if codebook_type == "hadamard" and L <= 64 and K <= 64:
536
- # Simple Hadamard-like matrix (for small sizes)
537
- codebook = np.random.choice([-1, 1], size=(L, K))
538
-
539
- elif codebook_type == "random_orthogonal":
540
- # Random orthogonal matrix
541
- random_matrix = np.random.randn(L, K)
542
- if L >= K:
543
- q, _ = np.linalg.qr(random_matrix)
544
- codebook = q[:, :K]
545
- else:
546
- codebook = random_matrix
547
-
548
- else:
549
- # Default random matrix
550
- codebook = np.random.randn(L, K) / np.sqrt(L)
551
-
552
- return codebook.astype(np.float32)
553
-
554
- def _simulate_membrane_operation(self, codebook: np.ndarray, key: np.ndarray,
555
- value: np.ndarray, H: int, W: int) -> Tuple[np.ndarray, np.ndarray]:
556
- """Simulate membrane write and read operation."""
557
- L, K = codebook.shape
558
-
559
- # Simulate write: M += alpha * C[:, k] ⊗ V
560
- # For simplicity, use first codebook column
561
- alpha = 1.0
562
- membrane = np.zeros((L, H, W))
563
-
564
- # Write operation (simplified)
565
- for l in range(min(L, 16)): # Limit for memory
566
- membrane[l] = codebook[l, 0] * value
567
-
568
- # Read operation: Y = ReLU(einsum('lhw,lk->khw', M, C))
569
- # Simplified readout
570
- read_result = np.zeros((H, W))
571
- for l in range(min(L, 16)):
572
- read_result += codebook[l, 0] * membrane[l]
573
-
574
- # Apply ReLU
575
- read_result = np.maximum(0, read_result)
576
-
577
- return membrane, read_result.astype(np.float32)
578
-
579
- def _compute_codebook_orthogonality(self, codebook: np.ndarray) -> float:
580
- """Compute orthogonality measure of codebook."""
581
- # Compute Gram matrix G = C^T C
582
- gram = codebook.T @ codebook
583
-
584
- # Orthogonality measure: how close to identity matrix
585
- identity = np.eye(gram.shape[0])
586
- frobenius_dist = np.linalg.norm(gram - identity, 'fro')
587
-
588
- # Normalize by matrix size
589
- orthogonality = 1.0 / (1.0 + frobenius_dist / gram.shape[0])
590
- return orthogonality
591
-
592
- def build_complete_dataset(self) -> DatasetDict:
593
- """Build the complete WrinkleBrane dataset."""
594
- print("🧠 Building WrinkleBrane Dataset...")
595
-
596
- all_samples = []
597
-
598
- # 1. Visual memory pairs (40% of dataset)
599
- print("👁️ Generating visual memory pairs...")
600
- visual_samples = self.generate_visual_memory_pairs(8000)
601
- all_samples.extend(visual_samples)
602
-
603
- # 2. Synthetic maps (25% of dataset)
604
- print("🗺️ Generating synthetic maps...")
605
- map_samples = self.generate_synthetic_maps(5000)
606
- all_samples.extend(map_samples)
607
-
608
- # 3. Interference studies (20% of dataset)
609
- print("⚡ Generating interference studies...")
610
- interference_samples = self.generate_interference_studies(4000)
611
- all_samples.extend(interference_samples)
612
-
613
- # 4. Orthogonality benchmarks (10% of dataset)
614
- print("📐 Generating orthogonality benchmarks...")
615
- orthogonal_samples = self.generate_orthogonality_benchmarks(2000)
616
- all_samples.extend(orthogonal_samples)
617
-
618
- # 5. Persistence traces (5% of dataset)
619
- print("⏰ Generating persistence traces...")
620
- persistence_samples = self.generate_persistence_traces(1000)
621
- all_samples.extend(persistence_samples)
622
-
623
- # Split into train/validation/test
624
- random.shuffle(all_samples)
625
-
626
- total = len(all_samples)
627
- train_split = int(0.8 * total)
628
- val_split = int(0.9 * total)
629
-
630
- train_data = all_samples[:train_split]
631
- val_data = all_samples[train_split:val_split]
632
- test_data = all_samples[val_split:]
633
-
634
- # Create HuggingFace datasets
635
- dataset_dict = DatasetDict({
636
- 'train': Dataset.from_list(train_data),
637
- 'validation': Dataset.from_list(val_data),
638
- 'test': Dataset.from_list(test_data)
639
- })
640
-
641
- print(f"✅ Dataset built: {len(train_data)} train, {len(val_data)} val, {len(test_data)} test")
642
- return dataset_dict
643
-
644
- def upload_to_huggingface(self, dataset: DatasetDict, private: bool = True) -> str:
645
- """Upload dataset to HuggingFace Hub."""
646
- print(f"🌐 Uploading to HuggingFace: {self.repo_id}")
647
-
648
- try:
649
- # Create repository
650
- create_repo(
651
- repo_id=self.repo_id,
652
- repo_type="dataset",
653
- private=private,
654
- exist_ok=True,
655
- token=self.hf_token
656
- )
657
-
658
- # Add dataset metadata
659
- dataset_info = {
660
- "dataset_info": self.config,
661
- "splits": {
662
- "train": len(dataset["train"]),
663
- "validation": len(dataset["validation"]),
664
- "test": len(dataset["test"])
665
- },
666
- "features": {
667
- "id": "string",
668
- "key_pattern": "2D array of floats (H x W)",
669
- "value_pattern": "2D array of floats (H x W)",
670
- "pattern_type": "string",
671
- "H": "integer (height)",
672
- "W": "integer (width)",
673
- "category": "string",
674
- "optional_metrics": "various floats for specific sample types"
675
- },
676
- "usage_notes": [
677
- "Optimized for WrinkleBrane associative memory training",
678
- "Key-value pairs for membrane storage and retrieval",
679
- "Includes interference studies and capacity analysis",
680
- "Supports orthogonality optimization research"
681
- ]
682
- }
683
-
684
- # Push dataset with metadata
685
- dataset.push_to_hub(
686
- repo_id=self.repo_id,
687
- token=self.hf_token,
688
- private=private
689
- )
690
-
691
- # Upload additional metadata
692
- with tempfile.NamedTemporaryFile(mode='w', suffix='.json', delete=False) as f:
693
- json.dump(dataset_info, f, indent=2)
694
- self.api.upload_file(
695
- path_or_fileobj=f.name,
696
- path_in_repo="dataset_info.json",
697
- repo_id=self.repo_id,
698
- repo_type="dataset",
699
- token=self.hf_token
700
- )
701
-
702
- print(f"✅ Dataset uploaded successfully to: https://huggingface.co/datasets/{self.repo_id}")
703
- return f"https://huggingface.co/datasets/{self.repo_id}"
704
-
705
- except Exception as e:
706
- print(f"❌ Upload failed: {e}")
707
- raise
708
-
709
-
710
- def create_wrinklebrane_dataset(hf_token: str, repo_id: str = "WrinkleBrane") -> str:
711
- """
712
- Convenience function to create and upload WrinkleBrane dataset.
713
-
714
- Args:
715
- hf_token: HuggingFace access token
716
- repo_id: Dataset repository ID
717
-
718
- Returns:
719
- URL to the uploaded dataset
720
- """
721
- builder = WrinkleBraneDatasetBuilder(hf_token, repo_id)
722
- dataset = builder.build_complete_dataset()
723
- return builder.upload_to_huggingface(dataset, private=True)