diff --git a/python/cudf/cudf/core/udf/groupby_utils.py b/python/cudf/cudf/core/udf/groupby_utils.py index ae09dd1d704..6e20ddbd08e 100644 --- a/python/cudf/cudf/core/udf/groupby_utils.py +++ b/python/cudf/cudf/core/udf/groupby_utils.py @@ -19,13 +19,12 @@ ) from cudf.core.udf.utils import ( Row, - _generate_cache_key, + _compile_or_get, _get_extensionty_size, _get_kernel, _get_udf_return_type, _supported_cols_from_frame, _supported_dtypes_from_frame, - precompiled, ) from cudf.utils.utils import _cudf_nvtx_annotate @@ -145,20 +144,18 @@ def jit_groupby_apply(offsets, grouped_values, function, *args): function : callable The user-defined function to execute """ - offsets = cp.asarray(offsets) - ngroups = len(offsets) - 1 - cache_key = _generate_cache_key( - grouped_values, function, args, suffix="__GROUPBY_APPLY_UDF" + kernel, return_type = _compile_or_get( + grouped_values, + function, + args, + kernel_getter=_get_groupby_apply_kernel, + suffix="__GROUPBY_APPLY_UDF", ) - if cache_key not in precompiled: - precompiled[cache_key] = _get_groupby_apply_kernel( - grouped_values, function, args - ) - kernel, return_type = precompiled[cache_key] + offsets = cp.asarray(offsets) + ngroups = len(offsets) - 1 - return_type = numpy_support.as_dtype(return_type) output = cudf.core.column.column_empty(ngroups, dtype=return_type) launch_args = [ diff --git a/python/cudf/cudf/core/udf/utils.py b/python/cudf/cudf/core/udf/utils.py index d890b94127f..9d7df530ccc 100644 --- a/python/cudf/cudf/core/udf/utils.py +++ b/python/cudf/cudf/core/udf/utils.py @@ -283,7 +283,9 @@ def _generate_cache_key(frame, func: Callable, args, suffix="__APPLY_UDF"): @_cudf_nvtx_annotate -def _compile_or_get(frame, func, args, kernel_getter=None): +def _compile_or_get( + frame, func, args, kernel_getter=None, suffix="__APPLY_UDF" +): """ Return a compiled kernel in terms of MaskedTypes that launches a kernel equivalent of `f` for the dtypes of `df`. The kernel uses @@ -308,7 +310,7 @@ def _compile_or_get(frame, func, args, kernel_getter=None): raise TypeError("only scalar valued args are supported by apply") # check to see if we already compiled this function - cache_key = _generate_cache_key(frame, func, args) + cache_key = _generate_cache_key(frame, func, args, suffix=suffix) if precompiled.get(cache_key) is not None: kernel, masked_or_scalar = precompiled[cache_key] return kernel, masked_or_scalar