Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Major fix to subset_to_blocks #173

Merged
merged 21 commits into from
Oct 16, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
77 changes: 72 additions & 5 deletions asv_bench/benchmarks/cohorts.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,22 @@ def track_num_tasks(self):
)[0]
return len(result.dask.to_dict())

def track_num_tasks_optimized(self):
result = flox.groupby_reduce(
self.array, self.by, func="sum", axis=self.axis, method="cohorts"
)[0]
(opt,) = dask.optimize(result)
return len(opt.dask.to_dict())

def track_num_layers(self):
result = flox.groupby_reduce(
self.array, self.by, func="sum", axis=self.axis, method="cohorts"
)[0]
return len(result.dask.layers)

track_num_tasks.unit = "tasks"
track_num_tasks_optimized.unit = "tasks"
track_num_layers.unit = "layers"


class NWMMidwest(Cohorts):
Expand All @@ -45,16 +60,68 @@ def setup(self, *args, **kwargs):
self.axis = (-2, -1)


class ERA5(Cohorts):
class ERA5Dataset:
"""ERA5"""

def __init__(self, *args, **kwargs):
self.time = pd.Series(pd.date_range("2016-01-01", "2018-12-31 23:59", freq="H"))
self.axis = (-1,)
self.array = dask.array.random.random((721, 1440, len(self.time)), chunks=(-1, -1, 48))

def rechunk(self):
self.array = flox.core.rechunk_for_cohorts(
self.array, -1, self.by, force_new_chunk_at=[1], chunksize=48, ignore_old_chunks=True
)


class ERA5DayOfYear(ERA5Dataset, Cohorts):
def setup(self, *args, **kwargs):
super().__init__()
self.by = self.time.dt.dayofyear.values


class ERA5DayOfYearRechunked(ERA5DayOfYear, Cohorts):
def setup(self, *args, **kwargs):
super().setup()
super().rechunk()


class ERA5MonthHour(ERA5Dataset, Cohorts):
def setup(self, *args, **kwargs):
time = pd.Series(pd.date_range("2016-01-01", "2018-12-31 23:59", freq="H"))
super().__init__()
by = (self.time.dt.month.values, self.time.dt.hour.values)
ret = flox.core._factorize_multiple(
by,
expected_groups=(pd.Index(np.arange(1, 13)), pd.Index(np.arange(1, 25))),
by_is_dask=False,
reindex=False,
)
# Add one so the rechunk code is simpler and makes sense
self.by = ret[0][0] + 1

self.by = time.dt.dayofyear.values

class ERA5MonthHourRechunked(ERA5MonthHour, Cohorts):
def setup(self, *args, **kwargs):
super().setup()
super().rechunk()


class PerfectMonthly(Cohorts):
"""Perfectly chunked for a "cohorts" monthly mean climatology"""

def setup(self, *args, **kwargs):
self.time = pd.Series(pd.date_range("1961-01-01", "2018-12-31 23:59", freq="M"))
self.axis = (-1,)
self.array = dask.array.random.random((721, 1440, len(self.time)), chunks=(-1, -1, 4))
self.by = self.time.dt.month.values

array = dask.array.random.random((721, 1440, len(time)), chunks=(-1, -1, 48))
def rechunk(self):
self.array = flox.core.rechunk_for_cohorts(
array, -1, self.by, force_new_chunk_at=[1], chunksize=48, ignore_old_chunks=True
self.array, -1, self.by, force_new_chunk_at=[1], chunksize=4, ignore_old_chunks=True
)


class PerfectMonthlyRechunked(PerfectMonthly):
def setup(self, *args, **kwargs):
super().setup()
super().rechunk()
96 changes: 64 additions & 32 deletions flox/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import operator
from collections import namedtuple
from functools import partial, reduce
from numbers import Integral
from typing import TYPE_CHECKING, Any, Callable, Dict, Literal, Mapping, Sequence, Union

