FrEIA.modules package#

Subclasses of torch.nn.Module, that are reversible and can be used in the nodes of the GraphINN class. The only additional things that are needed compared to the base class is an @staticmethod otuput_dims, and the ‘rev’-argument of the forward-method.

Abstract template:

  • InvertibleModule

Coupling blocks:

  • AllInOneBlock

  • NICECouplingBlock

  • RNVPCouplingBlock

  • GLOWCouplingBlock

  • GINCouplingBlock

  • AffineCouplingOneSided

  • ConditionalAffineTransform

  • RationalQuadraticSpline

Reshaping:

  • IRevNetDownsampling

  • IRevNetUpsampling

  • HaarDownsampling

  • HaarUpsampling

  • Flatten

  • Reshape

Graph topology:

  • Split

  • Concat

Other learned transforms:

  • ActNorm

  • IResNetLayer

  • InvAutoAct

  • InvAutoActFixed

  • InvAutoActTwoSided

  • InvAutoConv2D

  • InvAutoFC

  • LearnedElementwiseScaling

  • OrthogonalTransform

  • HouseholderPerm

  • ElementwiseRationalQuadraticSpline

Fixed (non-learned) transforms:

  • PermuteRandom

  • FixedLinearTransform

  • Fixed1x1Conv

  • InvertibleSigmoid

Abstract template#

class FrEIA.modules.InvertibleModule(dims_in: List[Tuple[int]], dims_c: List[Tuple[int]] | None = None)[source]#

Base class for all invertible modules in FrEIA.

Given module, an instance of some InvertibleModule. This module shall be invertible in its input dimensions, so that the input can be recovered by applying the module in backwards mode (rev=True), not to be confused with pytorch.backward() which computes the gradient of an operation:

x = torch.randn(BATCH_SIZE, DIM_COUNT)
c = torch.randn(BATCH_SIZE, CONDITION_DIM)

# Forward mode
z, jac = module([x], [c], jac=True)

# Backward mode
x_rev, jac_rev = module(z, [c], rev=True)

The module returns \(\log \det J = \log \left| \det \frac{\partial f}{\partial x} \right|\) of the operation in forward mode, and \(-\log | \det J | = \log \left| \det \frac{\partial f^{-1}}{\partial z} \right| = -\log \left| \det \frac{\partial f}{\partial x} \right|\) in backward mode (rev=True).

Then, torch.allclose(x, x_rev) == True and torch.allclose(jac, -jac_rev) == True.

__init__(dims_in: List[Tuple[int]], dims_c: List[Tuple[int]] | None = None)[source]#
Parameters:
  • dims_in – list of tuples specifying the shape of the inputs to this operator: dims_in = [shape_x_0, shape_x_1, ...]

  • dims_c – list of tuples specifying the shape of the conditions to this operator.

forward(x_or_z: Iterable[Tensor], c: Iterable[Tensor] | None = None, rev: bool = False, jac: bool = True) Tuple[Tuple[Tensor], Tensor][source]#

Perform a forward (default, rev=False) or backward pass (rev=True) through this module/operator.

Note to implementers:

  • Subclasses MUST return a Jacobian when jac=True, but CAN return a valid Jacobian when jac=False (not punished). The latter is only recommended if the computation of the Jacobian is trivial.

  • Subclasses MUST follow the convention that the returned Jacobian be consistent with the evaluation direction. Let’s make this more precise: Let \(f\) be the function that the subclass represents. Then:

    \[\begin{split}J &= \log \det \frac{\partial f}{\partial x} \\ -J &= \log \det \frac{\partial f^{-1}}{\partial z}.\end{split}\]

    Any subclass MUST return \(J\) for forward evaluation (rev=False), and \(-J\) for backward evaluation (rev=True).

Parameters:
  • x_or_z – input data (array-like of one or more tensors)

  • c – conditioning data (array-like of none or more tensors)

  • rev – perform backward pass

  • jac – return Jacobian associated to the direction

log_jacobian(*args, **kwargs)[source]#

This method is deprecated, and does nothing except raise a warning.

output_dims(input_dims: List[Tuple[int]]) List[Tuple[int]][source]#

Used for shape inference during construction of the graph. MUST be implemented for each subclass of InvertibleModule.

Parameters:

input_dims – A list with one entry for each input to the module. Even if the module only has one input, must be a list with one entry. Each entry is a tuple giving the shape of that input, excluding the batch dimension. For example for a module with one input, which receives a 32x32 pixel RGB image, input_dims would be [(3, 32, 32)]

Returns:

A list structured in the same way as input_dims. Each entry represents one output of the module, and the entry is a tuple giving the shape of that output. For example if the module splits the image into a right and a left half, the return value should be [(3, 16, 32), (3, 16, 32)]. It is up to the implementor of the subclass to ensure that the total number of elements in all inputs and all outputs is consistent.

Coupling blocks#

class FrEIA.modules.AllInOneBlock(dims_in, dims_c=[], subnet_constructor: Callable | None = None, affine_clamping: float = 2.0, gin_block: bool = False, global_affine_init: float = 1.0, global_affine_type: str = 'SOFTPLUS', permute_soft: bool = False, learned_householder_permutation: int = 0, reverse_permutation: bool = False)[source]#

Module combining the most common operations in a normalizing flow or similar model.

It combines affine coupling, permutation, and global affine transformation (‘ActNorm’). It can also be used as GIN coupling block, perform learned householder permutations, and use an inverted pre-permutation. The affine transformation includes a soft clamping mechanism, first used in Real-NVP. The block as a whole performs the following computation:

\[y = V\,R \; \Psi(s_\mathrm{global}) \odot \mathrm{Coupling}\Big(R^{-1} V^{-1} x\Big)+ t_\mathrm{global}\]
  • The inverse pre-permutation of x (i.e. \(R^{-1} V^{-1}\)) is optional (see reverse_permutation below).

  • The learned householder reflection matrix \(V\) is also optional all together (see learned_householder_permutation below).

  • For the coupling, the input is split into \(x_1, x_2\) along the channel dimension. Then the output of the coupling operation is the two halves \(u = \mathrm{concat}(u_1, u_2)\).

    \[\begin{split}u_1 &= x_1 \odot \exp \Big( \alpha \; \mathrm{tanh}\big( s(x_2) \big)\Big) + t(x_2) \\ u_2 &= x_2\end{split}\]

    Because \(\mathrm{tanh}(s) \in [-1, 1]\), this clamping mechanism prevents exploding values in the exponential. The hyperparameter \(\alpha\) can be adjusted.

__init__(dims_in, dims_c=[], subnet_constructor: Callable | None = None, affine_clamping: float = 2.0, gin_block: bool = False, global_affine_init: float = 1.0, global_affine_type: str = 'SOFTPLUS', permute_soft: bool = False, learned_householder_permutation: int = 0, reverse_permutation: bool = False)[source]#
Parameters:
  • subnet_constructor – class or callable f, called as f(channels_in, channels_out) and should return a torch.nn.Module. Predicts coupling coefficients \(s, t\).

  • affine_clamping – clamp the output of the multiplicative coefficients before exponentiation to +/- affine_clamping (see \(\alpha\) above).

  • gin_block – Turn the block into a GIN block from Sorrenson et al, 2019. Makes it so that the coupling operations as a whole is volume preserving.

  • global_affine_init – Initial value for the global affine scaling \(s_\mathrm{global}\).

  • global_affine_init'SIGMOID', 'SOFTPLUS', or 'EXP'. Defines the activation to be used on the beta for the global affine scaling (\(\Psi\) above).

  • permute_soft – bool, whether to sample the permutation matrix \(R\) from \(SO(N)\), or to use hard permutations instead. Note, permute_soft=True is very slow when working with >512 dimensions.

  • learned_householder_permutation – Int, if >0, turn on the matrix \(V\) above, that represents multiple learned householder reflections. Slow if large number. Dubious whether it actually helps network performance.

  • reverse_permutation – Reverse the permutation before the block, as introduced by Putzky et al, 2019. Turns on the \(R^{-1} V^{-1}\) pre-multiplication above.

class FrEIA.modules.NICECouplingBlock(dims_in, dims_c=[], subnet_constructor: callable | None = None, split_len: float | int = 0.5)[source]#

Coupling Block following the NICE (Dinh et al, 2015) design. The inputs are split in two halves. For 2D, 3D, 4D inputs, the split is performed along the channel dimension. Then, residual coefficients are predicted by two subnetworks that are added to each half in turn.

__init__(dims_in, dims_c=[], subnet_constructor: callable | None = None, split_len: float | int = 0.5)[source]#

Additional args in docstring of base class.

Parameters:

subnet_constructor – Callable function, class, or factory object, with signature constructor(dims_in, dims_out). The result should be a torch nn.Module, that takes dims_in input channels, and dims_out output channels. See tutorial for examples. Two of these subnetworks will be initialized inside the block.

class FrEIA.modules.RNVPCouplingBlock(dims_in, dims_c=[], subnet_constructor: Callable | None = None, clamp: float = 2.0, clamp_activation: str | Callable = 'ATAN', split_len: float | int = 0.5)[source]#

Coupling Block following the RealNVP design (Dinh et al, 2017) with some minor differences. The inputs are split in two halves. For 2D, 3D, 4D inputs, the split is performed along the channel dimension. For checkerboard-splitting, prepend an i_RevNet_downsampling module. Two affine coupling operations are performed in turn on both halves of the input.

__init__(dims_in, dims_c=[], subnet_constructor: Callable | None = None, clamp: float = 2.0, clamp_activation: str | Callable = 'ATAN', split_len: float | int = 0.5)[source]#

