import os import warnings from typing import Any, List, Optional from torch import distributed as dist __all__ = [ "init", "is_initialized", "size", "rank", "local_size", "local_rank", "is_main", "barrier", "gather", "all_gather", ] def init() -> None: if "RANK" not in os.environ: warnings.warn("Environment variable `RANK` is not set. Skipping distributed initialization.") return dist.init_process_group(backend="nccl", init_method="env://") def is_initialized() -> bool: return dist.is_initialized() def size() -> int: return int(os.environ.get("WORLD_SIZE", 1)) def rank() -> int: return int(os.environ.get("RANK", 0)) def local_size() -> int: return int(os.environ.get("LOCAL_WORLD_SIZE", 1)) def local_rank() -> int: return int(os.environ.get("LOCAL_RANK", 0)) def is_main() -> bool: return rank() == 0 def barrier() -> None: dist.barrier() def gather(obj: Any, dst: int = 0) -> Optional[List[Any]]: if not is_initialized(): return [obj] if is_main(): objs = [None for _ in range(size())] dist.gather_object(obj, objs, dst=dst) return objs else: dist.gather_object(obj, dst=dst) return None def all_gather(obj: Any) -> List[Any]: if not is_initialized(): return [obj] objs = [None for _ in range(size())] dist.all_gather_object(objs, obj) return objs