import numpy as np
Expand Down Expand Up @@ -288,7 +289,7 @@ def rechunk_for_cohorts(
divisions = []
counter = 1
for idx, lab in enumerate(labels):
if lab in force_new_chunk_at:
if lab in force_new_chunk_at or idx == 0:
divisions.append(idx)
counter = 1
continue
Expand All @@ -305,6 +306,7 @@ def rechunk_for_cohorts(
divisions.append(idx)
counter = 1
continue

counter += 1

divisions.append(len(labels))
Expand All @@ -313,6 +315,9 @@ def rechunk_for_cohorts(
print(labels_at_breaks[:40])

newchunks = tuple(np.diff(divisions))
if debug:
print(divisions[:10], newchunks[:10])
print(divisions[-10:], newchunks[-10:])
assert sum(newchunks) == len(labels)

if newchunks == array.chunks[axis]:
Expand Down Expand Up @@ -1046,26 +1051,18 @@ def _reduce_blockwise(
return result


def subset_to_blocks(
array: DaskArray, flatblocks: Sequence[int], blkshape: tuple[int] | None = None
) -> DaskArray:
def _normalize_indexes(array, flatblocks, blkshape):
"""
Advanced indexing of .blocks such that we always get a regular array back.
.blocks accessor can only accept one iterable at a time,
but can handle multiple slices.
To minimize tasks and layers, we normalize to produce slices
along as many axes as possible, and then repeatedly apply
any remaining iterables in a loop.

Parameters
----------
array : dask.array
flatblocks : flat indices of blocks to extract
blkshape : shape of blocks with which to unravel flatblocks

Returns
-------
dask.array
TODO: move this upstream
"""
if blkshape is None:
blkshape = array.blocks.shape

unraveled = np.unravel_index(flatblocks, blkshape)

normalized: list[Union[int, np.ndarray, slice]] = []
for ax, idx in enumerate(unraveled):
i = _unique(idx).squeeze()
Expand All @@ -1077,30 +1074,65 @@ def subset_to_blocks(
elif np.array_equal(i, np.arange(i[0], i[-1] + 1)):
normalized.append(slice(i[0], i[-1] + 1))
else:
normalized.append(i)
normalized.append(list(i))
full_normalized = (slice(None),) * (array.ndim - len(normalized)) + tuple(normalized)

# has no iterables
noiter = tuple(i if not hasattr(i, "__len__") else slice(None) for i in full_normalized)
noiter = list(i if not hasattr(i, "__len__") else slice(None) for i in full_normalized)
# has all iterables
alliter = {
ax: i if hasattr(i, "__len__") else slice(None) for ax, i in enumerate(full_normalized)
}
alliter = {ax: i for ax, i in enumerate(full_normalized) if hasattr(i, "__len__")}

# apply everything but the iterables
if all(i == slice(None) for i in noiter):
mesh = dict(zip(alliter.keys(), np.ix_(*alliter.values())))

full_tuple = tuple(i if ax not in mesh else mesh[ax] for ax, i in enumerate(noiter))

return full_tuple


def subset_to_blocks(
array: DaskArray, flatblocks: Sequence[int], blkshape: tuple[int] | None = None
) -> DaskArray:
"""
Advanced indexing of .blocks such that we always get a regular array back.

Parameters
----------
array : dask.array
flatblocks : flat indices of blocks to extract
blkshape : shape of blocks with which to unravel flatblocks

Returns
-------
dask.array
"""
import dask.array
from dask.array.slicing import normalize_index
from dask.base import tokenize
from dask.highlevelgraph import HighLevelGraph

if blkshape is None:
blkshape = array.blocks.shape

index = _normalize_indexes(array, flatblocks, blkshape)

if all(not isinstance(i, np.ndarray) and i == slice(None) for i in index):
return array

subset = array.blocks[noiter]
# These rest is copied from dask.array.core.py with slight modifications
index = normalize_index(index, array.numblocks)
index = tuple(slice(k, k + 1) if isinstance(k, Integral) else k for k in index)

for ax, inds in alliter.items():
if isinstance(inds, slice):
continue
idxr = [slice(None, None)] * array.ndim
idxr[ax] = inds
subset = subset.blocks[tuple(idxr)]
name = "blocks-" + tokenize(array, index)
new_keys = array._key_array[index]

squeezed = tuple(np.squeeze(i) if isinstance(i, np.ndarray) else i for i in index)
chunks = tuple(tuple(np.array(c)[i].tolist()) for c, i in zip(array.chunks, squeezed))

keys = itertools.product(*(range(len(c)) for c in chunks))
layer = {(name,) + key: tuple(new_keys[key].tolist()) for key in keys}
graph = HighLevelGraph.from_collections(name, layer, dependencies=[array])

return subset
return dask.array.Array(graph, name, chunks, meta=array)


def _extract_unknown_groups(reduced, group_chunks, dtype) -> tuple[DaskArray]:
Expand Down
12 changes: 12 additions & 0 deletions tests/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,18 @@ def assert_equal(a, b, tolerance=None):
np.testing.assert_allclose(a, b, equal_nan=True, **tolerance)


def assert_equal_tuple(a, b):
"""assert_equal for .blocks indexing tuples"""
assert len(a) == len(b)

for a_, b_ in zip(a, b):
assert type(a_) == type(b_)
if isinstance(a_, np.ndarray):
np.testing.assert_array_equal(a_, b_)
else:
assert a_ == b_


@pytest.fixture(scope="module", params=["flox", "numpy", "numba"])
def engine(request):
if request.param == "numba":
Expand Down
92 changes: 91 additions & 1 deletion tests/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,23 @@
from flox.core import (
_convert_expected_groups_to_index,
_get_optimal_chunks_for_groups,
_normalize_indexes,
factorize_,
find_group_cohorts,
groupby_reduce,
rechunk_for_cohorts,
reindex_,
subset_to_blocks,
)

from . import assert_equal, engine, has_dask, raise_if_dask_computes, requires_dask
from . import (
assert_equal,
assert_equal_tuple,
engine,
has_dask,
raise_if_dask_computes,
requires_dask,
)

labels = np.array([0, 0, 2, 2, 2, 1, 1, 2, 2, 1, 1, 0])
nan_labels = labels.astype(float) # copy
Expand Down Expand Up @@ -1035,3 +1044,84 @@ def test_dtype(func, dtype, engine):
labels = np.array(["a", "a", "c", "c", "c", "b", "b", "c", "c", "b", "b", "f"])
actual, _ = groupby_reduce(arr, labels, func=func, dtype=np.float64)
assert actual.dtype == np.dtype("float64")


@requires_dask
def test_subset_blocks():
array = dask.array.random.random((120,), chunks=(4,))

blockid = (0, 3, 6, 9, 12, 15, 18, 21, 24, 27)
subset = subset_to_blocks(array, blockid)
assert subset.blocks.shape == (len(blockid),)


@requires_dask
@pytest.mark.parametrize(
"flatblocks, expected",
(
((0, 1, 2, 3, 4), (slice(None),)),
((1, 2, 3), (slice(1, 4),)),
((1, 3), ([1, 3],)),
((0, 1, 3), ([0, 1, 3],)),
),
)
def test_normalize_block_indexing_1d(flatblocks, expected):
nblocks = 5
array = dask.array.ones((nblocks,), chunks=(1,))
expected = tuple(np.array(i) if isinstance(i, list) else i for i in expected)
actual = _normalize_indexes(array, flatblocks, array.blocks.shape)
assert_equal_tuple(expected, actual)


@requires_dask
@pytest.mark.parametrize(
"flatblocks, expected",
(
((0, 1, 2, 3, 4), (0, slice(None))),
((1, 2, 3), (0, slice(1, 4))),
((1, 3), (0, [1, 3])),
((0, 1, 3), (0, [0, 1, 3])),
(tuple(range(10)), (slice(0, 2), slice(None))),
((0, 1, 3, 5, 6, 8), (slice(0, 2), [0, 1, 3])),
((0, 3, 4, 5, 6, 8, 24), np.ix_([0, 1, 4], [0, 1, 3, 4])),
),
)
def test_normalize_block_indexing_2d(flatblocks, expected):
nblocks = 5
ndim = 2
array = dask.array.ones((nblocks,) * ndim, chunks=(1,) * ndim)
expected = tuple(np.array(i) if isinstance(i, list) else i for i in expected)
actual = _normalize_indexes(array, flatblocks, array.blocks.shape)
assert_equal_tuple(expected, actual)


@requires_dask
def test_subset_block_passthrough():
# full slice pass through
array = dask.array.ones((5,), chunks=(1,))
subset = subset_to_blocks(array, np.arange(5))
assert subset.name == array.name

array = dask.array.ones((5, 5), chunks=1)
subset = subset_to_blocks(array, np.arange(25))
assert subset.name == array.name


@requires_dask
@pytest.mark.parametrize(
"flatblocks, expectidx",
[
(np.arange(10), (slice(2), slice(None))),
(np.arange(8), (slice(2), slice(None))),
([0, 10], ([0, 2], slice(1))),
([0, 7], (slice(2), [0, 2])),
([0, 7, 9], (slice(2), [0, 2, 4])),
([0, 6, 12, 14], (slice(3), [0, 1, 2, 4])),
([0, 12, 14, 19], np.ix_([0, 2, 3], [0, 2, 4])),
],
)
def test_subset_block_2d(flatblocks, expectidx):
array = dask.array.from_array(np.arange(25).reshape((5, 5)), chunks=1)
subset = subset_to_blocks(array, flatblocks)
assert len(subset.dask.layers) == 2
assert_equal(subset, array.compute()[expectidx])