Skip to content

Commit

Permalink
Addition of ICNN Initialization Schemes (#90)
Browse files Browse the repository at this point in the history
* Add ICNN formulation with new initialization schemes.

* Adapt ICNN test.

* Add notebook comparing both ICNN initialization schemes.

* Update notebook on neural dual.

* Add notebook to documentation.

* Integration of comments by Marco.

* Integration of comments by Marco.

* Integration of comments by Marco.

* Integration of comments by Marco.

* 😵

* 🤯
  • Loading branch information
bunnech authored Jun 30, 2022
1 parent 81646ed commit 597eee5
Show file tree
Hide file tree
Showing 5 changed files with 324 additions and 142 deletions.
1 change: 1 addition & 0 deletions docs/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ There are currently three packages, ``geometry``, ``core`` and ``tools``, playin
notebooks/gromov_wasserstein_multiomics.ipynb
notebooks/fairness.ipynb
notebooks/neural_dual.ipynb
notebooks/icnn_inits.ipynb

.. toctree::
:maxdepth: 1
Expand Down
196 changes: 114 additions & 82 deletions ott/core/icnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
Expand All @@ -20,65 +20,21 @@
import jax
import jax.numpy as jnp
from flax import linen as nn
from jax.nn import initializers

from ott.core.layers import PosDefPotentials, PositiveDense
from ott.geometry.matrix_square_root import sqrtm, sqrtm_only

PRNGKey = Any
Shape = Tuple[int]
Dtype = Any # this could be a real type?
Dtype = Any
Array = Any


class PositiveDense(nn.Module):
"""A linear transformation using a weight matrix with all entries positive.
Args:
dim_hidden: the number of output dim_hidden.
beta: inverse temperature parameter of the softplus function (default: 1).
use_bias: whether to add a bias to the output (default: True).
dtype: the dtype of the computation (default: float32).
precision: numerical precision of computation see `jax.lax.Precision`
for details.
kernel_init: initializer function for the weight matrix.
bias_init: initializer function for the bias.
"""

dim_hidden: int
beta: float = 1.0
use_bias: bool = True
dtype: Any = jnp.float32
precision: Any = None
kernel_init: Callable[[PRNGKey, Shape, Dtype],
Array] = nn.initializers.lecun_normal()
bias_init: Callable[[PRNGKey, Shape, Dtype], Array] = nn.initializers.zeros

@nn.compact
def __call__(self, inputs):
"""Apply a linear transformation to inputs along the last dimension.
Args:
inputs: The nd-array to be transformed.
Returns:
The transformed input.
"""
inputs = jnp.asarray(inputs, self.dtype)
kernel = self.param(
'kernel', self.kernel_init, (inputs.shape[-1], self.dim_hidden)
)
scaled_kernel = self.beta * kernel
kernel = jnp.asarray(1 / self.beta * nn.softplus(scaled_kernel), self.dtype)
y = jax.lax.dot_general(
inputs,
kernel, (((inputs.ndim - 1,), (0,)), ((), ())),
precision=self.precision
)
if self.use_bias:
bias = self.param('bias', self.bias_init, (self.dim_hidden,))
bias = jnp.asarray(bias, self.dtype)
y = y + bias
return y


class ICNN(nn.Module):
"""Input convex neural network (ICNN) architeture.
"""Input convex neural network (ICNN) architeture with initialization.
Containing initialization schemes introduced in Bunne+(2022).
Args:
dim_hidden: sequence specifying size of hidden dimensions. The
Expand All @@ -88,66 +44,142 @@ class ICNN(nn.Module):
`jax.nn.initializers.normal`).
act_fn: choice of activation function used in network architecture
(needs to be convex, default: `nn.leaky_relu`).
pos_weights: choice to enforce positivity of weight or use regularizer.
dim_data: data dimensionality (default: 2).
gaussian_map: data inputs of source and target measures for
initialization scheme based on Gaussian approximation of input and
target measure (if None, identity initialization is used).
"""

dim_hidden: Sequence[int]
init_std: float = 0.1
init_std: float = 1e-1
init_fn: Callable = jax.nn.initializers.normal
act_fn: Callable = nn.leaky_relu
act_fn: Callable = nn.relu
pos_weights: bool = True
dim_data: int = 2
gaussian_map: Tuple[jnp.ndarray, jnp.ndarray] = None

def setup(self):
num_hidden = len(self.dim_hidden)

w_zs = []
self.num_hidden = len(self.dim_hidden)

if self.pos_weights:
Dense = PositiveDense
hid_dense = PositiveDense
# this function needs to be the inverse map of function
# used in PositiveDense layers
rescale = hid_dense.inv_rectifier_fn
else:
Dense = nn.Dense
hid_dense = nn.Dense
rescale = lambda x: x
self.use_init = False
# check if Gaussian map was provided
if self.gaussian_map is not None:
factor, mean = self.compute_gaussian_map(self.gaussian_map)
else:
factor, mean = self.compute_identity_map(self.dim_data)

for i in range(1, num_hidden):
w_zs = []
# keep track of previous size to normalize accordingly
normalization = 1
# subsequent layers propagate value of potential provided by
# first layer in x normalization factor is rescaled accordingly
for i in range(0, self.num_hidden):
w_zs.append(
Dense(
hid_dense(
self.dim_hidden[i],
kernel_init=self.init_fn(self.init_std),
use_bias=False
kernel_init=initializers.constant(rescale(1.0 / normalization)),
use_bias=False,
)
)
normalization = self.dim_hidden[i]
# final layer computes average, still with normalized rescaling
w_zs.append(
Dense(1, kernel_init=self.init_fn(self.init_std), use_bias=False)
hid_dense(
1,
kernel_init=initializers.constant(rescale(1.0 / normalization)),
use_bias=False,
)
)
self.w_zs = w_zs

w_xs = []
for i in range(num_hidden):
# first square layer, initialized to identity
w_xs.append(
PosDefPotentials(
self.dim_data,
num_potentials=1,
kernel_init=lambda *args, **kwargs: factor,
bias_init=lambda *args, **kwargs: mean,
use_bias=True,
)
)

# subsequent layers reinjected into convex functions
for i in range(self.num_hidden):
w_xs.append(
nn.Dense(
self.dim_hidden[i],
kernel_init=self.init_fn(self.init_std),
use_bias=True
bias_init=self.init_fn(self.init_std),
use_bias=True,
)
)
# final layer, to output number
w_xs.append(
nn.Dense(1, kernel_init=self.init_fn(self.init_std), use_bias=True)
nn.Dense(
1,
kernel_init=self.init_fn(self.init_std),
bias_init=self.init_fn(self.init_std),
use_bias=True,
)
)
self.w_xs = w_xs

@nn.compact
def __call__(self, x):
"""Apply ICNN module.
def compute_gaussian_map(self, inputs):

