Source code for tiramisu_brulee.experiment.cli.train

#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
tiramisu_brulee.experiment.cli.train

command-line interface functions for training
lesion segmentation Tiramisu neural networks

Author: Jacob Reinhold (jcreinhold@gmail.com)
Created on: May 25, 2021
"""

__all__ = [
    "train",
]

import argparse
import builtins
import copy
import gc
import itertools
import logging
import os
import pathlib
import time
import typing
import warnings

import jsonargparse
import torch
from pytorch_lightning import LightningModule, Trainer, seed_everything
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.loggers import MLFlowLogger, TensorBoardLogger
from pytorch_lightning.plugins import DDPPlugin

from tiramisu_brulee.experiment.cli.common import (
    EXPERIMENT_NAME,
    check_patch_size,
    handle_fast_dev_run,
    pseudo3d_dims_setup,
    tiramisu_brulee_info,
)
from tiramisu_brulee.experiment.cli.predict import predict_parser
from tiramisu_brulee.experiment.data import LesionSegDataModuleTrain, Mixup
from tiramisu_brulee.experiment.parse import (
    fix_type_funcs,
    generate_predict_config_yaml,
    generate_train_config_yaml,
    get_best_model_path,
    get_experiment_directory,
    none_string_to_none,
    path_to_str,
    remove_args,
)
from tiramisu_brulee.experiment.seg import LesionSegLightningTiramisu
from tiramisu_brulee.experiment.type import ArgParser, ArgType
from tiramisu_brulee.experiment.util import setup_log

# num of dataloader workers is set to 0 for compatibility w/ torchio, so ignore warning
warnings.filterwarnings("ignore", ".*does not have many workers*", category=UserWarning)


def train_parser(use_python_argparse: builtins.bool = True) -> ArgParser:
    """argument parser for training a Tiramisu CNN"""
    if use_python_argparse:
        ArgumentParser = argparse.ArgumentParser
        config_action = None
    else:
        ArgumentParser = jsonargparse.ArgumentParser
        config_action = jsonargparse.ActionConfigFile
    desc = "Train a Tiramisu CNN to segment lesions"
    parser = ArgumentParser(
        prog="lesion-train",
        description=desc,
    )
    parser.add_argument(
        "--config",
        action=config_action,  # type: ignore[arg-type]
        help="path to a configuration file in json or yaml format",
    )
    exp_parser = parser.add_argument_group("Experiment")
    exp_parser.add_argument(
        "-sd",
        "--seed",
        type=int,
        default=0,
        help="set seed for reproducibility",
    )
    exp_parser.add_argument(
        "-v",
        "--verbosity",
        action="count",
        default=0,
        help="increase output verbosity (e.g., -vv is more than -v)",
    )
    exp_parser.add_argument(
        "-uri",
        "--tracking-uri",
        type=str,
        default=None,
        help="use this URI for tracking metrics and artifacts with an MLFlow server",
    )
    exp_parser.add_argument(
        "-md",
        "--model-dir",
        type=str,
        default=None,
        help="save models to this directory if provided, "
        "otherwise save in default_root_dir",
    )
    exp_parser.add_argument(
        "-en",
        "--experiment-name",
        type=str,
        default=EXPERIMENT_NAME,
        help="name of the experiment/output directory",
    )
    exp_parser.add_argument(
        "-tn",
        "--trial-name",
        type=str,
        default=None,
        help="name of the trial/output version directory",
    )
    exp_parser.add_argument(
        "-stk",
        "--save-top-k",
        type=int,
        default=3,
        help="save the top k models in checkpoints",
    )
    parser = LesionSegLightningTiramisu.add_io_arguments(parser)
    parser = LesionSegLightningTiramisu.add_model_arguments(parser)
    parser = LesionSegLightningTiramisu.add_other_arguments(parser)
    parser = LesionSegLightningTiramisu.add_training_arguments(parser)
    parser = LesionSegDataModuleTrain.add_arguments(parser)
    parser = Mixup.add_arguments(parser)
    parser = Trainer.add_argparse_args(parser)
    unnecessary_args = {
        "enable_checkpointing",
        "in_channels",
        "logger",
        "max_steps",
        "min_steps",
        "out_channels",
        "weights_save_path",
    }
    if use_python_argparse:
        unnecessary_args.union({"min_epochs", "max_epochs"})
    else:
        parser.link_arguments("n_epochs", "min_epochs")  # type: ignore[attr-defined]
        parser.link_arguments("n_epochs", "max_epochs")  # type: ignore[attr-defined]
    unnecessary_args = handle_fast_dev_run(unnecessary_args)
    remove_args(parser, unnecessary_args)
    fix_type_funcs(parser)
    return parser


[docs]def train( args: ArgType = None, *, return_best_model_paths: builtins.bool = False, ) -> typing.Union[typing.List[pathlib.Path], builtins.int]: """train a Tiramisu CNN for segmentation""" parser = train_parser(False) if args is None: args = parser.parse_args(_skip_check=True) # type: ignore[call-overload] elif isinstance(args, list): args = parser.parse_args(args, _skip_check=True) # type: ignore[call-overload] args = none_string_to_none(args) setup_log(args.verbosity) logger = logging.getLogger(__name__) seed_everything(args.seed, workers=True) args = path_to_str(args) n_models_to_train = _compute_num_models_to_train(args) best_model_paths: typing.List[pathlib.Path] = [] use_pseudo3d = args.pseudo3d_dim is not None check_patch_size(args.patch_size, use_pseudo3d) pseudo3d_dims = pseudo3d_dims_setup(args.pseudo3d_dim, n_models_to_train, "train") individual_run_args = copy.deepcopy(vars(args)) individual_run_args["network_dim"] = 2 if use_pseudo3d else 3 channels_per_image = args.pseudo3d_size if use_pseudo3d else 1 individual_run_args["in_channels"] = args.num_input * channels_per_image train_iter_data = zip(args.train_csv, args.valid_csv, pseudo3d_dims) trainer: typing.Optional[Trainer] = None for i, (train_csv, valid_csv, p3d) in enumerate(train_iter_data, 1): trainer, checkpoint_callback = _setup_trainer_and_checkpoint(args) nth_model = f" ({i}/{n_models_to_train})" individual_run_args["train_csv"] = train_csv individual_run_args["valid_csv"] = valid_csv individual_run_args["pseudo3d_dim"] = p3d dm = LesionSegDataModuleTrain.from_csv(**individual_run_args) model = LesionSegLightningTiramisu(**individual_run_args) logger.debug(model) if args.auto_scale_batch_size or args.auto_lr_find: tuning_output = trainer.tune(model, datamodule=dm) msg = f"Tune output{nth_model}:\n{tuning_output}" logger.info(msg) trainer.fit(model, datamodule=dm) best_model_path = get_best_model_path(checkpoint_callback) best_model_paths.append(best_model_path) msg = f"Best model path: {best_model_path}" + nth_model + "\n" msg += "Finished training" + nth_model logger.info(msg) # kill multiprocessing workers, free memory for the next iteration dm.teardown() trainer.teardown() del dm, model, checkpoint_callback if i != n_models_to_train: del trainer trainer = None gc.collect() torch.cuda.empty_cache() if n_models_to_train > 1 and args.num_workers > 0: time.sleep(5.0) if trainer is not None: assert isinstance(trainer.logger, (TensorBoardLogger, MLFlowLogger)) _generate_config_yamls_in_train(args, best_model_paths, trainer.logger) return best_model_paths if return_best_model_paths else 0
def _compute_num_models_to_train(args: ArgType) -> builtins.int: assert isinstance(args, (argparse.Namespace, jsonargparse.Namespace)) n_models_to_train = len(args.train_csv) if n_models_to_train != len(args.valid_csv): raise ValueError( "Number of training and validation CSVs must be equal.\n" f"Got {n_models_to_train} != {len(args.valid_csv)}" ) return n_models_to_train def _format_checkpoints(args: ArgType) -> typing.Dict[builtins.str, typing.Any]: assert isinstance(args, (argparse.Namespace, jsonargparse.Namespace)) checkpoint_format = "{epoch}-{val_loss:.3f}" if args.track_metric != "loss": checkpoint_format += f"-{{val_{args.track_metric}:.3f}}" checkpoint_kwargs = dict( dirpath=args.model_dir, filename=checkpoint_format, monitor=f"val_{args.track_metric}", save_top_k=args.save_top_k, save_last=True, mode="max" if args.track_metric != "loss" else "min", every_n_epochs=args.checkpoint_every_n_epochs, ) return checkpoint_kwargs def _artifact_directory(args: ArgType) -> builtins.str: assert isinstance(args, (argparse.Namespace, jsonargparse.Namespace)) if args.default_root_dir is not None: artifact_dir = pathlib.Path(args.default_root_dir).resolve() else: artifact_dir = pathlib.Path.cwd() return str(artifact_dir) def _setup_experiment_logger( args: ArgType, ) -> typing.Union[TensorBoardLogger, MLFlowLogger]: assert isinstance(args, (argparse.Namespace, jsonargparse.Namespace)) exp_logger: typing.Union[TensorBoardLogger, MLFlowLogger] if args.tracking_uri is not None: exp_logger = MLFlowLogger( experiment_name=args.experiment_name, run_name=args.trial_name, tracking_uri=args.tracking_uri, ) else: artifact_dir = _artifact_directory(args) ignore_tensorboard_dir = bool(os.getenv("TIRAMISU_IGNORE_TB_DIR", False)) tensorboard_dir = pathlib.Path( "/opt/ml/output/tensorboard" ).resolve() # for SageMaker if not tensorboard_dir.is_dir() or ignore_tensorboard_dir: exp_logger = TensorBoardLogger( artifact_dir, name=args.experiment_name, version=args.trial_name, sub_dir="tensorboard", ) else: logging.info(f"Saving tensorboard logs to {tensorboard_dir}.") logging.info("Set an env variable TIRAMISU_IGNORE_TB_DIR=true to prevent.") exp_logger = TensorBoardLogger(str(tensorboard_dir), name="", version="") return exp_logger def _generate_config_yamls_in_train( args: ArgType, best_model_paths: typing.List[pathlib.Path], logger: typing.Union[TensorBoardLogger, MLFlowLogger], ) -> None: assert isinstance(args, (argparse.Namespace, jsonargparse.Namespace)) has_fdr = hasattr(args, "fast_dev_run") generate_config_yaml = (not args.fast_dev_run) if has_fdr else True if generate_config_yaml: n_models_to_train = _compute_num_models_to_train(args) dict_args = vars(args) exp_dirs = [get_experiment_directory(bmp) for bmp in best_model_paths] if args.pseudo3d_dim == "all": dict_args["pseudo3d_dim"] = [0, 1, 2] * n_models_to_train best_model_paths = [bmp for bmp in best_model_paths for _ in range(3)] else: dict_args["pseudo3d_dim"] = args.pseudo3d_dim train_cfgs = generate_train_config_yaml( exp_dirs=exp_dirs, dict_args=dict_args, best_model_paths=best_model_paths, parser=train_parser(False), ) predict_cfgs = generate_predict_config_yaml( exp_dirs=exp_dirs, dict_args=dict_args, best_model_paths=best_model_paths, parser=predict_parser(False), ) if args.tracking_uri is not None: assert isinstance(logger, MLFlowLogger) run_id = logger.run_id for cfg in itertools.chain(train_cfgs, predict_cfgs): logger.experiment.log_artifact(run_id, cfg) class MLFlowModelCheckpoint(ModelCheckpoint): def __init__( # type: ignore[no-untyped-def] self, mlflow_logger: MLFlowLogger, *args, **kwargs, ): super().__init__(*args, **kwargs) self.mlflow_logger = mlflow_logger # flake8: noqa: E501 def save_checkpoint( self, trainer: Trainer, unused: typing.Optional[LightningModule] = None, ) -> None: try: super().save_checkpoint(trainer=trainer, unused=unused) # type: ignore[call-arg] except TypeError: super().save_checkpoint(trainer=trainer) run_id = self.mlflow_logger.run_id self.mlflow_logger.experiment.log_artifact(run_id, self.best_model_path) def _create_checkpoint_callback( args: ArgType, logger: typing.Union[TensorBoardLogger, MLFlowLogger], ) -> typing.Union[ModelCheckpoint, MLFlowModelCheckpoint]: checkpoint_kwargs = _format_checkpoints(args) assert isinstance(args, (argparse.Namespace, jsonargparse.Namespace)) if args.tracking_uri is None: return ModelCheckpoint(**checkpoint_kwargs) else: assert isinstance(logger, MLFlowLogger) return MLFlowModelCheckpoint(mlflow_logger=logger, **checkpoint_kwargs) def _setup_trainer_and_checkpoint( args: ArgType, ) -> typing.Tuple[Trainer, ModelCheckpoint]: assert isinstance(args, (argparse.Namespace, jsonargparse.Namespace)) use_multigpus = not (args.gpus is None or args.gpus <= 1) exp_logger = _setup_experiment_logger(args) checkpoint_callback = _create_checkpoint_callback(args, exp_logger) plugins = args.plugins if use_multigpus and args.accelerator == "ddp": plugins = DDPPlugin(find_unused_parameters=False) trainer = Trainer.from_argparse_args( args, logger=exp_logger, callbacks=[checkpoint_callback], plugins=plugins, ) if args.tracking_uri is not None: tb_info = tiramisu_brulee_info() run_id = trainer.logger.run_id trainer.logger.experiment.set_tag(run_id, "Version", tb_info.version) trainer.logger.experiment.set_tag(run_id, "Commit", tb_info.commit) return trainer, checkpoint_callback