Source code for tiramisu_brulee.experiment.cli.common

"""Common functions for predict and train CLIs
Author: Jacob Reinhold <jcreinhold@gmail.com>
Created on: Jul 30, 2021
"""

__all__ = [
    "check_patch_size",
    "EXPERIMENT_NAME",
    "handle_fast_dev_run",
    "pseudo3d_dims_setup",
    "tiramisu_brulee_info",
]

import builtins
import pathlib
import subprocess  # nosec
import sys
import typing

from tiramisu_brulee.experiment.type import TiramisuBruleeInfo

EXPERIMENT_NAME = "lesion_tiramisu_experiment"


[docs]def check_patch_size( patch_size: typing.List[builtins.int], use_pseudo3d: builtins.bool ) -> None: n_patch_elems = len(patch_size) if n_patch_elems != 2 and use_pseudo3d: raise ValueError( f"Number of patch size elements must be 2 for " f"pseudo-3D (2D) network. Got {len(patch_size)}." ) elif n_patch_elems != 3 and not use_pseudo3d: raise ValueError( f"Number of patch size elements must be 3 for " f"a 3D network. Got {len(patch_size)}." )
[docs]def handle_fast_dev_run( unnecessary_args: typing.Set[builtins.str], ) -> typing.Set[builtins.str]: """fast_dev_run is problematic with py36 so remove it""" py_version = sys.version_info assert py_version.major == 3 if py_version.minor == 6: unnecessary_args.add("fast_dev_run") return unnecessary_args
[docs]def pseudo3d_dims_setup( pseudo3d_dim: typing.Optional[typing.List[builtins.int]], n_models: builtins.int, stage: builtins.str, ) -> typing.Union[typing.List[None], typing.List[builtins.int]]: assert stage in ("train", "predict") if stage == "predict": stage = "us" n_p3d = 0 if pseudo3d_dim is None else len(pseudo3d_dim) pseudo3d_dims: typing.Union[typing.List[None], typing.List[builtins.int]] if n_p3d == 1 and pseudo3d_dim is not None: pseudo3d_dims = pseudo3d_dim * n_models elif n_p3d == n_models and pseudo3d_dim is not None: pseudo3d_dims = pseudo3d_dim elif pseudo3d_dim is None: pseudo3d_dims = [None] * n_models else: raise ValueError( "pseudo3d_dim must be None (for 3D network), 1 value to be used " f"across all models to be {stage}ed, or N values corresponding to each " f"of the N models to be {stage}ed. Got {n_p3d} != {n_models}." ) return pseudo3d_dims
[docs]def tiramisu_brulee_info(*, short: builtins.bool = True) -> TiramisuBruleeInfo: """get the git commit hash and version for tiramisu-brulee""" import tiramisu_brulee tb_path = str(pathlib.Path(tiramisu_brulee.__file__).parents[1]) cmd = ["git", "rev-parse", "HEAD"] if short: cmd.insert(2, "--short") out = subprocess.run( # nosec cmd, cwd=tb_path, stdout=subprocess.PIPE, stderr=subprocess.PIPE ) if out.returncode == 0: commit = out.stdout.decode("ascii").strip() else: commit = "unknown" return TiramisuBruleeInfo(version=tiramisu_brulee.__version__, commit=commit)