|
|
|
from typing import Any, Dict |
|
|
|
import numpy as np |
|
import torch |
|
from mmcv.transforms import BaseTransform |
|
from PIL import Image |
|
|
|
from mmdet3d.datasets import GlobalRotScaleTrans |
|
from mmdet3d.registry import TRANSFORMS |
|
|
|
|
|
@TRANSFORMS.register_module() |
|
class ImageAug3D(BaseTransform): |
|
|
|
def __init__(self, final_dim, resize_lim, bot_pct_lim, rot_lim, rand_flip, |
|
is_train): |
|
self.final_dim = final_dim |
|
self.resize_lim = resize_lim |
|
self.bot_pct_lim = bot_pct_lim |
|
self.rand_flip = rand_flip |
|
self.rot_lim = rot_lim |
|
self.is_train = is_train |
|
|
|
def sample_augmentation(self, results): |
|
H, W = results['ori_shape'] |
|
fH, fW = self.final_dim |
|
if self.is_train: |
|
resize = np.random.uniform(*self.resize_lim) |
|
resize_dims = (int(W * resize), int(H * resize)) |
|
newW, newH = resize_dims |
|
crop_h = int( |
|
(1 - np.random.uniform(*self.bot_pct_lim)) * newH) - fH |
|
crop_w = int(np.random.uniform(0, max(0, newW - fW))) |
|
crop = (crop_w, crop_h, crop_w + fW, crop_h + fH) |
|
flip = False |
|
if self.rand_flip and np.random.choice([0, 1]): |
|
flip = True |
|
rotate = np.random.uniform(*self.rot_lim) |
|
else: |
|
resize = np.mean(self.resize_lim) |
|
resize_dims = (int(W * resize), int(H * resize)) |
|
newW, newH = resize_dims |
|
crop_h = int((1 - np.mean(self.bot_pct_lim)) * newH) - fH |
|
crop_w = int(max(0, newW - fW) / 2) |
|
crop = (crop_w, crop_h, crop_w + fW, crop_h + fH) |
|
flip = False |
|
rotate = 0 |
|
return resize, resize_dims, crop, flip, rotate |
|
|
|
def img_transform(self, img, rotation, translation, resize, resize_dims, |
|
crop, flip, rotate): |
|
|
|
img = Image.fromarray(img.astype('uint8'), mode='RGB') |
|
img = img.resize(resize_dims) |
|
img = img.crop(crop) |
|
if flip: |
|
img = img.transpose(method=Image.FLIP_LEFT_RIGHT) |
|
img = img.rotate(rotate) |
|
|
|
|
|
rotation *= resize |
|
translation -= torch.Tensor(crop[:2]) |
|
if flip: |
|
A = torch.Tensor([[-1, 0], [0, 1]]) |
|
b = torch.Tensor([crop[2] - crop[0], 0]) |
|
rotation = A.matmul(rotation) |
|
translation = A.matmul(translation) + b |
|
theta = rotate / 180 * np.pi |
|
A = torch.Tensor([ |
|
[np.cos(theta), np.sin(theta)], |
|
[-np.sin(theta), np.cos(theta)], |
|
]) |
|
b = torch.Tensor([crop[2] - crop[0], crop[3] - crop[1]]) / 2 |
|
b = A.matmul(-b) + b |
|
rotation = A.matmul(rotation) |
|
translation = A.matmul(translation) + b |
|
|
|
return img, rotation, translation |
|
|
|
def transform(self, data: Dict[str, Any]) -> Dict[str, Any]: |
|
imgs = data['img'] |
|
new_imgs = [] |
|
transforms = [] |
|
for img in imgs: |
|
resize, resize_dims, crop, flip, rotate = self.sample_augmentation( |
|
data) |
|
post_rot = torch.eye(2) |
|
post_tran = torch.zeros(2) |
|
new_img, rotation, translation = self.img_transform( |
|
img, |
|
post_rot, |
|
post_tran, |
|
resize=resize, |
|
resize_dims=resize_dims, |
|
crop=crop, |
|
flip=flip, |
|
rotate=rotate, |
|
) |
|
transform = torch.eye(4) |
|
transform[:2, :2] = rotation |
|
transform[:2, 3] = translation |
|
new_imgs.append(np.array(new_img).astype(np.float32)) |
|
transforms.append(transform.numpy()) |
|
data['img'] = new_imgs |
|
|
|
data['img_aug_matrix'] = transforms |
|
return data |
|
|
|
|
|
@TRANSFORMS.register_module() |
|
class BEVFusionRandomFlip3D: |
|
"""Compared with `RandomFlip3D`, this class directly records the lidar |
|
augmentation matrix in the `data`.""" |
|
|
|
def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]: |
|
flip_horizontal = np.random.choice([0, 1]) |
|
flip_vertical = np.random.choice([0, 1]) |
|
|
|
rotation = np.eye(3) |
|
if flip_horizontal: |
|
rotation = np.array([[1, 0, 0], [0, -1, 0], [0, 0, 1]]) @ rotation |
|
if 'points' in data: |
|
data['points'].flip('horizontal') |
|
if 'gt_bboxes_3d' in data: |
|
data['gt_bboxes_3d'].flip('horizontal') |
|
if 'gt_masks_bev' in data: |
|
data['gt_masks_bev'] = data['gt_masks_bev'][:, :, ::-1].copy() |
|
|
|
if flip_vertical: |
|
rotation = np.array([[-1, 0, 0], [0, 1, 0], [0, 0, 1]]) @ rotation |
|
if 'points' in data: |
|
data['points'].flip('vertical') |
|
if 'gt_bboxes_3d' in data: |
|
data['gt_bboxes_3d'].flip('vertical') |
|
if 'gt_masks_bev' in data: |
|
data['gt_masks_bev'] = data['gt_masks_bev'][:, ::-1, :].copy() |
|
|
|
if 'lidar_aug_matrix' not in data: |
|
data['lidar_aug_matrix'] = np.eye(4) |
|
data['lidar_aug_matrix'][:3, :] = rotation @ data[ |
|
'lidar_aug_matrix'][:3, :] |
|
return data |
|
|
|
|
|
@TRANSFORMS.register_module() |
|
class BEVFusionGlobalRotScaleTrans(GlobalRotScaleTrans): |
|
"""Compared with `GlobalRotScaleTrans`, the augmentation order in this |
|
class is rotation, translation and scaling (RTS).""" |
|
|
|
def transform(self, input_dict: dict) -> dict: |
|
"""Private function to rotate, scale and translate bounding boxes and |
|
points. |
|
|
|
Args: |
|
input_dict (dict): Result dict from loading pipeline. |
|
|
|
Returns: |
|
dict: Results after scaling, 'points', 'pcd_rotation', |
|
'pcd_scale_factor', 'pcd_trans' and `gt_bboxes_3d` are updated |
|
in the result dict. |
|
""" |
|
if 'transformation_3d_flow' not in input_dict: |
|
input_dict['transformation_3d_flow'] = [] |
|
|
|
self._rot_bbox_points(input_dict) |
|
|
|
if 'pcd_scale_factor' not in input_dict: |
|
self._random_scale(input_dict) |
|
self._trans_bbox_points(input_dict) |
|
self._scale_bbox_points(input_dict) |
|
|
|
input_dict['transformation_3d_flow'].extend(['R', 'T', 'S']) |
|
|
|
lidar_augs = np.eye(4) |
|
lidar_augs[:3, :3] = input_dict['pcd_rotation'].T * input_dict[ |
|
'pcd_scale_factor'] |
|
lidar_augs[:3, 3] = input_dict['pcd_trans'] * \ |
|
input_dict['pcd_scale_factor'] |
|
|
|
if 'lidar_aug_matrix' not in input_dict: |
|
input_dict['lidar_aug_matrix'] = np.eye(4) |
|
input_dict[ |
|
'lidar_aug_matrix'] = lidar_augs @ input_dict['lidar_aug_matrix'] |
|
|
|
return input_dict |
|
|
|
|
|
@TRANSFORMS.register_module() |
|
class GridMask(BaseTransform): |
|
|
|
def __init__( |
|
self, |
|
use_h, |
|
use_w, |
|
max_epoch, |
|
rotate=1, |
|
offset=False, |
|
ratio=0.5, |
|
mode=0, |
|
prob=1.0, |
|
fixed_prob=False, |
|
): |
|
self.use_h = use_h |
|
self.use_w = use_w |
|
self.rotate = rotate |
|
self.offset = offset |
|
self.ratio = ratio |
|
self.mode = mode |
|
self.st_prob = prob |
|
self.prob = prob |
|
self.epoch = None |
|
self.max_epoch = max_epoch |
|
self.fixed_prob = fixed_prob |
|
|
|
def set_epoch(self, epoch): |
|
self.epoch = epoch |
|
if not self.fixed_prob: |
|
self.set_prob(self.epoch, self.max_epoch) |
|
|
|
def set_prob(self, epoch, max_epoch): |
|
self.prob = self.st_prob * self.epoch / self.max_epoch |
|
|
|
def transform(self, results): |
|
if np.random.rand() > self.prob: |
|
return results |
|
imgs = results['img'] |
|
h = imgs[0].shape[0] |
|
w = imgs[0].shape[1] |
|
self.d1 = 2 |
|
self.d2 = min(h, w) |
|
hh = int(1.5 * h) |
|
ww = int(1.5 * w) |
|
d = np.random.randint(self.d1, self.d2) |
|
if self.ratio == 1: |
|
self.length = np.random.randint(1, d) |
|
else: |
|
self.length = min(max(int(d * self.ratio + 0.5), 1), d - 1) |
|
mask = np.ones((hh, ww), np.float32) |
|
st_h = np.random.randint(d) |
|
st_w = np.random.randint(d) |
|
if self.use_h: |
|
for i in range(hh // d): |
|
s = d * i + st_h |
|
t = min(s + self.length, hh) |
|
mask[s:t, :] *= 0 |
|
if self.use_w: |
|
for i in range(ww // d): |
|
s = d * i + st_w |
|
t = min(s + self.length, ww) |
|
mask[:, s:t] *= 0 |
|
|
|
r = np.random.randint(self.rotate) |
|
mask = Image.fromarray(np.uint8(mask)) |
|
mask = mask.rotate(r) |
|
mask = np.asarray(mask) |
|
mask = mask[(hh - h) // 2:(hh - h) // 2 + h, |
|
(ww - w) // 2:(ww - w) // 2 + w] |
|
|
|
mask = mask.astype(np.float32) |
|
mask = mask[:, :, None] |
|
if self.mode == 1: |
|
mask = 1 - mask |
|
|
|
|
|
if self.offset: |
|
offset = torch.from_numpy(2 * (np.random.rand(h, w) - 0.5)).float() |
|
offset = (1 - mask) * offset |
|
imgs = [x * mask + offset for x in imgs] |
|
else: |
|
imgs = [x * mask for x in imgs] |
|
|
|
results.update(img=imgs) |
|
return results |
|
|