Additional args in docstring of base class.

Parameters:
  • subnet_constructor – function or class, with signature constructor(dims_in, dims_out). The result should be a torch nn.Module, that takes dims_in input channels, and dims_out output channels. See tutorial for examples. Four of these subnetworks will be initialized in the block.

  • clamp – Soft clamping for the multiplicative component. The amplification or attenuation of each input dimension can be at most exp(±clamp).

  • clamp_activation – Function to perform the clamping. String values “ATAN”, “TANH”, and “SIGMOID” are recognized, or a function of object can be passed. TANH behaves like the original realNVP paper. A custom function should take tensors and map -inf to -1 and +inf to +1.

class FrEIA.modules.GLOWCouplingBlock(dims_in, dims_c=[], subnet_constructor: Callable | None = None, clamp: float = 2.0, clamp_activation: str | Callable = 'ATAN', split_len: float | int = 0.5)[source]#

Coupling Block following the GLOW design. Note, this is only the coupling part itself, and does not include ActNorm, invertible 1x1 convolutions, etc. See AllInOneBlock for a block combining these functions at once. The only difference to the RNVPCouplingBlock coupling blocks is that it uses a single subnetwork to jointly predict [s_i, t_i], instead of two separate subnetworks. This reduces computational cost and speeds up learning.

__init__(dims_in, dims_c=[], subnet_constructor: Callable | None = None, clamp: float = 2.0, clamp_activation: str | Callable = 'ATAN', split_len: float | int = 0.5)[source]#

Additional args in docstring of base class.

Parameters:
  • subnet_constructor – function or class, with signature constructor(dims_in, dims_out). The result should be a torch nn.Module, that takes dims_in input channels, and dims_out output channels. See tutorial for examples. Two of these subnetworks will be initialized in the block.

  • clamp – Soft clamping for the multiplicative component. The amplification or attenuation of each input dimension can be at most exp(±clamp).

  • clamp_activation – Function to perform the clamping. String values “ATAN”, “TANH”, and “SIGMOID” are recognized, or a function of object can be passed. TANH behaves like the original realNVP paper. A custom function should take tensors and map -inf to -1 and +inf to +1.

class FrEIA.modules.GINCouplingBlock(dims_in, dims_c=[], subnet_constructor: Callable | None = None, clamp: float = 2.0, clamp_activation: str | Callable = 'ATAN', split_len: float | int = 0.5)[source]#

Coupling Block following the GIN design. The difference from GLOWCouplingBlock (and other affine coupling blocks) is that the Jacobian determinant is constrained to be 1. This constrains the block to be volume-preserving. Volume preservation is achieved by subtracting the mean of the output of the s subnetwork from itself. While volume preserving, GIN is still more powerful than NICE, as GIN is not volume preserving within each dimension. Note: this implementation differs slightly from the originally published implementation, which scales the final component of the s subnetwork so the sum of the outputs of s is zero. There was no difference found between the implementations in practice, but subtracting the mean guarantees that all outputs of s are at most ±exp(clamp), which might be more stable in certain cases.

__init__(dims_in, dims_c=[], subnet_constructor: Callable | None = None, clamp: float = 2.0, clamp_activation: str | Callable = 'ATAN', split_len: float | int = 0.5)[source]#

Additional args in docstring of base class.

Parameters:
  • subnet_constructor – function or class, with signature constructor(dims_in, dims_out). The result should be a torch nn.Module, that takes dims_in input channels, and dims_out output channels. See tutorial for examples. Two of these subnetworks will be initialized in the block.

  • clamp – Soft clamping for the multiplicative component. The amplification or attenuation of each input dimension can be at most exp(±clamp).

  • clamp_activation – Function to perform the clamping. String values “ATAN”, “TANH”, and “SIGMOID” are recognized, or a function of object can be passed. TANH behaves like the original realNVP paper. A custom function should take tensors and map -inf to -1 and +inf to +1.

class FrEIA.modules.AffineCouplingOneSided(dims_in, dims_c=[], subnet_constructor: Callable | None = None, clamp: float = 2.0, clamp_activation: str | Callable = 'ATAN', split_len: float | int = 0.5)[source]#

Half of a coupling block following the GLOWCouplingBlock design. This means only one affine transformation on half the inputs. In the case where random permutations or orthogonal transforms are used after every block, this is not a restriction and simplifies the design.

__init__(dims_in, dims_c=[], subnet_constructor: Callable | None = None, clamp: float = 2.0, clamp_activation: str | Callable = 'ATAN', split_len: float | int = 0.5)[source]#

Additional args in docstring of base class.

Parameters:
  • subnet_constructor – function or class, with signature constructor(dims_in, dims_out). The result should be a torch nn.Module, that takes dims_in input channels, and dims_out output channels. See tutorial for examples. One subnetwork will be initialized in the block.

  • clamp – Soft clamping for the multiplicative component. The amplification or attenuation of each input dimension can be at most exp(±clamp).

  • clamp_activation – Function to perform the clamping. String values “ATAN”, “TANH”, and “SIGMOID” are recognized, or a function of object can be passed. TANH behaves like the original realNVP paper. A custom function should take tensors and map -inf to -1 and +inf to +1.

