# Copyright (c) OpenMMLab. All rights reserved. import torch from mmdet3d.registry import TASK_UTILS def fp16_clamp(x, min=None, max=None): if not x.is_cuda and x.dtype == torch.float16: # clamp for cpu float16, tensor fp16 has no clamp implementation return x.float().clamp(min, max).half() return x.clamp(min, max) def bbox_overlaps(bboxes1, bboxes2, mode='iou', is_aligned=False, eps=1e-6): """Calculate overlap between two set of bboxes. FP16 Contributed by https://github.com/open-mmlab/mmdetection/pull/4889 Note: Assume bboxes1 is M x 4, bboxes2 is N x 4, when mode is 'iou', there are some new generated variable when calculating IOU using bbox_overlaps function: 1) is_aligned is False area1: M x 1 area2: N x 1 lt: M x N x 2 rb: M x N x 2 wh: M x N x 2 overlap: M x N x 1 union: M x N x 1 ious: M x N x 1 Total memory: S = (9 x N x M + N + M) * 4 Byte, When using FP16, we can reduce: R = (9 x N x M + N + M) * 4 / 2 Byte R large than (N + M) * 4 * 2 is always true when N and M >= 1. Obviously, N + M <= N * M < 3 * N * M, when N >=2 and M >=2, N + 1 < 3 * N, when N or M is 1. Given M = 40 (ground truth), N = 400000 (three anchor boxes in per grid, FPN, R-CNNs), R = 275 MB (one times) A special case (dense detection), M = 512 (ground truth), R = 3516 MB = 3.43 GB When the batch size is B, reduce: B x R Therefore, CUDA memory runs out frequently. Experiments on GeForce RTX 2080Ti (11019 MiB): | dtype | M | N | Use | Real | Ideal | |:----:|:----:|:----:|:----:|:----:|:----:| | FP32 | 512 | 400000 | 8020 MiB | -- | -- | | FP16 | 512 | 400000 | 4504 MiB | 3516 MiB | 3516 MiB | | FP32 | 40 | 400000 | 1540 MiB | -- | -- | | FP16 | 40 | 400000 | 1264 MiB | 276MiB | 275 MiB | 2) is_aligned is True area1: N x 1 area2: N x 1 lt: N x 2 rb: N x 2 wh: N x 2 overlap: N x 1 union: N x 1 ious: N x 1 Total memory: S = 11 x N * 4 Byte When using FP16, we can reduce: R = 11 x N * 4 / 2 Byte So do the 'giou' (large than 'iou'). Time-wise, FP16 is generally faster than FP32. When gpu_assign_thr is not -1, it takes more time on cpu but not reduce memory. There, we can reduce half the memory and keep the speed. If ``is_aligned`` is ``False``, then calculate the overlaps between each bbox of bboxes1 and bboxes2, otherwise the overlaps between each aligned pair of bboxes1 and bboxes2. Args: bboxes1 (Tensor): shape (B, m, 4) in format or empty. bboxes2 (Tensor): shape (B, n, 4) in format or empty. B indicates the batch dim, in shape (B1, B2, ..., Bn). If ``is_aligned`` is ``True``, then m and n must be equal. mode (str): "iou" (intersection over union), "iof" (intersection over foreground) or "giou" (generalized intersection over union). Default "iou". is_aligned (bool, optional): If True, then m and n must be equal. Default False. eps (float, optional): A value added to the denominator for numerical stability. Default 1e-6. Returns: Tensor: shape (m, n) if ``is_aligned`` is False else shape (m,) Example: >>> bboxes1 = torch.FloatTensor([ >>> [0, 0, 10, 10], >>> [10, 10, 20, 20], >>> [32, 32, 38, 42], >>> ]) >>> bboxes2 = torch.FloatTensor([ >>> [0, 0, 10, 20], >>> [0, 10, 10, 19], >>> [10, 10, 20, 20], >>> ]) >>> overlaps = bbox_overlaps(bboxes1, bboxes2) >>> assert overlaps.shape == (3, 3) >>> overlaps = bbox_overlaps(bboxes1, bboxes2, is_aligned=True) >>> assert overlaps.shape == (3, ) Example: >>> empty = torch.empty(0, 4) >>> nonempty = torch.FloatTensor([[0, 0, 10, 9]]) >>> assert tuple(bbox_overlaps(empty, nonempty).shape) == (0, 1) >>> assert tuple(bbox_overlaps(nonempty, empty).shape) == (1, 0) >>> assert tuple(bbox_overlaps(empty, empty).shape) == (0, 0) """ assert mode in ['iou', 'iof', 'giou'], f'Unsupported mode {mode}' # Either the boxes are empty or the length of boxes' last dimension is 4 assert (bboxes1.size(-1) == 4 or bboxes1.size(0) == 0) assert (bboxes2.size(-1) == 4 or bboxes2.size(0) == 0) # Batch dim must be the same # Batch dim: (B1, B2, ... Bn) assert bboxes1.shape[:-2] == bboxes2.shape[:-2] batch_shape = bboxes1.shape[:-2] rows = bboxes1.size(-2) cols = bboxes2.size(-2) if is_aligned: assert rows == cols if rows * cols == 0: if is_aligned: return bboxes1.new(batch_shape + (rows, )) else: return bboxes1.new(batch_shape + (rows, cols)) area1 = (bboxes1[..., 2] - bboxes1[..., 0]) * ( bboxes1[..., 3] - bboxes1[..., 1]) area2 = (bboxes2[..., 2] - bboxes2[..., 0]) * ( bboxes2[..., 3] - bboxes2[..., 1]) if is_aligned: lt = torch.max(bboxes1[..., :2], bboxes2[..., :2]) # [B, rows, 2] rb = torch.min(bboxes1[..., 2:], bboxes2[..., 2:]) # [B, rows, 2] wh = fp16_clamp(rb - lt, min=0) overlap = wh[..., 0] * wh[..., 1] if mode in ['iou', 'giou']: union = area1 + area2 - overlap else: union = area1 if mode == 'giou': enclosed_lt = torch.min(bboxes1[..., :2], bboxes2[..., :2]) enclosed_rb = torch.max(bboxes1[..., 2:], bboxes2[..., 2:]) else: lt = torch.max(bboxes1[..., :, None, :2], bboxes2[..., None, :, :2]) # [B, rows, cols, 2] rb = torch.min(bboxes1[..., :, None, 2:], bboxes2[..., None, :, 2:]) # [B, rows, cols, 2] wh = fp16_clamp(rb - lt, min=0) overlap = wh[..., 0] * wh[..., 1] if mode in ['iou', 'giou']: union = area1[..., None] + area2[..., None, :] - overlap else: union = area1[..., None] if mode == 'giou': enclosed_lt = torch.min(bboxes1[..., :, None, :2], bboxes2[..., None, :, :2]) enclosed_rb = torch.max(bboxes1[..., :, None, 2:], bboxes2[..., None, :, 2:]) eps = union.new_tensor([eps]) union = torch.max(union, eps) ious = overlap / union if mode in ['iou', 'iof']: return ious # calculate gious enclose_wh = fp16_clamp(enclosed_rb - enclosed_lt, min=0) enclose_area = enclose_wh[..., 0] * enclose_wh[..., 1] enclose_area = torch.max(enclose_area, eps) gious = ious - (enclose_area - union) / enclose_area return gious @TASK_UTILS.register_module() class BBox3DL1Cost(object): """BBox3DL1Cost. Args: weight (int | float, optional): loss_weight """ def __init__(self, weight=1.): self.weight = weight def __call__(self, bbox_pred, gt_bboxes): """ Args: bbox_pred (Tensor): Predicted boxes with normalized coordinates (cx, cy, w, h), which are all in range [0, 1]. Shape [num_query, 4]. gt_bboxes (Tensor): Ground truth boxes with normalized coordinates (x1, y1, x2, y2). Shape [num_gt, 4]. Returns: torch.Tensor: bbox_cost value with weight """ bbox_cost = torch.cdist(bbox_pred, gt_bboxes, p=1) return bbox_cost * self.weight @TASK_UTILS.register_module() class FocalLossCost: """FocalLossCost. Args: weight (int | float, optional): loss_weight alpha (int | float, optional): focal_loss alpha gamma (int | float, optional): focal_loss gamma eps (float, optional): default 1e-12 binary_input (bool, optional): Whether the input is binary, default False. Examples: >>> from mmdet.core.bbox.match_costs.match_cost import FocalLossCost >>> import torch >>> self = FocalLossCost() >>> cls_pred = torch.rand(4, 3) >>> gt_labels = torch.tensor([0, 1, 2]) >>> factor = torch.tensor([10, 8, 10, 8]) >>> self(cls_pred, gt_labels) tensor([[-0.3236, -0.3364, -0.2699], [-0.3439, -0.3209, -0.4807], [-0.4099, -0.3795, -0.2929], [-0.1950, -0.1207, -0.2626]]) """ def __init__(self, weight=1., alpha=0.25, gamma=2, eps=1e-12, binary_input=False): self.weight = weight self.alpha = alpha self.gamma = gamma self.eps = eps self.binary_input = binary_input def _focal_loss_cost(self, cls_pred, gt_labels): """ Args: cls_pred (Tensor): Predicted classification logits, shape (num_query, num_class). gt_labels (Tensor): Label of `gt_bboxes`, shape (num_gt,). Returns: torch.Tensor: cls_cost value with weight """ cls_pred = cls_pred.sigmoid() neg_cost = -(1 - cls_pred + self.eps).log() * ( 1 - self.alpha) * cls_pred.pow(self.gamma) pos_cost = -(cls_pred + self.eps).log() * self.alpha * ( 1 - cls_pred).pow(self.gamma) cls_cost = pos_cost[:, gt_labels] - neg_cost[:, gt_labels] return cls_cost * self.weight def _mask_focal_loss_cost(self, cls_pred, gt_labels): """ Args: cls_pred (Tensor): Predicted classfication logits in shape (num_query, d1, ..., dn), dtype=torch.float32. gt_labels (Tensor): Ground truth in shape (num_gt, d1, ..., dn), dtype=torch.long. Labels should be binary. Returns: Tensor: Focal cost matrix with weight in shape\ (num_query, num_gt). """ cls_pred = cls_pred.flatten(1) gt_labels = gt_labels.flatten(1).float() n = cls_pred.shape[1] cls_pred = cls_pred.sigmoid() neg_cost = -(1 - cls_pred + self.eps).log() * ( 1 - self.alpha) * cls_pred.pow(self.gamma) pos_cost = -(cls_pred + self.eps).log() * self.alpha * ( 1 - cls_pred).pow(self.gamma) cls_cost = torch.einsum('nc,mc->nm', pos_cost, gt_labels) + \ torch.einsum('nc,mc->nm', neg_cost, (1 - gt_labels)) return cls_cost / n * self.weight def __call__(self, cls_pred, gt_labels): """ Args: cls_pred (Tensor): Predicted classfication logits. gt_labels (Tensor)): Labels. Returns: Tensor: Focal cost matrix with weight in shape\ (num_query, num_gt). """ if self.binary_input: return self._mask_focal_loss_cost(cls_pred, gt_labels) else: return self._focal_loss_cost(cls_pred, gt_labels) @TASK_UTILS.register_module() class IoUCost: """IoUCost. Args: iou_mode (str, optional): iou mode such as 'iou' | 'giou' weight (int | float, optional): loss weight Examples: >>> from mmdet.core.bbox.match_costs.match_cost import IoUCost >>> import torch >>> self = IoUCost() >>> bboxes = torch.FloatTensor([[1,1, 2, 2], [2, 2, 3, 4]]) >>> gt_bboxes = torch.FloatTensor([[0, 0, 2, 4], [1, 2, 3, 4]]) >>> self(bboxes, gt_bboxes) tensor([[-0.1250, 0.1667], [ 0.1667, -0.5000]]) """ def __init__(self, iou_mode='giou', weight=1.): self.weight = weight self.iou_mode = iou_mode def __call__(self, bboxes, gt_bboxes): """ Args: bboxes (Tensor): Predicted boxes with unnormalized coordinates (x1, y1, x2, y2). Shape (num_query, 4). gt_bboxes (Tensor): Ground truth boxes with unnormalized coordinates (x1, y1, x2, y2). Shape (num_gt, 4). Returns: torch.Tensor: iou_cost value with weight """ # overlaps: [num_bboxes, num_gt] overlaps = bbox_overlaps( bboxes, gt_bboxes, mode=self.iou_mode, is_aligned=False) # The 1 is a constant that doesn't change the matching, so omitted. iou_cost = -overlaps return iou_cost * self.weight