Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add quadratic layers and enhance ICNNs, update tutorial #477

Merged
merged 30 commits into from
Dec 20, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
b2b7fa6
fix a bug when bias is False
nvesseron Nov 19, 2023
b7674b8
update the PosDefPotentials class
nvesseron Nov 21, 2023
3bdfbc0
update PosDefPotentials
nvesseron Nov 21, 2023
4724bf0
added icnn adjustments
lucaeyring Nov 21, 2023
cd5c573
neuraldual fix freezee weights
lucaeyring Nov 22, 2023
3b6bb61
Merge branch 'ott-jax:main' into fix_bug_quad_layer
nvesseron Nov 28, 2023
051b222
use relu by default as activation function and rectifier_fn
nvesseron Nov 30, 2023
b47cb66
updates
nvesseron Dec 1, 2023
c86e135
solved conflicts
nvesseron Dec 1, 2023
73e4599
Update neural layers
michalk8 Dec 19, 2023
deda6a2
Clean ICNN impl.
michalk8 Dec 19, 2023
ada8983
Revert changes in the potentials
michalk8 Dec 19, 2023
7aa580f
Fix D102
michalk8 Dec 19, 2023
9afdd4f
Fix indentation
michalk8 Dec 19, 2023
a36a014
Remove `;`
michalk8 Dec 19, 2023
6b5f73b
Use tensordot
michalk8 Dec 19, 2023
ad84878
Update docs
michalk8 Dec 19, 2023
1f9d886
First rounds of test fixing
michalk8 Dec 19, 2023
95844b6
Fix rest of the tests
michalk8 Dec 19, 2023
f971859
Revert assertion
michalk8 Dec 19, 2023
092906c
Polish more docs
michalk8 Dec 19, 2023
9e8fe14
Fix docs linter
michalk8 Dec 19, 2023
e05d54d
Fix links in neuraldual notebook
michalk8 Dec 19, 2023
b9650a9
Fix links in the rest of the neural docs
michalk8 Dec 19, 2023
a5febbb
Update docs
michalk8 Dec 19, 2023
c453cb9
Allow ranks to be a tuple
michalk8 Dec 19, 2023
b702b27
Remvoe note
michalk8 Dec 19, 2023
353d0b4
Fix MetaMLP
michalk8 Dec 20, 2023
94df50f
Rerun neural notebooks
michalk8 Dec 20, 2023
20bfa12
Fix rendering
michalk8 Dec 20, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
83 changes: 53 additions & 30 deletions docs/tutorials/MetaOT.ipynb

Large diffs are not rendered by default.

92 changes: 47 additions & 45 deletions docs/tutorials/Monge_Gap.ipynb

Large diffs are not rendered by default.

143 changes: 110 additions & 33 deletions docs/tutorials/icnn_inits.ipynb

Large diffs are not rendered by default.

174 changes: 109 additions & 65 deletions docs/tutorials/neural_dual.ipynb

Large diffs are not rendered by default.

