Source code for tiramisu_brulee.experiment.lesion_tools
"""Functions specific to handling/processing lesion segmentations
Author: Jacob Reinhold <jcreinhold@gmail.com>
Created on: May 16, 2021
"""
__all__ = [
"almost_isbi15_score",
"clean_segmentation",
]
import builtins
import typing
import numpy as np
import torch
from scipy.ndimage.morphology import binary_fill_holes, generate_binary_structure
from skimage.morphology import remove_small_objects
from torchmetrics.functional import dice_score, pearson_corrcoef, precision
from tiramisu_brulee.experiment.util import image_one_hot
[docs]def clean_segmentation(
label: np.ndarray,
*,
fill_holes: builtins.bool = True,
minimum_lesion_size: builtins.int = 3,
) -> np.ndarray:
"""clean binary array by removing small objs & filling holes"""
d = label.ndim
if fill_holes:
structure = generate_binary_structure(d, d)
label = binary_fill_holes(label, structure=structure)
if minimum_lesion_size > 0:
label = remove_small_objects(
label,
min_size=minimum_lesion_size,
connectivity=d,
)
return label
[docs]def almost_isbi15_score(
pred: torch.Tensor,
target: torch.Tensor,
*,
return_dice_ppv: builtins.bool = False,
) -> typing.Union[torch.Tensor, typing.Tuple[torch.Tensor, torch.Tensor, torch.Tensor]]:
"""ISBI 15 MS challenge score excluding the LTPR & LFPR components"""
batch_size, num_classes = pred.shape[0:2]
multiclass = num_classes > 1
one_hot_classes = num_classes if multiclass else 2
pred_one_hot = (
pred if multiclass else image_one_hot(pred, num_classes=one_hot_classes)
)
dice = dice_score(pred_one_hot, target.int())
if multiclass and pred.shape != target.shape:
is_integer_label = pred.ndim != target.ndim
if is_integer_label:
target.unsqueeze_(1)
assert pred.ndim == target.ndim
target = image_one_hot(target.long(), num_classes=num_classes)
ppv = precision(
pred.int(),
target.int(),
num_classes=num_classes if multiclass else None,
mdmc_average="samplewise",
multiclass=multiclass or None,
)
isbi15_score = 0.5 * dice + 0.5 * ppv
if batch_size > 1 and not multiclass:
dims = list(range(1, pred.ndim))
corr = pearson_corrcoef(
pred.sum(dim=dims).float(),
target.sum(dim=dims).float(),
)
isbi15_score = 0.5 * isbi15_score + 0.5 * corr
if return_dice_ppv:
return isbi15_score, dice, ppv
else:
return isbi15_score