diff --git a/docs/cudf/source/user_guide/api_docs/pylibcudf/binaryop.rst b/docs/cudf/source/user_guide/api_docs/pylibcudf/binaryop.rst new file mode 100644 index 00000000000..e5bc6aa7cda --- /dev/null +++ b/docs/cudf/source/user_guide/api_docs/pylibcudf/binaryop.rst @@ -0,0 +1,6 @@ +======== +binaryop +======== + +.. automodule:: cudf._lib.pylibcudf.binaryop + :members: diff --git a/docs/cudf/source/user_guide/api_docs/pylibcudf/index.rst b/docs/cudf/source/user_guide/api_docs/pylibcudf/index.rst index 435278afeeb..7504295de92 100644 --- a/docs/cudf/source/user_guide/api_docs/pylibcudf/index.rst +++ b/docs/cudf/source/user_guide/api_docs/pylibcudf/index.rst @@ -8,6 +8,7 @@ This page provides API documentation for pylibcudf. :maxdepth: 1 :caption: API Documentation + binaryop column copying gpumemoryview diff --git a/python/cudf/cudf/_lib/binaryop.pxd b/python/cudf/cudf/_lib/binaryop.pxd deleted file mode 100644 index 1f6022251b3..00000000000 --- a/python/cudf/cudf/_lib/binaryop.pxd +++ /dev/null @@ -1,5 +0,0 @@ -# Copyright (c) 2020, NVIDIA CORPORATION. - -from libc.stdint cimport int32_t - -ctypedef int32_t underlying_type_t_binary_operator diff --git a/python/cudf/cudf/_lib/binaryop.pyx b/python/cudf/cudf/_lib/binaryop.pyx index 6212347b5b1..969be426044 100644 --- a/python/cudf/cudf/_lib/binaryop.pyx +++ b/python/cudf/cudf/_lib/binaryop.pyx @@ -1,160 +1,30 @@ -# Copyright (c) 2020-2022, NVIDIA CORPORATION. +# Copyright (c) 2020-2024, NVIDIA CORPORATION. -from enum import IntEnum - -from libcpp.memory cimport unique_ptr -from libcpp.string cimport string -from libcpp.utility cimport move - -from cudf._lib.binaryop cimport underlying_type_t_binary_operator from cudf._lib.column cimport Column - -from cudf._lib.scalar import as_device_scalar - from cudf._lib.scalar cimport DeviceScalar +from cudf._lib.types cimport dtype_to_pylibcudf_type -from cudf._lib.types import SUPPORTED_NUMPY_TO_LIBCUDF_TYPES - -from cudf._lib.cpp.column.column cimport column -from cudf._lib.cpp.column.column_view cimport column_view -from cudf._lib.cpp.scalar.scalar cimport scalar -from cudf._lib.cpp.types cimport data_type, type_id -from cudf._lib.types cimport dtype_to_data_type, underlying_type_t_type_id - -from cudf.api.types import is_scalar +from cudf._lib import pylibcudf +from cudf._lib.scalar import as_device_scalar from cudf.core.buffer import acquire_spill_lock -cimport cudf._lib.cpp.binaryop as cpp_binaryop -from cudf._lib.cpp.binaryop cimport binary_operator -import cudf - - -class BinaryOperation(IntEnum): - ADD = ( - binary_operator.ADD - ) - SUB = ( - binary_operator.SUB - ) - MUL = ( - binary_operator.MUL - ) - DIV = ( - binary_operator.DIV - ) - TRUEDIV = ( - binary_operator.TRUE_DIV - ) - FLOORDIV = ( - binary_operator.FLOOR_DIV - ) - MOD = ( - binary_operator.PYMOD - ) - POW = ( - binary_operator.POW - ) - INT_POW = ( - binary_operator.INT_POW - ) - EQ = ( - binary_operator.EQUAL - ) - NE = ( - binary_operator.NOT_EQUAL - ) - LT = ( - binary_operator.LESS - ) - GT = ( - binary_operator.GREATER - ) - LE = ( - binary_operator.LESS_EQUAL - ) - GE = ( - binary_operator.GREATER_EQUAL - ) - AND = ( - binary_operator.BITWISE_AND - ) - OR = ( - binary_operator.BITWISE_OR - ) - XOR = ( - binary_operator.BITWISE_XOR - ) - L_AND = ( - binary_operator.LOGICAL_AND - ) - L_OR = ( - binary_operator.LOGICAL_OR - ) - GENERIC_BINARY = ( - binary_operator.GENERIC_BINARY - ) - NULL_EQUALS = ( - binary_operator.NULL_EQUALS - ) - - -cdef binaryop_v_v(Column lhs, Column rhs, - binary_operator c_op, data_type c_dtype): - cdef column_view c_lhs = lhs.view() - cdef column_view c_rhs = rhs.view() - - cdef unique_ptr[column] c_result - - with nogil: - c_result = move( - cpp_binaryop.binary_operation( - c_lhs, - c_rhs, - c_op, - c_dtype - ) - ) - - return Column.from_unique_ptr(move(c_result)) - - -cdef binaryop_v_s(Column lhs, DeviceScalar rhs, - binary_operator c_op, data_type c_dtype): - cdef column_view c_lhs = lhs.view() - cdef const scalar* c_rhs = rhs.get_raw_ptr() - - cdef unique_ptr[column] c_result - - with nogil: - c_result = move( - cpp_binaryop.binary_operation( - c_lhs, - c_rhs[0], - c_op, - c_dtype - ) - ) - - return Column.from_unique_ptr(move(c_result)) - -cdef binaryop_s_v(DeviceScalar lhs, Column rhs, - binary_operator c_op, data_type c_dtype): - cdef const scalar* c_lhs = lhs.get_raw_ptr() - cdef column_view c_rhs = rhs.view() - - cdef unique_ptr[column] c_result - - with nogil: - c_result = move( - cpp_binaryop.binary_operation( - c_lhs[0], - c_rhs, - c_op, - c_dtype - ) - ) - - return Column.from_unique_ptr(move(c_result)) +# Map pandas operation names to pylibcudf operation names. +_op_map = { + "TRUEDIV": "TRUE_DIV", + "FLOORDIV": "FLOOR_DIV", + "MOD": "PYMOD", + "EQ": "EQUAL", + "NE": "NOT_EQUAL", + "LT": "LESS", + "GT": "GREATER", + "LE": "LESS_EQUAL", + "GE": "GREATER_EQUAL", + "AND": "BITWISE_AND", + "OR": "BITWISE_OR", + "XOR": "BITWISE_XOR", + "L_AND": "LOGICAL_AND", + "L_OR": "LOGICAL_OR", +} @acquire_spill_lock() @@ -166,74 +36,25 @@ def binaryop(lhs, rhs, op, dtype): # pipeline for libcudf binops that don't map to Python binops. if op not in {"INT_POW", "NULL_EQUALS"}: op = op[2:-2] - - op = BinaryOperation[op.upper()] - cdef binary_operator c_op = ( - op - ) - - cdef data_type c_dtype = dtype_to_data_type(dtype) - - if is_scalar(lhs) or lhs is None: - s_lhs = as_device_scalar(lhs, dtype=rhs.dtype if lhs is None else None) - result = binaryop_s_v( - s_lhs, - rhs, - c_op, - c_dtype - ) - - elif is_scalar(rhs) or rhs is None: - s_rhs = as_device_scalar(rhs, dtype=lhs.dtype if rhs is None else None) - result = binaryop_v_s( - lhs, - s_rhs, - c_op, - c_dtype - ) - - else: - result = binaryop_v_v( - lhs, - rhs, - c_op, - c_dtype - ) - return result - - -@acquire_spill_lock() -def binaryop_udf(Column lhs, Column rhs, udf_ptx, dtype): - """ - Apply a user-defined binary operator (a UDF) defined in `udf_ptx` on - the two input columns `lhs` and `rhs`. The output type of the UDF - has to be specified in `dtype`, a numpy data type. - Currently ONLY int32, int64, float32 and float64 are supported. - """ - cdef column_view c_lhs = lhs.view() - cdef column_view c_rhs = rhs.view() - - cdef type_id tid = ( - ( - ( - SUPPORTED_NUMPY_TO_LIBCUDF_TYPES[cudf.dtype(dtype)] - ) + op = op.upper() + op = _op_map.get(op, op) + + return Column.from_pylibcudf( + # Check if the dtype args are desirable here. + pylibcudf.binaryop.binary_operation( + lhs.to_pylibcudf(mode="read") if isinstance(lhs, Column) + else ( + as_device_scalar( + lhs, dtype=rhs.dtype if lhs is None else None + ) + ).c_value, + rhs.to_pylibcudf(mode="read") if isinstance(rhs, Column) + else ( + as_device_scalar( + rhs, dtype=lhs.dtype if rhs is None else None + ) + ).c_value, + pylibcudf.binaryop.BinaryOperator[op], + dtype_to_pylibcudf_type(dtype), ) ) - cdef data_type c_dtype = data_type(tid) - - cdef string cpp_str = udf_ptx.encode("UTF-8") - - cdef unique_ptr[column] c_result - - with nogil: - c_result = move( - cpp_binaryop.binary_operation( - c_lhs, - c_rhs, - cpp_str, - c_dtype - ) - ) - - return Column.from_unique_ptr(move(c_result)) diff --git a/python/cudf/cudf/_lib/cpp/CMakeLists.txt b/python/cudf/cudf/_lib/cpp/CMakeLists.txt index a99aa58dfe8..764f28add0e 100644 --- a/python/cudf/cudf/_lib/cpp/CMakeLists.txt +++ b/python/cudf/cudf/_lib/cpp/CMakeLists.txt @@ -1,5 +1,5 @@ # ============================================================================= -# Copyright (c) 2022-2023, NVIDIA CORPORATION. +# Copyright (c) 2022-2024, NVIDIA CORPORATION. # # Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except # in compliance with the License. You may obtain a copy of the License at @@ -12,7 +12,7 @@ # the License. # ============================================================================= -set(cython_sources copying.pyx types.pyx) +set(cython_sources binaryop.pyx copying.pyx types.pyx) set(linked_libraries cudf::cudf) diff --git a/python/cudf/cudf/_lib/cpp/binaryop.pxd b/python/cudf/cudf/_lib/cpp/binaryop.pxd index f73a9502cd1..735216e656a 100644 --- a/python/cudf/cudf/_lib/cpp/binaryop.pxd +++ b/python/cudf/cudf/_lib/cpp/binaryop.pxd @@ -1,5 +1,6 @@ -# Copyright (c) 2020-2022, NVIDIA CORPORATION. +# Copyright (c) 2020-2024, NVIDIA CORPORATION. +from libc.stdint cimport int32_t from libcpp.memory cimport unique_ptr from libcpp.string cimport string @@ -10,30 +11,30 @@ from cudf._lib.cpp.types cimport data_type cdef extern from "cudf/binaryop.hpp" namespace "cudf" nogil: - ctypedef enum binary_operator: - ADD "cudf::binary_operator::ADD" - SUB "cudf::binary_operator::SUB" - MUL "cudf::binary_operator::MUL" - DIV "cudf::binary_operator::DIV" - TRUE_DIV "cudf::binary_operator::TRUE_DIV" - FLOOR_DIV "cudf::binary_operator::FLOOR_DIV" - MOD "cudf::binary_operator::MOD" - PYMOD "cudf::binary_operator::PYMOD" - POW "cudf::binary_operator::POW" - INT_POW "cudf::binary_operator::INT_POW" - EQUAL "cudf::binary_operator::EQUAL" - NOT_EQUAL "cudf::binary_operator::NOT_EQUAL" - LESS "cudf::binary_operator::LESS" - GREATER "cudf::binary_operator::GREATER" - LESS_EQUAL "cudf::binary_operator::LESS_EQUAL" - GREATER_EQUAL "cudf::binary_operator::GREATER_EQUAL" - NULL_EQUALS "cudf::binary_operator::NULL_EQUALS" - BITWISE_AND "cudf::binary_operator::BITWISE_AND" - BITWISE_OR "cudf::binary_operator::BITWISE_OR" - BITWISE_XOR "cudf::binary_operator::BITWISE_XOR" - LOGICAL_AND "cudf::binary_operator::LOGICAL_AND" - LOGICAL_OR "cudf::binary_operator::LOGICAL_OR" - GENERIC_BINARY "cudf::binary_operator::GENERIC_BINARY" + cpdef enum class binary_operator(int32_t): + ADD + SUB + MUL + DIV + TRUE_DIV + FLOOR_DIV + MOD + PYMOD + POW + INT_POW + EQUAL + NOT_EQUAL + LESS + GREATER + LESS_EQUAL + GREATER_EQUAL + NULL_EQUALS + BITWISE_AND + BITWISE_OR + BITWISE_XOR + LOGICAL_AND + LOGICAL_OR + GENERIC_BINARY cdef unique_ptr[column] binary_operation ( const scalar& lhs, @@ -62,27 +63,3 @@ cdef extern from "cudf/binaryop.hpp" namespace "cudf" nogil: const string& op, data_type output_type ) except + - - unique_ptr[column] jit_binary_operation \ - "cudf::jit::binary_operation" ( - const column_view& lhs, - const column_view& rhs, - binary_operator op, - data_type output_type - ) except + - - unique_ptr[column] jit_binary_operation \ - "cudf::jit::binary_operation" ( - const column_view& lhs, - const scalar& rhs, - binary_operator op, - data_type output_type - ) except + - - unique_ptr[column] jit_binary_operation \ - "cudf::jit::binary_operation" ( - const scalar& lhs, - const column_view& rhs, - binary_operator op, - data_type output_type - ) except + diff --git a/python/cudf/cudf/_lib/cpp/binaryop.pyx b/python/cudf/cudf/_lib/cpp/binaryop.pyx new file mode 100644 index 00000000000..e69de29bb2d diff --git a/python/cudf/cudf/_lib/pylibcudf/CMakeLists.txt b/python/cudf/cudf/_lib/pylibcudf/CMakeLists.txt index 870a00f99a9..acb013c8b8c 100644 --- a/python/cudf/cudf/_lib/pylibcudf/CMakeLists.txt +++ b/python/cudf/cudf/_lib/pylibcudf/CMakeLists.txt @@ -1,5 +1,5 @@ # ============================================================================= -# Copyright (c) 2023, NVIDIA CORPORATION. +# Copyright (c) 2023-2024, NVIDIA CORPORATION. # # Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except # in compliance with the License. You may obtain a copy of the License at @@ -12,8 +12,8 @@ # the License. # ============================================================================= -set(cython_sources column.pyx copying.pyx gpumemoryview.pyx interop.pyx scalar.pyx table.pyx - types.pyx utils.pyx +set(cython_sources binaryop.pyx column.pyx copying.pyx gpumemoryview.pyx interop.pyx scalar.pyx + table.pyx types.pyx utils.pyx ) set(linked_libraries cudf::cudf) rapids_cython_create_modules( diff --git a/python/cudf/cudf/_lib/pylibcudf/__init__.pxd b/python/cudf/cudf/_lib/pylibcudf/__init__.pxd index 7a35854392c..f4b8c50eecc 100644 --- a/python/cudf/cudf/_lib/pylibcudf/__init__.pxd +++ b/python/cudf/cudf/_lib/pylibcudf/__init__.pxd @@ -1,7 +1,7 @@ -# Copyright (c) 2023, NVIDIA CORPORATION. +# Copyright (c) 2023-2024, NVIDIA CORPORATION. # TODO: Verify consistent usage of relative/absolute imports in pylibcudf. -from . cimport copying, interop +from . cimport binaryop, copying, interop from .column cimport Column from .gpumemoryview cimport gpumemoryview from .scalar cimport Scalar @@ -15,6 +15,7 @@ __all__ = [ "DataType", "Scalar", "Table", + "binaryop", "copying", "gpumemoryview", "interop", diff --git a/python/cudf/cudf/_lib/pylibcudf/__init__.py b/python/cudf/cudf/_lib/pylibcudf/__init__.py index 72b74a57b87..a27d80fc5a2 100644 --- a/python/cudf/cudf/_lib/pylibcudf/__init__.py +++ b/python/cudf/cudf/_lib/pylibcudf/__init__.py @@ -1,6 +1,6 @@ -# Copyright (c) 2023, NVIDIA CORPORATION. +# Copyright (c) 2023-2024, NVIDIA CORPORATION. -from . import copying, interop +from . import binaryop, copying, interop from .column import Column from .gpumemoryview import gpumemoryview from .scalar import Scalar @@ -13,6 +13,7 @@ "Scalar", "Table", "TypeId", + "binaryop", "copying", "gpumemoryview", "interop", diff --git a/python/cudf/cudf/_lib/pylibcudf/binaryop.pxd b/python/cudf/cudf/_lib/pylibcudf/binaryop.pxd new file mode 100644 index 00000000000..56b98333757 --- /dev/null +++ b/python/cudf/cudf/_lib/pylibcudf/binaryop.pxd @@ -0,0 +1,14 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. + +from cudf._lib.cpp.binaryop cimport binary_operator + +from .column cimport Column +from .types cimport DataType + + +cpdef Column binary_operation( + object lhs, + object rhs, + binary_operator op, + DataType data_type +) diff --git a/python/cudf/cudf/_lib/pylibcudf/binaryop.pyx b/python/cudf/cudf/_lib/pylibcudf/binaryop.pyx new file mode 100644 index 00000000000..af248ba2071 --- /dev/null +++ b/python/cudf/cudf/_lib/pylibcudf/binaryop.pyx @@ -0,0 +1,86 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. + +from cython.operator import dereference + +from libcpp.memory cimport unique_ptr +from libcpp.utility cimport move + +from cudf._lib.cpp cimport binaryop as cpp_binaryop +from cudf._lib.cpp.binaryop cimport binary_operator +from cudf._lib.cpp.column.column cimport column + +from cudf._lib.cpp.binaryop import \ + binary_operator as BinaryOperator # no-cython-lint + +from .column cimport Column +from .scalar cimport Scalar +from .types cimport DataType + + +cpdef Column binary_operation( + object lhs, + object rhs, + binary_operator op, + DataType data_type +): + """Perform a binary operation between a column and another column or scalar. + + Either ``lhs`` or ``rhs`` must be a + :py:class:`~cudf._lib.pylibcudf.column.Column`. The other may be a + :py:class:`~cudf._lib.pylibcudf.column.Column` or a + :py:class:`~cudf._lib.pylibcudf.scalar.Scalar`. + + For details, see :cpp:func:`binary_operation`. + + Parameters + ---------- + lhs : Column or Scalar + The left hand side argument. + rhs : Column or Scalar + The right hand side argument. + op : BinaryOperator + The operation to perform. + data_type : DataType + The output to use for the output. + + Returns + ------- + pylibcudf.Column + The result of the binary operation + """ + cdef unique_ptr[column] result + + if isinstance(lhs, Column) and isinstance(rhs, Column): + with nogil: + result = move( + cpp_binaryop.binary_operation( + ( lhs).view(), + ( rhs).view(), + op, + data_type.c_obj + ) + ) + elif isinstance(lhs, Column) and isinstance(rhs, Scalar): + with nogil: + result = move( + cpp_binaryop.binary_operation( + ( lhs).view(), + dereference(( rhs).c_obj), + op, + data_type.c_obj + ) + ) + elif isinstance(lhs, Scalar) and isinstance(rhs, Column): + with nogil: + result = move( + cpp_binaryop.binary_operation( + dereference(( lhs).c_obj), + ( rhs).view(), + op, + data_type.c_obj + ) + ) + else: + raise ValueError(f"Invalid arguments {lhs} and {rhs}") + + return Column.from_libcudf(move(result)) diff --git a/python/cudf/cudf/tests/test_udf_binops.py b/python/cudf/cudf/tests/test_udf_binops.py deleted file mode 100644 index 1ad45e721a3..00000000000 --- a/python/cudf/cudf/tests/test_udf_binops.py +++ /dev/null @@ -1,51 +0,0 @@ -# Copyright (c) 2018-2022, NVIDIA CORPORATION. - -import numpy as np -import pytest -from numba.cuda import compile_ptx -from numba.np import numpy_support - -import rmm - -import cudf -from cudf import Series, _lib as libcudf -from cudf.utils import dtypes as dtypeutils - -_driver_version = rmm._cuda.gpu.driverGetVersion() -_runtime_version = rmm._cuda.gpu.runtimeGetVersion() -_CUDA_JIT128INT_SUPPORTED = (_driver_version >= 11050) and ( - _runtime_version >= 11050 -) - - -@pytest.mark.skipif(not _CUDA_JIT128INT_SUPPORTED, reason="requires CUDA 11.5") -@pytest.mark.parametrize( - "dtype", sorted(list(dtypeutils.NUMERIC_TYPES - {"int8"})) -) -def test_generic_ptx(dtype): - - size = 500 - - lhs_arr = np.random.random(size).astype(dtype) - lhs_col = Series(lhs_arr)._column - - rhs_arr = np.random.random(size).astype(dtype) - rhs_col = Series(rhs_arr)._column - - def generic_function(a, b): - return a**3 + b - - nb_type = numpy_support.from_dtype(cudf.dtype(dtype)) - type_signature = (nb_type, nb_type) - - ptx_code, output_type = compile_ptx( - generic_function, type_signature, device=True - ) - - dtype = numpy_support.as_dtype(output_type).type - - out_col = libcudf.binaryop.binaryop_udf(lhs_col, rhs_col, ptx_code, dtype) - - result = lhs_arr**3 + rhs_arr - - np.testing.assert_almost_equal(result, out_col.values_host)