44 changes: 23 additions & 21 deletions docs/tutorials/tracking_progress.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@
"name": "stdout",
"output_type": "stream",
"text": [
"Converged: True, #iters: 7, cost: 1.2429015636444092\n"
"Converged: True, #iters: 70, cost: 1.2429015636444092\n"
]
}
],
Expand Down Expand Up @@ -170,14 +170,14 @@
"name": "stdout",
"output_type": "stream",
"text": [
"10 / 2000 -- 0.049124784767627716\n",
"20 / 2000 -- 0.019962385296821594\n",
"30 / 2000 -- 0.00910455733537674\n",
"40 / 2000 -- 0.004339158535003662\n",
"50 / 2000 -- 0.002111591398715973\n",
"60 / 2000 -- 0.001037590205669403\n",
"70 / 2000 -- 0.0005124583840370178\n",
"Converged: True, #iters: 7, cost: 1.2429015636444092\n"
"10 / 2000 -- 0.04912472516298294\n",
"20 / 2000 -- 0.019962534308433533\n",
"30 / 2000 -- 0.009104534983634949\n",
"40 / 2000 -- 0.004339255392551422\n",
"50 / 2000 -- 0.0021116361021995544\n",
"60 / 2000 -- 0.001037605106830597\n",
"70 / 2000 -- 0.0005124807357788086\n",
"Converged: True, #iters: 70, cost: 1.2429015636444092\n"
]
}
],
Expand Down Expand Up @@ -219,7 +219,7 @@
"name": "stderr",
"output_type": "stream",
"text": [
" 4%|██████▌ | 7/200 [00:00<00:08, 23.28it/s, error: 5.124584e-04]\n"
" 4%| | 7/200 [00:00<00:22, 8.57it/s, error: 5.124807e-04]\n"
]
}
],
Expand All @@ -240,7 +240,7 @@
"name": "stdout",
"output_type": "stream",
"text": [
"Converged: True, #iters: 7, cost: 1.2429015636444092\n"
"Converged: True, #iters: 70, cost: 1.2429015636444092\n"
]
}
],
Expand Down Expand Up @@ -270,7 +270,7 @@
"name": "stderr",
"output_type": "stream",
"text": [
" 4%|██████▌ | 7/200 [00:00<00:08, 23.53it/s, error: 5.124584e-04]\n"
" 4%| | 7/200 [00:00<00:23, 8.27it/s, error: 5.124807e-04]\n"
]
}
],
Expand All @@ -293,7 +293,7 @@
"name": "stdout",
"output_type": "stream",
"text": [
"Converged: True, #iters: 7, cost: 1.2429015636444092\n"
"Converged: True, #iters: 70, cost: 1.2429015636444092\n"
]
}
],
Expand Down Expand Up @@ -323,7 +323,7 @@
"name": "stderr",
"output_type": "stream",
"text": [
" 8%|██████████████▉ | 16/200 [00:00<00:07, 23.11it/s, error: 3.191899e-04]\n"
" 8%| | 16/200 [00:02<00:25, 7.10it/s, error: 3.223309e-04]\n"
]
}
],
Expand All @@ -347,7 +347,7 @@
"name": "stdout",
"output_type": "stream",
"text": [
"Converged: True, cost: 1.7340879440307617\n"
"Converged: True, cost: 1.7340872287750244\n"
]
}
],
Expand All @@ -373,6 +373,8 @@
"outputs": [],
"source": [
"# Samples spiral\n",
"\n",
"\n",
"def sample_spiral(\n",
" n, min_radius, max_radius, key, min_angle=0, max_angle=10, noise=1.0\n",
"):\n",
Expand Down Expand Up @@ -463,14 +465,14 @@
"output_type": "stream",
"text": [
"1 / 20 -- -1.0\n",
"2 / 20 -- 0.1304362416267395\n",
"3 / 20 -- 0.0898154005408287\n",
"4 / 20 -- 0.06759566068649292\n",
"5 / 20 -- 0.05465700849890709\n",
"2 / 20 -- 0.13043621182441711\n",
"3 / 20 -- 0.08981533348560333\n",
"4 / 20 -- 0.06759564578533173\n",
"5 / 20 -- 0.0546572208404541\n",
"\n",
"5 outer iterations were needed\n",
"The outer loop of Gromov Wasserstein has converged: True\n",
"The final regularized GW cost is: 1183.613\n"
"The final regularized GW cost is: 1183.617\n"
]
}
],
Expand Down Expand Up @@ -520,7 +522,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.6"
"version": "3.10.10"
}
},
"nbformat": 4,
Expand Down
224 changes: 149 additions & 75 deletions src/ott/neural/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,118 +22,192 @@
PRNGKey = jax.Array
Shape = Tuple[int, ...]
Dtype = Any
Array = Any
Array = jnp.ndarray

# wrap to silence docs linter
DEFAULT_KERNEL_INIT = lambda *a, **k: nn.initializers.lecun_normal()(*a, **k)
DEFAULT_BIAS_INIT = nn.initializers.zeros
DEFAULT_RECTIFIER = nn.activation.relu


