|  | from __future__ import annotations | 
					
						
						|  |  | 
					
						
						|  | import gzip | 
					
						
						|  | import os | 
					
						
						|  | import shutil | 
					
						
						|  | import tempfile | 
					
						
						|  | from typing import Optional | 
					
						
						|  |  | 
					
						
						|  | import torch | 
					
						
						|  | from huggingface_hub import HfApi, hf_hub_download, login | 
					
						
						|  |  | 
					
						
						|  | REPO_ID = "WCNegentropy/BitTransformerLM" | 
					
						
						|  | FILENAME = "model.pt.gz" | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def hf_login(token: Optional[str] = None) -> None: | 
					
						
						|  | """Authenticate with Hugging Face. | 
					
						
						|  |  | 
					
						
						|  | The ``token`` may be provided directly or via the ``HF_TOKEN`` environment | 
					
						
						|  | variable. If omitted entirely, the library will attempt an interactive login. | 
					
						
						|  | """ | 
					
						
						|  | login(token=token) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def save_checkpoint( | 
					
						
						|  | model: torch.nn.Module, | 
					
						
						|  | *, | 
					
						
						|  | repo_id: str = REPO_ID, | 
					
						
						|  | filename: str = FILENAME, | 
					
						
						|  | ) -> None: | 
					
						
						|  | """Upload the model weights to ``repo_id`` under ``filename``. | 
					
						
						|  |  | 
					
						
						|  | The file within the repository is overwritten each time to avoid | 
					
						
						|  | accumulating checkpoints. | 
					
						
						|  | """ | 
					
						
						|  | with tempfile.TemporaryDirectory() as tmp: | 
					
						
						|  | tmp_pt = os.path.join(tmp, "model.pt") | 
					
						
						|  | tmp_gz = os.path.join(tmp, filename) | 
					
						
						|  | torch.save(model.state_dict(), tmp_pt) | 
					
						
						|  | with open(tmp_pt, "rb") as src, gzip.open(tmp_gz, "wb") as dst: | 
					
						
						|  | dst.write(src.read()) | 
					
						
						|  | HfApi().upload_file( | 
					
						
						|  | path_or_fileobj=tmp_gz, | 
					
						
						|  | path_in_repo=f"checkpoints/{filename}", | 
					
						
						|  | repo_id=repo_id, | 
					
						
						|  | repo_type="model", | 
					
						
						|  | overwrite=True, | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def download_checkpoint( | 
					
						
						|  | dest_path: str, | 
					
						
						|  | *, | 
					
						
						|  | repo_id: str = REPO_ID, | 
					
						
						|  | filename: str = FILENAME, | 
					
						
						|  | ) -> bool: | 
					
						
						|  | """Download the latest checkpoint to ``dest_path``. | 
					
						
						|  |  | 
					
						
						|  | Returns ``True`` if the checkpoint was successfully retrieved. | 
					
						
						|  | """ | 
					
						
						|  | try: | 
					
						
						|  | buf = hf_hub_download( | 
					
						
						|  | repo_id, | 
					
						
						|  | f"checkpoints/{filename}", | 
					
						
						|  | repo_type="model", | 
					
						
						|  | force_download=True, | 
					
						
						|  | ) | 
					
						
						|  | except Exception as exc: | 
					
						
						|  | print("Failed to download checkpoint", exc) | 
					
						
						|  | return False | 
					
						
						|  | os.makedirs(os.path.dirname(dest_path), exist_ok=True) | 
					
						
						|  | shutil.copyfile(buf, dest_path) | 
					
						
						|  | return True | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | __all__ = ["hf_login", "save_checkpoint", "download_checkpoint"] | 
					
						
						|  |  |