Skip to content

Commit

Permalink
Remove cudf._lib.binops in favor of inlining pylibcudf (#17468)
Browse files Browse the repository at this point in the history
Contributes to #17317

Authors:
  - Matthew Roeschke (https://github.com/mroeschke)

Approvers:
  - GALI PREM SAGAR (https://github.com/galipremsagar)

URL: #17468
  • Loading branch information
mroeschke authored Dec 4, 2024
1 parent 4505c53 commit 351ece5
Show file tree
Hide file tree
Showing 12 changed files with 94 additions and 101 deletions.
1 change: 0 additions & 1 deletion python/cudf/cudf/_lib/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@

set(cython_sources
aggregation.pyx
binaryop.pyx
column.pyx
copying.pyx
csv.pyx
Expand Down
1 change: 0 additions & 1 deletion python/cudf/cudf/_lib/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
import numpy as np

from . import (
binaryop,
copying,
csv,
groupby,
Expand Down
61 changes: 0 additions & 61 deletions python/cudf/cudf/_lib/binaryop.pyx

This file was deleted.

60 changes: 60 additions & 0 deletions python/cudf/cudf/core/_internals/binaryop.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
# Copyright (c) 2020-2024, NVIDIA CORPORATION.
from __future__ import annotations

from typing import TYPE_CHECKING

import pylibcudf as plc

from cudf._lib.column import Column
from cudf._lib.types import dtype_to_pylibcudf_type
from cudf.core.buffer import acquire_spill_lock

if TYPE_CHECKING:
from cudf._typing import Dtype
from cudf.core.column import ColumnBase
from cudf.core.scalar import Scalar


@acquire_spill_lock()
def binaryop(
lhs: ColumnBase | Scalar, rhs: ColumnBase | Scalar, op: str, dtype: Dtype
) -> ColumnBase:
"""
Dispatches a binary op call to the appropriate libcudf function:
"""
# TODO: Shouldn't have to keep special-casing. We need to define a separate
# pipeline for libcudf binops that don't map to Python binops.
if op not in {"INT_POW", "NULL_EQUALS", "NULL_NOT_EQUALS"}:
op = op[2:-2]
# Map pandas operation names to pylibcudf operation names.
_op_map = {
"TRUEDIV": "TRUE_DIV",
"FLOORDIV": "FLOOR_DIV",
"MOD": "PYMOD",
"EQ": "EQUAL",
"NE": "NOT_EQUAL",
"LT": "LESS",
"GT": "GREATER",
"LE": "LESS_EQUAL",
"GE": "GREATER_EQUAL",
"AND": "BITWISE_AND",
"OR": "BITWISE_OR",
"XOR": "BITWISE_XOR",
"L_AND": "LOGICAL_AND",
"L_OR": "LOGICAL_OR",
}
op = op.upper()
op = _op_map.get(op, op)

return Column.from_pylibcudf(
plc.binaryop.binary_operation(
lhs.to_pylibcudf(mode="read")
if isinstance(lhs, Column)
else lhs.device_value.c_value,
rhs.to_pylibcudf(mode="read")
if isinstance(rhs, Column)
else rhs.device_value.c_value,
plc.binaryop.BinaryOperator[op],
dtype_to_pylibcudf_type(dtype),
)
)
2 changes: 1 addition & 1 deletion python/cudf/cudf/core/column/column.py
Original file line number Diff line number Diff line change
Expand Up @@ -1366,7 +1366,7 @@ def nans_to_nulls(self: Self) -> Self:

def normalize_binop_value(
self, other: ScalarLike
) -> ColumnBase | ScalarLike:
) -> ColumnBase | cudf.Scalar:
raise NotImplementedError

def _reduce(
Expand Down
16 changes: 9 additions & 7 deletions python/cudf/cudf/core/column/datetime.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
import cudf.core.column.string as string
from cudf import _lib as libcudf
from cudf.core._compat import PANDAS_GE_220
from cudf.core._internals import unary
from cudf.core._internals import binaryop, unary
from cudf.core._internals.search import search_sorted
from cudf.core._internals.timezones import (
check_ambiguous_and_nonexistent,
Expand Down Expand Up @@ -509,7 +509,9 @@ def isocalendar(self) -> dict[str, ColumnBase]:
)
}

def normalize_binop_value(self, other: DatetimeLikeScalar) -> ScalarLike:
def normalize_binop_value( # type: ignore[override]
self, other: DatetimeLikeScalar
) -> cudf.Scalar | cudf.DateOffset | ColumnBase:
if isinstance(other, (cudf.Scalar, ColumnBase, cudf.DateOffset)):
return other

Expand Down Expand Up @@ -789,12 +791,12 @@ def _binaryop(self, other: ColumnBinaryOperand, op: str) -> ColumnBase:
if out_dtype is None:
return NotImplemented

result_col = libcudf.binaryop.binaryop(lhs, rhs, op, out_dtype)
if out_dtype != cudf.dtype(np.bool_) and op == "__add__":
result_col = binaryop.binaryop(lhs, rhs, op, out_dtype)
if out_dtype.kind != "b" and op == "__add__":
return result_col
elif cudf.get_option(
"mode.pandas_compatible"
) and out_dtype == cudf.dtype(np.bool_):
elif (
cudf.get_option("mode.pandas_compatible") and out_dtype.kind == "b"
):
return result_col.fillna(op == "__ne__")
else:
return result_col
Expand Down
13 changes: 7 additions & 6 deletions python/cudf/cudf/core/column/decimal.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,11 @@
import pyarrow as pa

import cudf
from cudf import _lib as libcudf
from cudf._lib.strings.convert.convert_fixed_point import (
from_decimal as cpp_from_decimal,
)
from cudf.api.types import is_scalar
from cudf.core._internals import unary
from cudf.core._internals import binaryop, unary
from cudf.core.buffer import as_buffer
from cudf.core.column.column import ColumnBase
from cudf.core.column.numerical_base import NumericalBaseColumn
Expand All @@ -30,6 +29,8 @@
from cudf.utils.utils import pa_mask_buffer_to_mask

if TYPE_CHECKING:
from typing_extensions import Self

from cudf._typing import ColumnBinaryOperand, ColumnLike, Dtype, ScalarLike
from cudf.core.buffer import Buffer

Expand Down Expand Up @@ -141,7 +142,7 @@ def _binaryop(self, other: ColumnBinaryOperand, op: str):
rhs = rhs.astype(
type(output_type)(rhs.dtype.precision, rhs.dtype.scale)
)
result = libcudf.binaryop.binaryop(lhs, rhs, op, output_type)
result = binaryop.binaryop(lhs, rhs, op, output_type)
# libcudf doesn't support precision, so result.dtype doesn't
# maintain output_type.precision
result.dtype.precision = output_type.precision
Expand All @@ -153,7 +154,7 @@ def _binaryop(self, other: ColumnBinaryOperand, op: str):
"__le__",
"__ge__",
}:
result = libcudf.binaryop.binaryop(lhs, rhs, op, bool)
result = binaryop.binaryop(lhs, rhs, op, bool)
else:
raise TypeError(
f"{op} not supported for the following dtypes: "
Expand All @@ -177,7 +178,7 @@ def _validate_fillna_value(
"integer values"
)

def normalize_binop_value(self, other):
def normalize_binop_value(self, other) -> Self | cudf.Scalar:
if isinstance(other, ColumnBase):
if isinstance(other, cudf.core.column.NumericalColumn):
if other.dtype.kind not in "iu":
Expand Down Expand Up @@ -209,7 +210,7 @@ def normalize_binop_value(self, other):
other = Decimal(other)
metadata = other.as_tuple()
precision = max(len(metadata.digits), metadata.exponent)
scale = -metadata.exponent
scale = -cast(int, metadata.exponent)
return cudf.Scalar(
other, dtype=self.dtype.__class__(precision, scale)
)
Expand Down
4 changes: 2 additions & 2 deletions python/cudf/cudf/core/column/lists.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,8 +188,8 @@ def __cuda_array_interface__(self):
"Lists are not yet supported via `__cuda_array_interface__`"
)

def normalize_binop_value(self, other):
if not isinstance(other, ListColumn):
def normalize_binop_value(self, other) -> Self:
if not isinstance(other, type(self)):
return NotImplemented
return other

Expand Down
10 changes: 4 additions & 6 deletions python/cudf/cudf/core/column/numerical.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
import cudf.core.column.string as string
from cudf import _lib as libcudf
from cudf.api.types import is_integer, is_scalar
from cudf.core._internals import unary
from cudf.core._internals import binaryop, unary
from cudf.core.column.column import ColumnBase, as_column
from cudf.core.column.numerical_base import NumericalBaseColumn
from cudf.core.dtypes import CategoricalDtype
Expand Down Expand Up @@ -292,7 +292,7 @@ def _binaryop(self, other: ColumnBinaryOperand, op: str) -> ColumnBase:

lhs, rhs = (other, self) if reflect else (self, other)

return libcudf.binaryop.binaryop(lhs, rhs, op, out_dtype)
return binaryop.binaryop(lhs, rhs, op, out_dtype)

def nans_to_nulls(self: Self) -> Self:
# Only floats can contain nan.
Expand All @@ -301,11 +301,9 @@ def nans_to_nulls(self: Self) -> Self:
newmask = libcudf.transform.nans_to_nulls(self)
return self.set_mask(newmask)

def normalize_binop_value(
self, other: ScalarLike
) -> ColumnBase | cudf.Scalar:
def normalize_binop_value(self, other: ScalarLike) -> Self | cudf.Scalar:
if isinstance(other, ColumnBase):
if not isinstance(other, NumericalColumn):
if not isinstance(other, type(self)):
return NotImplemented
return other
if isinstance(other, cudf.Scalar):
Expand Down
10 changes: 4 additions & 6 deletions python/cudf/cudf/core/column/string.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,11 @@
import cudf.api.types
import cudf.core.column.column as column
import cudf.core.column.datetime as datetime
from cudf import _lib as libcudf
from cudf._lib import string_casting as str_cast, strings as libstrings
from cudf._lib.column import Column
from cudf._lib.types import size_type_dtype
from cudf.api.types import is_integer, is_scalar, is_string_dtype
from cudf.core._internals import binaryop
from cudf.core.buffer import acquire_spill_lock
from cudf.core.column.column import ColumnBase
from cudf.core.column.methods import ColumnMethods
Expand Down Expand Up @@ -6200,7 +6200,7 @@ def normalize_binop_value(self, other) -> column.ColumnBase | cudf.Scalar:

def _binaryop(
self, other: ColumnBinaryOperand, op: str
) -> "column.ColumnBase":
) -> column.ColumnBase:
reflect, op = self._check_reflected_op(op)
# Due to https://github.com/pandas-dev/pandas/issues/46332 we need to
# support binary operations between empty or all null string columns
Expand Down Expand Up @@ -6229,7 +6229,7 @@ def _binaryop(
if other is NotImplemented:
return NotImplemented

if isinstance(other, (StringColumn, str, cudf.Scalar)):
if isinstance(other, (StringColumn, cudf.Scalar)):
if isinstance(other, cudf.Scalar) and other.dtype != "O":
if op in {
"__eq__",
Expand Down Expand Up @@ -6279,9 +6279,7 @@ def _binaryop(
"NULL_NOT_EQUALS",
}:
lhs, rhs = (other, self) if reflect else (self, other)
return libcudf.binaryop.binaryop(
lhs=lhs, rhs=rhs, op=op, dtype="bool"
)
return binaryop.binaryop(lhs=lhs, rhs=rhs, op=op, dtype="bool")
return NotImplemented

@copy_docstring(column.ColumnBase.view)
Expand Down
13 changes: 5 additions & 8 deletions python/cudf/cudf/core/column/timedelta.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,8 @@
import cudf
import cudf.core.column.column as column
import cudf.core.column.string as string
from cudf import _lib as libcudf
from cudf.api.types import is_scalar
from cudf.core._internals import unary
from cudf.core._internals import binaryop, unary
from cudf.core.buffer import Buffer, acquire_spill_lock
from cudf.core.column.column import ColumnBase
from cudf.utils.dtypes import np_to_pa_dtype
Expand Down Expand Up @@ -188,8 +187,8 @@ def _binaryop(self, other: ColumnBinaryOperand, op: str) -> ColumnBase:
this = self.astype(common_dtype).astype(out_dtype)
if isinstance(other, cudf.Scalar):
if other.is_valid():
other = other.value.astype(common_dtype).astype(
out_dtype
other = cudf.Scalar(
other.value.astype(common_dtype).astype(out_dtype)
)
else:
other = cudf.Scalar(None, out_dtype)
Expand Down Expand Up @@ -219,10 +218,8 @@ def _binaryop(self, other: ColumnBinaryOperand, op: str) -> ColumnBase:

lhs, rhs = (other, this) if reflect else (this, other)

result = libcudf.binaryop.binaryop(lhs, rhs, op, out_dtype)
if cudf.get_option(
"mode.pandas_compatible"
) and out_dtype == cudf.dtype(np.bool_):
result = binaryop.binaryop(lhs, rhs, op, out_dtype)
if cudf.get_option("mode.pandas_compatible") and out_dtype.kind == "b":
result = result.fillna(op == "__ne__")
return result

Expand Down
4 changes: 2 additions & 2 deletions python/cudf/cudf/utils/applyutils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from numba.core.utils import pysignature

import cudf
from cudf import _lib as libcudf
from cudf.core._internals import binaryop
from cudf.core.buffer import acquire_spill_lock
from cudf.core.column import column
from cudf.utils import utils
Expand Down Expand Up @@ -121,7 +121,7 @@ def make_aggregate_nullmask(df, columns=None, op="__and__"):
nullmask.copy(), dtype=utils.mask_dtype
)
else:
out_mask = libcudf.binaryop.binaryop(
out_mask = binaryop.binaryop(
nullmask, out_mask, op, out_mask.dtype
)

Expand Down

0 comments on commit 351ece5

Please sign in to comment.