diff --git a/python/cudf/cudf/core/frame.py b/python/cudf/cudf/core/frame.py index 6b83f927727..ddf98823192 100644 --- a/python/cudf/cudf/core/frame.py +++ b/python/cudf/cudf/core/frame.py @@ -669,6 +669,12 @@ def get_column_values_na(col): matrix[:, i] = get_column_values_na(col) return matrix + # TODO: As of now, calling cupy.asarray is _much_ faster than calling + # to_cupy. We should investigate the reasons why and whether we can provide + # a more efficient method here by exploiting __cuda_array_interface__. In + # particular, we need to benchmark how much of the overhead is coming from + # (potentially unavoidable) local copies in to_cupy and how much comes from + # inefficiencies in the implementation. def to_cupy( self, dtype: Union[Dtype, None] = None, @@ -3622,6 +3628,8 @@ def dot(self, other, reflect=False): >>> [1, 2, 3, 4] @ s 10 """ + # TODO: This function does not currently support nulls. + # TODO: This function does not properly support misaligned indexes. lhs = self.values if isinstance(other, Frame): rhs = other.values @@ -3632,6 +3640,16 @@ def dot(self, other, reflect=False): ): rhs = cupy.asarray(other) else: + # TODO: This should raise an exception, not return NotImplemented, + # but __matmul__ relies on the current behavior. We should either + # move this implementation to __matmul__ and call it from here + # (checking for NotImplemented and raising NotImplementedError if + # that's what's returned), or __matmul__ should catch a + # NotImplementedError from here and return NotImplemented. The + # latter feels cleaner (putting the implementation in this method + # rather than in the operator) but will be slower in the (highly + # unlikely) case that we're multiplying a cudf object with another + # type of object that somehow supports this behavior. return NotImplemented if reflect: lhs, rhs = rhs, lhs diff --git a/python/cudf/cudf/core/series.py b/python/cudf/cudf/core/series.py index 5823ea18d1b..247fca4b9ee 100644 --- a/python/cudf/cudf/core/series.py +++ b/python/cudf/cudf/core/series.py @@ -7,6 +7,7 @@ import pickle import warnings from collections import abc as abc +from itertools import repeat from numbers import Number from shutil import get_terminal_size from typing import Any, MutableMapping, Optional, Set, Union @@ -955,14 +956,123 @@ def to_frame(self, name=None): def memory_usage(self, index=True, deep=False): return sum(super().memory_usage(index, deep).values()) + # For more detail on this function and how it should work, see + # https://numpy.org/doc/stable/reference/ufuncs.html def __array_ufunc__(self, ufunc, method, *inputs, **kwargs): - if method == "__call__": - return get_appropriate_dispatched_func( - cudf, cudf.Series, cupy, ufunc, inputs, kwargs - ) - else: + # We don't currently support reduction, accumulation, etc. We also + # don't support any special kwargs or higher arity ufuncs than binary. + if method != "__call__" or kwargs or ufunc.nin > 2: return NotImplemented + # Binary operations + binary_operations = { + # Arithmetic binary operations. + "add": "add", + "subtract": "sub", + "multiply": "mul", + "matmul": "matmul", + "divide": "truediv", + "true_divide": "truediv", + "floor_divide": "floordiv", + "power": "pow", + "float_power": "pow", + "remainder": "mod", + "mod": "mod", + "fmod": "mod", + # Bitwise binary operations. + "bitwise_and": "and", + "bitwise_or": "or", + "bitwise_xor": "xor", + # Comparison binary operators + "greater": "gt", + "greater_equal": "ge", + "less": "lt", + "less_equal": "le", + "not_equal": "ne", + "equal": "eq", + } + + # First look for methods of the class. + fname = ufunc.__name__ + if fname in binary_operations: + not_reflect = self is inputs[0] + other = inputs[not_reflect] + op = f"__{'' if not_reflect else 'r'}{binary_operations[fname]}__" + + # pandas bitwise operations return bools if indexes are misaligned. + # TODO: Generalize for other types of Frames + if ( + "bitwise" in fname + and isinstance(other, Series) + and not self.index.equals(other.index) + ): + return getattr(self, op)(other).astype(bool) + # Float_power returns float irrespective of the input type. + if fname == "float_power": + return getattr(self, op)(other).astype(float) + return getattr(self, op)(other) + + # Special handling for unary operations. + if fname == "negative": + return self * -1 + if fname == "positive": + return self.copy(deep=True) + if fname == "invert": + return ~self + if fname == "absolute": + return self.abs() + if fname == "fabs": + return self.abs().astype(np.float64) + + # Note: There are some operations that may be supported by libcudf but + # are not supported by pandas APIs. In particular, libcudf binary + # operations support logical and/or operations, but those operations + # are not defined on pd.Series/DataFrame. For now those operations will + # dispatch to cupy, but if ufuncs are ever a bottleneck we could add + # special handling to dispatch those (or any other) functions that we + # could implement without cupy. + + # Attempt to dispatch all other functions to cupy. + cupy_func = getattr(cupy, fname) + if cupy_func: + # Indices must be aligned before converting to arrays. + if ufunc.nin == 2 and all(map(isinstance, inputs, repeat(Series))): + inputs = _align_indices(inputs, allow_non_unique=True) + index = inputs[0].index + else: + index = self.index + + cupy_inputs = [] + mask = None + for inp in inputs: + # TODO: Generalize for other types of Frames + if isinstance(inp, Series) and inp.has_nulls: + new_mask = as_column(inp.nullmask) + + # TODO: This is a hackish way to perform a bitwise and of + # bitmasks. Once we expose cudf::detail::bitwise_and, then + # we can use that instead. + mask = new_mask if mask is None else (mask & new_mask) + + # Arbitrarily fill with zeros. For ufuncs, we assume that + # the end result propagates nulls via a bitwise and, so + # these elements are irrelevant. + inp = inp.fillna(0) + cupy_inputs.append(cupy.asarray(inp)) + + cp_output = cupy_func(*cupy_inputs, **kwargs) + + def make_frame(arr): + return self.__class__._from_data( + {self.name: as_column(arr).set_mask(mask)}, index=index + ) + + if ufunc.nout > 1: + return tuple(make_frame(out) for out in cp_output) + return make_frame(cp_output) + + return NotImplemented + def __array_function__(self, func, types, args, kwargs): handled_types = [cudf.Series] for t in types: @@ -1254,15 +1364,31 @@ def _binaryop( ) def logical_and(self, other): + warnings.warn( + "Series.logical_and is deprecated and will be removed.", + FutureWarning, + ) return self._binaryop(other, "l_and").astype(np.bool_) def remainder(self, other): + warnings.warn( + "Series.remainder is deprecated and will be removed.", + FutureWarning, + ) return self._binaryop(other, "mod") def logical_or(self, other): + warnings.warn( + "Series.logical_or is deprecated and will be removed.", + FutureWarning, + ) return self._binaryop(other, "l_or").astype(np.bool_) def logical_not(self): + warnings.warn( + "Series.logical_not is deprecated and will be removed.", + FutureWarning, + ) return self._unaryop("not") @copy_docstring(CategoricalAccessor) # type: ignore diff --git a/python/cudf/cudf/testing/_utils.py b/python/cudf/cudf/testing/_utils.py index 6c602d321eb..b97b2d660d6 100644 --- a/python/cudf/cudf/testing/_utils.py +++ b/python/cudf/cudf/testing/_utils.py @@ -46,7 +46,7 @@ def set_random_null_mask_inplace(series, null_probability=0.5, seed=None): probs = [null_probability, 1 - null_probability] rng = np.random.default_rng(seed=seed) mask = rng.choice([False, True], size=len(series), p=probs) - series[mask] = None + series.iloc[mask] = None # TODO: This function should be removed. Anywhere that it is being used should diff --git a/python/cudf/cudf/tests/test_array_ufunc.py b/python/cudf/cudf/tests/test_array_ufunc.py index 3fe0321ec54..2db0e8d29b8 100644 --- a/python/cudf/cudf/tests/test_array_ufunc.py +++ b/python/cudf/cudf/tests/test_array_ufunc.py @@ -1,175 +1,153 @@ +import operator +from functools import reduce + import cupy as cp import numpy as np -import pandas as pd import pytest import cudf -from cudf.testing._utils import assert_eq - - -@pytest.fixture -def np_ar_tup(): - np.random.seed(0) - return (np.random.random(100), np.random.random(100)) +from cudf.testing._utils import assert_eq, set_random_null_mask_inplace - -comparison_ops_ls = [ - np.greater, - np.greater_equal, - np.less, - np.less_equal, - np.equal, - np.not_equal, +_UFUNCS = [ + obj + for obj in (getattr(np, name) for name in dir(np)) + if isinstance(obj, np.ufunc) ] -@pytest.mark.parametrize( - "func", comparison_ops_ls + [np.subtract, np.fmod, np.power] -) -def test_ufunc_cudf_non_nullseries(np_ar_tup, func): - x, y = np_ar_tup[0], np_ar_tup[1] - s_1, s_2 = cudf.Series(x), cudf.Series(y) - expect = func(x, y) - got = func(s_1, s_2) - assert_eq(expect, got.to_numpy()) - - -@pytest.mark.parametrize( - "func", [np.bitwise_and, np.bitwise_or, np.bitwise_xor], -) -def test_ufunc_cudf_series_bitwise(func): - np.random.seed(0) - x = np.random.randint(size=100, low=0, high=100) - y = np.random.randint(size=100, low=0, high=100) - - s_1, s_2 = cudf.Series(x), cudf.Series(y) - expect = func(x, y) - got = func(s_1, s_2) - assert_eq(expect, got.to_numpy()) - - -@pytest.mark.parametrize( - "func", - [ - np.subtract, - np.multiply, - np.floor_divide, - np.true_divide, - np.power, - np.remainder, - np.divide, - ], -) -def test_ufunc_cudf_null_series(np_ar_tup, func): - x, y = np_ar_tup[0].astype(np.float32), np_ar_tup[1].astype(np.float32) - x[0] = np.nan - y[1] = np.nan - s_1, s_2 = cudf.Series(x), cudf.Series(y) - expect = func(x, y) - got = func(s_1, s_2) - assert_eq(expect, got.fillna(np.nan).to_numpy()) - - scalar = 0.5 - expect = func(x, scalar) - got = func(s_1, scalar) - assert_eq(expect, got.fillna(np.nan).to_numpy()) - - expect = func(scalar, x) - got = func(scalar, s_1) - assert_eq(expect, got.fillna(np.nan).to_numpy()) - - -@pytest.mark.xfail( - reason="""cuDF comparison operations with incorrectly - returns False rather than """ -) -@pytest.mark.parametrize( - "func", comparison_ops_ls, -) -def test_ufunc_cudf_null_series_comparison_ops(np_ar_tup, func): - x, y = np_ar_tup[0].astype(np.float32), np_ar_tup[1].astype(np.float32) - x[0] = np.nan - y[1] = np.nan - s_1, s_2 = cudf.Series(x), cudf.Series(y) - expect = func(x, y) - got = func(s_1, s_2) - assert_eq(expect, got.fillna(np.nan).to_numpy()) - - scalar = 0.5 - expect = func(x, scalar) - got = func(s_1, scalar) - assert_eq(expect, got.fillna(np.nan).to_numpy()) - - expect = func(scalar, x) - got = func(scalar, s_1) - assert_eq(expect, got.fillna(np.nan).to_numpy()) - - -@pytest.mark.parametrize( - "func", [np.logaddexp, np.fmax, np.fmod], -) -def test_ufunc_cudf_series_cupy_array(np_ar_tup, func): - x, y = np_ar_tup[0], np_ar_tup[1] - expect = func(x, y) - - cudf_s = cudf.Series(x) - cupy_ar = cp.array(y) - got = func(cudf_s, cupy_ar) - assert_eq(expect, got.to_numpy()) - - -@pytest.mark.parametrize( - "func", - [np.fmod, np.logaddexp, np.bitwise_and, np.bitwise_or, np.bitwise_xor], -) -def test_error_with_null_cudf_series(func): - s_1 = cudf.Series([1, 2]) - s_2 = cudf.Series([1, None]) - - # this thows a value error - # because of nulls in cudf.Series - with pytest.raises(ValueError): - func(s_1, s_2) - - s_1 = cudf.Series([1, 2]) - s_2 = cudf.Series([1, 2, None]) - - # this throws a value-error if indexes are not aligned - # following pandas behavior for ufunc numpy dispatching - with pytest.raises( - ValueError, match="Can only compare identically-labeled Series objects" - ): - func(s_1, s_2) - - -@pytest.mark.parametrize( - "func", [np.absolute, np.sign, np.exp2, np.tanh], -) -def test_ufunc_cudf_series_with_index(func): - data = [-1, 2, 3, 0] - index = [2, 3, 1, 0] - cudf_s = cudf.Series(data=data, index=index) - pd_s = pd.Series(data=data, index=index) - - expect = func(pd_s) - got = func(cudf_s) - - assert_eq(got, expect) - - -@pytest.mark.parametrize( - "func", [np.logaddexp2], -) -def test_ufunc_cudf_series_with_nonaligned_index(func): - cudf_s1 = cudf.Series(data=[-1, 2, 3, 0], index=[2, 3, 1, 0]) - cudf_s2 = cudf.Series(data=[-1, 2, 3, 0], index=[3, 1, 0, 2]) - - # this throws a value-error if indexes are not aligned - # following pandas behavior for ufunc numpy dispatching - with pytest.raises( - ValueError, match="Can only compare identically-labeled Series objects" +@pytest.mark.parametrize("ufunc", _UFUNCS) +@pytest.mark.parametrize("has_nulls", [True, False]) +@pytest.mark.parametrize("indexed", [True, False]) +def test_ufunc_series(ufunc, has_nulls, indexed): + # Note: This test assumes that all ufuncs are unary or binary. + fname = ufunc.__name__ + if indexed and fname in ( + "greater", + "greater_equal", + "less", + "less_equal", + "not_equal", + "equal", ): - func(cudf_s1, cudf_s2) + pytest.skip("Comparison operators do not support misaligned indexes.") + + if (indexed or has_nulls) and fname == "matmul": + pytest.xfail("Frame.dot currently does not support indexes or nulls") + + N = 100 + # Avoid zeros in either array to skip division by 0 errors. Also limit the + # scale to avoid issues with overflow, etc. We use ints because some + # operations (like bitwise ops) are not defined for floats. + pandas_args = args = [ + cudf.Series( + cp.random.randint(low=1, high=10, size=N), + index=cp.random.choice(range(N), N, False) if indexed else None, + ) + for _ in range(ufunc.nin) + ] + + if has_nulls: + # Converting nullable integer cudf.Series to pandas will produce a + # float pd.Series, so instead we replace nulls with an arbitrary + # integer value, precompute the mask, and then reapply it afterwards. + for arg in args: + set_random_null_mask_inplace(arg) + pandas_args = [arg.fillna(0) for arg in args] + + # Note: Different indexes must be aligned before the mask is computed. + # This requires using an internal function (_align_indices), and that + # is unlikely to change for the foreseeable future. + aligned = ( + cudf.core.series._align_indices(args, allow_non_unique=True) + if indexed and ufunc.nin == 2 + else args + ) + mask = reduce(operator.or_, (a.isna() for a in aligned)).to_pandas() + + try: + got = ufunc(*args) + except AttributeError as e: + # We xfail if we don't have an explicit dispatch and cupy doesn't have + # the method so that we can easily identify these methods. As of this + # writing, the only missing methods are isnat and heaviside. + if "module 'cupy' has no attribute" in str(e): + pytest.xfail(reason="Operation not supported by cupy") + raise + + expect = ufunc(*(arg.to_pandas() for arg in pandas_args)) + + try: + if ufunc.nout > 1: + for g, e in zip(got, expect): + if has_nulls: + e[mask] = np.nan + assert_eq(g, e) + else: + if has_nulls: + expect[mask] = np.nan + assert_eq(got, expect) + except AssertionError: + # TODO: This branch can be removed when + # https://github.com/rapidsai/cudf/issues/10178 is resolved + if fname in ("power", "float_power"): + not_equal = cudf.from_pandas(expect) != got + not_equal[got.isna()] = False + diffs = got[not_equal] - expect[not_equal.to_pandas()] + if diffs.abs().max() == 1: + pytest.xfail("https://github.com/rapidsai/cudf/issues/10178") + raise + + +@pytest.mark.parametrize("ufunc", [np.add, np.greater, np.logical_and]) +@pytest.mark.parametrize("has_nulls", [True, False]) +@pytest.mark.parametrize("indexed", [True, False]) +@pytest.mark.parametrize("type_", ["cupy", "numpy", "list"]) +def test_binary_ufunc_series_array(ufunc, has_nulls, indexed, type_): + fname = ufunc.__name__ + if fname == "greater" and has_nulls: + pytest.xfail( + "The way cudf casts nans in arrays to nulls during binops with " + "cudf objects is currently incompatible with pandas." + ) + N = 100 + # Avoid zeros in either array to skip division by 0 errors. Also limit the + # scale to avoid issues with overflow, etc. We use ints because some + # operations (like bitwise ops) are not defined for floats. + args = [ + cudf.Series( + cp.random.rand(N), + index=cp.random.choice(range(N), N, False) if indexed else None, + ) + for _ in range(ufunc.nin) + ] + + if has_nulls: + # Converting nullable integer cudf.Series to pandas will produce a + # float pd.Series, so instead we replace nulls with an arbitrary + # integer value, precompute the mask, and then reapply it afterwards. + for arg in args: + set_random_null_mask_inplace(arg) + + # Cupy doesn't support nulls, so we fill with nans before converting. + args[1] = args[1].fillna(cp.nan) + mask = args[0].isna().to_pandas() + + arg1 = args[1].to_cupy() if type_ == "cupy" else args[1].to_numpy() + if type_ == "list": + arg1 = arg1.tolist() + + got = ufunc(args[0], arg1) + expect = ufunc(args[0].to_pandas(), args[1].to_numpy()) + + if ufunc.nout > 1: + for g, e in zip(got, expect): + if has_nulls: + e[mask] = np.nan + assert_eq(g, e) + else: + if has_nulls: + expect[mask] = np.nan + assert_eq(got, expect) @pytest.mark.parametrize( diff --git a/python/cudf/cudf/utils/utils.py b/python/cudf/cudf/utils/utils.py index 65a803d6768..8571d9ffed5 100644 --- a/python/cudf/cudf/utils/utils.py +++ b/python/cudf/cudf/utils/utils.py @@ -361,7 +361,10 @@ def get_appropriate_dispatched_func( cupy_compatible_args, index = _get_cupy_compatible_args_index(args) if cupy_compatible_args: cupy_output = cupy_func(*cupy_compatible_args, **kwargs) - return _cast_to_appropriate_cudf_type(cupy_output, index) + if isinstance(cupy_output, cp.ndarray): + return _cast_to_appropriate_cudf_type(cupy_output, index) + else: + return cupy_output return NotImplemented