import os
from pathlib import Path


def get_latest_checkpoint(path: Path | str) -> Path | None:
    # Find the latest checkpoint
    ckpt_dir = Path(path)

    if ckpt_dir.exists() is False:
        return None

    ckpts = sorted(ckpt_dir.glob("*.ckpt"), key=os.path.getmtime)
    if len(ckpts) == 0:
        return None

    return ckpts[-1]