Skip to content

Commit

Permalink
Refactor binary ops for timedelta and datetime columns (#10581)
Browse files Browse the repository at this point in the history
This PR simplifies the handling of binary operations for datetime and timedelta columns. It reduces the number of nearly identical helper functions and consolidates logic for datetime-timedelta interop into the DatetimeColumn since timedeltas don't need to know how to work with datetimes. These changes also significantly reduce the number of redundant checks for the type of the other operand. The raised errors are no longer as highly customized as they used to be, but the type of exception is still the same which is the level of pandas compatibility that we want to provide, and the changes let us take advantage of reflection which is a major advantage.

Authors:
  - Vyas Ramasubramani (https://github.com/vyasr)

Approvers:
  - Bradley Dice (https://github.com/bdice)

URL: #10581
  • Loading branch information
vyasr authored Apr 5, 2022
1 parent 090f6b8 commit 0aef0c1
Show file tree
Hide file tree
Showing 5 changed files with 164 additions and 269 deletions.
94 changes: 50 additions & 44 deletions python/cudf/cudf/core/column/datetime.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,11 @@
DtypeObj,
ScalarLike,
)
from cudf.api.types import is_scalar
from cudf.api.types import is_datetime64_dtype, is_scalar, is_timedelta64_dtype
from cudf.core._compat import PANDAS_GE_120
from cudf.core.buffer import Buffer
from cudf.core.column import ColumnBase, as_column, column, string
from cudf.core.column.timedelta import _unit_to_nanoseconds_conversion
from cudf.utils.utils import _fillna_natwise

if PANDAS_GE_120:
Expand All @@ -33,16 +34,6 @@
_guess_datetime_format = pd.core.tools.datetimes._guess_datetime_format

# nanoseconds per time_unit
_numpy_to_pandas_conversion = {
"ns": 1,
"us": 1000,
"ms": 1000000,
"s": 1000000000,
"m": 60000000000,
"h": 3600000000000,
"D": 86400000000000,
}

_dtype_to_format_conversion = {
"datetime64[ns]": "%Y-%m-%d %H:%M:%S.%9f",
"datetime64[us]": "%Y-%m-%d %H:%M:%S.%6f",
Expand Down Expand Up @@ -378,7 +369,7 @@ def std(
self.as_numerical.std(
skipna=skipna, min_count=min_count, dtype=dtype, ddof=ddof
)
* _numpy_to_pandas_conversion[self.time_unit],
* _unit_to_nanoseconds_conversion[self.time_unit],
)

def median(self, skipna: bool = None) -> pd.Timestamp:
Expand Down Expand Up @@ -411,45 +402,49 @@ def _binaryop(self, other: ColumnBinaryOperand, op: str) -> ColumnBase:
if isinstance(other, cudf.DateOffset):
return other._datetime_binop(self, op, reflect=reflect)

# TODO: Figure out if I can reflect before we start these checks. That
# requires figuring out why _timedelta_add_result_dtype and
# _timedelta_sub_result_dtype are 1) not symmetric, and 2) different
# from each other.
# We check this on `other` before reflection since we already know the
# dtype of `self`.
other_is_timedelta = is_timedelta64_dtype(other.dtype)
other_is_datetime64 = not other_is_timedelta and is_datetime64_dtype(
other.dtype
)
lhs, rhs = (other, self) if reflect else (self, other)
out_dtype = None
if op in {
"__eq__",
"__ne__",
"__lt__",
"__gt__",
"__le__",
"__ge__",
"NULL_EQUALS",
}:
out_dtype: Dtype = cudf.dtype(np.bool_)
elif op == "__add__" and pd.api.types.is_timedelta64_dtype(
other.dtype
out_dtype = cudf.dtype(np.bool_)
elif (
op
in {
"__ne__",
"__lt__",
"__gt__",
"__le__",
"__ge__",
}
and other_is_datetime64
):
out_dtype = cudf.core.column.timedelta._timedelta_add_result_dtype(
other, self
)
elif op == "__sub__" and pd.api.types.is_timedelta64_dtype(
other.dtype
):
out_dtype = cudf.core.column.timedelta._timedelta_sub_result_dtype(
other if reflect else self, self if reflect else other
)
elif op == "__sub__" and pd.api.types.is_datetime64_dtype(other.dtype):
units = ["s", "ms", "us", "ns"]
lhs_time_unit = cudf.utils.dtypes.get_time_unit(self)
lhs_unit = units.index(lhs_time_unit)
rhs_time_unit = cudf.utils.dtypes.get_time_unit(other)
rhs_unit = units.index(rhs_time_unit)
out_dtype = np.dtype(
f"timedelta64[{units[max(lhs_unit, rhs_unit)]}]"
)
else:
out_dtype = cudf.dtype(np.bool_)
elif op == "__add__" and other_is_timedelta:
# The only thing we can add to a datetime is a timedelta. This
# operation is symmetric, i.e. we allow `datetime + timedelta` or
# `timedelta + datetime`. Both result in DatetimeColumns.
out_dtype = _resolve_mixed_dtypes(lhs, rhs, "datetime64")
elif op == "__sub__":
# Subtracting a datetime from a datetime results in a timedelta.
if other_is_datetime64:
out_dtype = _resolve_mixed_dtypes(lhs, rhs, "timedelta64")
# We can subtract a timedelta from a datetime, but not vice versa.
# Not only is subtraction antisymmetric (as is normal), it is only
# well-defined if this operation was not invoked via reflection.
elif other_is_timedelta and not reflect:
out_dtype = _resolve_mixed_dtypes(lhs, rhs, "datetime64")

if out_dtype is None:
return NotImplemented

lhs, rhs = (other, self) if reflect else (self, other)
return libcudf.binaryop.binaryop(lhs, rhs, op, out_dtype)

def fillna(
Expand Down Expand Up @@ -573,3 +568,14 @@ def infer_format(element: str, **kwargs) -> str:
raise ValueError("Unable to infer the timestamp format from the data")

return fmt


def _resolve_mixed_dtypes(
lhs: ColumnBinaryOperand, rhs: ColumnBinaryOperand, base_type: str
) -> Dtype:
units = ["s", "ms", "us", "ns"]
lhs_time_unit = cudf.utils.dtypes.get_time_unit(lhs)
lhs_unit = units.index(lhs_time_unit)
rhs_time_unit = cudf.utils.dtypes.get_time_unit(rhs)
rhs_unit = units.index(rhs_time_unit)
return cudf.dtype(f"{base_type}[{units[max(lhs_unit, rhs_unit)]}]")
Loading

0 comments on commit 0aef0c1

Please sign in to comment.