From 6eca6f17afe843e64e931f28efd445919666da26 Mon Sep 17 00:00:00 2001 From: Thomas Nicholas Date: Sat, 24 Jun 2023 13:35:37 -0400 Subject: [PATCH 1/7] array api compatiblity --- flox/aggregations.py | 3 ++- flox/duck_array_ops.py | 18 ++++++++++++++++++ flox/xrutils.py | 11 +++++++++-- 3 files changed, 29 insertions(+), 3 deletions(-) create mode 100644 flox/duck_array_ops.py diff --git a/flox/aggregations.py b/flox/aggregations.py index 13b23fafe..860cd4580 100644 --- a/flox/aggregations.py +++ b/flox/aggregations.py @@ -11,6 +11,7 @@ from . import aggregate_flox, aggregate_npg, xrutils from . import xrdtypes as dtypes +from .duck_array_ops import asarray if TYPE_CHECKING: FuncTuple = tuple[Callable | str, ...] @@ -64,7 +65,7 @@ def generic_aggregate( f"Expected engine to be one of ['flox', 'numpy', 'numba']. Received {engine} instead." ) - group_idx = np.asarray(group_idx, like=array) + group_idx = asarray(group_idx, like=array) with warnings.catch_warnings(): warnings.filterwarnings("ignore", r"All-NaN (slice|axis) encountered") diff --git a/flox/duck_array_ops.py b/flox/duck_array_ops.py new file mode 100644 index 000000000..615ce16fb --- /dev/null +++ b/flox/duck_array_ops.py @@ -0,0 +1,18 @@ +import numpy as np + + +def get_array_namespace(x): + if hasattr(x, "__array_namespace__"): + return x.__array_namespace__() + else: + return np + + +def reshape(array, shape): + xp = get_array_namespace(array) + return xp.reshape(array, shape) + + +def asarray(obj, like): + xp = get_array_namespace(like) + return xp.asarray(obj) diff --git a/flox/xrutils.py b/flox/xrutils.py index 45cf45eec..c7a744df1 100644 --- a/flox/xrutils.py +++ b/flox/xrutils.py @@ -34,11 +34,18 @@ def is_duck_array(value: Any) -> bool: hasattr(value, "ndim") and hasattr(value, "shape") and hasattr(value, "dtype") - and hasattr(value, "__array_function__") - and hasattr(value, "__array_ufunc__") + and ( + (hasattr(value, "__array_function__") and hasattr(value, "__array_ufunc__")) + or hasattr(value, "__array_namespace__") + ) ) +def is_chunked_array(x) -> bool: + """True if dask or cubed""" + return is_duck_dask_array(x) or (is_duck_array(x) and hasattr(x, "chunks")) + + def is_dask_collection(x): try: import dask From 58d2021f727ed85330a99081c9d43d01767bd9e7 Mon Sep 17 00:00:00 2001 From: Thomas Nicholas Date: Sat, 24 Jun 2023 13:36:08 -0400 Subject: [PATCH 2/7] use xarray chunkmanager --- flox/core.py | 36 ++++++++++++++++++++++++------------ 1 file changed, 24 insertions(+), 12 deletions(-) diff --git a/flox/core.py b/flox/core.py index 02f53b837..50b79d04b 100644 --- a/flox/core.py +++ b/flox/core.py @@ -35,7 +35,8 @@ generic_aggregate, ) from .cache import memoize -from .xrutils import is_duck_array, is_duck_dask_array, isnull +from .duck_array_ops import reshape +from .xrutils import is_duck_array, is_duck_dask_array, is_chunked_array, isnull if TYPE_CHECKING: try: @@ -764,7 +765,7 @@ def chunk_reduce( group_idx = np.broadcast_to(group_idx, array.shape[-by.ndim :]) # always reshape to 1D along group dimensions newshape = array.shape[: array.ndim - by.ndim] + (math.prod(array.shape[-by.ndim :]),) - array = array.reshape(newshape) + array = reshape(array, newshape) group_idx = group_idx.reshape(-1) assert group_idx.ndim == 1 @@ -1295,6 +1296,10 @@ def dask_groupby_agg( import dask.array from dask.array.core import slices_from_chunks + from xarray.core.parallelcompat import get_chunked_array_type + + chunkmanager = get_chunked_array_type(array) + # I think _tree_reduce expects this assert isinstance(axis, Sequence) assert all(ax >= 0 for ax in axis) @@ -1314,14 +1319,18 @@ def dask_groupby_agg( # Unifying chunks is necessary for argreductions. # We need to rechunk before zipping up with the index # let's always do it anyway - if not is_duck_dask_array(by): + if not is_chunked_array(by): # chunk numpy arrays like the input array # This removes an extra rechunk-merge layer that would be # added otherwise chunks = tuple(array.chunks[ax] if by.shape[ax] != 1 else (1,) for ax in range(-by.ndim, 0)) - by = dask.array.from_array(by, chunks=chunks) - _, (array, by) = dask.array.unify_chunks(array, inds, by, inds[-by.ndim :]) + by = chunkmanager.from_array( + by, + chunks=chunks, + spec=array.spec, # cubed needs all arguments to blockwise to have same Spec + ) + _, (array, by) = chunkmanager.unify_chunks(array, inds, by, inds[-by.ndim :]) # preprocess the array: # - for argreductions, this zips the index together with the array block @@ -1363,7 +1372,7 @@ def dask_groupby_agg( blockwise_method = tlz.compose(_expand_dims, blockwise_method) # apply reduction on chunk - intermediate = dask.array.blockwise( + intermediate = chunkmanager.blockwise( partial( blockwise_method, axis=axis, @@ -1381,9 +1390,9 @@ def dask_groupby_agg( inds[-by.ndim :], concatenate=False, dtype=array.dtype, # this is purely for show - meta=array._meta, + #meta=array._meta, align_arrays=False, - name=f"{name}-chunk-{token}", + #name=f"{name}-chunk-{token}", ) group_chunks: tuple[tuple[int | float, ...]] @@ -1397,9 +1406,12 @@ def dask_groupby_agg( combine = partial(_grouped_combine, engine=engine, sort=sort) combine_name = "grouped-combine" + #raise NotImplementedError("reached _tree_reduce call") + tree_reduce = partial( - dask.array.reductions._tree_reduce, - name=f"{name}-reduce-{method}-{combine_name}", + chunkmanager.reduction, + func=lambda x: x, + #name=f"{name}-reduce-{method}-{combine_name}", dtype=array.dtype, axis=axis, keepdims=True, @@ -1479,7 +1491,7 @@ def dask_groupby_agg( reduced = _collapse_blocks_along_axes(reduced, axis, group_chunks) # Can't use map_blocks because it forces concatenate=True along drop_axes, - result = dask.array.blockwise( + result = chunkmanager.blockwise( _extract_result, out_inds, reduced, @@ -1889,7 +1901,7 @@ def groupby_reduce( axis_ = np.core.numeric.normalize_axis_tuple(axis, array.ndim) # type: ignore nax = len(axis_) - has_dask = is_duck_dask_array(array) or is_duck_dask_array(by_) + has_dask = is_chunked_array(array) or is_duck_dask_array(by_) if _is_first_last_reduction(func): if has_dask and nax != 1: From 8fdc36779548415dfbfa67b20367231bfee1621f Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sat, 24 Jun 2023 17:39:59 +0000 Subject: [PATCH 3/7] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- flox/core.py | 13 +++++-------- 1 file changed, 5 insertions(+), 8 deletions(-) diff --git a/flox/core.py b/flox/core.py index 50b79d04b..006150792 100644 --- a/flox/core.py +++ b/flox/core.py @@ -36,7 +36,7 @@ ) from .cache import memoize from .duck_array_ops import reshape -from .xrutils import is_duck_array, is_duck_dask_array, is_chunked_array, isnull +from .xrutils import is_chunked_array, is_duck_array, is_duck_dask_array, isnull if TYPE_CHECKING: try: @@ -1295,7 +1295,6 @@ def dask_groupby_agg( ) -> tuple[DaskArray, tuple[np.ndarray | DaskArray]]: import dask.array from dask.array.core import slices_from_chunks - from xarray.core.parallelcompat import get_chunked_array_type chunkmanager = get_chunked_array_type(array) @@ -1390,9 +1389,9 @@ def dask_groupby_agg( inds[-by.ndim :], concatenate=False, dtype=array.dtype, # this is purely for show - #meta=array._meta, + # meta=array._meta, align_arrays=False, - #name=f"{name}-chunk-{token}", + # name=f"{name}-chunk-{token}", ) group_chunks: tuple[tuple[int | float, ...]] @@ -1401,17 +1400,15 @@ def dask_groupby_agg( combine: Callable[..., IntermediateDict] if do_simple_combine: combine = partial(_simple_combine, reindex=reindex) - combine_name = "simple-combine" else: combine = partial(_grouped_combine, engine=engine, sort=sort) - combine_name = "grouped-combine" - #raise NotImplementedError("reached _tree_reduce call") + # raise NotImplementedError("reached _tree_reduce call") tree_reduce = partial( chunkmanager.reduction, func=lambda x: x, - #name=f"{name}-reduce-{method}-{combine_name}", + # name=f"{name}-reduce-{method}-{combine_name}", dtype=array.dtype, axis=axis, keepdims=True, From 4777e77b47e70b7e920a2029227df495076ddec9 Mon Sep 17 00:00:00 2001 From: Thomas Nicholas Date: Sat, 24 Jun 2023 13:46:54 -0400 Subject: [PATCH 4/7] remove commented out line --- flox/core.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/flox/core.py b/flox/core.py index 50b79d04b..f9cdc61e3 100644 --- a/flox/core.py +++ b/flox/core.py @@ -1406,8 +1406,6 @@ def dask_groupby_agg( combine = partial(_grouped_combine, engine=engine, sort=sort) combine_name = "grouped-combine" - #raise NotImplementedError("reached _tree_reduce call") - tree_reduce = partial( chunkmanager.reduction, func=lambda x: x, From fabaf35f87564e5015d505ca2b7c8101f6856747 Mon Sep 17 00:00:00 2001 From: Thomas Nicholas Date: Mon, 26 Jun 2023 16:37:31 -0400 Subject: [PATCH 5/7] remove uneccessary asarray --- flox/aggregations.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/flox/aggregations.py b/flox/aggregations.py index 860cd4580..2675cda62 100644 --- a/flox/aggregations.py +++ b/flox/aggregations.py @@ -65,7 +65,7 @@ def generic_aggregate( f"Expected engine to be one of ['flox', 'numpy', 'numba']. Received {engine} instead." ) - group_idx = asarray(group_idx, like=array) + group_idx = np.asarray(group_idx, like=array) with warnings.catch_warnings(): warnings.filterwarnings("ignore", r"All-NaN (slice|axis) encountered") From 858c98a848607a5200d2ed455b428c18b3a5eca1 Mon Sep 17 00:00:00 2001 From: Thomas Nicholas Date: Mon, 26 Jun 2023 16:39:07 -0400 Subject: [PATCH 6/7] remove concatenate kwargs, use array API version of reshape, and add axis to chunk identity fn --- flox/core.py | 20 ++++++++++++-------- 1 file changed, 12 insertions(+), 8 deletions(-) diff --git a/flox/core.py b/flox/core.py index e3cf1ba84..77589e353 100644 --- a/flox/core.py +++ b/flox/core.py @@ -1387,7 +1387,7 @@ def dask_groupby_agg( inds, by, inds[-by.ndim :], - concatenate=False, + #concatenate=False, dtype=array.dtype, # this is purely for show # meta=array._meta, align_arrays=False, @@ -1403,14 +1403,17 @@ def dask_groupby_agg( else: combine = partial(_grouped_combine, engine=engine, sort=sort) + def identity(x, axis, keepdims): + return x + tree_reduce = partial( chunkmanager.reduction, - func=lambda x: x, + func=identity, # name=f"{name}-reduce-{method}-{combine_name}", dtype=array.dtype, axis=axis, keepdims=True, - concatenate=False, + #concatenate=False, ) aggregate = partial(_aggregate, combine=combine, agg=agg, fill_value=fill_value) @@ -1422,8 +1425,8 @@ def dask_groupby_agg( if method == "map-reduce": reduced = tree_reduce( intermediate, - combine=partial(combine, agg=agg), - aggregate=partial(aggregate, expected_groups=expected_groups, reindex=reindex), + combine_func=partial(combine, agg=agg), + aggregate_func=partial(aggregate, expected_groups=expected_groups, reindex=reindex), ) if is_duck_dask_array(by_input) and expected_groups is None: groups = _extract_unknown_groups(reduced, dtype=by.dtype) @@ -1495,7 +1498,7 @@ def dask_groupby_agg( dtype=agg.dtype["final"], key=agg.name, name=f"{name}-{token}", - concatenate=False, + #concatenate=False, ) return (result, groups) @@ -2015,9 +2018,10 @@ def groupby_reduce( # nan group labels are factorized to -1, and preserved # now we get rid of them by reindexing # This also handles bins with no data - result = reindex_( + reindexed = reindex_( result, from_=groups[0], to=expected_groups, fill_value=fill_value - ).reshape(result.shape[:-1] + grp_shape) + ) + result = reshape(reindexed, reindexed.shape[:-1] + grp_shape) groups = final_groups if is_bool_array and (_is_minmax_reduction(func) or _is_first_last_reduction(func)): From 786af6a036847c7f55f2e29d5d108ab1c54358db Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 26 Jun 2023 20:40:28 +0000 Subject: [PATCH 7/7] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- flox/aggregations.py | 1 - flox/core.py | 10 ++++------ 2 files changed, 4 insertions(+), 7 deletions(-) diff --git a/flox/aggregations.py b/flox/aggregations.py index 2675cda62..13b23fafe 100644 --- a/flox/aggregations.py +++ b/flox/aggregations.py @@ -11,7 +11,6 @@ from . import aggregate_flox, aggregate_npg, xrutils from . import xrdtypes as dtypes -from .duck_array_ops import asarray if TYPE_CHECKING: FuncTuple = tuple[Callable | str, ...] diff --git a/flox/core.py b/flox/core.py index 77589e353..4381b62ad 100644 --- a/flox/core.py +++ b/flox/core.py @@ -1387,7 +1387,7 @@ def dask_groupby_agg( inds, by, inds[-by.ndim :], - #concatenate=False, + # concatenate=False, dtype=array.dtype, # this is purely for show # meta=array._meta, align_arrays=False, @@ -1413,7 +1413,7 @@ def identity(x, axis, keepdims): dtype=array.dtype, axis=axis, keepdims=True, - #concatenate=False, + # concatenate=False, ) aggregate = partial(_aggregate, combine=combine, agg=agg, fill_value=fill_value) @@ -1498,7 +1498,7 @@ def identity(x, axis, keepdims): dtype=agg.dtype["final"], key=agg.name, name=f"{name}-{token}", - #concatenate=False, + # concatenate=False, ) return (result, groups) @@ -2018,9 +2018,7 @@ def groupby_reduce( # nan group labels are factorized to -1, and preserved # now we get rid of them by reindexing # This also handles bins with no data - reindexed = reindex_( - result, from_=groups[0], to=expected_groups, fill_value=fill_value - ) + reindexed = reindex_(result, from_=groups[0], to=expected_groups, fill_value=fill_value) result = reshape(reindexed, reindexed.shape[:-1] + grp_shape) groups = final_groups