diff --git a/cpp/include/cudf/unary.hpp b/cpp/include/cudf/unary.hpp index 74c8bc67d3a..a1825ffccfe 100644 --- a/cpp/include/cudf/unary.hpp +++ b/cpp/include/cudf/unary.hpp @@ -202,6 +202,16 @@ std::unique_ptr cast( rmm::cuda_stream_view stream = cudf::get_default_stream(), rmm::device_async_resource_ref mr = rmm::mr::get_current_device_resource()); +/** + * @brief Check if a cast between two datatypes is supported. + * + * @param from source type + * @param to target type + * + * @returns true if the cast is supported. + */ +bool is_supported_cast(data_type from, data_type to) noexcept; + /** * @brief Creates a column of `type_id::BOOL8` elements indicating the presence of `NaN` values * in a column of floating point values. diff --git a/cpp/src/unary/cast_ops.cu b/cpp/src/unary/cast_ops.cu index 64427326d87..ec21813705a 100644 --- a/cpp/src/unary/cast_ops.cu +++ b/cpp/src/unary/cast_ops.cu @@ -28,6 +28,7 @@ #include #include #include +#include #include #include @@ -459,6 +460,14 @@ std::unique_ptr cast(column_view const& input, return type_dispatcher(input.type(), detail::dispatch_unary_cast_from{input}, type, stream, mr); } +struct is_supported_cast_impl { + template + bool operator()() const + { + return is_supported_cast(); + } +}; + } // namespace detail std::unique_ptr cast(column_view const& input, @@ -470,4 +479,11 @@ std::unique_ptr cast(column_view const& input, return detail::cast(input, type, stream, mr); } +bool is_supported_cast(data_type from, data_type to) noexcept +{ + // No matching detail API call/nvtx annotation, since this doesn't + // launch a kernel. + return double_type_dispatcher(from, to, detail::is_supported_cast_impl{}); +} + } // namespace cudf diff --git a/python/cudf/cudf/_lib/pylibcudf/libcudf/unary.pxd b/python/cudf/cudf/_lib/pylibcudf/libcudf/unary.pxd index 7f8ae2b7617..2a1b189af51 100644 --- a/python/cudf/cudf/_lib/pylibcudf/libcudf/unary.pxd +++ b/python/cudf/cudf/_lib/pylibcudf/libcudf/unary.pxd @@ -1,6 +1,7 @@ # Copyright (c) 2020-2024, NVIDIA CORPORATION. from libc.stdint cimport int32_t +from libcpp cimport bool from libcpp.memory cimport unique_ptr from cudf._lib.pylibcudf.libcudf.column.column cimport column @@ -43,5 +44,6 @@ cdef extern from "cudf/unary.hpp" namespace "cudf" nogil: cdef extern unique_ptr[column] cast( column_view input, data_type out_type) except + + cdef extern bool is_supported_cast(data_type from_, data_type to) noexcept cdef extern unique_ptr[column] is_nan(column_view input) except + cdef extern unique_ptr[column] is_not_nan(column_view input) except + diff --git a/python/cudf/cudf/_lib/pylibcudf/unary.pxd b/python/cudf/cudf/_lib/pylibcudf/unary.pxd index 4aa4543bb80..d07df838172 100644 --- a/python/cudf/cudf/_lib/pylibcudf/unary.pxd +++ b/python/cudf/cudf/_lib/pylibcudf/unary.pxd @@ -1,5 +1,7 @@ # Copyright (c) 2024, NVIDIA CORPORATION. +from libcpp cimport bool + from cudf._lib.pylibcudf.libcudf.unary cimport unary_operator from .column cimport Column @@ -17,3 +19,5 @@ cpdef Column cast(Column input, DataType data_type) cpdef Column is_nan(Column input) cpdef Column is_not_nan(Column input) + +cpdef bool is_supported_cast(DataType from_, DataType to) diff --git a/python/cudf/cudf/_lib/pylibcudf/unary.pyx b/python/cudf/cudf/_lib/pylibcudf/unary.pyx index 0879b501a49..8da46f0a832 100644 --- a/python/cudf/cudf/_lib/pylibcudf/unary.pyx +++ b/python/cudf/cudf/_lib/pylibcudf/unary.pyx @@ -1,5 +1,6 @@ # Copyright (c) 2024, NVIDIA CORPORATION. +from libcpp cimport bool from libcpp.memory cimport unique_ptr from libcpp.utility cimport move @@ -154,3 +155,23 @@ cpdef Column is_not_nan(Column input): result = move(cpp_unary.is_not_nan(input.view())) return Column.from_libcudf(move(result)) + +cpdef bool is_supported_cast(DataType from_, DataType to): + """Check if a cast between datatypes is supported. + + For details, see :cpp:func:`is_supported_cast`. + + Parameters + ---------- + from_ + The source datatype + to + The target datatype + + Returns + ------- + bool + True if the cast is supported. + """ + with nogil: + return cpp_unary.is_supported_cast(from_.c_obj, to.c_obj) diff --git a/python/cudf/cudf/pylibcudf_tests/test_unary.py b/python/cudf/cudf/pylibcudf_tests/test_unary.py new file mode 100644 index 00000000000..b5e4f0cb0e8 --- /dev/null +++ b/python/cudf/cudf/pylibcudf_tests/test_unary.py @@ -0,0 +1,19 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. + +from cudf._lib import pylibcudf as plc + + +def test_is_supported_cast(): + assert plc.unary.is_supported_cast( + plc.DataType(plc.TypeId.INT8), plc.DataType(plc.TypeId.UINT64) + ) + assert plc.unary.is_supported_cast( + plc.DataType(plc.TypeId.DURATION_MILLISECONDS), + plc.DataType(plc.TypeId.UINT64), + ) + assert not plc.unary.is_supported_cast( + plc.DataType(plc.TypeId.INT32), plc.DataType(plc.TypeId.TIMESTAMP_DAYS) + ) + assert not plc.unary.is_supported_cast( + plc.DataType(plc.TypeId.INT32), plc.DataType(plc.TypeId.STRING) + )