Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

introduce multivariate cdf / quantiles #447

Merged
merged 10 commits into from
Oct 25, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 12 additions & 1 deletion docs/references.bib
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,6 @@ @inproceedings{scetbon:21
@article{schiebinger:19,
author = {Schiebinger, Geoffrey and Shu, Jian and Tabaka, Marcin and Cleary, Brian and Subramanian, Vidya and Solomon, Aryeh and Gould, Joshua and Liu, Siyan and Lin, Stacie and Berube, Peter and Lee, Lia and Chen, Jenny and Brumbaugh, Justin and Rigollet, Philippe and Hochedlinger, Konrad and Jaenisch, Rudolf and Regev, Aviv and Lander, Eric S.},
publisher = {Elsevier},
doi = {10.1016/j.cell.2019.01.006},
issn = {0092-8674},
journal = {Cell},
number = {4},
Expand All @@ -129,6 +128,18 @@ @article{memoli:11
year = {2011},
}

@article{chernozhukov:17,
author = {Chernozhukov, Victor and Galichon, Alfred and Hallin, Marc and Henry, Marc},
publisher = {Institute of Mathematical Statistics},
journal = {The Annals of Statistics},
keywords = {empirical transport maps,multivariate signs,Statistical depth,uniform convergence of empirical transport,vector quantiles,vector ranks},
number = {1},
pages = {223--256},
title = {{{M}onge–{K}antorovich depth, quantiles, ranks and signs}},
volume = {45},
year = {2017},
}

@inproceedings{scetbon:22,
author = {Scetbon, Meyer and Peyré, Gabriel and Cuturi, Marco},
editor = {Chaudhuri, Kamalika and Jegelka, Stefanie and Song, Le and Szepesvari, Csaba and Niu, Gang and Sabato, Sivan},
Expand Down
1 change: 1 addition & 0 deletions docs/tools.rst
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ Soft Sorting Algorithms
.. autosummary::
:toctree: _autosummary

soft_sort.multivariate_cdf_quantile_maps
soft_sort.quantile
soft_sort.quantile_normalization
soft_sort.quantize
Expand Down
87 changes: 85 additions & 2 deletions src/ott/tools/soft_sort.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,17 +16,22 @@

import jax
import jax.numpy as jnp
import jax.tree_util as jtu
import numpy as np

from ott.geometry import pointcloud
from ott import utils
from ott.geometry import costs, pointcloud
from ott.problems.linear import linear_problem
from ott.solvers import linear
from ott.solvers.linear import sinkhorn

__all__ = [
"sort", "ranks", "sort_with", "quantile", "quantile_normalization",
"quantize", "topk_mask"
"quantize", "topk_mask", "multivariate_cdf_quantile_maps"
]

Func_t = Callable[[jnp.ndarray], jnp.ndarray]


def transport_for_sort(
inputs: jnp.ndarray,
Expand Down Expand Up @@ -450,6 +455,84 @@ def _quantile(
return apply_on_axis(_quantile, inputs, axis, q, weight, **kwargs)


def multivariate_cdf_quantile_maps(
inputs: jnp.ndarray,
target_sampler: Optional[Callable[[jax.random.PRNGKey, Tuple[int, int]],
jnp.ndarray]] = None,
rng: Optional[jax.random.PRNGKey] = None,
num_target_samples: Optional[int] = None,
cost_fn: Optional[costs.CostFn] = None,
epsilon: Optional[float] = None,
input_weights: Optional[jnp.ndarray] = None,
target_weights: Optional[jnp.ndarray] = None,
**kwargs: Any
) -> Tuple[Func_t, Func_t]:
r"""Returns multivariate CDF and quantile maps, given input samples.

Implements the multivariate generalizations for CDF and quantiles proposed in
:cite:`chernozhukov:17`. The reference measure is assumed to be the uniform
michalk8 marked this conversation as resolved.
Show resolved Hide resolved
measure by default, but can be modified. For consistency, the reference
measure should be symmetrically centered around
:math:`(\tfrac{1}{2},\cdots,\tfrac{1}{2})` and supported on :math:`[0, 1]^d`.

The implementation return two entropic map estimators, one for the CDF map,
the other for the quantiles map.

Args:
inputs: 2D array of ``[n, d]`` vectors.
target_sampler: Callable that takes a ``rng`` and ``[m, d]`` shape.
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sorry about this one!!!

``m`` is passed on as ``target_num_samples``, dimension ``d`` is inferred
directly from the shape passed in ``inputs``. This is assumed by default
to be :func:`~jax.random.uniform`, and could be any other random sampler
properly wrapped to have the signature above.
rng: rng key used by ``target_sampler``.
num_target_samples: number ``m`` of points generated in the target
distribution.
cost_fn: Cost function, used to compare ``inputs`` and ``targets``.
Passed on to instantiate a
:class:`~ott.geometry.pointcloud.PointCloud` object. If :obj:`None`,
:class:`~ott.geometry.costs.SqEuclidean` is used.
epsilon: entropic regularization parameter used to instantiate the
:class:`~ott.geometry.pointcloud.PointCloud` object.
input_weights: ``[n,]`` vector of weights for input measure. Assumed to
be uniform by default.
target_weights: ``[m,]`` vector of weights for target measure. Assumed
to be uniform by default.
kwargs: keyword arguments passed on to the :func:`~ott.solvers.linear.solve`
function, which solves the OT problem between ``inputs`` and ``targets``
using the :class:`~ott.solvers.linear.sinkhorn.Sinkhorn` algorithm.
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

here i wasn't sure about the reference, because we use solve... that being said, probably not a good idea to use LR on this, because would crash :) so maybe an instance where we should force kwargs to only refer to sinkhorn...


Returns:
- The multivariate CDF map, taking a ``[b, d]`` batch of vectors in the
range of the ``inputs``, and mapping each vector within the range
of the reference measure (assumed by default to be :math:`[0, 1]^d`).
- The quantile map, mapping a batch ``[b, d]`` of multivariate quantile
vectors onto ``[b, d]`` vectors in :math:`[0, 1]^d`, the range of
the reference measure.
"""
n, d = inputs.shape
rng = utils.default_prng_key(rng)

if num_target_samples is None:
num_target_samples = n
if target_sampler is None:
target_sampler = jax.random.uniform

targets = target_sampler(rng, (num_target_samples, d))
geom = pointcloud.PointCloud(
inputs, targets, cost_fn=cost_fn, epsilon=epsilon
)

out = linear.solve(geom, a=input_weights, b=target_weights, **kwargs)
potentials = out.to_dual_potentials()

cdf_map = jtu.Partial(lambda x, p: p.transport(x), p=potentials)
quantile_map = jtu.Partial(
lambda x, p: p.transport(x, forward=False), p=potentials
)
return cdf_map, quantile_map


def _quantile_normalization(
inputs: jnp.ndarray, targets: jnp.ndarray, weights: float, **kwargs: Any
) -> jnp.ndarray:
Expand Down
2 changes: 1 addition & 1 deletion tests/solvers/linear/sinkhorn_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -474,7 +474,7 @@ def test_restart(self, lse_mode: bool):
assert num_iter_restarted == 1

@pytest.mark.cpu()
@pytest.mark.limit_memory("35 MB")
@pytest.mark.limit_memory("36 MB")
@pytest.mark.fast()
def test_sinkhorn_online_memory_jit(self):
# test that full matrix is not materialized.
Expand Down
42 changes: 42 additions & 0 deletions tests/tools/soft_sort_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,48 @@ def test_sort_batch(self, rng: jax.random.PRNGKeyArray, topk: int):
np.testing.assert_array_equal(xs.shape, expected_shape)
np.testing.assert_array_equal(jnp.diff(xs, axis=axis) >= 0.0, True)

def test_multivariate_cdf_quantiles(self, rng: jax.random.PRNGKeyArray):
n, d = 512, 3
key1, key2, key3 = jax.random.split(rng, 3)

# Set central point in sampled input measure
z = jax.random.uniform(key1, (1, d))

# Sample inputs symmetrically centered on z
inputs = 0.34 * jax.random.normal(key2, (n, d)) + z

# Set central point in target distribution.
q = 0.5 * jnp.ones((1, d))

# Set tolerance for quantile / cdf comparisons to ground truth.
atol = 0.1

# Check approximate correctness of naked call to API
cdf, qua = soft_sort.multivariate_cdf_quantile_maps(inputs)
np.testing.assert_allclose(cdf(z), q, atol=atol)
np.testing.assert_allclose(z, qua(q), atol=atol)

# Check passing custom sampler, must be still symmetric / centered on {.5}^d
# Check passing custom epsilon also works.
def ball_sampler(k: jax.random.PRNGKey, s: Tuple[int, int]) -> jnp.ndarray:
return 0.5 * (jax.random.ball(k, d=s[1], p=4, shape=(s[0],)) + 1.)

num_target_samples = 473

@functools.partial(jax.jit, static_argnums=[1])
def mv_c_q(inputs, num_target_samples, rng, epsilon):
return soft_sort.multivariate_cdf_quantile_maps(
inputs,
target_sampler=ball_sampler,
num_target_samples=num_target_samples,
rng=rng,
epsilon=epsilon
)

cdf, qua = mv_c_q(inputs, num_target_samples, key3, 0.05)
np.testing.assert_allclose(cdf(z), q, atol=atol)
np.testing.assert_allclose(z, qua(q), atol=atol)

@pytest.mark.fast.with_args("axis,jit", [(0, False), (1, True)], only_fast=0)
def test_ranks(self, axis, rng: jax.random.PRNGKeyArray, jit: bool):
rng1, rng2 = jax.random.split(rng, 2)
Expand Down