Skip to content

Commit

Permalink
Refactor array_ufunc for Index and unify across all classes (#10346)
Browse files Browse the repository at this point in the history
This PR builds on #10217 and #10287 to bring full ufunc support for Index types, expanding well beyond the small set previously supported in the `cudf.core.ops` namespace. By using most of the machinery introduced for IndexedFrame in the prior two PRs we avoid duplicating much logic so that all ufunc dispatches flow through a relatively standard path of known methods prior to a common cupy dispatch. With this change we are also able to deprecate the various ufunc operations defined in cudf/core/ops.py that exist only for this purpose as well as a number of Frame methods that are not defined for the corresponding pandas types. Users of those APIs are recommended to calling the corresponding numpy/cupy ufuncs instead to leverage the new dispatch.

This PR also fixes a bug where index binary operations that output booleans would previously return instances of GenericIndex, whereas those pandas operations would return numpy arrays. cudf now returns cupy arrays in those cases.

Resolves #9083. Contributes to #9038.

Authors:
  - Vyas Ramasubramani (https://github.com/vyasr)

Approvers:
  - GALI PREM SAGAR (https://github.com/galipremsagar)

URL: #10346
  • Loading branch information
vyasr authored Feb 25, 2022
1 parent 3f175ce commit e0af727
Show file tree
Hide file tree
Showing 12 changed files with 498 additions and 234 deletions.
8 changes: 0 additions & 8 deletions python/cudf/cudf/core/_base_index.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,14 +41,6 @@ class BaseIndex(Serializable):
_accessors: Set[Any] = set()
_data: ColumnAccessor

def __array_ufunc__(self, ufunc, method, *inputs, **kwargs):

if method == "__call__" and hasattr(cudf, ufunc.__name__):
func = getattr(cudf, ufunc.__name__)
return func(*inputs)
else:
return NotImplemented

@cached_property
def _values(self) -> ColumnBase:
raise NotImplementedError
Expand Down
55 changes: 23 additions & 32 deletions python/cudf/cudf/core/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,17 @@
import warnings
from collections import defaultdict
from collections.abc import Iterable, Sequence
from typing import Any, MutableMapping, Optional, Set, TypeVar
from typing import (
Any,
Dict,
MutableMapping,
Optional,
Set,
Tuple,
Type,
TypeVar,
Union,
)

import cupy
import numpy as np
Expand Down Expand Up @@ -44,6 +54,7 @@
from cudf.core.abc import Serializable
from cudf.core.column import (
CategoricalColumn,
ColumnBase,
as_column,
build_categorical_column,
build_column,
Expand Down Expand Up @@ -1909,7 +1920,7 @@ def _get_columns_by_label(self, labels, downcast=False):
)
return out

def _prep_for_binop(
def _make_operands_and_index_for_binop(
self,
other: Any,
fn: str,
Expand All @@ -1918,7 +1929,13 @@ def _prep_for_binop(
can_reindex: bool = False,
*args,
**kwargs,
):
) -> Tuple[
Union[
Dict[Optional[str], Tuple[ColumnBase, Any, bool, Any]],
Type[NotImplemented],
],
Optional[BaseIndex],
]:
lhs, rhs = self, other

if _is_scalar_or_zero_d_array(rhs):
Expand Down Expand Up @@ -1999,28 +2016,6 @@ def _prep_for_binop(

return operands, lhs._index

@annotate("DATAFRAME_BINARYOP", color="blue", domain="cudf_python")
def _binaryop(
self,
other: Any,
fn: str,
fill_value: Any = None,
reflect: bool = False,
can_reindex: bool = False,
*args,
**kwargs,
):
operands, out_index = self._prep_for_binop(
other, fn, fill_value, reflect, can_reindex
)
if operands is NotImplemented:
return NotImplemented

return self._from_data(
ColumnAccessor(type(self)._colwise_binop(operands, fn)),
index=out_index,
)

@annotate("DATAFRAME_UPDATE", color="blue", domain="cudf_python")
def update(
self,
Expand Down Expand Up @@ -2183,9 +2178,7 @@ def columns(self, columns):
columns = pd.Index(range(len(self._data.columns)))
is_multiindex = isinstance(columns, pd.MultiIndex)

if isinstance(
columns, (Series, cudf.Index, cudf.core.column.ColumnBase)
):
if isinstance(columns, (Series, cudf.Index, ColumnBase)):
columns = pd.Index(columns.to_numpy(), tupleize_cols=is_multiindex)
elif not isinstance(columns, pd.Index):
columns = pd.Index(columns, tupleize_cols=is_multiindex)
Expand Down Expand Up @@ -6626,7 +6619,7 @@ def _setitem_with_dataframe(
input_df: DataFrame,
replace_df: DataFrame,
input_cols: Any = None,
mask: Optional[cudf.core.column.ColumnBase] = None,
mask: Optional[ColumnBase] = None,
ignore_index: bool = False,
):
"""
Expand Down Expand Up @@ -6717,9 +6710,7 @@ def _get_union_of_series_names(series_list):


def _get_host_unique(array):
if isinstance(
array, (cudf.Series, cudf.Index, cudf.core.column.ColumnBase)
):
if isinstance(array, (cudf.Series, cudf.Index, ColumnBase)):
return array.unique.to_pandas()
elif isinstance(array, (str, numbers.Number)):
return [array]
Expand Down
170 changes: 169 additions & 1 deletion python/cudf/cudf/core/frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,49 @@
T = TypeVar("T", bound="Frame")


# Mapping from ufuncs to the corresponding binary operators.
_ufunc_binary_operations = {
# Arithmetic binary operations.
"add": "add",
"subtract": "sub",
"multiply": "mul",
"matmul": "matmul",
"divide": "truediv",
"true_divide": "truediv",
"floor_divide": "floordiv",
"power": "pow",
"float_power": "pow",
"remainder": "mod",
"mod": "mod",
"fmod": "mod",
# Bitwise binary operations.
"bitwise_and": "and",
"bitwise_or": "or",
"bitwise_xor": "xor",
# Comparison binary operators
"greater": "gt",
"greater_equal": "ge",
"less": "lt",
"less_equal": "le",
"not_equal": "ne",
"equal": "eq",
}

# These operators need to be mapped to their inverses when performing a
# reflected ufunc operation because no reflected version of the operators
# themselves exist. When these operators are invoked directly (not via
# __array_ufunc__) Python takes care of calling the inverse operation.
_ops_without_reflection = {
"gt": "lt",
"ge": "le",
"lt": "gt",
"le": "ge",
# ne and eq are symmetric, so they are their own inverse op
"ne": "ne",
"eq": "eq",
}


class Frame:
"""A collection of Column objects with an optional index.
Expand Down Expand Up @@ -2752,6 +2795,11 @@ def sin(self):
0.8011526357338306, 0.8939966636005579],
dtype='float64')
"""
warnings.warn(
"sin is deprecated and will be removed. Use numpy.sin instead",
FutureWarning,
)

return self._unaryop("sin")

@annotate("FRAME_COS", color="green", domain="cudf_python")
Expand Down Expand Up @@ -2814,6 +2862,11 @@ def cos(self):
-0.5984600690578581, -0.4480736161291701],
dtype='float64')
"""
warnings.warn(
"cos is deprecated and will be removed. Use numpy.cos instead",
FutureWarning,
)

return self._unaryop("cos")

@annotate("FRAME_TAN", color="green", domain="cudf_python")
Expand Down Expand Up @@ -2876,6 +2929,11 @@ def tan(self):
-1.3386902103511544, -1.995200412208242],
dtype='float64')
"""
warnings.warn(
"tan is deprecated and will be removed. Use numpy.tan instead",
FutureWarning,
)

return self._unaryop("tan")

@annotate("FRAME_ASIN", color="green", domain="cudf_python")
Expand Down Expand Up @@ -2927,6 +2985,11 @@ def asin(self):
1.5707963267948966, 0.3046926540153975],
dtype='float64')
"""
warnings.warn(
"asin is deprecated and will be removed in the future",
FutureWarning,
)

return self._unaryop("asin")

@annotate("FRAME_ACOS", color="green", domain="cudf_python")
Expand Down Expand Up @@ -2978,6 +3041,11 @@ def acos(self):
1.5707963267948966, 1.266103672779499],
dtype='float64')
"""
warnings.warn(
"acos is deprecated and will be removed. Use numpy.acos instead",
FutureWarning,
)

result = self.copy(deep=False)
for col in result._data:
min_float_dtype = cudf.utils.dtypes.get_min_float_dtype(
Expand Down Expand Up @@ -3047,6 +3115,11 @@ def atan(self):
0.2914567944778671],
dtype='float64')
"""
warnings.warn(
"atan is deprecated and will be removed. Use numpy.atan instead",
FutureWarning,
)

