From 769585048375688bbcdfa040385220439a76ade5 Mon Sep 17 00:00:00 2001 From: Bobbi Winema Yogatama <34973829+bwyogatama@users.noreply.github.com> Date: Fri, 27 Jan 2023 17:47:28 -0600 Subject: [PATCH] Implement groupby apply with JIT (#11452) Experimental cuDF Groupby Apply JIT pipeline. Authors: - Bobbi Winema Yogatama (https://github.com/bwyogatama) - Vyas Ramasubramani (https://github.com/vyasr) - https://github.com/brandon-b-miller - Yunsong Wang (https://github.com/PointKernel) Approvers: - AJ Schmidt (https://github.com/ajschmidt8) - Yunsong Wang (https://github.com/PointKernel) - Lawrence Mitchell (https://github.com/wence-) - Ashwin Srinath (https://github.com/shwina) - David Wendt (https://github.com/davidwendt) - Vyas Ramasubramani (https://github.com/vyasr) - Robert Maynard (https://github.com/robertmaynard) URL: https://github.com/rapidsai/cudf/pull/11452 --- .github/CODEOWNERS | 3 +- .gitignore | 1 + ci/release/update-version.sh | 3 + python/cudf/CMakeLists.txt | 7 +- python/cudf/cudf/__init__.py | 11 +- python/cudf/cudf/core/groupby/groupby.py | 133 ++++++-- python/cudf/cudf/core/udf/__init__.py | 4 +- python/cudf/cudf/core/udf/groupby_lowering.py | 157 +++++++++ python/cudf/cudf/core/udf/groupby_typing.py | 213 ++++++++++++ python/cudf/cudf/core/udf/groupby_utils.py | 200 +++++++++++ python/cudf/cudf/core/udf/templates.py | 28 +- python/cudf/cudf/core/udf/utils.py | 178 +++++++++- python/cudf/cudf/tests/test_groupby.py | 143 ++++++++ python/cudf/udf_cpp/groupby/CMakeLists.txt | 79 +++++ python/cudf/udf_cpp/groupby/function.cu | 323 ++++++++++++++++++ python/strings_udf/strings_udf/__init__.py | 128 +------ python/strings_udf/strings_udf/_typing.py | 17 +- 17 files changed, 1438 insertions(+), 190 deletions(-) create mode 100644 python/cudf/cudf/core/udf/groupby_lowering.py create mode 100644 python/cudf/cudf/core/udf/groupby_typing.py create mode 100644 python/cudf/cudf/core/udf/groupby_utils.py create mode 100644 python/cudf/udf_cpp/groupby/CMakeLists.txt create mode 100644 python/cudf/udf_cpp/groupby/function.cu diff --git a/.github/CODEOWNERS b/.github/CODEOWNERS index 4b3ed8d3e38..9578d32d13d 100644 --- a/.github/CODEOWNERS +++ b/.github/CODEOWNERS @@ -1,5 +1,6 @@ #cpp code owners -cpp/ @rapidsai/cudf-cpp-codeowners +cpp/ @rapidsai/cudf-cpp-codeowners +python/cudf/udf_cpp/ @rapidsai/cudf-cpp-codeowners #python code owners python/ @rapidsai/cudf-python-codeowners diff --git a/.gitignore b/.gitignore index 1867e65b7be..0f81dcb6f2b 100644 --- a/.gitignore +++ b/.gitignore @@ -31,6 +31,7 @@ python/cudf/*/_cuda/*.cpp python/cudf/*.ipynb python/cudf/.ipynb_checkpoints python/*/record.txt +python/cudf/cudf/core/udf/*.ptx python/cudf_kafka/*/_lib/**/*.cpp python/cudf_kafka/*/_lib/**/*.h python/custreamz/*/_lib/**/*.cpp diff --git a/ci/release/update-version.sh b/ci/release/update-version.sh index 335d907b7b9..c59b6bc4f1d 100755 --- a/ci/release/update-version.sh +++ b/ci/release/update-version.sh @@ -43,6 +43,9 @@ sed_runner 's/'"cudf_version .*)"'/'"cudf_version ${NEXT_FULL_TAG})"'/g' python/ # Strings UDF update sed_runner 's/'"strings_udf_version .*)"'/'"strings_udf_version ${NEXT_FULL_TAG})"'/g' python/strings_udf/CMakeLists.txt +# Groupby UDF update +sed_runner 's/'"VERSION ${CURRENT_SHORT_TAG}.*"'/'"VERSION ${NEXT_FULL_TAG}"'/g' python/cudf/udf_cpp/CMakeLists.txt + # cpp libcudf_kafka update sed_runner 's/'"VERSION ${CURRENT_SHORT_TAG}.*"'/'"VERSION ${NEXT_FULL_TAG}"'/g' cpp/libcudf_kafka/CMakeLists.txt diff --git a/python/cudf/CMakeLists.txt b/python/cudf/CMakeLists.txt index 1c2d06aeb62..638606e27bc 100644 --- a/python/cudf/CMakeLists.txt +++ b/python/cudf/CMakeLists.txt @@ -1,5 +1,5 @@ # ============================================================================= -# Copyright (c) 2022, NVIDIA CORPORATION. +# Copyright (c) 2022-2023, NVIDIA CORPORATION. # # Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except # in compliance with the License. You may obtain a copy of the License at @@ -17,6 +17,8 @@ cmake_minimum_required(VERSION 3.23.1 FATAL_ERROR) set(cudf_version 23.02.00) include(../../fetch_rapids.cmake) +include(rapids-cuda) +rapids_cuda_init_architectures(cudf-python) project( cudf-python @@ -25,7 +27,7 @@ project( # language to be enabled here. The test project that is built in scikit-build to verify # various linking options for the python library is hardcoded to build with C, so until # that is fixed we need to keep C. - C CXX + C CXX CUDA ) option(FIND_CUDF_CPP "Search for existing CUDF C++ installations before defaulting to local files" @@ -117,6 +119,7 @@ endif() rapids_cython_init() add_subdirectory(cudf/_lib) +add_subdirectory(udf_cpp/groupby) include(cmake/Modules/ProtobufHelpers.cmake) codegen_protoc(cudf/utils/metadata/orc_column_statistics.proto) diff --git a/python/cudf/cudf/__init__.py b/python/cudf/cudf/__init__.py index 28eb380f7cb..b86fb72d955 100644 --- a/python/cudf/cudf/__init__.py +++ b/python/cudf/cudf/__init__.py @@ -1,4 +1,4 @@ -# Copyright (c) 2018-2022, NVIDIA CORPORATION. +# Copyright (c) 2018-2023, NVIDIA CORPORATION. from cudf.utils.gpu_utils import validate_setup @@ -88,7 +88,14 @@ pass else: # Patch Numba to support CUDA enhanced compatibility. - patch_numba_linker_if_needed() + # cuDF requires a stronger set of conditions than what is + # checked by patch_numba_linker_if_needed due to the PTX + # files needed for JIT Groupby Apply and string UDFs + from cudf.core.udf.groupby_utils import dev_func_ptx + from cudf.core.udf.utils import _setup_numba_linker + + _setup_numba_linker(dev_func_ptx) + del patch_numba_linker_if_needed cuda.set_memory_manager(rmm.RMMNumbaManager) diff --git a/python/cudf/cudf/core/groupby/groupby.py b/python/cudf/cudf/core/groupby/groupby.py index 7acbc408b0a..91e00eb43f3 100644 --- a/python/cudf/cudf/core/groupby/groupby.py +++ b/python/cudf/cudf/core/groupby/groupby.py @@ -23,6 +23,7 @@ from cudf.core.column_accessor import ColumnAccessor from cudf.core.mixins import Reducible, Scannable from cudf.core.multiindex import MultiIndex +from cudf.core.udf.groupby_utils import jit_groupby_apply from cudf.utils.utils import GetAttrGetItemMixin, _cudf_nvtx_annotate @@ -786,14 +787,83 @@ def pipe(self, func, *args, **kwargs): """ return cudf.core.common.pipe(self, func, *args, **kwargs) - def apply(self, function, *args): + def _jit_groupby_apply( + self, function, group_names, offsets, group_keys, grouped_values, *args + ): + # Nulls are not yet supported + for colname in self.grouping.values._data.keys(): + if self.obj._data[colname].has_nulls(): + raise ValueError( + "Nulls not yet supported with groupby JIT engine" + ) + + chunk_results = jit_groupby_apply( + offsets, grouped_values, function, *args + ) + result = cudf.Series._from_data( + {None: chunk_results}, index=group_names + ) + result.index.names = self.grouping.names + result = result.reset_index() + result[None] = result.pop(0) + return result + + def _iterative_groupby_apply( + self, function, group_names, offsets, group_keys, grouped_values, *args + ): + ngroups = len(offsets) - 1 + if ngroups > self._MAX_GROUPS_BEFORE_WARN: + warnings.warn( + f"GroupBy.apply() performance scales poorly with " + f"number of groups. Got {ngroups} groups. Some functions " + "may perform better by passing engine='jit'", + RuntimeWarning, + ) + + chunks = [ + grouped_values[s:e] for s, e in zip(offsets[:-1], offsets[1:]) + ] + chunk_results = [function(chk, *args) for chk in chunks] + if not len(chunk_results): + return self.obj.head(0) + + if cudf.api.types.is_scalar(chunk_results[0]): + result = cudf.Series._from_data( + {None: chunk_results}, index=group_names + ) + result.index.names = self.grouping.names + elif isinstance(chunk_results[0], cudf.Series) and isinstance( + self.obj, cudf.DataFrame + ): + result = cudf.concat(chunk_results, axis=1).T + result.index.names = self.grouping.names + else: + result = cudf.concat(chunk_results) + if self._group_keys: + index_data = group_keys._data.copy(deep=True) + index_data[None] = grouped_values.index._column + result.index = cudf.MultiIndex._from_data(index_data) + return result + + def apply(self, function, *args, engine="cudf"): """Apply a python transformation function over the grouped chunk. Parameters ---------- - func : function + function : callable The python transformation function that will be applied on the grouped chunk. + args : tuple + Optional positional arguments to pass to the function. + engine: {'cudf', 'jit'}, default 'cudf' + Selects the GroupBy.apply implementation. Use `jit` to + select the numba JIT pipeline. Only certain operations are allowed + within the function when using this option: min, max, sum, mean, var, + std, idxmax, and idxmin and any arithmetic formula involving them are + allowed. Binary operations are not yet supported, so syntax like + `df['x'] * 2` is not yet allowed. + For more information, see the `cuDF guide to user defined functions + `__. Examples -------- @@ -850,40 +920,45 @@ def mult(df): a b c 0 1 1 1 2 2 1 3 + + ``engine='jit'`` may be used to accelerate certain functions, + initially those that contain reductions and arithmetic operations + between results of those reductions: + >>> import cudf + >>> df = cudf.DataFrame({'a':[1,1,2,2,3,3], 'b':[1,2,3,4,5,6]}) + >>> df.groupby('a').apply( + ... lambda group: group['b'].max() - group['b'].min(), + ... engine='jit' + ... ) + a None + 0 1 1 + 1 2 1 + 2 3 1 """ if not callable(function): raise TypeError(f"type {type(function)} is not callable") group_names, offsets, group_keys, grouped_values = self._grouped() - ngroups = len(offsets) - 1 - if ngroups > self._MAX_GROUPS_BEFORE_WARN: - warnings.warn( - f"GroupBy.apply() performance scales poorly with " - f"number of groups. Got {ngroups} groups." + if engine == "jit": + result = self._jit_groupby_apply( + function, + group_names, + offsets, + group_keys, + grouped_values, + *args, + ) + elif engine == "cudf": + result = self._iterative_groupby_apply( + function, + group_names, + offsets, + group_keys, + grouped_values, + *args, ) - - chunks = [ - grouped_values[s:e] for s, e in zip(offsets[:-1], offsets[1:]) - ] - chunk_results = [function(chk, *args) for chk in chunks] - if not len(chunk_results): - return self.obj.head(0) - - if cudf.api.types.is_scalar(chunk_results[0]): - result = cudf.Series(chunk_results, index=group_names) - result.index.names = self.grouping.names else: - if isinstance(chunk_results[0], cudf.Series) and isinstance( - self.obj, cudf.DataFrame - ): - result = cudf.concat(chunk_results, axis=1).T - result.index.names = self.grouping.names - else: - result = cudf.concat(chunk_results) - if self._group_keys: - index_data = group_keys._data.copy(deep=True) - index_data[None] = grouped_values.index._column - result.index = cudf.MultiIndex._from_data(index_data) + raise ValueError(f"Unsupported engine '{engine}'") if self._sort: result = result.sort_index() diff --git a/python/cudf/cudf/core/udf/__init__.py b/python/cudf/cudf/core/udf/__init__.py index 8092207e037..06ceecf0a35 100644 --- a/python/cudf/cudf/core/udf/__init__.py +++ b/python/cudf/cudf/core/udf/__init__.py @@ -1,4 +1,4 @@ -# Copyright (c) 2022, NVIDIA CORPORATION. +# Copyright (c) 2022-2023, NVIDIA CORPORATION. from functools import lru_cache @@ -9,7 +9,7 @@ from cudf.core.udf import api, row_function, utils from cudf.utils.dtypes import STRING_TYPES -from . import masked_lowering, masked_typing +from . import groupby_lowering, groupby_typing, masked_lowering, masked_typing _units = ["ns", "ms", "us", "s"] _datetime_cases = {types.NPDatetime(u) for u in _units} diff --git a/python/cudf/cudf/core/udf/groupby_lowering.py b/python/cudf/cudf/core/udf/groupby_lowering.py new file mode 100644 index 00000000000..376eccb9308 --- /dev/null +++ b/python/cudf/cudf/core/udf/groupby_lowering.py @@ -0,0 +1,157 @@ +# Copyright (c) 2022-2023, NVIDIA CORPORATION. + +from functools import partial + +from numba import types +from numba.core import cgutils +from numba.core.extending import lower_builtin +from numba.core.typing import signature as nb_signature +from numba.cuda.cudaimpl import lower as cuda_lower + +from cudf.core.udf.groupby_typing import ( + SUPPORTED_GROUPBY_NUMBA_TYPES, + Group, + GroupType, + call_cuda_functions, + group_size_type, + index_default_type, +) + + +def group_reduction_impl_basic(context, builder, sig, args, function): + """ + Instruction boilerplate used for calling a groupby reduction + __device__ function. Centers around a forward declaration of + this function and adds the pre/post processing instructions + necessary for calling it. + """ + # return type + retty = sig.return_type + + # a variable logically corresponding to the calling `Group` + grp = cgutils.create_struct_proxy(sig.args[0])( + context, builder, value=args[0] + ) + + # what specific (numba) GroupType + grp_type = sig.args[0] + group_dataty = grp_type.group_data_type + + # logically take the address of the group's data pointer + group_data_ptr = builder.alloca(grp.group_data.type) + builder.store(grp.group_data, group_data_ptr) + + # obtain the correct forward declaration from registry + type_key = (sig.return_type, grp_type.group_scalar_type) + func = call_cuda_functions[function][type_key] + + # insert the forward declaration and return its result + # pass it the data pointer and the group's size + return context.compile_internal( + builder, + func, + nb_signature(retty, group_dataty, grp_type.group_size_type), + (builder.load(group_data_ptr), grp.size), + ) + + +@lower_builtin(Group, types.Array, group_size_type, types.Array) +def group_constructor(context, builder, sig, args): + """ + Instruction boilerplate used for instantiating a Group + struct from a data pointer, an index pointer, and a size + """ + # a variable logically corresponding to the calling `Group` + grp = cgutils.create_struct_proxy(sig.return_type)(context, builder) + grp.group_data = cgutils.create_struct_proxy(sig.args[0])( + context, builder, value=args[0] + ).data + grp.index = cgutils.create_struct_proxy(sig.args[2])( + context, builder, value=args[2] + ).data + grp.size = args[1] + return grp._getvalue() + + +def group_reduction_impl_idx_max_or_min(context, builder, sig, args, function): + """ + Instruction boilerplate used for calling a groupby reduction + __device__ function in the case where the function is either + `idxmax` or `idxmin`. See `group_reduction_impl_basic` for + details. This lowering differs from other reductions due to + the presence of the index. This results in the forward + declaration expecting an extra arg. + """ + retty = sig.return_type + + grp = cgutils.create_struct_proxy(sig.args[0])( + context, builder, value=args[0] + ) + grp_type = sig.args[0] + + if grp_type.index_type != index_default_type: + raise TypeError( + f"Only inputs with default index dtype {index_default_type} " + "are supported." + ) + + group_dataty = grp_type.group_data_type + group_data_ptr = builder.alloca(grp.group_data.type) + builder.store(grp.group_data, group_data_ptr) + + index_dataty = grp_type.group_index_type + index_ptr = builder.alloca(grp.index.type) + builder.store(grp.index, index_ptr) + type_key = (index_default_type, grp_type.group_scalar_type) + func = call_cuda_functions[function][type_key] + + return context.compile_internal( + builder, + func, + nb_signature( + retty, group_dataty, index_dataty, grp_type.group_size_type + ), + (builder.load(group_data_ptr), builder.load(index_ptr), grp.size), + ) + + +cuda_Group_max = partial(group_reduction_impl_basic, function="max") +cuda_Group_min = partial(group_reduction_impl_basic, function="min") +cuda_Group_sum = partial(group_reduction_impl_basic, function="sum") +cuda_Group_mean = partial(group_reduction_impl_basic, function="mean") +cuda_Group_std = partial(group_reduction_impl_basic, function="std") +cuda_Group_var = partial(group_reduction_impl_basic, function="var") + +cuda_Group_idxmax = partial( + group_reduction_impl_idx_max_or_min, function="idxmax" +) +cuda_Group_idxmin = partial( + group_reduction_impl_idx_max_or_min, function="idxmin" +) + + +def cuda_Group_size(context, builder, sig, args): + grp = cgutils.create_struct_proxy(sig.args[0])( + context, builder, value=args[0] + ) + return grp.size + + +cuda_Group_count = cuda_Group_size + + +for ty in SUPPORTED_GROUPBY_NUMBA_TYPES: + cuda_lower("GroupType.max", GroupType(ty))(cuda_Group_max) + cuda_lower("GroupType.min", GroupType(ty))(cuda_Group_min) + cuda_lower("GroupType.sum", GroupType(ty))(cuda_Group_sum) + cuda_lower("GroupType.count", GroupType(ty))(cuda_Group_count) + cuda_lower("GroupType.size", GroupType(ty))(cuda_Group_size) + cuda_lower("GroupType.mean", GroupType(ty))(cuda_Group_mean) + cuda_lower("GroupType.std", GroupType(ty))(cuda_Group_std) + cuda_lower("GroupType.var", GroupType(ty))(cuda_Group_var) + cuda_lower("GroupType.idxmax", GroupType(ty, types.int64))( + cuda_Group_idxmax + ) + cuda_lower("GroupType.idxmin", GroupType(ty, types.int64))( + cuda_Group_idxmin + ) diff --git a/python/cudf/cudf/core/udf/groupby_typing.py b/python/cudf/cudf/core/udf/groupby_typing.py new file mode 100644 index 00000000000..37381a95fdf --- /dev/null +++ b/python/cudf/cudf/core/udf/groupby_typing.py @@ -0,0 +1,213 @@ +# Copyright (c) 2020-2023, NVIDIA CORPORATION. +from typing import Any, Dict + +import numba +from numba import cuda, types +from numba.core.extending import ( + make_attribute_wrapper, + models, + register_model, + type_callable, + typeof_impl, +) +from numba.core.typing import signature as nb_signature +from numba.core.typing.templates import AbstractTemplate, AttributeTemplate +from numba.cuda.cudadecl import registry as cuda_registry +from numba.np import numpy_support + +index_default_type = types.int64 +group_size_type = types.int64 +SUPPORTED_GROUPBY_NUMBA_TYPES = [types.int64, types.float64] +SUPPORTED_GROUPBY_NUMPY_TYPES = [ + numpy_support.as_dtype(dt) for dt in [types.int64, types.float64] +] + + +class Group: + """ + A piece of python code whose purpose is to be replaced + during compilation. After being registered to GroupType, + serves as a handle for instantiating GroupType objects + in python code and accessing their attributes + """ + + pass + + +class GroupType(numba.types.Type): + """ + Numba extension type carrying metadata associated with a single + GroupBy group. This metadata ultimately is passed to the CUDA + __device__ function which actually performs the work. + """ + + def __init__(self, group_scalar_type, index_type=index_default_type): + self.group_scalar_type = group_scalar_type + self.index_type = index_type + self.group_data_type = types.CPointer(group_scalar_type) + self.group_size_type = group_size_type + self.group_index_type = types.CPointer(index_type) + super().__init__( + name=f"Group({self.group_scalar_type}, {self.index_type})" + ) + + +@typeof_impl.register(Group) +def typeof_group(val, c): + """ + Tie Group and GroupType together such that when Numba + sees usage of Group in raw python code, it knows to + treat those usages as uses of GroupType + """ + return GroupType( + numba.np.numpy_support.from_dtype(val.dtype), + numba.np.numpy_support.from_dtype(val.index_dtype), + ) + + +# The typing of the python "function" Group.__init__ +# as it appears in python code +@type_callable(Group) +def type_group(context): + def typer(group_data, size, index): + if ( + isinstance(group_data, types.Array) + and isinstance(size, types.Integer) + and isinstance(index, types.Array) + ): + return GroupType(group_data.dtype, index.dtype) + + return typer + + +@register_model(GroupType) +class GroupModel(models.StructModel): + """ + Model backing GroupType instances. See the link below for details. + https://github.com/numba/numba/blob/main/numba/core/datamodel/models.py + """ + + def __init__(self, dmm, fe_type): + members = [ + ("group_data", types.CPointer(fe_type.group_scalar_type)), + ("size", group_size_type), + ("index", types.CPointer(fe_type.index_type)), + ] + super().__init__(dmm, fe_type, members) + + +call_cuda_functions: Dict[Any, Any] = {} + + +def _register_cuda_reduction_caller(funcname, inputty, retty): + cuda_func = cuda.declare_device( + f"Block{funcname}_{inputty}", + retty(types.CPointer(inputty), group_size_type), + ) + + def caller(data, size): + return cuda_func(data, size) + + call_cuda_functions.setdefault(funcname.lower(), {}) + + type_key = (retty, inputty) + call_cuda_functions[funcname.lower()][type_key] = caller + + +def _register_cuda_idx_reduction_caller(funcname, inputty): + cuda_func = cuda.declare_device( + f"Block{funcname}_{inputty}", + types.int64( + types.CPointer(inputty), + types.CPointer(index_default_type), + group_size_type, + ), + ) + + def caller(data, index, size): + return cuda_func(data, index, size) + + # only support default index type right now + type_key = (index_default_type, inputty) + call_cuda_functions.setdefault(funcname.lower(), {}) + call_cuda_functions[funcname.lower()][type_key] = caller + + +def _create_reduction_attr(name, retty=None): + class Attr(AbstractTemplate): + key = name + + def generic(self, args, kws): + return nb_signature( + self.this.group_scalar_type if not retty else retty, + recvr=self.this, + ) + + Attr.generic = generic + + def _attr(self, mod): + return types.BoundFunction( + Attr, GroupType(mod.group_scalar_type, mod.index_type) + ) + + return _attr + + +class GroupIdxMax(AbstractTemplate): + key = "GroupType.idxmax" + + def generic(self, args, kws): + return nb_signature(self.this.index_type, recvr=self.this) + + +class GroupIdxMin(AbstractTemplate): + key = "GroupType.idxmin" + + def generic(self, args, kws): + return nb_signature(self.this.index_type, recvr=self.this) + + +@cuda_registry.register_attr +class GroupAttr(AttributeTemplate): + key = GroupType + + resolve_max = _create_reduction_attr("GroupType.max") + resolve_min = _create_reduction_attr("GroupType.min") + resolve_sum = _create_reduction_attr("GroupType.sum") + + resolve_size = _create_reduction_attr( + "GroupType.size", retty=group_size_type + ) + resolve_count = _create_reduction_attr( + "GroupType.count", retty=types.int64 + ) + resolve_mean = _create_reduction_attr( + "GroupType.mean", retty=types.float64 + ) + resolve_var = _create_reduction_attr("GroupType.var", retty=types.float64) + resolve_std = _create_reduction_attr("GroupType.std", retty=types.float64) + + def resolve_idxmax(self, mod): + return types.BoundFunction( + GroupIdxMax, GroupType(mod.group_scalar_type, mod.index_type) + ) + + def resolve_idxmin(self, mod): + return types.BoundFunction( + GroupIdxMin, GroupType(mod.group_scalar_type, mod.index_type) + ) + + +for ty in SUPPORTED_GROUPBY_NUMBA_TYPES: + _register_cuda_reduction_caller("Max", ty, ty) + _register_cuda_reduction_caller("Min", ty, ty) + _register_cuda_reduction_caller("Sum", ty, ty) + _register_cuda_reduction_caller("Mean", ty, types.float64) + _register_cuda_reduction_caller("Std", ty, types.float64) + _register_cuda_reduction_caller("Var", ty, types.float64) + _register_cuda_idx_reduction_caller("IdxMax", ty) + _register_cuda_idx_reduction_caller("IdxMin", ty) + + +for attr in ("group_data", "index", "size"): + make_attribute_wrapper(GroupType, attr, attr) diff --git a/python/cudf/cudf/core/udf/groupby_utils.py b/python/cudf/cudf/core/udf/groupby_utils.py new file mode 100644 index 00000000000..a1174835db9 --- /dev/null +++ b/python/cudf/cudf/core/udf/groupby_utils.py @@ -0,0 +1,200 @@ +# Copyright (c) 2022-2023, NVIDIA CORPORATION. + +import os + +import cupy as cp +import numpy as np +from numba import cuda, types +from numba.cuda.cudadrv.devices import get_context +from numba.np import numpy_support +from numba.types import Record + +import cudf.core.udf.utils +from cudf.core.udf.groupby_typing import ( + SUPPORTED_GROUPBY_NUMPY_TYPES, + Group, + GroupType, +) +from cudf.core.udf.templates import ( + group_initializer_template, + groupby_apply_kernel_template, +) +from cudf.core.udf.utils import ( + _get_extensionty_size, + _get_kernel, + _get_ptx_file, + _get_udf_return_type, + _supported_cols_from_frame, + _supported_dtypes_from_frame, +) +from cudf.utils.utils import _cudf_nvtx_annotate + +dev_func_ptx = _get_ptx_file(os.path.dirname(__file__), "function_") +cudf.core.udf.utils.ptx_files.append(dev_func_ptx) + + +def _get_frame_groupby_type(dtype, index_dtype): + """ + Get the numba `Record` type corresponding to a frame. + Models the column as a dictionary like data structure + containing GroupTypes. + See numba.np.numpy_support.from_struct_dtype for details. + + Parameters + ---------- + level : np.dtype + A numpy structured array dtype associating field names + to scalar dtypes + index_dtype : np.dtype + A numpy scalar dtype associated with the index of the + incoming grouped data + """ + # Create the numpy structured type corresponding to the numpy dtype. + fields = [] + offset = 0 + + sizes = [val[0].itemsize for val in dtype.fields.values()] + for i, (name, info) in enumerate(dtype.fields.items()): + elemdtype = info[0] + title = info[2] if len(info) == 3 else None + ty = numpy_support.from_dtype(elemdtype) + indexty = numpy_support.from_dtype(index_dtype) + groupty = GroupType(ty, indexty) + infos = { + "type": groupty, + "offset": offset, + "title": title, + } + fields.append((name, infos)) + offset += _get_extensionty_size(groupty) + + # Align the next member of the struct to be a multiple of the + # memory access size, per PTX ISA 7.4/5.4.5 + if i < len(sizes) - 1: + alignment = offset % 8 + if alignment != 0: + offset += 8 - alignment + + # Numba requires that structures are aligned for the CUDA target + _is_aligned_struct = True + return Record(fields, offset, _is_aligned_struct) + + +def _groupby_apply_kernel_string_from_template(frame, args): + """ + Function to write numba kernels for `Groupby.apply` as a string. + Workaround until numba supports functions that use `*args` + """ + # Create argument list for kernel + frame = _supported_cols_from_frame( + frame, supported_types=SUPPORTED_GROUPBY_NUMPY_TYPES + ) + input_columns = ", ".join([f"input_col_{i}" for i in range(len(frame))]) + extra_args = ", ".join([f"extra_arg_{i}" for i in range(len(args))]) + + # Generate the initializers for each device function argument + initializers = [] + for i, colname in enumerate(frame.keys()): + initializers.append( + group_initializer_template.format(idx=i, name=colname) + ) + + return groupby_apply_kernel_template.format( + input_columns=input_columns, + extra_args=extra_args, + group_initializers="\n".join(initializers), + ) + + +def _get_groupby_apply_kernel(frame, func, args): + np_field_types = np.dtype( + list( + _supported_dtypes_from_frame( + frame, supported_types=SUPPORTED_GROUPBY_NUMPY_TYPES + ).items() + ) + ) + dataframe_group_type = _get_frame_groupby_type( + np_field_types, frame.index.dtype + ) + return_type = _get_udf_return_type(dataframe_group_type, func, args) + + # Dict of 'local' variables into which `_kernel` is defined + global_exec_context = { + "cuda": cuda, + "Group": Group, + "dataframe_group_type": dataframe_group_type, + "types": types, + } + kernel_string = _groupby_apply_kernel_string_from_template(frame, args) + + kernel = _get_kernel(kernel_string, global_exec_context, None, func) + + return kernel, return_type + + +@_cudf_nvtx_annotate +def jit_groupby_apply(offsets, grouped_values, function, *args): + """ + Main entrypoint for JIT Groupby.apply via Numba. + + Parameters + ---------- + offsets : list + A list of integers denoting the indices of the group + boundaries in grouped_values + grouped_values : DataFrame + A DataFrame representing the source data + sorted by group keys + function : callable + The user-defined function to execute + """ + offsets = cp.asarray(offsets) + ngroups = len(offsets) - 1 + + kernel, return_type = _get_groupby_apply_kernel( + grouped_values, function, args + ) + return_type = numpy_support.as_dtype(return_type) + + output = cudf.core.column.column_empty(ngroups, dtype=return_type) + launch_args = [ + offsets, + output, + grouped_values.index, + ] + launch_args += list( + _supported_cols_from_frame( + grouped_values, supported_types=SUPPORTED_GROUPBY_NUMPY_TYPES + ).values() + ) + launch_args += list(args) + + max_group_size = cp.diff(offsets).max() + + if max_group_size >= 256: + blocklim = 256 + else: + blocklim = ((max_group_size + 32 - 1) // 32) * 32 + + if kernel.specialized: + specialized = kernel + else: + specialized = kernel.specialize(*launch_args) + + # Ask the driver to give a good config + ctx = get_context() + # Dispatcher is specialized, so there's only one definition - get + # it so we can get the cufunc from the code library + (kern_def,) = specialized.overloads.values() + grid, tpb = ctx.get_max_potential_block_size( + func=kern_def._codelibrary.get_cufunc(), + b2d_func=0, + memsize=0, + blocksizelimit=int(blocklim), + ) + + # Launch kernel + specialized[ngroups, tpb](*launch_args) + + return output diff --git a/python/cudf/cudf/core/udf/templates.py b/python/cudf/cudf/core/udf/templates.py index 3ac7083582f..9a032146992 100644 --- a/python/cudf/cudf/core/udf/templates.py +++ b/python/cudf/cudf/core/udf/templates.py @@ -1,4 +1,4 @@ -# Copyright (c) 2020-2022, NVIDIA CORPORATION. +# Copyright (c) 2020-2023, NVIDIA CORPORATION. unmasked_input_initializer_template = """\ d_{idx} = input_col_{idx} @@ -14,6 +14,11 @@ row["{name}"] = masked_{idx} """ +group_initializer_template = """\ + arr_{idx} = input_col_{idx}[offset[block_id]:offset[block_id+1]] + dataframe_group["{name}"] = Group(arr_{idx}, size, arr_index) +""" + row_kernel_template = """\ def _kernel(retval, size, {input_columns}, {input_offsets}, {extra_args}): i = cuda.grid(1) @@ -52,3 +57,24 @@ def _kernel(retval, size, input_col_0, offset_0, {extra_args}): ret_data_arr[i] = ret_masked.value ret_mask_arr[i] = ret_masked.valid """ + +groupby_apply_kernel_template = """ +def _kernel(offset, out, index, {input_columns}, {extra_args}): + tid = cuda.threadIdx.x + block_id = cuda.blockIdx.x + tb_size = cuda.blockDim.x + + recarray = cuda.local.array(1, dtype=dataframe_group_type) + dataframe_group = recarray[0] + + if block_id < (len(offset) - 1): + + size = offset[block_id+1] - offset[block_id] + arr_index = index[offset[block_id]:offset[block_id+1]] + +{group_initializers} + + result = f_(dataframe_group, {extra_args}) + if cuda.threadIdx.x == 0: + out[block_id] = result +""" diff --git a/python/cudf/cudf/core/udf/utils.py b/python/cudf/cudf/core/udf/utils.py index 4d40d41f9c3..3ee1d8edcbd 100644 --- a/python/cudf/cudf/core/udf/utils.py +++ b/python/cudf/cudf/core/udf/utils.py @@ -1,19 +1,25 @@ -# Copyright (c) 2020-2022, NVIDIA CORPORATION. +# Copyright (c) 2020-2023, NVIDIA CORPORATION. +import glob +import os from typing import Any, Callable, Dict, List import cachetools import cupy as cp +import llvmlite.binding as ll import numpy as np +from cubinlinker.patch import _numba_version_ok, get_logger, new_patched_linker from numba import cuda, typeof +from numba.core.datamodel import default_manager from numba.core.errors import TypingError +from numba.cuda.cudadrv import nvvm +from numba.cuda.cudadrv.driver import Linker from numba.np import numpy_support from numba.types import CPointer, Poison, Tuple, boolean, int64, void import rmm from cudf.core.column.column import as_column -from cudf.core.dtypes import CategoricalDtype from cudf.core.udf.masked_typing import MaskedType from cudf.utils import cudautils from cudf.utils.dtypes import ( @@ -24,6 +30,9 @@ ) from cudf.utils.utils import _cudf_nvtx_annotate +logger = get_logger() + + JIT_SUPPORTED_TYPES = ( NUMERIC_TYPES | BOOL_TYPES | DATETIME_TYPES | TIMEDELTA_TYPES ) @@ -87,35 +96,28 @@ def _get_udf_return_type(argty, func: Callable, args=()): return result -def _is_jit_supported_type(dtype): - # category dtype isn't hashable - if isinstance(dtype, CategoricalDtype): - return False - return str(dtype) in JIT_SUPPORTED_TYPES - - -def _all_dtypes_from_frame(frame): +def _all_dtypes_from_frame(frame, supported_types=JIT_SUPPORTED_TYPES): return { colname: col.dtype - if _is_jit_supported_type(col.dtype) + if str(col.dtype) in supported_types else np.dtype("O") for colname, col in frame._data.items() } -def _supported_dtypes_from_frame(frame): +def _supported_dtypes_from_frame(frame, supported_types=JIT_SUPPORTED_TYPES): return { colname: col.dtype for colname, col in frame._data.items() - if _is_jit_supported_type(col.dtype) + if str(col.dtype) in supported_types } -def _supported_cols_from_frame(frame): +def _supported_cols_from_frame(frame, supported_types=JIT_SUPPORTED_TYPES): return { colname: col for colname, col in frame._data.items() - if _is_jit_supported_type(col.dtype) + if str(col.dtype) in supported_types } @@ -272,3 +274,149 @@ def _post_process_output_col(col, retty): if getter := output_col_getters.get(retty): col = getter(col) return as_column(col, retty) + + +def _get_best_ptx_file(archs, max_compute_capability): + """ + Determine of the available PTX files which one is + the most recent up to and including the device cc + """ + filtered_archs = [x for x in archs if x[0] <= max_compute_capability] + if filtered_archs: + return max(filtered_archs, key=lambda y: y[0]) + else: + return None + + +def _get_ptx_file(path, prefix): + if "RAPIDS_NO_INITIALIZE" in os.environ: + # cc=60 ptx is always built + cc = int(os.environ.get("STRINGS_UDF_CC", "60")) + else: + dev = cuda.get_current_device() + + # Load the highest compute capability file available that is less than + # the current device's. + cc = int("".join(str(x) for x in dev.compute_capability)) + files = glob.glob(os.path.join(path, f"{prefix}*.ptx")) + if len(files) == 0: + raise RuntimeError(f"Missing PTX files for cc={cc}") + regular_sms = [] + + for f in files: + file_name = os.path.basename(f) + sm_number = file_name.rstrip(".ptx").lstrip(prefix) + if sm_number.endswith("a"): + processed_sm_number = int(sm_number.rstrip("a")) + if processed_sm_number == cc: + return f + else: + regular_sms.append((int(sm_number), f)) + + regular_result = None + + if regular_sms: + regular_result = _get_best_ptx_file(regular_sms, cc) + + if regular_result is None: + raise RuntimeError( + "This cuDF installation is missing the necessary PTX " + f"files that are <={cc}." + ) + else: + return regular_result[1] + + +def _get_extensionty_size(ty): + """ + Return the size of an extension type in bytes + """ + data_layout = nvvm.data_layout + if isinstance(data_layout, dict): + data_layout = data_layout[64] + target_data = ll.create_target_data(data_layout) + llty = default_manager[ty].get_value_type() + return llty.get_abi_size(target_data) + + +def _get_cuda_version_from_ptx_file(path): + """ + https://docs.nvidia.com/cuda/parallel-thread-execution/ + Each PTX module must begin with a .version + directive specifying the PTX language version + + example header: + // + // Generated by NVIDIA NVVM Compiler + // + // Compiler Build ID: CL-31057947 + // Cuda compilation tools, release 11.6, V11.6.124 + // Based on NVVM 7.0.1 + // + + .version 7.6 + .target sm_52 + .address_size 64 + + """ + with open(path) as ptx_file: + for line in ptx_file: + if line.startswith(".version"): + ver_line = line + break + else: + raise ValueError("Could not read CUDA version from ptx file.") + version = ver_line.strip("\n").split(" ")[1] + # from ptx_docs/release_notes above: + ver_map = { + "7.5": (11, 5), + "7.6": (11, 6), + "7.7": (11, 7), + "7.8": (11, 8), + "8.0": (12, 0), + } + + cuda_ver = ver_map.get(version) + if cuda_ver is None: + raise ValueError( + f"Could not map PTX version {version} to a CUDA version" + ) + + return cuda_ver + + +def _setup_numba_linker(path): + from ptxcompiler.patch import NO_DRIVER, safe_get_versions + + from cudf.core.udf.utils import ( + _get_cuda_version_from_ptx_file, + maybe_patch_numba_linker, + ) + + versions = safe_get_versions() + if versions != NO_DRIVER: + driver_version, runtime_version = versions + ptx_toolkit_version = _get_cuda_version_from_ptx_file(path) + maybe_patch_numba_linker( + driver_version, runtime_version, ptx_toolkit_version + ) + + +def maybe_patch_numba_linker( + driver_version, runtime_version, ptx_toolkit_version +): + # Numba thinks cubinlinker is only needed if the driver is older than + # the ctk, but when PTX files are present, it might also need to patch + # because those PTX files may newer than the driver as well + if (driver_version < ptx_toolkit_version) or ( + driver_version < runtime_version + ): + logger.debug( + "Driver version %s.%s needs patching due to PTX files" + % driver_version + ) + if _numba_version_ok: + logger.debug("Patching Numba Linker") + Linker.new = new_patched_linker + else: + logger.debug("Cannot patch Numba Linker - unsupported version") diff --git a/python/cudf/cudf/tests/test_groupby.py b/python/cudf/cudf/tests/test_groupby.py index b3eac9b0f33..c5b330fd89c 100644 --- a/python/cudf/cudf/tests/test_groupby.py +++ b/python/cudf/cudf/tests/test_groupby.py @@ -2,6 +2,7 @@ import datetime import itertools +import textwrap from decimal import Decimal import numpy as np @@ -20,6 +21,7 @@ PANDAS_GE_150, PANDAS_LT_140, ) +from cudf.core.udf.groupby_typing import SUPPORTED_GROUPBY_NUMPY_TYPES from cudf.testing._utils import ( DATETIME_TYPES, SIGNED_TYPES, @@ -376,6 +378,147 @@ def emulate(df): assert_groupby_results_equal(expect, got) +@pytest.fixture(scope="module") +def groupby_jit_data(): + np.random.seed(0) + df = DataFrame() + nelem = 20 + df["key1"] = np.random.randint(0, 3, nelem) + df["key2"] = np.random.randint(0, 2, nelem) + df["val1"] = np.random.random(nelem) + df["val2"] = np.random.random(nelem) + return df + + +def run_groupby_apply_jit_test(data, func, keys, *args): + expect_groupby_obj = data.to_pandas().groupby(keys, as_index=False) + got_groupby_obj = data.groupby(keys) + + # compare cuDF jit to pandas + cudf_jit_result = got_groupby_obj.apply(func, *args, engine="jit") + pandas_result = expect_groupby_obj.apply(func, *args) + assert_groupby_results_equal(cudf_jit_result, pandas_result) + + +@pytest.mark.parametrize("dtype", SUPPORTED_GROUPBY_NUMPY_TYPES) +@pytest.mark.parametrize( + "func", ["min", "max", "sum", "mean", "var", "std", "idxmin", "idxmax"] +) +def test_groupby_apply_jit_reductions(func, groupby_jit_data, dtype): + # ideally we'd just have: + # lambda group: getattr(group, func)() + # but the current kernel caching mechanism relies on pickle which + # does not play nice with local functions. What's below uses + # exec as a workaround to write the test functions dynamically + + funcstr = textwrap.dedent( + f""" + def func(df): + return df['val1'].{func}() + """ + ) + lcl = {} + exec(funcstr, lcl) + func = lcl["func"] + + groupby_jit_data["val1"] = groupby_jit_data["val1"].astype(dtype) + groupby_jit_data["val2"] = groupby_jit_data["val2"].astype(dtype) + + run_groupby_apply_jit_test(groupby_jit_data, func, ["key1"]) + + +@pytest.mark.parametrize("dtype", ["float64"]) +@pytest.mark.parametrize("func", ["min", "max", "sum", "mean", "var", "std"]) +@pytest.mark.parametrize("special_val", [np.nan, np.inf, -np.inf]) +def test_groupby_apply_jit_reductions_special_vals( + func, groupby_jit_data, dtype, special_val +): + # dynamically generate to avoid pickling error. + # see test_groupby_apply_jit_reductions for details. + funcstr = textwrap.dedent( + f""" + def func(df): + return df['val1'].{func}() + """ + ) + lcl = {} + exec(funcstr, lcl) + func = lcl["func"] + + groupby_jit_data["val1"] = special_val + groupby_jit_data["val1"] = groupby_jit_data["val1"].astype(dtype) + + run_groupby_apply_jit_test(groupby_jit_data, func, ["key1"]) + + +@pytest.mark.parametrize("dtype", ["float64"]) +@pytest.mark.parametrize("func", ["idxmax", "idxmin"]) +@pytest.mark.parametrize("special_val", [np.nan, np.inf, -np.inf]) +def test_groupby_apply_jit_idx_reductions_special_vals( + func, groupby_jit_data, dtype, special_val +): + # dynamically generate to avoid pickling error. + # see test_groupby_apply_jit_reductions for details. + funcstr = textwrap.dedent( + f""" + def func(df): + return df['val1'].{func}() + """ + ) + lcl = {} + exec(funcstr, lcl) + func = lcl["func"] + + groupby_jit_data["val1"] = special_val + groupby_jit_data["val1"] = groupby_jit_data["val1"].astype(dtype) + + expect = ( + groupby_jit_data.to_pandas() + .groupby("key1", as_index=False) + .apply(func) + ) + + grouped = groupby_jit_data.groupby("key1") + sorted = grouped._grouped()[3].to_pandas() + expect_vals = sorted["key1"].drop_duplicates().index + expect[None] = expect_vals + + got = grouped.apply(func, engine="jit") + assert_eq(expect, got) + + +@pytest.mark.parametrize( + "func", + [ + lambda df: df["val1"].max() + df["val2"].min(), + lambda df: df["val1"].sum() + df["val2"].var(), + lambda df: df["val1"].mean() + df["val2"].std(), + ], +) +def test_groupby_apply_jit_basic(func, groupby_jit_data): + run_groupby_apply_jit_test(groupby_jit_data, func, ["key1", "key2"]) + + +def create_test_groupby_apply_jit_args_params(): + def f1(df, k): + return df["val1"].max() + df["val2"].min() + k + + def f2(df, k, L): + return df["val1"].sum() - df["val2"].var() + (k / L) + + def f3(df, k, L, m): + return ((k * df["val1"].mean()) + (L * df["val2"].std())) / m + + return [(f1, (42,)), (f2, (42, 119)), (f3, (42, 119, 212.1))] + + +@pytest.mark.parametrize( + "func,args", create_test_groupby_apply_jit_args_params() +) +def test_groupby_apply_jit_args(func, args, groupby_jit_data): + run_groupby_apply_jit_test(groupby_jit_data, func, ["key1", "key2"], *args) + + @pytest.mark.parametrize("nelem", [2, 3, 100, 500, 1000]) @pytest.mark.parametrize( "func", diff --git a/python/cudf/udf_cpp/groupby/CMakeLists.txt b/python/cudf/udf_cpp/groupby/CMakeLists.txt new file mode 100644 index 00000000000..043ab28f362 --- /dev/null +++ b/python/cudf/udf_cpp/groupby/CMakeLists.txt @@ -0,0 +1,79 @@ +# ============================================================================= +# Copyright (c) 2022-2023, NVIDIA CORPORATION. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except +# in compliance with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software distributed under the License +# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +# or implied. See the License for the specific language governing permissions and limitations under +# the License. +# ============================================================================= + +cmake_minimum_required(VERSION 3.23.1) + +include(rapids-find) + +# This function will copy the generated PTX file from its generator-specific location in the build +# tree into a specified location in the build tree from which we can install it. +function(copy_ptx_to_location target destination) + set(cmake_generated_file + "${CMAKE_CURRENT_BINARY_DIR}/cmake/cp_${target}_$>_ptx.cmake" + ) + file( + GENERATE + OUTPUT "${cmake_generated_file}" + CONTENT + " +set(ptx_paths \"$\") +file(COPY_FILE \${ptx_paths} \"${destination}/${target}.ptx\")" + ) + + add_custom_target( + ${target}_cp_ptx ALL + COMMAND ${CMAKE_COMMAND} -P "${cmake_generated_file}" + DEPENDS $ + COMMENT "Copying PTX files to '${destination}'" + ) +endfunction() + +# Create the shim library for each architecture. +set(GROUPBY_FUNCTION_CUDA_FLAGS --expt-relaxed-constexpr) + +# always build a default PTX file in case RAPIDS_NO_INITIALIZE is set and the device cc can't be +# safely queried through a context +list(INSERT CMAKE_CUDA_ARCHITECTURES 0 "60") + +list(TRANSFORM CMAKE_CUDA_ARCHITECTURES REPLACE "-real" "") +list(TRANSFORM CMAKE_CUDA_ARCHITECTURES REPLACE "-virtual" "") +list(SORT CMAKE_CUDA_ARCHITECTURES) +list(REMOVE_DUPLICATES CMAKE_CUDA_ARCHITECTURES) + +foreach(arch IN LISTS CMAKE_CUDA_ARCHITECTURES) + set(tgt function_${arch}) + + add_library(${tgt} OBJECT function.cu) + set_target_properties( + ${tgt} + PROPERTIES CUDA_STANDARD 17 + CUDA_STANDARD_REQUIRED ON + CUDA_ARCHITECTURES ${arch} + CUDA_PTX_COMPILATION ON + CUDA_SEPARABLE_COMPILATION ON + ) + + target_include_directories(${tgt} PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include) + target_compile_options( + ${tgt} PRIVATE "$<$:${GROUPBY_FUNCTION_CUDA_FLAGS}>" + ) + target_link_libraries(${tgt} PUBLIC cudf::cudf) + + copy_ptx_to_location(${tgt} "${CMAKE_CURRENT_BINARY_DIR}/") + install( + FILES $ + DESTINATION ./cudf/core/udf/ + RENAME ${tgt}.ptx + ) +endforeach() diff --git a/python/cudf/udf_cpp/groupby/function.cu b/python/cudf/udf_cpp/groupby/function.cu new file mode 100644 index 00000000000..f94f99c4b49 --- /dev/null +++ b/python/cudf/udf_cpp/groupby/function.cu @@ -0,0 +1,323 @@ +/* + * Copyright (c) 2022-2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include + +#include + +#include + +#include +#include + +template +__device__ bool are_all_nans(cooperative_groups::thread_block const& block, + T const* data, + int64_t size) +{ + // TODO: to be refactored with CG vote functions once + // block size is known at build time + __shared__ int64_t count; + + if (block.thread_rank() == 0) { count = 0; } + block.sync(); + + for (int64_t idx = block.thread_rank(); idx < size; idx += block.size()) { + if (not std::isnan(data[idx])) { + cuda::atomic_ref ref{count}; + ref.fetch_add(1, cuda::std::memory_order_relaxed); + break; + } + } + + block.sync(); + return count == 0; +} + +template +__device__ void device_sum(cooperative_groups::thread_block const& block, + T const* data, + int64_t size, + T* sum) +{ + T local_sum = 0; + + for (int64_t idx = block.thread_rank(); idx < size; idx += block.size()) { + local_sum += data[idx]; + } + + cuda::atomic_ref ref{*sum}; + ref.fetch_add(local_sum, cuda::std::memory_order_relaxed); + + block.sync(); +} + +template +__device__ T BlockSum(T const* data, int64_t size) +{ + auto block = cooperative_groups::this_thread_block(); + + if constexpr (std::is_floating_point_v) { + if (are_all_nans(block, data, size)) { return 0; } + } + + __shared__ T block_sum; + if (block.thread_rank() == 0) { block_sum = 0; } + block.sync(); + + device_sum(block, data, size, &block_sum); + return block_sum; +} + +template +__device__ double BlockMean(T const* data, int64_t size) +{ + auto block = cooperative_groups::this_thread_block(); + + __shared__ T block_sum; + if (block.thread_rank() == 0) { block_sum = 0; } + block.sync(); + + device_sum(block, data, size, &block_sum); + return static_cast(block_sum) / static_cast(size); +} + +template +__device__ double BlockVar(T const* data, int64_t size) +{ + auto block = cooperative_groups::this_thread_block(); + + __shared__ double block_var; + __shared__ T block_sum; + if (block.thread_rank() == 0) { + block_var = 0; + block_sum = 0; + } + block.sync(); + + T local_sum = 0; + double local_var = 0; + + device_sum(block, data, size, &block_sum); + + auto const mean = static_cast(block_sum) / static_cast(size); + + for (int64_t idx = block.thread_rank(); idx < size; idx += block.size()) { + auto const delta = static_cast(data[idx]) - mean; + local_var += delta * delta; + } + + cuda::atomic_ref ref{block_var}; + ref.fetch_add(local_var, cuda::std::memory_order_relaxed); + block.sync(); + + if (block.thread_rank() == 0) { block_var = block_var / static_cast(size - 1); } + block.sync(); + return block_var; +} + +template +__device__ double BlockStd(T const* data, int64_t size) +{ + auto const var = BlockVar(data, size); + return sqrt(var); +} + +template +__device__ T BlockMax(T const* data, int64_t size) +{ + auto block = cooperative_groups::this_thread_block(); + + if constexpr (std::is_floating_point_v) { + if (are_all_nans(block, data, size)) { return std::numeric_limits::quiet_NaN(); } + } + + auto local_max = cudf::DeviceMax::identity(); + __shared__ T block_max; + if (block.thread_rank() == 0) { block_max = local_max; } + block.sync(); + + for (int64_t idx = block.thread_rank(); idx < size; idx += block.size()) { + local_max = max(local_max, data[idx]); + } + + cuda::atomic_ref ref{block_max}; + ref.fetch_max(local_max, cuda::std::memory_order_relaxed); + + block.sync(); + + return block_max; +} + +template +__device__ T BlockMin(T const* data, int64_t size) +{ + auto block = cooperative_groups::this_thread_block(); + + if constexpr (std::is_floating_point_v) { + if (are_all_nans(block, data, size)) { return std::numeric_limits::quiet_NaN(); } + } + + auto local_min = cudf::DeviceMin::identity(); + + __shared__ T block_min; + if (block.thread_rank() == 0) { block_min = local_min; } + block.sync(); + + for (int64_t idx = block.thread_rank(); idx < size; idx += block.size()) { + local_min = min(local_min, data[idx]); + } + + cuda::atomic_ref ref{block_min}; + ref.fetch_min(local_min, cuda::std::memory_order_relaxed); + + block.sync(); + + return block_min; +} + +template +__device__ int64_t BlockIdxMax(T const* data, int64_t* index, int64_t size) +{ + auto block = cooperative_groups::this_thread_block(); + + __shared__ T block_max; + __shared__ int64_t block_idx_max; + __shared__ bool found_max; + + auto local_max = cudf::DeviceMax::identity(); + auto local_idx_max = cudf::DeviceMin::identity(); + + if (block.thread_rank() == 0) { + block_max = local_max; + block_idx_max = local_idx_max; + found_max = false; + } + block.sync(); + + for (int64_t idx = block.thread_rank(); idx < size; idx += block.size()) { + auto const current_data = data[idx]; + if (current_data > local_max) { + local_max = current_data; + local_idx_max = index[idx]; + found_max = true; + } + } + + cuda::atomic_ref ref{block_max}; + ref.fetch_max(local_max, cuda::std::memory_order_relaxed); + block.sync(); + + if (found_max) { + if (local_max == block_max) { + cuda::atomic_ref ref_idx{block_idx_max}; + ref_idx.fetch_min(local_idx_max, cuda::std::memory_order_relaxed); + } + } else { + if (block.thread_rank() == 0) { block_idx_max = index[0]; } + } + block.sync(); + + return block_idx_max; +} + +template +__device__ int64_t BlockIdxMin(T const* data, int64_t* index, int64_t size) +{ + auto block = cooperative_groups::this_thread_block(); + + __shared__ T block_min; + __shared__ int64_t block_idx_min; + __shared__ bool found_min; + + auto local_min = cudf::DeviceMin::identity(); + auto local_idx_min = cudf::DeviceMin::identity(); + + if (block.thread_rank() == 0) { + block_min = local_min; + block_idx_min = local_idx_min; + found_min = false; + } + block.sync(); + + for (int64_t idx = block.thread_rank(); idx < size; idx += block.size()) { + auto const current_data = data[idx]; + if (current_data < local_min) { + local_min = current_data; + local_idx_min = index[idx]; + found_min = true; + } + } + + cuda::atomic_ref ref{block_min}; + ref.fetch_min(local_min, cuda::std::memory_order_relaxed); + block.sync(); + + if (found_min) { + if (local_min == block_min) { + cuda::atomic_ref ref_idx{block_idx_min}; + ref_idx.fetch_min(local_idx_min, cuda::std::memory_order_relaxed); + } + } else { + if (block.thread_rank() == 0) { block_idx_min = index[0]; } + } + block.sync(); + + return block_idx_min; +} + +extern "C" { +#define make_definition(name, cname, type, return_type) \ + __device__ int name##_##cname(return_type* numba_return_value, type* const data, int64_t size) \ + { \ + return_type const res = name(data, size); \ + if (threadIdx.x == 0) { *numba_return_value = res; } \ + __syncthreads(); \ + return 0; \ + } + +make_definition(BlockSum, int64, int64_t, int64_t); +make_definition(BlockSum, float64, double, double); +make_definition(BlockMean, int64, int64_t, double); +make_definition(BlockMean, float64, double, double); +make_definition(BlockStd, int64, int64_t, double); +make_definition(BlockStd, float64, double, double); +make_definition(BlockVar, int64, int64_t, double); +make_definition(BlockVar, float64, double, double); +make_definition(BlockMin, int64, int64_t, int64_t); +make_definition(BlockMin, float64, double, double); +make_definition(BlockMax, int64, int64_t, int64_t); +make_definition(BlockMax, float64, double, double); +#undef make_definition +} + +extern "C" { +#define make_definition_idx(name, cname, type) \ + __device__ int name##_##cname( \ + int64_t* numba_return_value, type* const data, int64_t* index, int64_t size) \ + { \ + auto const res = name(data, index, size); \ + if (threadIdx.x == 0) { *numba_return_value = res; } \ + __syncthreads(); \ + return 0; \ + } + +make_definition_idx(BlockIdxMin, int64, int64_t); +make_definition_idx(BlockIdxMin, float64, double); +make_definition_idx(BlockIdxMax, int64, int64_t); +make_definition_idx(BlockIdxMax, float64, double); +#undef make_definition_idx +} diff --git a/python/strings_udf/strings_udf/__init__.py b/python/strings_udf/strings_udf/__init__.py index a26dc0a4064..66c037125e6 100644 --- a/python/strings_udf/strings_udf/__init__.py +++ b/python/strings_udf/strings_udf/__init__.py @@ -1,134 +1,17 @@ # Copyright (c) 2022-2023, NVIDIA CORPORATION. -import glob import os -from cubinlinker.patch import _numba_version_ok, get_logger, new_patched_linker from cuda import cudart -from numba import cuda -from numba.cuda.cudadrv.driver import Linker from ptxcompiler.patch import NO_DRIVER, safe_get_versions +from cudf.core.udf.utils import _get_cuda_version_from_ptx_file, _get_ptx_file + from . import _version __version__ = _version.get_versions()["version"] -logger = get_logger() - - -def _get_cuda_version_from_ptx_file(path): - """ - https://docs.nvidia.com/cuda/parallel-thread-execution/ - Each PTX module must begin with a .version - directive specifying the PTX language version - - example header: - // - // Generated by NVIDIA NVVM Compiler - // - // Compiler Build ID: CL-31057947 - // Cuda compilation tools, release 11.6, V11.6.124 - // Based on NVVM 7.0.1 - // - - .version 7.6 - .target sm_52 - .address_size 64 - - """ - with open(path) as ptx_file: - for line in ptx_file: - if line.startswith(".version"): - ver_line = line - break - else: - raise ValueError("Could not read CUDA version from ptx file.") - version = ver_line.strip("\n").split(" ")[1] - # from ptx_docs/release_notes above: - ver_map = { - "7.5": (11, 5), - "7.6": (11, 6), - "7.7": (11, 7), - "7.8": (11, 8), - "8.0": (12, 0), - } - - cuda_ver = ver_map.get(version) - if cuda_ver is None: - raise ValueError( - f"Could not map PTX version {version} to a CUDA version" - ) - - return cuda_ver - - -def _get_appropriate_file(sms, cc): - filtered_sms = list(filter(lambda x: x[0] <= cc, sms)) - if filtered_sms: - return max(filtered_sms, key=lambda y: y[0]) - else: - return None - -def maybe_patch_numba_linker(driver_version, ptx_toolkit_version): - # Numba thinks cubinlinker is only needed if the driver is older than the ctk - # but when strings_udf is present, it might also need to patch because the PTX - # file strings_udf relies on may be newer than the driver as well - if driver_version < ptx_toolkit_version: - logger.debug( - "Driver version %s.%s needs patching due to strings_udf" - % driver_version - ) - if _numba_version_ok: - logger.debug("Patching Numba Linker") - Linker.new = new_patched_linker - else: - logger.debug("Cannot patch Numba Linker - unsupported version") - - -def _get_ptx_file(): - if "RAPIDS_NO_INITIALIZE" in os.environ: - # shim_60.ptx is always built - cc = int(os.environ.get("STRINGS_UDF_CC", "60")) - else: - dev = cuda.get_current_device() - - # Load the highest compute capability file available that is less than - # the current device's. - cc = int("".join(str(x) for x in dev.compute_capability)) - files = glob.glob(os.path.join(os.path.dirname(__file__), "shim_*.ptx")) - if len(files) == 0: - raise RuntimeError( - "This strings_udf installation is missing the necessary PTX " - f"files for compute capability {cc}. " - "Please file an issue reporting this error and how you " - "installed cudf and strings_udf." - "https://github.com/rapidsai/cudf/issues" - ) - - regular_sms = [] - - for f in files: - file_name = os.path.basename(f) - sm_number = file_name.rstrip(".ptx").lstrip("shim_") - if sm_number.endswith("a"): - processed_sm_number = int(sm_number.rstrip("a")) - if processed_sm_number == cc: - return f - else: - regular_sms.append((int(sm_number), f)) - - regular_result = None - - if regular_sms: - regular_result = _get_appropriate_file(regular_sms, cc) - - if regular_result is None: - raise RuntimeError( - "This strings_udf installation is missing the necessary PTX " - f"files that are <={cc}." - ) - else: - return regular_result[1] +path = os.path.dirname(__file__) # Maximum size of a string column is 2 GiB @@ -158,7 +41,4 @@ def set_malloc_heap_size(size=None): ptxpath = None versions = safe_get_versions() if versions != NO_DRIVER: - driver_version, runtime_version = versions - ptxpath = _get_ptx_file() - strings_udf_ptx_version = _get_cuda_version_from_ptx_file(ptxpath) - maybe_patch_numba_linker(driver_version, strings_udf_ptx_version) + ptxpath = _get_ptx_file(path, "shim_") diff --git a/python/strings_udf/strings_udf/_typing.py b/python/strings_udf/strings_udf/_typing.py index 80deb881ec8..fa87ad63dc2 100644 --- a/python/strings_udf/strings_udf/_typing.py +++ b/python/strings_udf/strings_udf/_typing.py @@ -2,28 +2,19 @@ import operator -import llvmlite.binding as ll import numpy as np from numba import types -from numba.core.datamodel import default_manager from numba.core.extending import models, register_model from numba.core.typing import signature as nb_signature from numba.core.typing.templates import AbstractTemplate, AttributeTemplate from numba.cuda.cudadecl import registry as cuda_decl_registry -from numba.cuda.cudadrv import nvvm import rmm - -data_layout = nvvm.data_layout +from cudf.core.udf.utils import _get_extensionty_size # libcudf size_type size_type = types.int32 -# workaround for numba < 0.56 -if isinstance(data_layout, dict): - data_layout = data_layout[64] -target_data = ll.create_target_data(data_layout) - # String object definitions class UDFString(types.Type): @@ -32,8 +23,7 @@ class UDFString(types.Type): def __init__(self): super().__init__(name="udf_string") - llty = default_manager[self].get_value_type() - self.size_bytes = llty.get_abi_size(target_data) + self.size_bytes = _get_extensionty_size(self) @property def return_type(self): @@ -46,8 +36,7 @@ class StringView(types.Type): def __init__(self): super().__init__(name="string_view") - llty = default_manager[self].get_value_type() - self.size_bytes = llty.get_abi_size(target_data) + self.size_bytes = _get_extensionty_size(self) @property def return_type(self):