From a887c11713291c2d1de8bf74fc49bbce9a8ab9a0 Mon Sep 17 00:00:00 2001 From: Marco Cuturi Date: Tue, 4 Jul 2023 14:09:43 +0200 Subject: [PATCH] Update soft_sort.py pydocs --- src/ott/tools/soft_sort.py | 27 +++++++++++++-------------- 1 file changed, 13 insertions(+), 14 deletions(-) diff --git a/src/ott/tools/soft_sort.py b/src/ott/tools/soft_sort.py index 3f601bd8f..5786262b2 100644 --- a/src/ott/tools/soft_sort.py +++ b/src/ott/tools/soft_sort.py @@ -36,9 +36,9 @@ def transport_for_sort( r"""Solve reg. OT, from inputs to a weighted family of increasing values. Args: - inputs: jnp.ndarray[num_points]. Must be one dimensional. - weights: jnp.ndarray[num_points]. Weight vector `a` for input values. - target_weights: jnp.ndarray[num_targets]: Weight vector of the target + inputs: Array[num_points]. Must be one dimensional. + weights: Array[num_points]. Weight vector `a` for input values. + target_weights: Array[num_targets]: Weight vector of the target measure. It may be of different size than `weights`. squashing_fun: function taking an array to squash all its entries in [0,1]. sigmoid of whitened values by default. Can be set to be the identity by @@ -48,8 +48,7 @@ def transport_for_sort( :class:`~ott.solvers.linear.sinkhorn.Sinkhorn`. Returns: - A jnp.ndarray num_points x num_target transport matrix, from all - inputs onto the sorted target. + A :class:`~ott.solvers.linear.sinkhorn.SinkhornOutput` object. """ shape = inputs.shape if len(shape) > 2 or (len(shape) == 2 and shape[1] != 1): @@ -81,7 +80,7 @@ def apply_on_axis(op, inputs, axis, *args, **kwargs: Any) -> jnp.ndarray: Args: op: a differentiable operator (can be ranks, quantile, etc.) - inputs: jnp.ndarray of any shape. + inputs: Array of any shape. axis: the axis (int) or tuple of ints on which to apply the operator. If several axes are passed the operator, those are merged as a single dimension. @@ -89,7 +88,7 @@ def apply_on_axis(op, inputs, axis, *args, **kwargs: Any) -> jnp.ndarray: kwargs: other positional arguments to the operator. Returns: - A jnp.ndarray holding the output of the differentiable operator on the given + An Array holding the output of the differentiable operator on the given axis. """ op_inner = functools.partial(op, **kwargs) @@ -164,7 +163,7 @@ def sort( Args: - inputs: jnp.ndarray of any shape. + inputs: Array of any shape. axis: the axis on which to apply the soft-sorting operator. topk: if set to a positive value, the returned vector will only contain the top-k values. This also reduces the complexity of soft-sorting, since @@ -189,7 +188,7 @@ def sort( :class:`~ott.solvers.linear.sinkhorn.Sinkhorn` solver. Returns: - A jnp.ndarray of the same shape as the input with soft-sorted values on the + An Array of the same shape as the input with soft-sorted values on the given axis. """ return apply_on_axis(_sort, inputs, axis, topk, num_targets, **kwargs) @@ -231,7 +230,7 @@ def ranks( x_ranks = jax.numpy.argsort(jax.numpy.argsort(x)) / x.shape[0] Args: - inputs: jnp.ndarray of any shape. + inputs: Array of any shape. axis: the axis on which to apply the soft-sorting operator. topk: if set to a positive value, the returned vector will only contain the top-k values. This also reduces the complexity of soft-sorting, since @@ -256,7 +255,7 @@ def ranks( :class:`~ott.solvers.linear.sinkhorn.Sinkhorn` solver. Returns: - A jnp.ndarray of the same shape as the input with soft-rank values + An Array of the same shape as the input with soft-rank values normalized to be in :math:`[0,1]`, replacing the original ones. """ @@ -295,7 +294,7 @@ def quantile( Args: - inputs: a jnp.ndarray of any shape. + inputs: an Array of any shape. q: values of the quantile level to be computed, e.g. [0.5] for median. These values should all lie in :math:`[0,1]`. axis: the axis on which to apply the operator. @@ -317,8 +316,8 @@ def quantile( :class:`~ott.solvers.linear.sinkhorn.Sinkhorn` solver. Returns: - A jnp.ndarray, which has the same shape as the input, except on the ``axis`` - that is passed, which has size ``q.shape[0]`` to collect soft-quantile + An Array, which has the same shape as ``inputs``, except on the ``axis`` + that is passed, which has size ``q.shape[0]``, to collect soft-quantile values. """