diff --git a/python/cudf/cudf/core/dataframe.py b/python/cudf/cudf/core/dataframe.py index c686cd0fd39..c246eb3b266 100644 --- a/python/cudf/cudf/core/dataframe.py +++ b/python/cudf/cudf/core/dataframe.py @@ -63,6 +63,7 @@ from cudf.core.multiindex import MultiIndex from cudf.core.resample import DataFrameResampler from cudf.core.series import Series +from cudf.core.udf.row_function import _get_row_kernel from cudf.utils import applyutils, docutils, ioutils, queryutils, utils from cudf.utils.docutils import copy_docstring from cudf.utils.dtypes import ( @@ -3926,10 +3927,8 @@ def apply( raise ValueError("The `raw` kwarg is not yet supported.") if result_type is not None: raise ValueError("The `result_type` kwarg is not yet supported.") - if kwargs: - raise ValueError("UDFs using **kwargs are not yet supported.") - return self._apply(func, *args) + return self._apply(func, _get_row_kernel, *args, **kwargs) @applyutils.doc_apply() def apply_rows( diff --git a/python/cudf/cudf/core/frame.py b/python/cudf/cudf/core/frame.py index 69dc5389e7a..891f58657b0 100644 --- a/python/cudf/cudf/core/frame.py +++ b/python/cudf/cudf/core/frame.py @@ -45,7 +45,6 @@ ) from cudf.core.column_accessor import ColumnAccessor from cudf.core.join import Merge, MergeSemi -from cudf.core.udf.pipeline import compile_or_get, supported_cols_from_frame from cudf.core.window import Rolling from cudf.utils import ioutils from cudf.utils.docutils import copy_docstring @@ -1367,39 +1366,6 @@ def _quantiles( result._copy_type_metadata(self) return result - @annotate("APPLY", color="purple", domain="cudf_python") - def _apply(self, func, *args): - """ - Apply `func` across the rows of the frame. - """ - kernel, retty = compile_or_get(self, func, args) - - # Mask and data column preallocated - ans_col = cupy.empty(len(self), dtype=retty) - ans_mask = cudf.core.column.column_empty(len(self), dtype="bool") - launch_args = [(ans_col, ans_mask), len(self)] - offsets = [] - - # if compile_or_get succeeds, it is safe to create a kernel that only - # consumes the columns that are of supported dtype - for col in supported_cols_from_frame(self).values(): - data = col.data - mask = col.mask - if mask is None: - launch_args.append(data) - else: - launch_args.append((data, mask)) - offsets.append(col.offset) - launch_args += offsets - launch_args += list(args) - kernel.forall(len(self))(*launch_args) - - col = as_column(ans_col) - col.set_base_mask(libcudf.transform.bools_to_mask(ans_mask)) - result = cudf.Series._from_data({None: col}, self._index) - - return result - def rank( self, axis=0, diff --git a/python/cudf/cudf/core/indexed_frame.py b/python/cudf/cudf/core/indexed_frame.py index 9b42aca00d0..59040e3ecbb 100644 --- a/python/cudf/cudf/core/indexed_frame.py +++ b/python/cudf/cudf/core/indexed_frame.py @@ -24,11 +24,12 @@ is_integer_dtype, is_list_like, ) -from cudf.core.column import arange +from cudf.core.column import arange, as_column from cudf.core.column_accessor import ColumnAccessor from cudf.core.frame import Frame from cudf.core.index import Index, RangeIndex, _index_from_columns from cudf.core.multiindex import MultiIndex +from cudf.core.udf.utils import _compile_or_get, _supported_cols_from_frame from cudf.utils.utils import cached_property doc_reset_index_template = """ @@ -756,6 +757,51 @@ def add_suffix(self, suffix): Use `Series.add_suffix` or `DataFrame.add_suffix`" ) + @annotate("APPLY", color="purple", domain="cudf_python") + def _apply(self, func, kernel_getter, *args, **kwargs): + """Apply `func` across the rows of the frame.""" + if kwargs: + raise ValueError("UDFs using **kwargs are not yet supported.") + + try: + kernel, retty = _compile_or_get( + self, func, args, kernel_getter=kernel_getter + ) + except Exception as e: + raise ValueError( + "user defined function compilation failed." + ) from e + + # Mask and data column preallocated + ans_col = cp.empty(len(self), dtype=retty) + ans_mask = cudf.core.column.column_empty(len(self), dtype="bool") + launch_args = [(ans_col, ans_mask), len(self)] + offsets = [] + + # if _compile_or_get succeeds, it is safe to create a kernel that only + # consumes the columns that are of supported dtype + for col in _supported_cols_from_frame(self).values(): + data = col.data + mask = col.mask + if mask is None: + launch_args.append(data) + else: + launch_args.append((data, mask)) + offsets.append(col.offset) + launch_args += offsets + launch_args += list(args) + + try: + kernel.forall(len(self))(*launch_args) + except Exception as e: + raise RuntimeError("UDF kernel execution failed.") from e + + col = as_column(ans_col) + col.set_base_mask(libcudf.transform.bools_to_mask(ans_mask)) + result = cudf.Series._from_data({None: col}, self._index) + + return result + def sort_values( self, by, diff --git a/python/cudf/cudf/core/series.py b/python/cudf/cudf/core/series.py index 0371c40274f..61975d47af2 100644 --- a/python/cudf/cudf/core/series.py +++ b/python/cudf/cudf/core/series.py @@ -14,7 +14,6 @@ import cupy import numpy as np import pandas as pd -from numba import cuda from pandas._config import get_option import cudf @@ -67,6 +66,7 @@ doc_reset_index_template, ) from cudf.core.single_column_frame import SingleColumnFrame +from cudf.core.udf.scalar_function import _get_scalar_kernel from cudf.utils import cudautils, docutils from cudf.utils.docutils import copy_docstring from cudf.utils.dtypes import ( @@ -2374,7 +2374,7 @@ def apply(self, func, convert_dtype=True, args=(), **kwargs): by numba based on the function logic and argument types. See examples for details. args : tuple - Not supported + Positional arguments passed to func after the series value. **kwargs Not supported @@ -2440,20 +2440,9 @@ def apply(self, func, convert_dtype=True, args=(), **kwargs): 2 4.5 dtype: float64 """ - if args or kwargs: - raise ValueError( - "UDFs using *args or **kwargs are not yet supported." - ) - - # these functions are generally written as functions of scalar - # values rather than rows. Rather than writing an entirely separate - # numba kernel that is not built around a row object, its simpler - # to just turn this into the equivalent single column dataframe case - name = self.name or "__temp_srname" - df = cudf.DataFrame({name: self}) - f_ = cuda.jit(device=True)(func) - - return df.apply(lambda row: f_(row[name])) + if convert_dtype is not True: + raise ValueError("Series.apply only supports convert_dtype=True") + return self._apply(func, _get_scalar_kernel, *args, **kwargs) def applymap(self, udf, out_dtype=None): """Apply an elementwise function to transform the values in the Column. diff --git a/python/cudf/cudf/core/udf/pipeline.py b/python/cudf/cudf/core/udf/pipeline.py deleted file mode 100644 index 8e798de3bfe..00000000000 --- a/python/cudf/cudf/core/udf/pipeline.py +++ /dev/null @@ -1,390 +0,0 @@ -import math -from typing import Callable - -import cachetools -import numpy as np -from numba import cuda, typeof -from numba.np import numpy_support -from numba.types import Poison, Record, Tuple, boolean, int64, void -from nvtx import annotate - -from cudf.core.dtypes import CategoricalDtype -from cudf.core.udf.api import Masked, pack_return -from cudf.core.udf.typing import MaskedType -from cudf.utils import cudautils -from cudf.utils.dtypes import ( - BOOL_TYPES, - DATETIME_TYPES, - NUMERIC_TYPES, - TIMEDELTA_TYPES, -) - -libcudf_bitmask_type = numpy_support.from_dtype(np.dtype("int32")) -MASK_BITSIZE = np.dtype("int32").itemsize * 8 -precompiled: cachetools.LRUCache = cachetools.LRUCache(maxsize=32) - -JIT_SUPPORTED_TYPES = ( - NUMERIC_TYPES | BOOL_TYPES | DATETIME_TYPES | TIMEDELTA_TYPES -) - - -def _is_jit_supported_type(dtype): - # category dtype isn't hashable - if isinstance(dtype, CategoricalDtype): - return False - return str(dtype) in JIT_SUPPORTED_TYPES - - -def all_dtypes_from_frame(frame): - return { - colname: col.dtype - if _is_jit_supported_type(col.dtype) - else np.dtype("O") - for colname, col in frame._data.items() - } - - -def supported_dtypes_from_frame(frame): - return { - colname: col.dtype - for colname, col in frame._data.items() - if _is_jit_supported_type(col.dtype) - } - - -def supported_cols_from_frame(frame): - return { - colname: col - for colname, col in frame._data.items() - if _is_jit_supported_type(col.dtype) - } - - -def generate_cache_key(frame, func: Callable): - """Create a cache key that uniquely identifies a compilation. - - A new compilation is needed any time any of the following things change: - - The UDF itself as defined in python by the user - - The types of the columns utilized by the UDF - - The existence of the input columns masks - """ - return ( - *cudautils.make_cache_key( - func, tuple(all_dtypes_from_frame(frame).values()) - ), - *(col.mask is None for col in frame._data.values()), - *frame._data.keys(), - ) - - -def get_frame_row_type(dtype): - """ - Get the numba `Record` type corresponding to a frame. - Models each column and its mask as a MaskedType and - models the row as a dictionary like data structure - containing these MaskedTypes. - - Large parts of this function are copied with comments - from the Numba internals and slightly modified to - account for validity bools to be present in the final - struct. - """ - - # Create the numpy structured type corresponding to the numpy dtype. - - fields = [] - offset = 0 - - sizes = [val[0].itemsize for val in dtype.fields.values()] - for i, (name, info) in enumerate(dtype.fields.items()): - # *info* consists of the element dtype, its offset from the beginning - # of the record, and an optional "title" containing metadata. - # We ignore the offset in info because its value assumes no masking; - # instead, we compute the correct offset based on the masked type. - elemdtype = info[0] - title = info[2] if len(info) == 3 else None - ty = numpy_support.from_dtype(elemdtype) - infos = { - "type": MaskedType(ty), - "offset": offset, - "title": title, - } - fields.append((name, infos)) - - # increment offset by itemsize plus one byte for validity - offset += elemdtype.itemsize + 1 - - # Align the next member of the struct to be a multiple of the - # memory access size, per PTX ISA 7.4/5.4.5 - if i < len(sizes) - 1: - next_itemsize = sizes[i + 1] - offset = int(math.ceil(offset / next_itemsize) * next_itemsize) - - # Numba requires that structures are aligned for the CUDA target - _is_aligned_struct = True - return Record(fields, offset, _is_aligned_struct) - - -@annotate("NUMBA JIT", color="green", domain="cudf_python") -def get_udf_return_type(frame, func: Callable, args=()): - - """ - Get the return type of a masked UDF for a given set of argument dtypes. It - is assumed that the function consumes a dictionary whose keys are strings - and whose values are of MaskedType. Initially assume that the UDF may be - written to utilize any field in the row - including those containing an - unsupported dtype. If an unsupported dtype is actually used in the function - the compilation should fail at `compile_udf`. If compilation succeeds, one - can infer that the function does not use any of the columns of unsupported - dtype - meaning we can drop them going forward and the UDF will still end - up getting fed rows containing all the fields it actually needs to use to - compute the answer for that row. - """ - - # present a row containing all fields to the UDF and try and compile - row_type = get_frame_row_type( - np.dtype(list(all_dtypes_from_frame(frame).items())) - ) - compile_sig = (row_type, *(typeof(arg) for arg in args)) - - # Get the return type. The PTX is also returned by compile_udf, but is not - # needed here. - ptx, output_type = cudautils.compile_udf(func, compile_sig) - if not isinstance(output_type, MaskedType): - numba_output_type = numpy_support.from_dtype(np.dtype(output_type)) - else: - numba_output_type = output_type - - return ( - numba_output_type - if not isinstance(numba_output_type, MaskedType) - else numba_output_type.value_type - ) - - -def masked_array_type_from_col(col): - """ - Return a type representing a tuple of arrays, - the first element an array of the numba type - corresponding to `dtype`, and the second an - array of bools representing a mask. - """ - nb_scalar_ty = numpy_support.from_dtype(col.dtype) - if col.mask is None: - return nb_scalar_ty[::1] - else: - return Tuple((nb_scalar_ty[::1], libcudf_bitmask_type[::1])) - - -def construct_signature(frame, return_type, args): - """ - Build the signature of numba types that will be used to - actually JIT the kernel itself later, accounting for types - and offsets. Skips columns with unsupported dtypes. - """ - - # Tuple of arrays, first the output data array, then the mask - return_type = Tuple((return_type[::1], boolean[::1])) - offsets = [] - sig = [return_type, int64] - for col in supported_cols_from_frame(frame).values(): - sig.append(masked_array_type_from_col(col)) - offsets.append(int64) - - # return_type, size, data, masks, offsets, extra args - sig = void(*(sig + offsets + [typeof(arg) for arg in args])) - - return sig - - -@cuda.jit(device=True) -def mask_get(mask, pos): - return (mask[pos // MASK_BITSIZE] >> (pos % MASK_BITSIZE)) & 1 - - -kernel_template = """\ -def _kernel(retval, size, {input_columns}, {input_offsets}, {extra_args}): - i = cuda.grid(1) - ret_data_arr, ret_mask_arr = retval - if i < size: - # Create a structured array with the desired fields - rows = cuda.local.array(1, dtype=row_type) - - # one element of that array - row = rows[0] - -{masked_input_initializers} -{row_initializers} - - # pass the assembled row into the udf - ret = f_(row, {extra_args}) - - # pack up the return values and set them - ret_masked = pack_return(ret) - ret_data_arr[i] = ret_masked.value - ret_mask_arr[i] = ret_masked.valid -""" - -unmasked_input_initializer_template = """\ - d_{idx} = input_col_{idx} - masked_{idx} = Masked(d_{idx}[i], True) -""" - -masked_input_initializer_template = """\ - d_{idx}, m_{idx} = input_col_{idx} - masked_{idx} = Masked(d_{idx}[i], mask_get(m_{idx}, i + offset_{idx})) -""" - -row_initializer_template = """\ - row["{name}"] = masked_{idx} -""" - - -def _define_function(frame, row_type, args): - """ - The kernel we want to JIT compile looks something like the following, - which is an example for two columns that both have nulls present - - def _kernel(retval, input_col_0, input_col_1, offset_0, offset_1, size): - i = cuda.grid(1) - ret_data_arr, ret_mask_arr = retval - if i < size: - rows = cuda.local.array(1, dtype=row_type) - row = rows[0] - - d_0, m_0 = input_col_0 - masked_0 = Masked(d_0[i], mask_get(m_0, i + offset_0)) - d_1, m_1 = input_col_1 - masked_1 = Masked(d_1[i], mask_get(m_1, i + offset_1)) - - row["a"] = masked_0 - row["b"] = masked_1 - - ret = f_(row) - - ret_masked = pack_return(ret) - ret_data_arr[i] = ret_masked.value - ret_mask_arr[i] = ret_masked.valid - - However we do not always have two columns and columns do not always have - an associated mask. Ideally, we would just write one kernel and make use - of `*args` - and then one function would work for any number of columns, - currently numba does not support `*args` and treats functions it JITs as - if `*args` is a singular argument. Thus we are forced to write the right - functions dynamically at runtime and define them using `exec`. - """ - # Create argument list for kernel - frame = supported_cols_from_frame(frame) - - input_columns = ", ".join([f"input_col_{i}" for i in range(len(frame))]) - input_offsets = ", ".join([f"offset_{i}" for i in range(len(frame))]) - extra_args = ", ".join([f"extra_arg_{i}" for i in range(len(args))]) - - # Generate the initializers for each device function argument - initializers = [] - row_initializers = [] - for i, (colname, col) in enumerate(frame.items()): - idx = str(i) - if col.mask is not None: - template = masked_input_initializer_template - else: - template = unmasked_input_initializer_template - - initializer = template.format(idx=idx) - - initializers.append(initializer) - - row_initializer = row_initializer_template.format( - idx=idx, name=colname - ) - row_initializers.append(row_initializer) - - # Incorporate all of the above into the kernel code template - d = { - "input_columns": input_columns, - "input_offsets": input_offsets, - "extra_args": extra_args, - "masked_input_initializers": "\n".join(initializers), - "row_initializers": "\n".join(row_initializers), - "numba_rectype": row_type, # from global - } - - return kernel_template.format(**d) - - -@annotate("UDF COMPILATION", color="darkgreen", domain="cudf_python") -def compile_or_get(frame, func, args): - """ - Return a compiled kernel in terms of MaskedTypes that launches a - kernel equivalent of `f` for the dtypes of `df`. The kernel uses - a thread for each row and calls `f` using that rows data / mask - to produce an output value and output validity for each row. - - If the UDF has already been compiled for this requested dtypes, - a cached version will be returned instead of running compilation. - - CUDA kernels are void and do not return values. Thus, we need to - preallocate a column of the correct dtype and pass it in as one of - the kernel arguments. This creates a chicken-and-egg problem where - we need the column type to compile the kernel, but normally we would - be getting that type FROM compiling the kernel (and letting numba - determine it as a return value). As a workaround, we compile the UDF - itself outside the final kernel to invoke a full typing pass, which - unfortunately is difficult to do without running full compilation. - we then obtain the return type from that separate compilation and - use it to allocate an output column of the right dtype. - """ - - # check to see if we already compiled this function - cache_key = generate_cache_key(frame, func) - if precompiled.get(cache_key) is not None: - kernel, masked_or_scalar = precompiled[cache_key] - return kernel, masked_or_scalar - - # precompile the user udf to get the right return type. - # could be a MaskedType or a scalar type. - scalar_return_type = get_udf_return_type(frame, func, args) - - # get_udf_return_type will throw a TypingError if the user tries to use - # a field in the row containing an unsupported dtype, except in the - # edge case where all the function does is return that element: - - # def f(row): - # return row[] - # In this case numba is happy to return MaskedType() - # because it relies on not finding overloaded operators for types to raise - # the exception, so we have to explicitly check for that case. - if isinstance(scalar_return_type, Poison): - raise TypeError(str(scalar_return_type)) - - # this is the signature for the final full kernel compilation - sig = construct_signature(frame, scalar_return_type, args) - - # this row type is used within the kernel to pack up the column and - # mask data into the dict like data structure the user udf expects - np_field_types = np.dtype(list(supported_dtypes_from_frame(frame).items())) - row_type = get_frame_row_type(np_field_types) - - f_ = cuda.jit(device=True)(func) - # Dict of 'local' variables into which `_kernel` is defined - local_exec_context = {} - global_exec_context = { - "f_": f_, - "cuda": cuda, - "Masked": Masked, - "mask_get": mask_get, - "pack_return": pack_return, - "row_type": row_type, - } - exec( - _define_function(frame, row_type, args), - global_exec_context, - local_exec_context, - ) - # The python function definition representing the kernel - _kernel = local_exec_context["_kernel"] - kernel = cuda.jit(sig)(_kernel) - np_return_type = numpy_support.as_dtype(scalar_return_type) - precompiled[cache_key] = (kernel, np_return_type) - - return kernel, np_return_type diff --git a/python/cudf/cudf/core/udf/row_function.py b/python/cudf/cudf/core/udf/row_function.py new file mode 100644 index 00000000000..5cda9fb8218 --- /dev/null +++ b/python/cudf/cudf/core/udf/row_function.py @@ -0,0 +1,151 @@ +import math + +import numpy as np +from numba import cuda +from numba.np import numpy_support +from numba.types import Record + +from cudf.core.udf.api import Masked, pack_return +from cudf.core.udf.templates import ( + masked_input_initializer_template, + row_initializer_template, + row_kernel_template, + unmasked_input_initializer_template, +) +from cudf.core.udf.typing import MaskedType +from cudf.core.udf.utils import ( + _all_dtypes_from_frame, + _construct_signature, + _get_kernel, + _get_udf_return_type, + _mask_get, + _supported_cols_from_frame, + _supported_dtypes_from_frame, +) + + +def _get_frame_row_type(dtype): + """ + Get the numba `Record` type corresponding to a frame. + Models each column and its mask as a MaskedType and + models the row as a dictionary like data structure + containing these MaskedTypes. + + Large parts of this function are copied with comments + from the Numba internals and slightly modified to + account for validity bools to be present in the final + struct. + + See numba.np.numpy_support.from_struct_dtype for details. + """ + + # Create the numpy structured type corresponding to the numpy dtype. + + fields = [] + offset = 0 + + sizes = [val[0].itemsize for val in dtype.fields.values()] + for i, (name, info) in enumerate(dtype.fields.items()): + # *info* consists of the element dtype, its offset from the beginning + # of the record, and an optional "title" containing metadata. + # We ignore the offset in info because its value assumes no masking; + # instead, we compute the correct offset based on the masked type. + elemdtype = info[0] + title = info[2] if len(info) == 3 else None + ty = numpy_support.from_dtype(elemdtype) + infos = { + "type": MaskedType(ty), + "offset": offset, + "title": title, + } + fields.append((name, infos)) + + # increment offset by itemsize plus one byte for validity + offset += elemdtype.itemsize + 1 + + # Align the next member of the struct to be a multiple of the + # memory access size, per PTX ISA 7.4/5.4.5 + if i < len(sizes) - 1: + next_itemsize = sizes[i + 1] + offset = int(math.ceil(offset / next_itemsize) * next_itemsize) + + # Numba requires that structures are aligned for the CUDA target + _is_aligned_struct = True + return Record(fields, offset, _is_aligned_struct) + + +def _row_kernel_string_from_template(frame, row_type, args): + """ + Function to write numba kernels for `DataFrame.apply` as a string. + Workaround until numba supports functions that use `*args` + + `DataFrame.apply` expects functions of a dict like row as well as + possibly one or more scalar arguments + + def f(row, c, k): + return (row['x'] + c) / k + + Both the number of input columns as well as their nullability and any + scalar arguments may vary, so the kernels vary significantly. See + templates.py for the full row kernel template and more details. + """ + # Create argument list for kernel + frame = _supported_cols_from_frame(frame) + + input_columns = ", ".join([f"input_col_{i}" for i in range(len(frame))]) + input_offsets = ", ".join([f"offset_{i}" for i in range(len(frame))]) + extra_args = ", ".join([f"extra_arg_{i}" for i in range(len(args))]) + + # Generate the initializers for each device function argument + initializers = [] + row_initializers = [] + for i, (colname, col) in enumerate(frame.items()): + idx = str(i) + template = ( + masked_input_initializer_template + if col.mask is not None + else unmasked_input_initializer_template + ) + initializers.append(template.format(idx=idx)) + row_initializers.append( + row_initializer_template.format(idx=idx, name=colname) + ) + + return row_kernel_template.format( + input_columns=input_columns, + input_offsets=input_offsets, + extra_args=extra_args, + masked_input_initializers="\n".join(initializers), + row_initializers="\n".join(row_initializers), + numba_rectype=row_type, + ) + + +def _get_row_kernel(frame, func, args): + row_type = _get_frame_row_type( + np.dtype(list(_all_dtypes_from_frame(frame).items())) + ) + scalar_return_type = _get_udf_return_type(row_type, func, args) + + # this is the signature for the final full kernel compilation + sig = _construct_signature(frame, scalar_return_type, args) + + # this row type is used within the kernel to pack up the column and + # mask data into the dict like data structure the user udf expects + np_field_types = np.dtype( + list(_supported_dtypes_from_frame(frame).items()) + ) + row_type = _get_frame_row_type(np_field_types) + + # Dict of 'local' variables into which `_kernel` is defined + global_exec_context = { + "cuda": cuda, + "Masked": Masked, + "_mask_get": _mask_get, + "pack_return": pack_return, + "row_type": row_type, + } + kernel_string = _row_kernel_string_from_template(frame, row_type, args) + kernel = _get_kernel(kernel_string, global_exec_context, sig, func) + + return kernel, scalar_return_type diff --git a/python/cudf/cudf/core/udf/scalar_function.py b/python/cudf/cudf/core/udf/scalar_function.py new file mode 100644 index 00000000000..7f3b461a1f0 --- /dev/null +++ b/python/cudf/cudf/core/udf/scalar_function.py @@ -0,0 +1,64 @@ +from numba import cuda +from numba.np import numpy_support + +from cudf.core.udf.api import Masked, pack_return +from cudf.core.udf.templates import ( + masked_input_initializer_template, + scalar_kernel_template, + unmasked_input_initializer_template, +) +from cudf.core.udf.typing import MaskedType +from cudf.core.udf.utils import ( + _construct_signature, + _get_kernel, + _get_udf_return_type, + _mask_get, +) + + +def _scalar_kernel_string_from_template(sr, args): + """ + Function to write numba kernels for `Series.apply` as a string. + Workaround until numba supports functions that use `*args` + + `Series.apply` expects functions of a single variable and possibly + one or more constants, such as: + + def f(x, c, k): + return (x + c) / k + + where the `x` are meant to be the values of the series. Since there + can be only one column, the only thing that varies in the kinds of + kernels that we want is the number of extra_args. See templates.py + for the full kernel template. + """ + extra_args = ", ".join([f"extra_arg_{i}" for i in range(len(args))]) + + masked_initializer = ( + masked_input_initializer_template + if sr._column.mask + else unmasked_input_initializer_template + ).format(idx=0) + + return scalar_kernel_template.format( + extra_args=extra_args, masked_initializer=masked_initializer + ) + + +def _get_scalar_kernel(sr, func, args): + sr_type = MaskedType(numpy_support.from_dtype(sr.dtype)) + scalar_return_type = _get_udf_return_type(sr_type, func, args) + + sig = _construct_signature(sr, scalar_return_type, args=args) + f_ = cuda.jit(device=True)(func) + global_exec_context = { + "f_": f_, + "cuda": cuda, + "Masked": Masked, + "_mask_get": _mask_get, + "pack_return": pack_return, + } + kernel_string = _scalar_kernel_string_from_template(sr, args=args) + kernel = _get_kernel(kernel_string, global_exec_context, sig, func) + + return kernel, scalar_return_type diff --git a/python/cudf/cudf/core/udf/templates.py b/python/cudf/cudf/core/udf/templates.py new file mode 100644 index 00000000000..8cb11133323 --- /dev/null +++ b/python/cudf/cudf/core/udf/templates.py @@ -0,0 +1,52 @@ +unmasked_input_initializer_template = """\ + d_{idx} = input_col_{idx} + masked_{idx} = Masked(d_{idx}[i], True) +""" + +masked_input_initializer_template = """\ + d_{idx}, m_{idx} = input_col_{idx} + masked_{idx} = Masked(d_{idx}[i], _mask_get(m_{idx}, i + offset_{idx})) +""" + +row_initializer_template = """\ + row["{name}"] = masked_{idx} +""" + +row_kernel_template = """\ +def _kernel(retval, size, {input_columns}, {input_offsets}, {extra_args}): + i = cuda.grid(1) + ret_data_arr, ret_mask_arr = retval + if i < size: + # Create a structured array with the desired fields + rows = cuda.local.array(1, dtype=row_type) + + # one element of that array + row = rows[0] + +{masked_input_initializers} +{row_initializers} + + # pass the assembled row into the udf + ret = f_(row, {extra_args}) + + # pack up the return values and set them + ret_masked = pack_return(ret) + ret_data_arr[i] = ret_masked.value + ret_mask_arr[i] = ret_masked.valid +""" + +scalar_kernel_template = """ +def _kernel(retval, size, input_col_0, offset_0, {extra_args}): + i = cuda.grid(1) + ret_data_arr, ret_mask_arr = retval + + if i < size: + +{masked_initializer} + + ret = f_(masked_0, {extra_args}) + + ret_masked = pack_return(ret) + ret_data_arr[i] = ret_masked.value + ret_mask_arr[i] = ret_masked.valid +""" diff --git a/python/cudf/cudf/core/udf/utils.py b/python/cudf/cudf/core/udf/utils.py new file mode 100644 index 00000000000..a98ee40274e --- /dev/null +++ b/python/cudf/cudf/core/udf/utils.py @@ -0,0 +1,216 @@ +from typing import Callable + +import cachetools +import numpy as np +from numba import cuda, typeof +from numba.core.errors import TypingError +from numba.np import numpy_support +from numba.types import Poison, Tuple, boolean, int64, void +from nvtx import annotate + +from cudf.core.dtypes import CategoricalDtype +from cudf.core.udf.typing import MaskedType +from cudf.utils import cudautils +from cudf.utils.dtypes import ( + BOOL_TYPES, + DATETIME_TYPES, + NUMERIC_TYPES, + TIMEDELTA_TYPES, +) + +JIT_SUPPORTED_TYPES = ( + NUMERIC_TYPES | BOOL_TYPES | DATETIME_TYPES | TIMEDELTA_TYPES +) + +libcudf_bitmask_type = numpy_support.from_dtype(np.dtype("int32")) +MASK_BITSIZE = np.dtype("int32").itemsize * 8 + +precompiled: cachetools.LRUCache = cachetools.LRUCache(maxsize=32) + + +@annotate("NUMBA JIT", color="green", domain="cudf_python") +def _get_udf_return_type(argty, func: Callable, args=()): + """ + Get the return type of a masked UDF for a given set of argument dtypes. It + is assumed that the function consumes a dictionary whose keys are strings + and whose values are of MaskedType. Initially assume that the UDF may be + written to utilize any field in the row - including those containing an + unsupported dtype. If an unsupported dtype is actually used in the function + the compilation should fail at `compile_udf`. If compilation succeeds, one + can infer that the function does not use any of the columns of unsupported + dtype - meaning we can drop them going forward and the UDF will still end + up getting fed rows containing all the fields it actually needs to use to + compute the answer for that row. + """ + + # present a row containing all fields to the UDF and try and compile + compile_sig = (argty, *(typeof(arg) for arg in args)) + + # Get the return type. The PTX is also returned by compile_udf, but is not + # needed here. + ptx, output_type = cudautils.compile_udf(func, compile_sig) + if not isinstance(output_type, MaskedType): + numba_output_type = numpy_support.from_dtype(np.dtype(output_type)) + else: + numba_output_type = output_type + + result = ( + numba_output_type + if not isinstance(numba_output_type, MaskedType) + else numba_output_type.value_type + ) + + # _get_udf_return_type will throw a TypingError if the user tries to use + # a field in the row containing an unsupported dtype, except in the + # edge case where all the function does is return that element: + + # def f(row): + # return row[] + # In this case numba is happy to return MaskedType() + # because it relies on not finding overloaded operators for types to raise + # the exception, so we have to explicitly check for that case. + if isinstance(result, Poison): + raise TypingError(str(result)) + + return result + + +def _is_jit_supported_type(dtype): + # category dtype isn't hashable + if isinstance(dtype, CategoricalDtype): + return False + return str(dtype) in JIT_SUPPORTED_TYPES + + +def _all_dtypes_from_frame(frame): + return { + colname: col.dtype + if _is_jit_supported_type(col.dtype) + else np.dtype("O") + for colname, col in frame._data.items() + } + + +def _supported_dtypes_from_frame(frame): + return { + colname: col.dtype + for colname, col in frame._data.items() + if _is_jit_supported_type(col.dtype) + } + + +def _supported_cols_from_frame(frame): + return { + colname: col + for colname, col in frame._data.items() + if _is_jit_supported_type(col.dtype) + } + + +def _masked_array_type_from_col(col): + """ + Return a type representing a tuple of arrays, + the first element an array of the numba type + corresponding to `dtype`, and the second an + array of bools representing a mask. + """ + nb_scalar_ty = numpy_support.from_dtype(col.dtype) + if col.mask is None: + return nb_scalar_ty[::1] + else: + return Tuple((nb_scalar_ty[::1], libcudf_bitmask_type[::1])) + + +def _construct_signature(frame, return_type, args): + """ + Build the signature of numba types that will be used to + actually JIT the kernel itself later, accounting for types + and offsets. Skips columns with unsupported dtypes. + """ + + # Tuple of arrays, first the output data array, then the mask + return_type = Tuple((return_type[::1], boolean[::1])) + offsets = [] + sig = [return_type, int64] + for col in _supported_cols_from_frame(frame).values(): + sig.append(_masked_array_type_from_col(col)) + offsets.append(int64) + + # return_type, size, data, masks, offsets, extra args + sig = void(*(sig + offsets + [typeof(arg) for arg in args])) + + return sig + + +@cuda.jit(device=True) +def _mask_get(mask, pos): + """Return the validity of mask[pos] as a word.""" + return (mask[pos // MASK_BITSIZE] >> (pos % MASK_BITSIZE)) & 1 + + +def _generate_cache_key(frame, func: Callable): + """Create a cache key that uniquely identifies a compilation. + + A new compilation is needed any time any of the following things change: + - The UDF itself as defined in python by the user + - The types of the columns utilized by the UDF + - The existence of the input columns masks + """ + return ( + *cudautils.make_cache_key( + func, tuple(_all_dtypes_from_frame(frame).values()) + ), + *(col.mask is None for col in frame._data.values()), + *frame._data.keys(), + ) + + +@annotate("UDF COMPILATION", color="darkgreen", domain="cudf_python") +def _compile_or_get(frame, func, args, kernel_getter=None): + """ + Return a compiled kernel in terms of MaskedTypes that launches a + kernel equivalent of `f` for the dtypes of `df`. The kernel uses + a thread for each row and calls `f` using that rows data / mask + to produce an output value and output validity for each row. + + If the UDF has already been compiled for this requested dtypes, + a cached version will be returned instead of running compilation. + + CUDA kernels are void and do not return values. Thus, we need to + preallocate a column of the correct dtype and pass it in as one of + the kernel arguments. This creates a chicken-and-egg problem where + we need the column type to compile the kernel, but normally we would + be getting that type FROM compiling the kernel (and letting numba + determine it as a return value). As a workaround, we compile the UDF + itself outside the final kernel to invoke a full typing pass, which + unfortunately is difficult to do without running full compilation. + we then obtain the return type from that separate compilation and + use it to allocate an output column of the right dtype. + """ + + # check to see if we already compiled this function + cache_key = _generate_cache_key(frame, func) + if precompiled.get(cache_key) is not None: + kernel, masked_or_scalar = precompiled[cache_key] + return kernel, masked_or_scalar + + # precompile the user udf to get the right return type. + # could be a MaskedType or a scalar type. + + kernel, scalar_return_type = kernel_getter(frame, func, args) + + np_return_type = numpy_support.as_dtype(scalar_return_type) + precompiled[cache_key] = (kernel, np_return_type) + + return kernel, np_return_type + + +def _get_kernel(kernel_string, globals_, sig, func): + """template kernel compilation helper function""" + f_ = cuda.jit(device=True)(func) + globals_["f_"] = f_ + exec(kernel_string, globals_) + _kernel = globals_["_kernel"] + kernel = cuda.jit(sig)(_kernel) + + return kernel diff --git a/python/cudf/cudf/tests/test_udf_masked_ops.py b/python/cudf/cudf/tests/test_udf_masked_ops.py index 7f2925f2f06..56090c8eacf 100644 --- a/python/cudf/cudf/tests/test_udf_masked_ops.py +++ b/python/cudf/cudf/tests/test_udf_masked_ops.py @@ -13,7 +13,7 @@ comparison_ops, unary_ops, ) -from cudf.core.udf.pipeline import precompiled +from cudf.core.udf.utils import precompiled from cudf.testing._utils import NUMERIC_TYPES, _decimal_series, assert_eq @@ -486,7 +486,7 @@ def outer(row): {"a": [1, cudf.NA, 3, cudf.NA], "b": [1, 2, cudf.NA, cudf.NA]} ) - with pytest.raises(AttributeError): + with pytest.raises(ValueError): gdf.apply(outer, axis=1) pdf = gdf.to_pandas(nullable=True) @@ -539,7 +539,7 @@ def func(row): return row["unsupported_col"] # check that we fail when an unsupported type is used within a function - with pytest.raises(TypeError): + with pytest.raises(ValueError): data.apply(func, axis=1) # also check that a DF containing unsupported dtypes can still run a @@ -596,6 +596,44 @@ def func(row, c, k): run_masked_udf_test(func, data, args=(1, 2), check_dtype=False) +@pytest.mark.parametrize( + "data", + [ + [1, cudf.NA, 3], + [0.5, 2.0, cudf.NA, cudf.NA, 5.0], + [True, False, cudf.NA], + ], +) +@pytest.mark.parametrize("op", arith_ops + comparison_ops) +def test_mask_udf_scalar_args_binops_series(data, op): + data = cudf.Series(data) + + def func(x, c): + return x + c + + run_masked_udf_series(func, data, args=(1,), check_dtype=False) + + +@pytest.mark.parametrize( + "data", + [ + [1, cudf.NA, 3], + [0.5, 2.0, cudf.NA, cudf.NA, 5.0], + [True, False, cudf.NA], + ], +) +@pytest.mark.parametrize("op", arith_ops + comparison_ops) +def test_masked_udf_scalar_args_binops_multiple_series(data, op): + data = cudf.Series(data) + + def func(data, c, k): + x = op(data, c) + y = op(x, k) + return y + + run_masked_udf_series(func, data, args=(1, 2), check_dtype=False) + + def test_masked_udf_caching(): # Make sure similar functions that differ # by simple things like constants actually