From 1017045f46c44d205d6294bbb95e7bade1415e9c Mon Sep 17 00:00:00 2001 From: Bradley Dice Date: Fri, 28 Oct 2022 04:33:49 -0700 Subject: [PATCH] Add DataFrame.pivot_table. (#12015) This PR adds the method `DataFrame.pivot_table` to enhance pandas API compatibility. It uses the exact same arguments as `cudf.pivot_table` but automatically supplies the first argument (a DataFrame). Related: #11314 Authors: - Bradley Dice (https://github.com/bdice) Approvers: - Matthew Roeschke (https://github.com/mroeschke) - Ashwin Srinath (https://github.com/shwina) URL: https://github.com/rapidsai/cudf/pull/12015 --- docs/cudf/source/api_docs/dataframe.rst | 1 + python/cudf/cudf/core/dataframe.py | 30 ++++++++++++++++++++- python/cudf/cudf/tests/test_reshape.py | 36 +++++++++++++++++++++++++ 3 files changed, 66 insertions(+), 1 deletion(-) diff --git a/docs/cudf/source/api_docs/dataframe.rst b/docs/cudf/source/api_docs/dataframe.rst index bd868e85cc7..f5c9053ec92 100644 --- a/docs/cudf/source/api_docs/dataframe.rst +++ b/docs/cudf/source/api_docs/dataframe.rst @@ -210,6 +210,7 @@ Reshaping, sorting, transposing DataFrame.interleave_columns DataFrame.partition_by_hash DataFrame.pivot + DataFrame.pivot_table DataFrame.scatter_by_map DataFrame.sort_values DataFrame.sort_index diff --git a/python/cudf/cudf/core/dataframe.py b/python/cudf/cudf/core/dataframe.py index a3dd82d060e..02c5542a88a 100644 --- a/python/cudf/cudf/core/dataframe.py +++ b/python/cudf/cudf/core/dataframe.py @@ -6407,11 +6407,39 @@ def append( @_cudf_nvtx_annotate @copy_docstring(reshape.pivot) def pivot(self, index, columns, values=None): - return cudf.core.reshape.pivot( self, index=index, columns=columns, values=values ) + @_cudf_nvtx_annotate + @copy_docstring(reshape.pivot_table) + def pivot_table( + self, + values=None, + index=None, + columns=None, + aggfunc="mean", + fill_value=None, + margins=False, + dropna=None, + margins_name="All", + observed=False, + sort=True, + ): + return cudf.core.reshape.pivot_table( + self, + values=values, + index=index, + columns=columns, + aggfunc=aggfunc, + fill_value=fill_value, + margins=margins, + dropna=dropna, + margins_name=margins_name, + observed=observed, + sort=sort, + ) + @_cudf_nvtx_annotate @copy_docstring(reshape.unstack) def unstack(self, level=-1, fill_value=None): diff --git a/python/cudf/cudf/tests/test_reshape.py b/python/cudf/cudf/tests/test_reshape.py index df03104eda4..181bff8512a 100644 --- a/python/cudf/cudf/tests/test_reshape.py +++ b/python/cudf/cudf/tests/test_reshape.py @@ -596,6 +596,42 @@ def test_pivot_table_simple(data, aggfunc, fill_value): assert_eq(expected, actual, check_dtype=False) +@pytest.mark.parametrize( + "data", + [ + { + "A": ["one", "one", "two", "three"] * 6, + "B": ["A", "B", "C"] * 8, + "C": ["foo", "foo", "foo", "bar", "bar", "bar"] * 4, + "D": np.random.randn(24), + "E": np.random.randn(24), + } + ], +) +@pytest.mark.parametrize( + "aggfunc", ["mean", "count", {"D": "sum", "E": "count"}] +) +@pytest.mark.parametrize("fill_value", [0]) +def test_dataframe_pivot_table_simple(data, aggfunc, fill_value): + pdf = pd.DataFrame(data) + expected = pdf.pivot_table( + values=["D", "E"], + index=["A", "B"], + columns=["C"], + aggfunc=aggfunc, + fill_value=fill_value, + ) + cdf = cudf.DataFrame(data) + actual = cdf.pivot_table( + values=["D", "E"], + index=["A", "B"], + columns=["C"], + aggfunc=aggfunc, + fill_value=fill_value, + ) + assert_eq(expected, actual, check_dtype=False) + + def test_crosstab_simple(): a = np.array( [