|
import torch |
|
import torch.nn.functional as F |
|
from datasets import load_dataset |
|
from bit_transformer import text_to_bits, collapse_submodel |
|
from progressive_scaleup import progressive_scale_up_text |
|
|
|
|
|
def lines_to_bits(lines, max_len=64): |
|
data = [] |
|
for text in lines: |
|
bits = text_to_bits(text)[:max_len] |
|
if len(bits) < max_len: |
|
bits.extend([0] * (max_len - len(bits))) |
|
data.append(bits) |
|
return data |
|
|
|
|
|
def main(): |
|
ds = load_dataset("wikitext", "wikitext-2-raw-v1", split="train[:1%]") |
|
val_ds = load_dataset("wikitext", "wikitext-2-raw-v1", split="validation[:1%]") |
|
train_lines = [item["text"] for item in ds][:256] |
|
valid_lines = [item["text"] for item in val_ds][:64] |
|
|
|
train_bits = lines_to_bits(train_lines) |
|
valid_bits = lines_to_bits(valid_lines) |
|
|
|
progressive_scale_up_text( |
|
eps=0.65, |
|
steps=4, |
|
width_mult=2.0, |
|
max_len=64, |
|
dataset_size=min(64, len(train_bits)), |
|
) |
|
|
|
target_params = dict(d_model=16, nhead=4, num_layers=1, dim_feedforward=64, max_seq_len=64) |
|
model, _ = collapse_submodel(train_bits[:64], target_params, max_rounds=1) |
|
|
|
val_tensor = torch.tensor(valid_bits, dtype=torch.long) |
|
logits, _ = model(val_tensor) |
|
pred = logits[:, :-1, :].reshape(-1, 2) |
|
target = val_tensor[:, 1:].reshape(-1) |
|
loss = F.cross_entropy(pred, target) |
|
print("Collapsed model validation loss:", loss.item()) |
|
|
|
|
|
if __name__ == "__main__": |
|
main() |
|
|