import torch import torch.nn.functional as F from datasets import load_dataset from bit_transformer import text_to_bits, collapse_submodel from progressive_scaleup import progressive_scale_up_text def lines_to_bits(lines, max_len=64): data = [] for text in lines: bits = text_to_bits(text)[:max_len] if len(bits) < max_len: bits.extend([0] * (max_len - len(bits))) data.append(bits) return data def main(): ds = load_dataset("wikitext", "wikitext-2-raw-v1", split="train[:1%]") val_ds = load_dataset("wikitext", "wikitext-2-raw-v1", split="validation[:1%]") train_lines = [item["text"] for item in ds][:256] valid_lines = [item["text"] for item in val_ds][:64] train_bits = lines_to_bits(train_lines) valid_bits = lines_to_bits(valid_lines) progressive_scale_up_text( eps=0.65, steps=4, width_mult=2.0, max_len=64, dataset_size=min(64, len(train_bits)), ) target_params = dict(d_model=16, nhead=4, num_layers=1, dim_feedforward=64, max_seq_len=64) model, _ = collapse_submodel(train_bits[:64], target_params, max_rounds=1) val_tensor = torch.tensor(valid_bits, dtype=torch.long) logits, _ = model(val_tensor) pred = logits[:, :-1, :].reshape(-1, 2) target = val_tensor[:, 1:].reshape(-1) loss = F.cross_entropy(pred, target) print("Collapsed model validation loss:", loss.item()) if __name__ == "__main__": main()