diff --git a/cpp/include/raft/detail/mdarray.hpp b/cpp/include/raft/detail/mdarray.hpp index 624c7a4d07..a507868faa 100644 --- a/cpp/include/raft/detail/mdarray.hpp +++ b/cpp/include/raft/detail/mdarray.hpp @@ -224,6 +224,7 @@ struct accessor_mixin : public AccessorPolicy { using is_host_type = std::conditional_t; // make sure the explicit ctor can fall through using AccessorPolicy::AccessorPolicy; + using offset_policy = accessor_mixin; accessor_mixin(AccessorPolicy const& that) : AccessorPolicy{that} {} // NOLINT }; diff --git a/cpp/test/mdarray.cu b/cpp/test/mdarray.cu index 961a703a8b..855d731d07 100644 --- a/cpp/test/mdarray.cu +++ b/cpp/test/mdarray.cu @@ -409,12 +409,14 @@ TEST(MDArray, FuncArg) make_device_matrix(10, 10, rmm::cuda_stream_default); check_matrix_layout(d_matrix.view()); - // FIXME(jiamingy): The slice has a default accessor instead of accessor_mixin, due to - // the hardcoded policy in submdspan implementation. We need to have a rewritten - // version of submdspan for implementing padding. - // auto slice = - // stdex::submdspan(d_matrix.view(), std::make_tuple(2ul, 4ul), std::make_tuple(2ul, 5ul)); - // check_matrix_layout(slice); + auto slice = + stdex::submdspan(d_matrix.view(), std::make_tuple(2ul, 4ul), std::make_tuple(2ul, 5ul)); + static_assert(slice.is_strided()); + ASSERT_EQ(slice.extent(0), 2); + ASSERT_EQ(slice.extent(1), 3); + // is using device_accessor mixin. + static_assert( + std::is_same_v::accessor_type>); } } } // namespace raft