Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

introduce multivariate cdf / quantiles #447

Merged
merged 10 commits into from
Oct 25, 2023
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
79 changes: 78 additions & 1 deletion src/ott/tools/soft_sort.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,15 +18,18 @@
import jax.numpy as jnp
import numpy as np

from ott.geometry import pointcloud
from ott.geometry import costs, pointcloud
from ott.problems.linear import linear_problem
from ott.solvers import linear
from ott.solvers.linear import sinkhorn

__all__ = [
"sort", "ranks", "sort_with", "quantile", "quantile_normalization",
"quantize", "topk_mask"
]

Func_t = Callable[[jnp.ndarray], jnp.ndarray]


def transport_for_sort(
inputs: jnp.ndarray,
Expand Down Expand Up @@ -450,6 +453,80 @@ def _quantile(
return apply_on_axis(_quantile, inputs, axis, q, weight, **kwargs)


def multiv_cdf_quantile_maps(
michalk8 marked this conversation as resolved.
Show resolved Hide resolved
michalk8 marked this conversation as resolved.
Show resolved Hide resolved
inputs: jnp.ndarray,
target_sampler: Optional[Callable[[jax.random.PRNGKey, Tuple[int, int]],
jnp.ndarray]] = None,
key: Optional[jax.random.PRNGKey] = None,
michalk8 marked this conversation as resolved.
Show resolved Hide resolved
num_target_samples: Optional[int] = None,
cost_fn: Optional[costs.CostFn] = None,
epsilon: Optional[float] = None,
input_weights: Optional[jnp.ndarray] = None,
target_weights: Optional[jnp.ndarray] = None,
**kwargs: Any
) -> [Func_t, Func_t]:
marcocuturi marked this conversation as resolved.
Show resolved Hide resolved
r"""Returns multivariate CDF and quantile maps, given input samples.

Implements the multivariate generalizations for CDF and quantiles proposed in
:cite:`chernozhukov:17`. The reference measure is assumed to be the uniform
michalk8 marked this conversation as resolved.
Show resolved Hide resolved
measure by default, but can be modified. For consistency, the reference
measure should be symmetrically centered around
:math:`(\tfrac{1}{2},\cdots,\tfrac{1}{2})` and suppported on :math:`[0,1]^d`.

The implementation return two entropic map estimators, one for the CDF map,
the other for the quantiles map.

Args:
inputs: 2D array of :math:`[n, d]` vectors.
target_sampler: Callable that takes a ``key`` and ``[m,d]`` shape Tuple.
marcocuturi marked this conversation as resolved.
Show resolved Hide resolved
``m`` is passed on as ``target_num_samples``, dimension ``d`` is inferred
directly from the shape passed in ``inputs``. This is assumed by default
to be :func:`jax.random.uniform`, and could be any other random sampler
marcocuturi marked this conversation as resolved.
Show resolved Hide resolved
properly wrapped to have the signature above.
key: rng key used by ``target_sampler``
marcocuturi marked this conversation as resolved.
Show resolved Hide resolved
num_target_samples: number ``m`` of points generated in the target
distribution.
cost_fn: :class:`~ott.geometry.costs.CostFn` object, used to compare
marcocuturi marked this conversation as resolved.
Show resolved Hide resolved
``inputs`` and ``targets``. Passed on to instantiate a
:class:`~ott.geometry.pointcloud.PointCloud` object. This
defaults to the squared-Euclidean distance.
marcocuturi marked this conversation as resolved.
Show resolved Hide resolved
epsilon: entropic regularization parameter used to instantiate the
:class:`~ott.geometry.pointcloud.PointCloud` object.
input_weights: :math:`[n,]` vector of weights for input measure. Assumed to
be uniform by default.
target_weights: :math:`[m,]` vector of weights for target measure. Assumed
to be uniform by default.
kwargs: keyword arguments passed on to the :func:`~ott.solvers.linear.solve`
function, which solves the OT problem between ``inputs`` and ``targets``
using the Sinkhorn algorithm.

Returns:
Two callables, vector-to-vector mappings:
- multivariate CDF map, taking values in the range of the reference
marcocuturi marked this conversation as resolved.
Show resolved Hide resolved
measure.
- quantile map, going by default from :math:`[0, 1]^d` to the range of the
input measure.

Raises:
A ValueError in case the input and target have not the same dimension.
michalk8 marked this conversation as resolved.
Show resolved Hide resolved
"""
n, d = inputs.shape
key = jax.random.PRNGKey(0) if key is None else key
michalk8 marked this conversation as resolved.
Show resolved Hide resolved
num_target_samples = n if num_target_samples is None else num_target_samples
if target_sampler is None:
target_sampler = jax.random.uniform
targets = target_sampler(key, (num_target_samples, d))

geom = pointcloud.PointCloud(
inputs, targets, cost_fn=cost_fn, epsilon=epsilon
)
out = linear.solve(geom, a=input_weights, b=target_weights, **kwargs)
potentials = out.to_dual_potentials()
return potentials.transport, functools.partial(
marcocuturi marked this conversation as resolved.
Show resolved Hide resolved
potentials.transport, forward=False
)


def _quantile_normalization(
inputs: jnp.ndarray, targets: jnp.ndarray, weights: float, **kwargs: Any
) -> jnp.ndarray:
Expand Down
2 changes: 1 addition & 1 deletion tests/solvers/linear/sinkhorn_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -474,7 +474,7 @@ def test_restart(self, lse_mode: bool):
assert num_iter_restarted == 1

@pytest.mark.cpu()
@pytest.mark.limit_memory("35 MB")
@pytest.mark.limit_memory("36 MB")
@pytest.mark.fast()
def test_sinkhorn_online_memory_jit(self):
# test that full matrix is not materialized.
Expand Down
38 changes: 38 additions & 0 deletions tests/tools/soft_sort_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,44 @@ def test_sort_batch(self, rng: jax.random.PRNGKeyArray, topk: int):
np.testing.assert_array_equal(xs.shape, expected_shape)
np.testing.assert_array_equal(jnp.diff(xs, axis=axis) >= 0.0, True)

def test_multiv_cdf_quantiles(self):
n, d = 512, 3
keys = jax.random.split(jax.random.PRNGKey(0), 3)
michalk8 marked this conversation as resolved.
Show resolved Hide resolved

# Set central point in sampled input measure
z = jax.random.uniform(keys[0], (1, d))

# Sample inputs symmetrically centered on z
inputs = 0.34 * jax.random.normal(keys[0], (n, d)) + z

# Set central point in target distribution.
q = 0.5 * jnp.ones((1, d))

# Set tolerance for quantile / cdf comparisons to ground truth.
atol = 0.1

# Check approximate correctness of naked call to API
cdf, qua = soft_sort.multiv_cdf_quantile_maps(inputs)
np.testing.assert_allclose(cdf(z), q, atol=atol)
np.testing.assert_allclose(z, qua(q), atol=atol)

# Check passing custom sampler, must be still symmetric / centered on {.5}^d
# Check passing custom epsilon also works.
def ball_sampler(k: jax.random.PRNGKey, s: Tuple[int, int]) -> jnp.ndarray:
return 0.5 * (jax.random.ball(k, d=s[1], p=4, shape=(s[0],)) + 1.)

num_target_samples = 473

cdf, qua = soft_sort.multiv_cdf_quantile_maps(
inputs,
target_sampler=ball_sampler,
num_target_samples=num_target_samples,
key=keys[2],
epsilon=0.05
)
np.testing.assert_allclose(cdf(z), q, atol=atol)
np.testing.assert_allclose(z, qua(q), atol=atol)

@pytest.mark.fast.with_args("axis,jit", [(0, False), (1, True)], only_fast=0)
def test_ranks(self, axis, rng: jax.random.PRNGKeyArray, jit: bool):
rng1, rng2 = jax.random.split(rng, 2)
Expand Down