Skip to content

Commit

Permalink
Simplify column binary operations (#10421)
Browse files Browse the repository at this point in the history
This PR contains a large set of changes aimed at simplifying the logic used to normalize binary operands and perform binary operations at the column level. Many previously disparate code paths have been consolidated, while others have simply been moved into the narrowest possible scope so that they no longer affect all operations. In particular, all dtype-specific logic has been moved from `Frame._colwise_binop` into the appropriate column dtypes.

Authors:
  - Vyas Ramasubramani (https://github.com/vyasr)

Approvers:
  - Michael Wang (https://github.com/isVoid)
  - Charles Blackmon-Luca (https://github.com/charlesbluca)

URL: #10421
  • Loading branch information
vyasr authored Mar 17, 2022
1 parent 04933a2 commit 22a9f35
Show file tree
Hide file tree
Showing 15 changed files with 230 additions and 308 deletions.
2 changes: 1 addition & 1 deletion python/cudf/cudf/api/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,7 +248,7 @@ def _union_categoricals(
is_datetime64tz_dtype = pd_types.is_datetime64tz_dtype
is_extension_type = pd_types.is_extension_type
is_extension_array_dtype = pd_types.is_extension_array_dtype
is_float_dtype = pd_types.is_float_dtype
is_float_dtype = _wrap_pandas_is_dtype_api(pd_types.is_float_dtype)
is_int64_dtype = pd_types.is_int64_dtype
is_integer_dtype = _wrap_pandas_is_dtype_api(pd_types.is_integer_dtype)
is_object_dtype = pd_types.is_object_dtype
Expand Down
42 changes: 26 additions & 16 deletions python/cudf/cudf/core/column/categorical.py
Original file line number Diff line number Diff line change
Expand Up @@ -878,38 +878,48 @@ def slice(
def binary_operator(
self, op: str, rhs, reflect: bool = False
) -> ColumnBase:
if not (self.ordered and rhs.ordered) and op not in (
"eq",
"ne",
"NULL_EQUALS",
):
if op in ("lt", "gt", "le", "ge"):
raise TypeError(
"Unordered Categoricals can only compare equality or not"
)
if op not in {"eq", "ne", "lt", "le", "gt", "ge", "NULL_EQUALS"}:
raise TypeError(
f"Series of dtype `{self.dtype}` cannot perform the "
f"operation: {op}"
"Series of dtype `category` cannot perform the operation: "
f"{op}"
)
rhs = self._wrap_binop_normalization(rhs)
# TODO: This is currently just here to make mypy happy, but eventually
# we'll need to properly establish the APIs for these methods.
if not isinstance(rhs, CategoricalColumn):
raise ValueError
# Note: at this stage we are guaranteed that the dtypes are equal.
if not self.ordered and op not in {"eq", "ne", "NULL_EQUALS"}:
raise TypeError(
"The only binary operations supported by unordered "
"categorical columns are equality and inequality."
)
if self.dtype != rhs.dtype:
raise TypeError("Categoricals can only compare with the same type")
return self.as_numerical.binary_operator(op, rhs.as_numerical)

def normalize_binop_value(self, other: ScalarLike) -> CategoricalColumn:

if isinstance(other, column.ColumnBase):
if not isinstance(other, CategoricalColumn):
raise ValueError(
"Binary operations with categorical columns require both "
"columns to be categorical."
)
if other.dtype != self.dtype:
raise TypeError(
"Categoricals can only compare with the same type"
)
return other
if isinstance(other, np.ndarray) and other.ndim == 0:
other = other.item()

ary = cudf.utils.utils.scalar_broadcast_to(
self._encode(other), size=len(self), dtype=self.codes.dtype
)
col = column.build_categorical_column(
return column.build_categorical_column(
categories=self.dtype.categories._values,
codes=column.as_column(ary),
mask=self.base_mask,
ordered=self.dtype.ordered,
)
return col

def sort_by_values(
self, ascending: bool = True, na_position="last"
Expand Down
5 changes: 5 additions & 0 deletions python/cudf/cudf/core/column/column.py
Original file line number Diff line number Diff line change
Expand Up @@ -528,6 +528,11 @@ def __setitem__(self, key: Any, value: Any):
if out:
self._mimic_inplace(out, inplace=True)

def _wrap_binop_normalization(self, other):
if other is cudf.NA:
return cudf.Scalar(other, dtype=self.dtype)
return self.normalize_binop_value(other)

def _scatter_by_slice(
self, key: Slice, value: Union[cudf.core.scalar.Scalar, ColumnBase]
) -> Optional[ColumnBase]:
Expand Down
21 changes: 10 additions & 11 deletions python/cudf/cudf/core/column/datetime.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,12 +225,11 @@ def round(self, freq: str) -> ColumnBase:
return libcudf.datetime.round_datetime(self, freq)

def normalize_binop_value(self, other: DatetimeLikeScalar) -> ScalarLike:
if isinstance(other, cudf.Scalar):
if isinstance(other, (cudf.Scalar, ColumnBase, cudf.DateOffset)):
return other

if isinstance(other, np.ndarray) and other.ndim == 0:
other = other.item()

if isinstance(other, dt.datetime):
other = np.datetime64(other)
elif isinstance(other, dt.timedelta):
Expand All @@ -239,8 +238,7 @@ def normalize_binop_value(self, other: DatetimeLikeScalar) -> ScalarLike:
other = other.to_datetime64()
elif isinstance(other, pd.Timedelta):
other = other.to_timedelta64()
elif isinstance(other, cudf.DateOffset):
return other

if isinstance(other, np.datetime64):
if np.isnat(other):
return cudf.Scalar(None, dtype=self.dtype)
Expand All @@ -250,7 +248,7 @@ def normalize_binop_value(self, other: DatetimeLikeScalar) -> ScalarLike:
elif isinstance(other, np.timedelta64):
other_time_unit = cudf.utils.dtypes.get_time_unit(other)

if other_time_unit not in ("s", "ms", "ns", "us"):
if other_time_unit not in {"s", "ms", "ns", "us"}:
other = other.astype("timedelta64[s]")

if np.isnat(other):
Expand All @@ -259,8 +257,8 @@ def normalize_binop_value(self, other: DatetimeLikeScalar) -> ScalarLike:
return cudf.Scalar(other)
elif other is None:
return cudf.Scalar(other, dtype=self.dtype)
else:
raise TypeError(f"cannot normalize {type(other)}")

raise TypeError(f"cannot normalize {type(other)}")

@property
def as_numerical(self) -> "cudf.core.column.NumericalColumn":
Expand Down Expand Up @@ -390,11 +388,13 @@ def binary_operator(
rhs: Union[ColumnBase, "cudf.Scalar"],
reflect: bool = False,
) -> ColumnBase:
rhs = self._wrap_binop_normalization(rhs)
if isinstance(rhs, cudf.DateOffset):
return rhs._datetime_binop(self, op, reflect=reflect)

lhs: Union[ScalarLike, ColumnBase] = self
if op in ("eq", "ne", "lt", "gt", "le", "ge", "NULL_EQUALS"):
out_dtype = cudf.dtype(np.bool_) # type: Dtype
if op in {"eq", "ne", "lt", "gt", "le", "ge", "NULL_EQUALS"}:
out_dtype: Dtype = cudf.dtype(np.bool_)
elif op == "add" and pd.api.types.is_timedelta64_dtype(rhs.dtype):
out_dtype = cudf.core.column.timedelta._timedelta_add_result_dtype(
rhs, lhs
Expand All @@ -418,8 +418,7 @@ def binary_operator(
f" the operation {op}"
)

if reflect:
lhs, rhs = rhs, lhs
lhs, rhs = (self, rhs) if not reflect else (rhs, self)
return libcudf.binaryop.binaryop(lhs, rhs, op, out_dtype)

def fillna(
Expand Down
84 changes: 41 additions & 43 deletions python/cudf/cudf/core/column/decimal.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,46 +63,23 @@ def as_string_column(
def binary_operator(self, op, other, reflect=False):
if reflect:
self, other = other, self

if not isinstance(
other,
(
DecimalBaseColumn,
cudf.core.column.NumericalColumn,
cudf.Scalar,
),
):
raise TypeError(
f"Operator {op} not supported between"
f"{str(type(self))} and {str(type(other))}"
)
elif isinstance(
other, cudf.core.column.NumericalColumn
) and not is_integer_dtype(other.dtype):
raise TypeError(
f"Only decimal and integer column is supported for {op}."
)
if isinstance(other, cudf.core.column.NumericalColumn):
other = other.as_decimal_column(
self.dtype.__class__(self.dtype.__class__.MAX_PRECISION, 0)
)
if not isinstance(self.dtype, other.dtype.__class__):
if (
self.dtype.precision == other.dtype.precision
and self.dtype.scale == other.dtype.scale
):
other = other.astype(self.dtype)
# Decimals in libcudf don't support truediv, see
# https://github.com/rapidsai/cudf/pull/7435 for explanation.
op = op.replace("true", "")
other = self._wrap_binop_normalization(other)

# Binary Arithmetics between decimal columns. `Scale` and `precision`
# are computed outside of libcudf
try:
if op in ("add", "sub", "mul", "div"):
if op in {"add", "sub", "mul", "div"}:
output_type = _get_decimal_type(self.dtype, other.dtype, op)
result = libcudf.binaryop.binaryop(
self, other, op, output_type
)
# TODO: Why is this necessary? Why isn't the result's
# precision already set correctly based on output_type?
result.dtype.precision = output_type.precision
elif op in ("eq", "ne", "lt", "gt", "le", "ge"):
elif op in {"eq", "ne", "lt", "gt", "le", "ge"}:
result = libcudf.binaryop.binaryop(self, other, op, bool)
except RuntimeError as e:
if "Unsupported operator for these types" in str(e):
Expand Down Expand Up @@ -140,14 +117,41 @@ def fillna(
return result._with_type_metadata(self.dtype)

def normalize_binop_value(self, other):
if is_scalar(other) and isinstance(other, (int, np.int, Decimal)):
return cudf.Scalar(Decimal(other))
elif isinstance(other, cudf.Scalar) and isinstance(
other.dtype, cudf.core.dtypes.DecimalDtype
if isinstance(other, ColumnBase):
if isinstance(other, cudf.core.column.NumericalColumn):
if not is_integer_dtype(other.dtype):
raise TypeError(
"Decimal columns only support binary operations with "
"integer numerical columns."
)
other = other.as_decimal_column(
self.dtype.__class__(self.dtype.__class__.MAX_PRECISION, 0)
)
elif not isinstance(other, DecimalBaseColumn):
raise TypeError(
f"Binary operations are not supported between"
f"{str(type(self))} and {str(type(other))}"
)
elif not isinstance(self.dtype, other.dtype.__class__):
# This branch occurs if we have a DecimalBaseColumn of a
# different size (e.g. 64 instead of 32).
if (
self.dtype.precision == other.dtype.precision
and self.dtype.scale == other.dtype.scale
):
other = other.astype(self.dtype)

return other
if isinstance(other, cudf.Scalar) and isinstance(
# TODO: Should it be possible to cast scalars of other numerical
# types to decimal?
other.dtype,
cudf.core.dtypes.DecimalDtype,
):
return other
else:
raise TypeError(f"cannot normalize {type(other)}")
elif is_scalar(other) and isinstance(other, (int, Decimal)):
return cudf.Scalar(Decimal(other))
raise TypeError(f"cannot normalize {type(other)}")

def _decimal_quantile(
self, q: Union[float, Sequence[float]], interpolation: str, exact: bool
Expand Down Expand Up @@ -256,12 +260,6 @@ def _with_type_metadata(
class Decimal64Column(DecimalBaseColumn):
dtype: Decimal64Dtype

def __truediv__(self, other):
# TODO: This override is not sufficient. While it will change the
# behavior of x / y for two decimal columns, it will not affect
# col1.binary_operator(col2), which is how Series/Index will call this.
return self.binary_operator("div", other)

def __setitem__(self, key, value):
if isinstance(value, np.integer):
value = int(value)
Expand Down
5 changes: 4 additions & 1 deletion python/cudf/cudf/core/column/lists.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ def binary_operator(
Name: val, dtype: list
"""

other = self._wrap_binop_normalization(other)
if isinstance(other.dtype, ListDtype):
if binop == "add":
return concatenate_rows(
Expand Down Expand Up @@ -254,6 +254,9 @@ def __cuda_array_interface__(self):
"Lists are not yet supported via `__cuda_array_interface__`"
)

def normalize_binop_value(self, other):
return other

def _with_type_metadata(
self: "cudf.core.column.ListColumn", dtype: Dtype
) -> "cudf.core.column.ListColumn":
Expand Down
Loading

0 comments on commit 22a9f35

Please sign in to comment.