WCNegentropy's picture
πŸš€ OS Launch: Clean documentation and refined licensing
e3d9952 verified
raw
history blame
2.1 kB
from __future__ import annotations
import gzip
import os
import shutil
import tempfile
from typing import Optional
import torch
from huggingface_hub import HfApi, hf_hub_download, login
REPO_ID = "WCNegentropy/BitTransformerLM"
FILENAME = "model.pt.gz"
def hf_login(token: Optional[str] = None) -> None:
"""Authenticate with Hugging Face.
The ``token`` may be provided directly or via the ``HF_TOKEN`` environment
variable. If omitted entirely, the library will attempt an interactive login.
"""
login(token=token)
def save_checkpoint(
model: torch.nn.Module,
*,
repo_id: str = REPO_ID,
filename: str = FILENAME,
) -> None:
"""Upload the model weights to ``repo_id`` under ``filename``.
The file within the repository is overwritten each time to avoid
accumulating checkpoints.
"""
with tempfile.TemporaryDirectory() as tmp:
tmp_pt = os.path.join(tmp, "model.pt")
tmp_gz = os.path.join(tmp, filename)
torch.save(model.state_dict(), tmp_pt)
with open(tmp_pt, "rb") as src, gzip.open(tmp_gz, "wb") as dst:
dst.write(src.read())
HfApi().upload_file(
path_or_fileobj=tmp_gz,
path_in_repo=f"checkpoints/{filename}",
repo_id=repo_id,
repo_type="model",
overwrite=True,
)
def download_checkpoint(
dest_path: str,
*,
repo_id: str = REPO_ID,
filename: str = FILENAME,
) -> bool:
"""Download the latest checkpoint to ``dest_path``.
Returns ``True`` if the checkpoint was successfully retrieved.
"""
try:
buf = hf_hub_download(
repo_id,
f"checkpoints/{filename}",
repo_type="model",
force_download=True,
)
except Exception as exc: # pragma: no cover - network errors
print("Failed to download checkpoint", exc)
return False
os.makedirs(os.path.dirname(dest_path), exist_ok=True)
shutil.copyfile(buf, dest_path)
return True
__all__ = ["hf_login", "save_checkpoint", "download_checkpoint"]