|
|
|
|
|
|
|
|
|
import random |
|
import numpy as np |
|
import torch |
|
|
|
from .shardedtensor import * |
|
from .load_config import * |
|
|
|
|
|
def set_seed(seed=43211): |
|
random.seed(seed) |
|
np.random.seed(seed) |
|
torch.manual_seed(seed) |
|
torch.cuda.manual_seed_all(seed) |
|
if torch.backends.cudnn.enabled: |
|
torch.backends.cudnn.benchmark = False |
|
torch.backends.cudnn.deterministic = True |
|
|
|
|
|
def get_world_size(): |
|
if torch.distributed.is_initialized(): |
|
world_size = torch.distributed.get_world_size() |
|
else: |
|
world_size = 1 |
|
return world_size |
|
|
|
|
|
def get_local_rank(): |
|
return torch.distributed.get_rank() \ |
|
if torch.distributed.is_initialized() else 0 |
|
|
|
|
|
def print_on_rank0(func): |
|
local_rank = get_local_rank() |
|
if local_rank == 0: |
|
print("[INFO]", func) |
|
|
|
|
|
class RetriMeter(object): |
|
""" |
|
Statistics on whether retrieval yields a better pair. |
|
""" |
|
def __init__(self, freq=1024): |
|
self.freq = freq |
|
self.total = 0 |
|
self.replace = 0 |
|
self.updates = 0 |
|
|
|
def __call__(self, data): |
|
if isinstance(data, np.ndarray): |
|
self.replace += data.shape[0] - int((data[:, 0] == -1).sum()) |
|
self.total += data.shape[0] |
|
elif torch.is_tensor(data): |
|
self.replace += int(data.sum()) |
|
self.total += data.size(0) |
|
else: |
|
raise ValueError("unsupported RetriMeter data type.", type(data)) |
|
|
|
self.updates += 1 |
|
if get_local_rank() == 0 and self.updates % self.freq == 0: |
|
print("[INFO]", self) |
|
|
|
def __repr__(self): |
|
return "RetriMeter (" + str(self.replace / self.total) \ |
|
+ "/" + str(self.replace) + "/" + str(self.total) + ")" |
|
|