Skip to content

Commit

Permalink
Define proper binary operation APIs for columns (#10509)
Browse files Browse the repository at this point in the history
This PR changes the way that binary operations are performed between columns. Instead of directly invoking the `_binaryop` method Frame binary operations now invoke operators directly using the `operator` module. Each `Column` subclass now only defines operations that are well-defined, relying on Python to handle raising `TypeError`s for all others. Binary operations return `NotImplemented` instead of raising a `TypeError` _except_ in specific cases where a meaningful error should be raised, allowing us to take advantage of reflected operations to prevent duplicate logic on how to handle binary operations between distinct types. Finally, various edge cases that were previously handled by Frames are now handled in Column so that different dtype columns are the sole source of truth on what operands are supported. These changes move us towards fully functional Column classes that do not rely on preprocessed inputs coming from the Frame layer. 

This PR has a large changeset, but a large chunk of the changes lines are simply because some changes to the pipeline result in operations having their dunder names instead of having the dunders stripped, e.g. `__add__` instead of `add`.

Authors:
  - Vyas Ramasubramani (https://github.com/vyasr)

Approvers:
  - Bradley Dice (https://github.com/bdice)

URL: #10509
  • Loading branch information
vyasr authored Mar 25, 2022
1 parent d73d91f commit 19ab7d6
Show file tree
Hide file tree
Showing 21 changed files with 465 additions and 454 deletions.
6 changes: 5 additions & 1 deletion python/cudf/cudf/_lib/binaryop.pyx
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) 2020, NVIDIA CORPORATION.
# Copyright (c) 2020-2022, NVIDIA CORPORATION.

from enum import IntEnum

Expand Down Expand Up @@ -160,6 +160,10 @@ def binaryop(lhs, rhs, op, dtype):
"""
Dispatches a binary op call to the appropriate libcudf function:
"""
# TODO: Shouldn't have to keep special-casing. We need to define a separate
# pipeline for libcudf binops that don't map to Python binops.
if op != "NULL_EQUALS":
op = op[2:-2]

op = BinaryOperation[op.upper()]
cdef binary_operator c_op = <binary_operator> (
Expand Down
6 changes: 4 additions & 2 deletions python/cudf/cudf/_lib/datetime.pyx
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
# Copyright (c) 2020-2022, NVIDIA CORPORATION.

from libcpp.memory cimport unique_ptr
from libcpp.utility cimport move

Expand Down Expand Up @@ -56,8 +58,8 @@ def extract_datetime_component(Column col, object field):

if field == "weekday":
# Pandas counts Monday-Sunday as 0-6
# while we count Monday-Sunday as 1-7
result = result.binary_operator("sub", result.dtype.type(1))
# while libcudf counts Monday-Sunday as 1-7
result = result - result.dtype.type(1)

return result

Expand Down
2 changes: 1 addition & 1 deletion python/cudf/cudf/_typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
ColumnLike = Any

# binary operation
BinaryOperand = Union["cudf.Scalar", "cudf.core.column.ColumnBase"]
ColumnBinaryOperand = Union["cudf.Scalar", "cudf.core.column.ColumnBase"]

DataFrameOrSeries = Union["cudf.Series", "cudf.DataFrame"]
SeriesOrIndex = Union["cudf.Series", "cudf.core.index.BaseIndex"]
Expand Down
34 changes: 15 additions & 19 deletions python/cudf/cudf/core/column/categorical.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
import cudf
from cudf import _lib as libcudf
from cudf._lib.transform import bools_to_mask
from cudf._typing import ColumnLike, Dtype, ScalarLike
from cudf._typing import ColumnBinaryOperand, ColumnLike, Dtype, ScalarLike
from cudf.api.types import is_categorical_dtype, is_interval_dtype
from cudf.core.buffer import Buffer
from cudf.core.column import column
Expand Down Expand Up @@ -630,6 +630,14 @@ class CategoricalColumn(column.ColumnBase):
dtype: cudf.core.dtypes.CategoricalDtype
_codes: Optional[NumericalColumn]
_children: Tuple[NumericalColumn]
_VALID_BINARY_OPERATIONS = {
"__eq__",
"__ne__",
"__lt__",
"__le__",
"__gt__",
"__ge__",
}

def __init__(
self,
Expand Down Expand Up @@ -875,41 +883,29 @@ def slice(
offset=codes.offset,
)

def binary_operator(
self, op: str, rhs, reflect: bool = False
) -> ColumnBase:
if op not in {"eq", "ne", "lt", "le", "gt", "ge", "NULL_EQUALS"}:
raise TypeError(
"Series of dtype `category` cannot perform the operation: "
f"{op}"
)
rhs = self._wrap_binop_normalization(rhs)
def _binaryop(self, other: ColumnBinaryOperand, op: str) -> ColumnBase:
other = self._wrap_binop_normalization(other)
# TODO: This is currently just here to make mypy happy, but eventually
# we'll need to properly establish the APIs for these methods.
if not isinstance(rhs, CategoricalColumn):
if not isinstance(other, CategoricalColumn):
raise ValueError
# Note: at this stage we are guaranteed that the dtypes are equal.
if not self.ordered and op not in {"eq", "ne", "NULL_EQUALS"}:
if not self.ordered and op not in {"__eq__", "__ne__", "NULL_EQUALS"}:
raise TypeError(
"The only binary operations supported by unordered "
"categorical columns are equality and inequality."
)
return self.as_numerical.binary_operator(op, rhs.as_numerical)
return self.as_numerical._binaryop(other.as_numerical, op)

def normalize_binop_value(self, other: ScalarLike) -> CategoricalColumn:
if isinstance(other, column.ColumnBase):
if not isinstance(other, CategoricalColumn):
raise ValueError(
"Binary operations with categorical columns require both "
"columns to be categorical."
)
return NotImplemented
if other.dtype != self.dtype:
raise TypeError(
"Categoricals can only compare with the same type"
)
return other
if isinstance(other, np.ndarray) and other.ndim == 0:
other = other.item()

ary = cudf.utils.utils.scalar_broadcast_to(
self._encode(other), size=len(self), dtype=self.codes.dtype
Expand Down
71 changes: 13 additions & 58 deletions python/cudf/cudf/core/column/column.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@
drop_nulls,
)
from cudf._lib.transform import bools_to_mask
from cudf._typing import BinaryOperand, ColumnLike, Dtype, ScalarLike
from cudf._typing import ColumnLike, Dtype, ScalarLike
from cudf.api.types import (
_is_non_decimal_numeric_dtype,
_is_scalar_or_zero_d_array,
Expand All @@ -68,7 +68,7 @@
ListDtype,
StructDtype,
)
from cudf.core.mixins import Reducible
from cudf.core.mixins import BinaryOperand, Reducible
from cudf.utils import utils
from cudf.utils.dtypes import (
cudf_dtype_from_pa_type,
Expand All @@ -78,15 +78,15 @@
pandas_dtypes_alias_to_cudf_alias,
pandas_dtypes_to_np_dtypes,
)
from cudf.utils.utils import NotIterable, mask_dtype
from cudf.utils.utils import NotIterable, _array_ufunc, mask_dtype

T = TypeVar("T", bound="ColumnBase")
# TODO: This workaround allows type hints for `slice`, since `slice` is a
# method in ColumnBase.
Slice = TypeVar("Slice", bound=slice)


class ColumnBase(Column, Serializable, Reducible, NotIterable):
class ColumnBase(Column, Serializable, BinaryOperand, Reducible, NotIterable):
_VALID_REDUCTIONS = {
"any",
"all",
Expand Down Expand Up @@ -185,7 +185,10 @@ def equals(self, other: ColumnBase, check_dtypes: bool = False) -> bool:
return False
if check_dtypes and (self.dtype != other.dtype):
return False
return self.binary_operator("NULL_EQUALS", other).all()
ret = self._binaryop(other, "NULL_EQUALS")
if ret is NotImplemented:
raise TypeError(f"Cannot compare equality with {type(other)}")
return ret.all()

def all(self, skipna: bool = True) -> bool:
# The skipna argument is only used for numerical columns.
Expand Down Expand Up @@ -521,8 +524,10 @@ def __setitem__(self, key: Any, value: Any):
self._mimic_inplace(out, inplace=True)

def _wrap_binop_normalization(self, other):
if other is cudf.NA:
if other is cudf.NA or other is None:
return cudf.Scalar(other, dtype=self.dtype)
if isinstance(other, np.ndarray) and other.ndim == 0:
other = other.item()
return self.normalize_binop_value(other)

def _scatter_by_slice(
Expand Down Expand Up @@ -1029,50 +1034,8 @@ def __cuda_array_interface__(self):
"`__cuda_array_interface__`"
)

def __add__(self, other):
return self.binary_operator("add", other)

def __sub__(self, other):
return self.binary_operator("sub", other)

def __mul__(self, other):
return self.binary_operator("mul", other)

def __eq__(self, other):
return self.binary_operator("eq", other)

def __ne__(self, other):
return self.binary_operator("ne", other)

def __or__(self, other):
return self.binary_operator("or", other)

def __and__(self, other):
return self.binary_operator("and", other)

def __floordiv__(self, other):
return self.binary_operator("floordiv", other)

def __truediv__(self, other):
return self.binary_operator("truediv", other)

def __mod__(self, other):
return self.binary_operator("mod", other)

def __pow__(self, other):
return self.binary_operator("pow", other)

def __lt__(self, other):
return self.binary_operator("lt", other)

def __gt__(self, other):
return self.binary_operator("gt", other)

def __le__(self, other):
return self.binary_operator("le", other)

def __ge__(self, other):
return self.binary_operator("ge", other)
def __array_ufunc__(self, ufunc, method, *inputs, **kwargs):
return _array_ufunc(self, ufunc, method, inputs, kwargs)

def searchsorted(
self,
Expand Down Expand Up @@ -1133,14 +1096,6 @@ def unary_operator(self, unaryop: str):
f"Operation {unaryop} not supported for dtype {self.dtype}."
)

def binary_operator(
self, op: str, other: BinaryOperand, reflect: bool = False
) -> ColumnBase:
raise TypeError(
f"Operation {op} not supported between dtypes {self.dtype} and "
f"{other.dtype}."
)

def normalize_binop_value(
self, other: ScalarLike
) -> Union[ColumnBase, ScalarLike]:
Expand Down
87 changes: 56 additions & 31 deletions python/cudf/cudf/core/column/datetime.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,20 @@
import re
from locale import nl_langinfo
from types import SimpleNamespace
from typing import Any, Mapping, Sequence, Union, cast
from typing import Any, Mapping, Sequence, cast

import numpy as np
import pandas as pd

import cudf
from cudf import _lib as libcudf
from cudf._typing import DatetimeLikeScalar, Dtype, DtypeObj, ScalarLike
from cudf._typing import (
ColumnBinaryOperand,
DatetimeLikeScalar,
Dtype,
DtypeObj,
ScalarLike,
)
from cudf.api.types import is_scalar
from cudf.core._compat import PANDAS_GE_120
from cudf.core.buffer import Buffer
Expand Down Expand Up @@ -109,6 +115,19 @@ class DatetimeColumn(column.ColumnBase):
The validity mask
"""

_VALID_BINARY_OPERATIONS = {
"__eq__",
"__ne__",
"__lt__",
"__le__",
"__gt__",
"__ge__",
"__add__",
"__sub__",
"__radd__",
"__rsub__",
}

def __init__(
self,
data: Buffer,
Expand Down Expand Up @@ -227,8 +246,6 @@ def normalize_binop_value(self, other: DatetimeLikeScalar) -> ScalarLike:
if isinstance(other, (cudf.Scalar, ColumnBase, cudf.DateOffset)):
return other

if isinstance(other, np.ndarray) and other.ndim == 0:
other = other.item()
if isinstance(other, dt.datetime):
other = np.datetime64(other)
elif isinstance(other, dt.timedelta):
Expand All @@ -254,10 +271,8 @@ def normalize_binop_value(self, other: DatetimeLikeScalar) -> ScalarLike:
return cudf.Scalar(None, dtype=other.dtype)

return cudf.Scalar(other)
elif other is None:
return cudf.Scalar(other, dtype=self.dtype)

raise TypeError(f"cannot normalize {type(other)}")
return NotImplemented

@property
def as_numerical(self) -> "cudf.core.column.NumericalColumn":
Expand Down Expand Up @@ -388,43 +403,53 @@ def quantile(
return pd.Timestamp(result, unit=self.time_unit)
return result.astype(self.dtype)

def binary_operator(
self,
op: str,
rhs: Union[ColumnBase, "cudf.Scalar"],
reflect: bool = False,
) -> ColumnBase:
rhs = self._wrap_binop_normalization(rhs)
if isinstance(rhs, cudf.DateOffset):
return rhs._datetime_binop(self, op, reflect=reflect)

lhs: Union[ScalarLike, ColumnBase] = self
if op in {"eq", "ne", "lt", "gt", "le", "ge", "NULL_EQUALS"}:
def _binaryop(self, other: ColumnBinaryOperand, op: str) -> ColumnBase:
reflect, op = self._check_reflected_op(op)
other = self._wrap_binop_normalization(other)
if other is NotImplemented:
return NotImplemented
if isinstance(other, cudf.DateOffset):
return other._datetime_binop(self, op, reflect=reflect)

# TODO: Figure out if I can reflect before we start these checks. That
# requires figuring out why _timedelta_add_result_dtype and
# _timedelta_sub_result_dtype are 1) not symmetric, and 2) different
# from each other.
if op in {
"__eq__",
"__ne__",
"__lt__",
"__gt__",
"__le__",
"__ge__",
"NULL_EQUALS",
}:
out_dtype: Dtype = cudf.dtype(np.bool_)
elif op == "add" and pd.api.types.is_timedelta64_dtype(rhs.dtype):
elif op == "__add__" and pd.api.types.is_timedelta64_dtype(
other.dtype
):
out_dtype = cudf.core.column.timedelta._timedelta_add_result_dtype(
rhs, lhs
other, self
)
elif op == "sub" and pd.api.types.is_timedelta64_dtype(rhs.dtype):
elif op == "__sub__" and pd.api.types.is_timedelta64_dtype(
other.dtype
):
out_dtype = cudf.core.column.timedelta._timedelta_sub_result_dtype(
rhs if reflect else lhs, lhs if reflect else rhs
other if reflect else self, self if reflect else other
)
elif op == "sub" and pd.api.types.is_datetime64_dtype(rhs.dtype):
elif op == "__sub__" and pd.api.types.is_datetime64_dtype(other.dtype):
units = ["s", "ms", "us", "ns"]
lhs_time_unit = cudf.utils.dtypes.get_time_unit(lhs)
lhs_time_unit = cudf.utils.dtypes.get_time_unit(self)
lhs_unit = units.index(lhs_time_unit)
rhs_time_unit = cudf.utils.dtypes.get_time_unit(rhs)
rhs_time_unit = cudf.utils.dtypes.get_time_unit(other)
rhs_unit = units.index(rhs_time_unit)
out_dtype = np.dtype(
f"timedelta64[{units[max(lhs_unit, rhs_unit)]}]"
)
else:
raise TypeError(
f"Series of dtype {self.dtype} cannot perform "
f" the operation {op}"
)
return NotImplemented

lhs, rhs = (self, rhs) if not reflect else (rhs, self)
lhs, rhs = (other, self) if reflect else (self, other)
return libcudf.binaryop.binaryop(lhs, rhs, op, out_dtype)

def fillna(
Expand Down
Loading

0 comments on commit 19ab7d6

Please sign in to comment.