Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix/neural docs #473

Merged
merged 6 commits into from
Nov 28, 2023
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion docs/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ Install ``OTT`` from `PyPI <https://pypi.org/project/ott-jax/>`_ as:

pip install ott-jax

or with the :mod:`neural OT solvers <ott.solvers.nn>` dependencies:
or with the :mod:`neural OT <ott.neural>` dependencies:

.. code-block:: bash

Expand Down Expand Up @@ -90,6 +90,8 @@ Packages
When the problem is *not* convex, which is the case for most other uses of
this toolbox, the initialization can play a decisive role to reach a useful
solution.
- :mod:`ott.neural` .. TODO(marcocuturi): add some nice text here please.
michalk8 marked this conversation as resolved.
Show resolved Hide resolved

- :mod:`ott.tools` provides an interface to exploit OT solutions, as produced by
solvers from the :mod:`ott.solvers` module. Such tasks include computing
approximations to Wasserstein distances :cite:`genevay:18,sejourne:19`,
Expand All @@ -114,6 +116,7 @@ Packages
problems/index
solvers/index
initializers/index
neural/index
tools
math
utils
Expand Down
7 changes: 3 additions & 4 deletions docs/initializers/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,12 @@ solved with a :class:`~ott.solvers.linear.discrete_barycenter.FixedBarycenter`
solver.

When the problem is *not* convex, which describes pretty much all other pairings
of problems/solvers in ``OTT``, notably quadratic and neural-network based
below, initializers play a more important role: different initializations will
very likely result in different end solutions.
of problems/solvers in ``OTT``, notably the quadratic problem , initializers
play a more important role: different initializations will very likely result
in different end solutions.

.. toctree::
:maxdepth: 2

linear
quadratic
nn
17 changes: 0 additions & 17 deletions docs/initializers/nn.rst

This file was deleted.

36 changes: 36 additions & 0 deletions docs/neural/index.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
ott.neural
==========
.. module:: ott.neural
.. currentmodule:: ott.neural

Under reconstruction. .. TODO(marcocuturi): add some nice text here please.
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@marcocuturi can you please add something here?

Copy link
Contributor

@marcocuturi marcocuturi Nov 28, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

wasnt sure you wanted it now :)

In contrast to most methods presented in :mod:~ott.solvers, which output vectors or matrices, the goal of the :mod:~ott.neural module is to parameterize optimal transport maps and couplings as neural networks. These neural networks can generalize to new samples, in the sense that they can be conveniently evaluated outside training samples. This module implements layers, models and solvers to estimate such neural networks.


.. toctree::
:maxdepth: 2

solvers

Models
------
.. autosummary::
:toctree: _autosummary

models.ICNN
models.MLP
models.MetaInitializer

Losses
------
.. autosummary::
:toctree: _autosummary

losses.monge_gap
losses.monge_gap_from_samples

Layers
------
.. autosummary::
:toctree: _autosummary

layers.PositiveDense
layers.PosDefPotentials
24 changes: 24 additions & 0 deletions docs/neural/solvers.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
ott.neural.solvers
==================
.. module:: ott.neural.solvers
.. currentmodule:: ott.neural.solvers

Under reconstruction. .. TODO(marcocuturi): add some nice text here please.
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@marcocuturi also here, thank you!

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This module implements various solvers to estimate optimal transport between two probability measures, through samples, parameterized as neural networks. These neural networks are described in :mod:~ott.neural.models, borrowing lower-level components from :mod:~ott.neural.layers, building from the
flax <https://flax.readthedocs.io/en/latest/examples.html>__ library.


Solvers
-------
.. autosummary::
:toctree: _autosummary

map_estimator.MapEstimator
neuraldual.W2NeuralDual
neuraldual.BaseW2NeuralDual

Conjugate Solvers
-----------------
.. autosummary::
:toctree: _autosummary

conjugate.FenchelConjugateLBFGS
conjugate.FenchelConjugateSolver
conjugate.ConjugateResults
7 changes: 1 addition & 6 deletions docs/problems/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,10 @@ ott.problems
.. module:: ott.problems

The :mod:`ott.problems` module describes the low level optimal transport
problems that are solved by :mod:`ott.solvers`. These problems are
loosely divided into two categories, first finite-sample based problems, as in
:mod:`ott.problems.linear` and :mod:`ott.problems.quadratic`, or relying on
iterators. In that latter category, :mod:`ott.problems.nn` contains synthetic
data iterators.
problems that are solved by :mod:`ott.solvers`.

.. toctree::
:maxdepth: 2

linear
quadratic
nn
16 changes: 0 additions & 16 deletions docs/problems/nn.rst

This file was deleted.

5 changes: 1 addition & 4 deletions docs/solvers/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -9,16 +9,13 @@ problems. More advanced solvers, notably quadratic in
:mod:`ott.solvers.quadratic`, rely on calls to linear solvers as subroutines.
That property itself is implemented in the more abstract
:class:`~ott.solvers.was_solver.WassersteinSolver` class, which provides a
lower-level template at the interface between the two. Neural based solvers in
:mod:`ott.solvers.nn` live on a different category of their own, since they
typically solve the Monge formulation of OT.
lower-level template at the interface between the two.

