Source code for tiramisu_brulee.experiment.util

"""Miscellaneous tools for experiments
Author: Jacob Reinhold <jcreinhold@gmail.com>
Created on: May 16, 2021
"""

__all__ = [
    "append_num_to_filename",
    "BoundingBox3D",
    "image_one_hot",
    "minmax_scale_batch",
    "reshape_for_broadcasting",
    "to_np",
    "setup_log",
    "split_filename",
]

import builtins
import logging
import pathlib
import typing

import numpy as np
import torch
import torch.nn.functional as F

from tiramisu_brulee.experiment.type import Indices

T = typing.TypeVar("T", bound="BoundingBox3D")


[docs]def minmax_scale_batch(x: torch.Tensor) -> torch.Tensor: """rescale a batch of image PyTorch tensors to be between 0 and 1""" dims = list(range(1, x.dim())) xmin = x.amin(dim=dims, keepdim=True) xmax = x.amax(dim=dims, keepdim=True) return (x - xmin) / (xmax - xmin)
[docs]def to_np(x: torch.Tensor) -> np.ndarray: """convert a PyTorch torch.Tensor (potentially on GPU) to a numpy array""" data = x.detach().cpu().numpy() assert isinstance(data, np.ndarray) return data
[docs]def image_one_hot(image: torch.Tensor, *, num_classes: builtins.int) -> torch.Tensor: num_channels = image.shape[1] if num_channels > 1: msg = f"Image must only have one channel. Got {num_channels} channels." raise RuntimeError(msg) encoded: torch.Tensor = F.one_hot(image.long(), num_classes) encoded = encoded.transpose(1, -1)[..., 0].type(image.type()) return encoded
[docs]class BoundingBox3D: def __init__( self, i_low: builtins.int, i_high: builtins.int, j_low: builtins.int, j_high: builtins.int, k_low: builtins.int, k_high: builtins.int, *, original_shape: typing.Optional[ typing.Tuple[builtins.int, builtins.int, builtins.int] ] = None, ): """bounding box indices and crop/uncrop func for 3d vols""" self.i = slice(i_low, i_high) self.j = slice(j_low, j_high) self.k = slice(k_low, k_high) self.original_shape = original_shape
[docs] def crop_to_bbox(self, tensor: torch.Tensor) -> torch.Tensor: """returns the tensor cropped around the saved bbox""" return tensor[..., self.i, self.j, self.k]
def __call__(self, tensor: torch.Tensor) -> torch.Tensor: return self.crop_to_bbox(tensor)
[docs] def uncrop(self, tensor: torch.Tensor) -> torch.Tensor: """places a tensor back into the saved original shape""" assert tensor.ndim == 3, "expects tensors with shape HxWxD" assert self.original_shape is not None out = torch.zeros(self.original_shape, dtype=tensor.dtype, device=tensor.device) out[self.i, self.j, self.k] = tensor return out
[docs] def uncrop_batch(self, batch: torch.Tensor) -> torch.Tensor: """places a batch back into the saved original shape""" assert batch.ndim == 5, "expects tensors with shape NxCxHxWxD" assert self.original_shape is not None batch_size, channel_size = batch.shape[:2] out_shape = (batch_size, channel_size) + tuple(self.original_shape) out = torch.zeros(out_shape, dtype=batch.dtype, device=batch.device) out[..., self.i, self.j, self.k] = batch return out
[docs] @staticmethod def find_bbox(mask: torch.Tensor, *, pad: builtins.int = 0) -> Indices: h = torch.where(torch.any(torch.any(mask, dim=1), dim=1))[0] w = torch.where(torch.any(torch.any(mask, dim=0), dim=1))[0] d = torch.where(torch.any(torch.any(mask, dim=0), dim=0))[0] h_low, h_high = h[0].item(), h[-1].item() w_low, w_high = w[0].item(), w[-1].item() d_low, d_high = d[0].item(), d[-1].item() i, j, k = mask.shape return ( builtins.int(max(h_low - pad, 0)), builtins.int(min(h_high + pad, i)), builtins.int(max(w_low - pad, 0)), builtins.int(min(w_high + pad, j)), builtins.int(max(d_low - pad, 0)), builtins.int(min(d_high + pad, k)), )
[docs] @classmethod def from_image( cls: typing.Type[T], image: torch.Tensor, *, pad: builtins.int = 0, foreground_min: builtins.float = 1e-4, ) -> T: """find a bounding box for a 3D tensor (with optional padding)""" foreground_mask = image > foreground_min assert isinstance(foreground_mask, torch.Tensor) bbox_idxs = cls.find_bbox(foreground_mask, pad=pad) original_shape = cls.get_shape(image) return cls(*bbox_idxs, original_shape=original_shape)
[docs] @classmethod def from_batch( cls: typing.Type[T], batch: torch.Tensor, *, pad: builtins.int = 0, channel: builtins.int = 0, foreground_min: builtins.float = 1e-4, ) -> T: """create bbox that works for a batch of 3d vols""" assert batch.ndim == 5, "expects tensors with shape NxCxHxWxD" batch_size = batch.shape[0] assert batch_size > 0 image_shape = batch.shape[2:] h_low, h_high = image_shape[0], -1 w_low, w_high = image_shape[1], -1 d_low, d_high = image_shape[2], -1 for i in range(batch_size): image = batch[i, channel, ...] hl, hh, wl, wh, dl, dh = cls.find_bbox(image > foreground_min, pad=pad) h_low, h_high = min(hl, h_low), max(hh, h_high) w_low, w_high = min(wl, w_low), max(wh, w_high) d_low, d_high = min(dl, d_low), max(dh, d_high) # noinspection PyUnboundLocalVariable original_shape = cls.get_shape(image) return cls( h_low, h_high, w_low, w_high, d_low, d_high, original_shape=original_shape, )
[docs] @staticmethod def get_shape( image: torch.Tensor, ) -> typing.Tuple[builtins.int, builtins.int, builtins.int]: assert image.ndim == 3 orig_x, orig_y, orig_z = tuple(image.shape) return (orig_x, orig_y, orig_z)
[docs]def reshape_for_broadcasting( tensor: torch.Tensor, *, ndim: builtins.int ) -> torch.Tensor: """expand dimensions of a 0- or 1-dimensional tensor to ndim for broadcast ops""" assert tensor.ndim <= 1 dims = [1 for _ in range(ndim - 1)] return tensor.view(-1, *dims)
[docs]def split_filename( filepath: typing.Union[builtins.str, pathlib.Path] ) -> typing.Tuple[pathlib.Path, builtins.str, builtins.str]: """split a filepath into the directory, base, and extension""" filepath = pathlib.Path(filepath).resolve() path = filepath.parent _base = pathlib.Path(filepath.stem) ext = filepath.suffix if ext == ".gz": ext2 = _base.suffix base = str(_base.stem) ext = ext2 + ext else: base = str(_base) return pathlib.Path(path), base, ext
[docs]def append_num_to_filename( filepath: typing.Union[builtins.str, pathlib.Path], *, num: builtins.int ) -> pathlib.Path: """append num to the filename of filepath and return the modified path""" path, base, ext = split_filename(filepath) base += f"_{num}" return path / (base + ext)
[docs]def setup_log(verbosity: builtins.int) -> None: """set logger with verbosity logging level and message""" if verbosity == 1: level = logging.getLevelName("INFO") elif verbosity >= 2: level = logging.getLevelName("DEBUG") else: level = logging.getLevelName("WARNING") fmt = "%(asctime)s - %(name)s - %(levelname)s - %(message)s" logging.basicConfig(format=fmt, level=level) logging.captureWarnings(True)