Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove split_out #170

Merged
merged 2 commits into from
Oct 12, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
95 changes: 18 additions & 77 deletions flox/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -736,13 +736,6 @@ def _squeeze_results(results: IntermediateDict, axis: T_Axes) -> IntermediateDic
return newresults


def _split_groups(array, j, slicer):
"""Slices out chunks when split_out > 1"""
results = {"groups": array["groups"][..., slicer]}
results["intermediates"] = [v[..., slicer] for v in array["intermediates"]]
return results


def _finalize_results(
results: IntermediateDict,
agg: Aggregation,
Expand Down Expand Up @@ -997,38 +990,6 @@ def _grouped_combine(
return results


def split_blocks(applied, split_out, expected_groups, split_name):
import dask.array
from dask.array.core import normalize_chunks
from dask.highlevelgraph import HighLevelGraph

chunk_tuples = tuple(itertools.product(*tuple(range(n) for n in applied.numblocks)))
ngroups = len(expected_groups)
group_chunks = normalize_chunks(np.ceil(ngroups / split_out), (ngroups,))
idx = tuple(np.cumsum((0,) + group_chunks[0]))

# split each block into `split_out` chunks
dsk = {}
for i in chunk_tuples:
for j in range(split_out):
dsk[(split_name, *i, j)] = (
_split_groups,
(applied.name, *i),
j,
slice(idx[j], idx[j + 1]),
)

# now construct an array that can be passed to _tree_reduce
intergraph = HighLevelGraph.from_collections(split_name, dsk, dependencies=(applied,))
intermediate = dask.array.Array(
intergraph,
name=split_name,
chunks=applied.chunks + ((1,) * split_out,),
meta=applied._meta,
)
return intermediate, group_chunks


def _reduce_blockwise(
array,
by,
Expand Down Expand Up @@ -1169,7 +1130,6 @@ def dask_groupby_agg(
agg: Aggregation,
expected_groups: pd.Index | None,
axis: T_Axes = (),
split_out: int = 1,
fill_value: Any = None,
method: T_Method = "map-reduce",
reindex: bool = False,
Expand All @@ -1186,19 +1146,14 @@ def dask_groupby_agg(
assert isinstance(axis, Sequence)
assert all(ax >= 0 for ax in axis)

if method == "blockwise" and (split_out > 1 or not isinstance(by, np.ndarray)):
raise NotImplementedError

if split_out > 1 and expected_groups is None:
# This could be implemented using the "hash_split" strategy
# from dask.dataframe
if method == "blockwise" and not isinstance(by, np.ndarray):
raise NotImplementedError

inds = tuple(range(array.ndim))
name = f"groupby_{agg.name}"
token = dask.base.tokenize(array, by, agg, expected_groups, axis, split_out)
token = dask.base.tokenize(array, by, agg, expected_groups, axis)

if expected_groups is None and (reindex or split_out > 1):
if expected_groups is None and reindex:
expected_groups = _get_expected_groups(by, sort=sort)

by_input = by
Expand Down Expand Up @@ -1229,9 +1184,7 @@ def dask_groupby_agg(
# This allows us to discover groups at compute time, support argreductions, lower intermediate
# memory usage (but method="cohorts" would also work to reduce memory in some cases)

do_simple_combine = (
method != "blockwise" and reindex and not _is_arg_reduction(agg) and split_out == 1
)
do_simple_combine = method != "blockwise" and reindex and not _is_arg_reduction(agg)
if method == "blockwise":
# use the "non dask" code path, but applied blockwise
blockwise_method = partial(
Expand All @@ -1244,14 +1197,14 @@ def dask_groupby_agg(
func=agg.chunk,
fill_value=agg.fill_value["intermediate"],
dtype=agg.dtype["intermediate"],
reindex=reindex or (split_out > 1),
reindex=reindex,
)
if do_simple_combine:
# Add a dummy dimension that then gets reduced over
blockwise_method = tlz.compose(_expand_dims, blockwise_method)

# apply reduction on chunk
applied = dask.array.blockwise(
intermediate = dask.array.blockwise(
partial(
blockwise_method,
axis=axis,
Expand All @@ -1271,18 +1224,14 @@ def dask_groupby_agg(
token=f"{name}-chunk-{token}",
)

if split_out > 1:
intermediate, group_chunks = split_blocks(
applied, split_out, expected_groups, split_name=f"{name}-split-{token}"
)
else:
intermediate = applied
if expected_groups is None:
if is_duck_dask_array(by_input):
expected_groups = None
else:
expected_groups = _get_expected_groups(by_input, sort=sort)
group_chunks = ((len(expected_groups),) if expected_groups is not None else (np.nan,),)
if expected_groups is None:
if is_duck_dask_array(by_input):
expected_groups = None
else:
expected_groups = _get_expected_groups(by_input, sort=sort)
group_chunks: tuple[tuple[Union[int, float], ...]] = (
(len(expected_groups),) if expected_groups is not None else (np.nan,),
)

if method in ["map-reduce", "cohorts", "split-reduce"]:
combine: Callable[..., IntermediateDict]
Expand Down Expand Up @@ -1311,9 +1260,7 @@ def dask_groupby_agg(
if method == "map-reduce":
reduced = tree_reduce(
intermediate,
aggregate=partial(
aggregate, expected_groups=None if split_out > 1 else expected_groups
),
aggregate=partial(aggregate, expected_groups=expected_groups),
)
if is_duck_dask_array(by_input) and expected_groups is None:
groups = _extract_unknown_groups(reduced, group_chunks=group_chunks, dtype=by.dtype)
Expand Down Expand Up @@ -1380,7 +1327,7 @@ def dask_groupby_agg(
raise ValueError(f"Unknown method={method}.")

# extract results from the dict
output_chunks = reduced.chunks[: -(len(axis) + int(split_out > 1))] + group_chunks
output_chunks = reduced.chunks[: -len(axis)] + group_chunks
ochunks = tuple(range(len(chunks_v)) for chunks_v in output_chunks)
layer2: dict[tuple, tuple] = {}
agg_name = f"{name}-{token}"
Expand All @@ -1392,10 +1339,7 @@ def dask_groupby_agg(
nblocks = tuple(len(array.chunks[ax]) for ax in axis)
inchunk = ochunk[:-1] + np.unravel_index(ochunk[-1], nblocks)
else:
inchunk = ochunk[:-1] + (0,) * (len(axis) - 1)
if split_out > 1:
inchunk = inchunk + (0,)
inchunk = inchunk + (ochunk[-1],)
inchunk = ochunk[:-1] + (0,) * (len(axis) - 1) + (ochunk[-1],)

layer2[(agg_name, *ochunk)] = (operator.getitem, (reduced.name, *inchunk), agg.name)

Expand Down Expand Up @@ -1516,7 +1460,6 @@ def groupby_reduce(
fill_value=None,
dtype: np.typing.DTypeLike = None,
min_count: int | None = None,
split_out: int = 1,
method: T_Method = "map-reduce",
engine: T_Engine = "numpy",
reindex: bool | None = None,
Expand Down Expand Up @@ -1555,8 +1498,6 @@ def groupby_reduce(
fewer than min_count non-NA values are present the result will be
NA. Only used if skipna is set to True or defaults to True for the
array's dtype.
split_out : int, optional
Number of chunks along group axis in output (last axis)
method : {"map-reduce", "blockwise", "cohorts", "split-reduce"}, optional
Strategy for reduction of dask arrays only:
* ``"map-reduce"``:
Expand Down Expand Up @@ -1750,7 +1691,7 @@ def groupby_reduce(
if kwargs["fill_value"] is None:
kwargs["fill_value"] = agg.fill_value[agg.name]

partial_agg = partial(dask_groupby_agg, split_out=split_out, **kwargs)
partial_agg = partial(dask_groupby_agg, **kwargs)

if method == "blockwise" and by_.ndim == 1:
array = rechunk_for_blockwise(array, axis=-1, labels=by_)
Expand Down
4 changes: 0 additions & 4 deletions flox/xarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,6 @@ def xarray_reduce(
isbin: bool | Sequence[bool] = False,
sort: bool = True,
dim: Dims | ellipsis = None,
split_out: int = 1,
fill_value=None,
dtype: np.typing.DTypeLike = None,
method: str = "map-reduce",
Expand Down Expand Up @@ -94,8 +93,6 @@ def xarray_reduce(
dim : hashable
dimension name along which to reduce. If None, reduces across all
dimensions of `by`
split_out : int, optional
Number of output chunks along grouped dimension in output.
fill_value
Value used for missing groups in the output i.e. when one of the labels
in ``expected_groups`` is not actually present in ``by``.
Expand Down Expand Up @@ -396,7 +393,6 @@ def wrapper(array, *by, func, skipna, **kwargs):
"func": func,
"axis": axis,
"sort": sort,
"split_out": split_out,
"fill_value": fill_value,
"method": method,
"min_count": min_count,
Expand Down
4 changes: 1 addition & 3 deletions tests/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ def test_alignment_error():


@pytest.mark.parametrize("dtype", (float, int))
@pytest.mark.parametrize("chunk, split_out", [(False, 1), (True, 1), (True, 2), (True, 3)])
@pytest.mark.parametrize("chunk", [False, True])
@pytest.mark.parametrize("expected_groups", [None, [0, 1, 2], np.array([0, 1, 2])])
@pytest.mark.parametrize(
"func, array, by, expected",
Expand Down Expand Up @@ -114,7 +114,6 @@ def test_groupby_reduce(
expected: list[float],
expected_groups: T_ExpectedGroupsOpt,
chunk: bool,
split_out: int,
dtype: np.typing.DTypeLike,
) -> None:
array = array.astype(dtype)
Expand All @@ -137,7 +136,6 @@ def test_groupby_reduce(
func=func,
expected_groups=expected_groups,
fill_value=123,
split_out=split_out,
engine=engine,
)
g_dtype = by.dtype if expected_groups is None else np.asarray(expected_groups).dtype
Expand Down