Skip to content

Commit

Permalink
use dasks fuse_slice operation
Browse files Browse the repository at this point in the history
  • Loading branch information
pattonw committed Jul 8, 2024
1 parent 7797e1d commit fbde9a1
Show file tree
Hide file tree
Showing 3 changed files with 29 additions and 175 deletions.
70 changes: 7 additions & 63 deletions funlib/persistence/arrays/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
from .adapters import Adapter
import numpy as np
import dask.array as da
from functools import reduce
from dask.array.optimization import fuse_slice

from typing import Optional, Iterable, Any, Union

Expand Down Expand Up @@ -295,7 +297,6 @@ def __setitem__(self, key, value: np.ndarray):

region_slices = self.__slices(roi)


da.store(
self.data[roi_slices], self._source_data, regions=region_slices
)
Expand All @@ -306,7 +307,7 @@ def __setitem__(self, key, value: np.ndarray):
adapter for adapter in self.adapters if self._is_slice(adapter)
]

region_slices = self._combine_slices(*adapter_slices, key)
region_slices = reduce(fuse_slice, [*adapter_slices, key])

da.store(self.data[key], self._source_data, regions=region_slices)

Expand Down Expand Up @@ -352,63 +353,6 @@ def to_ndarray(self, roi, fill_value=0):

return data

def _combine_slices(
self, *roi_slices: list[Union[tuple[slice], slice]]
) -> list[slice]:
"""Combine slices into a single slice."""
# if there are multiple slices, then we are using adapters
# this is important because if we are considering the adapter slices
# we need to use the shape of the source data, not the adapted data
use_adapters = len(roi_slices) > 1
roi_slices = [
roi_slice if isinstance(roi_slice, tuple) else (roi_slice,)
for roi_slice in roi_slices
]
num_dims = max([len(roi_slice) for roi_slice in roi_slices])

remaining_dims = list(range(num_dims))
combined_ranges = [
(
range(0, self.shape[d], 1)
if not use_adapters
else range(0, self._source_data.shape[d], 1)
)
for d in range(num_dims)
]
combined_slices = []

for roi_slice in roi_slices:
dim_slices = [roi_slice[d] for d in range(num_dims) if len(roi_slice) > d]

del_dims = []
for d, s in enumerate(dim_slices):
current_dimension = remaining_dims[d]
combined_ranges[current_dimension] = combined_ranges[current_dimension][
s
]
if isinstance(s, int):
del_dims.append(d)
for d in del_dims:
del remaining_dims[d]

for combined_range in combined_ranges:
if isinstance(combined_range, int):
combined_slices.append(combined_range)
elif len(combined_range) == 0:
combined_slices.append(slice(0))
elif combined_range.stop < 0:
combined_slices.append(
slice(combined_range.start, None, combined_range.step)
)
else:
combined_slices.append(
slice(
combined_range.start, combined_range.stop, combined_range.step
)
)

return tuple(combined_slices)

def __slices(self, roi, use_adapters: bool = True, check_chunk_align: bool = False):
"""Get the voxel slices for the given roi."""

Expand Down Expand Up @@ -437,7 +381,7 @@ def __slices(self, roi, use_adapters: bool = True, check_chunk_align: bool = Fal
else []
)

combined_slice = self._combine_slices(*adapter_slices, roi_slices)
combined_slice = reduce(fuse_slice, [*adapter_slices, roi_slices])

return combined_slice

Expand All @@ -448,9 +392,9 @@ def _is_slice(self, adapter: Adapter):
or isinstance(adapter, list)
):
return True
elif isinstance(adapter, tuple) and all(
[isinstance(a, slice) or isinstance(a, int) for a in adapter]
):
elif isinstance(adapter, tuple) and all([self._is_slice(a) for a in adapter]):
return True
elif isinstance(adapter, np.ndarray) and adapter.dtype == bool:
return True
return False

Expand Down
98 changes: 0 additions & 98 deletions funlib/persistence/arrays/slices.py

This file was deleted.

36 changes: 22 additions & 14 deletions tests/test_slices.py
Original file line number Diff line number Diff line change
@@ -1,33 +1,41 @@
import numpy as np
from funlib.persistence.arrays.slices import chain_slices
from dask.array.optimization import fuse_slice
from functools import reduce
import pytest


def test_slice_chaining():

base = np.s_[::2, 0, :4]
def combine_slices(*slices):
return reduce(fuse_slice, slices)

base = np.s_[::2, :, :4]

# chain with index expressions

s1 = chain_slices(base, np.s_[0])
assert s1 == np.s_[0, 0, :4]
s1 = combine_slices(base, np.s_[0])
assert s1 == np.s_[0, :, :4]

s2 = chain_slices(s1, np.s_[1])
assert s2 == np.s_[0, 0, 1]
s2 = combine_slices(s1, np.s_[1])
assert s2 == np.s_[0, 1, :4]

# chain with index arrays

s1 = chain_slices(base, np.s_[[0, 1, 1, 2, 3, 5], :])
assert s1 == np.s_[[0, 2, 2, 4, 6, 10], 0, :4]
s1 = combine_slices(base, np.s_[[0, 1, 1, 2, 3, 5], :])
assert s1 == np.s_[[0, 2, 2, 4, 6, 10], 0:, :4]

# ...and another index array
s21 = chain_slices(s1, np.s_[[0, 3], :])
assert s21 == np.s_[[0, 4], 0, :4]
with pytest.raises(NotImplementedError):
# this is not supported because the combined indexing
# operation would not behave the same as the individual
# indexing operations performed in sequence
combine_slices(s1, np.s_[[0, 3], 2])

# ...and a slice() expression
s22 = chain_slices(s1, np.s_[1:4])
assert s22 == np.s_[[2, 2, 4], 0, :4]
s22 = combine_slices(s1, np.s_[1:4])
assert s22 == np.s_[[2, 2, 4], 0:, :4]

# chain with slice expressions

s1 = chain_slices(base, np.s_[10:20, ::2])
assert s1 == np.s_[20:40:2, 0, :4:2]
s1 = combine_slices(base, np.s_[10:20, ::2, 0])
assert s1 == np.s_[20:40:2, 0::2, 0]

0 comments on commit fbde9a1

Please sign in to comment.