diff --git a/python/cudf/cudf/_lib/binaryop.pyx b/python/cudf/cudf/_lib/binaryop.pyx index 1b590db9e6d..b11d31ab368 100644 --- a/python/cudf/cudf/_lib/binaryop.pyx +++ b/python/cudf/cudf/_lib/binaryop.pyx @@ -1,4 +1,4 @@ -# Copyright (c) 2020, NVIDIA CORPORATION. +# Copyright (c) 2020-2022, NVIDIA CORPORATION. from enum import IntEnum @@ -160,6 +160,10 @@ def binaryop(lhs, rhs, op, dtype): """ Dispatches a binary op call to the appropriate libcudf function: """ + # TODO: Shouldn't have to keep special-casing. We need to define a separate + # pipeline for libcudf binops that don't map to Python binops. + if op != "NULL_EQUALS": + op = op[2:-2] op = BinaryOperation[op.upper()] cdef binary_operator c_op = ( diff --git a/python/cudf/cudf/_lib/datetime.pyx b/python/cudf/cudf/_lib/datetime.pyx index e41016645cd..e218400a2db 100644 --- a/python/cudf/cudf/_lib/datetime.pyx +++ b/python/cudf/cudf/_lib/datetime.pyx @@ -1,3 +1,5 @@ +# Copyright (c) 2020-2022, NVIDIA CORPORATION. + from libcpp.memory cimport unique_ptr from libcpp.utility cimport move @@ -56,8 +58,8 @@ def extract_datetime_component(Column col, object field): if field == "weekday": # Pandas counts Monday-Sunday as 0-6 - # while we count Monday-Sunday as 1-7 - result = result.binary_operator("sub", result.dtype.type(1)) + # while libcudf counts Monday-Sunday as 1-7 + result = result - result.dtype.type(1) return result diff --git a/python/cudf/cudf/_typing.py b/python/cudf/cudf/_typing.py index ca2024929f3..87988150fd3 100644 --- a/python/cudf/cudf/_typing.py +++ b/python/cudf/cudf/_typing.py @@ -25,7 +25,7 @@ ColumnLike = Any # binary operation -BinaryOperand = Union["cudf.Scalar", "cudf.core.column.ColumnBase"] +ColumnBinaryOperand = Union["cudf.Scalar", "cudf.core.column.ColumnBase"] DataFrameOrSeries = Union["cudf.Series", "cudf.DataFrame"] SeriesOrIndex = Union["cudf.Series", "cudf.core.index.BaseIndex"] diff --git a/python/cudf/cudf/core/column/categorical.py b/python/cudf/cudf/core/column/categorical.py index caab2294484..e0022ed21ca 100644 --- a/python/cudf/cudf/core/column/categorical.py +++ b/python/cudf/cudf/core/column/categorical.py @@ -24,7 +24,7 @@ import cudf from cudf import _lib as libcudf from cudf._lib.transform import bools_to_mask -from cudf._typing import ColumnLike, Dtype, ScalarLike +from cudf._typing import ColumnBinaryOperand, ColumnLike, Dtype, ScalarLike from cudf.api.types import is_categorical_dtype, is_interval_dtype from cudf.core.buffer import Buffer from cudf.core.column import column @@ -630,6 +630,14 @@ class CategoricalColumn(column.ColumnBase): dtype: cudf.core.dtypes.CategoricalDtype _codes: Optional[NumericalColumn] _children: Tuple[NumericalColumn] + _VALID_BINARY_OPERATIONS = { + "__eq__", + "__ne__", + "__lt__", + "__le__", + "__gt__", + "__ge__", + } def __init__( self, @@ -875,41 +883,29 @@ def slice( offset=codes.offset, ) - def binary_operator( - self, op: str, rhs, reflect: bool = False - ) -> ColumnBase: - if op not in {"eq", "ne", "lt", "le", "gt", "ge", "NULL_EQUALS"}: - raise TypeError( - "Series of dtype `category` cannot perform the operation: " - f"{op}" - ) - rhs = self._wrap_binop_normalization(rhs) + def _binaryop(self, other: ColumnBinaryOperand, op: str) -> ColumnBase: + other = self._wrap_binop_normalization(other) # TODO: This is currently just here to make mypy happy, but eventually # we'll need to properly establish the APIs for these methods. - if not isinstance(rhs, CategoricalColumn): + if not isinstance(other, CategoricalColumn): raise ValueError # Note: at this stage we are guaranteed that the dtypes are equal. - if not self.ordered and op not in {"eq", "ne", "NULL_EQUALS"}: + if not self.ordered and op not in {"__eq__", "__ne__", "NULL_EQUALS"}: raise TypeError( "The only binary operations supported by unordered " "categorical columns are equality and inequality." ) - return self.as_numerical.binary_operator(op, rhs.as_numerical) + return self.as_numerical._binaryop(other.as_numerical, op) def normalize_binop_value(self, other: ScalarLike) -> CategoricalColumn: if isinstance(other, column.ColumnBase): if not isinstance(other, CategoricalColumn): - raise ValueError( - "Binary operations with categorical columns require both " - "columns to be categorical." - ) + return NotImplemented if other.dtype != self.dtype: raise TypeError( "Categoricals can only compare with the same type" ) return other - if isinstance(other, np.ndarray) and other.ndim == 0: - other = other.item() ary = cudf.utils.utils.scalar_broadcast_to( self._encode(other), size=len(self), dtype=self.codes.dtype diff --git a/python/cudf/cudf/core/column/column.py b/python/cudf/cudf/core/column/column.py index 2919b62b49c..401d5f82743 100644 --- a/python/cudf/cudf/core/column/column.py +++ b/python/cudf/cudf/core/column/column.py @@ -41,7 +41,7 @@ drop_nulls, ) from cudf._lib.transform import bools_to_mask -from cudf._typing import BinaryOperand, ColumnLike, Dtype, ScalarLike +from cudf._typing import ColumnLike, Dtype, ScalarLike from cudf.api.types import ( _is_non_decimal_numeric_dtype, _is_scalar_or_zero_d_array, @@ -68,7 +68,7 @@ ListDtype, StructDtype, ) -from cudf.core.mixins import Reducible +from cudf.core.mixins import BinaryOperand, Reducible from cudf.utils import utils from cudf.utils.dtypes import ( cudf_dtype_from_pa_type, @@ -78,7 +78,7 @@ pandas_dtypes_alias_to_cudf_alias, pandas_dtypes_to_np_dtypes, ) -from cudf.utils.utils import NotIterable, mask_dtype +from cudf.utils.utils import NotIterable, _array_ufunc, mask_dtype T = TypeVar("T", bound="ColumnBase") # TODO: This workaround allows type hints for `slice`, since `slice` is a @@ -86,7 +86,7 @@ Slice = TypeVar("Slice", bound=slice) -class ColumnBase(Column, Serializable, Reducible, NotIterable): +class ColumnBase(Column, Serializable, BinaryOperand, Reducible, NotIterable): _VALID_REDUCTIONS = { "any", "all", @@ -185,7 +185,10 @@ def equals(self, other: ColumnBase, check_dtypes: bool = False) -> bool: return False if check_dtypes and (self.dtype != other.dtype): return False - return self.binary_operator("NULL_EQUALS", other).all() + ret = self._binaryop(other, "NULL_EQUALS") + if ret is NotImplemented: + raise TypeError(f"Cannot compare equality with {type(other)}") + return ret.all() def all(self, skipna: bool = True) -> bool: # The skipna argument is only used for numerical columns. @@ -521,8 +524,10 @@ def __setitem__(self, key: Any, value: Any): self._mimic_inplace(out, inplace=True) def _wrap_binop_normalization(self, other): - if other is cudf.NA: + if other is cudf.NA or other is None: return cudf.Scalar(other, dtype=self.dtype) + if isinstance(other, np.ndarray) and other.ndim == 0: + other = other.item() return self.normalize_binop_value(other) def _scatter_by_slice( @@ -1029,50 +1034,8 @@ def __cuda_array_interface__(self): "`__cuda_array_interface__`" ) - def __add__(self, other): - return self.binary_operator("add", other) - - def __sub__(self, other): - return self.binary_operator("sub", other) - - def __mul__(self, other): - return self.binary_operator("mul", other) - - def __eq__(self, other): - return self.binary_operator("eq", other) - - def __ne__(self, other): - return self.binary_operator("ne", other) - - def __or__(self, other): - return self.binary_operator("or", other) - - def __and__(self, other): - return self.binary_operator("and", other) - - def __floordiv__(self, other): - return self.binary_operator("floordiv", other) - - def __truediv__(self, other): - return self.binary_operator("truediv", other) - - def __mod__(self, other): - return self.binary_operator("mod", other) - - def __pow__(self, other): - return self.binary_operator("pow", other) - - def __lt__(self, other): - return self.binary_operator("lt", other) - - def __gt__(self, other): - return self.binary_operator("gt", other) - - def __le__(self, other): - return self.binary_operator("le", other) - - def __ge__(self, other): - return self.binary_operator("ge", other) + def __array_ufunc__(self, ufunc, method, *inputs, **kwargs): + return _array_ufunc(self, ufunc, method, inputs, kwargs) def searchsorted( self, @@ -1133,14 +1096,6 @@ def unary_operator(self, unaryop: str): f"Operation {unaryop} not supported for dtype {self.dtype}." ) - def binary_operator( - self, op: str, other: BinaryOperand, reflect: bool = False - ) -> ColumnBase: - raise TypeError( - f"Operation {op} not supported between dtypes {self.dtype} and " - f"{other.dtype}." - ) - def normalize_binop_value( self, other: ScalarLike ) -> Union[ColumnBase, ScalarLike]: diff --git a/python/cudf/cudf/core/column/datetime.py b/python/cudf/cudf/core/column/datetime.py index b312f99829f..4ce5a70f0ec 100644 --- a/python/cudf/cudf/core/column/datetime.py +++ b/python/cudf/cudf/core/column/datetime.py @@ -7,14 +7,20 @@ import re from locale import nl_langinfo from types import SimpleNamespace -from typing import Any, Mapping, Sequence, Union, cast +from typing import Any, Mapping, Sequence, cast import numpy as np import pandas as pd import cudf from cudf import _lib as libcudf -from cudf._typing import DatetimeLikeScalar, Dtype, DtypeObj, ScalarLike +from cudf._typing import ( + ColumnBinaryOperand, + DatetimeLikeScalar, + Dtype, + DtypeObj, + ScalarLike, +) from cudf.api.types import is_scalar from cudf.core._compat import PANDAS_GE_120 from cudf.core.buffer import Buffer @@ -109,6 +115,19 @@ class DatetimeColumn(column.ColumnBase): The validity mask """ + _VALID_BINARY_OPERATIONS = { + "__eq__", + "__ne__", + "__lt__", + "__le__", + "__gt__", + "__ge__", + "__add__", + "__sub__", + "__radd__", + "__rsub__", + } + def __init__( self, data: Buffer, @@ -227,8 +246,6 @@ def normalize_binop_value(self, other: DatetimeLikeScalar) -> ScalarLike: if isinstance(other, (cudf.Scalar, ColumnBase, cudf.DateOffset)): return other - if isinstance(other, np.ndarray) and other.ndim == 0: - other = other.item() if isinstance(other, dt.datetime): other = np.datetime64(other) elif isinstance(other, dt.timedelta): @@ -254,10 +271,8 @@ def normalize_binop_value(self, other: DatetimeLikeScalar) -> ScalarLike: return cudf.Scalar(None, dtype=other.dtype) return cudf.Scalar(other) - elif other is None: - return cudf.Scalar(other, dtype=self.dtype) - raise TypeError(f"cannot normalize {type(other)}") + return NotImplemented @property def as_numerical(self) -> "cudf.core.column.NumericalColumn": @@ -388,43 +403,53 @@ def quantile( return pd.Timestamp(result, unit=self.time_unit) return result.astype(self.dtype) - def binary_operator( - self, - op: str, - rhs: Union[ColumnBase, "cudf.Scalar"], - reflect: bool = False, - ) -> ColumnBase: - rhs = self._wrap_binop_normalization(rhs) - if isinstance(rhs, cudf.DateOffset): - return rhs._datetime_binop(self, op, reflect=reflect) - - lhs: Union[ScalarLike, ColumnBase] = self - if op in {"eq", "ne", "lt", "gt", "le", "ge", "NULL_EQUALS"}: + def _binaryop(self, other: ColumnBinaryOperand, op: str) -> ColumnBase: + reflect, op = self._check_reflected_op(op) + other = self._wrap_binop_normalization(other) + if other is NotImplemented: + return NotImplemented + if isinstance(other, cudf.DateOffset): + return other._datetime_binop(self, op, reflect=reflect) + + # TODO: Figure out if I can reflect before we start these checks. That + # requires figuring out why _timedelta_add_result_dtype and + # _timedelta_sub_result_dtype are 1) not symmetric, and 2) different + # from each other. + if op in { + "__eq__", + "__ne__", + "__lt__", + "__gt__", + "__le__", + "__ge__", + "NULL_EQUALS", + }: out_dtype: Dtype = cudf.dtype(np.bool_) - elif op == "add" and pd.api.types.is_timedelta64_dtype(rhs.dtype): + elif op == "__add__" and pd.api.types.is_timedelta64_dtype( + other.dtype + ): out_dtype = cudf.core.column.timedelta._timedelta_add_result_dtype( - rhs, lhs + other, self ) - elif op == "sub" and pd.api.types.is_timedelta64_dtype(rhs.dtype): + elif op == "__sub__" and pd.api.types.is_timedelta64_dtype( + other.dtype + ): out_dtype = cudf.core.column.timedelta._timedelta_sub_result_dtype( - rhs if reflect else lhs, lhs if reflect else rhs + other if reflect else self, self if reflect else other ) - elif op == "sub" and pd.api.types.is_datetime64_dtype(rhs.dtype): + elif op == "__sub__" and pd.api.types.is_datetime64_dtype(other.dtype): units = ["s", "ms", "us", "ns"] - lhs_time_unit = cudf.utils.dtypes.get_time_unit(lhs) + lhs_time_unit = cudf.utils.dtypes.get_time_unit(self) lhs_unit = units.index(lhs_time_unit) - rhs_time_unit = cudf.utils.dtypes.get_time_unit(rhs) + rhs_time_unit = cudf.utils.dtypes.get_time_unit(other) rhs_unit = units.index(rhs_time_unit) out_dtype = np.dtype( f"timedelta64[{units[max(lhs_unit, rhs_unit)]}]" ) else: - raise TypeError( - f"Series of dtype {self.dtype} cannot perform " - f" the operation {op}" - ) + return NotImplemented - lhs, rhs = (self, rhs) if not reflect else (rhs, self) + lhs, rhs = (other, self) if reflect else (self, other) return libcudf.binaryop.binaryop(lhs, rhs, op, out_dtype) def fillna( diff --git a/python/cudf/cudf/core/column/decimal.py b/python/cudf/cudf/core/column/decimal.py index e011afbd0ff..f10e257d359 100644 --- a/python/cudf/cudf/core/column/decimal.py +++ b/python/cudf/cudf/core/column/decimal.py @@ -14,7 +14,7 @@ from cudf._lib.strings.convert.convert_fixed_point import ( from_decimal as cpp_from_decimal, ) -from cudf._typing import Dtype +from cudf._typing import ColumnBinaryOperand, Dtype from cudf.api.types import is_integer_dtype, is_scalar from cudf.core.buffer import Buffer from cudf.core.column import ColumnBase, as_column @@ -24,6 +24,7 @@ Decimal128Dtype, DecimalDtype, ) +from cudf.core.mixins import BinaryOperand from cudf.utils.utils import pa_mask_buffer_to_mask from .numerical_base import NumericalBaseColumn @@ -33,6 +34,7 @@ class DecimalBaseColumn(NumericalBaseColumn): """Base column for decimal32, decimal64 or decimal128 columns""" dtype: DecimalDtype + _VALID_BINARY_OPERATIONS = BinaryOperand._SUPPORTED_BINARY_OPERATIONS def as_decimal_column( self, dtype: Dtype, **kwargs @@ -60,18 +62,25 @@ def as_string_column( "cudf.core.column.StringColumn", as_column([], dtype="object") ) - def binary_operator(self, op, other, reflect=False): - if reflect: - self, other = other, self - # Decimals in libcudf don't support truediv, see - # https://github.com/rapidsai/cudf/pull/7435 for explanation. - op = op.replace("true", "") + # Decimals in libcudf don't support truediv, see + # https://github.com/rapidsai/cudf/pull/7435 for explanation. + def __truediv__(self, other): + return self._binaryop(other, "__div__") + + def __rtruediv__(self, other): + return self._binaryop(other, "__rdiv__") + + def _binaryop(self, other: ColumnBinaryOperand, op: str): + reflect, op = self._check_reflected_op(op) other = self._wrap_binop_normalization(other) + if other is NotImplemented: + return NotImplemented + lhs, rhs = (other, self) if reflect else (self, other) # Binary Arithmetics between decimal columns. `Scale` and `precision` # are computed outside of libcudf try: - if op in {"add", "sub", "mul", "div"}: + if op in {"__add__", "__sub__", "__mul__", "__div__"}: output_type = _get_decimal_type(self.dtype, other.dtype, op) result = libcudf.binaryop.binaryop( self, other, op, output_type @@ -79,7 +88,14 @@ def binary_operator(self, op, other, reflect=False): # TODO: Why is this necessary? Why isn't the result's # precision already set correctly based on output_type? result.dtype.precision = output_type.precision - elif op in {"eq", "ne", "lt", "gt", "le", "ge"}: + elif op in { + "__eq__", + "__ne__", + "__lt__", + "__gt__", + "__le__", + "__ge__", + }: result = libcudf.binaryop.binaryop(self, other, op, bool) except RuntimeError as e: if "Unsupported operator for these types" in str(e): @@ -128,10 +144,7 @@ def normalize_binop_value(self, other): self.dtype.__class__(self.dtype.__class__.MAX_PRECISION, 0) ) elif not isinstance(other, DecimalBaseColumn): - raise TypeError( - f"Binary operations are not supported between" - f"{str(type(self))} and {str(type(other))}" - ) + return NotImplemented elif not isinstance(self.dtype, other.dtype.__class__): # This branch occurs if we have a DecimalBaseColumn of a # different size (e.g. 64 instead of 32). @@ -151,7 +164,7 @@ def normalize_binop_value(self, other): return other elif is_scalar(other) and isinstance(other, (int, Decimal)): return cudf.Scalar(Decimal(other)) - raise TypeError(f"cannot normalize {type(other)}") + return NotImplemented def _decimal_quantile( self, q: Union[float, Sequence[float]], interpolation: str, exact: bool @@ -350,13 +363,13 @@ def _get_decimal_type(lhs_dtype, rhs_dtype, op): p1, p2 = lhs_dtype.precision, rhs_dtype.precision s1, s2 = lhs_dtype.scale, rhs_dtype.scale - if op in ("add", "sub"): + if op in {"__add__", "__sub__"}: scale = max(s1, s2) precision = scale + max(p1 - s1, p2 - s2) + 1 - elif op == "mul": + elif op == "__mul__": scale = s1 + s2 precision = p1 + p2 + 1 - elif op == "div": + elif op == "__div__": scale = max(6, s1 + p2 + 1) precision = p1 - s1 + s2 + scale else: diff --git a/python/cudf/cudf/core/column/lists.py b/python/cudf/cudf/core/column/lists.py index 53ab79542e2..0df5be2d862 100644 --- a/python/cudf/cudf/core/column/lists.py +++ b/python/cudf/cudf/core/column/lists.py @@ -19,7 +19,7 @@ sort_lists, ) from cudf._lib.strings.convert.convert_lists import format_list_column -from cudf._typing import BinaryOperand, ColumnLike, Dtype, ScalarLike +from cudf._typing import ColumnBinaryOperand, ColumnLike, Dtype, ScalarLike from cudf.api.types import _is_non_decimal_numeric_dtype, is_list_dtype from cudf.core.buffer import Buffer from cudf.core.column import ColumnBase, as_column, column @@ -29,6 +29,7 @@ class ListColumn(ColumnBase): dtype: ListDtype + _VALID_BINARY_OPERATIONS = {"__add__", "__radd__"} def __init__( self, size, dtype, mask=None, offset=0, null_count=None, children=(), @@ -92,50 +93,14 @@ def base_size(self): # avoid it being negative return max(0, len(self.base_children[0]) - 1) - def binary_operator( - self, binop: str, other: BinaryOperand, reflect: bool = False - ) -> ColumnBase: - """ - Calls a binary operator *binop* on operands *self* - and *other*. - - Parameters - ---------- - self, other : list columns - - binop : binary operator - Only "add" operator is currently being supported - for lists concatenation functions - - reflect : boolean, default False - If ``True``, swap the order of the operands. See - https://docs.python.org/3/reference/datamodel.html#object.__ror__ - for more information on when this is necessary. - - Returns - ------- - Series : the output dtype is determined by the - input operands. - - Examples - -------- - >>> import cudf - >>> gdf = cudf.DataFrame({'val': [['a', 'a'], ['b'], ['c']]}) - >>> gdf - val - 0 [a, a] - 1 [b] - 2 [c] - >>> gdf['val'] + gdf['val'] - 0 [a, a, a, a] - 1 [b, b] - 2 [c, c] - Name: val, dtype: list - - """ + def _binaryop(self, other: ColumnBinaryOperand, op: str) -> ColumnBase: + # Lists only support __add__, which concatenates lists. + reflect, op = self._check_reflected_op(op) other = self._wrap_binop_normalization(other) + if other is NotImplemented: + return NotImplemented if isinstance(other.dtype, ListDtype): - if binop == "add": + if op == "__add__": return concatenate_rows( cudf.core.frame.Frame({0: self, 1: other}) ) @@ -255,6 +220,8 @@ def __cuda_array_interface__(self): ) def normalize_binop_value(self, other): + if not isinstance(other, ListColumn): + return NotImplemented return other def _with_type_metadata( diff --git a/python/cudf/cudf/core/column/numerical.py b/python/cudf/cudf/core/column/numerical.py index 015524b841e..c9bc3c59aea 100644 --- a/python/cudf/cudf/core/column/numerical.py +++ b/python/cudf/cudf/core/column/numerical.py @@ -21,7 +21,13 @@ import cudf from cudf import _lib as libcudf from cudf._lib.stream_compaction import drop_nulls -from cudf._typing import BinaryOperand, ColumnLike, Dtype, DtypeObj, ScalarLike +from cudf._typing import ( + ColumnBinaryOperand, + ColumnLike, + Dtype, + DtypeObj, + ScalarLike, +) from cudf.api.types import ( is_bool_dtype, is_float_dtype, @@ -37,6 +43,7 @@ string, ) from cudf.core.dtypes import CategoricalDtype +from cudf.core.mixins import BinaryOperand from cudf.utils import cudautils, utils from cudf.utils.dtypes import ( NUMERIC_TYPES, @@ -63,6 +70,7 @@ class NumericalColumn(NumericalBaseColumn): """ _nan_count: Optional[int] + _VALID_BINARY_OPERATIONS = BinaryOperand._SUPPORTED_BINARY_OPERATIONS def __init__( self, @@ -150,9 +158,7 @@ def unary_operator(self, unaryop: Union[str, Callable]) -> ColumnBase: unaryop = libcudf.unary.UnaryOp[unaryop.upper()] return libcudf.unary.unary_operation(self, unaryop) - def binary_operator( - self, binop: str, rhs: BinaryOperand, reflect: bool = False, - ) -> ColumnBase: + def _binaryop(self, other: ColumnBinaryOperand, op: str) -> ColumnBase: int_float_dtype_mapping = { np.int8: np.float32, np.int16: np.float32, @@ -165,23 +171,19 @@ def binary_operator( np.bool_: np.float32, } - if binop in {"truediv", "rtruediv"}: + if op in {"__truediv__", "__rtruediv__"}: # Division with integer types results in a suitable float. if (truediv_type := int_float_dtype_mapping.get(self.dtype.type)) : - return self.astype(truediv_type).binary_operator( - binop, rhs, reflect - ) + return self.astype(truediv_type)._binaryop(other, op) - rhs = self._wrap_binop_normalization(rhs) + reflect, op = self._check_reflected_op(op) + if (other := self._wrap_binop_normalization(other)) is NotImplemented: + return NotImplemented out_dtype = self.dtype - if rhs is not None: - if isinstance(rhs, cudf.core.column.DecimalBaseColumn): - dtyp = rhs.dtype.__class__(rhs.dtype.MAX_PRECISION, 0) - return self.as_decimal_column(dtyp).binary_operator(binop, rhs) - - out_dtype = np.result_type(self.dtype, rhs.dtype) - if binop in {"mod", "floordiv"}: - tmp = self if reflect else rhs + if other is not None: + out_dtype = np.result_type(self.dtype, other.dtype) + if op in {"__mod__", "__floordiv__"}: + tmp = self if reflect else other # Guard against division by zero for integers. if ( (tmp.dtype.type in int_float_dtype_mapping) @@ -195,31 +197,29 @@ def binary_operator( ): out_dtype = cudf.dtype("float64") - if binop in { - "l_and", - "l_or", - "lt", - "gt", - "le", - "ge", - "eq", - "ne", + if op in { + "__lt__", + "__gt__", + "__le__", + "__ge__", + "__eq__", + "__ne__", "NULL_EQUALS", }: out_dtype = "bool" - if binop in {"and", "or", "xor"}: - if is_float_dtype(self.dtype) or is_float_dtype(rhs): + if op in {"__and__", "__or__", "__xor__"}: + if is_float_dtype(self.dtype) or is_float_dtype(other): raise TypeError( - f"Operation 'bitwise {binop}' not supported between " + f"Operation 'bitwise {op[2:-2]}' not supported between " f"{self.dtype.type.__name__} and " - f"{rhs.dtype.type.__name__}" + f"{other.dtype.type.__name__}" ) - if is_bool_dtype(self.dtype) or is_bool_dtype(rhs): + if is_bool_dtype(self.dtype) or is_bool_dtype(other): out_dtype = "bool" - lhs, rhs = (self, rhs) if not reflect else (rhs, self) - return libcudf.binaryop.binaryop(lhs, rhs, binop, out_dtype) + lhs, rhs = (other, self) if reflect else (self, other) + return libcudf.binaryop.binaryop(lhs, rhs, op, out_dtype) def nans_to_nulls(self: NumericalColumn) -> NumericalColumn: # Only floats can contain nan. @@ -232,15 +232,8 @@ def normalize_binop_value( self, other: ScalarLike ) -> Union[ColumnBase, ScalarLike]: if isinstance(other, ColumnBase): - if not isinstance( - other, (NumericalColumn, cudf.core.column.DecimalBaseColumn,), - ): - raise TypeError( - f"Binary operations are not supported between " - f"{type(self)}and {type(other)}" - ) - return other - if other is None: + if not isinstance(other, NumericalColumn): + return NotImplemented return other if isinstance(other, cudf.Scalar): if self.dtype == other.dtype: @@ -248,8 +241,6 @@ def normalize_binop_value( # expensive device-host transfer just to # adjust the dtype other = other.value - elif isinstance(other, np.ndarray) and other.ndim == 0: - other = other.item() other_dtype = np.min_scalar_type(other) if other_dtype.kind in {"b", "i", "u", "f"}: if isinstance(other, cudf.Scalar): @@ -270,7 +261,7 @@ def normalize_binop_value( data=Buffer(ary), dtype=ary.dtype, mask=self.mask, ) else: - raise TypeError(f"cannot broadcast {type(other)}") + return NotImplemented def int2ip(self) -> "cudf.core.column.StringColumn": if self.dtype != cudf.dtype("int64"): diff --git a/python/cudf/cudf/core/column/string.py b/python/cudf/cudf/core/column/string.py index 82be924dfbc..95bb06ebb0c 100644 --- a/python/cudf/cudf/core/column/string.py +++ b/python/cudf/cudf/core/column/string.py @@ -50,7 +50,13 @@ def str_to_boolean(column: StringColumn): if TYPE_CHECKING: - from cudf._typing import ColumnLike, Dtype, ScalarLike, SeriesOrIndex + from cudf._typing import ( + ColumnBinaryOperand, + ColumnLike, + Dtype, + ScalarLike, + SeriesOrIndex, + ) _str_to_numeric_typecast_functions = { @@ -5025,6 +5031,26 @@ class StringColumn(column.ColumnBase): _start_offset: Optional[int] _end_offset: Optional[int] + _VALID_BINARY_OPERATIONS = { + "__eq__", + "__ne__", + "__lt__", + "__le__", + "__gt__", + "__ge__", + "__add__", + "__radd__", + # These operators aren't actually supported, they only exist to allow + # empty column binops with scalars of arbitrary other dtypes. See + # the _binaryop method for more information. + "__sub__", + "__mul__", + "__mod__", + "__pow__", + "__truediv__", + "__floordiv__", + } + def __init__( self, mask: Buffer = None, @@ -5434,50 +5460,49 @@ def normalize_binop_value( and other.dtype == "object" ): return other - if isinstance(other, str) or other is None: - return utils.scalar_broadcast_to( - other, size=len(self), dtype="object" - ) - if isinstance(other, np.ndarray) and other.ndim == 0: - return utils.scalar_broadcast_to( - other.item(), size=len(self), dtype="object" - ) - raise TypeError(f"cannot broadcast {type(other)}") + if isinstance(other, str): + return cudf.Scalar(other) + return NotImplemented - def binary_operator( - self, op: str, rhs, reflect: bool = False + def _binaryop( + self, other: ColumnBinaryOperand, op: str ) -> "column.ColumnBase": - # Handle object columns that are empty or all nulls when performing - # binary operations - # See https://github.com/pandas-dev/pandas/issues/46332 + reflect, op = self._check_reflected_op(op) + # Due to https://github.com/pandas-dev/pandas/issues/46332 we need to + # support binary operations between empty or all null string columns + # and columns of other dtypes, even if those operations would otherwise + # be invalid. For example, you cannot divide strings, but pandas allows + # division between an empty string column and a (nonempty) integer + # column. Ideally we would disable these operators entirely, but until + # the above issue is resolved we cannot avoid this problem. if self.null_count == len(self): if op in { - "add", - "sub", - "mul", - "mod", - "pow", - "truediv", - "floordiv", - "radd", - "rsub", - "rmul", - "rmod", - "rpow", - "rtruediv", - "rfloordiv", + "__add__", + "__sub__", + "__mul__", + "__mod__", + "__pow__", + "__truediv__", + "__floordiv__", }: return self - elif op in {"eq", "lt", "le", "gt", "ge"}: + elif op in {"__eq__", "__lt__", "__le__", "__gt__", "__ge__"}: return self.notnull() - elif op == "ne": + elif op == "__ne__": return self.isnull() - rhs = self._wrap_binop_normalization(rhs) + other = self._wrap_binop_normalization(other) + if other is NotImplemented: + return NotImplemented + + if isinstance(other, (StringColumn, str, cudf.Scalar)): + if op == "__add__": + if isinstance(other, cudf.Scalar): + other = utils.scalar_broadcast_to( + other, size=len(self), dtype="object" + ) + lhs, rhs = (other, self) if reflect else (self, other) - if isinstance(rhs, (StringColumn, str, cudf.Scalar)): - lhs, rhs = (rhs, self) if reflect else (self, rhs) - if op == "add": return cast( "column.ColumnBase", libstrings.concatenate( @@ -5486,13 +5511,20 @@ def binary_operator( na_rep=cudf.Scalar(None, "str"), ), ) - elif op in {"eq", "ne", "gt", "lt", "ge", "le", "NULL_EQUALS"}: + elif op in { + "__eq__", + "__ne__", + "__gt__", + "__lt__", + "__ge__", + "__le__", + "NULL_EQUALS", + }: + lhs, rhs = (other, self) if reflect else (self, other) return libcudf.binaryop.binaryop( lhs=lhs, rhs=rhs, op=op, dtype="bool" ) - raise TypeError( - f"{op} not supported between {type(self)} and {type(rhs)}" - ) + return NotImplemented @copy_docstring(column.ColumnBase.view) def view(self, dtype) -> "cudf.core.column.ColumnBase": diff --git a/python/cudf/cudf/core/column/timedelta.py b/python/cudf/cudf/core/column/timedelta.py index 66e6271a4d1..11d295a6190 100644 --- a/python/cudf/cudf/core/column/timedelta.py +++ b/python/cudf/cudf/core/column/timedelta.py @@ -11,7 +11,12 @@ import cudf from cudf import _lib as libcudf -from cudf._typing import BinaryOperand, DatetimeLikeScalar, Dtype, DtypeObj +from cudf._typing import ( + ColumnBinaryOperand, + DatetimeLikeScalar, + Dtype, + DtypeObj, +) from cudf.api.types import is_scalar from cudf.core.buffer import Buffer from cudf.core.column import ColumnBase, column, string @@ -46,6 +51,27 @@ class TimeDeltaColumn(column.ColumnBase): If None, it is calculated automatically. """ + _VALID_BINARY_OPERATIONS = { + "__eq__", + "__ne__", + "__lt__", + "__le__", + "__gt__", + "__ge__", + "__add__", + "__sub__", + "__mul__", + "__mod__", + "__truediv__", + "__floordiv__", + "__radd__", + "__rsub__", + "__rmul__", + "__rmod__", + "__rtruediv__", + "__rfloordiv__", + } + def __init__( self, data: Buffer, @@ -125,97 +151,106 @@ def to_pandas( return pd_series - def _binary_op_mul(self, rhs: BinaryOperand) -> DtypeObj: - if rhs.dtype.kind in ("f", "i", "u"): + def _binary_op_mul(self, other: ColumnBinaryOperand) -> DtypeObj: + if other.dtype.kind in ("f", "i", "u"): out_dtype = self.dtype else: raise TypeError( - f"Multiplication of {self.dtype} with {rhs.dtype} " + f"Multiplication of {self.dtype} with {other.dtype} " f"cannot be performed." ) return out_dtype - def _binary_op_mod(self, rhs: BinaryOperand) -> DtypeObj: - if pd.api.types.is_timedelta64_dtype(rhs.dtype): - out_dtype = determine_out_dtype(self.dtype, rhs.dtype) - elif rhs.dtype.kind in ("f", "i", "u"): + def _binary_op_mod(self, other: ColumnBinaryOperand) -> DtypeObj: + if pd.api.types.is_timedelta64_dtype(other.dtype): + out_dtype = determine_out_dtype(self.dtype, other.dtype) + elif other.dtype.kind in ("f", "i", "u"): out_dtype = self.dtype else: raise TypeError( - f"Modulus of {self.dtype} with {rhs.dtype} " + f"Modulo of {self.dtype} with {other.dtype} " f"cannot be performed." ) return out_dtype - def _binary_op_lt_gt_le_ge_eq_ne(self, rhs: BinaryOperand) -> DtypeObj: - if pd.api.types.is_timedelta64_dtype(rhs.dtype): + def _binary_op_lt_gt_le_ge_eq_ne( + self, other: ColumnBinaryOperand + ) -> DtypeObj: + if pd.api.types.is_timedelta64_dtype(other.dtype): return np.bool_ raise TypeError( f"Invalid comparison between dtype={self.dtype}" - f" and {rhs.dtype}" + f" and {other.dtype}" ) def _binary_op_div( - self, rhs: BinaryOperand, op: str - ) -> Tuple["column.ColumnBase", BinaryOperand, DtypeObj]: - lhs = self # type: column.ColumnBase - if pd.api.types.is_timedelta64_dtype(rhs.dtype): - common_dtype = determine_out_dtype(self.dtype, rhs.dtype) - lhs = lhs.astype(common_dtype).astype("float64") - if isinstance(rhs, cudf.Scalar): - if rhs.is_valid(): - rhs = rhs.value.astype(common_dtype).astype("float64") + self, other: ColumnBinaryOperand, op: str + ) -> Tuple["column.ColumnBase", ColumnBinaryOperand, DtypeObj]: + this: ColumnBase = self + if pd.api.types.is_timedelta64_dtype(other.dtype): + common_dtype = determine_out_dtype(self.dtype, other.dtype) + this = self.astype(common_dtype).astype("float64") + if isinstance(other, cudf.Scalar): + if other.is_valid(): + other = other.value.astype(common_dtype).astype("float64") else: - rhs = cudf.Scalar(None, "float64") + other = cudf.Scalar(None, "float64") else: - rhs = rhs.astype(common_dtype).astype("float64") + other = other.astype(common_dtype).astype("float64") - out_dtype = cudf.dtype("float64" if op == "truediv" else "int64") - elif rhs.dtype.kind in ("f", "i", "u"): + out_dtype = cudf.dtype( + "float64" if op == "__truediv__" else "int64" + ) + elif other.dtype.kind in ("f", "i", "u"): out_dtype = self.dtype else: raise TypeError( - f"Division of {self.dtype} with {rhs.dtype} " + f"Division of {self.dtype} with {other.dtype} " f"cannot be performed." ) - return lhs, rhs, out_dtype + return this, other, out_dtype - def binary_operator( - self, op: str, rhs: BinaryOperand, reflect: bool = False + def _binaryop( + self, other: ColumnBinaryOperand, op: str ) -> "column.ColumnBase": - rhs = self._wrap_binop_normalization(rhs) - lhs, rhs = self, rhs - - if op in {"eq", "ne", "lt", "gt", "le", "ge", "NULL_EQUALS"}: - out_dtype = self._binary_op_lt_gt_le_ge_eq_ne(rhs) - elif op == "mul": - out_dtype = self._binary_op_mul(rhs) - elif op == "mod": - out_dtype = self._binary_op_mod(rhs) - elif op in {"truediv", "floordiv"}: - lhs, rhs, out_dtype = self._binary_op_div(rhs, op) # type: ignore - op = "truediv" - elif op == "add": - out_dtype = _timedelta_add_result_dtype(lhs, rhs) - elif op == "sub": - out_dtype = _timedelta_sub_result_dtype(lhs, rhs) + reflect, op = self._check_reflected_op(op) + other = self._wrap_binop_normalization(other) + if other is NotImplemented: + return NotImplemented + + this: ColumnBinaryOperand = self + if op in { + "__eq__", + "__ne__", + "__lt__", + "__gt__", + "__le__", + "__ge__", + "NULL_EQUALS", + }: + out_dtype = self._binary_op_lt_gt_le_ge_eq_ne(other) + elif op == "__mul__": + out_dtype = self._binary_op_mul(other) + elif op == "__mod__": + out_dtype = self._binary_op_mod(other) + elif op in {"__truediv__", "__floordiv__"}: + this, other, out_dtype = self._binary_op_div(other, op) + op = "__truediv__" + elif op == "__add__": + out_dtype = _timedelta_add_result_dtype(self, other) + elif op == "__sub__": + out_dtype = _timedelta_sub_result_dtype(self, other) else: - raise TypeError( - f"Series of dtype {self.dtype} cannot perform " - f"the operation {op}" - ) + return NotImplemented - if reflect: - lhs, rhs = rhs, lhs # type: ignore + lhs, rhs = (other, this) if reflect else (this, other) return libcudf.binaryop.binaryop(lhs, rhs, op, out_dtype) - def normalize_binop_value(self, other) -> BinaryOperand: + def normalize_binop_value(self, other) -> ColumnBinaryOperand: if isinstance(other, (ColumnBase, cudf.Scalar)): return other - if isinstance(other, np.ndarray) and other.ndim == 0: - other = other.item() if isinstance(other, dt.timedelta): other = np.timedelta64(other) elif isinstance(other, pd.Timestamp): @@ -235,10 +270,7 @@ def normalize_binop_value(self, other) -> BinaryOperand: return cudf.Scalar(other) elif np.isscalar(other): return cudf.Scalar(other) - elif other is None: - return cudf.Scalar(other, dtype=self.dtype) - else: - raise TypeError(f"cannot normalize {type(other)}") + return NotImplemented @property def as_numerical(self) -> "cudf.core.column.NumericalColumn": @@ -556,7 +588,7 @@ def determine_out_dtype(lhs_dtype: Dtype, rhs_dtype: Dtype) -> Dtype: def _timedelta_add_result_dtype( - lhs: BinaryOperand, rhs: BinaryOperand + lhs: ColumnBinaryOperand, rhs: ColumnBinaryOperand ) -> Dtype: if pd.api.types.is_timedelta64_dtype(rhs.dtype): out_dtype = determine_out_dtype(lhs.dtype, rhs.dtype) @@ -577,7 +609,7 @@ def _timedelta_add_result_dtype( def _timedelta_sub_result_dtype( - lhs: BinaryOperand, rhs: BinaryOperand + lhs: ColumnBinaryOperand, rhs: ColumnBinaryOperand ) -> Dtype: if pd.api.types.is_timedelta64_dtype( lhs.dtype diff --git a/python/cudf/cudf/core/frame.py b/python/cudf/cudf/core/frame.py index a9d7fce9d9b..d78744a719f 100644 --- a/python/cudf/cudf/core/frame.py +++ b/python/cudf/cudf/core/frame.py @@ -3,6 +3,7 @@ from __future__ import annotations import copy +import operator import pickle import warnings from collections import abc @@ -38,7 +39,6 @@ ColumnBase, as_column, build_categorical_column, - column_empty, deserialize_columns, serialize_columns, ) @@ -49,54 +49,11 @@ from cudf.utils import ioutils from cudf.utils.docutils import copy_docstring from cudf.utils.dtypes import find_common_type -from cudf.utils.utils import _cudf_nvtx_annotate +from cudf.utils.utils import _array_ufunc, _cudf_nvtx_annotate T = TypeVar("T", bound="Frame") -# Mapping from ufuncs to the corresponding binary operators. -_ufunc_binary_operations = { - # Arithmetic binary operations. - "add": "add", - "subtract": "sub", - "multiply": "mul", - "matmul": "matmul", - "divide": "truediv", - "true_divide": "truediv", - "floor_divide": "floordiv", - "power": "pow", - "float_power": "pow", - "remainder": "mod", - "mod": "mod", - "fmod": "mod", - # Bitwise binary operations. - "bitwise_and": "and", - "bitwise_or": "or", - "bitwise_xor": "xor", - # Comparison binary operators - "greater": "gt", - "greater_equal": "ge", - "less": "lt", - "less_equal": "le", - "not_equal": "ne", - "equal": "eq", -} - -# These operators need to be mapped to their inverses when performing a -# reflected ufunc operation because no reflected version of the operators -# themselves exist. When these operators are invoked directly (not via -# __array_ufunc__) Python takes care of calling the inverse operation. -_ops_without_reflection = { - "gt": "lt", - "ge": "le", - "lt": "gt", - "le": "ge", - # ne and eq are symmetric, so they are their own inverse op - "ne": "ne", - "eq": "eq", -} - - class Frame(BinaryOperand, Scannable): """A collection of Column objects with an optional index. @@ -2482,30 +2439,6 @@ def _unaryop(self, op): zip(self._column_names, data_columns), self._index ) - def _binaryop( - self, other: T, op: str, fill_value: Any = None, *args, **kwargs, - ) -> Frame: - """Perform a binary operation between two frames. - - Parameters - ---------- - other : Frame - The second operand. - op : str - The operation to perform. - fill_value : Any, default None - The value to replace null values with. If ``None``, nulls are not - filled before the operation. - - Returns - ------- - Frame - A new instance containing the result of the operation. - """ - raise NotImplementedError( - f"Binary operations are not supported for {self.__class__}" - ) - @classmethod @_cudf_nvtx_annotate def _colwise_binop( @@ -2535,8 +2468,6 @@ def _colwise_binop( A dict of columns constructed from the result of performing the requested operation on the operands. """ - fn = fn[2:-2] - # Now actually perform the binop on the columns in left and right. output = {} for ( @@ -2567,11 +2498,9 @@ def _colwise_binop( # are not numerical using the new binops mixin. outcol = ( - left_column.binary_operator(fn, right_column, reflect=reflect) - if right_column is not None - else column_empty( - left_column.size, left_column.dtype, masked=True - ) + getattr(operator, fn)(right_column, left_column) + if reflect + else getattr(operator, fn)(left_column, right_column) ) if output_mask is not None: @@ -2581,44 +2510,8 @@ def _colwise_binop( return output - # For more detail on this function and how it should work, see - # https://numpy.org/doc/stable/reference/ufuncs.html def __array_ufunc__(self, ufunc, method, *inputs, **kwargs): - # We don't currently support reduction, accumulation, etc. We also - # don't support any special kwargs or higher arity ufuncs than binary. - if method != "__call__" or kwargs or ufunc.nin > 2: - return NotImplemented - - fname = ufunc.__name__ - if fname in _ufunc_binary_operations: - reflect = self is not inputs[0] - other = inputs[0] if reflect else inputs[1] - - op = _ufunc_binary_operations[fname] - if reflect and op in _ops_without_reflection: - op = _ops_without_reflection[op] - reflect = False - op = f"__{'r' if reflect else ''}{op}__" - - # Float_power returns float irrespective of the input type. - if fname == "float_power": - return getattr(self, op)(other).astype(float) - return getattr(self, op)(other) - - # Special handling for various unary operations. - if fname == "negative": - return self * -1 - if fname == "positive": - return self.copy(deep=True) - if fname == "invert": - return ~self - if fname == "absolute": - return self.abs() - if fname == "fabs": - return self.abs().astype(np.float64) - - # None is a sentinel used by subclasses to trigger cupy dispatch. - return None + return _array_ufunc(self, ufunc, method, inputs, kwargs) def _apply_cupy_ufunc_to_operands( self, ufunc, cupy_func, operands, **kwargs diff --git a/python/cudf/cudf/core/index.py b/python/cudf/cudf/core/index.py index e60cf1f2103..d935da3bd14 100644 --- a/python/cudf/cudf/core/index.py +++ b/python/cudf/cudf/core/index.py @@ -854,9 +854,7 @@ def _from_data( def _binaryop( self, other: T, op: str, fill_value: Any = None, *args, **kwargs, ) -> SingleColumnFrame: - reflect = self._is_reflected_op(op) - if reflect: - op = op[:2] + op[3:] + reflect, op = self._check_reflected_op(op) operands = self._make_operands_for_binop(other, fill_value, reflect) if operands is NotImplemented: return NotImplemented diff --git a/python/cudf/cudf/core/indexed_frame.py b/python/cudf/cudf/core/indexed_frame.py index 7e116607017..b8077d7d28b 100644 --- a/python/cudf/cudf/core/indexed_frame.py +++ b/python/cudf/cudf/core/indexed_frame.py @@ -2118,9 +2118,7 @@ def _binaryop( *args, **kwargs, ): - reflect = self._is_reflected_op(op) - if reflect: - op = op[:2] + op[3:] + reflect, op = self._check_reflected_op(op) operands, out_index = self._make_operands_and_index_for_binop( other, op, fill_value, reflect, can_reindex ) diff --git a/python/cudf/cudf/core/mixins/binops.py b/python/cudf/cudf/core/mixins/binops.py index 773b47b62b2..e07977ed4c3 100644 --- a/python/cudf/cudf/core/mixins/binops.py +++ b/python/cudf/cudf/core/mixins/binops.py @@ -48,9 +48,25 @@ }, ) +# TODO: See if there is a better approach to these two issues: 1) The mixin +# assumes a single standard parameter, whereas binops have two, and 2) we need +# a way to determine reflected vs normal ops. -def _is_reflected_op(op): - return op[2] == "r" and op != "__rshift__" +def _binaryop(self, other, op: str): + """The core binary_operation function. -BinaryOperand._is_reflected_op = staticmethod(_is_reflected_op) + Must be overridden by subclasses, the default implementation raises a + NotImplementedError. + """ + raise NotImplementedError + + +def _check_reflected_op(op): + if (reflect := op[2] == "r" and op != "__rshift__") : + op = op[:2] + op[3:] + return reflect, op + + +BinaryOperand._binaryop = _binaryop +BinaryOperand._check_reflected_op = staticmethod(_check_reflected_op) diff --git a/python/cudf/cudf/core/mixins/binops.pyi b/python/cudf/cudf/core/mixins/binops.pyi index 45093cd04d4..ff47cdce418 100644 --- a/python/cudf/cudf/core/mixins/binops.pyi +++ b/python/cudf/cudf/core/mixins/binops.pyi @@ -1,10 +1,16 @@ # Copyright (c) 2022, NVIDIA CORPORATION. -from typing import Set +from typing import Any, Set, Tuple, TypeVar + +# Note: It may be possible to define a narrower bound here eventually. +BinaryOperandType = TypeVar("BinaryOperandType", bound="Any") class BinaryOperand: _SUPPORTED_BINARY_OPERATIONS: Set + def _binaryop(self, other: BinaryOperandType, op: str): + ... + def __add__(self, other): ... @@ -84,5 +90,5 @@ class BinaryOperand: ... @staticmethod - def _is_reflected_op(op) -> bool: + def _check_reflected_op(op) -> Tuple[bool, str]: ... diff --git a/python/cudf/cudf/core/tools/datetimes.py b/python/cudf/cudf/core/tools/datetimes.py index 62c31691ac1..b110a10e1e7 100644 --- a/python/cudf/cudf/core/tools/datetimes.py +++ b/python/cudf/cudf/core/tools/datetimes.py @@ -1,4 +1,4 @@ -# Copyright (c) 2019-2021, NVIDIA CORPORATION. +# Copyright (c) 2019-2022, NVIDIA CORPORATION. import math import re @@ -587,12 +587,12 @@ def _combine_kwargs_to_seconds(self, **kwargs): def _datetime_binop( self, datetime_col, op, reflect=False ) -> column.DatetimeColumn: - if reflect and op == "sub": + if reflect and op == "__sub__": raise TypeError( f"Can not subtract a {type(datetime_col).__name__}" f" from a {type(self).__name__}" ) - if op not in {"add", "sub"}: + if op not in {"__add__", "__sub__"}: raise TypeError( f"{op} not supported between {type(self).__name__}" f" and {type(datetime_col).__name__}" @@ -604,7 +604,7 @@ def _datetime_binop( for unit, value in self._scalars.items(): if unit != "months": - value = -value if op == "sub" else value + value = -value if op == "__sub__" else value datetime_col += cudf.core.column.as_column( value, length=len(datetime_col) ) @@ -613,7 +613,7 @@ def _datetime_binop( def _generate_months_column(self, size, op): months = self._scalars["months"] - months = -months if op == "sub" else months + months = -months if op == "__sub__" else months # TODO: pass a scalar instead of constructing a column # https://github.com/rapidsai/cudf/issues/6990 col = cudf.core.column.as_column(months, length=size) diff --git a/python/cudf/cudf/tests/test_list.py b/python/cudf/cudf/tests/test_list.py index fc9ad9711d1..8cc65de739e 100644 --- a/python/cudf/cudf/tests/test_list.py +++ b/python/cudf/cudf/tests/test_list.py @@ -381,7 +381,7 @@ def test_concatenate_rows_of_lists(): def test_concatenate_list_with_nonlist(): - with pytest.raises(TypeError, match="can only concatenate list to list"): + with pytest.raises(TypeError): gdf1 = cudf.DataFrame({"A": [["a", "c"], ["b", "d"], ["c", "d"]]}) gdf2 = cudf.DataFrame({"A": ["a", "b", "c"]}) gdf1["A"] + gdf2["A"] diff --git a/python/cudf/cudf/tests/test_timedelta.py b/python/cudf/cudf/tests/test_timedelta.py index e371cd16180..2623b755cfb 100644 --- a/python/cudf/cudf/tests/test_timedelta.py +++ b/python/cudf/cudf/tests/test_timedelta.py @@ -1175,8 +1175,7 @@ def test_timedelta_invalid_ops(): lfunc_args_and_kwargs=([psr, dt_psr],), rfunc_args_and_kwargs=([sr, dt_sr],), expected_error_message=re.escape( - f"Modulus of {sr.dtype} with {dt_sr.dtype} " - f"cannot be performed." + f"Modulo of {sr.dtype} with {dt_sr.dtype} " f"cannot be performed." ), ) @@ -1186,7 +1185,7 @@ def test_timedelta_invalid_ops(): lfunc_args_and_kwargs=([psr, "a"],), rfunc_args_and_kwargs=([sr, "a"],), expected_error_message=re.escape( - f"Modulus of {sr.dtype} with {np.dtype('object')} " + f"Modulo of {sr.dtype} with {np.dtype('object')} " f"cannot be performed." ), ) @@ -1285,9 +1284,7 @@ def test_timedelta_invalid_ops(): rfunc=operator.xor, lfunc_args_and_kwargs=([psr, psr],), rfunc_args_and_kwargs=([sr, sr],), - expected_error_message=re.escape( - f"Series of dtype {sr.dtype} cannot perform the operation xor" - ), + compare_error_message=False, ) diff --git a/python/cudf/cudf/utils/applyutils.py b/python/cudf/cudf/utils/applyutils.py index 593965046e6..89331b933a8 100644 --- a/python/cudf/cudf/utils/applyutils.py +++ b/python/cudf/cudf/utils/applyutils.py @@ -1,4 +1,4 @@ -# Copyright (c) 2018, NVIDIA CORPORATION. +# Copyright (c) 2018-2022, NVIDIA CORPORATION. import functools from typing import Any, Dict @@ -103,7 +103,7 @@ def apply_chunks( return applychunks.run(df, chunks=chunks, tpb=tpb) -def make_aggregate_nullmask(df, columns=None, op="and"): +def make_aggregate_nullmask(df, columns=None, op="__and__"): out_mask = None for k in columns or df._data: diff --git a/python/cudf/cudf/utils/utils.py b/python/cudf/cudf/utils/utils.py index 1bd3fa7558e..ed714182576 100644 --- a/python/cudf/cudf/utils/utils.py +++ b/python/cudf/cudf/utils/utils.py @@ -24,6 +24,92 @@ mask_dtype = cudf.dtype(np.int32) mask_bitsize = mask_dtype.itemsize * 8 +# Mapping from ufuncs to the corresponding binary operators. +_ufunc_binary_operations = { + # Arithmetic binary operations. + "add": "add", + "subtract": "sub", + "multiply": "mul", + "matmul": "matmul", + "divide": "truediv", + "true_divide": "truediv", + "floor_divide": "floordiv", + "power": "pow", + "float_power": "pow", + "remainder": "mod", + "mod": "mod", + "fmod": "mod", + # Bitwise binary operations. + "bitwise_and": "and", + "bitwise_or": "or", + "bitwise_xor": "xor", + # Comparison binary operators + "greater": "gt", + "greater_equal": "ge", + "less": "lt", + "less_equal": "le", + "not_equal": "ne", + "equal": "eq", +} + +# These operators need to be mapped to their inverses when performing a +# reflected ufunc operation because no reflected version of the operators +# themselves exist. When these operators are invoked directly (not via +# __array_ufunc__) Python takes care of calling the inverse operation. +_ops_without_reflection = { + "gt": "lt", + "ge": "le", + "lt": "gt", + "le": "ge", + # ne and eq are symmetric, so they are their own inverse op + "ne": "ne", + "eq": "eq", +} + + +# This is the implementation of __array_ufunc__ used for Frame and Column. +# For more detail on this function and how it should work, see +# https://numpy.org/doc/stable/reference/ufuncs.html +def _array_ufunc(obj, ufunc, method, inputs, kwargs): + # We don't currently support reduction, accumulation, etc. We also + # don't support any special kwargs or higher arity ufuncs than binary. + if method != "__call__" or kwargs or ufunc.nin > 2: + return NotImplemented + + fname = ufunc.__name__ + if fname in _ufunc_binary_operations: + reflect = obj is not inputs[0] + other = inputs[0] if reflect else inputs[1] + + op = _ufunc_binary_operations[fname] + if reflect and op in _ops_without_reflection: + op = _ops_without_reflection[op] + reflect = False + op = f"__{'r' if reflect else ''}{op}__" + + # float_power returns float irrespective of the input type. + # TODO: Do not get the attribute directly, get from the operator module + # so that we can still exploit reflection. + if fname == "float_power": + return getattr(obj, op)(other).astype(float) + return getattr(obj, op)(other) + + # Special handling for various unary operations. + if fname == "negative": + return obj * -1 + if fname == "positive": + return obj.copy(deep=True) + if fname == "invert": + return ~obj + if fname == "absolute": + # TODO: Make sure all obj (mainly Column) implement abs. + return abs(obj) + if fname == "fabs": + return abs(obj).astype(np.float64) + + # None is a sentinel used by subclasses to trigger cupy dispatch. + return None + _EQUALITY_OPS = { "__eq__",