Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add quadratic layers and enhance ICNNs, update tutorial #477

Merged
merged 30 commits into from
Dec 20, 2023
Merged
Show file tree
Hide file tree
Changes from 9 commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
b2b7fa6
fix a bug when bias is False
nvesseron Nov 19, 2023
b7674b8
update the PosDefPotentials class
nvesseron Nov 21, 2023
3bdfbc0
update PosDefPotentials
nvesseron Nov 21, 2023
4724bf0
added icnn adjustments
lucaeyring Nov 21, 2023
cd5c573
neuraldual fix freezee weights
lucaeyring Nov 22, 2023
3b6bb61
Merge branch 'ott-jax:main' into fix_bug_quad_layer
nvesseron Nov 28, 2023
051b222
use relu by default as activation function and rectifier_fn
nvesseron Nov 30, 2023
b47cb66
updates
nvesseron Dec 1, 2023
c86e135
solved conflicts
nvesseron Dec 1, 2023
73e4599
Update neural layers
michalk8 Dec 19, 2023
deda6a2
Clean ICNN impl.
michalk8 Dec 19, 2023
ada8983
Revert changes in the potentials
michalk8 Dec 19, 2023
7aa580f
Fix D102
michalk8 Dec 19, 2023
9afdd4f
Fix indentation
michalk8 Dec 19, 2023
a36a014
Remove `;`
michalk8 Dec 19, 2023
6b5f73b
Use tensordot
michalk8 Dec 19, 2023
ad84878
Update docs
michalk8 Dec 19, 2023
1f9d886
First rounds of test fixing
michalk8 Dec 19, 2023
95844b6
Fix rest of the tests
michalk8 Dec 19, 2023
f971859
Revert assertion
michalk8 Dec 19, 2023
092906c
Polish more docs
michalk8 Dec 19, 2023
9e8fe14
Fix docs linter
michalk8 Dec 19, 2023
e05d54d
Fix links in neuraldual notebook
michalk8 Dec 19, 2023
b9650a9
Fix links in the rest of the neural docs
michalk8 Dec 19, 2023
a5febbb
Update docs
michalk8 Dec 19, 2023
c453cb9
Allow ranks to be a tuple
michalk8 Dec 19, 2023
b702b27
Remvoe note
michalk8 Dec 19, 2023
353d0b4
Fix MetaMLP
michalk8 Dec 20, 2023
94df50f
Rerun neural notebooks
michalk8 Dec 20, 2023
20bfa12
Fix rendering
michalk8 Dec 20, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
73 changes: 37 additions & 36 deletions docs/tutorials/neural_dual.ipynb

Large diffs are not rendered by default.

118 changes: 69 additions & 49 deletions src/ott/neural/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,20 +29,17 @@ class PositiveDense(nn.Module):
"""A linear transformation using a weight matrix with all entries positive.

Args:
dim_hidden: the number of output dim_hidden.
rectifier_fn: choice of rectifier function (default: softplus function).
inv_rectifier_fn: choice of inverse rectifier function
michalk8 marked this conversation as resolved.
Show resolved Hide resolved
(default: inverse softplus function).
dtype: the dtype of the computation (default: float32).
precision: numerical precision of computation see `jax.lax.Precision`
for details.
kernel_init: initializer function for the weight matrix.
bias_init: initializer function for the bias.
dim_hidden: the number of output dim_hidden.
rectifier_fn: choice of rectifier function (default: softplus function).
dtype: the dtype of the computation (default: float32).
precision: numerical precision of computation see `jax.lax.Precision`
for details.
kernel_init: initializer function for the weight matrix.
bias_init: initializer function for the bias.
"""

dim_hidden: int
rectifier_fn: Callable[[jnp.ndarray], jnp.ndarray] = nn.softplus
inv_rectifier_fn: Callable[[jnp.ndarray],
jnp.ndarray] = lambda x: jnp.log(jnp.exp(x) - 1)
rectifier_fn: Callable[[jnp.ndarray], jnp.ndarray] = nn.relu
use_bias: bool = True
dtype: Any = jnp.float32
precision: Any = None
Expand All @@ -54,10 +51,10 @@ def __call__(self, inputs: jnp.ndarray) -> jnp.ndarray:
"""Applies a linear transformation to inputs along the last dimension.

Args:
inputs: Array to be transformed.
inputs: Array to be transformed.
michalk8 marked this conversation as resolved.
Show resolved Hide resolved

Returns:
The transformed input.
The transformed input.
michalk8 marked this conversation as resolved.
Show resolved Hide resolved
"""
kernel_init = nn.initializers.lecun_normal(
) if self.kernel_init is None else self.kernel_init
Expand All @@ -81,59 +78,82 @@ def __call__(self, inputs: jnp.ndarray) -> jnp.ndarray:


