Skip to content

Commit

Permalink
Merge pull request #932 from hawkinsp/master
Browse files Browse the repository at this point in the history
Use constant-time algorithm for static slice index calculation.
  • Loading branch information
hawkinsp authored Jun 26, 2019
2 parents c7afa1e + 014d235 commit 1508405
Showing 1 changed file with 7 additions and 8 deletions.
15 changes: 7 additions & 8 deletions jax/numpy/lax_numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -2341,17 +2341,16 @@ def _canonicalize_tuple_index(arr, idx):

def _static_idx(idx, size):
"""Helper function to compute the static slice start/limit/stride values."""
indices = onp.arange(size)[idx] # get shape statically
if not len(indices): # pylint: disable=g-explicit-length-test
assert isinstance(idx, slice)
start, stop, step = idx.indices(size)
if (step < 0 and stop >= start) or (step > 0 and start >= stop):
return 0, 0, 1, False # sliced to size zero
start, stop_inclusive = indices[0], indices[-1]
step = 1 if idx.step is None else idx.step

if step > 0:
end = _min(stop_inclusive + step, size)
return start, end, step, False
return start, stop, step, False
else:
end = _min(start - step, size)
return stop_inclusive, end, -step, True
k = (start - stop - 1) % (-step)
return stop + k + 1, start + 1, -step, True


blackman = onp.blackman
Expand Down

0 comments on commit 1508405

Please sign in to comment.