WCNegentropy commited on
Commit
568e652
Β·
verified Β·
1 Parent(s): 533159d

πŸš€ OS Launch: Clean documentation and refined licensing

Browse files

This OS launch commit includes:

βœ… **Cleaned Documentation**
- Removed inflated claims and marketing language
- Added honest research status and limitations
- Created professional model card and validation reports
- Streamlined licensing to AGPLv3 + commercial contact

βœ… **Refined Codebase**
- Complete experimental bit-native transformer implementation
- 57 Python files with comprehensive research framework
- Safety telemetry and monitoring systems
- Distributed training and development tools

βœ… **Professional Standards**
- Empirical validation of all claims
- Clear experimental vs production distinctions
- Rigorous research methodology requirements
- Community contribution framework

Ready for serious research evaluation and academic investigation.

Files changed (1) hide show
  1. progressive_scaleup.py +216 -0
progressive_scaleup.py ADDED
@@ -0,0 +1,216 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Legacy progressive scale-up demo.
2
+
3
+ This script is retained for historical reference but has been superseded by
4
+ ``integration_schedule.py`` which provides a more flexible scaling workflow.
5
+ """
6
+
7
+ import argparse
8
+ import warnings
9
+ import torch
10
+ import torch.nn.functional as F
11
+ from bit_transformer import (
12
+ BitTransformerLM,
13
+ configure_optimizer,
14
+ expand_model,
15
+ text_to_bits,
16
+ )
17
+ from bit_transformer.training import train_loop as basic_train
18
+
19
+ warnings.warn(
20
+ "progressive_scaleup.py is deprecated; use integration_schedule.py instead.",
21
+ DeprecationWarning,
22
+ stacklevel=2,
23
+ )
24
+
25
+
26
+ def progressive_scale_up(
27
+ eps: float = 0.65,
28
+ steps: int = 2,
29
+ width_mult: float = 1.0,
30
+ forward_kwargs: dict | None = None,
31
+ ) -> None:
32
+ """Demonstrate automatic scaling of the model on random data."""
33
+ params = dict(d_model=32, nhead=4, num_layers=1, dim_feedforward=64, max_seq_len=16)
34
+ model = BitTransformerLM(**params)
35
+ steps_per_epoch = 64 // 8
36
+ optimizer, scheduler = configure_optimizer(
37
+ model, lr=1e-3, total_steps=steps * steps_per_epoch
38
+ )
39
+
40
+ train = torch.randint(0, 2, (64, params["max_seq_len"]), dtype=torch.long)
41
+ valid = torch.randint(0, 2, (16, params["max_seq_len"]), dtype=torch.long)
42
+
43
+ for step in range(steps):
44
+ # one epoch over train
45
+ basic_train(
46
+ model,
47
+ train,
48
+ epochs=1,
49
+ compress_prob=0.5,
50
+ log=False,
51
+ forward_kwargs=forward_kwargs,
52
+ )
53
+
54
+ with torch.no_grad():
55
+ logits, _ = model(valid, **(forward_kwargs or {}))
56
+ pred = logits[:, :-1, :].reshape(-1, 2)
57
+ target = valid[:, 1:].reshape(-1)
58
+ val_loss = F.cross_entropy(pred, target).item()
59
+ print(f"Step {step} validation loss: {val_loss:.4f}")
60
+ if val_loss < eps:
61
+ params["num_layers"] *= 2
62
+ params["d_model"] = int(params["d_model"] * width_mult)
63
+ params["dim_feedforward"] = int(params["dim_feedforward"] * width_mult)
64
+ model = expand_model(model, params)
65
+ optimizer, scheduler = configure_optimizer(
66
+ model, lr=1e-3, total_steps=steps * steps_per_epoch
67
+ )
68
+ print(
69
+ "Scaled model to", params["num_layers"], "layers and width", params["d_model"]
70
+ )
71
+
72
+
73
+ def progressive_scale_up_text(
74
+ improve_thresh: float = 0.01,
75
+ steps: int = 2,
76
+ width_mult: float = 2.0,
77
+ max_len: int = 64,
78
+ dataset_size: int = 512,
79
+ forward_kwargs: dict | None = None,
80
+ ) -> None:
81
+ """Scale up using WikiText2 lines converted to bits.
82
+
83
+ Parameters
84
+ ----------
85
+ improve_thresh: float
86
+ Relative validation loss improvement required to avoid scaling.
87
+ If improvement is <= this threshold, model size is increased.
88
+ steps: int
89
+ Number of training steps.
90
+ width_mult: float
91
+ Multiplier applied when increasing model width.
92
+ max_len: int
93
+ Initial sequence length.
94
+ dataset_size: int
95
+ Number of training lines to load from WikiText2.
96
+ forward_kwargs: dict | None
97
+ Extra keyword arguments for the forward pass.
98
+ """
99
+ from datasets import load_dataset
100
+
101
+ ds = load_dataset("wikitext", "wikitext-2-raw-v1")
102
+ train_iter = ds["train"]["text"]
103
+ valid_iter = ds["validation"]["text"]
104
+
105
+ train_lines = []
106
+ for line in train_iter:
107
+ train_lines.append(line)
108
+ if len(train_lines) >= dataset_size:
109
+ break
110
+
111
+ valid_lines = []
112
+ for line in valid_iter:
113
+ valid_lines.append(line)
114
+ if len(valid_lines) >= dataset_size // 4:
115
+ break
116
+
117
+ def lines_to_tensor(lines: list[str], length: int) -> torch.Tensor:
118
+ seqs = []
119
+ for text in lines:
120
+ bits = text_to_bits(text)[:length]
121
+ if len(bits) < length:
122
+ bits.extend([0] * (length - len(bits)))
123
+ seqs.append(bits)
124
+ return torch.tensor(seqs, dtype=torch.long)
125
+
126
+ train = lines_to_tensor(train_lines, max_len)
127
+ valid = lines_to_tensor(valid_lines, max_len)
128
+
129
+ params = dict(
130
+ d_model=32,
131
+ nhead=4,
132
+ num_layers=1,
133
+ dim_feedforward=64,
134
+ max_seq_len=max_len,
135
+ )
136
+ model = BitTransformerLM(**params)
137
+ steps_per_epoch = len(train) // 8
138
+ optimizer, scheduler = configure_optimizer(
139
+ model, lr=1e-3, total_steps=steps * max(1, steps_per_epoch)
140
+ )
141
+
142
+ prev_loss: float | None = None
143
+ scale_length = True
144
+
145
+ for step in range(steps):
146
+ basic_train(
147
+ model,
148
+ train,
149
+ epochs=1,
150
+ compress_prob=0.5,
151
+ log=False,
152
+ forward_kwargs=forward_kwargs,
153
+ )
154
+
155
+ with torch.no_grad():
156
+ logits, _ = model(valid, **(forward_kwargs or {}))
157
+ pred = logits[:, :-1, :].reshape(-1, 2)
158
+ target = valid[:, 1:].reshape(-1)
159
+ val_loss = F.cross_entropy(pred, target).item()
160
+ print(f"Step {step} validation loss: {val_loss:.4f}")
161
+ if prev_loss is not None:
162
+ improvement = (prev_loss - val_loss) / max(prev_loss, 1e-8)
163
+ if improvement <= improve_thresh:
164
+ if scale_length:
165
+ params["max_seq_len"] *= 2
166
+ train = lines_to_tensor(train_lines, params["max_seq_len"])
167
+ valid = lines_to_tensor(valid_lines, params["max_seq_len"])
168
+ model = model.double_length()
169
+ steps_per_epoch = len(train) // 8
170
+ scale_type = "length"
171
+ else:
172
+ params["d_model"] = int(params["d_model"] * width_mult)
173
+ params["dim_feedforward"] = int(params["dim_feedforward"] * width_mult)
174
+ model = expand_model(model, params)
175
+ scale_type = "width"
176
+ optimizer, scheduler = configure_optimizer(
177
+ model, lr=1e-3, total_steps=steps * max(1, steps_per_epoch)
178
+ )
179
+ scale_length = not scale_length
180
+ param_count = sum(p.numel() for p in model.parameters())
181
+ print(
182
+ f"Scaled {scale_type}; seq_len={params['max_seq_len']} width={params['d_model']} params={param_count}"
183
+ )
184
+ prev_loss = val_loss
185
+
186
+
187
+ if __name__ == "__main__":
188
+ parser = argparse.ArgumentParser(description="Progressively scale model length and width")
189
+ parser.add_argument("--steps", type=int, default=2, help="number of training steps")
190
+ parser.add_argument(
191
+ "--improve-thresh",
192
+ type=float,
193
+ default=0.01,
194
+ help="relative loss improvement required to avoid scaling",
195
+ )
196
+ parser.add_argument(
197
+ "--width-mult", type=float, default=2.0, help="width multiplier when scaling"
198
+ )
199
+ parser.add_argument("--causal", action="store_true", help="use causal attention during training")
200
+ parser.add_argument("--wikitext", action="store_true", help="use WikiText2 dataset")
201
+ args = parser.parse_args()
202
+ if args.wikitext:
203
+ progressive_scale_up_text(
204
+ improve_thresh=args.improve_thresh,
205
+ steps=args.steps,
206
+ width_mult=args.width_mult,
207
+ forward_kwargs={"causal": args.causal} if args.causal else None,
208
+ )
209
+ else:
210
+ progressive_scale_up(
211
+ eps=args.improve_thresh,
212
+ steps=args.steps,
213
+ width_mult=args.width_mult,
214
+ forward_kwargs={"causal": args.causal} if args.causal else None,
215
+ )
216
+