|
import os |
|
import warnings |
|
from typing import Any, List, Optional |
|
|
|
from torch import distributed as dist |
|
|
|
__all__ = [ |
|
"init", |
|
"is_initialized", |
|
"size", |
|
"rank", |
|
"local_size", |
|
"local_rank", |
|
"is_main", |
|
"barrier", |
|
"gather", |
|
"all_gather", |
|
] |
|
|
|
|
|
def init() -> None: |
|
if "RANK" not in os.environ: |
|
warnings.warn("Environment variable `RANK` is not set. Skipping distributed initialization.") |
|
return |
|
dist.init_process_group(backend="nccl", init_method="env://") |
|
|
|
|
|
def is_initialized() -> bool: |
|
return dist.is_initialized() |
|
|
|
|
|
def size() -> int: |
|
return int(os.environ.get("WORLD_SIZE", 1)) |
|
|
|
|
|
def rank() -> int: |
|
return int(os.environ.get("RANK", 0)) |
|
|
|
|
|
def local_size() -> int: |
|
return int(os.environ.get("LOCAL_WORLD_SIZE", 1)) |
|
|
|
|
|
def local_rank() -> int: |
|
return int(os.environ.get("LOCAL_RANK", 0)) |
|
|
|
|
|
def is_main() -> bool: |
|
return rank() == 0 |
|
|
|
|
|
def barrier() -> None: |
|
dist.barrier() |
|
|
|
|
|
def gather(obj: Any, dst: int = 0) -> Optional[List[Any]]: |
|
if not is_initialized(): |
|
return [obj] |
|
if is_main(): |
|
objs = [None for _ in range(size())] |
|
dist.gather_object(obj, objs, dst=dst) |
|
return objs |
|
else: |
|
dist.gather_object(obj, dst=dst) |
|
return None |
|
|
|
|
|
def all_gather(obj: Any) -> List[Any]: |
|
if not is_initialized(): |
|
return [obj] |
|
objs = [None for _ in range(size())] |
|
dist.all_gather_object(objs, obj) |
|
return objs |
|
|