WCNegentropy commited on
Commit
ac13ef9
Β·
verified Β·
1 Parent(s): 90c4ba8

πŸš€ OS Launch: Clean documentation and refined licensing

Browse files

This OS launch commit includes:

βœ… **Cleaned Documentation**
- Removed inflated claims and marketing language
- Added honest research status and limitations
- Created professional model card and validation reports
- Streamlined licensing to AGPLv3 + commercial contact

βœ… **Refined Codebase**
- Complete experimental bit-native transformer implementation
- 57 Python files with comprehensive research framework
- Safety telemetry and monitoring systems
- Distributed training and development tools

βœ… **Professional Standards**
- Empirical validation of all claims
- Clear experimental vs production distinctions
- Rigorous research methodology requirements
- Community contribution framework

Ready for serious research evaluation and academic investigation.

Files changed (1) hide show
  1. enhanced_checkpoint_system.py +374 -0
enhanced_checkpoint_system.py ADDED
@@ -0,0 +1,374 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Enhanced checkpointing system for BitTransformerLM with multiple training runs support.
4
+ Optimized for Claude Code environment with HF Pro + 20GB persistent storage.
5
+ """
6
+
7
+ import os
8
+ import json
9
+ import shutil
10
+ import logging
11
+ from pathlib import Path
12
+ from typing import Dict, Any, Optional, List, Union
13
+ from datetime import datetime
14
+ import torch
15
+ from huggingface_hub import HfApi, hf_hub_download
16
+
17
+ from bit_transformer.error_handling import with_error_recovery, safe_operation
18
+ from bit_transformer.types import PathLike, ModelConfig, TrainingConfig
19
+
20
+ logger = logging.getLogger(__name__)
21
+
22
+
23
+ class EnhancedCheckpointManager:
24
+ """Advanced checkpoint management for multiple training runs with HF integration."""
25
+
26
+ def __init__(self,
27
+ base_dir: PathLike = "/data/checkpoints",
28
+ hf_repo_id: str = "WCNegentropy/BitTransformerLM",
29
+ hf_token: Optional[str] = None,
30
+ max_local_checkpoints: int = 5):
31
+
32
+ self.base_dir = Path(base_dir)
33
+ self.base_dir.mkdir(parents=True, exist_ok=True)
34
+
35
+ self.hf_repo_id = hf_repo_id
36
+ self.hf_token = hf_token or os.getenv("HF_TOKEN")
37
+ self.api = HfApi(token=self.hf_token) if self.hf_token else None
38
+
39
+ self.max_local_checkpoints = max_local_checkpoints
40
+
41
+ # Training session tracking
42
+ self.sessions_dir = self.base_dir / "training_sessions"
43
+ self.sessions_dir.mkdir(exist_ok=True)
44
+
45
+ # Best models storage
46
+ self.best_models_dir = self.base_dir / "best_models"
47
+ self.best_models_dir.mkdir(exist_ok=True)
48
+
49
+ def create_training_session(self,
50
+ session_name: str,
51
+ model_config: ModelConfig,
52
+ training_config: TrainingConfig) -> str:
53
+ """Create a new training session with metadata."""
54
+
55
+ timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
56
+ session_id = f"{session_name}_{timestamp}"
57
+ session_dir = self.sessions_dir / session_id
58
+ session_dir.mkdir(exist_ok=True)
59
+
60
+ # Save session metadata
61
+ metadata = {
62
+ "session_id": session_id,
63
+ "session_name": session_name,
64
+ "created_at": timestamp,
65
+ "model_config": model_config,
66
+ "training_config": training_config,
67
+ "checkpoints": [],
68
+ "best_metric": None,
69
+ "status": "active"
70
+ }
71
+
72
+ with open(session_dir / "metadata.json", "w") as f:
73
+ json.dump(metadata, f, indent=2, default=str)
74
+
75
+ logger.info(f"Created training session: {session_id}")
76
+ return session_id
77
+
78
+ @with_error_recovery(recovery_value=False)
79
+ def save_checkpoint(self,
80
+ model: torch.nn.Module,
81
+ session_id: str,
82
+ epoch: int,
83
+ metrics: Dict[str, float],
84
+ optimizer_state: Optional[Dict] = None,
85
+ scheduler_state: Optional[Dict] = None,
86
+ additional_data: Optional[Dict] = None) -> bool:
87
+ """Save checkpoint with comprehensive metadata."""
88
+
89
+ session_dir = self.sessions_dir / session_id
90
+ if not session_dir.exists():
91
+ raise ValueError(f"Training session {session_id} not found")
92
+
93
+ # Create checkpoint directory
94
+ checkpoint_name = f"checkpoint_epoch_{epoch:04d}"
95
+ checkpoint_dir = session_dir / checkpoint_name
96
+ checkpoint_dir.mkdir(exist_ok=True)
97
+
98
+ # Save model state
99
+ model_path = checkpoint_dir / "model.pt"
100
+ torch.save({
101
+ 'model_state_dict': model.state_dict(),
102
+ 'epoch': epoch,
103
+ 'metrics': metrics,
104
+ 'model_config': getattr(model, 'config', {}),
105
+ 'timestamp': datetime.now().isoformat()
106
+ }, model_path)
107
+
108
+ # Save optimizer state if provided
109
+ if optimizer_state:
110
+ torch.save(optimizer_state, checkpoint_dir / "optimizer.pt")
111
+
112
+ # Save scheduler state if provided
113
+ if scheduler_state:
114
+ torch.save(scheduler_state, checkpoint_dir / "scheduler.pt")
115
+
116
+ # Save additional data
117
+ if additional_data:
118
+ with open(checkpoint_dir / "additional_data.json", "w") as f:
119
+ json.dump(additional_data, f, indent=2, default=str)
120
+
121
+ # Update session metadata
122
+ self._update_session_metadata(session_id, checkpoint_name, metrics)
123
+
124
+ # Cleanup old checkpoints to save space
125
+ self._cleanup_old_checkpoints(session_dir)
126
+
127
+ logger.info(f"Saved checkpoint {checkpoint_name} for session {session_id}")
128
+ return True
129
+
130
+ def load_checkpoint(self,
131
+ session_id: str,
132
+ checkpoint_name: Optional[str] = None,
133
+ model: Optional[torch.nn.Module] = None) -> Dict[str, Any]:
134
+ """Load checkpoint with all associated data."""
135
+
136
+ session_dir = self.sessions_dir / session_id
137
+ if not session_dir.exists():
138
+ raise ValueError(f"Training session {session_id} not found")
139
+
140
+ # Use latest checkpoint if none specified
141
+ if checkpoint_name is None:
142
+ checkpoints = [d for d in session_dir.iterdir()
143
+ if d.is_dir() and d.name.startswith("checkpoint_")]
144
+ if not checkpoints:
145
+ raise ValueError(f"No checkpoints found for session {session_id}")
146
+ checkpoint_name = max(checkpoints, key=lambda x: x.name).name
147
+
148
+ checkpoint_dir = session_dir / checkpoint_name
149
+ if not checkpoint_dir.exists():
150
+ raise ValueError(f"Checkpoint {checkpoint_name} not found in session {session_id}")
151
+
152
+ # Load model state
153
+ model_path = checkpoint_dir / "model.pt"
154
+ checkpoint_data = torch.load(model_path, map_location='cpu', weights_only=False)
155
+
156
+ if model is not None:
157
+ model.load_state_dict(checkpoint_data['model_state_dict'])
158
+
159
+ # Load optimizer state if exists
160
+ optimizer_state = None
161
+ optimizer_path = checkpoint_dir / "optimizer.pt"
162
+ if optimizer_path.exists():
163
+ optimizer_state = torch.load(optimizer_path, map_location='cpu', weights_only=False)
164
+
165
+ # Load scheduler state if exists
166
+ scheduler_state = None
167
+ scheduler_path = checkpoint_dir / "scheduler.pt"
168
+ if scheduler_path.exists():
169
+ scheduler_state = torch.load(scheduler_path, map_location='cpu', weights_only=False)
170
+
171
+ # Load additional data if exists
172
+ additional_data = {}
173
+ additional_path = checkpoint_dir / "additional_data.json"
174
+ if additional_path.exists():
175
+ with open(additional_path) as f:
176
+ additional_data = json.load(f)
177
+
178
+ return {
179
+ 'model_data': checkpoint_data,
180
+ 'optimizer_state': optimizer_state,
181
+ 'scheduler_state': scheduler_state,
182
+ 'additional_data': additional_data,
183
+ 'checkpoint_path': str(checkpoint_dir)
184
+ }
185
+
186
+ def save_best_model(self,
187
+ session_id: str,
188
+ model: torch.nn.Module,
189
+ metric_name: str,
190
+ metric_value: float,
191
+ is_better_func: callable = lambda x, y: x > y) -> bool:
192
+ """Save model if it achieves best performance."""
193
+
194
+ best_model_path = self.best_models_dir / f"{session_id}_best.pt"
195
+ best_meta_path = self.best_models_dir / f"{session_id}_best_meta.json"
196
+
197
+ # Check if this is the best model so far
198
+ current_best = None
199
+ if best_meta_path.exists():
200
+ with open(best_meta_path) as f:
201
+ current_best = json.load(f)
202
+
203
+ if current_best is None or is_better_func(metric_value, current_best['metric_value']):
204
+ # Save new best model
205
+ torch.save({
206
+ 'model_state_dict': model.state_dict(),
207
+ 'metric_name': metric_name,
208
+ 'metric_value': metric_value,
209
+ 'session_id': session_id,
210
+ 'timestamp': datetime.now().isoformat()
211
+ }, best_model_path)
212
+
213
+ # Save metadata
214
+ with open(best_meta_path, "w") as f:
215
+ json.dump({
216
+ 'metric_name': metric_name,
217
+ 'metric_value': metric_value,
218
+ 'session_id': session_id,
219
+ 'timestamp': datetime.now().isoformat()
220
+ }, f, indent=2)
221
+
222
+ logger.info(f"New best model saved for session {session_id}: {metric_name}={metric_value}")
223
+ return True
224
+
225
+ return False
226
+
227
+ def push_to_hf(self,
228
+ session_id: str,
229
+ checkpoint_name: Optional[str] = None,
230
+ include_optimizer: bool = False) -> bool:
231
+ """Push checkpoint to HuggingFace Hub."""
232
+
233
+ if not self.api:
234
+ logger.error("HuggingFace API not available - check token")
235
+ return False
236
+
237
+ try:
238
+ checkpoint_data = self.load_checkpoint(session_id, checkpoint_name)
239
+ checkpoint_dir = Path(checkpoint_data['checkpoint_path'])
240
+
241
+ # Upload model weights
242
+ self.api.upload_file(
243
+ path_or_fileobj=str(checkpoint_dir / "model.pt"),
244
+ path_in_repo=f"checkpoints/{session_id}/model.pt",
245
+ repo_id=self.hf_repo_id,
246
+ commit_message=f"Upload checkpoint {checkpoint_name or 'latest'} from session {session_id}"
247
+ )
248
+
249
+ # Upload optimizer state if requested and exists
250
+ if include_optimizer and (checkpoint_dir / "optimizer.pt").exists():
251
+ self.api.upload_file(
252
+ path_or_fileobj=str(checkpoint_dir / "optimizer.pt"),
253
+ path_in_repo=f"checkpoints/{session_id}/optimizer.pt",
254
+ repo_id=self.hf_repo_id
255
+ )
256
+
257
+ logger.info(f"Successfully pushed checkpoint to HuggingFace: {self.hf_repo_id}")
258
+ return True
259
+
260
+ except Exception as e:
261
+ logger.error(f"Failed to push to HuggingFace: {e}")
262
+ return False
263
+
264
+ def pull_from_hf(self,
265
+ session_id: str,
266
+ local_session_id: Optional[str] = None) -> bool:
267
+ """Pull checkpoint from HuggingFace Hub."""
268
+
269
+ if not self.api:
270
+ logger.error("HuggingFace API not available - check token")
271
+ return False
272
+
273
+ try:
274
+ local_session = local_session_id or session_id
275
+ local_dir = self.sessions_dir / local_session / "checkpoint_from_hf"
276
+ local_dir.mkdir(parents=True, exist_ok=True)
277
+
278
+ # Download model weights
279
+ model_file = hf_hub_download(
280
+ repo_id=self.hf_repo_id,
281
+ filename=f"checkpoints/{session_id}/model.pt",
282
+ local_dir=str(local_dir),
283
+ local_dir_use_symlinks=False
284
+ )
285
+
286
+ logger.info(f"Successfully pulled checkpoint from HuggingFace to {local_dir}")
287
+ return True
288
+
289
+ except Exception as e:
290
+ logger.error(f"Failed to pull from HuggingFace: {e}")
291
+ return False
292
+
293
+ def get_storage_usage(self) -> Dict[str, Any]:
294
+ """Get detailed storage usage breakdown."""
295
+
296
+ def get_dir_size(path: Path) -> int:
297
+ total = 0
298
+ for item in path.rglob('*'):
299
+ if item.is_file():
300
+ total += item.stat().st_size
301
+ return total
302
+
303
+ usage = {
304
+ 'total_gb': get_dir_size(self.base_dir) / 1e9,
305
+ 'sessions_gb': get_dir_size(self.sessions_dir) / 1e9,
306
+ 'best_models_gb': get_dir_size(self.best_models_dir) / 1e9,
307
+ 'num_sessions': len(list(self.sessions_dir.iterdir())),
308
+ 'num_best_models': len(list(self.best_models_dir.glob('*_best.pt'))),
309
+ }
310
+
311
+ # Get per-session breakdown
312
+ sessions = []
313
+ for session_dir in self.sessions_dir.iterdir():
314
+ if session_dir.is_dir():
315
+ sessions.append({
316
+ 'session_id': session_dir.name,
317
+ 'size_gb': get_dir_size(session_dir) / 1e9,
318
+ 'num_checkpoints': len(list(session_dir.glob('checkpoint_*')))
319
+ })
320
+
321
+ usage['sessions'] = sorted(sessions, key=lambda x: x['size_gb'], reverse=True)
322
+
323
+ return usage
324
+
325
+ def _update_session_metadata(self, session_id: str, checkpoint_name: str, metrics: Dict[str, float]):
326
+ """Update session metadata with new checkpoint info."""
327
+ metadata_path = self.sessions_dir / session_id / "metadata.json"
328
+
329
+ with open(metadata_path) as f:
330
+ metadata = json.load(f)
331
+
332
+ metadata['checkpoints'].append({
333
+ 'name': checkpoint_name,
334
+ 'metrics': metrics,
335
+ 'timestamp': datetime.now().isoformat()
336
+ })
337
+
338
+ # Update best metric if applicable
339
+ if 'loss' in metrics:
340
+ if metadata['best_metric'] is None or metrics['loss'] < metadata['best_metric'].get('loss', float('inf')):
341
+ metadata['best_metric'] = metrics.copy()
342
+
343
+ with open(metadata_path, "w") as f:
344
+ json.dump(metadata, f, indent=2, default=str)
345
+
346
+ def _cleanup_old_checkpoints(self, session_dir: Path):
347
+ """Remove oldest checkpoints to stay within limits."""
348
+ checkpoints = sorted([d for d in session_dir.iterdir()
349
+ if d.is_dir() and d.name.startswith("checkpoint_")],
350
+ key=lambda x: x.stat().st_mtime)
351
+
352
+ while len(checkpoints) > self.max_local_checkpoints:
353
+ old_checkpoint = checkpoints.pop(0)
354
+ shutil.rmtree(old_checkpoint)
355
+ logger.info(f"Cleaned up old checkpoint: {old_checkpoint.name}")
356
+
357
+
358
+ # Convenience functions for easy usage
359
+ def create_checkpoint_manager(hf_token: str = "os.environ.get('HF_TOKEN', 'your-token-here')") -> EnhancedCheckpointManager:
360
+ """Create a pre-configured checkpoint manager for this environment."""
361
+ return EnhancedCheckpointManager(
362
+ base_dir="/data/checkpoints",
363
+ hf_repo_id="WCNegentropy/BitTransformerLM",
364
+ hf_token=hf_token,
365
+ max_local_checkpoints=3 # Conservative for 20GB storage
366
+ )
367
+
368
+
369
+ if __name__ == "__main__":
370
+ # Demo usage
371
+ manager = create_checkpoint_manager()
372
+ usage = manager.get_storage_usage()
373
+ print(f"Current storage usage: {usage['total_gb']:.2f} GB")
374
+ print(f"Number of training sessions: {usage['num_sessions']}")