Skip to content

Commit

Permalink
Update soft_sort.py (#387)
Browse files Browse the repository at this point in the history
pydocs
  • Loading branch information
marcocuturi authored Jul 4, 2023
1 parent ab3cf0a commit d10348a
Showing 1 changed file with 13 additions and 14 deletions.
27 changes: 13 additions & 14 deletions src/ott/tools/soft_sort.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,9 +36,9 @@ def transport_for_sort(
r"""Solve reg. OT, from inputs to a weighted family of increasing values.
Args:
inputs: jnp.ndarray[num_points]. Must be one dimensional.
weights: jnp.ndarray[num_points]. Weight vector `a` for input values.
target_weights: jnp.ndarray[num_targets]: Weight vector of the target
inputs: Array[num_points]. Must be one dimensional.
weights: Array[num_points]. Weight vector `a` for input values.
target_weights: Array[num_targets]: Weight vector of the target
measure. It may be of different size than `weights`.
squashing_fun: function taking an array to squash all its entries in [0,1].
sigmoid of whitened values by default. Can be set to be the identity by
Expand All @@ -48,8 +48,7 @@ def transport_for_sort(
:class:`~ott.solvers.linear.sinkhorn.Sinkhorn`.
Returns:
A jnp.ndarray<float> num_points x num_target transport matrix, from all
inputs onto the sorted target.
A :class:`~ott.solvers.linear.sinkhorn.SinkhornOutput` object.
"""
shape = inputs.shape
if len(shape) > 2 or (len(shape) == 2 and shape[1] != 1):
Expand Down Expand Up @@ -81,15 +80,15 @@ def apply_on_axis(op, inputs, axis, *args, **kwargs: Any) -> jnp.ndarray:
Args:
op: a differentiable operator (can be ranks, quantile, etc.)
inputs: jnp.ndarray<float> of any shape.
inputs: Array of any shape.
axis: the axis (int) or tuple of ints on which to apply the operator. If
several axes are passed the operator, those are merged as a single
dimension.
args: other positional arguments to the operator.
kwargs: other positional arguments to the operator.
Returns:
A jnp.ndarray holding the output of the differentiable operator on the given
An Array holding the output of the differentiable operator on the given
axis.
"""
op_inner = functools.partial(op, **kwargs)
Expand Down Expand Up @@ -164,7 +163,7 @@ def sort(
Args:
inputs: jnp.ndarray<float> of any shape.
inputs: Array of any shape.
axis: the axis on which to apply the soft-sorting operator.
topk: if set to a positive value, the returned vector will only contain
the top-k values. This also reduces the complexity of soft-sorting, since
Expand All @@ -189,7 +188,7 @@ def sort(
:class:`~ott.solvers.linear.sinkhorn.Sinkhorn` solver.
Returns:
A jnp.ndarray of the same shape as the input with soft-sorted values on the
An Array of the same shape as the input with soft-sorted values on the
given axis.
"""
return apply_on_axis(_sort, inputs, axis, topk, num_targets, **kwargs)
Expand Down Expand Up @@ -231,7 +230,7 @@ def ranks(
x_ranks = jax.numpy.argsort(jax.numpy.argsort(x)) / x.shape[0]
Args:
inputs: jnp.ndarray<float> of any shape.
inputs: Array of any shape.
axis: the axis on which to apply the soft-sorting operator.
topk: if set to a positive value, the returned vector will only contain
the top-k values. This also reduces the complexity of soft-sorting, since
Expand All @@ -256,7 +255,7 @@ def ranks(
:class:`~ott.solvers.linear.sinkhorn.Sinkhorn` solver.
Returns:
A jnp.ndarray of the same shape as the input with soft-rank values
An Array of the same shape as the input with soft-rank values
normalized to be in :math:`[0,1]`, replacing the original ones.
"""
Expand Down Expand Up @@ -295,7 +294,7 @@ def quantile(
Args:
inputs: a jnp.ndarray<float> of any shape.
inputs: an Array of any shape.
q: values of the quantile level to be computed, e.g. [0.5] for median.
These values should all lie in :math:`[0,1]`.
axis: the axis on which to apply the operator.
Expand All @@ -317,8 +316,8 @@ def quantile(
:class:`~ott.solvers.linear.sinkhorn.Sinkhorn` solver.
Returns:
A jnp.ndarray, which has the same shape as the input, except on the ``axis``
that is passed, which has size ``q.shape[0]`` to collect soft-quantile
An Array, which has the same shape as ``inputs``, except on the ``axis``
that is passed, which has size ``q.shape[0]``, to collect soft-quantile
values.
"""

Expand Down

0 comments on commit d10348a

Please sign in to comment.