diff --git a/cpp/include/raft/core/mdarray.hpp b/cpp/include/raft/core/mdarray.hpp index 0ab882e7a0..4465de21e7 100644 --- a/cpp/include/raft/core/mdarray.hpp +++ b/cpp/include/raft/core/mdarray.hpp @@ -914,4 +914,36 @@ auto reshape(const array_interface_type& mda, extents new_shape) return reshape(mda.view(), new_shape); } -} // namespace raft \ No newline at end of file +/** + * \brief Turns linear index into coordinate. Similar to numpy unravel_index. + * + * \code + * auto m = make_host_matrix(7, 6); + * auto m_v = m.view(); + * auto coord = unravel_index(2, m.extents(), typename decltype(m)::layout_type{}); + * std::apply(m_v, coord) = 2; + * \endcode + * + * \param idx The linear index. + * \param shape The shape of the array to use. + * \param layout Must be `layout_c_contiguous` (row-major) in current implementation. + * + * \return A std::tuple that represents the coordinate. + */ +template +MDSPAN_INLINE_FUNCTION auto unravel_index(Idx idx, + extents shape, + LayoutPolicy const& layout) +{ + static_assert(std::is_same_v>, + layout_c_contiguous>, + "Only C layout is supported."); + static_assert(std::is_integral_v, "Index must be integral."); + auto constexpr kIs64 = sizeof(std::remove_cv_t>) == sizeof(uint64_t); + if (kIs64 && static_cast(idx) > std::numeric_limits::max()) { + return detail::unravel_index_impl(static_cast(idx), shape); + } else { + return detail::unravel_index_impl(static_cast(idx), shape); + } +} +} // namespace raft diff --git a/cpp/include/raft/detail/mdarray.hpp b/cpp/include/raft/detail/mdarray.hpp index cb6f8a0920..c4557245ae 100644 --- a/cpp/include/raft/detail/mdarray.hpp +++ b/cpp/include/raft/detail/mdarray.hpp @@ -59,7 +59,7 @@ class device_reference { auto operator=(T const& other) -> device_reference& { auto* raw = ptr_.get(); - raft::update_device(raw, &other, 1, stream_); + update_device(raw, &other, 1, stream_); return *this; } }; @@ -240,4 +240,73 @@ namespace stdex = std::experimental; using vector_extent = stdex::extents; using matrix_extent = stdex::extents; using scalar_extent = stdex::extents<1>; + +template +MDSPAN_INLINE_FUNCTION auto native_popc(T v) -> int32_t +{ + int c = 0; + for (; v != 0; v &= v - 1) { + c++; + } + return c; +} + +MDSPAN_INLINE_FUNCTION auto popc(uint32_t v) -> int32_t +{ +#if defined(__CUDA_ARCH__) + return __popc(v); +#elif defined(__GNUC__) || defined(__clang__) + return __builtin_popcount(v); +#else + return native_popc(v); +#endif // compiler +} + +MDSPAN_INLINE_FUNCTION auto popc(uint64_t v) -> int32_t +{ +#if defined(__CUDA_ARCH__) + return __popcll(v); +#elif defined(__GNUC__) || defined(__clang__) + return __builtin_popcountll(v); +#else + return native_popc(v); +#endif // compiler +} + +template +MDSPAN_INLINE_FUNCTION constexpr auto arr_to_tup(T (&arr)[N], std::index_sequence) +{ + return std::make_tuple(arr[Idx]...); +} + +template +MDSPAN_INLINE_FUNCTION constexpr auto arr_to_tup(T (&arr)[N]) +{ + return arr_to_tup(arr, std::make_index_sequence{}); +} + +// uint division optimization inspired by the CIndexer in cupy. Division operation is +// slow on both CPU and GPU, especially 64 bit integer. So here we first try to avoid 64 +// bit when the index is smaller, then try to avoid division when it's exp of 2. +template +MDSPAN_INLINE_FUNCTION auto unravel_index_impl(I idx, stdex::extents shape) +{ + constexpr auto kRank = static_cast(shape.rank()); + std::size_t index[shape.rank()]{0}; // NOLINT + static_assert(std::is_signed::value, + "Don't change the type without changing the for loop."); + for (int32_t dim = kRank; --dim > 0;) { + auto s = static_cast>>(shape.extent(dim)); + if (s & (s - 1)) { + auto t = idx / s; + index[dim] = idx - t * s; + idx = t; + } else { // exp of 2 + index[dim] = idx & (s - 1); + idx >>= popc(s - 1); + } + } + index[0] = idx; + return arr_to_tup(index); +} } // namespace raft::detail diff --git a/cpp/test/mdarray.cu b/cpp/test/mdarray.cu index 987f2dcf2e..46ec3ba235 100644 --- a/cpp/test/mdarray.cu +++ b/cpp/test/mdarray.cu @@ -421,4 +421,96 @@ TEST(MDArray, FuncArg) std::is_same_v::accessor_type>); } } + +namespace { +void test_mdarray_unravel() +{ + { + uint32_t v{0}; + ASSERT_EQ(detail::native_popc(v), 0); + ASSERT_EQ(detail::popc(v), 0); + v = 1; + ASSERT_EQ(detail::native_popc(v), 1); + ASSERT_EQ(detail::popc(v), 1); + v = 0xffffffff; + ASSERT_EQ(detail::native_popc(v), 32); + ASSERT_EQ(detail::popc(v), 32); + } + { + uint64_t v{0}; + ASSERT_EQ(detail::native_popc(v), 0); + ASSERT_EQ(detail::popc(v), 0); + v = 1; + ASSERT_EQ(detail::native_popc(v), 1); + ASSERT_EQ(detail::popc(v), 1); + v = 0xffffffff; + ASSERT_EQ(detail::native_popc(v), 32); + ASSERT_EQ(detail::popc(v), 32); + v = 0xffffffffffffffff; + ASSERT_EQ(detail::native_popc(v), 64); + ASSERT_EQ(detail::popc(v), 64); + } + + // examples from numpy unravel_index + { + auto coord = unravel_index(22, detail::matrix_extent{7, 6}, stdex::layout_right{}); + static_assert(std::tuple_size::value == 2); + ASSERT_EQ(std::get<0>(coord), 3); + ASSERT_EQ(std::get<1>(coord), 4); + } + { + auto coord = unravel_index(41, detail::matrix_extent{7, 6}, stdex::layout_right{}); + static_assert(std::tuple_size::value == 2); + ASSERT_EQ(std::get<0>(coord), 6); + ASSERT_EQ(std::get<1>(coord), 5); + } + { + auto coord = unravel_index(37, detail::matrix_extent{7, 6}, stdex::layout_right{}); + static_assert(std::tuple_size::value == 2); + ASSERT_EQ(std::get<0>(coord), 6); + ASSERT_EQ(std::get<1>(coord), 1); + } + // assignment + { + auto m = make_host_matrix(7, 6); + auto m_v = m.view(); + for (size_t i = 0; i < m.size(); ++i) { + auto coord = unravel_index(i, m.extents(), typename decltype(m)::layout_type{}); + std::apply(m_v, coord) = i; + } + for (size_t i = 0; i < m.size(); ++i) { + auto coord = unravel_index(i, m.extents(), typename decltype(m)::layout_type{}); + ASSERT_EQ(std::apply(m_v, coord), i); + } + } + + { + handle_t handle; + auto m = make_device_matrix(handle, 7, 6); + auto m_v = m.view(); + thrust::for_each_n(handle.get_thrust_policy(), + thrust::make_counting_iterator(0ul), + m_v.size(), + [=] HD(size_t i) { + auto coord = + unravel_index(i, m_v.extents(), typename decltype(m_v)::layout_type{}); + std::apply(m_v, coord) = static_cast(i); + }); + thrust::device_vector status(1, 0); + auto p_status = status.data().get(); + thrust::for_each_n(handle.get_thrust_policy(), + thrust::make_counting_iterator(0ul), + m_v.size(), + [=] __device__(size_t i) { + auto coord = + unravel_index(i, m_v.extents(), typename decltype(m_v)::layout_type{}); + auto v = std::apply(m_v, coord); + if (v != static_cast(i)) { raft::myAtomicAdd(p_status, 1); } + }); + check_status(p_status, handle.get_stream()); + } +} +} // anonymous namespace + +TEST(MDArray, Unravel) { test_mdarray_unravel(); } } // namespace raft