Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

implements sliced W #576

Merged
merged 13 commits into from
Sep 13, 2024
9 changes: 9 additions & 0 deletions docs/references.bib
Original file line number Diff line number Diff line change
Expand Up @@ -1027,3 +1027,12 @@ @article{lin:22
volume = {23},
year = {2022},
}

@inproceedings{rabin:12,
author = {Rabin, Julien and Peyr{\'e}, Gabriel and Delon, Julie and Bernot, Marc},
title = {Wasserstein barycenter and its application to texture mixing},
booktitle = {Scale Space and Variational Methods in Computer Vision: Third International Conference, SSVM 2011, Ein-Gedi, Israel, May 29--June 2, 2011, Revised Selected Papers 3},
pages = {435--446},
year = {2012},
organization = {Springer}
}
21 changes: 17 additions & 4 deletions docs/tools.rst
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,15 @@ ott.tools
.. module:: ott.tools
.. currentmodule:: ott.tools

The tools package contains high level functions that build on outputs produced
by core functions. They can be used to compute Sinkhorn divergences
:cite:`sejourne:19`, instantiate transport matrices, provide differentiable
approximations to ranks and quantile functions :cite:`cuturi:19`, etc.
The :mod:`~ott.tools` package contains high level functions that build on
outputs produced by lower-level components in the toolbox, such as
:mod:`~ott.solvers`.

In particular, we provide user-friendly APIs to compute Sinkhorn divergences
:cite:`genevay:18,sejourne:19`, sliced Wasserstein distances :cite:`rabin:12`,
differentiable approximations to ranks and quantile functions :cite:`cuturi:19`,
and various tools to study Gaussians with the 2-Wasserstein metric
:cite:`gelbrich:90,delon:20`, etc.

Segmented Sinkhorn
------------------
Expand All @@ -23,6 +28,14 @@ Sinkhorn Divergence
sinkhorn_divergence.sinkhorn_divergence
sinkhorn_divergence.segment_sinkhorn_divergence

Sliced Wasserstein Distance
---------------------------
.. autosummary::
:toctree: _autosummary

sliced.random_proj_sphere
sliced.sliced_wasserstein

ProgOT
------
.. autosummary::
Expand Down
1 change: 1 addition & 0 deletions src/ott/tools/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,5 +18,6 @@
progot,
segment_sinkhorn,
sinkhorn_divergence,
sliced,
marcocuturi marked this conversation as resolved.
Show resolved Hide resolved
soft_sort,
)
109 changes: 109 additions & 0 deletions src/ott/tools/sliced.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
# Copyright OTT-JAX
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Callable, Optional, Tuple

import jax
import jax.numpy as jnp

from ott import utils
from ott.geometry import costs, pointcloud
from ott.solvers import linear
from ott.solvers.linear import univariate

__all__ = ["random_proj_sphere", "sliced_wasserstein"]

Projector = Callable[[jnp.ndarray, int, jax.Array], jnp.ndarray]


def random_proj_sphere(
x: jnp.ndarray,
marcocuturi marked this conversation as resolved.
Show resolved Hide resolved
n_proj: int = 1000,
rng: Optional[jax.Array] = None
) -> jnp.ndarray:
"""Project data on directions sampled randomly from sphere.

marcocuturi marked this conversation as resolved.
Show resolved Hide resolved
Args:
x: Array of size ``[n, dim]``.
n_proj: Number of randomly generated projections.
rng: Key used to sample feature extractors.

Returns:
Array of size ``[n, n_proj]`` features.
"""
rng = utils.default_prng_key(rng)
dim = x.shape[-1]
proj_m = jax.random.normal(rng, (n_proj, dim))
proj_m /= jnp.linalg.norm(proj_m, axis=1, keepdims=True)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is nitpicking, but are we sure that this division is safe? In theory the probability of getting the null vector is zero, but in practice I'm not sure of what is happening in the worst case.

return x @ proj_m.T


def sliced_wasserstein(
x: jnp.ndarray,
y: jnp.ndarray,
a: Optional[jnp.ndarray] = None,
b: Optional[jnp.ndarray] = None,
cost_fn: Optional[costs.CostFn] = None,
proj_fn: Optional[Projector] = None,
weights: Optional[jnp.ndarray] = None,
return_transport: bool = False,
return_dual_variables: bool = False,
**kwargs: Any,
) -> Tuple[jnp.ndarray, univariate.UnivariateOutput]:
r"""Compute the Sliced Wasserstein distance between two weighted point clouds.

Follows the approach outlined in :cite:`rabin:12` to compute a proxy for OT
distances that relies on creating features (possibly randomly) for data,
through e.g., projections, and then sum the 1D Wasserstein distances between
these features' univariate distributions on both source and target samples.

Args:
x: Array of shape ``[n, dim]`` of source points' coordinates.
y: Array of shape ``[m, dim]`` of target points' coordinates.
a: Array of shape ``[n,]`` of source probability weights.
b: Array of shape ``[m,]`` of target probability weights.
cost_fn: Cost function. Must be a submodular function of two real arguments,
i.e. such that :math:`\partial c(x,y)/\partial x \partial y <0`. If
:obj:`None`, use :class:`~ott.geometry.costs.SqEuclidean`.
proj_fn: Projection function, mapping any ``[b, dim]`` matrix of coordinates
to ``[b, n_proj]`` matrix of features, on which 1D transports (for
``n_proj`` directions) are subsequently computed independently.
By default, use :func:`~ott.tools.sliced.random_proj_sphere`.
weights: Array of shape ``[n_proj,]`` of weights used to average the
``n_proj`` 1D Wasserstein contributions (one for each feature) and form
the sliced Wasserstein distance. Uniform by default, resulting in average
of all these values.
return_transport: Whether to store ``n_proj`` transport plans in the output.
return_dual_variables: Whether to store ``n_proj`` pairs of dual vectors
in the output.
kwargs: Keyword arguments to ``proj_fn``. Could for instance
include, as done with default projector, number of ``n_proj`` projections,
as well as a ``rng`` key to sample as many directions.

Returns:
The sliced Wasserstein distance with the corresponding output object.
"""
if proj_fn is None:
proj_fn = random_proj_sphere

x_proj, y_proj = proj_fn(x, **kwargs), proj_fn(y, **kwargs),
geom = pointcloud.PointCloud(x_proj, y_proj, cost_fn=cost_fn)

out = linear.solve_univariate(
geom,
a,
b,
return_transport=return_transport,
return_dual_variables=return_dual_variables
)
return jnp.average(out.ot_costs, weights=weights), out
113 changes: 113 additions & 0 deletions tests/tools/sliced_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
# Copyright OTT-JAX
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Callable, Optional, Tuple

import pytest

marcocuturi marked this conversation as resolved.
Show resolved Hide resolved
import jax
import jax.numpy as jnp
import numpy as np

from ott.geometry import costs, pointcloud
from ott.solvers import linear
from ott.tools import sliced

Projector = Callable[[jnp.ndarray, int, jax.Array], jnp.ndarray]


def custom_proj(
x: jnp.ndarray,
rng: Optional[jax.Array] = None,
n_proj: int = 27
) -> jnp.ndarray:
dim = x.shape[1]
rng = jax.random.PRNGKey(42) if rng is None else rng
proj_m = jax.random.uniform(rng, (n_proj, dim))
return (x @ proj_m.T) ** 2


def gen_data(
rng: jax.Array, n: int, m: int, dim: int
) -> Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray, jnp.ndarray]:
rngs = jax.random.split(rng, 4)
x = jax.random.uniform(rngs[0], (n, dim))
y = jax.random.uniform(rngs[1], (m, dim))
a = jax.random.uniform(rngs[2], (n,))
b = jax.random.uniform(rngs[3], (m,))
a /= jnp.sum(a)
b /= jnp.sum(b)
return a, x, b, y


