From cc4a76ba19b4b01ca5cbf454365f5f48f4bf313c Mon Sep 17 00:00:00 2001 From: tsuki <12711693+enp1s0@users.noreply.github.com> Date: Wed, 10 May 2023 23:24:37 +0900 Subject: [PATCH] CAGRA: Separate graph index sorting functionality from prune function (#1471) # Changes This PR separates the graph index sorting functionality from the CAGRA pruning function and creates a new function. (Related issue: https://github.com/rapidsai/raft/issues/1446) # Unit test I have included a new unit test for the sorting function. The test utilizes a separate dataset from the one used in the CAGRA main test to avoid the effect of rounding errors during norm computation between two vectors in the dataset. More details are in the source code. https://github.com/enp1s0/raft/blob/ea6c449c260895e9125a591a4848eed06f5b72c4/cpp/test/neighbors/ann_cagra.cuh#L93-L96 # Issue Close #1446 Authors: - tsuki (https://github.com/enp1s0) - Tamas Bela Feher (https://github.com/tfeher) Approvers: - Tamas Bela Feher (https://github.com/tfeher) URL: https://github.com/rapidsai/raft/pull/1471 --- cpp/include/raft/neighbors/cagra.cuh | 65 +++++-- .../neighbors/detail/cagra/graph_core.cuh | 131 +++++++++----- cpp/test/neighbors/ann_cagra.cuh | 169 +++++++++++++++++- .../ann_cagra/test_float_uint32_t.cu | 4 + .../ann_cagra/test_int8_t_uint32_t.cu | 3 + .../ann_cagra/test_uint8_t_uint32_t.cu | 4 + 6 files changed, 306 insertions(+), 70 deletions(-) diff --git a/cpp/include/raft/neighbors/cagra.cuh b/cpp/include/raft/neighbors/cagra.cuh index 90728efd70..87d370b54a 100644 --- a/cpp/include/raft/neighbors/cagra.cuh +++ b/cpp/include/raft/neighbors/cagra.cuh @@ -52,8 +52,8 @@ namespace raft::neighbors::experimental::cagra { * @code{.cpp} * using namespace raft::neighbors; * // use default index parameters - * ivf_pq::index_params build_params; - * ivf_pq::search_params search_params + * cagra::index_params build_params; + * cagra::search_params search_params * auto knn_graph = raft::make_host_matrix(dataset.extent(0), 128); * // create knn graph * cagra::build_knn_graph(res, dataset, knn_graph.view(), 2, build_params, search_params); @@ -84,6 +84,49 @@ void build_knn_graph(raft::device_resources const& res, detail::build_knn_graph(res, dataset, knn_graph, refine_rate, build_params, search_params); } +/** + * @brief Sort a KNN graph index. + * Preprocessing step for `cagra::prune`: If a KNN graph is not built using + * `cagra::build_knn_graph`, then it is necessary to call this function before calling + * `cagra::prune`. If the graph is built by `cagra::build_knn_graph`, it is already sorted and you + * do not need to call this function. + * + * Usage example: + * @code{.cpp} + * using namespace raft::neighbors; + * cagra::index_params build_params; + * auto knn_graph = raft::make_host_matrix(dataset.extent(0), 128); + * // build KNN graph not using `cagra::build_knn_graph` + * // build(knn_graph, dataset, ...); + * // sort graph index + * sort_knn_graph(res, dataset.view(), knn_graph.view()); + * // prune graph + * cagra::prune(res, dataset, knn_graph.view(), pruned_graph.view()); + * // Construct an index from dataset and pruned knn_graph + * auto index = cagra::index(res, build_params.metric(), dataset, pruned_graph.view()); + * @endcode + * + * @tparam DataT type of the data in the source dataset + * @tparam IdxT type of the indices in the source dataset + * + * @param[in] res raft resources + * @param[in] dataset a matrix view (host or device) to a row-major matrix [n_rows, dim] + * @param[in,out] knn_graph a matrix view (host or device) of the input knn graph [n_rows, + * knn_graph_degree] + */ +template , memory_type::device>, + typename g_accessor = + host_device_accessor, memory_type::host>> +void sort_knn_graph(raft::device_resources const& res, + mdspan, row_major, d_accessor> dataset, + mdspan, row_major, g_accessor> knn_graph) +{ + detail::graph::sort_knn_graph(res, dataset, knn_graph); +} + /** * @brief Prune a KNN graph. * @@ -91,27 +134,21 @@ void build_knn_graph(raft::device_resources const& res, * * See [cagra::build_knn_graph](#cagra::build_knn_graph) for usage example * - * @tparam T data element type * @tparam IdxT type of the indices in the source dataset * * @param[in] res raft resources - * @param[in] dataset a matrix view (host or device) to a row-major matrix [n_rows, dim] * @param[in] knn_graph a matrix view (host or device) of the input knn graph [n_rows, * knn_graph_degree] * @param[out] new_graph a host matrix view of the pruned knn graph [n_rows, graph_degree] */ -template , memory_type::device>, +template , memory_type::host>> + host_device_accessor, memory_type::host>> void prune(raft::device_resources const& res, - mdspan, row_major, d_accessor> dataset, mdspan, row_major, g_accessor> knn_graph, raft::host_matrix_view new_graph) { - detail::graph::prune(res, dataset, knn_graph, new_graph); + detail::graph::prune(res, knn_graph, new_graph); } /** @@ -138,11 +175,11 @@ void prune(raft::device_resources const& res, * // create and fill the index from a [N, D] dataset * auto index = cagra::build(res, index_params, dataset); * // use default search parameters - * ivf_pq::search_params search_params; + * cagra::search_params search_params; * // search K nearest neighbours * auto neighbors = raft::make_device_matrix(res, n_queries, k); * auto distances = raft::make_device_matrix(res, n_queries, k); - * ivf_pq::search(res, search_params, index, queries, neighbors, distances); + * cagra::search(res, search_params, index, queries, neighbors, distances); * @endcode * * @tparam T data element type @@ -178,7 +215,7 @@ index build(raft::device_resources const& res, auto cagra_graph = raft::make_host_matrix(dataset.extent(0), params.graph_degree); - prune(res, dataset, knn_graph.view(), cagra_graph.view()); + prune(res, knn_graph.view(), cagra_graph.view()); // Construct an index from dataset and pruned knn graph. return index(res, params.metric, dataset, cagra_graph.view()); diff --git a/cpp/include/raft/neighbors/detail/cagra/graph_core.cuh b/cpp/include/raft/neighbors/detail/cagra/graph_core.cuh index 02055f2a4d..a08c83677b 100644 --- a/cpp/include/raft/neighbors/detail/cagra/graph_core.cuh +++ b/cpp/include/raft/neighbors/detail/cagra/graph_core.cuh @@ -405,36 +405,24 @@ void shift_array(T* array, uint64_t num) } } -/** Input arrays can be both host and device*/ -template , memory_type::device>, + host_device_accessor, memory_type::device>, typename g_accessor = - host_device_accessor, memory_type::host>> -void prune(raft::device_resources const& res, - mdspan, row_major, d_accessor> dataset, - mdspan, row_major, g_accessor> knn_graph, - raft::host_matrix_view new_graph) + host_device_accessor, memory_type::host>> +void sort_knn_graph(raft::device_resources const& res, + mdspan, row_major, d_accessor> dataset, + mdspan, row_major, g_accessor> knn_graph) { - RAFT_LOG_DEBUG( - "# Pruning kNN graph (size=%lu, degree=%lu)\n", knn_graph.extent(0), knn_graph.extent(1)); + RAFT_EXPECTS(dataset.extent(0) == knn_graph.extent(0), + "dataset size is expected to have the same number of graph index size"); + const uint32_t dataset_size = dataset.extent(0); + const uint32_t dataset_dim = dataset.extent(1); + const DataT* dataset_ptr = dataset.data_handle(); - RAFT_EXPECTS( - dataset.extent(0) == knn_graph.extent(0) && knn_graph.extent(0) == new_graph.extent(0), - "Each input array is expected to have the same number of rows"); - RAFT_EXPECTS(new_graph.extent(1) <= knn_graph.extent(1), - "output graph cannot have more columns than input graph"); - const uint32_t dataset_size = dataset.extent(0); - const uint32_t dataset_dim = dataset.extent(1); - const uint32_t input_graph_degree = knn_graph.extent(1); - const uint32_t output_graph_degree = new_graph.extent(1); - const DATA_T* dataset_ptr = dataset.data_handle(); - uint32_t* input_graph_ptr = (uint32_t*)knn_graph.data_handle(); - uint32_t* output_graph_ptr = new_graph.data_handle(); - float scale = 1.0f / raft::spatial::knn::detail::utils::config::kDivisor; - const std::size_t graph_size = dataset_size; - size_t array_size; + const uint32_t input_graph_degree = knn_graph.extent(1); + uint32_t* input_graph_ptr = (uint32_t*)knn_graph.data_handle(); // Setup GPUs int num_gpus = 0; @@ -451,46 +439,48 @@ void prune(raft::device_resources const& res, } RAFT_CUDA_TRY(cudaSetDevice(0)); - uint32_t graph_chunk_size = graph_size; - uint32_t*** d_input_graph_ptr = NULL; // [...][num_gpus][graph_chunk_size, input_graph_degree] - graph_chunk_size = (graph_size + num_gpus - 1) / num_gpus; + const uint32_t graph_size = knn_graph.extent(0); + uint32_t*** d_input_graph_ptr = NULL; // [...][num_gpus][graph_chunk_size, input_graph_degree] + const uint32_t graph_chunk_size = (graph_size + num_gpus - 1) / num_gpus; d_input_graph_ptr = mgpu_alloc(num_gpus, graph_chunk_size, input_graph_degree); - uint32_t dataset_chunk_size = dataset_size; - DATA_T*** d_dataset_ptr = NULL; // [num_gpus+1][...][...] - dataset_chunk_size = (dataset_size + num_gpus - 1) / num_gpus; + DataT*** d_dataset_ptr = NULL; // [num_gpus+1][...][...] + const uint32_t dataset_chunk_size = (dataset_size + num_gpus - 1) / num_gpus; assert(dataset_chunk_size == graph_chunk_size); - d_dataset_ptr = mgpu_alloc(num_gpus, dataset_chunk_size, dataset_dim); + d_dataset_ptr = mgpu_alloc(num_gpus, dataset_chunk_size, dataset_dim); - mgpu_H2D( + const float scale = 1.0f / raft::spatial::knn::detail::utils::config::kDivisor; + + mgpu_H2D( d_dataset_ptr, dataset_ptr, num_gpus, dataset_size, dataset_chunk_size, dataset_dim); - // - // Sorting kNN graph - // double time_sort_start = cur_time(); RAFT_LOG_DEBUG("# Sorting kNN Graph on GPUs "); - mgpu_H2D( - d_input_graph_ptr, input_graph_ptr, num_gpus, graph_size, graph_chunk_size, input_graph_degree); + mgpu_H2D(d_input_graph_ptr, + input_graph_ptr, + num_gpus, + dataset_size, + graph_chunk_size, + input_graph_degree); void (*kernel_sort)( - DATA_T**, uint32_t, uint32_t, uint32_t, float, uint32_t**, uint32_t, uint32_t, uint32_t, int); + DataT**, uint32_t, uint32_t, uint32_t, float, uint32_t**, uint32_t, uint32_t, uint32_t, int); constexpr int numElementsPerThread = 4; dim3 threads_sort(1, 1, 1); if (input_graph_degree <= numElementsPerThread * 32) { constexpr int blockDim_x = 32; - kernel_sort = kern_sort; + kernel_sort = kern_sort; threads_sort.x = blockDim_x; } else if (input_graph_degree <= numElementsPerThread * 64) { constexpr int blockDim_x = 64; - kernel_sort = kern_sort; + kernel_sort = kern_sort; threads_sort.x = blockDim_x; } else if (input_graph_degree <= numElementsPerThread * 128) { constexpr int blockDim_x = 128; - kernel_sort = kern_sort; + kernel_sort = kern_sort; threads_sort.x = blockDim_x; } else if (input_graph_degree <= numElementsPerThread * 256) { constexpr int blockDim_x = 256; - kernel_sort = kern_sort; + kernel_sort = kern_sort; threads_sort.x = blockDim_x; } else { fprintf(stderr, @@ -510,7 +500,7 @@ void prune(raft::device_resources const& res, dataset_dim, scale, d_input_graph_ptr[i_gpu], - graph_size, + dataset_size, graph_chunk_size, input_graph_degree, i_gpu); @@ -518,13 +508,60 @@ void prune(raft::device_resources const& res, RAFT_CUDA_TRY(cudaSetDevice(0)); RAFT_CUDA_TRY(cudaDeviceSynchronize()); RAFT_LOG_DEBUG("."); - mgpu_D2H( - d_input_graph_ptr, input_graph_ptr, num_gpus, graph_size, graph_chunk_size, input_graph_degree); + mgpu_D2H(d_input_graph_ptr, + input_graph_ptr, + num_gpus, + dataset_size, + graph_chunk_size, + input_graph_degree); RAFT_LOG_DEBUG("\n"); double time_sort_end = cur_time(); RAFT_LOG_DEBUG("# Sorting kNN graph time: %.1lf sec\n", time_sort_end - time_sort_start); - mgpu_free(d_dataset_ptr, num_gpus); + mgpu_free(d_dataset_ptr, num_gpus); +} + +/** Input arrays can be both host and device*/ +template , memory_type::host>> +void prune(raft::device_resources const& res, + mdspan, row_major, g_accessor> knn_graph, + raft::host_matrix_view new_graph) +{ + RAFT_LOG_DEBUG( + "# Pruning kNN graph (size=%lu, degree=%lu)\n", knn_graph.extent(0), knn_graph.extent(1)); + + RAFT_EXPECTS(knn_graph.extent(0) == new_graph.extent(0), + "Each input array is expected to have the same number of rows"); + RAFT_EXPECTS(new_graph.extent(1) <= knn_graph.extent(1), + "output graph cannot have more columns than input graph"); + const uint32_t input_graph_degree = knn_graph.extent(1); + const uint32_t output_graph_degree = new_graph.extent(1); + uint32_t* input_graph_ptr = (uint32_t*)knn_graph.data_handle(); + uint32_t* output_graph_ptr = new_graph.data_handle(); + const std::size_t graph_size = new_graph.extent(0); + size_t array_size; + + // Setup GPUs + int num_gpus = 0; + + // Setup GPUs + RAFT_CUDA_TRY(cudaGetDeviceCount(&num_gpus)); + RAFT_LOG_DEBUG("# num_gpus: %d\n", num_gpus); + for (int self = 0; self < num_gpus; self++) { + RAFT_CUDA_TRY(cudaSetDevice(self)); + for (int peer = 0; peer < num_gpus; peer++) { + if (self == peer) { continue; } + RAFT_CUDA_TRY(cudaDeviceEnablePeerAccess(peer, 0)); + } + } + RAFT_CUDA_TRY(cudaSetDevice(0)); + + uint32_t graph_chunk_size = graph_size; + uint32_t*** d_input_graph_ptr = NULL; // [...][num_gpus][graph_chunk_size, input_graph_degree] + graph_chunk_size = (graph_size + num_gpus - 1) / num_gpus; + d_input_graph_ptr = mgpu_alloc(num_gpus, graph_chunk_size, input_graph_degree); // uint8_t* detour_count; // [graph_size, input_graph_degree] diff --git a/cpp/test/neighbors/ann_cagra.cuh b/cpp/test/neighbors/ann_cagra.cuh index f9df1f724f..1096dc4fb0 100644 --- a/cpp/test/neighbors/ann_cagra.cuh +++ b/cpp/test/neighbors/ann_cagra.cuh @@ -40,6 +40,88 @@ #include namespace raft::neighbors::experimental::cagra { +namespace { +// For sort_knn_graph test +template +void RandomSuffle(raft::host_matrix_view index) +{ + for (IdxT i = 0; i < index.extent(0); i++) { + uint64_t rand = i; + IdxT* const row_ptr = index.data_handle() + i * index.extent(1); + for (unsigned j = 0; j < index.extent(1); j++) { + // Swap two indices at random + rand = raft::neighbors::experimental::cagra::detail::device::xorshift64(rand); + const auto i0 = rand % index.extent(1); + rand = raft::neighbors::experimental::cagra::detail::device::xorshift64(rand); + const auto i1 = rand % index.extent(1); + + const auto tmp = row_ptr[i0]; + row_ptr[i0] = row_ptr[i1]; + row_ptr[i1] = tmp; + } + } +} + +template +testing::AssertionResult CheckOrder(raft::host_matrix_view index_test, + raft::host_matrix_view dataset) +{ + for (IdxT i = 0; i < index_test.extent(0); i++) { + const DatatT* const base_vec = dataset.data_handle() + i * dataset.extent(1); + const IdxT* const index_row = index_test.data_handle() + i * index_test.extent(1); + DistanceT prev_distance = 0; + for (unsigned j = 0; j < index_test.extent(1) - 1; j++) { + const DatatT* const target_vec = dataset.data_handle() + index_row[j] * dataset.extent(1); + DistanceT distance = 0; + for (unsigned l = 0; l < dataset.extent(1); l++) { + const auto diff = + static_cast(target_vec[l]) - static_cast(base_vec[l]); + distance += diff * diff; + } + if (prev_distance > distance) { + return testing::AssertionFailure() + << "Wrong index order (row = " << i << ", neighbor_id = " << j + << "). (distance[neighbor_id-1] = " << prev_distance + << "should be larger than distance[neighbor_id] = " << distance << ")"; + } + prev_distance = distance; + } + } + return testing::AssertionSuccess(); +} + +// Generate dataset to ensure no rounding error occurs in the norm computation of any two vectors. +// When testing the CAGRA index sorting function, rounding errors can affect the norm and alter the +// order of the index. To ensure the accuracy of the test, we utilize the dataset. The generation +// method is based on the error-free transformation (EFT) method. +__global__ void GenerateRoundingErrorFreeDataset_kernel(float* const ptr, + const uint32_t size, + const uint32_t resolution) +{ + const auto tid = threadIdx.x + blockIdx.x * blockDim.x; + if (tid >= size) { return; } + + const float u32 = *reinterpret_cast(ptr + tid); + ptr[tid] = u32 / resolution; +} + +void GenerateRoundingErrorFreeDataset(float* const ptr, + const uint32_t n_row, + const uint32_t dim, + raft::random::Rng& rng, + cudaStream_t cuda_stream) +{ + const uint32_t size = n_row * dim; + const uint32_t block_size = 256; + const uint32_t grid_size = (size + block_size - 1) / block_size; + + const uint32_t resolution = 1u << static_cast(std::floor((24 - std::log2(dim)) / 2)); + rng.uniformInt(reinterpret_cast(ptr), size, 0u, resolution - 1, cuda_stream); + + GenerateRoundingErrorFreeDataset_kernel<<>>( + ptr, size, resolution); +} +} // namespace struct AnnCagraInputs { int n_queries; @@ -107,7 +189,7 @@ class AnnCagraTest : public ::testing::TestWithParam { stream_); update_host(distances_naive.data(), distances_naive_dev.data(), queries_size, stream_); update_host(indices_naive.data(), indices_naive_dev.data(), queries_size, stream_); - handle_.sync_stream(stream_); + handle_.sync_stream(); } { @@ -153,7 +235,7 @@ class AnnCagraTest : public ::testing::TestWithParam { update_host(distances_Cagra.data(), distances_dev.data(), queries_size, stream_); update_host(indices_Cagra.data(), indices_dev.data(), queries_size, stream_); - handle_.sync_stream(stream_); + handle_.sync_stream(); } // for (int i = 0; i < ps.n_queries; i++) { // // std::cout << "query " << i << std::end; @@ -194,18 +276,18 @@ class AnnCagraTest : public ::testing::TestWithParam { std::cout << "Done.\nRuning rng" << std::endl; raft::random::Rng r(1234ULL); if constexpr (std::is_same{}) { - r.uniform(database.data(), ps.n_rows * ps.dim, DataT(0.1), DataT(2.0), stream_); - r.uniform(search_queries.data(), ps.n_queries * ps.dim, DataT(0.1), DataT(2.0), stream_); + r.normal(database.data(), ps.n_rows * ps.dim, DataT(0.1), DataT(2.0), stream_); + r.normal(search_queries.data(), ps.n_queries * ps.dim, DataT(0.1), DataT(2.0), stream_); } else { r.uniformInt(database.data(), ps.n_rows * ps.dim, DataT(1), DataT(20), stream_); r.uniformInt(search_queries.data(), ps.n_queries * ps.dim, DataT(1), DataT(20), stream_); } - handle_.sync_stream(stream_); + handle_.sync_stream(); } void TearDown() override { - handle_.sync_stream(stream_); + handle_.sync_stream(); database.resize(0, stream_); search_queries.resize(0, stream_); } @@ -218,6 +300,75 @@ class AnnCagraTest : public ::testing::TestWithParam { rmm::device_uvector search_queries; }; +template +class AnnCagraSortTest : public ::testing::TestWithParam { + public: + AnnCagraSortTest() + : ps(::testing::TestWithParam::GetParam()), database(0, handle_.get_stream()) + { + } + + protected: + void testCagraSort() + { + { + // Step 1: Build a sorted KNN graph by CAGRA knn build + auto database_view = raft::make_device_matrix_view( + (const DataT*)database.data(), ps.n_rows, ps.dim); + auto database_host = raft::make_host_matrix(ps.n_rows, ps.dim); + raft::copy( + database_host.data_handle(), database.data(), database.size(), handle_.get_stream()); + auto database_host_view = raft::make_host_matrix_view( + (const DataT*)database_host.data_handle(), ps.n_rows, ps.dim); + + cagra::index_params index_params; + auto knn_graph = + raft::make_host_matrix(ps.n_rows, index_params.intermediate_graph_degree); + + if (ps.host_dataset) { + cagra::build_knn_graph(handle_, database_host_view, knn_graph.view()); + } else { + cagra::build_knn_graph(handle_, database_view, knn_graph.view()); + }; + + handle_.sync_stream(); + ASSERT_TRUE(CheckOrder(knn_graph.view(), database_host.view())); + + RandomSuffle(knn_graph.view()); + + cagra::sort_knn_graph(handle_, database_view, knn_graph.view()); + handle_.sync_stream(); + + ASSERT_TRUE(CheckOrder(knn_graph.view(), database_host.view())); + } + } + + void SetUp() override + { + std::cout << "Resizing database: " << ps.n_rows * ps.dim << std::endl; + database.resize(((size_t)ps.n_rows) * ps.dim, handle_.get_stream()); + std::cout << "Done.\nRuning rng" << std::endl; + raft::random::Rng r(1234ULL); + if constexpr (std::is_same{}) { + GenerateRoundingErrorFreeDataset(database.data(), ps.n_rows, ps.dim, r, handle_.get_stream()); + } else { + r.uniformInt(database.data(), ps.n_rows * ps.dim, DataT(1), DataT(20), handle_.get_stream()); + } + handle_.sync_stream(); + } + + void TearDown() override + { + handle_.sync_stream(); + database.resize(0, handle_.get_stream()); + } + + private: + raft::device_resources handle_; + AnnCagraInputs ps; + rmm::device_uvector database; +}; + inline std::vector generate_inputs() { // Todo(tfeher): MULTI_CTA tests a bug, consider disabling that mode. @@ -238,7 +389,7 @@ inline std::vector generate_inputs() auto inputs2 = raft::util::itertools::product({100}, {1000}, - {2, 4, 8, 64, 128, 196, 256, 512, 1024}, // dim + {8, 64, 128, 192, 256, 512, 1024}, // dim {16}, {search_algo::AUTO}, {10}, @@ -282,7 +433,7 @@ inline std::vector generate_inputs() inputs2 = raft::util::itertools::product({100}, {10000, 20000}, - {30}, + {32}, {10}, {search_algo::AUTO}, {10}, @@ -297,7 +448,7 @@ inline std::vector generate_inputs() inputs2 = raft::util::itertools::product({100}, {10000, 20000}, - {30}, + {32}, {10}, {search_algo::AUTO}, {10}, diff --git a/cpp/test/neighbors/ann_cagra/test_float_uint32_t.cu b/cpp/test/neighbors/ann_cagra/test_float_uint32_t.cu index 1497a515d2..adb44a9264 100644 --- a/cpp/test/neighbors/ann_cagra/test_float_uint32_t.cu +++ b/cpp/test/neighbors/ann_cagra/test_float_uint32_t.cu @@ -23,6 +23,10 @@ namespace raft::neighbors::experimental::cagra { typedef AnnCagraTest AnnCagraTestF; TEST_P(AnnCagraTestF, AnnCagra) { this->testCagra(); } +typedef AnnCagraSortTest AnnCagraSortTestF; +TEST_P(AnnCagraSortTestF, AnnCagraSort) { this->testCagraSort(); } + INSTANTIATE_TEST_CASE_P(AnnCagraTest, AnnCagraTestF, ::testing::ValuesIn(inputs)); +INSTANTIATE_TEST_CASE_P(AnnCagraSortTest, AnnCagraSortTestF, ::testing::ValuesIn(inputs)); } // namespace raft::neighbors::experimental::cagra diff --git a/cpp/test/neighbors/ann_cagra/test_int8_t_uint32_t.cu b/cpp/test/neighbors/ann_cagra/test_int8_t_uint32_t.cu index f148ebc186..11c986c189 100644 --- a/cpp/test/neighbors/ann_cagra/test_int8_t_uint32_t.cu +++ b/cpp/test/neighbors/ann_cagra/test_int8_t_uint32_t.cu @@ -22,7 +22,10 @@ namespace raft::neighbors::experimental::cagra { typedef AnnCagraTest AnnCagraTestI8; TEST_P(AnnCagraTestI8, AnnCagra) { this->testCagra(); } +typedef AnnCagraSortTest AnnCagraSortTestI8; +TEST_P(AnnCagraSortTestI8, AnnCagraSort) { this->testCagraSort(); } INSTANTIATE_TEST_CASE_P(AnnCagraTest, AnnCagraTestI8, ::testing::ValuesIn(inputs)); +INSTANTIATE_TEST_CASE_P(AnnCagraSortTest, AnnCagraSortTestI8, ::testing::ValuesIn(inputs)); } // namespace raft::neighbors::experimental::cagra diff --git a/cpp/test/neighbors/ann_cagra/test_uint8_t_uint32_t.cu b/cpp/test/neighbors/ann_cagra/test_uint8_t_uint32_t.cu index 087d7cec71..51d4feeed2 100644 --- a/cpp/test/neighbors/ann_cagra/test_uint8_t_uint32_t.cu +++ b/cpp/test/neighbors/ann_cagra/test_uint8_t_uint32_t.cu @@ -23,6 +23,10 @@ namespace raft::neighbors::experimental::cagra { typedef AnnCagraTest AnnCagraTestU8; TEST_P(AnnCagraTestU8, AnnCagra) { this->testCagra(); } +typedef AnnCagraSortTest AnnCagraSortTestU8; +TEST_P(AnnCagraSortTestU8, AnnCagraSort) { this->testCagraSort(); } + INSTANTIATE_TEST_CASE_P(AnnCagraTest, AnnCagraTestU8, ::testing::ValuesIn(inputs)); +INSTANTIATE_TEST_CASE_P(AnnCagraSortTest, AnnCagraSortTestU8, ::testing::ValuesIn(inputs)); } // namespace raft::neighbors::experimental::cagra