Source code for tiramisu_brulee.model.dense

"""Blocks/layers for densely-connected networks
Author: Jacob Reinhold <jcreinhold@gmail.com>
Created on: Jul 02, 2020
"""

__all__ = [
    "ACTIVATION",
    "Bottleneck2d",
    "Bottleneck3d",
    "DenseBlock2d",
    "DenseBlock3d",
    "TransitionDown2d",
    "TransitionDown3d",
    "TransitionUp2d",
    "TransitionUp3d",
]

import builtins
import enum
import functools
import typing

import torch
import torch.nn as nn
import torch.nn.functional as F

ACTIVATION = functools.partial(nn.ReLU, inplace=True)


# partial not supported well by mypy; avoid to type check in class vars below
# https://github.com/python/mypy/issues/1484
class Dropout2d(nn.Dropout2d):
    def __init__(
        self, p: builtins.float = 0.5, *, inplace: builtins.bool = True
    ) -> None:
        super().__init__(p, inplace)


class Dropout3d(nn.Dropout3d):
    def __init__(
        self, p: builtins.float = 0.5, *, inplace: builtins.bool = True
    ) -> None:
        super().__init__(p, inplace)


class ConvLayer(nn.Sequential):
    _conv: typing.ClassVar[typing.Union[typing.Type[nn.Conv2d], typing.Type[nn.Conv3d]]]
    _dropout: typing.ClassVar[
        typing.Union[typing.Type[nn.Dropout2d], typing.Type[nn.Dropout3d]]
    ]
    _kernel_size: typing.ClassVar[
        typing.Union[
            typing.Tuple[builtins.int, builtins.int],
            typing.Tuple[builtins.int, builtins.int, builtins.int],
        ]
    ]
    _maxpool: typing.ClassVar[
        typing.Union[None, typing.Type[nn.MaxPool2d], typing.Type[nn.MaxPool3d]]
    ]
    _norm: typing.ClassVar[
        typing.Union[typing.Type[nn.BatchNorm2d], typing.Type[nn.BatchNorm3d]]
    ]
    _padding_mode: typing.ClassVar[builtins.str] = "replicate"

    def __init__(
        self,
        *,
        in_channels: builtins.int,
        out_channels: builtins.int,
        dropout_rate: builtins.float = 0.2,
    ):
        super().__init__()
        self.dropout_rate = dropout_rate
        self.add_module("norm", self._norm(in_channels))
        self.add_module("act", ACTIVATION())
        padding: typing.Union[builtins.str, builtins.int]
        padding = "same" if self._use_padding() else 0
        conv = self._conv(
            in_channels=in_channels,
            out_channels=out_channels,
            kernel_size=self._kernel_size,  # type: ignore[arg-type]
            bias=False,
            padding=padding,
            padding_mode=self._padding_mode,
        )
        self.add_module("conv", conv)
        if self._use_dropout():
            self.add_module("drop", self._dropout(dropout_rate))
        if self._maxpool is not None:  # use maxpool if not None
            self.add_module("maxpool", self._maxpool(2))

    def _use_dropout(self) -> builtins.bool:
        return self.dropout_rate > 0.0

    def _use_padding(self) -> builtins.bool:
        return any(ks > 2 for ks in self._kernel_size)


class ConvLayer2d(ConvLayer):
    _conv = nn.Conv2d
    _dropout = Dropout2d
    _kernel_size = (3, 3)
    _maxpool = None
    _norm = nn.BatchNorm2d


class ConvLayer3d(ConvLayer):
    _conv = nn.Conv3d
    _dropout = Dropout3d
    _kernel_size = (3, 3, 3)
    _maxpool = None
    _norm = nn.BatchNorm3d


