-
Notifications
You must be signed in to change notification settings - Fork 10
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add function to chain (simple) slice expressions
- Loading branch information
Showing
2 changed files
with
137 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,104 @@ | ||
import numpy as np | ||
|
||
|
||
def chain_slices(slices_a, slices_b): | ||
|
||
# make sure both slice expressions are tuples | ||
if not isinstance(slices_a, tuple): | ||
slices_a = (slices_a,) | ||
if not isinstance(slices_b, tuple): | ||
slices_b = (slices_b,) | ||
|
||
# dimension of a is number of non-int expressions | ||
dim_a = sum([not isinstance(x, int) for x in slices_a]) | ||
|
||
# slices_b can't slice more dimensions than a has | ||
assert ( | ||
len(slices_b) <= dim_a | ||
), f"Slice expression {slices_b} has too many dimensions to chain with {slices_a}" | ||
|
||
chained = [] | ||
|
||
j = 0 | ||
for slice_a in slices_a: | ||
|
||
# if slice_a is int that dimension does not exist any longer, skip | ||
# also skip if b has no more elements | ||
if j == len(slices_b) or isinstance(slice_a, int): | ||
chained.append(slice_a) | ||
else: | ||
slice_b = slices_b[j] | ||
chained.append(_chain_slice(slice_a, slice_b)) | ||
j += 1 | ||
|
||
return tuple(chained) | ||
|
||
|
||
def _chain_slice(a, b): | ||
|
||
print(f"\t_chain_slice({a}, {b})") | ||
|
||
# a is a slice(start, stop, step) expression | ||
if isinstance(a, slice): | ||
|
||
print("\ta is slice") | ||
start_a = a.start if a.start else 0 | ||
step_a = a.step if a.step else 1 | ||
|
||
if isinstance(b, int): | ||
|
||
print("\tb is int") | ||
|
||
idx = start_a + step_a * b | ||
assert not a.stop or idx < a.stop, f"Slice {b} out of range for {b}" | ||
print(f"\tnew idx = {idx}") | ||
return idx | ||
|
||
elif isinstance(b, slice): | ||
|
||
print("\tb is slice") | ||
start_b = b.start if b.start else 0 | ||
step_b = b.step if b.step else 1 | ||
|
||
start = start_a + step_a * start_b if a.start or b.start else None | ||
stop = step_a * b.stop if b.stop else a.stop | ||
step = step_a * step_b if a.step or b.step else None | ||
|
||
return slice(start, stop, step) | ||
|
||
elif isinstance(b, list): | ||
|
||
return list(_chain_slice(a, x) for x in b) | ||
|
||
elif isinstance(b, np.ndarray): | ||
|
||
# is b a mask array? | ||
if b.dtype == bool: | ||
raise RuntimeError("Not yet implemented") | ||
|
||
return np.array([_chain_slice(a, x) for x in b]) | ||
|
||
else: | ||
|
||
raise RuntimeError( | ||
f"Don't know how to deal with slice {b} of type {type(b)}" | ||
) | ||
|
||
# is an index array | ||
elif isinstance(a, list): | ||
|
||
print(f"\ta is index array {a}") | ||
print(f"\tb is {b}") | ||
|
||
return list(np.array(a)[(b,)]) | ||
|
||
elif isinstance(a, np.ndarray): | ||
|
||
if a.dtype == bool: | ||
raise RuntimeError("Not yet implemented") | ||
|
||
return a[(b,)] | ||
|
||
else: | ||
|
||
raise RuntimeError(f"Don't know how to deal with slice {a} of type {type(a)}") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,33 @@ | ||
import numpy as np | ||
from funlib.persistence.arrays.slices import chain_slices | ||
|
||
|
||
def test_slice_chaining(): | ||
|
||
base = np.s_[::2, 0, :4] | ||
|
||
# chain with index expressions | ||
|
||
s1 = chain_slices(base, np.s_[0]) | ||
assert s1 == np.s_[0, 0, :4] | ||
|
||
s2 = chain_slices(s1, np.s_[1]) | ||
assert s2 == np.s_[0, 0, 1] | ||
|
||
# chain with index arrays | ||
|
||
s1 = chain_slices(base, np.s_[[0, 1, 1, 2, 3, 5], :]) | ||
assert s1 == np.s_[[0, 2, 2, 4, 6, 10], 0, :4] | ||
|
||
# ...and another index array | ||
s21 = chain_slices(s1, np.s_[[0, 3], :]) | ||
assert s21 == np.s_[[0, 4], 0, :4] | ||
|
||
# ...and a slice() expression | ||
s22 = chain_slices(s1, np.s_[1:4]) | ||
assert s22 == np.s_[[2, 2, 4], 0, :4] | ||
|
||
# chain with slice expressions | ||
|
||
s1 = chain_slices(base, np.s_[10:20, ::2]) | ||
assert s1 == np.s_[20:40:2, 0, :4:2] |