Source code for tiramisu_brulee.experiment.cli.predict

#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
tiramisu_brulee.experiment.cli.predict

command-line interface functions for predicting
lesion segmentations with Tiramisu neural network

Author: Jacob Reinhold (jcreinhold@gmail.com)
Created on: May 25, 2021
"""

__all__ = [
    "predict",
    "predict_image",
]

import argparse
import builtins
import collections
import functools
import gc
import inspect
import logging
import os
import pathlib
import tempfile
import typing
from concurrent.futures import ProcessPoolExecutor
from operator import add, or_

import jsonargparse
import numpy as np
import pandas as pd
import torch
import torchio as tio
from pytorch_lightning import Trainer, seed_everything

from tiramisu_brulee.experiment.cli.common import (
    check_patch_size,
    handle_fast_dev_run,
    pseudo3d_dims_setup,
)
from tiramisu_brulee.experiment.data import (
    LesionSegDataModulePredictBase,
    LesionSegDataModulePredictPatches,
    LesionSegDataModulePredictWhole,
    csv_to_subjectlist,
)
from tiramisu_brulee.experiment.lesion_tools import clean_segmentation
from tiramisu_brulee.experiment.parse import (
    dict_to_csv,
    fix_type_funcs,
    generate_predict_config_yaml,
    get_experiment_directory,
    none_string_to_none,
    parse_unknown_to_dict,
    path_to_str,
    remove_args,
)
from tiramisu_brulee.experiment.seg import LesionSegLightningTiramisu
from tiramisu_brulee.experiment.type import (
    ArgParser,
    ArgType,
    ModelNum,
    Namespace,
    file_path,
)
from tiramisu_brulee.experiment.util import append_num_to_filename, setup_log


def predict_parser(use_python_argparse: builtins.bool = True) -> ArgParser:
    """argument parser for using a Tiramisu CNN for prediction"""
    if use_python_argparse:
        ArgumentParser = argparse.ArgumentParser
        config_action = None
    else:
        ArgumentParser = jsonargparse.ArgumentParser
        config_action = jsonargparse.ActionConfigFile
    desc = "Use a Tiramisu CNN to segment lesions"
    parser = ArgumentParser(prog="lesion-predict", description=desc)
    parser.add_argument(
        "--config",
        action=config_action,  # type: ignore[arg-type]
        help="path to a configuration file in json or yaml format",
    )
    necessary_trainer_args = {
        "accelerator",
        "benchmark",
        "devices",
        "enable_progress_bar",
        "gpus",
        "precision",
        "progress_bar_refresh_rate",
        "strategy",
    }
    parser = _predict_parser_shared(parser, necessary_trainer_args, True)
    return parser


def predict_image_parser() -> argparse.ArgumentParser:
    """argument parser for using a Tiramisu CNN for single time-point prediction"""
    desc = "Use a Tiramisu CNN to segment lesions for a single time-point prediction"
    parser = argparse.ArgumentParser(prog="lesion-predict-image", description=desc)
    necessary_trainer_args = {
        "accelerator",
        "benchmark",
        "devices",
        "enable_progress_bar",
        "gpus",
        "precision",
        "progress_bar_refresh_rate",
        "strategy",
    }
    parser = _predict_parser_shared(parser, necessary_trainer_args, False)
    return parser


def _predict_parser_shared(
    parser: ArgParser,
    necessary_trainer_args: builtins.set,
    add_csv: builtins.bool,
) -> ArgParser:
    exp_parser = parser.add_argument_group("Experiment")
    exp_parser.add_argument(
        "-mp",
        "--model-path",
        type=file_path(),
        nargs="+",
        required=True,
        default=["SET ME!"],
        help="path to output the trained model",
    )
    exp_parser.add_argument(
        "-sd",
        "--seed",
        type=int,
        default=0,
        help="set seed for reproducibility",
    )
    exp_parser.add_argument(
        "-oa",
        "--only-aggregate",
        action="store_true",
        default=False,
        help="only aggregate results (useful to test different thresholds)",
    )
    exp_parser.add_argument(
        "-at",
        "--aggregation-type",
        default="mean",
        choices=("mean", "vote", "union"),
        help="aggregate results with this method",
    )
    exp_parser.add_argument(
        "-v",
        "--verbosity",
        action="count",
        default=0,
        help="increase output verbosity (e.g., -vv is more than -v)",
    )
    parser = LesionSegLightningTiramisu.add_other_arguments(parser)
    parser = LesionSegLightningTiramisu.add_testing_arguments(parser)
    parser = LesionSegDataModulePredictBase.add_arguments(parser, add_csv=add_csv)
    parser = Trainer.add_argparse_args(parser)
    trainer_args = set(inspect.signature(Trainer).parameters.keys())
    unnecessary_args = trainer_args - necessary_trainer_args
    unnecessary_args = handle_fast_dev_run(unnecessary_args)
    remove_args(parser, unnecessary_args)
    fix_type_funcs(parser)
    return parser


[docs]def predict(args: ArgType = None) -> builtins.int: """use a Tiramisu CNN for prediction""" parser = predict_parser(False) if args is None: args = parser.parse_args(_skip_check=True) # type: ignore[call-overload] elif isinstance(args, list): args = parser.parse_args(args, _skip_check=True) # type: ignore[call-overload] _predict(args, parser, True) return 0
[docs]def predict_image(args: ArgType = None) -> builtins.int: """use a Tiramisu CNN for prediction for a single time-point""" parser = predict_image_parser() if args is None: args, unknown = parser.parse_known_args() elif isinstance(args, list): args, unknown = parser.parse_known_args(args) else: raise ValueError("input args must be None or a list of strings to parse") modality_paths = parse_unknown_to_dict(unknown) with tempfile.NamedTemporaryFile("w", delete=False) as f: dict_to_csv(modality_paths, f) # type: ignore[arg-type] args.predict_csv = f.name _predict(args, parser, False) os.remove(args.predict_csv) return 0
def _predict_whole_image( args: Namespace, model_path: pathlib.Path, model_num: ModelNum, ) -> None: """predict a whole image volume as opposed to patches""" dict_args = vars(args) pp = args.predict_probability trainer = Trainer.from_argparse_args(args) dm = LesionSegDataModulePredictWhole.from_csv(**dict_args) model = LesionSegLightningTiramisu.load_from_checkpoint( str(model_path), predict_probability=pp, _model_num=model_num, ) logging.debug(model) trainer.predict(model, datamodule=dm) # kill multiprocessing workers for next iteration dm.teardown() trainer.teardown() def _predict_patch_image( args: Namespace, model_path: pathlib.Path, model_num: ModelNum, pseudo3d_dim: typing.Union[None, builtins.int], ) -> None: """predict a volume with patches as opposed to a whole volume""" dict_args = vars(args) dict_args["pseudo3d_dim"] = pseudo3d_dim pp = args.predict_probability subject_list = csv_to_subjectlist(args.predict_csv) model = LesionSegLightningTiramisu.load_from_checkpoint( str(model_path), predict_probability=pp, _model_num=model_num, ) logging.debug(model) for subject in subject_list: trainer = Trainer.from_argparse_args(args) dm = LesionSegDataModulePredictPatches(subject=subject, **dict_args) trainer.predict(model, datamodule=dm) # kill multiprocessing workers, free memory for the next iteration dm.teardown() trainer.teardown() del dm, trainer gc.collect() def _predict( args: Namespace, parser: ArgParser, use_multiprocessing: builtins.bool ) -> None: args = none_string_to_none(args) args = path_to_str(args) setup_log(args.verbosity) logger = logging.getLogger(__name__) seed_everything(args.seed, workers=True) n_models = len(args.model_path) if not args.only_aggregate: args.predict_probability = n_models > 1 or args.predict_probability patch_predict = args.patch_size is not None use_pseudo3d = args.pseudo3d_dim is not None if patch_predict: check_patch_size(args.patch_size, use_pseudo3d) pseudo3d_dims = pseudo3d_dims_setup(args.pseudo3d_dim, n_models, "predict") predict_iter_data = zip(args.model_path, pseudo3d_dims) for i, (model_path, p3d) in enumerate(predict_iter_data, 1): model_num = ModelNum(num=i, out_of=n_models) nth_model = f" ({i}/{n_models})" if patch_predict: _predict_patch_image(args, model_path, model_num, p3d) else: _predict_whole_image(args, model_path, model_num) logger.info("Finished prediction" + nth_model) # force garbage collection for the next iteration gc.collect() torch.cuda.empty_cache() if n_models > 1: num_workers = args.num_workers if use_multiprocessing else 0 aggregate( args.predict_csv, n_models, threshold=args.threshold, fill_holes=args.fill_holes, min_lesion_size=args.min_lesion_size, aggregation_type=args.aggregation_type, num_workers=num_workers, ) _generate_config_yamls_in_predict(args, parser) def _generate_config_yamls_in_predict(args: ArgType, parser: ArgParser) -> None: assert isinstance(args, (argparse.Namespace, jsonargparse.Namespace)) is_fast_dev_run = args.fast_dev_run if hasattr(args, "fast_dev_run") else False if ( not is_fast_dev_run and not args.only_aggregate and hasattr(parser, "get_defaults") ): exp_dirs = [] for mp in args.model_path: exp_dirs.append(get_experiment_directory(mp)) generate_predict_config_yaml(exp_dirs, parser, vars(args)) def aggregate( predict_csv: builtins.str, n_models: builtins.int, *, threshold: builtins.float = 0.5, fill_holes: builtins.bool = False, min_lesion_size: builtins.int = 3, aggregation_type: builtins.str = "mean", num_workers: typing.Optional[builtins.int] = None, ) -> None: """aggregate output from multiple model predictions""" df = pd.read_csv(predict_csv) out_fns = df["out"] n_fns = len(out_fns) out_fn_iter = enumerate(out_fns, 1) _aggregator = functools.partial( _aggregate, threshold=threshold, n_models=n_models, n_fns=n_fns, fill_holes=fill_holes, min_lesion_size=min_lesion_size, aggregation_type=aggregation_type, ) use_multiprocessing = num_workers is None or num_workers > 0 if use_multiprocessing: with ProcessPoolExecutor(num_workers) as executor: executor.map(_aggregator, out_fn_iter) else: collections.deque(map(_aggregator, out_fn_iter), maxlen=0) # noinspection PyUnboundLocalVariable def _aggregate( n_fn: typing.Tuple[builtins.int, builtins.str], *, threshold: builtins.float, n_models: builtins.int, n_fns: builtins.int, fill_holes: builtins.bool, min_lesion_size: builtins.int, aggregation_type: builtins.str = "mean", ) -> None: """aggregate helper for concurrent/parallel processing""" assert n_models >= 1 n, fn = n_fn data = [] for i in range(1, n_models + 1): _fn = append_num_to_filename(fn, num=i) image = tio.ScalarImage(_fn) array = image.numpy() data.append(array.squeeze()) if aggregation_type == "mean": agg = np.mean(data, axis=0) > threshold elif aggregation_type == "vote": _threshold = len(data) // 2 agg = functools.reduce(add, [d > threshold for d in data]) > _threshold elif aggregation_type == "union": agg = functools.reduce(or_, [d > threshold for d in data]) else: raise ValueError( "aggregation_type should be one of {mean, vote, union}. " f"Got {aggregation_type}." ) agg = clean_segmentation( agg, fill_holes=fill_holes, minimum_lesion_size=min_lesion_size ) agg = agg.astype(array.dtype) image.set_data(agg[np.newaxis]) image.save(fn) logging.info(f"Save aggregated prediction: {fn} ({n}/{n_fns})")