From 3efe56021b558871fdcf518cd5c126b32be3a781 Mon Sep 17 00:00:00 2001 From: fis Date: Wed, 27 Apr 2022 18:04:09 +0800 Subject: [PATCH 1/8] Implement unravel index for row-major array. --- cpp/include/raft/detail/mdarray.hpp | 125 ++++++++++++++++++++++++++++ cpp/test/mdarray.cu | 92 ++++++++++++++++++++ 2 files changed, 217 insertions(+) diff --git a/cpp/include/raft/detail/mdarray.hpp b/cpp/include/raft/detail/mdarray.hpp index cb6f8a0920..99ab44a463 100644 --- a/cpp/include/raft/detail/mdarray.hpp +++ b/cpp/include/raft/detail/mdarray.hpp @@ -27,6 +27,7 @@ #include #include #include +#include namespace raft::detail { /** @@ -240,4 +241,128 @@ namespace stdex = std::experimental; using vector_extent = stdex::extents; using matrix_extent = stdex::extents; using scalar_extent = stdex::extents<1>; + +template +MDSPAN_INLINE_FUNCTION auto native_popc(T v) -> int32_t +{ + int c = 0; + for (; v != 0; v &= v - 1) { + c++; + } + return c; +} + +MDSPAN_INLINE_FUNCTION auto popc(uint32_t v) -> int32_t +{ +#if defined(__CUDA_ARCH__) + return __popc(v); +#elif defined(__GNUC__) || defined(__clang__) + return __builtin_popcount(v); +#else + return native_popc(v); +#endif // compiler +} + +MDSPAN_INLINE_FUNCTION auto popc(uint64_t v) -> int32_t +{ +#if defined(__CUDA_ARCH__) + return __popcll(v); +#elif defined(__GNUC__) || defined(__clang__) + return __builtin_popcountll(v); +#else + return native_popc(v); +#endif // compiler +} + +template +MDSPAN_INLINE_FUNCTION constexpr auto arr_to_tup(T (&arr)[N], std::index_sequence) +{ + return thrust::make_tuple(arr[Idx]...); +} + +template +MDSPAN_INLINE_FUNCTION constexpr auto arr_to_tup(T (&arr)[N]) +{ + return arr_to_tup(arr, std::make_index_sequence{}); +} + +// uint division optimization inspired by the CIndexer in cupy. Division operation is +// slow on both CPU and GPU, especially 64 bit integer. So here we first try to avoid 64 +// bit when the index is smaller, then try to avoid division when it's exp of 2. +template +MDSPAN_INLINE_FUNCTION auto unravel_index_impl(I idx, stdex::extents shape) +{ + constexpr auto kRank = static_cast(shape.rank()); + size_t index[shape.rank()]{0}; // NOLINT + static_assert(std::is_signed::value, + "Don't change the type without changing the for loop."); + for (int32_t dim = kRank; --dim > 0;) { + auto s = static_cast>>(shape.extent(dim)); + if (s & (s - 1)) { + auto t = idx / s; + index[dim] = idx - t * s; + idx = t; + } else { // exp of 2 + index[dim] = idx & (s - 1); + idx >>= popc(s - 1); + } + } + index[0] = idx; + return arr_to_tup(index); +} + +/** + * \brief Turns linear index into coordinate. Similar to numpy unravel_index. This is not + * exposed to public as it's not part of the mdspan proposal, the returned tuple + * can not be directly used for indexing into mdspan and we might change the return + * type in the future. + * + * \code + * auto m = make_host_matrix(7, 6); + * auto m_v = m.view(); + * auto coord = detail::unravel_index(2, m.extents(), typename decltype(m)::layout_type{}); + * detail::apply(m_v, coord) = 2; + * \endcode + * + * \param idx The linear index. + * \param shape The shape of the array to use. + * \param layout Must be `layout_right` (row-major) in current implementation. + * + * \return A thrust::tuple that represents the coordinate. + */ +template +MDSPAN_INLINE_FUNCTION auto unravel_index(size_t idx, + detail::stdex::extents shape, + LayoutPolicy const&) +{ + static_assert(std::is_same::value, + "Only C layout is supported."); + if (idx > std::numeric_limits::max()) { + return unravel_index_impl(static_cast(idx), shape); + } else { + return unravel_index_impl(static_cast(idx), shape); + } +} + +template +MDSPAN_INLINE_FUNCTION auto constexpr apply_impl(Fn&& f, Tup&& t, std::index_sequence) + -> decltype(auto) +{ + return f(thrust::get(t)...); +} + +/** + * C++ 17 style apply for thrust tuple. + * + * \param f function to apply + * \param t tuple of arguments + */ +template >::value> +MDSPAN_INLINE_FUNCTION auto constexpr apply(Fn&& f, Tup&& t) -> decltype(auto) +{ + return apply_impl( + std::forward(f), std::forward(t), std::make_index_sequence{}); +} } // namespace raft::detail diff --git a/cpp/test/mdarray.cu b/cpp/test/mdarray.cu index 987f2dcf2e..75add8a2a9 100644 --- a/cpp/test/mdarray.cu +++ b/cpp/test/mdarray.cu @@ -421,4 +421,96 @@ TEST(MDArray, FuncArg) std::is_same_v::accessor_type>); } } + +namespace { +void test_mdarray_unravel() +{ + { + uint32_t v{0}; + ASSERT_EQ(detail::native_popc(v), 0); + ASSERT_EQ(detail::popc(v), 0); + v = 1; + ASSERT_EQ(detail::native_popc(v), 1); + ASSERT_EQ(detail::popc(v), 1); + v = 0xffffffff; + ASSERT_EQ(detail::native_popc(v), 32); + ASSERT_EQ(detail::popc(v), 32); + } + { + uint64_t v{0}; + ASSERT_EQ(detail::native_popc(v), 0); + ASSERT_EQ(detail::popc(v), 0); + v = 1; + ASSERT_EQ(detail::native_popc(v), 1); + ASSERT_EQ(detail::popc(v), 1); + v = 0xffffffff; + ASSERT_EQ(detail::native_popc(v), 32); + ASSERT_EQ(detail::popc(v), 32); + v = 0xffffffffffffffff; + ASSERT_EQ(detail::native_popc(v), 64); + ASSERT_EQ(detail::popc(v), 64); + } + + // examples in numpy unravel_index + { + auto coord = detail::unravel_index(22, detail::matrix_extent{7, 6}, stdex::layout_right{}); + static_assert(thrust::tuple_size::value == 2); + ASSERT_EQ(thrust::get<0>(coord), 3); + ASSERT_EQ(thrust::get<1>(coord), 4); + } + { + auto coord = detail::unravel_index(41, detail::matrix_extent{7, 6}, stdex::layout_right{}); + static_assert(thrust::tuple_size::value == 2); + ASSERT_EQ(thrust::get<0>(coord), 6); + ASSERT_EQ(thrust::get<1>(coord), 5); + } + { + auto coord = detail::unravel_index(37, detail::matrix_extent{7, 6}, stdex::layout_right{}); + static_assert(thrust::tuple_size::value == 2); + ASSERT_EQ(thrust::get<0>(coord), 6); + ASSERT_EQ(thrust::get<1>(coord), 1); + } + // assignment + { + auto m = make_host_matrix(7, 6); + auto m_v = m.view(); + for (size_t i = 0; i < m.size(); ++i) { + auto coord = detail::unravel_index(i, m.extents(), typename decltype(m)::layout_type{}); + detail::apply(m_v, coord) = i; + } + for (size_t i = 0; i < m.size(); ++i) { + auto coord = detail::unravel_index(i, m.extents(), typename decltype(m)::layout_type{}); + ASSERT_EQ(detail::apply(m_v, coord), i); + } + } + + { + handle_t handle; + auto m = make_device_matrix(handle, 7, 6); + auto m_v = m.view(); + thrust::for_each_n(handle.get_thrust_policy(), + thrust::make_counting_iterator(0ul), + m_v.size(), + [=] __device__(size_t i) { + auto coord = detail::unravel_index( + i, m_v.extents(), typename decltype(m_v)::layout_type{}); + detail::apply(m_v, coord) = static_cast(i); + }); + thrust::device_vector status(1, 0); + auto p_status = status.data().get(); + thrust::for_each_n(handle.get_thrust_policy(), + thrust::make_counting_iterator(0ul), + m_v.size(), + [=] __device__(size_t i) { + auto coord = detail::unravel_index( + i, m_v.extents(), typename decltype(m_v)::layout_type{}); + auto v = detail::apply(m_v, coord); + if (v != static_cast(i)) { raft::myAtomicAdd(p_status, 1); } + }); + check_status(p_status, handle.get_stream()); + } +} +} // anonymous namespace + +TEST(MDArray, Unravel) { test_mdarray_unravel(); } } // namespace raft From 67dba8ef609048929a2ae5e05abf222e59d22cea Mon Sep 17 00:00:00 2001 From: fis Date: Fri, 24 Jun 2022 09:24:44 +0800 Subject: [PATCH 2/8] Use std apply and expose it to public. --- cpp/include/raft/core/mdarray.hpp | 31 +++++++++++++++- cpp/include/raft/detail/mdarray.hpp | 57 +---------------------------- cpp/test/mdarray.cu | 46 +++++++++++------------ 3 files changed, 54 insertions(+), 80 deletions(-) diff --git a/cpp/include/raft/core/mdarray.hpp b/cpp/include/raft/core/mdarray.hpp index 0ab882e7a0..6313b136ef 100644 --- a/cpp/include/raft/core/mdarray.hpp +++ b/cpp/include/raft/core/mdarray.hpp @@ -914,4 +914,33 @@ auto reshape(const array_interface_type& mda, extents new_shape) return reshape(mda.view(), new_shape); } -} // namespace raft \ No newline at end of file +/** + * \brief Turns linear index into coordinate. Similar to numpy unravel_index. + * + * \code + * auto m = make_host_matrix(7, 6); + * auto m_v = m.view(); + * auto coord = detail::unravel_index(2, m.extents(), typename decltype(m)::layout_type{}); + * std::apply(m_v, coord) = 2; + * \endcode + * + * \param idx The linear index. + * \param shape The shape of the array to use. + * \param layout Must be `layout_right` (row-major) in current implementation. + * + * \return A std::tuple that represents the coordinate. + */ +template +MDSPAN_INLINE_FUNCTION auto unravel_index(size_t idx, + detail::stdex::extents shape, + LayoutPolicy const&) +{ + static_assert(std::is_same::value, + "Only C layout is supported."); + if (idx > std::numeric_limits::max()) { + return detail::unravel_index_impl(static_cast(idx), shape); + } else { + return detail::unravel_index_impl(static_cast(idx), shape); + } +} +} // namespace raft diff --git a/cpp/include/raft/detail/mdarray.hpp b/cpp/include/raft/detail/mdarray.hpp index 99ab44a463..b83e101062 100644 --- a/cpp/include/raft/detail/mdarray.hpp +++ b/cpp/include/raft/detail/mdarray.hpp @@ -277,7 +277,7 @@ MDSPAN_INLINE_FUNCTION auto popc(uint64_t v) -> int32_t template MDSPAN_INLINE_FUNCTION constexpr auto arr_to_tup(T (&arr)[N], std::index_sequence) { - return thrust::make_tuple(arr[Idx]...); + return std::make_tuple(arr[Idx]...); } template @@ -310,59 +310,4 @@ MDSPAN_INLINE_FUNCTION auto unravel_index_impl(I idx, stdex::extents index[0] = idx; return arr_to_tup(index); } - -/** - * \brief Turns linear index into coordinate. Similar to numpy unravel_index. This is not - * exposed to public as it's not part of the mdspan proposal, the returned tuple - * can not be directly used for indexing into mdspan and we might change the return - * type in the future. - * - * \code - * auto m = make_host_matrix(7, 6); - * auto m_v = m.view(); - * auto coord = detail::unravel_index(2, m.extents(), typename decltype(m)::layout_type{}); - * detail::apply(m_v, coord) = 2; - * \endcode - * - * \param idx The linear index. - * \param shape The shape of the array to use. - * \param layout Must be `layout_right` (row-major) in current implementation. - * - * \return A thrust::tuple that represents the coordinate. - */ -template -MDSPAN_INLINE_FUNCTION auto unravel_index(size_t idx, - detail::stdex::extents shape, - LayoutPolicy const&) -{ - static_assert(std::is_same::value, - "Only C layout is supported."); - if (idx > std::numeric_limits::max()) { - return unravel_index_impl(static_cast(idx), shape); - } else { - return unravel_index_impl(static_cast(idx), shape); - } -} - -template -MDSPAN_INLINE_FUNCTION auto constexpr apply_impl(Fn&& f, Tup&& t, std::index_sequence) - -> decltype(auto) -{ - return f(thrust::get(t)...); -} - -/** - * C++ 17 style apply for thrust tuple. - * - * \param f function to apply - * \param t tuple of arguments - */ -template >::value> -MDSPAN_INLINE_FUNCTION auto constexpr apply(Fn&& f, Tup&& t) -> decltype(auto) -{ - return apply_impl( - std::forward(f), std::forward(t), std::make_index_sequence{}); -} } // namespace raft::detail diff --git a/cpp/test/mdarray.cu b/cpp/test/mdarray.cu index 75add8a2a9..67a44ca5db 100644 --- a/cpp/test/mdarray.cu +++ b/cpp/test/mdarray.cu @@ -453,34 +453,34 @@ void test_mdarray_unravel() // examples in numpy unravel_index { - auto coord = detail::unravel_index(22, detail::matrix_extent{7, 6}, stdex::layout_right{}); - static_assert(thrust::tuple_size::value == 2); - ASSERT_EQ(thrust::get<0>(coord), 3); - ASSERT_EQ(thrust::get<1>(coord), 4); + auto coord = unravel_index(22, detail::matrix_extent{7, 6}, stdex::layout_right{}); + static_assert(std::tuple_size::value == 2); + ASSERT_EQ(std::get<0>(coord), 3); + ASSERT_EQ(std::get<1>(coord), 4); } { - auto coord = detail::unravel_index(41, detail::matrix_extent{7, 6}, stdex::layout_right{}); - static_assert(thrust::tuple_size::value == 2); - ASSERT_EQ(thrust::get<0>(coord), 6); - ASSERT_EQ(thrust::get<1>(coord), 5); + auto coord = unravel_index(41, detail::matrix_extent{7, 6}, stdex::layout_right{}); + static_assert(std::tuple_size::value == 2); + ASSERT_EQ(std::get<0>(coord), 6); + ASSERT_EQ(std::get<1>(coord), 5); } { - auto coord = detail::unravel_index(37, detail::matrix_extent{7, 6}, stdex::layout_right{}); - static_assert(thrust::tuple_size::value == 2); - ASSERT_EQ(thrust::get<0>(coord), 6); - ASSERT_EQ(thrust::get<1>(coord), 1); + auto coord = unravel_index(37, detail::matrix_extent{7, 6}, stdex::layout_right{}); + static_assert(std::tuple_size::value == 2); + ASSERT_EQ(std::get<0>(coord), 6); + ASSERT_EQ(std::get<1>(coord), 1); } // assignment { auto m = make_host_matrix(7, 6); auto m_v = m.view(); for (size_t i = 0; i < m.size(); ++i) { - auto coord = detail::unravel_index(i, m.extents(), typename decltype(m)::layout_type{}); - detail::apply(m_v, coord) = i; + auto coord = unravel_index(i, m.extents(), typename decltype(m)::layout_type{}); + std::apply(m_v, coord) = i; } for (size_t i = 0; i < m.size(); ++i) { - auto coord = detail::unravel_index(i, m.extents(), typename decltype(m)::layout_type{}); - ASSERT_EQ(detail::apply(m_v, coord), i); + auto coord = unravel_index(i, m.extents(), typename decltype(m)::layout_type{}); + ASSERT_EQ(std::apply(m_v, coord), i); } } @@ -492,19 +492,19 @@ void test_mdarray_unravel() thrust::make_counting_iterator(0ul), m_v.size(), [=] __device__(size_t i) { - auto coord = detail::unravel_index( - i, m_v.extents(), typename decltype(m_v)::layout_type{}); - detail::apply(m_v, coord) = static_cast(i); + auto coord = + unravel_index(i, m_v.extents(), typename decltype(m_v)::layout_type{}); + std::apply(m_v, coord) = static_cast(i); }); thrust::device_vector status(1, 0); auto p_status = status.data().get(); thrust::for_each_n(handle.get_thrust_policy(), thrust::make_counting_iterator(0ul), m_v.size(), - [=] __device__(size_t i) { - auto coord = detail::unravel_index( - i, m_v.extents(), typename decltype(m_v)::layout_type{}); - auto v = detail::apply(m_v, coord); + [=] HD(size_t i) { + auto coord = + unravel_index(i, m_v.extents(), typename decltype(m_v)::layout_type{}); + auto v = std::apply(m_v, coord); if (v != static_cast(i)) { raft::myAtomicAdd(p_status, 1); } }); check_status(p_status, handle.get_stream()); From 01cce85a9441e69e8603a905af7725bdbf7d45f0 Mon Sep 17 00:00:00 2001 From: fis Date: Fri, 24 Jun 2022 09:34:35 +0800 Subject: [PATCH 3/8] Remove header. --- cpp/include/raft/detail/mdarray.hpp | 1 - 1 file changed, 1 deletion(-) diff --git a/cpp/include/raft/detail/mdarray.hpp b/cpp/include/raft/detail/mdarray.hpp index b83e101062..b8dcc39b8f 100644 --- a/cpp/include/raft/detail/mdarray.hpp +++ b/cpp/include/raft/detail/mdarray.hpp @@ -27,7 +27,6 @@ #include #include #include -#include namespace raft::detail { /** From d1cea43d03ef3909cf926818bf2c26a660245cfc Mon Sep 17 00:00:00 2001 From: fis Date: Fri, 24 Jun 2022 09:35:08 +0800 Subject: [PATCH 4/8] cleanup. --- cpp/include/raft/detail/mdarray.hpp | 2 +- cpp/test/mdarray.cu | 6 +++--- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/cpp/include/raft/detail/mdarray.hpp b/cpp/include/raft/detail/mdarray.hpp index b8dcc39b8f..f8e59a68d2 100644 --- a/cpp/include/raft/detail/mdarray.hpp +++ b/cpp/include/raft/detail/mdarray.hpp @@ -59,7 +59,7 @@ class device_reference { auto operator=(T const& other) -> device_reference& { auto* raw = ptr_.get(); - raft::update_device(raw, &other, 1, stream_); + update_device(raw, &other, 1, stream_); return *this; } }; diff --git a/cpp/test/mdarray.cu b/cpp/test/mdarray.cu index 67a44ca5db..46ec3ba235 100644 --- a/cpp/test/mdarray.cu +++ b/cpp/test/mdarray.cu @@ -451,7 +451,7 @@ void test_mdarray_unravel() ASSERT_EQ(detail::popc(v), 64); } - // examples in numpy unravel_index + // examples from numpy unravel_index { auto coord = unravel_index(22, detail::matrix_extent{7, 6}, stdex::layout_right{}); static_assert(std::tuple_size::value == 2); @@ -491,7 +491,7 @@ void test_mdarray_unravel() thrust::for_each_n(handle.get_thrust_policy(), thrust::make_counting_iterator(0ul), m_v.size(), - [=] __device__(size_t i) { + [=] HD(size_t i) { auto coord = unravel_index(i, m_v.extents(), typename decltype(m_v)::layout_type{}); std::apply(m_v, coord) = static_cast(i); @@ -501,7 +501,7 @@ void test_mdarray_unravel() thrust::for_each_n(handle.get_thrust_policy(), thrust::make_counting_iterator(0ul), m_v.size(), - [=] HD(size_t i) { + [=] __device__(size_t i) { auto coord = unravel_index(i, m_v.extents(), typename decltype(m_v)::layout_type{}); auto v = std::apply(m_v, coord); From 7a4b48ec368c6592b8f17fd2eea0fd7eef545d05 Mon Sep 17 00:00:00 2001 From: fis Date: Fri, 24 Jun 2022 12:35:14 +0800 Subject: [PATCH 5/8] doxygen. --- cpp/include/raft/core/mdarray.hpp | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/cpp/include/raft/core/mdarray.hpp b/cpp/include/raft/core/mdarray.hpp index 6313b136ef..d1caf2498b 100644 --- a/cpp/include/raft/core/mdarray.hpp +++ b/cpp/include/raft/core/mdarray.hpp @@ -926,16 +926,17 @@ auto reshape(const array_interface_type& mda, extents new_shape) * * \param idx The linear index. * \param shape The shape of the array to use. - * \param layout Must be `layout_right` (row-major) in current implementation. + * \param layout Must be `layout_c_contiguous` (row-major) in current implementation. * * \return A std::tuple that represents the coordinate. */ template MDSPAN_INLINE_FUNCTION auto unravel_index(size_t idx, - detail::stdex::extents shape, - LayoutPolicy const&) + extents shape, + LayoutPolicy const& layout) { - static_assert(std::is_same::value, + static_assert(std::is_same_v>, + layout_c_contiguous>, "Only C layout is supported."); if (idx > std::numeric_limits::max()) { return detail::unravel_index_impl(static_cast(idx), shape); From 057baef3295a9cabd97cf754ca547205d65de1a7 Mon Sep 17 00:00:00 2001 From: fis Date: Fri, 24 Jun 2022 12:48:49 +0800 Subject: [PATCH 6/8] doxygen. --- cpp/include/raft/core/mdarray.hpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cpp/include/raft/core/mdarray.hpp b/cpp/include/raft/core/mdarray.hpp index d1caf2498b..71d30c060a 100644 --- a/cpp/include/raft/core/mdarray.hpp +++ b/cpp/include/raft/core/mdarray.hpp @@ -920,7 +920,7 @@ auto reshape(const array_interface_type& mda, extents new_shape) * \code * auto m = make_host_matrix(7, 6); * auto m_v = m.view(); - * auto coord = detail::unravel_index(2, m.extents(), typename decltype(m)::layout_type{}); + * auto coord = unravel_index(2, m.extents(), typename decltype(m)::layout_type{}); * std::apply(m_v, coord) = 2; * \endcode * From ec33b017b285f4d34df78c84c003ffa12165f79a Mon Sep 17 00:00:00 2001 From: fis Date: Tue, 28 Jun 2022 03:52:24 +0800 Subject: [PATCH 7/8] Custom index type. --- cpp/include/raft/core/mdarray.hpp | 8 +++++--- cpp/include/raft/detail/mdarray.hpp | 2 +- 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/cpp/include/raft/core/mdarray.hpp b/cpp/include/raft/core/mdarray.hpp index 71d30c060a..31bd506181 100644 --- a/cpp/include/raft/core/mdarray.hpp +++ b/cpp/include/raft/core/mdarray.hpp @@ -930,15 +930,17 @@ auto reshape(const array_interface_type& mda, extents new_shape) * * \return A std::tuple that represents the coordinate. */ -template -MDSPAN_INLINE_FUNCTION auto unravel_index(size_t idx, +template +MDSPAN_INLINE_FUNCTION auto unravel_index(Idx idx, extents shape, LayoutPolicy const& layout) { static_assert(std::is_same_v>, layout_c_contiguous>, "Only C layout is supported."); - if (idx > std::numeric_limits::max()) { + static_assert(std::is_integral_v, "Index must be integral."); + auto constexpr kIs64 = sizeof(std::remove_cv_t>) == sizeof(uint64_t); + if (kIs64 && idx > std::numeric_limits::max()) { return detail::unravel_index_impl(static_cast(idx), shape); } else { return detail::unravel_index_impl(static_cast(idx), shape); diff --git a/cpp/include/raft/detail/mdarray.hpp b/cpp/include/raft/detail/mdarray.hpp index f8e59a68d2..c4557245ae 100644 --- a/cpp/include/raft/detail/mdarray.hpp +++ b/cpp/include/raft/detail/mdarray.hpp @@ -292,7 +292,7 @@ template MDSPAN_INLINE_FUNCTION auto unravel_index_impl(I idx, stdex::extents shape) { constexpr auto kRank = static_cast(shape.rank()); - size_t index[shape.rank()]{0}; // NOLINT + std::size_t index[shape.rank()]{0}; // NOLINT static_assert(std::is_signed::value, "Don't change the type without changing the for loop."); for (int32_t dim = kRank; --dim > 0;) { From a766f347faa753674cf6ff31670c0318aff8ef71 Mon Sep 17 00:00:00 2001 From: fis Date: Tue, 28 Jun 2022 04:28:03 +0800 Subject: [PATCH 8/8] Signedness. --- cpp/include/raft/core/mdarray.hpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cpp/include/raft/core/mdarray.hpp b/cpp/include/raft/core/mdarray.hpp index 31bd506181..4465de21e7 100644 --- a/cpp/include/raft/core/mdarray.hpp +++ b/cpp/include/raft/core/mdarray.hpp @@ -940,7 +940,7 @@ MDSPAN_INLINE_FUNCTION auto unravel_index(Idx idx, "Only C layout is supported."); static_assert(std::is_integral_v, "Index must be integral."); auto constexpr kIs64 = sizeof(std::remove_cv_t>) == sizeof(uint64_t); - if (kIs64 && idx > std::numeric_limits::max()) { + if (kIs64 && static_cast(idx) > std::numeric_limits::max()) { return detail::unravel_index_impl(static_cast(idx), shape); } else { return detail::unravel_index_impl(static_cast(idx), shape);