class DenseBlock(nn.Module):
    _layer: typing.ClassVar[
        typing.Union[typing.Type[ConvLayer2d], typing.Type[ConvLayer3d]]
    ]

    def __init__(
        self,
        *,
        in_channels: builtins.int,
        growth_rate: builtins.int,
        n_layers: builtins.int,
        upsample: builtins.bool = False,
        dropout_rate: builtins.float = 0.2,
    ):
        super().__init__()
        self.in_channels = in_channels
        self.growth_rate = growth_rate
        self.n_layers = n_layers
        self.upsample = upsample
        self.dropout_rate = dropout_rate
        # out_channels = growth_rate b/c out_channels added w/ each layer
        _layer = functools.partial(
            self._layer,
            out_channels=self.growth_rate,
            dropout_rate=self.dropout_rate,
        )
        icr = self.in_channels_range
        self.layers = nn.ModuleList([_layer(in_channels=ic) for ic in icr])

    def forward(self, tensor: torch.Tensor) -> torch.Tensor:
        if self.upsample:
            new_features = []
            # We pass all previous activations builtins.into each dense
            # layer normally but we only store each dense layer's
            # output in the new_features array. Note that all
            # concatenation is done on the channel axis (i.e., 1)
            for layer in self.layers:
                out = layer(tensor)
                tensor = torch.cat((tensor, out), 1)
                new_features.append(out)
            return torch.cat(new_features, 1)
        else:
            for layer in self.layers:
                out = layer(tensor)
                tensor = torch.cat((tensor, out), 1)
            return tensor

    @property
    def in_channels_range(self) -> typing.List[builtins.int]:
        ic, gr = self.in_channels, self.growth_rate
        return [ic + i * gr for i in range(self.n_layers)]


