from __future__ import annotations import gzip import os import shutil import tempfile from typing import Optional import torch from huggingface_hub import HfApi, hf_hub_download, login REPO_ID = "WCNegentropy/BitTransformerLM" FILENAME = "model.pt.gz" def hf_login(token: Optional[str] = None) -> None: """Authenticate with Hugging Face. The ``token`` may be provided directly or via the ``HF_TOKEN`` environment variable. If omitted entirely, the library will attempt an interactive login. """ login(token=token) def save_checkpoint( model: torch.nn.Module, *, repo_id: str = REPO_ID, filename: str = FILENAME, ) -> None: """Upload the model weights to ``repo_id`` under ``filename``. The file within the repository is overwritten each time to avoid accumulating checkpoints. """ with tempfile.TemporaryDirectory() as tmp: tmp_pt = os.path.join(tmp, "model.pt") tmp_gz = os.path.join(tmp, filename) torch.save(model.state_dict(), tmp_pt) with open(tmp_pt, "rb") as src, gzip.open(tmp_gz, "wb") as dst: dst.write(src.read()) HfApi().upload_file( path_or_fileobj=tmp_gz, path_in_repo=f"checkpoints/{filename}", repo_id=repo_id, repo_type="model", overwrite=True, ) def download_checkpoint( dest_path: str, *, repo_id: str = REPO_ID, filename: str = FILENAME, ) -> bool: """Download the latest checkpoint to ``dest_path``. Returns ``True`` if the checkpoint was successfully retrieved. """ try: buf = hf_hub_download( repo_id, f"checkpoints/{filename}", repo_type="model", force_download=True, ) except Exception as exc: # pragma: no cover - network errors print("Failed to download checkpoint", exc) return False os.makedirs(os.path.dirname(dest_path), exist_ok=True) shutil.copyfile(buf, dest_path) return True __all__ = ["hf_login", "save_checkpoint", "download_checkpoint"]