class PositiveDense(nn.Module):
"""A linear transformation using a weight matrix with all entries positive.
"""A linear transformation using a matrix with all entries non-negative.

Args:
dim_hidden: the number of output dim_hidden.
rectifier_fn: choice of rectifier function (default: softplus function).
inv_rectifier_fn: choice of inverse rectifier function
michalk8 marked this conversation as resolved.
Show resolved Hide resolved
(default: inverse softplus function).
dtype: the dtype of the computation (default: float32).
precision: numerical precision of computation see `jax.lax.Precision`
for details.
kernel_init: initializer function for the weight matrix.
bias_init: initializer function for the bias.
dim_hidden: Number of output dimensions.
rectifier_fn: Rectifier function. The default is
:func:`~flax.linen.activation.relu`.
use_bias: Whether to add bias to the output.
kernel_init: Initializer for the matrix. The default is
:func:`~flax.linen.initializers.lecun_normal`.
bias_init: Initializer for the bias. The default is
:func:`~flax.linen.initializers.zeros`.
precision: Numerical precision of the computation.
michalk8 marked this conversation as resolved.
Show resolved Hide resolved
"""

dim_hidden: int
rectifier_fn: Callable[[jnp.ndarray], jnp.ndarray] = nn.softplus
inv_rectifier_fn: Callable[[jnp.ndarray],
jnp.ndarray] = lambda x: jnp.log(jnp.exp(x) - 1)
rectifier_fn: Callable[[Array], Array] = DEFAULT_RECTIFIER
use_bias: bool = True
dtype: Any = jnp.float32
precision: Any = None
kernel_init: Optional[Callable[[PRNGKey, Shape, Dtype], Array]] = None,
bias_init: Callable[[PRNGKey, Shape, Dtype], Array] = nn.initializers.zeros
kernel_init: Callable[[PRNGKey, Shape, Dtype], Array] = DEFAULT_KERNEL_INIT
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i think it is counterproductive to use a default initializer for this layer that has symmetric values. Here this will result in half of entries that will be below 0, and whose gradients will likely vanish quite quickly. See e.g https://openreview.net/pdf?id=pWZ97hUQtQ . Although I am not sure what we could use, it seemes that initializing by default with absolute value of a the default seems more appropriate. Another legit option would be to normalize any kernel matrix with row values summing to 1.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about, for simpliciy, a truncated normal with low=0.0?

bias_init: Callable[[PRNGKey, Shape, Dtype], Array] = DEFAULT_BIAS_INIT
precision: Optional[jax.lax.Precision] = None

@nn.compact
def __call__(self, inputs: jnp.ndarray) -> jnp.ndarray:
"""Applies a linear transformation to inputs along the last dimension.
def __call__(self, x: jnp.ndarray) -> jnp.ndarray:
"""Applies a linear transformation to x along the last dimension.

Args:
inputs: Array to be transformed.
x: Array of shape ``[batch, ..., features]``.

Returns:
The transformed input.
Array of shape ``[batch, ..., dim_hidden]``.
"""
kernel_init = nn.initializers.lecun_normal(
) if self.kernel_init is None else self.kernel_init
# TODO(michalk8): update when refactoring neuraldual
# assert x.ndim > 1, x.ndim

inputs = jnp.asarray(inputs, self.dtype)
kernel = self.param(
"kernel", kernel_init, (inputs.shape[-1], self.dim_hidden)
"kernel", self.kernel_init, (x.shape[-1], self.dim_hidden)
)
kernel = self.rectifier_fn(kernel)
kernel = jnp.asarray(kernel, self.dtype)
y = jax.lax.dot_general(
inputs,
kernel, (((inputs.ndim - 1,), (0,)), ((), ())),
precision=self.precision
)

x = jnp.tensordot(x, kernel, axes=(-1, 0), precision=self.precision)
if self.use_bias:
bias = self.param("bias", self.bias_init, (self.dim_hidden,))
bias = jnp.asarray(bias, self.dtype)
return y + bias
return y
x = x + self.param("bias", self.bias_init, (self.dim_hidden,))

