Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove split_out #170

Merged
merged 2 commits into from
Oct 12, 2022
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
93 changes: 16 additions & 77 deletions flox/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -736,13 +736,6 @@ def _squeeze_results(results: IntermediateDict, axis: T_Axes) -> IntermediateDic
return newresults


def _split_groups(array, j, slicer):
"""Slices out chunks when split_out > 1"""
results = {"groups": array["groups"][..., slicer]}
results["intermediates"] = [v[..., slicer] for v in array["intermediates"]]
return results


def _finalize_results(
results: IntermediateDict,
agg: Aggregation,
Expand Down Expand Up @@ -997,38 +990,6 @@ def _grouped_combine(
return results


def split_blocks(applied, split_out, expected_groups, split_name):
import dask.array
from dask.array.core import normalize_chunks
from dask.highlevelgraph import HighLevelGraph

chunk_tuples = tuple(itertools.product(*tuple(range(n) for n in applied.numblocks)))
ngroups = len(expected_groups)
group_chunks = normalize_chunks(np.ceil(ngroups / split_out), (ngroups,))
idx = tuple(np.cumsum((0,) + group_chunks[0]))

# split each block into `split_out` chunks
dsk = {}
for i in chunk_tuples:
for j in range(split_out):
dsk[(split_name, *i, j)] = (
_split_groups,
(applied.name, *i),
j,
slice(idx[j], idx[j + 1]),
)

# now construct an array that can be passed to _tree_reduce
intergraph = HighLevelGraph.from_collections(split_name, dsk, dependencies=(applied,))
intermediate = dask.array.Array(
intergraph,
name=split_name,
chunks=applied.chunks + ((1,) * split_out,),
meta=applied._meta,
)
return intermediate, group_chunks


def _reduce_blockwise(
array,
by,
Expand Down Expand Up @@ -1169,7 +1130,6 @@ def dask_groupby_agg(
agg: Aggregation,
expected_groups: pd.Index | None,
axis: T_Axes = (),
split_out: int = 1,
fill_value: Any = None,
method: T_Method = "map-reduce",
reindex: bool = False,
Expand All @@ -1186,19 +1146,14 @@ def dask_groupby_agg(
assert isinstance(axis, Sequence)
assert all(ax >= 0 for ax in axis)

if method == "blockwise" and (split_out > 1 or not isinstance(by, np.ndarray)):
raise NotImplementedError

if split_out > 1 and expected_groups is None:
# This could be implemented using the "hash_split" strategy
# from dask.dataframe
if method == "blockwise" and not isinstance(by, np.ndarray):
raise NotImplementedError

inds = tuple(range(array.ndim))
name = f"groupby_{agg.name}"
token = dask.base.tokenize(array, by, agg, expected_groups, axis, split_out)
token = dask.base.tokenize(array, by, agg, expected_groups, axis)

if expected_groups is None and (reindex or split_out > 1):
if expected_groups is None and reindex:
expected_groups = _get_expected_groups(by, sort=sort)

by_input = by
Expand Down Expand Up @@ -1229,9 +1184,7 @@ def dask_groupby_agg(
# This allows us to discover groups at compute time, support argreductions, lower intermediate
# memory usage (but method="cohorts" would also work to reduce memory in some cases)

do_simple_combine = (
method != "blockwise" and reindex and not _is_arg_reduction(agg) and split_out == 1
)
do_simple_combine = method != "blockwise" and reindex and not _is_arg_reduction(agg)
if method == "blockwise":
# use the "non dask" code path, but applied blockwise
blockwise_method = partial(
Expand All @@ -1244,14 +1197,14 @@ def dask_groupby_agg(
func=agg.chunk,
fill_value=agg.fill_value["intermediate"],
dtype=agg.dtype["intermediate"],
reindex=reindex or (split_out > 1),
reindex=reindex,
)
if do_simple_combine:
# Add a dummy dimension that then gets reduced over
blockwise_method = tlz.compose(_expand_dims, blockwise_method)

# apply reduction on chunk
applied = dask.array.blockwise(
intermediate = dask.array.blockwise(
partial(
blockwise_method,
axis=axis,
Expand All @@ -1271,18 +1224,12 @@ def dask_groupby_agg(
token=f"{name}-chunk-{token}",
)

if split_out > 1:
intermediate, group_chunks = split_blocks(
applied, split_out, expected_groups, split_name=f"{name}-split-{token}"
)
else:
intermediate = applied
if expected_groups is None:
if is_duck_dask_array(by_input):
expected_groups = None
else:
expected_groups = _get_expected_groups(by_input, sort=sort)
group_chunks = ((len(expected_groups),) if expected_groups is not None else (np.nan,),)
if expected_groups is None:
if is_duck_dask_array(by_input):
expected_groups = None
else:
expected_groups = _get_expected_groups(by_input, sort=sort)
group_chunks = ((len(expected_groups),) if expected_groups is not None else (np.nan,),)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
group_chunks = ((len(expected_groups),) if expected_groups is not None else (np.nan,),)
group_chunks: tuple[tuple[float, ...]] = ((len(expected_groups),) if expected_groups is not None else (np.nan,),)

Try explicitly defining how group_chunks is supposed to look like.

mypy get's the type hint from the first time a variable is defined, it's not clear from this line that it can be multiple floats like later on line 1302.

flox/core.py:1302: error: Incompatible types in assignment (expression has type "Tuple[Tuple[int, ...]]", variable has type "Tuple[Tuple[float]]")  [assignment]
Found 2 errors in 1 file (checked 10 source files)
flox/core.py:1322: error: Incompatible types in assignment (expression has type "Tuple[Tuple[int, ...]]", variable has type "Tuple[Tuple[float]]")  [assignment]

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh right. it's only ever float in that one clase and tuple[tuple[int]] otherwise

Copy link
Contributor

@Illviljan Illviljan Oct 11, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For tuples you have to add ... to imply multiple elements of that same type and those loops in 1302 will add multiple ints.
np.nan is a float so the entire thing then becomes float: tuple[tuple[float, ...]]


if method in ["map-reduce", "cohorts", "split-reduce"]:
combine: Callable[..., IntermediateDict]
Expand Down Expand Up @@ -1311,9 +1258,7 @@ def dask_groupby_agg(
if method == "map-reduce":
reduced = tree_reduce(
intermediate,
aggregate=partial(
aggregate, expected_groups=None if split_out > 1 else expected_groups
),
aggregate=partial(aggregate, expected_groups=expected_groups),
)
if is_duck_dask_array(by_input) and expected_groups is None:
groups = _extract_unknown_groups(reduced, group_chunks=group_chunks, dtype=by.dtype)
Expand Down Expand Up @@ -1380,7 +1325,7 @@ def dask_groupby_agg(
raise ValueError(f"Unknown method={method}.")

# extract results from the dict
output_chunks = reduced.chunks[: -(len(axis) + int(split_out > 1))] + group_chunks
output_chunks = reduced.chunks[: -len(axis)] + group_chunks
ochunks = tuple(range(len(chunks_v)) for chunks_v in output_chunks)
layer2: dict[tuple, tuple] = {}
agg_name = f"{name}-{token}"
Expand All @@ -1392,10 +1337,7 @@ def dask_groupby_agg(
nblocks = tuple(len(array.chunks[ax]) for ax in axis)
inchunk = ochunk[:-1] + np.unravel_index(ochunk[-1], nblocks)
else:
inchunk = ochunk[:-1] + (0,) * (len(axis) - 1)
if split_out > 1:
inchunk = inchunk + (0,)
inchunk = inchunk + (ochunk[-1],)
inchunk = ochunk[:-1] + (0,) * (len(axis) - 1) + (ochunk[-1],)

layer2[(agg_name, *ochunk)] = (operator.getitem, (reduced.name, *inchunk), agg.name)

Expand Down Expand Up @@ -1516,7 +1458,6 @@ def groupby_reduce(
fill_value=None,
dtype: np.typing.DTypeLike = None,
min_count: int | None = None,
split_out: int = 1,
method: T_Method = "map-reduce",
engine: T_Engine = "numpy",
reindex: bool | None = None,
Expand Down Expand Up @@ -1555,8 +1496,6 @@ def groupby_reduce(
fewer than min_count non-NA values are present the result will be
NA. Only used if skipna is set to True or defaults to True for the
array's dtype.
split_out : int, optional
Number of chunks along group axis in output (last axis)
method : {"map-reduce", "blockwise", "cohorts", "split-reduce"}, optional
Strategy for reduction of dask arrays only:
* ``"map-reduce"``:
Expand Down Expand Up @@ -1750,7 +1689,7 @@ def groupby_reduce(
if kwargs["fill_value"] is None:
kwargs["fill_value"] = agg.fill_value[agg.name]

partial_agg = partial(dask_groupby_agg, split_out=split_out, **kwargs)
partial_agg = partial(dask_groupby_agg, **kwargs)

if method == "blockwise" and by_.ndim == 1:
array = rechunk_for_blockwise(array, axis=-1, labels=by_)
Expand Down
4 changes: 0 additions & 4 deletions flox/xarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,6 @@ def xarray_reduce(
isbin: bool | Sequence[bool] = False,
sort: bool = True,
dim: Dims | ellipsis = None,
split_out: int = 1,
fill_value=None,
dtype: np.typing.DTypeLike = None,
method: str = "map-reduce",
Expand Down Expand Up @@ -94,8 +93,6 @@ def xarray_reduce(
dim : hashable
dimension name along which to reduce. If None, reduces across all
dimensions of `by`
split_out : int, optional
Number of output chunks along grouped dimension in output.
fill_value
Value used for missing groups in the output i.e. when one of the labels
in ``expected_groups`` is not actually present in ``by``.
Expand Down Expand Up @@ -396,7 +393,6 @@ def wrapper(array, *by, func, skipna, **kwargs):
"func": func,
"axis": axis,
"sort": sort,
"split_out": split_out,
"fill_value": fill_value,
"method": method,
"min_count": min_count,
Expand Down
4 changes: 1 addition & 3 deletions tests/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ def test_alignment_error():


@pytest.mark.parametrize("dtype", (float, int))
@pytest.mark.parametrize("chunk, split_out", [(False, 1), (True, 1), (True, 2), (True, 3)])
@pytest.mark.parametrize("chunk", [False, True])
@pytest.mark.parametrize("expected_groups", [None, [0, 1, 2], np.array([0, 1, 2])])
@pytest.mark.parametrize(
"func, array, by, expected",
Expand Down Expand Up @@ -114,7 +114,6 @@ def test_groupby_reduce(
expected: list[float],
expected_groups: T_ExpectedGroupsOpt,
chunk: bool,
split_out: int,
dtype: np.typing.DTypeLike,
) -> None:
array = array.astype(dtype)
Expand All @@ -137,7 +136,6 @@ def test_groupby_reduce(
func=func,
expected_groups=expected_groups,
fill_value=123,
split_out=split_out,
engine=engine,
)
g_dtype = by.dtype if expected_groups is None else np.asarray(expected_groups).dtype
Expand Down