diff --git a/python/cudf/cudf/api/types.py b/python/cudf/cudf/api/types.py index daf6d11aa9f..01af22f70bf 100644 --- a/python/cudf/cudf/api/types.py +++ b/python/cudf/cudf/api/types.py @@ -126,6 +126,7 @@ def is_scalar(val): return ( isinstance(val, DeviceScalar) or isinstance(val, cudf.Scalar) + or isinstance(val, cudf.core.tools.datetimes.DateOffset) or pd_types.is_scalar(val) ) @@ -267,3 +268,7 @@ def _union_categoricals( is_re = pd_types.is_re is_re_compilable = pd_types.is_re_compilable is_dtype_equal = pd_types.is_dtype_equal + + +# Aliases of numpy dtype functionality. +issubdtype = np.issubdtype diff --git a/python/cudf/cudf/core/dataframe.py b/python/cudf/cudf/core/dataframe.py index c02bf3d11a4..75d6d712708 100644 --- a/python/cudf/cudf/core/dataframe.py +++ b/python/cudf/cudf/core/dataframe.py @@ -24,7 +24,6 @@ import cudf from cudf import _lib as libcudf -from cudf._lib.null_mask import MaskState, create_null_mask from cudf.api.types import is_bool_dtype, is_dict_like from cudf.core import column, reshape from cudf.core.abc import Serializable @@ -58,38 +57,6 @@ T = TypeVar("T", bound="DataFrame") -_reverse_op = { - "add": "radd", - "radd": "add", - "sub": "rsub", - "rsub": "sub", - "mul": "rmul", - "rmul": "mul", - "mod": "rmod", - "rmod": "mod", - "pow": "rpow", - "rpow": "pow", - "floordiv": "rfloordiv", - "rfloordiv": "floordiv", - "truediv": "rtruediv", - "rtruediv": "truediv", - "__add__": "__radd__", - "__radd__": "__add__", - "__sub__": "__rsub__", - "__rsub__": "__sub__", - "__mul__": "__rmul__", - "__rmul__": "__mul__", - "__mod__": "__rmod__", - "__rmod__": "__mod__", - "__pow__": "__rpow__", - "__rpow__": "__pow__", - "__floordiv__": "__rfloordiv__", - "__rfloordiv__": "__floordiv__", - "__truediv__": "__rtruediv__", - "__rtruediv__": "__truediv__", -} - - _cupy_nan_methods_map = { "min": "nanmin", "max": "nanmax", @@ -1453,84 +1420,91 @@ def _get_columns_by_label(self, labels, downcast=False): ) return out - # unary, binary, rbinary, orderedcompare, unorderedcompare - def _apply_op(self, fn, other=None, fill_value=None): - - result = DataFrame(index=self.index) - - def op(lhs, rhs): - if fill_value is None: - return getattr(lhs, fn)(rhs) - else: - return getattr(lhs, fn)(rhs, fill_value) - - if other is None: - for col in self._data: - result[col] = getattr(self[col], fn)() - elif isinstance(other, Sequence): - # This adds the ith element of other to the ith column of self. - for k, col in enumerate(self._data): - result[col] = getattr(self[col], fn)(other[k]) - elif isinstance(other, DataFrame): + def _binaryop( + self, + other: Any, + fn: str, + fill_value: Any = None, + reflect: bool = False, + *args, + **kwargs, + ): + lhs, rhs = self, other + + if _is_scalar_or_zero_d_array(rhs): + rhs = [rhs] * lhs._num_columns + + # For columns that exist in rhs but not lhs, we swap the order so that + # we can always assume that left has a binary operator. This + # implementation assumes that binary operations between a column and + # NULL are always commutative, even for binops (like subtraction) that + # are normally anticommutative. + if isinstance(rhs, Sequence): + # TODO: Consider validating sequence length (pandas does). + operands = { + name: (left, right, reflect, fill_value) + for right, (name, left) in zip(rhs, lhs._data.items()) + } + elif isinstance(rhs, DataFrame): if fn in cudf.utils.utils._EQUALITY_OPS: - if not self.columns.equals( - other.columns - ) or not self.index.equals(other.index): + if not lhs.columns.equals(rhs.columns) or not lhs.index.equals( + rhs.index + ): raise ValueError( "Can only compare identically-labeled " "DataFrame objects" ) - lhs, rhs = _align_indices(self, other) - result.index = lhs.index - max_num_rows = max(lhs.shape[0], rhs.shape[0]) - - def fallback(col, fn): - if fill_value is None: - return Series.from_masked_array( - data=column_empty(max_num_rows, dtype="float64"), - mask=create_null_mask( - max_num_rows, state=MaskState.ALL_NULL - ), - ).set_index(col.index) - else: - return getattr(col, fn)(fill_value) + lhs, rhs = _align_indices(lhs, rhs) - for col in lhs._data: - if col not in rhs._data: - result[col] = fallback(lhs[col], fn) - else: - result[col] = op(lhs[col], rhs[col]) - for col in rhs._data: - if col not in lhs._data: - result[col] = fallback(rhs[col], _reverse_op[fn]) - elif isinstance(other, Series): - other_cols = other.to_pandas().to_dict() - df_cols = self._column_names - result_cols = df_cols + tuple( - col for col in other_cols if col not in df_cols + operands = { + name: ( + lcol, + rhs._data[name] + if name in rhs._data + else (fill_value or None), + reflect, + fill_value if name in rhs._data else None, + ) + for name, lcol in lhs._data.items() + } + for name, col in rhs._data.items(): + if name not in lhs._data: + operands[name] = ( + col, + (fill_value or None), + not reflect, + None, + ) + elif isinstance(rhs, Series): + # Note: This logic will need updating if any of the user-facing + # binop methods (e.g. DataFrame.add) ever support axis=0/rows. + right_dict = dict(zip(rhs.index.values_host, rhs.values_host)) + left_cols = lhs._column_names + # mypy thinks lhs._column_names is a List rather than a Tuple, so + # we have to ignore the type check. + result_cols = left_cols + tuple( # type: ignore + col for col in right_dict if col not in left_cols ) - + operands = {} for col in result_cols: - l_opr = ( - self[col] - if col in df_cols - else Series(as_column(np.nan, length=len(self))) - ) - r_opr = other_cols.get(col) - result[col] = op(l_opr, r_opr) - - elif isinstance(other, (numbers.Number, cudf.Scalar)) or ( - isinstance(other, np.ndarray) and other.ndim == 0 - ): - for col in self._data: - result[col] = op(self[col], other) + if col in left_cols: + left = lhs._data[col] + right = right_dict[col] if col in right_dict else None + else: + # We match pandas semantics here by performing binops + # between a NaN (not NULL!) column and the actual values, + # which results in nans, the pandas output. + left = as_column(np.nan, length=lhs._num_rows) + right = right_dict[col] + operands[col] = (left, right, reflect, fill_value) else: - raise NotImplementedError( - "DataFrame operations with " + str(type(other)) + " not " - "supported at this time." - ) - return result + return NotImplemented + + return self._from_data( + ColumnAccessor(type(self)._colwise_binop(operands, fn)), + index=lhs._index, + ) def add(self, other, axis="columns", level=None, fill_value=None): """ @@ -1583,7 +1557,7 @@ def add(self, other, axis="columns", level=None, fill_value=None): if level is not None: raise NotImplementedError("level parameter is not supported yet.") - return self._apply_op("add", other, fill_value) + return self._binaryop(other, "add", fill_value) def update( self, @@ -1678,8 +1652,13 @@ def update( self._mimic_inplace(source_df, inplace=True) - def __add__(self, other): - return self._apply_op("__add__", other) + def __invert__(self): + # Defer logic to Series since pandas semantics dictate different + # behaviors for different types that requires too much special casing + # of the standard _unaryop. + return DataFrame( + data={col: ~self[col] for col in self}, index=self.index + ) def radd(self, other, axis=1, level=None, fill_value=None): """ @@ -1732,10 +1711,7 @@ def radd(self, other, axis=1, level=None, fill_value=None): if level is not None: raise NotImplementedError("level parameter is not supported yet.") - return self._apply_op("radd", other, fill_value) - - def __radd__(self, other): - return self._apply_op("__radd__", other) + return self._binaryop(other, "add", fill_value, reflect=True) def sub(self, other, axis="columns", level=None, fill_value=None): """ @@ -1788,10 +1764,7 @@ def sub(self, other, axis="columns", level=None, fill_value=None): if level is not None: raise NotImplementedError("level parameter is not supported yet.") - return self._apply_op("sub", other, fill_value) - - def __sub__(self, other): - return self._apply_op("__sub__", other) + return self._binaryop(other, "sub", fill_value) def rsub(self, other, axis="columns", level=None, fill_value=None): """ @@ -1849,10 +1822,7 @@ def rsub(self, other, axis="columns", level=None, fill_value=None): if level is not None: raise NotImplementedError("level parameter is not supported yet.") - return self._apply_op("rsub", other, fill_value) - - def __rsub__(self, other): - return self._apply_op("__rsub__", other) + return self._binaryop(other, "sub", fill_value, reflect=True) def mul(self, other, axis="columns", level=None, fill_value=None): """ @@ -1907,10 +1877,7 @@ def mul(self, other, axis="columns", level=None, fill_value=None): if level is not None: raise NotImplementedError("level parameter is not supported yet.") - return self._apply_op("mul", other, fill_value) - - def __mul__(self, other): - return self._apply_op("__mul__", other) + return self._binaryop(other, "mul", fill_value) def rmul(self, other, axis="columns", level=None, fill_value=None): """ @@ -1965,10 +1932,7 @@ def rmul(self, other, axis="columns", level=None, fill_value=None): if level is not None: raise NotImplementedError("level parameter is not supported yet.") - return self._apply_op("rmul", other, fill_value) - - def __rmul__(self, other): - return self._apply_op("__rmul__", other) + return self._binaryop(other, "mul", fill_value, reflect=True) def mod(self, other, axis="columns", level=None, fill_value=None): """ @@ -2021,10 +1985,7 @@ def mod(self, other, axis="columns", level=None, fill_value=None): if level is not None: raise NotImplementedError("level parameter is not supported yet.") - return self._apply_op("mod", other, fill_value) - - def __mod__(self, other): - return self._apply_op("__mod__", other) + return self._binaryop(other, "mod", fill_value) def rmod(self, other, axis="columns", level=None, fill_value=None): """ @@ -2077,10 +2038,7 @@ def rmod(self, other, axis="columns", level=None, fill_value=None): if level is not None: raise NotImplementedError("level parameter is not supported yet.") - return self._apply_op("rmod", other, fill_value) - - def __rmod__(self, other): - return self._apply_op("__rmod__", other) + return self._binaryop(other, "mod", fill_value, reflect=True) def pow(self, other, axis="columns", level=None, fill_value=None): """ @@ -2133,10 +2091,7 @@ def pow(self, other, axis="columns", level=None, fill_value=None): if level is not None: raise NotImplementedError("level parameter is not supported yet.") - return self._apply_op("pow", other, fill_value) - - def __pow__(self, other): - return self._apply_op("__pow__", other) + return self._binaryop(other, "pow", fill_value) def rpow(self, other, axis="columns", level=None, fill_value=None): """ @@ -2189,10 +2144,7 @@ def rpow(self, other, axis="columns", level=None, fill_value=None): if level is not None: raise NotImplementedError("level parameter is not supported yet.") - return self._apply_op("rpow", other, fill_value) - - def __rpow__(self, other): - return self._apply_op("__rpow__", other) + return self._binaryop(other, "pow", fill_value, reflect=True) def floordiv(self, other, axis="columns", level=None, fill_value=None): """ @@ -2245,10 +2197,7 @@ def floordiv(self, other, axis="columns", level=None, fill_value=None): if level is not None: raise NotImplementedError("level parameter is not supported yet.") - return self._apply_op("floordiv", other, fill_value) - - def __floordiv__(self, other): - return self._apply_op("__floordiv__", other) + return self._binaryop(other, "floordiv", fill_value) def rfloordiv(self, other, axis="columns", level=None, fill_value=None): """ @@ -2311,10 +2260,7 @@ def rfloordiv(self, other, axis="columns", level=None, fill_value=None): if level is not None: raise NotImplementedError("level parameter is not supported yet.") - return self._apply_op("rfloordiv", other, fill_value) - - def __rfloordiv__(self, other): - return self._apply_op("__rfloordiv__", other) + return self._binaryop(other, "floordiv", fill_value, reflect=True) def truediv(self, other, axis="columns", level=None, fill_value=None): """ @@ -2372,14 +2318,11 @@ def truediv(self, other, axis="columns", level=None, fill_value=None): if level is not None: raise NotImplementedError("level parameter is not supported yet.") - return self._apply_op("truediv", other, fill_value) + return self._binaryop(other, "truediv", fill_value) # Alias for truediv div = truediv - def __truediv__(self, other): - return self._apply_op("__truediv__", other) - def rtruediv(self, other, axis="columns", level=None, fill_value=None): """ Get Floating division of dataframe and other, element-wise (binary @@ -2441,52 +2384,11 @@ def rtruediv(self, other, axis="columns", level=None, fill_value=None): if level is not None: raise NotImplementedError("level parameter is not supported yet.") - return self._apply_op("rtruediv", other, fill_value) + return self._binaryop(other, "truediv", fill_value, reflect=True) # Alias for rtruediv rdiv = rtruediv - def __rtruediv__(self, other): - return self._apply_op("__rtruediv__", other) - - __div__ = __truediv__ - - def __and__(self, other): - return self._apply_op("__and__", other) - - def __or__(self, other): - return self._apply_op("__or__", other) - - def __xor__(self, other): - return self._apply_op("__xor__", other) - - def __eq__(self, other): - return self._apply_op("__eq__", other) - - def __ne__(self, other): - return self._apply_op("__ne__", other) - - def __lt__(self, other): - return self._apply_op("__lt__", other) - - def __le__(self, other): - return self._apply_op("__le__", other) - - def __gt__(self, other): - return self._apply_op("__gt__", other) - - def __ge__(self, other): - return self._apply_op("__ge__", other) - - def __invert__(self): - return self._apply_op("__invert__") - - def __neg__(self): - return self._apply_op("__neg__") - - def __abs__(self): - return self._apply_op("__abs__") - def __iter__(self): return iter(self.columns) diff --git a/python/cudf/cudf/core/frame.py b/python/cudf/cudf/core/frame.py index ec6bf13bd15..5d5be111382 100644 --- a/python/cudf/cudf/core/frame.py +++ b/python/cudf/cudf/core/frame.py @@ -6,7 +6,7 @@ import functools import warnings from collections import abc -from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple, TypeVar, Union +from typing import Any, Dict, Optional, Tuple, TypeVar, Union import cupy import numpy as np @@ -17,7 +17,7 @@ import cudf from cudf import _lib as libcudf from cudf._typing import ColumnLike, DataFrameOrSeries -from cudf.api.types import is_dict_like, is_dtype_equal +from cudf.api.types import is_dict_like, is_dtype_equal, issubdtype from cudf.core.column import ( ColumnBase, as_column, @@ -25,9 +25,11 @@ column_empty, concat_columns, ) +from cudf.core.column_accessor import ColumnAccessor from cudf.core.join import merge from cudf.utils.dtypes import ( _is_non_decimal_numeric_dtype, + _is_scalar_or_zero_d_array, find_common_type, is_categorical_dtype, is_column_like, @@ -40,9 +42,6 @@ T = TypeVar("T", bound="Frame") -if TYPE_CHECKING: - from cudf.core.columnn_accessor import ColumnAccessor - class Frame(libcudf.table.Table): """ @@ -2053,9 +2052,7 @@ def from_arrow(cls, data): ) # as dictionary size can vary, it can't be a single table cudf_dictionaries_columns = { - name: cudf.core.column.ColumnBase.from_arrow( - dict_dictionaries[name] - ) + name: ColumnBase.from_arrow(dict_dictionaries[name]) for name in dict_dictionaries.keys() } @@ -2283,11 +2280,6 @@ def _postprocess_columns(self, other, include_index=True): self._copy_struct_names(other, include_index=include_index) self._copy_interval_data(other, include_index=include_index) - def _unaryop(self, op): - data_columns = (col.unary_operator(op) for col in self._columns) - data = zip(self._column_names, data_columns) - return self.__class__._from_table(Frame(data, self._index)) - def isnull(self): """ Identify missing values. @@ -3338,18 +3330,237 @@ def _reindex( return self._mimic_inplace(result, inplace=inplace) + def _unaryop(self, op): + data_columns = (col.unary_operator(op) for col in self._columns) + data = zip(self._column_names, data_columns) + return self.__class__._from_table(Frame(data, self._index)) + + def _binaryop( + self, + other: T, + fn: str, + fill_value: Any = None, + reflect: bool = False, + *args, + **kwargs, + ) -> Frame: + raise NotImplementedError + + @classmethod + def _colwise_binop( + cls, + operands: Dict[Optional[str], Tuple[ColumnBase, Any, bool, Any]], + fn: str, + ): + """Implement binary ops between two frame-like objects. + + Binary operations for Frames can be reduced to a sequence of binary + operations between column-like objects. Different types of frames need + to preprocess different inputs, so subclasses should implement binary + operations as a preprocessing step that calls this method. + + Parameters + ---------- + operands : Dict[Optional[str], Tuple[ColumnBase, Any, bool, Any]] + A mapping from column names to a tuple containing left and right + operands as well as a boolean indicating whether or not to reflect + an operation and fill value for nulls. + fn : str + The operation to perform. + + Returns + ------- + Frame + A subclass of Frame constructed from the result of performing the + requested operation on the operands. + """ + + # Now actually perform the binop on the columns in left and right. + output = {} + for ( + col, + (left_column, right_column, reflect, fill_value), + ) in operands.items(): + if right_column is cudf.NA: + right_column = cudf.Scalar( + right_column, dtype=left_column.dtype + ) + elif not isinstance(right_column, ColumnBase): + right_column = left_column.normalize_binop_value(right_column) + + fn_apply = fn + if fn == "truediv": + # Decimals in libcudf don't support truediv, see + # https://github.com/rapidsai/cudf/pull/7435 for explanation. + if is_decimal_dtype(left_column.dtype): + fn_apply = "div" + + # Division with integer types results in a suitable float. + truediv_type = { + np.int8: np.float32, + np.int16: np.float32, + np.int32: np.float32, + np.int64: np.float64, + np.uint8: np.float32, + np.uint16: np.float32, + np.uint32: np.float64, + np.uint64: np.float64, + np.bool_: np.float32, + }.get(left_column.dtype.type) + if truediv_type is not None: + left_column = left_column.astype(truediv_type) + + output_mask = None + if fill_value is not None: + if is_scalar(right_column): + if left_column.nullable: + left_column = left_column.fillna(fill_value) + else: + # If both columns are nullable, pandas semantics dictate + # that nulls that are present in both left_column and + # right_column are not filled. + if left_column.nullable and right_column.nullable: + lmask = as_column(left_column.nullmask) + rmask = as_column(right_column.nullmask) + output_mask = (lmask | rmask).data + left_column = left_column.fillna(fill_value) + right_column = right_column.fillna(fill_value) + elif left_column.nullable: + left_column = left_column.fillna(fill_value) + elif right_column.nullable: + right_column = right_column.fillna(fill_value) + + # For bitwise operations we must verify whether the input column + # types are valid, and if so, whether we need to coerce the output + # columns to booleans. + coerce_to_bool = False + if fn_apply in {"and", "or", "xor"}: + err_msg = ( + f"Operation 'bitwise {fn_apply}' not supported between " + f"{left_column.dtype.type.__name__} and {{}}" + ) + if right_column is None: + raise TypeError(err_msg.format(type(None))) + + try: + left_is_bool = issubdtype(left_column.dtype, np.bool_) + right_is_bool = issubdtype(right_column.dtype, np.bool_) + except TypeError: + raise TypeError(err_msg.format(type(right_column))) + + coerce_to_bool = left_is_bool or right_is_bool + + if not ( + (left_is_bool or issubdtype(left_column.dtype, np.integer)) + and ( + right_is_bool + or issubdtype(right_column.dtype, np.integer) + ) + ): + raise TypeError( + err_msg.format(right_column.dtype.type.__name__) + ) + + outcol = ( + left_column.binary_operator( + fn_apply, right_column, reflect=reflect + ) + if right_column is not None + else column_empty( + left_column.size, left_column.dtype, masked=True + ) + ) + + if output_mask is not None: + outcol = outcol.set_mask(output_mask) + + if coerce_to_bool: + outcol = outcol.astype(np.bool_) + + output[col] = outcol + + return output -_truediv_int_dtype_corrections = { - np.int8: np.float32, - np.int16: np.float32, - np.int32: np.float32, - np.int64: np.float64, - np.uint8: np.float32, - np.uint16: np.float32, - np.uint32: np.float64, - np.uint64: np.float64, - np.bool_: np.float32, -} + # Binary arithmetic operations. + def __add__(self, other): + return self._binaryop(other, "add") + + def __radd__(self, other): + return self._binaryop(other, "add", reflect=True) + + def __sub__(self, other): + return self._binaryop(other, "sub") + + def __rsub__(self, other): + return self._binaryop(other, "sub", reflect=True) + + def __mul__(self, other): + return self._binaryop(other, "mul") + + def __rmul__(self, other): + return self._binaryop(other, "mul", reflect=True) + + def __mod__(self, other): + return self._binaryop(other, "mod") + + def __rmod__(self, other): + return self._binaryop(other, "mod", reflect=True) + + def __pow__(self, other): + return self._binaryop(other, "pow") + + def __rpow__(self, other): + return self._binaryop(other, "pow", reflect=True) + + def __floordiv__(self, other): + return self._binaryop(other, "floordiv") + + def __rfloordiv__(self, other): + return self._binaryop(other, "floordiv", reflect=True) + + def __truediv__(self, other): + return self._binaryop(other, "truediv") + + def __rtruediv__(self, other): + return self._binaryop(other, "truediv", reflect=True) + + def __and__(self, other): + return self._binaryop(other, "and") + + def __or__(self, other): + return self._binaryop(other, "or") + + def __xor__(self, other): + return self._binaryop(other, "xor") + + # Binary rich comparison operations. + def __eq__(self, other): + return self._binaryop(other, "eq") + + def __ne__(self, other): + return self._binaryop(other, "ne") + + def __lt__(self, other): + return self._binaryop(other, "lt") + + def __le__(self, other): + return self._binaryop(other, "le") + + def __gt__(self, other): + return self._binaryop(other, "gt") + + def __ge__(self, other): + return self._binaryop(other, "ge") + + # Unary logical operators + def __neg__(self): + return -1 * self + + def __pos__(self): + return self.copy(deep=True) + + def __abs__(self): + return self._unaryop("abs") class SingleColumnFrame(Frame): @@ -3528,7 +3739,7 @@ def from_arrow(cls, array): 2 dtype: object """ - return cls(cudf.core.column.column.ColumnBase.from_arrow(array)) + return cls(ColumnBase.from_arrow(array)) def to_arrow(self): """ @@ -3645,14 +3856,13 @@ def _copy_construct(self, **kwargs): def _binaryop( self, - other, - fn, - fill_value=None, - reflect=False, - lhs=None, + other: T, + fn: str, + fill_value: Any = None, + reflect: bool = False, *args, **kwargs, - ): + ) -> SingleColumnFrame: """Perform a binary operation between two single column frames. Parameters @@ -3676,37 +3886,6 @@ def _binaryop( SingleColumnFrame A new instance containing the result of the operation. """ - if lhs is None: - lhs = self - - rhs = self._normalize_binop_value(other) - - if fn == "truediv": - truediv_type = _truediv_int_dtype_corrections.get(lhs.dtype.type) - if truediv_type is not None: - lhs = lhs.astype(truediv_type) - - output_mask = None - if fill_value is not None: - if is_scalar(rhs): - if lhs.nullable: - lhs = lhs.fillna(fill_value) - else: - # If both columns are nullable, pandas semantics dictate that - # nulls that are present in both lhs and rhs are not filled. - if lhs.nullable and rhs.nullable: - # Note: lhs is a Frame, while rhs is already a column. - lmask = as_column(lhs._column.nullmask) - rmask = as_column(rhs.nullmask) - output_mask = (lmask | rmask).data - lhs = lhs.fillna(fill_value) - rhs = rhs.fillna(fill_value) - elif lhs.nullable: - lhs = lhs.fillna(fill_value) - elif rhs.nullable: - rhs = rhs.fillna(fill_value) - - outcol = lhs._column.binary_operator(fn, rhs, reflect=reflect) # Get the appropriate name for output operations involving two objects # that are Series-like objects. The output shares the lhs's name unless @@ -3719,140 +3898,24 @@ def _binaryop( else: result_name = self.name - output = lhs._copy_construct(data=outcol, name=result_name) - - if output_mask is not None: - output._column = output._column.set_mask(output_mask) - return output - - def _normalize_binop_value(self, other): - """Returns a *column* (not a Series) or scalar for performing - binary operations with self._column. - """ - if isinstance(other, ColumnBase): - return other + # This needs to be tested correctly if isinstance(other, SingleColumnFrame): - return other._column - if other is cudf.NA: - return cudf.Scalar(other, dtype=self.dtype) - else: - return self._column.normalize_binop_value(other) - - def _bitwise_binop(self, other, op): - """Type-coercing wrapper around _binaryop for bitwise operations.""" - # This will catch attempts at bitwise ops on extension dtypes. - try: - self_is_bool = np.issubdtype(self.dtype, np.bool_) - other_is_bool = np.issubdtype(other.dtype, np.bool_) - except TypeError: - raise TypeError( - f"Operation 'bitwise {op}' not supported between " - f"{self.dtype.type.__name__} and {other.dtype.type.__name__}" - ) - - if (self_is_bool or np.issubdtype(self.dtype, np.integer)) and ( - other_is_bool or np.issubdtype(other.dtype, np.integer) - ): - # TODO: This doesn't work on Series (op) DataFrame - # because dataframe doesn't have dtype - ser = self._binaryop(other, op) - if self_is_bool or other_is_bool: - ser = ser.astype(np.bool_) - return ser - else: - raise TypeError( - f"Operation 'bitwise {op}' not supported between " - f"{self.dtype.type.__name__} and {other.dtype.type.__name__}" - ) - - # Binary arithmetic operations. - def __add__(self, other): - return self._binaryop(other, "add") - - def __radd__(self, other): - return self._binaryop(other, "add", reflect=True) - - def __sub__(self, other): - return self._binaryop(other, "sub") - - def __rsub__(self, other): - return self._binaryop(other, "sub", reflect=True) - - def __mul__(self, other): - return self._binaryop(other, "mul") - - def __rmul__(self, other): - return self._binaryop(other, "mul", reflect=True) - - def __mod__(self, other): - return self._binaryop(other, "mod") - - def __rmod__(self, other): - return self._binaryop(other, "mod", reflect=True) - - def __pow__(self, other): - return self._binaryop(other, "pow") - - def __rpow__(self, other): - return self._binaryop(other, "pow", reflect=True) - - def __floordiv__(self, other): - return self._binaryop(other, "floordiv") - - def __rfloordiv__(self, other): - return self._binaryop(other, "floordiv", reflect=True) - - def __truediv__(self, other): - if is_decimal_dtype(self.dtype): - return self._binaryop(other, "div") - else: - return self._binaryop(other, "truediv") - - def __rtruediv__(self, other): - if is_decimal_dtype(self.dtype): - return self._binaryop(other, "div", reflect=True) - else: - return self._binaryop(other, "truediv", reflect=True) - - __div__ = __truediv__ - - def __and__(self, other): - return self._bitwise_binop(other, "and") - - def __or__(self, other): - return self._bitwise_binop(other, "or") - - def __xor__(self, other): - return self._bitwise_binop(other, "xor") - - # Binary rich comparison operations. - def __eq__(self, other): - return self._binaryop(other, "eq") - - def __ne__(self, other): - return self._binaryop(other, "ne") - - def __lt__(self, other): - return self._binaryop(other, "lt") - - def __le__(self, other): - return self._binaryop(other, "le") - - def __gt__(self, other): - return self._binaryop(other, "gt") - - def __ge__(self, other): - return self._binaryop(other, "ge") - - # Unary logical operators - def __neg__(self): - return -1 * self + other = other._column + elif not _is_scalar_or_zero_d_array(other): + # Non-scalar right operands are valid iff they convert to columns. + try: + other = as_column(other) + except Exception: + return NotImplemented - def __pos__(self): - return self.copy(deep=True) + operands: Dict[Optional[str], Tuple[ColumnBase, Any, bool, Any]] = { + result_name: (self._column, other, reflect, fill_value) + } - def __abs__(self): - return self._unaryop("abs") + return self._copy_construct( + data=type(self)._colwise_binop(operands, fn)[result_name], + name=result_name, + ) def _get_replacement_values_for_columns( @@ -3890,7 +3953,7 @@ def _get_replacement_values_for_columns( to_replace_columns = {col: [to_replace] for col in columns_dtype_map} values_columns = {col: [value] for col in columns_dtype_map} elif cudf.utils.dtypes.is_list_like(to_replace) or isinstance( - to_replace, cudf.core.column.ColumnBase + to_replace, ColumnBase ): if is_scalar(value): to_replace_columns = {col: to_replace for col in columns_dtype_map} diff --git a/python/cudf/cudf/core/index.py b/python/cudf/cudf/core/index.py index a2a6d26bee6..691b6ab2e29 100644 --- a/python/cudf/cudf/core/index.py +++ b/python/cudf/cudf/core/index.py @@ -724,17 +724,6 @@ def difference(self, other, sort=None): return difference - def _binaryop(self, other, fn, fill_value=None, reflect=False): - # TODO: Rather than including an allowlist of acceptable types, we - # should instead return NotImplemented for __all__ other types. That - # will allow other types to support binops with cudf objects if they so - # choose, and just as importantly will allow better error messages if - # they don't support it. - if isinstance(other, (cudf.DataFrame, cudf.Series)): - return NotImplemented - - return super()._binaryop(other, fn, fill_value, reflect) - def _copy_construct(self, **kwargs): # Need to override the parent behavior because pandas allows operations # on unsigned types to return signed values, forcing us to choose the @@ -761,16 +750,6 @@ def _copy_construct(self, **kwargs): cls = CategoricalIndex elif cls is RangeIndex: # RangeIndex must convert to other numerical types for ops - - # TODO: The one exception to the output type selected here is - # that scalar multiplication of a RangeIndex in pandas results - # in another RangeIndex. Propagating that information through - # cudf with the current internals is possible, but requires - # significant hackery since we'd need _copy_construct or some - # other constructor to be intrinsically capable of processing - # operations. We should fix this behavior once we've completed - # a more thorough refactoring of the various Index classes that - # makes it easier to propagate this logic. try: cls = _dtype_to_index[data.dtype.type] except KeyError: @@ -1786,6 +1765,22 @@ def unique(self): # RangeIndex always has unique values return self + def __mul__(self, other): + # Multiplication by raw ints must return a RangeIndex to match pandas. + if isinstance(other, cudf.Scalar) and other.dtype.kind in "iu": + other = other.value + elif ( + isinstance(other, (np.ndarray, cupy.ndarray)) + and other.ndim == 0 + and other.dtype.kind in "iu" + ): + other = other.item() + if isinstance(other, (int, np.integer)): + return RangeIndex( + self.start * other, self.stop * other, self.step * other + ) + return super().__mul__(other) + class GenericIndex(BaseIndex): """An array of orderable values that represent the indices of another Column diff --git a/python/cudf/cudf/core/series.py b/python/cudf/cudf/core/series.py index 68f2b42483b..df3ad56848d 100644 --- a/python/cudf/cudf/core/series.py +++ b/python/cudf/cudf/core/series.py @@ -37,7 +37,7 @@ from cudf.core.column.string import StringMethods from cudf.core.column.struct import StructMethods from cudf.core.column_accessor import ColumnAccessor -from cudf.core.frame import SingleColumnFrame, _drop_rows_by_labels +from cudf.core.frame import Frame, SingleColumnFrame, _drop_rows_by_labels from cudf.core.groupby.groupby import SeriesGroupBy from cudf.core.index import BaseIndex, Index, RangeIndex, as_index from cudf.core.indexing import _SeriesIlocIndexer, _SeriesLocIndexer @@ -1329,16 +1329,29 @@ def __repr__(self): return "\n".join(lines) def _binaryop( - self, other, fn, fill_value=None, reflect=False, can_reindex=False + self, + other: Frame, + fn: str, + fill_value: Any = None, + reflect: bool = False, + can_reindex: bool = False, + *args, + **kwargs, ): - if isinstance(other, cudf.DataFrame): - return NotImplemented - - if isinstance(other, Series): + if isinstance(other, SingleColumnFrame): if ( + # TODO: The can_reindex logic also needs to be applied for + # DataFrame (the methods that need it just don't exist yet). not can_reindex and fn in cudf.utils.utils._EQUALITY_OPS - and not self.index.equals(other.index) + and ( + isinstance(other, Series) + # TODO: mypy doesn't like this line because the index + # property is not defined on SingleColumnFrame (or Index, + # for that matter). Ignoring is the easy solution for now, + # a cleaner fix requires reworking the type hierarchy. + and not self.index.equals(other.index) # type: ignore + ) ): raise ValueError( "Can only compare identically-labeled " "Series objects" @@ -1347,7 +1360,8 @@ def _binaryop( else: lhs = self - return super()._binaryop(other, fn, fill_value, reflect, lhs) + # Note that we call the super on lhs, not self. + return super(Series, lhs)._binaryop(other, fn, fill_value, reflect) def add(self, other, fill_value=None, axis=0): """ diff --git a/python/cudf/cudf/tests/test_binops.py b/python/cudf/cudf/tests/test_binops.py index 1c97cbb10ff..13019ca11f9 100644 --- a/python/cudf/cudf/tests/test_binops.py +++ b/python/cudf/cudf/tests/test_binops.py @@ -7,6 +7,7 @@ import random from itertools import product +import cupy as cp import numpy as np import pandas as pd import pytest @@ -2864,3 +2865,26 @@ def set_null_cases(column_l, column_r, case): ) def test_null_equals_columnops(lcol, rcol, ans, case): assert lcol._null_equals(rcol).all() == ans + + +def test_add_series_to_dataframe(): + """Verify that missing columns result in NaNs, not NULLs.""" + assert cp.all( + cp.isnan( + ( + cudf.DataFrame({"a": [1, 2, 3], "b": [4, 5, 6]}) + + cudf.Series([1, 2, 3], index=["a", "b", "c"]) + )["c"] + ) + ) + + +@pytest.mark.parametrize("obj_class", [cudf.Series, cudf.Index]) +@pytest.mark.parametrize("binop", _binops) +@pytest.mark.parametrize("other_type", [np.array, cp.array, pd.Series, list]) +def test_binops_non_cudf_types(obj_class, binop, other_type): + # Skip 0 to not deal with NaNs from division. + data = range(1, 100) + lhs = obj_class(data) + rhs = other_type(data) + assert cp.all((binop(lhs, rhs) == binop(lhs, lhs)).values)