return self._unaryop("atan")

@annotate("FRAME_EXP", color="green", domain="cudf_python")
Expand Down Expand Up @@ -3110,6 +3183,11 @@ def exp(self):
2.718281828459045, 1.0, 1.3498588075760032],
dtype='float64')
"""
warnings.warn(
"exp is deprecated and will be removed. Use numpy.exp instead",
FutureWarning,
)

return self._unaryop("exp")

@annotate("FRAME_LOG", color="green", domain="cudf_python")
Expand Down Expand Up @@ -3172,6 +3250,11 @@ def log(self):
Float64Index([2.302585092994046, 2.3978952727983707,
6.214608098422191], dtype='float64')
"""
warnings.warn(
"log is deprecated and will be removed. Use numpy.log instead",
FutureWarning,
)

return self._unaryop("log")

@annotate("FRAME_SQRT", color="green", domain="cudf_python")
Expand Down Expand Up @@ -3228,6 +3311,11 @@ def sqrt(self):
>>> index.sqrt()
Float64Index([nan, 10.0, 25.0], dtype='float64')
"""
warnings.warn(
"sqrt is deprecated and will be removed. Use numpy.sqrt instead",
FutureWarning,
)

return self._unaryop("sqrt")

@annotate("FRAME_ABS", color="green", domain="cudf_python")
Expand Down Expand Up @@ -3496,7 +3584,9 @@ def _binaryop(
Frame
A new instance containing the result of the operation.
"""
raise NotImplementedError
raise NotImplementedError(
f"Binary operations are not supported for {self.__class__}"
)

