Spaces:
Running
on
Zero
Running
on
Zero
File size: 2,261 Bytes
0a4fc35 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 |
from typing import Dict
import fairscale.nn.model_parallel.initialize as fs_init
from fairscale.nn.model_parallel.layers import ColumnParallelLinear, ParallelEmbedding, RowParallelLinear
import torch
import torch.distributed as dist
import torch.nn as nn
def get_model_parallel_dim_dict(model: nn.Module) -> Dict[str, int]:
ret_dict = {}
for module_name, module in model.named_modules():
def param_fqn(param_name):
return param_name if module_name == "" else module_name + "." + param_name
if isinstance(module, ColumnParallelLinear):
ret_dict[param_fqn("weight")] = 0
if module.bias is not None:
ret_dict[param_fqn("bias")] = 0
elif isinstance(module, RowParallelLinear):
ret_dict[param_fqn("weight")] = 1
if module.bias is not None:
ret_dict[param_fqn("bias")] = -1
elif isinstance(module, ParallelEmbedding):
ret_dict[param_fqn("weight")] = 1
else:
for param_name, param in module.named_parameters(recurse=False):
ret_dict[param_fqn(param_name)] = -1
return ret_dict
def calculate_l2_grad_norm(
model: nn.Module,
model_parallel_dim_dict: Dict[str, int],
) -> float:
mp_norm_sq = torch.tensor(0.0, dtype=torch.float32, device="cuda")
non_mp_norm_sq = torch.tensor(0.0, dtype=torch.float32, device="cuda")
for name, param in model.named_parameters():
if param.grad is None:
continue
name = ".".join(x for x in name.split(".") if not x.startswith("_"))
assert name in model_parallel_dim_dict
if model_parallel_dim_dict[name] < 0:
non_mp_norm_sq += param.grad.norm(dtype=torch.float32) ** 2
else:
mp_norm_sq += param.grad.norm(dtype=torch.float32) ** 2
dist.all_reduce(mp_norm_sq)
dist.all_reduce(non_mp_norm_sq)
non_mp_norm_sq /= fs_init.get_model_parallel_world_size()
return (mp_norm_sq.item() + non_mp_norm_sq.item()) ** 0.5
def scale_grad(model: nn.Module, factor: float) -> None:
for param in model.parameters():
if param.grad is not None:
param.grad.mul_(factor)
|