From 344fcc385b70f765b751b3e7ae838fd54829ad3e Mon Sep 17 00:00:00 2001 From: brandon-b-miller Date: Wed, 3 Jul 2024 07:42:26 -0700 Subject: [PATCH 1/4] add logic --- python/cudf_polars/cudf_polars/dsl/expr.py | 23 +++++++++++++++++++ .../cudf_polars/cudf_polars/utils/dtypes.py | 7 ++++++ 2 files changed, 30 insertions(+) diff --git a/python/cudf_polars/cudf_polars/dsl/expr.py b/python/cudf_polars/cudf_polars/dsl/expr.py index fe859c8d958..084489861e8 100644 --- a/python/cudf_polars/cudf_polars/dsl/expr.py +++ b/python/cudf_polars/cudf_polars/dsl/expr.py @@ -1186,6 +1186,28 @@ def __init__( and not dtypes.have_compatible_resolution(left.dtype.id(), right.dtype.id()) ): raise NotImplementedError("Casting rules for timelike types") + if op in ( + plc.binaryop.BinaryOperator.MUL, + plc.binaryop.BinaryOperator.DIV, + plc.binaryop.BinaryOperator.TRUE_DIV, + plc.binaryop.BinaryOperator.FLOOR_DIV, + ): + if ( + left.dtype.id() in dtypes.TIMELIKE_TYPES + and right.dtype.id() in dtypes.FLOATING_TYPES + ) or ( + right.dtype.id() in dtypes.TIMELIKE_TYPES + and left.dtype.id() in dtypes.FLOATING_TYPES + ): + raise NotImplementedError( + "No multiplying or dividing durations by floats" + ) + if ( + left.dtype.id() in dtypes.TIMELIKE_TYPES + and right.dtype.id() in dtypes.TIMELIKE_TYPES + and dtype.id() in dtypes.FLOATING_TYPES + ): + raise NotImplementedError("No dividing durations by durations") _MAPPING: ClassVar[dict[pl_expr.Operator, plc.binaryop.BinaryOperator]] = { pl_expr.Operator.Eq: plc.binaryop.BinaryOperator.EQUAL, @@ -1229,6 +1251,7 @@ def do_evaluate( lop = left.obj_scalar elif right.is_scalar: rop = right.obj_scalar + breakpoint() return Column( plc.binaryop.binary_operation(lop, rop, self.op, self.dtype), ) diff --git a/python/cudf_polars/cudf_polars/utils/dtypes.py b/python/cudf_polars/cudf_polars/utils/dtypes.py index 507acb5d33a..cde35d4feea 100644 --- a/python/cudf_polars/cudf_polars/utils/dtypes.py +++ b/python/cudf_polars/cudf_polars/utils/dtypes.py @@ -29,6 +29,13 @@ ] ) +FLOATING_TYPES = frozenset( + [ + plc.TypeId.FLOAT32, + plc.TypeId.FLOAT64, + ] +) + def have_compatible_resolution(lid: plc.TypeId, rid: plc.TypeId): """ From c92673a73f537259ce7ce745efad88e5f3d0d2f5 Mon Sep 17 00:00:00 2001 From: brandon-b-miller Date: Wed, 3 Jul 2024 07:43:11 -0700 Subject: [PATCH 2/4] clean --- python/cudf_polars/cudf_polars/dsl/expr.py | 1 - 1 file changed, 1 deletion(-) diff --git a/python/cudf_polars/cudf_polars/dsl/expr.py b/python/cudf_polars/cudf_polars/dsl/expr.py index 084489861e8..92fa5fee913 100644 --- a/python/cudf_polars/cudf_polars/dsl/expr.py +++ b/python/cudf_polars/cudf_polars/dsl/expr.py @@ -1251,7 +1251,6 @@ def do_evaluate( lop = left.obj_scalar elif right.is_scalar: rop = right.obj_scalar - breakpoint() return Column( plc.binaryop.binary_operation(lop, rop, self.op, self.dtype), ) From 990265f79700cd1afebb9feb21d15807d07cc735 Mon Sep 17 00:00:00 2001 From: brandon-b-miller <53796099+brandon-b-miller@users.noreply.github.com> Date: Mon, 22 Jul 2024 17:26:18 -0500 Subject: [PATCH 3/4] Update python/cudf_polars/cudf_polars/dsl/expr.py --- python/cudf_polars/cudf_polars/dsl/expr.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/cudf_polars/cudf_polars/dsl/expr.py b/python/cudf_polars/cudf_polars/dsl/expr.py index 07dad6b284a..4694805a6e7 100644 --- a/python/cudf_polars/cudf_polars/dsl/expr.py +++ b/python/cudf_polars/cudf_polars/dsl/expr.py @@ -1398,7 +1398,7 @@ def __init__( raise NotImplementedError( f"Operation {op.name} not supported " f"for types {left.dtype.id().name} and {right.dtype.id().name} " - f"with output type {self.dtype}" + f"with output type {self.dtype.id().name}" ) _MAPPING: ClassVar[dict[pl_expr.Operator, plc.binaryop.BinaryOperator]] = { From 5e0c7479078374568577cdadd32460942b75271d Mon Sep 17 00:00:00 2001 From: Lawrence Mitchell Date: Tue, 23 Jul 2024 09:54:09 +0000 Subject: [PATCH 4/4] Use is_supported_operation in tests and remove unnecessary code --- .../cudf_polars/cudf_polars/utils/dtypes.py | 38 +------------------ .../tests/expressions/test_literal.py | 18 ++++++--- 2 files changed, 14 insertions(+), 42 deletions(-) diff --git a/python/cudf_polars/cudf_polars/utils/dtypes.py b/python/cudf_polars/cudf_polars/utils/dtypes.py index 1279fe91d48..cd68d021286 100644 --- a/python/cudf_polars/cudf_polars/utils/dtypes.py +++ b/python/cudf_polars/cudf_polars/utils/dtypes.py @@ -14,43 +14,7 @@ import cudf._lib.pylibcudf as plc -__all__ = ["from_polars", "downcast_arrow_lists", "have_compatible_resolution"] - - -def have_compatible_resolution(lid: plc.TypeId, rid: plc.TypeId): - """ - Do two datetime typeids have matching resolution for a binop. - - Parameters - ---------- - lid - Left type id - rid - Right type id - - Returns - ------- - True if resolutions are compatible, False otherwise. - - Notes - ----- - Polars has different casting rules for combining - datetimes/durations than libcudf, and while we don't encode the - casting rules fully, just reject things we can't handle. - - Precondition for correctness: both lid and rid are timelike. - """ - if lid == rid: - return True - # Timestamps are smaller than durations in the libcudf enum. - lid, rid = sorted([lid, rid]) - if lid == plc.TypeId.TIMESTAMP_MILLISECONDS: - return rid == plc.TypeId.DURATION_MILLISECONDS - elif lid == plc.TypeId.TIMESTAMP_MICROSECONDS: - return rid == plc.TypeId.DURATION_MICROSECONDS - elif lid == plc.TypeId.TIMESTAMP_NANOSECONDS: - return rid == plc.TypeId.DURATION_NANOSECONDS - return False +__all__ = ["from_polars", "downcast_arrow_lists"] def downcast_arrow_lists(typ: pa.DataType) -> pa.DataType: diff --git a/python/cudf_polars/tests/expressions/test_literal.py b/python/cudf_polars/tests/expressions/test_literal.py index 55e688428bd..5bd3131d1d7 100644 --- a/python/cudf_polars/tests/expressions/test_literal.py +++ b/python/cudf_polars/tests/expressions/test_literal.py @@ -6,6 +6,8 @@ import polars as pl +import cudf._lib.pylibcudf as plc + from cudf_polars.testing.asserts import ( assert_gpu_result_equal, assert_ir_translation_raises, @@ -64,11 +66,17 @@ def test_timelike_literal(timestamp, timedelta): adjusted=timestamp + timedelta, two_delta=timedelta + timedelta, ) - schema = q.collect_schema() - time_type = schema["time"] - delta_type = schema["delta"] - if dtypes.have_compatible_resolution( - dtypes.from_polars(time_type).id(), dtypes.from_polars(delta_type).id() + schema = {k: dtypes.from_polars(v) for k, v in q.collect_schema().items()} + if plc.binaryop.is_supported_operation( + schema["adjusted"], + schema["time"], + schema["delta"], + plc.binaryop.BinaryOperator.ADD, + ) and plc.binaryop.is_supported_operation( + schema["two_delta"], + schema["delta"], + schema["delta"], + plc.binaryop.BinaryOperator.ADD, ): assert_gpu_result_equal(q) else: