deneme / relighting /dist_utils.py
sakinlesh's picture
Upload 25 files
dd06d6b verified
raw
history blame
3.92 kB
"""
Helpers for distributed training.
"""
import io
import os
import socket
try:
import blobfile as bf
except:
pass
try:
from mpi4py import MPI
except:
pass
import torch as th
import torch.distributed as dist
import builtins
import datetime
# Change this to reflect your cluster layout.
# The GPU for a given rank is (rank % GPUS_PER_NODE).
GPUS_PER_NODE = 8
SETUP_RETRY_COUNT = 3
def synchronize():
if not dist.is_available():
return
if not dist.is_initialized():
return
world_size = dist.get_world_size()
if world_size == 1:
return
dist.barrier()
def is_dist_avail_and_initialized():
if not dist.is_available():
return False
if not dist.is_initialized():
return False
return True
def get_world_size():
if not is_dist_avail_and_initialized():
return 1
return dist.get_world_size()
def setup_for_distributed(is_master):
"""
This function disables printing when not in master process
"""
builtin_print = builtins.print
def print(*args, **kwargs):
force = kwargs.pop('force', False)
force = force or (get_world_size() > 8)
if is_master or force:
now = datetime.datetime.now().time()
builtin_print('[{}] '.format(now), end='') # print with time stamp
builtin_print(*args, **kwargs)
builtins.print = print
def setup_dist_multinode(args):
"""
Setup a distributed process group.
"""
if not dist.is_available() or not dist.is_initialized():
th.distributed.init_process_group(backend="nccl", init_method='env://')
world_size = dist.get_world_size()
local_rank = int(os.getenv('LOCAL_RANK'))
print("rank",local_rank)
device = local_rank
th.cuda.set_device(device)
setup_for_distributed(device == 0)
synchronize()
else:
print("ddp failed!")
exit()
def setup_dist(global_seed):
"""
Setup a distributed process group.
"""
if dist.is_initialized():
return
th.cuda.set_device(int(os.environ["LOCAL_RANK"]))
th.distributed.init_process_group(backend="nccl", init_method="env://", timeout=datetime.timedelta(seconds=5400))
# fix seed
rank = dist.get_rank()
device = rank % th.cuda.device_count()
seed = global_seed * dist.get_world_size() + rank
th.manual_seed(seed)
th.cuda.set_device(device)
print(f"Starting rank={rank}, seed={seed}, world_size={dist.get_world_size()}.")
synchronize()
def dev():
"""
Get the device to use for torch.distributed.
"""
if th.cuda.is_available():
return th.device(f"cuda")
return th.device("cpu")
def load_state_dict(path, **kwargs):
"""
Load a PyTorch file without redundant fetches across MPI ranks.
"""
chunk_size = 2 ** 30 # MPI has a relatively small size limit
if MPI.COMM_WORLD.Get_rank() == 0:
with bf.BlobFile(path, "rb") as f:
data = f.read()
num_chunks = len(data) // chunk_size
if len(data) % chunk_size:
num_chunks += 1
MPI.COMM_WORLD.bcast(num_chunks)
for i in range(0, len(data), chunk_size):
MPI.COMM_WORLD.bcast(data[i : i + chunk_size])
else:
num_chunks = MPI.COMM_WORLD.bcast(None)
data = bytes()
for _ in range(num_chunks):
data += MPI.COMM_WORLD.bcast(None)
return th.load(io.BytesIO(data), **kwargs)
def sync_params(params):
"""
Synchronize a sequence of Tensors across ranks from rank 0.
"""
for p in params:
with th.no_grad():
dist.broadcast(p, 0)
def _find_free_port():
try:
s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
s.bind(("", 0))
s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
return s.getsockname()[1]
finally:
s.close()