.. toctree::
:maxdepth: 2

linear
quadratic
nn

Wasserstein Solver
------------------
Expand Down
3 changes: 1 addition & 2 deletions docs/solvers/linear.rst
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,7 @@ their own, either the Sinkhorn
:class:`~ott.solvers.linear.sinkhorn.Sinkhorn` or Low-Rank
:class:`~ott.solvers.linear.sinkhorn_lr.LRSinkhorn` solvers, to match two
datasets. They also appear as subroutines for more advanced solvers in the
:mod:`ott.solvers` module, notably :mod:`ott.solvers.quadratic` or
:mod:`ott.solvers.nn`.
:mod:`ott.solvers` module, notably :mod:`ott.solvers.quadratic`.

Sinkhorn Solvers
----------------
Expand Down
42 changes: 0 additions & 42 deletions docs/solvers/nn.rst

This file was deleted.

2 changes: 2 additions & 0 deletions docs/spelling/technical.txt
Original file line number Diff line number Diff line change
Expand Up @@ -134,10 +134,12 @@ subpopulation
subpopulations
subsample
subsampled
subsamples
subsampling
thresholding
transcriptome
undirected
univariate
unscaled
url
vectorized
Expand Down
7 changes: 0 additions & 7 deletions docs/tools.rst
Original file line number Diff line number Diff line change
Expand Up @@ -52,13 +52,6 @@ Clustering
k_means.k_means
k_means.KMeansOutput

Mapping Estimation
------------------
.. autosummary::
:toctree: _autosummary

map_estimator.MapEstimator

ott.tools.gaussian_mixture package
----------------------------------
.. currentmodule:: ott.tools.gaussian_mixture
Expand Down
2 changes: 1 addition & 1 deletion src/ott/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ def create_gaussian_mixture_samplers(
valid_batch_size: int = 2048,
rng: Optional[jax.Array] = None,
) -> Tuple[Dataset, Dataset, int]:
"""Gaussian samplers for :class:`~ott.solvers.nn.neuraldual.W2NeuralDual`.
"""Gaussian samplers.

Args:
name_source: name of the source sampler
Expand Down
28 changes: 16 additions & 12 deletions src/ott/geometry/distrib_costs.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,8 @@
# limitations under the License.
from typing import Any, Optional

import jax
import jax.numpy as jnp
import jax.tree_util as jtu

from ott.geometry import costs, pointcloud
from ott.problems.linear import linear_problem
Expand All @@ -25,20 +25,23 @@
]


@jax.tree_util.register_pytree_node_class
@jtu.register_pytree_node_class
class UnivariateWasserstein(costs.CostFn):
"""1D Wasserstein cost for two 1D distributions.

This ground cost between considers vectors as a family of values. The
Wasserstein distance between them is the 1D OT cost, using a user-defined
This ground cost between considers vectors as a family of values.
The Wasserstein distance between them is the 1D OT cost, using a user-defined
ground cost.

Args:
kwargs: arguments passed on when calling the
ground_cost: Cost used to compute the 1D optimal transport between vector,
should be a translation-invariant (TI) cost for correctness.
If :obj:`None`, defaults to :class:`~ott.geometry.costs.SqEuclidean`.
solver: 1D optimal transport solver.
kwargs: Arguments passed on when calling the
:class:`~ott.solvers.linear.univariate.UnivariateSolver`. May include
random key, or specific instructions to subsample or compute using
quantiles.

"""

def __init__(
Expand All @@ -47,22 +50,21 @@ def __init__(
solver: Optional[univariate.UnivariateSolver] = None,
**kwargs: Any
):
from ott.solvers.linear import univariate
super().__init__()

self.ground_cost = (
costs.SqEuclidean() if ground_cost is None else ground_cost
)

self._solver = univariate.UnivariateSolver() if solver is None else solver
self._kwargs_solve = kwargs
self._kwargs_solve["return_transport"] = False
michalk8 marked this conversation as resolved.
Show resolved Hide resolved

def pairwise(self, x: jnp.ndarray, y: jnp.ndarray) -> float:
"""Wasserstein distance between :math:`x` and :math:`y` seen as a 1D dist.

Args:
x: vector, array of shape ``[n,]``
y: vector, array of shape ``[m,]``
x: Array of shape ``[n,]``.
y: Array of shape ``[m,]``.

Returns:
The transport cost.
Expand All @@ -77,8 +79,10 @@ def pairwise(self, x: jnp.ndarray, y: jnp.ndarray) -> float:
return jnp.squeeze(out.ot_costs)

def tree_flatten(self): # noqa: D102
return (self.ground_cost,), (self._solver,)
return (self.ground_cost,), (self._solver, self._kwargs_solve)

@classmethod
def tree_unflatten(cls, aux_data, children): # noqa: D102
return cls(*children, *aux_data)
ground_cost, = children
solver, solve_kwargs = aux_data
return cls(ground_cost, solver, **solve_kwargs)
2 changes: 1 addition & 1 deletion src/ott/neural/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,4 +11,4 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from . import models, solvers
from . import layers, losses, models, solvers
Loading
Loading