Skip to content

Commit

Permalink
Use dynamic_extent from stdex. (#523)
Browse files Browse the repository at this point in the history
Close #478 .

Authors:
  - Jiaming Yuan (https://github.com/trivialfis)

Approvers:
  - Corey J. Nolet (https://github.com/cjnolet)

URL: #523
  • Loading branch information
trivialfis authored Feb 22, 2022
1 parent badaee0 commit e6d148b
Show file tree
Hide file tree
Showing 3 changed files with 9 additions and 7 deletions.
5 changes: 3 additions & 2 deletions cpp/include/raft/detail/mdarray.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
*/
#pragma once
#include <experimental/mdspan>
#include <raft/detail/span.hpp> // dynamic_extent
#include <rmm/cuda_stream_view.hpp>
#include <rmm/device_uvector.hpp>
#include <thrust/device_ptr.h>
Expand Down Expand Up @@ -234,7 +235,7 @@ using device_accessor = accessor_mixin<AccessorPolicy, false>;

namespace stdex = std::experimental;

using vector_extent = stdex::extents<stdex::dynamic_extent>;
using matrix_extent = stdex::extents<stdex::dynamic_extent, stdex::dynamic_extent>;
using vector_extent = stdex::extents<dynamic_extent>;
using matrix_extent = stdex::extents<dynamic_extent, dynamic_extent>;
using scalar_extent = stdex::extents<1>;
} // namespace raft::detail
3 changes: 2 additions & 1 deletion cpp/include/raft/detail/span.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,13 @@
*/
#pragma once

#include <experimental/mdspan>
#include <limits> // numeric_limits
#include <thrust/host_vector.h> // __host__ __device__
#include <type_traits>

namespace raft {
constexpr std::size_t dynamic_extent = std::numeric_limits<std::size_t>::max();
constexpr std::size_t dynamic_extent = std::experimental::dynamic_extent;

template <class ElementType, bool is_device, std::size_t Extent>
class span;
Expand Down
8 changes: 4 additions & 4 deletions cpp/test/mdarray.cu
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ void test_mdspan()
auto stream = rmm::cuda_stream_default;
rmm::device_uvector<float> a{16ul, stream};
thrust::sequence(rmm::exec_policy(stream), a.begin(), a.end());
stdex::mdspan<float, stdex::extents<stdex::dynamic_extent, stdex::dynamic_extent>> span{
stdex::mdspan<float, stdex::extents<raft::dynamic_extent, raft::dynamic_extent>> span{
a.data(), 4, 4};
thrust::device_vector<int32_t> status(1, 0);
auto p_status = status.data().get();
Expand Down Expand Up @@ -74,7 +74,7 @@ TEST(MDArray, Policy) { test_uvector_policy(); }

void test_mdarray_basic()
{
using matrix_extent = stdex::extents<stdex::dynamic_extent, stdex::dynamic_extent>;
using matrix_extent = stdex::extents<dynamic_extent, dynamic_extent>;
auto s = rmm::cuda_stream_default;
{
/**
Expand Down Expand Up @@ -180,7 +180,7 @@ TEST(MDArray, Basic) { test_mdarray_basic(); }
template <typename BasicMDarray, typename PolicyFn, typename ThrustPolicy>
void test_mdarray_copy_move(ThrustPolicy exec, PolicyFn make_policy)
{
using matrix_extent = stdex::extents<stdex::dynamic_extent, stdex::dynamic_extent>;
using matrix_extent = stdex::extents<dynamic_extent, dynamic_extent>;
layout_c_contiguous::mapping<matrix_extent> layout{matrix_extent{4, 4}};

using mdarray_t = BasicMDarray;
Expand Down Expand Up @@ -251,7 +251,7 @@ void test_mdarray_copy_move(ThrustPolicy exec, PolicyFn make_policy)

TEST(MDArray, CopyMove)
{
using matrix_extent = stdex::extents<stdex::dynamic_extent, stdex::dynamic_extent>;
using matrix_extent = stdex::extents<dynamic_extent, dynamic_extent>;
using d_matrix_t = device_mdarray<float, matrix_extent>;
using policy_t = typename d_matrix_t::container_policy_type;
auto s = rmm::cuda_stream_default;
Expand Down

0 comments on commit e6d148b

Please sign in to comment.