Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Lower more index_update cases to slice #15347

Closed
wants to merge 1 commit into from

Conversation

jakevdp
Copy link
Collaborator

@jakevdp jakevdp commented Mar 31, 2023

Before:

In [1]: import jax.numpy as jnp
In [2]: import jax
In [3]: x = jnp.arange(24).reshape(2, 3, 4)
In [4]: jax.make_jaxpr(lambda x: x[1,:,:2])(x)
Out[4]: 
{ lambda ; a:i32[2,3,4]. let
    b:i32[1] = broadcast_in_dim[broadcast_dimensions=() shape=(1,)] 1
    c:i32[1] = broadcast_in_dim[broadcast_dimensions=() shape=(1,)] 0
    d:i32[2] = concatenate[dimension=0] b c
    e:i32[3,2] = gather[
      dimension_numbers=GatherDimensionNumbers(offset_dims=(0, 1), collapsed_slice_dims=(0,), start_index_map=(0, 2))
      fill_value=None
      indices_are_sorted=True
      mode=GatherScatterMode.PROMISE_IN_BOUNDS
      slice_sizes=(1, 3, 2)
      unique_indices=True
    ] a d
  in (e,) }

After:

In [4]: jax.make_jaxpr(lambda x: x[1,:,:2])(x)
Out[4]: 
{ lambda ; a:i32[2,3,4]. let
    b:i32[1,3,2] = slice[
      limit_indices=(2, 3, 2)
      start_indices=(1, 0, 0)
      strides=(1, 1, 1)
    ] a
    c:i32[3,2] = squeeze[dimensions=(0,)] b
  in (c,) }

@jakevdp jakevdp self-assigned this Mar 31, 2023
@jakevdp jakevdp force-pushed the getitem-slice branch 5 times, most recently from a8ad117 to 05554a4 Compare April 3, 2023 15:26
@jakevdp jakevdp requested a review from yashk2810 April 3, 2023 15:27
@jakevdp jakevdp added the pull ready Ready for copybara import and testing label Apr 3, 2023
@@ -854,16 +854,43 @@ def testJVPOfGradOfIndexing(self):
self.assertAllClose(expected, primals)
self.assertAllClose(np.zeros_like(x), tangents)

def testSimpleIndexingUsesSlice(self):
jaxpr = jax.make_jaxpr(lambda x: x[:2, :2])(jnp.ones((3, 4)))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Instead of just checking the jaxpr, please check the output too

See this function: https://github.com/google/jax/blob/main/tests/array_test.py#L481-L484

Copy link
Collaborator

@yashk2810 yashk2810 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also please test the slicing with Arrays sharded over multiple devices

You can add more cases here if you want: https://github.com/google/jax/blob/main/tests/array_test.py#L480

@jakevdp
Copy link
Collaborator Author

jakevdp commented Apr 3, 2023

From offline discussion: it looks like lax.slice is fundamentally broken for sharded arrays, so @yashk2810 suggested that a way forward here is probably to skip the new code path if the input array is sharded.

@jakevdp
Copy link
Collaborator Author

jakevdp commented Apr 14, 2023

closing in favor of #15377

@jakevdp jakevdp closed this Apr 14, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
pull ready Ready for copybara import and testing
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants