diff --git a/tests/tools/soft_sort_test.py b/tests/tools/soft_sort_test.py index 52c7076b7..e452edfba 100644 --- a/tests/tools/soft_sort_test.py +++ b/tests/tools/soft_sort_test.py @@ -90,11 +90,16 @@ def test_sort_batch(self, rng: jax.random.PRNGKeyArray, topk: int): def test_ranks(self, axis, rng: jax.random.PRNGKeyArray): rng1, rng2 = jax.random.split(rng, 2) num_targets = 13 - x = jax.random.uniform(rng1, (8, 1, 1)) + x = jax.random.uniform(rng1, (8, 5, 2)) - # Define a custom version of ranks suited to recover closely true ranks + # Define a custom version of ranks suited to recover ranks that are + # close to true ranks. This requires notably small epsilon and large # iter. my_ranks = functools.partial( - soft_sort.ranks, squashing_fun=lambda x: x, epsilon=1e-4, axis=axis + soft_sort.ranks, + squashing_fun=lambda x: x, + epsilon=1e-4, + axis=axis, + max_iterations=10000 ) expected_ranks = jnp.argsort( jnp.argsort(x, axis=axis), axis=axis