Skip to content

Commit

Permalink
Update neural docs
Browse files Browse the repository at this point in the history
  • Loading branch information
michalk8 committed Nov 28, 2023
1 parent 1abf43a commit ece4f19
Show file tree
Hide file tree
Showing 18 changed files with 103 additions and 59 deletions.
2 changes: 2 additions & 0 deletions docs/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,8 @@ Packages
When the problem is *not* convex, which is the case for most other uses of
this toolbox, the initialization can play a decisive role to reach a useful
solution.
- :mod:`ott.neural` .. TODO(marcocuturi): add some nice text here please.

- :mod:`ott.tools` provides an interface to exploit OT solutions, as produced by
solvers from the :mod:`ott.solvers` module. Such tasks include computing
approximations to Wasserstein distances :cite:`genevay:18,sejourne:19`,
Expand Down
32 changes: 32 additions & 0 deletions docs/neural/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,35 @@ ott.neural
==========
.. module:: ott.neural
.. currentmodule:: ott.neural

Under reconstruction. .. TODO(marcocuturi): add some nice text here please.

.. toctree::
:maxdepth: 2

solvers

Models
------
.. autosummary::
:toctree: _autosummary

models.ICNN
models.MLP
models.MetaInitializer

Losses
------
.. autosummary::
:toctree: _autosummary

losses.monge_gap
losses.monge_gap_from_samples

Layers
------
.. autosummary::
:toctree: _autosummary

layers.PositiveDense
layers.PosDefPotentials
24 changes: 24 additions & 0 deletions docs/neural/solvers.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
ott.neural.solvers
==================
.. module:: ott.neural.solvers
.. currentmodule:: ott.neural.solvers

Under reconstruction. .. TODO(marcocuturi): add some nice text here please.

Solvers
-------
.. autosummary::
:toctree: _autosummary

map_estimator.MapEstimator
neuraldual.W2NeuralDual
neuraldual.BaseW2NeuralDual

Conjugate Solvers
-----------------
.. autosummary::
:toctree: _autosummary

conjugate.FenchelConjugateLBFGS
conjugate.FenchelConjugateSolver
conjugate.ConjugateResults
2 changes: 1 addition & 1 deletion src/ott/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ def create_gaussian_mixture_samplers(
valid_batch_size: int = 2048,
rng: Optional[jax.Array] = None,
) -> Tuple[Dataset, Dataset, int]:
"""Gaussian samplers for :class:`~ott.solvers.nn.neuraldual.W2NeuralDual`.
"""Gaussian samplers.
Args:
name_source: name of the source sampler
Expand Down
2 changes: 1 addition & 1 deletion src/ott/neural/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,4 +11,4 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from . import models, solvers
from . import layers, losses, models, solvers
19 changes: 11 additions & 8 deletions src/ott/neural/models/layers.py → src/ott/neural/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Callable, Tuple
from typing import Any, Callable, Optional, Tuple

import jax
import jax.numpy as jnp
Expand Down Expand Up @@ -46,8 +46,7 @@ class PositiveDense(nn.Module):
use_bias: bool = True
dtype: Any = jnp.float32
precision: Any = None
kernel_init: Callable[[PRNGKey, Shape, Dtype],
Array] = nn.initializers.lecun_normal()
kernel_init: Optional[Callable[[PRNGKey, Shape, Dtype], Array]] = None,
bias_init: Callable[[PRNGKey, Shape, Dtype], Array] = nn.initializers.zeros

