Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Separating mdspan/mdarray infra into host_* and device_* variants #810

Merged
merged 13 commits into from
Sep 22, 2022
3 changes: 2 additions & 1 deletion cpp/include/raft/cluster/detail/kmeans.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,10 @@
#include <raft/cluster/detail/kmeans_common.cuh>
#include <raft/cluster/kmeans_params.hpp>
#include <raft/core/cudart_utils.hpp>
#include <raft/core/device_mdarray.hpp>
#include <raft/core/handle.hpp>
#include <raft/core/host_mdarray.hpp>
#include <raft/core/logger.hpp>
#include <raft/core/mdarray.hpp>
#include <raft/cuda_utils.cuh>
#include <raft/distance/distance_type.hpp>
#include <raft/linalg/map_then_reduce.cuh>
Expand Down
2 changes: 1 addition & 1 deletion cpp/include/raft/cluster/detail/kmeans_common.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,9 @@

#include <raft/cluster/kmeans_params.hpp>
#include <raft/core/cudart_utils.hpp>
#include <raft/core/device_mdarray.hpp>
#include <raft/core/handle.hpp>
#include <raft/core/logger.hpp>
#include <raft/core/mdarray.hpp>
#include <raft/cuda_utils.cuh>
#include <raft/distance/distance.cuh>
#include <raft/distance/distance_type.hpp>
Expand Down
39 changes: 39 additions & 0 deletions cpp/include/raft/core/detail/accessor_mixin.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
/*
* Copyright (c) 2022, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#pragma once

namespace raft::detail {

/**
* @brief A mixin to distinguish host and device memory.
*/
template <typename AccessorPolicy, bool is_host, bool is_device>
struct accessor_mixin : public AccessorPolicy {
using accessor_type = AccessorPolicy;
using is_host_type = std::conditional_t<is_host, std::true_type, std::false_type>;
using is_device_type = std::conditional_t<is_device, std::true_type, std::false_type>;
using is_managed_type = std::conditional_t<is_device && is_host, std::true_type, std::false_type>;
static constexpr bool is_host_accessible = is_host;
static constexpr bool is_device_accessible = is_device;
static constexpr bool is_managed_accessible = is_device && is_host;
// make sure the explicit ctor can fall through
using AccessorPolicy::AccessorPolicy;
using offset_policy = accessor_mixin;
accessor_mixin(AccessorPolicy const& that) : AccessorPolicy{that} {} // NOLINT
};

} // namespace raft::detail
68 changes: 68 additions & 0 deletions cpp/include/raft/core/detail/mdspan_util.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
/*
* Copyright (c) 2022, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once

#include <raft/thirdparty/mdspan/include/experimental/mdspan>
#include <tuple>
#include <utility>

namespace raft::detail {

template <class T, std::size_t N, std::size_t... Idx>
MDSPAN_INLINE_FUNCTION constexpr auto arr_to_tup(T (&arr)[N], std::index_sequence<Idx...>)
{
return std::make_tuple(arr[Idx]...);
}

template <class T, std::size_t N>
MDSPAN_INLINE_FUNCTION constexpr auto arr_to_tup(T (&arr)[N])
{
return arr_to_tup(arr, std::make_index_sequence<N>{});
}

template <typename T>
MDSPAN_INLINE_FUNCTION auto native_popc(T v) -> int32_t
{
int c = 0;
for (; v != 0; v &= v - 1) {
c++;
}
return c;
}

MDSPAN_INLINE_FUNCTION auto popc(uint32_t v) -> int32_t
{
#if defined(__CUDA_ARCH__)
cjnolet marked this conversation as resolved.
Show resolved Hide resolved
return __popc(v);
#elif defined(__GNUC__) || defined(__clang__)
return __builtin_popcount(v);
#else
return native_popc(v);
#endif // compiler
}

MDSPAN_INLINE_FUNCTION auto popc(uint64_t v) -> int32_t
{
#if defined(__CUDA_ARCH__)
return __popcll(v);
#elif defined(__GNUC__) || defined(__clang__)
return __builtin_popcountll(v);
#else
return native_popc(v);
#endif // compiler
}

} // end namespace raft::detail
175 changes: 175 additions & 0 deletions cpp/include/raft/core/device_mdarray.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,175 @@
/*
* Copyright (c) 2022, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#pragma once

#include <raft/core/device_mdarray.hpp>
#include <raft/core/mdarray.hpp>

namespace raft {

/**
* @brief mdarray with device container policy
* @tparam ElementType the data type of the elements
* @tparam Extents defines the shape
* @tparam LayoutPolicy policy for indexing strides and layout ordering
* @tparam ContainerPolicy storage and accessor policy
*/
template <typename ElementType,
typename Extents,
typename LayoutPolicy = layout_c_contiguous,
typename ContainerPolicy = detail::device_uvector_policy<ElementType>>
using device_mdarray =
mdarray<ElementType, Extents, LayoutPolicy, device_accessor<ContainerPolicy>>;

