File size: 18,398 Bytes
2f39acc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
#!/usr/bin/env python3
"""
CPU-Optimized Edge Deployment BitTransformerLM Training
Optimized for consumer devices and edge applications.
"""

import os
import time
import torch
import torch.nn.functional as F
from datasets import load_dataset

from bit_transformer import (
    BitTransformerLM,
    text_to_bits,
    bits_to_text,
    train_loop,
    configure_optimizer,
    save_model,
    load_model,
    set_dropout,
    hil_safe_inference,
    quantize_dynamic,
)
from bit_transformer.torch_utils import cpu_autocast
from bit_transformer.training import train_loop


def create_optimal_cpu_model():
    """Create BitTransformerLM optimized for CPU edge deployment."""
    print("🧠 Creating CPU-optimized BitTransformerLM...")
    
    # Optimal configuration for edge devices:
    # - Small model size for low memory footprint
    # - CPU autocast for faster FP16 inference
    # - No reversible layers (simpler for CPU)
    # - Gradient checkpointing disabled for speed
    # - Small context length for efficiency
    
    model = BitTransformerLM(
        d_model=64,           # Small embedding dimension (vs 128 default)
        nhead=4,              # Fewer attention heads (vs 8 default)
        num_layers=3,         # Shallow model (vs 4 default)  
        dim_feedforward=128,  # Smaller FFN (vs 512 default)
        max_seq_len=256,      # Shorter context (vs 1024 default)
        reversible=False,     # Disable reversible layers (CPU doesn't benefit much)
        use_checkpoint=False, # Disable gradient checkpointing (prioritize speed)
        use_autocast=True,    # Enable CPU autocast for BF16 mixed precision
        use_act=False,        # Disable ACT for simplicity
        chunk_size=32,        # Small chunks for memory efficiency
        full_attn_logging=False,  # Disable attention logging to save memory
        lambda_K=1.0,         # Standard telemetry weights
        lambda_C=1.0,
        lambda_S=1.0,
    )
    
    # Calculate model parameters
    total_params = sum(p.numel() for p in model.parameters())
    trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    
    print(f"   πŸ“Š Model Configuration:")
    print(f"     d_model: {64}")
    print(f"     num_layers: {3}")
    print(f"     nhead: {4}")
    print(f"     dim_feedforward: {128}")
    print(f"     max_seq_len: {256}")
    print(f"     Total parameters: {total_params:,}")
    print(f"     Trainable parameters: {trainable_params:,}")
    print(f"     Estimated size: {total_params * 4 / 1024 / 1024:.1f}MB (FP32)")
    print(f"     With autocast: ~{total_params * 2 / 1024 / 1024:.1f}MB (BF16)")
    
    return model


