Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Expose reflection to check if casting between two types is supported #16239

Merged
merged 4 commits into from
Jul 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions cpp/include/cudf/unary.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -202,6 +202,16 @@ std::unique_ptr<column> cast(
rmm::cuda_stream_view stream = cudf::get_default_stream(),
rmm::device_async_resource_ref mr = rmm::mr::get_current_device_resource());

/**
* @brief Check if a cast between two datatypes is supported.
*
* @param from source type
* @param to target type
*
* @returns true if the cast is supported.
*/
bool is_supported_cast(data_type from, data_type to) noexcept;

/**
* @brief Creates a column of `type_id::BOOL8` elements indicating the presence of `NaN` values
* in a column of floating point values.
Expand Down
16 changes: 16 additions & 0 deletions cpp/src/unary/cast_ops.cu
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
#include <cudf/unary.hpp>
#include <cudf/utilities/default_stream.hpp>
#include <cudf/utilities/traits.hpp>
#include <cudf/utilities/type_dispatcher.hpp>

#include <rmm/cuda_stream_view.hpp>
#include <rmm/exec_policy.hpp>
Expand Down Expand Up @@ -459,6 +460,14 @@ std::unique_ptr<column> cast(column_view const& input,
return type_dispatcher(input.type(), detail::dispatch_unary_cast_from{input}, type, stream, mr);
}

struct is_supported_cast_impl {
template <typename From, typename To>
bool operator()() const
{
return is_supported_cast<From, To>();
}
};

} // namespace detail

std::unique_ptr<column> cast(column_view const& input,
Expand All @@ -470,4 +479,11 @@ std::unique_ptr<column> cast(column_view const& input,
return detail::cast(input, type, stream, mr);
}

bool is_supported_cast(data_type from, data_type to) noexcept
{
// No matching detail API call/nvtx annotation, since this doesn't
// launch a kernel.
return double_type_dispatcher(from, to, detail::is_supported_cast_impl{});
}

} // namespace cudf
2 changes: 2 additions & 0 deletions python/cudf/cudf/_lib/pylibcudf/libcudf/unary.pxd
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# Copyright (c) 2020-2024, NVIDIA CORPORATION.

from libc.stdint cimport int32_t
from libcpp cimport bool
from libcpp.memory cimport unique_ptr

from cudf._lib.pylibcudf.libcudf.column.column cimport column
Expand Down Expand Up @@ -43,5 +44,6 @@ cdef extern from "cudf/unary.hpp" namespace "cudf" nogil:
cdef extern unique_ptr[column] cast(
column_view input,
data_type out_type) except +
cdef extern bool is_supported_cast(data_type from_, data_type to) noexcept
cdef extern unique_ptr[column] is_nan(column_view input) except +
cdef extern unique_ptr[column] is_not_nan(column_view input) except +
4 changes: 4 additions & 0 deletions python/cudf/cudf/_lib/pylibcudf/unary.pxd
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
# Copyright (c) 2024, NVIDIA CORPORATION.

from libcpp cimport bool

from cudf._lib.pylibcudf.libcudf.unary cimport unary_operator

from .column cimport Column
Expand All @@ -17,3 +19,5 @@ cpdef Column cast(Column input, DataType data_type)
cpdef Column is_nan(Column input)

cpdef Column is_not_nan(Column input)

cpdef bool is_supported_cast(DataType from_, DataType to)
21 changes: 21 additions & 0 deletions python/cudf/cudf/_lib/pylibcudf/unary.pyx
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# Copyright (c) 2024, NVIDIA CORPORATION.

from libcpp cimport bool
from libcpp.memory cimport unique_ptr
from libcpp.utility cimport move

Expand Down Expand Up @@ -154,3 +155,23 @@ cpdef Column is_not_nan(Column input):
result = move(cpp_unary.is_not_nan(input.view()))

return Column.from_libcudf(move(result))

cpdef bool is_supported_cast(DataType from_, DataType to):
wence- marked this conversation as resolved.
Show resolved Hide resolved
"""Check if a cast between datatypes is supported.

For details, see :cpp:func:`is_supported_cast`.

Parameters
----------
from_
The source datatype
to
The target datatype

Returns
-------
bool
True if the cast is supported.
"""
with nogil:
return cpp_unary.is_supported_cast(from_.c_obj, to.c_obj)
19 changes: 19 additions & 0 deletions python/cudf/cudf/pylibcudf_tests/test_unary.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
# Copyright (c) 2024, NVIDIA CORPORATION.

from cudf._lib import pylibcudf as plc


def test_is_supported_cast():
assert plc.unary.is_supported_cast(
plc.DataType(plc.TypeId.INT8), plc.DataType(plc.TypeId.UINT64)
)
assert plc.unary.is_supported_cast(
plc.DataType(plc.TypeId.DURATION_MILLISECONDS),
plc.DataType(plc.TypeId.UINT64),
)
assert not plc.unary.is_supported_cast(
plc.DataType(plc.TypeId.INT32), plc.DataType(plc.TypeId.TIMESTAMP_DAYS)
)
assert not plc.unary.is_supported_cast(
plc.DataType(plc.TypeId.INT32), plc.DataType(plc.TypeId.STRING)
)
Loading