import warnings
from collections import deque, defaultdict
from typing import List, Tuple, Iterable, Union, Optional
import numpy as np
import torch
import torch.nn as nn
from torch import Tensor
from ...modules.base import InvertibleModule
from .nodes import AbstractNode, Node, ConditionNode, InputNode, OutputNode, FeedForwardNode
[docs]class GraphINN(InvertibleModule):
"""
This class represents the invertible net itself. It is a subclass of
InvertibleModule and supports the same methods.
The forward method has an additional option 'rev', with which the net can be
computed in reverse. Passing `jac` to the forward method additionally
computes the log determinant of the (inverse) Jacobian of the forward
(backward) pass.
"""
[docs] def __init__(self, node_list: Iterable[AbstractNode], force_tuple_output=False, verbose=False):
# Gather lists of input, output and condition nodes
in_nodes = [node for node in node_list if isinstance(node, InputNode)]
out_nodes = [node for node in node_list if isinstance(node, OutputNode)]
condition_nodes = [node for node in node_list if isinstance(node, ConditionNode)]
ff_nodes = [node for node in node_list if isinstance(node, FeedForwardNode)]
# Check that all nodes are in the list
for node in node_list:
for in_node, idx in node.inputs:
if in_node not in node_list:
raise ValueError(f"{node} gets input from {in_node}, "
f"but the latter is not in the node_list "
f"passed to GraphINN.")
for out_node, idx in node.outputs:
if out_node not in node_list:
raise ValueError(f"{out_node} gets input from {node}, "
f"but it's not in the node_list "
f"passed to GraphINN.")
for cond_node, idx in node.conditions:
if cond_node not in node_list:
raise ValueError(f"{node} is conditioned on {cond_node}, "
f"but the latter not in the node_list "
f"passed to GraphINN.")
# Global in- and output
global_in_shapes = [node.output_dims[0] for node in in_nodes]
global_out_shapes = [node.input_dims[0] for node in out_nodes]
global_cond_shapes = [node.output_dims[0] for node in condition_nodes]
# Only now we can set out shapes
super().__init__(global_in_shapes, global_cond_shapes)
self.node_list_fwd = topological_order(node_list, in_nodes, out_nodes, rev=False)
self.node_list_rev = topological_order(node_list, in_nodes, out_nodes, rev=True)
# Now we can store everything -- before calling super constructor,
# nn.Module doesn't allow assigning anything
self.in_nodes = in_nodes
self.condition_nodes = condition_nodes
self.out_nodes = out_nodes
self.ff_nodes = ff_nodes
self.global_out_shapes = global_out_shapes
self.force_tuple_output = force_tuple_output
self.module_list = nn.ModuleList([n.module for n in self.node_list_fwd
if n.module is not None])
if verbose:
print(self)
@property
def node_list(self):
return self.node_list_fwd
def output_dims(self, input_dims: List[Tuple[int]]) -> List[Tuple[int]]:
if len(self.global_out_shapes) == 1 and not self.force_tuple_output:
raise ValueError("You can only call output_dims on a "
"GraphINN with more than one output "
"or when setting force_tuple_output=True.")
return self.global_out_shapes
def forward(self, x_or_z: Union[Tensor, Iterable[Tensor]],
c: Iterable[Tensor] = None, rev: bool = False, jac: bool = True,
intermediate_outputs: bool = False, x: None = None) \
-> Tuple[Tuple[Tensor], Tensor]:
"""
Forward or backward computation of the whole net.
"""
if x is not None:
x_or_z = x
warnings.warn("You called GraphINN(x=...). x is now called x_or_z, "
"please pass input as positional argument.")
if torch.is_tensor(x_or_z):
x_or_z = x_or_z,
if torch.is_tensor(c):
c = c,
jacobian = torch.zeros(x_or_z[0].shape[0]).to(x_or_z[0])
outs = {}
jacobian_dict = {} if jac else None
# Explicitly set conditions and starts
start_nodes = self.out_nodes if rev else self.in_nodes
if len(x_or_z) != len(start_nodes):
raise ValueError(f"Got {len(x_or_z)} inputs, but expected "
f"{len(start_nodes)}.")
for tensor, start_node in zip(x_or_z, start_nodes):
outs[start_node, 0] = tensor
if c is None:
c = []
if len(c) != len(self.condition_nodes):
raise ValueError(f"Got {len(c)} conditions, but expected "
f"{len(self.condition_nodes)}.")
for tensor, condition_node in zip(c, self.condition_nodes):
outs[condition_node, 0] = tensor
# Go backwards through nodes if rev=True
node_list = self.node_list_rev if rev else self.node_list_fwd
for node in node_list:
# Skip input/condition/output nodes, they are handled above
if node in self.in_nodes + self.out_nodes + self.condition_nodes:
continue
# Collect inputs to node
mod_in = []
mod_c = []
for prev_node, channel in (node.outputs if rev else node.inputs):
mod_in.append(outs[prev_node, channel])
for cond_node, channel in (node.rev_conditions() if rev else node.conditions):
mod_c.append(outs[cond_node, channel])
mod_in = tuple(mod_in)
mod_c = tuple(mod_c)
try:
# Execute node
out, mod_jac = node.forward(x_or_z=mod_in, c=mod_c, rev=rev, jac=jac)
if jac and mod_jac is not None:
jacobian = jacobian + mod_jac
jacobian_dict[node] = mod_jac
except Exception as e:
raise RuntimeError(f"{node} encountered an error.") from e
for out_idx, out_value in enumerate(out):
outs[node, out_idx] = out_value
for out_node in (self.in_nodes if rev else self.out_nodes):
# This copies the one input of the out node
outs[out_node, 0] = outs[(out_node.outputs if rev
else out_node.inputs)[0]]
if intermediate_outputs:
return outs, jacobian_dict
else:
out_list = [outs[out_node, 0] for out_node
in (self.in_nodes if rev else self.out_nodes)]
if len(out_list) == 1 and not self.force_tuple_output:
return out_list[0], jacobian
else:
return tuple(out_list), jacobian
[docs] def log_jacobian_numerical(self, x, c=None, rev=False, h=1e-04):
"""
Approximate log Jacobian determinant via finite differences.
"""
if isinstance(x, (list, tuple)):
batch_size = x[0].shape[0]
ndim_x_separate = [np.prod(x_i.shape[1:]) for x_i in x]
ndim_x_total = sum(ndim_x_separate)
x_flat = torch.cat([x_i.view(batch_size, -1) for x_i in x], dim=1)
else:
batch_size = x.shape[0]
ndim_x_total = np.prod(x.shape[1:])
x_flat = x.reshape(batch_size, -1)
J_num = torch.zeros(batch_size, ndim_x_total, ndim_x_total)
for i in range(ndim_x_total):
offset = x[0].new_zeros(batch_size, ndim_x_total)
offset[:, i] = h
if isinstance(x, (list, tuple)):
x_upper = torch.split(x_flat + offset, ndim_x_separate, dim=1)
x_upper = [x_upper[i].view(*x[i].shape) for i in range(len(x))]
x_lower = torch.split(x_flat - offset, ndim_x_separate, dim=1)
x_lower = [x_lower[i].view(*x[i].shape) for i in range(len(x))]
else:
x_upper = (x_flat + offset).view(*x.shape)
x_lower = (x_flat - offset).view(*x.shape)
y_upper, _ = self.forward(x_upper, c=c, rev=rev, jac=False)
y_lower, _ = self.forward(x_lower, c=c, rev=rev, jac=False)
if isinstance(y_upper, (list, tuple)):
y_upper = torch.cat(
[y_i.view(batch_size, -1) for y_i in y_upper], dim=1)
y_lower = torch.cat(
[y_i.view(batch_size, -1) for y_i in y_lower], dim=1)
J_num[:, :, i] = (y_upper - y_lower).view(batch_size, -1) / (2 * h)
logdet_num = x[0].new_zeros(batch_size)
for i in range(batch_size):
logdet_num[i] = torch.slogdet(J_num[i])[1]
return logdet_num
[docs] def get_node_by_name(self, name) -> Optional[Node]:
"""
Return the first node in the graph with the provided name.
"""
for node in self.node_list:
if node.name == name:
return node
return None
[docs] def get_module_by_name(self, name) -> Optional[nn.Module]:
"""
Return module of the first node in the graph with the provided name.
"""
node = self.get_node_by_name(name)
try:
return node.module
except AttributeError:
return None
def topological_order(all_nodes: List[AbstractNode], in_nodes: List[InputNode],
out_nodes: List[OutputNode], rev: bool) -> List[AbstractNode]:
"""
Computes the topological order of nodes.
Parameters:
all_nodes: All nodes in the computation graph.
in_nodes: Input nodes (must also be present in `all_nodes`)
out_nodes: Output nodes (must also be present in `all_nodes`)
rev: Forward or backward topological order (differs because of conditioning)
Returns:
A sorted list of nodes, where the inputs to some node in the list
are available when all previous nodes in the list have been executed.
"""
# Topological order differs depending on computation direction
if not rev:
edges_out_to_in = {
node_b: {node_a for node_a, out_idx in node_b.inputs + node_b.conditions} for
node_b in all_nodes
}
start_nodes = in_nodes
end_nodes = out_nodes
else:
edges_out_to_in = {
node_b: {node_a for node_a, out_idx in node_b.outputs + node_b.rev_conditions()} for
node_b in all_nodes
}
start_nodes = out_nodes
end_nodes = in_nodes
# Reverse dict
edges_in_to_out = defaultdict(set)
for node_out, node_ins in edges_out_to_in.items():
for node_in in node_ins:
edges_in_to_out[node_in].add(node_out)
# Kahn's algorithm starting from the output nodes
sorted_nodes = []
no_pending_edges = deque(end_nodes)
while len(no_pending_edges) > 0:
node = no_pending_edges.popleft()
sorted_nodes.append(node)
for in_node in list(edges_out_to_in[node]):
# Mark edge as handled
edges_out_to_in[node].remove(in_node)
edges_in_to_out[in_node].remove(node)
# If this was the last edge to in_node, mark as ready to handle
if len(edges_in_to_out[in_node]) == 0:
no_pending_edges.append(in_node)
for in_node in start_nodes:
if in_node not in sorted_nodes:
raise ValueError(f"Error in graph: {in_node} is not connected "
f"to any {'out' if not rev else 'in'}put.")
if sum(map(len, edges_in_to_out.values())) == 0:
return sorted_nodes[::-1]
else:
raise ValueError("Graph is cyclic.")