Skip to content

Commit

Permalink
Top k fix (#376)
Browse files Browse the repository at this point in the history
* jax ignore sorted in top_k

* Ignore sorted argument for jax top_k

`sorted=True` is a strictly stronger guarantee than `sorted=False`, so
better to always return `sorted=True` than add an annoying inconsistency
between what backends support what.
  • Loading branch information
mattdangerw authored and fchollet committed Jun 20, 2023
1 parent 6cfac4e commit c8d37be
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 5 deletions.
6 changes: 2 additions & 4 deletions keras_core/backend/jax/math.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,8 @@ def segment_sum(data, segment_ids, num_segments=None, sorted=False):


def top_k(x, k, sorted=True):
if not sorted:
return ValueError(
"Jax backend does not support `sorted=False` for `ops.top_k`"
)
# Jax does not supported `sorted`, but in the case where `sorted=False`,
# order is not guaranteed, so OK to return sorted output.
return jax.lax.top_k(x, k)


Expand Down
8 changes: 7 additions & 1 deletion keras_core/operations/math_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ def test_segment_sum(self):
outputs = kmath.segment_sum(data, segment_ids, num_segments=5)
self.assertEqual(outputs.shape, (5, 4))

def test_topk(self):
def test_top_k(self):
x = KerasTensor((None, 2, 3))
values, indices = kmath.top_k(x, k=1)
self.assertEqual(values.shape, (None, 2, 1))
Expand Down Expand Up @@ -155,6 +155,12 @@ def test_top_k(self):
self.assertAllClose(values, [4, 3])
self.assertAllClose(indices, [1, 4])

x = np.array([0, 4, 2, 1, 3, -1], dtype=np.float32)
values, indices = kmath.top_k(x, k=2, sorted=False)
# Any order ok when `sorted=False`.
self.assertEqual(set(backend.convert_to_numpy(values)), set([4, 3]))
self.assertEqual(set(backend.convert_to_numpy(indices)), set([1, 4]))

x = np.random.rand(5, 5)
outputs = kmath.top_k(x, k=2)
expected = tf.math.top_k(x, k=2)
Expand Down

0 comments on commit c8d37be

Please sign in to comment.