diff --git a/CHANGELOG.md b/CHANGELOG.md index 1492965c80..24c4c7f12b 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,4 +1,4 @@ -# raft 0.19.0 (Date TBD) +# RAFT 0.19.0 (Date TBD) ## New Features @@ -8,13 +8,9 @@ # RAFT 0.18.0 (Date TBD) -## New Features - -## Improvements - -## Bug Fixes +Please see https://github.com/rapidsai/raft/releases/tag/branch-0.18-latest for the latest changes to this development branch. -# RAFT 0.17.0 (Date TBD) +# RAFT 0.17.0 (10 Dec 2020) ## New Features - PR #65: Adding cuml prims that break circular dependency between cuml and cumlprims projects diff --git a/build.sh b/build.sh index 213aea9347..b05e002788 100755 --- a/build.sh +++ b/build.sh @@ -18,7 +18,7 @@ ARGS=$* # script, and that this script resides in the repo dir! REPODIR=$(cd $(dirname $0); pwd) -VALIDARGS="clean cppraft pyraft -v -g --allgpuarch --nvtx --show_depr_warn -h --buildgtest" +VALIDARGS="clean cppraft pyraft -v -g --allgpuarch --nvtx --show_depr_warn -h --buildgtest --buildfaiss" HELP="$0 [ ...] [ ...] where is: clean - remove all existing build artifacts and configuration (start over) @@ -29,6 +29,7 @@ HELP="$0 [ ...] [ ...] -v - verbose build mode -g - build for debug --allgpuarch - build for all supported GPU architectures + --buildfaiss - build faiss statically into raft --nvtx - Enable nvtx for profiling support --show_depr_warn - show cmake deprecation warnings -h - print this text @@ -44,6 +45,7 @@ BUILD_DIRS="${CPP_RAFT_BUILD_DIR} ${PY_RAFT_BUILD_DIR} ${PYTHON_DEPS_CLONE}" VERBOSE="" BUILD_ALL_GPU_ARCH=0 BUILD_GTEST=OFF +BUILD_STATIC_FAISS=OFF SINGLEGPU="" NVTX=OFF CLEAN=0 @@ -89,6 +91,9 @@ fi if hasArg --buildgtest; then BUILD_GTEST=ON fi +if hasArg --buildfaiss; then + BUILD_STATIC_FAISS=ON +fi if hasArg --singlegpu; then SINGLEGPU="--singlegpu" fi @@ -140,6 +145,7 @@ if (( ${NUMARGS} == 0 )) || hasArg cppraft; then -DNCCL_PATH=${INSTALL_PREFIX} \ -DDISABLE_DEPRECATION_WARNING=${BUILD_DISABLE_DEPRECATION_WARNING} \ -DBUILD_GTEST=${BUILD_GTEST} \ + -DBUILD_STATIC_FAISS=${BUILD_STATIC_FAISS} \ .. fi diff --git a/cpp/CMakeLists.txt b/cpp/CMakeLists.txt index 263a12333d..3baee48a5f 100644 --- a/cpp/CMakeLists.txt +++ b/cpp/CMakeLists.txt @@ -42,6 +42,8 @@ set(CMAKE_EXPORT_COMPILE_COMMANDS ON) option(BUILD_GTEST "Build the GTEST library for running raft test executables" OFF) +option(BUILD_STATIC_FAISS "Build the FAISS library for nearest neighbors search on GPU" OFF) + option(CMAKE_CXX11_ABI "Enable the GLIBCXX11 ABI" ON) option(EMPTY_MARKER_KERNEL "Enable empty marker kernel after nvtxRangePop" ON) @@ -175,6 +177,18 @@ endif() include(cmake/Dependencies.cmake) include(cmake/comms.cmake) +################################################################################################### +# - FAISS ------------------------------------------------------------------------------------------- + +if(NOT BUILD_STATIC_FAISS) + find_path(FAISS_INCLUDE_DIRS "faiss" + HINTS + "$ENV{FAISS_ROOT}/include" + "$ENV{CONDA_PREFIX}/include/faiss" + "$ENV{CONDA_PREFIX}/include") +endif(NOT BUILD_STATIC_FAISS) +message(STATUS "FAISS: FAISS_INCLUDE_DIRS set to ${FAISS_INCLUDE_DIRS}") + ################################################################################################### # - RMM ------------------------------------------------------------------------------------------- @@ -196,6 +210,7 @@ set(RAFT_INCLUDE_DIR ${CMAKE_CURRENT_SOURCE_DIR}/include CACHE STRING set(RAFT_INCLUDE_DIRECTORIES ${RAFT_INCLUDE_DIR} ${CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES} + ${FAISS_INCLUDE_DIRS} ${RMM_INCLUDE_DIRS}) if(NOT CUB_IS_PART_OF_CTK) @@ -218,6 +233,7 @@ set(RAFT_LINK_LIBRARIES ${CUDA_curand_LIBRARY}) set(RAFT_LINK_DIRECTORIES + ${FAISS_INCLUDE_DIRS} ${RMM_INCLUDE_DIRS}) if(DEFINED ENV{CONDA_PREFIX}) @@ -261,6 +277,7 @@ if(BUILD_RAFT_TESTS) test/random/rng.cu test/random/rng_int.cu test/random/sample_without_replacement.cu + test/spatial/knn.cu test/stats/mean.cu test/stats/mean_center.cu test/stats/stddev.cu @@ -283,6 +300,7 @@ if(BUILD_RAFT_TESTS) target_link_libraries(test_raft PRIVATE ${RAFT_LINK_LIBRARIES} + FAISS::FAISS GTest::GTest GTest::Main OpenMP::OpenMP_CXX diff --git a/cpp/cmake/Dependencies.cmake b/cpp/cmake/Dependencies.cmake index 64033327d6..080efb5b1f 100644 --- a/cpp/cmake/Dependencies.cmake +++ b/cpp/cmake/Dependencies.cmake @@ -30,6 +30,48 @@ if(NOT CUB_IS_PART_OF_CTK) INSTALL_COMMAND "") endif(NOT CUB_IS_PART_OF_CTK) +############################################################################## +# - faiss -------------------------------------------------------------------- + +if(BUILD_STATIC_FAISS) + set(FAISS_DIR ${CMAKE_CURRENT_BINARY_DIR}/faiss CACHE STRING + "Path to FAISS source directory") + ExternalProject_Add(faiss + GIT_REPOSITORY https://github.com/facebookresearch/faiss.git + GIT_TAG a5b850dec6f1cd6c88ab467bfd5e87b0cac2e41d + CONFIGURE_COMMAND LIBS=-pthread + CPPFLAGS=-w + LDFLAGS=-L${CMAKE_INSTALL_PREFIX}/lib + ${CMAKE_CURRENT_BINARY_DIR}/faiss/src/faiss/configure + --prefix=${CMAKE_CURRENT_BINARY_DIR}/faiss + --with-blas=${BLAS_LIBRARIES} + --with-cuda=${CUDA_TOOLKIT_ROOT_DIR} + --with-cuda-arch=${FAISS_GPU_ARCHS} + -v + PREFIX ${FAISS_DIR} + BUILD_COMMAND make -j${PARALLEL_LEVEL} VERBOSE=1 + BUILD_BYPRODUCTS ${FAISS_DIR}/lib/libfaiss.a + BUILD_ALWAYS 1 + INSTALL_COMMAND make -s install > /dev/null + UPDATE_COMMAND "" + BUILD_IN_SOURCE 1 + PATCH_COMMAND patch -p1 -N < ${CMAKE_CURRENT_SOURCE_DIR}/cmake/faiss_cuda11.patch || true) + + ExternalProject_Get_Property(faiss install_dir) + add_library(FAISS::FAISS STATIC IMPORTED) + set_property(TARGET FAISS::FAISS PROPERTY + IMPORTED_LOCATION ${FAISS_DIR}/lib/libfaiss.a) + # to account for the FAISS file reorg that happened recently after the current + # pinned commit, just change the following line to + # set(FAISS_INCLUDE_DIRS "${FAISS_DIR}/src/faiss") + set(FAISS_INCLUDE_DIRS "${FAISS_DIR}/src") +else() + add_library(FAISS::FAISS SHARED IMPORTED) + set_property(TARGET FAISS::FAISS PROPERTY + IMPORTED_LOCATION $ENV{CONDA_PREFIX}/lib/libfaiss.so) + message(STATUS "Found FAISS: $ENV{CONDA_PREFIX}/lib/libfaiss.so") +endif(BUILD_STATIC_FAISS) + ############################################################################## # - googletest --------------------------------------------------------------- @@ -65,4 +107,5 @@ endif(BUILD_GTEST) if(NOT CUB_IS_PART_OF_CTK) add_dependencies(GTest::GTest cub) -endif(NOT CUB_IS_PART_OF_CTK) +endif(NOT CUB_IS_PART_OF_CTK) +add_dependencies(FAISS::FAISS faiss) diff --git a/cpp/cmake/faiss_cuda11.patch b/cpp/cmake/faiss_cuda11.patch new file mode 100644 index 0000000000..496ca0e7b2 --- /dev/null +++ b/cpp/cmake/faiss_cuda11.patch @@ -0,0 +1,40 @@ +diff --git a/configure b/configure +index ed40dae..f88ed0a 100755 +--- a/configure ++++ b/configure +@@ -2970,7 +2970,7 @@ ac_link='$CXX -o conftest$ac_exeext $CXXFLAGS $CPPFLAGS $LDFLAGS conftest.$ac_ex + ac_compiler_gnu=$ac_cv_cxx_compiler_gnu + + +- ax_cxx_compile_alternatives="11 0x" ax_cxx_compile_cxx11_required=true ++ ax_cxx_compile_alternatives="14 11 0x" ax_cxx_compile_cxx11_required=true + ac_ext=cpp + ac_cpp='$CXXCPP $CPPFLAGS' + ac_compile='$CXX -c $CXXFLAGS $CPPFLAGS conftest.$ac_ext >&5' +diff --git a/gpu/utils/DeviceDefs.cuh b/gpu/utils/DeviceDefs.cuh +index 89d3dda..bc0f9b5 100644 +--- a/gpu/utils/DeviceDefs.cuh ++++ b/gpu/utils/DeviceDefs.cuh +@@ -13,7 +13,7 @@ + namespace faiss { namespace gpu { + + #ifdef __CUDA_ARCH__ +-#if __CUDA_ARCH__ <= 750 ++#if __CUDA_ARCH__ <= 800 + constexpr int kWarpSize = 32; + #else + #error Unknown __CUDA_ARCH__; please define parameters for compute capability +diff --git a/gpu/utils/MatrixMult-inl.cuh b/gpu/utils/MatrixMult-inl.cuh +index ede225e..4f7eb44 100644 +--- a/gpu/utils/MatrixMult-inl.cuh ++++ b/gpu/utils/MatrixMult-inl.cuh +@@ -51,6 +51,9 @@ rawGemm(cublasHandle_t handle, + auto cBT = GetCudaType::Type; + + // Always accumulate in f32 ++# if __CUDACC_VER_MAJOR__ >= 11 ++ cublasSetMathMode(handle, CUBLAS_MATH_DISALLOW_REDUCED_PRECISION_REDUCTION); ++# endif + return cublasSgemmEx(handle, transa, transb, m, n, k, + &fAlpha, A, cAT, lda, + B, cBT, ldb, diff --git a/cpp/include/raft/comms/comms.hpp b/cpp/include/raft/comms/comms.hpp index 73e52e781b..0ca9f3972f 100644 --- a/cpp/include/raft/comms/comms.hpp +++ b/cpp/include/raft/comms/comms.hpp @@ -130,6 +130,15 @@ class comms_iface { const size_t* recvcounts, const size_t* displs, datatype_t datatype, cudaStream_t stream) const = 0; + virtual void gather(const void* sendbuff, void* recvbuff, size_t sendcount, + datatype_t datatype, int root, + cudaStream_t stream) const = 0; + + virtual void gatherv(const void* sendbuf, void* recvbuf, size_t sendcount, + const size_t* recvcounts, const size_t* displs, + datatype_t datatype, int root, + cudaStream_t stream) const = 0; + virtual void reducescatter(const void* sendbuff, void* recvbuff, size_t recvcount, datatype_t datatype, op_t op, cudaStream_t stream) const = 0; @@ -316,6 +325,45 @@ class comms_t { get_type(), stream); } + /** + * Gathers data from each rank onto all ranks + * @tparam value_t datatype of underlying buffers + * @param sendbuff buffer containing data to gather + * @param recvbuff buffer containing gathered data from all ranks + * @param sendcount number of elements in send buffer + * @param root rank to store the results + * @param stream CUDA stream to synchronize operation + */ + template + void gather(const value_t* sendbuff, value_t* recvbuff, size_t sendcount, + int root, cudaStream_t stream) const { + impl_->gather(static_cast(sendbuff), + static_cast(recvbuff), sendcount, get_type(), + root, stream); + } + + /** + * Gathers data from all ranks and delivers to combined data to all ranks + * @param value_t datatype of underlying buffers + * @param sendbuff buffer containing data to send + * @param recvbuff buffer containing data to receive + * @param sendcount number of elements in send buffer + * @param recvcounts pointer to an array (of length num_ranks size) containing the number of + * elements that are to be received from each rank + * @param displs pointer to an array (of length num_ranks size) to specify the displacement + * (relative to recvbuf) at which to place the incoming data from each rank + * @param root rank to store the results + * @param stream CUDA stream to synchronize operation + */ + template + void gatherv(const value_t* sendbuf, value_t* recvbuf, size_t sendcount, + const size_t* recvcounts, const size_t* displs, int root, + cudaStream_t stream) const { + impl_->gatherv(static_cast(sendbuf), + static_cast(recvbuf), sendcount, recvcounts, displs, + get_type(), root, stream); + } + /** * Reduces data from all ranks then scatters the result across ranks * @tparam value_t datatype of underlying buffers diff --git a/cpp/include/raft/comms/mpi_comms.hpp b/cpp/include/raft/comms/mpi_comms.hpp index a372702c34..8aebcc80cc 100644 --- a/cpp/include/raft/comms/mpi_comms.hpp +++ b/cpp/include/raft/comms/mpi_comms.hpp @@ -232,6 +232,39 @@ class mpi_comms : public comms_iface { } } + void gather(const void* sendbuff, void* recvbuff, size_t sendcount, + datatype_t datatype, int root, cudaStream_t stream) const { + size_t dtype_size = get_datatype_size(datatype); + NCCL_TRY(ncclGroupStart()); + if (get_rank() == root) { + for (int r = 0; r < get_size(); ++r) { + NCCL_TRY(ncclRecv( + static_cast(recvbuff) + sendcount * r * dtype_size, sendcount, + get_nccl_datatype(datatype), r, nccl_comm_, stream)); + } + } + NCCL_TRY(ncclSend(sendbuff, sendcount, get_nccl_datatype(datatype), root, + nccl_comm_, stream)); + NCCL_TRY(ncclGroupEnd()); + } + + void gatherv(const void* sendbuff, void* recvbuff, size_t sendcount, + const size_t* recvcounts, const size_t* displs, + datatype_t datatype, int root, cudaStream_t stream) const { + size_t dtype_size = get_datatype_size(datatype); + NCCL_TRY(ncclGroupStart()); + if (get_rank() == root) { + for (int r = 0; r < get_size(); ++r) { + NCCL_TRY(ncclRecv(static_cast(recvbuff) + displs[r] * dtype_size, + recvcounts[r], get_nccl_datatype(datatype), r, + nccl_comm_, stream)); + } + } + NCCL_TRY(ncclSend(sendbuff, sendcount, get_nccl_datatype(datatype), root, + nccl_comm_, stream)); + NCCL_TRY(ncclGroupEnd()); + } + void reducescatter(const void* sendbuff, void* recvbuff, size_t recvcount, datatype_t datatype, op_t op, cudaStream_t stream) const { NCCL_TRY(ncclReduceScatter(sendbuff, recvbuff, recvcount, diff --git a/cpp/include/raft/comms/std_comms.hpp b/cpp/include/raft/comms/std_comms.hpp index d4b9d2ba39..a304955ceb 100644 --- a/cpp/include/raft/comms/std_comms.hpp +++ b/cpp/include/raft/comms/std_comms.hpp @@ -346,6 +346,39 @@ class std_comms : public comms_iface { } } + void gather(const void *sendbuff, void *recvbuff, size_t sendcount, + datatype_t datatype, int root, cudaStream_t stream) const { + size_t dtype_size = get_datatype_size(datatype); + NCCL_TRY(ncclGroupStart()); + if (get_rank() == root) { + for (int r = 0; r < get_size(); ++r) { + NCCL_TRY(ncclRecv( + static_cast(recvbuff) + sendcount * r * dtype_size, sendcount, + get_nccl_datatype(datatype), r, nccl_comm_, stream)); + } + } + NCCL_TRY(ncclSend(sendbuff, sendcount, get_nccl_datatype(datatype), root, + nccl_comm_, stream)); + NCCL_TRY(ncclGroupEnd()); + } + + void gatherv(const void *sendbuff, void *recvbuff, size_t sendcount, + const size_t *recvcounts, const size_t *displs, + datatype_t datatype, int root, cudaStream_t stream) const { + size_t dtype_size = get_datatype_size(datatype); + NCCL_TRY(ncclGroupStart()); + if (get_rank() == root) { + for (int r = 0; r < get_size(); ++r) { + NCCL_TRY(ncclRecv( + static_cast(recvbuff) + displs[r] * dtype_size, recvcounts[r], + get_nccl_datatype(datatype), r, nccl_comm_, stream)); + } + } + NCCL_TRY(ncclSend(sendbuff, sendcount, get_nccl_datatype(datatype), root, + nccl_comm_, stream)); + NCCL_TRY(ncclGroupEnd()); + } + void reducescatter(const void *sendbuff, void *recvbuff, size_t recvcount, datatype_t datatype, op_t op, cudaStream_t stream) const { NCCL_TRY(ncclReduceScatter(sendbuff, recvbuff, recvcount, diff --git a/cpp/include/raft/comms/test.hpp b/cpp/include/raft/comms/test.hpp index fa7e471174..5dc6f02d21 100644 --- a/cpp/include/raft/comms/test.hpp +++ b/cpp/include/raft/comms/test.hpp @@ -16,11 +16,13 @@ #pragma once -#include #include #include #include +#include +#include + namespace raft { namespace comms { @@ -155,26 +157,114 @@ bool test_collective_allgather(const handle_t &handle, int root) { return true; } +bool test_collective_gather(const handle_t &handle, int root) { + comms_t const &communicator = handle.get_comms(); + + int const send = communicator.get_rank(); + + cudaStream_t stream = handle.get_stream(); + + raft::mr::device::buffer temp_d(handle.get_device_allocator(), stream); + temp_d.resize(1, stream); + + raft::mr::device::buffer recv_d( + handle.get_device_allocator(), stream, + communicator.get_rank() == root ? communicator.get_size() : 0); + + CUDA_CHECK(cudaMemcpyAsync(temp_d.data(), &send, sizeof(int), + cudaMemcpyHostToDevice, stream)); + + communicator.gather(temp_d.data(), recv_d.data(), 1, root, stream); + communicator.sync_stream(stream); + + if (communicator.get_rank() == root) { + std::vector temp_h(communicator.get_size(), 0); + CUDA_CHECK(cudaMemcpyAsync(temp_h.data(), recv_d.data(), + sizeof(int) * temp_h.size(), + cudaMemcpyDeviceToHost, stream)); + CUDA_CHECK(cudaStreamSynchronize(stream)); + + for (int i = 0; i < communicator.get_size(); i++) { + if (temp_h[i] != i) return false; + } + } + return true; +} + +bool test_collective_gatherv(const handle_t &handle, int root) { + comms_t const &communicator = handle.get_comms(); + + std::vector sendcounts(communicator.get_size()); + std::iota(sendcounts.begin(), sendcounts.end(), size_t{1}); + std::vector displacements(communicator.get_size() + 1, 0); + std::partial_sum(sendcounts.begin(), sendcounts.end(), + displacements.begin() + 1); + + std::vector sends(displacements[communicator.get_rank() + 1] - + displacements[communicator.get_rank()], + communicator.get_rank()); + + cudaStream_t stream = handle.get_stream(); + + raft::mr::device::buffer temp_d(handle.get_device_allocator(), stream); + temp_d.resize(sends.size(), stream); + + raft::mr::device::buffer recv_d( + handle.get_device_allocator(), stream, + communicator.get_rank() == root ? displacements.back() : 0); + + CUDA_CHECK(cudaMemcpyAsync(temp_d.data(), sends.data(), + sends.size() * sizeof(int), cudaMemcpyHostToDevice, + stream)); + + communicator.gatherv( + temp_d.data(), recv_d.data(), temp_d.size(), + communicator.get_rank() == root ? sendcounts.data() + : static_cast(nullptr), + communicator.get_rank() == root ? displacements.data() + : static_cast(nullptr), + root, stream); + communicator.sync_stream(stream); + + if (communicator.get_rank() == root) { + std::vector temp_h(displacements.back(), 0); + CUDA_CHECK(cudaMemcpyAsync(temp_h.data(), recv_d.data(), + sizeof(int) * displacements.back(), + cudaMemcpyDeviceToHost, stream)); + CUDA_CHECK(cudaStreamSynchronize(stream)); + + for (int i = 0; i < communicator.get_size(); i++) { + if (std::count_if(temp_h.begin() + displacements[i], + temp_h.begin() + displacements[i + 1], + [i](auto val) { return val != i; }) != 0) { + return false; + } + } + } + return true; +} + bool test_collective_reducescatter(const handle_t &handle, int root) { comms_t const &communicator = handle.get_comms(); - int const send = 1; + std::vector sends(communicator.get_size(), 1); cudaStream_t stream = handle.get_stream(); raft::mr::device::buffer temp_d(handle.get_device_allocator(), stream, - 1); + sends.size()); raft::mr::device::buffer recv_d(handle.get_device_allocator(), stream, 1); - CUDA_CHECK(cudaMemcpyAsync(temp_d.data(), &send, sizeof(int), - cudaMemcpyHostToDevice, stream)); + CUDA_CHECK(cudaMemcpyAsync(temp_d.data(), sends.data(), + sends.size() * sizeof(int), cudaMemcpyHostToDevice, + stream)); communicator.reducescatter(temp_d.data(), recv_d.data(), 1, op_t::SUM, stream); communicator.sync_stream(stream); int temp_h = -1; // Verify more than one byte is being sent - CUDA_CHECK(cudaMemcpyAsync(&temp_h, temp_d.data(), sizeof(int), + CUDA_CHECK(cudaMemcpyAsync(&temp_h, recv_d.data(), sizeof(int), cudaMemcpyDeviceToHost, stream)); CUDA_CHECK(cudaStreamSynchronize(stream)); communicator.barrier(); diff --git a/cpp/include/raft/spatial/knn/detail/brute_force_knn.hpp b/cpp/include/raft/spatial/knn/detail/brute_force_knn.hpp new file mode 100644 index 0000000000..1ca5be2052 --- /dev/null +++ b/cpp/include/raft/spatial/knn/detail/brute_force_knn.hpp @@ -0,0 +1,360 @@ +/* + * Copyright (c) 2020-2021, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include + +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include + +#include "processing.hpp" + +namespace raft { +namespace spatial { +namespace knn { +namespace detail { + +template +__global__ void knn_merge_parts_kernel(value_t *inK, value_idx *inV, + value_t *outK, value_idx *outV, + size_t n_samples, int n_parts, + value_t initK, value_idx initV, int k, + value_idx *translations) { + constexpr int kNumWarps = tpb / faiss::gpu::kWarpSize; + + __shared__ value_t smemK[kNumWarps * warp_q]; + __shared__ value_idx smemV[kNumWarps * warp_q]; + + /** + * Uses shared memory + */ + faiss::gpu::BlockSelect, warp_q, thread_q, + tpb> + heap(initK, initV, smemK, smemV, k); + + // Grid is exactly sized to rows available + int row = blockIdx.x; + int total_k = k * n_parts; + + int i = threadIdx.x; + + // Get starting pointers for cols in current thread + int part = i / k; + size_t row_idx = (row * k) + (part * n_samples * k); + + int col = i % k; + + value_t *inKStart = inK + (row_idx + col); + value_idx *inVStart = inV + (row_idx + col); + + int limit = faiss::gpu::utils::roundDown(total_k, faiss::gpu::kWarpSize); + value_idx translation = 0; + + for (; i < limit; i += tpb) { + translation = translations[part]; + heap.add(*inKStart, (*inVStart) + translation); + + part = (i + tpb) / k; + row_idx = (row * k) + (part * n_samples * k); + + col = (i + tpb) % k; + + inKStart = inK + (row_idx + col); + inVStart = inV + (row_idx + col); + } + + // Handle last remainder fraction of a warp of elements + if (i < total_k) { + translation = translations[part]; + heap.addThreadQ(*inKStart, (*inVStart) + translation); + } + + heap.reduce(); + + for (int i = threadIdx.x; i < k; i += tpb) { + outK[row * k + i] = smemK[i]; + outV[row * k + i] = smemV[i]; + } +} + +template +inline void knn_merge_parts_impl(value_t *inK, value_idx *inV, value_t *outK, + value_idx *outV, size_t n_samples, int n_parts, + int k, cudaStream_t stream, + value_idx *translations) { + auto grid = dim3(n_samples); + + constexpr int n_threads = (warp_q <= 1024) ? 128 : 64; + auto block = dim3(n_threads); + + auto kInit = faiss::gpu::Limits::getMax(); + auto vInit = -1; + knn_merge_parts_kernel + <<>>(inK, inV, outK, outV, n_samples, n_parts, + kInit, vInit, k, translations); + CUDA_CHECK(cudaPeekAtLastError()); +} + +/** + * @brief Merge knn distances and index matrix, which have been partitioned + * by row, into a single matrix with only the k-nearest neighbors. + * + * @param inK partitioned knn distance matrix + * @param inV partitioned knn index matrix + * @param outK merged knn distance matrix + * @param outV merged knn index matrix + * @param n_samples number of samples per partition + * @param n_parts number of partitions + * @param k number of neighbors per partition (also number of merged neighbors) + * @param stream CUDA stream to use + * @param translations mapping of index offsets for each partition + */ +template +inline void knn_merge_parts(value_t *inK, value_idx *inV, value_t *outK, + value_idx *outV, size_t n_samples, int n_parts, + int k, cudaStream_t stream, + value_idx *translations) { + if (k == 1) + knn_merge_parts_impl( + inK, inV, outK, outV, n_samples, n_parts, k, stream, translations); + else if (k <= 32) + knn_merge_parts_impl( + inK, inV, outK, outV, n_samples, n_parts, k, stream, translations); + else if (k <= 64) + knn_merge_parts_impl( + inK, inV, outK, outV, n_samples, n_parts, k, stream, translations); + else if (k <= 128) + knn_merge_parts_impl( + inK, inV, outK, outV, n_samples, n_parts, k, stream, translations); + else if (k <= 256) + knn_merge_parts_impl( + inK, inV, outK, outV, n_samples, n_parts, k, stream, translations); + else if (k <= 512) + knn_merge_parts_impl( + inK, inV, outK, outV, n_samples, n_parts, k, stream, translations); + else if (k <= 1024) + knn_merge_parts_impl( + inK, inV, outK, outV, n_samples, n_parts, k, stream, translations); +} + +inline faiss::MetricType build_faiss_metric(distance::DistanceType metric) { + switch (metric) { + case distance::DistanceType::L2Unexpanded: + return faiss::MetricType::METRIC_L2; + case distance::DistanceType::L1: + return faiss::MetricType::METRIC_L1; + case distance::DistanceType::Linf: + return faiss::MetricType::METRIC_Linf; + case distance::DistanceType::LpUnexpanded: + return faiss::MetricType::METRIC_Lp; + case distance::DistanceType::Canberra: + return faiss::MetricType::METRIC_Canberra; + case distance::DistanceType::BrayCurtis: + return faiss::MetricType::METRIC_BrayCurtis; + case distance::DistanceType::JensenShannon: + return faiss::MetricType::METRIC_JensenShannon; + default: + return faiss::MetricType::METRIC_INNER_PRODUCT; + } +} + +/** + * Search the kNN for the k-nearest neighbors of a set of query vectors + * @param[in] input vector of device device memory array pointers to search + * @param[in] sizes vector of memory sizes for each device array pointer in input + * @param[in] D number of cols in input and search_items + * @param[in] search_items set of vectors to query for neighbors + * @param[in] n number of items in search_items + * @param[out] res_I pointer to device memory for returning k nearest indices + * @param[out] res_D pointer to device memory for returning k nearest distances + * @param[in] k number of neighbors to query + * @param[in] allocator the device memory allocator to use for temporary scratch memory + * @param[in] userStream the main cuda stream to use + * @param[in] internalStreams optional when n_params > 0, the index partitions can be + * queried in parallel using these streams. Note that n_int_streams also + * has to be > 0 for these to be used and their cardinality does not need + * to correspond to n_parts. + * @param[in] n_int_streams size of internalStreams. When this is <= 0, only the + * user stream will be used. + * @param[in] rowMajorIndex are the index arrays in row-major layout? + * @param[in] rowMajorQuery are the query array in row-major layout? + * @param[in] translations translation ids for indices when index rows represent + * non-contiguous partitions + * @param[in] metric corresponds to the FAISS::metricType enum (default is euclidean) + * @param[in] metricArg metric argument to use. Corresponds to the p arg for lp norm + * @param[in] expanded_form whether or not lp variants should be reduced w/ lp-root + */ +template +void brute_force_knn_impl( + std::vector &input, std::vector &sizes, IntType D, + float *search_items, IntType n, int64_t *res_I, float *res_D, IntType k, + std::shared_ptr allocator, + cudaStream_t userStream, cudaStream_t *internalStreams = nullptr, + int n_int_streams = 0, bool rowMajorIndex = true, bool rowMajorQuery = true, + std::vector *translations = nullptr, + distance::DistanceType metric = distance::DistanceType::L2Unexpanded, + float metricArg = 2.0, bool expanded_form = false) { + ASSERT(input.size() == sizes.size(), + "input and sizes vectors should be the same size"); + + faiss::MetricType m = detail::build_faiss_metric(metric); + + std::vector *id_ranges; + if (translations == nullptr) { + // If we don't have explicit translations + // for offsets of the indices, build them + // from the local partitions + id_ranges = new std::vector(); + int64_t total_n = 0; + for (size_t i = 0; i < input.size(); i++) { + id_ranges->push_back(total_n); + total_n += sizes[i]; + } + } else { + // otherwise, use the given translations + id_ranges = translations; + } + + // perform preprocessing + std::unique_ptr> query_metric_processor = + create_processor(metric, n, D, k, rowMajorQuery, userStream, + allocator); + query_metric_processor->preprocess(search_items); + + std::vector>> metric_processors( + input.size()); + for (size_t i = 0; i < input.size(); i++) { + metric_processors[i] = create_processor( + metric, sizes[i], D, k, rowMajorQuery, userStream, allocator); + metric_processors[i]->preprocess(input[i]); + } + + int device; + CUDA_CHECK(cudaGetDevice(&device)); + + raft::mr::device::buffer trans(allocator, userStream, + id_ranges->size()); + raft::update_device(trans.data(), id_ranges->data(), id_ranges->size(), + userStream); + + raft::mr::device::buffer all_D(allocator, userStream, 0); + raft::mr::device::buffer all_I(allocator, userStream, 0); + + float *out_D = res_D; + int64_t *out_I = res_I; + + if (input.size() > 1) { + all_D.resize(input.size() * k * n, userStream); + all_I.resize(input.size() * k * n, userStream); + + out_D = all_D.data(); + out_I = all_I.data(); + } + + // Sync user stream only if using other streams to parallelize query + if (n_int_streams > 0) CUDA_CHECK(cudaStreamSynchronize(userStream)); + + for (size_t i = 0; i < input.size(); i++) { + faiss::gpu::StandardGpuResources gpu_res; + + cudaStream_t stream = + raft::select_stream(userStream, internalStreams, n_int_streams, i); + + gpu_res.noTempMemory(); + gpu_res.setCudaMallocWarning(false); + gpu_res.setDefaultStream(device, stream); + + faiss::gpu::GpuDistanceParams args; + args.metric = m; + args.metricArg = metricArg; + args.k = k; + args.dims = D; + args.vectors = input[i]; + args.vectorsRowMajor = rowMajorIndex; + args.numVectors = sizes[i]; + args.queries = search_items; + args.queriesRowMajor = rowMajorQuery; + args.numQueries = n; + args.outDistances = out_D + (i * k * n); + args.outIndices = out_I + (i * k * n); + + /** + * @todo: Until FAISS supports pluggable allocation strategies, + * we will not reap the benefits of the pool allocator for + * avoiding device-wide synchronizations from cudaMalloc/cudaFree + */ + bfKnn(&gpu_res, args); + + CUDA_CHECK(cudaPeekAtLastError()); + } + + // Sync internal streams if used. We don't need to + // sync the user stream because we'll already have + // fully serial execution. + for (int i = 0; i < n_int_streams; i++) { + CUDA_CHECK(cudaStreamSynchronize(internalStreams[i])); + } + + if (input.size() > 1 || translations != nullptr) { + // This is necessary for proper index translations. If there are + // no translations or partitions to combine, it can be skipped. + detail::knn_merge_parts(out_D, out_I, res_D, res_I, n, input.size(), k, + userStream, trans.data()); + } + + // Perform necessary post-processing + if ((m == faiss::MetricType::METRIC_L2 || + m == faiss::MetricType::METRIC_Lp) && + !expanded_form) { + /** + * post-processing + */ + float p = 0.5; // standard l2 + if (m == faiss::MetricType::METRIC_Lp) p = 1.0 / metricArg; + raft::linalg::unaryOp( + res_D, res_D, n * k, + [p] __device__(float input) { return powf(input, p); }, userStream); + } + + query_metric_processor->revert(search_items); + query_metric_processor->postprocess(out_D); + for (size_t i = 0; i < input.size(); i++) { + metric_processors[i]->revert(input[i]); + } + + if (translations == nullptr) delete id_ranges; +} + +} // namespace detail +} // namespace knn +} // namespace spatial +} // namespace raft diff --git a/cpp/include/raft/spatial/knn/detail/processing.hpp b/cpp/include/raft/spatial/knn/detail/processing.hpp new file mode 100644 index 0000000000..a645412c2f --- /dev/null +++ b/cpp/include/raft/spatial/knn/detail/processing.hpp @@ -0,0 +1,192 @@ +/* + * Copyright (c)2020-2021, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include + +namespace raft { +namespace spatial { +namespace knn { + +using deviceAllocator = raft::mr::device::allocator; +/** + * @brief A virtual class defining pre- and post-processing + * for metrics. This class will temporarily modify its given + * state in `preprocess()` and undo those modifications in + * `postprocess()` + */ + +template +class MetricProcessor { + public: + virtual void preprocess(math_t *data) {} + + virtual void revert(math_t *data) {} + + virtual void postprocess(math_t *data) {} + + virtual ~MetricProcessor() = default; +}; + +template +class CosineMetricProcessor : public MetricProcessor { + protected: + int k_; + bool row_major_; + size_t n_rows_; + size_t n_cols_; + cudaStream_t stream_; + std::shared_ptr device_allocator_; + raft::mr::device::buffer colsums_; + + public: + CosineMetricProcessor(size_t n_rows, size_t n_cols, int k, bool row_major, + cudaStream_t stream, + std::shared_ptr allocator) + : device_allocator_(allocator), + stream_(stream), + colsums_(allocator, stream, n_rows), + n_cols_(n_cols), + n_rows_(n_rows), + row_major_(row_major), + k_(k) {} + + void preprocess(math_t *data) { + raft::linalg::rowNorm(colsums_.data(), data, n_cols_, n_rows_, + raft::linalg::NormType::L2Norm, row_major_, stream_, + [] __device__(math_t in) { return sqrtf(in); }); + + raft::linalg::matrixVectorOp( + data, data, colsums_.data(), n_cols_, n_rows_, row_major_, false, + [] __device__(math_t mat_in, math_t vec_in) { return mat_in / vec_in; }, + stream_); + } + + void revert(math_t *data) { + raft::linalg::matrixVectorOp( + data, data, colsums_.data(), n_cols_, n_rows_, row_major_, false, + [] __device__(math_t mat_in, math_t vec_in) { return mat_in * vec_in; }, + stream_); + } + + void postprocess(math_t *data) { + raft::linalg::unaryOp( + data, data, k_ * n_rows_, [] __device__(math_t in) { return 1 - in; }, + stream_); + } + + ~CosineMetricProcessor() = default; +}; + +template +class CorrelationMetricProcessor : public CosineMetricProcessor { + using cosine = CosineMetricProcessor; + + public: + CorrelationMetricProcessor(size_t n_rows, size_t n_cols, int k, + bool row_major, cudaStream_t stream, + std::shared_ptr allocator) + : CosineMetricProcessor(n_rows, n_cols, k, row_major, stream, + allocator), + means_(allocator, stream, n_rows) {} + + void preprocess(math_t *data) { + math_t normalizer_const = 1.0 / (math_t)cosine::n_cols_; + + raft::linalg::reduce(means_.data(), data, cosine::n_cols_, cosine::n_rows_, + (math_t)0.0, cosine::row_major_, true, + cosine::stream_); + + raft::linalg::unaryOp( + means_.data(), means_.data(), cosine::n_rows_, + [=] __device__(math_t in) { return in * normalizer_const; }, + cosine::stream_); + + raft::stats::meanCenter(data, data, means_.data(), cosine::n_cols_, + cosine::n_rows_, cosine::row_major_, false, + cosine::stream_); + + CosineMetricProcessor::preprocess(data); + } + + void revert(math_t *data) { + CosineMetricProcessor::revert(data); + + raft::stats::meanAdd(data, data, means_.data(), cosine::n_cols_, + cosine::n_rows_, cosine::row_major_, false, + cosine::stream_); + } + + void postprocess(math_t *data) { + CosineMetricProcessor::postprocess(data); + } + + ~CorrelationMetricProcessor() = default; + + raft::mr::device::buffer means_; +}; + +template +class DefaultMetricProcessor : public MetricProcessor { + public: + void preprocess(math_t *data) {} + + void revert(math_t *data) {} + + void postprocess(math_t *data) {} + + ~DefaultMetricProcessor() = default; +}; + +template +inline std::unique_ptr> create_processor( + distance::DistanceType metric, int n, int D, int k, bool rowMajorQuery, + cudaStream_t userStream, std::shared_ptr allocator) { + MetricProcessor *mp = nullptr; + + switch (metric) { + case distance::DistanceType::CosineExpanded: + mp = new CosineMetricProcessor(n, D, k, rowMajorQuery, userStream, + allocator); + break; + + case distance::DistanceType::CorrelationExpanded: + mp = new CorrelationMetricProcessor(n, D, k, rowMajorQuery, + userStream, allocator); + break; + default: + mp = new DefaultMetricProcessor(); + } + + return std::unique_ptr>(mp); +} + +// Currently only being used by floats +template class MetricProcessor; +template class CosineMetricProcessor; +template class CorrelationMetricProcessor; +template class DefaultMetricProcessor; + +} // namespace knn +} // namespace spatial +} // namespace raft diff --git a/cpp/include/raft/spatial/knn/knn.hpp b/cpp/include/raft/spatial/knn/knn.hpp new file mode 100644 index 0000000000..ccee635701 --- /dev/null +++ b/cpp/include/raft/spatial/knn/knn.hpp @@ -0,0 +1,73 @@ +/* + * Copyright (c) 2020-2021, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include "detail/brute_force_knn.hpp" + +#include +#include + +namespace raft { +namespace spatial { +namespace knn { + +using deviceAllocator = raft::mr::device::allocator; + +/** + * @brief Flat C++ API function to perform a brute force knn on + * a series of input arrays and combine the results into a single + * output array for indexes and distances. + * + * @param[in] handle the cuml handle to use + * @param[in] input vector of pointers to the input arrays + * @param[in] sizes vector of sizes of input arrays + * @param[in] D the dimensionality of the arrays + * @param[in] search_items array of items to search of dimensionality D + * @param[in] n number of rows in search_items + * @param[out] res_I the resulting index array of size n * k + * @param[out] res_D the resulting distance array of size n * k + * @param[in] k the number of nearest neighbors to return + * @param[in] rowMajorIndex are the index arrays in row-major order? + * @param[in] rowMajorQuery are the query arrays in row-major order? + * @param[in] metric distance metric to use. Euclidean (L2) is used by + * default + * @param[in] metric_arg the value of `p` for Minkowski (l-p) distances. This + * is ignored if the metric_type is not Minkowski. + * @param[in] expanded should lp-based distances be returned in their expanded + * form (e.g., without raising to the 1/p power). + */ +void brute_force_knn( + raft::handle_t &handle, std::vector &input, std::vector &sizes, + int D, float *search_items, int n, int64_t *res_I, float *res_D, int k, + bool rowMajorIndex = false, bool rowMajorQuery = false, + distance::DistanceType metric = distance::DistanceType::L2Unexpanded, + float metric_arg = 2.0f, bool expanded = false) { + ASSERT(input.size() == sizes.size(), + "input and sizes vectors must be the same size"); + + std::vector int_streams = handle.get_internal_streams(); + + detail::brute_force_knn_impl( + input, sizes, D, search_items, n, res_I, res_D, k, + handle.get_device_allocator(), handle.get_stream(), int_streams.data(), + handle.get_num_internal_streams(), rowMajorIndex, rowMajorQuery, nullptr, + metric, metric_arg, expanded); +} + +} // namespace knn +} // namespace spatial +} // namespace raft diff --git a/cpp/test/spatial/knn.cu b/cpp/test/spatial/knn.cu new file mode 100644 index 0000000000..cfd4ecc9d1 --- /dev/null +++ b/cpp/test/spatial/knn.cu @@ -0,0 +1,175 @@ +/* + * Copyright (c) 2021, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include +#include +#include +#include +#include "../test_utils.h" + +namespace raft { +namespace spatial { +namespace knn { +struct KNNInputs { + std::vector> input; + int k; + std::vector labels; +}; + +__global__ void build_actual_output(int *output, int n_rows, int k, + const int *idx_labels, + const int64_t *indices) { + int element = threadIdx.x + blockDim.x * blockIdx.x; + if (element >= n_rows * k) return; + + int ind = (int)indices[element]; + output[element] = idx_labels[ind]; +} + +__global__ void build_expected_output(int *output, int n_rows, int k, + const int *labels) { + int row = threadIdx.x + blockDim.x * blockIdx.x; + if (row >= n_rows) return; + + int cur_label = labels[row]; + for (int i = 0; i < k; i++) { + output[row * k + i] = cur_label; + } +} + +template +class KNNTest : public ::testing::TestWithParam { + protected: + void testBruteForce() { + raft::print_device_vector("Input array: ", input_, rows_ * cols_, + std::cout); + std::cout << "K: " << k_ << "\n"; + raft::print_device_vector("Labels array: ", search_labels_, rows_, + std::cout); + + auto stream = handle_.get_stream(); + + raft::allocate(actual_labels_, rows_ * k_, true); + raft::allocate(expected_labels_, rows_ * k_, true); + + std::vector input_vec; + std::vector sizes_vec; + input_vec.push_back(input_); + sizes_vec.push_back(rows_); + + brute_force_knn(handle_, input_vec, sizes_vec, cols_, search_data_, rows_, + indices_, distances_, k_, true, true); + + build_actual_output<<>>( + actual_labels_, rows_, k_, search_labels_, indices_); + + build_expected_output<<>>( + expected_labels_, rows_, k_, search_labels_); + + raft::print_device_vector("Output indices: ", indices_, rows_ * k_, + std::cout); + raft::print_device_vector("Output distances: ", distances_, rows_ * k_, + std::cout); + raft::print_device_vector("Output labels: ", actual_labels_, rows_ * k_, + std::cout); + raft::print_device_vector("Expected labels: ", expected_labels_, rows_ * k_, + std::cout); + + ASSERT_TRUE(devArrMatch(expected_labels_, actual_labels_, rows_ * k_, + raft::Compare())); + } + + void SetUp() override { + params_ = ::testing::TestWithParam::GetParam(); + rows_ = params_.input.size(); + cols_ = params_.input[0].size(); + k_ = params_.k; + + std::vector row_major_input; + for (int i = 0; i < params_.input.size(); ++i) { + for (int j = 0; j < params_.input[i].size(); ++j) { + row_major_input.push_back(params_.input[i][j]); + } + } + rmm::device_buffer input_d = rmm::device_buffer( + row_major_input.data(), row_major_input.size() * sizeof(float)); + float *input_ptr = static_cast(input_d.data()); + + rmm::device_buffer labels_d = rmm::device_buffer( + params_.labels.data(), params_.labels.size() * sizeof(int)); + int *labels_ptr = static_cast(labels_d.data()); + + raft::allocate(input_, rows_ * cols_, true); + raft::allocate(search_data_, rows_ * cols_, true); + raft::allocate(indices_, rows_ * k_, true); + raft::allocate(distances_, rows_ * k_, true); + raft::allocate(search_labels_, rows_, true); + + raft::copy(input_, input_ptr, rows_ * cols_, handle_.get_stream()); + raft::copy(search_data_, input_ptr, rows_ * cols_, handle_.get_stream()); + raft::copy(search_labels_, labels_ptr, rows_, handle_.get_stream()); + } + + void TearDown() override { + CUDA_CHECK(cudaFree(search_data_)); + CUDA_CHECK(cudaFree(indices_)); + CUDA_CHECK(cudaFree(distances_)); + CUDA_CHECK(cudaFree(actual_labels_)); + } + + private: + raft::handle_t handle_; + KNNInputs params_; + int rows_; + int cols_; + float *input_; + float *search_data_; + int64_t *indices_; + float *distances_; + int k_; + + int *search_labels_; + int *actual_labels_; + int *expected_labels_; +}; + +const std::vector inputs = { + // 2D + {{ + {2.7810836, 2.550537003}, + {1.465489372, 2.362125076}, + {3.396561688, 4.400293529}, + {1.38807019, 1.850220317}, + {3.06407232, 3.005305973}, + {7.627531214, 2.759262235}, + {5.332441248, 2.088626775}, + {6.922596716, 1.77106367}, + {8.675418651, -0.242068655}, + {7.673756466, 3.508563011}, + }, + 2, + {0, 0, 0, 0, 0, 1, 1, 1, 1, 1}}}; + +typedef KNNTest KNNTestF; +TEST_P(KNNTestF, BruteForce) { this->testBruteForce(); } + +INSTANTIATE_TEST_CASE_P(KNNTest, KNNTestF, ::testing::ValuesIn(inputs)); + +} // namespace knn +} // namespace spatial +} // namespace raft diff --git a/python/raft/dask/common/__init__.py b/python/raft/dask/common/__init__.py index 788af46c92..73bb5d6700 100644 --- a/python/raft/dask/common/__init__.py +++ b/python/raft/dask/common/__init__.py @@ -21,6 +21,8 @@ from .comms_utils import perform_test_comms_allreduce from .comms_utils import perform_test_comms_send_recv from .comms_utils import perform_test_comms_allgather +from .comms_utils import perform_test_comms_gather +from .comms_utils import perform_test_comms_gatherv from .comms_utils import perform_test_comms_bcast from .comms_utils import perform_test_comms_reduce from .comms_utils import perform_test_comms_reducescatter diff --git a/python/raft/dask/common/comms_utils.pyx b/python/raft/dask/common/comms_utils.pyx index 4dbd2f1a7c..1a703485a9 100644 --- a/python/raft/dask/common/comms_utils.pyx +++ b/python/raft/dask/common/comms_utils.pyx @@ -60,6 +60,8 @@ cdef extern from "raft/comms/test.hpp" namespace "raft::comms": bool test_collective_broadcast(const handle_t &h, int root) except + bool test_collective_reduce(const handle_t &h, int root) except + bool test_collective_allgather(const handle_t &h, int root) except + + bool test_collective_gather(const handle_t &h, int root) except + + bool test_collective_gatherv(const handle_t &h, int root) except + bool test_collective_reducescatter(const handle_t &h, int root) except + bool test_pointToPoint_simple_send_recv(const handle_t &h, int numTrials) except + @@ -131,6 +133,36 @@ def perform_test_comms_allgather(handle, root): return test_collective_allgather(deref(h), root) +def perform_test_comms_gather(handle, root): + """ + Performs a gather on the current worker + + Parameters + ---------- + handle : raft.common.Handle + handle containing comms_t to use + root : int + Rank of the root worker + """ + cdef const handle_t* h = handle.getHandle() + return test_collective_gather(deref(h), root) + + +def perform_test_comms_gatherv(handle, root): + """ + Performs a gatherv on the current worker + + Parameters + ---------- + handle : raft.common.Handle + handle containing comms_t to use + root : int + Rank of the root worker + """ + cdef const handle_t* h = handle.getHandle() + return test_collective_gatherv(deref(h), root) + + def perform_test_comms_send_recv(handle, n_trials): """ Performs a p2p send/recv on the current worker diff --git a/python/raft/test/test_comms.py b/python/raft/test/test_comms.py index 7dccb7bbae..a0db3b7f4f 100644 --- a/python/raft/test/test_comms.py +++ b/python/raft/test/test_comms.py @@ -28,6 +28,8 @@ from raft.dask.common import perform_test_comms_bcast from raft.dask.common import perform_test_comms_reduce from raft.dask.common import perform_test_comms_allgather + from raft.dask.common import perform_test_comms_gather + from raft.dask.common import perform_test_comms_gatherv from raft.dask.common import perform_test_comms_reducescatter from raft.dask.common import perform_test_comm_split @@ -130,6 +132,8 @@ def _has_handle(sessionId): perform_test_comms_allgather, perform_test_comms_allreduce, perform_test_comms_bcast, + perform_test_comms_gather, + perform_test_comms_gatherv, perform_test_comms_reduce, perform_test_comms_reducescatter, ]