Skip to content

Commit

Permalink
Enable Datetime/Timedelta dtypes in Masked UDFs (#9451)
Browse files Browse the repository at this point in the history
Closes #9432.

Enables UDFs that do this:
```python
import cudf

df = cudf.DataFrame({'a':['2011-01-01'], 'b':[1]})
df['a'] = df['a'].astype('datetime64[ns]')
df['b'] = df['b'].astype('timedelta64[ns]')

def f(row):
    return row['a'] + row['b']

res = df.apply(f, axis=1)
```

```
0   2011-01-01 00:00:00.000000001
dtype: datetime64[ns]
```

Authors:
  - https://github.com/brandon-b-miller

Approvers:
  - GALI PREM SAGAR (https://github.com/galipremsagar)
  - Michael Wang (https://github.com/isVoid)

URL: #9451
  • Loading branch information
brandon-b-miller authored Oct 20, 2021
1 parent fc868b8 commit 919fedf
Show file tree
Hide file tree
Showing 3 changed files with 75 additions and 5 deletions.
8 changes: 8 additions & 0 deletions python/cudf/cudf/core/udf/lowering.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,6 +215,10 @@ def register_const_op(op):
cuda_lower(op, types.Number, MaskedType)(to_lower_op)
cuda_lower(op, MaskedType, types.Boolean)(to_lower_op)
cuda_lower(op, types.Boolean, MaskedType)(to_lower_op)
cuda_lower(op, MaskedType, types.NPDatetime)(to_lower_op)
cuda_lower(op, types.NPDatetime, MaskedType)(to_lower_op)
cuda_lower(op, MaskedType, types.NPTimedelta)(to_lower_op)
cuda_lower(op, types.NPTimedelta, MaskedType)(to_lower_op)


# register all lowering at init
Expand Down Expand Up @@ -266,6 +270,8 @@ def pack_return_masked_impl(context, builder, sig, args):

@cuda_lower(api.pack_return, types.Boolean)
@cuda_lower(api.pack_return, types.Number)
@cuda_lower(api.pack_return, types.NPDatetime)
@cuda_lower(api.pack_return, types.NPTimedelta)
def pack_return_scalar_impl(context, builder, sig, args):
outdata = cgutils.create_struct_proxy(sig.return_type)(context, builder)
outdata.value = args[0]
Expand Down Expand Up @@ -335,6 +341,8 @@ def cast_masked_to_masked(context, builder, fromty, toty, val):
# Masked constructor for use in a kernel for testing
@lower_builtin(api.Masked, types.Boolean, types.boolean)
@lower_builtin(api.Masked, types.Number, types.boolean)
@lower_builtin(api.Masked, types.NPDatetime, types.boolean)
@lower_builtin(api.Masked, types.NPTimedelta, types.boolean)
def masked_constructor(context, builder, sig, args):
ty = sig.return_type
value, valid = args
Expand Down
26 changes: 21 additions & 5 deletions python/cudf/cudf/core/udf/typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,13 @@
from cudf.core.udf import api
from cudf.core.udf._ops import arith_ops, comparison_ops, unary_ops

SUPPORTED_NUMBA_TYPES = (
types.Number,
types.Boolean,
types.NPDatetime,
types.NPTimedelta,
)


class MaskedType(types.Type):
"""
Expand All @@ -30,7 +37,7 @@ class MaskedType(types.Type):
def __init__(self, value):
# MaskedType in Numba shall be parameterized
# with a value type
if not isinstance(value, (types.Number, types.Boolean)):
if not isinstance(value, SUPPORTED_NUMBA_TYPES):
raise TypeError("value_type must be a numeric scalar type")
self.value_type = value
super().__init__(name=f"Masked{self.value_type}")
Expand Down Expand Up @@ -111,9 +118,18 @@ def typeof_masked(val, c):
@cuda_decl_registry.register
class MaskedConstructor(ConcreteTemplate):
key = api.Masked
units = ["ns", "ms", "us", "s"]
datetime_cases = set(types.NPDatetime(u) for u in units)
timedelta_cases = set(types.NPTimedelta(u) for u in units)
cases = [
nb_signature(MaskedType(t), t, types.boolean)
for t in (types.integer_domain | types.real_domain | {types.boolean})
for t in (
types.integer_domain
| types.real_domain
| datetime_cases
| timedelta_cases
| {types.boolean}
)
]


Expand Down Expand Up @@ -255,10 +271,10 @@ def generic(self, args, kws):
# In the case of op(Masked, scalar), we resolve the type between
# the Masked value_type and the scalar's type directly
if isinstance(args[0], MaskedType) and isinstance(
args[1], (types.Number, types.Boolean)
args[1], SUPPORTED_NUMBA_TYPES
):
to_resolve_types = (args[0].value_type, args[1])
elif isinstance(args[0], (types.Number, types.Boolean)) and isinstance(
elif isinstance(args[0], SUPPORTED_NUMBA_TYPES) and isinstance(
args[1], MaskedType
):
to_resolve_types = (args[1].value_type, args[0])
Expand Down Expand Up @@ -306,7 +322,7 @@ def generic(self, args, kws):
if isinstance(args[0], MaskedType):
# MaskedType(dtype, valid) -> MaskedType(dtype, valid)
return nb_signature(args[0], args[0])
elif isinstance(args[0], (types.Number, types.Boolean)):
elif isinstance(args[0], SUPPORTED_NUMBA_TYPES):
# scalar_type -> MaskedType(scalar_type, True)
return_type = MaskedType(args[0])
return nb_signature(return_type, args[0])
Expand Down
46 changes: 46 additions & 0 deletions python/cudf/cudf/tests/test_udf_masked_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,52 @@ def func_gdf(row):
run_masked_udf_test(func_pdf, func_gdf, gdf, check_dtype=False)


@pytest.mark.parametrize(
"dtype_l",
["datetime64[ns]", "datetime64[us]", "datetime64[ms]", "datetime64[s]"],
)
@pytest.mark.parametrize(
"dtype_r",
[
"timedelta64[ns]",
"timedelta64[us]",
"timedelta64[ms]",
"timedelta64[s]",
"datetime64[ns]",
"datetime64[ms]",
"datetime64[us]",
"datetime64[s]",
],
)
@pytest.mark.parametrize("op", [operator.add, operator.sub])
def test_arith_masked_vs_masked_datelike(op, dtype_l, dtype_r):
# Datetime version of the above
# does not test all dtype combinations for now
if "datetime" in dtype_l and "datetime" in dtype_r and op is operator.add:
# don't try adding datetimes to datetimes.
pytest.skip("Adding datetime to datetime is not valid")

def func_pdf(row):
x = row["a"]
y = row["b"]
return op(x, y)

def func_gdf(row):
x = row["a"]
y = row["b"]
return op(x, y)

gdf = cudf.DataFrame(
{
"a": ["2011-01-01", cudf.NA, "2011-03-01", cudf.NA],
"b": [4, 5, cudf.NA, cudf.NA],
}
)
gdf["a"] = gdf["a"].astype(dtype_l)
gdf["b"] = gdf["b"].astype(dtype_r)
run_masked_udf_test(func_pdf, func_gdf, gdf, check_dtype=False)


@pytest.mark.parametrize("op", comparison_ops)
def test_compare_masked_vs_masked(op):
# this test should test all the
Expand Down

0 comments on commit 919fedf

Please sign in to comment.