WCNegentropy commited on
Commit
90889c5
Β·
verified Β·
1 Parent(s): 568e652

πŸš€ OS Launch: Clean documentation and refined licensing

Browse files

This OS launch commit includes:

βœ… **Cleaned Documentation**
- Removed inflated claims and marketing language
- Added honest research status and limitations
- Created professional model card and validation reports
- Streamlined licensing to AGPLv3 + commercial contact

βœ… **Refined Codebase**
- Complete experimental bit-native transformer implementation
- 57 Python files with comprehensive research framework
- Safety telemetry and monitoring systems
- Distributed training and development tools

βœ… **Professional Standards**
- Empirical validation of all claims
- Clear experimental vs production distinctions
- Rigorous research methodology requirements
- Community contribution framework

Ready for serious research evaluation and academic investigation.

Files changed (1) hide show
  1. quick_training_run.py +339 -0
quick_training_run.py ADDED
@@ -0,0 +1,339 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Full end-to-end BitTransformerLM training run with all optimizations!
4
+ Small scale test to validate our enhanced system.
5
+ """
6
+
7
+ import torch
8
+ import torch.nn as nn
9
+ import torch.optim as optim
10
+ from torch.utils.data import Dataset, DataLoader
11
+ import numpy as np
12
+ import logging
13
+ from pathlib import Path
14
+ import time
15
+ from typing import List, Dict, Any
16
+
17
+ # Import our enhanced modules
18
+ from bit_transformer.model import BitTransformerLM
19
+ from bit_transformer.compression import compress_bits_batch, model_output_decompress
20
+ from bit_transformer.error_handling import safe_model_forward, setup_error_logging
21
+ from bit_transformer.types import BitSequence, TelemetryDict
22
+ from enhanced_checkpoint_system import create_checkpoint_manager
23
+
24
+ # Setup logging
25
+ logger = setup_error_logging("INFO")
26
+
27
+ class SimpleBitDataset(Dataset):
28
+ """Simple dataset of bit sequences for training."""
29
+
30
+ def __init__(self, num_samples: int = 1000, seq_length: int = 128):
31
+ self.num_samples = num_samples
32
+ self.seq_length = seq_length
33
+ self.data = self._generate_bit_sequences()
34
+
35
+ def _generate_bit_sequences(self) -> List[torch.Tensor]:
36
+ """Generate diverse bit sequences with different patterns."""
37
+ sequences = []
38
+
39
+ # Pattern 1: Alternating sequences
40
+ for i in range(self.num_samples // 4):
41
+ pattern = torch.tensor([i % 2] * self.seq_length, dtype=torch.long)
42
+ sequences.append(pattern)
43
+
44
+ # Pattern 2: Random sequences
45
+ for i in range(self.num_samples // 4):
46
+ pattern = torch.randint(0, 2, (self.seq_length,), dtype=torch.long)
47
+ sequences.append(pattern)
48
+
49
+ # Pattern 3: Structured patterns (runs)
50
+ for i in range(self.num_samples // 4):
51
+ pattern = []
52
+ pos = 0
53
+ while pos < self.seq_length:
54
+ run_length = min(np.random.randint(1, 20), self.seq_length - pos)
55
+ bit_value = np.random.randint(0, 2)
56
+ pattern.extend([bit_value] * run_length)
57
+ pos += run_length
58
+ pattern = torch.tensor(pattern[:self.seq_length], dtype=torch.long)
59
+ sequences.append(pattern)
60
+
61
+ # Pattern 4: Fibonacci-like sequences
62
+ remaining = self.num_samples - len(sequences)
63
+ for i in range(remaining):
64
+ pattern = [0, 1]
65
+ while len(pattern) < self.seq_length:
66
+ pattern.append(pattern[-1] ^ pattern[-2]) # XOR of last two bits
67
+ pattern = torch.tensor(pattern[:self.seq_length], dtype=torch.long)
68
+ sequences.append(pattern)
69
+
70
+ return sequences
71
+
72
+ def __len__(self):
73
+ return len(self.data)
74
+
75
+ def __getitem__(self, idx):
76
+ sequence = self.data[idx]
77
+ # For language modeling, input is sequence[:-1], target is sequence[1:]
78
+ return sequence[:-1], sequence[1:]
79
+
80
+
81
+ def compute_safety_metrics(predictions: torch.Tensor, targets: torch.Tensor) -> Dict[str, float]:
82
+ """Compute K/C/S safety metrics."""
83
+ pred_bits = (predictions > 0.5).float().flatten()
84
+
85
+ # K metric (Negentropy): Measure of order vs randomness
86
+ if len(pred_bits) > 0:
87
+ prob_1 = pred_bits.mean().item()
88
+ prob_0 = 1 - prob_1
89
+ if prob_0 > 0 and prob_1 > 0:
90
+ entropy = -prob_0 * np.log2(prob_0) - prob_1 * np.log2(prob_1)
91
+ negentropy = 1.0 - entropy # Higher = more ordered
92
+ else:
93
+ negentropy = 1.0 if prob_1 == 1.0 or prob_1 == 0.0 else 0.0
94
+ else:
95
+ negentropy = 0.0
96
+
97
+ # C metric (Complexity): Simple run-length approximation
98
+ changes = (pred_bits[1:] != pred_bits[:-1]).sum().item()
99
+ complexity = min(changes / len(pred_bits), 1.0) if len(pred_bits) > 1 else 0.0
100
+
101
+ # S metric (Symbiosis): Alignment with target distribution
102
+ target_bits = targets.float().flatten()
103
+ if len(target_bits) > 0:
104
+ target_mean = target_bits.mean()
105
+ pred_mean = pred_bits.mean()
106
+ symbiosis = 1.0 - abs(target_mean - pred_mean).item()
107
+ else:
108
+ symbiosis = 1.0
109
+
110
+ return {
111
+ 'K_negentropy': negentropy,
112
+ 'C_complexity': complexity,
113
+ 'S_symbiosis': symbiosis
114
+ }
115
+
116
+
117
+ def train_bittransformer():
118
+ """Main training function with all optimizations."""
119
+
120
+ logger.info("πŸš€ Starting BitTransformerLM end-to-end training run!")
121
+
122
+ # Model configuration - small but meaningful
123
+ model_config = {
124
+ 'd_model': 256,
125
+ 'nhead': 8,
126
+ 'num_layers': 4,
127
+ 'dim_feedforward': 512,
128
+ 'max_seq_len': 128,
129
+ 'use_checkpoint': True,
130
+ 'chunk_size': None, # Disable chunking for small model
131
+ }
132
+
133
+ training_config = {
134
+ 'batch_size': 16,
135
+ 'learning_rate': 1e-3,
136
+ 'num_epochs': 10,
137
+ 'save_every_n_epochs': 2,
138
+ 'log_every_n_steps': 10
139
+ }
140
+
141
+ # Initialize enhanced checkpoint manager
142
+ checkpoint_manager = create_checkpoint_manager()
143
+ session_id = checkpoint_manager.create_training_session(
144
+ session_name="end_to_end_test",
145
+ model_config=model_config,
146
+ training_config=training_config
147
+ )
148
+
149
+ logger.info(f"πŸ“ Created training session: {session_id}")
150
+
151
+ # Create dataset and dataloader
152
+ logger.info("πŸ“Š Creating training dataset...")
153
+ dataset = SimpleBitDataset(num_samples=800, seq_length=model_config['max_seq_len'])
154
+ dataloader = DataLoader(dataset, batch_size=training_config['batch_size'], shuffle=True)
155
+
156
+ # Initialize model
157
+ logger.info("🧠 Initializing BitTransformerLM model...")
158
+ model = BitTransformerLM(
159
+ d_model=model_config['d_model'],
160
+ nhead=model_config['nhead'],
161
+ num_layers=model_config['num_layers'],
162
+ dim_feedforward=model_config['dim_feedforward'],
163
+ max_seq_len=model_config['max_seq_len'],
164
+ use_checkpoint=model_config['use_checkpoint'],
165
+ chunk_size=model_config['chunk_size']
166
+ )
167
+
168
+ # Count parameters
169
+ total_params = sum(p.numel() for p in model.parameters())
170
+ trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
171
+ logger.info(f"πŸ”’ Model parameters: {total_params:,} total, {trainable_params:,} trainable")
172
+
173
+ # Setup optimizer and loss
174
+ optimizer = optim.AdamW(model.parameters(), lr=training_config['learning_rate'])
175
+ scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=training_config['num_epochs'])
176
+ criterion = nn.CrossEntropyLoss()
177
+
178
+ # Training loop
179
+ logger.info("πŸƒβ€β™‚οΈ Starting training loop...")
180
+
181
+ for epoch in range(training_config['num_epochs']):
182
+ model.train()
183
+ epoch_loss = 0.0
184
+ epoch_metrics = {'K_negentropy': 0.0, 'C_complexity': 0.0, 'S_symbiosis': 0.0}
185
+ num_batches = 0
186
+
187
+ start_time = time.time()
188
+
189
+ for batch_idx, (inputs, targets) in enumerate(dataloader):
190
+ optimizer.zero_grad()
191
+
192
+ # Forward pass with safety monitoring
193
+ try:
194
+ # BitTransformerLM returns (logits, telemetry)
195
+ output = safe_model_forward(model, inputs)
196
+ if isinstance(output, tuple):
197
+ logits, telemetry = output
198
+ else:
199
+ logits = output
200
+ telemetry = {}
201
+
202
+ # BitTransformerLM outputs logits for binary classification
203
+ # Shape should be [batch, seq_len, 2] for binary vocab
204
+ if logits.dim() == 2:
205
+ # If [batch*seq_len, 2], already flattened
206
+ logits_flat = logits
207
+ targets_flat = targets.reshape(-1)
208
+ else:
209
+ # If [batch, seq_len, 2], flatten
210
+ logits_flat = logits.reshape(-1, 2)
211
+ targets_flat = targets.reshape(-1)
212
+
213
+ loss = criterion(logits_flat, targets_flat)
214
+
215
+ # Backward pass
216
+ loss.backward()
217
+ torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
218
+ optimizer.step()
219
+
220
+ # Compute metrics
221
+ with torch.no_grad():
222
+ # Handle different logits shapes for predictions
223
+ if logits.dim() == 2:
224
+ # [batch*seq_len, 2] -> reshape back to [batch, seq_len, 2]
225
+ batch_size = inputs.shape[0]
226
+ seq_len = inputs.shape[1]
227
+ logits_reshaped = logits.reshape(batch_size, seq_len, 2)
228
+ predictions = torch.softmax(logits_reshaped, dim=-1)[:, :, 1] # Prob of bit=1
229
+ else:
230
+ # [batch, seq_len, 2]
231
+ predictions = torch.softmax(logits, dim=-1)[:, :, 1] # Prob of bit=1
232
+
233
+ safety_metrics = compute_safety_metrics(predictions, targets)
234
+
235
+ epoch_loss += loss.item()
236
+ for key, value in safety_metrics.items():
237
+ epoch_metrics[key] += value
238
+ num_batches += 1
239
+
240
+ # Logging
241
+ if batch_idx % training_config['log_every_n_steps'] == 0:
242
+ logger.info(f"Epoch {epoch+1}/{training_config['num_epochs']}, "
243
+ f"Batch {batch_idx}/{len(dataloader)}, "
244
+ f"Loss: {loss.item():.4f}, "
245
+ f"K: {safety_metrics['K_negentropy']:.3f}, "
246
+ f"C: {safety_metrics['C_complexity']:.3f}, "
247
+ f"S: {safety_metrics['S_symbiosis']:.3f}")
248
+
249
+ except Exception as e:
250
+ logger.error(f"Error in batch {batch_idx}: {e}")
251
+ continue
252
+
253
+ # End of epoch processing
254
+ scheduler.step()
255
+ epoch_time = time.time() - start_time
256
+
257
+ if num_batches > 0:
258
+ avg_loss = epoch_loss / num_batches
259
+ avg_metrics = {k: v / num_batches for k, v in epoch_metrics.items()}
260
+
261
+ logger.info(f"βœ… Epoch {epoch+1} completed in {epoch_time:.2f}s")
262
+ logger.info(f"πŸ“Š Avg Loss: {avg_loss:.4f}")
263
+ logger.info(f"πŸ“ˆ Safety Metrics - K: {avg_metrics['K_negentropy']:.3f}, "
264
+ f"C: {avg_metrics['C_complexity']:.3f}, "
265
+ f"S: {avg_metrics['S_symbiosis']:.3f}")
266
+
267
+ # Save checkpoint
268
+ if (epoch + 1) % training_config['save_every_n_epochs'] == 0:
269
+ checkpoint_success = checkpoint_manager.save_checkpoint(
270
+ model=model,
271
+ session_id=session_id,
272
+ epoch=epoch + 1,
273
+ metrics={
274
+ 'loss': avg_loss,
275
+ 'learning_rate': scheduler.get_last_lr()[0],
276
+ **avg_metrics
277
+ },
278
+ optimizer_state=optimizer.state_dict(),
279
+ scheduler_state=scheduler.state_dict()
280
+ )
281
+
282
+ if checkpoint_success:
283
+ logger.info(f"πŸ’Ύ Checkpoint saved for epoch {epoch+1}")
284
+
285
+ # Save best model if loss improved
286
+ checkpoint_manager.save_best_model(
287
+ session_id=session_id,
288
+ model=model,
289
+ metric_name='loss',
290
+ metric_value=avg_loss,
291
+ is_better_func=lambda x, y: x < y # Lower loss is better
292
+ )
293
+
294
+ logger.info("πŸŽ‰ Training completed successfully!")
295
+
296
+ # Test inference and compression
297
+ logger.info("πŸ§ͺ Testing model inference and compression...")
298
+
299
+ model.eval()
300
+ with torch.no_grad():
301
+ # Create a test sequence
302
+ test_input = torch.randint(0, 2, (1, 64), dtype=torch.long)
303
+ logger.info(f"πŸ“₯ Input sequence: {test_input.squeeze().tolist()}")
304
+
305
+ # Model inference
306
+ output_logits = model(test_input)
307
+ output_probs = torch.softmax(output_logits, dim=-1)
308
+ predicted_bits = torch.argmax(output_probs, dim=-1)
309
+
310
+ logger.info(f"πŸ“€ Predicted sequence: {predicted_bits.squeeze().tolist()}")
311
+
312
+ # Test compression
313
+ compressed = compress_bits_batch(predicted_bits)
314
+ logger.info(f"πŸ—œοΈ Compressed length: {len(compressed[0])} (original: {predicted_bits.shape[-1]})")
315
+
316
+ # Decompress to verify
317
+ decompressed = model_output_decompress(compressed)
318
+ compression_success = torch.equal(predicted_bits, decompressed)
319
+ logger.info(f"βœ… Compression/decompression successful: {compression_success}")
320
+
321
+ # Final storage usage report
322
+ storage_usage = checkpoint_manager.get_storage_usage()
323
+ logger.info(f"πŸ’Ύ Final storage usage: {storage_usage['total_gb']:.3f} GB")
324
+ logger.info(f"πŸ“ Training sessions: {storage_usage['num_sessions']}")
325
+
326
+ return session_id, model, checkpoint_manager
327
+
328
+
329
+ if __name__ == "__main__":
330
+ try:
331
+ session_id, trained_model, manager = train_bittransformer()
332
+ print(f"\nπŸŽ‰ SUCCESS! Training session completed: {session_id}")
333
+ print(f"πŸ” Use checkpoint_manager.load_checkpoint('{session_id}') to resume")
334
+
335
+ except Exception as e:
336
+ logger.error(f"❌ Training failed: {e}")
337
+ import traceback
338
+ traceback.print_exc()
339
+ raise