Skip to content

Commit

Permalink
Remove split_out
Browse files Browse the repository at this point in the history
Closes #166
Closes #11
  • Loading branch information
dcherian committed Oct 11, 2022
1 parent 72dfc87 commit f04a87e
Show file tree
Hide file tree
Showing 3 changed files with 17 additions and 84 deletions.
93 changes: 16 additions & 77 deletions flox/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -736,13 +736,6 @@ def _squeeze_results(results: IntermediateDict, axis: T_Axes) -> IntermediateDic
return newresults


def _split_groups(array, j, slicer):
"""Slices out chunks when split_out > 1"""
results = {"groups": array["groups"][..., slicer]}
results["intermediates"] = [v[..., slicer] for v in array["intermediates"]]
return results


def _finalize_results(
results: IntermediateDict,
agg: Aggregation,
Expand Down Expand Up @@ -997,38 +990,6 @@ def _grouped_combine(
return results


def split_blocks(applied, split_out, expected_groups, split_name):
import dask.array
from dask.array.core import normalize_chunks
from dask.highlevelgraph import HighLevelGraph

chunk_tuples = tuple(itertools.product(*tuple(range(n) for n in applied.numblocks)))
ngroups = len(expected_groups)
group_chunks = normalize_chunks(np.ceil(ngroups / split_out), (ngroups,))
idx = tuple(np.cumsum((0,) + group_chunks[0]))

# split each block into `split_out` chunks
dsk = {}
for i in chunk_tuples:
for j in range(split_out):
dsk[(split_name, *i, j)] = (
_split_groups,
(applied.name, *i),
j,
slice(idx[j], idx[j + 1]),
)

# now construct an array that can be passed to _tree_reduce
intergraph = HighLevelGraph.from_collections(split_name, dsk, dependencies=(applied,))
intermediate = dask.array.Array(
intergraph,
name=split_name,
chunks=applied.chunks + ((1,) * split_out,),
meta=applied._meta,
)
return intermediate, group_chunks


def _reduce_blockwise(
array,
by,
Expand Down Expand Up @@ -1169,7 +1130,6 @@ def dask_groupby_agg(
agg: Aggregation,
expected_groups: pd.Index | None,
axis: T_Axes = (),
split_out: int = 1,
fill_value: Any = None,
method: T_Method = "map-reduce",
reindex: bool = False,
Expand All @@ -1186,19 +1146,14 @@ def dask_groupby_agg(
assert isinstance(axis, Sequence)
assert all(ax >= 0 for ax in axis)

if method == "blockwise" and (split_out > 1 or not isinstance(by, np.ndarray)):
raise NotImplementedError

if split_out > 1 and expected_groups is None:
# This could be implemented using the "hash_split" strategy
# from dask.dataframe
if method == "blockwise" and not isinstance(by, np.ndarray):
raise NotImplementedError

inds = tuple(range(array.ndim))
name = f"groupby_{agg.name}"
token = dask.base.tokenize(array, by, agg, expected_groups, axis, split_out)
token = dask.base.tokenize(array, by, agg, expected_groups, axis)

if expected_groups is None and (reindex or split_out > 1):
if expected_groups is None and reindex:
expected_groups = _get_expected_groups(by, sort=sort)

by_input = by
Expand Down Expand Up @@ -1229,9 +1184,7 @@ def dask_groupby_agg(
# This allows us to discover groups at compute time, support argreductions, lower intermediate
# memory usage (but method="cohorts" would also work to reduce memory in some cases)

do_simple_combine = (
method != "blockwise" and reindex and not _is_arg_reduction(agg) and split_out == 1
)
do_simple_combine = method != "blockwise" and reindex and not _is_arg_reduction(agg)
if method == "blockwise":
# use the "non dask" code path, but applied blockwise
blockwise_method = partial(
Expand All @@ -1244,14 +1197,14 @@ def dask_groupby_agg(
func=agg.chunk,
fill_value=agg.fill_value["intermediate"],
dtype=agg.dtype["intermediate"],
reindex=reindex or (split_out > 1),
reindex=reindex,
)
if do_simple_combine:
# Add a dummy dimension that then gets reduced over
blockwise_method = tlz.compose(_expand_dims, blockwise_method)

# apply reduction on chunk
applied = dask.array.blockwise(
intermediate = dask.array.blockwise(
partial(
blockwise_method,
axis=axis,
Expand All @@ -1271,18 +1224,12 @@ def dask_groupby_agg(
token=f"{name}-chunk-{token}",
)

if split_out > 1:
intermediate, group_chunks = split_blocks(
applied, split_out, expected_groups, split_name=f"{name}-split-{token}"
)
else:
intermediate = applied
if expected_groups is None:
if is_duck_dask_array(by_input):
expected_groups = None
else:
expected_groups = _get_expected_groups(by_input, sort=sort)
group_chunks = ((len(expected_groups),) if expected_groups is not None else (np.nan,),)
if expected_groups is None:
if is_duck_dask_array(by_input):
expected_groups = None
else:
expected_groups = _get_expected_groups(by_input, sort=sort)
group_chunks = ((len(expected_groups),) if expected_groups is not None else (np.nan,),)

if method in ["map-reduce", "cohorts", "split-reduce"]:
combine: Callable[..., IntermediateDict]
Expand Down Expand Up @@ -1311,9 +1258,7 @@ def dask_groupby_agg(
if method == "map-reduce":
reduced = tree_reduce(
intermediate,
aggregate=partial(
aggregate, expected_groups=None if split_out > 1 else expected_groups
),
aggregate=partial(aggregate, expected_groups=expected_groups),
)
if is_duck_dask_array(by_input) and expected_groups is None:
groups = _extract_unknown_groups(reduced, group_chunks=group_chunks, dtype=by.dtype)
Expand Down Expand Up @@ -1380,7 +1325,7 @@ def dask_groupby_agg(
raise ValueError(f"Unknown method={method}.")

# extract results from the dict
output_chunks = reduced.chunks[: -(len(axis) + int(split_out > 1))] + group_chunks
output_chunks = reduced.chunks[: -len(axis)] + group_chunks
ochunks = tuple(range(len(chunks_v)) for chunks_v in output_chunks)
layer2: dict[tuple, tuple] = {}
agg_name = f"{name}-{token}"
Expand All @@ -1392,10 +1337,7 @@ def dask_groupby_agg(
nblocks = tuple(len(array.chunks[ax]) for ax in axis)
inchunk = ochunk[:-1] + np.unravel_index(ochunk[-1], nblocks)
else:
inchunk = ochunk[:-1] + (0,) * (len(axis) - 1)
if split_out > 1:
inchunk = inchunk + (0,)
inchunk = inchunk + (ochunk[-1],)
inchunk = ochunk[:-1] + (0,) * (len(axis) - 1) + (ochunk[-1],)

layer2[(agg_name, *ochunk)] = (operator.getitem, (reduced.name, *inchunk), agg.name)

Expand Down Expand Up @@ -1516,7 +1458,6 @@ def groupby_reduce(
fill_value=None,
dtype: np.typing.DTypeLike = None,
min_count: int | None = None,
split_out: int = 1,
method: T_Method = "map-reduce",
engine: T_Engine = "numpy",
reindex: bool | None = None,
Expand Down Expand Up @@ -1555,8 +1496,6 @@ def groupby_reduce(
fewer than min_count non-NA values are present the result will be
NA. Only used if skipna is set to True or defaults to True for the
array's dtype.
split_out : int, optional
Number of chunks along group axis in output (last axis)
method : {"map-reduce", "blockwise", "cohorts", "split-reduce"}, optional
Strategy for reduction of dask arrays only:
* ``"map-reduce"``:
Expand Down Expand Up @@ -1750,7 +1689,7 @@ def groupby_reduce(
if kwargs["fill_value"] is None:
kwargs["fill_value"] = agg.fill_value[agg.name]

partial_agg = partial(dask_groupby_agg, split_out=split_out, **kwargs)
partial_agg = partial(dask_groupby_agg, **kwargs)

if method == "blockwise" and by_.ndim == 1:
array = rechunk_for_blockwise(array, axis=-1, labels=by_)
Expand Down
4 changes: 0 additions & 4 deletions flox/xarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,6 @@ def xarray_reduce(
isbin: bool | Sequence[bool] = False,
sort: bool = True,
dim: Dims | ellipsis = None,
split_out: int = 1,
fill_value=None,
dtype: np.typing.DTypeLike = None,
method: str = "map-reduce",
Expand Down Expand Up @@ -94,8 +93,6 @@ def xarray_reduce(
dim : hashable
dimension name along which to reduce. If None, reduces across all
dimensions of `by`
split_out : int, optional
Number of output chunks along grouped dimension in output.
fill_value
Value used for missing groups in the output i.e. when one of the labels
in ``expected_groups`` is not actually present in ``by``.
Expand Down Expand Up @@ -396,7 +393,6 @@ def wrapper(array, *by, func, skipna, **kwargs):
"func": func,
"axis": axis,
"sort": sort,
"split_out": split_out,
"fill_value": fill_value,
"method": method,
"min_count": min_count,
Expand Down
4 changes: 1 addition & 3 deletions tests/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ def test_alignment_error():


@pytest.mark.parametrize("dtype", (float, int))
@pytest.mark.parametrize("chunk, split_out", [(False, 1), (True, 1), (True, 2), (True, 3)])
@pytest.mark.parametrize("chunk", [False, True])
@pytest.mark.parametrize("expected_groups", [None, [0, 1, 2], np.array([0, 1, 2])])
@pytest.mark.parametrize(
"func, array, by, expected",
Expand Down Expand Up @@ -114,7 +114,6 @@ def test_groupby_reduce(
expected: list[float],
expected_groups: T_ExpectedGroupsOpt,
chunk: bool,
split_out: int,
dtype: np.typing.DTypeLike,
) -> None:
array = array.astype(dtype)
Expand All @@ -137,7 +136,6 @@ def test_groupby_reduce(
func=func,
expected_groups=expected_groups,
fill_value=123,
split_out=split_out,
engine=engine,
)
g_dtype = by.dtype if expected_groups is None else np.asarray(expected_groups).dtype
Expand Down

0 comments on commit f04a87e

Please sign in to comment.