Source code for FrEIA.framework.sequence_inn

from typing import Iterable, Tuple, List

import torch.nn as nn
import torch
from torch import Tensor

from FrEIA.modules import InvertibleModule


[docs]class SequenceINN(InvertibleModule): """ Simpler than FrEIA.framework.GraphINN: Only supports a sequential series of modules (no splitting, merging, branching off). Has an append() method, to add new blocks in a more simple way than the computation-graph based approach of GraphINN. For example: ``` inn = SequenceINN(channels, dims_H, dims_W) for i in range(n_blocks): inn.append(FrEIA.modules.AllInOneBlock, clamp=2.0, permute_soft=True) inn.append(FrEIA.modules.HaarDownsampling) # and so on ``` """
[docs] def __init__(self, *dims: int, force_tuple_output=False): super().__init__([dims]) self.shapes = [tuple(dims)] self.conditions = [] self.module_list = nn.ModuleList() self.force_tuple_output = force_tuple_output
[docs] def append(self, module_class, cond=None, cond_shape=None, **kwargs): """ Append a reversible block from FrEIA.modules to the network. module_class: Class from FrEIA.modules. cond (int): index of which condition to use (conditions will be passed as list to forward()). Conditioning nodes are not needed for SequenceINN. cond_shape (tuple[int]): the shape of the condition tensor. **kwargs: Further keyword arguments that are passed to the constructor of module_class (see example). """ dims_in = [self.shapes[-1]] self.conditions.append(cond) if cond is not None: kwargs['dims_c'] = [cond_shape] module = module_class(dims_in, **kwargs) self.module_list.append(module) ouput_dims = module.output_dims(dims_in) assert len(ouput_dims) == 1, "Module has more than one output" self.shapes.append(ouput_dims[0])
[docs] def output_dims(self, input_dims: List[Tuple[int]]) -> List[Tuple[int]]: if not self.force_tuple_output: raise ValueError("You can only call output_dims on a SequentialINN " "when setting force_tuple_output=True.") return input_dims
[docs] def forward(self, x_or_z: Tensor, c: Iterable[Tensor] = None, rev: bool = False, jac: bool = True) -> Tuple[Tensor, Tensor]: """ Executes the sequential INN in forward or inverse (rev=True) direction. Arguments: x_or_z: input tensor (in contrast to GraphINN, a list of tensors is not supported, as SequenceINN only has one input). c: list of conditions. rev: whether to compute the network forward or reversed. jac: whether to compute the log jacobian Returns: z_or_x (Tensor): network output. jac (Tensor): log-jacobian-determinant. """ iterator = range(len(self.module_list)) jac = 0 if rev: iterator = reversed(iterator) if torch.is_tensor(x_or_z): x_or_z = (x_or_z,) for i in iterator: if self.conditions[i] is None: x_or_z, j = self.module_list[i](x_or_z, jac=jac, rev=rev) else: x_or_z, j = self.module_list[i](x_or_z, c=c[self.conditions[i]], jac=jac, rev=rev) jac = j + jac return x_or_z if self.force_tuple_output else x_or_z[0], jac