diff --git a/cpp/cmake/Modules/JitifyPreprocessKernels.cmake b/cpp/cmake/Modules/JitifyPreprocessKernels.cmake index eb1ade61440..7e2ec5254d3 100644 --- a/cpp/cmake/Modules/JitifyPreprocessKernels.cmake +++ b/cpp/cmake/Modules/JitifyPreprocessKernels.cmake @@ -56,6 +56,7 @@ endfunction() jit_preprocess_files(SOURCE_DIRECTORY ${CUDF_SOURCE_DIR}/src FILES binaryop/jit/kernel.cu + transform/jit/masked_udf_kernel.cu transform/jit/kernel.cu rolling/jit/kernel.cu ) diff --git a/cpp/include/cudf/transform.hpp b/cpp/include/cudf/transform.hpp index 460c62e3598..f5880e9b37f 100644 --- a/cpp/include/cudf/transform.hpp +++ b/cpp/include/cudf/transform.hpp @@ -53,6 +53,12 @@ std::unique_ptr transform( bool is_ptx, rmm::mr::device_memory_resource* mr = rmm::mr::get_current_device_resource()); +std::unique_ptr generalized_masked_op( + table_view const& data_view, + std::string const& binary_udf, + data_type output_type, + rmm::mr::device_memory_resource* mr = rmm::mr::get_current_device_resource()); + /** * @brief Creates a null_mask from `input` by converting `NaN` to null and * preserving existing null values and also returns new null_count. diff --git a/cpp/src/transform/jit/masked_udf_kernel.cu b/cpp/src/transform/jit/masked_udf_kernel.cu new file mode 100644 index 00000000000..319ad730c53 --- /dev/null +++ b/cpp/src/transform/jit/masked_udf_kernel.cu @@ -0,0 +1,85 @@ +/* + * Copyright (c) 2021, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include + +#include +#include + +#include +#include +#include +#include +#include + +namespace cudf { +namespace transformation { +namespace jit { + +template +struct Masked { + T value; + bool valid; +}; + +template +__device__ auto make_args(cudf::size_type id, TypeIn in_ptr, MaskType in_mask, OffsetType in_offset) +{ + bool valid = in_mask ? cudf::bit_is_set(in_mask, in_offset + id) : true; + return cuda::std::make_tuple(in_ptr[id], valid); +} + +template +__device__ auto make_args(cudf::size_type id, + InType in_ptr, + MaskType in_mask, // in practice, always cudf::bitmask_type const* + OffsetType in_offset, // in practice, always cudf::size_type + Arguments... args) +{ + bool valid = in_mask ? cudf::bit_is_set(in_mask, in_offset + id) : true; + return cuda::std::tuple_cat(cuda::std::make_tuple(in_ptr[id], valid), make_args(id, args...)); +} + +template +__global__ void generic_udf_kernel(cudf::size_type size, + TypeOut* out_data, + bool* out_mask, + Arguments... args) +{ + int const tid = threadIdx.x; + int const blkid = blockIdx.x; + int const blksz = blockDim.x; + int const gridsz = gridDim.x; + int const start = tid + blkid * blksz; + int const step = blksz * gridsz; + + Masked output; + for (cudf::size_type i = start; i < size; i += step) { + auto func_args = cuda::std::tuple_cat( + cuda::std::make_tuple(&output.value), + make_args(i, args...) // passed int64*, bool*, int64, int64*, bool*, int64 + ); + cuda::std::apply(GENERIC_OP, func_args); + out_data[i] = output.value; + out_mask[i] = output.valid; + } +} + +} // namespace jit +} // namespace transformation +} // namespace cudf diff --git a/cpp/src/transform/transform.cpp b/cpp/src/transform/transform.cpp index 40feab00b3c..5230b853a79 100644 --- a/cpp/src/transform/transform.cpp +++ b/cpp/src/transform/transform.cpp @@ -14,20 +14,22 @@ * limitations under the License. */ -#include - -#include -#include -#include - #include #include #include #include #include +#include #include #include +#include +#include + +#include +#include +#include + #include namespace cudf { @@ -63,6 +65,80 @@ void unary_operation(mutable_column_view output, cudf::jit::get_data_ptr(input)); } +std::vector make_template_types(column_view outcol_view, table_view const& data_view) +{ + std::string mskptr_type = + cudf::jit::get_type_name(cudf::data_type(cudf::type_to_id())) + "*"; + std::string offset_type = + cudf::jit::get_type_name(cudf::data_type(cudf::type_to_id())); + + std::vector template_types; + template_types.reserve((3 * data_view.num_columns()) + 1); + + template_types.push_back(cudf::jit::get_type_name(outcol_view.type())); + for (auto const& col : data_view) { + template_types.push_back(cudf::jit::get_type_name(col.type()) + "*"); + template_types.push_back(mskptr_type); + template_types.push_back(offset_type); + } + return template_types; +} + +void generalized_operation(table_view const& data_view, + std::string const& udf, + data_type output_type, + mutable_column_view outcol_view, + mutable_column_view outmsk_view, + rmm::cuda_stream_view stream, + rmm::mr::device_memory_resource* mr) +{ + auto const template_types = make_template_types(outcol_view, data_view); + + std::string generic_kernel_name = + jitify2::reflection::Template("cudf::transformation::jit::generic_udf_kernel") + .instantiate(template_types); + + std::string generic_cuda_source = cudf::jit::parse_single_function_ptx( + udf, "GENERIC_OP", cudf::jit::get_type_name(output_type), {0}); + + std::vector kernel_args; + kernel_args.reserve((data_view.num_columns() * 3) + 3); + + cudf::size_type size = outcol_view.size(); + const void* outcol_ptr = cudf::jit::get_data_ptr(outcol_view); + const void* outmsk_ptr = cudf::jit::get_data_ptr(outmsk_view); + kernel_args.insert(kernel_args.begin(), {&size, &outcol_ptr, &outmsk_ptr}); + + std::vector data_ptrs; + std::vector mask_ptrs; + std::vector offsets; + + data_ptrs.reserve(data_view.num_columns()); + mask_ptrs.reserve(data_view.num_columns()); + offsets.reserve(data_view.num_columns()); + + auto const iters = thrust::make_zip_iterator( + thrust::make_tuple(data_ptrs.begin(), mask_ptrs.begin(), offsets.begin())); + + std::for_each(iters, iters + data_view.num_columns(), [&](auto const& tuple_vals) { + kernel_args.push_back(&thrust::get<0>(tuple_vals)); + kernel_args.push_back(&thrust::get<1>(tuple_vals)); + kernel_args.push_back(&thrust::get<2>(tuple_vals)); + }); + + std::transform(data_view.begin(), data_view.end(), iters, [&](column_view const& col) { + return thrust::make_tuple(cudf::jit::get_data_ptr(col), col.null_mask(), col.offset()); + }); + + cudf::jit::get_program_cache(*transform_jit_masked_udf_kernel_cu_jit) + .get_kernel(generic_kernel_name, + {}, + {{"transform/jit/operation-udf.hpp", generic_cuda_source}}, + {"-arch=sm_."}) + ->configure_1d_max_occupancy(0, 0, 0, stream.value()) + ->launch(kernel_args.data()); +} + } // namespace jit } // namespace transformation @@ -89,6 +165,24 @@ std::unique_ptr transform(column_view const& input, return output; } +std::unique_ptr generalized_masked_op(table_view const& data_view, + std::string const& udf, + data_type output_type, + rmm::cuda_stream_view stream, + rmm::mr::device_memory_resource* mr) +{ + std::unique_ptr output = make_fixed_width_column(output_type, data_view.num_rows()); + std::unique_ptr output_mask = + make_fixed_width_column(cudf::data_type{cudf::type_id::BOOL8}, data_view.num_rows()); + + transformation::jit::generalized_operation( + data_view, udf, output_type, *output, *output_mask, stream, mr); + + auto final_output_mask = cudf::bools_to_mask(*output_mask); + output.get()->set_null_mask(std::move(*(final_output_mask.first))); + return output; +} + } // namespace detail std::unique_ptr transform(column_view const& input, @@ -101,4 +195,12 @@ std::unique_ptr transform(column_view const& input, return detail::transform(input, unary_udf, output_type, is_ptx, rmm::cuda_stream_default, mr); } +std::unique_ptr generalized_masked_op(table_view const& data_view, + std::string const& udf, + data_type output_type, + rmm::mr::device_memory_resource* mr) +{ + return detail::generalized_masked_op(data_view, udf, output_type, rmm::cuda_stream_default, mr); +} + } // namespace cudf diff --git a/python/cudf/cudf/_lib/cpp/transform.pxd b/python/cudf/cudf/_lib/cpp/transform.pxd index 5e37336cb94..9cb5bc10162 100644 --- a/python/cudf/cudf/_lib/cpp/transform.pxd +++ b/python/cudf/cudf/_lib/cpp/transform.pxd @@ -38,6 +38,12 @@ cdef extern from "cudf/transform.hpp" namespace "cudf" nogil: bool is_ptx ) except + + cdef unique_ptr[column] generalized_masked_op( + const table_view& data_view, + string udf, + data_type output_type, + ) except + + cdef pair[unique_ptr[table], unique_ptr[column]] encode( table_view input ) except + diff --git a/python/cudf/cudf/_lib/transform.pyx b/python/cudf/cudf/_lib/transform.pyx index 2c83f8b86e0..3ba9aac5687 100644 --- a/python/cudf/cudf/_lib/transform.pyx +++ b/python/cudf/cudf/_lib/transform.pyx @@ -123,6 +123,27 @@ def transform(Column input, op): return Column.from_unique_ptr(move(c_output)) +def masked_udf(Table incols, op, output_type): + cdef table_view data_view = incols.data_view() + cdef string c_str = op.encode("UTF-8") + cdef type_id c_tid + cdef data_type c_dtype + + c_tid = ( + np_to_cudf_types[output_type] + ) + c_dtype = data_type(c_tid) + + with nogil: + c_output = move(libcudf_transform.generalized_masked_op( + data_view, + c_str, + c_dtype, + )) + + return Column.from_unique_ptr(move(c_output)) + + def table_encode(Table input): cdef table_view c_input = input.data_view() cdef pair[unique_ptr[table], unique_ptr[column]] c_result diff --git a/python/cudf/cudf/core/__init__.py b/python/cudf/cudf/core/__init__.py index bf54a90cb01..5eaa5b52fd4 100644 --- a/python/cudf/cudf/core/__init__.py +++ b/python/cudf/cudf/core/__init__.py @@ -27,4 +27,5 @@ from cudf.core.multiindex import MultiIndex from cudf.core.scalar import NA, Scalar from cudf.core.series import Series +import cudf.core.udf from cudf.core.cut import cut diff --git a/python/cudf/cudf/core/dataframe.py b/python/cudf/cudf/core/dataframe.py index c02bf3d11a4..780466458cc 100644 --- a/python/cudf/cudf/core/dataframe.py +++ b/python/cudf/cudf/core/dataframe.py @@ -4697,6 +4697,177 @@ def query(self, expr, local_dict=None): boolmask = queryutils.query_execute(self, expr, callenv) return self._apply_boolean_mask(boolmask) + def apply(self, func, axis=1): + """ + Apply a function along an axis of the DataFrame. + + Designed to mimic `pandas.DataFrame.apply`. Applies a user + defined function row wise over a dataframe, with true null + handling. Works with UDFs using `core.udf.pipeline.nulludf` + and returns a single series. Uses numba to jit compile the + function to PTX via LLVM. + + Parameters + ---------- + func : function + Function to apply to each row. + + axis : {0 or 'index', 1 or 'columns'}, default 0 + Axis along which the function is applied: + * 0 or 'index': apply function to each column. + Note: axis=0 is not yet supported. + * 1 or 'columns': apply function to each row. + + Examples + -------- + + Simple function of a single variable which could be NA + + >>> from cudf.core.udf.pipeline import nulludf + >>> @nulludf + ... def f(x): + ... if x is cudf.NA: + ... return 0 + ... else: + ... return x + 1 + ... + >>> df = cudf.DataFrame({'a': [1, cudf.NA, 3]}) + >>> df.apply(lambda row: f(row['a'])) + 0 2 + 1 0 + 2 4 + dtype: int64 + + Function of multiple variables will operate in + a null aware manner + + >>> @nulludf + ... def f(x, y): + ... return x - y + ... + >>> df = cudf.DataFrame({ + ... 'a': [1, cudf.NA, 3, cudf.NA], + ... 'b': [5, 6, cudf.NA, cudf.NA] + ... }) + >>> df.apply(lambda row: f(row['a'], row['b'])) + 0 -4 + 1 + 2 + 3 + dtype: int64 + + Functions may conditionally return NA as in pandas + + >>> @nulludf + ... def f(x, y): + ... if x + y > 3: + ... return cudf.NA + ... else: + ... return x + y + ... + >>> df = cudf.DataFrame({ + ... 'a': [1, 2, 3], + ... 'b': [2, 1, 1] + ... }) + >>> df.apply(lambda row: f(row['a'], row['b'])) + 0 3 + 1 3 + 2 + dtype: int64 + + Mixed types are allowed, but will return the common + type, rather than object as in pandas + + >>> @nulludf + ... def f(x, y): + ... return x + y + ... + >>> df = cudf.DataFrame({ + ... 'a': [1, 2, 3], + ... 'b': [0.5, cudf.NA, 3.14] + ... }) + >>> df.apply(lambda row: f(row['a'], row['b'])) + 0 1.5 + 1 + 2 6.14 + dtype: float64 + + Functions may also return scalar values, however the + result will be promoted to a safe type regardless of + the data + + >>> @nulludf + ... def f(x): + ... if x > 3: + ... return x + ... else: + ... return 1.5 + ... + >>> df = cudf.DataFrame({ + ... 'a': [1, 3, 5] + ... }) + >>> df.apply(lambda row: f(row['a'])) + 0 1.5 + 1 1.5 + 2 5.0 + dtype: float64 + + Ops against N columns are supported generally + + >>> @nulludf + ... def f(v, w, x, y, z): + ... return x + (y - (z / w)) % v + ... + >>> df = cudf.DataFrame({ + ... 'a': [1, 2, 3], + ... 'b': [4, 5, 6], + ... 'c': [cudf.NA, 4, 4], + ... 'd': [8, 7, 8], + ... 'e': [7, 1, 6] + ... }) + >>> df.apply( + ... lambda row: f( + ... row['a'], + ... row['b'], + ... row['c'], + ... row['d'], + ... row['e'] + ... ) + ... ) + 0 + 1 4.8 + 2 5.0 + dtype: float64 + + Notes + ----- + Available only using cuda 11.1+ due to particular required + runtime compilation features + """ + + # libcudacxx tuples are not compatible with nvrtc 11.0 + runtime = cuda.cudadrv.runtime.Runtime() + mjr, mnr = runtime.get_version() + if mjr < 11 or (mjr == 11 and mnr < 1): + raise RuntimeError("DataFrame.apply requires CUDA 11.1+") + + for dtype in self.dtypes: + if ( + isinstance(dtype, cudf.core.dtypes._BaseDtype) + or dtype == "object" + ): + raise TypeError( + "DataFrame.apply currently only " + "supports non decimal numeric types" + ) + + if axis != 1: + raise ValueError( + "DataFrame.apply currently only supports row wise ops" + ) + + return cudf.Series(func(self)) + @applyutils.doc_apply() def apply_rows( self, diff --git a/python/cudf/cudf/core/frame.py b/python/cudf/cudf/core/frame.py index 3629358ee9f..c01aa201b7c 100644 --- a/python/cudf/cudf/core/frame.py +++ b/python/cudf/cudf/core/frame.py @@ -1456,6 +1456,17 @@ def _quantiles( result._copy_type_metadata(self) return result + @annotate("APPLY", color="purple", domain="cudf_python") + def _apply(self, func): + """ + Apply `func` across the rows of the frame. + """ + output_dtype, ptx = cudf.core.udf.pipeline.compile_masked_udf( + func, self.dtypes + ) + result = cudf._lib.transform.masked_udf(self, ptx, output_dtype) + return result + def rank( self, axis=0, diff --git a/python/cudf/cudf/core/udf/__init__.py b/python/cudf/cudf/core/udf/__init__.py new file mode 100644 index 00000000000..4608cae3228 --- /dev/null +++ b/python/cudf/cudf/core/udf/__init__.py @@ -0,0 +1 @@ +from . import typing, lowering diff --git a/python/cudf/cudf/core/udf/_ops.py b/python/cudf/cudf/core/udf/_ops.py new file mode 100644 index 00000000000..25201356fd9 --- /dev/null +++ b/python/cudf/cudf/core/udf/_ops.py @@ -0,0 +1,20 @@ +import operator + +arith_ops = [ + operator.add, + operator.sub, + operator.mul, + operator.truediv, + operator.floordiv, + operator.mod, + operator.pow, +] + +comparison_ops = [ + operator.eq, + operator.ne, + operator.lt, + operator.le, + operator.gt, + operator.ge, +] diff --git a/python/cudf/cudf/core/udf/classes.py b/python/cudf/cudf/core/udf/classes.py new file mode 100644 index 00000000000..fe2fbd9daad --- /dev/null +++ b/python/cudf/cudf/core/udf/classes.py @@ -0,0 +1,16 @@ +class Masked: + """ + Most of the time, MaskedType as defined in typing.py + combined with the ops defined to operate on them are + enough to fulfill the obligations of DataFrame.apply + However sometimes we need to refer to an instance of + a masked scalar outside the context of a UDF like as + a global variable. To get numba to identify that var + a of type MaskedType and treat it as such we need to + have an actual python class we can tie to MaskedType + This is that class + """ + + def __init__(self, value, valid): + self.value = value + self.valid = valid diff --git a/python/cudf/cudf/core/udf/lowering.py b/python/cudf/cudf/core/udf/lowering.py new file mode 100644 index 00000000000..1467a61f215 --- /dev/null +++ b/python/cudf/cudf/core/udf/lowering.py @@ -0,0 +1,273 @@ +import operator + +from llvmlite import ir +from numba.core import cgutils +from numba.core.typing import signature as nb_signature +from numba.cuda.cudaimpl import ( + lower as cuda_lower, + registry as cuda_lowering_registry, +) +from numba.extending import lower_builtin, types + +from cudf.core.udf.typing import MaskedType, NAType + +from . import classes +from ._ops import arith_ops, comparison_ops + + +@cuda_lowering_registry.lower_constant(NAType) +def constant_na(context, builder, ty, pyval): + # This handles None, etc. + return context.get_dummy_value() + + +# In the typing phase, we declared that a `MaskedType` can be +# added to another `MaskedType` and specified what kind of a +# `MaskedType` would result. Now we have to actually fill in +# the implementation details of how to do that. This is where +# we can involve both validities in constructing the answer + + +def make_arithmetic_op(op): + """ + Make closures that implement arithmetic operations. See + register_arithmetic_op for details. + """ + + def masked_scalar_op_impl(context, builder, sig, args): + """ + Implement `MaskedType` `MaskedType` + """ + # MaskedType(...), MaskedType(...) + masked_type_1, masked_type_2 = sig.args + # MaskedType(...) + masked_return_type = sig.return_type + + # Let there be two actual LLVM structs backing the two inputs + # https://mapping-high-level-constructs-to-llvm-ir.readthedocs.io/en/latest/basic-constructs/structures.html + m1 = cgutils.create_struct_proxy(masked_type_1)( + context, builder, value=args[0] + ) + m2 = cgutils.create_struct_proxy(masked_type_2)( + context, builder, value=args[1] + ) + + # we will return an output struct + result = cgutils.create_struct_proxy(masked_return_type)( + context, builder + ) + + # compute output validity + valid = builder.and_(m1.valid, m2.valid) + result.valid = valid + with builder.if_then(valid): + # Let numba handle generating the extra IR needed to perform + # operations on mixed types, by compiling the final core op between + # the two primitive values as a separate function and calling it + result.value = context.compile_internal( + builder, + lambda x, y: op(x, y), + nb_signature( + masked_return_type.value_type, + masked_type_1.value_type, + masked_type_2.value_type, + ), + (m1.value, m2.value), + ) + return result._getvalue() + + return masked_scalar_op_impl + + +def register_arithmetic_op(op): + """ + Register a lowering implementation for the + arithmetic op `op`. + + Because the lowering implementations compile the final + op separately using a lambda and compile_internal, `op` + needs to be tied to each lowering implementation using + a closure. + + This function makes and lowers a closure for one op. + + """ + to_lower_op = make_arithmetic_op(op) + cuda_lower(op, MaskedType, MaskedType)(to_lower_op) + + +def masked_scalar_null_op_impl(context, builder, sig, args): + """ + Implement `MaskedType` `NAType` + or `NAType` `MaskedType` + The answer to this is known up front so no actual operation + needs to take place + """ + + return_type = sig.return_type # MaskedType(...) + result = cgutils.create_struct_proxy(MaskedType(return_type.value_type))( + context, builder + ) + + # Invalidate the struct and leave `value` uninitialized + result.valid = context.get_constant(types.boolean, 0) + return result._getvalue() + + +def make_const_op(op): + def masked_scalar_const_op_impl(context, builder, sig, args): + return_type = sig.return_type + result = cgutils.create_struct_proxy(return_type)(context, builder) + result.valid = context.get_constant(types.boolean, 0) + if isinstance(sig.args[0], MaskedType): + masked_type, const_type = sig.args + masked_value, const_value = args + + indata = cgutils.create_struct_proxy(masked_type)( + context, builder, value=masked_value + ) + nb_sig = nb_signature( + return_type.value_type, masked_type.value_type, const_type + ) + compile_args = (indata.value, const_value) + else: + const_type, masked_type = sig.args + const_value, masked_value = args + indata = cgutils.create_struct_proxy(masked_type)( + context, builder, value=masked_value + ) + nb_sig = nb_signature( + return_type.value_type, const_type, masked_type.value_type + ) + compile_args = (const_value, indata.value) + with builder.if_then(indata.valid): + result.value = context.compile_internal( + builder, lambda x, y: op(x, y), nb_sig, compile_args + ) + result.valid = context.get_constant(types.boolean, 1) + return result._getvalue() + + return masked_scalar_const_op_impl + + +def register_const_op(op): + to_lower_op = make_const_op(op) + cuda_lower(op, MaskedType, types.Number)(to_lower_op) + cuda_lower(op, types.Number, MaskedType)(to_lower_op) + + # to_lower_op_reflected = make_reflected_const_op(op) + # cuda_lower(op, types.Number, MaskedType)(to_lower_op_reflected) + + +# register all lowering at init +for op in arith_ops + comparison_ops: + register_arithmetic_op(op) + register_const_op(op) + # null op impl can be shared between all ops + cuda_lower(op, MaskedType, NAType)(masked_scalar_null_op_impl) + cuda_lower(op, NAType, MaskedType)(masked_scalar_null_op_impl) + + +@cuda_lower(operator.is_, MaskedType, NAType) +@cuda_lower(operator.is_, NAType, MaskedType) +def masked_scalar_is_null_impl(context, builder, sig, args): + """ + Implement `MaskedType` is `NA` + """ + if isinstance(sig.args[1], NAType): + masked_type, na = sig.args + value = args[0] + else: + na, masked_type = sig.args + value = args[1] + + indata = cgutils.create_struct_proxy(masked_type)( + context, builder, value=value + ) + result = cgutils.alloca_once(builder, ir.IntType(1)) + with builder.if_else(indata.valid) as (then, otherwise): + with then: + builder.store(context.get_constant(types.boolean, 0), result) + with otherwise: + builder.store(context.get_constant(types.boolean, 1), result) + + return builder.load(result) + + +@cuda_lower(operator.truth, MaskedType) +def masked_scalar_truth_impl(context, builder, sig, args): + indata = cgutils.create_struct_proxy(MaskedType(types.boolean))( + context, builder, value=args[0] + ) + return indata.value + + +@cuda_lower(bool, MaskedType) +def masked_scalar_bool_impl(context, builder, sig, args): + indata = cgutils.create_struct_proxy(MaskedType(types.boolean))( + context, builder, value=args[0] + ) + return indata.value + + +# To handle the unification, we need to support casting from any type to a +# masked type. The cast implementation takes the value passed in and returns +# a masked type struct wrapping that value. +@cuda_lowering_registry.lower_cast(types.Any, MaskedType) +def cast_primitive_to_masked(context, builder, fromty, toty, val): + casted = context.cast(builder, val, fromty, toty.value_type) + ext = cgutils.create_struct_proxy(toty)(context, builder) + ext.value = casted + ext.valid = context.get_constant(types.boolean, 1) + return ext._getvalue() + + +@cuda_lowering_registry.lower_cast(NAType, MaskedType) +def cast_na_to_masked(context, builder, fromty, toty, val): + result = cgutils.create_struct_proxy(toty)(context, builder) + result.valid = context.get_constant(types.boolean, 0) + + return result._getvalue() + + +@cuda_lowering_registry.lower_cast(MaskedType, MaskedType) +def cast_masked_to_masked(context, builder, fromty, toty, val): + """ + When numba encounters an op that expects a certain type and + the input to the op is not of the expected type it will try + to cast the input to the appropriate type. But, in our case + the input may be a MaskedType, which numba doesn't natively + know how to cast to a different MaskedType with a different + `value_type`. This implements and registers that cast. + """ + + # We will + operand = cgutils.create_struct_proxy(fromty)(context, builder, value=val) + casted = context.cast( + builder, operand.value, fromty.value_type, toty.value_type + ) + ext = cgutils.create_struct_proxy(toty)(context, builder) + ext.value = casted + ext.valid = operand.valid + return ext._getvalue() + + +# Masked constructor for use in a kernel for testing +@lower_builtin(classes.Masked, types.Number, types.boolean) +def masked_constructor(context, builder, sig, args): + ty = sig.return_type + value, valid = args + masked = cgutils.create_struct_proxy(ty)(context, builder) + masked.value = value + masked.valid = valid + return masked._getvalue() + + +# Allows us to make an instance of MaskedType a global variable +# and properly use it inside functions we will later compile +@cuda_lowering_registry.lower_constant(MaskedType) +def lower_constant_masked(context, builder, ty, val): + masked = cgutils.create_struct_proxy(ty)(context, builder) + masked.value = context.get_constant(ty.value_type, val.value) + masked.valid = context.get_constant(types.boolean, val.valid) + return masked._getvalue() diff --git a/python/cudf/cudf/core/udf/pipeline.py b/python/cudf/cudf/core/udf/pipeline.py new file mode 100644 index 00000000000..c7b8be92c00 --- /dev/null +++ b/python/cudf/cudf/core/udf/pipeline.py @@ -0,0 +1,52 @@ +from numba.np import numpy_support +from nvtx import annotate + +from cudf.core.udf.typing import MaskedType +from cudf.utils import cudautils + + +@annotate("NUMBA JIT", color="green", domain="cudf_python") +def compile_masked_udf(func, dtypes): + """ + Generate an inlineable PTX function that will be injected into + a variadic kernel inside libcudf + + assume all input types are `MaskedType(input_col.dtype)` and then + compile the requestied PTX function as a function over those types + """ + to_compiler_sig = tuple( + MaskedType(arg) + for arg in (numpy_support.from_dtype(np_type) for np_type in dtypes) + ) + # Get the inlineable PTX function + ptx, numba_output_type = cudautils.compile_udf(func, to_compiler_sig) + numpy_output_type = numpy_support.as_dtype(numba_output_type.value_type) + + return numpy_output_type, ptx + + +def nulludf(func): + """ + Mimic pandas API: + + def f(x, y): + return x + y + df.apply(lambda row: f(row['x'], row['y'])) + + in this scheme, `row` is actually the whole dataframe + `DataFrame` sends `self` in as `row` and subsequently + we end up calling `f` on the resulting columns since + the dataframe is dict-like + """ + + def wrapper(*args): + from cudf import DataFrame + + # This probably creates copies but is fine for now + to_udf_table = DataFrame( + {idx: arg for idx, arg in zip(range(len(args)), args)} + ) + # Frame._apply + return to_udf_table._apply(func) + + return wrapper diff --git a/python/cudf/cudf/core/udf/typing.py b/python/cudf/cudf/core/udf/typing.py new file mode 100644 index 00000000000..6e026412f24 --- /dev/null +++ b/python/cudf/cudf/core/udf/typing.py @@ -0,0 +1,294 @@ +import operator + +from numba import types +from numba.core.extending import ( + make_attribute_wrapper, + models, + register_model, + typeof_impl, +) +from numba.core.typing import signature as nb_signature +from numba.core.typing.templates import ( + AbstractTemplate, + AttributeTemplate, + ConcreteTemplate, +) +from numba.core.typing.typeof import typeof +from numba.cuda.cudadecl import registry as cuda_decl_registry +from pandas._libs.missing import NAType as _NAType + +from . import classes +from ._ops import arith_ops, comparison_ops + + +class MaskedType(types.Type): + """ + A Numba type consisting of a value of some primitive type + and a validity boolean, over which we can define math ops + """ + + def __init__(self, value): + # MaskedType in Numba shall be parameterized + # with a value type + if not isinstance(value, (types.Number, types.Boolean)): + raise TypeError("value_type must be a numeric scalar type") + self.value_type = value + super().__init__(name=f"Masked{self.value_type}") + + def __hash__(self): + """ + Needed so that numba caches type instances with different + `value_type` separately. + """ + return self.__repr__().__hash__() + + def unify(self, context, other): + """ + Often within a UDF an instance arises where a variable could + be a `MaskedType`, an `NAType`, or a literal based off + the data at runtime, for examplem the variable `ret` here: + + def f(x): + if x == 1: + ret = x + elif x > 2: + ret = 1 + else: + ret = cudf.NA + return ret + + When numba analyzes this function it will eventually figure + out that the variable `ret` could be any of the three types + from above. This scenario will only work if numba knows how + to find some kind of common type between the possibilities, + and this function implements that - the goal is to return a + common type when comparing `self` to other. + + """ + + # If we have Masked and NA, the output should be a + # MaskedType with the original type as its value_type + if isinstance(other, NAType): + return self + + # two MaskedType unify to a new MaskedType whose value_type + # is the result of unifying `self` and `other` `value_type` + elif isinstance(other, MaskedType): + return MaskedType( + context.unify_pairs(self.value_type, other.value_type) + ) + + # if we have MaskedType and something that results in a + # scalar, unify between the MaskedType's value_type + # and that other thing + unified = context.unify_pairs(self.value_type, other) + if unified is None: + # The value types don't unify, so there is no unified masked type + return None + + return MaskedType(unified) + + def __eq__(self, other): + # Equality is required for determining whether a cast is required + # between two different types. + if not isinstance(other, MaskedType): + # Require a cast when the other type is not masked + return False + + # Require a cast for another masked with a different value type + return self.value_type == other.value_type + + +# For typing a Masked constant value defined outside a kernel (e.g. captured in +# a closure). +@typeof_impl.register(classes.Masked) +def typeof_masked(val, c): + return MaskedType(typeof(val.value)) + + +# Implemented typing for Masked(value, valid) - the construction of a Masked +# type in a kernel. +@cuda_decl_registry.register +class MaskedConstructor(ConcreteTemplate): + key = classes.Masked + + cases = [ + nb_signature(MaskedType(t), t, types.boolean) + for t in (types.integer_domain | types.real_domain) + ] + + +# Provide access to `m.value` and `m.valid` in a kernel for a Masked `m`. +make_attribute_wrapper(MaskedType, "value", "value") +make_attribute_wrapper(MaskedType, "valid", "valid") + + +# Typing for `classes.Masked` +@cuda_decl_registry.register_attr +class ClassesTemplate(AttributeTemplate): + key = types.Module(classes) + + def resolve_Masked(self, mod): + return types.Function(MaskedConstructor) + + +# Registration of the global is also needed for Numba to type classes.Masked +cuda_decl_registry.register_global(classes, types.Module(classes)) +# For typing bare Masked (as in `from .classes import Masked` +cuda_decl_registry.register_global( + classes.Masked, types.Function(MaskedConstructor) +) + + +# Tell numba how `MaskedType` is constructed on the backend in terms +# of primitive things that exist at the LLVM level +@register_model(MaskedType) +class MaskedModel(models.StructModel): + def __init__(self, dmm, fe_type): + # This struct has two members, a value and a validity + # let the type of the `value` field be the same as the + # `value_type` and let `valid` be a boolean + members = [("value", fe_type.value_type), ("valid", types.bool_)] + models.StructModel.__init__(self, dmm, fe_type, members) + + +class NAType(types.Type): + """ + A type for handling ops against nulls + Exists so we can: + 1. Teach numba that all occurances of `cudf.NA` are + to be read as instances of this type instead + 2. Define ops like `if x is cudf.NA` where `x` is of + type `Masked` to mean `if x.valid is False` + """ + + def __init__(self): + super().__init__(name="NA") + + def unify(self, context, other): + """ + Masked <-> NA is deferred to MaskedType.unify() + Literal <-> NA -> Masked + """ + if isinstance(other, MaskedType): + # bounce to MaskedType.unify + return None + elif isinstance(other, NAType): + # unify {NA, NA} -> NA + return self + else: + return MaskedType(other) + + +na_type = NAType() + + +@typeof_impl.register(_NAType) +def typeof_na(val, c): + """ + Tie instances of _NAType (cudf.NA) to our NAType. + Effectively make it so numba sees `cudf.NA` as an + instance of this NAType -> handle it accordingly. + """ + return na_type + + +register_model(NAType)(models.OpaqueModel) + + +# Ultimately, we want numba to produce PTX code that specifies how to implement +# an operation on two singular `Masked` structs together, which is defined +# as producing a new `Masked` with the right validity and if valid, +# the correct value. This happens in two phases: +# 1. Specify that `Masked` `Masked` exists and what it should return +# 2. Implement how to actually do (1) at the LLVM level +# The following code accomplishes (1) - it is really just a way of specifying +# that the has a CUDA overload that accepts two `Masked` that +# are parameterized with `value_type` and what flavor of `Masked` to return. +class MaskedScalarArithOp(AbstractTemplate): + def generic(self, args, kws): + """ + Typing for `Masked` `Masked` + Numba expects a valid numba type to be returned if typing is successful + else `None` signifies the error state (this pattern is commonly used + in Numba) + """ + if isinstance(args[0], MaskedType) and isinstance(args[1], MaskedType): + # In the case of op(Masked, Masked), the return type is a Masked + # such that Masked.value is the primitive type that would have + # been resolved if we were just operating on the + # `value_type`s. + return_type = self.context.resolve_function_type( + self.key, (args[0].value_type, args[1].value_type), kws + ).return_type + return nb_signature(MaskedType(return_type), args[0], args[1]) + + +class MaskedScalarNullOp(AbstractTemplate): + def generic(self, args, kws): + """ + Typing for `Masked` + `NA` + Handles situations like `x + cudf.NA` + """ + if isinstance(args[0], MaskedType) and isinstance(args[1], NAType): + # In the case of op(Masked, NA), the result has the same + # dtype as the original regardless of what it is + return nb_signature(args[0], args[0], na_type,) + elif isinstance(args[0], NAType) and isinstance(args[1], MaskedType): + return nb_signature(args[1], na_type, args[1]) + + +class MaskedScalarScalarOp(AbstractTemplate): + def generic(self, args, kws): + """ + Typing for `Masked` a scalar (and vice-versa). + handles situations like `x + 1` + """ + # In the case of op(Masked, scalar), we resolve the type between + # the Masked value_type and the scalar's type directly + if isinstance(args[0], MaskedType) and isinstance( + args[1], types.Number + ): + to_resolve_types = (args[0].value_type, args[1]) + elif isinstance(args[0], types.Number) and isinstance( + args[1], MaskedType + ): + to_resolve_types = (args[1].value_type, args[0]) + return_type = self.context.resolve_function_type( + self.key, to_resolve_types, kws + ).return_type + return nb_signature(MaskedType(return_type), args[0], args[1],) + + +@cuda_decl_registry.register_global(operator.is_) +class MaskedScalarIsNull(AbstractTemplate): + """ + Typing for `Masked is cudf.NA` + """ + + def generic(self, args, kws): + if isinstance(args[0], MaskedType) and isinstance(args[1], NAType): + return nb_signature(types.boolean, args[0], na_type) + elif isinstance(args[1], MaskedType) and isinstance(args[0], NAType): + return nb_signature(types.boolean, na_type, args[1]) + + +@cuda_decl_registry.register_global(operator.truth) +class MaskedScalarTruth(AbstractTemplate): + """ + Typing for `if Masked` + Used for `if x > y` + The truthiness of a MaskedType shall be the truthiness + of the `value` stored therein + """ + + def generic(self, args, kws): + if isinstance(args[0], MaskedType): + return nb_signature(types.boolean, MaskedType(types.boolean)) + + +for op in arith_ops + comparison_ops: + # Every op shares the same typing class + cuda_decl_registry.register_global(op)(MaskedScalarArithOp) + cuda_decl_registry.register_global(op)(MaskedScalarNullOp) + cuda_decl_registry.register_global(op)(MaskedScalarScalarOp) diff --git a/python/cudf/cudf/tests/test_extension_compilation.py b/python/cudf/cudf/tests/test_extension_compilation.py new file mode 100644 index 00000000000..e527fd0af17 --- /dev/null +++ b/python/cudf/cudf/tests/test_extension_compilation.py @@ -0,0 +1,309 @@ +import operator + +import pytest +from numba import cuda, types +from numba.cuda import compile_ptx + +from cudf import NA +from cudf.core.udf.classes import Masked +from cudf.core.udf.typing import MaskedType + +arith_ops = ( + operator.add, + operator.sub, + operator.mul, + operator.truediv, + operator.floordiv, + operator.mod, + operator.pow, +) + +comparison_ops = ( + operator.lt, + operator.le, + operator.eq, + operator.ne, + operator.ge, + operator.gt, +) + +unary_ops = (operator.truth,) + +ops = arith_ops + comparison_ops + +number_types = ( + types.float32, + types.float64, + types.int8, + types.int16, + types.int32, + types.int64, + types.uint8, + types.uint16, + types.uint32, + types.uint64, +) + +QUICK = False + +if QUICK: + arith_ops = (operator.add, operator.truediv, operator.pow) + number_types = (types.int32, types.float32) + + +number_ids = tuple(str(t) for t in number_types) + + +@pytest.mark.parametrize("op", unary_ops) +@pytest.mark.parametrize("ty", number_types, ids=number_ids) +def test_compile_masked_unary(op, ty): + def func(x): + return op(x) + + cc = (7, 5) + ptx, resty = compile_ptx(func, (MaskedType(ty),), cc=cc, device=True) + + +@pytest.mark.parametrize("op", arith_ops) +@pytest.mark.parametrize("ty", number_types, ids=number_ids) +def test_execute_masked_binary(op, ty): + @cuda.jit(device=True) + def func(x, y): + return op(x, y) + + @cuda.jit(debug=True) + def test_kernel(x, y): + # Reference result with unmasked value + u = func(x, y) + + # Construct masked values to test with + x0, y0 = Masked(x, False), Masked(y, False) + x1, y1 = Masked(x, True), Masked(y, True) + + # Call with masked types + r0 = func(x0, y0) + r1 = func(x1, y1) + + # Check masks are as expected, and unmasked result matches masked + # result + if r0.valid: + raise RuntimeError("Expected r0 to be invalid") + if not r1.valid: + raise RuntimeError("Expected r1 to be valid") + if u != r1.value: + print("Values: ", u, r1.value) + raise RuntimeError("u != r1.value") + + test_kernel[1, 1](1, 2) + + +@pytest.mark.parametrize("op", ops) +@pytest.mark.parametrize("ty", number_types, ids=number_ids) +@pytest.mark.parametrize("constant", [1, 1.5]) +def test_compile_arith_masked_vs_constant(op, ty, constant): + def func(x): + return op(x, constant) + + cc = (7, 5) + ptx, resty = compile_ptx(func, (MaskedType(ty),), cc=cc, device=True) + + assert isinstance(resty, MaskedType) + + # Check that the masked typing matches that of the unmasked typing + um_ptx, um_resty = compile_ptx(func, (ty,), cc=cc, device=True) + assert resty.value_type == um_resty + + +@pytest.mark.parametrize("op", ops) +@pytest.mark.parametrize("ty", number_types, ids=number_ids) +@pytest.mark.parametrize("constant", [1, 1.5]) +def test_compile_arith_constant_vs_masked(op, ty, constant): + def func(x): + return op(constant, x) + + cc = (7, 5) + ptx, resty = compile_ptx(func, (MaskedType(ty),), cc=cc, device=True) + + assert isinstance(resty, MaskedType) + + +@pytest.mark.parametrize("op", ops) +@pytest.mark.parametrize("ty", number_types, ids=number_ids) +def test_compile_arith_masked_vs_na(op, ty): + def func(x): + return op(x, NA) + + cc = (7, 5) + ptx, resty = compile_ptx(func, (MaskedType(ty),), cc=cc, device=True) + + assert isinstance(resty, MaskedType) + + +@pytest.mark.parametrize("op", ops) +@pytest.mark.parametrize("ty", number_types, ids=number_ids) +def test_compile_arith_na_vs_masked(op, ty): + def func(x): + return op(NA, x) + + cc = (7, 5) + ptx, resty = compile_ptx(func, (MaskedType(ty),), cc=cc, device=True) + + +@pytest.mark.parametrize("op", ops) +@pytest.mark.parametrize("ty1", number_types, ids=number_ids) +@pytest.mark.parametrize("ty2", number_types, ids=number_ids) +@pytest.mark.parametrize( + "masked", + ((False, True), (True, False), (True, True)), + ids=("um", "mu", "mm"), +) +def test_compile_arith_masked_ops(op, ty1, ty2, masked): + def func(x, y): + return op(x, y) + + cc = (7, 5) + + if masked[0]: + ty1 = MaskedType(ty1) + if masked[1]: + ty2 = MaskedType(ty2) + + ptx, resty = compile_ptx(func, (ty1, ty2), cc=cc, device=True) + + +def func_x_is_na(x): + return x is NA + + +def func_na_is_x(x): + return NA is x + + +@pytest.mark.parametrize("fn", (func_x_is_na, func_na_is_x)) +def test_is_na(fn): + + valid = Masked(1, True) + invalid = Masked(1, False) + + device_fn = cuda.jit(device=True)(fn) + + @cuda.jit(debug=True) + def test_kernel(): + valid_is_na = device_fn(valid) + invalid_is_na = device_fn(invalid) + + if valid_is_na: + raise RuntimeError("Valid masked value is NA and should not be") + + if not invalid_is_na: + raise RuntimeError("Invalid masked value is not NA and should be") + + test_kernel[1, 1]() + + +def func_lt_na(x): + return x < NA + + +def func_gt_na(x): + return x > NA + + +def func_eq_na(x): + return x == NA + + +def func_ne_na(x): + return x != NA + + +def func_ge_na(x): + return x >= NA + + +def func_le_na(x): + return x <= NA + + +def func_na_lt(x): + return x < NA + + +def func_na_gt(x): + return x > NA + + +def func_na_eq(x): + return x == NA + + +def func_na_ne(x): + return x != NA + + +def func_na_ge(x): + return x >= NA + + +def func_na_le(x): + return x <= NA + + +na_comparison_funcs = ( + func_lt_na, + func_gt_na, + func_eq_na, + func_ne_na, + func_ge_na, + func_le_na, + func_na_lt, + func_na_gt, + func_na_eq, + func_na_ne, + func_na_ge, + func_na_le, +) + + +@pytest.mark.parametrize("fn", na_comparison_funcs) +@pytest.mark.parametrize("ty", number_types, ids=number_ids) +def test_na_masked_comparisons(fn, ty): + + device_fn = cuda.jit(device=True)(fn) + + @cuda.jit(debug=True) + def test_kernel(): + unmasked = ty(1) + valid_masked = Masked(unmasked, True) + invalid_masked = Masked(unmasked, False) + + valid_cmp_na = device_fn(valid_masked) + invalid_cmp_na = device_fn(invalid_masked) + + if valid_cmp_na: + raise RuntimeError("Valid masked value compared True with NA") + + if invalid_cmp_na: + raise RuntimeError("Invalid masked value compared True with NA") + + test_kernel[1, 1]() + + +# xfail because scalars do not yet cast for a comparison to NA +@pytest.mark.xfail +@pytest.mark.parametrize("fn", na_comparison_funcs) +@pytest.mark.parametrize("ty", number_types, ids=number_ids) +def test_na_scalar_comparisons(fn, ty): + + device_fn = cuda.jit(device=True)(fn) + + @cuda.jit(debug=True) + def test_kernel(): + unmasked = ty(1) + + unmasked_cmp_na = device_fn(unmasked) + + if unmasked_cmp_na: + raise RuntimeError("Unmasked value compared True with NA") + + test_kernel[1, 1]() diff --git a/python/cudf/cudf/tests/test_udf_masked_ops.py b/python/cudf/cudf/tests/test_udf_masked_ops.py new file mode 100644 index 00000000000..f73f1526c7f --- /dev/null +++ b/python/cudf/cudf/tests/test_udf_masked_ops.py @@ -0,0 +1,292 @@ +import operator + +import pandas as pd +import pytest +from numba import cuda + +import cudf +from cudf.core.udf.pipeline import nulludf +from cudf.testing._utils import NUMERIC_TYPES, assert_eq + +arith_ops = [ + operator.add, + operator.sub, + operator.mul, + operator.truediv, + operator.floordiv, + operator.mod, + pytest.param( + operator.pow, + marks=pytest.mark.xfail( + reason="https://github.com/rapidsai/cudf/issues/8470" + ), + ), +] + +comparison_ops = [ + operator.eq, + operator.ne, + operator.lt, + operator.le, + operator.gt, + operator.ge, +] + + +def run_masked_udf_test(func_pdf, func_gdf, data, **kwargs): + + # Skip testing CUDA 11.0 + runtime = cuda.cudadrv.runtime.Runtime() + mjr, mnr = runtime.get_version() + if mjr < 11 or (mjr == 11 and mnr < 1): + pytest.skip("Skip testing for CUDA 11.0") + + gdf = data + pdf = data.to_pandas(nullable=True) + + expect = pdf.apply( + lambda row: func_pdf(*[row[i] for i in data.columns]), axis=1 + ) + obtain = gdf.apply( + lambda row: func_gdf(*[row[i] for i in data.columns]), axis=1 + ) + assert_eq(expect, obtain, **kwargs) + + +@pytest.mark.parametrize("op", arith_ops) +def test_arith_masked_vs_masked(op): + # This test should test all the typing + # and lowering for arithmetic ops between + # two columns + def func_pdf(x, y): + return op(x, y) + + @nulludf + def func_gdf(x, y): + return op(x, y) + + gdf = cudf.DataFrame({"a": [1, None, 3, None], "b": [4, 5, None, None]}) + run_masked_udf_test(func_pdf, func_gdf, gdf, check_dtype=False) + + +@pytest.mark.parametrize("op", comparison_ops) +def test_compare_masked_vs_masked(op): + # this test should test all the + # typing and lowering for comparisons + # between columns + + def func_pdf(x, y): + return op(x, y) + + @nulludf + def func_gdf(x, y): + return op(x, y) + + # we should get: + # [?, ?, , , ] + gdf = cudf.DataFrame( + {"a": [1, 0, None, 1, None], "b": [0, 1, 0, None, None]} + ) + run_masked_udf_test(func_pdf, func_gdf, gdf, check_dtype=False) + + +@pytest.mark.parametrize("op", arith_ops) +@pytest.mark.parametrize("constant", [1, 1.5]) +def test_arith_masked_vs_constant(op, constant): + def func_pdf(x): + return op(x, constant) + + @nulludf + def func_gdf(x): + return op(x, constant) + + # Just a single column -> result will be all NA + gdf = cudf.DataFrame({"data": [1, 2, None]}) + + run_masked_udf_test(func_pdf, func_gdf, gdf, check_dtype=False) + + +@pytest.mark.parametrize("op", arith_ops) +@pytest.mark.parametrize("constant", [1, 1.5]) +def test_arith_masked_vs_constant_reflected(op, constant): + def func_pdf(x): + return op(constant, x) + + @nulludf + def func_gdf(x): + return op(constant, x) + + # Just a single column -> result will be all NA + gdf = cudf.DataFrame({"data": [1, 2, None]}) + + run_masked_udf_test(func_pdf, func_gdf, gdf, check_dtype=False) + + +@pytest.mark.parametrize("op", arith_ops) +def test_arith_masked_vs_null(op): + def func_pdf(x): + return op(x, pd.NA) + + @nulludf + def func_gdf(x): + return op(x, cudf.NA) + + gdf = cudf.DataFrame({"data": [1, None, 3]}) + run_masked_udf_test(func_pdf, func_gdf, gdf, check_dtype=False) + + +@pytest.mark.parametrize("op", arith_ops) +def test_arith_masked_vs_null_reflected(op): + def func_pdf(x): + return op(pd.NA, x) + + @nulludf + def func_gdf(x): + return op(cudf.NA, x) + + gdf = cudf.DataFrame({"data": [1, None, 3]}) + run_masked_udf_test(func_pdf, func_gdf, gdf, check_dtype=False) + + +def test_masked_is_null_conditional(): + def func_pdf(x, y): + if x is pd.NA: + return y + else: + return x + y + + @nulludf + def func_gdf(x, y): + if x is cudf.NA: + return y + else: + return x + y + + gdf = cudf.DataFrame({"a": [1, None, 3, None], "b": [4, 5, None, None]}) + run_masked_udf_test(func_pdf, func_gdf, gdf, check_dtype=False) + + +@pytest.mark.parametrize("dtype_a", list(NUMERIC_TYPES)) +@pytest.mark.parametrize("dtype_b", list(NUMERIC_TYPES)) +def test_apply_mixed_dtypes(dtype_a, dtype_b): + """ + Test that operations can be performed between columns + of different dtypes and return a column with the correct + values and nulls + """ + # TODO: Parameterize over the op here + def func_pdf(x, y): + return x + y + + @nulludf + def func_gdf(x, y): + return x + y + + gdf = cudf.DataFrame({"a": [1.5, None, 3, None], "b": [4, 5, None, None]}) + gdf["a"] = gdf["a"].astype(dtype_a) + gdf["b"] = gdf["b"].astype(dtype_b) + + run_masked_udf_test(func_pdf, func_gdf, gdf, check_dtype=False) + + +@pytest.mark.parametrize("val", [5, 5.5]) +def test_apply_return_literal(val): + """ + Test unification codepath for scalars and MaskedType + makes sure that numba knows how to cast a scalar value + to a MaskedType + """ + + def func_pdf(x, y): + if x is not pd.NA and x < 2: + return val + else: + return x + y + + @nulludf + def func_gdf(x, y): + if x is not cudf.NA and x < 2: + return val + else: + return x + y + + gdf = cudf.DataFrame({"a": [1, None, 3, None], "b": [4, 5, None, None]}) + + run_masked_udf_test(func_pdf, func_gdf, gdf, check_dtype=False) + + +def test_apply_return_null(): + """ + Tests casting / unification of Masked and NA + """ + + def func_pdf(x): + if x is pd.NA: + return pd.NA + else: + return x + + @nulludf + def func_gdf(x): + if x is cudf.NA: + return cudf.NA + else: + return x + + gdf = cudf.DataFrame({"a": [1, None, 3]}) + run_masked_udf_test(func_pdf, func_gdf, gdf, check_dtype=False) + + +def test_apply_return_either_null_or_literal(): + def func_pdf(x): + if x > 5: + return 2 + else: + return pd.NA + + @nulludf + def func_gdf(x): + if x > 5: + return 2 + else: + return cudf.NA + + gdf = cudf.DataFrame({"a": [1, 3, 6]}) + run_masked_udf_test(func_pdf, func_gdf, gdf, check_dtype=False) + + +def test_apply_everything(): + def func_pdf(w, x, y, z): + if x is pd.NA: + return w + y - z + elif ((z > y) is not pd.NA) and z > y: + return x + elif ((x + y) is not pd.NA) and x + y == 0: + return z / x + elif x + y is pd.NA: + return 2.5 + else: + return y > 2 + + @nulludf + def func_gdf(w, x, y, z): + if x is cudf.NA: + return w + y - z + elif ((z > y) is not cudf.NA) and z > y: + return x + elif ((x + y) is not cudf.NA) and x + y == 0: + return z / x + elif x + y is cudf.NA: + return 2.5 + else: + return y > 2 + + gdf = cudf.DataFrame( + { + "a": [1, 3, 6, 0, None, 5, None], + "b": [3.0, 2.5, None, 5.0, 1.0, 5.0, 11.0], + "c": [2, 3, 6, 0, None, 5, None], + "d": [4, None, 6, 0, None, 5, None], + } + ) + run_masked_udf_test(func_pdf, func_gdf, gdf, check_dtype=False) diff --git a/python/cudf/cudf/utils/cudautils.py b/python/cudf/cudf/utils/cudautils.py index 262fe304dd8..312fbc425dd 100755 --- a/python/cudf/cudf/utils/cudautils.py +++ b/python/cudf/cudf/utils/cudautils.py @@ -262,10 +262,13 @@ def compile_udf(udf, type_signature): ptx_code, return_type = cuda.compile_ptx_for_current_device( udf, type_signature, device=True ) - output_type = numpy_support.as_dtype(return_type) + if not isinstance(return_type, cudf.core.udf.typing.MaskedType): + output_type = numpy_support.as_dtype(return_type).type + else: + output_type = return_type # Populate the cache for this function - res = (ptx_code, output_type.type) + res = (ptx_code, output_type) _udf_code_cache[key] = res return res