From 32962a1e07fa1859b94fd53c4e00d3e4ee3edb47 Mon Sep 17 00:00:00 2001 From: michalk8 <46717574+michalk8@users.noreply.github.com> Date: Tue, 22 Nov 2022 18:38:45 +0100 Subject: [PATCH] Misc/project structure (#176) * Reorganize repo structure * Update PC docs * Update imports, fix some types * Fix more types, pet-peeves * Fix tests * Update cost funcs and potentials * Fix LR initializer * Fix k-means initializer * Move `utils` * Update imports in notebooks * Update geometry docs * Update initializers * Update math docs * Update problem docstrings * Update `solvers` docstrings * Update `tools` docstrings * Remove remaining `core` mentions from docstrings * Start updating documentation * Fix typing * Update solvers docs * Add initializers * Update docs * Fix MetaOT links * Fix bibliography links * Fix more links in the notebooks * Follow line length in README.md * Update `tests` structure * Update badges * Add TODOs, fix citation in `index.rst`, move `implicit_diff` * Fix implicit_diff, TODOs in costs * Use `jax.lax.cond` in `UnbalancedBures` * Fix `UnbalancedBures` * Update CI versions * Fix UnbalancedBures's norm --- .editorconfig | 2 +- .github/workflows/lint.yml | 6 +- .github/workflows/notebook_tests.yml | 6 +- .github/workflows/publish_to_pypi.yml | 4 +- .github/workflows/tests.yml | 4 +- .gitignore | 2 +- CONTRIBUTING.md | 2 +- README.md | 49 +- docs/Makefile | 1 + docs/conf.py | 4 + docs/core.rst | 115 ---- docs/geometry.rst | 7 + docs/index.rst | 45 +- docs/initializers/index.rst | 16 + docs/initializers/linear.rst | 25 + docs/initializers/nn.rst | 14 + docs/initializers/quadratic.rst | 14 + docs/math.rst | 30 + docs/problems/index.rst | 15 + docs/problems/linear.rst | 22 + docs/problems/quadratic.rst | 22 + docs/references.bib | 14 +- docs/solvers/index.rst | 23 + docs/solvers/linear.rst | 42 ++ docs/solvers/nn.rst | 21 + docs/solvers/quadratic.rst | 23 + docs/tools.rst | 1 - {ott/examples => examples}/fairness/config.py | 0 {ott/examples => examples}/fairness/data.py | 6 +- {ott/examples => examples}/fairness/losses.py | 0 {ott/examples => examples}/fairness/main.py | 0 {ott/examples => examples}/fairness/models.py | 0 {ott/examples => examples}/fairness/train.py | 0 .../soft_error/config.py | 0 {ott/examples => examples}/soft_error/data.py | 0 .../soft_error/losses.py | 0 {ott/examples => examples}/soft_error/main.py | 0 .../examples => examples}/soft_error/model.py | 0 .../examples => examples}/soft_error/train.py | 0 ott/__init__.py | 2 +- ott/_version.py | 6 +- ott/core/__init__.py | 44 -- ott/core/_math_utils.py | 33 - ott/core/momentum.py | 71 --- ott/core/problems.py | 93 --- ott/geometry/__init__.py | 7 +- ott/geometry/costs.py | 228 ++++--- ott/geometry/epsilon_scheduler.py | 3 + ott/geometry/geometry.py | 19 +- ott/geometry/graph.py | 6 +- ott/geometry/grid.py | 18 +- ott/geometry/low_rank.py | 15 +- ott/geometry/pointcloud.py | 19 +- ott/{core => geometry}/segment.py | 4 +- ott/initializers/__init__.py | 1 + ott/initializers/linear/__init__.py | 1 + .../linear}/initializers.py | 245 +------- .../linear}/initializers_lr.py | 69 +- ott/initializers/nn/__init__.py | 1 + ott/initializers/nn/initializers.py | 228 +++++++ ott/initializers/quadratic/__init__.py | 1 + .../quadratic/initializers.py} | 49 +- ott/math/__init__.py | 7 + ott/{core => math}/decomposition.py | 6 +- ott/{core => math}/fixed_point_loop.py | 7 +- ott/{geometry => math}/matrix_square_root.py | 19 +- ott/{core => math}/unbalanced_functions.py | 0 ott/{geometry/ops.py => math/utils.py} | 60 +- ott/problems/__init__.py | 1 + ott/problems/linear/__init__.py | 1 + ott/problems/linear/barycenter_problem.py | 181 ++++++ .../linear/linear_problem.py} | 3 + ott/{core => problems/linear}/potentials.py | 38 +- ott/problems/quadratic/__init__.py | 1 + .../quadratic/gw_barycenter.py} | 219 +------ ott/problems/quadratic/quadratic_costs.py | 34 + .../quadratic/quadratic_problem.py} | 88 +-- ott/solvers/__init__.py | 1 + ott/solvers/linear/__init__.py | 8 + .../linear/acceleration.py} | 88 ++- .../linear}/continuous_barycenter.py | 27 +- .../linear}/discrete_barycenter.py | 20 +- .../linear}/implicit_differentiation.py | 57 +- ott/{core => solvers/linear}/sinkhorn.py | 87 +-- ott/{core => solvers/linear}/sinkhorn_lr.py | 91 +-- ott/solvers/nn/__init__.py | 1 + ott/{core => solvers/nn}/icnn.py | 8 +- ott/{core => solvers/nn}/layers.py | 28 +- ott/{core => solvers/nn}/neuraldual.py | 33 +- ott/solvers/quadratic/__init__.py | 1 + .../quadratic}/gromov_wasserstein.py | 85 ++- .../quadratic}/gw_barycenter.py | 39 +- ott/{core => solvers}/was_solver.py | 17 +- ott/tools/gaussian_mixture/fit_gmm.py | 15 +- .../gaussian_mixture/gaussian_mixture_pair.py | 2 +- ott/tools/gaussian_mixture/scale_tril.py | 3 +- ott/tools/k_means.py | 2 +- ott/tools/segment_sinkhorn.py | 43 +- ott/tools/sinkhorn_divergence.py | 47 +- ott/tools/transport.py | 100 ++- ott/types.py | 22 + ott/{core/dataclasses.py => utils.py} | 2 + tests/core/initializers_test.py | 590 ------------------ .../{geometry_costs_test.py => costs_test.py} | 0 tests/geometry/graph_test.py | 11 +- .../{geometry_lr_test.py => low_rank_test.py} | 0 ...cloud_apply_test.py => pointcloud_test.py} | 0 tests/geometry/scaling_cost_test.py | 5 +- ...etry_subset_test.py => subsetting_test.py} | 4 +- .../initializers/linear/sinkhorn_init_test.py | 331 ++++++++++ .../linear/sinkhorn_lr_init_test.py | 171 +++++ tests/initializers/quadratic/gw_init_test.py | 132 ++++ .../geometry_lse_test.py => math/lse_test.py} | 4 +- .../matrix_square_root_test.py | 4 +- .../linear}/potentials_test.py | 27 +- .../linear}/continuous_barycenter_test.py | 199 +----- .../linear}/discrete_barycenter_test.py | 4 +- .../linear}/sinkhorn_diff_test.py | 7 +- .../linear}/sinkhorn_grid_test.py | 2 +- .../linear}/sinkhorn_lr_test.py | 7 +- .../linear/sinkhorn_misc_test.py} | 47 +- .../{core => solvers/linear}/sinkhorn_test.py | 5 +- tests/{core => solvers/nn}/icnn_test.py | 18 +- tests/{core => solvers/nn}/neuraldual_test.py | 6 +- .../solvers/quadratic/fgw_barycenter_test.py | 78 +++ .../quadratic/fgw_test.py} | 43 +- tests/solvers/quadratic/gw_barycenter_test.py | 113 ++++ .../quadratic/gw_test.py} | 35 +- .../tools/gaussian_mixture/scale_tril_test.py | 2 +- tests/tools/segment_sinkhorn_test.py | 2 +- tests/tools/sinkhorn_divergence_test.py | 2 +- tests/tools/transport_test.py | 4 +- 132 files changed, 2729 insertions(+), 2314 deletions(-) delete mode 100644 docs/core.rst create mode 100644 docs/initializers/index.rst create mode 100644 docs/initializers/linear.rst create mode 100644 docs/initializers/nn.rst create mode 100644 docs/initializers/quadratic.rst create mode 100644 docs/math.rst create mode 100644 docs/problems/index.rst create mode 100644 docs/problems/linear.rst create mode 100644 docs/problems/quadratic.rst create mode 100644 docs/solvers/index.rst create mode 100644 docs/solvers/linear.rst create mode 100644 docs/solvers/nn.rst create mode 100644 docs/solvers/quadratic.rst rename {ott/examples => examples}/fairness/config.py (100%) rename {ott/examples => examples}/fairness/data.py (98%) rename {ott/examples => examples}/fairness/losses.py (100%) rename {ott/examples => examples}/fairness/main.py (100%) rename {ott/examples => examples}/fairness/models.py (100%) rename {ott/examples => examples}/fairness/train.py (100%) rename {ott/examples => examples}/soft_error/config.py (100%) rename {ott/examples => examples}/soft_error/data.py (100%) rename {ott/examples => examples}/soft_error/losses.py (100%) rename {ott/examples => examples}/soft_error/main.py (100%) rename {ott/examples => examples}/soft_error/model.py (100%) rename {ott/examples => examples}/soft_error/train.py (100%) delete mode 100644 ott/core/__init__.py delete mode 100644 ott/core/_math_utils.py delete mode 100644 ott/core/momentum.py delete mode 100644 ott/core/problems.py rename ott/{core => geometry}/segment.py (99%) create mode 100644 ott/initializers/__init__.py create mode 100644 ott/initializers/linear/__init__.py rename ott/{core => initializers/linear}/initializers.py (52%) rename ott/{core => initializers/linear}/initializers_lr.py (91%) create mode 100644 ott/initializers/nn/__init__.py create mode 100644 ott/initializers/nn/initializers.py create mode 100644 ott/initializers/quadratic/__init__.py rename ott/{core/quad_initializers.py => initializers/quadratic/initializers.py} (79%) create mode 100644 ott/math/__init__.py rename ott/{core => math}/decomposition.py (97%) rename ott/{core => math}/fixed_point_loop.py (98%) rename ott/{geometry => math}/matrix_square_root.py (95%) rename ott/{core => math}/unbalanced_functions.py (100%) rename ott/{geometry/ops.py => math/utils.py} (60%) create mode 100644 ott/problems/__init__.py create mode 100644 ott/problems/linear/__init__.py create mode 100644 ott/problems/linear/barycenter_problem.py rename ott/{core/linear_problems.py => problems/linear/linear_problem.py} (97%) rename ott/{core => problems/linear}/potentials.py (86%) create mode 100644 ott/problems/quadratic/__init__.py rename ott/{core/bar_problems.py => problems/quadratic/gw_barycenter.py} (55%) create mode 100644 ott/problems/quadratic/quadratic_costs.py rename ott/{core/quad_problems.py => problems/quadratic/quadratic_problem.py} (89%) create mode 100644 ott/solvers/__init__.py create mode 100644 ott/solvers/linear/__init__.py rename ott/{core/anderson.py => solvers/linear/acceleration.py} (61%) rename ott/{core => solvers/linear}/continuous_barycenter.py (88%) rename ott/{core => solvers/linear}/discrete_barycenter.py (94%) rename ott/{core => solvers/linear}/implicit_differentiation.py (87%) rename ott/{core => solvers/linear}/sinkhorn.py (94%) rename ott/{core => solvers/linear}/sinkhorn_lr.py (87%) create mode 100644 ott/solvers/nn/__init__.py rename ott/{core => solvers/nn}/icnn.py (97%) rename ott/{core => solvers/nn}/layers.py (82%) rename ott/{core => solvers/nn}/neuraldual.py (92%) create mode 100644 ott/solvers/quadratic/__init__.py rename ott/{core => solvers/quadratic}/gromov_wasserstein.py (88%) rename ott/{core => solvers/quadratic}/gw_barycenter.py (89%) rename ott/{core => solvers}/was_solver.py (87%) create mode 100644 ott/types.py rename ott/{core/dataclasses.py => utils.py} (96%) delete mode 100644 tests/core/initializers_test.py rename tests/geometry/{geometry_costs_test.py => costs_test.py} (100%) rename tests/geometry/{geometry_lr_test.py => low_rank_test.py} (100%) rename tests/geometry/{geometry_pointcloud_apply_test.py => pointcloud_test.py} (100%) rename tests/geometry/{geometry_subset_test.py => subsetting_test.py} (98%) create mode 100644 tests/initializers/linear/sinkhorn_init_test.py create mode 100644 tests/initializers/linear/sinkhorn_lr_init_test.py create mode 100644 tests/initializers/quadratic/gw_init_test.py rename tests/{geometry/geometry_lse_test.py => math/lse_test.py} (96%) rename tests/{geometry => math}/matrix_square_root_test.py (98%) rename tests/{core => problems/linear}/potentials_test.py (90%) rename tests/{core => solvers/linear}/continuous_barycenter_test.py (64%) rename tests/{core => solvers/linear}/discrete_barycenter_test.py (97%) rename tests/{core => solvers/linear}/sinkhorn_diff_test.py (99%) rename tests/{core => solvers/linear}/sinkhorn_grid_test.py (99%) rename tests/{core => solvers/linear}/sinkhorn_lr_test.py (95%) rename tests/{core/sinkhorn_extra_test.py => solvers/linear/sinkhorn_misc_test.py} (89%) rename tests/{core => solvers/linear}/sinkhorn_test.py (99%) rename tests/{core => solvers/nn}/icnn_test.py (80%) rename tests/{core => solvers/nn}/neuraldual_test.py (96%) create mode 100644 tests/solvers/quadratic/fgw_barycenter_test.py rename tests/{core/fused_gromov_wasserstein_test.py => solvers/quadratic/fgw_test.py} (91%) create mode 100644 tests/solvers/quadratic/gw_barycenter_test.py rename tests/{core/gromov_wasserstein_test.py => solvers/quadratic/gw_test.py} (95%) diff --git a/.editorconfig b/.editorconfig index a20daede6..e480103dd 100644 --- a/.editorconfig +++ b/.editorconfig @@ -3,9 +3,9 @@ root = true [*] end_of_line = lf insert_final_newline = true +charset = utf-8 [*py] -charset = utf-8 indent_size = 2 indent_style = space max_line_length = 80 diff --git a/.github/workflows/lint.yml b/.github/workflows/lint.yml index b280db886..9c091ef1b 100644 --- a/.github/workflows/lint.yml +++ b/.github/workflows/lint.yml @@ -15,13 +15,13 @@ jobs: os: [ubuntu-latest] steps: - - uses: actions/checkout@v2 + - uses: actions/checkout@v3 - name: Set up Python ${{ matrix.python }} - uses: actions/setup-python@v2 + uses: actions/setup-python@v4 with: python-version: ${{ matrix.python }} - - uses: actions/cache@v2 + - uses: actions/cache@v3 with: path: ~/.cache/pre-commit key: precommit-${{ env.pythonLocation }}-${{ hashFiles('**/.pre-commit-config.yaml') }} diff --git a/.github/workflows/notebook_tests.yml b/.github/workflows/notebook_tests.yml index b89b301e8..a3f327e5d 100644 --- a/.github/workflows/notebook_tests.yml +++ b/.github/workflows/notebook_tests.yml @@ -12,12 +12,12 @@ jobs: runs-on: ${{ matrix.os }} strategy: matrix: - python-version: ['3.8'] + python-version: [3.8] os: [ubuntu-latest] steps: - - uses: actions/checkout@v2 - - uses: actions/setup-python@v2 + - uses: actions/checkout@v3 + - uses: actions/setup-python@v4 with: python-version: ${{ matrix.python-version }} diff --git a/.github/workflows/publish_to_pypi.yml b/.github/workflows/publish_to_pypi.yml index b9efc2088..6d26315fd 100644 --- a/.github/workflows/publish_to_pypi.yml +++ b/.github/workflows/publish_to_pypi.yml @@ -13,9 +13,9 @@ jobs: runs-on: ubuntu-latest steps: - - uses: actions/checkout@v2 + - uses: actions/checkout@v3 - name: Set up Python - uses: actions/setup-python@v2 + uses: actions/setup-python@v4 with: python-version: 3.x - name: Install dependencies diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 9570b6a7d..be1e3dbd5 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -18,8 +18,8 @@ jobs: test_mark: [fast, all] steps: - - uses: actions/checkout@v2 - - uses: actions/setup-python@v2 + - uses: actions/checkout@v3 + - uses: actions/setup-python@v4 with: python-version: ${{ matrix.python-version }} diff --git a/.gitignore b/.gitignore index e289be07d..d8f1864bd 100644 --- a/.gitignore +++ b/.gitignore @@ -161,7 +161,7 @@ cython_debug/ # generated documentation docs/html -docs/_autosummary +**/_autosummary # macos **/.DS_Store diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 6038bcf66..34eb5da00 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -9,7 +9,7 @@ to the project, participating in discussions or raising issues. 1. fork the repository using the **Fork** button on GitHub or the following [link](https://github.com/ott-jax/ott/fork) 2. ```bash - git clone https://github.com/YOUR_USERNAME/ott + git clone https://github.com//ott cd ott pip install -e .'[dev,test]' pre-commit install diff --git a/README.md b/README.md index 48e8e75f9..cb44680a9 100644 --- a/README.md +++ b/README.md @@ -1,30 +1,44 @@ logo -# Optimal Transport Tools (OTT). - +# Optimal Transport Tools (OTT) +[![Downloads](https://pepy.tech/badge/ott-jax)](https://pypi.org/project/ott-jax/) [![Tests](https://img.shields.io/github/workflow/status/ott-jax/ott/tests/main)](https://github.com/ott-jax/ott/actions/workflows/tests.yml) [![Docs](https://img.shields.io/readthedocs/ott-jax/latest)](https://ott-jax.readthedocs.io/en/latest/) [![Coverage](https://img.shields.io/codecov/c/github/ott-jax/ott/main)](https://app.codecov.io/gh/ott-jax/ott) -**See [full documentation](https://ott-jax.readthedocs.io/en/latest/).** +**See the [full documentation](https://ott-jax.readthedocs.io/en/latest/).** ## What is OTT-JAX? - -A JAX powered library to compute optimal transport at scale and on accelerators, OTT-JAX includes the fastest implementation of the Sinkhorn algorithm you will find around. We have implemented all tweaks (scheduling, acceleration, initializations) and extensions (low-rank), that can be used directly, or within more advanced problems (Gromov-Wasserstein, barycenters). Some of JAX features, including [JIT](https://jax.readthedocs.io/en/latest/notebooks/quickstart.html#Using-jit-to-speed-up-functions), [auto-vectorization](https://jax.readthedocs.io/en/latest/notebooks/quickstart.html#Auto-vectorization-with-vmap) and [implicit differentiation](https://jax.readthedocs.io/en/latest/notebooks/Custom_derivative_rules_for_Python_code.html) work towards the goal of having end-to-end differentiable outputs. OTT-JAX is developed by a team of researchers from Apple, Google, Meta and many academic contributors, including TU München, Oxford, ENSAE/IP Paris and the Hebrew University. +A JAX powered library to compute optimal transport at scale and on accelerators, OTT-JAX includes the fastest +implementation of the Sinkhorn algorithm you will find around. We have implemented all tweaks (scheduling, +acceleration, initializations) and extensions (low-rank), that can be used directly, or within more advanced problems +(Gromov-Wasserstein, barycenters). Some of JAX features, including +[JIT](https://jax.readthedocs.io/en/latest/notebooks/quickstart.html#Using-jit-to-speed-up-functions), +[auto-vectorization](https://jax.readthedocs.io/en/latest/notebooks/quickstart.html#Auto-vectorization-with-vmap) and +[implicit differentiation](https://jax.readthedocs.io/en/latest/notebooks/Custom_derivative_rules_for_Python_code.html) +work towards the goal of having end-to-end differentiable outputs. OTT-JAX is developed by a team of researchers +from Apple, Google, Meta and many academic contributors, including TU München, Oxford, ENSAE/IP Paris and the +Hebrew University. ## What is optimal transport? +Optimal transport can be loosely described as the branch of mathematics and optimization that studies +*matching problems*: given two families of points, and a cost function on pairs of points, find a `good' (low cost) way +to associate bijectively to every point in the first family another in the second. -Optimal transport can be loosely described as the branch of mathematics and optimization that studies *matching problems*: given two families of points, and a cost function on pairs of points, find a `good' (low cost) way to associate bijectively to every point in the first family another in the second. - -Such problems appear in all areas of science, are easy to describe, yet hard to solve. Indeed, while matching optimally two sets of *n* points using a pairwise cost can be solved with the [Hungarian algorithm](https://en.wikipedia.org/wiki/Hungarian_algorithm), solving it costs an order of $O(n^3)$ operations, and lacks flexibility, since one may want to couple families of different sizes. +Such problems appear in all areas of science, are easy to describe, yet hard to solve. Indeed, while matching optimally +two sets of *n* points using a pairwise cost can be solved with the +[Hungarian algorithm](https://en.wikipedia.org/wiki/Hungarian_algorithm), solving it costs an order of $O(n^3)$ +operations, and lacks flexibility, since one may want to couple families of different sizes. -Optimal transport extends all of this, through faster algorithms (in $n^2$ or even linear in $n$) along with numerous generalizations that can help it handle weighted sets of different size, partial matchings, and even more evolved so-called quadratic matching problems. +Optimal transport extends all of this, through faster algorithms (in $n^2$ or even linear in $n$) along with numerous +generalizations that can help it handle weighted sets of different size, partial matchings, and even more evolved +so-called quadratic matching problems. -In the simple toy example below, we compute the optimal coupling matrix between two point clouds sampled randomly (2D vectors, compared with the squared Euclidean distance): +In the simple toy example below, we compute the optimal coupling matrix between two point clouds sampled randomly +(2D vectors, compared with the squared Euclidean distance): ## Example - -```py +```python import jax import jax.numpy as jnp from ott.tools import transport @@ -41,17 +55,22 @@ ot = transport.solve(x, y, a=a, b=b) P = ot.matrix ``` -The call to `solve` above works out the optimal transport solution. The `ot` object contains a transport matrix (here of size $12\times 14$) that quantifies a `link strength` between each point of the first point cloud, to one or more points from the second, as illustrated in the plot below. In this toy example, most choices were arbitrary, and are reflected in the crude `solve` API. We provide far more flexibility to define custom cost functions, objectives, and solvers, as detailed in the [full documentation](https://ott-jax.readthedocs.io/en/latest/). +The call to `solve` above works out the optimal transport solution. The `ot` object contains a transport matrix +(here of size $12\times 14$) that quantifies a `link strength` between each point of the first point cloud, to one or +more points from the second, as illustrated in the plot below. In this toy example, most choices were arbitrary, and +are reflected in the crude `solve` API. We provide far more flexibility to define custom cost functions, objectives, +and solvers, as detailed in the [full documentation](https://ott-jax.readthedocs.io/en/latest/). ![obtained coupling](https://raw.githubusercontent.com/ott-jax/ott/main/images/couplings.png) -## Citation +## Citation If you have found this work useful, please consider citing this reference: ``` @article{cuturi2022optimal, title={Optimal Transport Tools (OTT): A JAX Toolbox for all things Wasserstein}, - author={Cuturi, Marco and Meng-Papaxanthos, Laetitia and Tian, Yingtao and Bunne, Charlotte and Davis, Geoff and Teboul, Olivier}, + author={Cuturi, Marco and Meng-Papaxanthos, Laetitia and Tian, Yingtao and Bunne, Charlotte and + Davis, Geoff and Teboul, Olivier}, journal={arXiv preprint arXiv:2201.12324}, year={2022} } diff --git a/docs/Makefile b/docs/Makefile index 3db4deda9..2dab86e59 100644 --- a/docs/Makefile +++ b/docs/Makefile @@ -22,3 +22,4 @@ help: clean: @rm -rf $(BUILDDIR)/ @rm -rf $(SOURCEDIR)/_autosummary + @rm -rf $(SOURCEDIR)/**/_autosummary diff --git a/docs/conf.py b/docs/conf.py index 6b8dbd0bc..fc7b55099 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -73,6 +73,10 @@ source_suffix = ['.rst'] autosummary_generate = True +autosummary_filename_map = { + "ott.solvers.linear.sinkhorn.sinkhorn": + "ott.solvers.linear.sinkhorn.sinkhorn-function" +} autodoc_typehints = 'description' diff --git a/docs/core.rst b/docs/core.rst deleted file mode 100644 index 0c62c88e5..000000000 --- a/docs/core.rst +++ /dev/null @@ -1,115 +0,0 @@ -.. _core: - -ott.core package -================ -.. currentmodule:: ott.core -.. automodule:: ott.core - -The core package contains definitions of various OT problems, starting -from the most simple, the linear OT problem, to more advanced problems -such as quadratic, or involving multiple measures, the barycenter problem. -We follow with the classic :class:`~ott.core.sinkhorn.sinkhorn` routine (essentially a -wrapper for the :class:`~ott.core.sinkhorn.Sinkhorn` solver class) :cite:`cuturi:13,sejourne:19`. -We also provide an analogous low-rank Sinkhorn solver :cite:`scetbon:21` to handle very large instances. -Both are used within our Wasserstein barycenter solvers :cite:`benamou:15,janati:20a`, as well as our -Gromov-Wasserstein solver :cite:`memoli:11,scetbon:22`. We also provide an implementation of -input convex neural networks :cite:`amos:17`, a NN that can be used to estimate OT :cite:`makkuva:20`. - -OT Problems ------------ -.. autosummary:: - :toctree: _autosummary - - linear_problems.LinearProblem - quad_problems.QuadraticProblem - bar_problems.BarycenterProblem - bar_problems.GWBarycenterProblem - -Sinkhorn --------- -.. autosummary:: - :toctree: _autosummary - - sinkhorn.sinkhorn - sinkhorn.Sinkhorn - sinkhorn.SinkhornOutput - -Sinkhorn Dual Initializers --------------------------- -.. autosummary:: - :toctree: _autosummary - - initializers.DefaultInitializer - initializers.GaussianInitializer - initializers.SortingInitializer - initializers.MetaInitializer - initializers.MetaMLP - -Low-Rank Sinkhorn ------------------ -.. autosummary:: - :toctree: _autosummary - - sinkhorn_lr.LRSinkhorn - sinkhorn_lr.LRSinkhornOutput - -Low-Rank Sinkhorn Initializers ------------------------------- -.. autosummary:: - :toctree: _autosummary - - initializers_lr.RandomInitializer - initializers_lr.Rank2Initializer - initializers_lr.KMeansInitializer - initializers_lr.GeneralizedKMeansInitializer - -Quadratic Initializers ----------------------- -.. autosummary:: - :toctree: _autosummary - - quad_initializers.QuadraticInitializer - quad_initializers.LRQuadraticInitializer - -Barycenters (Entropic and LR) ------------------------------ -.. autosummary:: - :toctree: _autosummary - - discrete_barycenter.discrete_barycenter - continuous_barycenter.WassersteinBarycenter - continuous_barycenter.BarycenterState - gw_barycenter.GromovWassersteinBarycenter - gw_barycenter.GWBarycenterState - -Gromov-Wasserstein (Entropic and LR) ------------------------------------- -.. autosummary:: - :toctree: _autosummary - - gromov_wasserstein.gromov_wasserstein - gromov_wasserstein.GromovWasserstein - gromov_wasserstein.GWOutput - -Dual Potentials ---------------- -.. autosummary:: - :toctree: _autosummary - - potentials.DualPotentials - potentials.EntropicPotentials - -Neural Dual Potentials ----------------------- -.. autosummary:: - :toctree: _autosummary - - icnn.ICNN - neuraldual.NeuralDualSolver - -Padding Utilities ------------------ -.. autosummary:: - :toctree: _autosummary - - segment.segment_point_cloud diff --git a/docs/geometry.rst b/docs/geometry.rst index edabf6b3b..c857dd8a7 100644 --- a/docs/geometry.rst +++ b/docs/geometry.rst @@ -58,3 +58,10 @@ Cost Functions costs.Cosine costs.Bures costs.UnbalancedBures + +Utilities +--------- +.. autosummary:: + :toctree: _autosummary + + segment.segment_point_cloud diff --git a/docs/index.rst b/docs/index.rst index 52af7189f..037233463 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -1,6 +1,8 @@ +|Downloads| |Tests| |Docs| |Coverage| + Optimal Transport Tools (OTT) documentation =========================================== -`Code `_ on github. +`Code `_ on GitHub. To install, simply run ``pip install ott-jax``. Intro @@ -15,13 +17,12 @@ such as differentiable approximations to ranking or even clustering. To achieve this, `OTT` rests on two families of tools: The first family consists in *discrete* solvers computing transport between point clouds, using the Sinkhorn :cite:`cuturi:13` and low-rank Sinkhorn :cite:`scetbon:21` algorithms, -and moving up towards Gromov-Wasserstein :cite:`memoli:11`, :cite:`memoli:11`; +and moving up towards Gromov-Wasserstein :cite:`memoli:11,peyre:16`; the second family consists in *continuous* solvers, using suitable neural architectures :cite:`amos:17` coupled -with SGD type estimators :cite:`makkuva:20`, :cite:`korotin:21`. +with SGD type estimators :cite:`makkuva:20,korotin:21`. Design Choices -------------- - `OTT` is designed with the following choices: - Take advantage whenever possible of JAX features, such as `Just-in-time (JIT) compilation`_, @@ -42,20 +43,22 @@ Design Choices automatically in higher level calls (e.g. updates in Gromov-Wasserstein), without requiring any attention from the user. +.. TODO(marcocuturi): add missing package descriptions below + Packages -------- -There are currently three packages, ``geometry``, ``core`` and ``tools``, playing the following roles: - - :ref:`geometry` contains classes to instantiate objects that describe *two point clouds* paired with a *cost* function. Geometry objects are used to - describe OT problems, handled by solvers in ``core``. -- :ref:`core` classes describe OT problems (linear, quadratic, barycenters), and - solver classes, to instantiate algorithms that will output an OT. + describe OT problems, handled by solvers in the :ref:`solvers`. +- :ref:`problems` +- :ref:`solvers` +- :ref:`initializers` - :ref:`tools` provides an interface to exploit OT solutions, as produced by - solvers in the ``core`` package. Such tasks include computing approximations + solvers in the :ref:`solvers`. Such tasks include computing approximations to Wasserstein distances :cite:`genevay:18,sejourne:19`, approximating OT between GMMs, or computing differentiable sort and quantile operations :cite:`cuturi:19`. +- :ref:`math` .. toctree:: :maxdepth: 1 @@ -95,8 +98,11 @@ There are currently three packages, ``geometry``, ``core`` and ``tools``, playin :caption: Public API: ott packages geometry - core + problems/index + solvers/index + initializers/index tools + math .. toctree:: :maxdepth: 1 @@ -104,6 +110,23 @@ There are currently three packages, ``geometry``, ``core`` and ``tools``, playin references + +.. |Downloads| image:: https://pepy.tech/badge/ott-jax + :target: https://pypi.org/project/ott-jax/ + :alt: Documentation + +.. |Tests| image:: https://img.shields.io/github/workflow/status/ott-jax/ott/tests/main + :target: https://github.com/ott-jax/ott/actions/workflows/tests.yml + :alt: Documentation + +.. |Docs| image:: https://img.shields.io/readthedocs/ott-jax/latest + :target: https://ott-jax.readthedocs.io/en/latest/ + :alt: Documentation + +.. |Coverage| image:: https://img.shields.io/codecov/c/github/ott-jax/ott/main + :target: https://app.codecov.io/gh/ott-jax/ott + :alt: Coverage + .. _Just-in-time (JIT) compilation: https://jax.readthedocs.io/en/latest/jax.html#just-in-time-compilation-jit .. _auto-vectorization (VMAP): https://jax.readthedocs.io/en/latest/jax.html#vectorization-vmap .. _automatic: https://jax.readthedocs.io/en/latest/jax.html#automatic-differentiation diff --git a/docs/initializers/index.rst b/docs/initializers/index.rst new file mode 100644 index 000000000..b591be3b2 --- /dev/null +++ b/docs/initializers/index.rst @@ -0,0 +1,16 @@ +.. _initializers: + +ott.initializers package +======================== + +.. TODO(cuturi): add some nice text here please + +.. currentmodule:: ott.initializers +.. automodule:: ott.initializers + +.. toctree:: + :maxdepth: 2 + + linear + quadratic + nn diff --git a/docs/initializers/linear.rst b/docs/initializers/linear.rst new file mode 100644 index 000000000..ad3143fa1 --- /dev/null +++ b/docs/initializers/linear.rst @@ -0,0 +1,25 @@ +ott.initializers.linear package +=============================== +.. currentmodule:: ott.initializers.linear +.. automodule:: ott.initializers.linear + +.. TODO(marcocuturi): maybe add some text here + +Sinkhorn Initializers +--------------------- +.. autosummary:: + :toctree: _autosummary + + initializers.DefaultInitializer + initializers.GaussianInitializer + initializers.SinkhornInitializer + +Low-rank Sinkhorn Initializers +------------------------------ +.. autosummary:: + :toctree: _autosummary + + initializers_lr.RandomInitializer + initializers_lr.Rank2Initializer + initializers_lr.KMeansInitializer + initializers_lr.GeneralizedKMeansInitializer diff --git a/docs/initializers/nn.rst b/docs/initializers/nn.rst new file mode 100644 index 000000000..2f88e0999 --- /dev/null +++ b/docs/initializers/nn.rst @@ -0,0 +1,14 @@ +ott.initializers.nn package +=========================== +.. currentmodule:: ott.initializers.nn +.. automodule:: ott.initializers.nn + +.. TODO(marcocuturi): maybe add some text here + +Neural Initializers +------------------- +.. autosummary:: + :toctree: _autosummary + + initializers.MetaInitializer + initializers.MetaMLP diff --git a/docs/initializers/quadratic.rst b/docs/initializers/quadratic.rst new file mode 100644 index 000000000..1929bd380 --- /dev/null +++ b/docs/initializers/quadratic.rst @@ -0,0 +1,14 @@ +ott.initializers.quadratic package +================================== +.. currentmodule:: ott.initializers.quadratic +.. automodule:: ott.initializers.quadratic + +.. TODO(marcocuturi): maybe add some text here + +Gromov-Wasserstein Initializers +------------------------------- +.. autosummary:: + :toctree: _autosummary + + initializers.QuadraticInitializer + initializers.LRQuadraticInitializer diff --git a/docs/math.rst b/docs/math.rst new file mode 100644 index 000000000..20ea2fc3f --- /dev/null +++ b/docs/math.rst @@ -0,0 +1,30 @@ +.. _math: + +ott.math package +================ +.. currentmodule:: ott.math +.. automodule:: ott.math + +.. TODO(marcocuturi): maybe add some text here + +Fixed-point Iteration +--------------------- +.. autosummary:: + :toctree: _autosummary + + fixed_point_loop.fixpoint_iter + +Cholesky Decomposition +---------------------- +.. autosummary:: + :toctree: _autosummary + + decomposition.DenseCholeskySolver + decomposition.SparseCholeskySolver + +Matrix Square Root +------------------ +.. autosummary:: + :toctree: _autosummary + + matrix_square_root.sqrtm diff --git a/docs/problems/index.rst b/docs/problems/index.rst new file mode 100644 index 000000000..16e5ead90 --- /dev/null +++ b/docs/problems/index.rst @@ -0,0 +1,15 @@ +.. _problems: + +ott.problems package +==================== + +.. TODO(marcocuturi): add some nice text here please + +.. currentmodule:: ott.problems +.. automodule:: ott.problems + +.. toctree:: + :maxdepth: 2 + + linear + quadratic diff --git a/docs/problems/linear.rst b/docs/problems/linear.rst new file mode 100644 index 000000000..d8b442e15 --- /dev/null +++ b/docs/problems/linear.rst @@ -0,0 +1,22 @@ +ott.problems.linear package +=========================== +.. currentmodule:: ott.problems.linear +.. automodule:: ott.problems.linear + +.. TODO(marcocuturi): maybe add some text here + +OT Problems +----------- +.. autosummary:: + :toctree: _autosummary + + linear_problem.LinearProblem + barycenter_problem.BarycenterProblem + +Dual Potentials +--------------- +.. autosummary:: + :toctree: _autosummary + + potentials.DualPotentials + potentials.EntropicPotentials diff --git a/docs/problems/quadratic.rst b/docs/problems/quadratic.rst new file mode 100644 index 000000000..e7e8c32d1 --- /dev/null +++ b/docs/problems/quadratic.rst @@ -0,0 +1,22 @@ +ott.problems.quadratic package +============================== +.. currentmodule:: ott.problems.quadratic +.. automodule:: ott.problems.quadratic + +.. TODO(marcocuturi): maybe add some text here + +OT Problems +----------- +.. autosummary:: + :toctree: _autosummary + + quadratic_problem.QuadraticProblem + gw_barycenter.GWBarycenterProblem + +Costs +----- +.. autosummary:: + :toctree: _autosummary + + quadratic_costs.make_square_loss + quadratic_costs.make_kl_loss diff --git a/docs/references.bib b/docs/references.bib index 74abe1adb..3bef70d15 100644 --- a/docs/references.bib +++ b/docs/references.bib @@ -29,7 +29,6 @@ @InProceedings{peyre:16 url = {https://proceedings.mlr.press/v48/peyre16.html}, } - @InProceedings{feydy:19, title = {Interpolating between Optimal Transport and MMD using Sinkhorn Divergences}, author = {Feydy, Jean and S\'{e}journ\'{e}, Thibault and Vialard, Fran\c{c}ois-Xavier and Amari, Shun-ichi and Trouve, Alain and Peyr\'{e}, Gabriel}, @@ -161,10 +160,8 @@ @Article{demetci:20 title = {Gromov-Wasserstein optimal transport to align single-cell multi-omics data}, elocation-id = {2020.04.28.066787}, year = {2020}, - doi = {10.1101/2020.04.28.066787}, publisher = {Cold Spring Harbor Laboratory}, URL = {https://www.biorxiv.org/content/early/2020/11/11/2020.04.28.066787}, - eprint = {https://www.biorxiv.org/content/early/2020/11/11/2020.04.28.066787.full.pdf}, journal = {bioRxiv} } @@ -214,8 +211,6 @@ @Article{gelbrich:90 number = {1}, pages = {185-203}, doi = {https://doi.org/10.1002/mana.19901470121}, - url = {https://onlinelibrary.wiley.com/doi/abs/10.1002/mana.19901470121}, - eprint = {https://onlinelibrary.wiley.com/doi/pdf/10.1002/mana.19901470121}, year = {1990} } @@ -294,9 +289,8 @@ @Article{benamou:15 pages = {A1111-A1138}, year = {2015}, doi = {10.1137/141000439}, - URL = {https://doi.org/10.1137/141000439}, - eprint = {https://doi.org/10.1137/141000439} } + @article{brenier:91, title={Polar factorization and monotone rearrangement of vector-valued functions}, author={Brenier, Yann}, @@ -407,8 +401,6 @@ @Article{delon:20 pages = {936-970}, year = {2020}, doi = {10.1137/19M1301047}, - URL = {https://doi.org/10.1137/19M1301047}, - eprint = {https://doi.org/10.1137/19M1301047}, } @InProceedings{janati:20a, @@ -436,8 +428,6 @@ @Article{schmitz:18 pages = {643-678}, year = {2018}, doi = {10.1137/17M1140431}, - URL = {https://doi.org/10.1137/17M1140431}, - eprint = {https://doi.org/10.1137/17M1140431}, } @Article{alvarez-esteban:16, @@ -492,7 +482,7 @@ @inproceedings{chizat:20 year = {2020} } -@Article{higham:1997, +@Article{higham:97, author = "Higham, Nicholas J.", title = "Stable iterations for the matrix square root", journal = "Numerical Algorithms", diff --git a/docs/solvers/index.rst b/docs/solvers/index.rst new file mode 100644 index 000000000..8b7d62532 --- /dev/null +++ b/docs/solvers/index.rst @@ -0,0 +1,23 @@ +.. _solvers: + +ott.solvers package +=================== + +.. TODO(marcocuturi): add some nice text here please + +.. currentmodule:: ott.solvers +.. automodule:: ott.solvers + +.. toctree:: + :maxdepth: 2 + + linear + quadratic + nn + +Wasserstein Solver +------------------ +.. autosummary:: + :toctree: _autosummary + + was_solver.WassersteinSolver diff --git a/docs/solvers/linear.rst b/docs/solvers/linear.rst new file mode 100644 index 000000000..0605bd7e1 --- /dev/null +++ b/docs/solvers/linear.rst @@ -0,0 +1,42 @@ +ott.solvers.linear package +========================== +.. currentmodule:: ott.solvers.linear +.. automodule:: ott.solvers.linear + +.. TODO(marcocuturi): maybe add some text here + +Sinkhorn Solvers +---------------- +.. autosummary:: + :toctree: _autosummary + + sinkhorn.sinkhorn + sinkhorn.Sinkhorn + sinkhorn.SinkhornOutput + sinkhorn_lr.LRSinkhorn + sinkhorn_lr.LRSinkhornOutput + +Barycenter Solvers +------------------ +.. autosummary:: + :toctree: _autosummary + + continuous_barycenter.WassersteinBarycenter + continuous_barycenter.BarycenterState + discrete_barycenter.discrete_barycenter + discrete_barycenter.SinkhornBarycenterOutput + +Sinkhorn Acceleration +--------------------- +.. autosummary:: + :toctree: _autosummary + + acceleration.Momentum + acceleration.AndersonAcceleration + +Implicit Differentiation +------------------------ +.. autosummary:: + :toctree: _autosummary + + implicit_differentiation.ImplicitDiff diff --git a/docs/solvers/nn.rst b/docs/solvers/nn.rst new file mode 100644 index 000000000..08bd7fc1a --- /dev/null +++ b/docs/solvers/nn.rst @@ -0,0 +1,21 @@ +ott.solvers.nn package +====================== +.. currentmodule:: ott.solvers.nn +.. automodule:: ott.solvers.nn + +.. TODO(marcocuturi): maybe add some text here + +Neural Dual +----------- +.. autosummary:: + :toctree: _autosummary + + neuraldual.NeuralDualSolver + +ICNN +---- +.. autosummary:: + :toctree: _autosummary + + icnn.ICNN + layers.PositiveDense diff --git a/docs/solvers/quadratic.rst b/docs/solvers/quadratic.rst new file mode 100644 index 000000000..9f6ea7a38 --- /dev/null +++ b/docs/solvers/quadratic.rst @@ -0,0 +1,23 @@ +ott.solvers.quadratic package +============================= +.. currentmodule:: ott.solvers.quadratic +.. automodule:: ott.solvers.quadratic + +.. TODO(marcocuturi): maybe add some text here + +Gromov-Wasserstein Solvers +-------------------------- +.. autosummary:: + :toctree: _autosummary + + gromov_wasserstein.GromovWasserstein + gromov_wasserstein.GWOutput + gromov_wasserstein.gromov_wasserstein + +Barycenter Solvers +------------------ +.. autosummary:: + :toctree: _autosummary + + gw_barycenter.GWBarycenterState + gw_barycenter.GromovWassersteinBarycenter diff --git a/docs/tools.rst b/docs/tools.rst index 2a890d52a..a0847d506 100644 --- a/docs/tools.rst +++ b/docs/tools.rst @@ -23,7 +23,6 @@ Segmented Sinkhorn segment_sinkhorn.segment_sinkhorn - Sinkhorn Divergence ------------------- .. autosummary:: diff --git a/ott/examples/fairness/config.py b/examples/fairness/config.py similarity index 100% rename from ott/examples/fairness/config.py rename to examples/fairness/config.py diff --git a/ott/examples/fairness/data.py b/examples/fairness/data.py similarity index 98% rename from ott/examples/fairness/data.py rename to examples/fairness/data.py index 075e46367..91547d421 100644 --- a/ott/examples/fairness/data.py +++ b/examples/fairness/data.py @@ -19,8 +19,6 @@ import numpy as np import pandas as pd -open_fn = open - def load_df( data_path: str, @@ -30,12 +28,12 @@ def load_df( **kwargs ): """Load a pandas dataframe from two filenames.""" - with open_fn(data_path, 'r') as fp: + with open(data_path) as fp: df = pd.read_csv(fp, skipinitialspace=True, header=None, **kwargs) headers = [] targets = [] - with open_fn(info_path, 'r') as fp: + with open(info_path) as fp: for line in fp: if line.startswith('|') or not line.strip(): continue diff --git a/ott/examples/fairness/losses.py b/examples/fairness/losses.py similarity index 100% rename from ott/examples/fairness/losses.py rename to examples/fairness/losses.py diff --git a/ott/examples/fairness/main.py b/examples/fairness/main.py similarity index 100% rename from ott/examples/fairness/main.py rename to examples/fairness/main.py diff --git a/ott/examples/fairness/models.py b/examples/fairness/models.py similarity index 100% rename from ott/examples/fairness/models.py rename to examples/fairness/models.py diff --git a/ott/examples/fairness/train.py b/examples/fairness/train.py similarity index 100% rename from ott/examples/fairness/train.py rename to examples/fairness/train.py diff --git a/ott/examples/soft_error/config.py b/examples/soft_error/config.py similarity index 100% rename from ott/examples/soft_error/config.py rename to examples/soft_error/config.py diff --git a/ott/examples/soft_error/data.py b/examples/soft_error/data.py similarity index 100% rename from ott/examples/soft_error/data.py rename to examples/soft_error/data.py diff --git a/ott/examples/soft_error/losses.py b/examples/soft_error/losses.py similarity index 100% rename from ott/examples/soft_error/losses.py rename to examples/soft_error/losses.py diff --git a/ott/examples/soft_error/main.py b/examples/soft_error/main.py similarity index 100% rename from ott/examples/soft_error/main.py rename to examples/soft_error/main.py diff --git a/ott/examples/soft_error/model.py b/examples/soft_error/model.py similarity index 100% rename from ott/examples/soft_error/model.py rename to examples/soft_error/model.py diff --git a/ott/examples/soft_error/train.py b/examples/soft_error/train.py similarity index 100% rename from ott/examples/soft_error/train.py rename to examples/soft_error/train.py diff --git a/ott/__init__.py b/ott/__init__.py index df1438bef..8015c263b 100644 --- a/ott/__init__.py +++ b/ott/__init__.py @@ -12,5 +12,5 @@ # See the License for the specific language governing permissions and # limitations under the License. """OTT library.""" -from . import core, geometry, tools +from . import geometry, initializers, math, problems, solvers, tools, utils from ._version import __version__ diff --git a/ott/_version.py b/ott/_version.py index 23af74edc..689bed779 100644 --- a/ott/_version.py +++ b/ott/_version.py @@ -1,13 +1,11 @@ -from packaging.version import parse - try: from importlib_metadata import PackageNotFoundError, version # Python < 3.8 except ImportError: from importlib.metadata import PackageNotFoundError, version try: - __version__ = str(parse(version("ott-jax"))) + __version__ = version("ott-jax") except PackageNotFoundError: __version__ = "" -del parse, version, PackageNotFoundError +del version, PackageNotFoundError diff --git a/ott/core/__init__.py b/ott/core/__init__.py deleted file mode 100644 index 4a39d58fc..000000000 --- a/ott/core/__init__.py +++ /dev/null @@ -1,44 +0,0 @@ -# Copyright 2022 Google LLC. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""OTT core libraries: the engine behind most computations happening in OTT.""" - -# pytype: disable=import-error # kwargs-checking -from . import ( - anderson, - bar_problems, - continuous_barycenter, - dataclasses, - decomposition, - discrete_barycenter, - gromov_wasserstein, - gw_barycenter, - implicit_differentiation, - initializers, - initializers_lr, - linear_problems, - momentum, - potentials, - quad_initializers, - quad_problems, - sinkhorn, - sinkhorn_lr, -) - -# from . import neuraldual -from .implicit_differentiation import ImplicitDiff -from .linear_problems import LinearProblem -from .sinkhorn import Sinkhorn -from .sinkhorn_lr import LRSinkhorn - -# pytype: enable=import-error # kwargs-checking diff --git a/ott/core/_math_utils.py b/ott/core/_math_utils.py deleted file mode 100644 index 7a269a77b..000000000 --- a/ott/core/_math_utils.py +++ /dev/null @@ -1,33 +0,0 @@ -from typing import Optional, Union - -import jax.experimental.sparse as jesp -import jax.numpy as jnp - -__all__ = ["safe_log", "kl", "js"] - -Sparse_t = Union[jesp.CSR, jesp.CSC, jesp.COO, jesp.BCOO] - - -def safe_log(x: jnp.ndarray, *, eps: Optional[float] = None) -> jnp.ndarray: - if eps is None: - eps = jnp.finfo(x.dtype).tiny - return jnp.where(x > 0., jnp.log(x), jnp.log(eps)) - - -def kl(p: jnp.ndarray, q: jnp.ndarray) -> float: - """Kullback-Leilbler divergence.""" - return jnp.vdot(p, (safe_log(p) - safe_log(q))) - - -def js(p: jnp.ndarray, q: jnp.ndarray, *, c: float = 0.5) -> float: - """Jensen-Shannon divergence.""" - return c * (kl(p, q) + kl(q, p)) - - -def sparse_scale(c: float, mat: Sparse_t) -> Sparse_t: - """Scale a sparse matrix by a constant.""" - if isinstance(mat, jesp.BCOO): - # most feature complete, defer to original impl. - return c * mat - (data, *children), aux_data = mat.tree_flatten() - return type(mat).tree_unflatten(aux_data, [c * data] + children) diff --git a/ott/core/momentum.py b/ott/core/momentum.py deleted file mode 100644 index 380df4fc1..000000000 --- a/ott/core/momentum.py +++ /dev/null @@ -1,71 +0,0 @@ -# Copyright 2022 Google LLC. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""Functions related to momemtum.""" - -from typing import TYPE_CHECKING - -import jax -import jax.numpy as jnp - -from ott.core import dataclasses - -if TYPE_CHECKING: - from ott.core import sinkhorn - - -@dataclasses.register_pytree_node -class Momentum: - """Momentum for Sinkhorn updates, either constant or adaptive.""" - - start: int = 0 - error_threshold: float = jnp.inf - value: float = 1.0 - inner_iterations: int = 1 - - def weight(self, state: "sinkhorn.SinkhornState", iteration: int) -> float: - """Compute momentum term if needed, using previously seen errors.""" - if self.start == 0: - return self.value - idx = self.start // self.inner_iterations - - weight = jax.lax.cond( - jnp.logical_and( - iteration >= self.start, - state.errors[idx - 1, -1] < self.error_threshold - ), lambda state: self.lehmann(state), lambda state: self.value, state - ) - return weight - - def lehmann(self, state: "sinkhorn.SinkhornState") -> float: - """Momentum formula :cite:`lehmann:21`, eq. 5.""" - idx = self.start // self.inner_iterations - error_ratio = jnp.minimum( - state.errors[idx - 1, -1] / state.errors[idx - 2, -1], 0.99 - ) - power = 1.0 / self.inner_iterations - return 2.0 / (1.0 + jnp.sqrt(1.0 - error_ratio ** power)) - - def __call__( - self, - weight: float, - value: jnp.ndarray, - new_value: jnp.ndarray, - lse_mode: bool = True - ) -> jnp.ndarray: - if lse_mode: - value = jnp.where(jnp.isfinite(value), value, 0.0) - return (1.0 - weight) * value + weight * new_value - else: - value = jnp.where(value > 0.0, value, 1.0) - return value ** (1.0 - weight) * new_value ** weight diff --git a/ott/core/problems.py b/ott/core/problems.py deleted file mode 100644 index e60b4deb1..000000000 --- a/ott/core/problems.py +++ /dev/null @@ -1,93 +0,0 @@ -# Copyright 2022 The OTT Authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""Utility to make a problem class from arrays.""" -from typing import Any, Optional, Union - -import jax.numpy as jnp -import numpy as np - -from ott.core import linear_problems, quad_problems -from ott.geometry import geometry, pointcloud - - -def make( - *args: Union[jnp.ndarray, geometry.Geometry, linear_problems.LinearProblem, - quad_problems.QuadraticProblem], - a: Optional[jnp.ndarray] = None, - b: Optional[jnp.ndarray] = None, - tau_a: float = 1.0, - tau_b: float = 1.0, - objective: Optional[str] = None, - gw_unbalanced_correction: Optional[bool] = True, - fused_penalty: Optional[float] = None, - scale_cost: Optional[Union[bool, float, str]] = False, - **kwargs: Any, -): - """Make a problem from arrays, assuming PointCloud geometries.""" - if isinstance(args[0], (jnp.ndarray, np.ndarray)): - x = args[0] - y = args[1] if len(args) > 1 else args[0] - if ((objective == 'linear') or - (objective is None and x.shape[1] == y.shape[1])): # noqa: E129 - geom_xy = pointcloud.PointCloud(x, y, **kwargs) - return linear_problems.LinearProblem( - geom_xy, a=a, b=b, tau_a=tau_a, tau_b=tau_b - ) - elif ((objective == 'quadratic') or - (objective is None and x.shape[1] != y.shape[1])): - geom_xx = pointcloud.PointCloud(x, x, **kwargs) - geom_yy = pointcloud.PointCloud(y, y, **kwargs) - return quad_problems.QuadraticProblem( - geom_xx=geom_xx, - geom_yy=geom_yy, - geom_xy=None, - scale_cost=scale_cost, - a=a, - b=b, - tau_a=tau_a, - tau_b=tau_b, - gw_unbalanced_correction=gw_unbalanced_correction - ) - elif objective == 'fused': - geom_xx = pointcloud.PointCloud(x, x, **kwargs) - geom_yy = pointcloud.PointCloud(y, y, **kwargs) - geom_xy = pointcloud.PointCloud(x, y, **kwargs) - return quad_problems.QuadraticProblem( - geom_xx=geom_xx, - geom_yy=geom_yy, - geom_xy=geom_xy, - fused_penalty=fused_penalty, - scale_cost=scale_cost, - a=a, - b=b, - tau_a=tau_a, - tau_b=tau_b, - gw_unbalanced_correction=gw_unbalanced_correction - ) - else: - raise ValueError(f'Unknown transport problem `{objective}`') - elif isinstance(args[0], geometry.Geometry): - if len(args) == 1: - return linear_problems.LinearProblem( - *args, a=a, b=b, tau_a=tau_a, tau_b=tau_b - ) - return quad_problems.QuadraticProblem( - *args, a=a, b=b, tau_a=tau_a, tau_b=tau_b, scale_cost=scale_cost - ) - elif isinstance( - args[0], (linear_problems.LinearProblem, quad_problems.QuadraticProblem) - ): - return args[0] - else: - raise ValueError('Cannot instantiate a transport problem.') diff --git a/ott/geometry/__init__.py b/ott/geometry/__init__.py index 38b4ada53..e57c7b3ec 100644 --- a/ott/geometry/__init__.py +++ b/ott/geometry/__init__.py @@ -12,9 +12,4 @@ # See the License for the specific language governing permissions and # limitations under the License. """OTT ground geometries: Classes and cost functions to instantiate them.""" -from . import costs, low_rank, ops -from .epsilon_scheduler import Epsilon -from .geometry import Geometry -from .graph import Graph -from .grid import Grid -from .pointcloud import PointCloud +from . import costs, epsilon_scheduler, geometry, graph, grid, pointcloud, segment diff --git a/ott/geometry/costs.py b/ott/geometry/costs.py index cd59fc0c6..94c06112a 100644 --- a/ott/geometry/costs.py +++ b/ott/geometry/costs.py @@ -17,13 +17,17 @@ import abc import functools import math -from typing import Any, Callable, Optional, Union +from typing import Any, Callable, Optional, Tuple, Union import jax import jax.numpy as jnp -from ott.core import fixed_point_loop -from ott.geometry import matrix_square_root +from ott.math import fixed_point_loop, matrix_square_root + +__all__ = [ + "PNorm", "SqPNorm", "Euclidean", "SqEuclidean", "Cosine", "Bures", + "UnbalancedBures" +] @jax.tree_util.register_pytree_node_class @@ -32,9 +36,11 @@ class CostFn(abc.ABC): Cost functions evaluate a function on a pair of inputs. For convenience, that function is split into two norms -- evaluated on each input separately -- - followed by a pairwise cost that involves both inputs, as in + followed by a pairwise cost that involves both inputs, as in: + + .. math:: - c(x,y) = norm(x) + norm(y) + pairwise(x,y) + c(x,y) = norm(x) + norm(y) + pairwise(x,y) If the norm function is not implemented, that value is handled as a 0. """ @@ -43,17 +49,34 @@ class CostFn(abc.ABC): norm: Optional[Callable[[jnp.ndarray], Union[float, jnp.ndarray]]] = None @abc.abstractmethod - def pairwise(self, x: jnp.ndarray, y: jnp.ndarray) -> float: + def pairwise(self, x: jnp.ndarray, y: jnp.ndarray) -> jnp.ndarray: pass - def barycenter(self, weights: jnp.ndarray, xs: jnp.ndarray) -> float: - raise NotImplementedError("Barycenter not yet implemented for this cost.") + def barycenter(self, weights: jnp.ndarray, xs: jnp.ndarray) -> jnp.ndarray: + """Barycentric projection. + + Args: + weights: Weights of the points. + xs: Points to project. + + Returns: + The barycentric projection. + """ + raise NotImplementedError("Barycenter is not yet implemented.") @classmethod - def padder(cls, dim: int) -> jnp.ndarray: + def _padder(cls, dim: int) -> jnp.ndarray: + """Create a padding vector for easier jitting. + + Args: + dim: Dimensionality of the data. + + Returns: + The padding vector. + """ return jnp.zeros((1, dim)) - def __call__(self, x: jnp.ndarray, y: jnp.ndarray) -> float: + def __call__(self, x: jnp.ndarray, y: jnp.ndarray) -> jnp.ndarray: cost = self.pairwise(x, y) if self.norm is None: return cost @@ -98,6 +121,7 @@ class TICost(CostFn): real-values, to be used as: .. math:: + c(x,y) = h(z), z := x-y. If that cost function is used to form an Entropic map using the @@ -108,14 +132,14 @@ class TICost(CostFn): @abc.abstractmethod def h(self, z: jnp.ndarray) -> float: - """RBF function acting on difference of `x-y` to ouput cost.""" + """TI function acting on difference of :math:`x-y` to output cost.""" def h_legendre(self, z: jnp.ndarray) -> float: - """Legendre transform of RBF function `h` (when latter is convex).""" + """Legendre transform of :func:`h` when it is convex.""" raise NotImplementedError("`h_legendre` not implemented.") - def pairwise(self, x: jnp.ndarray, y: jnp.ndarray) -> float: - """Compute cost as evaluation of :func:`h` on `x-y`.""" + def pairwise(self, x: jnp.ndarray, y: jnp.ndarray) -> jnp.ndarray: + """Compute cost as evaluation of :func:`h` on :math:`x-y`.""" return self.h(x - y) @@ -125,12 +149,16 @@ class SqPNorm(TICost): For details on the derivation of the Legendre transform of the norm, see e.g. the reference :cite:`boyd:04`, p.93/94. + + Args: + p: Power of the p-norm. """ def __init__(self, p: float): + super().__init__() assert p >= 1.0, "p parameter in sq. p-norm should be >= 1.0" self.p = p - self.q = 1. / (1 - 1 / self.p) if p > 1.0 else 'inf' + self.q = 1. / (1. - 1. / self.p) if p > 1.0 else "inf" def h(self, z: jnp.ndarray) -> float: return 0.5 * jnp.linalg.norm(z, self.p) ** 2 @@ -149,12 +177,18 @@ def tree_unflatten(cls, aux_data, children): @jax.tree_util.register_pytree_node_class class PNorm(TICost): - """p-norm (to the power p) of the difference of two vectors.""" + """p-norm (to the power p) of the difference of two vectors. + + Args: + p: Power of the p-norm. + """ def __init__(self, p: float): + super().__init__() assert p >= 1.0, "p parameter in p-norm should be >= 1.0" self.p = p - self.q = 1. / (1 - 1 / self.p) + # TODO(marcocuturi): fix case when `p=1` + self.q = 1. / (1. - 1. / self.p) if p > 1. else "inf" def h(self, z: jnp.ndarray) -> float: return jnp.linalg.norm(z, self.p) ** self.p / self.p @@ -175,12 +209,13 @@ def tree_unflatten(cls, aux_data, children): class Euclidean(CostFn): """Euclidean distance. - Note that the Euclidean distance is not cast as a `TICost`, because this - would correspond to `h = jnp.linalg.norm`, whose gradient is not invertible, + Note that the Euclidean distance is not cast as a + :class:`~ott.geometry.costs.TICost`, since this would correspond to :math:`h` + being :func:`jax.numpy.linalg.norm`, whose gradient is not invertible, because the function is not strictly convex (it is linear on rays). """ - def pairwise(self, x: jnp.ndarray, y: jnp.ndarray) -> float: + def pairwise(self, x: jnp.ndarray, y: jnp.ndarray) -> jnp.ndarray: """Compute Euclidean norm.""" return jnp.linalg.norm(x - y) @@ -193,7 +228,7 @@ def norm(self, x: jnp.ndarray) -> Union[float, jnp.ndarray]: """Compute squared Euclidean norm for vector.""" return jnp.sum(x ** 2, axis=-1) - def pairwise(self, x: jnp.ndarray, y: jnp.ndarray) -> float: + def pairwise(self, x: jnp.ndarray, y: jnp.ndarray) -> jnp.ndarray: """Compute minus twice the dot-product between vectors.""" return -2. * jnp.vdot(x, y) @@ -210,13 +245,17 @@ def barycenter(self, weights: jnp.ndarray, xs: jnp.ndarray) -> jnp.ndarray: @jax.tree_util.register_pytree_node_class class Cosine(CostFn): - """Cosine distance CostFn.""" + """Cosine distance cost function. + + Args: + ridge: Ridge regularization. + """ def __init__(self, ridge: float = 1e-8): super().__init__() self._ridge = ridge - def pairwise(self, x: jnp.ndarray, y: jnp.ndarray) -> float: + def pairwise(self, x: jnp.ndarray, y: jnp.ndarray) -> jnp.ndarray: """Cosine distance between vectors, denominator regularized with ridge.""" ridge = self._ridge x_norm = jnp.linalg.norm(x, axis=-1) @@ -227,27 +266,32 @@ def pairwise(self, x: jnp.ndarray, y: jnp.ndarray) -> float: return jnp.clip(cosine_distance, 0., 2.) @classmethod - def padder(cls, dim: int) -> jnp.ndarray: + def _padder(cls, dim: int) -> jnp.ndarray: return jnp.ones((1, dim)) @jax.tree_util.register_pytree_node_class class Bures(CostFn): - """Bures distance between a pair of (mean, cov matrix) raveled as vectors.""" + """Bures distance between a pair of (mean, cov matrix) raveled as vectors. + + Args: + dimension: Dimensionality of the data. + kwargs: Keyword arguments for :func:`ott.math.matrix_square_root.sqrtm`. + """ def __init__(self, dimension: int, **kwargs: Any): super().__init__() self._dimension = dimension self._sqrtm_kw = kwargs - def norm(self, x: jnp.ndarray): + def norm(self, x: jnp.ndarray) -> jnp.ndarray: """Compute norm of Gaussian, sq. 2-norm of mean + trace of covariance.""" mean, cov = x_to_means_and_covs(x, self._dimension) norm = jnp.sum(mean ** 2, axis=-1) norm += jnp.trace(cov, axis1=-2, axis2=-1) return norm - def pairwise(self, x: jnp.ndarray, y: jnp.ndarray): + def pairwise(self, x: jnp.ndarray, y: jnp.ndarray) -> jnp.ndarray: """Compute - 2 x Bures dot-product.""" mean_x, cov_x = x_to_means_and_covs(x, self._dimension) mean_y, cov_y = x_to_means_and_covs(y, self._dimension) @@ -259,17 +303,6 @@ def pairwise(self, x: jnp.ndarray, y: jnp.ndarray): )[0] return -2 * (mean_dot_prod + jnp.trace(sq__sq_x_y_sq_x, axis1=-2, axis2=-1)) - @functools.partial(jax.vmap, in_axes=[None, None, 0, 0]) - def scale_covariances(self, cov_sqrt, cov_i, lambda_i): - """Iterate update needed to compute barycenter of covariances.""" - return lambda_i * matrix_square_root.sqrtm_only( - jnp.matmul(jnp.matmul(cov_sqrt, cov_i), cov_sqrt) - ) - - def relative_diff(self, x: jnp.ndarray, y: jnp.ndarray) -> jnp.ndarray: - """Monitor change in two successive estimates of matrices.""" - return jnp.sum(jnp.square(x - y)) / jnp.prod(jnp.array(x.shape)) - def covariance_fixpoint_iter( self, covs: jnp.ndarray, @@ -278,27 +311,40 @@ def covariance_fixpoint_iter( ) -> jnp.ndarray: """Iterate fix-point updates to compute barycenter of Gaussians.""" - def cond_fn(iteration, constants, state): + @functools.partial(jax.vmap, in_axes=[None, 0, 0]) + def scale_covariances( + cov_sqrt: jnp.ndarray, cov_i: jnp.ndarray, lambda_i: jnp.ndarray + ) -> jnp.ndarray: + """Iterate update needed to compute barycenter of covariances.""" + return lambda_i * matrix_square_root.sqrtm_only( + (cov_sqrt @ cov_i) @ cov_sqrt + ) + + def cond_fn(iteration: int, constants: Tuple[Any, ...], state) -> bool: + del iteration, constants _, diff = state - return diff > jnp.array(rtol) + return diff > rtol - def body_fn(iteration, constants, state, compute_error): - del compute_error + def body_fn( + iteration: int, constants: Tuple[Any, ...], + state: Tuple[jnp.ndarray, float], compute_error: bool + ) -> Tuple[jnp.ndarray, float]: + del iteration, constants, compute_error cov, _ = state cov_sqrt, cov_inv_sqrt, _ = matrix_square_root.sqrtm(cov) scaled_cov = jnp.linalg.matrix_power( - jnp.sum(self.scale_covariances(cov_sqrt, covs, lambdas), axis=0), 2 + jnp.sum(scale_covariances(cov_sqrt, covs, lambdas), axis=0), 2 ) - next_cov = jnp.matmul(jnp.matmul(cov_inv_sqrt, scaled_cov), cov_inv_sqrt) - diff = self.relative_diff(next_cov, cov) + next_cov = (cov_inv_sqrt @ scaled_cov) @ cov_inv_sqrt + diff = jnp.sum((next_cov - cov) ** 2) / jnp.prod(jnp.array(cov.shape)) return next_cov, diff - def init_state(): + def init_state() -> Tuple[jnp.ndarray, float]: cov_init = jnp.eye(self._dimension) diff = jnp.inf return cov_init, diff - state = fixed_point_loop.fixpoint_iter( + cov, _ = fixed_point_loop.fixpoint_iter( cond_fn=cond_fn, body_fn=body_fn, min_iterations=10, @@ -307,8 +353,6 @@ def init_state(): constants=(), state=init_state() ) - - cov, _ = state return cov def barycenter(self, weights: jnp.ndarray, xs: jnp.ndarray) -> jnp.ndarray: @@ -336,7 +380,7 @@ def barycenter(self, weights: jnp.ndarray, xs: jnp.ndarray) -> jnp.ndarray: return barycenter @classmethod - def padder(cls, dim: int) -> jnp.ndarray: + def _padder(cls, dim: int) -> jnp.ndarray: """Pad with concatenated zero means and \ raveled identity covariance matrix.""" dimension = int((-1 + math.sqrt(1 + 4 * dim)) / 2) @@ -356,37 +400,61 @@ def tree_unflatten(cls, aux_data, children): @jax.tree_util.register_pytree_node_class class UnbalancedBures(CostFn): - """Regularized/unbalanced Bures dist between two triplets of (mass,mean,cov). + """Unbalanced Bures distance between two triplets of `(mass, mean, cov)`. - This cost implements the value defined in :cite:`janati:20`, eq. 37, 39, 40. - We follow their notations. It is assumed inputs are given as - triplets (mass, mean, covariance) raveled as vectors, in that order. + This cost uses the notation defined in :cite:`janati:20`, eq. 37, 39, 40. + + Args: + dimension: Dimensionality of the data. + sigma: Entropic regularization. + gamma: KL-divergence regularization for the marginals. + kwargs: Keyword arguments for :func:`~ott.math.matrix_square_root.sqrtm`. """ def __init__( self, dimension: int, - gamma: float = 1.0, + *, sigma: float = 1.0, + gamma: float = 1.0, **kwargs: Any, ): super().__init__() self._dimension = dimension + self._sigma = sigma self._gamma = gamma - self._sigma2 = sigma ** 2 self._sqrtm_kw = kwargs - def norm(self, x: jnp.ndarray) -> Union[float, jnp.ndarray]: - """Compute norm of Gaussian for unbalanced Bures.""" - return self._gamma * x[0] + def norm(self, x: jnp.ndarray) -> jnp.ndarray: + """Compute norm of Gaussian for unbalanced Bures. + + Args: + x: Array of shape ``[n_points + n_points + n_dim ** 2,]``, potentially + batched, corresponding to the raveled mass, means and the covariance + matrix. + + Returns: + The norm, array of shape ``[]`` or ``[batch,]`` in the batched case. + """ + return self._gamma * x[..., 0] + + def pairwise(self, x: jnp.ndarray, y: jnp.ndarray) -> jnp.ndarray: + """Compute dot-product for unbalanced Bures. + + Args: + x: Array of shape ``[n_points + n_points + n_dim ** 2,]`` + corresponding to the raveled mass, means and the covariance matrix. + y: Array of shape ``[n_points + n_points + n_dim ** 2,]`` + corresponding to the raveled mass, means and the covariance matrix. - def pairwise(self, x: jnp.ndarray, y: jnp.ndarray) -> float: - """Compute dot-product for unbalanced Bures.""" + Returns: + The cost. + """ # Sets a few constants gam = self._gamma - sig2 = self._sigma2 - lam = sig2 + gam / 2 - tau = gam / (2 * lam) + sig2 = self._sigma ** 2 + lam = sig2 + gam / 2.0 + tau = gam / (2.0 * lam) # Extracts mass, mean vector, covariance matrices mass_x, mass_y = x[0], y[0] @@ -412,51 +480,51 @@ def pairwise(self, x: jnp.ndarray, y: jnp.ndarray) -> float: sldet_c, ldet_c = jnp.linalg.slogdet(c_mat) sldet_t_ab, ldet_t_ab = jnp.linalg.slogdet(tilde_a_b) sldet_ab, ldet_ab = jnp.linalg.slogdet(jnp.matmul(cov_x, cov_y)) - sldet_c_ab, ldet_c_ab = jnp.linalg.slogdet(c_mat - 2 * tilde_a_b / gam) + sldet_c_ab, ldet_c_ab = jnp.linalg.slogdet(c_mat - 2.0 * tilde_a_b / gam) # Gathers all these results to compute log total mass of transport log_m_pi = (0.5 * self._dimension * sig2 / (gam + sig2)) * jnp.log(sig2) - - log_m_pi += (1 / (tau + 1)) * ( + log_m_pi += (1.0 / (tau + 1.0)) * ( jnp.log(mass_x) + jnp.log(mass_y) + ldet_c + 0.5 * (tau * ldet_t_ab - ldet_ab) ) - log_m_pi += -jnp.sum( diff_means * jnp.linalg.solve(cov_x + cov_y + lam * iden, diff_means) - ) / (2 * (tau + 1)) - + ) / (2.0 * (tau + 1.0)) log_m_pi += -0.5 * ldet_c_ab - # If all logdet signs are 1, output value, nan otherwise. - return jnp.where( - sldet_c == 1 and sldet_c_ab == 1 and sldet_ab == 1 and sldet_t_ab == 1, - 2 * sig2 * mass_x * mass_y - 2 * (sig2 + gam) * jnp.exp(log_m_pi), - jnp.nan + # if all logdet signs are 1, output value, nan otherwise + pos_signs = (sldet_c + sldet_c_ab + sldet_t_ab + sldet_t_ab) == 4 + + return jax.lax.cond( + pos_signs, lambda: 2 * sig2 * mass_x * mass_y - 2 * + (sig2 + gam) * jnp.exp(log_m_pi), lambda: jnp.nan ) def tree_flatten(self): - return (), (self._dimension, self._gamma, self._sigma2, self._sqrtm_kw) + return (), (self._dimension, self._sigma, self._gamma, self._sqrtm_kw) @classmethod def tree_unflatten(cls, aux_data, children): del children - return cls(aux_data[0], aux_data[1], aux_data[2], **aux_data[3]) + dim, sigma, gamma, kwargs = aux_data + return cls(dim, sigma=sigma, gamma=gamma, **kwargs) -def x_to_means_and_covs(x: jnp.ndarray, dimension: jnp.ndarray) -> jnp.ndarray: +def x_to_means_and_covs(x: jnp.ndarray, + dimension: int) -> Tuple[jnp.ndarray, jnp.ndarray]: """Extract means and covariance matrices of Gaussians from raveled vector. Args: x: [num_gaussians, dimension, (1 + dimension)] array of concatenated means - and covariances (raveled) dimension: the dimension of the Gaussians. + and covariances (raveled) dimension: the dimension of the Gaussians. Returns: means: [num_gaussians, dimension] array that holds the means. covariances: [num_gaussians, dimension] array that holds the covariances. """ x = jnp.atleast_2d(x) - means = x[:, 0:dimension] + means = x[:, :dimension] covariances = jnp.reshape( x[:, dimension:dimension + dimension ** 2], (-1, dimension, dimension) ) diff --git a/ott/geometry/epsilon_scheduler.py b/ott/geometry/epsilon_scheduler.py index 1e0fc30e7..284fa29f7 100644 --- a/ott/geometry/epsilon_scheduler.py +++ b/ott/geometry/epsilon_scheduler.py @@ -19,6 +19,8 @@ import jax import jax.numpy as jnp +__all__ = ["Epsilon"] + @jax.tree_util.register_pytree_node_class class Epsilon: @@ -42,6 +44,7 @@ class Epsilon: decay: geometric decay factor, smaller than 1. """ + # TODO(michalk8): directly use the defaults instead of `None` def __init__( self, target: Optional[float] = None, diff --git a/ott/geometry/geometry.py b/ott/geometry/geometry.py index dc59740b7..d76282e51 100644 --- a/ott/geometry/geometry.py +++ b/ott/geometry/geometry.py @@ -25,7 +25,10 @@ import jax.scipy as jsp from typing_extensions import Literal -from ott.geometry import epsilon_scheduler, ops +from ott.geometry import epsilon_scheduler +from ott.math import utils + +__all__ = ["Geometry", "is_linear", "is_affine"] @jax.tree_util.register_pytree_node_class @@ -34,7 +37,7 @@ class Geometry: Optimal transport problems are intrinsically geometric: they compute an optimal way to transport mass from one configuration onto another. To define - what is meant by optimality of a transport requires defining a cost, of moving + what is meant by optimality of transport requires defining a cost, of moving mass from one among several sources, towards one out of multiple targets. These sources and targets can be provided as points in vectors spaces, grids, or more generally exclusively described through a (dissimilarity) cost matrix, @@ -62,11 +65,11 @@ class Geometry: 'median', 'mean' and 'max_cost'. Alternatively, a float factor can be given to rescale the cost such that ``cost_matrix /= scale_cost``. If `True`, use 'mean'. - tgt_mask: Mask specifying valid rows when computing some statistics of + src_mask: Mask specifying valid rows when computing some statistics of :attr:`cost_matrix`, see :attr:`src_mask`. tgt_mask: Mask specifying valid columns when computing some statistics of :attr:`cost_matrix`, see :attr:`tgt_mask`. - kwargs: additional kwargs to epsilon scheduler. + kwargs: additional kwargs for epsilon scheduler. Note: When defining a ``Geometry`` through a ``cost_matrix``, it is important to @@ -410,12 +413,12 @@ def _softmax( if vec is not None: if axis == 0: vec = vec.reshape((-1, 1)) - lse_output = ops.logsumexp( + lse_output = utils.logsumexp( self._center(f, g) / eps, b=vec, axis=axis, return_sign=True ) return eps * lse_output[0], lse_output[1] else: - lse_output = ops.logsumexp( + lse_output = utils.logsumexp( self._center(f, g) / eps, axis=axis, return_sign=False ) return eps * lse_output, jnp.array([1.0]) @@ -639,7 +642,7 @@ def to_LRCGeometry( Useful when this geometry is used in the linear term of fused GW. Returns: - The low-rank geometry. + Low-rank geometry. """ from ott.geometry import low_rank @@ -897,4 +900,4 @@ def is_affine(fn) -> bool: def is_linear(fn) -> bool: """Test heuristically if a function is linear.""" - return fn(0.0) == 0.0 and is_affine(fn) + return jnp.logical_and(fn(0.0) == 0.0, is_affine(fn)) diff --git a/ott/geometry/graph.py b/ott/geometry/graph.py index b7e6e9024..dd361f67a 100644 --- a/ott/geometry/graph.py +++ b/ott/geometry/graph.py @@ -5,9 +5,9 @@ import jax.numpy as jnp from typing_extensions import Literal -from ott.core import _math_utils as mu -from ott.core import decomposition, fixed_point_loop from ott.geometry import geometry +from ott.math import decomposition, fixed_point_loop +from ott.math import utils as mu __all__ = ["Graph"] @@ -192,7 +192,7 @@ def laplacian(self) -> Union[jnp.ndarray, Sparse_t]: # in the sparse case, we don't sum duplicates here because # we need to know `nnz` a priori for JIT (could be exposed in `__init__`) - # instead, `ott.core.decomposition._jax_sparse_to_scipy` handles it on host + # instead, `ott.math.decomposition._jax_sparse_to_scipy` handles it on host return D - self.graph @property diff --git a/ott/geometry/grid.py b/ott/geometry/grid.py index 25e9809eb..3cc94c77b 100644 --- a/ott/geometry/grid.py +++ b/ott/geometry/grid.py @@ -21,19 +21,22 @@ import jax.numpy as jnp import numpy as np -from ott.geometry import costs, geometry, ops, pointcloud +from ott.geometry import costs, geometry, pointcloud +from ott.math import utils + +__all__ = ["Grid"] @jax.tree_util.register_pytree_node_class class Grid(geometry.Geometry): - r"""Class describing the geometry of points taken in a cartestian product. + r"""Class describing the geometry of points taken in a Cartesian product. This class implements a geometry in which probability measures are supported on a :math:`d`-dimensional cartesian grid, a cartesian product of :math:`d` lists of values, each list being itself of size :math:`n_i`. The transportation cost between points in the grid is assumed to be separable, - namely a sum of coordinate-wise cost functions, as in + namely a sum of coordinate-wise cost functions, as in: .. math:: @@ -52,7 +55,7 @@ class Grid(geometry.Geometry): Args: x : list of arrays of varying sizes, describing the locations of the grid. - Locations are provided as a list of jnp.ndarrays, that is :math:`d` + Locations are provided as a list of arrays, that is :math:`d` vectors of (possibly varying) size :math:`n_i`. The resulting grid is the Cartesian product of these vectors. grid_size: tuple of integers describing grid sizes, namely @@ -201,14 +204,13 @@ def _apply_lse_kernel_one_dimension(self, dimension, f, g, eps, vec=None): if vec is not None: vec = jnp.transpose(vec, indices) - softmax_res, softmax_sgn = ops.logsumexp( + softmax_res, softmax_sgn = utils.logsumexp( centered_cost, b=vec, axis=1, return_sign=True ) return eps * jnp.transpose(softmax_res, indices), jnp.transpose(softmax_sgn, indices) - else: - softmax_res = ops.logsumexp(centered_cost, axis=1) - return eps * jnp.transpose(softmax_res, indices), None + softmax_res = utils.logsumexp(centered_cost, axis=1) + return eps * jnp.transpose(softmax_res, indices), None def _apply_cost_to_vec( self, vec: jnp.ndarray, axis: int = 0, fn=None diff --git a/ott/geometry/low_rank.py b/ott/geometry/low_rank.py index bf8bfa2a3..eeffa0733 100644 --- a/ott/geometry/low_rank.py +++ b/ott/geometry/low_rank.py @@ -22,6 +22,8 @@ from ott.geometry import geometry +__all__ = ["LRCGeometry"] + @jax.tree_util.register_pytree_node_class class LRCGeometry(geometry.Geometry): @@ -45,10 +47,11 @@ class LRCGeometry(geometry.Geometry): ``cost_matrix /= scale_cost``. If `True`, use 'mean'. batch_size: optional size of the batch to compute online (without instantiating the matrix) the scale factor ``scale_cost`` of the - ``cost_matrix`` when ``scale_cost='max_cost'``. If set to ``None``, the - batch size is set to 1024 or to the largest number of samples between - ``cost_1`` and ``cost_2`` if smaller than `1024`. - kwargs: Additional kwargs to :class:`~ott.geometry.geometry.Geometry`. + :attr:`cost_matrix` when ``scale_cost = 'max_cost'``. If `None`, the batch + size is set to `1024` or to the largest number of samples between + :attr:`cost_1` and :attr:`cost_2` if smaller than `1024`. + kwargs: Additional keyword arguments for + :class:`~ott.geometry.geometry.Geometry`. """ def __init__( @@ -187,10 +190,10 @@ def compute_max_cost(self) -> float: Three cases are taken into account: - If the number of samples of ``cost_1`` and ``cost_2`` are both smaller - than 1024 and if ``batch_size`` is ``None``, the ``cost_matrix`` is + than 1024 and if ``batch_size`` is `None`, the ``cost_matrix`` is computed to obtain its maximum entry. - If one of the number of samples of ``cost_1`` or ``cost_2`` is larger - than 1024 and if ``batch_size`` is ``None``, then the maximum of the + than 1024 and if ``batch_size`` is `None`, then the maximum of the cost matrix is calculated by batch. The batches are created on the longest axis of the cost matrix and their size is fixed to 1024. - If ``batch_size`` is provided as a float, then the maximum of the cost diff --git a/ott/geometry/pointcloud.py b/ott/geometry/pointcloud.py index 2f3fde9fb..57dedd3a4 100644 --- a/ott/geometry/pointcloud.py +++ b/ott/geometry/pointcloud.py @@ -21,19 +21,22 @@ import jax.numpy as jnp from typing_extensions import Literal -from ott.geometry import costs, geometry, low_rank, ops +from ott.geometry import costs, geometry, low_rank +from ott.math import utils as mu + +__all__ = ["PointCloud"] @jax.tree_util.register_pytree_node_class class PointCloud(geometry.Geometry): - """Defines geometry for 2 point clouds (possibly 1 vs itself) using CostFn. + """Defines geometry for 2 point clouds (possibly 1 vs itself). Creates a geometry, specifying a cost function passed as CostFn type object. - When the number of points is large, setting the `online` flag to `True` - implies that cost and kernel matrices used to update potentials or scalings + When the number of points is large, setting the ``batch_size`` flag implies + that cost and kernel matrices used to update potentials or scalings will be recomputed on the fly, rather than stored in memory. More precisely, - when setting `online`, the cost function will be partially cached by storing - norm values for each point in both point clouds, but the pairwise cost + when setting ``batch_size``, the cost function will be partially cached by + storing norm values for each point in both point clouds, but the pairwise cost function evaluations won't be. Args: @@ -596,7 +599,7 @@ def to_LRCGeometry( Useful when this geometry is used in the linear term of fused GW. kwargs: Keyword arguments, such as ``rank``, to :meth:`~ott.geometry.geometry.Geometry.to_LRCGeometry` used when - the point cloud does not squared Euclidean cost. + the point cloud does not have squared Euclidean cost. Returns: Returns the unmodified point cloud if :math:`n m \ge (n + m) d`, where @@ -730,7 +733,7 @@ def _apply_lse_kernel_xy( x, y, norm_x, norm_y, f, g, eps, vec, cost_fn, scale_cost ): c = _cost(x, y, norm_x, norm_y, cost_fn, scale_cost) - return ops.logsumexp((f + g - c) / eps, b=vec, return_sign=True, axis=-1) + return mu.logsumexp((f + g - c) / eps, b=vec, return_sign=True, axis=-1) def _transport_from_potentials_xy( diff --git a/ott/core/segment.py b/ott/geometry/segment.py similarity index 99% rename from ott/core/segment.py rename to ott/geometry/segment.py index 82c4c98d3..1051f348a 100644 --- a/ott/core/segment.py +++ b/ott/geometry/segment.py @@ -14,7 +14,9 @@ from typing import Callable, Optional, Tuple import jax -from jax import numpy as jnp +import jax.numpy as jnp + +__all__ = ["segment_point_cloud"] def segment_point_cloud( diff --git a/ott/initializers/__init__.py b/ott/initializers/__init__.py new file mode 100644 index 000000000..15cfac006 --- /dev/null +++ b/ott/initializers/__init__.py @@ -0,0 +1 @@ +from . import linear, nn, quadratic diff --git a/ott/initializers/linear/__init__.py b/ott/initializers/linear/__init__.py new file mode 100644 index 000000000..1ce1a00cd --- /dev/null +++ b/ott/initializers/linear/__init__.py @@ -0,0 +1 @@ +from . import initializers, initializers_lr diff --git a/ott/core/initializers.py b/ott/initializers/linear/initializers.py similarity index 52% rename from ott/core/initializers.py rename to ott/initializers/linear/initializers.py index 700aa47a5..65b91cc52 100644 --- a/ott/core/initializers.py +++ b/ott/initializers/linear/initializers.py @@ -12,44 +12,39 @@ # See the License for the specific language governing permissions and # limitations under the License. """Sinkhorn initializers.""" -import functools -from abc import ABC, abstractmethod -from typing import Any, Dict, Optional, Sequence, Tuple +import abc +from typing import TYPE_CHECKING, Any, Dict, Optional, Sequence, Tuple import jax import jax.numpy as jnp -import optax -from flax import linen as nn -from flax.training import train_state -from ott.core import linear_problems, sinkhorn -from ott.geometry import geometry, pointcloud +from ott.geometry import pointcloud -__all__ = [ - "DefaultInitializer", "GaussianInitializer", "SortingInitializer", - "MetaInitializer" -] +if TYPE_CHECKING: + from ott.problems.linear import linear_problem + +__all__ = ["DefaultInitializer", "GaussianInitializer", "SortingInitializer"] @jax.tree_util.register_pytree_node_class -class SinkhornInitializer(ABC): +class SinkhornInitializer(abc.ABC): """Base class for Sinkhorn initializers.""" - @abstractmethod + @abc.abstractmethod def init_dual_a( - self, ot_prob: linear_problems.LinearProblem, lse_mode: bool + self, ot_prob: 'linear_problem.LinearProblem', lse_mode: bool ) -> jnp.ndarray: """Initialization for Sinkhorn potential/scaling f_u.""" - @abstractmethod + @abc.abstractmethod def init_dual_b( - self, ot_prob: linear_problems.LinearProblem, lse_mode: bool + self, ot_prob: 'linear_problem.LinearProblem', lse_mode: bool ) -> jnp.ndarray: """Initialization for Sinkhorn potential/scaling g_v.""" def __call__( self, - ot_prob: linear_problems.LinearProblem, + ot_prob: 'linear_problem.LinearProblem', a: Optional[jnp.ndarray], b: Optional[jnp.ndarray], lse_mode: bool, @@ -101,7 +96,7 @@ class DefaultInitializer(SinkhornInitializer): """Default initialization of Sinkhorn dual potentials/primal scalings.""" def init_dual_a( - self, ot_prob: linear_problems.LinearProblem, lse_mode: bool + self, ot_prob: 'linear_problem.LinearProblem', lse_mode: bool ) -> jnp.ndarray: """Initialize Sinkhorn potential/scaling f_u. @@ -117,7 +112,7 @@ def init_dual_a( return init_dual_a def init_dual_b( - self, ot_prob: linear_problems.LinearProblem, lse_mode: bool + self, ot_prob: 'linear_problem.LinearProblem', lse_mode: bool ) -> jnp.ndarray: """Initialize Sinkhorn potential/scaling g_v. @@ -145,7 +140,7 @@ class GaussianInitializer(DefaultInitializer): def init_dual_a( self, - ot_prob: linear_problems.LinearProblem, + ot_prob: 'linear_problem.LinearProblem', lse_mode: bool, ) -> jnp.ndarray: """Gaussian initialization function. @@ -247,7 +242,7 @@ def cond_fn(state: Tuple[jnp.ndarray, float, int]) -> bool: def init_dual_a( self, - ot_prob: linear_problems.LinearProblem, + ot_prob: 'linear_problem.LinearProblem', lse_mode: bool, init_f: Optional[jnp.ndarray] = None, ) -> jnp.ndarray: @@ -285,212 +280,6 @@ def init_dual_a( return f_u -@jax.tree_util.register_pytree_node_class -class MetaInitializer(DefaultInitializer): - """Meta OT Initializer with a fixed geometry :cite:`amos:22`. - - This initializer consists of a predictive model that outputs the - :math:`f` duals to solve the entropy-regularized OT problem given - input probability weights ``a`` and ``b``, and a given (assumed to be - fixed) geometry ``geom``. - The model's parameters are learned using a training set of OT - instances (multiple pairs of probability weights), that assume the - **same** geometry ``geom`` is used throughout, both for training and - evaluation. The meta model defaults to the MLP in - :class:`~ott.core.initializers.MetaMLP` and, with batched problem - instances passed into :meth:`update`. - - **Sample training usage.** The following code shows a simple - example of using ``update`` to train the model, where - ``a`` and ``b`` are the weights of the measures and - ``geom`` is the fixed geometry. - - .. code-block:: python - - meta_initializer = init_lib.MetaInitializer(geom=geom) - while training(): - a, b = sample_batch() - loss, init_f, meta_initializer.state = meta_initializer.update( - meta_initializer.state, a=a, b=b) - - Args: - geom: The fixed geometry of the problem instances. - meta_model: The model to predict the potential :math:`f` from the measures. - opt: The optimizer to update the parameters. - rng: The PRNG key to use for initializing the model. - state: The training state of the model to start from. - """ - - def __init__( - self, - geom: geometry.Geometry, - meta_model: Optional[nn.Module] = None, - opt: optax.GradientTransformation = optax.adam(learning_rate=1e-3), - rng: jax.random.PRNGKeyArray = jax.random.PRNGKey(0), - state: Optional[train_state.TrainState] = None - ): - self.geom = geom - self.dtype = geom.x.dtype - self.opt = opt - self.rng = rng - - na, nb = geom.shape - self.meta_model = MetaMLP( - potential_size=na - ) if meta_model is None else meta_model - - if state is None: - # Initialize the model's training state. - a_placeholder = jnp.zeros(na, dtype=self.dtype) - b_placeholder = jnp.zeros(nb, dtype=self.dtype) - params = self.meta_model.init(rng, a_placeholder, b_placeholder)['params'] - self.state = train_state.TrainState.create( - apply_fn=self.meta_model.apply, params=params, tx=opt - ) - else: - self.state = state - - self.update_impl = self._get_update_fn() - - def update( - self, state: train_state.TrainState, a: jnp.ndarray, b: jnp.ndarray - ) -> Tuple[jnp.ndarray, jnp.ndarray, train_state.TrainState]: - r"""Update the meta model with the dual objective. - - The goal is for the model to match the optimal duals, i.e., - :math:`\hat f_\theta \approx f^\star`. - This can be done by training the predictions of :math:`\hat f_\theta` - to optimize the dual objective, which :math:`f^\star` also optimizes for. - The overall learning setup can thus be written as: - - .. math:: - \min_\theta\; {\mathbb E}_{(\alpha,\beta)\sim{\mathcal{D}}}\; - J(\hat f_\theta(a, b); \alpha, \beta), - - where :math:`a,b` are the probabilities of the measures :math:`\alpha,\beta`, - :math:`\mathcal{D}` is a meta distribution of optimal transport problems, - - .. math:: - -J(f; \alpha, \beta, c) := \langle f, a\rangle + \langle g, b \rangle - - \varepsilon\left\langle \exp\{f/\varepsilon\}, K\exp\{g/\varepsilon\}\right\rangle - - is the entropic dual objective, - and :math:`K_{i,j} := -C_{i,j}/\varepsilon` is the *Gibbs kernel*. - - Args: - state: Optimizer state of the meta model. - a: Probabilites of the :math:`\alpha` measure's atoms. - b: Probabilites of the :math:`\beta` measure's atoms. - - Returns: - The training loss, :math:`f`, and updated state. - """ - return self.update_impl(state, a, b) - - def init_dual_a( - self, ot_prob: linear_problems.LinearProblem, lse_mode: bool - ) -> jnp.ndarray: - # Detect if the problem is batched. - assert ot_prob.a.ndim in (1, 2) and ot_prob.b.ndim in (1, 2) - vmap_a_val = 0 if ot_prob.a.ndim == 2 else None - vmap_b_val = 0 if ot_prob.b.ndim == 2 else None - - if vmap_a_val is not None or vmap_b_val is not None: - compute_f_maybe_batch = jax.vmap( - self._compute_f, in_axes=(vmap_a_val, vmap_b_val, None) - ) - else: - compute_f_maybe_batch = self._compute_f - - init_f = compute_f_maybe_batch(ot_prob.a, ot_prob.b, self.state.params) - f_u = init_f if lse_mode else ot_prob.geom.scaling_from_potential(init_f) - return f_u - - def _get_update_fn(self): - """Return the implementation (and jitted) update function.""" - - def dual_obj_loss_single(params, a, b): - f_pred = self._compute_f(a, b, params) - g_pred = self.geom.update_potential( - f_pred, jnp.zeros_like(b), jnp.log(b), 0, axis=0 - ) - g_pred = jnp.where(jnp.isfinite(g_pred), g_pred, 0.) - - ot_prob = linear_problems.LinearProblem(geom=self.geom, a=a, b=b) - dual_obj = sinkhorn.ent_reg_cost(f_pred, g_pred, ot_prob, lse_mode=True) - loss = -dual_obj - return loss, f_pred - - def loss_batch(params, a, b): - loss_fn = functools.partial(dual_obj_loss_single, params=params) - loss, f_pred = jax.vmap(loss_fn)(a=a, b=b) - return jnp.mean(loss), f_pred - - @jax.jit - def update(state, a, b): - a = jnp.atleast_2d(a) - b = jnp.atleast_2d(b) - grad_fn = jax.value_and_grad(loss_batch, has_aux=True) - (loss, init_f), grads = grad_fn(state.params, a, b) - return loss, init_f, state.apply_gradients(grads=grads) - - return update - - def _compute_f(self, a, b, params): - r"""Predict the optimal :math:`f` potential. - - Args: - a: Probabilites of the :math:`\alpha` measure's atoms. - b: Probabilites of the :math:`\beta` measure's atoms. - params: The parameters of the Meta model. - - Returns: - The :math:`f` potential. - """ - return self.meta_model.apply({'params': params}, a, b) - - def tree_flatten(self) -> Tuple[Sequence[Any], Dict[str, Any]]: - return [self.geom, self.meta_model, self.opt], { - 'rng': self.rng, - 'state': self.state - } - - -class MetaMLP(nn.Module): - r"""A Meta MLP potential for :class:`~ott.core.initializers.MetaInitializer`. - - This provides an MLP :math:`\hat f_\theta(a, b)` that maps from the probabilities - of the measures to the optimal dual potentials :math:`f`. - - Args: - potential_size: The dimensionality of :math:`f`. - num_hidden_units: The number of hidden units in each layer. - num_hidden_layers: The number of hidden layers. - """ - - potential_size: int - num_hidden_units: int = 512 - num_hidden_layers: int = 3 - - @nn.compact - def __call__(self, a: jnp.ndarray, b: jnp.ndarray) -> jnp.ndarray: - r"""Make a prediction. - - Args: - a: Probabilites of the :math:`\alpha` measure's atoms. - b: Probabilites of the :math:`\beta` measure's atoms. - - Returns: - The :math:`f` potential. - """ - dtype = a.dtype - z = jnp.concatenate((a, b)) - for _ in range(self.num_hidden_layers): - z = nn.relu(nn.Dense(self.num_hidden_units, dtype=dtype)(z)) - f = nn.Dense(self.potential_size, dtype=dtype)(z) - return f - - def _vectorized_update( f: jnp.ndarray, modified_cost: jnp.ndarray ) -> jnp.ndarray: diff --git a/ott/core/initializers_lr.py b/ott/initializers/linear/initializers_lr.py similarity index 91% rename from ott/core/initializers_lr.py rename to ott/initializers/linear/initializers_lr.py index 641f1b284..8e60ba51a 100644 --- a/ott/core/initializers_lr.py +++ b/ott/initializers/linear/initializers_lr.py @@ -1,5 +1,5 @@ +import abc import functools -from abc import ABC, abstractmethod from typing import ( TYPE_CHECKING, Any, @@ -13,36 +13,32 @@ ) import jax +import jax.numpy as jnp import numpy as np -from jax import numpy as jnp from typing_extensions import Literal -from ott.core import _math_utils as mu from ott.geometry import geometry, low_rank, pointcloud +from ott.math import fixed_point_loop +from ott.math import utils as mu + +if TYPE_CHECKING: + from ott.problems.linear import linear_problem + from ott.problems.quadratic import quadratic_problem + from ott.solvers.linear import sinkhorn, sinkhorn_lr + from ott.solvers.quadratic import gromov_wasserstein + +Problem_t = Union["linear_problem.LinearProblem", + "quadratic_problem.QuadraticProblem"] __all__ = [ "RandomInitializer", "Rank2Initializer", "KMeansInitializer", "GeneralizedKMeansInitializer" ] -if TYPE_CHECKING: - from ott.core import ( - gromov_wasserstein, - linear_problems, - quad_problems, - sinkhorn, - sinkhorn_lr, - ) - Problem_t = Union[linear_problems.LinearProblem, - quad_problems.QuadraticProblem] -else: - Problem_t = "Union[linear_problems.LinearProblem, " \ - "quad_problems.QuadraticProblem]" - @jax.tree_util.register_pytree_node_class -class LRInitializer(ABC): - """Low-rank initializer for linear/quadratic problems. +class LRInitializer(abc.ABC): + """Base class for low-rank initializers. Args: rank: Rank of the factorization. @@ -53,7 +49,7 @@ def __init__(self, rank: int, **kwargs: Any): self._rank = rank self._kwargs = kwargs - @abstractmethod + @abc.abstractmethod def init_q( self, ot_prob: Problem_t, @@ -67,13 +63,14 @@ def init_q( Args: ot_prob: OT problem. key: Random key for seeding. + init_g: Initial value for :math:`g` factor. kwargs: Additional keyword arguments. Returns: Array of shape ``[n, rank]``. """ - @abstractmethod + @abc.abstractmethod def init_r( self, ot_prob: Problem_t, @@ -87,13 +84,14 @@ def init_r( Args: ot_prob: Linear OT problem. key: Random key for seeding. + init_g: Initial value for :math:`g` factor. kwargs: Additional keyword arguments. Returns: Array of shape ``[m, rank]``. """ - @abstractmethod + @abc.abstractmethod def init_g( self, ot_prob: Problem_t, @@ -130,7 +128,7 @@ def from_solver( Returns: The low-rank initializer. """ - from ott.core import gromov_wasserstein + from ott.solvers.quadratic import gromov_wasserstein if isinstance(solver, gromov_wasserstein.GromovWasserstein): assert solver.is_low_rank, "GW solver is not low-rank." @@ -344,7 +342,8 @@ class KMeansInitializer(LRInitializer): rank: Rank of the factorization. min_iterations: Minimum number of k-means iterations. max_iterations: Maximum number of k-means iterations. - sinkhorn_kwargs: Keyword arguments for :class:`~ott.core.sinkhorn.Sinkhorn`. + sinkhorn_kwargs: Keyword arguments for + :class:`~ott.solvers.linear.sinkhorn.Sinkhorn`. kwargs: Keyword arguments for :func:`~ott.tools.k_means.k_means`. """ @@ -382,7 +381,9 @@ def _compute_factor( which: Literal["q", "r"], **kwargs: Any, ) -> jnp.ndarray: - from ott.core import linear_problems, quad_problems, sinkhorn + from ott.problems.linear import linear_problem + from ott.problems.quadratic import quadratic_problem + from ott.solvers.linear import sinkhorn from ott.tools import k_means del kwargs @@ -395,7 +396,7 @@ def _compute_factor( ) fn = jax.jit(fn, static_argnames="k") if jit else fn - if isinstance(ot_prob, quad_problems.QuadraticProblem): + if isinstance(ot_prob, quadratic_problem.QuadraticProblem): geom = ot_prob.geom_xx if which == "q" else ot_prob.geom_yy else: geom = ot_prob.geom @@ -407,7 +408,7 @@ def _compute_factor( arr, centroids, epsilon=0.1, scale_cost="max_cost" ) - prob = linear_problems.LinearProblem(geom, marginals, init_g) + prob = linear_problem.LinearProblem(geom, marginals, init_g) solver = sinkhorn.Sinkhorn(**self._sinkhorn_kwargs) return solver(prob).matrix @@ -466,7 +467,8 @@ class GeneralizedKMeansInitializer(KMeansInitializer): inner_iterations: Number of iterations used by the algorithm before re-evaluating progress. threshold: Convergence threshold. - sinkhorn_kwargs: Keyword arguments for :class:`~ott.core.sinkhorn.Sinkhorn`. + sinkhorn_kwargs: Keyword arguments for + :class:`~ott.solvers.linear.sinkhorn.Sinkhorn`. """ def __init__( @@ -512,7 +514,9 @@ def _compute_factor( which: Literal["q", "r"], **kwargs: Any, ) -> jnp.ndarray: - from ott.core import fixed_point_loop, linear_problems, sinkhorn + from ott.problems.linear import linear_problem + from ott.problems.quadratic import quadratic_problem + from ott.solvers.linear import sinkhorn def init_fn() -> GeneralizedKMeansInitializer.State: n = geom.shape[0] @@ -527,7 +531,7 @@ def init_fn() -> GeneralizedKMeansInitializer.State: crossed_threshold=False ) - # see the explanation in `ott.core.sinkhorn_lr` + # see the explanation in `ott.solvers.linear.sinkhorn_lr` def converged( state: GeneralizedKMeansInitializer.State, consts: GeneralizedKMeansInitializer.Constants, iteration: int @@ -586,7 +590,7 @@ def body_fn( cost_matrix=cost, epsilon=eps, ) - problem = linear_problems.LinearProblem( + problem = linear_problem.LinearProblem( cost, a=consts.marginal, b=consts.g ) @@ -611,9 +615,8 @@ def body_fn( ) del kwargs - from ott.core import quad_problems - if isinstance(ot_prob, quad_problems.QuadraticProblem): + if isinstance(ot_prob, quadratic_problem.QuadraticProblem): geom = ot_prob.geom_xx if which == "q" else ot_prob.geom_yy else: geom = ot_prob.geom diff --git a/ott/initializers/nn/__init__.py b/ott/initializers/nn/__init__.py new file mode 100644 index 000000000..7ccb321da --- /dev/null +++ b/ott/initializers/nn/__init__.py @@ -0,0 +1 @@ +from . import initializers diff --git a/ott/initializers/nn/initializers.py b/ott/initializers/nn/initializers.py new file mode 100644 index 000000000..f87dce7d0 --- /dev/null +++ b/ott/initializers/nn/initializers.py @@ -0,0 +1,228 @@ +import functools +from typing import TYPE_CHECKING, Any, Dict, Optional, Sequence, Tuple + +import jax +import jax.numpy as jnp +import optax +from flax import linen as nn +from flax.training import train_state + +from ott.geometry import geometry +from ott.initializers.linear import initializers + +if TYPE_CHECKING: + from ott.problems.linear import linear_problem + +# TODO(michalk8): add initializer for NeuralDual? +__all__ = ["MetaInitializer", "MetaMLP"] + + +@jax.tree_util.register_pytree_node_class +class MetaInitializer(initializers.DefaultInitializer): + """Meta OT Initializer with a fixed geometry :cite:`amos:22`. + + This initializer consists of a predictive model that outputs the + :math:`f` duals to solve the entropy-regularized OT problem given + input probability weights ``a`` and ``b``, and a given (assumed to be + fixed) geometry ``geom``. + + The model's parameters are learned using a training set of OT + instances (multiple pairs of probability weights), that assume the + **same** geometry ``geom`` is used throughout, both for training and + evaluation. The meta model defaults to the MLP in + :class:`~ott.initializers.nn.initializers.MetaMLP` and, with batched problem + instances passed into :meth:`update`. + + Args: + geom: The fixed geometry of the problem instances. + meta_model: The model to predict the potential :math:`f` from the measures. + opt: The optimizer to update the parameters. + rng: The PRNG key to use for initializing the model. + state: The training state of the model to start from. + + Examples: + The following code shows a simple + example of using ``update`` to train the model, where + ``a`` and ``b`` are the weights of the measures and + ``geom`` is the fixed geometry. + + .. code-block:: python + + meta_initializer = init_lib.MetaInitializer(geom) + while training(): + a, b = sample_batch() + loss, init_f, meta_initializer.state = meta_initializer.update( + meta_initializer.state, a=a, b=b + ) + """ + + def __init__( + self, + geom: geometry.Geometry, + meta_model: Optional[nn.Module] = None, + opt: optax.GradientTransformation = optax.adam(learning_rate=1e-3), + rng: jax.random.PRNGKeyArray = jax.random.PRNGKey(0), + state: Optional[train_state.TrainState] = None + ): + self.geom = geom + self.dtype = geom.x.dtype + self.opt = opt + self.rng = rng + + na, nb = geom.shape + self.meta_model = MetaMLP( + potential_size=na + ) if meta_model is None else meta_model + + if state is None: + # Initialize the model's training state. + a_placeholder = jnp.zeros(na, dtype=self.dtype) + b_placeholder = jnp.zeros(nb, dtype=self.dtype) + params = self.meta_model.init(rng, a_placeholder, b_placeholder)['params'] + self.state = train_state.TrainState.create( + apply_fn=self.meta_model.apply, params=params, tx=opt + ) + else: + self.state = state + + self.update_impl = self._get_update_fn() + + def update( + self, state: train_state.TrainState, a: jnp.ndarray, b: jnp.ndarray + ) -> Tuple[jnp.ndarray, jnp.ndarray, train_state.TrainState]: + r"""Update the meta model with the dual objective. + + The goal is for the model to match the optimal duals, i.e., + :math:`\hat f_\theta \approx f^\star`. + This can be done by training the predictions of :math:`\hat f_\theta` + to optimize the dual objective, which :math:`f^\star` also optimizes for. + The overall learning setup can thus be written as: + + .. math:: + \min_\theta\; {\mathbb E}_{(\alpha,\beta)\sim{\mathcal{D}}}\; + J(\hat f_\theta(a, b); \alpha, \beta), + + where :math:`a,b` are the probabilities of the measures :math:`\alpha,\beta`, + :math:`\mathcal{D}` is a meta distribution of optimal transport problems, + + .. math:: + -J(f; \alpha, \beta, c) := \langle f, a\rangle + \langle g, b \rangle - + \varepsilon\left\langle \exp\{f/\varepsilon\}, K\exp\{g/\varepsilon\}\right\rangle + + is the entropic dual objective, + and :math:`K_{i,j} := -C_{i,j}/\varepsilon` is the *Gibbs kernel*. + + Args: + state: Optimizer state of the meta model. + a: Probabilites of the :math:`\alpha` measure's atoms. + b: Probabilites of the :math:`\beta` measure's atoms. + + Returns: + The training loss, :math:`f`, and updated state. + """ + return self.update_impl(state, a, b) + + def init_dual_a( + self, ot_prob: 'linear_problem.LinearProblem', lse_mode: bool + ) -> jnp.ndarray: + # Detect if the problem is batched. + assert ot_prob.a.ndim in (1, 2) and ot_prob.b.ndim in (1, 2) + vmap_a_val = 0 if ot_prob.a.ndim == 2 else None + vmap_b_val = 0 if ot_prob.b.ndim == 2 else None + + if vmap_a_val is not None or vmap_b_val is not None: + compute_f_maybe_batch = jax.vmap( + self._compute_f, in_axes=(vmap_a_val, vmap_b_val, None) + ) + else: + compute_f_maybe_batch = self._compute_f + + init_f = compute_f_maybe_batch(ot_prob.a, ot_prob.b, self.state.params) + f_u = init_f if lse_mode else ot_prob.geom.scaling_from_potential(init_f) + return f_u + + def _get_update_fn(self): + """Return the implementation (and jitted) update function.""" + from ott.problems.linear import linear_problem + from ott.solvers.linear import sinkhorn + + def dual_obj_loss_single(params, a, b): + f_pred = self._compute_f(a, b, params) + g_pred = self.geom.update_potential( + f_pred, jnp.zeros_like(b), jnp.log(b), 0, axis=0 + ) + g_pred = jnp.where(jnp.isfinite(g_pred), g_pred, 0.) + + ot_prob = linear_problem.LinearProblem(geom=self.geom, a=a, b=b) + dual_obj = sinkhorn.ent_reg_cost(f_pred, g_pred, ot_prob, lse_mode=True) + loss = -dual_obj + return loss, f_pred + + def loss_batch(params, a, b): + loss_fn = functools.partial(dual_obj_loss_single, params=params) + loss, f_pred = jax.vmap(loss_fn)(a=a, b=b) + return jnp.mean(loss), f_pred + + @jax.jit + def update(state, a, b): + a = jnp.atleast_2d(a) + b = jnp.atleast_2d(b) + grad_fn = jax.value_and_grad(loss_batch, has_aux=True) + (loss, init_f), grads = grad_fn(state.params, a, b) + return loss, init_f, state.apply_gradients(grads=grads) + + return update + + def _compute_f(self, a, b, params): + r"""Predict the optimal :math:`f` potential. + + Args: + a: Probabilites of the :math:`\alpha` measure's atoms. + b: Probabilites of the :math:`\beta` measure's atoms. + params: The parameters of the Meta model. + + Returns: + The :math:`f` potential. + """ + return self.meta_model.apply({'params': params}, a, b) + + def tree_flatten(self) -> Tuple[Sequence[Any], Dict[str, Any]]: + return [self.geom, self.meta_model, self.opt], { + 'rng': self.rng, + 'state': self.state + } + + +class MetaMLP(nn.Module): + r"""Potential for :class:`~ott.initializers.nn.initializers.MetaInitializer`. + + This provides an MLP :math:`\hat f_\theta(a, b)` that maps from the + probabilities of the measures to the optimal dual potentials :math:`f`. + + Args: + potential_size: The dimensionality of :math:`f`. + num_hidden_units: The number of hidden units in each layer. + num_hidden_layers: The number of hidden layers. + """ + + potential_size: int + num_hidden_units: int = 512 + num_hidden_layers: int = 3 + + @nn.compact + def __call__(self, a: jnp.ndarray, b: jnp.ndarray) -> jnp.ndarray: + r"""Make a prediction. + + Args: + a: Probabilities of the :math:`\alpha` measure's atoms. + b: Probabilities of the :math:`\beta` measure's atoms. + + Returns: + The :math:`f` potential. + """ + dtype = a.dtype + z = jnp.concatenate((a, b)) + for _ in range(self.num_hidden_layers): + z = nn.relu(nn.Dense(self.num_hidden_units, dtype=dtype)(z)) + f = nn.Dense(self.potential_size, dtype=dtype)(z) + return f diff --git a/ott/initializers/quadratic/__init__.py b/ott/initializers/quadratic/__init__.py new file mode 100644 index 000000000..7ccb321da --- /dev/null +++ b/ott/initializers/quadratic/__init__.py @@ -0,0 +1 @@ +from . import initializers diff --git a/ott/core/quad_initializers.py b/ott/initializers/quadratic/initializers.py similarity index 79% rename from ott/core/quad_initializers.py rename to ott/initializers/quadratic/initializers.py index 6bd66698d..2b09306d9 100644 --- a/ott/core/quad_initializers.py +++ b/ott/initializers/quadratic/initializers.py @@ -1,19 +1,20 @@ -from abc import ABC, abstractmethod +import abc from typing import TYPE_CHECKING, Any, Dict, Sequence, Tuple import jax -from ott.core import linear_problems, sinkhorn_lr from ott.geometry import geometry if TYPE_CHECKING: - from ott.core import initializers_lr, quad_problems + from ott.initializers.linear import initializers_lr + from ott.problems.linear import linear_problem + from ott.problems.quadratic import quadratic_problem __all__ = ["QuadraticInitializer", "LRQuadraticInitializer"] @jax.tree_util.register_pytree_node_class -class BaseQuadraticInitializer(ABC): +class BaseQuadraticInitializer(abc.ABC): """Base class for quadratic initializers. Args: @@ -24,8 +25,8 @@ def __init__(self, **kwargs: Any): self._kwargs = kwargs def __call__( - self, quad_prob: 'quad_problems.QuadraticProblem', **kwargs: Any - ) -> linear_problems.LinearProblem: + self, quad_prob: 'quadratic_problem.QuadraticProblem', **kwargs: Any + ) -> 'linear_problem.LinearProblem': """Compute the initial linearization of a quadratic problem. Args: @@ -35,11 +36,13 @@ def __call__( Returns: Linear problem. """ + from ott.problems.linear import linear_problem + n, m = quad_prob.geom_xx.shape[0], quad_prob.geom_yy.shape[0] geom = self._create_geometry(quad_prob, **kwargs) assert geom.shape == (n, m), f"Expected geometry of shape `{n, m}`, " \ f"found `{geom.shape}`." - return linear_problems.LinearProblem( + return linear_problem.LinearProblem( geom, a=quad_prob.a, b=quad_prob.b, @@ -47,9 +50,9 @@ def __call__( tau_b=quad_prob.tau_b ) - @abstractmethod + @abc.abstractmethod def _create_geometry( - self, quad_prob: 'quad_problems.QuadraticProblem', **kwargs: Any + self, quad_prob: 'quadratic_problem.QuadraticProblem', **kwargs: Any ) -> geometry.Geometry: """Compute initial geometry for linearization. @@ -58,7 +61,7 @@ def _create_geometry( kwargs: Additional keyword arguments. Returns: - The initial geometry used to initialize a linear problem. + Geometry used to initialize the linearized problem. """ def tree_flatten(self) -> Tuple[Sequence[Any], Dict[str, Any]]: @@ -104,9 +107,9 @@ class QuadraticInitializer(BaseQuadraticInitializer): """ def _create_geometry( - self, quad_prob: 'quad_problems.QuadraticProblem', *, epsilon: float, + self, quad_prob: 'quadratic_problem.QuadraticProblem', *, epsilon: float, **kwargs: Any - ) -> linear_problems.LinearProblem: + ) -> geometry.Geometry: """Compute initial geometry for linearization. Args: @@ -115,9 +118,9 @@ def _create_geometry( kwargs: Additional keyword arguments, unused. Returns: - The initial geometry used to initialize a linear problem. + The initial geometry used to initialize the linearized problem. """ - from ott.core.quad_problems import apply_cost, update_epsilon_unbalanced + from ott.problems.quadratic import quadratic_problem del kwargs unbalanced_correction = 0.0 @@ -131,14 +134,18 @@ def _create_geometry( if not quad_prob.is_balanced: transport_mass = marginal_1.sum() # Initialises epsilon for Unbalanced GW according to Sejourne et al (2021) - epsilon = update_epsilon_unbalanced(epsilon, transport_mass) + epsilon = quadratic_problem.update_epsilon_unbalanced( + epsilon=epsilon, transport_mass=transport_mass + ) unbalanced_correction = quad_prob.cost_unbalanced_correction( - tmp, marginal_1, marginal_2, epsilon + tmp, marginal_1, marginal_2, epsilon=epsilon ) h1, h2 = quad_prob.quad_loss - tmp = apply_cost(quad_prob.geom_xx, tmp, axis=1, fn=h1) - tmp = apply_cost(quad_prob.geom_yy, tmp.T, axis=1, fn=h2).T + tmp = quadratic_problem.apply_cost(quad_prob.geom_xx, tmp, axis=1, fn=h1) + tmp = quadratic_problem.apply_cost( + quad_prob.geom_yy, tmp.T, axis=1, fn=h2 + ).T cost_matrix = (marginal_cost.cost_matrix - tmp + unbalanced_correction) cost_matrix += quad_prob.fused_penalty * quad_prob._fused_cost_matrix @@ -158,18 +165,20 @@ def __init__(self, lr_linear_initializer: 'initializers_lr.LRInitializer'): self._linear_lr_initializer = lr_linear_initializer def _create_geometry( - self, quad_prob: 'quad_problems.QuadraticProblem', **kwargs: Any + self, quad_prob: 'quadratic_problem.QuadraticProblem', **kwargs: Any ) -> geometry.Geometry: """Compute initial geometry for linearization. Args: quad_prob: Quadratic OT problem. kwargs: Keyword arguments for - :meth:`ott.core.initializers_lr.LRInitializer.__call__`. + :meth:`~ott.initializers.linear.initializers_lr.LRInitializer.__call__`. Returns: The initial geometry used to initialize a linear problem. """ + from ott.solvers.linear import sinkhorn_lr + q, r, g = self._linear_lr_initializer(quad_prob, **kwargs) tmp_out = sinkhorn_lr.LRSinkhornOutput( q=q, r=r, g=g, costs=None, errors=None, ot_prob=None diff --git a/ott/math/__init__.py b/ott/math/__init__.py new file mode 100644 index 000000000..67aca8931 --- /dev/null +++ b/ott/math/__init__.py @@ -0,0 +1,7 @@ +from . import ( + decomposition, + fixed_point_loop, + matrix_square_root, + unbalanced_functions, + utils, +) diff --git a/ott/core/decomposition.py b/ott/math/decomposition.py similarity index 97% rename from ott/core/decomposition.py rename to ott/math/decomposition.py index cbdf61a8b..e05888e81 100644 --- a/ott/core/decomposition.py +++ b/ott/math/decomposition.py @@ -25,7 +25,7 @@ except ImportError: cholmod = None -__all__ = ["CholeskySolver", "DenseCholeskySolver", "SparseCholeskySolver"] +__all__ = ["DenseCholeskySolver", "SparseCholeskySolver"] T = TypeVar("T") @@ -65,12 +65,12 @@ def _solve(self, L: Optional[T], b: jnp.ndarray) -> jnp.ndarray: def create(cls, A: Union[T, sp.spmatrix], **kwargs: Any) -> "CholeskySolver": """Instantiate sparse or dense Cholesky solver. - Optionally converts :class:`scipy.sparse.spmatrix` to its + And optionally convert :class:`scipy.sparse.spmatrix` to its :mod:`jax` equivalent. Args: A: Symmetric positive definite matrix of shape ``[n, n]``. - kwargs: Keyword arguments for the initialization. + kwargs: Keyword arguments for the solver initialization. Returns: Sparse or dense Cholesky solver. diff --git a/ott/core/fixed_point_loop.py b/ott/math/fixed_point_loop.py similarity index 98% rename from ott/core/fixed_point_loop.py rename to ott/math/fixed_point_loop.py index 3482bb031..3b22ad20a 100644 --- a/ott/core/fixed_point_loop.py +++ b/ott/math/fixed_point_loop.py @@ -17,9 +17,10 @@ from typing import Any, Callable import jax +import jax.numpy as jnp import numpy as np -from jax import dtypes -from jax import numpy as jnp + +__all__ = ["fixpoint_iter", "fixpoint_iter_backprop"] def fixpoint_iter( @@ -123,7 +124,7 @@ def fixpoint_iter_fwd( states = jax.tree_util.tree_map( lambda x: jnp.zeros( (max_iterations // inner_iterations + 1,) + jnp.shape(x), - dtype=dtypes.result_type(x) + dtype=jax.dtypes.result_type(x) ), state ) diff --git a/ott/geometry/matrix_square_root.py b/ott/math/matrix_square_root.py similarity index 95% rename from ott/geometry/matrix_square_root.py rename to ott/math/matrix_square_root.py index 22eac70b9..761e67af2 100644 --- a/ott/geometry/matrix_square_root.py +++ b/ott/math/matrix_square_root.py @@ -22,7 +22,9 @@ import jax.numpy as jnp import numpy as np -from ott.core import fixed_point_loop +from ott.math import fixed_point_loop + +__all__ = ["sqrtm", "sqrtm_only", "inv_sqrtm_only"] @functools.partial(jax.custom_vjp, nondiff_argnums=(1, 2, 3, 4, 5)) @@ -33,7 +35,7 @@ def sqrtm( inner_iterations: int = 10, max_iterations: int = 1000, regularization: float = 1e-3 -) -> jnp.ndarray: +) -> Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]: """Higham algorithm to compute matrix square root of p.d. matrix. See :cite:`higham:97`, eq. 2.6b @@ -171,7 +173,7 @@ def sqrtm_fwd( max_iterations=max_iterations, regularization=regularization, ) - return ((sqrt_x, inv_sqrt_x, errors), (sqrt_x, inv_sqrt_x)) + return (sqrt_x, inv_sqrt_x, errors), (sqrt_x, inv_sqrt_x) def sqrtm_bwd( @@ -226,7 +228,7 @@ def sqrtm_bwd( axis1=-1, axis2=-2 ) - return (vjp_cot_sqrt + vjp_cot_inv_sqrt,) + return vjp_cot_sqrt + vjp_cot_inv_sqrt, sqrtm.defvjp(sqrtm_fwd, sqrtm_bwd) @@ -254,7 +256,7 @@ def sqrtm_only_bwd(sqrt_x: jnp.ndarray, axis1=-2, axis2=-1 ) - return (vjp,) + return vjp, sqrtm_only.defvjp(sqrtm_only_fwd, sqrtm_only_bwd) @@ -270,9 +272,8 @@ def inv_sqrtm_only_fwd(x: jnp.ndarray) -> Tuple[jnp.ndarray, jnp.ndarray]: return inv_sqrt_x, inv_sqrt_x -def inv_sqrtm_only_bwd( - residual: jnp.ndarray, cotangent: jnp.ndarray -) -> jnp.ndarray: +def inv_sqrtm_only_bwd(residual: jnp.ndarray, + cotangent: jnp.ndarray) -> Tuple[jnp.ndarray]: inv_sqrt_x = residual inv_x = jnp.matmul(inv_sqrt_x, inv_sqrt_x) vjp = jnp.swapaxes( @@ -287,7 +288,7 @@ def inv_sqrtm_only_bwd( axis1=-1, axis2=-2 ) - return (vjp,) + return vjp, inv_sqrtm_only.defvjp(inv_sqrtm_only_fwd, inv_sqrtm_only_bwd) diff --git a/ott/core/unbalanced_functions.py b/ott/math/unbalanced_functions.py similarity index 100% rename from ott/core/unbalanced_functions.py rename to ott/math/unbalanced_functions.py diff --git a/ott/geometry/ops.py b/ott/math/utils.py similarity index 60% rename from ott/geometry/ops.py rename to ott/math/utils.py index 1bf3248a3..fef5ae667 100644 --- a/ott/geometry/ops.py +++ b/ott/math/utils.py @@ -1,23 +1,46 @@ -# Copyright 2022 Google LLC. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""Low level functions used within the scope of Geometric processing.""" - import functools +from typing import TYPE_CHECKING, Optional, Union import jax +import jax.experimental.sparse as jesp import jax.numpy as jnp +if TYPE_CHECKING: + from ott.geometry import costs + +__all__ = [ + "safe_log", "kl", "js", "sparse_scale", "logsumexp", + "barycentric_projection" +] + +# TODO(michalk8): move to typing.py when refactoring types +Sparse_t = Union[jesp.CSR, jesp.CSC, jesp.COO, jesp.BCOO] + + +def safe_log(x: jnp.ndarray, *, eps: Optional[float] = None) -> jnp.ndarray: + if eps is None: + eps = jnp.finfo(x.dtype).tiny + return jnp.where(x > 0., jnp.log(x), jnp.log(eps)) + + +def kl(p: jnp.ndarray, q: jnp.ndarray) -> float: + """Kullback-Leilbler divergence.""" + return jnp.vdot(p, (safe_log(p) - safe_log(q))) + + +def js(p: jnp.ndarray, q: jnp.ndarray, *, c: float = 0.5) -> float: + """Jensen-Shannon divergence.""" + return c * (kl(p, q) + kl(q, p)) + + +def sparse_scale(c: float, mat: Sparse_t) -> Sparse_t: + """Scale a sparse matrix by a constant.""" + if isinstance(mat, jesp.BCOO): + # most feature complete, defer to original impl. + return c * mat + (data, *children), aux_data = mat.tree_flatten() + return type(mat).tree_unflatten(aux_data, [c * data] + children) + @functools.partial(jax.custom_jvp, nondiff_argnums=(1, 2, 4)) def logsumexp(mat, axis=None, keepdims=False, b=None, return_sign=False): @@ -69,3 +92,10 @@ def logsumexp_jvp(axis, keepdims, return_sign, primals, tangents): return (lse, sign), (sign * res, jnp.zeros_like(sign)) else: return lse, res + + +@functools.partial(jax.vmap, in_axes=[0, 0, None]) +def barycentric_projection( + matrix: jnp.ndarray, y: jnp.ndarray, cost_fn: "costs.CostFn" +) -> jnp.ndarray: + return jax.vmap(cost_fn.barycenter, in_axes=[0, None])(matrix, y) diff --git a/ott/problems/__init__.py b/ott/problems/__init__.py new file mode 100644 index 000000000..87714fd6b --- /dev/null +++ b/ott/problems/__init__.py @@ -0,0 +1 @@ +from . import linear, quadratic diff --git a/ott/problems/linear/__init__.py b/ott/problems/linear/__init__.py new file mode 100644 index 000000000..2088e5a3c --- /dev/null +++ b/ott/problems/linear/__init__.py @@ -0,0 +1 @@ +from . import barycenter_problem, linear_problem, potentials diff --git a/ott/problems/linear/barycenter_problem.py b/ott/problems/linear/barycenter_problem.py new file mode 100644 index 000000000..c2f28860f --- /dev/null +++ b/ott/problems/linear/barycenter_problem.py @@ -0,0 +1,181 @@ +# Copyright 2022 Apple Inc +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Classes defining OT problem(s) (objective function + utilities).""" +from typing import Any, Dict, Optional, Sequence, Tuple + +import jax +import jax.numpy as jnp + +from ott.geometry import costs, segment + +__all__ = ["BarycenterProblem"] + + +@jax.tree_util.register_pytree_node_class +class BarycenterProblem: + """Wasserstein barycenter problem :cite:`cuturi:14`. + + Args: + y: Array of shape ``[num_total_points, ndim]`` merging the points of all + measures. Alternatively, already segmented array of shape + ``[num_measures, max_measure_size, ndim]`` can be passed. + See also :func:`~ott.geometry.segment.segment_point_cloud`. + b: Array of shape ``[num_total_points,]`` containing the weights of all + the points within the measures that define the barycenter problem. + Same as ``y``, pre-segmented array of weights of shape + ``[num_measures, max_measure_size]`` can be passed. + If ``y`` is already pre-segmented, this array must be always specified. + weights: Array of shape ``[num_measures,]`` containing the weights of the + measures. + cost_fn: Cost function used. If `None`, + use the :class:`~ott.geometry.costs.SqEuclidean` cost. + epsilon: Epsilon regularization used to solve reg-OT problems. + debiased: **Currently not implemented.** + Whether the problem is debiased, in the sense that + the regularized transportation cost of barycenter to itself will + be considered when computing gradient. Note that if the debiased option + is used, the barycenter size in + :meth:`~ott.solvers.linear.continuous_barycenter.WassersteinBarycenter.init_state` + needs to be smaller than the maximum measure size for parallelization to + operate efficiently. + kwargs: Keyword arguments :func:`~ott.geometry.segment.segment_point_cloud`. + Only used when ``y`` is not already segmented. When passing + ``segment_ids``, 2 arguments must be specified for jitting to work: + + - ``num_segments`` - the total number of measures. + - ``max_measure_size`` - maximum of support sizes of these measures. + """ + + def __init__( + self, + y: jnp.ndarray, + b: Optional[jnp.ndarray] = None, + weights: Optional[jnp.ndarray] = None, + cost_fn: Optional[costs.CostFn] = None, + epsilon: Optional[float] = None, + debiased: bool = False, + **kwargs: Any, + ): + self._y = y + if y.ndim == 3 and b is None: + raise ValueError("Specify weights if `y` is already segmented.") + self._b = b + self._weights = weights + self.cost_fn = costs.SqEuclidean() if cost_fn is None else cost_fn + self.epsilon = epsilon + self.debiased = debiased + self._kwargs = kwargs + + if self._is_segmented: + # (num_measures, max_measure_size, ndim) + # (num_measures, max_measure_size) + assert self._y.shape[:2] == self._b.shape + else: + # (num_total_points, ndim) + # (num_total_points,) + assert self._b is None or self._y.shape[0] == self._b.shape[0] + + @property + def segmented_y_b(self) -> Tuple[jnp.ndarray, jnp.ndarray]: + """Tuple of arrays containing the segmented measures and weights. + + Additional segment may be added when the problem is debiased. + + - Segmented measures of shape ``[num_measures, max_measure_size, ndim]``. + - Segmented weights of shape ``[num_measures, max_measure_size]``. + """ + if self._is_segmented: + y, b = self._y, self._b + else: + y, b = segment.segment_point_cloud( + x=self._y, + a=self._b, + padding_vector=self.cost_fn._padder(self.ndim), + **self._kwargs + ) + + if self.debiased: + return self._add_slice_for_debiased(y, b) + return y, b + + def _add_slice_for_debiased( + self, y: jnp.ndarray, b: jnp.ndarray + ) -> Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]: + y, b = self._y, self._b + _, n, ndim = y.shape # (num_measures, max_measure_size, ndim) + # yapf: disable + y = jnp.concatenate((y, jnp.zeros((1, n, ndim))), axis=0) + b = jnp.concatenate((b, jnp.zeros((1, n))), axis=0) + # yapf: enable + return y, b + + @property + def flattened_y(self) -> jnp.ndarray: + """Array of shape ``[num_measures * (N_1 + N_2 + ...), ndim]``.""" + if self._is_segmented: + return self._y.reshape((-1, self._y.shape[-1])) + return self._y + + @property + def flattened_b(self) -> Optional[jnp.ndarray]: + """Array of shape ``[num_measures * (N_1 + N_2 + ...),]``.""" + return None if self._b is None else self._b.ravel() + + @property + def num_measures(self) -> int: + """Number of measures.""" + return self.segmented_y_b[0].shape[0] + + @property + def max_measure_size(self) -> int: + """Maximum number of points across all measures.""" + return self.segmented_y_b[0].shape[1] + + @property + def ndim(self) -> int: + """Number of dimensions of ``y``.""" + return self._y.shape[-1] + + @property + def weights(self) -> jnp.ndarray: + """Barycenter weights of shape ``[num_measures,]`` that sum to 1.""" + if self._weights is None: + weights = jnp.ones((self.num_measures,)) / self.num_measures + else: + # Check that the number of measures coincides with the weights' size. + assert self._weights.shape[0] == self.num_measures + # By default, we assume that weights sum to 1, and enforce this if needed. + weights = self._weights / jnp.sum(self._weights) + if self.debiased: + weights = jnp.concatenate((weights, jnp.array([-0.5]))) + return weights + + @property + def _is_segmented(self) -> bool: + return self._y.ndim == 3 + + def tree_flatten(self) -> Tuple[Sequence[Any], Dict[str, Any]]: + return ([self._y, self._b, self._weights], { + 'cost_fn': self.cost_fn, + 'epsilon': self.epsilon, + 'debiased': self.debiased, + **self._kwargs, + }) + + @classmethod + def tree_unflatten( + cls, aux_data: Dict[str, Any], children: Sequence[Any] + ) -> "BarycenterProblem": + y, b, weights = children + return cls(y=y, b=b, weights=weights, **aux_data) diff --git a/ott/core/linear_problems.py b/ott/problems/linear/linear_problem.py similarity index 97% rename from ott/core/linear_problems.py rename to ott/problems/linear/linear_problem.py index 4357267ff..6f7aad57d 100644 --- a/ott/core/linear_problems.py +++ b/ott/problems/linear/linear_problem.py @@ -20,6 +20,9 @@ from ott.geometry import geometry +__all__ = ["LinearProblem"] + +# TODO(michalk8): move to typing.py when refactoring the types MarginalFunc = Callable[[jnp.ndarray, jnp.ndarray], jnp.ndarray] TransportAppFunc = Callable[[jnp.ndarray, jnp.ndarray, jnp.ndarray, int], jnp.ndarray] diff --git a/ott/core/potentials.py b/ott/problems/linear/potentials.py similarity index 86% rename from ott/core/potentials.py rename to ott/problems/linear/potentials.py index 126348b5f..dde7366d6 100644 --- a/ott/core/potentials.py +++ b/ott/problems/linear/potentials.py @@ -1,4 +1,4 @@ -from typing import Any, Callable, Dict, Sequence, Tuple +from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, Sequence, Tuple import jax import jax.numpy as jnp @@ -6,7 +6,8 @@ import jax.tree_util as jtu from typing_extensions import Literal -from ott.geometry import costs, pointcloud +if TYPE_CHECKING: + from ott.geometry import costs, pointcloud __all__ = ["DualPotentials", "EntropicPotentials"] Potential_t = Callable[[jnp.ndarray], float] @@ -24,15 +25,16 @@ class DualPotentials: g: The second dual potential function. cost_fn: The cost function used to solve the OT problem. corr: Whether the duals solve the problem in distance form, or correlation - form (as used for instance for ICNNs, see e.g. top right of p.3 in :cite:`makkuva:20`) + form (as used for instance for ICNNs, see, e.g., top right of p.3 in + :cite:`makkuva:20`) """ def __init__( self, f: Potential_t, g: Potential_t, - cost_fn: costs.CostFn, *, + cost_fn: 'costs.CostFn', corr: bool = False ): self._f = f @@ -41,7 +43,7 @@ def __init__( self._corr = corr def transport(self, vec: jnp.ndarray, forward: bool = True) -> jnp.ndarray: - r"""Transport ``vec`` according to Brenier formula. + r"""Transport ``vec`` according to Brenier formula :cite:`brenier:91`. Uses Theorem 1.17 from :cite:`santambrogio:15` to compute an OT map when given the Legendre transform of the dual potentials. @@ -63,6 +65,8 @@ def transport(self, vec: jnp.ndarray, forward: bool = True) -> jnp.ndarray: Returns: The transported points. """ + from ott.geometry import costs + vec = jnp.atleast_2d(vec) if self._corr and isinstance(self.cost_fn, costs.SqEuclidean): return self._grad_g(vec) if forward else self._grad_f(vec) @@ -124,6 +128,8 @@ def _grad_g(self) -> Callable[[jnp.ndarray], jnp.ndarray]: @property def _grad_h_inv(self) -> Callable[[jnp.ndarray], jnp.ndarray]: + from ott.geometry import costs + assert isinstance(self.cost_fn, costs.TICost), ( "Cost must be a `TICost` and " "provide access to Legendre transform of `h`." @@ -148,16 +154,23 @@ class EntropicPotentials(DualPotentials): f: The first dual potential vector of shape ``[n,]``. g: The second dual potential vector of shape ``[m,]``. geom: Geometry used to compute the dual potentials using - :class:`~ott.core.sinkhorn.Sinkhorn`. - a: probability weights for the first measure. - b: probaility weights for the second measure. + :class:`~ott.solvers.linear.sinkhorn.Sinkhorn`. + a: Probability weights for the first measure. If `None`, use uniform. + b: Probability weights for the second measure. If `None`, use uniform. """ def __init__( - self, f: jnp.ndarray, g: jnp.ndarray, geom: pointcloud.PointCloud, - a: jnp.ndarray, b: jnp.ndarray + self, + f: jnp.ndarray, + g: jnp.ndarray, + geom: "pointcloud.PointCloud", + a: Optional[jnp.ndarray] = None, + b: Optional[jnp.ndarray] = None, ): n, m = geom.shape + a = jnp.ones(n) / n if a is None else a + b = jnp.ones(m) / m if b is None else b + assert f.shape == (n,) and a.shape == (n,), \ f"Expected `f` and `a` to be of shape `{n,}`, found `{f.shape}`." assert g.shape == (m,) and b.shape == (m,), \ @@ -181,10 +194,13 @@ def g(self) -> Potential_t: def _create_potential_function( self, *, kind: Literal["f", "g"] ) -> Potential_t: + from ott.geometry import pointcloud def callback(x: jnp.ndarray) -> float: cost = pointcloud.PointCloud( - jnp.atleast_2d(x), y, cost_fn=self._geom.cost_fn + jnp.atleast_2d(x), + y, + cost_fn=self.cost_fn, ).cost_matrix z = (potential - cost) / epsilon lse = -epsilon * jsp.special.logsumexp(z, b=prob_weights, axis=-1) diff --git a/ott/problems/quadratic/__init__.py b/ott/problems/quadratic/__init__.py new file mode 100644 index 000000000..18ff1c517 --- /dev/null +++ b/ott/problems/quadratic/__init__.py @@ -0,0 +1 @@ +from . import gw_barycenter, quadratic_costs, quadratic_problem diff --git a/ott/core/bar_problems.py b/ott/problems/quadratic/gw_barycenter.py similarity index 55% rename from ott/core/bar_problems.py rename to ott/problems/quadratic/gw_barycenter.py index dce1e5286..9d582ce1c 100644 --- a/ott/core/bar_problems.py +++ b/ott/problems/quadratic/gw_barycenter.py @@ -1,17 +1,3 @@ -# Copyright 2022 Apple Inc -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""Classes defining OT problem(s) (objective function + utilities).""" import functools from typing import Any, Dict, Optional, Sequence, Tuple, Union @@ -19,183 +5,27 @@ import jax.numpy as jnp from typing_extensions import Literal -from ott.core import quad_problems, segment -from ott.geometry import costs, geometry, pointcloud +from ott.geometry import costs, geometry, pointcloud, segment +from ott.math import utils as mu +from ott.problems.linear import barycenter_problem +from ott.problems.quadratic import quadratic_costs, quadratic_problem -__all__ = ["BarycenterProblem", "GWBarycenterProblem", "barycentric_projection"] +__all__ = ["GWBarycenterProblem"] +# TODO(michalk8): better abstraction (common superclass for Wasserstein bary) @jax.tree_util.register_pytree_node_class -class BarycenterProblem: - """Wasserstein barycenter problem :cite:`cuturi:14`. - - Args: - y: Array of shape ``[num_total_points, ndim]`` merging the points of all - measures. Alternatively, already segmented array of shape - ``[num_measures, max_measure_size, ndim]`` can be passed. - See also :func:`~ott.core.segment.segment_point_cloud`. - b: Array of shape ``[num_total_points,]`` containing the weights of all - the points within the measures that define the barycenter problem. - Similarly as ``y``, segmented array of weights of shape - ``[num_measures, max_measure_size]`` can be passed. - If ``y`` is already pre-segmented, this array must be always specified. - weights: Array of shape ``[num_measures,]`` containing the weights of the - measures. - cost_fn: Cost function used. If `None`, - use :class:`~ott.geometry.costs.SqEuclidean` cost. - epsilon: Epsilon regularization used to solve reg-OT problems. - debiased: **Currently not implemented.** - Whether the problem is debiased, in the sense that - the regularized transportation cost of barycenter to itself will - be considered when computing gradient. Note that if the debiased option - is used, the barycenter size in - :meth:`~ott.core.continuous_barycenter.WassersteinBarycenter.init_state` - needs to be smaller than the maximum measure size for parallelization to - operate efficiently. - kwargs: Keyword arguments :func:`~ott.core.segment.segment_point_cloud`. - Only used when ``y`` is not already segmented. When passing - ``segment_ids``, 2 arguments must be specified for jitting to work: - - - ``num_segments`` - the total number of measures. - - ``max_measure_size`` - maximum of support sizes of these measures. - """ - - def __init__( - self, - y: jnp.ndarray, - b: Optional[jnp.ndarray] = None, - weights: Optional[jnp.ndarray] = None, - cost_fn: Optional[costs.CostFn] = None, - epsilon: Optional[float] = None, - debiased: bool = False, - **kwargs: Any, - ): - self._y = y - if y.ndim == 3 and b is None: - raise ValueError("Specify weights if `y` is already segmented.") - self._b = b - self._weights = weights - self.cost_fn = costs.SqEuclidean() if cost_fn is None else cost_fn - self.epsilon = epsilon - self.debiased = debiased - self._kwargs = kwargs - - if self._is_segmented: - # (num_measures, max_measure_size, ndim) - # (num_measures, max_measure_size) - assert self._y.shape[:2] == self._b.shape - else: - # (num_total_points, ndim) - # (num_total_points,) - assert self._b is None or self._y.shape[0] == self._b.shape[0] - - @property - def segmented_y_b(self) -> Tuple[jnp.ndarray, jnp.ndarray]: - """Tuple of arrays containing the segmented measures and weights. - - Additional segment may be added when the problem is debiased. - - - Segmented measures of shape ``[num_measures, max_measure_size, ndim]``. - - Segmented weights of shape ``[num_measures, max_measure_size]``. - """ - if self._is_segmented: - y, b = self._y, self._b - else: - y, b = segment.segment_point_cloud( - x=self._y, - a=self._b, - padding_vector=self.cost_fn.padder(self.ndim), - **self._kwargs - ) - - if self.debiased: - return self._add_slice_for_debiased(y, b) - return y, b - - def _add_slice_for_debiased( - self, y: jnp.ndarray, b: jnp.ndarray - ) -> Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]: - y, b = self._y, self._b - _, n, ndim = y.shape # (num_measures, max_measure_size, ndim) - # yapf: disable - y = jnp.concatenate((y, jnp.zeros((1, n, ndim))), axis=0) - b = jnp.concatenate((b, jnp.zeros((1, n))), axis=0) - # yapf: enable - return y, b - - @property - def flattened_y(self) -> jnp.ndarray: - """Array of shape ``[num_measures * (N_1 + N_2 + ...), ndim]``.""" - if self._is_segmented: - return self._y.reshape((-1, self._y.shape[-1])) - return self._y - - @property - def flattened_b(self) -> Optional[jnp.ndarray]: - """Array of shape ``[num_measures * (N_1 + N_2 + ...),]``.""" - return None if self._b is None else self._b.ravel() - - @property - def num_measures(self) -> int: - """Number of measures.""" - return self.segmented_y_b[0].shape[0] - - @property - def max_measure_size(self) -> int: - """Maximum number of points across all measures.""" - return self.segmented_y_b[0].shape[1] - - @property - def ndim(self) -> int: - """Number of dimensions of ``y``.""" - return self._y.shape[-1] - - @property - def weights(self) -> jnp.ndarray: - """Barycenter weights of shape ``[num_measures,]`` that sum to 1.""" - if self._weights is None: - weights = jnp.ones((self.num_measures,)) / self.num_measures - else: - # Check that the number of measures coincides with the weights' size. - assert self._weights.shape[0] == self.num_measures - # By default, we assume that weights sum to 1, and enforce this if needed. - weights = self._weights / jnp.sum(self._weights) - if self.debiased: - weights = jnp.concatenate((weights, jnp.array([-0.5]))) - return weights - - @property - def _is_segmented(self) -> bool: - return self._y.ndim == 3 - - def tree_flatten(self) -> Tuple[Sequence[Any], Dict[str, Any]]: - return ([self._y, self._b, self._weights], { - 'cost_fn': self.cost_fn, - 'epsilon': self.epsilon, - 'debiased': self.debiased, - **self._kwargs, - }) - - @classmethod - def tree_unflatten( - cls, aux_data: Dict[str, Any], children: Sequence[Any] - ) -> "BarycenterProblem": - y, b, weights = children - return cls(y=y, b=b, weights=weights, **aux_data) - - -@jax.tree_util.register_pytree_node_class -class GWBarycenterProblem(BarycenterProblem): +class GWBarycenterProblem(barycenter_problem.BarycenterProblem): """(Fused) Gromov-Wasserstein barycenter problem :cite:`peyre:16,vayer:19`. Args: y: Array of shape ``[num_total_points, ndim]`` merging the points of all measures. Alternatively, already segmented array of shape ``[num_measures, max_measure_size, ndim]`` can be passed. - See also :func:`~ott.core.segment.segment_point_cloud`. + See also :func:`~ott.geometry.segment.segment_point_cloud`. b: Array of shape ``[num_total_points,]`` containing the weights of all the points within the measures that define the barycenter problem. - Similarly as ``y``, segmented array of weights of shape + Same as ``y``, pre-segmented array of weights of shape ``[num_measures, max_measure_size]`` can be passed. If ``y`` is already pre-segmented, this array must be passed. weights: Array of shape ``[num_measures,]`` containing the weights of the @@ -206,14 +36,14 @@ class GWBarycenterProblem(BarycenterProblem): Only one of ``y`` and ``cost`` can be specified. y_fused: Array of shape ``[num_total_points, ndim_fused]`` containing the data of the points of all measures used to define the linear term - in the fused case. Similarly as ``y``, can be specified as a pre-segmented + in the fused case. Same as ``y``, it can be specified as a pre-segmented array of shape ``[num_measures, max_measure_size, ndim_fused]``. gw_loss: Gromov-Wasserstein loss. fused_penalty: Multiplier of the linear term. Only used when ``y_fused != None``. scale_cost: Scaling of cost matrices passed to geometries. kwargs: Keyword arguments for - :class:`~ott.core.bar_problems.BarycenterProblem`. + :class:`~ott.problems.linear.barycenter_problem.BarycenterProblem`. """ def __init__( @@ -277,7 +107,7 @@ def project( y: jnp.ndarray, b: jnp.ndarray, transport: jnp.ndarray, - fn: Optional[quad_problems.Loss], + fn: Optional[quadratic_costs.Loss], ) -> jnp.ndarray: geom = self._create_y_geometry(y, mask=b > 0.) fn, lin = (None, True) if fn is None else (fn.func, fn.is_linear) @@ -309,7 +139,7 @@ def update_features(self, transports: jnp.ndarray, """Update the barycenter features in the fused case :cite:`vayer:19`. Uses :cite:`cuturi:14` eq. 8, and is implemented only - for the squared :class:`~ott.geometry.costs.SqEuclidean` cost. + for the :class:`~ott.geometry.costs.SqEuclidean` cost. Args: transports: Transport maps of shape @@ -332,7 +162,7 @@ def update_features(self, transports: jnp.ndarray, if self._loss_name == "sqeucl": cost_fn = costs.SqEuclidean() return jnp.sum( - weights * barycentric_projection(transports, y_fused, cost_fn), + weights * mu.barycentric_projection(transports, y_fused, cost_fn), axis=0 ) raise NotImplementedError(self._loss_name) @@ -396,8 +226,8 @@ def _create_problem( y: jnp.ndarray, b: jnp.ndarray, f: Optional[jnp.ndarray] = None - ) -> quad_problems.QuadraticProblem: - # TODO(michalk8): in the future, mask in the problem for convenience? + ) -> quadratic_problem.QuadraticProblem: + # TODO(michalk8): in future, mask in the problem for convenience? bary_mask = state.a > 0. y_mask = b > 0. @@ -412,7 +242,7 @@ def _create_problem( else: geom_xy = None - return quad_problems.QuadraticProblem( + return quadratic_problem.QuadraticProblem( geom_xx=geom_xx, geom_yy=geom_yy, geom_xy=geom_xy, @@ -434,7 +264,7 @@ def segmented_y_fused(self) -> Optional[jnp.ndarray]: return self._y_fused y_fused, _ = segment.segment_point_cloud( x=self._y_fused, - padding_vector=self.cost_fn.padder(self.ndim_fused), + padding_vector=self.cost_fn._padder(self.ndim_fused), **self._kwargs ) return y_fused @@ -449,16 +279,16 @@ def ndim_fused(self) -> Optional[int]: return self._y_fused.shape[-1] if self.is_fused else None @property - def gw_loss(self) -> quad_problems.GWLoss: + def gw_loss(self) -> quadratic_costs.GWLoss: """Gromov-Wasserstein loss.""" # TODO(michalk8): custom losses would require inverting some fns; # `https://jax.readthedocs.io/en/latest/notebooks/ some fns; # Writing_custom_interpreters_in_Jax.html#your-first-interpreter-invert` # might be useful if self._loss_name == 'sqeucl': - return quad_problems.make_square_loss() + return quadratic_costs.make_square_loss() if self._loss_name == 'kl': - return quad_problems.make_kl_loss() + return quadratic_costs.make_kl_loss() raise NotImplementedError( f"Loss `{self._loss_name}` is not yet implemented." ) @@ -482,10 +312,3 @@ def tree_unflatten( return cls( y=y, b=b, weights=weights, costs=costs, y_fused=y_fused, **aux_data ) - - -@functools.partial(jax.vmap, in_axes=[0, 0, None]) -def barycentric_projection( - matrix: jnp.ndarray, y: jnp.ndarray, cost_fn -) -> jnp.ndarray: - return jax.vmap(cost_fn.barycenter, in_axes=[0, None])(matrix, y) diff --git a/ott/problems/quadratic/quadratic_costs.py b/ott/problems/quadratic/quadratic_costs.py new file mode 100644 index 000000000..8de1b398d --- /dev/null +++ b/ott/problems/quadratic/quadratic_costs.py @@ -0,0 +1,34 @@ +from typing import Callable, NamedTuple + +import jax +import jax.numpy as jnp + +__all__ = ["make_square_loss", "make_kl_loss"] + + +class Loss(NamedTuple): + func: Callable[[jnp.ndarray], jnp.ndarray] + is_linear: bool + + +class GWLoss(NamedTuple): + f1: Loss + f2: Loss + h1: Loss + h2: Loss + + +def make_square_loss() -> GWLoss: + f1 = Loss(lambda x: x ** 2, is_linear=False) + f2 = Loss(lambda y: y ** 2, is_linear=False) + h1 = Loss(lambda x: x, is_linear=True) + h2 = Loss(lambda y: 2.0 * y, is_linear=True) + return GWLoss(f1, f2, h1, h2) + + +def make_kl_loss(clipping_value: float = 1e-8) -> GWLoss: + f1 = Loss(lambda x: -jax.scipy.special.entr(x) - x, is_linear=False) + f2 = Loss(lambda y: y, is_linear=True) + h1 = Loss(lambda x: x, is_linear=True) + h2 = Loss(lambda y: jnp.log(jnp.clip(y, clipping_value)), is_linear=False) + return GWLoss(f1, f2, h1, h2) diff --git a/ott/core/quad_problems.py b/ott/problems/quadratic/quadratic_problem.py similarity index 89% rename from ott/core/quad_problems.py rename to ott/problems/quadratic/quadratic_problem.py index 16b8182ec..2bf975c5d 100644 --- a/ott/core/quad_problems.py +++ b/ott/problems/quadratic/quadratic_problem.py @@ -13,64 +13,21 @@ # limitations under the License. """Classes defining OT problem(s) (objective function + utilities).""" -from typing import Callable, NamedTuple, Optional, Tuple, Union +from typing import TYPE_CHECKING, Optional, Tuple, Union import jax import jax.numpy as jnp +from typing_extensions import Literal -# Because Protocol is not available in Python < 3.8 -from typing_extensions import Literal, Protocol - -from ott.core import _math_utils as mu -from ott.core import linear_problems, sinkhorn_lr from ott.geometry import epsilon_scheduler, geometry, low_rank, pointcloud +from ott.problems.linear import linear_problem +from ott.problems.quadratic import quadratic_costs +from ott.types import Transport +if TYPE_CHECKING: + from ott.solvers.linear import sinkhorn_lr -class Transport(Protocol): - """Interface for the solution of a transport problem. - - Classes implementing those function do not have to inherit from it, the - class can however be used in type hints to support duck typing. - """ - - @property - def matrix(self) -> jnp.ndarray: - ... - - def apply(self, inputs: jnp.ndarray, axis: int) -> jnp.ndarray: - ... - - def marginal(self, axis: int = 0) -> jnp.ndarray: - ... - - -class Loss(NamedTuple): - func: Callable[[jnp.ndarray], jnp.ndarray] - is_linear: bool - - -class GWLoss(NamedTuple): - f1: Loss - f2: Loss - h1: Loss - h2: Loss - - -def make_square_loss() -> GWLoss: - f1 = Loss(lambda x: x ** 2, is_linear=False) - f2 = Loss(lambda y: y ** 2, is_linear=False) - h1 = Loss(lambda x: jnp.sqrt(2) * x, is_linear=True) - h2 = Loss(lambda y: jnp.sqrt(2) * y, is_linear=True) - return GWLoss(f1, f2, h1, h2) - - -def make_kl_loss(clipping_value: Optional[float] = None) -> GWLoss: - assert clipping_value is None, "Clipping deprecated in KL definition." - f1 = Loss(lambda x: -jax.scipy.special.entr(x) - x, is_linear=False) - f2 = Loss(lambda y: y, is_linear=True) - h1 = Loss(lambda x: x, is_linear=True) - h2 = Loss(lambda y: mu.safe_log(y), is_linear=False) - return GWLoss(f1, f2, h1, h2) +__all__ = ["QuadraticProblem"] @jax.tree_util.register_pytree_node_class @@ -144,7 +101,7 @@ def __init__( scale_cost: Optional[Union[bool, float, str]] = False, a: Optional[jnp.ndarray] = None, b: Optional[jnp.ndarray] = None, - loss: Union[Literal['sqeucl', 'kl'], GWLoss] = 'sqeucl', + loss: Union[Literal['sqeucl', 'kl'], quadratic_costs.GWLoss] = 'sqeucl', tau_a: Optional[float] = 1.0, tau_b: Optional[float] = 1.0, gw_unbalanced_correction: bool = True, @@ -169,9 +126,9 @@ def __init__( self._loss_name = loss if self._loss_name == 'sqeucl': - self.loss = make_square_loss() + self.loss = quadratic_costs.make_square_loss() elif loss == 'kl': - self.loss = make_kl_loss() + self.loss = quadratic_costs.make_kl_loss() else: self.loss = loss @@ -283,7 +240,7 @@ def init_transport_mass(self) -> float: return a.sum() * b.sum() def update_lr_geom( - self, lr_sink: sinkhorn_lr.LRSinkhornOutput + self, lr_sink: 'sinkhorn_lr.LRSinkhornOutput' ) -> geometry.Geometry: """Recompute (possibly LRC) linearization using LR Sinkhorn output.""" marginal_1 = lr_sink.marginal(1) @@ -314,7 +271,7 @@ def update_linearization( transport: Transport, epsilon: Optional[Union[epsilon_scheduler.Epsilon, float]] = None, old_transport_mass: float = 1.0 - ) -> linear_problems.LinearProblem: + ) -> linear_problem.LinearProblem: """Update linearization of GW problem by updating cost matrix. If the problem is balanced (``tau_a = 1.0 and tau_b = 1.0``), the equation @@ -365,15 +322,15 @@ def update_linearization( cost_matrix += self.fused_penalty * self._fused_cost_matrix * rescale_factor geom = geometry.Geometry(cost_matrix=cost_matrix, epsilon=epsilon) - return linear_problems.LinearProblem( + return linear_problem.LinearProblem( geom, self.a, self.b, tau_a=self.tau_a, tau_b=self.tau_b ) def update_lr_linearization( - self, lr_sink: sinkhorn_lr.LRSinkhornOutput - ) -> linear_problems.LinearProblem: + self, lr_sink: 'sinkhorn_lr.LRSinkhornOutput' + ) -> linear_problem.LinearProblem: """Update a Quad problem linearization using a LR Sinkhorn.""" - return linear_problems.LinearProblem( + return linear_problem.LinearProblem( self.update_lr_geom(lr_sink), self.a, self.b, @@ -493,12 +450,12 @@ def is_low_rank(self) -> bool: ) @property - def linear_loss(self) -> Tuple[Loss, Loss]: + def linear_loss(self) -> Tuple[quadratic_costs.Loss, quadratic_costs.Loss]: """Linear part of the Gromov-Wasserstein loss.""" return self.loss.f1, self.loss.f2 @property - def quad_loss(self) -> Tuple[Loss, Loss]: + def quad_loss(self) -> Tuple[quadratic_costs.Loss, quadratic_costs.Loss]: """Quadratic part of the Gromov-Wasserstein loss.""" return self.loss.h1, self.loss.h2 @@ -526,7 +483,9 @@ def tree_unflatten(cls, aux_data, children): return cls(*geoms, a=a, b=b, **aux_data) -def update_epsilon_unbalanced(epsilon, transport_mass): +def update_epsilon_unbalanced( + epsilon: Union[float, epsilon_scheduler.Epsilon], transport_mass: float +) -> epsilon_scheduler.Epsilon: updated_epsilon = epsilon_scheduler.Epsilon.make(epsilon) updated_epsilon._scale_epsilon = ( updated_epsilon._scale_epsilon * transport_mass @@ -535,6 +494,7 @@ def update_epsilon_unbalanced(epsilon, transport_mass): def apply_cost( - geom: geometry.Geometry, arr: jnp.ndarray, *, axis: int, fn: Loss + geom: geometry.Geometry, arr: jnp.ndarray, *, axis: int, + fn: quadratic_costs.Loss ) -> jnp.ndarray: return geom.apply_cost(arr, axis=axis, fn=fn.func, is_linear=fn.is_linear) diff --git a/ott/solvers/__init__.py b/ott/solvers/__init__.py new file mode 100644 index 000000000..15cfac006 --- /dev/null +++ b/ott/solvers/__init__.py @@ -0,0 +1 @@ +from . import linear, nn, quadratic diff --git a/ott/solvers/linear/__init__.py b/ott/solvers/linear/__init__.py new file mode 100644 index 000000000..40034b929 --- /dev/null +++ b/ott/solvers/linear/__init__.py @@ -0,0 +1,8 @@ +from . import ( + acceleration, + continuous_barycenter, + discrete_barycenter, + implicit_differentiation, + sinkhorn, + sinkhorn_lr, +) diff --git a/ott/core/anderson.py b/ott/solvers/linear/acceleration.py similarity index 61% rename from ott/core/anderson.py rename to ott/solvers/linear/acceleration.py index e20b6d0ec..a22ac9ab6 100644 --- a/ott/core/anderson.py +++ b/ott/solvers/linear/acceleration.py @@ -1,28 +1,17 @@ -# Copyright 2022 Google LLC. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""Tools for Anderson acceleration.""" -from typing import Any +from typing import TYPE_CHECKING import jax import jax.numpy as jnp -from ott.core import dataclasses +from ott import utils -SinkhornState = Any +if TYPE_CHECKING: + from ott.solvers.linear import sinkhorn +__all__ = ["AndersonAcceleration", "Momentum"] -@dataclasses.register_pytree_node + +@utils.register_pytree_node class AndersonAcceleration: """Implements Anderson acceleration for Sinkhorn.""" @@ -30,7 +19,7 @@ class AndersonAcceleration: refresh_every: int = 1 # Recompute interpolation periodically. ridge_identity: float = 1e-2 # Ridge used in the linear system. - def extrapolation(self, xs, fxs): + def extrapolation(self, xs: jnp.ndarray, fxs: jnp.ndarray) -> jnp.ndarray: """Compute Anderson extrapolation from past observations.""" # Remove -inf values to instantiate quadratic problem. All others # remain since they might be caused by a valid issue. @@ -49,11 +38,13 @@ def extrapolation(self, xs, fxs): # Recover linear combination and return it with NaN (caused # by 0 weights leading to -jnp.inf potentials, mixed with weights - # coefficiences of different signs), disambiguated to -inf. + # coefficients of different signs), disambiguated to -inf. combination = jnp.sum(fxs * weights[None, :], axis=1) return jnp.where(jnp.isfinite(combination), combination, -jnp.inf) - def update(self, state: SinkhornState, iteration: int, pb, lse_mode: bool): + def update( + self, state: 'sinkhorn.SinkhornState', iteration: int, pb, lse_mode: bool + ) -> 'sinkhorn.SinkhornState': """Anderson acceleration update. When using Anderson acceleration, first update the dual variable f_u with @@ -101,12 +92,63 @@ def update(self, state: SinkhornState, iteration: int, pb, lse_mode: bool): ) return state.set(fu=fu, old_fus=old_fus) - def init_maps(self, pb, state): + def init_maps( + self, pb, state: 'sinkhorn.SinkhornState' + ) -> 'sinkhorn.SinkhornState': """Initialize log matrix used in Anderson acceleration with nan values.""" fus = jnp.ones((pb.geom.shape[0], self.memory)) * jnp.nan return state.set(old_fus=fus, old_mapped_fus=fus) - def update_history(self, state, pb, lse_mode: bool): + def update_history( + self, state: 'sinkhorn.SinkhornState', pb, lse_mode: bool + ) -> 'sinkhorn.SinkhornState': f = state.fu if lse_mode else pb.geom.potential_from_scaling(state.fu) mapped = jnp.concatenate((state.old_mapped_fus[:, 1:], f[:, None]), axis=1) return state.set(old_mapped_fus=mapped) + + +@utils.register_pytree_node +class Momentum: + """Momentum for Sinkhorn updates, either constant or adaptive.""" + + start: int = 0 + error_threshold: float = jnp.inf + value: float = 1.0 + inner_iterations: int = 1 + + def weight(self, state: 'sinkhorn.SinkhornState', iteration: int) -> float: + """Compute momentum term if needed, using previously seen errors.""" + if self.start == 0: + return self.value + idx = self.start // self.inner_iterations + + weight = jax.lax.cond( + jnp.logical_and( + iteration >= self.start, + state.errors[idx - 1, -1] < self.error_threshold + ), lambda state: self.lehmann(state), lambda state: self.value, state + ) + return weight + + def lehmann(self, state: 'sinkhorn.SinkhornState') -> float: + """Momentum formula :cite:`lehmann:21`, eq. 5.""" + idx = self.start // self.inner_iterations + error_ratio = jnp.minimum( + state.errors[idx - 1, -1] / state.errors[idx - 2, -1], 0.99 + ) + power = 1.0 / self.inner_iterations + return 2.0 / (1.0 + jnp.sqrt(1.0 - error_ratio ** power)) + + def __call__( + self, + weight: float, + value: jnp.ndarray, + new_value: jnp.ndarray, + lse_mode: bool = True + ) -> jnp.ndarray: + if lse_mode: + value = jnp.where(jnp.isfinite(value), value, 0.0) + return (1.0 - weight) * value + weight * new_value + else: + value = jnp.where(value > 0.0, value, 1.0) + return value ** (1.0 - weight) * new_value ** weight diff --git a/ott/core/continuous_barycenter.py b/ott/solvers/linear/continuous_barycenter.py similarity index 88% rename from ott/core/continuous_barycenter.py rename to ott/solvers/linear/continuous_barycenter.py index 8a328264c..715a72384 100644 --- a/ott/core/continuous_barycenter.py +++ b/ott/solvers/linear/continuous_barycenter.py @@ -20,8 +20,11 @@ import jax import jax.numpy as jnp -from ott.core import bar_problems, fixed_point_loop, linear_problems, was_solver from ott.geometry import pointcloud +from ott.math import fixed_point_loop +from ott.math import utils as mu +from ott.problems.linear import barycenter_problem, linear_problem +from ott.solvers import was_solver __all__ = ["BarycenterState", "WassersteinBarycenter"] @@ -36,8 +39,6 @@ class BarycenterState(NamedTuple): inner Sinkhorn iterations. errors: Holds sequence of vectors of errors of the Sinkhorn algorithm at each iteration. - linear_states: State used to solve and store solutions to the OT problems - from the barycenter to the measures. x: barycenter points. a: barycenter weights. """ @@ -53,7 +54,7 @@ def set(self, **kwargs: Any) -> 'BarycenterState': return self._replace(**kwargs) def update( - self, iteration: int, bar_prob: bar_problems.BarycenterProblem, + self, iteration: int, bar_prob: barycenter_problem.BarycenterProblem, linear_ot_solver: Any, store_errors: bool ) -> 'BarycenterState': seg_y, seg_b = bar_prob.segmented_y_b @@ -63,7 +64,7 @@ def solve_linear_ot( a: Optional[jnp.ndarray], x: jnp.ndarray, b: jnp.ndarray, y: jnp.ndarray ): out = linear_ot_solver( - linear_problems.LinearProblem( + linear_problem.LinearProblem( pointcloud.PointCloud( x, y, @@ -101,7 +102,7 @@ def solve_linear_ot( # Approximation of barycenter as barycenter of barycenters per measure. - barycenters_per_measure = bar_problems.barycentric_projection( + barycenters_per_measure = mu.barycentric_projection( matrices, seg_y, bar_prob.cost_fn ) @@ -123,7 +124,7 @@ class WassersteinBarycenter(was_solver.WassersteinSolver): def __call__( self, - bar_prob: bar_problems.BarycenterProblem, + bar_prob: barycenter_problem.BarycenterProblem, bar_size: int = 100, x_init: Optional[jnp.ndarray] = None, rng: int = 0 @@ -134,7 +135,7 @@ def __call__( def init_state( self, - bar_prob: bar_problems.BarycenterProblem, + bar_prob: barycenter_problem.BarycenterProblem, bar_size: int, x_init: Optional[jnp.ndarray] = None, # TODO(michalk8): change the API to pass the PRNG key directly @@ -148,7 +149,7 @@ def init_state( x_init: Initial barycenter estimate of shape ``[bar_size, ndim]``. If `None`, ``bar_size`` points will be sampled from the input measures according to their weights - :attr:`~ott.core.bar_problems.BarycenterProblem.flattened_y`. + :attr:`~ott.problems.linear.barycenter_problem.BarycenterProblem.flattened_y`. rng: Seed for :func:`jax.random.PRNGKey`. Returns: @@ -183,18 +184,20 @@ def init_state( ) def output_from_state(self, state: BarycenterState) -> BarycenterState: + # TODO(michalk8): create an output variable to match rest of the framework return state def iterations( solver: WassersteinBarycenter, bar_size: int, - bar_prob: bar_problems.BarycenterProblem, x_init: jnp.ndarray, rng: int + bar_prob: barycenter_problem.BarycenterProblem, x_init: jnp.ndarray, + rng: int ) -> BarycenterState: """Jittable Wasserstein barycenter outer loop.""" def cond_fn( iteration: int, constants: Tuple[WassersteinBarycenter, - bar_problems.BarycenterProblem], + barycenter_problem.BarycenterProblem], state: BarycenterState ) -> bool: solver, _ = constants @@ -202,7 +205,7 @@ def cond_fn( def body_fn( iteration, constants: Tuple[WassersteinBarycenter, - bar_problems.BarycenterProblem], + barycenter_problem.BarycenterProblem], state: BarycenterState, compute_error: bool ) -> BarycenterState: del compute_error # Always assumed True diff --git a/ott/core/discrete_barycenter.py b/ott/solvers/linear/discrete_barycenter.py similarity index 94% rename from ott/core/discrete_barycenter.py rename to ott/solvers/linear/discrete_barycenter.py index dd53806df..044295129 100644 --- a/ott/core/discrete_barycenter.py +++ b/ott/solvers/linear/discrete_barycenter.py @@ -13,23 +13,29 @@ # limitations under the License. # Lint as: python3 -"""Implementation of Janati+(2020) Wasserstein barycenter algorithm.""" +"""Implementation of :cite:`janati:20` Wasserstein barycenter algorithm.""" -import collections import functools -from typing import Optional, Sequence +from typing import NamedTuple, Optional, Sequence import jax import jax.numpy as jnp -from ott.core import fixed_point_loop, sinkhorn from ott.geometry import geometry +from ott.math import fixed_point_loop +from ott.solvers.linear import sinkhorn -SinkhornBarycenterOutput = collections.namedtuple( - 'Barycenter', ['f', 'g', 'histogram', 'errors'] -) +__all__ = ["SinkhornBarycenterOutput", "discrete_barycenter"] +class SinkhornBarycenterOutput(NamedTuple): + f: jnp.ndarray + g: jnp.ndarray + histogram: jnp.ndarray + errors: jnp.ndarray + + +# TODO(michalk8): refactor as a solver? def discrete_barycenter( geom: geometry.Geometry, a: jnp.ndarray, diff --git a/ott/core/implicit_differentiation.py b/ott/solvers/linear/implicit_differentiation.py similarity index 87% rename from ott/core/implicit_differentiation.py rename to ott/solvers/linear/implicit_differentiation.py index 8b68f10ef..8e5cf30c8 100644 --- a/ott/core/implicit_differentiation.py +++ b/ott/solvers/linear/implicit_differentiation.py @@ -13,42 +13,50 @@ # limitations under the License. """Functions entering the implicit differentiation of Sinkhorn.""" -from typing import Callable, Optional, Tuple +from typing import TYPE_CHECKING, Callable, Optional, Tuple import jax import jax.numpy as jnp -from ott.core import dataclasses, linear_problems, unbalanced_functions +from ott import utils +from ott.math import unbalanced_functions +if TYPE_CHECKING: + from ott.problems.linear import linear_problem -@dataclasses.register_pytree_node +__all__ = ["ImplicitDiff"] + + +@utils.register_pytree_node class ImplicitDiff: """Implicit differentiation of Sinkhorn algorithm. - Attributes: - implicit_solver_fun: Callable, should return (solution, ...) + Args: + solver_fun: Callable, should return (solution, ...) ridge_kernel: promotes zero-sum solutions. only used if tau_a = tau_b = 1.0 ridge_identity: handles rank deficient transport matrices (this happens - typically when rows/cols in cost/kernel matrices are colinear, or, + typically when rows/cols in cost/kernel matrices are collinear, or, equivalently when two points from either measure are close). symmetric: flag used to figure out whether the linear system solved in the implicit function theorem is symmetric or not. This happens when either ``a == b`` or the precondition_fun is the identity. False by default, and, at the moment, needs to be set manually by the user in the more favorable case where the system is guaranteed to be symmetric. + precondition_fun: TODO(marcocuturi) """ - solver_fun: Callable = jax.scipy.sparse.linalg.cg # pylint: disable=g-bare-generic + solver_fun: Callable[[jnp.ndarray, jnp.ndarray], + Tuple[jnp.ndarray, ...]] = jax.scipy.sparse.linalg.cg ridge_kernel: float = 0.0 ridge_identity: float = 0.0 symmetric: bool = False - precondition_fun: Optional[Callable[[float], float]] = None + precondition_fun: Optional[Callable[[jnp.ndarray], jnp.ndarray]] = None def solve( self, gr: Tuple[jnp.ndarray, - jnp.ndarray], ot_prob: linear_problems.LinearProblem, + jnp.ndarray], ot_prob: "linear_problem.LinearProblem", f: jnp.ndarray, g: jnp.ndarray, lse_mode: bool - ): + ) -> jnp.ndarray: r"""Apply minus inverse of [hessian ``reg_ot_cost`` w.r.t. ``f``, ``g``]. This function is used to carry out implicit differentiation of ``sinkhorn`` @@ -58,9 +66,11 @@ def solve( Given a ``precondition_fun``, written here for short as :math:`h`, the first order conditions for the dual energy - :math:`E(K, \epsilon, a, b, f, g) :=- + - \langle\exp^{f/\epsilon}, K - \exp^{g/\epsilon}>` + + .. math:: + + E(K, \epsilon, a, b, f, g) :=- + - \langle\exp^{f/\epsilon}, K \exp^{g/\epsilon}> form the basis of the Sinkhorn algorithm. To differentiate optimal solutions to that problem, we exploit the fact that :math:`h(\nabla E = 0)` and @@ -91,7 +101,7 @@ def solve( application elementwise of :math:`h'` to the row (respectively column) marginal sum of the transport. - Note that we take great care in not instantiatiating these transport + Note that we take great care in not instantiating these transport matrices, to rely instead on calls to the ``app_transport`` method from the ``Geometry`` object ``geom`` (which will either use potentials or scalings, depending on ``lse_mode``) @@ -106,22 +116,19 @@ def solve( that subspace to enforce solutions have zero sum. The Schur complement can also be rank deficient if two lines or columns of T - are colinear. This will typically happen it two rows or columns of the cost + are collinear. This will typically happen it two rows or columns of the cost or kernel matrix are numerically close. To avoid this, we add a more global ``ridge_identity * z`` regularizer to achieve better conditioning. - These linear systems are solved using the user defined - ``implicit_solver_fun``, + These linear systems are solved using the user defined ``solver_fun``, which is set by default to ``cg``. When the system is symmetric (as detected by the corresponding flag ``symmetric``), ``cg`` is applied directly. When - it - is not, normal equations are used (i.e. the Schur complement is multiplied - by - its transpose before solving the system). + it is not, normal equations are used (i.e. the Schur complement is + multiplied by its transpose before solving the system). Args: gr: 2-tuple, (vector of size ``n``, vector of size ``m``). - ot_prob: the instantiation of the regularizad transport problem. + ot_prob: the instantiation of the regularized transport problem. f: potential, w.r.t marginal a. g: potential, w.r.t marginal b. lse_mode: bool, log-sum-exp mode if True, kernel else. @@ -269,9 +276,9 @@ def first_order_conditions( return jnp.concatenate((result_a, result_b)) def gradient( - self, prob: linear_problems.LinearProblem, f: jnp.ndarray, g: jnp.ndarray, - lse_mode: bool, gr: Tuple[jnp.ndarray, jnp.ndarray] - ) -> linear_problems.LinearProblem: + self, prob: "linear_problem.LinearProblem", f: jnp.ndarray, + g: jnp.ndarray, lse_mode: bool, gr: Tuple[jnp.ndarray, jnp.ndarray] + ) -> "linear_problem.LinearProblem": """Apply vjp to recover gradient in reverse mode differentiation.""" # Applies first part of vjp to gr: inverse part of implicit function theorem vjp_gr = self.solve(gr, prob, f, g, lse_mode) diff --git a/ott/core/sinkhorn.py b/ott/solvers/linear/sinkhorn.py similarity index 94% rename from ott/core/sinkhorn.py rename to ott/solvers/linear/sinkhorn.py index aee00d797..aebee0038 100644 --- a/ott/core/sinkhorn.py +++ b/ott/solvers/linear/sinkhorn.py @@ -21,14 +21,14 @@ import numpy as np from typing_extensions import Literal -from ott.core import anderson as anderson_lib -from ott.core import fixed_point_loop -from ott.core import implicit_differentiation as implicit_lib -from ott.core import initializers as init_lib -from ott.core import linear_problems -from ott.core import momentum as momentum_lib -from ott.core import potentials, unbalanced_functions from ott.geometry import geometry +from ott.initializers.linear import initializers as init_lib +from ott.math import fixed_point_loop, unbalanced_functions +from ott.problems.linear import linear_problem, potentials +from ott.solvers.linear import acceleration +from ott.solvers.linear import implicit_differentiation as implicit_lib + +__all__ = ["Sinkhorn", "SinkhornOutput"] class SinkhornState(NamedTuple): @@ -45,19 +45,19 @@ def set(self, **kwargs: Any) -> 'SinkhornState': return self._replace(**kwargs) def solution_error( - self, ot_prob: linear_problems.LinearProblem, norm_error: Sequence[int], + self, ot_prob: linear_problem.LinearProblem, norm_error: Sequence[int], lse_mode: bool ) -> jnp.ndarray: return solution_error(self.fu, self.gv, ot_prob, norm_error, lse_mode) def ent_reg_cost( - self, ot_prob: linear_problems.LinearProblem, lse_mode: bool + self, ot_prob: linear_problem.LinearProblem, lse_mode: bool ) -> float: return ent_reg_cost(self.fu, self.gv, ot_prob, lse_mode) def solution_error( - f_u: jnp.ndarray, g_v: jnp.ndarray, ot_prob: linear_problems.LinearProblem, + f_u: jnp.ndarray, g_v: jnp.ndarray, ot_prob: linear_problem.LinearProblem, norm_error: Sequence[int], lse_mode: bool ) -> jnp.ndarray: """Given two potential/scaling solutions, computes deviation to optimality. @@ -142,7 +142,7 @@ def marginal_error( def ent_reg_cost( - f: jnp.ndarray, g: jnp.ndarray, ot_prob: linear_problems.LinearProblem, + f: jnp.ndarray, g: jnp.ndarray, ot_prob: linear_problem.LinearProblem, lse_mode: bool ) -> float: r"""Compute objective of Sinkhorn for OT problem given dual solutions. @@ -209,14 +209,14 @@ class SinkhornOutput(NamedTuple): g: Optional[jnp.ndarray] = None errors: Optional[jnp.ndarray] = None reg_ot_cost: Optional[float] = None - ot_prob: Optional[linear_problems.LinearProblem] = None + ot_prob: Optional[linear_problem.LinearProblem] = None def set(self, **kwargs: Any) -> 'SinkhornOutput': """Return a copy of self, with potential overwrites.""" return self._replace(**kwargs) def set_cost( - self, ot_prob: linear_problems.LinearProblem, lse_mode: bool, + self, ot_prob: linear_problem.LinearProblem, lse_mode: bool, use_danskin: bool ) -> 'SinkhornOutput': f = jax.lax.stop_gradient(self.f) if use_danskin else self.f @@ -225,7 +225,7 @@ def set_cost( @property def linear(self) -> bool: - return isinstance(self.ot_prob, linear_problems.LinearProblem) + return isinstance(self.ot_prob, linear_problem.LinearProblem) @property def geom(self) -> geometry.Geometry: @@ -294,7 +294,7 @@ def transport_mass(self) -> float: def to_dual_potentials(self) -> potentials.EntropicPotentials: """Return the entropic map estimator.""" return potentials.EntropicPotentials( - self.f, self.g, self.geom, self.a, self.b + self.f, self.g, geom=self.geom, a=self.a, b=self.b ) @@ -304,7 +304,8 @@ class Sinkhorn: A Sinkhorn solver takes a linear OT problem object as an input and returns a SinkhornOutput object that contains all the information required to compute - transports. See :func:`~ott.core.sinkhorn.sinkhorn` for a functional wrapper. + transports. See :func:`~ott.solvers.linear.sinkhorn.sinkhorn` + for a functional wrapper. Args: lse_mode: ``True`` for log-sum-exp computations, ``False`` for kernel @@ -325,8 +326,8 @@ class Sinkhorn: unroll-able :func:`jax.lax.while_loop` that monitors convergence. In that case the error is not monitored and the ``converged`` flag will return ``False`` as a consequence. - momentum: a Momentum instance. See ott.core.momentum - anderson: an AndersonAcceleration instance. See ott.core.anderson. + momentum: Momentum instance. + anderson: AndersonAcceleration instance. implicit_diff: instance used to solve implicit differentiation. Unrolls iterations if None. parallel_dual_updates: updates potentials or scalings in parallel if True, @@ -352,8 +353,8 @@ def __init__( inner_iterations: int = 10, min_iterations: int = 0, max_iterations: int = 2000, - momentum: Optional[momentum_lib.Momentum] = None, - anderson: Optional[anderson_lib.AndersonAcceleration] = None, + momentum: Optional[acceleration.Momentum] = None, + anderson: Optional[acceleration.AndersonAcceleration] = None, parallel_dual_updates: bool = False, use_danskin: Optional[bool] = None, implicit_diff: Optional[implicit_lib.ImplicitDiff @@ -373,20 +374,20 @@ def __init__( self.implicit_diff = implicit_diff if momentum is not None: - self.momentum = momentum_lib.Momentum( + self.momentum = acceleration.Momentum( momentum.start, momentum.error_threshold, momentum.value, self.inner_iterations ) else: # Use no momentum if using Anderson or unrolling. if self.anderson is not None or self.implicit_diff is None: - self.momentum = momentum_lib.Momentum( + self.momentum = acceleration.Momentum( inner_iterations=self.inner_iterations ) # Use adaptive momentum from 300th iteration. Only do so # if error is already below threshold below. else: - self.momentum = momentum_lib.Momentum( + self.momentum = acceleration.Momentum( start=300, error_threshold=1e-2, inner_iterations=self.inner_iterations @@ -404,7 +405,7 @@ def __init__( implicit_lib.ImplicitDiff() if self.implicit_diff is None else self.implicit_diff ) - self.momentum = momentum_lib.Momentum( + self.momentum = acceleration.Momentum( inner_iterations=self.inner_iterations ) @@ -415,7 +416,7 @@ def __init__( def __call__( self, - ot_prob: linear_problems.LinearProblem, + ot_prob: linear_problem.LinearProblem, init: Tuple[Optional[jnp.ndarray], Optional[jnp.ndarray]] = (None, None), ) -> SinkhornOutput: """Run Sinkhorn algorithm. @@ -436,7 +437,7 @@ def __call__( return run_fn(ot_prob, self, (init_dual_a, init_dual_b)) def lse_step( - self, ot_prob: linear_problems.LinearProblem, state: SinkhornState, + self, ot_prob: linear_problem.LinearProblem, state: SinkhornState, iteration: int ) -> SinkhornState: """Sinkhorn LSE update.""" @@ -458,7 +459,7 @@ def lse_step( return state.set(fu=fu, gv=gv) def kernel_step( - self, ot_prob: linear_problems.LinearProblem, state: SinkhornState, + self, ot_prob: linear_problem.LinearProblem, state: SinkhornState, iteration: int ) -> SinkhornState: """Sinkhorn multiplicative update.""" @@ -478,7 +479,7 @@ def kernel_step( return state.set(fu=fu, gv=gv) def one_iteration( - self, ot_prob: linear_problems.LinearProblem, state: SinkhornState, + self, ot_prob: linear_problem.LinearProblem, state: SinkhornState, iteration: int, compute_error: bool ) -> SinkhornState: """Carries out sinkhorn iteration. @@ -544,8 +545,8 @@ def outer_iterations(self) -> int: return np.ceil(self.max_iterations / self.inner_iterations).astype(int) def init_state( - self, ot_prob: linear_problems.LinearProblem, init: Tuple[jnp.ndarray, - jnp.ndarray] + self, ot_prob: linear_problem.LinearProblem, init: Tuple[jnp.ndarray, + jnp.ndarray] ) -> SinkhornState: """Return the initial state of the loop.""" fu, gv = init @@ -555,7 +556,7 @@ def init_state( return self.anderson.init_maps(ot_prob, state) if self.anderson else state def output_from_state( - self, ot_prob: linear_problems.LinearProblem, state: SinkhornState + self, ot_prob: linear_problem.LinearProblem, state: SinkhornState ) -> SinkhornOutput: """Create an output from a loop state. @@ -625,7 +626,7 @@ def tree_unflatten(cls, aux_data, children): def run( - ot_prob: linear_problems.LinearProblem, solver: Sinkhorn, + ot_prob: linear_problem.LinearProblem, solver: Sinkhorn, init: Tuple[jnp.ndarray, ...] ) -> SinkhornOutput: """Run loop of the solver, outputting a state upgraded to an output.""" @@ -638,20 +639,20 @@ def run( def iterations( - ot_prob: linear_problems.LinearProblem, solver: Sinkhorn, + ot_prob: linear_problem.LinearProblem, solver: Sinkhorn, init: Tuple[jnp.ndarray, ...] ) -> SinkhornOutput: """Jittable Sinkhorn loop. args contain initialization variables.""" def cond_fn( - iteration: int, const: Tuple[linear_problems.LinearProblem, Sinkhorn], + iteration: int, const: Tuple[linear_problem.LinearProblem, Sinkhorn], state: SinkhornState ) -> bool: _, solver = const return solver._continue(state, iteration) def body_fn( - iteration: int, const: Tuple[linear_problems.LinearProblem, Sinkhorn], + iteration: int, const: Tuple[linear_problem.LinearProblem, Sinkhorn], state: SinkhornState, compute_error: bool ) -> SinkhornState: ot_prob, solver = const @@ -675,10 +676,10 @@ def body_fn( def _iterations_taped( - ot_prob: linear_problems.LinearProblem, solver: Sinkhorn, + ot_prob: linear_problem.LinearProblem, solver: Sinkhorn, init: Tuple[jnp.ndarray, ...] ) -> Tuple[SinkhornOutput, Tuple[jnp.ndarray, jnp.ndarray, - linear_problems.LinearProblem, Sinkhorn]]: + linear_problem.LinearProblem, Sinkhorn]]: """Run forward pass of the Sinkhorn algorithm storing side information.""" state = iterations(ot_prob, solver, init) return state, (state.f, state.g, ot_prob, solver) @@ -750,16 +751,16 @@ def make( ) # If no params are passed, align default with that provide in Sinkhorn solver. if momentum is None and chg_momentum_from is None: - mom = momentum_lib.Momentum(start=300, error_threshold=1e-2) + mom = acceleration.Momentum(start=300, error_threshold=1e-2) elif momentum is None: - mom = momentum_lib.Momentum(start=chg_momentum_from) + mom = acceleration.Momentum(start=chg_momentum_from) elif chg_momentum_from is None: - mom = momentum_lib.Momentum(value=momentum) + mom = acceleration.Momentum(value=momentum) else: - mom = momentum_lib.Momentum(start=chg_momentum_from, value=momentum) + mom = acceleration.Momentum(start=chg_momentum_from, value=momentum) if anderson_acceleration > 0: - anderson = anderson_lib.AndersonAcceleration( + anderson = acceleration.AndersonAcceleration( memory=anderson_acceleration, refresh_every=refresh_anderson_frequency ) else: @@ -1100,5 +1101,5 @@ def sinkhorn( by the user. """ sink = make(**kwargs) - ot_prob = linear_problems.LinearProblem(geom, a, b, tau_a, tau_b) + ot_prob = linear_problem.LinearProblem(geom, a, b, tau_a, tau_b) return sink(ot_prob, (init_dual_a, init_dual_b)) diff --git a/ott/core/sinkhorn_lr.py b/ott/solvers/linear/sinkhorn_lr.py similarity index 87% rename from ott/core/sinkhorn_lr.py rename to ott/solvers/linear/sinkhorn_lr.py index e27989f44..13f014818 100644 --- a/ott/core/sinkhorn_lr.py +++ b/ott/solvers/linear/sinkhorn_lr.py @@ -21,11 +21,14 @@ import jax.scipy as jsp from typing_extensions import Literal -from ott.core import _math_utils as mu -from ott.core import fixed_point_loop -from ott.core import initializers_lr as init_lib -from ott.core import linear_problems, sinkhorn from ott.geometry import geometry, low_rank, pointcloud +from ott.initializers.linear import initializers_lr as init_lib +from ott.math import fixed_point_loop +from ott.math import utils as mu +from ott.problems.linear import linear_problem +from ott.solvers.linear import sinkhorn + +__all__ = ["LRSinkhorn", "LRSinkhornOutput"] class LRSinkhornState(NamedTuple): @@ -48,13 +51,13 @@ def compute_error(self, previous_state: "LRSinkhornState") -> float: def reg_ot_cost( self, - ot_prob: linear_problems.LinearProblem, + ot_prob: linear_problem.LinearProblem, use_danskin: bool = False ) -> float: return compute_reg_ot_cost(self.q, self.r, self.g, ot_prob, use_danskin) def solution_error( - self, ot_prob: linear_problems.LinearProblem, norm_error: Tuple[int, ...], + self, ot_prob: linear_problem.LinearProblem, norm_error: Tuple[int, ...], lse_mode: bool ) -> jnp.ndarray: return solution_error(self.q, self.r, ot_prob, norm_error, lse_mode) @@ -68,7 +71,7 @@ def compute_reg_ot_cost( q: jnp.ndarray, r: jnp.ndarray, g: jnp.ndarray, - ot_prob: linear_problems.LinearProblem, + ot_prob: linear_problem.LinearProblem, use_danskin: bool = False ) -> float: q = jax.lax.stop_gradient(q) if use_danskin else q @@ -78,7 +81,7 @@ def compute_reg_ot_cost( def solution_error( - q: jnp.ndarray, r: jnp.ndarray, ot_prob: linear_problems.LinearProblem, + q: jnp.ndarray, r: jnp.ndarray, ot_prob: linear_problem.LinearProblem, norm_error: Tuple[int, ...], lse_mode: bool ) -> jnp.ndarray: """Compute solution error. @@ -122,7 +125,7 @@ class LRSinkhornOutput(NamedTuple): # TODO(michalk8): must be called `errors`, because of `store_inner_errors` # in future, enforce via class hierarchy errors: jnp.ndarray - ot_prob: linear_problems.LinearProblem + ot_prob: linear_problem.LinearProblem # TODO(michalk8): Optional is an artifact of the current impl., refactor reg_ot_cost: Optional[float] = None @@ -132,7 +135,7 @@ def set(self, **kwargs: Any) -> 'LRSinkhornOutput': def set_cost( self, - ot_prob: linear_problems.LinearProblem, + ot_prob: linear_problem.LinearProblem, lse_mode: bool, use_danskin: bool = False ) -> 'LRSinkhornOutput': @@ -141,14 +144,14 @@ def set_cost( def compute_reg_ot_cost( self, - ot_prob: linear_problems.LinearProblem, + ot_prob: linear_problem.LinearProblem, use_danskin: bool = False, ) -> float: return compute_reg_ot_cost(self.q, self.r, self.g, ot_prob, use_danskin) @property def linear(self) -> bool: - return isinstance(self.ot_prob, linear_problems.LinearProblem) + return isinstance(self.ot_prob, linear_problem.LinearProblem) @property def geom(self) -> geometry.Geometry: @@ -214,38 +217,38 @@ class LRSinkhorn(sinkhorn.Sinkhorn): case. Args: - rank: the rank constraint on the coupling to minimize the linear OT problem - gamma: the (inverse of) gradient step size used by mirror descent. + rank: The rank constraint on the coupling to minimize the linear OT problem + gamma: The (inverse of) gradient step size used by mirror descent. gamma_rescale: Whether to rescale :math:`\gamma` every iteration as described in :cite:`scetbon:22b`. - epsilon: entropic regularization added on top of low-rank problem. + epsilon: Entropic regularization added on top of low-rank problem. initializer: How to initialize the :math:`Q`, :math:`R` and :math:`g` factors. Valid options are: - - `'random'` - :class:`~ott.core.initializers_lr.RandomInitializer`. - - `'rank2'` - :class:`~ott.core.initializers_lr.Rank2Initializer`. - - `'k-means'` - :class:`~ott.core.initializers_lr.KMeansInitializer`. - - `'generalized-k-means'` - - :class:`~ott.core.initializers_lr.GeneralizedKMeansInitializer`. + - `'random'` - :class:`~ott.initializers.linear.initializers_lr.RandomInitializer`. + - `'rank2'` - :class:`~ott.initializers.linear.initializers_lr.Rank2Initializer`. + - `'k-means'` - :class:`~ott.initializers.linear.initializers_lr.KMeansInitializer`. + - `'generalized-k-means'` - :class:`~ott.initializers.linear.initializers_lr.GeneralizedKMeansInitializer`. - If `None`, :class:`~ott.core.initializers_lr.KMeansInitializer` + If `None`, :class:`~ott.initializers.linear.initializers_lr.KMeansInitializer` is used when the linear problem's geometry is :class:`~ott.geometry.pointcloud.PointCloud` or :class:`~ott.geometry.low_rank.LRCGeometry`. - Otherwise, use :class:`~ott.core.initializers_lr.RandomInitializer`. + Otherwise, use :class:`~ott.initializers.linear.initializers_lr.RandomInitializer`. - lse_mode: whether to run computations in lse or kernel mode. At the moment, + lse_mode: Whether to run computations in lse or kernel mode. At the moment, only ``lse_mode = True`` is implemented. - inner_iterations: number of inner iterations used by the algorithm before + inner_iterations: Number of inner iterations used by the algorithm before re-evaluating progress. - use_danskin: use Danskin theorem to evaluate gradient of objective w.r.t. + use_danskin: Use Danskin theorem to evaluate gradient of objective w.r.t. input parameters. Only `True` handled at this moment. implicit_diff: Whether to use implicit differentiation. Currently, only ``implicit_diff = False`` is implemented. - kwargs_dys: keyword arguments passed to :meth:`dykstra_update`. - kwargs_init: keyword arguments for - :class:`~ott.core.initializers_lr.LRInitializer`. - kwargs: Keyword arguments for :class:`~ott.core.sinkhorn.Sinkhorn`. + kwargs_dys: Keyword arguments passed to :meth:`dykstra_update`. + kwargs_init: Keyword arguments for + :class:`~ott.initializers.linear.initializers_lr.LRInitializer`. + kwargs: Keyword arguments for + :class:`~ott.solvers.linear.sinkhorn.Sinkhorn`. """ def __init__( @@ -265,8 +268,8 @@ def __init__( kwargs_init: Optional[Mapping[str, Any]] = None, **kwargs: Any, ): - assert lse_mode, "Kernel mode not yet implemented for LRSinkhorn." - assert not implicit_diff, "Implicit diff. not yet implemented for LRSink." + assert lse_mode, "Kernel mode not yet implemented." + assert not implicit_diff, "Implicit diff. not yet implemented." super().__init__( lse_mode=lse_mode, inner_iterations=inner_iterations, @@ -285,7 +288,7 @@ def __init__( def __call__( self, - ot_prob: linear_problems.LinearProblem, + ot_prob: linear_problem.LinearProblem, init: Tuple[Optional[jnp.ndarray], Optional[jnp.ndarray], Optional[jnp.ndarray]] = (None, None, None), key: Optional[jnp.ndarray] = None, @@ -297,9 +300,9 @@ def __call__( ot_prob: Linear OT problem. init: Initial values for the low-rank factors: - - :attr:`~ott.core.sinkhorn_lr.LRSinkhornOutput.q`. - - :attr:`~ott.core.sinkhorn_lr.LRSinkhornOutput.r`. - - :attr:`~ott.core.sinkhorn_lr.LRSinkhornOutput.g`. + - :attr:`~ott.solvers.linear.sinkhorn_lr.LRSinkhornOutput.q`. + - :attr:`~ott.solvers.linear.sinkhorn_lr.LRSinkhornOutput.r`. + - :attr:`~ott.solvers.linear.sinkhorn_lr.LRSinkhornOutput.g`. Any `None` values will be initialized using the initializer. key: Random key for seeding. @@ -316,7 +319,7 @@ def __call__( def _lr_costs( self, - ot_prob: linear_problems.LinearProblem, + ot_prob: linear_problem.LinearProblem, state: LRSinkhornState, ) -> Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray, float]: log_q, log_r, log_g = ( @@ -353,7 +356,7 @@ def dykstra_update( c_r: jnp.ndarray, h: jnp.ndarray, gamma: float, - ot_prob: linear_problems.LinearProblem, + ot_prob: linear_problem.LinearProblem, min_entry_value: float = 1e-6, tolerance: float = 1e-3, min_iter: int = 0, @@ -465,7 +468,7 @@ def recompute_couplings( return recompute_couplings(f1, g1_old, c_q, f2, g2_old, c_r, h_old, gamma) def lse_step( - self, ot_prob: linear_problems.LinearProblem, state: LRSinkhornState, + self, ot_prob: linear_problem.LinearProblem, state: LRSinkhornState, iteration: int ) -> LRSinkhornState: """LR Sinkhorn LSE update.""" @@ -476,7 +479,7 @@ def lse_step( return state.set(q=q, g=g, r=r, gamma=gamma) def kernel_step( - self, ot_prob: linear_problems.LinearProblem, state: LRSinkhornState, + self, ot_prob: linear_problem.LinearProblem, state: LRSinkhornState, iteration: int ) -> NoReturn: """Not implemented.""" @@ -484,7 +487,7 @@ def kernel_step( raise NotImplementedError("Not implemented.") def one_iteration( - self, ot_prob: linear_problems.LinearProblem, state: LRSinkhornState, + self, ot_prob: linear_problem.LinearProblem, state: LRSinkhornState, iteration: int, compute_error: bool ) -> LRSinkhornState: """Carries out one LR sinkhorn iteration. @@ -539,7 +542,7 @@ def is_entropic(self) -> bool: return self.epsilon > 0. def create_initializer( - self, prob: linear_problems.LinearProblem + self, prob: linear_problem.LinearProblem ) -> init_lib.LRInitializer: """Create a low-rank Sinkhorn initializer. @@ -569,7 +572,7 @@ def create_initializer( return initializer def init_state( - self, ot_prob: linear_problems.LinearProblem, + self, ot_prob: linear_problem.LinearProblem, init: Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray] ) -> LRSinkhornState: """Return the initial state of the loop.""" @@ -585,7 +588,7 @@ def init_state( ) def output_from_state( - self, ot_prob: linear_problems.LinearProblem, state: LRSinkhornState + self, ot_prob: linear_problem.LinearProblem, state: LRSinkhornState ) -> LRSinkhornOutput: """Create an output from a loop state. @@ -641,7 +644,7 @@ def _diverged(self, state: LRSinkhornState, iteration: int) -> bool: def run( - ot_prob: linear_problems.LinearProblem, + ot_prob: linear_problem.LinearProblem, solver: LRSinkhorn, init: Tuple[Optional[jnp.ndarray], Optional[jnp.ndarray], Optional[jnp.ndarray]], diff --git a/ott/solvers/nn/__init__.py b/ott/solvers/nn/__init__.py new file mode 100644 index 000000000..a695e215c --- /dev/null +++ b/ott/solvers/nn/__init__.py @@ -0,0 +1 @@ +from . import icnn, layers, neuraldual diff --git a/ott/core/icnn.py b/ott/solvers/nn/icnn.py similarity index 97% rename from ott/core/icnn.py rename to ott/solvers/nn/icnn.py index 8be945f4e..4e93bc91c 100644 --- a/ott/core/icnn.py +++ b/ott/solvers/nn/icnn.py @@ -13,7 +13,7 @@ # limitations under the License. # Lint as: python3 -"""Implementation of Amos+(2017) input convex neural networks (ICNN).""" +"""Implementation of :cite:`amos:17` input convex neural networks (ICNN).""" from typing import Any, Callable, Sequence, Tuple, Union @@ -24,8 +24,10 @@ from flax.training import train_state from jax.nn import initializers -from ott.core.layers import PosDefPotentials, PositiveDense -from ott.geometry import matrix_square_root +from ott.math import matrix_square_root +from ott.solvers.nn.layers import PosDefPotentials, PositiveDense + +__all__ = ["ICNN"] PRNGKey = Any Shape = Tuple[int] diff --git a/ott/core/layers.py b/ott/solvers/nn/layers.py similarity index 82% rename from ott/core/layers.py rename to ott/solvers/nn/layers.py index 5a4b0a396..770d966dd 100644 --- a/ott/core/layers.py +++ b/ott/solvers/nn/layers.py @@ -11,7 +11,7 @@ # limitations under the License. # Lint as: python3 -"""Layers used in input convex neural networks (Amos+(2017), Bunne+(2022)).""" +"""Layers used in input convex neural networks :cite:`amos:17,bunne:22`.""" from typing import Any, Callable, Tuple @@ -19,6 +19,8 @@ import jax.numpy as jnp from flax import linen as nn +__all__ = ["PositiveDense", "PosDefPotentials"] + PRNGKey = Any Shape = Tuple[int] Dtype = Any @@ -30,8 +32,8 @@ class PositiveDense(nn.Module): Args: dim_hidden: the number of output dim_hidden. - rectifier_fn: choice of rectiver function (default: softplus function). - inv_rectifier_fn: choice of inverse rectiver function + rectifier_fn: choice of rectifier function (default: softplus function). + inv_rectifier_fn: choice of inverse rectifier function (default: inverse softplus function). dtype: the dtype of the computation (default: float32). precision: numerical precision of computation see `jax.lax.Precision` @@ -40,8 +42,9 @@ class PositiveDense(nn.Module): bias_init: initializer function for the bias. """ dim_hidden: int - rectifier_fn: Callable = nn.softplus - inv_rectifier_fn: Callable = lambda x: jnp.log(jnp.exp(x) - 1) + rectifier_fn: Callable[[jnp.ndarray], jnp.ndarray] = nn.softplus + inv_rectifier_fn: Callable[[jnp.ndarray], + jnp.ndarray] = lambda x: jnp.log(jnp.exp(x) - 1) use_bias: bool = True dtype: Any = jnp.float32 precision: Any = None @@ -50,11 +53,11 @@ class PositiveDense(nn.Module): bias_init: Callable[[PRNGKey, Shape, Dtype], Array] = nn.initializers.zeros @nn.compact - def __call__(self, inputs): + def __call__(self, inputs: jnp.ndarray) -> jnp.ndarray: """Applies a linear transformation to inputs along the last dimension. Args: - inputs: The nd-array to be transformed. + inputs: Array to be transformed. Returns: The transformed input. """ @@ -79,8 +82,8 @@ class PosDefPotentials(nn.Module): """A layer to output (0.5 [A_i A_i^T] (x - b_i)_i potentials. Args: - use_bias: whether to add a bias to the output (default: True). - dtype: the dtype of the computation (default: float32). + use_bias: whether to add a bias to the output. + dtype: the dtype of the computation. precision: numerical precision of computation see `jax.lax.Precision` for details. kernel_init: initializer function for the weight matrix. @@ -96,11 +99,12 @@ class PosDefPotentials(nn.Module): bias_init: Callable[[PRNGKey, Shape, Dtype], Array] = nn.initializers.zeros @nn.compact - def __call__(self, inputs): - """Applies a few quadratic forms. + def __call__(self, inputs: jnp.ndarray) -> jnp.ndarray: + """Apply a few quadratic forms. Args: - inputs: The nd-array to be transformed (possibly batched) + inputs: Array to be transformed (possibly batched). + Returns: The transformed input. """ diff --git a/ott/core/neuraldual.py b/ott/solvers/nn/neuraldual.py similarity index 92% rename from ott/core/neuraldual.py rename to ott/solvers/nn/neuraldual.py index 615aca270..306afa5c2 100644 --- a/ott/core/neuraldual.py +++ b/ott/solvers/nn/neuraldual.py @@ -14,7 +14,7 @@ """A Jax implementation of the ICNN based Kantorovich dual.""" import warnings -from typing import Dict, Iterator, List, Optional, Tuple, Union +from typing import Dict, Iterable, List, Optional, Tuple, Union import flax.linen as nn import jax @@ -23,11 +23,13 @@ from flax import core from typing_extensions import Literal -from ott.core import icnn, potentials from ott.geometry import costs +from ott.problems.linear import potentials +from ott.solvers.nn import icnn -Train_t = Dict[Literal["training_logs", "validation_logs"], List[float]] -Potentials_t = potentials.DualPotentials +__all__ = ["NeuralDualSolver"] + +Train_t = Dict[Literal["train_logs", "valid_logs"], Dict[str, List[float]]] class NeuralDualSolver: @@ -133,11 +135,12 @@ def setup( def __call__( self, - trainloader_source: Iterator[jnp.ndarray], - trainloader_target: Iterator[jnp.ndarray], - validloader_source: Iterator[jnp.ndarray], - validloader_target: Iterator[jnp.ndarray], - ) -> Union[Potentials_t, Tuple[Potentials_t, Train_t]]: + trainloader_source: Iterable[jnp.ndarray], + trainloader_target: Iterable[jnp.ndarray], + validloader_source: Iterable[jnp.ndarray], + validloader_target: Iterable[jnp.ndarray], + ) -> Union[potentials.DualPotentials, Tuple[potentials.DualPotentials, + Train_t]]: logs = self.train_neuraldual( trainloader_source, trainloader_target, @@ -150,10 +153,10 @@ def __call__( def train_neuraldual( self, - trainloader_source, - trainloader_target, - validloader_source, - validloader_target, + trainloader_source: Iterable[jnp.ndarray], + trainloader_target: Iterable[jnp.ndarray], + validloader_source: Iterable[jnp.ndarray], + validloader_target: Iterable[jnp.ndarray], ) -> Train_t: """Implementation of the training and validation script.""" # noqa: D401 try: @@ -307,7 +310,9 @@ def to_dual_potentials(self) -> potentials.DualPotentials: """Return the Kantorovich dual potentials from the trained potentials.""" f = lambda x: self.state_f.apply_fn({"params": self.state_f.params}, x) g = lambda x: self.state_g.apply_fn({"params": self.state_g.params}, x) - return potentials.DualPotentials(f, g, costs.SqEuclidean(), corr=True) + return potentials.DualPotentials( + f, g, cost_fn=costs.SqEuclidean(), corr=True + ) @staticmethod def _clip_weights_icnn(params): diff --git a/ott/solvers/quadratic/__init__.py b/ott/solvers/quadratic/__init__.py new file mode 100644 index 000000000..af9e5d01e --- /dev/null +++ b/ott/solvers/quadratic/__init__.py @@ -0,0 +1 @@ +from . import gromov_wasserstein, gw_barycenter diff --git a/ott/core/gromov_wasserstein.py b/ott/solvers/quadratic/gromov_wasserstein.py similarity index 88% rename from ott/core/gromov_wasserstein.py rename to ott/solvers/quadratic/gromov_wasserstein.py index 0f9f61305..0489b9680 100644 --- a/ott/core/gromov_wasserstein.py +++ b/ott/solvers/quadratic/gromov_wasserstein.py @@ -20,17 +20,16 @@ import jax.numpy as jnp from typing_extensions import Literal -from ott.core import ( - fixed_point_loop, - initializers_lr, - linear_problems, - quad_initializers, - quad_problems, - sinkhorn, - sinkhorn_lr, - was_solver, -) from ott.geometry import epsilon_scheduler, geometry, low_rank, pointcloud +from ott.initializers.linear import initializers_lr +from ott.initializers.quadratic import initializers as quad_initializers +from ott.math import fixed_point_loop +from ott.problems.linear import linear_problem +from ott.problems.quadratic import quadratic_costs, quadratic_problem +from ott.solvers import was_solver +from ott.solvers.linear import sinkhorn, sinkhorn_lr + +__all__ = ["GWOutput", "GromovWasserstein", "gromov_wasserstein"] LinearOutput = Union[sinkhorn.SinkhornOutput, sinkhorn_lr.LRSinkhornOutput] @@ -107,7 +106,7 @@ class GWState(NamedTuple): costs: jnp.ndarray linear_convergence: jnp.ndarray linear_state: LinearOutput - linear_pb: linear_problems.LinearProblem + linear_pb: linear_problem.LinearProblem old_transport_mass: float keys: Optional[jnp.ndarray] = None errors: Optional[jnp.ndarray] = None @@ -118,7 +117,7 @@ def set(self, **kwargs: Any) -> 'GWState': def update( self, iteration: int, linear_sol: LinearOutput, - linear_pb: linear_problems.LinearProblem, store_errors: bool, + linear_pb: linear_problem.LinearProblem, store_errors: bool, old_transport_mass: float ) -> 'GWState': costs = self.costs.at[iteration].set(linear_sol.reg_ot_cost) @@ -145,33 +144,33 @@ class GromovWasserstein(was_solver.WassersteinSolver): Args: args: Positional_arguments for - :class:`~ott.core.was_solver.WassersteinSolver`. + :class:`~ott.solvers.was_solver.WassersteinSolver`. warm_start: Whether to initialize (low-rank) Sinkhorn calls using values from the previous iteration. If `None`, warm starts are not used for standard Sinkhorn, but used for low-rank Sinkhorn. quad_initializer: Quadratic initializer. If the solver is entropic, - :class:`~ott.core.quad_initializers.QuadraticInitializer` is always used. - Otherwise, the quadratic initializer wraps low-rank Sinkhorn initializers: + :class:`~ott.initializers.quadratic.initializers.QuadraticInitializer` + is always used. Otherwise, the quadratic initializer wraps the low-rank + Sinkhorn initializers: - - `'random'` - :class:`~ott.core.initializers_lr.RandomInitializer`. - - `'rank2'` - :class:`~ott.core.initializers_lr.Rank2Initializer`. - - `'k-means'` - :class:`~ott.core.initializers_lr.KMeansInitializer`. - - `'generalized-k-means'` - - :class:`~ott.core.initializers_lr.GeneralizedKMeansInitializer`. + - `'random'` - :class:`~ott.initializers.linear.initializers_lr.RandomInitializer`. + - `'rank2'` - :class:`~ott.initializers.linear.initializers_lr.Rank2Initializer`. + - `'k-means'` - :class:`~ott.initializers.linear.initializers_lr.KMeansInitializer`. + - `'generalized-k-means'` - :class:`~ott.initializers.linear.initializers_lr.GeneralizedKMeansInitializer`. If `None`, the low-rank initializer will be selected in a problem-specific manner: - - if both :attr:`~ott.core.quad_problems.QuadraticProblem.geom_xx` and - :attr:`~ott.core.quad_problems.QuadraticProblem.geom_yy` are - :class:`~ott.geometry.pointcloud.PointCloud` or - :class:`~ott.geometry.low_rank.LRCGeometry`, - :class:`~ott.core.initializers_lr.KMeansInitializer` is used. - - otherwise, use :class:`~ott.core.initializers_lr.RandomInitializer`. + - if both :attr:`~ott.problems.quadratic.quadratic_problem.QuadraticProblem.geom_xx` + and :attr:`~ott.problems.quadratic.quadratic_problem.QuadraticProblem.geom_yy` + are :class:`~ott.geometry.pointcloud.PointCloud` or :class:`~ott.geometry.low_rank.LRCGeometry`, + :class:`~ott.initializers.linear.initializers_lr.KMeansInitializer` + is used. + - otherwise, use :class:`~ott.initializers.linear.initializers_lr.RandomInitializer`. kwargs_init: Keyword arguments when creating the initializer. kwargs: Keyword arguments for - :class:`~ott.core.was_solver.WassersteinSolver`. + :class:`~ott.solvers.was_solver.WassersteinSolver`. """ def __init__( @@ -191,8 +190,8 @@ def __init__( def __call__( self, - prob: quad_problems.QuadraticProblem, - init: Optional[linear_problems.LinearProblem] = None, + prob: quadratic_problem.QuadraticProblem, + init: Optional[linear_problem.LinearProblem] = None, key: Optional[jnp.ndarray] = None, **kwargs: Any, ) -> GWOutput: @@ -240,8 +239,8 @@ def __call__( def init_state( self, - prob: quad_problems.QuadraticProblem, - init: linear_problems.LinearProblem, + prob: quadratic_problem.QuadraticProblem, + init: linear_problem.LinearProblem, key: jnp.ndarray, ) -> GWState: """Initialize the state of the Gromov-Wasserstein iterations. @@ -250,7 +249,7 @@ def init_state( prob: Quadratic OT problem. init: Initial linearization of the quadratic problem. key: Random key for low-rank initializers. Only used when - :attr:`warm_start` is `False`. + :attr:`warm_start` is `False`. Returns: The initial Gromov-Wasserstein state. @@ -292,7 +291,7 @@ def output_from_state(self, state: GWState) -> GWOutput: ) def create_initializer( - self, prob: quad_problems.QuadraticProblem + self, prob: quadratic_problem.QuadraticProblem ) -> quad_initializers.BaseQuadraticInitializer: """Create quadratic, possibly low-rank initializer. @@ -309,7 +308,7 @@ def create_initializer( assert isinstance( self.quad_initializer, quad_initializers.LRQuadraticInitializer ), f"Expected quadratic initializer to be low rank, " \ - f"found `{type(self.quad_initializer).__name___}`." + f"found `{type(self.quad_initializer).__name__}`." assert self.quad_initializer.rank == self.rank, \ f"Expected quadratic initializer of rank `{self.rank}`, " \ f"found `{self.quad_initializer.rank}`." @@ -345,8 +344,8 @@ def tree_flatten(self) -> Tuple[Sequence[Any], Dict[str, Any]]: def iterations( solver: GromovWasserstein, - prob: quad_problems.QuadraticProblem, - init: linear_problems.LinearProblem, + prob: quadratic_problem.QuadraticProblem, + init: linear_problem.LinearProblem, key: jnp.ndarray, ) -> GWOutput: """Jittable Gromov-Wasserstein outer loop.""" @@ -469,7 +468,7 @@ def gromov_wasserstein( scale_cost: Optional[Union[bool, float, str]] = False, a: Optional[jnp.ndarray] = None, b: Optional[jnp.ndarray] = None, - loss: Union[Literal['sqeucl', 'kl'], quad_problems.GWLoss] = 'sqeucl', + loss: Union[Literal['sqeucl', 'kl'], quadratic_costs.GWLoss] = 'sqeucl', tau_a: Optional[float] = 1.0, tau_b: Optional[float] = 1.0, gw_unbalanced_correction: bool = True, @@ -494,16 +493,14 @@ def gromov_wasserstein( - if `True`, use the default for each geometry. - if `False`, keep the original scaling in geometries. - if :class:`str`, use a specific method available in - :class:`ott.geometry.geometry.Geometry` or - :class:`ott.geometry.pointcloud.PointCloud`. + :class:`~ott.geometry.geometry.Geometry` or + :class:`~ott.geometry.pointcloud.PointCloud`. - if `None`, do not scale the cost matrices. a: jnp.ndarray[num_a,] or jnp.ndarray[batch,num_a] weights. b: jnp.ndarray[num_b,] or jnp.ndarray[batch,num_b] weights. loss: defaults to the square Euclidean distance. Can also pass 'kl' to define the GW loss as KL loss. - See :class:`~ott.core.gromov_wasserstein.GromovWasserstein` on how to pass - custom loss. tau_a: float between 0 and 1.0, parameter that controls the strength of the KL divergence constraint between the weights and marginals of the transport for the first view. If set to 1.0, then it is equivalent to a @@ -529,13 +526,13 @@ def gromov_wasserstein( geometries are *not* :class:`~ott.geometry.pointcloud.PointCloud` with `'sqeucl'` cost. If :class:`float`, that tolerance is shared across all 3 geometries. - kwargs: Keyword arguments to - :class:`~ott.core.gromov_wasserstein.GromovWasserstein`. + kwargs: Keyword arguments for + :class:`~ott.solvers.quadratic.gromov_wasserstein.GromovWasserstein`. Returns: A GromovWassersteinState named tuple. """ - prob = quad_problems.QuadraticProblem( + prob = quadratic_problem.QuadraticProblem( geom_xx, geom_yy, geom_xy=geom_xy, diff --git a/ott/core/gw_barycenter.py b/ott/solvers/quadratic/gw_barycenter.py similarity index 89% rename from ott/core/gw_barycenter.py rename to ott/solvers/quadratic/gw_barycenter.py index 18891afd1..f88e175cd 100644 --- a/ott/core/gw_barycenter.py +++ b/ott/solvers/quadratic/gw_barycenter.py @@ -4,20 +4,19 @@ import jax import jax.numpy as jnp -from ott.core import ( - bar_problems, - fixed_point_loop, - gromov_wasserstein, - linear_problems, - was_solver, -) from ott.geometry import pointcloud +from ott.math import fixed_point_loop +from ott.problems.linear import linear_problem +from ott.problems.quadratic import gw_barycenter +from ott.solvers import was_solver +from ott.solvers.quadratic import gromov_wasserstein __all__ = ["GWBarycenterState", "GromovWassersteinBarycenter"] class GWBarycenterState(NamedTuple): - """Holds the state of the :class:`~ott.core.bar_problems.GWBarycenterProblem`. + """Holds the state of the \ + :class:`~ott.problems.quadratic.gw_barycenter.GWBarycenterProblem`. Args: c: Barycenter cost matrix of shape ``[bar_size, bar_size]``. @@ -46,7 +45,7 @@ def set(self, **kwargs: Any) -> 'GWBarycenterState': @jax.tree_util.register_pytree_node_class class GromovWassersteinBarycenter(was_solver.WassersteinSolver): """Gromov-Wasserstein barycenter solver of the \ - :class:`~ott.core.bar_problems.GWBarycenterProblem`. + :class:`~ott.problems.quadratic.gw_barycenter.GWBarycenterProblem`. Args: epsilon: Entropy regulariser. @@ -58,7 +57,7 @@ class GromovWassersteinBarycenter(was_solver.WassersteinSolver): as its linear solver, at each iteration for each measure. quad_solver: The GW solver. kwargs: Keyword argument for - :class:`~ott.core.gromov_wasserstein.GromovWasserstein`. + :class:`~ott.solvers.quadratic.gromov_wasserstein.GromovWasserstein`. Only used when ``quad_solver = None``. """ @@ -93,7 +92,7 @@ def __init__( self._quad_solver = gromov_wasserstein.GromovWasserstein(**kwargs) def __call__( - self, problem: bar_problems.GWBarycenterProblem, bar_size: int, + self, problem: gw_barycenter.GWBarycenterProblem, bar_size: int, **kwargs: Any ) -> GWBarycenterState: """Solver the (fused) GW barycenter problem. @@ -113,7 +112,7 @@ def __call__( def init_state( self, - problem: bar_problems.GWBarycenterProblem, + problem: gw_barycenter.GWBarycenterProblem, bar_size: int, bar_init: Optional[Union[jnp.ndarray, Tuple[jnp.ndarray, jnp.ndarray]]] = None, @@ -131,9 +130,9 @@ def init_state( - :class:`jax.numpy.ndarray` - barycenter cost matrix of shape ``[bar_size, bar_size]``. Only used in the non-fused case. - - 2- :class:`tuple` of :class:`jax.numpy.ndarray` - the 1st array - corresponds to ``[bar_size, bar_size]`` cost matrix, - the 2nd array is ``[bar_size, ndim_fused]`` a feature matrix used in + - :class:`tuple` of :class:`jax.numpy.ndarray` - the 1st array + corresponds to a cost matrix of shape ``[bar_size, bar_size]``, + the 2nd array is a ``[bar_size, ndim_fused]`` feature matrix used in the fused case. a: An array of shape ``[bar_size,]`` containing the barycenter weights. @@ -188,7 +187,7 @@ def update_state( self, state: GWBarycenterState, iteration: int, - problem: bar_problems.GWBarycenterProblem, + problem: gw_barycenter.GWBarycenterProblem, store_errors: bool = True, ) -> Tuple[float, bool, jnp.ndarray, Optional[jnp.ndarray]]: @@ -236,7 +235,7 @@ def solve_gw( def output_from_state(self, state: GWBarycenterState) -> GWBarycenterState: """No-op.""" # TODO(michalk8): just for consistency with continuous barycenter - # will be refactored in the future + # will be refactored in the future to create an output return state def tree_flatten(self) -> Tuple[Sequence[Any], Dict[str, Any]]: @@ -279,13 +278,13 @@ def init_transports( geom = pointcloud.PointCloud( x, y, epsilon=epsilon, src_mask=a > 0, tgt_mask=b > 0 ) - problem = linear_problems.LinearProblem(geom, a=a, b=b) + problem = linear_problem.LinearProblem(geom, a=a, b=b) return solver(problem).matrix def iterations( solver: GromovWassersteinBarycenter, - problem: bar_problems.GWBarycenterProblem, init_state: GWBarycenterState + problem: gw_barycenter.GWBarycenterProblem, init_state: GWBarycenterState ) -> GWBarycenterState: def cond_fn( @@ -297,7 +296,7 @@ def cond_fn( def body_fn( iteration, constants: Tuple[GromovWassersteinBarycenter, - bar_problems.GWBarycenterProblem], + gw_barycenter.GWBarycenterProblem], state: GWBarycenterState, compute_error: bool ) -> GWBarycenterState: del compute_error # always assumed true diff --git a/ott/core/was_solver.py b/ott/solvers/was_solver.py similarity index 87% rename from ott/core/was_solver.py rename to ott/solvers/was_solver.py index eab88c19a..c823ef1d4 100644 --- a/ott/core/was_solver.py +++ b/ott/solvers/was_solver.py @@ -14,17 +14,21 @@ # Lint as: python3 """A Jax version of the regularised GW Solver (Peyre et al. 2016).""" -from typing import Any, Dict, Optional, Sequence, Tuple, Union +from typing import TYPE_CHECKING, Any, Dict, Optional, Sequence, Tuple, Union import jax import jax.numpy as jnp -from ott.core import sinkhorn, sinkhorn_lr +if TYPE_CHECKING: + from ott.solvers.linear import continuous_barycenter, sinkhorn, sinkhorn_lr -State = Union[sinkhorn.SinkhornState, sinkhorn_lr.LRSinkhornState, - "continuous_barycenter.BarycenterState"] # noqa: F821 +__all__ = ["WassersteinSolver"] +State = Union["sinkhorn.SinkhornState", "sinkhorn_lr.LRSinkhornState", + "continuous_barycenter.BarycenterState"] + +# TODO(michalk8): refactor to have generic nested solver API @jax.tree_util.register_pytree_node_class class WassersteinSolver: """A generic solver for problems that use a linear reg-OT pb in inner loop.""" @@ -33,8 +37,8 @@ def __init__( self, epsilon: Optional[float] = None, rank: int = -1, - linear_ot_solver: Optional[Union[sinkhorn.Sinkhorn, - sinkhorn_lr.LRSinkhorn]] = None, + linear_ot_solver: Optional[Union["sinkhorn.Sinkhorn", + "sinkhorn_lr.LRSinkhorn"]] = None, min_iterations: int = 5, max_iterations: int = 50, threshold: float = 1e-3, @@ -42,6 +46,7 @@ def __init__( store_inner_errors: bool = False, **kwargs: Any, ): + from ott.solvers.linear import sinkhorn, sinkhorn_lr default_epsilon = 1.0 # Set epsilon value to default if needed, but keep track of whether None was # passed to handle the case where a linear_ot_solver is passed or not. diff --git a/ott/tools/gaussian_mixture/fit_gmm.py b/ott/tools/gaussian_mixture/fit_gmm.py index d09cf6288..112ac44d7 100644 --- a/ott/tools/gaussian_mixture/fit_gmm.py +++ b/ott/tools/gaussian_mixture/fit_gmm.py @@ -185,7 +185,7 @@ def fit_model_em( def _get_dist_sq(points: jnp.ndarray, loc: jnp.ndarray) -> jnp.ndarray: """Get the squared distance from each point to each loc.""" - def _dist_sq_one_loc(points: jnp.ndarray, loc: jnp.ndarray): + def _dist_sq_one_loc(points: jnp.ndarray, loc: jnp.ndarray) -> jnp.ndarray: return jnp.sum((points - loc[None]) ** 2., axis=-1) dist_sq_fn = jax.vmap(_dist_sq_one_loc, in_axes=(None, 0), out_axes=1) @@ -266,9 +266,9 @@ def initialize( key: jnp.ndarray, points: jnp.ndarray, point_weights: Optional[jnp.ndarray], - n_components: jnp.ndarray, - n_attempts=50, - verbose=False + n_components: int, + n_attempts: int = 50, + verbose: bool = False ) -> gaussian_mixture.GaussianMixture: """Initialize a GMM via K-means++ with retries on failure. @@ -289,7 +289,7 @@ def initialize( for attempt in range(n_attempts): key, subkey = jax.random.split(key) try: - gmm = from_kmeans_plusplus( + return from_kmeans_plusplus( key=subkey, points=points, point_weights=point_weights, @@ -297,6 +297,5 @@ def initialize( ) except ValueError: if verbose: - print(f'Failed to initialize, attempt {attempt}', flush=True) - return gmm - raise ValueError('Failed to initialize') + print(f'Failed to initialize, attempt {attempt}.', flush=True) + raise ValueError('Failed to initialize.') diff --git a/ott/tools/gaussian_mixture/gaussian_mixture_pair.py b/ott/tools/gaussian_mixture/gaussian_mixture_pair.py index c5b8968c1..eccd3da68 100644 --- a/ott/tools/gaussian_mixture/gaussian_mixture_pair.py +++ b/ott/tools/gaussian_mixture/gaussian_mixture_pair.py @@ -16,8 +16,8 @@ import jax import jax.numpy as jnp -from ott.core import sinkhorn from ott.geometry import costs, geometry, pointcloud +from ott.solvers.linear import sinkhorn from ott.tools.gaussian_mixture import gaussian_mixture diff --git a/ott/tools/gaussian_mixture/scale_tril.py b/ott/tools/gaussian_mixture/scale_tril.py index c4f7ea077..53a8df1fc 100644 --- a/ott/tools/gaussian_mixture/scale_tril.py +++ b/ott/tools/gaussian_mixture/scale_tril.py @@ -18,7 +18,8 @@ import jax import jax.numpy as jnp -from ott.geometry import costs, matrix_square_root +from ott.geometry import costs +from ott.math import matrix_square_root from ott.tools.gaussian_mixture import linalg diff --git a/ott/tools/k_means.py b/ott/tools/k_means.py index 335d16dc8..4ad02723d 100644 --- a/ott/tools/k_means.py +++ b/ott/tools/k_means.py @@ -19,8 +19,8 @@ import jax.numpy as jnp from typing_extensions import Literal -from ott.core import fixed_point_loop from ott.geometry import costs, pointcloud +from ott.math import fixed_point_loop __all__ = ["k_means", "KMeansOutput"] diff --git a/ott/tools/segment_sinkhorn.py b/ott/tools/segment_sinkhorn.py index b770e06be..886430d07 100644 --- a/ott/tools/segment_sinkhorn.py +++ b/ott/tools/segment_sinkhorn.py @@ -15,10 +15,10 @@ from types import MappingProxyType from typing import Any, Mapping, Optional, Tuple -from jax import numpy as jnp +import jax.numpy as jnp -from ott.core import segment, sinkhorn -from ott.geometry import costs, pointcloud +from ott.geometry import costs, pointcloud, segment +from ott.solvers.linear import sinkhorn def segment_sinkhorn( @@ -29,9 +29,9 @@ def segment_sinkhorn( cost_fn: Optional[costs.CostFn] = None, segment_ids_x: Optional[jnp.ndarray] = None, segment_ids_y: Optional[jnp.ndarray] = None, - indices_are_sorted: Optional[bool] = None, - num_per_segment_x: Tuple[int] = None, - num_per_segment_y: Tuple[int] = None, + indices_are_sorted: bool = False, + num_per_segment_x: Optional[Tuple[int, ...]] = None, + num_per_segment_y: Optional[Tuple[int, ...]] = None, weights_x: Optional[jnp.ndarray] = None, weights_y: Optional[jnp.ndarray] = None, sinkhorn_kwargs: Mapping[str, Any] = MappingProxyType({}), @@ -56,22 +56,23 @@ def segment_sinkhorn( parallel. Args: - x: Array of input points, of shape [num_x, feature]. Multiple segments are - held in this single array. - y: Array of target points, of shape [num_y, feature]. - num_segments: Number of segments contained in x and y. Providing this number - is required for JIT compilation to work, see also - :func:`~ott.core.segment.segment_point_cloud`. + x: Array of input points, of shape `[num_x, feature]`. + Multiple segments are held in this single array. + y: Array of target points, of shape `[num_y, feature]`. + num_segments: Number of segments contained in `x` and `y`. + Providing this is required for JIT compilation to work, + see also :func:`~ott.geometry.segment.segment_point_cloud`. max_measure_size: Total size of measures after padding. Should ideally be set to an upper bound on points clouds processed with the segment - interface. Providing this number is required for JIT compilation to work. - cost_fn: Cost function, defaults to :class:`~ott.core.costs.SqEuclidean`. - segment_ids_x: **1st interface** The segment ID for which each row of x + interface. Providing this is required for JIT compilation to work. + cost_fn: Cost function, defaults to + :class:`~ott.geometry.costs.SqEuclidean`. + segment_ids_x: **1st interface** The segment ID for which each row of `x` belongs. This is a similar interface to `jax.ops.segment_sum`. - segment_ids_y: **1st interface** The segment ID for which each row of y + segment_ids_y: **1st interface** The segment ID for which each row of `y` belongs. indices_are_sorted: **1st interface** Whether `segment_ids_x` and - `segment_ids_y` are sorted. Default false. + `segment_ids_y` are sorted. num_per_segment_x: **2nd interface** Number of points in each segment in `x`. For example, [100, 20, 30] would imply that `x` is segmented into three arrays of length `[100]`, `[20]`, and `[30]` respectively. @@ -87,9 +88,9 @@ def segment_sinkhorn( `y`/`y` (except when `static_b` is `True`, in which case `y`/`y` is not evaluated). kwargs: keywords arguments passed to form - :class:`ott.geometry.pointcloud.PointCloud` geometry objects from the + :class:`~ott.geometry.pointcloud.PointCloud` geometry objects from the subsets of points and masses selected in `x` and `y`, possibly a - :class:`ott.geometry.costs.CostFn` or an entropy regularizer. + :class:`~ott.geometry.costs.CostFn` or an entropy regularizer. Returns: An array of sinkhorn reg_ot_cost for each segment. @@ -98,9 +99,9 @@ def segment_sinkhorn( dim = x.shape[1] if cost_fn is None: # default padder - padding_vector = costs.CostFn.padder(dim=dim) + padding_vector = costs.CostFn._padder(dim=dim) else: - padding_vector = cost_fn.padder(dim=dim) + padding_vector = cost_fn._padder(dim=dim) def eval_fn( padded_x: jnp.ndarray, diff --git a/ott/tools/sinkhorn_divergence.py b/ott/tools/sinkhorn_divergence.py index faa3e1ba0..28c32fa82 100644 --- a/ott/tools/sinkhorn_divergence.py +++ b/ott/tools/sinkhorn_divergence.py @@ -17,8 +17,9 @@ import jax.numpy as jnp -from ott.core import potentials, segment, sinkhorn -from ott.geometry import costs, geometry, pointcloud +from ott.geometry import costs, geometry, pointcloud, segment +from ott.problems.linear import potentials +from ott.solvers.linear import sinkhorn __all__ = [ "sinkhorn_divergence", "segment_sinkhorn_divergence", @@ -69,8 +70,9 @@ def sinkhorn_divergence( match that of `b` to converge. b: the weight of each target point. The sum of all elements of `b` must match that of `a` to converge. - sinkhorn_kwargs: keywords arguments for :func:`~ott.core.sinkhorn.sinkhorn` - that is called twice if ``static_b = True`` else 3 times. + sinkhorn_kwargs: keywords arguments for + :func:`~ott.solvers.linear.sinkhorn.sinkhorn` that is called twice + if ``static_b = True`` else 3 times. static_b: if True, divergence of measure `b` against itself is **not** computed. share_epsilon: if True, enforces that the same epsilon regularizer is shared @@ -137,7 +139,7 @@ def _sinkhorn_divergence( all elements of b must match that of a to converge. symmetric_sinkhorn: Use Sinkhorn updates in Eq. 25 of :cite:`feydy:19` for symmetric terms comparing x/x and y/y. - kwargs: Keyword arguments to :func:`ott.core.sinkhorn.sinkhorn`. + kwargs: Keyword arguments to :func:`~ott.solvers.linear.sinkhorn.sinkhorn`. Returns: SinkhornDivergenceOutput named tuple. @@ -189,9 +191,9 @@ def segment_sinkhorn_divergence( cost_fn: Optional[costs.CostFn] = None, segment_ids_x: Optional[jnp.ndarray] = None, segment_ids_y: Optional[jnp.ndarray] = None, - indices_are_sorted: Optional[bool] = None, - num_per_segment_x: Optional[jnp.ndarray] = None, - num_per_segment_y: Optional[jnp.ndarray] = None, + indices_are_sorted: bool = False, + num_per_segment_x: Optional[Tuple[int, ...]] = None, + num_per_segment_y: Optional[Tuple[int, ...]] = None, weights_x: Optional[jnp.ndarray] = None, weights_y: Optional[jnp.ndarray] = None, sinkhorn_kwargs: Mapping[str, Any] = MappingProxyType({}), @@ -218,23 +220,24 @@ def segment_sinkhorn_divergence( a tensor, and `vmap` used to evaluate sinkhorn divergences in parallel. Args: - x: Array of input points, of shape [num_x, feature]. Multiple segments are - held in this single array. - y: Array of target points, of shape [num_y, feature]. - num_segments: Number of segments contained in x and y. Providing this number - is required for JIT compilation to work, see also - :func:`~ott.core.segment.segment_point_cloud`. + x: Array of input points, of shape `[num_x, feature]`. + Multiple segments are held in this single array. + y: Array of target points, of shape `[num_y, feature]`. + num_segments: Number of segments contained in `x` and `y`. + Providing this is required for JIT compilation to work, + see also :func:`~ott.geometry.segment.segment_point_cloud`. max_measure_size: Total size of measures after padding. Should ideally be set to an upper bound on points clouds processed with the segment interface. Should also be smaller than total length of `x` or `y`. - Providing this number is required for JIT compilation to work. - cost_fn: Cost function, defaults to :class:`~ott.core.costs.SqEuclidean`. - segment_ids_x: **1st interface** The segment ID for which each row of x + Providing this is required for JIT compilation to work. + cost_fn: Cost function, + defaults to :class:`~ott.geometry.costs.SqEuclidean`. + segment_ids_x: **1st interface** The segment ID for which each row of `x` belongs. This is a similar interface to :func:`jax.ops.segment_sum`. - segment_ids_y: **1st interface** The segment ID for which each row of y + segment_ids_y: **1st interface** The segment ID for which each row of `y` belongs. indices_are_sorted: **1st interface** Whether `segment_ids_x` and - `segment_ids_y` are sorted. Default false. + `segment_ids_y` are sorted. num_per_segment_x: **2nd interface** Number of points in each segment in `x`. For example, [100, 20, 30] would imply that `x` is segmented into three arrays of length `[100]`, `[20]`, and `[30]` respectively. @@ -259,7 +262,7 @@ def segment_sinkhorn_divergence( symmetric_sinkhorn: Use Sinkhorn updates in Eq. 25 of :cite:`feydy:19` for symmetric terms comparing x/x and y/y. kwargs: keywords arguments passed to form - :class:`ott.geometry.pointcloud.PointCloud` geometry objects from the + :class:`~ott.geometry.pointcloud.PointCloud` geometry objects from the subsets of points and masses selected in `x` and `y`, this could be for instance entropy regularization float, scheduler or normalization. Returns: @@ -269,9 +272,9 @@ def segment_sinkhorn_divergence( dim = x.shape[1] if cost_fn is None: # default padder - padding_vector = costs.CostFn.padder(dim=dim) + padding_vector = costs.CostFn._padder(dim=dim) else: - padding_vector = cost_fn.padder(dim=dim) + padding_vector = cost_fn._padder(dim=dim) def eval_fn( padded_x: jnp.ndarray, diff --git a/ott/tools/transport.py b/ott/tools/transport.py index 452bff1dc..1f25a8d8f 100644 --- a/ott/tools/transport.py +++ b/ott/tools/transport.py @@ -21,30 +21,35 @@ >>> ot = ott.transport.solve(x, y) >>> Tz = ot.apply(z) -Even if the transport.solve sole function can support many complex use cases, we -suggest more advanced users to instantiate directly their problem (see -ott.core.problems) and their solvers (see ott.core.sinkhorn and -ott.core.gromov_wasserstein) for better control over the parameters. +Even if the `transport.solve` sole function can support many complex use cases, +we suggest more advanced users to instantiate directly their :mod:`ott.problems` +and their :mod:`ott.solvers` for better control over the parameters. """ -from typing import Any, NamedTuple, Optional +from typing import Any, NamedTuple, Optional, Union import jax.numpy as jnp +import numpy as np from typing_extensions import Literal -from ott.core import gromov_wasserstein, linear_problems, problems, sinkhorn -from ott.geometry import geometry +from ott.geometry import geometry, pointcloud +from ott.problems.linear import linear_problem +from ott.problems.quadratic import quadratic_problem +from ott.solvers.linear import sinkhorn +from ott.solvers.quadratic import gromov_wasserstein + +__all__ = ["Transport"] class Transport(NamedTuple): - """Implement a core.problems.Transport interface to transport solutions.""" + """Transport interface to transport solutions.""" problem: Any = None solver_output: Any = None @property def linear(self) -> bool: - return isinstance(self.problem, linear_problems.LinearProblem) + return isinstance(self.problem, linear_problem.LinearProblem) @property def geom(self) -> geometry.Geometry: @@ -109,7 +114,7 @@ def solve( fused_penalty = kwargs.pop('fused_penalty', None) eps_keys = ['epsilon', 'init', 'target', 'decay'] pb_kwargs = {k: v for k, v in kwargs.items() if k in eps_keys} - pb = problems.make( + pb = make( *args, objective=objective, a=a, @@ -120,7 +125,7 @@ def solve( fused_penalty=fused_penalty, **pb_kwargs ) - linear = isinstance(pb, linear_problems.LinearProblem) + linear = isinstance(pb, linear_problem.LinearProblem) solver_fn = sinkhorn.make if linear else gromov_wasserstein.make geom_keys = ['cost_fn', 'online'] @@ -130,3 +135,76 @@ def solve( solver = solver_fn(**kwargs) output = solver(pb, (init_dual_a, init_dual_b)) return Transport(pb, output) + + +def make( + *args: Union[jnp.ndarray, geometry.Geometry, linear_problem.LinearProblem, + quadratic_problem.QuadraticProblem], + a: Optional[jnp.ndarray] = None, + b: Optional[jnp.ndarray] = None, + tau_a: float = 1.0, + tau_b: float = 1.0, + objective: Optional[str] = None, + gw_unbalanced_correction: Optional[bool] = True, + fused_penalty: Optional[float] = None, + scale_cost: Optional[Union[bool, float, str]] = False, + **kwargs: Any, +): + """Make a problem from arrays, assuming PointCloud geometries.""" + if isinstance(args[0], (jnp.ndarray, np.ndarray)): + x = args[0] + y = args[1] if len(args) > 1 else args[0] + if ((objective == 'linear') or + (objective is None and x.shape[1] == y.shape[1])): # noqa: E129 + geom_xy = pointcloud.PointCloud(x, y, **kwargs) + return linear_problem.LinearProblem( + geom_xy, a=a, b=b, tau_a=tau_a, tau_b=tau_b + ) + elif ((objective == 'quadratic') or + (objective is None and x.shape[1] != y.shape[1])): + geom_xx = pointcloud.PointCloud(x, x, **kwargs) + geom_yy = pointcloud.PointCloud(y, y, **kwargs) + return quadratic_problem.QuadraticProblem( + geom_xx=geom_xx, + geom_yy=geom_yy, + geom_xy=None, + scale_cost=scale_cost, + a=a, + b=b, + tau_a=tau_a, + tau_b=tau_b, + gw_unbalanced_correction=gw_unbalanced_correction + ) + elif objective == 'fused': + geom_xx = pointcloud.PointCloud(x, x, **kwargs) + geom_yy = pointcloud.PointCloud(y, y, **kwargs) + geom_xy = pointcloud.PointCloud(x, y, **kwargs) + return quadratic_problem.QuadraticProblem( + geom_xx=geom_xx, + geom_yy=geom_yy, + geom_xy=geom_xy, + fused_penalty=fused_penalty, + scale_cost=scale_cost, + a=a, + b=b, + tau_a=tau_a, + tau_b=tau_b, + gw_unbalanced_correction=gw_unbalanced_correction + ) + else: + raise ValueError(f'Unknown transport problem `{objective}`') + elif isinstance(args[0], geometry.Geometry): + if len(args) == 1: + return linear_problem.LinearProblem( + *args, a=a, b=b, tau_a=tau_a, tau_b=tau_b + ) + return quadratic_problem.QuadraticProblem( + *args, a=a, b=b, tau_a=tau_a, tau_b=tau_b, scale_cost=scale_cost + ) + elif isinstance( + args[0], + (linear_problem.LinearProblem, quadratic_problem.QuadraticProblem) + ): + return args[0] + else: + raise ValueError('Cannot instantiate a transport problem.') diff --git a/ott/types.py b/ott/types.py new file mode 100644 index 000000000..28e746713 --- /dev/null +++ b/ott/types.py @@ -0,0 +1,22 @@ +import jax.numpy as jnp +from typing_extensions import Protocol + +# TODO(michalk8): introduce additional types here + + +class Transport(Protocol): + """Interface for the solution of a transport problem. + + Classes implementing those function do not have to inherit from it, the + class can however be used in type hints to support duck typing. + """ + + @property + def matrix(self) -> jnp.ndarray: + ... + + def apply(self, inputs: jnp.ndarray, axis: int) -> jnp.ndarray: + ... + + def marginal(self, axis: int = 0) -> jnp.ndarray: + ... diff --git a/ott/core/dataclasses.py b/ott/utils.py similarity index 96% rename from ott/core/dataclasses.py rename to ott/utils.py index 78aebac51..5f3de2ef6 100644 --- a/ott/core/dataclasses.py +++ b/ott/utils.py @@ -17,6 +17,8 @@ import jax +__all__ = ["register_pytree_node"] + def register_pytree_node(cls: type) -> type: """Register dataclasses as pytree_nodes.""" diff --git a/tests/core/initializers_test.py b/tests/core/initializers_test.py deleted file mode 100644 index 01df6ad6d..000000000 --- a/tests/core/initializers_test.py +++ /dev/null @@ -1,590 +0,0 @@ -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -# Lint as: python3 -"""Tests for Sinkhorn initializers.""" -import functools - -import jax -import jax.numpy as jnp -import numpy as np -import pytest - -from ott.core import gromov_wasserstein -from ott.core import initializers as init_lib -from ott.core import ( - initializers_lr, - linear_problems, - quad_initializers, - quad_problems, - sinkhorn, - sinkhorn_lr, -) -from ott.geometry import geometry, low_rank, pointcloud - - -def create_sorting_problem(rng, n, epsilon=0.01, online=False): - # define ot problem - x_init = jnp.array([-1., 0, .22]) - y_init = jnp.array([0., 0, 1.1]) - x_rng, y_rng = jax.random.split(rng) - - x = jnp.concatenate([x_init, 10 + jnp.abs(jax.random.normal(x_rng, (n,)))]) - y = jnp.concatenate([y_init, 10 + jnp.abs(jax.random.normal(y_rng, (n,)))]) - - x = jnp.sort(x) - y = jnp.sort(y) - - n = len(x) - m = len(y) - a = jnp.ones(n) / n - b = jnp.ones(m) / m - - batch_size = 3 if online else None - geom = pointcloud.PointCloud( - x.reshape(-1, 1), - y.reshape(-1, 1), - epsilon=epsilon, - batch_size=batch_size - ) - ot_problem = linear_problems.LinearProblem(geom=geom, a=a, b=b) - - return ot_problem - - -def create_ot_problem(rng, n, m, d, epsilon=0.01, online=False): - # define ot problem - x_rng, y_rng = jax.random.split(rng) - - mu_a = jnp.array([-1, 1]) * 5 - mu_b = jnp.array([0, 0]) - - x = jax.random.normal(x_rng, (n, d)) + mu_a - y = jax.random.normal(y_rng, (m, d)) + mu_b - - a = jnp.ones(n) / n - b = jnp.ones(m) / m - - batch_size = 3 if online else None - geom = pointcloud.PointCloud(x, y, epsilon=epsilon, batch_size=batch_size) - - ot_problem = linear_problems.LinearProblem(geom=geom, a=a, b=b) - return ot_problem - - -# define sinkhorn functions -@functools.partial(jax.jit, static_argnames=['lse_mode', 'vector_min']) -def run_sinkhorn_sort_init( - x, y, a=None, b=None, epsilon=0.01, vector_min=True, lse_mode=True -): - geom = pointcloud.PointCloud(x, y, epsilon=epsilon) - sort_init = init_lib.SortingInitializer(vectorized_update=vector_min) - out = sinkhorn.sinkhorn( - geom, a=a, b=b, jit=True, initializer=sort_init, lse_mode=lse_mode - ) - return out - - -@functools.partial(jax.jit, static_argnames=['lse_mode']) -def run_sinkhorn(x, y, a=None, b=None, epsilon=0.01, lse_mode=True): - geom = pointcloud.PointCloud(x, y, epsilon=epsilon) - out = sinkhorn.sinkhorn(geom, a=a, b=b, jit=True, lse_mode=lse_mode) - return out - - -@functools.partial(jax.jit, static_argnames=['lse_mode']) -def run_sinkhorn_gaus_init(x, y, a=None, b=None, epsilon=0.01, lse_mode=True): - geom = pointcloud.PointCloud(x, y, epsilon=epsilon) - out = sinkhorn.sinkhorn( - geom, - a=a, - b=b, - jit=True, - initializer=init_lib.GaussianInitializer(), - lse_mode=lse_mode - ) - return out - - -@pytest.mark.fast -class TestSinkhornInitializers: - - def test_init_pytree(self): - - @jax.jit - def init_sort(): - init = init_lib.SortingInitializer() - return init - - @jax.jit - def init_gaus(): - init = init_lib.GaussianInitializer() - return init - - _ = init_gaus() - _ = init_sort() - - @pytest.mark.parametrize( - "init", [ - "default", "gaussian", "sorting", - init_lib.DefaultInitializer(), "non-existent" - ] - ) - def test_create_initializer(self, init: str): - solver = sinkhorn.Sinkhorn(initializer=init) - expected_types = { - "default": init_lib.DefaultInitializer, - "gaussian": init_lib.GaussianInitializer, - "sorting": init_lib.SortingInitializer, - } - - if isinstance(init, init_lib.SinkhornInitializer): - assert solver.create_initializer() is init - elif init == "non-existent": - with pytest.raises(NotImplementedError, match=r""): - _ = solver.create_initializer() - else: - actual = solver.create_initializer() - expected_type = expected_types[init] - assert isinstance(actual, expected_type) - - @pytest.mark.parametrize( - "vector_min, lse_mode", [(True, True), (True, False), (False, True)] - ) - def test_sorting_init(self, vector_min: bool, lse_mode: bool): - """Tests sorting dual initializer.""" - rng = jax.random.PRNGKey(42) - n = 500 - epsilon = 0.01 - - ot_problem = create_sorting_problem( - rng=rng, n=n, epsilon=epsilon, online=False - ) - # run sinkhorn - sink_out_base = run_sinkhorn( - x=ot_problem.geom.x, - y=ot_problem.geom.y, - a=ot_problem.a, - b=ot_problem.b, - epsilon=epsilon - ) - base_num_iter = jnp.sum(sink_out_base.errors > -1) - - sink_out_init = run_sinkhorn_sort_init( - x=ot_problem.geom.x, - y=ot_problem.geom.y, - a=ot_problem.a, - b=ot_problem.b, - epsilon=epsilon, - vector_min=vector_min, - lse_mode=lse_mode - ) - sort_num_iter = jnp.sum(sink_out_init.errors > -1) - - # check initializer is better or equal - if lse_mode: - assert base_num_iter >= sort_num_iter - - def test_sorting_init_online(self, rng: jnp.ndarray): - n = 100 - epsilon = 0.01 - - ot_problem = create_sorting_problem( - rng=rng, n=n, epsilon=epsilon, online=True - ) - sort_init = init_lib.SortingInitializer(vectorized_update=True) - with pytest.raises(AssertionError, match=r"online"): - sort_init.init_dual_a(ot_problem, lse_mode=True) - - def test_sorting_init_square_cost(self, rng: jnp.ndarray): - n = 100 - m = 150 - d = 1 - epsilon = 0.01 - - ot_problem = create_ot_problem(rng, n, m, d, epsilon=epsilon, online=False) - sort_init = init_lib.SortingInitializer(vectorized_update=True) - with pytest.raises(AssertionError, match=r"square"): - sort_init.init_dual_a(ot_problem, lse_mode=True) - - def test_default_initializer(self, rng: jnp.ndarray): - """Tests default initializer""" - n = 200 - m = 200 - d = 2 - epsilon = 0.01 - - ot_problem = create_ot_problem(rng, n, m, d, epsilon=epsilon, online=False) - - default_potential_a = init_lib.DefaultInitializer().init_dual_a( - ot_problem, lse_mode=True - ) - default_potential_b = init_lib.DefaultInitializer().init_dual_b( - ot_problem, lse_mode=True - ) - - # check default is 0 - np.testing.assert_array_equal(0., default_potential_a) - np.testing.assert_array_equal(0., default_potential_b) - - def test_gauss_pointcloud_geom(self, rng: jnp.ndarray): - n = 200 - m = 200 - d = 2 - epsilon = 0.01 - - ot_problem = create_ot_problem(rng, n, m, d, epsilon=epsilon, online=False) - - gaus_init = init_lib.GaussianInitializer() - new_geom = geometry.Geometry( - cost_matrix=ot_problem.geom.cost_matrix, epsilon=epsilon - ) - ot_problem = linear_problems.LinearProblem( - geom=new_geom, a=ot_problem.a, b=ot_problem.b - ) - - with pytest.raises(AssertionError, match=r"point cloud"): - gaus_init.init_dual_a(ot_problem, lse_mode=True) - - @pytest.mark.parametrize('lse_mode', [True, False]) - def test_gauss_initializer(self, lse_mode, rng: jnp.ndarray): - """Tests Gaussian initializer""" - # define OT problem - n = 200 - m = 200 - d = 2 - epsilon = 0.01 - - ot_problem = create_ot_problem(rng, n, m, d, epsilon=epsilon, online=False) - - # run sinkhorn - sink_out = run_sinkhorn( - x=ot_problem.geom.x, - y=ot_problem.geom.y, - a=ot_problem.a, - b=ot_problem.b, - epsilon=epsilon, - lse_mode=lse_mode - ) - base_num_iter = jnp.sum(sink_out.errors > -1) - sink_out = run_sinkhorn_gaus_init( - x=ot_problem.geom.x, - y=ot_problem.geom.y, - a=ot_problem.a, - b=ot_problem.b, - epsilon=epsilon, - lse_mode=lse_mode - ) - gaus_num_iter = jnp.sum(sink_out.errors > -1) - - # check initializer is better - if lse_mode: - assert base_num_iter >= gaus_num_iter - - @pytest.mark.parametrize('lse_mode', [True, False]) - def test_meta_initializer(self, lse_mode, rng: jnp.ndarray): - """Tests Meta initializer""" - # define OT problem - n = 200 - m = 200 - d = 2 - epsilon = 0.01 - - ot_problem = create_ot_problem(rng, n, m, d, epsilon=epsilon, online=False) - a = ot_problem.a - b = ot_problem.b - geom = ot_problem.geom - - # run sinkhorn - sink_out = run_sinkhorn( - x=ot_problem.geom.x, - y=ot_problem.geom.y, - a=ot_problem.a, - b=ot_problem.b, - epsilon=epsilon, - lse_mode=lse_mode - ) - base_num_iter = jnp.sum(sink_out.errors > -1) - - # Overfit the initializer to the problem. - meta_initializer = init_lib.MetaInitializer(geom) - for _ in range(100): - _, _, meta_initializer.state = meta_initializer.update( - meta_initializer.state, a=a, b=b - ) - - sink_out = sinkhorn.sinkhorn( - geom, - a=a, - b=b, - jit=True, - initializer=meta_initializer, - lse_mode=lse_mode - ) - meta_num_iter = jnp.sum(sink_out.errors > -1) - - # check initializer is better - if lse_mode: - assert base_num_iter >= meta_num_iter - - -class TestLRInitializers: - - @pytest.mark.fast.with_args("kind", ["pc", "lrc", "geom"], only_fast=0) - def test_create_default_initializer(self, rng: jnp.ndarray, kind: str): - n, d, rank = 110, 2, 3 - x = jax.random.normal(rng, (n, d)) - geom = pointcloud.PointCloud(x) - - if kind == "pc": - pass - elif kind == "lrc": - geom = geom.to_LRCGeometry() - assert isinstance(geom, low_rank.LRCGeometry) - elif kind == "geom": - geom = geometry.Geometry(geom.cost_matrix) - else: - raise NotImplementedError(geom) - prob = linear_problems.LinearProblem(geom) - - solver = sinkhorn_lr.LRSinkhorn(rank=rank, initializer=None) - initializer = solver.create_initializer(prob) - - assert initializer.rank == rank - if kind in ("pc", "lrc"): - assert isinstance(initializer, initializers_lr.KMeansInitializer) - else: - assert isinstance(initializer, initializers_lr.RandomInitializer) - - q, r, g = initializer(prob) - - assert q.shape == (n, rank) - assert r.shape == (n, rank) - assert g.shape == (rank,) - - def test_explicitly_passing_initializer(self): - rank = 2 - initializer = initializers_lr.RandomInitializer(rank=rank) - solver = sinkhorn_lr.LRSinkhorn(rank=rank, initializer=initializer) - - assert solver.create_initializer(prob="not used") is initializer - - @pytest.mark.parametrize( - "initializer", ["random", "rank2", "k-means", "generalized-k-means"] - ) - @pytest.mark.parametrize("partial_init", ["q", "r", "g"]) - def test_partial_initialization( - self, rng: jnp.ndarray, initializer: str, partial_init: str - ): - n, d, rank = 100, 10, 6 - key1, key2, key3, key4 = jax.random.split(rng, 4) - x = jax.random.normal(key1, (n, d)) - pc = pointcloud.PointCloud(x, epsilon=5e-1) - prob = linear_problems.LinearProblem(pc) - q_init = jax.random.normal(key2, (n, rank)) - r_init = jax.random.normal(key2, (n, rank)) - g_init = jax.random.normal(key2, (rank,)) - - solver = sinkhorn_lr.LRSinkhorn(rank=rank, initializer=initializer) - initializer = solver.create_initializer(prob) - - if partial_init == "q": - q, _, _ = initializer(prob, q=q_init) - np.testing.assert_array_equal(q, q_init) - elif partial_init == "r": - _, r, _ = initializer(prob, r=r_init) - np.testing.assert_array_equal(r, r_init) - elif partial_init == "g": - _, _, g = initializer(prob, g=g_init) - np.testing.assert_array_equal(g, g_init) - else: - raise NotImplementedError(partial_init) - - @pytest.mark.fast.with_args("rank", [2, 4, 10, 13], only_fast=True) - def test_generalized_k_means_has_correct_rank( - self, rng: jnp.ndarray, rank: int - ): - n, d = 100, 10 - x = jax.random.normal(rng, (n, d)) - pc = pointcloud.PointCloud(x, epsilon=5e-1) - prob = linear_problems.LinearProblem(pc) - - solver = sinkhorn_lr.LRSinkhorn( - rank=rank, initializer="generalized-k-means" - ) - initializer = solver.create_initializer(prob) - - q, r, g = initializer(prob) - - assert jnp.linalg.matrix_rank(q) == rank - assert jnp.linalg.matrix_rank(r) == rank - - def test_generalized_k_means_matches_k_means(self, rng: jnp.ndarray): - n, d, rank = 120, 15, 5 - eps = 1e-1 - key1, key2 = jax.random.split(rng, 2) - x = jax.random.normal(key1, (n, d)) - y = jax.random.normal(key1, (n, d)) - - pc = pointcloud.PointCloud(x, y, epsilon=eps) - geom = geometry.Geometry(cost_matrix=pc.cost_matrix, epsilon=eps) - pc_problem = linear_problems.LinearProblem(pc) - geom_problem = linear_problems.LinearProblem(geom) - - solver = sinkhorn_lr.LRSinkhorn( - rank=rank, initializer="k-means", max_iterations=5000 - ) - pc_out = solver(pc_problem) - - solver = sinkhorn_lr.LRSinkhorn( - rank=rank, initializer="generalized-k-means", max_iterations=5000 - ) - geom_out = solver(geom_problem) - - with pytest.raises(AssertionError): - np.testing.assert_allclose(pc_out.costs, geom_out.costs) - - np.testing.assert_allclose( - pc_out.reg_ot_cost, geom_out.reg_ot_cost, atol=0.5, rtol=0.02 - ) - - @pytest.mark.parametrize("epsilon", [0., 1e-1]) - def test_better_initialization_helps(self, rng: jnp.ndarray, epsilon: float): - n, d, rank = 81, 13, 3 - key1, key2 = jax.random.split(rng, 2) - x = jax.random.normal(key1, (n, d)) - y = jax.random.normal(key2, (n, d)) - pc = pointcloud.PointCloud(x, y, epsilon=5e-1) - prob = linear_problems.LinearProblem(pc) - - solver_random = sinkhorn_lr.LRSinkhorn( - rank=rank, epsilon=epsilon, initializer="random", max_iterations=10000 - ) - solver_init = sinkhorn_lr.LRSinkhorn( - rank=rank, epsilon=epsilon, initializer="k-means", max_iterations=10000 - ) - - out_random = solver_random(prob) - out_init = solver_init(prob) - - assert out_random.converged - assert out_init.converged - # converged earlier - assert (out_init.errors > -1).sum() < (out_random.errors > -1).sum() - # converged to a better solution - assert out_init.reg_ot_cost < out_random.reg_ot_cost - - -class TestQuadraticInitializers: - - @pytest.mark.parametrize("kind", ["pc", "lrc", "geom"]) - def test_create_default_lr_initializer(self, rng: jnp.ndarray, kind: str): - n, d1, d2, rank = 150, 2, 3, 5 - eps = 1e-1 - key1, key2 = jax.random.split(rng, 2) - x = jax.random.normal(key1, (n, d1)) - y = jax.random.normal(key1, (n, d2)) - kwargs_init = {"foo": "bar"} - - geom_x = pointcloud.PointCloud(x, epsilon=eps) - geom_y = pointcloud.PointCloud(y, epsilon=eps) - if kind == "pc": - pass - elif kind == "lrc": - geom_x = geom_x.to_LRCGeometry() - geom_y = geom_y.to_LRCGeometry() - elif kind == "geom": - geom_x = geometry.Geometry(geom_x.cost_matrix, epsilon=eps) - geom_y = geometry.Geometry(geom_y.cost_matrix, epsilon=eps) - else: - raise NotImplementedError(kind) - prob = quad_problems.QuadraticProblem(geom_x, geom_y) - - solver = gromov_wasserstein.GromovWasserstein( - rank=rank, quad_initializer=None, kwargs_init=kwargs_init - ) - initializer = solver.create_initializer(prob) - - assert isinstance(initializer, quad_initializers.LRQuadraticInitializer) - assert initializer.rank == rank - linear_init = initializer._linear_lr_initializer - if kind in ("pc", "lrc"): - assert isinstance(linear_init, initializers_lr.KMeansInitializer) - else: - assert isinstance(linear_init, initializers_lr.RandomInitializer) - assert linear_init._kwargs == kwargs_init - - def test_non_lr_initializer(self): - solver = gromov_wasserstein.GromovWasserstein( - rank=-1, quad_initializer="not used" - ) - initializer = solver.create_initializer(prob="not used") - assert isinstance(initializer, quad_initializers.QuadraticInitializer) - - @pytest.mark.parametrize("rank", [-1, 2]) - def test_explicitly_passing_initializer(self, rank: int): - if rank == -1: - linear_init = init_lib.SortingInitializer() - quad_init = quad_initializers.QuadraticInitializer() - else: - linear_init = initializers_lr.Rank2Initializer(rank) - quad_init = quad_initializers.LRQuadraticInitializer(linear_init) - - solver = gromov_wasserstein.GromovWasserstein( - initializer=linear_init, - quad_initializer=quad_init, - ) - - assert solver.linear_ot_solver.initializer is linear_init - assert solver.quad_initializer is quad_init - if solver.is_low_rank: - assert solver.quad_initializer.rank == rank - - @pytest.mark.parametrize("eps", [0., 1e-2]) - def test_gw_better_initialization_helps(self, rng: jnp.ndarray, eps: float): - n, m, d1, d2, rank = 123, 124, 12, 10, 5 - key1, key2, key3, key4 = jax.random.split(rng, 4) - - geom_x = pointcloud.PointCloud( - jax.random.normal(key1, (n, d1)), - jax.random.normal(key2, (n, d1)), - epsilon=eps, - ) - geom_y = pointcloud.PointCloud( - jax.random.normal(key3, (m, d2)), - jax.random.normal(key4, (m, d2)), - epsilon=eps, - ) - problem = quad_problems.QuadraticProblem(geom_x, geom_y) - solver_random = gromov_wasserstein.GromovWasserstein( - rank=rank, - initializer="random", - quad_initializer="random", - epsilon=eps, - store_inner_errors=True, - ) - solver_kmeans = gromov_wasserstein.GromovWasserstein( - rank=rank, - initializer="k-means", - quad_initializer="k-means", - epsilon=eps, - store_inner_errors=True - ) - - out_random = solver_random(problem) - out_kmeans = solver_kmeans(problem) - - assert out_random.reg_gw_cost - out_kmeans.reg_gw_cost >= 1. - random_errors = out_random.errors[out_random.errors > -1] - kmeans_errors = out_kmeans.errors[out_kmeans.errors > -1] - np.testing.assert_array_equal(random_errors >= 0., True) - np.testing.assert_array_equal(kmeans_errors >= 0., True) diff --git a/tests/geometry/geometry_costs_test.py b/tests/geometry/costs_test.py similarity index 100% rename from tests/geometry/geometry_costs_test.py rename to tests/geometry/costs_test.py diff --git a/tests/geometry/graph_test.py b/tests/geometry/graph_test.py index 3bf976e9f..58f766890 100644 --- a/tests/geometry/graph_test.py +++ b/tests/geometry/graph_test.py @@ -11,10 +11,11 @@ from networkx.generators import balanced_tree, random_graphs from typing_extensions import Literal -from ott.core import decomposition -from ott.core import implicit_differentiation as implicit_lib -from ott.core import linear_problems, sinkhorn from ott.geometry import geometry, graph +from ott.math import decomposition +from ott.problems.linear import linear_problem +from ott.solvers.linear import implicit_differentiation as implicit_lib +from ott.solvers.linear import sinkhorn # we mix both dense/sparse tests sksparse = pytest.importorskip("sksparse") @@ -373,7 +374,7 @@ def test_graph_sinkhorn( def callback(geom: geometry.Geometry) -> sinkhorn.SinkhornOutput: solver = sinkhorn.Sinkhorn(lse_mode=False) - problem = linear_problems.LinearProblem(geom) + problem = linear_problem.LinearProblem(geom) return solver(problem) n, eps, tol = 11, 1e-5, 1e-3 @@ -422,7 +423,7 @@ def callback( geom = graph.Graph(G, t=1.) solver = sinkhorn.Sinkhorn(lse_mode=False, **kwargs) - problem = linear_problems.LinearProblem(geom) + problem = linear_problem.LinearProblem(geom) return solver(problem).reg_ot_cost diff --git a/tests/geometry/geometry_lr_test.py b/tests/geometry/low_rank_test.py similarity index 100% rename from tests/geometry/geometry_lr_test.py rename to tests/geometry/low_rank_test.py diff --git a/tests/geometry/geometry_pointcloud_apply_test.py b/tests/geometry/pointcloud_test.py similarity index 100% rename from tests/geometry/geometry_pointcloud_apply_test.py rename to tests/geometry/pointcloud_test.py diff --git a/tests/geometry/scaling_cost_test.py b/tests/geometry/scaling_cost_test.py index 2767c1ea5..8a6752608 100644 --- a/tests/geometry/scaling_cost_test.py +++ b/tests/geometry/scaling_cost_test.py @@ -19,8 +19,9 @@ import numpy as np import pytest -from ott.core import linear_problems, sinkhorn, sinkhorn_lr from ott.geometry import geometry, low_rank, pointcloud +from ott.problems.linear import linear_problem +from ott.solvers.linear import sinkhorn, sinkhorn_lr class TestScaleCost: @@ -154,7 +155,7 @@ def test_scale_cost_low_rank(self, scale: Union[str, float]): def apply_sinkhorn(cost1, cost2, scale_cost): geom = low_rank.LRCGeometry(cost1, cost2, scale_cost=scale_cost) - ot_prob = linear_problems.LinearProblem(geom, self.a, self.b) + ot_prob = linear_problem.LinearProblem(geom, self.a, self.b) solver = sinkhorn_lr.LRSinkhorn(rank=5, threshold=1e-3) out = solver(ot_prob) return geom, out diff --git a/tests/geometry/geometry_subset_test.py b/tests/geometry/subsetting_test.py similarity index 98% rename from tests/geometry/geometry_subset_test.py rename to tests/geometry/subsetting_test.py index 5d57e6eb9..2369e4b28 100644 --- a/tests/geometry/geometry_subset_test.py +++ b/tests/geometry/subsetting_test.py @@ -11,7 +11,9 @@ @pytest.fixture() -def pc_masked(rng: jnp.ndarray) -> Tuple[pointcloud.PointCloud, Tuple]: +def pc_masked( + rng: jnp.ndarray +) -> Tuple[pointcloud.PointCloud, pointcloud.PointCloud]: n, m = 20, 30 key1, key2 = jax.random.split(rng, 2) # x = jnp.full((n,), fill_value=1.) diff --git a/tests/initializers/linear/sinkhorn_init_test.py b/tests/initializers/linear/sinkhorn_init_test.py new file mode 100644 index 000000000..27e3f347d --- /dev/null +++ b/tests/initializers/linear/sinkhorn_init_test.py @@ -0,0 +1,331 @@ +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Lint as: python3 +"""Tests for Sinkhorn initializers.""" +import functools + +import jax +import jax.numpy as jnp +import numpy as np +import pytest + +import ott.initializers.nn.initializers +from ott.geometry import geometry, pointcloud +from ott.initializers.linear import initializers as lin_init +from ott.problems.linear import linear_problem +from ott.solvers.linear import sinkhorn + + +def create_sorting_problem(rng, n, epsilon=0.01, online=False): + # define ot problem + x_init = jnp.array([-1., 0, .22]) + y_init = jnp.array([0., 0, 1.1]) + x_rng, y_rng = jax.random.split(rng) + + x = jnp.concatenate([x_init, 10 + jnp.abs(jax.random.normal(x_rng, (n,)))]) + y = jnp.concatenate([y_init, 10 + jnp.abs(jax.random.normal(y_rng, (n,)))]) + + x = jnp.sort(x) + y = jnp.sort(y) + + n = len(x) + m = len(y) + a = jnp.ones(n) / n + b = jnp.ones(m) / m + + batch_size = 3 if online else None + geom = pointcloud.PointCloud( + x.reshape(-1, 1), + y.reshape(-1, 1), + epsilon=epsilon, + batch_size=batch_size + ) + ot_problem = linear_problem.LinearProblem(geom=geom, a=a, b=b) + + return ot_problem + + +def create_ot_problem(rng, n, m, d, epsilon=0.01, online=False): + # define ot problem + x_rng, y_rng = jax.random.split(rng) + + mu_a = jnp.array([-1, 1]) * 5 + mu_b = jnp.array([0, 0]) + + x = jax.random.normal(x_rng, (n, d)) + mu_a + y = jax.random.normal(y_rng, (m, d)) + mu_b + + a = jnp.ones(n) / n + b = jnp.ones(m) / m + + batch_size = 3 if online else None + geom = pointcloud.PointCloud(x, y, epsilon=epsilon, batch_size=batch_size) + + ot_problem = linear_problem.LinearProblem(geom=geom, a=a, b=b) + return ot_problem + + +# define sinkhorn functions +@functools.partial(jax.jit, static_argnames=['lse_mode', 'vector_min']) +def run_sinkhorn_sort_init( + x, y, a=None, b=None, epsilon=0.01, vector_min=True, lse_mode=True +): + geom = pointcloud.PointCloud(x, y, epsilon=epsilon) + sort_init = lin_init.SortingInitializer(vectorized_update=vector_min) + out = sinkhorn.sinkhorn( + geom, a=a, b=b, jit=True, initializer=sort_init, lse_mode=lse_mode + ) + return out + + +@functools.partial(jax.jit, static_argnames=['lse_mode']) +def run_sinkhorn(x, y, a=None, b=None, epsilon=0.01, lse_mode=True): + geom = pointcloud.PointCloud(x, y, epsilon=epsilon) + out = sinkhorn.sinkhorn(geom, a=a, b=b, jit=True, lse_mode=lse_mode) + return out + + +@functools.partial(jax.jit, static_argnames=['lse_mode']) +def run_sinkhorn_gaus_init(x, y, a=None, b=None, epsilon=0.01, lse_mode=True): + geom = pointcloud.PointCloud(x, y, epsilon=epsilon) + out = sinkhorn.sinkhorn( + geom, + a=a, + b=b, + jit=True, + initializer=lin_init.GaussianInitializer(), + lse_mode=lse_mode + ) + return out + + +@pytest.mark.fast +class TestSinkhornInitializers: + + def test_init_pytree(self): + + @jax.jit + def init_sort(): + init = lin_init.SortingInitializer() + return init + + @jax.jit + def init_gaus(): + init = lin_init.GaussianInitializer() + return init + + _ = init_gaus() + _ = init_sort() + + @pytest.mark.parametrize( + "init", [ + "default", "gaussian", "sorting", + lin_init.DefaultInitializer(), "non-existent" + ] + ) + def test_create_initializer(self, init: str): + solver = sinkhorn.Sinkhorn(initializer=init) + expected_types = { + "default": lin_init.DefaultInitializer, + "gaussian": lin_init.GaussianInitializer, + "sorting": lin_init.SortingInitializer, + } + + if isinstance(init, lin_init.SinkhornInitializer): + assert solver.create_initializer() is init + elif init == "non-existent": + with pytest.raises(NotImplementedError, match=r""): + _ = solver.create_initializer() + else: + actual = solver.create_initializer() + expected_type = expected_types[init] + assert isinstance(actual, expected_type) + + @pytest.mark.parametrize( + "vector_min, lse_mode", [(True, True), (True, False), (False, True)] + ) + def test_sorting_init(self, vector_min: bool, lse_mode: bool): + """Tests sorting dual initializer.""" + rng = jax.random.PRNGKey(42) + n = 500 + epsilon = 0.01 + + ot_problem = create_sorting_problem( + rng=rng, n=n, epsilon=epsilon, online=False + ) + # run sinkhorn + sink_out_base = run_sinkhorn( + x=ot_problem.geom.x, + y=ot_problem.geom.y, + a=ot_problem.a, + b=ot_problem.b, + epsilon=epsilon + ) + base_num_iter = jnp.sum(sink_out_base.errors > -1) + + sink_out_init = run_sinkhorn_sort_init( + x=ot_problem.geom.x, + y=ot_problem.geom.y, + a=ot_problem.a, + b=ot_problem.b, + epsilon=epsilon, + vector_min=vector_min, + lse_mode=lse_mode + ) + sort_num_iter = jnp.sum(sink_out_init.errors > -1) + + # check initializer is better or equal + if lse_mode: + assert base_num_iter >= sort_num_iter + + def test_sorting_init_online(self, rng: jnp.ndarray): + n = 100 + epsilon = 0.01 + + ot_problem = create_sorting_problem( + rng=rng, n=n, epsilon=epsilon, online=True + ) + sort_init = lin_init.SortingInitializer(vectorized_update=True) + with pytest.raises(AssertionError, match=r"online"): + sort_init.init_dual_a(ot_problem, lse_mode=True) + + def test_sorting_init_square_cost(self, rng: jnp.ndarray): + n = 100 + m = 150 + d = 1 + epsilon = 0.01 + + ot_problem = create_ot_problem(rng, n, m, d, epsilon=epsilon, online=False) + sort_init = lin_init.SortingInitializer(vectorized_update=True) + with pytest.raises(AssertionError, match=r"square"): + sort_init.init_dual_a(ot_problem, lse_mode=True) + + def test_default_initializer(self, rng: jnp.ndarray): + """Tests default initializer""" + n = 200 + m = 200 + d = 2 + epsilon = 0.01 + + ot_problem = create_ot_problem(rng, n, m, d, epsilon=epsilon, online=False) + + default_potential_a = lin_init.DefaultInitializer().init_dual_a( + ot_problem, lse_mode=True + ) + default_potential_b = lin_init.DefaultInitializer().init_dual_b( + ot_problem, lse_mode=True + ) + + # check default is 0 + np.testing.assert_array_equal(0., default_potential_a) + np.testing.assert_array_equal(0., default_potential_b) + + def test_gauss_pointcloud_geom(self, rng: jnp.ndarray): + n = 200 + m = 200 + d = 2 + epsilon = 0.01 + + ot_problem = create_ot_problem(rng, n, m, d, epsilon=epsilon, online=False) + + gaus_init = lin_init.GaussianInitializer() + new_geom = geometry.Geometry( + cost_matrix=ot_problem.geom.cost_matrix, epsilon=epsilon + ) + ot_problem = linear_problem.LinearProblem( + geom=new_geom, a=ot_problem.a, b=ot_problem.b + ) + + with pytest.raises(AssertionError, match=r"point cloud"): + gaus_init.init_dual_a(ot_problem, lse_mode=True) + + @pytest.mark.parametrize('lse_mode', [True, False]) + def test_gauss_initializer(self, lse_mode, rng: jnp.ndarray): + """Tests Gaussian initializer""" + # define OT problem + n = 200 + m = 200 + d = 2 + epsilon = 0.01 + + ot_problem = create_ot_problem(rng, n, m, d, epsilon=epsilon, online=False) + + # run sinkhorn + sink_out = run_sinkhorn( + x=ot_problem.geom.x, + y=ot_problem.geom.y, + a=ot_problem.a, + b=ot_problem.b, + epsilon=epsilon, + lse_mode=lse_mode + ) + base_num_iter = jnp.sum(sink_out.errors > -1) + sink_out = run_sinkhorn_gaus_init( + x=ot_problem.geom.x, + y=ot_problem.geom.y, + a=ot_problem.a, + b=ot_problem.b, + epsilon=epsilon, + lse_mode=lse_mode + ) + gaus_num_iter = jnp.sum(sink_out.errors > -1) + + # check initializer is better + if lse_mode: + assert base_num_iter >= gaus_num_iter + + @pytest.mark.parametrize('lse_mode', [True, False]) + def test_meta_initializer(self, lse_mode, rng: jnp.ndarray): + """Tests Meta initializer""" + # define OT problem + n = 200 + m = 200 + d = 2 + epsilon = 0.01 + + ot_problem = create_ot_problem(rng, n, m, d, epsilon=epsilon, online=False) + a = ot_problem.a + b = ot_problem.b + geom = ot_problem.geom + + # run sinkhorn + sink_out = run_sinkhorn( + x=ot_problem.geom.x, + y=ot_problem.geom.y, + a=ot_problem.a, + b=ot_problem.b, + epsilon=epsilon, + lse_mode=lse_mode + ) + base_num_iter = jnp.sum(sink_out.errors > -1) + + # Overfit the initializer to the problem. + meta_initializer = ott.initializers.nn.initializers.MetaInitializer(geom) + for _ in range(100): + _, _, meta_initializer.state = meta_initializer.update( + meta_initializer.state, a=a, b=b + ) + + sink_out = sinkhorn.sinkhorn( + geom, + a=a, + b=b, + jit=True, + initializer=meta_initializer, + lse_mode=lse_mode + ) + meta_num_iter = jnp.sum(sink_out.errors > -1) + + # check initializer is better + if lse_mode: + assert base_num_iter >= meta_num_iter diff --git a/tests/initializers/linear/sinkhorn_lr_init_test.py b/tests/initializers/linear/sinkhorn_lr_init_test.py new file mode 100644 index 000000000..bff23ec2e --- /dev/null +++ b/tests/initializers/linear/sinkhorn_lr_init_test.py @@ -0,0 +1,171 @@ +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Lint as: python3 +"""Tests for Sinkhorn initializers.""" + +import jax +import jax.numpy as jnp +import numpy as np +import pytest + +from ott.geometry import geometry, low_rank, pointcloud +from ott.initializers.linear import initializers_lr +from ott.problems.linear import linear_problem +from ott.solvers.linear import sinkhorn_lr + + +class TestLRInitializers: + + @pytest.mark.fast.with_args("kind", ["pc", "lrc", "geom"], only_fast=0) + def test_create_default_initializer(self, rng: jnp.ndarray, kind: str): + n, d, rank = 110, 2, 3 + x = jax.random.normal(rng, (n, d)) + geom = pointcloud.PointCloud(x) + + if kind == "pc": + pass + elif kind == "lrc": + geom = geom.to_LRCGeometry() + assert isinstance(geom, low_rank.LRCGeometry) + elif kind == "geom": + geom = geometry.Geometry(geom.cost_matrix) + else: + raise NotImplementedError(geom) + prob = linear_problem.LinearProblem(geom) + + solver = sinkhorn_lr.LRSinkhorn(rank=rank, initializer=None) + initializer = solver.create_initializer(prob) + + assert initializer.rank == rank + if kind in ("pc", "lrc"): + assert isinstance(initializer, initializers_lr.KMeansInitializer) + else: + assert isinstance(initializer, initializers_lr.RandomInitializer) + + q, r, g = initializer(prob) + + assert q.shape == (n, rank) + assert r.shape == (n, rank) + assert g.shape == (rank,) + + def test_explicitly_passing_initializer(self): + rank = 2 + initializer = initializers_lr.RandomInitializer(rank=rank) + solver = sinkhorn_lr.LRSinkhorn(rank=rank, initializer=initializer) + + assert solver.create_initializer(prob="not used") is initializer + + @pytest.mark.parametrize( + "initializer", ["random", "rank2", "k-means", "generalized-k-means"] + ) + @pytest.mark.parametrize("partial_init", ["q", "r", "g"]) + def test_partial_initialization( + self, rng: jnp.ndarray, initializer: str, partial_init: str + ): + n, d, rank = 100, 10, 6 + key1, key2, key3, key4 = jax.random.split(rng, 4) + x = jax.random.normal(key1, (n, d)) + pc = pointcloud.PointCloud(x, epsilon=5e-1) + prob = linear_problem.LinearProblem(pc) + q_init = jax.random.normal(key2, (n, rank)) + r_init = jax.random.normal(key2, (n, rank)) + g_init = jax.random.normal(key2, (rank,)) + + solver = sinkhorn_lr.LRSinkhorn(rank=rank, initializer=initializer) + initializer = solver.create_initializer(prob) + + if partial_init == "q": + q, _, _ = initializer(prob, q=q_init) + np.testing.assert_array_equal(q, q_init) + elif partial_init == "r": + _, r, _ = initializer(prob, r=r_init) + np.testing.assert_array_equal(r, r_init) + elif partial_init == "g": + _, _, g = initializer(prob, g=g_init) + np.testing.assert_array_equal(g, g_init) + else: + raise NotImplementedError(partial_init) + + @pytest.mark.fast.with_args("rank", [2, 4, 10, 13], only_fast=True) + def test_generalized_k_means_has_correct_rank( + self, rng: jnp.ndarray, rank: int + ): + n, d = 100, 10 + x = jax.random.normal(rng, (n, d)) + pc = pointcloud.PointCloud(x, epsilon=5e-1) + prob = linear_problem.LinearProblem(pc) + + solver = sinkhorn_lr.LRSinkhorn( + rank=rank, initializer="generalized-k-means" + ) + initializer = solver.create_initializer(prob) + + q, r, g = initializer(prob) + + assert jnp.linalg.matrix_rank(q) == rank + assert jnp.linalg.matrix_rank(r) == rank + + def test_generalized_k_means_matches_k_means(self, rng: jnp.ndarray): + n, d, rank = 120, 15, 5 + eps = 1e-1 + key1, key2 = jax.random.split(rng, 2) + x = jax.random.normal(key1, (n, d)) + y = jax.random.normal(key1, (n, d)) + + pc = pointcloud.PointCloud(x, y, epsilon=eps) + geom = geometry.Geometry(cost_matrix=pc.cost_matrix, epsilon=eps) + pc_problem = linear_problem.LinearProblem(pc) + geom_problem = linear_problem.LinearProblem(geom) + + solver = sinkhorn_lr.LRSinkhorn( + rank=rank, initializer="k-means", max_iterations=5000 + ) + pc_out = solver(pc_problem) + + solver = sinkhorn_lr.LRSinkhorn( + rank=rank, initializer="generalized-k-means", max_iterations=5000 + ) + geom_out = solver(geom_problem) + + with pytest.raises(AssertionError): + np.testing.assert_allclose(pc_out.costs, geom_out.costs) + + np.testing.assert_allclose( + pc_out.reg_ot_cost, geom_out.reg_ot_cost, atol=0.5, rtol=0.02 + ) + + @pytest.mark.parametrize("epsilon", [0., 1e-1]) + def test_better_initialization_helps(self, rng: jnp.ndarray, epsilon: float): + n, d, rank = 81, 13, 3 + key1, key2 = jax.random.split(rng, 2) + x = jax.random.normal(key1, (n, d)) + y = jax.random.normal(key2, (n, d)) + pc = pointcloud.PointCloud(x, y, epsilon=5e-1) + prob = linear_problem.LinearProblem(pc) + + solver_random = sinkhorn_lr.LRSinkhorn( + rank=rank, epsilon=epsilon, initializer="random", max_iterations=10000 + ) + solver_init = sinkhorn_lr.LRSinkhorn( + rank=rank, epsilon=epsilon, initializer="k-means", max_iterations=10000 + ) + + out_random = solver_random(prob) + out_init = solver_init(prob) + + assert out_random.converged + assert out_init.converged + # converged earlier + assert (out_init.errors > -1).sum() < (out_random.errors > -1).sum() + # converged to a better solution + assert out_init.reg_ot_cost < out_random.reg_ot_cost diff --git a/tests/initializers/quadratic/gw_init_test.py b/tests/initializers/quadratic/gw_init_test.py new file mode 100644 index 000000000..b900b60fc --- /dev/null +++ b/tests/initializers/quadratic/gw_init_test.py @@ -0,0 +1,132 @@ +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Lint as: python3 +"""Tests for Gromov-Wasserstein initializers.""" + +import jax +import jax.numpy as jnp +import numpy as np +import pytest + +from ott.geometry import geometry, pointcloud +from ott.initializers.linear import initializers as lin_init +from ott.initializers.linear import initializers_lr +from ott.initializers.quadratic import initializers as quad_init +from ott.problems.quadratic import quadratic_problem +from ott.solvers.quadratic import gromov_wasserstein + + +class TestQuadraticInitializers: + + @pytest.mark.parametrize("kind", ["pc", "lrc", "geom"]) + def test_create_default_lr_initializer(self, rng: jnp.ndarray, kind: str): + n, d1, d2, rank = 150, 2, 3, 5 + eps = 1e-1 + key1, key2 = jax.random.split(rng, 2) + x = jax.random.normal(key1, (n, d1)) + y = jax.random.normal(key1, (n, d2)) + kwargs_init = {"foo": "bar"} + + geom_x = pointcloud.PointCloud(x, epsilon=eps) + geom_y = pointcloud.PointCloud(y, epsilon=eps) + if kind == "pc": + pass + elif kind == "lrc": + geom_x = geom_x.to_LRCGeometry() + geom_y = geom_y.to_LRCGeometry() + elif kind == "geom": + geom_x = geometry.Geometry(geom_x.cost_matrix, epsilon=eps) + geom_y = geometry.Geometry(geom_y.cost_matrix, epsilon=eps) + else: + raise NotImplementedError(kind) + prob = quadratic_problem.QuadraticProblem(geom_x, geom_y) + + solver = gromov_wasserstein.GromovWasserstein( + rank=rank, quad_initializer=None, kwargs_init=kwargs_init + ) + initializer = solver.create_initializer(prob) + + assert isinstance(initializer, quad_init.LRQuadraticInitializer) + assert initializer.rank == rank + linear_init = initializer._linear_lr_initializer + if kind in ("pc", "lrc"): + assert isinstance(linear_init, initializers_lr.KMeansInitializer) + else: + assert isinstance(linear_init, initializers_lr.RandomInitializer) + assert linear_init._kwargs == kwargs_init + + def test_non_lr_initializer(self): + solver = gromov_wasserstein.GromovWasserstein( + rank=-1, quad_initializer="not used" + ) + initializer = solver.create_initializer(prob="not used") + assert isinstance(initializer, quad_init.QuadraticInitializer) + + @pytest.mark.parametrize("rank", [-1, 2]) + def test_explicitly_passing_initializer(self, rank: int): + if rank == -1: + linear_init = lin_init.SortingInitializer() + q_init = quad_init.QuadraticInitializer() + else: + linear_init = initializers_lr.Rank2Initializer(rank) + q_init = quad_init.LRQuadraticInitializer(linear_init) + + solver = gromov_wasserstein.GromovWasserstein( + initializer=linear_init, + quad_initializer=q_init, + ) + + assert solver.linear_ot_solver.initializer is linear_init + assert solver.quad_initializer is q_init + if solver.is_low_rank: + assert solver.quad_initializer.rank == rank + + @pytest.mark.parametrize("eps", [0., 1e-2]) + def test_gw_better_initialization_helps(self, rng: jnp.ndarray, eps: float): + n, m, d1, d2, rank = 123, 124, 12, 10, 5 + key1, key2, key3, key4 = jax.random.split(rng, 4) + + geom_x = pointcloud.PointCloud( + jax.random.normal(key1, (n, d1)), + jax.random.normal(key2, (n, d1)), + epsilon=eps, + ) + geom_y = pointcloud.PointCloud( + jax.random.normal(key3, (m, d2)), + jax.random.normal(key4, (m, d2)), + epsilon=eps, + ) + problem = quadratic_problem.QuadraticProblem(geom_x, geom_y) + solver_random = gromov_wasserstein.GromovWasserstein( + rank=rank, + initializer="random", + quad_initializer="random", + epsilon=eps, + store_inner_errors=True, + ) + solver_kmeans = gromov_wasserstein.GromovWasserstein( + rank=rank, + initializer="k-means", + quad_initializer="k-means", + epsilon=eps, + store_inner_errors=True + ) + + out_random = solver_random(problem) + out_kmeans = solver_kmeans(problem) + + assert out_random.reg_gw_cost - out_kmeans.reg_gw_cost >= 1. + random_errors = out_random.errors[out_random.errors > -1] + kmeans_errors = out_kmeans.errors[out_kmeans.errors > -1] + np.testing.assert_array_equal(random_errors >= 0., True) + np.testing.assert_array_equal(kmeans_errors >= 0., True) diff --git a/tests/geometry/geometry_lse_test.py b/tests/math/lse_test.py similarity index 96% rename from tests/geometry/geometry_lse_test.py rename to tests/math/lse_test.py index d7a030107..f0b076147 100644 --- a/tests/geometry/geometry_lse_test.py +++ b/tests/math/lse_test.py @@ -20,7 +20,7 @@ import numpy as np import pytest -from ott.geometry import ops +from ott.math import utils as mu @pytest.mark.fast @@ -36,7 +36,7 @@ def test_lse(self, rng: jnp.ndarray): b_1 = jax.random.normal(keys[2], (n, 1)) def lse_(x, axis, b, return_sign): - out = ops.logsumexp(x, axis, False, b, return_sign) + out = mu.logsumexp(x, axis, False, b, return_sign) return jnp.sum(out[0] if return_sign else out) lse = jax.value_and_grad(lse_, argnums=(0, 2)) diff --git a/tests/geometry/matrix_square_root_test.py b/tests/math/matrix_square_root_test.py similarity index 98% rename from tests/geometry/matrix_square_root_test.py rename to tests/math/matrix_square_root_test.py index 0b48fb706..b5d3c08d8 100644 --- a/tests/geometry/matrix_square_root_test.py +++ b/tests/math/matrix_square_root_test.py @@ -21,7 +21,7 @@ import numpy as np import pytest -from ott.geometry import matrix_square_root +from ott.math import matrix_square_root def _get_random_spd_matrix(dim: int, key: jnp.ndarray): @@ -56,7 +56,7 @@ def _get_test_fn( unit = jax.random.normal(key=subkey3, shape=(dim, dim)) unit /= jnp.sqrt(jnp.sum(unit ** 2.)) - def _test_fn(x: float) -> float: + def _test_fn(x: jnp.ndarray) -> jnp.ndarray: # m is the product of 2 symmetric, positive definite matrices # so it will be positive definite but not necessarily symmetric m = jnp.matmul(m0, m1 + x * dx) diff --git a/tests/core/potentials_test.py b/tests/problems/linear/potentials_test.py similarity index 90% rename from tests/core/potentials_test.py rename to tests/problems/linear/potentials_test.py index f5b73d7cf..5210b0e8e 100644 --- a/tests/core/potentials_test.py +++ b/tests/problems/linear/potentials_test.py @@ -3,8 +3,9 @@ import numpy as np import pytest -from ott.core import Sinkhorn, linear_problems from ott.geometry import costs, pointcloud +from ott.problems.linear import linear_problem +from ott.solvers.linear import sinkhorn from ott.tools import sinkhorn_divergence from ott.tools.gaussian_mixture import gaussian @@ -24,8 +25,8 @@ def test_entropic_potentials_dist(self, rng: jnp.ndarray, eps: float): y = g2.sample(key2, n2) geom = pointcloud.PointCloud(x, y, epsilon=eps) - prob = linear_problems.LinearProblem(geom) - out = Sinkhorn()(prob) + prob = linear_problem.LinearProblem(geom) + out = sinkhorn.Sinkhorn()(prob) assert out.converged potentials = out.to_dual_potentials() @@ -50,8 +51,8 @@ def test_entropic_potentials_displacement( y = g2.sample(key2, n2) geom = pointcloud.PointCloud(x, y, epsilon=eps) - prob = linear_problems.LinearProblem(geom) - out = Sinkhorn()(prob) + prob = linear_problem.LinearProblem(geom) + out = sinkhorn.Sinkhorn()(prob) assert out.converged potentials = out.to_dual_potentials() @@ -82,8 +83,8 @@ def test_entropic_potentials_sqpnorm( y = jax.random.normal(keys[1], (n2, d)) + 2 geom = pointcloud.PointCloud(x, y, epsilon=eps, cost_fn=cost_fn) - prob = linear_problems.LinearProblem(geom) - out = Sinkhorn()(prob) + prob = linear_problem.LinearProblem(geom) + out = sinkhorn.Sinkhorn()(prob) assert out.converged potentials = out.to_dual_potentials() @@ -118,8 +119,8 @@ def test_entropic_potentials_pnorm( y = jax.random.normal(keys[1], (n2, d)) + 2 geom = pointcloud.PointCloud(x, y, epsilon=eps, cost_fn=cost_fn) - prob = linear_problems.LinearProblem(geom) - out = Sinkhorn()(prob) + prob = linear_problem.LinearProblem(geom) + out = sinkhorn.Sinkhorn()(prob) assert out.converged potentials = out.to_dual_potentials() @@ -147,11 +148,11 @@ def test_distance_differentiability(self, rng: jnp.ndarray, jit: bool): x = jax.random.normal(key1, (n, d)) y = jax.random.normal(key2, (m, d)) - prob = linear_problems.LinearProblem(pointcloud.PointCloud(x, y)) + prob = linear_problem.LinearProblem(pointcloud.PointCloud(x, y)) v_x = jax.random.normal(key3, shape=x.shape) v_x = (v_x / jnp.linalg.norm(v_x, axis=-1, keepdims=True)) * 1e-3 - pots = Sinkhorn()(prob).to_dual_potentials() + pots = sinkhorn.Sinkhorn()(prob).to_dual_potentials() grad_dist = jax.grad(pots.distance) if jit: @@ -175,9 +176,9 @@ def test_potentials_sinkhorn_divergence( y = jax.random.normal(key2, (m, d)) + mu1 x_test = jax.random.normal(key3, (n, d)) + mu0 geom = pointcloud.PointCloud(x, y, epsilon=eps) - prob = linear_problems.LinearProblem(geom) + prob = linear_problem.LinearProblem(geom) - sink_pots = Sinkhorn()(prob).to_dual_potentials() + sink_pots = sinkhorn.Sinkhorn()(prob).to_dual_potentials() div_pots = sinkhorn_divergence.sinkhorn_divergence( type(geom), x, y, epsilon=eps ).to_dual_potentials() diff --git a/tests/core/continuous_barycenter_test.py b/tests/solvers/linear/continuous_barycenter_test.py similarity index 64% rename from tests/core/continuous_barycenter_test.py rename to tests/solvers/linear/continuous_barycenter_test.py index cc25bf747..36fdf1d97 100644 --- a/tests/core/continuous_barycenter_test.py +++ b/tests/solvers/linear/continuous_barycenter_test.py @@ -12,17 +12,18 @@ # limitations under the License. # Lint as: python3 -"""Tests for Continuous barycenters.""" +"""Tests for continuous barycenter.""" import functools -from typing import Any, Optional, Sequence, Tuple +from typing import Tuple import jax import jax.numpy as jnp import numpy as np import pytest -from ott.core import bar_problems, continuous_barycenter, gw_barycenter, segment -from ott.geometry import costs, pointcloud +from ott.geometry import costs, segment +from ott.problems.linear import barycenter_problem +from ott.solvers.linear import continuous_barycenter as cb from ott.tools.gaussian_mixture import gaussian_mixture means_and_covs_to_x = jax.vmap(costs.mean_and_cov_to_x, in_axes=[0, 0, None]) @@ -70,7 +71,7 @@ def test_euclidean_barycenter( b.append(c / jnp.sum(c)) b = jnp.concatenate(b, axis=0) # Set a barycenter problem with 8 measures, of irregular sizes. - bar_prob = bar_problems.BarycenterProblem( + bar_prob = barycenter_problem.BarycenterProblem( y, b, epsilon=epsilon, @@ -84,9 +85,7 @@ def test_euclidean_barycenter( # Define solver threshold = 1e-3 - solver = continuous_barycenter.WassersteinBarycenter( - rank=rank, threshold=threshold, jit=jit - ) + solver = cb.WassersteinBarycenter(rank=rank, threshold=threshold, jit=jit) # Set barycenter size to 31. bar_size = 31 @@ -119,19 +118,21 @@ def test_barycenter_jit(self, rng: jnp.ndarray, segment_before: bool): @functools.partial(jax.jit, static_argnums=(2, 3)) def barycenter( - y: jnp.ndarray, b: jnp.ndarray, segment_before: bool, - num_per_segment: int - ) -> continuous_barycenter.BarycenterState: + y: jnp.ndarray, + b: jnp.ndarray, + segment_before: bool, + num_per_segment: Tuple[int, ...], + ) -> cb.BarycenterState: if segment_before: y, b = segment.segment_point_cloud( x=y, a=b, num_per_segment=num_per_segment ) - bar_prob = bar_problems.BarycenterProblem(y, b, epsilon=1e-1) + bar_prob = barycenter_problem.BarycenterProblem(y, b, epsilon=1e-1) else: - bar_prob = bar_problems.BarycenterProblem( + bar_prob = barycenter_problem.BarycenterProblem( y, b, epsilon=1e-1, num_per_segment=num_per_segment ) - solver = continuous_barycenter.WassersteinBarycenter(threshold=threshold) + solver = cb.WassersteinBarycenter(threshold=threshold) return solver(bar_prob) rngs = jax.random.split(rng, 20) @@ -218,9 +219,9 @@ def test_bures_barycenter( num_segments=num_measures, max_measure_size=num_components, num_per_segment=(num_components, num_components), - padding_vector=bures_cost.padder(y.shape[1]), + padding_vector=bures_cost._padder(y.shape[1]), ) - bar_p = bar_problems.BarycenterProblem( + bar_p = barycenter_problem.BarycenterProblem( seg_y, seg_b, weights=barycentric_weights, @@ -231,9 +232,7 @@ def test_bures_barycenter( assert bar_p.max_measure_size == seg_y.shape[1] assert bar_p.ndim == seg_y.shape[2] - solver = continuous_barycenter.WassersteinBarycenter( - lse_mode=lse_mode, jit=jit - ) + solver = cb.WassersteinBarycenter(lse_mode=lse_mode, jit=jit) out = solver(bar_p, bar_size=bar_size, x_init=x_init) barycenter = out.x @@ -318,17 +317,17 @@ def test_bures_barycenter_different_number_of_components( for i in range(num_measures)] # positions of mass of the measures - ys = jnp.vstack( + ys = jnp.vstack([ means_and_covs_to_x(means_covs[i][0], means_covs[i][1], dim) for i in range(num_measures) - ) + ]) # mass distribution of the measures weights = [ gmm_generators[i].component_weight_ob.probs() for i in range(num_measures) ] - bs = jnp.hstack(jnp.array(weights[i]) for i in range(num_measures)) + bs = jnp.hstack([jnp.array(weights[i]) for i in range(num_measures)]) # random initialization of the barycenter gmm_generator = gaussian_mixture.GaussianMixture.from_random( @@ -340,7 +339,7 @@ def test_bures_barycenter_different_number_of_components( # test second interface for segmentation seg_ids = jnp.repeat(jnp.arange(num_measures), n_components) - bar_p = bar_problems.BarycenterProblem( + bar_p = barycenter_problem.BarycenterProblem( y=ys, b=bs, weights=barycentric_weights, @@ -354,7 +353,7 @@ def test_bures_barycenter_different_number_of_components( assert bar_p.num_measures == num_measures assert bar_p.ndim == ys.shape[-1] - solver = continuous_barycenter.WassersteinBarycenter(lse_mode=True, jit=jit) + solver = cb.WassersteinBarycenter(lse_mode=True, jit=jit) # Compute the barycenter. out = solver(bar_p, bar_size=bar_size, x_init=x_init) @@ -377,155 +376,3 @@ def test_bures_barycenter_different_number_of_components( jax.vmap(is_positive_semidefinite, in_axes=0, out_axes=0)(covs_bary), True ) - - -class TestGWBarycenter: - ndim = 3 - ndim_f = 4 - - @staticmethod - def random_pc( - n: int, - d: int, - rng: jnp.ndarray, - m: Optional[int] = None, - **kwargs: Any - ) -> pointcloud.PointCloud: - key1, key2 = jax.random.split(rng, 2) - x = jax.random.normal(key1, (n, d)) - y = x if m is None else jax.random.normal(key2, (m, d)) - return pointcloud.PointCloud(x, y, batch_size=None, **kwargs) - - @staticmethod - def pad_cost_matrices( - costs: Sequence[jnp.ndarray], - shape: Optional[Tuple[int, int]] = None - ) -> Tuple[jnp.ndarray, jnp.ndarray]: - if shape is None: - shape = jnp.asarray([arr.shape for arr in costs]).max() - shape = (shape, shape) - else: - assert shape[0] == shape[1], shape - - cs, weights = [], [] - for cost in costs: - r, c = cost.shape - cs.append(jnp.zeros(shape).at[:r, :c].set(cost)) - w = jnp.ones(r) / r - weights.append(jnp.concatenate([w, jnp.zeros(shape[0] - r)])) - return jnp.stack(cs), jnp.stack(weights) - - # TODO(cuturi) add back KL test when KL cost GW is fixed. - @pytest.mark.parametrize( - "gw_loss,bar_size,epsilon", - [("sqeucl", 17, None)] #, ("kl", 22, 1e-2)] - ) - def test_gw_barycenter( - self, rng: jnp.ndarray, gw_loss: str, bar_size: int, - epsilon: Optional[float] - ): - tol = 1e-3 if gw_loss == "sqeucl" else 1e-1 - num_per_segment = (13, 15, 21) - rngs = jax.random.split(rng, len(num_per_segment)) - pcs = [ - self.random_pc(n, d=self.ndim, rng=rng) - for n, rng in zip(num_per_segment, rngs) - ] - costs = [pc._compute_cost_matrix() for pc, n in zip(pcs, num_per_segment)] - costs, cbs = self.pad_cost_matrices(costs) - ys = jnp.concatenate([pc.x for pc in pcs]) - bs = jnp.concatenate([jnp.ones(n) / n for n in num_per_segment]) - kwargs = { - "gw_loss": gw_loss, - "num_per_segment": num_per_segment, - "epsilon": epsilon - } - - problem_pc = bar_problems.GWBarycenterProblem(y=ys, b=bs, **kwargs) - problem_cost = bar_problems.GWBarycenterProblem( - costs=costs, - b=cbs, - **kwargs, - ) - for prob in [problem_pc, problem_cost]: - assert not prob.is_fused - assert prob.ndim_fused is None - assert prob.num_measures == len(num_per_segment) - assert prob.max_measure_size == max(num_per_segment) - assert prob._loss_name == gw_loss - assert problem_pc.ndim == self.ndim - assert problem_cost.ndim is None - - solver = gw_barycenter.GromovWassersteinBarycenter(jit=True) - out_pc = solver(problem_pc, bar_size=bar_size) - out_cost = solver(problem_cost, bar_size=bar_size) - - assert out_pc.x is None - assert out_cost.x is None - assert out_pc.cost.shape == (bar_size, bar_size) - np.testing.assert_allclose(out_pc.cost, out_cost.cost, rtol=tol, atol=tol) - np.testing.assert_allclose(out_pc.costs, out_cost.costs, rtol=tol, atol=tol) - - @pytest.mark.fast( - "jit,fused_penalty,scale_cost", [(False, 1.5, "mean"), - (True, 3.1, "max_cost")], - only_fast=0 - ) - def test_fgw_barycenter( - self, - rng: jnp.ndarray, - jit: bool, - fused_penalty: float, - scale_cost: str, - ): - - def barycenter( - y: jnp.ndim, y_fused: jnp.ndarray, num_per_segment: Tuple[int, ...] - ) -> gw_barycenter.GWBarycenterState: - prob = bar_problems.GWBarycenterProblem( - y=y, - y_fused=y_fused, - num_per_segment=num_per_segment, - fused_penalty=fused_penalty, - scale_cost=scale_cost, - ) - assert prob.is_fused - assert prob.fused_penalty == fused_penalty - assert not prob._y_as_costs - assert prob.max_measure_size == max(num_per_segment) - assert prob.num_measures == len(num_per_segment) - assert prob.ndim == self.ndim - assert prob.ndim_fused == self.ndim_f - - solver = gw_barycenter.GromovWassersteinBarycenter( - jit=False, store_inner_errors=True, epsilon=epsilon - ) - - x_init = jax.random.normal(rng, (bar_size, self.ndim_f)) - cost_init = pointcloud.PointCloud(x_init).cost_matrix - - return solver(prob, bar_size=bar_size, bar_init=(cost_init, x_init)) - - bar_size, epsilon, = 10, 1e-1 - num_per_segment = (7, 12) - - key1, *rngs = jax.random.split(rng, len(num_per_segment) + 1) - y = jnp.concatenate([ - self.random_pc(n, d=self.ndim, rng=rng).x - for n, rng in zip(num_per_segment, rngs) - ]) - rngs = jax.random.split(key1, len(num_per_segment)) - y_fused = jnp.concatenate([ - self.random_pc(n, d=self.ndim_f, rng=rng).x - for n, rng in zip(num_per_segment, rngs) - ]) - - fn = jax.jit(barycenter, static_argnums=2) if jit else barycenter - out = fn(y, y_fused, num_per_segment) - - assert out.cost.shape == (bar_size, bar_size) - assert out.x.shape == (bar_size, self.ndim_f) - np.testing.assert_array_equal(jnp.isfinite(out.cost), True) - np.testing.assert_array_equal(jnp.isfinite(out.x), True) - np.testing.assert_array_equal(jnp.isfinite(out.costs), True) - np.testing.assert_array_equal(jnp.isfinite(out.errors), True) diff --git a/tests/core/discrete_barycenter_test.py b/tests/solvers/linear/discrete_barycenter_test.py similarity index 97% rename from tests/core/discrete_barycenter_test.py rename to tests/solvers/linear/discrete_barycenter_test.py index c17d031d5..8f2bcfa0b 100644 --- a/tests/core/discrete_barycenter_test.py +++ b/tests/solvers/linear/discrete_barycenter_test.py @@ -13,13 +13,11 @@ # limitations under the License. # Lint as: python3 -"""Tests for the Policy.""" - import jax.numpy as jnp import pytest -from ott.core import discrete_barycenter as db from ott.geometry import grid, pointcloud +from ott.solvers.linear import discrete_barycenter as db class TestDiscreteBarycenter: diff --git a/tests/core/sinkhorn_diff_test.py b/tests/solvers/linear/sinkhorn_diff_test.py similarity index 99% rename from tests/core/sinkhorn_diff_test.py rename to tests/solvers/linear/sinkhorn_diff_test.py index 5e22e1933..80d9e62c8 100644 --- a/tests/core/sinkhorn_diff_test.py +++ b/tests/solvers/linear/sinkhorn_diff_test.py @@ -22,9 +22,10 @@ import numpy as np import pytest -from ott.core import implicit_differentiation as implicit_lib -from ott.core import linear_problems, sinkhorn from ott.geometry import costs, geometry, grid, pointcloud +from ott.problems.linear import linear_problem +from ott.solvers.linear import implicit_differentiation as implicit_lib +from ott.solvers.linear import sinkhorn from ott.tools import transport @@ -763,7 +764,7 @@ def test_hessian_sinkhorn( def loss(a: jnp.ndarray, x: jnp.ndarray, implicit: bool = True): geom = pointcloud.PointCloud(x, y, epsilon=epsilon) - prob = linear_problems.LinearProblem(geom, a, b, tau_a, tau_b) + prob = linear_problem.LinearProblem(geom, a, b, tau_a, tau_b) implicit_diff = ( None if not implicit else implicit_lib.ImplicitDiff(ridge_kernel=ridge, ridge_identity=ridge) diff --git a/tests/core/sinkhorn_grid_test.py b/tests/solvers/linear/sinkhorn_grid_test.py similarity index 99% rename from tests/core/sinkhorn_grid_test.py rename to tests/solvers/linear/sinkhorn_grid_test.py index 0c4180561..7937ce717 100644 --- a/tests/core/sinkhorn_grid_test.py +++ b/tests/solvers/linear/sinkhorn_grid_test.py @@ -20,8 +20,8 @@ import numpy as np import pytest -from ott.core import sinkhorn from ott.geometry import grid, pointcloud +from ott.solvers.linear import sinkhorn class TestSinkhornGrid: diff --git a/tests/core/sinkhorn_lr_test.py b/tests/solvers/linear/sinkhorn_lr_test.py similarity index 95% rename from tests/core/sinkhorn_lr_test.py rename to tests/solvers/linear/sinkhorn_lr_test.py index 72569a7b5..d84e371da 100644 --- a/tests/core/sinkhorn_lr_test.py +++ b/tests/solvers/linear/sinkhorn_lr_test.py @@ -19,8 +19,9 @@ import numpy as np import pytest -from ott.core import linear_problems, sinkhorn_lr from ott.geometry import low_rank, pointcloud +from ott.problems.linear import linear_problem +from ott.solvers.linear import sinkhorn_lr class TestLRSinkhorn: @@ -58,7 +59,7 @@ def test_euclidean_point_cloud_lr( if use_lrcgeom: geom = geom.to_LRCGeometry() assert isinstance(geom, low_rank.LRCGeometry) - ot_prob = linear_problems.LinearProblem(geom, self.a, self.b) + ot_prob = linear_problem.LinearProblem(geom, self.a, self.b) # Start with a low rank parameter solver = sinkhorn_lr.LRSinkhorn( @@ -131,7 +132,7 @@ def test_output_apply_batch_size(self, axis: int): data = self.a if axis == 0 else self.b geom = pointcloud.PointCloud(self.x, self.y) - ot_prob = linear_problems.LinearProblem(geom, self.a, self.b) + ot_prob = linear_problem.LinearProblem(geom, self.a, self.b) solver = sinkhorn_lr.LRSinkhorn( threshold=threshold, rank=10, diff --git a/tests/core/sinkhorn_extra_test.py b/tests/solvers/linear/sinkhorn_misc_test.py similarity index 89% rename from tests/core/sinkhorn_extra_test.py rename to tests/solvers/linear/sinkhorn_misc_test.py index 6bdc0b258..96da5c6bc 100644 --- a/tests/core/sinkhorn_extra_test.py +++ b/tests/solvers/linear/sinkhorn_misc_test.py @@ -13,9 +13,9 @@ # limitations under the License. # Lint as: python3 -"""Tests Anderson acceleration for sinkhorn.""" +"""Tests Anderson acceleration for Sinkhorn.""" import functools -from typing import Any, Callable, Tuple +from typing import Callable, Tuple import chex import jax @@ -23,8 +23,8 @@ import numpy as np import pytest -from ott.core import sinkhorn from ott.geometry import costs, geometry, pointcloud +from ott.solvers.linear import sinkhorn non_jitted_sinkhorn = functools.partial(sinkhorn.sinkhorn, jit=False) @@ -131,25 +131,40 @@ def initialize(self): self.a = a / jnp.sum(a) self.b = b / jnp.sum(b) - def test_bures_point_cloud(self): + @pytest.mark.parametrize("lse_mode", [False, True]) + @pytest.mark.parametrize("unbalanced,thresh", [(False, 1e-3), (True, 1e-4)]) + def test_bures_point_cloud( + self, rng: jnp.ndarray, lse_mode: bool, unbalanced: bool, thresh: float + ): """Two point clouds of Gaussians, tested with various parameters.""" - threshold = 1e-3 - geom = pointcloud.PointCloud( - self.x, - self.y, - cost_fn=costs.Bures(dimension=self.dim, regularization=1e-4), - epsilon=self.eps + if unbalanced: + rng1, rng2 = jax.random.split(rng, 2) + ws_x = jnp.abs(jax.random.normal(rng1, (self.x.shape[0], 1))) + 1e-1 + ws_y = jnp.abs(jax.random.normal(rng2, (self.y.shape[0], 1))) + 1e-1 + ws_x = ws_x.at[0].set(0.) + x = jnp.concatenate([ws_x, self.x], axis=1) + y = jnp.concatenate([ws_y, self.y], axis=1) + cost_fn = costs.UnbalancedBures(dimension=self.dim, gamma=0.9, sigma=0.98) + else: + x, y = self.x, self.y + cost_fn = costs.Bures(dimension=self.dim, regularization=1e-4) + + geom = pointcloud.PointCloud(x, y, cost_fn=cost_fn, epsilon=self.eps) + + out = sinkhorn.sinkhorn( + geom, a=self.a, b=self.b, lse_mode=lse_mode, threshold=thresh ) - errors = sinkhorn.sinkhorn(geom, a=self.a, b=self.b, lse_mode=False).errors - err = errors[errors > -1][-1] - assert threshold > err + err = out.errors[out.errors > -1][-1] + + assert out.converged + assert thresh > err - def test_regularized_unbalanced_bures(self): + def test_regularized_unbalanced_bures_cost(self): """Tests Regularized Unbalanced Bures.""" x = jnp.concatenate((jnp.array([0.9]), self.x[0, :])) y = jnp.concatenate((jnp.array([1.1]), self.y[0, :])) - rub = costs.UnbalancedBures(self.dim, 1, 0.8) + rub = costs.UnbalancedBures(self.dim, gamma=1.0, sigma=0.8) assert not jnp.any(jnp.isnan(rub(x, y))) assert not jnp.any(jnp.isnan(rub(y, x))) np.testing.assert_allclose(rub(x, y), rub(y, x), rtol=5e-3, atol=5e-3) @@ -330,7 +345,7 @@ def f( def test_jit_vs_non_jit_bwd(self, implicit: bool): def loss( - a: jnp.ndarray, x: jnp.ndarray, fun: Callable[[Any], + a: jnp.ndarray, x: jnp.ndarray, fun: Callable[..., sinkhorn.SinkhornOutput] ): out = fun( diff --git a/tests/core/sinkhorn_test.py b/tests/solvers/linear/sinkhorn_test.py similarity index 99% rename from tests/core/sinkhorn_test.py rename to tests/solvers/linear/sinkhorn_test.py index e639e6dce..70f793854 100644 --- a/tests/core/sinkhorn_test.py +++ b/tests/solvers/linear/sinkhorn_test.py @@ -20,8 +20,9 @@ import numpy as np import pytest -from ott.core import linear_problems, sinkhorn from ott.geometry import costs, geometry, pointcloud +from ott.problems.linear import linear_problem +from ott.solvers.linear import sinkhorn class TestSinkhorn: @@ -460,7 +461,7 @@ def test_sinkhorn_online_memory(self, batch_size: int): x = jax.random.uniform(rngs[0], (n, 2)) y = jax.random.uniform(rngs[1], (m, 2)) geom = pointcloud.PointCloud(x, y, batch_size=batch_size, epsilon=1) - problem = linear_problems.LinearProblem(geom) + problem = linear_problem.LinearProblem(geom) solver = sinkhorn.Sinkhorn() out = solver(problem) diff --git a/tests/core/icnn_test.py b/tests/solvers/nn/icnn_test.py similarity index 80% rename from tests/core/icnn_test.py rename to tests/solvers/nn/icnn_test.py index 43b5be23c..b45d2c521 100644 --- a/tests/core/icnn_test.py +++ b/tests/solvers/nn/icnn_test.py @@ -20,7 +20,7 @@ import numpy as np import pytest -from ott.core.icnn import ICNN +from ott.solvers.nn import icnn @pytest.mark.fast @@ -32,22 +32,22 @@ def test_icnn_convexity(self, rng: jnp.ndarray): dim_hidden = (64, 64) # define icnn model - icnn = ICNN(dim_hidden) + model = icnn.ICNN(dim_hidden) # initialize model key1, key2, key3 = jax.random.split(rng, 3) - params = icnn.init(key1, jnp.ones(n_features))['params'] + params = model.init(key1, jnp.ones(n_features))['params'] # check convexity x = jax.random.normal(key1, (n_samples, n_features)) * 0.1 y = jax.random.normal(key2, (n_samples, n_features)) - out_x = icnn.apply({'params': params}, x) - out_y = icnn.apply({'params': params}, y) + out_x = model.apply({'params': params}, x) + out_y = model.apply({'params': params}, y) out = list() for t in jnp.linspace(0, 1): - out_xy = icnn.apply({'params': params}, t * x + (1 - t) * y) + out_xy = model.apply({'params': params}, t * x + (1 - t) * y) out.append((t * out_x + (1 - t) * out_y) - out_xy) np.testing.assert_array_equal(jnp.asarray(out) >= 0, True) @@ -58,17 +58,17 @@ def test_icnn_hessian(self, rng: jnp.ndarray): # define icnn model n_samples = 2 dim_hidden = (64, 64) - icnn = ICNN(dim_hidden) + model = icnn.ICNN(dim_hidden) # initialize model key1, key2 = jax.random.split(rng) - params = icnn.init(key1, jnp.ones(n_samples))['params'] + params = model.init(key1, jnp.ones(n_samples))['params'] # check if Hessian is positive-semidefinite via eigenvalues data = jax.random.normal(key2, (n_samples,)) # compute Hessian - hessian = jax.jacfwd(jax.jacrev(icnn.apply, argnums=1), argnums=1) + hessian = jax.jacfwd(jax.jacrev(model.apply, argnums=1), argnums=1) icnn_hess = hessian({'params': params}, data) # compute eigenvalues diff --git a/tests/core/neuraldual_test.py b/tests/solvers/nn/neuraldual_test.py similarity index 96% rename from tests/core/neuraldual_test.py rename to tests/solvers/nn/neuraldual_test.py index 289fc59b5..4fa9c009b 100644 --- a/tests/core/neuraldual_test.py +++ b/tests/solvers/nn/neuraldual_test.py @@ -21,7 +21,7 @@ import pytest from typing_extensions import Literal -from ott.core.neuraldual import NeuralDualSolver +from ott.solvers.nn import neuraldual class ToyDataset: @@ -94,7 +94,7 @@ def decreasing(losses: Sequence[float]) -> bool: dataloader_source, dataloader_target = toy_dataset # initialize neural dual - neural_dual_solver = NeuralDualSolver( + neural_dual_solver = neuraldual.NeuralDualSolver( input_dim=2, num_train_iters=num_train_iters, logging=True, @@ -113,7 +113,7 @@ def test_neural_dual_jit(self, toy_dataset: Tuple[ToyDataset, ToyDataset]): num_train_iters = 10 dataloader_source, dataloader_target = toy_dataset # initialize neural dual - neural_dual_solver = NeuralDualSolver( + neural_dual_solver = neuraldual.NeuralDualSolver( input_dim=2, num_train_iters=num_train_iters ) neural_dual = neural_dual_solver( diff --git a/tests/solvers/quadratic/fgw_barycenter_test.py b/tests/solvers/quadratic/fgw_barycenter_test.py new file mode 100644 index 000000000..d3dca9ad6 --- /dev/null +++ b/tests/solvers/quadratic/fgw_barycenter_test.py @@ -0,0 +1,78 @@ +"""Tests for Fused Gromov-Wasserstein barycenter.""" +from typing import Tuple + +import jax +import jax.numpy as jnp +import numpy as np +import pytest + +from ott.geometry import pointcloud +from ott.problems.quadratic import gw_barycenter as gwb +from ott.solvers.quadratic import gw_barycenter as gwb_solver + + +class FGWBarycenterTest: + + @pytest.mark.fast( + "jit,fused_penalty,scale_cost", [(False, 1.5, "mean"), + (True, 3.1, "max_cost")], + only_fast=0 + ) + def test_fgw_barycenter( + self, + rng: jnp.ndarray, + jit: bool, + fused_penalty: float, + scale_cost: str, + ): + + def barycenter( + y: jnp.ndim, y_fused: jnp.ndarray, num_per_segment: Tuple[int, ...] + ) -> gwb_solver.GWBarycenterState: + prob = gwb.GWBarycenterProblem( + y=y, + y_fused=y_fused, + num_per_segment=num_per_segment, + fused_penalty=fused_penalty, + scale_cost=scale_cost, + ) + assert prob.is_fused + assert prob.fused_penalty == fused_penalty + assert not prob._y_as_costs + assert prob.max_measure_size == max(num_per_segment) + assert prob.num_measures == len(num_per_segment) + assert prob.ndim == self.ndim + assert prob.ndim_fused == self.ndim_f + + solver = gwb_solver.GromovWassersteinBarycenter( + jit=False, store_inner_errors=True, epsilon=epsilon + ) + + x_init = jax.random.normal(rng, (bar_size, self.ndim_f)) + cost_init = pointcloud.PointCloud(x_init).cost_matrix + + return solver(prob, bar_size=bar_size, bar_init=(cost_init, x_init)) + + bar_size, epsilon, = 10, 1e-1 + num_per_segment = (7, 12) + + key1, *rngs = jax.random.split(rng, len(num_per_segment) + 1) + y = jnp.concatenate([ + self.random_pc(n, d=self.ndim, rng=rng).x + for n, rng in zip(num_per_segment, rngs) + ]) + rngs = jax.random.split(key1, len(num_per_segment)) + y_fused = jnp.concatenate([ + self.random_pc(n, d=self.ndim_f, rng=rng).x + for n, rng in zip(num_per_segment, rngs) + ]) + + fn = jax.jit(barycenter, static_argnums=2) if jit else barycenter + out = fn(y, y_fused, num_per_segment) + + assert out.cost.shape == (bar_size, bar_size) + assert out.x.shape == (bar_size, self.ndim_f) + np.testing.assert_array_equal(jnp.isfinite(out.cost), True) + np.testing.assert_array_equal(jnp.isfinite(out.x), True) + np.testing.assert_array_equal(jnp.isfinite(out.costs), True) + np.testing.assert_array_equal(jnp.isfinite(out.errors), True) diff --git a/tests/core/fused_gromov_wasserstein_test.py b/tests/solvers/quadratic/fgw_test.py similarity index 91% rename from tests/core/fused_gromov_wasserstein_test.py rename to tests/solvers/quadratic/fgw_test.py index a2333c96e..338ffafb1 100644 --- a/tests/core/fused_gromov_wasserstein_test.py +++ b/tests/solvers/quadratic/fgw_test.py @@ -21,8 +21,9 @@ import numpy as np import pytest -from ott.core import gromov_wasserstein, quad_problems from ott.geometry import geometry, low_rank, pointcloud +from ott.problems.quadratic import quadratic_problem +from ott.solvers.quadratic import gromov_wasserstein as gw_solver class TestFusedGromovWasserstein: @@ -49,13 +50,13 @@ def initialize(self, rng: jnp.ndarray): self.cy = jax.random.uniform(keys[5], (self.m, self.m)) self.cxy = jax.random.uniform(keys[6], (self.n, self.m)) - def test_flag_store_errors_fused(self): + def test_fgw_flag_store_errors_fused(self): """Tests whether errors are properly stored if requested.""" threshold_sinkhorn = 1e-2 geom_x = pointcloud.PointCloud(self.x) geom_y = pointcloud.PointCloud(self.y) geom_xy = pointcloud.PointCloud(self.x_2, self.y_2) - out = gromov_wasserstein.gromov_wasserstein( + out = gw_solver.gromov_wasserstein( geom_xx=geom_x, geom_yy=geom_y, geom_xy=geom_xy, @@ -66,7 +67,7 @@ def test_flag_store_errors_fused(self): ).errors assert out is None - out = gromov_wasserstein.gromov_wasserstein( + out = gw_solver.gromov_wasserstein( geom_xx=geom_x, geom_yy=geom_y, geom_xy=geom_xy, @@ -86,7 +87,7 @@ def test_flag_store_errors_fused(self): assert out.ndim == 2 @pytest.mark.fast.with_args(jit=[False, True], only_fast=1) - def test_gradient_marginals_fused_gromov_wasserstein(self, jit: bool): + def test_gradient_marginals_fgw_solver(self, jit: bool): """Test gradient w.r.t. probability weights.""" geom_x = pointcloud.PointCloud(self.x) geom_y = pointcloud.PointCloud(self.y) @@ -98,7 +99,7 @@ def reg_gw(a, b, implicit): 'implicit_differentiation': implicit, 'max_iterations': 1001 } - out = gromov_wasserstein.gromov_wasserstein( + out = gw_solver.gromov_wasserstein( geom_x, geom_y, geom_xy=geom_xy, @@ -136,14 +137,14 @@ def reg_gw(a, b, implicit): ) @pytest.mark.fast.with_args(lse_mode=[False, True], only_fast=1) - def test_fused_gromov_wasserstein_pointcloud(self, lse_mode: bool): + def test_fgw_solver_pointcloud(self, lse_mode: bool): """Test basic computations pointclouds.""" def reg_gw(x, y, x_2, y_2, fused_penalty, a, b): geom_x = pointcloud.PointCloud(x) geom_y = pointcloud.PointCloud(y) geom_xy = pointcloud.PointCloud(x_2, y_2) - return gromov_wasserstein.gromov_wasserstein( + return gw_solver.gromov_wasserstein( geom_x, geom_y, geom_xy=geom_xy, @@ -163,7 +164,7 @@ def reg_gw(x, y, x_2, y_2, fused_penalty, a, b): assert cost is not None @pytest.mark.parametrize("lse_mode", [False, True]) - def test_gradient_fused_gromov_wasserstein_pointcloud(self, lse_mode: bool): + def test_gradient_fgw_solver_pointcloud(self, lse_mode: bool): """Test gradient w.r.t. pointclouds.""" def reg_gw(x, y, x_2, y_2, fused_penalty, a, b, implicit): @@ -175,7 +176,7 @@ def reg_gw(x, y, x_2, y_2, fused_penalty, a, b, implicit): 'max_iterations': 1001, 'lse_mode': lse_mode } - return gromov_wasserstein.gromov_wasserstein( + return gw_solver.gromov_wasserstein( geom_x, geom_y, geom_xy=geom_xy, @@ -206,7 +207,7 @@ def reg_gw(x, y, x_2, y_2, fused_penalty, a, b, implicit): ) @pytest.mark.parametrize("lse_mode", [False, True]) - def test_gradient_fused_gromov_wasserstein_geometry(self, lse_mode: bool): + def test_gradient_fgw_solver_geometry(self, lse_mode: bool): """Test gradient w.r.t. cost matrices.""" def reg_gw(cx, cy, cxy, fused_penalty, a, b, implicit): @@ -218,7 +219,7 @@ def reg_gw(cx, cy, cxy, fused_penalty, a, b, implicit): 'max_iterations': 1001, 'lse_mode': lse_mode } - return gromov_wasserstein.gromov_wasserstein( + return gw_solver.gromov_wasserstein( geom_x, geom_y, geom_xy=geom_xy, @@ -251,7 +252,7 @@ def reg_gw(cx, cy, cxy, fused_penalty, a, b, implicit): grad_matrices[0][2], grad_matrices[1][2], rtol=1e-02, atol=1e-02 ) - def test_adaptive_threshold_fused(self): + def test_fgw_adaptive_threshold(self): """Checking solution is improved with smaller threshold for convergence.""" geom_x = pointcloud.PointCloud(self.x, self.x) geom_y = pointcloud.PointCloud(self.y, self.y) @@ -259,7 +260,7 @@ def test_adaptive_threshold_fused(self): # without warm start for calls to sinkhorn def loss_thre(threshold: float) -> float: - return gromov_wasserstein.gromov_wasserstein( + return gw_solver.gromov_wasserstein( geom_xx=geom_x, geom_yy=geom_y, geom_xy=geom_xy, @@ -274,7 +275,7 @@ def loss_thre(threshold: float) -> float: assert loss_thre(1e-3) > loss_thre(1e-5) @pytest.mark.parametrize("lse_mode", [False, True]) - def test_gradient_fused_gromov_wasserstein_penalty(self, lse_mode: bool): + def test_gradient_fgw_solver_penalty(self, lse_mode: bool): """Test gradient w.r.t. penalty.""" def reg_gw(cx, cy, cxy, fused_penalty, a, b, implicit): @@ -286,7 +287,7 @@ def reg_gw(cx, cy, cxy, fused_penalty, a, b, implicit): 'max_iterations': 1001, 'lse_mode': lse_mode } - return gromov_wasserstein.gromov_wasserstein( + return gw_solver.gromov_wasserstein( geom_x, geom_y, geom_xy=geom_xy, @@ -318,7 +319,7 @@ def reg_fgw(x, y, x_2, y_2, fused_penalty, a, b): geom_y = pointcloud.PointCloud(y) geom_xy = pointcloud.PointCloud(x_2, y_2) sinkhorn_kwargs = {'max_iterations': 1001} - return gromov_wasserstein.gromov_wasserstein( + return gw_solver.gromov_wasserstein( geom_x, geom_y, geom_xy=geom_xy, @@ -333,7 +334,7 @@ def reg_gw(x, y, a, b): geom_x = pointcloud.PointCloud(x) geom_y = pointcloud.PointCloud(y) sinkhorn_kwargs = {'max_iterations': 1001} - return gromov_wasserstein.gromov_wasserstein( + return gw_solver.gromov_wasserstein( geom_x, geom_y, a=a, @@ -366,7 +367,7 @@ def test_fgw_lr_memory(self, rng: jnp.ndarray, jit: bool): geom_y = pointcloud.PointCloud(y) geom_xy = pointcloud.PointCloud(xx, yy) - ot_gwlr = gromov_wasserstein.gromov_wasserstein( + ot_gwlr = gw_solver.gromov_wasserstein( geom_x, geom_y, geom_xy, rank=5, jit=jit ) res0 = ot_gwlr.apply(x.T, axis=0) @@ -391,14 +392,14 @@ def test_fgw_lr_generic_cost_matrix( geom_y = geometry.Geometry(cost_matrix=y @ y.T) geom_xy = geometry.Geometry(cost_matrix=xx @ yy.T) - problem = quad_problems.QuadraticProblem( + problem = quadratic_problem.QuadraticProblem( geom_x, geom_y, geom_xy, ranks=cost_rank, tolerances=5e-1 ) assert problem._is_low_rank_convertible lr_prob = problem.to_low_rank() assert lr_prob.is_low_rank - solver = gromov_wasserstein.GromovWasserstein(rank=5, epsilon=1) + solver = gw_solver.GromovWasserstein(rank=5, epsilon=1) out = solver(problem) assert solver.rank == 5 diff --git a/tests/solvers/quadratic/gw_barycenter_test.py b/tests/solvers/quadratic/gw_barycenter_test.py new file mode 100644 index 000000000..94cd5759b --- /dev/null +++ b/tests/solvers/quadratic/gw_barycenter_test.py @@ -0,0 +1,113 @@ +# Copyright 2022 Apple +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Lint as: python3 +"""Tests for Gromov-Wasserstein barycenter.""" +from typing import Any, Optional, Sequence, Tuple + +import jax +import jax.numpy as jnp +import numpy as np +import pytest + +from ott.geometry import pointcloud +from ott.problems.quadratic import gw_barycenter as gwb +from ott.solvers.quadratic import gw_barycenter as gwb_solver + + +class TestGWBarycenter: + ndim = 3 + ndim_f = 4 + + @staticmethod + def random_pc( + n: int, + d: int, + rng: jnp.ndarray, + m: Optional[int] = None, + **kwargs: Any + ) -> pointcloud.PointCloud: + key1, key2 = jax.random.split(rng, 2) + x = jax.random.normal(key1, (n, d)) + y = x if m is None else jax.random.normal(key2, (m, d)) + return pointcloud.PointCloud(x, y, batch_size=None, **kwargs) + + @staticmethod + def pad_cost_matrices( + costs: Sequence[jnp.ndarray], + shape: Optional[Tuple[int, int]] = None + ) -> Tuple[jnp.ndarray, jnp.ndarray]: + if shape is None: + shape = jnp.asarray([arr.shape for arr in costs]).max() + shape = (shape, shape) + else: + assert shape[0] == shape[1], shape + + cs, weights = [], [] + for cost in costs: + r, c = cost.shape + cs.append(jnp.zeros(shape).at[:r, :c].set(cost)) + w = jnp.ones(r) / r + weights.append(jnp.concatenate([w, jnp.zeros(shape[0] - r)])) + return jnp.stack(cs), jnp.stack(weights) + + # TODO(cuturi) add back KL test when KL cost GW is fixed. + @pytest.mark.parametrize( + "gw_loss,bar_size,epsilon", + [("sqeucl", 17, None)] # , ("kl", 22, 1e-2)] + ) + def test_gw_barycenter( + self, rng: jnp.ndarray, gw_loss: str, bar_size: int, + epsilon: Optional[float] + ): + tol = 1e-3 if gw_loss == "sqeucl" else 1e-1 + num_per_segment = (13, 15, 21) + rngs = jax.random.split(rng, len(num_per_segment)) + pcs = [ + self.random_pc(n, d=self.ndim, rng=rng) + for n, rng in zip(num_per_segment, rngs) + ] + costs = [pc._compute_cost_matrix() for pc, n in zip(pcs, num_per_segment)] + costs, cbs = self.pad_cost_matrices(costs) + ys = jnp.concatenate([pc.x for pc in pcs]) + bs = jnp.concatenate([jnp.ones(n) / n for n in num_per_segment]) + kwargs = { + "gw_loss": gw_loss, + "num_per_segment": num_per_segment, + "epsilon": epsilon + } + + problem_pc = gwb.GWBarycenterProblem(y=ys, b=bs, **kwargs) + problem_cost = gwb.GWBarycenterProblem( + costs=costs, + b=cbs, + **kwargs, + ) + for prob in [problem_pc, problem_cost]: + assert not prob.is_fused + assert prob.ndim_fused is None + assert prob.num_measures == len(num_per_segment) + assert prob.max_measure_size == max(num_per_segment) + assert prob._loss_name == gw_loss + assert problem_pc.ndim == self.ndim + assert problem_cost.ndim is None + + solver = gwb_solver.GromovWassersteinBarycenter(jit=True) + out_pc = solver(problem_pc, bar_size=bar_size) + out_cost = solver(problem_cost, bar_size=bar_size) + + assert out_pc.x is None + assert out_cost.x is None + assert out_pc.cost.shape == (bar_size, bar_size) + np.testing.assert_allclose(out_pc.cost, out_cost.cost, rtol=tol, atol=tol) + np.testing.assert_allclose(out_pc.costs, out_cost.costs, rtol=tol, atol=tol) diff --git a/tests/core/gromov_wasserstein_test.py b/tests/solvers/quadratic/gw_test.py similarity index 95% rename from tests/core/gromov_wasserstein_test.py rename to tests/solvers/quadratic/gw_test.py index 0ded70d48..0e37a9c6d 100644 --- a/tests/core/gromov_wasserstein_test.py +++ b/tests/solvers/quadratic/gw_test.py @@ -21,8 +21,9 @@ import numpy as np import pytest -from ott.core import gromov_wasserstein, quad_problems from ott.geometry import geometry, low_rank, pointcloud +from ott.problems.quadratic import quadratic_problem +from ott.solvers.quadratic import gromov_wasserstein @pytest.mark.fast @@ -48,7 +49,9 @@ def test_quad_to_low_rank( geom_yy = geometry.Geometry(geom_yy.cost_matrix) geom_xy = geometry.Geometry(geom_xy.cost_matrix) - prob = quad_problems.QuadraticProblem(geom_xx, geom_yy, geom_xy, ranks=rank) + prob = quadratic_problem.QuadraticProblem( + geom_xx, geom_yy, geom_xy, ranks=rank + ) assert not prob.is_low_rank # point clouds are always converted, if possible @@ -85,7 +88,7 @@ def test_quad_to_low_rank( assert lr_prob._is_low_rank_convertible assert lr_prob.to_low_rank() is lr_prob - def test_implicit_conversion_mixed_input(self, rng: jnp.ndarray): + def test_gw_implicit_conversion_mixed_input(self, rng: jnp.ndarray): n, m, d1, d2 = 200, 300, 20, 25 k1, k2 = jax.random.split(rng, 2) x = jax.random.normal(k1, (n, d1)) @@ -94,7 +97,7 @@ def test_implicit_conversion_mixed_input(self, rng: jnp.ndarray): geom_xx = pointcloud.PointCloud(x) geom_yy = pointcloud.PointCloud(y).to_LRCGeometry() - prob = quad_problems.QuadraticProblem(geom_xx, geom_yy, ranks=-1) + prob = quadratic_problem.QuadraticProblem(geom_xx, geom_yy, ranks=-1) lr_prob = prob.to_low_rank() assert prob._is_low_rank_convertible @@ -149,7 +152,7 @@ def test_flag_store_errors(self): assert out.ndim == 2 @pytest.mark.parametrize("jit", [False, True]) - def test_gradient_marginals_gromov_wasserstein(self, jit: bool): + def test_gradient_marginals_gw(self, jit: bool): """Test gradient w.r.t. probability weights.""" geom_x = pointcloud.PointCloud(self.x) geom_y = pointcloud.PointCloud(self.y) @@ -196,7 +199,7 @@ def reg_gw(a, b, implicit): ) @pytest.mark.fast - def test_gromov_wasserstein_pointcloud(self): + def test_gw_pointcloud(self): """Test basic computations pointclouds.""" def reg_gw(x, y, a, b): @@ -209,7 +212,7 @@ def reg_gw(x, y, a, b): assert not jnp.isnan(reg_gw(self.x, self.y, self.a, self.b)) @pytest.mark.parametrize("lse_mode", [False, True]) - def test_gradient_gromov_wasserstein_pointcloud(self, lse_mode: bool): + def test_gradient_gw_pointcloud(self, lse_mode: bool): """Test gradient w.r.t. pointclouds.""" def reg_gw(x, y, a, b, implicit): @@ -251,7 +254,7 @@ def reg_gw(x, y, a, b, implicit): ) @pytest.mark.parametrize("lse_mode", [False, True]) - def test_gradient_gromov_wasserstein_geometry(self, lse_mode: bool): + def test_gradient_gw_geometry(self, lse_mode: bool): """Test gradient w.r.t. cost matrices.""" def reg_gw(cx, cy, a, b, implicit): @@ -293,7 +296,7 @@ def reg_gw(cx, cy, a, b, implicit): grad_matrices[0][1], grad_matrices[1][1], rtol=1e-02, atol=1e-02 ) - def test_adaptive_threshold(self): + def test_gw_adaptive_threshold(self): """Checking solution is improved with smaller threshold for convergence.""" geom_x = pointcloud.PointCloud(self.x, self.x) geom_y = pointcloud.PointCloud(self.y, self.y) @@ -325,7 +328,7 @@ def test_gw_lr(self, rng: jnp.ndarray): geom_xx = pointcloud.PointCloud(x) geom_yy = pointcloud.PointCloud(y) - prob = quad_problems.QuadraticProblem(geom_xx, geom_yy, a=a, b=b) + prob = quadratic_problem.QuadraticProblem(geom_xx, geom_yy, a=a, b=b) solver = gromov_wasserstein.GromovWasserstein(rank=5, epsilon=0.2) ot_gwlr = solver(prob) solver = gromov_wasserstein.GromovWasserstein(epsilon=0.2) @@ -347,7 +350,7 @@ def test_gw_lr_matches_fused(self, rng: jnp.ndarray): geom_xx = pointcloud.PointCloud(x) geom_yy = pointcloud.PointCloud(y) geom_xy = pointcloud.PointCloud(x, z) # only used to compute n x m matrix - prob = quad_problems.QuadraticProblem( + prob = quadratic_problem.QuadraticProblem( geom_xx, geom_yy, geom_xy=geom_xy, fused_penalty=1.3, a=a, b=b ) solver = gromov_wasserstein.GromovWasserstein(rank=6) @@ -463,7 +466,7 @@ def initialize(self, rng: jnp.ndarray): self.tau_b = 0.9 @pytest.mark.fast - def test_gromov_wasserstein_pointcloud(self): + def test_gw_pointcloud(self): """Test basic computations pointclouds.""" def reg_gw(x, y, a, b): @@ -484,9 +487,7 @@ def reg_gw(x, y, a, b): assert not jnp.isnan(cost) @pytest.mark.parametrize("gw_unbalanced_correction", [False, True]) - def test_gradient_gromov_wasserstein_pointcloud( - self, gw_unbalanced_correction: bool - ): + def test_gradient_gw_pointcloud(self, gw_unbalanced_correction: bool): """Test gradient w.r.t. pointclouds.""" def reg_gw(x, y, a, b, implicit): @@ -530,9 +531,7 @@ def reg_gw(x, y, a, b, implicit): ) @pytest.mark.parametrize("gw_unbalanced_correction", [False, True]) - def test_gradient_gromov_wasserstein_geometry( - self, gw_unbalanced_correction: bool - ): + def test_gradient_gw_geometry(self, gw_unbalanced_correction: bool): """Test gradient w.r.t. cost matrices.""" def reg_gw(cx, cy, a, b, implicit): diff --git a/tests/tools/gaussian_mixture/scale_tril_test.py b/tests/tools/gaussian_mixture/scale_tril_test.py index ef37cff8f..32f2f93f3 100644 --- a/tests/tools/gaussian_mixture/scale_tril_test.py +++ b/tests/tools/gaussian_mixture/scale_tril_test.py @@ -18,7 +18,7 @@ import numpy as np import pytest -from ott.geometry import matrix_square_root +from ott.math import matrix_square_root from ott.tools.gaussian_mixture import scale_tril diff --git a/tests/tools/segment_sinkhorn_test.py b/tests/tools/segment_sinkhorn_test.py index e7269d2da..5a1e81c1e 100644 --- a/tests/tools/segment_sinkhorn_test.py +++ b/tests/tools/segment_sinkhorn_test.py @@ -20,8 +20,8 @@ import numpy as np import pytest -from ott.core import sinkhorn from ott.geometry import costs, pointcloud +from ott.solvers.linear import sinkhorn from ott.tools import segment_sinkhorn from ott.tools.gaussian_mixture import gaussian_mixture diff --git a/tests/tools/sinkhorn_divergence_test.py b/tests/tools/sinkhorn_divergence_test.py index a6110e4cd..dc8eacab2 100644 --- a/tests/tools/sinkhorn_divergence_test.py +++ b/tests/tools/sinkhorn_divergence_test.py @@ -21,8 +21,8 @@ import numpy as np import pytest -from ott.core import sinkhorn from ott.geometry import costs, geometry, pointcloud +from ott.solvers.linear import sinkhorn from ott.tools import sinkhorn_divergence from ott.tools.gaussian_mixture import gaussian_mixture diff --git a/tests/tools/transport_test.py b/tests/tools/transport_test.py index 44c157261..4954c85bd 100644 --- a/tests/tools/transport_test.py +++ b/tests/tools/transport_test.py @@ -18,8 +18,8 @@ import numpy as np import pytest -from ott.core import linear_problems from ott.geometry import pointcloud +from ott.problems.linear import linear_problem from ott.tools import transport @@ -60,7 +60,7 @@ def test_transport_from_problem(self, rng: jnp.ndarray): geom = pointcloud.PointCloud(x, y, batch_size=9) b = jax.random.uniform(rngs[2], (num_b,)) b /= jnp.sum(b) - pb = linear_problems.LinearProblem(geom, b=b) + pb = linear_problem.LinearProblem(geom, b=b) ot = transport.solve(pb) np.testing.assert_array_equal(ot.matrix.shape, (num_a, num_b))