def compute_moments(x, reg=1e-4, sqrt_inv=False):
shape = x.shape
z = x.reshape(shape[0], -1)
mu = jnp.expand_dims(jnp.mean(z, axis=0), 0)
z = z - mu
matmul = lambda a, b: jnp.matmul(a, b)
sigma = jax.vmap(matmul)(jnp.expand_dims(z, 2), jnp.expand_dims(z, 1))
# unbiased estimate
sigma = jnp.sum(sigma, axis=0) / (shape[0] - 1)
# regularize
sigma = sigma + reg * jnp.eye(shape[1])

if sqrt_inv:
sigma_sqrt, sigma_inv_sqrt, _ = sqrtm(sigma)
return sigma, sigma_sqrt, sigma_inv_sqrt, mu
else:
return sigma, mu

source, target = inputs
_, covs_sqrt, covs_inv_sqrt, mus = compute_moments(source, sqrt_inv=True)
covt, mut = compute_moments(target, sqrt_inv=False)

Args:
x: jnp.ndarray<float>[batch_size, n_features]: input to the ICNN.
mo = sqrtm_only(jnp.dot(jnp.dot(covs_sqrt, covt), covs_sqrt))
A = jnp.dot(jnp.dot(covs_inv_sqrt, mo), covs_inv_sqrt)
b = jnp.squeeze(mus) - jnp.linalg.solve(A, jnp.squeeze(mut))
A = sqrtm_only(A)

Returns:
jnp.ndarray<float>[1]: output of ICNN.
"""
z = self.act_fn(self.w_xs[0](x))
z = jnp.multiply(z, z)
return jnp.expand_dims(A, 0), jnp.expand_dims(b, 0)

for Wz, Wx in zip(self.w_zs[:-1], self.w_xs[1:-1]):
z = self.act_fn(jnp.add(Wz(z), Wx(x)))
y = jnp.add(self.w_zs[-1](z), self.w_xs[-1](x))
def compute_identity_map(self, input_dim):
A = jnp.eye(input_dim).reshape((1, input_dim, input_dim))
b = jnp.zeros((1, input_dim))

return jnp.squeeze(y)
return A, b

@nn.compact
def __call__(self, x):
for i in range(self.num_hidden + 2):
if i == 0:
z = self.w_xs[i](x)
# apply both transform on hidden state and x
# x is one step ahead as there is one more hidden layer for x
else:
z = jnp.add(self.w_zs[i - 1](z), self.w_xs[i](x))
if i != 0 or i != self.num_hidden + 1:
z = self.act_fn(z)
return jnp.squeeze(z)
Loading

0 comments on commit 597eee5

Please sign in to comment.