class TestSliced:
marcocuturi marked this conversation as resolved.
Show resolved Hide resolved

@pytest.mark.parametrize("proj_fn", [None, custom_proj])
@pytest.mark.parametrize("cost_fn", [costs.PNormP(1.3), None])
def test_random_projs(
self, rng: jax.Array, cost_fn: Optional[costs.CostFn],
proj_fn: Optional[Projector]
):
n, m, dim, n_proj = 12, 17, 5, 13
rng1, rng2 = jax.random.split(rng, 2)
a, x, b, y = gen_data(rng1, n, m, dim)
weights = jax.random.uniform(rng2, n_proj)

# Test non-negative and returns output as needed.
cost, out = sliced.sliced_wasserstein(
x,
y,
a,
b,
cost_fn=cost_fn,
proj_fn=proj_fn,
n_proj=n_proj,
rng=rng2,
weights=weights
)
assert cost > 0.0
np.testing.assert_array_equal(
cost, jnp.average(out.ot_costs, weights=weights)
)

@pytest.mark.parametrize("cost_fn", [costs.SqPNorm(1.4), None])
def test_consistency_with_id(
self, rng: jax.Array, cost_fn: Optional[costs.CostFn]
):
n, m, dim = 11, 12, 4
a, x, b, y = gen_data(rng, n, m, dim)

# Test matches standard implementation when using identity.
cost, _ = sliced.sliced_wasserstein(
x, y, proj_fn=lambda x: x, cost_fn=cost_fn
)
geom = pointcloud.PointCloud(x=x, y=y, cost_fn=cost_fn)
out_lin = jnp.mean(linear.solve_univariate(geom).ot_costs)
np.testing.assert_allclose(out_lin, cost, rtol=1e-6, atol=1e-6)

@pytest.mark.parametrize("proj_fn", [None, custom_proj])
def test_diff(self, rng: jax.Array, proj_fn: Optional[Projector]):
eps = 1e-4
n, m, dim = 13, 16, 7
a, x, b, y = gen_data(rng, n, m, dim)

# Test differentiability. We assume uniform samples because makes diff
marcocuturi marked this conversation as resolved.
Show resolved Hide resolved
# more accurate (avoiding ties, making computations a lot more sensitive).
dx = jax.random.uniform(rng, (n, dim)) - 0.5
cost_p, _ = sliced.sliced_wasserstein(x + eps * dx, y)
cost_m, _ = sliced.sliced_wasserstein(x - eps * dx, y)
g, _ = jax.jit(jax.grad(sliced.sliced_wasserstein, has_aux=True))(x, y)

np.testing.assert_allclose(
jnp.sum(g * dx), (cost_p - cost_m) / (2 * eps), atol=1e-3, rtol=1e-3
)
Loading