Skip to content

Commit

Permalink
Merge pull request #1545 from rapidsai/branch-23.06
Browse files Browse the repository at this point in the history
Forward-merge branch-23.06 to branch-23.08
  • Loading branch information
GPUtester authored May 23, 2023
2 parents 254b540 + 42c9c18 commit 99655ea
Show file tree
Hide file tree
Showing 72 changed files with 2,053 additions and 959 deletions.
7 changes: 4 additions & 3 deletions cpp/bench/prims/neighbors/knn.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -181,9 +181,10 @@ struct ivf_pq_knn {
{
search_params.n_probes = 20;
auto queries_view =
raft::make_device_matrix_view<const ValT, IdxT>(search_items, ps.n_queries, ps.n_dims);
auto idxs_view = raft::make_device_matrix_view<IdxT, IdxT>(out_idxs, ps.n_queries, ps.k);
auto dists_view = raft::make_device_matrix_view<dist_t, IdxT>(out_dists, ps.n_queries, ps.k);
raft::make_device_matrix_view<const ValT, uint32_t>(search_items, ps.n_queries, ps.n_dims);
auto idxs_view = raft::make_device_matrix_view<IdxT, uint32_t>(out_idxs, ps.n_queries, ps.k);
auto dists_view =
raft::make_device_matrix_view<dist_t, uint32_t>(out_dists, ps.n_queries, ps.k);
raft::neighbors::ivf_pq::search(
handle, search_params, *index, queries_view, idxs_view, dists_view);
}
Expand Down
10 changes: 9 additions & 1 deletion cpp/include/raft/core/device_span.hpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2022, NVIDIA CORPORATION.
* Copyright (c) 2022-2023, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand All @@ -20,10 +20,18 @@

namespace raft {

/**
* @defgroup device_span one-dimensional device span type
* @{
*/

/**
* @brief A span class for device pointer.
*/
template <typename T, size_t extent = std::experimental::dynamic_extent>
using device_span = span<T, true, extent>;

/**
* @}
*/
} // end namespace raft
12 changes: 12 additions & 0 deletions cpp/include/raft/core/host_mdarray.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,11 @@ template <typename ElementType,
typename LayoutPolicy = layout_c_contiguous>
using host_matrix = host_mdarray<ElementType, matrix_extent<IndexType>, LayoutPolicy>;

/**
* @defgroup host_mdarray_factories factories to create host mdarrays
* @{
*/

/**
* @brief Create a host mdarray.
* @tparam ElementType the data type of the matrix elements
Expand All @@ -90,6 +95,10 @@ auto make_host_mdarray(raft::resources& res, extents<IndexType, Extents...> exts
return mdarray_t{res, layout, policy};
}

/**
* @}
*/

/**
* @brief Create a host mdarray.
* @tparam ElementType the data type of the matrix elements
Expand Down Expand Up @@ -117,6 +126,7 @@ auto make_host_mdarray(extents<IndexType, Extents...> exts)
}

/**
* @ingroup host_mdarray_factories
* @brief Create a 2-dim c-contiguous host mdarray.
* @tparam ElementType the data type of the matrix elements
* @tparam IndexType the index type of the extents
Expand Down Expand Up @@ -157,6 +167,7 @@ auto make_host_matrix(IndexType n_rows, IndexType n_cols)
}

/**
* @ingroup host_mdarray_factories
* @brief Create a host scalar from v.
*
* @tparam ElementType the data type of the scalar element
Expand Down Expand Up @@ -206,6 +217,7 @@ auto make_host_scalar(ElementType const& v)
}

/**
* @ingroup host_mdarray_factories
* @brief Create a 1-dim host mdarray.
* @tparam ElementType the data type of the vector elements
* @tparam IndexType the index type of the extents
Expand Down
12 changes: 11 additions & 1 deletion cpp/include/raft/core/host_span.hpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2022, NVIDIA CORPORATION.
* Copyright (c) 2022-2023, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand All @@ -19,10 +19,20 @@
#include <raft/core/span.hpp>

namespace raft {

/**
* @defgroup device_span one-dimensional device span type
* @{
*/

/**
* @brief A span class for host pointer.
*/
template <typename T, size_t extent = std::experimental::dynamic_extent>
using host_span = span<T, false, extent>;

/**
* @}
*/

} // end namespace raft
9 changes: 9 additions & 0 deletions cpp/include/raft/core/interruptible.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,11 @@

