mm3dtest / projects /BEVFusion /bevfusion /transfusion_head.py
giantmonkeyTC
2344
34d1f8b
raw
history blame contribute delete
36.5 kB
# modify from https://github.com/mit-han-lab/bevfusion
import copy
from typing import List, Tuple
import numpy as np
import torch
import torch.nn.functional as F
from mmcv.cnn import ConvModule, build_conv_layer
from mmdet.models.task_modules import (AssignResult, PseudoSampler,
build_assigner, build_bbox_coder,
build_sampler)
from mmdet.models.utils import multi_apply
from mmengine.structures import InstanceData
from torch import nn
from mmdet3d.models import circle_nms, draw_heatmap_gaussian, gaussian_radius
from mmdet3d.models.dense_heads.centerpoint_head import SeparateHead
from mmdet3d.models.layers import nms_bev
from mmdet3d.registry import MODELS
from mmdet3d.structures import xywhr2xyxyr
def clip_sigmoid(x, eps=1e-4):
y = torch.clamp(x.sigmoid_(), min=eps, max=1 - eps)
return y
@MODELS.register_module()
class ConvFuser(nn.Sequential):
def __init__(self, in_channels: int, out_channels: int) -> None:
self.in_channels = in_channels
self.out_channels = out_channels
super().__init__(
nn.Conv2d(
sum(in_channels), out_channels, 3, padding=1, bias=False),
nn.BatchNorm2d(out_channels),
nn.ReLU(True),
)
def forward(self, inputs: List[torch.Tensor]) -> torch.Tensor:
return super().forward(torch.cat(inputs, dim=1))
@MODELS.register_module()
class TransFusionHead(nn.Module):
def __init__(
self,
num_proposals=128,
auxiliary=True,
in_channels=128 * 3,
hidden_channel=128,
num_classes=4,
# config for Transformer
num_decoder_layers=3,
decoder_layer=dict(),
num_heads=8,
nms_kernel_size=1,
bn_momentum=0.1,
# config for FFN
common_heads=dict(),
num_heatmap_convs=2,
conv_cfg=dict(type='Conv1d'),
norm_cfg=dict(type='BN1d'),
bias='auto',
# loss
loss_cls=dict(type='mmdet.GaussianFocalLoss', reduction='mean'),
loss_bbox=dict(type='mmdet.L1Loss', reduction='mean'),
loss_heatmap=dict(type='mmdet.GaussianFocalLoss', reduction='mean'),
# others
train_cfg=None,
test_cfg=None,
bbox_coder=None,
):
super(TransFusionHead, self).__init__()
self.num_classes = num_classes
self.num_proposals = num_proposals
self.auxiliary = auxiliary
self.in_channels = in_channels
self.num_heads = num_heads
self.num_decoder_layers = num_decoder_layers
self.bn_momentum = bn_momentum
self.nms_kernel_size = nms_kernel_size
self.train_cfg = train_cfg
self.test_cfg = test_cfg
self.use_sigmoid_cls = loss_cls.get('use_sigmoid', False)
if not self.use_sigmoid_cls:
self.num_classes += 1
self.loss_cls = MODELS.build(loss_cls)
self.loss_bbox = MODELS.build(loss_bbox)
self.loss_heatmap = MODELS.build(loss_heatmap)
self.bbox_coder = build_bbox_coder(bbox_coder)
self.sampling = False
# a shared convolution
self.shared_conv = build_conv_layer(
dict(type='Conv2d'),
in_channels,
hidden_channel,
kernel_size=3,
padding=1,
bias=bias,
)
layers = []
layers.append(
ConvModule(
hidden_channel,
hidden_channel,
kernel_size=3,
padding=1,
bias=bias,
conv_cfg=dict(type='Conv2d'),
norm_cfg=dict(type='BN2d'),
))
layers.append(
build_conv_layer(
dict(type='Conv2d'),
hidden_channel,
num_classes,
kernel_size=3,
padding=1,
bias=bias,
))
self.heatmap_head = nn.Sequential(*layers)
self.class_encoding = nn.Conv1d(num_classes, hidden_channel, 1)
# transformer decoder layers for object query with LiDAR feature
self.decoder = nn.ModuleList()
for i in range(self.num_decoder_layers):
self.decoder.append(MODELS.build(decoder_layer))
# Prediction Head
self.prediction_heads = nn.ModuleList()
for i in range(self.num_decoder_layers):
heads = copy.deepcopy(common_heads)
heads.update(dict(heatmap=(self.num_classes, num_heatmap_convs)))
self.prediction_heads.append(
SeparateHead(
hidden_channel,
heads,
conv_cfg=conv_cfg,
norm_cfg=norm_cfg,
bias=bias,
))
self.init_weights()
self._init_assigner_sampler()
# Position Embedding for Cross-Attention, which is re-used during training # noqa: E501
x_size = self.test_cfg['grid_size'][0] // self.test_cfg[
'out_size_factor']
y_size = self.test_cfg['grid_size'][1] // self.test_cfg[
'out_size_factor']
self.bev_pos = self.create_2D_grid(x_size, y_size)
self.img_feat_pos = None
self.img_feat_collapsed_pos = None
def create_2D_grid(self, x_size, y_size):
meshgrid = [[0, x_size - 1, x_size], [0, y_size - 1, y_size]]
# NOTE: modified
batch_x, batch_y = torch.meshgrid(
*[torch.linspace(it[0], it[1], it[2]) for it in meshgrid])
batch_x = batch_x + 0.5
batch_y = batch_y + 0.5
coord_base = torch.cat([batch_x[None], batch_y[None]], dim=0)[None]
coord_base = coord_base.view(1, 2, -1).permute(0, 2, 1)
return coord_base
def init_weights(self):
# initialize transformer
for m in self.decoder.parameters():
if m.dim() > 1:
nn.init.xavier_uniform_(m)
if hasattr(self, 'query'):
nn.init.xavier_normal_(self.query)
self.init_bn_momentum()
def init_bn_momentum(self):
for m in self.modules():
if isinstance(m, (nn.BatchNorm2d, nn.BatchNorm1d)):
m.momentum = self.bn_momentum
def _init_assigner_sampler(self):
"""Initialize the target assigner and sampler of the head."""
if self.train_cfg is None:
return
if self.sampling:
self.bbox_sampler = build_sampler(self.train_cfg.sampler)
else:
self.bbox_sampler = PseudoSampler()
if isinstance(self.train_cfg.assigner, dict):
self.bbox_assigner = build_assigner(self.train_cfg.assigner)
elif isinstance(self.train_cfg.assigner, list):
self.bbox_assigner = [
build_assigner(res) for res in self.train_cfg.assigner
]
def forward_single(self, inputs, metas):
"""Forward function for CenterPoint.
Args:
inputs (torch.Tensor): Input feature map with the shape of
[B, 512, 128(H), 128(W)]. (consistent with L748)
Returns:
list[dict]: Output results for tasks.
"""
batch_size = inputs.shape[0]
fusion_feat = self.shared_conv(inputs)
#################################
# image to BEV
#################################
fusion_feat_flatten = fusion_feat.view(batch_size,
fusion_feat.shape[1],
-1) # [BS, C, H*W]
bev_pos = self.bev_pos.repeat(batch_size, 1, 1).to(fusion_feat.device)
#################################
# query initialization
#################################
with torch.autocast('cuda', enabled=False):
dense_heatmap = self.heatmap_head(fusion_feat.float())
heatmap = dense_heatmap.detach().sigmoid()
padding = self.nms_kernel_size // 2
local_max = torch.zeros_like(heatmap)
# equals to nms radius = voxel_size * out_size_factor * kenel_size
local_max_inner = F.max_pool2d(
heatmap, kernel_size=self.nms_kernel_size, stride=1, padding=0)
local_max[:, :, padding:(-padding),
padding:(-padding)] = local_max_inner
# for Pedestrian & Traffic_cone in nuScenes
if self.test_cfg['dataset'] == 'nuScenes':
local_max[:, 8, ] = F.max_pool2d(
heatmap[:, 8], kernel_size=1, stride=1, padding=0)
local_max[:, 9, ] = F.max_pool2d(
heatmap[:, 9], kernel_size=1, stride=1, padding=0)
elif self.test_cfg[
'dataset'] == 'Waymo': # for Pedestrian & Cyclist in Waymo
local_max[:, 1, ] = F.max_pool2d(
heatmap[:, 1], kernel_size=1, stride=1, padding=0)
local_max[:, 2, ] = F.max_pool2d(
heatmap[:, 2], kernel_size=1, stride=1, padding=0)
heatmap = heatmap * (heatmap == local_max)
heatmap = heatmap.view(batch_size, heatmap.shape[1], -1)
# top num_proposals among all classes
top_proposals = heatmap.view(batch_size, -1).argsort(
dim=-1, descending=True)[..., :self.num_proposals]
top_proposals_class = top_proposals // heatmap.shape[-1]
top_proposals_index = top_proposals % heatmap.shape[-1]
query_feat = fusion_feat_flatten.gather(
index=top_proposals_index[:, None, :].expand(
-1, fusion_feat_flatten.shape[1], -1),
dim=-1,
)
self.query_labels = top_proposals_class
# add category embedding
one_hot = F.one_hot(
top_proposals_class,
num_classes=self.num_classes).permute(0, 2, 1)
query_cat_encoding = self.class_encoding(one_hot.float())
query_feat += query_cat_encoding
query_pos = bev_pos.gather(
index=top_proposals_index[:, None, :].permute(0, 2, 1).expand(
-1, -1, bev_pos.shape[-1]),
dim=1,
)
#################################
# transformer decoder layer (Fusion feature as K,V)
#################################
ret_dicts = []
for i in range(self.num_decoder_layers):
# Transformer Decoder Layer
# :param query: B C Pq :param query_pos: B Pq 3/6
query_feat = self.decoder[i](
query_feat,
key=fusion_feat_flatten,
query_pos=query_pos,
key_pos=bev_pos)
# Prediction
res_layer = self.prediction_heads[i](query_feat)
res_layer['center'] = res_layer['center'] + query_pos.permute(
0, 2, 1)
ret_dicts.append(res_layer)
# for next level positional embedding
query_pos = res_layer['center'].detach().clone().permute(0, 2, 1)
ret_dicts[0]['query_heatmap_score'] = heatmap.gather(
index=top_proposals_index[:,
None, :].expand(-1, self.num_classes,
-1),
dim=-1,
) # [bs, num_classes, num_proposals]
ret_dicts[0]['dense_heatmap'] = dense_heatmap
if self.auxiliary is False:
# only return the results of last decoder layer
return [ret_dicts[-1]]
# return all the layer's results for auxiliary superivison
new_res = {}
for key in ret_dicts[0].keys():
if key not in [
'dense_heatmap', 'dense_heatmap_old', 'query_heatmap_score'
]:
new_res[key] = torch.cat(
[ret_dict[key] for ret_dict in ret_dicts], dim=-1)
else:
new_res[key] = ret_dicts[0][key]
return [new_res]
def forward(self, feats, metas):
"""Forward pass.
Args:
feats (list[torch.Tensor]): Multi-level features, e.g.,
features produced by FPN.
Returns:
tuple(list[dict]): Output results. first index by level, second
index by layer
"""
if isinstance(feats, torch.Tensor):
feats = [feats]
res = multi_apply(self.forward_single, feats, [metas])
assert len(res) == 1, 'only support one level features.'
return res
def predict(self, batch_feats, batch_input_metas):
preds_dicts = self(batch_feats, batch_input_metas)
res = self.predict_by_feat(preds_dicts, batch_input_metas)
return res
def predict_by_feat(self,
preds_dicts,
metas,
img=None,
rescale=False,
for_roi=False):
"""Generate bboxes from bbox head predictions.
Args:
preds_dicts (tuple[list[dict]]): Prediction results.
Returns:
list[list[dict]]: Decoded bbox, scores and labels for each layer
& each batch.
"""
rets = []
for layer_id, preds_dict in enumerate(preds_dicts):
batch_size = preds_dict[0]['heatmap'].shape[0]
batch_score = preds_dict[0]['heatmap'][
..., -self.num_proposals:].sigmoid()
# if self.loss_iou.loss_weight != 0:
# batch_score = torch.sqrt(batch_score * preds_dict[0]['iou'][..., -self.num_proposals:].sigmoid()) # noqa: E501
one_hot = F.one_hot(
self.query_labels,
num_classes=self.num_classes).permute(0, 2, 1)
batch_score = batch_score * preds_dict[0][
'query_heatmap_score'] * one_hot
batch_center = preds_dict[0]['center'][..., -self.num_proposals:]
batch_height = preds_dict[0]['height'][..., -self.num_proposals:]
batch_dim = preds_dict[0]['dim'][..., -self.num_proposals:]
batch_rot = preds_dict[0]['rot'][..., -self.num_proposals:]
batch_vel = None
if 'vel' in preds_dict[0]:
batch_vel = preds_dict[0]['vel'][..., -self.num_proposals:]
temp = self.bbox_coder.decode(
batch_score,
batch_rot,
batch_dim,
batch_center,
batch_height,
batch_vel,
filter=True,
)
if self.test_cfg['dataset'] == 'nuScenes':
self.tasks = [
dict(
num_class=8,
class_names=[],
indices=[0, 1, 2, 3, 4, 5, 6, 7],
radius=-1,
),
dict(
num_class=1,
class_names=['pedestrian'],
indices=[8],
radius=0.175,
),
dict(
num_class=1,
class_names=['traffic_cone'],
indices=[9],
radius=0.175,
),
]
elif self.test_cfg['dataset'] == 'Waymo':
self.tasks = [
dict(
num_class=1,
class_names=['Car'],
indices=[0],
radius=0.7),
dict(
num_class=1,
class_names=['Pedestrian'],
indices=[1],
radius=0.7),
dict(
num_class=1,
class_names=['Cyclist'],
indices=[2],
radius=0.7),
]
ret_layer = []
for i in range(batch_size):
boxes3d = temp[i]['bboxes']
scores = temp[i]['scores']
labels = temp[i]['labels']
# adopt circle nms for different categories
if self.test_cfg['nms_type'] is not None:
keep_mask = torch.zeros_like(scores)
for task in self.tasks:
task_mask = torch.zeros_like(scores)
for cls_idx in task['indices']:
task_mask += labels == cls_idx
task_mask = task_mask.bool()
if task['radius'] > 0:
if self.test_cfg['nms_type'] == 'circle':
boxes_for_nms = torch.cat(
[
boxes3d[task_mask][:, :2],
scores[:, None][task_mask],
],
dim=1,
)
task_keep_indices = torch.tensor(
circle_nms(
boxes_for_nms.detach().cpu().numpy(),
task['radius'],
))
else:
boxes_for_nms = xywhr2xyxyr(
metas[i]['box_type_3d'](
boxes3d[task_mask][:, :7], 7).bev)
top_scores = scores[task_mask]
task_keep_indices = nms_bev(
boxes_for_nms,
top_scores,
thresh=task['radius'],
pre_maxsize=self.test_cfg['pre_maxsize'],
post_max_size=self.
test_cfg['post_maxsize'],
)
else:
task_keep_indices = torch.arange(task_mask.sum())
if task_keep_indices.shape[0] != 0:
keep_indices = torch.where(
task_mask != 0)[0][task_keep_indices]
keep_mask[keep_indices] = 1
keep_mask = keep_mask.bool()
ret = dict(
bboxes=boxes3d[keep_mask],
scores=scores[keep_mask],
labels=labels[keep_mask],
)
else: # no nms
ret = dict(bboxes=boxes3d, scores=scores, labels=labels)
temp_instances = InstanceData()
temp_instances.bboxes_3d = metas[0]['box_type_3d'](
ret['bboxes'], box_dim=ret['bboxes'].shape[-1])
temp_instances.scores_3d = ret['scores']
temp_instances.labels_3d = ret['labels'].int()
ret_layer.append(temp_instances)
rets.append(ret_layer)
assert len(
rets
) == 1, f'only support one layer now, but get {len(rets)} layers'
return rets[0]
def get_targets(self, batch_gt_instances_3d: List[InstanceData],
preds_dict: List[dict]):
"""Generate training targets.
Args:
batch_gt_instances_3d (List[InstanceData]):
preds_dict (list[dict]): The prediction results. The index of the
list is the index of layers. The inner dict contains
predictions of one mini-batch:
- center: (bs, 2, num_proposals)
- height: (bs, 1, num_proposals)
- dim: (bs, 3, num_proposals)
- rot: (bs, 2, num_proposals)
- vel: (bs, 2, num_proposals)
- cls_logit: (bs, num_classes, num_proposals)
- query_score: (bs, num_classes, num_proposals)
- heatmap: The original heatmap before fed into transformer
decoder, with shape (bs, 10, h, w)
Returns:
tuple[torch.Tensor]: Tuple of target including \
the following results in order.
- torch.Tensor: classification target. [BS, num_proposals]
- torch.Tensor: classification weights (mask)
[BS, num_proposals]
- torch.Tensor: regression target. [BS, num_proposals, 8]
- torch.Tensor: regression weights. [BS, num_proposals, 8]
"""
# change preds_dict into list of dict (index by batch_id)
# preds_dict[0]['center'].shape [bs, 3, num_proposal]
list_of_pred_dict = []
for batch_idx in range(len(batch_gt_instances_3d)):
pred_dict = {}
for key in preds_dict[0].keys():
preds = []
for i in range(self.num_decoder_layers):
pred_one_layer = preds_dict[i][key][batch_idx:batch_idx +
1]
preds.append(pred_one_layer)
pred_dict[key] = torch.cat(preds)
list_of_pred_dict.append(pred_dict)
assert len(batch_gt_instances_3d) == len(list_of_pred_dict)
res_tuple = multi_apply(
self.get_targets_single,
batch_gt_instances_3d,
list_of_pred_dict,
np.arange(len(batch_gt_instances_3d)),
)
labels = torch.cat(res_tuple[0], dim=0)
label_weights = torch.cat(res_tuple[1], dim=0)
bbox_targets = torch.cat(res_tuple[2], dim=0)
bbox_weights = torch.cat(res_tuple[3], dim=0)
ious = torch.cat(res_tuple[4], dim=0)
num_pos = np.sum(res_tuple[5])
matched_ious = np.mean(res_tuple[6])
heatmap = torch.cat(res_tuple[7], dim=0)
return (
labels,
label_weights,
bbox_targets,
bbox_weights,
ious,
num_pos,
matched_ious,
heatmap,
)
def get_targets_single(self, gt_instances_3d, preds_dict, batch_idx):
"""Generate training targets for a single sample.
Args:
gt_instances_3d (:obj:`InstanceData`): ground truth of instances.
preds_dict (dict): dict of prediction result for a single sample.
Returns:
tuple[torch.Tensor]: Tuple of target including \
the following results in order.
- torch.Tensor: classification target. [1, num_proposals]
- torch.Tensor: classification weights (mask) [1,
num_proposals] # noqa: E501
- torch.Tensor: regression target. [1, num_proposals, 8]
- torch.Tensor: regression weights. [1, num_proposals, 8]
- torch.Tensor: iou target. [1, num_proposals]
- int: number of positive proposals
- torch.Tensor: heatmap targets.
"""
# 1. Assignment
gt_bboxes_3d = gt_instances_3d.bboxes_3d
gt_labels_3d = gt_instances_3d.labels_3d
num_proposals = preds_dict['center'].shape[-1]
# get pred boxes, carefully ! don't change the network outputs
score = copy.deepcopy(preds_dict['heatmap'].detach())
center = copy.deepcopy(preds_dict['center'].detach())
height = copy.deepcopy(preds_dict['height'].detach())
dim = copy.deepcopy(preds_dict['dim'].detach())
rot = copy.deepcopy(preds_dict['rot'].detach())
if 'vel' in preds_dict.keys():
vel = copy.deepcopy(preds_dict['vel'].detach())
else:
vel = None
boxes_dict = self.bbox_coder.decode(
score, rot, dim, center, height,
vel) # decode the prediction to real world metric bbox
bboxes_tensor = boxes_dict[0]['bboxes']
gt_bboxes_tensor = gt_bboxes_3d.tensor.to(score.device)
# each layer should do label assign separately.
if self.auxiliary:
num_layer = self.num_decoder_layers
else:
num_layer = 1
assign_result_list = []
for idx_layer in range(num_layer):
bboxes_tensor_layer = bboxes_tensor[self.num_proposals *
idx_layer:self.num_proposals *
(idx_layer + 1), :]
score_layer = score[..., self.num_proposals *
idx_layer:self.num_proposals *
(idx_layer + 1), ]
if self.train_cfg.assigner.type == 'HungarianAssigner3D':
assign_result = self.bbox_assigner.assign(
bboxes_tensor_layer,
gt_bboxes_tensor,
gt_labels_3d,
score_layer,
self.train_cfg,
)
elif self.train_cfg.assigner.type == 'HeuristicAssigner':
assign_result = self.bbox_assigner.assign(
bboxes_tensor_layer,
gt_bboxes_tensor,
None,
gt_labels_3d,
self.query_labels[batch_idx],
)
else:
raise NotImplementedError
assign_result_list.append(assign_result)
# combine assign result of each layer
assign_result_ensemble = AssignResult(
num_gts=sum([res.num_gts for res in assign_result_list]),
gt_inds=torch.cat([res.gt_inds for res in assign_result_list]),
max_overlaps=torch.cat(
[res.max_overlaps for res in assign_result_list]),
labels=torch.cat([res.labels for res in assign_result_list]),
)
# 2. Sampling. Compatible with the interface of `PseudoSampler` in
# mmdet.
gt_instances, pred_instances = InstanceData(
bboxes=gt_bboxes_tensor), InstanceData(priors=bboxes_tensor)
sampling_result = self.bbox_sampler.sample(assign_result_ensemble,
pred_instances,
gt_instances)
pos_inds = sampling_result.pos_inds
neg_inds = sampling_result.neg_inds
assert len(pos_inds) + len(neg_inds) == num_proposals
# 3. Create target for loss computation
bbox_targets = torch.zeros([num_proposals, self.bbox_coder.code_size
]).to(center.device)
bbox_weights = torch.zeros([num_proposals, self.bbox_coder.code_size
]).to(center.device)
ious = assign_result_ensemble.max_overlaps
ious = torch.clamp(ious, min=0.0, max=1.0)
labels = bboxes_tensor.new_zeros(num_proposals, dtype=torch.long)
label_weights = bboxes_tensor.new_zeros(
num_proposals, dtype=torch.long)
if gt_labels_3d is not None: # default label is -1
labels += self.num_classes
# both pos and neg have classification loss, only pos has regression
# and iou loss
if len(pos_inds) > 0:
pos_bbox_targets = self.bbox_coder.encode(
sampling_result.pos_gt_bboxes)
bbox_targets[pos_inds, :] = pos_bbox_targets
bbox_weights[pos_inds, :] = 1.0
if gt_labels_3d is None:
labels[pos_inds] = 1
else:
labels[pos_inds] = gt_labels_3d[
sampling_result.pos_assigned_gt_inds]
if self.train_cfg.pos_weight <= 0:
label_weights[pos_inds] = 1.0
else:
label_weights[pos_inds] = self.train_cfg.pos_weight
if len(neg_inds) > 0:
label_weights[neg_inds] = 1.0
# # compute dense heatmap targets
device = labels.device
gt_bboxes_3d = torch.cat(
[gt_bboxes_3d.gravity_center, gt_bboxes_3d.tensor[:, 3:]],
dim=1).to(device)
grid_size = torch.tensor(self.train_cfg['grid_size'])
pc_range = torch.tensor(self.train_cfg['point_cloud_range'])
voxel_size = torch.tensor(self.train_cfg['voxel_size'])
feature_map_size = (grid_size[:2] // self.train_cfg['out_size_factor']
) # [x_len, y_len]
heatmap = gt_bboxes_3d.new_zeros(self.num_classes, feature_map_size[1],
feature_map_size[0])
for idx in range(len(gt_bboxes_3d)):
width = gt_bboxes_3d[idx][3]
length = gt_bboxes_3d[idx][4]
width = width / voxel_size[0] / self.train_cfg['out_size_factor']
length = length / voxel_size[1] / self.train_cfg['out_size_factor']
if width > 0 and length > 0:
radius = gaussian_radius(
(length, width),
min_overlap=self.train_cfg['gaussian_overlap'])
radius = max(self.train_cfg['min_radius'], int(radius))
x, y = gt_bboxes_3d[idx][0], gt_bboxes_3d[idx][1]
coor_x = ((x - pc_range[0]) / voxel_size[0] /
self.train_cfg['out_size_factor'])
coor_y = ((y - pc_range[1]) / voxel_size[1] /
self.train_cfg['out_size_factor'])
center = torch.tensor([coor_x, coor_y],
dtype=torch.float32,
device=device)
center_int = center.to(torch.int32)
# original
# draw_heatmap_gaussian(heatmap[gt_labels_3d[idx]], center_int, radius) # noqa: E501
# NOTE: fix
draw_heatmap_gaussian(heatmap[gt_labels_3d[idx]],
center_int[[1, 0]], radius)
mean_iou = ious[pos_inds].sum() / max(len(pos_inds), 1)
return (
labels[None],
label_weights[None],
bbox_targets[None],
bbox_weights[None],
ious[None],
int(pos_inds.shape[0]),
float(mean_iou),
heatmap[None],
)
def loss(self, batch_feats, batch_data_samples):
"""Loss function for CenterHead.
Args:
batch_feats (): Features in a batch.
batch_data_samples (List[:obj:`Det3DDataSample`]): The Data
Samples. It usually includes information such as
`gt_instance_3d`.
Returns:
dict[str:torch.Tensor]: Loss of heatmap and bbox of each task.
"""
batch_input_metas, batch_gt_instances_3d = [], []
for data_sample in batch_data_samples:
batch_input_metas.append(data_sample.metainfo)
batch_gt_instances_3d.append(data_sample.gt_instances_3d)
preds_dicts = self(batch_feats, batch_input_metas)
loss = self.loss_by_feat(preds_dicts, batch_gt_instances_3d)
return loss
def loss_by_feat(self, preds_dicts: Tuple[List[dict]],
batch_gt_instances_3d: List[InstanceData], *args,
**kwargs):
(
labels,
label_weights,
bbox_targets,
bbox_weights,
ious,
num_pos,
matched_ious,
heatmap,
) = self.get_targets(batch_gt_instances_3d, preds_dicts[0])
if hasattr(self, 'on_the_image_mask'):
label_weights = label_weights * self.on_the_image_mask
bbox_weights = bbox_weights * self.on_the_image_mask[:, :, None]
num_pos = bbox_weights.max(-1).values.sum()
preds_dict = preds_dicts[0][0]
loss_dict = dict()
# compute heatmap loss
loss_heatmap = self.loss_heatmap(
clip_sigmoid(preds_dict['dense_heatmap']).float(),
heatmap.float(),
avg_factor=max(heatmap.eq(1).float().sum().item(), 1),
)
loss_dict['loss_heatmap'] = loss_heatmap
# compute loss for each layer
for idx_layer in range(
self.num_decoder_layers if self.auxiliary else 1):
if idx_layer == self.num_decoder_layers - 1 or (
idx_layer == 0 and self.auxiliary is False):
prefix = 'layer_-1'
else:
prefix = f'layer_{idx_layer}'
layer_labels = labels[..., idx_layer *
self.num_proposals:(idx_layer + 1) *
self.num_proposals, ].reshape(-1)
layer_label_weights = label_weights[
..., idx_layer * self.num_proposals:(idx_layer + 1) *
self.num_proposals, ].reshape(-1)
layer_score = preds_dict['heatmap'][..., idx_layer *
self.num_proposals:(idx_layer +
1) *
self.num_proposals, ]
layer_cls_score = layer_score.permute(0, 2, 1).reshape(
-1, self.num_classes)
layer_loss_cls = self.loss_cls(
layer_cls_score.float(),
layer_labels,
layer_label_weights,
avg_factor=max(num_pos, 1),
)
layer_center = preds_dict['center'][..., idx_layer *
self.num_proposals:(idx_layer +
1) *
self.num_proposals, ]
layer_height = preds_dict['height'][..., idx_layer *
self.num_proposals:(idx_layer +
1) *
self.num_proposals, ]
layer_rot = preds_dict['rot'][..., idx_layer *
self.num_proposals:(idx_layer + 1) *
self.num_proposals, ]
layer_dim = preds_dict['dim'][..., idx_layer *
self.num_proposals:(idx_layer + 1) *
self.num_proposals, ]
preds = torch.cat(
[layer_center, layer_height, layer_dim, layer_rot],
dim=1).permute(0, 2, 1) # [BS, num_proposals, code_size]
if 'vel' in preds_dict.keys():
layer_vel = preds_dict['vel'][..., idx_layer *
self.num_proposals:(idx_layer +
1) *
self.num_proposals, ]
preds = torch.cat([
layer_center, layer_height, layer_dim, layer_rot, layer_vel
],
dim=1).permute(
0, 2,
1) # [BS, num_proposals, code_size]
code_weights = self.train_cfg.get('code_weights', None)
layer_bbox_weights = bbox_weights[:, idx_layer *
self.num_proposals:(idx_layer +
1) *
self.num_proposals, :, ]
layer_reg_weights = layer_bbox_weights * layer_bbox_weights.new_tensor( # noqa: E501
code_weights)
layer_bbox_targets = bbox_targets[:, idx_layer *
self.num_proposals:(idx_layer +
1) *
self.num_proposals, :, ]
layer_loss_bbox = self.loss_bbox(
preds,
layer_bbox_targets,
layer_reg_weights,
avg_factor=max(num_pos, 1))
loss_dict[f'{prefix}_loss_cls'] = layer_loss_cls
loss_dict[f'{prefix}_loss_bbox'] = layer_loss_bbox
# loss_dict[f'{prefix}_loss_iou'] = layer_loss_iou
loss_dict['matched_ious'] = layer_loss_cls.new_tensor(matched_ious)
return loss_dict