From 89557bb0efad2d32098ba86b78e4f4706e7fe88f Mon Sep 17 00:00:00 2001 From: GALI PREM SAGAR Date: Wed, 13 Sep 2023 19:22:46 -0500 Subject: [PATCH] Allow `numeric_only=True` for reduction operations on numeric types (#14111) Fixes: #14090 This PR allows passing `numeric_only=True` for reduction operation on numerical columns. Authors: - GALI PREM SAGAR (https://github.com/galipremsagar) Approvers: - Matthew Roeschke (https://github.com/mroeschke) URL: https://github.com/rapidsai/cudf/pull/14111 --- python/cudf/cudf/core/single_column_frame.py | 6 ++- python/cudf/cudf/tests/test_stats.py | 44 ++++++++++---------- 2 files changed, 26 insertions(+), 24 deletions(-) diff --git a/python/cudf/cudf/core/single_column_frame.py b/python/cudf/cudf/core/single_column_frame.py index 7c019f0722c..6a56ab8f3a5 100644 --- a/python/cudf/cudf/core/single_column_frame.py +++ b/python/cudf/cudf/core/single_column_frame.py @@ -49,9 +49,11 @@ def _reduce( if level is not None: raise NotImplementedError("level parameter is not implemented yet") - if numeric_only: + if numeric_only and not isinstance( + self._column, cudf.core.column.numerical_base.NumericalBaseColumn + ): raise NotImplementedError( - f"Series.{op} does not implement numeric_only" + f"Series.{op} does not implement numeric_only." ) try: return getattr(self._column, op)(**kwargs) diff --git a/python/cudf/cudf/tests/test_stats.py b/python/cudf/cudf/tests/test_stats.py index 6478fbaad95..463cdb8a7f4 100644 --- a/python/cudf/cudf/tests/test_stats.py +++ b/python/cudf/cudf/tests/test_stats.py @@ -247,30 +247,37 @@ def test_misc_quantiles(data, q): ], ) @pytest.mark.parametrize("null_flag", [False, True]) -def test_kurtosis_series(data, null_flag): +@pytest.mark.parametrize("numeric_only", [False, True]) +def test_kurtosis_series(data, null_flag, numeric_only): pdata = data.to_pandas() if null_flag and len(data) > 2: data.iloc[[0, 2]] = None pdata.iloc[[0, 2]] = None - got = data.kurtosis() + got = data.kurtosis(numeric_only=numeric_only) got = got if np.isscalar(got) else got.to_numpy() - expected = pdata.kurtosis() + expected = pdata.kurtosis(numeric_only=numeric_only) np.testing.assert_array_almost_equal(got, expected) - got = data.kurt() + got = data.kurt(numeric_only=numeric_only) got = got if np.isscalar(got) else got.to_numpy() - expected = pdata.kurt() + expected = pdata.kurt(numeric_only=numeric_only) np.testing.assert_array_almost_equal(got, expected) - got = data.kurt(numeric_only=False) - got = got if np.isscalar(got) else got.to_numpy() - expected = pdata.kurt(numeric_only=False) - np.testing.assert_array_almost_equal(got, expected) - with pytest.raises(NotImplementedError): - data.kurt(numeric_only=True) +@pytest.mark.parametrize("op", ["skew", "kurt"]) +def test_kurt_skew_error(op): + gs = cudf.Series(["ab", "cd"]) + ps = gs.to_pandas() + + with pytest.raises(FutureWarning): + assert_exceptions_equal( + getattr(gs, op), + getattr(ps, op), + lfunc_args_and_kwargs=([], {"numeric_only": True}), + rfunc_args_and_kwargs=([], {"numeric_only": True}), + ) @pytest.mark.parametrize( @@ -290,26 +297,19 @@ def test_kurtosis_series(data, null_flag): ], ) @pytest.mark.parametrize("null_flag", [False, True]) -def test_skew_series(data, null_flag): +@pytest.mark.parametrize("numeric_only", [False, True]) +def test_skew_series(data, null_flag, numeric_only): pdata = data.to_pandas() if null_flag and len(data) > 2: data.iloc[[0, 2]] = None pdata.iloc[[0, 2]] = None - got = data.skew() - expected = pdata.skew() + got = data.skew(numeric_only=numeric_only) + expected = pdata.skew(numeric_only=numeric_only) got = got if np.isscalar(got) else got.to_numpy() np.testing.assert_array_almost_equal(got, expected) - got = data.skew(numeric_only=False) - expected = pdata.skew(numeric_only=False) - got = got if np.isscalar(got) else got.to_numpy() - np.testing.assert_array_almost_equal(got, expected) - - with pytest.raises(NotImplementedError): - data.skew(numeric_only=True) - @pytest.mark.parametrize("dtype", params_dtypes) @pytest.mark.parametrize("num_na", [0, 1, 50, 99, 100])