import torch import torch.nn as nn import torch.nn.functional as F from ..base import modules as md class DecoderBlock(nn.Module): def __init__( self, in_channels, skip_channels, out_channels, use_batchnorm=True, attention_type=None, ): super().__init__() self.conv1 = md.Conv2dReLU( in_channels + skip_channels, out_channels, kernel_size=3, padding=1, use_batchnorm=use_batchnorm, ) self.attention1 = md.Attention(attention_type, in_channels=in_channels + skip_channels) self.conv2 = md.Conv2dReLU( out_channels, out_channels, kernel_size=3, padding=1, use_batchnorm=use_batchnorm, ) self.attention2 = md.Attention(attention_type, in_channels=out_channels) def forward(self, x, skip=None): x = F.interpolate(x, scale_factor=2, mode="nearest") if skip is not None: x = torch.cat([x, skip], dim=1) x = self.attention1(x) x = self.conv1(x) x = self.conv2(x) x = self.attention2(x) return x class CenterBlock(nn.Sequential): def __init__(self, in_channels, out_channels, use_batchnorm=True): conv1 = md.Conv2dReLU( in_channels, out_channels, kernel_size=3, padding=1, use_batchnorm=use_batchnorm, ) conv2 = md.Conv2dReLU( out_channels, out_channels, kernel_size=3, padding=1, use_batchnorm=use_batchnorm, ) super().__init__(conv1, conv2) class UnetPlusPlusDecoder(nn.Module): def __init__( self, encoder_channels, decoder_channels, n_blocks=5, use_batchnorm=True, attention_type=None, center=False, ): super().__init__() if n_blocks != len(decoder_channels): raise ValueError( "Model depth is {}, but you provide `decoder_channels` for {} blocks.".format( n_blocks, len(decoder_channels) ) ) encoder_channels = encoder_channels[1:] # remove first skip with same spatial resolution encoder_channels = encoder_channels[::-1] # reverse channels to start from head of encoder # computing blocks input and output channels head_channels = encoder_channels[0] self.in_channels = [head_channels] + list(decoder_channels[:-1]) self.skip_channels = list(encoder_channels[1:]) + [0] self.out_channels = decoder_channels if center: self.center = CenterBlock( head_channels, head_channels, use_batchnorm=use_batchnorm ) else: self.center = nn.Identity() # combine decoder keyword arguments kwargs = dict(use_batchnorm=use_batchnorm, attention_type=attention_type) blocks = {} for layer_idx in range(len(self.in_channels) - 1): for depth_idx in range(layer_idx+1): if depth_idx == 0: in_ch = self.in_channels[layer_idx] skip_ch = self.skip_channels[layer_idx] * (layer_idx+1) out_ch = self.out_channels[layer_idx] else: out_ch = self.skip_channels[layer_idx] skip_ch = self.skip_channels[layer_idx] * (layer_idx+1-depth_idx) in_ch = self.skip_channels[layer_idx - 1] blocks[f'x_{depth_idx}_{layer_idx}'] = DecoderBlock(in_ch, skip_ch, out_ch, **kwargs) blocks[f'x_{0}_{len(self.in_channels)-1}'] =\ DecoderBlock(self.in_channels[-1], 0, self.out_channels[-1], **kwargs) self.blocks = nn.ModuleDict(blocks) self.depth = len(self.in_channels) - 1 def forward(self, *features): features = features[1:] # remove first skip with same spatial resolution features = features[::-1] # reverse channels to start from head of encoder # start building dense connections dense_x = {} for layer_idx in range(len(self.in_channels)-1): for depth_idx in range(self.depth-layer_idx): if layer_idx == 0: output = self.blocks[f'x_{depth_idx}_{depth_idx}'](features[depth_idx], features[depth_idx+1]) dense_x[f'x_{depth_idx}_{depth_idx}'] = output else: dense_l_i = depth_idx + layer_idx cat_features = [dense_x[f'x_{idx}_{dense_l_i}'] for idx in range(depth_idx+1, dense_l_i+1)] cat_features = torch.cat(cat_features + [features[dense_l_i+1]], dim=1) dense_x[f'x_{depth_idx}_{dense_l_i}'] =\ self.blocks[f'x_{depth_idx}_{dense_l_i}'](dense_x[f'x_{depth_idx}_{dense_l_i-1}'], cat_features) dense_x[f'x_{0}_{self.depth}'] = self.blocks[f'x_{0}_{self.depth}'](dense_x[f'x_{0}_{self.depth-1}']) return dense_x[f'x_{0}_{self.depth}']