Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement cudf-polars datetime extraction methods #16500

Merged
merged 15 commits into from
Sep 5, 2024
42 changes: 5 additions & 37 deletions python/cudf/cudf/_lib/datetime.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@ from cudf._lib.pylibcudf.libcudf.scalar.scalar cimport scalar
from cudf._lib.pylibcudf.libcudf.types cimport size_type
from cudf._lib.scalar cimport DeviceScalar

import cudf._lib.pylibcudf as plc


@acquire_spill_lock()
def add_months(Column col, Column months):
Expand All @@ -37,43 +39,9 @@ def add_months(Column col, Column months):

@acquire_spill_lock()
def extract_datetime_component(Column col, object field):

cdef unique_ptr[column] c_result
cdef column_view col_view = col.view()

with nogil:
if field == "year":
c_result = move(libcudf_datetime.extract_year(col_view))
elif field == "month":
c_result = move(libcudf_datetime.extract_month(col_view))
elif field == "day":
c_result = move(libcudf_datetime.extract_day(col_view))
elif field == "weekday":
c_result = move(libcudf_datetime.extract_weekday(col_view))
elif field == "hour":
c_result = move(libcudf_datetime.extract_hour(col_view))
elif field == "minute":
c_result = move(libcudf_datetime.extract_minute(col_view))
elif field == "second":
c_result = move(libcudf_datetime.extract_second(col_view))
elif field == "millisecond":
c_result = move(
libcudf_datetime.extract_millisecond_fraction(col_view)
)
elif field == "microsecond":
c_result = move(
libcudf_datetime.extract_microsecond_fraction(col_view)
)
elif field == "nanosecond":
c_result = move(
libcudf_datetime.extract_nanosecond_fraction(col_view)
)
elif field == "day_of_year":
c_result = move(libcudf_datetime.day_of_year(col_view))
else:
raise ValueError(f"Invalid datetime field: '{field}'")

result = Column.from_unique_ptr(move(c_result))
result = Column.from_pylibcudf(
plc.datetime.extract_datetime_component(col.to_pylibcudf(mode="read"), field)
)

if field == "weekday":
# Pandas counts Monday-Sunday as 0-6
Expand Down
49 changes: 49 additions & 0 deletions python/cudf/cudf/_lib/pylibcudf/datetime.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,16 @@ from libcpp.utility cimport move

from cudf._lib.pylibcudf.libcudf.column.column cimport column
from cudf._lib.pylibcudf.libcudf.datetime cimport (
day_of_year as cpp_day_of_year,
extract_day as cpp_extract_day,
extract_hour as cpp_extract_hour,
extract_microsecond_fraction as cpp_extract_microsecond_fraction,
extract_millisecond_fraction as cpp_extract_millisecond_fraction,
extract_minute as cpp_extract_minute,
extract_month as cpp_extract_month,
extract_nanosecond_fraction as cpp_extract_nanosecond_fraction,
extract_second as cpp_extract_second,
extract_weekday as cpp_extract_weekday,
extract_year as cpp_extract_year,
)