class PosDefPotentials(nn.Module):
r"""A layer to output :math:`\frac{1}{2} ||A_i^T (x - b_i)||^2_i` potentials.
"""A layer to output 0.5 x^T(A_i A_i^T + Diag(d_i^2))x + b_i^T x + c_i potentials.
michalk8 marked this conversation as resolved.
Show resolved Hide resolved

Args:
use_bias: whether to add a bias to the output.
dtype: the dtype of the computation.
precision: numerical precision of computation see `jax.lax.Precision`
for details.
kernel_init: initializer function for the weight matrix.
bias_init: initializer function for the bias.
"""
dim_data: int
num_potentials: the dimension of the output
michalk8 marked this conversation as resolved.
Show resolved Hide resolved
rank: the rank of the matrix used for the quadratic layer
use_linear: whether to add a linear layer to the output
use_bias: whether to add a bias to the output.
dtype: the dtype of the computation.
precision: numerical precision of computation see `jax.lax.Precision` for details.
kernel_quadratic_init: initializer function for the weight matrix of the quadratic layer.
kernel_linear_init: initializer function for the weight matrix of the linea layer.
bias_init: initializer function for the bias.
""" # noqa: E501

num_potentials: int
rank: int = 0
michalk8 marked this conversation as resolved.
Show resolved Hide resolved
use_linear: bool = True
use_bias: bool = True
dtype: Any = jnp.float32
precision: Any = None
kernel_init: Optional[Callable[[PRNGKey, Shape, Dtype], Array]] = None
bias_init: Callable[[PRNGKey, Shape, Dtype], Array] = nn.initializers.zeros
kernel_quadratic_init: Callable[[PRNGKey, Shape, Dtype],
Array] = nn.initializers.lecun_normal()
kernel_diagonal_init: Callable[[PRNGKey, Shape, Dtype],
Array] = nn.initializers.ones
kernel_linear_init: Callable[[PRNGKey, Shape, Dtype],
Array] = nn.initializers.lecun_normal()
bias_init: Callable[[PRNGKey, Shape, Dtype],
Array] = nn.initializers.lecun_normal()

@nn.compact
def __call__(self, inputs: jnp.ndarray) -> jnp.ndarray:
"""Apply a few quadratic forms.
michalk8 marked this conversation as resolved.
Show resolved Hide resolved

Args:
inputs: Array to be transformed (possibly batched).
inputs: Array to be transformed (possibly batched).

Returns:
The transformed input.
The transformed input.
"""
kernel_init = nn.initializers.lecun_normal(
) if self.kernel_init is None else self.kernel_init
dim_data = inputs.shape[-1]
inputs = jnp.asarray(inputs, self.dtype)
kernel = self.param(
"kernel", kernel_init,
(self.num_potentials, inputs.shape[-1], inputs.shape[-1])
)
inputs = inputs.reshape((-1, dim_data))

if self.use_bias:
bias = self.param(
"bias", self.bias_init, (self.num_potentials, self.dim_data)
diag_kernel = self.param(
"diag_kernel", self.kernel_diagonal_init,
(1, dim_data, self.num_potentials)
)
# ensures the diag_kernel parameter stays non negative
diag_kernel = nn.activation.relu(diag_kernel)
michalk8 marked this conversation as resolved.
Show resolved Hide resolved
y = 0.5 * jnp.sum(jnp.multiply(inputs[..., None], diag_kernel) ** 2, axis=1)

if self.rank > 0:
quadratic_kernel = self.param(
"quad_kernel", self.kernel_quadratic_init,
(self.num_potentials, dim_data, self.rank)
)
bias = jnp.asarray(bias, self.dtype)

y = inputs.reshape((-1, inputs.shape[-1])) if inputs.ndim == 1 else inputs
y = y[..., None] - bias.T[None, ...]
y = jax.lax.dot_general(
y, kernel, (((1,), (1,)), ((2,), (0,))), precision=self.precision
y += jnp.sum(
0.5 * jnp.tensordot(
inputs,
quadratic_kernel,
axes=((inputs.ndim - 1,), (1,)),
precision=self.precision
) ** 2,
axis=2,
)
else:
y = jax.lax.dot_general(
inputs,
kernel, (((inputs.ndim - 1,), (0,)), ((), ())),
precision=self.precision

if self.use_linear:
linear_kernel = self.param(
"lin_kernel", self.kernel_linear_init,
(dim_data, self.num_potentials)
)
y = y + jnp.dot(inputs, linear_kernel, precision=self.precision)

y = 0.5 * y * y
return jnp.sum(y.reshape((-1, self.num_potentials, self.dim_data)), axis=2)
if self.use_bias:
bias = self.param("bias", self.bias_init, (1, self.num_potentials))
bias = jnp.asarray(bias, self.dtype)
y = y + bias

return y
Loading
Loading