Skip to content

Commit

Permalink
Add DataFrame.pivot_table. (#12015)
Browse files Browse the repository at this point in the history
This PR adds the method `DataFrame.pivot_table` to enhance pandas API compatibility. It uses the exact same arguments as `cudf.pivot_table` but automatically supplies the first argument (a DataFrame).

Related: #11314

Authors:
  - Bradley Dice (https://github.com/bdice)

Approvers:
  - Matthew Roeschke (https://github.com/mroeschke)
  - Ashwin Srinath (https://github.com/shwina)

URL: #12015
  • Loading branch information
bdice authored Oct 28, 2022
1 parent 69fac8a commit 1017045
Show file tree
Hide file tree
Showing 3 changed files with 66 additions and 1 deletion.
1 change: 1 addition & 0 deletions docs/cudf/source/api_docs/dataframe.rst
Original file line number Diff line number Diff line change
Expand Up @@ -210,6 +210,7 @@ Reshaping, sorting, transposing
DataFrame.interleave_columns
DataFrame.partition_by_hash
DataFrame.pivot
DataFrame.pivot_table
DataFrame.scatter_by_map
DataFrame.sort_values
DataFrame.sort_index
Expand Down
30 changes: 29 additions & 1 deletion python/cudf/cudf/core/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -6407,11 +6407,39 @@ def append(
@_cudf_nvtx_annotate
@copy_docstring(reshape.pivot)
def pivot(self, index, columns, values=None):

return cudf.core.reshape.pivot(
self, index=index, columns=columns, values=values
)

@_cudf_nvtx_annotate
@copy_docstring(reshape.pivot_table)
def pivot_table(
self,
values=None,
index=None,
columns=None,
aggfunc="mean",
fill_value=None,
margins=False,
dropna=None,
margins_name="All",
observed=False,
sort=True,
):
return cudf.core.reshape.pivot_table(
self,
values=values,
index=index,
columns=columns,
aggfunc=aggfunc,
fill_value=fill_value,
margins=margins,
dropna=dropna,
margins_name=margins_name,
observed=observed,
sort=sort,
)

@_cudf_nvtx_annotate
@copy_docstring(reshape.unstack)
def unstack(self, level=-1, fill_value=None):
Expand Down
36 changes: 36 additions & 0 deletions python/cudf/cudf/tests/test_reshape.py
Original file line number Diff line number Diff line change
Expand Up @@ -596,6 +596,42 @@ def test_pivot_table_simple(data, aggfunc, fill_value):
assert_eq(expected, actual, check_dtype=False)


@pytest.mark.parametrize(
"data",
[
{
"A": ["one", "one", "two", "three"] * 6,
"B": ["A", "B", "C"] * 8,
"C": ["foo", "foo", "foo", "bar", "bar", "bar"] * 4,
"D": np.random.randn(24),
"E": np.random.randn(24),
}
],
)
@pytest.mark.parametrize(
"aggfunc", ["mean", "count", {"D": "sum", "E": "count"}]
)
@pytest.mark.parametrize("fill_value", [0])
def test_dataframe_pivot_table_simple(data, aggfunc, fill_value):
pdf = pd.DataFrame(data)
expected = pdf.pivot_table(
values=["D", "E"],
index=["A", "B"],
columns=["C"],
aggfunc=aggfunc,
fill_value=fill_value,
)
cdf = cudf.DataFrame(data)
actual = cdf.pivot_table(
values=["D", "E"],
index=["A", "B"],
columns=["C"],
aggfunc=aggfunc,
fill_value=fill_value,
)
assert_eq(expected, actual, check_dtype=False)


def test_crosstab_simple():
a = np.array(
[
Expand Down

0 comments on commit 1017045

Please sign in to comment.