|
|
|
from typing import Tuple |
|
|
|
import torch |
|
from torch import nn |
|
|
|
from mmdet3d.registry import MODELS |
|
from .ops import bev_pool |
|
|
|
|
|
def gen_dx_bx(xbound, ybound, zbound): |
|
dx = torch.Tensor([row[2] for row in [xbound, ybound, zbound]]) |
|
bx = torch.Tensor( |
|
[row[0] + row[2] / 2.0 for row in [xbound, ybound, zbound]]) |
|
nx = torch.LongTensor([(row[1] - row[0]) / row[2] |
|
for row in [xbound, ybound, zbound]]) |
|
return dx, bx, nx |
|
|
|
|
|
class BaseViewTransform(nn.Module): |
|
|
|
def __init__( |
|
self, |
|
in_channels: int, |
|
out_channels: int, |
|
image_size: Tuple[int, int], |
|
feature_size: Tuple[int, int], |
|
xbound: Tuple[float, float, float], |
|
ybound: Tuple[float, float, float], |
|
zbound: Tuple[float, float, float], |
|
dbound: Tuple[float, float, float], |
|
) -> None: |
|
super().__init__() |
|
self.in_channels = in_channels |
|
self.image_size = image_size |
|
self.feature_size = feature_size |
|
self.xbound = xbound |
|
self.ybound = ybound |
|
self.zbound = zbound |
|
self.dbound = dbound |
|
|
|
dx, bx, nx = gen_dx_bx(self.xbound, self.ybound, self.zbound) |
|
self.dx = nn.Parameter(dx, requires_grad=False) |
|
self.bx = nn.Parameter(bx, requires_grad=False) |
|
self.nx = nn.Parameter(nx, requires_grad=False) |
|
|
|
self.C = out_channels |
|
self.frustum = self.create_frustum() |
|
self.D = self.frustum.shape[0] |
|
self.fp16_enabled = False |
|
|
|
def create_frustum(self): |
|
iH, iW = self.image_size |
|
fH, fW = self.feature_size |
|
|
|
ds = ( |
|
torch.arange(*self.dbound, |
|
dtype=torch.float).view(-1, 1, 1).expand(-1, fH, fW)) |
|
D, _, _ = ds.shape |
|
|
|
xs = ( |
|
torch.linspace(0, iW - 1, fW, |
|
dtype=torch.float).view(1, 1, fW).expand(D, fH, fW)) |
|
ys = ( |
|
torch.linspace(0, iH - 1, fH, |
|
dtype=torch.float).view(1, fH, 1).expand(D, fH, fW)) |
|
|
|
frustum = torch.stack((xs, ys, ds), -1) |
|
return nn.Parameter(frustum, requires_grad=False) |
|
|
|
def get_geometry( |
|
self, |
|
camera2lidar_rots, |
|
camera2lidar_trans, |
|
intrins, |
|
post_rots, |
|
post_trans, |
|
**kwargs, |
|
): |
|
B, N, _ = camera2lidar_trans.shape |
|
|
|
|
|
|
|
points = self.frustum - post_trans.view(B, N, 1, 1, 1, 3) |
|
points = ( |
|
torch.inverse(post_rots).view(B, N, 1, 1, 1, 3, |
|
3).matmul(points.unsqueeze(-1))) |
|
|
|
points = torch.cat( |
|
( |
|
points[:, :, :, :, :, :2] * points[:, :, :, :, :, 2:3], |
|
points[:, :, :, :, :, 2:3], |
|
), |
|
5, |
|
) |
|
combine = camera2lidar_rots.matmul(torch.inverse(intrins)) |
|
points = combine.view(B, N, 1, 1, 1, 3, 3).matmul(points).squeeze(-1) |
|
points += camera2lidar_trans.view(B, N, 1, 1, 1, 3) |
|
|
|
if 'extra_rots' in kwargs: |
|
extra_rots = kwargs['extra_rots'] |
|
points = ( |
|
extra_rots.view(B, 1, 1, 1, 1, 3, |
|
3).repeat(1, N, 1, 1, 1, 1, 1).matmul( |
|
points.unsqueeze(-1)).squeeze(-1)) |
|
if 'extra_trans' in kwargs: |
|
extra_trans = kwargs['extra_trans'] |
|
points += extra_trans.view(B, 1, 1, 1, 1, |
|
3).repeat(1, N, 1, 1, 1, 1) |
|
|
|
return points |
|
|
|
def get_cam_feats(self, x): |
|
raise NotImplementedError |
|
|
|
def bev_pool(self, geom_feats, x): |
|
B, N, D, H, W, C = x.shape |
|
Nprime = B * N * D * H * W |
|
|
|
|
|
x = x.reshape(Nprime, C) |
|
|
|
|
|
geom_feats = ((geom_feats - (self.bx - self.dx / 2.0)) / |
|
self.dx).long() |
|
geom_feats = geom_feats.view(Nprime, 3) |
|
batch_ix = torch.cat([ |
|
torch.full([Nprime // B, 1], ix, device=x.device, dtype=torch.long) |
|
for ix in range(B) |
|
]) |
|
geom_feats = torch.cat((geom_feats, batch_ix), 1) |
|
|
|
|
|
kept = ((geom_feats[:, 0] >= 0) |
|
& (geom_feats[:, 0] < self.nx[0]) |
|
& (geom_feats[:, 1] >= 0) |
|
& (geom_feats[:, 1] < self.nx[1]) |
|
& (geom_feats[:, 2] >= 0) |
|
& (geom_feats[:, 2] < self.nx[2])) |
|
x = x[kept] |
|
geom_feats = geom_feats[kept] |
|
|
|
x = bev_pool(x, geom_feats, B, self.nx[2], self.nx[0], self.nx[1]) |
|
|
|
|
|
final = torch.cat(x.unbind(dim=2), 1) |
|
|
|
return final |
|
|
|
def forward( |
|
self, |
|
img, |
|
points, |
|
lidar2image, |
|
camera_intrinsics, |
|
camera2lidar, |
|
img_aug_matrix, |
|
lidar_aug_matrix, |
|
metas, |
|
**kwargs, |
|
): |
|
intrins = camera_intrinsics[..., :3, :3] |
|
post_rots = img_aug_matrix[..., :3, :3] |
|
post_trans = img_aug_matrix[..., :3, 3] |
|
camera2lidar_rots = camera2lidar[..., :3, :3] |
|
camera2lidar_trans = camera2lidar[..., :3, 3] |
|
|
|
extra_rots = lidar_aug_matrix[..., :3, :3] |
|
extra_trans = lidar_aug_matrix[..., :3, 3] |
|
|
|
geom = self.get_geometry( |
|
camera2lidar_rots, |
|
camera2lidar_trans, |
|
intrins, |
|
post_rots, |
|
post_trans, |
|
extra_rots=extra_rots, |
|
extra_trans=extra_trans, |
|
) |
|
|
|
x = self.get_cam_feats(img) |
|
x = self.bev_pool(geom, x) |
|
return x |
|
|
|
|
|
@MODELS.register_module() |
|
class LSSTransform(BaseViewTransform): |
|
|
|
def __init__( |
|
self, |
|
in_channels: int, |
|
out_channels: int, |
|
image_size: Tuple[int, int], |
|
feature_size: Tuple[int, int], |
|
xbound: Tuple[float, float, float], |
|
ybound: Tuple[float, float, float], |
|
zbound: Tuple[float, float, float], |
|
dbound: Tuple[float, float, float], |
|
downsample: int = 1, |
|
) -> None: |
|
super().__init__( |
|
in_channels=in_channels, |
|
out_channels=out_channels, |
|
image_size=image_size, |
|
feature_size=feature_size, |
|
xbound=xbound, |
|
ybound=ybound, |
|
zbound=zbound, |
|
dbound=dbound, |
|
) |
|
self.depthnet = nn.Conv2d(in_channels, self.D + self.C, 1) |
|
if downsample > 1: |
|
assert downsample == 2, downsample |
|
self.downsample = nn.Sequential( |
|
nn.Conv2d( |
|
out_channels, out_channels, 3, padding=1, bias=False), |
|
nn.BatchNorm2d(out_channels), |
|
nn.ReLU(True), |
|
nn.Conv2d( |
|
out_channels, |
|
out_channels, |
|
3, |
|
stride=downsample, |
|
padding=1, |
|
bias=False, |
|
), |
|
nn.BatchNorm2d(out_channels), |
|
nn.ReLU(True), |
|
nn.Conv2d( |
|
out_channels, out_channels, 3, padding=1, bias=False), |
|
nn.BatchNorm2d(out_channels), |
|
nn.ReLU(True), |
|
) |
|
else: |
|
self.downsample = nn.Identity() |
|
|
|
def get_cam_feats(self, x): |
|
B, N, C, fH, fW = x.shape |
|
|
|
x = x.view(B * N, C, fH, fW) |
|
|
|
x = self.depthnet(x) |
|
depth = x[:, :self.D].softmax(dim=1) |
|
x = depth.unsqueeze(1) * x[:, self.D:(self.D + self.C)].unsqueeze(2) |
|
|
|
x = x.view(B, N, self.C, self.D, fH, fW) |
|
x = x.permute(0, 1, 3, 4, 5, 2) |
|
return x |
|
|
|
def forward(self, *args, **kwargs): |
|
x = super().forward(*args, **kwargs) |
|
x = self.downsample(x) |
|
return x |
|
|
|
|
|
class BaseDepthTransform(BaseViewTransform): |
|
|
|
def forward( |
|
self, |
|
img, |
|
points, |
|
lidar2image, |
|
cam_intrinsic, |
|
camera2lidar, |
|
img_aug_matrix, |
|
lidar_aug_matrix, |
|
metas, |
|
**kwargs, |
|
): |
|
intrins = cam_intrinsic[..., :3, :3] |
|
post_rots = img_aug_matrix[..., :3, :3] |
|
post_trans = img_aug_matrix[..., :3, 3] |
|
camera2lidar_rots = camera2lidar[..., :3, :3] |
|
camera2lidar_trans = camera2lidar[..., :3, 3] |
|
|
|
batch_size = len(points) |
|
depth = torch.zeros(batch_size, img.shape[1], 1, |
|
*self.image_size).to(points[0].device) |
|
|
|
for b in range(batch_size): |
|
cur_coords = points[b][:, :3] |
|
cur_img_aug_matrix = img_aug_matrix[b] |
|
cur_lidar_aug_matrix = lidar_aug_matrix[b] |
|
cur_lidar2image = lidar2image[b] |
|
|
|
|
|
cur_coords -= cur_lidar_aug_matrix[:3, 3] |
|
cur_coords = torch.inverse(cur_lidar_aug_matrix[:3, :3]).matmul( |
|
cur_coords.transpose(1, 0)) |
|
|
|
cur_coords = cur_lidar2image[:, :3, :3].matmul(cur_coords) |
|
cur_coords += cur_lidar2image[:, :3, 3].reshape(-1, 3, 1) |
|
|
|
dist = cur_coords[:, 2, :] |
|
cur_coords[:, 2, :] = torch.clamp(cur_coords[:, 2, :], 1e-5, 1e5) |
|
cur_coords[:, :2, :] /= cur_coords[:, 2:3, :] |
|
|
|
|
|
cur_coords = cur_img_aug_matrix[:, :3, :3].matmul(cur_coords) |
|
cur_coords += cur_img_aug_matrix[:, :3, 3].reshape(-1, 3, 1) |
|
cur_coords = cur_coords[:, :2, :].transpose(1, 2) |
|
|
|
|
|
cur_coords = cur_coords[..., [1, 0]] |
|
|
|
on_img = ((cur_coords[..., 0] < self.image_size[0]) |
|
& (cur_coords[..., 0] >= 0) |
|
& (cur_coords[..., 1] < self.image_size[1]) |
|
& (cur_coords[..., 1] >= 0)) |
|
for c in range(on_img.shape[0]): |
|
masked_coords = cur_coords[c, on_img[c]].long() |
|
masked_dist = dist[c, on_img[c]] |
|
depth = depth.to(masked_dist.dtype) |
|
depth[b, c, 0, masked_coords[:, 0], |
|
masked_coords[:, 1]] = masked_dist |
|
|
|
extra_rots = lidar_aug_matrix[..., :3, :3] |
|
extra_trans = lidar_aug_matrix[..., :3, 3] |
|
geom = self.get_geometry( |
|
camera2lidar_rots, |
|
camera2lidar_trans, |
|
intrins, |
|
post_rots, |
|
post_trans, |
|
extra_rots=extra_rots, |
|
extra_trans=extra_trans, |
|
) |
|
|
|
x = self.get_cam_feats(img, depth) |
|
x = self.bev_pool(geom, x) |
|
return x |
|
|
|
|
|
@MODELS.register_module() |
|
class DepthLSSTransform(BaseDepthTransform): |
|
|
|
def __init__( |
|
self, |
|
in_channels: int, |
|
out_channels: int, |
|
image_size: Tuple[int, int], |
|
feature_size: Tuple[int, int], |
|
xbound: Tuple[float, float, float], |
|
ybound: Tuple[float, float, float], |
|
zbound: Tuple[float, float, float], |
|
dbound: Tuple[float, float, float], |
|
downsample: int = 1, |
|
) -> None: |
|
"""Compared with `LSSTransform`, `DepthLSSTransform` adds sparse depth |
|
information from lidar points into the inputs of the `depthnet`.""" |
|
super().__init__( |
|
in_channels=in_channels, |
|
out_channels=out_channels, |
|
image_size=image_size, |
|
feature_size=feature_size, |
|
xbound=xbound, |
|
ybound=ybound, |
|
zbound=zbound, |
|
dbound=dbound, |
|
) |
|
self.dtransform = nn.Sequential( |
|
nn.Conv2d(1, 8, 1), |
|
nn.BatchNorm2d(8), |
|
nn.ReLU(True), |
|
nn.Conv2d(8, 32, 5, stride=4, padding=2), |
|
nn.BatchNorm2d(32), |
|
nn.ReLU(True), |
|
nn.Conv2d(32, 64, 5, stride=2, padding=2), |
|
nn.BatchNorm2d(64), |
|
nn.ReLU(True), |
|
) |
|
self.depthnet = nn.Sequential( |
|
nn.Conv2d(in_channels + 64, in_channels, 3, padding=1), |
|
nn.BatchNorm2d(in_channels), |
|
nn.ReLU(True), |
|
nn.Conv2d(in_channels, in_channels, 3, padding=1), |
|
nn.BatchNorm2d(in_channels), |
|
nn.ReLU(True), |
|
nn.Conv2d(in_channels, self.D + self.C, 1), |
|
) |
|
if downsample > 1: |
|
assert downsample == 2, downsample |
|
self.downsample = nn.Sequential( |
|
nn.Conv2d( |
|
out_channels, out_channels, 3, padding=1, bias=False), |
|
nn.BatchNorm2d(out_channels), |
|
nn.ReLU(True), |
|
nn.Conv2d( |
|
out_channels, |
|
out_channels, |
|
3, |
|
stride=downsample, |
|
padding=1, |
|
bias=False, |
|
), |
|
nn.BatchNorm2d(out_channels), |
|
nn.ReLU(True), |
|
nn.Conv2d( |
|
out_channels, out_channels, 3, padding=1, bias=False), |
|
nn.BatchNorm2d(out_channels), |
|
nn.ReLU(True), |
|
) |
|
else: |
|
self.downsample = nn.Identity() |
|
|
|
def get_cam_feats(self, x, d): |
|
B, N, C, fH, fW = x.shape |
|
|
|
d = d.view(B * N, *d.shape[2:]) |
|
x = x.view(B * N, C, fH, fW) |
|
|
|
d = self.dtransform(d) |
|
x = torch.cat([d, x], dim=1) |
|
x = self.depthnet(x) |
|
|
|
depth = x[:, :self.D].softmax(dim=1) |
|
x = depth.unsqueeze(1) * x[:, self.D:(self.D + self.C)].unsqueeze(2) |
|
|
|
x = x.view(B, N, self.C, self.D, fH, fW) |
|
x = x.permute(0, 1, 3, 4, 5, 2) |
|
return x |
|
|
|
def forward(self, *args, **kwargs): |
|
x = super().forward(*args, **kwargs) |
|
x = self.downsample(x) |
|
return x |
|
|