Source code for tiramisu_brulee.util
"""Miscellaneous functions
Author: Jacob Reinhold <jcreinhold@gmail.com>
Created on: Jul 01, 2020
"""
__all__ = ["InitType", "init_weights"]
import builtins
import enum
import torch
import torch.nn as nn
def is_conv(layer: nn.Module) -> builtins.bool:
classname = layer.__class__.__name__
return hasattr(layer, "weight") and "Conv" in classname
def is_norm(layer: nn.Module) -> builtins.bool:
classname = layer.__class__.__name__
return hasattr(layer, "weight") and "Norm" in classname
[docs]@enum.unique
class InitType(enum.Enum):
NORMAL = "normal"
XAVIER_NORMAL = "xavier_normal"
HE_NORMAL = "he_normal"
HE_UNIFORM = "he_uniform"
ORTHOGONAL = "orthogonal"
[docs] @classmethod
def from_string(cls, string: builtins.str) -> "InitType":
if string.lower() == "normal":
return cls.NORMAL
elif string.lower() == "xavier_normal":
return cls.XAVIER_NORMAL
elif string.lower() == "he_normal":
return cls.HE_NORMAL
elif string.lower() == "he_uniform":
return cls.HE_UNIFORM
elif string.lower() == "orthogonal":
return cls.ORTHOGONAL
else:
raise ValueError("Invalid init type.")
[docs]def init_weights(
net: nn.Module,
*,
init_type: InitType = InitType.NORMAL,
gain: builtins.float = 0.02,
) -> None:
def init_func(layer: nn.Module) -> None:
_is_conv = is_conv(layer)
_is_norm = is_norm(layer)
if not _is_conv and not _is_norm:
return
assert isinstance(layer.weight, torch.Tensor)
weight = layer.weight
assert weight is layer.weight
has_bias = hasattr(layer, "bias") and layer.bias is not None
if has_bias:
assert isinstance(layer.bias, torch.Tensor)
bias = layer.bias
assert bias is layer.bias
if _is_conv:
if init_type == InitType.NORMAL:
nn.init.normal_(weight, 0.0, gain)
elif init_type == InitType.XAVIER_NORMAL:
nn.init.xavier_normal_(weight, gain=gain)
elif init_type == InitType.HE_NORMAL:
nn.init.kaiming_normal_(weight, a=0.0, mode="fan_in")
elif init_type == InitType.HE_UNIFORM:
nn.init.kaiming_uniform_(weight, a=0.0, mode="fan_in")
elif init_type == InitType.ORTHOGONAL:
nn.init.orthogonal_(weight, gain=gain)
else:
err_msg = f"initialization type [{init_type}] not implemented"
raise NotImplementedError(err_msg)
if has_bias:
# noinspection PyUnboundLocalVariable
nn.init.constant_(bias, 0.0)
elif _is_norm:
nn.init.normal_(weight, 1.0, gain)
if has_bias:
# noinspection PyUnboundLocalVariable
nn.init.constant_(bias, 0.0)
net.apply(init_func)