Skip to content

Commit

Permalink
Fix nanmax, nanmin bug (xarray-contrib#411)
Browse files Browse the repository at this point in the history
* Add numpy vs dask property test

* Fix nanmin, nanmax bug
  • Loading branch information
dcherian authored Jan 8, 2025
1 parent 0c4b19f commit 9bd682c
Show file tree
Hide file tree
Showing 3 changed files with 51 additions and 6 deletions.
16 changes: 14 additions & 2 deletions flox/aggregations.py
Original file line number Diff line number Diff line change
Expand Up @@ -393,15 +393,17 @@ def _std_finalize(sumsq, sum_, count, ddof=0):
"nanmin",
chunk="nanmin",
combine="nanmin",
fill_value=dtypes.NA,
fill_value=dtypes.INF,
final_fill_value=dtypes.NA,
preserves_dtype=True,
)
max_ = Aggregation("max", chunk="max", combine="max", fill_value=dtypes.NINF, preserves_dtype=True)
nanmax = Aggregation(
"nanmax",
chunk="nanmax",
combine="nanmax",
fill_value=dtypes.NA,
fill_value=dtypes.NINF,
final_fill_value=dtypes.NA,
preserves_dtype=True,
)

Expand Down Expand Up @@ -845,6 +847,16 @@ def _initialize_aggregation(
# absent in one block, but present in another block
# We set it for numpy to get nansum, nanprod tests to pass
# where the identity element is 0, 1
# Also needed for nanmin, nanmax where intermediate fill_value is +-np.inf,
# but final_fill_value is dtypes.NA
if (
# TODO: this is a total hack, setting a default fill_value
# even though numpy doesn't define identity for nanmin, nanmax
agg.name in ["nanmin", "nanmax"] and min_count == 0
):
min_count = 1
agg.fill_value["user"] = agg.fill_value["user"] or agg.fill_value[agg.name]

if min_count > 0:
agg.min_count = min_count
agg.numpy += ("nanlen",)
Expand Down
8 changes: 4 additions & 4 deletions flox/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -1119,7 +1119,6 @@ def _finalize_results(
agg: Aggregation,
axis: T_Axes,
expected_groups: pd.Index | None,
fill_value: Any,
reindex: bool,
) -> FinalResultsDict:
"""Finalize results by
Expand All @@ -1142,6 +1141,7 @@ def _finalize_results(
else:
finalized[agg.name] = agg.finalize(*squeezed["intermediates"], **agg.finalize_kwargs)

fill_value = agg.fill_value["user"]
if min_count > 0:
count_mask = counts < min_count
if count_mask.any():
Expand Down Expand Up @@ -1183,7 +1183,7 @@ def _aggregate(
) -> FinalResultsDict:
"""Final aggregation step of tree reduction"""
results = combine(x_chunk, agg, axis, keepdims, is_aggregate=True)
return _finalize_results(results, agg, axis, expected_groups, fill_value, reindex)
return _finalize_results(results, agg, axis, expected_groups, reindex)


def _expand_dims(results: IntermediateDict) -> IntermediateDict:
Expand Down Expand Up @@ -1449,7 +1449,7 @@ def _reduce_blockwise(
if _is_arg_reduction(agg):
results["intermediates"][0] = np.unravel_index(results["intermediates"][0], array.shape)[-1]

result = _finalize_results(results, agg, axis, expected_groups, fill_value=fill_value, reindex=reindex)
result = _finalize_results(results, agg, axis, expected_groups, reindex=reindex)
return result


Expand Down Expand Up @@ -1926,7 +1926,7 @@ def _groupby_combine(a, axis, dummy_axis, dtype, keepdims):
def _groupby_aggregate(a):
# Convert cubed dict to one that _finalize_results works with
results = {"groups": expected_groups, "intermediates": a.values()}
out = _finalize_results(results, agg, axis, expected_groups, fill_value, reindex)
out = _finalize_results(results, agg, axis, expected_groups, reindex)
return out[agg.name]

# convert list of dtypes to a structured dtype for cubed
Expand Down
33 changes: 33 additions & 0 deletions tests/test_properties.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,39 @@ def test_groupby_reduce(data, array, func: str) -> None:
assert_equal(expected, actual, tolerance)


@given(
data=st.data(),
array=chunked_arrays(arrays=numeric_arrays),
func=func_st,
)
def test_groupby_reduce_numpy_vs_dask(data, array, func: str) -> None:
numpy_array = array.compute()
# overflow behaviour differs between bincount and sum (for example)
assume(not_overflowing_array(numpy_array))
# TODO: fix var for complex numbers upstream
assume(not (("quantile" in func or "var" in func or "std" in func) and array.dtype.kind == "c"))
# # arg* with nans in array are weird
assume("arg" not in func and not np.any(np.isnan(numpy_array.ravel())))
if func in ["nanmedian", "nanquantile", "median", "quantile"]:
array = array.rechunk({-1: -1})

axis = -1
by = data.draw(by_arrays(shape=(array.shape[-1],)))
kwargs = {"q": 0.8} if "quantile" in func else {}
flox_kwargs: dict[str, Any] = {}

kwargs = dict(
func=func,
axis=axis,
engine="numpy",
**flox_kwargs,
finalize_kwargs=kwargs,
)
result_dask, *_ = groupby_reduce(array, by, **kwargs)
result_numpy, *_ = groupby_reduce(numpy_array, by, **kwargs)
assert_equal(result_numpy, result_dask)


@given(
data=st.data(),
array=chunked_arrays(arrays=numeric_arrays),
Expand Down

0 comments on commit 9bd682c

Please sign in to comment.