"""Experiment-specific types
Author: Jacob Reinhold <jcreinhold@gmail.com>
Created on: May 28, 2021
"""
__all__ = [
"ArgParser",
"ArgType",
"Batch",
"BatchElement",
"file_path",
"Indices",
"ModelNum",
"Namespace",
"new_parse_type",
"nonnegative_float",
"nonnegative_int",
"nonnegative_int_or_none_or_all",
"PatchShapeOption",
"PatchShape",
"PathLike",
"positive_float",
"positive_float_or_none",
"positive_int",
"positive_int_or_none",
"positive_odd_int_or_none",
"probability_float",
"probability_float_or_none",
"TiramisuBruleeInfo",
]
import argparse
import builtins
import collections
import os
import pathlib
import typing
import jsonargparse
import torch
BatchElement = typing.Union[
torch.Tensor, typing.Dict[builtins.str, typing.Any], typing.List[typing.Any]
]
Batch = typing.Dict[builtins.str, BatchElement]
Indices = typing.Tuple[
builtins.int, builtins.int, builtins.int, builtins.int, builtins.int, builtins.int
]
ModelNum = collections.namedtuple("ModelNum", ["num", "out_of"])
Namespace = typing.Union[argparse.Namespace, jsonargparse.Namespace]
PatchShape2D = typing.Tuple[builtins.int, builtins.int]
PatchShape3D = typing.Tuple[builtins.int, builtins.int, builtins.int]
PatchShape = typing.Union[PatchShape2D, PatchShape3D]
PatchShape2DOption = typing.Tuple[
typing.Optional[builtins.int], typing.Optional[builtins.int]
]
PatchShape3DOption = typing.Tuple[
typing.Optional[builtins.int],
typing.Optional[builtins.int],
typing.Optional[builtins.int],
]
PatchShapeOption = typing.Union[PatchShape2DOption, PatchShape3DOption]
ArgType = typing.Optional[typing.Union[Namespace, typing.Iterable[builtins.str]]]
ArgParser = typing.Union[argparse.ArgumentParser, jsonargparse.ArgumentParser]
TiramisuBruleeInfo = collections.namedtuple("TiramisuBruleeInfo", ["version", "commit"])
PathLike = typing.Union[builtins.str, os.PathLike]
# flake8: noqa: E501
def return_none(func: typing.Callable) -> typing.Callable:
def new_func(self, string: typing.Any) -> typing.Any: # type: ignore[no-untyped-def]
if string is None:
return None
elif isinstance(string, builtins.str):
if string.lower() in ("none", "null"):
return None
return func(self, string)
return new_func
# flake8: noqa: E501
def return_str(match_string: builtins.str) -> typing.Callable:
def decorator(func: typing.Callable) -> typing.Callable:
def new_func(self, string: typing.Any) -> typing.Any: # type: ignore[no-untyped-def]
if isinstance(string, builtins.str):
if string.lower() == match_string:
return match_string
return func(self, string)
return new_func
return decorator
class _ParseType:
@property
def __name__(self) -> builtins.str:
name = self.__class__.__name__
assert isinstance(name, builtins.str)
return name
def __str__(self) -> builtins.str:
return self.__name__
[docs]class file_path(_ParseType):
def __call__(self, string: builtins.str) -> builtins.str:
path = pathlib.Path(string)
if not path.is_file():
msg = f"{string} is not a valid path."
raise argparse.ArgumentTypeError(msg)
return str(path.resolve())
[docs]class positive_float(_ParseType):
def __call__(self, string: builtins.str) -> builtins.float:
num = float(string)
if num <= 0.0:
msg = f"{string} needs to be a positive float."
raise argparse.ArgumentTypeError(msg)
return num
[docs]class positive_float_or_none(_ParseType):
@return_none
def __call__(self, string: builtins.str) -> typing.Union[builtins.float, None]:
return positive_float()(string)
[docs]class positive_int(_ParseType):
def __call__(self, string: builtins.str) -> builtins.int:
num = int(string)
if num <= 0:
msg = f"{string} needs to be a positive integer."
raise argparse.ArgumentTypeError(msg)
return num
[docs]class positive_odd_int_or_none(_ParseType):
@return_none
def __call__(self, string: builtins.str) -> typing.Union[builtins.int, None]:
num = int(string)
if num <= 0 or not (num % 2):
msg = f"{string} needs to be a positive odd integer."
raise argparse.ArgumentTypeError(msg)
return num
[docs]class positive_int_or_none(_ParseType):
@return_none
def __call__(self, string: builtins.str) -> typing.Union[builtins.int, None]:
return positive_int()(string)
[docs]class nonnegative_int(_ParseType):
def __call__(self, string: builtins.str) -> builtins.int:
num = int(string)
if num < 0:
msg = f"{string} needs to be a nonnegative integer."
raise argparse.ArgumentTypeError(msg)
return num
[docs]class nonnegative_int_or_none_or_all(_ParseType):
@return_none
@return_str("all")
def __call__(
self, string: builtins.str
) -> typing.Union[builtins.int, None, builtins.str]:
return nonnegative_int()(string)
[docs]class nonnegative_float(_ParseType):
def __call__(self, string: builtins.str) -> builtins.float:
num = float(string)
if num < 0.0:
msg = f"{string} needs to be a nonnegative float."
raise argparse.ArgumentTypeError(msg)
return num
[docs]class probability_float(_ParseType):
def __call__(self, string: builtins.str) -> builtins.float:
num = float(string)
if num < 0.0 or num > 1.0:
msg = f"{string} needs to be between 0 and 1."
raise argparse.ArgumentTypeError(msg)
return num
[docs]class probability_float_or_none(_ParseType):
@return_none
def __call__(self, string: builtins.str) -> typing.Union[builtins.float, None]:
return probability_float()(string)
class NewParseType:
def __init__(self, func: typing.Callable, name: builtins.str):
self.name = name
self.func = func
def __str__(self) -> builtins.str:
return self.name
def __call__(self, val: typing.Any) -> typing.Any:
return self.func(val)
[docs]def new_parse_type(func: typing.Callable, name: builtins.str) -> NewParseType:
return NewParseType(func, name)