diff --git a/README.md b/README.md index 34d66cbbc3..fa8b2b36e0 100755 --- a/README.md +++ b/README.md @@ -315,6 +315,7 @@ The folder structure mirrors other RAPIDS repos, with the following folders: - `solver`: Sparse solvers for optimization and approximation - `stats`: Moments, summary statistics, model performance measures - `util`: Various reusable tools and utilities for accelerated algorithm development + - `internal`: A private header-only component that hosts the code shared between benchmarks. - `scripts`: Helpful scripts for development - `src`: Compiled APIs and template specializations for the shared libraries - `test`: Googletests source code diff --git a/cpp/CMakeLists.txt b/cpp/CMakeLists.txt index 784bbbb935..c6850b290f 100644 --- a/cpp/CMakeLists.txt +++ b/cpp/CMakeLists.txt @@ -665,6 +665,13 @@ raft_export( distance distributed nn DOCUMENTATION doc_string NAMESPACE raft:: FINAL_CODE_BLOCK code_string ) +# ################################################################################################## +# * shared test/bench headers ------------------------------------------------ + +if(BUILD_TESTS OR BUILD_BENCH) + include(internal/CMakeLists.txt) +endif() + # ################################################################################################## # * build test executable ---------------------------------------------------- diff --git a/cpp/bench/CMakeLists.txt b/cpp/bench/CMakeLists.txt index 6b985acfc3..b1ffc72ba9 100644 --- a/cpp/bench/CMakeLists.txt +++ b/cpp/bench/CMakeLists.txt @@ -30,9 +30,9 @@ function(ConfigureBench) target_link_libraries( ${BENCH_NAME} PRIVATE raft::raft + raft_internal $<$:raft::distance> $<$:raft::nn> - GTest::gtest benchmark::benchmark Threads::Threads $ diff --git a/cpp/bench/matrix/select_k.cu b/cpp/bench/matrix/select_k.cu index 452a50ba50..3279c011cc 100644 --- a/cpp/bench/matrix/select_k.cu +++ b/cpp/bench/matrix/select_k.cu @@ -14,12 +14,7 @@ * limitations under the License. */ -/** - * TODO: reconsider how to organize shared test+bench files better - * Related Issue: https://github.com/rapidsai/raft/issues/1153 - * (although this header does not depend on any gtest headers) - */ -#include "../../test/matrix/select_k.cuh" +#include #include diff --git a/cpp/bench/neighbors/refine.cu b/cpp/bench/neighbors/refine.cu index a038905ace..cfce402968 100644 --- a/cpp/bench/neighbors/refine.cu +++ b/cpp/bench/neighbors/refine.cu @@ -1,5 +1,5 @@ /* - * Copyright (c) 2022, NVIDIA CORPORATION. + * Copyright (c) 2022-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -14,15 +14,16 @@ * limitations under the License. */ -#include +#include -#include +#include #include #include #include #include #include +#include #if defined RAFT_DISTANCE_COMPILED #include @@ -36,12 +37,10 @@ #include #include -#include "../../test/neighbors/refine_helper.cuh" - #include #include -using namespace raft::neighbors::detail; +using namespace raft::neighbors; namespace raft::bench::neighbors { diff --git a/cpp/internal/CMakeLists.txt b/cpp/internal/CMakeLists.txt new file mode 100644 index 0000000000..4d5c585c01 --- /dev/null +++ b/cpp/internal/CMakeLists.txt @@ -0,0 +1,21 @@ +# ============================================================================= +# Copyright (c) 2023, NVIDIA CORPORATION. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except +# in compliance with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software distributed under the License +# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +# or implied. See the License for the specific language governing permissions and limitations under +# the License. +# ============================================================================= + +if(BUILD_TESTS OR BUILD_BENCH) + add_library(raft_internal INTERFACE) + target_include_directories( + raft_internal INTERFACE "$" + ) + target_compile_features(raft_internal INTERFACE cxx_std_17 $) +endif() diff --git a/cpp/test/matrix/select_k.cuh b/cpp/internal/raft_internal/matrix/select_k.cuh similarity index 100% rename from cpp/test/matrix/select_k.cuh rename to cpp/internal/raft_internal/matrix/select_k.cuh diff --git a/cpp/internal/raft_internal/neighbors/naive_knn.cuh b/cpp/internal/raft_internal/neighbors/naive_knn.cuh new file mode 100644 index 0000000000..3ad055272b --- /dev/null +++ b/cpp/internal/raft_internal/neighbors/naive_knn.cuh @@ -0,0 +1,126 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include +#include +#include + +#include +#include +#include + +namespace raft::neighbors { + +template +__global__ void naive_distance_kernel(EvalT* dist, + const DataT* x, + const DataT* y, + IdxT m, + IdxT n, + IdxT k, + raft::distance::DistanceType metric) +{ + IdxT midx = IdxT(threadIdx.x) + IdxT(blockIdx.x) * IdxT(blockDim.x); + if (midx >= m) return; + IdxT grid_size = IdxT(blockDim.y) * IdxT(gridDim.y); + for (IdxT nidx = threadIdx.y + blockIdx.y * blockDim.y; nidx < n; nidx += grid_size) { + EvalT acc = EvalT(0); + for (IdxT i = 0; i < k; ++i) { + IdxT xidx = i + midx * k; + IdxT yidx = i + nidx * k; + auto xv = EvalT(x[xidx]); + auto yv = EvalT(y[yidx]); + switch (metric) { + case raft::distance::DistanceType::InnerProduct: { + acc += xv * yv; + } break; + case raft::distance::DistanceType::L2SqrtExpanded: + case raft::distance::DistanceType::L2SqrtUnexpanded: + case raft::distance::DistanceType::L2Expanded: + case raft::distance::DistanceType::L2Unexpanded: { + auto diff = xv - yv; + acc += diff * diff; + } break; + default: break; + } + } + switch (metric) { + case raft::distance::DistanceType::L2SqrtExpanded: + case raft::distance::DistanceType::L2SqrtUnexpanded: { + acc = raft::sqrt(acc); + } break; + default: break; + } + dist[midx * n + nidx] = acc; + } +} + +/** + * Naive, but flexible bruteforce KNN search. + * + * TODO: either replace this with brute_force_knn or with distance+select_k + * when either distance or brute_force_knn support 8-bit int inputs. + */ +template +void naive_knn(EvalT* dist_topk, + IdxT* indices_topk, + const DataT* x, + const DataT* y, + size_t n_inputs, + size_t input_len, + size_t dim, + uint32_t k, + raft::distance::DistanceType type, + rmm::cuda_stream_view stream) +{ + rmm::mr::device_memory_resource* mr = nullptr; + auto pool_guard = raft::get_pool_memory_resource(mr, 1024 * 1024); + + dim3 block_dim(16, 32, 1); + // maximum reasonable grid size in `y` direction + auto grid_y = + static_cast(std::min(raft::ceildiv(input_len, block_dim.y), 32768)); + + // bound the memory used by this function + size_t max_batch_size = + std::min(n_inputs, raft::ceildiv(size_t(1) << size_t(27), input_len)); + rmm::device_uvector dist(max_batch_size * input_len, stream, mr); + + for (size_t offset = 0; offset < n_inputs; offset += max_batch_size) { + size_t batch_size = std::min(max_batch_size, n_inputs - offset); + dim3 grid_dim(raft::ceildiv(batch_size, block_dim.x), grid_y, 1); + + naive_distance_kernel<<>>( + dist.data(), x + offset * dim, y, batch_size, input_len, dim, type); + + matrix::detail::select_k(dist.data(), + nullptr, + batch_size, + input_len, + static_cast(k), + dist_topk + offset * k, + indices_topk + offset * k, + type != raft::distance::DistanceType::InnerProduct, + stream, + mr); + } + RAFT_CUDA_TRY(cudaStreamSynchronize(stream)); +} + +} // namespace raft::neighbors diff --git a/cpp/test/neighbors/refine_helper.cuh b/cpp/internal/raft_internal/neighbors/refine_helper.cuh similarity index 74% rename from cpp/test/neighbors/refine_helper.cuh rename to cpp/internal/raft_internal/neighbors/refine_helper.cuh index 3c69a8f5b7..1d8c5600bd 100644 --- a/cpp/test/neighbors/refine_helper.cuh +++ b/cpp/internal/raft_internal/neighbors/refine_helper.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2022, NVIDIA CORPORATION. + * Copyright (c) 2022-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -15,7 +15,8 @@ */ #pragma once -#include "ann_utils.cuh" +#include + #include #include #include @@ -25,9 +26,9 @@ #include #include -#include +#include -namespace raft::neighbors::detail { +namespace raft::neighbors { template struct RefineInputs { @@ -66,16 +67,16 @@ class RefineHelper { { candidates = raft::make_device_matrix(handle_, p.n_queries, p.k0); rmm::device_uvector distances_tmp(p.n_queries * p.k0, stream_); - raft::neighbors::naiveBfKnn(distances_tmp.data(), - candidates.data_handle(), - queries.data_handle(), - dataset.data_handle(), - p.n_queries, - p.n_rows, - p.dim, - p.k0, - p.metric, - stream_); + naive_knn(distances_tmp.data(), + candidates.data_handle(), + queries.data_handle(), + dataset.data_handle(), + p.n_queries, + p.n_rows, + p.dim, + p.k0, + p.metric, + stream_); handle_.sync_stream(stream_); } @@ -98,16 +99,16 @@ class RefineHelper { { rmm::device_uvector distances_dev(p.n_queries * p.k, stream_); rmm::device_uvector indices_dev(p.n_queries * p.k, stream_); - raft::neighbors::naiveBfKnn(distances_dev.data(), - indices_dev.data(), - queries.data_handle(), - dataset.data_handle(), - p.n_queries, - p.n_rows, - p.dim, - p.k, - p.metric, - stream_); + naive_knn(distances_dev.data(), + indices_dev.data(), + queries.data_handle(), + dataset.data_handle(), + p.n_queries, + p.n_rows, + p.dim, + p.k, + p.metric, + stream_); true_refined_distances_host.resize(p.n_queries * p.k); true_refined_indices_host.resize(p.n_queries * p.k); raft::copy(true_refined_indices_host.data(), indices_dev.data(), indices_dev.size(), stream_); @@ -137,4 +138,4 @@ class RefineHelper { std::vector true_refined_indices_host; std::vector true_refined_distances_host; }; -} // namespace raft::neighbors::detail \ No newline at end of file +} // namespace raft::neighbors diff --git a/cpp/test/CMakeLists.txt b/cpp/test/CMakeLists.txt index 6c7ca11d86..8a9071fdd1 100644 --- a/cpp/test/CMakeLists.txt +++ b/cpp/test/CMakeLists.txt @@ -32,6 +32,7 @@ function(ConfigureTest) target_link_libraries( ${TEST_NAME} PRIVATE raft::raft + raft_internal $<$:raft::distance> $<$:raft::nn> GTest::gtest diff --git a/cpp/test/matrix/select_k.cu b/cpp/test/matrix/select_k.cu index cb92c15790..a8b5d60bb8 100644 --- a/cpp/test/matrix/select_k.cu +++ b/cpp/test/matrix/select_k.cu @@ -1,5 +1,5 @@ /* - * Copyright (c) 2022, NVIDIA CORPORATION. + * Copyright (c) 2022-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -15,7 +15,8 @@ */ #include "../test_utils.cuh" -#include "select_k.cuh" + +#include #include #include diff --git a/cpp/test/neighbors/ann_ivf_flat.cu b/cpp/test/neighbors/ann_ivf_flat.cu index 46a80a2f56..232759a948 100644 --- a/cpp/test/neighbors/ann_ivf_flat.cu +++ b/cpp/test/neighbors/ann_ivf_flat.cu @@ -17,6 +17,8 @@ #include "../test_utils.cuh" #include "ann_utils.cuh" +#include + #include #include #include @@ -78,16 +80,16 @@ class AnnIVFFlatTest : public ::testing::TestWithParam> { { rmm::device_uvector distances_naive_dev(queries_size, stream_); rmm::device_uvector indices_naive_dev(queries_size, stream_); - naiveBfKnn(distances_naive_dev.data(), - indices_naive_dev.data(), - search_queries.data(), - database.data(), - ps.num_queries, - ps.num_db_vecs, - ps.dim, - ps.k, - ps.metric, - stream_); + naive_knn(distances_naive_dev.data(), + indices_naive_dev.data(), + search_queries.data(), + database.data(), + ps.num_queries, + ps.num_db_vecs, + ps.dim, + ps.k, + ps.metric, + stream_); update_host(distances_naive.data(), distances_naive_dev.data(), queries_size, stream_); update_host(indices_naive.data(), indices_naive_dev.data(), queries_size, stream_); handle_.sync_stream(stream_); diff --git a/cpp/test/neighbors/ann_ivf_pq.cuh b/cpp/test/neighbors/ann_ivf_pq.cuh index 719f429f13..178078b297 100644 --- a/cpp/test/neighbors/ann_ivf_pq.cuh +++ b/cpp/test/neighbors/ann_ivf_pq.cuh @@ -18,6 +18,8 @@ #include "../test_utils.cuh" #include "ann_utils.cuh" +#include + #include #include #include @@ -158,16 +160,16 @@ class ivf_pq_test : public ::testing::TestWithParam { size_t queries_size = size_t{ps.num_queries} * size_t{ps.k}; rmm::device_uvector distances_naive_dev(queries_size, stream_); rmm::device_uvector indices_naive_dev(queries_size, stream_); - naiveBfKnn(distances_naive_dev.data(), - indices_naive_dev.data(), - search_queries.data(), - database.data(), - ps.num_queries, - ps.num_db_vecs, - ps.dim, - ps.k, - ps.index_params.metric, - stream_); + naive_knn(distances_naive_dev.data(), + indices_naive_dev.data(), + search_queries.data(), + database.data(), + ps.num_queries, + ps.num_db_vecs, + ps.dim, + ps.k, + ps.index_params.metric, + stream_); distances_ref.resize(queries_size); update_host(distances_ref.data(), distances_naive_dev.data(), queries_size, stream_); indices_ref.resize(queries_size); diff --git a/cpp/test/neighbors/ann_utils.cuh b/cpp/test/neighbors/ann_utils.cuh index f98f0fa771..4b07db32f4 100644 --- a/cpp/test/neighbors/ann_utils.cuh +++ b/cpp/test/neighbors/ann_utils.cuh @@ -103,100 +103,6 @@ inline auto operator<<(std::ostream& os, const print_metric& p) -> std::ostream& return os; } -template -__global__ void naive_distance_kernel(EvalT* dist, - const DataT* x, - const DataT* y, - IdxT m, - IdxT n, - IdxT k, - raft::distance::DistanceType metric) -{ - IdxT midx = IdxT(threadIdx.x) + IdxT(blockIdx.x) * IdxT(blockDim.x); - if (midx >= m) return; - IdxT grid_size = IdxT(blockDim.y) * IdxT(gridDim.y); - for (IdxT nidx = threadIdx.y + blockIdx.y * blockDim.y; nidx < n; nidx += grid_size) { - EvalT acc = EvalT(0); - for (IdxT i = 0; i < k; ++i) { - IdxT xidx = i + midx * k; - IdxT yidx = i + nidx * k; - auto xv = EvalT(x[xidx]); - auto yv = EvalT(y[yidx]); - switch (metric) { - case raft::distance::DistanceType::InnerProduct: { - acc += xv * yv; - } break; - case raft::distance::DistanceType::L2SqrtExpanded: - case raft::distance::DistanceType::L2SqrtUnexpanded: - case raft::distance::DistanceType::L2Expanded: - case raft::distance::DistanceType::L2Unexpanded: { - auto diff = xv - yv; - acc += diff * diff; - } break; - default: break; - } - } - switch (metric) { - case raft::distance::DistanceType::L2SqrtExpanded: - case raft::distance::DistanceType::L2SqrtUnexpanded: { - acc = raft::sqrt(acc); - } break; - default: break; - } - dist[midx * n + nidx] = acc; - } -} - -/** - * TODO: either replace this with brute_force_knn or with distance+select_k - * when either distance or brute_force_knn support 8-bit int inputs. - */ -template -void naiveBfKnn(EvalT* dist_topk, - IdxT* indices_topk, - const DataT* x, - const DataT* y, - size_t n_inputs, - size_t input_len, - size_t dim, - uint32_t k, - raft::distance::DistanceType type, - rmm::cuda_stream_view stream) -{ - rmm::mr::device_memory_resource* mr = nullptr; - auto pool_guard = raft::get_pool_memory_resource(mr, 1024 * 1024); - - dim3 block_dim(16, 32, 1); - // maximum reasonable grid size in `y` direction - auto grid_y = - static_cast(std::min(raft::ceildiv(input_len, block_dim.y), 32768)); - - // bound the memory used by this function - size_t max_batch_size = - std::min(n_inputs, raft::ceildiv(size_t(1) << size_t(27), input_len)); - rmm::device_uvector dist(max_batch_size * input_len, stream, mr); - - for (size_t offset = 0; offset < n_inputs; offset += max_batch_size) { - size_t batch_size = std::min(max_batch_size, n_inputs - offset); - dim3 grid_dim(raft::ceildiv(batch_size, block_dim.x), grid_y, 1); - - naive_distance_kernel<<>>( - dist.data(), x + offset * dim, y, batch_size, input_len, dim, type); - - matrix::detail::select_k(dist.data(), - nullptr, - batch_size, - input_len, - static_cast(k), - dist_topk + offset * k, - indices_topk + offset * k, - type != raft::distance::DistanceType::InnerProduct, - stream, - mr); - } - RAFT_CUDA_TRY(cudaStreamSynchronize(stream)); -} - template struct idx_dist_pair { IdxT idx; diff --git a/cpp/test/neighbors/refine.cu b/cpp/test/neighbors/refine.cu index 674171e030..e2575f0f4e 100644 --- a/cpp/test/neighbors/refine.cu +++ b/cpp/test/neighbors/refine.cu @@ -1,5 +1,5 @@ /* - * Copyright (c) 2022, NVIDIA CORPORATION. + * Copyright (c) 2022-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -17,7 +17,7 @@ #include "../test_utils.cuh" #include "ann_utils.cuh" -#include "refine_helper.cuh" +#include #include #include @@ -40,11 +40,11 @@ namespace raft::neighbors { template -class RefineTest : public ::testing::TestWithParam> { +class RefineTest : public ::testing::TestWithParam> { public: RefineTest() : stream_(handle_.get_stream()), - data(handle_, ::testing::TestWithParam>::GetParam()) + data(handle_, ::testing::TestWithParam>::GetParam()) { } @@ -104,11 +104,11 @@ class RefineTest : public ::testing::TestWithParam> { public: raft::handle_t handle_; rmm::cuda_stream_view stream_; - detail::RefineHelper data; + RefineHelper data; }; -const std::vector> inputs = - raft::util::itertools::product>( +const std::vector> inputs = + raft::util::itertools::product>( {137}, {1000}, {16},