return x


class PosDefPotentials(nn.Module):
r"""A layer to output :math:`\frac{1}{2} ||A_i^T (x - b_i)||^2_i` potentials.
r""":math:`\frac{1}{2} x^T (A_i A_i^T + \text{Diag}(d_i)) x + b_i^T x^2 + c_i` potentials.

michalk8 marked this conversation as resolved.
Show resolved Hide resolved
This class implements a layer that takes (batched) ``d``-dimensional vectors
``x`` in, to output a ``num_potentials``-dimensional vector. Each of the
entries in that output is a positive definite quadratic form evaluated at
``x``; each of these quadratic terms is parameterized as a low-rank plus
diagonal matrix. The low-rank term is parameterized as :math:`A_i A_i^T`,
where each of these matrices is of size ``(rank, d)``. Taken together,
these matrices form a tensor ``(num_potentials, rank, d)``.
The diagonal terms :math:`d_i` form a ``(num_potentials, d)`` matrix of
positive values; the linear terms :math:`b_i` form a ``(num_potentials, d)``
matrix. Finally, the :math:`c_i` are contained in a vector of size
``(num_potentials,)``.

Args:
use_bias: whether to add a bias to the output.
dtype: the dtype of the computation.
precision: numerical precision of computation see `jax.lax.Precision`
for details.
kernel_init: initializer function for the weight matrix.
bias_init: initializer function for the bias.
"""
dim_data: int
num_potentials: Dimension of the output.
rank: Rank of the matrices :math:`A_i` used as low-rank factors
for the quadratic potentials.
rectifier_fn: Rectifier function to ensure non-negativity of the diagonals
:math:`d_i`. The default is :func:`~flax.linen.activation.relu`.
use_linear: Whether to add a linear layers :math:`b_i` to the outputs.
use_bias: Whether to add biases :math:`c_i` to the outputs.
kernel_lr_init: Initializer for the matrices :math:`A_i`
of the quadratic potentials when ``rank > 0``.
The default is :func:`~flax.linen.initializers.lecun_normal`.
kernel_diag_init: Initializer for the diagonals :math:`d_i`.
The default is :func:`~flax.linen.initializers.ones`.
kernel_linear_init: Initializer for the linear layers :math:`b_i`.
The default is :func:`~flax.linen.initializers.lecun_normal`.
bias_init: Initializer for the bias. The default is
:func:`~flax.linen.initializers.zeros`.
precision: Numerical precision of the computation.
""" # noqa: E501

num_potentials: int
rank: int = 0
michalk8 marked this conversation as resolved.
Show resolved Hide resolved
rectifier_fn: Callable[[Array], Array] = DEFAULT_RECTIFIER
use_linear: bool = True
use_bias: bool = True
dtype: Any = jnp.float32
precision: Any = None
kernel_init: Optional[Callable[[PRNGKey, Shape, Dtype], Array]] = None
bias_init: Callable[[PRNGKey, Shape, Dtype], Array] = nn.initializers.zeros
kernel_lr_init: Callable[[PRNGKey, Shape, Dtype], Array] = DEFAULT_KERNEL_INIT
kernel_diag_init: Callable[[PRNGKey, Shape, Dtype],
Array] = nn.initializers.ones
kernel_linear_init: Callable[[PRNGKey, Shape, Dtype],
Array] = DEFAULT_KERNEL_INIT
bias_init: Callable[[PRNGKey, Shape, Dtype], Array] = DEFAULT_BIAS_INIT
precision: Optional[jax.lax.Precision] = None

@nn.compact
def __call__(self, inputs: jnp.ndarray) -> jnp.ndarray:
"""Apply a few quadratic forms.
def __call__(self, x: jnp.ndarray) -> jnp.ndarray:
"""Compute quadratic forms of the input.

Args:
inputs: Array to be transformed (possibly batched).
x: Array of shape ``[batch, ..., features]``.
michalk8 marked this conversation as resolved.
Show resolved Hide resolved

Returns:
The transformed input.
Array of shape ``[batch, ..., num_potentials]``.
michalk8 marked this conversation as resolved.
Show resolved Hide resolved
"""
kernel_init = nn.initializers.lecun_normal(
) if self.kernel_init is None else self.kernel_init
inputs = jnp.asarray(inputs, self.dtype)
kernel = self.param(
"kernel", kernel_init,
(self.num_potentials, inputs.shape[-1], inputs.shape[-1])
# TODO(michalk8): update when refactoring neuraldual
# assert x.ndim > 1, x.ndim

dim_data = x.shape[-1]
x = x.reshape((-1, dim_data))

diag_kernel = self.param(
"diag_kernel", self.kernel_diag_init, (dim_data, self.num_potentials)
)
# ensures the diag_kernel parameter stays non negative
diag_kernel = self.rectifier_fn(diag_kernel)

if self.use_bias:
bias = self.param(
"bias", self.bias_init, (self.num_potentials, self.dim_data)
)
bias = jnp.asarray(bias, self.dtype)
# (batch, dim_data, 1), (1, dim_data, num_potentials)
y = 0.5 * jnp.sum(((x ** 2)[..., None] * diag_kernel[None]), axis=1)

y = inputs.reshape((-1, inputs.shape[-1])) if inputs.ndim == 1 else inputs
y = y[..., None] - bias.T[None, ...]
y = jax.lax.dot_general(
y, kernel, (((1,), (1,)), ((2,), (0,))), precision=self.precision
if self.rank > 0:
quad_kernel = self.param(
"quad_kernel", self.kernel_lr_init,
(self.num_potentials, dim_data, self.rank)
)
else:
y = jax.lax.dot_general(
inputs,
kernel, (((inputs.ndim - 1,), (0,)), ((), ())),
precision=self.precision
# (batch, num_potentials, rank)
quad = 0.5 * jnp.tensordot(
x, quad_kernel, axes=(-1, 1), precision=self.precision
) ** 2
y = y + jnp.sum(quad, axis=-1)

if self.use_linear:
linear_kernel = self.param(
"lin_kernel", self.kernel_linear_init,
(dim_data, self.num_potentials)
)
y = y + jnp.dot(x, linear_kernel, precision=self.precision)

if self.use_bias:
y = y + self.param("bias", self.bias_init, (self.num_potentials,))

return y

@classmethod
def init_from_samples(
cls, source: jnp.ndarray, target: jnp.ndarray, **kwargs: Any
) -> "PosDefPotentials":
"""Initialize the layer using Gaussian approximation :cite:`bunne:22`.

Args:
source: Samples from the source distribution, array of shape ``[n, d]``.
target: Samples from the target distribution, array of shape ``[m, d]``.
kwargs: Keyword arguments for initialization. Note that ``use_linear``
will be always set to :obj:`True`.

Returns:
The layer with fixed linear and quadratic initialization.
"""
factor, mean = _compute_gaussian_map_params(source, target)

kwargs["use_linear"] = True
return cls(
kernel_lr_init=lambda *_, **__: factor,
kernel_linear_init=lambda *_, **__: mean.T,
**kwargs,
)


def _compute_gaussian_map_params(
source: jnp.ndarray, target: jnp.ndarray
) -> Tuple[jnp.ndarray, jnp.ndarray]:
from ott.math import matrix_square_root
from ott.tools.gaussian_mixture import gaussian

g_s = gaussian.Gaussian.from_samples(source)
g_t = gaussian.Gaussian.from_samples(target)
lin_op = g_s.scale.gaussian_map(g_t.scale)
b = jnp.squeeze(g_t.loc) - lin_op @ jnp.squeeze(g_s.loc)
lin_op = matrix_square_root.sqrtm_only(lin_op)

y = 0.5 * y * y
return jnp.sum(y.reshape((-1, self.num_potentials, self.dim_data)), axis=2)
return jnp.expand_dims(lin_op, 0), jnp.expand_dims(b, 0)
Loading
Loading