|
import io |
|
import time |
|
import contextlib |
|
from pathlib import Path |
|
import sys |
|
import torch |
|
|
|
ROOT = Path(__file__).resolve().parents[1] |
|
if str(ROOT) not in sys.path: |
|
sys.path.insert(0, str(ROOT)) |
|
|
|
from progressive_scaleup import progressive_scale_up_text |
|
from unified_workflow import run_workflow |
|
from bit_transformer.bit_io import text_to_bits |
|
from bit_transformer.safety import hil_safe_inference |
|
|
|
|
|
def capture_run(func, *args, **kwargs): |
|
buf = io.StringIO() |
|
start = time.time() |
|
with contextlib.redirect_stdout(buf): |
|
result = func(*args, **kwargs) |
|
duration = time.time() - start |
|
return result, buf.getvalue(), duration |
|
|
|
|
|
def main() -> None: |
|
summary: list[str] = [] |
|
|
|
_, log, dur = capture_run( |
|
progressive_scale_up_text, |
|
improve_thresh=0.01, |
|
steps=10, |
|
width_mult=2.0, |
|
max_len=64, |
|
dataset_size=512, |
|
forward_kwargs={"causal": True}, |
|
) |
|
summary.append("### Progressive Scale-Up (causal=True)\n") |
|
summary.append(log.strip()) |
|
summary.append(f"Duration: {dur:.2f}s\n") |
|
|
|
_, log, dur = capture_run( |
|
progressive_scale_up_text, |
|
improve_thresh=0.01, |
|
steps=10, |
|
width_mult=2.0, |
|
max_len=64, |
|
dataset_size=512, |
|
forward_kwargs={"causal": False}, |
|
) |
|
summary.append("### Progressive Scale-Up (causal=False)\n") |
|
summary.append(log.strip()) |
|
summary.append(f"Duration: {dur:.2f}s\n") |
|
|
|
(model, _), log, dur = capture_run( |
|
run_workflow, |
|
steps=2, |
|
max_len=32, |
|
dataset_size=32, |
|
plateau_steps=1, |
|
epochs_per_step=1, |
|
extra_steps=1, |
|
diffusion=False, |
|
) |
|
bits = text_to_bits("hi") |
|
tensor = torch.tensor(bits, dtype=torch.long).unsqueeze(0) |
|
out_bits, _ = hil_safe_inference(model, tensor, c_floor=0.0, s_floor=0.0) |
|
summary.append("### Unified Workflow (causal=True)\n") |
|
summary.append(log.strip()) |
|
summary.append(f"Inference on 'hi': {out_bits.squeeze(0).tolist()}\n") |
|
summary.append(f"Duration: {dur:.2f}s\n") |
|
|
|
(_, _), log, dur = capture_run( |
|
run_workflow, |
|
steps=2, |
|
max_len=32, |
|
dataset_size=32, |
|
plateau_steps=1, |
|
epochs_per_step=1, |
|
extra_steps=1, |
|
diffusion=True, |
|
) |
|
summary.append("### Unified Workflow (causal=False / Diffusion)\n") |
|
summary.append(log.strip()) |
|
summary.append(f"Duration: {dur:.2f}s\n") |
|
|
|
report = "\n".join(summary) |
|
print(report) |
|
|
|
|
|
if __name__ == "__main__": |
|
main() |
|
|