From a48fb545c11392e12d9bb86183242439726bbaed Mon Sep 17 00:00:00 2001 From: galipremsagar Date: Wed, 30 Aug 2023 19:04:31 -0700 Subject: [PATCH] Match pandas.Series.quantile behavior --- python/cudf/cudf/core/column/numerical_base.py | 10 ++++++++++ python/cudf/cudf/core/dataframe.py | 13 ++++++++++--- python/cudf/cudf/core/series.py | 9 +++++++-- python/cudf/cudf/tests/test_quantiles.py | 15 +++++++++++++++ 4 files changed, 42 insertions(+), 5 deletions(-) diff --git a/python/cudf/cudf/core/column/numerical_base.py b/python/cudf/cudf/core/column/numerical_base.py index 08c2f7cc7b1..e59d56af9dc 100644 --- a/python/cudf/cudf/core/column/numerical_base.py +++ b/python/cudf/cudf/core/column/numerical_base.py @@ -115,6 +115,16 @@ def quantile( result = self._numeric_quantile(q, interpolation, exact) if return_scalar: scalar_result = result.element_indexing(0) + if interpolation in {"lower", "higher", "nearest"}: + try: + new_scalar = self.dtype.type(scalar_result) + scalar_result = ( + new_scalar + if new_scalar == scalar_result + else scalar_result + ) + except (TypeError, ValueError): + pass return ( cudf.utils.dtypes._get_nan_for_dtype(self.dtype) if scalar_result is NA diff --git a/python/cudf/cudf/core/dataframe.py b/python/cudf/cudf/core/dataframe.py index 3f89f78d278..191e5a81c5a 100644 --- a/python/cudf/cudf/core/dataframe.py +++ b/python/cudf/cudf/core/dataframe.py @@ -5487,16 +5487,23 @@ def quantile( numeric_only : bool, default True If False, the quantile of datetime and timedelta data will be computed as well. - interpolation : {`linear`, `lower`, `higher`, `midpoint`, `nearest`} + interpolation : {'linear', 'lower', 'higher', 'midpoint', 'nearest'} This parameter specifies the interpolation method to use, when the desired quantile lies between two data points i and j. - Default is ``linear`` for ``method="single"``, and ``nearest`` + Default is ``'linear'`` for ``method="single"``, and ``'nearest'`` for ``method="table"``. + + * linear: `i + (j - i) * fraction`, where `fraction` is the + fractional part of the index surrounded by `i` and `j`. + * lower: `i`. + * higher: `j`. + * nearest: `i` or `j` whichever is nearest. + * midpoint: (`i` + `j`) / 2. columns : list of str List of column names to include. exact : boolean Whether to use approximate or exact quantile algorithm. - method : {`single`, `table`}, default `single` + method : {'single', 'table'}, default `'single'` Whether to compute quantiles per-column ('single') or over all columns ('table'). When 'table', the only allowed interpolation methods are 'nearest', 'lower', and 'higher'. diff --git a/python/cudf/cudf/core/series.py b/python/cudf/cudf/core/series.py index 30d584c2270..2fef741ac09 100644 --- a/python/cudf/cudf/core/series.py +++ b/python/cudf/cudf/core/series.py @@ -3132,8 +3132,13 @@ def quantile( interpolation : {'linear', 'lower', 'higher', 'midpoint', 'nearest'} This optional parameter specifies the interpolation method to use, when the desired quantile lies between two data points i and j: - columns : list of str - List of column names to include. + + * linear: `i + (j - i) * fraction`, where `fraction` is the + fractional part of the index surrounded by `i` and `j`. + * lower: `i`. + * higher: `j`. + * nearest: `i` or `j` whichever is nearest. + * midpoint: (`i` + `j`) / 2. exact : boolean Whether to use approximate or exact quantile algorithm. quant_index : boolean diff --git a/python/cudf/cudf/tests/test_quantiles.py b/python/cudf/cudf/tests/test_quantiles.py index 53b06e64a91..8b126073a0f 100644 --- a/python/cudf/cudf/tests/test_quantiles.py +++ b/python/cudf/cudf/tests/test_quantiles.py @@ -75,3 +75,18 @@ def test_quantile_q_type(): ), ): gs.quantile(cudf.DataFrame()) + + +@pytest.mark.parametrize( + "interpolation", ["linear", "lower", "higher", "midpoint", "nearest"] +) +def test_quantile_type_int_float(interpolation): + data = [1, 3, 4] + psr = pd.Series(data) + gsr = cudf.Series(data) + + expected = psr.quantile(0.5, interpolation=interpolation) + actual = gsr.quantile(0.5, interpolation=interpolation) + + assert expected == actual + assert type(expected) == type(actual)