class FrEIA.modules.ConditionalAffineTransform(dims_in, dims_c=[], subnet_constructor: Callable | None = None, clamp: float = 2.0, clamp_activation: str | Callable = 'ATAN', split_len: float | int = 0.5)[source]#

Similar to the conditioning layers from SPADE (Park et al, 2019): Perform an affine transformation on the whole input, where the affine coefficients are predicted from only the condition.

__init__(dims_in, dims_c=[], subnet_constructor: Callable | None = None, clamp: float = 2.0, clamp_activation: str | Callable = 'ATAN', split_len: float | int = 0.5)[source]#

Additional args in docstring of base class.

Parameters:
  • subnet_constructor – function or class, with signature constructor(dims_in, dims_out). The result should be a torch nn.Module, that takes dims_in input channels, and dims_out output channels. See tutorial for examples. One subnetwork will be initialized in the block.

  • clamp – Soft clamping for the multiplicative component. The amplification or attenuation of each input dimension can be at most exp(±clamp).

  • clamp_activation – Function to perform the clamping. String values “ATAN”, “TANH”, and “SIGMOID” are recognized, or a function of object can be passed. TANH behaves like the original realNVP paper. A custom function should take tensors and map -inf to -1 and +inf to +1.

Reshaping#

class FrEIA.modules.IRevNetDownsampling(dims_in, dims_c=None, legacy_backend: bool = False)[source]#

The invertible spatial downsampling used in i-RevNet. Each group of four neighboring pixels is reordered into one pixel with four times the channels in a checkerboard-like pattern. See i-RevNet, Jacobsen 2018 et al.

__init__(dims_in, dims_c=None, legacy_backend: bool = False)[source]#

See docstring of base class (FrEIA.modules.InvertibleModule) for more.

Parameters:

legacy_backend

If True, uses the splitting and concatenating method, adapted from github.com/jhjacobsen/pytorch-i-revnet/blob/master/models/model_utils.py for the use in FrEIA. Is usually slower on GPU. If False, uses a 2d strided convolution with a kernel representing the downsampling. Note that the ordering of the output channels will be different. If pixels in each patch in channel 1 are a1, b1,..., and in channel 2 are a2, b2,... Then the output channels will be the following:

legacy_backend=True: a1, a2, ..., b1, b2, ..., c1, c2, ...

legacy_backend=False: a1, b1, ..., a2, b2, ..., a3, b3, ...

(see also order_by_wavelet in module HaarDownsampling) Generally this difference is completely irrelevant, unless a certaint subset of pixels or channels is supposed to be split off or extracted.

class FrEIA.modules.IRevNetUpsampling(dims_in, dims_c=None, legacy_backend: bool = False)[source]#

The inverted operation of IRevNetDownsampling (see that docstring for details).

__init__(dims_in, dims_c=None, legacy_backend: bool = False)[source]#

See docstring of base class (FrEIA.modules.InvertibleModule) for more.

Parameters:

legacy_backend

If True, uses the splitting and concatenating method, adapted from github.com/jhjacobsen/pytorch-i-revnet/blob/master/models/model_utils.py for the use in FrEIA. Is usually slower on GPU. If False, uses a 2d strided transposed convolution with a representing the downsampling. Note that the expected ordering of the input channels will be different. If pixels in each output patch in channel 1 are a1, b1,..., and in channel 2 are a2, b2,... Then the expected input channels are be the following:

legacy_backend=True: a1, a2, ..., b1, b2, ..., c1, c2, ...

legacy_backend=False: a1, b1, ..., a2, b2, ..., a3, b3, ...

(see also order_by_wavelet in module HaarDownsampling) Generally this difference is completely irrelevant, unless a certaint subset of pixels or channels is supposed to be split off or extracted.

class FrEIA.modules.HaarDownsampling(dims_in, dims_c=None, order_by_wavelet: bool = False, rebalance: float = 1.0)[source]#

Uses Haar wavelets to split each channel into 4 channels, with half the width and height dimensions.

__init__(dims_in, dims_c=None, order_by_wavelet: bool = False, rebalance: float = 1.0)[source]#

See docstring of base class (FrEIA.modules.InvertibleModule) for more.

