Skip to content

Commit

Permalink
Misc/project structure (#176)
Browse files Browse the repository at this point in the history
* Reorganize repo structure

* Update PC docs

* Update imports, fix some types

* Fix more types, pet-peeves

* Fix tests

* Update cost funcs and potentials

* Fix LR initializer

* Fix k-means initializer

* Move `utils`

* Update imports in notebooks

* Update geometry docs

* Update initializers

* Update math docs

* Update problem docstrings

* Update `solvers` docstrings

* Update `tools` docstrings

* Remove remaining `core` mentions from docstrings

* Start updating documentation

* Fix typing

* Update solvers docs

* Add initializers

* Update docs

* Fix MetaOT links

* Fix bibliography links

* Fix more links in the notebooks

* Follow line length in README.md

* Update `tests` structure

* Update badges

* Add TODOs, fix citation in `index.rst`, move `implicit_diff`

* Fix implicit_diff, TODOs in costs

* Use `jax.lax.cond` in `UnbalancedBures`

* Fix `UnbalancedBures`

* Update CI versions

* Fix UnbalancedBures's norm
  • Loading branch information
michalk8 authored Nov 22, 2022
1 parent e4e5bac commit 32962a1
Show file tree
Hide file tree
Showing 132 changed files with 2,729 additions and 2,314 deletions.
2 changes: 1 addition & 1 deletion .editorconfig
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,9 @@ root = true
[*]
end_of_line = lf
insert_final_newline = true
charset = utf-8

[*py]
charset = utf-8
indent_size = 2
indent_style = space
max_line_length = 80
Expand Down
6 changes: 3 additions & 3 deletions .github/workflows/lint.yml
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,13 @@ jobs:
os: [ubuntu-latest]

steps:
- uses: actions/checkout@v2
- uses: actions/checkout@v3
- name: Set up Python ${{ matrix.python }}
uses: actions/setup-python@v2
uses: actions/setup-python@v4
with:
python-version: ${{ matrix.python }}

- uses: actions/cache@v2
- uses: actions/cache@v3
with:
path: ~/.cache/pre-commit
key: precommit-${{ env.pythonLocation }}-${{ hashFiles('**/.pre-commit-config.yaml') }}
Expand Down
6 changes: 3 additions & 3 deletions .github/workflows/notebook_tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,12 @@ jobs:
runs-on: ${{ matrix.os }}
strategy:
matrix:
python-version: ['3.8']
python-version: [3.8]
os: [ubuntu-latest]

steps:
- uses: actions/checkout@v2
- uses: actions/setup-python@v2
- uses: actions/checkout@v3
- uses: actions/setup-python@v4
with:
python-version: ${{ matrix.python-version }}

Expand Down
4 changes: 2 additions & 2 deletions .github/workflows/publish_to_pypi.yml
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,9 @@ jobs:
runs-on: ubuntu-latest

steps:
- uses: actions/checkout@v2
- uses: actions/checkout@v3
- name: Set up Python
uses: actions/setup-python@v2
uses: actions/setup-python@v4
with:
python-version: 3.x
- name: Install dependencies
Expand Down
4 changes: 2 additions & 2 deletions .github/workflows/tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@ jobs:
test_mark: [fast, all]

steps:
- uses: actions/checkout@v2
- uses: actions/setup-python@v2
- uses: actions/checkout@v3
- uses: actions/setup-python@v4
with:
python-version: ${{ matrix.python-version }}

Expand Down
2 changes: 1 addition & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,7 @@ cython_debug/

# generated documentation
docs/html
docs/_autosummary
**/_autosummary

# macos
**/.DS_Store
2 changes: 1 addition & 1 deletion CONTRIBUTING.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ to the project, participating in discussions or raising issues.
1. fork the repository using the **Fork** button on GitHub or the following
[link](https://github.com/ott-jax/ott/fork)
2. ```bash
git clone https://github.com/YOUR_USERNAME/ott
git clone https://github.com/<YOUR_USERNAME>/ott
cd ott
pip install -e .'[dev,test]'
pre-commit install
Expand Down
49 changes: 34 additions & 15 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,30 +1,44 @@
<img src="https://raw.githubusercontent.com/ott-jax/ott/main/docs/_static/images/logoOTT.png" width="10%" alt="logo">

# Optimal Transport Tools (OTT).

# Optimal Transport Tools (OTT)
[![Downloads](https://pepy.tech/badge/ott-jax)](https://pypi.org/project/ott-jax/)
[![Tests](https://img.shields.io/github/workflow/status/ott-jax/ott/tests/main)](https://github.com/ott-jax/ott/actions/workflows/tests.yml)
[![Docs](https://img.shields.io/readthedocs/ott-jax/latest)](https://ott-jax.readthedocs.io/en/latest/)
[![Coverage](https://img.shields.io/codecov/c/github/ott-jax/ott/main)](https://app.codecov.io/gh/ott-jax/ott)

**See [full documentation](https://ott-jax.readthedocs.io/en/latest/).**
**See the [full documentation](https://ott-jax.readthedocs.io/en/latest/).**

## What is OTT-JAX?

A JAX powered library to compute optimal transport at scale and on accelerators, OTT-JAX includes the fastest implementation of the Sinkhorn algorithm you will find around. We have implemented all tweaks (scheduling, acceleration, initializations) and extensions (low-rank), that can be used directly, or within more advanced problems (Gromov-Wasserstein, barycenters). Some of JAX features, including [JIT](https://jax.readthedocs.io/en/latest/notebooks/quickstart.html#Using-jit-to-speed-up-functions), [auto-vectorization](https://jax.readthedocs.io/en/latest/notebooks/quickstart.html#Auto-vectorization-with-vmap) and [implicit differentiation](https://jax.readthedocs.io/en/latest/notebooks/Custom_derivative_rules_for_Python_code.html) work towards the goal of having end-to-end differentiable outputs. OTT-JAX is developed by a team of researchers from Apple, Google, Meta and many academic contributors, including TU München, Oxford, ENSAE/IP Paris and the Hebrew University.
A JAX powered library to compute optimal transport at scale and on accelerators, OTT-JAX includes the fastest
implementation of the Sinkhorn algorithm you will find around. We have implemented all tweaks (scheduling,
acceleration, initializations) and extensions (low-rank), that can be used directly, or within more advanced problems
(Gromov-Wasserstein, barycenters). Some of JAX features, including
[JIT](https://jax.readthedocs.io/en/latest/notebooks/quickstart.html#Using-jit-to-speed-up-functions),
[auto-vectorization](https://jax.readthedocs.io/en/latest/notebooks/quickstart.html#Auto-vectorization-with-vmap) and
[implicit differentiation](https://jax.readthedocs.io/en/latest/notebooks/Custom_derivative_rules_for_Python_code.html)
work towards the goal of having end-to-end differentiable outputs. OTT-JAX is developed by a team of researchers
from Apple, Google, Meta and many academic contributors, including TU München, Oxford, ENSAE/IP Paris and the
Hebrew University.

## What is optimal transport?
Optimal transport can be loosely described as the branch of mathematics and optimization that studies
*matching problems*: given two families of points, and a cost function on pairs of points, find a `good' (low cost) way
to associate bijectively to every point in the first family another in the second.

Optimal transport can be loosely described as the branch of mathematics and optimization that studies *matching problems*: given two families of points, and a cost function on pairs of points, find a `good' (low cost) way to associate bijectively to every point in the first family another in the second.

Such problems appear in all areas of science, are easy to describe, yet hard to solve. Indeed, while matching optimally two sets of *n* points using a pairwise cost can be solved with the [Hungarian algorithm](https://en.wikipedia.org/wiki/Hungarian_algorithm), solving it costs an order of $O(n^3)$ operations, and lacks flexibility, since one may want to couple families of different sizes.
Such problems appear in all areas of science, are easy to describe, yet hard to solve. Indeed, while matching optimally
two sets of *n* points using a pairwise cost can be solved with the
[Hungarian algorithm](https://en.wikipedia.org/wiki/Hungarian_algorithm), solving it costs an order of $O(n^3)$
operations, and lacks flexibility, since one may want to couple families of different sizes.

Optimal transport extends all of this, through faster algorithms (in $n^2$ or even linear in $n$) along with numerous generalizations that can help it handle weighted sets of different size, partial matchings, and even more evolved so-called quadratic matching problems.
Optimal transport extends all of this, through faster algorithms (in $n^2$ or even linear in $n$) along with numerous
generalizations that can help it handle weighted sets of different size, partial matchings, and even more evolved
so-called quadratic matching problems.

In the simple toy example below, we compute the optimal coupling matrix between two point clouds sampled randomly (2D vectors, compared with the squared Euclidean distance):
In the simple toy example below, we compute the optimal coupling matrix between two point clouds sampled randomly
(2D vectors, compared with the squared Euclidean distance):

## Example

```py
```python
import jax
import jax.numpy as jnp
from ott.tools import transport
Expand All @@ -41,17 +55,22 @@ ot = transport.solve(x, y, a=a, b=b)
P = ot.matrix
```

The call to `solve` above works out the optimal transport solution. The `ot` object contains a transport matrix (here of size $12\times 14$) that quantifies a `link strength` between each point of the first point cloud, to one or more points from the second, as illustrated in the plot below. In this toy example, most choices were arbitrary, and are reflected in the crude `solve` API. We provide far more flexibility to define custom cost functions, objectives, and solvers, as detailed in the [full documentation](https://ott-jax.readthedocs.io/en/latest/).
The call to `solve` above works out the optimal transport solution. The `ot` object contains a transport matrix
(here of size $12\times 14$) that quantifies a `link strength` between each point of the first point cloud, to one or
more points from the second, as illustrated in the plot below. In this toy example, most choices were arbitrary, and
are reflected in the crude `solve` API. We provide far more flexibility to define custom cost functions, objectives,
and solvers, as detailed in the [full documentation](https://ott-jax.readthedocs.io/en/latest/).

![obtained coupling](https://raw.githubusercontent.com/ott-jax/ott/main/images/couplings.png)
## Citation

## Citation
If you have found this work useful, please consider citing this reference:

```
@article{cuturi2022optimal,
title={Optimal Transport Tools (OTT): A JAX Toolbox for all things Wasserstein},
author={Cuturi, Marco and Meng-Papaxanthos, Laetitia and Tian, Yingtao and Bunne, Charlotte and Davis, Geoff and Teboul, Olivier},
author={Cuturi, Marco and Meng-Papaxanthos, Laetitia and Tian, Yingtao and Bunne, Charlotte and
Davis, Geoff and Teboul, Olivier},
journal={arXiv preprint arXiv:2201.12324},
year={2022}
}
Expand Down
1 change: 1 addition & 0 deletions docs/Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -22,3 +22,4 @@ help:
clean:
@rm -rf $(BUILDDIR)/
@rm -rf $(SOURCEDIR)/_autosummary
@rm -rf $(SOURCEDIR)/**/_autosummary
4 changes: 4 additions & 0 deletions docs/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,10 @@
source_suffix = ['.rst']

autosummary_generate = True
autosummary_filename_map = {
"ott.solvers.linear.sinkhorn.sinkhorn":
"ott.solvers.linear.sinkhorn.sinkhorn-function"
}

autodoc_typehints = 'description'

Expand Down
115 changes: 0 additions & 115 deletions docs/core.rst

This file was deleted.

7 changes: 7 additions & 0 deletions docs/geometry.rst
Original file line number Diff line number Diff line change
Expand Up @@ -58,3 +58,10 @@ Cost Functions
costs.Cosine
costs.Bures
costs.UnbalancedBures

Utilities
---------
.. autosummary::
:toctree: _autosummary

segment.segment_point_cloud
Loading

0 comments on commit 32962a1

Please sign in to comment.