namespace raft {

/**
* @defgroup interruptible definitions and classes related to the interruptible API
* @{
*/

/**
* @brief Exception thrown during `interruptible::synchronize` call when it detects a request
* to cancel the work performed in this CPU thread.
Expand Down Expand Up @@ -297,6 +302,10 @@ class interruptible {
}
};

/**
* @}
*/

} // namespace raft

#endif
20 changes: 20 additions & 0 deletions cpp/include/raft/core/mdarray.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,12 @@
#include <raft/core/resources.hpp>

namespace raft {

/**
* @defgroup mdarray multi-dimensional memory-owning type
* @{
*/

/**
* @brief Interface to implement an owning multi-dimensional array
*
Expand Down Expand Up @@ -207,6 +213,7 @@ class mdarray
: cp_(cp), map_(m), c_(cp_.create(handle, map_.required_span_size()))
{
}

RAFT_MDARRAY_CTOR_CONSTEXPR mdarray(raft::resources const& handle,
mapping_type const& m,
container_policy_type& cp)
Expand Down Expand Up @@ -336,6 +343,15 @@ class mdarray
container_type c_;
};

/**
* @}
*/

/**
* @defgroup mdarray_reshape Row- or Col-norm computation
* @{
*/

/**
* @brief Flatten object implementing raft::array_interface into a 1-dim array view
*
Expand Down Expand Up @@ -371,4 +387,8 @@ auto reshape(const array_interface_type& mda, extents<IndexType, Extents...> new
return reshape(mda.view(), new_shape);
}

/**
* }@
*/

} // namespace raft
9 changes: 9 additions & 0 deletions cpp/include/raft/core/mdspan.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -213,6 +213,11 @@ constexpr auto make_extents(Extents... exts)
return extents<IndexType, ((void)exts, dynamic_extent)...>{exts...};
}

/**
* @defgroup mdspan_reshape Row- or Col-norm computation
* @{
*/

/**
* @brief Flatten raft::mdspan into a 1-dim array view
*
Expand Down Expand Up @@ -298,6 +303,10 @@ RAFT_INLINE_FUNCTION auto unravel_index(Idx idx,
}
}

/**
* @}
*/

/**
* @brief Const accessor specialization for default_accessor
*
Expand Down
9 changes: 9 additions & 0 deletions cpp/include/raft/core/span.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,11 @@
#include <type_traits>

namespace raft {

/**
* @defgroup span one-dimensional span type
* @{
*/
/**
* @brief The span class defined in ISO C++20. Iterator is defined as plain pointer and
* most of the methods have bound check on debug build.
Expand Down Expand Up @@ -274,4 +279,8 @@ auto as_writable_bytes(span<T, is_device, E> s) noexcept
{
return {reinterpret_cast<std::byte*>(s.data()), s.size_bytes()};
}

/**
* @}
*/
} // namespace raft
7 changes: 4 additions & 3 deletions cpp/include/raft/distance/detail/fused_l2_nn.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -325,12 +325,13 @@ void fusedL2NNImpl(OutT* min,
decltype(distance_op),
decltype(fin_op)>;

// Get pointer to fp32 SIMT kernel to determine the runtime architecture of the
// current system. Other methods to determine the architecture (that do not
// Get pointer to fp32 SIMT kernel to determine the best compute architecture
// out of all for which the kernel was compiled for that matches closely
// to the current device. Other methods to determine the architecture (that do not
// require a pointer) can be error prone. See:
// https://github.com/NVIDIA/cub/issues/545
void* kernel_ptr = reinterpret_cast<void*>(kernel);
auto runtime_arch = arch::kernel_runtime_arch(kernel_ptr);
auto runtime_arch = arch::kernel_virtual_arch(kernel_ptr);
auto cutlass_range = arch::SM_range(arch::SM_80(), arch::SM_future());

if (cutlass_range.contains(runtime_arch)) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -108,13 +108,14 @@ void pairwise_matrix_dispatch(OpT distance_op,
auto cutlass_range = arch::SM_range(arch::SM_80(), arch::SM_future());
auto legacy_range = arch::SM_range(arch::SM_min(), arch::SM_80());

// Get pointer to SM60 kernel to determine the runtime architecture of the
// current system. Other methods to determine the architecture (that do not
// Get pointer to SM60 kernel to determine the best compute architecture
// out of all for which the kernel was compiled for that matches closely
// to the current device. Other methods to determine the architecture (that do not
// require a pointer) can be error prone. See:
// https://github.com/NVIDIA/cub/issues/545
auto sm60_wrapper = pairwise_matrix_sm60_get_wrapper(distance_op, params, legacy_range);
void* kernel_ptr = reinterpret_cast<void*>(sm60_wrapper.kernel_ptr);
auto runtime_arch = arch::kernel_runtime_arch(kernel_ptr);
auto runtime_arch = arch::kernel_virtual_arch(kernel_ptr);

if (cutlass_range.contains(runtime_arch)) {
// If device is SM_80 or later, use CUTLASS-based kernel.
Expand Down
115 changes: 114 additions & 1 deletion cpp/include/raft/linalg/detail/cusolver_wrappers.hpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2022, NVIDIA CORPORATION.
* Copyright (c) 2022-2023, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -693,6 +693,119 @@ inline cusolverStatus_t CUSOLVERAPI cusolverDngesvdj( // NOLINT
return cusolverDnDgesvdj(
handle, jobz, econ, m, n, A, lda, S, U, ldu, V, ldv, work, lwork, info, params);
}

#if CUDART_VERSION >= 11010
template <typename T>
cusolverStatus_t cusolverDnxgesvdr_bufferSize( // NOLINT
cusolverDnHandle_t handle,
signed char jobu,
signed char jobv,
int64_t m,
int64_t n,
int64_t k,
int64_t p,
int64_t niters,
const T* a,
int64_t lda,
const T* Srand,
const T* Urand,
int64_t ldUrand,
const T* Vrand,
int64_t ldVrand,
size_t* workspaceInBytesOnDevice,
size_t* workspaceInBytesOnHost,
cudaStream_t stream)
{
RAFT_EXPECTS(std::is_floating_point_v<T>, "Unsupported data type");
cudaDataType dataType = std::is_same_v<T, float> ? CUDA_R_32F : CUDA_R_64F;
RAFT_CUSOLVER_TRY(cusolverDnSetStream(handle, stream));
cusolverDnParams_t dn_params = nullptr;
RAFT_CUSOLVER_TRY(cusolverDnCreateParams(&dn_params));
auto result = cusolverDnXgesvdr_bufferSize(handle,
dn_params,
jobu,
jobv,
m,
n,
k,
p,
niters,
dataType,
a,
lda,
dataType,
Srand,
dataType,
Urand,
ldUrand,
dataType,
Vrand,
ldVrand,
dataType,
workspaceInBytesOnDevice,
workspaceInBytesOnHost);
RAFT_CUSOLVER_TRY(cusolverDnDestroyParams(dn_params));
return result;
}
template <typename T>
cusolverStatus_t cusolverDnxgesvdr( // NOLINT
cusolverDnHandle_t handle,
signed char jobu,
signed char jobv,
int64_t m,
int64_t n,
int64_t k,
int64_t p,
int64_t niters,
T* a,
int64_t lda,
T* Srand,
T* Urand,
int64_t ldUrand,
T* Vrand,
int64_t ldVrand,
void* bufferOnDevice,
size_t workspaceInBytesOnDevice,
void* bufferOnHost,
size_t workspaceInBytesOnHost,
int* d_info,
cudaStream_t stream)
{
cudaDataType dataType = std::is_same_v<T, float> ? CUDA_R_32F : CUDA_R_64F;
RAFT_CUSOLVER_TRY(cusolverDnSetStream(handle, stream));
cusolverDnParams_t dn_params = nullptr;
RAFT_CUSOLVER_TRY(cusolverDnCreateParams(&dn_params));
auto result = cusolverDnXgesvdr(handle,
dn_params,
jobu,
jobv,
m,
n,
k,
p,
niters,
dataType,
a,
lda,
dataType,
Srand,
dataType,
Urand,
ldUrand,
dataType,
Vrand,
ldVrand,
dataType,
bufferOnDevice,
workspaceInBytesOnDevice,
bufferOnHost,
workspaceInBytesOnHost,
d_info);
RAFT_CUSOLVER_TRY(cusolverDnDestroyParams(dn_params));
return result;
}
#endif // CUDART_VERSION >= 11010

/** @} */

/**
Expand Down
Loading

0 comments on commit 99655ea

Please sign in to comment.