Parameters:
  • order_by_wavelet

    Whether to group the output by original channels or by wavelet. I.e. if the average, vertical, horizontal and diagonal wavelets for channel 1 are a1, v1, h1, d1, those for channel 2 are a2, v2, h2, d2, etc, then the output channels will be structured as follows:

    set to True: a1, a2, ..., v1, v2, ..., h1, h2, ..., d1, d2, ...

    set to False: a1, v1, h1, d1, a2, v2, h2, d2, ...

    The True option is slightly slower to compute than the False option. The option is useful if e.g. the average channels should be split off by a FrEIA.modules.Split. Then, setting order_by_wavelet=True allows to split off the first quarter of channels to isolate the average wavelets only.

  • rebalance – Must be !=0. There exist different conventions how to define the Haar wavelets. The wavelet components in the forward direction are multiplied with this factor, and those in the inverse direction are adjusted accordingly, so that the module as a whole is invertible. Stability of the network may be increased for rebalance < 1 (e.g. 0.5).

class FrEIA.modules.HaarUpsampling(dims_in, dims_c=None, order_by_wavelet: bool = False, rebalance: float = 1.0)[source]#

The inverted operation of HaarDownsampling (see that docstring for details).

__init__(dims_in, dims_c=None, order_by_wavelet: bool = False, rebalance: float = 1.0)[source]#

See docstring of base class (FrEIA.modules.InvertibleModule) for more.

Parameters:
  • order_by_wavelet

    Expected grouping of the input channels by wavelet or by output channel. I.e. if the average, vertical, horizontal and diagonal wavelets for channel 1 are a1, v1, h1, d1, those for channel 2 are a2, v2, h2, d2, etc, then the input channels are taken as follows:

    set to True: a1, a2, ..., v1, v2, ..., h1, h2, ..., d1, d2, ...

    set to False: a1, v1, h1, d1, a2, v2, h2, d2, ...

    The True option is slightly slower to compute than the False option. The option is useful if e.g. the input has been concatentated from average channels and the higher-frequency channels. Then, setting order_by_wavelet=True allows to split off the first quarter of channels to isolate the average wavelets only.

  • rebalance – Must be !=0. There exist different conventions how to define the Haar wavelets. The wavelet components in the forward direction are multiplied with this factor, and those in the inverse direction are adjusted accordingly, so that the module as a whole is invertible. Stability of the network may be increased for rebalance < 1 (e.g. 0.5).

class FrEIA.modules.Flatten(dims_in, dims_c=None)[source]#

Flattens N-D tensors into 1-D tensors.

__init__(dims_in, dims_c=None)[source]#

See docstring of base class (FrEIA.modules.InvertibleModule).

class FrEIA.modules.Reshape(dims_in, dims_c=None, output_dims: Iterable[int] | None = None, target_dim=None)[source]#

Reshapes N-D tensors into target dim tensors. Note that the reshape resulting from e.g. (3, 32, 32) -> (12, 16, 16) will not necessarily be spatially sensible. See IRevNetDownsampling, IRevNetUpsampling, HaarDownsampling, HaarUpsampling for spatially meaningful reshaping operations.

__init__(dims_in, dims_c=None, output_dims: Iterable[int] | None = None, target_dim=None)[source]#

See docstring of base class (FrEIA.modules.InvertibleModule) for more.

Parameters:
  • output_dims – The shape the reshaped output is supposed to have (not including batch dimension)

  • target_dim – Deprecated name for output_dims

Graph topology#

class FrEIA.modules.Split(dims_in: Sequence[Sequence[int]], section_sizes: int | Sequence[int] | None = None, n_sections: int = 2, dim: int = 0)[source]#

Invertible split operation.

Splits the incoming tensor along the given dimension, and returns a list of separate output tensors. The inverse is the corresponding merge operation.

__init__(dims_in: Sequence[Sequence[int]], section_sizes: int | Sequence[int] | None = None, n_sections: int = 2, dim: int = 0)[source]#

Inits the Split module with the attributes described above and checks that split sizes and dimensionality are compatible.

Parameters:
  • dims_in – A list of tuples containing the non-batch dimensionality of all incoming tensors. Handled automatically during compute graph setup. Split only takes one input tensor.

  • section_sizes – If set, takes precedence over n_sections and behaves like the argument in torch.split(), except when a list of section sizes is given that doesn’t add up to the size of dim, an additional split section is created to take the slack. Defaults to None.

  • n_sections – If section_sizes is None, the tensor is split into n_sections parts of equal size or close to it. This mode behaves like numpy.array_split(). Defaults to 2, i.e. splitting the data into two equal halves.

  • dim – Index of the dimension along which to split, not counting the batch dimension. Defaults to 0, i.e. the channel dimension in structured data.

class FrEIA.modules.Concat(dims_in: Sequence[Sequence[int]], dim: int = 0)[source]#

Invertible merge operation.

Concatenates a list of incoming tensors along a given dimension and passes on the result. Inverse is the corresponding split operation.

__init__(dims_in: Sequence[Sequence[int]], dim: int = 0)[source]#

Inits the Concat module with the attributes described above and checks that all dimensions are compatible.

