Convolutional INN with invertible downsampling#

For the following architecture (which works e.g. for CIFAR10), 3/4 of the outputs are split off after some convolutions, which encode the local details, and the rest are transformed further to encode semantic content. This is important, because even for moderately sized images, it becomes infeasible to transform all dimenions through the full depth of the INN. Many dimensions will just enocde image noise, so we can split them off early. Because the computational graph contains multiple outputs, we have to use the full G machinery.

nodes = [Ff.InputNode(3, 32, 32, name='input')]
ndim_x = 3 * 32 * 32

# Higher resolution convolutional part
for k in range(4):
    nodes.append(Ff.Node(nodes[-1],
                         Fm.GLOWCouplingBlock,
                         {'subnet_constructor':subnet_conv, 'clamp':1.2},
                         name=F'conv_high_res_{k}'))
    nodes.append(Ff.Node(nodes[-1],
                         Fm.PermuteRandom,
                         {'seed':k},
                         name=F'permute_high_res_{k}'))

nodes.append(Ff.Node(nodes[-1], Fm.IRevNetDownsampling, {}))

# Lower resolution convolutional part
for k in range(12):
    if k%2 == 0:
        subnet = subnet_conv_1x1
    else:
        subnet = subnet_conv

    nodes.append(Ff.Node(nodes[-1],
                         Fm.GLOWCouplingBlock,
                         {'subnet_constructor':subnet, 'clamp':1.2},
                         name=F'conv_low_res_{k}'))
    nodes.append(Ff.Node(nodes[-1],
                         Fm.PermuteRandom,
                         {'seed':k},
                         name=F'permute_low_res_{k}'))

# Make the outputs into a vector, then split off 1/4 of the outputs for the
# fully connected part
nodes.append(Ff.Node(nodes[-1], Fm.Flatten, {}, name='flatten'))
split_node = Ff.Node(nodes[-1],
                    Fm.Split,
                    {'section_sizes':(ndim_x // 4, 3 * ndim_x // 4), 'dim':0},
                    name='split')
nodes.append(split_node)

# Fully connected part
for k in range(12):
    nodes.append(Ff.Node(nodes[-1],
                         Fm.GLOWCouplingBlock,
                         {'subnet_constructor':subnet_fc, 'clamp':2.0},
                         name=F'fully_connected_{k}'))
    nodes.append(Ff.Node(nodes[-1],
                         Fm.PermuteRandom,
                         {'seed':k},
                         name=F'permute_{k}'))

# Concatenate the fully connected part and the skip connection to get a single output
nodes.append(Ff.Node([nodes[-1].out0, split_node.out1],
                    Fm.Concat1d, {'dim':0}, name='concat'))
nodes.append(Ff.OutputNode(nodes[-1], name='output'))

conv_inn = Ff.GraphINN(nodes)