Skip to content

Commit

Permalink
Cache JIT GroupBy.apply functions (#12802)
Browse files Browse the repository at this point in the history
This PR sends incoming UDFs that go through the `engine='jit'` codepath through the main UDF cache. This should avoid recompiling if a user reuses the same UDF on different input data, so long as the types of that data are the same.

Authors:
  - https://github.com/brandon-b-miller

Approvers:
  - Ashwin Srinath (https://github.com/shwina)
  - Bradley Dice (https://github.com/bdice)

URL: #12802
  • Loading branch information
brandon-b-miller authored Mar 24, 2023
1 parent fb96fc8 commit a0473cf
Show file tree
Hide file tree
Showing 3 changed files with 51 additions and 4 deletions.
15 changes: 12 additions & 3 deletions python/cudf/cudf/core/udf/groupby_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,13 @@
groupby_apply_kernel_template,
)
from cudf.core.udf.utils import (
_generate_cache_key,
_get_extensionty_size,
_get_kernel,
_get_udf_return_type,
_supported_cols_from_frame,
_supported_dtypes_from_frame,
precompiled,
)
from cudf.utils.utils import _cudf_nvtx_annotate

Expand Down Expand Up @@ -147,12 +149,19 @@ def jit_groupby_apply(offsets, grouped_values, function, *args):
offsets = cp.asarray(offsets)
ngroups = len(offsets) - 1

kernel, return_type = _get_groupby_apply_kernel(
grouped_values, function, args
cache_key = _generate_cache_key(
grouped_values, function, suffix="__GROUPBY_APPLY_UDF"
)
return_type = numpy_support.as_dtype(return_type)

if cache_key not in precompiled:
precompiled[cache_key] = _get_groupby_apply_kernel(
grouped_values, function, args
)
kernel, return_type = precompiled[cache_key]

return_type = numpy_support.as_dtype(return_type)
output = cudf.core.column.column_empty(ngroups, dtype=return_type)

launch_args = [
offsets,
output,
Expand Down
3 changes: 2 additions & 1 deletion python/cudf/cudf/core/udf/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,7 +245,7 @@ def _mask_get(mask, pos):
return (mask[pos // MASK_BITSIZE] >> (pos % MASK_BITSIZE)) & 1


def _generate_cache_key(frame, func: Callable):
def _generate_cache_key(frame, func: Callable, suffix="__APPLY_UDF"):
"""Create a cache key that uniquely identifies a compilation.
A new compilation is needed any time any of the following things change:
Expand All @@ -259,6 +259,7 @@ def _generate_cache_key(frame, func: Callable):
),
*(col.mask is None for col in frame._data.values()),
*frame._data.keys(),
suffix,
)


Expand Down
37 changes: 37 additions & 0 deletions python/cudf/cudf/tests/test_groupby.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from cudf import DataFrame, Series
from cudf.core._compat import PANDAS_GE_150, PANDAS_LT_140
from cudf.core.udf.groupby_typing import SUPPORTED_GROUPBY_NUMPY_TYPES
from cudf.core.udf.utils import precompiled
from cudf.testing._utils import (
DATETIME_TYPES,
SIGNED_TYPES,
Expand Down Expand Up @@ -534,6 +535,42 @@ def diverging_block(grp_df):
run_groupby_apply_jit_test(df, diverging_block, ["a"])


def test_groupby_apply_caching():
# Make sure similar functions that differ
# by simple things like constants actually
# recompile

# begin with a clear cache
precompiled.clear()
assert precompiled.currsize == 0

data = cudf.DataFrame({"a": [1, 1, 1, 2, 2, 2], "b": [1, 2, 3, 4, 5, 6]})

def f(group):
return group["b"].mean() * 2

# a single run should result in a cache size of 1
run_groupby_apply_jit_test(data, f, ["a"])
assert precompiled.currsize == 1

# a second run with f should not increase the count
run_groupby_apply_jit_test(data, f, ["a"])
assert precompiled.currsize == 1

# changing a constant value inside the UDF should miss
def f(group):
return group["b"].mean() * 3

run_groupby_apply_jit_test(data, f, ["a"])
assert precompiled.currsize == 2

# changing the dtypes of the columns should miss
data["b"] = data["b"].astype("float64")
run_groupby_apply_jit_test(data, f, ["a"])

assert precompiled.currsize == 3


@pytest.mark.parametrize("nelem", [2, 3, 100, 500, 1000])
@pytest.mark.parametrize(
"func",
Expand Down

0 comments on commit a0473cf

Please sign in to comment.