Skip to content

Commit

Permalink
Add soft topk_mask operator, fix ranks docs and tests (#396)
Browse files Browse the repository at this point in the history
* fix ranks

* indent

* fix test

* fix normalization of target_weights
  • Loading branch information
marcocuturi authored Jul 14, 2023
1 parent d56cf63 commit 7e21ae6
Show file tree
Hide file tree
Showing 2 changed files with 181 additions and 72 deletions.
199 changes: 132 additions & 67 deletions src/ott/tools/soft_sort.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ def apply_on_axis(op, inputs, axis, *args, **kwargs: Any) -> jnp.ndarray:
op: a differentiable operator (can be ranks, quantile, etc.)
inputs: Array of any shape.
axis: the axis (int) or tuple of ints on which to apply the operator. If
several axes are passed the operator, those are merged as a single
several axes are passed to the operator, those are merged as a single
dimension.
args: other positional arguments to the operator.
kwargs: other positional arguments to the operator.
Expand Down Expand Up @@ -152,7 +152,6 @@ def sort(
x = jax.random.uniform(rng, (100,))
x_sorted = sort(x)
will output sorted convex-combinations of values contained in ``x``, that are
differentiable approximations to the sorted vector of entries in ``x``.
These can be compared with the values produced by :func:`jax.numpy.sort`,
Expand All @@ -169,46 +168,55 @@ def sort(
the top-k values. This also reduces the complexity of soft-sorting, since
the number of target points to which the slice of the ``inputs`` tensor
will be mapped to will be equal to ``topk+1``.
num_targets: if ``topk`` is not specified, ``num_targets`` defines the
number of (composite) sorted values computed from the inputs (each value
is a convex combination of values recorded in the inputs, provided in
increasing order). If neither ``topk`` nor ``num_targets`` are specified,
``num_targets`` defaults to the size of the slices of the input that are
sorted, i.e. ``inputs.shape[axis]``, and the number of composite sorted
values is equal to the slice of the inputs that are sorted.
num_targets: if ``topk`` is not specified, a vector of size``num_targets``
is returned. This defines the number of (composite) sorted values computed
from the inputs (each value is a convex combination of values recorded in
the inputs, provided in increasing order). If neither ``topk`` nor
``num_targets`` are specified, ``num_targets`` defaults to the size of the
slices of the input that are sorted, i.e. ``inputs.shape[axis]``, and the
number of composite sorted values is equal to the slice of the inputs that
are sorted. As a result, the output is of the same size as ``inputs``.
kwargs: keyword arguments passed on to lower level functions. Of interest
to the user are ``squashing_fun``, which will redistribute the values in
``inputs`` to lie in :math:`[0,1]` (sigmoid of whitened values by default)
to solve the optimal transport problem;
attribute :class:`cost_fn <ott.geometry.costs.CostFn>` of
:class:`cost_fn <ott.geometry.costs.CostFn>` object of
:class:`~ott.geometry.pointcloud.PointCloud`, which defines the ground
cost function to transport from ``inputs`` to the ``num_targets`` target
values ; ``epsilon`` regularization
parameter. Remaining ``kwargs`` are passed on to defined the
1D cost function to transport from ``inputs`` to the ``num_targets``
target values ; ``epsilon`` regularization parameter. Remaining ``kwargs``
are passed on to parameterize the
:class:`~ott.solvers.linear.sinkhorn.Sinkhorn` solver.
Returns:
An Array of the same shape as the input with soft-sorted values on the
given axis.
An Array of the same shape as the input, except on ``axis``, where that size
will be equal to ``topk`` or ``num_targets``, with soft-sorted values on the
given axis. Same size as ``inputs`` if both these parameters are ``None``.
"""
return apply_on_axis(_sort, inputs, axis, topk, num_targets, **kwargs)


def _ranks(inputs: jnp.ndarray, num_targets, **kwargs: Any) -> jnp.ndarray:
def _ranks(
inputs: jnp.ndarray, num_targets, target_weights, **kwargs: Any
) -> jnp.ndarray:
"""Apply the soft ranks operator on a one dimensional array."""
num_points = inputs.shape[0]
num_targets = num_points if num_targets is None else num_targets
if target_weights is None:
num_targets = num_points if num_targets is None else num_targets
target_weights = jnp.ones((num_targets,)) / num_targets
else:
num_targets = target_weights.shape[0]
a = jnp.ones((num_points,)) / num_points
b = jnp.ones((num_targets,)) / num_targets
ot = transport_for_sort(inputs, a, b, **kwargs)
ot = transport_for_sort(inputs, a, target_weights, **kwargs)
out = 1.0 / a * ot.apply(jnp.arange(num_targets), axis=1)
out *= (num_points - 1.0) / (num_targets - 1.0)
return jnp.reshape(out, inputs.shape)


def ranks(
inputs: jnp.ndarray,
axis: int = -1,
num_targets: Optional[int] = None,
target_weights: Optional[jnp.ndarray] = None,
**kwargs: Any,
) -> jnp.ndarray:
r"""Apply the soft rank operator on input tensor.
Expand All @@ -220,46 +228,103 @@ def ranks(
x = jax.random.uniform(rng, (100,))
x_ranks = ranks(x)
will output fractional values, between 0 and 1, that are differentiable
approximations to the normalized ranks of entries in ``x``. These should be
compared to the non-differentiable rank vectors, namely the normalized inverse
permutation produced by :func:`jax.numpy.argsort`, which can be obtained as:
will output values that are differentiable approximations to the ranks of
entries in ``x``. These should be compared to the non-differentiable rank
vectors, namely the normalized inverse permutation produced by
:func:`jax.numpy.argsort`, which can be obtained as:
.. code-block:: python
x_ranks = jax.numpy.argsort(jax.numpy.argsort(x)) / x.shape[0]
x_ranks = jax.numpy.argsort(jax.numpy.argsort(x))
Args:
inputs: Array of any shape.
axis: the axis on which to apply the soft-sorting operator.
topk: if set to a positive value, the returned vector will only contain
the top-k values. This also reduces the complexity of soft-sorting, since
the number of target points to which the slice of the ``inputs`` tensor
will be mapped to will be equal to ``topk+1``.
num_targets: if ``topk`` is not specified, ``num_targets`` defines the
number of (composite) sorted values computed from the inputs (each value
is a convex combination of values recorded in the inputs, provided in
increasing order). If neither ``topk`` nor ``num_targets`` are specified,
``num_targets`` defaults to the size of the slices of the input that are
sorted, i.e. ``inputs.shape[axis]``, and the number of composite sorted
values is equal to the slice of the inputs that are sorted.
target_weights: This vector contains weights (summing to 1) that describe
amount of mass shipped to targets.
num_targets: If `target_weights` is ``None``, ``num_targets`` is considered
to define the number of targets used to rank inputs. Each normalized rank
returned in the output will be a convex combination of
``{1, .., num_targets}/num_targets``. The weight of each of these points
is assumed to be uniform. If neither ``num_targets`` nor
``target_weights`` are specified, ``num_targets`` defaults to the size
of the slices of the input that are sorted, i.e. ``inputs.shape[axis]``.
kwargs: keyword arguments passed on to lower level functions. Of interest
to the user are ``squashing_fun``, which will redistribute the values in
``inputs`` to lie in :math:`[0,1]` (sigmoid of whitened values by default)
to solve the optimal transport problem;
attribute :class:`cost_fn <ott.geometry.costs.CostFn>` of
:class:`cost_fn <ott.geometry.costs.CostFn>` object of
:class:`~ott.geometry.pointcloud.PointCloud`, which defines the ground
cost function to transport from ``inputs`` to the ``num_targets`` target
values ; ``epsilon`` regularization
parameter. Remaining ``kwargs`` are passed on to defined the
1D cost function to transport from ``inputs`` to the ``num_targets``
target values ; ``epsilon`` regularization parameter. Remaining ``kwargs``
are passed on to parameterize the
:class:`~ott.solvers.linear.sinkhorn.Sinkhorn` solver.
Returns:
An Array of the same shape as the input with soft-rank values
normalized to be in :math:`[0,1]`, replacing the original ones.
normalized to be in `[0, n-1]` where `n` is `inputs.shape[axis]`.
"""
return apply_on_axis(_ranks, inputs, axis, num_targets, **kwargs)
return apply_on_axis(
_ranks, inputs, axis, num_targets, target_weights, **kwargs
)


def topk_mask(
inputs: jnp.ndarray,
axis: int = -1,
k: int = 1,
**kwargs: Any,
) -> jnp.ndarray:
r"""Soft top-$k$ selection mask.
For instance:
.. code-block:: python
k = 5
x = jax.random.uniform(rng, (100,))
mask = top_k_mask(x, k=k)
will output a vector of shape ``x.shape``, with values in :math:`[0,1]`, that
are differentiable approximations to the binary mask selecting the top $k$
entries in ``x``. These should be compared to the non-differentiable mask
obtained with :func:`jax.numpy.argsort`, which can be obtained as:
.. code-block:: python
mask = x >= jax.numpy.sort(x).flip()[k-1]
Args:
inputs: Array of any shape.
axis: the axis on which to apply the soft-sorting operator.
k : topk parameter. Should be smaller than ``inputs.shape[axis]``.
kwargs: keyword arguments passed on to lower level functions. Of interest
to the user are ``squashing_fun``, which will redistribute the values in
``inputs`` to lie in :math:`[0,1]` (sigmoid of whitened values by default)
to solve the optimal transport problem;
:class:`cost_fn <ott.geometry.costs.CostFn>` object of
:class:`~ott.geometry.pointcloud.PointCloud`, which defines the ground
1D cost function to transport from ``inputs`` to the ``num_targets``
target values ; ``epsilon`` regularization parameter. Remaining ``kwargs``
are passed on to parameterize the
:class:`~ott.solvers.linear.sinkhorn.Sinkhorn` solver.
"""
num_points = inputs.shape[axis]
assert k < num_points, (
f"`k` must be smaller than `inputs.shape[axis]`, yet {k} >= {num_points}."
)
target_weights = jnp.array([1.0 - k / num_points, k / num_points])
out = apply_on_axis(
_ranks,
inputs,
axis,
num_targets=None,
target_weights=target_weights,
**kwargs
)
return out / (num_points - 1)


def quantile(
Expand Down Expand Up @@ -299,21 +364,21 @@ def quantile(
These values should all lie in :math:`[0,1]`.
axis: the axis on which to apply the operator.
weight: the weight assigned to each quantile target value in the OT problem.
This weight should be small, typically of the order of ``1/n``, where ``n``
is the size of ``x``. Note: Since the size of ``q`` times ``weight``
must be strictly smaller than ``1``, in order to leave enough mass to set
other target values in the transport problem, the algorithm might ensure
this by setting, when needed, a lower value.
This weight should be small, typically of the order of ``1/n``, where ``n``
is the size of ``x``. Note: Since the size of ``q`` times ``weight``
must be strictly smaller than ``1``, in order to leave enough mass to set
other target values in the transport problem, the algorithm might ensure
this by setting, when needed, a lower value.
kwargs: keyword arguments passed on to lower level functions. Of interest
to the user are ``squashing_fun``, which will redistribute the values in
``inputs`` to lie in :math:`[0,1]` (sigmoid of whitened values by default)
to solve the optimal transport problem;
attribute :class:`cost_fn <ott.geometry.costs.CostFn>` of
:class:`~ott.geometry.pointcloud.PointCloud`, which defines the ground
cost function to transport from ``inputs`` to the ``num_targets`` target
values ; ``epsilon`` regularization
parameter. Remaining ``kwargs`` are passed on to defined the
:class:`~ott.solvers.linear.sinkhorn.Sinkhorn` solver.
to the user are ``squashing_fun``, which will redistribute the values in
``inputs`` to lie in :math:`[0,1]` (sigmoid of whitened values by default)
to solve the optimal transport problem;
:class:`cost_fn <ott.geometry.costs.CostFn>` object of
:class:`~ott.geometry.pointcloud.PointCloud`, which defines the ground
1D cost function to transport from ``inputs`` to the ``num_targets``
target values ; ``epsilon`` regularization parameter. Remaining ``kwargs``
are passed on to parameterize the
:class:`~ott.solvers.linear.sinkhorn.Sinkhorn` solver.
Returns:
An Array, which has the same shape as ``inputs``, except on the ``axis``
Expand Down Expand Up @@ -420,11 +485,11 @@ def quantile_normalization(
to the user are ``squashing_fun``, which will redistribute the values in
``inputs`` to lie in :math:`[0,1]` (sigmoid of whitened values by default)
to solve the optimal transport problem;
attribute :class:`cost_fn <ott.geometry.costs.CostFn>` of
:class:`cost_fn <ott.geometry.costs.CostFn>` object of
:class:`~ott.geometry.pointcloud.PointCloud`, which defines the ground
cost function to transport from ``inputs`` to the ``num_targets`` target
values ; ``epsilon`` regularization
parameter. Remaining ``kwargs`` are passed on to defined the
1D cost function to transport from ``inputs`` to the ``num_targets``
target values ; ``epsilon`` regularization parameter. Remaining ``kwargs``
are passed on to parameterize the
:class:`~ott.solvers.linear.sinkhorn.Sinkhorn` solver.
Returns:
Expand Down Expand Up @@ -473,11 +538,11 @@ def sort_with(
to the user are ``squashing_fun``, which will redistribute the values in
``inputs`` to lie in :math:`[0,1]` (sigmoid of whitened values by default)
to solve the optimal transport problem;
attribute :class:`cost_fn <ott.geometry.costs.CostFn>` of
:class:`cost_fn <ott.geometry.costs.CostFn>` object of
:class:`~ott.geometry.pointcloud.PointCloud`, which defines the ground
cost function to transport from ``inputs`` to the ``num_targets`` target
values ; ``epsilon`` regularization
parameter. Remaining ``kwargs`` are passed on to defined the
1D cost function to transport from ``inputs`` to the ``num_targets``
target values ; ``epsilon`` regularization parameter. Remaining ``kwargs``
are passed on to parameterize the
:class:`~ott.solvers.linear.sinkhorn.Sinkhorn` solver.
Returns:
Expand Down Expand Up @@ -541,11 +606,11 @@ def quantize(
to the user are ``squashing_fun``, which will redistribute the values in
``inputs`` to lie in :math:`[0,1]` (sigmoid of whitened values by default)
to solve the optimal transport problem;
attribute :class:`cost_fn <ott.geometry.costs.CostFn>` of
:class:`cost_fn <ott.geometry.costs.CostFn>` object of
:class:`~ott.geometry.pointcloud.PointCloud`, which defines the ground
cost function to transport from ``inputs`` to the ``num_targets`` target
values ; ``epsilon`` regularization
parameter. Remaining ``kwargs`` are passed on to defined the
1D cost function to transport from ``inputs`` to the ``num_targets``
target values ; ``epsilon`` regularization parameter. Remaining ``kwargs``
are passed on to parameterize the
:class:`~ott.solvers.linear.sinkhorn.Sinkhorn` solver.
Expand Down
54 changes: 49 additions & 5 deletions tests/tools/soft_sort_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,14 +86,58 @@ def test_sort_batch(self, rng: jax.random.PRNGKeyArray, topk: int):
np.testing.assert_array_equal(xs.shape, expected_shape)
np.testing.assert_array_equal(jnp.diff(xs, axis=axis) >= 0.0, True)

def test_rank_one_array(self, rng: jax.random.PRNGKeyArray):
x = jax.random.uniform(rng, (20,))
expected_ranks = jnp.argsort(jnp.argsort(x, axis=0), axis=0).astype(float)
@pytest.mark.fast.with_args("axis", [0, 1], only_fast0=0)
def test_ranks(self, axis, rng: jax.random.PRNGKeyArray):
rng1, rng2 = jax.random.split(rng, 2)
num_targets = 13
x = jax.random.uniform(rng1, (8, 5, 2))

# Define a custom version of ranks suited to recover ranks that are
# close to true ranks. This requires notably small epsilon and large # iter.
my_ranks = functools.partial(
soft_sort.ranks,
squashing_fun=lambda x: x,
epsilon=1e-4,
axis=axis,
max_iterations=5000
)
expected_ranks = jnp.argsort(
jnp.argsort(x, axis=axis), axis=axis
).astype(float)
ranks = my_ranks(x)
np.testing.assert_array_equal(x.shape, ranks.shape)
np.testing.assert_allclose(ranks, expected_ranks, atol=0.3, rtol=0.1)

ranks = soft_sort.ranks(x, epsilon=0.001)
ranks = my_ranks(x, num_targets=num_targets)
np.testing.assert_array_equal(x.shape, ranks.shape)
np.testing.assert_allclose(ranks, expected_ranks, atol=0.3, rtol=0.1)

target_weights = jax.random.uniform(rng2, (num_targets,))
target_weights /= jnp.sum(target_weights)
ranks = my_ranks(x, target_weights=target_weights)
np.testing.assert_array_equal(x.shape, ranks.shape)
np.testing.assert_allclose(ranks, expected_ranks, atol=0.9, rtol=0.1)
np.testing.assert_allclose(ranks, expected_ranks, atol=0.3, rtol=0.1)

@pytest.mark.fast.with_args("axis", [0, 1], only_fast=0)
def test_topk_mask(self, axis, rng: jax.random.PRNGKeyArray):
k = 3
x = jax.random.uniform(rng, (13, 7, 1))
my_topk_mask = functools.partial(
soft_sort.topk_mask,
squashing_fun=lambda x: x,
epsilon=1e-4, # needed to recover a sharp mask given close ties
max_iterations=15000, # needed to recover a sharp mask given close ties
axis=axis
)
mask = my_topk_mask(x, k=k, axis=axis)

def boolean_topk_mask(u, k):
return u >= jnp.flip(jax.numpy.sort(u))[k - 1]

expected_mask = soft_sort.apply_on_axis(boolean_topk_mask, x, axis, k)

np.testing.assert_array_equal(x.shape, mask.shape)
np.testing.assert_allclose(mask, expected_mask, atol=0.01, rtol=0.1)

@pytest.mark.fast()
@pytest.mark.parametrize("q", [0.2, 0.9])
Expand Down

0 comments on commit 7e21ae6

Please sign in to comment.