|
from __future__ import annotations |
|
|
|
import gzip |
|
import os |
|
import shutil |
|
import tempfile |
|
from typing import Optional |
|
|
|
import torch |
|
from huggingface_hub import HfApi, hf_hub_download, login |
|
|
|
REPO_ID = "WCNegentropy/BitTransformerLM" |
|
FILENAME = "model.pt.gz" |
|
|
|
|
|
def hf_login(token: Optional[str] = None) -> None: |
|
"""Authenticate with Hugging Face. |
|
|
|
The ``token`` may be provided directly or via the ``HF_TOKEN`` environment |
|
variable. If omitted entirely, the library will attempt an interactive login. |
|
""" |
|
login(token=token) |
|
|
|
|
|
def save_checkpoint( |
|
model: torch.nn.Module, |
|
*, |
|
repo_id: str = REPO_ID, |
|
filename: str = FILENAME, |
|
) -> None: |
|
"""Upload the model weights to ``repo_id`` under ``filename``. |
|
|
|
The file within the repository is overwritten each time to avoid |
|
accumulating checkpoints. |
|
""" |
|
with tempfile.TemporaryDirectory() as tmp: |
|
tmp_pt = os.path.join(tmp, "model.pt") |
|
tmp_gz = os.path.join(tmp, filename) |
|
torch.save(model.state_dict(), tmp_pt) |
|
with open(tmp_pt, "rb") as src, gzip.open(tmp_gz, "wb") as dst: |
|
dst.write(src.read()) |
|
HfApi().upload_file( |
|
path_or_fileobj=tmp_gz, |
|
path_in_repo=f"checkpoints/{filename}", |
|
repo_id=repo_id, |
|
repo_type="model", |
|
overwrite=True, |
|
) |
|
|
|
|
|
def download_checkpoint( |
|
dest_path: str, |
|
*, |
|
repo_id: str = REPO_ID, |
|
filename: str = FILENAME, |
|
) -> bool: |
|
"""Download the latest checkpoint to ``dest_path``. |
|
|
|
Returns ``True`` if the checkpoint was successfully retrieved. |
|
""" |
|
try: |
|
buf = hf_hub_download( |
|
repo_id, |
|
f"checkpoints/{filename}", |
|
repo_type="model", |
|
force_download=True, |
|
) |
|
except Exception as exc: |
|
print("Failed to download checkpoint", exc) |
|
return False |
|
os.makedirs(os.path.dirname(dest_path), exist_ok=True) |
|
shutil.copyfile(buf, dest_path) |
|
return True |
|
|
|
|
|
__all__ = ["hf_login", "save_checkpoint", "download_checkpoint"] |
|
|