|
|
|
import copy |
|
from typing import List, Tuple |
|
|
|
import numpy as np |
|
import torch |
|
import torch.nn.functional as F |
|
from mmcv.cnn import ConvModule, build_conv_layer |
|
from mmdet.models.task_modules import (AssignResult, PseudoSampler, |
|
build_assigner, build_bbox_coder, |
|
build_sampler) |
|
from mmdet.models.utils import multi_apply |
|
from mmengine.structures import InstanceData |
|
from torch import nn |
|
|
|
from mmdet3d.models import circle_nms, draw_heatmap_gaussian, gaussian_radius |
|
from mmdet3d.models.dense_heads.centerpoint_head import SeparateHead |
|
from mmdet3d.models.layers import nms_bev |
|
from mmdet3d.registry import MODELS |
|
from mmdet3d.structures import xywhr2xyxyr |
|
|
|
|
|
def clip_sigmoid(x, eps=1e-4): |
|
y = torch.clamp(x.sigmoid_(), min=eps, max=1 - eps) |
|
return y |
|
|
|
|
|
@MODELS.register_module() |
|
class ConvFuser(nn.Sequential): |
|
|
|
def __init__(self, in_channels: int, out_channels: int) -> None: |
|
self.in_channels = in_channels |
|
self.out_channels = out_channels |
|
super().__init__( |
|
nn.Conv2d( |
|
sum(in_channels), out_channels, 3, padding=1, bias=False), |
|
nn.BatchNorm2d(out_channels), |
|
nn.ReLU(True), |
|
) |
|
|
|
def forward(self, inputs: List[torch.Tensor]) -> torch.Tensor: |
|
return super().forward(torch.cat(inputs, dim=1)) |
|
|
|
|
|
@MODELS.register_module() |
|
class TransFusionHead(nn.Module): |
|
|
|
def __init__( |
|
self, |
|
num_proposals=128, |
|
auxiliary=True, |
|
in_channels=128 * 3, |
|
hidden_channel=128, |
|
num_classes=4, |
|
|
|
num_decoder_layers=3, |
|
decoder_layer=dict(), |
|
num_heads=8, |
|
nms_kernel_size=1, |
|
bn_momentum=0.1, |
|
|
|
common_heads=dict(), |
|
num_heatmap_convs=2, |
|
conv_cfg=dict(type='Conv1d'), |
|
norm_cfg=dict(type='BN1d'), |
|
bias='auto', |
|
|
|
loss_cls=dict(type='mmdet.GaussianFocalLoss', reduction='mean'), |
|
loss_bbox=dict(type='mmdet.L1Loss', reduction='mean'), |
|
loss_heatmap=dict(type='mmdet.GaussianFocalLoss', reduction='mean'), |
|
|
|
train_cfg=None, |
|
test_cfg=None, |
|
bbox_coder=None, |
|
): |
|
super(TransFusionHead, self).__init__() |
|
|
|
self.num_classes = num_classes |
|
self.num_proposals = num_proposals |
|
self.auxiliary = auxiliary |
|
self.in_channels = in_channels |
|
self.num_heads = num_heads |
|
self.num_decoder_layers = num_decoder_layers |
|
self.bn_momentum = bn_momentum |
|
self.nms_kernel_size = nms_kernel_size |
|
self.train_cfg = train_cfg |
|
self.test_cfg = test_cfg |
|
|
|
self.use_sigmoid_cls = loss_cls.get('use_sigmoid', False) |
|
if not self.use_sigmoid_cls: |
|
self.num_classes += 1 |
|
self.loss_cls = MODELS.build(loss_cls) |
|
self.loss_bbox = MODELS.build(loss_bbox) |
|
self.loss_heatmap = MODELS.build(loss_heatmap) |
|
|
|
self.bbox_coder = build_bbox_coder(bbox_coder) |
|
self.sampling = False |
|
|
|
|
|
self.shared_conv = build_conv_layer( |
|
dict(type='Conv2d'), |
|
in_channels, |
|
hidden_channel, |
|
kernel_size=3, |
|
padding=1, |
|
bias=bias, |
|
) |
|
|
|
layers = [] |
|
layers.append( |
|
ConvModule( |
|
hidden_channel, |
|
hidden_channel, |
|
kernel_size=3, |
|
padding=1, |
|
bias=bias, |
|
conv_cfg=dict(type='Conv2d'), |
|
norm_cfg=dict(type='BN2d'), |
|
)) |
|
layers.append( |
|
build_conv_layer( |
|
dict(type='Conv2d'), |
|
hidden_channel, |
|
num_classes, |
|
kernel_size=3, |
|
padding=1, |
|
bias=bias, |
|
)) |
|
self.heatmap_head = nn.Sequential(*layers) |
|
self.class_encoding = nn.Conv1d(num_classes, hidden_channel, 1) |
|
|
|
|
|
self.decoder = nn.ModuleList() |
|
for i in range(self.num_decoder_layers): |
|
self.decoder.append(MODELS.build(decoder_layer)) |
|
|
|
|
|
self.prediction_heads = nn.ModuleList() |
|
for i in range(self.num_decoder_layers): |
|
heads = copy.deepcopy(common_heads) |
|
heads.update(dict(heatmap=(self.num_classes, num_heatmap_convs))) |
|
self.prediction_heads.append( |
|
SeparateHead( |
|
hidden_channel, |
|
heads, |
|
conv_cfg=conv_cfg, |
|
norm_cfg=norm_cfg, |
|
bias=bias, |
|
)) |
|
|
|
self.init_weights() |
|
self._init_assigner_sampler() |
|
|
|
|
|
x_size = self.test_cfg['grid_size'][0] // self.test_cfg[ |
|
'out_size_factor'] |
|
y_size = self.test_cfg['grid_size'][1] // self.test_cfg[ |
|
'out_size_factor'] |
|
self.bev_pos = self.create_2D_grid(x_size, y_size) |
|
|
|
self.img_feat_pos = None |
|
self.img_feat_collapsed_pos = None |
|
|
|
def create_2D_grid(self, x_size, y_size): |
|
meshgrid = [[0, x_size - 1, x_size], [0, y_size - 1, y_size]] |
|
|
|
batch_x, batch_y = torch.meshgrid( |
|
*[torch.linspace(it[0], it[1], it[2]) for it in meshgrid]) |
|
batch_x = batch_x + 0.5 |
|
batch_y = batch_y + 0.5 |
|
coord_base = torch.cat([batch_x[None], batch_y[None]], dim=0)[None] |
|
coord_base = coord_base.view(1, 2, -1).permute(0, 2, 1) |
|
return coord_base |
|
|
|
def init_weights(self): |
|
|
|
for m in self.decoder.parameters(): |
|
if m.dim() > 1: |
|
nn.init.xavier_uniform_(m) |
|
if hasattr(self, 'query'): |
|
nn.init.xavier_normal_(self.query) |
|
self.init_bn_momentum() |
|
|
|
def init_bn_momentum(self): |
|
for m in self.modules(): |
|
if isinstance(m, (nn.BatchNorm2d, nn.BatchNorm1d)): |
|
m.momentum = self.bn_momentum |
|
|
|
def _init_assigner_sampler(self): |
|
"""Initialize the target assigner and sampler of the head.""" |
|
if self.train_cfg is None: |
|
return |
|
|
|
if self.sampling: |
|
self.bbox_sampler = build_sampler(self.train_cfg.sampler) |
|
else: |
|
self.bbox_sampler = PseudoSampler() |
|
if isinstance(self.train_cfg.assigner, dict): |
|
self.bbox_assigner = build_assigner(self.train_cfg.assigner) |
|
elif isinstance(self.train_cfg.assigner, list): |
|
self.bbox_assigner = [ |
|
build_assigner(res) for res in self.train_cfg.assigner |
|
] |
|
|
|
def forward_single(self, inputs, metas): |
|
"""Forward function for CenterPoint. |
|
Args: |
|
inputs (torch.Tensor): Input feature map with the shape of |
|
[B, 512, 128(H), 128(W)]. (consistent with L748) |
|
Returns: |
|
list[dict]: Output results for tasks. |
|
""" |
|
batch_size = inputs.shape[0] |
|
fusion_feat = self.shared_conv(inputs) |
|
|
|
|
|
|
|
|
|
fusion_feat_flatten = fusion_feat.view(batch_size, |
|
fusion_feat.shape[1], |
|
-1) |
|
bev_pos = self.bev_pos.repeat(batch_size, 1, 1).to(fusion_feat.device) |
|
|
|
|
|
|
|
|
|
with torch.autocast('cuda', enabled=False): |
|
dense_heatmap = self.heatmap_head(fusion_feat.float()) |
|
heatmap = dense_heatmap.detach().sigmoid() |
|
padding = self.nms_kernel_size // 2 |
|
local_max = torch.zeros_like(heatmap) |
|
|
|
local_max_inner = F.max_pool2d( |
|
heatmap, kernel_size=self.nms_kernel_size, stride=1, padding=0) |
|
local_max[:, :, padding:(-padding), |
|
padding:(-padding)] = local_max_inner |
|
|
|
if self.test_cfg['dataset'] == 'nuScenes': |
|
local_max[:, 8, ] = F.max_pool2d( |
|
heatmap[:, 8], kernel_size=1, stride=1, padding=0) |
|
local_max[:, 9, ] = F.max_pool2d( |
|
heatmap[:, 9], kernel_size=1, stride=1, padding=0) |
|
elif self.test_cfg[ |
|
'dataset'] == 'Waymo': |
|
local_max[:, 1, ] = F.max_pool2d( |
|
heatmap[:, 1], kernel_size=1, stride=1, padding=0) |
|
local_max[:, 2, ] = F.max_pool2d( |
|
heatmap[:, 2], kernel_size=1, stride=1, padding=0) |
|
heatmap = heatmap * (heatmap == local_max) |
|
heatmap = heatmap.view(batch_size, heatmap.shape[1], -1) |
|
|
|
|
|
top_proposals = heatmap.view(batch_size, -1).argsort( |
|
dim=-1, descending=True)[..., :self.num_proposals] |
|
top_proposals_class = top_proposals // heatmap.shape[-1] |
|
top_proposals_index = top_proposals % heatmap.shape[-1] |
|
query_feat = fusion_feat_flatten.gather( |
|
index=top_proposals_index[:, None, :].expand( |
|
-1, fusion_feat_flatten.shape[1], -1), |
|
dim=-1, |
|
) |
|
self.query_labels = top_proposals_class |
|
|
|
|
|
one_hot = F.one_hot( |
|
top_proposals_class, |
|
num_classes=self.num_classes).permute(0, 2, 1) |
|
query_cat_encoding = self.class_encoding(one_hot.float()) |
|
query_feat += query_cat_encoding |
|
|
|
query_pos = bev_pos.gather( |
|
index=top_proposals_index[:, None, :].permute(0, 2, 1).expand( |
|
-1, -1, bev_pos.shape[-1]), |
|
dim=1, |
|
) |
|
|
|
|
|
|
|
ret_dicts = [] |
|
for i in range(self.num_decoder_layers): |
|
|
|
|
|
query_feat = self.decoder[i]( |
|
query_feat, |
|
key=fusion_feat_flatten, |
|
query_pos=query_pos, |
|
key_pos=bev_pos) |
|
|
|
|
|
res_layer = self.prediction_heads[i](query_feat) |
|
res_layer['center'] = res_layer['center'] + query_pos.permute( |
|
0, 2, 1) |
|
ret_dicts.append(res_layer) |
|
|
|
|
|
query_pos = res_layer['center'].detach().clone().permute(0, 2, 1) |
|
|
|
ret_dicts[0]['query_heatmap_score'] = heatmap.gather( |
|
index=top_proposals_index[:, |
|
None, :].expand(-1, self.num_classes, |
|
-1), |
|
dim=-1, |
|
) |
|
ret_dicts[0]['dense_heatmap'] = dense_heatmap |
|
|
|
if self.auxiliary is False: |
|
|
|
return [ret_dicts[-1]] |
|
|
|
|
|
new_res = {} |
|
for key in ret_dicts[0].keys(): |
|
if key not in [ |
|
'dense_heatmap', 'dense_heatmap_old', 'query_heatmap_score' |
|
]: |
|
new_res[key] = torch.cat( |
|
[ret_dict[key] for ret_dict in ret_dicts], dim=-1) |
|
else: |
|
new_res[key] = ret_dicts[0][key] |
|
return [new_res] |
|
|
|
def forward(self, feats, metas): |
|
"""Forward pass. |
|
|
|
Args: |
|
feats (list[torch.Tensor]): Multi-level features, e.g., |
|
features produced by FPN. |
|
Returns: |
|
tuple(list[dict]): Output results. first index by level, second |
|
index by layer |
|
""" |
|
if isinstance(feats, torch.Tensor): |
|
feats = [feats] |
|
res = multi_apply(self.forward_single, feats, [metas]) |
|
assert len(res) == 1, 'only support one level features.' |
|
return res |
|
|
|
def predict(self, batch_feats, batch_input_metas): |
|
preds_dicts = self(batch_feats, batch_input_metas) |
|
res = self.predict_by_feat(preds_dicts, batch_input_metas) |
|
return res |
|
|
|
def predict_by_feat(self, |
|
preds_dicts, |
|
metas, |
|
img=None, |
|
rescale=False, |
|
for_roi=False): |
|
"""Generate bboxes from bbox head predictions. |
|
|
|
Args: |
|
preds_dicts (tuple[list[dict]]): Prediction results. |
|
Returns: |
|
list[list[dict]]: Decoded bbox, scores and labels for each layer |
|
& each batch. |
|
""" |
|
rets = [] |
|
for layer_id, preds_dict in enumerate(preds_dicts): |
|
batch_size = preds_dict[0]['heatmap'].shape[0] |
|
batch_score = preds_dict[0]['heatmap'][ |
|
..., -self.num_proposals:].sigmoid() |
|
|
|
|
|
one_hot = F.one_hot( |
|
self.query_labels, |
|
num_classes=self.num_classes).permute(0, 2, 1) |
|
batch_score = batch_score * preds_dict[0][ |
|
'query_heatmap_score'] * one_hot |
|
|
|
batch_center = preds_dict[0]['center'][..., -self.num_proposals:] |
|
batch_height = preds_dict[0]['height'][..., -self.num_proposals:] |
|
batch_dim = preds_dict[0]['dim'][..., -self.num_proposals:] |
|
batch_rot = preds_dict[0]['rot'][..., -self.num_proposals:] |
|
batch_vel = None |
|
if 'vel' in preds_dict[0]: |
|
batch_vel = preds_dict[0]['vel'][..., -self.num_proposals:] |
|
|
|
temp = self.bbox_coder.decode( |
|
batch_score, |
|
batch_rot, |
|
batch_dim, |
|
batch_center, |
|
batch_height, |
|
batch_vel, |
|
filter=True, |
|
) |
|
|
|
if self.test_cfg['dataset'] == 'nuScenes': |
|
self.tasks = [ |
|
dict( |
|
num_class=8, |
|
class_names=[], |
|
indices=[0, 1, 2, 3, 4, 5, 6, 7], |
|
radius=-1, |
|
), |
|
dict( |
|
num_class=1, |
|
class_names=['pedestrian'], |
|
indices=[8], |
|
radius=0.175, |
|
), |
|
dict( |
|
num_class=1, |
|
class_names=['traffic_cone'], |
|
indices=[9], |
|
radius=0.175, |
|
), |
|
] |
|
elif self.test_cfg['dataset'] == 'Waymo': |
|
self.tasks = [ |
|
dict( |
|
num_class=1, |
|
class_names=['Car'], |
|
indices=[0], |
|
radius=0.7), |
|
dict( |
|
num_class=1, |
|
class_names=['Pedestrian'], |
|
indices=[1], |
|
radius=0.7), |
|
dict( |
|
num_class=1, |
|
class_names=['Cyclist'], |
|
indices=[2], |
|
radius=0.7), |
|
] |
|
|
|
ret_layer = [] |
|
for i in range(batch_size): |
|
boxes3d = temp[i]['bboxes'] |
|
scores = temp[i]['scores'] |
|
labels = temp[i]['labels'] |
|
|
|
if self.test_cfg['nms_type'] is not None: |
|
keep_mask = torch.zeros_like(scores) |
|
for task in self.tasks: |
|
task_mask = torch.zeros_like(scores) |
|
for cls_idx in task['indices']: |
|
task_mask += labels == cls_idx |
|
task_mask = task_mask.bool() |
|
if task['radius'] > 0: |
|
if self.test_cfg['nms_type'] == 'circle': |
|
boxes_for_nms = torch.cat( |
|
[ |
|
boxes3d[task_mask][:, :2], |
|
scores[:, None][task_mask], |
|
], |
|
dim=1, |
|
) |
|
task_keep_indices = torch.tensor( |
|
circle_nms( |
|
boxes_for_nms.detach().cpu().numpy(), |
|
task['radius'], |
|
)) |
|
else: |
|
boxes_for_nms = xywhr2xyxyr( |
|
metas[i]['box_type_3d']( |
|
boxes3d[task_mask][:, :7], 7).bev) |
|
top_scores = scores[task_mask] |
|
task_keep_indices = nms_bev( |
|
boxes_for_nms, |
|
top_scores, |
|
thresh=task['radius'], |
|
pre_maxsize=self.test_cfg['pre_maxsize'], |
|
post_max_size=self. |
|
test_cfg['post_maxsize'], |
|
) |
|
else: |
|
task_keep_indices = torch.arange(task_mask.sum()) |
|
if task_keep_indices.shape[0] != 0: |
|
keep_indices = torch.where( |
|
task_mask != 0)[0][task_keep_indices] |
|
keep_mask[keep_indices] = 1 |
|
keep_mask = keep_mask.bool() |
|
ret = dict( |
|
bboxes=boxes3d[keep_mask], |
|
scores=scores[keep_mask], |
|
labels=labels[keep_mask], |
|
) |
|
else: |
|
ret = dict(bboxes=boxes3d, scores=scores, labels=labels) |
|
|
|
temp_instances = InstanceData() |
|
temp_instances.bboxes_3d = metas[0]['box_type_3d']( |
|
ret['bboxes'], box_dim=ret['bboxes'].shape[-1]) |
|
temp_instances.scores_3d = ret['scores'] |
|
temp_instances.labels_3d = ret['labels'].int() |
|
|
|
ret_layer.append(temp_instances) |
|
|
|
rets.append(ret_layer) |
|
assert len( |
|
rets |
|
) == 1, f'only support one layer now, but get {len(rets)} layers' |
|
|
|
return rets[0] |
|
|
|
def get_targets(self, batch_gt_instances_3d: List[InstanceData], |
|
preds_dict: List[dict]): |
|
"""Generate training targets. |
|
Args: |
|
batch_gt_instances_3d (List[InstanceData]): |
|
preds_dict (list[dict]): The prediction results. The index of the |
|
list is the index of layers. The inner dict contains |
|
predictions of one mini-batch: |
|
- center: (bs, 2, num_proposals) |
|
- height: (bs, 1, num_proposals) |
|
- dim: (bs, 3, num_proposals) |
|
- rot: (bs, 2, num_proposals) |
|
- vel: (bs, 2, num_proposals) |
|
- cls_logit: (bs, num_classes, num_proposals) |
|
- query_score: (bs, num_classes, num_proposals) |
|
- heatmap: The original heatmap before fed into transformer |
|
decoder, with shape (bs, 10, h, w) |
|
Returns: |
|
tuple[torch.Tensor]: Tuple of target including \ |
|
the following results in order. |
|
- torch.Tensor: classification target. [BS, num_proposals] |
|
- torch.Tensor: classification weights (mask) |
|
[BS, num_proposals] |
|
- torch.Tensor: regression target. [BS, num_proposals, 8] |
|
- torch.Tensor: regression weights. [BS, num_proposals, 8] |
|
""" |
|
|
|
|
|
list_of_pred_dict = [] |
|
for batch_idx in range(len(batch_gt_instances_3d)): |
|
pred_dict = {} |
|
for key in preds_dict[0].keys(): |
|
preds = [] |
|
for i in range(self.num_decoder_layers): |
|
pred_one_layer = preds_dict[i][key][batch_idx:batch_idx + |
|
1] |
|
preds.append(pred_one_layer) |
|
pred_dict[key] = torch.cat(preds) |
|
list_of_pred_dict.append(pred_dict) |
|
|
|
assert len(batch_gt_instances_3d) == len(list_of_pred_dict) |
|
res_tuple = multi_apply( |
|
self.get_targets_single, |
|
batch_gt_instances_3d, |
|
list_of_pred_dict, |
|
np.arange(len(batch_gt_instances_3d)), |
|
) |
|
labels = torch.cat(res_tuple[0], dim=0) |
|
label_weights = torch.cat(res_tuple[1], dim=0) |
|
bbox_targets = torch.cat(res_tuple[2], dim=0) |
|
bbox_weights = torch.cat(res_tuple[3], dim=0) |
|
ious = torch.cat(res_tuple[4], dim=0) |
|
num_pos = np.sum(res_tuple[5]) |
|
matched_ious = np.mean(res_tuple[6]) |
|
heatmap = torch.cat(res_tuple[7], dim=0) |
|
return ( |
|
labels, |
|
label_weights, |
|
bbox_targets, |
|
bbox_weights, |
|
ious, |
|
num_pos, |
|
matched_ious, |
|
heatmap, |
|
) |
|
|
|
def get_targets_single(self, gt_instances_3d, preds_dict, batch_idx): |
|
"""Generate training targets for a single sample. |
|
Args: |
|
gt_instances_3d (:obj:`InstanceData`): ground truth of instances. |
|
preds_dict (dict): dict of prediction result for a single sample. |
|
Returns: |
|
tuple[torch.Tensor]: Tuple of target including \ |
|
the following results in order. |
|
- torch.Tensor: classification target. [1, num_proposals] |
|
- torch.Tensor: classification weights (mask) [1, |
|
num_proposals] # noqa: E501 |
|
- torch.Tensor: regression target. [1, num_proposals, 8] |
|
- torch.Tensor: regression weights. [1, num_proposals, 8] |
|
- torch.Tensor: iou target. [1, num_proposals] |
|
- int: number of positive proposals |
|
- torch.Tensor: heatmap targets. |
|
""" |
|
|
|
gt_bboxes_3d = gt_instances_3d.bboxes_3d |
|
gt_labels_3d = gt_instances_3d.labels_3d |
|
num_proposals = preds_dict['center'].shape[-1] |
|
|
|
|
|
score = copy.deepcopy(preds_dict['heatmap'].detach()) |
|
center = copy.deepcopy(preds_dict['center'].detach()) |
|
height = copy.deepcopy(preds_dict['height'].detach()) |
|
dim = copy.deepcopy(preds_dict['dim'].detach()) |
|
rot = copy.deepcopy(preds_dict['rot'].detach()) |
|
if 'vel' in preds_dict.keys(): |
|
vel = copy.deepcopy(preds_dict['vel'].detach()) |
|
else: |
|
vel = None |
|
|
|
boxes_dict = self.bbox_coder.decode( |
|
score, rot, dim, center, height, |
|
vel) |
|
bboxes_tensor = boxes_dict[0]['bboxes'] |
|
gt_bboxes_tensor = gt_bboxes_3d.tensor.to(score.device) |
|
|
|
if self.auxiliary: |
|
num_layer = self.num_decoder_layers |
|
else: |
|
num_layer = 1 |
|
|
|
assign_result_list = [] |
|
for idx_layer in range(num_layer): |
|
bboxes_tensor_layer = bboxes_tensor[self.num_proposals * |
|
idx_layer:self.num_proposals * |
|
(idx_layer + 1), :] |
|
score_layer = score[..., self.num_proposals * |
|
idx_layer:self.num_proposals * |
|
(idx_layer + 1), ] |
|
|
|
if self.train_cfg.assigner.type == 'HungarianAssigner3D': |
|
assign_result = self.bbox_assigner.assign( |
|
bboxes_tensor_layer, |
|
gt_bboxes_tensor, |
|
gt_labels_3d, |
|
score_layer, |
|
self.train_cfg, |
|
) |
|
elif self.train_cfg.assigner.type == 'HeuristicAssigner': |
|
assign_result = self.bbox_assigner.assign( |
|
bboxes_tensor_layer, |
|
gt_bboxes_tensor, |
|
None, |
|
gt_labels_3d, |
|
self.query_labels[batch_idx], |
|
) |
|
else: |
|
raise NotImplementedError |
|
assign_result_list.append(assign_result) |
|
|
|
|
|
assign_result_ensemble = AssignResult( |
|
num_gts=sum([res.num_gts for res in assign_result_list]), |
|
gt_inds=torch.cat([res.gt_inds for res in assign_result_list]), |
|
max_overlaps=torch.cat( |
|
[res.max_overlaps for res in assign_result_list]), |
|
labels=torch.cat([res.labels for res in assign_result_list]), |
|
) |
|
|
|
|
|
|
|
gt_instances, pred_instances = InstanceData( |
|
bboxes=gt_bboxes_tensor), InstanceData(priors=bboxes_tensor) |
|
sampling_result = self.bbox_sampler.sample(assign_result_ensemble, |
|
pred_instances, |
|
gt_instances) |
|
pos_inds = sampling_result.pos_inds |
|
neg_inds = sampling_result.neg_inds |
|
assert len(pos_inds) + len(neg_inds) == num_proposals |
|
|
|
|
|
bbox_targets = torch.zeros([num_proposals, self.bbox_coder.code_size |
|
]).to(center.device) |
|
bbox_weights = torch.zeros([num_proposals, self.bbox_coder.code_size |
|
]).to(center.device) |
|
ious = assign_result_ensemble.max_overlaps |
|
ious = torch.clamp(ious, min=0.0, max=1.0) |
|
labels = bboxes_tensor.new_zeros(num_proposals, dtype=torch.long) |
|
label_weights = bboxes_tensor.new_zeros( |
|
num_proposals, dtype=torch.long) |
|
|
|
if gt_labels_3d is not None: |
|
labels += self.num_classes |
|
|
|
|
|
|
|
if len(pos_inds) > 0: |
|
pos_bbox_targets = self.bbox_coder.encode( |
|
sampling_result.pos_gt_bboxes) |
|
|
|
bbox_targets[pos_inds, :] = pos_bbox_targets |
|
bbox_weights[pos_inds, :] = 1.0 |
|
|
|
if gt_labels_3d is None: |
|
labels[pos_inds] = 1 |
|
else: |
|
labels[pos_inds] = gt_labels_3d[ |
|
sampling_result.pos_assigned_gt_inds] |
|
if self.train_cfg.pos_weight <= 0: |
|
label_weights[pos_inds] = 1.0 |
|
else: |
|
label_weights[pos_inds] = self.train_cfg.pos_weight |
|
|
|
if len(neg_inds) > 0: |
|
label_weights[neg_inds] = 1.0 |
|
|
|
|
|
device = labels.device |
|
gt_bboxes_3d = torch.cat( |
|
[gt_bboxes_3d.gravity_center, gt_bboxes_3d.tensor[:, 3:]], |
|
dim=1).to(device) |
|
grid_size = torch.tensor(self.train_cfg['grid_size']) |
|
pc_range = torch.tensor(self.train_cfg['point_cloud_range']) |
|
voxel_size = torch.tensor(self.train_cfg['voxel_size']) |
|
feature_map_size = (grid_size[:2] // self.train_cfg['out_size_factor'] |
|
) |
|
heatmap = gt_bboxes_3d.new_zeros(self.num_classes, feature_map_size[1], |
|
feature_map_size[0]) |
|
for idx in range(len(gt_bboxes_3d)): |
|
width = gt_bboxes_3d[idx][3] |
|
length = gt_bboxes_3d[idx][4] |
|
width = width / voxel_size[0] / self.train_cfg['out_size_factor'] |
|
length = length / voxel_size[1] / self.train_cfg['out_size_factor'] |
|
if width > 0 and length > 0: |
|
radius = gaussian_radius( |
|
(length, width), |
|
min_overlap=self.train_cfg['gaussian_overlap']) |
|
radius = max(self.train_cfg['min_radius'], int(radius)) |
|
x, y = gt_bboxes_3d[idx][0], gt_bboxes_3d[idx][1] |
|
|
|
coor_x = ((x - pc_range[0]) / voxel_size[0] / |
|
self.train_cfg['out_size_factor']) |
|
coor_y = ((y - pc_range[1]) / voxel_size[1] / |
|
self.train_cfg['out_size_factor']) |
|
|
|
center = torch.tensor([coor_x, coor_y], |
|
dtype=torch.float32, |
|
device=device) |
|
center_int = center.to(torch.int32) |
|
|
|
|
|
|
|
|
|
draw_heatmap_gaussian(heatmap[gt_labels_3d[idx]], |
|
center_int[[1, 0]], radius) |
|
|
|
mean_iou = ious[pos_inds].sum() / max(len(pos_inds), 1) |
|
return ( |
|
labels[None], |
|
label_weights[None], |
|
bbox_targets[None], |
|
bbox_weights[None], |
|
ious[None], |
|
int(pos_inds.shape[0]), |
|
float(mean_iou), |
|
heatmap[None], |
|
) |
|
|
|
def loss(self, batch_feats, batch_data_samples): |
|
"""Loss function for CenterHead. |
|
|
|
Args: |
|
batch_feats (): Features in a batch. |
|
batch_data_samples (List[:obj:`Det3DDataSample`]): The Data |
|
Samples. It usually includes information such as |
|
`gt_instance_3d`. |
|
Returns: |
|
dict[str:torch.Tensor]: Loss of heatmap and bbox of each task. |
|
""" |
|
batch_input_metas, batch_gt_instances_3d = [], [] |
|
for data_sample in batch_data_samples: |
|
batch_input_metas.append(data_sample.metainfo) |
|
batch_gt_instances_3d.append(data_sample.gt_instances_3d) |
|
preds_dicts = self(batch_feats, batch_input_metas) |
|
loss = self.loss_by_feat(preds_dicts, batch_gt_instances_3d) |
|
|
|
return loss |
|
|
|
def loss_by_feat(self, preds_dicts: Tuple[List[dict]], |
|
batch_gt_instances_3d: List[InstanceData], *args, |
|
**kwargs): |
|
( |
|
labels, |
|
label_weights, |
|
bbox_targets, |
|
bbox_weights, |
|
ious, |
|
num_pos, |
|
matched_ious, |
|
heatmap, |
|
) = self.get_targets(batch_gt_instances_3d, preds_dicts[0]) |
|
if hasattr(self, 'on_the_image_mask'): |
|
label_weights = label_weights * self.on_the_image_mask |
|
bbox_weights = bbox_weights * self.on_the_image_mask[:, :, None] |
|
num_pos = bbox_weights.max(-1).values.sum() |
|
preds_dict = preds_dicts[0][0] |
|
loss_dict = dict() |
|
|
|
|
|
loss_heatmap = self.loss_heatmap( |
|
clip_sigmoid(preds_dict['dense_heatmap']).float(), |
|
heatmap.float(), |
|
avg_factor=max(heatmap.eq(1).float().sum().item(), 1), |
|
) |
|
loss_dict['loss_heatmap'] = loss_heatmap |
|
|
|
|
|
for idx_layer in range( |
|
self.num_decoder_layers if self.auxiliary else 1): |
|
if idx_layer == self.num_decoder_layers - 1 or ( |
|
idx_layer == 0 and self.auxiliary is False): |
|
prefix = 'layer_-1' |
|
else: |
|
prefix = f'layer_{idx_layer}' |
|
|
|
layer_labels = labels[..., idx_layer * |
|
self.num_proposals:(idx_layer + 1) * |
|
self.num_proposals, ].reshape(-1) |
|
layer_label_weights = label_weights[ |
|
..., idx_layer * self.num_proposals:(idx_layer + 1) * |
|
self.num_proposals, ].reshape(-1) |
|
layer_score = preds_dict['heatmap'][..., idx_layer * |
|
self.num_proposals:(idx_layer + |
|
1) * |
|
self.num_proposals, ] |
|
layer_cls_score = layer_score.permute(0, 2, 1).reshape( |
|
-1, self.num_classes) |
|
layer_loss_cls = self.loss_cls( |
|
layer_cls_score.float(), |
|
layer_labels, |
|
layer_label_weights, |
|
avg_factor=max(num_pos, 1), |
|
) |
|
|
|
layer_center = preds_dict['center'][..., idx_layer * |
|
self.num_proposals:(idx_layer + |
|
1) * |
|
self.num_proposals, ] |
|
layer_height = preds_dict['height'][..., idx_layer * |
|
self.num_proposals:(idx_layer + |
|
1) * |
|
self.num_proposals, ] |
|
layer_rot = preds_dict['rot'][..., idx_layer * |
|
self.num_proposals:(idx_layer + 1) * |
|
self.num_proposals, ] |
|
layer_dim = preds_dict['dim'][..., idx_layer * |
|
self.num_proposals:(idx_layer + 1) * |
|
self.num_proposals, ] |
|
preds = torch.cat( |
|
[layer_center, layer_height, layer_dim, layer_rot], |
|
dim=1).permute(0, 2, 1) |
|
if 'vel' in preds_dict.keys(): |
|
layer_vel = preds_dict['vel'][..., idx_layer * |
|
self.num_proposals:(idx_layer + |
|
1) * |
|
self.num_proposals, ] |
|
preds = torch.cat([ |
|
layer_center, layer_height, layer_dim, layer_rot, layer_vel |
|
], |
|
dim=1).permute( |
|
0, 2, |
|
1) |
|
code_weights = self.train_cfg.get('code_weights', None) |
|
layer_bbox_weights = bbox_weights[:, idx_layer * |
|
self.num_proposals:(idx_layer + |
|
1) * |
|
self.num_proposals, :, ] |
|
layer_reg_weights = layer_bbox_weights * layer_bbox_weights.new_tensor( |
|
code_weights) |
|
layer_bbox_targets = bbox_targets[:, idx_layer * |
|
self.num_proposals:(idx_layer + |
|
1) * |
|
self.num_proposals, :, ] |
|
layer_loss_bbox = self.loss_bbox( |
|
preds, |
|
layer_bbox_targets, |
|
layer_reg_weights, |
|
avg_factor=max(num_pos, 1)) |
|
|
|
loss_dict[f'{prefix}_loss_cls'] = layer_loss_cls |
|
loss_dict[f'{prefix}_loss_bbox'] = layer_loss_bbox |
|
|
|
|
|
loss_dict['matched_ious'] = layer_loss_cls.new_tensor(matched_ious) |
|
|
|
return loss_dict |
|
|