@nn.compact
Expand All @@ -60,9 +59,12 @@ def __call__(self, inputs: jnp.ndarray) -> jnp.ndarray:
Returns:
The transformed input.
"""
kernel_init = nn.initializers.lecun_normal(
) if self.kernel_init is None else self.kernel_init

inputs = jnp.asarray(inputs, self.dtype)
kernel = self.param(
"kernel", self.kernel_init, (inputs.shape[-1], self.dim_hidden)
"kernel", kernel_init, (inputs.shape[-1], self.dim_hidden)
)
kernel = self.rectifier_fn(kernel)
kernel = jnp.asarray(kernel, self.dtype)
Expand All @@ -79,7 +81,7 @@ def __call__(self, inputs: jnp.ndarray) -> jnp.ndarray:


class PosDefPotentials(nn.Module):
"""A layer to output (0.5 || A_i^T (x - b_i)||^2)_i potentials.
r"""A layer to output :math:`\frac{1}{2} ||A_i^T (x - b_i)||^2_i` potentials.
Args:
use_bias: whether to add a bias to the output.
Expand All @@ -94,8 +96,7 @@ class PosDefPotentials(nn.Module):
use_bias: bool = True
dtype: Any = jnp.float32
precision: Any = None
kernel_init: Callable[[PRNGKey, Shape, Dtype],
Array] = nn.initializers.lecun_normal()
kernel_init: Optional[Callable[[PRNGKey, Shape, Dtype], Array]] = None
bias_init: Callable[[PRNGKey, Shape, Dtype], Array] = nn.initializers.zeros

@nn.compact
Expand All @@ -108,9 +109,11 @@ def __call__(self, inputs: jnp.ndarray) -> jnp.ndarray:
Returns:
The transformed input.
"""
kernel_init = nn.initializers.lecun_normal(
) if self.kernel_init is None else self.kernel_init
inputs = jnp.asarray(inputs, self.dtype)
kernel = self.param(
"kernel", self.kernel_init,
"kernel", kernel_init,
(self.num_potentials, inputs.shape[-1], inputs.shape[-1])
)

Expand Down
11 changes: 5 additions & 6 deletions src/ott/neural/solvers/losses.py → src/ott/neural/losses.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Any, Callable, Literal, Optional, Tuple, Union

