Skip to content

Commit

Permalink
Add function to chain (simple) slice expressions
Browse files Browse the repository at this point in the history
  • Loading branch information
funkey committed Jun 26, 2024
1 parent ccb4b81 commit 9708bf7
Show file tree
Hide file tree
Showing 2 changed files with 137 additions and 0 deletions.
104 changes: 104 additions & 0 deletions funlib/persistence/arrays/slices.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
import numpy as np


def chain_slices(slices_a, slices_b):

# make sure both slice expressions are tuples
if not isinstance(slices_a, tuple):
slices_a = (slices_a,)
if not isinstance(slices_b, tuple):
slices_b = (slices_b,)

# dimension of a is number of non-int expressions
dim_a = sum([not isinstance(x, int) for x in slices_a])

# slices_b can't slice more dimensions than a has
assert (
len(slices_b) <= dim_a
), f"Slice expression {slices_b} has too many dimensions to chain with {slices_a}"

chained = []

j = 0
for slice_a in slices_a:

# if slice_a is int that dimension does not exist any longer, skip
# also skip if b has no more elements
if j == len(slices_b) or isinstance(slice_a, int):
chained.append(slice_a)
else:
slice_b = slices_b[j]
chained.append(_chain_slice(slice_a, slice_b))
j += 1

return tuple(chained)


def _chain_slice(a, b):

print(f"\t_chain_slice({a}, {b})")

# a is a slice(start, stop, step) expression
if isinstance(a, slice):

print("\ta is slice")
start_a = a.start if a.start else 0
step_a = a.step if a.step else 1

if isinstance(b, int):

print("\tb is int")

idx = start_a + step_a * b
assert not a.stop or idx < a.stop, f"Slice {b} out of range for {b}"
print(f"\tnew idx = {idx}")
return idx

elif isinstance(b, slice):

print("\tb is slice")
start_b = b.start if b.start else 0
step_b = b.step if b.step else 1

start = start_a + step_a * start_b if a.start or b.start else None
stop = step_a * b.stop if b.stop else a.stop
step = step_a * step_b if a.step or b.step else None

return slice(start, stop, step)

elif isinstance(b, list):

return list(_chain_slice(a, x) for x in b)

elif isinstance(b, np.ndarray):

# is b a mask array?
if b.dtype == bool:
raise RuntimeError("Not yet implemented")

return np.array([_chain_slice(a, x) for x in b])

else:

raise RuntimeError(
f"Don't know how to deal with slice {b} of type {type(b)}"
)

# is an index array
elif isinstance(a, list):

print(f"\ta is index array {a}")
print(f"\tb is {b}")

return list(np.array(a)[(b,)])

elif isinstance(a, np.ndarray):

if a.dtype == bool:
raise RuntimeError("Not yet implemented")

return a[(b,)]

else:

raise RuntimeError(f"Don't know how to deal with slice {a} of type {type(a)}")
33 changes: 33 additions & 0 deletions tests/test_slices.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
import numpy as np
from funlib.persistence.arrays.slices import chain_slices


def test_slice_chaining():

base = np.s_[::2, 0, :4]

# chain with index expressions

s1 = chain_slices(base, np.s_[0])
assert s1 == np.s_[0, 0, :4]

s2 = chain_slices(s1, np.s_[1])
assert s2 == np.s_[0, 0, 1]

# chain with index arrays

s1 = chain_slices(base, np.s_[[0, 1, 1, 2, 3, 5], :])
assert s1 == np.s_[[0, 2, 2, 4, 6, 10], 0, :4]

# ...and another index array
s21 = chain_slices(s1, np.s_[[0, 3], :])
assert s21 == np.s_[[0, 4], 0, :4]

# ...and a slice() expression
s22 = chain_slices(s1, np.s_[1:4])
assert s22 == np.s_[[2, 2, 4], 0, :4]

# chain with slice expressions

s1 = chain_slices(base, np.s_[10:20, ::2])
assert s1 == np.s_[20:40:2, 0, :4:2]

0 comments on commit 9708bf7

Please sign in to comment.