Source code for tiramisu_brulee.model.tiramisu

"""PyTorch Tiramisu network

PyTorch implementation of the Tiramisu network architecture.
Implementation based on `pytorch_tiramisu`.

Changes from `pytorch_tiramisu` include:
  1) removal of bias from conv layers,
  2) change zero padding to replication padding,
  3) cosmetic changes for brevity, clarity, consistency

References:
  Jégou, Simon, et al. "The one hundred layers tiramisu:
  Fully convolutional densenets for semantic segmentation."
  CVPR. 2017.

  https://github.com/bfortuner/pytorch_tiramisu

Author: Jacob Reinhold <jcreinhold@gmail.com>
Created on: Jul 01, 2020
"""

__all__ = [
    "Tiramisu2d",
    "Tiramisu3d",
]

import builtins
import typing

import torch
import torch.nn as nn

from tiramisu_brulee.model.dense import (
    Bottleneck2d,
    Bottleneck3d,
    DenseBlock2d,
    DenseBlock3d,
    ResizeMethod,
    TransitionDown2d,
    TransitionDown3d,
    TransitionUp2d,
    TransitionUp3d,
)


class Tiramisu(nn.Module):
    _bottleneck: typing.ClassVar[
        typing.Union[typing.Type[Bottleneck2d], typing.Type[Bottleneck3d]]
    ]
    _conv: typing.ClassVar[typing.Union[typing.Type[nn.Conv2d], typing.Type[nn.Conv3d]]]
    _denseblock: typing.ClassVar[
        typing.Union[typing.Type[DenseBlock2d], typing.Type[DenseBlock3d]]
    ]
    _trans_down: typing.ClassVar[
        typing.Union[typing.Type[TransitionDown2d], typing.Type[TransitionDown3d]]
    ]
    _trans_up: typing.ClassVar[
        typing.Union[typing.Type[TransitionUp2d], typing.Type[TransitionUp3d]]
    ]
    _first_kernel_size: typing.ClassVar[
        typing.Union[
            typing.Tuple[builtins.int, builtins.int],
            typing.Tuple[builtins.int, builtins.int, builtins.int],
        ]
    ]
    _final_kernel_size: typing.ClassVar[
        typing.Union[
            typing.Tuple[builtins.int, builtins.int],
            typing.Tuple[builtins.int, builtins.int, builtins.int],
        ]
    ]
    _padding_mode: typing.ClassVar[builtins.str] = "replicate"

    # flake8: noqa: E501
    def __init__(
        self,
        *,
        in_channels: builtins.int = 3,
        out_channels: builtins.int = 1,
        down_blocks: typing.Collection[builtins.int] = (5, 5, 5, 5, 5),
        up_blocks: typing.Collection[builtins.int] = (5, 5, 5, 5, 5),
        bottleneck_layers: builtins.int = 5,
        growth_rate: builtins.int = 16,
        first_conv_out_channels: builtins.int = 48,
        dropout_rate: builtins.float = 0.2,
        resize_method: ResizeMethod = ResizeMethod.CROP,
        input_shape: typing.Optional[typing.Tuple[builtins.int, ...]] = None,
        static_upsample: builtins.bool = False,
    ):
        """
        Base class for Tiramisu convolutional neural network

        See Also:
            Jégou, Simon, et al. "The one hundred layers tiramisu: Fully
            convolutional densenets for semantic segmentation." CVPR. 2017.

            Based on: https://github.com/bfortuner/pytorch_tiramisu

        Args:
            in_channels (builtins.int): number of input channels
            out_channels (builtins.int): number of output channels
            down_blocks (typing.Collection[builtins.int]): number of layers in each block in down path
            up_blocks (typing.Collection[builtins.int]): number of layers in each block in up path
            bottleneck_layers (builtins.int): number of layers in the bottleneck
            growth_rate (builtins.int): number of channels to grow by in each layer
            first_conv_out_channels (builtins.int): number of output channels in first conv
            dropout_rate (builtins.float): dropout rate/probability
            resize_method (ResizeMethod): method to resize the image in upsample branch
            input_shape: optionally provide shape of the input image (for onnx)
            static_upsample: use static upsampling when capable if input_shape provided
                (doesn't check upsampled size matches)
        """
        super().__init__()
        assert len(down_blocks) == len(up_blocks)

        self.down_blocks = down_blocks
        self.up_blocks = up_blocks
        skip_connection_channel_counts: typing.List[builtins.int] = []
        if input_shape is not None:
            tensor_shape = torch.as_tensor(input_shape)
            shapes = [tuple(map(int, input_shape))]

        self.first_conv = nn.Sequential(
            self._conv(
                in_channels,
                first_conv_out_channels,
                self._first_kernel_size,  # type: ignore[arg-type]
                bias=False,
                padding="same",
                padding_mode=self._padding_mode,
            ),
        )
        cur_channels_count: builtins.int = first_conv_out_channels

        # Downsampling path
        self.dense_down = nn.ModuleList([])
        self.trans_down = nn.ModuleList([])
        for i, n_layers in enumerate(down_blocks, 1):
            denseblock = self._denseblock(
                in_channels=cur_channels_count,
                growth_rate=growth_rate,
                n_layers=n_layers,
                upsample=False,
                dropout_rate=dropout_rate,
            )
            self.dense_down.append(denseblock)
            cur_channels_count += growth_rate * n_layers
            skip_connection_channel_counts.insert(0, cur_channels_count)
            trans_down_block = self._trans_down(
                in_channels=cur_channels_count,
                out_channels=cur_channels_count,
                dropout_rate=dropout_rate,
            )
            self.trans_down.append(trans_down_block)
            if i < len(down_blocks) and input_shape is not None:
                tensor_shape = torch.div(tensor_shape, 2, rounding_mode="floor")
                shapes.append(tuple(map(int, tensor_shape)))

        # Bottleneck
        self.bottleneck = self._bottleneck(
            in_channels=cur_channels_count,
            growth_rate=growth_rate,
            n_layers=bottleneck_layers,
            dropout_rate=dropout_rate,
        )
        prev_block_channels = growth_rate * bottleneck_layers
        cur_channels_count += prev_block_channels

        # Upsampling path
        self.dense_up = nn.ModuleList([])
        self.trans_up = nn.ModuleList([])
        up_info = zip(up_blocks, skip_connection_channel_counts)
        for i, (n_layers, sccc) in enumerate(up_info, 1):
            resize_shape = None if input_shape is None else shapes.pop()
            if resize_shape is not None and static_upsample:
                _static_upsample = all(x % 2 == 0 for x in resize_shape)
            else:
                _static_upsample = False
            trans_up_block = self._trans_up(
                in_channels=prev_block_channels,
                out_channels=prev_block_channels,
                resize_method=resize_method,
                resize_shape=resize_shape,
                static=_static_upsample,
            )
            self.trans_up.append(trans_up_block)
            cur_channels_count = prev_block_channels + sccc
            upsample = i < len(up_blocks)  # do not upsample on last block
            denseblock = self._denseblock(
                in_channels=cur_channels_count,
                growth_rate=growth_rate,
                n_layers=n_layers,
                upsample=upsample,
                dropout_rate=dropout_rate,
            )
            self.dense_up.append(denseblock)
            prev_block_channels = growth_rate * n_layers
            cur_channels_count += prev_block_channels

        self.final_conv = self._conv(
            in_channels=cur_channels_count,
            out_channels=out_channels,
            kernel_size=self._final_kernel_size,  # type: ignore[arg-type]
            bias=True,
            padding="same",
            padding_mode=self._padding_mode,
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        out = self.first_conv(x)
        skip_connections = []
        for dbd, tdb in zip(self.dense_down, self.trans_down):
            out = dbd(out)
            skip_connections.append(out)
            out = tdb(out)
        out = self.bottleneck(out)
        for ubd, tub in zip(self.dense_up, self.trans_up):
            skip = skip_connections.pop()
            out = tub(out, skip=skip)
            out = ubd(out)
        out = self.final_conv(out)
        assert isinstance(out, torch.Tensor)
        return out


[docs]class Tiramisu2d(Tiramisu): _bottleneck = Bottleneck2d _conv = nn.Conv2d _denseblock = DenseBlock2d _pad = nn.ReplicationPad2d _trans_down = TransitionDown2d _trans_up = TransitionUp2d _first_kernel_size = (3, 3) _final_kernel_size = (1, 1)
[docs]class Tiramisu3d(Tiramisu): _bottleneck = Bottleneck3d _conv = nn.Conv3d _denseblock = DenseBlock3d _pad = nn.ReplicationPad3d _trans_down = TransitionDown3d _trans_up = TransitionUp3d _first_kernel_size = (3, 3, 3) _final_kernel_size = (1, 1, 1)