giantmonkeyTC
2344
34d1f8b
raw
history blame contribute delete
9.36 kB
# modify from https://github.com/mit-han-lab/bevfusion
from typing import Any, Dict
import numpy as np
import torch
from mmcv.transforms import BaseTransform
from PIL import Image
from mmdet3d.datasets import GlobalRotScaleTrans
from mmdet3d.registry import TRANSFORMS
@TRANSFORMS.register_module()
class ImageAug3D(BaseTransform):
def __init__(self, final_dim, resize_lim, bot_pct_lim, rot_lim, rand_flip,
is_train):
self.final_dim = final_dim
self.resize_lim = resize_lim
self.bot_pct_lim = bot_pct_lim
self.rand_flip = rand_flip
self.rot_lim = rot_lim
self.is_train = is_train
def sample_augmentation(self, results):
H, W = results['ori_shape']
fH, fW = self.final_dim
if self.is_train:
resize = np.random.uniform(*self.resize_lim)
resize_dims = (int(W * resize), int(H * resize))
newW, newH = resize_dims
crop_h = int(
(1 - np.random.uniform(*self.bot_pct_lim)) * newH) - fH
crop_w = int(np.random.uniform(0, max(0, newW - fW)))
crop = (crop_w, crop_h, crop_w + fW, crop_h + fH)
flip = False
if self.rand_flip and np.random.choice([0, 1]):
flip = True
rotate = np.random.uniform(*self.rot_lim)
else:
resize = np.mean(self.resize_lim)
resize_dims = (int(W * resize), int(H * resize))
newW, newH = resize_dims
crop_h = int((1 - np.mean(self.bot_pct_lim)) * newH) - fH
crop_w = int(max(0, newW - fW) / 2)
crop = (crop_w, crop_h, crop_w + fW, crop_h + fH)
flip = False
rotate = 0
return resize, resize_dims, crop, flip, rotate
def img_transform(self, img, rotation, translation, resize, resize_dims,
crop, flip, rotate):
# adjust image
img = Image.fromarray(img.astype('uint8'), mode='RGB')
img = img.resize(resize_dims)
img = img.crop(crop)
if flip:
img = img.transpose(method=Image.FLIP_LEFT_RIGHT)
img = img.rotate(rotate)
# post-homography transformation
rotation *= resize
translation -= torch.Tensor(crop[:2])
if flip:
A = torch.Tensor([[-1, 0], [0, 1]])
b = torch.Tensor([crop[2] - crop[0], 0])
rotation = A.matmul(rotation)
translation = A.matmul(translation) + b
theta = rotate / 180 * np.pi
A = torch.Tensor([
[np.cos(theta), np.sin(theta)],
[-np.sin(theta), np.cos(theta)],
])
b = torch.Tensor([crop[2] - crop[0], crop[3] - crop[1]]) / 2
b = A.matmul(-b) + b
rotation = A.matmul(rotation)
translation = A.matmul(translation) + b
return img, rotation, translation
def transform(self, data: Dict[str, Any]) -> Dict[str, Any]:
imgs = data['img']
new_imgs = []
transforms = []
for img in imgs:
resize, resize_dims, crop, flip, rotate = self.sample_augmentation(
data)
post_rot = torch.eye(2)
post_tran = torch.zeros(2)
new_img, rotation, translation = self.img_transform(
img,
post_rot,
post_tran,
resize=resize,
resize_dims=resize_dims,
crop=crop,
flip=flip,
rotate=rotate,
)
transform = torch.eye(4)
transform[:2, :2] = rotation
transform[:2, 3] = translation
new_imgs.append(np.array(new_img).astype(np.float32))
transforms.append(transform.numpy())
data['img'] = new_imgs
# update the calibration matrices
data['img_aug_matrix'] = transforms
return data
@TRANSFORMS.register_module()
class BEVFusionRandomFlip3D:
"""Compared with `RandomFlip3D`, this class directly records the lidar
augmentation matrix in the `data`."""
def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
flip_horizontal = np.random.choice([0, 1])
flip_vertical = np.random.choice([0, 1])
rotation = np.eye(3)
if flip_horizontal:
rotation = np.array([[1, 0, 0], [0, -1, 0], [0, 0, 1]]) @ rotation
if 'points' in data:
data['points'].flip('horizontal')
if 'gt_bboxes_3d' in data:
data['gt_bboxes_3d'].flip('horizontal')
if 'gt_masks_bev' in data:
data['gt_masks_bev'] = data['gt_masks_bev'][:, :, ::-1].copy()
if flip_vertical:
rotation = np.array([[-1, 0, 0], [0, 1, 0], [0, 0, 1]]) @ rotation
if 'points' in data:
data['points'].flip('vertical')
if 'gt_bboxes_3d' in data:
data['gt_bboxes_3d'].flip('vertical')
if 'gt_masks_bev' in data:
data['gt_masks_bev'] = data['gt_masks_bev'][:, ::-1, :].copy()
if 'lidar_aug_matrix' not in data:
data['lidar_aug_matrix'] = np.eye(4)
data['lidar_aug_matrix'][:3, :] = rotation @ data[
'lidar_aug_matrix'][:3, :]
return data
@TRANSFORMS.register_module()
class BEVFusionGlobalRotScaleTrans(GlobalRotScaleTrans):
"""Compared with `GlobalRotScaleTrans`, the augmentation order in this
class is rotation, translation and scaling (RTS)."""
def transform(self, input_dict: dict) -> dict:
"""Private function to rotate, scale and translate bounding boxes and
points.
Args:
input_dict (dict): Result dict from loading pipeline.
Returns:
dict: Results after scaling, 'points', 'pcd_rotation',
'pcd_scale_factor', 'pcd_trans' and `gt_bboxes_3d` are updated
in the result dict.
"""
if 'transformation_3d_flow' not in input_dict:
input_dict['transformation_3d_flow'] = []
self._rot_bbox_points(input_dict)
if 'pcd_scale_factor' not in input_dict:
self._random_scale(input_dict)
self._trans_bbox_points(input_dict)
self._scale_bbox_points(input_dict)
input_dict['transformation_3d_flow'].extend(['R', 'T', 'S'])
lidar_augs = np.eye(4)
lidar_augs[:3, :3] = input_dict['pcd_rotation'].T * input_dict[
'pcd_scale_factor']
lidar_augs[:3, 3] = input_dict['pcd_trans'] * \
input_dict['pcd_scale_factor']
if 'lidar_aug_matrix' not in input_dict:
input_dict['lidar_aug_matrix'] = np.eye(4)
input_dict[
'lidar_aug_matrix'] = lidar_augs @ input_dict['lidar_aug_matrix']
return input_dict
@TRANSFORMS.register_module()
class GridMask(BaseTransform):
def __init__(
self,
use_h,
use_w,
max_epoch,
rotate=1,
offset=False,
ratio=0.5,
mode=0,
prob=1.0,
fixed_prob=False,
):
self.use_h = use_h
self.use_w = use_w
self.rotate = rotate
self.offset = offset
self.ratio = ratio
self.mode = mode
self.st_prob = prob
self.prob = prob
self.epoch = None
self.max_epoch = max_epoch
self.fixed_prob = fixed_prob
def set_epoch(self, epoch):
self.epoch = epoch
if not self.fixed_prob:
self.set_prob(self.epoch, self.max_epoch)
def set_prob(self, epoch, max_epoch):
self.prob = self.st_prob * self.epoch / self.max_epoch
def transform(self, results):
if np.random.rand() > self.prob:
return results
imgs = results['img']
h = imgs[0].shape[0]
w = imgs[0].shape[1]
self.d1 = 2
self.d2 = min(h, w)
hh = int(1.5 * h)
ww = int(1.5 * w)
d = np.random.randint(self.d1, self.d2)
if self.ratio == 1:
self.length = np.random.randint(1, d)
else:
self.length = min(max(int(d * self.ratio + 0.5), 1), d - 1)
mask = np.ones((hh, ww), np.float32)
st_h = np.random.randint(d)
st_w = np.random.randint(d)
if self.use_h:
for i in range(hh // d):
s = d * i + st_h
t = min(s + self.length, hh)
mask[s:t, :] *= 0
if self.use_w:
for i in range(ww // d):
s = d * i + st_w
t = min(s + self.length, ww)
mask[:, s:t] *= 0
r = np.random.randint(self.rotate)
mask = Image.fromarray(np.uint8(mask))
mask = mask.rotate(r)
mask = np.asarray(mask)
mask = mask[(hh - h) // 2:(hh - h) // 2 + h,
(ww - w) // 2:(ww - w) // 2 + w]
mask = mask.astype(np.float32)
mask = mask[:, :, None]
if self.mode == 1:
mask = 1 - mask
# mask = mask.expand_as(imgs[0])
if self.offset:
offset = torch.from_numpy(2 * (np.random.rand(h, w) - 0.5)).float()
offset = (1 - mask) * offset
imgs = [x * mask + offset for x in imgs]
else:
imgs = [x * mask for x in imgs]
results.update(img=imgs)
return results