Expand Down Expand Up @@ -31,3 +41,42 @@ cpdef Column extract_year(
with nogil:
result = move(cpp_extract_year(values.view()))
return Column.from_libcudf(move(result))


def extract_datetime_component(Column col, str field):

cdef unique_ptr[column] c_result

with nogil:
if field == "year":
c_result = move(cpp_extract_year(col.view()))
elif field == "month":
c_result = move(cpp_extract_month(col.view()))
elif field == "day":
c_result = move(cpp_extract_day(col.view()))
elif field == "weekday":
c_result = move(cpp_extract_weekday(col.view()))
elif field == "hour":
c_result = move(cpp_extract_hour(col.view()))
elif field == "minute":
c_result = move(cpp_extract_minute(col.view()))
elif field == "second":
c_result = move(cpp_extract_second(col.view()))
elif field == "millisecond":
c_result = move(
cpp_extract_millisecond_fraction(col.view())
)
elif field == "microsecond":
c_result = move(
cpp_extract_microsecond_fraction(col.view())
)
elif field == "nanosecond":
c_result = move(
cpp_extract_nanosecond_fraction(col.view())
)
elif field == "day_of_year":
c_result = move(cpp_day_of_year(col.view()))
wence- marked this conversation as resolved.
Show resolved Hide resolved
else:
raise ValueError(f"Invalid datetime field: '{field}'")

return Column.from_libcudf(move(c_result))
42 changes: 38 additions & 4 deletions python/cudf/cudf/pylibcudf_tests/test_datetime.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,18 @@
# Copyright (c) 2024, NVIDIA CORPORATION.

import datetime
import functools

import pyarrow as pa
import pyarrow.compute as pc
import pytest
from utils import assert_column_eq

import cudf._lib.pylibcudf as plc


@pytest.fixture
def column(has_nulls):
def date_column(has_nulls):
values = [
datetime.date(1999, 1, 1),
datetime.date(2024, 10, 12),
Expand All @@ -22,9 +24,41 @@ def column(has_nulls):
return plc.interop.from_arrow(pa.array(values, type=pa.date32()))


def test_extract_year(column):
got = plc.datetime.extract_year(column)
@pytest.fixture(scope="module", params=["s", "ms", "us", "ns"])
def datetime_column(has_nulls, request):
values = [
datetime.datetime(1999, 1, 1),
datetime.datetime(2024, 10, 12),
datetime.datetime(1970, 1, 1),
datetime.datetime(2260, 1, 1),
datetime.datetime(2024, 2, 29, 3, 14, 15),
datetime.datetime(2024, 2, 29, 3, 14, 15, 999),
]
if has_nulls:
values[2] = None
return plc.interop.from_arrow(
pa.array(values, type=pa.timestamp(request.param))
)


@pytest.mark.parametrize(
"component, pc_fun",
[
("year", pc.year),
("month", pc.month),
("day", pc.day),
("weekday", functools.partial(pc.day_of_week, count_from_zero=False)),
("hour", pc.hour),
("minute", pc.minute),
("second", pc.second),
("millisecond", pc.millisecond),
("microsecond", pc.microsecond),
("nanosecond", pc.nanosecond),
],
)
def test_extraction(datetime_column, component, pc_fun):
got = plc.datetime.extract_datetime_component(datetime_column, component)
# libcudf produces an int16, arrow produces an int64
expect = pa.compute.year(plc.interop.to_arrow(column)).cast(pa.int16())
expect = pc_fun(plc.interop.to_arrow(datetime_column)).cast(pa.int16())

assert_column_eq(expect, got)
84 changes: 81 additions & 3 deletions python/cudf_polars/cudf_polars/dsl/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -946,6 +946,19 @@ def do_evaluate(

class TemporalFunction(Expr):
__slots__ = ("name", "options", "children")
_temporal_extraction_function_map: ClassVar[dict[pl_expr.TemporalFunction, str]] = {
pl_expr.TemporalFunction.Year: "year",
pl_expr.TemporalFunction.Month: "month",
pl_expr.TemporalFunction.Day: "day",
pl_expr.TemporalFunction.WeekDay: "weekday",
pl_expr.TemporalFunction.Hour: "hour",
pl_expr.TemporalFunction.Minute: "minute",
pl_expr.TemporalFunction.Second: "second",
pl_expr.TemporalFunction.Millisecond: "millisecond",
pl_expr.TemporalFunction.Microsecond: "microsecond",
pl_expr.TemporalFunction.Nanosecond: "nanosecond",
}
_supported_temporal_functions = _temporal_extraction_function_map
brandon-b-miller marked this conversation as resolved.
Show resolved Hide resolved
_non_child = ("dtype", "name", "options")
children: tuple[Expr, ...]

Expand All @@ -960,7 +973,7 @@ def __init__(
self.options = options
self.name = name
self.children = children
if self.name != pl_expr.TemporalFunction.Year:
if self.name not in TemporalFunction._supported_temporal_functions:
brandon-b-miller marked this conversation as resolved.
Show resolved Hide resolved
raise NotImplementedError(f"Temporal function {self.name}")

def do_evaluate(
Expand All @@ -975,9 +988,74 @@ def do_evaluate(
child.evaluate(df, context=context, mapping=mapping)
for child in self.children
]
if self.name == pl_expr.TemporalFunction.Year:
if self.name in TemporalFunction._temporal_extraction_function_map:
brandon-b-miller marked this conversation as resolved.
Show resolved Hide resolved
(column,) = columns
return Column(plc.datetime.extract_year(column.obj))
if self.name == pl_expr.TemporalFunction.Millisecond:
return Column(
plc.datetime.extract_datetime_component(column.obj, "millisecond")
)
wence- marked this conversation as resolved.
Show resolved Hide resolved
if self.name == pl_expr.TemporalFunction.Microsecond:
millis = plc.datetime.extract_datetime_component(
column.obj, "millisecond"
)
micros = plc.datetime.extract_datetime_component(
column.obj, "microsecond"
)
processed_mili = plc.binaryop.binary_operation(
millis,
plc.interop.from_arrow(pa.scalar(1_000, type=pa.int32())),
plc.binaryop.BinaryOperator.MUL,
plc.DataType(plc.TypeId.INT32),
)
total_micros = plc.binaryop.binary_operation(
micros,
processed_mili,
brandon-b-miller marked this conversation as resolved.
Show resolved Hide resolved
plc.binaryop.BinaryOperator.ADD,
plc.types.DataType(plc.types.TypeId.INT32),
)
return Column(total_micros)
elif self.name == pl_expr.TemporalFunction.Nanosecond:
millis = plc.datetime.extract_datetime_component(
column.obj, "millisecond"
)
micros = plc.datetime.extract_datetime_component(
column.obj, "microsecond"
)
nanos = plc.datetime.extract_datetime_component(
column.obj, "nanosecond"
)
processed_mili = plc.binaryop.binary_operation(
millis,
plc.interop.from_arrow(pa.scalar(1_000_000, type=pa.int32())),
plc.binaryop.BinaryOperator.MUL,
plc.types.DataType(plc.types.TypeId.INT32),
)
processed_micro = plc.binaryop.binary_operation(
micros,
plc.interop.from_arrow(pa.scalar(1_000, type=pa.int32())),
plc.binaryop.BinaryOperator.MUL,
plc.types.DataType(plc.types.TypeId.INT32),
)
total_nanos = plc.binaryop.binary_operation(
nanos,
processed_mili,
plc.binaryop.BinaryOperator.ADD,
plc.types.DataType(plc.types.TypeId.INT32),
)
total_nanos = plc.binaryop.binary_operation(
total_nanos,
processed_micro,
plc.binaryop.BinaryOperator.ADD,
plc.types.DataType(plc.types.TypeId.INT32),
)
brandon-b-miller marked this conversation as resolved.
Show resolved Hide resolved
return Column(total_nanos)

return Column(
plc.datetime.extract_datetime_component(
column.obj,
TemporalFunction._temporal_extraction_function_map[self.name],
brandon-b-miller marked this conversation as resolved.
Show resolved Hide resolved
)
)
raise NotImplementedError(
f"TemporalFunction {self.name}"
) # pragma: no cover; init trips first
brandon-b-miller marked this conversation as resolved.
Show resolved Hide resolved
Expand Down
19 changes: 18 additions & 1 deletion python/cudf_polars/cudf_polars/dsl/translate.py
Original file line number Diff line number Diff line change
Expand Up @@ -429,12 +429,29 @@ def _(node: pl_expr.Function, visitor: NodeTraverser, dtype: plc.DataType) -> ex
*(translate_expr(visitor, n=n) for n in node.input),
)
elif isinstance(name, pl_expr.TemporalFunction):
return expr.TemporalFunction(
# functions for which evaluation of the expression may not return
# the same dtype as polars, either due to libcudf returning a different
# dtype, or due to our internal processing affecting what libcudf returns
polars_result_dtypes = {
pl_expr.TemporalFunction.Year: plc.DataType(plc.TypeId.INT32),
pl_expr.TemporalFunction.Month: plc.DataType(plc.TypeId.INT8),
pl_expr.TemporalFunction.Day: plc.DataType(plc.TypeId.INT8),
pl_expr.TemporalFunction.WeekDay: plc.DataType(plc.TypeId.INT8),
pl_expr.TemporalFunction.Hour: plc.DataType(plc.TypeId.INT8),
pl_expr.TemporalFunction.Minute: plc.DataType(plc.TypeId.INT8),
pl_expr.TemporalFunction.Second: plc.DataType(plc.TypeId.INT8),
pl_expr.TemporalFunction.Millisecond: plc.DataType(plc.TypeId.INT32),
}
result_expr = expr.TemporalFunction(
brandon-b-miller marked this conversation as resolved.
Show resolved Hide resolved
dtype,
name,
options,
*(translate_expr(visitor, n=n) for n in node.input),
)
if name in polars_result_dtypes:
return expr.Cast(polars_result_dtypes[name], result_expr)
brandon-b-miller marked this conversation as resolved.
Show resolved Hide resolved
return result_expr

elif isinstance(name, str):
return expr.UnaryFunction(
dtype,
Expand Down
Loading
Loading