duynhm's picture
Initial commit
be2715b
raw
history blame
5.21 kB
import torch
import torch.nn as nn
import torch.nn.functional as F
from ..base import modules as md
class DecoderBlock(nn.Module):
def __init__(
self,
in_channels,
skip_channels,
out_channels,
use_batchnorm=True,
attention_type=None,
):
super().__init__()
self.conv1 = md.Conv2dReLU(
in_channels + skip_channels,
out_channels,
kernel_size=3,
padding=1,
use_batchnorm=use_batchnorm,
)
self.attention1 = md.Attention(attention_type, in_channels=in_channels + skip_channels)
self.conv2 = md.Conv2dReLU(
out_channels,
out_channels,
kernel_size=3,
padding=1,
use_batchnorm=use_batchnorm,
)
self.attention2 = md.Attention(attention_type, in_channels=out_channels)
def forward(self, x, skip=None):
x = F.interpolate(x, scale_factor=2, mode="nearest")
if skip is not None:
x = torch.cat([x, skip], dim=1)
x = self.attention1(x)
x = self.conv1(x)
x = self.conv2(x)
x = self.attention2(x)
return x
class CenterBlock(nn.Sequential):
def __init__(self, in_channels, out_channels, use_batchnorm=True):
conv1 = md.Conv2dReLU(
in_channels,
out_channels,
kernel_size=3,
padding=1,
use_batchnorm=use_batchnorm,
)
conv2 = md.Conv2dReLU(
out_channels,
out_channels,
kernel_size=3,
padding=1,
use_batchnorm=use_batchnorm,
)
super().__init__(conv1, conv2)
class UnetPlusPlusDecoder(nn.Module):
def __init__(
self,
encoder_channels,
decoder_channels,
n_blocks=5,
use_batchnorm=True,
attention_type=None,
center=False,
):
super().__init__()
if n_blocks != len(decoder_channels):
raise ValueError(
"Model depth is {}, but you provide `decoder_channels` for {} blocks.".format(
n_blocks, len(decoder_channels)
)
)
encoder_channels = encoder_channels[1:] # remove first skip with same spatial resolution
encoder_channels = encoder_channels[::-1] # reverse channels to start from head of encoder
# computing blocks input and output channels
head_channels = encoder_channels[0]
self.in_channels = [head_channels] + list(decoder_channels[:-1])
self.skip_channels = list(encoder_channels[1:]) + [0]
self.out_channels = decoder_channels
if center:
self.center = CenterBlock(
head_channels, head_channels, use_batchnorm=use_batchnorm
)
else:
self.center = nn.Identity()
# combine decoder keyword arguments
kwargs = dict(use_batchnorm=use_batchnorm, attention_type=attention_type)
blocks = {}
for layer_idx in range(len(self.in_channels) - 1):
for depth_idx in range(layer_idx+1):
if depth_idx == 0:
in_ch = self.in_channels[layer_idx]
skip_ch = self.skip_channels[layer_idx] * (layer_idx+1)
out_ch = self.out_channels[layer_idx]
else:
out_ch = self.skip_channels[layer_idx]
skip_ch = self.skip_channels[layer_idx] * (layer_idx+1-depth_idx)
in_ch = self.skip_channels[layer_idx - 1]
blocks[f'x_{depth_idx}_{layer_idx}'] = DecoderBlock(in_ch, skip_ch, out_ch, **kwargs)
blocks[f'x_{0}_{len(self.in_channels)-1}'] =\
DecoderBlock(self.in_channels[-1], 0, self.out_channels[-1], **kwargs)
self.blocks = nn.ModuleDict(blocks)
self.depth = len(self.in_channels) - 1
def forward(self, *features):
features = features[1:] # remove first skip with same spatial resolution
features = features[::-1] # reverse channels to start from head of encoder
# start building dense connections
dense_x = {}
for layer_idx in range(len(self.in_channels)-1):
for depth_idx in range(self.depth-layer_idx):
if layer_idx == 0:
output = self.blocks[f'x_{depth_idx}_{depth_idx}'](features[depth_idx], features[depth_idx+1])
dense_x[f'x_{depth_idx}_{depth_idx}'] = output
else:
dense_l_i = depth_idx + layer_idx
cat_features = [dense_x[f'x_{idx}_{dense_l_i}'] for idx in range(depth_idx+1, dense_l_i+1)]
cat_features = torch.cat(cat_features + [features[dense_l_i+1]], dim=1)
dense_x[f'x_{depth_idx}_{dense_l_i}'] =\
self.blocks[f'x_{depth_idx}_{dense_l_i}'](dense_x[f'x_{depth_idx}_{dense_l_i-1}'], cat_features)
dense_x[f'x_{0}_{self.depth}'] = self.blocks[f'x_{0}_{self.depth}'](dense_x[f'x_{0}_{self.depth-1}'])
return dense_x[f'x_{0}_{self.depth}']