def load_training_dataset(dataset_size=512, max_len=128):
    """Load and prepare training dataset optimized for edge training."""
    print("πŸ“š Loading training dataset...")
    
    try:
        # Try to load BitTransformerLM dataset from HuggingFace
        print("   Attempting to load BitTransformerLM dataset...")
        dataset = load_dataset("WCNegentropy/BitTransformerLM", split="train[:{}]".format(dataset_size))
        if dataset and len(dataset) > 0:
            train_texts = [item['text'] for item in dataset if item.get('text')]
            if len(train_texts) > 0:
                print(f"   βœ… Loaded {len(train_texts)} samples from BitTransformerLM dataset")
            else:
                raise Exception("No text samples found in dataset")
        else:
            raise Exception("Dataset empty or not accessible")
            
    except Exception as e:
        print(f"   ⚠️  BitTransformerLM dataset not available: {e}")
        print("   πŸ“– Falling back to WikiText-2...")
        try:
            # Fallback to WikiText-2 for training
            ds = load_dataset("wikitext", "wikitext-2-raw-v1")
            train_texts = [text for text in ds["train"]["text"] if text.strip()][:dataset_size]
            print(f"   βœ… Loaded {len(train_texts)} samples from WikiText-2")
        except Exception as e2:
            print(f"   ❌ Failed to load WikiText-2: {e2}")
            print("   🎲 Using synthetic text data...")
            # Generate simple synthetic text for demonstration
            synthetic_texts = [
                "The quick brown fox jumps over the lazy dog.",
                "Machine learning is transforming technology.",
                "Edge computing enables local AI processing.",
                "BitTransformerLM uses bit-native processing.",
                "CPU optimization improves inference speed.",
                "Neural networks learn from training data.",
                "Transformers use attention mechanisms.",
                "Language models understand text patterns.",
            ]
            train_texts = (synthetic_texts * (dataset_size // len(synthetic_texts) + 1))[:dataset_size]
            print(f"   βœ… Generated {len(train_texts)} synthetic samples")
    
    # Convert text to bits
    print("   πŸ”„ Converting text to bits...")
    train_sequences = []
    valid_sequences = []
    
    for i, text in enumerate(train_texts):
        try:
            bits = text_to_bits(text)[:max_len]
            if len(bits) < max_len:
                bits.extend([0] * (max_len - len(bits)))  # Pad to max_len
            
            # Use 80/20 split for train/validation
            if i < len(train_texts) * 0.8:
                train_sequences.append(bits)
            else:
                valid_sequences.append(bits)
                
        except Exception as e:
            print(f"   ⚠️  Failed to convert text to bits: {e}")
            continue
    
    train_tensor = torch.tensor(train_sequences, dtype=torch.long)
    valid_tensor = torch.tensor(valid_sequences, dtype=torch.long) if valid_sequences else train_tensor[:16]
    
    print(f"   πŸ“Š Dataset Statistics:")
    print(f"     Training sequences: {len(train_sequences)}")
    print(f"     Validation sequences: {len(valid_sequences)}")
    print(f"     Sequence length: {max_len}")
    print(f"     Training tensor shape: {train_tensor.shape}")
    
    return train_tensor, valid_tensor, train_texts[:len(train_sequences)]


def train_cpu_optimized_model(model, train_data, valid_data, epochs=5):
    """Train the model with CPU-optimized settings."""
    print(f"πŸš€ Training CPU-optimized BitTransformerLM for {epochs} epochs...")
    
    # Set model to training mode
    model.train()
    set_dropout(model, 0.1)
    
    # Configure optimizer for edge deployment
    # Lower learning rate and smaller batch size for stable CPU training
    batch_size = 4  # Small batch size for memory efficiency
    learning_rate = 5e-4  # Conservative learning rate
    total_steps = max(1, epochs * (len(train_data) // batch_size))  # Ensure at least 1 step
    
    if len(train_data) == 0:
        raise ValueError("No training data available - check dataset loading")
    
    optimizer, scheduler = configure_optimizer(
        model,
        lr=learning_rate,
        total_steps=total_steps,
        weight_decay=0.01
    )
    
    print(f"   πŸ“‹ Training Configuration:")
    print(f"     Batch size: {batch_size}")
    print(f"     Learning rate: {learning_rate}")
    print(f"     Total steps: {total_steps}")
    print(f"     CPU autocast: Enabled")
    
    # Training loop with CPU optimizations
    train_losses = []
    
    for epoch in range(epochs):
        print(f"\n   πŸ“– Epoch {epoch + 1}/{epochs}")
        epoch_losses = []
        epoch_start_time = time.time()
        
        # Shuffle training data
        perm = torch.randperm(len(train_data))
        train_data_shuffled = train_data[perm]
        
        # Process in small batches
        for batch_idx in range(0, len(train_data_shuffled), batch_size):
            batch_end = min(batch_idx + batch_size, len(train_data_shuffled))
            batch = train_data_shuffled[batch_idx:batch_end]
            
            if len(batch) == 0:
                continue
            
            optimizer.zero_grad()
            
            # Use CPU autocast for mixed precision
            with cpu_autocast():
                logits, telemetry = model(batch)
                
                # Standard autoregressive loss
                pred = logits[:, :-1, :].reshape(-1, 2)
                target = batch[:, 1:].reshape(-1)
                loss = F.cross_entropy(pred, target)
            
            # Backward pass
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            optimizer.step()
            
            # Only step scheduler if we haven't exceeded total steps
            if scheduler.last_epoch < scheduler.total_steps - 1:
                scheduler.step()
            
            batch_loss = loss.item()
            epoch_losses.append(batch_loss)
            
            # Log progress every 50 steps
            if (batch_idx // batch_size) % 50 == 0:
                avg_loss = sum(epoch_losses[-10:]) / min(10, len(epoch_losses))
                telemetry_str = f"K={telemetry.get('K', 0):.3f}, C={telemetry.get('C', 0):.3f}, S={telemetry.get('S', 0):.3f}"
                print(f"     Step {batch_idx // batch_size}: Loss={avg_loss:.4f}, {telemetry_str}")
        
        epoch_time = time.time() - epoch_start_time
        avg_epoch_loss = sum(epoch_losses) / len(epoch_losses)
        train_losses.append(avg_epoch_loss)
        
        print(f"   ⏱️  Epoch {epoch + 1} completed in {epoch_time:.1f}s, Avg Loss: {avg_epoch_loss:.4f}")
        
        # Validation every epoch
        if len(valid_data) > 0:
            model.eval()
            set_dropout(model, 0.0)
            
            with torch.no_grad():
                with cpu_autocast():
                    val_batch = valid_data[:min(8, len(valid_data))]  # Small validation batch
                    val_logits, val_telemetry = model(val_batch)
                    val_pred = val_logits[:, :-1, :].reshape(-1, 2)
                    val_target = val_batch[:, 1:].reshape(-1)
                    val_loss = F.cross_entropy(val_pred, val_target).item()
            
            print(f"   πŸ“Š Validation Loss: {val_loss:.4f}")
            print(f"   πŸ“ˆ Telemetry - K: {val_telemetry.get('K', 0):.3f}, C: {val_telemetry.get('C', 0):.3f}, S: {val_telemetry.get('S', 0):.3f}")
            
            model.train()
            set_dropout(model, 0.1)
    
    print(f"\nβœ… Training completed!")
    print(f"   Final training loss: {train_losses[-1]:.4f}")
    
    return model, train_losses


def test_model_inference(model, test_texts):
    """Test the trained model with inference and safety checks."""
    print("\nπŸ§ͺ Testing Model Inference...")
    
    model.eval()
    set_dropout(model, 0.0)
    
    # Test basic inference
    test_samples = test_texts[:3]  # Test with first 3 samples
    
    for i, text in enumerate(test_samples):
        print(f"\n   Test {i + 1}: {text[:50]}...")
        
        try:
            # Convert to bits
            input_bits = text_to_bits(text)[:64]  # Shorter for demo
            if len(input_bits) < 64:
                input_bits.extend([0] * (64 - len(input_bits)))
            
            input_tensor = torch.tensor([input_bits], dtype=torch.long)
            
            # Run inference with CPU autocast
            with torch.no_grad():
                with cpu_autocast():
                    logits, telemetry = model(input_tensor)
                    
                    # Generate next tokens
                    next_token_logits = logits[0, -1, :]
                    next_token_probs = F.softmax(next_token_logits, dim=-1)
                    next_token = torch.multinomial(next_token_probs, 1).item()
                    
            print(f"     Input bits: {input_bits[:16]}... (showing first 16)")
            print(f"     Next token prediction: {next_token}")
            print(f"     Next token confidence: {next_token_probs[next_token]:.3f}")
            print(f"     Telemetry - K: {telemetry.get('K', 0):.3f}, C: {telemetry.get('C', 0):.3f}, S: {telemetry.get('S', 0):.3f}")
            
        except Exception as e:
            print(f"     ❌ Inference failed: {e}")
    
    # Test safe inference
    print(f"\nπŸ›‘οΈ Testing Safe Inference...")
    try:
        # Create a simple prompt
        test_prompt = "The future of AI is"
        prompt_bits = text_to_bits(test_prompt)
        prompt_tensor = torch.tensor([prompt_bits], dtype=torch.long)
        
        with cpu_autocast():
            safe_result = hil_safe_inference(model, prompt_tensor, max_new_tokens=16)
            
        if safe_result is not None:
            print(f"   βœ… Safe inference successful")
            print(f"   Generated {len(safe_result[0]) - len(prompt_bits)} new tokens")
        else:
            print(f"   ⚠️  Safe inference blocked by safety gates")
            
    except Exception as e:
        print(f"   ❌ Safe inference test failed: {e}")


def benchmark_cpu_performance(model):
    """Benchmark the model's CPU performance."""
    print("\n⚑ CPU Performance Benchmark...")
    
    model.eval()
    set_dropout(model, 0.0)
    
    # Prepare test data
    batch_sizes = [1, 2, 4]
    sequence_lengths = [32, 64, 128]
    
    results = []
    
    for batch_size in batch_sizes:
        for seq_len in sequence_lengths:
            print(f"\n   Testing batch_size={batch_size}, seq_len={seq_len}")
            
            # Create random test data
            test_data = torch.randint(0, 2, (batch_size, seq_len), dtype=torch.long)
            
            # Warmup
            with torch.no_grad():
                with cpu_autocast():
                    for _ in range(3):
                        _, _ = model(test_data)
            
            # Benchmark
            times = []
            for _ in range(10):
                start_time = time.time()
                with torch.no_grad():
                    with cpu_autocast():
                        logits, telemetry = model(test_data)
                end_time = time.time()
                times.append(end_time - start_time)
            
            avg_time = sum(times) / len(times)
            throughput = (batch_size * seq_len) / avg_time
            
            result = {
                'batch_size': batch_size,
                'seq_len': seq_len,
                'avg_time_ms': avg_time * 1000,
                'throughput_tokens_per_sec': throughput
            }
            results.append(result)
            
            print(f"     Average time: {avg_time * 1000:.2f}ms")
            print(f"     Throughput: {throughput:.0f} tokens/sec")
    
    # Summary
    print(f"\nπŸ“Š Performance Summary:")
    best_throughput = max(results, key=lambda x: x['throughput_tokens_per_sec'])
    print(f"   Best throughput: {best_throughput['throughput_tokens_per_sec']:.0f} tokens/sec")
    print(f"   At batch_size={best_throughput['batch_size']}, seq_len={best_throughput['seq_len']}")
    
    return results


def quantize_for_deployment(model):
    """Apply dynamic quantization for deployment."""
    print("\nπŸ—œοΈ Applying Dynamic Quantization for Deployment...")
    
    try:
        quantized_model = quantize_dynamic(model)
        
        # Compare model sizes
        original_params = sum(p.numel() for p in model.parameters())
        quantized_params = sum(p.numel() for p in quantized_model.parameters())
        
        print(f"   Original parameters: {original_params:,}")
        print(f"   Quantized parameters: {quantized_params:,}")
        print(f"   Model size reduction: ~50% (FP32 -> INT8)")
        
        # Quick inference test
        test_input = torch.randint(0, 2, (1, 32), dtype=torch.long)
        
        with torch.no_grad():
            original_output = model(test_input)
            quantized_output = quantized_model(test_input)
        
        print(f"   βœ… Quantization successful - model still functional")
        
        return quantized_model
        
    except Exception as e:
        print(f"   ❌ Quantization failed: {e}")
        return model


def main():
    """Main training and testing pipeline."""
    print("πŸš€ CPU-Optimized BitTransformerLM Training Pipeline")
    print("="*60)
    
    # Step 1: Create optimal CPU model
    model = create_optimal_cpu_model()
    
    # Step 2: Load training dataset  
    train_data, valid_data, train_texts = load_training_dataset(dataset_size=256, max_len=128)
    
    # Step 3: Train the model
    trained_model, train_losses = train_cpu_optimized_model(model, train_data, valid_data, epochs=3)
    
    # Step 4: Test inference
    test_model_inference(trained_model, train_texts)
    
    # Step 5: Benchmark performance
    benchmark_results = benchmark_cpu_performance(trained_model)
    
    # Step 6: Apply quantization
    quantized_model = quantize_for_deployment(trained_model)
    
    # Step 7: Save models
    print("\nπŸ’Ύ Saving Models...")
    
    # Create weights directory if it doesn't exist
    os.makedirs("weights", exist_ok=True)
    
    try:
        save_model(trained_model, "weights/cpu_edge_model.pt.gz")
        print("   βœ… Saved trained model: weights/cpu_edge_model.pt.gz")
        
        save_model(quantized_model, "weights/cpu_edge_model_quantized.pt.gz")
        print("   βœ… Saved quantized model: weights/cpu_edge_model_quantized.pt.gz")
        
    except Exception as e:
        print(f"   ⚠️  Model saving failed: {e}")
    
    # Final summary
    print("\n" + "="*60)
    print("πŸŽ‰ CPU-Optimized BitTransformerLM Training Complete!")
    print("="*60)
    
    total_params = sum(p.numel() for p in trained_model.parameters())
    final_loss = train_losses[-1] if train_losses else "N/A"
    best_throughput = max(benchmark_results, key=lambda x: x['throughput_tokens_per_sec'])
    
    print(f"πŸ“Š Final Results:")
    print(f"   Model Parameters: {total_params:,}")
    print(f"   Final Training Loss: {final_loss}")
    print(f"   Peak Throughput: {best_throughput['throughput_tokens_per_sec']:.0f} tokens/sec")
    print(f"   Model Size (quantized): ~{total_params * 1 / 1024 / 1024:.1f}MB")
    print(f"   CPU Optimizations: BF16 autocast, no gradient checkpointing, small chunks")
    print(f"   Edge Ready: βœ… Optimized for consumer CPUs")


if __name__ == "__main__":
    main()