Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Replace indexed_shape by a version that uses jit #519

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
76 changes: 20 additions & 56 deletions scico/numpy/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
import jax

import scico.numpy as snp
from scico.typing import ArrayIndex, Axes, AxisIndex, BlockShape, DType, Shape
from scico.typing import ArrayIndex, Axes, BlockShape, DType, Shape

from ._blockarray import BlockArray

Expand Down Expand Up @@ -68,72 +68,36 @@ def parse_axes(
return axes


def slice_length(length: int, idx: AxisIndex) -> Optional[int]:
"""Determine the length of an array axis after indexing.

Determine the length of an array axis after slicing. An exception is
raised if the indexing expression is an integer that is out of bounds
for the specified axis length. A value of ``None`` is returned for
valid integer indexing expressions as an indication that the
corresponding axis shape is an empty tuple; this value should be
converted to a unit integer if the axis size is required.

Args:
length: Length of axis being sliced.
idx: Indexing/slice to be applied to axis.

Returns:
Length of indexed/sliced axis.

Raises:
ValueError: If `idx` is an integer index that is out bounds for
the axis length or if the type of `idx` is not one of
`Ellipsis`, `int`, or `slice`.
"""
if idx is Ellipsis:
return length
if isinstance(idx, int):
if idx < -length or idx > length - 1:
raise ValueError(f"Index {idx} out of bounds for axis of length {length}.")
return None
if not isinstance(idx, slice):
raise ValueError(f"Index expression {idx} is of an unrecognized type.")
start, stop, stride = idx.indices(length)
if start > stop:
start = stop
return (stop - start + stride - 1) // stride


def indexed_shape(shape: Shape, idx: ArrayIndex) -> Tuple[int, ...]:
"""Determine the shape of an array after indexing/slicing.

Args:
shape: Shape of array.
idx: Indexing expression.
idx: Indexing expression (singleton or tuple of `Ellipsis`,
`int`, `slice`, or ``None`` (`np.newaxis`)).

Returns:
Shape of indexed/sliced array.

Raises:
ValueError: If any element of `idx` is not one of `Ellipsis`,
`int`, `slice`, or ``None`` (`np.newaxis`), or if an integer
index is out bounds for the corresponding axis length.
"""
if not isinstance(idx, tuple):
idx = (idx,)
idx_shape: List[Optional[int]] = list(shape)
offset = 0
newaxis = 0
for axis, ax_idx in enumerate(idx):
if ax_idx is None:
idx_shape.insert(axis, 1)
newaxis += 1
continue
if ax_idx is Ellipsis:
offset = len(shape) - len(idx)
continue
idx_shape[axis + offset + newaxis] = slice_length(shape[axis + offset], ax_idx)
return tuple(filter(lambda x: x is not None, idx_shape)) # type: ignore

# convert any slices to its representation (slice, (start, stop, step))
# allows hashing, needed for jax.jit
idx = tuple(exp.__reduce__() if isinstance(exp, slice) else exp for exp in idx)

def get_shape(in_shape, ind_expr):
# convert slices representations back to slices
ind_expr = tuple(
(slice(*exp[1]) if isinstance(exp, tuple) and len(exp) > 0 and exp[0] == slice else exp)
for exp in ind_expr
)
return jax.numpy.empty(in_shape)[ind_expr].shape

# because all arguments are static, this compiles each time it gets new arguments
f = jax.jit(get_shape, static_argnums=(0, 1))

return tuple(t.item() for t in f(shape, idx)) # type: ignore


def no_nan_divide(
Expand Down
21 changes: 0 additions & 21 deletions scico/test/numpy/test_numpy_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
no_nan_divide,
parse_axes,
real_dtype,
slice_length,
)
from scico.random import randn

Expand Down Expand Up @@ -68,26 +67,6 @@ def test_parse_axes():
np.testing.assert_raises(ValueError, parse_axes, axes)


@pytest.mark.parametrize("length", (4, 5, 8, 16, 17))
@pytest.mark.parametrize("start", (None, 0, 1, 2, 3))
@pytest.mark.parametrize("stop", (None, 0, 1, 2, -2, -1))
@pytest.mark.parametrize("stride", (None, 1, 2, 3))
def test_slice_length(length, start, stop, stride):
x = np.zeros(length)
slc = slice(start, stop, stride)
assert x[slc].size == slice_length(length, slc)


@pytest.mark.parametrize("length", (4, 5))
@pytest.mark.parametrize("slc", (0, 1, -4, Ellipsis))
def test_slice_length_other(length, slc):
x = np.zeros(length)
if isinstance(slc, int):
assert slice_length(length, slc) is None
else:
assert x[slc].size == slice_length(length, slc)


@pytest.mark.parametrize("shape", ((8, 8, 1), (7, 1, 6, 5)))
@pytest.mark.parametrize(
"slc",
Expand Down
Loading