import jax
Expand Down Expand Up @@ -48,7 +47,7 @@ def monge_gap(
W_{c, \varepsilon}(\hat{\rho}_n, T \sharp \hat{\rho}_n)
See :cite:`uscidda:23` Eq. (8). This function is a thin wrapper that calls
:func:`~ott.solvers.linear.nn.lossses.monge_gap_from_samples`.
:func:`~ott.neural.losses.monge_gap_from_samples`.
Args:
map_fn: Callable corresponding to map :math:`T` in definition above. The
Expand All @@ -69,7 +68,7 @@ def monge_gap(
given to rescale the cost such that ``cost_matrix /= scale_cost``.
If `True`, use 'mean'.
return_output: boolean to also return the
:class:`~ott.solvers.linear.sinkhorn.SinkhornOutput`
:class:`~ott.solvers.linear.sinkhorn.SinkhornOutput`.
kwargs: holds the kwargs to instantiate the or
:class:`~ott.solvers.linear.sinkhorn.Sinkhorn` solver to
compute the regularized OT cost.
Expand All @@ -86,6 +85,7 @@ def monge_gap(
epsilon=epsilon,
relative_epsilon=relative_epsilon,
scale_cost=scale_cost,
return_output=return_output,
**kwargs
)

Expand All @@ -104,13 +104,12 @@ def monge_gap_from_samples(
r"""Monge gap, instantiated in terms of samples before / after applying map.
.. math::
\frac{1}{n} \sum_{i=1}^n c(x_i, y_i)) -
W_{c, \varepsilon}(\frac{1}{n}\sum_i \delta_{x_i},
\frac{1}{n}\sum_i \delta_{y_i})
where :math:`W_{c, \varepsilon}` is an entropy-regularized optimal transport
cost, :attr:`~ott.solvers.linear.sinkhorn.SinkhornOutput.ent_reg_cost`
cost, the :attr:`~ott.solvers.linear.sinkhorn.SinkhornOutput.ent_reg_cost`.
Args:
source: samples from first measure, array of shape ``[n, d]``.
Expand All @@ -129,7 +128,7 @@ def monge_gap_from_samples(
given to rescale the cost such that ``cost_matrix /= scale_cost``.
If `True`, use 'mean'.
return_output: boolean to also return the
:class:`~ott.solvers.linear.sinkhorn.SinkhornOutput`
:class:`~ott.solvers.linear.sinkhorn.SinkhornOutput`.
kwargs: holds the kwargs to instantiate the or
:class:`~ott.solvers.linear.sinkhorn.Sinkhorn` solver to
compute the regularized OT cost.
Expand Down
4 changes: 2 additions & 2 deletions src/ott/neural/models/models.py → src/ott/neural/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
from ott.geometry import geometry
from ott.initializers.linear import initializers as lin_init
from ott.math import matrix_square_root
from ott.neural.models import layers
from ott.neural import layers
from ott.neural.solvers import neuraldual
from ott.problems.linear import linear_problem

Expand Down Expand Up @@ -177,7 +177,7 @@ def __call__(self, x: jnp.ndarray) -> float: # noqa: D102


class MLP(neuraldual.BaseW2NeuralDual):
"""A generic, typically not-convex (w.r.t input) MLP.
"""A generic, not-convex MLP.
Args:
dim_hidden: sequence specifying size of hidden dimensions. The output
Expand Down
14 changes: 0 additions & 14 deletions src/ott/neural/models/__init__.py

This file was deleted.

2 changes: 1 addition & 1 deletion src/ott/neural/solvers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,4 +11,4 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from . import losses, map_estimator, neuraldual
from . import conjugate, map_estimator, neuraldual
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ def solve(

@utils.register_pytree_node
class FenchelConjugateLBFGS(FenchelConjugateSolver):
"""Solve for the conjugate using :class:`jaxopt.LBFGS`.
"""Solve for the conjugate using :class:`~jaxopt.LBFGS`.
Args:
gtol: gradient tolerance
Expand Down
7 changes: 3 additions & 4 deletions src/ott/neural/solvers/map_estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ class MapEstimator:
For instance, :math:`\Delta` can be the
:func:`~ott.tools.sinkhorn_divergence.sinkhorn_divergence`
and :math:`R` the :func:`~ott.solvers.nn.losses.monge_gap_from_samples`
and :math:`R` the :func:`~ott.neural.losses.monge_gap_from_samples`
:cite:`uscidda:23` for a given cost function :math:`c`.
In that case, it estimates a :math:`c`-OT map, i.e. a map :math:`T`
optimal for the Monge problem induced by :math:`c`.
Expand Down Expand Up @@ -129,7 +129,7 @@ def setup(
def regularizer(self) -> Callable[[jnp.ndarray, jnp.ndarray], float]:
"""Regularizer added to the fitting loss.
Can be e.g. the :func:`~ott.solvers.nn.losses.monge_gap_from_samples`.
Can be, e.g. the :func:`~ott.neural.losses.monge_gap_from_samples`.
If no regularizer is passed for solver instantiation,
or regularization weight :attr:`regularizer_strength` is 0,
return 0 by default along with an empty set of log values.
Expand All @@ -142,8 +142,7 @@ def regularizer(self) -> Callable[[jnp.ndarray, jnp.ndarray], float]:
def fitting_loss(self) -> Callable[[jnp.ndarray, jnp.ndarray], float]:
"""Fitting loss to fit the marginal constraint.
Can be for instance the
:func:`~ott.tools.sinkhorn_divergence.sinkhorn_divergence`.
Can be, e.g. :func:`~ott.tools.sinkhorn_divergence.sinkhorn_divergence`.
If no fitting_loss is passed for solver instantiation, return 0 by default,
and no log values.
"""
Expand Down
22 changes: 11 additions & 11 deletions src/ott/neural/solvers/neuraldual.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,14 +35,14 @@

from ott import utils
from ott.geometry import costs
from ott.neural.models import conjugate_solvers, models
from ott.neural import models
from ott.neural.solvers import conjugate
from ott.problems.linear import potentials

__all__ = ["W2NeuralDual", "BaseW2NeuralDual", "W2NeuralTrainState"]
__all__ = ["W2NeuralTrainState", "BaseW2NeuralDual", "W2NeuralDual"]

Train_t = Dict[Literal["train_logs", "valid_logs"], Dict[str, List[float]]]
Callback_t = Callable[[int, potentials.DualPotentials], None]
Conj_t = Optional[conjugate_solvers.FenchelConjugateSolver]

PotentialValueFn_t = Callable[[jnp.ndarray], jnp.ndarray]
PotentialGradientFn_t = Callable[[jnp.ndarray], jnp.ndarray]
Expand All @@ -52,8 +52,8 @@ class W2NeuralTrainState(train_state.TrainState):
"""Adds information about the model's value and gradient to the state.
This extends :class:`~flax.training.train_state.TrainState` to include
the potential methods from :class:`~ott.neural.models.models.BaseW2NeuralDual`
used during training.
the potential methods from the
:class:`~ott.neural.solvers.neuraldual.BaseW2NeuralDual` used during training.
Args:
potential_value_fn: the potential's value function
Expand Down Expand Up @@ -170,8 +170,7 @@ class W2NeuralDual:
denoted source and target, respectively. This is achieved by parameterizing
a Kantorovich potential :math:`f_\theta: \mathbb{R}^n\rightarrow\mathbb{R}`
associated with the :math:`\alpha` measure with an
:class:`~ott.solvers.nn.models.ICNN`, :class:`~ott.solvers.nn.models.MLP`,
or other :class:`~ott.solvers.nn.models.BaseW2NeuralDual`, where
:class:`~ott.neural.models.ICNN` or :class:`~ott.neural.models.MLP`, where
:math:`\nabla f` transports source to target cells. This potential is learned
by optimizing the dual form associated with the negative inner product cost
Expand All @@ -187,10 +186,10 @@ class W2NeuralDual:
transport map from :math:`\beta` to :math:`\alpha`.
This solver estimates the conjugate :math:`f^\star`
with a neural approximation :math:`g` that is fine-tuned
with :class:`~ott.solvers.nn.conjugate_solvers.FenchelConjugateSolver`,
with :class:`~ott.neural.solvers.conjugate.FenchelConjugateSolver`,
which is a combination further described in :cite:`amos:23`.
The :class:`~ott.solvers.nn.models.BaseW2NeuralDual` potentials for
The :class:`~ott.neural.solvers.neuraldual.BaseW2NeuralDual` potentials for
``neural_f`` and ``neural_g`` can
1. both provide the values of the potentials :math:`f` and :math:`g`, or
Expand All @@ -199,7 +198,7 @@ class W2NeuralDual:
via the Fenchel conjugate as discussed in :cite:`amos:23`.
The potential's value or gradient mapping is specified via
:attr:`~ott.solvers.nn.models.BaseW2NeuralDual.is_potential`.
:attr:`~ott.neural.solvers.neuraldual.BaseW2NeuralDual.is_potential`.
Args:
dim_data: input dimensionality of data required for network init
Expand Down Expand Up @@ -242,7 +241,8 @@ def __init__(
rng: Optional[jax.Array] = None,
pos_weights: bool = True,
beta: float = 1.0,
conjugate_solver: Conj_t = conjugate_solvers.DEFAULT_CONJUGATE_SOLVER,
conjugate_solver: Optional[conjugate.FenchelConjugateSolver
] = conjugate.DEFAULT_CONJUGATE_SOLVER,
amortization_loss: Literal["objective", "regression"] = "regression",
parallel_updates: bool = True,
):
Expand Down
2 changes: 1 addition & 1 deletion tests/neural/icnn_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
import jax.numpy as jnp
import numpy as np
import pytest
from ott.neural.models import models
from ott.neural import models


@pytest.mark.fast()
Expand Down
3 changes: 1 addition & 2 deletions tests/neural/losses_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,7 @@
import numpy as np
import pytest
from ott.geometry import costs
from ott.neural.models import models
from ott.neural.solvers import losses
from ott.neural import losses, models


@pytest.mark.fast()
Expand Down
4 changes: 2 additions & 2 deletions tests/neural/map_estimator_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@
import pytest
from ott import datasets
from ott.geometry import pointcloud
from ott.neural.models import models
from ott.neural.solvers import losses, map_estimator
from ott.neural import losses, models
from ott.neural.solvers import map_estimator
from ott.tools import sinkhorn_divergence


Expand Down
2 changes: 1 addition & 1 deletion tests/neural/meta_initializer_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from flax import linen as nn
from ott.geometry import pointcloud
from ott.initializers.linear import initializers as linear_init
from ott.neural.models import models as nn_init
from ott.neural import models as nn_init
from ott.problems.linear import linear_problem
from ott.solvers.linear import sinkhorn

Expand Down
Loading

0 comments on commit ece4f19

Please sign in to comment.