Source code for tiramisu_brulee.loss

"""Segmentation loss functions

See Also:
    https://gitlab.com/shan-deep-networks/pytorch-metrics/

    https://github.com/catalyst-team/catalyst/

    https://github.com/facebookresearch/fvcore

    S.A. Taghanaki et al. "Combo loss: Handling input and
    output imbalance in multi-organ segmentation." Computerized
    Medical Imaging and Graphics 75 (2019): 24-33.

Author: Jacob Reinhold <jcreinhold@gmail.com>
Created on: Jul 01, 2020
"""

__all__ = [
    "binary_combo_loss",
    "binary_focal_loss",
    "deeply_supervised_loss",
    "dice_loss",
    "l1_segmentation_loss",
    "mse_segmentation_loss",
]

import builtins
import typing

import torch
import torch.nn.functional as F


def per_channel_dice(
    tensor1: torch.Tensor,
    tensor2: torch.Tensor,
    *,
    eps: builtins.float = 1e-3,
    keepdim: builtins.bool = False,
) -> torch.Tensor:
    """compute dice score for each channel separately and reduce"""
    spatial_dims = tuple(range(2 - len(tensor1.shape), 0))
    intersection = torch.sum(tensor1 * tensor2, dim=spatial_dims, keepdim=keepdim)
    x_sum = torch.sum(tensor1, dim=spatial_dims, keepdim=keepdim)
    y_sum = torch.sum(tensor2, dim=spatial_dims, keepdim=keepdim)
    pc_dice = (2 * intersection + eps) / (x_sum + y_sum + eps)
    return pc_dice


def weighted_channel_avg(tensor: torch.Tensor, *, weight: torch.Tensor) -> torch.Tensor:
    weight = weight[None, ...].repeat([tensor.shape[0], 1])
    weighted = torch.mean(weight * tensor)
    return weighted


[docs]def dice_loss( pred: torch.Tensor, target: torch.Tensor, *, weight: typing.Optional[torch.Tensor] = None, reduction: builtins.str = "mean", eps: builtins.float = 1e-3, ) -> torch.Tensor: """sorensen-dice coefficient loss function""" keepdim = reduction != "mean" pc_dice = per_channel_dice(pred, target, eps=eps, keepdim=keepdim) dice: torch.Tensor if reduction == "mean": if weight is None: dice = torch.mean(pc_dice) else: dice = weighted_channel_avg(pc_dice, weight=weight) elif reduction == "none": dice = pc_dice else: raise NotImplementedError(f"{reduction} not implemented.") one_minus_dice: torch.Tensor = 1.0 - dice return one_minus_dice
[docs]def binary_focal_loss( pred: torch.Tensor, target: torch.Tensor, *, pos_weight: typing.Optional[typing.Union[builtins.float, torch.Tensor]] = None, reduction: builtins.str = "mean", gamma: builtins.float = 2.0, ) -> torch.Tensor: """focal loss for binary classification or segmentation""" use_focal = gamma > 0.0 bce_reduction = "none" if use_focal else reduction if use_focal: bce_pos_weight = None else: if pos_weight is not None and isinstance(pos_weight, builtins.float): bce_pos_weight = torch.tensor( [pos_weight], dtype=pred.dtype, device=pred.device ) elif pos_weight is None or isinstance(pos_weight, torch.Tensor): bce_pos_weight = pos_weight else: msg = ( f"pos_weight must be a none, float, or tensor. Got {type(pos_weight)}." ) raise ValueError(msg) bce_loss = F.binary_cross_entropy_with_logits( pred, target, reduction=bce_reduction, pos_weight=bce_pos_weight, ) loss_val: torch.Tensor if use_focal: p = torch.sigmoid(pred) p_t = p * target + (1 - p) * (1 - target) loss_val = bce_loss * ((1 - p_t) ** gamma) else: loss_val = bce_loss if pos_weight is not None and use_focal: weight = pos_weight / (1.0 + pos_weight) weight_t = weight * target + (1 - weight) * (1 - target) loss_val = weight_t * loss_val if use_focal: if reduction == "mean": loss_val = loss_val.mean() elif reduction == "sum": loss_val = loss_val.sum() elif reduction == "batchwise_mean": loss_val = loss_val.sum(0) elif reduction == "none": pass else: raise NotImplementedError(f"{reduction} not implemented.") return loss_val
[docs]def binary_combo_loss( pred: torch.Tensor, target: torch.Tensor, *, reduction: builtins.str = "mean", pos_weight: typing.Optional[builtins.float] = None, focal_gamma: builtins.float = 0.0, combo_weight: builtins.float = 0.5, ) -> torch.Tensor: """combo loss (dice + focal weighted by combo_weight) for binary labels""" assert 0.0 <= combo_weight <= 1.0 assert 0.0 <= focal_gamma f_loss = binary_focal_loss( pred, target, pos_weight=pos_weight, reduction=reduction, gamma=focal_gamma, ) p = torch.sigmoid(pred) d_loss = dice_loss(p, target, reduction=reduction) loss = combo_weight * f_loss + (1 - combo_weight) * d_loss return loss
def combo_loss( pred: torch.Tensor, target: torch.Tensor, *, num_classes: builtins.int, reduction: builtins.str = "mean", combo_weight: builtins.float = 0.5, ) -> torch.Tensor: """combo loss (dice + focal weighted by combo_weight) for multi-class labels""" assert 0.0 <= combo_weight <= 1.0 assert 2 <= num_classes channel_not_removed = pred.ndim == target.ndim if channel_not_removed and target.shape[1] > 1: msg = f"Channel size must be 1 or 0. Got {target.shape[1]}" raise ValueError(msg) _target = target[:, 0, ...] if channel_not_removed else target _target = _target.long() f_loss = F.cross_entropy(pred, _target, reduction=reduction) p = torch.softmax(pred, dim=1) target_one_hot = F.one_hot(_target, num_classes) target_one_hot = torch.movedim(target_one_hot, -1, 1) target_one_hot = target_one_hot.float() d_loss = dice_loss(p, target_one_hot, reduction=reduction) loss = combo_weight * f_loss + (1 - combo_weight) * d_loss return loss
[docs]def deeply_supervised_loss( # type: ignore[no-untyped-def] preds: typing.List[torch.Tensor], target: torch.Tensor, *, loss_func: typing.Callable, level_weights: typing.Union[builtins.float, typing.List[builtins.float]] = 1.0, **loss_func_kwargs, ) -> torch.Tensor: """compute loss_func by comparing multiple same-shape preds to target""" if isinstance(level_weights, builtins.float): level_weights = [level_weights] * len(preds) loss_val = torch.tensor(0.0, dtype=target.dtype, device=target.device) for lw, x in zip(level_weights, preds): loss_val += lw * loss_func(x, target, **loss_func_kwargs) return loss_val
[docs]def l1_segmentation_loss( pred: torch.Tensor, target: torch.Tensor, *, reduction: builtins.str = "mean", ) -> torch.Tensor: """l1 loss for segmentation by applying sigmoid to pred -> l1""" return F.l1_loss(torch.sigmoid(pred), target, reduction=reduction)
[docs]def mse_segmentation_loss( pred: torch.Tensor, target: torch.Tensor, *, reduction: builtins.str = "mean", ) -> torch.Tensor: """mse loss for segmentation by applying sigmoid to pred -> mse""" return F.mse_loss(torch.sigmoid(pred), target, reduction=reduction)