diff --git a/ci/checks/copyright.py b/ci/checks/copyright.py index 43a4a186f8..a44314a6ce 100644 --- a/ci/checks/copyright.py +++ b/ci/checks/copyright.py @@ -37,7 +37,7 @@ re.compile(r"setup[.]cfg$"), re.compile(r"meta[.]yaml$") ] -ExemptFiles = ["cpp/include/raft/spatial/knn/detail/faiss_select/"] +ExemptFiles = ["cpp/include/raft/neighbors/detail/faiss_select/"] # this will break starting at year 10000, which is probably OK :) CheckSimple = re.compile( diff --git a/cpp/include/raft/core/resource/cublas_handle.hpp b/cpp/include/raft/core/resource/cublas_handle.hpp index 710fcc7e60..c8d8ee4c02 100644 --- a/cpp/include/raft/core/resource/cublas_handle.hpp +++ b/cpp/include/raft/core/resource/cublas_handle.hpp @@ -71,7 +71,9 @@ inline cublasHandle_t get_cublas_handle(resources const& res) cudaStream_t stream = get_cuda_stream(res); res.add_resource_factory(std::make_shared(stream)); } - return *res.get_resource(resource_type::CUBLAS_HANDLE); + auto ret = *res.get_resource(resource_type::CUBLAS_HANDLE); + RAFT_CUBLAS_TRY(cublasSetStream(ret, get_cuda_stream(res))); + return ret; }; /** diff --git a/cpp/include/raft/distance/distance_types.hpp b/cpp/include/raft/distance/distance_types.hpp index f5ed68af4a..4060147f1d 100644 --- a/cpp/include/raft/distance/distance_types.hpp +++ b/cpp/include/raft/distance/distance_types.hpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2021-2022, NVIDIA CORPORATION. + * Copyright (c) 2021-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -66,6 +66,26 @@ enum DistanceType : unsigned short { Precomputed = 100 }; +/** + * Whether minimal distance corresponds to similar elements (using the given metric). + */ +inline bool is_min_close(DistanceType metric) +{ + bool select_min; + switch (metric) { + case DistanceType::InnerProduct: + case DistanceType::CosineExpanded: + case DistanceType::CorrelationExpanded: + // Similarity metrics have the opposite meaning, i.e. nearest neighbors are those with larger + // similarity (See the same logic at cpp/include/raft/sparse/spatial/detail/knn.cuh:362 + // {perform_k_selection}) + select_min = false; + break; + default: select_min = true; + } + return select_min; +} + namespace kernels { enum KernelType { LINEAR, POLYNOMIAL, RBF, TANH }; diff --git a/cpp/include/raft/linalg/detail/transpose.cuh b/cpp/include/raft/linalg/detail/transpose.cuh index 9e7b236fed..05588bda9c 100644 --- a/cpp/include/raft/linalg/detail/transpose.cuh +++ b/cpp/include/raft/linalg/detail/transpose.cuh @@ -37,6 +37,7 @@ void transpose(raft::device_resources const& handle, cudaStream_t stream) { cublasHandle_t cublas_h = handle.get_cublas_handle(); + RAFT_CUBLAS_TRY(cublasSetStream(cublas_h, stream)); int out_n_rows = n_cols; int out_n_cols = n_rows; @@ -90,6 +91,7 @@ void transpose_row_major_impl( auto out_n_cols = in.extent(0); T constexpr kOne = 1; T constexpr kZero = 0; + CUBLAS_TRY(cublasgeam(handle.get_cublas_handle(), CUBLAS_OP_T, CUBLAS_OP_N, @@ -116,6 +118,7 @@ void transpose_col_major_impl( auto out_n_cols = in.extent(0); T constexpr kOne = 1; T constexpr kZero = 0; + CUBLAS_TRY(cublasgeam(handle.get_cublas_handle(), CUBLAS_OP_T, CUBLAS_OP_N, diff --git a/cpp/include/raft/neighbors/brute_force.cuh b/cpp/include/raft/neighbors/brute_force.cuh index ac9d14ce17..4891cc5f8d 100644 --- a/cpp/include/raft/neighbors/brute_force.cuh +++ b/cpp/include/raft/neighbors/brute_force.cuh @@ -18,9 +18,8 @@ #include #include +#include #include -#include -#include namespace raft::neighbors::brute_force { @@ -96,15 +95,15 @@ inline void knn_merge_parts( "Number of columns in output indices and distances matrices must be equal to k"); auto n_parts = in_keys.extent(0) / n_samples; - spatial::knn::detail::knn_merge_parts(in_keys.data_handle(), - in_values.data_handle(), - out_keys.data_handle(), - out_values.data_handle(), - n_samples, - n_parts, - in_keys.extent(1), - handle.get_stream(), - translations.value_or(nullptr)); + detail::knn_merge_parts(in_keys.data_handle(), + in_values.data_handle(), + out_keys.data_handle(), + out_values.data_handle(), + n_samples, + n_parts, + in_keys.extent(1), + handle.get_stream(), + translations.value_or(nullptr)); } /** @@ -181,21 +180,21 @@ void knn(raft::device_resources const& handle, std::vector* trans_arg = global_id_offset.has_value() ? &trans : nullptr; - raft::spatial::knn::detail::brute_force_knn_impl(handle, - inputs, - sizes, - static_cast(index[0].extent(1)), - // TODO: This is unfortunate. Need to fix. - const_cast(search.data_handle()), - static_cast(search.extent(0)), - indices.data_handle(), - distances.data_handle(), - k, - rowMajorIndex, - rowMajorQuery, - trans_arg, - metric, - metric_arg.value_or(2.0f)); + raft::neighbors::detail::brute_force_knn_impl(handle, + inputs, + sizes, + static_cast(index[0].extent(1)), + // TODO: This is unfortunate. Need to fix. + const_cast(search.data_handle()), + static_cast(search.extent(0)), + indices.data_handle(), + distances.data_handle(), + k, + rowMajorIndex, + rowMajorQuery, + trans_arg, + metric, + metric_arg.value_or(2.0f)); } /** diff --git a/cpp/include/raft/spatial/knn/detail/faiss_select/Comparators.cuh b/cpp/include/raft/neighbors/detail/faiss_select/Comparators.cuh similarity index 84% rename from cpp/include/raft/spatial/knn/detail/faiss_select/Comparators.cuh rename to cpp/include/raft/neighbors/detail/faiss_select/Comparators.cuh index 173c06af30..1a34d2f68c 100644 --- a/cpp/include/raft/spatial/knn/detail/faiss_select/Comparators.cuh +++ b/cpp/include/raft/neighbors/detail/faiss_select/Comparators.cuh @@ -10,7 +10,7 @@ #include #include -namespace raft::spatial::knn::detail::faiss_select { +namespace raft::neighbors::detail::faiss_select { template struct Comparator { @@ -26,4 +26,4 @@ struct Comparator { __device__ static inline bool gt(half a, half b) { return __hgt(a, b); } }; -} // namespace raft::spatial::knn::detail::faiss_select +} // namespace raft::neighbors::detail::faiss_select diff --git a/cpp/include/raft/neighbors/detail/faiss_select/DistanceUtils.h b/cpp/include/raft/neighbors/detail/faiss_select/DistanceUtils.h new file mode 100644 index 0000000000..cd4a52e5df --- /dev/null +++ b/cpp/include/raft/neighbors/detail/faiss_select/DistanceUtils.h @@ -0,0 +1,52 @@ +/** + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file thirdparty/LICENSES/LICENSE.faiss + */ + +#pragma once + +namespace raft::neighbors::detail::faiss_select { +// If the inner size (dim) of the vectors is small, we want a larger query tile +// size, like 1024 +inline void chooseTileSize(size_t numQueries, + size_t numCentroids, + size_t dim, + size_t elementSize, + size_t totalMem, + size_t& tileRows, + size_t& tileCols) +{ + // The matrix multiplication should be large enough to be efficient, but if + // it is too large, we seem to lose efficiency as opposed to + // double-streaming. Each tile size here defines 1/2 of the memory use due + // to double streaming. We ignore available temporary memory, as that is + // adjusted independently by the user and can thus meet these requirements + // (or not). For <= 4 GB GPUs, prefer 512 MB of usage. For <= 8 GB GPUs, + // prefer 768 MB of usage. Otherwise, prefer 1 GB of usage. + size_t targetUsage = 0; + + if (totalMem <= ((size_t)4) * 1024 * 1024 * 1024) { + targetUsage = 512 * 1024 * 1024; + } else if (totalMem <= ((size_t)8) * 1024 * 1024 * 1024) { + targetUsage = 768 * 1024 * 1024; + } else { + targetUsage = 1024 * 1024 * 1024; + } + + targetUsage /= 2 * elementSize; + + // 512 seems to be a batch size sweetspot for float32. + // If we are on float16, increase to 512. + // If the k size (vec dim) of the matrix multiplication is small (<= 32), + // increase to 1024. + size_t preferredTileRows = 512; + if (dim <= 32) { preferredTileRows = 1024; } + + tileRows = std::min(preferredTileRows, numQueries); + + // tileCols is the remainder size + tileCols = std::min(targetUsage / preferredTileRows, numCentroids); +} +} // namespace raft::neighbors::detail::faiss_select diff --git a/cpp/include/raft/spatial/knn/detail/faiss_select/MergeNetworkBlock.cuh b/cpp/include/raft/neighbors/detail/faiss_select/MergeNetworkBlock.cuh similarity index 97% rename from cpp/include/raft/spatial/knn/detail/faiss_select/MergeNetworkBlock.cuh rename to cpp/include/raft/neighbors/detail/faiss_select/MergeNetworkBlock.cuh index d923b41ded..79e3f95be0 100644 --- a/cpp/include/raft/spatial/knn/detail/faiss_select/MergeNetworkBlock.cuh +++ b/cpp/include/raft/neighbors/detail/faiss_select/MergeNetworkBlock.cuh @@ -8,10 +8,10 @@ #pragma once #include -#include -#include +#include +#include -namespace raft::spatial::knn::detail::faiss_select { +namespace raft::neighbors::detail::faiss_select { // Merge pairs of lists smaller than blockDim.x (NumThreads) template ::merge(listK, listV); } -} // namespace raft::spatial::knn::detail::faiss_select +} // namespace raft::neighbors::detail::faiss_select diff --git a/cpp/include/raft/spatial/knn/detail/faiss_select/MergeNetworkUtils.cuh b/cpp/include/raft/neighbors/detail/faiss_select/MergeNetworkUtils.cuh similarity index 79% rename from cpp/include/raft/spatial/knn/detail/faiss_select/MergeNetworkUtils.cuh rename to cpp/include/raft/neighbors/detail/faiss_select/MergeNetworkUtils.cuh index 2cb01f9199..78f794bff4 100644 --- a/cpp/include/raft/spatial/knn/detail/faiss_select/MergeNetworkUtils.cuh +++ b/cpp/include/raft/neighbors/detail/faiss_select/MergeNetworkUtils.cuh @@ -7,7 +7,7 @@ #pragma once -namespace raft::spatial::knn::detail::faiss_select { +namespace raft::neighbors::detail::faiss_select { template inline __device__ void swap(bool swap, T& x, T& y) @@ -22,4 +22,4 @@ inline __device__ void assign(bool assign, T& x, T y) { x = assign ? y : x; } -} // namespace raft::spatial::knn::detail::faiss_select +} // namespace raft::neighbors::detail::faiss_select diff --git a/cpp/include/raft/spatial/knn/detail/faiss_select/MergeNetworkWarp.cuh b/cpp/include/raft/neighbors/detail/faiss_select/MergeNetworkWarp.cuh similarity index 98% rename from cpp/include/raft/spatial/knn/detail/faiss_select/MergeNetworkWarp.cuh rename to cpp/include/raft/neighbors/detail/faiss_select/MergeNetworkWarp.cuh index bce739b2d8..04f7f90aac 100644 --- a/cpp/include/raft/spatial/knn/detail/faiss_select/MergeNetworkWarp.cuh +++ b/cpp/include/raft/neighbors/detail/faiss_select/MergeNetworkWarp.cuh @@ -7,12 +7,12 @@ #pragma once -#include -#include +#include +#include #include -namespace raft::spatial::knn::detail::faiss_select { +namespace raft::neighbors::detail::faiss_select { // // This file contains functions to: @@ -518,4 +518,4 @@ inline __device__ void warpSortAnyRegisters(K k[N], V v[N]) BitonicSortStep::sort(k, v); } -} // namespace raft::spatial::knn::detail::faiss_select +} // namespace raft::neighbors::detail::faiss_select diff --git a/cpp/include/raft/spatial/knn/detail/faiss_select/Select.cuh b/cpp/include/raft/neighbors/detail/faiss_select/Select.cuh similarity index 97% rename from cpp/include/raft/spatial/knn/detail/faiss_select/Select.cuh rename to cpp/include/raft/neighbors/detail/faiss_select/Select.cuh index e4faff7a6c..4aa7d68f54 100644 --- a/cpp/include/raft/spatial/knn/detail/faiss_select/Select.cuh +++ b/cpp/include/raft/neighbors/detail/faiss_select/Select.cuh @@ -7,14 +7,14 @@ #pragma once -#include -#include -#include +#include +#include +#include #include #include -namespace raft::spatial::knn::detail::faiss_select { +namespace raft::neighbors::detail::faiss_select { // Specialization for block-wide monotonic merges producing a merge sort // since what we really want is a constexpr loop expansion @@ -552,4 +552,4 @@ struct WarpSelect { V threadV; }; -} // namespace raft::spatial::knn::detail::faiss_select +} // namespace raft::neighbors::detail::faiss_select diff --git a/cpp/include/raft/spatial/knn/detail/faiss_select/StaticUtils.h b/cpp/include/raft/neighbors/detail/faiss_select/StaticUtils.h similarity index 91% rename from cpp/include/raft/spatial/knn/detail/faiss_select/StaticUtils.h rename to cpp/include/raft/neighbors/detail/faiss_select/StaticUtils.h index bac051b68c..5a25c7a321 100644 --- a/cpp/include/raft/spatial/knn/detail/faiss_select/StaticUtils.h +++ b/cpp/include/raft/neighbors/detail/faiss_select/StaticUtils.h @@ -15,7 +15,7 @@ #define __device__ #endif -namespace raft::spatial::knn::detail::faiss_select::utils { +namespace raft::neighbors::detail::faiss_select::utils { template constexpr __host__ __device__ bool isPowerOf2(T v) @@ -45,4 +45,4 @@ static_assert(nextHighestPowerOf2(1536000000u) == 2147483648u, "nextHighestPower static_assert(nextHighestPowerOf2((size_t)2147483648ULL) == (size_t)4294967296ULL, "nextHighestPowerOf2"); -} // namespace raft::spatial::knn::detail::faiss_select::utils +} // namespace raft::neighbors::detail::faiss_select::utils diff --git a/cpp/include/raft/spatial/knn/detail/faiss_select/key_value_block_select.cuh b/cpp/include/raft/neighbors/detail/faiss_select/key_value_block_select.cuh similarity index 96% rename from cpp/include/raft/spatial/knn/detail/faiss_select/key_value_block_select.cuh rename to cpp/include/raft/neighbors/detail/faiss_select/key_value_block_select.cuh index 617a26a243..ff06b7dca4 100644 --- a/cpp/include/raft/spatial/knn/detail/faiss_select/key_value_block_select.cuh +++ b/cpp/include/raft/neighbors/detail/faiss_select/key_value_block_select.cuh @@ -7,14 +7,14 @@ #pragma once -#include -#include +#include +#include // TODO: Need to think further about the impact (and new boundaries created) on the registers // because this will change the max k that can be processed. One solution might be to break // up k into multiple batches for larger k. -namespace raft::spatial::knn::detail::faiss_select { +namespace raft::neighbors::detail::faiss_select { // `Dir` true, produce largest values. // `Dir` false, produce smallest values. @@ -221,4 +221,4 @@ struct KeyValueBlockSelect { int kMinus1; }; -} // namespace raft::spatial::knn::detail::faiss_select +} // namespace raft::neighbors::detail::faiss_select diff --git a/cpp/include/raft/neighbors/detail/ivf_flat_search.cuh b/cpp/include/raft/neighbors/detail/ivf_flat_search.cuh index b2bfd18610..f657070df4 100644 --- a/cpp/include/raft/neighbors/detail/ivf_flat_search.cuh +++ b/cpp/include/raft/neighbors/detail/ivf_flat_search.cuh @@ -1244,26 +1244,6 @@ void search_impl(raft::device_resources const& handle, } } -/** - * Whether minimal distance corresponds to similar elements (using the given metric). - */ -inline bool is_min_close(distance::DistanceType metric) -{ - bool select_min; - switch (metric) { - case raft::distance::DistanceType::InnerProduct: - case raft::distance::DistanceType::CosineExpanded: - case raft::distance::DistanceType::CorrelationExpanded: - // Similarity metrics have the opposite meaning, i.e. nearest neighbors are those with larger - // similarity (See the same logic at cpp/include/raft/sparse/spatial/detail/knn.cuh:362 - // {perform_k_selection}) - select_min = false; - break; - default: select_min = true; - } - return select_min; -} - /** See raft::neighbors::ivf_flat::search docs */ template inline void search(raft::device_resources const& handle, @@ -1295,7 +1275,7 @@ inline void search(raft::device_resources const& handle, n_queries, k, n_probes, - is_min_close(index.metric()), + raft::distance::is_min_close(index.metric()), neighbors, distances, mr); diff --git a/cpp/include/raft/neighbors/detail/knn_brute_force.cuh b/cpp/include/raft/neighbors/detail/knn_brute_force.cuh new file mode 100644 index 0000000000..875fc3b37c --- /dev/null +++ b/cpp/include/raft/neighbors/detail/knn_brute_force.cuh @@ -0,0 +1,455 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include +#include + +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace raft::neighbors::detail { +using namespace raft::spatial::knn::detail; +using namespace raft::spatial::knn; + +/** + * Calculates brute force knn, using a fixed memory budget + * by tiling over both the rows and columns of pairwise_distances + */ +template +void tiled_brute_force_knn(const raft::device_resources& handle, + const ElementType* search, // size (m ,d) + const ElementType* index, // size (n ,d) + size_t m, + size_t n, + size_t d, + int k, + ElementType* distances, // size (m, k) + IndexType* indices, // size (m, k) + raft::distance::DistanceType metric, + float metric_arg = 0.0, + size_t max_row_tile_size = 0, + size_t max_col_tile_size = 0) +{ + // Figure out the number of rows/cols to tile for + size_t tile_rows = 0; + size_t tile_cols = 0; + auto stream = handle.get_stream(); + auto device_memory = handle.get_workspace_resource(); + auto total_mem = device_memory->get_mem_info(stream).second; + faiss_select::chooseTileSize(m, n, d, sizeof(ElementType), total_mem, tile_rows, tile_cols); + + // for unittesting, its convenient to be able to put a max size on the tiles + // so we can test the tiling logic without having to use huge inputs. + if (max_row_tile_size && (tile_rows > max_row_tile_size)) { tile_rows = max_row_tile_size; } + if (max_col_tile_size && (tile_cols > max_col_tile_size)) { tile_cols = max_col_tile_size; } + + // tile_cols must be at least k items + tile_cols = std::max(tile_cols, static_cast(k)); + + // stores pairwise distances for the current tile + rmm::device_uvector temp_distances(tile_rows * tile_cols, stream); + + // calculate norms for L2 expanded distances - this lets us avoid calculating + // norms repeatedly per-tile, and just do once for the entire input + auto pairwise_metric = metric; + rmm::device_uvector search_norms(0, stream); + rmm::device_uvector index_norms(0, stream); + if (metric == raft::distance::DistanceType::L2Expanded || + metric == raft::distance::DistanceType::L2SqrtExpanded) { + search_norms.resize(m, stream); + index_norms.resize(n, stream); + raft::linalg::rowNorm( + search_norms.data(), search, d, m, raft::linalg::NormType::L2Norm, true, stream); + raft::linalg::rowNorm( + index_norms.data(), index, d, n, raft::linalg::NormType::L2Norm, true, stream); + pairwise_metric = raft::distance::DistanceType::InnerProduct; + } + + // if we're tiling over columns, we need additional buffers for temporary output + // distances/indices + size_t num_col_tiles = raft::ceildiv(n, tile_cols); + size_t temp_out_cols = k * num_col_tiles; + + // the final column tile could have less than 'k' items in it + // in which case the number of columns here is too high in the temp output. + // adjust if necessary + auto last_col_tile_size = n % tile_cols; + if (last_col_tile_size && (last_col_tile_size < static_cast(k))) { + temp_out_cols -= k - last_col_tile_size; + } + + // if we have less than k items in the index, we should fill out the result + // to indicate that we are missing items (and match behaviour in faiss) + if (n < static_cast(k)) { + raft::matrix::fill(handle, + raft::make_device_matrix_view(distances, m, static_cast(k)), + std::numeric_limits::lowest()); + + if constexpr (std::is_signed_v) { + raft::matrix::fill( + handle, raft::make_device_matrix_view(indices, m, static_cast(k)), IndexType{-1}); + } + } + + rmm::device_uvector temp_out_distances(tile_rows * temp_out_cols, stream); + rmm::device_uvector temp_out_indices(tile_rows * temp_out_cols, stream); + + bool select_min = raft::distance::is_min_close(metric); + + for (size_t i = 0; i < m; i += tile_rows) { + size_t current_query_size = std::min(tile_rows, m - i); + + for (size_t j = 0; j < n; j += tile_cols) { + size_t current_centroid_size = std::min(tile_cols, n - j); + size_t current_k = std::min(current_centroid_size, static_cast(k)); + + // calculate the top-k elements for the current tile, by calculating the + // full pairwise distance for the tile - and then selecting the top-k from that + // note: we're using a int32 IndexType here on purpose in order to + // use the pairwise_distance specializations. Since the tile size will ensure + // that the total memory is < 1GB per tile, this will not cause any issues + distance::pairwise_distance(handle, + search + i * d, + index + j * d, + temp_distances.data(), + current_query_size, + current_centroid_size, + d, + pairwise_metric, + true, + metric_arg); + if (metric == raft::distance::DistanceType::L2Expanded || + metric == raft::distance::DistanceType::L2SqrtExpanded) { + auto row_norms = search_norms.data() + i; + auto col_norms = index_norms.data() + j; + auto dist = temp_distances.data(); + + raft::linalg::map_offset( + handle, + raft::make_device_vector_view(dist, current_query_size * current_centroid_size), + [=] __device__(IndexType i) { + IndexType row = i / current_centroid_size, col = i % current_centroid_size; + + auto val = row_norms[row] + col_norms[col] - 2.0 * dist[i]; + + // due to numerical instability (especially around self-distance) + // the distances here could be slightly negative, which will + // cause NaN values in the subsequent sqrt. Clamp to 0 + val = val * (val >= 0.0001); + if (metric == raft::distance::DistanceType::L2SqrtExpanded) { val = sqrt(val); } + return val; + }); + } + + select_k(temp_distances.data(), + nullptr, + current_query_size, + current_centroid_size, + distances + i * k, + indices + i * k, + select_min, + current_k, + stream); + + // if we're tiling over columns, we need to do a couple things to fix up + // the output of select_k + // 1. The column id's in the output are relative to the tile, so we need + // to adjust the column ids by adding the column the tile starts at (j) + // 2. select_k writes out output in a row-major format, which means we + // can't just concat the output of all the tiles and do a select_k on the + // concatenation. + // Fix both of these problems in a single pass here + if (tile_cols != n) { + const ElementType* in_distances = distances + i * k; + const IndexType* in_indices = indices + i * k; + ElementType* out_distances = temp_out_distances.data(); + IndexType* out_indices = temp_out_indices.data(); + + auto count = thrust::make_counting_iterator(0); + thrust::for_each(handle.get_thrust_policy(), + count, + count + current_query_size * current_k, + [=] __device__(IndexType i) { + IndexType row = i / current_k, col = i % current_k; + IndexType out_index = row * temp_out_cols + j * k / tile_cols + col; + + out_distances[out_index] = in_distances[i]; + out_indices[out_index] = in_indices[i] + j; + }); + } + } + + if (tile_cols != n) { + // select the actual top-k items here from the temporary output + select_k(temp_out_distances.data(), + temp_out_indices.data(), + current_query_size, + temp_out_cols, + distances + i * k, + indices + i * k, + select_min, + k, + stream); + } + } +} + +/** + * Search the kNN for the k-nearest neighbors of a set of query vectors + * @param[in] input vector of device device memory array pointers to search + * @param[in] sizes vector of memory sizes for each device array pointer in input + * @param[in] D number of cols in input and search_items + * @param[in] search_items set of vectors to query for neighbors + * @param[in] n number of items in search_items + * @param[out] res_I pointer to device memory for returning k nearest indices + * @param[out] res_D pointer to device memory for returning k nearest distances + * @param[in] k number of neighbors to query + * @param[in] userStream the main cuda stream to use + * @param[in] internalStreams optional when n_params > 0, the index partitions can be + * queried in parallel using these streams. Note that n_int_streams also + * has to be > 0 for these to be used and their cardinality does not need + * to correspond to n_parts. + * @param[in] n_int_streams size of internalStreams. When this is <= 0, only the + * user stream will be used. + * @param[in] rowMajorIndex are the index arrays in row-major layout? + * @param[in] rowMajorQuery are the query array in row-major layout? + * @param[in] translations translation ids for indices when index rows represent + * non-contiguous partitions + * @param[in] metric corresponds to the raft::distance::DistanceType enum (default is L2Expanded) + * @param[in] metricArg metric argument to use. Corresponds to the p arg for lp norm + */ +template +void brute_force_knn_impl( + raft::device_resources const& handle, + std::vector& input, + std::vector& sizes, + IntType D, + value_t* search_items, + IntType n, + IdxType* res_I, + value_t* res_D, + IntType k, + bool rowMajorIndex = true, + bool rowMajorQuery = true, + std::vector* translations = nullptr, + raft::distance::DistanceType metric = raft::distance::DistanceType::L2Expanded, + float metricArg = 0) +{ + auto userStream = handle.get_stream(); + + ASSERT(input.size() == sizes.size(), "input and sizes vectors should be the same size"); + + std::vector* id_ranges; + if (translations == nullptr) { + // If we don't have explicit translations + // for offsets of the indices, build them + // from the local partitions + id_ranges = new std::vector(); + IdxType total_n = 0; + for (size_t i = 0; i < input.size(); i++) { + id_ranges->push_back(total_n); + total_n += sizes[i]; + } + } else { + // otherwise, use the given translations + id_ranges = translations; + } + + // perform preprocessing + std::unique_ptr> query_metric_processor = + create_processor(metric, n, D, k, rowMajorQuery, userStream); + query_metric_processor->preprocess(search_items); + + std::vector>> metric_processors(input.size()); + for (size_t i = 0; i < input.size(); i++) { + metric_processors[i] = + create_processor(metric, sizes[i], D, k, rowMajorQuery, userStream); + metric_processors[i]->preprocess(input[i]); + } + + int device; + RAFT_CUDA_TRY(cudaGetDevice(&device)); + + rmm::device_uvector trans(id_ranges->size(), userStream); + raft::update_device(trans.data(), id_ranges->data(), id_ranges->size(), userStream); + + rmm::device_uvector all_D(0, userStream); + rmm::device_uvector all_I(0, userStream); + + value_t* out_D = res_D; + IdxType* out_I = res_I; + + if (input.size() > 1) { + all_D.resize(input.size() * k * n, userStream); + all_I.resize(input.size() * k * n, userStream); + + out_D = all_D.data(); + out_I = all_I.data(); + } + + // currently we don't support col_major inside tiled_brute_force_knn, because + // of limitattions of the pairwise_distance API: + // 1) paiwise_distance takes a single 'isRowMajor' parameter - and we have + // multiple options here (like rowMajorQuery/rowMajorIndex) + // 2) because of tiling, we need to be able to set a custom stride in the PW + // api, which isn't supported + // Instead, transpose the input matrices if they are passed as col-major. + auto search = search_items; + rmm::device_uvector search_row_major(0, userStream); + if (!rowMajorQuery) { + search_row_major.resize(n * D, userStream); + raft::linalg::transpose(handle, search, search_row_major.data(), n, D, userStream); + search = search_row_major.data(); + } + + // transpose into a temporary buffer if necessary + rmm::device_uvector index_row_major(0, userStream); + if (!rowMajorIndex) { + size_t total_size = 0; + for (auto size : sizes) { + total_size += size; + } + index_row_major.resize(total_size * D, userStream); + } + + // Make other streams from pool wait on main stream + handle.wait_stream_pool_on_stream(); + + size_t total_rows_processed = 0; + for (size_t i = 0; i < input.size(); i++) { + value_t* out_d_ptr = out_D + (i * k * n); + IdxType* out_i_ptr = out_I + (i * k * n); + + auto stream = handle.get_next_usable_stream(i); + + if (k <= 64 && rowMajorQuery == rowMajorIndex && rowMajorQuery == true && + (metric == raft::distance::DistanceType::L2Unexpanded || + metric == raft::distance::DistanceType::L2SqrtUnexpanded || + metric == raft::distance::DistanceType::L2Expanded || + metric == raft::distance::DistanceType::L2SqrtExpanded)) { + fusedL2Knn(D, + out_i_ptr, + out_d_ptr, + input[i], + search_items, + sizes[i], + n, + k, + rowMajorIndex, + rowMajorQuery, + stream, + metric); + + // Perform necessary post-processing + if (metric == raft::distance::DistanceType::L2SqrtExpanded || + metric == raft::distance::DistanceType::L2SqrtUnexpanded || + metric == raft::distance::DistanceType::LpUnexpanded) { + float p = 0.5; // standard l2 + if (metric == raft::distance::DistanceType::LpUnexpanded) p = 1.0 / metricArg; + raft::linalg::unaryOp( + res_D, + res_D, + n * k, + [p] __device__(float input) { return powf(fabsf(input), p); }, + stream); + } + } else { + switch (metric) { + case raft::distance::DistanceType::Haversine: + ASSERT(D == 2, + "Haversine distance requires 2 dimensions " + "(latitude / longitude)."); + + haversine_knn(out_i_ptr, out_d_ptr, input[i], search_items, sizes[i], n, k, stream); + break; + default: + // Create a new handle with the current stream from the stream pool + raft::device_resources stream_pool_handle(handle); + raft::resource::set_cuda_stream(stream_pool_handle, stream); + + auto index = input[i]; + if (!rowMajorIndex) { + index = index_row_major.data() + total_rows_processed * D; + total_rows_processed += sizes[i]; + raft::linalg::transpose(handle, input[i], index, sizes[i], D, stream); + } + + // cosine/correlation are handled by metric processor, use IP distance + // for brute force knn call. + auto tiled_metric = metric; + if (metric == raft::distance::DistanceType::CosineExpanded || + metric == raft::distance::DistanceType::CorrelationExpanded) { + tiled_metric = raft::distance::DistanceType::InnerProduct; + } + + tiled_brute_force_knn(stream_pool_handle, + search, + index, + n, + sizes[i], + D, + k, + out_d_ptr, + out_i_ptr, + tiled_metric, + metricArg); + break; + } + } + + RAFT_CUDA_TRY(cudaPeekAtLastError()); + } + + // Sync internal streams if used. We don't need to + // sync the user stream because we'll already have + // fully serial execution. + handle.sync_stream_pool(); + + if (input.size() > 1 || translations != nullptr) { + // This is necessary for proper index translations. If there are + // no translations or partitions to combine, it can be skipped. + knn_merge_parts(out_D, out_I, res_D, res_I, n, input.size(), k, userStream, trans.data()); + } + + query_metric_processor->revert(search_items); + query_metric_processor->postprocess(out_D); + for (size_t i = 0; i < input.size(); i++) { + metric_processors[i]->revert(input[i]); + } + + if (translations == nullptr) delete id_ranges; +}; + +} // namespace raft::neighbors::detail diff --git a/cpp/include/raft/neighbors/detail/knn_merge_parts.cuh b/cpp/include/raft/neighbors/detail/knn_merge_parts.cuh new file mode 100644 index 0000000000..e2b5c41fb0 --- /dev/null +++ b/cpp/include/raft/neighbors/detail/knn_merge_parts.cuh @@ -0,0 +1,172 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include + +#include +#include +#include + +namespace raft::neighbors::detail { + +template +__global__ void knn_merge_parts_kernel(value_t* inK, + value_idx* inV, + value_t* outK, + value_idx* outV, + size_t n_samples, + int n_parts, + value_t initK, + value_idx initV, + int k, + value_idx* translations) +{ + constexpr int kNumWarps = tpb / WarpSize; + + __shared__ value_t smemK[kNumWarps * warp_q]; + __shared__ value_idx smemV[kNumWarps * warp_q]; + + /** + * Uses shared memory + */ + faiss_select:: + BlockSelect, warp_q, thread_q, tpb> + heap(initK, initV, smemK, smemV, k); + + // Grid is exactly sized to rows available + int row = blockIdx.x; + int total_k = k * n_parts; + + int i = threadIdx.x; + + // Get starting pointers for cols in current thread + int part = i / k; + size_t row_idx = (row * k) + (part * n_samples * k); + + int col = i % k; + + value_t* inKStart = inK + (row_idx + col); + value_idx* inVStart = inV + (row_idx + col); + + int limit = Pow2::roundDown(total_k); + value_idx translation = 0; + + for (; i < limit; i += tpb) { + translation = translations[part]; + heap.add(*inKStart, (*inVStart) + translation); + + part = (i + tpb) / k; + row_idx = (row * k) + (part * n_samples * k); + + col = (i + tpb) % k; + + inKStart = inK + (row_idx + col); + inVStart = inV + (row_idx + col); + } + + // Handle last remainder fraction of a warp of elements + if (i < total_k) { + translation = translations[part]; + heap.addThreadQ(*inKStart, (*inVStart) + translation); + } + + heap.reduce(); + + for (int i = threadIdx.x; i < k; i += tpb) { + outK[row * k + i] = smemK[i]; + outV[row * k + i] = smemV[i]; + } +} + +template +inline void knn_merge_parts_impl(value_t* inK, + value_idx* inV, + value_t* outK, + value_idx* outV, + size_t n_samples, + int n_parts, + int k, + cudaStream_t stream, + value_idx* translations) +{ + auto grid = dim3(n_samples); + + constexpr int n_threads = (warp_q <= 1024) ? 128 : 64; + auto block = dim3(n_threads); + + auto kInit = std::numeric_limits::max(); + auto vInit = -1; + knn_merge_parts_kernel + <<>>( + inK, inV, outK, outV, n_samples, n_parts, kInit, vInit, k, translations); + RAFT_CUDA_TRY(cudaPeekAtLastError()); +} + +/** + * @brief Merge knn distances and index matrix, which have been partitioned + * by row, into a single matrix with only the k-nearest neighbors. + * + * @param inK partitioned knn distance matrix + * @param inV partitioned knn index matrix + * @param outK merged knn distance matrix + * @param outV merged knn index matrix + * @param n_samples number of samples per partition + * @param n_parts number of partitions + * @param k number of neighbors per partition (also number of merged neighbors) + * @param stream CUDA stream to use + * @param translations mapping of index offsets for each partition + */ +template +inline void knn_merge_parts(value_t* inK, + value_idx* inV, + value_t* outK, + value_idx* outV, + size_t n_samples, + int n_parts, + int k, + cudaStream_t stream, + value_idx* translations) +{ + if (k == 1) + knn_merge_parts_impl( + inK, inV, outK, outV, n_samples, n_parts, k, stream, translations); + else if (k <= 32) + knn_merge_parts_impl( + inK, inV, outK, outV, n_samples, n_parts, k, stream, translations); + else if (k <= 64) + knn_merge_parts_impl( + inK, inV, outK, outV, n_samples, n_parts, k, stream, translations); + else if (k <= 128) + knn_merge_parts_impl( + inK, inV, outK, outV, n_samples, n_parts, k, stream, translations); + else if (k <= 256) + knn_merge_parts_impl( + inK, inV, outK, outV, n_samples, n_parts, k, stream, translations); + else if (k <= 512) + knn_merge_parts_impl( + inK, inV, outK, outV, n_samples, n_parts, k, stream, translations); + else if (k <= 1024) + knn_merge_parts_impl( + inK, inV, outK, outV, n_samples, n_parts, k, stream, translations); +} +} // namespace raft::neighbors::detail diff --git a/cpp/include/raft/neighbors/detail/refine.cuh b/cpp/include/raft/neighbors/detail/refine.cuh index b0aebe28b6..f244d5875c 100644 --- a/cpp/include/raft/neighbors/detail/refine.cuh +++ b/cpp/include/raft/neighbors/detail/refine.cuh @@ -128,7 +128,7 @@ void refine_device(raft::device_resources const& handle, refinement_index.metric(), 1, k, - raft::neighbors::ivf_flat::detail::is_min_close(metric), + raft::distance::is_min_close(metric), indices.data_handle(), distances.data_handle(), grid_dim_x, diff --git a/cpp/include/raft/spatial/knn/detail/selection_faiss.cuh b/cpp/include/raft/neighbors/detail/selection_faiss.cuh similarity index 95% rename from cpp/include/raft/spatial/knn/detail/selection_faiss.cuh rename to cpp/include/raft/neighbors/detail/selection_faiss.cuh index 5264f5d12e..5df42e94b9 100644 --- a/cpp/include/raft/spatial/knn/detail/selection_faiss.cuh +++ b/cpp/include/raft/neighbors/detail/selection_faiss.cuh @@ -16,16 +16,12 @@ #pragma once -#include #include #include -#include +#include -namespace raft { -namespace spatial { -namespace knn { -namespace detail { +namespace raft::neighbors::detail { template constexpr int kFaissMaxK() @@ -170,8 +166,4 @@ inline void select_k(const key_t* inK, else ASSERT(k <= max_k, "Current max k is %d (requested %d)", max_k, k); } - -}; // namespace detail -}; // namespace knn -}; // namespace spatial -}; // namespace raft +}; // namespace raft::neighbors::detail diff --git a/cpp/include/raft/neighbors/specializations/knn.cuh b/cpp/include/raft/neighbors/specializations/knn.cuh index b1cfa278d6..e0b64415fe 100644 --- a/cpp/include/raft/neighbors/specializations/knn.cuh +++ b/cpp/include/raft/neighbors/specializations/knn.cuh @@ -16,73 +16,55 @@ #pragma once +#include #include -namespace raft { -namespace spatial { -namespace knn { -extern template void brute_force_knn(raft::device_resources const& handle, - std::vector& input, - std::vector& sizes, - int D, - float* search_items, - int n, - long* res_I, - float* res_D, - int k, - bool rowMajorIndex, - bool rowMajorQuery, - std::vector* translations, - distance::DistanceType metric, - float metric_arg); +namespace raft::spatial::knn { +#define RAFT_INST(IdxT, T, IntT) \ + extern template void brute_force_knn(raft::device_resources const& handle, \ + std::vector& input, \ + std::vector& sizes, \ + IntT D, \ + T* search_items, \ + IntT n, \ + IdxT* res_I, \ + T* res_D, \ + IntT k, \ + bool rowMajorIndex, \ + bool rowMajorQuery, \ + std::vector* translations, \ + distance::DistanceType metric, \ + float metric_arg); -extern template void brute_force_knn( - raft::device_resources const& handle, - std::vector& input, - std::vector& sizes, - unsigned int D, - float* search_items, - unsigned int n, - long* res_I, - float* res_D, - unsigned int k, - bool rowMajorIndex, - bool rowMajorQuery, - std::vector* translations, - distance::DistanceType metric, - float metric_arg); +RAFT_INST(long, float, int); +RAFT_INST(long, float, unsigned int); +RAFT_INST(uint32_t, float, int); +RAFT_INST(uint32_t, float, unsigned int); +#undef RAFT_INST +}; // namespace raft::spatial::knn -extern template void brute_force_knn(raft::device_resources const& handle, - std::vector& input, - std::vector& sizes, - int D, - float* search_items, - int n, - uint32_t* res_I, - float* res_D, - int k, - bool rowMajorIndex, - bool rowMajorQuery, - std::vector* translations, - distance::DistanceType metric, - float metric_arg); - -extern template void brute_force_knn( - raft::device_resources const& handle, - std::vector& input, - std::vector& sizes, - unsigned int D, - float* search_items, - unsigned int n, - uint32_t* res_I, - float* res_D, - unsigned int k, - bool rowMajorIndex, - bool rowMajorQuery, - std::vector* translations, - distance::DistanceType metric, - float metric_arg); - -}; // namespace knn -}; // namespace spatial -}; // namespace raft +// also define the detail api, which is used by raft::neighbors::brute_force +// (not doing the public api, since has extra template params on index_layout, matrix_index, +// search_layout etc - and isn't clear what the defaults here should be) +namespace raft::neighbors::detail { +#define RAFT_INST(IdxT, T, IntT) \ + extern template void brute_force_knn_impl(raft::device_resources const& handle, \ + std::vector& input, \ + std::vector& sizes, \ + IntT D, \ + T* search_items, \ + IntT n, \ + IdxT* res_I, \ + T* res_D, \ + IntT k, \ + bool rowMajorIndex, \ + bool rowMajorQuery, \ + std::vector* translations, \ + raft::distance::DistanceType metric, \ + float metricArg); +RAFT_INST(long, float, int); +RAFT_INST(long, float, unsigned int); +RAFT_INST(uint32_t, float, int); +RAFT_INST(uint32_t, float, unsigned int); +#undef RAFT_INST +} // namespace raft::neighbors::detail diff --git a/cpp/include/raft/spatial/knn/detail/ball_cover.cuh b/cpp/include/raft/spatial/knn/detail/ball_cover.cuh index 7b3cf2d8f7..99d688e232 100644 --- a/cpp/include/raft/spatial/knn/detail/ball_cover.cuh +++ b/cpp/include/raft/spatial/knn/detail/ball_cover.cuh @@ -22,17 +22,16 @@ #include "ball_cover/common.cuh" #include "ball_cover/registers.cuh" #include "haversine_distance.cuh" -#include "knn_brute_force_faiss.cuh" -#include "selection_faiss.cuh" #include #include #include -#include +#include #include +#include #include #include @@ -178,23 +177,16 @@ void k_closest_landmarks(raft::device_resources const& handle, value_idx* R_knn_inds, value_t* R_knn_dists) { - // TODO: Add const to the brute-force knn inputs - std::vector input = {const_cast(index.get_R().data_handle())}; - std::vector sizes = {index.n_landmarks}; - - brute_force_knn_impl(handle, - input, - sizes, - index.n, - const_cast(query_pts), - n_query_pts, - R_knn_inds, - R_knn_dists, - k, - true, - true, - nullptr, - index.get_metric()); + std::vector> inputs = {index.get_R()}; + + raft::neighbors::brute_force::knn( + handle, + inputs, + make_device_matrix_view(query_pts, n_query_pts, inputs[0].extent(1)), + make_device_matrix_view(R_knn_inds, n_query_pts, k), + make_device_matrix_view(R_knn_dists, n_query_pts, k), + k, + index.get_metric()); } /** diff --git a/cpp/include/raft/spatial/knn/detail/ball_cover/registers.cuh b/cpp/include/raft/spatial/knn/detail/ball_cover/registers.cuh index 394d27235b..f665368c41 100644 --- a/cpp/include/raft/spatial/knn/detail/ball_cover/registers.cuh +++ b/cpp/include/raft/spatial/knn/detail/ball_cover/registers.cuh @@ -19,13 +19,12 @@ #include "common.cuh" #include "../../ball_cover_types.hpp" -#include "../faiss_select/key_value_block_select.cuh" #include "../haversine_distance.cuh" -#include "../selection_faiss.cuh" #include #include +#include #include #include @@ -180,19 +179,14 @@ __global__ void compute_final_dists_registers(const value_t* X_index, local_x_ptr[j] = x_ptr[j]; } - faiss_select::KeyValueBlockSelect, - warp_q, - thread_q, - tpb> - heap(std::numeric_limits::max(), - std::numeric_limits::max(), - -1, - shared_memK, - shared_memV, - k); + using namespace raft::neighbors::detail::faiss_select; + KeyValueBlockSelect, warp_q, thread_q, tpb> heap( + std::numeric_limits::max(), + std::numeric_limits::max(), + -1, + shared_memK, + shared_memV, + k); const value_int n_k = Pow2::roundDown(k); value_int i = threadIdx.x; @@ -349,19 +343,14 @@ __global__ void block_rbc_kernel_registers(const value_t* X_index, } // Each warp works on 1 R - faiss_select::KeyValueBlockSelect, - warp_q, - thread_q, - tpb> - heap(std::numeric_limits::max(), - std::numeric_limits::max(), - -1, - shared_memK, - shared_memV, - k); + using namespace raft::neighbors::detail::faiss_select; + KeyValueBlockSelect, warp_q, thread_q, tpb> heap( + std::numeric_limits::max(), + std::numeric_limits::max(), + -1, + shared_memK, + shared_memV, + k); value_t min_R_dist = R_knn_dists[blockIdx.x * k + (k - 1)]; value_int n_dists_computed = 0; diff --git a/cpp/include/raft/spatial/knn/detail/common_faiss.h b/cpp/include/raft/spatial/knn/detail/common_faiss.h deleted file mode 100644 index 57076350f0..0000000000 --- a/cpp/include/raft/spatial/knn/detail/common_faiss.h +++ /dev/null @@ -1,56 +0,0 @@ -/* - * Copyright (c) 2020-2022, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once - -#include -#include - -#include -#include - -namespace raft { -namespace spatial { -namespace knn { -namespace detail { - -inline faiss::MetricType build_faiss_metric(raft::distance::DistanceType metric) -{ - switch (metric) { - case raft::distance::DistanceType::CosineExpanded: - return faiss::MetricType::METRIC_INNER_PRODUCT; - case raft::distance::DistanceType::CorrelationExpanded: - return faiss::MetricType::METRIC_INNER_PRODUCT; - case raft::distance::DistanceType::L2Expanded: return faiss::MetricType::METRIC_L2; - case raft::distance::DistanceType::L2Unexpanded: return faiss::MetricType::METRIC_L2; - case raft::distance::DistanceType::L2SqrtExpanded: return faiss::MetricType::METRIC_L2; - case raft::distance::DistanceType::L2SqrtUnexpanded: return faiss::MetricType::METRIC_L2; - case raft::distance::DistanceType::L1: return faiss::MetricType::METRIC_L1; - case raft::distance::DistanceType::InnerProduct: return faiss::MetricType::METRIC_INNER_PRODUCT; - case raft::distance::DistanceType::LpUnexpanded: return faiss::MetricType::METRIC_Lp; - case raft::distance::DistanceType::Linf: return faiss::MetricType::METRIC_Linf; - case raft::distance::DistanceType::Canberra: return faiss::MetricType::METRIC_Canberra; - case raft::distance::DistanceType::BrayCurtis: return faiss::MetricType::METRIC_BrayCurtis; - case raft::distance::DistanceType::JensenShannon: - return faiss::MetricType::METRIC_JensenShannon; - default: THROW("MetricType not supported: %d", metric); - } -} - -} // namespace detail -} // namespace knn -} // namespace spatial -} // namespace raft diff --git a/cpp/include/raft/spatial/knn/detail/fused_l2_knn.cuh b/cpp/include/raft/spatial/knn/detail/fused_l2_knn.cuh index f1f160a154..4e18a210d4 100644 --- a/cpp/include/raft/spatial/knn/detail/fused_l2_knn.cuh +++ b/cpp/include/raft/spatial/knn/detail/fused_l2_knn.cuh @@ -17,7 +17,7 @@ #include #include #include -#include +#include // TODO: Need to hide the PairwiseDistance class impl and expose to public API #include "processing.cuh" #include @@ -219,9 +219,8 @@ __global__ __launch_bounds__(Policy::Nthreads, 2) void fusedL2kNN(const DataT* x constexpr auto identity = std::numeric_limits::max(); constexpr auto keyMax = std::numeric_limits::max(); constexpr auto Dir = false; - typedef faiss_select:: - WarpSelect, NumWarpQ, NumThreadQ, 32> - myWarpSelect; + using namespace raft::neighbors::detail::faiss_select; + typedef WarpSelect, NumWarpQ, NumThreadQ, 32> myWarpSelect; auto rowEpilog_lambda = [m, n, numOfNN, out_dists, out_inds, mutexes] __device__( IdxT gridStrideY) { diff --git a/cpp/include/raft/spatial/knn/detail/haversine_distance.cuh b/cpp/include/raft/spatial/knn/detail/haversine_distance.cuh index 7d361ba4fb..058e98da9f 100644 --- a/cpp/include/raft/spatial/knn/detail/haversine_distance.cuh +++ b/cpp/include/raft/spatial/knn/detail/haversine_distance.cuh @@ -22,7 +22,7 @@ #include #include -#include +#include namespace raft { namespace spatial { @@ -65,13 +65,9 @@ __global__ void haversine_knn_kernel(value_idx* out_inds, __shared__ value_t smemK[kNumWarps * warp_q]; __shared__ value_idx smemV[kNumWarps * warp_q]; - faiss_select:: - BlockSelect, warp_q, thread_q, tpb> - heap(std::numeric_limits::max(), - std::numeric_limits::max(), - smemK, - smemV, - k); + using namespace raft::neighbors::detail::faiss_select; + BlockSelect, warp_q, thread_q, tpb> heap( + std::numeric_limits::max(), std::numeric_limits::max(), smemK, smemV, k); // Grid is exactly sized to rows available int limit = Pow2::roundDown(n_index_rows); diff --git a/cpp/include/raft/spatial/knn/detail/knn_brute_force_faiss.cuh b/cpp/include/raft/spatial/knn/detail/knn_brute_force_faiss.cuh deleted file mode 100644 index 2b2b5cee0c..0000000000 --- a/cpp/include/raft/spatial/knn/detail/knn_brute_force_faiss.cuh +++ /dev/null @@ -1,397 +0,0 @@ -/* - * Copyright (c) 2020-2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once - -#include -#include -#include - -#include - -#include - -#include -#include -#include -#include -#include -#include -#include -#include - -#include "fused_l2_knn.cuh" -#include "haversine_distance.cuh" -#include "processing.cuh" - -#include "common_faiss.h" - -namespace raft { -namespace spatial { -namespace knn { -namespace detail { - -template -__global__ void knn_merge_parts_kernel(value_t* inK, - value_idx* inV, - value_t* outK, - value_idx* outV, - size_t n_samples, - int n_parts, - value_t initK, - value_idx initV, - int k, - value_idx* translations) -{ - constexpr int kNumWarps = tpb / WarpSize; - - __shared__ value_t smemK[kNumWarps * warp_q]; - __shared__ value_idx smemV[kNumWarps * warp_q]; - - /** - * Uses shared memory - */ - faiss_select:: - BlockSelect, warp_q, thread_q, tpb> - heap(initK, initV, smemK, smemV, k); - - // Grid is exactly sized to rows available - int row = blockIdx.x; - int total_k = k * n_parts; - - int i = threadIdx.x; - - // Get starting pointers for cols in current thread - int part = i / k; - size_t row_idx = (row * k) + (part * n_samples * k); - - int col = i % k; - - value_t* inKStart = inK + (row_idx + col); - value_idx* inVStart = inV + (row_idx + col); - - int limit = Pow2::roundDown(total_k); - value_idx translation = 0; - - for (; i < limit; i += tpb) { - translation = translations[part]; - heap.add(*inKStart, (*inVStart) + translation); - - part = (i + tpb) / k; - row_idx = (row * k) + (part * n_samples * k); - - col = (i + tpb) % k; - - inKStart = inK + (row_idx + col); - inVStart = inV + (row_idx + col); - } - - // Handle last remainder fraction of a warp of elements - if (i < total_k) { - translation = translations[part]; - heap.addThreadQ(*inKStart, (*inVStart) + translation); - } - - heap.reduce(); - - for (int i = threadIdx.x; i < k; i += tpb) { - outK[row * k + i] = smemK[i]; - outV[row * k + i] = smemV[i]; - } -} - -template -inline void knn_merge_parts_impl(value_t* inK, - value_idx* inV, - value_t* outK, - value_idx* outV, - size_t n_samples, - int n_parts, - int k, - cudaStream_t stream, - value_idx* translations) -{ - auto grid = dim3(n_samples); - - constexpr int n_threads = (warp_q <= 1024) ? 128 : 64; - auto block = dim3(n_threads); - - auto kInit = std::numeric_limits::max(); - auto vInit = -1; - knn_merge_parts_kernel - <<>>( - inK, inV, outK, outV, n_samples, n_parts, kInit, vInit, k, translations); - RAFT_CUDA_TRY(cudaPeekAtLastError()); -} - -/** - * @brief Merge knn distances and index matrix, which have been partitioned - * by row, into a single matrix with only the k-nearest neighbors. - * - * @param inK partitioned knn distance matrix - * @param inV partitioned knn index matrix - * @param outK merged knn distance matrix - * @param outV merged knn index matrix - * @param n_samples number of samples per partition - * @param n_parts number of partitions - * @param k number of neighbors per partition (also number of merged neighbors) - * @param stream CUDA stream to use - * @param translations mapping of index offsets for each partition - */ -template -inline void knn_merge_parts(value_t* inK, - value_idx* inV, - value_t* outK, - value_idx* outV, - size_t n_samples, - int n_parts, - int k, - cudaStream_t stream, - value_idx* translations) -{ - if (k == 1) - knn_merge_parts_impl( - inK, inV, outK, outV, n_samples, n_parts, k, stream, translations); - else if (k <= 32) - knn_merge_parts_impl( - inK, inV, outK, outV, n_samples, n_parts, k, stream, translations); - else if (k <= 64) - knn_merge_parts_impl( - inK, inV, outK, outV, n_samples, n_parts, k, stream, translations); - else if (k <= 128) - knn_merge_parts_impl( - inK, inV, outK, outV, n_samples, n_parts, k, stream, translations); - else if (k <= 256) - knn_merge_parts_impl( - inK, inV, outK, outV, n_samples, n_parts, k, stream, translations); - else if (k <= 512) - knn_merge_parts_impl( - inK, inV, outK, outV, n_samples, n_parts, k, stream, translations); - else if (k <= 1024) - knn_merge_parts_impl( - inK, inV, outK, outV, n_samples, n_parts, k, stream, translations); -} - -/** - * Search the kNN for the k-nearest neighbors of a set of query vectors - * @param[in] input vector of device device memory array pointers to search - * @param[in] sizes vector of memory sizes for each device array pointer in input - * @param[in] D number of cols in input and search_items - * @param[in] search_items set of vectors to query for neighbors - * @param[in] n number of items in search_items - * @param[out] res_I pointer to device memory for returning k nearest indices - * @param[out] res_D pointer to device memory for returning k nearest distances - * @param[in] k number of neighbors to query - * @param[in] userStream the main cuda stream to use - * @param[in] internalStreams optional when n_params > 0, the index partitions can be - * queried in parallel using these streams. Note that n_int_streams also - * has to be > 0 for these to be used and their cardinality does not need - * to correspond to n_parts. - * @param[in] n_int_streams size of internalStreams. When this is <= 0, only the - * user stream will be used. - * @param[in] rowMajorIndex are the index arrays in row-major layout? - * @param[in] rowMajorQuery are the query array in row-major layout? - * @param[in] translations translation ids for indices when index rows represent - * non-contiguous partitions - * @param[in] metric corresponds to the raft::distance::DistanceType enum (default is L2Expanded) - * @param[in] metricArg metric argument to use. Corresponds to the p arg for lp norm - */ -template -void brute_force_knn_impl( - raft::device_resources const& handle, - std::vector& input, - std::vector& sizes, - IntType D, - value_t* search_items, - IntType n, - IdxType* res_I, - value_t* res_D, - IntType k, - bool rowMajorIndex = true, - bool rowMajorQuery = true, - std::vector* translations = nullptr, - raft::distance::DistanceType metric = raft::distance::DistanceType::L2Expanded, - float metricArg = 0) -{ - auto userStream = handle.get_stream(); - - ASSERT(input.size() == sizes.size(), "input and sizes vectors should be the same size"); - - std::vector* id_ranges; - if (translations == nullptr) { - // If we don't have explicit translations - // for offsets of the indices, build them - // from the local partitions - id_ranges = new std::vector(); - IdxType total_n = 0; - for (size_t i = 0; i < input.size(); i++) { - id_ranges->push_back(total_n); - total_n += sizes[i]; - } - } else { - // otherwise, use the given translations - id_ranges = translations; - } - - // perform preprocessing - std::unique_ptr> query_metric_processor = - create_processor(metric, n, D, k, rowMajorQuery, userStream); - query_metric_processor->preprocess(search_items); - - std::vector>> metric_processors(input.size()); - for (size_t i = 0; i < input.size(); i++) { - metric_processors[i] = - create_processor(metric, sizes[i], D, k, rowMajorQuery, userStream); - metric_processors[i]->preprocess(input[i]); - } - - int device; - RAFT_CUDA_TRY(cudaGetDevice(&device)); - - rmm::device_uvector trans(id_ranges->size(), userStream); - raft::update_device(trans.data(), id_ranges->data(), id_ranges->size(), userStream); - - rmm::device_uvector all_D(0, userStream); - rmm::device_uvector all_I(0, userStream); - - value_t* out_D = res_D; - IdxType* out_I = res_I; - - if (input.size() > 1) { - all_D.resize(input.size() * k * n, userStream); - all_I.resize(input.size() * k * n, userStream); - - out_D = all_D.data(); - out_I = all_I.data(); - } - - // Make other streams from pool wait on main stream - handle.wait_stream_pool_on_stream(); - - for (size_t i = 0; i < input.size(); i++) { - value_t* out_d_ptr = out_D + (i * k * n); - IdxType* out_i_ptr = out_I + (i * k * n); - - auto stream = handle.get_next_usable_stream(i); - - if (k <= 64 && rowMajorQuery == rowMajorIndex && rowMajorQuery == true && - (metric == raft::distance::DistanceType::L2Unexpanded || - metric == raft::distance::DistanceType::L2SqrtUnexpanded || - metric == raft::distance::DistanceType::L2Expanded || - metric == raft::distance::DistanceType::L2SqrtExpanded)) { - fusedL2Knn(D, - out_i_ptr, - out_d_ptr, - input[i], - search_items, - sizes[i], - n, - k, - rowMajorIndex, - rowMajorQuery, - stream, - metric); - } else { - switch (metric) { - case raft::distance::DistanceType::Haversine: - - ASSERT(D == 2, - "Haversine distance requires 2 dimensions " - "(latitude / longitude)."); - - haversine_knn(out_i_ptr, out_d_ptr, input[i], search_items, sizes[i], n, k, stream); - break; - default: - faiss::MetricType m = build_faiss_metric(metric); - - raft::spatial::knn::RmmGpuResources gpu_res; - - gpu_res.noTempMemory(); - gpu_res.setDefaultStream(device, stream); - - faiss::gpu::GpuDistanceParams args; - args.metric = m; - args.metricArg = metricArg; - args.k = k; - args.dims = D; - args.vectors = input[i]; - args.vectorsRowMajor = rowMajorIndex; - args.numVectors = sizes[i]; - args.queries = search_items; - args.queriesRowMajor = rowMajorQuery; - args.numQueries = n; - args.outDistances = out_d_ptr; - args.outIndices = out_i_ptr; - args.outIndicesType = sizeof(IdxType) == 4 ? faiss::gpu::IndicesDataType::I32 - : faiss::gpu::IndicesDataType::I64; - - /** - * @todo: Until FAISS supports pluggable allocation strategies, - * we will not reap the benefits of the pool allocator for - * avoiding device-wide synchronizations from cudaMalloc/cudaFree - */ - bfKnn(&gpu_res, args); - } - } - - RAFT_CUDA_TRY(cudaPeekAtLastError()); - } - - // Sync internal streams if used. We don't need to - // sync the user stream because we'll already have - // fully serial execution. - handle.sync_stream_pool(); - - if (input.size() > 1 || translations != nullptr) { - // This is necessary for proper index translations. If there are - // no translations or partitions to combine, it can be skipped. - knn_merge_parts(out_D, out_I, res_D, res_I, n, input.size(), k, userStream, trans.data()); - } - - // Perform necessary post-processing - if (metric == raft::distance::DistanceType::L2SqrtExpanded || - metric == raft::distance::DistanceType::L2SqrtUnexpanded || - metric == raft::distance::DistanceType::LpUnexpanded) { - /** - * post-processing - */ - float p = 0.5; // standard l2 - if (metric == raft::distance::DistanceType::LpUnexpanded) p = 1.0 / metricArg; - raft::linalg::unaryOp( - res_D, - res_D, - n * k, - [p] __device__(float input) { return powf(fabsf(input), p); }, - userStream); - } - - query_metric_processor->revert(search_items); - query_metric_processor->postprocess(out_D); - for (size_t i = 0; i < input.size(); i++) { - metric_processors[i]->revert(input[i]); - } - - if (translations == nullptr) delete id_ranges; -}; - -} // namespace detail -} // namespace knn -} // namespace spatial -} // namespace raft diff --git a/cpp/include/raft/spatial/knn/knn.cuh b/cpp/include/raft/spatial/knn/knn.cuh index ca2c248392..692d262043 100644 --- a/cpp/include/raft/spatial/knn/knn.cuh +++ b/cpp/include/raft/spatial/knn/knn.cuh @@ -16,13 +16,12 @@ #pragma once -#include "detail/knn_brute_force_faiss.cuh" -#include "detail/selection_faiss.cuh" - #include #include #include #include +#include +#include namespace raft::spatial::knn { @@ -61,7 +60,7 @@ inline void knn_merge_parts(value_t* in_keys, cudaStream_t stream, idx_t* translations) { - detail::knn_merge_parts( + raft::neighbors::detail::knn_merge_parts( in_keys, in_values, out_keys, out_values, n_samples, n_parts, k, stream, translations); } @@ -148,7 +147,7 @@ template switch (algo) { case SelectKAlgo::FAISS: - detail::select_k( + neighbors::detail::select_k( in_keys, in_values, n_inputs, input_len, out_keys, out_values, select_min, k, stream); break; @@ -212,20 +211,20 @@ void brute_force_knn(raft::device_resources const& handle, { ASSERT(input.size() == sizes.size(), "input and sizes vectors must be the same size"); - detail::brute_force_knn_impl(handle, - input, - sizes, - D, - search_items, - n, - res_I, - res_D, - k, - rowMajorIndex, - rowMajorQuery, - translations, - metric, - metric_arg); + raft::neighbors::detail::brute_force_knn_impl(handle, + input, + sizes, + D, + search_items, + n, + res_I, + res_D, + k, + rowMajorIndex, + rowMajorQuery, + translations, + metric, + metric_arg); } } // namespace raft::spatial::knn diff --git a/cpp/include/raft/util/bitonic_sort.cuh b/cpp/include/raft/util/bitonic_sort.cuh index 5de464b4c7..e34708e332 100644 --- a/cpp/include/raft/util/bitonic_sort.cuh +++ b/cpp/include/raft/util/bitonic_sort.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2022, NVIDIA CORPORATION. + * Copyright (c) 2022-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. diff --git a/cpp/include/raft/util/integer_utils.hpp b/cpp/include/raft/util/integer_utils.hpp index 3b0d9d44ae..e85086df42 100644 --- a/cpp/include/raft/util/integer_utils.hpp +++ b/cpp/include/raft/util/integer_utils.hpp @@ -1,7 +1,7 @@ /* * Copyright 2019 BlazingDB, Inc. * Copyright 2019 Eyal Rozenberg - * Copyright (c) 2020-2022, NVIDIA CORPORATION. + * Copyright (c) 2020-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. diff --git a/cpp/src/nn/specializations/brute_force_knn_long_float_int.cu b/cpp/src/nn/specializations/brute_force_knn_long_float_int.cu index a74ce0c0c5..2c21d1ec64 100644 --- a/cpp/src/nn/specializations/brute_force_knn_long_float_int.cu +++ b/cpp/src/nn/specializations/brute_force_knn_long_float_int.cu @@ -15,11 +15,9 @@ */ #include +#include #include -// TODO: Change this to proper specializations after FAISS is removed -#include - namespace raft { namespace spatial { namespace knn { diff --git a/cpp/src/nn/specializations/brute_force_knn_long_float_uint.cu b/cpp/src/nn/specializations/brute_force_knn_long_float_uint.cu index 576c74d9e0..7e6e7e80d0 100644 --- a/cpp/src/nn/specializations/brute_force_knn_long_float_uint.cu +++ b/cpp/src/nn/specializations/brute_force_knn_long_float_uint.cu @@ -15,11 +15,9 @@ */ #include +#include #include -// TODO: Change this to proper specializations after FAISS is removed -#include - namespace raft { namespace spatial { namespace knn { diff --git a/cpp/src/nn/specializations/brute_force_knn_uint32_t_float_int.cu b/cpp/src/nn/specializations/brute_force_knn_uint32_t_float_int.cu index cba2436a4d..e94c12d579 100644 --- a/cpp/src/nn/specializations/brute_force_knn_uint32_t_float_int.cu +++ b/cpp/src/nn/specializations/brute_force_knn_uint32_t_float_int.cu @@ -15,11 +15,9 @@ */ #include +#include #include -// TODO: Change this to proper specializations after FAISS is removed -#include - namespace raft { namespace spatial { namespace knn { diff --git a/cpp/src/nn/specializations/brute_force_knn_uint32_t_float_uint.cu b/cpp/src/nn/specializations/brute_force_knn_uint32_t_float_uint.cu index 17c687ebce..95cf8a1eb3 100644 --- a/cpp/src/nn/specializations/brute_force_knn_uint32_t_float_uint.cu +++ b/cpp/src/nn/specializations/brute_force_knn_uint32_t_float_uint.cu @@ -15,11 +15,9 @@ */ #include +#include #include -// TODO: Change this to proper specializations after FAISS is removed -#include - namespace raft { namespace spatial { namespace knn { diff --git a/cpp/test/CMakeLists.txt b/cpp/test/CMakeLists.txt index ed6b44c71c..acfb470bd8 100644 --- a/cpp/test/CMakeLists.txt +++ b/cpp/test/CMakeLists.txt @@ -268,10 +268,11 @@ if(BUILD_TESTS) test/neighbors/ann_ivf_pq/test_uint8_t_int64_t.cu test/neighbors/knn.cu test/neighbors/fused_l2_knn.cu + test/neighbors/tiled_knn.cu test/neighbors/haversine.cu test/neighbors/ball_cover.cu - test/neighbors/epsilon_neighborhood.cu test/neighbors/faiss_mr.cu + test/neighbors/epsilon_neighborhood.cu test/neighbors/refine.cu test/neighbors/selection.cu OPTIONAL diff --git a/cpp/test/neighbors/ball_cover.cu b/cpp/test/neighbors/ball_cover.cu index 6e387a3bb1..6dcae8e34d 100644 --- a/cpp/test/neighbors/ball_cover.cu +++ b/cpp/test/neighbors/ball_cover.cu @@ -19,8 +19,8 @@ #include #include #include +#include #include -#include #include #if defined RAFT_NN_COMPILED #include @@ -112,24 +112,16 @@ void compute_bfknn(const raft::device_resources& handle, value_t* dists, int64_t* inds) { - std::vector input_vec = {const_cast(X1)}; - std::vector sizes_vec = {n_rows}; - - std::vector* translations = nullptr; - - raft::spatial::knn::detail::brute_force_knn_impl(handle, - input_vec, - sizes_vec, - d, - const_cast(X2), - n_query_rows, - inds, - dists, - k, - true, - true, - translations, - metric); + std::vector> input_vec = { + make_device_matrix_view(X1, n_rows, d)}; + + raft::neighbors::brute_force::knn(handle, + input_vec, + make_device_matrix_view(X2, n_query_rows, d), + make_device_matrix_view(inds, n_query_rows, k), + make_device_matrix_view(dists, n_query_rows, k), + k, + metric); } struct ToRadians { @@ -361,4 +353,4 @@ INSTANTIATE_TEST_CASE_P(BallCoverKNNQueryTest, TEST_P(BallCoverAllKNNTestF, Fit) { basicTest(); } TEST_P(BallCoverKNNQueryTestF, Fit) { basicTest(); } -} // namespace raft::neighbors::ball_cover \ No newline at end of file +} // namespace raft::neighbors::ball_cover diff --git a/cpp/test/neighbors/faiss_mr.cu b/cpp/test/neighbors/faiss_mr.cu index 5f0bcae933..89f012db0f 100644 --- a/cpp/test/neighbors/faiss_mr.cu +++ b/cpp/test/neighbors/faiss_mr.cu @@ -18,6 +18,7 @@ #include #include +#include #include #include diff --git a/cpp/test/neighbors/fused_l2_knn.cu b/cpp/test/neighbors/fused_l2_knn.cu index 61bee62235..a5fead8093 100644 --- a/cpp/test/neighbors/fused_l2_knn.cu +++ b/cpp/test/neighbors/fused_l2_knn.cu @@ -15,6 +15,7 @@ */ #include "../test_utils.cuh" +#include "./knn_utils.cuh" #include #include @@ -51,59 +52,6 @@ struct FusedL2KNNInputs { raft::distance::DistanceType metric_; }; -template -struct idx_dist_pair { - IdxT idx; - DistT dist; - compareDist eq_compare; - bool operator==(const idx_dist_pair& a) const - { - if (idx == a.idx) return true; - if (eq_compare(dist, a.dist)) return true; - return false; - } - idx_dist_pair(IdxT x, DistT y, compareDist op) : idx(x), dist(y), eq_compare(op) {} -}; - -template -testing::AssertionResult devArrMatchKnnPair(const T* expected_idx, - const T* actual_idx, - const DistT* expected_dist, - const DistT* actual_dist, - size_t rows, - size_t cols, - const DistT eps, - cudaStream_t stream = 0) -{ - size_t size = rows * cols; - std::unique_ptr exp_idx_h(new T[size]); - std::unique_ptr act_idx_h(new T[size]); - std::unique_ptr exp_dist_h(new DistT[size]); - std::unique_ptr act_dist_h(new DistT[size]); - raft::update_host(exp_idx_h.get(), expected_idx, size, stream); - raft::update_host(act_idx_h.get(), actual_idx, size, stream); - raft::update_host(exp_dist_h.get(), expected_dist, size, stream); - raft::update_host(act_dist_h.get(), actual_dist, size, stream); - RAFT_CUDA_TRY(cudaStreamSynchronize(stream)); - for (size_t i(0); i < rows; ++i) { - for (size_t j(0); j < cols; ++j) { - auto idx = i * cols + j; // row major assumption! - auto exp_idx = exp_idx_h.get()[idx]; - auto act_idx = act_idx_h.get()[idx]; - auto exp_dist = exp_dist_h.get()[idx]; - auto act_dist = act_dist_h.get()[idx]; - idx_dist_pair exp_kvp(exp_idx, exp_dist, raft::CompareApprox(eps)); - idx_dist_pair act_kvp(act_idx, act_dist, raft::CompareApprox(eps)); - if (!(exp_kvp == act_kvp)) { - return testing::AssertionFailure() - << "actual=" << act_kvp.idx << "," << act_kvp.dist << "!=" - << "expected" << exp_kvp.idx << "," << exp_kvp.dist << " @" << i << "," << j; - } - } - } - return testing::AssertionSuccess(); -} - template class FusedL2KNNTest : public ::testing::TestWithParam { public: diff --git a/cpp/test/neighbors/knn.cu b/cpp/test/neighbors/knn.cu index fd7a1a03aa..7976725c65 100644 --- a/cpp/test/neighbors/knn.cu +++ b/cpp/test/neighbors/knn.cu @@ -20,6 +20,11 @@ #include #include #include + +#if defined RAFT_DISTANCE_COMPILED +#include +#endif + #if defined RAFT_NN_COMPILED #include #endif @@ -188,12 +193,12 @@ const std::vector inputs = { 2, {0, 0, 0, 0, 0, 1, 1, 1, 1, 1}}}; -typedef KNNTest KNNTestFint64_t; -TEST_P(KNNTestFint64_t, BruteForce) { this->testBruteForce(); } +typedef KNNTest KNNTestFint32_t; +TEST_P(KNNTestFint32_t, BruteForce) { this->testBruteForce(); } typedef KNNTest KNNTestFuint32_t; TEST_P(KNNTestFuint32_t, BruteForce) { this->testBruteForce(); } -INSTANTIATE_TEST_CASE_P(KNNTest, KNNTestFint64_t, ::testing::ValuesIn(inputs)); +INSTANTIATE_TEST_CASE_P(KNNTest, KNNTestFint32_t, ::testing::ValuesIn(inputs)); INSTANTIATE_TEST_CASE_P(KNNTest, KNNTestFuint32_t, ::testing::ValuesIn(inputs)); } // namespace raft::neighbors::brute_force diff --git a/cpp/test/neighbors/knn_utils.cuh b/cpp/test/neighbors/knn_utils.cuh new file mode 100644 index 0000000000..ac34699ac5 --- /dev/null +++ b/cpp/test/neighbors/knn_utils.cuh @@ -0,0 +1,79 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include + +#include "../test_utils.cuh" +#include + +#include + +namespace raft::spatial::knn { +template +struct idx_dist_pair { + IdxT idx; + DistT dist; + compareDist eq_compare; + bool operator==(const idx_dist_pair& a) const + { + if (idx == a.idx) return true; + if (eq_compare(dist, a.dist)) return true; + return false; + } + idx_dist_pair(IdxT x, DistT y, compareDist op) : idx(x), dist(y), eq_compare(op) {} +}; + +template +testing::AssertionResult devArrMatchKnnPair(const T* expected_idx, + const T* actual_idx, + const DistT* expected_dist, + const DistT* actual_dist, + size_t rows, + size_t cols, + const DistT eps, + cudaStream_t stream = 0) +{ + size_t size = rows * cols; + std::unique_ptr exp_idx_h(new T[size]); + std::unique_ptr act_idx_h(new T[size]); + std::unique_ptr exp_dist_h(new DistT[size]); + std::unique_ptr act_dist_h(new DistT[size]); + raft::update_host(exp_idx_h.get(), expected_idx, size, stream); + raft::update_host(act_idx_h.get(), actual_idx, size, stream); + raft::update_host(exp_dist_h.get(), expected_dist, size, stream); + raft::update_host(act_dist_h.get(), actual_dist, size, stream); + RAFT_CUDA_TRY(cudaStreamSynchronize(stream)); + for (size_t i(0); i < rows; ++i) { + for (size_t j(0); j < cols; ++j) { + auto idx = i * cols + j; // row major assumption! + auto exp_idx = exp_idx_h.get()[idx]; + auto act_idx = act_idx_h.get()[idx]; + auto exp_dist = exp_dist_h.get()[idx]; + auto act_dist = act_dist_h.get()[idx]; + idx_dist_pair exp_kvp(exp_idx, exp_dist, raft::CompareApprox(eps)); + idx_dist_pair act_kvp(act_idx, act_dist, raft::CompareApprox(eps)); + if (!(exp_kvp == act_kvp)) { + return testing::AssertionFailure() + << "actual=" << act_kvp.idx << "," << act_kvp.dist << "!=" + << "expected" << exp_kvp.idx << "," << exp_kvp.dist << " @" << i << "," << j; + } + } + } + return testing::AssertionSuccess(); +} +} // namespace raft::spatial::knn diff --git a/cpp/test/neighbors/selection.cu b/cpp/test/neighbors/selection.cu index 26e37e433f..25939f65c3 100644 --- a/cpp/test/neighbors/selection.cu +++ b/cpp/test/neighbors/selection.cu @@ -118,7 +118,7 @@ struct SelectInOutComputed { } break; case knn::SelectKAlgo::FAISS: - if (spec.k > raft::spatial::knn::detail::kFaissMaxK()) { + if (spec.k > raft::neighbors::detail::kFaissMaxK()) { not_supported = true; return; } diff --git a/cpp/test/neighbors/tiled_knn.cu b/cpp/test/neighbors/tiled_knn.cu new file mode 100644 index 0000000000..4784f915f3 --- /dev/null +++ b/cpp/test/neighbors/tiled_knn.cu @@ -0,0 +1,254 @@ +/* + * Copyright (c) 2021-2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "../test_utils.cuh" +#include "./ann_utils.cuh" +#include "./knn_utils.cuh" + +#include +#include +#include +#include +#include +#include + +#if defined RAFT_NN_COMPILED +#include +#include +#endif + +#include + +#include + +#include +#include +#include + +namespace raft::neighbors::brute_force { +struct TiledKNNInputs { + int num_queries; + int num_db_vecs; + int dim; + int k; + int row_tiles; + int col_tiles; + raft::distance::DistanceType metric; + bool row_major; +}; + +std::ostream& operator<<(std::ostream& os, const TiledKNNInputs& input) +{ + return os << "num_queries:" << input.num_queries << " num_vecs:" << input.num_db_vecs + << " dim:" << input.dim << " k:" << input.k << " row_tiles:" << input.row_tiles + << " col_tiles:" << input.col_tiles << " metric:" << print_metric{input.metric} + << " row_major:" << input.row_major; +} + +template +class TiledKNNTest : public ::testing::TestWithParam { + public: + TiledKNNTest() + : stream_(handle_.get_stream()), + params_(::testing::TestWithParam::GetParam()), + database(params_.num_db_vecs * params_.dim, stream_), + search_queries(params_.num_queries * params_.dim, stream_), + raft_indices_(params_.num_queries * params_.k, stream_), + raft_distances_(params_.num_queries * params_.k, stream_), + ref_indices_(params_.num_queries * params_.k, stream_), + ref_distances_(params_.num_queries * params_.k, stream_) + { + raft::matrix::fill( + handle_, + raft::make_device_matrix_view(database.data(), params_.num_db_vecs, params_.dim), + T{0.0}); + raft::matrix::fill( + handle_, + raft::make_device_matrix_view(search_queries.data(), params_.num_queries, params_.dim), + T{0.0}); + raft::matrix::fill( + handle_, + raft::make_device_matrix_view(raft_indices_.data(), params_.num_queries, params_.k), + 0); + raft::matrix::fill( + handle_, + raft::make_device_matrix_view(raft_distances_.data(), params_.num_queries, params_.k), + T{0.0}); + raft::matrix::fill( + handle_, + raft::make_device_matrix_view(ref_indices_.data(), params_.num_queries, params_.k), + 0); + raft::matrix::fill( + handle_, + raft::make_device_matrix_view(ref_distances_.data(), params_.num_queries, params_.k), + T{0.0}); + } + + protected: + void testBruteForce() + { + float metric_arg = 3.0; + + // calculate the naive knn, by calculating the full pairwise distances and doing a k-select + rmm::device_uvector temp_distances(num_db_vecs * num_queries, stream_); + rmm::device_uvector workspace(0, stream_); + distance::pairwise_distance(handle_, + search_queries.data(), + database.data(), + temp_distances.data(), + num_queries, + num_db_vecs, + dim, + workspace, + metric, + params_.row_major, + metric_arg); + + // setting the 'isRowMajor' flag in the pairwise distances api, not only sets + // the inputs as colmajor - but also the output. this means we have to transpose in this + // case + auto temp_dist = temp_distances.data(); + rmm::device_uvector temp_row_major_dist(num_db_vecs * num_queries, stream_); + if (!params_.row_major) { + raft::linalg::transpose( + handle_, temp_dist, temp_row_major_dist.data(), num_queries, num_db_vecs, stream_); + temp_dist = temp_row_major_dist.data(); + } + + raft::neighbors::detail::select_k(temp_dist, + nullptr, + num_queries, + num_db_vecs, + ref_distances_.data(), + ref_indices_.data(), + raft::distance::is_min_close(metric), + k_, + stream_); + + if ((params_.row_tiles == 0) && (params_.col_tiles == 0)) { + std::vector input{database.data()}; + std::vector sizes{static_cast(num_db_vecs)}; + neighbors::detail::brute_force_knn_impl(handle_, + input, + sizes, + dim, + const_cast(search_queries.data()), + num_queries, + raft_indices_.data(), + raft_distances_.data(), + k_, + params_.row_major, + params_.row_major, + nullptr, + metric, + metric_arg); + } else { + neighbors::detail::tiled_brute_force_knn(handle_, + search_queries.data(), + database.data(), + num_queries, + num_db_vecs, + dim, + k_, + raft_distances_.data(), + raft_indices_.data(), + metric, + metric_arg, + params_.row_tiles, + params_.col_tiles); + } + + // verify. + ASSERT_TRUE(raft::spatial::knn::devArrMatchKnnPair(ref_indices_.data(), + raft_indices_.data(), + ref_distances_.data(), + raft_distances_.data(), + num_queries, + k_, + float(0.001), + stream_)); + } + + void SetUp() override + { + num_queries = params_.num_queries; + num_db_vecs = params_.num_db_vecs; + dim = params_.dim; + k_ = params_.k; + metric = params_.metric; + + unsigned long long int seed = 1234ULL; + raft::random::RngState r(seed); + + // JensenShannon distance requires positive values + T min_val = metric == raft::distance::DistanceType::JensenShannon ? T(0.0) : T(-1.0); + uniform(handle_, r, database.data(), num_db_vecs * dim, min_val, T(1.0)); + uniform(handle_, r, search_queries.data(), num_queries * dim, min_val, T(1.0)); + } + + private: + raft::device_resources handle_; + cudaStream_t stream_ = 0; + TiledKNNInputs params_; + int num_queries; + int num_db_vecs; + int dim; + rmm::device_uvector database; + rmm::device_uvector search_queries; + rmm::device_uvector raft_indices_; + rmm::device_uvector raft_distances_; + rmm::device_uvector ref_indices_; + rmm::device_uvector ref_distances_; + int k_; + raft::distance::DistanceType metric; +}; + +const std::vector random_inputs = { + {256, 512, 16, 8, 16, 8, raft::distance::DistanceType::L2Expanded, true}, + {256, 512, 16, 8, 16, 8, raft::distance::DistanceType::L2Unexpanded, true}, + {256, 512, 16, 8, 16, 8, raft::distance::DistanceType::L2SqrtExpanded, true}, + {256, 512, 16, 8, 16, 8, raft::distance::DistanceType::L2SqrtUnexpanded, true}, + {256, 512, 16, 8, 16, 8, raft::distance::DistanceType::L1, true}, + {256, 512, 16, 8, 16, 8, raft::distance::DistanceType::Linf, true}, + {256, 512, 16, 8, 16, 8, raft::distance::DistanceType::InnerProduct, true}, + {256, 512, 16, 8, 16, 8, raft::distance::DistanceType::CorrelationExpanded, true}, + {256, 512, 16, 8, 16, 8, raft::distance::DistanceType::CosineExpanded, true}, + {256, 512, 16, 8, 16, 8, raft::distance::DistanceType::LpUnexpanded, true}, + {256, 512, 16, 8, 16, 8, raft::distance::DistanceType::JensenShannon, true}, + {256, 512, 16, 8, 16, 8, raft::distance::DistanceType::L2SqrtExpanded, true}, + // BrayCurtis isn't currently supported by pairwise_distance api + // {256, 512, 16, 8, 16, 8, raft::distance::DistanceType::BrayCurtis}, + {256, 512, 16, 8, 16, 8, raft::distance::DistanceType::Canberra, true}, + {10000, 40000, 32, 30, 512, 1024, raft::distance::DistanceType::L2Expanded, true}, + {345, 1023, 16, 128, 512, 1024, raft::distance::DistanceType::CosineExpanded, true}, + {789, 20516, 64, 256, 512, 4096, raft::distance::DistanceType::L2SqrtExpanded, true}, + // Test where the final column tile has < K items: + {4, 12, 32, 6, 4, 8, raft::distance::DistanceType::L2Expanded, true}, + // Test where passing column_tiles < K + {1, 40, 32, 30, 1, 8, raft::distance::DistanceType::L2Expanded, true}, + // Passing tile sizes of 0 means to use brute_force_knn_impl (instead of the + // tiled_brute_force_knn api). + {1000, 500000, 128, 128, 0, 0, raft::distance::DistanceType::L2Expanded, true}, + {1000, 500000, 128, 128, 0, 0, raft::distance::DistanceType::L2Expanded, false}, + {1000, 5000, 128, 128, 0, 0, raft::distance::DistanceType::LpUnexpanded, true}, + {1000, 5000, 128, 128, 0, 0, raft::distance::DistanceType::L2SqrtExpanded, false}, + {1000, 5000, 128, 128, 0, 0, raft::distance::DistanceType::InnerProduct, false}}; + +typedef TiledKNNTest TiledKNNTestF; +TEST_P(TiledKNNTestF, BruteForce) { this->testBruteForce(); } + +INSTANTIATE_TEST_CASE_P(TiledKNNTest, TiledKNNTestF, ::testing::ValuesIn(random_inputs)); +} // namespace raft::neighbors::brute_force