Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add soft topk_mask operator, fix ranks docs and tests #396

Merged
merged 4 commits into from
Jul 14, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
199 changes: 132 additions & 67 deletions src/ott/tools/soft_sort.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ def apply_on_axis(op, inputs, axis, *args, **kwargs: Any) -> jnp.ndarray:
op: a differentiable operator (can be ranks, quantile, etc.)
inputs: Array of any shape.
axis: the axis (int) or tuple of ints on which to apply the operator. If
several axes are passed the operator, those are merged as a single
several axes are passed to the operator, those are merged as a single
dimension.
args: other positional arguments to the operator.
kwargs: other positional arguments to the operator.
Expand Down Expand Up @@ -152,7 +152,6 @@ def sort(
x = jax.random.uniform(rng, (100,))
x_sorted = sort(x)


will output sorted convex-combinations of values contained in ``x``, that are
differentiable approximations to the sorted vector of entries in ``x``.
These can be compared with the values produced by :func:`jax.numpy.sort`,
Expand All @@ -169,46 +168,55 @@ def sort(
the top-k values. This also reduces the complexity of soft-sorting, since
the number of target points to which the slice of the ``inputs`` tensor
will be mapped to will be equal to ``topk+1``.
num_targets: if ``topk`` is not specified, ``num_targets`` defines the
number of (composite) sorted values computed from the inputs (each value
is a convex combination of values recorded in the inputs, provided in
increasing order). If neither ``topk`` nor ``num_targets`` are specified,
``num_targets`` defaults to the size of the slices of the input that are
sorted, i.e. ``inputs.shape[axis]``, and the number of composite sorted
values is equal to the slice of the inputs that are sorted.
num_targets: if ``topk`` is not specified, a vector of size``num_targets``
is returned. This defines the number of (composite) sorted values computed
from the inputs (each value is a convex combination of values recorded in
the inputs, provided in increasing order). If neither ``topk`` nor
``num_targets`` are specified, ``num_targets`` defaults to the size of the
slices of the input that are sorted, i.e. ``inputs.shape[axis]``, and the
number of composite sorted values is equal to the slice of the inputs that
are sorted. As a result, the output is of the same size as ``inputs``.
kwargs: keyword arguments passed on to lower level functions. Of interest
to the user are ``squashing_fun``, which will redistribute the values in
``inputs`` to lie in :math:`[0,1]` (sigmoid of whitened values by default)
to solve the optimal transport problem;
attribute :class:`cost_fn <ott.geometry.costs.CostFn>` of
:class:`cost_fn <ott.geometry.costs.CostFn>` object of
:class:`~ott.geometry.pointcloud.PointCloud`, which defines the ground
cost function to transport from ``inputs`` to the ``num_targets`` target
values ; ``epsilon`` regularization
parameter. Remaining ``kwargs`` are passed on to defined the
1D cost function to transport from ``inputs`` to the ``num_targets``
target values ; ``epsilon`` regularization parameter. Remaining ``kwargs``
are passed on to parameterize the
:class:`~ott.solvers.linear.sinkhorn.Sinkhorn` solver.

Returns:
An Array of the same shape as the input with soft-sorted values on the
given axis.
An Array of the same shape as the input, except on ``axis``, where that size
will be equal to ``topk`` or ``num_targets``, with soft-sorted values on the
given axis. Same size as ``inputs`` if both these parameters are ``None``.
"""
return apply_on_axis(_sort, inputs, axis, topk, num_targets, **kwargs)


def _ranks(inputs: jnp.ndarray, num_targets, **kwargs: Any) -> jnp.ndarray:
def _ranks(
inputs: jnp.ndarray, num_targets, target_weights, **kwargs: Any
) -> jnp.ndarray:
"""Apply the soft ranks operator on a one dimensional array."""
num_points = inputs.shape[0]
num_targets = num_points if num_targets is None else num_targets
if target_weights is None:
num_targets = num_points if num_targets is None else num_targets
target_weights = jnp.ones((num_targets,)) / num_targets
else:
num_targets = target_weights.shape[0]
a = jnp.ones((num_points,)) / num_points
b = jnp.ones((num_targets,)) / num_targets
ot = transport_for_sort(inputs, a, b, **kwargs)
ot = transport_for_sort(inputs, a, target_weights, **kwargs)
out = 1.0 / a * ot.apply(jnp.arange(num_targets), axis=1)
out *= (num_points - 1.0) / (num_targets - 1.0)
return jnp.reshape(out, inputs.shape)


def ranks(
inputs: jnp.ndarray,
axis: int = -1,
num_targets: Optional[int] = None,
target_weights: Optional[jnp.ndarray] = None,
**kwargs: Any,
) -> jnp.ndarray:
r"""Apply the soft rank operator on input tensor.
Expand All @@ -220,46 +228,103 @@ def ranks(
x = jax.random.uniform(rng, (100,))
x_ranks = ranks(x)

will output fractional values, between 0 and 1, that are differentiable
approximations to the normalized ranks of entries in ``x``. These should be
compared to the non-differentiable rank vectors, namely the normalized inverse
permutation produced by :func:`jax.numpy.argsort`, which can be obtained as:
will output values that are differentiable approximations to the ranks of
entries in ``x``. These should be compared to the non-differentiable rank
vectors, namely the normalized inverse permutation produced by
:func:`jax.numpy.argsort`, which can be obtained as:

.. code-block:: python

x_ranks = jax.numpy.argsort(jax.numpy.argsort(x)) / x.shape[0]
x_ranks = jax.numpy.argsort(jax.numpy.argsort(x))

Args:
inputs: Array of any shape.
axis: the axis on which to apply the soft-sorting operator.
topk: if set to a positive value, the returned vector will only contain
the top-k values. This also reduces the complexity of soft-sorting, since
the number of target points to which the slice of the ``inputs`` tensor
will be mapped to will be equal to ``topk+1``.
num_targets: if ``topk`` is not specified, ``num_targets`` defines the
number of (composite) sorted values computed from the inputs (each value
is a convex combination of values recorded in the inputs, provided in
increasing order). If neither ``topk`` nor ``num_targets`` are specified,
``num_targets`` defaults to the size of the slices of the input that are
sorted, i.e. ``inputs.shape[axis]``, and the number of composite sorted
values is equal to the slice of the inputs that are sorted.
target_weights: This vector contains weights (summing to 1) that describe
amount of mass shipped to targets.
num_targets: If `target_weights` is ``None``, ``num_targets`` is considered
to define the number of targets used to rank inputs. Each normalized rank
returned in the output will be a convex combination of
``{1, .., num_targets}/num_targets``. The weight of each of these points
is assumed to be uniform. If neither ``num_targets`` nor
``target_weights`` are specified, ``num_targets`` defaults to the size
of the slices of the input that are sorted, i.e. ``inputs.shape[axis]``.
kwargs: keyword arguments passed on to lower level functions. Of interest
to the user are ``squashing_fun``, which will redistribute the values in
``inputs`` to lie in :math:`[0,1]` (sigmoid of whitened values by default)
to solve the optimal transport problem;
attribute :class:`cost_fn <ott.geometry.costs.CostFn>` of
:class:`cost_fn <ott.geometry.costs.CostFn>` object of
:class:`~ott.geometry.pointcloud.PointCloud`, which defines the ground
cost function to transport from ``inputs`` to the ``num_targets`` target
values ; ``epsilon`` regularization
parameter. Remaining ``kwargs`` are passed on to defined the
1D cost function to transport from ``inputs`` to the ``num_targets``
target values ; ``epsilon`` regularization parameter. Remaining ``kwargs``
are passed on to parameterize the
:class:`~ott.solvers.linear.sinkhorn.Sinkhorn` solver.

Returns:
An Array of the same shape as the input with soft-rank values
normalized to be in :math:`[0,1]`, replacing the original ones.
normalized to be in `[0, n-1]` where `n` is `inputs.shape[axis]`.

"""
return apply_on_axis(_ranks, inputs, axis, num_targets, **kwargs)
return apply_on_axis(
_ranks, inputs, axis, num_targets, target_weights, **kwargs
)


def topk_mask(
inputs: jnp.ndarray,
axis: int = -1,
k: int = 1,
**kwargs: Any,
) -> jnp.ndarray:
r"""Soft top-$k$ selection mask.

For instance:

.. code-block:: python

k = 5
x = jax.random.uniform(rng, (100,))
mask = top_k_mask(x, k=k)

will output a vector of shape ``x.shape``, with values in :math:`[0,1]`, that
are differentiable approximations to the binary mask selecting the top $k$
entries in ``x``. These should be compared to the non-differentiable mask
obtained with :func:`jax.numpy.argsort`, which can be obtained as:

.. code-block:: python

mask = x >= jax.numpy.sort(x).flip()[k-1]

Args:
inputs: Array of any shape.
axis: the axis on which to apply the soft-sorting operator.
k : topk parameter. Should be smaller than ``inputs.shape[axis]``.
kwargs: keyword arguments passed on to lower level functions. Of interest
to the user are ``squashing_fun``, which will redistribute the values in
``inputs`` to lie in :math:`[0,1]` (sigmoid of whitened values by default)
to solve the optimal transport problem;
:class:`cost_fn <ott.geometry.costs.CostFn>` object of
:class:`~ott.geometry.pointcloud.PointCloud`, which defines the ground
1D cost function to transport from ``inputs`` to the ``num_targets``
target values ; ``epsilon`` regularization parameter. Remaining ``kwargs``
are passed on to parameterize the
:class:`~ott.solvers.linear.sinkhorn.Sinkhorn` solver.

"""
num_points = inputs.shape[axis]
assert k < num_points, (
f"`k` must be smaller than `inputs.shape[axis]`, yet {k} >= {num_points}."
)
target_weights = jnp.array([1.0 - k / num_points, k / num_points])
out = apply_on_axis(
_ranks,
inputs,
axis,
num_targets=None,
target_weights=target_weights,
**kwargs
)
return out / (num_points - 1)


def quantile(
Expand Down Expand Up @@ -299,21 +364,21 @@ def quantile(
These values should all lie in :math:`[0,1]`.
axis: the axis on which to apply the operator.
weight: the weight assigned to each quantile target value in the OT problem.
This weight should be small, typically of the order of ``1/n``, where ``n``
is the size of ``x``. Note: Since the size of ``q`` times ``weight``
must be strictly smaller than ``1``, in order to leave enough mass to set
other target values in the transport problem, the algorithm might ensure
this by setting, when needed, a lower value.
This weight should be small, typically of the order of ``1/n``, where ``n``
is the size of ``x``. Note: Since the size of ``q`` times ``weight``
must be strictly smaller than ``1``, in order to leave enough mass to set
other target values in the transport problem, the algorithm might ensure
this by setting, when needed, a lower value.
kwargs: keyword arguments passed on to lower level functions. Of interest
to the user are ``squashing_fun``, which will redistribute the values in
``inputs`` to lie in :math:`[0,1]` (sigmoid of whitened values by default)
to solve the optimal transport problem;
attribute :class:`cost_fn <ott.geometry.costs.CostFn>` of
:class:`~ott.geometry.pointcloud.PointCloud`, which defines the ground
cost function to transport from ``inputs`` to the ``num_targets`` target
values ; ``epsilon`` regularization
parameter. Remaining ``kwargs`` are passed on to defined the
:class:`~ott.solvers.linear.sinkhorn.Sinkhorn` solver.
to the user are ``squashing_fun``, which will redistribute the values in
``inputs`` to lie in :math:`[0,1]` (sigmoid of whitened values by default)
to solve the optimal transport problem;
:class:`cost_fn <ott.geometry.costs.CostFn>` object of
:class:`~ott.geometry.pointcloud.PointCloud`, which defines the ground
1D cost function to transport from ``inputs`` to the ``num_targets``
target values ; ``epsilon`` regularization parameter. Remaining ``kwargs``
are passed on to parameterize the
:class:`~ott.solvers.linear.sinkhorn.Sinkhorn` solver.

Returns:
An Array, which has the same shape as ``inputs``, except on the ``axis``
Expand Down Expand Up @@ -420,11 +485,11 @@ def quantile_normalization(
to the user are ``squashing_fun``, which will redistribute the values in
``inputs`` to lie in :math:`[0,1]` (sigmoid of whitened values by default)
to solve the optimal transport problem;
attribute :class:`cost_fn <ott.geometry.costs.CostFn>` of
:class:`cost_fn <ott.geometry.costs.CostFn>` object of
:class:`~ott.geometry.pointcloud.PointCloud`, which defines the ground
cost function to transport from ``inputs`` to the ``num_targets`` target
values ; ``epsilon`` regularization
parameter. Remaining ``kwargs`` are passed on to defined the
1D cost function to transport from ``inputs`` to the ``num_targets``
target values ; ``epsilon`` regularization parameter. Remaining ``kwargs``
are passed on to parameterize the
:class:`~ott.solvers.linear.sinkhorn.Sinkhorn` solver.

Returns:
Expand Down Expand Up @@ -473,11 +538,11 @@ def sort_with(
to the user are ``squashing_fun``, which will redistribute the values in
``inputs`` to lie in :math:`[0,1]` (sigmoid of whitened values by default)
to solve the optimal transport problem;
attribute :class:`cost_fn <ott.geometry.costs.CostFn>` of
:class:`cost_fn <ott.geometry.costs.CostFn>` object of
:class:`~ott.geometry.pointcloud.PointCloud`, which defines the ground
cost function to transport from ``inputs`` to the ``num_targets`` target
values ; ``epsilon`` regularization
parameter. Remaining ``kwargs`` are passed on to defined the
1D cost function to transport from ``inputs`` to the ``num_targets``
target values ; ``epsilon`` regularization parameter. Remaining ``kwargs``
are passed on to parameterize the
:class:`~ott.solvers.linear.sinkhorn.Sinkhorn` solver.

Returns:
Expand Down Expand Up @@ -541,11 +606,11 @@ def quantize(
to the user are ``squashing_fun``, which will redistribute the values in
``inputs`` to lie in :math:`[0,1]` (sigmoid of whitened values by default)
to solve the optimal transport problem;
attribute :class:`cost_fn <ott.geometry.costs.CostFn>` of
:class:`cost_fn <ott.geometry.costs.CostFn>` object of
:class:`~ott.geometry.pointcloud.PointCloud`, which defines the ground
cost function to transport from ``inputs`` to the ``num_targets`` target
values ; ``epsilon`` regularization
parameter. Remaining ``kwargs`` are passed on to defined the
1D cost function to transport from ``inputs`` to the ``num_targets``
target values ; ``epsilon`` regularization parameter. Remaining ``kwargs``
are passed on to parameterize the
:class:`~ott.solvers.linear.sinkhorn.Sinkhorn` solver.


Expand Down
54 changes: 49 additions & 5 deletions tests/tools/soft_sort_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,14 +86,58 @@ def test_sort_batch(self, rng: jax.random.PRNGKeyArray, topk: int):
np.testing.assert_array_equal(xs.shape, expected_shape)
np.testing.assert_array_equal(jnp.diff(xs, axis=axis) >= 0.0, True)

def test_rank_one_array(self, rng: jax.random.PRNGKeyArray):
x = jax.random.uniform(rng, (20,))
expected_ranks = jnp.argsort(jnp.argsort(x, axis=0), axis=0).astype(float)
@pytest.mark.fast.with_args("axis", [0, 1], only_fast0=0)
def test_ranks(self, axis, rng: jax.random.PRNGKeyArray):
rng1, rng2 = jax.random.split(rng, 2)
num_targets = 13
x = jax.random.uniform(rng1, (8, 5, 2))

# Define a custom version of ranks suited to recover ranks that are
# close to true ranks. This requires notably small epsilon and large # iter.
my_ranks = functools.partial(
soft_sort.ranks,
squashing_fun=lambda x: x,
epsilon=1e-4,
axis=axis,
max_iterations=5000
)
expected_ranks = jnp.argsort(
jnp.argsort(x, axis=axis), axis=axis
).astype(float)
ranks = my_ranks(x)
np.testing.assert_array_equal(x.shape, ranks.shape)
np.testing.assert_allclose(ranks, expected_ranks, atol=0.3, rtol=0.1)

ranks = soft_sort.ranks(x, epsilon=0.001)
ranks = my_ranks(x, num_targets=num_targets)
np.testing.assert_array_equal(x.shape, ranks.shape)
np.testing.assert_allclose(ranks, expected_ranks, atol=0.3, rtol=0.1)

target_weights = jax.random.uniform(rng2, (num_targets,))
target_weights /= jnp.sum(target_weights)
ranks = my_ranks(x, target_weights=target_weights)
np.testing.assert_array_equal(x.shape, ranks.shape)
np.testing.assert_allclose(ranks, expected_ranks, atol=0.9, rtol=0.1)
np.testing.assert_allclose(ranks, expected_ranks, atol=0.3, rtol=0.1)

@pytest.mark.fast.with_args("axis", [0, 1], only_fast=0)
def test_topk_mask(self, axis, rng: jax.random.PRNGKeyArray):
k = 3
x = jax.random.uniform(rng, (13, 7, 1))
my_topk_mask = functools.partial(
soft_sort.topk_mask,
squashing_fun=lambda x: x,
epsilon=1e-4, # needed to recover a sharp mask given close ties
max_iterations=15000, # needed to recover a sharp mask given close ties
axis=axis
)
mask = my_topk_mask(x, k=k, axis=axis)

def boolean_topk_mask(u, k):
return u >= jnp.flip(jax.numpy.sort(u))[k - 1]

expected_mask = soft_sort.apply_on_axis(boolean_topk_mask, x, axis, k)

np.testing.assert_array_equal(x.shape, mask.shape)
np.testing.assert_allclose(mask, expected_mask, atol=0.01, rtol=0.1)

@pytest.mark.fast()
@pytest.mark.parametrize("q", [0.2, 0.9])
Expand Down