# modify from https://github.com/mit-han-lab/bevfusion import torch import torch.nn as nn import torch.nn.functional as F from mmcv.cnn import ConvModule from mmengine.model import BaseModule from mmdet3d.registry import MODELS @MODELS.register_module() class GeneralizedLSSFPN(BaseModule): def __init__( self, in_channels, out_channels, num_outs, start_level=0, end_level=-1, no_norm_on_lateral=False, conv_cfg=None, norm_cfg=dict(type='BN2d'), act_cfg=dict(type='ReLU'), upsample_cfg=dict(mode='bilinear', align_corners=True), ) -> None: super().__init__() assert isinstance(in_channels, list) self.in_channels = in_channels self.out_channels = out_channels self.num_ins = len(in_channels) self.num_outs = num_outs self.no_norm_on_lateral = no_norm_on_lateral self.fp16_enabled = False self.upsample_cfg = upsample_cfg.copy() if end_level == -1: self.backbone_end_level = self.num_ins - 1 # assert num_outs >= self.num_ins - start_level else: # if end_level < inputs, no extra level is allowed self.backbone_end_level = end_level assert end_level <= len(in_channels) assert num_outs == end_level - start_level self.start_level = start_level self.end_level = end_level self.lateral_convs = nn.ModuleList() self.fpn_convs = nn.ModuleList() for i in range(self.start_level, self.backbone_end_level): l_conv = ConvModule( in_channels[i] + (in_channels[i + 1] if i == self.backbone_end_level - 1 else out_channels), out_channels, 1, conv_cfg=conv_cfg, norm_cfg=norm_cfg if not self.no_norm_on_lateral else None, act_cfg=act_cfg, inplace=False, ) fpn_conv = ConvModule( out_channels, out_channels, 3, padding=1, conv_cfg=conv_cfg, norm_cfg=norm_cfg, act_cfg=act_cfg, inplace=False, ) self.lateral_convs.append(l_conv) self.fpn_convs.append(fpn_conv) def forward(self, inputs): """Forward function.""" # upsample -> cat -> conv1x1 -> conv3x3 assert len(inputs) == len(self.in_channels) # build laterals laterals = [inputs[i + self.start_level] for i in range(len(inputs))] # build top-down path used_backbone_levels = len(laterals) - 1 for i in range(used_backbone_levels - 1, -1, -1): x = F.interpolate( laterals[i + 1], size=laterals[i].shape[2:], **self.upsample_cfg, ) laterals[i] = torch.cat([laterals[i], x], dim=1) laterals[i] = self.lateral_convs[i](laterals[i]) laterals[i] = self.fpn_convs[i](laterals[i]) # build outputs outs = [laterals[i] for i in range(used_backbone_levels)] return tuple(outs)