Parameters:
  • dims_in – A list of tuples containing the non-batch dimensionality of all incoming tensors. Handled automatically during compute graph setup. Dimensionality of incoming tensors must be identical, except in the merge dimension dim. Concat only makes sense with multiple input tensors.

  • dim – Index of the dimension along which to concatenate, not counting the batch dimension. Defaults to 0, i.e. the channel dimension in structured data.

Other learned transforms#

class FrEIA.modules.ActNorm(dims_in, dims_c=None, init_data: Tensor | None = None)[source]#

A technique to achieve stable flow initialization.

First introduced in Kingma et al. 2018: https://arxiv.org/abs/1807.03039 The module is similar to a traditional batch normalization layer, but the data mean and standard deviation are initialized from the first batch that is passed through the module. They are treated as learnable parameters from there on.

Using ActNorm layers interspersed throughout an INN ensures that intermediate outputs of the INN have standard deviation 1 and mean 0, so that the training is stable at the start, avoiding exploding or zeroed outputs. Just as with standard batch normalization layers, ActNorm contains additional channel-wise scaling and bias parameters.

__init__(dims_in, dims_c=None, init_data: Tensor | None = None)[source]#
Parameters:
  • dims_in – list of tuples specifying the shape of the inputs to this operator: dims_in = [shape_x_0, shape_x_1, ...]

  • dims_c – list of tuples specifying the shape of the conditions to this operator.

initialize(batch: Tensor)[source]#
load_state_dict(state_dict, strict=True)[source]#

Copies parameters and buffers from state_dict into this module and its descendants. If strict is True, then the keys of state_dict must exactly match the keys returned by this module’s state_dict() function.

Parameters:
  • state_dict (dict) – a dict containing parameters and persistent buffers.

  • strict (bool, optional) – whether to strictly enforce that the keys in state_dict match the keys returned by this module’s state_dict() function. Default: True

Returns:

  • missing_keys is a list of str containing the missing keys

  • unexpected_keys is a list of str containing the unexpected keys

Return type:

NamedTuple with missing_keys and unexpected_keys fields

property scale#
class FrEIA.modules.InvAutoAct(dims_in, dims_c=None, slope_init=2.0, learnable=True)[source]#

A nonlinear invertible activation analogous to Leaky ReLU, with learned slopes.

The slope is symmetric between the positive and negative side, i.e.

\[ \begin{align}\begin{aligned}x \geq 0 &\implies g(x) = x \odot \exp(\alpha)\\x < 0 &\implies g(x) = x \oslash \exp(\alpha)\end{aligned}\end{align} \]

A separate slope is learned for each entry along the first intput dimenison (after the batch dimenison). I.e. element-wise for flattened inputs, channel-wise for image inputs, etc.

__init__(dims_in, dims_c=None, slope_init=2.0, learnable=True)[source]#
Parameters:
  • slope_init – The initial value of the slope on the positive side. Accounts for the exp-activation, i.e. \(\exp(\alpha) =\) slope_init.

  • learnable – If False, the slopes are fixed at their initial value, and not learned.

class FrEIA.modules.InvAutoActTwoSided(dims_in, dims_c=None, init_pos: float = 2.0, init_neg: float = 0.5, learnable: bool = True)[source]#

A nonlinear invertible activation analogous to Leaky ReLU, with learned slopes.

The slopes are learned separately for each entry along the first intput dimenison (after the batch dimenison). I.e. element-wise for flattened inputs, channel-wise for image inputs, etc. Internally, the slopes are learned in log-space, to ensure they stay strictly > 0:

\[ \begin{align}\begin{aligned}x \geq 0 &\implies g(x) = x \odot \exp(\alpha_+)\\x < 0 &\implies g(x) = x \odot \exp(\alpha_-)\end{aligned}\end{align} \]
__init__(dims_in, dims_c=None, init_pos: float = 2.0, init_neg: float = 0.5, learnable: bool = True)[source]#
Parameters:
  • init_pos – The initial slope for the positive half of the activation. Must be > 0. Note that the initial value accounts for the exp-activation, meaning \(\exp(\alpha_+) =\) init_pos.

  • init_pos – The initial slope for the negative half of the activation. Must be > 0. The initial value accounts for the exp-activation the same as init_pos.

  • learnable – If False, the slopes are fixed at their initial value, and not learned.

class FrEIA.modules.LearnedElementwiseScaling(dims_in, dims_c=None, init_scale=1.0)[source]#

Scale each element of the input by a learned, non-negative factor. Unlike most other FrEIA modules, the scaling is not e.g. channel-wise for images, but really scales each individual element. To ensure positivity, the scaling is learned in log-space:

\[g(x) = x \odot \exp(s)\]
__init__(dims_in, dims_c=None, init_scale=1.0)[source]#
Parameters:

init_scale – The initial scaling value. It accounts for the exp-activation, i.e. \(\exp(s) =\) init_scale.

class FrEIA.modules.OrthogonalTransform(dims_in, dims_c=None, correction_interval: int = 256, clamp: float = 5.0)[source]#

