From 0438a7ea511bdf6e8ddcfca137e356b5457a7dce Mon Sep 17 00:00:00 2001 From: Deepak Cherian Date: Sat, 7 Sep 2024 23:21:57 -0600 Subject: [PATCH 1/2] Drop python 3.9, use ruff (#392) * Drop python 3.9, use ruff * switch to Ruff * fix mypy * remove toctrees * fix --- .github/workflows/ci.yaml | 4 +- .pre-commit-config.yaml | 17 +- asv_bench/benchmarks/__init__.py | 1 - asv_bench/benchmarks/cohorts.py | 25 ++- asv_bench/benchmarks/combine.py | 8 +- asv_bench/benchmarks/reduce.py | 10 +- ci/docs.yml | 1 + docs/source/conf.py | 3 + .../user-stories/climatology-hourly.ipynb | 1 - flox/__init__.py | 8 +- flox/aggregate_flox.py | 18 +- flox/aggregations.py | 72 ++++--- flox/core.py | 184 ++++++++---------- flox/xarray.py | 37 ++-- flox/xrutils.py | 18 +- pyproject.toml | 14 +- tests/__init__.py | 4 +- tests/strategies.py | 23 ++- tests/test_asv.py | 4 +- tests/test_core.py | 141 +++++++++----- tests/test_properties.py | 33 ++-- tests/test_xarray.py | 55 +++--- 22 files changed, 384 insertions(+), 297 deletions(-) diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml index 04799a78e..3f1416b2a 100644 --- a/.github/workflows/ci.yaml +++ b/.github/workflows/ci.yaml @@ -26,7 +26,7 @@ jobs: matrix: os: ["ubuntu-latest"] env: ["environment"] - python-version: ["3.9", "3.12"] + python-version: ["3.10", "3.12"] include: - os: "windows-latest" env: "environment" @@ -36,7 +36,7 @@ jobs: python-version: "3.12" - os: "ubuntu-latest" env: "minimal-requirements" - python-version: "3.9" + python-version: "3.10" steps: - uses: actions/checkout@v4 with: diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 713c11adf..c1d32f7d3 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -4,10 +4,11 @@ ci: repos: - repo: https://github.com/astral-sh/ruff-pre-commit # Ruff version. - rev: "v0.5.0" + rev: "v0.6.4" hooks: - id: ruff args: ["--fix", "--show-fixes"] + - id: ruff-format - repo: https://github.com/pre-commit/mirrors-prettier rev: "v4.0.0-alpha.8" @@ -22,11 +23,6 @@ repos: - id: end-of-file-fixer - id: check-docstring-first - - repo: https://github.com/psf/black-pre-commit-mirror - rev: 24.4.2 - hooks: - - id: black - - repo: https://github.com/executablebooks/mdformat rev: 0.7.17 hooks: @@ -35,13 +31,6 @@ repos: - mdformat-black - mdformat-myst - - repo: https://github.com/nbQA-dev/nbQA - rev: 1.8.5 - hooks: - - id: nbqa-black - - id: nbqa-ruff - args: [--fix] - - repo: https://github.com/kynan/nbstripout rev: 0.7.1 hooks: @@ -56,7 +45,7 @@ repos: - tomli - repo: https://github.com/abravalheri/validate-pyproject - rev: v0.18 + rev: v0.19 hooks: - id: validate-pyproject diff --git a/asv_bench/benchmarks/__init__.py b/asv_bench/benchmarks/__init__.py index 0a35e6cd7..73029e2c0 100644 --- a/asv_bench/benchmarks/__init__.py +++ b/asv_bench/benchmarks/__init__.py @@ -21,7 +21,6 @@ def _skip_slow(): >>> from . import _skip_slow >>> def time_something_slow(): ... pass - ... >>> time_something.setup = _skip_slow """ if os.environ.get("ASV_SKIP_SLOW", "0") == "1": diff --git a/asv_bench/benchmarks/cohorts.py b/asv_bench/benchmarks/cohorts.py index f2e993798..7a62d9d28 100644 --- a/asv_bench/benchmarks/cohorts.py +++ b/asv_bench/benchmarks/cohorts.py @@ -67,7 +67,12 @@ def track_num_layers(self): track_num_tasks.unit = "tasks" # type: ignore[attr-defined] # Lazy track_num_tasks_optimized.unit = "tasks" # type: ignore[attr-defined] # Lazy track_num_layers.unit = "layers" # type: ignore[attr-defined] # Lazy - for f in [track_num_tasks, track_num_tasks_optimized, track_num_layers, track_num_cohorts]: + for f in [ + track_num_tasks, + track_num_tasks_optimized, + track_num_layers, + track_num_cohorts, + ]: f.repeat = 1 # type: ignore[attr-defined] # Lazy f.rounds = 1 # type: ignore[attr-defined] # Lazy f.number = 1 # type: ignore[attr-defined] # Lazy @@ -82,9 +87,7 @@ def setup(self, *args, **kwargs): y = np.repeat(np.arange(30), 60) by = x[np.newaxis, :] * y[:, np.newaxis] - self.by = flox.core._factorize_multiple((by,), expected_groups=(None,), any_by_dask=False)[ - 0 - ][0] + self.by = flox.core._factorize_multiple((by,), expected_groups=(None,), any_by_dask=False)[0][0] self.array = dask.array.ones(self.by.shape, chunks=(350, 350)) self.axis = (-2, -1) @@ -101,7 +104,12 @@ def __init__(self, *args, **kwargs): def rechunk(self): self.array = flox.core.rechunk_for_cohorts( - self.array, -1, self.by, force_new_chunk_at=[1], chunksize=48, ignore_old_chunks=True + self.array, + -1, + self.by, + force_new_chunk_at=[1], + chunksize=48, + ignore_old_chunks=True, ) @@ -151,7 +159,12 @@ def setup(self, *args, **kwargs): def rechunk(self): self.array = flox.core.rechunk_for_cohorts( - self.array, -1, self.by, force_new_chunk_at=[1], chunksize=4, ignore_old_chunks=True + self.array, + -1, + self.by, + force_new_chunk_at=[1], + chunksize=4, + ignore_old_chunks=True, ) diff --git a/asv_bench/benchmarks/combine.py b/asv_bench/benchmarks/combine.py index 27600685f..613b62108 100644 --- a/asv_bench/benchmarks/combine.py +++ b/asv_bench/benchmarks/combine.py @@ -65,12 +65,8 @@ def construct_member(groups) -> dict[str, Any]: * 2 ] - self.x_chunk_reindexed = [ - construct_member(groups) for groups in [np.array((1, 2, 3, 4))] * 4 - ] + self.x_chunk_reindexed = [construct_member(groups) for groups in [np.array((1, 2, 3, 4))] * 4] self.kwargs = { - "agg": flox.aggregations._initialize_aggregation( - "sum", "float64", np.float64, 0, 0, {} - ), + "agg": flox.aggregations._initialize_aggregation("sum", "float64", np.float64, 0, 0, {}), "axis": (3,), } diff --git a/asv_bench/benchmarks/reduce.py b/asv_bench/benchmarks/reduce.py index dd866d565..a31da26a1 100644 --- a/asv_bench/benchmarks/reduce.py +++ b/asv_bench/benchmarks/reduce.py @@ -7,7 +7,11 @@ N = 3000 funcs = ["sum", "nansum", "mean", "nanmean", "max", "nanmax", "count"] -engines = [None, "flox", "numpy"] # numbagg is disabled for now since it takes ages in CI +engines = [ + None, + "flox", + "numpy", +] # numbagg is disabled for now since it takes ages in CI expected_groups = { "None": None, "bins": pd.IntervalIndex.from_breaks([1, 2, 4]), @@ -17,9 +21,7 @@ NUMBAGG_FUNCS = ["nansum", "nanmean", "nanmax", "count", "all"] numbagg_skip = [] for name in expected_names: - numbagg_skip.extend( - list((func, name, "numbagg") for func in funcs if func not in NUMBAGG_FUNCS) - ) + numbagg_skip.extend(list((func, name, "numbagg") for func in funcs if func not in NUMBAGG_FUNCS)) def setup_jit(): diff --git a/ci/docs.yml b/ci/docs.yml index 50bf98829..fcf9d721a 100644 --- a/ci/docs.yml +++ b/ci/docs.yml @@ -16,6 +16,7 @@ dependencies: - myst-parser - myst-nb - sphinx + - sphinx-remove-toctrees - furo>=2024.08 - ipykernel - jupyter diff --git a/docs/source/conf.py b/docs/source/conf.py index 80412ba23..9160622b7 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -40,6 +40,7 @@ "sphinx.ext.napoleon", "myst_nb", "sphinx_codeautolink", + "sphinx_remove_toctrees", ] codeautolink_concat_default = True @@ -54,6 +55,8 @@ master_doc = "index" language = "en" +remove_from_toctrees = ["generated/*"] + # General information about the project. project = "flox" current_year = datetime.datetime.now().year diff --git a/docs/source/user-stories/climatology-hourly.ipynb b/docs/source/user-stories/climatology-hourly.ipynb index d264ac102..f9cb54bdf 100644 --- a/docs/source/user-stories/climatology-hourly.ipynb +++ b/docs/source/user-stories/climatology-hourly.ipynb @@ -92,7 +92,6 @@ "%load_ext watermark\n", "\n", "\n", - "\n", "%watermark -iv" ] }, diff --git a/flox/__init__.py b/flox/__init__.py index 839bfb076..92728f987 100644 --- a/flox/__init__.py +++ b/flox/__init__.py @@ -1,9 +1,15 @@ #!/usr/bin/env python # flake8: noqa """Top-level module for flox .""" + from . import cache from .aggregations import Aggregation, Scan # noqa -from .core import groupby_reduce, groupby_scan, rechunk_for_blockwise, rechunk_for_cohorts # noqa +from .core import ( + groupby_reduce, + groupby_scan, + rechunk_for_blockwise, + rechunk_for_cohorts, +) # noqa def _get_version(): diff --git a/flox/aggregate_flox.py b/flox/aggregate_flox.py index 684822363..1e7d330a3 100644 --- a/flox/aggregate_flox.py +++ b/flox/aggregate_flox.py @@ -89,10 +89,14 @@ def quantile_(array, inv_idx, *, q, axis, skipna, group_idx, dtype=None, out=Non idxshape = (q.shape[0],) + array.shape[:-1] + (actual_sizes.shape[-1],) lo_ = np.floor( - virtual_index, casting="unsafe", out=np.empty(virtual_index.shape, dtype=np.int64) + virtual_index, + casting="unsafe", + out=np.empty(virtual_index.shape, dtype=np.int64), ) hi_ = np.ceil( - virtual_index, casting="unsafe", out=np.empty(virtual_index.shape, dtype=np.int64) + virtual_index, + casting="unsafe", + out=np.empty(virtual_index.shape, dtype=np.int64), ) kth = np.unique(np.concatenate([lo_.reshape(-1), hi_.reshape(-1)])) @@ -119,7 +123,15 @@ def quantile_(array, inv_idx, *, q, axis, skipna, group_idx, dtype=None, out=Non def _np_grouped_op( - group_idx, array, op, axis=-1, size=None, fill_value=None, dtype=None, out=None, **kwargs + group_idx, + array, + op, + axis=-1, + size=None, + fill_value=None, + dtype=None, + out=None, + **kwargs, ): """ most of this code is from shoyer's gist diff --git a/flox/aggregations.py b/flox/aggregations.py index 5515a5b68..4e0312198 100644 --- a/flox/aggregations.py +++ b/flox/aggregations.py @@ -3,10 +3,10 @@ import copy import logging import warnings -from collections.abc import Sequence +from collections.abc import Callable, Sequence from dataclasses import dataclass from functools import cached_property, partial -from typing import TYPE_CHECKING, Any, Callable, Literal, TypedDict +from typing import TYPE_CHECKING, Any, Literal, TypedDict import numpy as np import pandas as pd @@ -110,7 +110,13 @@ def generic_aggregate( with warnings.catch_warnings(): warnings.filterwarnings("ignore", r"All-NaN (slice|axis) encountered") result = method( - group_idx, array, axis=axis, size=size, fill_value=fill_value, dtype=dtype, **kwargs + group_idx, + array, + axis=axis, + size=size, + fill_value=fill_value, + dtype=dtype, + **kwargs, ) return result @@ -238,9 +244,7 @@ def __init__( # The following are set by _initialize_aggregation self.finalize_kwargs: dict[Any, Any] = {} self.min_count: int = 0 - self.new_dims_func: Callable = ( - returns_empty_tuple if new_dims_func is None else new_dims_func - ) + self.new_dims_func: Callable = returns_empty_tuple if new_dims_func is None else new_dims_func self.preserves_dtype = preserves_dtype @cached_property @@ -386,11 +390,19 @@ def _std_finalize(sumsq, sum_, count, ddof=0): min_ = Aggregation("min", chunk="min", combine="min", fill_value=dtypes.INF, preserves_dtype=True) nanmin = Aggregation( - "nanmin", chunk="nanmin", combine="nanmin", fill_value=dtypes.NA, preserves_dtype=True + "nanmin", + chunk="nanmin", + combine="nanmin", + fill_value=dtypes.NA, + preserves_dtype=True, ) max_ = Aggregation("max", chunk="max", combine="max", fill_value=dtypes.NINF, preserves_dtype=True) nanmax = Aggregation( - "nanmax", chunk="nanmax", combine="nanmax", fill_value=dtypes.NA, preserves_dtype=True + "nanmax", + chunk="nanmax", + combine="nanmax", + fill_value=dtypes.NA, + preserves_dtype=True, ) @@ -482,10 +494,18 @@ def _pick_second(*x): first = Aggregation("first", chunk=None, combine=None, fill_value=None, preserves_dtype=True) last = Aggregation("last", chunk=None, combine=None, fill_value=None, preserves_dtype=True) nanfirst = Aggregation( - "nanfirst", chunk="nanfirst", combine="nanfirst", fill_value=dtypes.NA, preserves_dtype=True + "nanfirst", + chunk="nanfirst", + combine="nanfirst", + fill_value=dtypes.NA, + preserves_dtype=True, ) nanlast = Aggregation( - "nanlast", chunk="nanlast", combine="nanlast", fill_value=dtypes.NA, preserves_dtype=True + "nanlast", + chunk="nanlast", + combine="nanlast", + fill_value=dtypes.NA, + preserves_dtype=True, ) all_ = Aggregation( @@ -510,10 +530,18 @@ def _pick_second(*x): # Support statistical quantities only blockwise # The parallel versions will be approximate and are hard to implement! median = Aggregation( - name="median", fill_value=dtypes.NA, chunk=None, combine=None, final_dtype=np.floating + name="median", + fill_value=dtypes.NA, + chunk=None, + combine=None, + final_dtype=np.floating, ) nanmedian = Aggregation( - name="nanmedian", fill_value=dtypes.NA, chunk=None, combine=None, final_dtype=np.floating + name="nanmedian", + fill_value=dtypes.NA, + chunk=None, + combine=None, + final_dtype=np.floating, ) @@ -537,12 +565,8 @@ def quantile_new_dims_func(q) -> tuple[Dim]: final_dtype=np.floating, new_dims_func=quantile_new_dims_func, ) -mode = Aggregation( - name="mode", fill_value=dtypes.NA, chunk=None, combine=None, preserves_dtype=True -) -nanmode = Aggregation( - name="nanmode", fill_value=dtypes.NA, chunk=None, combine=None, preserves_dtype=True -) +mode = Aggregation(name="mode", fill_value=dtypes.NA, chunk=None, combine=None, preserves_dtype=True) +nanmode = Aggregation(name="nanmode", fill_value=dtypes.NA, chunk=None, combine=None, preserves_dtype=True) @dataclass @@ -658,9 +682,7 @@ def scan_binary_op(left_state: ScanState, right_state: ScanState, *, agg: Scan) engine="flox", fill_value=agg.identity, ) - result = AlignedArrays( - array=final_value[..., left.group_idx.size :], group_idx=right.group_idx - ) + result = AlignedArrays(array=final_value[..., left.group_idx.size :], group_idx=right.group_idx) else: raise ValueError(f"Unknown binary op application mode: {agg.mode!r}") @@ -779,9 +801,7 @@ def _initialize_aggregation( dtype_: np.dtype | None = ( np.dtype(dtype) if dtype is not None and not isinstance(dtype, np.dtype) else dtype ) - final_dtype = dtypes._normalize_dtype( - dtype_ or agg.dtype_init["final"], array_dtype, fill_value - ) + final_dtype = dtypes._normalize_dtype(dtype_ or agg.dtype_init["final"], array_dtype, fill_value) if not agg.preserves_dtype: final_dtype = dtypes._maybe_promote_int(final_dtype) agg.dtype = { @@ -794,9 +814,7 @@ def _initialize_aggregation( if int_dtype is None else np.dtype(int_dtype) ) - for int_dtype, int_fv in zip( - agg.dtype_init["intermediate"], agg.fill_value["intermediate"] - ) + for int_dtype, int_fv in zip(agg.dtype_init["intermediate"], agg.fill_value["intermediate"]) ), } diff --git a/flox/core.py b/flox/core.py index bb84ab809..c419f7465 100644 --- a/flox/core.py +++ b/flox/core.py @@ -8,7 +8,7 @@ import sys import warnings from collections import namedtuple -from collections.abc import Sequence +from collections.abc import Callable, Sequence from concurrent.futures import ThreadPoolExecutor from functools import partial, reduce from itertools import product @@ -16,8 +16,8 @@ from typing import ( TYPE_CHECKING, Any, - Callable, Literal, + TypeAlias, TypedDict, TypeVar, Union, @@ -74,37 +74,37 @@ import dask.array.Array as DaskArray from dask.typing import Graph - T_DuckArray = Union[np.ndarray, DaskArray, CubedArray] # Any ? - T_By = T_DuckArray + T_DuckArray: TypeAlias = np.ndarray | DaskArray | CubedArray # Any ? + T_By: TypeAlias = T_DuckArray T_Bys = tuple[T_By, ...] T_ExpectIndex = pd.Index T_ExpectIndexTuple = tuple[T_ExpectIndex, ...] - T_ExpectIndexOpt = Union[T_ExpectIndex, None] + T_ExpectIndexOpt = T_ExpectIndex | None T_ExpectIndexOptTuple = tuple[T_ExpectIndexOpt, ...] - T_Expect = Union[Sequence, np.ndarray, T_ExpectIndex] + T_Expect = Sequence | np.ndarray | T_ExpectIndex T_ExpectTuple = tuple[T_Expect, ...] - T_ExpectOpt = Union[Sequence, np.ndarray, T_ExpectIndexOpt] + T_ExpectOpt = Sequence | np.ndarray | T_ExpectIndexOpt T_ExpectOptTuple = tuple[T_ExpectOpt, ...] - T_ExpectedGroups = Union[T_Expect, T_ExpectOptTuple] - T_ExpectedGroupsOpt = Union[T_ExpectedGroups, None] - T_Func = Union[str, Callable] - T_Funcs = Union[T_Func, Sequence[T_Func]] - T_Agg = Union[str, Aggregation] - T_Scan = Union[str, Scan] + T_ExpectedGroups = T_Expect | T_ExpectOptTuple + T_ExpectedGroupsOpt = T_ExpectedGroups | None + T_Func = str | Callable + T_Funcs = T_Func | Sequence[T_Func] + T_Agg = str | Aggregation + T_Scan = str | Scan T_Axis = int T_Axes = tuple[T_Axis, ...] - T_AxesOpt = Union[T_Axis, T_Axes, None] - T_Dtypes = Union[np.typing.DTypeLike, Sequence[np.typing.DTypeLike], None] - T_FillValues = Union[np.typing.ArrayLike, Sequence[np.typing.ArrayLike], None] + T_AxesOpt = T_Axis | T_Axes | None + T_Dtypes = np.typing.DTypeLike | Sequence[np.typing.DTypeLike] | None + T_FillValues = np.typing.ArrayLike | Sequence[np.typing.ArrayLike] | None T_Engine = Literal["flox", "numpy", "numba", "numbagg"] T_EngineOpt = None | T_Engine T_Method = Literal["map-reduce", "blockwise", "cohorts"] T_MethodOpt = None | Literal["map-reduce", "blockwise", "cohorts"] - T_IsBins = Union[bool | Sequence[bool]] + T_IsBins = bool | Sequence[bool] T = TypeVar("T") -IntermediateDict = dict[Union[str, Callable], Any] +IntermediateDict = dict[str | Callable, Any] FinalResultsDict = dict[str, Union["DaskArray", "CubedArray", np.ndarray]] FactorProps = namedtuple("FactorProps", "offset_group nan_sentinel nanmask") @@ -136,9 +136,7 @@ def _postprocess_numbagg(result, *, func, fill_value, size, seen_groups): # The condition needs to be # len(found_groups) < size; if so we mask with fill_value (?) default_fv = DEFAULT_FILL_VALUE[func] - needs_masking = fill_value is not None and not np.array_equal( - fill_value, default_fv, equal_nan=True - ) + needs_masking = fill_value is not None and not np.array_equal(fill_value, default_fv, equal_nan=True) groups = np.arange(size) if needs_masking: mask = np.isin(groups, seen_groups, assume_unique=True, invert=True) @@ -164,9 +162,7 @@ def _is_arg_reduction(func: T_Agg) -> bool: def _is_minmax_reduction(func: T_Agg) -> bool: - return not _is_arg_reduction(func) and ( - isinstance(func, str) and ("max" in func or "min" in func) - ) + return not _is_arg_reduction(func) and (isinstance(func, str) and ("max" in func or "min" in func)) def _is_first_last_reduction(func: T_Agg) -> bool: @@ -254,8 +250,7 @@ def slices_from_chunks(chunks): """slightly modified from dask.array.core.slices_from_chunks to be lazy""" cumdims = [tlz.accumulate(operator.add, bds, 0) for bds in chunks] slices = ( - (slice(s, s + dim) for s, dim in zip(starts, shapes)) - for starts, shapes in zip(cumdims, chunks) + (slice(s, s + dim) for s, dim in zip(starts, shapes)) for starts, shapes in zip(cumdims, chunks) ) return product(*slices) @@ -396,9 +391,7 @@ def find_group_cohorts( chunks_per_label = chunks_per_label[present_labels_mask] label_chunks = { - present_labels[idx].item(): bitmask.indices[ - slice(bitmask.indptr[idx], bitmask.indptr[idx + 1]) - ] + present_labels[idx].item(): bitmask.indices[slice(bitmask.indptr[idx], bitmask.indptr[idx + 1])] for idx in range(bitmask.shape[LABEL_AXIS]) } @@ -510,9 +503,7 @@ def invert(x) -> tuple[np.ndarray, ...]: for rowidx in order: if present_labels[rowidx] in merged_keys: continue - cohidx = containment.indices[ - slice(containment.indptr[rowidx], containment.indptr[rowidx + 1]) - ] + cohidx = containment.indices[slice(containment.indptr[rowidx], containment.indptr[rowidx + 1])] cohort_ = present_labels[cohidx] cohort = [elem.item() for elem in cohort_ if elem not in merged_keys] if not cohort: @@ -604,9 +595,7 @@ def rechunk_for_cohorts( else: next_break_is_close = False - if (not ignore_old_chunks and idx in oldbreaks) or ( - counter >= chunksize and not next_break_is_close - ): + if (not ignore_old_chunks and idx in oldbreaks) or (counter >= chunksize and not next_break_is_close): divisions.append(idx) counter = 1 continue @@ -922,7 +911,10 @@ def chunk_argreduce( if reindex and expected_groups is not None: results["intermediates"][1] = reindex_( - results["intermediates"][1], results["groups"].squeeze(), expected_groups, fill_value=0 + results["intermediates"][1], + results["groups"].squeeze(), + expected_groups, + fill_value=0, ) assert results["intermediates"][0].shape == results["intermediates"][1].shape @@ -1017,8 +1009,7 @@ def chunk_reduce( order = "C" if nax > 1: needs_broadcast = any( - group_idx.shape[ax] != array.shape[ax] and group_idx.shape[ax] == 1 - for ax in range(-nax, 0) + group_idx.shape[ax] != array.shape[ax] and group_idx.shape[ax] == 1 for ax in range(-nax, 0) ) if needs_broadcast: # This is the dim=... case, it's a lot faster to ravel group_idx @@ -1098,9 +1089,7 @@ def chunk_reduce( result = result[..., :-1] # TODO: Figure out how to generalize this if reduction in ("quantile", "nanquantile"): - new_dims_shape = tuple( - dim.size for dim in quantile_new_dims_func(**kw) if not dim.is_scalar - ) + new_dims_shape = tuple(dim.size for dim in quantile_new_dims_func(**kw) if not dim.is_scalar) else: new_dims_shape = tuple() result = result.reshape(new_dims_shape + final_array_shape[:-1] + found_groups_shape) @@ -1168,7 +1157,10 @@ def _finalize_results( # Final reindexing has to be here to be lazy if not reindex and expected_groups is not None: finalized[agg.name] = reindex_( - finalized[agg.name], squeezed["groups"], expected_groups, fill_value=fill_value + finalized[agg.name], + squeezed["groups"], + expected_groups, + fill_value=fill_value, ) finalized["groups"] = expected_groups.to_numpy() else: @@ -1194,9 +1186,7 @@ def _aggregate( def _expand_dims(results: IntermediateDict) -> IntermediateDict: - results["intermediates"] = tuple( - np.expand_dims(array, DUMMY_AXIS) for array in results["intermediates"] - ) + results["intermediates"] = tuple(np.expand_dims(array, DUMMY_AXIS) for array in results["intermediates"]) return results @@ -1238,7 +1228,8 @@ def _simple_combine( # So now reindex before combining by reducing along DUMMY_AXIS unique_groups = _find_unique_groups(x_chunk) x_chunk = deepmap( - partial(reindex_intermediates, agg=agg, unique_groups=unique_groups), x_chunk + partial(reindex_intermediates, agg=agg, unique_groups=unique_groups), + x_chunk, ) else: unique_groups = deepfirst(x_chunk)["groups"] @@ -1280,7 +1271,10 @@ def reindex_intermediates(x: IntermediateDict, agg: Aggregation, unique_groups) newx: IntermediateDict = {"groups": np.broadcast_to(unique_groups, new_shape)} newx["intermediates"] = tuple( reindex_( - v, from_=np.atleast_1d(x["groups"].squeeze()), to=pd.Index(unique_groups), fill_value=f + v, + from_=np.atleast_1d(x["groups"].squeeze()), + to=pd.Index(unique_groups), + fill_value=f, ) for v, f in zip(x["intermediates"], agg.fill_value["intermediate"]) ) @@ -1315,7 +1309,8 @@ def _grouped_combine( # I bet we can minimize the amount of reindexing for mD reductions too, but it's complicated unique_groups = _find_unique_groups(x_chunk) x_chunk = deepmap( - partial(reindex_intermediates, agg=agg, unique_groups=unique_groups), x_chunk + partial(reindex_intermediates, agg=agg, unique_groups=unique_groups), + x_chunk, ) # these are negative axis indices useful for concatenating the intermediates @@ -1332,15 +1327,16 @@ def _grouped_combine( # We need to send the intermediate array values & indexes at the same time # intermediates are (value e.g. max, index e.g. argmax, counts) - array_idx = tuple( - _conc2(x_chunk, key1="intermediates", key2=idx, axis=axis) for idx in (0, 1) - ) + array_idx = tuple(_conc2(x_chunk, key1="intermediates", key2=idx, axis=axis) for idx in (0, 1)) # for a single element along axis, we don't want to run the argreduction twice # This happens when we are reducing along an axis with a single chunk. avoid_reduction = array_idx[0].shape[axis[0]] == 1 if avoid_reduction: - results: IntermediateDict = {"groups": groups, "intermediates": list(array_idx)} + results: IntermediateDict = { + "groups": groups, + "intermediates": list(array_idx), + } else: results = chunk_argreduce( array_idx, @@ -1387,12 +1383,8 @@ def _grouped_combine( array = _conc2(x_chunk, key1="intermediates", key2=idx, axis=axis) if array.shape[-1] == 0: # all empty when combined - results["intermediates"].append( - np.empty(shape=(1,) * (len(axis) - 1) + (0,), dtype=dtype) - ) - results["groups"] = np.empty( - shape=(1,) * (len(neg_axis) - 1) + (0,), dtype=groups.dtype - ) + results["intermediates"].append(np.empty(shape=(1,) * (len(axis) - 1) + (0,), dtype=dtype)) + results["groups"] = np.empty(shape=(1,) * (len(neg_axis) - 1) + (0,), dtype=groups.dtype) else: _results = chunk_reduce( array, @@ -1456,9 +1448,7 @@ def _reduce_blockwise( if _is_arg_reduction(agg): results["intermediates"][0] = np.unravel_index(results["intermediates"][0], array.shape)[-1] - result = _finalize_results( - results, agg, axis, expected_groups, fill_value=fill_value, reindex=reindex - ) + result = _finalize_results(results, agg, axis, expected_groups, fill_value=fill_value, reindex=reindex) return result @@ -1570,7 +1560,6 @@ def _extract_unknown_groups(reduced, dtype) -> tuple[DaskArray]: def _unify_chunks(array, by): - from dask.array import from_array, unify_chunks inds = tuple(range(array.ndim)) @@ -1653,9 +1642,7 @@ def dask_groupby_agg( if method == "blockwise": # use the "non dask" code path, but applied blockwise - blockwise_method = partial( - _reduce_blockwise, agg=agg, fill_value=fill_value, reindex=reindex - ) + blockwise_method = partial(_reduce_blockwise, agg=agg, fill_value=fill_value, reindex=reindex) else: # choose `chunk_reduce` or `chunk_argreduce` blockwise_method = partial( @@ -1752,7 +1739,9 @@ def dask_groupby_agg( reindexed, combine=partial(combine, agg=agg, reindex=do_simple_combine), aggregate=partial( - aggregate, expected_groups=cohort_index, reindex=do_simple_combine + aggregate, + expected_groups=cohort_index, + reindex=do_simple_combine, ), ) ) @@ -1882,9 +1871,7 @@ def _reduction_func(a, by, axis, start_group, num_groups): # let's always do it anyway if not is_chunked_array(by): # chunk numpy arrays like the input array - chunks = tuple( - array.chunks[ax] if by.shape[ax] != 1 else (1,) for ax in range(-by.ndim, 0) - ) + chunks = tuple(array.chunks[ax] if by.shape[ax] != 1 else (1,) for ax in range(-by.ndim, 0)) by = cubed.from_array(by, chunks=chunks, spec=array.spec) _, (array, by) = cubed.core.unify_chunks(array, inds, by, inds[-by.ndim :]) @@ -1918,9 +1905,7 @@ def _groupby_func(a, by, axis, intermediate_dtype, num_groups): out = blockwise_method(a, by) # Convert dict to one that cubed understands, dropping groups since they are # known, and the same for every block. - return { - f"f{idx}": intermediate for idx, intermediate in enumerate(out["intermediates"]) - } + return {f"f{idx}": intermediate for idx, intermediate in enumerate(out["intermediates"])} def _groupby_combine(a, axis, dummy_axis, dtype, keepdims): # this is similar to _simple_combine, except the dummy axis and concatenation is handled by cubed @@ -2003,18 +1988,14 @@ def _validate_reindex( ) -> bool | None: # logger.debug("Entering _validate_reindex: reindex is {}".format(reindex)) # noqa def first_or_last(): - return func in ["first", "last"] or ( - _is_first_last_reduction(func) and array_dtype.kind != "f" - ) + return func in ["first", "last"] or (_is_first_last_reduction(func) and array_dtype.kind != "f") all_numpy = not is_dask_array and not any_by_dask if reindex is True and not all_numpy: if _is_arg_reduction(func): raise NotImplementedError if method == "cohorts" or (method == "blockwise" and not any_by_dask): - raise ValueError( - "reindex=True is not a valid choice for method='blockwise' or method='cohorts'." - ) + raise ValueError("reindex=True is not a valid choice for method='blockwise' or method='cohorts'.") if first_or_last(): raise ValueError("reindex must be None or False when func is 'first' or 'last.") @@ -2144,9 +2125,7 @@ def _factorize_multiple( for by_, expect in zip(by, expected_groups): if expect is None: if is_duck_dask_array(by_): - raise ValueError( - "Please provide expected_groups when grouping by a dask array." - ) + raise ValueError("Please provide expected_groups when grouping by a dask array.") found_group = pd.unique(by_.reshape(-1)) else: @@ -2177,7 +2156,7 @@ def _validate_expected_groups(nby: int, expected_groups: T_ExpectedGroupsOpt) -> return (None,) * nby if nby == 1 and not isinstance(expected_groups, tuple): - if isinstance(expected_groups, (pd.Index, np.ndarray)): + if isinstance(expected_groups, pd.Index | np.ndarray): return (expected_groups,) else: array = np.asarray(expected_groups) @@ -2252,9 +2231,7 @@ def _choose_engine(by, agg: Aggregation): (isinstance(func, str) and "nan" in func) for func in agg.chunk ) if HAS_NUMBAGG: - if agg.name in ["all", "any"] or ( - not_arg_reduce and has_blockwise_nan_skipping and dtype is None - ): + if agg.name in ["all", "any"] or (not_arg_reduce and has_blockwise_nan_skipping and dtype is None): logger.debug("_choose_engine: Choosing 'numbagg'") return "numbagg" @@ -2411,11 +2388,7 @@ def groupby_reduce( any_by_dask = any(by_is_dask) provided_expected = expected_groups is not None - if ( - engine == "numbagg" - and _is_arg_reduction(func) - and (any_by_dask or is_duck_dask_array(array)) - ): + if engine == "numbagg" and _is_arg_reduction(func) and (any_by_dask or is_duck_dask_array(array)): # There is only one test that fails, but I can't figure # out why without deep debugging. # just disable for now. @@ -2515,9 +2488,7 @@ def groupby_reduce( # TODO: Does this depend on chunking of by? # For e.g., we could relax this if there is only one chunk along all # by dim != axis? - raise NotImplementedError( - "Please provide ``expected_groups`` when not reducing along all axes." - ) + raise NotImplementedError("Please provide ``expected_groups`` when not reducing along all axes.") assert nax <= by_.ndim if nax < by_.ndim: @@ -2580,7 +2551,13 @@ def groupby_reduce( elif not has_dask: results = _reduce_blockwise( - array, by_, agg, expected_groups=expected_, reindex=bool(reindex), sort=sort, **kwargs + array, + by_, + agg, + expected_groups=expected_, + reindex=bool(reindex), + sort=sort, + **kwargs, ) groups = (results["groups"],) result = results[agg.name] @@ -2627,7 +2604,13 @@ def groupby_reduce( # TODO: clean this up reindex = _validate_reindex( - reindex, func, method, expected_, any_by_dask, is_duck_dask_array(array), array.dtype + reindex, + func, + method, + expected_, + any_by_dask, + is_duck_dask_array(array), + array.dtype, ) if TYPE_CHECKING: @@ -2798,9 +2781,7 @@ def groupby_scan( if expected_groups is not None: raise NotImplementedError("Setting `expected_groups` and binning is not supported yet.") expected_groups = _validate_expected_groups(nby, expected_groups) - expected_groups = _convert_expected_groups_to_index( - expected_groups, isbin=(False,) * nby, sort=False - ) + expected_groups = _convert_expected_groups_to_index(expected_groups, isbin=(False,) * nby, sort=False) # Don't factorize early only when # grouping by dask arrays, and not having expected_groups @@ -2918,7 +2899,12 @@ def dask_groupby_scan(array, by, axes: T_Axes, agg: Scan) -> DaskArray: # 1. zip together group indices & array zipped = map_blocks( - _zip, by, array, dtype=array.dtype, meta=array._meta, name="groupby-scan-preprocess" + _zip, + by, + array, + dtype=array.dtype, + meta=array._meta, + name="groupby-scan-preprocess", ) scan_ = partial(chunk_scan, agg=agg) diff --git a/flox/xarray.py b/flox/xarray.py index 10a33e1d6..1562acc85 100644 --- a/flox/xarray.py +++ b/flox/xarray.py @@ -1,7 +1,7 @@ from __future__ import annotations from collections.abc import Hashable, Iterable, Sequence -from typing import TYPE_CHECKING, Any, Union +from typing import TYPE_CHECKING, Any import numpy as np import pandas as pd @@ -25,7 +25,7 @@ from .core import T_ExpectedGroupsOpt, T_ExpectIndex, T_ExpectOpt - Dims = Union[str, Iterable[Hashable], None] + Dims = str | Iterable[Hashable] | None def _restore_dim_order(result, obj, by, no_groupby_reorder=False): @@ -286,9 +286,7 @@ def xarray_reduce( try: xr.align(ds, *by_da, join="exact", copy=False) except ValueError as e: - raise ValueError( - "Object being grouped must be exactly aligned with every array in `by`." - ) from e + raise ValueError("Object being grouped must be exactly aligned with every array in `by`.") from e needs_broadcast = any( not set(grouper_dims).issubset(set(variable.dims)) for variable in ds.data_vars.values() @@ -329,15 +327,11 @@ def xarray_reduce( group_names: tuple[Any, ...] = () group_sizes: dict[Any, int] = {} for idx, (b_, expect, isbin_) in enumerate(zip(by_da, expected_groups_valid, isbins)): - group_name = ( - f"{b_.name}_bins" if isbin_ or isinstance(expect, pd.IntervalIndex) else b_.name - ) + group_name = f"{b_.name}_bins" if isbin_ or isinstance(expect, pd.IntervalIndex) else b_.name group_names += (group_name,) if isbin_ and isinstance(expect, int): - raise NotImplementedError( - "flox does not support binning into an integer number of bins yet." - ) + raise NotImplementedError("flox does not support binning into an integer number of bins yet.") expect1: T_ExpectOpt if expect is None: @@ -448,7 +442,8 @@ def wrapper(array, *by, func, skipna, core_dims, **kwargs): output_core_dims=[output_core_dims], dask="allowed", dask_gufunc_kwargs=dict( - output_sizes=output_sizes, output_dtypes=[dtype] if dtype is not None else None + output_sizes=output_sizes, + output_dtypes=[dtype] if dtype is not None else None, ), keep_attrs=keep_attrs, kwargs={ @@ -520,11 +515,12 @@ def wrapper(array, *by, func, skipna, core_dims, **kwargs): template = obj if actual[var].ndim > 1 + len(vector_dims): - no_groupby_reorder = isinstance( - obj, xr.Dataset - ) # do not re-order dataarrays inside datasets + no_groupby_reorder = isinstance(obj, xr.Dataset) # do not re-order dataarrays inside datasets actual[var] = _restore_dim_order( - actual[var].variable, template, by_da[0], no_groupby_reorder=no_groupby_reorder + actual[var].variable, + template, + by_da[0], + no_groupby_reorder=no_groupby_reorder, ) if missing_dim: @@ -625,13 +621,14 @@ def _rechunk(func, obj, dim, labels, **kwargs): if obj[var].chunks is not None: obj[var] = obj[var].copy( data=func( - obj[var].data, axis=obj[var].get_axis_num(dim), labels=labels.data, **kwargs + obj[var].data, + axis=obj[var].get_axis_num(dim), + labels=labels.data, + **kwargs, ) ) else: if obj.chunks is not None: - obj = obj.copy( - data=func(obj.data, axis=obj.get_axis_num(dim), labels=labels.data, **kwargs) - ) + obj = obj.copy(data=func(obj.data, axis=obj.get_axis_num(dim), labels=labels.data, **kwargs)) return obj diff --git a/flox/xrutils.py b/flox/xrutils.py index 12bf54a10..ba8a56723 100644 --- a/flox/xrutils.py +++ b/flox/xrutils.py @@ -4,14 +4,14 @@ import datetime import importlib from collections.abc import Iterable -from typing import Any, Optional +from typing import Any import numpy as np import pandas as pd from packaging.version import Version -def module_available(module: str, minversion: Optional[str] = None) -> bool: +def module_available(module: str, minversion: str | None = None) -> bool: """Checks whether a module is installed without importing it. Use this for a lightweight check and lazy imports. @@ -137,7 +137,7 @@ def is_scalar(value: Any, include_0d: bool = True) -> bool: include_0d = getattr(value, "ndim", None) == 0 return ( include_0d - or isinstance(value, (str, bytes, dict)) + or isinstance(value, str | bytes | dict) or not ( isinstance(value, (Iterable,) + NON_NUMPY_SUPPORTED_ARRAY_TYPES) or hasattr(value, "__array_function__") @@ -150,7 +150,7 @@ def notnull(data): data = np.asarray(data) scalar_type = data.dtype.type - if issubclass(scalar_type, (np.bool_, np.integer, np.character, np.void)): + if issubclass(scalar_type, np.bool_ | np.integer | np.character | np.void): # these types cannot represent missing values return np.ones_like(data, dtype=bool) else: @@ -163,7 +163,7 @@ def isnull(data): if not is_duck_array(data): data = np.asarray(data) scalar_type = data.dtype.type - if issubclass(scalar_type, (np.datetime64, np.timedelta64)): + if issubclass(scalar_type, np.datetime64 | np.timedelta64): # datetime types use NaT for null # note: must check timedelta64 before integers, because currently # timedelta64 inherits from np.integer @@ -171,12 +171,12 @@ def isnull(data): elif issubclass(scalar_type, np.inexact): # float types use NaN for null return np.isnan(data) - elif issubclass(scalar_type, (np.bool_, np.integer, np.character, np.void)): + elif issubclass(scalar_type, np.bool_ | np.integer | np.character | np.void): # these types cannot represent missing values return np.zeros_like(data, dtype=bool) else: # at this point, array should have dtype=object - if isinstance(data, (np.ndarray, dask_array_type)): + if isinstance(data, (np.ndarray, dask_array_type)): # noqa return pd.isnull(data) else: # Not reachable yet, but intended for use with other duck array @@ -275,9 +275,7 @@ def timedelta_to_numeric(value, datetime_unit="ns", dtype=float): try: a = pd.to_timedelta(value) except ValueError: - raise ValueError( - f"Could not convert {value!r} to timedelta64 using pandas.to_timedelta" - ) + raise ValueError(f"Could not convert {value!r} to timedelta64 using pandas.to_timedelta") return py_timedelta_to_float(a, datetime_unit) else: raise TypeError( diff --git a/pyproject.toml b/pyproject.toml index 0e0ad70bb..b1dc18781 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -3,7 +3,7 @@ name = "flox" description = "GroupBy operations for dask.array" license = {file = "LICENSE"} readme = "README.md" -requires-python = ">=3.9" +requires-python = ">=3.10" keywords = ["xarray", "dask", "groupby"] classifiers = [ "Development Status :: 4 - Beta", @@ -11,7 +11,6 @@ classifiers = [ "Natural Language :: English", "Operating System :: OS Independent", "Programming Language :: Python", - "Programming Language :: Python :: 3.9", "Programming Language :: Python :: 3.10", "Programming Language :: Python :: 3.11", "Programming Language :: Python :: 3.12", @@ -60,12 +59,9 @@ fallback_version = "999" write_to = "flox/_version.py" write_to_template= '__version__ = "{version}"' -[tool.black] -line-length = 100 -target-version = ["py39"] - [tool.ruff] -target-version = "py39" +line-length = 110 +target-version = "py310" builtins = ["ellipsis"] exclude = [ ".eggs", @@ -109,6 +105,10 @@ known-third-party = [ "xarray" ] +[tool.ruff.format] +# Enable reformatting of code snippets in docstrings. +docstring-code-format = true + [tool.mypy] allow_redefinition = true files = "**/*.py" diff --git a/tests/__init__.py b/tests/__init__.py index 3c3ba1dac..deaf7c0ae 100644 --- a/tests/__init__.py +++ b/tests/__init__.py @@ -188,9 +188,9 @@ def dask_assert_eq( a_original = a b_original = b - if isinstance(a, (list, int, float)): + if isinstance(a, list | int | float): a = np.array(a) - if isinstance(b, (list, int, float)): + if isinstance(b, list | int | float): b = np.array(b) a, adt, a_meta, a_computed = _get_dt_meta_computed( diff --git a/tests/strategies.py b/tests/strategies.py index a2f95da1f..6f28db32d 100644 --- a/tests/strategies.py +++ b/tests/strategies.py @@ -1,6 +1,7 @@ from __future__ import annotations -from typing import Any, Callable +from collections.abc import Callable +from typing import Any import cftime import dask @@ -29,19 +30,25 @@ def supported_dtypes() -> st.SearchStrategy[np.dtype]: array_dtype_st = supported_dtypes().filter(lambda x: x.kind not in "cmMU") by_dtype_st = supported_dtypes() -NON_NUMPY_FUNCS = ["first", "last", "nanfirst", "nanlast", "count", "any", "all"] + list( - SCIPY_STATS_FUNCS -) +NON_NUMPY_FUNCS = [ + "first", + "last", + "nanfirst", + "nanlast", + "count", + "any", + "all", +] + list(SCIPY_STATS_FUNCS) SKIPPED_FUNCS = ["var", "std", "nanvar", "nanstd"] -func_st = st.sampled_from( - [f for f in ALL_FUNCS if f not in NON_NUMPY_FUNCS and f not in SKIPPED_FUNCS] -) +func_st = st.sampled_from([f for f in ALL_FUNCS if f not in NON_NUMPY_FUNCS and f not in SKIPPED_FUNCS]) numeric_arrays = npst.arrays( elements={"allow_subnormal": False}, shape=npst.array_shapes(), dtype=array_dtype_st ) all_arrays = npst.arrays( - elements={"allow_subnormal": False}, shape=npst.array_shapes(), dtype=supported_dtypes() + elements={"allow_subnormal": False}, + shape=npst.array_shapes(), + dtype=supported_dtypes(), ) calendars = st.sampled_from( diff --git a/tests/test_asv.py b/tests/test_asv.py index c6f9525bc..e26cbbaf1 100644 --- a/tests/test_asv.py +++ b/tests/test_asv.py @@ -7,9 +7,7 @@ from asv_bench.benchmarks import reduce -@pytest.mark.parametrize( - "problem", [reduce.ChunkReduce1D, reduce.ChunkReduce2D, reduce.ChunkReduce2DAllAxes] -) +@pytest.mark.parametrize("problem", [reduce.ChunkReduce1D, reduce.ChunkReduce2D, reduce.ChunkReduce2DAllAxes]) def test_reduce(problem) -> None: testcase = problem() testcase.setup() diff --git a/tests/test_core.py b/tests/test_core.py index 2d2252068..2c33ebcfc 100644 --- a/tests/test_core.py +++ b/tests/test_core.py @@ -3,8 +3,9 @@ import itertools import logging import warnings +from collections.abc import Callable from functools import partial, reduce -from typing import TYPE_CHECKING, Callable +from typing import TYPE_CHECKING from unittest.mock import MagicMock, patch import numpy as np @@ -126,14 +127,24 @@ def test_alignment_error(): ("sum", np.ones((12,)), nan_labels, [1, 4, 2]), # form 1 ("sum", np.ones((2, 12)), labels, [[3, 4, 5], [3, 4, 5]]), # form 3 ("sum", np.ones((2, 12)), nan_labels, [[1, 4, 2], [1, 4, 2]]), # form 3 - ("sum", np.ones((2, 12)), np.array([labels, labels]), [6, 8, 10]), # form 1 after reshape + ( + "sum", + np.ones((2, 12)), + np.array([labels, labels]), + [6, 8, 10], + ), # form 1 after reshape ("sum", np.ones((2, 12)), np.array([nan_labels, nan_labels]), [2, 8, 4]), # (np.ones((12,)), np.array([labels, labels])), # form 4 ("count", np.ones((12,)), labels, [3, 4, 5]), # form 1 ("count", np.ones((12,)), nan_labels, [1, 4, 2]), # form 1 ("count", np.ones((2, 12)), labels, [[3, 4, 5], [3, 4, 5]]), # form 3 ("count", np.ones((2, 12)), nan_labels, [[1, 4, 2], [1, 4, 2]]), # form 3 - ("count", np.ones((2, 12)), np.array([labels, labels]), [6, 8, 10]), # form 1 after reshape + ( + "count", + np.ones((2, 12)), + np.array([labels, labels]), + [6, 8, 10], + ), # form 1 after reshape ("count", np.ones((2, 12)), np.array([nan_labels, nan_labels]), [2, 8, 4]), ("nanmean", np.ones((12,)), labels, [1, 1, 1]), # form 1 ("nanmean", np.ones((12,)), nan_labels, [1, 1, 1]), # form 1 @@ -215,9 +226,7 @@ def gen_array_by(size, func): @pytest.mark.parametrize("add_nan_by", [True, False]) @pytest.mark.parametrize("func", ALL_FUNCS) def test_groupby_reduce_all(nby, size, chunks, func, add_nan_by, engine): - if ("arg" in func and engine in ["flox", "numbagg"]) or ( - func in BLOCKWISE_FUNCS and chunks != -1 - ): + if ("arg" in func and engine in ["flox", "numbagg"]) or (func in BLOCKWISE_FUNCS and chunks != -1): pytest.skip() array, by = gen_array_by(size, func) @@ -237,7 +246,10 @@ def test_groupby_reduce_all(nby, size, chunks, func, add_nan_by, engine): fill_value = np.nan tolerance = {"rtol": 1e-13, "atol": 1e-15} elif "quantile" in func: - finalize_kwargs = [{"q": DEFAULT_QUANTILE}, {"q": [DEFAULT_QUANTILE / 2, DEFAULT_QUANTILE]}] + finalize_kwargs = [ + {"q": DEFAULT_QUANTILE}, + {"q": [DEFAULT_QUANTILE / 2, DEFAULT_QUANTILE]}, + ] fill_value = None tolerance = None else: @@ -313,7 +325,12 @@ def test_groupby_reduce_all(nby, size, chunks, func, add_nan_by, engine): combine_error = RuntimeError("This combine should not have been called.") for method, reindex in params: call = partial( - groupby_reduce, array, *by, method=method, reindex=reindex, **flox_kwargs + groupby_reduce, + array, + *by, + method=method, + reindex=reindex, + **flox_kwargs, ) if ("arg" in func or func in ["first", "last"]) and reindex is True: # simple_combine with argreductions not supported right now @@ -461,7 +478,9 @@ def test_groupby_agg_dask(func, shape, array_chunks, group_chunks, add_nan, dtyp labels[-2:] = np.nan kwargs = dict( - func=func, expected_groups=[0, 1, 2], fill_value=False if func in ["all", "any"] else 123 + func=func, + expected_groups=[0, 1, 2], + fill_value=False if func in ["all", "any"] else 123, ) expected, _ = groupby_reduce(array.compute(), labels, engine="numpy", **kwargs) @@ -674,15 +693,16 @@ def test_first_last_disallowed_dask(func): # anything else is not. with pytest.raises(ValueError): groupby_reduce( - dask.array.empty((2, 3, 2), chunks=(-1, -1, 1)), np.ones((2,)), func=func, axis=-1 + dask.array.empty((2, 3, 2), chunks=(-1, -1, 1)), + np.ones((2,)), + func=func, + axis=-1, ) @requires_dask @pytest.mark.parametrize("func", ALL_FUNCS) -@pytest.mark.parametrize( - "axis", [None, (0, 1, 2), (0, 1), (0, 2), (1, 2), 0, 1, 2, (0,), (1,), (2,)] -) +@pytest.mark.parametrize("axis", [None, (0, 1, 2), (0, 1), (0, 2), (1, 2), 0, 1, 2, (0,), (1,), (2,)]) def test_groupby_reduce_axis_subset_against_numpy(func, axis, engine): if ("arg" in func and engine in ["flox", "numbagg"]) or func in BLOCKWISE_FUNCS: pytest.skip() @@ -797,7 +817,8 @@ def _maybe_chunk(arr): @requires_dask @pytest.mark.parametrize( - "expected_groups, reindex", [(None, None), (None, False), ([0, 1, 2], True), ([0, 1, 2], False)] + "expected_groups, reindex", + [(None, None), (None, False), ([0, 1, 2], True), ([0, 1, 2], False)], ) def test_groupby_all_nan_blocks_dask(expected_groups, reindex, engine): labels = np.array([0, 0, 2, 2, 2, 1, 1, 2, 2, 1, 1, 0]) @@ -848,11 +869,16 @@ def test_bad_npg_behaviour(): # fmt: off array = np.array([[1] * 12, [1] * 12]) # fmt: on - assert_equal(aggregate(labels, array, axis=-1, func="argmax"), np.array([[0, 5, 2], [0, 5, 2]])) + assert_equal( + aggregate(labels, array, axis=-1, func="argmax"), + np.array([[0, 5, 2], [0, 5, 2]]), + ) assert ( aggregate( - np.array([0, 1, 2, 0, 1, 2]), np.array([-np.inf, 0, 0, -np.inf, 0, 0]), func="max" + np.array([0, 1, 2, 0, 1, 2]), + np.array([-np.inf, 0, 0, -np.inf, 0, 0]), + func="max", )[0] == -np.inf ) @@ -900,13 +926,17 @@ def test_groupby_bins(chunk_labels, kwargs, chunks, engine, method) -> None: with raise_if_dask_computes(): actual, *groups = groupby_reduce( - array, labels, func="count", fill_value=0, engine=engine, method=method, **kwargs + array, + labels, + func="count", + fill_value=0, + engine=engine, + method=method, + **kwargs, ) (groups_array,) = groups expected = np.array([3, 1, 0], dtype=np.intp) - for left, right in zip( - groups_array, pd.IntervalIndex.from_arrays([1, 2, 4], [2, 4, 5]).to_numpy() - ): + for left, right in zip(groups_array, pd.IntervalIndex.from_arrays([1, 2, 4], [2, 4, 5]).to_numpy()): assert left == right assert_equal(actual, expected) @@ -940,7 +970,11 @@ def test_rechunk_for_blockwise(inchunks, expected): [[[0, 1, 2, 3]], [0, 1, 2, 0, 1, 2, 3], (3, 4)], [[[0], [1], [2], [3]], [0, 1, 2, 0, 1, 2, 3], (2, 2, 2, 1)], [[[0, 1, 2], [3]], [0, 1, 2, 0, 1, 2, 3], (3, 3, 1)], - [[[0], [1, 2, 3, 4], [5]], np.repeat(np.arange(6), [4, 4, 12, 2, 3, 4]), (4, 8, 4, 9, 4)], + [ + [[0], [1, 2, 3, 4], [5]], + np.repeat(np.arange(6), [4, 4, 12, 2, 3, 4]), + (4, 8, 4, 9, 4), + ], ], ) def test_find_group_cohorts(expected, labels, chunks: tuple[int]) -> None: @@ -1039,11 +1073,14 @@ def test_fill_value_behaviour(func, chunks, fill_value, engine): if chunks: array = dask.array.from_array(array, chunks) actual, _ = groupby_reduce( - array, by, func=func, engine=engine, fill_value=fill_value, expected_groups=[0, 1, 2, 3] - ) - expected = np.array( - [fill_value, fill_value, npfunc([1.0, 1.0], axis=0), npfunc([1.0, 1.0], axis=0)] + array, + by, + func=func, + engine=engine, + fill_value=fill_value, + expected_groups=[0, 1, 2, 3], ) + expected = np.array([fill_value, fill_value, npfunc([1.0, 1.0], axis=0), npfunc([1.0, 1.0], axis=0)]) assert_equal(actual, expected) @@ -1140,7 +1177,12 @@ def test_dtype_promotion(func, fill_value, expected, engine): by = [0, 1] actual, _ = groupby_reduce( - array, by, func=func, expected_groups=[1, 2], fill_value=fill_value, engine=engine + array, + by, + func=func, + expected_groups=[1, 2], + fill_value=fill_value, + engine=engine, ) assert np.issubdtype(actual.dtype, expected) @@ -1259,9 +1301,7 @@ def test_group_by_datetime_cubed(engine, method): assert_equal(expected, actual) edges = pd.date_range("1999-12-31", "2000-12-31", freq="ME").to_series().to_numpy() - actual, _ = groupby_reduce( - cubedarray, t.to_numpy(), isbin=True, expected_groups=edges, **kwargs - ) + actual, _ = groupby_reduce(cubedarray, t.to_numpy(), isbin=True, expected_groups=edges, **kwargs) expected = data.resample("ME").mean().to_numpy() assert_equal(expected, actual) @@ -1316,9 +1356,7 @@ def test_multiple_groupers_bins(chunk) -> None: @pytest.mark.parametrize("expected_groups", [None, (np.arange(5), [2, 3]), (None, [2, 3])]) -@pytest.mark.parametrize( - "by1", [np.arange(5)[:, None], np.broadcast_to(np.arange(5)[:, None], (5, 2))] -) +@pytest.mark.parametrize("by1", [np.arange(5)[:, None], np.broadcast_to(np.arange(5)[:, None], (5, 2))]) @pytest.mark.parametrize( "by2", [ @@ -1341,9 +1379,7 @@ def test_multiple_groupers(chunk, by1, by2, expected_groups) -> None: # output from `count` is intp expected = np.ones((5, 2), dtype=np.intp) - actual, *_ = groupby_reduce( - array, by1, by2, axis=(0, 1), func="count", expected_groups=expected_groups - ) + actual, *_ = groupby_reduce(array, by1, by2, axis=(0, 1), func="count", expected_groups=expected_groups) assert_equal(expected, actual) @@ -1435,9 +1471,7 @@ def grouped_median(group_idx, array, *, axis=-1, size=None, fill_value=None, dty dtype=dtype, ) - agg_median = Aggregation( - name="median", numpy=grouped_median, fill_value=-1, chunk=None, combine=None - ) + agg_median = Aggregation(name="median", numpy=grouped_median, fill_value=-1, chunk=None, combine=None) array = np.arange(100, dtype=np.float32).reshape(5, 20) by = np.ones((20,)) @@ -1480,7 +1514,12 @@ def test_dtype(func, dtype, engine): arr = np.ones((4, 12), dtype=dtype) labels = np.array(["a", "a", "c", "c", "c", "b", "b", "c", "c", "b", "b", "f"]) actual, _ = groupby_reduce( - arr, labels, func=func, dtype=np.float64, engine=engine, finalize_kwargs=finalize_kwargs + arr, + labels, + func=func, + dtype=np.float64, + engine=engine, + finalize_kwargs=finalize_kwargs, ) assert actual.dtype == np.dtype("float64") @@ -1587,9 +1626,7 @@ def test_subset_block_2d(flatblocks, expectidx): [True, None, "sum", ([1], None), True], ], ) -def test_validate_reindex_map_reduce( - dask_expected, reindex, func, expected_groups, any_by_dask -) -> None: +def test_validate_reindex_map_reduce(dask_expected, reindex, func, expected_groups, any_by_dask) -> None: actual = _validate_reindex( reindex, func, @@ -1720,12 +1757,20 @@ def test_1d_blockwise_sort_optimization(): assert all("getitem" not in k for k in actual.dask) actual, _ = groupby_reduce( - array, time.dt.dayofyear.values[::-1], sort=True, method="blockwise", func="count" + array, + time.dt.dayofyear.values[::-1], + sort=True, + method="blockwise", + func="count", ) assert any("getitem" in k for k in actual.dask.layers) actual, _ = groupby_reduce( - array, time.dt.dayofyear.values[::-1], sort=False, method="blockwise", func="count" + array, + time.dt.dayofyear.values[::-1], + sort=False, + method="blockwise", + func="count", ) assert all("getitem" not in k for k in actual.dask.layers) @@ -1760,9 +1805,7 @@ def test_negative_index_factorize_race_condition(): @pytest.mark.parametrize("sort", [True, False]) def test_expected_index_conversion_passthrough_range_index(sort): index = pd.RangeIndex(100) - actual = _convert_expected_groups_to_index( - expected_groups=(index,), isbin=(False,), sort=(sort,) - ) + actual = _convert_expected_groups_to_index(expected_groups=(index,), isbin=(False,), sort=(sort,)) assert actual[0] is index @@ -1935,9 +1978,7 @@ def test_ffill_bfill(chunks, size, add_nan_by, func): def test_blockwise_nans(): array = dask.array.ones((1, 10), chunks=2) by = np.array([-1, 0, -1, 1, -1, 2, -1, 3, 4, 4]) - actual, actual_groups = flox.groupby_reduce( - array, by, func="sum", expected_groups=pd.RangeIndex(0, 5) - ) + actual, actual_groups = flox.groupby_reduce(array, by, func="sum", expected_groups=pd.RangeIndex(0, 5)) expected, expected_groups = flox.groupby_reduce( array.compute(), by, func="sum", expected_groups=pd.RangeIndex(0, 5) ) diff --git a/tests/test_properties.py b/tests/test_properties.py index 3c86dc34b..584314cea 100644 --- a/tests/test_properties.py +++ b/tests/test_properties.py @@ -1,5 +1,6 @@ import warnings -from typing import Any, Callable +from collections.abc import Callable +from typing import Any import pandas as pd import pytest @@ -92,18 +93,18 @@ def test_groupby_reduce(data, array, func: str) -> None: flox_kwargs: dict[str, Any] = {} with np.errstate(invalid="ignore", divide="ignore"): actual, *_ = groupby_reduce( - array, by, func=func, axis=axis, engine="numpy", **flox_kwargs, finalize_kwargs=kwargs + array, + by, + func=func, + axis=axis, + engine="numpy", + **flox_kwargs, + finalize_kwargs=kwargs, ) # numpy-groupies always does the calculation in float64 if ( - ( - "var" in func - or "std" in func - or "sum" in func - or "mean" in func - or "quantile" in func - ) + ("var" in func or "std" in func or "sum" in func or "mean" in func or "quantile" in func) and array.dtype.kind == "f" and array.dtype.itemsize != 8 ): @@ -195,8 +196,18 @@ def reverse(arr): def test_first_last(data, array: dask.array.Array, func: str) -> None: by = data.draw(by_arrays(shape=(array.shape[-1],))) - INVERSES = {"first": "last", "last": "first", "nanfirst": "nanlast", "nanlast": "nanfirst"} - MATES = {"first": "nanfirst", "last": "nanlast", "nanfirst": "first", "nanlast": "last"} + INVERSES = { + "first": "last", + "last": "first", + "nanfirst": "nanlast", + "nanlast": "nanfirst", + } + MATES = { + "first": "nanfirst", + "last": "nanlast", + "nanfirst": "first", + "nanlast": "last", + } inverse = INVERSES[func] mate = MATES[func] diff --git a/tests/test_xarray.py b/tests/test_xarray.py index 11b2e23cb..2592fa07c 100644 --- a/tests/test_xarray.py +++ b/tests/test_xarray.py @@ -49,7 +49,9 @@ def test_xarray_reduce(skipna, add_nan, min_count, engine_no_numba, reindex): labels2 = np.array([1, 2, 2, 1]) da = xr.DataArray( - arr, dims=("x", "y"), coords={"labels2": ("x", labels2), "labels": ("y", labels)} + arr, + dims=("x", "y"), + coords={"labels2": ("x", labels2), "labels": ("y", labels)}, ).expand_dims(z=4) expected = da.groupby("labels").sum(skipna=skipna, min_count=min_count) @@ -98,7 +100,9 @@ def test_xarray_reduce_multiple_groupers(pass_expected_groups, chunk, engine_no_ labels2 = np.array([1, 2, 2, 1]) da = xr.DataArray( - arr, dims=("x", "y"), coords={"labels2": ("x", labels2), "labels": ("y", labels)} + arr, + dims=("x", "y"), + coords={"labels2": ("x", labels2), "labels": ("y", labels)}, ).expand_dims(z=4) if chunk: @@ -177,9 +181,7 @@ def test_xarray_reduce_multiple_groupers_2(pass_expected_groups, chunk, engine_n (None, (None, None), [[1, 2], [1, 2]]), ) def test_validate_expected_groups(expected_groups): - da = xr.DataArray( - [1.0, 2.0], dims="x", coords={"labels": ("x", [1, 2]), "labels2": ("x", [1, 2])} - ) + da = xr.DataArray([1.0, 2.0], dims="x", coords={"labels": ("x", [1, 2]), "labels2": ("x", [1, 2])}) with pytest.raises(ValueError): xarray_reduce( da.chunk({"x": 1}), @@ -196,12 +198,13 @@ def test_xarray_reduce_single_grouper(engine_no_numba): engine = engine_no_numba # DataArray ds = xr.Dataset( - {"Tair": (("time", "x", "y"), dask.array.ones((36, 205, 275), chunks=(9, -1, -1)))}, - coords={ - "time": xr.date_range( - "1980-09-01 00:00", "1983-09-18 00:00", freq="ME", calendar="noleap" + { + "Tair": ( + ("time", "x", "y"), + dask.array.ones((36, 205, 275), chunks=(9, -1, -1)), ) }, + coords={"time": xr.date_range("1980-09-01 00:00", "1983-09-18 00:00", freq="ME", calendar="noleap")}, ) actual = xarray_reduce(ds.Tair, ds.time.dt.month, func="mean", engine=engine) expected = ds.Tair.groupby("time.month").mean() @@ -380,12 +383,13 @@ def test_func_is_aggregation(): from flox.aggregations import mean ds = xr.Dataset( - {"Tair": (("time", "x", "y"), dask.array.ones((36, 205, 275), chunks=(9, -1, -1)))}, - coords={ - "time": xr.date_range( - "1980-09-01 00:00", "1983-09-18 00:00", freq="ME", calendar="noleap" + { + "Tair": ( + ("time", "x", "y"), + dask.array.ones((36, 205, 275), chunks=(9, -1, -1)), ) }, + coords={"time": xr.date_range("1980-09-01 00:00", "1983-09-18 00:00", freq="ME", calendar="noleap")}, ) expected = xarray_reduce(ds.Tair, ds.time.dt.month, func="mean") actual = xarray_reduce(ds.Tair, ds.time.dt.month, func=mean) @@ -520,7 +524,10 @@ def test_dtype(add_nan, chunk, dtype, dtype_out, engine_no_numba): data, dims=("x", "t"), coords={ - "labels": ("t", np.array(["a", "a", "c", "c", "c", "b", "b", "c", "c", "b", "b", "f"])) + "labels": ( + "t", + np.array(["a", "a", "c", "c", "c", "b", "b", "c", "c", "b", "b", "f"]), + ) }, name="arr", ) @@ -642,7 +649,11 @@ def test_fill_value_xarray_binning(): def test_groupby_2d_dataset(): d = { "coords": { - "bit_index": {"dims": ("bit_index",), "attrs": {"name": "bit_index"}, "data": [0, 1]}, + "bit_index": { + "dims": ("bit_index",), + "attrs": {"name": "bit_index"}, + "data": [0, 1], + }, "index": {"dims": ("index",), "data": [0, 6, 8, 10, 14]}, "clifford": {"dims": ("index",), "attrs": {}, "data": [1, 1, 4, 10, 4]}, }, @@ -664,18 +675,14 @@ def test_groupby_2d_dataset(): expected = ds.groupby("clifford").mean() with xr.set_options(use_flox=True): actual = ds.groupby("clifford").mean() - assert ( - expected.counts.dims == actual.counts.dims - ) # https://github.com/pydata/xarray/issues/8292 + assert expected.counts.dims == actual.counts.dims # https://github.com/pydata/xarray/issues/8292 xr.testing.assert_identical(expected, actual) @pytest.mark.parametrize("chunk", (pytest.param(True, marks=requires_dask), False)) def test_resampling_missing_groups(chunk): # Regression test for https://github.com/pydata/xarray/issues/8592 - time_coords = pd.to_datetime( - ["2018-06-13T03:40:36", "2018-06-13T05:50:37", "2018-06-15T03:02:34"] - ) + time_coords = pd.to_datetime(["2018-06-13T03:40:36", "2018-06-13T05:50:37", "2018-06-15T03:02:34"]) latitude_coords = [0.0] longitude_coords = [0.0] @@ -684,7 +691,11 @@ def test_resampling_missing_groups(chunk): da = xr.DataArray( data, - coords={"time": time_coords, "latitude": latitude_coords, "longitude": longitude_coords}, + coords={ + "time": time_coords, + "latitude": latitude_coords, + "longitude": longitude_coords, + }, dims=["time", "latitude", "longitude"], ) if chunk: From 7421cb152ad5fc4336bb1570165f889d18e0ea9b Mon Sep 17 00:00:00 2001 From: Deepak Cherian Date: Sun, 8 Sep 2024 00:05:46 -0600 Subject: [PATCH 2/2] Preserve dtype better when specified. (#389) * Preserve dtype better when specified. * Add one more test * tweak test * more test * [revert] test with Xarray PR branch * tweak * show versions * Drop python 3.9, use ruff * switch to Ruff * fix mypy * remove toctrees * fix * one more --- .github/workflows/ci.yaml | 4 +++- ci/environment.yml | 3 ++- ci/no-dask.yml | 3 ++- flox/aggregations.py | 13 ++++++++----- flox/xrdtypes.py | 9 +++++++-- tests/strategies.py | 4 ++-- tests/test_core.py | 15 ++++++++++++++- tests/test_properties.py | 24 +++++++++++++++++++++++- tests/test_xarray.py | 25 ++++++++++++++++++++++++- 9 files changed, 85 insertions(+), 15 deletions(-) diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml index 3f1416b2a..bbee15060 100644 --- a/.github/workflows/ci.yaml +++ b/.github/workflows/ci.yaml @@ -70,6 +70,7 @@ jobs: - name: Run Tests id: status run: | + python -c "import xarray; xarray.show_versions()" pytest --durations=20 --durations-min=0.5 -n auto --cov=./ --cov-report=xml --hypothesis-profile ci - name: Upload code coverage to Codecov uses: codecov/codecov-action@v4.5.0 @@ -98,7 +99,7 @@ jobs: steps: - uses: actions/checkout@v4 with: - repository: "pydata/xarray" + repository: "dcherian/xarray" fetch-depth: 0 # Fetch all history for all branches and tags. - name: Set up conda environment uses: mamba-org/setup-micromamba@v1 @@ -112,6 +113,7 @@ jobs: pint>=0.22 - name: Install xarray run: | + git checkout flox-preserve-dtype python -m pip install --no-deps . - name: Install upstream flox run: | diff --git a/ci/environment.yml b/ci/environment.yml index 82995d079..dac6880ac 100644 --- a/ci/environment.yml +++ b/ci/environment.yml @@ -19,7 +19,6 @@ dependencies: - pytest-pretty - pytest-xdist - syrupy - - xarray - pre-commit - numpy_groupies>=0.9.19 - pooch @@ -27,3 +26,5 @@ dependencies: - numba - numbagg>=0.3 - hypothesis + - pip: + - git+https://github.com/dcherian/xarray.git@flox-preserve-dtype diff --git a/ci/no-dask.yml b/ci/no-dask.yml index 1f05c63a9..fb2bac92d 100644 --- a/ci/no-dask.yml +++ b/ci/no-dask.yml @@ -14,7 +14,6 @@ dependencies: - pytest-pretty - pytest-xdist - syrupy - - xarray - numpydoc - pre-commit - numpy_groupies>=0.9.19 @@ -22,3 +21,5 @@ dependencies: - toolz - numba - numbagg>=0.3 + - pip: + - git+https://github.com/dcherian/xarray.git@flox-preserve-dtype diff --git a/flox/aggregations.py b/flox/aggregations.py index 4e0312198..0906c8ccb 100644 --- a/flox/aggregations.py +++ b/flox/aggregations.py @@ -549,12 +549,15 @@ def quantile_new_dims_func(q) -> tuple[Dim]: return (Dim(name="quantile", values=q),) +# if the input contains integers or floats smaller than float64, +# the output data-type is float64. Otherwise, the output data-type is the same as that +# of the input. quantile = Aggregation( name="quantile", fill_value=dtypes.NA, chunk=None, combine=None, - final_dtype=np.floating, + final_dtype=np.float64, new_dims_func=quantile_new_dims_func, ) nanquantile = Aggregation( @@ -562,7 +565,7 @@ def quantile_new_dims_func(q) -> tuple[Dim]: fill_value=dtypes.NA, chunk=None, combine=None, - final_dtype=np.floating, + final_dtype=np.float64, new_dims_func=quantile_new_dims_func, ) mode = Aggregation(name="mode", fill_value=dtypes.NA, chunk=None, combine=None, preserves_dtype=True) @@ -801,9 +804,9 @@ def _initialize_aggregation( dtype_: np.dtype | None = ( np.dtype(dtype) if dtype is not None and not isinstance(dtype, np.dtype) else dtype ) - final_dtype = dtypes._normalize_dtype(dtype_ or agg.dtype_init["final"], array_dtype, fill_value) - if not agg.preserves_dtype: - final_dtype = dtypes._maybe_promote_int(final_dtype) + final_dtype = dtypes._normalize_dtype( + dtype_ or agg.dtype_init["final"], array_dtype, agg.preserves_dtype, fill_value + ) agg.dtype = { "user": dtype, # Save to automatically choose an engine "final": final_dtype, diff --git a/flox/xrdtypes.py b/flox/xrdtypes.py index 3fd0f4fec..34d0d2a52 100644 --- a/flox/xrdtypes.py +++ b/flox/xrdtypes.py @@ -150,9 +150,14 @@ def is_datetime_like(dtype): return np.issubdtype(dtype, np.datetime64) or np.issubdtype(dtype, np.timedelta64) -def _normalize_dtype(dtype: DTypeLike, array_dtype: np.dtype, fill_value=None) -> np.dtype: +def _normalize_dtype( + dtype: DTypeLike, array_dtype: np.dtype, preserves_dtype: bool, fill_value=None +) -> np.dtype: if dtype is None: - dtype = array_dtype + if not preserves_dtype: + dtype = _maybe_promote_int(array_dtype) + else: + dtype = array_dtype if dtype is np.floating: # mean, std, var always result in floating # but we preserve the array's dtype if it is floating diff --git a/tests/strategies.py b/tests/strategies.py index 6f28db32d..b1dc7ce3f 100644 --- a/tests/strategies.py +++ b/tests/strategies.py @@ -27,7 +27,7 @@ def supported_dtypes() -> st.SearchStrategy[np.dtype]: # TODO: stop excluding everything but U -array_dtype_st = supported_dtypes().filter(lambda x: x.kind not in "cmMU") +array_dtypes = supported_dtypes().filter(lambda x: x.kind not in "cmMU") by_dtype_st = supported_dtypes() NON_NUMPY_FUNCS = [ @@ -43,7 +43,7 @@ def supported_dtypes() -> st.SearchStrategy[np.dtype]: func_st = st.sampled_from([f for f in ALL_FUNCS if f not in NON_NUMPY_FUNCS and f not in SKIPPED_FUNCS]) numeric_arrays = npst.arrays( - elements={"allow_subnormal": False}, shape=npst.array_shapes(), dtype=array_dtype_st + elements={"allow_subnormal": False}, shape=npst.array_shapes(), dtype=array_dtypes ) all_arrays = npst.arrays( elements={"allow_subnormal": False}, diff --git a/tests/test_core.py b/tests/test_core.py index 2c33ebcfc..cef9ad8a1 100644 --- a/tests/test_core.py +++ b/tests/test_core.py @@ -81,7 +81,7 @@ def _get_array_func(func: str) -> Callable: def npfunc(x, **kwargs): x = np.asarray(x) - return (~np.isnan(x)).sum() + return (~xrutils.isnull(x)).sum(**kwargs) elif func in ["nanfirst", "nanlast"]: npfunc = getattr(xrutils, func) @@ -1984,3 +1984,16 @@ def test_blockwise_nans(): ) assert_equal(expected_groups, actual_groups) assert_equal(expected, actual) + + +@pytest.mark.parametrize("func", ["sum", "prod", "count", "nansum"]) +@pytest.mark.parametrize("engine", ["flox", "numpy"]) +def test_agg_dtypes(func, engine): + # regression test for GH388 + counts = np.array([0, 2, 1, 0, 1]) + group = np.array([1, 1, 1, 2, 2]) + actual, _ = groupby_reduce( + counts, group, expected_groups=(np.array([1, 2]),), func=func, dtype="uint8", engine=engine + ) + expected = _get_array_func(func)(counts, dtype="uint8") + assert actual.dtype == np.uint8 == expected.dtype diff --git a/tests/test_properties.py b/tests/test_properties.py index 584314cea..0437ef253 100644 --- a/tests/test_properties.py +++ b/tests/test_properties.py @@ -20,7 +20,7 @@ from flox.xrutils import notnull from . import assert_equal -from .strategies import by_arrays, chunked_arrays, func_st, numeric_arrays +from .strategies import array_dtypes, by_arrays, chunked_arrays, func_st, numeric_arrays from .strategies import chunks as chunks_strategy dask.config.set(scheduler="sync") @@ -244,3 +244,25 @@ def test_first_last_useless(data, func): actual, groups = groupby_reduce(array, by, axis=-1, func=func, engine="numpy") expected = np.zeros(shape[:-1] + (len(groups),), dtype=array.dtype) assert_equal(actual, expected) + + +@given( + func=st.sampled_from(["sum", "prod", "nansum", "nanprod"]), + engine=st.sampled_from(["numpy", "flox"]), + array_dtype=st.none() | array_dtypes, + dtype=st.none() | array_dtypes, +) +def test_agg_dtype_specified(func, array_dtype, dtype, engine): + # regression test for GH388 + counts = np.array([0, 2, 1, 0, 1], dtype=array_dtype) + group = np.array([1, 1, 1, 2, 2]) + actual, _ = groupby_reduce( + counts, + group, + expected_groups=(np.array([1, 2]),), + func=func, + dtype=dtype, + engine=engine, + ) + expected = getattr(np, func)(counts, keepdims=True, dtype=dtype) + assert actual.dtype == expected.dtype diff --git a/tests/test_xarray.py b/tests/test_xarray.py index 2592fa07c..9423eb11e 100644 --- a/tests/test_xarray.py +++ b/tests/test_xarray.py @@ -24,7 +24,7 @@ # test against legacy xarray implementation # avoid some compilation overhead -xr.set_options(use_flox=False, use_numbagg=False) +xr.set_options(use_flox=False, use_numbagg=False, use_bottleneck=False) tolerance64 = {"rtol": 1e-15, "atol": 1e-18} np.random.seed(123) @@ -760,3 +760,26 @@ def test_direct_reduction(func): with xr.set_options(use_flox=False): expected = getattr(data.groupby("x", squeeze=False), func)(**kwargs) xr.testing.assert_identical(expected, actual) + + +@pytest.mark.parametrize("reduction", ["max", "min", "nanmax", "nanmin", "sum", "nansum", "prod", "nanprod"]) +def test_groupby_preserve_dtype(reduction): + # all groups are present, we should follow numpy exactly + ds = xr.Dataset( + { + "test": ( + ["x", "y"], + np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]], dtype="int16"), + ) + }, + coords={"idx": ("x", [1, 2, 1])}, + ) + + kwargs = {"engine": "numpy"} + if "nan" in reduction: + kwargs["skipna"] = True + with xr.set_options(use_flox=True): + actual = getattr(ds.groupby("idx"), reduction.removeprefix("nan"))(**kwargs).test.dtype + expected = getattr(np, reduction)(ds.test.data, axis=0).dtype + + assert actual == expected