[docs]class DenseBlock2d(DenseBlock): _layer = ConvLayer2d
[docs]class DenseBlock3d(DenseBlock): _layer = ConvLayer3d
[docs]class TransitionDown2d(ConvLayer): _conv = nn.Conv2d _dropout = Dropout2d _kernel_size = (1, 1) _maxpool = nn.MaxPool2d _norm = nn.BatchNorm2d
[docs]class TransitionDown3d(ConvLayer): _conv = nn.Conv3d _dropout = Dropout3d _kernel_size = (1, 1, 1) _maxpool = nn.MaxPool3d _norm = nn.BatchNorm3d
@enum.unique class ResizeMethod(enum.Enum): CROP: builtins.str = "crop" INTERPOLATE: builtins.str = "interpolate" @classmethod def from_string(cls, string: builtins.str) -> "ResizeMethod": if string.lower() == "crop": return cls.CROP elif string.lower() == "interpolate": return cls.INTERPOLATE else: msg = f"Only 'crop' and 'interpolate' allowed. Got {string}" raise ValueError(msg) class TransitionUp(nn.Module): _conv: typing.ClassVar[typing.Union[typing.Type[nn.Conv2d], typing.Type[nn.Conv3d]]] _conv_trans: typing.ClassVar[ typing.Union[typing.Type[nn.ConvTranspose2d], typing.Type[nn.ConvTranspose3d]] ] _kernel_size: typing.ClassVar[ typing.Union[ typing.Tuple[builtins.int, builtins.int], typing.Tuple[builtins.int, builtins.int, builtins.int], ] ] _stride: typing.ClassVar[ typing.Union[ typing.Tuple[builtins.int, builtins.int], typing.Tuple[builtins.int, builtins.int, builtins.int], ] ] _interp_mode: typing.ClassVar[builtins.str] def __init__( self, *, in_channels: builtins.int, out_channels: builtins.int, resize_method: ResizeMethod = ResizeMethod.CROP, resize_shape: typing.Optional[typing.Tuple[builtins.int, ...]] = None, static: builtins.bool = False, ): super().__init__() self.resize_shape = resize_shape _conv_kwargs = dict( in_channels=in_channels, out_channels=out_channels, kernel_size=self._kernel_size, bias=False, ) conv_kwargs: typing.Dict[builtins.str, typing.Any] = _conv_kwargs.copy() conv_kwargs["padding"] = "same" conv_kwargs["padding_mode"] = "replicate" conv_trans_kwargs: typing.Dict[builtins.str, typing.Any] = _conv_kwargs.copy() conv_trans_kwargs["stride"] = self._stride self.conv: typing.Union[ nn.Conv2d, nn.Conv3d, nn.ConvTranspose2d, nn.ConvTranspose3d ] self.resize: typing.Callable[..., torch.Tensor] if resize_method == ResizeMethod.CROP: self.conv = self._conv_trans(**conv_trans_kwargs) self.resize = self._crop_to_target setattr(self, "forward", self._forward_dynamic_trans) elif resize_method == ResizeMethod.INTERPOLATE and not static: self.conv = self._conv(**conv_kwargs) setattr(self, "forward", self._forward_dynamic_conv) self.resize = self._interpolate_to_target elif resize_method == ResizeMethod.INTERPOLATE and static: self.conv = self._conv(**conv_kwargs) setattr(self, "forward", self._forward_static) else: msg = f"resize_method needs to be a ResizeMethod. Got {resize_method}" raise ValueError(msg) def _forward_dynamic_trans( self, tensor: torch.Tensor, *, skip: torch.Tensor ) -> torch.Tensor: out: torch.Tensor = self.conv(tensor) out = self.resize(out, target=skip) out = torch.cat((out, skip), 1) return out def _forward_dynamic_conv( self, tensor: torch.Tensor, *, skip: torch.Tensor ) -> torch.Tensor: out: torch.Tensor = self.resize(tensor, target=skip) out = self.conv(out) out = torch.cat((out, skip), 1) return out def _forward_static( self, tensor: torch.Tensor, *, skip: torch.Tensor ) -> torch.Tensor: out: torch.Tensor = self._interpolate(tensor, scale_factor=2.0) out = self.conv(out) out = torch.cat((out, skip), 1) return out def _crop_to_target( self, tensor: torch.Tensor, *, target: torch.Tensor ) -> torch.Tensor: raise NotImplementedError def _interpolate_to_target( self, tensor: torch.Tensor, *, target: torch.Tensor ) -> torch.Tensor: return self._interpolate(tensor, size=target.shape[2:]) def _interpolate(self, tensor: torch.Tensor, **kwargs: typing.Any) -> torch.Tensor: interp_kwargs = dict(mode=self._interp_mode, align_corners=True, **kwargs) out: torch.Tensor = F.interpolate(tensor, **interp_kwargs) return out
[docs]class TransitionUp2d(TransitionUp): _conv = nn.Conv2d _conv_trans = nn.ConvTranspose2d _kernel_size = (3, 3) _stride = (2, 2) _interp_mode = "bilinear" def _crop_to_target( self, tensor: torch.Tensor, *, target: torch.Tensor ) -> torch.Tensor: if self.resize_shape is None: _, _, max_h, max_w = target.shape else: max_h, max_w = self.resize_shape _, _, _h, _w = tensor.size() h = torch.div(_h - max_h, 2, rounding_mode="trunc") w = torch.div(_w - max_w, 2, rounding_mode="trunc") return tensor[:, :, h : h + max_h, w : w + max_w] # type: ignore[misc]
[docs]class TransitionUp3d(TransitionUp): _conv = nn.Conv3d _conv_trans = nn.ConvTranspose3d _kernel_size = (3, 3, 3) _stride = (2, 2, 2) _interp_mode = "trilinear" # flake8: noqa: E501 def _crop_to_target( self, tensor: torch.Tensor, *, target: torch.Tensor ) -> torch.Tensor: if self.resize_shape is None: _, _, max_h, max_w, max_d = target.shape else: max_h, max_w, max_d = self.resize_shape _, _, _h, _w, _d = tensor.size() h = torch.div(_h - max_h, 2, rounding_mode="trunc") w = torch.div(_w - max_w, 2, rounding_mode="trunc") d = torch.div(_d - max_d, 2, rounding_mode="trunc") return tensor[:, :, h : h + max_h, w : w + max_w, d : d + max_d] # type: ignore[misc]
class Bottleneck(nn.Sequential): _layer: typing.ClassVar[ typing.Union[typing.Type[DenseBlock2d], typing.Type[DenseBlock3d]] ] def __init__( self, *, in_channels: builtins.int, growth_rate: builtins.int, n_layers: builtins.int, dropout_rate: builtins.float = 0.2, ): super().__init__() layer = self._layer( in_channels=in_channels, growth_rate=growth_rate, n_layers=n_layers, upsample=True, dropout_rate=dropout_rate, ) self.add_module("bottleneck", layer)
[docs]class Bottleneck2d(Bottleneck): _layer = DenseBlock2d
[docs]class Bottleneck3d(Bottleneck): _layer = DenseBlock3d