"""Data handling classes for training/prediction
load and process data for training/prediction
for segmentation tasks
Author: Jacob Reinhold (jcreinhold@gmail.com)
Created on: May 17, 2021
"""
__all__ = [
"csv_to_subjectlist",
"LesionSegDataModulePredictBase",
"LesionSegDataModulePredictPatches",
"LesionSegDataModulePredictWhole",
"LesionSegDataModuleTrain",
"Mixup",
]
import builtins
import logging
import pathlib
import types
import typing
import warnings
from multiprocessing import Manager
import numpy as np
import pandas as pd
import pytorch_lightning as pl
import torch
import torch.distributions as D
import torchio as tio
from jsonargparse import ArgumentParser
from torch.utils.data import DataLoader
from torch.utils.data.dataloader import default_collate
from tiramisu_brulee.experiment.type import (
Batch,
PatchShape,
PatchShapeOption,
PathLike,
file_path,
nonnegative_int,
nonnegative_int_or_none_or_all,
positive_float,
positive_int,
positive_int_or_none,
positive_odd_int_or_none,
probability_float,
)
from tiramisu_brulee.experiment.util import reshape_for_broadcasting
RECOGNIZED_NAMES = (
"cect",
"ct",
"flair",
"pd",
"pet",
"t1",
"t1c",
"t2",
"label",
"weight",
"div",
"out",
)
logger = logging.getLogger(__name__)
TrainDataModule = typing.TypeVar("TrainDataModule", bound="LesionSegDataModuleTrain")
PredictDataModule = typing.TypeVar(
"PredictDataModule", bound="LesionSegDataModulePredictBase"
)
class SubjectsDataset(tio.SubjectsDataset):
def __init__(self, *args, **kwargs): # type: ignore[no-untyped-def]
super().__init__(*args, **kwargs)
manager = Manager()
self._subjects = manager.list(self._subjects)
class NonRandomLabelSampler(tio.LabelSampler):
def __init__(self, *args, **kwargs): # type: ignore[no-untyped-def]
super().__init__(*args, **kwargs)
self._cache = dict()
def _generate_patches(
self,
subject: tio.Subject,
num_patches: typing.Optional[builtins.int] = None,
) -> typing.Generator[tio.Subject, None, None]:
patches_left = num_patches if num_patches is not None else True
count = 0
name = subject["name"]
if name not in self._cache:
self._cache[name] = dict()
probability_map = self.get_probability_map(subject)
probability_map = self.process_probability_map(probability_map, subject)
cdf = self.get_cumulative_distribution_function(probability_map)
while patches_left:
if count in self._cache[name]:
idx = self._cache[name][count]
else:
idx = self.get_random_index_ini(probability_map, cdf) # noqa
self._cache[name][count] = idx
count += 1
cropped_subject = self.crop(subject, idx, self.patch_size)
yield cropped_subject
if num_patches is not None:
patches_left -= 1
class LesionSegDataModuleBase(pl.LightningDataModule):
def __init__(
self,
*,
batch_size: builtins.int,
patch_size: typing.Optional[PatchShapeOption] = None,
num_workers: builtins.int = 16,
pseudo3d_dim: typing.Optional[builtins.int] = None,
pseudo3d_size: typing.Optional[builtins.int] = None,
reorient_to_canonical: builtins.bool = True,
use_memory_saving_dataset: builtins.bool = False,
):
super().__init__()
self.batch_size = batch_size
self.num_workers = num_workers
self.pseudo3d_dim = pseudo3d_dim
self._pseudo3d_dim_internal = (
pseudo3d_dim if isinstance(pseudo3d_dim, builtins.int) else 0
)
self.pseudo3d_size = pseudo3d_size
if self._use_pseudo3d and self.pseudo3d_size is None:
raise ValueError(
"If pseudo3d_dim provided, pseudo3d_size must be provided."
)
self.patch_size = self._determine_patch_size(patch_size)
self.reorient_to_canonical = reorient_to_canonical
self.use_memory_saving_dataset = use_memory_saving_dataset
def _determine_input(
self,
subjects: typing.Union[tio.Subject, typing.List[tio.Subject]],
*,
other_subjects: typing.Optional[typing.List[tio.Subject]] = None,
) -> None:
"""
assume all columns except:
`name`, `label`, `div`, `weight`, or `out`
are some type of non-categorical image
"""
exclude = ("name", "label", "div", "weight", "out")
if isinstance(subjects, list):
subject = subjects[0] # arbitrarily pick the first element
else:
subject = subjects
inputs = []
for key in subject:
if key not in exclude:
inputs.append(key)
if len(inputs) == 0:
msg = (
"No inputs detected in CSV. Expect columns like "
"`t1` with corresponding paths to NIfTI files."
)
raise ValueError(msg)
if other_subjects is not None:
other_subject = other_subjects[0]
for key in inputs:
if key not in other_subject:
msg = "Validation CSV fields not the same as training CSV"
raise ValueError(msg)
if "label" not in subject or "label" not in other_subject:
msg = "`label` field expected in both " "training and validation CSV."
raise ValueError(msg)
if ("div" in subject) ^ ("div" in other_subject):
msg = (
"If `div` present in one of the training "
"or validation CSVs, it is expected in "
"both training and validation CSV."
)
raise ValueError(msg)
self._use_div = "div" in subject
self._input_fields = tuple(sorted(inputs))
@staticmethod
def _div_image_batch(
image_batch: torch.Tensor, *, div: torch.Tensor
) -> torch.Tensor:
with torch.no_grad():
image_batch /= reshape_for_broadcasting(div, ndim=image_batch.ndim)
return image_batch
def _default_collate_fn(
self,
batch: Batch,
*,
cat_dim: typing.Optional[builtins.int] = None,
) -> torch.Tensor:
if isinstance(batch, list):
batch = default_collate(batch)
inputs: typing.List[torch.Tensor] = []
for field in self._input_fields:
inputs.append(batch[field][tio.DATA])
cat_dim_ = cat_dim or 1
src: torch.Tensor = torch.cat(inputs, dim=cat_dim_)
if cat_dim is not None:
# if axis is not None, use pseudo3d images
src.swapaxes_(1, cat_dim_)
src.squeeze_()
if src.ndim == 3: # batch size of 1
src.unsqueeze_(0)
assert src.ndim == 4
if self._use_div:
assert isinstance(batch["div"], torch.Tensor)
div_factor: torch.Tensor = batch["div"]
src = self._div_image_batch(src, div=div_factor)
return src
@staticmethod
def _pseudo3d_label(
label: torch.Tensor, pseudo3d_dim: builtins.int
) -> torch.Tensor:
assert label.ndim == 5, "expects label with shape NxCxHxWxD"
if pseudo3d_dim == 0:
mid_channel = label.shape[2] // 2
label = label[..., mid_channel, :, :]
elif pseudo3d_dim == 1:
mid_channel = label.shape[3] // 2
label = label[..., :, mid_channel, :]
elif pseudo3d_dim == 2:
mid_channel = label.shape[4] // 2
label = label[..., :, :, mid_channel]
else:
raise ValueError(f"pseudo3d_dim must be 0, 1, or 2. Got {pseudo3d_dim}.")
return label
def _determine_patch_size(
self, patch_size: typing.Optional[PatchShapeOption]
) -> typing.Optional[PatchShapeOption]:
if patch_size is None:
return None
patch_size_list = list(patch_size)
if self._use_pseudo3d and len(patch_size_list) != 2:
raise ValueError(
"If using pseudo3d, patch size must contain only 2 values."
)
if self._use_pseudo3d:
patch_size_list.insert(self._pseudo3d_dim_internal, self.pseudo3d_size)
return tuple(patch_size_list) # type: ignore[return-value]
@property
def _use_pseudo3d(self) -> builtins.bool:
return self.pseudo3d_dim is not None
@staticmethod
def _add_common_arguments(parser: ArgumentParser) -> ArgumentParser:
parser.add_argument(
"-bs",
"--batch-size",
type=positive_int(),
default=1,
help="training/validation batch size",
)
parser.add_argument(
"-nw",
"--num-workers",
type=nonnegative_int(),
default=16,
help="number of CPUs to use for loading data",
)
parser.add_argument(
"-rtc",
"--reorient-to-canonical",
action="store_true",
default=False,
help="reorient inputs images to canonical orientation "
"(useful when using data from heterogeneous sources "
"or using pseudo3d_dim == all; otherwise, e.g., the "
"axis corresponding to left-right in one image might "
"be anterior-posterior in another.)",
)
parser.add_argument(
"-p3d",
"--pseudo3d-dim",
type=nonnegative_int_or_none_or_all(),
nargs="+",
choices=(0, 1, 2, "all"),
default=None,
help="dim on which to concatenate the images for input "
"to a 2D network. If provided, either provide 1 value"
"to be used for each CSV or provide N values "
"corresponding to the N CSVs. If not provided, "
"use 3D network.",
)
parser.add_argument(
"-p3s",
"--pseudo3d-size",
type=positive_odd_int_or_none(),
default=None,
help="size of the pseudo3d dimension (if -p3d provided)",
)
parser.add_argument(
"-nsa",
"--non-strict-affine",
dest="strict_affine",
action="store_false",
default=True,
help="if images have different affine matrices, "
"resample the images to be consistent; avoid using"
"this by coregistering your images within a subject.",
)
parser.add_argument(
"-cd",
"--check-dicom",
action="store_true",
default=False,
help="check DICOM images to see if they have uniform "
"spacing between slices; warn the user if not.",
)
parser.add_argument(
"--use-memory-saving-dataset",
action="store_true",
default=False,
help="default dataset can leak memory when num_workers > 1, "
"use this if you encounter non-GPU OOM errors",
)
return parser
def __repr__(self) -> builtins.str:
desc = (
f"batch size: {self.batch_size}; "
f"patch size: {self.patch_size}; "
f"num workers: {self.num_workers}; "
f"pseudo3d dim: {self.pseudo3d_dim}; "
f"pseudo3d size: {self.pseudo3d_size}; "
f"reorient to canonical: {self.reorient_to_canonical}"
)
return desc
[docs]class LesionSegDataModuleTrain(LesionSegDataModuleBase):
"""Data module for training and validation for lesion segmentation
Args:
train_subject_list (typing.List[tio.Subject]):
list of torchio.Subject for training
val_subject_list (typing.List[tio.Subject]):
list of torchio.Subject for validation
batch_size (int): batch size for training/validation
patch_size (PatchShape): patch size for training/validation
queue_length (int): Maximum number of patches that can be
stored in the queue. Using a large number means that
the queue needs to be filled less often, but more CPU
memory is needed to store the patches.
samples_per_volume (int): Number of patches to extract from
each volume. A small number of patches ensures a large
variability in the queue, but training will be slower.
num_workers (int): number of subprocesses for data loading
label_sampler (bool): sample patches centered on positive labels
spatial_augmentation (bool): use random affine and elastic
data augmentation for training
pseudo3d_dim (typing.Optional[int]): concatenate images along this
axis and swap it for channel dimension
"""
# noinspection PyUnusedLocal
def __init__( # type: ignore[no-untyped-def]
self,
*,
train_subject_list: typing.List[tio.Subject],
val_subject_list: typing.List[tio.Subject],
batch_size: builtins.int = 2,
patch_size: PatchShape = (96, 96, 96),
queue_length: builtins.int = 200,
samples_per_volume: builtins.int = 10,
num_workers: builtins.int = 16,
label_sampler: builtins.bool = False,
spatial_augmentation: builtins.bool = False,
pseudo3d_dim: typing.Optional[builtins.int] = None,
pseudo3d_size: typing.Optional[builtins.int] = None,
reorient_to_canonical: builtins.bool = True,
num_classes: builtins.int = 1,
pos_sampling_weight: float = 1.0,
use_memory_saving_dataset: builtins.bool = False,
random_validation_patches: builtins.bool = False,
**kwargs,
):
super().__init__(
batch_size=batch_size,
patch_size=patch_size,
num_workers=num_workers,
pseudo3d_dim=pseudo3d_dim,
pseudo3d_size=pseudo3d_size,
reorient_to_canonical=reorient_to_canonical,
use_memory_saving_dataset=use_memory_saving_dataset,
)
self.train_subject_list = train_subject_list
self.val_subject_list = val_subject_list
self.queue_length = queue_length
self.samples_per_volume = samples_per_volume
self.label_sampler = label_sampler
self.spatial_augmentation = spatial_augmentation
self.num_classes = num_classes
self.pos_sampling_weight = pos_sampling_weight
if not random_validation_patches and pseudo3d_dim == "all":
msg = "Deterministic val patches not implemented w/ pseudo3d_dim = all"
warnings.warn(msg)
random_validation_patches = True
self.random_validation_patches = random_validation_patches
[docs] @classmethod
def from_csv( # type: ignore[no-untyped-def]
cls: typing.Type[TrainDataModule],
*,
train_csv: builtins.str,
valid_csv: builtins.str,
**kwargs,
) -> TrainDataModule:
strict_affine = kwargs.get("strict_affine", True)
check_dicom = kwargs.get("check_dicom", False)
tsl = csv_to_subjectlist(
train_csv,
strict=strict_affine,
check_dicom=check_dicom,
)
vsl = csv_to_subjectlist(
valid_csv,
strict=strict_affine,
check_dicom=check_dicom,
)
return cls(train_subject_list=tsl, val_subject_list=vsl, **kwargs)
[docs] def setup(self, stage: typing.Optional[builtins.str] = None) -> None:
super().setup(stage)
self._determine_input(
self.train_subject_list,
other_subjects=self.val_subject_list,
)
self._setup_train_dataset()
self._setup_val_dataset()
[docs] def train_dataloader(self) -> DataLoader:
sampler = self._get_train_sampler()
patches_queue = tio.Queue(
self.train_dataset,
self.queue_length,
self.samples_per_volume,
sampler,
num_workers=self.num_workers,
shuffle_subjects=True,
shuffle_patches=True,
)
train_dataloader = DataLoader(
patches_queue,
batch_size=self.batch_size,
collate_fn=self._collate_fn,
)
return train_dataloader
[docs] def val_dataloader(self) -> DataLoader:
if self._use_pseudo3d:
sampler = self._get_val_sampler()
patches_queue = tio.Queue(
self.val_dataset,
self.queue_length,
self.samples_per_volume,
sampler,
num_workers=self.num_workers,
shuffle_subjects=False,
shuffle_patches=False,
)
val_dataloader = DataLoader(
patches_queue,
batch_size=self.batch_size,
collate_fn=self._collate_fn,
)
else:
val_dataloader = DataLoader(
self.val_dataset,
batch_size=self.batch_size,
num_workers=self.num_workers,
pin_memory=True,
collate_fn=self._collate_fn,
)
return val_dataloader
def _get_train_augmentation(self) -> typing.Callable:
transforms = []
if self.reorient_to_canonical:
transforms.append(tio.ToCanonical())
if self.num_classes >= 1:
transforms.append(image_to_float())
else:
msg = f"num_classes must be positive. Got {self.num_classes}."
raise ValueError(msg)
if self.spatial_augmentation:
spatial = tio.OneOf(
{tio.RandomAffine(): 0.8, tio.RandomElasticDeformation(): 0.2},
p=0.75,
)
transforms.insert(1, spatial)
# noinspection PyTypeChecker
flip = tio.RandomFlip(axes=("LR",))
transforms.append(flip)
if self.pseudo3d_dim == "all":
transforms.insert(1, RandomTranspose())
transforms.append(RandomRot90())
transform: typing.Callable = tio.Compose(transforms)
return transform
def _get_train_sampler(self) -> typing.Union[tio.LabelSampler, tio.UniformSampler]:
if self.label_sampler:
p = self.pos_sampling_weight
lps = {0: 1.0 - p, 1: p}
return tio.LabelSampler(self.patch_size, label_probabilities=lps)
else:
return tio.UniformSampler(self.patch_size)
def _setup_train_dataset(self) -> None:
transform = self._get_train_augmentation()
ds = SubjectsDataset if self.use_memory_saving_dataset else tio.SubjectsDataset
subjects_dataset = ds(
self.train_subject_list,
transform=transform,
)
self.train_dataset = subjects_dataset
def _get_val_augmentation(self) -> typing.Callable:
transforms = []
if self.reorient_to_canonical:
transforms.append(tio.ToCanonical())
if self.num_classes >= 1:
transforms.append(image_to_float())
else:
msg = f"num_classes must be positive. Got {self.num_classes}."
raise ValueError(msg)
if not self._use_pseudo3d:
transforms.insert(1, tio.CropOrPad(self.patch_size))
if self.pseudo3d_dim == "all":
transforms.insert(1, RandomTranspose())
transform: typing.Callable = tio.Compose(transforms)
return transform
def _get_val_sampler(self) -> tio.LabelSampler:
p = self.pos_sampling_weight
lps = {0: 1.0 - p, 1: p}
if self.random_validation_patches:
sampler = tio.LabelSampler(self.patch_size, label_probabilities=lps)
else:
sampler = NonRandomLabelSampler(self.patch_size, label_probabilities=lps)
return sampler
def _setup_val_dataset(self) -> None:
transform = self._get_val_augmentation()
ds = SubjectsDataset if self.use_memory_saving_dataset else tio.SubjectsDataset
subjects_dataset = ds(
self.val_subject_list,
transform=transform,
)
self.val_dataset = subjects_dataset
# flake8: noqa: E501
def _collate_fn(
self, batch: typing.List[tio.Subject]
) -> typing.Tuple[torch.Tensor, torch.Tensor]:
collated_batch: Batch = default_collate(batch)
p3d = self._pseudo3d_dim_internal if self._use_pseudo3d else None
# offset by batch/channel dims if pseudo3d used
p3d_with_offset = (p3d + 2) if self._use_pseudo3d else None # type: ignore[operator]
src = self._default_collate_fn(collated_batch, cat_dim=p3d_with_offset)
tgt = collated_batch["label"][tio.DATA]
if self._use_pseudo3d:
tgt = self._pseudo3d_label(tgt, self._pseudo3d_dim_internal)
return src, tgt
[docs] @staticmethod
def add_arguments(parent_parser: ArgumentParser) -> ArgumentParser:
parser = parent_parser.add_argument_group("Data")
parser.add_argument(
"--train-csv",
type=file_path(),
nargs="+",
required=True,
default=["SET ME!"],
help="path(s) to CSV(s) with training images",
)
parser.add_argument(
"--valid-csv",
type=file_path(),
nargs="+",
required=True,
default=["SET ME!"],
help="path(s) to CSV(s) with validation images",
)
parser.add_argument(
"-ps",
"--patch-size",
type=positive_int(),
nargs="+",
default=[96, 96, 96],
help="training/validation patch size extracted from image",
)
parser.add_argument(
"-ql",
"--queue-length",
type=positive_int(),
default=200,
help="queue length for torchio sampler",
)
parser.add_argument(
"-spv",
"--samples-per-volume",
type=positive_int(),
default=10,
help="samples per volume for torchio sampler",
)
parser.add_argument(
"-ls",
"--label-sampler",
action="store_true",
default=False,
help="use label sampler instead of uniform",
)
parser.add_argument(
"-sa",
"--spatial-augmentation",
action="store_true",
default=False,
help="use spatial (affine and elastic) data augmentation",
)
parser.add_argument(
"-psw",
"--pos-sampling-weight",
type=probability_float(),
default=1.0,
help="sample positive voxels with this weight/negative voxels with 1.0-this",
)
parser.add_argument(
"--random-validation-patches",
action="store_true",
default=False,
help="randomly sample patches in the validation set",
)
LesionSegDataModuleBase._add_common_arguments(parser)
return parent_parser
def __repr__(self) -> builtins.str:
desc = super().__repr__()
desc += (
f"; queue length: {self.queue_length}; "
f"samples per volume: {self.samples_per_volume}; "
f"label sampler: {self.label_sampler}; "
f"spatial aug: {self.spatial_augmentation}; "
f"num classes: {self.num_classes}; "
f"pos sampling weight: {self.pos_sampling_weight}"
)
return desc
class WholeImagePredictBatch:
def __init__(
self,
*,
src: torch.Tensor,
affine: torch.Tensor,
path: typing.List[builtins.str],
out: typing.List[builtins.str],
reorient: builtins.bool,
):
self.src = src
self.affine = affine
self.path = path
self.out = out
self.reorient = reorient
self.validate()
def validate(self) -> None:
paths = (pathlib.Path(path) for path in self.path)
assert len(self.src) == len(self.affine)
assert len(self.affine) == len(self.path)
assert len(self.path) == len(self.out)
assert all(path.is_file() or path.is_dir() for path in paths)
assert isinstance(self.reorient, builtins.bool)
def to(self, device: typing.Union[builtins.str, torch.device], **kwargs) -> None: # type: ignore[no-untyped-def]
self.src = self.src.to(device=device, **kwargs)
class PatchesImagePredictBatch:
def __init__(
self,
*,
src: torch.Tensor,
affine: torch.Tensor,
path: builtins.str,
out: builtins.str,
locations: torch.Tensor,
grid_obj: types.SimpleNamespace,
pseudo3d_dim: typing.Optional[builtins.int],
total_batches: builtins.int,
reorient: builtins.bool,
):
self.src = src
self.affine = affine
self.path = path
self.out = out
self.locations = locations
self.grid_obj = grid_obj
self.pseudo3d_dim = pseudo3d_dim
self.total_batches = total_batches
self.reorient = reorient
self.validate()
def validate(self) -> None:
path = pathlib.Path(self.path)
assert len(self.src) == len(self.affine)
assert path.is_file() or path.is_dir()
assert self.pseudo3d_dim is None or (0 <= self.pseudo3d_dim <= 2)
assert self.total_batches >= 1
assert hasattr(self.grid_obj, "subject")
assert hasattr(self.grid_obj, "padding_mode")
assert hasattr(self.grid_obj, "patch_overlap")
assert hasattr(self.grid_obj.subject, "spatial_shape")
assert isinstance(self.reorient, builtins.bool)
def to(self, device: typing.Union[builtins.str, torch.device], **kwargs) -> None: # type: ignore[no-untyped-def]
self.src = self.src.to(device=device, **kwargs)
[docs]class LesionSegDataModulePredictBase(LesionSegDataModuleBase):
def __init__( # type: ignore[no-untyped-def]
self,
subjects: typing.Union[tio.Subject, typing.List[tio.Subject]],
batch_size: builtins.int,
patch_size: typing.Optional[PatchShapeOption] = None,
num_workers: builtins.int = 16,
pseudo3d_dim: typing.Optional[builtins.int] = None,
pseudo3d_size: typing.Optional[builtins.int] = None,
reorient_to_canonical: builtins.bool = True,
use_memory_saving_dataset: builtins.bool = False,
**kwargs,
):
super().__init__(
batch_size=batch_size,
patch_size=patch_size,
num_workers=num_workers,
pseudo3d_dim=pseudo3d_dim,
pseudo3d_size=pseudo3d_size,
reorient_to_canonical=reorient_to_canonical,
use_memory_saving_dataset=use_memory_saving_dataset,
)
self.predict_dataset: tio.SubjectsDataset
self.subjects = subjects
[docs] def setup(self, stage: typing.Optional[builtins.str] = None) -> None:
super().setup(stage)
self._determine_input(self.subjects)
self._setup_predict_dataset()
[docs] def predict_dataloader(self) -> DataLoader:
pred_dataloader: DataLoader = DataLoader(
self.predict_dataset,
batch_size=self.batch_size,
num_workers=self.num_workers,
pin_memory=True,
collate_fn=self._collate_fn,
)
self.total_batches = len(pred_dataloader)
return pred_dataloader
def _setup_predict_dataset(self) -> None:
raise NotImplementedError
def _collate_fn(
self, batch: typing.List[tio.Subject]
) -> typing.Union[WholeImagePredictBatch, PatchesImagePredictBatch]:
raise NotImplementedError
[docs] @staticmethod
def add_arguments(
parent_parser: ArgumentParser,
add_csv: builtins.bool = True,
) -> ArgumentParser:
parser = parent_parser.add_argument_group("Data")
if add_csv:
parser.add_argument(
"--predict-csv",
type=file_path(),
required=True,
default="SET ME!",
help="path to csv of prediction images",
)
parser.add_argument(
"-ps",
"--patch-size",
type=positive_int_or_none(),
nargs="+",
default=None,
help="shape of patches (None -> crop image to foreground)",
)
parser.add_argument(
"-po",
"--patch-overlap",
type=nonnegative_int(),
nargs=3,
default=None,
help="patches will overlap by this much (None -> patch-size // 2)",
)
LesionSegDataModuleBase._add_common_arguments(parser)
return parent_parser
[docs]class LesionSegDataModulePredictWhole(LesionSegDataModulePredictBase):
"""Data module for whole-image prediction for lesion segmentation
Args:
subjects (typing.List[tio.Subject]):
list of torchio.Subject for prediction
batch_size (int): number of images to predict at a time
num_workers (int):
number of subprocesses to use for data loading
"""
# noinspection PyUnusedLocal
def __init__( # type: ignore[no-untyped-def]
self,
*,
subjects: typing.Union[tio.Subject, typing.List[tio.Subject]],
batch_size: builtins.int,
num_workers: builtins.int = 16,
reorient_to_canonical: builtins.bool = True,
**kwargs,
):
super().__init__(
subjects=subjects,
batch_size=batch_size,
patch_size=None,
num_workers=num_workers,
pseudo3d_dim=None,
pseudo3d_size=None,
reorient_to_canonical=reorient_to_canonical,
)
[docs] @classmethod
def from_csv( # type: ignore[no-untyped-def]
cls: typing.Type[PredictDataModule],
predict_csv: builtins.str,
**kwargs,
) -> PredictDataModule:
strict_affine = kwargs.get("strict_affine", True)
check_dicom = kwargs.get("check_dicom", False)
subject_list = csv_to_subjectlist(
predict_csv,
strict=strict_affine,
check_dicom=check_dicom,
)
return cls(subjects=subject_list, **kwargs)
[docs] def setup(self, stage: typing.Optional[builtins.str] = None) -> None:
super().setup(stage)
def _setup_predict_dataset(self) -> None:
transforms = []
if self.reorient_to_canonical:
transforms.append(tio.ToCanonical())
transforms.append(image_to_float())
transform = tio.Compose(transforms)
ds = SubjectsDataset if self.use_memory_saving_dataset else tio.SubjectsDataset
subjects_dataset = ds(self.subjects, transform=transform)
self.predict_dataset = subjects_dataset
def _collate_fn(self, batch: typing.List[tio.Subject]) -> WholeImagePredictBatch:
collated_batch: Batch = default_collate(batch)
src: torch.Tensor = self._default_collate_fn(collated_batch)
# assume all input images are co-registered
# so arbitrarily choose first
field: builtins.str = self._input_fields[0]
first_field = collated_batch[field]
assert isinstance(first_field, dict)
affine: torch.Tensor = first_field[tio.AFFINE]
path: typing.List[builtins.str] = [
str(filepath) for filepath in first_field["path"]
]
out_path = collated_batch["out"] # path to save the prediction
assert isinstance(out_path, list)
out = WholeImagePredictBatch(
src=src,
affine=affine,
path=path,
out=out_path,
reorient=self.reorient_to_canonical,
)
return out
[docs]class LesionSegDataModulePredictPatches(LesionSegDataModulePredictBase):
"""Data module for patch-based prediction for lesion segmentation
Args:
subject (tio.Subject):
a torchio.Subject for prediction
batch_size (int): number of patches to predict at a time
patch_size (PatchShapeOption): patch size for training/validation
if any element is None, use the corresponding image dim
patch_overlap (typing.Optional[PatchShape]):
overlap of each patch, if None then patch_size // 2
num_workers (int):
number of subprocesses to use for data loading
pseudo3d_dim (typing.Optional[int]): concatenate images along this
axis and swap it for channel dimension
pseudo3d_size (typing.Optional[int]): number of slices to concatenate
if pseudo3d_dim provided, must be an odd (usually small) integer
"""
# noinspection PyUnusedLocal
def __init__( # type: ignore[no-untyped-def]
self,
*,
subject: tio.Subject,
batch_size: builtins.int = 1,
patch_size: PatchShapeOption = (96, 96, 96),
patch_overlap: typing.Optional[PatchShape] = None,
num_workers: builtins.int = 16,
pseudo3d_dim: typing.Optional[builtins.int] = None,
pseudo3d_size: typing.Optional[builtins.int] = None,
reorient_to_canonical: builtins.bool = True,
use_memory_saving_dataset: builtins.bool = False,
**kwargs,
):
super().__init__(
subjects=subject,
batch_size=batch_size,
patch_size=patch_size,
num_workers=num_workers,
pseudo3d_dim=pseudo3d_dim,
pseudo3d_size=pseudo3d_size,
reorient_to_canonical=reorient_to_canonical,
use_memory_saving_dataset=use_memory_saving_dataset,
)
assert self.patch_size is not None
# self.patch_size is the result from _determine_patch_size
ps: PatchShapeOption = self.patch_size
self._set_patch_size(subject, ps)
self.patch_overlap: PatchShape = patch_overlap or self._default_overlap(ps)
def _set_patch_size(
self,
subject: tio.Subject,
patch_size: PatchShapeOption,
) -> None:
if len(patch_size) != 3:
raise ValueError(
"Patch size must have length 3 here. "
f"Got {len(patch_size)}. Something went wrong."
)
image_dim = subject.spatial_shape
if len(image_dim) != 3:
raise ValueError(
"Input image must be three-dimensional. "
f"Got image dim of {len(image_dim)}."
)
patch_size_no_none = [ps or dim for ps, dim in zip(patch_size, image_dim)]
ps_x, ps_y, ps_z = patch_size_no_none
self.patch_size = (ps_x, ps_y, ps_z)
def _default_overlap(self, patch_size: PatchShapeOption) -> PatchShape:
patch_overlap_list = []
for i, ps in enumerate(patch_size):
if ps is None:
patch_overlap_list.append(0)
continue
if i == self.pseudo3d_dim:
assert self.pseudo3d_size is not None
patch_overlap_list.append(self.pseudo3d_size - 1)
continue
overlap = ps // 2
if overlap % 2:
overlap += 1
patch_overlap_list.append(overlap)
patch_overlap: PatchShape
if len(patch_size) == 2:
po_x, po_y = patch_overlap_list
patch_overlap = (po_x, po_y)
elif len(patch_size) == 3:
po_x, po_y, po_z = patch_overlap_list
patch_overlap = (po_x, po_y, po_z)
else:
raise ValueError(
f"patch_size must have length 2 or 3. Got {len(patch_size)}."
)
return patch_overlap
def _setup_predict_dataset(self) -> None:
field = self._input_fields[0]
self.path: builtins.str = self.subjects[field]["path"]
# `subjects` is only one subject in this class
if self.reorient_to_canonical:
self.subjects = tio.ToCanonical()(self.subjects)
self.subjects = image_to_float()(self.subjects)
grid_sampler = tio.GridSampler(
self.subjects,
self.patch_size,
self.patch_overlap,
padding_mode="edge",
)
# need to create aggregator in LesionSegLightning* module, which expects
# the grid sampler we don't want to send the whole sampler over though,
# so create a makeshift object with the relevant attributes that duck types
self.grid_obj = types.SimpleNamespace(
subject=types.SimpleNamespace(
spatial_shape=grid_sampler.subject.spatial_shape
),
padding_mode=grid_sampler.padding_mode,
patch_overlap=grid_sampler.patch_overlap,
)
self.predict_dataset = grid_sampler
# flake8: noqa: E501
def _collate_fn(self, batch: typing.List[tio.Subject]) -> PatchesImagePredictBatch:
collated_batch = default_collate(batch)
p3d = self._pseudo3d_dim_internal if self._use_pseudo3d else None
# offset by batch/channel dims if pseudo3d used
p3d_with_offset = (p3d + 2) if self._use_pseudo3d else None # type: ignore[operator]
src: torch.Tensor = self._default_collate_fn(
collated_batch, cat_dim=p3d_with_offset
)
# assume input images are co-registered so arbitrarily choose first
field: builtins.str = self._input_fields[0]
affine: torch.Tensor = collated_batch[field][tio.AFFINE]
out_path: builtins.str = collated_batch["out"] # path to save the prediction
locations: torch.Tensor = collated_batch[tio.LOCATION]
out = PatchesImagePredictBatch(
src=src,
affine=affine,
path=self.path,
out=out_path, # path to save the prediction
locations=locations,
grid_obj=self.grid_obj,
pseudo3d_dim=p3d,
total_batches=self.total_batches,
reorient=self.reorient_to_canonical,
)
return out
[docs]class Mixup:
"""mixup for data augmentation
See Also:
Zhang, Hongyi, et al. "mixup: Beyond empirical risk minimization."
arXiv preprint arXiv:1710.09412 (2017).
Args:
alpha (float): parameter for beta distribution
"""
def __init__(self, alpha: float):
self.alpha = alpha
def _mixup_dist(self, device: torch.device) -> D.Distribution:
alpha = torch.tensor(self.alpha, device=device)
dist = D.Beta(alpha, alpha)
return dist
def _mixup_coef(
self, batch_size: builtins.int, device: torch.device
) -> torch.Tensor:
dist = self._mixup_dist(device)
lam: torch.Tensor = dist.sample((batch_size,)) # convex combination coef
return lam
def __call__(
self, src: torch.Tensor, tgt: torch.Tensor
) -> typing.Tuple[torch.Tensor, torch.Tensor]:
with torch.no_grad():
batch_size = src.shape[0]
perm = torch.randperm(batch_size)
lam = self._mixup_coef(batch_size, src.device)
lam = reshape_for_broadcasting(lam, ndim=src.ndim)
src = lam * src + (1 - lam) * src[perm]
tgt = tgt.float()
tgt = lam * tgt + (1 - lam) * tgt[perm]
return src, tgt
[docs] @staticmethod
def add_arguments(parent_parser: ArgumentParser) -> ArgumentParser:
parser = parent_parser.add_argument_group("Mixup")
parser.add_argument(
"-mu",
"--mixup",
action="store_true",
default=False,
help="use mixup during training",
)
parser.add_argument(
"-ma",
"--mixup-alpha",
type=positive_float(),
default=0.4,
help="mixup alpha parameter for beta distribution",
)
return parent_parser
def _to_float(tensor: torch.Tensor) -> torch.Tensor:
"""create separate func b/c lambda not pickle-able"""
return tensor.float()
def image_to_float() -> tio.Transform:
"""cast an image from any type (e.g., uint8 or float64) to float32"""
return tio.Lambda(_to_float, types_to_apply=[tio.INTENSITY, tio.LABEL])
class RandomTranspose(
tio.transforms.augmentation.RandomTransform,
tio.SpatialTransform,
):
transposes = ((0, 1, 2, 3), (0, 2, 1, 3), (0, 3, 1, 2))
def apply_transform(self, subject: tio.Subject) -> tio.Subject:
index = self.get_params()
for image in self.get_images(subject):
data = image.data.permute(*self.transposes[index])
image.set_data(data)
return subject
@staticmethod
def get_params() -> builtins.int:
dim = int(torch.randint(0, 3, (1,)).item())
return dim
class RandomRot90(
tio.transforms.augmentation.RandomTransform,
tio.SpatialTransform,
):
def apply_transform(self, subject: tio.Subject) -> tio.Subject:
k = self.get_params()
for image in self.get_images(subject):
assert image.data.ndim == 4
data = image.data.rot90(k, (-2, -1))
image.set_data(data)
return subject
@staticmethod
def get_params() -> builtins.int:
n_rot = int(torch.randint(0, 4, size=(1,)).item())
return n_rot
def _get_type(name: builtins.str) -> builtins.str:
name_lower = name.lower()
_type: builtins.str
if name_lower == "label":
_type = tio.LABEL
elif name_lower in ("weight", "div"):
_type = "float"
elif name_lower == "out":
_type = "path"
elif name_lower in RECOGNIZED_NAMES:
_type = tio.INTENSITY
else:
warnings.warn(
f"{name} not in known {RECOGNIZED_NAMES}. "
f"Assuming {name} is a non-label image."
)
_type = tio.INTENSITY
return _type
[docs]def csv_to_subjectlist(
filename: builtins.str,
*,
strict: builtins.bool = True,
check_dicom: builtins.bool = False,
) -> typing.List[tio.Subject]:
"""Convert a csv file to a list of torchio subjects
Args:
filename: pathlib.Path to csv file formatted with
`subject` in a column, describing the
id/name of the subject (must be unique).
Row will fill in the filenames per type.
Other columns headers must be one of:
ct, flair, label, pd, t1, t1c, t2, weight, div
(`label` should correspond to a segmentation mask)
(`weight` and `div` should correspond to a float)
strict: if affine matrices are different enough
(according to torchio tolerance), raise a
runtime error. Otherwise, resample the images
of the subject to the first image.
check_dicom: if true, check dicom images for uniform
spacing and warn the user about image if there
is serious non-uniformity in slice distances
Returns:
subject_list (typing.List[torchio.Subject]): list of torchio Subjects
"""
df = pd.read_csv(filename, index_col="subject")
names = df.columns.to_list()
subject_list = []
for row in df.iterrows():
subject_name = row[0]
data = {}
for name in names:
val_type = _get_type(name)
val = row[1][name]
if val_type == "float":
data[name] = torch.tensor(val, dtype=torch.float32)
elif val_type == "path":
data[name] = val
else:
image_path = pathlib.Path(val)
if image_path.is_dir() and check_dicom:
_check_spacing_between_dicom_slices(image_path, strict)
data[name] = tio.Image(image_path, type=val_type)
subject = tio.Subject(name=subject_name, **data)
subject = _check_consistent_space_and_resample(subject, strict)
subject_list.append(subject)
return subject_list
def _check_consistent_space_and_resample(
subject: tio.Subject,
strict: builtins.bool = True,
) -> tio.Subject:
"""Check space of images in subject consistent; if not strict, resample."""
# spatial shape always needs to be the same
subject.check_consistent_spatial_shape()
if strict:
subject.check_consistent_affine()
else:
default_printoptions = np.get_printoptions()
np.set_printoptions(precision=5, suppress=True)
try:
subject.check_consistent_affine()
except RuntimeError as e:
warnings.warn(f"{subject['name']} has inconsistent affine matrices.")
logger.info(e)
logger.info("Attempting to resample the images to be consistent.")
affine = None
first_image = None
first_image_name = None
iterable = subject.get_images_dict(intensity_only=False).items()
for image_name, image in iterable:
if affine is None:
affine = image.affine
first_image = image
first_image_name = image_name
elif not np.allclose(affine, image.affine, rtol=1e-6, atol=1e-6):
aff_mtx_dist = np.linalg.norm(affine - image.affine)
logger.info(
f"Frobenius dist. between {first_image_name} and {image_name} "
f"affine matrices {aff_mtx_dist:0.4e}"
)
if aff_mtx_dist >= 1e-4:
msg = (
"Distance between affine matrices is large. "
"Consider aborting and registering the images manually."
)
warnings.warn(msg)
if image.type == tio.LABEL:
resampler = tio.Resample(first_image, "nearest")
else:
resampler = tio.Resample(first_image)
resampled = resampler(image)
subject[image_name] = resampled
np.set_printoptions(**default_printoptions)
return subject
def _check_spacing_between_dicom_slices(
dicom_dir: PathLike,
strict: builtins.bool = True,
) -> None:
try:
import pydicom # type: ignore[import]
except (ImportError, ModuleNotFoundError):
warnings.warn("pydicom not found. Cannot validate DICOM image.")
return
images = [pydicom.dcmread(path) for path in pathlib.Path(dicom_dir).glob("*.dcm")]
slice_thickness = float(images[0].SliceThickness)
def get_stack_position(image: pydicom.dataset.Dataset) -> builtins.int:
stack_position: builtins.int = image.InStackPositionNumber
return stack_position
sorted_images = sorted(images, key=get_stack_position)
positions = np.array([img.ImagePositionPatient for img in sorted_images])
space_between_positions = np.diff(positions, axis=0)
dist_between_slices = np.linalg.norm(space_between_positions, axis=1)
diff_in_dist = np.abs(np.diff(dist_between_slices))
median_dist_between_slices = np.median(dist_between_slices)
slice_thickness_msg = ""
if not np.isclose(slice_thickness, median_dist_between_slices):
slice_thickness_msg = (
f"Slice thickness: {slice_thickness:0.6f} != "
f"(Median) computed slice thickness {median_dist_between_slices:0.6f}"
)
max_diff_in_dist = diff_in_dist.max()
inconsistent_dist_msg = ""
if max_diff_in_dist > 5e-4:
# TODO: why is max_diff_in_dist different from ITK "Maximum nonuniformity"
inconsistent_dist_msg = (
"Maximum difference in distance between slices: "
f"{max_diff_in_dist:0.5e}."
)
if slice_thickness_msg or inconsistent_dist_msg:
msg = (
(slice_thickness_msg + "\n" + inconsistent_dist_msg)
if slice_thickness_msg and inconsistent_dist_msg
else (slice_thickness_msg or inconsistent_dist_msg)
)
if strict:
raise RuntimeError(msg)
else:
warnings.warn(msg)