diff --git a/docs/references.bib b/docs/references.bib index f2f59d870..35ba274ba 100644 --- a/docs/references.bib +++ b/docs/references.bib @@ -106,7 +106,6 @@ @inproceedings{scetbon:21 @article{schiebinger:19, author = {Schiebinger, Geoffrey and Shu, Jian and Tabaka, Marcin and Cleary, Brian and Subramanian, Vidya and Solomon, Aryeh and Gould, Joshua and Liu, Siyan and Lin, Stacie and Berube, Peter and Lee, Lia and Chen, Jenny and Brumbaugh, Justin and Rigollet, Philippe and Hochedlinger, Konrad and Jaenisch, Rudolf and Regev, Aviv and Lander, Eric S.}, publisher = {Elsevier}, - doi = {10.1016/j.cell.2019.01.006}, issn = {0092-8674}, journal = {Cell}, number = {4}, @@ -129,6 +128,18 @@ @article{memoli:11 year = {2011}, } +@article{chernozhukov:17, + author = {Chernozhukov, Victor and Galichon, Alfred and Hallin, Marc and Henry, Marc}, + publisher = {Institute of Mathematical Statistics}, + journal = {The Annals of Statistics}, + keywords = {empirical transport maps,multivariate signs,Statistical depth,uniform convergence of empirical transport,vector quantiles,vector ranks}, + number = {1}, + pages = {223--256}, + title = {{{M}onge–{K}antorovich depth, quantiles, ranks and signs}}, + volume = {45}, + year = {2017}, +} + @inproceedings{scetbon:22, author = {Scetbon, Meyer and Peyré, Gabriel and Cuturi, Marco}, editor = {Chaudhuri, Kamalika and Jegelka, Stefanie and Song, Le and Szepesvari, Csaba and Niu, Gang and Sabato, Sivan}, diff --git a/docs/tools.rst b/docs/tools.rst index 455a4eb44..966445c24 100644 --- a/docs/tools.rst +++ b/docs/tools.rst @@ -35,6 +35,7 @@ Soft Sorting Algorithms .. autosummary:: :toctree: _autosummary + soft_sort.multivariate_cdf_quantile_maps soft_sort.quantile soft_sort.quantile_normalization soft_sort.quantize diff --git a/src/ott/tools/soft_sort.py b/src/ott/tools/soft_sort.py index 6d7b60ba6..c2d03afe5 100644 --- a/src/ott/tools/soft_sort.py +++ b/src/ott/tools/soft_sort.py @@ -16,17 +16,22 @@ import jax import jax.numpy as jnp +import jax.tree_util as jtu import numpy as np -from ott.geometry import pointcloud +from ott import utils +from ott.geometry import costs, pointcloud from ott.problems.linear import linear_problem +from ott.solvers import linear from ott.solvers.linear import sinkhorn __all__ = [ "sort", "ranks", "sort_with", "quantile", "quantile_normalization", - "quantize", "topk_mask" + "quantize", "topk_mask", "multivariate_cdf_quantile_maps" ] +Func_t = Callable[[jnp.ndarray], jnp.ndarray] + def transport_for_sort( inputs: jnp.ndarray, @@ -450,6 +455,84 @@ def _quantile( return apply_on_axis(_quantile, inputs, axis, q, weight, **kwargs) +def multivariate_cdf_quantile_maps( + inputs: jnp.ndarray, + target_sampler: Optional[Callable[[jax.random.PRNGKey, Tuple[int, int]], + jnp.ndarray]] = None, + rng: Optional[jax.random.PRNGKey] = None, + num_target_samples: Optional[int] = None, + cost_fn: Optional[costs.CostFn] = None, + epsilon: Optional[float] = None, + input_weights: Optional[jnp.ndarray] = None, + target_weights: Optional[jnp.ndarray] = None, + **kwargs: Any +) -> Tuple[Func_t, Func_t]: + r"""Returns multivariate CDF and quantile maps, given input samples. + + Implements the multivariate generalizations for CDF and quantiles proposed in + :cite:`chernozhukov:17`. The reference measure is assumed to be the uniform + measure by default, but can be modified. For consistency, the reference + measure should be symmetrically centered around + :math:`(\tfrac{1}{2},\cdots,\tfrac{1}{2})` and supported on :math:`[0, 1]^d`. + + The implementation return two entropic map estimators, one for the CDF map, + the other for the quantiles map. + + Args: + inputs: 2D array of ``[n, d]`` vectors. + target_sampler: Callable that takes a ``rng`` and ``[m, d]`` shape. + ``m`` is passed on as ``target_num_samples``, dimension ``d`` is inferred + directly from the shape passed in ``inputs``. This is assumed by default + to be :func:`~jax.random.uniform`, and could be any other random sampler + properly wrapped to have the signature above. + rng: rng key used by ``target_sampler``. + num_target_samples: number ``m`` of points generated in the target + distribution. + cost_fn: Cost function, used to compare ``inputs`` and ``targets``. + Passed on to instantiate a + :class:`~ott.geometry.pointcloud.PointCloud` object. If :obj:`None`, + :class:`~ott.geometry.costs.SqEuclidean` is used. + epsilon: entropic regularization parameter used to instantiate the + :class:`~ott.geometry.pointcloud.PointCloud` object. + input_weights: ``[n,]`` vector of weights for input measure. Assumed to + be uniform by default. + target_weights: ``[m,]`` vector of weights for target measure. Assumed + to be uniform by default. + kwargs: keyword arguments passed on to the :func:`~ott.solvers.linear.solve` + function, which solves the OT problem between ``inputs`` and ``targets`` + using the :class:`~ott.solvers.linear.sinkhorn.Sinkhorn` algorithm. + + Returns: + - The multivariate CDF map, taking a ``[b, d]`` batch of vectors in the + range of the ``inputs``, and mapping each vector within the range + of the reference measure (assumed by default to be :math:`[0, 1]^d`). + - The quantile map, mapping a batch ``[b, d]`` of multivariate quantile + vectors onto ``[b, d]`` vectors in :math:`[0, 1]^d`, the range of + the reference measure. + """ + n, d = inputs.shape + rng = utils.default_prng_key(rng) + + if num_target_samples is None: + num_target_samples = n + if target_sampler is None: + target_sampler = jax.random.uniform + + targets = target_sampler(rng, (num_target_samples, d)) + geom = pointcloud.PointCloud( + inputs, targets, cost_fn=cost_fn, epsilon=epsilon + ) + + out = linear.solve(geom, a=input_weights, b=target_weights, **kwargs) + potentials = out.to_dual_potentials() + + cdf_map = jtu.Partial(lambda x, p: p.transport(x), p=potentials) + quantile_map = jtu.Partial( + lambda x, p: p.transport(x, forward=False), p=potentials + ) + return cdf_map, quantile_map + + def _quantile_normalization( inputs: jnp.ndarray, targets: jnp.ndarray, weights: float, **kwargs: Any ) -> jnp.ndarray: diff --git a/tests/solvers/linear/sinkhorn_test.py b/tests/solvers/linear/sinkhorn_test.py index a7d34c967..2b5b1fc1c 100644 --- a/tests/solvers/linear/sinkhorn_test.py +++ b/tests/solvers/linear/sinkhorn_test.py @@ -474,7 +474,7 @@ def test_restart(self, lse_mode: bool): assert num_iter_restarted == 1 @pytest.mark.cpu() - @pytest.mark.limit_memory("35 MB") + @pytest.mark.limit_memory("36 MB") @pytest.mark.fast() def test_sinkhorn_online_memory_jit(self): # test that full matrix is not materialized. diff --git a/tests/tools/soft_sort_test.py b/tests/tools/soft_sort_test.py index 28c2c1ad6..55700f357 100644 --- a/tests/tools/soft_sort_test.py +++ b/tests/tools/soft_sort_test.py @@ -86,6 +86,48 @@ def test_sort_batch(self, rng: jax.random.PRNGKeyArray, topk: int): np.testing.assert_array_equal(xs.shape, expected_shape) np.testing.assert_array_equal(jnp.diff(xs, axis=axis) >= 0.0, True) + def test_multivariate_cdf_quantiles(self, rng: jax.random.PRNGKeyArray): + n, d = 512, 3 + key1, key2, key3 = jax.random.split(rng, 3) + + # Set central point in sampled input measure + z = jax.random.uniform(key1, (1, d)) + + # Sample inputs symmetrically centered on z + inputs = 0.34 * jax.random.normal(key2, (n, d)) + z + + # Set central point in target distribution. + q = 0.5 * jnp.ones((1, d)) + + # Set tolerance for quantile / cdf comparisons to ground truth. + atol = 0.1 + + # Check approximate correctness of naked call to API + cdf, qua = soft_sort.multivariate_cdf_quantile_maps(inputs) + np.testing.assert_allclose(cdf(z), q, atol=atol) + np.testing.assert_allclose(z, qua(q), atol=atol) + + # Check passing custom sampler, must be still symmetric / centered on {.5}^d + # Check passing custom epsilon also works. + def ball_sampler(k: jax.random.PRNGKey, s: Tuple[int, int]) -> jnp.ndarray: + return 0.5 * (jax.random.ball(k, d=s[1], p=4, shape=(s[0],)) + 1.) + + num_target_samples = 473 + + @functools.partial(jax.jit, static_argnums=[1]) + def mv_c_q(inputs, num_target_samples, rng, epsilon): + return soft_sort.multivariate_cdf_quantile_maps( + inputs, + target_sampler=ball_sampler, + num_target_samples=num_target_samples, + rng=rng, + epsilon=epsilon + ) + + cdf, qua = mv_c_q(inputs, num_target_samples, key3, 0.05) + np.testing.assert_allclose(cdf(z), q, atol=atol) + np.testing.assert_allclose(z, qua(q), atol=atol) + @pytest.mark.fast.with_args("axis,jit", [(0, False), (1, True)], only_fast=0) def test_ranks(self, axis, rng: jax.random.PRNGKeyArray, jit: bool): rng1, rng2 = jax.random.split(rng, 2)