|
""" |
|
Helpers for distributed training. |
|
""" |
|
|
|
import io |
|
import os |
|
import socket |
|
|
|
try: |
|
import blobfile as bf |
|
except: |
|
pass |
|
|
|
try: |
|
from mpi4py import MPI |
|
except: |
|
pass |
|
|
|
import torch as th |
|
import torch.distributed as dist |
|
import builtins |
|
import datetime |
|
|
|
|
|
|
|
GPUS_PER_NODE = 8 |
|
|
|
SETUP_RETRY_COUNT = 3 |
|
def synchronize(): |
|
if not dist.is_available(): |
|
return |
|
|
|
if not dist.is_initialized(): |
|
return |
|
|
|
world_size = dist.get_world_size() |
|
|
|
if world_size == 1: |
|
return |
|
|
|
dist.barrier() |
|
|
|
def is_dist_avail_and_initialized(): |
|
if not dist.is_available(): |
|
return False |
|
if not dist.is_initialized(): |
|
return False |
|
return True |
|
def get_world_size(): |
|
if not is_dist_avail_and_initialized(): |
|
return 1 |
|
return dist.get_world_size() |
|
|
|
def setup_for_distributed(is_master): |
|
""" |
|
This function disables printing when not in master process |
|
""" |
|
builtin_print = builtins.print |
|
|
|
def print(*args, **kwargs): |
|
force = kwargs.pop('force', False) |
|
force = force or (get_world_size() > 8) |
|
if is_master or force: |
|
now = datetime.datetime.now().time() |
|
builtin_print('[{}] '.format(now), end='') |
|
builtin_print(*args, **kwargs) |
|
|
|
builtins.print = print |
|
|
|
def setup_dist_multinode(args): |
|
""" |
|
Setup a distributed process group. |
|
""" |
|
if not dist.is_available() or not dist.is_initialized(): |
|
th.distributed.init_process_group(backend="nccl", init_method='env://') |
|
world_size = dist.get_world_size() |
|
local_rank = int(os.getenv('LOCAL_RANK')) |
|
print("rank",local_rank) |
|
device = local_rank |
|
th.cuda.set_device(device) |
|
setup_for_distributed(device == 0) |
|
|
|
synchronize() |
|
else: |
|
print("ddp failed!") |
|
exit() |
|
|
|
def setup_dist(global_seed): |
|
""" |
|
Setup a distributed process group. |
|
""" |
|
if dist.is_initialized(): |
|
return |
|
th.cuda.set_device(int(os.environ["LOCAL_RANK"])) |
|
th.distributed.init_process_group(backend="nccl", init_method="env://", timeout=datetime.timedelta(seconds=5400)) |
|
|
|
|
|
rank = dist.get_rank() |
|
device = rank % th.cuda.device_count() |
|
seed = global_seed * dist.get_world_size() + rank |
|
th.manual_seed(seed) |
|
th.cuda.set_device(device) |
|
print(f"Starting rank={rank}, seed={seed}, world_size={dist.get_world_size()}.") |
|
synchronize() |
|
|
|
def dev(): |
|
""" |
|
Get the device to use for torch.distributed. |
|
""" |
|
if th.cuda.is_available(): |
|
return th.device(f"cuda") |
|
return th.device("cpu") |
|
|
|
|
|
def load_state_dict(path, **kwargs): |
|
""" |
|
Load a PyTorch file without redundant fetches across MPI ranks. |
|
""" |
|
chunk_size = 2 ** 30 |
|
if MPI.COMM_WORLD.Get_rank() == 0: |
|
with bf.BlobFile(path, "rb") as f: |
|
data = f.read() |
|
num_chunks = len(data) // chunk_size |
|
if len(data) % chunk_size: |
|
num_chunks += 1 |
|
MPI.COMM_WORLD.bcast(num_chunks) |
|
for i in range(0, len(data), chunk_size): |
|
MPI.COMM_WORLD.bcast(data[i : i + chunk_size]) |
|
else: |
|
num_chunks = MPI.COMM_WORLD.bcast(None) |
|
data = bytes() |
|
for _ in range(num_chunks): |
|
data += MPI.COMM_WORLD.bcast(None) |
|
|
|
return th.load(io.BytesIO(data), **kwargs) |
|
|
|
|
|
def sync_params(params): |
|
""" |
|
Synchronize a sequence of Tensors across ranks from rank 0. |
|
""" |
|
for p in params: |
|
with th.no_grad(): |
|
dist.broadcast(p, 0) |
|
|
|
|
|
def _find_free_port(): |
|
try: |
|
s = socket.socket(socket.AF_INET, socket.SOCK_STREAM) |
|
s.bind(("", 0)) |
|
s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) |
|
return s.getsockname()[1] |
|
finally: |
|
s.close() |