Learnable orthogonal matrix, with additional scaling and bias term.

The matrix is learned as a completely free weight matrix, and projected back to the Stiefel manifold (set of all orthogonal matrices) in regular intervals. With input x, the output z is computed as

\[z = \Psi(s) \odot Rx + b\]

R is the orthogonal matrix, b the bias, s the scaling, and \(\Psi\) is a clamped scaling activation \(\Psi(\cdot) = \exp(\frac{2 \alpha}{\pi} \mathrm{atan}(\cdot))\).

__init__(dims_in, dims_c=None, correction_interval: int = 256, clamp: float = 5.0)[source]#
Parameters:
  • correction_interval – After this many gradient steps, the matrix is projected back to the Stiefel manifold to make it perfectly orthogonal.

  • clamp – clamps the log scaling for stability. Corresponds to \(alpha\) above.

class FrEIA.modules.HouseholderPerm(dims_in, dims_c=None, n_reflections: int = 1, fixed: bool = False)[source]#

Fast product of a series of learned Householder matrices. This implementation is based on work by Mathiesen et al, 2020: https://invertibleworkshop.github.io/accepted_papers/pdfs/10.pdf Only works for flattened 1D input tensors.

The module can be used in one of two ways:

  • Without a condition, the reflection vectors that form the householder matrices are learned as free parameters

  • Used as a conditional module, the condition conatins the reflection vectors. The module does not have any learnable parameters in that case, but the condition can be backpropagated (e.g. to predict the reflection vectors by some other network). The condition must have the shape (input size, n_reflections).

__init__(dims_in, dims_c=None, n_reflections: int = 1, fixed: bool = False)[source]#
Parameters:
  • n_reflections – How many subsequent householder reflections to perform. Each householder reflection is learned independently. Must be >= 2 due to implementation reasons.

  • fixed – If true, the householder matrices are initialized randomly and only computed once, and then kept fixed from there on.

Fixed (non-learned) transforms#

class FrEIA.modules.PermuteRandom(dims_in, dims_c=None, seed: int | None = None)[source]#

Constructs a random permutation, that stays fixed during training. Permutes along the first (channel-) dimension for multi-dimenional tensors.

__init__(dims_in, dims_c=None, seed: int | None = None)[source]#

Additional args in docstring of base class FrEIA.modules.InvertibleModule.

Parameters:

seed – Int seed for the permutation (numpy is used for RNG). If seed is None, do not reseed RNG.

class FrEIA.modules.FixedLinearTransform(dims_in, dims_c=None, M: Tensor | None = None, b: None | Tensor = None)[source]#

Fixed linear transformation for 1D input tesors. The transformation is \(y = Mx + b\). With d input dimensions, M must be an invertible d x d tensor, and b is an optional offset vector of length d.

__init__(dims_in, dims_c=None, M: Tensor | None = None, b: None | Tensor = None)[source]#

Additional args in docstring of base class FrEIA.modules.InvertibleModule.

Parameters:
  • M – Square, invertible matrix, with which each input is multiplied. Shape (d, d).

  • b – Optional vector which is added element-wise. Shape (d,).

class FrEIA.modules.Fixed1x1Conv(dims_in, dims_c=None, M: Tensor | None = None)[source]#

Given an invertible matrix M, a 1x1 convolution is performed using M as the convolution kernel. Effectively, a matrix muplitplication along the channel dimension is performed in each pixel.

__init__(dims_in, dims_c=None, M: Tensor | None = None)[source]#

Additional args in docstring of base class FrEIA.modules.InvertibleModule.

Parameters:

M – Square, invertible matrix, with which each input is multiplied. Shape (d, d).

class FrEIA.modules.InvertibleSigmoid(dims_in, **kwargs)[source]#

Applies the sigmoid function element-wise across all batches, and the associated inverse function in reverse pass. Contains no trainable parameters. Sigmoid function S(x) and its corresponding inverse function is given by

\[\begin{split}S(x) &= \frac{1}{1 + \exp(-x)} \\ S^{-1}(x) &= \log{\frac{x}{1-x}}.\end{split}\]

The returning Jacobian is computed as

\[J = \log \det \frac{1}{(1+\exp{x})(1+\exp{-x})}.\]
__init__(dims_in, **kwargs)[source]#
Parameters:
  • dims_in – list of tuples specifying the shape of the inputs to this operator: dims_in = [shape_x_0, shape_x_1, ...]

  • dims_c – list of tuples specifying the shape of the conditions to this operator.

Approximately- or semi-invertible transforms#

class FrEIA.modules.InvAutoFC(dims_in, dims_c=None, dims_out=None)[source]#

Fully connected ‘Invertible Autoencoder’-layer (see arxiv.org/pdf/1802.06869.pdf). The weight matrix of the inverse is the tranposed weight matrix of the forward pass. If a reconstruction loss between forward and inverse is used, the layer converges to an invertible, orthogonal, linear transformation.

