|
|
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
from mmcv.cnn import ConvModule |
|
from mmengine.model import BaseModule |
|
|
|
from mmdet3d.registry import MODELS |
|
|
|
|
|
@MODELS.register_module() |
|
class GeneralizedLSSFPN(BaseModule): |
|
|
|
def __init__( |
|
self, |
|
in_channels, |
|
out_channels, |
|
num_outs, |
|
start_level=0, |
|
end_level=-1, |
|
no_norm_on_lateral=False, |
|
conv_cfg=None, |
|
norm_cfg=dict(type='BN2d'), |
|
act_cfg=dict(type='ReLU'), |
|
upsample_cfg=dict(mode='bilinear', align_corners=True), |
|
) -> None: |
|
super().__init__() |
|
assert isinstance(in_channels, list) |
|
self.in_channels = in_channels |
|
self.out_channels = out_channels |
|
self.num_ins = len(in_channels) |
|
self.num_outs = num_outs |
|
self.no_norm_on_lateral = no_norm_on_lateral |
|
self.fp16_enabled = False |
|
self.upsample_cfg = upsample_cfg.copy() |
|
|
|
if end_level == -1: |
|
self.backbone_end_level = self.num_ins - 1 |
|
|
|
else: |
|
|
|
self.backbone_end_level = end_level |
|
assert end_level <= len(in_channels) |
|
assert num_outs == end_level - start_level |
|
self.start_level = start_level |
|
self.end_level = end_level |
|
|
|
self.lateral_convs = nn.ModuleList() |
|
self.fpn_convs = nn.ModuleList() |
|
|
|
for i in range(self.start_level, self.backbone_end_level): |
|
l_conv = ConvModule( |
|
in_channels[i] + |
|
(in_channels[i + 1] if i == self.backbone_end_level - |
|
1 else out_channels), |
|
out_channels, |
|
1, |
|
conv_cfg=conv_cfg, |
|
norm_cfg=norm_cfg if not self.no_norm_on_lateral else None, |
|
act_cfg=act_cfg, |
|
inplace=False, |
|
) |
|
fpn_conv = ConvModule( |
|
out_channels, |
|
out_channels, |
|
3, |
|
padding=1, |
|
conv_cfg=conv_cfg, |
|
norm_cfg=norm_cfg, |
|
act_cfg=act_cfg, |
|
inplace=False, |
|
) |
|
|
|
self.lateral_convs.append(l_conv) |
|
self.fpn_convs.append(fpn_conv) |
|
|
|
def forward(self, inputs): |
|
"""Forward function.""" |
|
|
|
assert len(inputs) == len(self.in_channels) |
|
|
|
|
|
laterals = [inputs[i + self.start_level] for i in range(len(inputs))] |
|
|
|
|
|
used_backbone_levels = len(laterals) - 1 |
|
for i in range(used_backbone_levels - 1, -1, -1): |
|
x = F.interpolate( |
|
laterals[i + 1], |
|
size=laterals[i].shape[2:], |
|
**self.upsample_cfg, |
|
) |
|
laterals[i] = torch.cat([laterals[i], x], dim=1) |
|
laterals[i] = self.lateral_convs[i](laterals[i]) |
|
laterals[i] = self.fpn_convs[i](laterals[i]) |
|
|
|
|
|
outs = [laterals[i] for i in range(used_backbone_levels)] |
|
return tuple(outs) |
|
|