diff --git a/docs/cudf/source/user_guide/guide-to-udfs.ipynb b/docs/cudf/source/user_guide/guide-to-udfs.ipynb index 3cd1ac2e0c4..67cc90f9236 100644 --- a/docs/cudf/source/user_guide/guide-to-udfs.ipynb +++ b/docs/cudf/source/user_guide/guide-to-udfs.ipynb @@ -22,7 +22,7 @@ "- CuPy NDArrays\n", "- Numba DeviceNDArrays\n", "\n", - "It also demonstrates cuDF's default null handling behavior, and how to write UDFs that can interact with null values in a limited fashion. Finally, it demonstrates some newer more general null handling via the `DataFrame.apply` API." + "It also demonstrates cuDF's default null handling behavior, and how to write UDFs that can interact with null values in a limited fashion. Finally, it demonstrates some newer more general null handling via the `apply` API." ] }, { @@ -1447,20 +1447,9 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "More general support for `NA` handling is provided on an experimental basis. While the details of the way this works are out of scope of this guide, the broad strokes of the pipeline are similar to those of `Series.applymap`: Numba is used to translate a standard python function into an operation on the data columns and their masks, and then the reduced and optimized version of this function is runtime compiled and called using the data. \n", + "More general support for `NA` handling is provided on an experimental basis. Numba is used to translate a standard python function into an operation on the data columns and their masks, and then the reduced and optimized version of this function is runtime compiled and called using the data. \n", "\n", - "One advantage of this approach apart from the ability to handle nulls generally in an intuitive manner is it results in a very familiar API to Pandas users. Let's see how this works with an example.\n", - "\n", - "The key to accessing this API is a decorator: `cudf.core.udf.pipeline.nulludf`:" - ] - }, - { - "cell_type": "code", - "execution_count": 24, - "metadata": {}, - "outputs": [], - "source": [ - "from cudf.core.udf.pipeline import nulludf" + "One advantage of this approach apart from the ability to handle nulls generally in an intuitive manner is it results in a very familiar API to Pandas users. Let's see how this works with an example." ] }, { @@ -1472,7 +1461,7 @@ }, { "cell_type": "code", - "execution_count": 25, + "execution_count": 24, "metadata": {}, "outputs": [ { @@ -1527,7 +1516,7 @@ "2 3 6" ] }, - "execution_count": 25, + "execution_count": 24, "metadata": {}, "output_type": "execute_result" } @@ -1544,18 +1533,17 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "The entrypoint for UDFs used in this manner is `cudf.DataFrame.apply`. To use it, start by defining a completely standard python function decorated with the decorator `nulludf`:" + "The entrypoint for UDFs used in this manner is `cudf.DataFrame.apply`. To use it, start by defining a standard python function designed to accept a single dict-like row of the dataframe:" ] }, { "cell_type": "code", - "execution_count": 26, + "execution_count": 25, "metadata": {}, "outputs": [], "source": [ - "@nulludf\n", - "def f(x, y):\n", - " return x + y" + "def f(row):\n", + " return row['A'] + row['B']" ] }, { @@ -1567,7 +1555,7 @@ }, { "cell_type": "code", - "execution_count": 27, + "execution_count": 26, "metadata": {}, "outputs": [ { @@ -1579,31 +1567,25 @@ "dtype: int64" ] }, - "execution_count": 27, + "execution_count": 26, "metadata": {}, "output_type": "execute_result" } ], "source": [ - "df.apply(\n", - " lambda row: f(\n", - " row['A'],\n", - " row['B']\n", - " ),\n", - " axis=1\n", - ")" + "df.apply(f, axis=1)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "Advanced users might recognize that cuDF does not actually have a `row` object (a special type of Pandas series that behaves like a dict). The `nulludf` decorator is the key to making this work - it really just rearranges things nicely such that the API works in this way. The same function works the same way in pandas, except without the decorator of course:" + "The same function should produce the same result as pandas:" ] }, { "cell_type": "code", - "execution_count": 28, + "execution_count": 27, "metadata": {}, "outputs": [ { @@ -1615,22 +1597,13 @@ "dtype: object" ] }, - "execution_count": 28, + "execution_count": 27, "metadata": {}, "output_type": "execute_result" } ], "source": [ - "def g(x, y):\n", - " return x + y\n", - "\n", - "df.to_pandas(nullable=True).apply(\n", - " lambda row: g(\n", - " row['A'],\n", - " row['B']\n", - " ),\n", - " axis=1\n", - ")" + "df.to_pandas(nullable=True).apply(f, axis=1)" ] }, { @@ -1649,7 +1622,7 @@ }, { "cell_type": "code", - "execution_count": 29, + "execution_count": 28, "metadata": {}, "outputs": [ { @@ -1700,14 +1673,14 @@ "2 3" ] }, - "execution_count": 29, + "execution_count": 28, "metadata": {}, "output_type": "execute_result" } ], "source": [ - "@nulludf\n", - "def f(x):\n", + "def f(row):\n", + " x = row['a']\n", " if x is cudf.NA:\n", " return 0\n", " else:\n", @@ -1719,7 +1692,7 @@ }, { "cell_type": "code", - "execution_count": 30, + "execution_count": 29, "metadata": {}, "outputs": [ { @@ -1731,13 +1704,13 @@ "dtype: int64" ] }, - "execution_count": 30, + "execution_count": 29, "metadata": {}, "output_type": "execute_result" } ], "source": [ - "df.apply(lambda row: f(row['a']))" + "df.apply(f, axis=1)" ] }, { @@ -1749,7 +1722,7 @@ }, { "cell_type": "code", - "execution_count": 31, + "execution_count": 30, "metadata": {}, "outputs": [ { @@ -1804,14 +1777,15 @@ "2 3 1" ] }, - "execution_count": 31, + "execution_count": 30, "metadata": {}, "output_type": "execute_result" } ], "source": [ - "@nulludf\n", - "def f(x, y):\n", + "def f(row):\n", + " x = row['a']\n", + " y = row['b']\n", " if x + y > 3:\n", " return cudf.NA\n", " else:\n", @@ -1826,7 +1800,7 @@ }, { "cell_type": "code", - "execution_count": 32, + "execution_count": 31, "metadata": {}, "outputs": [ { @@ -1838,13 +1812,13 @@ "dtype: int64" ] }, - "execution_count": 32, + "execution_count": 31, "metadata": {}, "output_type": "execute_result" } ], "source": [ - "df.apply(lambda row: f(row['a'], row['b']))" + "df.apply(f, axis=1)" ] }, { @@ -1856,7 +1830,7 @@ }, { "cell_type": "code", - "execution_count": 33, + "execution_count": 32, "metadata": {}, "outputs": [ { @@ -1911,15 +1885,14 @@ "2 3 3.14" ] }, - "execution_count": 33, + "execution_count": 32, "metadata": {}, "output_type": "execute_result" } ], "source": [ - "@nulludf\n", - "def f(x, y):\n", - " return x + y\n", + "def f(row):\n", + " return row['a'] + row['b']\n", "\n", "df = cudf.DataFrame({\n", " 'a': [1, 2, 3], \n", @@ -1930,7 +1903,7 @@ }, { "cell_type": "code", - "execution_count": 34, + "execution_count": 33, "metadata": {}, "outputs": [ { @@ -1942,13 +1915,13 @@ "dtype: float64" ] }, - "execution_count": 34, + "execution_count": 33, "metadata": {}, "output_type": "execute_result" } ], "source": [ - "df.apply(lambda row: f(row['a'], row['b']))" + "df.apply(f, axis=1)" ] }, { @@ -1973,7 +1946,7 @@ }, { "cell_type": "code", - "execution_count": 35, + "execution_count": 34, "metadata": {}, "outputs": [ { @@ -2024,14 +1997,14 @@ "2 5" ] }, - "execution_count": 35, + "execution_count": 34, "metadata": {}, "output_type": "execute_result" } ], "source": [ - "@nulludf\n", - "def f(x):\n", + "def f(row):\n", + " x = row['a']\n", " if x > 3:\n", " return x\n", " else:\n", @@ -2045,7 +2018,7 @@ }, { "cell_type": "code", - "execution_count": 36, + "execution_count": 35, "metadata": {}, "outputs": [ { @@ -2057,13 +2030,13 @@ "dtype: float64" ] }, - "execution_count": 36, + "execution_count": 35, "metadata": {}, "output_type": "execute_result" } ], "source": [ - "df.apply(lambda row: f(row['a']))" + "df.apply(f, axis=1)" ] }, { @@ -2075,7 +2048,7 @@ }, { "cell_type": "code", - "execution_count": 37, + "execution_count": 36, "metadata": {}, "outputs": [ { @@ -2142,16 +2115,14 @@ "2 3 6 4 8 6" ] }, - "execution_count": 37, + "execution_count": 36, "metadata": {}, "output_type": "execute_result" } ], "source": [ - "\n", - "@nulludf\n", - "def f(v, w, x, y, z):\n", - " return x + (y - (z / w)) % v\n", + "def f(row):\n", + " return row['a'] + (row['b'] - (row['c'] / row['d'])) % row['e']\n", "\n", "df = cudf.DataFrame({\n", " 'a': [1, 2, 3],\n", @@ -2165,33 +2136,110 @@ }, { "cell_type": "code", - "execution_count": 38, + "execution_count": 37, "metadata": {}, "outputs": [ { "data": { "text/plain": [ - "0 \n", - "1 4.8\n", - "2 5.0\n", + "0 \n", + "1 2.428571429\n", + "2 8.5\n", "dtype: float64" ] }, - "execution_count": 38, + "execution_count": 37, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "df.apply(f, axis=1)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# `cudf.Series.apply`" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "cuDF provides a similar API to `pandas.Series.apply` for applying scalar UDFs to series objects. Like pandas, these UDFs do not need to be written in terms of rows. These UDFs have generalized null handling and are slightly more flexible than those that work with `applymap`. Ultimately, `applymap` will be deprecated and removed in favor of `apply`. Here is an example: " + ] + }, + { + "cell_type": "code", + "execution_count": 38, + "metadata": {}, + "outputs": [], + "source": [ + "# Create a cuDF series\n", + "sr = cudf.Series([1, cudf.NA, 3])" + ] + }, + { + "cell_type": "code", + "execution_count": 39, + "metadata": {}, + "outputs": [], + "source": [ + "# define a scalar function\n", + "def f(x):\n", + " if x is cudf.NA:\n", + " return 42\n", + " else:\n", + " return 2**x" + ] + }, + { + "cell_type": "code", + "execution_count": 40, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "0 2\n", + "1 42\n", + "2 8\n", + "dtype: int64" + ] + }, + "execution_count": 40, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "sr.apply(f)" + ] + }, + { + "cell_type": "code", + "execution_count": 41, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "0 2\n", + "1 42\n", + "2 8\n", + "dtype: int64" + ] + }, + "execution_count": 41, "metadata": {}, "output_type": "execute_result" } ], "source": [ - "df.apply(\n", - " lambda row: f(\n", - " row['a'],\n", - " row['b'],\n", - " row['c'],\n", - " row['d'],\n", - " row['e']\n", - " )\n", - ")" + "# Check the pandas result\n", + "sr.to_pandas(nullable=True).apply(f)" ] }, { @@ -2206,16 +2254,7 @@ "metadata": {}, "source": [ "- Only numeric nondecimal scalar types are currently supported as of yet, but strings and structured types are in planning. Attempting to use this API with those types will throw a `TypeError`.\n", - "- Due to some more recent CUDA features being leveraged in the pipeline, support for CUDA 11.0 is currently unavailable. In particular, the 11.1+ toolkit will be needed, else the API will raise.\n", - "- We do not yet fully support all arithmetic operators. Certain ops like bitwise operations are not currently implemented, but planned in future releases. If an operator is needed, a github issue should be raised so that it can be properly prioritized and implemented.\n", - "- Due to limitations in the Numba's output is currently runtime compiled, we can't yet support certain functions:\n", - " - `pow`\n", - " - `sin`\n", - " - `cos`\n", - " - `tan`\n", - " \n", - " Attempting to use these functions inside a UDF will result in an NVRTC error.\n", - " " + "- We do not yet fully support all arithmetic operators. Certain ops like bitwise operations are not currently implemented, but planned in future releases. If an operator is needed, a github issue should be raised so that it can be properly prioritized and implemented." ] }, { @@ -2255,7 +2294,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.8.10" + "version": "3.8.12" } }, "nbformat": 4, diff --git a/python/cudf/cudf/core/dataframe.py b/python/cudf/cudf/core/dataframe.py index 5c7aabd18fd..00bc67a24c0 100644 --- a/python/cudf/cudf/core/dataframe.py +++ b/python/cudf/cudf/core/dataframe.py @@ -4856,7 +4856,7 @@ def apply( if args or kwargs: raise ValueError("args and kwargs are not yet supported.") - return cudf.Series(func(self)) + return self._apply(func) @applyutils.doc_apply() def apply_rows( diff --git a/python/cudf/cudf/core/series.py b/python/cudf/cudf/core/series.py index 6eb12870956..0dc5c6ceb14 100644 --- a/python/cudf/cudf/core/series.py +++ b/python/cudf/cudf/core/series.py @@ -13,6 +13,7 @@ import cupy import numpy as np import pandas as pd +from numba import cuda from pandas._config import get_option import cudf @@ -3372,7 +3373,10 @@ def apply(self, func, convert_dtype=True, args=(), **kwargs): Notes ----- UDFs are cached in memory to avoid recompilation. The first - call to the UDF will incur compilation overhead. + call to the UDF will incur compilation overhead. `func` may + call nested functions that are decorated with the decorator + `numba.cuda.jit(device=True)`, otherwise numba will raise a + typing error. Examples -------- @@ -3425,16 +3429,21 @@ def apply(self, func, convert_dtype=True, args=(), **kwargs): 1 2 4.5 dtype: float64 - - - """ if args or kwargs: raise ValueError( "UDFs using *args or **kwargs are not yet supported." ) - return super()._apply(func) + # these functions are generally written as functions of scalar + # values rather than rows. Rather than writing an entirely separate + # numba kernel that is not built around a row object, its simpler + # to just turn this into the equivalent single column dataframe case + name = self.name or "__temp_srname" + df = cudf.DataFrame({name: self}) + f_ = cuda.jit(device=True)(func) + + return df.apply(lambda row: f_(row[name])) def applymap(self, udf, out_dtype=None): """Apply an elementwise function to transform the values in the Column. diff --git a/python/cudf/cudf/core/udf/lowering.py b/python/cudf/cudf/core/udf/lowering.py index 3986abc2bf0..7af73537aa8 100644 --- a/python/cudf/cudf/core/udf/lowering.py +++ b/python/cudf/cudf/core/udf/lowering.py @@ -269,6 +269,7 @@ def cast_masked_to_masked(context, builder, fromty, toty, val): # Masked constructor for use in a kernel for testing +@lower_builtin(api.Masked, types.Boolean, types.boolean) @lower_builtin(api.Masked, types.Number, types.boolean) def masked_constructor(context, builder, sig, args): ty = sig.return_type diff --git a/python/cudf/cudf/core/udf/pipeline.py b/python/cudf/cudf/core/udf/pipeline.py index 7f3aa7baa93..b52668fcd05 100644 --- a/python/cudf/cudf/core/udf/pipeline.py +++ b/python/cudf/cudf/core/udf/pipeline.py @@ -1,8 +1,10 @@ +import math + import cachetools import numpy as np from numba import cuda from numba.np import numpy_support -from numba.types import Tuple, boolean, int64, void +from numba.types import Record, Tuple, boolean, int64, void from nvtx import annotate from cudf.core.udf.api import Masked, pack_return @@ -14,21 +16,67 @@ precompiled: cachetools.LRUCache = cachetools.LRUCache(maxsize=32) +def get_frame_row_type(fr): + """ + Get the numba `Record` type corresponding to a frame. + Models each column and its mask as a MaskedType and + models the row as a dictionary like data structure + containing these MaskedTypes. + + Large parts of this function are copied with comments + from the Numba internals and slightly modified to + account for validity bools to be present in the final + struct. + """ + + # Create the numpy structured type corresponding to the frame. + dtype = np.dtype([(name, col.dtype) for name, col in fr._data.items()]) + + fields = [] + offset = 0 + + sizes = [val[0].itemsize for val in dtype.fields.values()] + for i, (name, info) in enumerate(dtype.fields.items()): + # *info* consists of the element dtype, its offset from the beginning + # of the record, and an optional "title" containing metadata. + # We ignore the offset in info because its value assumes no masking; + # instead, we compute the correct offset based on the masked type. + elemdtype = info[0] + title = info[2] if len(info) == 3 else None + ty = numpy_support.from_dtype(elemdtype) + infos = { + "type": MaskedType(ty), + "offset": offset, + "title": title, + } + fields.append((name, infos)) + + # increment offset by itemsize plus one byte for validity + offset += elemdtype.itemsize + 1 + + # Align the next member of the struct to be a multiple of the + # memory access size, per PTX ISA 7.4/5.4.5 + if i < len(sizes) - 1: + next_itemsize = sizes[i + 1] + offset = int(math.ceil(offset / next_itemsize) * next_itemsize) + + # Numba requires that structures are aligned for the CUDA target + _is_aligned_struct = True + return Record(fields, offset, _is_aligned_struct) + + @annotate("NUMBA JIT", color="green", domain="cudf_python") -def get_udf_return_type(func, dtypes): +def get_udf_return_type(func, df): """ Get the return type of a masked UDF for a given set of argument dtypes. It is assumed that a `MaskedType(dtype)` is passed to the function for each input dtype. """ - to_compiler_sig = tuple( - MaskedType(arg) - for arg in (numpy_support.from_dtype(np_type) for np_type in dtypes) - ) + row_type = get_frame_row_type(df) + # Get the return type. The PTX is also returned by compile_udf, but is not # needed here. - ptx, output_type = cudautils.compile_udf(func, to_compiler_sig) - + ptx, output_type = cudautils.compile_udf(func, (row_type,)) if not isinstance(output_type, MaskedType): numba_output_type = numpy_support.from_dtype(np.dtype(output_type)) else: @@ -37,33 +85,6 @@ def get_udf_return_type(func, dtypes): return numba_output_type -def nulludf(func): - """ - Mimic pandas API: - - def f(x, y): - return x + y - df.apply(lambda row: f(row['x'], row['y'])) - - in this scheme, `row` is actually the whole dataframe - `DataFrame` sends `self` in as `row` and subsequently - we end up calling `f` on the resulting columns since - the dataframe is dict-like - """ - - def wrapper(*args): - from cudf import DataFrame - - # This probably creates copies but is fine for now - to_udf_table = DataFrame( - {idx: arg for idx, arg in zip(range(len(args)), args)} - ) - # Frame._apply - return to_udf_table._apply(func) - - return wrapper - - def masked_array_type_from_col(col): """ Return a type representing a tuple of arrays, @@ -109,8 +130,19 @@ def _kernel(retval, {input_columns}, {input_offsets}, size): i = cuda.grid(1) ret_data_arr, ret_mask_arr = retval if i < size: + # Create a structured array with the desired fields + rows = cuda.local.array(1, dtype=row_type) + + # one element of that array + row = rows[0] + {masked_input_initializers} - ret = {user_udf_call} +{row_initializers} + + # pass the assembled row into the udf + ret = f_(row) + + # pack up the return values and set them ret_masked = pack_return(ret) ret_data_arr[i] = ret_masked.value ret_mask_arr[i] = ret_masked.valid @@ -126,19 +158,52 @@ def _kernel(retval, {input_columns}, {input_offsets}, size): masked_{idx} = Masked(d_{idx}[i], mask_get(m_{idx}, i + offset_{idx})) """ +row_initializer_template = """\ + row["{name}"] = masked_{idx} +""" -def _define_function(df, scalar_return=False): - # Create argument list for kernel - input_columns = ", ".join([f"input_col_{i}" for i in range(len(df._data))]) - input_offsets = ", ".join([f"offset_{i}" for i in range(len(df._data))]) - # Create argument list to pass to device function - args = ", ".join([f"masked_{i}" for i in range(len(df._data))]) - user_udf_call = f"f_({args})" +def _define_function(fr, row_type, scalar_return=False): + """ + The kernel we want to JIT compile looks something like the following, + which is an example for two columns that both have nulls present + + def _kernel(retval, input_col_0, input_col_1, offset_0, offset_1, size): + i = cuda.grid(1) + ret_data_arr, ret_mask_arr = retval + if i < size: + rows = cuda.local.array(1, dtype=row_type) + row = rows[0] + + d_0, m_0 = input_col_0 + masked_0 = Masked(d_0[i], mask_get(m_0, i + offset_0)) + d_1, m_1 = input_col_1 + masked_1 = Masked(d_1[i], mask_get(m_1, i + offset_1)) + + row["a"] = masked_0 + row["b"] = masked_1 + + ret = f_(row) + + ret_masked = pack_return(ret) + ret_data_arr[i] = ret_masked.value + ret_mask_arr[i] = ret_masked.valid + + However we do not always have two columns and columns do not always have + an associated mask. Ideally, we would just write one kernel and make use + of `*args` - and then one function would work for any number of columns, + currently numba does not support `*args` and treats functions it JITs as + if `*args` is a singular argument. Thus we are forced to write the right + funtions dynamically at runtime and define them using `exec`. + """ + # Create argument list for kernel + input_columns = ", ".join([f"input_col_{i}" for i in range(len(fr._data))]) + input_offsets = ", ".join([f"offset_{i}" for i in range(len(fr._data))]) # Generate the initializers for each device function argument initializers = [] - for i, col in enumerate(df._data.values()): + row_initializers = [] + for i, (colname, col) in enumerate(fr._data.items()): idx = str(i) if col.mask is not None: template = masked_input_initializer_template @@ -149,14 +214,21 @@ def _define_function(df, scalar_return=False): initializers.append(initializer) + row_initializer = row_initializer_template.format( + idx=idx, name=colname + ) + row_initializers.append(row_initializer) + masked_input_initializers = "\n".join(initializers) + row_initializers = "\n".join(row_initializers) # Incorporate all of the above into the kernel code template d = { "input_columns": input_columns, "input_offsets": input_offsets, "masked_input_initializers": masked_input_initializers, - "user_udf_call": user_udf_call, + "row_initializers": row_initializers, + "numba_rectype": row_type, # from global } return kernel_template.format(**d) @@ -173,6 +245,16 @@ def compile_or_get(df, f): If the UDF has already been compiled for this requested dtypes, a cached version will be returned instead of running compilation. + CUDA kernels are void and do not return values. Thus, we need to + preallocate a column of the correct dtype and pass it in as one of + the kernel arguments. This creates a chicken-and-egg problem where + we need the column type to compile the kernel, but normally we would + be getting that type FROM compiling the kernel (and letting numba + determine it as a return value). As a workaround, we compile the UDF + itself outside the final kernel to invoke a full typing pass, which + unfortunately is difficult to do without running full compilation. + we then obtain the return type from that separate compilation and + use it to allocate an output column of the right dtype. """ # check to see if we already compiled this function @@ -180,12 +262,15 @@ def compile_or_get(df, f): cache_key = ( *cudautils.make_cache_key(f, frame_dtypes), *(col.mask is None for col in df._data.values()), + *df._data.keys(), ) if precompiled.get(cache_key) is not None: kernel, scalar_return_type = precompiled[cache_key] return kernel, scalar_return_type - numba_return_type = get_udf_return_type(f, frame_dtypes) + # precompile the user udf to get the right return type. + # could be a MaskedType or a scalar type. + numba_return_type = get_udf_return_type(f, df) _is_scalar_return = not isinstance(numba_return_type, MaskedType) scalar_return_type = ( @@ -194,9 +279,14 @@ def compile_or_get(df, f): else numba_return_type.value_type ) + # this is the signature for the final full kernel compilation sig = construct_signature(df, scalar_return_type) - f_ = cuda.jit(device=True)(f) + # this row type is used within the kernel to pack up the column and + # mask data into the dict like data structure the user udf expects + row_type = get_frame_row_type(df) + + f_ = cuda.jit(device=True)(f) # Dict of 'local' variables into which `_kernel` is defined local_exec_context = {} global_exec_context = { @@ -205,9 +295,10 @@ def compile_or_get(df, f): "Masked": Masked, "mask_get": mask_get, "pack_return": pack_return, + "row_type": row_type, } exec( - _define_function(df, scalar_return=_is_scalar_return), + _define_function(df, row_type, scalar_return=_is_scalar_return), global_exec_context, local_exec_context, ) diff --git a/python/cudf/cudf/core/udf/typing.py b/python/cudf/cudf/core/udf/typing.py index 042d97db838..06ba1b5ad5c 100644 --- a/python/cudf/cudf/core/udf/typing.py +++ b/python/cudf/cudf/core/udf/typing.py @@ -111,10 +111,9 @@ def typeof_masked(val, c): @cuda_decl_registry.register class MaskedConstructor(ConcreteTemplate): key = api.Masked - cases = [ nb_signature(MaskedType(t), t, types.boolean) - for t in (types.integer_domain | types.real_domain) + for t in (types.integer_domain | types.real_domain | {types.boolean}) ] diff --git a/python/cudf/cudf/tests/test_udf_masked_ops.py b/python/cudf/cudf/tests/test_udf_masked_ops.py index c0018dae47d..0d767b6d2b3 100644 --- a/python/cudf/cudf/tests/test_udf_masked_ops.py +++ b/python/cudf/cudf/tests/test_udf_masked_ops.py @@ -2,9 +2,9 @@ import pandas as pd import pytest +from numba import cuda import cudf -from cudf.core.udf.pipeline import nulludf from cudf.testing._utils import NUMERIC_TYPES, assert_eq arith_ops = [ @@ -31,12 +31,8 @@ def run_masked_udf_test(func_pdf, func_gdf, data, **kwargs): gdf = data pdf = data.to_pandas(nullable=True) - expect = pdf.apply( - lambda row: func_pdf(*[row[i] for i in data.columns]), axis=1 - ) - obtain = gdf.apply( - lambda row: func_gdf(*[row[i] for i in data.columns]), axis=1 - ) + expect = pdf.apply(func_pdf, axis=1) + obtain = gdf.apply(func_gdf, axis=1) assert_eq(expect, obtain, **kwargs) @@ -54,11 +50,14 @@ def test_arith_masked_vs_masked(op): # This test should test all the typing # and lowering for arithmetic ops between # two columns - def func_pdf(x, y): + def func_pdf(row): + x = row["a"] + y = row["b"] return op(x, y) - @nulludf - def func_gdf(x, y): + def func_gdf(row): + x = row["a"] + y = row["b"] return op(x, y) gdf = cudf.DataFrame({"a": [1, None, 3, None], "b": [4, 5, None, None]}) @@ -71,11 +70,14 @@ def test_compare_masked_vs_masked(op): # typing and lowering for comparisons # between columns - def func_pdf(x, y): + def func_pdf(row): + x = row["a"] + y = row["b"] return op(x, y) - @nulludf - def func_gdf(x, y): + def func_gdf(row): + x = row["a"] + y = row["b"] return op(x, y) # we should get: @@ -90,11 +92,12 @@ def func_gdf(x, y): @pytest.mark.parametrize("constant", [1, 1.5, True, False]) @pytest.mark.parametrize("data", [[1, 2, cudf.NA]]) def test_arith_masked_vs_constant(op, constant, data): - def func_pdf(x): + def func_pdf(row): + x = row["data"] return op(x, constant) - @nulludf - def func_gdf(x): + def func_gdf(row): + x = row["data"] return op(x, constant) gdf = cudf.DataFrame({"data": data}) @@ -119,11 +122,12 @@ def func_gdf(x): @pytest.mark.parametrize("constant", [1, 1.5, True, False]) @pytest.mark.parametrize("data", [[2, 3, cudf.NA], [1, cudf.NA, 1]]) def test_arith_masked_vs_constant_reflected(op, constant, data): - def func_pdf(x): + def func_pdf(row): + x = row["data"] return op(constant, x) - @nulludf - def func_gdf(x): + def func_gdf(row): + x = row["data"] return op(constant, x) # Just a single column -> result will be all NA @@ -141,11 +145,12 @@ def func_gdf(x): @pytest.mark.parametrize("op", arith_ops) @pytest.mark.parametrize("data", [[1, cudf.NA, 3], [2, 3, cudf.NA]]) def test_arith_masked_vs_null(op, data): - def func_pdf(x): + def func_pdf(row): + x = row["data"] return op(x, pd.NA) - @nulludf - def func_gdf(x): + def func_gdf(row): + x = row["data"] return op(x, cudf.NA) gdf = cudf.DataFrame({"data": data}) @@ -158,11 +163,12 @@ def func_gdf(x): @pytest.mark.parametrize("op", arith_ops) def test_arith_masked_vs_null_reflected(op): - def func_pdf(x): + def func_pdf(row): + x = row["data"] return op(pd.NA, x) - @nulludf - def func_gdf(x): + def func_gdf(row): + x = row["data"] return op(cudf.NA, x) gdf = cudf.DataFrame({"data": [1, None, 3]}) @@ -170,14 +176,17 @@ def func_gdf(x): def test_masked_is_null_conditional(): - def func_pdf(x, y): + def func_pdf(row): + x = row["a"] + y = row["b"] if x is pd.NA: return y else: return x + y - @nulludf - def func_gdf(x, y): + def func_gdf(row): + x = row["a"] + y = row["b"] if x is cudf.NA: return y else: @@ -196,11 +205,14 @@ def test_apply_mixed_dtypes(dtype_a, dtype_b): values and nulls """ # TODO: Parameterize over the op here - def func_pdf(x, y): + def func_pdf(row): + x = row["a"] + y = row["b"] return x + y - @nulludf - def func_gdf(x, y): + def func_gdf(row): + x = row["a"] + y = row["b"] return x + y gdf = cudf.DataFrame({"a": [1.5, None, 3, None], "b": [4, 5, None, None]}) @@ -218,14 +230,17 @@ def test_apply_return_literal(val): to a MaskedType """ - def func_pdf(x, y): + def func_pdf(row): + x = row["a"] + y = row["b"] if x is not pd.NA and x < 2: return val else: return x + y - @nulludf - def func_gdf(x, y): + def func_gdf(row): + x = row["a"] + y = row["b"] if x is not cudf.NA and x < 2: return val else: @@ -241,14 +256,15 @@ def test_apply_return_null(): Tests casting / unification of Masked and NA """ - def func_pdf(x): + def func_pdf(row): + x = row["a"] if x is pd.NA: return pd.NA else: return x - @nulludf - def func_gdf(x): + def func_gdf(row): + x = row["a"] if x is cudf.NA: return cudf.NA else: @@ -259,14 +275,15 @@ def func_gdf(x): def test_apply_return_either_null_or_literal(): - def func_pdf(x): + def func_pdf(row): + x = row["a"] if x > 5: return 2 else: return pd.NA - @nulludf - def func_gdf(x): + def func_gdf(row): + x = row["a"] if x > 5: return 2 else: @@ -280,7 +297,6 @@ def test_apply_return_literal_only(): def func_pdf(x): return 5 - @nulludf def func_gdf(x): return 5 @@ -289,7 +305,11 @@ def func_gdf(x): def test_apply_everything(): - def func_pdf(w, x, y, z): + def func_pdf(row): + w = row["a"] + x = row["b"] + y = row["c"] + z = row["d"] if x is pd.NA: return w + y - z elif ((z > y) is not pd.NA) and z > y: @@ -301,8 +321,11 @@ def func_pdf(w, x, y, z): else: return y > 2 - @nulludf - def func_gdf(w, x, y, z): + def func_gdf(row): + w = row["a"] + x = row["b"] + y = row["c"] + z = row["d"] if x is cudf.NA: return w + y - z elif ((z > y) is not cudf.NA) and z > y: @@ -430,3 +453,69 @@ def func_gsr(x): data = cudf.Series([1, cudf.NA, 3, cudf.NA]) run_masked_udf_series(func_psr, func_gsr, data, check_dtype=False) + + +@pytest.mark.parametrize("op", arith_ops + comparison_ops) +def test_masked_udf_lambda_support(op): + func = lambda row: op(row["a"], row["b"]) # noqa: E731 + + data = cudf.DataFrame( + {"a": [1, cudf.NA, 3, cudf.NA], "b": [1, 2, cudf.NA, cudf.NA]} + ) + + run_masked_udf_test(func, func, data, check_dtype=False) + + +@pytest.mark.parametrize("op", arith_ops + comparison_ops) +def test_masked_udf_nested_function_support(op): + """ + Nested functions need to be explicitly jitted by the user + for numba to recognize them. Unfortunately the object + representing the jitted function can not itself be used in + pandas udfs. + """ + + def inner(x, y): + return op(x, y) + + def outer(row): + x = row["a"] + y = row["b"] + return inner(x, y) + + data = cudf.DataFrame( + {"a": [1, cudf.NA, 3, cudf.NA], "b": [1, 2, cudf.NA, cudf.NA]} + ) + + with pytest.raises(AttributeError): + run_masked_udf_test(outer, outer, data, check_dtype=False) + + inner_gpu = cuda.jit(device=True)(inner) + + def outer_gpu(row): + x = row["a"] + y = row["b"] + return inner_gpu(x, y) + + run_masked_udf_test(outer, outer_gpu, data, check_dtype=False) + + +@pytest.mark.parametrize( + "data", + [ + {"a": [1, 2, 3], "b": [4, 5, 6], "c": [7, 8, 9]}, + {"a": [1, 2, 3], "c": [4, 5, 6], "b": [7, 8, 9]}, + pytest.param( + {"a": [1, 2, 3], "b": [4, 5, 6], "c": ["a", "b", "c"]}, + marks=pytest.mark.xfail( + reason="Until cudf/9359 is merged, this will fail" + ), + ), + ], +) +def test_masked_udf_subset_selection(data): + def func(row): + return row["a"] + row["b"] + + data = cudf.DataFrame(data) + run_masked_udf_test(func, func, data) diff --git a/python/cudf/cudf/utils/cudautils.py b/python/cudf/cudf/utils/cudautils.py index 7b7fe674210..5fa091a0081 100755 --- a/python/cudf/cudf/utils/cudautils.py +++ b/python/cudf/cudf/utils/cudautils.py @@ -211,6 +211,10 @@ def grouped_window_sizes_from_offset(arr, group_starts, offset): def make_cache_key(udf, sig): + """ + Build a cache key for a user defined function. Used to avoid + recompiling the same function for the same set of types + """ codebytes = udf.__code__.co_code if udf.__closure__ is not None: cvars = tuple([x.cell_contents for x in udf.__closure__]) @@ -252,8 +256,6 @@ def compile_udf(udf, type_signature): """ import cudf.core.udf - # Check if we've already compiled a similar (but possibly distinct) - # function before key = make_cache_key(udf, type_signature) res = _udf_code_cache.get(key) if res: