from typing import Dict import fairscale.nn.model_parallel.initialize as fs_init from fairscale.nn.model_parallel.layers import ColumnParallelLinear, ParallelEmbedding, RowParallelLinear import torch import torch.distributed as dist import torch.nn as nn def get_model_parallel_dim_dict(model: nn.Module) -> Dict[str, int]: ret_dict = {} for module_name, module in model.named_modules(): def param_fqn(param_name): return param_name if module_name == "" else module_name + "." + param_name if isinstance(module, ColumnParallelLinear): ret_dict[param_fqn("weight")] = 0 if module.bias is not None: ret_dict[param_fqn("bias")] = 0 elif isinstance(module, RowParallelLinear): ret_dict[param_fqn("weight")] = 1 if module.bias is not None: ret_dict[param_fqn("bias")] = -1 elif isinstance(module, ParallelEmbedding): ret_dict[param_fqn("weight")] = 1 else: for param_name, param in module.named_parameters(recurse=False): ret_dict[param_fqn(param_name)] = -1 return ret_dict def calculate_l2_grad_norm( model: nn.Module, model_parallel_dim_dict: Dict[str, int], ) -> float: mp_norm_sq = torch.tensor(0.0, dtype=torch.float32, device="cuda") non_mp_norm_sq = torch.tensor(0.0, dtype=torch.float32, device="cuda") for name, param in model.named_parameters(): if param.grad is None: continue name = ".".join(x for x in name.split(".") if not x.startswith("_")) assert name in model_parallel_dim_dict if model_parallel_dim_dict[name] < 0: non_mp_norm_sq += param.grad.norm(dtype=torch.float32) ** 2 else: mp_norm_sq += param.grad.norm(dtype=torch.float32) ** 2 dist.all_reduce(mp_norm_sq) dist.all_reduce(non_mp_norm_sq) non_mp_norm_sq /= fs_init.get_model_parallel_world_size() return (mp_norm_sq.item() + non_mp_norm_sq.item()) ** 0.5 def scale_grad(model: nn.Module, factor: float) -> None: for param in model.parameters(): if param.grad is not None: param.grad.mul_(factor)