__init__(dims_in, dims_c=None, dims_out=None)[source]#
Parameters:

dims_out – If None, the output dimenison equals the input dimenison. However, becuase InvAuto is only asymptotically invertible, there is no strict limitation to have the same number of input- and ouput-dimensions. If dims_out is an integer instead of None, that number of output dimensions is used.

class FrEIA.modules.InvAutoConv2D(dims_in, dims_c=None, dims_out=None, kernel_size=3, padding=1)[source]#

Convolutional variant of the ‘Invertible Autoencoder’-layer (see arxiv.org/pdf/1802.06869.pdf). The the inverse is a tranposed convolution with the same kernel as the forward pass. If a reconstruction loss between forward and inverse is used, the layer converges to an invertible, orthogonal, linear transformation.

__init__(dims_in, dims_c=None, dims_out=None, kernel_size=3, padding=1)[source]#
Parameters:
  • kernel_size – Spatial size of the convlution kernel.

  • padding – Padding of the input. Choosing padding = kernel_size // 2 retains the image shape between in- and output.

  • dims_out – If None, the output dimenison equals the input dimenison. However, becuase InvAuto is only asymptotically invertible, there is no strict limitation to have the same number of input- and ouput-dimensions. Therefore dims_out can also be a tuple of length 3: (channels, width, height). The channels are the output channels of the convolution. The user is responsible for making the width and height match with the actual output, depending on kernel_size and padding.

class FrEIA.modules.IResNetLayer(dims_in, dims_c=[], internal_size=None, n_internal_layers=1, jacobian_iterations=20, hutchinson_samples=1, fixed_point_iterations=50, lipschitz_iterations=10, lipschitz_batchsize=10, spectral_norm_max=0.8)[source]#

Implementation of the i-ResNet architecture as proposed in https://arxiv.org/pdf/1811.00995.pdf

__init__(dims_in, dims_c=[], internal_size=None, n_internal_layers=1, jacobian_iterations=20, hutchinson_samples=1, fixed_point_iterations=50, lipschitz_iterations=10, lipschitz_batchsize=10, spectral_norm_max=0.8)[source]#
Parameters:
  • dims_in – list of tuples specifying the shape of the inputs to this operator: dims_in = [shape_x_0, shape_x_1, ...]

  • dims_c – list of tuples specifying the shape of the conditions to this operator.

lipschitz_correction()[source]#
class FrEIA.modules.GaussianMixtureModel(dims_in, dims_c)[source]#

An invertible Gaussian mixture model. The weights, means, covariance parameterization and component index must be supplied as conditional inputs to the module and can come from an external feed-forward network, which may be trained by backpropagating through the GMM. Weights should first be normalized via GaussianMixtureModel.normalize_weights(w) and component indices can be sampled via GaussianMixtureModel.pick_mixture_component(w). If component indices are specified, the model reduces to that Gaussian mixture component and maps between data x and standard normal latent variable z. Components can also be chosen consistently at random, by supplying an integer random seed instead of indices. If a None value is supplied instead of indices, the model maps between K data points x and K latent codes z simultaneously, where K is the number of mixture components. Mathematical derivations are found in the technical report “Training Mixture Density Networks with full covariance matrices” on arXiv.

__init__(dims_in, dims_c)[source]#
Parameters:
  • dims_in – list of tuples specifying the shape of the inputs to this operator: dims_in = [shape_x_0, shape_x_1, ...]

  • dims_c – list of tuples specifying the shape of the conditions to this operator.

static nll_loss(w, z, log_jacobian)[source]#

Negative log-likelihood loss for training a Mixture Density Network.

w: Mixture component weights, must be positive and sum to

one. Tensor must be of size [batch_size, n_components].

z: Latent codes for all mixture components. Tensor must be

of size [batch, n_components, n_dims].

log_jacobian: Jacobian log-determinants for each precision matrix.

Tensor size must be [batch_size, n_components].

static nll_upper_bound(w, z, log_jacobian)[source]#

Numerically more stable upper bound of the negative log-likelihood loss for training a Mixture Density Network.

w: Mixture component weights, must be positive and sum to

one. Tensor must be of size [batch_size, n_components].

z: Latent codes for all mixture components. Tensor must be

of size [batch, n_components, n_dims].

log_jacobian: Jacobian log-determinants for each precision matrix.

Tensor size must be [batch_size, n_components].

static normalize_weights(w)[source]#

Apply softmax to ensure component weights are positive and sum to one. Works on batches of component weights.

w: Unnormalized weights for Gaussian mixture components, must be of

size [batch_size, n_components]

static pick_mixture_component(w, seed=None)[source]#

Randomly choose mixture component indices with probability given by the component weights w. Works on batches of component weights.

w: Weights of the mixture components, must be positive and sum to one seed: Optional RNG seed for consistent decisions