diff --git a/jax/numpy/lax_numpy.py b/jax/numpy/lax_numpy.py index d4a88477f541..c8f351c6cc47 100644 --- a/jax/numpy/lax_numpy.py +++ b/jax/numpy/lax_numpy.py @@ -2341,17 +2341,16 @@ def _canonicalize_tuple_index(arr, idx): def _static_idx(idx, size): """Helper function to compute the static slice start/limit/stride values.""" - indices = onp.arange(size)[idx] # get shape statically - if not len(indices): # pylint: disable=g-explicit-length-test + assert isinstance(idx, slice) + start, stop, step = idx.indices(size) + if (step < 0 and stop >= start) or (step > 0 and start >= stop): return 0, 0, 1, False # sliced to size zero - start, stop_inclusive = indices[0], indices[-1] - step = 1 if idx.step is None else idx.step + if step > 0: - end = _min(stop_inclusive + step, size) - return start, end, step, False + return start, stop, step, False else: - end = _min(start - step, size) - return stop_inclusive, end, -step, True + k = (start - stop - 1) % (-step) + return stop + k + 1, start + 1, -step, True blackman = onp.blackman