From 57cfa20924c8e02631083a02b353377323d7f4b8 Mon Sep 17 00:00:00 2001 From: Ben Frederickson Date: Fri, 27 Jan 2023 11:03:24 -0800 Subject: [PATCH 01/62] Replace faiss bfKnn Replace faiss bfKnn with code that leverages our pairwise_distance api and select_k api - by computing tiling over the inputs. This lets us remove faiss as a dependency --- build.sh | 8 +- .../all_cuda-118_arch-x86_64.yaml | 2 - conda/recipes/libraft/conda_build_config.yaml | 3 - conda/recipes/libraft/meta.yaml | 6 +- cpp/CMakeLists.txt | 25 +- cpp/cmake/thirdparty/get_faiss.cmake | 89 --- .../raft/spatial/knn/detail/common_faiss.h | 56 -- .../knn/detail/faiss_select/DistanceUtils.h | 52 ++ .../knn/detail/knn_brute_force_faiss.cuh | 163 ++++- .../spatial/knn/detail/selection_faiss.cuh | 1 - cpp/include/raft/spatial/knn/faiss_mr.hpp | 640 ------------------ cpp/test/CMakeLists.txt | 2 +- cpp/test/neighbors/faiss_mr.cu | 94 --- cpp/test/neighbors/fused_l2_knn.cu | 54 +- cpp/test/neighbors/knn_utils.cuh | 84 +++ cpp/test/neighbors/tiled_knn.cu | 165 +++++ dependencies.yaml | 2 - docs/source/build.md | 20 +- 18 files changed, 444 insertions(+), 1022 deletions(-) delete mode 100644 cpp/cmake/thirdparty/get_faiss.cmake delete mode 100644 cpp/include/raft/spatial/knn/detail/common_faiss.h create mode 100644 cpp/include/raft/spatial/knn/detail/faiss_select/DistanceUtils.h delete mode 100644 cpp/include/raft/spatial/knn/faiss_mr.hpp delete mode 100644 cpp/test/neighbors/faiss_mr.cu create mode 100644 cpp/test/neighbors/knn_utils.cuh create mode 100644 cpp/test/neighbors/tiled_knn.cu diff --git a/build.sh b/build.sh index b47e1ed862..2496eea5c2 100755 --- a/build.sh +++ b/build.sh @@ -18,7 +18,7 @@ ARGS=$* # script, and that this script resides in the repo dir! REPODIR=$(cd $(dirname $0); pwd) -VALIDARGS="clean libraft pylibraft raft-dask docs tests bench clean --uninstall -v -g -n --compile-libs --compile-nn --compile-dist --allgpuarch --no-nvtx --show_depr_warn -h --buildfaiss --minimal-deps" +VALIDARGS="clean libraft pylibraft raft-dask docs tests bench clean --uninstall -v -g -n --compile-libs --compile-nn --compile-dist --allgpuarch --no-nvtx --show_depr_warn -h --minimal-deps" HELP="$0 [ ...] [ ...] [--cmake-args=\"\"] [--cache-tool=] [--limit-tests=] [--limit-bench=] where is: clean - remove all existing build artifacts and configuration (start over) @@ -45,7 +45,6 @@ HELP="$0 [ ...] [ ...] [--cmake-args=\"\"] [--cache-tool=\\\" - pass arbitrary list of CMake configuration options (escape all quotes in argument) @@ -68,7 +67,6 @@ BUILD_ALL_GPU_ARCH=0 BUILD_TESTS=OFF BUILD_TYPE=Release BUILD_BENCH=OFF -BUILD_STATIC_FAISS=OFF COMPILE_LIBRARIES=OFF COMPILE_NN_LIBRARY=OFF COMPILE_DIST_LIBRARY=OFF @@ -335,9 +333,6 @@ if hasArg bench || (( ${NUMARGS} == 0 )); then fi -if hasArg --buildfaiss; then - BUILD_STATIC_FAISS=ON -fi if hasArg --no-nvtx; then NVTX=OFF fi @@ -407,7 +402,6 @@ if (( ${NUMARGS} == 0 )) || hasArg libraft || hasArg docs || hasArg tests || has -DCMAKE_MESSAGE_LOG_LEVEL=${CMAKE_LOG_LEVEL} \ -DRAFT_COMPILE_NN_LIBRARY=${COMPILE_NN_LIBRARY} \ -DRAFT_COMPILE_DIST_LIBRARY=${COMPILE_DIST_LIBRARY} \ - -DRAFT_USE_FAISS_STATIC=${BUILD_STATIC_FAISS} \ -DRAFT_ENABLE_thrust_DEPENDENCY=${ENABLE_thrust_DEPENDENCY} \ ${CACHE_ARGS} \ ${EXTRA_CMAKE_ARGS} diff --git a/conda/environments/all_cuda-118_arch-x86_64.yaml b/conda/environments/all_cuda-118_arch-x86_64.yaml index 87b7075935..316c4ede79 100644 --- a/conda/environments/all_cuda-118_arch-x86_64.yaml +++ b/conda/environments/all_cuda-118_arch-x86_64.yaml @@ -22,7 +22,6 @@ dependencies: - dask>=2022.12.0 - distributed>=2022.12.0 - doxygen>=1.8.20 -- faiss-proc=*=cuda - gcc_linux-64=9 - libcublas-dev=11.11.3.6 - libcublas=11.11.3.6 @@ -32,7 +31,6 @@ dependencies: - libcusolver=11.4.1.48 - libcusparse-dev=11.7.5.86 - libcusparse=11.7.5.86 -- libfaiss>=1.7.1=cuda* - ninja - pytest - pytest-cov diff --git a/conda/recipes/libraft/conda_build_config.yaml b/conda/recipes/libraft/conda_build_config.yaml index 1012bddb40..ae4ba68229 100644 --- a/conda/recipes/libraft/conda_build_config.yaml +++ b/conda/recipes/libraft/conda_build_config.yaml @@ -19,9 +19,6 @@ nccl_version: gtest_version: - "=1.10.0" -libfaiss_version: - - "1.7.2 *_cuda" - # The CTK libraries below are missing from the conda-forge::cudatoolkit # package. The "*_host_*" version specifiers correspond to `11.8` packages and the # "*_run_*" version specifiers correspond to `11.x` packages. diff --git a/conda/recipes/libraft/meta.yaml b/conda/recipes/libraft/meta.yaml index b0d6c47ee9..4729b01ed2 100644 --- a/conda/recipes/libraft/meta.yaml +++ b/conda/recipes/libraft/meta.yaml @@ -1,4 +1,4 @@ -# Copyright (c) 2022, NVIDIA CORPORATION. +# Copyright (c) 2022-2023, NVIDIA CORPORATION. # Usage: # conda build . -c conda-forge -c nvidia -c rapidsai @@ -126,7 +126,6 @@ outputs: host: - {{ pin_subpackage('libraft-headers', exact=True) }} - cuda-profiler-api {{ cuda_profiler_api_host_version }} - - faiss-proc=*=cuda - lapack - libcublas {{ libcublas_host_version }} - libcublas-dev {{ libcublas_host_version }} @@ -136,10 +135,7 @@ outputs: - libcusolver-dev {{ libcusolver_host_version }} - libcusparse {{ libcusparse_host_version }} - libcusparse-dev {{ libcusparse_host_version }} - - libfaiss {{ libfaiss_version }} run: - - faiss-proc=*=cuda - - libfaiss {{ libfaiss_version }} - {{ pin_subpackage('libraft-headers', exact=True) }} about: home: https://rapids.ai/ diff --git a/cpp/CMakeLists.txt b/cpp/CMakeLists.txt index 784bbbb935..a084e7f0ca 100644 --- a/cpp/CMakeLists.txt +++ b/cpp/CMakeLists.txt @@ -1,5 +1,5 @@ # ============================================================================= -# Copyright (c) 2020-2022, NVIDIA CORPORATION. +# Copyright (c) 2020-2023, NVIDIA CORPORATION. # # Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except # in compliance with the License. You may obtain a copy of the License at @@ -65,9 +65,7 @@ option( option(RAFT_COMPILE_DIST_LIBRARY "Enable building raft distant shared library instantiations" ${RAFT_COMPILE_LIBRARIES} ) -option(RAFT_ENABLE_NN_DEPENDENCIES "Search for raft::nn dependencies like faiss" - ${RAFT_COMPILE_LIBRARIES} -) +option(RAFT_ENABLE_NN_DEPENDENCIES "Search for raft::nn dependencies" ${RAFT_COMPILE_LIBRARIES}) option(RAFT_ENABLE_thrust_DEPENDENCY "Enable Thrust dependency" ON) @@ -83,16 +81,7 @@ if(BUILD_TESTS AND NOT RAFT_ENABLE_thrust_DEPENDENCY) set(RAFT_ENABLE_thrust_DEPENDENCY ON) endif() -option(RAFT_EXCLUDE_FAISS_FROM_ALL "Exclude FAISS targets from RAFT's 'all' target" ON) - include(CMakeDependentOption) -cmake_dependent_option( - RAFT_USE_FAISS_STATIC - "Build and statically link the FAISS library for nearest neighbors search on GPU" - ON - RAFT_COMPILE_LIBRARIES - OFF -) message(VERBOSE "RAFT: Building optional components: ${raft_FIND_COMPONENTS}") message(VERBOSE "RAFT: Build RAFT unit-tests: ${BUILD_TESTS}") @@ -177,7 +166,6 @@ rapids_cpm_init() # thrust before rmm/cuco so we get the right version of thrust/cub include(cmake/thirdparty/get_thrust.cmake) include(cmake/thirdparty/get_rmm.cmake) -include(cmake/thirdparty/get_faiss.cmake) include(cmake/thirdparty/get_cutlass.cmake) if(RAFT_ENABLE_cuco_DEPENDENCY) @@ -292,7 +280,6 @@ if(RAFT_COMPILE_DIST_LIBRARY) src/distance/distance/specializations/detail/chebyshev.cu src/distance/distance/specializations/detail/correlation.cu src/distance/distance/specializations/detail/cosine.cu - src/distance/distance/specializations/detail/cosine.cu src/distance/distance/specializations/detail/hamming_unexpanded.cu src/distance/distance/specializations/detail/hellinger_expanded.cu src/distance/distance/specializations/detail/jensen_shannon_float_float_float_int.cu @@ -433,7 +420,7 @@ if(RAFT_COMPILE_NN_LIBRARY) target_link_libraries( raft_nn_lib - PUBLIC faiss::faiss raft::raft + PUBLIC raft::raft PRIVATE nvidia::cutlass::cutlass ) target_compile_options( @@ -641,12 +628,6 @@ endif() if(nn IN_LIST raft_FIND_COMPONENTS) enable_language(CUDA) - - if(TARGET faiss AND (NOT TARGET faiss::faiss)) - add_library(faiss::faiss ALIAS faiss) - elseif(TARGET faiss::faiss AND (NOT TARGET faiss)) - add_library(faiss ALIAS faiss::faiss) - endif() endif() ]=] ) diff --git a/cpp/cmake/thirdparty/get_faiss.cmake b/cpp/cmake/thirdparty/get_faiss.cmake deleted file mode 100644 index e6f06a00a5..0000000000 --- a/cpp/cmake/thirdparty/get_faiss.cmake +++ /dev/null @@ -1,89 +0,0 @@ -#============================================================================= -# Copyright (c) 2021-2022, NVIDIA CORPORATION. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -#============================================================================= - -function(find_and_configure_faiss) - set(oneValueArgs VERSION REPOSITORY PINNED_TAG BUILD_STATIC_LIBS EXCLUDE_FROM_ALL) - cmake_parse_arguments(PKG "${options}" "${oneValueArgs}" - "${multiValueArgs}" ${ARGN} ) - - if(RAFT_ENABLE_NN_DEPENDENCIES OR RAFT_COMPILE_LIBRARIES) - rapids_find_generate_module(faiss - HEADER_NAMES faiss/IndexFlat.h - LIBRARY_NAMES faiss - ) - - set(BUILD_SHARED_LIBS ON) - if (PKG_BUILD_STATIC_LIBS) - set(BUILD_SHARED_LIBS OFF) - set(CPM_DOWNLOAD_faiss ON) - endif() - - rapids_cpm_find(faiss ${PKG_VERSION} - GLOBAL_TARGETS faiss::faiss - CPM_ARGS - GIT_REPOSITORY ${PKG_REPOSITORY} - GIT_TAG ${PKG_PINNED_TAG} - EXCLUDE_FROM_ALL ${PKG_EXCLUDE_FROM_ALL} - OPTIONS - "FAISS_ENABLE_PYTHON OFF" - "CUDAToolkit_ROOT ${CUDAToolkit_LIBRARY_DIR}" - "FAISS_ENABLE_GPU ON" - "BUILD_TESTING OFF" - "CMAKE_MESSAGE_LOG_LEVEL VERBOSE" - "FAISS_USE_CUDA_TOOLKIT_STATIC ${CUDA_STATIC_RUNTIME}" - ) - - if(TARGET faiss AND NOT TARGET faiss::faiss) - add_library(faiss::faiss ALIAS faiss) - endif() - - if(faiss_ADDED) - rapids_export(BUILD faiss - EXPORT_SET faiss-targets - GLOBAL_TARGETS faiss - NAMESPACE faiss::) - endif() - endif() - - # We generate the faiss-config files when we built faiss locally, so always do `find_dependency` - rapids_export_package(BUILD OpenMP raft-nn-lib-exports) # faiss uses openMP but doesn't export a need for it - rapids_export_package(BUILD faiss raft-nn-lib-exports GLOBAL_TARGETS faiss::faiss faiss) - rapids_export_package(INSTALL faiss raft-nn-lib-exports GLOBAL_TARGETS faiss::faiss faiss) - - # Tell cmake where it can find the generated faiss-config.cmake we wrote. - include("${rapids-cmake-dir}/export/find_package_root.cmake") - rapids_export_find_package_root(BUILD faiss [=[${CMAKE_CURRENT_LIST_DIR}]=] raft-nn-lib-exports) -endfunction() - -if(NOT RAFT_FAISS_GIT_TAG) - # TODO: Remove this once faiss supports FAISS_USE_CUDA_TOOLKIT_STATIC - # (https://github.com/facebookresearch/faiss/pull/2446) - set(RAFT_FAISS_GIT_TAG fea/statically-link-ctk-v1.7.0) - # set(RAFT_FAISS_GIT_TAG bde7c0027191f29c9dadafe4f6e68ca0ee31fb30) -endif() - -if(NOT RAFT_FAISS_GIT_REPOSITORY) - # TODO: Remove this once faiss supports FAISS_USE_CUDA_TOOLKIT_STATIC - # (https://github.com/facebookresearch/faiss/pull/2446) - set(RAFT_FAISS_GIT_REPOSITORY https://github.com/trxcllnt/faiss.git) - # set(RAFT_FAISS_GIT_REPOSITORY https://github.com/facebookresearch/faiss.git) -endif() - -find_and_configure_faiss(VERSION 1.7.0 - REPOSITORY ${RAFT_FAISS_GIT_REPOSITORY} - PINNED_TAG ${RAFT_FAISS_GIT_TAG} - BUILD_STATIC_LIBS ${RAFT_USE_FAISS_STATIC} - EXCLUDE_FROM_ALL ${RAFT_EXCLUDE_FAISS_FROM_ALL}) diff --git a/cpp/include/raft/spatial/knn/detail/common_faiss.h b/cpp/include/raft/spatial/knn/detail/common_faiss.h deleted file mode 100644 index 57076350f0..0000000000 --- a/cpp/include/raft/spatial/knn/detail/common_faiss.h +++ /dev/null @@ -1,56 +0,0 @@ -/* - * Copyright (c) 2020-2022, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once - -#include -#include - -#include -#include - -namespace raft { -namespace spatial { -namespace knn { -namespace detail { - -inline faiss::MetricType build_faiss_metric(raft::distance::DistanceType metric) -{ - switch (metric) { - case raft::distance::DistanceType::CosineExpanded: - return faiss::MetricType::METRIC_INNER_PRODUCT; - case raft::distance::DistanceType::CorrelationExpanded: - return faiss::MetricType::METRIC_INNER_PRODUCT; - case raft::distance::DistanceType::L2Expanded: return faiss::MetricType::METRIC_L2; - case raft::distance::DistanceType::L2Unexpanded: return faiss::MetricType::METRIC_L2; - case raft::distance::DistanceType::L2SqrtExpanded: return faiss::MetricType::METRIC_L2; - case raft::distance::DistanceType::L2SqrtUnexpanded: return faiss::MetricType::METRIC_L2; - case raft::distance::DistanceType::L1: return faiss::MetricType::METRIC_L1; - case raft::distance::DistanceType::InnerProduct: return faiss::MetricType::METRIC_INNER_PRODUCT; - case raft::distance::DistanceType::LpUnexpanded: return faiss::MetricType::METRIC_Lp; - case raft::distance::DistanceType::Linf: return faiss::MetricType::METRIC_Linf; - case raft::distance::DistanceType::Canberra: return faiss::MetricType::METRIC_Canberra; - case raft::distance::DistanceType::BrayCurtis: return faiss::MetricType::METRIC_BrayCurtis; - case raft::distance::DistanceType::JensenShannon: - return faiss::MetricType::METRIC_JensenShannon; - default: THROW("MetricType not supported: %d", metric); - } -} - -} // namespace detail -} // namespace knn -} // namespace spatial -} // namespace raft diff --git a/cpp/include/raft/spatial/knn/detail/faiss_select/DistanceUtils.h b/cpp/include/raft/spatial/knn/detail/faiss_select/DistanceUtils.h new file mode 100644 index 0000000000..51b7955d5a --- /dev/null +++ b/cpp/include/raft/spatial/knn/detail/faiss_select/DistanceUtils.h @@ -0,0 +1,52 @@ +/** + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file thirdparty/LICENSES/LICENSE.faiss + */ + +#pragma once + +namespace raft::spatial::knn::detail::faiss_select { +// If the inner size (dim) of the vectors is small, we want a larger query tile +// size, like 1024 +inline void chooseTileSize(size_t numQueries, + size_t numCentroids, + size_t dim, + size_t elementSize, + size_t totalMem, + size_t& tileRows, + size_t& tileCols) +{ + // The matrix multiplication should be large enough to be efficient, but if + // it is too large, we seem to lose efficiency as opposed to + // double-streaming. Each tile size here defines 1/2 of the memory use due + // to double streaming. We ignore available temporary memory, as that is + // adjusted independently by the user and can thus meet these requirements + // (or not). For <= 4 GB GPUs, prefer 512 MB of usage. For <= 8 GB GPUs, + // prefer 768 MB of usage. Otherwise, prefer 1 GB of usage. + size_t targetUsage = 0; + + if (totalMem <= ((size_t)4) * 1024 * 1024 * 1024) { + targetUsage = 512 * 1024 * 1024; + } else if (totalMem <= ((size_t)8) * 1024 * 1024 * 1024) { + targetUsage = 768 * 1024 * 1024; + } else { + targetUsage = 1024 * 1024 * 1024; + } + + targetUsage /= 2 * elementSize; + + // 512 seems to be a batch size sweetspot for float32. + // If we are on float16, increase to 512. + // If the k size (vec dim) of the matrix multiplication is small (<= 32), + // increase to 1024. + size_t preferredTileRows = 512; + if (dim <= 32) { preferredTileRows = 1024; } + + tileRows = std::min(preferredTileRows, numQueries); + + // tileCols is the remainder size + tileCols = std::min(targetUsage / preferredTileRows, numCentroids); +} +} // namespace raft::spatial::knn::detail::faiss_select diff --git a/cpp/include/raft/spatial/knn/detail/knn_brute_force_faiss.cuh b/cpp/include/raft/spatial/knn/detail/knn_brute_force_faiss.cuh index b246121958..22f18127b5 100644 --- a/cpp/include/raft/spatial/knn/detail/knn_brute_force_faiss.cuh +++ b/cpp/include/raft/spatial/knn/detail/knn_brute_force_faiss.cuh @@ -22,14 +22,14 @@ #include -#include - #include #include #include +#include #include +#include #include -#include +#include #include #include @@ -37,8 +37,6 @@ #include "haversine_distance.cuh" #include "processing.cuh" -#include "common_faiss.h" - namespace raft { namespace spatial { namespace knn { @@ -141,6 +139,128 @@ inline void knn_merge_parts_impl(value_t* inK, RAFT_CUDA_TRY(cudaPeekAtLastError()); } +/** + * Calculates brute force knn, using a fixed memory budget + * by tiling over both the rows and columns of pairwise_distances + */ +template +void tiled_brute_force_knn(const raft::handle_t& handle, + const ElementType* search, // size (m ,d) + const ElementType* index, // size (n ,d) + size_t m, + size_t n, + size_t d, + int k, + ElementType* distances, // size (m, k) + IndexType* indices, // size (m, k) + raft::distance::DistanceType metric, + size_t max_row_tile_size = 0, + size_t max_col_tile_size = 0) +{ + // Figure out the number of rows/cols to tile for + size_t tile_rows = 0; + size_t tile_cols = 0; + auto stream = handle.get_stream(); + auto device_memory = handle.get_workspace_resource(); + auto total_mem = device_memory->get_mem_info(stream).second; + raft::spatial::knn::detail::faiss_select::chooseTileSize( + m, n, d, sizeof(ElementType), total_mem, tile_rows, tile_cols); + + // for unittesting, its convenient to be able to put a max size on the tiles + // so we can test the tiling logic without having to use huge inputs. + if (max_row_tile_size && (tile_rows > max_row_tile_size)) { tile_rows = max_row_tile_size; } + if (max_col_tile_size && (tile_cols > max_col_tile_size)) { tile_cols = max_col_tile_size; } + + // stores pairwise distances for the current tile + rmm::device_uvector temp_distances(tile_rows * tile_cols, stream); + + // if we're tiling over columns, we need additional buffers for temporary output + // distances/indices + size_t num_col_tiles = raft::ceildiv(n, tile_cols); + size_t temp_out_cols = k * num_col_tiles; + + // the final column tile could have less than 'k' items in it + // in which case the number of columns here is too high in the temp output. + // adjust if necessary + auto last_col_tile_size = n % tile_cols; + if (last_col_tile_size && (last_col_tile_size < static_cast(k))) { + temp_out_cols -= k - last_col_tile_size; + } + rmm::device_uvector temp_out_distances(tile_rows * temp_out_cols, stream); + rmm::device_uvector temp_out_indices(tile_rows * temp_out_cols, stream); + + for (size_t i = 0; i < m; i += tile_rows) { + size_t current_query_size = std::min(tile_rows, m - i); + + for (size_t j = 0; j < n; j += tile_cols) { + size_t current_centroid_size = std::min(tile_cols, n - j); + size_t current_k = std::min(current_centroid_size, static_cast(k)); + + // calculate the top-k elements for the current tile, by calculating the + // full pairwise distance for the tile - and then selecting the top-k from that + distance::pairwise_distance(handle, + search + i * d, + index + j * d, + temp_distances.data(), + current_query_size, + current_centroid_size, + d, + metric, + true); + + detail::select_k(temp_distances.data(), + nullptr, + current_query_size, + current_centroid_size, + distances + i * k, + indices + i * k, + true, + current_k, + stream); + + // if we're tiling over columns, we need to do a couple things to fix up + // the output of select_k + // 1. The column id's in the output are relative to the tile, so we need + // to adjust the column ids by adding the column the tile starts at (j) + // 2. select_k writes out output in a row-major format, which means we + // can't just concat the output of all the tiles and do a select_k on the + // concatenation. + // Fix both of these problems in a single pass here + if (tile_cols != n) { + const ElementType* in_distances = distances + i * k; + const IndexType* in_indices = indices + i * k; + ElementType* out_distances = temp_out_distances.data(); + IndexType* out_indices = temp_out_indices.data(); + + auto count = thrust::make_counting_iterator(0); + thrust::for_each(handle.get_thrust_policy(), + count, + count + current_query_size * current_k, + [=] __device__(IndexType i) { + IndexType row = i / current_k, col = i % current_k; + IndexType out_index = row * temp_out_cols + j * k / tile_cols + col; + + out_distances[out_index] = in_distances[i]; + out_indices[out_index] = in_indices[i] + j; + }); + } + } + + if (tile_cols != n) { + // select the actual top-k items here from the temporary output + detail::select_k(temp_out_distances.data(), + temp_out_indices.data(), + current_query_size, + temp_out_cols, + distances + i * k, + indices + i * k, + true, + k, + stream); + } + } +} + /** * @brief Merge knn distances and index matrix, which have been partitioned * by row, into a single matrix with only the k-nearest neighbors. @@ -311,7 +431,6 @@ void brute_force_knn_impl( } else { switch (metric) { case raft::distance::DistanceType::Haversine: - ASSERT(D == 2, "Haversine distance requires 2 dimensions " "(latitude / longitude)."); @@ -319,35 +438,9 @@ void brute_force_knn_impl( haversine_knn(out_i_ptr, out_d_ptr, input[i], search_items, sizes[i], n, k, stream); break; default: - faiss::MetricType m = build_faiss_metric(metric); - - raft::spatial::knn::RmmGpuResources gpu_res; - - gpu_res.noTempMemory(); - gpu_res.setDefaultStream(device, stream); - - faiss::gpu::GpuDistanceParams args; - args.metric = m; - args.metricArg = metricArg; - args.k = k; - args.dims = D; - args.vectors = input[i]; - args.vectorsRowMajor = rowMajorIndex; - args.numVectors = sizes[i]; - args.queries = search_items; - args.queriesRowMajor = rowMajorQuery; - args.numQueries = n; - args.outDistances = out_d_ptr; - args.outIndices = out_i_ptr; - args.outIndicesType = sizeof(IdxType) == 4 ? faiss::gpu::IndicesDataType::I32 - : faiss::gpu::IndicesDataType::I64; - - /** - * @todo: Until FAISS supports pluggable allocation strategies, - * we will not reap the benefits of the pool allocator for - * avoiding device-wide synchronizations from cudaMalloc/cudaFree - */ - bfKnn(&gpu_res, args); + tiled_brute_force_knn( + handle, input[i], search_items, sizes[i], n, D, k, out_d_ptr, out_i_ptr, metric); + break; } } diff --git a/cpp/include/raft/spatial/knn/detail/selection_faiss.cuh b/cpp/include/raft/spatial/knn/detail/selection_faiss.cuh index 2cdc0fae91..fa1f556f22 100644 --- a/cpp/include/raft/spatial/knn/detail/selection_faiss.cuh +++ b/cpp/include/raft/spatial/knn/detail/selection_faiss.cuh @@ -16,7 +16,6 @@ #pragma once -#include #include #include diff --git a/cpp/include/raft/spatial/knn/faiss_mr.hpp b/cpp/include/raft/spatial/knn/faiss_mr.hpp deleted file mode 100644 index 3cae417996..0000000000 --- a/cpp/include/raft/spatial/knn/faiss_mr.hpp +++ /dev/null @@ -1,640 +0,0 @@ -/** - * Copyright (c) Facebook, Inc. and its affiliates. - * - * This source code is licensed under the MIT license found in the - * LICENSE file in the root directory of this source tree. - */ - -/* -This code contains unnecessary code duplication. These could be deleted -once the relevant changes would be made on the FAISS side. Indeed most of -the logic in the below code is similar to FAISS's standard implementation -and should thus be inherited instead of duplicated. This FAISS's issue -once solved should allow the removal of the unnecessary duplicates -in this file : https://github.com/facebookresearch/faiss/issues/2097 -*/ - -#pragma once - -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include - -#include -#include -#include - -namespace raft { -namespace spatial { -namespace knn { - -using namespace faiss::gpu; - -namespace { - -// How many streams per device we allocate by default (for multi-streaming) -constexpr int kNumStreams = 2; - -// Use 256 MiB of pinned memory for async CPU <-> GPU copies by default -constexpr size_t kDefaultPinnedMemoryAllocation = (size_t)256 * 1024 * 1024; - -// Default temporary memory allocation for <= 4 GiB memory GPUs -constexpr size_t k4GiBTempMem = (size_t)512 * 1024 * 1024; - -// Default temporary memory allocation for <= 8 GiB memory GPUs -constexpr size_t k8GiBTempMem = (size_t)1024 * 1024 * 1024; - -// Maximum temporary memory allocation for all GPUs -constexpr size_t kMaxTempMem = (size_t)1536 * 1024 * 1024; - -std::string allocsToString(const std::unordered_map& map) -{ - // Produce a sorted list of all outstanding allocations by type - std::unordered_map> stats; - - for (auto& entry : map) { - auto& a = entry.second; - - auto it = stats.find(a.type); - if (it != stats.end()) { - stats[a.type].first++; - stats[a.type].second += a.size; - } else { - stats[a.type] = std::make_pair(1, a.size); - } - } - - std::stringstream ss; - for (auto& entry : stats) { - ss << "Alloc type " << allocTypeToString(entry.first) << ": " << entry.second.first - << " allocations, " << entry.second.second << " bytes\n"; - } - - return ss.str(); -} - -} // namespace - -/// RMM implementation of the GpuResources object that provides for a -/// temporary memory manager -class RmmGpuResourcesImpl : public GpuResources { - public: - RmmGpuResourcesImpl() - : pinnedMemAlloc_(nullptr), - pinnedMemAllocSize_(0), - // let the adjustment function determine the memory size for us by passing - // in a huge value that will then be adjusted - tempMemSize_(getDefaultTempMemForGPU(-1, std::numeric_limits::max())), - pinnedMemSize_(kDefaultPinnedMemoryAllocation), - allocLogging_(false), - cmr(new rmm::mr::cuda_memory_resource), - mmr(new rmm::mr::managed_memory_resource), - pmr(new rmm::mr::pinned_memory_resource){}; - - ~RmmGpuResourcesImpl() - { - // The temporary memory allocator has allocated memory through us, so clean - // that up before we finish fully de-initializing ourselves - tempMemory_.clear(); - - // Make sure all allocations have been freed - bool allocError = false; - - for (auto& entry : allocs_) { - auto& map = entry.second; - - if (!map.empty()) { - std::cerr << "RmmGpuResources destroyed with allocations outstanding:\n" - << "Device " << entry.first << " outstanding allocations:\n"; - std::cerr << allocsToString(map); - allocError = true; - } - } - - FAISS_ASSERT_MSG(!allocError, "GPU memory allocations not properly cleaned up"); - - for (auto& entry : defaultStreams_) { - DeviceScope scope(entry.first); - - // We created these streams, so are responsible for destroying them - CUDA_VERIFY(cudaStreamDestroy(entry.second)); - } - - for (auto& entry : alternateStreams_) { - DeviceScope scope(entry.first); - - for (auto stream : entry.second) { - CUDA_VERIFY(cudaStreamDestroy(stream)); - } - } - - for (auto& entry : asyncCopyStreams_) { - DeviceScope scope(entry.first); - - CUDA_VERIFY(cudaStreamDestroy(entry.second)); - } - - for (auto& entry : blasHandles_) { - DeviceScope scope(entry.first); - - auto blasStatus = cublasDestroy(entry.second); - FAISS_ASSERT(blasStatus == CUBLAS_STATUS_SUCCESS); - } - - if (pinnedMemAlloc_) { pmr->deallocate(pinnedMemAlloc_, pinnedMemAllocSize_); } - }; - - /// Disable allocation of temporary memory; all temporary memory - /// requests will call cudaMalloc / cudaFree at the point of use - void noTempMemory() { setTempMemory(0); }; - - /// Specify that we wish to use a certain fixed size of memory on - /// all devices as temporary memory. This is the upper bound for the GPU - /// memory that we will reserve. We will never go above 1.5 GiB on any GPU; - /// smaller GPUs (with <= 4 GiB or <= 8 GiB) will use less memory than that. - /// To avoid any temporary memory allocation, pass 0. - void setTempMemory(size_t size) - { - if (tempMemSize_ != size) { - // adjust based on general limits - tempMemSize_ = getDefaultTempMemForGPU(-1, size); - - // We need to re-initialize memory resources for all current devices that - // have been initialized. - // This should be safe to do, even if we are currently running work, because - // the cudaFree call that this implies will force-synchronize all GPUs with - // the CPU - for (auto& p : tempMemory_) { - int device = p.first; - // Free the existing memory first - p.second.reset(); - - // Allocate new - p.second = std::unique_ptr( - new StackDeviceMemory(this, - p.first, - // adjust for this specific device - getDefaultTempMemForGPU(device, tempMemSize_))); - } - } - }; - - /// Set amount of pinned memory to allocate, for async GPU <-> CPU - /// transfers - void setPinnedMemory(size_t size) - { - // Should not call this after devices have been initialized - FAISS_ASSERT(defaultStreams_.size() == 0); - FAISS_ASSERT(!pinnedMemAlloc_); - - pinnedMemSize_ = size; - }; - - /// Called to change the stream for work ordering. We do not own `stream`; - /// i.e., it will not be destroyed when the GpuResources object gets cleaned - /// up. - /// We are guaranteed that all Faiss GPU work is ordered with respect to - /// this stream upon exit from an index or other Faiss GPU call. - void setDefaultStream(int device, cudaStream_t stream) - { - if (isInitialized(device)) { - // A new series of calls may not be ordered with what was the previous - // stream, so if the stream being specified is different, then we need to - // ensure ordering between the two (new stream waits on old). - auto it = userDefaultStreams_.find(device); - cudaStream_t prevStream = nullptr; - - if (it != userDefaultStreams_.end()) { - prevStream = it->second; - } else { - FAISS_ASSERT(defaultStreams_.count(device)); - prevStream = defaultStreams_[device]; - } - - if (prevStream != stream) { streamWait({stream}, {prevStream}); } - } - - userDefaultStreams_[device] = stream; - }; - - /// Revert the default stream to the original stream managed by this resources - /// object, in case someone called `setDefaultStream`. - void revertDefaultStream(int device) - { - if (isInitialized(device)) { - auto it = userDefaultStreams_.find(device); - - if (it != userDefaultStreams_.end()) { - // There was a user stream set that we need to synchronize against - cudaStream_t prevStream = userDefaultStreams_[device]; - - FAISS_ASSERT(defaultStreams_.count(device)); - cudaStream_t newStream = defaultStreams_[device]; - - streamWait({newStream}, {prevStream}); - } - } - - userDefaultStreams_.erase(device); - }; - - /// Returns the stream for the given device on which all Faiss GPU work is - /// ordered. - /// We are guaranteed that all Faiss GPU work is ordered with respect to - /// this stream upon exit from an index or other Faiss GPU call. - cudaStream_t getDefaultStream(int device) - { - initializeForDevice(device); - - auto it = userDefaultStreams_.find(device); - if (it != userDefaultStreams_.end()) { - // There is a user override stream set - return it->second; - } - - // Otherwise, our base default stream - return defaultStreams_[device]; - }; - - /// Called to change the work ordering streams to the null stream - /// for all devices - void setDefaultNullStreamAllDevices() - { - for (int dev = 0; dev < getNumDevices(); ++dev) { - setDefaultStream(dev, nullptr); - } - }; - - /// If enabled, will print every GPU memory allocation and deallocation to - /// standard output - void setLogMemoryAllocations(bool enable) { allocLogging_ = enable; }; - - public: - /// Internal system calls - - /// Initialize resources for this device - void initializeForDevice(int device) - { - if (isInitialized(device)) { return; } - - // If this is the first device that we're initializing, create our - // pinned memory allocation - if (defaultStreams_.empty() && pinnedMemSize_ > 0) { - pinnedMemAlloc_ = pmr->allocate(pinnedMemSize_); - pinnedMemAllocSize_ = pinnedMemSize_; - } - - FAISS_ASSERT(device < getNumDevices()); - DeviceScope scope(device); - - // Make sure that device properties for all devices are cached - auto& prop = getDeviceProperties(device); - - // Also check to make sure we meet our minimum compute capability (3.0) - FAISS_ASSERT_FMT(prop.major >= 3, - "Device id %d with CC %d.%d not supported, " - "need 3.0+ compute capability", - device, - prop.major, - prop.minor); - - // Create streams - cudaStream_t defaultStream = 0; - CUDA_VERIFY(cudaStreamCreateWithFlags(&defaultStream, cudaStreamNonBlocking)); - - defaultStreams_[device] = defaultStream; - - cudaStream_t asyncCopyStream = 0; - CUDA_VERIFY(cudaStreamCreateWithFlags(&asyncCopyStream, cudaStreamNonBlocking)); - - asyncCopyStreams_[device] = asyncCopyStream; - - std::vector deviceStreams; - for (int j = 0; j < kNumStreams; ++j) { - cudaStream_t stream = 0; - CUDA_VERIFY(cudaStreamCreateWithFlags(&stream, cudaStreamNonBlocking)); - - deviceStreams.push_back(stream); - } - - alternateStreams_[device] = std::move(deviceStreams); - - // Create cuBLAS handle - cublasHandle_t blasHandle = 0; - auto blasStatus = cublasCreate(&blasHandle); - FAISS_ASSERT(blasStatus == CUBLAS_STATUS_SUCCESS); - blasHandles_[device] = blasHandle; - - // For CUDA 10 on V100, enabling tensor core usage would enable automatic - // rounding down of inputs to f16 (though accumulate in f32) which results in - // unacceptable loss of precision in general. - // For CUDA 11 / A100, only enable tensor core support if it doesn't result in - // a loss of precision. -#if CUDA_VERSION >= 11000 - cublasSetMathMode(blasHandle, CUBLAS_MATH_DISALLOW_REDUCED_PRECISION_REDUCTION); -#endif - - FAISS_ASSERT(allocs_.count(device) == 0); - allocs_[device] = std::unordered_map(); - - FAISS_ASSERT(tempMemory_.count(device) == 0); - auto mem = std::unique_ptr( - new StackDeviceMemory(this, - device, - // adjust for this specific device - getDefaultTempMemForGPU(device, tempMemSize_))); - - tempMemory_.emplace(device, std::move(mem)); - }; - - cublasHandle_t getBlasHandle(int device) - { - initializeForDevice(device); - return blasHandles_[device]; - }; - - std::vector getAlternateStreams(int device) - { - initializeForDevice(device); - return alternateStreams_[device]; - }; - - /// Allocate non-temporary GPU memory - void* allocMemory(const AllocRequest& req) - { - initializeForDevice(req.device); - - // We don't allocate a placeholder for zero-sized allocations - if (req.size == 0) { return nullptr; } - - // Make sure that the allocation is a multiple of 16 bytes for alignment - // purposes - auto adjReq = req; - adjReq.size = utils::roundUp(adjReq.size, (size_t)16); - - void* p = nullptr; - - if (allocLogging_) { std::cout << "RmmGpuResources: alloc " << adjReq.toString() << "\n"; } - - if (adjReq.space == MemorySpace::Temporary) { - // If we don't have enough space in our temporary memory manager, we need - // to allocate this request separately - auto& tempMem = tempMemory_[adjReq.device]; - - if (adjReq.size > tempMem->getSizeAvailable()) { - // We need to allocate this ourselves - AllocRequest newReq = adjReq; - newReq.space = MemorySpace::Device; - newReq.type = AllocType::TemporaryMemoryOverflow; - - return allocMemory(newReq); - } - - // Otherwise, we can handle this locally - p = tempMemory_[adjReq.device]->allocMemory(adjReq.stream, adjReq.size); - - } else if (adjReq.space == MemorySpace::Device) { - p = cmr->allocate(adjReq.size, adjReq.stream); - } else if (adjReq.space == MemorySpace::Unified) { - p = mmr->allocate(adjReq.size, adjReq.stream); - } else { - FAISS_ASSERT_FMT(false, "unknown MemorySpace %d", (int)adjReq.space); - } - - allocs_[adjReq.device][p] = adjReq; - - return p; - }; - - /// Returns a previous allocation - void deallocMemory(int device, void* p) - { - FAISS_ASSERT(isInitialized(device)); - - if (!p) { return; } - - auto& a = allocs_[device]; - auto it = a.find(p); - FAISS_ASSERT(it != a.end()); - - auto& req = it->second; - - if (allocLogging_) { std::cout << "RmmGpuResources: dealloc " << req.toString() << "\n"; } - - if (req.space == MemorySpace::Temporary) { - tempMemory_[device]->deallocMemory(device, req.stream, req.size, p); - } else if (req.space == MemorySpace::Device) { - cmr->deallocate(p, req.size, req.stream); - } else if (req.space == MemorySpace::Unified) { - mmr->deallocate(p, req.size, req.stream); - } else { - FAISS_ASSERT_FMT(false, "unknown MemorySpace %d", (int)req.space); - } - - a.erase(it); - }; - - size_t getTempMemoryAvailable(int device) const - { - FAISS_ASSERT(isInitialized(device)); - - auto it = tempMemory_.find(device); - FAISS_ASSERT(it != tempMemory_.end()); - - return it->second->getSizeAvailable(); - }; - - /// Export a description of memory used for Python - std::map>> getMemoryInfo() const - { - using AT = std::map>; - - std::map out; - - for (auto& entry : allocs_) { - AT outDevice; - - for (auto& a : entry.second) { - auto& v = outDevice[allocTypeToString(a.second.type)]; - v.first++; - v.second += a.second.size; - } - - out[entry.first] = std::move(outDevice); - } - - return out; - }; - - std::pair getPinnedMemory() - { - return std::make_pair(pinnedMemAlloc_, pinnedMemAllocSize_); - }; - - cudaStream_t getAsyncCopyStream(int device) - { - initializeForDevice(device); - return asyncCopyStreams_[device]; - }; - - private: - /// Have GPU resources been initialized for this device yet? - bool isInitialized(int device) const - { - // Use default streams as a marker for whether or not a certain - // device has been initialized - return defaultStreams_.count(device) != 0; - }; - - /// Adjust the default temporary memory allocation based on the total GPU - /// memory size - static size_t getDefaultTempMemForGPU(int device, size_t requested) - { - auto totalMem = device != -1 ? getDeviceProperties(device).totalGlobalMem - : std::numeric_limits::max(); - - if (totalMem <= (size_t)4 * 1024 * 1024 * 1024) { - // If the GPU has <= 4 GiB of memory, reserve 512 MiB - - if (requested > k4GiBTempMem) { return k4GiBTempMem; } - } else if (totalMem <= (size_t)8 * 1024 * 1024 * 1024) { - // If the GPU has <= 8 GiB of memory, reserve 1 GiB - - if (requested > k8GiBTempMem) { return k8GiBTempMem; } - } else { - // Never use more than 1.5 GiB - if (requested > kMaxTempMem) { return kMaxTempMem; } - } - - // use whatever lower limit the user requested - return requested; - }; - - private: - /// Set of currently outstanding memory allocations per device - /// device -> (alloc request, allocated ptr) - std::unordered_map> allocs_; - - /// Temporary memory provider, per each device - std::unordered_map> tempMemory_; - - /// Our default stream that work is ordered on, one per each device - std::unordered_map defaultStreams_; - - /// This contains particular streams as set by the user for - /// ordering, if any - std::unordered_map userDefaultStreams_; - - /// Other streams we can use, per each device - std::unordered_map> alternateStreams_; - - /// Async copy stream to use for GPU <-> CPU pinned memory copies - std::unordered_map asyncCopyStreams_; - - /// cuBLAS handle for each device - std::unordered_map blasHandles_; - - /// Pinned memory allocation for use with this GPU - void* pinnedMemAlloc_; - size_t pinnedMemAllocSize_; - - /// Another option is to use a specified amount of memory on all - /// devices - size_t tempMemSize_; - - /// Amount of pinned memory we should allocate - size_t pinnedMemSize_; - - /// Whether or not we log every GPU memory allocation and deallocation - bool allocLogging_; - - // cuda_memory_resource - std::unique_ptr cmr; - - // managed_memory_resource - std::unique_ptr mmr; - - // pinned_memory_resource - std::unique_ptr pmr; -}; - -/// Default implementation of GpuResources that allocates a cuBLAS -/// stream and 2 streams for use, as well as temporary memory. -/// Internally, the Faiss GPU code uses the instance managed by getResources, -/// but this is the user-facing object that is internally reference counted. -class RmmGpuResources : public GpuResourcesProvider { - public: - RmmGpuResources() : res_(new RmmGpuResourcesImpl){}; - - ~RmmGpuResources(){}; - - std::shared_ptr getResources() { return res_; }; - - /// Disable allocation of temporary memory; all temporary memory - /// requests will call cudaMalloc / cudaFree at the point of use - void noTempMemory() { res_->noTempMemory(); }; - - /// Specify that we wish to use a certain fixed size of memory on - /// all devices as temporary memory. This is the upper bound for the GPU - /// memory that we will reserve. We will never go above 1.5 GiB on any GPU; - /// smaller GPUs (with <= 4 GiB or <= 8 GiB) will use less memory than that. - /// To avoid any temporary memory allocation, pass 0. - void setTempMemory(size_t size) { res_->setTempMemory(size); }; - - /// Set amount of pinned memory to allocate, for async GPU <-> CPU - /// transfers - void setPinnedMemory(size_t size) { res_->setPinnedMemory(size); }; - - /// Called to change the stream for work ordering. We do not own `stream`; - /// i.e., it will not be destroyed when the GpuResources object gets cleaned - /// up. - /// We are guaranteed that all Faiss GPU work is ordered with respect to - /// this stream upon exit from an index or other Faiss GPU call. - void setDefaultStream(int device, cudaStream_t stream) - { - res_->setDefaultStream(device, stream); - }; - - /// Revert the default stream to the original stream managed by this resources - /// object, in case someone called `setDefaultStream`. - void revertDefaultStream(int device) { res_->revertDefaultStream(device); }; - - /// Called to change the work ordering streams to the null stream - /// for all devices - void setDefaultNullStreamAllDevices() { res_->setDefaultNullStreamAllDevices(); }; - - /// Export a description of memory used for Python - std::map>> getMemoryInfo() const - { - return res_->getMemoryInfo(); - }; - - /// Returns the current default stream - cudaStream_t getDefaultStream(int device) { return res_->getDefaultStream(device); }; - - /// Returns the current amount of temp memory available - size_t getTempMemoryAvailable(int device) const { return res_->getTempMemoryAvailable(device); }; - - /// Synchronize our default stream with the CPU - void syncDefaultStreamCurrentDevice() { res_->syncDefaultStreamCurrentDevice(); }; - - /// If enabled, will print every GPU memory allocation and deallocation to - /// standard output - void setLogMemoryAllocations(bool enable) { res_->setLogMemoryAllocations(enable); }; - - private: - std::shared_ptr res_; -}; - -} // namespace knn -} // namespace spatial -} // namespace raft \ No newline at end of file diff --git a/cpp/test/CMakeLists.txt b/cpp/test/CMakeLists.txt index 8ca30a5c82..049899754a 100644 --- a/cpp/test/CMakeLists.txt +++ b/cpp/test/CMakeLists.txt @@ -244,10 +244,10 @@ if(BUILD_TESTS) test/neighbors/ann_ivf_pq/test_uint8_t_uint64_t.cu test/neighbors/knn.cu test/neighbors/fused_l2_knn.cu + test/neighbors/tiled_knn.cu test/neighbors/haversine.cu test/neighbors/ball_cover.cu test/neighbors/epsilon_neighborhood.cu - test/neighbors/faiss_mr.cu test/neighbors/refine.cu test/neighbors/selection.cu OPTIONAL diff --git a/cpp/test/neighbors/faiss_mr.cu b/cpp/test/neighbors/faiss_mr.cu deleted file mode 100644 index 38e793d120..0000000000 --- a/cpp/test/neighbors/faiss_mr.cu +++ /dev/null @@ -1,94 +0,0 @@ -/* - * Copyright (c) 2021-2022, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "../test_utils.cuh" - -#include -#include -#include - -#include - -#include - -#include -#include -#include - -namespace raft { -namespace spatial { -namespace knn { - -using namespace faiss::gpu; - -struct AllocInputs { - size_t size; -}; - -template -class FAISS_MR_Test : public ::testing::TestWithParam { - public: - FAISS_MR_Test() - : params_(::testing::TestWithParam::GetParam()), stream(handle.get_stream()) - { - } - - protected: - size_t getFreeMemory(MemorySpace mem_space) - { - if (mem_space == MemorySpace::Device) { - rmm::mr::cuda_memory_resource cmr; - rmm::mr::device_memory_resource* dmr = &cmr; - return dmr->get_mem_info(stream).first; - } else if (mem_space == MemorySpace::Unified) { - rmm::mr::managed_memory_resource mmr; - rmm::mr::device_memory_resource* dmr = &mmr; - return dmr->get_mem_info(stream).first; - } - return 0; - } - - void testAllocs(MemorySpace mem_space) - { - raft::spatial::knn::RmmGpuResources faiss_mr; - auto faiss_mr_impl = faiss_mr.getResources(); - size_t free_before = getFreeMemory(mem_space); - AllocRequest req(AllocType::Other, 0, mem_space, stream, params_.size); - void* ptr = faiss_mr_impl->allocMemory(req); - size_t free_after_alloc = getFreeMemory(mem_space); - faiss_mr_impl->deallocMemory(0, ptr); - ASSERT_TRUE(free_after_alloc <= free_before - params_.size); - } - - raft::handle_t handle; - cudaStream_t stream; - AllocInputs params_; -}; - -const std::vector inputs = {{19687}}; - -typedef FAISS_MR_Test FAISS_MR_TestF; -TEST_P(FAISS_MR_TestF, TestAllocs) -{ - testAllocs(MemorySpace::Device); - testAllocs(MemorySpace::Unified); -} - -INSTANTIATE_TEST_CASE_P(FAISS_MR_Test, FAISS_MR_TestF, ::testing::ValuesIn(inputs)); - -} // namespace knn -} // namespace spatial -} // namespace raft diff --git a/cpp/test/neighbors/fused_l2_knn.cu b/cpp/test/neighbors/fused_l2_knn.cu index ca20bebaf6..5bb72e6b3d 100644 --- a/cpp/test/neighbors/fused_l2_knn.cu +++ b/cpp/test/neighbors/fused_l2_knn.cu @@ -15,6 +15,7 @@ */ #include "../test_utils.cuh" +#include "./knn_utils.cuh" #include #include @@ -45,59 +46,6 @@ struct FusedL2KNNInputs { raft::distance::DistanceType metric_; }; -template -struct idx_dist_pair { - IdxT idx; - DistT dist; - compareDist eq_compare; - bool operator==(const idx_dist_pair& a) const - { - if (idx == a.idx) return true; - if (eq_compare(dist, a.dist)) return true; - return false; - } - idx_dist_pair(IdxT x, DistT y, compareDist op) : idx(x), dist(y), eq_compare(op) {} -}; - -template -testing::AssertionResult devArrMatchKnnPair(const T* expected_idx, - const T* actual_idx, - const DistT* expected_dist, - const DistT* actual_dist, - size_t rows, - size_t cols, - const DistT eps, - cudaStream_t stream = 0) -{ - size_t size = rows * cols; - std::unique_ptr exp_idx_h(new T[size]); - std::unique_ptr act_idx_h(new T[size]); - std::unique_ptr exp_dist_h(new DistT[size]); - std::unique_ptr act_dist_h(new DistT[size]); - raft::update_host(exp_idx_h.get(), expected_idx, size, stream); - raft::update_host(act_idx_h.get(), actual_idx, size, stream); - raft::update_host(exp_dist_h.get(), expected_dist, size, stream); - raft::update_host(act_dist_h.get(), actual_dist, size, stream); - RAFT_CUDA_TRY(cudaStreamSynchronize(stream)); - for (size_t i(0); i < rows; ++i) { - for (size_t j(0); j < cols; ++j) { - auto idx = i * cols + j; // row major assumption! - auto exp_idx = exp_idx_h.get()[idx]; - auto act_idx = act_idx_h.get()[idx]; - auto exp_dist = exp_dist_h.get()[idx]; - auto act_dist = act_dist_h.get()[idx]; - idx_dist_pair exp_kvp(exp_idx, exp_dist, raft::CompareApprox(eps)); - idx_dist_pair act_kvp(act_idx, act_dist, raft::CompareApprox(eps)); - if (!(exp_kvp == act_kvp)) { - return testing::AssertionFailure() - << "actual=" << act_kvp.idx << "," << act_kvp.dist << "!=" - << "expected" << exp_kvp.idx << "," << exp_kvp.dist << " @" << i << "," << j; - } - } - } - return testing::AssertionSuccess(); -} - template class FusedL2KNNTest : public ::testing::TestWithParam { public: diff --git a/cpp/test/neighbors/knn_utils.cuh b/cpp/test/neighbors/knn_utils.cuh new file mode 100644 index 0000000000..2c4dad5c0b --- /dev/null +++ b/cpp/test/neighbors/knn_utils.cuh @@ -0,0 +1,84 @@ +/* + * Copyright (c) 2022-2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include "../test_utils.cuh" +#include + +#include +#include +#include +#include +#include +#include + +#include + +namespace raft::spatial::knn { +template +struct idx_dist_pair { + IdxT idx; + DistT dist; + compareDist eq_compare; + bool operator==(const idx_dist_pair& a) const + { + if (idx == a.idx) return true; + if (eq_compare(dist, a.dist)) return true; + return false; + } + idx_dist_pair(IdxT x, DistT y, compareDist op) : idx(x), dist(y), eq_compare(op) {} +}; + +template +testing::AssertionResult devArrMatchKnnPair(const T* expected_idx, + const T* actual_idx, + const DistT* expected_dist, + const DistT* actual_dist, + size_t rows, + size_t cols, + const DistT eps, + cudaStream_t stream = 0) +{ + size_t size = rows * cols; + std::unique_ptr exp_idx_h(new T[size]); + std::unique_ptr act_idx_h(new T[size]); + std::unique_ptr exp_dist_h(new DistT[size]); + std::unique_ptr act_dist_h(new DistT[size]); + raft::update_host(exp_idx_h.get(), expected_idx, size, stream); + raft::update_host(act_idx_h.get(), actual_idx, size, stream); + raft::update_host(exp_dist_h.get(), expected_dist, size, stream); + raft::update_host(act_dist_h.get(), actual_dist, size, stream); + RAFT_CUDA_TRY(cudaStreamSynchronize(stream)); + for (size_t i(0); i < rows; ++i) { + for (size_t j(0); j < cols; ++j) { + auto idx = i * cols + j; // row major assumption! + auto exp_idx = exp_idx_h.get()[idx]; + auto act_idx = act_idx_h.get()[idx]; + auto exp_dist = exp_dist_h.get()[idx]; + auto act_dist = act_dist_h.get()[idx]; + idx_dist_pair exp_kvp(exp_idx, exp_dist, raft::CompareApprox(eps)); + idx_dist_pair act_kvp(act_idx, act_dist, raft::CompareApprox(eps)); + if (!(exp_kvp == act_kvp)) { + return testing::AssertionFailure() + << "actual=" << act_kvp.idx << "," << act_kvp.dist << "!=" + << "expected" << exp_kvp.idx << "," << exp_kvp.dist << " @" << i << "," << j; + } + } + } + return testing::AssertionSuccess(); +} +} // namespace raft::spatial::knn diff --git a/cpp/test/neighbors/tiled_knn.cu b/cpp/test/neighbors/tiled_knn.cu new file mode 100644 index 0000000000..0805b628f7 --- /dev/null +++ b/cpp/test/neighbors/tiled_knn.cu @@ -0,0 +1,165 @@ +/* + * Copyright (c) 2021-2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "../test_utils.cuh" +#include "./knn_utils.cuh" + +#include +#include +#include +#include + +#if defined RAFT_NN_COMPILED +#include +#include +#endif + +#include + +#include + +#include + +#include +#include +#include + +namespace raft::neighbors::brute_force { +struct TiledKNNInputs { + int num_queries; + int num_db_vecs; + int dim; + int k; + int row_tiles; + int col_tiles; + raft::distance::DistanceType metric_; +}; + +template +class TiledKNNTest : public ::testing::TestWithParam { + public: + TiledKNNTest() + : stream_(handle_.get_stream()), + params_(::testing::TestWithParam::GetParam()), + database(params_.num_db_vecs * params_.dim, stream_), + search_queries(params_.num_queries * params_.dim, stream_), + raft_indices_(params_.num_queries * params_.k, stream_), + raft_distances_(params_.num_queries * params_.k, stream_), + ref_indices_(params_.num_queries * params_.k, stream_), + ref_distances_(params_.num_queries * params_.k, stream_) + { + RAFT_CUDA_TRY(cudaMemsetAsync(database.data(), 0, database.size() * sizeof(T), stream_)); + RAFT_CUDA_TRY( + cudaMemsetAsync(search_queries.data(), 0, search_queries.size() * sizeof(T), stream_)); + RAFT_CUDA_TRY( + cudaMemsetAsync(raft_indices_.data(), 0, raft_indices_.size() * sizeof(int), stream_)); + RAFT_CUDA_TRY( + cudaMemsetAsync(raft_distances_.data(), 0, raft_distances_.size() * sizeof(T), stream_)); + RAFT_CUDA_TRY( + cudaMemsetAsync(ref_indices_.data(), 0, ref_indices_.size() * sizeof(int), stream_)); + RAFT_CUDA_TRY( + cudaMemsetAsync(ref_distances_.data(), 0, ref_distances_.size() * sizeof(T), stream_)); + } + + protected: + void testBruteForce() + { + // calculate the naive knn, by calculating the full pairwise distances and doing a k-select + rmm::device_uvector temp_distances(num_db_vecs * num_queries, stream_); + distance::pairwise_distance( + handle_, + raft::make_device_matrix_view(search_queries.data(), num_queries, dim), + raft::make_device_matrix_view(database.data(), num_db_vecs, dim), + raft::make_device_matrix_view(temp_distances.data(), num_queries, num_db_vecs), + metric); + + using namespace raft::spatial; + knn::select_k(temp_distances.data(), + nullptr, + num_queries, + num_db_vecs, + ref_distances_.data(), + ref_indices_.data(), + true, + k_, + stream_); + + knn::detail::tiled_brute_force_knn(handle_, + search_queries.data(), + database.data(), + num_queries, + num_db_vecs, + dim, + k_, + raft_distances_.data(), + raft_indices_.data(), + metric, + params_.row_tiles, + params_.col_tiles); + + // verify. + ASSERT_TRUE(knn::devArrMatchKnnPair(ref_indices_.data(), + raft_indices_.data(), + ref_distances_.data(), + raft_distances_.data(), + num_queries, + k_, + float(0.001), + stream_)); + } + + void SetUp() override + { + num_queries = params_.num_queries; + num_db_vecs = params_.num_db_vecs; + dim = params_.dim; + k_ = params_.k; + metric = params_.metric_; + + unsigned long long int seed = 1234ULL; + raft::random::RngState r(seed); + uniform(handle_, r, database.data(), num_db_vecs * dim, T(-1.0), T(1.0)); + uniform(handle_, r, search_queries.data(), num_queries * dim, T(-1.0), T(1.0)); + } + + private: + raft::handle_t handle_; + cudaStream_t stream_ = 0; + TiledKNNInputs params_; + int num_queries; + int num_db_vecs; + int dim; + rmm::device_uvector database; + rmm::device_uvector search_queries; + rmm::device_uvector raft_indices_; + rmm::device_uvector raft_distances_; + rmm::device_uvector ref_indices_; + rmm::device_uvector ref_distances_; + int k_; + raft::distance::DistanceType metric; +}; + +const std::vector random_inputs = { + {256, 512, 16, 8, 16, 8, raft::distance::DistanceType::L2Expanded}, + {4, 12, 32, 6, 4, 8, raft::distance::DistanceType::L2Expanded}, + {10000, 40000, 32, 30, 512, 1024, raft::distance::DistanceType::L2Expanded}, +}; + +typedef TiledKNNTest TiledKNNTestF; +TEST_P(TiledKNNTestF, BruteForce) { this->testBruteForce(); } + +INSTANTIATE_TEST_CASE_P(TiledKNNTest, TiledKNNTestF, ::testing::ValuesIn(random_inputs)); +} // namespace raft::neighbors::brute_force diff --git a/dependencies.yaml b/dependencies.yaml index ae900542c0..36d7dd809f 100644 --- a/dependencies.yaml +++ b/dependencies.yaml @@ -172,8 +172,6 @@ dependencies: - ucx>=1.13.0 - ucx-py=0.30 - ucx-proc=*=gpu - - libfaiss>=1.7.1=cuda* - - faiss-proc=*=cuda - dask-cuda=23.02 test_python: common: diff --git a/docs/source/build.md b/docs/source/build.md index 4052e49cf8..a6e16d3824 100644 --- a/docs/source/build.md +++ b/docs/source/build.md @@ -47,7 +47,6 @@ In addition to the libraries included with cudatoolkit 11.0+, there are some oth - [cuCollections](https://github.com/NVIDIA/cuCollections) - Used in `raft::sparse::distance` API. - [Libcu++](https://github.com/NVIDIA/libcudacxx) v1.7.0 - Used by cuCollections - [CUTLASS](https://github.com/NVIDIA/cutlass) v2.9.1 - Used in `raft::distance` API. -- [FAISS](https://github.com/facebookresearch/faiss) v1.7.0 - Used in `raft::neighbors` API. - [NCCL](https://github.com/NVIDIA/nccl) - Used in `raft::comms` API and needed to build `raft-dask`. - [UCX](https://github.com/openucx/ucx) - Used in `raft::comms` API and needed to build `raft-dask`. - [Googletest](https://github.com/google/googletest) - Needed to build tests @@ -60,14 +59,14 @@ The recommended way to build and install RAFT is to use the `build.sh` script in ### Header-only C++ -`build.sh` uses [rapids-cmake](https://github.com/rapidsai/rapids-cmake), which will automatically download any dependencies which are not already installed. It's important to note that while all the headers will be installed and available, some parts of the RAFT API depend on libraries like `FAISS`, which will need to be explicitly enabled in `build.sh`. +`build.sh` uses [rapids-cmake](https://github.com/rapidsai/rapids-cmake), which will automatically download any dependencies which are not already installed. It's important to note that while all the headers will be installed and available, some parts of the RAFT API depend on libraries like `cuCollections`, which will need to be explicitly enabled in `build.sh`. The following example will download the needed dependencies and install the RAFT headers into `$INSTALL_PREFIX/include/raft`. ```bash ./build.sh libraft ``` -The `-n` flag can be passed to just have the build download the needed dependencies. Since RAFT is primarily used at build-time, the dependencies will never be installed by the RAFT build, with the exception of building FAISS statically into the shared libraries. +The `-n` flag can be passed to just have the build download the needed dependencies. Since RAFT is primarily used at build-time, the dependencies will never be installed by the RAFT build. ```bash ./build.sh libraft -n ``` @@ -167,8 +166,7 @@ RAFT's cmake has the following configurable flags available:. | RAFT_COMPILE_LIBRARIES | ON, OFF | OFF | Compiles all `libraft` shared libraries (these are required for Googletests) | | RAFT_COMPILE_NN_LIBRARY | ON, OFF | OFF | Compiles the `libraft-nn` shared library | | RAFT_COMPILE_DIST_LIBRARY | ON, OFF | OFF | Compiles the `libraft-distance` shared library | -| RAFT_ENABLE_NN_DEPENDENCIES | ON, OFF | OFF | Searches for dependencies of nearest neighbors API, such as FAISS, and compiles them if not found. Needed for `raft::spatial::knn` | -| RAFT_USE_FAISS_STATIC | ON, OFF | OFF | Statically link FAISS into `libraft-nn` | +| RAFT_ENABLE_NN_DEPENDENCIES | ON, OFF | OFF | Searches for dependencies of nearest neighbors API, and compiles them if not found. Needed for `raft::spatial::knn` | | RAFT_STATIC_LINK_LIBRARIES | ON, OFF | ON | Build static link libraries instead of shared libraries | | DETECT_CONDA_ENV | ON, OFF | ON | Enable detection of conda environment for dependencies | | NVTX | ON, OFF | OFF | Enable NVTX Markers | @@ -176,7 +174,7 @@ RAFT's cmake has the following configurable flags available:. | CUDA_ENABLE_LINEINFO | ON, OFF | OFF | Enable the -lineinfo option for nvcc | | CUDA_STATIC_RUNTIME | ON, OFF | OFF | Statically link the CUDA runtime | -Currently, shared libraries are provided for the `libraft-nn` and `libraft-distance` components. The `libraft-nn` component depends upon [FAISS](https://github.com/facebookresearch/faiss) and the `RAFT_ENABLE_NN_DEPENDENCIES` option will build it from source if it is not already installed. +Currently, shared libraries are provided for the `libraft-nn` and `libraft-distance` components. ### Python @@ -278,7 +276,7 @@ If RAFT has already been installed, such as by using the `build.sh` script, use ### Using C++ pre-compiled shared libraries -Use `find_package(raft COMPONENTS nn distance)` to enable the shared libraries and transitively pass dependencies through separate targets for each component. In this example, the `raft::distance` and `raft::nn` targets will be available for configuring linking paths in addition to `raft::raft`. These targets will also pass through any transitive dependencies (such as FAISS for the `nn` package). +Use `find_package(raft COMPONENTS nn distance)` to enable the shared libraries and transitively pass dependencies through separate targets for each component. In this example, the `raft::distance` and `raft::nn` targets will be available for configuring linking paths in addition to `raft::raft`. These targets will also pass through any transitive dependencies. The pre-compiled libraries contain template specializations for commonly used types, such as single- and double-precision floating-point. In order to use the symbols in the pre-compiled libraries, the compiler needs to be told not to instantiate templates that are already contained in the shared libraries. By convention, these header files are named `specializations.cuh` and located in the base directory for the packages that contain specializations. @@ -303,7 +301,7 @@ set(RAFT_FORK "rapidsai") set(RAFT_PINNED_TAG "branch-${RAFT_VERSION}") function(find_and_configure_raft) - set(oneValueArgs VERSION FORK PINNED_TAG USE_FAISS_STATIC + set(oneValueArgs VERSION FORK PINNED_TAG COMPILE_LIBRARIES ENABLE_NN_DEPENDENCIES CLONE_ON_PIN USE_NN_LIBRARY USE_DISTANCE_LIBRARY ENABLE_thrust_DEPENDENCY) @@ -348,7 +346,6 @@ function(find_and_configure_raft) "BUILD_TESTS OFF" "BUILD_BENCH OFF" "RAFT_ENABLE_NN_DEPENDENCIES ${PKG_ENABLE_NN_DEPENDENCIES}" - "RAFT_USE_FAISS_STATIC ${PKG_USE_FAISS_STATIC}" "RAFT_COMPILE_LIBRARIES ${PKG_COMPILE_LIBRARIES}" "RAFT_ENABLE_thrust_DEPENDENCY ${PKG_ENABLE_thrust_DEPENDENCY}" ) @@ -370,8 +367,7 @@ find_and_configure_raft(VERSION ${RAFT_VERSION}.00 COMPILE_LIBRARIES NO USE_NN_LIBRARY NO USE_DISTANCE_LIBRARY NO - ENABLE_NN_DEPENDENCIES NO # This builds FAISS if not installed - USE_FAISS_STATIC NO + ENABLE_NN_DEPENDENCIES NO ENABLE_thrust_DEPENDENCY YES ) ``` @@ -388,4 +384,4 @@ Once built and installed, RAFT can be safely uninstalled using `build.sh` by spe Leaving off the installed components will uninstall everything that's been installed: ```bash ./build.sh --uninstall -``` \ No newline at end of file +``` From 805abc72e071b353b99e09dea5c1d494cc9e3f62 Mon Sep 17 00:00:00 2001 From: Ben Frederickson Date: Fri, 27 Jan 2023 11:36:06 -0800 Subject: [PATCH 02/62] fix merge --- cpp/test/neighbors/tiled_knn.cu | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cpp/test/neighbors/tiled_knn.cu b/cpp/test/neighbors/tiled_knn.cu index 0805b628f7..ae5bc1e976 100644 --- a/cpp/test/neighbors/tiled_knn.cu +++ b/cpp/test/neighbors/tiled_knn.cu @@ -136,7 +136,7 @@ class TiledKNNTest : public ::testing::TestWithParam { } private: - raft::handle_t handle_; + raft::device_resources handle_; cudaStream_t stream_ = 0; TiledKNNInputs params_; int num_queries; From 74bd44fdc5071e14f5f4df3daf92ec3d3174b041 Mon Sep 17 00:00:00 2001 From: Ben Frederickson Date: Fri, 27 Jan 2023 14:32:58 -0800 Subject: [PATCH 03/62] Fix bug with col_tiles < K --- .../raft/spatial/knn/detail/knn_brute_force_faiss.cuh | 3 +++ cpp/test/neighbors/tiled_knn.cu | 7 ++++++- 2 files changed, 9 insertions(+), 1 deletion(-) diff --git a/cpp/include/raft/spatial/knn/detail/knn_brute_force_faiss.cuh b/cpp/include/raft/spatial/knn/detail/knn_brute_force_faiss.cuh index 8efca16613..dafc44ff3f 100644 --- a/cpp/include/raft/spatial/knn/detail/knn_brute_force_faiss.cuh +++ b/cpp/include/raft/spatial/knn/detail/knn_brute_force_faiss.cuh @@ -171,6 +171,9 @@ void tiled_brute_force_knn(const raft::device_resources& handle, if (max_row_tile_size && (tile_rows > max_row_tile_size)) { tile_rows = max_row_tile_size; } if (max_col_tile_size && (tile_cols > max_col_tile_size)) { tile_cols = max_col_tile_size; } + // tile_cols must be at least k items + tile_cols = std::max(tile_cols, static_cast(k)); + // stores pairwise distances for the current tile rmm::device_uvector temp_distances(tile_rows * tile_cols, stream); diff --git a/cpp/test/neighbors/tiled_knn.cu b/cpp/test/neighbors/tiled_knn.cu index ae5bc1e976..3dd9a5394b 100644 --- a/cpp/test/neighbors/tiled_knn.cu +++ b/cpp/test/neighbors/tiled_knn.cu @@ -154,8 +154,13 @@ class TiledKNNTest : public ::testing::TestWithParam { const std::vector random_inputs = { {256, 512, 16, 8, 16, 8, raft::distance::DistanceType::L2Expanded}, - {4, 12, 32, 6, 4, 8, raft::distance::DistanceType::L2Expanded}, {10000, 40000, 32, 30, 512, 1024, raft::distance::DistanceType::L2Expanded}, + {345, 1023, 16, 128, 512, 1024, raft::distance::DistanceType::CosineExpanded}, + {789, 20516, 64, 256, 512, 4096, raft::distance::DistanceType::L2SqrtExpanded}, + // Test where the final column tile has < K items: + {4, 12, 32, 6, 4, 8, raft::distance::DistanceType::L2Expanded}, + // Test where passing column_tiles < K + {1, 40, 32, 30, 1, 8, raft::distance::DistanceType::L2Expanded}, }; typedef TiledKNNTest TiledKNNTestF; From 1d9581b61fd8b1ded2997b7c4aa2ff19beca6936 Mon Sep 17 00:00:00 2001 From: Ben Frederickson Date: Mon, 30 Jan 2023 12:33:35 -0800 Subject: [PATCH 04/62] Include metric_arg in bfknn --- .../knn/detail/knn_brute_force_faiss.cuh | 17 ++++++++++++++--- 1 file changed, 14 insertions(+), 3 deletions(-) diff --git a/cpp/include/raft/spatial/knn/detail/knn_brute_force_faiss.cuh b/cpp/include/raft/spatial/knn/detail/knn_brute_force_faiss.cuh index dafc44ff3f..be08c1e3aa 100644 --- a/cpp/include/raft/spatial/knn/detail/knn_brute_force_faiss.cuh +++ b/cpp/include/raft/spatial/knn/detail/knn_brute_force_faiss.cuh @@ -154,6 +154,7 @@ void tiled_brute_force_knn(const raft::device_resources& handle, ElementType* distances, // size (m, k) IndexType* indices, // size (m, k) raft::distance::DistanceType metric, + float metric_arg = 0.0, size_t max_row_tile_size = 0, size_t max_col_tile_size = 0) { @@ -209,7 +210,8 @@ void tiled_brute_force_knn(const raft::device_resources& handle, current_centroid_size, d, metric, - true); + true, + metric_arg); detail::select_k(temp_distances.data(), nullptr, @@ -441,8 +443,17 @@ void brute_force_knn_impl( haversine_knn(out_i_ptr, out_d_ptr, input[i], search_items, sizes[i], n, k, stream); break; default: - tiled_brute_force_knn( - handle, input[i], search_items, sizes[i], n, D, k, out_d_ptr, out_i_ptr, metric); + tiled_brute_force_knn(handle, + input[i], + search_items, + sizes[i], + n, + D, + k, + out_d_ptr, + out_i_ptr, + metric, + metricArg); break; } } From b4cf88c5dab6e1c26c398001c89dbe9f4e51a487 Mon Sep 17 00:00:00 2001 From: Ben Frederickson Date: Mon, 30 Jan 2023 13:07:41 -0800 Subject: [PATCH 05/62] speedup compile times --- cpp/CMakeLists.txt | 5 +- cpp/src/nn/specializations/knn.cu | 86 ------------------- .../nn/specializations/knn_long_float_int.cu | 45 ++++++++++ .../nn/specializations/knn_long_float_uint.cu | 44 ++++++++++ .../nn/specializations/knn_uint_float_int.cu | 44 ++++++++++ .../nn/specializations/knn_uint_float_uint.cu | 46 ++++++++++ cpp/test/neighbors/knn.cu | 7 +- 7 files changed, 187 insertions(+), 90 deletions(-) delete mode 100644 cpp/src/nn/specializations/knn.cu create mode 100644 cpp/src/nn/specializations/knn_long_float_int.cu create mode 100644 cpp/src/nn/specializations/knn_long_float_uint.cu create mode 100644 cpp/src/nn/specializations/knn_uint_float_int.cu create mode 100644 cpp/src/nn/specializations/knn_uint_float_uint.cu diff --git a/cpp/CMakeLists.txt b/cpp/CMakeLists.txt index 7bf63f484f..a16f1d25a0 100644 --- a/cpp/CMakeLists.txt +++ b/cpp/CMakeLists.txt @@ -414,7 +414,10 @@ if(RAFT_COMPILE_NN_LIBRARY) src/nn/specializations/fused_l2_knn_long_float_false.cu src/nn/specializations/fused_l2_knn_int_float_true.cu src/nn/specializations/fused_l2_knn_int_float_false.cu - src/nn/specializations/knn.cu + src/nn/specializations/knn_long_float_int.cu + src/nn/specializations/knn_long_float_uint.cu + src/nn/specializations/knn_uint_float_int.cu + src/nn/specializations/knn_uint_float_uint.cu ) set_target_properties( raft_nn_lib diff --git a/cpp/src/nn/specializations/knn.cu b/cpp/src/nn/specializations/knn.cu deleted file mode 100644 index d135610bfb..0000000000 --- a/cpp/src/nn/specializations/knn.cu +++ /dev/null @@ -1,86 +0,0 @@ -/* - * Copyright (c) 2021-2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include -#include - -namespace raft { -namespace spatial { -namespace knn { - -template void brute_force_knn(raft::device_resources const& handle, - std::vector& input, - std::vector& sizes, - int D, - float* search_items, - int n, - long* res_I, - float* res_D, - int k, - bool rowMajorIndex, - bool rowMajorQuery, - std::vector* translations, - distance::DistanceType metric, - float metric_arg); - -template void brute_force_knn(raft::device_resources const& handle, - std::vector& input, - std::vector& sizes, - unsigned int D, - float* search_items, - unsigned int n, - long* res_I, - float* res_D, - unsigned int k, - bool rowMajorIndex, - bool rowMajorQuery, - std::vector* translations, - distance::DistanceType metric, - float metric_arg); - -template void brute_force_knn(raft::device_resources const& handle, - std::vector& input, - std::vector& sizes, - int D, - float* search_items, - int n, - uint32_t* res_I, - float* res_D, - int k, - bool rowMajorIndex, - bool rowMajorQuery, - std::vector* translations, - distance::DistanceType metric, - float metric_arg); - -template void brute_force_knn(raft::device_resources const& handle, - std::vector& input, - std::vector& sizes, - unsigned int D, - float* search_items, - unsigned int n, - uint32_t* res_I, - float* res_D, - unsigned int k, - bool rowMajorIndex, - bool rowMajorQuery, - std::vector* translations, - distance::DistanceType metric, - float metric_arg); - -}; // namespace knn -}; // namespace spatial -}; // namespace raft diff --git a/cpp/src/nn/specializations/knn_long_float_int.cu b/cpp/src/nn/specializations/knn_long_float_int.cu new file mode 100644 index 0000000000..1360430132 --- /dev/null +++ b/cpp/src/nn/specializations/knn_long_float_int.cu @@ -0,0 +1,45 @@ +/* + * Copyright (c) 2021-2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include + +// include specializations for pairwise_distance and fusedl2knn +// to avoid recompiling again here +#include +#include + +namespace raft { +namespace spatial { +namespace knn { + +template void brute_force_knn(raft::device_resources const& handle, + std::vector& input, + std::vector& sizes, + int D, + float* search_items, + int n, + long* res_I, + float* res_D, + int k, + bool rowMajorIndex, + bool rowMajorQuery, + std::vector* translations, + distance::DistanceType metric, + float metric_arg); +}; // namespace knn +}; // namespace spatial +}; // namespace raft diff --git a/cpp/src/nn/specializations/knn_long_float_uint.cu b/cpp/src/nn/specializations/knn_long_float_uint.cu new file mode 100644 index 0000000000..a84a9e9456 --- /dev/null +++ b/cpp/src/nn/specializations/knn_long_float_uint.cu @@ -0,0 +1,44 @@ +/* + * Copyright (c) 2021-2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include + +// include specializations for pairwise_distance and fusedl2knn +// to avoid recompiling again here +#include +#include + +namespace raft { +namespace spatial { +namespace knn { +template void brute_force_knn(raft::device_resources const& handle, + std::vector& input, + std::vector& sizes, + unsigned int D, + float* search_items, + unsigned int n, + long* res_I, + float* res_D, + unsigned int k, + bool rowMajorIndex, + bool rowMajorQuery, + std::vector* translations, + distance::DistanceType metric, + float metric_arg); +}; // namespace knn +}; // namespace spatial +}; // namespace raft diff --git a/cpp/src/nn/specializations/knn_uint_float_int.cu b/cpp/src/nn/specializations/knn_uint_float_int.cu new file mode 100644 index 0000000000..da8bf0eeec --- /dev/null +++ b/cpp/src/nn/specializations/knn_uint_float_int.cu @@ -0,0 +1,44 @@ +/* + * Copyright (c) 2021-2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include + +// include specializations for pairwise_distance and fusedl2knn +// to avoid recompiling again here +#include +#include + +namespace raft { +namespace spatial { +namespace knn { +template void brute_force_knn(raft::device_resources const& handle, + std::vector& input, + std::vector& sizes, + int D, + float* search_items, + int n, + uint32_t* res_I, + float* res_D, + int k, + bool rowMajorIndex, + bool rowMajorQuery, + std::vector* translations, + distance::DistanceType metric, + float metric_arg); +}; // namespace knn +}; // namespace spatial +}; // namespace raft diff --git a/cpp/src/nn/specializations/knn_uint_float_uint.cu b/cpp/src/nn/specializations/knn_uint_float_uint.cu new file mode 100644 index 0000000000..b2a482a868 --- /dev/null +++ b/cpp/src/nn/specializations/knn_uint_float_uint.cu @@ -0,0 +1,46 @@ +/* + * Copyright (c) 2021-2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include + +// include specializations for pairwise_distance and fusedl2knn +// to avoid recompiling again here +#include +#include + +namespace raft { +namespace spatial { +namespace knn { + +template void brute_force_knn(raft::device_resources const& handle, + std::vector& input, + std::vector& sizes, + unsigned int D, + float* search_items, + unsigned int n, + uint32_t* res_I, + float* res_D, + unsigned int k, + bool rowMajorIndex, + bool rowMajorQuery, + std::vector* translations, + distance::DistanceType metric, + float metric_arg); + +}; // namespace knn +}; // namespace spatial +}; // namespace raft diff --git a/cpp/test/neighbors/knn.cu b/cpp/test/neighbors/knn.cu index 6814d47dcb..a4bf8f807a 100644 --- a/cpp/test/neighbors/knn.cu +++ b/cpp/test/neighbors/knn.cu @@ -21,6 +21,7 @@ #include #include #if defined RAFT_NN_COMPILED +#include #include #endif @@ -188,12 +189,12 @@ const std::vector inputs = { 2, {0, 0, 0, 0, 0, 1, 1, 1, 1, 1}}}; -typedef KNNTest KNNTestFint64_t; -TEST_P(KNNTestFint64_t, BruteForce) { this->testBruteForce(); } +typedef KNNTest KNNTestFint32_t; +TEST_P(KNNTestFint32_t, BruteForce) { this->testBruteForce(); } typedef KNNTest KNNTestFuint32_t; TEST_P(KNNTestFuint32_t, BruteForce) { this->testBruteForce(); } -INSTANTIATE_TEST_CASE_P(KNNTest, KNNTestFint64_t, ::testing::ValuesIn(inputs)); +INSTANTIATE_TEST_CASE_P(KNNTest, KNNTestFint32_t, ::testing::ValuesIn(inputs)); INSTANTIATE_TEST_CASE_P(KNNTest, KNNTestFuint32_t, ::testing::ValuesIn(inputs)); } // namespace raft::neighbors::brute_force From 5442d311acbae9b22edc8894f759b4b7690704c0 Mon Sep 17 00:00:00 2001 From: Ben Frederickson Date: Tue, 31 Jan 2023 15:09:55 -0800 Subject: [PATCH 06/62] Suggestions from code review --- build.sh | 1 - cpp/CMakeLists.txt | 1 - cpp/include/raft/neighbors/brute_force.cuh | 32 +++++++++---------- .../detail/knn_brute_force.cuh} | 19 ++++------- .../raft/spatial/knn/detail/ball_cover.cuh | 29 +++++++++-------- cpp/include/raft/spatial/knn/knn.cuh | 32 +++++++++---------- cpp/test/neighbors/ball_cover.cu | 4 +-- docs/source/build.md | 6 ++-- 8 files changed, 58 insertions(+), 66 deletions(-) rename cpp/include/raft/{spatial/knn/detail/knn_brute_force_faiss.cuh => neighbors/detail/knn_brute_force.cuh} (98%) diff --git a/build.sh b/build.sh index 2496eea5c2..849c6d9500 100755 --- a/build.sh +++ b/build.sh @@ -394,7 +394,6 @@ if (( ${NUMARGS} == 0 )) || hasArg libraft || hasArg docs || hasArg tests || has -DCMAKE_CUDA_ARCHITECTURES=${RAFT_CMAKE_CUDA_ARCHITECTURES} \ -DCMAKE_BUILD_TYPE=${BUILD_TYPE} \ -DRAFT_COMPILE_LIBRARIES=${COMPILE_LIBRARIES} \ - -DRAFT_ENABLE_NN_DEPENDENCIES=${ENABLE_NN_DEPENDENCIES} \ -DRAFT_NVTX=${NVTX} \ -DDISABLE_DEPRECATION_WARNINGS=${DISABLE_DEPRECATION_WARNINGS} \ -DBUILD_TESTS=${BUILD_TESTS} \ diff --git a/cpp/CMakeLists.txt b/cpp/CMakeLists.txt index a16f1d25a0..c26e494f4c 100644 --- a/cpp/CMakeLists.txt +++ b/cpp/CMakeLists.txt @@ -65,7 +65,6 @@ option( option(RAFT_COMPILE_DIST_LIBRARY "Enable building raft distant shared library instantiations" ${RAFT_COMPILE_LIBRARIES} ) -option(RAFT_ENABLE_NN_DEPENDENCIES "Search for raft::nn dependencies" ${RAFT_COMPILE_LIBRARIES}) option(RAFT_ENABLE_thrust_DEPENDENCY "Enable Thrust dependency" ON) diff --git a/cpp/include/raft/neighbors/brute_force.cuh b/cpp/include/raft/neighbors/brute_force.cuh index ac9d14ce17..f359c64677 100644 --- a/cpp/include/raft/neighbors/brute_force.cuh +++ b/cpp/include/raft/neighbors/brute_force.cuh @@ -18,8 +18,8 @@ #include #include +#include #include -#include #include namespace raft::neighbors::brute_force { @@ -181,21 +181,21 @@ void knn(raft::device_resources const& handle, std::vector* trans_arg = global_id_offset.has_value() ? &trans : nullptr; - raft::spatial::knn::detail::brute_force_knn_impl(handle, - inputs, - sizes, - static_cast(index[0].extent(1)), - // TODO: This is unfortunate. Need to fix. - const_cast(search.data_handle()), - static_cast(search.extent(0)), - indices.data_handle(), - distances.data_handle(), - k, - rowMajorIndex, - rowMajorQuery, - trans_arg, - metric, - metric_arg.value_or(2.0f)); + raft::neighbors::detail::brute_force_knn_impl(handle, + inputs, + sizes, + static_cast(index[0].extent(1)), + // TODO: This is unfortunate. Need to fix. + const_cast(search.data_handle()), + static_cast(search.extent(0)), + indices.data_handle(), + distances.data_handle(), + k, + rowMajorIndex, + rowMajorQuery, + trans_arg, + metric, + metric_arg.value_or(2.0f)); } /** diff --git a/cpp/include/raft/spatial/knn/detail/knn_brute_force_faiss.cuh b/cpp/include/raft/neighbors/detail/knn_brute_force.cuh similarity index 98% rename from cpp/include/raft/spatial/knn/detail/knn_brute_force_faiss.cuh rename to cpp/include/raft/neighbors/detail/knn_brute_force.cuh index be08c1e3aa..44f92365d1 100644 --- a/cpp/include/raft/spatial/knn/detail/knn_brute_force_faiss.cuh +++ b/cpp/include/raft/neighbors/detail/knn_brute_force.cuh @@ -29,18 +29,16 @@ #include #include #include +#include +#include +#include #include #include #include -#include "fused_l2_knn.cuh" -#include "haversine_distance.cuh" -#include "processing.cuh" - -namespace raft { -namespace spatial { -namespace knn { -namespace detail { +namespace raft::neighbors::detail { +using namespace raft::spatial::knn::detail; +using namespace raft::spatial::knn; template @@ -33,6 +32,7 @@ #include #include +#include #include #include @@ -182,19 +182,20 @@ void k_closest_landmarks(raft::device_resources const& handle, std::vector input = {const_cast(index.get_R().data_handle())}; std::vector sizes = {index.n_landmarks}; - brute_force_knn_impl(handle, - input, - sizes, - index.n, - const_cast(query_pts), - n_query_pts, - R_knn_inds, - R_knn_dists, - k, - true, - true, - nullptr, - index.get_metric()); + raft::neighbors::detail::brute_force_knn_impl( + handle, + input, + sizes, + index.n, + const_cast(query_pts), + n_query_pts, + R_knn_inds, + R_knn_dists, + k, + true, + true, + nullptr, + index.get_metric()); } /** diff --git a/cpp/include/raft/spatial/knn/knn.cuh b/cpp/include/raft/spatial/knn/knn.cuh index ca2c248392..727fb313ce 100644 --- a/cpp/include/raft/spatial/knn/knn.cuh +++ b/cpp/include/raft/spatial/knn/knn.cuh @@ -16,13 +16,13 @@ #pragma once -#include "detail/knn_brute_force_faiss.cuh" #include "detail/selection_faiss.cuh" #include #include #include #include +#include namespace raft::spatial::knn { @@ -61,7 +61,7 @@ inline void knn_merge_parts(value_t* in_keys, cudaStream_t stream, idx_t* translations) { - detail::knn_merge_parts( + raft::neighbors::detail::knn_merge_parts( in_keys, in_values, out_keys, out_values, n_samples, n_parts, k, stream, translations); } @@ -212,20 +212,20 @@ void brute_force_knn(raft::device_resources const& handle, { ASSERT(input.size() == sizes.size(), "input and sizes vectors must be the same size"); - detail::brute_force_knn_impl(handle, - input, - sizes, - D, - search_items, - n, - res_I, - res_D, - k, - rowMajorIndex, - rowMajorQuery, - translations, - metric, - metric_arg); + raft::neighbors::detail::brute_force_knn_impl(handle, + input, + sizes, + D, + search_items, + n, + res_I, + res_D, + k, + rowMajorIndex, + rowMajorQuery, + translations, + metric, + metric_arg); } } // namespace raft::spatial::knn diff --git a/cpp/test/neighbors/ball_cover.cu b/cpp/test/neighbors/ball_cover.cu index a97df7df75..3f96defc0c 100644 --- a/cpp/test/neighbors/ball_cover.cu +++ b/cpp/test/neighbors/ball_cover.cu @@ -19,8 +19,8 @@ #include #include #include +#include #include -#include #include #if defined RAFT_NN_COMPILED #include @@ -361,4 +361,4 @@ INSTANTIATE_TEST_CASE_P(BallCoverKNNQueryTest, TEST_P(BallCoverAllKNNTestF, Fit) { basicTest(); } TEST_P(BallCoverKNNQueryTestF, Fit) { basicTest(); } -} // namespace raft::neighbors::ball_cover \ No newline at end of file +} // namespace raft::neighbors::ball_cover diff --git a/docs/source/build.md b/docs/source/build.md index a6e16d3824..08cb6be961 100644 --- a/docs/source/build.md +++ b/docs/source/build.md @@ -59,7 +59,7 @@ The recommended way to build and install RAFT is to use the `build.sh` script in ### Header-only C++ -`build.sh` uses [rapids-cmake](https://github.com/rapidsai/rapids-cmake), which will automatically download any dependencies which are not already installed. It's important to note that while all the headers will be installed and available, some parts of the RAFT API depend on libraries like `cuCollections`, which will need to be explicitly enabled in `build.sh`. +`build.sh` uses [rapids-cmake](https://github.com/rapidsai/rapids-cmake), which will automatically download any dependencies which are not already installed. The following example will download the needed dependencies and install the RAFT headers into `$INSTALL_PREFIX/include/raft`. ```bash @@ -152,7 +152,7 @@ Use `CMAKE_INSTALL_PREFIX` to install RAFT into a specific location. The snippet cd cpp mkdir build cd build -cmake -D BUILD_TESTS=ON -DRAFT_COMPILE_LIBRARIES=ON -DRAFT_ENABLE_NN_DEPENDENCIES=ON -DCMAKE_INSTALL_PREFIX=$CONDA_PREFIX ../ +cmake -D BUILD_TESTS=ON -DRAFT_COMPILE_LIBRARIES=ON -DCMAKE_INSTALL_PREFIX=$CONDA_PREFIX ../ make -j install ``` @@ -166,7 +166,6 @@ RAFT's cmake has the following configurable flags available:. | RAFT_COMPILE_LIBRARIES | ON, OFF | OFF | Compiles all `libraft` shared libraries (these are required for Googletests) | | RAFT_COMPILE_NN_LIBRARY | ON, OFF | OFF | Compiles the `libraft-nn` shared library | | RAFT_COMPILE_DIST_LIBRARY | ON, OFF | OFF | Compiles the `libraft-distance` shared library | -| RAFT_ENABLE_NN_DEPENDENCIES | ON, OFF | OFF | Searches for dependencies of nearest neighbors API, and compiles them if not found. Needed for `raft::spatial::knn` | | RAFT_STATIC_LINK_LIBRARIES | ON, OFF | ON | Build static link libraries instead of shared libraries | | DETECT_CONDA_ENV | ON, OFF | ON | Enable detection of conda environment for dependencies | | NVTX | ON, OFF | OFF | Enable NVTX Markers | @@ -345,7 +344,6 @@ function(find_and_configure_raft) OPTIONS "BUILD_TESTS OFF" "BUILD_BENCH OFF" - "RAFT_ENABLE_NN_DEPENDENCIES ${PKG_ENABLE_NN_DEPENDENCIES}" "RAFT_COMPILE_LIBRARIES ${PKG_COMPILE_LIBRARIES}" "RAFT_ENABLE_thrust_DEPENDENCY ${PKG_ENABLE_thrust_DEPENDENCY}" ) From 0f5d206e357e9efee02d82a6fc5be5306337f30c Mon Sep 17 00:00:00 2001 From: Ben Frederickson Date: Tue, 31 Jan 2023 15:38:52 -0800 Subject: [PATCH 07/62] fixes --- cpp/include/raft/neighbors/brute_force.cuh | 18 +++++++-------- cpp/test/neighbors/ball_cover.cu | 26 +++++++++++----------- cpp/test/neighbors/tiled_knn.cu | 24 ++++++++++---------- 3 files changed, 34 insertions(+), 34 deletions(-) diff --git a/cpp/include/raft/neighbors/brute_force.cuh b/cpp/include/raft/neighbors/brute_force.cuh index f359c64677..76e05f3234 100644 --- a/cpp/include/raft/neighbors/brute_force.cuh +++ b/cpp/include/raft/neighbors/brute_force.cuh @@ -96,15 +96,15 @@ inline void knn_merge_parts( "Number of columns in output indices and distances matrices must be equal to k"); auto n_parts = in_keys.extent(0) / n_samples; - spatial::knn::detail::knn_merge_parts(in_keys.data_handle(), - in_values.data_handle(), - out_keys.data_handle(), - out_values.data_handle(), - n_samples, - n_parts, - in_keys.extent(1), - handle.get_stream(), - translations.value_or(nullptr)); + detail::knn_merge_parts(in_keys.data_handle(), + in_values.data_handle(), + out_keys.data_handle(), + out_values.data_handle(), + n_samples, + n_parts, + in_keys.extent(1), + handle.get_stream(), + translations.value_or(nullptr)); } /** diff --git a/cpp/test/neighbors/ball_cover.cu b/cpp/test/neighbors/ball_cover.cu index 3f96defc0c..0906bb230a 100644 --- a/cpp/test/neighbors/ball_cover.cu +++ b/cpp/test/neighbors/ball_cover.cu @@ -117,19 +117,19 @@ void compute_bfknn(const raft::device_resources& handle, std::vector* translations = nullptr; - raft::spatial::knn::detail::brute_force_knn_impl(handle, - input_vec, - sizes_vec, - d, - const_cast(X2), - n_query_rows, - inds, - dists, - k, - true, - true, - translations, - metric); + raft::neighbors::detail::brute_force_knn_impl(handle, + input_vec, + sizes_vec, + d, + const_cast(X2), + n_query_rows, + inds, + dists, + k, + true, + true, + translations, + metric); } struct ToRadians { diff --git a/cpp/test/neighbors/tiled_knn.cu b/cpp/test/neighbors/tiled_knn.cu index 3dd9a5394b..e1c9e4ceac 100644 --- a/cpp/test/neighbors/tiled_knn.cu +++ b/cpp/test/neighbors/tiled_knn.cu @@ -97,18 +97,18 @@ class TiledKNNTest : public ::testing::TestWithParam { k_, stream_); - knn::detail::tiled_brute_force_knn(handle_, - search_queries.data(), - database.data(), - num_queries, - num_db_vecs, - dim, - k_, - raft_distances_.data(), - raft_indices_.data(), - metric, - params_.row_tiles, - params_.col_tiles); + neighbors::detail::tiled_brute_force_knn(handle_, + search_queries.data(), + database.data(), + num_queries, + num_db_vecs, + dim, + k_, + raft_distances_.data(), + raft_indices_.data(), + metric, + params_.row_tiles, + params_.col_tiles); // verify. ASSERT_TRUE(knn::devArrMatchKnnPair(ref_indices_.data(), From e870eb332cdd4120fc96679514d8a5218a5a780f Mon Sep 17 00:00:00 2001 From: Ben Frederickson Date: Mon, 6 Feb 2023 16:47:51 -0800 Subject: [PATCH 08/62] use pairwise_distance specialization to speed up compile times --- .../raft/neighbors/detail/knn_brute_force.cuh | 23 +++++++++++-------- 1 file changed, 13 insertions(+), 10 deletions(-) diff --git a/cpp/include/raft/neighbors/detail/knn_brute_force.cuh b/cpp/include/raft/neighbors/detail/knn_brute_force.cuh index 44f92365d1..15564052c1 100644 --- a/cpp/include/raft/neighbors/detail/knn_brute_force.cuh +++ b/cpp/include/raft/neighbors/detail/knn_brute_force.cuh @@ -200,16 +200,19 @@ void tiled_brute_force_knn(const raft::device_resources& handle, // calculate the top-k elements for the current tile, by calculating the // full pairwise distance for the tile - and then selecting the top-k from that - distance::pairwise_distance(handle, - search + i * d, - index + j * d, - temp_distances.data(), - current_query_size, - current_centroid_size, - d, - metric, - true, - metric_arg); + // note: we're using a int32 IndexType here on purpose in order to + // use the pairwise_distance specializations. Since the tile size will ensure + // that the total memory is < 1GB per tile, this will not cause any issues + distance::pairwise_distance(handle, + search + i * d, + index + j * d, + temp_distances.data(), + current_query_size, + current_centroid_size, + d, + metric, + true, + metric_arg); detail::select_k(temp_distances.data(), nullptr, From 8445aed6bdc95a3cf102d38bc8f69a301838c891 Mon Sep 17 00:00:00 2001 From: Ben Frederickson Date: Tue, 7 Feb 2023 10:26:45 -0800 Subject: [PATCH 09/62] Use distance specializations --- cpp/src/nn/specializations/brute_force_knn_long_float_int.cu | 1 + cpp/src/nn/specializations/brute_force_knn_long_float_uint.cu | 1 + cpp/src/nn/specializations/brute_force_knn_uint32_t_float_int.cu | 1 + .../nn/specializations/brute_force_knn_uint32_t_float_uint.cu | 1 + 4 files changed, 4 insertions(+) diff --git a/cpp/src/nn/specializations/brute_force_knn_long_float_int.cu b/cpp/src/nn/specializations/brute_force_knn_long_float_int.cu index b08bcfbc79..9926ccef87 100644 --- a/cpp/src/nn/specializations/brute_force_knn_long_float_int.cu +++ b/cpp/src/nn/specializations/brute_force_knn_long_float_int.cu @@ -15,6 +15,7 @@ */ #include +#include #include namespace raft { diff --git a/cpp/src/nn/specializations/brute_force_knn_long_float_uint.cu b/cpp/src/nn/specializations/brute_force_knn_long_float_uint.cu index 78cb92bb38..8efe7c5b7b 100644 --- a/cpp/src/nn/specializations/brute_force_knn_long_float_uint.cu +++ b/cpp/src/nn/specializations/brute_force_knn_long_float_uint.cu @@ -15,6 +15,7 @@ */ #include +#include #include namespace raft { diff --git a/cpp/src/nn/specializations/brute_force_knn_uint32_t_float_int.cu b/cpp/src/nn/specializations/brute_force_knn_uint32_t_float_int.cu index 0082a30796..add2cb0add 100644 --- a/cpp/src/nn/specializations/brute_force_knn_uint32_t_float_int.cu +++ b/cpp/src/nn/specializations/brute_force_knn_uint32_t_float_int.cu @@ -15,6 +15,7 @@ */ #include +#include #include namespace raft { diff --git a/cpp/src/nn/specializations/brute_force_knn_uint32_t_float_uint.cu b/cpp/src/nn/specializations/brute_force_knn_uint32_t_float_uint.cu index b2a1af2cf0..89891a0920 100644 --- a/cpp/src/nn/specializations/brute_force_knn_uint32_t_float_uint.cu +++ b/cpp/src/nn/specializations/brute_force_knn_uint32_t_float_uint.cu @@ -15,6 +15,7 @@ */ #include +#include #include namespace raft { From 5905b2dc698f686c669d3d6540cf6574899e75f7 Mon Sep 17 00:00:00 2001 From: Ben Frederickson Date: Mon, 13 Feb 2023 18:02:33 -0800 Subject: [PATCH 10/62] use specializations in RBC code --- .../raft/spatial/knn/detail/ball_cover.cuh | 22 +++++--------- cpp/test/neighbors/ball_cover.cu | 30 +++++++------------ 2 files changed, 18 insertions(+), 34 deletions(-) diff --git a/cpp/include/raft/spatial/knn/detail/ball_cover.cuh b/cpp/include/raft/spatial/knn/detail/ball_cover.cuh index 885af90672..9d89967dd2 100644 --- a/cpp/include/raft/spatial/knn/detail/ball_cover.cuh +++ b/cpp/include/raft/spatial/knn/detail/ball_cover.cuh @@ -32,7 +32,7 @@ #include #include -#include +#include #include #include @@ -178,23 +178,15 @@ void k_closest_landmarks(raft::device_resources const& handle, value_idx* R_knn_inds, value_t* R_knn_dists) { - // TODO: Add const to the brute-force knn inputs - std::vector input = {const_cast(index.get_R().data_handle())}; - std::vector sizes = {index.n_landmarks}; + std::vector> inputs = {index.get_R()}; - raft::neighbors::detail::brute_force_knn_impl( + raft::neighbors::brute_force::knn( handle, - input, - sizes, - index.n, - const_cast(query_pts), - n_query_pts, - R_knn_inds, - R_knn_dists, + inputs, + make_device_matrix_view(query_pts, n_query_pts, inputs[0].extent(1)), + make_device_matrix_view(R_knn_inds, n_query_pts, k), + make_device_matrix_view(R_knn_dists, n_query_pts, k), k, - true, - true, - nullptr, index.get_metric()); } diff --git a/cpp/test/neighbors/ball_cover.cu b/cpp/test/neighbors/ball_cover.cu index 0906bb230a..d6b7dea5de 100644 --- a/cpp/test/neighbors/ball_cover.cu +++ b/cpp/test/neighbors/ball_cover.cu @@ -19,7 +19,7 @@ #include #include #include -#include +#include #include #include #if defined RAFT_NN_COMPILED @@ -112,24 +112,16 @@ void compute_bfknn(const raft::device_resources& handle, value_t* dists, int64_t* inds) { - std::vector input_vec = {const_cast(X1)}; - std::vector sizes_vec = {n_rows}; - - std::vector* translations = nullptr; - - raft::neighbors::detail::brute_force_knn_impl(handle, - input_vec, - sizes_vec, - d, - const_cast(X2), - n_query_rows, - inds, - dists, - k, - true, - true, - translations, - metric); + std::vector> input_vec = { + make_device_matrix_view(X1, n_rows, d)}; + + raft::neighbors::brute_force::knn(handle, + input_vec, + make_device_matrix_view(X2, n_query_rows, d), + make_device_matrix_view(inds, n_query_rows, k), + make_device_matrix_view(dists, n_query_rows, k), + k, + metric); } struct ToRadians { From 8eaba848d22d96577b34f3b0d07ee9328633546b Mon Sep 17 00:00:00 2001 From: Ben Frederickson Date: Tue, 14 Feb 2023 14:34:25 -0800 Subject: [PATCH 11/62] use pw specializations in rbc --- cpp/CMakeLists.txt | 2 +- .../ball_cover_all_knn_query.cu | 4 +- .../specializations/ball_cover_build_index.cu | 4 +- .../specializations/ball_cover_knn_query.cu | 4 +- .../nn/specializations/knn_long_float_int.cu | 45 ------------------ .../nn/specializations/knn_long_float_uint.cu | 44 ------------------ .../nn/specializations/knn_uint_float_int.cu | 44 ------------------ .../nn/specializations/knn_uint_float_uint.cu | 46 ------------------- 8 files changed, 4 insertions(+), 189 deletions(-) delete mode 100644 cpp/src/nn/specializations/knn_long_float_int.cu delete mode 100644 cpp/src/nn/specializations/knn_long_float_uint.cu delete mode 100644 cpp/src/nn/specializations/knn_uint_float_int.cu delete mode 100644 cpp/src/nn/specializations/knn_uint_float_uint.cu diff --git a/cpp/CMakeLists.txt b/cpp/CMakeLists.txt index 41f196dbad..f7ae76b5e4 100644 --- a/cpp/CMakeLists.txt +++ b/cpp/CMakeLists.txt @@ -455,7 +455,7 @@ if(RAFT_COMPILE_NN_LIBRARY) target_link_libraries( raft_nn_lib - PUBLIC raft::raft + PUBLIC raft::raft raft::raft_distance_lib PRIVATE nvidia::cutlass::cutlass ) target_compile_options( diff --git a/cpp/src/nn/specializations/ball_cover_all_knn_query.cu b/cpp/src/nn/specializations/ball_cover_all_knn_query.cu index da5cd8de4f..184e18e2ba 100644 --- a/cpp/src/nn/specializations/ball_cover_all_knn_query.cu +++ b/cpp/src/nn/specializations/ball_cover_all_knn_query.cu @@ -18,10 +18,8 @@ #include // Ignore upstream specializations to avoid unnecessary recompiling -#ifdef RAFT_DISTANCE_COMPILED +static_assert(RAFT_DISTANCE_COMPILED, "Requires distance specializations"); #include -#endif - #include #include #include diff --git a/cpp/src/nn/specializations/ball_cover_build_index.cu b/cpp/src/nn/specializations/ball_cover_build_index.cu index 70fcbec356..05b3beec73 100644 --- a/cpp/src/nn/specializations/ball_cover_build_index.cu +++ b/cpp/src/nn/specializations/ball_cover_build_index.cu @@ -18,10 +18,8 @@ #include // Ignore upstream specializations to avoid unnecessary recompiling -#ifdef RAFT_DISTANCE_COMPILED +static_assert(RAFT_DISTANCE_COMPILED, "Requires distance specializations"); #include -#endif - #include #include #include diff --git a/cpp/src/nn/specializations/ball_cover_knn_query.cu b/cpp/src/nn/specializations/ball_cover_knn_query.cu index d5ca1cbc1c..a11f6ba2d2 100644 --- a/cpp/src/nn/specializations/ball_cover_knn_query.cu +++ b/cpp/src/nn/specializations/ball_cover_knn_query.cu @@ -18,10 +18,8 @@ #include // Ignore upstream specializations to avoid unnecessary recompiling -#ifdef RAFT_DISTANCE_COMPILED +static_assert(RAFT_DISTANCE_COMPILED, "Requires distance specializations"); #include -#endif - #include #include #include diff --git a/cpp/src/nn/specializations/knn_long_float_int.cu b/cpp/src/nn/specializations/knn_long_float_int.cu deleted file mode 100644 index 1360430132..0000000000 --- a/cpp/src/nn/specializations/knn_long_float_int.cu +++ /dev/null @@ -1,45 +0,0 @@ -/* - * Copyright (c) 2021-2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include -#include - -// include specializations for pairwise_distance and fusedl2knn -// to avoid recompiling again here -#include -#include - -namespace raft { -namespace spatial { -namespace knn { - -template void brute_force_knn(raft::device_resources const& handle, - std::vector& input, - std::vector& sizes, - int D, - float* search_items, - int n, - long* res_I, - float* res_D, - int k, - bool rowMajorIndex, - bool rowMajorQuery, - std::vector* translations, - distance::DistanceType metric, - float metric_arg); -}; // namespace knn -}; // namespace spatial -}; // namespace raft diff --git a/cpp/src/nn/specializations/knn_long_float_uint.cu b/cpp/src/nn/specializations/knn_long_float_uint.cu deleted file mode 100644 index a84a9e9456..0000000000 --- a/cpp/src/nn/specializations/knn_long_float_uint.cu +++ /dev/null @@ -1,44 +0,0 @@ -/* - * Copyright (c) 2021-2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include -#include - -// include specializations for pairwise_distance and fusedl2knn -// to avoid recompiling again here -#include -#include - -namespace raft { -namespace spatial { -namespace knn { -template void brute_force_knn(raft::device_resources const& handle, - std::vector& input, - std::vector& sizes, - unsigned int D, - float* search_items, - unsigned int n, - long* res_I, - float* res_D, - unsigned int k, - bool rowMajorIndex, - bool rowMajorQuery, - std::vector* translations, - distance::DistanceType metric, - float metric_arg); -}; // namespace knn -}; // namespace spatial -}; // namespace raft diff --git a/cpp/src/nn/specializations/knn_uint_float_int.cu b/cpp/src/nn/specializations/knn_uint_float_int.cu deleted file mode 100644 index da8bf0eeec..0000000000 --- a/cpp/src/nn/specializations/knn_uint_float_int.cu +++ /dev/null @@ -1,44 +0,0 @@ -/* - * Copyright (c) 2021-2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include -#include - -// include specializations for pairwise_distance and fusedl2knn -// to avoid recompiling again here -#include -#include - -namespace raft { -namespace spatial { -namespace knn { -template void brute_force_knn(raft::device_resources const& handle, - std::vector& input, - std::vector& sizes, - int D, - float* search_items, - int n, - uint32_t* res_I, - float* res_D, - int k, - bool rowMajorIndex, - bool rowMajorQuery, - std::vector* translations, - distance::DistanceType metric, - float metric_arg); -}; // namespace knn -}; // namespace spatial -}; // namespace raft diff --git a/cpp/src/nn/specializations/knn_uint_float_uint.cu b/cpp/src/nn/specializations/knn_uint_float_uint.cu deleted file mode 100644 index b2a482a868..0000000000 --- a/cpp/src/nn/specializations/knn_uint_float_uint.cu +++ /dev/null @@ -1,46 +0,0 @@ -/* - * Copyright (c) 2021-2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include -#include - -// include specializations for pairwise_distance and fusedl2knn -// to avoid recompiling again here -#include -#include - -namespace raft { -namespace spatial { -namespace knn { - -template void brute_force_knn(raft::device_resources const& handle, - std::vector& input, - std::vector& sizes, - unsigned int D, - float* search_items, - unsigned int n, - uint32_t* res_I, - float* res_D, - unsigned int k, - bool rowMajorIndex, - bool rowMajorQuery, - std::vector* translations, - distance::DistanceType metric, - float metric_arg); - -}; // namespace knn -}; // namespace spatial -}; // namespace raft From fe728e9052ff9322854cc1c645f16e3b288827b8 Mon Sep 17 00:00:00 2001 From: Ben Frederickson Date: Tue, 14 Feb 2023 15:38:49 -0800 Subject: [PATCH 12/62] use matrix::select_k in bfknn call --- .../raft/neighbors/detail/knn_brute_force.cuh | 42 ++++++++++--------- 1 file changed, 22 insertions(+), 20 deletions(-) diff --git a/cpp/include/raft/neighbors/detail/knn_brute_force.cuh b/cpp/include/raft/neighbors/detail/knn_brute_force.cuh index 15564052c1..c4159ef29a 100644 --- a/cpp/include/raft/neighbors/detail/knn_brute_force.cuh +++ b/cpp/include/raft/neighbors/detail/knn_brute_force.cuh @@ -27,12 +27,11 @@ #include #include #include +#include #include -#include #include #include #include -#include #include #include @@ -214,15 +213,16 @@ void tiled_brute_force_knn(const raft::device_resources& handle, true, metric_arg); - detail::select_k(temp_distances.data(), - nullptr, - current_query_size, - current_centroid_size, - distances + i * k, - indices + i * k, - true, - current_k, - stream); + matrix::select_k( + handle, + raft::make_device_matrix_view( + temp_distances.data(), current_query_size, current_centroid_size), + std::nullopt, + raft::make_device_matrix_view( + distances + i * k, current_query_size, k), + raft::make_device_matrix_view( + indices + i * k, current_query_size, k), + true); // if we're tiling over columns, we need to do a couple things to fix up // the output of select_k @@ -254,15 +254,17 @@ void tiled_brute_force_knn(const raft::device_resources& handle, if (tile_cols != n) { // select the actual top-k items here from the temporary output - detail::select_k(temp_out_distances.data(), - temp_out_indices.data(), - current_query_size, - temp_out_cols, - distances + i * k, - indices + i * k, - true, - k, - stream); + matrix::select_k( + handle, + raft::make_device_matrix_view( + temp_out_distances.data(), current_query_size, temp_out_cols), + raft::make_device_matrix_view( + temp_out_indices.data(), current_query_size, temp_out_cols), + raft::make_device_matrix_view( + distances + i * k, current_query_size, k), + raft::make_device_matrix_view( + indices + i * k, current_query_size, k), + true); } } } From 96e05e1a95c6827f21cbd05c08cb94ba10f2e81e Mon Sep 17 00:00:00 2001 From: Ben Frederickson Date: Tue, 14 Feb 2023 17:18:07 -0800 Subject: [PATCH 13/62] expose bf detail specialization --- .../raft/neighbors/specializations/knn.cuh | 114 ++++++++---------- 1 file changed, 48 insertions(+), 66 deletions(-) diff --git a/cpp/include/raft/neighbors/specializations/knn.cuh b/cpp/include/raft/neighbors/specializations/knn.cuh index b1cfa278d6..e0b64415fe 100644 --- a/cpp/include/raft/neighbors/specializations/knn.cuh +++ b/cpp/include/raft/neighbors/specializations/knn.cuh @@ -16,73 +16,55 @@ #pragma once +#include #include -namespace raft { -namespace spatial { -namespace knn { -extern template void brute_force_knn(raft::device_resources const& handle, - std::vector& input, - std::vector& sizes, - int D, - float* search_items, - int n, - long* res_I, - float* res_D, - int k, - bool rowMajorIndex, - bool rowMajorQuery, - std::vector* translations, - distance::DistanceType metric, - float metric_arg); +namespace raft::spatial::knn { +#define RAFT_INST(IdxT, T, IntT) \ + extern template void brute_force_knn(raft::device_resources const& handle, \ + std::vector& input, \ + std::vector& sizes, \ + IntT D, \ + T* search_items, \ + IntT n, \ + IdxT* res_I, \ + T* res_D, \ + IntT k, \ + bool rowMajorIndex, \ + bool rowMajorQuery, \ + std::vector* translations, \ + distance::DistanceType metric, \ + float metric_arg); -extern template void brute_force_knn( - raft::device_resources const& handle, - std::vector& input, - std::vector& sizes, - unsigned int D, - float* search_items, - unsigned int n, - long* res_I, - float* res_D, - unsigned int k, - bool rowMajorIndex, - bool rowMajorQuery, - std::vector* translations, - distance::DistanceType metric, - float metric_arg); +RAFT_INST(long, float, int); +RAFT_INST(long, float, unsigned int); +RAFT_INST(uint32_t, float, int); +RAFT_INST(uint32_t, float, unsigned int); +#undef RAFT_INST +}; // namespace raft::spatial::knn -extern template void brute_force_knn(raft::device_resources const& handle, - std::vector& input, - std::vector& sizes, - int D, - float* search_items, - int n, - uint32_t* res_I, - float* res_D, - int k, - bool rowMajorIndex, - bool rowMajorQuery, - std::vector* translations, - distance::DistanceType metric, - float metric_arg); - -extern template void brute_force_knn( - raft::device_resources const& handle, - std::vector& input, - std::vector& sizes, - unsigned int D, - float* search_items, - unsigned int n, - uint32_t* res_I, - float* res_D, - unsigned int k, - bool rowMajorIndex, - bool rowMajorQuery, - std::vector* translations, - distance::DistanceType metric, - float metric_arg); - -}; // namespace knn -}; // namespace spatial -}; // namespace raft +// also define the detail api, which is used by raft::neighbors::brute_force +// (not doing the public api, since has extra template params on index_layout, matrix_index, +// search_layout etc - and isn't clear what the defaults here should be) +namespace raft::neighbors::detail { +#define RAFT_INST(IdxT, T, IntT) \ + extern template void brute_force_knn_impl(raft::device_resources const& handle, \ + std::vector& input, \ + std::vector& sizes, \ + IntT D, \ + T* search_items, \ + IntT n, \ + IdxT* res_I, \ + T* res_D, \ + IntT k, \ + bool rowMajorIndex, \ + bool rowMajorQuery, \ + std::vector* translations, \ + raft::distance::DistanceType metric, \ + float metricArg); +RAFT_INST(long, float, int); +RAFT_INST(long, float, unsigned int); +RAFT_INST(uint32_t, float, int); +RAFT_INST(uint32_t, float, unsigned int); +#undef RAFT_INST +} // namespace raft::neighbors::detail From 59060b29c352bbabb96e4bef515751fb98aca578 Mon Sep 17 00:00:00 2001 From: Ben Frederickson Date: Tue, 14 Feb 2023 20:16:11 -0800 Subject: [PATCH 14/62] Revert "use pw specializations in rbc" This reverts commit 8eaba848d22d96577b34f3b0d07ee9328633546b. Change didn't seem to build in CI --- cpp/CMakeLists.txt | 2 +- cpp/src/nn/specializations/ball_cover_all_knn_query.cu | 4 +++- cpp/src/nn/specializations/ball_cover_build_index.cu | 4 +++- cpp/src/nn/specializations/ball_cover_knn_query.cu | 4 +++- 4 files changed, 10 insertions(+), 4 deletions(-) diff --git a/cpp/CMakeLists.txt b/cpp/CMakeLists.txt index f7ae76b5e4..41f196dbad 100644 --- a/cpp/CMakeLists.txt +++ b/cpp/CMakeLists.txt @@ -455,7 +455,7 @@ if(RAFT_COMPILE_NN_LIBRARY) target_link_libraries( raft_nn_lib - PUBLIC raft::raft raft::raft_distance_lib + PUBLIC raft::raft PRIVATE nvidia::cutlass::cutlass ) target_compile_options( diff --git a/cpp/src/nn/specializations/ball_cover_all_knn_query.cu b/cpp/src/nn/specializations/ball_cover_all_knn_query.cu index 184e18e2ba..da5cd8de4f 100644 --- a/cpp/src/nn/specializations/ball_cover_all_knn_query.cu +++ b/cpp/src/nn/specializations/ball_cover_all_knn_query.cu @@ -18,8 +18,10 @@ #include // Ignore upstream specializations to avoid unnecessary recompiling -static_assert(RAFT_DISTANCE_COMPILED, "Requires distance specializations"); +#ifdef RAFT_DISTANCE_COMPILED #include +#endif + #include #include #include diff --git a/cpp/src/nn/specializations/ball_cover_build_index.cu b/cpp/src/nn/specializations/ball_cover_build_index.cu index 05b3beec73..70fcbec356 100644 --- a/cpp/src/nn/specializations/ball_cover_build_index.cu +++ b/cpp/src/nn/specializations/ball_cover_build_index.cu @@ -18,8 +18,10 @@ #include // Ignore upstream specializations to avoid unnecessary recompiling -static_assert(RAFT_DISTANCE_COMPILED, "Requires distance specializations"); +#ifdef RAFT_DISTANCE_COMPILED #include +#endif + #include #include #include diff --git a/cpp/src/nn/specializations/ball_cover_knn_query.cu b/cpp/src/nn/specializations/ball_cover_knn_query.cu index a11f6ba2d2..d5ca1cbc1c 100644 --- a/cpp/src/nn/specializations/ball_cover_knn_query.cu +++ b/cpp/src/nn/specializations/ball_cover_knn_query.cu @@ -18,8 +18,10 @@ #include // Ignore upstream specializations to avoid unnecessary recompiling -static_assert(RAFT_DISTANCE_COMPILED, "Requires distance specializations"); +#ifdef RAFT_DISTANCE_COMPILED #include +#endif + #include #include #include From c734bace91a7b3d07caffef7aaeec65b2fc39c45 Mon Sep 17 00:00:00 2001 From: Ben Frederickson Date: Wed, 15 Feb 2023 12:34:56 -0800 Subject: [PATCH 15/62] Add tests for other metrics Also remove metrics processors - since is handled inside PW distance --- .../raft/neighbors/detail/knn_brute_force.cuh | 19 +------------------ cpp/test/neighbors/tiled_knn.cu | 13 +++++++++++++ 2 files changed, 14 insertions(+), 18 deletions(-) diff --git a/cpp/include/raft/neighbors/detail/knn_brute_force.cuh b/cpp/include/raft/neighbors/detail/knn_brute_force.cuh index c4159ef29a..78bccdc6e0 100644 --- a/cpp/include/raft/neighbors/detail/knn_brute_force.cuh +++ b/cpp/include/raft/neighbors/detail/knn_brute_force.cuh @@ -378,18 +378,6 @@ void brute_force_knn_impl( id_ranges = translations; } - // perform preprocessing - std::unique_ptr> query_metric_processor = - create_processor(metric, n, D, k, rowMajorQuery, userStream); - query_metric_processor->preprocess(search_items); - - std::vector>> metric_processors(input.size()); - for (size_t i = 0; i < input.size(); i++) { - metric_processors[i] = - create_processor(metric, sizes[i], D, k, rowMajorQuery, userStream); - metric_processors[i]->preprocess(input[i]); - } - int device; RAFT_CUDA_TRY(cudaGetDevice(&device)); @@ -476,6 +464,7 @@ void brute_force_knn_impl( } // Perform necessary post-processing + // TODO: is this only really necessary for fusedL2Knn code? if (metric == raft::distance::DistanceType::L2SqrtExpanded || metric == raft::distance::DistanceType::L2SqrtUnexpanded || metric == raft::distance::DistanceType::LpUnexpanded) { @@ -492,12 +481,6 @@ void brute_force_knn_impl( userStream); } - query_metric_processor->revert(search_items); - query_metric_processor->postprocess(out_D); - for (size_t i = 0; i < input.size(); i++) { - metric_processors[i]->revert(input[i]); - } - if (translations == nullptr) delete id_ranges; }; diff --git a/cpp/test/neighbors/tiled_knn.cu b/cpp/test/neighbors/tiled_knn.cu index e1c9e4ceac..00c09eff9a 100644 --- a/cpp/test/neighbors/tiled_knn.cu +++ b/cpp/test/neighbors/tiled_knn.cu @@ -111,6 +111,7 @@ class TiledKNNTest : public ::testing::TestWithParam { params_.col_tiles); // verify. + std::cout << "testing out " << metric << std::endl; ASSERT_TRUE(knn::devArrMatchKnnPair(ref_indices_.data(), raft_indices_.data(), ref_distances_.data(), @@ -154,6 +155,18 @@ class TiledKNNTest : public ::testing::TestWithParam { const std::vector random_inputs = { {256, 512, 16, 8, 16, 8, raft::distance::DistanceType::L2Expanded}, + {256, 512, 16, 8, 16, 8, raft::distance::DistanceType::L2Unexpanded}, + {256, 512, 16, 8, 16, 8, raft::distance::DistanceType::L2SqrtExpanded}, + {256, 512, 16, 8, 16, 8, raft::distance::DistanceType::L2SqrtUnexpanded}, + {256, 512, 16, 8, 16, 8, raft::distance::DistanceType::L1}, + {256, 512, 16, 8, 16, 8, raft::distance::DistanceType::Linf}, + {256, 512, 16, 8, 16, 8, raft::distance::DistanceType::InnerProduct}, + {256, 512, 16, 8, 16, 8, raft::distance::DistanceType::CorrelationExpanded}, + // JensenShannon produces incorrect results + // {256, 512, 16, 8, 16, 8, raft::distance::DistanceType::JensenShannon}, + // BrayCurtis isn't currently supported by pairwise_distance api + // {256, 512, 16, 8, 16, 8, raft::distance::DistanceType::BrayCurtis}, + {256, 512, 16, 8, 16, 8, raft::distance::DistanceType::Canberra}, {10000, 40000, 32, 30, 512, 1024, raft::distance::DistanceType::L2Expanded}, {345, 1023, 16, 128, 512, 1024, raft::distance::DistanceType::CosineExpanded}, {789, 20516, 64, 256, 512, 4096, raft::distance::DistanceType::L2SqrtExpanded}, From c65e4bbe9d471030b84d23189e7995e250d4de8d Mon Sep 17 00:00:00 2001 From: Ben Frederickson Date: Thu, 16 Feb 2023 12:43:49 -0800 Subject: [PATCH 16/62] Fix parameter order --- cpp/include/raft/neighbors/detail/knn_brute_force.cuh | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/cpp/include/raft/neighbors/detail/knn_brute_force.cuh b/cpp/include/raft/neighbors/detail/knn_brute_force.cuh index 78bccdc6e0..6089636a33 100644 --- a/cpp/include/raft/neighbors/detail/knn_brute_force.cuh +++ b/cpp/include/raft/neighbors/detail/knn_brute_force.cuh @@ -435,10 +435,10 @@ void brute_force_knn_impl( break; default: tiled_brute_force_knn(handle, - input[i], search_items, - sizes[i], + input[i], n, + sizes[i], D, k, out_d_ptr, From 3830e5362dd77b003efb6d1046056c688cae76d2 Mon Sep 17 00:00:00 2001 From: Ben Frederickson Date: Thu, 16 Feb 2023 16:47:54 -0800 Subject: [PATCH 17/62] Fix Lp distance --- .../raft/neighbors/detail/knn_brute_force.cuh | 35 ++++----- cpp/test/neighbors/tiled_knn.cu | 74 ++++++++++++++----- 2 files changed, 73 insertions(+), 36 deletions(-) diff --git a/cpp/include/raft/neighbors/detail/knn_brute_force.cuh b/cpp/include/raft/neighbors/detail/knn_brute_force.cuh index 6089636a33..d96d54770c 100644 --- a/cpp/include/raft/neighbors/detail/knn_brute_force.cuh +++ b/cpp/include/raft/neighbors/detail/knn_brute_force.cuh @@ -411,7 +411,8 @@ void brute_force_knn_impl( (metric == raft::distance::DistanceType::L2Unexpanded || metric == raft::distance::DistanceType::L2SqrtUnexpanded || metric == raft::distance::DistanceType::L2Expanded || - metric == raft::distance::DistanceType::L2SqrtExpanded)) { + metric == raft::distance::DistanceType::L2SqrtExpanded || + metric == raft::distance::DistanceType::LpUnexpanded)) { fusedL2Knn(D, out_i_ptr, out_d_ptr, @@ -424,6 +425,20 @@ void brute_force_knn_impl( rowMajorQuery, stream, metric); + + // Perform necessary post-processing + if (metric == raft::distance::DistanceType::L2SqrtExpanded || + metric == raft::distance::DistanceType::L2SqrtUnexpanded || + metric == raft::distance::DistanceType::LpUnexpanded) { + float p = 0.5; // standard l2 + if (metric == raft::distance::DistanceType::LpUnexpanded) p = 1.0 / metricArg; + raft::linalg::unaryOp( + res_D, + res_D, + n * k, + [p] __device__(float input) { return powf(fabsf(input), p); }, + userStream); + } } else { switch (metric) { case raft::distance::DistanceType::Haversine: @@ -463,24 +478,6 @@ void brute_force_knn_impl( knn_merge_parts(out_D, out_I, res_D, res_I, n, input.size(), k, userStream, trans.data()); } - // Perform necessary post-processing - // TODO: is this only really necessary for fusedL2Knn code? - if (metric == raft::distance::DistanceType::L2SqrtExpanded || - metric == raft::distance::DistanceType::L2SqrtUnexpanded || - metric == raft::distance::DistanceType::LpUnexpanded) { - /** - * post-processing - */ - float p = 0.5; // standard l2 - if (metric == raft::distance::DistanceType::LpUnexpanded) p = 1.0 / metricArg; - raft::linalg::unaryOp( - res_D, - res_D, - n * k, - [p] __device__(float input) { return powf(fabsf(input), p); }, - userStream); - } - if (translations == nullptr) delete id_ranges; }; diff --git a/cpp/test/neighbors/tiled_knn.cu b/cpp/test/neighbors/tiled_knn.cu index 00c09eff9a..ca535294e7 100644 --- a/cpp/test/neighbors/tiled_knn.cu +++ b/cpp/test/neighbors/tiled_knn.cu @@ -15,6 +15,7 @@ */ #include "../test_utils.cuh" +#include "./ann_utils.cuh" #include "./knn_utils.cuh" #include @@ -48,6 +49,13 @@ struct TiledKNNInputs { raft::distance::DistanceType metric_; }; +std::ostream& operator<<(std::ostream& os, const TiledKNNInputs& input) +{ + return os << "num_queries:" << input.num_queries << " num_vecs:" << input.num_db_vecs + << " dim:" << input.dim << " k:" << input.k << " row_tiles:" << input.row_tiles + << " col_tiles:" << input.col_tiles << " metric:" << print_metric{input.metric_}; +} + template class TiledKNNTest : public ::testing::TestWithParam { public: @@ -77,6 +85,8 @@ class TiledKNNTest : public ::testing::TestWithParam { protected: void testBruteForce() { + float metric_arg = 3.0; + // calculate the naive knn, by calculating the full pairwise distances and doing a k-select rmm::device_uvector temp_distances(num_db_vecs * num_queries, stream_); distance::pairwise_distance( @@ -84,7 +94,8 @@ class TiledKNNTest : public ::testing::TestWithParam { raft::make_device_matrix_view(search_queries.data(), num_queries, dim), raft::make_device_matrix_view(database.data(), num_db_vecs, dim), raft::make_device_matrix_view(temp_distances.data(), num_queries, num_db_vecs), - metric); + metric, + metric_arg); using namespace raft::spatial; knn::select_k(temp_distances.data(), @@ -97,21 +108,40 @@ class TiledKNNTest : public ::testing::TestWithParam { k_, stream_); - neighbors::detail::tiled_brute_force_knn(handle_, - search_queries.data(), - database.data(), - num_queries, - num_db_vecs, - dim, - k_, - raft_distances_.data(), - raft_indices_.data(), - metric, - params_.row_tiles, - params_.col_tiles); + if ((params_.row_tiles == 0) && (params_.col_tiles == 0)) { + std::vector input{database.data()}; + std::vector sizes{static_cast(num_db_vecs)}; + raft::spatial::knn::brute_force_knn(handle_, + input, + sizes, + dim, + const_cast(search_queries.data()), + num_queries, + raft_indices_.data(), + raft_distances_.data(), + k_, + true, + true, + nullptr, + metric, + metric_arg); + } else { + neighbors::detail::tiled_brute_force_knn(handle_, + search_queries.data(), + database.data(), + num_queries, + num_db_vecs, + dim, + k_, + raft_distances_.data(), + raft_indices_.data(), + metric, + metric_arg, + params_.row_tiles, + params_.col_tiles); + } // verify. - std::cout << "testing out " << metric << std::endl; ASSERT_TRUE(knn::devArrMatchKnnPair(ref_indices_.data(), raft_indices_.data(), ref_distances_.data(), @@ -162,6 +192,8 @@ const std::vector random_inputs = { {256, 512, 16, 8, 16, 8, raft::distance::DistanceType::Linf}, {256, 512, 16, 8, 16, 8, raft::distance::DistanceType::InnerProduct}, {256, 512, 16, 8, 16, 8, raft::distance::DistanceType::CorrelationExpanded}, + {256, 512, 16, 8, 16, 8, raft::distance::DistanceType::CosineExpanded}, + {256, 512, 16, 8, 16, 8, raft::distance::DistanceType::LpUnexpanded}, // JensenShannon produces incorrect results // {256, 512, 16, 8, 16, 8, raft::distance::DistanceType::JensenShannon}, // BrayCurtis isn't currently supported by pairwise_distance api @@ -169,12 +201,20 @@ const std::vector random_inputs = { {256, 512, 16, 8, 16, 8, raft::distance::DistanceType::Canberra}, {10000, 40000, 32, 30, 512, 1024, raft::distance::DistanceType::L2Expanded}, {345, 1023, 16, 128, 512, 1024, raft::distance::DistanceType::CosineExpanded}, - {789, 20516, 64, 256, 512, 4096, raft::distance::DistanceType::L2SqrtExpanded}, + // TODO: next two tests are failing (and definitely used to work) + // {789, 20516, 64, 256, 512, 4096, raft::distance::DistanceType::L2SqrtExpanded}, // Test where the final column tile has < K items: - {4, 12, 32, 6, 4, 8, raft::distance::DistanceType::L2Expanded}, + // {4, 12, 32, 6, 4, 8, raft::distance::DistanceType::L2Expanded}, // Test where passing column_tiles < K {1, 40, 32, 30, 1, 8, raft::distance::DistanceType::L2Expanded}, -}; + // Passing tile sizes of 0 means to use the public api (instead of the + // detail api). + {1000, 500000, 128, 128, 0, 0, raft::distance::DistanceType::L2Expanded}, + {1000, 500000, 128, 128, 0, 0, raft::distance::DistanceType::L2Expanded}, + {1000, 5000, 128, 128, 0, 0, raft::distance::DistanceType::LpUnexpanded}, + {1000, 5000, 128, 128, 0, 0, raft::distance::DistanceType::L2SqrtExpanded}, + {1000, 5000, 128, 128, 0, 0, raft::distance::DistanceType::CosineExpanded}, + {1000, 5000, 128, 128, 0, 0, raft::distance::DistanceType::InnerProduct}}; typedef TiledKNNTest TiledKNNTestF; TEST_P(TiledKNNTestF, BruteForce) { this->testBruteForce(); } From 3f0b9a7c2e5e17f7d4468968c77017cb9dfeee21 Mon Sep 17 00:00:00 2001 From: Ben Frederickson Date: Fri, 17 Feb 2023 12:29:01 -0800 Subject: [PATCH 18/62] Revert "use matrix::select_k in bfknn call" This reverts commit fe728e9052ff9322854cc1c645f16e3b288827b8. This is causing incorrect results, just use the faiss select_k call instead --- .../raft/neighbors/detail/knn_brute_force.cuh | 42 +++++++++---------- 1 file changed, 20 insertions(+), 22 deletions(-) diff --git a/cpp/include/raft/neighbors/detail/knn_brute_force.cuh b/cpp/include/raft/neighbors/detail/knn_brute_force.cuh index d96d54770c..514d3d980f 100644 --- a/cpp/include/raft/neighbors/detail/knn_brute_force.cuh +++ b/cpp/include/raft/neighbors/detail/knn_brute_force.cuh @@ -27,11 +27,12 @@ #include #include #include -#include #include +#include #include #include #include +#include #include #include @@ -213,16 +214,15 @@ void tiled_brute_force_knn(const raft::device_resources& handle, true, metric_arg); - matrix::select_k( - handle, - raft::make_device_matrix_view( - temp_distances.data(), current_query_size, current_centroid_size), - std::nullopt, - raft::make_device_matrix_view( - distances + i * k, current_query_size, k), - raft::make_device_matrix_view( - indices + i * k, current_query_size, k), - true); + detail::select_k(temp_distances.data(), + nullptr, + current_query_size, + current_centroid_size, + distances + i * k, + indices + i * k, + true, + current_k, + stream); // if we're tiling over columns, we need to do a couple things to fix up // the output of select_k @@ -254,17 +254,15 @@ void tiled_brute_force_knn(const raft::device_resources& handle, if (tile_cols != n) { // select the actual top-k items here from the temporary output - matrix::select_k( - handle, - raft::make_device_matrix_view( - temp_out_distances.data(), current_query_size, temp_out_cols), - raft::make_device_matrix_view( - temp_out_indices.data(), current_query_size, temp_out_cols), - raft::make_device_matrix_view( - distances + i * k, current_query_size, k), - raft::make_device_matrix_view( - indices + i * k, current_query_size, k), - true); + detail::select_k(temp_out_distances.data(), + temp_out_indices.data(), + current_query_size, + temp_out_cols, + distances + i * k, + indices + i * k, + true, + k, + stream); } } } From 39005701a819964730798afa43d9f2628934f7c3 Mon Sep 17 00:00:00 2001 From: Ben Frederickson Date: Fri, 17 Feb 2023 12:55:44 -0800 Subject: [PATCH 19/62] re-enable failing tests --- cpp/test/neighbors/tiled_knn.cu | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/cpp/test/neighbors/tiled_knn.cu b/cpp/test/neighbors/tiled_knn.cu index ca535294e7..f42a17872c 100644 --- a/cpp/test/neighbors/tiled_knn.cu +++ b/cpp/test/neighbors/tiled_knn.cu @@ -201,10 +201,9 @@ const std::vector random_inputs = { {256, 512, 16, 8, 16, 8, raft::distance::DistanceType::Canberra}, {10000, 40000, 32, 30, 512, 1024, raft::distance::DistanceType::L2Expanded}, {345, 1023, 16, 128, 512, 1024, raft::distance::DistanceType::CosineExpanded}, - // TODO: next two tests are failing (and definitely used to work) - // {789, 20516, 64, 256, 512, 4096, raft::distance::DistanceType::L2SqrtExpanded}, + {789, 20516, 64, 256, 512, 4096, raft::distance::DistanceType::L2SqrtExpanded}, // Test where the final column tile has < K items: - // {4, 12, 32, 6, 4, 8, raft::distance::DistanceType::L2Expanded}, + {4, 12, 32, 6, 4, 8, raft::distance::DistanceType::L2Expanded}, // Test where passing column_tiles < K {1, 40, 32, 30, 1, 8, raft::distance::DistanceType::L2Expanded}, // Passing tile sizes of 0 means to use the public api (instead of the From 8e719156afc5bca35d930df68b15475a81ea9665 Mon Sep 17 00:00:00 2001 From: Ben Frederickson Date: Fri, 17 Feb 2023 13:51:14 -0800 Subject: [PATCH 20/62] fix cosine/innerproduct in bfknn --- cpp/include/raft/distance/distance_types.hpp | 22 ++++++++++++++++++- .../raft/neighbors/detail/knn_brute_force.cuh | 6 +++-- cpp/include/raft/neighbors/detail/refine.cuh | 2 +- .../spatial/knn/detail/ivf_flat_search.cuh | 22 +------------------ cpp/test/neighbors/tiled_knn.cu | 2 +- 5 files changed, 28 insertions(+), 26 deletions(-) diff --git a/cpp/include/raft/distance/distance_types.hpp b/cpp/include/raft/distance/distance_types.hpp index f5ed68af4a..4060147f1d 100644 --- a/cpp/include/raft/distance/distance_types.hpp +++ b/cpp/include/raft/distance/distance_types.hpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2021-2022, NVIDIA CORPORATION. + * Copyright (c) 2021-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -66,6 +66,26 @@ enum DistanceType : unsigned short { Precomputed = 100 }; +/** + * Whether minimal distance corresponds to similar elements (using the given metric). + */ +inline bool is_min_close(DistanceType metric) +{ + bool select_min; + switch (metric) { + case DistanceType::InnerProduct: + case DistanceType::CosineExpanded: + case DistanceType::CorrelationExpanded: + // Similarity metrics have the opposite meaning, i.e. nearest neighbors are those with larger + // similarity (See the same logic at cpp/include/raft/sparse/spatial/detail/knn.cuh:362 + // {perform_k_selection}) + select_min = false; + break; + default: select_min = true; + } + return select_min; +} + namespace kernels { enum KernelType { LINEAR, POLYNOMIAL, RBF, TANH }; diff --git a/cpp/include/raft/neighbors/detail/knn_brute_force.cuh b/cpp/include/raft/neighbors/detail/knn_brute_force.cuh index 514d3d980f..d1bef3170e 100644 --- a/cpp/include/raft/neighbors/detail/knn_brute_force.cuh +++ b/cpp/include/raft/neighbors/detail/knn_brute_force.cuh @@ -191,6 +191,8 @@ void tiled_brute_force_knn(const raft::device_resources& handle, rmm::device_uvector temp_out_distances(tile_rows * temp_out_cols, stream); rmm::device_uvector temp_out_indices(tile_rows * temp_out_cols, stream); + bool select_min = raft::distance::is_min_close(metric); + for (size_t i = 0; i < m; i += tile_rows) { size_t current_query_size = std::min(tile_rows, m - i); @@ -220,7 +222,7 @@ void tiled_brute_force_knn(const raft::device_resources& handle, current_centroid_size, distances + i * k, indices + i * k, - true, + select_min, current_k, stream); @@ -260,7 +262,7 @@ void tiled_brute_force_knn(const raft::device_resources& handle, temp_out_cols, distances + i * k, indices + i * k, - true, + select_min, k, stream); } diff --git a/cpp/include/raft/neighbors/detail/refine.cuh b/cpp/include/raft/neighbors/detail/refine.cuh index b264643584..ce79e40433 100644 --- a/cpp/include/raft/neighbors/detail/refine.cuh +++ b/cpp/include/raft/neighbors/detail/refine.cuh @@ -128,7 +128,7 @@ void refine_device(raft::device_resources const& handle, refinement_index.metric(), 1, k, - raft::spatial::knn::ivf_flat::detail::is_min_close(metric), + raft::distance::is_min_close(metric), indices.data_handle(), distances.data_handle(), grid_dim_x, diff --git a/cpp/include/raft/spatial/knn/detail/ivf_flat_search.cuh b/cpp/include/raft/spatial/knn/detail/ivf_flat_search.cuh index 7f70d4b8a5..23d5699e8a 100644 --- a/cpp/include/raft/spatial/knn/detail/ivf_flat_search.cuh +++ b/cpp/include/raft/spatial/knn/detail/ivf_flat_search.cuh @@ -1248,26 +1248,6 @@ void search_impl(raft::device_resources const& handle, } } -/** - * Whether minimal distance corresponds to similar elements (using the given metric). - */ -inline bool is_min_close(distance::DistanceType metric) -{ - bool select_min; - switch (metric) { - case raft::distance::DistanceType::InnerProduct: - case raft::distance::DistanceType::CosineExpanded: - case raft::distance::DistanceType::CorrelationExpanded: - // Similarity metrics have the opposite meaning, i.e. nearest neighbors are those with larger - // similarity (See the same logic at cpp/include/raft/sparse/spatial/detail/knn.cuh:362 - // {perform_k_selection}) - select_min = false; - break; - default: select_min = true; - } - return select_min; -} - /** See raft::spatial::knn::ivf_flat::search docs */ template inline void search(raft::device_resources const& handle, @@ -1299,7 +1279,7 @@ inline void search(raft::device_resources const& handle, n_queries, k, n_probes, - is_min_close(index.metric()), + raft::distance::is_min_close(index.metric()), neighbors, distances, mr); diff --git a/cpp/test/neighbors/tiled_knn.cu b/cpp/test/neighbors/tiled_knn.cu index f42a17872c..9a61027ced 100644 --- a/cpp/test/neighbors/tiled_knn.cu +++ b/cpp/test/neighbors/tiled_knn.cu @@ -104,7 +104,7 @@ class TiledKNNTest : public ::testing::TestWithParam { num_db_vecs, ref_distances_.data(), ref_indices_.data(), - true, + raft::distance::is_min_close(metric), k_, stream_); From f806bf6711ec26ea8f3db918803b6ed4fb5e74c5 Mon Sep 17 00:00:00 2001 From: Ben Frederickson Date: Fri, 17 Feb 2023 14:40:41 -0800 Subject: [PATCH 21/62] Test JensenShannon distance --- cpp/test/neighbors/tiled_knn.cu | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/cpp/test/neighbors/tiled_knn.cu b/cpp/test/neighbors/tiled_knn.cu index 9a61027ced..1d0b8207ff 100644 --- a/cpp/test/neighbors/tiled_knn.cu +++ b/cpp/test/neighbors/tiled_knn.cu @@ -162,8 +162,11 @@ class TiledKNNTest : public ::testing::TestWithParam { unsigned long long int seed = 1234ULL; raft::random::RngState r(seed); - uniform(handle_, r, database.data(), num_db_vecs * dim, T(-1.0), T(1.0)); - uniform(handle_, r, search_queries.data(), num_queries * dim, T(-1.0), T(1.0)); + + // JensenShannon distance requires positive values + T min_val = metric == raft::distance::DistanceType::JensenShannon ? T(0.0) : T(-1.0); + uniform(handle_, r, database.data(), num_db_vecs * dim, min_val, T(1.0)); + uniform(handle_, r, search_queries.data(), num_queries * dim, min_val, T(1.0)); } private: @@ -194,8 +197,7 @@ const std::vector random_inputs = { {256, 512, 16, 8, 16, 8, raft::distance::DistanceType::CorrelationExpanded}, {256, 512, 16, 8, 16, 8, raft::distance::DistanceType::CosineExpanded}, {256, 512, 16, 8, 16, 8, raft::distance::DistanceType::LpUnexpanded}, - // JensenShannon produces incorrect results - // {256, 512, 16, 8, 16, 8, raft::distance::DistanceType::JensenShannon}, + {256, 512, 16, 8, 16, 8, raft::distance::DistanceType::JensenShannon}, // BrayCurtis isn't currently supported by pairwise_distance api // {256, 512, 16, 8, 16, 8, raft::distance::DistanceType::BrayCurtis}, {256, 512, 16, 8, 16, 8, raft::distance::DistanceType::Canberra}, From 3315dca313cfcf9b0eee387221da62921351b168 Mon Sep 17 00:00:00 2001 From: Ben Frederickson Date: Fri, 17 Feb 2023 20:01:10 -0800 Subject: [PATCH 22/62] support k up to 2048 in faiss select The faiss repo allows k values up to 2048, but we were limiting to 512 or 1024 instead. This seems to be because of compiler errors such as ``` ptxas error : Entry function '_ZN4raft7spatial3knn6detail15select_k_kernelIdmLb1ELi2048ELi8ELi64EEEvPKT_PKT0_mmPS4_PS7_S4_S7_i' uses too much shared data (0x10000 bytes, 0xc000 max) ptxas error : Entry function '_ZN4raft7spatial3knn6detail15select_k_kernelIdmLb0ELi2048ELi8ELi64EEEvPKT_PKT0_mmPS4_PS7_S4_S7_i' uses too much shared data (0x10000 bytes, 0xc000 max) ptxas error : Entry function '_ZN4raft7spatial3knn6detail15select_k_kernelIdmLb1ELi1024ELi8ELi128EEEvPKT_PKT0_mmPS4_PS7_S4_S7_i' uses too much shared data (0x10000 bytes, 0xc000 max) ptxas error : Entry function '_ZN4raft7spatial3knn6detail15select_k_kernelIdmLb0ELi1024ELi8ELi128EEEvPKT_PKT0_mmPS4_PS7_S4_S7_i' uses too much shared data (0x10000 bytes, 0xc000 max) ``` when compiling with larger k values. However, these only affect double precision with 64bit indices - and float32 works up to k=2048 (even with 64 bit indices). --- .../spatial/knn/detail/selection_faiss.cuh | 14 +++++++++++--- cpp/test/neighbors/selection.cu | 19 ++++++++++--------- 2 files changed, 21 insertions(+), 12 deletions(-) diff --git a/cpp/include/raft/spatial/knn/detail/selection_faiss.cuh b/cpp/include/raft/spatial/knn/detail/selection_faiss.cuh index 2cdc0fae91..7e648d6cdc 100644 --- a/cpp/include/raft/spatial/knn/detail/selection_faiss.cuh +++ b/cpp/include/raft/spatial/knn/detail/selection_faiss.cuh @@ -27,10 +27,13 @@ namespace spatial { namespace knn { namespace detail { -template +template constexpr int kFaissMaxK() { - return (sizeof(key_t) + sizeof(payload_t) > 8) ? 512 : 1024; + if (sizeof(key_t) >= 8) { + return sizeof(payload_t) >= 8 ? 512: 1024; + } + return 2048; } template @@ -159,7 +162,12 @@ inline void select_k(const key_t* inK, select_k_impl( inK, inV, n_rows, n_cols, outK, outV, select_min, k, stream); else if (k <= 1024 && k <= max_k) - select_k_impl( + // note: have to use constexpr std::min here to avoid instantiating templates + // for parameters we don't support + select_k_impl( + inK, inV, n_rows, n_cols, outK, outV, select_min, k, stream); + else if (k <= 2048 && k <= max_k) + select_k_impl( inK, inV, n_rows, n_cols, outK, outV, select_min, k, stream); else ASSERT(k <= max_k, "Current max k is %d (requested %d)", max_k, k); diff --git a/cpp/test/neighbors/selection.cu b/cpp/test/neighbors/selection.cu index 61a6345e5e..f8404272cd 100644 --- a/cpp/test/neighbors/selection.cu +++ b/cpp/test/neighbors/selection.cu @@ -487,13 +487,14 @@ INSTANTIATE_TEST_CASE_P( * SelectionTest/ReferencedRandomFloatSizeT.LargeK/0 * Indicices do not match! ref[91628] = 131.359 != res[36504] = 158.438 * Actual: false (actual=36504 != expected=91628 @38999; - */ -// typedef SelectionTest::params_random> -// ReferencedRandomFloatSizeT; -// TEST_P(ReferencedRandomFloatSizeT, LargeK) { run(); } -// INSTANTIATE_TEST_CASE_P(SelectionTest, -// ReferencedRandomFloatSizeT, -// testing::Combine(inputs_random_largek, -// testing::Values(knn::SelectKAlgo::RADIX_11_BITS))); - +*/ +typedef SelectionTest::params_random> + ReferencedRandomFloatSizeT; +TEST_P(ReferencedRandomFloatSizeT, LargeK) { run(); } +INSTANTIATE_TEST_CASE_P( + SelectionTest, + ReferencedRandomFloatSizeT, + testing::Combine(inputs_random_largek, + testing::Values(knn::SelectKAlgo::FAISS), + testing::Values(std::make_shared()))); } // namespace raft::spatial::selection From a83bef3be18403df21639d68f1c28e4205bbc9bf Mon Sep 17 00:00:00 2001 From: Ben Frederickson Date: Fri, 17 Feb 2023 20:06:48 -0800 Subject: [PATCH 23/62] cmake format --- cpp/CMakeLists.txt | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/cpp/CMakeLists.txt b/cpp/CMakeLists.txt index d3de543ec1..a018b1c9f0 100644 --- a/cpp/CMakeLists.txt +++ b/cpp/CMakeLists.txt @@ -61,7 +61,9 @@ set(RAFT_COMPILE_LIBRARIES_DEFAULT OFF) if(BUILD_TESTS OR BUILD_BENCH) set(RAFT_COMPILE_LIBRARIES_DEFAULT ON) endif() -option(RAFT_COMPILE_LIBRARIES "Enable building raft shared library instantiations" ${RAFT_COMPILE_LIBRARIES_DEFAULT}) +option(RAFT_COMPILE_LIBRARIES "Enable building raft shared library instantiations" + ${RAFT_COMPILE_LIBRARIES_DEFAULT} +) option( RAFT_COMPILE_NN_LIBRARY "Enable building raft nearest neighbors shared library instantiations" ${RAFT_COMPILE_LIBRARIES} From 3b811a14462265dc4a8cd6b11e82d7439d1e663d Mon Sep 17 00:00:00 2001 From: Ben Frederickson Date: Fri, 17 Feb 2023 20:01:10 -0800 Subject: [PATCH 24/62] support k up to 2048 in faiss select The faiss repo allows k values up to 2048, but we were limiting to 512 or 1024 instead. This seems to be because of compiler errors such as ``` ptxas error : Entry function '_ZN4raft7spatial3knn6detail15select_k_kernelIdmLb1ELi2048ELi8ELi64EEEvPKT_PKT0_mmPS4_PS7_S4_S7_i' uses too much shared data (0x10000 bytes, 0xc000 max) ptxas error : Entry function '_ZN4raft7spatial3knn6detail15select_k_kernelIdmLb0ELi2048ELi8ELi64EEEvPKT_PKT0_mmPS4_PS7_S4_S7_i' uses too much shared data (0x10000 bytes, 0xc000 max) ptxas error : Entry function '_ZN4raft7spatial3knn6detail15select_k_kernelIdmLb1ELi1024ELi8ELi128EEEvPKT_PKT0_mmPS4_PS7_S4_S7_i' uses too much shared data (0x10000 bytes, 0xc000 max) ptxas error : Entry function '_ZN4raft7spatial3knn6detail15select_k_kernelIdmLb0ELi1024ELi8ELi128EEEvPKT_PKT0_mmPS4_PS7_S4_S7_i' uses too much shared data (0x10000 bytes, 0xc000 max) ``` when compiling with larger k values. However, these only affect double precision with 64bit indices - and float32 works up to k=2048 (even with 64 bit indices). --- .../spatial/knn/detail/selection_faiss.cuh | 14 +++++++++++--- cpp/test/neighbors/selection.cu | 19 ++++++++++--------- 2 files changed, 21 insertions(+), 12 deletions(-) diff --git a/cpp/include/raft/spatial/knn/detail/selection_faiss.cuh b/cpp/include/raft/spatial/knn/detail/selection_faiss.cuh index fa1f556f22..b284e60316 100644 --- a/cpp/include/raft/spatial/knn/detail/selection_faiss.cuh +++ b/cpp/include/raft/spatial/knn/detail/selection_faiss.cuh @@ -26,10 +26,13 @@ namespace spatial { namespace knn { namespace detail { -template +template constexpr int kFaissMaxK() { - return (sizeof(key_t) + sizeof(payload_t) > 8) ? 512 : 1024; + if (sizeof(key_t) >= 8) { + return sizeof(payload_t) >= 8 ? 512: 1024; + } + return 2048; } template @@ -158,7 +161,12 @@ inline void select_k(const key_t* inK, select_k_impl( inK, inV, n_rows, n_cols, outK, outV, select_min, k, stream); else if (k <= 1024 && k <= max_k) - select_k_impl( + // note: have to use constexpr std::min here to avoid instantiating templates + // for parameters we don't support + select_k_impl( + inK, inV, n_rows, n_cols, outK, outV, select_min, k, stream); + else if (k <= 2048 && k <= max_k) + select_k_impl( inK, inV, n_rows, n_cols, outK, outV, select_min, k, stream); else ASSERT(k <= max_k, "Current max k is %d (requested %d)", max_k, k); diff --git a/cpp/test/neighbors/selection.cu b/cpp/test/neighbors/selection.cu index 61a6345e5e..f8404272cd 100644 --- a/cpp/test/neighbors/selection.cu +++ b/cpp/test/neighbors/selection.cu @@ -487,13 +487,14 @@ INSTANTIATE_TEST_CASE_P( * SelectionTest/ReferencedRandomFloatSizeT.LargeK/0 * Indicices do not match! ref[91628] = 131.359 != res[36504] = 158.438 * Actual: false (actual=36504 != expected=91628 @38999; - */ -// typedef SelectionTest::params_random> -// ReferencedRandomFloatSizeT; -// TEST_P(ReferencedRandomFloatSizeT, LargeK) { run(); } -// INSTANTIATE_TEST_CASE_P(SelectionTest, -// ReferencedRandomFloatSizeT, -// testing::Combine(inputs_random_largek, -// testing::Values(knn::SelectKAlgo::RADIX_11_BITS))); - +*/ +typedef SelectionTest::params_random> + ReferencedRandomFloatSizeT; +TEST_P(ReferencedRandomFloatSizeT, LargeK) { run(); } +INSTANTIATE_TEST_CASE_P( + SelectionTest, + ReferencedRandomFloatSizeT, + testing::Combine(inputs_random_largek, + testing::Values(knn::SelectKAlgo::FAISS), + testing::Values(std::make_shared()))); } // namespace raft::spatial::selection From c39dc65ed5d88232fc8ed80b0ead77d00361c63b Mon Sep 17 00:00:00 2001 From: Ben Frederickson Date: Fri, 17 Feb 2023 21:18:58 -0800 Subject: [PATCH 25/62] style --- cpp/CMakeLists.txt | 4 +++- cpp/include/raft/spatial/knn/detail/selection_faiss.cuh | 4 +--- cpp/test/neighbors/selection.cu | 2 +- 3 files changed, 5 insertions(+), 5 deletions(-) diff --git a/cpp/CMakeLists.txt b/cpp/CMakeLists.txt index b020b8421f..98c21d192f 100644 --- a/cpp/CMakeLists.txt +++ b/cpp/CMakeLists.txt @@ -61,7 +61,9 @@ set(RAFT_COMPILE_LIBRARIES_DEFAULT OFF) if(BUILD_TESTS OR BUILD_BENCH) set(RAFT_COMPILE_LIBRARIES_DEFAULT ON) endif() -option(RAFT_COMPILE_LIBRARIES "Enable building raft shared library instantiations" ${RAFT_COMPILE_LIBRARIES_DEFAULT}) +option(RAFT_COMPILE_LIBRARIES "Enable building raft shared library instantiations" + ${RAFT_COMPILE_LIBRARIES_DEFAULT} +) option( RAFT_COMPILE_NN_LIBRARY "Enable building raft nearest neighbors shared library instantiations" ${RAFT_COMPILE_LIBRARIES} diff --git a/cpp/include/raft/spatial/knn/detail/selection_faiss.cuh b/cpp/include/raft/spatial/knn/detail/selection_faiss.cuh index 7e648d6cdc..c036ca4d32 100644 --- a/cpp/include/raft/spatial/knn/detail/selection_faiss.cuh +++ b/cpp/include/raft/spatial/knn/detail/selection_faiss.cuh @@ -30,9 +30,7 @@ namespace detail { template constexpr int kFaissMaxK() { - if (sizeof(key_t) >= 8) { - return sizeof(payload_t) >= 8 ? 512: 1024; - } + if (sizeof(key_t) >= 8) { return sizeof(payload_t) >= 8 ? 512 : 1024; } return 2048; } diff --git a/cpp/test/neighbors/selection.cu b/cpp/test/neighbors/selection.cu index f8404272cd..147aa6bac3 100644 --- a/cpp/test/neighbors/selection.cu +++ b/cpp/test/neighbors/selection.cu @@ -487,7 +487,7 @@ INSTANTIATE_TEST_CASE_P( * SelectionTest/ReferencedRandomFloatSizeT.LargeK/0 * Indicices do not match! ref[91628] = 131.359 != res[36504] = 158.438 * Actual: false (actual=36504 != expected=91628 @38999; -*/ + */ typedef SelectionTest::params_random> ReferencedRandomFloatSizeT; TEST_P(ReferencedRandomFloatSizeT, LargeK) { run(); } From 275229480b5fdb8787df7f4c8e090966232dc954 Mon Sep 17 00:00:00 2001 From: Ben Frederickson Date: Fri, 17 Feb 2023 21:23:31 -0800 Subject: [PATCH 26/62] code review suggestions --- ci/check_style.sh | 2 +- cpp/include/raft/spatial/knn/detail/selection_faiss.cuh | 6 +++--- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/ci/check_style.sh b/ci/check_style.sh index 345f7fd866..0ee6e88e58 100755 --- a/ci/check_style.sh +++ b/ci/check_style.sh @@ -1,5 +1,5 @@ #!/bin/bash -# Copyright (c) 2020-2022, NVIDIA CORPORATION. +# Copyright (c) 2020-2023, NVIDIA CORPORATION. set -euo pipefail diff --git a/cpp/include/raft/spatial/knn/detail/selection_faiss.cuh b/cpp/include/raft/spatial/knn/detail/selection_faiss.cuh index c036ca4d32..5264f5d12e 100644 --- a/cpp/include/raft/spatial/knn/detail/selection_faiss.cuh +++ b/cpp/include/raft/spatial/knn/detail/selection_faiss.cuh @@ -34,7 +34,7 @@ constexpr int kFaissMaxK() return 2048; } -template +template __global__ void select_k_kernel(const key_t* inK, const payload_t* inV, size_t n_rows, @@ -106,10 +106,10 @@ inline void select_k_impl(const key_t* inK, auto kInit = select_min ? upper_bound() : lower_bound(); auto vInit = -1; if (select_min) { - select_k_kernel + select_k_kernel <<>>(inK, inV, n_rows, n_cols, outK, outV, kInit, vInit, k); } else { - select_k_kernel + select_k_kernel <<>>(inK, inV, n_rows, n_cols, outK, outV, kInit, vInit, k); } RAFT_CUDA_TRY(cudaGetLastError()); From 1548a78c06fb6c637dbb122d133757cab6c94c74 Mon Sep 17 00:00:00 2001 From: Ben Frederickson Date: Tue, 21 Feb 2023 10:46:09 -0800 Subject: [PATCH 27/62] Remove ENABLE_NN_DEPENDENCIES option --- build.sh | 4 ---- docs/source/build.md | 5 +---- 2 files changed, 1 insertion(+), 8 deletions(-) diff --git a/build.sh b/build.sh index a8818e757b..d58ce56ce0 100755 --- a/build.sh +++ b/build.sh @@ -71,7 +71,6 @@ BUILD_BENCH=OFF COMPILE_LIBRARIES=OFF COMPILE_NN_LIBRARY=OFF COMPILE_DIST_LIBRARY=OFF -ENABLE_NN_DEPENDENCIES=OFF INSTALL_TARGET=install TEST_TARGETS="CLUSTER_TEST;CORE_TEST;DISTANCE_TEST;LABEL_TEST;LINALG_TEST;MATRIX_TEST;RANDOM_TEST;SOLVERS_TEST;SPARSE_TEST;SPARSE_DIST_TEST;SPARSE_NEIGHBORS_TEST;NEIGHBORS_TEST;STATS_TEST;UTILS_TEST" @@ -276,7 +275,6 @@ if hasArg --compile-libs || (( ${NUMARGS} == 0 )); then fi if hasArg --compile-nn || hasArg --compile-libs || (( ${NUMARGS} == 0 )); then - ENABLE_NN_DEPENDENCIES=ON COMPILE_NN_LIBRARY=ON CMAKE_TARGET="${CMAKE_TARGET};raft_nn_lib" fi @@ -297,7 +295,6 @@ if hasArg tests || (( ${NUMARGS} == 0 )); then $CMAKE_TARGET == *"NEIGHBORS_TEST"* || \ $CMAKE_TARGET == *"STATS_TEST"* ]]; then echo "-- Enabling nearest neighbors lib for gtests" - ENABLE_NN_DEPENDENCIES=ON COMPILE_NN_LIBRARY=ON fi @@ -321,7 +318,6 @@ if hasArg bench || (( ${NUMARGS} == 0 )); then if [[ $CMAKE_TARGET == *"CLUSTER_BENCH"* || \ $CMAKE_TARGET == *"NEIGHBORS_BENCH"* ]]; then echo "-- Enabling nearest neighbors lib for benchmarks" - ENABLE_NN_DEPENDENCIES=ON COMPILE_NN_LIBRARY=ON fi diff --git a/docs/source/build.md b/docs/source/build.md index 56d7b4f1bf..4b8bbad279 100644 --- a/docs/source/build.md +++ b/docs/source/build.md @@ -300,7 +300,7 @@ set(RAFT_PINNED_TAG "branch-${RAFT_VERSION}") function(find_and_configure_raft) set(oneValueArgs VERSION FORK PINNED_TAG - COMPILE_LIBRARIES ENABLE_NN_DEPENDENCIES CLONE_ON_PIN + COMPILE_LIBRARIES CLONE_ON_PIN USE_NN_LIBRARY USE_DISTANCE_LIBRARY ENABLE_thrust_DEPENDENCY) cmake_parse_arguments(PKG "${options}" "${oneValueArgs}" @@ -364,13 +364,10 @@ find_and_configure_raft(VERSION ${RAFT_VERSION}.00 COMPILE_LIBRARIES NO USE_NN_LIBRARY NO USE_DISTANCE_LIBRARY NO - ENABLE_NN_DEPENDENCIES NO ENABLE_thrust_DEPENDENCY YES ) ``` -If using the nearest neighbors APIs without the shared libraries, set `ENABLE_NN_DEPENDENCIES=ON` and keep `USE_NN_LIBRARY=OFF` - ## Uninstall Once built and installed, RAFT can be safely uninstalled using `build.sh` by specifying any or all of the installed components. Please note that since `pylibraft` depends on `libraft`, uninstalling `pylibraft` will also uninstall `libraft`: From 31c9cf26db8e0a710ead36e2c05d310ee4c00b87 Mon Sep 17 00:00:00 2001 From: Ben Frederickson Date: Tue, 21 Feb 2023 14:24:51 -0800 Subject: [PATCH 28/62] temporarily re-add faiss build targets This is needed by the ANN benchmarks code, as well as cuml --- build.sh | 13 ++- .../all_cuda-118_arch-x86_64.yaml | 2 + cpp/CMakeLists.txt | 21 ++++- cpp/cmake/thirdparty/get_faiss.cmake | 89 +++++++++++++++++++ docs/source/build.md | 23 +++-- 5 files changed, 139 insertions(+), 9 deletions(-) create mode 100644 cpp/cmake/thirdparty/get_faiss.cmake diff --git a/build.sh b/build.sh index d58ce56ce0..93f11d11a1 100755 --- a/build.sh +++ b/build.sh @@ -18,7 +18,7 @@ ARGS=$* # script, and that this script resides in the repo dir! REPODIR=$(cd $(dirname $0); pwd) -VALIDARGS="clean libraft pylibraft raft-dask docs tests bench clean --uninstall -v -g -n --compile-libs --compile-nn --compile-dist --allgpuarch --no-nvtx --show_depr_warn -h --minimal-deps" +VALIDARGS="clean libraft pylibraft raft-dask docs tests bench clean --uninstall -v -g -n --compile-libs --compile-nn --compile-dist --allgpuarch --no-nvtx --show_depr_warn -h --buildfaiss --minimal-deps" HELP="$0 [ ...] [ ...] [--cmake-args=\"\"] [--cache-tool=] [--limit-tests=] [--limit-bench=] where is: clean - remove all existing build artifacts and configuration (start over) @@ -45,6 +45,7 @@ HELP="$0 [ ...] [ ...] [--cmake-args=\"\"] [--cache-tool=\\\" - pass arbitrary list of CMake configuration options (escape all quotes in argument) @@ -68,9 +69,11 @@ BUILD_ALL_GPU_ARCH=0 BUILD_TESTS=OFF BUILD_TYPE=Release BUILD_BENCH=OFF +BUILD_STATIC_FAISS=OFF COMPILE_LIBRARIES=OFF COMPILE_NN_LIBRARY=OFF COMPILE_DIST_LIBRARY=OFF +ENABLE_NN_DEPENDENCIES=OFF INSTALL_TARGET=install TEST_TARGETS="CLUSTER_TEST;CORE_TEST;DISTANCE_TEST;LABEL_TEST;LINALG_TEST;MATRIX_TEST;RANDOM_TEST;SOLVERS_TEST;SPARSE_TEST;SPARSE_DIST_TEST;SPARSE_NEIGHBORS_TEST;NEIGHBORS_TEST;STATS_TEST;UTILS_TEST" @@ -275,6 +278,7 @@ if hasArg --compile-libs || (( ${NUMARGS} == 0 )); then fi if hasArg --compile-nn || hasArg --compile-libs || (( ${NUMARGS} == 0 )); then + ENABLE_NN_DEPENDENCIES=ON COMPILE_NN_LIBRARY=ON CMAKE_TARGET="${CMAKE_TARGET};raft_nn_lib" fi @@ -295,6 +299,7 @@ if hasArg tests || (( ${NUMARGS} == 0 )); then $CMAKE_TARGET == *"NEIGHBORS_TEST"* || \ $CMAKE_TARGET == *"STATS_TEST"* ]]; then echo "-- Enabling nearest neighbors lib for gtests" + ENABLE_NN_DEPENDENCIES=ON COMPILE_NN_LIBRARY=ON fi @@ -318,6 +323,7 @@ if hasArg bench || (( ${NUMARGS} == 0 )); then if [[ $CMAKE_TARGET == *"CLUSTER_BENCH"* || \ $CMAKE_TARGET == *"NEIGHBORS_BENCH"* ]]; then echo "-- Enabling nearest neighbors lib for benchmarks" + ENABLE_NN_DEPENDENCIES=ON COMPILE_NN_LIBRARY=ON fi @@ -330,6 +336,9 @@ if hasArg bench || (( ${NUMARGS} == 0 )); then fi +if hasArg --buildfaiss; then + BUILD_STATIC_FAISS=ON +fi if hasArg --no-nvtx; then NVTX=OFF fi @@ -391,6 +400,7 @@ if (( ${NUMARGS} == 0 )) || hasArg libraft || hasArg docs || hasArg tests || has -DCMAKE_CUDA_ARCHITECTURES=${RAFT_CMAKE_CUDA_ARCHITECTURES} \ -DCMAKE_BUILD_TYPE=${BUILD_TYPE} \ -DRAFT_COMPILE_LIBRARIES=${COMPILE_LIBRARIES} \ + -DRAFT_ENABLE_NN_DEPENDENCIES=${ENABLE_NN_DEPENDENCIES} \ -DRAFT_NVTX=${NVTX} \ -DDISABLE_DEPRECATION_WARNINGS=${DISABLE_DEPRECATION_WARNINGS} \ -DBUILD_TESTS=${BUILD_TESTS} \ @@ -398,6 +408,7 @@ if (( ${NUMARGS} == 0 )) || hasArg libraft || hasArg docs || hasArg tests || has -DCMAKE_MESSAGE_LOG_LEVEL=${CMAKE_LOG_LEVEL} \ -DRAFT_COMPILE_NN_LIBRARY=${COMPILE_NN_LIBRARY} \ -DRAFT_COMPILE_DIST_LIBRARY=${COMPILE_DIST_LIBRARY} \ + -DRAFT_USE_FAISS_STATIC=${BUILD_STATIC_FAISS} \ -DRAFT_ENABLE_thrust_DEPENDENCY=${ENABLE_thrust_DEPENDENCY} \ ${CACHE_ARGS} \ ${EXTRA_CMAKE_ARGS} diff --git a/conda/environments/all_cuda-118_arch-x86_64.yaml b/conda/environments/all_cuda-118_arch-x86_64.yaml index 4c2ea0ce3e..e498d10312 100644 --- a/conda/environments/all_cuda-118_arch-x86_64.yaml +++ b/conda/environments/all_cuda-118_arch-x86_64.yaml @@ -22,6 +22,7 @@ dependencies: - dask>=2023.1.1 - distributed>=2023.1.1 - doxygen>=1.8.20 +- faiss-proc=*=cuda - gcc_linux-64=9 - graphviz - ipython @@ -33,6 +34,7 @@ dependencies: - libcusolver=11.4.1.48 - libcusparse-dev=11.7.5.86 - libcusparse=11.7.5.86 +- libfaiss>=1.7.1=cuda* - ninja - numpydoc - pydata-sphinx-theme diff --git a/cpp/CMakeLists.txt b/cpp/CMakeLists.txt index 94e98904f6..7e5b10b227 100644 --- a/cpp/CMakeLists.txt +++ b/cpp/CMakeLists.txt @@ -71,6 +71,9 @@ option( option(RAFT_COMPILE_DIST_LIBRARY "Enable building raft distant shared library instantiations" ${RAFT_COMPILE_LIBRARIES} ) +option(RAFT_ENABLE_NN_DEPENDENCIES "Search for raft::nn dependencies like faiss" + ${RAFT_COMPILE_NN_LIBRARY} +) option(RAFT_ENABLE_thrust_DEPENDENCY "Enable Thrust dependency" ON) @@ -86,7 +89,16 @@ if(BUILD_TESTS AND NOT RAFT_ENABLE_thrust_DEPENDENCY) set(RAFT_ENABLE_thrust_DEPENDENCY ON) endif() +option(RAFT_EXCLUDE_FAISS_FROM_ALL "Exclude FAISS targets from RAFT's 'all' target" ON) + include(CMakeDependentOption) +cmake_dependent_option( + RAFT_USE_FAISS_STATIC + "Build and statically link the FAISS library for nearest neighbors search on GPU" + ON + RAFT_COMPILE_LIBRARIES + OFF +) message(VERBOSE "RAFT: Building optional components: ${raft_FIND_COMPONENTS}") message(VERBOSE "RAFT: Build RAFT unit-tests: ${BUILD_TESTS}") @@ -171,6 +183,7 @@ rapids_cpm_init() # thrust before rmm/cuco so we get the right version of thrust/cub include(cmake/thirdparty/get_thrust.cmake) include(cmake/thirdparty/get_rmm.cmake) +include(cmake/thirdparty/get_faiss.cmake) include(cmake/thirdparty/get_cutlass.cmake) if(RAFT_ENABLE_cuco_DEPENDENCY) @@ -474,7 +487,7 @@ if(RAFT_COMPILE_NN_LIBRARY) target_link_libraries( raft_nn_lib - PUBLIC raft::raft + PUBLIC faiss::faiss raft::raft PRIVATE nvidia::cutlass::cutlass ) target_compile_options( @@ -682,6 +695,12 @@ endif() if(nn IN_LIST raft_FIND_COMPONENTS) enable_language(CUDA) + + if(TARGET faiss AND (NOT TARGET faiss::faiss)) + add_library(faiss::faiss ALIAS faiss) + elseif(TARGET faiss::faiss AND (NOT TARGET faiss)) + add_library(faiss ALIAS faiss::faiss) + endif() endif() ]=] ) diff --git a/cpp/cmake/thirdparty/get_faiss.cmake b/cpp/cmake/thirdparty/get_faiss.cmake new file mode 100644 index 0000000000..e6f06a00a5 --- /dev/null +++ b/cpp/cmake/thirdparty/get_faiss.cmake @@ -0,0 +1,89 @@ +#============================================================================= +# Copyright (c) 2021-2022, NVIDIA CORPORATION. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +#============================================================================= + +function(find_and_configure_faiss) + set(oneValueArgs VERSION REPOSITORY PINNED_TAG BUILD_STATIC_LIBS EXCLUDE_FROM_ALL) + cmake_parse_arguments(PKG "${options}" "${oneValueArgs}" + "${multiValueArgs}" ${ARGN} ) + + if(RAFT_ENABLE_NN_DEPENDENCIES OR RAFT_COMPILE_LIBRARIES) + rapids_find_generate_module(faiss + HEADER_NAMES faiss/IndexFlat.h + LIBRARY_NAMES faiss + ) + + set(BUILD_SHARED_LIBS ON) + if (PKG_BUILD_STATIC_LIBS) + set(BUILD_SHARED_LIBS OFF) + set(CPM_DOWNLOAD_faiss ON) + endif() + + rapids_cpm_find(faiss ${PKG_VERSION} + GLOBAL_TARGETS faiss::faiss + CPM_ARGS + GIT_REPOSITORY ${PKG_REPOSITORY} + GIT_TAG ${PKG_PINNED_TAG} + EXCLUDE_FROM_ALL ${PKG_EXCLUDE_FROM_ALL} + OPTIONS + "FAISS_ENABLE_PYTHON OFF" + "CUDAToolkit_ROOT ${CUDAToolkit_LIBRARY_DIR}" + "FAISS_ENABLE_GPU ON" + "BUILD_TESTING OFF" + "CMAKE_MESSAGE_LOG_LEVEL VERBOSE" + "FAISS_USE_CUDA_TOOLKIT_STATIC ${CUDA_STATIC_RUNTIME}" + ) + + if(TARGET faiss AND NOT TARGET faiss::faiss) + add_library(faiss::faiss ALIAS faiss) + endif() + + if(faiss_ADDED) + rapids_export(BUILD faiss + EXPORT_SET faiss-targets + GLOBAL_TARGETS faiss + NAMESPACE faiss::) + endif() + endif() + + # We generate the faiss-config files when we built faiss locally, so always do `find_dependency` + rapids_export_package(BUILD OpenMP raft-nn-lib-exports) # faiss uses openMP but doesn't export a need for it + rapids_export_package(BUILD faiss raft-nn-lib-exports GLOBAL_TARGETS faiss::faiss faiss) + rapids_export_package(INSTALL faiss raft-nn-lib-exports GLOBAL_TARGETS faiss::faiss faiss) + + # Tell cmake where it can find the generated faiss-config.cmake we wrote. + include("${rapids-cmake-dir}/export/find_package_root.cmake") + rapids_export_find_package_root(BUILD faiss [=[${CMAKE_CURRENT_LIST_DIR}]=] raft-nn-lib-exports) +endfunction() + +if(NOT RAFT_FAISS_GIT_TAG) + # TODO: Remove this once faiss supports FAISS_USE_CUDA_TOOLKIT_STATIC + # (https://github.com/facebookresearch/faiss/pull/2446) + set(RAFT_FAISS_GIT_TAG fea/statically-link-ctk-v1.7.0) + # set(RAFT_FAISS_GIT_TAG bde7c0027191f29c9dadafe4f6e68ca0ee31fb30) +endif() + +if(NOT RAFT_FAISS_GIT_REPOSITORY) + # TODO: Remove this once faiss supports FAISS_USE_CUDA_TOOLKIT_STATIC + # (https://github.com/facebookresearch/faiss/pull/2446) + set(RAFT_FAISS_GIT_REPOSITORY https://github.com/trxcllnt/faiss.git) + # set(RAFT_FAISS_GIT_REPOSITORY https://github.com/facebookresearch/faiss.git) +endif() + +find_and_configure_faiss(VERSION 1.7.0 + REPOSITORY ${RAFT_FAISS_GIT_REPOSITORY} + PINNED_TAG ${RAFT_FAISS_GIT_TAG} + BUILD_STATIC_LIBS ${RAFT_USE_FAISS_STATIC} + EXCLUDE_FROM_ALL ${RAFT_EXCLUDE_FAISS_FROM_ALL}) diff --git a/docs/source/build.md b/docs/source/build.md index 4b8bbad279..70b07f4e81 100644 --- a/docs/source/build.md +++ b/docs/source/build.md @@ -47,6 +47,7 @@ In addition to the libraries included with cudatoolkit 11.0+, there are some oth - [cuCollections](https://github.com/NVIDIA/cuCollections) - Used in `raft::sparse::distance` API. - [Libcu++](https://github.com/NVIDIA/libcudacxx) v1.7.0 - Used by cuCollections - [CUTLASS](https://github.com/NVIDIA/cutlass) v2.9.1 - Used in `raft::distance` API. +- [FAISS](https://github.com/facebookresearch/faiss) v1.7.0 - Used in `raft::neighbors` API. - [NCCL](https://github.com/NVIDIA/nccl) - Used in `raft::comms` API and needed to build `raft-dask`. - [UCX](https://github.com/openucx/ucx) - Used in `raft::comms` API and needed to build `raft-dask`. - [Googletest](https://github.com/google/googletest) - Needed to build tests @@ -59,14 +60,14 @@ The recommended way to build and install RAFT is to use the `build.sh` script in ### Header-only C++ -`build.sh` uses [rapids-cmake](https://github.com/rapidsai/rapids-cmake), which will automatically download any dependencies which are not already installed. +`build.sh` uses [rapids-cmake](https://github.com/rapidsai/rapids-cmake), which will automatically download any dependencies which are not already installed. It's important to note that while all the headers will be installed and available, some parts of the RAFT API depend on libraries like `FAISS`, which will need to be explicitly enabled in `build.sh`. The following example will download the needed dependencies and install the RAFT headers into `$INSTALL_PREFIX/include/raft`. ```bash ./build.sh libraft ``` -The `-n` flag can be passed to just have the build download the needed dependencies. Since RAFT is primarily used at build-time, the dependencies will never be installed by the RAFT build. +The `-n` flag can be passed to just have the build download the needed dependencies. Since RAFT is primarily used at build-time, the dependencies will never be installed by the RAFT build, with the exception of building FAISS statically into the shared libraries. ```bash ./build.sh libraft -n ``` @@ -152,7 +153,7 @@ Use `CMAKE_INSTALL_PREFIX` to install RAFT into a specific location. The snippet cd cpp mkdir build cd build -cmake -D BUILD_TESTS=ON -DRAFT_COMPILE_LIBRARIES=ON -DCMAKE_INSTALL_PREFIX=$CONDA_PREFIX ../ +cmake -D BUILD_TESTS=ON -DRAFT_COMPILE_LIBRARIES=ON -DRAFT_ENABLE_NN_DEPENDENCIES=ON -DCMAKE_INSTALL_PREFIX=$CONDA_PREFIX ../ make -j install ``` @@ -166,13 +167,15 @@ RAFT's cmake has the following configurable flags available:. | RAFT_COMPILE_LIBRARIES | ON, OFF | ON if either BUILD_TESTS or BUILD_BENCH is ON; otherwise OFF | Compiles all `libraft` shared libraries (these are required for Googletests) | | RAFT_COMPILE_NN_LIBRARY | ON, OFF | OFF | Compiles the `libraft-nn` shared library | | RAFT_COMPILE_DIST_LIBRARY | ON, OFF | OFF | Compiles the `libraft-distance` shared library | +| RAFT_ENABLE_NN_DEPENDENCIES | ON, OFF | OFF | Searches for dependencies of nearest neighbors API, such as FAISS, and compiles them if not found. Needed for `raft::spatial::knn` | +| RAFT_USE_FAISS_STATIC | ON, OFF | OFF | Statically link FAISS into `libraft-nn` | | DETECT_CONDA_ENV | ON, OFF | ON | Enable detection of conda environment for dependencies | | RAFT_NVTX | ON, OFF | OFF | Enable NVTX Markers | | CUDA_ENABLE_KERNELINFO | ON, OFF | OFF | Enables `kernelinfo` in nvcc. This is useful for `compute-sanitizer` | | CUDA_ENABLE_LINEINFO | ON, OFF | OFF | Enable the -lineinfo option for nvcc | | CUDA_STATIC_RUNTIME | ON, OFF | OFF | Statically link the CUDA runtime | -Currently, shared libraries are provided for the `libraft-nn` and `libraft-distance` components. +Currently, shared libraries are provided for the `libraft-nn` and `libraft-distance` components. The `libraft-nn` component depends upon [FAISS](https://github.com/facebookresearch/faiss) and the `RAFT_ENABLE_NN_DEPENDENCIES` option will build it from source if it is not already installed. ### Python @@ -274,7 +277,7 @@ If RAFT has already been installed, such as by using the `build.sh` script, use ### Using C++ pre-compiled shared libraries -Use `find_package(raft COMPONENTS nn distance)` to enable the shared libraries and transitively pass dependencies through separate targets for each component. In this example, the `raft::distance` and `raft::nn` targets will be available for configuring linking paths in addition to `raft::raft`. These targets will also pass through any transitive dependencies. +Use `find_package(raft COMPONENTS nn distance)` to enable the shared libraries and transitively pass dependencies through separate targets for each component. In this example, the `raft::distance` and `raft::nn` targets will be available for configuring linking paths in addition to `raft::raft`. These targets will also pass through any transitive dependencies (such as FAISS for the `nn` package). The pre-compiled libraries contain template specializations for commonly used types, such as single- and double-precision floating-point. In order to use the symbols in the pre-compiled libraries, the compiler needs to be told not to instantiate templates that are already contained in the shared libraries. By convention, these header files are named `specializations.cuh` and located in the base directory for the packages that contain specializations. @@ -299,8 +302,8 @@ set(RAFT_FORK "rapidsai") set(RAFT_PINNED_TAG "branch-${RAFT_VERSION}") function(find_and_configure_raft) - set(oneValueArgs VERSION FORK PINNED_TAG - COMPILE_LIBRARIES CLONE_ON_PIN + set(oneValueArgs VERSION FORK PINNED_TAG USE_FAISS_STATIC + COMPILE_LIBRARIES ENABLE_NN_DEPENDENCIES CLONE_ON_PIN USE_NN_LIBRARY USE_DISTANCE_LIBRARY ENABLE_thrust_DEPENDENCY) cmake_parse_arguments(PKG "${options}" "${oneValueArgs}" @@ -343,6 +346,8 @@ function(find_and_configure_raft) OPTIONS "BUILD_TESTS OFF" "BUILD_BENCH OFF" + "RAFT_ENABLE_NN_DEPENDENCIES ${PKG_ENABLE_NN_DEPENDENCIES}" + "RAFT_USE_FAISS_STATIC ${PKG_USE_FAISS_STATIC}" "RAFT_COMPILE_LIBRARIES ${PKG_COMPILE_LIBRARIES}" "RAFT_ENABLE_thrust_DEPENDENCY ${PKG_ENABLE_thrust_DEPENDENCY}" ) @@ -364,10 +369,14 @@ find_and_configure_raft(VERSION ${RAFT_VERSION}.00 COMPILE_LIBRARIES NO USE_NN_LIBRARY NO USE_DISTANCE_LIBRARY NO + ENABLE_NN_DEPENDENCIES NO # This builds FAISS if not installed + USE_FAISS_STATIC NO ENABLE_thrust_DEPENDENCY YES ) ``` +If using the nearest neighbors APIs without the shared libraries, set `ENABLE_NN_DEPENDENCIES=ON` and keep `USE_NN_LIBRARY=OFF` + ## Uninstall Once built and installed, RAFT can be safely uninstalled using `build.sh` by specifying any or all of the installed components. Please note that since `pylibraft` depends on `libraft`, uninstalling `pylibraft` will also uninstall `libraft`: From f7fd6a7ff94fbd711a31a9213b60db8d14bc7952 Mon Sep 17 00:00:00 2001 From: Ben Frederickson Date: Tue, 21 Feb 2023 14:27:30 -0800 Subject: [PATCH 29/62] couple more files to re-add faiss --- conda/recipes/libraft/conda_build_config.yaml | 3 +++ conda/recipes/libraft/meta.yaml | 4 ++++ dependencies.yaml | 2 ++ 3 files changed, 9 insertions(+) diff --git a/conda/recipes/libraft/conda_build_config.yaml b/conda/recipes/libraft/conda_build_config.yaml index ae4ba68229..1012bddb40 100644 --- a/conda/recipes/libraft/conda_build_config.yaml +++ b/conda/recipes/libraft/conda_build_config.yaml @@ -19,6 +19,9 @@ nccl_version: gtest_version: - "=1.10.0" +libfaiss_version: + - "1.7.2 *_cuda" + # The CTK libraries below are missing from the conda-forge::cudatoolkit # package. The "*_host_*" version specifiers correspond to `11.8` packages and the # "*_run_*" version specifiers correspond to `11.x` packages. diff --git a/conda/recipes/libraft/meta.yaml b/conda/recipes/libraft/meta.yaml index 4732a63af6..b84f979572 100644 --- a/conda/recipes/libraft/meta.yaml +++ b/conda/recipes/libraft/meta.yaml @@ -128,6 +128,7 @@ outputs: host: - {{ pin_subpackage('libraft-headers', exact=True) }} - cuda-profiler-api {{ cuda_profiler_api_host_version }} + - faiss-proc=*=cuda - lapack - libcublas {{ libcublas_host_version }} - libcublas-dev {{ libcublas_host_version }} @@ -137,7 +138,10 @@ outputs: - libcusolver-dev {{ libcusolver_host_version }} - libcusparse {{ libcusparse_host_version }} - libcusparse-dev {{ libcusparse_host_version }} + - libfaiss {{ libfaiss_version }} run: + - faiss-proc=*=cuda + - libfaiss {{ libfaiss_version }} - {{ pin_subpackage('libraft-headers', exact=True) }} about: home: https://rapids.ai/ diff --git a/dependencies.yaml b/dependencies.yaml index 0bdc921a70..571c9a095a 100644 --- a/dependencies.yaml +++ b/dependencies.yaml @@ -179,6 +179,8 @@ dependencies: - ucx-py=0.31.* - ucx-proc=*=gpu - rmm=23.04 + - libfaiss>=1.7.1=cuda* + - faiss-proc=*=cuda - dask-cuda=23.04 test_python: common: From 37d66d20761ee111634c6513709dc67edddaf842 Mon Sep 17 00:00:00 2001 From: Ben Frederickson Date: Wed, 22 Feb 2023 10:29:44 -0800 Subject: [PATCH 30/62] re-add faiss_mr --- .../raft/neighbors/detail/knn_brute_force.cuh | 1 - cpp/include/raft/spatial/knn/faiss_mr.hpp | 640 ++++++++++++++++++ cpp/test/CMakeLists.txt | 1 + cpp/test/neighbors/faiss_mr.cu | 94 +++ 4 files changed, 735 insertions(+), 1 deletion(-) create mode 100644 cpp/include/raft/spatial/knn/faiss_mr.hpp create mode 100644 cpp/test/neighbors/faiss_mr.cu diff --git a/cpp/include/raft/neighbors/detail/knn_brute_force.cuh b/cpp/include/raft/neighbors/detail/knn_brute_force.cuh index d1bef3170e..81066698f2 100644 --- a/cpp/include/raft/neighbors/detail/knn_brute_force.cuh +++ b/cpp/include/raft/neighbors/detail/knn_brute_force.cuh @@ -31,7 +31,6 @@ #include #include #include -#include #include #include #include diff --git a/cpp/include/raft/spatial/knn/faiss_mr.hpp b/cpp/include/raft/spatial/knn/faiss_mr.hpp new file mode 100644 index 0000000000..3cae417996 --- /dev/null +++ b/cpp/include/raft/spatial/knn/faiss_mr.hpp @@ -0,0 +1,640 @@ +/** + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + +/* +This code contains unnecessary code duplication. These could be deleted +once the relevant changes would be made on the FAISS side. Indeed most of +the logic in the below code is similar to FAISS's standard implementation +and should thus be inherited instead of duplicated. This FAISS's issue +once solved should allow the removal of the unnecessary duplicates +in this file : https://github.com/facebookresearch/faiss/issues/2097 +*/ + +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include + +namespace raft { +namespace spatial { +namespace knn { + +using namespace faiss::gpu; + +namespace { + +// How many streams per device we allocate by default (for multi-streaming) +constexpr int kNumStreams = 2; + +// Use 256 MiB of pinned memory for async CPU <-> GPU copies by default +constexpr size_t kDefaultPinnedMemoryAllocation = (size_t)256 * 1024 * 1024; + +// Default temporary memory allocation for <= 4 GiB memory GPUs +constexpr size_t k4GiBTempMem = (size_t)512 * 1024 * 1024; + +// Default temporary memory allocation for <= 8 GiB memory GPUs +constexpr size_t k8GiBTempMem = (size_t)1024 * 1024 * 1024; + +// Maximum temporary memory allocation for all GPUs +constexpr size_t kMaxTempMem = (size_t)1536 * 1024 * 1024; + +std::string allocsToString(const std::unordered_map& map) +{ + // Produce a sorted list of all outstanding allocations by type + std::unordered_map> stats; + + for (auto& entry : map) { + auto& a = entry.second; + + auto it = stats.find(a.type); + if (it != stats.end()) { + stats[a.type].first++; + stats[a.type].second += a.size; + } else { + stats[a.type] = std::make_pair(1, a.size); + } + } + + std::stringstream ss; + for (auto& entry : stats) { + ss << "Alloc type " << allocTypeToString(entry.first) << ": " << entry.second.first + << " allocations, " << entry.second.second << " bytes\n"; + } + + return ss.str(); +} + +} // namespace + +/// RMM implementation of the GpuResources object that provides for a +/// temporary memory manager +class RmmGpuResourcesImpl : public GpuResources { + public: + RmmGpuResourcesImpl() + : pinnedMemAlloc_(nullptr), + pinnedMemAllocSize_(0), + // let the adjustment function determine the memory size for us by passing + // in a huge value that will then be adjusted + tempMemSize_(getDefaultTempMemForGPU(-1, std::numeric_limits::max())), + pinnedMemSize_(kDefaultPinnedMemoryAllocation), + allocLogging_(false), + cmr(new rmm::mr::cuda_memory_resource), + mmr(new rmm::mr::managed_memory_resource), + pmr(new rmm::mr::pinned_memory_resource){}; + + ~RmmGpuResourcesImpl() + { + // The temporary memory allocator has allocated memory through us, so clean + // that up before we finish fully de-initializing ourselves + tempMemory_.clear(); + + // Make sure all allocations have been freed + bool allocError = false; + + for (auto& entry : allocs_) { + auto& map = entry.second; + + if (!map.empty()) { + std::cerr << "RmmGpuResources destroyed with allocations outstanding:\n" + << "Device " << entry.first << " outstanding allocations:\n"; + std::cerr << allocsToString(map); + allocError = true; + } + } + + FAISS_ASSERT_MSG(!allocError, "GPU memory allocations not properly cleaned up"); + + for (auto& entry : defaultStreams_) { + DeviceScope scope(entry.first); + + // We created these streams, so are responsible for destroying them + CUDA_VERIFY(cudaStreamDestroy(entry.second)); + } + + for (auto& entry : alternateStreams_) { + DeviceScope scope(entry.first); + + for (auto stream : entry.second) { + CUDA_VERIFY(cudaStreamDestroy(stream)); + } + } + + for (auto& entry : asyncCopyStreams_) { + DeviceScope scope(entry.first); + + CUDA_VERIFY(cudaStreamDestroy(entry.second)); + } + + for (auto& entry : blasHandles_) { + DeviceScope scope(entry.first); + + auto blasStatus = cublasDestroy(entry.second); + FAISS_ASSERT(blasStatus == CUBLAS_STATUS_SUCCESS); + } + + if (pinnedMemAlloc_) { pmr->deallocate(pinnedMemAlloc_, pinnedMemAllocSize_); } + }; + + /// Disable allocation of temporary memory; all temporary memory + /// requests will call cudaMalloc / cudaFree at the point of use + void noTempMemory() { setTempMemory(0); }; + + /// Specify that we wish to use a certain fixed size of memory on + /// all devices as temporary memory. This is the upper bound for the GPU + /// memory that we will reserve. We will never go above 1.5 GiB on any GPU; + /// smaller GPUs (with <= 4 GiB or <= 8 GiB) will use less memory than that. + /// To avoid any temporary memory allocation, pass 0. + void setTempMemory(size_t size) + { + if (tempMemSize_ != size) { + // adjust based on general limits + tempMemSize_ = getDefaultTempMemForGPU(-1, size); + + // We need to re-initialize memory resources for all current devices that + // have been initialized. + // This should be safe to do, even if we are currently running work, because + // the cudaFree call that this implies will force-synchronize all GPUs with + // the CPU + for (auto& p : tempMemory_) { + int device = p.first; + // Free the existing memory first + p.second.reset(); + + // Allocate new + p.second = std::unique_ptr( + new StackDeviceMemory(this, + p.first, + // adjust for this specific device + getDefaultTempMemForGPU(device, tempMemSize_))); + } + } + }; + + /// Set amount of pinned memory to allocate, for async GPU <-> CPU + /// transfers + void setPinnedMemory(size_t size) + { + // Should not call this after devices have been initialized + FAISS_ASSERT(defaultStreams_.size() == 0); + FAISS_ASSERT(!pinnedMemAlloc_); + + pinnedMemSize_ = size; + }; + + /// Called to change the stream for work ordering. We do not own `stream`; + /// i.e., it will not be destroyed when the GpuResources object gets cleaned + /// up. + /// We are guaranteed that all Faiss GPU work is ordered with respect to + /// this stream upon exit from an index or other Faiss GPU call. + void setDefaultStream(int device, cudaStream_t stream) + { + if (isInitialized(device)) { + // A new series of calls may not be ordered with what was the previous + // stream, so if the stream being specified is different, then we need to + // ensure ordering between the two (new stream waits on old). + auto it = userDefaultStreams_.find(device); + cudaStream_t prevStream = nullptr; + + if (it != userDefaultStreams_.end()) { + prevStream = it->second; + } else { + FAISS_ASSERT(defaultStreams_.count(device)); + prevStream = defaultStreams_[device]; + } + + if (prevStream != stream) { streamWait({stream}, {prevStream}); } + } + + userDefaultStreams_[device] = stream; + }; + + /// Revert the default stream to the original stream managed by this resources + /// object, in case someone called `setDefaultStream`. + void revertDefaultStream(int device) + { + if (isInitialized(device)) { + auto it = userDefaultStreams_.find(device); + + if (it != userDefaultStreams_.end()) { + // There was a user stream set that we need to synchronize against + cudaStream_t prevStream = userDefaultStreams_[device]; + + FAISS_ASSERT(defaultStreams_.count(device)); + cudaStream_t newStream = defaultStreams_[device]; + + streamWait({newStream}, {prevStream}); + } + } + + userDefaultStreams_.erase(device); + }; + + /// Returns the stream for the given device on which all Faiss GPU work is + /// ordered. + /// We are guaranteed that all Faiss GPU work is ordered with respect to + /// this stream upon exit from an index or other Faiss GPU call. + cudaStream_t getDefaultStream(int device) + { + initializeForDevice(device); + + auto it = userDefaultStreams_.find(device); + if (it != userDefaultStreams_.end()) { + // There is a user override stream set + return it->second; + } + + // Otherwise, our base default stream + return defaultStreams_[device]; + }; + + /// Called to change the work ordering streams to the null stream + /// for all devices + void setDefaultNullStreamAllDevices() + { + for (int dev = 0; dev < getNumDevices(); ++dev) { + setDefaultStream(dev, nullptr); + } + }; + + /// If enabled, will print every GPU memory allocation and deallocation to + /// standard output + void setLogMemoryAllocations(bool enable) { allocLogging_ = enable; }; + + public: + /// Internal system calls + + /// Initialize resources for this device + void initializeForDevice(int device) + { + if (isInitialized(device)) { return; } + + // If this is the first device that we're initializing, create our + // pinned memory allocation + if (defaultStreams_.empty() && pinnedMemSize_ > 0) { + pinnedMemAlloc_ = pmr->allocate(pinnedMemSize_); + pinnedMemAllocSize_ = pinnedMemSize_; + } + + FAISS_ASSERT(device < getNumDevices()); + DeviceScope scope(device); + + // Make sure that device properties for all devices are cached + auto& prop = getDeviceProperties(device); + + // Also check to make sure we meet our minimum compute capability (3.0) + FAISS_ASSERT_FMT(prop.major >= 3, + "Device id %d with CC %d.%d not supported, " + "need 3.0+ compute capability", + device, + prop.major, + prop.minor); + + // Create streams + cudaStream_t defaultStream = 0; + CUDA_VERIFY(cudaStreamCreateWithFlags(&defaultStream, cudaStreamNonBlocking)); + + defaultStreams_[device] = defaultStream; + + cudaStream_t asyncCopyStream = 0; + CUDA_VERIFY(cudaStreamCreateWithFlags(&asyncCopyStream, cudaStreamNonBlocking)); + + asyncCopyStreams_[device] = asyncCopyStream; + + std::vector deviceStreams; + for (int j = 0; j < kNumStreams; ++j) { + cudaStream_t stream = 0; + CUDA_VERIFY(cudaStreamCreateWithFlags(&stream, cudaStreamNonBlocking)); + + deviceStreams.push_back(stream); + } + + alternateStreams_[device] = std::move(deviceStreams); + + // Create cuBLAS handle + cublasHandle_t blasHandle = 0; + auto blasStatus = cublasCreate(&blasHandle); + FAISS_ASSERT(blasStatus == CUBLAS_STATUS_SUCCESS); + blasHandles_[device] = blasHandle; + + // For CUDA 10 on V100, enabling tensor core usage would enable automatic + // rounding down of inputs to f16 (though accumulate in f32) which results in + // unacceptable loss of precision in general. + // For CUDA 11 / A100, only enable tensor core support if it doesn't result in + // a loss of precision. +#if CUDA_VERSION >= 11000 + cublasSetMathMode(blasHandle, CUBLAS_MATH_DISALLOW_REDUCED_PRECISION_REDUCTION); +#endif + + FAISS_ASSERT(allocs_.count(device) == 0); + allocs_[device] = std::unordered_map(); + + FAISS_ASSERT(tempMemory_.count(device) == 0); + auto mem = std::unique_ptr( + new StackDeviceMemory(this, + device, + // adjust for this specific device + getDefaultTempMemForGPU(device, tempMemSize_))); + + tempMemory_.emplace(device, std::move(mem)); + }; + + cublasHandle_t getBlasHandle(int device) + { + initializeForDevice(device); + return blasHandles_[device]; + }; + + std::vector getAlternateStreams(int device) + { + initializeForDevice(device); + return alternateStreams_[device]; + }; + + /// Allocate non-temporary GPU memory + void* allocMemory(const AllocRequest& req) + { + initializeForDevice(req.device); + + // We don't allocate a placeholder for zero-sized allocations + if (req.size == 0) { return nullptr; } + + // Make sure that the allocation is a multiple of 16 bytes for alignment + // purposes + auto adjReq = req; + adjReq.size = utils::roundUp(adjReq.size, (size_t)16); + + void* p = nullptr; + + if (allocLogging_) { std::cout << "RmmGpuResources: alloc " << adjReq.toString() << "\n"; } + + if (adjReq.space == MemorySpace::Temporary) { + // If we don't have enough space in our temporary memory manager, we need + // to allocate this request separately + auto& tempMem = tempMemory_[adjReq.device]; + + if (adjReq.size > tempMem->getSizeAvailable()) { + // We need to allocate this ourselves + AllocRequest newReq = adjReq; + newReq.space = MemorySpace::Device; + newReq.type = AllocType::TemporaryMemoryOverflow; + + return allocMemory(newReq); + } + + // Otherwise, we can handle this locally + p = tempMemory_[adjReq.device]->allocMemory(adjReq.stream, adjReq.size); + + } else if (adjReq.space == MemorySpace::Device) { + p = cmr->allocate(adjReq.size, adjReq.stream); + } else if (adjReq.space == MemorySpace::Unified) { + p = mmr->allocate(adjReq.size, adjReq.stream); + } else { + FAISS_ASSERT_FMT(false, "unknown MemorySpace %d", (int)adjReq.space); + } + + allocs_[adjReq.device][p] = adjReq; + + return p; + }; + + /// Returns a previous allocation + void deallocMemory(int device, void* p) + { + FAISS_ASSERT(isInitialized(device)); + + if (!p) { return; } + + auto& a = allocs_[device]; + auto it = a.find(p); + FAISS_ASSERT(it != a.end()); + + auto& req = it->second; + + if (allocLogging_) { std::cout << "RmmGpuResources: dealloc " << req.toString() << "\n"; } + + if (req.space == MemorySpace::Temporary) { + tempMemory_[device]->deallocMemory(device, req.stream, req.size, p); + } else if (req.space == MemorySpace::Device) { + cmr->deallocate(p, req.size, req.stream); + } else if (req.space == MemorySpace::Unified) { + mmr->deallocate(p, req.size, req.stream); + } else { + FAISS_ASSERT_FMT(false, "unknown MemorySpace %d", (int)req.space); + } + + a.erase(it); + }; + + size_t getTempMemoryAvailable(int device) const + { + FAISS_ASSERT(isInitialized(device)); + + auto it = tempMemory_.find(device); + FAISS_ASSERT(it != tempMemory_.end()); + + return it->second->getSizeAvailable(); + }; + + /// Export a description of memory used for Python + std::map>> getMemoryInfo() const + { + using AT = std::map>; + + std::map out; + + for (auto& entry : allocs_) { + AT outDevice; + + for (auto& a : entry.second) { + auto& v = outDevice[allocTypeToString(a.second.type)]; + v.first++; + v.second += a.second.size; + } + + out[entry.first] = std::move(outDevice); + } + + return out; + }; + + std::pair getPinnedMemory() + { + return std::make_pair(pinnedMemAlloc_, pinnedMemAllocSize_); + }; + + cudaStream_t getAsyncCopyStream(int device) + { + initializeForDevice(device); + return asyncCopyStreams_[device]; + }; + + private: + /// Have GPU resources been initialized for this device yet? + bool isInitialized(int device) const + { + // Use default streams as a marker for whether or not a certain + // device has been initialized + return defaultStreams_.count(device) != 0; + }; + + /// Adjust the default temporary memory allocation based on the total GPU + /// memory size + static size_t getDefaultTempMemForGPU(int device, size_t requested) + { + auto totalMem = device != -1 ? getDeviceProperties(device).totalGlobalMem + : std::numeric_limits::max(); + + if (totalMem <= (size_t)4 * 1024 * 1024 * 1024) { + // If the GPU has <= 4 GiB of memory, reserve 512 MiB + + if (requested > k4GiBTempMem) { return k4GiBTempMem; } + } else if (totalMem <= (size_t)8 * 1024 * 1024 * 1024) { + // If the GPU has <= 8 GiB of memory, reserve 1 GiB + + if (requested > k8GiBTempMem) { return k8GiBTempMem; } + } else { + // Never use more than 1.5 GiB + if (requested > kMaxTempMem) { return kMaxTempMem; } + } + + // use whatever lower limit the user requested + return requested; + }; + + private: + /// Set of currently outstanding memory allocations per device + /// device -> (alloc request, allocated ptr) + std::unordered_map> allocs_; + + /// Temporary memory provider, per each device + std::unordered_map> tempMemory_; + + /// Our default stream that work is ordered on, one per each device + std::unordered_map defaultStreams_; + + /// This contains particular streams as set by the user for + /// ordering, if any + std::unordered_map userDefaultStreams_; + + /// Other streams we can use, per each device + std::unordered_map> alternateStreams_; + + /// Async copy stream to use for GPU <-> CPU pinned memory copies + std::unordered_map asyncCopyStreams_; + + /// cuBLAS handle for each device + std::unordered_map blasHandles_; + + /// Pinned memory allocation for use with this GPU + void* pinnedMemAlloc_; + size_t pinnedMemAllocSize_; + + /// Another option is to use a specified amount of memory on all + /// devices + size_t tempMemSize_; + + /// Amount of pinned memory we should allocate + size_t pinnedMemSize_; + + /// Whether or not we log every GPU memory allocation and deallocation + bool allocLogging_; + + // cuda_memory_resource + std::unique_ptr cmr; + + // managed_memory_resource + std::unique_ptr mmr; + + // pinned_memory_resource + std::unique_ptr pmr; +}; + +/// Default implementation of GpuResources that allocates a cuBLAS +/// stream and 2 streams for use, as well as temporary memory. +/// Internally, the Faiss GPU code uses the instance managed by getResources, +/// but this is the user-facing object that is internally reference counted. +class RmmGpuResources : public GpuResourcesProvider { + public: + RmmGpuResources() : res_(new RmmGpuResourcesImpl){}; + + ~RmmGpuResources(){}; + + std::shared_ptr getResources() { return res_; }; + + /// Disable allocation of temporary memory; all temporary memory + /// requests will call cudaMalloc / cudaFree at the point of use + void noTempMemory() { res_->noTempMemory(); }; + + /// Specify that we wish to use a certain fixed size of memory on + /// all devices as temporary memory. This is the upper bound for the GPU + /// memory that we will reserve. We will never go above 1.5 GiB on any GPU; + /// smaller GPUs (with <= 4 GiB or <= 8 GiB) will use less memory than that. + /// To avoid any temporary memory allocation, pass 0. + void setTempMemory(size_t size) { res_->setTempMemory(size); }; + + /// Set amount of pinned memory to allocate, for async GPU <-> CPU + /// transfers + void setPinnedMemory(size_t size) { res_->setPinnedMemory(size); }; + + /// Called to change the stream for work ordering. We do not own `stream`; + /// i.e., it will not be destroyed when the GpuResources object gets cleaned + /// up. + /// We are guaranteed that all Faiss GPU work is ordered with respect to + /// this stream upon exit from an index or other Faiss GPU call. + void setDefaultStream(int device, cudaStream_t stream) + { + res_->setDefaultStream(device, stream); + }; + + /// Revert the default stream to the original stream managed by this resources + /// object, in case someone called `setDefaultStream`. + void revertDefaultStream(int device) { res_->revertDefaultStream(device); }; + + /// Called to change the work ordering streams to the null stream + /// for all devices + void setDefaultNullStreamAllDevices() { res_->setDefaultNullStreamAllDevices(); }; + + /// Export a description of memory used for Python + std::map>> getMemoryInfo() const + { + return res_->getMemoryInfo(); + }; + + /// Returns the current default stream + cudaStream_t getDefaultStream(int device) { return res_->getDefaultStream(device); }; + + /// Returns the current amount of temp memory available + size_t getTempMemoryAvailable(int device) const { return res_->getTempMemoryAvailable(device); }; + + /// Synchronize our default stream with the CPU + void syncDefaultStreamCurrentDevice() { res_->syncDefaultStreamCurrentDevice(); }; + + /// If enabled, will print every GPU memory allocation and deallocation to + /// standard output + void setLogMemoryAllocations(bool enable) { res_->setLogMemoryAllocations(enable); }; + + private: + std::shared_ptr res_; +}; + +} // namespace knn +} // namespace spatial +} // namespace raft \ No newline at end of file diff --git a/cpp/test/CMakeLists.txt b/cpp/test/CMakeLists.txt index b3894da53b..49f924cc9f 100644 --- a/cpp/test/CMakeLists.txt +++ b/cpp/test/CMakeLists.txt @@ -266,6 +266,7 @@ if(BUILD_TESTS) test/neighbors/tiled_knn.cu test/neighbors/haversine.cu test/neighbors/ball_cover.cu + test/neighbors/faiss_mr.cu test/neighbors/epsilon_neighborhood.cu test/neighbors/refine.cu test/neighbors/selection.cu diff --git a/cpp/test/neighbors/faiss_mr.cu b/cpp/test/neighbors/faiss_mr.cu new file mode 100644 index 0000000000..5f0bcae933 --- /dev/null +++ b/cpp/test/neighbors/faiss_mr.cu @@ -0,0 +1,94 @@ +/* + * Copyright (c) 2021-2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "../test_utils.cuh" + +#include +#include +#include + +#include + +#include + +#include +#include +#include + +namespace raft { +namespace spatial { +namespace knn { + +using namespace faiss::gpu; + +struct AllocInputs { + size_t size; +}; + +template +class FAISS_MR_Test : public ::testing::TestWithParam { + public: + FAISS_MR_Test() + : params_(::testing::TestWithParam::GetParam()), stream(handle.get_stream()) + { + } + + protected: + size_t getFreeMemory(MemorySpace mem_space) + { + if (mem_space == MemorySpace::Device) { + rmm::mr::cuda_memory_resource cmr; + rmm::mr::device_memory_resource* dmr = &cmr; + return dmr->get_mem_info(stream).first; + } else if (mem_space == MemorySpace::Unified) { + rmm::mr::managed_memory_resource mmr; + rmm::mr::device_memory_resource* dmr = &mmr; + return dmr->get_mem_info(stream).first; + } + return 0; + } + + void testAllocs(MemorySpace mem_space) + { + raft::spatial::knn::RmmGpuResources faiss_mr; + auto faiss_mr_impl = faiss_mr.getResources(); + size_t free_before = getFreeMemory(mem_space); + AllocRequest req(AllocType::Other, 0, mem_space, stream, params_.size); + void* ptr = faiss_mr_impl->allocMemory(req); + size_t free_after_alloc = getFreeMemory(mem_space); + faiss_mr_impl->deallocMemory(0, ptr); + ASSERT_TRUE(free_after_alloc <= free_before - params_.size); + } + + raft::device_resources handle; + cudaStream_t stream; + AllocInputs params_; +}; + +const std::vector inputs = {{19687}}; + +typedef FAISS_MR_Test FAISS_MR_TestF; +TEST_P(FAISS_MR_TestF, TestAllocs) +{ + testAllocs(MemorySpace::Device); + testAllocs(MemorySpace::Unified); +} + +INSTANTIATE_TEST_CASE_P(FAISS_MR_Test, FAISS_MR_TestF, ::testing::ValuesIn(inputs)); + +} // namespace knn +} // namespace spatial +} // namespace raft From a61c92f99142b5371d486c243ce49d83f7fec2ac Mon Sep 17 00:00:00 2001 From: Ben Frederickson Date: Wed, 22 Feb 2023 17:07:36 -0800 Subject: [PATCH 31/62] explicitly include faiss_mr --- cpp/test/neighbors/faiss_mr.cu | 1 + 1 file changed, 1 insertion(+) diff --git a/cpp/test/neighbors/faiss_mr.cu b/cpp/test/neighbors/faiss_mr.cu index 5f0bcae933..89f012db0f 100644 --- a/cpp/test/neighbors/faiss_mr.cu +++ b/cpp/test/neighbors/faiss_mr.cu @@ -18,6 +18,7 @@ #include #include +#include #include #include From dbd31b25efcdcfff1e4d147e944dce0428b62875 Mon Sep 17 00:00:00 2001 From: Ben Frederickson Date: Fri, 24 Feb 2023 13:10:05 -0800 Subject: [PATCH 32/62] Allow col_major input to bfknn --- .../raft/neighbors/detail/knn_brute_force.cuh | 36 ++++--- cpp/test/neighbors/tiled_knn.cu | 94 +++++++++++-------- 2 files changed, 82 insertions(+), 48 deletions(-) diff --git a/cpp/include/raft/neighbors/detail/knn_brute_force.cuh b/cpp/include/raft/neighbors/detail/knn_brute_force.cuh index 81066698f2..aae985b5fd 100644 --- a/cpp/include/raft/neighbors/detail/knn_brute_force.cuh +++ b/cpp/include/raft/neighbors/detail/knn_brute_force.cuh @@ -27,6 +27,7 @@ #include #include #include +#include #include #include #include @@ -448,17 +449,30 @@ void brute_force_knn_impl( haversine_knn(out_i_ptr, out_d_ptr, input[i], search_items, sizes[i], n, k, stream); break; default: - tiled_brute_force_knn(handle, - search_items, - input[i], - n, - sizes[i], - D, - k, - out_d_ptr, - out_i_ptr, - metric, - metricArg); + // currently we don't support col_major inside tiled_brute_force_knn, because + // of limitattions of the pairwise_distance API: + // 1) paiwise_distance takes a single 'isRowMajor' parameter - and we have + // multiple options here (like rowMajorQuery/rowMajorIndex) + // 2) because of tiling, we need to be able to set a custom stride in the PW + // api, which isn't supported + // Instead, transpose the input matrices if they are passed as col-major. + auto search = search_items; + rmm::device_uvector search_row_major(0, stream); + if (!rowMajorQuery) { + search_row_major.resize(n * D, stream); + raft::linalg::transpose(handle, search, search_row_major.data(), n, D, stream); + search = search_row_major.data(); + } + auto index = input[i]; + rmm::device_uvector index_row_major(0, stream); + if (!rowMajorIndex) { + index_row_major.resize(sizes[i] * D, stream); + raft::linalg::transpose(handle, index, index_row_major.data(), sizes[i], D, stream); + index = index_row_major.data(); + } + + tiled_brute_force_knn( + handle, search, index, n, sizes[i], D, k, out_d_ptr, out_i_ptr, metric, metricArg); break; } } diff --git a/cpp/test/neighbors/tiled_knn.cu b/cpp/test/neighbors/tiled_knn.cu index 1d0b8207ff..9e86f1ef87 100644 --- a/cpp/test/neighbors/tiled_knn.cu +++ b/cpp/test/neighbors/tiled_knn.cu @@ -21,6 +21,7 @@ #include #include #include +#include #include #if defined RAFT_NN_COMPILED @@ -46,14 +47,16 @@ struct TiledKNNInputs { int k; int row_tiles; int col_tiles; - raft::distance::DistanceType metric_; + raft::distance::DistanceType metric; + bool row_major; }; std::ostream& operator<<(std::ostream& os, const TiledKNNInputs& input) { return os << "num_queries:" << input.num_queries << " num_vecs:" << input.num_db_vecs << " dim:" << input.dim << " k:" << input.k << " row_tiles:" << input.row_tiles - << " col_tiles:" << input.col_tiles << " metric:" << print_metric{input.metric_}; + << " col_tiles:" << input.col_tiles << " metric:" << print_metric{input.metric} + << " row_major:" << input.row_major; } template @@ -89,16 +92,32 @@ class TiledKNNTest : public ::testing::TestWithParam { // calculate the naive knn, by calculating the full pairwise distances and doing a k-select rmm::device_uvector temp_distances(num_db_vecs * num_queries, stream_); - distance::pairwise_distance( - handle_, - raft::make_device_matrix_view(search_queries.data(), num_queries, dim), - raft::make_device_matrix_view(database.data(), num_db_vecs, dim), - raft::make_device_matrix_view(temp_distances.data(), num_queries, num_db_vecs), - metric, - metric_arg); + rmm::device_uvector workspace(0, stream_); + distance::pairwise_distance(handle_, + search_queries.data(), + database.data(), + temp_distances.data(), + num_queries, + num_db_vecs, + dim, + workspace, + metric, + params_.row_major, + metric_arg); + + // setting the 'isRowMajor' flag in the pairwise distances api, not only sets + // the inputs as colmajor - but also the output. this means we have to transpose in this + // case + auto temp_dist = temp_distances.data(); + rmm::device_uvector temp_row_major_dist(num_db_vecs * num_queries, stream_); + if (!params_.row_major) { + raft::linalg::transpose( + handle_, temp_dist, temp_row_major_dist.data(), num_queries, num_db_vecs, stream_); + temp_dist = temp_row_major_dist.data(); + } using namespace raft::spatial; - knn::select_k(temp_distances.data(), + knn::select_k(temp_dist, nullptr, num_queries, num_db_vecs, @@ -120,8 +139,8 @@ class TiledKNNTest : public ::testing::TestWithParam { raft_indices_.data(), raft_distances_.data(), k_, - true, - true, + params_.row_major, + params_.row_major, nullptr, metric, metric_arg); @@ -158,7 +177,7 @@ class TiledKNNTest : public ::testing::TestWithParam { num_db_vecs = params_.num_db_vecs; dim = params_.dim; k_ = params_.k; - metric = params_.metric_; + metric = params_.metric; unsigned long long int seed = 1234ULL; raft::random::RngState r(seed); @@ -187,35 +206,36 @@ class TiledKNNTest : public ::testing::TestWithParam { }; const std::vector random_inputs = { - {256, 512, 16, 8, 16, 8, raft::distance::DistanceType::L2Expanded}, - {256, 512, 16, 8, 16, 8, raft::distance::DistanceType::L2Unexpanded}, - {256, 512, 16, 8, 16, 8, raft::distance::DistanceType::L2SqrtExpanded}, - {256, 512, 16, 8, 16, 8, raft::distance::DistanceType::L2SqrtUnexpanded}, - {256, 512, 16, 8, 16, 8, raft::distance::DistanceType::L1}, - {256, 512, 16, 8, 16, 8, raft::distance::DistanceType::Linf}, - {256, 512, 16, 8, 16, 8, raft::distance::DistanceType::InnerProduct}, - {256, 512, 16, 8, 16, 8, raft::distance::DistanceType::CorrelationExpanded}, - {256, 512, 16, 8, 16, 8, raft::distance::DistanceType::CosineExpanded}, - {256, 512, 16, 8, 16, 8, raft::distance::DistanceType::LpUnexpanded}, - {256, 512, 16, 8, 16, 8, raft::distance::DistanceType::JensenShannon}, + {256, 512, 16, 8, 16, 8, raft::distance::DistanceType::L2Expanded, true}, + {256, 512, 16, 8, 16, 8, raft::distance::DistanceType::L2Unexpanded, true}, + {256, 512, 16, 8, 16, 8, raft::distance::DistanceType::L2SqrtExpanded, true}, + {256, 512, 16, 8, 16, 8, raft::distance::DistanceType::L2SqrtUnexpanded, true}, + {256, 512, 16, 8, 16, 8, raft::distance::DistanceType::L1, true}, + {256, 512, 16, 8, 16, 8, raft::distance::DistanceType::Linf, true}, + {256, 512, 16, 8, 16, 8, raft::distance::DistanceType::InnerProduct, true}, + {256, 512, 16, 8, 16, 8, raft::distance::DistanceType::CorrelationExpanded, true}, + {256, 512, 16, 8, 16, 8, raft::distance::DistanceType::CosineExpanded, true}, + {256, 512, 16, 8, 16, 8, raft::distance::DistanceType::LpUnexpanded, true}, + {256, 512, 16, 8, 16, 8, raft::distance::DistanceType::JensenShannon, true}, + {256, 512, 16, 8, 16, 8, raft::distance::DistanceType::L2SqrtExpanded, true}, // BrayCurtis isn't currently supported by pairwise_distance api // {256, 512, 16, 8, 16, 8, raft::distance::DistanceType::BrayCurtis}, - {256, 512, 16, 8, 16, 8, raft::distance::DistanceType::Canberra}, - {10000, 40000, 32, 30, 512, 1024, raft::distance::DistanceType::L2Expanded}, - {345, 1023, 16, 128, 512, 1024, raft::distance::DistanceType::CosineExpanded}, - {789, 20516, 64, 256, 512, 4096, raft::distance::DistanceType::L2SqrtExpanded}, + {256, 512, 16, 8, 16, 8, raft::distance::DistanceType::Canberra, true}, + {10000, 40000, 32, 30, 512, 1024, raft::distance::DistanceType::L2Expanded, true}, + {345, 1023, 16, 128, 512, 1024, raft::distance::DistanceType::CosineExpanded, true}, + {789, 20516, 64, 256, 512, 4096, raft::distance::DistanceType::L2SqrtExpanded, true}, // Test where the final column tile has < K items: - {4, 12, 32, 6, 4, 8, raft::distance::DistanceType::L2Expanded}, + {4, 12, 32, 6, 4, 8, raft::distance::DistanceType::L2Expanded, true}, // Test where passing column_tiles < K - {1, 40, 32, 30, 1, 8, raft::distance::DistanceType::L2Expanded}, + {1, 40, 32, 30, 1, 8, raft::distance::DistanceType::L2Expanded, true}, // Passing tile sizes of 0 means to use the public api (instead of the - // detail api). - {1000, 500000, 128, 128, 0, 0, raft::distance::DistanceType::L2Expanded}, - {1000, 500000, 128, 128, 0, 0, raft::distance::DistanceType::L2Expanded}, - {1000, 5000, 128, 128, 0, 0, raft::distance::DistanceType::LpUnexpanded}, - {1000, 5000, 128, 128, 0, 0, raft::distance::DistanceType::L2SqrtExpanded}, - {1000, 5000, 128, 128, 0, 0, raft::distance::DistanceType::CosineExpanded}, - {1000, 5000, 128, 128, 0, 0, raft::distance::DistanceType::InnerProduct}}; + // detail api). Note that we can only test col_major in the public api + {1000, 500000, 128, 128, 0, 0, raft::distance::DistanceType::L2Expanded, true}, + {1000, 500000, 128, 128, 0, 0, raft::distance::DistanceType::L2Expanded, false}, + {1000, 5000, 128, 128, 0, 0, raft::distance::DistanceType::LpUnexpanded, true}, + {1000, 5000, 128, 128, 0, 0, raft::distance::DistanceType::L2SqrtExpanded, false}, + {1000, 5000, 128, 128, 0, 0, raft::distance::DistanceType::CosineExpanded, true}, + {1000, 5000, 128, 128, 0, 0, raft::distance::DistanceType::InnerProduct, false}}; typedef TiledKNNTest TiledKNNTestF; TEST_P(TiledKNNTestF, BruteForce) { this->testBruteForce(); } From fddecc36ff1d2f7731b419ba15800e0165534a1a Mon Sep 17 00:00:00 2001 From: Ben Frederickson Date: Fri, 24 Feb 2023 15:24:30 -0800 Subject: [PATCH 33/62] fix faiss queryempty test --- cpp/include/raft/neighbors/detail/knn_brute_force.cuh | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/cpp/include/raft/neighbors/detail/knn_brute_force.cuh b/cpp/include/raft/neighbors/detail/knn_brute_force.cuh index aae985b5fd..ba0940be3c 100644 --- a/cpp/include/raft/neighbors/detail/knn_brute_force.cuh +++ b/cpp/include/raft/neighbors/detail/knn_brute_force.cuh @@ -188,6 +188,17 @@ void tiled_brute_force_knn(const raft::device_resources& handle, if (last_col_tile_size && (last_col_tile_size < static_cast(k))) { temp_out_cols -= k - last_col_tile_size; } + + // if we have less than k items in the index, we should fill out the result + // to indicate that we are missing items (and match behaviour in faiss) + if (n < static_cast(k)) { + thrust::fill(handle.get_thrust_policy(), + distances, + distances + m * k, + std::numeric_limits::lowest()); + thrust::fill(handle.get_thrust_policy(), indices, indices + m * k, -1); + } + rmm::device_uvector temp_out_distances(tile_rows * temp_out_cols, stream); rmm::device_uvector temp_out_indices(tile_rows * temp_out_cols, stream); From 4687144488b12a3576a64fbf3bacb2392b70859a Mon Sep 17 00:00:00 2001 From: Ben Frederickson Date: Mon, 27 Feb 2023 10:07:11 -0800 Subject: [PATCH 34/62] exclude LP from fused --- cpp/include/raft/neighbors/detail/knn_brute_force.cuh | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/cpp/include/raft/neighbors/detail/knn_brute_force.cuh b/cpp/include/raft/neighbors/detail/knn_brute_force.cuh index ba0940be3c..516768ccf4 100644 --- a/cpp/include/raft/neighbors/detail/knn_brute_force.cuh +++ b/cpp/include/raft/neighbors/detail/knn_brute_force.cuh @@ -422,8 +422,7 @@ void brute_force_knn_impl( (metric == raft::distance::DistanceType::L2Unexpanded || metric == raft::distance::DistanceType::L2SqrtUnexpanded || metric == raft::distance::DistanceType::L2Expanded || - metric == raft::distance::DistanceType::L2SqrtExpanded || - metric == raft::distance::DistanceType::LpUnexpanded)) { + metric == raft::distance::DistanceType::L2SqrtExpanded)) { fusedL2Knn(D, out_i_ptr, out_d_ptr, From 06c8674ea4addbe2226c9ca6626f2c2efca05764 Mon Sep 17 00:00:00 2001 From: Ben Frederickson Date: Mon, 27 Feb 2023 15:43:01 -0800 Subject: [PATCH 35/62] use metric processor for cosine/correlation cuml expects cosine distance to be (1-Cosine(a,b)), where the pairwise_distances api was returning just Cosine(a,b). Revert --- .../raft/neighbors/detail/knn_brute_force.cuh | 38 ++++++++++++++++++- 1 file changed, 36 insertions(+), 2 deletions(-) diff --git a/cpp/include/raft/neighbors/detail/knn_brute_force.cuh b/cpp/include/raft/neighbors/detail/knn_brute_force.cuh index 516768ccf4..d7244b120d 100644 --- a/cpp/include/raft/neighbors/detail/knn_brute_force.cuh +++ b/cpp/include/raft/neighbors/detail/knn_brute_force.cuh @@ -389,6 +389,18 @@ void brute_force_knn_impl( id_ranges = translations; } + // perform preprocessing + std::unique_ptr> query_metric_processor = + create_processor(metric, n, D, k, rowMajorQuery, userStream); + query_metric_processor->preprocess(search_items); + + std::vector>> metric_processors(input.size()); + for (size_t i = 0; i < input.size(); i++) { + metric_processors[i] = + create_processor(metric, sizes[i], D, k, rowMajorQuery, userStream); + metric_processors[i]->preprocess(input[i]); + } + int device; RAFT_CUDA_TRY(cudaGetDevice(&device)); @@ -481,8 +493,24 @@ void brute_force_knn_impl( index = index_row_major.data(); } - tiled_brute_force_knn( - handle, search, index, n, sizes[i], D, k, out_d_ptr, out_i_ptr, metric, metricArg); + // cosine/correlation are handled by metric processor, use IP distance + // for brute force knn call + auto tiled_metric = metric; + if (metric == raft::distance::DistanceType::CosineExpanded || + metric == raft::distance::DistanceType::CorrelationExpanded) { + tiled_metric = raft::distance::DistanceType::InnerProduct; + } + tiled_brute_force_knn(handle, + search, + index, + n, + sizes[i], + D, + k, + out_d_ptr, + out_i_ptr, + tiled_metric, + metricArg); break; } } @@ -501,6 +529,12 @@ void brute_force_knn_impl( knn_merge_parts(out_D, out_I, res_D, res_I, n, input.size(), k, userStream, trans.data()); } + query_metric_processor->revert(search_items); + query_metric_processor->postprocess(out_D); + for (size_t i = 0; i < input.size(); i++) { + metric_processors[i]->revert(input[i]); + } + if (translations == nullptr) delete id_ranges; }; From b44d15c4a50fc96c2cd7c9043f9cdd9aa0e9a0c5 Mon Sep 17 00:00:00 2001 From: Ben Frederickson Date: Tue, 28 Feb 2023 15:18:48 -0800 Subject: [PATCH 36/62] exclude cosine --- cpp/test/neighbors/tiled_knn.cu | 1 - 1 file changed, 1 deletion(-) diff --git a/cpp/test/neighbors/tiled_knn.cu b/cpp/test/neighbors/tiled_knn.cu index 9e86f1ef87..a44474eb04 100644 --- a/cpp/test/neighbors/tiled_knn.cu +++ b/cpp/test/neighbors/tiled_knn.cu @@ -234,7 +234,6 @@ const std::vector random_inputs = { {1000, 500000, 128, 128, 0, 0, raft::distance::DistanceType::L2Expanded, false}, {1000, 5000, 128, 128, 0, 0, raft::distance::DistanceType::LpUnexpanded, true}, {1000, 5000, 128, 128, 0, 0, raft::distance::DistanceType::L2SqrtExpanded, false}, - {1000, 5000, 128, 128, 0, 0, raft::distance::DistanceType::CosineExpanded, true}, {1000, 5000, 128, 128, 0, 0, raft::distance::DistanceType::InnerProduct, false}}; typedef TiledKNNTest TiledKNNTestF; From eb0271acd171df51f6d4cf2b921d0520eb6bf379 Mon Sep 17 00:00:00 2001 From: Ben Frederickson Date: Wed, 1 Mar 2023 11:27:07 -0800 Subject: [PATCH 37/62] avoid l2expanded distance I'm seeing a bunch of test failures in cuml CI - and at least some of them fail with L2Expanded distance and work with L2Unexpanded distance. --- cpp/include/raft/neighbors/detail/knn_brute_force.cuh | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/cpp/include/raft/neighbors/detail/knn_brute_force.cuh b/cpp/include/raft/neighbors/detail/knn_brute_force.cuh index d7244b120d..438762c4b9 100644 --- a/cpp/include/raft/neighbors/detail/knn_brute_force.cuh +++ b/cpp/include/raft/neighbors/detail/knn_brute_force.cuh @@ -500,6 +500,16 @@ void brute_force_knn_impl( metric == raft::distance::DistanceType::CorrelationExpanded) { tiled_metric = raft::distance::DistanceType::InnerProduct; } + + // L2Expanded distance seems to cause a bunch of test failures in cuml + // revert to using L2Unexpanded while we figure this out + if (metric == raft::distance::DistanceType::L2Expanded) { + tiled_metric = raft::distance::DistanceType::L2Unexpanded; + } + if (metric == raft::distance::DistanceType::L2SqrtExpanded) { + tiled_metric = raft::distance::DistanceType::L2SqrtUnexpanded; + } + tiled_brute_force_knn(handle, search, index, From 4c41c63b85263c4775a3449c271d2b82f230ca8e Mon Sep 17 00:00:00 2001 From: Ben Frederickson Date: Mon, 6 Mar 2023 10:50:52 -0800 Subject: [PATCH 38/62] Expanded L2 Changes Only calculate norms once per input (rather than per-tile) when calculating L2Expanded distance --- .../raft/neighbors/detail/knn_brute_force.cuh | 48 ++++++++++++++----- 1 file changed, 37 insertions(+), 11 deletions(-) diff --git a/cpp/include/raft/neighbors/detail/knn_brute_force.cuh b/cpp/include/raft/neighbors/detail/knn_brute_force.cuh index 438762c4b9..894bec6816 100644 --- a/cpp/include/raft/neighbors/detail/knn_brute_force.cuh +++ b/cpp/include/raft/neighbors/detail/knn_brute_force.cuh @@ -176,6 +176,22 @@ void tiled_brute_force_knn(const raft::device_resources& handle, // stores pairwise distances for the current tile rmm::device_uvector temp_distances(tile_rows * tile_cols, stream); + // calculate norms for L2 expanded distances - this lets us avoid calculating + // norms repeatedly per-tile, and just do once for the entire input + auto pairwise_metric = metric; + rmm::device_uvector search_norms(0, stream); + rmm::device_uvector index_norms(0, stream); + if (metric == raft::distance::DistanceType::L2Expanded || + metric == raft::distance::DistanceType::L2SqrtExpanded) { + search_norms.resize(m, stream); + index_norms.resize(n, stream); + raft::linalg::rowNorm( + search_norms.data(), search, d, m, raft::linalg::NormType::L2Norm, true, stream); + raft::linalg::rowNorm( + index_norms.data(), index, d, n, raft::linalg::NormType::L2Norm, true, stream); + pairwise_metric = raft::distance::DistanceType::InnerProduct; + } + // if we're tiling over columns, we need additional buffers for temporary output // distances/indices size_t num_col_tiles = raft::ceildiv(n, tile_cols); @@ -223,9 +239,28 @@ void tiled_brute_force_knn(const raft::device_resources& handle, current_query_size, current_centroid_size, d, - metric, + pairwise_metric, true, metric_arg); + if (metric == raft::distance::DistanceType::L2Expanded || + metric == raft::distance::DistanceType::L2SqrtExpanded) { + auto row_norms = search_norms.data() + i; + auto col_norms = index_norms.data() + j; + auto dist = temp_distances.data(); + auto count = thrust::make_counting_iterator(0); + + thrust::for_each(handle.get_thrust_policy(), + count, + count + current_query_size * current_centroid_size, + [=] __device__(IndexType i) { + IndexType row = i / current_centroid_size, + col = i % current_centroid_size; + dist[i] = row_norms[row] + col_norms[col] - 2.0 * dist[i]; + if (metric == raft::distance::DistanceType::L2SqrtExpanded) { + dist[i] = sqrt(dist[i]); + } + }); + } detail::select_k(temp_distances.data(), nullptr, @@ -494,22 +529,13 @@ void brute_force_knn_impl( } // cosine/correlation are handled by metric processor, use IP distance - // for brute force knn call + // for brute force knn call. auto tiled_metric = metric; if (metric == raft::distance::DistanceType::CosineExpanded || metric == raft::distance::DistanceType::CorrelationExpanded) { tiled_metric = raft::distance::DistanceType::InnerProduct; } - // L2Expanded distance seems to cause a bunch of test failures in cuml - // revert to using L2Unexpanded while we figure this out - if (metric == raft::distance::DistanceType::L2Expanded) { - tiled_metric = raft::distance::DistanceType::L2Unexpanded; - } - if (metric == raft::distance::DistanceType::L2SqrtExpanded) { - tiled_metric = raft::distance::DistanceType::L2SqrtUnexpanded; - } - tiled_brute_force_knn(handle, search, index, From cdf196271e5d6aa9108543185953e510355078cc Mon Sep 17 00:00:00 2001 From: Ben Frederickson Date: Mon, 6 Mar 2023 17:06:27 -0800 Subject: [PATCH 39/62] correct for small instabilities in l2sqrtexpanded distance --- cpp/include/raft/neighbors/detail/knn_brute_force.cuh | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/cpp/include/raft/neighbors/detail/knn_brute_force.cuh b/cpp/include/raft/neighbors/detail/knn_brute_force.cuh index 894bec6816..cc3b0daf52 100644 --- a/cpp/include/raft/neighbors/detail/knn_brute_force.cuh +++ b/cpp/include/raft/neighbors/detail/knn_brute_force.cuh @@ -255,7 +255,10 @@ void tiled_brute_force_knn(const raft::device_resources& handle, [=] __device__(IndexType i) { IndexType row = i / current_centroid_size, col = i % current_centroid_size; - dist[i] = row_norms[row] + col_norms[col] - 2.0 * dist[i]; + + auto val = row_norms[row] + col_norms[col] - 2.0 * dist[i]; + // correct for small instabilities + dist[i] = val > 1e-6 ? val : 0; if (metric == raft::distance::DistanceType::L2SqrtExpanded) { dist[i] = sqrt(dist[i]); } From 6a1e2d8c53fa1ba7d692f08864c20508d84d8ecf Mon Sep 17 00:00:00 2001 From: Ben Frederickson Date: Mon, 6 Mar 2023 17:25:35 -0800 Subject: [PATCH 40/62] warp divergence --- cpp/include/raft/neighbors/detail/knn_brute_force.cuh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cpp/include/raft/neighbors/detail/knn_brute_force.cuh b/cpp/include/raft/neighbors/detail/knn_brute_force.cuh index cc3b0daf52..62672d5586 100644 --- a/cpp/include/raft/neighbors/detail/knn_brute_force.cuh +++ b/cpp/include/raft/neighbors/detail/knn_brute_force.cuh @@ -258,7 +258,7 @@ void tiled_brute_force_knn(const raft::device_resources& handle, auto val = row_norms[row] + col_norms[col] - 2.0 * dist[i]; // correct for small instabilities - dist[i] = val > 1e-6 ? val : 0; + dist[i] = val * (fabs(val) >= 1e-6); if (metric == raft::distance::DistanceType::L2SqrtExpanded) { dist[i] = sqrt(dist[i]); } From 1e2817c46f2334e09acd68824dfd95832edb60d6 Mon Sep 17 00:00:00 2001 From: Ben Frederickson Date: Mon, 6 Mar 2023 19:36:21 -0800 Subject: [PATCH 41/62] clamp to 0 --- cpp/include/raft/neighbors/detail/knn_brute_force.cuh | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/cpp/include/raft/neighbors/detail/knn_brute_force.cuh b/cpp/include/raft/neighbors/detail/knn_brute_force.cuh index 62672d5586..f7d7d44df4 100644 --- a/cpp/include/raft/neighbors/detail/knn_brute_force.cuh +++ b/cpp/include/raft/neighbors/detail/knn_brute_force.cuh @@ -257,8 +257,11 @@ void tiled_brute_force_knn(const raft::device_resources& handle, col = i % current_centroid_size; auto val = row_norms[row] + col_norms[col] - 2.0 * dist[i]; - // correct for small instabilities - dist[i] = val * (fabs(val) >= 1e-6); + + // due to numerical instability (especially around self-distance) + // the distances here could be slightly negative, which will + // cause NaN values in the subsequent sqrt. Clamp to 0 + dist[i] = val * (val > 0.0); if (metric == raft::distance::DistanceType::L2SqrtExpanded) { dist[i] = sqrt(dist[i]); } From 4b56fac3f78e9208d9299a3ccd119108ab46e224 Mon Sep 17 00:00:00 2001 From: Ben Frederickson Date: Tue, 7 Mar 2023 10:17:50 -0800 Subject: [PATCH 42/62] threshold --- cpp/include/raft/neighbors/detail/knn_brute_force.cuh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cpp/include/raft/neighbors/detail/knn_brute_force.cuh b/cpp/include/raft/neighbors/detail/knn_brute_force.cuh index f7d7d44df4..a497db8e5d 100644 --- a/cpp/include/raft/neighbors/detail/knn_brute_force.cuh +++ b/cpp/include/raft/neighbors/detail/knn_brute_force.cuh @@ -261,7 +261,7 @@ void tiled_brute_force_knn(const raft::device_resources& handle, // due to numerical instability (especially around self-distance) // the distances here could be slightly negative, which will // cause NaN values in the subsequent sqrt. Clamp to 0 - dist[i] = val * (val > 0.0); + dist[i] = val * (val >= 0.0001); if (metric == raft::distance::DistanceType::L2SqrtExpanded) { dist[i] = sqrt(dist[i]); } From 4b41e2c37713f78810f57594f902d25ce8ed7872 Mon Sep 17 00:00:00 2001 From: Ben Frederickson Date: Tue, 7 Mar 2023 22:00:54 -0800 Subject: [PATCH 43/62] Transpose for fusedl2knn as well --- .../raft/neighbors/detail/knn_brute_force.cuh | 58 +++++++++---------- 1 file changed, 29 insertions(+), 29 deletions(-) diff --git a/cpp/include/raft/neighbors/detail/knn_brute_force.cuh b/cpp/include/raft/neighbors/detail/knn_brute_force.cuh index a497db8e5d..030f9eebe1 100644 --- a/cpp/include/raft/neighbors/detail/knn_brute_force.cuh +++ b/cpp/include/raft/neighbors/detail/knn_brute_force.cuh @@ -471,16 +471,38 @@ void brute_force_knn_impl( auto stream = handle.get_next_usable_stream(i); - if (k <= 64 && rowMajorQuery == rowMajorIndex && rowMajorQuery == true && - (metric == raft::distance::DistanceType::L2Unexpanded || - metric == raft::distance::DistanceType::L2SqrtUnexpanded || - metric == raft::distance::DistanceType::L2Expanded || - metric == raft::distance::DistanceType::L2SqrtExpanded)) { + // currently we don't support col_major inside tiled_brute_force_knn, because + // of limitattions of the pairwise_distance API: + // 1) paiwise_distance takes a single 'isRowMajor' parameter - and we have + // multiple options here (like rowMajorQuery/rowMajorIndex) + // 2) because of tiling, we need to be able to set a custom stride in the PW + // api, which isn't supported + // Instead, transpose the input matrices if they are passed as col-major. + // This also lets us use fusedL2KNN for col-major inputs + auto search = search_items; + rmm::device_uvector search_row_major(0, stream); + if (!rowMajorQuery) { + search_row_major.resize(n * D, stream); + raft::linalg::transpose(handle, search, search_row_major.data(), n, D, stream); + search = search_row_major.data(); + } + auto index = input[i]; + rmm::device_uvector index_row_major(0, stream); + if (!rowMajorIndex) { + index_row_major.resize(sizes[i] * D, stream); + raft::linalg::transpose(handle, index, index_row_major.data(), sizes[i], D, stream); + index = index_row_major.data(); + } + + if (k <= 64 && (metric == raft::distance::DistanceType::L2Unexpanded || + metric == raft::distance::DistanceType::L2SqrtUnexpanded || + metric == raft::distance::DistanceType::L2Expanded || + metric == raft::distance::DistanceType::L2SqrtExpanded)) { fusedL2Knn(D, out_i_ptr, out_d_ptr, - input[i], - search_items, + index, + search, sizes[i], n, k, @@ -512,28 +534,6 @@ void brute_force_knn_impl( haversine_knn(out_i_ptr, out_d_ptr, input[i], search_items, sizes[i], n, k, stream); break; default: - // currently we don't support col_major inside tiled_brute_force_knn, because - // of limitattions of the pairwise_distance API: - // 1) paiwise_distance takes a single 'isRowMajor' parameter - and we have - // multiple options here (like rowMajorQuery/rowMajorIndex) - // 2) because of tiling, we need to be able to set a custom stride in the PW - // api, which isn't supported - // Instead, transpose the input matrices if they are passed as col-major. - auto search = search_items; - rmm::device_uvector search_row_major(0, stream); - if (!rowMajorQuery) { - search_row_major.resize(n * D, stream); - raft::linalg::transpose(handle, search, search_row_major.data(), n, D, stream); - search = search_row_major.data(); - } - auto index = input[i]; - rmm::device_uvector index_row_major(0, stream); - if (!rowMajorIndex) { - index_row_major.resize(sizes[i] * D, stream); - raft::linalg::transpose(handle, index, index_row_major.data(), sizes[i], D, stream); - index = index_row_major.data(); - } - // cosine/correlation are handled by metric processor, use IP distance // for brute force knn call. auto tiled_metric = metric; From 455c9525d848372890ff1846b6817baa1b8e12c1 Mon Sep 17 00:00:00 2001 From: Ben Frederickson Date: Tue, 7 Mar 2023 22:21:27 -0800 Subject: [PATCH 44/62] fix --- .../raft/neighbors/detail/knn_brute_force.cuh | 14 ++------------ 1 file changed, 2 insertions(+), 12 deletions(-) diff --git a/cpp/include/raft/neighbors/detail/knn_brute_force.cuh b/cpp/include/raft/neighbors/detail/knn_brute_force.cuh index 030f9eebe1..3cedccf7be 100644 --- a/cpp/include/raft/neighbors/detail/knn_brute_force.cuh +++ b/cpp/include/raft/neighbors/detail/knn_brute_force.cuh @@ -498,18 +498,8 @@ void brute_force_knn_impl( metric == raft::distance::DistanceType::L2SqrtUnexpanded || metric == raft::distance::DistanceType::L2Expanded || metric == raft::distance::DistanceType::L2SqrtExpanded)) { - fusedL2Knn(D, - out_i_ptr, - out_d_ptr, - index, - search, - sizes[i], - n, - k, - rowMajorIndex, - rowMajorQuery, - stream, - metric); + fusedL2Knn( + D, out_i_ptr, out_d_ptr, index, search, sizes[i], n, k, true, true, stream, metric); // Perform necessary post-processing if (metric == raft::distance::DistanceType::L2SqrtExpanded || From df46b658053ac19abc059622ea5d26c5019b57a3 Mon Sep 17 00:00:00 2001 From: Ben Frederickson Date: Tue, 7 Mar 2023 22:21:27 -0800 Subject: [PATCH 45/62] Fix stream handling on col-major inputs For col-major inputs, we transpose the inputs since we only handle row-major in the tiled_brute_force_knn call. However, this was happening on the stream from stream pool when given multiple partitions - causing dask errors later on. Fix. --- .../raft/neighbors/detail/knn_brute_force.cuh | 65 ++++++++++--------- 1 file changed, 35 insertions(+), 30 deletions(-) diff --git a/cpp/include/raft/neighbors/detail/knn_brute_force.cuh b/cpp/include/raft/neighbors/detail/knn_brute_force.cuh index 030f9eebe1..11929b536a 100644 --- a/cpp/include/raft/neighbors/detail/knn_brute_force.cuh +++ b/cpp/include/raft/neighbors/detail/knn_brute_force.cuh @@ -462,6 +462,21 @@ void brute_force_knn_impl( out_I = all_I.data(); } + // currently we don't support col_major inside tiled_brute_force_knn, because + // of limitattions of the pairwise_distance API: + // 1) paiwise_distance takes a single 'isRowMajor' parameter - and we have + // multiple options here (like rowMajorQuery/rowMajorIndex) + // 2) because of tiling, we need to be able to set a custom stride in the PW + // api, which isn't supported + // Instead, transpose the input matrices if they are passed as col-major. + auto search = search_items; + rmm::device_uvector search_row_major(0, userStream); + if (!rowMajorQuery) { + search_row_major.resize(n * D, userStream); + raft::linalg::transpose(handle, search, search_row_major.data(), n, D, userStream); + search = search_row_major.data(); + } + // Make other streams from pool wait on main stream handle.wait_stream_pool_on_stream(); @@ -471,38 +486,16 @@ void brute_force_knn_impl( auto stream = handle.get_next_usable_stream(i); - // currently we don't support col_major inside tiled_brute_force_knn, because - // of limitattions of the pairwise_distance API: - // 1) paiwise_distance takes a single 'isRowMajor' parameter - and we have - // multiple options here (like rowMajorQuery/rowMajorIndex) - // 2) because of tiling, we need to be able to set a custom stride in the PW - // api, which isn't supported - // Instead, transpose the input matrices if they are passed as col-major. - // This also lets us use fusedL2KNN for col-major inputs - auto search = search_items; - rmm::device_uvector search_row_major(0, stream); - if (!rowMajorQuery) { - search_row_major.resize(n * D, stream); - raft::linalg::transpose(handle, search, search_row_major.data(), n, D, stream); - search = search_row_major.data(); - } - auto index = input[i]; - rmm::device_uvector index_row_major(0, stream); - if (!rowMajorIndex) { - index_row_major.resize(sizes[i] * D, stream); - raft::linalg::transpose(handle, index, index_row_major.data(), sizes[i], D, stream); - index = index_row_major.data(); - } - - if (k <= 64 && (metric == raft::distance::DistanceType::L2Unexpanded || - metric == raft::distance::DistanceType::L2SqrtUnexpanded || - metric == raft::distance::DistanceType::L2Expanded || - metric == raft::distance::DistanceType::L2SqrtExpanded)) { + if (k <= 64 && rowMajorQuery == rowMajorIndex && rowMajorQuery == true && + (metric == raft::distance::DistanceType::L2Unexpanded || + metric == raft::distance::DistanceType::L2SqrtUnexpanded || + metric == raft::distance::DistanceType::L2Expanded || + metric == raft::distance::DistanceType::L2SqrtExpanded)) { fusedL2Knn(D, out_i_ptr, out_d_ptr, - index, - search, + input[i], + search_items, sizes[i], n, k, @@ -522,7 +515,7 @@ void brute_force_knn_impl( res_D, n * k, [p] __device__(float input) { return powf(fabsf(input), p); }, - userStream); + stream); } } else { switch (metric) { @@ -534,6 +527,18 @@ void brute_force_knn_impl( haversine_knn(out_i_ptr, out_d_ptr, input[i], search_items, sizes[i], n, k, stream); break; default: + + auto index = input[i]; + rmm::device_uvector index_row_major(0, userStream); + if (!rowMajorIndex) { + index_row_major.resize(sizes[i] * D, userStream); + raft::linalg::transpose(handle, index, index_row_major.data(), sizes[i], D, userStream); + index = index_row_major.data(); + + // Make other streams from pool wait on main stream + handle.wait_stream_pool_on_stream(); + } + // cosine/correlation are handled by metric processor, use IP distance // for brute force knn call. auto tiled_metric = metric; From 28ebeefe3d49598cc4e77133430f30e5645f1fb7 Mon Sep 17 00:00:00 2001 From: Ben Frederickson Date: Mon, 13 Mar 2023 09:46:42 -0700 Subject: [PATCH 46/62] fix build for missing symbols --- cpp/src/nn/specializations/brute_force_knn_long_float_int.cu | 2 +- cpp/src/nn/specializations/brute_force_knn_long_float_uint.cu | 2 +- .../nn/specializations/brute_force_knn_uint32_t_float_int.cu | 2 +- .../nn/specializations/brute_force_knn_uint32_t_float_uint.cu | 2 +- 4 files changed, 4 insertions(+), 4 deletions(-) diff --git a/cpp/src/nn/specializations/brute_force_knn_long_float_int.cu b/cpp/src/nn/specializations/brute_force_knn_long_float_int.cu index c5e06176f2..56f7c3cfa9 100644 --- a/cpp/src/nn/specializations/brute_force_knn_long_float_int.cu +++ b/cpp/src/nn/specializations/brute_force_knn_long_float_int.cu @@ -19,7 +19,7 @@ #include // TODO: Change this to proper specializations after FAISS is removed -#include +// #include namespace raft { namespace spatial { diff --git a/cpp/src/nn/specializations/brute_force_knn_long_float_uint.cu b/cpp/src/nn/specializations/brute_force_knn_long_float_uint.cu index c4fe1b8aa8..6b1779107a 100644 --- a/cpp/src/nn/specializations/brute_force_knn_long_float_uint.cu +++ b/cpp/src/nn/specializations/brute_force_knn_long_float_uint.cu @@ -19,7 +19,7 @@ #include // TODO: Change this to proper specializations after FAISS is removed -#include +// #include namespace raft { namespace spatial { diff --git a/cpp/src/nn/specializations/brute_force_knn_uint32_t_float_int.cu b/cpp/src/nn/specializations/brute_force_knn_uint32_t_float_int.cu index d6ffe15fb1..94e3f35be5 100644 --- a/cpp/src/nn/specializations/brute_force_knn_uint32_t_float_int.cu +++ b/cpp/src/nn/specializations/brute_force_knn_uint32_t_float_int.cu @@ -19,7 +19,7 @@ #include // TODO: Change this to proper specializations after FAISS is removed -#include +// #include namespace raft { namespace spatial { diff --git a/cpp/src/nn/specializations/brute_force_knn_uint32_t_float_uint.cu b/cpp/src/nn/specializations/brute_force_knn_uint32_t_float_uint.cu index 6d56164e11..ab374918e9 100644 --- a/cpp/src/nn/specializations/brute_force_knn_uint32_t_float_uint.cu +++ b/cpp/src/nn/specializations/brute_force_knn_uint32_t_float_uint.cu @@ -19,7 +19,7 @@ #include // TODO: Change this to proper specializations after FAISS is removed -#include +// #include namespace raft { namespace spatial { From 65d7725e25e77c450f402cd2a77a153156d75400 Mon Sep 17 00:00:00 2001 From: Ben Frederickson Date: Mon, 13 Mar 2023 17:48:10 -0700 Subject: [PATCH 47/62] code review feedback --- .../raft/neighbors/detail/knn_brute_force.cuh | 2 +- .../brute_force_knn_long_float_int.cu | 5 +- .../brute_force_knn_long_float_uint.cu | 5 +- .../brute_force_knn_uint32_t_float_int.cu | 5 +- .../brute_force_knn_uint32_t_float_uint.cu | 5 +- cpp/test/neighbors/knn.cu | 6 +- cpp/test/neighbors/knn_utils.cuh | 13 +-- cpp/test/neighbors/tiled_knn.cu | 104 ++++++++++-------- 8 files changed, 73 insertions(+), 72 deletions(-) diff --git a/cpp/include/raft/neighbors/detail/knn_brute_force.cuh b/cpp/include/raft/neighbors/detail/knn_brute_force.cuh index 11929b536a..3398435491 100644 --- a/cpp/include/raft/neighbors/detail/knn_brute_force.cuh +++ b/cpp/include/raft/neighbors/detail/knn_brute_force.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2020-2023, NVIDIA CORPORATION. + * Copyright (c) 2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. diff --git a/cpp/src/nn/specializations/brute_force_knn_long_float_int.cu b/cpp/src/nn/specializations/brute_force_knn_long_float_int.cu index 56f7c3cfa9..2c21d1ec64 100644 --- a/cpp/src/nn/specializations/brute_force_knn_long_float_int.cu +++ b/cpp/src/nn/specializations/brute_force_knn_long_float_int.cu @@ -15,12 +15,9 @@ */ #include -#include +#include #include -// TODO: Change this to proper specializations after FAISS is removed -// #include - namespace raft { namespace spatial { namespace knn { diff --git a/cpp/src/nn/specializations/brute_force_knn_long_float_uint.cu b/cpp/src/nn/specializations/brute_force_knn_long_float_uint.cu index 6b1779107a..7e6e7e80d0 100644 --- a/cpp/src/nn/specializations/brute_force_knn_long_float_uint.cu +++ b/cpp/src/nn/specializations/brute_force_knn_long_float_uint.cu @@ -15,12 +15,9 @@ */ #include -#include +#include #include -// TODO: Change this to proper specializations after FAISS is removed -// #include - namespace raft { namespace spatial { namespace knn { diff --git a/cpp/src/nn/specializations/brute_force_knn_uint32_t_float_int.cu b/cpp/src/nn/specializations/brute_force_knn_uint32_t_float_int.cu index 94e3f35be5..e94c12d579 100644 --- a/cpp/src/nn/specializations/brute_force_knn_uint32_t_float_int.cu +++ b/cpp/src/nn/specializations/brute_force_knn_uint32_t_float_int.cu @@ -15,12 +15,9 @@ */ #include -#include +#include #include -// TODO: Change this to proper specializations after FAISS is removed -// #include - namespace raft { namespace spatial { namespace knn { diff --git a/cpp/src/nn/specializations/brute_force_knn_uint32_t_float_uint.cu b/cpp/src/nn/specializations/brute_force_knn_uint32_t_float_uint.cu index ab374918e9..95cf8a1eb3 100644 --- a/cpp/src/nn/specializations/brute_force_knn_uint32_t_float_uint.cu +++ b/cpp/src/nn/specializations/brute_force_knn_uint32_t_float_uint.cu @@ -15,12 +15,9 @@ */ #include -#include +#include #include -// TODO: Change this to proper specializations after FAISS is removed -// #include - namespace raft { namespace spatial { namespace knn { diff --git a/cpp/test/neighbors/knn.cu b/cpp/test/neighbors/knn.cu index b80729d1ce..7976725c65 100644 --- a/cpp/test/neighbors/knn.cu +++ b/cpp/test/neighbors/knn.cu @@ -20,8 +20,12 @@ #include #include #include + +#if defined RAFT_DISTANCE_COMPILED +#include +#endif + #if defined RAFT_NN_COMPILED -#include #include #endif diff --git a/cpp/test/neighbors/knn_utils.cuh b/cpp/test/neighbors/knn_utils.cuh index 2c4dad5c0b..ac34699ac5 100644 --- a/cpp/test/neighbors/knn_utils.cuh +++ b/cpp/test/neighbors/knn_utils.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2022-2023, NVIDIA CORPORATION. + * Copyright (c) 2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -16,17 +16,12 @@ #pragma once +#include + #include "../test_utils.cuh" #include -#include -#include -#include -#include -#include -#include - -#include +#include namespace raft::spatial::knn { template diff --git a/cpp/test/neighbors/tiled_knn.cu b/cpp/test/neighbors/tiled_knn.cu index a44474eb04..fccb0bb8fe 100644 --- a/cpp/test/neighbors/tiled_knn.cu +++ b/cpp/test/neighbors/tiled_knn.cu @@ -22,6 +22,7 @@ #include #include #include +#include #include #if defined RAFT_NN_COMPILED @@ -30,6 +31,7 @@ #endif #include +#include #include @@ -72,17 +74,30 @@ class TiledKNNTest : public ::testing::TestWithParam { ref_indices_(params_.num_queries * params_.k, stream_), ref_distances_(params_.num_queries * params_.k, stream_) { - RAFT_CUDA_TRY(cudaMemsetAsync(database.data(), 0, database.size() * sizeof(T), stream_)); - RAFT_CUDA_TRY( - cudaMemsetAsync(search_queries.data(), 0, search_queries.size() * sizeof(T), stream_)); - RAFT_CUDA_TRY( - cudaMemsetAsync(raft_indices_.data(), 0, raft_indices_.size() * sizeof(int), stream_)); - RAFT_CUDA_TRY( - cudaMemsetAsync(raft_distances_.data(), 0, raft_distances_.size() * sizeof(T), stream_)); - RAFT_CUDA_TRY( - cudaMemsetAsync(ref_indices_.data(), 0, ref_indices_.size() * sizeof(int), stream_)); - RAFT_CUDA_TRY( - cudaMemsetAsync(ref_distances_.data(), 0, ref_distances_.size() * sizeof(T), stream_)); + raft::matrix::fill( + handle_, + raft::make_device_matrix_view(database.data(), params_.num_db_vecs, params_.dim), + T{0.0}); + raft::matrix::fill( + handle_, + raft::make_device_matrix_view(search_queries.data(), params_.num_queries, params_.dim), + T{0.0}); + raft::matrix::fill( + handle_, + raft::make_device_matrix_view(raft_indices_.data(), params_.num_queries, params_.k), + 0); + raft::matrix::fill( + handle_, + raft::make_device_matrix_view(raft_distances_.data(), params_.num_queries, params_.k), + T{0.0}); + raft::matrix::fill( + handle_, + raft::make_device_matrix_view(ref_indices_.data(), params_.num_queries, params_.k), + 0); + raft::matrix::fill( + handle_, + raft::make_device_matrix_view(ref_distances_.data(), params_.num_queries, params_.k), + T{0.0}); } protected: @@ -116,34 +131,33 @@ class TiledKNNTest : public ::testing::TestWithParam { temp_dist = temp_row_major_dist.data(); } - using namespace raft::spatial; - knn::select_k(temp_dist, - nullptr, - num_queries, - num_db_vecs, - ref_distances_.data(), - ref_indices_.data(), - raft::distance::is_min_close(metric), - k_, - stream_); + raft::spatial::knn::detail::select_k(temp_dist, + nullptr, + num_queries, + num_db_vecs, + ref_distances_.data(), + ref_indices_.data(), + raft::distance::is_min_close(metric), + k_, + stream_); if ((params_.row_tiles == 0) && (params_.col_tiles == 0)) { std::vector input{database.data()}; std::vector sizes{static_cast(num_db_vecs)}; - raft::spatial::knn::brute_force_knn(handle_, - input, - sizes, - dim, - const_cast(search_queries.data()), - num_queries, - raft_indices_.data(), - raft_distances_.data(), - k_, - params_.row_major, - params_.row_major, - nullptr, - metric, - metric_arg); + neighbors::detail::brute_force_knn_impl(handle_, + input, + sizes, + dim, + const_cast(search_queries.data()), + num_queries, + raft_indices_.data(), + raft_distances_.data(), + k_, + params_.row_major, + params_.row_major, + nullptr, + metric, + metric_arg); } else { neighbors::detail::tiled_brute_force_knn(handle_, search_queries.data(), @@ -161,14 +175,14 @@ class TiledKNNTest : public ::testing::TestWithParam { } // verify. - ASSERT_TRUE(knn::devArrMatchKnnPair(ref_indices_.data(), - raft_indices_.data(), - ref_distances_.data(), - raft_distances_.data(), - num_queries, - k_, - float(0.001), - stream_)); + ASSERT_TRUE(raft::spatial::knn::devArrMatchKnnPair(ref_indices_.data(), + raft_indices_.data(), + ref_distances_.data(), + raft_distances_.data(), + num_queries, + k_, + float(0.001), + stream_)); } void SetUp() override @@ -228,8 +242,8 @@ const std::vector random_inputs = { {4, 12, 32, 6, 4, 8, raft::distance::DistanceType::L2Expanded, true}, // Test where passing column_tiles < K {1, 40, 32, 30, 1, 8, raft::distance::DistanceType::L2Expanded, true}, - // Passing tile sizes of 0 means to use the public api (instead of the - // detail api). Note that we can only test col_major in the public api + // Passing tile sizes of 0 means to use brute_force_knn_impl (instead of the + // tiled_brute_force_knn api). {1000, 500000, 128, 128, 0, 0, raft::distance::DistanceType::L2Expanded, true}, {1000, 500000, 128, 128, 0, 0, raft::distance::DistanceType::L2Expanded, false}, {1000, 5000, 128, 128, 0, 0, raft::distance::DistanceType::LpUnexpanded, true}, From e41ff884f241b51f09f2b93b5074d82187ca9ff1 Mon Sep 17 00:00:00 2001 From: Ben Frederickson Date: Mon, 13 Mar 2023 22:58:12 -0700 Subject: [PATCH 48/62] matrix::fill and linalg::map_offset --- .../raft/neighbors/detail/knn_brute_force.cuh | 45 +++++++++---------- 1 file changed, 22 insertions(+), 23 deletions(-) diff --git a/cpp/include/raft/neighbors/detail/knn_brute_force.cuh b/cpp/include/raft/neighbors/detail/knn_brute_force.cuh index 3398435491..cc86278183 100644 --- a/cpp/include/raft/neighbors/detail/knn_brute_force.cuh +++ b/cpp/include/raft/neighbors/detail/knn_brute_force.cuh @@ -27,7 +27,9 @@ #include #include #include +#include #include +#include #include #include #include @@ -208,11 +210,11 @@ void tiled_brute_force_knn(const raft::device_resources& handle, // if we have less than k items in the index, we should fill out the result // to indicate that we are missing items (and match behaviour in faiss) if (n < static_cast(k)) { - thrust::fill(handle.get_thrust_policy(), - distances, - distances + m * k, - std::numeric_limits::lowest()); - thrust::fill(handle.get_thrust_policy(), indices, indices + m * k, -1); + raft::matrix::fill(handle, + raft::make_device_matrix_view(distances, m, static_cast(k)), + std::numeric_limits::lowest()); + raft::matrix::fill( + handle, raft::make_device_matrix_view(indices, m, static_cast(k)), IndexType{-1}); } rmm::device_uvector temp_out_distances(tile_rows * temp_out_cols, stream); @@ -247,25 +249,22 @@ void tiled_brute_force_knn(const raft::device_resources& handle, auto row_norms = search_norms.data() + i; auto col_norms = index_norms.data() + j; auto dist = temp_distances.data(); - auto count = thrust::make_counting_iterator(0); - thrust::for_each(handle.get_thrust_policy(), - count, - count + current_query_size * current_centroid_size, - [=] __device__(IndexType i) { - IndexType row = i / current_centroid_size, - col = i % current_centroid_size; - - auto val = row_norms[row] + col_norms[col] - 2.0 * dist[i]; - - // due to numerical instability (especially around self-distance) - // the distances here could be slightly negative, which will - // cause NaN values in the subsequent sqrt. Clamp to 0 - dist[i] = val * (val >= 0.0001); - if (metric == raft::distance::DistanceType::L2SqrtExpanded) { - dist[i] = sqrt(dist[i]); - } - }); + raft::linalg::map_offset( + handle, + raft::make_device_vector_view(dist, current_query_size * current_centroid_size), + [=] __device__(IndexType i) { + IndexType row = i / current_centroid_size, col = i % current_centroid_size; + + auto val = row_norms[row] + col_norms[col] - 2.0 * dist[i]; + + // due to numerical instability (especially around self-distance) + // the distances here could be slightly negative, which will + // cause NaN values in the subsequent sqrt. Clamp to 0 + val = val * (val >= 0.0001); + if (metric == raft::distance::DistanceType::L2SqrtExpanded) { val = sqrt(dist[i]); } + return val; + }); } detail::select_k(temp_distances.data(), From 5eb7d22948ae61ffae6fcd788f65902d3ed15143 Mon Sep 17 00:00:00 2001 From: Ben Frederickson Date: Mon, 13 Mar 2023 23:31:29 -0700 Subject: [PATCH 49/62] build fix --- cpp/include/raft/neighbors/detail/knn_brute_force.cuh | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/cpp/include/raft/neighbors/detail/knn_brute_force.cuh b/cpp/include/raft/neighbors/detail/knn_brute_force.cuh index cc86278183..18eaffce9d 100644 --- a/cpp/include/raft/neighbors/detail/knn_brute_force.cuh +++ b/cpp/include/raft/neighbors/detail/knn_brute_force.cuh @@ -213,8 +213,11 @@ void tiled_brute_force_knn(const raft::device_resources& handle, raft::matrix::fill(handle, raft::make_device_matrix_view(distances, m, static_cast(k)), std::numeric_limits::lowest()); - raft::matrix::fill( - handle, raft::make_device_matrix_view(indices, m, static_cast(k)), IndexType{-1}); + + if constexpr (std::is_signed_v) { + raft::matrix::fill( + handle, raft::make_device_matrix_view(indices, m, static_cast(k)), IndexType{-1}); + } } rmm::device_uvector temp_out_distances(tile_rows * temp_out_cols, stream); From e36d089625ff4672cbcb40ac34b9c096397e1a5f Mon Sep 17 00:00:00 2001 From: Ben Frederickson Date: Tue, 14 Mar 2023 09:14:35 -0700 Subject: [PATCH 50/62] fix --- cpp/include/raft/neighbors/detail/knn_brute_force.cuh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cpp/include/raft/neighbors/detail/knn_brute_force.cuh b/cpp/include/raft/neighbors/detail/knn_brute_force.cuh index 18eaffce9d..bf0c4e91dd 100644 --- a/cpp/include/raft/neighbors/detail/knn_brute_force.cuh +++ b/cpp/include/raft/neighbors/detail/knn_brute_force.cuh @@ -265,7 +265,7 @@ void tiled_brute_force_knn(const raft::device_resources& handle, // the distances here could be slightly negative, which will // cause NaN values in the subsequent sqrt. Clamp to 0 val = val * (val >= 0.0001); - if (metric == raft::distance::DistanceType::L2SqrtExpanded) { val = sqrt(dist[i]); } + if (metric == raft::distance::DistanceType::L2SqrtExpanded) { val = sqrt(val); } return val; }); } From e8f9c55e112284764daefef3157db9a7b1aab19f Mon Sep 17 00:00:00 2001 From: Ben Frederickson Date: Tue, 14 Mar 2023 11:47:51 -0700 Subject: [PATCH 51/62] move faiss_select into raft::neighbors namespace --- ci/checks/copyright.py | 2 +- .../detail/faiss_select/Comparators.cuh | 4 +- .../detail/faiss_select/DistanceUtils.h | 4 +- .../detail/faiss_select/MergeNetworkBlock.cuh | 8 ++-- .../detail/faiss_select/MergeNetworkUtils.cuh | 4 +- .../detail/faiss_select/MergeNetworkWarp.cuh | 8 ++-- .../detail/faiss_select/Select.cuh | 10 ++--- .../detail/faiss_select/StaticUtils.h | 4 +- .../faiss_select/key_value_block_select.cuh | 8 ++-- .../raft/neighbors/detail/knn_brute_force.cuh | 45 +++++++++---------- .../detail/selection_faiss.cuh | 13 ++---- .../raft/spatial/knn/detail/ball_cover.cuh | 3 +- .../knn/detail/ball_cover/registers.cuh | 45 +++++++------------ .../raft/spatial/knn/detail/fused_l2_knn.cuh | 7 ++- .../spatial/knn/detail/haversine_distance.cuh | 12 ++--- cpp/include/raft/spatial/knn/knn.cuh | 4 +- cpp/test/neighbors/selection.cu | 2 +- cpp/test/neighbors/tiled_knn.cu | 3 -- 18 files changed, 78 insertions(+), 108 deletions(-) rename cpp/include/raft/{spatial/knn => neighbors}/detail/faiss_select/Comparators.cuh (84%) rename cpp/include/raft/{spatial/knn => neighbors}/detail/faiss_select/DistanceUtils.h (94%) rename cpp/include/raft/{spatial/knn => neighbors}/detail/faiss_select/MergeNetworkBlock.cuh (97%) rename cpp/include/raft/{spatial/knn => neighbors}/detail/faiss_select/MergeNetworkUtils.cuh (79%) rename cpp/include/raft/{spatial/knn => neighbors}/detail/faiss_select/MergeNetworkWarp.cuh (98%) rename cpp/include/raft/{spatial/knn => neighbors}/detail/faiss_select/Select.cuh (97%) rename cpp/include/raft/{spatial/knn => neighbors}/detail/faiss_select/StaticUtils.h (91%) rename cpp/include/raft/{spatial/knn => neighbors}/detail/faiss_select/key_value_block_select.cuh (96%) rename cpp/include/raft/{spatial/knn => neighbors}/detail/selection_faiss.cuh (96%) diff --git a/ci/checks/copyright.py b/ci/checks/copyright.py index 43a4a186f8..a44314a6ce 100644 --- a/ci/checks/copyright.py +++ b/ci/checks/copyright.py @@ -37,7 +37,7 @@ re.compile(r"setup[.]cfg$"), re.compile(r"meta[.]yaml$") ] -ExemptFiles = ["cpp/include/raft/spatial/knn/detail/faiss_select/"] +ExemptFiles = ["cpp/include/raft/neighbors/detail/faiss_select/"] # this will break starting at year 10000, which is probably OK :) CheckSimple = re.compile( diff --git a/cpp/include/raft/spatial/knn/detail/faiss_select/Comparators.cuh b/cpp/include/raft/neighbors/detail/faiss_select/Comparators.cuh similarity index 84% rename from cpp/include/raft/spatial/knn/detail/faiss_select/Comparators.cuh rename to cpp/include/raft/neighbors/detail/faiss_select/Comparators.cuh index 173c06af30..1a34d2f68c 100644 --- a/cpp/include/raft/spatial/knn/detail/faiss_select/Comparators.cuh +++ b/cpp/include/raft/neighbors/detail/faiss_select/Comparators.cuh @@ -10,7 +10,7 @@ #include #include -namespace raft::spatial::knn::detail::faiss_select { +namespace raft::neighbors::detail::faiss_select { template struct Comparator { @@ -26,4 +26,4 @@ struct Comparator { __device__ static inline bool gt(half a, half b) { return __hgt(a, b); } }; -} // namespace raft::spatial::knn::detail::faiss_select +} // namespace raft::neighbors::detail::faiss_select diff --git a/cpp/include/raft/spatial/knn/detail/faiss_select/DistanceUtils.h b/cpp/include/raft/neighbors/detail/faiss_select/DistanceUtils.h similarity index 94% rename from cpp/include/raft/spatial/knn/detail/faiss_select/DistanceUtils.h rename to cpp/include/raft/neighbors/detail/faiss_select/DistanceUtils.h index 51b7955d5a..cd4a52e5df 100644 --- a/cpp/include/raft/spatial/knn/detail/faiss_select/DistanceUtils.h +++ b/cpp/include/raft/neighbors/detail/faiss_select/DistanceUtils.h @@ -7,7 +7,7 @@ #pragma once -namespace raft::spatial::knn::detail::faiss_select { +namespace raft::neighbors::detail::faiss_select { // If the inner size (dim) of the vectors is small, we want a larger query tile // size, like 1024 inline void chooseTileSize(size_t numQueries, @@ -49,4 +49,4 @@ inline void chooseTileSize(size_t numQueries, // tileCols is the remainder size tileCols = std::min(targetUsage / preferredTileRows, numCentroids); } -} // namespace raft::spatial::knn::detail::faiss_select +} // namespace raft::neighbors::detail::faiss_select diff --git a/cpp/include/raft/spatial/knn/detail/faiss_select/MergeNetworkBlock.cuh b/cpp/include/raft/neighbors/detail/faiss_select/MergeNetworkBlock.cuh similarity index 97% rename from cpp/include/raft/spatial/knn/detail/faiss_select/MergeNetworkBlock.cuh rename to cpp/include/raft/neighbors/detail/faiss_select/MergeNetworkBlock.cuh index d923b41ded..79e3f95be0 100644 --- a/cpp/include/raft/spatial/knn/detail/faiss_select/MergeNetworkBlock.cuh +++ b/cpp/include/raft/neighbors/detail/faiss_select/MergeNetworkBlock.cuh @@ -8,10 +8,10 @@ #pragma once #include -#include -#include +#include +#include -namespace raft::spatial::knn::detail::faiss_select { +namespace raft::neighbors::detail::faiss_select { // Merge pairs of lists smaller than blockDim.x (NumThreads) template ::merge(listK, listV); } -} // namespace raft::spatial::knn::detail::faiss_select +} // namespace raft::neighbors::detail::faiss_select diff --git a/cpp/include/raft/spatial/knn/detail/faiss_select/MergeNetworkUtils.cuh b/cpp/include/raft/neighbors/detail/faiss_select/MergeNetworkUtils.cuh similarity index 79% rename from cpp/include/raft/spatial/knn/detail/faiss_select/MergeNetworkUtils.cuh rename to cpp/include/raft/neighbors/detail/faiss_select/MergeNetworkUtils.cuh index 2cb01f9199..78f794bff4 100644 --- a/cpp/include/raft/spatial/knn/detail/faiss_select/MergeNetworkUtils.cuh +++ b/cpp/include/raft/neighbors/detail/faiss_select/MergeNetworkUtils.cuh @@ -7,7 +7,7 @@ #pragma once -namespace raft::spatial::knn::detail::faiss_select { +namespace raft::neighbors::detail::faiss_select { template inline __device__ void swap(bool swap, T& x, T& y) @@ -22,4 +22,4 @@ inline __device__ void assign(bool assign, T& x, T y) { x = assign ? y : x; } -} // namespace raft::spatial::knn::detail::faiss_select +} // namespace raft::neighbors::detail::faiss_select diff --git a/cpp/include/raft/spatial/knn/detail/faiss_select/MergeNetworkWarp.cuh b/cpp/include/raft/neighbors/detail/faiss_select/MergeNetworkWarp.cuh similarity index 98% rename from cpp/include/raft/spatial/knn/detail/faiss_select/MergeNetworkWarp.cuh rename to cpp/include/raft/neighbors/detail/faiss_select/MergeNetworkWarp.cuh index bce739b2d8..04f7f90aac 100644 --- a/cpp/include/raft/spatial/knn/detail/faiss_select/MergeNetworkWarp.cuh +++ b/cpp/include/raft/neighbors/detail/faiss_select/MergeNetworkWarp.cuh @@ -7,12 +7,12 @@ #pragma once -#include -#include +#include +#include #include -namespace raft::spatial::knn::detail::faiss_select { +namespace raft::neighbors::detail::faiss_select { // // This file contains functions to: @@ -518,4 +518,4 @@ inline __device__ void warpSortAnyRegisters(K k[N], V v[N]) BitonicSortStep::sort(k, v); } -} // namespace raft::spatial::knn::detail::faiss_select +} // namespace raft::neighbors::detail::faiss_select diff --git a/cpp/include/raft/spatial/knn/detail/faiss_select/Select.cuh b/cpp/include/raft/neighbors/detail/faiss_select/Select.cuh similarity index 97% rename from cpp/include/raft/spatial/knn/detail/faiss_select/Select.cuh rename to cpp/include/raft/neighbors/detail/faiss_select/Select.cuh index e4faff7a6c..4aa7d68f54 100644 --- a/cpp/include/raft/spatial/knn/detail/faiss_select/Select.cuh +++ b/cpp/include/raft/neighbors/detail/faiss_select/Select.cuh @@ -7,14 +7,14 @@ #pragma once -#include -#include -#include +#include +#include +#include #include #include -namespace raft::spatial::knn::detail::faiss_select { +namespace raft::neighbors::detail::faiss_select { // Specialization for block-wide monotonic merges producing a merge sort // since what we really want is a constexpr loop expansion @@ -552,4 +552,4 @@ struct WarpSelect { V threadV; }; -} // namespace raft::spatial::knn::detail::faiss_select +} // namespace raft::neighbors::detail::faiss_select diff --git a/cpp/include/raft/spatial/knn/detail/faiss_select/StaticUtils.h b/cpp/include/raft/neighbors/detail/faiss_select/StaticUtils.h similarity index 91% rename from cpp/include/raft/spatial/knn/detail/faiss_select/StaticUtils.h rename to cpp/include/raft/neighbors/detail/faiss_select/StaticUtils.h index bac051b68c..5a25c7a321 100644 --- a/cpp/include/raft/spatial/knn/detail/faiss_select/StaticUtils.h +++ b/cpp/include/raft/neighbors/detail/faiss_select/StaticUtils.h @@ -15,7 +15,7 @@ #define __device__ #endif -namespace raft::spatial::knn::detail::faiss_select::utils { +namespace raft::neighbors::detail::faiss_select::utils { template constexpr __host__ __device__ bool isPowerOf2(T v) @@ -45,4 +45,4 @@ static_assert(nextHighestPowerOf2(1536000000u) == 2147483648u, "nextHighestPower static_assert(nextHighestPowerOf2((size_t)2147483648ULL) == (size_t)4294967296ULL, "nextHighestPowerOf2"); -} // namespace raft::spatial::knn::detail::faiss_select::utils +} // namespace raft::neighbors::detail::faiss_select::utils diff --git a/cpp/include/raft/spatial/knn/detail/faiss_select/key_value_block_select.cuh b/cpp/include/raft/neighbors/detail/faiss_select/key_value_block_select.cuh similarity index 96% rename from cpp/include/raft/spatial/knn/detail/faiss_select/key_value_block_select.cuh rename to cpp/include/raft/neighbors/detail/faiss_select/key_value_block_select.cuh index 617a26a243..ff06b7dca4 100644 --- a/cpp/include/raft/spatial/knn/detail/faiss_select/key_value_block_select.cuh +++ b/cpp/include/raft/neighbors/detail/faiss_select/key_value_block_select.cuh @@ -7,14 +7,14 @@ #pragma once -#include -#include +#include +#include // TODO: Need to think further about the impact (and new boundaries created) on the registers // because this will change the max k that can be processed. One solution might be to break // up k into multiple batches for larger k. -namespace raft::spatial::knn::detail::faiss_select { +namespace raft::neighbors::detail::faiss_select { // `Dir` true, produce largest values. // `Dir` false, produce smallest values. @@ -221,4 +221,4 @@ struct KeyValueBlockSelect { int kMinus1; }; -} // namespace raft::spatial::knn::detail::faiss_select +} // namespace raft::neighbors::detail::faiss_select diff --git a/cpp/include/raft/neighbors/detail/knn_brute_force.cuh b/cpp/include/raft/neighbors/detail/knn_brute_force.cuh index bf0c4e91dd..748f0894ca 100644 --- a/cpp/include/raft/neighbors/detail/knn_brute_force.cuh +++ b/cpp/include/raft/neighbors/detail/knn_brute_force.cuh @@ -30,11 +30,11 @@ #include #include #include -#include -#include +#include +#include +#include #include #include -#include #include #include @@ -164,8 +164,7 @@ void tiled_brute_force_knn(const raft::device_resources& handle, auto stream = handle.get_stream(); auto device_memory = handle.get_workspace_resource(); auto total_mem = device_memory->get_mem_info(stream).second; - raft::spatial::knn::detail::faiss_select::chooseTileSize( - m, n, d, sizeof(ElementType), total_mem, tile_rows, tile_cols); + faiss_select::chooseTileSize(m, n, d, sizeof(ElementType), total_mem, tile_rows, tile_cols); // for unittesting, its convenient to be able to put a max size on the tiles // so we can test the tiling logic without having to use huge inputs. @@ -270,15 +269,15 @@ void tiled_brute_force_knn(const raft::device_resources& handle, }); } - detail::select_k(temp_distances.data(), - nullptr, - current_query_size, - current_centroid_size, - distances + i * k, - indices + i * k, - select_min, - current_k, - stream); + select_k(temp_distances.data(), + nullptr, + current_query_size, + current_centroid_size, + distances + i * k, + indices + i * k, + select_min, + current_k, + stream); // if we're tiling over columns, we need to do a couple things to fix up // the output of select_k @@ -310,15 +309,15 @@ void tiled_brute_force_knn(const raft::device_resources& handle, if (tile_cols != n) { // select the actual top-k items here from the temporary output - detail::select_k(temp_out_distances.data(), - temp_out_indices.data(), - current_query_size, - temp_out_cols, - distances + i * k, - indices + i * k, - select_min, - k, - stream); + select_k(temp_out_distances.data(), + temp_out_indices.data(), + current_query_size, + temp_out_cols, + distances + i * k, + indices + i * k, + select_min, + k, + stream); } } } diff --git a/cpp/include/raft/spatial/knn/detail/selection_faiss.cuh b/cpp/include/raft/neighbors/detail/selection_faiss.cuh similarity index 96% rename from cpp/include/raft/spatial/knn/detail/selection_faiss.cuh rename to cpp/include/raft/neighbors/detail/selection_faiss.cuh index ff27345766..5df42e94b9 100644 --- a/cpp/include/raft/spatial/knn/detail/selection_faiss.cuh +++ b/cpp/include/raft/neighbors/detail/selection_faiss.cuh @@ -19,12 +19,9 @@ #include #include -#include +#include -namespace raft { -namespace spatial { -namespace knn { -namespace detail { +namespace raft::neighbors::detail { template constexpr int kFaissMaxK() @@ -169,8 +166,4 @@ inline void select_k(const key_t* inK, else ASSERT(k <= max_k, "Current max k is %d (requested %d)", max_k, k); } - -}; // namespace detail -}; // namespace knn -}; // namespace spatial -}; // namespace raft +}; // namespace raft::neighbors::detail diff --git a/cpp/include/raft/spatial/knn/detail/ball_cover.cuh b/cpp/include/raft/spatial/knn/detail/ball_cover.cuh index 9d89967dd2..99d688e232 100644 --- a/cpp/include/raft/spatial/knn/detail/ball_cover.cuh +++ b/cpp/include/raft/spatial/knn/detail/ball_cover.cuh @@ -22,14 +22,13 @@ #include "ball_cover/common.cuh" #include "ball_cover/registers.cuh" #include "haversine_distance.cuh" -#include "selection_faiss.cuh" #include #include #include -#include +#include #include #include diff --git a/cpp/include/raft/spatial/knn/detail/ball_cover/registers.cuh b/cpp/include/raft/spatial/knn/detail/ball_cover/registers.cuh index 394d27235b..f665368c41 100644 --- a/cpp/include/raft/spatial/knn/detail/ball_cover/registers.cuh +++ b/cpp/include/raft/spatial/knn/detail/ball_cover/registers.cuh @@ -19,13 +19,12 @@ #include "common.cuh" #include "../../ball_cover_types.hpp" -#include "../faiss_select/key_value_block_select.cuh" #include "../haversine_distance.cuh" -#include "../selection_faiss.cuh" #include #include +#include #include #include @@ -180,19 +179,14 @@ __global__ void compute_final_dists_registers(const value_t* X_index, local_x_ptr[j] = x_ptr[j]; } - faiss_select::KeyValueBlockSelect, - warp_q, - thread_q, - tpb> - heap(std::numeric_limits::max(), - std::numeric_limits::max(), - -1, - shared_memK, - shared_memV, - k); + using namespace raft::neighbors::detail::faiss_select; + KeyValueBlockSelect, warp_q, thread_q, tpb> heap( + std::numeric_limits::max(), + std::numeric_limits::max(), + -1, + shared_memK, + shared_memV, + k); const value_int n_k = Pow2::roundDown(k); value_int i = threadIdx.x; @@ -349,19 +343,14 @@ __global__ void block_rbc_kernel_registers(const value_t* X_index, } // Each warp works on 1 R - faiss_select::KeyValueBlockSelect, - warp_q, - thread_q, - tpb> - heap(std::numeric_limits::max(), - std::numeric_limits::max(), - -1, - shared_memK, - shared_memV, - k); + using namespace raft::neighbors::detail::faiss_select; + KeyValueBlockSelect, warp_q, thread_q, tpb> heap( + std::numeric_limits::max(), + std::numeric_limits::max(), + -1, + shared_memK, + shared_memV, + k); value_t min_R_dist = R_knn_dists[blockIdx.x * k + (k - 1)]; value_int n_dists_computed = 0; diff --git a/cpp/include/raft/spatial/knn/detail/fused_l2_knn.cuh b/cpp/include/raft/spatial/knn/detail/fused_l2_knn.cuh index f1f160a154..4e18a210d4 100644 --- a/cpp/include/raft/spatial/knn/detail/fused_l2_knn.cuh +++ b/cpp/include/raft/spatial/knn/detail/fused_l2_knn.cuh @@ -17,7 +17,7 @@ #include #include #include -#include +#include // TODO: Need to hide the PairwiseDistance class impl and expose to public API #include "processing.cuh" #include @@ -219,9 +219,8 @@ __global__ __launch_bounds__(Policy::Nthreads, 2) void fusedL2kNN(const DataT* x constexpr auto identity = std::numeric_limits::max(); constexpr auto keyMax = std::numeric_limits::max(); constexpr auto Dir = false; - typedef faiss_select:: - WarpSelect, NumWarpQ, NumThreadQ, 32> - myWarpSelect; + using namespace raft::neighbors::detail::faiss_select; + typedef WarpSelect, NumWarpQ, NumThreadQ, 32> myWarpSelect; auto rowEpilog_lambda = [m, n, numOfNN, out_dists, out_inds, mutexes] __device__( IdxT gridStrideY) { diff --git a/cpp/include/raft/spatial/knn/detail/haversine_distance.cuh b/cpp/include/raft/spatial/knn/detail/haversine_distance.cuh index 7d361ba4fb..058e98da9f 100644 --- a/cpp/include/raft/spatial/knn/detail/haversine_distance.cuh +++ b/cpp/include/raft/spatial/knn/detail/haversine_distance.cuh @@ -22,7 +22,7 @@ #include #include -#include +#include namespace raft { namespace spatial { @@ -65,13 +65,9 @@ __global__ void haversine_knn_kernel(value_idx* out_inds, __shared__ value_t smemK[kNumWarps * warp_q]; __shared__ value_idx smemV[kNumWarps * warp_q]; - faiss_select:: - BlockSelect, warp_q, thread_q, tpb> - heap(std::numeric_limits::max(), - std::numeric_limits::max(), - smemK, - smemV, - k); + using namespace raft::neighbors::detail::faiss_select; + BlockSelect, warp_q, thread_q, tpb> heap( + std::numeric_limits::max(), std::numeric_limits::max(), smemK, smemV, k); // Grid is exactly sized to rows available int limit = Pow2::roundDown(n_index_rows); diff --git a/cpp/include/raft/spatial/knn/knn.cuh b/cpp/include/raft/spatial/knn/knn.cuh index 727fb313ce..eef1131723 100644 --- a/cpp/include/raft/spatial/knn/knn.cuh +++ b/cpp/include/raft/spatial/knn/knn.cuh @@ -16,8 +16,6 @@ #pragma once -#include "detail/selection_faiss.cuh" - #include #include #include @@ -148,7 +146,7 @@ template switch (algo) { case SelectKAlgo::FAISS: - detail::select_k( + neighbors::detail::select_k( in_keys, in_values, n_inputs, input_len, out_keys, out_values, select_min, k, stream); break; diff --git a/cpp/test/neighbors/selection.cu b/cpp/test/neighbors/selection.cu index 26e37e433f..25939f65c3 100644 --- a/cpp/test/neighbors/selection.cu +++ b/cpp/test/neighbors/selection.cu @@ -118,7 +118,7 @@ struct SelectInOutComputed { } break; case knn::SelectKAlgo::FAISS: - if (spec.k > raft::spatial::knn::detail::kFaissMaxK()) { + if (spec.k > raft::neighbors::detail::kFaissMaxK()) { not_supported = true; return; } diff --git a/cpp/test/neighbors/tiled_knn.cu b/cpp/test/neighbors/tiled_knn.cu index fccb0bb8fe..9a54a3e20f 100644 --- a/cpp/test/neighbors/tiled_knn.cu +++ b/cpp/test/neighbors/tiled_knn.cu @@ -30,9 +30,6 @@ #include #endif -#include -#include - #include #include From 9f211a0ebe6d2f3926f778a105699265f0b8f31c Mon Sep 17 00:00:00 2001 From: Ben Frederickson Date: Tue, 14 Mar 2023 13:03:16 -0700 Subject: [PATCH 52/62] move knn_merge parts to its own file --- cpp/include/raft/neighbors/brute_force.cuh | 1 - .../raft/neighbors/detail/knn_brute_force.cuh | 146 +-------------- .../raft/neighbors/detail/knn_merge_parts.cuh | 172 ++++++++++++++++++ cpp/include/raft/spatial/knn/knn.cuh | 1 + cpp/test/neighbors/tiled_knn.cu | 18 +- 5 files changed, 183 insertions(+), 155 deletions(-) create mode 100644 cpp/include/raft/neighbors/detail/knn_merge_parts.cuh diff --git a/cpp/include/raft/neighbors/brute_force.cuh b/cpp/include/raft/neighbors/brute_force.cuh index 76e05f3234..4891cc5f8d 100644 --- a/cpp/include/raft/neighbors/brute_force.cuh +++ b/cpp/include/raft/neighbors/brute_force.cuh @@ -20,7 +20,6 @@ #include #include #include -#include namespace raft::neighbors::brute_force { diff --git a/cpp/include/raft/neighbors/detail/knn_brute_force.cuh b/cpp/include/raft/neighbors/detail/knn_brute_force.cuh index 748f0894ca..b1ae07445c 100644 --- a/cpp/include/raft/neighbors/detail/knn_brute_force.cuh +++ b/cpp/include/raft/neighbors/detail/knn_brute_force.cuh @@ -32,6 +32,7 @@ #include #include #include +#include #include #include #include @@ -42,103 +43,6 @@ namespace raft::neighbors::detail { using namespace raft::spatial::knn::detail; using namespace raft::spatial::knn; -template -__global__ void knn_merge_parts_kernel(value_t* inK, - value_idx* inV, - value_t* outK, - value_idx* outV, - size_t n_samples, - int n_parts, - value_t initK, - value_idx initV, - int k, - value_idx* translations) -{ - constexpr int kNumWarps = tpb / WarpSize; - - __shared__ value_t smemK[kNumWarps * warp_q]; - __shared__ value_idx smemV[kNumWarps * warp_q]; - - /** - * Uses shared memory - */ - faiss_select:: - BlockSelect, warp_q, thread_q, tpb> - heap(initK, initV, smemK, smemV, k); - - // Grid is exactly sized to rows available - int row = blockIdx.x; - int total_k = k * n_parts; - - int i = threadIdx.x; - - // Get starting pointers for cols in current thread - int part = i / k; - size_t row_idx = (row * k) + (part * n_samples * k); - - int col = i % k; - - value_t* inKStart = inK + (row_idx + col); - value_idx* inVStart = inV + (row_idx + col); - - int limit = Pow2::roundDown(total_k); - value_idx translation = 0; - - for (; i < limit; i += tpb) { - translation = translations[part]; - heap.add(*inKStart, (*inVStart) + translation); - - part = (i + tpb) / k; - row_idx = (row * k) + (part * n_samples * k); - - col = (i + tpb) % k; - - inKStart = inK + (row_idx + col); - inVStart = inV + (row_idx + col); - } - - // Handle last remainder fraction of a warp of elements - if (i < total_k) { - translation = translations[part]; - heap.addThreadQ(*inKStart, (*inVStart) + translation); - } - - heap.reduce(); - - for (int i = threadIdx.x; i < k; i += tpb) { - outK[row * k + i] = smemK[i]; - outV[row * k + i] = smemV[i]; - } -} - -template -inline void knn_merge_parts_impl(value_t* inK, - value_idx* inV, - value_t* outK, - value_idx* outV, - size_t n_samples, - int n_parts, - int k, - cudaStream_t stream, - value_idx* translations) -{ - auto grid = dim3(n_samples); - - constexpr int n_threads = (warp_q <= 1024) ? 128 : 64; - auto block = dim3(n_threads); - - auto kInit = std::numeric_limits::max(); - auto vInit = -1; - knn_merge_parts_kernel - <<>>( - inK, inV, outK, outV, n_samples, n_parts, kInit, vInit, k, translations); - RAFT_CUDA_TRY(cudaPeekAtLastError()); -} - /** * Calculates brute force knn, using a fixed memory budget * by tiling over both the rows and columns of pairwise_distances @@ -322,54 +226,6 @@ void tiled_brute_force_knn(const raft::device_resources& handle, } } -/** - * @brief Merge knn distances and index matrix, which have been partitioned - * by row, into a single matrix with only the k-nearest neighbors. - * - * @param inK partitioned knn distance matrix - * @param inV partitioned knn index matrix - * @param outK merged knn distance matrix - * @param outV merged knn index matrix - * @param n_samples number of samples per partition - * @param n_parts number of partitions - * @param k number of neighbors per partition (also number of merged neighbors) - * @param stream CUDA stream to use - * @param translations mapping of index offsets for each partition - */ -template -inline void knn_merge_parts(value_t* inK, - value_idx* inV, - value_t* outK, - value_idx* outV, - size_t n_samples, - int n_parts, - int k, - cudaStream_t stream, - value_idx* translations) -{ - if (k == 1) - knn_merge_parts_impl( - inK, inV, outK, outV, n_samples, n_parts, k, stream, translations); - else if (k <= 32) - knn_merge_parts_impl( - inK, inV, outK, outV, n_samples, n_parts, k, stream, translations); - else if (k <= 64) - knn_merge_parts_impl( - inK, inV, outK, outV, n_samples, n_parts, k, stream, translations); - else if (k <= 128) - knn_merge_parts_impl( - inK, inV, outK, outV, n_samples, n_parts, k, stream, translations); - else if (k <= 256) - knn_merge_parts_impl( - inK, inV, outK, outV, n_samples, n_parts, k, stream, translations); - else if (k <= 512) - knn_merge_parts_impl( - inK, inV, outK, outV, n_samples, n_parts, k, stream, translations); - else if (k <= 1024) - knn_merge_parts_impl( - inK, inV, outK, outV, n_samples, n_parts, k, stream, translations); -} - /** * Search the kNN for the k-nearest neighbors of a set of query vectors * @param[in] input vector of device device memory array pointers to search diff --git a/cpp/include/raft/neighbors/detail/knn_merge_parts.cuh b/cpp/include/raft/neighbors/detail/knn_merge_parts.cuh new file mode 100644 index 0000000000..e2b5c41fb0 --- /dev/null +++ b/cpp/include/raft/neighbors/detail/knn_merge_parts.cuh @@ -0,0 +1,172 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include + +#include +#include +#include + +namespace raft::neighbors::detail { + +template +__global__ void knn_merge_parts_kernel(value_t* inK, + value_idx* inV, + value_t* outK, + value_idx* outV, + size_t n_samples, + int n_parts, + value_t initK, + value_idx initV, + int k, + value_idx* translations) +{ + constexpr int kNumWarps = tpb / WarpSize; + + __shared__ value_t smemK[kNumWarps * warp_q]; + __shared__ value_idx smemV[kNumWarps * warp_q]; + + /** + * Uses shared memory + */ + faiss_select:: + BlockSelect, warp_q, thread_q, tpb> + heap(initK, initV, smemK, smemV, k); + + // Grid is exactly sized to rows available + int row = blockIdx.x; + int total_k = k * n_parts; + + int i = threadIdx.x; + + // Get starting pointers for cols in current thread + int part = i / k; + size_t row_idx = (row * k) + (part * n_samples * k); + + int col = i % k; + + value_t* inKStart = inK + (row_idx + col); + value_idx* inVStart = inV + (row_idx + col); + + int limit = Pow2::roundDown(total_k); + value_idx translation = 0; + + for (; i < limit; i += tpb) { + translation = translations[part]; + heap.add(*inKStart, (*inVStart) + translation); + + part = (i + tpb) / k; + row_idx = (row * k) + (part * n_samples * k); + + col = (i + tpb) % k; + + inKStart = inK + (row_idx + col); + inVStart = inV + (row_idx + col); + } + + // Handle last remainder fraction of a warp of elements + if (i < total_k) { + translation = translations[part]; + heap.addThreadQ(*inKStart, (*inVStart) + translation); + } + + heap.reduce(); + + for (int i = threadIdx.x; i < k; i += tpb) { + outK[row * k + i] = smemK[i]; + outV[row * k + i] = smemV[i]; + } +} + +template +inline void knn_merge_parts_impl(value_t* inK, + value_idx* inV, + value_t* outK, + value_idx* outV, + size_t n_samples, + int n_parts, + int k, + cudaStream_t stream, + value_idx* translations) +{ + auto grid = dim3(n_samples); + + constexpr int n_threads = (warp_q <= 1024) ? 128 : 64; + auto block = dim3(n_threads); + + auto kInit = std::numeric_limits::max(); + auto vInit = -1; + knn_merge_parts_kernel + <<>>( + inK, inV, outK, outV, n_samples, n_parts, kInit, vInit, k, translations); + RAFT_CUDA_TRY(cudaPeekAtLastError()); +} + +/** + * @brief Merge knn distances and index matrix, which have been partitioned + * by row, into a single matrix with only the k-nearest neighbors. + * + * @param inK partitioned knn distance matrix + * @param inV partitioned knn index matrix + * @param outK merged knn distance matrix + * @param outV merged knn index matrix + * @param n_samples number of samples per partition + * @param n_parts number of partitions + * @param k number of neighbors per partition (also number of merged neighbors) + * @param stream CUDA stream to use + * @param translations mapping of index offsets for each partition + */ +template +inline void knn_merge_parts(value_t* inK, + value_idx* inV, + value_t* outK, + value_idx* outV, + size_t n_samples, + int n_parts, + int k, + cudaStream_t stream, + value_idx* translations) +{ + if (k == 1) + knn_merge_parts_impl( + inK, inV, outK, outV, n_samples, n_parts, k, stream, translations); + else if (k <= 32) + knn_merge_parts_impl( + inK, inV, outK, outV, n_samples, n_parts, k, stream, translations); + else if (k <= 64) + knn_merge_parts_impl( + inK, inV, outK, outV, n_samples, n_parts, k, stream, translations); + else if (k <= 128) + knn_merge_parts_impl( + inK, inV, outK, outV, n_samples, n_parts, k, stream, translations); + else if (k <= 256) + knn_merge_parts_impl( + inK, inV, outK, outV, n_samples, n_parts, k, stream, translations); + else if (k <= 512) + knn_merge_parts_impl( + inK, inV, outK, outV, n_samples, n_parts, k, stream, translations); + else if (k <= 1024) + knn_merge_parts_impl( + inK, inV, outK, outV, n_samples, n_parts, k, stream, translations); +} +} // namespace raft::neighbors::detail diff --git a/cpp/include/raft/spatial/knn/knn.cuh b/cpp/include/raft/spatial/knn/knn.cuh index eef1131723..692d262043 100644 --- a/cpp/include/raft/spatial/knn/knn.cuh +++ b/cpp/include/raft/spatial/knn/knn.cuh @@ -21,6 +21,7 @@ #include #include #include +#include namespace raft::spatial::knn { diff --git a/cpp/test/neighbors/tiled_knn.cu b/cpp/test/neighbors/tiled_knn.cu index 9a54a3e20f..4784f915f3 100644 --- a/cpp/test/neighbors/tiled_knn.cu +++ b/cpp/test/neighbors/tiled_knn.cu @@ -128,15 +128,15 @@ class TiledKNNTest : public ::testing::TestWithParam { temp_dist = temp_row_major_dist.data(); } - raft::spatial::knn::detail::select_k(temp_dist, - nullptr, - num_queries, - num_db_vecs, - ref_distances_.data(), - ref_indices_.data(), - raft::distance::is_min_close(metric), - k_, - stream_); + raft::neighbors::detail::select_k(temp_dist, + nullptr, + num_queries, + num_db_vecs, + ref_distances_.data(), + ref_indices_.data(), + raft::distance::is_min_close(metric), + k_, + stream_); if ((params_.row_tiles == 0) && (params_.col_tiles == 0)) { std::vector input{database.data()}; From 9593ae16b527c32d8b46e88fef946c8c0a46a999 Mon Sep 17 00:00:00 2001 From: Ben Frederickson Date: Wed, 15 Mar 2023 10:27:02 -0700 Subject: [PATCH 53/62] Use stream pool --- .../raft/neighbors/detail/knn_brute_force.cuh | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/cpp/include/raft/neighbors/detail/knn_brute_force.cuh b/cpp/include/raft/neighbors/detail/knn_brute_force.cuh index b1ae07445c..1a2a0b51ce 100644 --- a/cpp/include/raft/neighbors/detail/knn_brute_force.cuh +++ b/cpp/include/raft/neighbors/detail/knn_brute_force.cuh @@ -384,16 +384,16 @@ void brute_force_knn_impl( haversine_knn(out_i_ptr, out_d_ptr, input[i], search_items, sizes[i], n, k, stream); break; default: + // Create a new handle with the current stream from the stream pool + raft::device_resources stream_pool_handle(handle); + raft::resource::set_cuda_stream(stream_pool_handle, stream); auto index = input[i]; - rmm::device_uvector index_row_major(0, userStream); + rmm::device_uvector index_row_major(0, stream); if (!rowMajorIndex) { - index_row_major.resize(sizes[i] * D, userStream); - raft::linalg::transpose(handle, index, index_row_major.data(), sizes[i], D, userStream); + index_row_major.resize(sizes[i] * D, stream); + raft::linalg::transpose(handle, index, index_row_major.data(), sizes[i], D, stream); index = index_row_major.data(); - - // Make other streams from pool wait on main stream - handle.wait_stream_pool_on_stream(); } // cosine/correlation are handled by metric processor, use IP distance @@ -404,7 +404,7 @@ void brute_force_knn_impl( tiled_metric = raft::distance::DistanceType::InnerProduct; } - tiled_brute_force_knn(handle, + tiled_brute_force_knn(stream_pool_handle, search, index, n, From 97753b0ecf99ed4728971aef35fcbbc3d33ea55f Mon Sep 17 00:00:00 2001 From: Ben Frederickson Date: Wed, 15 Mar 2023 13:39:26 -0700 Subject: [PATCH 54/62] use right handle for transpose --- cpp/include/raft/neighbors/detail/knn_brute_force.cuh | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/cpp/include/raft/neighbors/detail/knn_brute_force.cuh b/cpp/include/raft/neighbors/detail/knn_brute_force.cuh index 1a2a0b51ce..5e0c6e7da7 100644 --- a/cpp/include/raft/neighbors/detail/knn_brute_force.cuh +++ b/cpp/include/raft/neighbors/detail/knn_brute_force.cuh @@ -392,7 +392,8 @@ void brute_force_knn_impl( rmm::device_uvector index_row_major(0, stream); if (!rowMajorIndex) { index_row_major.resize(sizes[i] * D, stream); - raft::linalg::transpose(handle, index, index_row_major.data(), sizes[i], D, stream); + raft::linalg::transpose( + stream_pool_handle, index, index_row_major.data(), sizes[i], D, stream); index = index_row_major.data(); } From a534538176affbc16db4a62116d88e7c924785fc Mon Sep 17 00:00:00 2001 From: Ben Frederickson Date: Wed, 15 Mar 2023 14:07:37 -0700 Subject: [PATCH 55/62] set blas stream --- cpp/include/raft/core/resource/cublas_handle.hpp | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/cpp/include/raft/core/resource/cublas_handle.hpp b/cpp/include/raft/core/resource/cublas_handle.hpp index 710fcc7e60..c24a23bd6d 100644 --- a/cpp/include/raft/core/resource/cublas_handle.hpp +++ b/cpp/include/raft/core/resource/cublas_handle.hpp @@ -71,7 +71,9 @@ inline cublasHandle_t get_cublas_handle(resources const& res) cudaStream_t stream = get_cuda_stream(res); res.add_resource_factory(std::make_shared(stream)); } - return *res.get_resource(resource_type::CUBLAS_HANDLE); + auto ret = *res.get_resource(resource_type::CUBLAS_HANDLE); + cublasSetStream(ret, get_cuda_stream(res)); + return ret; }; /** From 237b7e165b05bd46c0069100ea9c571ebe60ed60 Mon Sep 17 00:00:00 2001 From: Ben Frederickson Date: Wed, 15 Mar 2023 14:17:07 -0700 Subject: [PATCH 56/62] error handling --- cpp/include/raft/core/resource/cublas_handle.hpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cpp/include/raft/core/resource/cublas_handle.hpp b/cpp/include/raft/core/resource/cublas_handle.hpp index c24a23bd6d..c8d8ee4c02 100644 --- a/cpp/include/raft/core/resource/cublas_handle.hpp +++ b/cpp/include/raft/core/resource/cublas_handle.hpp @@ -72,7 +72,7 @@ inline cublasHandle_t get_cublas_handle(resources const& res) res.add_resource_factory(std::make_shared(stream)); } auto ret = *res.get_resource(resource_type::CUBLAS_HANDLE); - cublasSetStream(ret, get_cuda_stream(res)); + RAFT_CUBLAS_TRY(cublasSetStream(ret, get_cuda_stream(res))); return ret; }; From 07290bcf77bcee648c75bef7a57c60018aaeea98 Mon Sep 17 00:00:00 2001 From: Ben Frederickson Date: Wed, 15 Mar 2023 20:34:53 -0700 Subject: [PATCH 57/62] try to isolate stream failure Try to figure out if failures are in transpose code, or in tiled_knn code --- .../raft/neighbors/detail/knn_brute_force.cuh | 16 +++++++++------- 1 file changed, 9 insertions(+), 7 deletions(-) diff --git a/cpp/include/raft/neighbors/detail/knn_brute_force.cuh b/cpp/include/raft/neighbors/detail/knn_brute_force.cuh index 5e0c6e7da7..ca24b16386 100644 --- a/cpp/include/raft/neighbors/detail/knn_brute_force.cuh +++ b/cpp/include/raft/neighbors/detail/knn_brute_force.cuh @@ -384,19 +384,21 @@ void brute_force_knn_impl( haversine_knn(out_i_ptr, out_d_ptr, input[i], search_items, sizes[i], n, k, stream); break; default: - // Create a new handle with the current stream from the stream pool - raft::device_resources stream_pool_handle(handle); - raft::resource::set_cuda_stream(stream_pool_handle, stream); - auto index = input[i]; rmm::device_uvector index_row_major(0, stream); if (!rowMajorIndex) { - index_row_major.resize(sizes[i] * D, stream); - raft::linalg::transpose( - stream_pool_handle, index, index_row_major.data(), sizes[i], D, stream); + index_row_major.resize(sizes[i] * D, userStream); + raft::linalg::transpose(handle, index, index_row_major.data(), sizes[i], D, userStream); index = index_row_major.data(); + + // Make other streams from pool wait on main stream + handle.wait_stream_pool_on_stream(); } + // Create a new handle with the current stream from the stream pool + raft::device_resources stream_pool_handle(handle); + raft::resource::set_cuda_stream(stream_pool_handle, stream); + // cosine/correlation are handled by metric processor, use IP distance // for brute force knn call. auto tiled_metric = metric; From 2671a0eaad642819d9482c8d6cc1dea565c17de8 Mon Sep 17 00:00:00 2001 From: Ben Frederickson Date: Thu, 16 Mar 2023 10:47:52 -0700 Subject: [PATCH 58/62] Move transpose code out of loop Trying to resolve stream failures here --- .../raft/neighbors/detail/knn_brute_force.cuh | 25 +++++++++++++------ 1 file changed, 18 insertions(+), 7 deletions(-) diff --git a/cpp/include/raft/neighbors/detail/knn_brute_force.cuh b/cpp/include/raft/neighbors/detail/knn_brute_force.cuh index ca24b16386..b685c947db 100644 --- a/cpp/include/raft/neighbors/detail/knn_brute_force.cuh +++ b/cpp/include/raft/neighbors/detail/knn_brute_force.cuh @@ -334,9 +334,25 @@ void brute_force_knn_impl( search = search_row_major.data(); } + // transpose into a temporary buffer if necessary + rmm::device_uvector index_row_major(0, userStream); + if (!rowMajorIndex) { + size_t total_size = 0; + for (auto size : sizes) { + total_size += size; + } + index_row_major.resize(total_size * D, userStream); + auto index = index_row_major.data(); + for (size_t i = 0; i < input.size(); i++) { + raft::linalg::transpose(handle, input[i], index, sizes[i], D, userStream); + index += sizes[i]; + } + } + // Make other streams from pool wait on main stream handle.wait_stream_pool_on_stream(); + size_t total_rows_processed = 0; for (size_t i = 0; i < input.size(); i++) { value_t* out_d_ptr = out_D + (i * k * n); IdxType* out_i_ptr = out_I + (i * k * n); @@ -385,14 +401,9 @@ void brute_force_knn_impl( break; default: auto index = input[i]; - rmm::device_uvector index_row_major(0, stream); if (!rowMajorIndex) { - index_row_major.resize(sizes[i] * D, userStream); - raft::linalg::transpose(handle, index, index_row_major.data(), sizes[i], D, userStream); - index = index_row_major.data(); - - // Make other streams from pool wait on main stream - handle.wait_stream_pool_on_stream(); + index = index_row_major.data() + total_rows_processed; + total_rows_processed += sizes[i]; } // Create a new handle with the current stream from the stream pool From 7e0bb9b6d291906684a59c9a09946e3fb5ff5019 Mon Sep 17 00:00:00 2001 From: Ben Frederickson Date: Thu, 16 Mar 2023 11:55:18 -0700 Subject: [PATCH 59/62] fix --- cpp/include/raft/neighbors/detail/knn_brute_force.cuh | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/cpp/include/raft/neighbors/detail/knn_brute_force.cuh b/cpp/include/raft/neighbors/detail/knn_brute_force.cuh index b685c947db..335e4c332d 100644 --- a/cpp/include/raft/neighbors/detail/knn_brute_force.cuh +++ b/cpp/include/raft/neighbors/detail/knn_brute_force.cuh @@ -345,7 +345,7 @@ void brute_force_knn_impl( auto index = index_row_major.data(); for (size_t i = 0; i < input.size(); i++) { raft::linalg::transpose(handle, input[i], index, sizes[i], D, userStream); - index += sizes[i]; + index += sizes[i] * D; } } @@ -402,7 +402,7 @@ void brute_force_knn_impl( default: auto index = input[i]; if (!rowMajorIndex) { - index = index_row_major.data() + total_rows_processed; + index = index_row_major.data() + total_rows_processed * D; total_rows_processed += sizes[i]; } From b4c3284afbb1bd6b578c3154b7d7e2940baefe09 Mon Sep 17 00:00:00 2001 From: Ben Frederickson Date: Thu, 16 Mar 2023 15:41:40 -0700 Subject: [PATCH 60/62] try transpose inside streampool again --- .../raft/neighbors/detail/knn_brute_force.cuh | 14 +++++--------- 1 file changed, 5 insertions(+), 9 deletions(-) diff --git a/cpp/include/raft/neighbors/detail/knn_brute_force.cuh b/cpp/include/raft/neighbors/detail/knn_brute_force.cuh index 335e4c332d..875fc3b37c 100644 --- a/cpp/include/raft/neighbors/detail/knn_brute_force.cuh +++ b/cpp/include/raft/neighbors/detail/knn_brute_force.cuh @@ -342,11 +342,6 @@ void brute_force_knn_impl( total_size += size; } index_row_major.resize(total_size * D, userStream); - auto index = index_row_major.data(); - for (size_t i = 0; i < input.size(); i++) { - raft::linalg::transpose(handle, input[i], index, sizes[i], D, userStream); - index += sizes[i] * D; - } } // Make other streams from pool wait on main stream @@ -400,16 +395,17 @@ void brute_force_knn_impl( haversine_knn(out_i_ptr, out_d_ptr, input[i], search_items, sizes[i], n, k, stream); break; default: + // Create a new handle with the current stream from the stream pool + raft::device_resources stream_pool_handle(handle); + raft::resource::set_cuda_stream(stream_pool_handle, stream); + auto index = input[i]; if (!rowMajorIndex) { index = index_row_major.data() + total_rows_processed * D; total_rows_processed += sizes[i]; + raft::linalg::transpose(handle, input[i], index, sizes[i], D, stream); } - // Create a new handle with the current stream from the stream pool - raft::device_resources stream_pool_handle(handle); - raft::resource::set_cuda_stream(stream_pool_handle, stream); - // cosine/correlation are handled by metric processor, use IP distance // for brute force knn call. auto tiled_metric = metric; From 92d82db893feca721e33babee7c600b065f1c71e Mon Sep 17 00:00:00 2001 From: Ben Frederickson Date: Thu, 16 Mar 2023 20:07:36 -0700 Subject: [PATCH 61/62] one more try with cublasSetStream --- cpp/include/raft/linalg/detail/transpose.cuh | 20 ++++++++++++++------ 1 file changed, 14 insertions(+), 6 deletions(-) diff --git a/cpp/include/raft/linalg/detail/transpose.cuh b/cpp/include/raft/linalg/detail/transpose.cuh index 9e7b236fed..161ae20112 100644 --- a/cpp/include/raft/linalg/detail/transpose.cuh +++ b/cpp/include/raft/linalg/detail/transpose.cuh @@ -37,6 +37,7 @@ void transpose(raft::device_resources const& handle, cudaStream_t stream) { cublasHandle_t cublas_h = handle.get_cublas_handle(); + RAFT_CUBLAS_TRY(cublasSetStream(cublas_h, stream)); int out_n_rows = n_cols; int out_n_cols = n_rows; @@ -90,7 +91,11 @@ void transpose_row_major_impl( auto out_n_cols = in.extent(0); T constexpr kOne = 1; T constexpr kZero = 0; - CUBLAS_TRY(cublasgeam(handle.get_cublas_handle(), + + cublasHandle_t cublas_h = handle.get_cublas_handle(); + RAFT_CUBLAS_TRY(cublasSetStream(cublas_h, stream)); + + CUBLAS_TRY(cublasgeam(cublas_h, CUBLAS_OP_T, CUBLAS_OP_N, out_n_cols, @@ -112,11 +117,14 @@ void transpose_col_major_impl( raft::mdspan, LayoutPolicy, AccessorPolicy> in, raft::mdspan, LayoutPolicy, AccessorPolicy> out) { - auto out_n_rows = in.extent(1); - auto out_n_cols = in.extent(0); - T constexpr kOne = 1; - T constexpr kZero = 0; - CUBLAS_TRY(cublasgeam(handle.get_cublas_handle(), + auto out_n_rows = in.extent(1); + auto out_n_cols = in.extent(0); + T constexpr kOne = 1; + T constexpr kZero = 0; + cublasHandle_t cublas_h = handle.get_cublas_handle(); + RAFT_CUBLAS_TRY(cublasSetStream(cublas_h, stream)); + + CUBLAS_TRY(cublasgeam(cublas_h, CUBLAS_OP_T, CUBLAS_OP_N, out_n_rows, From 21a19531dd2f4b48e74d4e2b75e613464073d376 Mon Sep 17 00:00:00 2001 From: Ben Frederickson Date: Thu, 16 Mar 2023 21:30:14 -0700 Subject: [PATCH 62/62] fix --- cpp/include/raft/linalg/detail/transpose.cuh | 17 ++++++----------- 1 file changed, 6 insertions(+), 11 deletions(-) diff --git a/cpp/include/raft/linalg/detail/transpose.cuh b/cpp/include/raft/linalg/detail/transpose.cuh index 161ae20112..05588bda9c 100644 --- a/cpp/include/raft/linalg/detail/transpose.cuh +++ b/cpp/include/raft/linalg/detail/transpose.cuh @@ -92,10 +92,7 @@ void transpose_row_major_impl( T constexpr kOne = 1; T constexpr kZero = 0; - cublasHandle_t cublas_h = handle.get_cublas_handle(); - RAFT_CUBLAS_TRY(cublasSetStream(cublas_h, stream)); - - CUBLAS_TRY(cublasgeam(cublas_h, + CUBLAS_TRY(cublasgeam(handle.get_cublas_handle(), CUBLAS_OP_T, CUBLAS_OP_N, out_n_cols, @@ -117,14 +114,12 @@ void transpose_col_major_impl( raft::mdspan, LayoutPolicy, AccessorPolicy> in, raft::mdspan, LayoutPolicy, AccessorPolicy> out) { - auto out_n_rows = in.extent(1); - auto out_n_cols = in.extent(0); - T constexpr kOne = 1; - T constexpr kZero = 0; - cublasHandle_t cublas_h = handle.get_cublas_handle(); - RAFT_CUBLAS_TRY(cublasSetStream(cublas_h, stream)); + auto out_n_rows = in.extent(1); + auto out_n_cols = in.extent(0); + T constexpr kOne = 1; + T constexpr kZero = 0; - CUBLAS_TRY(cublasgeam(cublas_h, + CUBLAS_TRY(cublasgeam(handle.get_cublas_handle(), CUBLAS_OP_T, CUBLAS_OP_N, out_n_rows,