From f54dd43fdf40ca3063ec2225a5d554df153476ff Mon Sep 17 00:00:00 2001 From: brandon-b-miller Date: Mon, 8 Jul 2024 19:53:51 -0700 Subject: [PATCH 1/6] raise when casting from timestamp to numeric --- python/cudf_polars/cudf_polars/dsl/expr.py | 27 ++++++++++++++++++++++ 1 file changed, 27 insertions(+) diff --git a/python/cudf_polars/cudf_polars/dsl/expr.py b/python/cudf_polars/cudf_polars/dsl/expr.py index 93cb9db7cbd..23bf0bc5cd3 100644 --- a/python/cudf_polars/cudf_polars/dsl/expr.py +++ b/python/cudf_polars/cudf_polars/dsl/expr.py @@ -1115,6 +1115,33 @@ class Cast(Expr): def __init__(self, dtype: plc.DataType, value: Expr) -> None: super().__init__(dtype) self.children = (value,) + if ( + self.dtype.id() == plc.TypeId.STRING + or value.dtype.id() == plc.TypeId.STRING + ): + raise NotImplementedError( + "Need to implement cast to/from string separately." + ) + # TODO: use exposed libcudf trait checking APIs + elif self.dtype.id() in { + plc.TypeId.TIMESTAMP_DAYS, + plc.TypeId.TIMESTAMP_MICROSECONDS, + plc.TypeId.TIMESTAMP_MILLISECONDS, + plc.TypeId.TIMESTAMP_NANOSECONDS, + } and value.dtype.id() in { + plc.TypeId.INT8, + plc.TypeId.INT16, + plc.TypeId.INT32, + plc.TypeId.INT64, + plc.TypeId.UINT8, + plc.TypeId.UINT16, + plc.TypeId.UINT32, + plc.TypeId.UINT64, + plc.TypeId.FLOAT32, + plc.TypeId.FLOAT64, + plc.TypeId.BOOL, + }: + raise NotImplementedError("Can't cast duration to numeric") def do_evaluate( self, From 1c6db1bc0db7169a1ad33babd2884c73255794dc Mon Sep 17 00:00:00 2001 From: brandon-b-miller Date: Tue, 9 Jul 2024 06:07:40 -0700 Subject: [PATCH 2/6] use libcudf traits --- python/cudf_polars/cudf_polars/dsl/expr.py | 19 +------------------ 1 file changed, 1 insertion(+), 18 deletions(-) diff --git a/python/cudf_polars/cudf_polars/dsl/expr.py b/python/cudf_polars/cudf_polars/dsl/expr.py index fa17c9e26d2..8158f58aad0 100644 --- a/python/cudf_polars/cudf_polars/dsl/expr.py +++ b/python/cudf_polars/cudf_polars/dsl/expr.py @@ -1123,24 +1123,7 @@ def __init__(self, dtype: plc.DataType, value: Expr) -> None: "Need to implement cast to/from string separately." ) # TODO: use exposed libcudf trait checking APIs - elif self.dtype.id() in { - plc.TypeId.TIMESTAMP_DAYS, - plc.TypeId.TIMESTAMP_MICROSECONDS, - plc.TypeId.TIMESTAMP_MILLISECONDS, - plc.TypeId.TIMESTAMP_NANOSECONDS, - } and value.dtype.id() in { - plc.TypeId.INT8, - plc.TypeId.INT16, - plc.TypeId.INT32, - plc.TypeId.INT64, - plc.TypeId.UINT8, - plc.TypeId.UINT16, - plc.TypeId.UINT32, - plc.TypeId.UINT64, - plc.TypeId.FLOAT32, - plc.TypeId.FLOAT64, - plc.TypeId.BOOL, - }: + elif plc.traits.is_chrono(self.dtype) and plc.traits.is_numeric(value.dtype): raise NotImplementedError("Can't cast duration to numeric") def do_evaluate( From 35c0e70fb7e6d7a9cb2b3a6cceebe002f884e825 Mon Sep 17 00:00:00 2001 From: brandon-b-miller Date: Tue, 9 Jul 2024 06:08:36 -0700 Subject: [PATCH 3/6] cleanup --- python/cudf_polars/cudf_polars/dsl/expr.py | 1 - 1 file changed, 1 deletion(-) diff --git a/python/cudf_polars/cudf_polars/dsl/expr.py b/python/cudf_polars/cudf_polars/dsl/expr.py index 8158f58aad0..1b918f19eaa 100644 --- a/python/cudf_polars/cudf_polars/dsl/expr.py +++ b/python/cudf_polars/cudf_polars/dsl/expr.py @@ -1122,7 +1122,6 @@ def __init__(self, dtype: plc.DataType, value: Expr) -> None: raise NotImplementedError( "Need to implement cast to/from string separately." ) - # TODO: use exposed libcudf trait checking APIs elif plc.traits.is_chrono(self.dtype) and plc.traits.is_numeric(value.dtype): raise NotImplementedError("Can't cast duration to numeric") From dc0ebbbd8f2c76a8ab00f20046597aedd788279c Mon Sep 17 00:00:00 2001 From: brandon-b-miller Date: Mon, 22 Jul 2024 08:17:31 -0700 Subject: [PATCH 4/6] use is_supported_cast --- python/cudf_polars/cudf_polars/dsl/expr.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/python/cudf_polars/cudf_polars/dsl/expr.py b/python/cudf_polars/cudf_polars/dsl/expr.py index 05457c25cfa..168eeacf3d5 100644 --- a/python/cudf_polars/cudf_polars/dsl/expr.py +++ b/python/cudf_polars/cudf_polars/dsl/expr.py @@ -1160,6 +1160,10 @@ class Cast(Expr): def __init__(self, dtype: plc.DataType, value: Expr) -> None: super().__init__(dtype) self.children = (value,) + if not plc.unary.is_supported_cast(self.dtype, value.dtype): + raise NotImplementedError( + f"Can't cast {self.dtype.id().name} to {value.dtype.id().name}" + ) if ( self.dtype.id() == plc.TypeId.STRING or value.dtype.id() == plc.TypeId.STRING @@ -1167,8 +1171,6 @@ def __init__(self, dtype: plc.DataType, value: Expr) -> None: raise NotImplementedError( "Need to implement cast to/from string separately." ) - elif plc.traits.is_chrono(self.dtype) and plc.traits.is_numeric(value.dtype): - raise NotImplementedError("Can't cast duration to numeric") def do_evaluate( self, From a140cd07e1ff55b5ff54d5797e4399db31cea8fa Mon Sep 17 00:00:00 2001 From: brandon-b-miller <53796099+brandon-b-miller@users.noreply.github.com> Date: Mon, 22 Jul 2024 10:55:21 -0500 Subject: [PATCH 5/6] Update python/cudf_polars/cudf_polars/dsl/expr.py Co-authored-by: Lawrence Mitchell --- python/cudf_polars/cudf_polars/dsl/expr.py | 7 ------- 1 file changed, 7 deletions(-) diff --git a/python/cudf_polars/cudf_polars/dsl/expr.py b/python/cudf_polars/cudf_polars/dsl/expr.py index 168eeacf3d5..f60c2b6a88f 100644 --- a/python/cudf_polars/cudf_polars/dsl/expr.py +++ b/python/cudf_polars/cudf_polars/dsl/expr.py @@ -1164,13 +1164,6 @@ def __init__(self, dtype: plc.DataType, value: Expr) -> None: raise NotImplementedError( f"Can't cast {self.dtype.id().name} to {value.dtype.id().name}" ) - if ( - self.dtype.id() == plc.TypeId.STRING - or value.dtype.id() == plc.TypeId.STRING - ): - raise NotImplementedError( - "Need to implement cast to/from string separately." - ) def do_evaluate( self, From a2858e6fd919732c5eff7349a6fe59efe127c2fd Mon Sep 17 00:00:00 2001 From: brandon-b-miller Date: Mon, 22 Jul 2024 12:40:41 -0700 Subject: [PATCH 6/6] basic tests --- .../tests/expressions/test_casting.py | 52 +++++++++++++++++++ 1 file changed, 52 insertions(+) create mode 100644 python/cudf_polars/tests/expressions/test_casting.py diff --git a/python/cudf_polars/tests/expressions/test_casting.py b/python/cudf_polars/tests/expressions/test_casting.py new file mode 100644 index 00000000000..3e003054338 --- /dev/null +++ b/python/cudf_polars/tests/expressions/test_casting.py @@ -0,0 +1,52 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. +# SPDX-License-Identifier: Apache-2.0 +from __future__ import annotations + +import pytest + +import polars as pl + +from cudf_polars.testing.asserts import ( + assert_gpu_result_equal, + assert_ir_translation_raises, +) + +_supported_dtypes = [(pl.Int8(), pl.Int64())] + +_unsupported_dtypes = [ + (pl.String(), pl.Int64()), +] + + +@pytest.fixture +def dtypes(request): + return request.param + + +@pytest.fixture +def tests(dtypes): + fromtype, totype = dtypes + if fromtype == pl.String(): + data = ["a", "b", "c"] + else: + data = [1, 2, 3] + return pl.DataFrame( + { + "a": pl.Series(data, dtype=fromtype), + } + ).lazy(), totype + + +@pytest.mark.parametrize("dtypes", _supported_dtypes, indirect=True) +def test_cast_supported(tests): + df, totype = tests + q = df.select(pl.col("a").cast(totype)) + assert_gpu_result_equal(q) + + +@pytest.mark.parametrize("dtypes", _unsupported_dtypes, indirect=True) +def test_cast_unsupported(tests): + df, totype = tests + assert_ir_translation_raises( + df.select(pl.col("a").cast(totype)), NotImplementedError + )