Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update soft_sort.py #387

Merged
merged 1 commit into from
Jul 4, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 13 additions & 14 deletions src/ott/tools/soft_sort.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,9 +36,9 @@ def transport_for_sort(
r"""Solve reg. OT, from inputs to a weighted family of increasing values.
Args:
inputs: jnp.ndarray[num_points]. Must be one dimensional.
weights: jnp.ndarray[num_points]. Weight vector `a` for input values.
target_weights: jnp.ndarray[num_targets]: Weight vector of the target
inputs: Array[num_points]. Must be one dimensional.
weights: Array[num_points]. Weight vector `a` for input values.
target_weights: Array[num_targets]: Weight vector of the target
measure. It may be of different size than `weights`.
squashing_fun: function taking an array to squash all its entries in [0,1].
sigmoid of whitened values by default. Can be set to be the identity by
Expand All @@ -48,8 +48,7 @@ def transport_for_sort(
:class:`~ott.solvers.linear.sinkhorn.Sinkhorn`.
Returns:
A jnp.ndarray<float> num_points x num_target transport matrix, from all
inputs onto the sorted target.
A :class:`~ott.solvers.linear.sinkhorn.SinkhornOutput` object.
"""
shape = inputs.shape
if len(shape) > 2 or (len(shape) == 2 and shape[1] != 1):
Expand Down Expand Up @@ -81,15 +80,15 @@ def apply_on_axis(op, inputs, axis, *args, **kwargs: Any) -> jnp.ndarray:
Args:
op: a differentiable operator (can be ranks, quantile, etc.)
inputs: jnp.ndarray<float> of any shape.
inputs: Array of any shape.
axis: the axis (int) or tuple of ints on which to apply the operator. If
several axes are passed the operator, those are merged as a single
dimension.
args: other positional arguments to the operator.
kwargs: other positional arguments to the operator.
Returns:
A jnp.ndarray holding the output of the differentiable operator on the given
An Array holding the output of the differentiable operator on the given
axis.
"""
op_inner = functools.partial(op, **kwargs)
Expand Down Expand Up @@ -164,7 +163,7 @@ def sort(
Args:
inputs: jnp.ndarray<float> of any shape.
inputs: Array of any shape.
axis: the axis on which to apply the soft-sorting operator.
topk: if set to a positive value, the returned vector will only contain
the top-k values. This also reduces the complexity of soft-sorting, since
Expand All @@ -189,7 +188,7 @@ def sort(
:class:`~ott.solvers.linear.sinkhorn.Sinkhorn` solver.
Returns:
A jnp.ndarray of the same shape as the input with soft-sorted values on the
An Array of the same shape as the input with soft-sorted values on the
given axis.
"""
return apply_on_axis(_sort, inputs, axis, topk, num_targets, **kwargs)
Expand Down Expand Up @@ -231,7 +230,7 @@ def ranks(
x_ranks = jax.numpy.argsort(jax.numpy.argsort(x)) / x.shape[0]
Args:
inputs: jnp.ndarray<float> of any shape.
inputs: Array of any shape.
axis: the axis on which to apply the soft-sorting operator.
topk: if set to a positive value, the returned vector will only contain
the top-k values. This also reduces the complexity of soft-sorting, since
Expand All @@ -256,7 +255,7 @@ def ranks(
:class:`~ott.solvers.linear.sinkhorn.Sinkhorn` solver.
Returns:
A jnp.ndarray of the same shape as the input with soft-rank values
An Array of the same shape as the input with soft-rank values
normalized to be in :math:`[0,1]`, replacing the original ones.
"""
Expand Down Expand Up @@ -295,7 +294,7 @@ def quantile(
Args:
inputs: a jnp.ndarray<float> of any shape.
inputs: an Array of any shape.
q: values of the quantile level to be computed, e.g. [0.5] for median.
These values should all lie in :math:`[0,1]`.
axis: the axis on which to apply the operator.
Expand All @@ -317,8 +316,8 @@ def quantile(
:class:`~ott.solvers.linear.sinkhorn.Sinkhorn` solver.
Returns:
A jnp.ndarray, which has the same shape as the input, except on the ``axis``
that is passed, which has size ``q.shape[0]`` to collect soft-quantile
An Array, which has the same shape as ``inputs``, except on the ``axis``
that is passed, which has size ``q.shape[0]``, to collect soft-quantile
values.
"""

Expand Down