|
|
|
import numpy as np |
|
import torch |
|
|
|
from mmdet3d.structures.bbox_3d.utils import limit_period |
|
|
|
|
|
def normalize_bbox(bboxes, pc_range): |
|
|
|
cx = bboxes[..., 0:1] |
|
cy = bboxes[..., 1:2] |
|
cz = bboxes[..., 2:3] |
|
length = bboxes[..., 3:4].log() |
|
width = bboxes[..., 4:5].log() |
|
height = bboxes[..., 5:6].log() |
|
|
|
rot = -bboxes[..., 6:7] - np.pi / 2 |
|
rot = limit_period(rot, period=np.pi * 2) |
|
if bboxes.size(-1) > 7: |
|
vx = bboxes[..., 7:8] |
|
vy = bboxes[..., 8:9] |
|
normalized_bboxes = torch.cat( |
|
(cx, cy, length, width, cz, height, rot.sin(), rot.cos(), vx, vy), |
|
dim=-1) |
|
else: |
|
normalized_bboxes = torch.cat( |
|
(cx, cy, length, width, cz, height, rot.sin(), rot.cos()), dim=-1) |
|
return normalized_bboxes |
|
|
|
|
|
def denormalize_bbox(normalized_bboxes, pc_range): |
|
|
|
rot_sine = normalized_bboxes[..., 6:7] |
|
|
|
rot_cosine = normalized_bboxes[..., 7:8] |
|
rot = torch.atan2(rot_sine, rot_cosine) |
|
rot = -rot - np.pi / 2 |
|
rot = limit_period(rot, period=np.pi * 2) |
|
|
|
|
|
cx = normalized_bboxes[..., 0:1] |
|
cy = normalized_bboxes[..., 1:2] |
|
cz = normalized_bboxes[..., 4:5] |
|
|
|
|
|
length = normalized_bboxes[..., 2:3] |
|
width = normalized_bboxes[..., 3:4] |
|
height = normalized_bboxes[..., 5:6] |
|
|
|
width = width.exp() |
|
length = length.exp() |
|
height = height.exp() |
|
if normalized_bboxes.size(-1) > 8: |
|
|
|
vx = normalized_bboxes[:, 8:9] |
|
vy = normalized_bboxes[:, 9:10] |
|
denormalized_bboxes = torch.cat( |
|
[cx, cy, cz, length, width, height, rot, vx, vy], dim=-1) |
|
else: |
|
denormalized_bboxes = torch.cat( |
|
[cx, cy, cz, length, width, height, rot], dim=-1) |
|
|
|
return denormalized_bboxes |
|
|