Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Addition of ICNN Initialization Schemes #90

Merged
merged 22 commits into from
Jun 30, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ There are currently three packages, ``geometry``, ``core`` and ``tools``, playin
notebooks/gromov_wasserstein_multiomics.ipynb
notebooks/fairness.ipynb
notebooks/neural_dual.ipynb
notebooks/icnn_inits.ipynb

.. toctree::
:maxdepth: 1
Expand Down
519 changes: 519 additions & 0 deletions docs/notebooks/icnn_inits.ipynb

Large diffs are not rendered by default.

82 changes: 47 additions & 35 deletions docs/notebooks/neural_dual.ipynb

Large diffs are not rendered by default.

196 changes: 114 additions & 82 deletions ott/core/icnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
Expand All @@ -20,65 +20,21 @@
import jax
import jax.numpy as jnp
from flax import linen as nn
from jax.nn import initializers

from ott.core.layers import PosDefPotentials, PositiveDense
from ott.geometry.matrix_square_root import sqrtm, sqrtm_only

PRNGKey = Any
Shape = Tuple[int]
Dtype = Any # this could be a real type?
Dtype = Any
Array = Any


class PositiveDense(nn.Module):
"""A linear transformation using a weight matrix with all entries positive.

Args:
dim_hidden: the number of output dim_hidden.
beta: inverse temperature parameter of the softplus function (default: 1).
use_bias: whether to add a bias to the output (default: True).
dtype: the dtype of the computation (default: float32).
precision: numerical precision of computation see `jax.lax.Precision`
for details.
kernel_init: initializer function for the weight matrix.
bias_init: initializer function for the bias.
"""

dim_hidden: int
beta: float = 1.0
use_bias: bool = True
dtype: Any = jnp.float32
precision: Any = None
kernel_init: Callable[[PRNGKey, Shape, Dtype],
Array] = nn.initializers.lecun_normal()
bias_init: Callable[[PRNGKey, Shape, Dtype], Array] = nn.initializers.zeros

@nn.compact
def __call__(self, inputs):
"""Apply a linear transformation to inputs along the last dimension.

Args:
inputs: The nd-array to be transformed.
Returns:
The transformed input.
"""
inputs = jnp.asarray(inputs, self.dtype)
kernel = self.param(
'kernel', self.kernel_init, (inputs.shape[-1], self.dim_hidden)
)
scaled_kernel = self.beta * kernel
kernel = jnp.asarray(1 / self.beta * nn.softplus(scaled_kernel), self.dtype)
y = jax.lax.dot_general(
inputs,
kernel, (((inputs.ndim - 1,), (0,)), ((), ())),
precision=self.precision
)
if self.use_bias:
bias = self.param('bias', self.bias_init, (self.dim_hidden,))
bias = jnp.asarray(bias, self.dtype)
y = y + bias
return y


class ICNN(nn.Module):
"""Input convex neural network (ICNN) architeture.
"""Input convex neural network (ICNN) architeture with initialization.

Containing initialization schemes introduced in Bunne+(2022).

Args:
dim_hidden: sequence specifying size of hidden dimensions. The
Expand All @@ -88,66 +44,142 @@ class ICNN(nn.Module):
`jax.nn.initializers.normal`).
act_fn: choice of activation function used in network architecture
(needs to be convex, default: `nn.leaky_relu`).
pos_weights: choice to enforce positivity of weight or use regularizer.
dim_data: data dimensionality (default: 2).
gaussian_map: data inputs of source and target measures for
initialization scheme based on Gaussian approximation of input and
target measure (if None, identity initialization is used).
"""

dim_hidden: Sequence[int]
init_std: float = 0.1
init_std: float = 1e-1
init_fn: Callable = jax.nn.initializers.normal
act_fn: Callable = nn.leaky_relu
act_fn: Callable = nn.relu
pos_weights: bool = True
dim_data: int = 2
gaussian_map: Tuple[jnp.ndarray, jnp.ndarray] = None

def setup(self):
num_hidden = len(self.dim_hidden)

w_zs = []
self.num_hidden = len(self.dim_hidden)

if self.pos_weights:
Dense = PositiveDense
hid_dense = PositiveDense
# this function needs to be the inverse map of function
# used in PositiveDense layers
rescale = hid_dense.inv_rectifier_fn
else:
Dense = nn.Dense
hid_dense = nn.Dense
rescale = lambda x: x
self.use_init = False
# check if Gaussian map was provided
if self.gaussian_map is not None:
factor, mean = self.compute_gaussian_map(self.gaussian_map)
else:
factor, mean = self.compute_identity_map(self.dim_data)

for i in range(1, num_hidden):
w_zs = []
# keep track of previous size to normalize accordingly
normalization = 1
# subsequent layers propagate value of potential provided by
# first layer in x normalization factor is rescaled accordingly
for i in range(0, self.num_hidden):
w_zs.append(
Dense(
hid_dense(
self.dim_hidden[i],
kernel_init=self.init_fn(self.init_std),
use_bias=False
kernel_init=initializers.constant(rescale(1.0 / normalization)),
use_bias=False,
)
)
normalization = self.dim_hidden[i]
# final layer computes average, still with normalized rescaling
w_zs.append(
Dense(1, kernel_init=self.init_fn(self.init_std), use_bias=False)
hid_dense(
1,
kernel_init=initializers.constant(rescale(1.0 / normalization)),
use_bias=False,
)
)
self.w_zs = w_zs

w_xs = []
for i in range(num_hidden):
# first square layer, initialized to identity
w_xs.append(
PosDefPotentials(
self.dim_data,
num_potentials=1,
kernel_init=lambda *args, **kwargs: factor,
bias_init=lambda *args, **kwargs: mean,
use_bias=True,
)
)

# subsequent layers reinjected into convex functions
for i in range(self.num_hidden):
w_xs.append(
nn.Dense(
self.dim_hidden[i],
kernel_init=self.init_fn(self.init_std),
use_bias=True
bias_init=self.init_fn(self.init_std),
use_bias=True,
)
)
# final layer, to output number
w_xs.append(
nn.Dense(1, kernel_init=self.init_fn(self.init_std), use_bias=True)
nn.Dense(
1,
kernel_init=self.init_fn(self.init_std),
bias_init=self.init_fn(self.init_std),
use_bias=True,
)
)
self.w_xs = w_xs

@nn.compact
def __call__(self, x):
"""Apply ICNN module.
def compute_gaussian_map(self, inputs):

def compute_moments(x, reg=1e-4, sqrt_inv=False):
shape = x.shape
z = x.reshape(shape[0], -1)
mu = jnp.expand_dims(jnp.mean(z, axis=0), 0)
z = z - mu
matmul = lambda a, b: jnp.matmul(a, b)
sigma = jax.vmap(matmul)(jnp.expand_dims(z, 2), jnp.expand_dims(z, 1))
# unbiased estimate
sigma = jnp.sum(sigma, axis=0) / (shape[0] - 1)
# regularize
sigma = sigma + reg * jnp.eye(shape[1])

if sqrt_inv:
sigma_sqrt, sigma_inv_sqrt, _ = sqrtm(sigma)
return sigma, sigma_sqrt, sigma_inv_sqrt, mu
else:
return sigma, mu

source, target = inputs
_, covs_sqrt, covs_inv_sqrt, mus = compute_moments(source, sqrt_inv=True)
covt, mut = compute_moments(target, sqrt_inv=False)

Args:
x: jnp.ndarray<float>[batch_size, n_features]: input to the ICNN.
mo = sqrtm_only(jnp.dot(jnp.dot(covs_sqrt, covt), covs_sqrt))
A = jnp.dot(jnp.dot(covs_inv_sqrt, mo), covs_inv_sqrt)
b = jnp.squeeze(mus) - jnp.linalg.solve(A, jnp.squeeze(mut))
A = sqrtm_only(A)

Returns:
jnp.ndarray<float>[1]: output of ICNN.
"""
z = self.act_fn(self.w_xs[0](x))
z = jnp.multiply(z, z)
return jnp.expand_dims(A, 0), jnp.expand_dims(b, 0)

for Wz, Wx in zip(self.w_zs[:-1], self.w_xs[1:-1]):
z = self.act_fn(jnp.add(Wz(z), Wx(x)))
y = jnp.add(self.w_zs[-1](z), self.w_xs[-1](x))
def compute_identity_map(self, input_dim):
A = jnp.eye(input_dim).reshape((1, input_dim, input_dim))
b = jnp.zeros((1, input_dim))

return jnp.squeeze(y)
return A, b

@nn.compact
def __call__(self, x):
for i in range(self.num_hidden + 2):
if i == 0:
z = self.w_xs[i](x)
# apply both transform on hidden state and x
# x is one step ahead as there is one more hidden layer for x
else:
z = jnp.add(self.w_zs[i - 1](z), self.w_xs[i](x))
if i != 0 or i != self.num_hidden + 1:
z = self.act_fn(z)
return jnp.squeeze(z)
Loading