From 39f256c3397afc9c495cb819636abddb23f81dc0 Mon Sep 17 00:00:00 2001 From: brandon-b-miller <53796099+brandon-b-miller@users.noreply.github.com> Date: Tue, 23 Jul 2024 19:03:16 -0500 Subject: [PATCH] Fall back to CPU for unsupported libcudf binaryops in cudf-polars (#16188) This PR adds logic that should trigger CPU fallback unsupported binary ops. Authors: - https://github.com/brandon-b-miller - Lawrence Mitchell (https://github.com/wence-) Approvers: - Lawrence Mitchell (https://github.com/wence-) URL: https://github.com/rapidsai/cudf/pull/16188 --- python/cudf_polars/cudf_polars/dsl/expr.py | 13 ++++--- .../cudf_polars/cudf_polars/utils/dtypes.py | 38 +------------------ .../tests/expressions/test_literal.py | 18 ++++++--- 3 files changed, 21 insertions(+), 48 deletions(-) diff --git a/python/cudf_polars/cudf_polars/dsl/expr.py b/python/cudf_polars/cudf_polars/dsl/expr.py index 9835e6f8461..6325feced94 100644 --- a/python/cudf_polars/cudf_polars/dsl/expr.py +++ b/python/cudf_polars/cudf_polars/dsl/expr.py @@ -1424,13 +1424,14 @@ def __init__( super().__init__(dtype) self.op = op self.children = (left, right) - if ( - op in (plc.binaryop.BinaryOperator.ADD, plc.binaryop.BinaryOperator.SUB) - and plc.traits.is_chrono(left.dtype) - and plc.traits.is_chrono(right.dtype) - and not dtypes.have_compatible_resolution(left.dtype.id(), right.dtype.id()) + if not plc.binaryop.is_supported_operation( + self.dtype, left.dtype, right.dtype, op ): - raise NotImplementedError("Casting rules for timelike types") + raise NotImplementedError( + f"Operation {op.name} not supported " + f"for types {left.dtype.id().name} and {right.dtype.id().name} " + f"with output type {self.dtype.id().name}" + ) _MAPPING: ClassVar[dict[pl_expr.Operator, plc.binaryop.BinaryOperator]] = { pl_expr.Operator.Eq: plc.binaryop.BinaryOperator.EQUAL, diff --git a/python/cudf_polars/cudf_polars/utils/dtypes.py b/python/cudf_polars/cudf_polars/utils/dtypes.py index 1279fe91d48..cd68d021286 100644 --- a/python/cudf_polars/cudf_polars/utils/dtypes.py +++ b/python/cudf_polars/cudf_polars/utils/dtypes.py @@ -14,43 +14,7 @@ import cudf._lib.pylibcudf as plc -__all__ = ["from_polars", "downcast_arrow_lists", "have_compatible_resolution"] - - -def have_compatible_resolution(lid: plc.TypeId, rid: plc.TypeId): - """ - Do two datetime typeids have matching resolution for a binop. - - Parameters - ---------- - lid - Left type id - rid - Right type id - - Returns - ------- - True if resolutions are compatible, False otherwise. - - Notes - ----- - Polars has different casting rules for combining - datetimes/durations than libcudf, and while we don't encode the - casting rules fully, just reject things we can't handle. - - Precondition for correctness: both lid and rid are timelike. - """ - if lid == rid: - return True - # Timestamps are smaller than durations in the libcudf enum. - lid, rid = sorted([lid, rid]) - if lid == plc.TypeId.TIMESTAMP_MILLISECONDS: - return rid == plc.TypeId.DURATION_MILLISECONDS - elif lid == plc.TypeId.TIMESTAMP_MICROSECONDS: - return rid == plc.TypeId.DURATION_MICROSECONDS - elif lid == plc.TypeId.TIMESTAMP_NANOSECONDS: - return rid == plc.TypeId.DURATION_NANOSECONDS - return False +__all__ = ["from_polars", "downcast_arrow_lists"] def downcast_arrow_lists(typ: pa.DataType) -> pa.DataType: diff --git a/python/cudf_polars/tests/expressions/test_literal.py b/python/cudf_polars/tests/expressions/test_literal.py index 55e688428bd..5bd3131d1d7 100644 --- a/python/cudf_polars/tests/expressions/test_literal.py +++ b/python/cudf_polars/tests/expressions/test_literal.py @@ -6,6 +6,8 @@ import polars as pl +import cudf._lib.pylibcudf as plc + from cudf_polars.testing.asserts import ( assert_gpu_result_equal, assert_ir_translation_raises, @@ -64,11 +66,17 @@ def test_timelike_literal(timestamp, timedelta): adjusted=timestamp + timedelta, two_delta=timedelta + timedelta, ) - schema = q.collect_schema() - time_type = schema["time"] - delta_type = schema["delta"] - if dtypes.have_compatible_resolution( - dtypes.from_polars(time_type).id(), dtypes.from_polars(delta_type).id() + schema = {k: dtypes.from_polars(v) for k, v in q.collect_schema().items()} + if plc.binaryop.is_supported_operation( + schema["adjusted"], + schema["time"], + schema["delta"], + plc.binaryop.BinaryOperator.ADD, + ) and plc.binaryop.is_supported_operation( + schema["two_delta"], + schema["delta"], + schema["delta"], + plc.binaryop.BinaryOperator.ADD, ): assert_gpu_result_equal(q) else: