Skip to content

Commit

Permalink
fix test
Browse files Browse the repository at this point in the history
  • Loading branch information
marcocuturi committed Jul 13, 2023
1 parent 1745e0f commit 9f20b50
Showing 1 changed file with 8 additions and 3 deletions.
11 changes: 8 additions & 3 deletions tests/tools/soft_sort_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,11 +90,16 @@ def test_sort_batch(self, rng: jax.random.PRNGKeyArray, topk: int):
def test_ranks(self, axis, rng: jax.random.PRNGKeyArray):
rng1, rng2 = jax.random.split(rng, 2)
num_targets = 13
x = jax.random.uniform(rng1, (8, 1, 1))
x = jax.random.uniform(rng1, (8, 5, 2))

# Define a custom version of ranks suited to recover closely true ranks
# Define a custom version of ranks suited to recover ranks that are
# close to true ranks. This requires notably small epsilon and large # iter.
my_ranks = functools.partial(
soft_sort.ranks, squashing_fun=lambda x: x, epsilon=1e-4, axis=axis
soft_sort.ranks,
squashing_fun=lambda x: x,
epsilon=1e-4,
axis=axis,
max_iterations=10000
)
expected_ranks = jnp.argsort(
jnp.argsort(x, axis=axis), axis=axis
Expand Down

0 comments on commit 9f20b50

Please sign in to comment.