Skip to content

Commit

Permalink
Implement groupby apply with JIT (#11452)
Browse files Browse the repository at this point in the history
Experimental cuDF Groupby Apply JIT pipeline.

Authors:
  - Bobbi Winema Yogatama (https://github.com/bwyogatama)
  - Vyas Ramasubramani (https://github.com/vyasr)
  - https://github.com/brandon-b-miller
  - Yunsong Wang (https://github.com/PointKernel)

Approvers:
  - AJ Schmidt (https://github.com/ajschmidt8)
  - Yunsong Wang (https://github.com/PointKernel)
  - Lawrence Mitchell (https://github.com/wence-)
  - Ashwin Srinath (https://github.com/shwina)
  - David Wendt (https://github.com/davidwendt)
  - Vyas Ramasubramani (https://github.com/vyasr)
  - Robert Maynard (https://github.com/robertmaynard)

URL: #11452
  • Loading branch information
bwyogatama authored Jan 27, 2023
1 parent fb17ac7 commit 7695850
Show file tree
Hide file tree
Showing 17 changed files with 1,438 additions and 190 deletions.
3 changes: 2 additions & 1 deletion .github/CODEOWNERS
Validating CODEOWNERS rules …
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
#cpp code owners
cpp/ @rapidsai/cudf-cpp-codeowners
cpp/ @rapidsai/cudf-cpp-codeowners
python/cudf/udf_cpp/ @rapidsai/cudf-cpp-codeowners

#python code owners
python/ @rapidsai/cudf-python-codeowners
Expand Down
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ python/cudf/*/_cuda/*.cpp
python/cudf/*.ipynb
python/cudf/.ipynb_checkpoints
python/*/record.txt
python/cudf/cudf/core/udf/*.ptx
python/cudf_kafka/*/_lib/**/*.cpp
python/cudf_kafka/*/_lib/**/*.h
python/custreamz/*/_lib/**/*.cpp
Expand Down
3 changes: 3 additions & 0 deletions ci/release/update-version.sh
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,9 @@ sed_runner 's/'"cudf_version .*)"'/'"cudf_version ${NEXT_FULL_TAG})"'/g' python/
# Strings UDF update
sed_runner 's/'"strings_udf_version .*)"'/'"strings_udf_version ${NEXT_FULL_TAG})"'/g' python/strings_udf/CMakeLists.txt

# Groupby UDF update
sed_runner 's/'"VERSION ${CURRENT_SHORT_TAG}.*"'/'"VERSION ${NEXT_FULL_TAG}"'/g' python/cudf/udf_cpp/CMakeLists.txt

# cpp libcudf_kafka update
sed_runner 's/'"VERSION ${CURRENT_SHORT_TAG}.*"'/'"VERSION ${NEXT_FULL_TAG}"'/g' cpp/libcudf_kafka/CMakeLists.txt

Expand Down
7 changes: 5 additions & 2 deletions python/cudf/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# =============================================================================
# Copyright (c) 2022, NVIDIA CORPORATION.
# Copyright (c) 2022-2023, NVIDIA CORPORATION.
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except
# in compliance with the License. You may obtain a copy of the License at
Expand All @@ -17,6 +17,8 @@ cmake_minimum_required(VERSION 3.23.1 FATAL_ERROR)
set(cudf_version 23.02.00)

include(../../fetch_rapids.cmake)
include(rapids-cuda)
rapids_cuda_init_architectures(cudf-python)

project(
cudf-python
Expand All @@ -25,7 +27,7 @@ project(
# language to be enabled here. The test project that is built in scikit-build to verify
# various linking options for the python library is hardcoded to build with C, so until
# that is fixed we need to keep C.
C CXX
C CXX CUDA
)

option(FIND_CUDF_CPP "Search for existing CUDF C++ installations before defaulting to local files"
Expand Down Expand Up @@ -117,6 +119,7 @@ endif()
rapids_cython_init()

add_subdirectory(cudf/_lib)
add_subdirectory(udf_cpp/groupby)

include(cmake/Modules/ProtobufHelpers.cmake)
codegen_protoc(cudf/utils/metadata/orc_column_statistics.proto)
Expand Down
11 changes: 9 additions & 2 deletions python/cudf/cudf/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) 2018-2022, NVIDIA CORPORATION.
# Copyright (c) 2018-2023, NVIDIA CORPORATION.

from cudf.utils.gpu_utils import validate_setup

Expand Down Expand Up @@ -88,7 +88,14 @@
pass
else:
# Patch Numba to support CUDA enhanced compatibility.
patch_numba_linker_if_needed()
# cuDF requires a stronger set of conditions than what is
# checked by patch_numba_linker_if_needed due to the PTX
# files needed for JIT Groupby Apply and string UDFs
from cudf.core.udf.groupby_utils import dev_func_ptx
from cudf.core.udf.utils import _setup_numba_linker

_setup_numba_linker(dev_func_ptx)

del patch_numba_linker_if_needed

cuda.set_memory_manager(rmm.RMMNumbaManager)
Expand Down
133 changes: 104 additions & 29 deletions python/cudf/cudf/core/groupby/groupby.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from cudf.core.column_accessor import ColumnAccessor
from cudf.core.mixins import Reducible, Scannable
from cudf.core.multiindex import MultiIndex
from cudf.core.udf.groupby_utils import jit_groupby_apply
from cudf.utils.utils import GetAttrGetItemMixin, _cudf_nvtx_annotate


Expand Down Expand Up @@ -786,14 +787,83 @@ def pipe(self, func, *args, **kwargs):
"""
return cudf.core.common.pipe(self, func, *args, **kwargs)

def apply(self, function, *args):
def _jit_groupby_apply(
self, function, group_names, offsets, group_keys, grouped_values, *args
):
# Nulls are not yet supported
for colname in self.grouping.values._data.keys():
if self.obj._data[colname].has_nulls():
raise ValueError(
"Nulls not yet supported with groupby JIT engine"
)

chunk_results = jit_groupby_apply(
offsets, grouped_values, function, *args
)
result = cudf.Series._from_data(
{None: chunk_results}, index=group_names
)
result.index.names = self.grouping.names
result = result.reset_index()
result[None] = result.pop(0)
return result

def _iterative_groupby_apply(
self, function, group_names, offsets, group_keys, grouped_values, *args
):
ngroups = len(offsets) - 1
if ngroups > self._MAX_GROUPS_BEFORE_WARN:
warnings.warn(
f"GroupBy.apply() performance scales poorly with "
f"number of groups. Got {ngroups} groups. Some functions "
"may perform better by passing engine='jit'",
RuntimeWarning,
)

chunks = [
grouped_values[s:e] for s, e in zip(offsets[:-1], offsets[1:])
]
chunk_results = [function(chk, *args) for chk in chunks]
if not len(chunk_results):
return self.obj.head(0)

if cudf.api.types.is_scalar(chunk_results[0]):
result = cudf.Series._from_data(
{None: chunk_results}, index=group_names
)
result.index.names = self.grouping.names
elif isinstance(chunk_results[0], cudf.Series) and isinstance(
self.obj, cudf.DataFrame
):
result = cudf.concat(chunk_results, axis=1).T
result.index.names = self.grouping.names
else:
result = cudf.concat(chunk_results)
if self._group_keys:
index_data = group_keys._data.copy(deep=True)
index_data[None] = grouped_values.index._column
result.index = cudf.MultiIndex._from_data(index_data)
return result

def apply(self, function, *args, engine="cudf"):
"""Apply a python transformation function over the grouped chunk.
Parameters
----------
func : function
function : callable
The python transformation function that will be applied
on the grouped chunk.
args : tuple
Optional positional arguments to pass to the function.
engine: {'cudf', 'jit'}, default 'cudf'
Selects the GroupBy.apply implementation. Use `jit` to
select the numba JIT pipeline. Only certain operations are allowed
within the function when using this option: min, max, sum, mean, var,
std, idxmax, and idxmin and any arithmetic formula involving them are
allowed. Binary operations are not yet supported, so syntax like
`df['x'] * 2` is not yet allowed.
For more information, see the `cuDF guide to user defined functions
<https://docs.rapids.ai/api/cudf/stable/user_guide/guide-to-udfs.html>`__.
Examples
--------
Expand Down Expand Up @@ -850,40 +920,45 @@ def mult(df):
a b c
0 1 1 1
2 2 1 3
``engine='jit'`` may be used to accelerate certain functions,
initially those that contain reductions and arithmetic operations
between results of those reductions:
>>> import cudf
>>> df = cudf.DataFrame({'a':[1,1,2,2,3,3], 'b':[1,2,3,4,5,6]})
>>> df.groupby('a').apply(
... lambda group: group['b'].max() - group['b'].min(),
... engine='jit'
... )
a None
0 1 1
1 2 1
2 3 1
"""
if not callable(function):
raise TypeError(f"type {type(function)} is not callable")
group_names, offsets, group_keys, grouped_values = self._grouped()

ngroups = len(offsets) - 1
if ngroups > self._MAX_GROUPS_BEFORE_WARN:
warnings.warn(
f"GroupBy.apply() performance scales poorly with "
f"number of groups. Got {ngroups} groups."
if engine == "jit":
result = self._jit_groupby_apply(
function,
group_names,
offsets,
group_keys,
grouped_values,
*args,
)
elif engine == "cudf":
result = self._iterative_groupby_apply(
function,
group_names,
offsets,
group_keys,
grouped_values,
*args,
)

chunks = [
grouped_values[s:e] for s, e in zip(offsets[:-1], offsets[1:])
]
chunk_results = [function(chk, *args) for chk in chunks]
if not len(chunk_results):
return self.obj.head(0)

if cudf.api.types.is_scalar(chunk_results[0]):
result = cudf.Series(chunk_results, index=group_names)
result.index.names = self.grouping.names
else:
if isinstance(chunk_results[0], cudf.Series) and isinstance(
self.obj, cudf.DataFrame
):
result = cudf.concat(chunk_results, axis=1).T
result.index.names = self.grouping.names
else:
result = cudf.concat(chunk_results)
if self._group_keys:
index_data = group_keys._data.copy(deep=True)
index_data[None] = grouped_values.index._column
result.index = cudf.MultiIndex._from_data(index_data)
raise ValueError(f"Unsupported engine '{engine}'")

if self._sort:
result = result.sort_index()
Expand Down
4 changes: 2 additions & 2 deletions python/cudf/cudf/core/udf/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) 2022, NVIDIA CORPORATION.
# Copyright (c) 2022-2023, NVIDIA CORPORATION.

from functools import lru_cache

Expand All @@ -9,7 +9,7 @@
from cudf.core.udf import api, row_function, utils
from cudf.utils.dtypes import STRING_TYPES

from . import masked_lowering, masked_typing
from . import groupby_lowering, groupby_typing, masked_lowering, masked_typing

_units = ["ns", "ms", "us", "s"]
_datetime_cases = {types.NPDatetime(u) for u in _units}
Expand Down
Loading

0 comments on commit 7695850

Please sign in to comment.