/**
* @brief Shorthand for 0-dim host mdarray (scalar).
* @tparam ElementType the data type of the scalar element
* @tparam IndexType the index type of the extents
*/
template <typename ElementType, typename IndexType = std::uint32_t>
using device_scalar = device_mdarray<ElementType, scalar_extent<IndexType>>;

/**
* @brief Shorthand for 1-dim device mdarray.
* @tparam ElementType the data type of the vector elements
* @tparam IndexType the index type of the extents
* @tparam LayoutPolicy policy for strides and layout ordering
*/
template <typename ElementType,
typename IndexType = std::uint32_t,
typename LayoutPolicy = layout_c_contiguous>
using device_vector = device_mdarray<ElementType, vector_extent<IndexType>, LayoutPolicy>;

/**
* @brief Shorthand for c-contiguous device matrix.
* @tparam ElementType the data type of the matrix elements
* @tparam IndexType the index type of the extents
* @tparam LayoutPolicy policy for strides and layout ordering
*/
template <typename ElementType,
typename IndexType = std::uint32_t,
typename LayoutPolicy = layout_c_contiguous>
using device_matrix = device_mdarray<ElementType, matrix_extent<IndexType>, LayoutPolicy>;

/**
* @brief Create a device mdarray.
* @tparam ElementType the data type of the matrix elements
* @tparam IndexType the index type of the extents
* @tparam LayoutPolicy policy for strides and layout ordering
* @param handle raft::handle_t
* @param exts dimensionality of the array (series of integers)
* @return raft::device_mdarray
*/
template <typename ElementType,
typename IndexType = std::uint32_t,
typename LayoutPolicy = layout_c_contiguous,
size_t... Extents>
auto make_device_mdarray(const raft::handle_t& handle, extents<IndexType, Extents...> exts)
{
using mdarray_t = device_mdarray<ElementType, decltype(exts), LayoutPolicy>;

typename mdarray_t::mapping_type layout{exts};
typename mdarray_t::container_policy_type policy{handle.get_stream()};

return mdarray_t{layout, policy};
}

/**
* @brief Create a device mdarray.
* @tparam ElementType the data type of the matrix elements
* @tparam IndexType the index type of the extents
* @tparam LayoutPolicy policy for strides and layout ordering
* @param handle raft::handle_t
* @param mr rmm memory resource used for allocating the memory for the array
* @param exts dimensionality of the array (series of integers)
* @return raft::device_mdarray
*/
template <typename ElementType,
typename IndexType = std::uint32_t,
typename LayoutPolicy = layout_c_contiguous,
size_t... Extents>
auto make_device_mdarray(const raft::handle_t& handle,
rmm::mr::device_memory_resource* mr,
extents<IndexType, Extents...> exts)
{
using mdarray_t = device_mdarray<ElementType, decltype(exts), LayoutPolicy>;

typename mdarray_t::mapping_type layout{exts};
typename mdarray_t::container_policy_type policy{handle.get_stream(), mr};

return mdarray_t{layout, policy};
}

/**
* @brief Create a 2-dim c-contiguous device mdarray.
*
* @tparam ElementType the data type of the matrix elements
* @tparam IndexType the index type of the extents
* @tparam LayoutPolicy policy for strides and layout ordering
* @param[in] handle raft handle for managing expensive resources
* @param[in] n_rows number or rows in matrix
* @param[in] n_cols number of columns in matrix
* @return raft::device_matrix
*/
template <typename ElementType,
typename IndexType = std::uint32_t,
typename LayoutPolicy = layout_c_contiguous>
auto make_device_matrix(raft::handle_t const& handle, IndexType n_rows, IndexType n_cols)
{
return make_device_mdarray<ElementType, IndexType, LayoutPolicy>(
handle.get_stream(), make_extents<IndexType>(n_rows, n_cols));
}

/**
* @brief Create a device scalar from v.
*
* @tparam ElementType the data type of the scalar element
* @tparam IndexType the index type of the extents
* @param[in] handle raft handle for managing expensive cuda resources
* @param[in] v scalar to wrap on device
* @return raft::device_scalar
*/
template <typename ElementType, typename IndexType = std::uint32_t>
auto make_device_scalar(raft::handle_t const& handle, ElementType const& v)
{
scalar_extent<IndexType> extents;
using policy_t = typename device_scalar<ElementType>::container_policy_type;
policy_t policy{handle.get_stream()};
auto scalar = device_scalar<ElementType>{extents, policy};
scalar(0) = v;
return scalar;
}

/**
* @brief Create a 1-dim device mdarray.
* @tparam ElementType the data type of the vector elements
* @tparam IndexType the index type of the extents
* @tparam LayoutPolicy policy for strides and layout ordering
* @param[in] handle raft handle for managing expensive cuda resources
* @param[in] n number of elements in vector
* @return raft::device_vector
*/
template <typename ElementType,
typename IndexType = std::uint32_t,
typename LayoutPolicy = layout_c_contiguous>
auto make_device_vector(raft::handle_t const& handle, IndexType n)
{
return make_device_mdarray<ElementType, IndexType, LayoutPolicy>(handle.get_stream(),
make_extents<IndexType>(n));
}

} // end namespace raft
Loading