File size: 2,096 Bytes
36c78b1
 
 
 
 
 
 
 
 
 
 
e3d9952
36c78b1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
from __future__ import annotations

import gzip
import os
import shutil
import tempfile
from typing import Optional

import torch
from huggingface_hub import HfApi, hf_hub_download, login

REPO_ID = "WCNegentropy/BitTransformerLM"
FILENAME = "model.pt.gz"


def hf_login(token: Optional[str] = None) -> None:
    """Authenticate with Hugging Face.

    The ``token`` may be provided directly or via the ``HF_TOKEN`` environment
    variable. If omitted entirely, the library will attempt an interactive login.
    """
    login(token=token)


def save_checkpoint(
    model: torch.nn.Module,
    *,
    repo_id: str = REPO_ID,
    filename: str = FILENAME,
) -> None:
    """Upload the model weights to ``repo_id`` under ``filename``.

    The file within the repository is overwritten each time to avoid
    accumulating checkpoints.
    """
    with tempfile.TemporaryDirectory() as tmp:
        tmp_pt = os.path.join(tmp, "model.pt")
        tmp_gz = os.path.join(tmp, filename)
        torch.save(model.state_dict(), tmp_pt)
        with open(tmp_pt, "rb") as src, gzip.open(tmp_gz, "wb") as dst:
            dst.write(src.read())
        HfApi().upload_file(
            path_or_fileobj=tmp_gz,
            path_in_repo=f"checkpoints/{filename}",
            repo_id=repo_id,
            repo_type="model",
            overwrite=True,
        )


def download_checkpoint(
    dest_path: str,
    *,
    repo_id: str = REPO_ID,
    filename: str = FILENAME,
) -> bool:
    """Download the latest checkpoint to ``dest_path``.

    Returns ``True`` if the checkpoint was successfully retrieved.
    """
    try:
        buf = hf_hub_download(
            repo_id,
            f"checkpoints/{filename}",
            repo_type="model",
            force_download=True,
        )
    except Exception as exc:  # pragma: no cover - network errors
        print("Failed to download checkpoint", exc)
        return False
    os.makedirs(os.path.dirname(dest_path), exist_ok=True)
    shutil.copyfile(buf, dest_path)
    return True


__all__ = ["hf_login", "save_checkpoint", "download_checkpoint"]