From b8fa61a654c82207169f00492cab5f87cd665368 Mon Sep 17 00:00:00 2001 From: Pascal Bourgault Date: Tue, 16 Aug 2022 14:35:13 -0400 Subject: [PATCH 1/2] Test datetime with all engines - use numeric for count --- flox/xarray.py | 4 ++-- tests/test_xarray.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/flox/xarray.py b/flox/xarray.py index 9302dc318..3e8f8a279 100644 --- a/flox/xarray.py +++ b/flox/xarray.py @@ -296,7 +296,7 @@ def wrapper(array, *by, func, skipna, **kwargs): if "nan" not in func and func not in ["all", "any", "count"]: func = f"nan{func}" - requires_numeric = func not in ["count", "any", "all"] + requires_numeric = func not in ["any", "all"] if requires_numeric: is_npdatetime = array.dtype.kind in "Mm" is_cftime = _contains_cftime_datetimes(array) @@ -311,7 +311,7 @@ def wrapper(array, *by, func, skipna, **kwargs): result, *groups = groupby_reduce(array, *by, func=func, **kwargs) - if requires_numeric: + if requires_numeric and func != 'count': if is_npdatetime: return result.astype(dtype) + offset elif is_cftime: diff --git a/tests/test_xarray.py b/tests/test_xarray.py index 0a696b24a..90a2d50c4 100644 --- a/tests/test_xarray.py +++ b/tests/test_xarray.py @@ -420,7 +420,7 @@ def test_cache(): @pytest.mark.parametrize("use_cftime", [True, False]) @pytest.mark.parametrize("func", ["count", "mean"]) -def test_datetime_array_reduce(use_cftime, func): +def test_datetime_array_reduce(use_cftime, func, engine): time = xr.DataArray( xr.date_range("2009-01-01", "2012-12-31", use_cftime=use_cftime), @@ -428,7 +428,7 @@ def test_datetime_array_reduce(use_cftime, func): name="time", ) expected = getattr(time.resample(time="YS"), func)() - actual = resample_reduce(time.resample(time="YS"), func=func, engine="flox") + actual = resample_reduce(time.resample(time="YS"), func=func, engine=engine) assert_equal(expected, actual) From b28fccbdfa305ab1ab3ee22521649a950dca2ae3 Mon Sep 17 00:00:00 2001 From: Pascal Bourgault Date: Tue, 16 Aug 2022 15:09:54 -0400 Subject: [PATCH 2/2] keep engine flox for dtype O and func count : better performance --- flox/xarray.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/flox/xarray.py b/flox/xarray.py index 3e8f8a279..7234f8826 100644 --- a/flox/xarray.py +++ b/flox/xarray.py @@ -296,7 +296,10 @@ def wrapper(array, *by, func, skipna, **kwargs): if "nan" not in func and func not in ["all", "any", "count"]: func = f"nan{func}" - requires_numeric = func not in ["any", "all"] + # Flox's count works with non-numeric and its faster than converting. + requires_numeric = func not in ["count", "any", "all"] or ( + func == "count" and engine != "flox" + ) if requires_numeric: is_npdatetime = array.dtype.kind in "Mm" is_cftime = _contains_cftime_datetimes(array) @@ -311,7 +314,8 @@ def wrapper(array, *by, func, skipna, **kwargs): result, *groups = groupby_reduce(array, *by, func=func, **kwargs) - if requires_numeric and func != 'count': + # Output of count has an int dtype. + if requires_numeric and func != "count": if is_npdatetime: return result.astype(dtype) + offset elif is_cftime: