Skip to content

Commit

Permalink
Merge branch 'branch-21.12' into bug-2202-fix_macro_collision
Browse files Browse the repository at this point in the history
  • Loading branch information
cjnolet authored Nov 15, 2021
2 parents dfbb6a4 + 224e16e commit 235a6f4
Show file tree
Hide file tree
Showing 70 changed files with 1,533 additions and 897 deletions.
87 changes: 83 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,15 +1,94 @@
# <div align="left"><img src="https://rapids.ai/assets/images/rapids_logo.png" width="90px"/>&nbsp;RAFT: RAPIDS Analytics Frameworks Toolset</div>
# <div align="left"><img src="https://rapids.ai/assets/images/rapids_logo.png" width="90px"/>&nbsp;RAFT: RAPIDS Analytics Framework Toolkit</div>

RAFT is a repository containining shared utilities, mathematical operations and common functions for the analytics components of RAPIDS. Both the C++ and Python components can be included in consuming libraries.
RAFT is a library containing building-blocks for rapid composition of RAPIDS Analytics. These building-blocks include shared representations, mathematical computational primitives, and utilities that accelerate building analytics and data science algorithms in the RAPIDS ecosystem. Both the C++ and Python components can be included in consuming libraries, providing building-blocks for both dense and sparse matrix formats in the following general categories:
#####
| Category | Description / Examples |
| --- | --- |
| **Data Formats** | tensor representations and conversions for both sparse and dense formats |
| **Data Generation** | graph, spatial, and machine learning dataset generation |
| **Dense Operations** | linear algebra, statistics |
| **Spatial** | pairwise distances, nearest neighbors, neighborhood / proximity graph construction |
| **Sparse/Graph Operations** | linear algebra, statistics, slicing, msf, spectral embedding/clustering, slhc, vertex degree |
| **Solvers** | eigenvalue decomposition, least squares, lanczos |
| **Tools** | multi-node multi-gpu communicator, utilities |

By taking a primitives-based approach to algorithm development, RAFT accelerates algorithm construction time and reduces
the maintenance burden by maximizing reuse across projects. RAFT relies on the [RAPIDS memory manager (RMM)](https://github.com/rapidsai/rmm) which,
like other projects in the RAPIDS ecosystem, eases the burden of configuring different allocation strategies globally
across the libraries that use it. RMM also provides RAII wrappers around device arrays that handle the allocation and cleanup.

## Getting started

Refer to the [Build and Development Guide](BUILD.md) for details on RAFT's design, building, testing and development guidelines.

Most of the primitives in RAFT accept a `raft::handle_t` object for the management of resources which are expensive to create, such CUDA streams, stream pools, and handles to other CUDA libraries like `cublas` and `cusolver`.


### C++ Example

The example below demonstrates creating a RAFT handle and using it with RMM's `device_uvector` to allocate memory on device and compute
pairwise Euclidean distances:
```c++
#include <raft/handle.hpp>
#include <raft/distance/distance.hpp>

#include <rmm/device_uvector.hpp>
raft::handle_t handle;

int n_samples = ...;
int n_features = ...;

rmm::device_uvector<float> input(n_samples * n_features, handle.get_stream());
rmm::device_uvector<float> output(n_samples * n_samples, handle.get_stream());

// ... Populate feature matrix ...

auto metric = raft::distance::DistanceType::L2SqrtExpanded;
rmm::device_uvector<char> workspace(0, handle.get_stream());
raft::distance::pairwise_distance(handle, input.data(), input.data(),
output.data(),
n_samples, n_samples, n_features,
workspace.data(), metric);
```
## Folder Structure and Contents
The folder structure mirrors the main RAPIDS repos (cuDF, cuML, cuGraph...), with the following folders:
The folder structure mirrors other RAPIDS repos (cuDF, cuML, cuGraph...), with the following folders:
- `cpp`: Source code for all C++ code. The code is header only, therefore it is in the `include` folder (with no `src`).
- `cpp`: Source code for all C++ code. The code is currently header-only, therefore it is in the `include` folder (with no `src`).
- `python`: Source code for all Python source code.
- `ci`: Scripts for running CI in PRs
[comment]: <> (TODO: This needs to be updated after the public API is established)
[comment]: <> (The library layout contains the following structure:)
[comment]: <> (```bash)
[comment]: <> (cpp/include/raft)
[comment]: <> ( |------------ comms [communication abstraction layer])
[comment]: <> ( |------------ distance [dense pairwise distances])
[comment]: <> ( |------------ linalg [dense linear algebra])
[comment]: <> ( |------------ matrix [dense matrix format])
[comment]: <> ( |------------ random [random matrix generation])
[comment]: <> ( |------------ sparse [sparse matrix and graph algorithms])
[comment]: <> ( |------------ spatial [spatial algorithms])
[comment]: <> ( |------------ spectral [spectral clustering])
[comment]: <> ( |------------ stats [statistics primitives])
[comment]: <> ( |------------ handle.hpp [raft handle])
[comment]: <> (```)
2 changes: 1 addition & 1 deletion ci/gpu/build.sh
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ function hasArg {

# Set path and build parallel level
export PATH=/opt/conda/bin:/usr/local/cuda/bin:$PATH
export PARALLEL_LEVEL=${PARALLEL_LEVEL:-4}
export PARALLEL_LEVEL=${PARALLEL_LEVEL:-8}
export CUDA_REL=${CUDA_VERSION%.*}

# Set home to the job's workspace
Expand Down
4 changes: 3 additions & 1 deletion cpp/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
#=============================================================================

cmake_minimum_required(VERSION 3.20.1 FATAL_ERROR)
file(DOWNLOAD https://raw.githubusercontent.com/rapidsai/rapids-cmake/branch-22.02/RAPIDS.cmake
file(DOWNLOAD https://raw.githubusercontent.com/rapidsai/rapids-cmake/branch-21.12/RAPIDS.cmake
${CMAKE_BINARY_DIR}/RAPIDS.cmake)
include(${CMAKE_BINARY_DIR}/RAPIDS.cmake)
include(rapids-cmake)
Expand Down Expand Up @@ -100,8 +100,10 @@ endif()
# add third party dependencies using CPM
rapids_cpm_init()

# thrust and libcudacxx need to be before cuco!
include(cmake/thirdparty/get_thrust.cmake)
include(cmake/thirdparty/get_rmm.cmake)
include(cmake/thirdparty/get_libcudacxx.cmake)
include(cmake/thirdparty/get_cuco.cmake)

if(BUILD_TESTS)
Expand Down
21 changes: 21 additions & 0 deletions cpp/cmake/libcudacxx.patch
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
diff --git a/include/cuda/std/detail/__config b/include/cuda/std/detail/__config
index d55a43688..654142d7e 100644
--- a/include/cuda/std/detail/__config
+++ b/include/cuda/std/detail/__config
@@ -23,7 +23,7 @@
#define _LIBCUDACXX_CUDACC_VER_MINOR __CUDACC_VER_MINOR__
#define _LIBCUDACXX_CUDACC_VER_BUILD __CUDACC_VER_BUILD__
#define _LIBCUDACXX_CUDACC_VER \
- _LIBCUDACXX_CUDACC_VER_MAJOR * 10000 + _LIBCUDACXX_CUDACC_VER_MINOR * 100 + \
+ _LIBCUDACXX_CUDACC_VER_MAJOR * 100000 + _LIBCUDACXX_CUDACC_VER_MINOR * 1000 + \
_LIBCUDACXX_CUDACC_VER_BUILD

#define _LIBCUDACXX_HAS_NO_LONG_DOUBLE
@@ -64,7 +64,7 @@
# endif
#endif

-#if defined(_LIBCUDACXX_COMPILER_MSVC) || (defined(_LIBCUDACXX_CUDACC_VER) && (_LIBCUDACXX_CUDACC_VER < 110500))
+#if defined(_LIBCUDACXX_COMPILER_MSVC) || (defined(_LIBCUDACXX_CUDACC_VER) && (_LIBCUDACXX_CUDACC_VER < 1105000))
# define _LIBCUDACXX_HAS_NO_INT128
#endif
2 changes: 1 addition & 1 deletion cpp/cmake/thirdparty/get_cuco.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ function(find_and_configure_cuco VERSION)
INSTALL_EXPORT_SET raft-exports
CPM_ARGS
GIT_REPOSITORY https://github.com/NVIDIA/cuCollections.git
GIT_TAG 729857a5698a0e8d8f812e0464f65f37854ae17b
GIT_TAG f0eecb203590f1f4ac4a9f1700229f4434ac64dc
OPTIONS "BUILD_TESTS OFF"
"BUILD_BENCHMARKS OFF"
"BUILD_EXAMPLES OFF"
Expand Down
26 changes: 26 additions & 0 deletions cpp/cmake/thirdparty/get_libcudacxx.cmake
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
# =============================================================================
# Copyright (c) 2020-2021, NVIDIA CORPORATION.
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except
# in compliance with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software distributed under the License
# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
# or implied. See the License for the specific language governing permissions and limitations under
# the License.
# =============================================================================

# This function finds libcudacxx and sets any additional necessary environment variables.
function(find_and_configure_libcudacxx)
include(${rapids-cmake-dir}/cpm/libcudacxx.cmake)

rapids_cpm_libcudacxx(
BUILD_EXPORT_SET raft-exports INSTALL_EXPORT_SET raft-exports PATCH_COMMAND patch
--reject-file=- -p1 -N < ${RAFT_SOURCE_DIR}/cmake/libcudacxx.patch || true
)

endfunction()

find_and_configure_libcudacxx()
71 changes: 71 additions & 0 deletions cpp/include/raft/linalg/cusolver_wrappers.h
Original file line number Diff line number Diff line change
Expand Up @@ -719,5 +719,76 @@ inline cusolverStatus_t cusolverSpcsrqrsvBatched( // NOLINT
}
/** @} */

#if CUDART_VERSION >= 11010
/**
* @defgroup DnXsyevd cusolver DnXsyevd operations
* @{
*/
template <typename T>
cusolverStatus_t cusolverDnxsyevd_bufferSize( // NOLINT
cusolverDnHandle_t handle, cusolverDnParams_t params, cusolverEigMode_t jobz,
cublasFillMode_t uplo, int64_t n, const T *A, int64_t lda, const T *W,
size_t *workspaceInBytesOnDevice, size_t *workspaceInBytesOnHost,
cudaStream_t stream);

template <>
inline cusolverStatus_t cusolverDnxsyevd_bufferSize( // NOLINT
cusolverDnHandle_t handle, cusolverDnParams_t params, cusolverEigMode_t jobz,
cublasFillMode_t uplo, int64_t n, const float *A, int64_t lda, const float *W,
size_t *workspaceInBytesOnDevice, size_t *workspaceInBytesOnHost,
cudaStream_t stream) {
CUSOLVER_CHECK(cusolverDnSetStream(handle, stream));
return cusolverDnXsyevd_bufferSize(
handle, params, jobz, uplo, n, CUDA_R_32F, A, lda, CUDA_R_32F, W,
CUDA_R_32F, workspaceInBytesOnDevice, workspaceInBytesOnHost);
}

template <>
inline cusolverStatus_t cusolverDnxsyevd_bufferSize( // NOLINT
cusolverDnHandle_t handle, cusolverDnParams_t params, cusolverEigMode_t jobz,
cublasFillMode_t uplo, int64_t n, const double *A, int64_t lda,
const double *W, size_t *workspaceInBytesOnDevice,
size_t *workspaceInBytesOnHost, cudaStream_t stream) {
CUSOLVER_CHECK(cusolverDnSetStream(handle, stream));
return cusolverDnXsyevd_bufferSize(
handle, params, jobz, uplo, n, CUDA_R_64F, A, lda, CUDA_R_64F, W,
CUDA_R_64F, workspaceInBytesOnDevice, workspaceInBytesOnHost);
}

template <typename T>
cusolverStatus_t cusolverDnxsyevd( // NOLINT
cusolverDnHandle_t handle, cusolverDnParams_t params, cusolverEigMode_t jobz,
cublasFillMode_t uplo, int64_t n, T *A, int64_t lda, T *W, T *bufferOnDevice,
size_t workspaceInBytesOnDevice, T *bufferOnHost,
size_t workspaceInBytesOnHost, int *info, cudaStream_t stream);

template <>
inline cusolverStatus_t cusolverDnxsyevd( // NOLINT
cusolverDnHandle_t handle, cusolverDnParams_t params, cusolverEigMode_t jobz,
cublasFillMode_t uplo, int64_t n, float *A, int64_t lda, float *W,
float *bufferOnDevice, size_t workspaceInBytesOnDevice, float *bufferOnHost,
size_t workspaceInBytesOnHost, int *info, cudaStream_t stream) {
CUSOLVER_CHECK(cusolverDnSetStream(handle, stream));
return cusolverDnXsyevd(handle, params, jobz, uplo, n, CUDA_R_32F, A, lda,
CUDA_R_32F, W, CUDA_R_32F, bufferOnDevice,
workspaceInBytesOnDevice, bufferOnHost,
workspaceInBytesOnHost, info);
}

template <>
inline cusolverStatus_t cusolverDnxsyevd( // NOLINT
cusolverDnHandle_t handle, cusolverDnParams_t params, cusolverEigMode_t jobz,
cublasFillMode_t uplo, int64_t n, double *A, int64_t lda, double *W,
double *bufferOnDevice, size_t workspaceInBytesOnDevice, double *bufferOnHost,
size_t workspaceInBytesOnHost, int *info, cudaStream_t stream) {
CUSOLVER_CHECK(cusolverDnSetStream(handle, stream));
return cusolverDnXsyevd(handle, params, jobz, uplo, n, CUDA_R_64F, A, lda,
CUDA_R_64F, W, CUDA_R_64F, bufferOnDevice,
workspaceInBytesOnDevice, bufferOnHost,
workspaceInBytesOnHost, info);
}
/** @} */
#endif

} // namespace linalg
} // namespace raft
76 changes: 59 additions & 17 deletions cpp/include/raft/linalg/eig.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -21,13 +21,41 @@
#include <raft/linalg/cusolver_wrappers.h>
#include <raft/cuda_utils.cuh>
#include <raft/handle.hpp>
#include <raft/matrix/matrix.cuh>
#include <raft/matrix/matrix.hpp>
#include <rmm/device_scalar.hpp>
#include <rmm/device_uvector.hpp>

namespace raft {
namespace linalg {

template <typename math_t>
void eigDC_legacy(const raft::handle_t &handle, const math_t *in,
std::size_t n_rows, std::size_t n_cols, math_t *eig_vectors,
math_t *eig_vals, cudaStream_t stream) {
cusolverDnHandle_t cusolverH = handle.get_cusolver_dn_handle();

int lwork;
CUSOLVER_CHECK(cusolverDnsyevd_bufferSize(cusolverH, CUSOLVER_EIG_MODE_VECTOR,
CUBLAS_FILL_MODE_UPPER, n_rows, in,
n_cols, eig_vals, &lwork));

rmm::device_uvector<math_t> d_work(lwork, stream);
rmm::device_scalar<int> d_dev_info(stream);

raft::matrix::copy(in, eig_vectors, n_rows, n_cols, stream);

CUSOLVER_CHECK(cusolverDnsyevd(cusolverH, CUSOLVER_EIG_MODE_VECTOR,
CUBLAS_FILL_MODE_UPPER, n_rows, eig_vectors,
n_cols, eig_vals, d_work.data(), lwork,
d_dev_info.data(), stream));
CUDA_CHECK(cudaGetLastError());

auto dev_info = d_dev_info.value(stream);
ASSERT(dev_info == 0,
"eig.cuh: eigensolver couldn't converge to a solution. "
"This usually occurs when some of the features do not vary enough.");
}

/**
* @defgroup eig decomp with divide and conquer method for the column-major
* symmetric matrices
Expand All @@ -42,31 +70,43 @@ namespace linalg {
* @{
*/
template <typename math_t>
void eigDC(const raft::handle_t &handle, const math_t *in, int n_rows,
int n_cols, math_t *eig_vectors, math_t *eig_vals,
void eigDC(const raft::handle_t &handle, const math_t *in, std::size_t n_rows,
std::size_t n_cols, math_t *eig_vectors, math_t *eig_vals,
cudaStream_t stream) {
#if CUDART_VERSION < 11010
eigDC_legacy(handle, in, n_rows, n_cols, eig_vectors, eig_vals, stream);
#else
cusolverDnHandle_t cusolverH = handle.get_cusolver_dn_handle();

int lwork;
CUSOLVER_CHECK(cusolverDnsyevd_bufferSize(cusolverH, CUSOLVER_EIG_MODE_VECTOR,
CUBLAS_FILL_MODE_UPPER, n_rows, in,
n_cols, eig_vals, &lwork));
cusolverDnParams_t dn_params = nullptr;
CUSOLVER_CHECK(cusolverDnCreateParams(&dn_params));

rmm::device_uvector<math_t> d_work(lwork, stream);
size_t workspaceDevice = 0;
size_t workspaceHost = 0;
CUSOLVER_CHECK(cusolverDnxsyevd_bufferSize(
cusolverH, dn_params, CUSOLVER_EIG_MODE_VECTOR, CUBLAS_FILL_MODE_UPPER,
static_cast<int64_t>(n_rows), eig_vectors, static_cast<int64_t>(n_cols),
eig_vals, &workspaceDevice, &workspaceHost, stream));

rmm::device_uvector<math_t> d_work(workspaceDevice / sizeof(math_t), stream);
rmm::device_scalar<int> d_dev_info(stream);
std::vector<math_t> h_work(workspaceHost / sizeof(math_t));

raft::matrix::copy(in, eig_vectors, n_rows, n_cols, stream);

CUSOLVER_CHECK(cusolverDnsyevd(cusolverH, CUSOLVER_EIG_MODE_VECTOR,
CUBLAS_FILL_MODE_UPPER, n_rows, eig_vectors,
n_cols, eig_vals, d_work.data(), lwork,
d_dev_info.data(), stream));
CUDA_CHECK(cudaGetLastError());
CUSOLVER_CHECK(cusolverDnxsyevd(
cusolverH, dn_params, CUSOLVER_EIG_MODE_VECTOR, CUBLAS_FILL_MODE_UPPER,
static_cast<int64_t>(n_rows), eig_vectors, static_cast<int64_t>(n_cols),
eig_vals, d_work.data(), workspaceDevice, h_work.data(), workspaceHost,
d_dev_info.data(), stream));

CUDA_CHECK(cudaGetLastError());
CUSOLVER_CHECK(cusolverDnDestroyParams(dn_params));
int dev_info = d_dev_info.value(stream);
ASSERT(dev_info == 0,
"eig.cuh: eigensolver couldn't converge to a solution. "
"This usually occurs when some of the features do not vary enough.");
#endif
}

enum EigVecMemUsage { OVERWRITE_INPUT, COPY_INPUT };
Expand Down Expand Up @@ -155,15 +195,17 @@ void eigSelDC(const raft::handle_t &handle, math_t *in, int n_rows, int n_cols,
* @{
*/
template <typename math_t>
void eigJacobi(const raft::handle_t &handle, const math_t *in, int n_rows,
int n_cols, math_t *eig_vectors, math_t *eig_vals,
cudaStream_t stream, math_t tol = 1.e-7, int sweeps = 15) {
void eigJacobi(const raft::handle_t &handle, const math_t *in,
std::size_t n_rows, std::size_t n_cols, math_t *eig_vectors,
math_t *eig_vals, cudaStream_t stream, math_t tol = 1.e-7,
std::uint32_t sweeps = 15) {
cusolverDnHandle_t cusolverH = handle.get_cusolver_dn_handle();

syevjInfo_t syevj_params = nullptr;
CUSOLVER_CHECK(cusolverDnCreateSyevjInfo(&syevj_params));
CUSOLVER_CHECK(cusolverDnXsyevjSetTolerance(syevj_params, tol));
CUSOLVER_CHECK(cusolverDnXsyevjSetMaxSweeps(syevj_params, sweeps));
CUSOLVER_CHECK(
cusolverDnXsyevjSetMaxSweeps(syevj_params, static_cast<int>(sweeps)));

int lwork;
CUSOLVER_CHECK(cusolverDnsyevj_bufferSize(
Expand Down
2 changes: 1 addition & 1 deletion cpp/include/raft/linalg/qr.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@

#include <raft/linalg/cublas_wrappers.h>
#include <raft/linalg/cusolver_wrappers.h>
#include <raft/matrix/matrix.cuh>
#include <raft/matrix/matrix.hpp>
#include <rmm/device_scalar.hpp>
#include <rmm/device_uvector.hpp>

Expand Down
Loading

0 comments on commit 235a6f4

Please sign in to comment.