File size: 3,919 Bytes
dd06d6b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 |
"""
Helpers for distributed training.
"""
import io
import os
import socket
try:
import blobfile as bf
except:
pass
try:
from mpi4py import MPI
except:
pass
import torch as th
import torch.distributed as dist
import builtins
import datetime
# Change this to reflect your cluster layout.
# The GPU for a given rank is (rank % GPUS_PER_NODE).
GPUS_PER_NODE = 8
SETUP_RETRY_COUNT = 3
def synchronize():
if not dist.is_available():
return
if not dist.is_initialized():
return
world_size = dist.get_world_size()
if world_size == 1:
return
dist.barrier()
def is_dist_avail_and_initialized():
if not dist.is_available():
return False
if not dist.is_initialized():
return False
return True
def get_world_size():
if not is_dist_avail_and_initialized():
return 1
return dist.get_world_size()
def setup_for_distributed(is_master):
"""
This function disables printing when not in master process
"""
builtin_print = builtins.print
def print(*args, **kwargs):
force = kwargs.pop('force', False)
force = force or (get_world_size() > 8)
if is_master or force:
now = datetime.datetime.now().time()
builtin_print('[{}] '.format(now), end='') # print with time stamp
builtin_print(*args, **kwargs)
builtins.print = print
def setup_dist_multinode(args):
"""
Setup a distributed process group.
"""
if not dist.is_available() or not dist.is_initialized():
th.distributed.init_process_group(backend="nccl", init_method='env://')
world_size = dist.get_world_size()
local_rank = int(os.getenv('LOCAL_RANK'))
print("rank",local_rank)
device = local_rank
th.cuda.set_device(device)
setup_for_distributed(device == 0)
synchronize()
else:
print("ddp failed!")
exit()
def setup_dist(global_seed):
"""
Setup a distributed process group.
"""
if dist.is_initialized():
return
th.cuda.set_device(int(os.environ["LOCAL_RANK"]))
th.distributed.init_process_group(backend="nccl", init_method="env://", timeout=datetime.timedelta(seconds=5400))
# fix seed
rank = dist.get_rank()
device = rank % th.cuda.device_count()
seed = global_seed * dist.get_world_size() + rank
th.manual_seed(seed)
th.cuda.set_device(device)
print(f"Starting rank={rank}, seed={seed}, world_size={dist.get_world_size()}.")
synchronize()
def dev():
"""
Get the device to use for torch.distributed.
"""
if th.cuda.is_available():
return th.device(f"cuda")
return th.device("cpu")
def load_state_dict(path, **kwargs):
"""
Load a PyTorch file without redundant fetches across MPI ranks.
"""
chunk_size = 2 ** 30 # MPI has a relatively small size limit
if MPI.COMM_WORLD.Get_rank() == 0:
with bf.BlobFile(path, "rb") as f:
data = f.read()
num_chunks = len(data) // chunk_size
if len(data) % chunk_size:
num_chunks += 1
MPI.COMM_WORLD.bcast(num_chunks)
for i in range(0, len(data), chunk_size):
MPI.COMM_WORLD.bcast(data[i : i + chunk_size])
else:
num_chunks = MPI.COMM_WORLD.bcast(None)
data = bytes()
for _ in range(num_chunks):
data += MPI.COMM_WORLD.bcast(None)
return th.load(io.BytesIO(data), **kwargs)
def sync_params(params):
"""
Synchronize a sequence of Tensors across ranks from rank 0.
"""
for p in params:
with th.no_grad():
dist.broadcast(p, 0)
def _find_free_port():
try:
s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
s.bind(("", 0))
s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
return s.getsockname()[1]
finally:
s.close() |