diff --git a/python/cudf/cudf/tests/test_reshape.py b/python/cudf/cudf/tests/test_reshape.py index df03104eda4..181bff8512a 100644 --- a/python/cudf/cudf/tests/test_reshape.py +++ b/python/cudf/cudf/tests/test_reshape.py @@ -596,6 +596,42 @@ def test_pivot_table_simple(data, aggfunc, fill_value): assert_eq(expected, actual, check_dtype=False) +@pytest.mark.parametrize( + "data", + [ + { + "A": ["one", "one", "two", "three"] * 6, + "B": ["A", "B", "C"] * 8, + "C": ["foo", "foo", "foo", "bar", "bar", "bar"] * 4, + "D": np.random.randn(24), + "E": np.random.randn(24), + } + ], +) +@pytest.mark.parametrize( + "aggfunc", ["mean", "count", {"D": "sum", "E": "count"}] +) +@pytest.mark.parametrize("fill_value", [0]) +def test_dataframe_pivot_table_simple(data, aggfunc, fill_value): + pdf = pd.DataFrame(data) + expected = pdf.pivot_table( + values=["D", "E"], + index=["A", "B"], + columns=["C"], + aggfunc=aggfunc, + fill_value=fill_value, + ) + cdf = cudf.DataFrame(data) + actual = cdf.pivot_table( + values=["D", "E"], + index=["A", "B"], + columns=["C"], + aggfunc=aggfunc, + fill_value=fill_value, + ) + assert_eq(expected, actual, check_dtype=False) + + def test_crosstab_simple(): a = np.array( [