@classmethod
@annotate("FRAME_COLWISE_BINOP", color="green", domain="cudf_python")
Expand Down Expand Up @@ -3658,6 +3748,84 @@ def _colwise_binop(

return output

# For more detail on this function and how it should work, see
# https://numpy.org/doc/stable/reference/ufuncs.html
def __array_ufunc__(self, ufunc, method, *inputs, **kwargs):
# We don't currently support reduction, accumulation, etc. We also
# don't support any special kwargs or higher arity ufuncs than binary.
if method != "__call__" or kwargs or ufunc.nin > 2:
return NotImplemented

fname = ufunc.__name__
if fname in _ufunc_binary_operations:
reflect = self is not inputs[0]
other = inputs[0] if reflect else inputs[1]

op = _ufunc_binary_operations[fname]
if reflect and op in _ops_without_reflection:
op = _ops_without_reflection[op]
reflect = False
op = f"__{'r' if reflect else ''}{op}__"

# Float_power returns float irrespective of the input type.
if fname == "float_power":
return getattr(self, op)(other).astype(float)
return getattr(self, op)(other)

# Special handling for various unary operations.
if fname == "negative":
return self * -1
if fname == "positive":
return self.copy(deep=True)
if fname == "invert":
return ~self
if fname == "absolute":
return self.abs()
if fname == "fabs":
return self.abs().astype(np.float64)

# None is a sentinel used by subclasses to trigger cupy dispatch.
return None

def _apply_cupy_ufunc_to_operands(
self, ufunc, cupy_func, operands, **kwargs
):
# Note: There are some operations that may be supported by libcudf but
# are not supported by pandas APIs. In particular, libcudf binary
# operations support logical and/or operations as well as
# trigonometric, but those operations are not defined on
# pd.Series/DataFrame. For now those operations will dispatch to cupy,
# but if ufuncs are ever a bottleneck we could add special handling to
# dispatch those (or any other) functions that we could implement
# without cupy.

mask = None
data = [{} for _ in range(ufunc.nout)]
for name, (left, right, _, _) in operands.items():
cupy_inputs = []
for inp in (left, right) if ufunc.nin == 2 else (left,):
if isinstance(inp, ColumnBase) and inp.has_nulls():
new_mask = as_column(inp.nullmask)

# TODO: This is a hackish way to perform a bitwise and
# of bitmasks. Once we expose
# cudf::detail::bitwise_and, then we can use that
# instead.
mask = new_mask if mask is None else (mask & new_mask)

# Arbitrarily fill with zeros. For ufuncs, we assume
# that the end result propagates nulls via a bitwise
# and, so these elements are irrelevant.
inp = inp.fillna(0)
cupy_inputs.append(cupy.asarray(inp))

cp_output = cupy_func(*cupy_inputs, **kwargs)
if ufunc.nout == 1:
cp_output = (cp_output,)
for i, out in enumerate(cp_output):
data[i][name] = as_column(out).set_mask(mask)
return data

@annotate("FRAME_DOT", color="green", domain="cudf_python")
def dot(self, other, reflect=False):
"""
Expand Down
Loading

0 comments on commit e0af727

Please sign in to comment.